From 1b91ca792e7da3a16672fa7bdb0865ae254d9ab1 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Thu, 21 Nov 2024 10:41:23 +0100 Subject: [PATCH] Add assignment --- .task/checksum/astgen | 2 +- ast/expr.go | 10 +++++++ ast/printer.go | 4 +++ ast/printer_test.go | 8 ++++++ cmd/astgen/main.go | 1 + interpreter/environment.go | 14 +++++++++ interpreter/interpreter.go | 7 +++++ interpreter/interpreter_test.go | 50 +++++++++++++++++++++++++++++++++ parser/parser.go | 22 +++++++++++++-- parser/parser_test.go | 36 ++++++++++++++++++++++++ 10 files changed, 151 insertions(+), 3 deletions(-) diff --git a/.task/checksum/astgen b/.task/checksum/astgen index 10f1d8c..cbdfebb 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -ea7730697fb5d7edf1961630c5f68b96 +4ade8f4433feb634e5c96f97b05adfc1 diff --git a/ast/expr.go b/ast/expr.go index a4ddc0b..44bdbad 100644 --- a/ast/expr.go +++ b/ast/expr.go @@ -4,6 +4,7 @@ import "golox/token" type ExprVisitor[T any] interface { VisitErrorExpr(ee *ErrorExpr) T + VisitAssignExpr(ae *AssignExpr) T VisitBinaryExpr(be *BinaryExpr) T VisitGroupingExpr(ge *GroupingExpr) T VisitLiteralExpr(le *LiteralExpr) T @@ -23,6 +24,15 @@ func (ee *ErrorExpr) Accept(v ExprVisitor[any]) any { return v.VisitErrorExpr(ee) } +type AssignExpr struct { + Name token.Token + Value Expr +} + +func (ae *AssignExpr) Accept(v ExprVisitor[any]) any { + return v.VisitAssignExpr(ae) +} + type BinaryExpr struct { Left Expr Operator token.Token diff --git a/ast/printer.go b/ast/printer.go index 985a7b4..3ec69e4 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -68,6 +68,10 @@ func (ap *Printer) VisitErrorExpr(expr *ErrorExpr) any { return expr.Value } +func (ap *Printer) VisitAssignExpr(expr *AssignExpr) any { + return ap.parenthesize("=", &VariableExpr{expr.Name}, expr.Value) +} + func (ap *Printer) parenthesize(name string, exprs ...Expr) string { str := "(" + name diff --git a/ast/printer_test.go b/ast/printer_test.go index 4ea1241..a6e6b5e 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -50,6 +50,14 @@ func TestPrintExpr(t *testing.T) { expr: &ErrorExpr{Value: "error"}, expected: "error", }, + { + name: "Assign expression", + expr: &AssignExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Value: &LiteralExpr{Value: 42}, + }, + expected: "(= foo 42)", + }, } for _, tt := range tests { diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index 0ca5b56..0384fc8 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -23,6 +23,7 @@ func main() { defineAst(d, "Expr", []string{ "Error : Value string", + "Assign : Name token.Token, Value Expr", "Binary : Left Expr, Operator token.Token, Right Expr", "Grouping : Expression Expr", "Literal : Value any", diff --git a/interpreter/environment.go b/interpreter/environment.go index e5a9c21..e97137d 100644 --- a/interpreter/environment.go +++ b/interpreter/environment.go @@ -1,17 +1,21 @@ package interpreter +// environment represents the environment in which the interpreter operates. type environment struct { values map[string]any } +// newEnvironment creates a new environment. func newEnvironment() *environment { return &environment{values: make(map[string]any)} } +// define defines a new variable in the environment. func (e *environment) define(name string, value any) { e.values[name] = value } +// get gets the value of a variable in the environment. func (e *environment) get(name string) any { value, ok := e.values[name] if !ok { @@ -20,3 +24,13 @@ func (e *environment) get(name string) any { return value } + +// assign assigns a new value to a variable in the environment. +func (e *environment) assign(name string, value any) { + _, ok := e.values[name] + if !ok { + panic("Undefined variable '" + name + "'.") + } + + e.values[name] = value +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 0897339..4716a1d 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -70,6 +70,13 @@ func (i *Interpreter) VisitErrorExpr(e *ast.ErrorExpr) any { panic(e.Value) } +// VisitAssignExpr visits an AssignExpr. +func (i *Interpreter) VisitAssignExpr(a *ast.AssignExpr) any { + value := i.evaluate(a.Value) + i.env.assign(a.Name.Lexeme, value) + return value +} + // VisitLiteralExpr visits a LiteralExpr. func (i *Interpreter) VisitLiteralExpr(l *ast.LiteralExpr) any { return l.Value diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 6d51431..c7b73eb 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -277,6 +277,19 @@ func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { i.VisitBinaryExpr(binary) } +func TestInterpretErrorStatement(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + errorStmt := &ast.ErrorStmt{Value: "error"} + + defer func() { + if r := recover(); r != "error" { + t.Errorf("expected panic with 'error', got %v", r) + } + }() + + i.VisitErrorStmt(errorStmt) +} + func TestInterpretExprStatement(t *testing.T) { i := New(errors.NewMockErrorLogger()) literal := &ast.LiteralExpr{Value: 42.0} @@ -338,3 +351,40 @@ func TestInterpretVarStatement(t *testing.T) { t.Errorf("expected 42, got %v", result) } } + +func TestInterpretVarStatementNoInitializer(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + varStmt := &ast.VarStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + } + + i.VisitVarStmt(varStmt) + result := i.env.get("foo") + if result != nil { + t.Errorf("expected nil, got %v", result) + } +} + +func TestInterpretAssignment(t *testing.T) { + i := New(errors.NewMockErrorLogger()) + varStmt := &ast.VarStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + } + + i.VisitVarStmt(varStmt) + result := i.env.get("foo") + if result != nil { + t.Errorf("expected nil, got %v", result) + } + + assign := &ast.AssignExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Value: &ast.LiteralExpr{Value: 42.0}, + } + + i.VisitAssignExpr(assign) + result = i.env.get("foo") + if result != 42.0 { + t.Errorf("expected 42, got %v", result) + } +} diff --git a/parser/parser.go b/parser/parser.go index cbdb0e4..ec84297 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -114,9 +114,27 @@ func (p *Parser) printStatement() ast.Stmt { return &ast.PrintStmt{Expression: expr} } -// expression → equality ; +// expression → assignment ; func (p *Parser) expression() ast.Expr { - return p.equality() + return p.assignment() +} + +// assignment → IDENTIFIER "=" assignment | equality ; +func (p *Parser) assignment() ast.Expr { + expr := p.equality() + + if p.match(token.EQUAL) { + equals := p.previous() + value := p.assignment() + + if v, ok := expr.(*ast.VariableExpr); ok { + return &ast.AssignExpr{Name: v.Name, Value: value} + } + + p.newErrorExpr(equals, "Invalid assignment target.") + } + + return expr } // equality → comparison ( ( "!=" | "==" ) comparison )* ; diff --git a/parser/parser_test.go b/parser/parser_test.go index e7b5b71..51b105a 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -223,3 +223,39 @@ func TestParseVarStmt(t *testing.T) { }) } } + +func TestParseAssignment(t *testing.T) { + tests := []struct { + name string + tokens []token.Token + expected string + }{ + { + name: "simple assignment", + tokens: []token.Token{ + {Type: token.IDENTIFIER, Lexeme: "foo"}, + {Type: token.EQUAL, Lexeme: "="}, + {Type: token.NUMBER, Lexeme: "42", Literal: 42}, + {Type: token.SEMICOLON, Lexeme: ";"}, + {Type: token.EOF}, + }, + expected: "(= foo 42)\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := New(tt.tokens, errors.NewMockErrorLogger()) + stmts := parser.Parse() + if len(stmts) != 1 { + t.Fatalf("expected 1 statement, got %d", len(stmts)) + } + + ap := ast.NewPrinter() + s := ap.PrintStmts(stmts) + if s != tt.expected { + t.Errorf("expected %v, got %v", tt.expected, s) + } + }) + } +}