Add global variables and environments

main
oabrivard 1 year ago
parent b97f44db2d
commit 5eacc29825

@ -1 +1 @@
bf6f2cc6973975f2bcc789463842856a
ea7730697fb5d7edf1961630c5f68b96

@ -8,6 +8,7 @@ type ExprVisitor[T any] interface {
VisitGroupingExpr(ge *GroupingExpr) T
VisitLiteralExpr(le *LiteralExpr) T
VisitUnaryExpr(ue *UnaryExpr) T
VisitVariableExpr(ve *VariableExpr) T
}
type Expr interface {
@ -57,3 +58,11 @@ func (ue *UnaryExpr) Accept(v ExprVisitor[any]) any {
return v.VisitUnaryExpr(ue)
}
type VariableExpr struct {
Name token.Token
}
func (ve *VariableExpr) Accept(v ExprVisitor[any]) any {
return v.VisitVariableExpr(ve)
}

@ -41,6 +41,10 @@ func (ap *Printer) VisitVarStmt(stmt *VarStmt) any {
return ap.parenthesize("var", &LiteralExpr{stmt.Name}, stmt.Initializer)
}
func (ap *Printer) VisitVariableExpr(expr *VariableExpr) any {
return expr.Name.Lexeme
}
func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any {
return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right)
}
@ -68,6 +72,9 @@ func (ap *Printer) parenthesize(name string, exprs ...Expr) string {
str := "(" + name
for _, expr := range exprs {
if expr == nil {
continue
}
str += " " + expr.Accept(ap).(string)
}

@ -27,6 +27,7 @@ func main() {
"Grouping : Expression Expr",
"Literal : Value any",
"Unary : Operator token.Token, Right Expr",
"Variable : Name token.Token",
})
defineAst(d, "Stmt", []string{

@ -0,0 +1,22 @@
package interpreter
type environment struct {
values map[string]any
}
func newEnvironment() *environment {
return &environment{values: make(map[string]any)}
}
func (e *environment) define(name string, value any) {
e.values[name] = value
}
func (e *environment) get(name string) any {
value, ok := e.values[name]
if !ok {
panic("Undefined variable '" + name + "'.")
}
return value
}

@ -11,11 +11,12 @@ import (
// Interpreter interprets the AST.
type Interpreter struct {
errLogger errors.Logger
env *environment
}
// New creates a new Interpreter.
func New(el errors.Logger) *Interpreter {
return &Interpreter{el}
return &Interpreter{el, newEnvironment()}
}
// Interpret interprets the AST.
@ -47,6 +48,12 @@ func (i *Interpreter) VisitPrintStmt(ps *ast.PrintStmt) any {
// VisitVarStmt visits a var statement.
func (i *Interpreter) VisitVarStmt(vs *ast.VarStmt) any {
var value any
if vs.Initializer != nil {
value = i.evaluate(vs.Initializer)
}
i.env.define(vs.Name.Lexeme, value)
return nil
}
@ -142,6 +149,11 @@ func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any {
panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line))
}
// VisitVariableExpr visits a VariableExpr.
func (i *Interpreter) VisitVariableExpr(v *ast.VariableExpr) any {
return i.env.get(v.Name.Lexeme)
}
// checkNumberOperands checks if the operands are numbers.
func checkNumberOperands(operator token.Token, operands ...any) {
for _, operand := range operands {

@ -1,16 +1,18 @@
// FILE: interpreter_test.go
package interpreter_test
package interpreter
import (
"bytes"
"golox/ast"
"golox/errors"
"golox/interpreter"
"golox/token"
"io"
"os"
"testing"
)
func TestInterpretLiteralExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42}
result := i.VisitLiteralExpr(literal)
@ -20,7 +22,7 @@ func TestInterpretLiteralExpr(t *testing.T) {
}
func TestInterpretGroupingExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42}
grouping := &ast.GroupingExpr{Expression: literal}
@ -31,7 +33,7 @@ func TestInterpretGroupingExpr(t *testing.T) {
}
func TestInterpretUnaryExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0}
unary := &ast.UnaryExpr{
Operator: token.Token{Type: token.MINUS, Lexeme: "-"},
@ -45,7 +47,7 @@ func TestInterpretUnaryExpr(t *testing.T) {
}
func TestInterpretUnaryExprBang(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: true}
unary := &ast.UnaryExpr{
Operator: token.Token{Type: token.BANG, Lexeme: "!"},
@ -59,7 +61,7 @@ func TestInterpretUnaryExprBang(t *testing.T) {
}
func TestInterpretErrorExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
errorExpr := &ast.ErrorExpr{Value: "error"}
defer func() {
@ -72,7 +74,7 @@ func TestInterpretErrorExpr(t *testing.T) {
}
func TestInterpretExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0}
defer func() {
@ -88,7 +90,7 @@ func TestInterpretExpr(t *testing.T) {
}
func TestInterpretBinaryExpr(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
@ -104,7 +106,7 @@ func TestInterpretBinaryExpr(t *testing.T) {
}
func TestInterpretBinaryExprDivisionByZero(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 0.0}
binary := &ast.BinaryExpr{
@ -123,7 +125,7 @@ func TestInterpretBinaryExprDivisionByZero(t *testing.T) {
}
func TestInterpretBinaryExprAddition(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
@ -139,7 +141,7 @@ func TestInterpretBinaryExprAddition(t *testing.T) {
}
func TestInterpretBinaryExprSubtraction(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
@ -155,7 +157,7 @@ func TestInterpretBinaryExprSubtraction(t *testing.T) {
}
func TestInterpretBinaryExprStringConcatenation(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: "bar"}
binary := &ast.BinaryExpr{
@ -171,7 +173,7 @@ func TestInterpretBinaryExprStringConcatenation(t *testing.T) {
}
func TestInterpretBinaryExprInvalidOperands(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{
@ -190,7 +192,7 @@ func TestInterpretBinaryExprInvalidOperands(t *testing.T) {
}
func TestInterpretBinaryExprComparison(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
@ -206,7 +208,7 @@ func TestInterpretBinaryExprComparison(t *testing.T) {
}
func TestInterpretBinaryExprComparisonEqual(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{
@ -222,7 +224,7 @@ func TestInterpretBinaryExprComparisonEqual(t *testing.T) {
}
func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
@ -238,7 +240,7 @@ func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) {
}
func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{
@ -257,7 +259,7 @@ func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) {
}
func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) {
i := interpreter.New(errors.NewMockErrorLogger())
i := New(errors.NewMockErrorLogger())
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
@ -274,3 +276,65 @@ func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) {
i.VisitBinaryExpr(binary)
}
func TestInterpretExprStatement(t *testing.T) {
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0}
exprStmt := &ast.ExpressionStmt{Expression: literal}
result := i.VisitExpressionStmt(exprStmt)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
}
func TestInterpretPrintStatement(t *testing.T) {
old := os.Stdout // keep backup of the real stdout
r, w, err := os.Pipe()
if err != nil {
t.Fatal(err)
}
os.Stdout = w
outC := make(chan string)
// copy the output in a separate goroutine so printing can't block indefinitely
go func() {
var buf bytes.Buffer
io.Copy(&buf, r)
outC <- buf.String()
}()
i := New(errors.NewMockErrorLogger())
literal := &ast.LiteralExpr{Value: 42.0}
printStmt := &ast.PrintStmt{Expression: literal}
result := i.VisitPrintStmt(printStmt)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
// back to normal state
w.Close()
os.Stdout = old // restoring the real stdout
out := <-outC
// reading our temp stdout
expected := "42\n"
if out != expected {
t.Errorf("run() = %v; want %v", out, expected)
}
}
func TestInterpretVarStatement(t *testing.T) {
i := New(errors.NewMockErrorLogger())
varStmt := &ast.VarStmt{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"},
Initializer: &ast.LiteralExpr{Value: 42.0},
}
i.VisitVarStmt(varStmt)
result := i.env.get("foo")
if result != 42.0 {
t.Errorf("expected 42, got %v", result)
}
}

@ -42,7 +42,7 @@ func (p *Parser) Parse() []ast.Stmt {
stmts := []ast.Stmt{}
for !p.isAtEnd() {
stmt := p.statement()
stmt := p.declaration()
if _, ok := stmt.(*ast.ErrorStmt); ok {
p.synchronize()
} else {
@ -53,6 +53,36 @@ func (p *Parser) Parse() []ast.Stmt {
return stmts
}
// declaration → varDecl | statement ;
func (p *Parser) declaration() ast.Stmt {
if p.match(token.VAR) {
return p.varDeclaration()
}
return p.statement()
}
// varDecl → "var" IDENTIFIER ( "=" expression )? ";" ;
func (p *Parser) varDeclaration() ast.Stmt {
err := p.consume(token.IDENTIFIER, "Expect variable name.")
if err != nil {
return p.fromErrorExpr(err)
}
name := p.previous()
var initializer ast.Expr
if p.match(token.EQUAL) {
initializer = p.expression()
}
err = p.consume(token.SEMICOLON, "Expect ';' after variable declaration.")
if err != nil {
return p.fromErrorExpr(err)
}
return &ast.VarStmt{Name: name, Initializer: initializer}
}
// statement → exprStmt | printStmt ;
func (p *Parser) statement() ast.Stmt {
if p.match(token.PRINT) {
@ -152,7 +182,7 @@ func (p *Parser) unary() ast.Expr {
return p.primary()
}
// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ;
// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER;
func (p *Parser) primary() ast.Expr {
switch {
case p.match(token.FALSE):
@ -163,6 +193,8 @@ func (p *Parser) primary() ast.Expr {
return &ast.LiteralExpr{Value: nil}
case p.match(token.NUMBER, token.STRING):
return &ast.LiteralExpr{Value: p.previous().Literal}
case p.match(token.IDENTIFIER):
return &ast.VariableExpr{Name: p.previous()}
case p.match(token.LEFT_PAREN):
expr := p.expression()
err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.")

@ -176,3 +176,50 @@ func TestParsePrintStmt(t *testing.T) {
})
}
}
func TestParseVarStmt(t *testing.T) {
tests := []struct {
name string
tokens []token.Token
expected string
}{
{
name: "simple var statement",
tokens: []token.Token{
{Type: token.VAR, Lexeme: "var"},
{Type: token.IDENTIFIER, Lexeme: "foo"},
{Type: token.EQUAL, Lexeme: "="},
{Type: token.NUMBER, Lexeme: "42", Literal: 42},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(var foo 42)\n",
},
{
name: "simple var statement",
tokens: []token.Token{
{Type: token.VAR, Lexeme: "var"},
{Type: token.IDENTIFIER, Lexeme: "foo"},
{Type: token.SEMICOLON, Lexeme: ";"},
{Type: token.EOF},
},
expected: "(var foo)\n",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := New(tt.tokens, errors.NewMockErrorLogger())
stmts := parser.Parse()
if len(stmts) != 1 {
t.Fatalf("expected 1 statement, got %d", len(stmts))
}
ap := ast.NewPrinter()
s := ap.PrintStmts(stmts)
if s != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, s)
}
})
}
}

Loading…
Cancel
Save