koava 0.1.0__tar.gz → 0.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {koava-0.1.0 → koava-0.1.1}/.github/workflows/koava-release.yml +51 -49
  2. {koava-0.1.0 → koava-0.1.1}/Cargo.lock +1 -1
  3. {koava-0.1.0 → koava-0.1.1}/Cargo.toml +1 -1
  4. {koava-0.1.0 → koava-0.1.1}/PKG-INFO +1 -1
  5. {koava-0.1.0 → koava-0.1.1}/pyproject.toml +1 -1
  6. {koava-0.1.0 → koava-0.1.1}/src/auth.rs +2 -14
  7. {koava-0.1.0 → koava-0.1.1}/src/cli.rs +16 -18
  8. {koava-0.1.0 → koava-0.1.1}/src/config.rs +37 -0
  9. {koava-0.1.0 → koava-0.1.1}/src/model.rs +7 -8
  10. {koava-0.1.0 → koava-0.1.1}/src/upload.rs +1 -1
  11. {koava-0.1.0 → koava-0.1.1}/src/utils.rs +49 -59
  12. {koava-0.1.0 → koava-0.1.1}/.github/workflows/ci.yml +0 -0
  13. {koava-0.1.0 → koava-0.1.1}/.gitignore +0 -0
  14. {koava-0.1.0 → koava-0.1.1}/FEATURES.md +0 -0
  15. {koava-0.1.0 → koava-0.1.1}/LICENSE +0 -0
  16. {koava-0.1.0 → koava-0.1.1}/MANIFEST.in +0 -0
  17. {koava-0.1.0 → koava-0.1.1}/README.md +0 -0
  18. {koava-0.1.0 → koava-0.1.1}/custom_config.json +0 -0
  19. {koava-0.1.0 → koava-0.1.1}/proptest-regressions/config.txt +0 -0
  20. {koava-0.1.0 → koava-0.1.1}/proptest-regressions/store.txt +0 -0
  21. {koava-0.1.0 → koava-0.1.1}/protocol/Cargo.toml +0 -0
  22. {koava-0.1.0 → koava-0.1.1}/protocol/src/api/auth.rs +0 -0
  23. {koava-0.1.0 → koava-0.1.1}/protocol/src/api/mod.rs +0 -0
  24. {koava-0.1.0 → koava-0.1.1}/protocol/src/api/model.rs +0 -0
  25. {koava-0.1.0 → koava-0.1.1}/protocol/src/common/auth.rs +0 -0
  26. {koava-0.1.0 → koava-0.1.1}/protocol/src/common/key.rs +0 -0
  27. {koava-0.1.0 → koava-0.1.1}/protocol/src/common/mod.rs +0 -0
  28. {koava-0.1.0 → koava-0.1.1}/protocol/src/common/model.rs +0 -0
  29. {koava-0.1.0 → koava-0.1.1}/protocol/src/lib.rs +0 -0
  30. {koava-0.1.0 → koava-0.1.1}/src/client.rs +0 -0
  31. {koava-0.1.0 → koava-0.1.1}/src/encrypt.rs +0 -0
  32. {koava-0.1.0 → koava-0.1.1}/src/error.rs +0 -0
  33. {koava-0.1.0 → koava-0.1.1}/src/file.rs +0 -0
  34. {koava-0.1.0 → koava-0.1.1}/src/huggingface.rs +0 -0
  35. {koava-0.1.0 → koava-0.1.1}/src/key.rs +0 -0
  36. {koava-0.1.0 → koava-0.1.1}/src/main.rs +0 -0
  37. {koava-0.1.0 → koava-0.1.1}/src/policy.rs +0 -0
  38. {koava-0.1.0 → koava-0.1.1}/src/push.rs +0 -0
  39. {koava-0.1.0 → koava-0.1.1}/src/security.rs +0 -0
  40. {koava-0.1.0 → koava-0.1.1}/src/store.rs +0 -0
  41. {koava-0.1.0 → koava-0.1.1}/src/templates.rs +0 -0
  42. {koava-0.1.0 → koava-0.1.1}/src/tests/mocks.rs +0 -0
  43. {koava-0.1.0 → koava-0.1.1}/src/tests/mod.rs +0 -0
  44. {koava-0.1.0 → koava-0.1.1}/src/tests/utils.rs +0 -0
  45. {koava-0.1.0 → koava-0.1.1}/src/ui.rs +0 -0
  46. {koava-0.1.0 → koava-0.1.1}/templates/KOALAVAULT_PROPRIETARY_LICENSE.txt +0 -0
  47. {koava-0.1.0 → koava-0.1.1}/templates/README_ENCRYPTED_MODEL.md +0 -0
