diff --git a/.task/checksum/astgen b/.task/checksum/astgen deleted file mode 100644 index 69cf4bb..0000000 --- a/.task/checksum/astgen +++ /dev/null @@ -1 +0,0 @@ -8ad203c83fd99fe71ecaa9da9795cc32 diff --git a/ast/printer.go b/ast/printer.go index 2d89766..6a17017 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -78,6 +78,16 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any { return ap.parenthesizeExpr("var", &LiteralExpr{stmt.Name}, stmt.Initializer) } +func (ap *Printer) VisitReturnStmt(stmt *ReturnStmt) any { + str := "(return" + + if stmt.Value != nil { + str += " " + stmt.Value.Accept(ap).(string) + } + + return str + ")" +} + func (ap *Printer) VisitWhileStmt(stmt *WhileStmt) any { str := "(while" diff --git a/ast/printer_test.go b/ast/printer_test.go index 770c9d0..a899f18 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -193,6 +193,16 @@ func TestPrintStmts(t *testing.T) { }, expected: "(while true {\n 1\n})\n", }, + { + name: "Return statement", + stmts: []Stmt{ + &ReturnStmt{ + Keyword: token.Token{Type: token.RETURN, Lexeme: "return"}, + Value: &LiteralExpr{Value: 42}, + }, + }, + expected: "(return 42)\n", + }, } for _, tt := range tests { diff --git a/ast/stmt.go b/ast/stmt.go index 2e84bb1..c55539c 100644 --- a/ast/stmt.go +++ b/ast/stmt.go @@ -9,6 +9,7 @@ type StmtVisitor[T any] interface { VisitFunctionStmt(fs *FunctionStmt) T VisitIfStmt(is *IfStmt) T VisitPrintStmt(ps *PrintStmt) T + VisitReturnStmt(rs *ReturnStmt) T VisitVarStmt(vs *VarStmt) T VisitWhileStmt(ws *WhileStmt) T } @@ -69,6 +70,15 @@ func (ps *PrintStmt) Accept(v StmtVisitor[any]) any { return v.VisitPrintStmt(ps) } +type ReturnStmt struct { + Keyword token.Token + Value Expr +} + +func (rs *ReturnStmt) Accept(v StmtVisitor[any]) any { + return v.VisitReturnStmt(rs) +} + type VarStmt struct { Name token.Token Initializer Expr diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index 56f4bd9..ce2046b 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -40,6 +40,7 @@ func main() { "Function : Name token.Token, Params []token.Token, Body []Stmt", "If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt", "Print : Expression Expr", + "Return : Keyword token.Token, Value Expr", "Var : Name token.Token, Initializer Expr", "While : Condition Expr, Body Stmt", }) diff --git a/interpreter/callable.go b/interpreter/callable.go index 52b4217..7a7a904 100644 --- a/interpreter/callable.go +++ b/interpreter/callable.go @@ -55,13 +55,23 @@ func (f *function) arity() int { } // call calls the function with the given arguments. -func (f *function) call(i *Interpreter, arguments []any) any { +func (f *function) call(i *Interpreter, arguments []any) (result any) { env := newEnvironment(i.globals) for i, param := range f.declaration.Params { env.define(param.Lexeme, arguments[i]) } + defer func() { + if r := recover(); r != nil { + if e, ok := r.(ReturnValue); ok { + result = e.Value + } else { + panic(r) + } + } + }() + i.executeBlock(f.declaration.Body, env) return nil diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index c4ab809..696dbfc 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -9,6 +9,11 @@ import ( "time" ) +// ReturnValue is a struct that holds a return value when a return statement is encountered. +type ReturnValue struct { + Value any +} + // Interpreter interprets the AST. type Interpreter struct { errLogger errors.Logger @@ -77,6 +82,16 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any { return nil } +// VisitReturnStmt visits a return statement. +func (i *Interpreter) VisitReturnStmt(rs *ast.ReturnStmt) any { + var value any + if rs.Value != nil { + value = i.evaluate(rs.Value) + } + + panic(ReturnValue{Value: value}) +} + // VisitWhileStmt visits a while statement. func (i *Interpreter) VisitWhileStmt(ws *ast.WhileStmt) any { for isTruthy(i.evaluate(ws.Condition)) { diff --git a/parser/parser.go b/parser/parser.go index 564b8fe..e486e94 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -86,7 +86,7 @@ func (p *Parser) varDeclaration() ast.Stmt { return &ast.VarStmt{Name: name, Initializer: initializer} } -// statement → exprStmt | forStmt | ifStmt | printStmt | whileStmt | block ; +// statement → exprStmt | forStmt | ifStmt | printStmt | returnStmt | whileStmt | block ; func (p *Parser) statement() ast.Stmt { if p.match(token.FOR) { return p.forStatement() @@ -97,6 +97,9 @@ func (p *Parser) statement() ast.Stmt { if p.match(token.PRINT) { return p.printStatement() } + if p.match(token.RETURN) { + return p.returnStatement() + } if p.match(token.WHILE) { return p.whileStatement() } @@ -193,6 +196,22 @@ func (p *Parser) printStatement() ast.Stmt { return &ast.PrintStmt{Expression: expr} } +// returnStmt → "return" expression? ";" ; +func (p *Parser) returnStatement() ast.Stmt { + keyword := p.previous() + var value ast.Expr + if !p.check(token.SEMICOLON) { + value = p.expression() + } + + err := p.consume(token.SEMICOLON, "Expect ';' after return value.") + if err != nil { + return p.fromErrorExpr(err) + } + + return &ast.ReturnStmt{Keyword: keyword, Value: value} +} + // whileStmt → "while" "(" expression ")" statement ; func (p *Parser) whileStatement() ast.Stmt { err := p.consume(token.LEFT_PAREN, "Expect '(' after 'while'.") diff --git a/parser/parser_test.go b/parser/parser_test.go index 0ff0557..04034cf 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -188,6 +188,50 @@ func TestParsePrintStmt(t *testing.T) { } } +func TestParseReturnStmt(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple return statement", + tokens: []token.Token{ + {Type: token.RETURN, Lexeme: "return"}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(return)\n", + }, + { + name: "return statement with value", + tokens: []token.Token{ + {Type: token.RETURN, Lexeme: "return"}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(return 42)\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} + func TestParseVarStmt(t *testing.T) { tests := []struct { name string diff --git a/testdata/fiborec.lox b/testdata/fiborec.lox new file mode 100644 index 0000000..d162bcc --- /dev/null +++ b/testdata/fiborec.lox @@ -0,0 +1,8 @@ +fun fib(n) { + if (n <= 1) return n; + return fib(n - 2) + fib(n - 1); +} + +for (var i = 0; i < 20; i = i + 1) { + print fib(i); +}