Add class declarations

main
Olivier Abrivard 2 years ago
parent 37be7dfab3
commit 85c850a16d

@ -53,6 +53,7 @@ defineAst("Expr", exprTypes)
val stmtTypes = listOf( val stmtTypes = listOf(
"Block : List<Stmt?> statements", "Block : List<Stmt?> statements",
"ClassStmt : Token name, List<Function> methods",
"Expression : Expr expression", "Expression : Expr expression",
"Function : Token name, List<Token> params, List<Stmt?> body", "Function : Token name, List<Token> params, List<Stmt?> body",
"If : Expr condition, Stmt thenBranch, Stmt? elseBranch", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch",

@ -46,20 +46,10 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
executeBlock(stmt.statements, Environment(environment)) executeBlock(stmt.statements, Environment(environment))
} }
fun executeBlock(statements: List<Stmt?>, environment: Environment) { override fun visitClassStmt(stmt: ClassStmt) {
val previous = this.environment environment.define(stmt.name.lexeme, null)
try { val klass = LoxClass(stmt.name.lexeme)
this.environment = environment environment.assign(stmt.name, klass)
for (statement in statements) {
execute(statement)
}
} finally {
this.environment = previous
}
}
private fun evaluate(expr: Expr?): Any? {
return expr?.accept(this)
} }
override fun visitExpression(stmt: Expression) { override fun visitExpression(stmt: Expression) {
@ -215,6 +205,22 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
return lookUpVariable(expr.name, expr) return lookUpVariable(expr.name, expr)
} }
fun executeBlock(statements: List<Stmt?>, environment: Environment) {
val previous = this.environment
try {
this.environment = environment
for (statement in statements) {
execute(statement)
}
} finally {
this.environment = previous
}
}
private fun evaluate(expr: Expr?): Any? {
return expr?.accept(this)
}
private fun lookUpVariable(name: Token, expr: Expr): Any? { private fun lookUpVariable(name: Token, expr: Expr): Any? {
val distance = locals[expr] val distance = locals[expr]
return if (distance != null) { return if (distance != null) {

@ -0,0 +1,5 @@
package fr.celticinfo.lox
class LoxClass(val name: String) {
override fun toString() = name
}

@ -22,6 +22,7 @@ class Parser(private val tokens: List<Token>) {
private fun declaration(): Stmt? { private fun declaration(): Stmt? {
return try { return try {
when { when {
match(CLASS) -> classDeclaration()
match(FUN) -> function("function") match(FUN) -> function("function")
match(VAR) -> varDeclaration() match(VAR) -> varDeclaration()
else -> statement() else -> statement()
@ -32,7 +33,21 @@ class Parser(private val tokens: List<Token>) {
} }
} }
private fun function(kind: String): Stmt { private fun classDeclaration(): ClassStmt {
val name = consume(IDENTIFIER, "Expect class name.")
consume(LEFT_BRACE, "Expect '{' before class body.")
val methods: MutableList<Function> = ArrayList()
while (!check(RIGHT_BRACE) && !isAtEnd()) {
methods.add(function("method"))
}
consume(RIGHT_BRACE, "Expect '}' after class body.")
return ClassStmt(name, methods)
}
private fun function(kind: String): Function {
val name = consume(IDENTIFIER, "Expect $kind name.") val name = consume(IDENTIFIER, "Expect $kind name.")
consume(LEFT_PAREN, "Expect '(' after $kind name.") consume(LEFT_PAREN, "Expect '(' after $kind name.")
val parameters: MutableList<Token> = ArrayList() val parameters: MutableList<Token> = ArrayList()
@ -51,7 +66,7 @@ class Parser(private val tokens: List<Token>) {
return Function(name, parameters, body) return Function(name, parameters, body)
} }
private fun varDeclaration(): Stmt { private fun varDeclaration(): Var {
val name = consume(IDENTIFIER, "Expect variable name.") val name = consume(IDENTIFIER, "Expect variable name.")
var initializer: Expr? = null var initializer: Expr? = null
@ -63,7 +78,7 @@ class Parser(private val tokens: List<Token>) {
return Var(name, initializer) return Var(name, initializer)
} }
private fun whileStatement(): Stmt { private fun whileStatement(): While {
consume(LEFT_PAREN, "Expect '(' after 'while'.") consume(LEFT_PAREN, "Expect '(' after 'while'.")
val condition = expression() val condition = expression()
consume(RIGHT_PAREN, "Expect ')' after condition.") consume(RIGHT_PAREN, "Expect ')' after condition.")
@ -120,7 +135,7 @@ class Parser(private val tokens: List<Token>) {
return body return body
} }
private fun ifStatement(): Stmt { private fun ifStatement(): If {
consume(LEFT_PAREN, "Expect '(' after 'if'.") consume(LEFT_PAREN, "Expect '(' after 'if'.")
val condition = expression() val condition = expression()
consume(RIGHT_PAREN, "Expect ')' after if condition.") consume(RIGHT_PAREN, "Expect ')' after if condition.")
@ -135,20 +150,20 @@ class Parser(private val tokens: List<Token>) {
return If(condition, thenBranch, elseBranch) return If(condition, thenBranch, elseBranch)
} }
private fun printStatement(): Stmt { private fun printStatement(): Print {
val value = expression() val value = expression()
consume(SEMICOLON, "Expect ';' after value.") consume(SEMICOLON, "Expect ';' after value.")
return Print(value) return Print(value)
} }
private fun returnStatement(): Stmt { private fun returnStatement(): Return {
val keyword = previous() val keyword = previous()
val value = if (!check(SEMICOLON)) expression() else null val value = if (!check(SEMICOLON)) expression() else null
consume(SEMICOLON, "Expect ';' after return value.") consume(SEMICOLON, "Expect ';' after return value.")
return Return(keyword, value) return Return(keyword, value)
} }
private fun expressionStatement(): Stmt { private fun expressionStatement(): Expression {
val value = expression() val value = expression()
consume(SEMICOLON, "Expect ';' after expression.") consume(SEMICOLON, "Expect ';' after expression.")
return Expression(value) return Expression(value)

@ -17,6 +17,11 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
endScope() endScope()
} }
override fun visitClassStmt(stmt: ClassStmt) {
declare(stmt.name)
define(stmt.name)
}
override fun visitExpression(stmt: Expression) { override fun visitExpression(stmt: Expression) {
resolve(stmt.expression) resolve(stmt.expression)
} }

@ -5,6 +5,7 @@ package fr.celticinfo.lox
*/ */
interface StmtVisitor<R> { interface StmtVisitor<R> {
fun visitBlock(stmt: Block): R fun visitBlock(stmt: Block): R
fun visitClassStmt(stmt: ClassStmt): R
fun visitExpression(stmt: Expression): R fun visitExpression(stmt: Expression): R
fun visitFunction(stmt: Function): R fun visitFunction(stmt: Function): R
fun visitIf(stmt: If): R fun visitIf(stmt: If): R
@ -29,6 +30,15 @@ data class Block(
} }
} }
data class ClassStmt(
val name: Token,
val methods: List<Function>
) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R {
return visitor.visitClassStmt(this)
}
}
data class Expression( data class Expression(
val expression: Expr val expression: Expr
) : Stmt() { ) : Stmt() {

@ -769,4 +769,40 @@ var a = "global";
System.setOut(standardOut) System.setOut(standardOut)
} }
} }
@Test
fun `valid code with class declaration`() {
val standardOut = System.out
val outputStreamCaptor = ByteArrayOutputStream()
System.setOut(PrintStream(outputStreamCaptor))
try {
val code = """
class DevonshireCream {
serveOn() {
return "Scones";
}
}
print DevonshireCream;
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim()
assertEquals("DevonshireCream", output)
} finally {
System.setOut(standardOut)
}
}
} }

@ -234,4 +234,24 @@ class ParserTest {
val printStmt = stmt as Print val printStmt = stmt as Print
assertTrue(printStmt.expression is Call) assertTrue(printStmt.expression is Call)
} }
@Test
fun `valid code with class declaration`() {
val code = """
class DevonshireCream {
serveOn() {
return "Scones";
}
}
print DevonshireCream;
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2,statements.size)
val stmt = statements[0]
assertTrue(stmt is ClassStmt)
}
} }
Loading…
Cancel
Save