diff --git a/errors/logger.go b/errors/logger.go index 8dfb287..3b6953f 100644 --- a/errors/logger.go +++ b/errors/logger.go @@ -8,17 +8,23 @@ type Logger interface { RuntimeError(message string) } -type mockErrorLogger struct{} +type mockErrorLogger struct { + Errors []string + RuntimeErrors []string +} func (el *mockErrorLogger) Error(line int, message string) { + el.Errors = append(el.Errors, message) } func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) { + el.Errors = append(el.Errors, message) } func (el *mockErrorLogger) RuntimeError(message string) { + el.RuntimeErrors = append(el.RuntimeErrors, message) } func NewMockErrorLogger() *mockErrorLogger { - return &mockErrorLogger{} + return &mockErrorLogger{[]string{}, []string{}} } diff --git a/interpreter/environment.go b/interpreter/environment.go index 51b895f..a649500 100644 --- a/interpreter/environment.go +++ b/interpreter/environment.go @@ -17,6 +17,21 @@ func (e *environment) define(name string, value any) { e.values[name] = value } +// ancestor returns the environment at a given distance. +func (e *environment) ancestor(distance int) *environment { + env := e + for i := 0; i < distance; i++ { + env = env.enclosing + } + + return env +} + +// getAt gets the value of a variable at a given distance. +func (e *environment) getAt(distance int, name string) any { + return e.ancestor(distance).values[name] +} + // get gets the value of a variable in the environment. func (e *environment) get(name string) any { value, ok := e.values[name] @@ -31,6 +46,11 @@ func (e *environment) get(name string) any { return value } +// assignAt assigns a new value to a variable at a given distance. +func (e *environment) assignAt(distance int, name string, value any) { + e.ancestor(distance).values[name] = value +} + // assign assigns a new value to a variable in the environment. func (e *environment) assign(name string, value any) { _, ok := e.values[name] diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 2cdf2ce..9b350a7 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -19,6 +19,7 @@ type Interpreter struct { errLogger errors.Logger env *environment globals *environment + locals map[ast.Expr]int } // New creates a new Interpreter. @@ -34,7 +35,7 @@ func New(el errors.Logger) *Interpreter { globals.define("clock", clockCallable) - return &Interpreter{el, globals, globals} + return &Interpreter{el, globals, globals, make(map[ast.Expr]int)} } // Interpret interprets the AST. @@ -134,7 +135,14 @@ func (i *Interpreter) VisitErrorExpr(e *ast.ErrorExpr) any { // VisitAssignExpr visits an AssignExpr. func (i *Interpreter) VisitAssignExpr(a *ast.AssignExpr) any { value := i.evaluate(a.Value) - i.env.assign(a.Name.Lexeme, value) + + depth, ok := i.locals[a] + if ok { + i.env.assignAt(depth, a.Name.Lexeme, value) + } else { + i.globals.assign(a.Name.Lexeme, value) + } + return value } @@ -256,7 +264,17 @@ func (i *Interpreter) VisitCallExpr(c *ast.CallExpr) any { // VisitVariableExpr visits a VariableExpr. func (i *Interpreter) VisitVariableExpr(v *ast.VariableExpr) any { - return i.env.get(v.Name.Lexeme) + return i.lookUpVariable(&v.Name, v) +} + +// lookUpVariable looks up a variable. +func (i *Interpreter) lookUpVariable(name *token.Token, expr ast.Expr) any { + depth, ok := i.locals[expr] + if ok { + return i.env.getAt(depth, name.Lexeme) + } + + return i.globals.get(name.Lexeme) } // checkNumberOperands checks if the operands are numbers. @@ -304,6 +322,11 @@ func (i *Interpreter) execute(s ast.Stmt) { s.Accept(i) } +// resolve resolves the variables depths of an expression +func (i *Interpreter) Resolve(e ast.Expr, depth int) { + i.locals[e] = depth +} + // executeBlock executes a block of statements. func (i *Interpreter) executeBlock(stmts []ast.Stmt, env *environment) { previous := i.env diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index e334d1d..c2719de 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -437,22 +437,27 @@ func TestInterpretBlockStatement(t *testing.T) { outC <- buf.String() }() + le := &ast.LiteralExpr{Value: 42.0} + ve := &ast.VariableExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + } + // begin unit test i := New(errors.NewMockErrorLogger()) block := &ast.BlockStmt{ Statements: []ast.Stmt{ &ast.VarStmt{ Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, - Initializer: &ast.LiteralExpr{Value: 42.0}, + Initializer: le, }, &ast.PrintStmt{ - Expression: &ast.VariableExpr{ - Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, - }, + Expression: ve, }, }, } + i.locals[le] = 0 + i.locals[ve] = 0 i.VisitBlockStmt(block) _, found := i.env.values["foo"] if found { @@ -655,10 +660,13 @@ func TestInterpretFunctionCall(t *testing.T) { func TestInterpretFunctionCallWithArguments(t *testing.T) { i := New(errors.NewMockErrorLogger()) + + ve := &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "a"}} + functionStmt := &ast.FunctionStmt{ Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, Params: []token.Token{{Type: token.IDENTIFIER, Lexeme: "a"}}, - Body: []ast.Stmt{&ast.BlockStmt{Statements: []ast.Stmt{&ast.PrintStmt{Expression: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "a"}}}}}}, + Body: []ast.Stmt{&ast.BlockStmt{Statements: []ast.Stmt{&ast.PrintStmt{Expression: ve}}}}, } i.VisitFunctionStmt(functionStmt) @@ -670,6 +678,7 @@ func TestInterpretFunctionCallWithArguments(t *testing.T) { }, } + i.locals[ve] = 0 i.VisitCallExpr(callExpr) } diff --git a/lox/lox.go b/lox/lox.go index 56f1999..4dee108 100644 --- a/lox/lox.go +++ b/lox/lox.go @@ -5,6 +5,7 @@ import ( "fmt" "golox/interpreter" "golox/parser" + "golox/resolver" "golox/scanner" "golox/token" "os" @@ -100,5 +101,13 @@ func (l *Lox) run(source string) { return } + resolver := resolver.New(l.interpreter, l) + resolver.Resolve(stmts) + + // Stop if there was a resolution error. + if l.hadError { + return + } + l.interpreter.Interpret(stmts) } diff --git a/resolver/resolver.go b/resolver/resolver.go new file mode 100644 index 0000000..9ff3d8f --- /dev/null +++ b/resolver/resolver.go @@ -0,0 +1,261 @@ +package resolver + +import ( + "golox/ast" + "golox/errors" + "golox/interpreter" + "golox/token" +) + +// FuncType represents the type of a function. +type FuncType int + +const ( + // NoneFuncType represents a function with no type. + NoneFuncType FuncType = iota + // FunctionFuncType represents a function. + FunctionFuncType + // MethodFuncType represents a method. + MethodFuncType + // InitializerFuncType represents an initializer. + InitializerFuncType +) + +// Resolver implements variable resolution for the AST. +type Resolver struct { + errLogger errors.Logger + interpreter *interpreter.Interpreter + scopes []map[string]bool + currFunc FuncType +} + +// New creates a new Resolver. +func New(i *interpreter.Interpreter, el errors.Logger) *Resolver { + return &Resolver{el, i, []map[string]bool{}, NoneFuncType} +} + +// VisitBlockStmt visits a block statement. +func (r *Resolver) VisitBlockStmt(bs *ast.BlockStmt) any { + r.beginScope() + r.Resolve(bs.Statements) + r.endScope() + return nil +} + +// Resolve resolves the variables for each statements in stmts. +func (r *Resolver) Resolve(stmts []ast.Stmt) { + for _, stmt := range stmts { + r.resolveStmt(stmt) + } +} + +// resolveStmt resolves a statement. +func (r *Resolver) resolveStmt(stmt ast.Stmt) { + stmt.Accept(r) +} + +// resolveExpr resolves an expression. +func (r *Resolver) resolveExpr(expr ast.Expr) { + expr.Accept(r) +} + +// beginScope begins a new scope. +func (r *Resolver) beginScope() { + r.scopes = append(r.scopes, map[string]bool{}) +} + +// endScope ends the current scope. +func (r *Resolver) endScope() { + r.scopes = r.scopes[:len(r.scopes)-1] +} + +// visitVarStmt visits a variable statement. +func (r *Resolver) VisitVarStmt(vs *ast.VarStmt) any { + r.declare(&vs.Name) + if vs.Initializer != nil { + r.resolveExpr(vs.Initializer) + } + r.define(&vs.Name) + + return nil +} + +// declare declares a variable in the current scope. +func (r *Resolver) declare(name *token.Token) { + if len(r.scopes) == 0 { + return + } + + scope := r.scopes[len(r.scopes)-1] + if _, ok := scope[name.Lexeme]; ok { + r.errLogger.ErrorAtToken(*name, "Variable with this name already declared in this scope.") + } + scope[name.Lexeme] = false +} + +// define defines a variable in the current scope. +func (r *Resolver) define(name *token.Token) { + if len(r.scopes) == 0 { + return + } + + r.scopes[len(r.scopes)-1][name.Lexeme] = true +} + +// VisitVariableExpr visits a variable expression. +func (r *Resolver) VisitVariableExpr(ve *ast.VariableExpr) any { + if len(r.scopes) > 0 { + if _, ok := r.scopes[len(r.scopes)-1][ve.Name.Lexeme]; ok && !r.scopes[len(r.scopes)-1][ve.Name.Lexeme] { + r.errLogger.ErrorAtToken(ve.Name, "Cannot read local variable in its own initializer.") + } + } + + r.resolveLocal(ve, &ve.Name) + + return nil +} + +// resolveLocal resolves a local variable. +func (r *Resolver) resolveLocal(expr ast.Expr, name *token.Token) { + for i := len(r.scopes) - 1; i >= 0; i-- { + if _, ok := r.scopes[i][name.Lexeme]; ok { + r.interpreter.Resolve(expr, len(r.scopes)-1-i) + return + } + } +} + +// VisitAssignExpr visits an assignment expression. +func (r *Resolver) VisitAssignExpr(ae *ast.AssignExpr) any { + r.resolveExpr(ae.Value) + r.resolveLocal(ae, &ae.Name) + + return nil +} + +// VisitFunctionStmt visits a function statement. +func (r *Resolver) VisitFunctionStmt(fs *ast.FunctionStmt) any { + r.declare(&fs.Name) + r.define(&fs.Name) + + r.resolveFunction(fs, FunctionFuncType) + + return nil +} + +// resolveFunction resolves a function. +func (r *Resolver) resolveFunction(fs *ast.FunctionStmt, t FuncType) { + enclosingFunc := r.currFunc + r.currFunc = t + + r.beginScope() + for _, param := range fs.Params { + r.declare(¶m) + r.define(¶m) + } + + r.Resolve(fs.Body) + r.endScope() + + r.currFunc = enclosingFunc +} + +// VisitExpressionStmt visits an expression statement. +func (r *Resolver) VisitExpressionStmt(es *ast.ExpressionStmt) any { + r.resolveExpr(es.Expression) + return nil +} + +// VisitIfStmt visits an if statement. +func (r *Resolver) VisitIfStmt(is *ast.IfStmt) any { + r.resolveExpr(is.Condition) + r.resolveStmt(is.ThenBranch) + if is.ElseBranch != nil { + r.resolveStmt(is.ElseBranch) + } + + return nil +} + +// VisitPrintStmt visits a print statement. +func (r *Resolver) VisitPrintStmt(ps *ast.PrintStmt) any { + r.resolveExpr(ps.Expression) + + return nil +} + +// VisitReturnStmt visits a return statement. +func (r *Resolver) VisitReturnStmt(rs *ast.ReturnStmt) any { + if r.currFunc == NoneFuncType { + r.errLogger.ErrorAtToken(rs.Keyword, "Cannot return from top-level code.") + } + + if rs.Value != nil { + r.resolveExpr(rs.Value) + } + + return nil +} + +// VisitWhileStmt visits a while statement. +func (r *Resolver) VisitWhileStmt(ws *ast.WhileStmt) any { + r.resolveExpr(ws.Condition) + r.resolveStmt(ws.Body) + + return nil +} + +// VisitErrorStmt visits an error statement. +func (r *Resolver) VisitErrorStmt(es *ast.ErrorStmt) any { + return nil +} + +// VisitBinaryExpr visits a binary expression. +func (r *Resolver) VisitBinaryExpr(be *ast.BinaryExpr) any { + r.resolveExpr(be.Left) + r.resolveExpr(be.Right) + + return nil +} + +// VisitCallExpr visits a call expression. +func (r *Resolver) VisitCallExpr(ce *ast.CallExpr) any { + r.resolveExpr(ce.Callee) + for _, arg := range ce.Arguments { + r.resolveExpr(arg) + } + + return nil +} + +// VisitGroupingExpr visits a grouping expression. +func (r *Resolver) VisitGroupingExpr(ge *ast.GroupingExpr) any { + r.resolveExpr(ge.Expression) + + return nil +} + +// VisitLiteralExpr visits a literal expression. +func (r *Resolver) VisitLiteralExpr(le *ast.LiteralExpr) any { + return nil +} + +// VisitLogicalExpr visits a logical expression. +func (r *Resolver) VisitLogicalExpr(le *ast.LogicalExpr) any { + r.resolveExpr(le.Left) + r.resolveExpr(le.Right) + + return nil +} + +// VisitUnaryExpr visits a unary expression. +func (r *Resolver) VisitUnaryExpr(ue *ast.UnaryExpr) any { + r.resolveExpr(ue.Right) + + return nil +} + +// VisitErrorExpr visits an error statement. +func (r *Resolver) VisitErrorExpr(ee *ast.ErrorExpr) any { + return nil +} diff --git a/resolver/resolver_test.go b/resolver/resolver_test.go new file mode 100644 index 0000000..562febd --- /dev/null +++ b/resolver/resolver_test.go @@ -0,0 +1,276 @@ +package resolver + +import ( + "testing" + + "golox/ast" + "golox/errors" + "golox/interpreter" + "golox/token" +) + +func TestResolver_VisitBlockStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + blockStmt := &ast.BlockStmt{ + Statements: []ast.Stmt{}, + } + + resolver.VisitBlockStmt(blockStmt) + + if len(resolver.scopes) != 0 { + t.Errorf("expected scopes to be empty, got %v", resolver.scopes) + } +} + +func TestResolver_VisitVarStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + varStmt := &ast.VarStmt{ + Name: token.Token{Lexeme: "a"}, + Initializer: nil, + } + + resolver.VisitVarStmt(varStmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitVariableExpr(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + varExpr := &ast.VariableExpr{ + Name: token.Token{Lexeme: "a"}, + } + + resolver.VisitVariableExpr(varExpr) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitAssignExpr(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + assignExpr := &ast.AssignExpr{ + Name: token.Token{Lexeme: "a"}, + Value: &ast.LiteralExpr{Value: 42}, + } + + resolver.VisitAssignExpr(assignExpr) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitFunctionStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + funcStmt := &ast.FunctionStmt{ + Name: token.Token{Lexeme: "foo"}, + Params: []token.Token{}, + Body: []ast.Stmt{}, + } + + resolver.VisitFunctionStmt(funcStmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitReturnStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + returnStmt := &ast.ReturnStmt{ + Keyword: token.Token{Lexeme: "return"}, + Value: &ast.LiteralExpr{Value: 42}, + } + + resolver.currFunc = FunctionFuncType + resolver.VisitReturnStmt(returnStmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitIfStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + ifStmt := &ast.IfStmt{ + Condition: &ast.LiteralExpr{Value: true}, + ThenBranch: &ast.BlockStmt{ + Statements: []ast.Stmt{}, + }, + ElseBranch: nil, + } + + resolver.VisitIfStmt(ifStmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitWhileStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + whileStmt := &ast.WhileStmt{ + Condition: &ast.LiteralExpr{Value: true}, + Body: &ast.BlockStmt{ + Statements: []ast.Stmt{}, + }, + } + + resolver.VisitWhileStmt(whileStmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_VisitExpressionStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + exprStmt := &ast.ExpressionStmt{ + Expression: &ast.LiteralExpr{Value: 42}, + } + + resolver.VisitExpressionStmt(exprStmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_Resolve(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + stmts := []ast.Stmt{ + &ast.BlockStmt{ + Statements: []ast.Stmt{ + &ast.VarStmt{ + Name: token.Token{Lexeme: "a"}, + Initializer: nil, + }, + &ast.VarStmt{ + Name: token.Token{Lexeme: "a"}, + Initializer: nil, + }, + }, + }, + } + + resolver.Resolve(stmts) + + if len(logger.Errors) != 1 { + t.Errorf("expected 1 error, got %v", logger.Errors) + } +} + +func TestResolver_ResolveStmt(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + stmt := &ast.VarStmt{ + Name: token.Token{Lexeme: "a"}, + Initializer: nil, + } + + resolver.resolveStmt(stmt) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_ResolveExpr(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + expr := &ast.LiteralExpr{Value: 42} + + resolver.resolveExpr(expr) + + if len(logger.Errors) != 0 { + t.Errorf("expected no errors, got %v", logger.Errors) + } +} + +func TestResolver_Declare(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + resolver.beginScope() + resolver.declare(&token.Token{Lexeme: "a"}) + + if len(resolver.scopes) != 1 { + t.Errorf("expected 1 scope, got %v", resolver.scopes) + } +} + +func TestResolver_Define(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + resolver.beginScope() + resolver.define(&token.Token{Lexeme: "a"}) + + if len(resolver.scopes) != 1 { + t.Errorf("expected 1 scope, got %v", resolver.scopes) + } +} + +func TestResolver_BeginScope(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + resolver.beginScope() + + if len(resolver.scopes) != 1 { + t.Errorf("expected 1 scope, got %v", resolver.scopes) + } +} + +func TestResolver_EndScope(t *testing.T) { + logger := errors.NewMockErrorLogger() + interp := &interpreter.Interpreter{} + resolver := New(interp, logger) + + resolver.beginScope() + resolver.endScope() + + if len(resolver.scopes) != 0 { + t.Errorf("expected 0 scopes, got %v", resolver.scopes) + } +} diff --git a/testdata/closure.lox b/testdata/closure.lox new file mode 100644 index 0000000..f826c86 --- /dev/null +++ b/testdata/closure.lox @@ -0,0 +1,10 @@ +var a = "global"; +{ + fun showA() { + print a; + } + + showA(); + var a = "block"; + showA(); +}