Add functions

main
oabrivard 1 year ago
parent 41bf68b685
commit 7da761a4de

3
.gitignore vendored

@ -35,4 +35,7 @@ bin/
# Go workspace file
go.work
# Go tasks
.task/
# <--- Go

@ -1 +1 @@
ef159c739c079a54bc5f6af10e2ea025
8ad203c83fd99fe71ecaa9da9795cc32

@ -6,6 +6,7 @@ type ExprVisitor[T any] interface {
VisitErrorExpr(ee *ErrorExpr) T
VisitAssignExpr(ae *AssignExpr) T
VisitBinaryExpr(be *BinaryExpr) T
VisitCallExpr(ce *CallExpr) T
VisitGroupingExpr(ge *GroupingExpr) T
VisitLiteralExpr(le *LiteralExpr) T
VisitLogicalExpr(le *LogicalExpr) T
@ -44,6 +45,16 @@ func (be *BinaryExpr) Accept(v ExprVisitor[any]) any {
return v.VisitBinaryExpr(be)
}
type CallExpr struct {
Callee Expr
Paren token.Token
Arguments []Expr
}
func (ce *CallExpr) Accept(v ExprVisitor[any]) any {
return v.VisitCallExpr(ce)
}
type GroupingExpr struct {
Expression Expr
}

@ -33,6 +33,25 @@ func (ap *Printer) VisitExpressionStmt(stmt *ExpressionStmt) any {
return stmt.Expression.Accept(ap)
}
func (ap *Printer) VisitFunctionStmt(stmt *FunctionStmt) any {
str := "(fun " + stmt.Name.Lexeme + "("
for i, param := range stmt.Params {
if i > 0 {
str += ", "
}
str += param.Lexeme
}
str += ")"
if stmt.Body != nil {
str += " " + ap.bracesizeStmt(stmt.Body...)
}
return str + ")"
}
func (ap *Printer) VisitPrintStmt(stmt *PrintStmt) any {
return ap.parenthesizeExpr("print", stmt.Expression)
}
@ -85,6 +104,19 @@ func (ap *Printer) VisitBinaryExpr(expr *BinaryExpr) any {
return ap.parenthesizeExpr(expr.Operator.Lexeme, expr.Left, expr.Right)
}
func (ap *Printer) VisitCallExpr(expr *CallExpr) any {
str := expr.Callee.Accept(ap).(string) + "("
for i, arg := range expr.Arguments {
if i > 0 {
str += ", "
}
str += arg.Accept(ap).(string)
}
return str + ")"
}
func (ap *Printer) VisitGroupingExpr(expr *GroupingExpr) any {
return ap.parenthesizeExpr("group", expr.Expression)
}

@ -63,6 +63,22 @@ func TestPrintExpr(t *testing.T) {
},
expected: "(= foo 42)",
},
{
name: "Variable expression",
expr: &VariableExpr{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"},
},
expected: "foo",
},
{
name: "Call expression",
expr: &CallExpr{
Callee: &VariableExpr{Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"}},
Paren: token.Token{Type: token.LEFT_PAREN, Lexeme: "("},
Arguments: []Expr{&LiteralExpr{Value: 1}, &LiteralExpr{Value: 2}},
},
expected: "foo(1, 2)",
},
}
for _, tt := range tests {
@ -119,6 +135,64 @@ func TestPrintStmts(t *testing.T) {
},
expected: "(var foo 42)\n",
},
{
name: "Function statement",
stmts: []Stmt{
&FunctionStmt{
Name: token.Token{Type: token.IDENTIFIER, Lexeme: "foo"},
Params: []token.Token{{Type: token.IDENTIFIER, Lexeme: "bar"}},
Body: []Stmt{
&ExpressionStmt{
Expression: &LiteralExpr{Value: 42},
},
},
},
},
expected: "(fun foo(bar) {\n 42\n})\n",
},
{
name: "Block statement",
stmts: []Stmt{
&BlockStmt{
Statements: []Stmt{
&ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
&ExpressionStmt{
Expression: &LiteralExpr{Value: 2},
},
},
},
},
expected: "{\n 1\n 2\n}\n",
},
{
name: "If statement",
stmts: []Stmt{
&IfStmt{
Condition: &LiteralExpr{Value: true},
ThenBranch: &ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
ElseBranch: &ExpressionStmt{
Expression: &LiteralExpr{Value: 2},
},
},
},
expected: "(if true {\n 1\n} else {\n 2\n})\n",
},
{
name: "While statement",
stmts: []Stmt{
&WhileStmt{
Condition: &LiteralExpr{Value: true},
Body: &ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
},
},
expected: "(while true {\n 1\n})\n",
},
}
for _, tt := range tests {
@ -131,58 +205,3 @@ func TestPrintStmts(t *testing.T) {
})
}
}
func TestPrintBlockStmt(t *testing.T) {
stmt := &BlockStmt{
Statements: []Stmt{
&ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
&ExpressionStmt{
Expression: &LiteralExpr{Value: 2},
},
},
}
printer := NewPrinter()
result := printer.PrintStmts([]Stmt{stmt})
expected := "{\n 1\n 2\n}\n"
if result != expected {
t.Errorf("expected %v, got %v", expected, result)
}
}
func TestPrintIfStmt(t *testing.T) {
stmt := &IfStmt{
Condition: &LiteralExpr{Value: true},
ThenBranch: &ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
ElseBranch: &ExpressionStmt{
Expression: &LiteralExpr{Value: 2},
},
}
printer := NewPrinter()
result := printer.PrintStmts([]Stmt{stmt})
expected := "(if true {\n 1\n} else {\n 2\n})\n"
if result != expected {
t.Errorf("expected %v, got %v", expected, result)
}
}
func TestPrintWhileStmt(t *testing.T) {
stmt := &WhileStmt{
Condition: &LiteralExpr{Value: true},
Body: &ExpressionStmt{
Expression: &LiteralExpr{Value: 1},
},
}
printer := NewPrinter()
result := printer.PrintStmts([]Stmt{stmt})
expected := "(while true {\n 1\n})\n"
if result != expected {
t.Errorf("expected %v, got %v", expected, result)
}
}

