Added Stmt class

main
oabrivard 2 years ago
parent cc4aa58bff
commit 72f48b75bb

@ -1,29 +1,20 @@
import java.util.* import java.util.*
val types = listOf( fun defineAst(baseName: String, types: List<String>) {
"Binary : Expr left, Token operator, Expr right",
"Grouping : Expr expression",
"Literal : Any? value",
"Unary : Token operator, Expr right"
)
println("package fr.celticinfo.lox") println("package fr.celticinfo.lox")
println() println()
println("interface ExprVisitor<R> {") println("interface ${baseName}Visitor<R> {")
for (type in types) { for (type in types) {
val parts = type.split(":") val parts = type.split(":")
val name = parts[0].trim() val name = parts[0].trim()
println(" fun visit$name(${name.lowercase(Locale.getDefault())}: $name): R") println(" fun visit$name(${baseName.lowercase(Locale.getDefault())}: $name): R")
} }
println("}") println("}")
println() println()
println("""/** println("sealed class $baseName {")
* The Expr class represents the different types of expressions that can be parsed by the Parser. println(" abstract fun <R> accept(visitor: ${baseName}Visitor<R>): R")
*/""".trimIndent())
println("sealed class Expr {")
println(" abstract fun <R> accept(visitor: ExprVisitor<R>): R")
println("}") println("}")
for (type in types) { for (type in types) {
val parts = type.split(":") val parts = type.split(":")
@ -38,9 +29,26 @@ for (type in types) {
val sep = if (field == fields.last()) "" else "," val sep = if (field == fields.last()) "" else ","
println(" val $fname: $ftype$sep") println(" val $fname: $ftype$sep")
} }
println(") : Expr() {") println(") : ${baseName}() {")
println(" override fun <R> accept(visitor: ExprVisitor<R>): R {") println(" override fun <R> accept(visitor: ${baseName}Visitor<R>): R {")
println(" return visitor.visit$name(this)") println(" return visitor.visit$name(this)")
println(" }") println(" }")
println("}") println("}")
} }
}
val exprTypes = listOf(
"Binary : Expr left, Token operator, Expr right",
"Grouping : Expr expression",
"Literal : Any? value",
"Unary : Token operator, Expr right"
)
defineAst("Expr", exprTypes)
val stmtTypes = listOf(
"Expression : Expr expression",
"Print : Expr expression"
)
defineAst("Stmt", stmtTypes)

