diff --git a/src/main/ExprGenerator.ws.kts b/src/main/ExprGenerator.ws.kts index f808b5c..0edcf7e 100644 --- a/src/main/ExprGenerator.ws.kts +++ b/src/main/ExprGenerator.ws.kts @@ -52,7 +52,8 @@ val exprTypes = listOf( defineAst("Expr", exprTypes) val stmtTypes = listOf( - "Block : List statements", + "Block : List statements", + "ClassStmt : Token name, List methods", "Expression : Expr expression", "Function : Token name, List params, List body", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch", diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index 1ed593d..cb782c3 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -46,20 +46,10 @@ class Interpreter: ExprVisitor, StmtVisitor{ executeBlock(stmt.statements, Environment(environment)) } - fun executeBlock(statements: List, environment: Environment) { - val previous = this.environment - try { - this.environment = environment - for (statement in statements) { - execute(statement) - } - } finally { - this.environment = previous - } - } - - private fun evaluate(expr: Expr?): Any? { - return expr?.accept(this) + override fun visitClassStmt(stmt: ClassStmt) { + environment.define(stmt.name.lexeme, null) + val klass = LoxClass(stmt.name.lexeme) + environment.assign(stmt.name, klass) } override fun visitExpression(stmt: Expression) { @@ -215,6 +205,22 @@ class Interpreter: ExprVisitor, StmtVisitor{ return lookUpVariable(expr.name, expr) } + fun executeBlock(statements: List, environment: Environment) { + val previous = this.environment + try { + this.environment = environment + for (statement in statements) { + execute(statement) + } + } finally { + this.environment = previous + } + } + + private fun evaluate(expr: Expr?): Any? { + return expr?.accept(this) + } + private fun lookUpVariable(name: Token, expr: Expr): Any? { val distance = locals[expr] return if (distance != null) { diff --git a/src/main/fr/celticinfo/lox/LoxClass.kt b/src/main/fr/celticinfo/lox/LoxClass.kt new file mode 100644 index 0000000..35fb8c7 --- /dev/null +++ b/src/main/fr/celticinfo/lox/LoxClass.kt @@ -0,0 +1,5 @@ +package fr.celticinfo.lox + +class LoxClass(val name: String) { + override fun toString() = name +} \ No newline at end of file diff --git a/src/main/fr/celticinfo/lox/Parser.kt b/src/main/fr/celticinfo/lox/Parser.kt index 4e164dd..6570af8 100644 --- a/src/main/fr/celticinfo/lox/Parser.kt +++ b/src/main/fr/celticinfo/lox/Parser.kt @@ -22,6 +22,7 @@ class Parser(private val tokens: List) { private fun declaration(): Stmt? { return try { when { + match(CLASS) -> classDeclaration() match(FUN) -> function("function") match(VAR) -> varDeclaration() else -> statement() @@ -32,7 +33,21 @@ class Parser(private val tokens: List) { } } - private fun function(kind: String): Stmt { + private fun classDeclaration(): ClassStmt { + val name = consume(IDENTIFIER, "Expect class name.") + consume(LEFT_BRACE, "Expect '{' before class body.") + + val methods: MutableList = ArrayList() + while (!check(RIGHT_BRACE) && !isAtEnd()) { + methods.add(function("method")) + } + + consume(RIGHT_BRACE, "Expect '}' after class body.") + + return ClassStmt(name, methods) + } + + private fun function(kind: String): Function { val name = consume(IDENTIFIER, "Expect $kind name.") consume(LEFT_PAREN, "Expect '(' after $kind name.") val parameters: MutableList = ArrayList() @@ -51,7 +66,7 @@ class Parser(private val tokens: List) { return Function(name, parameters, body) } - private fun varDeclaration(): Stmt { + private fun varDeclaration(): Var { val name = consume(IDENTIFIER, "Expect variable name.") var initializer: Expr? = null @@ -63,7 +78,7 @@ class Parser(private val tokens: List) { return Var(name, initializer) } - private fun whileStatement(): Stmt { + private fun whileStatement(): While { consume(LEFT_PAREN, "Expect '(' after 'while'.") val condition = expression() consume(RIGHT_PAREN, "Expect ')' after condition.") @@ -120,7 +135,7 @@ class Parser(private val tokens: List) { return body } - private fun ifStatement(): Stmt { + private fun ifStatement(): If { consume(LEFT_PAREN, "Expect '(' after 'if'.") val condition = expression() consume(RIGHT_PAREN, "Expect ')' after if condition.") @@ -135,20 +150,20 @@ class Parser(private val tokens: List) { return If(condition, thenBranch, elseBranch) } - private fun printStatement(): Stmt { + private fun printStatement(): Print { val value = expression() consume(SEMICOLON, "Expect ';' after value.") return Print(value) } - private fun returnStatement(): Stmt { + private fun returnStatement(): Return { val keyword = previous() val value = if (!check(SEMICOLON)) expression() else null consume(SEMICOLON, "Expect ';' after return value.") return Return(keyword, value) } - private fun expressionStatement(): Stmt { + private fun expressionStatement(): Expression { val value = expression() consume(SEMICOLON, "Expect ';' after expression.") return Expression(value) diff --git a/src/main/fr/celticinfo/lox/Resolver.kt b/src/main/fr/celticinfo/lox/Resolver.kt index 34d80f4..570b85b 100644 --- a/src/main/fr/celticinfo/lox/Resolver.kt +++ b/src/main/fr/celticinfo/lox/Resolver.kt @@ -17,6 +17,11 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor, StmtVi endScope() } + override fun visitClassStmt(stmt: ClassStmt) { + declare(stmt.name) + define(stmt.name) + } + override fun visitExpression(stmt: Expression) { resolve(stmt.expression) } diff --git a/src/main/fr/celticinfo/lox/Stmt.kt b/src/main/fr/celticinfo/lox/Stmt.kt index 9db2f55..0efc957 100644 --- a/src/main/fr/celticinfo/lox/Stmt.kt +++ b/src/main/fr/celticinfo/lox/Stmt.kt @@ -5,6 +5,7 @@ package fr.celticinfo.lox */ interface StmtVisitor { fun visitBlock(stmt: Block): R + fun visitClassStmt(stmt: ClassStmt): R fun visitExpression(stmt: Expression): R fun visitFunction(stmt: Function): R fun visitIf(stmt: If): R @@ -29,6 +30,15 @@ data class Block( } } +data class ClassStmt( + val name: Token, + val methods: List +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitClassStmt(this) + } +} + data class Expression( val expression: Expr ) : Stmt() { diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index b61d4d0..33bccf3 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -769,4 +769,40 @@ var a = "global"; System.setOut(standardOut) } } + + @Test + fun `valid code with class declaration`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ + class DevonshireCream { + serveOn() { + return "Scones"; + } + } + + print DevonshireCream; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2, statements.size) + + val interpreter = Interpreter() + + val resolver = Resolver(interpreter) + resolver.resolve(statements) + + interpreter.interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("DevonshireCream", output) + } finally { + System.setOut(standardOut) + } + } } diff --git a/src/test/fr/celticinfo/lox/ParserTest.kt b/src/test/fr/celticinfo/lox/ParserTest.kt index 32ff540..756597a 100644 --- a/src/test/fr/celticinfo/lox/ParserTest.kt +++ b/src/test/fr/celticinfo/lox/ParserTest.kt @@ -234,4 +234,24 @@ class ParserTest { val printStmt = stmt as Print assertTrue(printStmt.expression is Call) } + + @Test + fun `valid code with class declaration`() { + val code = """ + class DevonshireCream { + serveOn() { + return "Scones"; + } + } + + print DevonshireCream; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(2,statements.size) + val stmt = statements[0] + assertTrue(stmt is ClassStmt) + } } \ No newline at end of file