From e8b34eed57343039a5ff6dfcebf7cdd99d72b8b8 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Mon, 25 Nov 2024 10:25:13 +0100 Subject: [PATCH] Add while loops --- .task/checksum/astgen | 2 +- ast/expr.go | 11 +++++ ast/printer.go | 18 ++++++++ ast/printer_test.go | 40 ++++++++++++++++ ast/stmt.go | 10 ++++ cmd/astgen/main.go | 2 + interpreter/interpreter.go | 26 +++++++++++ interpreter/interpreter_test.go | 81 +++++++++++++++++++++++++++++++++ parser/parser.go | 54 ++++++++++++++++++++-- parser/parser_test.go | 50 ++++++++++++++++++++ 10 files changed, 290 insertions(+), 4 deletions(-) diff --git a/.task/checksum/astgen b/.task/checksum/astgen index 63cec90..e2d05de 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -98e69ec471348f6ec8e672a1b8c9825d +ef159c739c079a54bc5f6af10e2ea025 diff --git a/ast/expr.go b/ast/expr.go index 44bdbad..3a332d4 100644 --- a/ast/expr.go +++ b/ast/expr.go @@ -8,6 +8,7 @@ type ExprVisitor[T any] interface { VisitBinaryExpr(be *BinaryExpr) T VisitGroupingExpr(ge *GroupingExpr) T VisitLiteralExpr(le *LiteralExpr) T + VisitLogicalExpr(le *LogicalExpr) T VisitUnaryExpr(ue *UnaryExpr) T VisitVariableExpr(ve *VariableExpr) T } @@ -59,6 +60,16 @@ func (le *LiteralExpr) Accept(v ExprVisitor[any]) any { return v.VisitLiteralExpr(le) } +type LogicalExpr struct { + Left Expr + Operator token.Token + Right Expr +} + +func (le *LogicalExpr) Accept(v ExprVisitor[any]) any { + return v.VisitLogicalExpr(le) +} + type UnaryExpr struct { Operator token.Token Right Expr diff --git a/ast/printer.go b/ast/printer.go index 4fa84e5..6d14322 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -59,6 +59,20 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { return ap.parenthesizeExpr("var", &LiteralExpr{stmt.Name}, stmt.Initializer) } +func (ap *Printer) VisitWhileStmt(stmt *WhileStmt) any { + str := "(while" + + if stmt.Condition != nil { + str += " " + stmt.Condition.Accept(ap).(string) + } + + if stmt.Body != nil { + str += " " + ap.bracesizeStmt(stmt.Body) + } + + return str + ")" +} + func (ap *Printer) VisitBlockStmt(stmt *BlockStmt) any { return ap.bracesizeStmt(stmt.Statements...) } @@ -82,6 +96,10 @@ func (ap *Printer) VisitLiteralExpr(expr *LiteralExpr) any { return fmt.Sprint(expr.Value) } +func (ap *Printer) VisitLogicalExpr(expr *LogicalExpr) any { + return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Left, expr.Right) +} + func (ap *Printer) VisitUnaryExpr(expr *UnaryExpr) any { return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Right) } diff --git a/ast/printer_test.go b/ast/printer_test.go index 2f0ddf9..a75b02c 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -27,6 +27,11 @@ func TestPrintExpr(t *testing.T) { }, expected: "(group 1)", }, + { + name: "Logical expression", + expr: &LogicalExpr{Left: &LiteralExpr{Value: 1}, Operator: token.Token{Type: token.AND, Lexeme: "and"}, Right: &LiteralExpr{Value: 2}}, + expected: "(and 1 2)", + }, { name: "Literal expression", expr: &LiteralExpr{Value: 123}, @@ -95,6 +100,25 @@ func TestPrintStmts(t *testing.T) { }, expected: "42\n", }, + { + name: "Error statement", + stmts: []Stmt{ + &ErrorStmt{ + Value: "error", + }, + }, + expected: "error\n", + }, + { + name: "Var statement", + stmts: []Stmt{ + &VarStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Initializer: &LiteralExpr{Value: 42}, + }, + }, + expected: "(var foo 42)\n", + }, } for _, tt := range tests { @@ -146,3 +170,19 @@ func TestPrintIfStmt(t *testing.T) { t.Errorf("expected %v, got %v", expected, result) } } + +func TestPrintWhileStmt(t *testing.T) { + stmt := &WhileStmt{ + Condition: &LiteralExpr{Value: true}, + Body: &ExpressionStmt{ + Expression: &LiteralExpr{Value: 1}, + }, + } + + printer := NewPrinter() + result := printer.PrintStmts([]Stmt{stmt}) + expected := "(while true {\n 1\n})\n" + if result != expected { + t.Errorf("expected %v, got %v", expected, result) + } +} diff --git a/ast/stmt.go b/ast/stmt.go index eb65299..e5944f3 100644 --- a/ast/stmt.go +++ b/ast/stmt.go @@ -9,6 +9,7 @@ type StmtVisitor[T any] interface { VisitIfStmt(is *IfStmt) T VisitPrintStmt(ps *PrintStmt) T VisitVarStmt(vs *VarStmt) T + VisitWhileStmt(ws *WhileStmt) T } type Stmt interface { @@ -66,3 +67,12 @@ func (vs *VarStmt) Accept(v StmtVisitor[any]) any { return v.VisitVarStmt(vs) } +type WhileStmt struct { + Condition Expr + Body Stmt +} + +func (ws *WhileStmt) Accept(v StmtVisitor[any]) any { + return v.VisitWhileStmt(ws) +} + diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index ab3a86d..ab529cd 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -27,6 +27,7 @@ func main() { "Binary : Left Expr, Operator token.Token, Right Expr", "Grouping : Expression Expr", "Literal : Value any", + "Logical : Left Expr, Operator token.Token, Right Expr", "Unary : Operator token.Token, Right Expr", "Variable : Name token.Token", }) @@ -38,6 +39,7 @@ func main() { "If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt", "Print : Expression Expr", "Var : Name token.Token, Initializer Expr", + "While : Condition Expr, Body Stmt", }) } diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 4a5560c..79754db 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -57,6 +57,15 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any { return nil } +// VisitWhileStmt visits a while statement. +func (i *Interpreter) VisitWhileStmt(ws *ast.WhileStmt) any { + for isTruthy(i.evaluate(ws.Condition)) { + i.execute(ws.Body) + } + + return nil +} + // VisitVarStmt visits a var statement. func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any { var value any @@ -99,6 +108,23 @@ func (i *Interpreter) VisitLiteralExpr(l *ast.LiteralExpr) any { return l.Value } +// VisitLogicalExpr visits a LogicalExpr. +func (i *Interpreter) VisitLogicalExpr(l *ast.LogicalExpr) any { + left := i.evaluate(l.Left) + + if l.Operator.Type == token.OR { + if isTruthy(left) { + return left + } + } else { + if !isTruthy(left) { + return left + } + } + + return i.evaluate(l.Right) +} + // VisitGroupingExpr visits a GroupingExpr. func (i *Interpreter) VisitGroupingExpr(g *ast.GroupingExpr) any { return i.evaluate(g.Expression) diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index b58cbed..4f30a02 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -277,6 +277,22 @@ func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { i.VisitBinaryExpr(binary) } +func TestInterpretLogicalExpr(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + left := &ast.LiteralExpr{Value: true} + right := &ast.LiteralExpr{Value: false} + logical := &ast.LogicalExpr{ + Left: left, + Operator: token.Token{Type: token.AND, Lexeme: "and"}, + Right: right, + } + + result := i.VisitLogicalExpr(logical) + if result != false { + t.Errorf("expected false, got %v", result) + } +} + func TestInterpretErrorStatement(t *testing.T) { i := New(errors.NewMockErrorLogger()) errorStmt := &ast.ErrorStmt{Value: "error"} @@ -538,3 +554,68 @@ func TestInterpretIfStatementElseBranch(t *testing.T) { t.Errorf("run() = %v; want %v", out, expected) } } + +func TestInterpretWhileStatement(t *testing.T) { + old := os.Stdout // keep backup of the real stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = w + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + // begin unit test + i := New(errors.NewMockErrorLogger()) + varStmt := &ast.VarStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}, + Initializer: &ast.LiteralExpr{Value: 3.0}, + } + + i.VisitVarStmt(varStmt) + + whileStmt := &ast.WhileStmt{ + Condition: &ast.BinaryExpr{ + Left: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}}, + Operator: token.Token{Type: token.GREATER, Lexeme: ">"}, + Right: &ast.LiteralExpr{Value: 0.0}, + }, + Body: &ast.BlockStmt{ + Statements: []ast.Stmt{ + &ast.PrintStmt{ + Expression: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}}, + }, + &ast.ExpressionStmt{ + Expression: &ast.AssignExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}, + Value: &ast.BinaryExpr{ + Left: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}}, + Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, + Right: &ast.LiteralExpr{Value: 1.0}, + }, + }, + }, + }, + }, + } + + i.VisitWhileStmt(whileStmt) + // end unit test + + // back to normal state + w.Close() + os.Stdout = old // restoring the real stdout + out := <-outC + + // reading our temp stdout + expected := "3\n2\n1\n" + if out != expected { + t.Errorf("run() = %v; want %v", out, expected) + } +} diff --git a/parser/parser.go b/parser/parser.go index 1fb5b33..20b358c 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -83,7 +83,7 @@ func (p *Parser) varDeclaration() ast.Stmt { return &ast.VarStmt{Name: name, Initializer: initializer} } -// statement → exprStmt | ifStmt | printStmt | block ; +// statement → exprStmt | ifStmt | printStmt | whileStmt | block ; func (p *Parser) statement() ast.Stmt { if p.match(token.IF) { return p.ifStatement() @@ -91,6 +91,9 @@ func (p *Parser) statement() ast.Stmt { if p.match(token.PRINT) { return p.printStatement() } + if p.match(token.WHILE) { + return p.whileStatement() + } if p.match(token.LEFT_BRACE) { return p.blockStatement() } @@ -132,6 +135,25 @@ func (p *Parser) printStatement() ast.Stmt { return &ast.PrintStmt{Expression: expr} } +// whileStmt → "while" "(" expression ")" statement ; +func (p *Parser) whileStatement() ast.Stmt { + err := p.consume(token.LEFT_PAREN, "Expect '(' after 'while'.") + if err != nil { + return p.fromErrorExpr(err) + } + + condition := p.expression() + + err = p.consume(token.RIGHT_PAREN, "Expect ')' after while condition.") + if err != nil { + return p.fromErrorExpr(err) + } + + body := p.statement() + + return &ast.WhileStmt{Condition: condition, Body: body} +} + // block → "{" declaration* "}" ; func (p *Parser) blockStatement() ast.Stmt { statements := []ast.Stmt{} @@ -164,9 +186,9 @@ func (p *Parser) expression() ast.Expr { return p.assignment() } -// assignment → IDENTIFIER "=" assignment | equality ; +// assignment → IDENTIFIER "=" assignment | logic_or ; func (p *Parser) assignment() ast.Expr { - expr := p.equality() + expr := p.or() if p.match(token.EQUAL) { equals := p.previous() @@ -182,6 +204,32 @@ func (p *Parser) assignment() ast.Expr { return expr } +// logic_or → logic_and ( "or" logic_and )* ; +func (p *Parser) or() ast.Expr { + expr := p.and() + + for p.match(token.OR) { + operator := p.previous() + right := p.and() + expr = &ast.LogicalExpr{Left: expr, Operator: operator, Right: right} + } + + return expr +} + +// logic_and → equality ( "and" equality )* ; +func (p *Parser) and() ast.Expr { + expr := p.equality() + + for p.match(token.AND) { + operator := p.previous() + right := p.equality() + expr = &ast.LogicalExpr{Left: expr, Operator: operator, Right: right} + } + + return expr +} + // equality → comparison ( ( "!=" | "==" ) comparison )* ; func (p *Parser) equality() ast.Expr { expr := p.comparison() diff --git a/parser/parser_test.go b/parser/parser_test.go index d1d6dba..326adda 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -58,6 +58,17 @@ func TestExpressionParsing(t *testing.T) { }, expected: "(> 1 2)", }, + { + name: "Logical expression", + tokens: []token.Token{ + {Type: token.NUMBER, Literal: 1}, + {Type: token.GREATER, Lexeme: "and"}, + {Type: token.NUMBER, Literal: 2}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(and 1 2)", + }, { name: "Equality expression", tokens: []token.Token{ @@ -380,3 +391,42 @@ func TestParseIfElseStatement(t *testing.T) { }) } } + +func TestParseWhileStatement(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple while statement", + tokens: []token.Token{ + {Type: token.WHILE, Lexeme: "while"}, + {Type: token.LEFT_PAREN, Lexeme: "("}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.RIGHT_PAREN, Lexeme: ")"}, + {Type: token.PRINT, Lexeme: "print"}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(while 42 {\n (print 42)\n})\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +}