diff --git a/.gitignore b/.gitignore index 78a0187..a98fadf 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,7 @@ bin/ # Go workspace file go.work +# Go tasks +.task/ + # <--- Go diff --git a/.task/checksum/astgen b/.task/checksum/astgen index e2d05de..69cf4bb 100644 --- a/.task/checksum/astgen +++ b/.task/checksum/astgen @@ -1 +1 @@ -ef159c739c079a54bc5f6af10e2ea025 +8ad203c83fd99fe71ecaa9da9795cc32 diff --git a/ast/expr.go b/ast/expr.go index 3a332d4..4400ebf 100644 --- a/ast/expr.go +++ b/ast/expr.go @@ -6,6 +6,7 @@ type ExprVisitor[T any] interface { VisitErrorExpr(ee *ErrorExpr) T VisitAssignExpr(ae *AssignExpr) T VisitBinaryExpr(be *BinaryExpr) T + VisitCallExpr(ce *CallExpr) T VisitGroupingExpr(ge *GroupingExpr) T VisitLiteralExpr(le *LiteralExpr) T VisitLogicalExpr(le *LogicalExpr) T @@ -44,6 +45,16 @@ func (be *BinaryExpr) Accept(v ExprVisitor[any]) any { return v.VisitBinaryExpr(be) } +type CallExpr struct { + Callee Expr + Paren token.Token + Arguments []Expr +} + +func (ce *CallExpr) Accept(v ExprVisitor[any]) any { + return v.VisitCallExpr(ce) +} + type GroupingExpr struct { Expression Expr } diff --git a/ast/printer.go b/ast/printer.go index 6d14322..2d89766 100644 --- a/ast/printer.go +++ b/ast/printer.go @@ -33,6 +33,25 @@ func (ap *Printer) VisitExpressionStmt(stmt *ExpressionStmt) any { return stmt.Expression.Accept(ap) } +func (ap *Printer) VisitFunctionStmt(stmt *FunctionStmt) any { + str := "(fun " + stmt.Name.Lexeme + "(" + + for i, param := range stmt.Params { + if i > 0 { + str += ", " + } + str += param.Lexeme + + } + + str += ")" + if stmt.Body != nil { + str += " " + ap.bracesizeStmt(stmt.Body...) + } + + return str + ")" +} + func (ap *Printer) VisitPrintStmt(stmt *PrintStmt) any { return ap.parenthesizeExpr("print", stmt.Expression) } @@ -85,6 +104,19 @@ func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any { return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Left, expr.Right) } +func (ap *Printer) VisitCallExpr(expr *CallExpr) any { + str := expr.Callee.Accept(ap).(string) + "(" + + for i, arg := range expr.Arguments { + if i > 0 { + str += ", " + } + str += arg.Accept(ap).(string) + } + + return str + ")" +} + func (ap *Printer) VisitGroupingExpr(expr *GroupingExpr) any { return ap.parenthesizeExpr("group", expr.Expression) } diff --git a/ast/printer_test.go b/ast/printer_test.go index a75b02c..770c9d0 100644 --- a/ast/printer_test.go +++ b/ast/printer_test.go @@ -63,6 +63,22 @@ func TestPrintExpr(t *testing.T) { }, expected: "(= foo 42)", }, + { + name: "Variable expression", + expr: &VariableExpr{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + }, + expected: "foo", + }, + { + name: "Call expression", + expr: &CallExpr{ + Callee: &VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}}, + Paren: token.Token{Type: token.LEFT_PAREN, Lexeme: "("}, + Arguments: []Expr{&LiteralExpr{Value: 1}, &LiteralExpr{Value: 2}}, + }, + expected: "foo(1, 2)", + }, } for _, tt := range tests { @@ -119,6 +135,64 @@ func TestPrintStmts(t *testing.T) { }, expected: "(var foo 42)\n", }, + { + name: "Function statement", + stmts: []Stmt{ + &FunctionStmt{ + Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}, + Params: []token.Token{{Type: token.IDENTIFIER, Lexeme: "bar"}}, + Body: []Stmt{ + &ExpressionStmt{ + Expression: &LiteralExpr{Value: 42}, + }, + }, + }, + }, + expected: "(fun foo(bar) {\n 42\n})\n", + }, + { + name: "Block statement", + stmts: []Stmt{ + &BlockStmt{ + Statements: []Stmt{ + &ExpressionStmt{ + Expression: &LiteralExpr{Value: 1}, + }, + &ExpressionStmt{ + Expression: &LiteralExpr{Value: 2}, + }, + }, + }, + }, + expected: "{\n 1\n 2\n}\n", + }, + { + name: "If statement", + stmts: []Stmt{ + &IfStmt{ + Condition: &LiteralExpr{Value: true}, + ThenBranch: &ExpressionStmt{ + Expression: &LiteralExpr{Value: 1}, + }, + ElseBranch: &ExpressionStmt{ + Expression: &LiteralExpr{Value: 2}, + }, + }, + }, + expected: "(if true {\n 1\n} else {\n 2\n})\n", + }, + { + name: "While statement", + stmts: []Stmt{ + &WhileStmt{ + Condition: &LiteralExpr{Value: true}, + Body: &ExpressionStmt{ + Expression: &LiteralExpr{Value: 1}, + }, + }, + }, + expected: "(while true {\n 1\n})\n", + }, } for _, tt := range tests { @@ -131,58 +205,3 @@ func TestPrintStmts(t *testing.T) { }) } } - -func TestPrintBlockStmt(t *testing.T) { - stmt := &BlockStmt{ - Statements: []Stmt{ - &ExpressionStmt{ - Expression: &LiteralExpr{Value: 1}, - }, - &ExpressionStmt{ - Expression: &LiteralExpr{Value: 2}, - }, - }, - } - - printer := NewPrinter() - result := printer.PrintStmts([]Stmt{stmt}) - expected := "{\n 1\n 2\n}\n" - if result != expected { - t.Errorf("expected %v, got %v", expected, result) - } -} - -func TestPrintIfStmt(t *testing.T) { - stmt := &IfStmt{ - Condition: &LiteralExpr{Value: true}, - ThenBranch: &ExpressionStmt{ - Expression: &LiteralExpr{Value: 1}, - }, - ElseBranch: &ExpressionStmt{ - Expression: &LiteralExpr{Value: 2}, - }, - } - - printer := NewPrinter() - result := printer.PrintStmts([]Stmt{stmt}) - expected := "(if true {\n 1\n} else {\n 2\n})\n" - if result != expected { - t.Errorf("expected %v, got %v", expected, result) - } -} - -func TestPrintWhileStmt(t *testing.T) { - stmt := &WhileStmt{ - Condition: &LiteralExpr{Value: true}, - Body: &ExpressionStmt{ - Expression: &LiteralExpr{Value: 1}, - }, - } - - printer := NewPrinter() - result := printer.PrintStmts([]Stmt{stmt}) - expected := "(while true {\n 1\n})\n" - if result != expected { - t.Errorf("expected %v, got %v", expected, result) - } -} diff --git a/ast/stmt.go b/ast/stmt.go index e5944f3..2e84bb1 100644 --- a/ast/stmt.go +++ b/ast/stmt.go @@ -6,6 +6,7 @@ type StmtVisitor[T any] interface { VisitErrorStmt(es *ErrorStmt) T VisitBlockStmt(bs *BlockStmt) T VisitExpressionStmt(es *ExpressionStmt) T + VisitFunctionStmt(fs *FunctionStmt) T VisitIfStmt(is *IfStmt) T VisitPrintStmt(ps *PrintStmt) T VisitVarStmt(vs *VarStmt) T @@ -40,6 +41,16 @@ func (es *ExpressionStmt) Accept(v StmtVisitor[any]) any { return v.VisitExpressionStmt(es) } +type FunctionStmt struct { + Name token.Token + Params []token.Token + Body []Stmt +} + +func (fs *FunctionStmt) Accept(v StmtVisitor[any]) any { + return v.VisitFunctionStmt(fs) +} + type IfStmt struct { Condition Expr ThenBranch Stmt diff --git a/cmd/astgen/main.go b/cmd/astgen/main.go index ab529cd..56f4bd9 100644 --- a/cmd/astgen/main.go +++ b/cmd/astgen/main.go @@ -25,6 +25,7 @@ func main() { "Error : Value string", "Assign : Name token.Token, Value Expr", "Binary : Left Expr, Operator token.Token, Right Expr", + "Call : Callee Expr, Paren token.Token, Arguments []Expr", "Grouping : Expression Expr", "Literal : Value any", "Logical : Left Expr, Operator token.Token, Right Expr", @@ -36,6 +37,7 @@ func main() { "Error : Value string", "Block : Statements []Stmt", "Expression : Expression Expr", + "Function : Name token.Token, Params []token.Token, Body []Stmt", "If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt", "Print : Expression Expr", "Var : Name token.Token, Initializer Expr", diff --git a/interpreter/callable.go b/interpreter/callable.go new file mode 100644 index 0000000..52b4217 --- /dev/null +++ b/interpreter/callable.go @@ -0,0 +1,73 @@ +package interpreter + +import "golox/ast" + +// callable is an interface for callables like functions and methods. +type callable interface { + arity() int + call(i *Interpreter, arguments []any) any +} + +// nativeCallable is a struct that implements the callable interface. +type nativeCallable struct { + arityFn func() int + callFn func(interpreter *Interpreter, args []any) any +} + +// newNativeCallable creates a new nativeCallable. +func newNativeCallable(arityFn func() int, callFn func(interpreter *Interpreter, args []any) any) *nativeCallable { + return &nativeCallable{ + arityFn: arityFn, + callFn: callFn, + } +} + +// arity returns the number of arguments the callable expects. +func (n *nativeCallable) arity() int { + return n.arityFn() +} + +// call calls the callable with the given arguments. +func (n *nativeCallable) call(i *Interpreter, arguments []any) any { + return n.callFn(i, arguments) +} + +// String returns a string representation of the callable. +func (n *nativeCallable) String() string { + return "" +} + +// function is a struct that implements the callable interface. +type function struct { + declaration *ast.FunctionStmt +} + +// newFunction creates a new function. +func newFunction(declaration *ast.FunctionStmt) *function { + return &function{ + declaration: declaration, + } +} + +// arity returns the number of arguments the function expects. +func (f *function) arity() int { + return len(f.declaration.Params) +} + +// call calls the function with the given arguments. +func (f *function) call(i *Interpreter, arguments []any) any { + env := newEnvironment(i.globals) + + for i, param := range f.declaration.Params { + env.define(param.Lexeme, arguments[i]) + } + + i.executeBlock(f.declaration.Body, env) + + return nil +} + +// String returns a string representation of the function. +func (f *function) String() string { + return "" +} diff --git a/interpreter/interpreter.go b/interpreter/interpreter.go index 79754db..c4ab809 100644 --- a/interpreter/interpreter.go +++ b/interpreter/interpreter.go @@ -6,17 +6,30 @@ import ( "golox/errors" "golox/token" "strings" + "time" ) // Interpreter interprets the AST. type Interpreter struct { errLogger errors.Logger env *environment + globals *environment } // New creates a new Interpreter. func New(el errors.Logger) *Interpreter { - return &Interpreter{el, newEnvironment(nil)} + globals := newEnvironment(nil) + + clockCallable := newNativeCallable( + func() int { return 0 }, + func(interpreter *Interpreter, args []any) any { + t := time.Now().UnixNano() / int64(time.Millisecond) + return float64(t) + }) + + globals.define("clock", clockCallable) + + return &Interpreter{el, globals, globals} } // Interpret interprets the AST. @@ -39,6 +52,13 @@ func (i *Interpreter) VisitExpressionStmt(es *ast.ExpressionStmt) any { return nil } +// VisitFunctionStmt visits a function statement. +func (i *Interpreter) VisitFunctionStmt(fs *ast.FunctionStmt) any { + function := newFunction(fs) + i.env.define(fs.Name.Lexeme, function) + return nil +} + // VisitIfStmt visits an if statement. func (i *Interpreter) VisitIfStmt(is *ast.IfStmt) any { if isTruthy(i.evaluate(is.Condition)) { @@ -199,6 +219,26 @@ func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any { panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line)) } +// VisitCallExpr visits a CallExpr. +func (i *Interpreter) VisitCallExpr(c *ast.CallExpr) any { + callee := i.evaluate(c.Callee) + + var arguments []any + for _, argument := range c.Arguments { + arguments = append(arguments, i.evaluate(argument)) + } + + if f, ok := callee.(callable); ok { + if len(arguments) != f.arity() { + panic(fmt.Sprintf("Expected %d arguments but got %d [line %d]", f.arity(), len(arguments), c.Paren.Line)) + } + + return f.call(i, arguments) + } + + panic(fmt.Sprintf("Can only call functions and classes [line %d]", c.Paren.Line)) +} + // VisitVariableExpr visits a VariableExpr. func (i *Interpreter) VisitVariableExpr(v *ast.VariableExpr) any { return i.env.get(v.Name.Lexeme) diff --git a/parser/parser.go b/parser/parser.go index b5d77f4..564b8fe 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -53,8 +53,11 @@ func (p *Parser) Parse() []ast.Stmt { return stmts } -// declaration → varDecl | statement ; +// declaration → funDecl | varDecl | statement ; func (p *Parser) declaration() ast.Stmt { + if p.match(token.FUN) { + return p.function("function") + } if p.match(token.VAR) { return p.varDeclaration() } @@ -236,6 +239,55 @@ func (p *Parser) expressionStatement() ast.Stmt { return &ast.ExpressionStmt{Expression: expr} } +// function → "fun" IDENTIFIER "(" parameters? ")" block ; +func (p *Parser) function(kind string) ast.Stmt { + err := p.consume(token.IDENTIFIER, "Expect "+kind+" name.") + if err != nil { + return p.fromErrorExpr(err) + } + name := p.previous() + + err = p.consume(token.LEFT_PAREN, "Expect '(' after "+kind+" name.") + if err != nil { + return p.fromErrorExpr(err) + } + + parameters := []token.Token{} + if !p.check(token.RIGHT_PAREN) { + for { + if len(parameters) >= 255 { + p.newErrorExpr(p.peek(), "Cannot have more than 255 parameters.") + } + + err = p.consume(token.IDENTIFIER, "Expect parameter name.") + if err != nil { + return p.fromErrorExpr(err) + } + + parameters = append(parameters, p.previous()) + if !p.match(token.COMMA) { + break + } + } + } + + err = p.consume(token.RIGHT_PAREN, "Expect ')' after parameters.") + if err != nil { + return p.fromErrorExpr(err) + } + + err = p.consume(token.LEFT_BRACE, "Expect '{' before "+kind+" body.") + if err != nil { + return p.fromErrorExpr(err) + } + + if body, ok := p.blockStatement().(*ast.BlockStmt); ok { + return &ast.FunctionStmt{Name: name, Params: parameters, Body: body.Statements} + } + + return p.newErrorStmt(p.peek(), "Expected block statement.") +} + // expression → assignment ; func (p *Parser) expression() ast.Expr { return p.assignment() @@ -337,7 +389,7 @@ func (p *Parser) factor() ast.Expr { return expr } -// unary → ( "!" | "-" ) unary | primary ; +// unary → ( "!" | "-" ) unary | call ; func (p *Parser) unary() ast.Expr { if p.match(token.BANG, token.MINUS) { operator := p.previous() @@ -345,7 +397,47 @@ func (p *Parser) unary() ast.Expr { return &ast.UnaryExpr{Operator: operator, Right: right} } - return p.primary() + return p.call() +} + +// call → primary ( "(" arguments? ")" )* ; +func (p *Parser) call() ast.Expr { + expr := p.primary() + + for { + if p.match(token.LEFT_PAREN) { + expr = p.finishCall(expr) + } else { + break + } + } + + return expr +} + +// finishCall finishes the call expression. +func (p *Parser) finishCall(callee ast.Expr) ast.Expr { + arguments := []ast.Expr{} + + if !p.check(token.RIGHT_PAREN) { + for { + if len(arguments) >= 255 { + p.newErrorExpr(p.peek(), "Cannot have more than 255 arguments.") + } + + arguments = append(arguments, p.expression()) + if !p.match(token.COMMA) { + break + } + } + } + + err := p.consume(token.RIGHT_PAREN, "Expect ')' after arguments.") + if err != nil { + return p.newErrorExpr(p.peek(), err.Value) + } + + return &ast.CallExpr{Callee: callee, Paren: p.previous(), Arguments: arguments} } // primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER; @@ -439,6 +531,12 @@ func (p *Parser) fromErrorExpr(ee *ast.ErrorExpr) *ast.ErrorStmt { return &ast.ErrorStmt{Value: ee.Value} } +// newErrorStmt creates a new ErrorStmt and reports the error. +func (p *Parser) newErrorStmt(t token.Token, message string) *ast.ErrorStmt { + p.errLogger.ErrorAtToken(t, message) + return &ast.ErrorStmt{Value: message} +} + // synchronize synchronizes the parser after an error. // It skips tokens until it finds a statement boundary. func (p *Parser) synchronize() { diff --git a/testdata/sayhi.lox b/testdata/sayhi.lox new file mode 100644 index 0000000..72be92f --- /dev/null +++ b/testdata/sayhi.lox @@ -0,0 +1,13 @@ +var t1 = clock(); + +fun sayHi(first, last) { + print "Hi, " + first + " " + last + "!"; +} + +sayHi("Dear", "Reader"); + +var t2 = clock(); + +print t1; +print t2; +print t1 - t2;