@ -6,6 +6,7 @@ type StmtVisitor[T any] interface {
VisitErrorStmt(es *ErrorStmt) T
VisitBlockStmt(bs *BlockStmt) T
VisitExpressionStmt(es *ExpressionStmt) T
VisitFunctionStmt(fs *FunctionStmt) T
VisitIfStmt(is *IfStmt) T
VisitPrintStmt(ps *PrintStmt) T
VisitVarStmt(vs *VarStmt) T
@ -40,6 +41,16 @@ func (es *ExpressionStmt) Accept(v StmtVisitor[any]) any {
return v.VisitExpressionStmt(es)
}
type FunctionStmt struct {
Name token.Token
Params []token.Token
Body []Stmt
}
func (fs *FunctionStmt) Accept(v StmtVisitor[any]) any {
return v.VisitFunctionStmt(fs)
}
type IfStmt struct {
Condition Expr
ThenBranch Stmt

@ -25,6 +25,7 @@ func main() {
"Error : Value string",
"Assign : Name token.Token, Value Expr",
"Binary : Left Expr, Operator token.Token, Right Expr",
"Call : Callee Expr, Paren token.Token, Arguments []Expr",
"Grouping : Expression Expr",
"Literal : Value any",
"Logical : Left Expr, Operator token.Token, Right Expr",
@ -36,6 +37,7 @@ func main() {
"Error : Value string",
"Block : Statements []Stmt",
"Expression : Expression Expr",
"Function : Name token.Token, Params []token.Token, Body []Stmt",
"If : Condition Expr, ThenBranch Stmt, ElseBranch Stmt",
"Print : Expression Expr",
"Var : Name token.Token, Initializer Expr",

@ -0,0 +1,73 @@
package interpreter
import "golox/ast"
// callable is an interface for callables like functions and methods.
type callable interface {
arity() int
call(i *Interpreter, arguments []any) any
}
// nativeCallable is a struct that implements the callable interface.
type nativeCallable struct {
arityFn func() int
callFn func(interpreter *Interpreter, args []any) any
}
// newNativeCallable creates a new nativeCallable.
func newNativeCallable(arityFn func() int, callFn func(interpreter *Interpreter, args []any) any) *nativeCallable {
return &nativeCallable{
arityFn: arityFn,
callFn: callFn,
}
}
// arity returns the number of arguments the callable expects.
func (n *nativeCallable) arity() int {
return n.arityFn()
}
// call calls the callable with the given arguments.
func (n *nativeCallable) call(i *Interpreter, arguments []any) any {
return n.callFn(i, arguments)
}
// String returns a string representation of the callable.
func (n *nativeCallable) String() string {
return "<native fn>"
}
// function is a struct that implements the callable interface.
type function struct {
declaration *ast.FunctionStmt
}
// newFunction creates a new function.
func newFunction(declaration *ast.FunctionStmt) *function {
return &function{
declaration: declaration,
}
}
// arity returns the number of arguments the function expects.
func (f *function) arity() int {
return len(f.declaration.Params)
}
// call calls the function with the given arguments.
func (f *function) call(i *Interpreter, arguments []any) any {
env := newEnvironment(i.globals)
for i, param := range f.declaration.Params {
env.define(param.Lexeme, arguments[i])
}
i.executeBlock(f.declaration.Body, env)
return nil
}
// String returns a string representation of the function.
func (f *function) String() string {
return "<fn " + f.declaration.Name.Lexeme + ">"
}

@ -6,17 +6,30 @@ import (
"golox/errors"
"golox/token"
"strings"
"time"
)
// Interpreter interprets the AST.
type Interpreter struct {
errLogger errors.Logger
env *environment
globals *environment
}
// New creates a new Interpreter.
func New(el errors.Logger) *Interpreter {
return &Interpreter{el, newEnvironment(nil)}
globals := newEnvironment(nil)
clockCallable := newNativeCallable(
func() int { return 0 },
func(interpreter *Interpreter, args []any) any {
t := time.Now().UnixNano() / int64(time.Millisecond)
return float64(t)
})
globals.define("clock", clockCallable)
return &Interpreter{el, globals, globals}
}
// Interpret interprets the AST.
@ -39,6 +52,13 @@ func (i *Interpreter) VisitExpressionStmt(es *ast.ExpressionStmt) any {
return nil
}
// VisitFunctionStmt visits a function statement.
func (i *Interpreter) VisitFunctionStmt(fs *ast.FunctionStmt) any {
function := newFunction(fs)
i.env.define(fs.Name.Lexeme, function)
return nil
}
// VisitIfStmt visits an if statement.
func (i *Interpreter) VisitIfStmt(is *ast.IfStmt) any {
if isTruthy(i.evaluate(is.Condition)) {
@ -199,6 +219,26 @@ func (i *Interpreter) VisitBinaryExpr(b *ast.BinaryExpr) any {
panic(fmt.Sprintf("Unknown binary operator '%s' [line %d]", b.Operator.Lexeme, b.Operator.Line))
}
// VisitCallExpr visits a CallExpr.
func (i *Interpreter) VisitCallExpr(c *ast.CallExpr) any {
callee := i.evaluate(c.Callee)
var arguments []any
for _, argument := range c.Arguments {
arguments = append(arguments, i.evaluate(argument))
}
if f, ok := callee.(callable); ok {
if len(arguments) != f.arity() {
panic(fmt.Sprintf("Expected %d arguments but got %d [line %d]", f.arity(), len(arguments), c.Paren.Line))
}
return f.call(i, arguments)
}
panic(fmt.Sprintf("Can only call functions and classes [line %d]", c.Paren.Line))
}
// VisitVariableExpr visits a VariableExpr.
func (i *Interpreter) VisitVariableExpr(v *ast.VariableExpr) any {
return i.env.get(v.Name.Lexeme)

@ -53,8 +53,11 @@ func (p *Parser) Parse() []ast.Stmt {
return stmts
}
// declaration → varDecl | statement ;
// declaration → funDecl | varDecl | statement ;
func (p *Parser) declaration() ast.Stmt {
if p.match(token.FUN) {
return p.function("function")
}
if p.match(token.VAR) {
return p.varDeclaration()
}
@ -236,6 +239,55 @@ func (p *Parser) expressionStatement() ast.Stmt {
return &ast.ExpressionStmt{Expression: expr}
}
// function → "fun" IDENTIFIER "(" parameters? ")" block ;
func (p *Parser) function(kind string) ast.Stmt {
err := p.consume(token.IDENTIFIER, "Expect "+kind+" name.")
if err != nil {
return p.fromErrorExpr(err)
}
name := p.previous()
err = p.consume(token.LEFT_PAREN, "Expect '(' after "+kind+" name.")
if err != nil {
return p.fromErrorExpr(err)
}
parameters := []token.Token{}
if !p.check(token.RIGHT_PAREN) {
for {
if len(parameters) >= 255 {
p.newErrorExpr(p.peek(), "Cannot have more than 255 parameters.")
}
err = p.consume(token.IDENTIFIER, "Expect parameter name.")
if err != nil {
return p.fromErrorExpr(err)
}
parameters = append(parameters, p.previous())
if !p.match(token.COMMA) {
break
}
}
}
err = p.consume(token.RIGHT_PAREN, "Expect ')' after parameters.")
if err != nil {
return p.fromErrorExpr(err)
}
err = p.consume(token.LEFT_BRACE, "Expect '{' before "+kind+" body.")
if err != nil {
return p.fromErrorExpr(err)
}
if body, ok := p.blockStatement().(*ast.BlockStmt); ok {
return &ast.FunctionStmt{Name: name, Params: parameters, Body: body.Statements}
}
return p.newErrorStmt(p.peek(), "Expected block statement.")
}
// expression → assignment ;
func (p *Parser) expression() ast.Expr {
return p.assignment()
@ -337,7 +389,7 @@ func (p *Parser) factor() ast.Expr {
return expr
}
// unary → ( "!" | "-" ) unary | primary ;
// unary → ( "!" | "-" ) unary | call ;
func (p *Parser) unary() ast.Expr {
if p.match(token.BANG, token.MINUS) {
operator := p.previous()
@ -345,7 +397,47 @@ func (p *Parser) unary() ast.Expr {
return &ast.UnaryExpr{Operator: operator, Right: right}
}
return p.primary()
return p.call()
}
// call → primary ( "(" arguments? ")" )* ;
func (p *Parser) call() ast.Expr {
expr := p.primary()
for {
if p.match(token.LEFT_PAREN) {
expr = p.finishCall(expr)
} else {
break
}
}
return expr
}
// finishCall finishes the call expression.
func (p *Parser) finishCall(callee ast.Expr) ast.Expr {
arguments := []ast.Expr{}
if !p.check(token.RIGHT_PAREN) {
for {
if len(arguments) >= 255 {
p.newErrorExpr(p.peek(), "Cannot have more than 255 arguments.")
}
arguments = append(arguments, p.expression())
if !p.match(token.COMMA) {
break
}
}
}
err := p.consume(token.RIGHT_PAREN, "Expect ')' after arguments.")
if err != nil {
return p.newErrorExpr(p.peek(), err.Value)
}
return &ast.CallExpr{Callee: callee, Paren: p.previous(), Arguments: arguments}
}
// primary → NUMBER | STRING | "true" | "false" | "nil" | "(" expression ")" | IDENTIFIER;
@ -439,6 +531,12 @@ func (p *Parser) fromErrorExpr(ee *ast.ErrorExpr) *ast.ErrorStmt {
return &ast.ErrorStmt{Value: ee.Value}
}
// newErrorStmt creates a new ErrorStmt and reports the error.
func (p *Parser) newErrorStmt(t token.Token, message string) *ast.ErrorStmt {
p.errLogger.ErrorAtToken(t, message)
return &ast.ErrorStmt{Value: message}
}
// synchronize synchronizes the parser after an error.
// It skips tokens until it finds a statement boundary.
func (p *Parser) synchronize() {

13
testdata/sayhi.lox vendored

@ -0,0 +1,13 @@
var t1 = clock();
fun sayHi(first, last) {
print "Hi, " + first + " " + last + "!";
}
sayHi("Dear", "Reader");
var t2 = clock();
print t1;
print t2;
print t1 - t2;
Loading…
Cancel
Save