diff --git a/src/main/ExprGenerator.ws.kts b/src/main/ExprGenerator.ws.kts index 3f38478..1267cf6 100644 --- a/src/main/ExprGenerator.ws.kts +++ b/src/main/ExprGenerator.ws.kts @@ -48,6 +48,7 @@ val exprTypes = listOf( "Literal : Any? value", "Logical : Expr left, Token operator, Expr right", "Set : Expr obj, Token name, Expr value", + "Super : Token keyword, Token method", "This : Token keyword", "Unary : Token operator, Expr right", "Variable : Token name" diff --git a/src/main/fr/celticinfo/lox/AstPrinter.kt b/src/main/fr/celticinfo/lox/AstPrinter.kt index bd66cdc..37a6abf 100644 --- a/src/main/fr/celticinfo/lox/AstPrinter.kt +++ b/src/main/fr/celticinfo/lox/AstPrinter.kt @@ -37,6 +37,10 @@ class AstPrinter : ExprVisitor { return parenthesize("set", expr.obj, Literal(expr.name), expr.value) } + override fun visitSuper(expr: Super): String { + return parenthesize("super", Literal(expr.method)) + } + override fun visitThis(expr: This): String { return "this" } diff --git a/src/main/fr/celticinfo/lox/Environment.kt b/src/main/fr/celticinfo/lox/Environment.kt index f8551e8..61fde43 100644 --- a/src/main/fr/celticinfo/lox/Environment.kt +++ b/src/main/fr/celticinfo/lox/Environment.kt @@ -3,18 +3,9 @@ package fr.celticinfo.lox /** * The Environment class is used to store the variables and their values. */ -class Environment { - private val enclosing: Environment? +class Environment(val enclosing: Environment? = null) { private val values = mutableMapOf() - constructor() { - enclosing = null - } - - constructor(enclosing: Environment) { - this.enclosing = enclosing - } - fun define(name: String, value: Any?) { values[name] = value } diff --git a/src/main/fr/celticinfo/lox/Expr.kt b/src/main/fr/celticinfo/lox/Expr.kt index 90993ba..d3cc013 100644 --- a/src/main/fr/celticinfo/lox/Expr.kt +++ b/src/main/fr/celticinfo/lox/Expr.kt @@ -12,6 +12,7 @@ interface ExprVisitor { fun visitLiteral(expr: Literal): R fun visitLogical(expr: Logical): R fun visitSet(expr: Set): R + fun visitSuper(expr: Super): R fun visitThis(expr: This): R fun visitUnary(expr: Unary): R fun visitVariable(expr: Variable): R @@ -96,6 +97,15 @@ data class Set( } } +data class Super( + val keyword: Token, + val method: Token +) : Expr() { + override fun accept(visitor: ExprVisitor): R { + return visitor.visitSuper(this) + } +} + data class This( val keyword: Token ) : Expr() { diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index ea26682..b66376c 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -57,6 +57,11 @@ class Interpreter: ExprVisitor, StmtVisitor{ environment.define(stmt.name.lexeme, null) + if (superClass != null) { + environment = Environment(environment) + environment.define("super", superClass) + } + val methods = stmt.methods.associate { method -> val function = LoxFunction(method, environment, method.name.lexeme == "init") method.name.lexeme to function @@ -64,6 +69,10 @@ class Interpreter: ExprVisitor, StmtVisitor{ val klass = LoxClass(stmt.name.lexeme, superClass, methods) + if (superClass != null) { + environment = environment.enclosing!! + } + environment.assign(stmt.name, klass) } @@ -219,6 +228,15 @@ class Interpreter: ExprVisitor, StmtVisitor{ return value } + override fun visitSuper(expr: Super): Any? { + val distance = locals[expr] + val superClass = environment.getAt(distance!!, "super") as LoxClass + val obj = environment.getAt(distance - 1, "this") as LoxInstance + val method = superClass.findMethod(expr.method.lexeme) + ?: throw RuntimeError(expr.method, "Undefined property '${expr.method.lexeme}'") + return method.bind(obj) + } + override fun visitThis(expr: This): Any? { return lookUpVariable(expr.keyword, expr) } diff --git a/src/main/fr/celticinfo/lox/Parser.kt b/src/main/fr/celticinfo/lox/Parser.kt index e6ae396..af3ae40 100644 --- a/src/main/fr/celticinfo/lox/Parser.kt +++ b/src/main/fr/celticinfo/lox/Parser.kt @@ -343,6 +343,12 @@ class Parser(private val tokens: List) { match(TRUE) -> return Literal(true) match(NIL) -> return Literal(null) match(NUMBER, STRING) -> return Literal(previous().literal) + match(SUPER) -> { + val keyword = previous() + consume(DOT, "Expect '.' after 'super'.") + val method = consume(IDENTIFIER, "Expect superclass method name.") + return Super(keyword, method) + } match(THIS) -> return This(previous()) match(IDENTIFIER) -> return Variable(previous()) match(LEFT_PAREN) -> { diff --git a/src/main/fr/celticinfo/lox/Resolver.kt b/src/main/fr/celticinfo/lox/Resolver.kt index 5fb1178..05f0ec3 100644 --- a/src/main/fr/celticinfo/lox/Resolver.kt +++ b/src/main/fr/celticinfo/lox/Resolver.kt @@ -27,12 +27,18 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi define(stmt.name) if (stmt.superClass != null) { + currentClassType = ClassType.SUBCLASS if (stmt.name.lexeme == stmt.superClass.name.lexeme) { Lox.error(stmt.superClass.name, "A class cannot inherit from itself.") } resolve(stmt.superClass) } + if (stmt.superClass != null) { + beginScope() + scopes.last()["super"] = true + } + beginScope() scopes.last()["this"] = true @@ -42,6 +48,10 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi } endScope() + if (stmt.superClass != null) { + endScope() + } + currentClassType = enclosingClassType } @@ -133,6 +143,15 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi resolve(expr.obj) } + override fun visitSuper(expr: Super) { + if (currentClassType == ClassType.NONE) { + Lox.error(expr.keyword, "Cannot use 'super' outside of a class.") + } else if (currentClassType != ClassType.SUBCLASS) { + Lox.error(expr.keyword, "Cannot use 'super' in a class with no superclass.") + } + resolveLocal(expr, expr.keyword) + } + override fun visitThis(expr: This) { if (currentClassType == ClassType.NONE) { Lox.error(expr.keyword, "Cannot use 'this' outside of a class.") @@ -215,5 +234,6 @@ enum class FunctionType { enum class ClassType { NONE, - CLASS + CLASS, + SUBCLASS } \ No newline at end of file diff --git a/src/main/fr/celticinfo/loxext/RpnPrinter.kt b/src/main/fr/celticinfo/loxext/RpnPrinter.kt index 4324af9..5b2bf8a 100644 --- a/src/main/fr/celticinfo/loxext/RpnPrinter.kt +++ b/src/main/fr/celticinfo/loxext/RpnPrinter.kt @@ -39,6 +39,10 @@ class RpnPrinter : ExprVisitor { return stack("set", expr.obj, Literal(expr.name), expr.value) } + override fun visitSuper(expr: Super): String { + return stack("super", Literal(expr.method)) + } + override fun visitThis(expr: This): String { return "this" } diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index d3713dd..0310ee6 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -1026,4 +1026,48 @@ var a = "global"; } finally { System.setOut(standardOut) } - }} + } + + @Test + fun `Overriding a method should work`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ + class Doughnut { + cook() { + print "Fry until golden brown."; + } + } + + class BostonCream < Doughnut { + cook() { + super.cook(); + print "Pipe full of custard and coat with chocolate."; + } + } + + BostonCream().cook(); + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("Fry until golden brown.\nPipe full of custard and coat with chocolate.", output) + } finally { + System.setOut(standardOut) + } + } +}