JSTprove 1.0.0__py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jstprove-1.0.0.dist-info/METADATA +397 -0
- jstprove-1.0.0.dist-info/RECORD +81 -0
- jstprove-1.0.0.dist-info/WHEEL +6 -0
- jstprove-1.0.0.dist-info/entry_points.txt +2 -0
- jstprove-1.0.0.dist-info/licenses/LICENSE +26 -0
- jstprove-1.0.0.dist-info/top_level.txt +1 -0
- python/__init__.py +0 -0
- python/core/__init__.py +3 -0
- python/core/binaries/__init__.py +0 -0
- python/core/binaries/expander-exec +0 -0
- python/core/binaries/onnx_generic_circuit_1-0-0 +0 -0
- python/core/circuit_models/__init__.py +0 -0
- python/core/circuit_models/generic_onnx.py +231 -0
- python/core/circuit_models/simple_circuit.py +133 -0
- python/core/circuits/__init__.py +0 -0
- python/core/circuits/base.py +1000 -0
- python/core/circuits/errors.py +188 -0
- python/core/circuits/zk_model_base.py +25 -0
- python/core/model_processing/__init__.py +0 -0
- python/core/model_processing/converters/__init__.py +0 -0
- python/core/model_processing/converters/base.py +143 -0
- python/core/model_processing/converters/onnx_converter.py +1181 -0
- python/core/model_processing/errors.py +147 -0
- python/core/model_processing/onnx_custom_ops/__init__.py +16 -0
- python/core/model_processing/onnx_custom_ops/conv.py +111 -0
- python/core/model_processing/onnx_custom_ops/custom_helpers.py +56 -0
- python/core/model_processing/onnx_custom_ops/gemm.py +91 -0
- python/core/model_processing/onnx_custom_ops/maxpool.py +79 -0
- python/core/model_processing/onnx_custom_ops/onnx_helpers.py +173 -0
- python/core/model_processing/onnx_custom_ops/relu.py +43 -0
- python/core/model_processing/onnx_quantizer/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/exceptions.py +168 -0
- python/core/model_processing/onnx_quantizer/layers/__init__.py +0 -0
- python/core/model_processing/onnx_quantizer/layers/base.py +396 -0
- python/core/model_processing/onnx_quantizer/layers/constant.py +118 -0
- python/core/model_processing/onnx_quantizer/layers/conv.py +180 -0
- python/core/model_processing/onnx_quantizer/layers/gemm.py +171 -0
- python/core/model_processing/onnx_quantizer/layers/maxpool.py +140 -0
- python/core/model_processing/onnx_quantizer/layers/relu.py +76 -0
- python/core/model_processing/onnx_quantizer/onnx_op_quantizer.py +200 -0
- python/core/model_templates/__init__.py +0 -0
- python/core/model_templates/circuit_template.py +57 -0
- python/core/utils/__init__.py +0 -0
- python/core/utils/benchmarking_helpers.py +163 -0
- python/core/utils/constants.py +4 -0
- python/core/utils/errors.py +117 -0
- python/core/utils/general_layer_functions.py +268 -0
- python/core/utils/helper_functions.py +1138 -0
- python/core/utils/model_registry.py +166 -0
- python/core/utils/scratch_tests.py +66 -0
- python/core/utils/witness_utils.py +291 -0
- python/frontend/__init__.py +0 -0
- python/frontend/cli.py +115 -0
- python/frontend/commands/__init__.py +17 -0
- python/frontend/commands/args.py +100 -0
- python/frontend/commands/base.py +199 -0
- python/frontend/commands/bench/__init__.py +54 -0
- python/frontend/commands/bench/list.py +42 -0
- python/frontend/commands/bench/model.py +172 -0
- python/frontend/commands/bench/sweep.py +212 -0
- python/frontend/commands/compile.py +58 -0
- python/frontend/commands/constants.py +5 -0
- python/frontend/commands/model_check.py +53 -0
- python/frontend/commands/prove.py +50 -0
- python/frontend/commands/verify.py +73 -0
- python/frontend/commands/witness.py +64 -0
- python/scripts/__init__.py +0 -0
- python/scripts/benchmark_runner.py +833 -0
- python/scripts/gen_and_bench.py +482 -0
- python/tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/__init__.py +0 -0
- python/tests/circuit_e2e_tests/circuit_model_developer_test.py +1158 -0
- python/tests/circuit_e2e_tests/helper_fns_for_tests.py +190 -0
- python/tests/circuit_e2e_tests/other_e2e_test.py +217 -0
- python/tests/circuit_parent_classes/__init__.py +0 -0
- python/tests/circuit_parent_classes/test_circuit.py +969 -0
- python/tests/circuit_parent_classes/test_onnx_converter.py +201 -0
- python/tests/circuit_parent_classes/test_ort_custom_layers.py +116 -0
- python/tests/test_cli.py +1021 -0
- python/tests/utils_testing/__init__.py +0 -0
- python/tests/utils_testing/test_helper_functions.py +891 -0
|
@@ -0,0 +1,190 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generator, Mapping, Sequence, TypeAlias, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pytest
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from python.core.circuits.zk_model_base import ZKModelBase
|
|
12
|
+
from python.core.utils.helper_functions import CircuitExecutionConfig, RunType
|
|
13
|
+
|
|
14
|
+
GOOD_OUTPUT = ["Witness Generated"]
|
|
15
|
+
BAD_OUTPUT = [
|
|
16
|
+
"Witness generation failed",
|
|
17
|
+
"Outputs generated do not match outputs supplied",
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
NUMPARAMS3 = 3
|
|
21
|
+
NUMPARAMS4 = 4
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture(scope="module")
|
|
25
|
+
def model_fixture(
|
|
26
|
+
request: pytest.FixtureRequest,
|
|
27
|
+
tmp_path_factory: pytest.TempPathFactory,
|
|
28
|
+
) -> dict[str, Any]:
|
|
29
|
+
param = request.param
|
|
30
|
+
name = f"{param.name}"
|
|
31
|
+
model_class = param.loader
|
|
32
|
+
args, kwargs = (), {}
|
|
33
|
+
|
|
34
|
+
if len(param) == NUMPARAMS3:
|
|
35
|
+
if isinstance(param[2], dict):
|
|
36
|
+
kwargs = param[2]
|
|
37
|
+
else:
|
|
38
|
+
args = param[2]
|
|
39
|
+
elif len(param) == NUMPARAMS4:
|
|
40
|
+
args, kwargs = param[2], param[3]
|
|
41
|
+
|
|
42
|
+
temp_dir = tmp_path_factory.mktemp(name)
|
|
43
|
+
circuit_path = temp_dir / f"{name}_circuit.txt"
|
|
44
|
+
quantized_path = temp_dir / f"{name}_quantized.pt"
|
|
45
|
+
|
|
46
|
+
model = model_class(*args, **kwargs)
|
|
47
|
+
|
|
48
|
+
model.base_testing(
|
|
49
|
+
CircuitExecutionConfig(
|
|
50
|
+
run_type=RunType.COMPILE_CIRCUIT,
|
|
51
|
+
dev_mode=True,
|
|
52
|
+
circuit_path=str(circuit_path),
|
|
53
|
+
quantized_path=quantized_path,
|
|
54
|
+
),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
return {
|
|
58
|
+
"name": name,
|
|
59
|
+
"model_class": model_class,
|
|
60
|
+
"circuit_path": circuit_path,
|
|
61
|
+
"temp_dir": temp_dir,
|
|
62
|
+
"model": model,
|
|
63
|
+
"quantized_model": quantized_path,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
@pytest.fixture()
|
|
68
|
+
def temp_witness_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
69
|
+
witness_path = tmp_path / "temp_witness.txt"
|
|
70
|
+
# Give it to the test
|
|
71
|
+
yield witness_path
|
|
72
|
+
|
|
73
|
+
# After the test is done, remove it
|
|
74
|
+
if Path.exists(witness_path):
|
|
75
|
+
witness_path.unlink()
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@pytest.fixture()
|
|
79
|
+
def temp_input_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
80
|
+
input_path = tmp_path / "temp_input.txt"
|
|
81
|
+
# Give it to the test
|
|
82
|
+
yield input_path
|
|
83
|
+
|
|
84
|
+
# After the test is done, remove it
|
|
85
|
+
if Path.exists(input_path):
|
|
86
|
+
input_path.unlink()
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
@pytest.fixture()
|
|
90
|
+
def temp_output_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
91
|
+
output_path = tmp_path / "temp_output.txt"
|
|
92
|
+
# Give it to the test
|
|
93
|
+
yield output_path
|
|
94
|
+
|
|
95
|
+
# After the test is done, remove it
|
|
96
|
+
if Path.exists(output_path):
|
|
97
|
+
output_path.unlink()
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@pytest.fixture()
|
|
101
|
+
def temp_proof_file(tmp_path: str) -> Generator[Path, None, None]:
|
|
102
|
+
output_path = tmp_path / "temp_proof.txt"
|
|
103
|
+
# Give it to the test
|
|
104
|
+
yield output_path
|
|
105
|
+
|
|
106
|
+
# After the test is done, remove it
|
|
107
|
+
if Path.exists(output_path):
|
|
108
|
+
output_path.unlink()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
ScalarOrTensor: TypeAlias = Union[int, float, torch.Tensor]
|
|
112
|
+
NestedArray: TypeAlias = Union[
|
|
113
|
+
ScalarOrTensor,
|
|
114
|
+
list["NestedArray"],
|
|
115
|
+
tuple["NestedArray"],
|
|
116
|
+
np.ndarray,
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def add_1_to_first_element(x: NestedArray) -> NestedArray:
|
|
121
|
+
"""Safely adds 1 to the first element of any scalar/list/tensor."""
|
|
122
|
+
if isinstance(x, (int, float)):
|
|
123
|
+
return x + 1
|
|
124
|
+
if isinstance(x, torch.Tensor):
|
|
125
|
+
x = x.clone() # avoid in-place modification
|
|
126
|
+
x.view(-1)[0] += 1
|
|
127
|
+
return x
|
|
128
|
+
if isinstance(x, (list, tuple, np.ndarray)):
|
|
129
|
+
x = list(x)
|
|
130
|
+
x[0] = add_1_to_first_element(x[0])
|
|
131
|
+
return x
|
|
132
|
+
msg = f"Unsupported type for get_outputs patch: {type(x)}"
|
|
133
|
+
raise TypeError(msg)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
# Define models to be tested
|
|
137
|
+
circuit_compile_results = {}
|
|
138
|
+
witness_generated_results = {}
|
|
139
|
+
|
|
140
|
+
Nested: TypeAlias = Union[float, Mapping[str, "Nested"], Sequence["Nested"]]
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def contains_float(obj: Nested) -> bool:
|
|
144
|
+
if isinstance(obj, float):
|
|
145
|
+
return True
|
|
146
|
+
if isinstance(obj, dict):
|
|
147
|
+
return any(contains_float(v) for v in obj.values())
|
|
148
|
+
if isinstance(obj, list):
|
|
149
|
+
return any(contains_float(i) for i in obj)
|
|
150
|
+
return False
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
@pytest.fixture(scope="module")
|
|
154
|
+
def check_model_compiles(model_fixture: dict[str, Any]) -> None:
|
|
155
|
+
# Default to True; will be set to False if first test fails
|
|
156
|
+
result = circuit_compile_results.get(model_fixture["model"])
|
|
157
|
+
if result is False:
|
|
158
|
+
pytest.skip(
|
|
159
|
+
f"Skipping because the first test failed for: {model_fixture['model']}",
|
|
160
|
+
)
|
|
161
|
+
return result
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
@pytest.fixture(scope="module")
|
|
165
|
+
def check_witness_generated(model_fixture: dict[str, Any]) -> None:
|
|
166
|
+
# Default to True; will be set to False if first test fails
|
|
167
|
+
result = witness_generated_results.get(model_fixture["model"])
|
|
168
|
+
if result is False:
|
|
169
|
+
pytest.skip(
|
|
170
|
+
f"Skipping because the first test failed for: {model_fixture['model']}",
|
|
171
|
+
)
|
|
172
|
+
return result
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def assert_very_close(
|
|
176
|
+
inputs_1: np.array,
|
|
177
|
+
inputs_2: np.array,
|
|
178
|
+
model: ZKModelBase,
|
|
179
|
+
) -> None:
|
|
180
|
+
for i in inputs_1:
|
|
181
|
+
x = torch.div(
|
|
182
|
+
torch.as_tensor(inputs_1[i]),
|
|
183
|
+
model.scale_base**model.scale_exponent,
|
|
184
|
+
)
|
|
185
|
+
y = torch.div(
|
|
186
|
+
torch.as_tensor(inputs_2[i]),
|
|
187
|
+
model.scale_base**model.scale_exponent,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
assert torch.isclose(x, y, rtol=1e-8).all()
|
|
@@ -0,0 +1,217 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import subprocess
|
|
3
|
+
import sys
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Generator
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import onnx
|
|
9
|
+
import pytest
|
|
10
|
+
from onnx import helper, numpy_helper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_simple_gemm_onnx_model(
|
|
14
|
+
input_size: int,
|
|
15
|
+
output_size: int,
|
|
16
|
+
model_path: Path,
|
|
17
|
+
) -> None:
|
|
18
|
+
"""Create a simple ONNX model with a single GEMM layer."""
|
|
19
|
+
# Define input
|
|
20
|
+
input_tensor = helper.make_tensor_value_info(
|
|
21
|
+
"input",
|
|
22
|
+
onnx.TensorProto.FLOAT,
|
|
23
|
+
[1, input_size],
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Define output
|
|
27
|
+
output_tensor = helper.make_tensor_value_info(
|
|
28
|
+
"output",
|
|
29
|
+
onnx.TensorProto.FLOAT,
|
|
30
|
+
[1, output_size],
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Create random number generator
|
|
34
|
+
rng = np.random.default_rng()
|
|
35
|
+
|
|
36
|
+
# Create weight tensor
|
|
37
|
+
weight = rng.standard_normal((output_size, input_size)).astype(np.float32)
|
|
38
|
+
weight_tensor = numpy_helper.from_array(weight, name="weight")
|
|
39
|
+
|
|
40
|
+
# Create bias tensor
|
|
41
|
+
bias = rng.standard_normal((output_size,)).astype(np.float32)
|
|
42
|
+
bias_tensor = numpy_helper.from_array(bias, name="bias")
|
|
43
|
+
|
|
44
|
+
# Create GEMM node
|
|
45
|
+
gemm_node = helper.make_node(
|
|
46
|
+
"Gemm",
|
|
47
|
+
inputs=["input", "weight", "bias"],
|
|
48
|
+
outputs=["output"],
|
|
49
|
+
alpha=1.0,
|
|
50
|
+
beta=1.0,
|
|
51
|
+
transB=1, # Transpose B (weight)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# Create graph
|
|
55
|
+
graph = helper.make_graph(
|
|
56
|
+
[gemm_node],
|
|
57
|
+
"simple_gemm",
|
|
58
|
+
[input_tensor],
|
|
59
|
+
[output_tensor],
|
|
60
|
+
[weight_tensor, bias_tensor],
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# Create model
|
|
64
|
+
model = helper.make_model(graph, producer_name="simple_gemm_creator")
|
|
65
|
+
|
|
66
|
+
# Save model
|
|
67
|
+
onnx.save(model, str(model_path))
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@pytest.mark.e2e()
|
|
71
|
+
def test_parallel_compile_and_witness_two_simple_models( # noqa: PLR0915
|
|
72
|
+
tmp_path: str,
|
|
73
|
+
capsys: Generator[pytest.CaptureFixture[str], None, None],
|
|
74
|
+
) -> None:
|
|
75
|
+
"""Test compiling and running witness
|
|
76
|
+
for two different simple ONNX models in parallel.
|
|
77
|
+
"""
|
|
78
|
+
# Create two simple ONNX models with different shapes
|
|
79
|
+
model1_path = Path(tmp_path) / "simple_gemm1.onnx"
|
|
80
|
+
model2_path = Path(tmp_path) / "simple_gemm2.onnx"
|
|
81
|
+
model1_input_size = 4
|
|
82
|
+
model1_output_size = 2
|
|
83
|
+
|
|
84
|
+
model2_input_size = 6
|
|
85
|
+
model2_output_size = 3
|
|
86
|
+
|
|
87
|
+
create_simple_gemm_onnx_model(model1_input_size, model1_output_size, model1_path)
|
|
88
|
+
create_simple_gemm_onnx_model(model2_input_size, model2_output_size, model2_path)
|
|
89
|
+
|
|
90
|
+
# Define paths for artifacts
|
|
91
|
+
circuit1_path = Path(tmp_path) / "circuit1.txt"
|
|
92
|
+
circuit2_path = Path(tmp_path) / "circuit2.txt"
|
|
93
|
+
|
|
94
|
+
# Compile both models
|
|
95
|
+
compile_cmd1 = [
|
|
96
|
+
sys.executable,
|
|
97
|
+
"-m",
|
|
98
|
+
"python.frontend.cli",
|
|
99
|
+
"compile",
|
|
100
|
+
"-m",
|
|
101
|
+
str(model1_path),
|
|
102
|
+
"-c",
|
|
103
|
+
str(circuit1_path),
|
|
104
|
+
]
|
|
105
|
+
compile_cmd2 = [
|
|
106
|
+
sys.executable,
|
|
107
|
+
"-m",
|
|
108
|
+
"python.frontend.cli",
|
|
109
|
+
"compile",
|
|
110
|
+
"-m",
|
|
111
|
+
str(model2_path),
|
|
112
|
+
"-c",
|
|
113
|
+
str(circuit2_path),
|
|
114
|
+
]
|
|
115
|
+
|
|
116
|
+
# Run compile commands
|
|
117
|
+
result1 = subprocess.run(
|
|
118
|
+
compile_cmd1, # noqa: S603
|
|
119
|
+
capture_output=True,
|
|
120
|
+
text=True,
|
|
121
|
+
check=False,
|
|
122
|
+
)
|
|
123
|
+
assert result1.returncode == 0, f"Compile failed for model1: {result1.stderr}"
|
|
124
|
+
|
|
125
|
+
result2 = subprocess.run(
|
|
126
|
+
compile_cmd2, # noqa: S603
|
|
127
|
+
capture_output=True,
|
|
128
|
+
text=True,
|
|
129
|
+
check=False,
|
|
130
|
+
)
|
|
131
|
+
assert result2.returncode == 0, f"Compile failed for model2: {result2.stderr}"
|
|
132
|
+
|
|
133
|
+
# Create input files
|
|
134
|
+
input1_data = {"input": [1.0] * model1_input_size} # 10 inputs
|
|
135
|
+
input2_data = {"input": [1.0] * model2_input_size} # 20 inputs
|
|
136
|
+
|
|
137
|
+
input1_path = Path(tmp_path) / "input1.json"
|
|
138
|
+
input2_path = Path(tmp_path) / "input2.json"
|
|
139
|
+
|
|
140
|
+
with Path.open(input1_path, "w") as f:
|
|
141
|
+
json.dump(input1_data, f)
|
|
142
|
+
with Path.open(input2_path, "w") as f:
|
|
143
|
+
json.dump(input2_data, f)
|
|
144
|
+
|
|
145
|
+
# Define output and witness paths
|
|
146
|
+
output1_path = Path(tmp_path) / "output1.json"
|
|
147
|
+
witness1_path = Path(tmp_path) / "witness1.bin"
|
|
148
|
+
output2_path = Path(tmp_path) / "output2.json"
|
|
149
|
+
witness2_path = Path(tmp_path) / "witness2.bin"
|
|
150
|
+
|
|
151
|
+
# Run witness commands in parallel
|
|
152
|
+
witness_cmd1 = [
|
|
153
|
+
sys.executable,
|
|
154
|
+
"-m",
|
|
155
|
+
"python.frontend.cli",
|
|
156
|
+
"witness",
|
|
157
|
+
"-c",
|
|
158
|
+
str(circuit1_path),
|
|
159
|
+
"-i",
|
|
160
|
+
str(input1_path),
|
|
161
|
+
"-o",
|
|
162
|
+
str(output1_path),
|
|
163
|
+
"-w",
|
|
164
|
+
str(witness1_path),
|
|
165
|
+
]
|
|
166
|
+
witness_cmd2 = [
|
|
167
|
+
sys.executable,
|
|
168
|
+
"-m",
|
|
169
|
+
"python.frontend.cli",
|
|
170
|
+
"witness",
|
|
171
|
+
"-c",
|
|
172
|
+
str(circuit2_path),
|
|
173
|
+
"-i",
|
|
174
|
+
str(input2_path),
|
|
175
|
+
"-o",
|
|
176
|
+
str(output2_path),
|
|
177
|
+
"-w",
|
|
178
|
+
str(witness2_path),
|
|
179
|
+
]
|
|
180
|
+
|
|
181
|
+
# Start both processes
|
|
182
|
+
proc1 = subprocess.Popen(witness_cmd1) # noqa: S603
|
|
183
|
+
proc2 = subprocess.Popen(witness_cmd2) # noqa: S603
|
|
184
|
+
|
|
185
|
+
# Wait for both to complete
|
|
186
|
+
proc1.wait()
|
|
187
|
+
proc2.wait()
|
|
188
|
+
|
|
189
|
+
# Check return codes
|
|
190
|
+
assert proc1.returncode == 0, "Witness failed for model1"
|
|
191
|
+
assert proc2.returncode == 0, "Witness failed for model2"
|
|
192
|
+
|
|
193
|
+
# Verify outputs exist
|
|
194
|
+
assert output1_path.exists(), "Output1 file not generated"
|
|
195
|
+
assert output2_path.exists(), "Output2 file not generated"
|
|
196
|
+
assert witness1_path.exists(), "Witness1 file not generated"
|
|
197
|
+
assert witness2_path.exists(), "Witness2 file not generated"
|
|
198
|
+
|
|
199
|
+
# Check output contents (should have the correct shapes)
|
|
200
|
+
with Path.open(output1_path) as f:
|
|
201
|
+
output1 = json.load(f)
|
|
202
|
+
with Path.open(output2_path) as f:
|
|
203
|
+
output2 = json.load(f)
|
|
204
|
+
|
|
205
|
+
# Model1: input 10 -> output 5
|
|
206
|
+
assert "output" in output1, "Output1 missing 'output' key"
|
|
207
|
+
assert (
|
|
208
|
+
len(output1["output"]) == model1_output_size
|
|
209
|
+
), f"Output1 should have {model1_output_size} elements,"
|
|
210
|
+
f" got {len(output1['output'])}"
|
|
211
|
+
|
|
212
|
+
# Model2: input 20 -> output 8
|
|
213
|
+
assert "output" in output2, "Output2 missing 'output' key"
|
|
214
|
+
assert (
|
|
215
|
+
len(output2["output"]) == model2_output_size
|
|
216
|
+
), f"Output2 should have {model2_output_size} elements,"
|
|
217
|
+
f" got {len(output2['output'])}"
|
|
File without changes
|