You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

511 lines
18 KiB
Rust

//! Shared test utilities for integration tests.
//!
//! Provides `TestApp` which sets up a real Postgres database, runs migrations,
//! builds the full Axum router, and offers helper methods to send HTTP requests
//! through it using `tower::ServiceExt::oneshot`.
//!
//! ## Database strategy
//!
//! Each `TestApp` creates a unique temporary database (named with a UUID suffix)
//! on the Postgres instance pointed to by the `TEST_DATABASE_URL` environment
//! variable. The base URL should point to an existing database (e.g. `postgres`)
//! that can be used to issue `CREATE DATABASE` / `DROP DATABASE` commands.
//!
//! ## External service bypass
//!
//! Turnstile and Resend are configured with test bypass keys so that no
//! external HTTP calls are made during tests.
use axum::body::Body;
use axum::http::{Method, Request, StatusCode};
use axum::Router;
use http_body_util::BodyExt;
use sqlx::postgres::PgPoolOptions;
use sqlx::{Executor, PgPool};
use tower::ServiceExt;
use ai_synth_backend::app_state::AppState;
use ai_synth_backend::config::AppConfig;
use ai_synth_backend::db;
use ai_synth_backend::models::user::UserRole;
use ai_synth_backend::router::build_router;
use ai_synth_backend::services::auth;
use ai_synth_backend::services::email::TEST_API_KEY as EMAIL_TEST_KEY;
use ai_synth_backend::services::turnstile::TEST_SECRET_KEY as TURNSTILE_TEST_KEY;
/// A self-contained test application backed by a real Postgres database.
#[allow(dead_code)]
pub struct TestApp {
/// The Axum router, ready to receive requests via `oneshot`.
pub router: Router,
/// The Postgres connection pool (to the test database).
pub pool: PgPool,
/// A pool connected to the *admin* database, used for DROP on cleanup.
admin_pool: PgPool,
/// The name of the ephemeral test database.
db_name: String,
/// The application config used to build this test app.
pub config: AppConfig,
}
#[allow(dead_code)]
impl TestApp {
/// Spin up a new test application with its own database.
///
/// Reads `TEST_DATABASE_URL` from the environment. This should be a Postgres
/// connection string pointing to an admin database (e.g. `postgres`).
///
/// Example:
/// ```text
/// TEST_DATABASE_URL=postgres://user:pass@localhost:5432/postgres
/// ```
pub async fn new() -> Self {
let base_url = std::env::var("TEST_DATABASE_URL").expect(
"TEST_DATABASE_URL must be set to run integration tests \
(e.g. postgres://user:pass@localhost:5432/postgres)",
);
// Generate a unique database name
let db_name = format!("ai_synth_test_{}", uuid::Uuid::new_v4().simple());
// Connect to the admin database to create the test database
let admin_pool = PgPoolOptions::new()
.max_connections(2)
.connect(&base_url)
.await
.expect("Failed to connect to admin database (TEST_DATABASE_URL)");
admin_pool
.execute(format!("CREATE DATABASE \"{db_name}\"").as_str())
.await
.expect("Failed to create test database");
// Build the connection URL for the test database
let test_db_url = {
// Replace the database name portion of the URL
let parts = base_url.rsplitn(2, '/').collect::<Vec<_>>();
if parts.len() == 2 {
format!("{}/{}", parts[1], db_name)
} else {
format!("{}/{}", base_url.trim_end_matches('/'), db_name)
}
};
// Connect to the test database
let pool = PgPoolOptions::new()
.max_connections(5)
.connect(&test_db_url)
.await
.expect("Failed to connect to test database");
// Run migrations
sqlx::migrate!("./migrations")
.run(&pool)
.await
.expect("Failed to run migrations on test database");
// Build a test config with bypassed external services
let config = AppConfig {
database_url: test_db_url,
master_encryption_key: std::sync::Arc::new("ab".repeat(32)), // 64 hex chars
app_url: "http://localhost:3000".into(),
port: 0,
static_dir: "/tmp/ai_synth_test_static".into(), // not used in API tests
resend_api_key: EMAIL_TEST_KEY.into(),
email_from: "test@example.com".into(),
turnstile_secret_key: TURNSTILE_TEST_KEY.into(),
turnstile_site_key: "test-site-key".into(),
generation_timeout_secs: 1800,
};
let http_client = reqwest::Client::new();
let state = AppState::new(config.clone(), pool.clone(), http_client);
// Create the static dir so ServeDir/ServeFile don't error
let _ = std::fs::create_dir_all(&config.static_dir);
let index_path = format!("{}/index.html", config.static_dir);
if !std::path::Path::new(&index_path).exists() {
let _ = std::fs::write(&index_path, "<html><body>test</body></html>");
}
let router = build_router(state, &config);
Self {
router,
pool,
admin_pool,
db_name,
config,
}
}
// ── Request helpers ──────────────────────────────────────────────
/// Send a GET request to the given URI and return (StatusCode, body bytes).
pub async fn get(&self, uri: &str) -> (StatusCode, serde_json::Value) {
self.request(Method::GET, uri, None, None).await
}
/// Send a GET request with a session cookie.
pub async fn get_with_session(
&self,
uri: &str,
session_cookie: &str,
) -> (StatusCode, serde_json::Value) {
self.request(Method::GET, uri, None, Some(session_cookie))
.await
}
/// Send a POST request with a JSON body and the CSRF header.
pub async fn post(
&self,
uri: &str,
body: &serde_json::Value,
) -> (StatusCode, serde_json::Value) {
self.request(Method::POST, uri, Some(body), None).await
}
/// Send a POST request with a JSON body, CSRF header, and session cookie.
pub async fn post_with_session(
&self,
uri: &str,
body: &serde_json::Value,
session_cookie: &str,
) -> (StatusCode, serde_json::Value) {
self.request(Method::POST, uri, Some(body), Some(session_cookie))
.await
}
/// Send a PUT request with a JSON body and the CSRF header.
pub async fn put_with_session(
&self,
uri: &str,
body: &serde_json::Value,
session_cookie: &str,
) -> (StatusCode, serde_json::Value) {
self.request(Method::PUT, uri, Some(body), Some(session_cookie))
.await
}
/// Send a DELETE request with a session cookie and the CSRF header.
pub async fn delete_with_session(
&self,
uri: &str,
session_cookie: &str,
) -> (StatusCode, serde_json::Value) {
self.request(Method::DELETE, uri, None, Some(session_cookie))
.await
}
/// Send a raw `Request<Body>` through the router and return
/// (StatusCode, raw response bytes as String, and all response headers).
///
/// Useful for endpoints that return non-JSON content (e.g. CSV export).
pub async fn raw_request_text(
&self,
req: Request<Body>,
) -> (StatusCode, String, axum::http::HeaderMap) {
let response = self
.router
.clone()
.oneshot(req)
.await
.expect("Failed to send raw request");
let status = response.status();
let headers = response.headers().clone();
let bytes = response
.into_body()
.collect()
.await
.expect("Failed to read response body")
.to_bytes();
let text = String::from_utf8_lossy(&bytes).to_string();
(status, text, headers)
}
/// Send a raw `Request<Body>` through the router and return
/// (StatusCode, raw response bytes, and all response headers).
///
/// Useful for endpoints that return binary content (e.g. PDF export).
pub async fn raw_request_bytes(
&self,
req: Request<Body>,
) -> (StatusCode, Vec<u8>, axum::http::HeaderMap) {
let response = self
.router
.clone()
.oneshot(req)
.await
.expect("Failed to send raw request");
let status = response.status();
let headers = response.headers().clone();
let bytes = response
.into_body()
.collect()
.await
.expect("Failed to read response body")
.to_bytes()
.to_vec();
(status, bytes, headers)
}
/// Send a POST request *without* the CSRF header (to test CSRF rejection).
pub async fn post_without_csrf(
&self,
uri: &str,
body: &serde_json::Value,
) -> (StatusCode, serde_json::Value) {
let body_bytes = serde_json::to_vec(body).unwrap();
// Deliberately omit X-Requested-With
let req = Request::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json")
.body(Body::from(body_bytes))
.unwrap();
let response = self
.router
.clone()
.oneshot(req)
.await
.expect("Failed to send request");
let status = response.status();
let bytes = response
.into_body()
.collect()
.await
.expect("Failed to read response body")
.to_bytes();
let json = if bytes.is_empty() {
serde_json::json!({})
} else {
serde_json::from_slice(&bytes).unwrap_or_else(|_| {
serde_json::json!({ "raw": String::from_utf8_lossy(&bytes).to_string() })
})
};
(status, json)
}
/// Low-level request helper.
async fn request(
&self,
method: Method,
uri: &str,
body: Option<&serde_json::Value>,
session_cookie: Option<&str>,
) -> (StatusCode, serde_json::Value) {
let body_bytes = body.map(|b| serde_json::to_vec(b).unwrap());
let mut builder = Request::builder().method(method.clone()).uri(uri);
// Set content-type for requests with a body
if body_bytes.is_some() {
builder = builder.header("Content-Type", "application/json");
}
// Add CSRF header for mutating methods
let mutating = [Method::POST, Method::PUT, Method::DELETE, Method::PATCH];
if mutating.contains(&method) {
builder = builder.header("X-Requested-With", "XMLHttpRequest");
}
// Add session cookie if provided
if let Some(cookie) = session_cookie {
builder = builder.header(
"Cookie",
format!("{}={}", auth::SESSION_COOKIE_NAME, cookie),
);
}
let req = builder
.body(Body::from(body_bytes.unwrap_or_default()))
.unwrap();
let response = self
.router
.clone()
.oneshot(req)
.await
.expect("Failed to send request");
let status = response.status();
let bytes = response
.into_body()
.collect()
.await
.expect("Failed to read response body")
.to_bytes();
let json = if bytes.is_empty() {
serde_json::json!({})
} else {
serde_json::from_slice(&bytes).unwrap_or_else(|_| {
serde_json::json!({ "raw": String::from_utf8_lossy(&bytes).to_string() })
})
};
(status, json)
}
/// Send a raw `Request<Body>` through the router and return the full response.
///
/// Useful when you need to inspect response headers (e.g. Set-Cookie).
pub async fn raw_request(
&self,
req: Request<Body>,
) -> axum::http::Response<Body> {
self.router
.clone()
.oneshot(req)
.await
.expect("Failed to send raw request")
}
// ── Auth helpers ─────────────────────────────────────────────────
/// Create a test user directly in the database and return their UUID.
pub async fn create_test_user(&self, email: &str) -> uuid::Uuid {
let user = db::users::create(&self.pool, email, None, UserRole::User)
.await
.expect("Failed to create test user");
user.id
}
/// Create a test user and a valid session, returning the raw session token
/// (cookie value) that can be used in authenticated requests.
pub async fn create_authenticated_user(&self, email: &str) -> (uuid::Uuid, String) {
let user_id = self.create_test_user(email).await;
let session_token = auth::create_session(&self.pool, user_id, None, None)
.await
.expect("Failed to create session");
(user_id, session_token)
}
/// Create an admin user directly in the database and return a valid session
/// token (cookie value) that can be used in authenticated admin requests.
pub async fn create_admin_user(&self, email: &str) -> (uuid::Uuid, String) {
let user = db::users::create(&self.pool, email, None, UserRole::Admin)
.await
.expect("Failed to create admin user");
let session_token = auth::create_session(&self.pool, user.id, None, None)
.await
.expect("Failed to create session for admin user");
(user.id, session_token)
}
/// Register a user through the API (POST /api/v1/auth/register) and return
/// the raw magic link token from the database so the test can verify it.
///
/// This exercises the full registration flow including Turnstile bypass.
pub async fn register_user_via_api(&self, email: &str) -> (StatusCode, serde_json::Value) {
let body = serde_json::json!({
"email": email,
"turnstile_token": "test-token"
});
self.post("/api/v1/auth/register", &body).await
}
/// Get the most recent unused magic link token hash for an email from the
/// database. Returns the raw token — but since we only store hashes, this
/// is only useful when combined with knowing the raw token. Instead, we
/// return the token_hash so tests can call verify with the right approach.
///
/// For tests, we directly create a magic link token and return the raw token.
pub async fn create_magic_link_for_email(&self, email: &str) -> Option<String> {
auth::create_magic_link(&self.pool, email)
.await
.expect("Failed to create magic link")
}
// ── Synthesis helpers ────────────────────────────────────────────
/// Insert a test synthesis directly into the database, bypassing the
/// LLM generation pipeline. Returns the synthesis UUID.
///
/// `sections_json` should be a `serde_json::Value` array of
/// `{title, items: [{title, url, summary}]}` objects.
pub async fn insert_test_synthesis(
&self,
user_id: uuid::Uuid,
week: &str,
sections_json: &serde_json::Value,
) -> uuid::Uuid {
let row = db::syntheses::create(&self.pool, user_id, week, sections_json, uuid::Uuid::new_v4(), None)
.await
.expect("Failed to insert test synthesis");
row.id
}
}
#[allow(dead_code)]
impl TestApp {
/// Explicitly clean up the test database.
///
/// Call this at the end of each test to ensure the ephemeral database is
/// dropped. If not called, the `Drop` implementation will attempt cleanup
/// but may be less reliable.
pub async fn cleanup(self) {
let db_name = self.db_name.clone();
let admin_pool = self.admin_pool.clone();
// Close the test pool to release all connections
self.pool.close().await;
// Force-disconnect any remaining connections and drop the database
let _ = admin_pool
.execute(
format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid()",
db_name
)
.as_str(),
)
.await;
let _ = admin_pool
.execute(format!("DROP DATABASE IF EXISTS \"{}\"", db_name).as_str())
.await;
}
}
impl Drop for TestApp {
fn drop(&mut self) {
// Fire-and-forget cleanup. Don't .join() — that deadlocks when
// running inside a tokio runtime (the spawned thread's block_on
// conflicts with the existing runtime's connection pool).
let admin_pool = self.admin_pool.clone();
let db_name = self.db_name.clone();
let test_pool = self.pool.clone();
std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
test_pool.close().await;
let _ = admin_pool
.execute(
format!(
"SELECT pg_terminate_backend(pid) FROM pg_stat_activity WHERE datname = '{}' AND pid <> pg_backend_pid()",
db_name
)
.as_str(),
)
.await;
let _ = admin_pool
.execute(format!("DROP DATABASE IF EXISTS \"{}\"", db_name).as_str())
.await;
});
});
// Don't join — let the cleanup thread run independently
}
}