@@ -177,12 +177,58 @@ jobs:
177
177
  echo "Is Release: $IS_RELEASE"
178
178
  echo "=============================="
179
179
 
180
- - name: Update versions in files
180
+ - name: Check versions in files
181
181
  run: |
182
- VERSION="${{ steps.config.outputs.version }}"
183
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
184
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
185
- echo "Updated pyproject.toml and Cargo.toml version to $VERSION"
182
+ EVENT_NAME="${{ github.event_name }}"
183
+ REF_TYPE="${{ github.ref_type }}"
184
+ IS_TAG_PUSH=false
185
+
186
+ if [ "$EVENT_NAME" = "push" ] && [ "$REF_TYPE" = "tag" ]; then
187
+ IS_TAG_PUSH=true
188
+ TAG_NAME="${{ github.ref_name }}"
189
+ TAG_VERSION="${TAG_NAME#v}"
190
+ fi
191
+
192
+ # Extract version from Cargo.toml
193
+ CARGO_VERSION=$(grep -E '^version = ' Cargo.toml | sed -E 's/^version = "([^"]+)"/\1/')
194
+
195
+ # Extract version from pyproject.toml
196
+ PYPROJECT_VERSION=$(grep -E '^version = ' pyproject.toml | sed -E 's/^version = "([^"]+)"/\1/')
197
+
198
+ echo "=== Version Check ==="
199
+ echo "Cargo.toml version: $CARGO_VERSION"
200
+ echo "pyproject.toml version: $PYPROJECT_VERSION"
201
+
202
+ if [ "$IS_TAG_PUSH" = "true" ]; then
203
+ echo "Tag version: $TAG_VERSION"
204
+ echo ""
205
+ echo "Checking three-way version match (Cargo.toml, pyproject.toml, and tag)..."
206
+
207
+ # Check if all three versions match
208
+ if [ "$CARGO_VERSION" != "$PYPROJECT_VERSION" ]; then
209
+ echo "::error::Version mismatch: Cargo.toml has '$CARGO_VERSION' but pyproject.toml has '$PYPROJECT_VERSION'"
210
+ exit 1
211
+ fi
212
+
213
+ if [ "$CARGO_VERSION" != "$TAG_VERSION" ]; then
214
+ echo "::error::Version mismatch: Cargo.toml/pyproject.toml have '$CARGO_VERSION' but tag has '$TAG_VERSION'"
215
+ exit 1
216
+ fi
217
+
218
+ echo "✓ All three versions match: $CARGO_VERSION"
219
+ else
220
+ echo ""
221
+ echo "Checking two-way version match (Cargo.toml and pyproject.toml)..."
222
+
223
+ # Check if Cargo.toml and pyproject.toml versions match
224
+ if [ "$CARGO_VERSION" != "$PYPROJECT_VERSION" ]; then
225
+ echo "::error::Version mismatch: Cargo.toml has '$CARGO_VERSION' but pyproject.toml has '$PYPROJECT_VERSION'"
226
+ exit 1
227
+ fi
228
+
229
+ echo "✓ Both files have the same version: $CARGO_VERSION"
230
+ fi
231
+ echo "===================="
186
232
 
187
233
  build-linux-x86_64:
