From 16e28c07b51bee782dca66e4c7b12d4fda6c0966 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Thu, 25 Jul 2024 11:08:49 +0200 Subject: [PATCH] Implemented Blocks and Assignments --- src/main/ExprGenerator.ws.kts | 8 +- src/main/fr/celticinfo/lox/AstPrinter.kt | 24 ++- src/main/fr/celticinfo/lox/Environment.kt | 47 ++++++ src/main/fr/celticinfo/lox/Expr.kt | 21 ++- src/main/fr/celticinfo/lox/Interpreter.kt | 42 ++++- src/main/fr/celticinfo/lox/Parser.kt | 66 +++++++- src/main/fr/celticinfo/lox/Stmt.kt | 19 +++ src/main/fr/celticinfo/loxext/RpnPrinter.kt | 24 ++- src/test/fr/celticinfo/lox/InterpreterTest.kt | 146 +++++++++++++++++- src/test/fr/celticinfo/lox/ParserTest.kt | 94 +++++++++-- 10 files changed, 451 insertions(+), 40 deletions(-) create mode 100644 src/main/fr/celticinfo/lox/Environment.kt diff --git a/src/main/ExprGenerator.ws.kts b/src/main/ExprGenerator.ws.kts index d68f7f4..5e90e2b 100644 --- a/src/main/ExprGenerator.ws.kts +++ b/src/main/ExprGenerator.ws.kts @@ -40,15 +40,19 @@ fun defineAst(baseName: String, types: List) { } val exprTypes = listOf( + "Assign : Token name, Expr value", "Binary : Expr left, Token operator, Expr right", "Grouping : Expr expression", "Literal : Any? value", - "Unary : Token operator, Expr right" + "Unary : Token operator, Expr right", + "Variable : Token name" ) defineAst("Expr", exprTypes) val stmtTypes = listOf( + "Block : List statements", "Expression : Expr expression", - "Print : Expr expression" + "Print : Expr expression", + "Var : Token name, Expr? initializer" ) defineAst("Stmt", stmtTypes) diff --git a/src/main/fr/celticinfo/lox/AstPrinter.kt b/src/main/fr/celticinfo/lox/AstPrinter.kt index 5692f34..68cb927 100644 --- a/src/main/fr/celticinfo/lox/AstPrinter.kt +++ b/src/main/fr/celticinfo/lox/AstPrinter.kt @@ -5,22 +5,30 @@ class AstPrinter : ExprVisitor { return expr.accept(this) } - override fun visitBinary(binary: Binary): String { - return parenthesize(binary.operator.lexeme, binary.left, binary.right) + override fun visitAssign(expr: Assign): String { + return parenthesize("=", expr) } - override fun visitGrouping(grouping: Grouping): String { - return parenthesize("group", grouping.expression) + override fun visitBinary(expr: Binary): String { + return parenthesize(expr.operator.lexeme, expr.left, expr.right) } - override fun visitLiteral(literal: Literal): String { - return literal.value?.toString() ?: "nil" + override fun visitGrouping(expr: Grouping): String { + return parenthesize("group", expr.expression) } - override fun visitUnary(unary: Unary): String { - return parenthesize(unary.operator.lexeme, unary.right) + override fun visitLiteral(expr: Literal): String { + return expr.value?.toString() ?: "nil" } + override fun visitUnary(expr: Unary): String { + return parenthesize(expr.operator.lexeme, expr.right) + } + + override fun visitVariable(expr: Variable): String { + return expr.name.lexeme + } + private fun parenthesize(name: String, vararg exprs: Expr): String { val builder = StringBuilder() builder.append("(").append(name) diff --git a/src/main/fr/celticinfo/lox/Environment.kt b/src/main/fr/celticinfo/lox/Environment.kt new file mode 100644 index 0000000..8b9225b --- /dev/null +++ b/src/main/fr/celticinfo/lox/Environment.kt @@ -0,0 +1,47 @@ +package fr.celticinfo.lox + +/** + * The Environment class is used to store the variables and their values. + */ +class Environment { + private val enclosing: Environment? + private val values = mutableMapOf() + + constructor() { + enclosing = null + } + + constructor(enclosing: Environment) { + this.enclosing = enclosing + } + + fun define(name: String, value: Any?) { + values[name] = value + } + + fun get(name: Token): Any? { + if (values.containsKey(name.lexeme)) { + return values[name.lexeme] + } + + if (enclosing != null) { + return enclosing.get(name) + } + + throw RuntimeError(name, "Undefined variable '${name.lexeme}'.") + } + + fun assign(name: Token, value: Any?) { + if (values.containsKey(name.lexeme)) { + values[name.lexeme] = value + return + } + + if (enclosing != null) { + enclosing.assign(name, value) + return + } + + throw RuntimeError(name, "Undefined variable '${name.lexeme}'.") + } +} \ No newline at end of file diff --git a/src/main/fr/celticinfo/lox/Expr.kt b/src/main/fr/celticinfo/lox/Expr.kt index ca40780..a213d3f 100644 --- a/src/main/fr/celticinfo/lox/Expr.kt +++ b/src/main/fr/celticinfo/lox/Expr.kt @@ -4,10 +4,12 @@ package fr.celticinfo.lox * The ExprVisitor interface is used to visit the different types of expressions that can be parsed by the Parser. */ interface ExprVisitor { + fun visitAssign(expr: Assign): R fun visitBinary(expr: Binary): R fun visitGrouping(expr: Grouping): R fun visitLiteral(expr: Literal): R fun visitUnary(expr: Unary): R + fun visitVariable(expr: Variable): R } /** @@ -17,6 +19,15 @@ sealed class Expr { abstract fun accept(visitor: ExprVisitor): R } +data class Assign( + val name: Token, + val value: Expr +) : Expr() { + override fun accept(visitor: ExprVisitor): R { + return visitor.visitAssign(this) + } +} + data class Binary( val left: Expr, val operator: Token, val right: Expr ) : Expr() { @@ -47,4 +58,12 @@ data class Unary( override fun accept(visitor: ExprVisitor): R { return visitor.visitUnary(this) } -} \ No newline at end of file +} + +data class Variable( + val name: Token +) : Expr() { + override fun accept(visitor: ExprVisitor): R { + return visitor.visitVariable(this) + } +} diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index 34c1a77..09dd715 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -3,8 +3,9 @@ package fr.celticinfo.lox import fr.celticinfo.lox.TokenType.* class Interpreter: ExprVisitor, StmtVisitor{ + private var environment = Environment() - fun interpret(statements: List) { + fun interpret(statements: List) { try { for (statement in statements) { execute(statement) @@ -15,12 +16,28 @@ class Interpreter: ExprVisitor, StmtVisitor{ } } - private fun execute(stmt: Stmt) { - stmt.accept(this) + private fun execute(stmt: Stmt?) { + stmt?.accept(this) } - private fun evaluate(expr: Expr): Any? { - return expr.accept(this) + override fun visitBlock(stmt: Block) { + executeBlock(stmt.statements, Environment(environment)) + } + + private fun executeBlock(statements: List, environment: Environment) { + val previous = this.environment + try { + this.environment = environment + for (statement in statements) { + execute(statement) + } + } finally { + this.environment = previous + } + } + + private fun evaluate(expr: Expr?): Any? { + return expr?.accept(this) } override fun visitExpression(stmt: Expression) { @@ -32,6 +49,17 @@ class Interpreter: ExprVisitor, StmtVisitor{ println(stringify(value)) } + override fun visitVar(stmt: Var) { + val value = evaluate(stmt.initializer) + environment.define(stmt.name.lexeme, value) + } + + override fun visitAssign(expr: Assign): Any? { + val value = evaluate(expr.value) + environment.assign(expr.name, value) + return value + } + override fun visitBinary(expr: Binary): Any? { val left = evaluate(expr.left) val right = evaluate(expr.right) @@ -104,6 +132,10 @@ class Interpreter: ExprVisitor, StmtVisitor{ } } + override fun visitVariable(expr: Variable): Any? { + return environment.get(expr.name) + } + private fun isTruthy(obj: Any?): Boolean { return when (obj) { null -> false diff --git a/src/main/fr/celticinfo/lox/Parser.kt b/src/main/fr/celticinfo/lox/Parser.kt index fd510e0..a3b8548 100644 --- a/src/main/fr/celticinfo/lox/Parser.kt +++ b/src/main/fr/celticinfo/lox/Parser.kt @@ -10,18 +10,44 @@ import fr.celticinfo.lox.TokenType.* class Parser(private val tokens: List) { private var current = 0 - fun parse(): List { - val statements: MutableList = ArrayList() + fun parse(): List { + val statements: MutableList = ArrayList() while (!isAtEnd()) { - statements.add(statement()) + statements.add(declaration()) } return statements } + private fun declaration(): Stmt? { + return try { + when { + match(VAR) -> varDeclaration() + else -> statement() + } + } catch (error: ParseError) { + synchronize() + return null + } + } + + private fun varDeclaration(): Stmt { + val name = consume(IDENTIFIER, "Expect variable name.") + + var initializer: Expr? = null + if (match(EQUAL)) { + initializer = expression() + } + + consume(SEMICOLON, "Expect ';' after variable declaration.") + return Var(name, initializer) + } + + private fun statement(): Stmt { return when { match(PRINT) -> printStatement() + match(LEFT_BRACE) -> Block(blockStatement()) else -> expressionStatement() } } @@ -34,12 +60,41 @@ class Parser(private val tokens: List) { private fun expressionStatement(): Stmt { val value = expression() - consume(SEMICOLON, "Expect ';' after value.") + consume(SEMICOLON, "Expect ';' after expression.") return Expression(value) } + private fun blockStatement(): List { + val statements: MutableList = ArrayList() + + while (!check(RIGHT_BRACE) && !isAtEnd()) { + statements.add(declaration()) + } + + consume(RIGHT_BRACE, "Expect '}' after block.") + return statements + } + private fun expression(): Expr { - return equality() + return assignment() + } + + private fun assignment(): Expr { + val expr = equality() + + if (match(EQUAL)) { + val equals = previous() + val value = assignment() + + if (expr is Variable) { + val name = expr.name + return Assign(name, value) + } + + error(equals, "Invalid assignment target.") + } + + return expr } private fun equality(): Expr { @@ -116,6 +171,7 @@ class Parser(private val tokens: List) { match(TRUE) -> return Literal(true) match(NIL) -> return Literal(null) match(NUMBER, STRING) -> return Literal(previous().literal) + match(IDENTIFIER) -> return Variable(previous()) match(LEFT_PAREN) -> { val expr = expression() consume(RIGHT_PAREN, "Expect ')' after expression.") diff --git a/src/main/fr/celticinfo/lox/Stmt.kt b/src/main/fr/celticinfo/lox/Stmt.kt index 0c495ef..51564fe 100644 --- a/src/main/fr/celticinfo/lox/Stmt.kt +++ b/src/main/fr/celticinfo/lox/Stmt.kt @@ -4,8 +4,10 @@ package fr.celticinfo.lox * The StmtVisitor interface is used to visit the different types of statements that can be parsed by the Parser. */ interface StmtVisitor { + fun visitBlock(stmt: Block): R fun visitExpression(stmt: Expression): R fun visitPrint(stmt: Print): R + fun visitVar(stmt: Var): R } /** @@ -15,6 +17,14 @@ sealed class Stmt { abstract fun accept(visitor: StmtVisitor): R } +data class Block( + val statements: List +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitBlock(this) + } +} + data class Expression( val expression: Expr ) : Stmt() { @@ -29,4 +39,13 @@ data class Print( override fun accept(visitor: StmtVisitor): R { return visitor.visitPrint(this) } +} + +data class Var( + val name: Token, + val initializer: Expr? +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitVar(this) + } } \ No newline at end of file diff --git a/src/main/fr/celticinfo/loxext/RpnPrinter.kt b/src/main/fr/celticinfo/loxext/RpnPrinter.kt index 525e49f..6fb7f79 100644 --- a/src/main/fr/celticinfo/loxext/RpnPrinter.kt +++ b/src/main/fr/celticinfo/loxext/RpnPrinter.kt @@ -7,20 +7,28 @@ class RpnPrinter : ExprVisitor { return expr.accept(this) } - override fun visitBinary(binary: Binary): String { - return stack(binary.operator.lexeme, binary.left, binary.right) + override fun visitAssign(expr: Assign): String { + return stack("=", expr) } - override fun visitGrouping(grouping: Grouping): String { - return stack("", grouping.expression) + override fun visitBinary(expr: Binary): String { + return stack(expr.operator.lexeme, expr.left, expr.right) } - override fun visitLiteral(literal: Literal): String { - return literal.value?.toString() ?: "nil" + override fun visitGrouping(expr: Grouping): String { + return stack("", expr.expression) } - override fun visitUnary(unary: Unary): String { - return stack(unary.operator.lexeme, unary.right) + override fun visitLiteral(expr: Literal): String { + return expr.value?.toString() ?: "nil" + } + + override fun visitUnary(expr: Unary): String { + return stack(expr.operator.lexeme, expr.right) + } + + override fun visitVariable(expr: Variable): String { + return expr.name.lexeme } private fun stack(name: String, vararg exprs: Expr): String { diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index 45d376c..0dff9da 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -1,8 +1,11 @@ package fr.celticinfo.lox import fr.celticinfo.loxext.RpnPrinter -import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertDoesNotThrow +import java.io.ByteArrayOutputStream +import java.io.PrintStream import kotlin.test.* +import kotlin.test.Test class InterpreterTest { @@ -57,4 +60,145 @@ class InterpreterTest { } ) } + + @Test + fun `Shadowing a variable should not raise error`() { + val code = """ + var a = 1; + { + var a = 2; + } + print a; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + assertDoesNotThrow { + Interpreter().interpret(statements) + } + } + + @Test + fun `Variable reassignment should not raise error`() { + val code = """ + var a = 1; + a = 2; + print a; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + assertDoesNotThrow { + Interpreter().interpret(statements) + } + } + + @Test + fun `Variable reassignment with different type should not raise error`() { + val code = """ + var a = 1; + a = false; + print a; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + assertDoesNotThrow { + Interpreter().interpret(statements) + } + } + + @Test + fun `Variable shadowing should not raise error`() { + val code = """ + var a = 1; + { + var a = false; + print a; + } + print a; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + assertDoesNotThrow { + Interpreter().interpret(statements) + } + } + + @Test + fun `Variable shadowing with different type should not raise error`() { + val code = """ + var a = 1; + { + var a = false; + print a; + } + print a; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3, statements.size) + + assertDoesNotThrow { + Interpreter().interpret(statements) + } + } + + @Test + fun `Variable shadowing should work with block`() { + val standardOut = System.out + val outputStreamCaptor = ByteArrayOutputStream() + + System.setOut(PrintStream(outputStreamCaptor)) + + try { + val code = """ +var a = "global a"; +var b = "global b"; +var c = "global c"; +{ + var a = "outer a"; + var b = "outer b"; + { + var a = "inner a"; + print a; + print b; + print c; + } + print a; + print b; + print c; +} +print a; +print b; +print c; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(7, statements.size) + + Interpreter().interpret(statements) + val output = outputStreamCaptor.toString().trim() + assertEquals("inner a\nouter b\nglobal c\nouter a\nouter b\nglobal c\nglobal a\nglobal b\nglobal c", output) + } finally { + System.setOut(standardOut) + } + } } \ No newline at end of file diff --git a/src/test/fr/celticinfo/lox/ParserTest.kt b/src/test/fr/celticinfo/lox/ParserTest.kt index 39a9108..09674a0 100644 --- a/src/test/fr/celticinfo/lox/ParserTest.kt +++ b/src/test/fr/celticinfo/lox/ParserTest.kt @@ -1,8 +1,9 @@ package fr.celticinfo.lox import fr.celticinfo.loxext.RpnPrinter -import org.junit.jupiter.api.Test +import org.junit.jupiter.api.assertAll import kotlin.test.* +import kotlin.test.Test class ParserTest { @Test @@ -29,12 +30,10 @@ class ParserTest { val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - - assertFailsWith( - block = { - parser.parse() - } - ) + val statements = parser.parse() + assertEquals(1,statements.size) + val stmt = statements.first() + assertNull(stmt) } @Test @@ -45,11 +44,86 @@ class ParserTest { val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(1,statements.size) + val stmt = statements.first() + assertNull(stmt) + } + + @Test + fun `valid code with multiple statements`() { + val code = """ + var a = 1; + var b = 2; + print a + b; + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(3,statements.size) + assertAll( + { assertTrue(statements[0] is Var) }, + { assertTrue(statements[1] is Var) }, + { assertTrue(statements[2] is Print) } + ) + } - assertFailsWith( - block = { - parser.parse() + @Test + fun `valid code with block`() { + val code = """ + { + var a = 1; + var b = 2; + print a + b; } + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(1,statements.size) + val stmt = statements.first() + assertTrue(stmt is Block) + val block = stmt.statements + assertEquals(3,block.size) + assertAll( + { assertTrue(block[0] is Var) }, + { assertTrue(block[1] is Var) }, + { assertTrue(block[2] is Print) } + ) + } + + @Test + fun `valid code with block and nested block`() { + val code = """ + { + var a = 1; + { + var b = 2; + print a + b; + } + } + """.trimIndent() + val scanner = Scanner(code) + val tokens = scanner.scanTokens() + val parser = Parser(tokens) + val statements = parser.parse() + assertEquals(1,statements.size) + val stmt = statements.first() + assertTrue(stmt is Block) + val block = stmt.statements + assertEquals(2,block.size) + assertAll( + { assertTrue(block[0] is Var) }, + { assertTrue(block[1] is Block) } + ) + val nestedBlock = block[1] as Block + val nestedStatements = nestedBlock.statements + assertEquals(2,nestedStatements.size) + assertAll( + { assertTrue(nestedStatements[0] is Var) }, + { assertTrue(nestedStatements[1] is Print) } ) } } \ No newline at end of file