diff --git a/src/main/ExprGenerator.ws.kts b/src/main/ExprGenerator.ws.kts index 795bab2..e6b7514 100644 --- a/src/main/ExprGenerator.ws.kts +++ b/src/main/ExprGenerator.ws.kts @@ -54,6 +54,7 @@ defineAst("Expr", exprTypes) val stmtTypes = listOf( "Block : List statements", "Expression : Expr expression", + "Function : Token name, List params, List body", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch", "Print : Expr expression", "Var : Token name, Expr? initializer", diff --git a/src/main/fr/celticinfo/lox/Parser.kt b/src/main/fr/celticinfo/lox/Parser.kt index 7774195..2ac7278 100644 --- a/src/main/fr/celticinfo/lox/Parser.kt +++ b/src/main/fr/celticinfo/lox/Parser.kt @@ -22,6 +22,7 @@ class Parser(private val tokens: List) { private fun declaration(): Stmt? { return try { when { + match(FUN) -> function("function") match(VAR) -> varDeclaration() else -> statement() } @@ -30,7 +31,26 @@ class Parser(private val tokens: List) { return null } } - + + private fun function(kind: String): Stmt { + val name = consume(IDENTIFIER, "Expect $kind name.") + consume(LEFT_PAREN, "Expect '(' after $kind name.") + val parameters: MutableList = ArrayList() + if (!check(RIGHT_PAREN)) { + do { + if (parameters.size >= 255) { + error(peek(), "Cannot have more than 255 parameters.") + } + parameters.add(consume(IDENTIFIER, "Expect parameter name.")) + } while (match(COMMA)) + } + consume(RIGHT_PAREN, "Expect ')' after parameters.") + + consume(LEFT_BRACE, "Expect '{' before $kind body.") + val body = blockStatement() + return Function(name, parameters, body) + } + private fun varDeclaration(): Stmt { val name = consume(IDENTIFIER, "Expect variable name.") diff --git a/src/main/fr/celticinfo/lox/Stmt.kt b/src/main/fr/celticinfo/lox/Stmt.kt index ca40b5c..9ac9c26 100644 --- a/src/main/fr/celticinfo/lox/Stmt.kt +++ b/src/main/fr/celticinfo/lox/Stmt.kt @@ -6,6 +6,7 @@ package fr.celticinfo.lox interface StmtVisitor { fun visitBlock(stmt: Block): R fun visitExpression(stmt: Expression): R + fun visitFunction(stmt: Function): R fun visitIf(stmt: If): R fun visitPrint(stmt: Print): R fun visitVar(stmt: Var): R @@ -35,6 +36,16 @@ data class Expression( } } +data class Function( + val name: Token, + val params: List, + val body: List +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitFunction(this) + } +} + data class If( val condition: Expr, val thenBranch: Stmt, diff --git a/src/test/fr/celticinfo/lox/ParserTest.kt b/src/test/fr/celticinfo/lox/ParserTest.kt index 888e101..21ee564 100644 --- a/src/test/fr/celticinfo/lox/ParserTest.kt +++ b/src/test/fr/celticinfo/lox/ParserTest.kt @@ -195,4 +195,24 @@ class ParserTest { val blockStatements = block.statements assertEquals(2,blockStatements.size) } + + @Test + fun `valid code with function declaration`() { + val code = """ + fun add(a, b) { + return a + b; + } + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(1,statements.size) + val stmt = statements.first() + assertTrue(stmt is Function) + val function = stmt as Function + assertEquals("add",function.name.lexeme) + assertEquals(2,function.params.size) + assertEquals(1,function.body.size) + } } \ No newline at end of file