188
234
  needs: prepare
@@ -211,12 +257,6 @@ jobs:
211
257
  - name: Install maturin
212
258
  run: pip install maturin
213
259
 
214
- - name: Update versions
215
- run: |
216
- VERSION="${{ needs.prepare.outputs.version }}"
217
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
218
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
219
-
220
260
  - name: Build for Linux x86_64
221
261
  run: maturin build --release --target x86_64-unknown-linux-gnu --features cert-pinning
222
262
 
@@ -270,12 +310,6 @@ jobs:
270
310
  - name: Install maturin
271
311
  run: pip install maturin
272
312
 
273
- - name: Update versions
274
- run: |
275
- VERSION="${{ needs.prepare.outputs.version }}"
276
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
277
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
278
-
279
313
  - name: Build for Linux ARM64
280
314
  run: maturin build --release --target aarch64-unknown-linux-gnu --features cert-pinning
281
315
 
@@ -307,12 +341,6 @@ jobs:
307
341
  - name: Install maturin
308
342
  run: pip install maturin
309
343
 
310
- - name: Update versions
311
- run: |
312
- VERSION="${{ needs.prepare.outputs.version }}"
313
- sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
314
- sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
315
-
316
344
  - name: Build for macOS x86_64
317
345
  run: maturin build --release --target x86_64-apple-darwin --features cert-pinning
318
346
 
@@ -343,12 +371,6 @@ jobs:
343
371
 
344
372
  - name: Install maturin
345
373
  run: pip install maturin
346
-
347
- - name: Update versions
348
- run: |
349
- VERSION="${{ needs.prepare.outputs.version }}"
350
- sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
351
- sed -i '' "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
352
374
 
353
375
  - name: Build for macOS ARM64
354
376
  run: maturin build --release --target aarch64-apple-darwin --features cert-pinning
@@ -380,13 +402,6 @@ jobs:
380
402
 
381
403
  - name: Install maturin
382
404
  run: pip install maturin
