package proxy import ( "bytes" "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/gofiber/fiber/v3" "knowfoolery/backend/shared/infra/auth/zitadel" "knowfoolery/backend/shared/infra/observability/logging" "knowfoolery/backend/shared/infra/utils/httputil" ) // ReverseProxy forwards incoming requests to configured upstream services. type ReverseProxy struct { client *http.Client logger *logging.Logger } // New creates a reverse proxy with the provided timeout. func New(timeout time.Duration, logger *logging.Logger) *ReverseProxy { if timeout <= 0 { timeout = 3 * time.Second } return &ReverseProxy{ client: &http.Client{Timeout: timeout}, logger: logger, } } // NewWithClient creates a reverse proxy using a custom HTTP client. func NewWithClient(client *http.Client, logger *logging.Logger) *ReverseProxy { if client == nil { client = &http.Client{Timeout: 3 * time.Second} } return &ReverseProxy{ client: client, logger: logger, } } // Forward sends the current request to upstreamBaseURL with targetPath. func (p *ReverseProxy) Forward(c fiber.Ctx, upstreamBaseURL string, targetPath string) error { base, err := url.Parse(upstreamBaseURL) if err != nil { return httputil.InternalError(c, "Invalid upstream configuration") } reqURL := *base reqURL.Path = joinPath(base.Path, targetPath) reqURL.RawQuery = string(c.Request().URI().QueryString()) proxyReq, err := http.NewRequestWithContext( c.Context(), c.Method(), reqURL.String(), bytes.NewReader(c.Body()), ) if err != nil { return httputil.InternalError(c, "Failed to build upstream request") } copyRequestHeaders(c, proxyReq) injectForwardedHeaders(c, proxyReq) injectUserHeaders(c, proxyReq) resp, err := p.client.Do(proxyReq) if err != nil { if p.logger != nil { p.logger.WithError(err).Warn("upstream request failed") } return c.Status(fiber.StatusBadGateway).JSON(httputil.NewErrorResponse( "UPSTREAM_UNAVAILABLE", "Upstream service unavailable", err.Error(), )) } defer resp.Body.Close() copyResponseHeaders(c, resp) c.Status(resp.StatusCode) body, err := io.ReadAll(resp.Body) if err != nil { if p.logger != nil { p.logger.WithError(err).Warn("failed to read upstream response") } return c.Status(fiber.StatusBadGateway).JSON(httputil.NewErrorResponse( "UPSTREAM_RESPONSE_ERROR", "Failed to read upstream response", err.Error(), )) } if len(body) == 0 { return nil } return c.Send(body) } func joinPath(basePath string, reqPath string) string { bp := strings.TrimRight(basePath, "/") rp := reqPath if rp == "" { rp = "/" } if !strings.HasPrefix(rp, "/") { rp = "/" + rp } if bp == "" { return rp } return bp + rp } func copyRequestHeaders(c fiber.Ctx, req *http.Request) { c.Request().Header.VisitAll(func(k []byte, v []byte) { name := string(k) if isHopByHopHeader(name) || strings.EqualFold(name, "Host") { return } req.Header.Add(name, string(v)) }) } func injectForwardedHeaders(c fiber.Ctx, req *http.Request) { clientIP := c.IP() xff := c.Get("X-Forwarded-For") if xff == "" { req.Header.Set("X-Forwarded-For", clientIP) } else { req.Header.Set("X-Forwarded-For", xff+", "+clientIP) } req.Header.Set("X-Forwarded-Proto", c.Protocol()) if reqID := requestID(c); reqID != "" { req.Header.Set("X-Request-ID", reqID) } } func injectUserHeaders(c fiber.Ctx, req *http.Request) { if userID := localString(c, string(zitadel.ContextKeyUserID)); userID != "" { req.Header.Set("X-User-ID", userID) } if email := localString(c, string(zitadel.ContextKeyUserEmail)); email != "" { req.Header.Set("X-User-Email", email) } if roles := c.Locals(string(zitadel.ContextKeyUserRoles)); roles != nil { if vals, ok := roles.([]string); ok { req.Header.Set("X-User-Roles", strings.Join(vals, ",")) } } if mfa := c.Locals(string(zitadel.ContextKeyMFAVerified)); mfa != nil { if verified, ok := mfa.(bool); ok { req.Header.Set("X-User-MFA-Verified", fmt.Sprintf("%t", verified)) } } } func copyResponseHeaders(c fiber.Ctx, resp *http.Response) { for name, values := range resp.Header { if isHopByHopHeader(name) { continue } for _, v := range values { c.Append(name, v) } } } func isHopByHopHeader(name string) bool { switch strings.ToLower(name) { case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailer", "transfer-encoding", "upgrade": return true default: return false } } func localString(c fiber.Ctx, key string) string { if raw := c.Locals(key); raw != nil { if s, ok := raw.(string); ok { return s } } return "" } func requestID(c fiber.Ctx) string { if raw := c.Locals("request_id"); raw != nil { if s, ok := raw.(string); ok { return s } } return "" }