From 16958e9f25d62504c28e6da615d356adcc1817bf Mon Sep 17 00:00:00 2001 From: oabrivard Date: Mon, 25 Nov 2024 09:17:03 +0100 Subject: [PATCH] Add conditional execution --- .task/checksum/astgen | 2 +- ast/printer.go | 48 ++++++++++++++----- ast/printer_test.go | 21 ++++++++- ast/stmt.go | 11 +++++ cmd/astgen/main.go | 1 + interpreter/interpreter.go | 11 +++++ interpreter/interpreter_test.go | 83 ++++++++++++++++++++++++++++++++ parser/parser.go | 50 +++++++++++++++----- parser/parser_test.go | 84 ++++++++++++++++++++++++++++++++- 9 files changed, 283 insertions(+), 28 deletions(-) diff --git a/.task/checksum/astgen b/.task/checksum/astgen index b08fd33..63cec90 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -53f18a0e7ed4ab1114f57c8bbdaf9abd +98e69ec471348f6ec8e672a1b8c9825d diff --git a/ast/printer.go b/ast/printer.go index 800034e..4fa84e5 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -34,33 +34,45 @@ func (ap *Printer) VisitExpressionStmt(stmt *ExpressionStmt) any { } func (ap *Printer) VisitPrintStmt(stmt *PrintStmt) any { - return ap.parenthesize("print", stmt.Expression) + return ap.parenthesizeExpr("print", stmt.Expression) } -func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { - return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer) -} +func (ap *Printer) VisitIfStmt(stmt *IfStmt) any { + str := "(if" -func (ap *Printer) VisitBlockStmt(stmt *BlockStmt) any { - str := "(block\n" + if stmt.Condition != nil { + str += " " + stmt.Condition.Accept(ap).(string) + } - for _, s := range stmt.Statements { - str += " " + s.Accept(ap).(string) + "\n" + if stmt.ThenBranch != nil { + str += " " + ap.bracesizeStmt(stmt.ThenBranch) + } + + if stmt.ElseBranch != nil { + str += " else " + ap.bracesizeStmt(stmt.ElseBranch) } return str + ")" } +func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { + return ap.parenthesizeExpr("var", &LiteralExpr{stmt.Name}, stmt.Initializer) +} + +func (ap *Printer) VisitBlockStmt(stmt *BlockStmt) any { + return ap.bracesizeStmt(stmt.Statements...) +} + func (ap *Printer) VisitVariableExpr(expr *VariableExpr) any { return expr.Name.Lexeme } func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any { - return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) + return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Left, expr.Right) } func (ap *Printer) VisitGroupingExpr(expr *GroupingExpr) any { - return ap.parenthesize("group", expr.Expression) + return ap.parenthesizeExpr("group", expr.Expression) } func (ap *Printer) VisitLiteralExpr(expr *LiteralExpr) any { @@ -71,7 +83,7 @@ func (ap *Printer) VisitLiteralExpr(expr *LiteralExpr) any { } func (ap *Printer) VisitUnaryExpr(expr *UnaryExpr) any { - return ap.parenthesize(expr.Operator.Lexeme, expr.Right) + return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Right) } func (ap *Printer) VisitErrorExpr(expr *ErrorExpr) any { @@ -79,10 +91,10 @@ func (ap *Printer) VisitErrorExpr(expr *ErrorExpr) any { } func (ap *Printer) VisitAssignExpr(expr *AssignExpr) any { - return ap.parenthesize("=", &VariableExpr{expr.Name}, expr.Value) + return ap.parenthesizeExpr("=", &VariableExpr{expr.Name}, expr.Value) } -func (ap *Printer) parenthesize(name string, exprs ...Expr) string { +func (ap *Printer) parenthesizeExpr(name string, exprs ...Expr) string { str := "(" + name for _, expr := range exprs { @@ -94,3 +106,13 @@ func (ap *Printer) parenthesize(name string, exprs ...Expr) string { return str + ")" } + +func (ap *Printer) bracesizeStmt(stmts ...Stmt) string { + str := "{\n" + + for _, s := range stmts { + str += " " + s.Accept(ap).(string) + "\n" + } + + return str + "}" +} diff --git a/ast/printer_test.go b/ast/printer_test.go index 72cdf33..2f0ddf9 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -122,7 +122,26 @@ func TestPrintBlockStmt(t *testing.T) { printer := NewPrinter() result := printer.PrintStmts([]Stmt{stmt}) - expected := "(block\n 1\n 2\n)\n" + expected := "{\n 1\n 2\n}\n" + if result != expected { + t.Errorf("expected %v, got %v", expected, result) + } +} + +func TestPrintIfStmt(t *testing.T) { + stmt := &IfStmt{ + Condition: &LiteralExpr{Value: true}, + ThenBranch: &ExpressionStmt{ + Expression: &LiteralExpr{Value: 1}, + }, + ElseBranch: &ExpressionStmt{ + Expression: &LiteralExpr{Value: 2}, + }, + } + + printer := NewPrinter() + result := printer.PrintStmts([]Stmt{stmt}) + expected := "(if true {\n 1\n} else {\n 2\n})\n" if result != expected { t.Errorf("expected %v, got %v", expected, result) } diff --git a/ast/stmt.go b/ast/stmt.go index 3ceeeea..eb65299 100644 --- a/ast/stmt.go +++ b/ast/stmt.go @@ -6,6 +6,7 @@ type StmtVisitor[T any] interface { VisitErrorStmt(es *ErrorStmt) T VisitBlockStmt(bs *BlockStmt) T VisitExpressionStmt(es *ExpressionStmt) T + VisitIfStmt(is *IfStmt) T VisitPrintStmt(ps *PrintStmt) T VisitVarStmt(vs *VarStmt) T } @@ -38,6 +39,16 @@ func (es *ExpressionStmt) Accept(v StmtVisitor[any]) any { return v.VisitExpressionStmt(es) } +type IfStmt struct { + Condition Expr + ThenBranch Stmt + ElseBranch Stmt +} + +func (is *IfStmt) Accept(v StmtVisitor[any]) any { + return v.VisitIfStmt(is) +} + type PrintStmt struct { Expression Expr } diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index 6fa24f7..ab3a86d 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -35,6 +35,7 @@ func main() { "Error : Value string", "Block : Statements []Stmt", "Expression : Expression Expr", + "If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt", "Print : Expression Expr", "Var : Name token.Token, Initializer Expr", }) diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 85877f0..4a5560c 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -39,6 +39,17 @@ func (i *Interpreter) VisitExpressionStmt(es *ast.ExpressionStmt) any { return nil } +// VisitIfStmt visits an if statement. +func (i *Interpreter) VisitIfStmt(is *ast.IfStmt) any { + if isTruthy(i.evaluate(is.Condition)) { + i.execute(is.ThenBranch) + } else if is.ElseBranch != nil { + i.execute(is.ElseBranch) + } + + return nil +} + // VisitPrintStmt visits a print statement. func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any { value := i.evaluate(ps.Expression) diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 6bb0ea1..b58cbed 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -455,3 +455,86 @@ func TestInterpretBlockStatement(t *testing.T) { t.Errorf("run() = %v; want %v", out, expected) } } + +func TestInterpretIfStatement(t *testing.T) { + old := os.Stdout // keep backup of the real stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = w + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + // begin unit test + i := New(errors.NewMockErrorLogger()) + ifStmt := &ast.IfStmt{ + Condition: &ast.LiteralExpr{Value: true}, + ThenBranch: &ast.PrintStmt{ + Expression: &ast.LiteralExpr{Value: 42.0}, + }, + } + + i.VisitIfStmt(ifStmt) + // end unit test + + // back to normal state + w.Close() + os.Stdout = old // restoring the real stdout + out := <-outC + + // reading our temp stdout + expected := "42\n" + if out != expected { + t.Errorf("run() = %v; want %v", out, expected) + } +} + +func TestInterpretIfStatementElseBranch(t *testing.T) { + old := os.Stdout // keep backup of the real stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = w + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + // begin unit test + i := New(errors.NewMockErrorLogger()) + ifStmt := &ast.IfStmt{ + Condition: &ast.LiteralExpr{Value: false}, + ThenBranch: &ast.PrintStmt{ + Expression: &ast.LiteralExpr{Value: 42.0}, + }, + ElseBranch: &ast.PrintStmt{ + Expression: &ast.LiteralExpr{Value: 24.0}, + }, + } + + i.VisitIfStmt(ifStmt) + // end unit test + + // back to normal state + w.Close() + os.Stdout = old // restoring the real stdout + out := <-outC + + // reading our temp stdout + expected := "24\n" + if out != expected { + t.Errorf("run() = %v; want %v", out, expected) + } +} diff --git a/parser/parser.go b/parser/parser.go index ed74b24..1fb5b33 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -83,8 +83,11 @@ func (p *Parser) varDeclaration() ast.Stmt { return &ast.VarStmt{Name: name, Initializer: initializer} } -// statement → exprStmt | printStmt | block ; +// statement → exprStmt | ifStmt | printStmt | block ; func (p *Parser) statement() ast.Stmt { + if p.match(token.IF) { + return p.ifStatement() + } if p.match(token.PRINT) { return p.printStatement() } @@ -95,6 +98,40 @@ func (p *Parser) statement() ast.Stmt { return p.expressionStatement() } +// ifStmt → "if" "(" expression ")" statement ( "else" statement )? ; +func (p *Parser) ifStatement() ast.Stmt { + err := p.consume(token.LEFT_PAREN, "Expect '(' after 'if'.") + if err != nil { + return p.fromErrorExpr(err) + } + + condition := p.expression() + + err = p.consume(token.RIGHT_PAREN, "Expect ')' after if condition.") + if err != nil { + return p.fromErrorExpr(err) + } + + thenBranch := p.statement() + var elseBranch ast.Stmt + if p.match(token.ELSE) { + elseBranch = p.statement() + } + + return &ast.IfStmt{Condition: condition, ThenBranch: thenBranch, ElseBranch: elseBranch} +} + +// printStmt → "print" expression ";" ; +func (p *Parser) printStatement() ast.Stmt { + expr := p.expression() + err := p.consume(token.SEMICOLON, "Expect ';' after value.") + if err != nil { + return p.fromErrorExpr(err) + } + + return &ast.PrintStmt{Expression: expr} +} + // block → "{" declaration* "}" ; func (p *Parser) blockStatement() ast.Stmt { statements := []ast.Stmt{} @@ -122,17 +159,6 @@ func (p *Parser) expressionStatement() ast.Stmt { return &ast.ExpressionStmt{Expression: expr} } -// printStmt → "print" expression ";" ; -func (p *Parser) printStatement() ast.Stmt { - expr := p.expression() - err := p.consume(token.SEMICOLON, "Expect ';' after value.") - if err != nil { - return p.fromErrorExpr(err) - } - - return &ast.PrintStmt{Expression: expr} -} - // expression → assignment ; func (p *Parser) expression() ast.Expr { return p.assignment() diff --git a/parser/parser_test.go b/parser/parser_test.go index 875b0db..d1d6dba 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -278,7 +278,89 @@ func TestParseBlockStatement(t *testing.T) { {Type: token.RIGHT_BRACE, Lexeme: "}"}, {Type: token.EOF}, }, - expected: "(block\n (var foo 42)\n)\n", + expected: "{\n (var foo 42)\n}\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} + +func TestParseIfStatement(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple if statement", + tokens: []token.Token{ + {Type: token.IF, Lexeme: "if"}, + {Type: token.LEFT_PAREN, Lexeme: "("}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.RIGHT_PAREN, Lexeme: ")"}, + {Type: token.PRINT, Lexeme: "print"}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(if 42 {\n (print 42)\n})\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} + +func TestParseIfElseStatement(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple if else statement", + tokens: []token.Token{ + {Type: token.IF, Lexeme: "if"}, + {Type: token.LEFT_PAREN, Lexeme: "("}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.RIGHT_PAREN, Lexeme: ")"}, + {Type: token.PRINT, Lexeme: "print"}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.ELSE, Lexeme: "else"}, + {Type: token.PRINT, Lexeme: "print"}, + {Type: token.NUMBER, Lexeme: "24", Literal: 24}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(if 42 {\n (print 42)\n} else {\n (print 24)\n})\n", }, }