onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +18 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/ext_test_case.py +3 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/doc_helper.py +143 -0
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/model_builder_helper.py +24 -8
- onnx_diagnostic/helpers/rt_helper.py +5 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +648 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +133 -16
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/RECORD +39 -23
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.2.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ import os
|
|
|
4
4
|
import sys
|
|
5
5
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
6
|
import time
|
|
7
|
+
import numpy as np
|
|
7
8
|
import onnx
|
|
8
9
|
import onnxscript
|
|
9
10
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
@@ -17,6 +18,7 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
|
17
18
|
from ..tasks import random_input_kwargs
|
|
18
19
|
from ..torch_export_patches import torch_export_patches
|
|
19
20
|
from ..torch_export_patches.patch_inputs import use_dyn_not_str
|
|
21
|
+
from ..reference import TorchOnnxEvaluator
|
|
20
22
|
from .hghub import get_untrained_model_with_inputs
|
|
21
23
|
|
|
22
24
|
|
|
@@ -192,11 +194,16 @@ def _quiet_or_not_quiet(
|
|
|
192
194
|
summary: Dict[str, Any],
|
|
193
195
|
data: Optional[Dict[str, Any]],
|
|
194
196
|
fct: Callable,
|
|
197
|
+
repeat: int = 1,
|
|
198
|
+
warmup: int = 0,
|
|
195
199
|
) -> Any:
|
|
196
200
|
begin = time.perf_counter()
|
|
197
201
|
if quiet:
|
|
198
202
|
try:
|
|
199
|
-
|
|
203
|
+
res = fct()
|
|
204
|
+
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
205
|
+
if warmup + repeat == 1:
|
|
206
|
+
return res
|
|
200
207
|
except Exception as e:
|
|
201
208
|
summary[f"ERR_{suffix}"] = str(e)
|
|
202
209
|
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
@@ -204,11 +211,45 @@ def _quiet_or_not_quiet(
|
|
|
204
211
|
return {f"ERR_{suffix}": e}
|
|
205
212
|
data[f"ERR_{suffix}"] = e
|
|
206
213
|
return None
|
|
207
|
-
|
|
214
|
+
else:
|
|
215
|
+
res = fct()
|
|
208
216
|
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
217
|
+
if warmup + repeat > 1:
|
|
218
|
+
if suffix == "run":
|
|
219
|
+
res = torch_deepcopy(res)
|
|
220
|
+
summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
|
|
221
|
+
summary[f"{suffix}_warmup"] = warmup
|
|
222
|
+
summary[f"{suffix}_repeat"] = repeat
|
|
223
|
+
for _w in range(max(0, warmup - 1)):
|
|
224
|
+
t = fct()
|
|
225
|
+
summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
|
|
226
|
+
summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
|
|
227
|
+
times = []
|
|
228
|
+
for _r in range(repeat):
|
|
229
|
+
begin = time.perf_counter()
|
|
230
|
+
t = fct()
|
|
231
|
+
times.append(time.perf_counter() - begin)
|
|
232
|
+
a = np.array(times)
|
|
233
|
+
summary[f"time_{suffix}_latency"] = a.mean()
|
|
234
|
+
summary[f"time_{suffix}_latency_std"] = a.std()
|
|
235
|
+
summary[f"time_{suffix}_latency_min"] = a.min()
|
|
236
|
+
summary[f"time_{suffix}_latency_min"] = a.max()
|
|
209
237
|
return res
|
|
210
238
|
|
|
211
239
|
|
|
240
|
+
def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
241
|
+
"""Shrinks the configuration before it gets added to the information to log."""
|
|
242
|
+
new_cfg = {}
|
|
243
|
+
for k, v in cfg.items():
|
|
244
|
+
|
|
245
|
+
new_cfg[k] = (
|
|
246
|
+
v
|
|
247
|
+
if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
|
|
248
|
+
else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
|
|
249
|
+
)
|
|
250
|
+
return new_cfg
|
|
251
|
+
|
|
252
|
+
|
|
212
253
|
def validate_model(
|
|
213
254
|
model_id: str,
|
|
214
255
|
task: Optional[str] = None,
|
|
@@ -231,9 +272,14 @@ def validate_model(
|
|
|
231
272
|
model_options: Optional[Dict[str, Any]] = None,
|
|
232
273
|
subfolder: Optional[str] = None,
|
|
233
274
|
opset: Optional[int] = None,
|
|
275
|
+
runtime: str = "onnxruntime",
|
|
276
|
+
repeat: int = 1,
|
|
277
|
+
warmup: int = 0,
|
|
234
278
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
235
279
|
"""
|
|
236
280
|
Validates a model.
|
|
281
|
+
The function can also be called through the command line
|
|
282
|
+
:ref:`l-cmd-validate`.
|
|
237
283
|
|
|
238
284
|
:param model_id: model id to validate
|
|
239
285
|
:param task: task used to generate the necessary inputs,
|
|
@@ -241,7 +287,8 @@ def validate_model(
|
|
|
241
287
|
if it can be determined
|
|
242
288
|
:param do_run: checks the model works with the defined inputs
|
|
243
289
|
:param exporter: exporter the model using this exporter,
|
|
244
|
-
available list: ``export-strict``, ``export-nostrict``,
|
|
290
|
+
available list: ``export-strict``, ``export-nostrict``, ...
|
|
291
|
+
see below
|
|
245
292
|
:param do_same: checks the discrepancies of the exported model
|
|
246
293
|
:param verbose: verbosity level
|
|
247
294
|
:param dtype: uses this dtype to check the model
|
|
@@ -267,6 +314,10 @@ def validate_model(
|
|
|
267
314
|
``num_hidden_layers`` or ``attn_implementation``
|
|
268
315
|
:param subfolder: version or subfolders to uses when retrieving a model id
|
|
269
316
|
:param opset: onnx opset to use for the conversion
|
|
317
|
+
:param runtime: onnx runtime to use to check about discrepancies,
|
|
318
|
+
only if `do_run` is true
|
|
319
|
+
:param repeat: number of time to measure the model
|
|
320
|
+
:param warmup: warmup the model first
|
|
270
321
|
:return: two dictionaries, one with some metrics,
|
|
271
322
|
another one with whatever the function produces
|
|
272
323
|
|
|
@@ -274,6 +325,20 @@ def validate_model(
|
|
|
274
325
|
information:
|
|
275
326
|
|
|
276
327
|
* ``PRINT_CONFIG``: prints the model configuration
|
|
328
|
+
|
|
329
|
+
The following exporters are available:
|
|
330
|
+
|
|
331
|
+
* ``export-nostrict``: run :func:`torch.export.export` (..., strict=False)
|
|
332
|
+
* ``onnx-dynamo``: run :func:`torch.onnx.export` (..., dynamo=True),
|
|
333
|
+
models can be optimized with ``optimization`` in ``("ir", "os_ort")``
|
|
334
|
+
* ``modelbuilder``: use :epkg:`ModelBuilder` to builds the onnx model
|
|
335
|
+
* ``custom``: custom exporter (see :epkg:`experimental-experiment`),
|
|
336
|
+
models can be optimized with ``optimization`` in
|
|
337
|
+
``("default", "default+onnxruntime", "default+os_ort", "default+onnxruntime+os_ort")``
|
|
338
|
+
|
|
339
|
+
The default runtime, :epkg:`onnxruntime` is used to validate a model and check the
|
|
340
|
+
exported model returns the same outputs as the original one, otherwise,
|
|
341
|
+
:class:`onnx_diagnostic.reference.TorchOnnxEvaluator` is used.
|
|
277
342
|
"""
|
|
278
343
|
assert (
|
|
279
344
|
not rewrite or patch
|
|
@@ -295,6 +360,7 @@ def validate_model(
|
|
|
295
360
|
version_ortfusiontype=ortfusiontype or "",
|
|
296
361
|
version_stop_if_static=str(stop_if_static),
|
|
297
362
|
version_exporter=exporter or "",
|
|
363
|
+
version_runtime=runtime,
|
|
298
364
|
)
|
|
299
365
|
)
|
|
300
366
|
if opset:
|
|
@@ -436,7 +502,9 @@ def validate_model(
|
|
|
436
502
|
if summary["model_module"] in sys.modules:
|
|
437
503
|
summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
|
|
438
504
|
summary["model_config_class"] = data["configuration"].__class__.__name__
|
|
439
|
-
summary["model_config"] = str(data["configuration"].to_dict()).replace(
|
|
505
|
+
summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace(
|
|
506
|
+
" ", ""
|
|
507
|
+
)
|
|
440
508
|
summary["model_id"] = model_id
|
|
441
509
|
|
|
442
510
|
if verbose:
|
|
@@ -460,7 +528,13 @@ def validate_model(
|
|
|
460
528
|
model = data["model"]
|
|
461
529
|
|
|
462
530
|
expected = _quiet_or_not_quiet(
|
|
463
|
-
quiet,
|
|
531
|
+
quiet,
|
|
532
|
+
"run",
|
|
533
|
+
summary,
|
|
534
|
+
data,
|
|
535
|
+
(lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
|
|
536
|
+
repeat=repeat,
|
|
537
|
+
warmup=warmup,
|
|
464
538
|
)
|
|
465
539
|
if "ERR_run" in summary:
|
|
466
540
|
return summary, data
|
|
@@ -522,7 +596,7 @@ def validate_model(
|
|
|
522
596
|
|
|
523
597
|
disc = max_diff(data["expected"], expected)
|
|
524
598
|
for k, v in disc.items():
|
|
525
|
-
summary[f"disc_patched_{k}"] = v
|
|
599
|
+
summary[f"disc_patched_{k}"] = str(v)
|
|
526
600
|
if verbose:
|
|
527
601
|
print("[validate_model] done (patched run)")
|
|
528
602
|
print(f"[validate_model] patched discrepancies={string_diff(disc)}")
|
|
@@ -618,7 +692,14 @@ def validate_model(
|
|
|
618
692
|
return summary, data
|
|
619
693
|
|
|
620
694
|
if do_run:
|
|
621
|
-
summary_valid, data = validate_onnx_model(
|
|
695
|
+
summary_valid, data = validate_onnx_model(
|
|
696
|
+
data=data,
|
|
697
|
+
quiet=quiet,
|
|
698
|
+
verbose=verbose,
|
|
699
|
+
runtime=runtime,
|
|
700
|
+
repeat=repeat,
|
|
701
|
+
warmup=warmup,
|
|
702
|
+
)
|
|
622
703
|
summary.update(summary_valid)
|
|
623
704
|
|
|
624
705
|
if ortfusiontype and "onnx_filename" in data:
|
|
@@ -671,7 +752,13 @@ def validate_model(
|
|
|
671
752
|
|
|
672
753
|
if do_run:
|
|
673
754
|
summary_valid, data = validate_onnx_model(
|
|
674
|
-
data=data,
|
|
755
|
+
data=data,
|
|
756
|
+
quiet=quiet,
|
|
757
|
+
verbose=verbose,
|
|
758
|
+
flavour=flavour,
|
|
759
|
+
runtime=runtime,
|
|
760
|
+
repeat=repeat,
|
|
761
|
+
warmup=warmup,
|
|
675
762
|
)
|
|
676
763
|
summary.update(summary_valid)
|
|
677
764
|
|
|
@@ -883,6 +970,9 @@ def validate_onnx_model(
|
|
|
883
970
|
quiet: bool = False,
|
|
884
971
|
verbose: int = 0,
|
|
885
972
|
flavour: Optional[str] = None,
|
|
973
|
+
runtime: str = "onnxruntime",
|
|
974
|
+
repeat: int = 1,
|
|
975
|
+
warmup: int = 0,
|
|
886
976
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
887
977
|
"""
|
|
888
978
|
Verifies that an onnx model produces the same
|
|
@@ -895,6 +985,9 @@ def validate_onnx_model(
|
|
|
895
985
|
:param quiet: catch exception or not
|
|
896
986
|
:param verbose: verbosity
|
|
897
987
|
:param flavour: use a different version of the inputs
|
|
988
|
+
:param runtime: onnx runtime to use, onnxruntime or torch
|
|
989
|
+
:param repeat: run that number of times the model
|
|
990
|
+
:param warmup: warmup the model
|
|
898
991
|
:return: two dictionaries, one with some metrics,
|
|
899
992
|
another one with whatever the function produces
|
|
900
993
|
"""
|
|
@@ -936,18 +1029,28 @@ def validate_onnx_model(
|
|
|
936
1029
|
f"{providers}..., flavour={flavour!r}"
|
|
937
1030
|
)
|
|
938
1031
|
|
|
1032
|
+
cls_runtime = (
|
|
1033
|
+
(
|
|
1034
|
+
lambda model, providers: onnxruntime.InferenceSession(
|
|
1035
|
+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
|
|
1036
|
+
providers=providers,
|
|
1037
|
+
)
|
|
1038
|
+
)
|
|
1039
|
+
if runtime == "onnxruntime"
|
|
1040
|
+
else (
|
|
1041
|
+
lambda model, providers: TorchOnnxEvaluator(
|
|
1042
|
+
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1043
|
+
)
|
|
1044
|
+
)
|
|
1045
|
+
)
|
|
939
1046
|
sess = _quiet_or_not_quiet(
|
|
940
1047
|
quiet,
|
|
941
|
-
_mk("
|
|
1048
|
+
_mk("onnx_ort_create"),
|
|
942
1049
|
summary,
|
|
943
1050
|
data,
|
|
944
|
-
(
|
|
945
|
-
lambda source=source, providers=providers: onnxruntime.InferenceSession(
|
|
946
|
-
source, providers=providers
|
|
947
|
-
)
|
|
948
|
-
),
|
|
1051
|
+
(lambda source=source, providers=providers: cls_runtime(source, providers)),
|
|
949
1052
|
)
|
|
950
|
-
if f"ERR_{_mk('
|
|
1053
|
+
if f"ERR_{_mk('onnx_ort_create')}" in summary:
|
|
951
1054
|
return summary, data
|
|
952
1055
|
|
|
953
1056
|
data[_mk("onnx_ort_sess")] = sess
|
|
@@ -975,6 +1078,8 @@ def validate_onnx_model(
|
|
|
975
1078
|
summary,
|
|
976
1079
|
data,
|
|
977
1080
|
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
|
|
1081
|
+
repeat=repeat,
|
|
1082
|
+
warmup=warmup,
|
|
978
1083
|
)
|
|
979
1084
|
if f"ERR_{_mk('time_onnx_ort_run')}" in summary:
|
|
980
1085
|
return summary, data
|
|
@@ -1051,7 +1156,7 @@ def call_torch_export_onnx(
|
|
|
1051
1156
|
dynamo=False,
|
|
1052
1157
|
dynamic_axes={
|
|
1053
1158
|
k: v
|
|
1054
|
-
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds)
|
|
1159
|
+
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
|
|
1055
1160
|
.replace_by_string()
|
|
1056
1161
|
.items()
|
|
1057
1162
|
if isinstance(v, dict)
|
|
@@ -1229,6 +1334,13 @@ def call_torch_export_custom(
|
|
|
1229
1334
|
"custom-nostrict",
|
|
1230
1335
|
"custom-nostrict-default",
|
|
1231
1336
|
"custom-nostrict-all",
|
|
1337
|
+
"custom-inline",
|
|
1338
|
+
"custom-strict-inline",
|
|
1339
|
+
"custom-strict-default-inline",
|
|
1340
|
+
"custom-strict-all-inline",
|
|
1341
|
+
"custom-nostrict-inline",
|
|
1342
|
+
"custom-nostrict-default-inline",
|
|
1343
|
+
"custom-nostrict-all-inline",
|
|
1232
1344
|
}
|
|
1233
1345
|
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
|
|
1234
1346
|
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
@@ -1269,6 +1381,10 @@ def call_torch_export_custom(
|
|
|
1269
1381
|
),
|
|
1270
1382
|
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
|
|
1271
1383
|
)
|
|
1384
|
+
inline = "-inline" in exporter
|
|
1385
|
+
if inline:
|
|
1386
|
+
export_options.aten_as_function = set()
|
|
1387
|
+
|
|
1272
1388
|
options = OptimizationOptions(patterns=optimization) if optimization else None
|
|
1273
1389
|
model = data["model"]
|
|
1274
1390
|
kws = dict(
|
|
@@ -1279,6 +1395,7 @@ def call_torch_export_custom(
|
|
|
1279
1395
|
large_model=True,
|
|
1280
1396
|
return_optimize_report=True,
|
|
1281
1397
|
verbose=max(verbose - 2, 0),
|
|
1398
|
+
inline=inline,
|
|
1282
1399
|
)
|
|
1283
1400
|
if opset:
|
|
1284
1401
|
kws["target_opset"] = opset
|
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
3
|
+
import onnx
|
|
4
|
+
import torch
|
|
5
|
+
from ..api import TensorLike
|
|
6
|
+
from ..helpers import string_type
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class RuntimeValueKind(enum.IntEnum):
|
|
10
|
+
"Kind of result."
|
|
11
|
+
|
|
12
|
+
RESULT = 1
|
|
13
|
+
INITIALIZER = 3
|
|
14
|
+
INPUT = 5
|
|
15
|
+
OUTPUT = 9
|
|
16
|
+
|
|
17
|
+
def to_str(self) -> str:
|
|
18
|
+
for k, v in self.__class__.__dict__.items():
|
|
19
|
+
if v == int(self):
|
|
20
|
+
return k
|
|
21
|
+
raise RuntimeError(f"Unable to display {self!r}")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class RuntimeDevice(enum.IntEnum):
|
|
25
|
+
"Device definition"
|
|
26
|
+
|
|
27
|
+
UNKNOWN = 0
|
|
28
|
+
NEW = 1
|
|
29
|
+
CPU = 2
|
|
30
|
+
CUDA = 4
|
|
31
|
+
|
|
32
|
+
def to_str(self) -> str:
|
|
33
|
+
for k, v in self.__class__.__dict__.items():
|
|
34
|
+
if v == int(self):
|
|
35
|
+
return k
|
|
36
|
+
raise RuntimeError(f"Unable to display {self!r}")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RuntimeValue:
|
|
40
|
+
"""Describes a value used during the execution of a model."""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
name: str,
|
|
45
|
+
dtype: Optional[Any] = None,
|
|
46
|
+
shape: Optional[Tuple[Union[str, int], ...]] = None,
|
|
47
|
+
value: Optional[Any] = None,
|
|
48
|
+
first_used: Optional[int] = None,
|
|
49
|
+
last_used: Optional[int] = None,
|
|
50
|
+
created: Optional[int] = None,
|
|
51
|
+
is_shape: Optional[bool] = None,
|
|
52
|
+
kind: Optional[RuntimeValueKind] = None,
|
|
53
|
+
device: Optional[RuntimeDevice] = None,
|
|
54
|
+
):
|
|
55
|
+
self.name = name
|
|
56
|
+
self.dtype = dtype
|
|
57
|
+
self.shape = shape
|
|
58
|
+
self.value = value
|
|
59
|
+
self.first_used = first_used
|
|
60
|
+
self.last_used = last_used
|
|
61
|
+
self.created = created
|
|
62
|
+
self.is_shape = is_shape
|
|
63
|
+
self.kind = kind
|
|
64
|
+
self.device = device
|
|
65
|
+
|
|
66
|
+
def __repr__(self) -> str:
|
|
67
|
+
"usual"
|
|
68
|
+
ad = {}
|
|
69
|
+
for att in [
|
|
70
|
+
"name",
|
|
71
|
+
"dtype",
|
|
72
|
+
"shape",
|
|
73
|
+
"first_used",
|
|
74
|
+
"last_used",
|
|
75
|
+
"is_shape",
|
|
76
|
+
"kind",
|
|
77
|
+
"created",
|
|
78
|
+
"device",
|
|
79
|
+
]:
|
|
80
|
+
v = getattr(self, att)
|
|
81
|
+
if v is not None:
|
|
82
|
+
ad[att] = v
|
|
83
|
+
if self.value is not None:
|
|
84
|
+
ad["value"] = (
|
|
85
|
+
self.value.string_type()
|
|
86
|
+
if hasattr(self.value, "string_type")
|
|
87
|
+
else string_type(self.value, with_shape=True)
|
|
88
|
+
)
|
|
89
|
+
msg = ", ".join(
|
|
90
|
+
f"{name}={t.to_str()}" if hasattr(t, "to_str") else f"{name}={t}"
|
|
91
|
+
for name, t in ad.items()
|
|
92
|
+
)
|
|
93
|
+
return f"{self.__class__.__name__}({msg})"
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def has_value(self) -> bool:
|
|
97
|
+
"Tells if value is specified."
|
|
98
|
+
return self.value is not None
|
|
99
|
+
|
|
100
|
+
def string_type(self) -> str:
|
|
101
|
+
"Returns a string describing the value."
|
|
102
|
+
rows = []
|
|
103
|
+
if self.shape is not None:
|
|
104
|
+
rows.append(f"shape={self.shape}")
|
|
105
|
+
if self.is_shape is not None:
|
|
106
|
+
rows.append(f"is_shape={self.is_shape}")
|
|
107
|
+
if self.device is not None:
|
|
108
|
+
rows.append(f"device={self.device}")
|
|
109
|
+
text = f", {', '.join(rows)}" if rows else ""
|
|
110
|
+
if self.value is None:
|
|
111
|
+
return (
|
|
112
|
+
f"RuntimeValue(name={self.name!r}{text}"
|
|
113
|
+
f", dtype={self.dtype}, kind={self.kind})"
|
|
114
|
+
)
|
|
115
|
+
return (
|
|
116
|
+
f"RuntimeValue(name={self.name!r}, "
|
|
117
|
+
f"kind={self.kind}{text}, value={self.value.string_type()})"
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def set_value(self, value: Union[torch.Tensor, TensorLike]):
|
|
121
|
+
"""Sets the value."""
|
|
122
|
+
assert value is not None, "Use clean_value to set a value to None"
|
|
123
|
+
self.value = value
|
|
124
|
+
is_sequence = hasattr(value, "is_sequence") and value.is_sequence()
|
|
125
|
+
if self.dtype:
|
|
126
|
+
assert value is None or self.dtype == value.dtype, (
|
|
127
|
+
f"Unexpected dtype={value.dtype}, previous dtype was {self.dtype}, "
|
|
128
|
+
f"is_sequence={is_sequence}"
|
|
129
|
+
)
|
|
130
|
+
else:
|
|
131
|
+
self.dtype = value.dtype
|
|
132
|
+
self.shape = None if is_sequence else tuple(map(int, value.shape))
|
|
133
|
+
|
|
134
|
+
def clean_value(self):
|
|
135
|
+
"""Sets value to None."""
|
|
136
|
+
self.value = None
|
|
137
|
+
|
|
138
|
+
@property
|
|
139
|
+
def is_output(self) -> bool:
|
|
140
|
+
"Tells if it is an output."
|
|
141
|
+
return self.kind == RuntimeValueKind.OUTPUT
|
|
142
|
+
|
|
143
|
+
@property
|
|
144
|
+
def is_input(self) -> bool:
|
|
145
|
+
"Tells if it is an input."
|
|
146
|
+
return self.kind == RuntimeValueKind.INPUT
|
|
147
|
+
|
|
148
|
+
@property
|
|
149
|
+
def is_initializer(self) -> bool:
|
|
150
|
+
"Tells if it is an initializer."
|
|
151
|
+
return self.kind == RuntimeValueKind.INITIALIZER
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
|
|
155
|
+
"""
|
|
156
|
+
Returns the hidden inputs (inputs coming from an upper context)
|
|
157
|
+
used by a subgraph.
|
|
158
|
+
"""
|
|
159
|
+
hidden = set()
|
|
160
|
+
memo = (
|
|
161
|
+
set(i.name for i in graph.initializer)
|
|
162
|
+
| set(i.name for i in graph.sparse_initializer)
|
|
163
|
+
| set(i.name for i in graph.input)
|
|
164
|
+
)
|
|
165
|
+
for node in graph.node:
|
|
166
|
+
for i in node.input:
|
|
167
|
+
if i not in memo:
|
|
168
|
+
hidden.add(i)
|
|
169
|
+
for att in node.attribute:
|
|
170
|
+
if att.type == onnx.AttributeProto.GRAPH and att.g:
|
|
171
|
+
hid = get_hidden_inputs(att.g)
|
|
172
|
+
less = set(h for h in hid if h not in memo)
|
|
173
|
+
hidden |= less
|
|
174
|
+
memo |= set(node.output)
|
|
175
|
+
return hidden
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def set_is_shape(
|
|
179
|
+
node: onnx.NodeProto, values: Dict[str, RuntimeValue], drop: Optional[Set[str]] = None
|
|
180
|
+
) -> List[str]:
|
|
181
|
+
"""
|
|
182
|
+
Sets attribute ``is_shape`` for outputs of a node.
|
|
183
|
+
|
|
184
|
+
:param node: node to process
|
|
185
|
+
:param values: stored results, values in this dictionary are updated
|
|
186
|
+
:param drop: variables not to consider because the come from the graph
|
|
187
|
+
holding this subgraph
|
|
188
|
+
:return: list of modified results
|
|
189
|
+
"""
|
|
190
|
+
if not node.input:
|
|
191
|
+
# Constant
|
|
192
|
+
return []
|
|
193
|
+
drop = drop or set()
|
|
194
|
+
if node.op_type in ("Shape", "Size") and node.domain == "":
|
|
195
|
+
values[node.output[0]].is_shape = True
|
|
196
|
+
return [node.output[0]]
|
|
197
|
+
is_shapes = [values[i].is_shape for i in node.input if i not in drop]
|
|
198
|
+
if any(is_shapes):
|
|
199
|
+
if is_shapes[0] and len(node.output) == 1:
|
|
200
|
+
values[node.output[0]].is_shape = True
|
|
201
|
+
return [node.output[0]]
|
|
202
|
+
else:
|
|
203
|
+
for o in node.output:
|
|
204
|
+
values[o].is_shape = False
|
|
205
|
+
return list(node.output)
|
|
206
|
+
return []
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def first_used_last_used(
|
|
210
|
+
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
|
|
211
|
+
constant_as_initializer: bool = False,
|
|
212
|
+
) -> Dict[str, RuntimeValue]:
|
|
213
|
+
"""
|
|
214
|
+
Builds first used, last used information for every result
|
|
215
|
+
in the model.
|
|
216
|
+
|
|
217
|
+
:param proto: model, graph or function
|
|
218
|
+
:param constant_as_initializer: outputs of node Constant is tagged as INITIALIZER
|
|
219
|
+
:return: dictionary of RuntimeValue
|
|
220
|
+
"""
|
|
221
|
+
values = {}
|
|
222
|
+
if isinstance(proto, onnx.ModelProto):
|
|
223
|
+
initializer = proto.graph.initializer
|
|
224
|
+
sparse_initializer = proto.graph.sparse_initializer
|
|
225
|
+
_input = proto.graph.input
|
|
226
|
+
output = proto.graph.output
|
|
227
|
+
_node = proto.graph.node
|
|
228
|
+
allow_unknown = False
|
|
229
|
+
elif isinstance(proto, onnx.GraphProto):
|
|
230
|
+
initializer = proto.initializer
|
|
231
|
+
sparse_initializer = proto.sparse_initializer
|
|
232
|
+
_input = proto.input
|
|
233
|
+
output = proto.output
|
|
234
|
+
_node = proto.node
|
|
235
|
+
allow_unknown = True
|
|
236
|
+
else:
|
|
237
|
+
initializer = []
|
|
238
|
+
sparse_initializer = []
|
|
239
|
+
_input = proto.input
|
|
240
|
+
output = proto.output
|
|
241
|
+
_node = proto.node
|
|
242
|
+
allow_unknown = False
|
|
243
|
+
|
|
244
|
+
for init in initializer:
|
|
245
|
+
values[init.name] = RuntimeValue(
|
|
246
|
+
init.name, kind=RuntimeValueKind.INITIALIZER, created=-1
|
|
247
|
+
)
|
|
248
|
+
for init in sparse_initializer:
|
|
249
|
+
values[init.name] = RuntimeValue(
|
|
250
|
+
init.name, created=-1, kind=RuntimeValueKind.INITIALIZER
|
|
251
|
+
)
|
|
252
|
+
for inp in _input:
|
|
253
|
+
n = inp if isinstance(inp, str) else inp.name
|
|
254
|
+
values[n] = RuntimeValue(n, created=-1, kind=RuntimeValueKind.INPUT)
|
|
255
|
+
drop = set()
|
|
256
|
+
for it, node in enumerate(_node):
|
|
257
|
+
for i in node.input:
|
|
258
|
+
if i not in values:
|
|
259
|
+
assert allow_unknown, f"Input {i!r} is unknown."
|
|
260
|
+
# This input comes from a context and the model is a GraphProto
|
|
261
|
+
drop.add(i)
|
|
262
|
+
continue
|
|
263
|
+
if values[i].first_used is None:
|
|
264
|
+
values[i].first_used = it
|
|
265
|
+
values[i].last_used = it
|
|
266
|
+
for att in node.attribute:
|
|
267
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
268
|
+
for n in get_hidden_inputs(att.g):
|
|
269
|
+
if values[n].first_used is None:
|
|
270
|
+
values[n].first_used = it
|
|
271
|
+
values[n].last_used = it
|
|
272
|
+
is_constant = node.op_type == "Constant" and node.domain == ""
|
|
273
|
+
for o in node.output:
|
|
274
|
+
values[o] = RuntimeValue(
|
|
275
|
+
o,
|
|
276
|
+
created=it,
|
|
277
|
+
kind=(
|
|
278
|
+
RuntimeValueKind.INITIALIZER
|
|
279
|
+
if is_constant and constant_as_initializer
|
|
280
|
+
else RuntimeValueKind.RESULT
|
|
281
|
+
),
|
|
282
|
+
)
|
|
283
|
+
set_is_shape(node, values, drop=drop)
|
|
284
|
+
|
|
285
|
+
for out in output:
|
|
286
|
+
n = out if isinstance(out, str) else out.name
|
|
287
|
+
values[n].kind = RuntimeValueKind.OUTPUT
|
|
288
|
+
values[n].last_used = len(_node)
|
|
289
|
+
return values
|