440 lines
14 KiB
Rust
440 lines
14 KiB
Rust
use log::{debug, info};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::fs;
|
|
use std::path::PathBuf;
|
|
use std::time::Duration;
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct OidcSession {
|
|
pub openbao_token: String,
|
|
pub openbao_token_ttl: u64,
|
|
pub openbao_renewable: bool,
|
|
pub refresh_token: Option<String>,
|
|
pub id_token: Option<String>,
|
|
pub expires_at: Option<i64>,
|
|
}
|
|
|
|
impl OidcSession {
|
|
pub fn is_expired(&self) -> bool {
|
|
if let Some(expires_at) = self.expires_at {
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs() as i64)
|
|
.unwrap_or(0);
|
|
expires_at <= now
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
pub fn is_openbao_token_expired(&self, _ttl: u64) -> bool {
|
|
if let Some(expires_at) = self.expires_at {
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs() as i64)
|
|
.unwrap_or(0);
|
|
expires_at <= now
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeviceAuthorizationResponse {
|
|
device_code: String,
|
|
user_code: String,
|
|
verification_uri: String,
|
|
#[serde(rename = "verification_uri_complete")]
|
|
verification_uri_complete: Option<String>,
|
|
expires_in: u64,
|
|
interval: u64,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct TokenResponse {
|
|
access_token: String,
|
|
#[serde(rename = "expires_in", default)]
|
|
expires_in: Option<u64>,
|
|
#[serde(rename = "refresh_token", default)]
|
|
refresh_token: Option<String>,
|
|
#[serde(rename = "id_token", default)]
|
|
id_token: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct TokenErrorResponse {
|
|
error: String,
|
|
#[serde(rename = "error_description")]
|
|
error_description: Option<String>,
|
|
}
|
|
|
|
fn get_session_cache_path() -> PathBuf {
|
|
let hash = {
|
|
use std::collections::hash_map::DefaultHasher;
|
|
use std::hash::{Hash, Hasher};
|
|
let mut hasher = DefaultHasher::new();
|
|
"zitadel-oidc".hash(&mut hasher);
|
|
format!("{:016x}", hasher.finish())
|
|
};
|
|
directories::BaseDirs::new()
|
|
.map(|dirs| {
|
|
dirs.data_dir()
|
|
.join("harmony")
|
|
.join("secrets")
|
|
.join(format!("oidc_session_{hash}"))
|
|
})
|
|
.unwrap_or_else(|| PathBuf::from(format!("/tmp/oidc_session_{hash}")))
|
|
}
|
|
|
|
fn load_session() -> Result<OidcSession, String> {
|
|
let path = get_session_cache_path();
|
|
serde_json::from_str(
|
|
&fs::read_to_string(&path)
|
|
.map_err(|e| format!("Could not load session from {path:?}: {e}"))?,
|
|
)
|
|
.map_err(|e| format!("Could not deserialize session from {path:?}: {e}"))
|
|
}
|
|
|
|
fn save_session(session: &OidcSession) -> Result<(), String> {
|
|
let path = get_session_cache_path();
|
|
if let Some(parent) = path.parent() {
|
|
fs::create_dir_all(parent)
|
|
.map_err(|e| format!("Could not create session directory: {e}"))?;
|
|
}
|
|
#[cfg(unix)]
|
|
{
|
|
use std::os::unix::fs::OpenOptionsExt;
|
|
let mut file = fs::OpenOptions::new()
|
|
.write(true)
|
|
.create(true)
|
|
.truncate(true)
|
|
.mode(0o600)
|
|
.open(&path)
|
|
.map_err(|e| format!("Could not open session file: {e}"))?;
|
|
use std::io::Write;
|
|
file.write_all(
|
|
serde_json::to_string_pretty(session)
|
|
.map_err(|e| format!("Could not serialize session: {e}"))?
|
|
.as_bytes(),
|
|
)
|
|
.map_err(|e| format!("Could not write session file: {e}"))?;
|
|
}
|
|
#[cfg(not(unix))]
|
|
{
|
|
fs::write(
|
|
&path,
|
|
serde_json::to_string_pretty(session).map_err(|e| e.to_string())?,
|
|
)
|
|
.map_err(|e| format!("Could not write session file: {e}"))?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub struct ZitadelOidcAuth {
|
|
sso_url: String,
|
|
client_id: String,
|
|
skip_tls: bool,
|
|
/// OpenBao URL for JWT exchange. When set, the id_token from Zitadel is
|
|
/// exchanged for a real OpenBao client token via `/v1/auth/{mount}/login`.
|
|
openbao_url: Option<String>,
|
|
jwt_auth_mount: Option<String>,
|
|
jwt_role: Option<String>,
|
|
}
|
|
|
|
impl ZitadelOidcAuth {
|
|
pub fn new(
|
|
sso_url: String,
|
|
client_id: String,
|
|
skip_tls: bool,
|
|
openbao_url: Option<String>,
|
|
jwt_auth_mount: Option<String>,
|
|
jwt_role: Option<String>,
|
|
) -> Self {
|
|
Self {
|
|
sso_url,
|
|
client_id,
|
|
skip_tls,
|
|
openbao_url,
|
|
jwt_auth_mount,
|
|
jwt_role,
|
|
}
|
|
}
|
|
|
|
pub async fn authenticate(&self) -> Result<OidcSession, String> {
|
|
if let Ok(session) = load_session() {
|
|
if !session.is_expired() {
|
|
info!("ZITADEL_OIDC: Using cached session");
|
|
return Ok(session);
|
|
}
|
|
}
|
|
|
|
info!("ZITADEL_OIDC: Starting device authorization flow");
|
|
|
|
let device_code = self.request_device_code().await?;
|
|
self.print_verification_instructions(&device_code);
|
|
let token_response = self
|
|
.poll_for_token(&device_code, device_code.interval)
|
|
.await?;
|
|
let session = self.process_token_response(token_response).await?;
|
|
let _ = save_session(&session);
|
|
Ok(session)
|
|
}
|
|
|
|
fn http_client(&self) -> Result<reqwest::Client, String> {
|
|
let mut builder = reqwest::Client::builder();
|
|
if self.skip_tls {
|
|
builder = builder.danger_accept_invalid_certs(true);
|
|
}
|
|
|
|
// Resolve the SSO hostname to 127.0.0.1 so the device flow works
|
|
// without /etc/hosts entries. This preserves the correct Host header
|
|
// (which Zitadel validates against ExternalDomain) while routing
|
|
// through the local k3d/traefik ingress.
|
|
if let Ok(url) = reqwest::Url::parse(&self.sso_url) {
|
|
if let Some(host) = url.host_str() {
|
|
let port = url
|
|
.port()
|
|
.unwrap_or(if url.scheme() == "https" { 443 } else { 80 });
|
|
let addr = std::net::SocketAddr::from(([127, 0, 0, 1], port));
|
|
builder = builder.resolve(host, addr);
|
|
}
|
|
}
|
|
|
|
builder
|
|
.build()
|
|
.map_err(|e| format!("Failed to build HTTP client: {e}"))
|
|
}
|
|
|
|
async fn request_device_code(&self) -> Result<DeviceAuthorizationResponse, String> {
|
|
let client = self.http_client()?;
|
|
let params = [
|
|
("client_id", self.client_id.as_str()),
|
|
("scope", "openid email profile offline_access"),
|
|
];
|
|
|
|
let response = client
|
|
.post(format!("{}/oauth/v2/device_authorization", self.sso_url))
|
|
.form(¶ms)
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("Device authorization request failed: {e}"))?;
|
|
|
|
response
|
|
.json::<DeviceAuthorizationResponse>()
|
|
.await
|
|
.map_err(|e| format!("Failed to parse device authorization response: {e}"))
|
|
}
|
|
|
|
fn print_verification_instructions(&self, code: &DeviceAuthorizationResponse) {
|
|
println!();
|
|
println!("=================================================");
|
|
println!("[Harmony] To authenticate with Zitadel, open your browser:");
|
|
println!(" {}", code.verification_uri);
|
|
println!();
|
|
println!(" and enter code: {}", code.user_code);
|
|
if let Some(ref complete_url) = code.verification_uri_complete {
|
|
println!();
|
|
println!(" Or visit this direct link (code is pre-filled):");
|
|
println!(" {}", complete_url);
|
|
}
|
|
println!("=================================================");
|
|
println!();
|
|
}
|
|
|
|
async fn poll_for_token(
|
|
&self,
|
|
code: &DeviceAuthorizationResponse,
|
|
interval_secs: u64,
|
|
) -> Result<TokenResponse, String> {
|
|
let client = self.http_client()?;
|
|
let params = [
|
|
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
|
|
("device_code", code.device_code.as_str()),
|
|
("client_id", self.client_id.as_str()),
|
|
];
|
|
|
|
let interval = Duration::from_secs(interval_secs.max(5));
|
|
let max_attempts = (code.expires_in / interval_secs).max(60) as usize;
|
|
|
|
for attempt in 0..max_attempts {
|
|
tokio::time::sleep(interval).await;
|
|
|
|
let response = client
|
|
.post(format!("{}/oauth/v2/token", self.sso_url))
|
|
.form(¶ms)
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("Token request failed: {e}"))?;
|
|
|
|
let status = response.status();
|
|
let body = response
|
|
.text()
|
|
.await
|
|
.map_err(|e| format!("Failed to read response body: {e}"))?;
|
|
|
|
if status == 400 {
|
|
if let Ok(error) = serde_json::from_str::<TokenErrorResponse>(&body) {
|
|
match error.error.as_str() {
|
|
"authorization_pending" => {
|
|
debug!("ZITADEL_OIDC: authorization_pending (attempt {})", attempt);
|
|
continue;
|
|
}
|
|
"slow_down" => {
|
|
debug!("ZITADEL_OIDC: slow_down, increasing interval");
|
|
tokio::time::sleep(Duration::from_secs(5)).await;
|
|
continue;
|
|
}
|
|
"expired_token" => {
|
|
return Err(
|
|
"Device code expired. Please restart authentication.".to_string()
|
|
);
|
|
}
|
|
"access_denied" => {
|
|
return Err("Access denied by user.".to_string());
|
|
}
|
|
_ => {
|
|
return Err(format!(
|
|
"OAuth error: {} - {}",
|
|
error.error,
|
|
error.error_description.unwrap_or_default()
|
|
));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
return serde_json::from_str(&body)
|
|
.map_err(|e| format!("Failed to parse token response: {e}"));
|
|
}
|
|
|
|
Err("Token polling timed out".to_string())
|
|
}
|
|
|
|
async fn exchange_jwt_for_openbao_token(
|
|
&self,
|
|
id_token: &str,
|
|
) -> Result<(String, u64, bool), String> {
|
|
let openbao_url = self
|
|
.openbao_url
|
|
.as_ref()
|
|
.ok_or("No OpenBao URL configured for JWT exchange")?;
|
|
let mount = self.jwt_auth_mount.as_deref().unwrap_or("jwt");
|
|
let role = self.jwt_role.as_deref().unwrap_or("harmony-developer");
|
|
|
|
let client = self.http_client()?;
|
|
let url = format!("{}/v1/auth/{}/login", openbao_url, mount);
|
|
|
|
debug!(
|
|
"ZITADEL_OIDC: Exchanging id_token for OpenBao token via {}",
|
|
url
|
|
);
|
|
|
|
let response = client
|
|
.post(&url)
|
|
.json(&serde_json::json!({
|
|
"role": role,
|
|
"jwt": id_token
|
|
}))
|
|
.send()
|
|
.await
|
|
.map_err(|e| format!("JWT exchange request failed: {e}"))?;
|
|
|
|
if !response.status().is_success() {
|
|
let body = response.text().await.unwrap_or_default();
|
|
return Err(format!("JWT exchange failed: {body}"));
|
|
}
|
|
|
|
let body: serde_json::Value = response
|
|
.json()
|
|
.await
|
|
.map_err(|e| format!("Failed to parse JWT exchange response: {e}"))?;
|
|
|
|
let auth = body
|
|
.get("auth")
|
|
.ok_or("No 'auth' in JWT exchange response")?;
|
|
let client_token = auth
|
|
.get("client_token")
|
|
.and_then(|v| v.as_str())
|
|
.ok_or("No client_token in JWT exchange response")?
|
|
.to_string();
|
|
let lease_duration = auth
|
|
.get("lease_duration")
|
|
.and_then(|v| v.as_u64())
|
|
.unwrap_or(14400);
|
|
let renewable = auth
|
|
.get("renewable")
|
|
.and_then(|v| v.as_bool())
|
|
.unwrap_or(true);
|
|
|
|
info!(
|
|
"ZITADEL_OIDC: JWT exchange successful (ttl={}s)",
|
|
lease_duration
|
|
);
|
|
Ok((client_token, lease_duration, renewable))
|
|
}
|
|
|
|
async fn process_token_response(&self, response: TokenResponse) -> Result<OidcSession, String> {
|
|
let (openbao_token, ttl, renewable) = if self.openbao_url.is_some() {
|
|
let id_token = response
|
|
.id_token
|
|
.as_deref()
|
|
.ok_or("No id_token in OIDC response (required for JWT exchange)")?;
|
|
self.exchange_jwt_for_openbao_token(id_token).await?
|
|
} else {
|
|
(
|
|
response.access_token.clone(),
|
|
response.expires_in.unwrap_or(3600),
|
|
true,
|
|
)
|
|
};
|
|
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs() as i64)
|
|
.unwrap_or(0);
|
|
|
|
Ok(OidcSession {
|
|
openbao_token,
|
|
openbao_token_ttl: ttl,
|
|
openbao_renewable: renewable,
|
|
refresh_token: response.refresh_token,
|
|
id_token: response.id_token,
|
|
expires_at: Some(now + ttl as i64),
|
|
})
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_oidc_session_is_expired() {
|
|
let session = OidcSession {
|
|
openbao_token: "test".to_string(),
|
|
openbao_token_ttl: 3600,
|
|
openbao_renewable: true,
|
|
refresh_token: None,
|
|
id_token: None,
|
|
expires_at: Some(0),
|
|
};
|
|
assert!(session.is_expired());
|
|
|
|
let future = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs() as i64)
|
|
.unwrap_or(0)
|
|
+ 3600;
|
|
let session2 = OidcSession {
|
|
openbao_token: "test".to_string(),
|
|
openbao_token_ttl: 3600,
|
|
openbao_renewable: true,
|
|
refresh_token: None,
|
|
id_token: None,
|
|
expires_at: Some(future),
|
|
};
|
|
assert!(!session2.is_expired());
|
|
}
|
|
}
|