Fix expression parser

main
oabrivard 1 year ago
parent 2aae58b670
commit e0e921602d

@ -1 +1 @@
5f28a0838ca9382a947da5bd52756cf3 f630c7b1cf9f3b55f1467ac4c8704152

@ -31,11 +31,11 @@ tasks:
astgen: astgen:
desc: "Generate AST nodes" desc: "Generate AST nodes"
cmds: cmds:
- go run cmd/astgen/main.go "./parser" - go run cmd/astgen/main.go "./ast"
sources: sources:
- cmd/astgen/main.go - cmd/astgen/main.go
generates: generates:
- parser/expr.go - ast/expr.go
clean: clean:
desc: "Clean up" desc: "Clean up"

@ -1,4 +1,4 @@
package parser package ast
import "golox/token" import "golox/token"

@ -2,44 +2,43 @@ package ast
import ( import (
"fmt" "fmt"
"golox/parser"
) )
type Printer struct { type Printer struct {
} }
func New() *Printer { func NewPrinter() *Printer {
return &Printer{} return &Printer{}
} }
func (ap *Printer) Print(expr parser.Expr) string { func (ap *Printer) Print(expr Expr) string {
return expr.Accept(ap).(string) return expr.Accept(ap).(string)
} }
func (ap *Printer) VisitBinaryExpr(expr *parser.BinaryExpr) any { func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any {
return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right)
} }
func (ap *Printer) VisitGroupingExpr(expr *parser.GroupingExpr) any { func (ap *Printer) VisitGroupingExpr(expr *GroupingExpr) any {
return ap.parenthesize("group", expr.Expression) return ap.parenthesize("group", expr.Expression)
} }
func (ap *Printer) VisitLiteralExpr(expr *parser.LiteralExpr) any { func (ap *Printer) VisitLiteralExpr(expr *LiteralExpr) any {
if expr.Value == nil { if expr.Value == nil {
return "nil" return "nil"
} }
return fmt.Sprint(expr.Value) return fmt.Sprint(expr.Value)
} }
func (ap *Printer) VisitUnaryExpr(expr *parser.UnaryExpr) any { func (ap *Printer) VisitUnaryExpr(expr *UnaryExpr) any {
return ap.parenthesize(expr.Operator.Lexeme, expr.Right) return ap.parenthesize(expr.Operator.Lexeme, expr.Right)
} }
func (ap *Printer) VisitErrorExpr(expr *parser.ErrorExpr) any { func (ap *Printer) VisitErrorExpr(expr *ErrorExpr) any {
return expr.Value return expr.Value
} }
func (ap *Printer) parenthesize(name string, exprs ...parser.Expr) string { func (ap *Printer) parenthesize(name string, exprs ...Expr) string {
str := "(" + name str := "(" + name
for _, expr := range exprs { for _, expr := range exprs {

@ -0,0 +1,64 @@
package ast
import (
"golox/token"
"testing"
)
func TestPrinter(t *testing.T) {
tests := []struct {
name string
expr Expr
expected string
}{
{
name: "Binary expression",
expr: &BinaryExpr{
Left: &LiteralExpr{Value: 1},
Operator: token.Token{Type: token.PLUS, Lexeme: "+"},
Right: &LiteralExpr{Value: 2},
},
expected: "(+ 1 2)",
},
{
name: "Grouping expression",
expr: &GroupingExpr{
Expression: &LiteralExpr{Value: 1},
},
expected: "(group 1)",
},
{
name: "Literal expression",
expr: &LiteralExpr{Value: 123},
expected: "123",
},
{
name: "Unary expression",
expr: &UnaryExpr{
Operator: token.Token{Type: token.MINUS, Lexeme: "-"},
Right: &LiteralExpr{Value: 123},
},
expected: "(- 123)",
},
{
name: "Nil literal expression",
expr: &LiteralExpr{Value: nil},
expected: "nil",
},
{
name: "Error expression",
expr: &ErrorExpr{Value: "error"},
expected: "error",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
printer := NewPrinter()
result := printer.Print(tt.expr)
if result != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, result)
}
})
}
}

@ -40,7 +40,7 @@ func defineAst(outputDir, baseName string, types []string) {
} }
defer file.Close() defer file.Close()
file.WriteString("package parser\n\n") file.WriteString("package ast\n\n")
file.WriteString("import \"golox/token\"\n\n") file.WriteString("import \"golox/token\"\n\n")
defineVisitor(file, baseName, types) defineVisitor(file, baseName, types)

@ -80,6 +80,6 @@ func (l *Lox) run(source string) {
return return
} }
p := ast.New() p := ast.NewPrinter()
fmt.Println(p.Print(expr)) fmt.Println(p.Print(expr))
} }

