diff --git a/interpreter/callable.go b/interpreter/callable.go index 7a7a904..4422715 100644 --- a/interpreter/callable.go +++ b/interpreter/callable.go @@ -40,12 +40,14 @@ func (n *nativeCallable) String() string { // function is a struct that implements the callable interface. type function struct { declaration *ast.FunctionStmt + closure *environment } // newFunction creates a new function. -func newFunction(declaration *ast.FunctionStmt) *function { +func newFunction(declaration *ast.FunctionStmt, closure *environment) *function { return &function{ declaration: declaration, + closure: closure, } } @@ -56,7 +58,7 @@ func (f *function) arity() int { // call calls the function with the given arguments. func (f *function) call(i *Interpreter, arguments []any) (result any) { - env := newEnvironment(i.globals) + env := newEnvironment(f.closure) for i, param := range f.declaration.Params { env.define(param.Lexeme, arguments[i]) diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 696dbfc..2cdf2ce 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -59,7 +59,7 @@ func (i *Interpreter) VisitExpressionStmt(es *ast.ExpressionStmt) any { // VisitFunctionStmt visits a function statement. func (i *Interpreter) VisitFunctionStmt(fs *ast.FunctionStmt) any { - function := newFunction(fs) + function := newFunction(fs, i.env) i.env.define(fs.Name.Lexeme, function) return nil } diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 4f30a02..e334d1d 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -619,3 +619,80 @@ func TestInterpretWhileStatement(t *testing.T) { t.Errorf("run() = %v; want %v", out, expected) } } + +func TestInterpretFunctionStatement(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + functionStmt := &ast.FunctionStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Params: []token.Token{}, + Body: []ast.Stmt{&ast.BlockStmt{Statements: []ast.Stmt{}}}, + } + + i.VisitFunctionStmt(functionStmt) + result := i.env.get("foo") + if result == nil { + t.Errorf("expected function, got nil") + } +} + +func TestInterpretFunctionCall(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + functionStmt := &ast.FunctionStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Params: []token.Token{}, + Body: []ast.Stmt{&ast.BlockStmt{Statements: []ast.Stmt{&ast.PrintStmt{Expression: &ast.LiteralExpr{Value: 42.0}}}}}, + } + + i.VisitFunctionStmt(functionStmt) + + callExpr := &ast.CallExpr{ + Callee: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}}, + Arguments: []ast.Expr{}, + } + + i.VisitCallExpr(callExpr) +} + +func TestInterpretFunctionCallWithArguments(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + functionStmt := &ast.FunctionStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Params: []token.Token{{Type: token.IDENTIFIER, Lexeme: "a"}}, + Body: []ast.Stmt{&ast.BlockStmt{Statements: []ast.Stmt{&ast.PrintStmt{Expression: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "a"}}}}}}, + } + + i.VisitFunctionStmt(functionStmt) + + callExpr := &ast.CallExpr{ + Callee: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}}, + Arguments: []ast.Expr{ + &ast.LiteralExpr{Value: 42.0}, + }, + } + + i.VisitCallExpr(callExpr) +} + +func TestInterpretFunctionCallWithWrongNumberOfArguments(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + functionStmt := &ast.FunctionStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Params: []token.Token{{Type: token.IDENTIFIER, Lexeme: "a"}}, + Body: []ast.Stmt{&ast.BlockStmt{Statements: []ast.Stmt{}}}, + } + + i.VisitFunctionStmt(functionStmt) + + callExpr := &ast.CallExpr{ + Callee: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}}, + Arguments: []ast.Expr{}, + } + + defer func() { + if r := recover(); r != "Expected 1 arguments but got 0 [line 0]" { + t.Errorf("expected panic with 'expected 1 arguments but got 0', got %v", r) + } + }() + + i.VisitCallExpr(callExpr) +} diff --git a/testdata/counter.lox b/testdata/counter.lox new file mode 100644 index 0000000..4e89bf1 --- /dev/null +++ b/testdata/counter.lox @@ -0,0 +1,13 @@ +fun makeCounter() { + var i = 0; + fun count() { + i = i + 1; + print i; + } + + return count; +} + +var counter = makeCounter(); +counter(); // "1". +counter(); // "2".