Add while loops

main
oabrivard 1 year ago
parent 16958e9f25
commit e8b34eed57

@ -1 +1 @@
98e69ec471348f6ec8e672a1b8c9825d ef159c739c079a54bc5f6af10e2ea025

@ -8,6 +8,7 @@ type ExprVisitor[T any] interface {
VisitBinaryExpr(be *BinaryExpr) T VisitBinaryExpr(be *BinaryExpr) T
VisitGroupingExpr(ge *GroupingExpr) T VisitGroupingExpr(ge *GroupingExpr) T
VisitLiteralExpr(le *LiteralExpr) T VisitLiteralExpr(le *LiteralExpr) T
VisitLogicalExpr(le *LogicalExpr) T
VisitUnaryExpr(ue *UnaryExpr) T VisitUnaryExpr(ue *UnaryExpr) T
VisitVariableExpr(ve *VariableExpr) T VisitVariableExpr(ve *VariableExpr) T
} }
@ -59,6 +60,16 @@ func (le *LiteralExpr) Accept(v ExprVisitor[any]) any {
return v.VisitLiteralExpr(le) return v.VisitLiteralExpr(le)
} }
type LogicalExpr struct {
Left Expr
Operator token.Token
Right Expr
}
func (le *LogicalExpr) Accept(v ExprVisitor[any]) any {
return v.VisitLogicalExpr(le)
}
type UnaryExpr struct { type UnaryExpr struct {
Operator token.Token Operator token.Token
Right Expr Right Expr

@ -59,6 +59,20 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any {
return ap.parenthesizeExpr("var", &LiteralExpr{stmt.Name}, stmt.Initializer) return ap.parenthesizeExpr("var", &LiteralExpr{stmt.Name}, stmt.Initializer)
} }
func (ap *Printer) VisitWhileStmt(stmt *WhileStmt) any {
str := "(while"
if stmt.Condition != nil {
str += " " + stmt.Condition.Accept(ap).(string)
}
if stmt.Body != nil {
str += " " + ap.bracesizeStmt(stmt.Body)
}
return str + ")"
}
func (ap *Printer) VisitBlockStmt(stmt *BlockStmt) any { func (ap *Printer) VisitBlockStmt(stmt *BlockStmt) any {
return ap.bracesizeStmt(stmt.Statements...) return ap.bracesizeStmt(stmt.Statements...)
} }
@ -82,6 +96,10 @@ func (ap *Printer) VisitLiteralExpr(expr *LiteralExpr) any {
return fmt.Sprint(expr.Value) return fmt.Sprint(expr.Value)
} }
func (ap *Printer) VisitLogicalExpr(expr *LogicalExpr) any {
return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Left, expr.Right)
}
func (ap *Printer) VisitUnaryExpr(expr *UnaryExpr) any { func (ap *Printer) VisitUnaryExpr(expr *UnaryExpr) any {
return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Right) return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Right)
} }

@ -27,6 +27,11 @@ func TestPrintExpr(t *testing.T) {
}, },
expected: "(group 1)", expected: "(group 1)",
}, },
{
name: "Logical expression",
expr: &LogicalExpr{Left: &LiteralExpr{Value: 1}, Operator: token.Token{Type: token.AND, Lexeme: "and"}, Right: &LiteralExpr{Value: 2}},
expected: "(and 1 2)",
},
{ {
name: "Literal expression", name: "Literal expression",
expr: &LiteralExpr{Value: 123}, expr: &LiteralExpr{Value: 123},
@ -95,6 +100,25 @@ func TestPrintStmts(t *testing.T) {
}, },
expected: "42\n", expected: "42\n",
}, },
{
name: "Error statement",
stmts: []Stmt{
&ErrorStmt{
Value: "error",
},
},
expected: "error\n",
},
{
name: "Var statement",
stmts: []Stmt{
&VarStmt{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"},
Initializer: &LiteralExpr{Value: 42},
},
},
expected: "(var foo 42)\n",
},
} }
for _, tt := range tests { for _, tt := range tests {
@ -146,3 +170,19 @@ func TestPrintIfStmt(t *testing.T) {
t.Errorf("expected %v, got %v", expected, result) t.Errorf("expected %v, got %v", expected, result)
} }
} }
func TestPrintWhileStmt(t *testing.T) {
stmt := &WhileStmt{
Condition: &LiteralExpr{Value: true},
Body: &ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
}
printer := NewPrinter()
result := printer.PrintStmts([]Stmt{stmt})
expected := "(while true {\n 1\n})\n"
if result != expected {
t.Errorf("expected %v, got %v", expected, result)
}
}

