From 41bf68b68574fa409eed5cf9133314ec959ec410 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Mon, 25 Nov 2024 10:48:41 +0100 Subject: [PATCH] Add for loops --- parser/parser.go | 57 ++++++++++++++++++++++++++++++++++++++++++- parser/parser_test.go | 52 +++++++++++++++++++++++++++++++++++++++ testdata/fibo.lox | 8 ++++++ 3 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 testdata/fibo.lox diff --git a/parser/parser.go b/parser/parser.go index 20b358c..b5d77f4 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -83,8 +83,11 @@ func (p *Parser) varDeclaration() ast.Stmt { return &ast.VarStmt{Name: name, Initializer: initializer} } -// statement → exprStmt | ifStmt | printStmt | whileStmt | block ; +// statement → exprStmt | forStmt | ifStmt | printStmt | whileStmt | block ; func (p *Parser) statement() ast.Stmt { + if p.match(token.FOR) { + return p.forStatement() + } if p.match(token.IF) { return p.ifStatement() } @@ -101,6 +104,58 @@ func (p *Parser) statement() ast.Stmt { return p.expressionStatement() } +// forStmt → "for" "(" ( varDecl | exprStmt | ";" ) expression? ";" expression? ")" statement ; +func (p *Parser) forStatement() ast.Stmt { + err := p.consume(token.LEFT_PAREN, "Expect '(' after 'for'.") + if err != nil { + return p.fromErrorExpr(err) + } + + var initializer ast.Stmt + if p.match(token.SEMICOLON) { + initializer = nil + } else if p.match(token.VAR) { + initializer = p.varDeclaration() + } else { + initializer = p.expressionStatement() + } + + var condition ast.Expr + if !p.check(token.SEMICOLON) { + condition = p.expression() + } + err = p.consume(token.SEMICOLON, "Expect ';' after loop condition.") + if err != nil { + return p.fromErrorExpr(err) + } + + var increment ast.Expr + if !p.check(token.RIGHT_PAREN) { + increment = p.expression() + } + err = p.consume(token.RIGHT_PAREN, "Expect ')' after for clauses.") + if err != nil { + return p.fromErrorExpr(err) + } + + body := p.statement() + + if increment != nil { + body = &ast.BlockStmt{Statements: []ast.Stmt{body, &ast.ExpressionStmt{Expression: increment}}} + } + + if condition == nil { + condition = &ast.LiteralExpr{Value: true} + } + body = &ast.WhileStmt{Condition: condition, Body: body} + + if initializer != nil { + body = &ast.BlockStmt{Statements: []ast.Stmt{initializer, body}} + } + + return body +} + // ifStmt → "if" "(" expression ")" statement ( "else" statement )? ; func (p *Parser) ifStatement() ast.Stmt { err := p.consume(token.LEFT_PAREN, "Expect '(' after 'if'.") diff --git a/parser/parser_test.go b/parser/parser_test.go index 326adda..0ff0557 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -430,3 +430,55 @@ func TestParseWhileStatement(t *testing.T) { }) } } + +func TestParseForStatement(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple for statement", + tokens: []token.Token{ + {Type: token.FOR, Lexeme: "for"}, + {Type: token.LEFT_PAREN, Lexeme: "("}, + {Type: token.VAR, Lexeme: "var"}, + {Type: token.IDENTIFIER, Lexeme: "i"}, + {Type: token.EQUAL, Lexeme: "="}, + {Type: token.NUMBER, Lexeme: "0", Literal: 0.0}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.IDENTIFIER, Lexeme: "i"}, + {Type: token.LESS, Lexeme: "<"}, + {Type: token.NUMBER, Lexeme: "10", Literal: 10.0}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.IDENTIFIER, Lexeme: "i"}, + {Type: token.EQUAL, Lexeme: "="}, + {Type: token.IDENTIFIER, Lexeme: "i"}, + {Type: token.PLUS, Lexeme: "+"}, + {Type: token.NUMBER, Lexeme: "1", Literal: 1.0}, + {Type: token.RIGHT_PAREN, Lexeme: ")"}, + {Type: token.PRINT, Lexeme: "print"}, + {Type: token.IDENTIFIER, Lexeme: "i"}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "{\n (var i 0)\n (while (< i 10) {\n {\n (print i)\n (= i (+ i 1))\n}\n})\n}\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +} diff --git a/testdata/fibo.lox b/testdata/fibo.lox new file mode 100644 index 0000000..91d3786 --- /dev/null +++ b/testdata/fibo.lox @@ -0,0 +1,8 @@ +var a = 0; +var temp; + +for (var b = 1; a < 10000; b = temp + b) { + print a; + temp = a; + a = b; +}