JSTprove 1.2.0__py3-none-macosx_11_0_arm64.whl → 1.4.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/METADATA +1 -1
- {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/RECORD +32 -26
- python/core/binaries/onnx_generic_circuit_1-4-0 +0 -0
- python/core/circuits/base.py +29 -12
- python/core/circuits/errors.py +1 -2
- python/core/model_processing/converters/base.py +3 -3
- python/core/model_processing/onnx_custom_ops/__init__.py +5 -4
- python/core/model_processing/onnx_quantizer/exceptions.py +2 -2
- python/core/model_processing/onnx_quantizer/layers/base.py +79 -2
- python/core/model_processing/onnx_quantizer/layers/clip.py +92 -0
- python/core/model_processing/onnx_quantizer/layers/max.py +49 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +79 -4
- python/core/model_processing/onnx_quantizer/layers/min.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -0
- python/core/model_templates/circuit_template.py +48 -38
- python/core/utils/errors.py +1 -1
- python/core/utils/scratch_tests.py +29 -23
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +18 -14
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +11 -13
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +35 -53
- python/tests/onnx_quantizer_tests/layers/base.py +1 -3
- python/tests/onnx_quantizer_tests/layers/clip_config.py +127 -0
- python/tests/onnx_quantizer_tests/layers/max_config.py +100 -0
- python/tests/onnx_quantizer_tests/layers/maxpool_config.py +106 -0
- python/tests/onnx_quantizer_tests/layers/min_config.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +6 -5
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +6 -1
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +17 -8
- python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/WHEEL +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.4.0.dist-info}/top_level.txt +0 -0
|
@@ -3,9 +3,12 @@ from __future__ import annotations
|
|
|
3
3
|
from typing import TYPE_CHECKING
|
|
4
4
|
|
|
5
5
|
if TYPE_CHECKING:
|
|
6
|
+
from typing import ClassVar
|
|
7
|
+
|
|
6
8
|
import onnx
|
|
7
9
|
|
|
8
10
|
from python.core.model_processing.onnx_custom_ops.onnx_helpers import (
|
|
11
|
+
extract_attributes,
|
|
9
12
|
get_attribute_ints,
|
|
10
13
|
)
|
|
11
14
|
from python.core.model_processing.onnx_quantizer.exceptions import InvalidParamError
|
|
@@ -21,6 +24,12 @@ class QuantizeMaxpool(QuantizerBase):
|
|
|
21
24
|
USE_WB = False
|
|
22
25
|
USE_SCALING = False
|
|
23
26
|
|
|
27
|
+
DEFAULT_ATTRS: ClassVar = {
|
|
28
|
+
"dilations": [1],
|
|
29
|
+
"pads": [0],
|
|
30
|
+
"strides": [1],
|
|
31
|
+
}
|
|
32
|
+
|
|
24
33
|
|
|
25
34
|
class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
26
35
|
"""
|
|
@@ -72,6 +81,35 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
72
81
|
InvalidParamError: If any requirement is not met.
|
|
73
82
|
"""
|
|
74
83
|
_ = initializer_map
|
|
84
|
+
attributes = extract_attributes(node)
|
|
85
|
+
ceil_mode = attributes.get("ceil_mode", None)
|
|
86
|
+
auto_pad = attributes.get("auto_pad", None)
|
|
87
|
+
storage_order = attributes.get("storage_order", None)
|
|
88
|
+
|
|
89
|
+
if ceil_mode != 0 and ceil_mode is not None:
|
|
90
|
+
raise InvalidParamError(
|
|
91
|
+
node.name,
|
|
92
|
+
node.op_type,
|
|
93
|
+
"ceil_mode must be 0",
|
|
94
|
+
"ceil_mode",
|
|
95
|
+
"0",
|
|
96
|
+
)
|
|
97
|
+
if auto_pad != "NOTSET" and auto_pad is not None:
|
|
98
|
+
raise InvalidParamError(
|
|
99
|
+
node.name,
|
|
100
|
+
node.op_type,
|
|
101
|
+
"auto_pad must be NOTSET",
|
|
102
|
+
"auto_pad",
|
|
103
|
+
"NOTSET",
|
|
104
|
+
)
|
|
105
|
+
if storage_order != 0 and storage_order is not None:
|
|
106
|
+
raise InvalidParamError(
|
|
107
|
+
node.name,
|
|
108
|
+
node.op_type,
|
|
109
|
+
"storage_order must be 0",
|
|
110
|
+
"storage_order",
|
|
111
|
+
"0",
|
|
112
|
+
)
|
|
75
113
|
self.check_all_params_exist(node)
|
|
76
114
|
self.check_params_size(node)
|
|
77
115
|
self.check_pool_pads(node)
|
|
@@ -85,8 +123,8 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
85
123
|
Raises:
|
|
86
124
|
InvalidParamError: If shape requirement is not met.
|
|
87
125
|
"""
|
|
88
|
-
|
|
89
|
-
|
|
126
|
+
required_attrs = ["kernel_shape"]
|
|
127
|
+
|
|
90
128
|
self.validate_required_attrs(node, required_attrs)
|
|
91
129
|
|
|
92
130
|
# Check dimension of kernel
|
|
@@ -121,11 +159,23 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
121
159
|
|
|
122
160
|
def check_pool_pads(self: MaxpoolQuantizer, node: onnx.NodeProto) -> None:
|
|
123
161
|
kernel_shape = get_attribute_ints(node, "kernel_shape", default=[])
|
|
124
|
-
|
|
162
|
+
pads_raw = get_attribute_ints(
|
|
163
|
+
node,
|
|
164
|
+
"pads",
|
|
165
|
+
default=self.DEFAULT_ATTRS.get("pads", None),
|
|
166
|
+
)
|
|
167
|
+
pads = self.adjust_pads(node, pads_raw)
|
|
168
|
+
|
|
125
169
|
if pads is None:
|
|
126
170
|
return
|
|
127
171
|
num_dims = len(kernel_shape)
|
|
128
|
-
|
|
172
|
+
|
|
173
|
+
if len(pads) == 1:
|
|
174
|
+
pads = pads * 2 * num_dims
|
|
175
|
+
elif len(pads) == num_dims:
|
|
176
|
+
# If only beginning pads given, repeat for end pads
|
|
177
|
+
pads = pads + pads
|
|
178
|
+
elif len(pads) != num_dims * 2:
|
|
129
179
|
raise InvalidParamError(
|
|
130
180
|
node.name,
|
|
131
181
|
node.op_type,
|
|
@@ -148,3 +198,28 @@ class MaxpoolQuantizer(BaseOpQuantizer, QuantizeMaxpool):
|
|
|
148
198
|
node.op_type,
|
|
149
199
|
f"pads[{dim + num_dims}]={pad_after} >= kernel[{dim}]={kernel}",
|
|
150
200
|
)
|
|
201
|
+
|
|
202
|
+
def adjust_pads(
|
|
203
|
+
self: MaxpoolQuantizer,
|
|
204
|
+
node: onnx.NodeProto,
|
|
205
|
+
pads_raw: str | int | list[int] | None,
|
|
206
|
+
) -> list[int]:
|
|
207
|
+
if pads_raw is None:
|
|
208
|
+
pads: list[int] = []
|
|
209
|
+
elif isinstance(pads_raw, str):
|
|
210
|
+
# single string, could be "0" or "1 2"
|
|
211
|
+
pads = [int(x) for x in pads_raw.split()]
|
|
212
|
+
elif isinstance(pads_raw, int):
|
|
213
|
+
# single integer
|
|
214
|
+
pads = [pads_raw]
|
|
215
|
+
elif isinstance(pads_raw, (list, tuple)):
|
|
216
|
+
# already a list of numbers (may be strings)
|
|
217
|
+
pads = [int(x) for x in pads_raw]
|
|
218
|
+
else:
|
|
219
|
+
raise InvalidParamError(
|
|
220
|
+
node.name,
|
|
221
|
+
node.op_type,
|
|
222
|
+
f"Cannot parse pads: {pads_raw}",
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
return pads
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING, ClassVar
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
import onnx
|
|
7
|
+
|
|
8
|
+
from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
9
|
+
BaseOpQuantizer,
|
|
10
|
+
QuantizerBase,
|
|
11
|
+
ScaleConfig,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class QuantizeMin(QuantizerBase):
|
|
16
|
+
OP_TYPE = "Min"
|
|
17
|
+
DOMAIN = "" # standard ONNX domain
|
|
18
|
+
USE_WB = True # let framework wire inputs/outputs normally
|
|
19
|
+
USE_SCALING = False # passthrough: no internal scaling
|
|
20
|
+
SCALE_PLAN: ClassVar = {1: 1} # elementwise arity plan
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class MinQuantizer(BaseOpQuantizer, QuantizeMin):
|
|
24
|
+
"""
|
|
25
|
+
Passthrough quantizer for elementwise Min.
|
|
26
|
+
We rely on the converter to quantize graph inputs; no extra scaling here.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self: MinQuantizer,
|
|
31
|
+
new_initializers: list[onnx.TensorProto] | None = None,
|
|
32
|
+
) -> None:
|
|
33
|
+
super().__init__()
|
|
34
|
+
if new_initializers is not None:
|
|
35
|
+
self.new_initializers = new_initializers
|
|
36
|
+
|
|
37
|
+
def quantize(
|
|
38
|
+
self: MinQuantizer,
|
|
39
|
+
node: onnx.NodeProto,
|
|
40
|
+
graph: onnx.GraphProto,
|
|
41
|
+
scale_config: ScaleConfig,
|
|
42
|
+
initializer_map: dict[str, onnx.TensorProto],
|
|
43
|
+
) -> list[onnx.NodeProto]:
|
|
44
|
+
# Delegate to QuantizerBase's generic passthrough implementation.
|
|
45
|
+
return QuantizeMin.quantize(self, node, graph, scale_config, initializer_map)
|
|
46
|
+
|
|
47
|
+
def check_supported(
|
|
48
|
+
self: MinQuantizer,
|
|
49
|
+
node: onnx.NodeProto,
|
|
50
|
+
initializer_map: dict[str, onnx.TensorProto] | None = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
# Min has no attributes; elementwise, variadic ≥ 1 input per ONNX spec.
|
|
53
|
+
# We mirror Add/Max broadcasting behavior; no extra checks here.
|
|
54
|
+
_ = node, initializer_map
|
|
@@ -20,12 +20,15 @@ from python.core.model_processing.onnx_quantizer.layers.base import (
|
|
|
20
20
|
from python.core.model_processing.onnx_quantizer.layers.batchnorm import (
|
|
21
21
|
BatchnormQuantizer,
|
|
22
22
|
)
|
|
23
|
+
from python.core.model_processing.onnx_quantizer.layers.clip import ClipQuantizer
|
|
23
24
|
from python.core.model_processing.onnx_quantizer.layers.constant import (
|
|
24
25
|
ConstantQuantizer,
|
|
25
26
|
)
|
|
26
27
|
from python.core.model_processing.onnx_quantizer.layers.conv import ConvQuantizer
|
|
27
28
|
from python.core.model_processing.onnx_quantizer.layers.gemm import GemmQuantizer
|
|
29
|
+
from python.core.model_processing.onnx_quantizer.layers.max import MaxQuantizer
|
|
28
30
|
from python.core.model_processing.onnx_quantizer.layers.maxpool import MaxpoolQuantizer
|
|
31
|
+
from python.core.model_processing.onnx_quantizer.layers.min import MinQuantizer
|
|
29
32
|
from python.core.model_processing.onnx_quantizer.layers.mul import MulQuantizer
|
|
30
33
|
from python.core.model_processing.onnx_quantizer.layers.relu import ReluQuantizer
|
|
31
34
|
from python.core.model_processing.onnx_quantizer.layers.sub import SubQuantizer
|
|
@@ -74,6 +77,7 @@ class ONNXOpQuantizer:
|
|
|
74
77
|
|
|
75
78
|
# Register handlers
|
|
76
79
|
self.register("Add", AddQuantizer(self.new_initializers))
|
|
80
|
+
self.register("Clip", ClipQuantizer(self.new_initializers))
|
|
77
81
|
self.register("Sub", SubQuantizer(self.new_initializers))
|
|
78
82
|
self.register("Mul", MulQuantizer(self.new_initializers))
|
|
79
83
|
self.register("Conv", ConvQuantizer(self.new_initializers))
|
|
@@ -83,6 +87,8 @@ class ONNXOpQuantizer:
|
|
|
83
87
|
self.register("Constant", ConstantQuantizer())
|
|
84
88
|
self.register("MaxPool", MaxpoolQuantizer())
|
|
85
89
|
self.register("Flatten", PassthroughQuantizer())
|
|
90
|
+
self.register("Max", MaxQuantizer(self.new_initializers))
|
|
91
|
+
self.register("Min", MinQuantizer(self.new_initializers))
|
|
86
92
|
self.register("BatchNormalization", BatchnormQuantizer(self.new_initializers))
|
|
87
93
|
|
|
88
94
|
def register(
|
|
@@ -1,57 +1,67 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from secrets import randbelow
|
|
4
|
+
|
|
2
5
|
from python.core.circuits.base import Circuit
|
|
3
|
-
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
class SimpleCircuit(Circuit):
|
|
6
|
-
|
|
7
|
-
Note: This template is irrelevant if using the
|
|
8
|
-
The template only helps developers if they choose to incorporate other circuit
|
|
9
|
+
"""
|
|
10
|
+
Note: This template is irrelevant if using the ONNX circuit builder.
|
|
11
|
+
The template only helps developers if they choose to incorporate other circuit
|
|
12
|
+
builders into the framework.
|
|
9
13
|
|
|
10
|
-
To begin, we need to specify some basic attributes surrounding the circuit we will
|
|
11
|
-
|
|
12
|
-
name - name of the rust bin to be run by the circuit.
|
|
14
|
+
To begin, we need to specify some basic attributes surrounding the circuit we will
|
|
15
|
+
be using.
|
|
13
16
|
|
|
14
|
-
|
|
15
|
-
|
|
17
|
+
- `required_keys`: the variables in the input dictionary (and input file).
|
|
18
|
+
- `name`: name of the Rust bin to be run by the circuit.
|
|
19
|
+
- `scale_base`: base of the scaling applied to each value.
|
|
20
|
+
- `scale_exponent`: exponent applied to the base to get the scaling factor.
|
|
21
|
+
Scaling factor will be multiplied by each input.
|
|
16
22
|
|
|
17
|
-
Other default inputs can be defined below
|
|
18
|
-
|
|
19
|
-
|
|
23
|
+
Other default inputs can be defined below.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, file_name: str | None = None) -> None:
|
|
20
27
|
# Initialize the base class
|
|
21
28
|
super().__init__()
|
|
22
|
-
|
|
29
|
+
self.file_name = file_name
|
|
30
|
+
|
|
23
31
|
# Circuit-specific parameters
|
|
24
32
|
self.required_keys = ["input_a", "input_b", "nonce"]
|
|
25
33
|
self.name = "simple_circuit" # Use exact name that matches the binary
|
|
26
34
|
|
|
27
35
|
self.scale_exponent = 1
|
|
28
36
|
self.scale_base = 1
|
|
29
|
-
|
|
37
|
+
|
|
30
38
|
self.input_a = 100
|
|
31
39
|
self.input_b = 200
|
|
32
|
-
self.nonce =
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
40
|
+
self.nonce = randbelow(10_000)
|
|
41
|
+
|
|
42
|
+
def get_inputs(self) -> dict[str, int]:
|
|
43
|
+
"""
|
|
44
|
+
Specify the inputs to the circuit, based on what was specified
|
|
45
|
+
in `__init__`.
|
|
46
|
+
"""
|
|
47
|
+
return {
|
|
48
|
+
"input_a": self.input_a,
|
|
49
|
+
"input_b": self.input_b,
|
|
50
|
+
"nonce": self.nonce,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def get_outputs(self, inputs: dict[str, int] | None = None) -> int:
|
|
44
54
|
"""
|
|
45
55
|
Compute the output of the circuit.
|
|
46
|
-
|
|
56
|
+
|
|
57
|
+
This is overwritten from the base class to ensure computation happens
|
|
58
|
+
only once.
|
|
47
59
|
"""
|
|
48
|
-
if inputs
|
|
49
|
-
inputs = {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
# def format_outputs(self, outputs):
|
|
57
|
-
# return {"output": outputs.long().tolist()}
|
|
60
|
+
if inputs is None:
|
|
61
|
+
inputs = {
|
|
62
|
+
"input_a": self.input_a,
|
|
63
|
+
"input_b": self.input_b,
|
|
64
|
+
"nonce": self.nonce,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return inputs["input_a"] + inputs["input_b"]
|
python/core/utils/errors.py
CHANGED
|
@@ -30,7 +30,7 @@ class FileCacheError(CircuitExecutionError):
|
|
|
30
30
|
class ProofBackendError(CircuitExecutionError):
|
|
31
31
|
"""Raised when a Cargo command fails."""
|
|
32
32
|
|
|
33
|
-
def __init__(
|
|
33
|
+
def __init__(
|
|
34
34
|
self: ProofBackendError,
|
|
35
35
|
message: str,
|
|
36
36
|
command: list[str] | None = None,
|
|
@@ -1,66 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import onnx
|
|
2
|
-
from onnx import TensorProto, helper, shape_inference
|
|
3
|
-
from onnx import numpy_helper
|
|
4
|
-
from onnx import load, save
|
|
4
|
+
from onnx import TensorProto, helper, load, shape_inference
|
|
5
5
|
from onnx.utils import extract_model
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
def prune_model(
|
|
9
|
+
model_path: str,
|
|
10
|
+
output_names: list[str],
|
|
11
|
+
save_path: str,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""Extract a sub-model with the same inputs and new outputs."""
|
|
8
14
|
model = load(model_path)
|
|
9
15
|
|
|
10
|
-
# Provide model input names and the new desired output names
|
|
16
|
+
# Provide model input names and the new desired output names.
|
|
11
17
|
input_names = [i.name for i in model.graph.input]
|
|
12
18
|
|
|
13
19
|
extract_model(
|
|
14
20
|
input_path=model_path,
|
|
15
21
|
output_path=save_path,
|
|
16
22
|
input_names=input_names,
|
|
17
|
-
output_names=output_names
|
|
23
|
+
output_names=output_names,
|
|
18
24
|
)
|
|
19
25
|
|
|
20
|
-
print(f"Pruned model saved to {save_path}")
|
|
26
|
+
print(f"Pruned model saved to {save_path}") # noqa: T201
|
|
21
27
|
|
|
22
28
|
|
|
23
|
-
def cut_model(
|
|
29
|
+
def cut_model(
|
|
30
|
+
model_path: str,
|
|
31
|
+
output_names: list[str],
|
|
32
|
+
save_path: str,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Replace the graph outputs with the tensors named in `output_names`."""
|
|
24
35
|
model = onnx.load(model_path)
|
|
25
36
|
model = shape_inference.infer_shapes(model)
|
|
26
37
|
|
|
27
38
|
graph = model.graph
|
|
28
39
|
|
|
29
|
-
# Remove all current outputs one by one (cannot use .clear() or assignment)
|
|
30
|
-
while
|
|
40
|
+
# Remove all current outputs one by one (cannot use .clear() or assignment).
|
|
41
|
+
while graph.output:
|
|
31
42
|
graph.output.pop()
|
|
32
43
|
|
|
33
|
-
# Add new outputs
|
|
44
|
+
# Add new outputs.
|
|
34
45
|
for name in output_names:
|
|
35
|
-
# Look in value_info, input, or output
|
|
46
|
+
# Look in value_info, input, or output.
|
|
36
47
|
candidates = list(graph.value_info) + list(graph.input) + list(graph.output)
|
|
37
48
|
value_info = next((vi for vi in candidates if vi.name == name), None)
|
|
38
49
|
if value_info is None:
|
|
39
|
-
|
|
50
|
+
msg = f"Tensor {name} not found in model graph."
|
|
51
|
+
raise ValueError(msg)
|
|
40
52
|
|
|
41
53
|
elem_type = value_info.type.tensor_type.elem_type
|
|
42
54
|
shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
|
|
43
55
|
new_output = helper.make_tensor_value_info(name, elem_type, shape)
|
|
44
56
|
graph.output.append(new_output)
|
|
57
|
+
|
|
45
58
|
for output in graph.output:
|
|
46
|
-
print(output)
|
|
59
|
+
print(output) # noqa: T201
|
|
47
60
|
if output.name == "/conv1/Conv_output_0":
|
|
48
61
|
output.type.tensor_type.elem_type = TensorProto.INT64
|
|
49
62
|
|
|
50
63
|
onnx.save(model, save_path)
|
|
51
|
-
print(f"Saved cut model with outputs {output_names} to {save_path}")
|
|
64
|
+
print(f"Saved cut model with outputs {output_names} to {save_path}") # noqa: T201
|
|
52
65
|
|
|
53
66
|
|
|
54
67
|
if __name__ == "__main__":
|
|
55
|
-
# /conv1/Conv_output_0
|
|
56
|
-
# prune_model(
|
|
57
|
-
# model_path="models_onnx/doom.onnx",
|
|
58
|
-
# output_names=["/Relu_2_output_0"], # replace with your intermediate tensor
|
|
59
|
-
# save_path= "models_onnx/test_doom_cut.onnx"
|
|
60
|
-
# )
|
|
61
|
-
# cut_model("models_onnx/doom.onnx",["/Relu_2_output_0"], "test_doom_after_conv.onnx")
|
|
62
68
|
prune_model(
|
|
63
69
|
model_path="models_onnx/doom.onnx",
|
|
64
70
|
output_names=["/Relu_3_output_0"], # replace with your intermediate tensor
|
|
65
|
-
save_path=
|
|
71
|
+
save_path="models_onnx/test_doom_cut.onnx",
|
|
66
72
|
)
|
|
@@ -2,7 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from collections.abc import Generator
|
|
9
|
+
|
|
6
10
|
|
|
7
11
|
import pytest
|
|
8
12
|
import torch
|
|
@@ -35,7 +39,7 @@ OUTPUTTWICE = 2
|
|
|
35
39
|
OUTPUTTHREETIMES = 3
|
|
36
40
|
|
|
37
41
|
|
|
38
|
-
@pytest.mark.e2e
|
|
42
|
+
@pytest.mark.e2e
|
|
39
43
|
def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
|
|
40
44
|
# Here you could just check that circuit file exists
|
|
41
45
|
circuit_compile_results[model_fixture["model"]] = False
|
|
@@ -43,7 +47,7 @@ def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
|
|
|
43
47
|
circuit_compile_results[model_fixture["model"]] = True
|
|
44
48
|
|
|
45
49
|
|
|
46
|
-
@pytest.mark.e2e
|
|
50
|
+
@pytest.mark.e2e
|
|
47
51
|
def test_witness_dev(
|
|
48
52
|
model_fixture: dict[str, Any],
|
|
49
53
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -89,7 +93,7 @@ def test_witness_dev(
|
|
|
89
93
|
witness_generated_results[model_fixture["model"]] = True
|
|
90
94
|
|
|
91
95
|
|
|
92
|
-
@pytest.mark.e2e
|
|
96
|
+
@pytest.mark.e2e
|
|
93
97
|
def test_witness_wrong_outputs_dev(
|
|
94
98
|
model_fixture: dict[str, Any],
|
|
95
99
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -145,7 +149,7 @@ def test_witness_wrong_outputs_dev(
|
|
|
145
149
|
), f"Expected '{output}' in stdout, but it was not found."
|
|
146
150
|
|
|
147
151
|
|
|
148
|
-
@pytest.mark.e2e
|
|
152
|
+
@pytest.mark.e2e
|
|
149
153
|
def test_witness_prove_verify_true_inputs_dev(
|
|
150
154
|
model_fixture: dict[str, Any],
|
|
151
155
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -239,7 +243,7 @@ def test_witness_prove_verify_true_inputs_dev(
|
|
|
239
243
|
), "Expected 'Verified' in stdout three times, but it was not found."
|
|
240
244
|
|
|
241
245
|
|
|
242
|
-
@pytest.mark.e2e
|
|
246
|
+
@pytest.mark.e2e
|
|
243
247
|
def test_witness_prove_verify_true_inputs_dev_expander_call(
|
|
244
248
|
model_fixture: dict[str, Any],
|
|
245
249
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -339,7 +343,7 @@ def test_witness_prove_verify_true_inputs_dev_expander_call(
|
|
|
339
343
|
assert stdout.count("proving") == 1, "Expected 'proving' but it was not found."
|
|
340
344
|
|
|
341
345
|
|
|
342
|
-
@pytest.mark.e2e
|
|
346
|
+
@pytest.mark.e2e
|
|
343
347
|
def test_witness_read_after_write_json(
|
|
344
348
|
model_fixture: dict[str, Any],
|
|
345
349
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -421,7 +425,7 @@ def test_witness_read_after_write_json(
|
|
|
421
425
|
), "Input JSON read is not identical to what was written"
|
|
422
426
|
|
|
423
427
|
|
|
424
|
-
@pytest.mark.e2e
|
|
428
|
+
@pytest.mark.e2e
|
|
425
429
|
def test_witness_fresh_compile_dev(
|
|
426
430
|
model_fixture: dict[str, Any],
|
|
427
431
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -467,7 +471,7 @@ def test_witness_fresh_compile_dev(
|
|
|
467
471
|
|
|
468
472
|
|
|
469
473
|
# Use once fixed input shape read in rust
|
|
470
|
-
@pytest.mark.e2e
|
|
474
|
+
@pytest.mark.e2e
|
|
471
475
|
def test_witness_incorrect_input_shape(
|
|
472
476
|
model_fixture: dict[str, Any],
|
|
473
477
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -555,7 +559,7 @@ def test_witness_incorrect_input_shape(
|
|
|
555
559
|
), f"Did not expect '{output}' in stdout, but it was found."
|
|
556
560
|
|
|
557
561
|
|
|
558
|
-
@pytest.mark.e2e
|
|
562
|
+
@pytest.mark.e2e
|
|
559
563
|
def test_witness_unscaled(
|
|
560
564
|
model_fixture: dict[str, Any],
|
|
561
565
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -666,7 +670,7 @@ def test_witness_unscaled(
|
|
|
666
670
|
)
|
|
667
671
|
|
|
668
672
|
|
|
669
|
-
@pytest.mark.e2e
|
|
673
|
+
@pytest.mark.e2e
|
|
670
674
|
def test_witness_unscaled_and_incorrect_shape_input(
|
|
671
675
|
model_fixture: dict[str, Any],
|
|
672
676
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -790,7 +794,7 @@ def test_witness_unscaled_and_incorrect_shape_input(
|
|
|
790
794
|
)
|
|
791
795
|
|
|
792
796
|
|
|
793
|
-
@pytest.mark.e2e
|
|
797
|
+
@pytest.mark.e2e
|
|
794
798
|
def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
|
|
795
799
|
model_fixture: dict[str, Any],
|
|
796
800
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -937,7 +941,7 @@ def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
|
|
|
937
941
|
)
|
|
938
942
|
|
|
939
943
|
|
|
940
|
-
@pytest.mark.e2e
|
|
944
|
+
@pytest.mark.e2e
|
|
941
945
|
def test_witness_wrong_name(
|
|
942
946
|
model_fixture: dict[str, Any],
|
|
943
947
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -1063,7 +1067,7 @@ def add_to_first_scalar(data: list, delta: float = 0.1) -> bool:
|
|
|
1063
1067
|
return False
|
|
1064
1068
|
|
|
1065
1069
|
|
|
1066
|
-
@pytest.mark.e2e
|
|
1070
|
+
@pytest.mark.e2e
|
|
1067
1071
|
def test_witness_prove_verify_false_inputs_dev(
|
|
1068
1072
|
model_fixture: dict[str, Any],
|
|
1069
1073
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Generator, Mapping, Sequence
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypeAlias
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytest
|
|
@@ -64,7 +65,7 @@ def model_fixture(
|
|
|
64
65
|
}
|
|
65
66
|
|
|
66
67
|
|
|
67
|
-
@pytest.fixture
|
|
68
|
+
@pytest.fixture
|
|
68
69
|
def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
69
70
|
witness_path = tmp_path / "temp_witness.txt"
|
|
70
71
|
# Give it to the test
|
|
@@ -75,7 +76,7 @@ def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
75
76
|
witness_path.unlink()
|
|
76
77
|
|
|
77
78
|
|
|
78
|
-
@pytest.fixture
|
|
79
|
+
@pytest.fixture
|
|
79
80
|
def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
80
81
|
input_path = tmp_path / "temp_input.txt"
|
|
81
82
|
# Give it to the test
|
|
@@ -86,7 +87,7 @@ def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
86
87
|
input_path.unlink()
|
|
87
88
|
|
|
88
89
|
|
|
89
|
-
@pytest.fixture
|
|
90
|
+
@pytest.fixture
|
|
90
91
|
def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
91
92
|
output_path = tmp_path / "temp_output.txt"
|
|
92
93
|
# Give it to the test
|
|
@@ -97,7 +98,7 @@ def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
97
98
|
output_path.unlink()
|
|
98
99
|
|
|
99
100
|
|
|
100
|
-
@pytest.fixture
|
|
101
|
+
@pytest.fixture
|
|
101
102
|
def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
102
103
|
output_path = tmp_path / "temp_proof.txt"
|
|
103
104
|
# Give it to the test
|
|
@@ -108,13 +109,10 @@ def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
108
109
|
output_path.unlink()
|
|
109
110
|
|
|
110
111
|
|
|
111
|
-
ScalarOrTensor: TypeAlias =
|
|
112
|
-
NestedArray: TypeAlias =
|
|
113
|
-
ScalarOrTensor
|
|
114
|
-
|
|
115
|
-
tuple["NestedArray"],
|
|
116
|
-
np.ndarray,
|
|
117
|
-
]
|
|
112
|
+
ScalarOrTensor: TypeAlias = int | float | torch.Tensor
|
|
113
|
+
NestedArray: TypeAlias = (
|
|
114
|
+
ScalarOrTensor | list["NestedArray"] | tuple["NestedArray"] | np.ndarray
|
|
115
|
+
)
|
|
118
116
|
|
|
119
117
|
|
|
120
118
|
def add_1_to_first_element(x: NestedArray) -> NestedArray:
|
|
@@ -137,7 +135,7 @@ def add_1_to_first_element(x: NestedArray) -> NestedArray:
|
|
|
137
135
|
circuit_compile_results = {}
|
|
138
136
|
witness_generated_results = {}
|
|
139
137
|
|
|
140
|
-
Nested: TypeAlias =
|
|
138
|
+
Nested: TypeAlias = float | Mapping[str, "Nested"] | Sequence["Nested"]
|
|
141
139
|
|
|
142
140
|
|
|
143
141
|
def contains_float(obj: Nested) -> bool:
|