Interpreting resolved variables

main
Olivier Abrivard 2 years ago
parent 948be0b4c6
commit 37be7dfab3

@ -31,6 +31,10 @@ class Environment {
throw RuntimeError(name, "Undefined variable '${name.lexeme}'.") throw RuntimeError(name, "Undefined variable '${name.lexeme}'.")
} }
fun getAt(distance: Int, name: String): Any? {
return ancestor(distance).values[name]
}
fun assign(name: Token, value: Any?) { fun assign(name: Token, value: Any?) {
if (values.containsKey(name.lexeme)) { if (values.containsKey(name.lexeme)) {
values[name.lexeme] = value values[name.lexeme] = value
@ -44,4 +48,16 @@ class Environment {
throw RuntimeError(name, "Undefined variable '${name.lexeme}'.") throw RuntimeError(name, "Undefined variable '${name.lexeme}'.")
} }
fun assignAt(distance: Int, name: Token, value: Any?) {
ancestor(distance).values[name.lexeme] = value
}
private fun ancestor(distance: Int): Environment {
var environment = this
for (i in 0 until distance) {
environment = environment.enclosing!!
}
return environment
}
} }

@ -5,6 +5,7 @@ import fr.celticinfo.lox.TokenType.*
class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
var globals = Environment() var globals = Environment()
private var environment = globals private var environment = globals
private val locals = mutableMapOf<Expr, Int>()
constructor() { constructor() {
globals.define("clock", object : LoxCallable { globals.define("clock", object : LoxCallable {
@ -33,6 +34,10 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
} }
} }
fun resolve(expr: Expr, depth: Int) {
locals[expr] = depth
}
private fun execute(stmt: Stmt?) { private fun execute(stmt: Stmt?) {
stmt?.accept(this) stmt?.accept(this)
} }
@ -97,7 +102,12 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
override fun visitAssign(expr: Assign): Any? { override fun visitAssign(expr: Assign): Any? {
val value = evaluate(expr.value) val value = evaluate(expr.value)
environment.assign(expr.name, value) val distance = locals[expr]
if (distance != null) {
environment.assignAt(distance, expr.name, value)
} else {
globals.assign(expr.name, value)
}
return value return value
} }
@ -202,7 +212,16 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
} }
override fun visitVariable(expr: Variable): Any? { override fun visitVariable(expr: Variable): Any? {
return environment.get(expr.name) return lookUpVariable(expr.name, expr)
}
private fun lookUpVariable(name: Token, expr: Expr): Any? {
val distance = locals[expr]
return if (distance != null) {
environment.getAt(distance, name.lexeme)
} else {
globals.get(name)
}
} }
private fun isTruthy(obj: Any?): Boolean { private fun isTruthy(obj: Any?): Boolean {

@ -45,6 +45,12 @@ object Lox {
// Stop if there was a syntax error. // Stop if there was a syntax error.
if (hadError) return if (hadError) return
val resolver = Resolver(interpreter)
resolver.resolve(statements)
// Stop if there was a resolution error.
if (hadError) return;
interpreter.interpret(statements) interpreter.interpret(statements)
} }
@ -65,6 +71,15 @@ object Lox {
hadRuntimeError = true hadRuntimeError = true
} }
fun hadError() = hadError
fun hadRuntimeError() = hadRuntimeError
fun resetError() {
hadError = false
hadRuntimeError = false
}
private fun report(line: Int, where: String, message: String) { private fun report(line: Int, where: String, message: String) {
System.err.println("[line $line] Error$where: $message") System.err.println("[line $line] Error$where: $message")
hadError = true hadError = true

@ -5,6 +5,7 @@ package fr.celticinfo.lox
*/ */
class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVisitor<Unit> { class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVisitor<Unit> {
private val scopes = mutableListOf<MutableMap<String, Boolean>>() private val scopes = mutableListOf<MutableMap<String, Boolean>>()
private var currentFunctionType = FunctionType.NONE
init { init {
scopes.add(mutableMapOf()) scopes.add(mutableMapOf())
@ -23,7 +24,7 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
override fun visitFunction(stmt: Function) { override fun visitFunction(stmt: Function) {
declare(stmt.name) declare(stmt.name)
define(stmt.name) define(stmt.name)
resolveFunction(stmt) resolveFunction(stmt, FunctionType.FUNCTION)
} }
override fun visitIf(stmt: If) { override fun visitIf(stmt: If) {
@ -37,6 +38,10 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
} }
override fun visitReturn(stmt: Return) { override fun visitReturn(stmt: Return) {
if (currentFunctionType == FunctionType.NONE) {
Lox.error(stmt.keyword, "Cannot return from top-level code.")
}
stmt.value?.let { resolve(it) } stmt.value?.let { resolve(it) }
} }
@ -90,7 +95,7 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
resolve(expr.right) resolve(expr.right)
} }
private fun resolve(statements: List<Stmt?>) { fun resolve(statements: List<Stmt?>) {
for (statement in statements) { for (statement in statements) {
resolve(statement) resolve(statement)
} }
@ -135,7 +140,10 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
} }
} }
private fun resolveFunction(stmt: Function) { private fun resolveFunction(stmt: Function, type: FunctionType) {
val enclosingFunctionType = currentFunctionType
currentFunctionType = type
beginScope() beginScope()
for (param in stmt.params) { for (param in stmt.params) {
declare(param) declare(param)
@ -143,5 +151,12 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
} }
resolve(stmt.body) resolve(stmt.body)
endScope() endScope()
currentFunctionType = enclosingFunctionType
} }
} }
enum class FunctionType {
NONE,
FUNCTION
}

@ -10,6 +10,11 @@ import kotlin.test.Test
class InterpreterTest { class InterpreterTest {
@BeforeTest
fun setUp() {
Lox.resetError()
}
@Test @Test
fun `validate interpreter`() { fun `validate interpreter`() {
val code = """ val code = """
@ -195,7 +200,10 @@ print c;
val statements = parser.parse() val statements = parser.parse()
assertEquals(7, statements.size) assertEquals(7, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("inner a\nouter b\nglobal c\nouter a\nouter b\nglobal c\nglobal a\nglobal b\nglobal c", output) assertEquals("inner a\nouter b\nglobal c\nouter a\nouter b\nglobal c\nglobal a\nglobal b\nglobal c", output)
} finally { } finally {
@ -331,7 +339,13 @@ for (var b = 1; a < 10000; b = temp + b) {
val statements = parser.parse() val statements = parser.parse()
assertEquals(3, statements.size) assertEquals(3, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("0\n" + assertEquals("0\n" +
"1\n" + "1\n" +
@ -408,7 +422,12 @@ sayHi("Dear", "Reader");
val statements = parser.parse() val statements = parser.parse()
assertEquals(2, statements.size) assertEquals(2, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("Hi, Dear Reader!", output) assertEquals("Hi, Dear Reader!", output)
} finally { } finally {
@ -438,7 +457,13 @@ print fib(10);
val statements = parser.parse() val statements = parser.parse()
assertEquals(2, statements.size) assertEquals(2, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("55", output) assertEquals("55", output)
} finally { } finally {
@ -470,7 +495,12 @@ print fib(10);
val statements = parser.parse() val statements = parser.parse()
assertEquals(2, statements.size) assertEquals(2, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("55", output) assertEquals("55", output)
} finally { } finally {
@ -504,7 +534,13 @@ print fib(10);
val statements = parser.parse() val statements = parser.parse()
assertEquals(2, statements.size) assertEquals(2, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("55", output) assertEquals("55", output)
} finally { } finally {
@ -542,7 +578,12 @@ counter();
val statements = parser.parse() val statements = parser.parse()
assertEquals(5, statements.size) assertEquals(5, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("1\n2\n3", output) assertEquals("1\n2\n3", output)
} finally { } finally {
@ -580,11 +621,152 @@ counter();
val statements = parser.parse() val statements = parser.parse()
assertEquals(5, statements.size) assertEquals(5, statements.size)
Interpreter().interpret(statements) val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim() val output = outputStreamCaptor.toString().trim()
assertEquals("1\n1\n1", output) assertEquals("1\n1\n1", output)
} finally { } finally {
System.setOut(standardOut) System.setOut(standardOut)
} }
} }
@Test
fun `Initiate a variable with the same name as a function should not work`() {
val code = """
fun a() {
return 1;
}
var a = a();
print a;
"""
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(3, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
assert(Lox.hadError())
}
@Test
fun `Initiate a variable with the same name from an outer scope variable should not work`() {
val code = """
var a = 1;
{
var a = a;
print a;
}
"""
assert(!Lox.hadError())
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
assert(Lox.hadError())
}
@Test
fun `Initiate a variable with the same name from an outer scope function should not work`() {
val code = """
fun a() {
return 1;
}
{
var a = a;
print a;
}
"""
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
assert(Lox.hadError())
}
@Test
fun `Initiate a variable with the same name from an outer scope function should not work with shadowing`() {
val code = """
fun a() {
return 1;
}
{
fun a() {
return 2;
}
var a = a;
print a;
}
"""
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
assert(Lox.hadError())
}
@Test
fun `Closures should work with shadowing and outer scope variables with the same name`() {
val standardOut = System.out
val outputStreamCaptor = ByteArrayOutputStream()
System.setOut(PrintStream(outputStreamCaptor))
try {
val code = """
var a = "global";
{
fun showA() {
print a;
}
showA();
var a = "block";
showA();
}
"""
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim()
assertEquals("global\nglobal", output)
} finally {
System.setOut(standardOut)
}
}
} }

Loading…
Cancel
Save