package resolver import ( "golox/ast" "golox/errors" "golox/interpreter" "golox/token" ) // FuncType represents the type of a function. type FuncType int const ( // NoneFuncType represents a function with no type. NoneFuncType FuncType = iota // FunctionFuncType represents a function. FunctionFuncType // MethodFuncType represents a method. MethodFuncType // InitializerFuncType represents an initializer. InitializerFuncType ) // Resolver implements variable resolution for the AST. type Resolver struct { errLogger errors.Logger interpreter *interpreter.Interpreter scopes []map[string]bool currFunc FuncType } // New creates a new Resolver. func New(i *interpreter.Interpreter, el errors.Logger) *Resolver { return &Resolver{el, i, []map[string]bool{}, NoneFuncType} } // VisitBlockStmt visits a block statement. func (r *Resolver) VisitBlockStmt(bs *ast.BlockStmt) any { r.beginScope() r.Resolve(bs.Statements) r.endScope() return nil } // Resolve resolves the variables for each statements in stmts. func (r *Resolver) Resolve(stmts []ast.Stmt) { for _, stmt := range stmts { r.resolveStmt(stmt) } } // resolveStmt resolves a statement. func (r *Resolver) resolveStmt(stmt ast.Stmt) { stmt.Accept(r) } // resolveExpr resolves an expression. func (r *Resolver) resolveExpr(expr ast.Expr) { expr.Accept(r) } // beginScope begins a new scope. func (r *Resolver) beginScope() { r.scopes = append(r.scopes, map[string]bool{}) } // endScope ends the current scope. func (r *Resolver) endScope() { r.scopes = r.scopes[:len(r.scopes)-1] } // visitVarStmt visits a variable statement. func (r *Resolver) VisitVarStmt(vs *ast.VarStmt) any { r.declare(&vs.Name) if vs.Initializer != nil { r.resolveExpr(vs.Initializer) } r.define(&vs.Name) return nil } // declare declares a variable in the current scope. func (r *Resolver) declare(name *token.Token) { if len(r.scopes) == 0 { return } scope := r.scopes[len(r.scopes)-1] if _, ok := scope[name.Lexeme]; ok { r.errLogger.ErrorAtToken(*name, "Variable with this name already declared in this scope.") } scope[name.Lexeme] = false } // define defines a variable in the current scope. func (r *Resolver) define(name *token.Token) { if len(r.scopes) == 0 { return } r.scopes[len(r.scopes)-1][name.Lexeme] = true } // VisitVariableExpr visits a variable expression. func (r *Resolver) VisitVariableExpr(ve *ast.VariableExpr) any { if len(r.scopes) > 0 { if _, ok := r.scopes[len(r.scopes)-1][ve.Name.Lexeme]; ok && !r.scopes[len(r.scopes)-1][ve.Name.Lexeme] { r.errLogger.ErrorAtToken(ve.Name, "Cannot read local variable in its own initializer.") } } r.resolveLocal(ve, &ve.Name) return nil } // resolveLocal resolves a local variable. func (r *Resolver) resolveLocal(expr ast.Expr, name *token.Token) { for i := len(r.scopes) - 1; i >= 0; i-- { if _, ok := r.scopes[i][name.Lexeme]; ok { r.interpreter.Resolve(expr, len(r.scopes)-1-i) return } } } // VisitAssignExpr visits an assignment expression. func (r *Resolver) VisitAssignExpr(ae *ast.AssignExpr) any { r.resolveExpr(ae.Value) r.resolveLocal(ae, &ae.Name) return nil } // VisitFunctionStmt visits a function statement. func (r *Resolver) VisitFunctionStmt(fs *ast.FunctionStmt) any { r.declare(&fs.Name) r.define(&fs.Name) r.resolveFunction(fs, FunctionFuncType) return nil } // resolveFunction resolves a function. func (r *Resolver) resolveFunction(fs *ast.FunctionStmt, t FuncType) { enclosingFunc := r.currFunc r.currFunc = t r.beginScope() for _, param := range fs.Params { r.declare(¶m) r.define(¶m) } r.Resolve(fs.Body) r.endScope() r.currFunc = enclosingFunc } // VisitExpressionStmt visits an expression statement. func (r *Resolver) VisitExpressionStmt(es *ast.ExpressionStmt) any { r.resolveExpr(es.Expression) return nil } // VisitIfStmt visits an if statement. func (r *Resolver) VisitIfStmt(is *ast.IfStmt) any { r.resolveExpr(is.Condition) r.resolveStmt(is.ThenBranch) if is.ElseBranch != nil { r.resolveStmt(is.ElseBranch) } return nil } // VisitPrintStmt visits a print statement. func (r *Resolver) VisitPrintStmt(ps *ast.PrintStmt) any { r.resolveExpr(ps.Expression) return nil } // VisitReturnStmt visits a return statement. func (r *Resolver) VisitReturnStmt(rs *ast.ReturnStmt) any { if r.currFunc == NoneFuncType { r.errLogger.ErrorAtToken(rs.Keyword, "Cannot return from top-level code.") } if rs.Value != nil { r.resolveExpr(rs.Value) } return nil } // VisitWhileStmt visits a while statement. func (r *Resolver) VisitWhileStmt(ws *ast.WhileStmt) any { r.resolveExpr(ws.Condition) r.resolveStmt(ws.Body) return nil } // VisitErrorStmt visits an error statement. func (r *Resolver) VisitErrorStmt(es *ast.ErrorStmt) any { return nil } // VisitBinaryExpr visits a binary expression. func (r *Resolver) VisitBinaryExpr(be *ast.BinaryExpr) any { r.resolveExpr(be.Left) r.resolveExpr(be.Right) return nil } // VisitCallExpr visits a call expression. func (r *Resolver) VisitCallExpr(ce *ast.CallExpr) any { r.resolveExpr(ce.Callee) for _, arg := range ce.Arguments { r.resolveExpr(arg) } return nil } // VisitGroupingExpr visits a grouping expression. func (r *Resolver) VisitGroupingExpr(ge *ast.GroupingExpr) any { r.resolveExpr(ge.Expression) return nil } // VisitLiteralExpr visits a literal expression. func (r *Resolver) VisitLiteralExpr(le *ast.LiteralExpr) any { return nil } // VisitLogicalExpr visits a logical expression. func (r *Resolver) VisitLogicalExpr(le *ast.LogicalExpr) any { r.resolveExpr(le.Left) r.resolveExpr(le.Right) return nil } // VisitUnaryExpr visits a unary expression. func (r *Resolver) VisitUnaryExpr(ue *ast.UnaryExpr) any { r.resolveExpr(ue.Right) return nil } // VisitErrorExpr visits an error statement. func (r *Resolver) VisitErrorExpr(ee *ast.ErrorExpr) any { return nil }