@ -1,11 +1,13 @@
package fr.celticinfo.lox package fr.celticinfo.lox
/**
* The ExprVisitor interface is used to visit the different types of expressions that can be parsed by the Parser.
*/
interface ExprVisitor<R> { interface ExprVisitor<R> {
fun visitBinary(binary: Binary): R fun visitBinary(expr: Binary): R
fun visitGrouping(grouping: Grouping): R fun visitGrouping(expr: Grouping): R
fun visitLiteral(literal: Literal): R fun visitLiteral(expr: Literal): R
fun visitUnary(unary: Unary): R fun visitUnary(expr: Unary): R
} }
/** /**

@ -2,41 +2,54 @@ package fr.celticinfo.lox
import fr.celticinfo.lox.TokenType.* import fr.celticinfo.lox.TokenType.*
class Interpreter: ExprVisitor<Any?>{ class Interpreter: ExprVisitor<Any?>, StmtVisitor<Unit>{
fun interpret(expr: Expr): Any? { fun interpret(statements: List<Stmt>) {
return try { try {
val value = evaluate(expr) for (statement in statements) {
println(stringify(value)) execute(statement)
value }
} catch (error: RuntimeError) { } catch (error: RuntimeError) {
Lox.runtimeError(error) Lox.runtimeError(error)
null throw error
} }
} }
private fun execute(stmt: Stmt) {
stmt.accept(this)
}
private fun evaluate(expr: Expr): Any? { private fun evaluate(expr: Expr): Any? {
return expr.accept(this) return expr.accept(this)
} }
override fun visitBinary(binary: Binary): Any? { override fun visitExpression(stmt: Expression) {
val left = evaluate(binary.left) evaluate(stmt.expression)
val right = evaluate(binary.right) }
override fun visitPrint(stmt: Print) {
val value = evaluate(stmt.expression)
println(stringify(value))
}
override fun visitBinary(expr: Binary): Any? {
val left = evaluate(expr.left)
val right = evaluate(expr.right)
return when (binary.operator.type) { return when (expr.operator.type) {
MINUS -> { MINUS -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
left as Double - right as Double left as Double - right as Double
} }
SLASH -> { SLASH -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
if (right == 0.0) { if (right == 0.0) {
throw RuntimeError(binary.operator, "Division by zero") throw RuntimeError(expr.operator, "Division by zero")
} }
left as Double / right as Double left as Double / right as Double
} }
STAR -> { STAR -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
left as Double * right as Double left as Double * right as Double
} }
PLUS -> { PLUS -> {
@ -45,23 +58,23 @@ class Interpreter: ExprVisitor<Any?>{
} else if (left is String && right is String) { } else if (left is String && right is String) {
left + right left + right
} else { } else {
throw RuntimeError(binary.operator, "Operands must be two numbers or two strings") throw RuntimeError(expr.operator, "Operands must be two numbers or two strings")
} }
} }
GREATER -> { GREATER -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
left as Double > right as Double left as Double > right as Double
} }
GREATER_EQUAL -> { GREATER_EQUAL -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
left as Double >= right as Double left as Double >= right as Double
} }
LESS -> { LESS -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
(left as Double) < right as Double (left as Double) < right as Double
} }
LESS_EQUAL -> { LESS_EQUAL -> {
checkNumberOperands(binary.operator, left, right) checkNumberOperands(expr.operator, left, right)
left as Double <= right as Double left as Double <= right as Double
} }
BANG_EQUAL -> return !isEqual(left, right) BANG_EQUAL -> return !isEqual(left, right)
@ -70,20 +83,20 @@ class Interpreter: ExprVisitor<Any?>{
} }
} }
override fun visitGrouping(grouping: Grouping): Any? { override fun visitGrouping(expr: Grouping): Any? {
return evaluate(grouping.expression) return evaluate(expr.expression)
} }
override fun visitLiteral(literal: Literal): Any? { override fun visitLiteral(expr: Literal): Any? {
return literal.value return expr.value
} }
override fun visitUnary(unary: Unary): Any? { override fun visitUnary(expr: Unary): Any? {
val right = evaluate(unary.right) val right = evaluate(expr.right)
return when (unary.operator.type) { return when (expr.operator.type) {
MINUS -> { MINUS -> {
checkNumberOperand(unary.operator, right) checkNumberOperand(expr.operator, right)
-(right as Double) -(right as Double)
} }
BANG -> !isTruthy(right) BANG -> !isTruthy(right)

@ -40,12 +40,12 @@ object Lox {
val scanner = Scanner(source) val scanner = Scanner(source)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expression = parser.parse() val statements = parser.parse()
// Stop if there was a syntax error. // Stop if there was a syntax error.
if (hadError) return if (hadError) return
interpreter.interpret(expression!!) interpreter.interpret(statements)
} }
fun error(line: Int, s: String) { fun error(line: Int, s: String) {

@ -2,6 +2,7 @@ package fr.celticinfo.lox
import fr.celticinfo.lox.TokenType.* import fr.celticinfo.lox.TokenType.*
/** /**
* The Parser class is responsible for parsing the tokens produced by the Scanner into an abstract syntax tree (AST). * The Parser class is responsible for parsing the tokens produced by the Scanner into an abstract syntax tree (AST).
* It is a recursive descent parser that uses a series of methods to parse different parts of the grammar. * It is a recursive descent parser that uses a series of methods to parse different parts of the grammar.
@ -9,18 +10,32 @@ import fr.celticinfo.lox.TokenType.*
class Parser(private val tokens: List<Token>) { class Parser(private val tokens: List<Token>) {
private var current = 0 private var current = 0
fun parse(): Expr? { fun parse(): List<Stmt> {
return try { val statements: MutableList<Stmt> = ArrayList()
val result = expression() while (!isAtEnd()) {
statements.add(statement())
}
if (!isAtEnd()) { return statements
throw error(peek(), "Expect end of expression.")
} }
result private fun statement(): Stmt {
} catch (error: ParseError) { return when {
null match(PRINT) -> printStatement()
else -> expressionStatement()
}
} }
private fun printStatement(): Stmt {
val value = expression()
consume(SEMICOLON, "Expect ';' after value.")
return Print(value)
}
private fun expressionStatement(): Stmt {
val value = expression()
consume(SEMICOLON, "Expect ';' after value.")
return Expression(value)
} }
private fun expression(): Expr { private fun expression(): Expr {

@ -0,0 +1,32 @@
package fr.celticinfo.lox
/**
* The StmtVisitor interface is used to visit the different types of statements that can be parsed by the Parser.
*/
interface StmtVisitor<R> {
fun visitExpression(stmt: Expression): R
fun visitPrint(stmt: Print): R
}
/**
* The Stmt class represents the different types of statements that can be parsed by the Parser.
*/
sealed class Stmt {
abstract fun <R> accept(visitor: StmtVisitor<R>): R
}
data class Expression(
val expression: Expr
) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R {
return visitor.visitExpression(this)
}
}
data class Print(
val expression: Expr
) : Stmt() {
override fun <R> accept(visitor: StmtVisitor<R>): R {
return visitor.visitPrint(this)
}
}

