diff --git a/.task/checksum/astgen b/.task/checksum/astgen index aeceed0..10f1d8c 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -bf6f2cc6973975f2bcc789463842856a +ea7730697fb5d7edf1961630c5f68b96 diff --git a/ast/expr.go b/ast/expr.go index e7fd652..a4ddc0b 100644 --- a/ast/expr.go +++ b/ast/expr.go @@ -8,6 +8,7 @@ type ExprVisitor[T any] interface { VisitGroupingExpr(ge *GroupingExpr) T VisitLiteralExpr(le *LiteralExpr) T VisitUnaryExpr(ue *UnaryExpr) T + VisitVariableExpr(ve *VariableExpr) T } type Expr interface { @@ -57,3 +58,11 @@ func (ue *UnaryExpr) Accept(v ExprVisitor[any]) any { return v.VisitUnaryExpr(ue) } +type VariableExpr struct { + Name token.Token +} + +func (ve *VariableExpr) Accept(v ExprVisitor[any]) any { + return v.VisitVariableExpr(ve) +} + diff --git a/ast/printer.go b/ast/printer.go index 36bd731..985a7b4 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -41,6 +41,10 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer) } +func (ap *Printer) VisitVariableExpr(expr *VariableExpr) any { + return expr.Name.Lexeme +} + func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any { return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) } @@ -68,6 +72,9 @@ func (ap *Printer) parenthesize(name string, exprs ...Expr) string { str := "(" + name for _, expr := range exprs { + if expr == nil { + continue + } str += " " + expr.Accept(ap).(string) } diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index 0f07a87..0ca5b56 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -27,6 +27,7 @@ func main() { "Grouping : Expression Expr", "Literal : Value any", "Unary : Operator token.Token, Right Expr", + "Variable : Name token.Token", }) defineAst(d, "Stmt", []string{ diff --git a/interpreter/environment.go b/interpreter/environment.go new file mode 100644 index 0000000..e5a9c21 --- /dev/null +++ b/interpreter/environment.go @@ -0,0 +1,22 @@ +package interpreter + +type environment struct { + values map[string]any +} + +func newEnvironment() *environment { + return &environment{values: make(map[string]any)} +} + +func (e *environment) define(name string, value any) { + e.values[name] = value +} + +func (e *environment) get(name string) any { + value, ok := e.values[name] + if !ok { + panic("Undefined variable '" + name + "'.") + } + + return value +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 3ee4360..0897339 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -11,11 +11,12 @@ import ( // Interpreter interprets the AST. type Interpreter struct { errLogger errors.Logger + env *environment } // New creates a new Interpreter. func New(el errors.Logger) *Interpreter { - return &Interpreter{el} + return &Interpreter{el, newEnvironment()} } // Interpret interprets the AST. @@ -47,6 +48,12 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any { // VisitVarStmt visits a var statement. func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any { + var value any + if vs.Initializer != nil { + value = i.evaluate(vs.Initializer) + } + + i.env.define(vs.Name.Lexeme, value) return nil } @@ -142,6 +149,11 @@ func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any { panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line)) } +// VisitVariableExpr visits a VariableExpr. +func (i *Interpreter) VisitVariableExpr(v *ast.VariableExpr) any { + return i.env.get(v.Name.Lexeme) +} + // checkNumberOperands checks if the operands are numbers. func checkNumberOperands(operator token.Token, operands ...any) { for _, operand := range operands { diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index ecf78aa..6d51431 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -1,16 +1,18 @@ // FILE: interpreter_test.go -package interpreter_test +package interpreter import ( + "bytes" "golox/ast" "golox/errors" - "golox/interpreter" "golox/token" + "io" + "os" "testing" ) func TestInterpretLiteralExpr(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42} result := i.VisitLiteralExpr(literal) @@ -20,7 +22,7 @@ func TestInterpretLiteralExpr(t *testing.T) { } func TestInterpretGroupingExpr(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42} grouping := &ast.GroupingExpr{Expression: literal} @@ -31,7 +33,7 @@ func TestInterpretGroupingExpr(t *testing.T) { } func TestInterpretUnaryExpr(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42.0} unary := &ast.UnaryExpr{ Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, @@ -45,7 +47,7 @@ func TestInterpretUnaryExpr(t *testing.T) { } func TestInterpretUnaryExprBang(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: true} unary := &ast.UnaryExpr{ Operator: token.Token{Type: token.BANG, Lexeme: "!"}, @@ -59,7 +61,7 @@ func TestInterpretUnaryExprBang(t *testing.T) { } func TestInterpretErrorExpr(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) errorExpr := &ast.ErrorExpr{Value: "error"} defer func() { @@ -72,7 +74,7 @@ func TestInterpretErrorExpr(t *testing.T) { } func TestInterpretExpr(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42.0} defer func() { @@ -88,7 +90,7 @@ func TestInterpretExpr(t *testing.T) { } func TestInterpretBinaryExpr(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -104,7 +106,7 @@ func TestInterpretBinaryExpr(t *testing.T) { } func TestInterpretBinaryExprDivisionByZero(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 0.0} binary := &ast.BinaryExpr{ @@ -123,7 +125,7 @@ func TestInterpretBinaryExprDivisionByZero(t *testing.T) { } func TestInterpretBinaryExprAddition(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -139,7 +141,7 @@ func TestInterpretBinaryExprAddition(t *testing.T) { } func TestInterpretBinaryExprSubtraction(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -155,7 +157,7 @@ func TestInterpretBinaryExprSubtraction(t *testing.T) { } func TestInterpretBinaryExprStringConcatenation(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: "foo"} right := &ast.LiteralExpr{Value: "bar"} binary := &ast.BinaryExpr{ @@ -171,7 +173,7 @@ func TestInterpretBinaryExprStringConcatenation(t *testing.T) { } func TestInterpretBinaryExprInvalidOperands(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: "foo"} right := &ast.LiteralExpr{Value: 42.0} binary := &ast.BinaryExpr{ @@ -190,7 +192,7 @@ func TestInterpretBinaryExprInvalidOperands(t *testing.T) { } func TestInterpretBinaryExprComparison(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -206,7 +208,7 @@ func TestInterpretBinaryExprComparison(t *testing.T) { } func TestInterpretBinaryExprComparisonEqual(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 42.0} binary := &ast.BinaryExpr{ @@ -222,7 +224,7 @@ func TestInterpretBinaryExprComparisonEqual(t *testing.T) { } func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -238,7 +240,7 @@ func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) { } func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: "foo"} right := &ast.LiteralExpr{Value: 42.0} binary := &ast.BinaryExpr{ @@ -257,7 +259,7 @@ func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) { } func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { - i := interpreter.New(errors.NewMockErrorLogger()) + i := New(errors.NewMockErrorLogger()) left := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 2.0} binary := &ast.BinaryExpr{ @@ -274,3 +276,65 @@ func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { i.VisitBinaryExpr(binary) } + +func TestInterpretExprStatement(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + literal := &ast.LiteralExpr{Value: 42.0} + exprStmt := &ast.ExpressionStmt{Expression: literal} + + result := i.VisitExpressionStmt(exprStmt) + if result != nil { + t.Errorf("expected nil, got %v", result) + } +} + +func TestInterpretPrintStatement(t *testing.T) { + old := os.Stdout // keep backup of the real stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = w + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + i := New(errors.NewMockErrorLogger()) + literal := &ast.LiteralExpr{Value: 42.0} + printStmt := &ast.PrintStmt{Expression: literal} + + result := i.VisitPrintStmt(printStmt) + if result != nil { + t.Errorf("expected nil, got %v", result) + } + + // back to normal state + w.Close() + os.Stdout = old // restoring the real stdout + out := <-outC + + // reading our temp stdout + expected := "42\n" + if out != expected { + t.Errorf("run() = %v; want %v", out, expected) + } +} + +func TestInterpretVarStatement(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + varStmt := &ast.VarStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Initializer: &ast.LiteralExpr{Value: 42.0}, + } + + i.VisitVarStmt(varStmt) + result := i.env.get("foo") + if result != 42.0 { + t.Errorf("expected 42, got %v", result) + } +} diff --git a/parser/parser.go b/parser/parser.go index 154ba22..cbdb0e4 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -42,7 +42,7 @@ func (p *Parser) Parse() []ast.Stmt { stmts := []ast.Stmt{} for !p.isAtEnd() { - stmt := p.statement() + stmt := p.declaration() if _, ok := stmt.(*ast.ErrorStmt); ok { p.synchronize() } else { @@ -53,6 +53,36 @@ func (p *Parser) Parse() []ast.Stmt { return stmts } +// declaration → varDecl | statement ; +func (p *Parser) declaration() ast.Stmt { + if p.match(token.VAR) { + return p.varDeclaration() + } + + return p.statement() +} + +// varDecl → "var" IDENTIFIER ( "=" expression )? ";" ; +func (p *Parser) varDeclaration() ast.Stmt { + err := p.consume(token.IDENTIFIER, "Expect variable name.") + if err != nil { + return p.fromErrorExpr(err) + } + name := p.previous() + + var initializer ast.Expr + if p.match(token.EQUAL) { + initializer = p.expression() + } + + err = p.consume(token.SEMICOLON, "Expect ';' after variable declaration.") + if err != nil { + return p.fromErrorExpr(err) + } + + return &ast.VarStmt{Name: name, Initializer: initializer} +} + // statement → exprStmt | printStmt ; func (p *Parser) statement() ast.Stmt { if p.match(token.PRINT) { @@ -152,7 +182,7 @@ func (p *Parser) unary() ast.Expr { return p.primary() } -// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ; +// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER; func (p *Parser) primary() ast.Expr { switch { case p.match(token.FALSE): @@ -163,6 +193,8 @@ func (p *Parser) primary() ast.Expr { return &ast.LiteralExpr{Value: nil} case p.match(token.NUMBER, token.STRING): return &ast.LiteralExpr{Value: p.previous().Literal} + case p.match(token.IDENTIFIER): + return &ast.VariableExpr{Name: p.previous()} case p.match(token.LEFT_PAREN): expr := p.expression() err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.") diff --git a/parser/parser_test.go b/parser/parser_test.go index b2cd5ef..e7b5b71 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -176,3 +176,50 @@ func TestParsePrintStmt(t *testing.T) { }) } } + +func TestParseVarStmt(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple var statement", + tokens: []token.Token{ + {Type: token.VAR, Lexeme: "var"}, + {Type: token.IDENTIFIER, Lexeme: "foo"}, + {Type: token.EQUAL, Lexeme: "="}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(var foo 42)\n", + }, + { + name: "simple var statement", + tokens: []token.Token{ + {Type: token.VAR, Lexeme: "var"}, + {Type: token.IDENTIFIER, Lexeme: "foo"}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(var foo)\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +}