Add superclasses and subclasses

main
Olivier Abrivard 1 year ago
parent 2fec77db1f
commit fa461db4f2

@ -56,7 +56,7 @@ defineAst("Expr", exprTypes)
val stmtTypes = listOf( val stmtTypes = listOf(
"Block : List<Stmt?> statements", "Block : List<Stmt?> statements",
"ClassStmt : Token name, List<Function> methods", "ClassStmt : Token name, Variable? superClass, List<Function> methods",
"Expression : Expr expression", "Expression : Expr expression",
"Function : Token name, List<Token> params, List<Stmt?> body", "Function : Token name, List<Token> params, List<Stmt?> body",
"If : Expr condition, Stmt thenBranch, Stmt? elseBranch", "If : Expr condition, Stmt thenBranch, Stmt? elseBranch",

@ -47,6 +47,14 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
} }
override fun visitClassStmt(stmt: ClassStmt) { override fun visitClassStmt(stmt: ClassStmt) {
val superClass = stmt.superClass?.let {
val evaluatedSuperClass = evaluate(it)
if (evaluatedSuperClass !is LoxClass) {
throw RuntimeError(it.name, "Superclass must be a class")
}
evaluatedSuperClass
}
environment.define(stmt.name.lexeme, null) environment.define(stmt.name.lexeme, null)
val methods = stmt.methods.associate { method -> val methods = stmt.methods.associate { method ->
@ -54,7 +62,7 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
method.name.lexeme to function method.name.lexeme to function
} }
val klass = LoxClass(stmt.name.lexeme, methods) val klass = LoxClass(stmt.name.lexeme, superClass, methods)
environment.assign(stmt.name, klass) environment.assign(stmt.name, klass)
} }

@ -3,7 +3,7 @@ package fr.celticinfo.lox
/** /**
* The LoxClass class represents a class in the Lox language. * The LoxClass class represents a class in the Lox language.
*/ */
class LoxClass(private val name: String, private val methods: Map<String, LoxFunction>) : LoxCallable { class LoxClass(private val name: String, private val superClass: LoxClass?, private val methods: Map<String, LoxFunction>) : LoxCallable {
override fun call(interpreter: Interpreter, arguments: List<Any?>): Any? { override fun call(interpreter: Interpreter, arguments: List<Any?>): Any? {
val instance = LoxInstance(this) val instance = LoxInstance(this)
val initializer = findMethod("init") val initializer = findMethod("init")

@ -35,6 +35,13 @@ class Parser(private val tokens: List<Token>) {
private fun classDeclaration(): ClassStmt { private fun classDeclaration(): ClassStmt {
val name = consume(IDENTIFIER, "Expect class name.") val name = consume(IDENTIFIER, "Expect class name.")
var superClass: Variable? = null
if (match(LESS)) {
consume(IDENTIFIER, "Expect superclass name.")
superClass = Variable(previous())
}
consume(LEFT_BRACE, "Expect '{' before class body.") consume(LEFT_BRACE, "Expect '{' before class body.")
val methods: MutableList<Function> = ArrayList() val methods: MutableList<Function> = ArrayList()
@ -44,7 +51,7 @@ class Parser(private val tokens: List<Token>) {
consume(RIGHT_BRACE, "Expect '}' after class body.") consume(RIGHT_BRACE, "Expect '}' after class body.")
return ClassStmt(name, methods) return ClassStmt(name, superClass, methods)
} }
private fun function(kind: String): Function { private fun function(kind: String): Function {

@ -26,6 +26,13 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
declare(stmt.name) declare(stmt.name)
define(stmt.name) define(stmt.name)
if (stmt.superClass != null) {
if (stmt.name.lexeme == stmt.superClass.name.lexeme) {
Lox.error(stmt.superClass.name, "A class cannot inherit from itself.")
}
resolve(stmt.superClass)
}
beginScope() beginScope()
scopes.last()["this"] = true scopes.last()["this"] = true

@ -32,6 +32,7 @@ data class Block(
data class ClassStmt( data class ClassStmt(
val name: Token, val name: Token,
val superClass: Variable?,
val methods: List<Function> val methods: List<Function>
) : Stmt() { ) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R { override fun <R> accept(visitor: StmtVisitor<R>): R {

@ -364,4 +364,31 @@ class ParserTest {
assertTrue(statements[2] is Expression) assertTrue(statements[2] is Expression)
} }
@Test
fun `valid code with inheritance`() {
val code = """
class Doughnut {
cook() {
print "Fry until golden brown.";
}
}
class BostonCream < Doughnut {
cook() {
super.cook();
print "Pipe full of custard and coat with chocolate.";
}
}
BostonCream().cook();
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(3,statements.size)
assertTrue(statements[0] is ClassStmt)
assertTrue(statements[1] is ClassStmt)
assertTrue(statements[2] is Expression)
}
} }
Loading…
Cancel
Save