diff --git a/.task/checksum/astgen b/.task/checksum/astgen index 8582cb7..bbd454f 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -5f28a0838ca9382a947da5bd52756cf3 +f630c7b1cf9f3b55f1467ac4c8704152 diff --git a/Taskfile.yml b/Taskfile.yml index 65f86b1..dd130a2 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -31,11 +31,11 @@ tasks: astgen: desc: "Generate AST nodes" cmds: - - go run cmd/astgen/main.go "./parser" + - go run cmd/astgen/main.go "./ast" sources: - cmd/astgen/main.go generates: - - parser/expr.go + - ast/expr.go clean: desc: "Clean up" diff --git a/parser/expr.go b/ast/expr.go similarity index 98% rename from parser/expr.go rename to ast/expr.go index c6ba2f0..cd6fc3d 100644 --- a/parser/expr.go +++ b/ast/expr.go @@ -1,4 +1,4 @@ -package parser +package ast import "golox/token" diff --git a/ast/printer.go b/ast/printer.go index fdd6ca9..343d6df 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -2,44 +2,43 @@ package ast import ( "fmt" - "golox/parser" ) type Printer struct { } -func New() *Printer { +func NewPrinter() *Printer { return &Printer{} } -func (ap *Printer) Print(expr parser.Expr) string { +func (ap *Printer) Print(expr Expr) string { return expr.Accept(ap).(string) } -func (ap *Printer) VisitBinaryExpr(expr *parser.BinaryExpr) any { +func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any { return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) } -func (ap *Printer) VisitGroupingExpr(expr *parser.GroupingExpr) any { +func (ap *Printer) VisitGroupingExpr(expr *GroupingExpr) any { return ap.parenthesize("group", expr.Expression) } -func (ap *Printer) VisitLiteralExpr(expr *parser.LiteralExpr) any { +func (ap *Printer) VisitLiteralExpr(expr *LiteralExpr) any { if expr.Value == nil { return "nil" } return fmt.Sprint(expr.Value) } -func (ap *Printer) VisitUnaryExpr(expr *parser.UnaryExpr) any { +func (ap *Printer) VisitUnaryExpr(expr *UnaryExpr) any { return ap.parenthesize(expr.Operator.Lexeme, expr.Right) } -func (ap *Printer) VisitErrorExpr(expr *parser.ErrorExpr) any { +func (ap *Printer) VisitErrorExpr(expr *ErrorExpr) any { return expr.Value } -func (ap *Printer) parenthesize(name string, exprs ...parser.Expr) string { +func (ap *Printer) parenthesize(name string, exprs ...Expr) string { str := "(" + name for _, expr := range exprs { diff --git a/ast/printer_test.go b/ast/printer_test.go new file mode 100644 index 0000000..ec2c8ce --- /dev/null +++ b/ast/printer_test.go @@ -0,0 +1,64 @@ +package ast + +import ( + "golox/token" + "testing" +) + +func TestPrinter(t *testing.T) { + tests := []struct { + name string + expr Expr + expected string + }{ + { + name: "Binary expression", + expr: &BinaryExpr{ + Left: &LiteralExpr{Value: 1}, + Operator: token.Token{Type: token.PLUS, Lexeme: "+"}, + Right: &LiteralExpr{Value: 2}, + }, + expected: "(+ 1 2)", + }, + { + name: "Grouping expression", + expr: &GroupingExpr{ + Expression: &LiteralExpr{Value: 1}, + }, + expected: "(group 1)", + }, + { + name: "Literal expression", + expr: &LiteralExpr{Value: 123}, + expected: "123", + }, + { + name: "Unary expression", + expr: &UnaryExpr{ + Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, + Right: &LiteralExpr{Value: 123}, + }, + expected: "(- 123)", + }, + { + name: "Nil literal expression", + expr: &LiteralExpr{Value: nil}, + expected: "nil", + }, + { + name: "Error expression", + expr: &ErrorExpr{Value: "error"}, + expected: "error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + printer := NewPrinter() + result := printer.Print(tt.expr) + if result != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, result) + } + }) + } +} diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index d7ee37d..982f3cd 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -40,7 +40,7 @@ func defineAst(outputDir, baseName string, types []string) { } defer file.Close() - file.WriteString("package parser\n\n") + file.WriteString("package ast\n\n") file.WriteString("import \"golox/token\"\n\n") defineVisitor(file, baseName, types) diff --git a/lox/lox.go b/lox/lox.go index 99b632f..0fb1158 100644 --- a/lox/lox.go +++ b/lox/lox.go @@ -80,6 +80,6 @@ func (l *Lox) run(source string) { return } - p := ast.New() + p := ast.NewPrinter() fmt.Println(p.Print(expr)) } diff --git a/lox/lox_test.go b/lox/lox_test.go index c6ff114..9e23514 100644 --- a/lox/lox_test.go +++ b/lox/lox_test.go @@ -24,7 +24,7 @@ func TestRun(t *testing.T) { outC <- buf.String() }() - source := "print('Hello, World!');" + source := "1+4/2" l := New() l.run(source) @@ -34,7 +34,7 @@ func TestRun(t *testing.T) { out := <-outC // reading our temp stdout - expected := source + "\n" + expected := "(+ 1 (/ 4 2))\n" if out != expected { t.Errorf("run() = %v; want %v", out, expected) } @@ -63,7 +63,7 @@ func TestRunFile(t *testing.T) { } defer os.Remove(tmpfile.Name()) - content := "print('Hello, World!');" + content := "1+4/2" if _, err := tmpfile.Write([]byte(content)); err != nil { t.Fatal(err) } @@ -79,7 +79,7 @@ func TestRunFile(t *testing.T) { out := <-outC // reading our temp stdout - expected := "print('Hello, World!');\n" + expected := "(+ 1 (/ 4 2))\n" if out != expected { t.Errorf("RunFile() = %v; want %v", out, expected) } @@ -166,7 +166,7 @@ func TestRunPrompt(t *testing.T) { outC <- buf.String() }() - input := "print('Hello, World!');\n\n" + input := "1+4/2\n\n" wIn.Write([]byte(input)) wIn.Close() @@ -178,7 +178,7 @@ func TestRunPrompt(t *testing.T) { os.Stdout = oldStdout out := <-outC - expected := "> print('Hello, World!');\n\n> " + expected := "> (+ 1 (/ 4 2))\n> " if out != expected { t.Errorf("RunPrompt() = %v; want %v", out, expected) } diff --git a/parser/parser.go b/parser/parser.go index b85fb23..bacb271 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -15,6 +15,7 @@ package parser import ( + "golox/ast" "golox/errors" "golox/token" ) @@ -32,96 +33,96 @@ func New(tokens []token.Token, el errors.Logger) *Parser { } // Parse parses the tokens and returns the AST. -func (p *Parser) Parse() Expr { +func (p *Parser) Parse() ast.Expr { return p.expression() } // expression → equality ; -func (p *Parser) expression() Expr { +func (p *Parser) expression() ast.Expr { return p.equality() } // equality → comparison ( ( "!=" | "==" ) comparison )* ; -func (p *Parser) equality() Expr { +func (p *Parser) equality() ast.Expr { expr := p.comparison() for p.match(token.BANG_EQUAL, token.EQUAL_EQUAL) { operator := p.previous() right := p.comparison() - expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right} } return expr } // comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ; -func (p *Parser) comparison() Expr { +func (p *Parser) comparison() ast.Expr { expr := p.term() for p.match(token.GREATER, token.GREATER_EQUAL, token.LESS, token.LESS_EQUAL) { operator := p.previous() right := p.term() - expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right} } return expr } // term → factor ( ( "-" | "+" ) factor )* ; -func (p *Parser) term() Expr { +func (p *Parser) term() ast.Expr { expr := p.factor() for p.match(token.MINUS, token.PLUS) { operator := p.previous() right := p.factor() - expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right} } return expr } // factor → unary ( ( "/" | "*" ) unary )* ; -func (p *Parser) factor() Expr { +func (p *Parser) factor() ast.Expr { expr := p.unary() for p.match(token.SLASH, token.STAR) { operator := p.previous() right := p.unary() - expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right} } return expr } // unary → ( "!" | "-" ) unary | primary ; -func (p *Parser) unary() Expr { +func (p *Parser) unary() ast.Expr { if p.match(token.BANG, token.MINUS) { operator := p.previous() right := p.unary() - return &UnaryExpr{Operator: operator, Right: right} + return &ast.UnaryExpr{Operator: operator, Right: right} } return p.primary() } // primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ; -func (p *Parser) primary() Expr { +func (p *Parser) primary() ast.Expr { switch { case p.match(token.FALSE): - return &LiteralExpr{Value: false} + return &ast.LiteralExpr{Value: false} case p.match(token.TRUE): - return &LiteralExpr{Value: true} + return &ast.LiteralExpr{Value: true} case p.match(token.NIL): - return &LiteralExpr{Value: nil} + return &ast.LiteralExpr{Value: nil} case p.match(token.NUMBER, token.STRING): - return &LiteralExpr{Value: p.previous().Literal} + return &ast.LiteralExpr{Value: p.previous().Literal} case p.match(token.LEFT_PAREN): expr := p.expression() err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.") if err != nil { return err } - return &GroupingExpr{Expression: expr} + return &ast.GroupingExpr{Expression: expr} } return p.newErrorExpr(p.peek(), "Expect expression.") @@ -140,7 +141,7 @@ func (p *Parser) match(types ...token.TokenType) bool { } // consume consumes the current token if it is of the given type. -func (p *Parser) consume(tt token.TokenType, message string) *ErrorExpr { +func (p *Parser) consume(tt token.TokenType, message string) *ast.ErrorExpr { if p.check(tt) { p.advance() return nil @@ -183,9 +184,9 @@ func (p *Parser) previous() token.Token { } // newErrorExpr creates a new ErrorExpr and reports the error. -func (p *Parser) newErrorExpr(t token.Token, message string) *ErrorExpr { +func (p *Parser) newErrorExpr(t token.Token, message string) *ast.ErrorExpr { p.errLogger.ErrorAtToken(t, message) - return &ErrorExpr{Value: message} + return &ast.ErrorExpr{Value: message} } // synchronize synchronizes the parser after an error. diff --git a/parser/parser_test.go b/parser/parser_test.go new file mode 100644 index 0000000..9974b0e --- /dev/null +++ b/parser/parser_test.go @@ -0,0 +1,91 @@ +package parser + +import ( + "golox/ast" + "golox/token" + "testing" +) + +type mockErrorLogger struct{} + +func (el *mockErrorLogger) Error(line int, message string) { +} + +func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) { +} + +func newMockErrorLogger() *mockErrorLogger { + return &mockErrorLogger{} +} + +func TestParser(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "Simple expression", + tokens: []token.Token{ + {Type: token.NUMBER, Literal: 1}, + {Type: token.PLUS, Lexeme: "+"}, + {Type: token.NUMBER, Literal: 2}, + {Type: token.EOF}, + }, + expected: "(+ 1 2)", + }, + { + name: "Unary expression", + tokens: []token.Token{ + {Type: token.MINUS, Lexeme: "-"}, + {Type: token.NUMBER, Literal: 123}, + {Type: token.EOF}, + }, + expected: "(- 123)", + }, + { + name: "Grouping expression", + tokens: []token.Token{ + {Type: token.LEFT_PAREN, Lexeme: "("}, + {Type: token.NUMBER, Literal: 1}, + {Type: token.PLUS, Lexeme: "+"}, + {Type: token.NUMBER, Literal: 2}, + {Type: token.RIGHT_PAREN, Lexeme: ")"}, + {Type: token.EOF}, + }, + expected: "(group (+ 1 2))", + }, + { + name: "Comparison expression", + tokens: []token.Token{ + {Type: token.NUMBER, Literal: 1}, + {Type: token.GREATER, Lexeme: ">"}, + {Type: token.NUMBER, Literal: 2}, + {Type: token.EOF}, + }, + expected: "(> 1 2)", + }, + { + name: "Equality expression", + tokens: []token.Token{ + {Type: token.NUMBER, Literal: 1}, + {Type: token.EQUAL_EQUAL, Lexeme: "=="}, + {Type: token.NUMBER, Literal: 2}, + {Type: token.EOF}, + }, + expected: "(== 1 2)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, newMockErrorLogger()) + expr := parser.Parse() + ap := ast.NewPrinter() + s := ap.Print(expr) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 004baae..824866e 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -310,32 +310,33 @@ func TestScanToken(t *testing.T) { name string source string expected token.TokenType + lexeme string }{ - {"Left paren", "(", token.LEFT_PAREN}, - {"Right paren", ")", token.RIGHT_PAREN}, - {"Left brace", "{", token.LEFT_BRACE}, - {"Right brace", "}", token.RIGHT_BRACE}, - {"Comma", ",", token.COMMA}, - {"Dot", ".", token.DOT}, - {"Minus", "-", token.MINUS}, - {"Plus", "+", token.PLUS}, - {"Semicolon", ";", token.SEMICOLON}, - {"Star", "*", token.STAR}, - {"Bang", "!", token.BANG}, - {"Bang equal", "!=", token.BANG_EQUAL}, - {"Equal", "=", token.EQUAL}, - {"Equal equal", "==", token.EQUAL_EQUAL}, - {"Less", "<", token.LESS}, - {"Less equal", "<=", token.LESS_EQUAL}, - {"Greater", ">", token.GREATER}, - {"Greater equal", ">=", token.GREATER_EQUAL}, - {"Slash", "/", token.SLASH}, - {"Comment", "// comment\n", token.EOF}, - {"Whitespace", " \r\t\n", token.EOF}, - {"String", `"hello"`, token.STRING}, - {"Number", "123", token.NUMBER}, - {"Identifier", "var", token.VAR}, - {"Unexpected character", "@", token.EOF}, + {"Left paren", "(", token.LEFT_PAREN, "("}, + {"Right paren", ")", token.RIGHT_PAREN, ")"}, + {"Left brace", "{", token.LEFT_BRACE, "{"}, + {"Right brace", "}", token.RIGHT_BRACE, "}"}, + {"Comma", ",", token.COMMA, ","}, + {"Dot", ".", token.DOT, "."}, + {"Minus", "-", token.MINUS, "-"}, + {"Plus", "+", token.PLUS, "+"}, + {"Semicolon", ";", token.SEMICOLON, ";"}, + {"Star", "*", token.STAR, "*"}, + {"Bang", "!", token.BANG, "!"}, + {"Bang equal", "!=", token.BANG_EQUAL, "!="}, + {"Equal", "=", token.EQUAL, "="}, + {"Equal equal", "==", token.EQUAL_EQUAL, "=="}, + {"Less", "<", token.LESS, "<"}, + {"Less equal", "<=", token.LESS_EQUAL, "<="}, + {"Greater", ">", token.GREATER, ">"}, + {"Greater equal", ">=", token.GREATER_EQUAL, ">="}, + {"Slash", "/", token.SLASH, "/"}, + {"Comment", "// comment\n", token.EOF, ""}, + {"Whitespace", " \r\t\n", token.EOF, ""}, + {"String", `"hello"`, token.STRING, `"hello"`}, + {"Number", "123", token.NUMBER, "123"}, + {"Identifier", "var", token.VAR, "var"}, + {"Unexpected character", "@", token.EOF, ""}, } for _, tt := range tests { @@ -346,6 +347,9 @@ func TestScanToken(t *testing.T) { if scanner.tokens[0].Type != tt.expected { t.Errorf("expected %v, got %v", tt.expected, scanner.tokens[0].Type) } + if scanner.tokens[0].Lexeme != tt.lexeme { + t.Errorf("expected %v, got %v", tt.lexeme, scanner.tokens[0].Lexeme) + } } else if tt.expected != token.EOF { t.Errorf("expected %v, got no tokens", tt.expected) }