liter_llm 1.0.0.pre.rc.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/README.md +239 -0
- data/ext/liter_llm_rb/extconf.rb +65 -0
- data/ext/liter_llm_rb/native/.cargo/config.toml +23 -0
- data/ext/liter_llm_rb/native/Cargo.lock +3713 -0
- data/ext/liter_llm_rb/native/Cargo.toml +32 -0
- data/ext/liter_llm_rb/native/build.rs +15 -0
- data/ext/liter_llm_rb/native/src/lib.rs +1079 -0
- data/lib/liter_llm.rb +8 -0
- data/sig/liter_llm.rbs +416 -0
- data/vendor/Cargo.toml +54 -0
- data/vendor/liter-llm/Cargo.toml +92 -0
- data/vendor/liter-llm/README.md +252 -0
- data/vendor/liter-llm/schemas/pricing.json +40 -0
- data/vendor/liter-llm/schemas/providers.json +1662 -0
- data/vendor/liter-llm/src/auth/azure_ad.rs +264 -0
- data/vendor/liter-llm/src/auth/bedrock_sts.rs +353 -0
- data/vendor/liter-llm/src/auth/mod.rs +68 -0
- data/vendor/liter-llm/src/auth/vertex_oauth.rs +353 -0
- data/vendor/liter-llm/src/client/config.rs +351 -0
- data/vendor/liter-llm/src/client/managed.rs +622 -0
- data/vendor/liter-llm/src/client/mod.rs +864 -0
- data/vendor/liter-llm/src/cost.rs +212 -0
- data/vendor/liter-llm/src/error.rs +190 -0
- data/vendor/liter-llm/src/http/eventstream.rs +860 -0
- data/vendor/liter-llm/src/http/mod.rs +12 -0
- data/vendor/liter-llm/src/http/request.rs +438 -0
- data/vendor/liter-llm/src/http/retry.rs +72 -0
- data/vendor/liter-llm/src/http/streaming.rs +289 -0
- data/vendor/liter-llm/src/lib.rs +37 -0
- data/vendor/liter-llm/src/provider/anthropic.rs +2250 -0
- data/vendor/liter-llm/src/provider/azure.rs +579 -0
- data/vendor/liter-llm/src/provider/bedrock.rs +1543 -0
- data/vendor/liter-llm/src/provider/cohere.rs +654 -0
- data/vendor/liter-llm/src/provider/custom.rs +404 -0
- data/vendor/liter-llm/src/provider/google_ai.rs +281 -0
- data/vendor/liter-llm/src/provider/mistral.rs +188 -0
- data/vendor/liter-llm/src/provider/mod.rs +616 -0
- data/vendor/liter-llm/src/provider/vertex.rs +1504 -0
- data/vendor/liter-llm/src/tests.rs +1425 -0
- data/vendor/liter-llm/src/tokenizer.rs +281 -0
- data/vendor/liter-llm/src/tower/budget.rs +599 -0
- data/vendor/liter-llm/src/tower/cache.rs +502 -0
- data/vendor/liter-llm/src/tower/cache_opendal.rs +270 -0
- data/vendor/liter-llm/src/tower/cooldown.rs +231 -0
- data/vendor/liter-llm/src/tower/cost.rs +404 -0
- data/vendor/liter-llm/src/tower/fallback.rs +121 -0
- data/vendor/liter-llm/src/tower/health.rs +219 -0
- data/vendor/liter-llm/src/tower/hooks.rs +369 -0
- data/vendor/liter-llm/src/tower/mod.rs +77 -0
- data/vendor/liter-llm/src/tower/rate_limit.rs +300 -0
- data/vendor/liter-llm/src/tower/router.rs +436 -0
- data/vendor/liter-llm/src/tower/service.rs +181 -0
- data/vendor/liter-llm/src/tower/tests.rs +539 -0
- data/vendor/liter-llm/src/tower/tests_common.rs +252 -0
- data/vendor/liter-llm/src/tower/tracing.rs +209 -0
- data/vendor/liter-llm/src/tower/types.rs +170 -0
- data/vendor/liter-llm/src/types/audio.rs +52 -0
- data/vendor/liter-llm/src/types/batch.rs +77 -0
- data/vendor/liter-llm/src/types/chat.rs +214 -0
- data/vendor/liter-llm/src/types/common.rs +244 -0
- data/vendor/liter-llm/src/types/embedding.rs +84 -0
- data/vendor/liter-llm/src/types/files.rs +58 -0
- data/vendor/liter-llm/src/types/image.rs +40 -0
- data/vendor/liter-llm/src/types/mod.rs +27 -0
- data/vendor/liter-llm/src/types/models.rs +21 -0
- data/vendor/liter-llm/src/types/moderation.rs +80 -0
- data/vendor/liter-llm/src/types/ocr.rs +87 -0
- data/vendor/liter-llm/src/types/rerank.rs +46 -0
- data/vendor/liter-llm/src/types/responses.rs +55 -0
- data/vendor/liter-llm/src/types/search.rs +45 -0
- data/vendor/liter-llm/tests/contract.rs +332 -0
- data/vendor/liter-llm-ffi/Cargo.toml +30 -0
- data/vendor/liter-llm-ffi/build.rs +66 -0
- data/vendor/liter-llm-ffi/cbindgen.toml +60 -0
- data/vendor/liter-llm-ffi/liter_llm.h +850 -0
- data/vendor/liter-llm-ffi/src/lib.rs +2488 -0
- metadata +286 -0
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
//! Azure AD OAuth2 credential provider (client-credentials flow).
|
|
2
|
+
//!
|
|
3
|
+
//! Exchanges client credentials for a bearer token via the Microsoft Identity
|
|
4
|
+
//! Platform v2.0 token endpoint. Tokens are cached and refreshed automatically
|
|
5
|
+
//! when they are within 5 minutes of expiry.
|
|
6
|
+
//!
|
|
7
|
+
//! # Environment variables
|
|
8
|
+
//!
|
|
9
|
+
//! | Variable | Description |
|
|
10
|
+
//! |----------|-------------|
|
|
11
|
+
//! | `AZURE_TENANT_ID` | Azure AD tenant ID |
|
|
12
|
+
//! | `AZURE_CLIENT_ID` | Application (client) ID |
|
|
13
|
+
//! | `AZURE_CLIENT_SECRET` | Client secret value |
|
|
14
|
+
//! | `AZURE_AD_TOKEN` | Static bearer token (skips OAuth flow) |
|
|
15
|
+
//! | `AZURE_AD_SCOPE` | OAuth scope (defaults to `https://cognitiveservices.azure.com/.default`) |
|
|
16
|
+
|
|
17
|
+
use std::sync::Arc;
|
|
18
|
+
use std::time::Instant;
|
|
19
|
+
|
|
20
|
+
use secrecy::{ExposeSecret, SecretString};
|
|
21
|
+
use tokio::sync::RwLock;
|
|
22
|
+
|
|
23
|
+
use super::{Credential, CredentialProvider, StaticTokenProvider};
|
|
24
|
+
use crate::client::BoxFuture;
|
|
25
|
+
use crate::error::LiterLlmError;
|
|
26
|
+
|
|
27
|
+
/// Default OAuth2 scope for Azure Cognitive Services (including Azure OpenAI).
|
|
28
|
+
const DEFAULT_SCOPE: &str = "https://cognitiveservices.azure.com/.default";
|
|
29
|
+
|
|
30
|
+
/// Minimum remaining lifetime before a cached token is considered expired.
|
|
31
|
+
const EXPIRY_BUFFER_SECS: u64 = 300;
|
|
32
|
+
|
|
33
|
+
/// Cached token and its acquisition timestamp + lifetime.
|
|
34
|
+
struct CachedToken {
|
|
35
|
+
token: SecretString,
|
|
36
|
+
acquired_at: Instant,
|
|
37
|
+
expires_in_secs: u64,
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
impl CachedToken {
|
|
41
|
+
/// Returns `true` if the token is still valid with the safety buffer.
|
|
42
|
+
fn is_valid(&self) -> bool {
|
|
43
|
+
let elapsed = self.acquired_at.elapsed().as_secs();
|
|
44
|
+
elapsed + EXPIRY_BUFFER_SECS < self.expires_in_secs
|
|
45
|
+
}
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
/// Azure AD OAuth2 credential provider using the client-credentials grant.
|
|
49
|
+
///
|
|
50
|
+
/// Obtains bearer tokens from `https://login.microsoftonline.com/{tenant}/oauth2/v2.0/token`
|
|
51
|
+
/// and caches them until they are within 5 minutes of expiry.
|
|
52
|
+
pub struct AzureAdCredentialProvider {
|
|
53
|
+
tenant_id: String,
|
|
54
|
+
client_id: String,
|
|
55
|
+
client_secret: SecretString,
|
|
56
|
+
scope: String,
|
|
57
|
+
cached: RwLock<Option<CachedToken>>,
|
|
58
|
+
http_client: reqwest::Client,
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
impl AzureAdCredentialProvider {
|
|
62
|
+
/// Create a new provider with explicit credentials.
|
|
63
|
+
///
|
|
64
|
+
/// Uses the default scope `https://cognitiveservices.azure.com/.default`.
|
|
65
|
+
#[must_use]
|
|
66
|
+
pub fn new(tenant_id: impl Into<String>, client_id: impl Into<String>, client_secret: SecretString) -> Self {
|
|
67
|
+
Self {
|
|
68
|
+
tenant_id: tenant_id.into(),
|
|
69
|
+
client_id: client_id.into(),
|
|
70
|
+
client_secret,
|
|
71
|
+
scope: DEFAULT_SCOPE.to_owned(),
|
|
72
|
+
cached: RwLock::new(None),
|
|
73
|
+
http_client: reqwest::Client::new(),
|
|
74
|
+
}
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
/// Override the OAuth2 scope (default: `https://cognitiveservices.azure.com/.default`).
|
|
78
|
+
#[must_use]
|
|
79
|
+
pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
|
|
80
|
+
self.scope = scope.into();
|
|
81
|
+
self
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
/// Override the HTTP client used for token requests.
|
|
85
|
+
#[must_use]
|
|
86
|
+
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
|
|
87
|
+
self.http_client = client;
|
|
88
|
+
self
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
/// Create a provider from environment variables.
|
|
92
|
+
///
|
|
93
|
+
/// If `AZURE_AD_TOKEN` is set, returns a [`StaticTokenProvider`] instead
|
|
94
|
+
/// (no OAuth flow needed).
|
|
95
|
+
///
|
|
96
|
+
/// # Errors
|
|
97
|
+
///
|
|
98
|
+
/// Returns [`LiterLlmError::Authentication`] if required environment
|
|
99
|
+
/// variables are missing.
|
|
100
|
+
pub fn from_env() -> Result<Arc<dyn CredentialProvider>, LiterLlmError> {
|
|
101
|
+
// Fast path: static token from environment.
|
|
102
|
+
if let Ok(token) = std::env::var("AZURE_AD_TOKEN") {
|
|
103
|
+
return Ok(Arc::new(StaticTokenProvider::new(SecretString::from(token))));
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
let tenant_id = env_var_required("AZURE_TENANT_ID")?;
|
|
107
|
+
let client_id = env_var_required("AZURE_CLIENT_ID")?;
|
|
108
|
+
let client_secret = SecretString::from(env_var_required("AZURE_CLIENT_SECRET")?);
|
|
109
|
+
|
|
110
|
+
let mut provider = Self::new(tenant_id, client_id, client_secret);
|
|
111
|
+
|
|
112
|
+
if let Ok(scope) = std::env::var("AZURE_AD_SCOPE") {
|
|
113
|
+
provider.scope = scope;
|
|
114
|
+
}
|
|
115
|
+
|
|
116
|
+
Ok(Arc::new(provider))
|
|
117
|
+
}
|
|
118
|
+
|
|
119
|
+
/// Exchange client credentials for an access token.
|
|
120
|
+
async fn fetch_token(&self) -> Result<CachedToken, LiterLlmError> {
|
|
121
|
+
let url = format!("https://login.microsoftonline.com/{}/oauth2/v2.0/token", self.tenant_id);
|
|
122
|
+
|
|
123
|
+
let resp = self
|
|
124
|
+
.http_client
|
|
125
|
+
.post(&url)
|
|
126
|
+
.form(&[
|
|
127
|
+
("grant_type", "client_credentials"),
|
|
128
|
+
("client_id", &self.client_id),
|
|
129
|
+
("client_secret", self.client_secret.expose_secret()),
|
|
130
|
+
("scope", &self.scope),
|
|
131
|
+
])
|
|
132
|
+
.send()
|
|
133
|
+
.await
|
|
134
|
+
.map_err(|e| LiterLlmError::Authentication {
|
|
135
|
+
message: format!("Azure AD token request failed: {e}"),
|
|
136
|
+
})?;
|
|
137
|
+
|
|
138
|
+
let status = resp.status();
|
|
139
|
+
let body = resp.text().await.map_err(|e| LiterLlmError::Authentication {
|
|
140
|
+
message: format!("Azure AD token response unreadable: {e}"),
|
|
141
|
+
})?;
|
|
142
|
+
|
|
143
|
+
if !status.is_success() {
|
|
144
|
+
return Err(LiterLlmError::Authentication {
|
|
145
|
+
message: format!("Azure AD token request returned {status}: {body}"),
|
|
146
|
+
});
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
let parsed: TokenResponse = serde_json::from_str(&body).map_err(|e| LiterLlmError::Authentication {
|
|
150
|
+
message: format!("Azure AD token response parse error: {e}"),
|
|
151
|
+
})?;
|
|
152
|
+
|
|
153
|
+
Ok(CachedToken {
|
|
154
|
+
token: SecretString::from(parsed.access_token),
|
|
155
|
+
acquired_at: Instant::now(),
|
|
156
|
+
expires_in_secs: parsed.expires_in,
|
|
157
|
+
})
|
|
158
|
+
}
|
|
159
|
+
}
|
|
160
|
+
|
|
161
|
+
impl CredentialProvider for AzureAdCredentialProvider {
|
|
162
|
+
fn resolve(&self) -> BoxFuture<'_, Credential> {
|
|
163
|
+
Box::pin(async move {
|
|
164
|
+
// Fast path: read lock to check cache.
|
|
165
|
+
{
|
|
166
|
+
let guard = self.cached.read().await;
|
|
167
|
+
if let Some(ref cached) = *guard
|
|
168
|
+
&& cached.is_valid()
|
|
169
|
+
{
|
|
170
|
+
return Ok(Credential::BearerToken(cached.token.clone()));
|
|
171
|
+
}
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
// Slow path: write lock to refresh.
|
|
175
|
+
let mut guard = self.cached.write().await;
|
|
176
|
+
|
|
177
|
+
// Double-check after acquiring write lock (another task may have refreshed).
|
|
178
|
+
if let Some(ref cached) = *guard
|
|
179
|
+
&& cached.is_valid()
|
|
180
|
+
{
|
|
181
|
+
return Ok(Credential::BearerToken(cached.token.clone()));
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
let fresh = self.fetch_token().await?;
|
|
185
|
+
let token = fresh.token.clone();
|
|
186
|
+
*guard = Some(fresh);
|
|
187
|
+
|
|
188
|
+
Ok(Credential::BearerToken(token))
|
|
189
|
+
})
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/// Minimal deserialization of the Azure AD token response.
|
|
194
|
+
#[derive(serde::Deserialize)]
|
|
195
|
+
struct TokenResponse {
|
|
196
|
+
access_token: String,
|
|
197
|
+
expires_in: u64,
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
/// Read a required environment variable, returning an auth error if missing.
|
|
201
|
+
fn env_var_required(name: &str) -> Result<String, LiterLlmError> {
|
|
202
|
+
std::env::var(name).map_err(|_| LiterLlmError::Authentication {
|
|
203
|
+
message: format!("missing required environment variable: {name}"),
|
|
204
|
+
})
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
#[cfg(test)]
|
|
208
|
+
mod tests {
|
|
209
|
+
use super::*;
|
|
210
|
+
|
|
211
|
+
#[test]
|
|
212
|
+
fn cached_token_validity() {
|
|
213
|
+
let cached = CachedToken {
|
|
214
|
+
token: SecretString::from("test-token".to_owned()),
|
|
215
|
+
acquired_at: Instant::now(),
|
|
216
|
+
expires_in_secs: 3600,
|
|
217
|
+
};
|
|
218
|
+
assert!(cached.is_valid());
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
#[test]
|
|
222
|
+
fn cached_token_expired() {
|
|
223
|
+
let cached = CachedToken {
|
|
224
|
+
token: SecretString::from("test-token".to_owned()),
|
|
225
|
+
// A token with zero lifetime is immediately expired (no Duration subtraction
|
|
226
|
+
// needed, which avoids panics on Windows where Instant uptime may be < 1h).
|
|
227
|
+
acquired_at: Instant::now(),
|
|
228
|
+
expires_in_secs: 0,
|
|
229
|
+
};
|
|
230
|
+
assert!(!cached.is_valid());
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
#[test]
|
|
234
|
+
fn cached_token_within_buffer() {
|
|
235
|
+
let cached = CachedToken {
|
|
236
|
+
token: SecretString::from("test-token".to_owned()),
|
|
237
|
+
// 200s lifetime is within the 300s expiry buffer, so the token is invalid.
|
|
238
|
+
acquired_at: Instant::now(),
|
|
239
|
+
expires_in_secs: 200,
|
|
240
|
+
};
|
|
241
|
+
assert!(!cached.is_valid());
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
#[test]
|
|
245
|
+
fn default_scope() {
|
|
246
|
+
let provider = AzureAdCredentialProvider::new("tenant", "client", SecretString::from("secret".to_owned()));
|
|
247
|
+
assert_eq!(provider.scope, DEFAULT_SCOPE);
|
|
248
|
+
}
|
|
249
|
+
|
|
250
|
+
#[test]
|
|
251
|
+
fn with_scope_override() {
|
|
252
|
+
let provider = AzureAdCredentialProvider::new("tenant", "client", SecretString::from("secret".to_owned()))
|
|
253
|
+
.with_scope("https://custom.scope/.default");
|
|
254
|
+
assert_eq!(provider.scope, "https://custom.scope/.default");
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
#[tokio::test]
|
|
258
|
+
#[ignore] // Requires network access and valid Azure AD credentials.
|
|
259
|
+
async fn live_azure_ad_token_exchange() {
|
|
260
|
+
let provider = AzureAdCredentialProvider::from_env().expect("Azure AD env vars not set");
|
|
261
|
+
let credential = provider.resolve().await.expect("token exchange failed");
|
|
262
|
+
assert!(matches!(credential, Credential::BearerToken(_)));
|
|
263
|
+
}
|
|
264
|
+
}
|
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
//! AWS STS Web Identity credential provider for Bedrock.
|
|
2
|
+
//!
|
|
3
|
+
//! Exchanges a web identity token (JWT from a file) for temporary AWS
|
|
4
|
+
//! credentials via the STS `AssumeRoleWithWebIdentity` API. This is the
|
|
5
|
+
//! standard authentication flow for EKS pods using IAM Roles for Service
|
|
6
|
+
//! Accounts (IRSA) and for other OIDC federation scenarios.
|
|
7
|
+
//!
|
|
8
|
+
//! Credentials are cached and refreshed automatically when they are within
|
|
9
|
+
//! 5 minutes of expiry.
|
|
10
|
+
//!
|
|
11
|
+
//! # Environment variables
|
|
12
|
+
//!
|
|
13
|
+
//! | Variable | Description |
|
|
14
|
+
//! |----------|-------------|
|
|
15
|
+
//! | `AWS_ROLE_ARN` | ARN of the IAM role to assume |
|
|
16
|
+
//! | `AWS_WEB_IDENTITY_TOKEN_FILE` | Path to a file containing the OIDC JWT |
|
|
17
|
+
//! | `AWS_ROLE_SESSION_NAME` | Session name (defaults to `liter-llm-session`) |
|
|
18
|
+
//! | `AWS_REGION` or `AWS_DEFAULT_REGION` | AWS region (defaults to `us-east-1`) |
|
|
19
|
+
|
|
20
|
+
use std::path::PathBuf;
|
|
21
|
+
use std::time::Instant;
|
|
22
|
+
|
|
23
|
+
use secrecy::SecretString;
|
|
24
|
+
use tokio::sync::RwLock;
|
|
25
|
+
|
|
26
|
+
use super::{Credential, CredentialProvider};
|
|
27
|
+
use crate::client::BoxFuture;
|
|
28
|
+
use crate::error::LiterLlmError;
|
|
29
|
+
|
|
30
|
+
/// Default session name when `AWS_ROLE_SESSION_NAME` is not set.
|
|
31
|
+
const DEFAULT_SESSION_NAME: &str = "liter-llm-session";
|
|
32
|
+
|
|
33
|
+
/// Default AWS region when neither `AWS_REGION` nor `AWS_DEFAULT_REGION` is set.
|
|
34
|
+
const DEFAULT_REGION: &str = "us-east-1";
|
|
35
|
+
|
|
36
|
+
/// Minimum remaining lifetime before cached credentials are considered expired.
|
|
37
|
+
const EXPIRY_BUFFER_SECS: u64 = 300;
|
|
38
|
+
|
|
39
|
+
/// Default credential duration in seconds (1 hour).
|
|
40
|
+
const DEFAULT_DURATION_SECS: u64 = 3600;
|
|
41
|
+
|
|
42
|
+
/// Cached temporary credentials.
|
|
43
|
+
struct CachedCredentials {
|
|
44
|
+
access_key_id: SecretString,
|
|
45
|
+
secret_access_key: SecretString,
|
|
46
|
+
session_token: SecretString,
|
|
47
|
+
acquired_at: Instant,
|
|
48
|
+
expires_in_secs: u64,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
impl CachedCredentials {
|
|
52
|
+
/// Returns `true` if the credentials are still valid with the safety buffer.
|
|
53
|
+
fn is_valid(&self) -> bool {
|
|
54
|
+
let elapsed = self.acquired_at.elapsed().as_secs();
|
|
55
|
+
elapsed + EXPIRY_BUFFER_SECS < self.expires_in_secs
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
/// AWS STS Web Identity credential provider.
|
|
60
|
+
///
|
|
61
|
+
/// Reads a JWT from a token file, sends it to the STS
|
|
62
|
+
/// `AssumeRoleWithWebIdentity` endpoint, and returns temporary AWS credentials
|
|
63
|
+
/// suitable for SigV4 signing.
|
|
64
|
+
pub struct WebIdentityCredentialProvider {
|
|
65
|
+
role_arn: String,
|
|
66
|
+
token_file: PathBuf,
|
|
67
|
+
session_name: String,
|
|
68
|
+
region: String,
|
|
69
|
+
cached: RwLock<Option<CachedCredentials>>,
|
|
70
|
+
http_client: reqwest::Client,
|
|
71
|
+
}
|
|
72
|
+
|
|
73
|
+
impl WebIdentityCredentialProvider {
|
|
74
|
+
/// Create a new provider with explicit parameters.
|
|
75
|
+
#[must_use]
|
|
76
|
+
pub fn new(
|
|
77
|
+
role_arn: impl Into<String>,
|
|
78
|
+
token_file: impl Into<PathBuf>,
|
|
79
|
+
session_name: impl Into<String>,
|
|
80
|
+
region: impl Into<String>,
|
|
81
|
+
) -> Self {
|
|
82
|
+
Self {
|
|
83
|
+
role_arn: role_arn.into(),
|
|
84
|
+
token_file: token_file.into(),
|
|
85
|
+
session_name: session_name.into(),
|
|
86
|
+
region: region.into(),
|
|
87
|
+
cached: RwLock::new(None),
|
|
88
|
+
http_client: reqwest::Client::new(),
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
/// Create a provider from standard AWS environment variables.
|
|
93
|
+
///
|
|
94
|
+
/// # Errors
|
|
95
|
+
///
|
|
96
|
+
/// Returns [`LiterLlmError::Authentication`] if `AWS_ROLE_ARN` or
|
|
97
|
+
/// `AWS_WEB_IDENTITY_TOKEN_FILE` are not set.
|
|
98
|
+
pub fn from_env() -> Result<Self, LiterLlmError> {
|
|
99
|
+
let role_arn = env_var_required("AWS_ROLE_ARN")?;
|
|
100
|
+
let token_file = env_var_required("AWS_WEB_IDENTITY_TOKEN_FILE")?;
|
|
101
|
+
|
|
102
|
+
let session_name = std::env::var("AWS_ROLE_SESSION_NAME").unwrap_or_else(|_| DEFAULT_SESSION_NAME.to_owned());
|
|
103
|
+
|
|
104
|
+
let region = std::env::var("AWS_REGION")
|
|
105
|
+
.or_else(|_| std::env::var("AWS_DEFAULT_REGION"))
|
|
106
|
+
.unwrap_or_else(|_| DEFAULT_REGION.to_owned());
|
|
107
|
+
|
|
108
|
+
Ok(Self::new(role_arn, token_file, session_name, region))
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
/// Override the HTTP client used for STS requests.
|
|
112
|
+
#[must_use]
|
|
113
|
+
pub fn with_http_client(mut self, client: reqwest::Client) -> Self {
|
|
114
|
+
self.http_client = client;
|
|
115
|
+
self
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
/// Read the web identity token from the token file and exchange it for
|
|
119
|
+
/// temporary AWS credentials.
|
|
120
|
+
async fn fetch_credentials(&self) -> Result<CachedCredentials, LiterLlmError> {
|
|
121
|
+
let token = tokio::fs::read_to_string(&self.token_file)
|
|
122
|
+
.await
|
|
123
|
+
.map_err(|e| LiterLlmError::Authentication {
|
|
124
|
+
message: format!(
|
|
125
|
+
"failed to read web identity token file {}: {e}",
|
|
126
|
+
self.token_file.display()
|
|
127
|
+
),
|
|
128
|
+
})?;
|
|
129
|
+
let token = token.trim();
|
|
130
|
+
|
|
131
|
+
let url = format!("https://sts.{}.amazonaws.com/", self.region);
|
|
132
|
+
|
|
133
|
+
let resp = self
|
|
134
|
+
.http_client
|
|
135
|
+
.post(&url)
|
|
136
|
+
.header("Content-Type", "application/x-www-form-urlencoded")
|
|
137
|
+
.form(&[
|
|
138
|
+
("Action", "AssumeRoleWithWebIdentity"),
|
|
139
|
+
("Version", "2011-06-15"),
|
|
140
|
+
("RoleArn", &self.role_arn),
|
|
141
|
+
("RoleSessionName", &self.session_name),
|
|
142
|
+
("WebIdentityToken", token),
|
|
143
|
+
("DurationSeconds", &DEFAULT_DURATION_SECS.to_string()),
|
|
144
|
+
])
|
|
145
|
+
.send()
|
|
146
|
+
.await
|
|
147
|
+
.map_err(|e| LiterLlmError::Authentication {
|
|
148
|
+
message: format!("STS AssumeRoleWithWebIdentity request failed: {e}"),
|
|
149
|
+
})?;
|
|
150
|
+
|
|
151
|
+
let status = resp.status();
|
|
152
|
+
let body = resp.text().await.map_err(|e| LiterLlmError::Authentication {
|
|
153
|
+
message: format!("STS response unreadable: {e}"),
|
|
154
|
+
})?;
|
|
155
|
+
|
|
156
|
+
if !status.is_success() {
|
|
157
|
+
return Err(LiterLlmError::Authentication {
|
|
158
|
+
message: format!("STS AssumeRoleWithWebIdentity returned {status}: {body}"),
|
|
159
|
+
});
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
// Parse the XML response to extract credentials.
|
|
163
|
+
let creds = parse_sts_response(&body)?;
|
|
164
|
+
|
|
165
|
+
Ok(CachedCredentials {
|
|
166
|
+
access_key_id: SecretString::from(creds.access_key_id),
|
|
167
|
+
secret_access_key: SecretString::from(creds.secret_access_key),
|
|
168
|
+
session_token: SecretString::from(creds.session_token),
|
|
169
|
+
acquired_at: Instant::now(),
|
|
170
|
+
expires_in_secs: DEFAULT_DURATION_SECS,
|
|
171
|
+
})
|
|
172
|
+
}
|
|
173
|
+
}
|
|
174
|
+
|
|
175
|
+
impl CredentialProvider for WebIdentityCredentialProvider {
|
|
176
|
+
fn resolve(&self) -> BoxFuture<'_, Credential> {
|
|
177
|
+
Box::pin(async move {
|
|
178
|
+
// Fast path: read lock to check cache.
|
|
179
|
+
{
|
|
180
|
+
let guard = self.cached.read().await;
|
|
181
|
+
if let Some(ref cached) = *guard
|
|
182
|
+
&& cached.is_valid()
|
|
183
|
+
{
|
|
184
|
+
return Ok(Credential::AwsCredentials {
|
|
185
|
+
access_key_id: cached.access_key_id.clone(),
|
|
186
|
+
secret_access_key: cached.secret_access_key.clone(),
|
|
187
|
+
session_token: Some(cached.session_token.clone()),
|
|
188
|
+
});
|
|
189
|
+
}
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
// Slow path: write lock to refresh.
|
|
193
|
+
let mut guard = self.cached.write().await;
|
|
194
|
+
|
|
195
|
+
// Double-check after acquiring write lock.
|
|
196
|
+
if let Some(ref cached) = *guard
|
|
197
|
+
&& cached.is_valid()
|
|
198
|
+
{
|
|
199
|
+
return Ok(Credential::AwsCredentials {
|
|
200
|
+
access_key_id: cached.access_key_id.clone(),
|
|
201
|
+
secret_access_key: cached.secret_access_key.clone(),
|
|
202
|
+
session_token: Some(cached.session_token.clone()),
|
|
203
|
+
});
|
|
204
|
+
}
|
|
205
|
+
|
|
206
|
+
let fresh = self.fetch_credentials().await?;
|
|
207
|
+
let credential = Credential::AwsCredentials {
|
|
208
|
+
access_key_id: fresh.access_key_id.clone(),
|
|
209
|
+
secret_access_key: fresh.secret_access_key.clone(),
|
|
210
|
+
session_token: Some(fresh.session_token.clone()),
|
|
211
|
+
};
|
|
212
|
+
*guard = Some(fresh);
|
|
213
|
+
|
|
214
|
+
Ok(credential)
|
|
215
|
+
})
|
|
216
|
+
}
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
/// Parsed STS temporary credentials.
|
|
220
|
+
#[derive(Debug)]
|
|
221
|
+
struct StsCredentials {
|
|
222
|
+
access_key_id: String,
|
|
223
|
+
secret_access_key: String,
|
|
224
|
+
session_token: String,
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
/// Extract credential fields from the STS XML response using simple string
|
|
228
|
+
/// matching. We avoid pulling in a full XML parser for three fixed elements.
|
|
229
|
+
fn parse_sts_response(xml: &str) -> Result<StsCredentials, LiterLlmError> {
|
|
230
|
+
let access_key_id = extract_xml_element(xml, "AccessKeyId")?;
|
|
231
|
+
let secret_access_key = extract_xml_element(xml, "SecretAccessKey")?;
|
|
232
|
+
let session_token = extract_xml_element(xml, "SessionToken")?;
|
|
233
|
+
|
|
234
|
+
Ok(StsCredentials {
|
|
235
|
+
access_key_id,
|
|
236
|
+
secret_access_key,
|
|
237
|
+
session_token,
|
|
238
|
+
})
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
/// Extract the text content of a simple XML element `<tag>value</tag>`.
|
|
242
|
+
fn extract_xml_element(xml: &str, tag: &str) -> Result<String, LiterLlmError> {
|
|
243
|
+
let open = format!("<{tag}>");
|
|
244
|
+
let close = format!("</{tag}>");
|
|
245
|
+
|
|
246
|
+
let start = xml.find(&open).ok_or_else(|| LiterLlmError::Authentication {
|
|
247
|
+
message: format!("STS response missing <{tag}> element"),
|
|
248
|
+
})? + open.len();
|
|
249
|
+
|
|
250
|
+
let end = xml[start..].find(&close).ok_or_else(|| LiterLlmError::Authentication {
|
|
251
|
+
message: format!("STS response missing </{tag}> element"),
|
|
252
|
+
})? + start;
|
|
253
|
+
|
|
254
|
+
Ok(xml[start..end].to_owned())
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
/// Read a required environment variable, returning an auth error if missing.
|
|
258
|
+
fn env_var_required(name: &str) -> Result<String, LiterLlmError> {
|
|
259
|
+
std::env::var(name).map_err(|_| LiterLlmError::Authentication {
|
|
260
|
+
message: format!("missing required environment variable: {name}"),
|
|
261
|
+
})
|
|
262
|
+
}
|
|
263
|
+
|
|
264
|
+
#[cfg(test)]
|
|
265
|
+
mod tests {
|
|
266
|
+
use super::*;
|
|
267
|
+
|
|
268
|
+
#[test]
|
|
269
|
+
fn cached_credentials_validity() {
|
|
270
|
+
let cached = CachedCredentials {
|
|
271
|
+
access_key_id: SecretString::from("AKIA...".to_owned()),
|
|
272
|
+
secret_access_key: SecretString::from("secret".to_owned()),
|
|
273
|
+
session_token: SecretString::from("token".to_owned()),
|
|
274
|
+
acquired_at: Instant::now(),
|
|
275
|
+
expires_in_secs: 3600,
|
|
276
|
+
};
|
|
277
|
+
assert!(cached.is_valid());
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
#[test]
|
|
281
|
+
fn cached_credentials_expired() {
|
|
282
|
+
let cached = CachedCredentials {
|
|
283
|
+
access_key_id: SecretString::from("AKIA...".to_owned()),
|
|
284
|
+
secret_access_key: SecretString::from("secret".to_owned()),
|
|
285
|
+
session_token: SecretString::from("token".to_owned()),
|
|
286
|
+
// Zero lifetime is immediately expired; avoids Duration subtraction panic on Windows.
|
|
287
|
+
acquired_at: Instant::now(),
|
|
288
|
+
expires_in_secs: 0,
|
|
289
|
+
};
|
|
290
|
+
assert!(!cached.is_valid());
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
#[test]
|
|
294
|
+
fn parse_sts_xml_response() {
|
|
295
|
+
let xml = r#"
|
|
296
|
+
<AssumeRoleWithWebIdentityResponse xmlns="https://sts.amazonaws.com/doc/2011-06-15/">
|
|
297
|
+
<AssumeRoleWithWebIdentityResult>
|
|
298
|
+
<Credentials>
|
|
299
|
+
<AccessKeyId>AKIAIOSFODNN7EXAMPLE</AccessKeyId>
|
|
300
|
+
<SecretAccessKey>wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY</SecretAccessKey>
|
|
301
|
+
<SessionToken>FwoGZXIvYXdzEBYaDGlY...</SessionToken>
|
|
302
|
+
<Expiration>2024-01-01T00:00:00Z</Expiration>
|
|
303
|
+
</Credentials>
|
|
304
|
+
</AssumeRoleWithWebIdentityResult>
|
|
305
|
+
</AssumeRoleWithWebIdentityResponse>
|
|
306
|
+
"#;
|
|
307
|
+
|
|
308
|
+
let creds = parse_sts_response(xml).expect("should parse");
|
|
309
|
+
assert_eq!(creds.access_key_id, "AKIAIOSFODNN7EXAMPLE");
|
|
310
|
+
assert_eq!(creds.secret_access_key, "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY");
|
|
311
|
+
assert_eq!(creds.session_token, "FwoGZXIvYXdzEBYaDGlY...");
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
#[test]
|
|
315
|
+
fn parse_sts_xml_missing_element() {
|
|
316
|
+
let xml = r"<Response><AccessKeyId>AKIA</AccessKeyId></Response>";
|
|
317
|
+
let err = parse_sts_response(xml).unwrap_err();
|
|
318
|
+
assert!(err.to_string().contains("SecretAccessKey"));
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
#[test]
|
|
322
|
+
fn extract_xml_element_success() {
|
|
323
|
+
let xml = "<Root><Foo>bar</Foo></Root>";
|
|
324
|
+
assert_eq!(extract_xml_element(xml, "Foo").expect("should work"), "bar");
|
|
325
|
+
}
|
|
326
|
+
|
|
327
|
+
#[test]
|
|
328
|
+
fn extract_xml_element_missing_open() {
|
|
329
|
+
let err = extract_xml_element("<Root></Root>", "Missing").unwrap_err();
|
|
330
|
+
assert!(err.to_string().contains("<Missing>"));
|
|
331
|
+
}
|
|
332
|
+
|
|
333
|
+
#[test]
|
|
334
|
+
fn constructor_defaults() {
|
|
335
|
+
let provider = WebIdentityCredentialProvider::new(
|
|
336
|
+
"arn:aws:iam::123456789012:role/TestRole",
|
|
337
|
+
"/var/run/secrets/token",
|
|
338
|
+
"test-session",
|
|
339
|
+
"eu-west-1",
|
|
340
|
+
);
|
|
341
|
+
assert_eq!(provider.role_arn, "arn:aws:iam::123456789012:role/TestRole");
|
|
342
|
+
assert_eq!(provider.session_name, "test-session");
|
|
343
|
+
assert_eq!(provider.region, "eu-west-1");
|
|
344
|
+
}
|
|
345
|
+
|
|
346
|
+
#[tokio::test]
|
|
347
|
+
#[ignore] // Requires network access and valid AWS OIDC credentials.
|
|
348
|
+
async fn live_sts_web_identity_exchange() {
|
|
349
|
+
let provider = WebIdentityCredentialProvider::from_env().expect("AWS env vars not set");
|
|
350
|
+
let credential = provider.resolve().await.expect("STS exchange failed");
|
|
351
|
+
assert!(matches!(credential, Credential::AwsCredentials { .. }));
|
|
352
|
+
}
|
|
353
|
+
}
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
#[cfg(feature = "azure-auth")]
|
|
2
|
+
pub mod azure_ad;
|
|
3
|
+
#[cfg(feature = "bedrock-auth")]
|
|
4
|
+
pub mod bedrock_sts;
|
|
5
|
+
#[cfg(feature = "vertex-auth")]
|
|
6
|
+
pub mod vertex_oauth;
|
|
7
|
+
|
|
8
|
+
use std::sync::Arc;
|
|
9
|
+
|
|
10
|
+
use secrecy::SecretString;
|
|
11
|
+
|
|
12
|
+
use crate::client::BoxFuture;
|
|
13
|
+
|
|
14
|
+
/// Dynamic credential provider for providers that use token-based auth
|
|
15
|
+
/// (Azure AD, Vertex OAuth2) or refreshable credentials (AWS STS).
|
|
16
|
+
///
|
|
17
|
+
/// Implementations handle caching, refresh, and expiry internally.
|
|
18
|
+
/// The client calls `resolve()` before each request when a credential
|
|
19
|
+
/// provider is configured.
|
|
20
|
+
pub trait CredentialProvider: Send + Sync {
|
|
21
|
+
/// Retrieve a valid credential.
|
|
22
|
+
///
|
|
23
|
+
/// Implementations should cache credentials and only refresh when
|
|
24
|
+
/// expired or about to expire.
|
|
25
|
+
fn resolve(&self) -> BoxFuture<'_, Credential>;
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
/// Blanket implementation so `Arc<dyn CredentialProvider>` is itself a
|
|
29
|
+
/// `CredentialProvider`, making it convenient to share providers across
|
|
30
|
+
/// clients.
|
|
31
|
+
impl CredentialProvider for Arc<dyn CredentialProvider> {
|
|
32
|
+
fn resolve(&self) -> BoxFuture<'_, Credential> {
|
|
33
|
+
(**self).resolve()
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
/// A resolved credential ready for use in request authentication.
|
|
38
|
+
#[derive(Debug, Clone)]
|
|
39
|
+
pub enum Credential {
|
|
40
|
+
/// Bearer token (Azure AD, Vertex OAuth2, generic OIDC).
|
|
41
|
+
BearerToken(SecretString),
|
|
42
|
+
/// AWS credentials for SigV4 signing.
|
|
43
|
+
AwsCredentials {
|
|
44
|
+
access_key_id: SecretString,
|
|
45
|
+
secret_access_key: SecretString,
|
|
46
|
+
session_token: Option<SecretString>,
|
|
47
|
+
},
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/// A static credential provider that always returns the same bearer token.
|
|
51
|
+
/// Useful for testing or when tokens are managed externally.
|
|
52
|
+
pub struct StaticTokenProvider {
|
|
53
|
+
token: SecretString,
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
impl StaticTokenProvider {
|
|
57
|
+
#[must_use]
|
|
58
|
+
pub fn new(token: SecretString) -> Self {
|
|
59
|
+
Self { token }
|
|
60
|
+
}
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
impl CredentialProvider for StaticTokenProvider {
|
|
64
|
+
fn resolve(&self) -> BoxFuture<'_, Credential> {
|
|
65
|
+
let token = self.token.clone();
|
|
66
|
+
Box::pin(async move { Ok(Credential::BearerToken(token)) })
|
|
67
|
+
}
|
|
68
|
+
}
|