JSTprove 1.0.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +5 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Argument specifications for CLI commands."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
import argparse
|
|
9
|
+
|
|
10
|
+
from dataclasses import dataclass, field
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass(frozen=True)
|
|
15
|
+
class ArgSpec:
|
|
16
|
+
"""Specification for a command-line argument."""
|
|
17
|
+
|
|
18
|
+
name: str
|
|
19
|
+
flag: str
|
|
20
|
+
help_text: str
|
|
21
|
+
short: str = ""
|
|
22
|
+
arg_type: type | None = None
|
|
23
|
+
extra_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def positional(self) -> str:
|
|
27
|
+
"""Return the positional argument name."""
|
|
28
|
+
return f"pos_{self.name}"
|
|
29
|
+
|
|
30
|
+
def add_to_parser(
|
|
31
|
+
self,
|
|
32
|
+
parser: argparse.ArgumentParser,
|
|
33
|
+
help_override: str | None = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
"""Add both positional and flag arguments to the parser."""
|
|
36
|
+
help_text = help_override or self.help_text
|
|
37
|
+
kwargs = {"help": help_text, **self.extra_kwargs}
|
|
38
|
+
if self.arg_type is not None:
|
|
39
|
+
kwargs["type"] = self.arg_type
|
|
40
|
+
|
|
41
|
+
if self.short:
|
|
42
|
+
parser.add_argument(
|
|
43
|
+
self.positional,
|
|
44
|
+
nargs="?",
|
|
45
|
+
metavar=self.name,
|
|
46
|
+
**kwargs,
|
|
47
|
+
)
|
|
48
|
+
parser.add_argument(
|
|
49
|
+
self.short,
|
|
50
|
+
self.flag,
|
|
51
|
+
**kwargs,
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
parser.add_argument(
|
|
55
|
+
self.flag,
|
|
56
|
+
**kwargs,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
MODEL_PATH = ArgSpec(
|
|
61
|
+
name="model_path",
|
|
62
|
+
flag="--model-path",
|
|
63
|
+
short="-m",
|
|
64
|
+
help_text="Path to the original ONNX model.",
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
CIRCUIT_PATH = ArgSpec(
|
|
68
|
+
name="circuit_path",
|
|
69
|
+
flag="--circuit-path",
|
|
70
|
+
short="-c",
|
|
71
|
+
help_text="Path to the compiled circuit.",
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
INPUT_PATH = ArgSpec(
|
|
75
|
+
name="input_path",
|
|
76
|
+
flag="--input-path",
|
|
77
|
+
short="-i",
|
|
78
|
+
help_text="Path to input JSON.",
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
OUTPUT_PATH = ArgSpec(
|
|
82
|
+
name="output_path",
|
|
83
|
+
flag="--output-path",
|
|
84
|
+
short="-o",
|
|
85
|
+
help_text="Path to write model outputs JSON.",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
WITNESS_PATH = ArgSpec(
|
|
89
|
+
name="witness_path",
|
|
90
|
+
flag="--witness-path",
|
|
91
|
+
short="-w",
|
|
92
|
+
help_text="Path to write witness.",
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
PROOF_PATH = ArgSpec(
|
|
96
|
+
name="proof_path",
|
|
97
|
+
flag="--proof-path",
|
|
98
|
+
short="-p",
|
|
99
|
+
help_text="Path to write proof.",
|
|
100
|
+
)
|
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import functools
|
|
5
|
+
import importlib
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import TYPE_CHECKING, Any, ClassVar
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from collections.abc import Callable
|
|
12
|
+
|
|
13
|
+
from python.frontend.commands.args import ArgSpec
|
|
14
|
+
|
|
15
|
+
from python.frontend.commands.constants import (
|
|
16
|
+
DEFAULT_CIRCUIT_CLASS,
|
|
17
|
+
DEFAULT_CIRCUIT_MODULE,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class HiddenPositionalHelpFormatter(argparse.HelpFormatter):
|
|
22
|
+
def _format_usage(
|
|
23
|
+
self,
|
|
24
|
+
usage: str | None,
|
|
25
|
+
actions: list,
|
|
26
|
+
groups: list,
|
|
27
|
+
prefix: str | None,
|
|
28
|
+
) -> str:
|
|
29
|
+
filtered_actions = [
|
|
30
|
+
action
|
|
31
|
+
for action in actions
|
|
32
|
+
if not (
|
|
33
|
+
isinstance(action, argparse._StoreAction) # noqa: SLF001
|
|
34
|
+
and action.dest.startswith("pos_")
|
|
35
|
+
)
|
|
36
|
+
]
|
|
37
|
+
return super()._format_usage(usage, filtered_actions, groups, prefix)
|
|
38
|
+
|
|
39
|
+
def _format_action(self, action: argparse.Action) -> str:
|
|
40
|
+
if isinstance(
|
|
41
|
+
action,
|
|
42
|
+
argparse._StoreAction, # noqa: SLF001
|
|
43
|
+
) and action.dest.startswith(
|
|
44
|
+
"pos_",
|
|
45
|
+
):
|
|
46
|
+
return ""
|
|
47
|
+
return super()._format_action(action)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class BaseCommand(ABC):
|
|
51
|
+
"""Base class for CLI commands."""
|
|
52
|
+
|
|
53
|
+
name: ClassVar[str]
|
|
54
|
+
aliases: ClassVar[list[str]] = []
|
|
55
|
+
help: ClassVar[str]
|
|
56
|
+
|
|
57
|
+
@classmethod
|
|
58
|
+
@abstractmethod
|
|
59
|
+
def configure_parser(
|
|
60
|
+
cls: type[BaseCommand],
|
|
61
|
+
parser: argparse.ArgumentParser,
|
|
62
|
+
) -> None:
|
|
63
|
+
"""Configure the argument parser for this command."""
|
|
64
|
+
|
|
65
|
+
@classmethod
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def run(cls: type[BaseCommand], args: argparse.Namespace) -> None:
|
|
68
|
+
"""Execute the command."""
|
|
69
|
+
|
|
70
|
+
@staticmethod
|
|
71
|
+
def validate_required(*required: ArgSpec) -> Callable:
|
|
72
|
+
def decorator(func: Callable) -> Callable:
|
|
73
|
+
@functools.wraps(func)
|
|
74
|
+
def wrapper(cls: type[BaseCommand], args: argparse.Namespace) -> None:
|
|
75
|
+
for arg_spec in required:
|
|
76
|
+
flag_val = getattr(args, arg_spec.name, None)
|
|
77
|
+
pos_val = getattr(args, arg_spec.positional, None)
|
|
78
|
+
merged = flag_val if flag_val is not None else pos_val
|
|
79
|
+
if not merged:
|
|
80
|
+
msg = f"Missing required argument: {arg_spec.name}"
|
|
81
|
+
raise ValueError(msg)
|
|
82
|
+
setattr(args, arg_spec.name, merged)
|
|
83
|
+
return func(cls, args)
|
|
84
|
+
|
|
85
|
+
return wrapper
|
|
86
|
+
|
|
87
|
+
return decorator
|
|
88
|
+
|
|
89
|
+
@staticmethod
|
|
90
|
+
def validate_paths(*paths: ArgSpec) -> Callable:
|
|
91
|
+
def decorator(func: Callable) -> Callable:
|
|
92
|
+
@functools.wraps(func)
|
|
93
|
+
def wrapper(cls: type[BaseCommand], args: argparse.Namespace) -> None:
|
|
94
|
+
for arg_spec in paths:
|
|
95
|
+
cls._ensure_file_exists(getattr(args, arg_spec.name))
|
|
96
|
+
return func(cls, args)
|
|
97
|
+
|
|
98
|
+
return wrapper
|
|
99
|
+
|
|
100
|
+
return decorator
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def validate_parent_paths(*paths: ArgSpec) -> Callable:
|
|
104
|
+
def decorator(func: Callable) -> Callable:
|
|
105
|
+
@functools.wraps(func)
|
|
106
|
+
def wrapper(cls: type[BaseCommand], args: argparse.Namespace) -> None:
|
|
107
|
+
for arg_spec in paths:
|
|
108
|
+
cls._ensure_parent_dir(getattr(args, arg_spec.name))
|
|
109
|
+
return func(cls, args)
|
|
110
|
+
|
|
111
|
+
return wrapper
|
|
112
|
+
|
|
113
|
+
return decorator
|
|
114
|
+
|
|
115
|
+
@staticmethod
|
|
116
|
+
def validate_optional_paths(*paths: ArgSpec) -> Callable:
|
|
117
|
+
def decorator(func: Callable) -> Callable:
|
|
118
|
+
@functools.wraps(func)
|
|
119
|
+
def wrapper(cls: type[BaseCommand], args: argparse.Namespace) -> None:
|
|
120
|
+
for arg_spec in paths:
|
|
121
|
+
flag_val = getattr(args, arg_spec.name, None)
|
|
122
|
+
pos_val = getattr(args, arg_spec.positional, None)
|
|
123
|
+
merged = flag_val if flag_val is not None else pos_val
|
|
124
|
+
if merged is not None:
|
|
125
|
+
cls._ensure_file_exists(merged)
|
|
126
|
+
setattr(args, arg_spec.name, merged)
|
|
127
|
+
return func(cls, args)
|
|
128
|
+
|
|
129
|
+
return wrapper
|
|
130
|
+
|
|
131
|
+
return decorator
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def _ensure_file_exists(path: str) -> None:
|
|
135
|
+
p = Path(path)
|
|
136
|
+
if not p.is_file():
|
|
137
|
+
msg = f"Required file not found: {path}"
|
|
138
|
+
raise FileNotFoundError(msg)
|
|
139
|
+
if not p.stat().st_mode & 0o444:
|
|
140
|
+
msg = f"Cannot read file: {path}"
|
|
141
|
+
raise PermissionError(msg)
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def _ensure_parent_dir(path: str) -> None:
|
|
145
|
+
Path(path).parent.mkdir(parents=True, exist_ok=True)
|
|
146
|
+
|
|
147
|
+
@staticmethod
|
|
148
|
+
def append_arg(cmd: list[str], flag: str, val: object | None) -> None:
|
|
149
|
+
if val is None:
|
|
150
|
+
return
|
|
151
|
+
if isinstance(val, str) and not val.strip():
|
|
152
|
+
return
|
|
153
|
+
cmd.extend([flag, str(val)])
|
|
154
|
+
|
|
155
|
+
@staticmethod
|
|
156
|
+
def append_args_from_specs(
|
|
157
|
+
cmd: list[str],
|
|
158
|
+
specs: tuple[tuple[ArgSpec, str], ...],
|
|
159
|
+
) -> None:
|
|
160
|
+
for spec, value in specs:
|
|
161
|
+
cmd.extend([spec.flag, value])
|
|
162
|
+
|
|
163
|
+
@staticmethod
|
|
164
|
+
def append_args_from_namespace(
|
|
165
|
+
cmd: list[str],
|
|
166
|
+
args: argparse.Namespace,
|
|
167
|
+
specs: tuple[ArgSpec, ...],
|
|
168
|
+
) -> None:
|
|
169
|
+
for spec in specs:
|
|
170
|
+
value = getattr(args, spec.name, None)
|
|
171
|
+
BaseCommand.append_arg(cmd, spec.flag, value)
|
|
172
|
+
|
|
173
|
+
@staticmethod
|
|
174
|
+
def _build_circuit(model_name_hint: str | None = None) -> Any: # noqa: ANN401
|
|
175
|
+
mod = importlib.import_module(DEFAULT_CIRCUIT_MODULE)
|
|
176
|
+
try:
|
|
177
|
+
cls = getattr(mod, DEFAULT_CIRCUIT_CLASS)
|
|
178
|
+
except AttributeError as e:
|
|
179
|
+
msg = (
|
|
180
|
+
f"Default circuit class '{DEFAULT_CIRCUIT_CLASS}' "
|
|
181
|
+
f"not found in '{DEFAULT_CIRCUIT_MODULE}'"
|
|
182
|
+
)
|
|
183
|
+
raise RuntimeError(msg) from e
|
|
184
|
+
|
|
185
|
+
name = model_name_hint or "cli"
|
|
186
|
+
|
|
187
|
+
for attempt in (
|
|
188
|
+
lambda: cls(model_name=name),
|
|
189
|
+
lambda: cls(name=name),
|
|
190
|
+
lambda: cls(name),
|
|
191
|
+
lambda: cls(),
|
|
192
|
+
):
|
|
193
|
+
try:
|
|
194
|
+
return attempt()
|
|
195
|
+
except TypeError: # noqa: PERF203
|
|
196
|
+
continue
|
|
197
|
+
|
|
198
|
+
msg = f"Could not construct {cls.__name__} with/without name '{name}'"
|
|
199
|
+
raise RuntimeError(msg)
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
from python.frontend.commands.base import BaseCommand
|
|
9
|
+
from python.frontend.commands.bench.list import ListCommand
|
|
10
|
+
from python.frontend.commands.bench.model import ModelCommand
|
|
11
|
+
from python.frontend.commands.bench.sweep import SweepCommand
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BenchCommand(BaseCommand):
|
|
15
|
+
"""Benchmark JSTprove models with various configurations."""
|
|
16
|
+
|
|
17
|
+
name: ClassVar[str] = "bench"
|
|
18
|
+
aliases: ClassVar[list[str]] = []
|
|
19
|
+
help: ClassVar[str] = "Benchmark JSTprove models with various configurations."
|
|
20
|
+
|
|
21
|
+
SUBCOMMANDS: ClassVar[list[type[BaseCommand]]] = [
|
|
22
|
+
ListCommand,
|
|
23
|
+
ModelCommand,
|
|
24
|
+
SweepCommand,
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def configure_parser(
|
|
29
|
+
cls: type[BenchCommand],
|
|
30
|
+
parser: argparse.ArgumentParser,
|
|
31
|
+
) -> None:
|
|
32
|
+
subparsers = parser.add_subparsers(
|
|
33
|
+
dest="bench_subcommand",
|
|
34
|
+
required=True,
|
|
35
|
+
help="Benchmark subcommands",
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
for subcommand_cls in cls.SUBCOMMANDS:
|
|
39
|
+
subparser = subparsers.add_parser(
|
|
40
|
+
subcommand_cls.name,
|
|
41
|
+
help=subcommand_cls.help,
|
|
42
|
+
aliases=subcommand_cls.aliases,
|
|
43
|
+
)
|
|
44
|
+
subcommand_cls.configure_parser(subparser)
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def run(cls: type[BenchCommand], args: argparse.Namespace) -> None:
|
|
48
|
+
for subcommand_cls in cls.SUBCOMMANDS:
|
|
49
|
+
if args.bench_subcommand in [subcommand_cls.name, *subcommand_cls.aliases]:
|
|
50
|
+
subcommand_cls.run(args)
|
|
51
|
+
return
|
|
52
|
+
|
|
53
|
+
msg = f"Unknown bench subcommand: {args.bench_subcommand}"
|
|
54
|
+
raise ValueError(msg)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import argparse
|
|
7
|
+
|
|
8
|
+
from python.frontend.commands.args import ArgSpec
|
|
9
|
+
from python.frontend.commands.base import BaseCommand
|
|
10
|
+
|
|
11
|
+
LIST_MODELS = ArgSpec(
|
|
12
|
+
name="list_models",
|
|
13
|
+
flag="--list-models",
|
|
14
|
+
help_text="List all available circuit models.",
|
|
15
|
+
extra_kwargs={"action": "store_true", "default": False},
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ListCommand(BaseCommand):
|
|
20
|
+
"""List all available circuit models for benchmarking."""
|
|
21
|
+
|
|
22
|
+
name: ClassVar[str] = "list"
|
|
23
|
+
aliases: ClassVar[list[str]] = []
|
|
24
|
+
help: ClassVar[str] = "List all available circuit models for benchmarking."
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def configure_parser(
|
|
28
|
+
cls: type[ListCommand],
|
|
29
|
+
parser: argparse.ArgumentParser,
|
|
30
|
+
) -> None:
|
|
31
|
+
LIST_MODELS.add_to_parser(parser)
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
def run(cls: type[ListCommand], args: argparse.Namespace) -> None: # noqa: ARG003
|
|
35
|
+
from python.core.utils.model_registry import ( # noqa: PLC0415
|
|
36
|
+
list_available_models,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
available_models = list_available_models()
|
|
40
|
+
print("\nAvailable Circuit Models:") # noqa: T201
|
|
41
|
+
for model in available_models:
|
|
42
|
+
print(f"- {model}") # noqa: T201
|
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import sys
|
|
5
|
+
import tempfile
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
8
|
+
|
|
9
|
+
if TYPE_CHECKING:
|
|
10
|
+
import argparse
|
|
11
|
+
|
|
12
|
+
from python.core.utils.constants import MODEL_SOURCE_CLASS, MODEL_SOURCE_ONNX
|
|
13
|
+
from python.core.utils.helper_functions import (
|
|
14
|
+
ensure_parent_dir,
|
|
15
|
+
run_subprocess,
|
|
16
|
+
to_json,
|
|
17
|
+
)
|
|
18
|
+
from python.frontend.commands.args import ArgSpec
|
|
19
|
+
from python.frontend.commands.base import BaseCommand
|
|
20
|
+
|
|
21
|
+
SOURCE_CHOICES: tuple[str, ...] = (MODEL_SOURCE_CLASS, MODEL_SOURCE_ONNX)
|
|
22
|
+
|
|
23
|
+
BENCH_MODEL_PATH = ArgSpec(
|
|
24
|
+
name="model_path",
|
|
25
|
+
flag="--model-path",
|
|
26
|
+
short="-m",
|
|
27
|
+
help_text="Direct path to ONNX model file to benchmark.",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
MODEL = ArgSpec(
|
|
31
|
+
name="model",
|
|
32
|
+
flag="--model",
|
|
33
|
+
help_text=(
|
|
34
|
+
"Model name(s) from registry to benchmark. "
|
|
35
|
+
"Use multiple times to test more than one."
|
|
36
|
+
),
|
|
37
|
+
extra_kwargs={"action": "append", "default": None},
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
SOURCE = ArgSpec(
|
|
41
|
+
name="source",
|
|
42
|
+
flag="--source",
|
|
43
|
+
help_text="Restrict registry models to a specific source: class or onnx.",
|
|
44
|
+
extra_kwargs={"choices": list(SOURCE_CHOICES), "default": None},
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
ITERATIONS = ArgSpec(
|
|
48
|
+
name="iterations",
|
|
49
|
+
flag="--iterations",
|
|
50
|
+
help_text="E2E loops per model (default 5)",
|
|
51
|
+
arg_type=int,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
RESULTS = ArgSpec(
|
|
55
|
+
name="results",
|
|
56
|
+
flag="--results",
|
|
57
|
+
help_text="Path to JSONL results (e.g., benchmarking/model_name.jsonl)",
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class ModelCommand(BaseCommand):
|
|
62
|
+
"""Benchmark specific models from registry or file path."""
|
|
63
|
+
|
|
64
|
+
name: ClassVar[str] = "model"
|
|
65
|
+
aliases: ClassVar[list[str]] = []
|
|
66
|
+
help: ClassVar[str] = "Benchmark specific models from registry or file path."
|
|
67
|
+
|
|
68
|
+
DEFAULT_ITERATIONS: ClassVar[int] = 5
|
|
69
|
+
SCRIPT_BENCHMARK_RUNNER: ClassVar[str] = "python.scripts.benchmark_runner"
|
|
70
|
+
|
|
71
|
+
@classmethod
|
|
72
|
+
def configure_parser(
|
|
73
|
+
cls: type[ModelCommand],
|
|
74
|
+
parser: argparse.ArgumentParser,
|
|
75
|
+
) -> None:
|
|
76
|
+
BENCH_MODEL_PATH.add_to_parser(parser)
|
|
77
|
+
MODEL.add_to_parser(parser)
|
|
78
|
+
SOURCE.add_to_parser(parser)
|
|
79
|
+
ITERATIONS.add_to_parser(parser)
|
|
80
|
+
RESULTS.add_to_parser(parser)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
@BaseCommand.validate_optional_paths(BENCH_MODEL_PATH)
|
|
84
|
+
def run(cls: type[ModelCommand], args: argparse.Namespace) -> None:
|
|
85
|
+
if args.model_path:
|
|
86
|
+
name = Path(args.model_path).stem
|
|
87
|
+
cls._run_bench_single_model(args, args.model_path, name)
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
if args.model or args.source:
|
|
91
|
+
cls._run_bench_on_models(args)
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
msg = "Specify --model-path, --model, or --source"
|
|
95
|
+
raise ValueError(msg)
|
|
96
|
+
|
|
97
|
+
@classmethod
|
|
98
|
+
def _run_bench_on_models(cls: type[ModelCommand], args: argparse.Namespace) -> None:
|
|
99
|
+
from python.core.utils.model_registry import get_models_to_test # noqa: PLC0415
|
|
100
|
+
|
|
101
|
+
models = get_models_to_test(args.model, args.source or SOURCE_CHOICES[1])
|
|
102
|
+
if not models:
|
|
103
|
+
msg = "No models selected for benchmarking."
|
|
104
|
+
raise ValueError(msg)
|
|
105
|
+
|
|
106
|
+
for model_entry in models:
|
|
107
|
+
instance = model_entry.loader()
|
|
108
|
+
cls._run_bench_single_model(
|
|
109
|
+
args,
|
|
110
|
+
instance.model_file_name,
|
|
111
|
+
model_entry.name,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
@classmethod
|
|
115
|
+
def _generate_model_input(
|
|
116
|
+
cls: type[ModelCommand],
|
|
117
|
+
model_path: str,
|
|
118
|
+
input_file: Path,
|
|
119
|
+
name: str,
|
|
120
|
+
) -> None:
|
|
121
|
+
instance = BaseCommand._build_circuit() # noqa: SLF001
|
|
122
|
+
instance.model_file_name = model_path
|
|
123
|
+
|
|
124
|
+
try:
|
|
125
|
+
instance.load_model(model_path)
|
|
126
|
+
except Exception as e:
|
|
127
|
+
msg = f"Failed to load model {model_path}: {e}"
|
|
128
|
+
raise RuntimeError(msg) from e
|
|
129
|
+
|
|
130
|
+
try:
|
|
131
|
+
inputs = instance.get_inputs()
|
|
132
|
+
formatted_inputs = instance.format_inputs(inputs)
|
|
133
|
+
to_json(formatted_inputs, str(input_file))
|
|
134
|
+
except Exception as e:
|
|
135
|
+
msg = f"Failed to generate input for {name}: {e}"
|
|
136
|
+
raise RuntimeError(msg) from e
|
|
137
|
+
|
|
138
|
+
@classmethod
|
|
139
|
+
def _run_bench_single_model(
|
|
140
|
+
cls: type[ModelCommand],
|
|
141
|
+
args: argparse.Namespace,
|
|
142
|
+
model_path: str,
|
|
143
|
+
name: str,
|
|
144
|
+
) -> None:
|
|
145
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
146
|
+
tmp_path = Path(tmpdir)
|
|
147
|
+
input_file = tmp_path / "input.json"
|
|
148
|
+
|
|
149
|
+
cls._generate_model_input(model_path, input_file, name)
|
|
150
|
+
|
|
151
|
+
iterations = str(args.iterations or cls.DEFAULT_ITERATIONS)
|
|
152
|
+
results = args.results or f"benchmarking/{name}.jsonl"
|
|
153
|
+
ensure_parent_dir(results)
|
|
154
|
+
|
|
155
|
+
cmd = [
|
|
156
|
+
sys.executable,
|
|
157
|
+
"-m",
|
|
158
|
+
cls.SCRIPT_BENCHMARK_RUNNER,
|
|
159
|
+
"--model",
|
|
160
|
+
model_path,
|
|
161
|
+
"--input",
|
|
162
|
+
str(input_file),
|
|
163
|
+
"--iterations",
|
|
164
|
+
iterations,
|
|
165
|
+
"--output",
|
|
166
|
+
results,
|
|
167
|
+
"--summarize",
|
|
168
|
+
]
|
|
169
|
+
if os.environ.get("JSTPROVE_DEBUG") == "1":
|
|
170
|
+
print(f"[debug] bench {name} cmd:", " ".join(cmd)) # noqa: T201
|
|
171
|
+
|
|
172
|
+
run_subprocess(cmd)
|