@ -9,6 +9,7 @@ type StmtVisitor[T any] interface {
VisitIfStmt(is *IfStmt) T VisitIfStmt(is *IfStmt) T
VisitPrintStmt(ps *PrintStmt) T VisitPrintStmt(ps *PrintStmt) T
VisitVarStmt(vs *VarStmt) T VisitVarStmt(vs *VarStmt) T
VisitWhileStmt(ws *WhileStmt) T
} }
type Stmt interface { type Stmt interface {
@ -66,3 +67,12 @@ func (vs *VarStmt) Accept(v StmtVisitor[any]) any {
return v.VisitVarStmt(vs) return v.VisitVarStmt(vs)
} }
type WhileStmt struct {
Condition Expr
Body Stmt
}
func (ws *WhileStmt) Accept(v StmtVisitor[any]) any {
return v.VisitWhileStmt(ws)
}

@ -27,6 +27,7 @@ func main() {
"Binary : Left Expr, Operator token.Token, Right Expr", "Binary : Left Expr, Operator token.Token, Right Expr",
"Grouping : Expression Expr", "Grouping : Expression Expr",
"Literal : Value any", "Literal : Value any",
"Logical : Left Expr, Operator token.Token, Right Expr",
"Unary : Operator token.Token, Right Expr", "Unary : Operator token.Token, Right Expr",
"Variable : Name token.Token", "Variable : Name token.Token",
}) })
@ -38,6 +39,7 @@ func main() {
"If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt", "If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt",
"Print : Expression Expr", "Print : Expression Expr",
"Var : Name token.Token, Initializer Expr", "Var : Name token.Token, Initializer Expr",
"While : Condition Expr, Body Stmt",
}) })
} }

@ -57,6 +57,15 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any {
return nil return nil
} }
// VisitWhileStmt visits a while statement.
func (i *Interpreter) VisitWhileStmt(ws *ast.WhileStmt) any {
for isTruthy(i.evaluate(ws.Condition)) {
i.execute(ws.Body)
}
return nil
}
// VisitVarStmt visits a var statement. // VisitVarStmt visits a var statement.
func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any { func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any {
var value any var value any
@ -99,6 +108,23 @@ func (i *Interpreter) VisitLiteralExpr(l *ast.LiteralExpr) any {
return l.Value return l.Value
} }
// VisitLogicalExpr visits a LogicalExpr.
func (i *Interpreter) VisitLogicalExpr(l *ast.LogicalExpr) any {
left := i.evaluate(l.Left)
if l.Operator.Type == token.OR {
if isTruthy(left) {
return left
}
} else {
if !isTruthy(left) {
return left
}
}
return i.evaluate(l.Right)
}
// VisitGroupingExpr visits a GroupingExpr. // VisitGroupingExpr visits a GroupingExpr.
func (i *Interpreter) VisitGroupingExpr(g *ast.GroupingExpr) any { func (i *Interpreter) VisitGroupingExpr(g *ast.GroupingExpr) any {
return i.evaluate(g.Expression) return i.evaluate(g.Expression)

@ -277,6 +277,22 @@ func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) {
i.VisitBinaryExpr(binary) i.VisitBinaryExpr(binary)
} }
func TestInterpretLogicalExpr(t *testing.T) {
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: true}
right := &ast.LiteralExpr{Value: false}
logical := &ast.LogicalExpr{
Left: left,
Operator: token.Token{Type: token.AND, Lexeme: "and"},
Right: right,
}
result := i.VisitLogicalExpr(logical)
if result != false {
t.Errorf("expected false, got %v", result)
}
}
func TestInterpretErrorStatement(t *testing.T) { func TestInterpretErrorStatement(t *testing.T) {
i := New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
errorStmt := &ast.ErrorStmt{Value: "error"} errorStmt := &ast.ErrorStmt{Value: "error"}
@ -538,3 +554,68 @@ func TestInterpretIfStatementElseBranch(t *testing.T) {
t.Errorf("run() = %v; want %v", out, expected) t.Errorf("run() = %v; want %v", out, expected)
} }
} }
func TestInterpretWhileStatement(t *testing.T) {
old := os.Stdout // keep backup of the real stdout
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
os.Stdout = w
outC := make(chan string)
// copy the output in a separate goroutine so printing can't block indefinitely
go func() {
var buf bytes.Buffer
io.Copy(&buf, r)
outC <- buf.String()
}()
// begin unit test
i := New(errors.NewMockErrorLogger())
varStmt := &ast.VarStmt{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"},
Initializer: &ast.LiteralExpr{Value: 3.0},
}
i.VisitVarStmt(varStmt)
whileStmt := &ast.WhileStmt{
Condition: &ast.BinaryExpr{
Left: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}},
Operator: token.Token{Type: token.GREATER, Lexeme: ">"},
Right: &ast.LiteralExpr{Value: 0.0},
},
Body: &ast.BlockStmt{
Statements: []ast.Stmt{
&ast.PrintStmt{
Expression: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}},
},
&ast.ExpressionStmt{
Expression: &ast.AssignExpr{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"},
Value: &ast.BinaryExpr{
Left: &ast.VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "i"}},
Operator: token.Token{Type: token.MINUS, Lexeme: "-"},
Right: &ast.LiteralExpr{Value: 1.0},
},
},
},
},
},
}
i.VisitWhileStmt(whileStmt)
// end unit test
// back to normal state
w.Close()
os.Stdout = old // restoring the real stdout
out := <-outC
// reading our temp stdout
expected := "3\n2\n1\n"
if out != expected {
t.Errorf("run() = %v; want %v", out, expected)
}
}

