You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

204 lines
4.7 KiB
Go

package proxy
import (
"bytes"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gofiber/fiber/v3"
"knowfoolery/backend/shared/infra/auth/zitadel"
"knowfoolery/backend/shared/infra/observability/logging"
"knowfoolery/backend/shared/infra/utils/httputil"
)
// ReverseProxy forwards incoming requests to configured upstream services.
type ReverseProxy struct {
client *http.Client
logger *logging.Logger
}
// New creates a reverse proxy with the provided timeout.
func New(timeout time.Duration, logger *logging.Logger) *ReverseProxy {
if timeout <= 0 {
timeout = 3 * time.Second
}
return &ReverseProxy{
client: &http.Client{Timeout: timeout},
logger: logger,
}
}
// NewWithClient creates a reverse proxy using a custom HTTP client.
func NewWithClient(client *http.Client, logger *logging.Logger) *ReverseProxy {
if client == nil {
client = &http.Client{Timeout: 3 * time.Second}
}
return &ReverseProxy{
client: client,
logger: logger,
}
}
// Forward sends the current request to upstreamBaseURL with targetPath.
func (p *ReverseProxy) Forward(c fiber.Ctx, upstreamBaseURL string, targetPath string) error {
base, err := url.Parse(upstreamBaseURL)
if err != nil {
return httputil.InternalError(c, "Invalid upstream configuration")
}
reqURL := *base
reqURL.Path = joinPath(base.Path, targetPath)
reqURL.RawQuery = string(c.Request().URI().QueryString())
proxyReq, err := http.NewRequestWithContext(
c.Context(),
c.Method(),
reqURL.String(),
bytes.NewReader(c.Body()),
)
if err != nil {
return httputil.InternalError(c, "Failed to build upstream request")
}
copyRequestHeaders(c, proxyReq)
injectForwardedHeaders(c, proxyReq)
injectUserHeaders(c, proxyReq)
resp, err := p.client.Do(proxyReq)
if err != nil {
if p.logger != nil {
p.logger.WithError(err).Warn("upstream request failed")
}
return c.Status(fiber.StatusBadGateway).JSON(httputil.NewErrorResponse(
"UPSTREAM_UNAVAILABLE",
"Upstream service unavailable",
err.Error(),
))
}
defer resp.Body.Close()
copyResponseHeaders(c, resp)
c.Status(resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
if p.logger != nil {
p.logger.WithError(err).Warn("failed to read upstream response")
}
return c.Status(fiber.StatusBadGateway).JSON(httputil.NewErrorResponse(
"UPSTREAM_RESPONSE_ERROR",
"Failed to read upstream response",
err.Error(),
))
}
if len(body) == 0 {
return nil
}
return c.Send(body)
}
func joinPath(basePath string, reqPath string) string {
bp := strings.TrimRight(basePath, "/")
rp := reqPath
if rp == "" {
rp = "/"
}
if !strings.HasPrefix(rp, "/") {
rp = "/" + rp
}
if bp == "" {
return rp
}
return bp + rp
}
func copyRequestHeaders(c fiber.Ctx, req *http.Request) {
c.Request().Header.VisitAll(func(k []byte, v []byte) {
name := string(k)
if isHopByHopHeader(name) || strings.EqualFold(name, "Host") {
return
}
req.Header.Add(name, string(v))
})
}
func injectForwardedHeaders(c fiber.Ctx, req *http.Request) {
clientIP := c.IP()
xff := c.Get("X-Forwarded-For")
if xff == "" {
req.Header.Set("X-Forwarded-For", clientIP)
} else {
req.Header.Set("X-Forwarded-For", xff+", "+clientIP)
}
req.Header.Set("X-Forwarded-Proto", c.Protocol())
if reqID := requestID(c); reqID != "" {
req.Header.Set("X-Request-ID", reqID)
}
}
func injectUserHeaders(c fiber.Ctx, req *http.Request) {
if userID := localString(c, string(zitadel.ContextKeyUserID)); userID != "" {
req.Header.Set("X-User-ID", userID)
}
if email := localString(c, string(zitadel.ContextKeyUserEmail)); email != "" {
req.Header.Set("X-User-Email", email)
}
if roles := c.Locals(string(zitadel.ContextKeyUserRoles)); roles != nil {
if vals, ok := roles.([]string); ok {
req.Header.Set("X-User-Roles", strings.Join(vals, ","))
}
}
if mfa := c.Locals(string(zitadel.ContextKeyMFAVerified)); mfa != nil {
if verified, ok := mfa.(bool); ok {
req.Header.Set("X-User-MFA-Verified", fmt.Sprintf("%t", verified))
}
}
}
func copyResponseHeaders(c fiber.Ctx, resp *http.Response) {
for name, values := range resp.Header {
if isHopByHopHeader(name) {
continue
}
for _, v := range values {
c.Append(name, v)
}
}
}
func isHopByHopHeader(name string) bool {
switch strings.ToLower(name) {
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailer",
"transfer-encoding", "upgrade":
return true
default:
return false
}
}
func localString(c fiber.Ctx, key string) string {
if raw := c.Locals(key); raw != nil {
if s, ok := raw.(string); ok {
return s
}
}
return ""
}
func requestID(c fiber.Ctx) string {
if raw := c.Locals("request_id"); raw != nil {
if s, ok := raw.(string); ok {
return s
}
}
return ""
}