Add expression parser

main
oabrivard 1 year ago
parent d8050c1699
commit 2aae58b670

@ -0,0 +1 @@
5f28a0838ca9382a947da5bd52756cf3

@ -0,0 +1,44 @@
# https://taskfile.dev
version: '3'
vars:
APP_NAME: golox
tasks:
default:
cmds:
- task: run
run:
desc: "Run the interpreter"
deps: [build]
cmds:
- ./bin/{{.APP_NAME}} {{.CLI_ARGS}}
build:
desc: "Build the interpreter"
deps: [astgen]
cmds:
- go build -o bin/{{.APP_NAME}} cmd/golox/main.go
test:
desc: "Run tests"
deps: [astgen]
cmds:
- go test -v ./...
astgen:
desc: "Generate AST nodes"
cmds:
- go run cmd/astgen/main.go "./parser"
sources:
- cmd/astgen/main.go
generates:
- parser/expr.go
clean:
desc: "Clean up"
cmds:
- rm -rf bin/*

@ -0,0 +1,50 @@
package ast
import (
"fmt"
"golox/parser"
)
type Printer struct {
}
func New() *Printer {
return &Printer{}
}
func (ap *Printer) Print(expr parser.Expr) string {
return expr.Accept(ap).(string)
}
func (ap *Printer) VisitBinaryExpr(expr *parser.BinaryExpr) any {
return ap.parenthesize(expr.Operator.Lexeme, expr.Left, expr.Right)
}
func (ap *Printer) VisitGroupingExpr(expr *parser.GroupingExpr) any {
return ap.parenthesize("group", expr.Expression)
}
func (ap *Printer) VisitLiteralExpr(expr *parser.LiteralExpr) any {
if expr.Value == nil {
return "nil"
}
return fmt.Sprint(expr.Value)
}
func (ap *Printer) VisitUnaryExpr(expr *parser.UnaryExpr) any {
return ap.parenthesize(expr.Operator.Lexeme, expr.Right)
}
func (ap *Printer) VisitErrorExpr(expr *parser.ErrorExpr) any {
return expr.Value
}
func (ap *Printer) parenthesize(name string, exprs ...parser.Expr) string {
str := "(" + name
for _, expr := range exprs {
str += " " + expr.Accept(ap).(string)
}
return str + ")"
}

@ -1,2 +0,0 @@
#!/bin/zsh
go build -o bin/golox

@ -0,0 +1,81 @@
package main
import (
"fmt"
"os"
"strings"
)
func main() {
if len(os.Args) != 2 {
fmt.Println("Usage: astgen <output_directory>")
os.Exit(64)
}
d := os.Args[1]
if _, err := os.Stat(d); os.IsNotExist(err) {
fmt.Println("Directory does not exist")
os.Exit(74)
}
fmt.Println("Generating AST classes in", d)
defineAst(d, "Expr", []string{
"Error : Value string",
"Binary : Left Expr, Operator token.Token, Right Expr",
"Grouping : Expression Expr",
"Literal : Value any",
"Unary : Operator token.Token, Right Expr",
})
}
func defineAst(outputDir, baseName string, types []string) {
path := outputDir + "/" + strings.ToLower(baseName) + ".go"
file, err := os.Create(path)
if err != nil {
fmt.Println("Error creating file", path)
os.Exit(74)
}
defer file.Close()
file.WriteString("package parser\n\n")
file.WriteString("import \"golox/token\"\n\n")
defineVisitor(file, baseName, types)
file.WriteString("type " + baseName + " interface {\n")
file.WriteString(" Accept(visitor " + baseName + "Visitor[any]) any\n")
file.WriteString("}\n\n")
for _, t := range types {
defineType(file, baseName, t)
}
}
func defineVisitor(file *os.File, baseName string, types []string) {
file.WriteString("type " + baseName + "Visitor[T any] interface {\n")
for _, t := range types {
typeName := strings.TrimSpace(t[:strings.Index(t, ":")-1])
file.WriteString(" Visit" + typeName + baseName + "(" + strings.ToLower(typeName) + " *" + typeName + baseName + ") T\n")
}
file.WriteString("}\n\n")
}
func defineType(file *os.File, baseName, typeString string) {
typeName := strings.TrimSpace(typeString[:strings.Index(typeString, ":")-1])
fields := strings.TrimSpace(typeString[strings.Index(typeString, ":")+1:])
file.WriteString("type " + typeName + baseName + " struct {\n")
for _, field := range strings.Split(fields, ", ") {
file.WriteString(" " + field + "\n")
}
file.WriteString("}\n\n")
file.WriteString("func (t *" + typeName + baseName + ") Accept(visitor " + baseName + "Visitor[any]) any {\n")
file.WriteString(" return visitor.Visit" + typeName + baseName + "(t)\n")
file.WriteString("}\n\n")
}

@ -9,13 +9,14 @@ import (
func main() {
nbArgs := len(os.Args)
l := lox.New()
if nbArgs > 2 {
fmt.Println("Usage: golox [script]")
os.Exit(64)
} else if nbArgs == 2 {
lox.RunFile(os.Args[1])
l.RunFile(os.Args[1])
} else {
lox.RunPrompt()
l.RunPrompt()
}
}

@ -0,0 +1,8 @@
package errors
import "golox/token"
type Logger interface {
Error(line int, message string)
ErrorAtToken(t token.Token, message string)
}

@ -3,26 +3,36 @@ package lox
import (
"bufio"
"fmt"
"golox/ast"
"golox/parser"
"golox/scanner"
"golox/token"
"os"
)
var hadError = false
type Lox struct {
hadError bool
}
func New() *Lox {
return &Lox{hadError: false}
}
func RunFile(path string) {
func (l *Lox) RunFile(path string) {
bytes, err := os.ReadFile(path)
if err != nil {
fmt.Println("Error reading file", path)
os.Exit(74)
}
run(string(bytes))
l.run(string(bytes))
if hadError {
if l.hadError {
os.Exit(65)
}
}
func RunPrompt() {
func (l *Lox) RunPrompt() {
reader := bufio.NewReader(os.Stdin)
for {
@ -38,19 +48,38 @@ func RunPrompt() {
break
}
run(line)
hadError = false
l.run(line)
l.hadError = false
}
}
func Error(line int, message string) {
report(line, "", message)
func (l *Lox) Error(line int, message string) {
l.report(line, "", message)
}
func report(line int, where string, message string) {
func (l *Lox) ErrorAtToken(t token.Token, message string) {
if t.Type == token.EOF {
l.report(t.Line, " at end", message)
} else {
l.report(t.Line, " at '"+t.Lexeme+"'", message)
}
}
func (l *Lox) report(line int, where string, message string) {
fmt.Printf("[line %d] Error %s: %s\n", line, where, message)
l.hadError = true
}
func run(source string) {
fmt.Println(source)
func (l *Lox) run(source string) {
scanner := scanner.New(source, l)
tokens := scanner.ScanTokens()
parser := parser.New(tokens, l)
expr := parser.Parse()
if l.hadError {
return
}
p := ast.New()
fmt.Println(p.Print(expr))
}

@ -25,7 +25,8 @@ func TestRun(t *testing.T) {
}()
source := "print('Hello, World!');"
run(source)
l := New()
l.run(source)
// back to normal state
w.Close()
@ -69,7 +70,8 @@ func TestRunFile(t *testing.T) {
if err := tmpfile.Close(); err != nil {
t.Fatal(err)
}
RunFile(tmpfile.Name())
l := New()
l.RunFile(tmpfile.Name())
// back to normal state
w.Close()
@ -101,7 +103,8 @@ func TestError(t *testing.T) {
line := 1
message := "Unexpected character."
Error(line, message)
l := New()
l.Error(line, message)
// back to normal state
w.Close()
@ -134,7 +137,8 @@ func TestReport(t *testing.T) {
line := 1
where := "at 'foo'"
message := "Unexpected character."
report(line, where, message)
l := New()
l.report(line, where, message)
// back to normal state
w.Close()
@ -166,7 +170,8 @@ func TestRunPrompt(t *testing.T) {
wIn.Write([]byte(input))
wIn.Close()
RunPrompt()
l := New()
l.RunPrompt()
wOut.Close()
os.Stdin = oldStdin

@ -0,0 +1,59 @@
package parser
import "golox/token"
type ExprVisitor[T any] interface {
VisitErrorExpr(error *ErrorExpr) T
VisitBinaryExpr(binary *BinaryExpr) T
VisitGroupingExpr(grouping *GroupingExpr) T
VisitLiteralExpr(literal *LiteralExpr) T
VisitUnaryExpr(unary *UnaryExpr) T
}
type Expr interface {
Accept(visitor ExprVisitor[any]) any
}
type ErrorExpr struct {
Value string
}
func (t *ErrorExpr) Accept(visitor ExprVisitor[any]) any {
return visitor.VisitErrorExpr(t)
}
type BinaryExpr struct {
Left Expr
Operator token.Token
Right Expr
}
func (t *BinaryExpr) Accept(visitor ExprVisitor[any]) any {
return visitor.VisitBinaryExpr(t)
}
type GroupingExpr struct {
Expression Expr
}
func (t *GroupingExpr) Accept(visitor ExprVisitor[any]) any {
return visitor.VisitGroupingExpr(t)
}
type LiteralExpr struct {
Value any
}
func (t *LiteralExpr) Accept(visitor ExprVisitor[any]) any {
return visitor.VisitLiteralExpr(t)
}
type UnaryExpr struct {
Operator token.Token
Right Expr
}
func (t *UnaryExpr) Accept(visitor ExprVisitor[any]) any {
return visitor.VisitUnaryExpr(t)
}

@ -0,0 +1,207 @@
/* Description: This file contains the recursivde descent parser implementation.
* The parser is responsible for parsing the tokens generated by the scanner.
*
* The grammar is as follows:
* expression equality ;
* equality comparison ( ( "!=" | "==" ) comparison )* ;
* comparison term ( ( ">" | ">=" | "<" | "<=" ) term )* ;
* term factor ( ( "-" | "+" ) factor )* ;
* factor unary ( ( "/" | "*" ) unary )* ;
* unary ( "!" | "-" ) unary
* | primary ;
* primary NUMBER | STRING | "true" | "false" | "nil"
* | "(" expression ")" ;
*/
package parser
import (
"golox/errors"
"golox/token"
)
// Parser is a recursive descent parser.
type Parser struct {
tokens []token.Token
current int
errLogger errors.Logger
}
// New creates a new Parser.
func New(tokens []token.Token, el errors.Logger) *Parser {
return &Parser{tokens: tokens, current: 0, errLogger: el}
}
// Parse parses the tokens and returns the AST.
func (p *Parser) Parse() Expr {
return p.expression()
}
// expression → equality ;
func (p *Parser) expression() Expr {
return p.equality()
}
// equality → comparison ( ( "!=" | "==" ) comparison )* ;
func (p *Parser) equality() Expr {
expr := p.comparison()
for p.match(token.BANG_EQUAL, token.EQUAL_EQUAL) {
operator := p.previous()
right := p.comparison()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right}
}
return expr
}
// comparison → term ( ( ">" | ">=" | "<" | "<=" ) term )* ;
func (p *Parser) comparison() Expr {
expr := p.term()
for p.match(token.GREATER, token.GREATER_EQUAL, token.LESS, token.LESS_EQUAL) {
operator := p.previous()
right := p.term()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right}
}
return expr
}
// term → factor ( ( "-" | "+" ) factor )* ;
func (p *Parser) term() Expr {
expr := p.factor()
for p.match(token.MINUS, token.PLUS) {
operator := p.previous()
right := p.factor()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right}
}
return expr
}
// factor → unary ( ( "/" | "*" ) unary )* ;
func (p *Parser) factor() Expr {
expr := p.unary()
for p.match(token.SLASH, token.STAR) {
operator := p.previous()
right := p.unary()
expr = &BinaryExpr{Left: expr, Operator: operator, Right: right}
}
return expr
}
// unary → ( "!" | "-" ) unary | primary ;
func (p *Parser) unary() Expr {
if p.match(token.BANG, token.MINUS) {
operator := p.previous()
right := p.unary()
return &UnaryExpr{Operator: operator, Right: right}
}
return p.primary()
}
// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" ;
func (p *Parser) primary() Expr {
switch {
case p.match(token.FALSE):
return &LiteralExpr{Value: false}
case p.match(token.TRUE):
return &LiteralExpr{Value: true}
case p.match(token.NIL):
return &LiteralExpr{Value: nil}
case p.match(token.NUMBER, token.STRING):
return &LiteralExpr{Value: p.previous().Literal}
case p.match(token.LEFT_PAREN):
expr := p.expression()
err := p.consume(token.RIGHT_PAREN, "Expect ')' after expression.")
if err != nil {
return err
}
return &GroupingExpr{Expression: expr}
}
return p.newErrorExpr(p.peek(), "Expect expression.")
}
// match checks if the current token is any of the given types.
func (p *Parser) match(types ...token.TokenType) bool {
for _, t := range types {
if p.check(t) {
p.advance()
return true
}
}
return false
}
// consume consumes the current token if it is of the given type.
func (p *Parser) consume(tt token.TokenType, message string) *ErrorExpr {
if p.check(tt) {
p.advance()
return nil
}
return p.newErrorExpr(p.peek(), message)
}
// check checks if the current token is of the given type.
func (p *Parser) check(tt token.TokenType) bool {
if p.isAtEnd() {
return false
}
return p.peek().Type == tt
}
// advance advances the current token and returns the previous token.
func (p *Parser) advance() token.Token {
if !p.isAtEnd() {
p.current++
}
return p.previous()
}
// isAtEnd checks if the parser has reached the end of the tokens.
func (p *Parser) isAtEnd() bool {
return p.peek().Type == token.EOF
}
// peek returns the current token.
func (p *Parser) peek() token.Token {
return p.tokens[p.current]
}
// previous returns the previous token.
func (p *Parser) previous() token.Token {
return p.tokens[p.current-1]
}
// newErrorExpr creates a new ErrorExpr and reports the error.
func (p *Parser) newErrorExpr(t token.Token, message string) *ErrorExpr {
p.errLogger.ErrorAtToken(t, message)
return &ErrorExpr{Value: message}
}
// synchronize synchronizes the parser after an error.
func (p *Parser) synchronize() {
p.advance()
for !p.isAtEnd() {
if p.previous().Type == token.SEMICOLON {
return
}
switch p.peek().Type {
case token.CLASS, token.FUN, token.VAR, token.FOR, token.IF, token.WHILE, token.PRINT, token.RETURN:
return
}
p.advance()
}
}

