You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
5.8 KiB
Go

package resolver
import (
"golox/ast"
"golox/errors"
"golox/interpreter"
"golox/token"
)
// FuncType represents the type of a function.
type FuncType int
const (
// NoneFuncType represents a function with no type.
NoneFuncType FuncType = iota
// FunctionFuncType represents a function.
FunctionFuncType
// MethodFuncType represents a method.
MethodFuncType
// InitializerFuncType represents an initializer.
InitializerFuncType
)
// Resolver implements variable resolution for the AST.
type Resolver struct {
errLogger errors.Logger
interpreter *interpreter.Interpreter
scopes []map[string]bool
currFunc FuncType
}
// New creates a new Resolver.
func New(i *interpreter.Interpreter, el errors.Logger) *Resolver {
return &Resolver{el, i, []map[string]bool{}, NoneFuncType}
}
// VisitBlockStmt visits a block statement.
func (r *Resolver) VisitBlockStmt(bs *ast.BlockStmt) any {
r.beginScope()
r.Resolve(bs.Statements)
r.endScope()
return nil
}
// Resolve resolves the variables for each statements in stmts.
func (r *Resolver) Resolve(stmts []ast.Stmt) {
for _, stmt := range stmts {
r.resolveStmt(stmt)
}
}
// resolveStmt resolves a statement.
func (r *Resolver) resolveStmt(stmt ast.Stmt) {
stmt.Accept(r)
}
// resolveExpr resolves an expression.
func (r *Resolver) resolveExpr(expr ast.Expr) {
expr.Accept(r)
}
// beginScope begins a new scope.
func (r *Resolver) beginScope() {
r.scopes = append(r.scopes, map[string]bool{})
}
// endScope ends the current scope.
func (r *Resolver) endScope() {
r.scopes = r.scopes[:len(r.scopes)-1]
}
// visitVarStmt visits a variable statement.
func (r *Resolver) VisitVarStmt(vs *ast.VarStmt) any {
r.declare(&vs.Name)
if vs.Initializer != nil {
r.resolveExpr(vs.Initializer)
}
r.define(&vs.Name)
return nil
}
// declare declares a variable in the current scope.
func (r *Resolver) declare(name *token.Token) {
if len(r.scopes) == 0 {
return
}
scope := r.scopes[len(r.scopes)-1]
if _, ok := scope[name.Lexeme]; ok {
r.errLogger.ErrorAtToken(*name, "Variable with this name already declared in this scope.")
}
scope[name.Lexeme] = false
}
// define defines a variable in the current scope.
func (r *Resolver) define(name *token.Token) {
if len(r.scopes) == 0 {
return
}
r.scopes[len(r.scopes)-1][name.Lexeme] = true
}
// VisitVariableExpr visits a variable expression.
func (r *Resolver) VisitVariableExpr(ve *ast.VariableExpr) any {
if len(r.scopes) > 0 {
if _, ok := r.scopes[len(r.scopes)-1][ve.Name.Lexeme]; ok && !r.scopes[len(r.scopes)-1][ve.Name.Lexeme] {
r.errLogger.ErrorAtToken(ve.Name, "Cannot read local variable in its own initializer.")
}
}
r.resolveLocal(ve, &ve.Name)
return nil
}
// resolveLocal resolves a local variable.
func (r *Resolver) resolveLocal(expr ast.Expr, name *token.Token) {
for i := len(r.scopes) - 1; i >= 0; i-- {
if _, ok := r.scopes[i][name.Lexeme]; ok {
r.interpreter.Resolve(expr, len(r.scopes)-1-i)
return
}
}
}
// VisitAssignExpr visits an assignment expression.
func (r *Resolver) VisitAssignExpr(ae *ast.AssignExpr) any {
r.resolveExpr(ae.Value)
r.resolveLocal(ae, &ae.Name)
return nil
}
// VisitFunctionStmt visits a function statement.
func (r *Resolver) VisitFunctionStmt(fs *ast.FunctionStmt) any {
r.declare(&fs.Name)
r.define(&fs.Name)
r.resolveFunction(fs, FunctionFuncType)
return nil
}
// resolveFunction resolves a function.
func (r *Resolver) resolveFunction(fs *ast.FunctionStmt, t FuncType) {
enclosingFunc := r.currFunc
r.currFunc = t
r.beginScope()
for _, param := range fs.Params {
r.declare(&param)
r.define(&param)
}
r.Resolve(fs.Body)
r.endScope()
r.currFunc = enclosingFunc
}
// VisitExpressionStmt visits an expression statement.
func (r *Resolver) VisitExpressionStmt(es *ast.ExpressionStmt) any {
r.resolveExpr(es.Expression)
return nil
}
// VisitIfStmt visits an if statement.
func (r *Resolver) VisitIfStmt(is *ast.IfStmt) any {
r.resolveExpr(is.Condition)
r.resolveStmt(is.ThenBranch)
if is.ElseBranch != nil {
r.resolveStmt(is.ElseBranch)
}
return nil
}
// VisitPrintStmt visits a print statement.
func (r *Resolver) VisitPrintStmt(ps *ast.PrintStmt) any {
r.resolveExpr(ps.Expression)
return nil
}
// VisitReturnStmt visits a return statement.
func (r *Resolver) VisitReturnStmt(rs *ast.ReturnStmt) any {
if r.currFunc == NoneFuncType {
r.errLogger.ErrorAtToken(rs.Keyword, "Cannot return from top-level code.")
}
if rs.Value != nil {
r.resolveExpr(rs.Value)
}
return nil
}
// VisitWhileStmt visits a while statement.
func (r *Resolver) VisitWhileStmt(ws *ast.WhileStmt) any {
r.resolveExpr(ws.Condition)
r.resolveStmt(ws.Body)
return nil
}
// VisitErrorStmt visits an error statement.
func (r *Resolver) VisitErrorStmt(es *ast.ErrorStmt) any {
return nil
}
// VisitBinaryExpr visits a binary expression.
func (r *Resolver) VisitBinaryExpr(be *ast.BinaryExpr) any {
r.resolveExpr(be.Left)
r.resolveExpr(be.Right)
return nil
}
// VisitCallExpr visits a call expression.
func (r *Resolver) VisitCallExpr(ce *ast.CallExpr) any {
r.resolveExpr(ce.Callee)
for _, arg := range ce.Arguments {
r.resolveExpr(arg)
}
return nil
}
// VisitGroupingExpr visits a grouping expression.
func (r *Resolver) VisitGroupingExpr(ge *ast.GroupingExpr) any {
r.resolveExpr(ge.Expression)
return nil
}
// VisitLiteralExpr visits a literal expression.
func (r *Resolver) VisitLiteralExpr(le *ast.LiteralExpr) any {
return nil
}
// VisitLogicalExpr visits a logical expression.
func (r *Resolver) VisitLogicalExpr(le *ast.LogicalExpr) any {
r.resolveExpr(le.Left)
r.resolveExpr(le.Right)
return nil
}
// VisitUnaryExpr visits a unary expression.
func (r *Resolver) VisitUnaryExpr(ue *ast.UnaryExpr) any {
r.resolveExpr(ue.Right)
return nil
}
// VisitErrorExpr visits an error statement.
func (r *Resolver) VisitErrorExpr(ee *ast.ErrorExpr) any {
return nil
}