onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
1
|
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
2
2
|
import numpy as np
|
|
3
3
|
from onnx import (
|
|
4
4
|
AttributeProto,
|
|
@@ -6,6 +6,7 @@ from onnx import (
|
|
|
6
6
|
FunctionProto,
|
|
7
7
|
ModelProto,
|
|
8
8
|
NodeProto,
|
|
9
|
+
TensorProto,
|
|
9
10
|
TypeProto,
|
|
10
11
|
ValueInfoProto,
|
|
11
12
|
helper as oh,
|
|
@@ -16,7 +17,14 @@ from onnx import (
|
|
|
16
17
|
from onnx.defs import onnx_opset_version
|
|
17
18
|
import onnxruntime
|
|
18
19
|
from ..helpers import string_type
|
|
19
|
-
from ..helpers.onnx_helper import
|
|
20
|
+
from ..helpers.onnx_helper import (
|
|
21
|
+
get_hidden_inputs,
|
|
22
|
+
dtype_to_tensor_dtype,
|
|
23
|
+
np_dtype_to_tensor_dtype,
|
|
24
|
+
to_array_extended,
|
|
25
|
+
pretty_onnx,
|
|
26
|
+
)
|
|
27
|
+
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
20
28
|
from ..helpers.ort_session import (
|
|
21
29
|
InferenceSessionForTorch,
|
|
22
30
|
InferenceSessionForNumpy,
|
|
@@ -31,6 +39,54 @@ PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
|
|
|
31
39
|
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
|
|
32
40
|
|
|
33
41
|
|
|
42
|
+
class OnnxList(list):
|
|
43
|
+
"""Defines a list for the runtime."""
|
|
44
|
+
|
|
45
|
+
def __init__(self, itype: Union[list, int]):
|
|
46
|
+
super().__init__()
|
|
47
|
+
if isinstance(itype, int):
|
|
48
|
+
self.itype = itype
|
|
49
|
+
self.dtype = onnx_dtype_to_torch_dtype(itype)
|
|
50
|
+
else:
|
|
51
|
+
assert itype, "The list cannot be created with an empty list."
|
|
52
|
+
self.itype = (
|
|
53
|
+
np_dtype_to_tensor_dtype(itype[0].dtype)
|
|
54
|
+
if isinstance(itype[0], np.ndarray)
|
|
55
|
+
else torch_dtype_to_onnx_dtype(itype[0].dtype)
|
|
56
|
+
)
|
|
57
|
+
self.extend(itype)
|
|
58
|
+
self.dtype = itype[0].dtype
|
|
59
|
+
self.shape = "OnnxList"
|
|
60
|
+
|
|
61
|
+
def get_device(self):
|
|
62
|
+
"Returns the device of the first tensor."
|
|
63
|
+
assert len(self) > 0, "Cannot access the device for an empty list."
|
|
64
|
+
return self[0].get_device() if hasattr(self[0], "get_device") else -1
|
|
65
|
+
|
|
66
|
+
def numpy(self):
|
|
67
|
+
"Creates a new list with all tensors on numpy or self it is already the case."
|
|
68
|
+
if all(isinstance(v, np.ndarray) for v in self):
|
|
69
|
+
return self
|
|
70
|
+
return OnnxList([v.detach().cpu().numpy() for v in self])
|
|
71
|
+
|
|
72
|
+
def to(self, tensor_like) -> "OnnxList":
|
|
73
|
+
"Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`."
|
|
74
|
+
if isinstance(tensor_like, np.ndarray):
|
|
75
|
+
return self
|
|
76
|
+
import torch
|
|
77
|
+
|
|
78
|
+
return OnnxList(
|
|
79
|
+
[
|
|
80
|
+
torch.from_numpy(t).to(tensor_like.device) if isinstance(t, np.ndarray) else t
|
|
81
|
+
for t in self
|
|
82
|
+
]
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def clone(self) -> "OnnxList":
|
|
86
|
+
"Clone (torch)."
|
|
87
|
+
return OnnxList([t.clone() for t in self]) if len(self) > 0 else OnnxList(self.itype)
|
|
88
|
+
|
|
89
|
+
|
|
34
90
|
class OnnxruntimeEvaluator:
|
|
35
91
|
"""
|
|
36
92
|
This class loads an onnx model and the executes one by one the nodes
|
|
@@ -54,6 +110,9 @@ class OnnxruntimeEvaluator:
|
|
|
54
110
|
:param whole: if True, do not split node by node
|
|
55
111
|
:param torch_or_numpy: force the use of one of them, True for torch,
|
|
56
112
|
False for numpy, None to let the class choose
|
|
113
|
+
:param dump_onnx_model: dumps the temporary onnx model created if whole is True
|
|
114
|
+
:param function_kwargs: a FunctionProto may have parameters,
|
|
115
|
+
this contains the values of them
|
|
57
116
|
"""
|
|
58
117
|
|
|
59
118
|
def __init__(
|
|
@@ -77,6 +136,8 @@ class OnnxruntimeEvaluator:
|
|
|
77
136
|
opsets: Optional[Union[int, Dict[str, int]]] = None,
|
|
78
137
|
whole: bool = False,
|
|
79
138
|
torch_or_numpy: Optional[bool] = None,
|
|
139
|
+
function_kwargs: Optional[Dict[str, Any]] = None,
|
|
140
|
+
dump_onnx_model: Optional[str] = None,
|
|
80
141
|
):
|
|
81
142
|
if isinstance(proto, str):
|
|
82
143
|
self.proto: Proto = load(proto)
|
|
@@ -90,6 +151,9 @@ class OnnxruntimeEvaluator:
|
|
|
90
151
|
assert isinstance(
|
|
91
152
|
self.proto, PROTO
|
|
92
153
|
), f"Unexpected type for self.proto {type(self.proto)}"
|
|
154
|
+
assert (
|
|
155
|
+
whole or not dump_onnx_model
|
|
156
|
+
), f"whole must be True for dump_onnx_model={dump_onnx_model!r}"
|
|
93
157
|
|
|
94
158
|
self._cache: Dict[
|
|
95
159
|
Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
|
|
@@ -109,6 +173,8 @@ class OnnxruntimeEvaluator:
|
|
|
109
173
|
use_training_api=use_training_api,
|
|
110
174
|
)
|
|
111
175
|
self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
|
|
176
|
+
self.function_kwargs = function_kwargs
|
|
177
|
+
self.dump_onnx_model = dump_onnx_model
|
|
112
178
|
|
|
113
179
|
self.verbose = verbose
|
|
114
180
|
self.torch_or_numpy = torch_or_numpy
|
|
@@ -199,6 +265,8 @@ class OnnxruntimeEvaluator:
|
|
|
199
265
|
def _log_arg(self, a: Any) -> Any:
|
|
200
266
|
if isinstance(a, (str, int, float)):
|
|
201
267
|
return a
|
|
268
|
+
if isinstance(a, OnnxList):
|
|
269
|
+
return string_type(a)
|
|
202
270
|
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
|
|
203
271
|
if hasattr(a, "shape"):
|
|
204
272
|
prefix = "A:" if hasattr(a, "astype") else "T:"
|
|
@@ -221,6 +289,12 @@ class OnnxruntimeEvaluator:
|
|
|
221
289
|
def _is_local_function(self, node: NodeProto) -> bool:
|
|
222
290
|
return (node.domain, node.op_type) in self.local_functions
|
|
223
291
|
|
|
292
|
+
def _run_init(self, feed_inputs):
|
|
293
|
+
if self.sess_ is None:
|
|
294
|
+
assert self.proto, "self.proto is empty"
|
|
295
|
+
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
|
|
296
|
+
return self.sess_
|
|
297
|
+
|
|
224
298
|
def run(
|
|
225
299
|
self,
|
|
226
300
|
outputs: Optional[List[str]],
|
|
@@ -244,9 +318,7 @@ class OnnxruntimeEvaluator:
|
|
|
244
318
|
"""
|
|
245
319
|
if self.rt_nodes_ is None:
|
|
246
320
|
# runs a whole
|
|
247
|
-
|
|
248
|
-
assert self.proto, "self.proto is empty"
|
|
249
|
-
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
|
|
321
|
+
self._run_init(feed_inputs)
|
|
250
322
|
assert self.sess_, "mypy not happy"
|
|
251
323
|
return self.sess_.run(outputs, feed_inputs)
|
|
252
324
|
if outputs is None:
|
|
@@ -273,14 +345,16 @@ class OnnxruntimeEvaluator:
|
|
|
273
345
|
if node.op_type == "If" and node.domain == "":
|
|
274
346
|
outputs = self._run_if(node, inputs, results)
|
|
275
347
|
elif node.op_type in {"Scan", "Loop"} and node.domain == "":
|
|
276
|
-
outputs = self.
|
|
348
|
+
outputs = self._run_scan_or_loop(node, inputs, results)
|
|
277
349
|
elif self._is_local_function(node):
|
|
278
350
|
outputs = self._run_local(node, inputs, results)
|
|
279
351
|
else:
|
|
280
352
|
outputs = self._run(node, inputs, results)
|
|
281
|
-
for
|
|
282
|
-
|
|
283
|
-
|
|
353
|
+
node_output = [o for o in node.output if o]
|
|
354
|
+
assert len(node_output) == len(
|
|
355
|
+
outputs
|
|
356
|
+
), f"Length mismatch between node output={node.output} and outputs={outputs}"
|
|
357
|
+
for name, value in zip(node_output, outputs):
|
|
284
358
|
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
|
|
285
359
|
assert isinstance(name, str), f"unexpected type for name {type(name)}"
|
|
286
360
|
results[name] = value
|
|
@@ -355,11 +429,12 @@ class OnnxruntimeEvaluator:
|
|
|
355
429
|
nodes: Sequence[NodeProto],
|
|
356
430
|
vinputs: Sequence[ValueInfoProto],
|
|
357
431
|
voutputs: Sequence[ValueInfoProto],
|
|
432
|
+
functions: Optional[Sequence[FunctionProto]] = None,
|
|
358
433
|
) -> ModelProto:
|
|
359
434
|
onx = oh.make_model(
|
|
360
435
|
oh.make_graph(nodes, "-", vinputs, voutputs),
|
|
361
436
|
ir_version=getattr(self.proto, "ir_version", self.ir_version),
|
|
362
|
-
functions=getattr(self.proto, "functions",
|
|
437
|
+
functions=[*getattr(self.proto, "functions", []), *(functions or [])],
|
|
363
438
|
)
|
|
364
439
|
del onx.opset_import[:]
|
|
365
440
|
if hasattr(self.proto, "opset_import"):
|
|
@@ -373,51 +448,61 @@ class OnnxruntimeEvaluator:
|
|
|
373
448
|
)
|
|
374
449
|
else:
|
|
375
450
|
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
|
|
451
|
+
opsets = {d.domain: d.version for d in onx.opset_import}
|
|
452
|
+
add = {}
|
|
453
|
+
for node in self.enumerate_nodes(onx.graph.node):
|
|
454
|
+
if node.domain and node.domain not in opsets and node.domain not in add:
|
|
455
|
+
add[node.domain] = 1
|
|
456
|
+
onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()])
|
|
376
457
|
|
|
377
458
|
# That helps fixing bugs.
|
|
378
459
|
onx = shi.infer_shapes(onx)
|
|
379
460
|
return onx
|
|
380
461
|
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
for att in node.attribute:
|
|
395
|
-
if att.type == AttributeProto.GRAPH and att.g:
|
|
396
|
-
hid = self._get_hidden_inputs(att.g)
|
|
397
|
-
less = set(h for h in hid if h not in memo)
|
|
398
|
-
hidden |= less
|
|
399
|
-
memo |= set(node.output)
|
|
400
|
-
return hidden
|
|
462
|
+
def _make_model_outputs(
|
|
463
|
+
self, node: NodeProto, inputs: List[ValueInfoProto]
|
|
464
|
+
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
|
|
465
|
+
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
|
|
466
|
+
|
|
467
|
+
def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
|
|
468
|
+
"Enumerates nodes recursively."
|
|
469
|
+
for node in nodes:
|
|
470
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
471
|
+
for att in node.attribute:
|
|
472
|
+
if att.type == AttributeProto.GRAPH:
|
|
473
|
+
yield from self.enumerate_nodes(att.g.node)
|
|
474
|
+
yield node
|
|
401
475
|
|
|
402
476
|
@classmethod
|
|
403
|
-
def _get_hidden_node_inputs(
|
|
404
|
-
"""Calls multiple
|
|
477
|
+
def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
|
|
478
|
+
"""Calls multiple get_hidden_inputs on every attribute."""
|
|
405
479
|
if node.op_type not in {"Loop", "Scan", "If"}:
|
|
406
480
|
return set()
|
|
407
481
|
hidden = set()
|
|
408
482
|
for att in node.attribute:
|
|
409
483
|
if att.type == AttributeProto.GRAPH:
|
|
410
|
-
hidden |=
|
|
484
|
+
hidden |= get_hidden_inputs(att.g)
|
|
411
485
|
return hidden - (hidden & set(node.input))
|
|
412
486
|
|
|
413
487
|
def _get_sess(
|
|
414
488
|
self, node: Union[ModelProto, NodeProto], inputs: List[Any]
|
|
415
489
|
) -> Tuple[ModelProto, _InferenceSession]:
|
|
490
|
+
on_cpu = None
|
|
416
491
|
if isinstance(node, ModelProto):
|
|
417
492
|
onx = node
|
|
418
493
|
else:
|
|
494
|
+
functions = []
|
|
495
|
+
if isinstance(node, FunctionProto):
|
|
496
|
+
functions.append(node)
|
|
497
|
+
node = oh.make_node(
|
|
498
|
+
node.name,
|
|
499
|
+
list(node.input),
|
|
500
|
+
list(node.output),
|
|
501
|
+
domain=node.domain,
|
|
502
|
+
**(self.function_kwargs or {}),
|
|
503
|
+
)
|
|
419
504
|
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
|
|
420
|
-
if node.op_type == "Constant":
|
|
505
|
+
if node.op_type == "Constant" and node.domain == "":
|
|
421
506
|
# We force the type to be a boolean.
|
|
422
507
|
ref = ExtendedReferenceEvaluator(node)
|
|
423
508
|
cst = ref.run(None, {})[0]
|
|
@@ -427,6 +512,19 @@ class OnnxruntimeEvaluator:
|
|
|
427
512
|
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
|
|
428
513
|
)
|
|
429
514
|
]
|
|
515
|
+
prenodes = [] # type: ignore[var-annotated]
|
|
516
|
+
elif node.op_type == "ConcatFromSequence" and node.domain == "":
|
|
517
|
+
# We force the type to be a boolean.
|
|
518
|
+
vinputs = [
|
|
519
|
+
oh.make_value_info(
|
|
520
|
+
node.input[0],
|
|
521
|
+
type_proto=oh.make_sequence_type_proto(
|
|
522
|
+
oh.make_tensor_type_proto(elem_type=inputs[0].itype, shape=None)
|
|
523
|
+
),
|
|
524
|
+
)
|
|
525
|
+
]
|
|
526
|
+
voutputs = [oh.make_tensor_value_info(node.output[0], inputs[0].itype, None)]
|
|
527
|
+
prenodes = [] # type: ignore[var-annotated]
|
|
430
528
|
else:
|
|
431
529
|
unique_names = set()
|
|
432
530
|
vinputs = []
|
|
@@ -440,18 +538,35 @@ class OnnxruntimeEvaluator:
|
|
|
440
538
|
vinputs.append(value)
|
|
441
539
|
|
|
442
540
|
# no need to run shape inference
|
|
443
|
-
voutputs =
|
|
541
|
+
prenodes, voutputs = self._make_model_outputs(node, vinputs)
|
|
444
542
|
|
|
445
|
-
onx = self._make_model_proto(
|
|
543
|
+
onx = self._make_model_proto(
|
|
544
|
+
[*prenodes, node], vinputs, voutputs, functions=functions
|
|
545
|
+
)
|
|
546
|
+
if node.op_type in {"Shape", "Size"}:
|
|
547
|
+
on_cpu = True
|
|
446
548
|
|
|
549
|
+
if self.dump_onnx_model:
|
|
550
|
+
onnx_save(
|
|
551
|
+
onx, self.dump_onnx_model, save_as_external_data=len(onx.graph.node) > 100
|
|
552
|
+
)
|
|
447
553
|
cls = (
|
|
448
554
|
InferenceSessionForNumpy
|
|
449
555
|
if any(isinstance(i, np.ndarray) for i in inputs)
|
|
450
556
|
and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
|
|
451
557
|
else InferenceSessionForTorch
|
|
452
558
|
)
|
|
559
|
+
if (
|
|
560
|
+
"providers" not in self.session_kwargs or not self.session_kwargs["providers"]
|
|
561
|
+
) and any(hasattr(t, "is_cuda") and t.is_cuda for t in inputs):
|
|
562
|
+
sess_kwargs = self.session_kwargs.copy()
|
|
563
|
+
sess_kwargs["providers"] = ["CUDAExecutionProvider"]
|
|
564
|
+
else:
|
|
565
|
+
sess_kwargs = self.session_kwargs or {}
|
|
566
|
+
if on_cpu and "CUDAExecutionProvider" in (sess_kwargs.get("providers", []) or []):
|
|
567
|
+
sess_kwargs["cpu_outputs"] = True
|
|
453
568
|
try:
|
|
454
|
-
sess = cls(onx, **
|
|
569
|
+
sess = cls(onx, **sess_kwargs)
|
|
455
570
|
except (
|
|
456
571
|
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
|
457
572
|
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
|
@@ -473,15 +588,29 @@ class OnnxruntimeEvaluator:
|
|
|
473
588
|
if i == "" or i in unique_names:
|
|
474
589
|
continue
|
|
475
590
|
unique_names.add(i)
|
|
476
|
-
|
|
591
|
+
if isinstance(it, OnnxList):
|
|
592
|
+
value = oh.make_value_info(
|
|
593
|
+
i,
|
|
594
|
+
type_proto=oh.make_sequence_type_proto(
|
|
595
|
+
oh.make_tensor_type_proto(
|
|
596
|
+
elem_type=dtype_to_tensor_dtype(it.dtype), shape=None
|
|
597
|
+
)
|
|
598
|
+
),
|
|
599
|
+
)
|
|
600
|
+
else:
|
|
601
|
+
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
|
|
477
602
|
vinputs.append(value)
|
|
478
603
|
|
|
479
|
-
reduced_set =
|
|
604
|
+
reduced_set = get_hidden_inputs(g)
|
|
480
605
|
for i, v in context.items():
|
|
481
606
|
if i in reduced_set and i not in unique_names:
|
|
482
607
|
unique_names.add(i)
|
|
483
608
|
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
|
|
484
609
|
vinputs.append(value)
|
|
610
|
+
assert len(reduced_set & set(context)) == len(reduced_set), (
|
|
611
|
+
f"Missing hidden inputs {sorted(reduced_set)} from context={sorted(context)} "
|
|
612
|
+
f"(len(inputs)={len([i for i in inputs if i])}) for node {pretty_onnx(node)}"
|
|
613
|
+
)
|
|
485
614
|
return vinputs
|
|
486
615
|
|
|
487
616
|
def _get_sess_if(
|
|
@@ -530,6 +659,14 @@ class OnnxruntimeEvaluator:
|
|
|
530
659
|
|
|
531
660
|
def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
|
|
532
661
|
"""Runs a node."""
|
|
662
|
+
if node.op_type[0] == "S":
|
|
663
|
+
if node.op_type == "SequenceEmpty":
|
|
664
|
+
dtype = TensorProto.FLOAT
|
|
665
|
+
for att in node.attribute:
|
|
666
|
+
if att.name == "dtype":
|
|
667
|
+
dtype = att.i
|
|
668
|
+
return [OnnxList(itype=dtype)]
|
|
669
|
+
|
|
533
670
|
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
|
|
534
671
|
key = (id(node), *types)
|
|
535
672
|
if key in self._cache:
|
|
@@ -538,13 +675,31 @@ class OnnxruntimeEvaluator:
|
|
|
538
675
|
onx, sess = self._get_sess(node, inputs)
|
|
539
676
|
self._cache[key] = onx, sess
|
|
540
677
|
|
|
541
|
-
feeds =
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
678
|
+
feeds = {}
|
|
679
|
+
for i, val in zip(node.input, inputs):
|
|
680
|
+
if i == "":
|
|
681
|
+
assert (
|
|
682
|
+
val is None
|
|
683
|
+
), f"input name={i!r} but val={string_type(val, with_shape=True)}"
|
|
684
|
+
continue
|
|
685
|
+
feeds[i] = val
|
|
545
686
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
687
|
+
|
|
688
|
+
if node.op_type[0] == "C":
|
|
689
|
+
if node.op_type == "ConcatFromSequence":
|
|
690
|
+
res = sess.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
|
|
691
|
+
if isinstance(inputs[0][0], np.ndarray):
|
|
692
|
+
return list(res)
|
|
693
|
+
import torch
|
|
694
|
+
|
|
695
|
+
return [torch.from_numpy(r).to(inputs[0][0].device) for r in res]
|
|
696
|
+
|
|
546
697
|
outputs = list(sess.run(None, feeds))
|
|
547
698
|
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
699
|
+
assert not any(type(v) is list for v in outputs), (
|
|
700
|
+
f"One output type is a list, this should not be allowed, "
|
|
701
|
+
f"node.op_type={node.op_type}, feeds={string_type(feeds, with_shape=True)}"
|
|
702
|
+
)
|
|
548
703
|
return outputs
|
|
549
704
|
|
|
550
705
|
def _run_if(
|
|
@@ -570,7 +725,7 @@ class OnnxruntimeEvaluator:
|
|
|
570
725
|
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
571
726
|
return outputs
|
|
572
727
|
|
|
573
|
-
def
|
|
728
|
+
def _get_sess_scan_or_loop(
|
|
574
729
|
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
|
|
575
730
|
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
|
|
576
731
|
g = None
|
|
@@ -605,10 +760,26 @@ class OnnxruntimeEvaluator:
|
|
|
605
760
|
)
|
|
606
761
|
return onx, sess
|
|
607
762
|
|
|
608
|
-
def
|
|
763
|
+
def feeds_to_numpy(self, feeds):
|
|
764
|
+
new_feeds = {}
|
|
765
|
+
for k, v in feeds.items():
|
|
766
|
+
if hasattr(v, "detach"):
|
|
767
|
+
new_feeds[k] = v.detach().cpu().numpy()
|
|
768
|
+
elif isinstance(v, OnnxList):
|
|
769
|
+
new_feeds[k] = v.numpy()
|
|
770
|
+
else:
|
|
771
|
+
new_feeds[k] = v
|
|
772
|
+
return new_feeds
|
|
773
|
+
|
|
774
|
+
def _run_scan_or_loop(
|
|
609
775
|
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
|
|
610
776
|
) -> List[Any]:
|
|
611
777
|
"""Runs a node Scan."""
|
|
778
|
+
assert not any(type(i) is list for i in inputs), (
|
|
779
|
+
f"One input is a list but it should an OnnxList, "
|
|
780
|
+
f"node.op_type={node.op_type!r}, node.input={node.input}, "
|
|
781
|
+
f"inputs={string_type(inputs, with_shape=True)}"
|
|
782
|
+
)
|
|
612
783
|
feeds = dict(zip(node.input, inputs))
|
|
613
784
|
feeds.update(results)
|
|
614
785
|
name = "body"
|
|
@@ -616,10 +787,21 @@ class OnnxruntimeEvaluator:
|
|
|
616
787
|
if key in self._cache:
|
|
617
788
|
sess = self._cache[key][1]
|
|
618
789
|
else:
|
|
619
|
-
self._cache[key] = _onx, sess = self.
|
|
790
|
+
self._cache[key] = _onx, sess = self._get_sess_scan_or_loop(
|
|
791
|
+
node, name, inputs, results
|
|
792
|
+
)
|
|
620
793
|
|
|
621
794
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
622
795
|
feeds = {name: results[name] for name in sess.input_names}
|
|
796
|
+
if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()):
|
|
797
|
+
# This operator uses sequence. onnxruntime does not play well with sequence.
|
|
798
|
+
sess._run_init(feeds) # type: ignore[union-attr]
|
|
799
|
+
outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
|
|
800
|
+
return [
|
|
801
|
+
(OnnxList(v).to(feeds[node.input[0]]) if isinstance(v, list) else v)
|
|
802
|
+
for v in outputs
|
|
803
|
+
]
|
|
804
|
+
|
|
623
805
|
outputs = sess.run(None, feeds)
|
|
624
806
|
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
625
807
|
return outputs
|
|
@@ -1,10 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.config_helper import
|
|
4
|
-
update_config,
|
|
5
|
-
check_hasattr,
|
|
6
|
-
default_num_hidden_layers as nhl,
|
|
7
|
-
)
|
|
3
|
+
from ..helpers.config_helper import update_config, check_hasattr
|
|
8
4
|
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
9
5
|
|
|
10
6
|
|
|
@@ -13,8 +9,9 @@ __TASK__ = "feature-extraction"
|
|
|
13
9
|
|
|
14
10
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
15
11
|
"""Reduces a model size."""
|
|
16
|
-
check_hasattr(config, "
|
|
17
|
-
|
|
12
|
+
check_hasattr(config, "vocab_size")
|
|
13
|
+
# Bart architecture does not like too much that the number of layers is changed.
|
|
14
|
+
kwargs = dict(vocab_size=2056)
|
|
18
15
|
update_config(config, kwargs)
|
|
19
16
|
return kwargs
|
|
20
17
|
|
|
@@ -25,7 +22,8 @@ def get_inputs(
|
|
|
25
22
|
batch_size: int,
|
|
26
23
|
sequence_length: int,
|
|
27
24
|
dummy_max_token_id: int,
|
|
28
|
-
|
|
25
|
+
past_length: int = 30,
|
|
26
|
+
past_length2: int = 4,
|
|
29
27
|
decoder_attention_heads: Optional[int] = None,
|
|
30
28
|
encoder_attention_heads: Optional[int] = None,
|
|
31
29
|
encoder_ffn_dim: Optional[int] = None,
|
|
@@ -73,13 +71,13 @@ def get_inputs(
|
|
|
73
71
|
torch.randn(
|
|
74
72
|
batch_size,
|
|
75
73
|
encoder_attention_heads,
|
|
76
|
-
|
|
74
|
+
past_length,
|
|
77
75
|
encoder_ffn_dim,
|
|
78
76
|
),
|
|
79
77
|
torch.randn(
|
|
80
78
|
batch_size,
|
|
81
79
|
encoder_attention_heads,
|
|
82
|
-
|
|
80
|
+
past_length,
|
|
83
81
|
encoder_ffn_dim,
|
|
84
82
|
),
|
|
85
83
|
)
|
|
@@ -92,13 +90,13 @@ def get_inputs(
|
|
|
92
90
|
torch.randn(
|
|
93
91
|
batch_size,
|
|
94
92
|
decoder_attention_heads,
|
|
95
|
-
|
|
93
|
+
past_length2,
|
|
96
94
|
decoder_ffn_dim,
|
|
97
95
|
),
|
|
98
96
|
torch.randn(
|
|
99
97
|
batch_size,
|
|
100
98
|
decoder_attention_heads,
|
|
101
|
-
|
|
99
|
+
past_length2,
|
|
102
100
|
decoder_ffn_dim,
|
|
103
101
|
),
|
|
104
102
|
)
|
|
@@ -124,7 +122,8 @@ def get_inputs(
|
|
|
124
122
|
batch_size=batch_size + 1,
|
|
125
123
|
sequence_length=sequence_length + add_second_input,
|
|
126
124
|
dummy_max_token_id=dummy_max_token_id,
|
|
127
|
-
|
|
125
|
+
past_length=past_length,
|
|
126
|
+
past_length2=past_length2,
|
|
128
127
|
decoder_attention_heads=decoder_attention_heads,
|
|
129
128
|
encoder_attention_heads=encoder_attention_heads,
|
|
130
129
|
encoder_ffn_dim=encoder_ffn_dim,
|
|
@@ -146,7 +145,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
146
145
|
check_hasattr(config, "vocab_size")
|
|
147
146
|
kwargs = dict(
|
|
148
147
|
batch_size=2,
|
|
149
|
-
sequence_length=
|
|
148
|
+
sequence_length=12,
|
|
149
|
+
past_length=30,
|
|
150
|
+
past_length2=4,
|
|
150
151
|
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
151
152
|
)
|
|
152
153
|
for att in [
|