onnx-diagnostic 0.8.10__py3-none-any.whl → 0.8.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +136 -140
- onnx_diagnostic/ci_models/export_phi4_mm.py +2 -4
- onnx_diagnostic/export/api.py +2 -4
- onnx_diagnostic/export/validate.py +2 -0
- onnx_diagnostic/ext_test_case.py +32 -15
- onnx_diagnostic/helpers/args_helper.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +0 -1
- onnx_diagnostic/helpers/cache_helper.py +6 -6
- onnx_diagnostic/helpers/doc_helper.py +7 -4
- onnx_diagnostic/helpers/graph_helper.py +6 -6
- onnx_diagnostic/helpers/log_helper.py +37 -14
- onnx_diagnostic/helpers/memory_peak.py +5 -1
- onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
- onnx_diagnostic/helpers/model_builder_helper.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +283 -110
- onnx_diagnostic/helpers/ort_session.py +0 -1
- onnx_diagnostic/helpers/torch_helper.py +8 -9
- onnx_diagnostic/investigate/__init__.py +0 -0
- onnx_diagnostic/investigate/input_observer.py +329 -0
- onnx_diagnostic/reference/evaluator.py +0 -1
- onnx_diagnostic/reference/ort_evaluator.py +0 -1
- onnx_diagnostic/reference/report_results_comparison.py +9 -3
- onnx_diagnostic/reference/torch_evaluator.py +5 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +0 -1
- onnx_diagnostic/torch_export_patches/__init__.py +0 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +44 -23
- onnx_diagnostic/torch_models/code_sample.py +5 -10
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
- onnx_diagnostic/torch_models/validate.py +1 -1
- onnx_diagnostic/torch_onnx/compare.py +0 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
- onnx_diagnostic/typing.py +15 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/RECORD +45 -43
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/WHEEL +1 -1
- onnx_diagnostic/api.py +0 -15
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.8.11.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import json
|
|
3
3
|
import os
|
|
4
|
+
import re
|
|
4
5
|
import sys
|
|
5
6
|
import warnings
|
|
6
7
|
from typing import (
|
|
@@ -32,11 +33,10 @@ from onnx import (
|
|
|
32
33
|
ValueInfoProto,
|
|
33
34
|
load as onnx_load,
|
|
34
35
|
)
|
|
36
|
+
from ..typing import InferenceSessionLike, TensorLike
|
|
35
37
|
|
|
36
|
-
TensorLike = Union[np.ndarray, "torch.Tensor"] # noqa: F821
|
|
37
38
|
|
|
38
|
-
|
|
39
|
-
def _make_stat(init: TensorProto) -> Dict[str, float]:
|
|
39
|
+
def _make_stat(init: TensorProto) -> Dict[str, Any]:
|
|
40
40
|
"""
|
|
41
41
|
Produces statistics.
|
|
42
42
|
|
|
@@ -160,11 +160,11 @@ def _validate_graph(
|
|
|
160
160
|
verbose: int = 0,
|
|
161
161
|
watch: Optional[Set[str]] = None,
|
|
162
162
|
path: Optional[Sequence[str]] = None,
|
|
163
|
-
):
|
|
164
|
-
found = []
|
|
163
|
+
) -> List[Union[NodeProto, TensorProto, ValueInfoProto]]:
|
|
164
|
+
found: List[Union[NodeProto, TensorProto, ValueInfoProto]] = []
|
|
165
165
|
path = path or ["root"]
|
|
166
|
-
set_init =
|
|
167
|
-
set_input =
|
|
166
|
+
set_init = {i.name for i in g.initializer}
|
|
167
|
+
set_input = {i.name for i in g.input}
|
|
168
168
|
existing |= set_init | set_input
|
|
169
169
|
if watch and set_init & watch:
|
|
170
170
|
if verbose:
|
|
@@ -215,18 +215,15 @@ def _validate_graph(
|
|
|
215
215
|
f"in {'/'.join(path)}/{node.op_type}[{node.name}]"
|
|
216
216
|
)
|
|
217
217
|
found.append(node)
|
|
218
|
-
out =
|
|
218
|
+
out = {o.name for o in g.output}
|
|
219
219
|
ins = out & existing
|
|
220
|
-
|
|
221
|
-
raise AssertionError(
|
|
222
|
-
f"One output is missing, out={node.input}, existing={ins}, path={path}"
|
|
223
|
-
)
|
|
220
|
+
assert ins == out, f"One output is missing, out={node.input}, existing={ins}, path={path}"
|
|
224
221
|
return found
|
|
225
222
|
|
|
226
223
|
|
|
227
224
|
def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[str]] = None):
|
|
228
|
-
existing = set(g.input)
|
|
229
|
-
found = []
|
|
225
|
+
existing: Set[str] = set(g.input)
|
|
226
|
+
found: List[Union[NodeProto, TensorProto, ValueInfoProto]] = []
|
|
230
227
|
for node in g.node:
|
|
231
228
|
ins = set(node.input) & existing
|
|
232
229
|
if ins != set(node.input):
|
|
@@ -240,7 +237,7 @@ def _validate_function(g: FunctionProto, verbose: int = 0, watch: Optional[Set[s
|
|
|
240
237
|
for att in node.attribute:
|
|
241
238
|
if att.type == AttributeProto.GRAPH:
|
|
242
239
|
found.extend(
|
|
243
|
-
_validate_graph(g, existing.copy(), path=[g.name], verbose=verbose)
|
|
240
|
+
_validate_graph(att.g, existing.copy(), path=[g.name], verbose=verbose)
|
|
244
241
|
)
|
|
245
242
|
existing |= set(node.output)
|
|
246
243
|
if watch and set(node.output) & watch:
|
|
@@ -285,7 +282,7 @@ def check_model_ort(
|
|
|
285
282
|
onx: ModelProto,
|
|
286
283
|
providers: Optional[Union[str, List[Any]]] = None,
|
|
287
284
|
dump_file: Optional[str] = None,
|
|
288
|
-
) ->
|
|
285
|
+
) -> InferenceSessionLike:
|
|
289
286
|
"""
|
|
290
287
|
Loads a model with onnxruntime.
|
|
291
288
|
|
|
@@ -308,10 +305,9 @@ def check_model_ort(
|
|
|
308
305
|
|
|
309
306
|
if isinstance(onx, str):
|
|
310
307
|
try:
|
|
308
|
+
# pyrefly: ignore[bad-return]
|
|
311
309
|
return InferenceSession(onx, providers=providers)
|
|
312
310
|
except Exception as e:
|
|
313
|
-
import onnx
|
|
314
|
-
|
|
315
311
|
if dump_file:
|
|
316
312
|
onnx.save(onx, dump_file)
|
|
317
313
|
|
|
@@ -319,8 +315,8 @@ def check_model_ort(
|
|
|
319
315
|
f"onnxruntime cannot load the model "
|
|
320
316
|
f"due to {e}\n{pretty_onnx(onnx.load(onx))}"
|
|
321
317
|
)
|
|
322
|
-
return
|
|
323
318
|
try:
|
|
319
|
+
# pyrefly: ignore[bad-return]
|
|
324
320
|
return InferenceSession(onx.SerializeToString(), providers=providers)
|
|
325
321
|
except Exception as e:
|
|
326
322
|
if dump_file:
|
|
@@ -358,7 +354,17 @@ def onnx_dtype_name(itype: int, exc: bool = True) -> str:
|
|
|
358
354
|
|
|
359
355
|
|
|
360
356
|
def pretty_onnx(
|
|
361
|
-
onx: Union[
|
|
357
|
+
onx: Union[
|
|
358
|
+
AttributeProto,
|
|
359
|
+
FunctionProto,
|
|
360
|
+
GraphProto,
|
|
361
|
+
ModelProto,
|
|
362
|
+
NodeProto,
|
|
363
|
+
onnx.SparseTensorProto,
|
|
364
|
+
TensorProto,
|
|
365
|
+
ValueInfoProto,
|
|
366
|
+
str,
|
|
367
|
+
],
|
|
362
368
|
with_attributes: bool = False,
|
|
363
369
|
highlight: Optional[Set[str]] = None,
|
|
364
370
|
shape_inference: bool = False,
|
|
@@ -377,6 +383,9 @@ def pretty_onnx(
|
|
|
377
383
|
assert onx is not None, "onx cannot be None"
|
|
378
384
|
|
|
379
385
|
if shape_inference:
|
|
386
|
+
assert isinstance(
|
|
387
|
+
onx, ModelProto
|
|
388
|
+
), f"shape inference only works for ModelProto, not {type(onx)}"
|
|
380
389
|
onx = onnx.shape_inference.infer_shapes(onx)
|
|
381
390
|
|
|
382
391
|
if isinstance(onx, ValueInfoProto):
|
|
@@ -447,6 +456,8 @@ def pretty_onnx(
|
|
|
447
456
|
shape = "x".join(map(str, onx.dims))
|
|
448
457
|
return f"TensorProto:{onx.data_type}:{shape}:{onx.name}"
|
|
449
458
|
|
|
459
|
+
assert not isinstance(onx, onnx.SparseTensorProto), "SparseTensorProto is not handled yet."
|
|
460
|
+
|
|
450
461
|
try:
|
|
451
462
|
from onnx_array_api.plotting.text_plot import onnx_simple_text_plot
|
|
452
463
|
|
|
@@ -538,12 +549,6 @@ def from_array_ml_dtypes(arr: TensorLike, name: Optional[str] = None) -> TensorP
|
|
|
538
549
|
return tensor
|
|
539
550
|
|
|
540
551
|
|
|
541
|
-
_STORAGE_TYPE = {
|
|
542
|
-
TensorProto.FLOAT16: np.int16,
|
|
543
|
-
TensorProto.BFLOAT16: np.int16,
|
|
544
|
-
}
|
|
545
|
-
|
|
546
|
-
|
|
547
552
|
def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> TensorProto:
|
|
548
553
|
"""
|
|
549
554
|
Converts an array into a :class:`onnx.TensorProto`.
|
|
@@ -561,54 +566,9 @@ def from_array_extended(tensor: TensorLike, name: Optional[str] = None) -> Tenso
|
|
|
561
566
|
), f"Unable to convert type {type(tensor)} into TensorProto."
|
|
562
567
|
return proto_from_tensor(tensor, name=name)
|
|
563
568
|
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
float8e4m3fn,
|
|
568
|
-
float8e4m3fnuz,
|
|
569
|
-
float8e5m2,
|
|
570
|
-
float8e5m2fnuz,
|
|
571
|
-
)
|
|
572
|
-
except ImportError:
|
|
573
|
-
bfloat16 = None
|
|
574
|
-
|
|
575
|
-
if bfloat16 is None:
|
|
576
|
-
return onh.from_array(tensor, name)
|
|
577
|
-
|
|
578
|
-
dt = tensor.dtype
|
|
579
|
-
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
|
|
580
|
-
to = TensorProto.FLOAT8E4M3FN
|
|
581
|
-
dt_to = np.uint8
|
|
582
|
-
elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
|
|
583
|
-
to = TensorProto.FLOAT8E4M3FNUZ
|
|
584
|
-
dt_to = np.uint8
|
|
585
|
-
elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
|
|
586
|
-
to = TensorProto.FLOAT8E5M2
|
|
587
|
-
dt_to = np.uint8
|
|
588
|
-
elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
|
|
589
|
-
to = TensorProto.FLOAT8E5M2FNUZ
|
|
590
|
-
dt_to = np.uint8
|
|
591
|
-
elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
|
|
592
|
-
to = TensorProto.BFLOAT16
|
|
593
|
-
dt_to = np.uint16
|
|
594
|
-
else:
|
|
595
|
-
try:
|
|
596
|
-
import ml_dtypes
|
|
597
|
-
except ImportError:
|
|
598
|
-
ml_dtypes = None
|
|
599
|
-
if ml_dtypes is not None and (
|
|
600
|
-
tensor.dtype == ml_dtypes.bfloat16
|
|
601
|
-
or tensor.dtype == ml_dtypes.float8_e4m3fn
|
|
602
|
-
or tensor.dtype == ml_dtypes.float8_e4m3fnuz
|
|
603
|
-
or tensor.dtype == ml_dtypes.float8_e5m2
|
|
604
|
-
or tensor.dtype == ml_dtypes.float8_e5m2fnuz
|
|
605
|
-
):
|
|
606
|
-
return from_array_ml_dtypes(tensor, name)
|
|
607
|
-
return onh.from_array(tensor, name)
|
|
608
|
-
|
|
609
|
-
t = onh.from_array(tensor.astype(dt_to), name)
|
|
610
|
-
t.data_type = to
|
|
611
|
-
return t
|
|
569
|
+
assert isinstance(tensor, np.ndarray) # type checking
|
|
570
|
+
# pyrefly: ignore[bad-argument-type]
|
|
571
|
+
return onh.from_array(tensor, name)
|
|
612
572
|
|
|
613
573
|
|
|
614
574
|
def to_array_extended(proto: TensorProto) -> TensorLike:
|
|
@@ -666,6 +626,7 @@ def onnx_dtype_to_np_dtype(itype: int) -> Any:
|
|
|
666
626
|
)
|
|
667
627
|
|
|
668
628
|
|
|
629
|
+
# pyrefly: ignore[unknown-name]
|
|
669
630
|
def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F821
|
|
670
631
|
"""
|
|
671
632
|
Converts a torch dtype or numpy dtype into a onnx element type.
|
|
@@ -679,6 +640,7 @@ def dtype_to_tensor_dtype(dt: Union[np.dtype, "torch.dtype"]) -> int: # noqa: F
|
|
|
679
640
|
pass
|
|
680
641
|
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
681
642
|
|
|
643
|
+
# pyrefly: ignore[bad-argument-type]
|
|
682
644
|
return torch_dtype_to_onnx_dtype(dt)
|
|
683
645
|
|
|
684
646
|
|
|
@@ -779,6 +741,7 @@ def tensor_dtype_to_np_dtype(tensor_dtype: int) -> np.dtype:
|
|
|
779
741
|
f"ml_dtypes can be used."
|
|
780
742
|
) from e
|
|
781
743
|
|
|
744
|
+
# pyrefly: ignore[bad-assignment]
|
|
782
745
|
mapping: Dict[int, np.dtype] = {
|
|
783
746
|
TensorProto.BFLOAT16: ml_dtypes.bfloat16,
|
|
784
747
|
TensorProto.FLOAT8E4M3FN: ml_dtypes.float8_e4m3fn,
|
|
@@ -798,7 +761,7 @@ def iterator_initializer_constant(
|
|
|
798
761
|
model: Union[FunctionProto, GraphProto, ModelProto],
|
|
799
762
|
use_numpy: bool = True,
|
|
800
763
|
prefix: str = "",
|
|
801
|
-
) -> Iterator[Tuple[str,
|
|
764
|
+
) -> Iterator[Tuple[str, TensorLike]]: # noqa: F821
|
|
802
765
|
"""
|
|
803
766
|
Iterates on iniatialiers and constant in an onnx model.
|
|
804
767
|
|
|
@@ -814,9 +777,12 @@ def iterator_initializer_constant(
|
|
|
814
777
|
if prefix:
|
|
815
778
|
prefix += "."
|
|
816
779
|
for init in graph.initializer:
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
780
|
+
s = f"{prefix}{init.name}"
|
|
781
|
+
if use_numpy:
|
|
782
|
+
yield s, to_array_extended(init)
|
|
783
|
+
else:
|
|
784
|
+
# pyrefly: ignore[unbound-name]
|
|
785
|
+
yield s, to_tensor(init)
|
|
820
786
|
nodes = graph.node
|
|
821
787
|
name = graph.name
|
|
822
788
|
if isinstance(model, ModelProto):
|
|
@@ -831,13 +797,15 @@ def iterator_initializer_constant(
|
|
|
831
797
|
if node.op_type == "Constant" and node.domain == "":
|
|
832
798
|
from ..reference import ExtendedReferenceEvaluator as Inference
|
|
833
799
|
|
|
834
|
-
if not use_numpy:
|
|
835
|
-
import torch
|
|
836
800
|
sess = Inference(node)
|
|
837
801
|
value = sess.run(None, {})[0]
|
|
838
|
-
|
|
839
|
-
|
|
840
|
-
|
|
802
|
+
|
|
803
|
+
if not use_numpy:
|
|
804
|
+
import torch
|
|
805
|
+
|
|
806
|
+
yield f"{prefix}{node.output[0]}", (torch.from_numpy(value))
|
|
807
|
+
else:
|
|
808
|
+
yield f"{prefix}{node.output[0]}", (value)
|
|
841
809
|
|
|
842
810
|
if node.op_type in {"Loop", "Body", "Scan"}:
|
|
843
811
|
for att in node.attribute:
|
|
@@ -870,7 +838,9 @@ def tensor_statistics(tensor: Union[np.ndarray, TensorProto]) -> Dict[str, Union
|
|
|
870
838
|
from .helper import size_type
|
|
871
839
|
|
|
872
840
|
if isinstance(tensor, TensorProto):
|
|
841
|
+
# pyrefly: ignore[bad-assignment]
|
|
873
842
|
tensor = to_array_extended(tensor)
|
|
843
|
+
assert isinstance(tensor, np.ndarray) # type checking
|
|
874
844
|
itype = np_dtype_to_tensor_dtype(tensor.dtype)
|
|
875
845
|
stat = dict(
|
|
876
846
|
mean=float(tensor.mean()),
|
|
@@ -948,7 +918,7 @@ class NodeCoordinates:
|
|
|
948
918
|
|
|
949
919
|
def __init__(
|
|
950
920
|
self,
|
|
951
|
-
node: Union[
|
|
921
|
+
node: Union[TensorProto, NodeProto, onnx.SparseTensorProto, ValueInfoProto, str],
|
|
952
922
|
path: Tuple[Tuple[int, str, str], ...],
|
|
953
923
|
):
|
|
954
924
|
assert isinstance(path, tuple), f"Unexpected type {type(path)} for path"
|
|
@@ -968,9 +938,7 @@ class NodeCoordinates:
|
|
|
968
938
|
|
|
969
939
|
|
|
970
940
|
class ResultFound:
|
|
971
|
-
"""
|
|
972
|
-
Class returned by :func:`enumerate_results`.
|
|
973
|
-
"""
|
|
941
|
+
"""Class returned by :func:`enumerate_results`."""
|
|
974
942
|
|
|
975
943
|
__slots__ = ("consumer", "name", "producer")
|
|
976
944
|
|
|
@@ -1060,9 +1028,9 @@ def enumerate_results(
|
|
|
1060
1028
|
print(f"[enumerate_results] {indent}-- {r}")
|
|
1061
1029
|
yield r
|
|
1062
1030
|
for i in proto.sparse_initializer:
|
|
1063
|
-
if i.name in name:
|
|
1031
|
+
if i.values.name in name:
|
|
1064
1032
|
r = ResultFound(
|
|
1065
|
-
i.name,
|
|
1033
|
+
i.values.name,
|
|
1066
1034
|
NodeCoordinates(i, tuple([*coordinates, (-1, "INIT", "")])), # noqa: C409
|
|
1067
1035
|
None,
|
|
1068
1036
|
)
|
|
@@ -1165,9 +1133,9 @@ def shadowing_names(
|
|
|
1165
1133
|
return shadowing_names(
|
|
1166
1134
|
proto.node,
|
|
1167
1135
|
verbose=verbose,
|
|
1168
|
-
existing=
|
|
1169
|
-
|
|
|
1170
|
-
|
|
|
1136
|
+
existing={i.name for i in proto.initializer}
|
|
1137
|
+
| {i.values.name for i in proto.sparse_initializer}
|
|
1138
|
+
| {i.name for i in proto.input if i.name},
|
|
1171
1139
|
shadow_context=set(),
|
|
1172
1140
|
post_shadow_context=set(),
|
|
1173
1141
|
)
|
|
@@ -1201,9 +1169,9 @@ def shadowing_names(
|
|
|
1201
1169
|
for att in node.attribute:
|
|
1202
1170
|
if att.type == AttributeProto.GRAPH:
|
|
1203
1171
|
g = att.g
|
|
1204
|
-
shadow |=
|
|
1205
|
-
shadow |=
|
|
1206
|
-
shadow |=
|
|
1172
|
+
shadow |= {i.name for i in g.input} & shadow_context
|
|
1173
|
+
shadow |= {i.name for i in g.initializer} & shadow_context
|
|
1174
|
+
shadow |= {i.values.name for i in g.sparse_initializer} & shadow_context
|
|
1207
1175
|
s, _ps, c = shadowing_names(
|
|
1208
1176
|
g.node, verbose=verbose, existing=existing, shadow_context=existing
|
|
1209
1177
|
)
|
|
@@ -1225,9 +1193,9 @@ def get_hidden_inputs(graph: onnx.GraphProto) -> Set[str]:
|
|
|
1225
1193
|
"""
|
|
1226
1194
|
hidden = set()
|
|
1227
1195
|
memo = (
|
|
1228
|
-
|
|
1229
|
-
|
|
|
1230
|
-
|
|
|
1196
|
+
{i.name for i in graph.initializer}
|
|
1197
|
+
| {i.values.name for i in graph.sparse_initializer}
|
|
1198
|
+
| {i.name for i in graph.input}
|
|
1231
1199
|
)
|
|
1232
1200
|
for node in graph.node:
|
|
1233
1201
|
for i in node.input:
|
|
@@ -1353,7 +1321,6 @@ def make_submodel(
|
|
|
1353
1321
|
Creates a model with the given list of nodes.
|
|
1354
1322
|
It computes the minimum list of inputs needed for this model.
|
|
1355
1323
|
The function assumes the nodes are sorted.
|
|
1356
|
-
It does not handle yet subgraphs.
|
|
1357
1324
|
|
|
1358
1325
|
:param nodes: list of nodes
|
|
1359
1326
|
:param ir_version: ir version
|
|
@@ -1376,25 +1343,61 @@ def make_submodel(
|
|
|
1376
1343
|
if att.type == onnx.AttributeProto.GRAPH:
|
|
1377
1344
|
not_known |= get_hidden_inputs(att.g)
|
|
1378
1345
|
|
|
1379
|
-
|
|
1346
|
+
return oh.make_model(
|
|
1380
1347
|
oh.make_graph(
|
|
1381
1348
|
nodes,
|
|
1382
1349
|
"submodel",
|
|
1383
1350
|
[_mkv_(n, *type_rank_fn(n)) for n in sorted(not_known) if n],
|
|
1384
|
-
[_mkv_(n, *type_rank_fn(n)) for n in
|
|
1351
|
+
[_mkv_(n, *type_rank_fn(n)) for n in output_names if n],
|
|
1385
1352
|
),
|
|
1386
1353
|
ir_version=ir_version,
|
|
1387
1354
|
opset_imports=opset_imports,
|
|
1388
1355
|
)
|
|
1389
|
-
|
|
1356
|
+
|
|
1357
|
+
|
|
1358
|
+
def make_subfunction(
|
|
1359
|
+
name: str,
|
|
1360
|
+
nodes: List[NodeProto],
|
|
1361
|
+
opset_imports: Sequence[OperatorSetIdProto],
|
|
1362
|
+
output_names: List[str],
|
|
1363
|
+
domain: str = "local_function",
|
|
1364
|
+
) -> FunctionProto:
|
|
1365
|
+
"""
|
|
1366
|
+
Creates a function with the given list of nodes.
|
|
1367
|
+
It computes the minimum list of inputs needed for this model.
|
|
1368
|
+
The function assumes the nodes are sorted.
|
|
1369
|
+
|
|
1370
|
+
:param name: function name
|
|
1371
|
+
:param nodes: list of nodes
|
|
1372
|
+
:param opset_imports: opset import
|
|
1373
|
+
:param output_names: desired outputs
|
|
1374
|
+
:param domain: function domain
|
|
1375
|
+
:return: model proto
|
|
1376
|
+
"""
|
|
1377
|
+
not_known: Set[str] = set()
|
|
1378
|
+
for node in nodes[::-1]:
|
|
1379
|
+
not_known -= {o for o in node.output if o}
|
|
1380
|
+
not_known |= {i for i in node.input if i}
|
|
1381
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1382
|
+
# there are hidden inputs
|
|
1383
|
+
for att in node.attribute:
|
|
1384
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1385
|
+
not_known |= get_hidden_inputs(att.g)
|
|
1386
|
+
|
|
1387
|
+
return oh.make_function(
|
|
1388
|
+
domain,
|
|
1389
|
+
name,
|
|
1390
|
+
nodes=nodes,
|
|
1391
|
+
inputs=sorted(not_known),
|
|
1392
|
+
outputs=output_names,
|
|
1393
|
+
opset_imports=opset_imports,
|
|
1394
|
+
)
|
|
1390
1395
|
|
|
1391
1396
|
|
|
1392
1397
|
def get_tensor_shape(
|
|
1393
1398
|
obj: Union[onnx.ValueInfoProto, onnx.TypeProto, onnx.TensorProto],
|
|
1394
1399
|
) -> Optional[List[Optional[Union[int, str]]]]:
|
|
1395
|
-
"""
|
|
1396
|
-
Returns the shape if that makes sense for this object.
|
|
1397
|
-
"""
|
|
1400
|
+
"""Returns the shape if that makes sense for this object."""
|
|
1398
1401
|
if isinstance(obj, ValueInfoProto):
|
|
1399
1402
|
return get_tensor_shape(obj.type)
|
|
1400
1403
|
elif not isinstance(obj, onnx.TypeProto):
|
|
@@ -1512,9 +1515,6 @@ def onnx_remove_node_unused(
|
|
|
1512
1515
|
if not ({o for o in node.output if o} & marked_set):
|
|
1513
1516
|
removed.add(ind)
|
|
1514
1517
|
|
|
1515
|
-
if not is_function:
|
|
1516
|
-
initializers = [i for i in graph.initializer if i.name in marked]
|
|
1517
|
-
sparse_initializers = [i for i in graph.sparse_initializer if i.name in marked]
|
|
1518
1518
|
new_nodes = [node for i, node in enumerate(nodes) if i not in removed]
|
|
1519
1519
|
|
|
1520
1520
|
# Finally create the new graph.
|
|
@@ -1529,13 +1529,16 @@ def onnx_remove_node_unused(
|
|
|
1529
1529
|
attributes=graph.attribute,
|
|
1530
1530
|
doc_string=graph.doc_string,
|
|
1531
1531
|
)
|
|
1532
|
+
|
|
1533
|
+
initializers = [i for i in graph.initializer if i.name in marked]
|
|
1534
|
+
sparse_initializers = [i for i in graph.sparse_initializer if i.values.name in marked]
|
|
1532
1535
|
new_graph = oh.make_graph(
|
|
1533
1536
|
new_nodes,
|
|
1534
1537
|
graph.name,
|
|
1535
1538
|
graph.input,
|
|
1536
1539
|
graph.output,
|
|
1537
1540
|
initializers,
|
|
1538
|
-
sparse_initializers,
|
|
1541
|
+
sparse_initializer=sparse_initializers,
|
|
1539
1542
|
)
|
|
1540
1543
|
new_graph.value_info.extend(graph.value_info)
|
|
1541
1544
|
return new_graph
|
|
@@ -1549,7 +1552,7 @@ def select_model_inputs_outputs(
|
|
|
1549
1552
|
overwrite: Optional[Dict[str, Any]] = None,
|
|
1550
1553
|
remove_unused: bool = True,
|
|
1551
1554
|
verbose: int = 0,
|
|
1552
|
-
):
|
|
1555
|
+
) -> ModelProto:
|
|
1553
1556
|
"""
|
|
1554
1557
|
Takes a model and changes its outputs.
|
|
1555
1558
|
|
|
@@ -1709,6 +1712,7 @@ def select_model_inputs_outputs(
|
|
|
1709
1712
|
)
|
|
1710
1713
|
if remove_unused:
|
|
1711
1714
|
graph = onnx_remove_node_unused(graph, recursive=False)
|
|
1715
|
+
assert isinstance(graph, GraphProto) # type checking
|
|
1712
1716
|
onnx_model = oh.make_model(graph, functions=model.functions)
|
|
1713
1717
|
onnx_model.ir_version = model.ir_version
|
|
1714
1718
|
onnx_model.producer_name = model.producer_name
|
|
@@ -1727,3 +1731,172 @@ def select_model_inputs_outputs(
|
|
|
1727
1731
|
op_set.version = oimp.version
|
|
1728
1732
|
|
|
1729
1733
|
return onnx_model
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
def _find_used_names(node_list, node_indices):
|
|
1737
|
+
# find all the outputs the subset of nodes produces
|
|
1738
|
+
possible_outputs = set()
|
|
1739
|
+
for i_node in node_indices:
|
|
1740
|
+
if not node_list[i_node]:
|
|
1741
|
+
continue
|
|
1742
|
+
possible_outputs |= {o for o in node_list[i_node].output if o}
|
|
1743
|
+
# find all requires input from the other nodes
|
|
1744
|
+
set_indices = set(node_indices)
|
|
1745
|
+
not_known: Set[str] = set()
|
|
1746
|
+
ranges = list(range(len(node_list)))
|
|
1747
|
+
for i_node in ranges[::-1]:
|
|
1748
|
+
if i_node in set_indices:
|
|
1749
|
+
continue
|
|
1750
|
+
node = node_list[i_node]
|
|
1751
|
+
if not node:
|
|
1752
|
+
continue
|
|
1753
|
+
not_known -= {o for o in node.output if o}
|
|
1754
|
+
not_known |= {i for i in node.input if i}
|
|
1755
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1756
|
+
# there are hidden inputs
|
|
1757
|
+
for att in node.attribute:
|
|
1758
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1759
|
+
not_known |= get_hidden_inputs(att.g)
|
|
1760
|
+
# output
|
|
1761
|
+
selection = possible_outputs & not_known
|
|
1762
|
+
assert selection, (
|
|
1763
|
+
f"No output is needed, possible_outputs={sorted(possible_outputs)}, "
|
|
1764
|
+
f"not_known={sorted(not_known)}"
|
|
1765
|
+
)
|
|
1766
|
+
return sorted(selection)
|
|
1767
|
+
|
|
1768
|
+
|
|
1769
|
+
def check_for_non_recursivity(
|
|
1770
|
+
node_list: List[Optional[NodeProto]], inputs: Sequence[str], outputs: Sequence[str]
|
|
1771
|
+
):
|
|
1772
|
+
"""
|
|
1773
|
+
We finally need to check that any of this output is not required
|
|
1774
|
+
by one input from the function itself, that would mean one node
|
|
1775
|
+
needs an output of the function and is also required by the function:
|
|
1776
|
+
it is probably missing from the initial set.
|
|
1777
|
+
|
|
1778
|
+
|
|
1779
|
+
|
|
1780
|
+
:param node_list: list of nodes
|
|
1781
|
+
:param inputs: input names to consider
|
|
1782
|
+
:param outputs: output names which cannot be involved in input names
|
|
1783
|
+
"""
|
|
1784
|
+
set_inputs = set(inputs)
|
|
1785
|
+
set_outputs = set(outputs)
|
|
1786
|
+
for node in node_list[::-1]:
|
|
1787
|
+
if not node:
|
|
1788
|
+
continue
|
|
1789
|
+
si = set(node.output)
|
|
1790
|
+
if si & set_inputs:
|
|
1791
|
+
set_inputs |= set(node.input)
|
|
1792
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
1793
|
+
# there are hidden inputs
|
|
1794
|
+
for att in node.attribute:
|
|
1795
|
+
if att.type == onnx.AttributeProto.GRAPH:
|
|
1796
|
+
set_inputs |= get_hidden_inputs(att.g)
|
|
1797
|
+
if set_outputs & set_inputs:
|
|
1798
|
+
raise ValueError(
|
|
1799
|
+
f"Results {set_outputs & set_inputs} are needed for inputs {inputs} "
|
|
1800
|
+
f"but also requires {outputs} which is not allowed."
|
|
1801
|
+
)
|
|
1802
|
+
|
|
1803
|
+
|
|
1804
|
+
def make_model_with_local_functions(
|
|
1805
|
+
model: ModelProto,
|
|
1806
|
+
regex: str = ".*[.]layers[.][0-9]+[.]forward$",
|
|
1807
|
+
domain: str = "local_function",
|
|
1808
|
+
metadata_key_prefix: Union[str, Tuple[str, ...]] = ("namespace", "source["),
|
|
1809
|
+
verbose: int = 0,
|
|
1810
|
+
) -> ModelProto:
|
|
1811
|
+
"""
|
|
1812
|
+
Selects nodes based on a regular expression, using metadata
|
|
1813
|
+
``'namespace'``. It is going to look into every value
|
|
1814
|
+
matching the regular expression and partition the nodes based
|
|
1815
|
+
on the unique values the regular expression finds.
|
|
1816
|
+
Every set of nodes it replaced by a call to a local function.
|
|
1817
|
+
|
|
1818
|
+
:param model: model proto
|
|
1819
|
+
:param regex: regular expression
|
|
1820
|
+
:param domain: function domain
|
|
1821
|
+
:param metadata_keys: list of metadata keys to consider,
|
|
1822
|
+
every value is split into multiple ones.
|
|
1823
|
+
:param verbose: verbosity
|
|
1824
|
+
:return: model proto
|
|
1825
|
+
"""
|
|
1826
|
+
prefix = (
|
|
1827
|
+
metadata_key_prefix
|
|
1828
|
+
if isinstance(metadata_key_prefix, tuple)
|
|
1829
|
+
else (metadata_key_prefix,)
|
|
1830
|
+
)
|
|
1831
|
+
reg = re.compile(regex)
|
|
1832
|
+
unique_values = set()
|
|
1833
|
+
unique: Dict[str, List[int]] = {}
|
|
1834
|
+
for i, node in enumerate(model.graph.node):
|
|
1835
|
+
selected = False
|
|
1836
|
+
for data in node.metadata_props:
|
|
1837
|
+
if data.key.startswith(prefix):
|
|
1838
|
+
values = re.split("[,:]", data.value)
|
|
1839
|
+
for v in values:
|
|
1840
|
+
if not v:
|
|
1841
|
+
continue
|
|
1842
|
+
if reg.match(v):
|
|
1843
|
+
if v not in unique:
|
|
1844
|
+
unique[v] = []
|
|
1845
|
+
unique[v].append(i)
|
|
1846
|
+
selected = True
|
|
1847
|
+
break
|
|
1848
|
+
unique_values.add(v)
|
|
1849
|
+
if selected:
|
|
1850
|
+
break
|
|
1851
|
+
# sets of nodes.
|
|
1852
|
+
if not unique:
|
|
1853
|
+
if verbose:
|
|
1854
|
+
print(f"[make_model_with_local_functions] no match in {sorted(unique_values)}")
|
|
1855
|
+
return model
|
|
1856
|
+
|
|
1857
|
+
if verbose:
|
|
1858
|
+
print(f"[make_model_with_local_functions] matched {len(unique)} partitions")
|
|
1859
|
+
functions = []
|
|
1860
|
+
new_nodes: List[Optional[NodeProto]] = list(model.graph.node)
|
|
1861
|
+
for key, node_indices in unique.items():
|
|
1862
|
+
function_name = key.strip().replace(".", "_")
|
|
1863
|
+
if verbose:
|
|
1864
|
+
print(
|
|
1865
|
+
f"[make_model_with_local_functions] move {len(node_indices)} "
|
|
1866
|
+
f"nodes in partition {function_name!r}"
|
|
1867
|
+
)
|
|
1868
|
+
outputs = _find_used_names(new_nodes, node_indices)
|
|
1869
|
+
function_nodes = [new_nodes[i] for i in node_indices]
|
|
1870
|
+
lf = make_subfunction(
|
|
1871
|
+
function_name,
|
|
1872
|
+
[n for n in function_nodes if n],
|
|
1873
|
+
model.opset_import,
|
|
1874
|
+
outputs,
|
|
1875
|
+
domain=domain,
|
|
1876
|
+
)
|
|
1877
|
+
check_for_non_recursivity(new_nodes, lf.input, lf.output)
|
|
1878
|
+
functions.append(lf)
|
|
1879
|
+
maxi = max(node_indices)
|
|
1880
|
+
for i in node_indices:
|
|
1881
|
+
new_nodes[i] = None
|
|
1882
|
+
new_nodes[maxi] = oh.make_node(lf.name, lf.input, lf.output, domain=lf.domain)
|
|
1883
|
+
|
|
1884
|
+
return oh.make_model(
|
|
1885
|
+
oh.make_graph(
|
|
1886
|
+
[n for n in new_nodes if n],
|
|
1887
|
+
model.graph.name,
|
|
1888
|
+
model.graph.input,
|
|
1889
|
+
model.graph.output,
|
|
1890
|
+
model.graph.initializer,
|
|
1891
|
+
doc_string=model.graph.doc_string,
|
|
1892
|
+
value_info=model.graph.value_info,
|
|
1893
|
+
sparse_initializer=model.graph.sparse_initializer,
|
|
1894
|
+
),
|
|
1895
|
+
ir_version=model.ir_version,
|
|
1896
|
+
opset_imports=(
|
|
1897
|
+
model.opset_import
|
|
1898
|
+
if domain in {d.domain for d in model.opset_import}
|
|
1899
|
+
else [*model.opset_import, oh.make_opsetid(domain, 1)]
|
|
1900
|
+
),
|
|
1901
|
+
functions=[*model.functions, *functions],
|
|
1902
|
+
)
|
|
@@ -19,12 +19,7 @@ from .cache_helper import (
|
|
|
19
19
|
CacheKeyValue,
|
|
20
20
|
)
|
|
21
21
|
from .mini_onnx_builder import create_onnx_model_from_input_tensors
|
|
22
|
-
from .onnx_helper import
|
|
23
|
-
to_array_extended,
|
|
24
|
-
tensor_dtype_to_np_dtype,
|
|
25
|
-
_STORAGE_TYPE,
|
|
26
|
-
onnx_dtype_name,
|
|
27
|
-
)
|
|
22
|
+
from .onnx_helper import to_array_extended, tensor_dtype_to_np_dtype, onnx_dtype_name
|
|
28
23
|
|
|
29
24
|
|
|
30
25
|
def proto_from_tensor(
|
|
@@ -84,13 +79,17 @@ def proto_from_tensor(
|
|
|
84
79
|
byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
|
|
85
80
|
tensor.raw_data = bytes(byte_data)
|
|
86
81
|
if sys.byteorder == "big":
|
|
87
|
-
|
|
88
|
-
|
|
82
|
+
storage_type = {
|
|
83
|
+
onnx.TensorProto.FLOAT16: np.int16,
|
|
84
|
+
onnx.TensorProto.BFLOAT16: np.int16,
|
|
85
|
+
}
|
|
86
|
+
np_dtype = storage_type[tensor.data_type] # type: ignore
|
|
87
|
+
np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True) # type: ignore
|
|
89
88
|
else:
|
|
90
89
|
tensor.raw_data = np_arr.tobytes()
|
|
91
90
|
if sys.byteorder == "big":
|
|
92
91
|
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
|
|
93
|
-
np.
|
|
92
|
+
np.frombuffer(tensor.raw_data, dtype=np_dtype).byteswap(inplace=True)
|
|
94
93
|
return tensor
|
|
95
94
|
|
|
96
95
|
|
|
File without changes
|