Add Return statement

main
Olivier Abrivard 1 year ago
parent 20033fd74a
commit 3ac911b666

@ -57,6 +57,7 @@ val stmtTypes = listOf(
"Function : Token name, List<Token> params, List<Stmt?> body", "Function : Token name, List<Token> params, List<Stmt?> body",
"If : Expr condition, Stmt thenBranch, Stmt? elseBranch", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch",
"Print : Expr expression", "Print : Expr expression",
"Return : Token keyword, Expr? value",
"Var : Token name, Expr? initializer", "Var : Token name, Expr? initializer",
"While : Expr condition, Stmt body" "While : Expr condition, Stmt body"
) )

@ -85,6 +85,11 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
println(stringify(value)) println(stringify(value))
} }
override fun visitReturn(stmt: Return) {
val value = stmt.value?.let { evaluate(it) }
throw LoxReturn(value)
}
override fun visitVar(stmt: Var) { override fun visitVar(stmt: Var) {
val value = evaluate(stmt.initializer) val value = evaluate(stmt.initializer)
environment.define(stmt.name.lexeme, value) environment.define(stmt.name.lexeme, value)

@ -9,10 +9,17 @@ class LoxFunction : LoxCallable {
override fun call(interpreter: Interpreter, arguments: List<Any?>): Any? { override fun call(interpreter: Interpreter, arguments: List<Any?>): Any? {
val environment = Environment(interpreter.globals) val environment = Environment(interpreter.globals)
for (i in declaration.params.indices) { for (i in declaration.params.indices) {
environment.define(declaration.params[i].lexeme, arguments[i]) environment.define(declaration.params[i].lexeme, arguments[i])
} }
interpreter.executeBlock(declaration.body, environment)
try {
interpreter.executeBlock(declaration.body, environment)
} catch (returnValue: LoxReturn) {
return returnValue.value
}
return null return null
} }

@ -0,0 +1,4 @@
package fr.celticinfo.lox
class LoxReturn(val value: Any?) : RuntimeException(null, null, false, false) {
}

@ -77,6 +77,7 @@ class Parser(private val tokens: List<Token>) {
match(FOR) -> forStatement() match(FOR) -> forStatement()
match(IF) -> ifStatement() match(IF) -> ifStatement()
match(PRINT) -> printStatement() match(PRINT) -> printStatement()
match(RETURN) -> returnStatement()
match(WHILE) -> whileStatement() match(WHILE) -> whileStatement()
match(LEFT_BRACE) -> Block(blockStatement()) match(LEFT_BRACE) -> Block(blockStatement())
else -> expressionStatement() else -> expressionStatement()
@ -140,6 +141,13 @@ class Parser(private val tokens: List<Token>) {
return Print(value) return Print(value)
} }
private fun returnStatement(): Stmt {
val keyword = previous()
val value = if (!check(SEMICOLON)) expression() else null
consume(SEMICOLON, "Expect ';' after return value.")
return Return(keyword, value)
}
private fun expressionStatement(): Stmt { private fun expressionStatement(): Stmt {
val value = expression() val value = expression()
consume(SEMICOLON, "Expect ';' after expression.") consume(SEMICOLON, "Expect ';' after expression.")

@ -9,6 +9,7 @@ interface StmtVisitor<R> {
fun visitFunction(stmt: Function): R fun visitFunction(stmt: Function): R
fun visitIf(stmt: If): R fun visitIf(stmt: If): R
fun visitPrint(stmt: Print): R fun visitPrint(stmt: Print): R
fun visitReturn(stmt: Return): R
fun visitVar(stmt: Var): R fun visitVar(stmt: Var): R
fun visitWhile(stmt: While): R fun visitWhile(stmt: While): R
} }
@ -63,6 +64,15 @@ data class Print(
} }
} }
data class Return(
val keyword: Token,
val value: Expr?
) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R {
return visitor.visitReturn(this)
}
}
data class Var( data class Var(
val name: Token, val name: Token,
val initializer: Expr? val initializer: Expr?

@ -415,4 +415,67 @@ sayHi("Dear", "Reader");
System.setOut(standardOut) System.setOut(standardOut)
} }
} }
}
@Test
fun `Function should work with return statement`() {
val standardOut = System.out
val outputStreamCaptor = ByteArrayOutputStream()
System.setOut(PrintStream(outputStreamCaptor))
try {
val code = """
fun fib(n) {
if (n <= 1) return n;
return fib(n - 1) + fib(n - 2);
}
print fib(10);
"""
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
Interpreter().interpret(statements)
val output = outputStreamCaptor.toString().trim()
assertEquals("55", output)
} finally {
System.setOut(standardOut)
}
}
@Test
fun `Function should work with return statement in block`() {
val standardOut = System.out
val outputStreamCaptor = ByteArrayOutputStream()
System.setOut(PrintStream(outputStreamCaptor))
try {
val code = """
fun fib(n) {
if (n <= 1) {
return n;
}
return fib(n - 1) + fib(n - 2);
}
print fib(10);
"""
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
Interpreter().interpret(statements)
val output = outputStreamCaptor.toString().trim()
assertEquals("55", output)
} finally {
System.setOut(standardOut)
}
}
}

@ -215,4 +215,23 @@ class ParserTest {
assertEquals(2,function.params.size) assertEquals(2,function.params.size)
assertEquals(1,function.body.size) assertEquals(1,function.body.size)
} }
@Test
fun `valid code with function call`() {
val code = """
fun add(a, b) {
return a + b;
}
print add(1, 2);
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2,statements.size)
val stmt = statements[1]
assertTrue(stmt is Print)
val printStmt = stmt as Print
assertTrue(printStmt.expression is Call)
}
} }
Loading…
Cancel
Save