You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

427 lines
11 KiB
Go

package ent
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
domain "knowfoolery/backend/services/user-service/internal/domain/user"
sharederrors "knowfoolery/backend/shared/domain/errors"
sharedtypes "knowfoolery/backend/shared/domain/types"
sharedpostgres "knowfoolery/backend/shared/infra/database/postgres"
)
// UserRepository implements user storage on PostgreSQL.
type UserRepository struct {
client *sharedpostgres.Client
}
// NewUserRepository creates a new user repository.
func NewUserRepository(client *sharedpostgres.Client) *UserRepository {
return &UserRepository{client: client}
}
// EnsureSchema creates service tables if missing.
func (r *UserRepository) EnsureSchema(ctx context.Context) error {
const ddl = `
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY,
zitadel_user_id VARCHAR(128) UNIQUE,
email VARCHAR(320) NOT NULL UNIQUE,
email_verified BOOLEAN NOT NULL DEFAULT FALSE,
display_name VARCHAR(50) NOT NULL,
consent_version VARCHAR(32) NOT NULL,
consent_given_at TIMESTAMPTZ NOT NULL,
consent_source VARCHAR(32) NOT NULL DEFAULT 'web',
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
deleted_at TIMESTAMPTZ NULL
);
CREATE INDEX IF NOT EXISTS idx_users_deleted_at ON users (deleted_at);
CREATE INDEX IF NOT EXISTS idx_users_created_at ON users (created_at DESC);
CREATE TABLE IF NOT EXISTS user_audit_log (
id UUID PRIMARY KEY,
actor_user_id VARCHAR(128),
target_user_id UUID NOT NULL,
action VARCHAR(64) NOT NULL,
metadata_json JSONB NOT NULL DEFAULT '{}'::jsonb,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
);
CREATE INDEX IF NOT EXISTS idx_user_audit_log_target_user_id ON user_audit_log (target_user_id, created_at DESC);
`
_, err := r.client.Pool.Exec(ctx, ddl)
return err
}
// Create inserts a new user.
func (r *UserRepository) Create(ctx context.Context, user *domain.User) (*domain.User, error) {
id := uuid.NewString()
now := time.Now().UTC()
const q = `
INSERT INTO users (
id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at
)
VALUES ($1,$2,$3,$4,$5,$6,$7,$8,$9,$10)
RETURNING id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at`
row := r.client.Pool.QueryRow(ctx, q,
id,
nullIfEmpty(user.ZitadelUserID),
user.Email,
user.EmailVerified,
user.DisplayName,
user.ConsentVersion,
user.ConsentGivenAt,
user.ConsentSource,
now,
now,
)
created, err := scanUser(row)
if err != nil {
if isUniqueViolation(err) {
return nil, sharederrors.Wrap(sharederrors.CodeUserAlreadyExists, "user already exists", err)
}
return nil, err
}
return created, nil
}
// GetByID fetches a user by ID and excludes soft-deleted users.
func (r *UserRepository) GetByID(ctx context.Context, id string) (*domain.User, error) {
const q = `
SELECT id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at
FROM users
WHERE id=$1 AND deleted_at IS NULL`
row := r.client.Pool.QueryRow(ctx, q, id)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, sharederrors.Wrap(sharederrors.CodeUserNotFound, "user not found", err)
}
return nil, err
}
return u, nil
}
// GetByEmail fetches a non-deleted user by email.
func (r *UserRepository) GetByEmail(ctx context.Context, email string) (*domain.User, error) {
const q = `
SELECT id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at
FROM users
WHERE email=$1 AND deleted_at IS NULL`
row := r.client.Pool.QueryRow(ctx, q, strings.ToLower(email))
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, sharederrors.Wrap(sharederrors.CodeUserNotFound, "user not found", err)
}
return nil, err
}
return u, nil
}
// GetByZitadelUserID fetches a non-deleted user by identity id.
func (r *UserRepository) GetByZitadelUserID(ctx context.Context, zitadelUserID string) (*domain.User, error) {
const q = `
SELECT id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at
FROM users
WHERE zitadel_user_id=$1 AND deleted_at IS NULL`
row := r.client.Pool.QueryRow(ctx, q, zitadelUserID)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, sharederrors.Wrap(sharederrors.CodeUserNotFound, "user not found", err)
}
return nil, err
}
return u, nil
}
// UpdateProfile updates mutable profile fields.
func (r *UserRepository) UpdateProfile(
ctx context.Context,
id string,
displayName string,
consent domain.ConsentRecord,
) (*domain.User, error) {
const q = `
UPDATE users
SET display_name=$2,
consent_version=$3,
consent_given_at=$4,
consent_source=$5,
updated_at=NOW()
WHERE id=$1 AND deleted_at IS NULL
RETURNING id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at`
row := r.client.Pool.QueryRow(ctx, q,
id,
displayName,
consent.Version,
consent.GivenAt,
consent.Source,
)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, sharederrors.Wrap(sharederrors.CodeUserNotFound, "user not found", err)
}
return nil, err
}
return u, nil
}
// MarkEmailVerified marks user's email as verified.
func (r *UserRepository) MarkEmailVerified(ctx context.Context, id string) (*domain.User, error) {
const q = `
UPDATE users
SET email_verified=true, updated_at=NOW()
WHERE id=$1 AND deleted_at IS NULL
RETURNING id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at`
row := r.client.Pool.QueryRow(ctx, q, id)
u, err := scanUser(row)
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
return nil, sharederrors.Wrap(sharederrors.CodeUserNotFound, "user not found", err)
}
return nil, err
}
return u, nil
}
// SoftDelete marks user as deleted and writes an audit entry.
func (r *UserRepository) SoftDelete(ctx context.Context, id string, actorUserID string) error {
tx, err := r.client.Pool.Begin(ctx)
if err != nil {
return err
}
defer func() { _ = tx.Rollback(ctx) }()
const qDelete = `UPDATE users SET deleted_at=NOW(), updated_at=NOW() WHERE id=$1 AND deleted_at IS NULL`
res, err := tx.Exec(ctx, qDelete, id)
if err != nil {
return err
}
if res.RowsAffected() == 0 {
return sharederrors.Wrap(sharederrors.CodeUserNotFound, "user not found", nil)
}
auditID := uuid.NewString()
const qAudit = `
INSERT INTO user_audit_log (id, actor_user_id, target_user_id, action, metadata_json, created_at)
VALUES ($1,$2,$3,$4,$5,NOW())`
_, err = tx.Exec(
ctx,
qAudit,
auditID,
nullIfEmpty(actorUserID),
id,
domain.AuditActionGDPRDelete,
`{"operation":"gdpr_delete"}`,
)
if err != nil {
return err
}
if err := tx.Commit(ctx); err != nil {
return err
}
return nil
}
// List returns paginated users.
func (r *UserRepository) List(
ctx context.Context,
pagination sharedtypes.Pagination,
filter domain.ListFilter,
) ([]*domain.User, int64, error) {
clauses := make([]string, 0)
args := make([]interface{}, 0)
if !filter.IncludeDeleted {
clauses = append(clauses, "deleted_at IS NULL")
}
if filter.Email != "" {
args = append(args, "%"+strings.ToLower(filter.Email)+"%")
clauses = append(clauses, fmt.Sprintf("LOWER(email) LIKE $%d", len(args)))
}
if filter.DisplayName != "" {
args = append(args, "%"+strings.ToLower(filter.DisplayName)+"%")
clauses = append(clauses, fmt.Sprintf("LOWER(display_name) LIKE $%d", len(args)))
}
if filter.CreatedAfter != nil {
args = append(args, *filter.CreatedAfter)
clauses = append(clauses, fmt.Sprintf("created_at >= $%d", len(args)))
}
if filter.CreatedBefore != nil {
args = append(args, *filter.CreatedBefore)
clauses = append(clauses, fmt.Sprintf("created_at <= $%d", len(args)))
}
whereSQL := ""
if len(clauses) > 0 {
whereSQL = " WHERE " + strings.Join(clauses, " AND ")
}
countQuery := "SELECT COUNT(*) FROM users" + whereSQL
var total int64
if err := r.client.Pool.QueryRow(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, 0, err
}
limit := pagination.Limit()
offset := pagination.Offset()
args = append(args, limit, offset)
query := `
SELECT id, zitadel_user_id, email, email_verified, display_name,
consent_version, consent_given_at, consent_source,
created_at, updated_at, deleted_at
FROM users` + whereSQL + fmt.Sprintf(" ORDER BY created_at DESC LIMIT $%d OFFSET $%d", len(args)-1, len(args))
rows, err := r.client.Pool.Query(ctx, query, args...)
if err != nil {
return nil, 0, err
}
defer rows.Close()
out := make([]*domain.User, 0)
for rows.Next() {
u, err := scanUser(rows)
if err != nil {
return nil, 0, err
}
out = append(out, u)
}
return out, total, rows.Err()
}
// AuditLogsByUserID returns audit entries for a user.
func (r *UserRepository) AuditLogsByUserID(ctx context.Context, id string) ([]domain.AuditLogEntry, error) {
const q = `
SELECT id, actor_user_id, target_user_id, action, metadata_json::text, created_at
FROM user_audit_log
WHERE target_user_id=$1
ORDER BY created_at DESC`
rows, err := r.client.Pool.Query(ctx, q, id)
if err != nil {
return nil, err
}
defer rows.Close()
out := make([]domain.AuditLogEntry, 0)
for rows.Next() {
var entry domain.AuditLogEntry
var actor sql.NullString
if err := rows.Scan(
&entry.ID,
&actor,
&entry.TargetUserID,
&entry.Action,
&entry.MetadataJSON,
&entry.CreatedAt,
); err != nil {
return nil, err
}
if actor.Valid {
entry.ActorUserID = actor.String
}
out = append(out, entry)
}
return out, rows.Err()
}
// WriteAuditLog writes a custom audit entry.
func (r *UserRepository) WriteAuditLog(ctx context.Context, entry domain.AuditLogEntry) error {
if entry.ID == "" {
entry.ID = uuid.NewString()
}
if entry.Action == "" {
entry.Action = "unknown"
}
if entry.MetadataJSON == "" {
entry.MetadataJSON = "{}"
}
const q = `
INSERT INTO user_audit_log (id, actor_user_id, target_user_id, action, metadata_json, created_at)
VALUES ($1,$2,$3,$4,$5,NOW())`
_, err := r.client.Pool.Exec(ctx, q,
entry.ID,
nullIfEmpty(entry.ActorUserID),
entry.TargetUserID,
entry.Action,
entry.MetadataJSON,
)
return err
}
func nullIfEmpty(v string) interface{} {
if strings.TrimSpace(v) == "" {
return nil
}
return v
}
func scanUser(scanner interface {
Scan(dest ...interface{}) error
}) (*domain.User, error) {
var u domain.User
var zitadelUserID sql.NullString
var deletedAt *time.Time
if err := scanner.Scan(
&u.ID,
&zitadelUserID,
&u.Email,
&u.EmailVerified,
&u.DisplayName,
&u.ConsentVersion,
&u.ConsentGivenAt,
&u.ConsentSource,
&u.CreatedAt,
&u.UpdatedAt,
&deletedAt,
); err != nil {
return nil, err
}
if zitadelUserID.Valid {
u.ZitadelUserID = zitadelUserID.String
}
u.DeletedAt = deletedAt
return &u, nil
}
func isUniqueViolation(err error) bool {
if err == nil {
return false
}
return strings.Contains(strings.ToLower(err.Error()), "duplicate key")
}