Add global variables and environments

main
oabrivard 1 year ago
parent b97f44db2d
commit 5eacc29825

@ -1 +1 @@
bf6f2cc6973975f2bcc789463842856a ea7730697fb5d7edf1961630c5f68b96

@ -8,6 +8,7 @@ type ExprVisitor[T any] interface {
VisitGroupingExpr(ge *GroupingExpr) T VisitGroupingExpr(ge *GroupingExpr) T
VisitLiteralExpr(le *LiteralExpr) T VisitLiteralExpr(le *LiteralExpr) T
VisitUnaryExpr(ue *UnaryExpr) T VisitUnaryExpr(ue *UnaryExpr) T
VisitVariableExpr(ve *VariableExpr) T
} }
type Expr interface { type Expr interface {
@ -57,3 +58,11 @@ func (ue *UnaryExpr) Accept(v ExprVisitor[any]) any {
return v.VisitUnaryExpr(ue) return v.VisitUnaryExpr(ue)
} }
type VariableExpr struct {
Name token.Token
}
func (ve *VariableExpr) Accept(v ExprVisitor[any]) any {
return v.VisitVariableExpr(ve)
}

@ -41,6 +41,10 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any {
return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer) return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer)
} }
func (ap *Printer) VisitVariableExpr(expr *VariableExpr) any {
return expr.Name.Lexeme
}
func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any { func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any {
return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right)
} }
@ -68,6 +72,9 @@ func (ap *Printer) parenthesize(name string, exprs ...Expr) string {
str := "(" + name str := "(" + name
for _, expr := range exprs { for _, expr := range exprs {
if expr == nil {
continue
}
str += " " + expr.Accept(ap).(string) str += " " + expr.Accept(ap).(string)
} }

@ -27,6 +27,7 @@ func main() {
"Grouping : Expression Expr", "Grouping : Expression Expr",
"Literal : Value any", "Literal : Value any",
"Unary : Operator token.Token, Right Expr", "Unary : Operator token.Token, Right Expr",
"Variable : Name token.Token",
}) })
defineAst(d, "Stmt", []string{ defineAst(d, "Stmt", []string{

@ -0,0 +1,22 @@
package interpreter
type environment struct {
values map[string]any
}
func newEnvironment() *environment {
return &environment{values: make(map[string]any)}
}
func (e *environment) define(name string, value any) {
e.values[name] = value
}
func (e *environment) get(name string) any {
value, ok := e.values[name]
if !ok {
panic("Undefined variable '" + name + "'.")
}
return value
}

@ -11,11 +11,12 @@ import (
// Interpreter interprets the AST. // Interpreter interprets the AST.
type Interpreter struct { type Interpreter struct {
errLogger errors.Logger errLogger errors.Logger
env *environment
} }
// New creates a new Interpreter. // New creates a new Interpreter.
func New(el errors.Logger) *Interpreter { func New(el errors.Logger) *Interpreter {
return &Interpreter{el} return &Interpreter{el, newEnvironment()}
} }
// Interpret interprets the AST. // Interpret interprets the AST.
@ -47,6 +48,12 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any {
// VisitVarStmt visits a var statement. // VisitVarStmt visits a var statement.
func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any { func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any {
var value any
if vs.Initializer != nil {
value = i.evaluate(vs.Initializer)
}
i.env.define(vs.Name.Lexeme, value)
return nil return nil
} }
@ -142,6 +149,11 @@ func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any {
panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line)) panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line))
} }
// VisitVariableExpr visits a VariableExpr.
func (i *Interpreter) VisitVariableExpr(v *ast.VariableExpr) any {
return i.env.get(v.Name.Lexeme)
}
// checkNumberOperands checks if the operands are numbers. // checkNumberOperands checks if the operands are numbers.
func checkNumberOperands(operator token.Token, operands ...any) { func checkNumberOperands(operator token.Token, operands ...any) {
for _, operand := range operands { for _, operand := range operands {

@ -1,16 +1,18 @@
// FILE: interpreter_test.go // FILE: interpreter_test.go
package interpreter_test package interpreter
import ( import (
"bytes"
"golox/ast" "golox/ast"
"golox/errors" "golox/errors"
"golox/interpreter"
"golox/token" "golox/token"
"io"
"os"
"testing" "testing"
) )
func TestInterpretLiteralExpr(t *testing.T) { func TestInterpretLiteralExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42} literal := &ast.LiteralExpr{Value: 42}
result := i.VisitLiteralExpr(literal) result := i.VisitLiteralExpr(literal)
@ -20,7 +22,7 @@ func TestInterpretLiteralExpr(t *testing.T) {
} }
func TestInterpretGroupingExpr(t *testing.T) { func TestInterpretGroupingExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42} literal := &ast.LiteralExpr{Value: 42}
grouping := &ast.GroupingExpr{Expression: literal} grouping := &ast.GroupingExpr{Expression: literal}
@ -31,7 +33,7 @@ func TestInterpretGroupingExpr(t *testing.T) {
} }
func TestInterpretUnaryExpr(t *testing.T) { func TestInterpretUnaryExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0} literal := &ast.LiteralExpr{Value: 42.0}
unary := &ast.UnaryExpr{ unary := &ast.UnaryExpr{
Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, Operator: token.Token{Type: token.MINUS, Lexeme: "-"},
@ -45,7 +47,7 @@ func TestInterpretUnaryExpr(t *testing.T) {
} }
func TestInterpretUnaryExprBang(t *testing.T) { func TestInterpretUnaryExprBang(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: true} literal := &ast.LiteralExpr{Value: true}
unary := &ast.UnaryExpr{ unary := &ast.UnaryExpr{
Operator: token.Token{Type: token.BANG, Lexeme: "!"}, Operator: token.Token{Type: token.BANG, Lexeme: "!"},
@ -59,7 +61,7 @@ func TestInterpretUnaryExprBang(t *testing.T) {
} }
func TestInterpretErrorExpr(t *testing.T) { func TestInterpretErrorExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
errorExpr := &ast.ErrorExpr{Value: "error"} errorExpr := &ast.ErrorExpr{Value: "error"}
defer func() { defer func() {
@ -72,7 +74,7 @@ func TestInterpretErrorExpr(t *testing.T) {
} }
func TestInterpretExpr(t *testing.T) { func TestInterpretExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0} literal := &ast.LiteralExpr{Value: 42.0}
defer func() { defer func() {
@ -88,7 +90,7 @@ func TestInterpretExpr(t *testing.T) {
} }
func TestInterpretBinaryExpr(t *testing.T) { func TestInterpretBinaryExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0} right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -104,7 +106,7 @@ func TestInterpretBinaryExpr(t *testing.T) {
} }
func TestInterpretBinaryExprDivisionByZero(t *testing.T) { func TestInterpretBinaryExprDivisionByZero(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 0.0} right := &ast.LiteralExpr{Value: 0.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -123,7 +125,7 @@ func TestInterpretBinaryExprDivisionByZero(t *testing.T) {
} }
func TestInterpretBinaryExprAddition(t *testing.T) { func TestInterpretBinaryExprAddition(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0} right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -139,7 +141,7 @@ func TestInterpretBinaryExprAddition(t *testing.T) {
} }
func TestInterpretBinaryExprSubtraction(t *testing.T) { func TestInterpretBinaryExprSubtraction(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0} right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -155,7 +157,7 @@ func TestInterpretBinaryExprSubtraction(t *testing.T) {
} }
func TestInterpretBinaryExprStringConcatenation(t *testing.T) { func TestInterpretBinaryExprStringConcatenation(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: "foo"} left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: "bar"} right := &ast.LiteralExpr{Value: "bar"}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -171,7 +173,7 @@ func TestInterpretBinaryExprStringConcatenation(t *testing.T) {
} }
func TestInterpretBinaryExprInvalidOperands(t *testing.T) { func TestInterpretBinaryExprInvalidOperands(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: "foo"} left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -190,7 +192,7 @@ func TestInterpretBinaryExprInvalidOperands(t *testing.T) {
} }
func TestInterpretBinaryExprComparison(t *testing.T) { func TestInterpretBinaryExprComparison(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0} right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -206,7 +208,7 @@ func TestInterpretBinaryExprComparison(t *testing.T) {
} }
func TestInterpretBinaryExprComparisonEqual(t *testing.T) { func TestInterpretBinaryExprComparisonEqual(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -222,7 +224,7 @@ func TestInterpretBinaryExprComparisonEqual(t *testing.T) {
} }
func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) { func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0} right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -238,7 +240,7 @@ func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) {
} }
func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) { func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: "foo"} left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: 42.0} right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -257,7 +259,7 @@ func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) {
} }
func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger()) i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0} left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0} right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{ binary := &ast.BinaryExpr{
@ -274,3 +276,65 @@ func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) {
i.VisitBinaryExpr(binary) i.VisitBinaryExpr(binary)
} }
func TestInterpretExprStatement(t *testing.T) {
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0}
exprStmt := &ast.ExpressionStmt{Expression: literal}
result := i.VisitExpressionStmt(exprStmt)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
}
func TestInterpretPrintStatement(t *testing.T) {
old := os.Stdout // keep backup of the real stdout
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
os.Stdout = w
outC := make(chan string)
// copy the output in a separate goroutine so printing can't block indefinitely
go func() {
var buf bytes.Buffer
io.Copy(&buf, r)
outC <- buf.String()
}()
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0}
printStmt := &ast.PrintStmt{Expression: literal}
result := i.VisitPrintStmt(printStmt)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
// back to normal state
w.Close()
os.Stdout = old // restoring the real stdout
out := <-outC
// reading our temp stdout
expected := "42\n"
if out != expected {
t.Errorf("run() = %v; want %v", out, expected)
}
}
func TestInterpretVarStatement(t *testing.T) {
i := New(errors.NewMockErrorLogger())
varStmt := &ast.VarStmt{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"},
Initializer: &ast.LiteralExpr{Value: 42.0},
}
i.VisitVarStmt(varStmt)
result := i.env.get("foo")
if result != 42.0 {
t.Errorf("expected 42, got %v", result)
}
}

