Files
harmony/harmony_secret/src/store/zitadel.rs

440 lines
14 KiB
Rust

use log::{debug, info};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcSession {
pub openbao_token: String,
pub openbao_token_ttl: u64,
pub openbao_renewable: bool,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub expires_at: Option<i64>,
}
impl OidcSession {
pub fn is_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
expires_at <= now
} else {
false
}
}
pub fn is_openbao_token_expired(&self, _ttl: u64) -> bool {
if let Some(expires_at) = self.expires_at {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
expires_at <= now
} else {
false
}
}
}
#[derive(Debug, Deserialize)]
struct DeviceAuthorizationResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(rename = "verification_uri_complete")]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: u64,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(rename = "expires_in", default)]
expires_in: Option<u64>,
#[serde(rename = "refresh_token", default)]
refresh_token: Option<String>,
#[serde(rename = "id_token", default)]
id_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenErrorResponse {
error: String,
#[serde(rename = "error_description")]
error_description: Option<String>,
}
fn get_session_cache_path() -> PathBuf {
let hash = {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
"zitadel-oidc".hash(&mut hasher);
format!("{:016x}", hasher.finish())
};
directories::BaseDirs::new()
.map(|dirs| {
dirs.data_dir()
.join("harmony")
.join("secrets")
.join(format!("oidc_session_{hash}"))
})
.unwrap_or_else(|| PathBuf::from(format!("/tmp/oidc_session_{hash}")))
}
fn load_session() -> Result<OidcSession, String> {
let path = get_session_cache_path();
serde_json::from_str(
&fs::read_to_string(&path)
.map_err(|e| format!("Could not load session from {path:?}: {e}"))?,
)
.map_err(|e| format!("Could not deserialize session from {path:?}: {e}"))
}
fn save_session(session: &OidcSession) -> Result<(), String> {
let path = get_session_cache_path();
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)
.map_err(|e| format!("Could not create session directory: {e}"))?;
}
#[cfg(unix)]
{
use std::os::unix::fs::OpenOptionsExt;
let mut file = fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)
.map_err(|e| format!("Could not open session file: {e}"))?;
use std::io::Write;
file.write_all(
serde_json::to_string_pretty(session)
.map_err(|e| format!("Could not serialize session: {e}"))?
.as_bytes(),
)
.map_err(|e| format!("Could not write session file: {e}"))?;
}
#[cfg(not(unix))]
{
fs::write(
&path,
serde_json::to_string_pretty(session).map_err(|e| e.to_string())?,
)
.map_err(|e| format!("Could not write session file: {e}"))?;
}
Ok(())
}
pub struct ZitadelOidcAuth {
sso_url: String,
client_id: String,
skip_tls: bool,
/// OpenBao URL for JWT exchange. When set, the id_token from Zitadel is
/// exchanged for a real OpenBao client token via `/v1/auth/{mount}/login`.
openbao_url: Option<String>,
jwt_auth_mount: Option<String>,
jwt_role: Option<String>,
}
impl ZitadelOidcAuth {
pub fn new(
sso_url: String,
client_id: String,
skip_tls: bool,
openbao_url: Option<String>,
jwt_auth_mount: Option<String>,
jwt_role: Option<String>,
) -> Self {
Self {
sso_url,
client_id,
skip_tls,
openbao_url,
jwt_auth_mount,
jwt_role,
}
}
pub async fn authenticate(&self) -> Result<OidcSession, String> {
if let Ok(session) = load_session() {
if !session.is_expired() {
info!("ZITADEL_OIDC: Using cached session");
return Ok(session);
}
}
info!("ZITADEL_OIDC: Starting device authorization flow");
let device_code = self.request_device_code().await?;
self.print_verification_instructions(&device_code);
let token_response = self
.poll_for_token(&device_code, device_code.interval)
.await?;
let session = self.process_token_response(token_response).await?;
let _ = save_session(&session);
Ok(session)
}
fn http_client(&self) -> Result<reqwest::Client, String> {
let mut builder = reqwest::Client::builder();
if self.skip_tls {
builder = builder.danger_accept_invalid_certs(true);
}
// Resolve the SSO hostname to 127.0.0.1 so the device flow works
// without /etc/hosts entries. This preserves the correct Host header
// (which Zitadel validates against ExternalDomain) while routing
// through the local k3d/traefik ingress.
if let Ok(url) = reqwest::Url::parse(&self.sso_url) {
if let Some(host) = url.host_str() {
let port = url
.port()
.unwrap_or(if url.scheme() == "https" { 443 } else { 80 });
let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
builder = builder.resolve(host, addr);
}
}
builder
.build()
.map_err(|e| format!("Failed to build HTTP client: {e}"))
}
async fn request_device_code(&self) -> Result<DeviceAuthorizationResponse, String> {
let client = self.http_client()?;
let params = [
("client_id", self.client_id.as_str()),
("scope", "openid email profile offline_access"),
];
let response = client
.post(format!("{}/oauth/v2/device_authorization", self.sso_url))
.form(&params)
.send()
.await
.map_err(|e| format!("Device authorization request failed: {e}"))?;
response
.json::<DeviceAuthorizationResponse>()
.await
.map_err(|e| format!("Failed to parse device authorization response: {e}"))
}
fn print_verification_instructions(&self, code: &DeviceAuthorizationResponse) {
println!();
println!("=================================================");
println!("[Harmony] To authenticate with Zitadel, open your browser:");
println!(" {}", code.verification_uri);
println!();
println!(" and enter code: {}", code.user_code);
if let Some(ref complete_url) = code.verification_uri_complete {
println!();
println!(" Or visit this direct link (code is pre-filled):");
println!(" {}", complete_url);
}
println!("=================================================");
println!();
}
async fn poll_for_token(
&self,
code: &DeviceAuthorizationResponse,
interval_secs: u64,
) -> Result<TokenResponse, String> {
let client = self.http_client()?;
let params = [
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("device_code", code.device_code.as_str()),
("client_id", self.client_id.as_str()),
];
let interval = Duration::from_secs(interval_secs.max(5));
let max_attempts = (code.expires_in / interval_secs).max(60) as usize;
for attempt in 0..max_attempts {
tokio::time::sleep(interval).await;
let response = client
.post(format!("{}/oauth/v2/token", self.sso_url))
.form(&params)
.send()
.await
.map_err(|e| format!("Token request failed: {e}"))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| format!("Failed to read response body: {e}"))?;
if status == 400 {
if let Ok(error) = serde_json::from_str::<TokenErrorResponse>(&body) {
match error.error.as_str() {
"authorization_pending" => {
debug!("ZITADEL_OIDC: authorization_pending (attempt {})", attempt);
continue;
}
"slow_down" => {
debug!("ZITADEL_OIDC: slow_down, increasing interval");
tokio::time::sleep(Duration::from_secs(5)).await;
continue;
}
"expired_token" => {
return Err(
"Device code expired. Please restart authentication.".to_string()
);
}
"access_denied" => {
return Err("Access denied by user.".to_string());
}
_ => {
return Err(format!(
"OAuth error: {} - {}",
error.error,
error.error_description.unwrap_or_default()
));
}
}
}
}
return serde_json::from_str(&body)
.map_err(|e| format!("Failed to parse token response: {e}"));
}
Err("Token polling timed out".to_string())
}
async fn exchange_jwt_for_openbao_token(
&self,
id_token: &str,
) -> Result<(String, u64, bool), String> {
let openbao_url = self
.openbao_url
.as_ref()
.ok_or("No OpenBao URL configured for JWT exchange")?;
let mount = self.jwt_auth_mount.as_deref().unwrap_or("jwt");
let role = self.jwt_role.as_deref().unwrap_or("harmony-developer");
let client = self.http_client()?;
let url = format!("{}/v1/auth/{}/login", openbao_url, mount);
debug!(
"ZITADEL_OIDC: Exchanging id_token for OpenBao token via {}",
url
);
let response = client
.post(&url)
.json(&serde_json::json!({
"role": role,
"jwt": id_token
}))
.send()
.await
.map_err(|e| format!("JWT exchange request failed: {e}"))?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(format!("JWT exchange failed: {body}"));
}
let body: serde_json::Value = response
.json()
.await
.map_err(|e| format!("Failed to parse JWT exchange response: {e}"))?;
let auth = body
.get("auth")
.ok_or("No 'auth' in JWT exchange response")?;
let client_token = auth
.get("client_token")
.and_then(|v| v.as_str())
.ok_or("No client_token in JWT exchange response")?
.to_string();
let lease_duration = auth
.get("lease_duration")
.and_then(|v| v.as_u64())
.unwrap_or(14400);
let renewable = auth
.get("renewable")
.and_then(|v| v.as_bool())
.unwrap_or(true);
info!(
"ZITADEL_OIDC: JWT exchange successful (ttl={}s)",
lease_duration
);
Ok((client_token, lease_duration, renewable))
}
async fn process_token_response(&self, response: TokenResponse) -> Result<OidcSession, String> {
let (openbao_token, ttl, renewable) = if self.openbao_url.is_some() {
let id_token = response
.id_token
.as_deref()
.ok_or("No id_token in OIDC response (required for JWT exchange)")?;
self.exchange_jwt_for_openbao_token(id_token).await?
} else {
(
response.access_token.clone(),
response.expires_in.unwrap_or(3600),
true,
)
};
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0);
Ok(OidcSession {
openbao_token,
openbao_token_ttl: ttl,
openbao_renewable: renewable,
refresh_token: response.refresh_token,
id_token: response.id_token,
expires_at: Some(now + ttl as i64),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oidc_session_is_expired() {
let session = OidcSession {
openbao_token: "test".to_string(),
openbao_token_ttl: 3600,
openbao_renewable: true,
refresh_token: None,
id_token: None,
expires_at: Some(0),
};
assert!(session.is_expired());
let future = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() as i64)
.unwrap_or(0)
+ 3600;
let session2 = OidcSession {
openbao_token: "test".to_string(),
openbao_token_ttl: 3600,
openbao_renewable: true,
refresh_token: None,
id_token: None,
expires_at: Some(future),
};
assert!(!session2.is_expired());
}
}