diff --git a/errors/logger.go b/errors/logger.go index 5f1b84a..8dfb287 100644 --- a/errors/logger.go +++ b/errors/logger.go @@ -5,4 +5,20 @@ import "golox/token" type Logger interface { Error(line int, message string) ErrorAtToken(t token.Token, message string) + RuntimeError(message string) +} + +type mockErrorLogger struct{} + +func (el *mockErrorLogger) Error(line int, message string) { +} + +func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) { +} + +func (el *mockErrorLogger) RuntimeError(message string) { +} + +func NewMockErrorLogger() *mockErrorLogger { + return &mockErrorLogger{} } diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go new file mode 100644 index 0000000..0f34de5 --- /dev/null +++ b/interpreter/interpreter.go @@ -0,0 +1,185 @@ +package interpreter + +import ( + "fmt" + "golox/ast" + "golox/errors" + "golox/token" + "strings" +) + +// Interpreter interprets the AST. +type Interpreter struct { + errLogger errors.Logger +} + +// New creates a new Interpreter. +func New() *Interpreter { + return &Interpreter{} +} + +// Interpret interprets the AST. +func (i *Interpreter) InterpretExpr(expr ast.Expr) string { + defer i.afterPanic() + + value := i.evaluate(expr) + return stringify(value) +} + +// VisitErrorExpr visits an ErrorExpr. +func (i *Interpreter) VisitErrorExpr(e *ast.ErrorExpr) any { + panic(e.Value) +} + +// VisitLiteralExpr visits a LiteralExpr. +func (i *Interpreter) VisitLiteralExpr(l *ast.LiteralExpr) any { + return l.Value +} + +// VisitGroupingExpr visits a GroupingExpr. +func (i *Interpreter) VisitGroupingExpr(g *ast.GroupingExpr) any { + return i.evaluate(g.Expression) +} + +// VisitUnaryExpr visits a UnaryExpr. +func (i *Interpreter) VisitUnaryExpr(u *ast.UnaryExpr) any { + right := i.evaluate(u.Right) + + switch u.Operator.Type { + case token.MINUS: + checkNumberOperands(u.Operator, right) + return -right.(float64) + case token.BANG: + return !isTruthy(right) + } + + return nil +} + +// VisitBinaryExpr visits a BinaryExpr. +func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any { + left := i.evaluate(b.Left) + right := i.evaluate(b.Right) + + switch b.Operator.Type { + case token.MINUS: + checkNumberOperands(b.Operator, left, right) + return left.(float64) - right.(float64) + case token.SLASH: + checkNumberOperands(b.Operator, left, right) + denominator := right.(float64) + if denominator == 0.0 { + panic(fmt.Sprintf("Division by zero [line %d]", b.Operator.Line)) + } + return left.(float64) / denominator + case token.STAR: + checkNumberOperands(b.Operator, left, right) + return left.(float64) * right.(float64) + case token.PLUS: + if l, ok := left.(float64); ok { + if r, ok := right.(float64); ok { + return l + r + } + } + + if l, ok := left.(string); ok { + if r, ok := right.(string); ok { + return l + r + } + } + + panic(fmt.Sprintf("Operands must be two numbers or two strings [line %d]", b.Operator.Line)) + case token.GREATER: + checkNumberOperands(b.Operator, left, right) + return left.(float64) > right.(float64) + case token.GREATER_EQUAL: + checkNumberOperands(b.Operator, left, right) + return left.(float64) >= right.(float64) + case token.LESS: + checkNumberOperands(b.Operator, left, right) + return left.(float64) < right.(float64) + case token.LESS_EQUAL: + checkNumberOperands(b.Operator, left, right) + return left.(float64) <= right.(float64) + case token.BANG_EQUAL: + return !isEqual(left, right) + case token.EQUAL_EQUAL: + return isEqual(left, right) + } + + panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line)) +} + +// checkNumberOperands checks if the operands are numbers. +func checkNumberOperands(operator token.Token, operands ...any) { + for _, operand := range operands { + if _, ok := operand.(float64); !ok { + panic(fmt.Sprintf("Operands of operator '%s' must be numbers [line %d]", operator.Lexeme, operator.Line)) + } + } +} + +// isTruthy checks if a value is truthy. +func isTruthy(v any) bool { + if v == nil { + return false + } + + if b, ok := v.(bool); ok { + return b + } + + return true +} + +// isEqual checks if two values are equal. +func isEqual(a, b any) bool { + if a == nil && b == nil { + return true + } + + if a == nil { + return false + } + + return a == b +} + +// evaluate evaluates an expression. +func (i *Interpreter) evaluate(e ast.Expr) any { + return e.Accept(i) +} + +// stringify returns a string representation of a value. +func stringify(v any) string { + if v == nil { + return "nil" + } + + if b, ok := v.(bool); ok { + if b { + return "true" + } + + return "false" + } + + s := fmt.Sprintf("%v", v) + + if f, ok := v.(float64); ok { + if strings.HasSuffix(s, ".0") { + return fmt.Sprintf("%d", int(f)) + } + + return fmt.Sprintf("%g", f) + } + + return s +} + +// afterPanic handles a panic. +func (i *Interpreter) afterPanic() { + if r := recover(); r != nil { + i.errLogger.RuntimeError(r.(string)) + } +} diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go new file mode 100644 index 0000000..412554f --- /dev/null +++ b/interpreter/interpreter_test.go @@ -0,0 +1,275 @@ +// FILE: interpreter_test.go +package interpreter_test + +import ( + "golox/ast" + "golox/interpreter" + "golox/token" + "testing" +) + +func TestInterpretLiteralExpr(t *testing.T) { + i := interpreter.New() + literal := &ast.LiteralExpr{Value: 42} + + result := i.VisitLiteralExpr(literal) + if result != 42 { + t.Errorf("expected 42, got %v", result) + } +} + +func TestInterpretGroupingExpr(t *testing.T) { + i := interpreter.New() + literal := &ast.LiteralExpr{Value: 42} + grouping := &ast.GroupingExpr{Expression: literal} + + result := i.VisitGroupingExpr(grouping) + if result != 42 { + t.Errorf("expected 42, got %v", result) + } +} + +func TestInterpretUnaryExpr(t *testing.T) { + i := interpreter.New() + literal := &ast.LiteralExpr{Value: 42.0} + unary := &ast.UnaryExpr{ + Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, + Right: literal, + } + + result := i.VisitUnaryExpr(unary) + if result != -42.0 { + t.Errorf("expected -42, got %v", result) + } +} + +func TestInterpretUnaryExprBang(t *testing.T) { + i := interpreter.New() + literal := &ast.LiteralExpr{Value: true} + unary := &ast.UnaryExpr{ + Operator: token.Token{Type: token.BANG, Lexeme: "!"}, + Right: literal, + } + + result := i.VisitUnaryExpr(unary) + if result != false { + t.Errorf("expected false, got %v", result) + } +} + +func TestInterpretErrorExpr(t *testing.T) { + i := interpreter.New() + errorExpr := &ast.ErrorExpr{Value: "error"} + + defer func() { + if r := recover(); r != "error" { + t.Errorf("expected panic with 'error', got %v", r) + } + }() + + i.VisitErrorExpr(errorExpr) +} + +func TestInterpretExpr(t *testing.T) { + i := interpreter.New() + literal := &ast.LiteralExpr{Value: 42.0} + + defer func() { + if r := recover(); r != nil { + t.Errorf("unexpected panic: %v", r) + } + }() + + result := i.InterpretExpr(literal) + if result != "42" { + t.Errorf("expected '42', got %v", result) + } +} + +func TestInterpretBinaryExpr(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 2.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.STAR, Lexeme: "*"}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != 84.0 { + t.Errorf("expected 84, got %v", result) + } +} + +func TestInterpretBinaryExprDivisionByZero(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 0.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.SLASH, Lexeme: "/"}, + Right: right, + } + + defer func() { + if r := recover(); r != "Division by zero [line 0]" { + t.Errorf("expected panic with 'division by zero', got %v", r) + } + }() + + i.VisitBinaryExpr(binary) +} + +func TestInterpretBinaryExprAddition(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 2.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.PLUS, Lexeme: "+"}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != 44.0 { + t.Errorf("expected 44, got %v", result) + } +} + +func TestInterpretBinaryExprSubtraction(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 2.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.MINUS, Lexeme: "-"}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != 40.0 { + t.Errorf("expected 40, got %v", result) + } +} + +func TestInterpretBinaryExprStringConcatenation(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: "foo"} + right := &ast.LiteralExpr{Value: "bar"} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.PLUS, Lexeme: "+"}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != "foobar" { + t.Errorf("expected 'foobar', got %v", result) + } +} + +func TestInterpretBinaryExprInvalidOperands(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: "foo"} + right := &ast.LiteralExpr{Value: 42.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.PLUS, Lexeme: "+"}, + Right: right, + } + + defer func() { + if r := recover(); r != "Operands must be two numbers or two strings [line 0]" { + t.Errorf("expected panic with 'operands must be two numbers or two strings', got %v", r) + } + }() + + i.VisitBinaryExpr(binary) +} + +func TestInterpretBinaryExprComparison(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 2.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.GREATER, Lexeme: ">"}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != true { + t.Errorf("expected true, got %v", result) + } +} + +func TestInterpretBinaryExprComparisonEqual(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 42.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.EQUAL_EQUAL, Lexeme: "=="}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != true { + t.Errorf("expected true, got %v", result) + } +} + +func TestInterpretBinaryExprComparisonNotEqual(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 2.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.BANG_EQUAL, Lexeme: "!="}, + Right: right, + } + + result := i.VisitBinaryExpr(binary) + if result != true { + t.Errorf("expected true, got %v", result) + } +} + +func TestInterpretBinaryExprComparisonInvalidOperands(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: "foo"} + right := &ast.LiteralExpr{Value: 42.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.GREATER, Lexeme: ">"}, + Right: right, + } + + defer func() { + if r := recover(); r != "Operands of operator '>' must be numbers [line 0]" { + t.Errorf("expected panic with 'operands must be numbers', got %v", r) + } + }() + + i.VisitBinaryExpr(binary) +} + +func TestInterpretBinaryExprInvalidOperatorType(t *testing.T) { + i := interpreter.New() + left := &ast.LiteralExpr{Value: 42.0} + right := &ast.LiteralExpr{Value: 2.0} + binary := &ast.BinaryExpr{ + Left: left, + Operator: token.Token{Type: token.EOF, Lexeme: ""}, + Right: right, + } + + defer func() { + if r := recover(); r != "Unknown binary operator '' [line 0]" { + t.Errorf("expected panic with 'unknown operator type', got %v", r) + } + }() + + i.VisitBinaryExpr(binary) +} diff --git a/lox/lox.go b/lox/lox.go index 0fb1158..dfb6bbb 100644 --- a/lox/lox.go +++ b/lox/lox.go @@ -3,7 +3,7 @@ package lox import ( "bufio" "fmt" - "golox/ast" + "golox/interpreter" "golox/parser" "golox/scanner" "golox/token" @@ -11,11 +11,17 @@ import ( ) type Lox struct { - hadError bool + hadError bool + hadRuntimeError bool + interpreter *interpreter.Interpreter } func New() *Lox { - return &Lox{hadError: false} + return &Lox{ + hadError: false, + hadRuntimeError: false, + interpreter: interpreter.New(), + } } func (l *Lox) RunFile(path string) { @@ -30,6 +36,10 @@ func (l *Lox) RunFile(path string) { if l.hadError { os.Exit(65) } + + if l.hadRuntimeError { + os.Exit(70) + } } func (l *Lox) RunPrompt() { @@ -54,10 +64,12 @@ func (l *Lox) RunPrompt() { } func (l *Lox) Error(line int, message string) { + l.hadError = true l.report(line, "", message) } func (l *Lox) ErrorAtToken(t token.Token, message string) { + l.hadError = true if t.Type == token.EOF { l.report(t.Line, " at end", message) } else { @@ -65,9 +77,13 @@ func (l *Lox) ErrorAtToken(t token.Token, message string) { } } +func (l *Lox) RuntimeError(message string) { + l.hadRuntimeError = true + fmt.Println(message) +} + func (l *Lox) report(line int, where string, message string) { fmt.Printf("[line %d] Error %s: %s\n", line, where, message) - l.hadError = true } func (l *Lox) run(source string) { @@ -80,6 +96,6 @@ func (l *Lox) run(source string) { return } - p := ast.NewPrinter() - fmt.Println(p.Print(expr)) + s := l.interpreter.InterpretExpr(expr) + fmt.Println(s) } diff --git a/lox/lox_test.go b/lox/lox_test.go index 9e23514..6c9c0c3 100644 --- a/lox/lox_test.go +++ b/lox/lox_test.go @@ -34,7 +34,7 @@ func TestRun(t *testing.T) { out := <-outC // reading our temp stdout - expected := "(+ 1 (/ 4 2))\n" + expected := "3\n" if out != expected { t.Errorf("run() = %v; want %v", out, expected) } @@ -79,7 +79,7 @@ func TestRunFile(t *testing.T) { out := <-outC // reading our temp stdout - expected := "(+ 1 (/ 4 2))\n" + expected := "3\n" if out != expected { t.Errorf("RunFile() = %v; want %v", out, expected) } @@ -178,7 +178,7 @@ func TestRunPrompt(t *testing.T) { os.Stdout = oldStdout out := <-outC - expected := "> (+ 1 (/ 4 2))\n> " + expected := "> 3\n> " if out != expected { t.Errorf("RunPrompt() = %v; want %v", out, expected) } diff --git a/parser/parser_test.go b/parser/parser_test.go index 9974b0e..23e1930 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -2,22 +2,11 @@ package parser import ( "golox/ast" + "golox/errors" "golox/token" "testing" ) -type mockErrorLogger struct{} - -func (el *mockErrorLogger) Error(line int, message string) { -} - -func (el *mockErrorLogger) ErrorAtToken(t token.Token, message string) { -} - -func newMockErrorLogger() *mockErrorLogger { - return &mockErrorLogger{} -} - func TestParser(t *testing.T) { tests := []struct { name string @@ -75,11 +64,22 @@ func TestParser(t *testing.T) { }, expected: "(== 1 2)", }, + { + name: "Parsing error - missing right parenthesis", + tokens: []token.Token{ + {Type: token.LEFT_PAREN, Lexeme: "("}, + {Type: token.NUMBER, Literal: 1}, + {Type: token.PLUS, Lexeme: "+"}, + {Type: token.NUMBER, Literal: 2}, + {Type: token.EOF}, + }, + expected: "Expect ')' after expression.", + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - parser := New(tt.tokens, newMockErrorLogger()) + parser := New(tt.tokens, errors.NewMockErrorLogger()) expr := parser.Parse() ap := ast.NewPrinter() s := ap.Print(expr) diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 824866e..88cc3be 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -1,22 +1,11 @@ package scanner import ( + "golox/errors" "golox/token" "testing" ) -type errorLogger struct{} - -func (el *errorLogger) Error(line int, message string) { -} - -func (el *errorLogger) ErrorAtToken(t token.Token, message string) { -} - -func newErrorLogger() *errorLogger { - return &errorLogger{} -} - func TestScanTokens(t *testing.T) { tests := []struct { name string @@ -88,7 +77,7 @@ func TestScanTokens(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) tokens := scanner.ScanTokens() if len(tokens) != len(tt.tokens)+1 { // +1 for EOF token @@ -120,7 +109,7 @@ func TestIsAtEnd(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) if got := scanner.isAtEnd(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -141,7 +130,7 @@ func TestMatch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) if got := scanner.match(tt.char); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -161,7 +150,7 @@ func TestPeek(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) if got := scanner.peek(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -181,7 +170,7 @@ func TestPeekNext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) if got := scanner.peekNext(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -200,7 +189,7 @@ func TestAdvance(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) if got := scanner.advance(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -258,7 +247,7 @@ func TestString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) scanner.advance() // Move to the first character of the string scanner.string() if tt.expected == "" { @@ -288,7 +277,7 @@ func TestNumber(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) scanner.number() if tt.expected == 0 { if len(scanner.tokens) != 0 { @@ -341,7 +330,7 @@ func TestScanToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) scanner.scanToken() if len(scanner.tokens) > 0 { if scanner.tokens[0].Type != tt.expected { @@ -384,7 +373,7 @@ func TestIdentifier(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source, newErrorLogger()) + scanner := New(tt.source, errors.NewMockErrorLogger()) scanner.identifier() if len(scanner.tokens) != 1 { t.Fatalf("expected 1 token, got %d", len(scanner.tokens))