@ -24,7 +24,7 @@ func TestRun(t *testing.T) {
outC <- buf.String() outC <- buf.String()
}() }()
source := "print('Hello, World!');" source := "1+4/2"
l := New() l := New()
l.run(source) l.run(source)
@ -34,7 +34,7 @@ func TestRun(t *testing.T) {
out := <-outC out := <-outC
// reading our temp stdout // reading our temp stdout
expected := source + "\n" expected := "(+ 1 (/ 4 2))\n"
if out != expected { if out != expected {
t.Errorf("run() = %v; want %v", out, expected) t.Errorf("run() = %v; want %v", out, expected)
} }
@ -63,7 +63,7 @@ func TestRunFile(t *testing.T) {
} }
defer os.Remove(tmpfile.Name()) defer os.Remove(tmpfile.Name())
content := "print('Hello, World!');" content := "1+4/2"
if _, err := tmpfile.Write([]byte(content)); err != nil { if _, err := tmpfile.Write([]byte(content)); err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -79,7 +79,7 @@ func TestRunFile(t *testing.T) {
out := <-outC out := <-outC
// reading our temp stdout // reading our temp stdout
expected := "print('Hello, World!');\n" expected := "(+ 1 (/ 4 2))\n"
if out != expected { if out != expected {
t.Errorf("RunFile() = %v; want %v", out, expected) t.Errorf("RunFile() = %v; want %v", out, expected)
} }
@ -166,7 +166,7 @@ func TestRunPrompt(t *testing.T) {
outC <- buf.String() outC <- buf.String()
}() }()
input := "print('Hello, World!');\n\n" input := "1+4/2\n\n"
wIn.Write([]byte(input)) wIn.Write([]byte(input))
wIn.Close() wIn.Close()
@ -178,7 +178,7 @@ func TestRunPrompt(t *testing.T) {
os.Stdout = oldStdout os.Stdout = oldStdout
out := <-outC out := <-outC
expected := "> print('Hello, World!');\n\n> " expected := "> (+ 1 (/ 4 2))\n> "
if out != expected { if out != expected {
t.Errorf("RunPrompt() = %v; want %v", out, expected) t.Errorf("RunPrompt() = %v; want %v", out, expected)
} }

@ -15,6 +15,7 @@
package parser package parser
import ( import (
"golox/ast"
"golox/errors" "golox/errors"
"golox/token" "golox/token"
) )
@ -32,96 +33,96 @@ func New(tokens []token.Token, el errors.Logger) *Parser {
} }
// Parse parses the tokens and returns the AST. // Parse parses the tokens and returns the AST.
func (p *Parser) Parse() Expr { func (p *Parser) Parse() ast.Expr {
return p.expression() return p.expression()
} }
// expression → equality ; // expression → equality ;
func (p *Parser) expression() Expr { func (p *Parser) expression() ast.Expr {
return p.equality() return p.equality()
} }
// equality → comparison ( ( "!=" | "==" ) comparison )* ; // equality → comparison ( ( "!=" | "==" ) comparison )* ;
func (p *Parser) equality() Expr { func (p *Parser) equality() ast.Expr {
expr := p.comparison() expr := p.comparison()
for p.match(token.BANG_EQUAL, token.EQUAL_EQUAL) { for p.match(token.BANG_EQUAL, token.EQUAL_EQUAL) {
operator := p.previous() operator := p.previous()
right := p.comparison() right := p.comparison()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right}
} }
return expr return expr
} }
// comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ; // comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ;
func (p *Parser) comparison() Expr { func (p *Parser) comparison() ast.Expr {
expr := p.term() expr := p.term()
for p.match(token.GREATER, token.GREATER_EQUAL, token.LESS, token.LESS_EQUAL) { for p.match(token.GREATER, token.GREATER_EQUAL, token.LESS, token.LESS_EQUAL) {
operator := p.previous() operator := p.previous()
right := p.term() right := p.term()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right}
} }
return expr return expr
} }
// term → factor ( ( "-" | "+" ) factor )* ; // term → factor ( ( "-" | "+" ) factor )* ;
func (p *Parser) term() Expr { func (p *Parser) term() ast.Expr {
expr := p.factor() expr := p.factor()
for p.match(token.MINUS, token.PLUS) { for p.match(token.MINUS, token.PLUS) {
operator := p.previous() operator := p.previous()
right := p.factor() right := p.factor()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right}
} }
return expr return expr
} }
// factor → unary ( ( "/" | "*" ) unary )* ; // factor → unary ( ( "/" | "*" ) unary )* ;
func (p *Parser) factor() Expr { func (p *Parser) factor() ast.Expr {
expr := p.unary() expr := p.unary()
for p.match(token.SLASH, token.STAR) { for p.match(token.SLASH, token.STAR) {
operator := p.previous() operator := p.previous()
right := p.unary() right := p.unary()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} expr = &ast.BinaryExpr{Left: expr, Operator: operator, Right: right}
} }
return expr return expr
} }
// unary → ( "!" | "-" ) unary | primary ; // unary → ( "!" | "-" ) unary | primary ;
func (p *Parser) unary() Expr { func (p *Parser) unary() ast.Expr {
if p.match(token.BANG, token.MINUS) { if p.match(token.BANG, token.MINUS) {
operator := p.previous() operator := p.previous()
right := p.unary() right := p.unary()
return &UnaryExpr{Operator: operator, Right: right} return &ast.UnaryExpr{Operator: operator, Right: right}
} }
return p.primary() return p.primary()
} }
// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ; // primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ;
func (p *Parser) primary() Expr { func (p *Parser) primary() ast.Expr {
switch { switch {
case p.match(token.FALSE): case p.match(token.FALSE):
return &LiteralExpr{Value: false} return &ast.LiteralExpr{Value: false}
case p.match(token.TRUE): case p.match(token.TRUE):
return &LiteralExpr{Value: true} return &ast.LiteralExpr{Value: true}
case p.match(token.NIL): case p.match(token.NIL):
return &LiteralExpr{Value: nil} return &ast.LiteralExpr{Value: nil}
case p.match(token.NUMBER, token.STRING): case p.match(token.NUMBER, token.STRING):
return &LiteralExpr{Value: p.previous().Literal} return &ast.LiteralExpr{Value: p.previous().Literal}
case p.match(token.LEFT_PAREN): case p.match(token.LEFT_PAREN):
expr := p.expression() expr := p.expression()
err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.") err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.")
if err != nil { if err != nil {
return err return err
} }
return &GroupingExpr{Expression: expr} return &ast.GroupingExpr{Expression: expr}
} }
return p.newErrorExpr(p.peek(), "Expect expression.") return p.newErrorExpr(p.peek(), "Expect expression.")
@ -140,7 +141,7 @@ func (p *Parser) match(types ...token.TokenType) bool {
} }
// consume consumes the current token if it is of the given type. // consume consumes the current token if it is of the given type.
func (p *Parser) consume(tt token.TokenType, message string) *ErrorExpr { func (p *Parser) consume(tt token.TokenType, message string) *ast.ErrorExpr {
if p.check(tt) { if p.check(tt) {
p.advance() p.advance()
return nil return nil
@ -183,9 +184,9 @@ func (p *Parser) previous() token.Token {
} }
// newErrorExpr creates a new ErrorExpr and reports the error. // newErrorExpr creates a new ErrorExpr and reports the error.
func (p *Parser) newErrorExpr(t token.Token, message string) *ErrorExpr { func (p *Parser) newErrorExpr(t token.Token, message string) *ast.ErrorExpr {
p.errLogger.ErrorAtToken(t, message) p.errLogger.ErrorAtToken(t, message)
return &ErrorExpr{Value: message} return &ast.ErrorExpr{Value: message}
} }
// synchronize synchronizes the parser after an error. // synchronize synchronizes the parser after an error.

@ -0,0 +1,91 @@
package parser
import (
"golox/ast"
"golox/token"
"testing"
)
type mockErrorLogger struct{}
func (el *mockErrorLogger) Error(line int, message string) {
}
func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) {
}
func newMockErrorLogger() *mockErrorLogger {
return &mockErrorLogger{}
}
func TestParser(t *testing.T) {
tests := []struct {
name string
tokens []token.Token
expected string
}{
{
name: "Simple expression",
tokens: []token.Token{
{Type: token.NUMBER, Literal: 1},
{Type: token.PLUS, Lexeme: "+"},
{Type: token.NUMBER, Literal: 2},
{Type: token.EOF},
},
expected: "(+ 1 2)",
},
{
name: "Unary expression",
tokens: []token.Token{
{Type: token.MINUS, Lexeme: "-"},
{Type: token.NUMBER, Literal: 123},
{Type: token.EOF},
},
expected: "(- 123)",
},
{
name: "Grouping expression",
tokens: []token.Token{
{Type: token.LEFT_PAREN, Lexeme: "("},
{Type: token.NUMBER, Literal: 1},
{Type: token.PLUS, Lexeme: "+"},
{Type: token.NUMBER, Literal: 2},
{Type: token.RIGHT_PAREN, Lexeme: ")"},
{Type: token.EOF},
},
expected: "(group (+ 1 2))",
},
{
name: "Comparison expression",
tokens: []token.Token{
{Type: token.NUMBER, Literal: 1},
{Type: token.GREATER, Lexeme: ">"},
{Type: token.NUMBER, Literal: 2},
{Type: token.EOF},
},
expected: "(> 1 2)",
},
{
name: "Equality expression",
tokens: []token.Token{
{Type: token.NUMBER, Literal: 1},
{Type: token.EQUAL_EQUAL, Lexeme: "=="},
{Type: token.NUMBER, Literal: 2},
{Type: token.EOF},
},
expected: "(== 1 2)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := New(tt.tokens, newMockErrorLogger())
expr := parser.Parse()
ap := ast.NewPrinter()
s := ap.Print(expr)
if s != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, s)
}
})
}
}

