Add return statement

main
oabrivard 1 year ago
parent 7da761a4de
commit 4c92f8e4a0

@ -1 +0,0 @@
8ad203c83fd99fe71ecaa9da9795cc32

@ -78,6 +78,16 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any {
return ap.parenthesizeExpr("var", &LiteralExpr{stmt.Name}, stmt.Initializer)
}
func (ap *Printer) VisitReturnStmt(stmt *ReturnStmt) any {
str := "(return"
if stmt.Value != nil {
str += " " + stmt.Value.Accept(ap).(string)
}
return str + ")"
}
func (ap *Printer) VisitWhileStmt(stmt *WhileStmt) any {
str := "(while"

@ -193,6 +193,16 @@ func TestPrintStmts(t *testing.T) {
},
expected: "(while true {\n 1\n})\n",
},
{
name: "Return statement",
stmts: []Stmt{
&ReturnStmt{
Keyword: token.Token{Type: token.RETURN, Lexeme: "return"},
Value: &LiteralExpr{Value: 42},
},
},
expected: "(return 42)\n",
},
}
for _, tt := range tests {

@ -9,6 +9,7 @@ type StmtVisitor[T any] interface {
VisitFunctionStmt(fs *FunctionStmt) T
VisitIfStmt(is *IfStmt) T
VisitPrintStmt(ps *PrintStmt) T
VisitReturnStmt(rs *ReturnStmt) T
VisitVarStmt(vs *VarStmt) T
VisitWhileStmt(ws *WhileStmt) T
}
@ -69,6 +70,15 @@ func (ps *PrintStmt) Accept(v StmtVisitor[any]) any {
return v.VisitPrintStmt(ps)
}
type ReturnStmt struct {
Keyword token.Token
Value Expr
}
func (rs *ReturnStmt) Accept(v StmtVisitor[any]) any {
return v.VisitReturnStmt(rs)
}
type VarStmt struct {
Name token.Token
Initializer Expr

@ -40,6 +40,7 @@ func main() {
"Function : Name token.Token, Params []token.Token, Body []Stmt",
"If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt",
"Print : Expression Expr",
"Return : Keyword token.Token, Value Expr",
"Var : Name token.Token, Initializer Expr",
"While : Condition Expr, Body Stmt",
})

@ -55,13 +55,23 @@ func (f *function) arity() int {
}
// call calls the function with the given arguments.
func (f *function) call(i *Interpreter, arguments []any) any {
func (f *function) call(i *Interpreter, arguments []any) (result any) {
env := newEnvironment(i.globals)
for i, param := range f.declaration.Params {
env.define(param.Lexeme, arguments[i])
}
defer func() {
if r := recover(); r != nil {
if e, ok := r.(ReturnValue); ok {
result = e.Value
} else {
panic(r)
}
}
}()
i.executeBlock(f.declaration.Body, env)
return nil

@ -9,6 +9,11 @@ import (
"time"
)
// ReturnValue is a struct that holds a return value when a return statement is encountered.
type ReturnValue struct {
Value any
}
// Interpreter interprets the AST.
type Interpreter struct {
errLogger errors.Logger
@ -77,6 +82,16 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any {
return nil
}
// VisitReturnStmt visits a return statement.
func (i *Interpreter) VisitReturnStmt(rs *ast.ReturnStmt) any {
var value any
if rs.Value != nil {
value = i.evaluate(rs.Value)
}
panic(ReturnValue{Value: value})
}
// VisitWhileStmt visits a while statement.
func (i *Interpreter) VisitWhileStmt(ws *ast.WhileStmt) any {
for isTruthy(i.evaluate(ws.Condition)) {

@ -86,7 +86,7 @@ func (p *Parser) varDeclaration() ast.Stmt {
return &ast.VarStmt{Name: name, Initializer: initializer}
}
// statement → exprStmt | forStmt | ifStmt | printStmt | whileStmt | block ;
// statement → exprStmt | forStmt | ifStmt | printStmt | returnStmt | whileStmt | block ;
func (p *Parser) statement() ast.Stmt {
if p.match(token.FOR) {
return p.forStatement()
@ -97,6 +97,9 @@ func (p *Parser) statement() ast.Stmt {
if p.match(token.PRINT) {
return p.printStatement()
}
if p.match(token.RETURN) {
return p.returnStatement()
}
if p.match(token.WHILE) {
return p.whileStatement()
}
@ -193,6 +196,22 @@ func (p *Parser) printStatement() ast.Stmt {
return &ast.PrintStmt{Expression: expr}
}
// returnStmt → "return" expression? ";" ;
func (p *Parser) returnStatement() ast.Stmt {
keyword := p.previous()
var value ast.Expr
if !p.check(token.SEMICOLON) {
value = p.expression()
}
err := p.consume(token.SEMICOLON, "Expect ';' after return value.")
if err != nil {
return p.fromErrorExpr(err)
}
return &ast.ReturnStmt{Keyword: keyword, Value: value}
}
// whileStmt → "while" "(" expression ")" statement ;
func (p *Parser) whileStatement() ast.Stmt {
err := p.consume(token.LEFT_PAREN, "Expect '(' after 'while'.")

@ -188,6 +188,50 @@ func TestParsePrintStmt(t *testing.T) {
}
}
func TestParseReturnStmt(t *testing.T) {
tests := []struct {
name string
tokens []token.Token
expected string
}{
{
name: "simple return statement",
tokens: []token.Token{
{Type: token.RETURN, Lexeme: "return"},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(return)\n",
},
{
name: "return statement with value",
tokens: []token.Token{
{Type: token.RETURN, Lexeme: "return"},
{Type: token.NUMBER, Lexeme: "42", Literal: 42},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(return 42)\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := New(tt.tokens, errors.NewMockErrorLogger())
stmts := parser.Parse()
if len(stmts) != 1 {
t.Fatalf("expected 1 statement, got %d", len(stmts))
}
ap := ast.NewPrinter()
s := ap.PrintStmts(stmts)
if s != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, s)
}
})
}
}
func TestParseVarStmt(t *testing.T) {
tests := []struct {
name string

@ -0,0 +1,8 @@
fun fib(n) {
if (n <= 1) return n;
return fib(n - 2) + fib(n - 1);
}
for (var i = 0; i < 20; i = i + 1) {
print fib(i);
}
Loading…
Cancel
Save