JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
# python/scripts/gen_and_bench.py
|
|
2
|
+
# ruff: noqa: S603, T201, RUF002
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Generate simple CNN ONNX models and benchmark JSTprove.
|
|
6
|
+
|
|
7
|
+
Typical usages:
|
|
8
|
+
|
|
9
|
+
Depth sweep (vary conv depth, fixed input size):
|
|
10
|
+
jst bench \
|
|
11
|
+
--sweep depth \
|
|
12
|
+
--depth-min 1 \
|
|
13
|
+
--depth-max 16 \
|
|
14
|
+
--input-hw 56 \
|
|
15
|
+
--iterations 3 \
|
|
16
|
+
--results benchmarking/depth_sweep.jsonl
|
|
17
|
+
|
|
18
|
+
Breadth sweep (vary input resolution, fixed conv depth):
|
|
19
|
+
jst bench \
|
|
20
|
+
--sweep breadth \
|
|
21
|
+
--arch-depth 5 \
|
|
22
|
+
--input-hw-list 28,56,84,112 \
|
|
23
|
+
--iterations 3 \
|
|
24
|
+
--results benchmarking/breadth_sweep.jsonl \
|
|
25
|
+
--pool-cap 2 --conv-out-ch 16 --fc-hidden 256
|
|
26
|
+
|
|
27
|
+
Output locations (unless overridden via --onnx-dir / --inputs-dir):
|
|
28
|
+
depth → python/models/models_onnx/depth ; python/models/inputs/depth
|
|
29
|
+
breadth → python/models/models_onnx/breadth ; python/models/inputs/breadth
|
|
30
|
+
|
|
31
|
+
Each model is benchmarked by python.scripts.benchmark_runner and a row is
|
|
32
|
+
appended to the JSONL file passed via --results.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
# --- Standard library --------------------------------------------------------
|
|
38
|
+
import argparse
|
|
39
|
+
import json
|
|
40
|
+
import math
|
|
41
|
+
import subprocess
|
|
42
|
+
from pathlib import Path
|
|
43
|
+
from typing import TYPE_CHECKING
|
|
44
|
+
|
|
45
|
+
if TYPE_CHECKING:
|
|
46
|
+
from collections.abc import Sequence
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# --- Third-party -------------------------------------------------------------
|
|
50
|
+
import torch
|
|
51
|
+
import torch.nn.functional as F # noqa: N812
|
|
52
|
+
from torch import nn
|
|
53
|
+
|
|
54
|
+
# -----------------------------------------------------------------------------
|
|
55
|
+
# Planning helpers
|
|
56
|
+
# -----------------------------------------------------------------------------
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _max_pools_allowed(input_hw: int, stop_at_hw: int) -> int:
|
|
60
|
+
"""
|
|
61
|
+
Given an input size H=W=input_hw, return how many 2×2/stride-2 pools
|
|
62
|
+
can be applied while keeping H >= stop_at_hw.
|
|
63
|
+
"""
|
|
64
|
+
two = 2
|
|
65
|
+
if input_hw <= 0 or stop_at_hw <= 0:
|
|
66
|
+
return 0
|
|
67
|
+
pools = 0
|
|
68
|
+
h = input_hw
|
|
69
|
+
while h >= two and (h // two) >= stop_at_hw:
|
|
70
|
+
pools += 1
|
|
71
|
+
h //= two
|
|
72
|
+
return pools
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def plan_for_depth(
|
|
76
|
+
d: int,
|
|
77
|
+
*,
|
|
78
|
+
input_hw: int = 56,
|
|
79
|
+
base_fc: int = 1,
|
|
80
|
+
pool_cap: int | None = None,
|
|
81
|
+
stop_at_hw: int | None = 7,
|
|
82
|
+
) -> list[str]:
|
|
83
|
+
"""
|
|
84
|
+
Build a symbolic plan for `d` conv blocks followed by FC layers.
|
|
85
|
+
|
|
86
|
+
Layout:
|
|
87
|
+
- First K blocks: conv → relu → maxpool2d_k2_s2 (K = min(d, allowed_pools))
|
|
88
|
+
- Remaining blocks: conv → relu
|
|
89
|
+
- Tail: reshape → (fc → relu) × base_fc → final
|
|
90
|
+
|
|
91
|
+
Pooling policy:
|
|
92
|
+
- If pool_cap is provided, cap pooling at that many initial blocks.
|
|
93
|
+
- Else if stop_at_hw is provided, allow pooling while H >= stop_at_hw.
|
|
94
|
+
- Else fall back to floor(log2(H)).
|
|
95
|
+
"""
|
|
96
|
+
if pool_cap is not None:
|
|
97
|
+
allowed_pools = max(0, int(pool_cap))
|
|
98
|
+
elif stop_at_hw is not None:
|
|
99
|
+
allowed_pools = _max_pools_allowed(input_hw, stop_at_hw)
|
|
100
|
+
else:
|
|
101
|
+
allowed_pools = int(math.log2(max(1, input_hw)))
|
|
102
|
+
|
|
103
|
+
pools = min(d, allowed_pools)
|
|
104
|
+
conv_only = max(0, d - pools)
|
|
105
|
+
|
|
106
|
+
plan: list[str] = []
|
|
107
|
+
# pooled blocks
|
|
108
|
+
for i in range(1, pools + 1):
|
|
109
|
+
plan += [f"conv{i}", "relu", "maxpool2d_k2_s2"]
|
|
110
|
+
# non-pooled blocks
|
|
111
|
+
for j in range(pools + 1, pools + conv_only + 1):
|
|
112
|
+
plan += [f"conv{j}", "relu"]
|
|
113
|
+
|
|
114
|
+
# FC tail
|
|
115
|
+
plan += ["reshape"]
|
|
116
|
+
for k in range(1, base_fc + 1):
|
|
117
|
+
plan += [f"fc{k}", "relu"]
|
|
118
|
+
plan += ["final"]
|
|
119
|
+
return plan
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def count_layers(plan: Sequence[str]) -> tuple[int, int, int, int]:
|
|
123
|
+
"""Return a summary tuple (num_convs, num_pools, num_fcs, num_relus)."""
|
|
124
|
+
c = sum(1 for t in plan if t.startswith("conv"))
|
|
125
|
+
p = sum(1 for t in plan if t.startswith("maxpool"))
|
|
126
|
+
f = sum(1 for t in plan if t.startswith("fc"))
|
|
127
|
+
r = sum(1 for t in plan if t == "relu")
|
|
128
|
+
return c, p, f, r
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# -----------------------------------------------------------------------------
|
|
132
|
+
# Torch model that consumes the plan
|
|
133
|
+
# -----------------------------------------------------------------------------
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
class CNNDemo(nn.Module):
|
|
137
|
+
"""
|
|
138
|
+
Minimal CNN whose structure is defined by a symbolic plan.
|
|
139
|
+
Uses fixed conv hyperparameters and a configurable FC head.
|
|
140
|
+
"""
|
|
141
|
+
|
|
142
|
+
def __init__( # noqa: PLR0913
|
|
143
|
+
self: CNNDemo,
|
|
144
|
+
layers: Sequence[str],
|
|
145
|
+
*,
|
|
146
|
+
in_ch: int = 4,
|
|
147
|
+
conv_out_ch: int = 16,
|
|
148
|
+
conv_kernel: int = 3,
|
|
149
|
+
conv_stride: int = 1,
|
|
150
|
+
conv_pad: int = 1,
|
|
151
|
+
fc_hidden: int = 128,
|
|
152
|
+
n_actions: int = 10,
|
|
153
|
+
input_shape: tuple[int, int, int, int] = (1, 4, 56, 56),
|
|
154
|
+
) -> None:
|
|
155
|
+
super().__init__()
|
|
156
|
+
self.layers_plan = list(layers)
|
|
157
|
+
_ = in_ch
|
|
158
|
+
_, C, H, W = input_shape # noqa: N806
|
|
159
|
+
cur_c, cur_h, cur_w = C, H, W
|
|
160
|
+
|
|
161
|
+
self.convs = nn.ModuleList()
|
|
162
|
+
self.fcs = nn.ModuleList()
|
|
163
|
+
self.pools = nn.ModuleList()
|
|
164
|
+
|
|
165
|
+
next_fc_in = None
|
|
166
|
+
for tok in self.layers_plan:
|
|
167
|
+
if tok.startswith("conv"):
|
|
168
|
+
conv = nn.Conv2d(
|
|
169
|
+
in_channels=cur_c,
|
|
170
|
+
out_channels=conv_out_ch,
|
|
171
|
+
kernel_size=conv_kernel,
|
|
172
|
+
stride=conv_stride,
|
|
173
|
+
padding=conv_pad,
|
|
174
|
+
)
|
|
175
|
+
self.convs.append(conv)
|
|
176
|
+
cur_c = conv_out_ch
|
|
177
|
+
cur_h = (cur_h + 2 * conv_pad - conv_kernel) // conv_stride + 1
|
|
178
|
+
cur_w = (cur_w + 2 * conv_pad - conv_kernel) // conv_stride + 1
|
|
179
|
+
elif tok == "relu":
|
|
180
|
+
pass
|
|
181
|
+
elif tok.startswith("maxpool"):
|
|
182
|
+
pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
|
|
183
|
+
self.pools.append(pool)
|
|
184
|
+
cur_h = (cur_h - 2) // 2 + 1
|
|
185
|
+
cur_w = (cur_w - 2) // 2 + 1
|
|
186
|
+
elif tok == "reshape":
|
|
187
|
+
next_fc_in = cur_c * cur_h * cur_w
|
|
188
|
+
elif tok.startswith("fc") or tok == "final":
|
|
189
|
+
if next_fc_in is None:
|
|
190
|
+
next_fc_in = cur_c * cur_h * cur_w
|
|
191
|
+
out_features = n_actions if tok == "final" else fc_hidden
|
|
192
|
+
self.fcs.append(nn.Linear(next_fc_in, out_features))
|
|
193
|
+
next_fc_in = out_features
|
|
194
|
+
else:
|
|
195
|
+
msg = f"Unknown token: {tok}"
|
|
196
|
+
raise ValueError(msg)
|
|
197
|
+
|
|
198
|
+
self._ci = self._fi = self._pi = 0
|
|
199
|
+
|
|
200
|
+
def forward(self: CNNDemo, x: torch.Tensor) -> torch.Tensor:
|
|
201
|
+
"""Execute the plan in order."""
|
|
202
|
+
self._ci = self._fi = self._pi = 0
|
|
203
|
+
for tok in self.layers_plan:
|
|
204
|
+
if tok.startswith("conv"):
|
|
205
|
+
x = self.convs[self._ci](x)
|
|
206
|
+
self._ci += 1
|
|
207
|
+
elif tok == "relu":
|
|
208
|
+
x = F.relu(x)
|
|
209
|
+
elif tok.startswith("maxpool"):
|
|
210
|
+
x = self.pools[self._pi](x)
|
|
211
|
+
self._pi += 1
|
|
212
|
+
elif tok == "reshape":
|
|
213
|
+
x = x.reshape(x.shape[0], -1)
|
|
214
|
+
elif tok.startswith("fc") or tok == "final":
|
|
215
|
+
x = self.fcs[self._fi](x)
|
|
216
|
+
self._fi += 1
|
|
217
|
+
else:
|
|
218
|
+
msg = f"Unknown token: {tok}"
|
|
219
|
+
raise ValueError(msg)
|
|
220
|
+
return x
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
# -----------------------------------------------------------------------------
|
|
224
|
+
# Export / inputs / benchmark shim
|
|
225
|
+
# -----------------------------------------------------------------------------
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def export_onnx(
|
|
229
|
+
model: nn.Module,
|
|
230
|
+
onnx_path: Path,
|
|
231
|
+
input_shape: tuple[int] = (1, 4, 56, 56),
|
|
232
|
+
) -> None:
|
|
233
|
+
"""Export a Torch model to ONNX and ensure the directory exists."""
|
|
234
|
+
onnx_path.parent.mkdir(parents=True, exist_ok=True)
|
|
235
|
+
model.eval()
|
|
236
|
+
dummy = torch.zeros(*input_shape)
|
|
237
|
+
torch.onnx.export(
|
|
238
|
+
model,
|
|
239
|
+
dummy,
|
|
240
|
+
onnx_path.as_posix(),
|
|
241
|
+
input_names=["input"],
|
|
242
|
+
output_names=["output"],
|
|
243
|
+
opset_version=13,
|
|
244
|
+
do_constant_folding=True,
|
|
245
|
+
dynamic_axes=None,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def write_input_json(json_path: Path, input_shape: tuple[int] = (1, 4, 28, 28)) -> None:
|
|
250
|
+
"""Write a zero-valued input tensor to JSON alongside its [N,C,H,W] shape."""
|
|
251
|
+
json_path.parent.mkdir(parents=True, exist_ok=True)
|
|
252
|
+
n, c, h, w = input_shape
|
|
253
|
+
arr = [0.0] * (n * c * h * w)
|
|
254
|
+
with json_path.open("w", encoding="utf-8") as f:
|
|
255
|
+
json.dump({"input": arr, "shape": [n, c, h, w]}, f)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def run_bench(
|
|
259
|
+
onnx_path: Path,
|
|
260
|
+
input_json: Path,
|
|
261
|
+
iterations: int,
|
|
262
|
+
results_jsonl: Path,
|
|
263
|
+
) -> int:
|
|
264
|
+
"""
|
|
265
|
+
Invoke the benchmark runner module as a subprocess.
|
|
266
|
+
Returns the exit code (0 on success).
|
|
267
|
+
"""
|
|
268
|
+
cmd = [
|
|
269
|
+
"python",
|
|
270
|
+
"-m",
|
|
271
|
+
"python.scripts.benchmark_runner",
|
|
272
|
+
"--model",
|
|
273
|
+
onnx_path.as_posix(),
|
|
274
|
+
"--input",
|
|
275
|
+
input_json.as_posix(),
|
|
276
|
+
"--iterations",
|
|
277
|
+
str(iterations),
|
|
278
|
+
"--output",
|
|
279
|
+
results_jsonl.as_posix(),
|
|
280
|
+
"--summarize",
|
|
281
|
+
]
|
|
282
|
+
return subprocess.run(cmd, check=False, shell=False).returncode
|
|
283
|
+
|
|
284
|
+
|
|
285
|
+
# -----------------------------------------------------------------------------
|
|
286
|
+
# CLI
|
|
287
|
+
# -----------------------------------------------------------------------------
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
def _parse_int_list(s: str) -> list[int]:
|
|
291
|
+
"""
|
|
292
|
+
Parse either a comma list "28,56,84" or a range "start:stop[:step]".
|
|
293
|
+
The range is inclusive of stop. Non-positive values are filtered out.
|
|
294
|
+
"""
|
|
295
|
+
three = 3
|
|
296
|
+
s = s.strip()
|
|
297
|
+
if ":" in s:
|
|
298
|
+
parts = [int(x) for x in s.split(":")]
|
|
299
|
+
if len(parts) not in (2, 3):
|
|
300
|
+
msg = "range syntax must be start:stop[:step]"
|
|
301
|
+
raise ValueError(msg)
|
|
302
|
+
start, stop = parts[0], parts[1]
|
|
303
|
+
step = parts[2] if len(parts) == three else 1
|
|
304
|
+
if step == 0:
|
|
305
|
+
msg = "step must be nonzero"
|
|
306
|
+
raise ValueError(msg)
|
|
307
|
+
out = list(range(start, stop + (1 if step > 0 else -1), step))
|
|
308
|
+
return [x for x in out if x > 0]
|
|
309
|
+
return [int(x) for x in s.split(",") if x.strip()]
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def _resolve_output_dirs(
|
|
313
|
+
sweep: str,
|
|
314
|
+
onnx_dir_arg: str | None,
|
|
315
|
+
inputs_dir_arg: str | None,
|
|
316
|
+
) -> tuple[Path, Path]:
|
|
317
|
+
"""
|
|
318
|
+
Choose output directories from the sweep type unless explicitly overridden.
|
|
319
|
+
|
|
320
|
+
depth → python/models/models_onnx/depth ; python/models/inputs/depth
|
|
321
|
+
breadth → python/models/models_onnx/breadth ; python/models/inputs/breadth
|
|
322
|
+
"""
|
|
323
|
+
sub = sweep if sweep in ("depth", "breadth") else "depth"
|
|
324
|
+
default_onnx = Path(f"python/models/models_onnx/{sub}")
|
|
325
|
+
default_inputs = Path(f"python/models/inputs/{sub}")
|
|
326
|
+
onnx_dir = Path(onnx_dir_arg) if onnx_dir_arg else default_onnx
|
|
327
|
+
inputs_dir = Path(inputs_dir_arg) if inputs_dir_arg else default_inputs
|
|
328
|
+
return onnx_dir, inputs_dir
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
def main() -> None: # noqa: PLR0915
|
|
332
|
+
"""Argument parsing and sweep orchestration."""
|
|
333
|
+
ap = argparse.ArgumentParser(
|
|
334
|
+
description="Depth or breadth sweep for simple LeNet-like CNNs.",
|
|
335
|
+
)
|
|
336
|
+
# depth controls
|
|
337
|
+
ap.add_argument("--depth-min", type=int, default=1)
|
|
338
|
+
ap.add_argument("--depth-max", type=int, default=12)
|
|
339
|
+
ap.add_argument("--iterations", type=int, default=3)
|
|
340
|
+
ap.add_argument("--results", default=None)
|
|
341
|
+
ap.add_argument("--onnx-dir", default=None)
|
|
342
|
+
ap.add_argument("--inputs-dir", default=None)
|
|
343
|
+
ap.add_argument("--n-actions", type=int, default=10)
|
|
344
|
+
|
|
345
|
+
# sweep mode + breadth options
|
|
346
|
+
ap.add_argument(
|
|
347
|
+
"--sweep",
|
|
348
|
+
choices=["depth", "breadth"],
|
|
349
|
+
default="depth",
|
|
350
|
+
help="depth: vary number of conv blocks; "
|
|
351
|
+
"breadth: vary input size at fixed depth",
|
|
352
|
+
)
|
|
353
|
+
ap.add_argument(
|
|
354
|
+
"--arch-depth",
|
|
355
|
+
type=int,
|
|
356
|
+
default=5,
|
|
357
|
+
help="(breadth) conv blocks used for all inputs",
|
|
358
|
+
)
|
|
359
|
+
ap.add_argument(
|
|
360
|
+
"--input-hw",
|
|
361
|
+
type=int,
|
|
362
|
+
default=56,
|
|
363
|
+
help="(depth) input H=W when varying depth",
|
|
364
|
+
)
|
|
365
|
+
ap.add_argument(
|
|
366
|
+
"--input-hw-list",
|
|
367
|
+
type=str,
|
|
368
|
+
default="28,56,84,112",
|
|
369
|
+
help="(breadth) comma list or start:stop[:step], e.g. "
|
|
370
|
+
"'28,56,84' or '32:160:32'",
|
|
371
|
+
)
|
|
372
|
+
ap.add_argument(
|
|
373
|
+
"--pool-cap",
|
|
374
|
+
type=int,
|
|
375
|
+
default=2,
|
|
376
|
+
help="cap the number of initial maxpool blocks",
|
|
377
|
+
)
|
|
378
|
+
ap.add_argument(
|
|
379
|
+
"--stop-at-hw",
|
|
380
|
+
type=int,
|
|
381
|
+
default=None,
|
|
382
|
+
help="allow pooling while H >= this (if pool-cap unset)",
|
|
383
|
+
)
|
|
384
|
+
ap.add_argument("--conv-out-ch", type=int, default=16)
|
|
385
|
+
ap.add_argument("--fc-hidden", type=int, default=128)
|
|
386
|
+
ap.add_argument(
|
|
387
|
+
"--tag",
|
|
388
|
+
type=str,
|
|
389
|
+
default="",
|
|
390
|
+
help="optional tag added to filenames",
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
args = ap.parse_args()
|
|
394
|
+
|
|
395
|
+
# Ensure output dirs and a robust default results path
|
|
396
|
+
onnx_dir, in_dir = _resolve_output_dirs(args.sweep, args.onnx_dir, args.inputs_dir)
|
|
397
|
+
|
|
398
|
+
# Dynamic default for results (also handles empty string)
|
|
399
|
+
if args.results is None or not str(args.results).strip():
|
|
400
|
+
# default results path based on sweep type
|
|
401
|
+
results = Path("benchmarking") / f"{args.sweep}_sweep.jsonl"
|
|
402
|
+
else:
|
|
403
|
+
results = Path(args.results)
|
|
404
|
+
|
|
405
|
+
results.parent.mkdir(parents=True, exist_ok=True)
|
|
406
|
+
|
|
407
|
+
onnx_dir, in_dir = _resolve_output_dirs(args.sweep, args.onnx_dir, args.inputs_dir)
|
|
408
|
+
|
|
409
|
+
if args.sweep == "depth":
|
|
410
|
+
input_shape = (1, 4, args.input_hw, args.input_hw)
|
|
411
|
+
for d in range(args.depth_min, args.depth_max + 1):
|
|
412
|
+
plan = plan_for_depth(
|
|
413
|
+
d=d,
|
|
414
|
+
input_hw=args.input_hw,
|
|
415
|
+
base_fc=1,
|
|
416
|
+
pool_cap=args.pool_cap,
|
|
417
|
+
stop_at_hw=args.stop_at_hw,
|
|
418
|
+
)
|
|
419
|
+
C, P, Fc, R = count_layers(plan) # noqa: N806
|
|
420
|
+
uid = f"depth_d{d}_c{C}_p{P}_f{Fc}_r{R}"
|
|
421
|
+
if args.tag:
|
|
422
|
+
uid = f"{uid}_{args.tag}"
|
|
423
|
+
|
|
424
|
+
onnx_path = onnx_dir / f"{uid}.onnx"
|
|
425
|
+
input_json = in_dir / f"{uid}_input.json"
|
|
426
|
+
|
|
427
|
+
model = CNNDemo(
|
|
428
|
+
plan,
|
|
429
|
+
input_shape=input_shape,
|
|
430
|
+
n_actions=args.n_actions,
|
|
431
|
+
conv_out_ch=args.conv_out_ch,
|
|
432
|
+
fc_hidden=args.fc_hidden,
|
|
433
|
+
)
|
|
434
|
+
export_onnx(model, onnx_path, input_shape=input_shape)
|
|
435
|
+
write_input_json(input_json, input_shape=input_shape)
|
|
436
|
+
print(f"[gen] d={d} :: C={C}, P={P}, F={Fc}, R={R} -> {onnx_path.name}")
|
|
437
|
+
|
|
438
|
+
rc = run_bench(onnx_path, input_json, args.iterations, results)
|
|
439
|
+
if rc != 0:
|
|
440
|
+
print(f"[warn] benchmark rc={rc} for depth={d}")
|
|
441
|
+
else:
|
|
442
|
+
# breadth sweep: fixed architecture depth; vary input sizes
|
|
443
|
+
sizes = _parse_int_list(args.input_hw_list)
|
|
444
|
+
d = int(args.arch_depth)
|
|
445
|
+
for hw in sizes:
|
|
446
|
+
input_shape = (1, 4, hw, hw)
|
|
447
|
+
plan = plan_for_depth(
|
|
448
|
+
d=d,
|
|
449
|
+
input_hw=hw,
|
|
450
|
+
base_fc=1,
|
|
451
|
+
pool_cap=args.pool_cap,
|
|
452
|
+
stop_at_hw=args.stop_at_hw,
|
|
453
|
+
)
|
|
454
|
+
C, P, Fc, R = count_layers(plan) # noqa: N806
|
|
455
|
+
uid = f"breadth_h{hw}_d{d}_c{C}_p{P}_f{Fc}_r{R}"
|
|
456
|
+
if args.tag:
|
|
457
|
+
uid = f"{uid}_{args.tag}"
|
|
458
|
+
|
|
459
|
+
onnx_path = onnx_dir / f"{uid}.onnx"
|
|
460
|
+
input_json = in_dir / f"{uid}_input.json"
|
|
461
|
+
|
|
462
|
+
model = CNNDemo(
|
|
463
|
+
plan,
|
|
464
|
+
input_shape=input_shape,
|
|
465
|
+
n_actions=args.n_actions,
|
|
466
|
+
conv_out_ch=args.conv_out_ch,
|
|
467
|
+
fc_hidden=args.fc_hidden,
|
|
468
|
+
)
|
|
469
|
+
export_onnx(model, onnx_path, input_shape=input_shape)
|
|
470
|
+
write_input_json(input_json, input_shape=input_shape)
|
|
471
|
+
print(
|
|
472
|
+
f"[gen] H=W={hw} :: d={d} | C={C}, P={P}, F={Fc}, R={R} "
|
|
473
|
+
f"-> {onnx_path.name}",
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
rc = run_bench(onnx_path, input_json, args.iterations, results)
|
|
477
|
+
if rc != 0:
|
|
478
|
+
print(f"[warn] benchmark rc={rc} for hw={hw}")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
if __name__ == "__main__":
|
|
482
|
+
main()
|
python/tests/__init__.py
ADDED
|
File without changes
|
|
File without changes
|