mod-trace 0.3.2__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {mod_trace-0.3.2 → mod_trace-0.4.0}/Cargo.lock +1 -1
- {mod_trace-0.3.2 → mod_trace-0.4.0}/Cargo.toml +1 -1
- {mod_trace-0.3.2 → mod_trace-0.4.0}/PKG-INFO +2 -2
- {mod_trace-0.3.2 → mod_trace-0.4.0}/README.md +1 -1
- mod_trace-0.4.0/examples/pytorch/README.md +35 -0
- mod_trace-0.4.0/examples/pytorch/generate_demo_models.py +35 -0
- mod_trace-0.4.0/examples/pytorch/mlp_v1.pt +0 -0
- mod_trace-0.4.0/examples/pytorch/mlp_v2.pt +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/pyproject.toml +1 -1
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/main.rs +441 -16
- mod_trace-0.4.0/src/pt.rs +398 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/.github/workflows/release.yml +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/.gitignore +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/LICENSE +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/benchmarks/tiny_pytorch.py +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/docs/ARCHITECTURE.md +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/docs/REAL_MODELS.md +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/docs/tensor-lab.md +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/broken_shape.json +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/README.md +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/clf_v1.txt +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/clf_v2.txt +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/lightgbm/generate_demo_models.py +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/make_sample_catboost.py +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/mlp.json +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/README.md +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/generate_demo_models.py +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_retrain_a.onnx +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_retrain_b.onnx +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_v1.onnx +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/onnx/mlp_v2.onnx +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/tiny_attention.json +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/examples/tiny_attention_plan.json +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/catboost_deep_diff.py +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/catboost_explain.py +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/cbm.rs +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/demo.rs +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/explain.rs +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/lgbm.rs +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/model.rs +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/onnx.rs +0 -0
- {mod_trace-0.3.2 → mod_trace-0.4.0}/src/tensor.rs +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: mod-trace
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Classifier: Programming Language :: Rust
|
|
5
5
|
Classifier: Programming Language :: Python :: 3
|
|
6
6
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
@@ -25,7 +25,7 @@ mod-trace is a small Rust CLI for answering a practical question:
|
|
|
25
25
|
What is inside this model file?
|
|
26
26
|
```
|
|
27
27
|
|
|
28
|
-
It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models,
|
|
28
|
+
It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models, ONNX `.onnx` graphs, and PyTorch `.pt`/`.pth` checkpoints, then report structure, size, parameters, operator mix, rough inference cost, and changes between versions. All formats are read natively — no Python, framework, or runtime needed (CatBoost `--deep` is the one optional exception). The PyTorch reader is static: it sizes/names tensors and fingerprints weights without decoding exact shapes.
|
|
29
29
|
|
|
30
30
|
The most useful command is `explain-diff`, which says in plain English what changed between two model versions:
|
|
31
31
|
|
|
@@ -8,7 +8,7 @@ mod-trace is a small Rust CLI for answering a practical question:
|
|
|
8
8
|
What is inside this model file?
|
|
9
9
|
```
|
|
10
10
|
|
|
11
|
-
It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models,
|
|
11
|
+
It can inspect real artifacts such as CatBoost `.cbm` files, LightGBM `.txt`/`.lgb` text models, ONNX `.onnx` graphs, and PyTorch `.pt`/`.pth` checkpoints, then report structure, size, parameters, operator mix, rough inference cost, and changes between versions. All formats are read natively — no Python, framework, or runtime needed (CatBoost `--deep` is the one optional exception). The PyTorch reader is static: it sizes/names tensors and fingerprints weights without decoding exact shapes.
|
|
12
12
|
|
|
13
13
|
The most useful command is `explain-diff`, which says in plain English what changed between two model versions:
|
|
14
14
|
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# PyTorch example models
|
|
2
|
+
|
|
3
|
+
Synthetic `torch.save` artifacts for trying `mod-trace` on PyTorch with **no
|
|
4
|
+
torch and no Python** — mod-trace reads the `.pt` zip (pickled structure + raw
|
|
5
|
+
tensor storages) statically.
|
|
6
|
+
|
|
7
|
+
| Files | What they show |
|
|
8
|
+
|-------|----------------|
|
|
9
|
+
| `mlp_v1.pt` vs `mlp_v2.pt` | Same 2-layer MLP, hidden size 32 → 64 (parameter count ~doubles, same layer names). |
|
|
10
|
+
|
|
11
|
+
## Try it
|
|
12
|
+
|
|
13
|
+
```bash
|
|
14
|
+
mod-trace inspect examples/pytorch/mlp_v1.pt
|
|
15
|
+
mod-trace explain-diff examples/pytorch/mlp_v1.pt examples/pytorch/mlp_v2.pt
|
|
16
|
+
mod-trace check --max-parameter-growth 30% examples/pytorch/mlp_v1.pt examples/pytorch/mlp_v2.pt
|
|
17
|
+
mod-trace inspect --json examples/pytorch/mlp_v1.pt
|
|
18
|
+
```
|
|
19
|
+
|
|
20
|
+
## What it reads (and what it doesn't)
|
|
21
|
+
|
|
22
|
+
Reads, statically: file size, tensor/storage count, **estimated parameter count**
|
|
23
|
+
(from storage bytes ÷ dtype), **dominant dtype**, **recovered parameter/layer
|
|
24
|
+
names** (`fc1.weight`, …), and fingerprints (a sampled weight fingerprint that
|
|
25
|
+
changes on a retrain/finetune).
|
|
26
|
+
|
|
27
|
+
Does **not** decode exact per-tensor shapes — that would need a full pickle
|
|
28
|
+
interpreter. Same static/heuristic philosophy as the CatBoost and ONNX readers.
|
|
29
|
+
|
|
30
|
+
## Regenerate
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
python -m pip install torch
|
|
34
|
+
python examples/pytorch/generate_demo_models.py
|
|
35
|
+
```
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""Generate the synthetic PyTorch demo models used by the README examples.
|
|
2
|
+
|
|
3
|
+
Fully synthetic (no real data). Run:
|
|
4
|
+
|
|
5
|
+
python -m pip install torch
|
|
6
|
+
python examples/pytorch/generate_demo_models.py
|
|
7
|
+
|
|
8
|
+
Produces, in this directory:
|
|
9
|
+
mlp_v1.pt / mlp_v2.pt -> same 2-layer MLP, different hidden size (32 vs 64)
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import os
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
|
|
17
|
+
HERE = os.path.dirname(os.path.abspath(__file__))
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Net(nn.Module):
|
|
21
|
+
def __init__(self, hidden):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.fc1 = nn.Linear(16, hidden)
|
|
24
|
+
self.fc2 = nn.Linear(hidden, 4)
|
|
25
|
+
|
|
26
|
+
def forward(self, x):
|
|
27
|
+
return self.fc2(torch.relu(self.fc1(x)))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
if __name__ == "__main__":
|
|
31
|
+
torch.manual_seed(0)
|
|
32
|
+
torch.save(Net(32).state_dict(), os.path.join(HERE, "mlp_v1.pt"))
|
|
33
|
+
torch.manual_seed(1)
|
|
34
|
+
torch.save(Net(64).state_dict(), os.path.join(HERE, "mlp_v2.pt"))
|
|
35
|
+
print("wrote mlp_v1.pt and mlp_v2.pt")
|
|
Binary file
|
|
Binary file
|
|
@@ -4,6 +4,7 @@ mod explain;
|
|
|
4
4
|
mod lgbm;
|
|
5
5
|
mod model;
|
|
6
6
|
mod onnx;
|
|
7
|
+
mod pt;
|
|
7
8
|
mod tensor;
|
|
8
9
|
|
|
9
10
|
use std::collections::BTreeSet;
|
|
@@ -54,6 +55,7 @@ fn run() -> Result<(), String> {
|
|
|
54
55
|
Some("onnx") => onnx_cmd(&args.rest),
|
|
55
56
|
Some("catboost") | Some("cbm") => catboost_cmd(&args.rest),
|
|
56
57
|
Some("lightgbm") | Some("lgbm") => lgbm_cmd(&args.rest),
|
|
58
|
+
Some("pytorch") | Some("pt") => pt_cmd(&args.rest),
|
|
57
59
|
Some("validate") => validate_model_cmd(&args.rest),
|
|
58
60
|
Some("tensor-inspect") => inspect_model_cmd(&args.rest),
|
|
59
61
|
Some("run") => run_model_cmd(&args.rest),
|
|
@@ -93,6 +95,7 @@ struct BuiltInDoctorReport {
|
|
|
93
95
|
catboost_metadata: bool,
|
|
94
96
|
lightgbm_text: bool,
|
|
95
97
|
onnx_static_graph: bool,
|
|
98
|
+
pytorch_zip: bool,
|
|
96
99
|
json_tensor_plans: bool,
|
|
97
100
|
}
|
|
98
101
|
|
|
@@ -125,6 +128,7 @@ fn doctor_report() -> DoctorReport {
|
|
|
125
128
|
catboost_metadata: true,
|
|
126
129
|
lightgbm_text: true,
|
|
127
130
|
onnx_static_graph: true,
|
|
131
|
+
pytorch_zip: true,
|
|
128
132
|
json_tensor_plans: true,
|
|
129
133
|
},
|
|
130
134
|
optional_python: PythonDoctorReport {
|
|
@@ -168,6 +172,10 @@ fn print_doctor_report(report: &DoctorReport) {
|
|
|
168
172
|
" ONNX static graph: {}",
|
|
169
173
|
availability(report.built_in.onnx_static_graph)
|
|
170
174
|
);
|
|
175
|
+
println!(
|
|
176
|
+
" PyTorch .pt/.pth (zip): {}",
|
|
177
|
+
availability(report.built_in.pytorch_zip)
|
|
178
|
+
);
|
|
171
179
|
println!(
|
|
172
180
|
" JSON tensor plans: {}",
|
|
173
181
|
availability(report.built_in.json_tensor_plans)
|
|
@@ -192,15 +200,15 @@ fn print_doctor_report(report: &DoctorReport) {
|
|
|
192
200
|
println!();
|
|
193
201
|
println!("Available commands:");
|
|
194
202
|
println!(
|
|
195
|
-
" inspect .cbm/.lgb/.onnx/.json: {}",
|
|
203
|
+
" inspect .cbm/.lgb/.onnx/.pt/.json: {}",
|
|
196
204
|
available_unavailable(report.commands.inspect_artifacts)
|
|
197
205
|
);
|
|
198
206
|
println!(
|
|
199
|
-
" diff .cbm/.lgb/.onnx: {}",
|
|
207
|
+
" diff .cbm/.lgb/.onnx/.pt: {}",
|
|
200
208
|
available_unavailable(report.commands.diff_artifacts)
|
|
201
209
|
);
|
|
202
210
|
println!(
|
|
203
|
-
" explain-diff .cbm/.lgb/.onnx: {}",
|
|
211
|
+
" explain-diff .cbm/.lgb/.onnx/.pt: {}",
|
|
204
212
|
available_unavailable(report.commands.diff_artifacts)
|
|
205
213
|
);
|
|
206
214
|
println!(
|
|
@@ -301,6 +309,10 @@ fn inspect_cmd(args: &[String]) -> Result<(), String> {
|
|
|
301
309
|
Err("--deep inspection is currently only supported for CatBoost artifacts.".to_string())
|
|
302
310
|
}
|
|
303
311
|
ArtifactKind::Onnx => onnx_cmd(args),
|
|
312
|
+
ArtifactKind::PyTorch if deep => {
|
|
313
|
+
Err("--deep inspection is currently only supported for CatBoost artifacts.".to_string())
|
|
314
|
+
}
|
|
315
|
+
ArtifactKind::PyTorch => pt_cmd(args),
|
|
304
316
|
ArtifactKind::Json if deep => {
|
|
305
317
|
Err("--deep inspection is currently only supported for CatBoost artifacts.".to_string())
|
|
306
318
|
}
|
|
@@ -312,7 +324,7 @@ fn inspect_cmd(args: &[String]) -> Result<(), String> {
|
|
|
312
324
|
}
|
|
313
325
|
ArtifactKind::Json => inspect_model_cmd(args),
|
|
314
326
|
ArtifactKind::Unknown => Err(format!(
|
|
315
|
-
"unsupported artifact type for `{}`. Try .cbm, .lgb, .onnx, or .json.",
|
|
327
|
+
"unsupported artifact type for `{}`. Try .cbm, .lgb, .onnx, .pt/.pth, or .json.",
|
|
316
328
|
target
|
|
317
329
|
)),
|
|
318
330
|
}
|
|
@@ -378,6 +390,13 @@ fn diff_cmd(args: &[String]) -> Result<(), String> {
|
|
|
378
390
|
}
|
|
379
391
|
diff_onnx(&paths[0], &paths[1], json)
|
|
380
392
|
}
|
|
393
|
+
(ArtifactKind::PyTorch, ArtifactKind::PyTorch) => {
|
|
394
|
+
if deep {
|
|
395
|
+
println!("Note: --deep is currently only used for CatBoost artifacts.");
|
|
396
|
+
println!();
|
|
397
|
+
}
|
|
398
|
+
diff_pt(&paths[0], &paths[1], json)
|
|
399
|
+
}
|
|
381
400
|
(ArtifactKind::Json, ArtifactKind::Json) => Err(
|
|
382
401
|
"tensor plan diff is not supported yet. Use trace, compare, why, or validate instead."
|
|
383
402
|
.to_string(),
|
|
@@ -462,8 +481,9 @@ fn check_cmd(args: &[String]) -> Result<(), String> {
|
|
|
462
481
|
check_lgbm(&paths[0], &paths[1], &options)
|
|
463
482
|
}
|
|
464
483
|
(ArtifactKind::Onnx, ArtifactKind::Onnx) => check_onnx(&paths[0], &paths[1], &options),
|
|
484
|
+
(ArtifactKind::PyTorch, ArtifactKind::PyTorch) => check_pt(&paths[0], &paths[1], &options),
|
|
465
485
|
(ArtifactKind::Json, ArtifactKind::Json) => {
|
|
466
|
-
Err("check supports CatBoost, LightGBM, and
|
|
486
|
+
Err("check supports CatBoost, LightGBM, ONNX, and PyTorch artifacts, not tensor plan JSON."
|
|
467
487
|
.to_string())
|
|
468
488
|
}
|
|
469
489
|
(left, right) => Err(format!(
|
|
@@ -715,6 +735,7 @@ fn explain_cmd(args: &[String]) -> Result<(), String> {
|
|
|
715
735
|
ArtifactKind::Onnx => return explain_onnx_cmd(target),
|
|
716
736
|
ArtifactKind::CatBoost => return explain_catboost_cmd(target),
|
|
717
737
|
ArtifactKind::LightGbm => return lgbm_cmd(&[target.to_string()]),
|
|
738
|
+
ArtifactKind::PyTorch => return pt_cmd(&[target.to_string()]),
|
|
718
739
|
ArtifactKind::Json => return explain_model_cmd(&[target.to_string()]),
|
|
719
740
|
ArtifactKind::Unknown => {}
|
|
720
741
|
}
|
|
@@ -1013,8 +1034,9 @@ fn explain_diff_cmd(args: &[String]) -> Result<(), String> {
|
|
|
1013
1034
|
(ArtifactKind::Onnx, ArtifactKind::Onnx) => explain_diff_onnx(old, new),
|
|
1014
1035
|
(ArtifactKind::LightGbm, ArtifactKind::LightGbm) => explain_diff_lgbm(old, new),
|
|
1015
1036
|
(ArtifactKind::CatBoost, ArtifactKind::CatBoost) => explain_diff_catboost(old, new),
|
|
1037
|
+
(ArtifactKind::PyTorch, ArtifactKind::PyTorch) => explain_diff_pt(old, new),
|
|
1016
1038
|
(left, right) => Err(format!(
|
|
1017
|
-
"explain-diff needs two artifacts of the same supported type (.onnx, .cbm, .lgb): {} vs {}",
|
|
1039
|
+
"explain-diff needs two artifacts of the same supported type (.onnx, .cbm, .lgb, .pt): {} vs {}",
|
|
1018
1040
|
left.label(),
|
|
1019
1041
|
right.label()
|
|
1020
1042
|
)),
|
|
@@ -1326,6 +1348,19 @@ fn explain_diff_catboost(old_path: &str, new_path: &str) -> Result<(), String> {
|
|
|
1326
1348
|
let old = cbm::inspect(old_path)?;
|
|
1327
1349
|
let new = cbm::inspect(new_path)?;
|
|
1328
1350
|
|
|
1351
|
+
let old_features = old.feature_candidates.iter().collect::<BTreeSet<_>>();
|
|
1352
|
+
let new_features = new.feature_candidates.iter().collect::<BTreeSet<_>>();
|
|
1353
|
+
let added = new_features
|
|
1354
|
+
.difference(&old_features)
|
|
1355
|
+
.map(|name| name.as_str())
|
|
1356
|
+
.collect::<Vec<_>>();
|
|
1357
|
+
let removed = old_features
|
|
1358
|
+
.difference(&new_features)
|
|
1359
|
+
.map(|name| name.as_str())
|
|
1360
|
+
.collect::<Vec<_>>();
|
|
1361
|
+
let names_known = !old.feature_candidates.is_empty() || !new.feature_candidates.is_empty();
|
|
1362
|
+
let config_same = catboost_training_config_same(&old, &new);
|
|
1363
|
+
|
|
1329
1364
|
println!("Model Change Explanation");
|
|
1330
1365
|
println!("------------------------");
|
|
1331
1366
|
println!("Type: CatBoost");
|
|
@@ -1334,21 +1369,68 @@ fn explain_diff_catboost(old_path: &str, new_path: &str) -> Result<(), String> {
|
|
|
1334
1369
|
println!();
|
|
1335
1370
|
println!("Architecture:");
|
|
1336
1371
|
println!(
|
|
1337
|
-
" Trees: {} -> {}",
|
|
1372
|
+
" Trees: {} -> {} ({})",
|
|
1338
1373
|
opt_num(old.iterations.map(|value| value as usize)),
|
|
1339
|
-
opt_num(new.iterations.map(|value| value as usize))
|
|
1374
|
+
opt_num(new.iterations.map(|value| value as usize)),
|
|
1375
|
+
match (old.iterations, new.iterations) {
|
|
1376
|
+
(Some(o), Some(n)) => growth_label(o as usize, n as usize),
|
|
1377
|
+
_ => "unknown".to_string(),
|
|
1378
|
+
}
|
|
1340
1379
|
);
|
|
1341
1380
|
println!(
|
|
1342
|
-
" Depth: {} -> {}",
|
|
1381
|
+
" Depth: {} -> {} ({})",
|
|
1343
1382
|
opt_num(old.depth.map(|value| value as usize)),
|
|
1344
|
-
opt_num(new.depth.map(|value| value as usize))
|
|
1383
|
+
opt_num(new.depth.map(|value| value as usize)),
|
|
1384
|
+
same_or_changed(old.depth == new.depth)
|
|
1345
1385
|
);
|
|
1386
|
+
let feature_note = if !names_known {
|
|
1387
|
+
"names not recovered".to_string()
|
|
1388
|
+
} else if added.is_empty() && removed.is_empty() {
|
|
1389
|
+
"same set".to_string()
|
|
1390
|
+
} else {
|
|
1391
|
+
format!("+{} added, -{} removed", added.len(), removed.len())
|
|
1392
|
+
};
|
|
1346
1393
|
println!(
|
|
1347
|
-
" Features (recovered): {} -> {}",
|
|
1394
|
+
" Features (recovered): {} -> {} ({})",
|
|
1348
1395
|
old.feature_candidates.len(),
|
|
1349
|
-
new.feature_candidates.len()
|
|
1396
|
+
new.feature_candidates.len(),
|
|
1397
|
+
feature_note
|
|
1350
1398
|
);
|
|
1399
|
+
print_lgbm_feature_list(" added: ", &added, 8);
|
|
1400
|
+
print_lgbm_feature_list(" removed:", &removed, 8);
|
|
1351
1401
|
println!();
|
|
1402
|
+
println!("Training config:");
|
|
1403
|
+
println!(
|
|
1404
|
+
" Loss: {} -> {} ({})",
|
|
1405
|
+
old.loss_function.as_deref().unwrap_or("unknown"),
|
|
1406
|
+
new.loss_function.as_deref().unwrap_or("unknown"),
|
|
1407
|
+
same_or_changed(old.loss_function == new.loss_function)
|
|
1408
|
+
);
|
|
1409
|
+
println!(
|
|
1410
|
+
" Eval metric: {} -> {} ({})",
|
|
1411
|
+
old.eval_metric.as_deref().unwrap_or("unknown"),
|
|
1412
|
+
new.eval_metric.as_deref().unwrap_or("unknown"),
|
|
1413
|
+
same_or_changed(old.eval_metric == new.eval_metric)
|
|
1414
|
+
);
|
|
1415
|
+
println!(
|
|
1416
|
+
" Learning rate: {} -> {} ({})",
|
|
1417
|
+
opt_float(old.learning_rate),
|
|
1418
|
+
opt_float(new.learning_rate),
|
|
1419
|
+
same_or_changed(old.learning_rate == new.learning_rate)
|
|
1420
|
+
);
|
|
1421
|
+
println!(
|
|
1422
|
+
" Grow policy: {} -> {} ({})",
|
|
1423
|
+
old.grow_policy.as_deref().unwrap_or("unknown"),
|
|
1424
|
+
new.grow_policy.as_deref().unwrap_or("unknown"),
|
|
1425
|
+
same_or_changed(old.grow_policy == new.grow_policy)
|
|
1426
|
+
);
|
|
1427
|
+
println!();
|
|
1428
|
+
println!(
|
|
1429
|
+
"File size: {} -> {} ({})",
|
|
1430
|
+
format_bytes(old.bytes),
|
|
1431
|
+
format_bytes(new.bytes),
|
|
1432
|
+
growth_label(old.bytes, new.bytes)
|
|
1433
|
+
);
|
|
1352
1434
|
match (old.estimated_leaf_values(), new.estimated_leaf_values()) {
|
|
1353
1435
|
(Some(o), Some(n)) => println!(
|
|
1354
1436
|
"Estimated leaf-slot growth: {}",
|
|
@@ -1363,7 +1445,20 @@ fn explain_diff_catboost(old_path: &str, new_path: &str) -> Result<(), String> {
|
|
|
1363
1445
|
println!("Learned state: unchanged");
|
|
1364
1446
|
}
|
|
1365
1447
|
println!();
|
|
1366
|
-
println!("
|
|
1448
|
+
println!("Summary:");
|
|
1449
|
+
let features_same = added.is_empty() && removed.is_empty();
|
|
1450
|
+
let descriptor = if !features_same {
|
|
1451
|
+
"feature set changed"
|
|
1452
|
+
} else if !config_same {
|
|
1453
|
+
"training config changed"
|
|
1454
|
+
} else if old.iterations != new.iterations || old.depth != new.depth {
|
|
1455
|
+
"same spec, retrained with different tree count/depth"
|
|
1456
|
+
} else {
|
|
1457
|
+
"same spec and features, retrained"
|
|
1458
|
+
};
|
|
1459
|
+
println!(" {descriptor}.");
|
|
1460
|
+
println!();
|
|
1461
|
+
println!("Note: heuristic summary; run `diff --deep` for exact split/leaf changes.");
|
|
1367
1462
|
|
|
1368
1463
|
Ok(())
|
|
1369
1464
|
}
|
|
@@ -1692,6 +1787,7 @@ enum ArtifactKind {
|
|
|
1692
1787
|
CatBoost,
|
|
1693
1788
|
LightGbm,
|
|
1694
1789
|
Onnx,
|
|
1790
|
+
PyTorch,
|
|
1695
1791
|
Json,
|
|
1696
1792
|
Unknown,
|
|
1697
1793
|
}
|
|
@@ -1702,6 +1798,7 @@ impl ArtifactKind {
|
|
|
1702
1798
|
ArtifactKind::CatBoost => "CatBoost",
|
|
1703
1799
|
ArtifactKind::LightGbm => "LightGBM",
|
|
1704
1800
|
ArtifactKind::Onnx => "ONNX",
|
|
1801
|
+
ArtifactKind::PyTorch => "PyTorch",
|
|
1705
1802
|
ArtifactKind::Json => "tensor plan JSON",
|
|
1706
1803
|
ArtifactKind::Unknown => "unknown",
|
|
1707
1804
|
}
|
|
@@ -1718,6 +1815,7 @@ fn artifact_kind(path: &str) -> ArtifactKind {
|
|
|
1718
1815
|
Some("cbm") => return ArtifactKind::CatBoost,
|
|
1719
1816
|
Some("lgb") => return ArtifactKind::LightGbm,
|
|
1720
1817
|
Some("onnx") => return ArtifactKind::Onnx,
|
|
1818
|
+
Some("pt") | Some("pth") => return ArtifactKind::PyTorch,
|
|
1721
1819
|
Some("json") => return ArtifactKind::Json,
|
|
1722
1820
|
_ => {}
|
|
1723
1821
|
}
|
|
@@ -2989,6 +3087,332 @@ fn check_lgbm(old_path: &str, new_path: &str, options: &CheckOptions) -> Result<
|
|
|
2989
3087
|
finish_check(&checks)
|
|
2990
3088
|
}
|
|
2991
3089
|
|
|
3090
|
+
fn pt_cmd(args: &[String]) -> Result<(), String> {
|
|
3091
|
+
if args.is_empty() {
|
|
3092
|
+
return Err(
|
|
3093
|
+
"usage: mod-trace pytorch [--json] [--limit 20] <model.pt|model.pth> [more...]"
|
|
3094
|
+
.to_string(),
|
|
3095
|
+
);
|
|
3096
|
+
}
|
|
3097
|
+
|
|
3098
|
+
let mut json = false;
|
|
3099
|
+
let mut limit = DEFAULT_LIMIT;
|
|
3100
|
+
let mut paths = Vec::new();
|
|
3101
|
+
let mut i = 0usize;
|
|
3102
|
+
while i < args.len() {
|
|
3103
|
+
match args[i].as_str() {
|
|
3104
|
+
"--json" => {
|
|
3105
|
+
json = true;
|
|
3106
|
+
i += 1;
|
|
3107
|
+
}
|
|
3108
|
+
"--deep" => {
|
|
3109
|
+
return Err(
|
|
3110
|
+
"--deep inspection is currently only supported for CatBoost artifacts."
|
|
3111
|
+
.to_string(),
|
|
3112
|
+
);
|
|
3113
|
+
}
|
|
3114
|
+
"--limit" => {
|
|
3115
|
+
let value = args
|
|
3116
|
+
.get(i + 1)
|
|
3117
|
+
.ok_or_else(|| "--limit needs a number".to_string())?;
|
|
3118
|
+
limit = value
|
|
3119
|
+
.parse::<usize>()
|
|
3120
|
+
.map_err(|err| format!("parse --limit: {err}"))?;
|
|
3121
|
+
i += 2;
|
|
3122
|
+
}
|
|
3123
|
+
value => {
|
|
3124
|
+
paths.push(value.to_string());
|
|
3125
|
+
i += 1;
|
|
3126
|
+
}
|
|
3127
|
+
}
|
|
3128
|
+
}
|
|
3129
|
+
|
|
3130
|
+
if paths.is_empty() {
|
|
3131
|
+
return Err(
|
|
3132
|
+
"usage: mod-trace pytorch [--json] [--limit 20] <model.pt|model.pth> [more...]"
|
|
3133
|
+
.to_string(),
|
|
3134
|
+
);
|
|
3135
|
+
}
|
|
3136
|
+
|
|
3137
|
+
if json {
|
|
3138
|
+
let reports = paths
|
|
3139
|
+
.iter()
|
|
3140
|
+
.map(|path| pt::inspect(path))
|
|
3141
|
+
.collect::<Result<Vec<_>, _>>()?;
|
|
3142
|
+
if reports.len() == 1 {
|
|
3143
|
+
print_json(&reports[0])?;
|
|
3144
|
+
} else {
|
|
3145
|
+
print_json(&reports)?;
|
|
3146
|
+
}
|
|
3147
|
+
return Ok(());
|
|
3148
|
+
}
|
|
3149
|
+
|
|
3150
|
+
for (index, path) in paths.iter().enumerate() {
|
|
3151
|
+
if index > 0 {
|
|
3152
|
+
println!();
|
|
3153
|
+
}
|
|
3154
|
+
let report = pt::inspect(path)?;
|
|
3155
|
+
print_pt_report(&report, limit);
|
|
3156
|
+
}
|
|
3157
|
+
|
|
3158
|
+
Ok(())
|
|
3159
|
+
}
|
|
3160
|
+
|
|
3161
|
+
fn print_pt_report(report: &pt::PtReport, limit: usize) {
|
|
3162
|
+
println!("PyTorch Model Summary");
|
|
3163
|
+
println!("---------------------");
|
|
3164
|
+
println!("Model: {}", report.path);
|
|
3165
|
+
println!("Format: {}", report.format());
|
|
3166
|
+
println!("File size: {}", format_bytes(report.bytes));
|
|
3167
|
+
if let Some(version) = report.torch_version.as_deref() {
|
|
3168
|
+
println!("Serialization version: {version}");
|
|
3169
|
+
}
|
|
3170
|
+
println!();
|
|
3171
|
+
println!("Structure:");
|
|
3172
|
+
println!(" Tensors (storages): {}", report.tensor_count);
|
|
3173
|
+
println!(
|
|
3174
|
+
" Parameters (est): {}",
|
|
3175
|
+
format_count_human(report.estimated_parameter_count as usize)
|
|
3176
|
+
);
|
|
3177
|
+
println!(
|
|
3178
|
+
" Parameter bytes: {}",
|
|
3179
|
+
format_bytes(report.total_parameter_bytes as usize)
|
|
3180
|
+
);
|
|
3181
|
+
print_optional("Dominant dtype", report.dominant_dtype.as_deref());
|
|
3182
|
+
println!();
|
|
3183
|
+
println!("Parameter-like Internals:");
|
|
3184
|
+
println!(
|
|
3185
|
+
" Full artifact fingerprint: {}",
|
|
3186
|
+
format_hex(report.file_fingerprint)
|
|
3187
|
+
);
|
|
3188
|
+
println!(
|
|
3189
|
+
" Learned-state fingerprint (sampled): {}",
|
|
3190
|
+
format_hex(report.learned_state_fingerprint)
|
|
3191
|
+
);
|
|
3192
|
+
println!(
|
|
3193
|
+
" Note: parameter count is estimated from storage bytes / dtype; tensor shapes are not decoded."
|
|
3194
|
+
);
|
|
3195
|
+
println!();
|
|
3196
|
+
println!("Recovered Parameter Names:");
|
|
3197
|
+
if report.param_names.is_empty() {
|
|
3198
|
+
println!(" none recovered from the pickle");
|
|
3199
|
+
} else {
|
|
3200
|
+
println!(" found: {}", report.param_names.len());
|
|
3201
|
+
for name in report.param_names.iter().take(limit) {
|
|
3202
|
+
println!(" {name}");
|
|
3203
|
+
}
|
|
3204
|
+
if report.param_names.len() > limit {
|
|
3205
|
+
println!(" ... {} more", report.param_names.len() - limit);
|
|
3206
|
+
}
|
|
3207
|
+
}
|
|
3208
|
+
}
|
|
3209
|
+
|
|
3210
|
+
fn diff_pt(old_path: &str, new_path: &str, json: bool) -> Result<(), String> {
|
|
3211
|
+
let old = pt::inspect(old_path)?;
|
|
3212
|
+
let new = pt::inspect(new_path)?;
|
|
3213
|
+
|
|
3214
|
+
if json {
|
|
3215
|
+
print_json(&serde_json::json!({
|
|
3216
|
+
"type": "pytorch",
|
|
3217
|
+
"old": old,
|
|
3218
|
+
"new": new,
|
|
3219
|
+
}))?;
|
|
3220
|
+
return Ok(());
|
|
3221
|
+
}
|
|
3222
|
+
|
|
3223
|
+
println!("Model Diff");
|
|
3224
|
+
println!("----------");
|
|
3225
|
+
println!("Type: PyTorch");
|
|
3226
|
+
println!("Old: {}", old.path);
|
|
3227
|
+
println!("New: {}", new.path);
|
|
3228
|
+
println!();
|
|
3229
|
+
println!("Structure:");
|
|
3230
|
+
print_diff_bytes("File size", old.bytes, new.bytes);
|
|
3231
|
+
print_diff_usize("Tensors (storages)", old.tensor_count, new.tensor_count);
|
|
3232
|
+
print_diff_usize(
|
|
3233
|
+
"Parameters (est)",
|
|
3234
|
+
old.estimated_parameter_count as usize,
|
|
3235
|
+
new.estimated_parameter_count as usize,
|
|
3236
|
+
);
|
|
3237
|
+
print_diff_bytes(
|
|
3238
|
+
"Parameter bytes",
|
|
3239
|
+
old.total_parameter_bytes as usize,
|
|
3240
|
+
new.total_parameter_bytes as usize,
|
|
3241
|
+
);
|
|
3242
|
+
print_diff_opt_str(
|
|
3243
|
+
"Dominant dtype",
|
|
3244
|
+
old.dominant_dtype.as_deref(),
|
|
3245
|
+
new.dominant_dtype.as_deref(),
|
|
3246
|
+
);
|
|
3247
|
+
println!();
|
|
3248
|
+
println!("Parameter-like Internals:");
|
|
3249
|
+
print_fingerprint_diff(
|
|
3250
|
+
"Full artifact fingerprint",
|
|
3251
|
+
Some(old.file_fingerprint),
|
|
3252
|
+
Some(new.file_fingerprint),
|
|
3253
|
+
);
|
|
3254
|
+
print_fingerprint_diff(
|
|
3255
|
+
"Learned-state fingerprint (sampled)",
|
|
3256
|
+
Some(old.learned_state_fingerprint),
|
|
3257
|
+
Some(new.learned_state_fingerprint),
|
|
3258
|
+
);
|
|
3259
|
+
println!();
|
|
3260
|
+
println!("Parameter Names:");
|
|
3261
|
+
print_diff_usize("Recovered names", old.param_names.len(), new.param_names.len());
|
|
3262
|
+
print_feature_name_changes(&old.param_names, &new.param_names, 12);
|
|
3263
|
+
|
|
3264
|
+
Ok(())
|
|
3265
|
+
}
|
|
3266
|
+
|
|
3267
|
+
fn explain_diff_pt(old_path: &str, new_path: &str) -> Result<(), String> {
|
|
3268
|
+
let old = pt::inspect(old_path)?;
|
|
3269
|
+
let new = pt::inspect(new_path)?;
|
|
3270
|
+
|
|
3271
|
+
let old_names = old.param_names.iter().collect::<BTreeSet<_>>();
|
|
3272
|
+
let new_names = new.param_names.iter().collect::<BTreeSet<_>>();
|
|
3273
|
+
let added = new_names
|
|
3274
|
+
.difference(&old_names)
|
|
3275
|
+
.map(|name| name.as_str())
|
|
3276
|
+
.collect::<Vec<_>>();
|
|
3277
|
+
let removed = old_names
|
|
3278
|
+
.difference(&new_names)
|
|
3279
|
+
.map(|name| name.as_str())
|
|
3280
|
+
.collect::<Vec<_>>();
|
|
3281
|
+
let names_known = !old.param_names.is_empty() || !new.param_names.is_empty();
|
|
3282
|
+
|
|
3283
|
+
println!("Model Change Explanation");
|
|
3284
|
+
println!("------------------------");
|
|
3285
|
+
println!("Type: PyTorch");
|
|
3286
|
+
println!("Old: {}", old.path);
|
|
3287
|
+
println!("New: {}", new.path);
|
|
3288
|
+
println!();
|
|
3289
|
+
println!("Architecture:");
|
|
3290
|
+
println!(
|
|
3291
|
+
" Tensors: {} -> {} ({})",
|
|
3292
|
+
old.tensor_count,
|
|
3293
|
+
new.tensor_count,
|
|
3294
|
+
growth_label(old.tensor_count, new.tensor_count)
|
|
3295
|
+
);
|
|
3296
|
+
println!(
|
|
3297
|
+
" Parameters (est): {} -> {} ({})",
|
|
3298
|
+
format_count_human(old.estimated_parameter_count as usize),
|
|
3299
|
+
format_count_human(new.estimated_parameter_count as usize),
|
|
3300
|
+
growth_label(
|
|
3301
|
+
old.estimated_parameter_count as usize,
|
|
3302
|
+
new.estimated_parameter_count as usize
|
|
3303
|
+
)
|
|
3304
|
+
);
|
|
3305
|
+
println!(
|
|
3306
|
+
" Dominant dtype: {} -> {} ({})",
|
|
3307
|
+
old.dominant_dtype.as_deref().unwrap_or("unknown"),
|
|
3308
|
+
new.dominant_dtype.as_deref().unwrap_or("unknown"),
|
|
3309
|
+
same_or_changed(old.dominant_dtype == new.dominant_dtype)
|
|
3310
|
+
);
|
|
3311
|
+
let name_note = if !names_known {
|
|
3312
|
+
"names not recovered".to_string()
|
|
3313
|
+
} else if added.is_empty() && removed.is_empty() {
|
|
3314
|
+
"same set".to_string()
|
|
3315
|
+
} else {
|
|
3316
|
+
format!("+{} added, -{} removed", added.len(), removed.len())
|
|
3317
|
+
};
|
|
3318
|
+
println!(
|
|
3319
|
+
" Param names: {} -> {} ({})",
|
|
3320
|
+
old.param_names.len(),
|
|
3321
|
+
new.param_names.len(),
|
|
3322
|
+
name_note
|
|
3323
|
+
);
|
|
3324
|
+
print_lgbm_feature_list(" added: ", &added, 8);
|
|
3325
|
+
print_lgbm_feature_list(" removed:", &removed, 8);
|
|
3326
|
+
println!();
|
|
3327
|
+
println!(
|
|
3328
|
+
"File size: {} -> {} ({})",
|
|
3329
|
+
format_bytes(old.bytes),
|
|
3330
|
+
format_bytes(new.bytes),
|
|
3331
|
+
growth_label(old.bytes, new.bytes)
|
|
3332
|
+
);
|
|
3333
|
+
println!();
|
|
3334
|
+
if old.learned_state_fingerprint != new.learned_state_fingerprint {
|
|
3335
|
+
println!("Learned state: CHANGED (sampled weight bytes differ - a real retrain/finetune)");
|
|
3336
|
+
} else {
|
|
3337
|
+
println!("Learned state: unchanged (sampled weight bytes identical)");
|
|
3338
|
+
}
|
|
3339
|
+
println!();
|
|
3340
|
+
println!("Summary:");
|
|
3341
|
+
let names_same = added.is_empty() && removed.is_empty();
|
|
3342
|
+
let descriptor = if !names_known {
|
|
3343
|
+
"weights compared (parameter names not recovered)"
|
|
3344
|
+
} else if !names_same {
|
|
3345
|
+
"parameter set changed (layers added/removed)"
|
|
3346
|
+
} else if old.estimated_parameter_count != new.estimated_parameter_count {
|
|
3347
|
+
"same layers, parameter count changed (resized)"
|
|
3348
|
+
} else {
|
|
3349
|
+
"same architecture, retrained/finetuned"
|
|
3350
|
+
};
|
|
3351
|
+
println!(
|
|
3352
|
+
" {descriptor}; params {}.",
|
|
3353
|
+
growth_label(
|
|
3354
|
+
old.estimated_parameter_count as usize,
|
|
3355
|
+
new.estimated_parameter_count as usize
|
|
3356
|
+
)
|
|
3357
|
+
);
|
|
3358
|
+
println!();
|
|
3359
|
+
println!(
|
|
3360
|
+
"Note: static read of the torch.save zip (no torch); shapes are not decoded and the weight fingerprint is sampled."
|
|
3361
|
+
);
|
|
3362
|
+
|
|
3363
|
+
Ok(())
|
|
3364
|
+
}
|
|
3365
|
+
|
|
3366
|
+
fn check_pt(old_path: &str, new_path: &str, options: &CheckOptions) -> Result<(), String> {
|
|
3367
|
+
if options.max_ops_growth_pct.is_some() {
|
|
3368
|
+
return Err("--max-ops-growth is only supported for ONNX artifacts.".to_string());
|
|
3369
|
+
}
|
|
3370
|
+
if options.fail_on_new_op {
|
|
3371
|
+
return Err("--fail-on-new-op is only supported for ONNX artifacts.".to_string());
|
|
3372
|
+
}
|
|
3373
|
+
if options.fail_on_training_config_change {
|
|
3374
|
+
return Err(
|
|
3375
|
+
"--fail-on-training-config-change is only supported for CatBoost/LightGBM artifacts."
|
|
3376
|
+
.to_string(),
|
|
3377
|
+
);
|
|
3378
|
+
}
|
|
3379
|
+
|
|
3380
|
+
let old = pt::inspect(old_path)?;
|
|
3381
|
+
let new = pt::inspect(new_path)?;
|
|
3382
|
+
let mut checks = Vec::new();
|
|
3383
|
+
|
|
3384
|
+
if let Some(max_pct) = options.max_size_growth_pct {
|
|
3385
|
+
checks.push(growth_check(
|
|
3386
|
+
"file size growth",
|
|
3387
|
+
old.bytes,
|
|
3388
|
+
new.bytes,
|
|
3389
|
+
max_pct,
|
|
3390
|
+
));
|
|
3391
|
+
}
|
|
3392
|
+
if let Some(max_pct) = options.max_parameter_growth_pct {
|
|
3393
|
+
checks.push(growth_check(
|
|
3394
|
+
"parameter growth",
|
|
3395
|
+
old.estimated_parameter_count as usize,
|
|
3396
|
+
new.estimated_parameter_count as usize,
|
|
3397
|
+
max_pct,
|
|
3398
|
+
));
|
|
3399
|
+
}
|
|
3400
|
+
if options.fail_on_feature_change {
|
|
3401
|
+
checks.push(boolean_check(
|
|
3402
|
+
"parameter names unchanged",
|
|
3403
|
+
old.param_names == new.param_names,
|
|
3404
|
+
format!(
|
|
3405
|
+
"{} -> {} recovered names",
|
|
3406
|
+
old.param_names.len(),
|
|
3407
|
+
new.param_names.len()
|
|
3408
|
+
),
|
|
3409
|
+
));
|
|
3410
|
+
}
|
|
3411
|
+
|
|
3412
|
+
print_check_report("PyTorch", &old.path, &new.path, &checks);
|
|
3413
|
+
finish_check(&checks)
|
|
3414
|
+
}
|
|
3415
|
+
|
|
2992
3416
|
fn onnx_cmd(args: &[String]) -> Result<(), String> {
|
|
2993
3417
|
if args.is_empty() {
|
|
2994
3418
|
return Err("usage: mod-trace onnx [--json] [--limit 20] <model.onnx>".to_string());
|
|
@@ -3488,13 +3912,14 @@ fn print_help() {
|
|
|
3488
3912
|
"mod-trace - inspect ML model artifacts without loading the framework\n\n\
|
|
3489
3913
|
Core usage:\n \
|
|
3490
3914
|
mod-trace doctor [--json]\n \
|
|
3491
|
-
mod-trace inspect [--deep] [--json] [--limit 20] <model.cbm|model.lgb|model.onnx|model.json>\n \
|
|
3492
|
-
mod-trace diff [--deep] [--json] <old-model> <new-model> (.cbm, .lgb/.txt
|
|
3915
|
+
mod-trace inspect [--deep] [--json] [--limit 20] <model.cbm|model.lgb|model.onnx|model.pt|model.json>\n \
|
|
3916
|
+
mod-trace diff [--deep] [--json] <old-model> <new-model> (.cbm, .lgb/.txt, .onnx, or .pt/.pth)\n \
|
|
3493
3917
|
mod-trace explain-diff <old-model> <new-model> (plain-English what changed: layers, params, cost, new ops)\n\n\
|
|
3494
3918
|
mod-trace check [--max-size-growth 20%] [--max-ops-growth 25%] [--max-parameter-growth 30%] [--fail-on-feature-change] [--fail-on-training-config-change] [--fail-on-new-op] <old-model> <new-model>\n\n\
|
|
3495
3919
|
Artifact inspectors:\n \
|
|
3496
3920
|
mod-trace catboost [--deep] [--json] [--limit 20] <model.cbm> [more.cbm...]\n \
|
|
3497
3921
|
mod-trace lightgbm [--json] [--limit 20] <model.lgb|model.txt> [more...]\n \
|
|
3922
|
+
mod-trace pytorch [--json] [--limit 20] <model.pt|model.pth> [more...]\n \
|
|
3498
3923
|
mod-trace onnx [--json] [--limit 20] <model.onnx>\n \
|
|
3499
3924
|
mod-trace explain <model.onnx|model.cbm>\n\n\
|
|
3500
3925
|
Tensor lab (secondary; see docs/tensor-lab.md):\n \
|
|
@@ -3831,7 +4256,7 @@ mod tests {
|
|
|
3831
4256
|
|
|
3832
4257
|
assert_eq!(
|
|
3833
4258
|
check_cmd(&args),
|
|
3834
|
-
Err("check supports CatBoost, LightGBM, and
|
|
4259
|
+
Err("check supports CatBoost, LightGBM, ONNX, and PyTorch artifacts, not tensor plan JSON."
|
|
3835
4260
|
.to_string())
|
|
3836
4261
|
);
|
|
3837
4262
|
}
|
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
use serde::Serialize;
|
|
2
|
+
use std::fs::File;
|
|
3
|
+
use std::io::{Read, Seek, SeekFrom};
|
|
4
|
+
use std::path::Path;
|
|
5
|
+
|
|
6
|
+
/// A static read of a PyTorch `torch.save` artifact (`.pt`/`.pth`).
|
|
7
|
+
///
|
|
8
|
+
/// Modern PyTorch saves an (uncompressed) ZIP of a pickled structure
|
|
9
|
+
/// (`data.pkl`) plus raw tensor storages under `.../data/N`. This reader walks
|
|
10
|
+
/// that zip without running Python or torch: it sizes the storages, recovers
|
|
11
|
+
/// parameter names and dtype from the pickle, and fingerprints the weights by
|
|
12
|
+
/// sampling. It does not decode exact per-tensor shapes (that needs a pickle VM).
|
|
13
|
+
#[derive(Serialize)]
|
|
14
|
+
pub struct PtReport {
|
|
15
|
+
pub path: String,
|
|
16
|
+
pub bytes: usize,
|
|
17
|
+
pub is_zip: bool,
|
|
18
|
+
pub tensor_count: usize,
|
|
19
|
+
pub total_parameter_bytes: u64,
|
|
20
|
+
pub estimated_parameter_count: u64,
|
|
21
|
+
pub dominant_dtype: Option<String>,
|
|
22
|
+
pub param_names: Vec<String>,
|
|
23
|
+
pub torch_version: Option<String>,
|
|
24
|
+
pub file_fingerprint: u64,
|
|
25
|
+
pub learned_state_fingerprint: u64,
|
|
26
|
+
}
|
|
27
|
+
|
|
28
|
+
impl PtReport {
|
|
29
|
+
pub fn format(&self) -> &'static str {
|
|
30
|
+
if self.is_zip {
|
|
31
|
+
"PyTorch (torch.save zip)"
|
|
32
|
+
} else {
|
|
33
|
+
"PyTorch (legacy pickle)"
|
|
34
|
+
}
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
pub fn looks_like_pt(head: &[u8]) -> bool {
|
|
39
|
+
head.starts_with(b"PK\x03\x04") || head.first() == Some(&0x80)
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
const SAMPLE_PER_STORAGE: u64 = 1 << 20; // 1 MiB sampled per tensor for the weight fingerprint
|
|
43
|
+
|
|
44
|
+
pub fn inspect(path: &str) -> Result<PtReport, String> {
|
|
45
|
+
let mut file = File::open(path).map_err(|err| format!("open {path}: {err}"))?;
|
|
46
|
+
let total = file
|
|
47
|
+
.metadata()
|
|
48
|
+
.map_err(|err| format!("stat {path}: {err}"))?
|
|
49
|
+
.len();
|
|
50
|
+
|
|
51
|
+
let mut magic = [0u8; 4];
|
|
52
|
+
let read = file.read(&mut magic).map_err(|err| format!("read {path}: {err}"))?;
|
|
53
|
+
if read >= 4 && &magic == b"PK\x03\x04" {
|
|
54
|
+
inspect_zip(&mut file, path, total)
|
|
55
|
+
} else if read >= 1 && magic[0] == 0x80 {
|
|
56
|
+
inspect_legacy(&mut file, path, total)
|
|
57
|
+
} else {
|
|
58
|
+
Err(format!(
|
|
59
|
+
"{path} does not look like a PyTorch file (no zip `PK` or pickle marker)"
|
|
60
|
+
))
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
fn read_at(file: &mut File, offset: u64, len: usize) -> Result<Vec<u8>, String> {
|
|
65
|
+
file.seek(SeekFrom::Start(offset))
|
|
66
|
+
.map_err(|err| format!("seek: {err}"))?;
|
|
67
|
+
let mut buf = vec![0u8; len];
|
|
68
|
+
file.read_exact(&mut buf).map_err(|err| format!("read: {err}"))?;
|
|
69
|
+
Ok(buf)
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
fn le_u16(b: &[u8], o: usize) -> u16 {
|
|
73
|
+
u16::from_le_bytes([b[o], b[o + 1]])
|
|
74
|
+
}
|
|
75
|
+
fn le_u32(b: &[u8], o: usize) -> u32 {
|
|
76
|
+
u32::from_le_bytes([b[o], b[o + 1], b[o + 2], b[o + 3]])
|
|
77
|
+
}
|
|
78
|
+
fn le_u64(b: &[u8], o: usize) -> u64 {
|
|
79
|
+
u64::from_le_bytes([
|
|
80
|
+
b[o], b[o + 1], b[o + 2], b[o + 3], b[o + 4], b[o + 5], b[o + 6], b[o + 7],
|
|
81
|
+
])
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
/// Offset where an entry's data begins, by reading its local file header.
|
|
85
|
+
fn entry_data_start(file: &mut File, local_offset: u64) -> Result<u64, String> {
|
|
86
|
+
let lh = read_at(file, local_offset, 30)?;
|
|
87
|
+
if &lh[0..4] != b"PK\x03\x04" {
|
|
88
|
+
return Err("bad local header".to_string());
|
|
89
|
+
}
|
|
90
|
+
let name_len = le_u16(&lh, 26) as u64;
|
|
91
|
+
let extra_len = le_u16(&lh, 28) as u64;
|
|
92
|
+
Ok(local_offset + 30 + name_len + extra_len)
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
fn find_sig_rev(buf: &[u8], sig: &[u8; 4]) -> Option<usize> {
|
|
96
|
+
if buf.len() < 4 {
|
|
97
|
+
return None;
|
|
98
|
+
}
|
|
99
|
+
let mut i = buf.len() - 4;
|
|
100
|
+
loop {
|
|
101
|
+
if &buf[i..i + 4] == sig {
|
|
102
|
+
return Some(i);
|
|
103
|
+
}
|
|
104
|
+
if i == 0 {
|
|
105
|
+
return None;
|
|
106
|
+
}
|
|
107
|
+
i -= 1;
|
|
108
|
+
}
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
fn inspect_zip(file: &mut File, path: &str, total: u64) -> Result<PtReport, String> {
|
|
112
|
+
// Torch zips use data descriptors, so sizes live in the central directory.
|
|
113
|
+
let tail_len = total.min(65_557) as usize;
|
|
114
|
+
let tail = read_at(file, total - tail_len as u64, tail_len)?;
|
|
115
|
+
let eocd_rel =
|
|
116
|
+
find_sig_rev(&tail, b"PK\x05\x06").ok_or_else(|| format!("{path}: no zip EOCD record"))?;
|
|
117
|
+
let eocd = &tail[eocd_rel..];
|
|
118
|
+
|
|
119
|
+
let mut cd_size = le_u32(eocd, 12) as u64;
|
|
120
|
+
let mut cd_offset = le_u32(eocd, 16) as u64;
|
|
121
|
+
|
|
122
|
+
// ZIP64 when fields are saturated (models > 4 GiB).
|
|
123
|
+
if cd_offset == 0xFFFF_FFFF || cd_size == 0xFFFF_FFFF {
|
|
124
|
+
if eocd_rel >= 20 {
|
|
125
|
+
let loc = &tail[eocd_rel - 20..];
|
|
126
|
+
if &loc[0..4] == b"PK\x06\x07" {
|
|
127
|
+
let z64_off = le_u64(loc, 8);
|
|
128
|
+
let z64 = read_at(file, z64_off, 56)?;
|
|
129
|
+
if &z64[0..4] == b"PK\x06\x06" {
|
|
130
|
+
cd_size = le_u64(&z64, 40);
|
|
131
|
+
cd_offset = le_u64(&z64, 48);
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
}
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
let cd = read_at(file, cd_offset, cd_size as usize)?;
|
|
138
|
+
|
|
139
|
+
let mut pkl = Vec::new();
|
|
140
|
+
let mut torch_version = None;
|
|
141
|
+
let mut tensor_count = 0usize;
|
|
142
|
+
let mut total_parameter_bytes = 0u64;
|
|
143
|
+
let mut struct_hash = FNV_OFFSET;
|
|
144
|
+
let mut weight_hash = FNV_OFFSET;
|
|
145
|
+
|
|
146
|
+
let mut o = 0usize;
|
|
147
|
+
while o + 46 <= cd.len() && &cd[o..o + 4] == b"PK\x01\x02" {
|
|
148
|
+
let comp32 = le_u32(&cd, o + 20);
|
|
149
|
+
let mut uncomp_size = le_u32(&cd, o + 24) as u64;
|
|
150
|
+
let name_len = le_u16(&cd, o + 28) as usize;
|
|
151
|
+
let extra_len = le_u16(&cd, o + 30) as usize;
|
|
152
|
+
let comment_len = le_u16(&cd, o + 32) as usize;
|
|
153
|
+
let mut local_offset = le_u32(&cd, o + 42) as u64;
|
|
154
|
+
let name = String::from_utf8_lossy(&cd[o + 46..o + 46 + name_len]).to_string();
|
|
155
|
+
let extra = &cd[o + 46 + name_len..o + 46 + name_len + extra_len];
|
|
156
|
+
|
|
157
|
+
// ZIP64 extended info: present fields appear in order uncomp, comp, offset.
|
|
158
|
+
if uncomp_size == 0xFFFF_FFFF || local_offset == 0xFFFF_FFFF {
|
|
159
|
+
let mut e = 0usize;
|
|
160
|
+
while e + 4 <= extra.len() {
|
|
161
|
+
let id = le_u16(extra, e);
|
|
162
|
+
let sz = le_u16(extra, e + 2) as usize;
|
|
163
|
+
if id == 0x0001 {
|
|
164
|
+
let mut f = e + 4;
|
|
165
|
+
if uncomp_size == 0xFFFF_FFFF && f + 8 <= extra.len() {
|
|
166
|
+
uncomp_size = le_u64(extra, f);
|
|
167
|
+
f += 8;
|
|
168
|
+
}
|
|
169
|
+
if comp32 == 0xFFFF_FFFF && f + 8 <= extra.len() {
|
|
170
|
+
f += 8; // skip compressed size
|
|
171
|
+
}
|
|
172
|
+
if local_offset == 0xFFFF_FFFF && f + 8 <= extra.len() {
|
|
173
|
+
local_offset = le_u64(extra, f);
|
|
174
|
+
}
|
|
175
|
+
break;
|
|
176
|
+
}
|
|
177
|
+
e += 4 + sz;
|
|
178
|
+
}
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
let base = name.rsplit('/').next().unwrap_or(&name);
|
|
182
|
+
struct_hash = fnv_update(struct_hash, name.as_bytes());
|
|
183
|
+
struct_hash = fnv_update(struct_hash, &uncomp_size.to_le_bytes());
|
|
184
|
+
|
|
185
|
+
if name.ends_with("data.pkl") {
|
|
186
|
+
let start = entry_data_start(file, local_offset)?;
|
|
187
|
+
pkl = read_at(file, start, uncomp_size as usize)?;
|
|
188
|
+
} else if name.ends_with("/version") || name == "version" {
|
|
189
|
+
let start = entry_data_start(file, local_offset)?;
|
|
190
|
+
let v = read_at(file, start, uncomp_size.min(32) as usize)?;
|
|
191
|
+
torch_version = Some(String::from_utf8_lossy(&v).trim().to_string());
|
|
192
|
+
} else if name.contains("/data/") && !base.is_empty() && base.bytes().all(|b| b.is_ascii_digit())
|
|
193
|
+
{
|
|
194
|
+
tensor_count += 1;
|
|
195
|
+
total_parameter_bytes += uncomp_size;
|
|
196
|
+
let sample = uncomp_size.min(SAMPLE_PER_STORAGE) as usize;
|
|
197
|
+
if sample > 0 {
|
|
198
|
+
let start = entry_data_start(file, local_offset)?;
|
|
199
|
+
let bytes = read_at(file, start, sample)?;
|
|
200
|
+
weight_hash = fnv_update(weight_hash, &bytes);
|
|
201
|
+
}
|
|
202
|
+
weight_hash = fnv_update(weight_hash, &uncomp_size.to_le_bytes());
|
|
203
|
+
}
|
|
204
|
+
|
|
205
|
+
o += 46 + name_len + extra_len + comment_len;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
if pkl.is_empty() {
|
|
209
|
+
return Err(format!("{path}: no data.pkl found inside the archive"));
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
let param_names = recover_param_names(&pkl);
|
|
213
|
+
let (dtype, elem_size) = dominant_dtype(&pkl);
|
|
214
|
+
let estimated_parameter_count = total_parameter_bytes / elem_size;
|
|
215
|
+
|
|
216
|
+
let file_fingerprint = fnv_update(fnv_update(FNV_OFFSET, &pkl), &struct_hash.to_le_bytes());
|
|
217
|
+
|
|
218
|
+
Ok(PtReport {
|
|
219
|
+
path: file_name(path),
|
|
220
|
+
bytes: total as usize,
|
|
221
|
+
is_zip: true,
|
|
222
|
+
tensor_count,
|
|
223
|
+
total_parameter_bytes,
|
|
224
|
+
estimated_parameter_count,
|
|
225
|
+
dominant_dtype: dtype,
|
|
226
|
+
param_names,
|
|
227
|
+
torch_version,
|
|
228
|
+
file_fingerprint,
|
|
229
|
+
learned_state_fingerprint: weight_hash,
|
|
230
|
+
})
|
|
231
|
+
}
|
|
232
|
+
|
|
233
|
+
fn inspect_legacy(file: &mut File, path: &str, total: u64) -> Result<PtReport, String> {
|
|
234
|
+
// Legacy (non-zip) pickle: we can only scan strings + fingerprint the bytes.
|
|
235
|
+
let bytes = read_at(file, 0, total as usize)?;
|
|
236
|
+
let param_names = recover_param_names(&bytes);
|
|
237
|
+
let (dtype, _) = dominant_dtype(&bytes);
|
|
238
|
+
Ok(PtReport {
|
|
239
|
+
path: file_name(path),
|
|
240
|
+
bytes: total as usize,
|
|
241
|
+
is_zip: false,
|
|
242
|
+
tensor_count: 0,
|
|
243
|
+
total_parameter_bytes: 0,
|
|
244
|
+
estimated_parameter_count: 0,
|
|
245
|
+
dominant_dtype: dtype,
|
|
246
|
+
param_names,
|
|
247
|
+
torch_version: None,
|
|
248
|
+
file_fingerprint: fnv_update(FNV_OFFSET, &bytes),
|
|
249
|
+
learned_state_fingerprint: fnv_update(FNV_OFFSET, &bytes),
|
|
250
|
+
})
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
fn file_name(path: &str) -> String {
|
|
254
|
+
Path::new(path)
|
|
255
|
+
.file_name()
|
|
256
|
+
.and_then(|name| name.to_str())
|
|
257
|
+
.unwrap_or(path)
|
|
258
|
+
.to_string()
|
|
259
|
+
}
|
|
260
|
+
|
|
261
|
+
/// Pull UTF-8 strings out of pickle BINUNICODE (0x58) / SHORT_BINUNICODE (0x8c)
|
|
262
|
+
/// opcodes and keep the ones that look like state_dict keys.
|
|
263
|
+
fn recover_param_names(pkl: &[u8]) -> Vec<String> {
|
|
264
|
+
let mut names = Vec::new();
|
|
265
|
+
let mut seen = std::collections::BTreeSet::new();
|
|
266
|
+
let mut i = 0usize;
|
|
267
|
+
while i < pkl.len() {
|
|
268
|
+
let (text, next) = match pkl[i] {
|
|
269
|
+
0x8c if i + 2 <= pkl.len() => {
|
|
270
|
+
let n = pkl[i + 1] as usize;
|
|
271
|
+
if i + 2 + n <= pkl.len() {
|
|
272
|
+
(std::str::from_utf8(&pkl[i + 2..i + 2 + n]).ok(), i + 2 + n)
|
|
273
|
+
} else {
|
|
274
|
+
(None, i + 1)
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
0x58 if i + 5 <= pkl.len() => {
|
|
278
|
+
let n = le_u32(pkl, i + 1) as usize;
|
|
279
|
+
if i + 5 + n <= pkl.len() {
|
|
280
|
+
(std::str::from_utf8(&pkl[i + 5..i + 5 + n]).ok(), i + 5 + n)
|
|
281
|
+
} else {
|
|
282
|
+
(None, i + 1)
|
|
283
|
+
}
|
|
284
|
+
}
|
|
285
|
+
_ => (None, i + 1),
|
|
286
|
+
};
|
|
287
|
+
if let Some(text) = text {
|
|
288
|
+
if looks_like_param_name(text) && seen.insert(text.to_string()) {
|
|
289
|
+
names.push(text.to_string());
|
|
290
|
+
}
|
|
291
|
+
}
|
|
292
|
+
i = next;
|
|
293
|
+
}
|
|
294
|
+
names
|
|
295
|
+
}
|
|
296
|
+
|
|
297
|
+
fn looks_like_param_name(value: &str) -> bool {
|
|
298
|
+
if value.len() < 2 || value.len() > 200 {
|
|
299
|
+
return false;
|
|
300
|
+
}
|
|
301
|
+
if !value
|
|
302
|
+
.bytes()
|
|
303
|
+
.all(|b| b.is_ascii_alphanumeric() || b == b'_' || b == b'.')
|
|
304
|
+
{
|
|
305
|
+
return false;
|
|
306
|
+
}
|
|
307
|
+
value.contains('.')
|
|
308
|
+
|| value.ends_with("weight")
|
|
309
|
+
|| value.ends_with("bias")
|
|
310
|
+
|| value.ends_with("running_mean")
|
|
311
|
+
|| value.ends_with("running_var")
|
|
312
|
+
}
|
|
313
|
+
|
|
314
|
+
fn dominant_dtype(pkl: &[u8]) -> (Option<String>, u64) {
|
|
315
|
+
// (storage token, dtype name, element size)
|
|
316
|
+
let candidates: [(&[u8], &str, u64); 8] = [
|
|
317
|
+
(b"DoubleStorage", "float64", 8),
|
|
318
|
+
(b"BFloat16Storage", "bfloat16", 2),
|
|
319
|
+
(b"HalfStorage", "float16", 2),
|
|
320
|
+
(b"FloatStorage", "float32", 4),
|
|
321
|
+
(b"LongStorage", "int64", 8),
|
|
322
|
+
(b"IntStorage", "int32", 4),
|
|
323
|
+
(b"ByteStorage", "uint8", 1),
|
|
324
|
+
(b"BoolStorage", "bool", 1),
|
|
325
|
+
];
|
|
326
|
+
let mut best: Option<(&str, u64, usize)> = None;
|
|
327
|
+
for (token, name, size) in candidates {
|
|
328
|
+
let count = count_occurrences(pkl, token);
|
|
329
|
+
if count > 0 && best.map(|(_, _, c)| count > c).unwrap_or(true) {
|
|
330
|
+
best = Some((name, size, count));
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
match best {
|
|
334
|
+
Some((name, size, _)) => (Some(name.to_string()), size),
|
|
335
|
+
None => (None, 4),
|
|
336
|
+
}
|
|
337
|
+
}
|
|
338
|
+
|
|
339
|
+
fn count_occurrences(haystack: &[u8], needle: &[u8]) -> usize {
|
|
340
|
+
if needle.is_empty() || haystack.len() < needle.len() {
|
|
341
|
+
return 0;
|
|
342
|
+
}
|
|
343
|
+
let mut count = 0;
|
|
344
|
+
let mut i = 0;
|
|
345
|
+
while i + needle.len() <= haystack.len() {
|
|
346
|
+
if &haystack[i..i + needle.len()] == needle {
|
|
347
|
+
count += 1;
|
|
348
|
+
i += needle.len();
|
|
349
|
+
} else {
|
|
350
|
+
i += 1;
|
|
351
|
+
}
|
|
352
|
+
}
|
|
353
|
+
count
|
|
354
|
+
}
|
|
355
|
+
|
|
356
|
+
const FNV_OFFSET: u64 = 0xcbf29ce484222325;
|
|
357
|
+
const FNV_PRIME: u64 = 0x100000001b3;
|
|
358
|
+
|
|
359
|
+
fn fnv_update(mut hash: u64, data: &[u8]) -> u64 {
|
|
360
|
+
for byte in data {
|
|
361
|
+
hash ^= u64::from(*byte);
|
|
362
|
+
hash = hash.wrapping_mul(FNV_PRIME);
|
|
363
|
+
}
|
|
364
|
+
hash
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
#[cfg(test)]
|
|
368
|
+
mod tests {
|
|
369
|
+
use super::{count_occurrences, looks_like_param_name, recover_param_names};
|
|
370
|
+
|
|
371
|
+
#[test]
|
|
372
|
+
fn param_name_filter() {
|
|
373
|
+
assert!(looks_like_param_name("fc1.weight"));
|
|
374
|
+
assert!(looks_like_param_name("encoder.layers.3.attn.bias"));
|
|
375
|
+
assert!(looks_like_param_name("classifier.weight"));
|
|
376
|
+
assert!(!looks_like_param_name("cpu"));
|
|
377
|
+
assert!(!looks_like_param_name("storage"));
|
|
378
|
+
assert!(!looks_like_param_name("has space"));
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
#[test]
|
|
382
|
+
fn recovers_binunicode_names() {
|
|
383
|
+
// SHORT_BINUNICODE(0x8c) len=10 "fc1.weight", then BINUNICODE(0x58) len=8 "fc1.bias"
|
|
384
|
+
let mut pkl = vec![0x8c, 10];
|
|
385
|
+
pkl.extend_from_slice(b"fc1.weight");
|
|
386
|
+
pkl.push(0x58);
|
|
387
|
+
pkl.extend_from_slice(&8u32.to_le_bytes());
|
|
388
|
+
pkl.extend_from_slice(b"fc1.bias");
|
|
389
|
+
let names = recover_param_names(&pkl);
|
|
390
|
+
assert_eq!(names, vec!["fc1.weight", "fc1.bias"]);
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
#[test]
|
|
394
|
+
fn counts_tokens() {
|
|
395
|
+
assert_eq!(count_occurrences(b"FloatStorageFloatStorage", b"FloatStorage"), 2);
|
|
396
|
+
assert_eq!(count_occurrences(b"none here", b"FloatStorage"), 0);
|
|
397
|
+
}
|
|
398
|
+
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|