@ -1,7 +1,7 @@
package scanner
import (
"golox/lox"
"golox/errors"
"golox/token"
"strconv"
)
@ -14,16 +14,18 @@ type Scanner struct {
current int
line int
tokens []token.Token
errLogger errors.Logger
}
// New creates a new Scanner struct with the given source code.
func New(source string) *Scanner {
func New(source string, el errors.Logger) *Scanner {
return &Scanner{
source: source, // The source code to scan.
start: 0, // The start position of the scanner.
current: 0, // The current position of the scanner.
line: 1, // The current line number.
tokens: []token.Token{}, // The tokens that have been scanned.
errLogger: el, // The error logger.
}
}
@ -113,7 +115,7 @@ func (s *Scanner) scanToken() {
if isAlpha(c) {
s.identifier()
} else {
lox.Error(s.line, "Unexpected character.")
s.errLogger.Error(s.line, "Unexpected character.")
}
}
}
@ -151,7 +153,7 @@ func (s *Scanner) number() {
f, err := strconv.ParseFloat(s.source[s.start:s.current], 64)
if err != nil {
lox.Error(s.line, "Could not parse number.")
s.errLogger.Error(s.line, "Could not parse number.")
return
}
@ -168,7 +170,7 @@ func (s *Scanner) string() {
}
if s.isAtEnd() {
lox.Error(s.line, "Unterminated string.")
s.errLogger.Error(s.line, "Unterminated string.")
return
}

@ -5,6 +5,18 @@ import (
"testing"
)
type errorLogger struct{}
func (el *errorLogger) Error(line int, message string) {
}
func (el *errorLogger) ErrorAtToken(t token.Token, message string) {
}
func newErrorLogger() *errorLogger {
return &errorLogger{}
}
func TestScanTokens(t *testing.T) {
tests := []struct {
name string
@ -76,7 +88,7 @@ func TestScanTokens(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
tokens := scanner.ScanTokens()
if len(tokens) != len(tt.tokens)+1 { // +1 for EOF token
@ -108,7 +120,7 @@ func TestIsAtEnd(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
if got := scanner.isAtEnd(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got)
}
@ -129,7 +141,7 @@ func TestMatch(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
if got := scanner.match(tt.char); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got)
}
@ -149,7 +161,7 @@ func TestPeek(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
if got := scanner.peek(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got)
}
@ -169,7 +181,7 @@ func TestPeekNext(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
if got := scanner.peekNext(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got)
}
@ -188,7 +200,7 @@ func TestAdvance(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
if got := scanner.advance(); got != tt.expected {
t.Errorf("expected %v, got %v", tt.expected, got)
}
@ -246,7 +258,7 @@ func TestString(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
scanner.advance() // Move to the first character of the string
scanner.string()
if tt.expected == "" {
@ -276,7 +288,7 @@ func TestNumber(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
scanner.number()
if tt.expected == 0 {
if len(scanner.tokens) != 0 {
@ -328,7 +340,7 @@ func TestScanToken(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
scanner.scanToken()
if len(scanner.tokens) > 0 {
if scanner.tokens[0].Type != tt.expected {
@ -368,7 +380,7 @@ func TestIdentifier(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
scanner := New(tt.source)
scanner := New(tt.source, newErrorLogger())
scanner.identifier()
if len(scanner.tokens) != 1 {
t.Fatalf("expected 1 token, got %d", len(scanner.tokens))

Loading…
Cancel
Save