Add expression interpreter

main
oabrivard 1 year ago
parent e0e921602d
commit 95ab9d78d5

@ -5,4 +5,20 @@ import "golox/token"
type Logger interface { type Logger interface {
Error(line int, message string) Error(line int, message string)
ErrorAtToken(t token.Token, message string) ErrorAtToken(t token.Token, message string)
RuntimeError(message string)
}
type mockErrorLogger struct{}
func (el *mockErrorLogger) Error(line int, message string) {
}
func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) {
}
func (el *mockErrorLogger) RuntimeError(message string) {
}
func NewMockErrorLogger() *mockErrorLogger {
return &mockErrorLogger{}
} }

@ -0,0 +1,185 @@
package interpreter
import (
"fmt"
"golox/ast"
"golox/errors"
"golox/token"
"strings"
)
// Interpreter interprets the AST.
type Interpreter struct {
errLogger errors.Logger
}
// New creates a new Interpreter.
func New() *Interpreter {
return &Interpreter{}
}
// Interpret interprets the AST.
func (i *Interpreter) InterpretExpr(expr ast.Expr) string {
defer i.afterPanic()
value := i.evaluate(expr)
return stringify(value)
}
// VisitErrorExpr visits an ErrorExpr.
func (i *Interpreter) VisitErrorExpr(e *ast.ErrorExpr) any {
panic(e.Value)
}
// VisitLiteralExpr visits a LiteralExpr.
func (i *Interpreter) VisitLiteralExpr(l *ast.LiteralExpr) any {
return l.Value
}
// VisitGroupingExpr visits a GroupingExpr.
func (i *Interpreter) VisitGroupingExpr(g *ast.GroupingExpr) any {
return i.evaluate(g.Expression)
}
// VisitUnaryExpr visits a UnaryExpr.
func (i *Interpreter) VisitUnaryExpr(u *ast.UnaryExpr) any {
right := i.evaluate(u.Right)
switch u.Operator.Type {
case token.MINUS:
checkNumberOperands(u.Operator, right)
return -right.(float64)
case token.BANG:
return !isTruthy(right)
}
return nil
}
// VisitBinaryExpr visits a BinaryExpr.
func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any {
left := i.evaluate(b.Left)
right := i.evaluate(b.Right)
switch b.Operator.Type {
case token.MINUS:
checkNumberOperands(b.Operator, left, right)
return left.(float64) - right.(float64)
case token.SLASH:
checkNumberOperands(b.Operator, left, right)
denominator := right.(float64)
if denominator == 0.0 {
panic(fmt.Sprintf("Division by zero [line %d]", b.Operator.Line))
}
return left.(float64) / denominator
case token.STAR:
checkNumberOperands(b.Operator, left, right)
return left.(float64) * right.(float64)
case token.PLUS:
if l, ok := left.(float64); ok {
if r, ok := right.(float64); ok {
return l + r
}
}
if l, ok := left.(string); ok {
if r, ok := right.(string); ok {
return l + r
}
}
panic(fmt.Sprintf("Operands must be two numbers or two strings [line %d]", b.Operator.Line))
case token.GREATER:
checkNumberOperands(b.Operator, left, right)
return left.(float64) > right.(float64)
case token.GREATER_EQUAL:
checkNumberOperands(b.Operator, left, right)
return left.(float64) >= right.(float64)
case token.LESS:
checkNumberOperands(b.Operator, left, right)
return left.(float64) < right.(float64)
case token.LESS_EQUAL:
checkNumberOperands(b.Operator, left, right)
return left.(float64) <= right.(float64)
case token.BANG_EQUAL:
return !isEqual(left, right)
case token.EQUAL_EQUAL:
return isEqual(left, right)
}
panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line))
}
// checkNumberOperands checks if the operands are numbers.
func checkNumberOperands(operator token.Token, operands ...any) {
for _, operand := range operands {
if _, ok := operand.(float64); !ok {
panic(fmt.Sprintf("Operands of operator '%s' must be numbers [line %d]", operator.Lexeme, operator.Line))
}
}
}
// isTruthy checks if a value is truthy.
func isTruthy(v any) bool {
if v == nil {
return false
}
if b, ok := v.(bool); ok {
return b
}
return true
}
// isEqual checks if two values are equal.
func isEqual(a, b any) bool {
if a == nil && b == nil {
return true
}
if a == nil {
return false
}
return a == b
}
// evaluate evaluates an expression.
func (i *Interpreter) evaluate(e ast.Expr) any {
return e.Accept(i)
}
// stringify returns a string representation of a value.
func stringify(v any) string {
if v == nil {
return "nil"
}
if b, ok := v.(bool); ok {
if b {
return "true"
}
return "false"
}
s := fmt.Sprintf("%v", v)
if f, ok := v.(float64); ok {
if strings.HasSuffix(s, ".0") {
return fmt.Sprintf("%d", int(f))
}
return fmt.Sprintf("%g", f)
}
return s
}
// afterPanic handles a panic.
func (i *Interpreter) afterPanic() {
if r := recover(); r != nil {
i.errLogger.RuntimeError(r.(string))
}
}

