diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index b763174..469dd42 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -50,7 +50,7 @@ class Interpreter: ExprVisitor, StmtVisitor{ environment.define(stmt.name.lexeme, null) val methods = stmt.methods.associate { method -> - val function = LoxFunction(method, environment) + val function = LoxFunction(method, environment, method.name.lexeme == "init") method.name.lexeme to function } diff --git a/src/main/fr/celticinfo/lox/LoxCallable.kt b/src/main/fr/celticinfo/lox/LoxCallable.kt index 9fbea39..98a7c2d 100644 --- a/src/main/fr/celticinfo/lox/LoxCallable.kt +++ b/src/main/fr/celticinfo/lox/LoxCallable.kt @@ -1,5 +1,8 @@ package fr.celticinfo.lox +/** + * The LoxCallable interface represents a callable entity in the Lox language. + */ interface LoxCallable { fun call(interpreter: Interpreter, arguments: List): Any? fun arity(): Int diff --git a/src/main/fr/celticinfo/lox/LoxClass.kt b/src/main/fr/celticinfo/lox/LoxClass.kt index 0325805..6a8a63f 100644 --- a/src/main/fr/celticinfo/lox/LoxClass.kt +++ b/src/main/fr/celticinfo/lox/LoxClass.kt @@ -1,8 +1,13 @@ package fr.celticinfo.lox +/** + * The LoxClass class represents a class in the Lox language. + */ class LoxClass(private val name: String, private val methods: Map) : LoxCallable { override fun call(interpreter: Interpreter, arguments: List): Any? { val instance = LoxInstance(this) + val initializer = findMethod("init") + initializer?.bind(instance)?.call(interpreter, arguments) return instance } @@ -10,7 +15,7 @@ class LoxClass(private val name: String, private val methods: Map): Any? { @@ -22,8 +26,16 @@ class LoxFunction( interpreter.executeBlock(declaration.body, environment) null } catch (returnValue: LoxReturn) { + if (isInitializer) { + return closure.getAt(0, "this") + } + returnValue.value } + + if (isInitializer) { + return closure.getAt(0, "this") + } } override fun arity() = declaration.params.size diff --git a/src/main/fr/celticinfo/lox/Resolver.kt b/src/main/fr/celticinfo/lox/Resolver.kt index d11d017..204337a 100644 --- a/src/main/fr/celticinfo/lox/Resolver.kt +++ b/src/main/fr/celticinfo/lox/Resolver.kt @@ -30,7 +30,7 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi scopes.last()["this"] = true for (method in stmt.methods) { - val declaration = FunctionType.METHOD + val declaration = if (method.name.lexeme == "init") FunctionType.INITIALIZER else FunctionType.METHOD resolveFunction(method, declaration) } endScope() @@ -63,7 +63,12 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi Lox.error(stmt.keyword, "Cannot return from top-level code.") } - stmt.value?.let { resolve(it) } + if (stmt.value != null) { + if (currentFunctionType == FunctionType.INITIALIZER) { + Lox.error(stmt.keyword, "Cannot return a value from an initializer.") + } + resolve(stmt.value) + } } override fun visitVar(stmt: Var) { @@ -197,6 +202,7 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi enum class FunctionType { NONE, FUNCTION, + INITIALIZER, METHOD } diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index eee9aad..1759c35 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -947,4 +947,46 @@ var a = "global"; System.setOut(standardOut) } } + + @Test + fun `valid code with constructor`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ + class Cake { + init(flavor) { + this.flavor = flavor; + } + + taste() { + var adjective = "delicious"; + print "The " + this.flavor + " cake is " + adjective + "!"; + } + } + + var cake = Cake("German chocolate"); + cake.taste(); + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("The German chocolate cake is delicious!", output) + } finally { + System.setOut(standardOut) + } + } } diff --git a/src/test/fr/celticinfo/lox/ParserTest.kt b/src/test/fr/celticinfo/lox/ParserTest.kt index 3aa9a6e..dc57997 100644 --- a/src/test/fr/celticinfo/lox/ParserTest.kt +++ b/src/test/fr/celticinfo/lox/ParserTest.kt @@ -336,4 +336,32 @@ class ParserTest { assertTrue(statements[2] is Expression) assertTrue(statements[3] is Expression) } + + @Test + fun `valid code with constructor`() { + val code = """ + class Cake { + init(flavor) { + this.flavor = flavor; + } + + taste() { + var adjective = "delicious"; + print "The " + this.flavor + " cake is " + adjective + "!"; + } + } + + var cake = Cake("German chocolate"); + cake.taste(); + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3,statements.size) + assertTrue(statements[0] is ClassStmt) + assertTrue(statements[1] is Var) + assertTrue(statements[2] is Expression) + } + } \ No newline at end of file