JSTprove 1.2.0__py3-none-macosx_11_0_arm64.whl → 1.3.0__py3-none-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of JSTprove might be problematic. Click here for more details.
- {jstprove-1.2.0.dist-info → jstprove-1.3.0.dist-info}/METADATA +1 -1
- {jstprove-1.2.0.dist-info → jstprove-1.3.0.dist-info}/RECORD +30 -24
- python/core/binaries/onnx_generic_circuit_1-3-0 +0 -0
- python/core/circuits/base.py +29 -12
- python/core/circuits/errors.py +1 -2
- python/core/model_processing/converters/base.py +3 -3
- python/core/model_processing/onnx_custom_ops/__init__.py +5 -4
- python/core/model_processing/onnx_quantizer/exceptions.py +2 -2
- python/core/model_processing/onnx_quantizer/layers/base.py +34 -0
- python/core/model_processing/onnx_quantizer/layers/clip.py +92 -0
- python/core/model_processing/onnx_quantizer/layers/max.py +49 -0
- python/core/model_processing/onnx_quantizer/layers/min.py +54 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +6 -0
- python/core/model_templates/circuit_template.py +48 -38
- python/core/utils/errors.py +1 -1
- python/core/utils/scratch_tests.py +29 -23
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +18 -14
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +11 -13
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +35 -53
- python/tests/onnx_quantizer_tests/layers/base.py +1 -3
- python/tests/onnx_quantizer_tests/layers/clip_config.py +127 -0
- python/tests/onnx_quantizer_tests/layers/max_config.py +100 -0
- python/tests/onnx_quantizer_tests/layers/min_config.py +94 -0
- python/tests/onnx_quantizer_tests/layers_tests/test_integration.py +6 -5
- python/tests/onnx_quantizer_tests/layers_tests/test_quantize.py +6 -1
- python/tests/onnx_quantizer_tests/test_registered_quantizers.py +17 -8
- python/core/binaries/onnx_generic_circuit_1-2-0 +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.3.0.dist-info}/WHEEL +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.3.0.dist-info}/entry_points.txt +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.3.0.dist-info}/licenses/LICENSE +0 -0
- {jstprove-1.2.0.dist-info → jstprove-1.3.0.dist-info}/top_level.txt +0 -0
|
@@ -1,57 +1,67 @@
|
|
|
1
|
-
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from secrets import randbelow
|
|
4
|
+
|
|
2
5
|
from python.core.circuits.base import Circuit
|
|
3
|
-
|
|
6
|
+
|
|
4
7
|
|
|
5
8
|
class SimpleCircuit(Circuit):
|
|
6
|
-
|
|
7
|
-
Note: This template is irrelevant if using the
|
|
8
|
-
The template only helps developers if they choose to incorporate other circuit
|
|
9
|
+
"""
|
|
10
|
+
Note: This template is irrelevant if using the ONNX circuit builder.
|
|
11
|
+
The template only helps developers if they choose to incorporate other circuit
|
|
12
|
+
builders into the framework.
|
|
9
13
|
|
|
10
|
-
To begin, we need to specify some basic attributes surrounding the circuit we will
|
|
11
|
-
|
|
12
|
-
name - name of the rust bin to be run by the circuit.
|
|
14
|
+
To begin, we need to specify some basic attributes surrounding the circuit we will
|
|
15
|
+
be using.
|
|
13
16
|
|
|
14
|
-
|
|
15
|
-
|
|
17
|
+
- `required_keys`: the variables in the input dictionary (and input file).
|
|
18
|
+
- `name`: name of the Rust bin to be run by the circuit.
|
|
19
|
+
- `scale_base`: base of the scaling applied to each value.
|
|
20
|
+
- `scale_exponent`: exponent applied to the base to get the scaling factor.
|
|
21
|
+
Scaling factor will be multiplied by each input.
|
|
16
22
|
|
|
17
|
-
Other default inputs can be defined below
|
|
18
|
-
|
|
19
|
-
|
|
23
|
+
Other default inputs can be defined below.
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, file_name: str | None = None) -> None:
|
|
20
27
|
# Initialize the base class
|
|
21
28
|
super().__init__()
|
|
22
|
-
|
|
29
|
+
self.file_name = file_name
|
|
30
|
+
|
|
23
31
|
# Circuit-specific parameters
|
|
24
32
|
self.required_keys = ["input_a", "input_b", "nonce"]
|
|
25
33
|
self.name = "simple_circuit" # Use exact name that matches the binary
|
|
26
34
|
|
|
27
35
|
self.scale_exponent = 1
|
|
28
36
|
self.scale_base = 1
|
|
29
|
-
|
|
37
|
+
|
|
30
38
|
self.input_a = 100
|
|
31
39
|
self.input_b = 200
|
|
32
|
-
self.nonce =
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
40
|
+
self.nonce = randbelow(10_000)
|
|
41
|
+
|
|
42
|
+
def get_inputs(self) -> dict[str, int]:
|
|
43
|
+
"""
|
|
44
|
+
Specify the inputs to the circuit, based on what was specified
|
|
45
|
+
in `__init__`.
|
|
46
|
+
"""
|
|
47
|
+
return {
|
|
48
|
+
"input_a": self.input_a,
|
|
49
|
+
"input_b": self.input_b,
|
|
50
|
+
"nonce": self.nonce,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
def get_outputs(self, inputs: dict[str, int] | None = None) -> int:
|
|
44
54
|
"""
|
|
45
55
|
Compute the output of the circuit.
|
|
46
|
-
|
|
56
|
+
|
|
57
|
+
This is overwritten from the base class to ensure computation happens
|
|
58
|
+
only once.
|
|
47
59
|
"""
|
|
48
|
-
if inputs
|
|
49
|
-
inputs = {
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
# def format_outputs(self, outputs):
|
|
57
|
-
# return {"output": outputs.long().tolist()}
|
|
60
|
+
if inputs is None:
|
|
61
|
+
inputs = {
|
|
62
|
+
"input_a": self.input_a,
|
|
63
|
+
"input_b": self.input_b,
|
|
64
|
+
"nonce": self.nonce,
|
|
65
|
+
}
|
|
66
|
+
|
|
67
|
+
return inputs["input_a"] + inputs["input_b"]
|
python/core/utils/errors.py
CHANGED
|
@@ -30,7 +30,7 @@ class FileCacheError(CircuitExecutionError):
|
|
|
30
30
|
class ProofBackendError(CircuitExecutionError):
|
|
31
31
|
"""Raised when a Cargo command fails."""
|
|
32
32
|
|
|
33
|
-
def __init__(
|
|
33
|
+
def __init__(
|
|
34
34
|
self: ProofBackendError,
|
|
35
35
|
message: str,
|
|
36
36
|
command: list[str] | None = None,
|
|
@@ -1,66 +1,72 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import onnx
|
|
2
|
-
from onnx import TensorProto, helper, shape_inference
|
|
3
|
-
from onnx import numpy_helper
|
|
4
|
-
from onnx import load, save
|
|
4
|
+
from onnx import TensorProto, helper, load, shape_inference
|
|
5
5
|
from onnx.utils import extract_model
|
|
6
6
|
|
|
7
|
-
|
|
7
|
+
|
|
8
|
+
def prune_model(
|
|
9
|
+
model_path: str,
|
|
10
|
+
output_names: list[str],
|
|
11
|
+
save_path: str,
|
|
12
|
+
) -> None:
|
|
13
|
+
"""Extract a sub-model with the same inputs and new outputs."""
|
|
8
14
|
model = load(model_path)
|
|
9
15
|
|
|
10
|
-
# Provide model input names and the new desired output names
|
|
16
|
+
# Provide model input names and the new desired output names.
|
|
11
17
|
input_names = [i.name for i in model.graph.input]
|
|
12
18
|
|
|
13
19
|
extract_model(
|
|
14
20
|
input_path=model_path,
|
|
15
21
|
output_path=save_path,
|
|
16
22
|
input_names=input_names,
|
|
17
|
-
output_names=output_names
|
|
23
|
+
output_names=output_names,
|
|
18
24
|
)
|
|
19
25
|
|
|
20
|
-
print(f"Pruned model saved to {save_path}")
|
|
26
|
+
print(f"Pruned model saved to {save_path}") # noqa: T201
|
|
21
27
|
|
|
22
28
|
|
|
23
|
-
def cut_model(
|
|
29
|
+
def cut_model(
|
|
30
|
+
model_path: str,
|
|
31
|
+
output_names: list[str],
|
|
32
|
+
save_path: str,
|
|
33
|
+
) -> None:
|
|
34
|
+
"""Replace the graph outputs with the tensors named in `output_names`."""
|
|
24
35
|
model = onnx.load(model_path)
|
|
25
36
|
model = shape_inference.infer_shapes(model)
|
|
26
37
|
|
|
27
38
|
graph = model.graph
|
|
28
39
|
|
|
29
|
-
# Remove all current outputs one by one (cannot use .clear() or assignment)
|
|
30
|
-
while
|
|
40
|
+
# Remove all current outputs one by one (cannot use .clear() or assignment).
|
|
41
|
+
while graph.output:
|
|
31
42
|
graph.output.pop()
|
|
32
43
|
|
|
33
|
-
# Add new outputs
|
|
44
|
+
# Add new outputs.
|
|
34
45
|
for name in output_names:
|
|
35
|
-
# Look in value_info, input, or output
|
|
46
|
+
# Look in value_info, input, or output.
|
|
36
47
|
candidates = list(graph.value_info) + list(graph.input) + list(graph.output)
|
|
37
48
|
value_info = next((vi for vi in candidates if vi.name == name), None)
|
|
38
49
|
if value_info is None:
|
|
39
|
-
|
|
50
|
+
msg = f"Tensor {name} not found in model graph."
|
|
51
|
+
raise ValueError(msg)
|
|
40
52
|
|
|
41
53
|
elem_type = value_info.type.tensor_type.elem_type
|
|
42
54
|
shape = [dim.dim_value for dim in value_info.type.tensor_type.shape.dim]
|
|
43
55
|
new_output = helper.make_tensor_value_info(name, elem_type, shape)
|
|
44
56
|
graph.output.append(new_output)
|
|
57
|
+
|
|
45
58
|
for output in graph.output:
|
|
46
|
-
print(output)
|
|
59
|
+
print(output) # noqa: T201
|
|
47
60
|
if output.name == "/conv1/Conv_output_0":
|
|
48
61
|
output.type.tensor_type.elem_type = TensorProto.INT64
|
|
49
62
|
|
|
50
63
|
onnx.save(model, save_path)
|
|
51
|
-
print(f"Saved cut model with outputs {output_names} to {save_path}")
|
|
64
|
+
print(f"Saved cut model with outputs {output_names} to {save_path}") # noqa: T201
|
|
52
65
|
|
|
53
66
|
|
|
54
67
|
if __name__ == "__main__":
|
|
55
|
-
# /conv1/Conv_output_0
|
|
56
|
-
# prune_model(
|
|
57
|
-
# model_path="models_onnx/doom.onnx",
|
|
58
|
-
# output_names=["/Relu_2_output_0"], # replace with your intermediate tensor
|
|
59
|
-
# save_path= "models_onnx/test_doom_cut.onnx"
|
|
60
|
-
# )
|
|
61
|
-
# cut_model("models_onnx/doom.onnx",["/Relu_2_output_0"], "test_doom_after_conv.onnx")
|
|
62
68
|
prune_model(
|
|
63
69
|
model_path="models_onnx/doom.onnx",
|
|
64
70
|
output_names=["/Relu_3_output_0"], # replace with your intermediate tensor
|
|
65
|
-
save_path=
|
|
71
|
+
save_path="models_onnx/test_doom_cut.onnx",
|
|
66
72
|
)
|
|
@@ -2,7 +2,11 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from collections.abc import Generator
|
|
9
|
+
|
|
6
10
|
|
|
7
11
|
import pytest
|
|
8
12
|
import torch
|
|
@@ -35,7 +39,7 @@ OUTPUTTWICE = 2
|
|
|
35
39
|
OUTPUTTHREETIMES = 3
|
|
36
40
|
|
|
37
41
|
|
|
38
|
-
@pytest.mark.e2e
|
|
42
|
+
@pytest.mark.e2e
|
|
39
43
|
def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
|
|
40
44
|
# Here you could just check that circuit file exists
|
|
41
45
|
circuit_compile_results[model_fixture["model"]] = False
|
|
@@ -43,7 +47,7 @@ def test_circuit_compiles(model_fixture: dict[str, Any]) -> None:
|
|
|
43
47
|
circuit_compile_results[model_fixture["model"]] = True
|
|
44
48
|
|
|
45
49
|
|
|
46
|
-
@pytest.mark.e2e
|
|
50
|
+
@pytest.mark.e2e
|
|
47
51
|
def test_witness_dev(
|
|
48
52
|
model_fixture: dict[str, Any],
|
|
49
53
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -89,7 +93,7 @@ def test_witness_dev(
|
|
|
89
93
|
witness_generated_results[model_fixture["model"]] = True
|
|
90
94
|
|
|
91
95
|
|
|
92
|
-
@pytest.mark.e2e
|
|
96
|
+
@pytest.mark.e2e
|
|
93
97
|
def test_witness_wrong_outputs_dev(
|
|
94
98
|
model_fixture: dict[str, Any],
|
|
95
99
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -145,7 +149,7 @@ def test_witness_wrong_outputs_dev(
|
|
|
145
149
|
), f"Expected '{output}' in stdout, but it was not found."
|
|
146
150
|
|
|
147
151
|
|
|
148
|
-
@pytest.mark.e2e
|
|
152
|
+
@pytest.mark.e2e
|
|
149
153
|
def test_witness_prove_verify_true_inputs_dev(
|
|
150
154
|
model_fixture: dict[str, Any],
|
|
151
155
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -239,7 +243,7 @@ def test_witness_prove_verify_true_inputs_dev(
|
|
|
239
243
|
), "Expected 'Verified' in stdout three times, but it was not found."
|
|
240
244
|
|
|
241
245
|
|
|
242
|
-
@pytest.mark.e2e
|
|
246
|
+
@pytest.mark.e2e
|
|
243
247
|
def test_witness_prove_verify_true_inputs_dev_expander_call(
|
|
244
248
|
model_fixture: dict[str, Any],
|
|
245
249
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -339,7 +343,7 @@ def test_witness_prove_verify_true_inputs_dev_expander_call(
|
|
|
339
343
|
assert stdout.count("proving") == 1, "Expected 'proving' but it was not found."
|
|
340
344
|
|
|
341
345
|
|
|
342
|
-
@pytest.mark.e2e
|
|
346
|
+
@pytest.mark.e2e
|
|
343
347
|
def test_witness_read_after_write_json(
|
|
344
348
|
model_fixture: dict[str, Any],
|
|
345
349
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -421,7 +425,7 @@ def test_witness_read_after_write_json(
|
|
|
421
425
|
), "Input JSON read is not identical to what was written"
|
|
422
426
|
|
|
423
427
|
|
|
424
|
-
@pytest.mark.e2e
|
|
428
|
+
@pytest.mark.e2e
|
|
425
429
|
def test_witness_fresh_compile_dev(
|
|
426
430
|
model_fixture: dict[str, Any],
|
|
427
431
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -467,7 +471,7 @@ def test_witness_fresh_compile_dev(
|
|
|
467
471
|
|
|
468
472
|
|
|
469
473
|
# Use once fixed input shape read in rust
|
|
470
|
-
@pytest.mark.e2e
|
|
474
|
+
@pytest.mark.e2e
|
|
471
475
|
def test_witness_incorrect_input_shape(
|
|
472
476
|
model_fixture: dict[str, Any],
|
|
473
477
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -555,7 +559,7 @@ def test_witness_incorrect_input_shape(
|
|
|
555
559
|
), f"Did not expect '{output}' in stdout, but it was found."
|
|
556
560
|
|
|
557
561
|
|
|
558
|
-
@pytest.mark.e2e
|
|
562
|
+
@pytest.mark.e2e
|
|
559
563
|
def test_witness_unscaled(
|
|
560
564
|
model_fixture: dict[str, Any],
|
|
561
565
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -666,7 +670,7 @@ def test_witness_unscaled(
|
|
|
666
670
|
)
|
|
667
671
|
|
|
668
672
|
|
|
669
|
-
@pytest.mark.e2e
|
|
673
|
+
@pytest.mark.e2e
|
|
670
674
|
def test_witness_unscaled_and_incorrect_shape_input(
|
|
671
675
|
model_fixture: dict[str, Any],
|
|
672
676
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -790,7 +794,7 @@ def test_witness_unscaled_and_incorrect_shape_input(
|
|
|
790
794
|
)
|
|
791
795
|
|
|
792
796
|
|
|
793
|
-
@pytest.mark.e2e
|
|
797
|
+
@pytest.mark.e2e
|
|
794
798
|
def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
|
|
795
799
|
model_fixture: dict[str, Any],
|
|
796
800
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -937,7 +941,7 @@ def test_witness_unscaled_and_incorrect_and_bad_named_input( # noqa: PLR0915
|
|
|
937
941
|
)
|
|
938
942
|
|
|
939
943
|
|
|
940
|
-
@pytest.mark.e2e
|
|
944
|
+
@pytest.mark.e2e
|
|
941
945
|
def test_witness_wrong_name(
|
|
942
946
|
model_fixture: dict[str, Any],
|
|
943
947
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -1063,7 +1067,7 @@ def add_to_first_scalar(data: list, delta: float = 0.1) -> bool:
|
|
|
1063
1067
|
return False
|
|
1064
1068
|
|
|
1065
1069
|
|
|
1066
|
-
@pytest.mark.e2e
|
|
1070
|
+
@pytest.mark.e2e
|
|
1067
1071
|
def test_witness_prove_verify_false_inputs_dev(
|
|
1068
1072
|
model_fixture: dict[str, Any],
|
|
1069
1073
|
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
from collections.abc import Generator, Mapping, Sequence
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import TYPE_CHECKING, Any,
|
|
5
|
+
from typing import TYPE_CHECKING, Any, TypeAlias
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pytest
|
|
@@ -64,7 +65,7 @@ def model_fixture(
|
|
|
64
65
|
}
|
|
65
66
|
|
|
66
67
|
|
|
67
|
-
@pytest.fixture
|
|
68
|
+
@pytest.fixture
|
|
68
69
|
def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
69
70
|
witness_path = tmp_path / "temp_witness.txt"
|
|
70
71
|
# Give it to the test
|
|
@@ -75,7 +76,7 @@ def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
75
76
|
witness_path.unlink()
|
|
76
77
|
|
|
77
78
|
|
|
78
|
-
@pytest.fixture
|
|
79
|
+
@pytest.fixture
|
|
79
80
|
def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
80
81
|
input_path = tmp_path / "temp_input.txt"
|
|
81
82
|
# Give it to the test
|
|
@@ -86,7 +87,7 @@ def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
86
87
|
input_path.unlink()
|
|
87
88
|
|
|
88
89
|
|
|
89
|
-
@pytest.fixture
|
|
90
|
+
@pytest.fixture
|
|
90
91
|
def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
91
92
|
output_path = tmp_path / "temp_output.txt"
|
|
92
93
|
# Give it to the test
|
|
@@ -97,7 +98,7 @@ def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
97
98
|
output_path.unlink()
|
|
98
99
|
|
|
99
100
|
|
|
100
|
-
@pytest.fixture
|
|
101
|
+
@pytest.fixture
|
|
101
102
|
def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
102
103
|
output_path = tmp_path / "temp_proof.txt"
|
|
103
104
|
# Give it to the test
|
|
@@ -108,13 +109,10 @@ def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
|
108
109
|
output_path.unlink()
|
|
109
110
|
|
|
110
111
|
|
|
111
|
-
ScalarOrTensor: TypeAlias =
|
|
112
|
-
NestedArray: TypeAlias =
|
|
113
|
-
ScalarOrTensor
|
|
114
|
-
|
|
115
|
-
tuple["NestedArray"],
|
|
116
|
-
np.ndarray,
|
|
117
|
-
]
|
|
112
|
+
ScalarOrTensor: TypeAlias = int | float | torch.Tensor
|
|
113
|
+
NestedArray: TypeAlias = (
|
|
114
|
+
ScalarOrTensor | list["NestedArray"] | tuple["NestedArray"] | np.ndarray
|
|
115
|
+
)
|
|
118
116
|
|
|
119
117
|
|
|
120
118
|
def add_1_to_first_element(x: NestedArray) -> NestedArray:
|
|
@@ -137,7 +135,7 @@ def add_1_to_first_element(x: NestedArray) -> NestedArray:
|
|
|
137
135
|
circuit_compile_results = {}
|
|
138
136
|
witness_generated_results = {}
|
|
139
137
|
|
|
140
|
-
Nested: TypeAlias =
|
|
138
|
+
Nested: TypeAlias = float | Mapping[str, "Nested"] | Sequence["Nested"]
|
|
141
139
|
|
|
142
140
|
|
|
143
141
|
def contains_float(obj: Nested) -> bool:
|
|
@@ -1,115 +1,97 @@
|
|
|
1
|
-
import
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
2
3
|
import numpy as np
|
|
3
|
-
import torch
|
|
4
4
|
import onnx
|
|
5
|
-
|
|
6
|
-
|
|
5
|
+
import pytest
|
|
6
|
+
import torch
|
|
7
|
+
from onnx import TensorProto, helper, shape_inference
|
|
7
8
|
|
|
8
9
|
from python.core.model_processing.converters.onnx_converter import ONNXConverter
|
|
9
|
-
from python.core.model_processing.onnx_custom_ops.onnx_helpers import extract_shape_dict
|
|
10
|
-
from python.core.model_processing.onnx_quantizer.onnx_op_quantizer import ONNXOpQuantizer
|
|
11
|
-
|
|
12
|
-
from onnxruntime import InferenceSession, SessionOptions
|
|
13
|
-
from onnxruntime_extensions import get_library_path, OrtPyFunction
|
|
14
|
-
from python.core.model_processing.onnx_custom_ops import conv
|
|
15
|
-
|
|
16
|
-
from python.core.model_processing.onnx_custom_ops.conv import int64_conv
|
|
17
|
-
from python.core.model_processing.onnx_custom_ops.gemm import int64_gemm7
|
|
18
10
|
|
|
19
11
|
|
|
20
12
|
@pytest.fixture
|
|
21
|
-
def tiny_conv_model_path(tmp_path):
|
|
13
|
+
def tiny_conv_model_path(tmp_path: Path) -> Path:
|
|
22
14
|
# Create input and output tensor info
|
|
23
|
-
input_tensor = helper.make_tensor_value_info(
|
|
24
|
-
output_tensor = helper.make_tensor_value_info(
|
|
15
|
+
input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 4, 4])
|
|
16
|
+
output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 2])
|
|
25
17
|
|
|
26
18
|
# Kernel weights (3x3 ones)
|
|
27
|
-
|
|
28
|
-
name=
|
|
19
|
+
w_init = helper.make_tensor(
|
|
20
|
+
name="W",
|
|
29
21
|
data_type=TensorProto.FLOAT,
|
|
30
22
|
dims=[1, 1, 3, 3],
|
|
31
|
-
vals=np.ones((1 * 1 * 3 * 3), dtype=np.float32).tolist()
|
|
23
|
+
vals=np.ones((1 * 1 * 3 * 3), dtype=np.float32).tolist(),
|
|
32
24
|
)
|
|
33
|
-
|
|
34
|
-
name=
|
|
25
|
+
z_init = helper.make_tensor(
|
|
26
|
+
name="Z",
|
|
35
27
|
data_type=TensorProto.FLOAT,
|
|
36
28
|
dims=[1],
|
|
37
|
-
vals=np.ones((
|
|
29
|
+
vals=np.ones((1), dtype=np.float32).tolist(),
|
|
38
30
|
)
|
|
39
31
|
|
|
40
32
|
# Conv node with no padding, stride 1
|
|
41
33
|
conv_node = helper.make_node(
|
|
42
|
-
|
|
43
|
-
inputs=[
|
|
44
|
-
outputs=[
|
|
34
|
+
"Conv",
|
|
35
|
+
inputs=["X", "W", "Z"],
|
|
36
|
+
outputs=["Y"],
|
|
45
37
|
kernel_shape=[3, 3],
|
|
46
38
|
pads=[0, 0, 0, 0],
|
|
47
39
|
strides=[1, 1],
|
|
48
|
-
dilations
|
|
40
|
+
dilations=[1, 1],
|
|
49
41
|
)
|
|
50
42
|
|
|
51
43
|
# Build graph and model
|
|
52
44
|
graph = helper.make_graph(
|
|
53
45
|
nodes=[conv_node],
|
|
54
|
-
name=
|
|
46
|
+
name="TinyConvGraph",
|
|
55
47
|
inputs=[input_tensor],
|
|
56
48
|
outputs=[output_tensor],
|
|
57
|
-
initializer=[
|
|
49
|
+
initializer=[w_init, z_init],
|
|
58
50
|
)
|
|
59
51
|
|
|
60
|
-
model = helper.make_model(graph, producer_name=
|
|
52
|
+
model = helper.make_model(graph, producer_name="tiny-conv-example")
|
|
61
53
|
|
|
62
54
|
# Save to a temporary file
|
|
63
55
|
model_path = tmp_path / "tiny_conv.onnx"
|
|
64
56
|
onnx.save(model, str(model_path))
|
|
65
57
|
|
|
66
|
-
return
|
|
58
|
+
return model_path
|
|
59
|
+
|
|
67
60
|
|
|
68
61
|
@pytest.mark.integration
|
|
69
|
-
def test_tiny_conv(tiny_conv_model_path):
|
|
62
|
+
def test_tiny_conv(tiny_conv_model_path: Path, tmp_path: Path) -> None:
|
|
70
63
|
path = tiny_conv_model_path
|
|
71
64
|
|
|
72
65
|
converter = ONNXConverter()
|
|
73
66
|
|
|
74
|
-
|
|
75
|
-
id_count = 0
|
|
67
|
+
# Load and validate original model
|
|
76
68
|
model = onnx.load(path)
|
|
77
|
-
# Fix, can remove this next line
|
|
78
|
-
onnx.checker.check_model(model)
|
|
79
|
-
|
|
80
|
-
# Check the model and print Y"s shape information
|
|
81
69
|
onnx.checker.check_model(model)
|
|
82
|
-
print(f"Before shape inference, the shape info of Y is:\n{model.graph.value_info}")
|
|
83
70
|
|
|
84
|
-
# Apply shape inference
|
|
71
|
+
# Apply shape inference and validate
|
|
85
72
|
inferred_model = shape_inference.infer_shapes(model)
|
|
86
|
-
|
|
87
|
-
# Check the model and print Y"s shape information
|
|
88
73
|
onnx.checker.check_model(inferred_model)
|
|
89
|
-
# print(f"After shape inference, the shape info of Y is:\n{inferred_model.graph.value_info}")
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
domain_to_version = {opset.domain: opset.version for opset in model.opset_import}
|
|
93
|
-
|
|
94
|
-
inferred_model = shape_inference.infer_shapes(model)
|
|
95
|
-
output_name_to_shape = extract_shape_dict(inferred_model)
|
|
96
|
-
id_count = 0
|
|
97
74
|
|
|
75
|
+
# Quantize and add custom domain
|
|
98
76
|
new_model = converter.quantize_model(model, 2, 21)
|
|
99
77
|
custom_domain = onnx.helper.make_operatorsetid(domain="ai.onnx.contrib", version=1)
|
|
100
78
|
new_model.opset_import.append(custom_domain)
|
|
101
79
|
onnx.checker.check_model(new_model)
|
|
102
80
|
|
|
103
|
-
|
|
81
|
+
# Save quantized model
|
|
82
|
+
out_path = tmp_path / "model_quant.onnx"
|
|
83
|
+
with out_path.open("wb") as f:
|
|
104
84
|
f.write(new_model.SerializeToString())
|
|
105
85
|
|
|
106
|
-
model
|
|
107
|
-
onnx.
|
|
86
|
+
# Reload quantized model to ensure it is valid
|
|
87
|
+
model_quant = onnx.load(str(out_path))
|
|
88
|
+
onnx.checker.check_model(model_quant)
|
|
108
89
|
|
|
90
|
+
# Prepare inputs and compare outputs
|
|
109
91
|
inputs = np.arange(16, dtype=np.float32).reshape(1, 1, 4, 4)
|
|
110
92
|
outputs_true = converter.run_model_onnx_runtime(path, inputs)
|
|
93
|
+
outputs_quant = converter.run_model_onnx_runtime(out_path, inputs)
|
|
111
94
|
|
|
112
|
-
outputs_quant = converter.run_model_onnx_runtime("model.onnx", inputs)
|
|
113
95
|
true = torch.tensor(np.array(outputs_true), dtype=torch.float32)
|
|
114
96
|
quant = torch.tensor(np.array(outputs_quant), dtype=torch.float32) / (2**21)
|
|
115
97
|
|
|
@@ -113,9 +113,7 @@ class LayerTestConfig:
|
|
|
113
113
|
# respect that; otherwise use original valid_inputs.
|
|
114
114
|
inputs = test_spec.input_overrides or self.valid_inputs
|
|
115
115
|
|
|
116
|
-
# Prepare attributes
|
|
117
|
-
attrs = {**self.valid_attributes, **test_spec.attr_overrides}
|
|
118
|
-
# Remove omitted attributes if specified
|
|
116
|
+
# Prepare attributes and remove omitted attributes if specified
|
|
119
117
|
attrs = {**self.valid_attributes, **test_spec.attr_overrides}
|
|
120
118
|
for key in getattr(test_spec, "omit_attrs", []):
|
|
121
119
|
attrs.pop(key, None)
|