You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
204 lines
4.7 KiB
Go
204 lines
4.7 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gofiber/fiber/v3"
|
|
|
|
"knowfoolery/backend/shared/infra/auth/zitadel"
|
|
"knowfoolery/backend/shared/infra/observability/logging"
|
|
"knowfoolery/backend/shared/infra/utils/httputil"
|
|
)
|
|
|
|
// ReverseProxy forwards incoming requests to configured upstream services.
|
|
type ReverseProxy struct {
|
|
client *http.Client
|
|
logger *logging.Logger
|
|
}
|
|
|
|
// New creates a reverse proxy with the provided timeout.
|
|
func New(timeout time.Duration, logger *logging.Logger) *ReverseProxy {
|
|
if timeout <= 0 {
|
|
timeout = 3 * time.Second
|
|
}
|
|
|
|
return &ReverseProxy{
|
|
client: &http.Client{Timeout: timeout},
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// NewWithClient creates a reverse proxy using a custom HTTP client.
|
|
func NewWithClient(client *http.Client, logger *logging.Logger) *ReverseProxy {
|
|
if client == nil {
|
|
client = &http.Client{Timeout: 3 * time.Second}
|
|
}
|
|
return &ReverseProxy{
|
|
client: client,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// Forward sends the current request to upstreamBaseURL with targetPath.
|
|
func (p *ReverseProxy) Forward(c fiber.Ctx, upstreamBaseURL string, targetPath string) error {
|
|
base, err := url.Parse(upstreamBaseURL)
|
|
if err != nil {
|
|
return httputil.InternalError(c, "Invalid upstream configuration")
|
|
}
|
|
|
|
reqURL := *base
|
|
reqURL.Path = joinPath(base.Path, targetPath)
|
|
reqURL.RawQuery = string(c.Request().URI().QueryString())
|
|
|
|
proxyReq, err := http.NewRequestWithContext(
|
|
c.Context(),
|
|
c.Method(),
|
|
reqURL.String(),
|
|
bytes.NewReader(c.Body()),
|
|
)
|
|
if err != nil {
|
|
return httputil.InternalError(c, "Failed to build upstream request")
|
|
}
|
|
|
|
copyRequestHeaders(c, proxyReq)
|
|
injectForwardedHeaders(c, proxyReq)
|
|
injectUserHeaders(c, proxyReq)
|
|
|
|
resp, err := p.client.Do(proxyReq)
|
|
if err != nil {
|
|
if p.logger != nil {
|
|
p.logger.WithError(err).Warn("upstream request failed")
|
|
}
|
|
return c.Status(fiber.StatusBadGateway).JSON(httputil.NewErrorResponse(
|
|
"UPSTREAM_UNAVAILABLE",
|
|
"Upstream service unavailable",
|
|
err.Error(),
|
|
))
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
copyResponseHeaders(c, resp)
|
|
c.Status(resp.StatusCode)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
if p.logger != nil {
|
|
p.logger.WithError(err).Warn("failed to read upstream response")
|
|
}
|
|
return c.Status(fiber.StatusBadGateway).JSON(httputil.NewErrorResponse(
|
|
"UPSTREAM_RESPONSE_ERROR",
|
|
"Failed to read upstream response",
|
|
err.Error(),
|
|
))
|
|
}
|
|
|
|
if len(body) == 0 {
|
|
return nil
|
|
}
|
|
return c.Send(body)
|
|
}
|
|
|
|
func joinPath(basePath string, reqPath string) string {
|
|
bp := strings.TrimRight(basePath, "/")
|
|
rp := reqPath
|
|
if rp == "" {
|
|
rp = "/"
|
|
}
|
|
if !strings.HasPrefix(rp, "/") {
|
|
rp = "/" + rp
|
|
}
|
|
if bp == "" {
|
|
return rp
|
|
}
|
|
return bp + rp
|
|
}
|
|
|
|
func copyRequestHeaders(c fiber.Ctx, req *http.Request) {
|
|
c.Request().Header.VisitAll(func(k []byte, v []byte) {
|
|
name := string(k)
|
|
if isHopByHopHeader(name) || strings.EqualFold(name, "Host") {
|
|
return
|
|
}
|
|
req.Header.Add(name, string(v))
|
|
})
|
|
}
|
|
|
|
func injectForwardedHeaders(c fiber.Ctx, req *http.Request) {
|
|
clientIP := c.IP()
|
|
xff := c.Get("X-Forwarded-For")
|
|
if xff == "" {
|
|
req.Header.Set("X-Forwarded-For", clientIP)
|
|
} else {
|
|
req.Header.Set("X-Forwarded-For", xff+", "+clientIP)
|
|
}
|
|
req.Header.Set("X-Forwarded-Proto", c.Protocol())
|
|
|
|
if reqID := requestID(c); reqID != "" {
|
|
req.Header.Set("X-Request-ID", reqID)
|
|
}
|
|
}
|
|
|
|
func injectUserHeaders(c fiber.Ctx, req *http.Request) {
|
|
if userID := localString(c, string(zitadel.ContextKeyUserID)); userID != "" {
|
|
req.Header.Set("X-User-ID", userID)
|
|
}
|
|
if email := localString(c, string(zitadel.ContextKeyUserEmail)); email != "" {
|
|
req.Header.Set("X-User-Email", email)
|
|
}
|
|
if roles := c.Locals(string(zitadel.ContextKeyUserRoles)); roles != nil {
|
|
if vals, ok := roles.([]string); ok {
|
|
req.Header.Set("X-User-Roles", strings.Join(vals, ","))
|
|
}
|
|
}
|
|
if mfa := c.Locals(string(zitadel.ContextKeyMFAVerified)); mfa != nil {
|
|
if verified, ok := mfa.(bool); ok {
|
|
req.Header.Set("X-User-MFA-Verified", fmt.Sprintf("%t", verified))
|
|
}
|
|
}
|
|
}
|
|
|
|
func copyResponseHeaders(c fiber.Ctx, resp *http.Response) {
|
|
for name, values := range resp.Header {
|
|
if isHopByHopHeader(name) {
|
|
continue
|
|
}
|
|
for _, v := range values {
|
|
c.Append(name, v)
|
|
}
|
|
}
|
|
}
|
|
|
|
func isHopByHopHeader(name string) bool {
|
|
switch strings.ToLower(name) {
|
|
case "connection", "keep-alive", "proxy-authenticate", "proxy-authorization", "te", "trailer",
|
|
"transfer-encoding", "upgrade":
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func localString(c fiber.Ctx, key string) string {
|
|
if raw := c.Locals(key); raw != nil {
|
|
if s, ok := raw.(string); ok {
|
|
return s
|
|
}
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func requestID(c fiber.Ctx) string {
|
|
if raw := c.Locals("request_id"); raw != nil {
|
|
if s, ok := raw.(string); ok {
|
|
return s
|
|
}
|
|
}
|
|
return ""
|
|
}
|