koava 0.1.0__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {koava-0.1.0 → koava-0.1.1}/.github/workflows/koava-release.yml +51 -49
- {koava-0.1.0 → koava-0.1.1}/Cargo.lock +1 -1
- {koava-0.1.0 → koava-0.1.1}/Cargo.toml +1 -1
- {koava-0.1.0 → koava-0.1.1}/PKG-INFO +1 -1
- {koava-0.1.0 → koava-0.1.1}/pyproject.toml +1 -1
- {koava-0.1.0 → koava-0.1.1}/src/auth.rs +2 -14
- {koava-0.1.0 → koava-0.1.1}/src/cli.rs +16 -18
- {koava-0.1.0 → koava-0.1.1}/src/config.rs +37 -0
- {koava-0.1.0 → koava-0.1.1}/src/model.rs +7 -8
- {koava-0.1.0 → koava-0.1.1}/src/upload.rs +1 -1
- {koava-0.1.0 → koava-0.1.1}/src/utils.rs +49 -59
- {koava-0.1.0 → koava-0.1.1}/.github/workflows/ci.yml +0 -0
- {koava-0.1.0 → koava-0.1.1}/.gitignore +0 -0
- {koava-0.1.0 → koava-0.1.1}/FEATURES.md +0 -0
- {koava-0.1.0 → koava-0.1.1}/LICENSE +0 -0
- {koava-0.1.0 → koava-0.1.1}/MANIFEST.in +0 -0
- {koava-0.1.0 → koava-0.1.1}/README.md +0 -0
- {koava-0.1.0 → koava-0.1.1}/custom_config.json +0 -0
- {koava-0.1.0 → koava-0.1.1}/proptest-regressions/config.txt +0 -0
- {koava-0.1.0 → koava-0.1.1}/proptest-regressions/store.txt +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/Cargo.toml +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/api/auth.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/api/mod.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/api/model.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/common/auth.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/common/key.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/common/mod.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/common/model.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/protocol/src/lib.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/client.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/encrypt.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/error.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/file.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/huggingface.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/key.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/main.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/policy.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/push.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/security.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/store.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/templates.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/tests/mocks.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/tests/mod.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/tests/utils.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/src/ui.rs +0 -0
- {koava-0.1.0 → koava-0.1.1}/templates/KOALAVAULT_PROPRIETARY_LICENSE.txt +0 -0
- {koava-0.1.0 → koava-0.1.1}/templates/README_ENCRYPTED_MODEL.md +0 -0
|
@@ -177,12 +177,58 @@ jobs:
|
|
|
177
177
|
echo "Is Release: $IS_RELEASE"
|
|
178
178
|
echo "=============================="
|
|
179
179
|
|
|
180
|
-
- name:
|
|
180
|
+
- name: Check versions in files
|
|
181
181
|
run: |
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
182
|
+
EVENT_NAME="${{ github.event_name }}"
|
|
183
|
+
REF_TYPE="${{ github.ref_type }}"
|
|
184
|
+
IS_TAG_PUSH=false
|
|
185
|
+
|
|
186
|
+
if [ "$EVENT_NAME" = "push" ] && [ "$REF_TYPE" = "tag" ]; then
|
|
187
|
+
IS_TAG_PUSH=true
|
|
188
|
+
TAG_NAME="${{ github.ref_name }}"
|
|
189
|
+
TAG_VERSION="${TAG_NAME#v}"
|
|
190
|
+
fi
|
|
191
|
+
|
|
192
|
+
# Extract version from Cargo.toml
|
|
193
|
+
CARGO_VERSION=$(grep -E '^version = ' Cargo.toml | sed -E 's/^version = "([^"]+)"/\1/')
|
|
194
|
+
|
|
195
|
+
# Extract version from pyproject.toml
|
|
196
|
+
PYPROJECT_VERSION=$(grep -E '^version = ' pyproject.toml | sed -E 's/^version = "([^"]+)"/\1/')
|
|
197
|
+
|
|
198
|
+
echo "=== Version Check ==="
|
|
199
|
+
echo "Cargo.toml version: $CARGO_VERSION"
|
|
200
|
+
echo "pyproject.toml version: $PYPROJECT_VERSION"
|
|
201
|
+
|
|
202
|
+
if [ "$IS_TAG_PUSH" = "true" ]; then
|
|
203
|
+
echo "Tag version: $TAG_VERSION"
|
|
204
|
+
echo ""
|
|
205
|
+
echo "Checking three-way version match (Cargo.toml, pyproject.toml, and tag)..."
|
|
206
|
+
|
|
207
|
+
# Check if all three versions match
|
|
208
|
+
if [ "$CARGO_VERSION" != "$PYPROJECT_VERSION" ]; then
|
|
209
|
+
echo "::error::Version mismatch: Cargo.toml has '$CARGO_VERSION' but pyproject.toml has '$PYPROJECT_VERSION'"
|
|
210
|
+
exit 1
|
|
211
|
+
fi
|
|
212
|
+
|
|
213
|
+
if [ "$CARGO_VERSION" != "$TAG_VERSION" ]; then
|
|
214
|
+
echo "::error::Version mismatch: Cargo.toml/pyproject.toml have '$CARGO_VERSION' but tag has '$TAG_VERSION'"
|
|
215
|
+
exit 1
|
|
216
|
+
fi
|
|
217
|
+
|
|
218
|
+
echo "✓ All three versions match: $CARGO_VERSION"
|
|
219
|
+
else
|
|
220
|
+
echo ""
|
|
221
|
+
echo "Checking two-way version match (Cargo.toml and pyproject.toml)..."
|
|
222
|
+
|
|
223
|
+
# Check if Cargo.toml and pyproject.toml versions match
|
|
224
|
+
if [ "$CARGO_VERSION" != "$PYPROJECT_VERSION" ]; then
|
|
225
|
+
echo "::error::Version mismatch: Cargo.toml has '$CARGO_VERSION' but pyproject.toml has '$PYPROJECT_VERSION'"
|
|
226
|
+
exit 1
|
|
227
|
+
fi
|
|
228
|
+
|
|
229
|
+
echo "✓ Both files have the same version: $CARGO_VERSION"
|
|
230
|
+
fi
|
|
231
|
+
echo "===================="
|
|
186
232
|
|
|
187
233
|
build-linux-x86_64:
|
|
188
234
|
needs: prepare
|
|
@@ -211,12 +257,6 @@ jobs:
|
|
|
211
257
|
- name: Install maturin
|
|
212
258
|
run: pip install maturin
|
|
213
259
|
|
|
214
|
-
- name: Update versions
|
|
215
|
-
run: |
|
|
216
|
-
VERSION="${{ needs.prepare.outputs.version }}"
|
|
217
|
-
sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
|
218
|
-
sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
|
|
219
|
-
|
|
220
260
|
- name: Build for Linux x86_64
|
|
221
261
|
run: maturin build --release --target x86_64-unknown-linux-gnu --features cert-pinning
|
|
222
262
|
|
|
@@ -270,12 +310,6 @@ jobs:
|
|
|
270
310
|
- name: Install maturin
|
|
271
311
|
run: pip install maturin
|
|
272
312
|
|
|
273
|
-
- name: Update versions
|
|
274
|
-
run: |
|
|
275
|
-
VERSION="${{ needs.prepare.outputs.version }}"
|
|
276
|
-
sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
|
277
|
-
sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
|
|
278
|
-
|
|
279
313
|
- name: Build for Linux ARM64
|
|
280
314
|
run: maturin build --release --target aarch64-unknown-linux-gnu --features cert-pinning
|
|
281
315
|
|
|
@@ -307,12 +341,6 @@ jobs:
|
|
|
307
341
|
- name: Install maturin
|
|
308
342
|
run: pip install maturin
|
|
309
343
|
|
|
310
|
-
- name: Update versions
|
|
311
|
-
run: |
|
|
312
|
-
VERSION="${{ needs.prepare.outputs.version }}"
|
|
313
|
-
sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
|
314
|
-
sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
|
|
315
|
-
|
|
316
344
|
- name: Build for macOS x86_64
|
|
317
345
|
run: maturin build --release --target x86_64-apple-darwin --features cert-pinning
|
|
318
346
|
|
|
@@ -343,12 +371,6 @@ jobs:
|
|
|
343
371
|
|
|
344
372
|
- name: Install maturin
|
|
345
373
|
run: pip install maturin
|
|
346
|
-
|
|
347
|
-
- name: Update versions
|
|
348
|
-
run: |
|
|
349
|
-
VERSION="${{ needs.prepare.outputs.version }}"
|
|
350
|
-
sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
|
351
|
-
sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
|
|
352
374
|
|
|
353
375
|
- name: Build for macOS ARM64
|
|
354
376
|
run: maturin build --release --target aarch64-apple-darwin --features cert-pinning
|
|
@@ -380,13 +402,6 @@ jobs:
|
|
|
380
402
|
|
|
381
403
|
- name: Install maturin
|
|
382
404
|
run: pip install maturin
|
|
383
|
-
|
|
384
|
-
- name: Update versions
|
|
385
|
-
shell: pwsh
|
|
386
|
-
run: |
|
|
387
|
-
$version = "${{ needs.prepare.outputs.version }}"
|
|
388
|
-
(Get-Content pyproject.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content pyproject.toml
|
|
389
|
-
(Get-Content Cargo.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content Cargo.toml
|
|
390
405
|
|
|
391
406
|
- name: Build for Windows x64
|
|
392
407
|
run: maturin build --release --target x86_64-pc-windows-msvc --features cert-pinning
|
|
@@ -418,13 +433,6 @@ jobs:
|
|
|
418
433
|
|
|
419
434
|
- name: Install maturin
|
|
420
435
|
run: pip install maturin
|
|
421
|
-
|
|
422
|
-
- name: Update versions
|
|
423
|
-
shell: pwsh
|
|
424
|
-
run: |
|
|
425
|
-
$version = "${{ needs.prepare.outputs.version }}"
|
|
426
|
-
(Get-Content pyproject.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content pyproject.toml
|
|
427
|
-
(Get-Content Cargo.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content Cargo.toml
|
|
428
436
|
|
|
429
437
|
- name: Build for Windows ARM64
|
|
430
438
|
run: maturin build --release --target aarch64-pc-windows-msvc --features cert-pinning
|
|
@@ -458,12 +466,6 @@ jobs:
|
|
|
458
466
|
- name: Install maturin
|
|
459
467
|
run: pip install maturin
|
|
460
468
|
|
|
461
|
-
- name: Update versions
|
|
462
|
-
run: |
|
|
463
|
-
VERSION="${{ needs.prepare.outputs.version }}"
|
|
464
|
-
sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
|
|
465
|
-
sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
|
|
466
|
-
|
|
467
469
|
- name: Build source distribution (sdist)
|
|
468
470
|
run: maturin sdist
|
|
469
471
|
|
|
@@ -210,13 +210,7 @@ impl AuthService {
|
|
|
210
210
|
|
|
211
211
|
if let Some(username) = &username_opt {
|
|
212
212
|
// Normalize endpoint to include /api
|
|
213
|
-
let normalized_endpoint =
|
|
214
|
-
self.config.endpoint.clone()
|
|
215
|
-
} else if self.config.endpoint.ends_with("/") {
|
|
216
|
-
format!("{}api", self.config.endpoint)
|
|
217
|
-
} else {
|
|
218
|
-
format!("{}/api", self.config.endpoint)
|
|
219
|
-
};
|
|
213
|
+
let normalized_endpoint = self.config.get_api_endpoint();
|
|
220
214
|
|
|
221
215
|
// Build HTTP client (no proxy for localhost)
|
|
222
216
|
let mut builder = reqwest::Client::builder()
|
|
@@ -281,13 +275,7 @@ impl AuthService {
|
|
|
281
275
|
.map_err(|e| KoavaError::config(e.to_string()))?;
|
|
282
276
|
|
|
283
277
|
// Normalize endpoint to include /api
|
|
284
|
-
let normalized_endpoint =
|
|
285
|
-
self.config.endpoint.clone()
|
|
286
|
-
} else if self.config.endpoint.ends_with("/") {
|
|
287
|
-
format!("{}api", self.config.endpoint)
|
|
288
|
-
} else {
|
|
289
|
-
format!("{}/api", self.config.endpoint)
|
|
290
|
-
};
|
|
278
|
+
let normalized_endpoint = self.config.get_api_endpoint();
|
|
291
279
|
|
|
292
280
|
// Disable proxy for localhost to avoid corporate/system proxy causing 502
|
|
293
281
|
let mut builder = reqwest::Client::builder()
|
|
@@ -40,6 +40,16 @@ impl CliHandler {
|
|
|
40
40
|
}
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
+
/// Get the authenticated client and config
|
|
44
|
+
async fn get_authenticated_client(
|
|
45
|
+
&self,
|
|
46
|
+
) -> Result<(Config, std::sync::Arc<crate::HttpClient>)> {
|
|
47
|
+
let config = self.load_config().await?;
|
|
48
|
+
let auth_service = crate::auth::AuthService::new(config.clone());
|
|
49
|
+
let client = auth_service.get_authenticated_client().await?;
|
|
50
|
+
Ok((config, client))
|
|
51
|
+
}
|
|
52
|
+
|
|
43
53
|
/// Execute a CLI command
|
|
44
54
|
pub async fn execute(&mut self, command: Commands) -> Result<()> {
|
|
45
55
|
match command {
|
|
@@ -59,9 +69,7 @@ impl CliHandler {
|
|
|
59
69
|
|
|
60
70
|
/// Handle encrypt command
|
|
61
71
|
async fn handle_encrypt(&mut self, args: EncryptArgs) -> Result<()> {
|
|
62
|
-
let config = self.
|
|
63
|
-
let auth_service = crate::auth::AuthService::new(config.clone());
|
|
64
|
-
let client = auth_service.get_authenticated_client().await?;
|
|
72
|
+
let (config, client) = self.get_authenticated_client().await?;
|
|
65
73
|
let encrypt_service = EncryptService::new(config);
|
|
66
74
|
encrypt_service.encrypt(&*client, args).await
|
|
67
75
|
}
|
|
@@ -75,18 +83,14 @@ impl CliHandler {
|
|
|
75
83
|
|
|
76
84
|
/// Handle upload command - upload encrypted model to server
|
|
77
85
|
async fn handle_upload(&mut self, args: UploadArgs) -> Result<()> {
|
|
78
|
-
let
|
|
79
|
-
let auth_service = crate::auth::AuthService::new(config);
|
|
80
|
-
let client = auth_service.get_authenticated_client().await?;
|
|
86
|
+
let (_config, client) = self.get_authenticated_client().await?;
|
|
81
87
|
let service = crate::model::ModelService::new();
|
|
82
88
|
service.upload(client, args).await
|
|
83
89
|
}
|
|
84
90
|
|
|
85
91
|
/// Handle remove command
|
|
86
92
|
async fn handle_remove(&mut self, args: RemoveArgs) -> Result<()> {
|
|
87
|
-
let
|
|
88
|
-
let auth_service = crate::auth::AuthService::new(config);
|
|
89
|
-
let client = auth_service.get_authenticated_client().await?;
|
|
93
|
+
let (_config, client) = self.get_authenticated_client().await?;
|
|
90
94
|
let service = crate::model::ModelService::new();
|
|
91
95
|
service.remove(client, args).await
|
|
92
96
|
}
|
|
@@ -147,27 +151,21 @@ impl CliHandler {
|
|
|
147
151
|
|
|
148
152
|
/// Handle list command - list files for a model on server
|
|
149
153
|
async fn handle_list(&mut self, args: ListArgs) -> Result<()> {
|
|
150
|
-
let
|
|
151
|
-
let auth_service = crate::auth::AuthService::new(config);
|
|
152
|
-
let client = auth_service.get_authenticated_client().await?;
|
|
154
|
+
let (_config, client) = self.get_authenticated_client().await?;
|
|
153
155
|
let service = crate::model::ModelService::new();
|
|
154
156
|
service.list(client, args).await
|
|
155
157
|
}
|
|
156
158
|
|
|
157
159
|
/// Handle create command
|
|
158
160
|
async fn handle_create(&mut self, args: CreateArgs) -> Result<()> {
|
|
159
|
-
let
|
|
160
|
-
let auth_service = crate::auth::AuthService::new(config);
|
|
161
|
-
let client = auth_service.get_authenticated_client().await?;
|
|
161
|
+
let (_config, client) = self.get_authenticated_client().await?;
|
|
162
162
|
let service = crate::model::ModelService::new();
|
|
163
163
|
service.create(client, args).await
|
|
164
164
|
}
|
|
165
165
|
|
|
166
166
|
/// Handle push command: create -> encrypt -> upload -> hf create repo -> upload to hf -> update model
|
|
167
167
|
async fn handle_push(&mut self, args: crate::PushArgs) -> Result<()> {
|
|
168
|
-
let config = self.
|
|
169
|
-
let auth_service = crate::auth::AuthService::new(config.clone());
|
|
170
|
-
let client = auth_service.get_authenticated_client().await?;
|
|
168
|
+
let (config, client) = self.get_authenticated_client().await?;
|
|
171
169
|
let mut service = crate::push::PushService::new(config);
|
|
172
170
|
service.push(client, args).await
|
|
173
171
|
}
|
|
@@ -203,6 +203,18 @@ impl Config {
|
|
|
203
203
|
crate::security::verify_certificate_pinning(&self.endpoint, &self.server_public_key).await
|
|
204
204
|
}
|
|
205
205
|
|
|
206
|
+
/// Get the API base endpoint (ensuring /api suffix and proper scheme)
|
|
207
|
+
pub fn get_api_endpoint(&self) -> String {
|
|
208
|
+
let base = self.endpoint_url("");
|
|
209
|
+
let base = base.trim_end_matches('/');
|
|
210
|
+
|
|
211
|
+
if base.ends_with("/api") {
|
|
212
|
+
base.to_string()
|
|
213
|
+
} else {
|
|
214
|
+
format!("{}/api", base)
|
|
215
|
+
}
|
|
216
|
+
}
|
|
217
|
+
|
|
206
218
|
/// Get the full URL for an endpoint
|
|
207
219
|
pub fn endpoint_url(&self, endpoint: &str) -> String {
|
|
208
220
|
let endpoint = endpoint.strip_prefix('/').unwrap_or(endpoint);
|
|
@@ -645,4 +657,29 @@ mod tests {
|
|
|
645
657
|
config.endpoint = "api.test.com".to_string();
|
|
646
658
|
assert_eq!(config.endpoint_url(""), "https://api.test.com/");
|
|
647
659
|
}
|
|
660
|
+
|
|
661
|
+
#[test]
|
|
662
|
+
fn test_get_api_endpoint() {
|
|
663
|
+
let mut config = Config::default();
|
|
664
|
+
|
|
665
|
+
// Case 1: No /api suffix
|
|
666
|
+
config.endpoint = "https://example.com".to_string();
|
|
667
|
+
assert_eq!(config.get_api_endpoint(), "https://example.com/api");
|
|
668
|
+
|
|
669
|
+
// Case 2: With /api suffix
|
|
670
|
+
config.endpoint = "https://example.com/api".to_string();
|
|
671
|
+
assert_eq!(config.get_api_endpoint(), "https://example.com/api");
|
|
672
|
+
|
|
673
|
+
// Case 3: With trailing slash
|
|
674
|
+
config.endpoint = "https://example.com/".to_string();
|
|
675
|
+
assert_eq!(config.get_api_endpoint(), "https://example.com/api");
|
|
676
|
+
|
|
677
|
+
// Case 4: With /api/ suffix
|
|
678
|
+
config.endpoint = "https://example.com/api/".to_string();
|
|
679
|
+
assert_eq!(config.get_api_endpoint(), "https://example.com/api");
|
|
680
|
+
|
|
681
|
+
// Case 5: Missing scheme (adds https)
|
|
682
|
+
config.endpoint = "example.com".to_string();
|
|
683
|
+
assert_eq!(config.get_api_endpoint(), "https://example.com/api");
|
|
684
|
+
}
|
|
648
685
|
}
|
|
@@ -241,25 +241,24 @@ pub async fn encrypt_safetensors_file(
|
|
|
241
241
|
/// Extract non-reserved metadata from safetensors file header
|
|
242
242
|
/// This parses the JSON header and filters out reserved fields (starting with "__")
|
|
243
243
|
fn extract_metadata_from_header(file_content: &[u8]) -> Result<Option<HashMap<String, String>>> {
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
// Read header length (first 8 bytes)
|
|
247
|
-
if file_content.len() < N_LEN {
|
|
244
|
+
// Read header length
|
|
245
|
+
if file_content.len() < CryptoUtils::HEADER_LENGTH_SIZE {
|
|
248
246
|
return Ok(None);
|
|
249
247
|
}
|
|
250
248
|
|
|
251
249
|
let header_len = u64::from_le_bytes(
|
|
252
|
-
file_content[..
|
|
250
|
+
file_content[..CryptoUtils::HEADER_LENGTH_SIZE]
|
|
253
251
|
.try_into()
|
|
254
252
|
.map_err(|_| KoavaError::io("Header parsing", "Invalid header length"))?,
|
|
255
253
|
) as usize;
|
|
256
254
|
|
|
257
|
-
if file_content.len() <
|
|
255
|
+
if file_content.len() < CryptoUtils::HEADER_LENGTH_SIZE + header_len {
|
|
258
256
|
return Ok(None);
|
|
259
257
|
}
|
|
260
258
|
|
|
261
259
|
// Parse JSON header
|
|
262
|
-
let header_bytes = &file_content
|
|
260
|
+
let header_bytes = &file_content
|
|
261
|
+
[CryptoUtils::HEADER_LENGTH_SIZE..CryptoUtils::HEADER_LENGTH_SIZE + header_len];
|
|
263
262
|
let header_str = std::str::from_utf8(header_bytes)
|
|
264
263
|
.map_err(|e| KoavaError::io("Header parsing", format!("Invalid UTF-8: {}", e)))?;
|
|
265
264
|
|
|
@@ -267,7 +266,7 @@ fn extract_metadata_from_header(file_content: &[u8]) -> Result<Option<HashMap<St
|
|
|
267
266
|
.map_err(|e| KoavaError::serialization(format!("Invalid JSON in header: {}", e)))?;
|
|
268
267
|
|
|
269
268
|
// Extract __metadata__ field if present
|
|
270
|
-
if let Some(metadata_obj) = header_json.get(
|
|
269
|
+
if let Some(metadata_obj) = header_json.get(CryptoUtils::METADATA_KEY) {
|
|
271
270
|
if let Some(metadata_map) = metadata_obj.as_object() {
|
|
272
271
|
// Filter out reserved fields (starting with "__")
|
|
273
272
|
let mut filtered_metadata = HashMap::new();
|
|
@@ -215,7 +215,7 @@ impl<C: ApiClient + ?Sized> UploadService<C> {
|
|
|
215
215
|
continue;
|
|
216
216
|
}
|
|
217
217
|
|
|
218
|
-
match CryptoUtils::extract_safetensors_header(&file.path) {
|
|
218
|
+
match CryptoUtils::extract_safetensors_header(&file.path).await {
|
|
219
219
|
Ok(header_data) => {
|
|
220
220
|
let file_info = FileInfo {
|
|
221
221
|
id: None,
|
|
@@ -3,8 +3,6 @@
|
|
|
3
3
|
use base64::{engine::general_purpose, Engine};
|
|
4
4
|
use serde::{Deserialize, Serialize};
|
|
5
5
|
use sha2::{Digest, Sha256};
|
|
6
|
-
use std::fs::File;
|
|
7
|
-
use std::io::{BufReader, Read};
|
|
8
6
|
use std::path::Path;
|
|
9
7
|
|
|
10
8
|
use crate::error::{KoavaError, Result};
|
|
@@ -46,6 +44,11 @@ pub struct FileInfo {
|
|
|
46
44
|
pub struct CryptoUtils;
|
|
47
45
|
|
|
48
46
|
impl CryptoUtils {
|
|
47
|
+
pub const HEADER_LENGTH_SIZE: usize = 8;
|
|
48
|
+
pub const MAX_HEADER_SIZE: usize = 1024 * 1024;
|
|
49
|
+
pub const METADATA_KEY: &str = "__metadata__";
|
|
50
|
+
pub const ENCRYPTION_KEY: &str = "__encryption__";
|
|
51
|
+
|
|
49
52
|
/// Calculate SHA256 hash of a string (used for header hashing)
|
|
50
53
|
pub fn calculate_sha256_hash(data: &str) -> String {
|
|
51
54
|
let mut hasher = Sha256::new();
|
|
@@ -60,34 +63,48 @@ impl CryptoUtils {
|
|
|
60
63
|
hex::encode(hasher.finalize())
|
|
61
64
|
}
|
|
62
65
|
|
|
63
|
-
///
|
|
64
|
-
/// This
|
|
65
|
-
pub fn
|
|
66
|
+
/// Read the safetensors header raw bytes (excluding length prefix).
|
|
67
|
+
/// This function handles the 8-byte length reading and size validation.
|
|
68
|
+
pub async fn read_safetensors_header_raw<P: AsRef<Path>>(file_path: P) -> Result<Vec<u8>> {
|
|
66
69
|
let file_path = file_path.as_ref();
|
|
67
|
-
let mut file =
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
);
|
|
70
|
+
let mut file = tokio::fs::File::open(file_path)
|
|
71
|
+
.await
|
|
72
|
+
.map_err(|e| KoavaError::io("File open", format!("Failed to open file: {}", e)))?;
|
|
71
73
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
file.read_exact(&mut header_len_bytes)
|
|
75
|
-
|
|
74
|
+
let mut header_len_bytes = [0u8; Self::HEADER_LENGTH_SIZE];
|
|
75
|
+
use tokio::io::AsyncReadExt;
|
|
76
|
+
file.read_exact(&mut header_len_bytes).await.map_err(|e| {
|
|
77
|
+
KoavaError::io(
|
|
78
|
+
"Header read",
|
|
79
|
+
format!("Failed to read header length: {}", e),
|
|
80
|
+
)
|
|
81
|
+
})?;
|
|
76
82
|
|
|
77
83
|
let header_len = u64::from_le_bytes(header_len_bytes) as usize;
|
|
78
84
|
|
|
79
|
-
if header_len >
|
|
80
|
-
|
|
81
|
-
|
|
85
|
+
if header_len > Self::MAX_HEADER_SIZE {
|
|
86
|
+
return Err(KoavaError::validation(
|
|
87
|
+
"Header too large (exceeds 1MB limit)",
|
|
88
|
+
));
|
|
82
89
|
}
|
|
83
90
|
|
|
84
|
-
// Read header JSON
|
|
85
91
|
let mut header_json_bytes = vec![0u8; header_len];
|
|
86
|
-
file.read_exact(&mut header_json_bytes)
|
|
87
|
-
|
|
92
|
+
file.read_exact(&mut header_json_bytes).await.map_err(|e| {
|
|
93
|
+
KoavaError::io("Header read", format!("Failed to read header JSON: {}", e))
|
|
94
|
+
})?;
|
|
95
|
+
|
|
96
|
+
Ok(header_json_bytes)
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
/// Extract header data from a Safetensors file
|
|
100
|
+
/// This reads the first 8 bytes (header length) + header JSON and encodes as base64
|
|
101
|
+
pub async fn extract_safetensors_header<P: AsRef<Path>>(file_path: P) -> Result<String> {
|
|
102
|
+
let header_json_bytes = Self::read_safetensors_header_raw(file_path).await?;
|
|
103
|
+
let header_len = header_json_bytes.len();
|
|
104
|
+
let header_len_bytes = (header_len as u64).to_le_bytes();
|
|
88
105
|
|
|
89
106
|
// Combine header length + header JSON
|
|
90
|
-
let mut header_data = Vec::with_capacity(
|
|
107
|
+
let mut header_data = Vec::with_capacity(Self::HEADER_LENGTH_SIZE + header_len);
|
|
91
108
|
header_data.extend_from_slice(&header_len_bytes);
|
|
92
109
|
header_data.extend_from_slice(&header_json_bytes);
|
|
93
110
|
|
|
@@ -126,37 +143,7 @@ impl CryptoUtils {
|
|
|
126
143
|
/// Detect if a safetensors file is encrypted by checking its header metadata
|
|
127
144
|
/// This function only reads the file header portion, not the entire file
|
|
128
145
|
pub async fn detect_safetensors_encryption<P: AsRef<Path>>(file_path: P) -> Result<bool> {
|
|
129
|
-
let
|
|
130
|
-
|
|
131
|
-
// Open file and read only the header portion
|
|
132
|
-
let mut file = tokio::fs::File::open(file_path)
|
|
133
|
-
.await
|
|
134
|
-
.map_err(|e| KoavaError::io("File open", format!("Failed to open file: {}", e)))?;
|
|
135
|
-
|
|
136
|
-
// Read header length (first 8 bytes, little endian)
|
|
137
|
-
let mut header_len_bytes = [0u8; 8];
|
|
138
|
-
use tokio::io::AsyncReadExt;
|
|
139
|
-
file.read_exact(&mut header_len_bytes).await.map_err(|e| {
|
|
140
|
-
KoavaError::io(
|
|
141
|
-
"Header read",
|
|
142
|
-
format!("Failed to read header length: {}", e),
|
|
143
|
-
)
|
|
144
|
-
})?;
|
|
145
|
-
|
|
146
|
-
let header_len = u64::from_le_bytes(header_len_bytes) as usize;
|
|
147
|
-
|
|
148
|
-
if header_len > 1024 * 1024 {
|
|
149
|
-
// 1MB limit for safety
|
|
150
|
-
return Err(KoavaError::validation(
|
|
151
|
-
"Header too large (exceeds 1MB limit)",
|
|
152
|
-
));
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
// Read header JSON
|
|
156
|
-
let mut header_json_bytes = vec![0u8; header_len];
|
|
157
|
-
file.read_exact(&mut header_json_bytes).await.map_err(|e| {
|
|
158
|
-
KoavaError::io("Header read", format!("Failed to read header JSON: {}", e))
|
|
159
|
-
})?;
|
|
146
|
+
let header_json_bytes = Self::read_safetensors_header_raw(file_path).await?;
|
|
160
147
|
|
|
161
148
|
// Parse the header JSON to check for encryption metadata
|
|
162
149
|
let header_json: serde_json::Value =
|
|
@@ -165,9 +152,9 @@ impl CryptoUtils {
|
|
|
165
152
|
})?;
|
|
166
153
|
|
|
167
154
|
// Check if the file contains encryption metadata
|
|
168
|
-
if let Some(metadata) = header_json.get(
|
|
155
|
+
if let Some(metadata) = header_json.get(Self::METADATA_KEY) {
|
|
169
156
|
if let Some(metadata_obj) = metadata.as_object() {
|
|
170
|
-
if metadata_obj.contains_key(
|
|
157
|
+
if metadata_obj.contains_key(Self::ENCRYPTION_KEY) {
|
|
171
158
|
return Ok(true);
|
|
172
159
|
}
|
|
173
160
|
}
|
|
@@ -234,12 +221,14 @@ mod tests {
|
|
|
234
221
|
assert!(result.is_err());
|
|
235
222
|
}
|
|
236
223
|
|
|
237
|
-
#[test]
|
|
238
|
-
fn test_extract_safetensors_header() {
|
|
224
|
+
#[tokio::test]
|
|
225
|
+
async fn test_extract_safetensors_header() {
|
|
239
226
|
let temp_dir = create_temp_dir();
|
|
240
227
|
let file_path = create_mock_safetensors_file(&temp_dir, "model.safetensors", false);
|
|
241
228
|
|
|
242
|
-
let header_b64 = CryptoUtils::extract_safetensors_header(&file_path)
|
|
229
|
+
let header_b64 = CryptoUtils::extract_safetensors_header(&file_path)
|
|
230
|
+
.await
|
|
231
|
+
.unwrap();
|
|
243
232
|
|
|
244
233
|
// Verify it's valid base64
|
|
245
234
|
let decoded = CryptoUtils::decode_base64(&header_b64).unwrap();
|
|
@@ -248,9 +237,10 @@ mod tests {
|
|
|
248
237
|
assert_eq!(decoded.len() >= 8, true);
|
|
249
238
|
}
|
|
250
239
|
|
|
251
|
-
#[test]
|
|
252
|
-
fn test_extract_safetensors_header_nonexistent_file() {
|
|
253
|
-
let result =
|
|
240
|
+
#[tokio::test]
|
|
241
|
+
async fn test_extract_safetensors_header_nonexistent_file() {
|
|
242
|
+
let result =
|
|
243
|
+
CryptoUtils::extract_safetensors_header("/nonexistent/file.safetensors").await;
|
|
254
244
|
assert!(result.is_err());
|
|
255
245
|
}
|
|
256
246
|
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|