From 714cf80229cc5f73a5182c81027f0a5fde92767d Mon Sep 17 00:00:00 2001 From: Olivier Abrivard Date: Fri, 6 Sep 2024 10:11:17 +0200 Subject: [PATCH] Add local functions and closure --- src/main/fr/celticinfo/lox/Interpreter.kt | 2 +- src/main/fr/celticinfo/lox/LoxFunction.kt | 28 ++--- src/test/fr/celticinfo/lox/InterpreterTest.kt | 109 ++++++++++++++++++ 3 files changed, 121 insertions(+), 18 deletions(-) diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index cdcfa46..a662398 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -62,7 +62,7 @@ class Interpreter: ExprVisitor, StmtVisitor{ } override fun visitFunction(stmt: Function) { - val function = LoxFunction(stmt) + val function = LoxFunction(stmt, environment) environment.define(stmt.name.lexeme, function) } diff --git a/src/main/fr/celticinfo/lox/LoxFunction.kt b/src/main/fr/celticinfo/lox/LoxFunction.kt index b0ef32b..d7b0e2e 100644 --- a/src/main/fr/celticinfo/lox/LoxFunction.kt +++ b/src/main/fr/celticinfo/lox/LoxFunction.kt @@ -1,34 +1,28 @@ package fr.celticinfo.lox -class LoxFunction : LoxCallable { - private val declaration: Function - - constructor(declaration: Function) { - this.declaration = declaration - } +class LoxFunction( + private val declaration: Function, + private val closure: Environment +) : LoxCallable { override fun call(interpreter: Interpreter, arguments: List): Any? { - val environment = Environment(interpreter.globals) + val environment = Environment(closure) for (i in declaration.params.indices) { environment.define(declaration.params[i].lexeme, arguments[i]) } - try { + return try { interpreter.executeBlock(declaration.body, environment) + null } catch (returnValue: LoxReturn) { - return returnValue.value + returnValue.value } - - return null } - override fun arity(): Int { - return declaration.params.size - } + override fun arity() = declaration.params.size + + override fun toString() = "" - override fun toString(): String { - return "" - } } \ No newline at end of file diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index 7583a6c..53ec82e 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -477,5 +477,114 @@ print fib(10); System.setOut(standardOut) } } + + @Test + fun `Function should work with return statement in block with shadowing`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +fun fib(n) { + var a = 1; + if (n <= 1) { + return n; + } + return fib(n - 1) + fib(n - 2); + } +print fib(10); + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + Interpreter().interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("55", output) + } finally { + System.setOut(standardOut) + } + } + + + @Test + fun `Closures should work`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +fun makeCounter() { + var i = 0; + fun count() { + i = i + 1; + print i; + } + return count; +} + +var counter = makeCounter(); +counter(); +counter(); +counter(); + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(5, statements.size) + + Interpreter().interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("1\n2\n3", output) + } finally { + System.setOut(standardOut) + } + } + + @Test + fun `Closures should work with shadowing`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +fun makeCounter() { + var i = 0; + fun count() { + var i = 0; + i = i + 1; + print i; + } + return count; +} + +var counter = makeCounter(); +counter(); +counter(); +counter(); + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(5, statements.size) + + Interpreter().interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("1\n1\n1", output) + } finally { + System.setOut(standardOut) + } + } +}