diff --git a/.task/checksum/astgen b/.task/checksum/astgen index cbdfebb..b08fd33 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -4ade8f4433feb634e5c96f97b05adfc1 +53f18a0e7ed4ab1114f57c8bbdaf9abd diff --git a/ast/printer.go b/ast/printer.go index 3ec69e4..800034e 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -41,6 +41,16 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer) } +func (ap *Printer) VisitBlockStmt(stmt *BlockStmt) any { + str := "(block\n" + + for _, s := range stmt.Statements { + str += " " + s.Accept(ap).(string) + "\n" + } + + return str + ")" +} + func (ap *Printer) VisitVariableExpr(expr *VariableExpr) any { return expr.Name.Lexeme } diff --git a/ast/printer_test.go b/ast/printer_test.go index a6e6b5e..72cdf33 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -107,3 +107,23 @@ func TestPrintStmts(t *testing.T) { }) } } + +func TestPrintBlockStmt(t *testing.T) { + stmt := &BlockStmt{ + Statements: []Stmt{ + &ExpressionStmt{ + Expression: &LiteralExpr{Value: 1}, + }, + &ExpressionStmt{ + Expression: &LiteralExpr{Value: 2}, + }, + }, + } + + printer := NewPrinter() + result := printer.PrintStmts([]Stmt{stmt}) + expected := "(block\n 1\n 2\n)\n" + if result != expected { + t.Errorf("expected %v, got %v", expected, result) + } +} diff --git a/ast/stmt.go b/ast/stmt.go index 2f80ec8..3ceeeea 100644 --- a/ast/stmt.go +++ b/ast/stmt.go @@ -4,6 +4,7 @@ import "golox/token" type StmtVisitor[T any] interface { VisitErrorStmt(es *ErrorStmt) T + VisitBlockStmt(bs *BlockStmt) T VisitExpressionStmt(es *ExpressionStmt) T VisitPrintStmt(ps *PrintStmt) T VisitVarStmt(vs *VarStmt) T @@ -21,6 +22,14 @@ func (es *ErrorStmt) Accept(v StmtVisitor[any]) any { return v.VisitErrorStmt(es) } +type BlockStmt struct { + Statements []Stmt +} + +func (bs *BlockStmt) Accept(v StmtVisitor[any]) any { + return v.VisitBlockStmt(bs) +} + type ExpressionStmt struct { Expression Expr } diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index 0384fc8..6fa24f7 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -33,6 +33,7 @@ func main() { defineAst(d, "Stmt", []string{ "Error : Value string", + "Block : Statements []Stmt", "Expression : Expression Expr", "Print : Expression Expr", "Var : Name token.Token, Initializer Expr", diff --git a/interpreter/environment.go b/interpreter/environment.go index e97137d..51b895f 100644 --- a/interpreter/environment.go +++ b/interpreter/environment.go @@ -2,12 +2,14 @@ package interpreter // environment represents the environment in which the interpreter operates. type environment struct { - values map[string]any + values map[string]any + enclosing *environment } // newEnvironment creates a new environment. -func newEnvironment() *environment { - return &environment{values: make(map[string]any)} +// e is the enclosing environment and is nil for the global environment. +func newEnvironment(e *environment) *environment { + return &environment{values: make(map[string]any), enclosing: e} } // define defines a new variable in the environment. @@ -19,6 +21,10 @@ func (e *environment) define(name string, value any) { func (e *environment) get(name string) any { value, ok := e.values[name] if !ok { + if e.enclosing != nil { + return e.enclosing.get(name) + } + panic("Undefined variable '" + name + "'.") } @@ -29,6 +35,11 @@ func (e *environment) get(name string) any { func (e *environment) assign(name string, value any) { _, ok := e.values[name] if !ok { + if e.enclosing != nil { + e.enclosing.assign(name, value) + return + } + panic("Undefined variable '" + name + "'.") } diff --git a/interpreter/environment_test.go b/interpreter/environment_test.go new file mode 100644 index 0000000..3c72f0a --- /dev/null +++ b/interpreter/environment_test.go @@ -0,0 +1,84 @@ +package interpreter + +import ( + "testing" +) + +func TestDefineAndGet(t *testing.T) { + env := newEnvironment(nil) + env.define("x", 42) + + value := env.get("x") + if value != 42 { + t.Errorf("expected 42, got %v", value) + } +} + +func TestGetUndefinedVariable(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic for undefined variable") + } + }() + + env := newEnvironment(nil) + env.get("x") +} + +func TestAssign(t *testing.T) { + env := newEnvironment(nil) + env.define("x", 42) + env.assign("x", 43) + + value := env.get("x") + if value != 43 { + t.Errorf("expected 43, got %v", value) + } +} + +func TestAssignUndefinedVariable(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic for assigning undefined variable") + } + }() + + env := newEnvironment(nil) + env.assign("x", 43) +} + +func TestEnclosingEnvironmentGet(t *testing.T) { + global := newEnvironment(nil) + global.define("x", 42) + + local := newEnvironment(global) + value := local.get("x") + if value != 42 { + t.Errorf("expected 42, got %v", value) + } +} + +func TestEnclosingEnvironmentAssign(t *testing.T) { + global := newEnvironment(nil) + global.define("x", 42) + + local := newEnvironment(global) + local.assign("x", 43) + + value := global.get("x") + if value != 43 { + t.Errorf("expected 43, got %v", value) + } +} + +func TestEnclosingEnvironmentAssignUndefined(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expected panic for assigning undefined variable in enclosing environment") + } + }() + + global := newEnvironment(nil) + local := newEnvironment(global) + local.assign("x", 43) +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 4716a1d..85877f0 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -16,7 +16,7 @@ type Interpreter struct { // New creates a new Interpreter. func New(el errors.Logger) *Interpreter { - return &Interpreter{el, newEnvironment()} + return &Interpreter{el, newEnvironment(nil)} } // Interpret interprets the AST. @@ -57,6 +57,12 @@ func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any { return nil } +// VisitBlockStmt visits a block statement. +func (i *Interpreter) VisitBlockStmt(bs *ast.BlockStmt) any { + i.executeBlock(bs.Statements, newEnvironment(i.env)) + return nil +} + // Interpret interprets the AST. func (i *Interpreter) InterpretExpr(expr ast.Expr) string { defer i.afterPanic() @@ -206,6 +212,20 @@ func (i *Interpreter) execute(s ast.Stmt) { s.Accept(i) } +// executeBlock executes a block of statements. +func (i *Interpreter) executeBlock(stmts []ast.Stmt, env *environment) { + previous := i.env + defer func() { + i.env = previous + }() + + i.env = env + + for _, stmt := range stmts { + i.execute(stmt) + } +} + // stringify returns a string representation of a value. func stringify(v any) string { if v == nil { diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index c7b73eb..6bb0ea1 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -388,3 +388,70 @@ func TestInterpretAssignment(t *testing.T) { t.Errorf("expected 42, got %v", result) } } + +func TestInterpretAssignmentUndefinedVariable(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + assign := &ast.AssignExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Value: &ast.LiteralExpr{Value: 42.0}, + } + + defer func() { + if r := recover(); r != "Undefined variable 'foo'." { + t.Errorf("expected panic with 'undefined variable', got %v", r) + } + }() + + i.VisitAssignExpr(assign) +} + +func TestInterpretBlockStatement(t *testing.T) { + old := os.Stdout // keep backup of the real stdout + r, w, err := os.Pipe() + if err != nil { + t.Fatal(err) + } + os.Stdout = w + + outC := make(chan string) + // copy the output in a separate goroutine so printing can't block indefinitely + go func() { + var buf bytes.Buffer + io.Copy(&buf, r) + outC <- buf.String() + }() + + // begin unit test + i := New(errors.NewMockErrorLogger()) + block := &ast.BlockStmt{ + Statements: []ast.Stmt{ + &ast.VarStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Initializer: &ast.LiteralExpr{Value: 42.0}, + }, + &ast.PrintStmt{ + Expression: &ast.VariableExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + }, + }, + }, + } + + i.VisitBlockStmt(block) + _, found := i.env.values["foo"] + if found { + t.Errorf("expected to not find 'foo' in environment") + } + // end unit test + + // back to normal state + w.Close() + os.Stdout = old // restoring the real stdout + out := <-outC + + // reading our temp stdout + expected := "42\n" + if out != expected { + t.Errorf("run() = %v; want %v", out, expected) + } +} diff --git a/parser/parser.go b/parser/parser.go index ec84297..ed74b24 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -83,15 +83,34 @@ func (p *Parser) varDeclaration() ast.Stmt { return &ast.VarStmt{Name: name, Initializer: initializer} } -// statement → exprStmt | printStmt ; +// statement → exprStmt | printStmt | block ; func (p *Parser) statement() ast.Stmt { if p.match(token.PRINT) { return p.printStatement() } + if p.match(token.LEFT_BRACE) { + return p.blockStatement() + } return p.expressionStatement() } +// block → "{" declaration* "}" ; +func (p *Parser) blockStatement() ast.Stmt { + statements := []ast.Stmt{} + + for !p.check(token.RIGHT_BRACE) && !p.isAtEnd() { + statements = append(statements, p.declaration()) + } + + err := p.consume(token.RIGHT_BRACE, "Expect '}' after block.") + if err != nil { + return p.fromErrorExpr(err) + } + + return &ast.BlockStmt{Statements: statements} +} + // exprStmt → expression ";" ; func (p *Parser) expressionStatement() ast.Stmt { expr := p.expression() diff --git a/parser/parser_test.go b/parser/parser_test.go index 51b105a..875b0db 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -259,3 +259,42 @@ func TestParseAssignment(t *testing.T) { }) } } + +func TestParseBlockStatement(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple block statement", + tokens: []token.Token{ + {Type: token.LEFT_BRACE, Lexeme: "{"}, + {Type: token.VAR, Lexeme: "var"}, + {Type: token.IDENTIFIER, Lexeme: "foo"}, + {Type: token.EQUAL, Lexeme: "="}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.RIGHT_BRACE, Lexeme: "}"}, + {Type: token.EOF}, + }, + expected: "(block\n (var foo 42)\n)\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +}