Added Stmt class
parent
cc4aa58bff
commit
72f48b75bb
@ -1,46 +1,54 @@
|
|||||||
import java.util.*
|
import java.util.*
|
||||||
|
|
||||||
val types = listOf(
|
fun defineAst(baseName: String, types: List<String>) {
|
||||||
|
println("package fr.celticinfo.lox")
|
||||||
|
|
||||||
|
println()
|
||||||
|
println("interface ${baseName}Visitor<R> {")
|
||||||
|
for (type in types) {
|
||||||
|
val parts = type.split(":")
|
||||||
|
val name = parts[0].trim()
|
||||||
|
println(" fun visit$name(${baseName.lowercase(Locale.getDefault())}: $name): R")
|
||||||
|
}
|
||||||
|
println("}")
|
||||||
|
|
||||||
|
println()
|
||||||
|
println("sealed class $baseName {")
|
||||||
|
println(" abstract fun <R> accept(visitor: ${baseName}Visitor<R>): R")
|
||||||
|
println("}")
|
||||||
|
for (type in types) {
|
||||||
|
val parts = type.split(":")
|
||||||
|
val name = parts[0].trim()
|
||||||
|
val fields = parts[1].trim().split(",").map { it.trim() }
|
||||||
|
println()
|
||||||
|
println("data class $name(")
|
||||||
|
for (field in fields) {
|
||||||
|
val fparts = field.split(" ")
|
||||||
|
val ftype = fparts[0]
|
||||||
|
val fname = fparts[1]
|
||||||
|
val sep = if (field == fields.last()) "" else ","
|
||||||
|
println(" val $fname: $ftype$sep")
|
||||||
|
}
|
||||||
|
println(") : ${baseName}() {")
|
||||||
|
println(" override fun <R> accept(visitor: ${baseName}Visitor<R>): R {")
|
||||||
|
println(" return visitor.visit$name(this)")
|
||||||
|
println(" }")
|
||||||
|
println("}")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
val exprTypes = listOf(
|
||||||
"Binary : Expr left, Token operator, Expr right",
|
"Binary : Expr left, Token operator, Expr right",
|
||||||
"Grouping : Expr expression",
|
"Grouping : Expr expression",
|
||||||
"Literal : Any? value",
|
"Literal : Any? value",
|
||||||
"Unary : Token operator, Expr right"
|
"Unary : Token operator, Expr right"
|
||||||
)
|
)
|
||||||
|
defineAst("Expr", exprTypes)
|
||||||
|
|
||||||
println("package fr.celticinfo.lox")
|
val stmtTypes = listOf(
|
||||||
|
"Expression : Expr expression",
|
||||||
println()
|
"Print : Expr expression"
|
||||||
println("interface ExprVisitor<R> {")
|
)
|
||||||
for (type in types) {
|
defineAst("Stmt", stmtTypes)
|
||||||
val parts = type.split(":")
|
|
||||||
val name = parts[0].trim()
|
|
||||||
println(" fun visit$name(${name.lowercase(Locale.getDefault())}: $name): R")
|
|
||||||
}
|
|
||||||
println("}")
|
|
||||||
|
|
||||||
println()
|
|
||||||
println("""/**
|
|
||||||
* The Expr class represents the different types of expressions that can be parsed by the Parser.
|
|
||||||
*/""".trimIndent())
|
|
||||||
println("sealed class Expr {")
|
|
||||||
println(" abstract fun <R> accept(visitor: ExprVisitor<R>): R")
|
|
||||||
println("}")
|
|
||||||
for (type in types) {
|
|
||||||
val parts = type.split(":")
|
|
||||||
val name = parts[0].trim()
|
|
||||||
val fields = parts[1].trim().split(",").map { it.trim() }
|
|
||||||
println()
|
|
||||||
println("data class $name(")
|
|
||||||
for (field in fields) {
|
|
||||||
val fparts = field.split(" ")
|
|
||||||
val ftype = fparts[0]
|
|
||||||
val fname = fparts[1]
|
|
||||||
val sep = if (field == fields.last()) "" else ","
|
|
||||||
println(" val $fname: $ftype$sep")
|
|
||||||
}
|
|
||||||
println(") : Expr() {")
|
|
||||||
println(" override fun <R> accept(visitor: ExprVisitor<R>): R {")
|
|
||||||
println(" return visitor.visit$name(this)")
|
|
||||||
println(" }")
|
|
||||||
println("}")
|
|
||||||
}
|
|
||||||
|
|||||||
@ -0,0 +1,32 @@
|
|||||||
|
package fr.celticinfo.lox
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The StmtVisitor interface is used to visit the different types of statements that can be parsed by the Parser.
|
||||||
|
*/
|
||||||
|
interface StmtVisitor<R> {
|
||||||
|
fun visitExpression(stmt: Expression): R
|
||||||
|
fun visitPrint(stmt: Print): R
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The Stmt class represents the different types of statements that can be parsed by the Parser.
|
||||||
|
*/
|
||||||
|
sealed class Stmt {
|
||||||
|
abstract fun <R> accept(visitor: StmtVisitor<R>): R
|
||||||
|
}
|
||||||
|
|
||||||
|
data class Expression(
|
||||||
|
val expression: Expr
|
||||||
|
) : Stmt() {
|
||||||
|
override fun <R> accept(visitor: StmtVisitor<R>): R {
|
||||||
|
return visitor.visitExpression(this)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data class Print(
|
||||||
|
val expression: Expr
|
||||||
|
) : Stmt() {
|
||||||
|
override fun <R> accept(visitor: StmtVisitor<R>): R {
|
||||||
|
return visitor.visitPrint(this)
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,46 +1,60 @@
|
|||||||
package fr.celticinfo.lox
|
package fr.celticinfo.lox
|
||||||
|
|
||||||
|
import fr.celticinfo.loxext.RpnPrinter
|
||||||
import org.junit.jupiter.api.Test
|
import org.junit.jupiter.api.Test
|
||||||
import org.junit.jupiter.api.Assertions.*
|
import kotlin.test.*
|
||||||
|
|
||||||
class InterpreterTest {
|
class InterpreterTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `validate interpreter`() {
|
fun `validate interpreter`() {
|
||||||
val code = """
|
val code = """
|
||||||
(1 + 2 * 3 - 5) / 2
|
1 + 2 * 3 - 4 / 5;
|
||||||
""".trimIndent()
|
""".trimIndent()
|
||||||
val scanner = Scanner(code)
|
val scanner = Scanner(code)
|
||||||
val tokens = scanner.scanTokens()
|
val tokens = scanner.scanTokens()
|
||||||
val parser = Parser(tokens)
|
val parser = Parser(tokens)
|
||||||
val expr = parser.parse()
|
val statements = parser.parse()
|
||||||
val value = Interpreter().interpret(expr!!)
|
assertEquals(1, statements.size)
|
||||||
assertEquals(1.0, value)
|
val stmt = statements.first()
|
||||||
|
assertTrue(stmt is Expression)
|
||||||
|
val expr = stmt.expression
|
||||||
|
assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr))
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `Division by zero should raise error`() {
|
fun `Division by zero should raise error`() {
|
||||||
val code = """
|
val code = """
|
||||||
1 / 0
|
1 / 0;
|
||||||
""".trimIndent()
|
""".trimIndent()
|
||||||
val scanner = Scanner(code)
|
val scanner = Scanner(code)
|
||||||
val tokens = scanner.scanTokens()
|
val tokens = scanner.scanTokens()
|
||||||
val parser = Parser(tokens)
|
val parser = Parser(tokens)
|
||||||
val expr = parser.parse()
|
val statements = parser.parse()
|
||||||
val value = Interpreter().interpret(expr!!)
|
assertEquals(1, statements.size)
|
||||||
assertNull(value)
|
|
||||||
|
assertFailsWith<RuntimeError>(
|
||||||
|
block = {
|
||||||
|
Interpreter().interpret(statements)
|
||||||
|
}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
fun `Invalid type raise error`() {
|
fun `Invalid type raise error`() {
|
||||||
val code = """
|
val code = """
|
||||||
1 + false
|
1 + false;
|
||||||
""".trimIndent()
|
""".trimIndent()
|
||||||
val scanner = Scanner(code)
|
val scanner = Scanner(code)
|
||||||
val tokens = scanner.scanTokens()
|
val tokens = scanner.scanTokens()
|
||||||
val parser = Parser(tokens)
|
val parser = Parser(tokens)
|
||||||
val expr = parser.parse()
|
val statements = parser.parse()
|
||||||
val value = Interpreter().interpret(expr!!)
|
assertEquals(1, statements.size)
|
||||||
assertNull(value)
|
|
||||||
|
assertFailsWith<RuntimeError>(
|
||||||
|
block = {
|
||||||
|
Interpreter().interpret(statements)
|
||||||
|
}
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Loading…
Reference in New Issue