@ -0,0 +1,275 @@
// FILE: interpreter_test.go
package interpreter_test
import (
"golox/ast"
"golox/interpreter"
"golox/token"
"testing"
)
func TestInterpretLiteralExpr(t *testing.T) {
i := interpreter.New()
literal := &ast.LiteralExpr{Value: 42}
result := i.VisitLiteralExpr(literal)
if result != 42 {
t.Errorf("expected 42, got %v", result)
}
}
func TestInterpretGroupingExpr(t *testing.T) {
i := interpreter.New()
literal := &ast.LiteralExpr{Value: 42}
grouping := &ast.GroupingExpr{Expression: literal}
result := i.VisitGroupingExpr(grouping)
if result != 42 {
t.Errorf("expected 42, got %v", result)
}
}
func TestInterpretUnaryExpr(t *testing.T) {
i := interpreter.New()
literal := &ast.LiteralExpr{Value: 42.0}
unary := &ast.UnaryExpr{
Operator: token.Token{Type: token.MINUS, Lexeme: "-"},
Right: literal,
}
result := i.VisitUnaryExpr(unary)
if result != -42.0 {
t.Errorf("expected -42, got %v", result)
}
}
func TestInterpretUnaryExprBang(t *testing.T) {
i := interpreter.New()
literal := &ast.LiteralExpr{Value: true}
unary := &ast.UnaryExpr{
Operator: token.Token{Type: token.BANG, Lexeme: "!"},
Right: literal,
}
result := i.VisitUnaryExpr(unary)
if result != false {
t.Errorf("expected false, got %v", result)
}
}
func TestInterpretErrorExpr(t *testing.T) {
i := interpreter.New()
errorExpr := &ast.ErrorExpr{Value: "error"}
defer func() {
if r := recover(); r != "error" {
t.Errorf("expected panic with 'error', got %v", r)
}
}()
i.VisitErrorExpr(errorExpr)
}
func TestInterpretExpr(t *testing.T) {
i := interpreter.New()
literal := &ast.LiteralExpr{Value: 42.0}
defer func() {
if r := recover(); r != nil {
t.Errorf("unexpected panic: %v", r)
}
}()
result := i.InterpretExpr(literal)
if result != "42" {
t.Errorf("expected '42', got %v", result)
}
}
func TestInterpretBinaryExpr(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.STAR, Lexeme: "*"},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != 84.0 {
t.Errorf("expected 84, got %v", result)
}
}
func TestInterpretBinaryExprDivisionByZero(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 0.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.SLASH, Lexeme: "/"},
Right: right,
}
defer func() {
if r := recover(); r != "Division by zero [line 0]" {
t.Errorf("expected panic with 'division by zero', got %v", r)
}
}()
i.VisitBinaryExpr(binary)
}
func TestInterpretBinaryExprAddition(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.PLUS, Lexeme: "+"},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != 44.0 {
t.Errorf("expected 44, got %v", result)
}
}
func TestInterpretBinaryExprSubtraction(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.MINUS, Lexeme: "-"},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != 40.0 {
t.Errorf("expected 40, got %v", result)
}
}
func TestInterpretBinaryExprStringConcatenation(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: "bar"}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.PLUS, Lexeme: "+"},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != "foobar" {
t.Errorf("expected 'foobar', got %v", result)
}
}
func TestInterpretBinaryExprInvalidOperands(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.PLUS, Lexeme: "+"},
Right: right,
}
defer func() {
if r := recover(); r != "Operands must be two numbers or two strings [line 0]" {
t.Errorf("expected panic with 'operands must be two numbers or two strings', got %v", r)
}
}()
i.VisitBinaryExpr(binary)
}
func TestInterpretBinaryExprComparison(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.GREATER, Lexeme: ">"},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != true {
t.Errorf("expected true, got %v", result)
}
}
func TestInterpretBinaryExprComparisonEqual(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.EQUAL_EQUAL, Lexeme: "=="},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != true {
t.Errorf("expected true, got %v", result)
}
}
func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.BANG_EQUAL, Lexeme: "!="},
Right: right,
}
result := i.VisitBinaryExpr(binary)
if result != true {
t.Errorf("expected true, got %v", result)
}
}
func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: "foo"}
right := &ast.LiteralExpr{Value: 42.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.GREATER, Lexeme: ">"},
Right: right,
}
defer func() {
if r := recover(); r != "Operands of operator '>' must be numbers [line 0]" {
t.Errorf("expected panic with 'operands must be numbers', got %v", r)
}
}()
i.VisitBinaryExpr(binary)
}
func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) {
i := interpreter.New()
left := &ast.LiteralExpr{Value: 42.0}
right := &ast.LiteralExpr{Value: 2.0}
binary := &ast.BinaryExpr{
Left: left,
Operator: token.Token{Type: token.EOF, Lexeme: ""},
Right: right,
}
defer func() {
if r := recover(); r != "Unknown binary operator '' [line 0]" {
t.Errorf("expected panic with 'unknown operator type', got %v", r)
}
}()
i.VisitBinaryExpr(binary)
}

