Files
cg_api_secure-webshare/crates/cgcx-server/src/main.rs

845 lines
31 KiB
Rust

use axum::{
body::Body,
extract::{Path, Query, State},
http::{header, HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
routing::{get, post},
Json, Router,
};
use cgcx_config::Config;
use cgcx_core::{ContentId, CgcxError};
use cgcx_crypto::{unwrap_content_key, DecryptStream, MasterKey};
use cgcx_db::{Database, ContentRepo, ContentFileRepo};
use cgcx_storage::Storage;
use serde::{Deserialize, Serialize};
use std::net::IpAddr;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::AsyncReadExt;
use tower_http::{
catch_panic::CatchPanicLayer,
compression::CompressionLayer,
cors::{AllowOrigin, CorsLayer},
services::ServeDir,
timeout::TimeoutLayer,
trace::TraceLayer,
};
use tracing::{info, warn};
use sodiumoxide::crypto::secretstream::xchacha20poly1305::Tag::Final as TagFinal;
#[derive(Clone)]
struct AppState {
db: Arc<Database>,
storage: Arc<Storage>,
config: Arc<Config>,
master_key: Arc<MasterKey>,
cookie_secret: Vec<u8>,
allowed_roots: Arc<Vec<std::path::PathBuf>>,
}
#[derive(Serialize)]
struct HealthResponse {
status: String,
}
#[derive(Serialize)]
struct ContentMetadata {
cxid: String,
files: Vec<FileMetadata>,
has_password: bool,
max_views: Option<u64>,
current_views: u64,
allow_download: bool,
created_at: String,
}
#[derive(Serialize)]
struct FileMetadata {
idx: u32,
name: String,
mime: String,
size: u64,
render_flags: u32,
}
#[derive(Deserialize)]
struct VerifyPasswordRequest {
password: String,
}
#[derive(Deserialize)]
struct FileQuery {
#[serde(default)]
download: bool,
#[serde(rename = "sc", default)]
sc: Option<String>,
}
#[derive(Deserialize, Default)]
struct ScQuery {
#[serde(rename = "sc", default)]
sc: Option<String>,
}
struct ByteRange {
start: u64,
end: Option<u64>,
}
struct AppError(CgcxError);
impl From<CgcxError> for AppError {
fn from(e: CgcxError) -> Self {
Self(e)
}
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, msg) = match &self.0 {
CgcxError::NotFound => (StatusCode::NOT_FOUND, "Not found"),
CgcxError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized"),
CgcxError::Forbidden => (StatusCode::FORBIDDEN, "Forbidden"),
CgcxError::BadRequest(_) => (StatusCode::BAD_REQUEST, "Bad request"),
CgcxError::InvalidContentId(_) => (StatusCode::BAD_REQUEST, "Bad request"),
CgcxError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, "Rate limited"),
CgcxError::InsufficientStorage => (StatusCode::INSUFFICIENT_STORAGE, "Insufficient storage"),
other => {
tracing::error!("Internal server error: {}", other);
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error")
}
};
let body = serde_json::json!({ "error": msg });
(status, [(header::CONTENT_TYPE, "application/json")], body.to_string()).into_response()
}
}
type AppResult<T> = Result<T, AppError>;
#[tokio::main]
async fn main() -> cgcx_core::Result<()> {
tracing_subscriber::fmt::init();
// Log panics so we can diagnose 500s that CatchPanicLayer swallows.
std::panic::set_hook(Box::new(|info| {
let msg = if let Some(s) = info.payload().downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = info.payload().downcast_ref::<String>() {
s.clone()
} else {
"unknown panic payload".to_string()
};
let location = info.location().map(|l| format!("{}:{}", l.file(), l.line())).unwrap_or_default();
tracing::error!("PANIC at {}: {}", location, msg);
}));
let config = Arc::new(Config::load()?);
config.validate()?;
tokio::fs::create_dir_all("data").await.ok();
let db = Arc::new(Database::open("data/db.sqlite")?);
db.run_migrations().await?;
let storage = Arc::new(Storage::new(config.storage.paths.clone()));
storage.ensure_dirs().await?;
let master_key = match &config.crypto.aes_master_key_source {
cgcx_config::KeySource::Env { var } => MasterKey::load_from_env(var)?,
cgcx_config::KeySource::File { path } => MasterKey::load_from_file(path)?,
};
master_key.log_startup(false);
let cookie_secret = blake3::hash(master_key.as_bytes()).as_bytes().to_vec();
let allowed_roots = Arc::new(vec![
tokio::fs::canonicalize(&config.storage.paths.media).await.map_err(|e| CgcxError::Io(e))?,
tokio::fs::canonicalize(&config.storage.paths.documents).await.map_err(|e| CgcxError::Io(e))?,
tokio::fs::canonicalize(&config.storage.paths.text).await.map_err(|e| CgcxError::Io(e))?,
tokio::fs::canonicalize(&config.storage.paths.temp).await.map_err(|e| CgcxError::Io(e))?,
]);
let state = AppState {
db,
storage,
config: config.clone(),
master_key: Arc::new(master_key),
cookie_secret,
allowed_roots,
};
let mut governor_builder = tower_governor::governor::GovernorConfigBuilder::default();
let mut governor_builder = governor_builder.key_extractor(CgcxKeyExtractor);
governor_builder.period(Duration::from_secs(60) / config.rate_limiting.requests_per_minute);
governor_builder.burst_size(config.rate_limiting.burst);
let governor_conf = governor_builder
.finish()
.expect("invalid general rate limit config");
let mut password_governor_builder = tower_governor::governor::GovernorConfigBuilder::default();
let mut password_governor_builder = password_governor_builder.key_extractor(CgcxKeyExtractor);
password_governor_builder.period(Duration::from_secs(60) / config.rate_limiting.password_attempts_per_minute);
password_governor_builder.burst_size(3);
let password_governor_conf = password_governor_builder
.finish()
.expect("invalid password rate limit config");
let password_route = Router::new()
.route("/api/content/{cxid}/verify-password", post(verify_password))
.layer(tower_governor::GovernorLayer {
config: Arc::new(password_governor_conf),
});
let static_service = ServeDir::new("frontend/dist/assets");
let mut origins: Vec<HeaderValue> = vec![
config.server.base_url.parse().expect("invalid server.base_url"),
];
for origin in [
"http://127.0.0.1:5173",
"http://localhost:5173",
"http://127.0.0.1:8090",
"http://localhost:8090",
] {
if let Ok(hv) = origin.parse::<HeaderValue>() {
if !origins.contains(&hv) {
origins.push(hv);
}
}
}
let cors = CorsLayer::new()
.allow_origin(AllowOrigin::list(origins))
.allow_methods([Method::GET, Method::POST, Method::HEAD, Method::OPTIONS])
.allow_headers([header::CONTENT_TYPE, header::AUTHORIZATION, header::ACCEPT, header::ACCEPT_ENCODING, header::RANGE])
.allow_credentials(true)
.max_age(Duration::from_secs(86400));
let compression = CompressionLayer::new().compress_when(|_status: axum::http::StatusCode, _version: axum::http::Version, headers: &axum::http::HeaderMap, _extensions: &axum::http::Extensions| {
headers
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|ct| {
ct.starts_with("text/html")
|| ct.starts_with("text/css")
|| ct.starts_with("application/json")
|| ct.starts_with("text/plain")
})
.unwrap_or(false)
});
let app = Router::new()
.route("/api/health", get(health))
.route("/api/content/{cxid}", get(get_metadata))
.route("/api/content/{cxid}/file/{file_idx}", get(serve_file))
.merge(password_route)
.nest_service("/assets", static_service)
.fallback(fallback)
.layer(tower_governor::GovernorLayer {
config: Arc::new(governor_conf),
})
.layer(compression)
.layer(axum::middleware::from_fn(security_headers))
.layer(TraceLayer::new_for_http())
.layer(TimeoutLayer::with_status_code(
StatusCode::REQUEST_TIMEOUT,
Duration::from_secs(30),
))
.layer(CatchPanicLayer::new())
.layer(cors)
.with_state(state.clone());
// Spawn background sweeper task
let db_clone = state.db.clone();
let storage_clone = state.storage.clone();
let config_clone = (*state.config).clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(24 * 60 * 60));
interval.tick().await; // skip immediate first tick
loop {
interval.tick().await;
info!("Running daily orphan cleanup");
let pipeline = cgcx_file_pipeline::FilePipeline::new(
(*storage_clone).clone(),
(*db_clone).clone(),
config_clone.clone(),
);
if let Err(e) = pipeline.cleanup_orphans().await {
warn!("Orphan cleanup failed: {}", e);
}
}
});
let addr = format!("{}:{}", config.server.bind_address, config.server.port);
info!("Server listening on http://{}", addr);
let listener = tokio::net::TcpListener::bind(&addr).await.map_err(|e| CgcxError::Io(e))?;
axum::serve(listener, app).await.map_err(|e| CgcxError::Io(e))?;
Ok(())
}
async fn fallback(uri: axum::http::Uri) -> Response {
let path = uri.path();
tracing::info!("fallback: path={}", path);
if path.starts_with("/api/") {
return (StatusCode::NOT_FOUND, Json(serde_json::json!({"error": "Not found"}))).into_response();
}
match tokio::fs::read_to_string("frontend/dist/index.html").await {
Ok(html) => (StatusCode::OK, [(header::CONTENT_TYPE, "text/html")], html).into_response(),
Err(_) => (StatusCode::NOT_FOUND, "Frontend not built").into_response(),
}
}
async fn security_headers(req: axum::http::Request<Body>, next: Next) -> Response {
let mut response = next.run(req).await;
let headers = response.headers_mut();
headers.insert(
header::CONTENT_SECURITY_POLICY,
HeaderValue::from_static("default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline'; img-src 'self' data: blob:; media-src 'self' blob:; connect-src 'self'; frame-ancestors 'none'; base-uri 'self'; form-action 'self';"),
);
headers.insert(header::X_CONTENT_TYPE_OPTIONS, HeaderValue::from_static("nosniff"));
headers.insert(header::X_FRAME_OPTIONS, HeaderValue::from_static("DENY"));
headers.insert(header::REFERRER_POLICY, HeaderValue::from_static("strict-origin-when-cross-origin"));
headers.insert(
HeaderName::from_static("permissions-policy"),
HeaderValue::from_static("accelerometer=(), camera=(), geolocation=(), gyroscope=(), magnetometer=(), microphone=(), payment=(), usb=()"),
);
headers.insert(
header::STRICT_TRANSPORT_SECURITY,
HeaderValue::from_static("max-age=31536000; includeSubDomains; preload"),
);
response
}
async fn health() -> impl IntoResponse {
tracing::info!("health");
axum::Json(HealthResponse {
status: "ok".into(),
})
}
fn password_from_request(
headers: &HeaderMap,
query_sc: Option<&str>,
cxid: &str,
password_hash: Option<&str>,
cookie_secret: &[u8],
) -> bool {
if let Some(sc) = query_sc {
if let Some(hash) = password_hash {
use argon2::{Argon2, PasswordHash, PasswordVerifier};
if let Ok(parsed_hash) = PasswordHash::new(hash) {
if Argon2::default().verify_password(sc.as_bytes(), &parsed_hash).is_ok() {
return true;
}
}
}
}
headers
.get_all(header::COOKIE)
.iter()
.any(|v| {
v.to_str().ok().map(|s| {
s.split(';').any(|part| {
let part = part.trim();
part.starts_with("cgcx_pw=") && verify_cookie(cxid, &part[8..], cookie_secret)
})
}).unwrap_or(false)
})
}
async fn get_metadata(
State(state): State<AppState>,
Path(cxid): Path<String>,
Query(query): Query<ScQuery>,
headers: HeaderMap,
) -> AppResult<Response> {
tracing::info!("get_metadata: cxid={}", cxid);
let content_id = ContentId::try_from(cxid.as_str())?;
let repo = ContentRepo::new(state.db.conn());
let content = repo.get(&content_id).await?.ok_or(CgcxError::NotFound)?;
if content.status == cgcx_core::ContentStatus::Deleted || content.status == cgcx_core::ContentStatus::Blacklisted {
tracing::warn!("get_metadata returning NotFound for cxid={}", cxid);
return Err(CgcxError::NotFound.into());
}
if let Some(max) = content.max_views {
if content.view_count >= max {
return Ok(Response::builder()
.status(StatusCode::GONE)
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
}
}
if content.password_hash.is_some() {
if !password_from_request(&headers, query.sc.as_deref(), &cxid, content.password_hash.as_deref(), &state.cookie_secret) {
tracing::warn!("get_metadata returning Unauthorized for cxid={}", cxid);
return Err(CgcxError::Unauthorized.into());
}
}
let file_repo = ContentFileRepo::new(state.db.conn());
let files = file_repo.list_by_content(&content_id).await?;
let body = serde_json::to_vec(&ContentMetadata {
cxid: content.id.to_string(),
files: files.into_iter().map(|f| FileMetadata {
idx: f.file_index,
name: f.original_name,
mime: f.mime_type,
size: f.size_bytes,
render_flags: f.render_flags,
}).collect(),
has_password: content.password_hash.is_some(),
max_views: content.max_views,
current_views: content.view_count,
allow_download: content.allow_download,
created_at: content.created_at.to_rfc3339(),
}).map_err(|_| CgcxError::BadRequest("json serialization".into()))?;
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(body))
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?)
}
async fn verify_password(
State(state): State<AppState>,
Path(cxid): Path<String>,
Json(req): Json<VerifyPasswordRequest>,
) -> AppResult<impl IntoResponse> {
tracing::info!("verify_password: cxid={}", cxid);
let content_id = ContentId::try_from(cxid.as_str())?;
let repo = ContentRepo::new(state.db.conn());
let content = repo.get(&content_id).await?.ok_or(CgcxError::NotFound)?;
let Some(hash) = content.password_hash else {
return Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
};
use argon2::{Argon2, PasswordHash, PasswordVerifier};
let parsed_hash = PasswordHash::new(&hash)
.map_err(|_| CgcxError::Crypto("invalid stored password hash".into()))?;
let valid = Argon2::default()
.verify_password(req.password.as_bytes(), &parsed_hash)
.is_ok();
if !valid {
tracing::warn!("verify_password returning Unauthorized for cxid={}", cxid);
return Err(CgcxError::Unauthorized.into());
}
let cookie_value = make_cookie_value(&cxid, &state.cookie_secret);
let cookie = format!(
"cgcx_pw={}; Max-Age=3600; SameSite=Strict; HttpOnly; Path=/",
cookie_value
);
Ok(Response::builder()
.status(StatusCode::NO_CONTENT)
.header(header::SET_COOKIE, cookie)
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?)
}
async fn serve_file(
State(state): State<AppState>,
Path((cxid, file_idx)): Path<(String, u32)>,
Query(query): Query<FileQuery>,
headers: HeaderMap,
) -> AppResult<impl IntoResponse> {
tracing::info!("serve_file: cxid={} file_idx={}", cxid, file_idx);
let content_id = ContentId::try_from(cxid.as_str())?;
let repo = ContentRepo::new(state.db.conn());
let content = repo.get(&content_id).await?.ok_or(CgcxError::NotFound)?;
if content.status == cgcx_core::ContentStatus::Deleted || content.status == cgcx_core::ContentStatus::Blacklisted {
tracing::warn!("serve_file returning NotFound for cxid={}", cxid);
return Err(CgcxError::NotFound.into());
}
if let Some(max) = content.max_views {
if content.view_count >= max {
return Ok(Response::builder()
.status(StatusCode::GONE)
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
}
}
if content.password_hash.is_some() {
if !password_from_request(&headers, query.sc.as_deref(), &cxid, content.password_hash.as_deref(), &state.cookie_secret) {
tracing::warn!("serve_file returning Unauthorized for cxid={}", cxid);
return Err(CgcxError::Unauthorized.into());
}
}
if query.download && !content.allow_download {
tracing::warn!("serve_file returning Forbidden (download not allowed) for cxid={}", cxid);
return Err(CgcxError::Forbidden.into());
}
let file_repo = ContentFileRepo::new(state.db.conn());
let files = file_repo.list_by_content(&content_id).await?;
let file = files.iter().find(|f| f.file_index == file_idx).ok_or(CgcxError::NotFound)?;
// Handle zero-size files early to avoid underflow in range parsing
if file.size_bytes == 0 {
let etag = format!("\"{}\"", hex::encode(&file.encrypted_hash));
let content_type = file.mime_type.clone();
let sanitized_name = sanitize_content_disposition(&file.original_name);
let disposition = if query.download && content.allow_download {
format!("attachment; filename=\"{}\"", sanitized_name)
} else {
format!("inline; filename=\"{}\"", sanitized_name)
};
return Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_DISPOSITION, disposition)
.header(header::ETAG, etag)
.header(header::CONTENT_LENGTH, "0")
.header(header::CACHE_CONTROL, "private, no-store, max-age=0")
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
}
// Path traversal validation
let canonical_path = tokio::fs::canonicalize(&file.stored_path).await
.map_err(|e| {
tracing::error!("canonicalize failed for {:?}: {}", file.stored_path, e);
CgcxError::Storage("invalid stored path".into())
})?;
if !state.allowed_roots.iter().any(|root| canonical_path.starts_with(root)) {
tracing::error!("Path traversal blocked: {:?}", canonical_path);
tracing::warn!("serve_file returning Forbidden (path traversal) for cxid={}", cxid);
return Err(CgcxError::Forbidden.into());
}
let etag = format!("\"{}\"", hex::encode(&file.encrypted_hash));
// If-None-Match check (skip increment)
if let Some(inm) = headers.get(header::IF_NONE_MATCH) {
if inm.to_str().ok().map(|s| s == etag).unwrap_or(false) {
return Ok(Response::builder()
.status(StatusCode::NOT_MODIFIED)
.header(header::ETAG, etag.clone())
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
}
}
// Parse Range header
let range = if let Some(range_hdr) = headers.get(header::RANGE) {
if let Some(hdr_str) = range_hdr.to_str().ok() {
match parse_range(hdr_str, file.size_bytes) {
Some(r) => Some(r),
None => {
return Ok(Response::builder()
.status(StatusCode::RANGE_NOT_SATISFIABLE)
.header(header::CONTENT_RANGE, format!("bytes */{}", file.size_bytes))
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
}
}
} else {
None
}
} else {
None
};
let is_range = range.is_some();
let is_conditional = headers.contains_key(header::IF_NONE_MATCH);
if !is_range && !is_conditional {
let new_views = repo.increment_views(&content_id).await?;
if let Some(max) = content.max_views {
if new_views >= max {
if !state.config.content.keep_content {
for f in &files {
if let Err(e) = tokio::fs::remove_file(&f.stored_path).await {
tracing::warn!("failed to remove file {:?}: {}", f.stored_path, e);
}
}
let _ = state.storage.delete_content_files(&content_id, "application/octet-stream").await;
}
repo.set_status(&content_id, cgcx_core::ContentStatus::Deleted).await?;
return Ok(Response::builder()
.status(StatusCode::GONE)
.body(Body::empty())
.map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?);
}
}
}
let content_type = file.mime_type.clone();
let sanitized_name = sanitize_content_disposition(&file.original_name);
let disposition = if query.download && content.allow_download {
format!("attachment; filename=\"{}\"", sanitized_name)
} else {
format!("inline; filename=\"{}\"", sanitized_name)
};
let (status, content_length, content_range) = if let Some(ref r) = range {
let end = r.end.unwrap_or(file.size_bytes - 1);
let len = end - r.start + 1;
let cr = format!("bytes {}-{}/{}", r.start, end, file.size_bytes);
(StatusCode::PARTIAL_CONTENT, len, Some(cr))
} else {
(StatusCode::OK, file.size_bytes, None)
};
let mut response = Response::builder()
.status(status)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_DISPOSITION, disposition)
.header(header::ETAG, etag.clone())
.header(header::CONTENT_LENGTH, content_length.to_string());
if file.mime_type.starts_with("video/") || file.mime_type.starts_with("audio/") {
response = response.header(header::ACCEPT_RANGES, "bytes");
}
if let Some(cr) = content_range {
response = response.header(header::CONTENT_RANGE, cr);
}
if content.password_hash.is_some() {
response = response.header(header::CACHE_CONTROL, "private, no-store, max-age=0");
} else {
response = response.header(header::CACHE_CONTROL, "private, max-age=60");
}
let (tx, rx) = tokio::sync::mpsc::channel::<Result<Vec<u8>, std::io::Error>>(4);
let path = file.stored_path.clone();
let master_key = state.master_key.clone();
let wrapped_key = file.encrypted_key_wrapped.clone();
let expected_hash = file.encrypted_hash.clone();
let file_size = file.size_bytes;
tokio::spawn(async move {
if let Err(e) = stream_decrypted_file(path, master_key, wrapped_key, tx, range, file_size, expected_hash).await {
warn!("stream error: {}", e);
}
});
let body_stream = tokio_stream::wrappers::ReceiverStream::new(rx);
let body = Body::from_stream(body_stream);
Ok(response.body(body).map_err(|e| CgcxError::Storage(format!("response build failed: {}", e)))?)
}
async fn stream_decrypted_file(
path: std::path::PathBuf,
master_key: Arc<MasterKey>,
wrapped_key: Vec<u8>,
tx: tokio::sync::mpsc::Sender<Result<Vec<u8>, std::io::Error>>,
_range: Option<ByteRange>,
_file_size: u64,
expected_hash: Vec<u8>,
) -> cgcx_core::Result<()> {
let mut file = tokio::fs::File::open(&path).await.map_err(|e| CgcxError::Storage(e.to_string()))?;
let mut header_buf = [0u8; 24];
file.read_exact(&mut header_buf).await.map_err(|e| CgcxError::Storage(e.to_string()))?;
let content_key = unwrap_content_key(&wrapped_key, &master_key)?;
let header = sodiumoxide::crypto::secretstream::xchacha20poly1305::Header::from_slice(&header_buf)
.ok_or_else(|| CgcxError::Crypto("invalid header".into()))?;
let mut decrypt_stream = DecryptStream::new(&content_key, &header)?;
let mut len_buf = [0u8; 4];
let mut saw_final = false;
loop {
if file.read_exact(&mut len_buf).await.is_err() {
break; // EOF at message boundary
}
let msg_len = u32::from_le_bytes(len_buf) as usize;
if msg_len > 50_000_000 {
let _ = tx.send(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "message too large"))).await;
return Err(CgcxError::Crypto("message length exceeds sanity bound".into()));
}
let mut msg_buf = vec![0u8; msg_len];
file.read_exact(&mut msg_buf).await.map_err(|e| CgcxError::Storage(e.to_string()))?;
match decrypt_stream.pull(&msg_buf) {
Ok((plaintext, tag)) => {
if tx.send(Ok(plaintext)).await.is_err() {
return Ok(()); // client disconnected
}
if tag == TagFinal {
saw_final = true;
break;
}
}
Err(e) => {
let _ = tx.send(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))).await;
return Err(e);
}
}
}
if !saw_final {
let _ = tx.send(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "stream ended without Final tag"))).await;
return Err(CgcxError::Crypto("stream ended without Final tag".into()));
}
let computed_hash = decrypt_stream.finalize().to_vec();
if computed_hash != expected_hash {
tracing::error!(target: "critical", "BLAKE3 integrity mismatch for file {:?}: expected {} got {}", path, hex::encode(&expected_hash), hex::encode(&computed_hash));
let _ = tx.send(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, "integrity check failed"))).await;
return Err(CgcxError::Crypto("BLAKE3 integrity mismatch".into()));
}
Ok(())
}
fn parse_range(range_hdr: &str, file_size: u64) -> Option<ByteRange> {
const PREFIX: &str = "bytes=";
if !range_hdr.starts_with(PREFIX) {
return None;
}
let rest = &range_hdr[PREFIX.len()..];
// Basic version: only single-byte range
if rest.contains(',') {
return None;
}
let mut parts = rest.splitn(2, '-');
let start_str = parts.next()?.trim();
let end_str = parts.next()?.trim();
if start_str.is_empty() && end_str.is_empty() {
return None;
}
if start_str.is_empty() {
let suffix_len: u64 = end_str.parse().ok()?;
let start = file_size.saturating_sub(suffix_len);
Some(ByteRange {
start,
end: Some(file_size.saturating_sub(1)),
})
} else if end_str.is_empty() {
let start: u64 = start_str.parse().ok()?;
if start >= file_size {
return None;
}
Some(ByteRange { start, end: None })
} else {
let start: u64 = start_str.parse().ok()?;
let end: u64 = end_str.parse().ok()?;
if start > end || start >= file_size {
return None;
}
let end = end.min(file_size - 1);
Some(ByteRange { start, end: Some(end) })
}
}
fn sanitize_content_disposition(name: &str) -> String {
name.chars()
.filter(|c| !c.is_control())
.map(|c| match c {
'\\' => "\\\\".to_string(),
'"' => "\\\"".to_string(),
c => c.to_string(),
})
.collect()
}
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
fn hmac_cookie(cxid: &str, secret: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(secret).expect("HMAC can take key of any size");
mac.update(cxid.as_bytes());
mac.finalize().into_bytes().to_vec()
}
fn make_cookie_value(cxid: &str, secret: &[u8]) -> String {
use base64::Engine;
let mac = hmac_cookie(cxid, secret);
let mut raw = Vec::with_capacity(cxid.len() + 1 + mac.len());
raw.extend_from_slice(cxid.as_bytes());
raw.push(b':');
raw.extend_from_slice(&mac);
base64::engine::general_purpose::STANDARD.encode(&raw)
}
fn verify_cookie(cxid: &str, cookie_value: &str, secret: &[u8]) -> bool {
use base64::Engine;
let decoded = match base64::engine::general_purpose::STANDARD.decode(cookie_value) {
Ok(d) => d,
Err(_) => return false,
};
let mut parts = decoded.splitn(2, |&b| b == b':');
let decoded_cxid = match parts.next() {
Some(p) => match std::str::from_utf8(p) {
Ok(s) => s,
Err(_) => return false,
},
None => return false,
};
let mac_bytes = match parts.next() {
Some(p) => p,
None => return false,
};
if decoded_cxid != cxid {
return false;
}
let expected = hmac_cookie(cxid, secret);
if mac_bytes.len() != expected.len() {
return false;
}
use subtle::ConstantTimeEq;
mac_bytes.ct_eq(&expected).into()
}
// Custom key extractor for tower_governor that never fails with UnableToExtractKey.
// It tries forwarded headers, then ConnectInfo, then falls back to User-Agent or a global key.
#[derive(Clone, Copy, Debug)]
struct CgcxKeyExtractor;
impl tower_governor::key_extractor::KeyExtractor for CgcxKeyExtractor {
type Key = String;
fn extract<T>(&self, req: &axum::http::Request<T>) -> Result<Self::Key, tower_governor::errors::GovernorError> {
// 1. Try X-Forwarded-For header (reverse proxy)
if let Some(xff) = req.headers().get("x-forwarded-for") {
if let Ok(s) = xff.to_str() {
if let Some(ip) = s.split(',').next() {
return Ok(ip.trim().to_string());
}
}
}
// 2. Try X-Real-Ip header
if let Some(xri) = req.headers().get("x-real-ip") {
if let Ok(s) = xri.to_str() {
if let Ok(ip) = s.parse::<IpAddr>() {
return Ok(ip.to_string());
}
}
}
// 3. Try ConnectInfo extension (direct connections)
if let Some(addr) = req.extensions().get::<axum::extract::ConnectInfo<std::net::SocketAddr>>() {
return Ok(addr.ip().to_string());
}
// 4. Fall back to User-Agent so different browsers get different buckets
if let Some(ua) = req.headers().get("user-agent") {
if let Ok(s) = ua.to_str() {
return Ok(format!("ua:{}", s));
}
}
// 5. Ultimate fallback: global bucket
Ok("global".to_string())
}
}