From 3ac911b6662b8331f92eb49345f9827e3354f900 Mon Sep 17 00:00:00 2001 From: Olivier Abrivard Date: Fri, 6 Sep 2024 09:39:14 +0200 Subject: [PATCH] Add Return statement --- src/main/ExprGenerator.ws.kts | 1 + src/main/fr/celticinfo/lox/Interpreter.kt | 5 ++ src/main/fr/celticinfo/lox/LoxFunction.kt | 9 ++- src/main/fr/celticinfo/lox/LoxReturn.kt | 4 ++ src/main/fr/celticinfo/lox/Parser.kt | 8 +++ src/main/fr/celticinfo/lox/Stmt.kt | 10 +++ src/test/fr/celticinfo/lox/InterpreterTest.kt | 65 ++++++++++++++++++- src/test/fr/celticinfo/lox/ParserTest.kt | 19 ++++++ 8 files changed, 119 insertions(+), 2 deletions(-) create mode 100644 src/main/fr/celticinfo/lox/LoxReturn.kt diff --git a/src/main/ExprGenerator.ws.kts b/src/main/ExprGenerator.ws.kts index e6b7514..f808b5c 100644 --- a/src/main/ExprGenerator.ws.kts +++ b/src/main/ExprGenerator.ws.kts @@ -57,6 +57,7 @@ val stmtTypes = listOf( "Function : Token name, List params, List body", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch", "Print : Expr expression", + "Return : Token keyword, Expr? value", "Var : Token name, Expr? initializer", "While : Expr condition, Stmt body" ) diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index a7590ea..cdcfa46 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -85,6 +85,11 @@ class Interpreter: ExprVisitor, StmtVisitor{ println(stringify(value)) } + override fun visitReturn(stmt: Return) { + val value = stmt.value?.let { evaluate(it) } + throw LoxReturn(value) + } + override fun visitVar(stmt: Var) { val value = evaluate(stmt.initializer) environment.define(stmt.name.lexeme, value) diff --git a/src/main/fr/celticinfo/lox/LoxFunction.kt b/src/main/fr/celticinfo/lox/LoxFunction.kt index bd8b081..b0ef32b 100644 --- a/src/main/fr/celticinfo/lox/LoxFunction.kt +++ b/src/main/fr/celticinfo/lox/LoxFunction.kt @@ -9,10 +9,17 @@ class LoxFunction : LoxCallable { override fun call(interpreter: Interpreter, arguments: List): Any? { val environment = Environment(interpreter.globals) + for (i in declaration.params.indices) { environment.define(declaration.params[i].lexeme, arguments[i]) } - interpreter.executeBlock(declaration.body, environment) + + try { + interpreter.executeBlock(declaration.body, environment) + } catch (returnValue: LoxReturn) { + return returnValue.value + } + return null } diff --git a/src/main/fr/celticinfo/lox/LoxReturn.kt b/src/main/fr/celticinfo/lox/LoxReturn.kt new file mode 100644 index 0000000..47c02c9 --- /dev/null +++ b/src/main/fr/celticinfo/lox/LoxReturn.kt @@ -0,0 +1,4 @@ +package fr.celticinfo.lox + +class LoxReturn(val value: Any?) : RuntimeException(null, null, false, false) { +} \ No newline at end of file diff --git a/src/main/fr/celticinfo/lox/Parser.kt b/src/main/fr/celticinfo/lox/Parser.kt index 2ac7278..4e164dd 100644 --- a/src/main/fr/celticinfo/lox/Parser.kt +++ b/src/main/fr/celticinfo/lox/Parser.kt @@ -77,6 +77,7 @@ class Parser(private val tokens: List) { match(FOR) -> forStatement() match(IF) -> ifStatement() match(PRINT) -> printStatement() + match(RETURN) -> returnStatement() match(WHILE) -> whileStatement() match(LEFT_BRACE) -> Block(blockStatement()) else -> expressionStatement() @@ -140,6 +141,13 @@ class Parser(private val tokens: List) { return Print(value) } + private fun returnStatement(): Stmt { + val keyword = previous() + val value = if (!check(SEMICOLON)) expression() else null + consume(SEMICOLON, "Expect ';' after return value.") + return Return(keyword, value) + } + private fun expressionStatement(): Stmt { val value = expression() consume(SEMICOLON, "Expect ';' after expression.") diff --git a/src/main/fr/celticinfo/lox/Stmt.kt b/src/main/fr/celticinfo/lox/Stmt.kt index 9ac9c26..9db2f55 100644 --- a/src/main/fr/celticinfo/lox/Stmt.kt +++ b/src/main/fr/celticinfo/lox/Stmt.kt @@ -9,6 +9,7 @@ interface StmtVisitor { fun visitFunction(stmt: Function): R fun visitIf(stmt: If): R fun visitPrint(stmt: Print): R + fun visitReturn(stmt: Return): R fun visitVar(stmt: Var): R fun visitWhile(stmt: While): R } @@ -63,6 +64,15 @@ data class Print( } } +data class Return( + val keyword: Token, + val value: Expr? +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitReturn(this) + } +} + data class Var( val name: Token, val initializer: Expr? diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index cdb2e27..7583a6c 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -415,4 +415,67 @@ sayHi("Dear", "Reader"); System.setOut(standardOut) } } -} \ No newline at end of file + + @Test + fun `Function should work with return statement`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +fun fib(n) { + if (n <= 1) return n; + return fib(n - 1) + fib(n - 2); +} + +print fib(10); + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + Interpreter().interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("55", output) + } finally { + System.setOut(standardOut) + } + } + + @Test + fun `Function should work with return statement in block`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +fun fib(n) { + if (n <= 1) { + return n; + } + return fib(n - 1) + fib(n - 2); +} + +print fib(10); + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + Interpreter().interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("55", output) + } finally { + System.setOut(standardOut) + } + } +} + diff --git a/src/test/fr/celticinfo/lox/ParserTest.kt b/src/test/fr/celticinfo/lox/ParserTest.kt index 21ee564..32ff540 100644 --- a/src/test/fr/celticinfo/lox/ParserTest.kt +++ b/src/test/fr/celticinfo/lox/ParserTest.kt @@ -215,4 +215,23 @@ class ParserTest { assertEquals(2,function.params.size) assertEquals(1,function.body.size) } + + @Test + fun `valid code with function call`() { + val code = """ + fun add(a, b) { + return a + b; + } + print add(1, 2); + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2,statements.size) + val stmt = statements[1] + assertTrue(stmt is Print) + val printStmt = stmt as Print + assertTrue(printStmt.expression is Call) + } } \ No newline at end of file