@ -42,7 +42,7 @@ func (p *Parser) Parse() []ast.Stmt {
stmts := []ast.Stmt{} stmts := []ast.Stmt{}
for !p.isAtEnd() { for !p.isAtEnd() {
stmt := p.statement() stmt := p.declaration()
if _, ok := stmt.(*ast.ErrorStmt); ok { if _, ok := stmt.(*ast.ErrorStmt); ok {
p.synchronize() p.synchronize()
} else { } else {
@ -53,6 +53,36 @@ func (p *Parser) Parse() []ast.Stmt {
return stmts return stmts
} }
// declaration → varDecl | statement ;
func (p *Parser) declaration() ast.Stmt {
if p.match(token.VAR) {
return p.varDeclaration()
}
return p.statement()
}
// varDecl → "var" IDENTIFIER ( "=" expression )? ";" ;
func (p *Parser) varDeclaration() ast.Stmt {
err := p.consume(token.IDENTIFIER, "Expect variable name.")
if err != nil {
return p.fromErrorExpr(err)
}
name := p.previous()
var initializer ast.Expr
if p.match(token.EQUAL) {
initializer = p.expression()
}
err = p.consume(token.SEMICOLON, "Expect ';' after variable declaration.")
if err != nil {
return p.fromErrorExpr(err)
}
return &ast.VarStmt{Name: name, Initializer: initializer}
}
// statement → exprStmt | printStmt ; // statement → exprStmt | printStmt ;
func (p *Parser) statement() ast.Stmt { func (p *Parser) statement() ast.Stmt {
if p.match(token.PRINT) { if p.match(token.PRINT) {
@ -152,7 +182,7 @@ func (p *Parser) unary() ast.Expr {
return p.primary() return p.primary()
} }
// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ; // primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER;
func (p *Parser) primary() ast.Expr { func (p *Parser) primary() ast.Expr {
switch { switch {
case p.match(token.FALSE): case p.match(token.FALSE):
@ -163,6 +193,8 @@ func (p *Parser) primary() ast.Expr {
return &ast.LiteralExpr{Value: nil} return &ast.LiteralExpr{Value: nil}
case p.match(token.NUMBER, token.STRING): case p.match(token.NUMBER, token.STRING):
return &ast.LiteralExpr{Value: p.previous().Literal} return &ast.LiteralExpr{Value: p.previous().Literal}
case p.match(token.IDENTIFIER):
return &ast.VariableExpr{Name: p.previous()}
case p.match(token.LEFT_PAREN): case p.match(token.LEFT_PAREN):
expr := p.expression() expr := p.expression()
err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.") err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.")

@ -176,3 +176,50 @@ func TestParsePrintStmt(t *testing.T) {
}) })
} }
} }
func TestParseVarStmt(t *testing.T) {
tests := []struct {
name string
tokens []token.Token
expected string
}{
{
name: "simple var statement",
tokens: []token.Token{
{Type: token.VAR, Lexeme: "var"},
{Type: token.IDENTIFIER, Lexeme: "foo"},
{Type: token.EQUAL, Lexeme: "="},
{Type: token.NUMBER, Lexeme: "42", Literal: 42},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(var foo 42)\n",
},
{
name: "simple var statement",
tokens: []token.Token{
{Type: token.VAR, Lexeme: "var"},
{Type: token.IDENTIFIER, Lexeme: "foo"},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(var foo)\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := New(tt.tokens, errors.NewMockErrorLogger())
stmts := parser.Parse()
if len(stmts) != 1 {
t.Fatalf("expected 1 statement, got %d", len(stmts))
}
ap := ast.NewPrinter()
s := ap.PrintStmts(stmts)
if s != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, s)
}
})
}
}

Loading…
Cancel
Save