mod-trace 0.3.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. {mod_trace-0.3.2 → mod_trace-0.4.0}/Cargo.lock +1 -1
  2. {mod_trace-0.3.2 → mod_trace-0.4.0}/Cargo.toml +1 -1
  3. {mod_trace-0.3.2 → mod_trace-0.4.0}/PKG-INFO +2 -2
  4. {mod_trace-0.3.2 → mod_trace-0.4.0}/README.md +1 -1
  5. mod_trace-0.4.0/examples/pytorch/README.md +35 -0
  6. mod_trace-0.4.0/examples/pytorch/generate_demo_models.py +35 -0
  7. mod_trace-0.4.0/examples/pytorch/mlp_v1.pt +0 -0
  8. mod_trace-0.4.0/examples/pytorch/mlp_v2.pt +0 -0
  9. {mod_trace-0.3.2 → mod_trace-0.4.0}/pyproject.toml +1 -1
  10. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/main.rs +441 -16
  11. mod_trace-0.4.0/src/pt.rs +398 -0
  12. {mod_trace-0.3.2 → mod_trace-0.4.0}/.github/workflows/release.yml +0 -0
  13. {mod_trace-0.3.2 → mod_trace-0.4.0}/.gitignore +0 -0
  14. {mod_trace-0.3.2 → mod_trace-0.4.0}/LICENSE +0 -0
  15. {mod_trace-0.3.2 → mod_trace-0.4.0}/benchmarks/tiny_pytorch.py +0 -0
  16. {mod_trace-0.3.2 → mod_trace-0.4.0}/docs/ARCHITECTURE.md +0 -0
  17. {mod_trace-0.3.2 → mod_trace-0.4.0}/docs/REAL_MODELS.md +0 -0
  18. {mod_trace-0.3.2 → mod_trace-0.4.0}/docs/tensor-lab.md +0 -0
  19. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/broken_shape.json +0 -0
  20. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/README.md +0 -0
  21. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/clf_v1.txt +0 -0
  22. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/clf_v2.txt +0 -0
  23. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/generate_demo_models.py +0 -0
  24. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/make_sample_catboost.py +0 -0
  25. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/mlp.json +0 -0
  26. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/README.md +0 -0
  27. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/generate_demo_models.py +0 -0
  28. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_retrain_a.onnx +0 -0
  29. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_retrain_b.onnx +0 -0
  30. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_v1.onnx +0 -0
  31. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_v2.onnx +0 -0
  32. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/tiny_attention.json +0 -0
  33. {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/tiny_attention_plan.json +0 -0
  34. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/catboost_deep_diff.py +0 -0
  35. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/catboost_explain.py +0 -0
  36. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/cbm.rs +0 -0
  37. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/demo.rs +0 -0
  38. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/explain.rs +0 -0
  39. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/lgbm.rs +0 -0
  40. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/model.rs +0 -0
  41. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/onnx.rs +0 -0
  42. {mod_trace-0.3.2 → mod_trace-0.4.0}/src/tensor.rs +0 -0
@@ -16,7 +16,7 @@ checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8"
16
16
 
17
17
  [[package]]
18
18
  name = "mod-trace"
19
- version = "0.3.2"
19
+ version = "0.4.0"
20
20
  dependencies = [
21
21
  "serde",
22
22
  "serde_json",
@@ -1,6 +1,6 @@
1
1
  [package]
2
2
  name = "mod-trace"
3
- version = "0.3.2"
3
+ version = "0.4.0"
4
4
  edition = "2024"
5
5
  description = "Rust CLI for inspecting ML model artifacts without loading the framework"
6
6
  license = "MIT"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: mod-trace
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Classifier: Programming Language :: Rust
5
5
  Classifier: Programming Language :: Python :: 3
6
6
  Classifier: Programming Language :: Python :: 3 :: Only
@@ -25,7 +25,7 @@ mod-trace is a small Rust CLI for answering a practical question:
25
25
  What is inside this model file?
26
26
  ```
27
27
 
28
- It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models, and ONNX `.onnx` graphs, then report structure, size, parameters, operator mix, rough inference cost, and changes between versions. CatBoost, LightGBM, and ONNX are all read natively — no Python, framework, or runtime needed (CatBoost `--deep` is the one optional exception).
28
+ It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models, ONNX `.onnx` graphs, and PyTorch `.pt`/`.pth` checkpoints, then report structure, size, parameters, operator mix, rough inference cost, and changes between versions. All formats are read natively — no Python, framework, or runtime needed (CatBoost `--deep` is the one optional exception). The PyTorch reader is static: it sizes/names tensors and fingerprints weights without decoding exact shapes.
29
29
 
30
30
  The most useful command is `explain-diff`, which says in plain English what changed between two model versions:
31
31
 
@@ -8,7 +8,7 @@ mod-trace is a small Rust CLI for answering a practical question:
8
8
  What is inside this model file?
9
9
  ```
10
10
 
11
- It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models, and ONNX `.onnx` graphs, then report structure, size, parameters, operator mix, rough inference cost, and changes between versions. CatBoost, LightGBM, and ONNX are all read natively — no Python, framework, or runtime needed (CatBoost `--deep` is the one optional exception).
11
+ It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models, ONNX `.onnx` graphs, and PyTorch `.pt`/`.pth` checkpoints, then report structure, size, parameters, operator mix, rough inference cost, and changes between versions. All formats are read natively — no Python, framework, or runtime needed (CatBoost `--deep` is the one optional exception). The PyTorch reader is static: it sizes/names tensors and fingerprints weights without decoding exact shapes.
12
12
 
13
13
  The most useful command is `explain-diff`, which says in plain English what changed between two model versions:
14
14
 
@@ -0,0 +1,35 @@
1
+ # PyTorch example models
2
+
3
+ Synthetic `torch.save` artifacts for trying `mod-trace` on PyTorch with **no
4
+ torch and no Python** — mod-trace reads the `.pt` zip (pickled structure + raw
5
+ tensor storages) statically.
6
+
7
+ | Files | What they show |
8
+ |-------|----------------|
9
+ | `mlp_v1.pt` vs `mlp_v2.pt` | Same 2-layer MLP, hidden size 32 → 64 (parameter count ~doubles, same layer names). |
10
+
11
+ ## Try it
12
+
13
+ ```bash
14
+ mod-trace inspect examples/pytorch/mlp_v1.pt
15
+ mod-trace explain-diff examples/pytorch/mlp_v1.pt examples/pytorch/mlp_v2.pt
16
+ mod-trace check --max-parameter-growth 30% examples/pytorch/mlp_v1.pt examples/pytorch/mlp_v2.pt
17
+ mod-trace inspect --json examples/pytorch/mlp_v1.pt
18
+ ```
19
+
20
+ ## What it reads (and what it doesn't)
21
+
22
+ Reads, statically: file size, tensor/storage count, **estimated parameter count**
23
+ (from storage bytes ÷ dtype), **dominant dtype**, **recovered parameter/layer
24
+ names** (`fc1.weight`, …), and fingerprints (a sampled weight fingerprint that
25
+ changes on a retrain/finetune).
26
+
27
+ Does **not** decode exact per-tensor shapes — that would need a full pickle
28
+ interpreter. Same static/heuristic philosophy as the CatBoost and ONNX readers.
29
+
30
+ ## Regenerate
31
+
32
+ ```bash
33
+ python -m pip install torch
34
+ python examples/pytorch/generate_demo_models.py
35
+ ```
@@ -0,0 +1,35 @@
1
+ """Generate the synthetic PyTorch demo models used by the README examples.
2
+
3
+ Fully synthetic (no real data). Run:
4
+
5
+ python -m pip install torch
6
+ python examples/pytorch/generate_demo_models.py
7
+
8
+ Produces, in this directory:
9
+ mlp_v1.pt / mlp_v2.pt -> same 2-layer MLP, different hidden size (32 vs 64)
10
+ """
11
+
12
+ import os
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+
17
+ HERE = os.path.dirname(os.path.abspath(__file__))
18
+
19
+
20
+ class Net(nn.Module):
21
+ def __init__(self, hidden):
22
+ super().__init__()
23
+ self.fc1 = nn.Linear(16, hidden)
24
+ self.fc2 = nn.Linear(hidden, 4)
25
+
26
+ def forward(self, x):
27
+ return self.fc2(torch.relu(self.fc1(x)))
28
+
29
+
30
+ if __name__ == "__main__":
31
+ torch.manual_seed(0)
32
+ torch.save(Net(32).state_dict(), os.path.join(HERE, "mlp_v1.pt"))
33
+ torch.manual_seed(1)
34
+ torch.save(Net(64).state_dict(), os.path.join(HERE, "mlp_v2.pt"))
35
+ print("wrote mlp_v1.pt and mlp_v2.pt")
@@ -4,7 +4,7 @@ build-backend = "maturin"
4
4
 
5
5
  [project]
6
6
  name = "mod-trace"
7
- version = "0.3.2"
7
+ version = "0.4.0"
8
8
  description = "Rust CLI for inspecting ML model artifacts without loading the framework"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.9"
@@ -4,6 +4,7 @@ mod explain;
4
4
  mod lgbm;
5
5
  mod model;
6
6
  mod onnx;
7
+ mod pt;
7
8
  mod tensor;
8
9
 
9
10
  use std::collections::BTreeSet;
@@ -54,6 +55,7 @@ fn run() -> Result<(), String> {
54
55
  Some("onnx") => onnx_cmd(&args.rest),
55
56
  Some("catboost") | Some("cbm") => catboost_cmd(&args.rest),
56
57
  Some("lightgbm") | Some("lgbm") => lgbm_cmd(&args.rest),
58
+ Some("pytorch") | Some("pt") => pt_cmd(&args.rest),
57
59
  Some("validate") => validate_model_cmd(&args.rest),
58
60
  Some("tensor-inspect") => inspect_model_cmd(&args.rest),
59
61
  Some("run") => run_model_cmd(&args.rest),
@@ -93,6 +95,7 @@ struct BuiltInDoctorReport {
93
95
  catboost_metadata: bool,
94
96
  lightgbm_text: bool,
95
97
  onnx_static_graph: bool,
98
+ pytorch_zip: bool,
96
99
  json_tensor_plans: bool,
97
100
  }
98
101
 
@@ -125,6 +128,7 @@ fn doctor_report() -> DoctorReport {
125
128
  catboost_metadata: true,
126
129
  lightgbm_text: true,
127
130
  onnx_static_graph: true,
131
+ pytorch_zip: true,
128
132
  json_tensor_plans: true,
129
133
  },
130
134
  optional_python: PythonDoctorReport {
@@ -168,6 +172,10 @@ fn print_doctor_report(report: &DoctorReport) {
168
172
  " ONNX static graph: {}",
169
173
  availability(report.built_in.onnx_static_graph)
170
174
  );
175
+ println!(
176
+ " PyTorch .pt/.pth (zip): {}",
177
+ availability(report.built_in.pytorch_zip)
178
+ );
171
179
  println!(
172
180
  " JSON tensor plans: {}",
173
181
  availability(report.built_in.json_tensor_plans)
@@ -192,15 +200,15 @@ fn print_doctor_report(report: &DoctorReport) {
192
200
  println!();
193
201
  println!("Available commands:");
194
202
  println!(
195
- " inspect .cbm/.lgb/.onnx/.json: {}",
203
+ " inspect .cbm/.lgb/.onnx/.pt/.json: {}",
196
204
  available_unavailable(report.commands.inspect_artifacts)
197
205
  );
198
206
  println!(
199
- " diff .cbm/.lgb/.onnx: {}",
207
+ " diff .cbm/.lgb/.onnx/.pt: {}",
200
208
  available_unavailable(report.commands.diff_artifacts)
201
209
  );
202
210
  println!(
203
- " explain-diff .cbm/.lgb/.onnx: {}",
211
+ " explain-diff .cbm/.lgb/.onnx/.pt: {}",
204
212
  available_unavailable(report.commands.diff_artifacts)
205
213
  );
206
214
  println!(
@@ -301,6 +309,10 @@ fn inspect_cmd(args: &[String]) -> Result<(), String> {
301
309
  Err("--deep inspection is currently only supported for CatBoost artifacts.".to_string())
302
310
  }
303
311
  ArtifactKind::Onnx => onnx_cmd(args),
312
+ ArtifactKind::PyTorch if deep => {
313
+ Err("--deep inspection is currently only supported for CatBoost artifacts.".to_string())
314
+ }
315
+ ArtifactKind::PyTorch => pt_cmd(args),
304
316
  ArtifactKind::Json if deep => {
305
317
  Err("--deep inspection is currently only supported for CatBoost artifacts.".to_string())
306
318
  }
@@ -312,7 +324,7 @@ fn inspect_cmd(args: &[String]) -> Result<(), String> {
312
324
  }
313
325
  ArtifactKind::Json => inspect_model_cmd(args),
314
326
  ArtifactKind::Unknown => Err(format!(
315
- "unsupported artifact type for `{}`. Try .cbm, .lgb, .onnx, or .json.",
327
+ "unsupported artifact type for `{}`. Try .cbm, .lgb, .onnx, .pt/.pth, or .json.",
316
328
  target
317
329
  )),
318
330
  }
@@ -378,6 +390,13 @@ fn diff_cmd(args: &[String]) -> Result<(), String> {
378
390
  }
379
391
  diff_onnx(&paths[0], &paths[1], json)
380
392
  }
393
+ (ArtifactKind::PyTorch, ArtifactKind::PyTorch) => {
394
+ if deep {
395
+ println!("Note: --deep is currently only used for CatBoost artifacts.");
396
+ println!();
397
+ }
398
+ diff_pt(&paths[0], &paths[1], json)
399
+ }
381
400
  (ArtifactKind::Json, ArtifactKind::Json) => Err(
382
401
  "tensor plan diff is not supported yet. Use trace, compare, why, or validate instead."
383
402
  .to_string(),
@@ -462,8 +481,9 @@ fn check_cmd(args: &[String]) -> Result<(), String> {
462
481
  check_lgbm(&paths[0], &paths[1], &options)
463
482
  }
464
483
  (ArtifactKind::Onnx, ArtifactKind::Onnx) => check_onnx(&paths[0], &paths[1], &options),
484
+ (ArtifactKind::PyTorch, ArtifactKind::PyTorch) => check_pt(&paths[0], &paths[1], &options),
465
485
  (ArtifactKind::Json, ArtifactKind::Json) => {
466
- Err("check supports CatBoost, LightGBM, and ONNX artifacts, not tensor plan JSON."
486
+ Err("check supports CatBoost, LightGBM, ONNX, and PyTorch artifacts, not tensor plan JSON."
467
487
  .to_string())
468
488
  }
469
489
  (left, right) => Err(format!(
@@ -715,6 +735,7 @@ fn explain_cmd(args: &[String]) -> Result<(), String> {
715
735
  ArtifactKind::Onnx => return explain_onnx_cmd(target),
716
736
  ArtifactKind::CatBoost => return explain_catboost_cmd(target),
717
737
  ArtifactKind::LightGbm => return lgbm_cmd(&[target.to_string()]),
738
+ ArtifactKind::PyTorch => return pt_cmd(&[target.to_string()]),
718
739
  ArtifactKind::Json => return explain_model_cmd(&[target.to_string()]),
719
740
  ArtifactKind::Unknown => {}
720
741
  }
@@ -1013,8 +1034,9 @@ fn explain_diff_cmd(args: &[String]) -> Result<(), String> {
1013
1034
  (ArtifactKind::Onnx, ArtifactKind::Onnx) => explain_diff_onnx(old, new),
1014
1035
  (ArtifactKind::LightGbm, ArtifactKind::LightGbm) => explain_diff_lgbm(old, new),
1015
1036
  (ArtifactKind::CatBoost, ArtifactKind::CatBoost) => explain_diff_catboost(old, new),
1037
+ (ArtifactKind::PyTorch, ArtifactKind::PyTorch) => explain_diff_pt(old, new),
1016
1038
  (left, right) => Err(format!(
1017
- "explain-diff needs two artifacts of the same supported type (.onnx, .cbm, .lgb): {} vs {}",
1039
+ "explain-diff needs two artifacts of the same supported type (.onnx, .cbm, .lgb, .pt): {} vs {}",
1018
1040
  left.label(),
1019
1041
  right.label()
1020
1042
  )),
@@ -1326,6 +1348,19 @@ fn explain_diff_catboost(old_path: &str, new_path: &str) -> Result<(), String> {
1326
1348
  let old = cbm::inspect(old_path)?;
1327
1349
  let new = cbm::inspect(new_path)?;
1328
1350
 
1351
+ let old_features = old.feature_candidates.iter().collect::<BTreeSet<_>>();
1352
+ let new_features = new.feature_candidates.iter().collect::<BTreeSet<_>>();
1353
+ let added = new_features
1354
+ .difference(&old_features)
1355
+ .map(|name| name.as_str())
1356
+ .collect::<Vec<_>>();
1357
+ let removed = old_features
1358
+ .difference(&new_features)
1359
+ .map(|name| name.as_str())
1360
+ .collect::<Vec<_>>();
1361
+ let names_known = !old.feature_candidates.is_empty() || !new.feature_candidates.is_empty();
1362
+ let config_same = catboost_training_config_same(&old, &new);
1363
+
1329
1364
  println!("Model Change Explanation");
1330
1365
  println!("------------------------");
1331
1366
  println!("Type: CatBoost");
@@ -1334,21 +1369,68 @@ fn explain_diff_catboost(old_path: &str, new_path: &str) -> Result<(), String> {
1334
1369
  println!();
1335
1370
  println!("Architecture:");
1336
1371
  println!(
1337
- " Trees: {} -> {}",
1372
+ " Trees: {} -> {} ({})",
1338
1373
  opt_num(old.iterations.map(|value| value as usize)),
1339
- opt_num(new.iterations.map(|value| value as usize))
1374
+ opt_num(new.iterations.map(|value| value as usize)),
1375
+ match (old.iterations, new.iterations) {
1376
+ (Some(o), Some(n)) => growth_label(o as usize, n as usize),
1377
+ _ => "unknown".to_string(),
1378
+ }
1340
1379
  );
1341
1380
  println!(
1342
- " Depth: {} -> {}",
1381
+ " Depth: {} -> {} ({})",
1343
1382
  opt_num(old.depth.map(|value| value as usize)),
1344
- opt_num(new.depth.map(|value| value as usize))
1383
+ opt_num(new.depth.map(|value| value as usize)),
1384
+ same_or_changed(old.depth == new.depth)
1345
1385
  );
1386
+ let feature_note = if !names_known {
1387
+ "names not recovered".to_string()
1388
+ } else if added.is_empty() && removed.is_empty() {
1389
+ "same set".to_string()
1390
+ } else {
1391
+ format!("+{} added, -{} removed", added.len(), removed.len())
1392
+ };
1346
1393
  println!(
1347
- " Features (recovered): {} -> {}",
1394
+ " Features (recovered): {} -> {} ({})",
1348
1395
  old.feature_candidates.len(),
1349
- new.feature_candidates.len()
1396
+ new.feature_candidates.len(),
1397
+ feature_note
1350
1398
  );
1399
+ print_lgbm_feature_list(" added: ", &added, 8);
1400
+ print_lgbm_feature_list(" removed:", &removed, 8);
1351
1401
  println!();
1402
+ println!("Training config:");
1403
+ println!(
1404
+ " Loss: {} -> {} ({})",
1405
+ old.loss_function.as_deref().unwrap_or("unknown"),
1406
+ new.loss_function.as_deref().unwrap_or("unknown"),
1407
+ same_or_changed(old.loss_function == new.loss_function)
1408
+ );
1409
+ println!(
1410
+ " Eval metric: {} -> {} ({})",
1411
+ old.eval_metric.as_deref().unwrap_or("unknown"),
1412
+ new.eval_metric.as_deref().unwrap_or("unknown"),
1413
+ same_or_changed(old.eval_metric == new.eval_metric)
1414
+ );
1415
+ println!(
1416
+ " Learning rate: {} -> {} ({})",
1417
+ opt_float(old.learning_rate),
1418
+ opt_float(new.learning_rate),
1419
+ same_or_changed(old.learning_rate == new.learning_rate)
1420
+ );
1421
+ println!(
1422
+ " Grow policy: {} -> {} ({})",
1423
+ old.grow_policy.as_deref().unwrap_or("unknown"),
1424
+ new.grow_policy.as_deref().unwrap_or("unknown"),
1425
+ same_or_changed(old.grow_policy == new.grow_policy)
1426
+ );
1427
+ println!();
1428
+ println!(
1429
+ "File size: {} -> {} ({})",
1430
+ format_bytes(old.bytes),
1431
+ format_bytes(new.bytes),
1432
+ growth_label(old.bytes, new.bytes)
1433
+ );
1352
1434
  match (old.estimated_leaf_values(), new.estimated_leaf_values()) {
1353
1435
  (Some(o), Some(n)) => println!(
1354
1436
  "Estimated leaf-slot growth: {}",
@@ -1363,7 +1445,20 @@ fn explain_diff_catboost(old_path: &str, new_path: &str) -> Result<(), String> {
1363
1445
  println!("Learned state: unchanged");
1364
1446
  }
1365
1447
  println!();
1366
- println!("Note: CatBoost internals are summarized; run `diff --deep` for exact split/leaf changes.");
1448
+ println!("Summary:");
1449
+ let features_same = added.is_empty() && removed.is_empty();
1450
+ let descriptor = if !features_same {
1451
+ "feature set changed"
1452
+ } else if !config_same {
1453
+ "training config changed"
1454
+ } else if old.iterations != new.iterations || old.depth != new.depth {
1455
+ "same spec, retrained with different tree count/depth"
1456
+ } else {
1457
+ "same spec and features, retrained"
1458
+ };
1459
+ println!(" {descriptor}.");
1460
+ println!();
1461
+ println!("Note: heuristic summary; run `diff --deep` for exact split/leaf changes.");
1367
1462
 
1368
1463
  Ok(())
1369
1464
  }
@@ -1692,6 +1787,7 @@ enum ArtifactKind {
1692
1787
  CatBoost,
1693
1788
  LightGbm,
1694
1789
  Onnx,
1790
+ PyTorch,
1695
1791
  Json,
1696
1792
  Unknown,
1697
1793
  }
@@ -1702,6 +1798,7 @@ impl ArtifactKind {
1702
1798
  ArtifactKind::CatBoost => "CatBoost",
1703
1799
  ArtifactKind::LightGbm => "LightGBM",
1704
1800
  ArtifactKind::Onnx => "ONNX",
1801
+ ArtifactKind::PyTorch => "PyTorch",
1705
1802
  ArtifactKind::Json => "tensor plan JSON",
1706
1803
  ArtifactKind::Unknown => "unknown",
1707
1804
  }
@@ -1718,6 +1815,7 @@ fn artifact_kind(path: &str) -> ArtifactKind {
1718
1815
  Some("cbm") => return ArtifactKind::CatBoost,
1719
1816
  Some("lgb") => return ArtifactKind::LightGbm,
1720
1817
  Some("onnx") => return ArtifactKind::Onnx,
1818
+ Some("pt") | Some("pth") => return ArtifactKind::PyTorch,
1721
1819
  Some("json") => return ArtifactKind::Json,
1722
1820
  _ => {}
1723
1821
  }
@@ -2989,6 +3087,332 @@ fn check_lgbm(old_path: &str, new_path: &str, options: &CheckOptions) -> Result<
2989
3087
  finish_check(&checks)
2990
3088
  }
2991
3089
 
3090
+ fn pt_cmd(args: &[String]) -> Result<(), String> {
3091
+ if args.is_empty() {
3092
+ return Err(
3093
+ "usage: mod-trace pytorch [--json] [--limit 20] <model.pt|model.pth> [more...]"
3094
+ .to_string(),
3095
+ );
3096
+ }
3097
+
3098
+ let mut json = false;
3099
+ let mut limit = DEFAULT_LIMIT;
3100
+ let mut paths = Vec::new();
3101
+ let mut i = 0usize;
3102
+ while i < args.len() {
3103
+ match args[i].as_str() {
3104
+ "--json" => {
3105
+ json = true;
3106
+ i += 1;
3107
+ }
3108
+ "--deep" => {
3109
+ return Err(
3110
+ "--deep inspection is currently only supported for CatBoost artifacts."
3111
+ .to_string(),
3112
+ );
3113
+ }
3114
+ "--limit" => {
3115
+ let value = args
3116
+ .get(i + 1)
3117
+ .ok_or_else(|| "--limit needs a number".to_string())?;
3118
+ limit = value
3119
+ .parse::<usize>()
3120
+ .map_err(|err| format!("parse --limit: {err}"))?;
3121
+ i += 2;
3122
+ }
3123
+ value => {
3124
+ paths.push(value.to_string());
3125
+ i += 1;
3126
+ }
3127
+ }
3128
+ }
3129
+
3130
+ if paths.is_empty() {
3131
+ return Err(
3132
+ "usage: mod-trace pytorch [--json] [--limit 20] <model.pt|model.pth> [more...]"
3133
+ .to_string(),
3134
+ );
3135
+ }
3136
+
3137
+ if json {
3138
+ let reports = paths
3139
+ .iter()
3140
+ .map(|path| pt::inspect(path))
3141
+ .collect::<Result<Vec<_>, _>>()?;
3142
+ if reports.len() == 1 {
3143
+ print_json(&reports[0])?;
3144
+ } else {
3145
+ print_json(&reports)?;
3146
+ }
3147
+ return Ok(());
3148
+ }
3149
+
3150
+ for (index, path) in paths.iter().enumerate() {
3151
+ if index > 0 {
3152
+ println!();
3153
+ }
3154
+ let report = pt::inspect(path)?;
3155
+ print_pt_report(&report, limit);
3156
+ }
3157
+
3158
+ Ok(())
3159
+ }
3160
+
3161
+ fn print_pt_report(report: &pt::PtReport, limit: usize) {
3162
+ println!("PyTorch Model Summary");
3163
+ println!("---------------------");
3164
+ println!("Model: {}", report.path);
3165
+ println!("Format: {}", report.format());
3166
+ println!("File size: {}", format_bytes(report.bytes));
3167
+ if let Some(version) = report.torch_version.as_deref() {
3168
+ println!("Serialization version: {version}");
3169
+ }
3170
+ println!();
3171
+ println!("Structure:");
3172
+ println!(" Tensors (storages): {}", report.tensor_count);
3173
+ println!(
3174
+ " Parameters (est): {}",
3175
+ format_count_human(report.estimated_parameter_count as usize)
3176
+ );
3177
+ println!(
3178
+ " Parameter bytes: {}",
3179
+ format_bytes(report.total_parameter_bytes as usize)
3180
+ );
3181
+ print_optional("Dominant dtype", report.dominant_dtype.as_deref());
3182
+ println!();
3183
+ println!("Parameter-like Internals:");
3184
+ println!(
3185
+ " Full artifact fingerprint: {}",
3186
+ format_hex(report.file_fingerprint)
3187
+ );
3188
+ println!(
3189
+ " Learned-state fingerprint (sampled): {}",
3190
+ format_hex(report.learned_state_fingerprint)
3191
+ );
3192
+ println!(
3193
+ " Note: parameter count is estimated from storage bytes / dtype; tensor shapes are not decoded."
3194
+ );
3195
+ println!();
3196
+ println!("Recovered Parameter Names:");
3197
+ if report.param_names.is_empty() {
3198
+ println!(" none recovered from the pickle");
3199
+ } else {
3200
+ println!(" found: {}", report.param_names.len());
3201
+ for name in report.param_names.iter().take(limit) {
3202
+ println!(" {name}");
3203
+ }
3204
+ if report.param_names.len() > limit {
3205
+ println!(" ... {} more", report.param_names.len() - limit);
3206
+ }
3207
+ }
3208
+ }
3209
+
3210
+ fn diff_pt(old_path: &str, new_path: &str, json: bool) -> Result<(), String> {
3211
+ let old = pt::inspect(old_path)?;
3212
+ let new = pt::inspect(new_path)?;
3213
+
3214
+ if json {
3215
+ print_json(&serde_json::json!({
3216
+ "type": "pytorch",
3217
+ "old": old,
3218
+ "new": new,
3219
+ }))?;
3220
+ return Ok(());
3221
+ }
3222
+
3223
+ println!("Model Diff");
3224
+ println!("----------");
3225
+ println!("Type: PyTorch");
3226
+ println!("Old: {}", old.path);
3227
+ println!("New: {}", new.path);
3228
+ println!();
3229
+ println!("Structure:");
3230
+ print_diff_bytes("File size", old.bytes, new.bytes);
3231
+ print_diff_usize("Tensors (storages)", old.tensor_count, new.tensor_count);
3232
+ print_diff_usize(
3233
+ "Parameters (est)",
3234
+ old.estimated_parameter_count as usize,
3235
+ new.estimated_parameter_count as usize,
3236
+ );
3237
+ print_diff_bytes(
3238
+ "Parameter bytes",
3239
+ old.total_parameter_bytes as usize,
3240
+ new.total_parameter_bytes as usize,
3241
+ );
3242
+ print_diff_opt_str(
3243
+ "Dominant dtype",
3244
+ old.dominant_dtype.as_deref(),
3245
+ new.dominant_dtype.as_deref(),
3246
+ );
3247
+ println!();
3248
+ println!("Parameter-like Internals:");
3249
+ print_fingerprint_diff(
3250
+ "Full artifact fingerprint",
3251
+ Some(old.file_fingerprint),
3252
+ Some(new.file_fingerprint),
3253
+ );
3254
+ print_fingerprint_diff(
3255
+ "Learned-state fingerprint (sampled)",
3256
+ Some(old.learned_state_fingerprint),
3257
+ Some(new.learned_state_fingerprint),
3258
+ );
3259
+ println!();
3260
+ println!("Parameter Names:");
3261
+ print_diff_usize("Recovered names", old.param_names.len(), new.param_names.len());
3262
+ print_feature_name_changes(&old.param_names, &new.param_names, 12);
3263
+
3264
+ Ok(())
3265
+ }
3266
+
3267
+ fn explain_diff_pt(old_path: &str, new_path: &str) -> Result<(), String> {
3268
+ let old = pt::inspect(old_path)?;
3269
+ let new = pt::inspect(new_path)?;
3270
+
3271
+ let old_names = old.param_names.iter().collect::<BTreeSet<_>>();
3272
+ let new_names = new.param_names.iter().collect::<BTreeSet<_>>();
3273
+ let added = new_names
3274
+ .difference(&old_names)
3275
+ .map(|name| name.as_str())
3276
+ .collect::<Vec<_>>();
3277
+ let removed = old_names
3278
+ .difference(&new_names)
3279
+ .map(|name| name.as_str())
3280
+ .collect::<Vec<_>>();
3281
+ let names_known = !old.param_names.is_empty() || !new.param_names.is_empty();
3282
+
3283
+ println!("Model Change Explanation");
3284
+ println!("------------------------");
3285
+ println!("Type: PyTorch");
3286
+ println!("Old: {}", old.path);
3287
+ println!("New: {}", new.path);
3288
+ println!();
3289
+ println!("Architecture:");
3290
+ println!(
3291
+ " Tensors: {} -> {} ({})",
3292
+ old.tensor_count,
3293
+ new.tensor_count,
3294
+ growth_label(old.tensor_count, new.tensor_count)
3295
+ );
3296
+ println!(
3297
+ " Parameters (est): {} -> {} ({})",
3298
+ format_count_human(old.estimated_parameter_count as usize),
3299
+ format_count_human(new.estimated_parameter_count as usize),
3300
+ growth_label(
3301
+ old.estimated_parameter_count as usize,
3302
+ new.estimated_parameter_count as usize
3303
+ )
3304
+ );
3305
+ println!(
3306
+ " Dominant dtype: {} -> {} ({})",
3307
+ old.dominant_dtype.as_deref().unwrap_or("unknown"),
3308
+ new.dominant_dtype.as_deref().unwrap_or("unknown"),
3309
+ same_or_changed(old.dominant_dtype == new.dominant_dtype)
3310
+ );
3311
+ let name_note = if !names_known {
3312
+ "names not recovered".to_string()
3313
+ } else if added.is_empty() && removed.is_empty() {
3314
+ "same set".to_string()
3315
+ } else {
3316
+ format!("+{} added, -{} removed", added.len(), removed.len())
3317
+ };
3318
+ println!(
3319
+ " Param names: {} -> {} ({})",
3320
+ old.param_names.len(),
3321
+ new.param_names.len(),
3322
+ name_note
3323
+ );
3324
+ print_lgbm_feature_list(" added: ", &added, 8);
3325
+ print_lgbm_feature_list(" removed:", &removed, 8);
3326
+ println!();
3327
+ println!(
3328
+ "File size: {} -> {} ({})",
3329
+ format_bytes(old.bytes),
3330
+ format_bytes(new.bytes),
3331
+ growth_label(old.bytes, new.bytes)
3332
+ );
3333
+ println!();
3334
+ if old.learned_state_fingerprint != new.learned_state_fingerprint {
3335
+ println!("Learned state: CHANGED (sampled weight bytes differ - a real retrain/finetune)");
3336
+ } else {
3337
+ println!("Learned state: unchanged (sampled weight bytes identical)");
3338
+ }
3339
+ println!();
3340
+ println!("Summary:");
3341
+ let names_same = added.is_empty() && removed.is_empty();
3342
+ let descriptor = if !names_known {
3343
+ "weights compared (parameter names not recovered)"
3344
+ } else if !names_same {
3345
+ "parameter set changed (layers added/removed)"
3346
+ } else if old.estimated_parameter_count != new.estimated_parameter_count {
3347
+ "same layers, parameter count changed (resized)"
3348
+ } else {
3349
+ "same architecture, retrained/finetuned"
3350
+ };
3351
+ println!(
3352
+ " {descriptor}; params {}.",
3353
+ growth_label(
3354
+ old.estimated_parameter_count as usize,
3355
+ new.estimated_parameter_count as usize
3356
+ )
3357
+ );
3358
+ println!();
3359
+ println!(
3360
+ "Note: static read of the torch.save zip (no torch); shapes are not decoded and the weight fingerprint is sampled."
3361
+ );
3362
+
3363
+ Ok(())
3364
+ }
3365
+
3366
+ fn check_pt(old_path: &str, new_path: &str, options: &CheckOptions) -> Result<(), String> {
3367
+ if options.max_ops_growth_pct.is_some() {
3368
+ return Err("--max-ops-growth is only supported for ONNX artifacts.".to_string());
3369
+ }
3370
+ if options.fail_on_new_op {
3371
+ return Err("--fail-on-new-op is only supported for ONNX artifacts.".to_string());
3372
+ }
3373
+ if options.fail_on_training_config_change {
3374
+ return Err(
3375
+ "--fail-on-training-config-change is only supported for CatBoost/LightGBM artifacts."
3376
+ .to_string(),
3377
+ );
3378
+ }
3379
+
3380
+ let old = pt::inspect(old_path)?;
3381
+ let new = pt::inspect(new_path)?;
3382
+ let mut checks = Vec::new();
3383
+
3384
+ if let Some(max_pct) = options.max_size_growth_pct {
3385
+ checks.push(growth_check(
3386
+ "file size growth",
3387
+ old.bytes,
3388
+ new.bytes,
3389
+ max_pct,
3390
+ ));
3391
+ }
3392
+ if let Some(max_pct) = options.max_parameter_growth_pct {
3393
+ checks.push(growth_check(
3394
+ "parameter growth",
3395
+ old.estimated_parameter_count as usize,
3396
+ new.estimated_parameter_count as usize,
3397
+ max_pct,
3398
+ ));
3399
+ }
3400
+ if options.fail_on_feature_change {
3401
+ checks.push(boolean_check(
3402
+ "parameter names unchanged",
3403
+ old.param_names == new.param_names,
3404
+ format!(
3405
+ "{} -> {} recovered names",
3406
+ old.param_names.len(),
3407
+ new.param_names.len()
3408
+ ),
3409
+ ));
3410
+ }
3411
+
3412
+ print_check_report("PyTorch", &old.path, &new.path, &checks);
3413
+ finish_check(&checks)
3414
+ }
3415
+
2992
3416
  fn onnx_cmd(args: &[String]) -> Result<(), String> {
2993
3417
  if args.is_empty() {
2994
3418
  return Err("usage: mod-trace onnx [--json] [--limit 20] <model.onnx>".to_string());
@@ -3488,13 +3912,14 @@ fn print_help() {
3488
3912
  "mod-trace - inspect ML model artifacts without loading the framework\n\n\
3489
3913
  Core usage:\n \
3490
3914
  mod-trace doctor [--json]\n \
3491
- mod-trace inspect [--deep] [--json] [--limit 20] <model.cbm|model.lgb|model.onnx|model.json>\n \
3492
- mod-trace diff [--deep] [--json] <old-model> <new-model> (.cbm, .lgb/.txt LightGBM, or .onnx)\n \
3915
+ mod-trace inspect [--deep] [--json] [--limit 20] <model.cbm|model.lgb|model.onnx|model.pt|model.json>\n \
3916
+ mod-trace diff [--deep] [--json] <old-model> <new-model> (.cbm, .lgb/.txt, .onnx, or .pt/.pth)\n \
3493
3917
  mod-trace explain-diff <old-model> <new-model> (plain-English what changed: layers, params, cost, new ops)\n\n\
3494
3918
  mod-trace check [--max-size-growth 20%] [--max-ops-growth 25%] [--max-parameter-growth 30%] [--fail-on-feature-change] [--fail-on-training-config-change] [--fail-on-new-op] <old-model> <new-model>\n\n\
3495
3919
  Artifact inspectors:\n \
3496
3920
  mod-trace catboost [--deep] [--json] [--limit 20] <model.cbm> [more.cbm...]\n \
3497
3921
  mod-trace lightgbm [--json] [--limit 20] <model.lgb|model.txt> [more...]\n \
3922
+ mod-trace pytorch [--json] [--limit 20] <model.pt|model.pth> [more...]\n \
3498
3923
  mod-trace onnx [--json] [--limit 20] <model.onnx>\n \
3499
3924
  mod-trace explain <model.onnx|model.cbm>\n\n\
3500
3925
  Tensor lab (secondary; see docs/tensor-lab.md):\n \
@@ -3831,7 +4256,7 @@ mod tests {
3831
4256
 
3832
4257
  assert_eq!(
3833
4258
  check_cmd(&args),
3834
- Err("check supports CatBoost, LightGBM, and ONNX artifacts, not tensor plan JSON."
4259
+ Err("check supports CatBoost, LightGBM, ONNX, and PyTorch artifacts, not tensor plan JSON."
3835
4260
  .to_string())
3836
4261
  );
3837
4262
  }
@@ -0,0 +1,398 @@
1
+ use serde::Serialize;
2
+ use std::fs::File;
3
+ use std::io::{Read, Seek, SeekFrom};
4
+ use std::path::Path;
5
+
6
+ /// A static read of a PyTorch `torch.save` artifact (`.pt`/`.pth`).
7
+ ///
8
+ /// Modern PyTorch saves an (uncompressed) ZIP of a pickled structure
9
+ /// (`data.pkl`) plus raw tensor storages under `.../data/N`. This reader walks
10
+ /// that zip without running Python or torch: it sizes the storages, recovers
11
+ /// parameter names and dtype from the pickle, and fingerprints the weights by
12
+ /// sampling. It does not decode exact per-tensor shapes (that needs a pickle VM).
13
+ #[derive(Serialize)]
14
+ pub struct PtReport {
15
+ pub path: String,
16
+ pub bytes: usize,
17
+ pub is_zip: bool,
18
+ pub tensor_count: usize,
19
+ pub total_parameter_bytes: u64,
20
+ pub estimated_parameter_count: u64,
21
+ pub dominant_dtype: Option<String>,
22
+ pub param_names: Vec<String>,
23
+ pub torch_version: Option<String>,
24
+ pub file_fingerprint: u64,
25
+ pub learned_state_fingerprint: u64,
26
+ }
27
+
28
+ impl PtReport {
29
+ pub fn format(&self) -> &'static str {
30
+ if self.is_zip {
31
+ "PyTorch (torch.save zip)"
32
+ } else {
33
+ "PyTorch (legacy pickle)"
34
+ }
35
+ }
36
+ }
37
+
38
+ pub fn looks_like_pt(head: &[u8]) -> bool {
39
+ head.starts_with(b"PK\x03\x04") || head.first() == Some(&0x80)
40
+ }
41
+
42
+ const SAMPLE_PER_STORAGE: u64 = 1 << 20; // 1 MiB sampled per tensor for the weight fingerprint
43
+
44
+ pub fn inspect(path: &str) -> Result<PtReport, String> {
45
+ let mut file = File::open(path).map_err(|err| format!("open {path}: {err}"))?;
46
+ let total = file
47
+ .metadata()
48
+ .map_err(|err| format!("stat {path}: {err}"))?
49
+ .len();
50
+
51
+ let mut magic = [0u8; 4];
52
+ let read = file.read(&mut magic).map_err(|err| format!("read {path}: {err}"))?;
53
+ if read >= 4 && &magic == b"PK\x03\x04" {
54
+ inspect_zip(&mut file, path, total)
55
+ } else if read >= 1 && magic[0] == 0x80 {
56
+ inspect_legacy(&mut file, path, total)
57
+ } else {
58
+ Err(format!(
59
+ "{path} does not look like a PyTorch file (no zip `PK` or pickle marker)"
60
+ ))
61
+ }
62
+ }
63
+
64
+ fn read_at(file: &mut File, offset: u64, len: usize) -> Result<Vec<u8>, String> {
65
+ file.seek(SeekFrom::Start(offset))
66
+ .map_err(|err| format!("seek: {err}"))?;
67
+ let mut buf = vec![0u8; len];
68
+ file.read_exact(&mut buf).map_err(|err| format!("read: {err}"))?;
69
+ Ok(buf)
70
+ }
71
+
72
+ fn le_u16(b: &[u8], o: usize) -> u16 {
73
+ u16::from_le_bytes([b[o], b[o + 1]])
74
+ }
75
+ fn le_u32(b: &[u8], o: usize) -> u32 {
76
+ u32::from_le_bytes([b[o], b[o + 1], b[o + 2], b[o + 3]])
77
+ }
78
+ fn le_u64(b: &[u8], o: usize) -> u64 {
79
+ u64::from_le_bytes([
80
+ b[o], b[o + 1], b[o + 2], b[o + 3], b[o + 4], b[o + 5], b[o + 6], b[o + 7],
81
+ ])
82
+ }
83
+
84
+ /// Offset where an entry's data begins, by reading its local file header.
85
+ fn entry_data_start(file: &mut File, local_offset: u64) -> Result<u64, String> {
86
+ let lh = read_at(file, local_offset, 30)?;
87
+ if &lh[0..4] != b"PK\x03\x04" {
88
+ return Err("bad local header".to_string());
89
+ }
90
+ let name_len = le_u16(&lh, 26) as u64;
91
+ let extra_len = le_u16(&lh, 28) as u64;
92
+ Ok(local_offset + 30 + name_len + extra_len)
93
+ }
94
+
95
+ fn find_sig_rev(buf: &[u8], sig: &[u8; 4]) -> Option<usize> {
96
+ if buf.len() < 4 {
97
+ return None;
98
+ }
99
+ let mut i = buf.len() - 4;
100
+ loop {
101
+ if &buf[i..i + 4] == sig {
102
+ return Some(i);
103
+ }
104
+ if i == 0 {
105
+ return None;
106
+ }
107
+ i -= 1;
108
+ }
109
+ }
110
+
111
+ fn inspect_zip(file: &mut File, path: &str, total: u64) -> Result<PtReport, String> {
112
+ // Torch zips use data descriptors, so sizes live in the central directory.
113
+ let tail_len = total.min(65_557) as usize;
114
+ let tail = read_at(file, total - tail_len as u64, tail_len)?;
115
+ let eocd_rel =
116
+ find_sig_rev(&tail, b"PK\x05\x06").ok_or_else(|| format!("{path}: no zip EOCD record"))?;
117
+ let eocd = &tail[eocd_rel..];
118
+
119
+ let mut cd_size = le_u32(eocd, 12) as u64;
120
+ let mut cd_offset = le_u32(eocd, 16) as u64;
121
+
122
+ // ZIP64 when fields are saturated (models > 4 GiB).
123
+ if cd_offset == 0xFFFF_FFFF || cd_size == 0xFFFF_FFFF {
124
+ if eocd_rel >= 20 {
125
+ let loc = &tail[eocd_rel - 20..];
126
+ if &loc[0..4] == b"PK\x06\x07" {
127
+ let z64_off = le_u64(loc, 8);
128
+ let z64 = read_at(file, z64_off, 56)?;
129
+ if &z64[0..4] == b"PK\x06\x06" {
130
+ cd_size = le_u64(&z64, 40);
131
+ cd_offset = le_u64(&z64, 48);
132
+ }
133
+ }
134
+ }
135
+ }
136
+
137
+ let cd = read_at(file, cd_offset, cd_size as usize)?;
138
+
139
+ let mut pkl = Vec::new();
140
+ let mut torch_version = None;
141
+ let mut tensor_count = 0usize;
142
+ let mut total_parameter_bytes = 0u64;
143
+ let mut struct_hash = FNV_OFFSET;
144
+ let mut weight_hash = FNV_OFFSET;
145
+
146
+ let mut o = 0usize;
147
+ while o + 46 <= cd.len() && &cd[o..o + 4] == b"PK\x01\x02" {
148
+ let comp32 = le_u32(&cd, o + 20);
149
+ let mut uncomp_size = le_u32(&cd, o + 24) as u64;
150
+ let name_len = le_u16(&cd, o + 28) as usize;
151
+ let extra_len = le_u16(&cd, o + 30) as usize;
152
+ let comment_len = le_u16(&cd, o + 32) as usize;
153
+ let mut local_offset = le_u32(&cd, o + 42) as u64;
154
+ let name = String::from_utf8_lossy(&cd[o + 46..o + 46 + name_len]).to_string();
155
+ let extra = &cd[o + 46 + name_len..o + 46 + name_len + extra_len];
156
+
157
+ // ZIP64 extended info: present fields appear in order uncomp, comp, offset.
158
+ if uncomp_size == 0xFFFF_FFFF || local_offset == 0xFFFF_FFFF {
159
+ let mut e = 0usize;
160
+ while e + 4 <= extra.len() {
161
+ let id = le_u16(extra, e);
162
+ let sz = le_u16(extra, e + 2) as usize;
163
+ if id == 0x0001 {
164
+ let mut f = e + 4;
165
+ if uncomp_size == 0xFFFF_FFFF && f + 8 <= extra.len() {
166
+ uncomp_size = le_u64(extra, f);
167
+ f += 8;
168
+ }
169
+ if comp32 == 0xFFFF_FFFF && f + 8 <= extra.len() {
170
+ f += 8; // skip compressed size
171
+ }
172
+ if local_offset == 0xFFFF_FFFF && f + 8 <= extra.len() {
173
+ local_offset = le_u64(extra, f);
174
+ }
175
+ break;
176
+ }
177
+ e += 4 + sz;
178
+ }
179
+ }
180
+
181
+ let base = name.rsplit('/').next().unwrap_or(&name);
182
+ struct_hash = fnv_update(struct_hash, name.as_bytes());
183
+ struct_hash = fnv_update(struct_hash, &uncomp_size.to_le_bytes());
184
+
185
+ if name.ends_with("data.pkl") {
186
+ let start = entry_data_start(file, local_offset)?;
187
+ pkl = read_at(file, start, uncomp_size as usize)?;
188
+ } else if name.ends_with("/version") || name == "version" {
189
+ let start = entry_data_start(file, local_offset)?;
190
+ let v = read_at(file, start, uncomp_size.min(32) as usize)?;
191
+ torch_version = Some(String::from_utf8_lossy(&v).trim().to_string());
192
+ } else if name.contains("/data/") && !base.is_empty() && base.bytes().all(|b| b.is_ascii_digit())
193
+ {
194
+ tensor_count += 1;
195
+ total_parameter_bytes += uncomp_size;
196
+ let sample = uncomp_size.min(SAMPLE_PER_STORAGE) as usize;
197
+ if sample > 0 {
198
+ let start = entry_data_start(file, local_offset)?;
199
+ let bytes = read_at(file, start, sample)?;
200
+ weight_hash = fnv_update(weight_hash, &bytes);
201
+ }
202
+ weight_hash = fnv_update(weight_hash, &uncomp_size.to_le_bytes());
203
+ }
204
+
205
+ o += 46 + name_len + extra_len + comment_len;
206
+ }
207
+
208
+ if pkl.is_empty() {
209
+ return Err(format!("{path}: no data.pkl found inside the archive"));
210
+ }
211
+
212
+ let param_names = recover_param_names(&pkl);
213
+ let (dtype, elem_size) = dominant_dtype(&pkl);
214
+ let estimated_parameter_count = total_parameter_bytes / elem_size;
215
+
216
+ let file_fingerprint = fnv_update(fnv_update(FNV_OFFSET, &pkl), &struct_hash.to_le_bytes());
217
+
218
+ Ok(PtReport {
219
+ path: file_name(path),
220
+ bytes: total as usize,
221
+ is_zip: true,
222
+ tensor_count,
223
+ total_parameter_bytes,
224
+ estimated_parameter_count,
225
+ dominant_dtype: dtype,
226
+ param_names,
227
+ torch_version,
228
+ file_fingerprint,
229
+ learned_state_fingerprint: weight_hash,
230
+ })
231
+ }
232
+
233
+ fn inspect_legacy(file: &mut File, path: &str, total: u64) -> Result<PtReport, String> {
234
+ // Legacy (non-zip) pickle: we can only scan strings + fingerprint the bytes.
235
+ let bytes = read_at(file, 0, total as usize)?;
236
+ let param_names = recover_param_names(&bytes);
237
+ let (dtype, _) = dominant_dtype(&bytes);
238
+ Ok(PtReport {
239
+ path: file_name(path),
240
+ bytes: total as usize,
241
+ is_zip: false,
242
+ tensor_count: 0,
243
+ total_parameter_bytes: 0,
244
+ estimated_parameter_count: 0,
245
+ dominant_dtype: dtype,
246
+ param_names,
247
+ torch_version: None,
248
+ file_fingerprint: fnv_update(FNV_OFFSET, &bytes),
249
+ learned_state_fingerprint: fnv_update(FNV_OFFSET, &bytes),
250
+ })
251
+ }
252
+
253
+ fn file_name(path: &str) -> String {
254
+ Path::new(path)
255
+ .file_name()
256
+ .and_then(|name| name.to_str())
257
+ .unwrap_or(path)
258
+ .to_string()
259
+ }
260
+
261
+ /// Pull UTF-8 strings out of pickle BINUNICODE (0x58) / SHORT_BINUNICODE (0x8c)
262
+ /// opcodes and keep the ones that look like state_dict keys.
263
+ fn recover_param_names(pkl: &[u8]) -> Vec<String> {
264
+ let mut names = Vec::new();
265
+ let mut seen = std::collections::BTreeSet::new();
266
+ let mut i = 0usize;
267
+ while i < pkl.len() {
268
+ let (text, next) = match pkl[i] {
269
+ 0x8c if i + 2 <= pkl.len() => {
270
+ let n = pkl[i + 1] as usize;
271
+ if i + 2 + n <= pkl.len() {
272
+ (std::str::from_utf8(&pkl[i + 2..i + 2 + n]).ok(), i + 2 + n)
273
+ } else {
274
+ (None, i + 1)
275
+ }
276
+ }
277
+ 0x58 if i + 5 <= pkl.len() => {
278
+ let n = le_u32(pkl, i + 1) as usize;
279
+ if i + 5 + n <= pkl.len() {
280
+ (std::str::from_utf8(&pkl[i + 5..i + 5 + n]).ok(), i + 5 + n)
281
+ } else {
282
+ (None, i + 1)
283
+ }
284
+ }
285
+ _ => (None, i + 1),
286
+ };
287
+ if let Some(text) = text {
288
+ if looks_like_param_name(text) && seen.insert(text.to_string()) {
289
+ names.push(text.to_string());
290
+ }
291
+ }
292
+ i = next;
293
+ }
294
+ names
295
+ }
296
+
297
+ fn looks_like_param_name(value: &str) -> bool {
298
+ if value.len() < 2 || value.len() > 200 {
299
+ return false;
300
+ }
301
+ if !value
302
+ .bytes()
303
+ .all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
304
+ {
305
+ return false;
306
+ }
307
+ value.contains('.')
308
+ || value.ends_with("weight")
309
+ || value.ends_with("bias")
310
+ || value.ends_with("running_mean")
311
+ || value.ends_with("running_var")
312
+ }
313
+
314
+ fn dominant_dtype(pkl: &[u8]) -> (Option<String>, u64) {
315
+ // (storage token, dtype name, element size)
316
+ let candidates: [(&[u8], &str, u64); 8] = [
317
+ (b"DoubleStorage", "float64", 8),
318
+ (b"BFloat16Storage", "bfloat16", 2),
319
+ (b"HalfStorage", "float16", 2),
320
+ (b"FloatStorage", "float32", 4),
321
+ (b"LongStorage", "int64", 8),
322
+ (b"IntStorage", "int32", 4),
323
+ (b"ByteStorage", "uint8", 1),
324
+ (b"BoolStorage", "bool", 1),
325
+ ];
326
+ let mut best: Option<(&str, u64, usize)> = None;
327
+ for (token, name, size) in candidates {
328
+ let count = count_occurrences(pkl, token);
329
+ if count > 0 && best.map(|(_, _, c)| count > c).unwrap_or(true) {
330
+ best = Some((name, size, count));
331
+ }
332
+ }
333
+ match best {
334
+ Some((name, size, _)) => (Some(name.to_string()), size),
335
+ None => (None, 4),
336
+ }
337
+ }
338
+
339
+ fn count_occurrences(haystack: &[u8], needle: &[u8]) -> usize {
340
+ if needle.is_empty() || haystack.len() < needle.len() {
341
+ return 0;
342
+ }
343
+ let mut count = 0;
344
+ let mut i = 0;
345
+ while i + needle.len() <= haystack.len() {
346
+ if &haystack[i..i + needle.len()] == needle {
347
+ count += 1;
348
+ i += needle.len();
349
+ } else {
350
+ i += 1;
351
+ }
352
+ }
353
+ count
354
+ }
355
+
356
+ const FNV_OFFSET: u64 = 0xcbf29ce484222325;
357
+ const FNV_PRIME: u64 = 0x100000001b3;
358
+
359
+ fn fnv_update(mut hash: u64, data: &[u8]) -> u64 {
360
+ for byte in data {
361
+ hash ^= u64::from(*byte);
362
+ hash = hash.wrapping_mul(FNV_PRIME);
363
+ }
364
+ hash
365
+ }
366
+
367
+ #[cfg(test)]
368
+ mod tests {
369
+ use super::{count_occurrences, looks_like_param_name, recover_param_names};
370
+
371
+ #[test]
372
+ fn param_name_filter() {
373
+ assert!(looks_like_param_name("fc1.weight"));
374
+ assert!(looks_like_param_name("encoder.layers.3.attn.bias"));
375
+ assert!(looks_like_param_name("classifier.weight"));
376
+ assert!(!looks_like_param_name("cpu"));
377
+ assert!(!looks_like_param_name("storage"));
378
+ assert!(!looks_like_param_name("has space"));
379
+ }
380
+
381
+ #[test]
382
+ fn recovers_binunicode_names() {
383
+ // SHORT_BINUNICODE(0x8c) len=10 "fc1.weight", then BINUNICODE(0x58) len=8 "fc1.bias"
384
+ let mut pkl = vec![0x8c, 10];
385
+ pkl.extend_from_slice(b"fc1.weight");
386
+ pkl.push(0x58);
387
+ pkl.extend_from_slice(&8u32.to_le_bytes());
388
+ pkl.extend_from_slice(b"fc1.bias");
389
+ let names = recover_param_names(&pkl);
390
+ assert_eq!(names, vec!["fc1.weight", "fc1.bias"]);
391
+ }
392
+
393
+ #[test]
394
+ fn counts_tokens() {
395
+ assert_eq!(count_occurrences(b"FloatStorageFloatStorage", b"FloatStorage"), 2);
396
+ assert_eq!(count_occurrences(b"none here", b"FloatStorage"), 0);
397
+ }
398
+ }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes