From 2aae58b6709ed23f3708a24a9f0260a9106d4d5d Mon Sep 17 00:00:00 2001 From: oabrivard Date: Sat, 9 Nov 2024 02:23:46 +0100 Subject: [PATCH] Add expression parser --- .task/checksum/astgen | 1 + Taskfile.yml | 44 ++++++++ ast/printer.go | 50 +++++++++ build.sh | 2 - cmd/astgen/main.go | 81 ++++++++++++++ main.go => cmd/golox/main.go | 5 +- errors/logger.go | 8 ++ lox/lox.go | 53 +++++++-- lox/lox_test.go | 15 ++- parser/expr.go | 59 ++++++++++ parser/parser.go | 207 +++++++++++++++++++++++++++++++++++ scanner/scanner.go | 32 +++--- scanner/scanner_test.go | 32 ++++-- 13 files changed, 543 insertions(+), 46 deletions(-) create mode 100644 .task/checksum/astgen create mode 100644 Taskfile.yml create mode 100644 ast/printer.go delete mode 100755 build.sh create mode 100644 cmd/astgen/main.go rename main.go => cmd/golox/main.go (78%) create mode 100644 errors/logger.go create mode 100644 parser/expr.go create mode 100644 parser/parser.go diff --git a/.task/checksum/astgen b/.task/checksum/astgen new file mode 100644 index 0000000..8582cb7 --- /dev/null +++ b/.task/checksum/astgen @@ -0,0 +1 @@ +5f28a0838ca9382a947da5bd52756cf3 diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..65f86b1 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,44 @@ +# https://taskfile.dev + +version: '3' + +vars: + APP_NAME: golox + +tasks: + default: + cmds: + - task: run + + run: + desc: "Run the interpreter" + deps: [build] + cmds: + - ./bin/{{.APP_NAME}} {{.CLI_ARGS}} + + build: + desc: "Build the interpreter" + deps: [astgen] + cmds: + - go build -o bin/{{.APP_NAME}} cmd/golox/main.go + + test: + desc: "Run tests" + deps: [astgen] + cmds: + - go test -v ./... + + astgen: + desc: "Generate AST nodes" + cmds: + - go run cmd/astgen/main.go "./parser" + sources: + - cmd/astgen/main.go + generates: + - parser/expr.go + + clean: + desc: "Clean up" + cmds: + - rm -rf bin/* + diff --git a/ast/printer.go b/ast/printer.go new file mode 100644 index 0000000..fdd6ca9 --- /dev/null +++ b/ast/printer.go @@ -0,0 +1,50 @@ +package ast + +import ( + "fmt" + "golox/parser" +) + +type Printer struct { +} + +func New() *Printer { + return &Printer{} +} + +func (ap *Printer) Print(expr parser.Expr) string { + return expr.Accept(ap).(string) +} + +func (ap *Printer) VisitBinaryExpr(expr *parser.BinaryExpr) any { + return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right) +} + +func (ap *Printer) VisitGroupingExpr(expr *parser.GroupingExpr) any { + return ap.parenthesize("group", expr.Expression) +} + +func (ap *Printer) VisitLiteralExpr(expr *parser.LiteralExpr) any { + if expr.Value == nil { + return "nil" + } + return fmt.Sprint(expr.Value) +} + +func (ap *Printer) VisitUnaryExpr(expr *parser.UnaryExpr) any { + return ap.parenthesize(expr.Operator.Lexeme, expr.Right) +} + +func (ap *Printer) VisitErrorExpr(expr *parser.ErrorExpr) any { + return expr.Value +} + +func (ap *Printer) parenthesize(name string, exprs ...parser.Expr) string { + str := "(" + name + + for _, expr := range exprs { + str += " " + expr.Accept(ap).(string) + } + + return str + ")" +} diff --git a/build.sh b/build.sh deleted file mode 100755 index cc61283..0000000 --- a/build.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/zsh -go build -o bin/golox diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go new file mode 100644 index 0000000..d7ee37d --- /dev/null +++ b/cmd/astgen/main.go @@ -0,0 +1,81 @@ +package main + +import ( + "fmt" + "os" + "strings" +) + +func main() { + if len(os.Args) != 2 { + fmt.Println("Usage: astgen ") + os.Exit(64) + } + + d := os.Args[1] + + if _, err := os.Stat(d); os.IsNotExist(err) { + fmt.Println("Directory does not exist") + os.Exit(74) + } + + fmt.Println("Generating AST classes in", d) + + defineAst(d, "Expr", []string{ + "Error : Value string", + "Binary : Left Expr, Operator token.Token, Right Expr", + "Grouping : Expression Expr", + "Literal : Value any", + "Unary : Operator token.Token, Right Expr", + }) +} + +func defineAst(outputDir, baseName string, types []string) { + path := outputDir + "/" + strings.ToLower(baseName) + ".go" + + file, err := os.Create(path) + if err != nil { + fmt.Println("Error creating file", path) + os.Exit(74) + } + defer file.Close() + + file.WriteString("package parser\n\n") + file.WriteString("import \"golox/token\"\n\n") + defineVisitor(file, baseName, types) + + file.WriteString("type " + baseName + " interface {\n") + file.WriteString(" Accept(visitor " + baseName + "Visitor[any]) any\n") + file.WriteString("}\n\n") + + for _, t := range types { + defineType(file, baseName, t) + } +} + +func defineVisitor(file *os.File, baseName string, types []string) { + file.WriteString("type " + baseName + "Visitor[T any] interface {\n") + + for _, t := range types { + typeName := strings.TrimSpace(t[:strings.Index(t, ":")-1]) + file.WriteString(" Visit" + typeName + baseName + "(" + strings.ToLower(typeName) + " *" + typeName + baseName + ") T\n") + } + + file.WriteString("}\n\n") +} + +func defineType(file *os.File, baseName, typeString string) { + typeName := strings.TrimSpace(typeString[:strings.Index(typeString, ":")-1]) + fields := strings.TrimSpace(typeString[strings.Index(typeString, ":")+1:]) + + file.WriteString("type " + typeName + baseName + " struct {\n") + + for _, field := range strings.Split(fields, ", ") { + file.WriteString(" " + field + "\n") + } + + file.WriteString("}\n\n") + file.WriteString("func (t *" + typeName + baseName + ") Accept(visitor " + baseName + "Visitor[any]) any {\n") + file.WriteString(" return visitor.Visit" + typeName + baseName + "(t)\n") + file.WriteString("}\n\n") +} diff --git a/main.go b/cmd/golox/main.go similarity index 78% rename from main.go rename to cmd/golox/main.go index b539208..71c08b1 100644 --- a/main.go +++ b/cmd/golox/main.go @@ -9,13 +9,14 @@ import ( func main() { nbArgs := len(os.Args) + l := lox.New() if nbArgs > 2 { fmt.Println("Usage: golox [script]") os.Exit(64) } else if nbArgs == 2 { - lox.RunFile(os.Args[1]) + l.RunFile(os.Args[1]) } else { - lox.RunPrompt() + l.RunPrompt() } } diff --git a/errors/logger.go b/errors/logger.go new file mode 100644 index 0000000..5f1b84a --- /dev/null +++ b/errors/logger.go @@ -0,0 +1,8 @@ +package errors + +import "golox/token" + +type Logger interface { + Error(line int, message string) + ErrorAtToken(t token.Token, message string) +} diff --git a/lox/lox.go b/lox/lox.go index 34ade77..99b632f 100644 --- a/lox/lox.go +++ b/lox/lox.go @@ -3,26 +3,36 @@ package lox import ( "bufio" "fmt" + "golox/ast" + "golox/parser" + "golox/scanner" + "golox/token" "os" ) -var hadError = false +type Lox struct { + hadError bool +} + +func New() *Lox { + return &Lox{hadError: false} +} -func RunFile(path string) { +func (l *Lox) RunFile(path string) { bytes, err := os.ReadFile(path) if err != nil { fmt.Println("Error reading file", path) os.Exit(74) } - run(string(bytes)) + l.run(string(bytes)) - if hadError { + if l.hadError { os.Exit(65) } } -func RunPrompt() { +func (l *Lox) RunPrompt() { reader := bufio.NewReader(os.Stdin) for { @@ -38,19 +48,38 @@ func RunPrompt() { break } - run(line) - hadError = false + l.run(line) + l.hadError = false } } -func Error(line int, message string) { - report(line, "", message) +func (l *Lox) Error(line int, message string) { + l.report(line, "", message) } -func report(line int, where string, message string) { +func (l *Lox) ErrorAtToken(t token.Token, message string) { + if t.Type == token.EOF { + l.report(t.Line, " at end", message) + } else { + l.report(t.Line, " at '"+t.Lexeme+"'", message) + } +} + +func (l *Lox) report(line int, where string, message string) { fmt.Printf("[line %d] Error %s: %s\n", line, where, message) + l.hadError = true } -func run(source string) { - fmt.Println(source) +func (l *Lox) run(source string) { + scanner := scanner.New(source, l) + tokens := scanner.ScanTokens() + parser := parser.New(tokens, l) + expr := parser.Parse() + + if l.hadError { + return + } + + p := ast.New() + fmt.Println(p.Print(expr)) } diff --git a/lox/lox_test.go b/lox/lox_test.go index f5e4077..c6ff114 100644 --- a/lox/lox_test.go +++ b/lox/lox_test.go @@ -25,7 +25,8 @@ func TestRun(t *testing.T) { }() source := "print('Hello, World!');" - run(source) + l := New() + l.run(source) // back to normal state w.Close() @@ -69,7 +70,8 @@ func TestRunFile(t *testing.T) { if err := tmpfile.Close(); err != nil { t.Fatal(err) } - RunFile(tmpfile.Name()) + l := New() + l.RunFile(tmpfile.Name()) // back to normal state w.Close() @@ -101,7 +103,8 @@ func TestError(t *testing.T) { line := 1 message := "Unexpected character." - Error(line, message) + l := New() + l.Error(line, message) // back to normal state w.Close() @@ -134,7 +137,8 @@ func TestReport(t *testing.T) { line := 1 where := "at 'foo'" message := "Unexpected character." - report(line, where, message) + l := New() + l.report(line, where, message) // back to normal state w.Close() @@ -166,7 +170,8 @@ func TestRunPrompt(t *testing.T) { wIn.Write([]byte(input)) wIn.Close() - RunPrompt() + l := New() + l.RunPrompt() wOut.Close() os.Stdin = oldStdin diff --git a/parser/expr.go b/parser/expr.go new file mode 100644 index 0000000..c6ba2f0 --- /dev/null +++ b/parser/expr.go @@ -0,0 +1,59 @@ +package parser + +import "golox/token" + +type ExprVisitor[T any] interface { + VisitErrorExpr(error *ErrorExpr) T + VisitBinaryExpr(binary *BinaryExpr) T + VisitGroupingExpr(grouping *GroupingExpr) T + VisitLiteralExpr(literal *LiteralExpr) T + VisitUnaryExpr(unary *UnaryExpr) T +} + +type Expr interface { + Accept(visitor ExprVisitor[any]) any +} + +type ErrorExpr struct { + Value string +} + +func (t *ErrorExpr) Accept(visitor ExprVisitor[any]) any { + return visitor.VisitErrorExpr(t) +} + +type BinaryExpr struct { + Left Expr + Operator token.Token + Right Expr +} + +func (t *BinaryExpr) Accept(visitor ExprVisitor[any]) any { + return visitor.VisitBinaryExpr(t) +} + +type GroupingExpr struct { + Expression Expr +} + +func (t *GroupingExpr) Accept(visitor ExprVisitor[any]) any { + return visitor.VisitGroupingExpr(t) +} + +type LiteralExpr struct { + Value any +} + +func (t *LiteralExpr) Accept(visitor ExprVisitor[any]) any { + return visitor.VisitLiteralExpr(t) +} + +type UnaryExpr struct { + Operator token.Token + Right Expr +} + +func (t *UnaryExpr) Accept(visitor ExprVisitor[any]) any { + return visitor.VisitUnaryExpr(t) +} + diff --git a/parser/parser.go b/parser/parser.go new file mode 100644 index 0000000..b85fb23 --- /dev/null +++ b/parser/parser.go @@ -0,0 +1,207 @@ +/* Description: This file contains the recursivde descent parser implementation. + * The parser is responsible for parsing the tokens generated by the scanner. + * + * The grammar is as follows: + * expression → equality ; + * equality → comparison ( ( "!=" | "==" ) comparison )* ; + * comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ; + * term → factor ( ( "-" | "+" ) factor )* ; + * factor → unary ( ( "/" | "*" ) unary )* ; + * unary → ( "!" | "-" ) unary + * | primary ; + * primary → NUMBER | STRING | "true" | "false" | "nil" + * | "(" expression ")" ; + */ +package parser + +import ( + "golox/errors" + "golox/token" +) + +// Parser is a recursive descent parser. +type Parser struct { + tokens []token.Token + current int + errLogger errors.Logger +} + +// New creates a new Parser. +func New(tokens []token.Token, el errors.Logger) *Parser { + return &Parser{tokens: tokens, current: 0, errLogger: el} +} + +// Parse parses the tokens and returns the AST. +func (p *Parser) Parse() Expr { + return p.expression() +} + +// expression → equality ; +func (p *Parser) expression() Expr { + return p.equality() +} + +// equality → comparison ( ( "!=" | "==" ) comparison )* ; +func (p *Parser) equality() Expr { + expr := p.comparison() + + for p.match(token.BANG_EQUAL, token.EQUAL_EQUAL) { + operator := p.previous() + right := p.comparison() + expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + } + + return expr +} + +// comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ; +func (p *Parser) comparison() Expr { + expr := p.term() + + for p.match(token.GREATER, token.GREATER_EQUAL, token.LESS, token.LESS_EQUAL) { + operator := p.previous() + right := p.term() + expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + } + + return expr +} + +// term → factor ( ( "-" | "+" ) factor )* ; +func (p *Parser) term() Expr { + expr := p.factor() + + for p.match(token.MINUS, token.PLUS) { + operator := p.previous() + right := p.factor() + expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + } + + return expr +} + +// factor → unary ( ( "/" | "*" ) unary )* ; +func (p *Parser) factor() Expr { + expr := p.unary() + + for p.match(token.SLASH, token.STAR) { + operator := p.previous() + right := p.unary() + expr = &BinaryExpr{Left: expr, Operator: operator, Right: right} + } + + return expr +} + +// unary → ( "!" | "-" ) unary | primary ; +func (p *Parser) unary() Expr { + if p.match(token.BANG, token.MINUS) { + operator := p.previous() + right := p.unary() + return &UnaryExpr{Operator: operator, Right: right} + } + + return p.primary() +} + +// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ; +func (p *Parser) primary() Expr { + switch { + case p.match(token.FALSE): + return &LiteralExpr{Value: false} + case p.match(token.TRUE): + return &LiteralExpr{Value: true} + case p.match(token.NIL): + return &LiteralExpr{Value: nil} + case p.match(token.NUMBER, token.STRING): + return &LiteralExpr{Value: p.previous().Literal} + case p.match(token.LEFT_PAREN): + expr := p.expression() + err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.") + if err != nil { + return err + } + return &GroupingExpr{Expression: expr} + } + + return p.newErrorExpr(p.peek(), "Expect expression.") +} + +// match checks if the current token is any of the given types. +func (p *Parser) match(types ...token.TokenType) bool { + for _, t := range types { + if p.check(t) { + p.advance() + return true + } + } + + return false +} + +// consume consumes the current token if it is of the given type. +func (p *Parser) consume(tt token.TokenType, message string) *ErrorExpr { + if p.check(tt) { + p.advance() + return nil + } + + return p.newErrorExpr(p.peek(), message) +} + +// check checks if the current token is of the given type. +func (p *Parser) check(tt token.TokenType) bool { + if p.isAtEnd() { + return false + } + + return p.peek().Type == tt +} + +// advance advances the current token and returns the previous token. +func (p *Parser) advance() token.Token { + if !p.isAtEnd() { + p.current++ + } + + return p.previous() +} + +// isAtEnd checks if the parser has reached the end of the tokens. +func (p *Parser) isAtEnd() bool { + return p.peek().Type == token.EOF +} + +// peek returns the current token. +func (p *Parser) peek() token.Token { + return p.tokens[p.current] +} + +// previous returns the previous token. +func (p *Parser) previous() token.Token { + return p.tokens[p.current-1] +} + +// newErrorExpr creates a new ErrorExpr and reports the error. +func (p *Parser) newErrorExpr(t token.Token, message string) *ErrorExpr { + p.errLogger.ErrorAtToken(t, message) + return &ErrorExpr{Value: message} +} + +// synchronize synchronizes the parser after an error. +func (p *Parser) synchronize() { + p.advance() + + for !p.isAtEnd() { + if p.previous().Type == token.SEMICOLON { + return + } + + switch p.peek().Type { + case token.CLASS, token.FUN, token.VAR, token.FOR, token.IF, token.WHILE, token.PRINT, token.RETURN: + return + } + + p.advance() + } +} diff --git a/scanner/scanner.go b/scanner/scanner.go index 120dfbc..103635c 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -1,7 +1,7 @@ package scanner import ( - "golox/lox" + "golox/errors" "golox/token" "strconv" ) @@ -9,21 +9,23 @@ import ( // Scanner is a struct that holds the source code, the start and current position // of the scanner, the current line, and the tokens that have been scanned. type Scanner struct { - source string - start int - current int - line int - tokens []token.Token + source string + start int + current int + line int + tokens []token.Token + errLogger errors.Logger } // New creates a new Scanner struct with the given source code. -func New(source string) *Scanner { +func New(source string, el errors.Logger) *Scanner { return &Scanner{ - source: source, // The source code to scan. - start: 0, // The start position of the scanner. - current: 0, // The current position of the scanner. - line: 1, // The current line number. - tokens: []token.Token{}, // The tokens that have been scanned. + source: source, // The source code to scan. + start: 0, // The start position of the scanner. + current: 0, // The current position of the scanner. + line: 1, // The current line number. + tokens: []token.Token{}, // The tokens that have been scanned. + errLogger: el, // The error logger. } } @@ -113,7 +115,7 @@ func (s *Scanner) scanToken() { if isAlpha(c) { s.identifier() } else { - lox.Error(s.line, "Unexpected character.") + s.errLogger.Error(s.line, "Unexpected character.") } } } @@ -151,7 +153,7 @@ func (s *Scanner) number() { f, err := strconv.ParseFloat(s.source[s.start:s.current], 64) if err != nil { - lox.Error(s.line, "Could not parse number.") + s.errLogger.Error(s.line, "Could not parse number.") return } @@ -168,7 +170,7 @@ func (s *Scanner) string() { } if s.isAtEnd() { - lox.Error(s.line, "Unterminated string.") + s.errLogger.Error(s.line, "Unterminated string.") return } diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index 28c3439..004baae 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -5,6 +5,18 @@ import ( "testing" ) +type errorLogger struct{} + +func (el *errorLogger) Error(line int, message string) { +} + +func (el *errorLogger) ErrorAtToken(t token.Token, message string) { +} + +func newErrorLogger() *errorLogger { + return &errorLogger{} +} + func TestScanTokens(t *testing.T) { tests := []struct { name string @@ -76,7 +88,7 @@ func TestScanTokens(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) tokens := scanner.ScanTokens() if len(tokens) != len(tt.tokens)+1 { // +1 for EOF token @@ -108,7 +120,7 @@ func TestIsAtEnd(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) if got := scanner.isAtEnd(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -129,7 +141,7 @@ func TestMatch(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) if got := scanner.match(tt.char); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -149,7 +161,7 @@ func TestPeek(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) if got := scanner.peek(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -169,7 +181,7 @@ func TestPeekNext(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) if got := scanner.peekNext(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -188,7 +200,7 @@ func TestAdvance(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) if got := scanner.advance(); got != tt.expected { t.Errorf("expected %v, got %v", tt.expected, got) } @@ -246,7 +258,7 @@ func TestString(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) scanner.advance() // Move to the first character of the string scanner.string() if tt.expected == "" { @@ -276,7 +288,7 @@ func TestNumber(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) scanner.number() if tt.expected == 0 { if len(scanner.tokens) != 0 { @@ -328,7 +340,7 @@ func TestScanToken(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) scanner.scanToken() if len(scanner.tokens) > 0 { if scanner.tokens[0].Type != tt.expected { @@ -368,7 +380,7 @@ func TestIdentifier(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - scanner := New(tt.source) + scanner := New(tt.source, newErrorLogger()) scanner.identifier() if len(scanner.tokens) != 1 { t.Fatalf("expected 1 token, got %d", len(scanner.tokens))