From 37be7dfab30d9402c13ef1e4a02d89c2f7718262 Mon Sep 17 00:00:00 2001 From: Olivier Abrivard Date: Mon, 16 Sep 2024 10:04:29 +0200 Subject: [PATCH] Interpreting resolved variables --- src/main/fr/celticinfo/lox/Environment.kt | 16 ++ src/main/fr/celticinfo/lox/Interpreter.kt | 23 +- src/main/fr/celticinfo/lox/Lox.kt | 15 ++ src/main/fr/celticinfo/lox/Resolver.kt | 23 +- src/test/fr/celticinfo/lox/InterpreterTest.kt | 198 +++++++++++++++++- 5 files changed, 261 insertions(+), 14 deletions(-) diff --git a/src/main/fr/celticinfo/lox/Environment.kt b/src/main/fr/celticinfo/lox/Environment.kt index 8b9225b..f8551e8 100644 --- a/src/main/fr/celticinfo/lox/Environment.kt +++ b/src/main/fr/celticinfo/lox/Environment.kt @@ -31,6 +31,10 @@ class Environment { throw RuntimeError(name, "Undefined variable '${name.lexeme}'.") } + fun getAt(distance: Int, name: String): Any? { + return ancestor(distance).values[name] + } + fun assign(name: Token, value: Any?) { if (values.containsKey(name.lexeme)) { values[name.lexeme] = value @@ -44,4 +48,16 @@ class Environment { throw RuntimeError(name, "Undefined variable '${name.lexeme}'.") } + + fun assignAt(distance: Int, name: Token, value: Any?) { + ancestor(distance).values[name.lexeme] = value + } + + private fun ancestor(distance: Int): Environment { + var environment = this + for (i in 0 until distance) { + environment = environment.enclosing!! + } + return environment + } } \ No newline at end of file diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index a662398..1ed593d 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -5,6 +5,7 @@ import fr.celticinfo.lox.TokenType.* class Interpreter: ExprVisitor, StmtVisitor{ var globals = Environment() private var environment = globals + private val locals = mutableMapOf() constructor() { globals.define("clock", object : LoxCallable { @@ -33,6 +34,10 @@ class Interpreter: ExprVisitor, StmtVisitor{ } } + fun resolve(expr: Expr, depth: Int) { + locals[expr] = depth + } + private fun execute(stmt: Stmt?) { stmt?.accept(this) } @@ -97,7 +102,12 @@ class Interpreter: ExprVisitor, StmtVisitor{ override fun visitAssign(expr: Assign): Any? { val value = evaluate(expr.value) - environment.assign(expr.name, value) + val distance = locals[expr] + if (distance != null) { + environment.assignAt(distance, expr.name, value) + } else { + globals.assign(expr.name, value) + } return value } @@ -202,7 +212,16 @@ class Interpreter: ExprVisitor, StmtVisitor{ } override fun visitVariable(expr: Variable): Any? { - return environment.get(expr.name) + return lookUpVariable(expr.name, expr) + } + + private fun lookUpVariable(name: Token, expr: Expr): Any? { + val distance = locals[expr] + return if (distance != null) { + environment.getAt(distance, name.lexeme) + } else { + globals.get(name) + } } private fun isTruthy(obj: Any?): Boolean { diff --git a/src/main/fr/celticinfo/lox/Lox.kt b/src/main/fr/celticinfo/lox/Lox.kt index 7d3d94e..7169a3a 100644 --- a/src/main/fr/celticinfo/lox/Lox.kt +++ b/src/main/fr/celticinfo/lox/Lox.kt @@ -45,6 +45,12 @@ object Lox { // Stop if there was a syntax error. if (hadError) return + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + // Stop if there was a resolution error. + if (hadError) return; + interpreter.interpret(statements) } @@ -65,6 +71,15 @@ object Lox { hadRuntimeError = true } + fun hadError() = hadError + + fun hadRuntimeError() = hadRuntimeError + + fun resetError() { + hadError = false + hadRuntimeError = false + } + private fun report(line: Int, where: String, message: String) { System.err.println("[line $line] Error$where: $message") hadError = true diff --git a/src/main/fr/celticinfo/lox/Resolver.kt b/src/main/fr/celticinfo/lox/Resolver.kt index 98218e8..34d80f4 100644 --- a/src/main/fr/celticinfo/lox/Resolver.kt +++ b/src/main/fr/celticinfo/lox/Resolver.kt @@ -5,6 +5,7 @@ package fr.celticinfo.lox */ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVisitor { private val scopes = mutableListOf>() + private var currentFunctionType = FunctionType.NONE init { scopes.add(mutableMapOf()) @@ -23,7 +24,7 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi override fun visitFunction(stmt: Function) { declare(stmt.name) define(stmt.name) - resolveFunction(stmt) + resolveFunction(stmt, FunctionType.FUNCTION) } override fun visitIf(stmt: If) { @@ -37,6 +38,10 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi } override fun visitReturn(stmt: Return) { + if (currentFunctionType == FunctionType.NONE) { + Lox.error(stmt.keyword, "Cannot return from top-level code.") + } + stmt.value?.let { resolve(it) } } @@ -90,7 +95,7 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi resolve(expr.right) } - private fun resolve(statements: List) { + fun resolve(statements: List) { for (statement in statements) { resolve(statement) } @@ -135,7 +140,10 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi } } - private fun resolveFunction(stmt: Function) { + private fun resolveFunction(stmt: Function, type: FunctionType) { + val enclosingFunctionType = currentFunctionType + currentFunctionType = type + beginScope() for (param in stmt.params) { declare(param) @@ -143,5 +151,12 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi } resolve(stmt.body) endScope() + + currentFunctionType = enclosingFunctionType } -} \ No newline at end of file +} + +enum class FunctionType { + NONE, + FUNCTION +} diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index 53ec82e..b61d4d0 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -10,6 +10,11 @@ import kotlin.test.Test class InterpreterTest { + @BeforeTest + fun setUp() { + Lox.resetError() + } + @Test fun `validate interpreter`() { val code = """ @@ -195,7 +200,10 @@ print c; val statements = parser.parse() assertEquals(7, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + val resolver = Resolver(interpreter) + resolver.resolve(statements) + interpreter.interpret(statements) val output = outputStreamCaptor.toString().trim() assertEquals("inner a\nouter b\nglobal c\nouter a\nouter b\nglobal c\nglobal a\nglobal b\nglobal c", output) } finally { @@ -331,7 +339,13 @@ for (var b = 1; a < 10000; b = temp + b) { val statements = parser.parse() assertEquals(3, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() assertEquals("0\n" + "1\n" + @@ -408,7 +422,12 @@ sayHi("Dear", "Reader"); val statements = parser.parse() assertEquals(2, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) val output = outputStreamCaptor.toString().trim() assertEquals("Hi, Dear Reader!", output) } finally { @@ -438,7 +457,13 @@ print fib(10); val statements = parser.parse() assertEquals(2, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() assertEquals("55", output) } finally { @@ -470,7 +495,12 @@ print fib(10); val statements = parser.parse() assertEquals(2, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) val output = outputStreamCaptor.toString().trim() assertEquals("55", output) } finally { @@ -504,7 +534,13 @@ print fib(10); val statements = parser.parse() assertEquals(2, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() assertEquals("55", output) } finally { @@ -542,7 +578,12 @@ counter(); val statements = parser.parse() assertEquals(5, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) val output = outputStreamCaptor.toString().trim() assertEquals("1\n2\n3", output) } finally { @@ -580,11 +621,152 @@ counter(); val statements = parser.parse() assertEquals(5, statements.size) - Interpreter().interpret(statements) + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) val output = outputStreamCaptor.toString().trim() assertEquals("1\n1\n1", output) } finally { System.setOut(standardOut) } } + + @Test + fun `Initiate a variable with the same name as a function should not work`() { + val code = """ +fun a() { + return 1; + +} + +var a = a(); + +print a; + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + val interpreter = Interpreter() + val resolver = Resolver(interpreter) + resolver.resolve(statements) + assert(Lox.hadError()) + } + + @Test + fun `Initiate a variable with the same name from an outer scope variable should not work`() { + val code = """ +var a = 1; +{ + var a = a; + print a; +} + """ + assert(!Lox.hadError()) + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + val interpreter = Interpreter() + val resolver = Resolver(interpreter) + resolver.resolve(statements) + assert(Lox.hadError()) + } + + @Test + fun `Initiate a variable with the same name from an outer scope function should not work`() { + val code = """ +fun a() { + return 1; +} + +{ + var a = a; + print a; +} + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + val interpreter = Interpreter() + val resolver = Resolver(interpreter) + resolver.resolve(statements) + assert(Lox.hadError()) + } + + @Test + fun `Initiate a variable with the same name from an outer scope function should not work with shadowing`() { + val code = """ +fun a() { + return 1; +} + +{ + fun a() { + return 2; + } + var a = a; + print a; +} + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + val interpreter = Interpreter() + val resolver = Resolver(interpreter) + resolver.resolve(statements) + assert(Lox.hadError()) + } + + @Test + fun `Closures should work with shadowing and outer scope variables with the same name`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +var a = "global"; +{ + fun showA() { + print a; + } + + showA(); + var a = "block"; + showA(); +} + """ + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("global\nglobal", output) + } finally { + System.setOut(standardOut) + } + } }