From 72f48b75bbbcbad256563811bfb582a673226788 Mon Sep 17 00:00:00 2001 From: oabrivard Date: Sun, 7 Jul 2024 16:29:32 +0200 Subject: [PATCH] Added Stmt class --- src/main/ExprGenerator.ws.kts | 84 ++++++++++--------- src/main/fr/celticinfo/lox/Expr.kt | 12 +-- src/main/fr/celticinfo/lox/Interpreter.kt | 69 ++++++++------- src/main/fr/celticinfo/lox/Lox.kt | 4 +- src/main/fr/celticinfo/lox/Parser.kt | 33 ++++++-- src/main/fr/celticinfo/lox/Stmt.kt | 32 +++++++ src/test/fr/celticinfo/lox/InterpreterTest.kt | 40 ++++++--- src/test/fr/celticinfo/lox/ParserTest.kt | 29 +++++-- 8 files changed, 199 insertions(+), 104 deletions(-) create mode 100644 src/main/fr/celticinfo/lox/Stmt.kt diff --git a/src/main/ExprGenerator.ws.kts b/src/main/ExprGenerator.ws.kts index d3ad083..d68f7f4 100644 --- a/src/main/ExprGenerator.ws.kts +++ b/src/main/ExprGenerator.ws.kts @@ -1,46 +1,54 @@ import java.util.* -val types = listOf( +fun defineAst(baseName: String, types: List) { + println("package fr.celticinfo.lox") + + println() + println("interface ${baseName}Visitor {") + for (type in types) { + val parts = type.split(":") + val name = parts[0].trim() + println(" fun visit$name(${baseName.lowercase(Locale.getDefault())}: $name): R") + } + println("}") + + println() + println("sealed class $baseName {") + println(" abstract fun accept(visitor: ${baseName}Visitor): R") + println("}") + for (type in types) { + val parts = type.split(":") + val name = parts[0].trim() + val fields = parts[1].trim().split(",").map { it.trim() } + println() + println("data class $name(") + for (field in fields) { + val fparts = field.split(" ") + val ftype = fparts[0] + val fname = fparts[1] + val sep = if (field == fields.last()) "" else "," + println(" val $fname: $ftype$sep") + } + println(") : ${baseName}() {") + println(" override fun accept(visitor: ${baseName}Visitor): R {") + println(" return visitor.visit$name(this)") + println(" }") + println("}") + } + + +} + +val exprTypes = listOf( "Binary : Expr left, Token operator, Expr right", "Grouping : Expr expression", "Literal : Any? value", "Unary : Token operator, Expr right" ) +defineAst("Expr", exprTypes) -println("package fr.celticinfo.lox") - -println() -println("interface ExprVisitor {") -for (type in types) { - val parts = type.split(":") - val name = parts[0].trim() - println(" fun visit$name(${name.lowercase(Locale.getDefault())}: $name): R") -} -println("}") - -println() -println("""/** - * The Expr class represents the different types of expressions that can be parsed by the Parser. - */""".trimIndent()) -println("sealed class Expr {") -println(" abstract fun accept(visitor: ExprVisitor): R") -println("}") -for (type in types) { - val parts = type.split(":") - val name = parts[0].trim() - val fields = parts[1].trim().split(",").map { it.trim() } - println() - println("data class $name(") - for (field in fields) { - val fparts = field.split(" ") - val ftype = fparts[0] - val fname = fparts[1] - val sep = if (field == fields.last()) "" else "," - println(" val $fname: $ftype$sep") - } - println(") : Expr() {") - println(" override fun accept(visitor: ExprVisitor): R {") - println(" return visitor.visit$name(this)") - println(" }") - println("}") -} +val stmtTypes = listOf( + "Expression : Expr expression", + "Print : Expr expression" +) +defineAst("Stmt", stmtTypes) diff --git a/src/main/fr/celticinfo/lox/Expr.kt b/src/main/fr/celticinfo/lox/Expr.kt index 596ce3d..ca40780 100644 --- a/src/main/fr/celticinfo/lox/Expr.kt +++ b/src/main/fr/celticinfo/lox/Expr.kt @@ -1,11 +1,13 @@ package fr.celticinfo.lox +/** + * The ExprVisitor interface is used to visit the different types of expressions that can be parsed by the Parser. + */ interface ExprVisitor { - fun visitBinary(binary: Binary): R - fun visitGrouping(grouping: Grouping): R - fun visitLiteral(literal: Literal): R - fun visitUnary(unary: Unary): R - + fun visitBinary(expr: Binary): R + fun visitGrouping(expr: Grouping): R + fun visitLiteral(expr: Literal): R + fun visitUnary(expr: Unary): R } /** diff --git a/src/main/fr/celticinfo/lox/Interpreter.kt b/src/main/fr/celticinfo/lox/Interpreter.kt index 77207f6..34c1a77 100644 --- a/src/main/fr/celticinfo/lox/Interpreter.kt +++ b/src/main/fr/celticinfo/lox/Interpreter.kt @@ -2,41 +2,54 @@ package fr.celticinfo.lox import fr.celticinfo.lox.TokenType.* -class Interpreter: ExprVisitor{ +class Interpreter: ExprVisitor, StmtVisitor{ - fun interpret(expr: Expr): Any? { - return try { - val value = evaluate(expr) - println(stringify(value)) - value + fun interpret(statements: List) { + try { + for (statement in statements) { + execute(statement) + } } catch (error: RuntimeError) { Lox.runtimeError(error) - null + throw error } } + private fun execute(stmt: Stmt) { + stmt.accept(this) + } + private fun evaluate(expr: Expr): Any? { return expr.accept(this) } - override fun visitBinary(binary: Binary): Any? { - val left = evaluate(binary.left) - val right = evaluate(binary.right) + override fun visitExpression(stmt: Expression) { + evaluate(stmt.expression) + } + + override fun visitPrint(stmt: Print) { + val value = evaluate(stmt.expression) + println(stringify(value)) + } + + override fun visitBinary(expr: Binary): Any? { + val left = evaluate(expr.left) + val right = evaluate(expr.right) - return when (binary.operator.type) { + return when (expr.operator.type) { MINUS -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) left as Double - right as Double } SLASH -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) if (right == 0.0) { - throw RuntimeError(binary.operator, "Division by zero") + throw RuntimeError(expr.operator, "Division by zero") } left as Double / right as Double } STAR -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) left as Double * right as Double } PLUS -> { @@ -45,23 +58,23 @@ class Interpreter: ExprVisitor{ } else if (left is String && right is String) { left + right } else { - throw RuntimeError(binary.operator, "Operands must be two numbers or two strings") + throw RuntimeError(expr.operator, "Operands must be two numbers or two strings") } } GREATER -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) left as Double > right as Double } GREATER_EQUAL -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) left as Double >= right as Double } LESS -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) (left as Double) < right as Double } LESS_EQUAL -> { - checkNumberOperands(binary.operator, left, right) + checkNumberOperands(expr.operator, left, right) left as Double <= right as Double } BANG_EQUAL -> return !isEqual(left, right) @@ -70,20 +83,20 @@ class Interpreter: ExprVisitor{ } } - override fun visitGrouping(grouping: Grouping): Any? { - return evaluate(grouping.expression) + override fun visitGrouping(expr: Grouping): Any? { + return evaluate(expr.expression) } - override fun visitLiteral(literal: Literal): Any? { - return literal.value + override fun visitLiteral(expr: Literal): Any? { + return expr.value } - override fun visitUnary(unary: Unary): Any? { - val right = evaluate(unary.right) + override fun visitUnary(expr: Unary): Any? { + val right = evaluate(expr.right) - return when (unary.operator.type) { + return when (expr.operator.type) { MINUS -> { - checkNumberOperand(unary.operator, right) + checkNumberOperand(expr.operator, right) -(right as Double) } BANG -> !isTruthy(right) diff --git a/src/main/fr/celticinfo/lox/Lox.kt b/src/main/fr/celticinfo/lox/Lox.kt index b102902..7d3d94e 100644 --- a/src/main/fr/celticinfo/lox/Lox.kt +++ b/src/main/fr/celticinfo/lox/Lox.kt @@ -40,12 +40,12 @@ object Lox { val scanner = Scanner(source) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expression = parser.parse() + val statements = parser.parse() // Stop if there was a syntax error. if (hadError) return - interpreter.interpret(expression!!) + interpreter.interpret(statements) } fun error(line: Int, s: String) { diff --git a/src/main/fr/celticinfo/lox/Parser.kt b/src/main/fr/celticinfo/lox/Parser.kt index 6f4b09d..fd510e0 100644 --- a/src/main/fr/celticinfo/lox/Parser.kt +++ b/src/main/fr/celticinfo/lox/Parser.kt @@ -2,6 +2,7 @@ package fr.celticinfo.lox import fr.celticinfo.lox.TokenType.* + /** * The Parser class is responsible for parsing the tokens produced by the Scanner into an abstract syntax tree (AST). * It is a recursive descent parser that uses a series of methods to parse different parts of the grammar. @@ -9,20 +10,34 @@ import fr.celticinfo.lox.TokenType.* class Parser(private val tokens: List) { private var current = 0 - fun parse(): Expr? { - return try { - val result = expression() + fun parse(): List { + val statements: MutableList = ArrayList() + while (!isAtEnd()) { + statements.add(statement()) + } - if (!isAtEnd()) { - throw error(peek(), "Expect end of expression.") - } + return statements + } - result - } catch (error: ParseError) { - null + private fun statement(): Stmt { + return when { + match(PRINT) -> printStatement() + else -> expressionStatement() } } + private fun printStatement(): Stmt { + val value = expression() + consume(SEMICOLON, "Expect ';' after value.") + return Print(value) + } + + private fun expressionStatement(): Stmt { + val value = expression() + consume(SEMICOLON, "Expect ';' after value.") + return Expression(value) + } + private fun expression(): Expr { return equality() } diff --git a/src/main/fr/celticinfo/lox/Stmt.kt b/src/main/fr/celticinfo/lox/Stmt.kt new file mode 100644 index 0000000..0c495ef --- /dev/null +++ b/src/main/fr/celticinfo/lox/Stmt.kt @@ -0,0 +1,32 @@ +package fr.celticinfo.lox + +/** + * The StmtVisitor interface is used to visit the different types of statements that can be parsed by the Parser. + */ +interface StmtVisitor { + fun visitExpression(stmt: Expression): R + fun visitPrint(stmt: Print): R +} + +/** + * The Stmt class represents the different types of statements that can be parsed by the Parser. + */ +sealed class Stmt { + abstract fun accept(visitor: StmtVisitor): R +} + +data class Expression( + val expression: Expr +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitExpression(this) + } +} + +data class Print( + val expression: Expr +) : Stmt() { + override fun accept(visitor: StmtVisitor): R { + return visitor.visitPrint(this) + } +} \ No newline at end of file diff --git a/src/test/fr/celticinfo/lox/InterpreterTest.kt b/src/test/fr/celticinfo/lox/InterpreterTest.kt index 811d8e8..45d376c 100644 --- a/src/test/fr/celticinfo/lox/InterpreterTest.kt +++ b/src/test/fr/celticinfo/lox/InterpreterTest.kt @@ -1,46 +1,60 @@ package fr.celticinfo.lox +import fr.celticinfo.loxext.RpnPrinter import org.junit.jupiter.api.Test -import org.junit.jupiter.api.Assertions.* +import kotlin.test.* class InterpreterTest { @Test fun `validate interpreter`() { val code = """ - (1 + 2 * 3 - 5) / 2 + 1 + 2 * 3 - 4 / 5; """.trimIndent() val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expr = parser.parse() - val value = Interpreter().interpret(expr!!) - assertEquals(1.0, value) + val statements = parser.parse() + assertEquals(1, statements.size) + val stmt = statements.first() + assertTrue(stmt is Expression) + val expr = stmt.expression + assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr)) } @Test fun `Division by zero should raise error`() { val code = """ - 1 / 0 + 1 / 0; """.trimIndent() val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expr = parser.parse() - val value = Interpreter().interpret(expr!!) - assertNull(value) + val statements = parser.parse() + assertEquals(1, statements.size) + + assertFailsWith( + block = { + Interpreter().interpret(statements) + } + ) } @Test fun `Invalid type raise error`() { val code = """ - 1 + false + 1 + false; """.trimIndent() val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expr = parser.parse() - val value = Interpreter().interpret(expr!!) - assertNull(value) + val statements = parser.parse() + assertEquals(1, statements.size) + + assertFailsWith( + block = { + Interpreter().interpret(statements) + } + ) } } \ No newline at end of file diff --git a/src/test/fr/celticinfo/lox/ParserTest.kt b/src/test/fr/celticinfo/lox/ParserTest.kt index d5c70f3..39a9108 100644 --- a/src/test/fr/celticinfo/lox/ParserTest.kt +++ b/src/test/fr/celticinfo/lox/ParserTest.kt @@ -2,20 +2,23 @@ package fr.celticinfo.lox import fr.celticinfo.loxext.RpnPrinter import org.junit.jupiter.api.Test -import org.junit.jupiter.api.Assertions.* +import kotlin.test.* class ParserTest { @Test fun `validate parser`() { val code = """ - 1 + 2 * 3 - 4 / 5 + 1 + 2 * 3 - 4 / 5; """.trimIndent() val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expr = parser.parse() - assertNotNull(expr) - assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr!!)) + val statements = parser.parse() + assertEquals(1,statements.size) + val stmt = statements.first() + assertTrue(stmt is Expression) + val expr = stmt.expression + assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr)) } @Test @@ -26,8 +29,12 @@ class ParserTest { val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expr = parser.parse() - assertNull(expr) + + assertFailsWith( + block = { + parser.parse() + } + ) } @Test @@ -38,7 +45,11 @@ class ParserTest { val scanner = Scanner(code) val tokens = scanner.scanTokens() val parser = Parser(tokens) - val expr = parser.parse() - assertNull(expr) + + assertFailsWith( + block = { + parser.parse() + } + ) } } \ No newline at end of file