Add function declaration

main
Olivier Abrivard 1 year ago
parent 0dcf8dd26c
commit f24f8729e0

@ -54,6 +54,7 @@ defineAst("Expr", exprTypes)
val stmtTypes = listOf( val stmtTypes = listOf(
"Block : List<Stmt?> statements", "Block : List<Stmt?> statements",
"Expression : Expr expression", "Expression : Expr expression",
"Function : Token name, List<Token> params, List<Stmt?> body",
"If : Expr condition, Stmt thenBranch, Stmt? elseBranch", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch",
"Print : Expr expression", "Print : Expr expression",
"Var : Token name, Expr? initializer", "Var : Token name, Expr? initializer",

@ -22,6 +22,7 @@ class Parser(private val tokens: List<Token>) {
private fun declaration(): Stmt? { private fun declaration(): Stmt? {
return try { return try {
when { when {
match(FUN) -> function("function")
match(VAR) -> varDeclaration() match(VAR) -> varDeclaration()
else -> statement() else -> statement()
} }
@ -31,6 +32,25 @@ class Parser(private val tokens: List<Token>) {
} }
} }
private fun function(kind: String): Stmt {
val name = consume(IDENTIFIER, "Expect $kind name.")
consume(LEFT_PAREN, "Expect '(' after $kind name.")
val parameters: MutableList<Token> = ArrayList()
if (!check(RIGHT_PAREN)) {
do {
if (parameters.size >= 255) {
error(peek(), "Cannot have more than 255 parameters.")
}
parameters.add(consume(IDENTIFIER, "Expect parameter name."))
} while (match(COMMA))
}
consume(RIGHT_PAREN, "Expect ')' after parameters.")
consume(LEFT_BRACE, "Expect '{' before $kind body.")
val body = blockStatement()
return Function(name, parameters, body)
}
private fun varDeclaration(): Stmt { private fun varDeclaration(): Stmt {
val name = consume(IDENTIFIER, "Expect variable name.") val name = consume(IDENTIFIER, "Expect variable name.")

@ -6,6 +6,7 @@ package fr.celticinfo.lox
interface StmtVisitor<R> { interface StmtVisitor<R> {
fun visitBlock(stmt: Block): R fun visitBlock(stmt: Block): R
fun visitExpression(stmt: Expression): R fun visitExpression(stmt: Expression): R
fun visitFunction(stmt: Function): R
fun visitIf(stmt: If): R fun visitIf(stmt: If): R
fun visitPrint(stmt: Print): R fun visitPrint(stmt: Print): R
fun visitVar(stmt: Var): R fun visitVar(stmt: Var): R
@ -35,6 +36,16 @@ data class Expression(
} }
} }
data class Function(
val name: Token,
val params: List<Token>,
val body: List<Stmt?>
) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R {
return visitor.visitFunction(this)
}
}
data class If( data class If(
val condition: Expr, val condition: Expr,
val thenBranch: Stmt, val thenBranch: Stmt,

@ -195,4 +195,24 @@ class ParserTest {
val blockStatements = block.statements val blockStatements = block.statements
assertEquals(2,blockStatements.size) assertEquals(2,blockStatements.size)
} }
@Test
fun `valid code with function declaration`() {
val code = """
fun add(a, b) {
return a + b;
}
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(1,statements.size)
val stmt = statements.first()
assertTrue(stmt is Function)
val function = stmt as Function
assertEquals("add",function.name.lexeme)
assertEquals(2,function.params.size)
assertEquals(1,function.body.size)
}
} }
Loading…
Cancel
Save