Add class declarations

main
Olivier Abrivard 2 years ago
parent 37be7dfab3
commit 85c850a16d

@ -52,7 +52,8 @@ val exprTypes = listOf(
defineAst("Expr", exprTypes)
val stmtTypes = listOf(
"Block : List<Stmt?> statements",
"Block : List<Stmt?> statements",
"ClassStmt : Token name, List<Function> methods",
"Expression : Expr expression",
"Function : Token name, List<Token> params, List<Stmt?> body",
"If : Expr condition, Stmt thenBranch, Stmt? elseBranch",

@ -46,20 +46,10 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
executeBlock(stmt.statements, Environment(environment))
}
fun executeBlock(statements: List<Stmt?>, environment: Environment) {
val previous = this.environment
try {
this.environment = environment
for (statement in statements) {
execute(statement)
}
} finally {
this.environment = previous
}
}
private fun evaluate(expr: Expr?): Any? {
return expr?.accept(this)
override fun visitClassStmt(stmt: ClassStmt) {
environment.define(stmt.name.lexeme, null)
val klass = LoxClass(stmt.name.lexeme)
environment.assign(stmt.name, klass)
}
override fun visitExpression(stmt: Expression) {
@ -215,6 +205,22 @@ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
return lookUpVariable(expr.name, expr)
}
fun executeBlock(statements: List<Stmt?>, environment: Environment) {
val previous = this.environment
try {
this.environment = environment
for (statement in statements) {
execute(statement)
}
} finally {
this.environment = previous
}
}
private fun evaluate(expr: Expr?): Any? {
return expr?.accept(this)
}
private fun lookUpVariable(name: Token, expr: Expr): Any? {
val distance = locals[expr]
return if (distance != null) {

@ -0,0 +1,5 @@
package fr.celticinfo.lox
class LoxClass(val name: String) {
override fun toString() = name
}

@ -22,6 +22,7 @@ class Parser(private val tokens: List<Token>) {
private fun declaration(): Stmt? {
return try {
when {
match(CLASS) -> classDeclaration()
match(FUN) -> function("function")
match(VAR) -> varDeclaration()
else -> statement()
@ -32,7 +33,21 @@ class Parser(private val tokens: List<Token>) {
}
}
private fun function(kind: String): Stmt {
private fun classDeclaration(): ClassStmt {
val name = consume(IDENTIFIER, "Expect class name.")
consume(LEFT_BRACE, "Expect '{' before class body.")
val methods: MutableList<Function> = ArrayList()
while (!check(RIGHT_BRACE) && !isAtEnd()) {
methods.add(function("method"))
}
consume(RIGHT_BRACE, "Expect '}' after class body.")
return ClassStmt(name, methods)
}
private fun function(kind: String): Function {
val name = consume(IDENTIFIER, "Expect $kind name.")
consume(LEFT_PAREN, "Expect '(' after $kind name.")
val parameters: MutableList<Token> = ArrayList()
@ -51,7 +66,7 @@ class Parser(private val tokens: List<Token>) {
return Function(name, parameters, body)
}
private fun varDeclaration(): Stmt {
private fun varDeclaration(): Var {
val name = consume(IDENTIFIER, "Expect variable name.")
var initializer: Expr? = null
@ -63,7 +78,7 @@ class Parser(private val tokens: List<Token>) {
return Var(name, initializer)
}
private fun whileStatement(): Stmt {
private fun whileStatement(): While {
consume(LEFT_PAREN, "Expect '(' after 'while'.")
val condition = expression()
consume(RIGHT_PAREN, "Expect ')' after condition.")
@ -120,7 +135,7 @@ class Parser(private val tokens: List<Token>) {
return body
}
private fun ifStatement(): Stmt {
private fun ifStatement(): If {
consume(LEFT_PAREN, "Expect '(' after 'if'.")
val condition = expression()
consume(RIGHT_PAREN, "Expect ')' after if condition.")
@ -135,20 +150,20 @@ class Parser(private val tokens: List<Token>) {
return If(condition, thenBranch, elseBranch)
}
private fun printStatement(): Stmt {
private fun printStatement(): Print {
val value = expression()
consume(SEMICOLON, "Expect ';' after value.")
return Print(value)
}
private fun returnStatement(): Stmt {
private fun returnStatement(): Return {
val keyword = previous()
val value = if (!check(SEMICOLON)) expression() else null
consume(SEMICOLON, "Expect ';' after return value.")
return Return(keyword, value)
}
private fun expressionStatement(): Stmt {
private fun expressionStatement(): Expression {
val value = expression()
consume(SEMICOLON, "Expect ';' after expression.")
return Expression(value)

@ -17,6 +17,11 @@ class Resolver(private val interpreter: Interpreter) : ExprVisitor<Unit>, StmtVi
endScope()
}
override fun visitClassStmt(stmt: ClassStmt) {
declare(stmt.name)
define(stmt.name)
}
override fun visitExpression(stmt: Expression) {
resolve(stmt.expression)
}

@ -5,6 +5,7 @@ package fr.celticinfo.lox
*/
interface StmtVisitor<R> {
fun visitBlock(stmt: Block): R
fun visitClassStmt(stmt: ClassStmt): R
fun visitExpression(stmt: Expression): R
fun visitFunction(stmt: Function): R
fun visitIf(stmt: If): R
@ -29,6 +30,15 @@ data class Block(
}
}
data class ClassStmt(
val name: Token,
val methods: List<Function>
) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R {
return visitor.visitClassStmt(this)
}
}
data class Expression(
val expression: Expr
) : Stmt() {

@ -769,4 +769,40 @@ var a = "global";
System.setOut(standardOut)
}
}
@Test
fun `valid code with class declaration`() {
val standardOut = System.out
val outputStreamCaptor = ByteArrayOutputStream()
System.setOut(PrintStream(outputStreamCaptor))
try {
val code = """
class DevonshireCream {
serveOn() {
return "Scones";
}
}
print DevonshireCream;
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2, statements.size)
val interpreter = Interpreter()
val resolver = Resolver(interpreter)
resolver.resolve(statements)
interpreter.interpret(statements)
val output = outputStreamCaptor.toString().trim()
assertEquals("DevonshireCream", output)
} finally {
System.setOut(standardOut)
}
}
}

@ -234,4 +234,24 @@ class ParserTest {
val printStmt = stmt as Print
assertTrue(printStmt.expression is Call)
}
@Test
fun `valid code with class declaration`() {
val code = """
class DevonshireCream {
serveOn() {
return "Scones";
}
}
print DevonshireCream;
""".trimIndent()
val scanner = Scanner(code)
val tokens = scanner.scanTokens()
val parser = Parser(tokens)
val statements = parser.parse()
assertEquals(2,statements.size)
val stmt = statements[0]
assertTrue(stmt is ClassStmt)
}
}
Loading…
Cancel
Save