@ -1,46 +1,60 @@
package fr.celticinfo.lox package fr.celticinfo.lox
import fr.celticinfo.loxext.RpnPrinter
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Assertions.* import kotlin.test.*
class InterpreterTest { class InterpreterTest {
@Test @Test
fun `validate interpreter`() { fun `validate interpreter`() {
val code = """ val code = """
(1 + 2 * 3 - 5) / 2 1 + 2 * 3 - 4 / 5;
""".trimIndent() """.trimIndent()
val scanner = Scanner(code) val scanner = Scanner(code)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expr = parser.parse() val statements = parser.parse()
val value = Interpreter().interpret(expr!!) assertEquals(1, statements.size)
assertEquals(1.0, value) val stmt = statements.first()
assertTrue(stmt is Expression)
val expr = stmt.expression
assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr))
} }
@Test @Test
fun `Division by zero should raise error`() { fun `Division by zero should raise error`() {
val code = """ val code = """
1 / 0 1 / 0;
""".trimIndent() """.trimIndent()
val scanner = Scanner(code) val scanner = Scanner(code)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expr = parser.parse() val statements = parser.parse()
val value = Interpreter().interpret(expr!!) assertEquals(1, statements.size)
assertNull(value)
assertFailsWith<RuntimeError>(
block = {
Interpreter().interpret(statements)
}
)
} }
@Test @Test
fun `Invalid type raise error`() { fun `Invalid type raise error`() {
val code = """ val code = """
1 + false 1 + false;
""".trimIndent() """.trimIndent()
val scanner = Scanner(code) val scanner = Scanner(code)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expr = parser.parse() val statements = parser.parse()
val value = Interpreter().interpret(expr!!) assertEquals(1, statements.size)
assertNull(value)
assertFailsWith<RuntimeError>(
block = {
Interpreter().interpret(statements)
}
)
} }
} }

@ -2,20 +2,23 @@ package fr.celticinfo.lox
import fr.celticinfo.loxext.RpnPrinter import fr.celticinfo.loxext.RpnPrinter
import org.junit.jupiter.api.Test import org.junit.jupiter.api.Test
import org.junit.jupiter.api.Assertions.* import kotlin.test.*
class ParserTest { class ParserTest {
@Test @Test
fun `validate parser`() { fun `validate parser`() {
val code = """ val code = """
1 + 2 * 3 - 4 / 5 1 + 2 * 3 - 4 / 5;
""".trimIndent() """.trimIndent()
val scanner = Scanner(code) val scanner = Scanner(code)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expr = parser.parse() val statements = parser.parse()
assertNotNull(expr) assertEquals(1,statements.size)
assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr!!)) val stmt = statements.first()
assertTrue(stmt is Expression)
val expr = stmt.expression
assertEquals("1.0 2.0 3.0 * + 4.0 5.0 / -", RpnPrinter().print(expr))
} }
@Test @Test
@ -26,8 +29,12 @@ class ParserTest {
val scanner = Scanner(code) val scanner = Scanner(code)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expr = parser.parse()
assertNull(expr) assertFailsWith<Parser.ParseError>(
block = {
parser.parse()
}
)
} }
@Test @Test
@ -38,7 +45,11 @@ class ParserTest {
val scanner = Scanner(code) val scanner = Scanner(code)
val tokens = scanner.scanTokens() val tokens = scanner.scanTokens()
val parser = Parser(tokens) val parser = Parser(tokens)
val expr = parser.parse()
assertNull(expr) assertFailsWith<Parser.ParseError>(
block = {
parser.parse()
}
)
} }
} }
Loading…
Cancel
Save