Compare commits
11 Commits
feat/broca
...
feat/harmo
| Author | SHA1 | Date | |
|---|---|---|---|
| ccc26e07eb | |||
| 9a67bcc96f | |||
| a377fc1404 | |||
| c9977fee12 | |||
| 64bf585e07 | |||
| 44e2c45435 | |||
| cdccbc8939 | |||
| 9830971d05 | |||
| e1183ef6de | |||
| 8499f4d1b7 | |||
| 67c3265286 |
@@ -15,4 +15,4 @@ jobs:
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run check script
|
||||
run: bash check.sh
|
||||
run: bash build/check.sh
|
||||
|
||||
868
Cargo.lock
generated
868
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -23,6 +23,7 @@ members = [
|
||||
"harmony_agent/deploy",
|
||||
"harmony_node_readiness",
|
||||
"harmony-k8s",
|
||||
"harmony_assets",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
@@ -37,6 +38,7 @@ derive-new = "0.7"
|
||||
async-trait = "0.1"
|
||||
tokio = { version = "1.40", features = [
|
||||
"io-std",
|
||||
"io-util",
|
||||
"fs",
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
@@ -73,6 +75,7 @@ base64 = "0.22.1"
|
||||
tar = "0.4.44"
|
||||
lazy_static = "1.5.0"
|
||||
directories = "6.0.0"
|
||||
futures-util = "0.3"
|
||||
thiserror = "2.0.14"
|
||||
serde = { version = "1.0.209", features = ["derive", "rc"] }
|
||||
serde_json = "1.0.127"
|
||||
@@ -86,3 +89,4 @@ reqwest = { version = "0.12", features = [
|
||||
"json",
|
||||
], default-features = false }
|
||||
assertor = "0.0.4"
|
||||
tokio-test = "0.4"
|
||||
|
||||
@@ -2,13 +2,14 @@ use std::collections::HashMap;
|
||||
|
||||
use k8s_openapi::api::{
|
||||
apps::v1::Deployment,
|
||||
core::v1::{Node, ServiceAccount},
|
||||
core::v1::{Namespace, Node, ServiceAccount},
|
||||
};
|
||||
use k8s_openapi::apiextensions_apiserver::pkg::apis::apiextensions::v1::CustomResourceDefinition;
|
||||
use kube::api::ApiResource;
|
||||
use kube::{
|
||||
Error, Resource,
|
||||
api::{Api, DynamicObject, GroupVersionKind, ListParams, ObjectList},
|
||||
core::ErrorResponse,
|
||||
runtime::conditions,
|
||||
runtime::wait::await_condition,
|
||||
};
|
||||
@@ -313,4 +314,65 @@ impl K8sClient {
|
||||
) -> Result<ObjectList<Node>, Error> {
|
||||
self.list_resources(None, list_params).await
|
||||
}
|
||||
|
||||
pub async fn namespace_exists(&self, name: &str) -> Result<bool, Error> {
|
||||
let api: Api<Namespace> = Api::all(self.client.clone());
|
||||
match api.get_opt(name).await? {
|
||||
Some(_) => Ok(true),
|
||||
None => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_namespace(&self, name: &str) -> Result<Namespace, Error> {
|
||||
let namespace = Namespace {
|
||||
metadata: k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta {
|
||||
name: Some(name.to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
..Default::default()
|
||||
};
|
||||
let api: Api<Namespace> = Api::all(self.client.clone());
|
||||
api.create(&kube::api::PostParams::default(), &namespace)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn wait_for_namespace(
|
||||
&self,
|
||||
name: &str,
|
||||
timeout: Option<Duration>,
|
||||
) -> Result<(), Error> {
|
||||
let api: Api<Namespace> = Api::all(self.client.clone());
|
||||
let timeout = timeout.unwrap_or(Duration::from_secs(60));
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
loop {
|
||||
if start.elapsed() > timeout {
|
||||
return Err(Error::Api(ErrorResponse {
|
||||
status: "Timeout".to_string(),
|
||||
message: format!("Namespace '{}' not ready within timeout", name),
|
||||
reason: "Timeout".to_string(),
|
||||
code: 408,
|
||||
}));
|
||||
}
|
||||
|
||||
match api.get_opt(name).await? {
|
||||
Some(ns) => {
|
||||
if let Some(status) = ns.status {
|
||||
if status.phase == Some("Active".to_string()) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err(Error::Api(ErrorResponse {
|
||||
status: "NotFound".to_string(),
|
||||
message: format!("Namespace '{}' not found", name),
|
||||
reason: "NotFound".to_string(),
|
||||
code: 404,
|
||||
}));
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,10 +267,16 @@ pub(crate) fn harmony_load_balancer_service_to_haproxy_xml(
|
||||
SSL::Default => "".into(),
|
||||
SSL::Other(other) => other.as_str().into(),
|
||||
};
|
||||
let path_without_query = path.split_once('?').map_or(path.as_str(), |(p, _)| p);
|
||||
let (port, port_name) = match port {
|
||||
Some(port) => (Some(port.to_string()), port.to_string()),
|
||||
None => (None, "serverport".to_string()),
|
||||
};
|
||||
|
||||
let haproxy_check = HAProxyHealthCheck {
|
||||
name: format!("HTTP_{http_method}_{path}"),
|
||||
name: format!("HTTP_{http_method}_{path_without_query}_{port_name}"),
|
||||
uuid: Uuid::new_v4().to_string(),
|
||||
http_method: http_method.to_string().into(),
|
||||
http_method: http_method.to_string().to_lowercase().into(),
|
||||
health_check_type: "http".to_string(),
|
||||
http_uri: path.clone().into(),
|
||||
interval: "2s".to_string(),
|
||||
@@ -314,7 +320,10 @@ pub(crate) fn harmony_load_balancer_service_to_haproxy_xml(
|
||||
let mut backend = HAProxyBackend {
|
||||
uuid: Uuid::new_v4().to_string(),
|
||||
enabled: 1,
|
||||
name: format!("backend_{}", service.listening_port),
|
||||
name: format!(
|
||||
"backend_{}",
|
||||
service.listening_port.to_string().replace(':', "_")
|
||||
),
|
||||
algorithm: "roundrobin".to_string(),
|
||||
random_draws: Some(2),
|
||||
stickiness_expire: "30m".to_string(),
|
||||
@@ -346,10 +355,22 @@ pub(crate) fn harmony_load_balancer_service_to_haproxy_xml(
|
||||
let frontend = Frontend {
|
||||
uuid: uuid::Uuid::new_v4().to_string(),
|
||||
enabled: 1,
|
||||
name: format!("frontend_{}", service.listening_port),
|
||||
name: format!(
|
||||
"frontend_{}",
|
||||
service.listening_port.to_string().replace(':', "_")
|
||||
),
|
||||
bind: service.listening_port.to_string(),
|
||||
mode: "tcp".to_string(), // TODO do not depend on health check here
|
||||
default_backend: Some(backend.uuid.clone()),
|
||||
stickiness_expire: "30m".to_string().into(),
|
||||
stickiness_size: "50k".to_string().into(),
|
||||
stickiness_conn_rate_period: "10s".to_string().into(),
|
||||
stickiness_sess_rate_period: "10s".to_string().into(),
|
||||
stickiness_http_req_rate_period: "10s".to_string().into(),
|
||||
stickiness_http_err_rate_period: "10s".to_string().into(),
|
||||
stickiness_bytes_in_rate_period: "1m".to_string().into(),
|
||||
stickiness_bytes_out_rate_period: "1m".to_string().into(),
|
||||
ssl_hsts_max_age: 15768000,
|
||||
..Default::default()
|
||||
};
|
||||
info!("HAPRoxy frontend and backend mode currently hardcoded to tcp");
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||
use serde::Serialize;
|
||||
use std::str::FromStr;
|
||||
|
||||
use non_blank_string_rs::NonBlankString;
|
||||
|
||||
use crate::interpret::Interpret;
|
||||
use crate::modules::helm::chart::HelmChartScore;
|
||||
use crate::modules::k8s::apps::crd::{Subscription, SubscriptionSpec};
|
||||
use crate::modules::k8s::resource::K8sResourceScore;
|
||||
use crate::score::Score;
|
||||
use crate::topology::{K8sclient, Topology};
|
||||
use crate::topology::{HelmCommand, K8sclient, Topology};
|
||||
|
||||
/// Install the CloudNativePg (CNPG) Operator via an OperatorHub `Subscription`.
|
||||
///
|
||||
@@ -100,3 +104,41 @@ impl<T: Topology + K8sclient> Score<T> for CloudNativePgOperatorScore {
|
||||
format!("CloudNativePgOperatorScore({})", self.namespace)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct CloudNativePgOperatorHelmScore {
|
||||
pub namespace: String,
|
||||
}
|
||||
|
||||
impl Default for CloudNativePgOperatorHelmScore {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
namespace: "cnpg-system".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Topology + K8sclient + HelmCommand + 'static> Score<T> for CloudNativePgOperatorHelmScore {
|
||||
fn name(&self) -> String {
|
||||
format!("CloudNativePgOperatorHelmScore({})", self.namespace)
|
||||
}
|
||||
|
||||
fn create_interpret(&self) -> Box<dyn Interpret<T>> {
|
||||
let cnpg_helm_score = HelmChartScore {
|
||||
namespace: Some(NonBlankString::from_str(&self.namespace).unwrap()),
|
||||
release_name: NonBlankString::from_str("cloudnative-pg").unwrap(),
|
||||
chart_name: NonBlankString::from_str(
|
||||
"oci://ghcr.io/cloudnative-pg/charts/cloudnative-pg",
|
||||
)
|
||||
.unwrap(),
|
||||
chart_version: None,
|
||||
values_overrides: None,
|
||||
values_yaml: None,
|
||||
create_namespace: true,
|
||||
install_only: true,
|
||||
repository: None,
|
||||
};
|
||||
|
||||
cnpg_helm_score.create_interpret()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,24 +1,35 @@
|
||||
use crate::data::Version;
|
||||
use crate::interpret::{Interpret, InterpretError, InterpretName, InterpretStatus, Outcome};
|
||||
use crate::inventory::Inventory;
|
||||
|
||||
use crate::modules::k8s::resource::K8sResourceScore;
|
||||
use crate::modules::postgresql::capability::PostgreSQLConfig;
|
||||
use crate::modules::postgresql::cnpg::{
|
||||
Bootstrap, Cluster, ClusterSpec, ExternalCluster, Initdb, PgBaseBackup, ReplicaSpec,
|
||||
SecretKeySelector, Storage,
|
||||
};
|
||||
use crate::modules::postgresql::operator::{
|
||||
CloudNativePgOperatorHelmScore, CloudNativePgOperatorScore,
|
||||
};
|
||||
use crate::score::Score;
|
||||
use crate::topology::{K8sclient, Topology};
|
||||
use crate::topology::{HelmCommand, K8sclient, Topology};
|
||||
use async_trait::async_trait;
|
||||
use harmony_k8s::KubernetesDistribution;
|
||||
use harmony_types::id::Id;
|
||||
use k8s_openapi::ByteString;
|
||||
use k8s_openapi::api::core::v1::Secret;
|
||||
use k8s_openapi::api::core::v1::{Pod, Secret};
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||
use log::{info, warn};
|
||||
use serde::Serialize;
|
||||
|
||||
/// Deploys an opinionated, highly available PostgreSQL cluster managed by CNPG.
|
||||
///
|
||||
/// This score automatically ensures the CloudNativePG (CNPG) operator is installed
|
||||
/// before creating the Cluster CRD. The installation method depends on the Kubernetes
|
||||
/// distribution:
|
||||
///
|
||||
/// - **OpenShift/OKD**: Uses OperatorHub Subscription via `CloudNativePgOperatorScore`
|
||||
/// - **K3s/Other**: Uses Helm chart via `CloudNativePgOperatorHelmScore`
|
||||
///
|
||||
/// # Usage
|
||||
/// ```
|
||||
/// use harmony::modules::postgresql::PostgreSQLScore;
|
||||
@@ -26,12 +37,7 @@ use serde::Serialize;
|
||||
/// ```
|
||||
///
|
||||
/// # Limitations (Happy Path)
|
||||
/// - Requires CNPG operator installed (use CloudNativePgOperatorScore).
|
||||
/// - No backups, monitoring, extensions configured.
|
||||
///
|
||||
/// TODO : refactor this to declare a clean dependency on cnpg operator. Then cnpg operator will
|
||||
/// self-deploy either using operatorhub or helm chart depending on k8s flavor. This is cnpg
|
||||
/// specific behavior
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct K8sPostgreSQLScore {
|
||||
pub config: PostgreSQLConfig,
|
||||
@@ -56,7 +62,7 @@ impl K8sPostgreSQLScore {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Topology + K8sclient> Score<T> for K8sPostgreSQLScore {
|
||||
impl<T: Topology + K8sclient + HelmCommand + 'static> Score<T> for K8sPostgreSQLScore {
|
||||
fn create_interpret(&self) -> Box<dyn Interpret<T>> {
|
||||
Box::new(K8sPostgreSQLInterpret {
|
||||
config: self.config.clone(),
|
||||
@@ -73,13 +79,127 @@ pub struct K8sPostgreSQLInterpret {
|
||||
config: PostgreSQLConfig,
|
||||
}
|
||||
|
||||
impl K8sPostgreSQLInterpret {
|
||||
async fn ensure_namespace<T: Topology + K8sclient>(
|
||||
&self,
|
||||
topology: &T,
|
||||
) -> Result<(), InterpretError> {
|
||||
let k8s_client = topology
|
||||
.k8s_client()
|
||||
.await
|
||||
.map_err(|e| InterpretError::new(format!("Failed to get k8s client: {}", e)))?;
|
||||
|
||||
let namespace_name = &self.config.namespace;
|
||||
|
||||
if k8s_client
|
||||
.namespace_exists(namespace_name)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InterpretError::new(format!(
|
||||
"Failed to check namespace '{}': {}",
|
||||
namespace_name, e
|
||||
))
|
||||
})?
|
||||
{
|
||||
info!("Namespace '{}' already exists", namespace_name);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Creating namespace '{}'", namespace_name);
|
||||
k8s_client
|
||||
.create_namespace(namespace_name)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InterpretError::new(format!(
|
||||
"Failed to create namespace '{}': {}",
|
||||
namespace_name, e
|
||||
))
|
||||
})?;
|
||||
|
||||
k8s_client
|
||||
.wait_for_namespace(namespace_name, Some(std::time::Duration::from_secs(30)))
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InterpretError::new(format!("Namespace '{}' not ready: {}", namespace_name, e))
|
||||
})?;
|
||||
|
||||
info!("Namespace '{}' is ready", namespace_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ensure_cnpg_operator<T: Topology + K8sclient + HelmCommand + 'static>(
|
||||
&self,
|
||||
topology: &T,
|
||||
) -> Result<(), InterpretError> {
|
||||
let k8s_client = topology
|
||||
.k8s_client()
|
||||
.await
|
||||
.map_err(|e| InterpretError::new(format!("Failed to get k8s client: {}", e)))?;
|
||||
|
||||
let pods = k8s_client
|
||||
.list_all_resources_with_labels::<Pod>("app.kubernetes.io/name=cloudnative-pg")
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InterpretError::new(format!("Failed to list CNPG operator pods: {}", e))
|
||||
})?;
|
||||
|
||||
if !pods.is_empty() {
|
||||
info!("CNPG operator is already installed");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
warn!("CNPG operator not found, installing...");
|
||||
let distro = k8s_client.get_k8s_distribution().await.map_err(|e| {
|
||||
InterpretError::new(format!("Failed to detect k8s distribution: {}", e))
|
||||
})?;
|
||||
|
||||
match distro {
|
||||
KubernetesDistribution::OpenshiftFamily => {
|
||||
info!("Installing CNPG operator via OperatorHub Subscription");
|
||||
let score = CloudNativePgOperatorScore::default_openshift();
|
||||
score
|
||||
.interpret(&Inventory::empty(), topology)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InterpretError::new(format!("Failed to install CNPG operator: {}", e))
|
||||
})?;
|
||||
}
|
||||
KubernetesDistribution::K3sFamily | KubernetesDistribution::Default => {
|
||||
info!("Installing CNPG operator via Helm chart");
|
||||
let score = CloudNativePgOperatorHelmScore::default();
|
||||
score
|
||||
.interpret(&Inventory::empty(), topology)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
InterpretError::new(format!("Failed to install CNPG operator: {}", e))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
k8s_client
|
||||
.wait_until_deployment_ready(
|
||||
"cloudnative-pg",
|
||||
Some("cnpg-system"),
|
||||
Some(std::time::Duration::from_secs(120)),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| InterpretError::new(format!("CNPG operator not ready: {}", e)))?;
|
||||
|
||||
info!("CNPG operator is ready");
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl<T: Topology + K8sclient> Interpret<T> for K8sPostgreSQLInterpret {
|
||||
impl<T: Topology + K8sclient + HelmCommand + 'static> Interpret<T> for K8sPostgreSQLInterpret {
|
||||
async fn execute(
|
||||
&self,
|
||||
inventory: &Inventory,
|
||||
topology: &T,
|
||||
) -> Result<Outcome, InterpretError> {
|
||||
self.ensure_cnpg_operator(topology).await?;
|
||||
self.ensure_namespace(topology).await?;
|
||||
|
||||
match &self.config.role {
|
||||
super::capability::PostgreSQLClusterRole::Primary => {
|
||||
let metadata = ObjectMeta {
|
||||
|
||||
56
harmony_assets/Cargo.toml
Normal file
56
harmony_assets/Cargo.toml
Normal file
@@ -0,0 +1,56 @@
|
||||
[package]
|
||||
name = "harmony_assets"
|
||||
edition = "2024"
|
||||
version.workspace = true
|
||||
readme.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[lib]
|
||||
name = "harmony_assets"
|
||||
|
||||
[[bin]]
|
||||
name = "harmony_assets"
|
||||
path = "src/cli/mod.rs"
|
||||
required-features = ["cli"]
|
||||
|
||||
[features]
|
||||
default = ["blake3"]
|
||||
sha256 = ["dep:sha2"]
|
||||
blake3 = ["dep:blake3"]
|
||||
s3 = [
|
||||
"dep:aws-sdk-s3",
|
||||
"dep:aws-config",
|
||||
]
|
||||
cli = [
|
||||
"dep:clap",
|
||||
"dep:indicatif",
|
||||
"dep:inquire",
|
||||
]
|
||||
reqwest = ["dep:reqwest"]
|
||||
|
||||
[dependencies]
|
||||
log.workspace = true
|
||||
tokio.workspace = true
|
||||
thiserror.workspace = true
|
||||
directories.workspace = true
|
||||
sha2 = { version = "0.10", optional = true }
|
||||
blake3 = { version = "1.5", optional = true }
|
||||
reqwest = { version = "0.12", optional = true, default-features = false, features = ["stream", "rustls-tls"] }
|
||||
futures-util.workspace = true
|
||||
async-trait.workspace = true
|
||||
url.workspace = true
|
||||
|
||||
# CLI only
|
||||
clap = { version = "4.5", features = ["derive"], optional = true }
|
||||
indicatif = { version = "0.18", optional = true }
|
||||
inquire = { version = "0.7", optional = true }
|
||||
|
||||
# S3 only
|
||||
aws-sdk-s3 = { version = "1", optional = true }
|
||||
aws-config = { version = "1", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile.workspace = true
|
||||
httptest = "0.16"
|
||||
pretty_assertions.workspace = true
|
||||
tokio-test.workspace = true
|
||||
80
harmony_assets/src/asset.rs
Normal file
80
harmony_assets/src/asset.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use crate::hash::ChecksumAlgo;
|
||||
use std::path::PathBuf;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Asset {
|
||||
pub url: Url,
|
||||
pub checksum: String,
|
||||
pub checksum_algo: ChecksumAlgo,
|
||||
pub file_name: String,
|
||||
pub size: Option<u64>,
|
||||
}
|
||||
|
||||
impl Asset {
|
||||
pub fn new(url: Url, checksum: String, checksum_algo: ChecksumAlgo, file_name: String) -> Self {
|
||||
Self {
|
||||
url,
|
||||
checksum,
|
||||
checksum_algo,
|
||||
file_name,
|
||||
size: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_size(mut self, size: u64) -> Self {
|
||||
self.size = Some(size);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn formatted_checksum(&self) -> String {
|
||||
crate::hash::format_checksum(&self.checksum, self.checksum_algo.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LocalCache {
|
||||
pub base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl LocalCache {
|
||||
pub fn new(base_dir: PathBuf) -> Self {
|
||||
Self { base_dir }
|
||||
}
|
||||
|
||||
pub fn path_for(&self, asset: &Asset) -> PathBuf {
|
||||
let prefix = &asset.checksum[..16.min(asset.checksum.len())];
|
||||
self.base_dir.join(prefix).join(&asset.file_name)
|
||||
}
|
||||
|
||||
pub fn cache_key_dir(&self, asset: &Asset) -> PathBuf {
|
||||
let prefix = &asset.checksum[..16.min(asset.checksum.len())];
|
||||
self.base_dir.join(prefix)
|
||||
}
|
||||
|
||||
pub async fn ensure_dir(&self, asset: &Asset) -> Result<(), crate::errors::AssetError> {
|
||||
let dir = self.cache_key_dir(asset);
|
||||
tokio::fs::create_dir_all(&dir)
|
||||
.await
|
||||
.map_err(|e| crate::errors::AssetError::IoError(e))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LocalCache {
|
||||
fn default() -> Self {
|
||||
let base_dir = directories::ProjectDirs::from("io", "NationTech", "Harmony")
|
||||
.map(|dirs| dirs.cache_dir().join("assets"))
|
||||
.unwrap_or_else(|| PathBuf::from("/tmp/harmony_assets"));
|
||||
Self::new(base_dir)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StoredAsset {
|
||||
pub url: Url,
|
||||
pub checksum: String,
|
||||
pub checksum_algo: ChecksumAlgo,
|
||||
pub size: u64,
|
||||
pub key: String,
|
||||
}
|
||||
25
harmony_assets/src/cli/checksum.rs
Normal file
25
harmony_assets/src/cli/checksum.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub struct ChecksumArgs {
|
||||
pub path: String,
|
||||
#[arg(short, long, default_value = "blake3")]
|
||||
pub algo: String,
|
||||
}
|
||||
|
||||
pub async fn execute(args: ChecksumArgs) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use harmony_assets::{ChecksumAlgo, checksum_for_path};
|
||||
|
||||
let path = std::path::Path::new(&args.path);
|
||||
if !path.exists() {
|
||||
eprintln!("Error: File not found: {}", args.path);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let algo = ChecksumAlgo::from_str(&args.algo)?;
|
||||
let checksum = checksum_for_path(path, algo.clone()).await?;
|
||||
|
||||
println!("{}:{} {}", algo.name(), checksum, args.path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
82
harmony_assets/src/cli/download.rs
Normal file
82
harmony_assets/src/cli/download.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub struct DownloadArgs {
|
||||
pub url: String,
|
||||
pub checksum: String,
|
||||
#[arg(short, long)]
|
||||
pub output: Option<String>,
|
||||
#[arg(short, long, default_value = "blake3")]
|
||||
pub algo: String,
|
||||
}
|
||||
|
||||
pub async fn execute(args: DownloadArgs) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use harmony_assets::{
|
||||
Asset, AssetStore, ChecksumAlgo, LocalCache, LocalStore, verify_checksum,
|
||||
};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use url::Url;
|
||||
|
||||
let url = Url::parse(&args.url).map_err(|e| format!("Invalid URL: {}", e))?;
|
||||
|
||||
let file_name = args
|
||||
.output
|
||||
.or_else(|| {
|
||||
std::path::Path::new(&args.url)
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.map(|s| s.to_string())
|
||||
})
|
||||
.unwrap_or_else(|| "download".to_string());
|
||||
|
||||
let algo = ChecksumAlgo::from_str(&args.algo)?;
|
||||
let asset = Asset::new(url, args.checksum.clone(), algo.clone(), file_name);
|
||||
|
||||
let cache = LocalCache::default();
|
||||
|
||||
println!("Downloading: {}", asset.url);
|
||||
println!("Checksum: {}:{}", algo.name(), args.checksum);
|
||||
println!("Cache dir: {:?}", cache.base_dir);
|
||||
|
||||
let total_size = asset.size.unwrap_or(0);
|
||||
let pb = if total_size > 0 {
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("{spinner:.green} [{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({bytes_per_sec})")?
|
||||
.progress_chars("=>-"),
|
||||
);
|
||||
Some(pb)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let progress_fn: Box<dyn Fn(u64, Option<u64>) + Send> = Box::new({
|
||||
let pb = pb.clone();
|
||||
move |bytes, _total| {
|
||||
if let Some(ref pb) = pb {
|
||||
pb.set_position(bytes);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let store = LocalStore::default();
|
||||
let result = store.fetch(&asset, &cache, Some(progress_fn)).await;
|
||||
|
||||
if let Some(pb) = pb {
|
||||
pb.finish();
|
||||
}
|
||||
|
||||
match result {
|
||||
Ok(path) => {
|
||||
verify_checksum(&path, &args.checksum, algo).await?;
|
||||
println!("\nDownloaded to: {:?}", path);
|
||||
println!("Checksum verified OK");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Download failed: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
49
harmony_assets/src/cli/mod.rs
Normal file
49
harmony_assets/src/cli/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
pub mod checksum;
|
||||
pub mod download;
|
||||
pub mod upload;
|
||||
pub mod verify;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "harmony_assets",
|
||||
version,
|
||||
about = "Asset management CLI for downloading, uploading, and verifying large binary assets"
|
||||
)]
|
||||
pub struct Cli {
|
||||
#[command(subcommand)]
|
||||
pub command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
pub enum Commands {
|
||||
Upload(upload::UploadArgs),
|
||||
Download(download::DownloadArgs),
|
||||
Checksum(checksum::ChecksumArgs),
|
||||
Verify(verify::VerifyArgs),
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
log::info!("Starting harmony_assets CLI");
|
||||
|
||||
let cli = Cli::parse();
|
||||
|
||||
match cli.command {
|
||||
Commands::Upload(args) => {
|
||||
upload::execute(args).await?;
|
||||
}
|
||||
Commands::Download(args) => {
|
||||
download::execute(args).await?;
|
||||
}
|
||||
Commands::Checksum(args) => {
|
||||
checksum::execute(args).await?;
|
||||
}
|
||||
Commands::Verify(args) => {
|
||||
verify::execute(args).await?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
166
harmony_assets/src/cli/upload.rs
Normal file
166
harmony_assets/src/cli/upload.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
use clap::Parser;
|
||||
use harmony_assets::{S3Config, S3Store, checksum_for_path_with_progress};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use std::path::Path;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub struct UploadArgs {
|
||||
pub source: String,
|
||||
pub key: Option<String>,
|
||||
#[arg(short, long)]
|
||||
pub content_type: Option<String>,
|
||||
#[arg(short, long, default_value_t = true)]
|
||||
pub public_read: bool,
|
||||
#[arg(short, long)]
|
||||
pub endpoint: Option<String>,
|
||||
#[arg(short, long)]
|
||||
pub bucket: Option<String>,
|
||||
#[arg(short, long)]
|
||||
pub region: Option<String>,
|
||||
#[arg(short, long)]
|
||||
pub access_key_id: Option<String>,
|
||||
#[arg(short, long)]
|
||||
pub secret_access_key: Option<String>,
|
||||
#[arg(short, long, default_value = "blake3")]
|
||||
pub algo: String,
|
||||
#[arg(short, long, default_value_t = false)]
|
||||
pub yes: bool,
|
||||
}
|
||||
|
||||
pub async fn execute(args: UploadArgs) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let source_path = Path::new(&args.source);
|
||||
if !source_path.exists() {
|
||||
eprintln!("Error: File not found: {}", args.source);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let key = args.key.unwrap_or_else(|| {
|
||||
source_path
|
||||
.file_name()
|
||||
.and_then(|n| n.to_str())
|
||||
.unwrap_or("upload")
|
||||
.to_string()
|
||||
});
|
||||
|
||||
let metadata = tokio::fs::metadata(source_path)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to read file metadata: {}", e))?;
|
||||
let total_size = metadata.len();
|
||||
|
||||
let endpoint = args
|
||||
.endpoint
|
||||
.or_else(|| std::env::var("S3_ENDPOINT").ok())
|
||||
.unwrap_or_default();
|
||||
let bucket = args
|
||||
.bucket
|
||||
.or_else(|| std::env::var("S3_BUCKET").ok())
|
||||
.unwrap_or_else(|| {
|
||||
inquire::Text::new("S3 Bucket name:")
|
||||
.with_default("harmony-assets")
|
||||
.prompt()
|
||||
.unwrap()
|
||||
});
|
||||
let region = args
|
||||
.region
|
||||
.or_else(|| std::env::var("S3_REGION").ok())
|
||||
.unwrap_or_else(|| {
|
||||
inquire::Text::new("S3 Region:")
|
||||
.with_default("us-east-1")
|
||||
.prompt()
|
||||
.unwrap()
|
||||
});
|
||||
let access_key_id = args
|
||||
.access_key_id
|
||||
.or_else(|| std::env::var("AWS_ACCESS_KEY_ID").ok());
|
||||
let secret_access_key = args
|
||||
.secret_access_key
|
||||
.or_else(|| std::env::var("AWS_SECRET_ACCESS_KEY").ok());
|
||||
|
||||
let config = S3Config {
|
||||
endpoint: if endpoint.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(endpoint)
|
||||
},
|
||||
bucket: bucket.clone(),
|
||||
region: region.clone(),
|
||||
access_key_id,
|
||||
secret_access_key,
|
||||
public_read: args.public_read,
|
||||
};
|
||||
|
||||
println!("Upload Configuration:");
|
||||
println!(" Source: {}", args.source);
|
||||
println!(" S3 Key: {}", key);
|
||||
println!(" Bucket: {}", bucket);
|
||||
println!(" Region: {}", region);
|
||||
println!(
|
||||
" Size: {} bytes ({} MB)",
|
||||
total_size,
|
||||
total_size as f64 / 1024.0 / 1024.0
|
||||
);
|
||||
println!();
|
||||
|
||||
if !args.yes {
|
||||
let confirm = inquire::Confirm::new("Proceed with upload?")
|
||||
.with_default(true)
|
||||
.prompt()?;
|
||||
if !confirm {
|
||||
println!("Upload cancelled.");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
let store = S3Store::new(config)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to initialize S3 client: {}", e))?;
|
||||
|
||||
println!("Computing checksum while uploading...\n");
|
||||
|
||||
let pb = ProgressBar::new(total_size);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template("{spinner:.green} [{elapsed_precise}] [{bar:40}] {bytes}/{total_bytes} ({bytes_per_sec})")?
|
||||
.progress_chars("=>-"),
|
||||
);
|
||||
|
||||
{
|
||||
let algo = harmony_assets::ChecksumAlgo::from_str(&args.algo)?;
|
||||
let rt = tokio::runtime::Handle::current();
|
||||
let pb_clone = pb.clone();
|
||||
let _checksum = rt.block_on(checksum_for_path_with_progress(
|
||||
source_path,
|
||||
algo,
|
||||
|read, _total| {
|
||||
pb_clone.set_position(read);
|
||||
},
|
||||
))?;
|
||||
}
|
||||
|
||||
pb.set_position(total_size);
|
||||
|
||||
let result = store
|
||||
.store(source_path, &key, args.content_type.as_deref())
|
||||
.await;
|
||||
|
||||
pb.finish();
|
||||
|
||||
match result {
|
||||
Ok(asset) => {
|
||||
println!("\nUpload complete!");
|
||||
println!(" URL: {}", asset.url);
|
||||
println!(
|
||||
" Checksum: {}:{}",
|
||||
asset.checksum_algo.name(),
|
||||
asset.checksum
|
||||
);
|
||||
println!(" Size: {} bytes", asset.size);
|
||||
println!(" Key: {}", asset.key);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Upload failed: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
32
harmony_assets/src/cli/verify.rs
Normal file
32
harmony_assets/src/cli/verify.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use clap::Parser;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
pub struct VerifyArgs {
|
||||
pub path: String,
|
||||
pub expected: String,
|
||||
#[arg(short, long, default_value = "blake3")]
|
||||
pub algo: String,
|
||||
}
|
||||
|
||||
pub async fn execute(args: VerifyArgs) -> Result<(), Box<dyn std::error::Error>> {
|
||||
use harmony_assets::{ChecksumAlgo, verify_checksum};
|
||||
|
||||
let path = std::path::Path::new(&args.path);
|
||||
if !path.exists() {
|
||||
eprintln!("Error: File not found: {}", args.path);
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
let algo = ChecksumAlgo::from_str(&args.algo)?;
|
||||
|
||||
match verify_checksum(path, &args.expected, algo).await {
|
||||
Ok(()) => {
|
||||
println!("Checksum verified OK");
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Verification FAILED: {}", e);
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
}
|
||||
37
harmony_assets/src/errors.rs
Normal file
37
harmony_assets/src/errors.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use std::path::PathBuf;
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AssetError {
|
||||
#[error("File not found: {0}")]
|
||||
FileNotFound(PathBuf),
|
||||
|
||||
#[error("Checksum mismatch for '{path}': expected {expected}, got {actual}")]
|
||||
ChecksumMismatch {
|
||||
path: PathBuf,
|
||||
expected: String,
|
||||
actual: String,
|
||||
},
|
||||
|
||||
#[error("Checksum algorithm not available: {0}. Enable the corresponding feature flag.")]
|
||||
ChecksumAlgoNotAvailable(String),
|
||||
|
||||
#[error("Download failed: {0}")]
|
||||
DownloadFailed(String),
|
||||
|
||||
#[error("S3 error: {0}")]
|
||||
S3Error(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
IoError(#[from] std::io::Error),
|
||||
|
||||
#[cfg(feature = "reqwest")]
|
||||
#[error("HTTP error: {0}")]
|
||||
HttpError(#[from] reqwest::Error),
|
||||
|
||||
#[error("Store error: {0}")]
|
||||
StoreError(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
ConfigError(String),
|
||||
}
|
||||
233
harmony_assets/src/hash.rs
Normal file
233
harmony_assets/src/hash.rs
Normal file
@@ -0,0 +1,233 @@
|
||||
use crate::errors::AssetError;
|
||||
use std::path::Path;
|
||||
|
||||
#[cfg(feature = "blake3")]
|
||||
use blake3::Hasher as B3Hasher;
|
||||
#[cfg(feature = "sha256")]
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ChecksumAlgo {
|
||||
BLAKE3,
|
||||
SHA256,
|
||||
}
|
||||
|
||||
impl Default for ChecksumAlgo {
|
||||
fn default() -> Self {
|
||||
#[cfg(feature = "blake3")]
|
||||
return ChecksumAlgo::BLAKE3;
|
||||
#[cfg(not(feature = "blake3"))]
|
||||
return ChecksumAlgo::SHA256;
|
||||
}
|
||||
}
|
||||
|
||||
impl ChecksumAlgo {
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
ChecksumAlgo::BLAKE3 => "blake3",
|
||||
ChecksumAlgo::SHA256 => "sha256",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_str(s: &str) -> Result<Self, AssetError> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"blake3" | "b3" => Ok(ChecksumAlgo::BLAKE3),
|
||||
"sha256" | "sha-256" => Ok(ChecksumAlgo::SHA256),
|
||||
_ => Err(AssetError::ChecksumAlgoNotAvailable(s.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ChecksumAlgo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.name())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn checksum_for_file<R>(reader: R, algo: ChecksumAlgo) -> Result<String, AssetError>
|
||||
where
|
||||
R: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
match algo {
|
||||
#[cfg(feature = "blake3")]
|
||||
ChecksumAlgo::BLAKE3 => {
|
||||
let mut hasher = B3Hasher::new();
|
||||
let mut reader = reader;
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = tokio::io::AsyncReadExt::read(&mut reader, &mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
Ok(hasher.finalize().to_hex().to_string())
|
||||
}
|
||||
#[cfg(not(feature = "blake3"))]
|
||||
ChecksumAlgo::BLAKE3 => Err(AssetError::ChecksumAlgoNotAvailable("blake3".to_string())),
|
||||
#[cfg(feature = "sha256")]
|
||||
ChecksumAlgo::SHA256 => {
|
||||
let mut hasher = Sha256::new();
|
||||
let mut reader = reader;
|
||||
let mut buf = vec![0u8; 65536];
|
||||
loop {
|
||||
let n = tokio::io::AsyncReadExt::read(&mut reader, &mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
#[cfg(not(feature = "sha256"))]
|
||||
ChecksumAlgo::SHA256 => Err(AssetError::ChecksumAlgoNotAvailable("sha256".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn checksum_for_path(path: &Path, algo: ChecksumAlgo) -> Result<String, AssetError> {
|
||||
let file = tokio::fs::File::open(path)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
let reader = tokio::io::BufReader::with_capacity(65536, file);
|
||||
checksum_for_file(reader, algo).await
|
||||
}
|
||||
|
||||
pub async fn checksum_for_path_with_progress<F>(
|
||||
path: &Path,
|
||||
algo: ChecksumAlgo,
|
||||
mut progress: F,
|
||||
) -> Result<String, AssetError>
|
||||
where
|
||||
F: FnMut(u64, Option<u64>) + Send,
|
||||
{
|
||||
let file = tokio::fs::File::open(path)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
let metadata = file.metadata().await.map_err(|e| AssetError::IoError(e))?;
|
||||
let total = Some(metadata.len());
|
||||
let reader = tokio::io::BufReader::with_capacity(65536, file);
|
||||
|
||||
match algo {
|
||||
#[cfg(feature = "blake3")]
|
||||
ChecksumAlgo::BLAKE3 => {
|
||||
let mut hasher = B3Hasher::new();
|
||||
let mut reader = reader;
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut read: u64 = 0;
|
||||
loop {
|
||||
let n = tokio::io::AsyncReadExt::read(&mut reader, &mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
read += n as u64;
|
||||
progress(read, total);
|
||||
}
|
||||
Ok(hasher.finalize().to_hex().to_string())
|
||||
}
|
||||
#[cfg(not(feature = "blake3"))]
|
||||
ChecksumAlgo::BLAKE3 => Err(AssetError::ChecksumAlgoNotAvailable("blake3".to_string())),
|
||||
#[cfg(feature = "sha256")]
|
||||
ChecksumAlgo::SHA256 => {
|
||||
let mut hasher = Sha256::new();
|
||||
let mut reader = reader;
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut read: u64 = 0;
|
||||
loop {
|
||||
let n = tokio::io::AsyncReadExt::read(&mut reader, &mut buf).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
hasher.update(&buf[..n]);
|
||||
read += n as u64;
|
||||
progress(read, total);
|
||||
}
|
||||
Ok(format!("{:x}", hasher.finalize()))
|
||||
}
|
||||
#[cfg(not(feature = "sha256"))]
|
||||
ChecksumAlgo::SHA256 => Err(AssetError::ChecksumAlgoNotAvailable("sha256".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn verify_checksum(
|
||||
path: &Path,
|
||||
expected: &str,
|
||||
algo: ChecksumAlgo,
|
||||
) -> Result<(), AssetError> {
|
||||
let actual = checksum_for_path(path, algo).await?;
|
||||
let expected_clean = expected
|
||||
.trim_start_matches("blake3:")
|
||||
.trim_start_matches("sha256:")
|
||||
.trim_start_matches("b3:")
|
||||
.trim_start_matches("sha-256:");
|
||||
if actual != expected_clean {
|
||||
return Err(AssetError::ChecksumMismatch {
|
||||
path: path.to_path_buf(),
|
||||
expected: expected_clean.to_string(),
|
||||
actual,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn format_checksum(checksum: &str, algo: ChecksumAlgo) -> String {
|
||||
format!("{}:{}", algo.name(), checksum)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
async fn create_temp_file(content: &[u8]) -> NamedTempFile {
|
||||
let mut file = NamedTempFile::new().unwrap();
|
||||
file.write_all(content).unwrap();
|
||||
file.flush().unwrap();
|
||||
file
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_checksum_blake3() {
|
||||
let file = create_temp_file(b"hello world").await;
|
||||
let checksum = checksum_for_path(file.path(), ChecksumAlgo::BLAKE3)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
checksum,
|
||||
"d74981efa70a0c880b8d8c1985d075dbcbf679b99a5f9914e5aaf96b831a9e24"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verify_checksum_success() {
|
||||
let file = create_temp_file(b"hello world").await;
|
||||
let checksum = checksum_for_path(file.path(), ChecksumAlgo::BLAKE3)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = verify_checksum(file.path(), &checksum, ChecksumAlgo::BLAKE3).await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_verify_checksum_failure() {
|
||||
let file = create_temp_file(b"hello world").await;
|
||||
let result = verify_checksum(
|
||||
file.path(),
|
||||
"blake3:0000000000000000000000000000000000000000000000000000000000000000",
|
||||
ChecksumAlgo::BLAKE3,
|
||||
)
|
||||
.await;
|
||||
assert!(matches!(result, Err(AssetError::ChecksumMismatch { .. })));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_checksum_with_prefix() {
|
||||
let file = create_temp_file(b"hello world").await;
|
||||
let checksum = checksum_for_path(file.path(), ChecksumAlgo::BLAKE3)
|
||||
.await
|
||||
.unwrap();
|
||||
let formatted = format_checksum(&checksum, ChecksumAlgo::BLAKE3);
|
||||
assert!(formatted.starts_with("blake3:"));
|
||||
}
|
||||
}
|
||||
14
harmony_assets/src/lib.rs
Normal file
14
harmony_assets/src/lib.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
pub mod asset;
|
||||
pub mod errors;
|
||||
pub mod hash;
|
||||
pub mod store;
|
||||
|
||||
pub use asset::{Asset, LocalCache, StoredAsset};
|
||||
pub use errors::AssetError;
|
||||
pub use hash::{ChecksumAlgo, checksum_for_path, checksum_for_path_with_progress, verify_checksum};
|
||||
pub use store::AssetStore;
|
||||
|
||||
#[cfg(feature = "s3")]
|
||||
pub use store::{S3Config, S3Store};
|
||||
|
||||
pub use store::local::LocalStore;
|
||||
137
harmony_assets/src/store/local.rs
Normal file
137
harmony_assets/src/store/local.rs
Normal file
@@ -0,0 +1,137 @@
|
||||
use crate::asset::{Asset, LocalCache};
|
||||
use crate::errors::AssetError;
|
||||
use crate::store::AssetStore;
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use url::Url;
|
||||
|
||||
#[cfg(feature = "reqwest")]
|
||||
use crate::hash::verify_checksum;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LocalStore {
|
||||
base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl LocalStore {
|
||||
pub fn new(base_dir: PathBuf) -> Self {
|
||||
Self { base_dir }
|
||||
}
|
||||
|
||||
pub fn with_cache(cache: LocalCache) -> Self {
|
||||
Self {
|
||||
base_dir: cache.base_dir.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn base_dir(&self) -> &PathBuf {
|
||||
&self.base_dir
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LocalStore {
|
||||
fn default() -> Self {
|
||||
Self::new(LocalCache::default().base_dir)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl AssetStore for LocalStore {
|
||||
#[cfg(feature = "reqwest")]
|
||||
async fn fetch(
|
||||
&self,
|
||||
asset: &Asset,
|
||||
cache: &LocalCache,
|
||||
progress: Option<Box<dyn Fn(u64, Option<u64>) + Send>>,
|
||||
) -> Result<PathBuf, AssetError> {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
let dest_path = cache.path_for(asset);
|
||||
|
||||
if dest_path.exists() {
|
||||
let verification =
|
||||
verify_checksum(&dest_path, &asset.checksum, asset.checksum_algo.clone()).await;
|
||||
if verification.is_ok() {
|
||||
log::debug!("Asset already cached at {:?}", dest_path);
|
||||
return Ok(dest_path);
|
||||
} else {
|
||||
log::warn!("Cached file failed checksum verification, re-downloading");
|
||||
tokio::fs::remove_file(&dest_path)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
}
|
||||
}
|
||||
|
||||
cache.ensure_dir(asset).await?;
|
||||
|
||||
log::info!("Downloading asset from {}", asset.url);
|
||||
let client = reqwest::Client::new();
|
||||
let response = client
|
||||
.get(asset.url.as_str())
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AssetError::DownloadFailed(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(AssetError::DownloadFailed(format!(
|
||||
"HTTP {}: {}",
|
||||
response.status(),
|
||||
asset.url
|
||||
)));
|
||||
}
|
||||
|
||||
let total_size = response.content_length();
|
||||
|
||||
let mut file = tokio::fs::File::create(&dest_path)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut downloaded: u64 = 0;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| AssetError::DownloadFailed(e.to_string()))?;
|
||||
tokio::io::AsyncWriteExt::write_all(&mut file, &chunk)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
downloaded += chunk.len() as u64;
|
||||
if let Some(ref p) = progress {
|
||||
p(downloaded, total_size);
|
||||
}
|
||||
}
|
||||
|
||||
tokio::io::AsyncWriteExt::flush(&mut file)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
|
||||
drop(file);
|
||||
|
||||
verify_checksum(&dest_path, &asset.checksum, asset.checksum_algo.clone()).await?;
|
||||
|
||||
log::info!("Asset downloaded and verified: {:?}", dest_path);
|
||||
Ok(dest_path)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "reqwest"))]
|
||||
async fn fetch(
|
||||
&self,
|
||||
_asset: &Asset,
|
||||
_cache: &LocalCache,
|
||||
_progress: Option<Box<dyn Fn(u64, Option<u64>) + Send>>,
|
||||
) -> Result<PathBuf, AssetError> {
|
||||
Err(AssetError::DownloadFailed(
|
||||
"HTTP downloads not available. Enable the 'reqwest' feature.".to_string(),
|
||||
))
|
||||
}
|
||||
|
||||
async fn exists(&self, key: &str) -> Result<bool, AssetError> {
|
||||
let path = self.base_dir.join(key);
|
||||
Ok(path.exists())
|
||||
}
|
||||
|
||||
fn url_for(&self, key: &str) -> Result<Url, AssetError> {
|
||||
let path = self.base_dir.join(key);
|
||||
Url::from_file_path(&path)
|
||||
.map_err(|_| AssetError::StoreError("Could not convert path to file URL".to_string()))
|
||||
}
|
||||
}
|
||||
27
harmony_assets/src/store/mod.rs
Normal file
27
harmony_assets/src/store/mod.rs
Normal file
@@ -0,0 +1,27 @@
|
||||
use crate::asset::{Asset, LocalCache};
|
||||
use crate::errors::AssetError;
|
||||
use async_trait::async_trait;
|
||||
use std::path::PathBuf;
|
||||
use url::Url;
|
||||
|
||||
pub mod local;
|
||||
|
||||
#[cfg(feature = "s3")]
|
||||
pub mod s3;
|
||||
|
||||
#[async_trait]
|
||||
pub trait AssetStore: Send + Sync {
|
||||
async fn fetch(
|
||||
&self,
|
||||
asset: &Asset,
|
||||
cache: &LocalCache,
|
||||
progress: Option<Box<dyn Fn(u64, Option<u64>) + Send>>,
|
||||
) -> Result<PathBuf, AssetError>;
|
||||
|
||||
async fn exists(&self, key: &str) -> Result<bool, AssetError>;
|
||||
|
||||
fn url_for(&self, key: &str) -> Result<Url, AssetError>;
|
||||
}
|
||||
|
||||
#[cfg(feature = "s3")]
|
||||
pub use s3::{S3Config, S3Store};
|
||||
235
harmony_assets/src/store/s3.rs
Normal file
235
harmony_assets/src/store/s3.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
use crate::asset::StoredAsset;
|
||||
use crate::errors::AssetError;
|
||||
use crate::hash::ChecksumAlgo;
|
||||
use async_trait::async_trait;
|
||||
use aws_sdk_s3::Client as S3Client;
|
||||
use aws_sdk_s3::primitives::ByteStream;
|
||||
use aws_sdk_s3::types::ObjectCannedAcl;
|
||||
use std::path::Path;
|
||||
use url::Url;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct S3Config {
|
||||
pub endpoint: Option<String>,
|
||||
pub bucket: String,
|
||||
pub region: String,
|
||||
pub access_key_id: Option<String>,
|
||||
pub secret_access_key: Option<String>,
|
||||
pub public_read: bool,
|
||||
}
|
||||
|
||||
impl Default for S3Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
endpoint: None,
|
||||
bucket: String::new(),
|
||||
region: String::from("us-east-1"),
|
||||
access_key_id: None,
|
||||
secret_access_key: None,
|
||||
public_read: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct S3Store {
|
||||
client: S3Client,
|
||||
config: S3Config,
|
||||
}
|
||||
|
||||
impl S3Store {
|
||||
pub async fn new(config: S3Config) -> Result<Self, AssetError> {
|
||||
let mut cfg_builder = aws_config::defaults(aws_config::BehaviorVersion::latest());
|
||||
|
||||
if let Some(ref endpoint) = config.endpoint {
|
||||
cfg_builder = cfg_builder.endpoint_url(endpoint);
|
||||
}
|
||||
|
||||
let cfg = cfg_builder.load().await;
|
||||
let client = S3Client::new(&cfg);
|
||||
|
||||
Ok(Self { client, config })
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &S3Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
fn public_url(&self, key: &str) -> Result<Url, AssetError> {
|
||||
let url_str = if let Some(ref endpoint) = self.config.endpoint {
|
||||
format!(
|
||||
"{}/{}/{}",
|
||||
endpoint.trim_end_matches('/'),
|
||||
self.config.bucket,
|
||||
key
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"https://{}.s3.{}.amazonaws.com/{}",
|
||||
self.config.bucket, self.config.region, key
|
||||
)
|
||||
};
|
||||
Url::parse(&url_str).map_err(|e| AssetError::S3Error(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn store(
|
||||
&self,
|
||||
source: &Path,
|
||||
key: &str,
|
||||
content_type: Option<&str>,
|
||||
) -> Result<StoredAsset, AssetError> {
|
||||
let metadata = tokio::fs::metadata(source)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
let size = metadata.len();
|
||||
|
||||
let checksum = crate::checksum_for_path(source, ChecksumAlgo::default())
|
||||
.await
|
||||
.map_err(|e| AssetError::StoreError(e.to_string()))?;
|
||||
|
||||
let body = ByteStream::from_path(source).await.map_err(|e| {
|
||||
AssetError::IoError(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
e.to_string(),
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut put_builder = self
|
||||
.client
|
||||
.put_object()
|
||||
.bucket(&self.config.bucket)
|
||||
.key(key)
|
||||
.body(body)
|
||||
.content_length(size as i64)
|
||||
.metadata("checksum", &checksum);
|
||||
|
||||
if self.config.public_read {
|
||||
put_builder = put_builder.acl(ObjectCannedAcl::PublicRead);
|
||||
}
|
||||
|
||||
if let Some(ct) = content_type {
|
||||
put_builder = put_builder.content_type(ct);
|
||||
}
|
||||
|
||||
put_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AssetError::S3Error(e.to_string()))?;
|
||||
|
||||
Ok(StoredAsset {
|
||||
url: self.public_url(key)?,
|
||||
checksum,
|
||||
checksum_algo: ChecksumAlgo::default(),
|
||||
size,
|
||||
key: key.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
use crate::store::AssetStore;
|
||||
use crate::{Asset, LocalCache};
|
||||
|
||||
#[async_trait]
|
||||
impl AssetStore for S3Store {
|
||||
async fn fetch(
|
||||
&self,
|
||||
asset: &Asset,
|
||||
cache: &LocalCache,
|
||||
progress: Option<Box<dyn Fn(u64, Option<u64>) + Send>>,
|
||||
) -> Result<std::path::PathBuf, AssetError> {
|
||||
let dest_path = cache.path_for(asset);
|
||||
|
||||
if dest_path.exists() {
|
||||
let verification =
|
||||
crate::verify_checksum(&dest_path, &asset.checksum, asset.checksum_algo.clone())
|
||||
.await;
|
||||
if verification.is_ok() {
|
||||
log::debug!("Asset already cached at {:?}", dest_path);
|
||||
return Ok(dest_path);
|
||||
}
|
||||
}
|
||||
|
||||
cache.ensure_dir(asset).await?;
|
||||
|
||||
log::info!(
|
||||
"Downloading asset from s3://{}/{}",
|
||||
self.config.bucket,
|
||||
asset.url
|
||||
);
|
||||
|
||||
let key = extract_s3_key(&asset.url, &self.config.bucket)?;
|
||||
let obj = self
|
||||
.client
|
||||
.get_object()
|
||||
.bucket(&self.config.bucket)
|
||||
.key(&key)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| AssetError::S3Error(e.to_string()))?;
|
||||
|
||||
let total_size = obj.content_length.unwrap_or(0) as u64;
|
||||
let mut file = tokio::fs::File::create(&dest_path)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
|
||||
let mut stream = obj.body;
|
||||
let mut downloaded: u64 = 0;
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
let chunk = chunk_result.map_err(|e| AssetError::S3Error(e.to_string()))?;
|
||||
tokio::io::AsyncWriteExt::write_all(&mut file, &chunk)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
downloaded += chunk.len() as u64;
|
||||
if let Some(ref p) = progress {
|
||||
p(downloaded, Some(total_size));
|
||||
}
|
||||
}
|
||||
|
||||
tokio::io::AsyncWriteExt::flush(&mut file)
|
||||
.await
|
||||
.map_err(|e| AssetError::IoError(e))?;
|
||||
|
||||
drop(file);
|
||||
|
||||
crate::verify_checksum(&dest_path, &asset.checksum, asset.checksum_algo.clone()).await?;
|
||||
|
||||
Ok(dest_path)
|
||||
}
|
||||
|
||||
async fn exists(&self, key: &str) -> Result<bool, AssetError> {
|
||||
match self
|
||||
.client
|
||||
.head_object()
|
||||
.bucket(&self.config.bucket)
|
||||
.key(key)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => {
|
||||
let err_str = e.to_string();
|
||||
if err_str.contains("NoSuchKey") || err_str.contains("NotFound") {
|
||||
Ok(false)
|
||||
} else {
|
||||
Err(AssetError::S3Error(err_str))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn url_for(&self, key: &str) -> Result<Url, AssetError> {
|
||||
self.public_url(key)
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_s3_key(url: &Url, bucket: &str) -> Result<String, AssetError> {
|
||||
let path = url.path().trim_start_matches('/');
|
||||
if let Some(stripped) = path.strip_prefix(&format!("{}/", bucket)) {
|
||||
Ok(stripped.to_string())
|
||||
} else if path == bucket {
|
||||
Ok(String::new())
|
||||
} else {
|
||||
Ok(path.to_string())
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ use directories::ProjectDirs;
|
||||
use interactive_parse::InteractiveParseObj;
|
||||
use log::debug;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use serde::{Serialize, de::DeserializeOwned};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use thiserror::Error;
|
||||
@@ -76,12 +76,11 @@ impl ConfigManager {
|
||||
pub async fn get<T: Config>(&self) -> Result<T, ConfigError> {
|
||||
for source in &self.sources {
|
||||
if let Some(value) = source.get(T::KEY).await? {
|
||||
let config: T = serde_json::from_value(value).map_err(|e| {
|
||||
ConfigError::Deserialization {
|
||||
let config: T =
|
||||
serde_json::from_value(value).map_err(|e| ConfigError::Deserialization {
|
||||
key: T::KEY.to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?;
|
||||
})?;
|
||||
debug!("Retrieved config for key {} from source", T::KEY);
|
||||
return Ok(config);
|
||||
}
|
||||
@@ -95,17 +94,20 @@ impl ConfigManager {
|
||||
match self.get::<T>().await {
|
||||
Ok(config) => Ok(config),
|
||||
Err(ConfigError::NotFound { .. }) => {
|
||||
let config = T::parse_to_obj()
|
||||
.map_err(|e| ConfigError::PromptError(e.to_string()))?;
|
||||
let config =
|
||||
T::parse_to_obj().map_err(|e| ConfigError::PromptError(e.to_string()))?;
|
||||
|
||||
for source in &self.sources {
|
||||
if let Err(e) = source
|
||||
.set(T::KEY, &serde_json::to_value(&config).map_err(|e| {
|
||||
ConfigError::Serialization {
|
||||
key: T::KEY.to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?)
|
||||
.set(
|
||||
T::KEY,
|
||||
&serde_json::to_value(&config).map_err(|e| {
|
||||
ConfigError::Serialization {
|
||||
key: T::KEY.to_string(),
|
||||
source: e,
|
||||
}
|
||||
})?,
|
||||
)
|
||||
.await
|
||||
{
|
||||
debug!("Failed to save config to source: {e}");
|
||||
@@ -175,8 +177,35 @@ mod tests {
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Mutex;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
static TEST_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
static ENV_LOCK: Mutex<()> = Mutex::new(());
|
||||
|
||||
fn setup_env_vars(key: &str, value: Option<&str>) -> String {
|
||||
let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
|
||||
let env_var = format!("HARMONY_CONFIG_{}_{}", key, id);
|
||||
|
||||
unsafe {
|
||||
if let Some(v) = value {
|
||||
std::env::set_var(&env_var, v);
|
||||
} else {
|
||||
std::env::remove_var(&env_var);
|
||||
}
|
||||
}
|
||||
|
||||
env_var
|
||||
}
|
||||
|
||||
fn run_in_isolated_env<F>(f: F)
|
||||
where
|
||||
F: FnOnce() + Send + 'static,
|
||||
{
|
||||
let handle = std::thread::spawn(f);
|
||||
handle.join().expect("Test thread panicked");
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq)]
|
||||
struct TestConfig {
|
||||
name: String,
|
||||
@@ -339,18 +368,14 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_env_source_reads_from_environment() {
|
||||
unsafe {
|
||||
std::env::set_var(
|
||||
"HARMONY_CONFIG_TestConfig",
|
||||
r#"{"name":"from_env","count":7}"#,
|
||||
);
|
||||
}
|
||||
let _lock = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let env_var = setup_env_vars("TestConfig", Some(r#"{"name":"from_env","count":7}"#));
|
||||
|
||||
let source = EnvSource;
|
||||
let result = source.get("TestConfig").await;
|
||||
let result = source.get(&env_var.replace("HARMONY_CONFIG_", "")).await;
|
||||
|
||||
unsafe {
|
||||
std::env::remove_var("HARMONY_CONFIG_TestConfig");
|
||||
std::env::remove_var(&env_var);
|
||||
}
|
||||
|
||||
let value = result.unwrap().unwrap();
|
||||
@@ -361,26 +386,32 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_env_source_returns_none_when_not_set() {
|
||||
unsafe {
|
||||
std::env::remove_var("HARMONY_CONFIG_TestConfig");
|
||||
}
|
||||
let _lock = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
|
||||
run_in_isolated_env(|| {
|
||||
let env_var = setup_env_vars("TestConfig", None);
|
||||
|
||||
let source = EnvSource;
|
||||
let result = source.get("TestConfig").await.unwrap();
|
||||
assert!(result.is_none());
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap();
|
||||
rt.block_on(async {
|
||||
let source = EnvSource;
|
||||
let result = source.get(&env_var.replace("HARMONY_CONFIG_", "")).await;
|
||||
assert!(result.unwrap().is_none());
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_env_source_returns_error_for_invalid_json() {
|
||||
unsafe {
|
||||
std::env::set_var("HARMONY_CONFIG_TestConfig", "not valid json");
|
||||
}
|
||||
let _lock = ENV_LOCK.lock().unwrap_or_else(|e| e.into_inner());
|
||||
let env_var = setup_env_vars("TestConfig", Some("not valid json"));
|
||||
|
||||
let source = EnvSource;
|
||||
let result = source.get("TestConfig").await;
|
||||
let result = source.get(&env_var.replace("HARMONY_CONFIG_", "")).await;
|
||||
|
||||
unsafe {
|
||||
std::env::remove_var("HARMONY_CONFIG_TestConfig");
|
||||
std::env::remove_var(&env_var);
|
||||
}
|
||||
|
||||
assert!(result.is_err());
|
||||
@@ -438,4 +469,4 @@ mod tests {
|
||||
|
||||
assert_eq!(parsed, config);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use async_trait::async_trait;
|
||||
use crate::{ConfigError, ConfigSource};
|
||||
use async_trait::async_trait;
|
||||
|
||||
pub struct EnvSource;
|
||||
|
||||
@@ -11,16 +11,14 @@ fn env_key_for(config_key: &str) -> String {
|
||||
impl ConfigSource for EnvSource {
|
||||
async fn get(&self, key: &str) -> Result<Option<serde_json::Value>, ConfigError> {
|
||||
let env_key = env_key_for(key);
|
||||
|
||||
|
||||
match std::env::var(&env_key) {
|
||||
Ok(value) => {
|
||||
serde_json::from_str(&value)
|
||||
.map(Some)
|
||||
.map_err(|e| ConfigError::EnvError(format!(
|
||||
"Invalid JSON in environment variable {}: {}",
|
||||
env_key, e
|
||||
)))
|
||||
}
|
||||
Ok(value) => serde_json::from_str(&value).map(Some).map_err(|e| {
|
||||
ConfigError::EnvError(format!(
|
||||
"Invalid JSON in environment variable {}: {}",
|
||||
env_key, e
|
||||
))
|
||||
}),
|
||||
Err(std::env::VarError::NotPresent) => Ok(None),
|
||||
Err(e) => Err(ConfigError::EnvError(format!(
|
||||
"Failed to read environment variable {}: {}",
|
||||
@@ -31,12 +29,11 @@ impl ConfigSource for EnvSource {
|
||||
|
||||
async fn set(&self, key: &str, value: &serde_json::Value) -> Result<(), ConfigError> {
|
||||
let env_key = env_key_for(key);
|
||||
let json_string = serde_json::to_string(value)
|
||||
.map_err(|e| ConfigError::Serialization {
|
||||
key: key.to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
let json_string = serde_json::to_string(value).map_err(|e| ConfigError::Serialization {
|
||||
key: key.to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
// SAFETY: Setting environment variables is generally safe in single-threaded contexts.
|
||||
// In multi-threaded contexts, this could cause races, but is acceptable for this use case
|
||||
// as config is typically set once at startup.
|
||||
@@ -45,4 +42,4 @@ impl ConfigSource for EnvSource {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,14 +26,15 @@ impl LocalFileSource {
|
||||
impl ConfigSource for LocalFileSource {
|
||||
async fn get(&self, key: &str) -> Result<Option<serde_json::Value>, ConfigError> {
|
||||
let path = self.file_path_for(key);
|
||||
|
||||
|
||||
match fs::read(&path).await {
|
||||
Ok(contents) => {
|
||||
let value: serde_json::Value = serde_json::from_slice(&contents)
|
||||
.map_err(|e| ConfigError::Deserialization {
|
||||
let value: serde_json::Value = serde_json::from_slice(&contents).map_err(|e| {
|
||||
ConfigError::Deserialization {
|
||||
key: key.to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
}
|
||||
})?;
|
||||
Ok(Some(value))
|
||||
}
|
||||
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
|
||||
@@ -46,16 +47,16 @@ impl ConfigSource for LocalFileSource {
|
||||
|
||||
async fn set(&self, key: &str, value: &serde_json::Value) -> Result<(), ConfigError> {
|
||||
fs::create_dir_all(&self.base_path).await?;
|
||||
|
||||
|
||||
let path = self.file_path_for(key);
|
||||
let contents = serde_json::to_string_pretty(value)
|
||||
.map_err(|e| ConfigError::Serialization {
|
||||
let contents =
|
||||
serde_json::to_string_pretty(value).map_err(|e| ConfigError::Serialization {
|
||||
key: key.to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
|
||||
fs::write(&path, contents).await?;
|
||||
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pub mod env;
|
||||
pub mod local_file;
|
||||
pub mod prompt;
|
||||
pub mod store;
|
||||
pub mod store;
|
||||
|
||||
@@ -18,7 +18,9 @@ impl PromptSource {
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn with_writer(writer: Arc<dyn std::io::Write + Send + Sync>) -> Self {
|
||||
Self { writer: Some(writer) }
|
||||
Self {
|
||||
writer: Some(writer),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -45,4 +47,4 @@ where
|
||||
{
|
||||
let _guard = PROMPT_MUTEX.lock().await;
|
||||
f.await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,8 +19,8 @@ impl<S: SecretStore + 'static> ConfigSource for StoreSource<S> {
|
||||
async fn get(&self, key: &str) -> Result<Option<serde_json::Value>, ConfigError> {
|
||||
match self.store.get_raw(&self.namespace, key).await {
|
||||
Ok(bytes) => {
|
||||
let value: serde_json::Value = serde_json::from_slice(&bytes)
|
||||
.map_err(|e| ConfigError::Deserialization {
|
||||
let value: serde_json::Value =
|
||||
serde_json::from_slice(&bytes).map_err(|e| ConfigError::Deserialization {
|
||||
key: key.to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
@@ -36,10 +36,10 @@ impl<S: SecretStore + 'static> ConfigSource for StoreSource<S> {
|
||||
key: key.to_string(),
|
||||
source: e,
|
||||
})?;
|
||||
|
||||
|
||||
self.store
|
||||
.set_raw(&self.namespace, key, &bytes)
|
||||
.await
|
||||
.map_err(ConfigError::StoreError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use proc_macro::TokenStream;
|
||||
use proc_macro_crate::{crate_name, FoundCrate};
|
||||
use proc_macro_crate::{FoundCrate, crate_name};
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, DeriveInput, Ident};
|
||||
use syn::{DeriveInput, Ident, parse_macro_input};
|
||||
|
||||
#[proc_macro_derive(Config)]
|
||||
pub fn derive_config(input: TokenStream) -> TokenStream {
|
||||
|
||||
Reference in New Issue
Block a user