383
-
384
- - name: Update versions
385
- shell: pwsh
386
- run: |
387
- $version = "${{ needs.prepare.outputs.version }}"
388
- (Get-Content pyproject.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content pyproject.toml
389
- (Get-Content Cargo.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content Cargo.toml
390
405
 
391
406
  - name: Build for Windows x64
392
407
  run: maturin build --release --target x86_64-pc-windows-msvc --features cert-pinning
@@ -418,13 +433,6 @@ jobs:
418
433
 
419
434
  - name: Install maturin
420
435
  run: pip install maturin
421
-
422
- - name: Update versions
423
- shell: pwsh
424
- run: |
425
- $version = "${{ needs.prepare.outputs.version }}"
426
- (Get-Content pyproject.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content pyproject.toml
427
- (Get-Content Cargo.toml) -replace '^version = ".*"', "version = `"$version`"" | Set-Content Cargo.toml
428
436
 
429
437
  - name: Build for Windows ARM64
430
438
  run: maturin build --release --target aarch64-pc-windows-msvc --features cert-pinning
@@ -458,12 +466,6 @@ jobs:
458
466
  - name: Install maturin
459
467
  run: pip install maturin
460
468
 
461
- - name: Update versions
462
- run: |
463
- VERSION="${{ needs.prepare.outputs.version }}"
464
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" pyproject.toml
465
- sed -i "s/^version = \".*\"/version = \"$VERSION\"/" Cargo.toml
466
-
467
469
  - name: Build source distribution (sdist)
468
470
  run: maturin sdist
469
471
 
@@ -1567,7 +1567,7 @@ dependencies = [
1567
1567
 
1568
1568
  [[package]]
1569
1569
  name = "koava"
1570
- version = "0.1.0"
1570
+ version = "0.1.1"
1571
1571
  dependencies = [
1572
1572
  "anyhow",
1573
1573
  "base64 0.22.1",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "koava"
3
- version = "0.1.0"
3
+ version = "0.1.1"
4
4
  edition = "2021"
5
5
  description = "KoalaVault model converter tool for producers"
6
6
  authors = ["KoalaVault Team"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: koava
3
- Version: 0.1.0
3
+ Version: 0.1.1
4
4
  Classifier: License :: OSI Approved :: Apache Software License
5
5
  Classifier: Programming Language :: Python
6
6
  Classifier: Programming Language :: Python :: 3
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "koava"
7
- version = "0.1.0"
7
+ version = "0.1.1"
8
8
  description = "KoalaVault model converter tool for producers"
9
9
  readme = "README.md"
10
10
  license = {file = "LICENSE"}
@@ -210,13 +210,7 @@ impl AuthService {
210
210
 
211
211
  if let Some(username) = &username_opt {
212
212
  // Normalize endpoint to include /api
213
- let normalized_endpoint = if self.config.endpoint.ends_with("/api") {
214
- self.config.endpoint.clone()
215
- } else if self.config.endpoint.ends_with("/") {
216
- format!("{}api", self.config.endpoint)
217
- } else {
218
- format!("{}/api", self.config.endpoint)
219
- };
213
+ let normalized_endpoint = self.config.get_api_endpoint();
220
214
 
221
215
  // Build HTTP client (no proxy for localhost)
222
216
  let mut builder = reqwest::Client::builder()
@@ -281,13 +275,7 @@ impl AuthService {
281
275
  .map_err(|e| KoavaError::config(e.to_string()))?;
282
276
 
283
277
  // Normalize endpoint to include /api
284
- let normalized_endpoint = if self.config.endpoint.ends_with("/api") {
285
- self.config.endpoint.clone()
286
- } else if self.config.endpoint.ends_with("/") {
287
- format!("{}api", self.config.endpoint)
288
- } else {
289
- format!("{}/api", self.config.endpoint)
290
- };
278
+ let normalized_endpoint = self.config.get_api_endpoint();
291
279
 
292
280
  // Disable proxy for localhost to avoid corporate/system proxy causing 502
293
281
  let mut builder = reqwest::Client::builder()
@@ -40,6 +40,16 @@ impl CliHandler {
40
40
  }
41
41
  }
42
42
 
43
+ /// Get the authenticated client and config
44
+ async fn get_authenticated_client(
45
+ &self,
46
+ ) -> Result<(Config, std::sync::Arc<crate::HttpClient>)> {
47
+ let config = self.load_config().await?;
48
+ let auth_service = crate::auth::AuthService::new(config.clone());
49
+ let client = auth_service.get_authenticated_client().await?;
50
+ Ok((config, client))
51
+ }
52
+
43
53
  /// Execute a CLI command
44
54
  pub async fn execute(&mut self, command: Commands) -> Result<()> {
45
55
  match command {
@@ -59,9 +69,7 @@ impl CliHandler {
59
69
 
60
70
  /// Handle encrypt command
61
71
  async fn handle_encrypt(&mut self, args: EncryptArgs) -> Result<()> {
62
- let config = self.load_config().await?;
63
- let auth_service = crate::auth::AuthService::new(config.clone());
64
- let client = auth_service.get_authenticated_client().await?;
72
+ let (config, client) = self.get_authenticated_client().await?;
65
73
  let encrypt_service = EncryptService::new(config);
66
74
  encrypt_service.encrypt(&*client, args).await
67
75
  }
@@ -75,18 +83,14 @@ impl CliHandler {
75
83
 
76
84
  /// Handle upload command - upload encrypted model to server
77
85
  async fn handle_upload(&mut self, args: UploadArgs) -> Result<()> {
78
- let config = self.load_config().await?;
79
- let auth_service = crate::auth::AuthService::new(config);
80
- let client = auth_service.get_authenticated_client().await?;
86
+ let (_config, client) = self.get_authenticated_client().await?;
81
87
  let service = crate::model::ModelService::new();
82
88
  service.upload(client, args).await
83
89
  }
84
90
 
85
91
  /// Handle remove command
86
92
  async fn handle_remove(&mut self, args: RemoveArgs) -> Result<()> {
87
- let config = self.load_config().await?;
88
- let auth_service = crate::auth::AuthService::new(config);
89
- let client = auth_service.get_authenticated_client().await?;
93
+ let (_config, client) = self.get_authenticated_client().await?;
90
94
  let service = crate::model::ModelService::new();
91
95
  service.remove(client, args).await
92
96
  }
@@ -147,27 +151,21 @@ impl CliHandler {
147
151
 
148
152
  /// Handle list command - list files for a model on server
149
153
  async fn handle_list(&mut self, args: ListArgs) -> Result<()> {
150
- let config = self.load_config().await?;
151
- let auth_service = crate::auth::AuthService::new(config);
152
- let client = auth_service.get_authenticated_client().await?;
154
+ let (_config, client) = self.get_authenticated_client().await?;
153
155
  let service = crate::model::ModelService::new();
154
156
  service.list(client, args).await
155
157
  }
156
158
 
157
159
  /// Handle create command
158
160
  async fn handle_create(&mut self, args: CreateArgs) -> Result<()> {
159
- let config = self.load_config().await?;
160
- let auth_service = crate::auth::AuthService::new(config);
161
- let client = auth_service.get_authenticated_client().await?;
161
+ let (_config, client) = self.get_authenticated_client().await?;
162
162
  let service = crate::model::ModelService::new();
163
163
  service.create(client, args).await
164
164
  }
165
165
 
166
166
  /// Handle push command: create -> encrypt -> upload -> hf create repo -> upload to hf -> update model
167
167
  async fn handle_push(&mut self, args: crate::PushArgs) -> Result<()> {
168
- let config = self.load_config().await?;
169
- let auth_service = crate::auth::AuthService::new(config.clone());
170
- let client = auth_service.get_authenticated_client().await?;
168
+ let (config, client) = self.get_authenticated_client().await?;
171
169
  let mut service = crate::push::PushService::new(config);
172
170
  service.push(client, args).await
173
171
  }
@@ -203,6 +203,18 @@ impl Config {
203
203
  crate::security::verify_certificate_pinning(&self.endpoint, &self.server_public_key).await
204
204
  }
205
205
 
206
+ /// Get the API base endpoint (ensuring /api suffix and proper scheme)
207
+ pub fn get_api_endpoint(&self) -> String {
208
+ let base = self.endpoint_url("");
209
+ let base = base.trim_end_matches('/');
210
+
211
+ if base.ends_with("/api") {
212
+ base.to_string()
213
+ } else {
214
+ format!("{}/api", base)
215
+ }
216
+ }
217
+
206
218
  /// Get the full URL for an endpoint
207
219
  pub fn endpoint_url(&self, endpoint: &str) -> String {
208
220
  let endpoint = endpoint.strip_prefix('/').unwrap_or(endpoint);
@@ -645,4 +657,29 @@ mod tests {
645
657
  config.endpoint = "api.test.com".to_string();
646
658
  assert_eq!(config.endpoint_url(""), "https://api.test.com/");
647
659
  }
660
+
661
+ #[test]
662
+ fn test_get_api_endpoint() {
663
+ let mut config = Config::default();
664
+
665
+ // Case 1: No /api suffix
666
+ config.endpoint = "https://example.com".to_string();
667
+ assert_eq!(config.get_api_endpoint(), "https://example.com/api");
668
+
669
+ // Case 2: With /api suffix
670
+ config.endpoint = "https://example.com/api".to_string();
671
+ assert_eq!(config.get_api_endpoint(), "https://example.com/api");
672
+
673
+ // Case 3: With trailing slash
674
+ config.endpoint = "https://example.com/".to_string();
675
+ assert_eq!(config.get_api_endpoint(), "https://example.com/api");
676
+
677
+ // Case 4: With /api/ suffix
678
+ config.endpoint = "https://example.com/api/".to_string();
679
+ assert_eq!(config.get_api_endpoint(), "https://example.com/api");
680
+
681
+ // Case 5: Missing scheme (adds https)
682
+ config.endpoint = "example.com".to_string();
683
+ assert_eq!(config.get_api_endpoint(), "https://example.com/api");
684
+ }
648
685
  }
@@ -241,25 +241,24 @@ pub async fn encrypt_safetensors_file(
241
241
  /// Extract non-reserved metadata from safetensors file header
242
242
  /// This parses the JSON header and filters out reserved fields (starting with "__")
243
243
  fn extract_metadata_from_header(file_content: &[u8]) -> Result<Option<HashMap<String, String>>> {
244
- const N_LEN: usize = 8;
245
-
246
- // Read header length (first 8 bytes)
247
- if file_content.len() < N_LEN {
244
+ // Read header length
245
+ if file_content.len() < CryptoUtils::HEADER_LENGTH_SIZE {
248
246
  return Ok(None);
249
247
  }
250
248
 
251
249
  let header_len = u64::from_le_bytes(
252
- file_content[..N_LEN]
250
+ file_content[..CryptoUtils::HEADER_LENGTH_SIZE]
253
251
  .try_into()
254
252
  .map_err(|_| KoavaError::io("Header parsing", "Invalid header length"))?,
255
253
  ) as usize;
256
254
 
257
- if file_content.len() < N_LEN + header_len {
255
+ if file_content.len() < CryptoUtils::HEADER_LENGTH_SIZE + header_len {
258
256
  return Ok(None);
259
257
  }
260
258
 
261
259
  // Parse JSON header
262
- let header_bytes = &file_content[N_LEN..N_LEN + header_len];
260
+ let header_bytes = &file_content
261
+ [CryptoUtils::HEADER_LENGTH_SIZE..CryptoUtils::HEADER_LENGTH_SIZE + header_len];
263
262
  let header_str = std::str::from_utf8(header_bytes)
264
263
  .map_err(|e| KoavaError::io("Header parsing", format!("Invalid UTF-8: {}", e)))?;
265
264
 
@@ -267,7 +266,7 @@ fn extract_metadata_from_header(file_content: &[u8]) -> Result<Option<HashMap<St
267
266
  .map_err(|e| KoavaError::serialization(format!("Invalid JSON in header: {}", e)))?;
268
267
 
269
268
  // Extract __metadata__ field if present
270
- if let Some(metadata_obj) = header_json.get("__metadata__") {
269
+ if let Some(metadata_obj) = header_json.get(CryptoUtils::METADATA_KEY) {
271
270
  if let Some(metadata_map) = metadata_obj.as_object() {
272
271
  // Filter out reserved fields (starting with "__")
273
272
  let mut filtered_metadata = HashMap::new();
@@ -215,7 +215,7 @@ impl<C: ApiClient + ?Sized> UploadService<C> {
215
215
  continue;
216
216
  }
217
217
 
218
- match CryptoUtils::extract_safetensors_header(&file.path) {
218
+ match CryptoUtils::extract_safetensors_header(&file.path).await {
219
219
  Ok(header_data) => {
220
220
  let file_info = FileInfo {
221
221
  id: None,
@@ -3,8 +3,6 @@
3
3
  use base64::{engine::general_purpose, Engine};
4
4
  use serde::{Deserialize, Serialize};
5
5
  use sha2::{Digest, Sha256};
6
- use std::fs::File;
7
- use std::io::{BufReader, Read};
8
6
  use std::path::Path;
9
7
 
10
8
  use crate::error::{KoavaError, Result};
@@ -46,6 +44,11 @@ pub struct FileInfo {
46
44
  pub struct CryptoUtils;
47
45
 
48
46
  impl CryptoUtils {
47
+ pub const HEADER_LENGTH_SIZE: usize = 8;
48
+ pub const MAX_HEADER_SIZE: usize = 1024 * 1024;
49
+ pub const METADATA_KEY: &str = "__metadata__";
50
+ pub const ENCRYPTION_KEY: &str = "__encryption__";
51
+
49
52
  /// Calculate SHA256 hash of a string (used for header hashing)
50
53
  pub fn calculate_sha256_hash(data: &str) -> String {
51
54
  let mut hasher = Sha256::new();
@@ -60,34 +63,48 @@ impl CryptoUtils {
60
63
  hex::encode(hasher.finalize())
61
64
  }
62
65
 
63
- /// Extract header data from a Safetensors file
64
- /// This reads the first 8 bytes (header length) + header JSON and encodes as base64
65
- pub fn extract_safetensors_header<P: AsRef<Path>>(file_path: P) -> Result<String> {
66
+ /// Read the safetensors header raw bytes (excluding length prefix).
67
+ /// This function handles the 8-byte length reading and size validation.
68
+ pub async fn read_safetensors_header_raw<P: AsRef<Path>>(file_path: P) -> Result<Vec<u8>> {
66
69
  let file_path = file_path.as_ref();
67
- let mut file = BufReader::new(
68
- File::open(file_path)
69
- .map_err(|e| KoavaError::io("File open", format!("Failed to open file: {}", e)))?,
70
- );
70
+ let mut file = tokio::fs::File::open(file_path)
71
+ .await
72
+ .map_err(|e| KoavaError::io("File open", format!("Failed to open file: {}", e)))?;
71
73
 
72
- // Read header length (first 8 bytes, little endian)
73
- let mut header_len_bytes = [0u8; 8];
74
- file.read_exact(&mut header_len_bytes)
75
- .map_err(|e| KoavaError::crypto(format!("Failed to read header length: {}", e)))?;
74
+ let mut header_len_bytes = [0u8; Self::HEADER_LENGTH_SIZE];
75
+ use tokio::io::AsyncReadExt;
76
+ file.read_exact(&mut header_len_bytes).await.map_err(|e| {
77
+ KoavaError::io(
78
+ "Header read",
79
+ format!("Failed to read header length: {}", e),
80
+ )
81
+ })?;
76
82
 
77
83
  let header_len = u64::from_le_bytes(header_len_bytes) as usize;
78
84
 
79
- if header_len > 1024 * 1024 {
80
- // 1MB limit for safety
81
- return Err(KoavaError::crypto("Header too large"));
85
+ if header_len > Self::MAX_HEADER_SIZE {
86
+ return Err(KoavaError::validation(
87
+ "Header too large (exceeds 1MB limit)",
88
+ ));
82
89
  }
83
90
 
84
- // Read header JSON
85
91
  let mut header_json_bytes = vec![0u8; header_len];
86
- file.read_exact(&mut header_json_bytes)
87
- .map_err(|e| KoavaError::crypto(format!("Failed to read header JSON: {}", e)))?;
92
+ file.read_exact(&mut header_json_bytes).await.map_err(|e| {
93
+ KoavaError::io("Header read", format!("Failed to read header JSON: {}", e))
94
+ })?;
95
+
96
+ Ok(header_json_bytes)
97
+ }
98
+
99
+ /// Extract header data from a Safetensors file
100
+ /// This reads the first 8 bytes (header length) + header JSON and encodes as base64
101
+ pub async fn extract_safetensors_header<P: AsRef<Path>>(file_path: P) -> Result<String> {
102
+ let header_json_bytes = Self::read_safetensors_header_raw(file_path).await?;
103
+ let header_len = header_json_bytes.len();
104
+ let header_len_bytes = (header_len as u64).to_le_bytes();
88
105
 
89
106
  // Combine header length + header JSON
90
- let mut header_data = Vec::with_capacity(8 + header_len);
107
+ let mut header_data = Vec::with_capacity(Self::HEADER_LENGTH_SIZE + header_len);
91
108
  header_data.extend_from_slice(&header_len_bytes);
92
109
  header_data.extend_from_slice(&header_json_bytes);
93
110
 
@@ -126,37 +143,7 @@ impl CryptoUtils {
126
143
  /// Detect if a safetensors file is encrypted by checking its header metadata
127
144
  /// This function only reads the file header portion, not the entire file
128
145
  pub async fn detect_safetensors_encryption<P: AsRef<Path>>(file_path: P) -> Result<bool> {
129
- let file_path = file_path.as_ref();
130
-
131
- // Open file and read only the header portion
132
- let mut file = tokio::fs::File::open(file_path)
133
- .await
134
- .map_err(|e| KoavaError::io("File open", format!("Failed to open file: {}", e)))?;
135
-
136
- // Read header length (first 8 bytes, little endian)
137
- let mut header_len_bytes = [0u8; 8];
138
- use tokio::io::AsyncReadExt;
139
- file.read_exact(&mut header_len_bytes).await.map_err(|e| {
140
- KoavaError::io(
141
- "Header read",
142
- format!("Failed to read header length: {}", e),
143
- )
144
- })?;
145
-
146
- let header_len = u64::from_le_bytes(header_len_bytes) as usize;
147
-
148
- if header_len > 1024 * 1024 {
149
- // 1MB limit for safety
150
- return Err(KoavaError::validation(
151
- "Header too large (exceeds 1MB limit)",
152
- ));
153
- }
154
-
155
- // Read header JSON
156
- let mut header_json_bytes = vec![0u8; header_len];
157
- file.read_exact(&mut header_json_bytes).await.map_err(|e| {
158
- KoavaError::io("Header read", format!("Failed to read header JSON: {}", e))
159
- })?;
146
+ let header_json_bytes = Self::read_safetensors_header_raw(file_path).await?;
160
147
 
161
148
  // Parse the header JSON to check for encryption metadata
162
149
  let header_json: serde_json::Value =
@@ -165,9 +152,9 @@ impl CryptoUtils {
165
152
  })?;
166
153
 
167
154
  // Check if the file contains encryption metadata
168
- if let Some(metadata) = header_json.get("__metadata__") {
155
+ if let Some(metadata) = header_json.get(Self::METADATA_KEY) {
169
156
  if let Some(metadata_obj) = metadata.as_object() {
170
- if metadata_obj.contains_key("__encryption__") {
157
+ if metadata_obj.contains_key(Self::ENCRYPTION_KEY) {
171
158
  return Ok(true);
172
159
  }
173
160
  }
@@ -234,12 +221,14 @@ mod tests {
234
221
  assert!(result.is_err());
235
222
  }
236
223
 
237
- #[test]
238
- fn test_extract_safetensors_header() {
224
+ #[tokio::test]
225
+ async fn test_extract_safetensors_header() {
239
226
  let temp_dir = create_temp_dir();
240
227
  let file_path = create_mock_safetensors_file(&temp_dir, "model.safetensors", false);
241
228
 
242
- let header_b64 = CryptoUtils::extract_safetensors_header(&file_path).unwrap();
229
+ let header_b64 = CryptoUtils::extract_safetensors_header(&file_path)
230
+ .await
231
+ .unwrap();
243
232
 
244
233
  // Verify it's valid base64
245
234
  let decoded = CryptoUtils::decode_base64(&header_b64).unwrap();
@@ -248,9 +237,10 @@ mod tests {
248
237
  assert_eq!(decoded.len() >= 8, true);
249
238
  }
250
239
 
251
- #[test]
252
- fn test_extract_safetensors_header_nonexistent_file() {
253
- let result = CryptoUtils::extract_safetensors_header("/nonexistent/file.safetensors");
240
+ #[tokio::test]
241
+ async fn test_extract_safetensors_header_nonexistent_file() {
242
+ let result =
243
+ CryptoUtils::extract_safetensors_header("/nonexistent/file.safetensors").await;
254
244
  assert!(result.is_err());
255
245
  }
256
246
 
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes