From b97f44db2d6c93a9486a925d7c1d3f1de99bf58b Mon Sep 17 00:00:00 2001 From: oabrivard Date: Fri, 15 Nov 2024 10:36:17 +0100 Subject: [PATCH] Add print and expression statements --- .task/checksum/astgen | 2 +- Taskfile.yml | 3 +- ast/expr.go | 30 +++++------ ast/printer.go | 28 +++++++++- ast/printer_test.go | 41 ++++++++++++++- ast/stmt.go | 48 +++++++++++++++++ cmd/astgen/main.go | 15 ++++-- interpreter/interpreter.go | 41 ++++++++++++++- interpreter/interpreter_test.go | 35 +++++++------ lox/lox.go | 13 +++-- lox/lox_test.go | 6 +-- parser/parser.go | 61 +++++++++++++++++++-- parser/parser_test.go | 93 +++++++++++++++++++++++++++++++-- 13 files changed, 359 insertions(+), 57 deletions(-) create mode 100644 ast/stmt.go diff --git a/.task/checksum/astgen b/.task/checksum/astgen index bbd454f..aeceed0 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -f630c7b1cf9f3b55f1467ac4c8704152 +bf6f2cc6973975f2bcc789463842856a diff --git a/Taskfile.yml b/Taskfile.yml index dd130a2..b94339d 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -26,7 +26,7 @@ tasks: desc: "Run tests" deps: [astgen] cmds: - - go test -v ./... + - go test ./... {{.CLI_ARGS}} astgen: desc: "Generate AST nodes" @@ -36,6 +36,7 @@ tasks: - cmd/astgen/main.go generates: - ast/expr.go + - ast/stmt.go clean: desc: "Clean up" diff --git a/ast/expr.go b/ast/expr.go index cd6fc3d..e7fd652 100644 --- a/ast/expr.go +++ b/ast/expr.go @@ -3,11 +3,11 @@ package ast import "golox/token" type ExprVisitor[T any] interface { - VisitErrorExpr(error *ErrorExpr) T - VisitBinaryExpr(binary *BinaryExpr) T - VisitGroupingExpr(grouping *GroupingExpr) T - VisitLiteralExpr(literal *LiteralExpr) T - VisitUnaryExpr(unary *UnaryExpr) T + VisitErrorExpr(ee *ErrorExpr) T + VisitBinaryExpr(be *BinaryExpr) T + VisitGroupingExpr(ge *GroupingExpr) T + VisitLiteralExpr(le *LiteralExpr) T + VisitUnaryExpr(ue *UnaryExpr) T } type Expr interface { @@ -18,8 +18,8 @@ type ErrorExpr struct { Value string } -func (t *ErrorExpr) Accept(visitor ExprVisitor[any]) any { - return visitor.VisitErrorExpr(t) +func (ee *ErrorExpr) Accept(v ExprVisitor[any]) any { + return v.VisitErrorExpr(ee) } type BinaryExpr struct { @@ -28,24 +28,24 @@ type BinaryExpr struct { Right Expr } -func (t *BinaryExpr) Accept(visitor ExprVisitor[any]) any { - return visitor.VisitBinaryExpr(t) +func (be *BinaryExpr) Accept(v ExprVisitor[any]) any { + return v.VisitBinaryExpr(be) } type GroupingExpr struct { Expression Expr } -func (t *GroupingExpr) Accept(visitor ExprVisitor[any]) any { - return visitor.VisitGroupingExpr(t) +func (ge *GroupingExpr) Accept(v ExprVisitor[any]) any { + return v.VisitGroupingExpr(ge) } type LiteralExpr struct { Value any } -func (t *LiteralExpr) Accept(visitor ExprVisitor[any]) any { - return visitor.VisitLiteralExpr(t) +func (le *LiteralExpr) Accept(v ExprVisitor[any]) any { + return v.VisitLiteralExpr(le) } type UnaryExpr struct { @@ -53,7 +53,7 @@ type UnaryExpr struct { Right Expr } -func (t *UnaryExpr) Accept(visitor ExprVisitor[any]) any { - return visitor.VisitUnaryExpr(t) +func (ue *UnaryExpr) Accept(v ExprVisitor[any]) any { + return v.VisitUnaryExpr(ue) } diff --git a/ast/printer.go b/ast/printer.go index 343d6df..36bd731 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -11,10 +11,36 @@ func NewPrinter() *Printer { return &Printer{} } -func (ap *Printer) Print(expr Expr) string { +func (ap *Printer) PrintStmts(stmts []Stmt) string { + str := "" + + for _, stmt := range stmts { + str += stmt.Accept(ap).(string) + "\n" + } + + return str +} + +func (ap *Printer) PrintExpr(expr Expr) string { return expr.Accept(ap).(string) } +func (ap *Printer) VisitErrorStmt(stmt *ErrorStmt) any { + return stmt.Value +} + +func (ap *Printer) VisitExpressionStmt(stmt *ExpressionStmt) any { + return stmt.Expression.Accept(ap) +} + +func (ap *Printer) VisitPrintStmt(stmt *PrintStmt) any { + return ap.parenthesize("print", stmt.Expression) +} + +func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { + return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer) +} + func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any { return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) } diff --git a/ast/printer_test.go b/ast/printer_test.go index ec2c8ce..4ea1241 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestPrinter(t *testing.T) { +func TestPrintExpr(t *testing.T) { tests := []struct { name string expr Expr @@ -55,7 +55,44 @@ func TestPrinter(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { printer := NewPrinter() - result := printer.Print(tt.expr) + result := printer.PrintExpr(tt.expr) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestPrintStmts(t *testing.T) { + tests := []struct { + name string + stmts []Stmt + expected string + }{ + { + name: "Print statement", + stmts: []Stmt{ + &PrintStmt{ + Expression: &LiteralExpr{Value: 42}, + }, + }, + expected: "(print 42)\n", + }, + { + name: "Expression statement", + stmts: []Stmt{ + &ExpressionStmt{ + Expression: &LiteralExpr{Value: 42}, + }, + }, + expected: "42\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + printer := NewPrinter() + result := printer.PrintStmts(tt.stmts) if result != tt.expected { t.Errorf("expected %v, got %v", tt.expected, result) } diff --git a/ast/stmt.go b/ast/stmt.go new file mode 100644 index 0000000..2f80ec8 --- /dev/null +++ b/ast/stmt.go @@ -0,0 +1,48 @@ +package ast + +import "golox/token" + +type StmtVisitor[T any] interface { + VisitErrorStmt(es *ErrorStmt) T + VisitExpressionStmt(es *ExpressionStmt) T + VisitPrintStmt(ps *PrintStmt) T + VisitVarStmt(vs *VarStmt) T +} + +type Stmt interface { + Accept(visitor StmtVisitor[any]) any +} + +type ErrorStmt struct { + Value string +} + +func (es *ErrorStmt) Accept(v StmtVisitor[any]) any { + return v.VisitErrorStmt(es) +} + +type ExpressionStmt struct { + Expression Expr +} + +func (es *ExpressionStmt) Accept(v StmtVisitor[any]) any { + return v.VisitExpressionStmt(es) +} + +type PrintStmt struct { + Expression Expr +} + +func (ps *PrintStmt) Accept(v StmtVisitor[any]) any { + return v.VisitPrintStmt(ps) +} + +type VarStmt struct { + Name token.Token + Initializer Expr +} + +func (vs *VarStmt) Accept(v StmtVisitor[any]) any { + return v.VisitVarStmt(vs) +} + diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index 982f3cd..0f07a87 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -28,6 +28,13 @@ func main() { "Literal : Value any", "Unary : Operator token.Token, Right Expr", }) + + defineAst(d, "Stmt", []string{ + "Error : Value string", + "Expression : Expression Expr", + "Print : Expression Expr", + "Var : Name token.Token, Initializer Expr", + }) } func defineAst(outputDir, baseName string, types []string) { @@ -58,7 +65,8 @@ func defineVisitor(file *os.File, baseName string, types []string) { for _, t := range types { typeName := strings.TrimSpace(t[:strings.Index(t, ":")-1]) - file.WriteString(" Visit" + typeName + baseName + "(" + strings.ToLower(typeName) + " *" + typeName + baseName + ") T\n") + paramName := strings.ToLower(typeName[:1]) + strings.ToLower(baseName[:1]) + file.WriteString(" Visit" + typeName + baseName + "(" + paramName + " *" + typeName + baseName + ") T\n") } file.WriteString("}\n\n") @@ -74,8 +82,9 @@ func defineType(file *os.File, baseName, typeString string) { file.WriteString(" " + field + "\n") } + varName := strings.ToLower(typeName[:1]) + strings.ToLower(baseName[:1]) file.WriteString("}\n\n") - file.WriteString("func (t *" + typeName + baseName + ") Accept(visitor " + baseName + "Visitor[any]) any {\n") - file.WriteString(" return visitor.Visit" + typeName + baseName + "(t)\n") + file.WriteString("func (" + varName + " *" + typeName + baseName + ") Accept(v " + baseName + "Visitor[any]) any {\n") + file.WriteString(" return v.Visit" + typeName + baseName + "(" + varName + ")\n") file.WriteString("}\n\n") } diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 0f34de5..3ee4360 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -14,8 +14,40 @@ type Interpreter struct { } // New creates a new Interpreter. -func New() *Interpreter { - return &Interpreter{} +func New(el errors.Logger) *Interpreter { + return &Interpreter{el} +} + +// Interpret interprets the AST. +func (i *Interpreter) Interpret(stmts []ast.Stmt) { + defer i.afterPanic() + + for _, stmt := range stmts { + i.execute(stmt) + } +} + +// VisitErrorStmt visits an error statement. +func (i *Interpreter) VisitErrorStmt(es *ast.ErrorStmt) any { + panic(es.Value) +} + +// VisitExpressionStmt visits an expression. +func (i *Interpreter) VisitExpressionStmt(es *ast.ExpressionStmt) any { + i.evaluate(es.Expression) + return nil +} + +// VisitPrintStmt visits a print statement. +func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any { + value := i.evaluate(ps.Expression) + fmt.Println(stringify(value)) + return nil +} + +// VisitVarStmt visits a var statement. +func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any { + return nil } // Interpret interprets the AST. @@ -150,6 +182,11 @@ func (i *Interpreter) evaluate(e ast.Expr) any { return e.Accept(i) } +// execute executes a statement. +func (i *Interpreter) execute(s ast.Stmt) { + s.Accept(i) +} + // stringify returns a string representation of a value. func stringify(v any) string { if v == nil { diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 412554f..ecf78aa 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -3,13 +3,14 @@ package interpreter_test import ( "golox/ast" + "golox/errors" "golox/interpreter" "golox/token" "testing" ) func TestInterpretLiteralExpr(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42} result := i.VisitLiteralExpr(literal) @@ -19,7 +20,7 @@ func TestInterpretLiteralExpr(t *testing.T) { } func TestInterpretGroupingExpr(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42} grouping := &ast.GroupingExpr{Expression: literal} @@ -30,7 +31,7 @@ func TestInterpretGroupingExpr(t *testing.T) { } func TestInterpretUnaryExpr(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42.0} unary := &ast.UnaryExpr{ Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, @@ -44,7 +45,7 @@ func TestInterpretUnaryExpr(t *testing.T) { } func TestInterpretUnaryExprBang(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: true} unary := &ast.UnaryExpr{ Operator: token.Token{Type: token.BANG, Lexeme: "!"}, @@ -58,7 +59,7 @@ func TestInterpretUnaryExprBang(t *testing.T) { } func TestInterpretErrorExpr(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) errorExpr := &ast.ErrorExpr{Value: "error"} defer func() { @@ -71,7 +72,7 @@ func TestInterpretErrorExpr(t *testing.T) { } func TestInterpretExpr(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42.0} defer func() { @@ -87,7 +88,7 @@ func TestInterpretExpr(t *testing.T) { } func TestInterpretBinaryExpr(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -103,7 +104,7 @@ func TestInterpretBinaryExpr(t *testing.T) { } func TestInterpretBinaryExprDivisionByZero(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 0.0} binary := &ast.BinaryExpr{ @@ -122,7 +123,7 @@ func TestInterpretBinaryExprDivisionByZero(t *testing.T) { } func TestInterpretBinaryExprAddition(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -138,7 +139,7 @@ func TestInterpretBinaryExprAddition(t *testing.T) { } func TestInterpretBinaryExprSubtraction(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -154,7 +155,7 @@ func TestInterpretBinaryExprSubtraction(t *testing.T) { } func TestInterpretBinaryExprStringConcatenation(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: "foo"} right := &ast.LiteralExpr{Value: "bar"} binary := &ast.BinaryExpr{ @@ -170,7 +171,7 @@ func TestInterpretBinaryExprStringConcatenation(t *testing.T) { } func TestInterpretBinaryExprInvalidOperands(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: "foo"} right := &ast.LiteralExpr{Value: 42.0} binary := &ast.BinaryExpr{ @@ -189,7 +190,7 @@ func TestInterpretBinaryExprInvalidOperands(t *testing.T) { } func TestInterpretBinaryExprComparison(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -205,7 +206,7 @@ func TestInterpretBinaryExprComparison(t *testing.T) { } func TestInterpretBinaryExprComparisonEqual(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 42.0} binary := &ast.BinaryExpr{ @@ -221,7 +222,7 @@ func TestInterpretBinaryExprComparisonEqual(t *testing.T) { } func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -237,7 +238,7 @@ func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) { } func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: "foo"} right := &ast.LiteralExpr{Value: 42.0} binary := &ast.BinaryExpr{ @@ -256,7 +257,7 @@ func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) { } func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { - i := interpreter.New() + i := interpreter.New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ diff --git a/lox/lox.go b/lox/lox.go index dfb6bbb..56f1999 100644 --- a/lox/lox.go +++ b/lox/lox.go @@ -17,11 +17,14 @@ type Lox struct { } func New() *Lox { - return &Lox{ + l := &Lox{ hadError: false, hadRuntimeError: false, - interpreter: interpreter.New(), } + + l.interpreter = interpreter.New(l) + + return l } func (l *Lox) RunFile(path string) { @@ -90,12 +93,12 @@ func (l *Lox) run(source string) { scanner := scanner.New(source, l) tokens := scanner.ScanTokens() parser := parser.New(tokens, l) - expr := parser.Parse() + stmts := parser.Parse() + // Stop if there was a syntax error. if l.hadError { return } - s := l.interpreter.InterpretExpr(expr) - fmt.Println(s) + l.interpreter.Interpret(stmts) } diff --git a/lox/lox_test.go b/lox/lox_test.go index 6c9c0c3..f09584f 100644 --- a/lox/lox_test.go +++ b/lox/lox_test.go @@ -24,7 +24,7 @@ func TestRun(t *testing.T) { outC <- buf.String() }() - source := "1+4/2" + source := "print 1+4/2;" l := New() l.run(source) @@ -63,7 +63,7 @@ func TestRunFile(t *testing.T) { } defer os.Remove(tmpfile.Name()) - content := "1+4/2" + content := "print 1+4/2;" if _, err := tmpfile.Write([]byte(content)); err != nil { t.Fatal(err) } @@ -166,7 +166,7 @@ func TestRunPrompt(t *testing.T) { outC <- buf.String() }() - input := "1+4/2\n\n" + input := "print 1+4/2;\n\n" wIn.Write([]byte(input)) wIn.Close() diff --git a/parser/parser.go b/parser/parser.go index bacb271..154ba22 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -2,15 +2,20 @@ * The parser is responsible for parsing the tokens generated by the scanner. * * The grammar is as follows: + * program → statement* EOF ; + * statement → exprStmt + * | printStmt ; + * exprStmt → expression ";" ; + * printStmt → "print" expression ";" ; * expression → equality ; * equality → comparison ( ( "!=" | "==" ) comparison )* ; * comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ; * term → factor ( ( "-" | "+" ) factor )* ; * factor → unary ( ( "/" | "*" ) unary )* ; * unary → ( "!" | "-" ) unary - * | primary ; + * | primary ; * primary → NUMBER | STRING | "true" | "false" | "nil" - * | "(" expression ")" ; + * | "(" expression ")" ; */ package parser @@ -33,8 +38,50 @@ func New(tokens []token.Token, el errors.Logger) *Parser { } // Parse parses the tokens and returns the AST. -func (p *Parser) Parse() ast.Expr { - return p.expression() +func (p *Parser) Parse() []ast.Stmt { + stmts := []ast.Stmt{} + + for !p.isAtEnd() { + stmt := p.statement() + if _, ok := stmt.(*ast.ErrorStmt); ok { + p.synchronize() + } else { + stmts = append(stmts, stmt) + } + } + + return stmts +} + +// statement → exprStmt | printStmt ; +func (p *Parser) statement() ast.Stmt { + if p.match(token.PRINT) { + return p.printStatement() + } + + return p.expressionStatement() +} + +// exprStmt → expression ";" ; +func (p *Parser) expressionStatement() ast.Stmt { + expr := p.expression() + err := p.consume(token.SEMICOLON, "Expect ';' after expression.") + if err != nil { + return p.fromErrorExpr(err) + } + + return &ast.ExpressionStmt{Expression: expr} +} + +// printStmt → "print" expression ";" ; +func (p *Parser) printStatement() ast.Stmt { + expr := p.expression() + err := p.consume(token.SEMICOLON, "Expect ';' after value.") + if err != nil { + return p.fromErrorExpr(err) + } + + return &ast.PrintStmt{Expression: expr} } // expression → equality ; @@ -189,7 +236,13 @@ func (p *Parser) newErrorExpr(t token.Token, message string) *ast.ErrorExpr { return &ast.ErrorExpr{Value: message} } +// fromErrorExpr creates an ErrorStmt from an ErrorExpr. +func (p *Parser) fromErrorExpr(ee *ast.ErrorExpr) *ast.ErrorStmt { + return &ast.ErrorStmt{Value: ee.Value} +} + // synchronize synchronizes the parser after an error. +// It skips tokens until it finds a statement boundary. func (p *Parser) synchronize() { p.advance() diff --git a/parser/parser_test.go b/parser/parser_test.go index 23e1930..b2cd5ef 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -7,7 +7,7 @@ import ( "testing" ) -func TestParser(t *testing.T) { +func TestExpressionParsing(t *testing.T) { tests := []struct { name string tokens []token.Token @@ -19,6 +19,7 @@ func TestParser(t *testing.T) { {Type: token.NUMBER, Literal: 1}, {Type: token.PLUS, Lexeme: "+"}, {Type: token.NUMBER, Literal: 2}, + {Type: token.SEMICOLON, Lexeme: ";"}, {Type: token.EOF}, }, expected: "(+ 1 2)", @@ -28,6 +29,7 @@ func TestParser(t *testing.T) { tokens: []token.Token{ {Type: token.MINUS, Lexeme: "-"}, {Type: token.NUMBER, Literal: 123}, + {Type: token.SEMICOLON, Lexeme: ";"}, {Type: token.EOF}, }, expected: "(- 123)", @@ -40,6 +42,7 @@ func TestParser(t *testing.T) { {Type: token.PLUS, Lexeme: "+"}, {Type: token.NUMBER, Literal: 2}, {Type: token.RIGHT_PAREN, Lexeme: ")"}, + {Type: token.SEMICOLON, Lexeme: ";"}, {Type: token.EOF}, }, expected: "(group (+ 1 2))", @@ -50,6 +53,7 @@ func TestParser(t *testing.T) { {Type: token.NUMBER, Literal: 1}, {Type: token.GREATER, Lexeme: ">"}, {Type: token.NUMBER, Literal: 2}, + {Type: token.SEMICOLON, Lexeme: ";"}, {Type: token.EOF}, }, expected: "(> 1 2)", @@ -60,6 +64,7 @@ func TestParser(t *testing.T) { {Type: token.NUMBER, Literal: 1}, {Type: token.EQUAL_EQUAL, Lexeme: "=="}, {Type: token.NUMBER, Literal: 2}, + {Type: token.SEMICOLON, Lexeme: ";"}, {Type: token.EOF}, }, expected: "(== 1 2)", @@ -71,6 +76,7 @@ func TestParser(t *testing.T) { {Type: token.NUMBER, Literal: 1}, {Type: token.PLUS, Lexeme: "+"}, {Type: token.NUMBER, Literal: 2}, + {Type: token.SEMICOLON, Lexeme: ";"}, {Type: token.EOF}, }, expected: "Expect ')' after expression.", @@ -80,9 +86,90 @@ func TestParser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { parser := New(tt.tokens, errors.NewMockErrorLogger()) - expr := parser.Parse() + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + stmt := stmts[0] + var expr ast.Expr + if es, ok := stmt.(*ast.ExpressionStmt); !ok { + t.Errorf("expected ExprStmt, got %T", stmt) + } else { + expr = es.Expression + } + + ap := ast.NewPrinter() + s := ap.PrintExpr(expr) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} + +func TestParseExpressionStmt(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple expression statement", + tokens: []token.Token{ + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "42\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} + +func TestParsePrintStmt(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple print statement", + tokens: []token.Token{ + {Type: token.PRINT, Lexeme: "print"}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(print 42)\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + ap := ast.NewPrinter() - s := ap.Print(expr) + s := ap.PrintStmts(stmts) if s != tt.expected { t.Errorf("expected %v, got %v", tt.expected, s) }