@ -83,7 +83,7 @@ func (p *Parser) varDeclaration() ast.Stmt {
return &ast.VarStmt{Name: name, Initializer: initializer} return &ast.VarStmt{Name: name, Initializer: initializer}
} }
// statement → exprStmt | ifStmt | printStmt | block ; // statement → exprStmt | ifStmt | printStmt | whileStmt | block ;
func (p *Parser) statement() ast.Stmt { func (p *Parser) statement() ast.Stmt {
if p.match(token.IF) { if p.match(token.IF) {
return p.ifStatement() return p.ifStatement()
@ -91,6 +91,9 @@ func (p *Parser) statement() ast.Stmt {
if p.match(token.PRINT) { if p.match(token.PRINT) {
return p.printStatement() return p.printStatement()
} }
if p.match(token.WHILE) {
return p.whileStatement()
}
if p.match(token.LEFT_BRACE) { if p.match(token.LEFT_BRACE) {
return p.blockStatement() return p.blockStatement()
} }
@ -132,6 +135,25 @@ func (p *Parser) printStatement() ast.Stmt {
return &ast.PrintStmt{Expression: expr} return &ast.PrintStmt{Expression: expr}
} }
// whileStmt → "while" "(" expression ")" statement ;
func (p *Parser) whileStatement() ast.Stmt {
err := p.consume(token.LEFT_PAREN, "Expect '(' after 'while'.")
if err != nil {
return p.fromErrorExpr(err)
}
condition := p.expression()
err = p.consume(token.RIGHT_PAREN, "Expect ')' after while condition.")
if err != nil {
return p.fromErrorExpr(err)
}
body := p.statement()
return &ast.WhileStmt{Condition: condition, Body: body}
}
// block → "{" declaration* "}" ; // block → "{" declaration* "}" ;
func (p *Parser) blockStatement() ast.Stmt { func (p *Parser) blockStatement() ast.Stmt {
statements := []ast.Stmt{} statements := []ast.Stmt{}
@ -164,9 +186,9 @@ func (p *Parser) expression() ast.Expr {
return p.assignment() return p.assignment()
} }
// assignment → IDENTIFIER "=" assignment | equality ; // assignment → IDENTIFIER "=" assignment | logic_or ;
func (p *Parser) assignment() ast.Expr { func (p *Parser) assignment() ast.Expr {
expr := p.equality() expr := p.or()
if p.match(token.EQUAL) { if p.match(token.EQUAL) {
equals := p.previous() equals := p.previous()
@ -182,6 +204,32 @@ func (p *Parser) assignment() ast.Expr {
return expr return expr
} }
// logic_or → logic_and ( "or" logic_and )* ;
func (p *Parser) or() ast.Expr {
expr := p.and()
for p.match(token.OR) {
operator := p.previous()
right := p.and()
expr = &ast.LogicalExpr{Left: expr, Operator: operator, Right: right}
}
return expr
}
// logic_and → equality ( "and" equality )* ;
func (p *Parser) and() ast.Expr {
expr := p.equality()
for p.match(token.AND) {
operator := p.previous()
right := p.equality()
expr = &ast.LogicalExpr{Left: expr, Operator: operator, Right: right}
}
return expr
}
// equality → comparison ( ( "!=" | "==" ) comparison )* ; // equality → comparison ( ( "!=" | "==" ) comparison )* ;
func (p *Parser) equality() ast.Expr { func (p *Parser) equality() ast.Expr {
expr := p.comparison() expr := p.comparison()

@ -58,6 +58,17 @@ func TestExpressionParsing(t *testing.T) {
}, },
expected: "(> 1 2)", expected: "(> 1 2)",
}, },
{
name: "Logical expression",
tokens: []token.Token{
{Type: token.NUMBER, Literal: 1},
{Type: token.GREATER, Lexeme: "and"},
{Type: token.NUMBER, Literal: 2},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(and 1 2)",
},
{ {
name: "Equality expression", name: "Equality expression",
tokens: []token.Token{ tokens: []token.Token{
@ -380,3 +391,42 @@ func TestParseIfElseStatement(t *testing.T) {
}) })
} }
} }
func TestParseWhileStatement(t *testing.T) {
tests := []struct {
name string
tokens []token.Token
expected string
}{
{
name: "simple while statement",
tokens: []token.Token{
{Type: token.WHILE, Lexeme: "while"},
{Type: token.LEFT_PAREN, Lexeme: "("},
{Type: token.NUMBER, Lexeme: "42", Literal: 42},
{Type: token.RIGHT_PAREN, Lexeme: ")"},
{Type: token.PRINT, Lexeme: "print"},
{Type: token.NUMBER, Lexeme: "42", Literal: 42},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(while 42 {\n (print 42)\n})\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := New(tt.tokens, errors.NewMockErrorLogger())
stmts := parser.Parse()
if len(stmts) != 1 {
t.Fatalf("expected 1 statement, got %d", len(stmts))
}
ap := ast.NewPrinter()
s := ap.PrintStmts(stmts)
if s != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, s)
}
})
}
}

Loading…
Cancel
Save