@ -310,32 +310,33 @@ func TestScanToken(t *testing.T) {
name string name string
source string source string
expected token.TokenType expected token.TokenType
lexeme string
}{ }{
{"Left paren", "(", token.LEFT_PAREN}, {"Left paren", "(", token.LEFT_PAREN, "("},
{"Right paren", ")", token.RIGHT_PAREN}, {"Right paren", ")", token.RIGHT_PAREN, ")"},
{"Left brace", "{", token.LEFT_BRACE}, {"Left brace", "{", token.LEFT_BRACE, "{"},
{"Right brace", "}", token.RIGHT_BRACE}, {"Right brace", "}", token.RIGHT_BRACE, "}"},
{"Comma", ",", token.COMMA}, {"Comma", ",", token.COMMA, ","},
{"Dot", ".", token.DOT}, {"Dot", ".", token.DOT, "."},
{"Minus", "-", token.MINUS}, {"Minus", "-", token.MINUS, "-"},
{"Plus", "+", token.PLUS}, {"Plus", "+", token.PLUS, "+"},
{"Semicolon", ";", token.SEMICOLON}, {"Semicolon", ";", token.SEMICOLON, ";"},
{"Star", "*", token.STAR}, {"Star", "*", token.STAR, "*"},
{"Bang", "!", token.BANG}, {"Bang", "!", token.BANG, "!"},
{"Bang equal", "!=", token.BANG_EQUAL}, {"Bang equal", "!=", token.BANG_EQUAL, "!="},
{"Equal", "=", token.EQUAL}, {"Equal", "=", token.EQUAL, "="},
{"Equal equal", "==", token.EQUAL_EQUAL}, {"Equal equal", "==", token.EQUAL_EQUAL, "=="},
{"Less", "<", token.LESS}, {"Less", "<", token.LESS, "<"},
{"Less equal", "<=", token.LESS_EQUAL}, {"Less equal", "<=", token.LESS_EQUAL, "<="},
{"Greater", ">", token.GREATER}, {"Greater", ">", token.GREATER, ">"},
{"Greater equal", ">=", token.GREATER_EQUAL}, {"Greater equal", ">=", token.GREATER_EQUAL, ">="},
{"Slash", "/", token.SLASH}, {"Slash", "/", token.SLASH, "/"},
{"Comment", "// comment\n", token.EOF}, {"Comment", "// comment\n", token.EOF, ""},
{"Whitespace", " \r\t\n", token.EOF}, {"Whitespace", " \r\t\n", token.EOF, ""},
{"String", `"hello"`, token.STRING}, {"String", `"hello"`, token.STRING, `"hello"`},
{"Number", "123", token.NUMBER}, {"Number", "123", token.NUMBER, "123"},
{"Identifier", "var", token.VAR}, {"Identifier", "var", token.VAR, "var"},
{"Unexpected character", "@", token.EOF}, {"Unexpected character", "@", token.EOF, ""},
} }
for _, tt := range tests { for _, tt := range tests {
@ -346,6 +347,9 @@ func TestScanToken(t *testing.T) {
if scanner.tokens[0].Type != tt.expected { if scanner.tokens[0].Type != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, scanner.tokens[0].Type) t.Errorf("expected %v, got %v", tt.expected, scanner.tokens[0].Type)
} }
if scanner.tokens[0].Lexeme != tt.lexeme {
t.Errorf("expected %v, got %v", tt.lexeme, scanner.tokens[0].Lexeme)
}
} else if tt.expected != token.EOF { } else if tt.expected != token.EOF {
t.Errorf("expected %v, got no tokens", tt.expected) t.Errorf("expected %v, got no tokens", tt.expected)
} }

Loading…
Cancel
Save