@ -3,7 +3,7 @@ package lox
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"golox/ast" "golox/interpreter"
"golox/parser" "golox/parser"
"golox/scanner" "golox/scanner"
"golox/token" "golox/token"
@ -12,10 +12,16 @@ import (
type Lox struct { type Lox struct {
hadError bool hadError bool
hadRuntimeError bool
interpreter *interpreter.Interpreter
} }
func New() *Lox { func New() *Lox {
return &Lox{hadError: false} return &Lox{
hadError: false,
hadRuntimeError: false,
interpreter: interpreter.New(),
}
} }
func (l *Lox) RunFile(path string) { func (l *Lox) RunFile(path string) {
@ -30,6 +36,10 @@ func (l *Lox) RunFile(path string) {
if l.hadError { if l.hadError {
os.Exit(65) os.Exit(65)
} }
if l.hadRuntimeError {
os.Exit(70)
}
} }
func (l *Lox) RunPrompt() { func (l *Lox) RunPrompt() {
@ -54,10 +64,12 @@ func (l *Lox) RunPrompt() {
} }
func (l *Lox) Error(line int, message string) { func (l *Lox) Error(line int, message string) {
l.hadError = true
l.report(line, "", message) l.report(line, "", message)
} }
func (l *Lox) ErrorAtToken(t token.Token, message string) { func (l *Lox) ErrorAtToken(t token.Token, message string) {
l.hadError = true
if t.Type == token.EOF { if t.Type == token.EOF {
l.report(t.Line, " at end", message) l.report(t.Line, " at end", message)
} else { } else {
@ -65,9 +77,13 @@ func (l *Lox) ErrorAtToken(t token.Token, message string) {
} }
} }
func (l *Lox) RuntimeError(message string) {
l.hadRuntimeError = true
fmt.Println(message)
}
func (l *Lox) report(line int, where string, message string) { func (l *Lox) report(line int, where string, message string) {
fmt.Printf("[line %d] Error %s: %s\n", line, where, message) fmt.Printf("[line %d] Error %s: %s\n", line, where, message)
l.hadError = true
} }
func (l *Lox) run(source string) { func (l *Lox) run(source string) {
@ -80,6 +96,6 @@ func (l *Lox) run(source string) {
return return
} }
p := ast.NewPrinter() s := l.interpreter.InterpretExpr(expr)
fmt.Println(p.Print(expr)) fmt.Println(s)
} }

@ -34,7 +34,7 @@ func TestRun(t *testing.T) {
out := <-outC out := <-outC
// reading our temp stdout // reading our temp stdout
expected := "(+ 1 (/ 4 2))\n" expected := "3\n"
if out != expected { if out != expected {
t.Errorf("run() = %v; want %v", out, expected) t.Errorf("run() = %v; want %v", out, expected)
} }
@ -79,7 +79,7 @@ func TestRunFile(t *testing.T) {
out := <-outC out := <-outC
// reading our temp stdout // reading our temp stdout
expected := "(+ 1 (/ 4 2))\n" expected := "3\n"
if out != expected { if out != expected {
t.Errorf("RunFile() = %v; want %v", out, expected) t.Errorf("RunFile() = %v; want %v", out, expected)
} }
@ -178,7 +178,7 @@ func TestRunPrompt(t *testing.T) {
os.Stdout = oldStdout os.Stdout = oldStdout
out := <-outC out := <-outC
expected := "> (+ 1 (/ 4 2))\n> " expected := "> 3\n> "
if out != expected { if out != expected {
t.Errorf("RunPrompt() = %v; want %v", out, expected) t.Errorf("RunPrompt() = %v; want %v", out, expected)
} }

@ -2,22 +2,11 @@ package parser
import ( import (
"golox/ast" "golox/ast"
"golox/errors"
"golox/token" "golox/token"
"testing" "testing"
) )
type mockErrorLogger struct{}
func (el *mockErrorLogger) Error(line int, message string) {
}
func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) {
}
func newMockErrorLogger() *mockErrorLogger {
return &mockErrorLogger{}
}
func TestParser(t *testing.T) { func TestParser(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -75,11 +64,22 @@ func TestParser(t *testing.T) {
}, },
expected: "(== 1 2)", expected: "(== 1 2)",
}, },
{
name: "Parsing error - missing right parenthesis",
tokens: []token.Token{
{Type: token.LEFT_PAREN, Lexeme: "("},
{Type: token.NUMBER, Literal: 1},
{Type: token.PLUS, Lexeme: "+"},
{Type: token.NUMBER, Literal: 2},
{Type: token.EOF},
},
expected: "Expect ')' after expression.",
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
parser := New(tt.tokens, newMockErrorLogger()) parser := New(tt.tokens, errors.NewMockErrorLogger())
expr := parser.Parse() expr := parser.Parse()
ap := ast.NewPrinter() ap := ast.NewPrinter()
s := ap.Print(expr) s := ap.Print(expr)

@ -1,22 +1,11 @@
package scanner package scanner
import ( import (
"golox/errors"
"golox/token" "golox/token"
"testing" "testing"
) )
type errorLogger struct{}
func (el *errorLogger) Error(line int, message string) {
}
func (el *errorLogger) ErrorAtToken(t token.Token, message string) {
}
func newErrorLogger() *errorLogger {
return &errorLogger{}
}
func TestScanTokens(t *testing.T) { func TestScanTokens(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -88,7 +77,7 @@ func TestScanTokens(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
tokens := scanner.ScanTokens() tokens := scanner.ScanTokens()
if len(tokens) != len(tt.tokens)+1 { // +1 for EOF token if len(tokens) != len(tt.tokens)+1 { // +1 for EOF token
@ -120,7 +109,7 @@ func TestIsAtEnd(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
if got := scanner.isAtEnd(); got != tt.expected { if got := scanner.isAtEnd(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got) t.Errorf("expected %v, got %v", tt.expected, got)
} }
@ -141,7 +130,7 @@ func TestMatch(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
if got := scanner.match(tt.char); got != tt.expected { if got := scanner.match(tt.char); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got) t.Errorf("expected %v, got %v", tt.expected, got)
} }
@ -161,7 +150,7 @@ func TestPeek(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
if got := scanner.peek(); got != tt.expected { if got := scanner.peek(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got) t.Errorf("expected %v, got %v", tt.expected, got)
} }
@ -181,7 +170,7 @@ func TestPeekNext(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
if got := scanner.peekNext(); got != tt.expected { if got := scanner.peekNext(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got) t.Errorf("expected %v, got %v", tt.expected, got)
} }
@ -200,7 +189,7 @@ func TestAdvance(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
if got := scanner.advance(); got != tt.expected { if got := scanner.advance(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got) t.Errorf("expected %v, got %v", tt.expected, got)
} }
@ -258,7 +247,7 @@ func TestString(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
scanner.advance() // Move to the first character of the string scanner.advance() // Move to the first character of the string
scanner.string() scanner.string()
if tt.expected == "" { if tt.expected == "" {
@ -288,7 +277,7 @@ func TestNumber(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
scanner.number() scanner.number()
if tt.expected == 0 { if tt.expected == 0 {
if len(scanner.tokens) != 0 { if len(scanner.tokens) != 0 {
@ -341,7 +330,7 @@ func TestScanToken(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
scanner.scanToken() scanner.scanToken()
if len(scanner.tokens) > 0 { if len(scanner.tokens) > 0 {
if scanner.tokens[0].Type != tt.expected { if scanner.tokens[0].Type != tt.expected {
@ -384,7 +373,7 @@ func TestIdentifier(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source, newErrorLogger()) scanner := New(tt.source, errors.NewMockErrorLogger())
scanner.identifier() scanner.identifier()
if len(scanner.tokens) != 1 { if len(scanner.tokens) != 1 {
t.Fatalf("expected 1 token, got %d", len(scanner.tokens)) t.Fatalf("expected 1 token, got %d", len(scanner.tokens))

Loading…
Cancel
Save