onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +387 -12
- onnx_diagnostic/export/api.py +91 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +3 -3
- onnx_diagnostic/export/onnx_plug.py +396 -0
- onnx_diagnostic/ext_test_case.py +92 -23
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +210 -0
- onnx_diagnostic/helpers/helper.py +90 -26
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +103 -1
- onnx_diagnostic/helpers/ort_session.py +37 -11
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +103 -6
- onnx_diagnostic/reference/ort_evaluator.py +233 -28
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/validate.py +50 -1
- onnx_diagnostic/torch_onnx/sbs.py +963 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
|
|
1
|
+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
|
|
2
2
|
import numpy as np
|
|
3
3
|
from onnx import (
|
|
4
4
|
AttributeProto,
|
|
@@ -6,6 +6,7 @@ from onnx import (
|
|
|
6
6
|
FunctionProto,
|
|
7
7
|
ModelProto,
|
|
8
8
|
NodeProto,
|
|
9
|
+
TensorProto,
|
|
9
10
|
TypeProto,
|
|
10
11
|
ValueInfoProto,
|
|
11
12
|
helper as oh,
|
|
@@ -16,7 +17,13 @@ from onnx import (
|
|
|
16
17
|
from onnx.defs import onnx_opset_version
|
|
17
18
|
import onnxruntime
|
|
18
19
|
from ..helpers import string_type
|
|
19
|
-
from ..helpers.onnx_helper import
|
|
20
|
+
from ..helpers.onnx_helper import (
|
|
21
|
+
pretty_onnx,
|
|
22
|
+
dtype_to_tensor_dtype,
|
|
23
|
+
to_array_extended,
|
|
24
|
+
np_dtype_to_tensor_dtype,
|
|
25
|
+
)
|
|
26
|
+
from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
20
27
|
from ..helpers.ort_session import (
|
|
21
28
|
InferenceSessionForTorch,
|
|
22
29
|
InferenceSessionForNumpy,
|
|
@@ -31,6 +38,54 @@ PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
|
|
|
31
38
|
Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
|
|
32
39
|
|
|
33
40
|
|
|
41
|
+
class OnnxList(list):
|
|
42
|
+
"""Defines a list for the runtime."""
|
|
43
|
+
|
|
44
|
+
def __init__(self, itype: Union[list, int]):
|
|
45
|
+
super().__init__()
|
|
46
|
+
if isinstance(itype, int):
|
|
47
|
+
self.itype = itype
|
|
48
|
+
self.dtype = onnx_dtype_to_torch_dtype(itype)
|
|
49
|
+
else:
|
|
50
|
+
assert itype, "The list cannot be created with an empty list."
|
|
51
|
+
self.itype = (
|
|
52
|
+
np_dtype_to_tensor_dtype(itype[0].dtype)
|
|
53
|
+
if isinstance(itype[0], np.ndarray)
|
|
54
|
+
else torch_dtype_to_onnx_dtype(itype[0].dtype)
|
|
55
|
+
)
|
|
56
|
+
self.extend(itype)
|
|
57
|
+
self.dtype = itype[0].dtype
|
|
58
|
+
self.shape = "OnnxList"
|
|
59
|
+
|
|
60
|
+
def get_device(self):
|
|
61
|
+
"Returns the device of the first tensor."
|
|
62
|
+
assert len(self) > 0, "Cannot access the device for an empty list."
|
|
63
|
+
return self[0].get_device() if hasattr(self[0], "get_device") else -1
|
|
64
|
+
|
|
65
|
+
def numpy(self):
|
|
66
|
+
"Creates a new list with all tensors on numpy or self it is already the case."
|
|
67
|
+
if all(isinstance(v, np.ndarray) for v in self):
|
|
68
|
+
return self
|
|
69
|
+
return OnnxList([v.detach().cpu().numpy() for v in self])
|
|
70
|
+
|
|
71
|
+
def to(self, tensor_like) -> "OnnxList":
|
|
72
|
+
"Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`."
|
|
73
|
+
if isinstance(tensor_like, np.ndarray):
|
|
74
|
+
return self
|
|
75
|
+
import torch
|
|
76
|
+
|
|
77
|
+
return OnnxList(
|
|
78
|
+
[
|
|
79
|
+
torch.from_numpy(t).to(tensor_like.device) if isinstance(t, np.ndarray) else t
|
|
80
|
+
for t in self
|
|
81
|
+
]
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def clone(self) -> "OnnxList":
|
|
85
|
+
"Clone (torch)."
|
|
86
|
+
return OnnxList([t.clone() for t in self]) if len(self) > 0 else OnnxList(self.itype)
|
|
87
|
+
|
|
88
|
+
|
|
34
89
|
class OnnxruntimeEvaluator:
|
|
35
90
|
"""
|
|
36
91
|
This class loads an onnx model and the executes one by one the nodes
|
|
@@ -54,6 +109,9 @@ class OnnxruntimeEvaluator:
|
|
|
54
109
|
:param whole: if True, do not split node by node
|
|
55
110
|
:param torch_or_numpy: force the use of one of them, True for torch,
|
|
56
111
|
False for numpy, None to let the class choose
|
|
112
|
+
:param dump_onnx_model: dumps the temporary onnx model created if whole is True
|
|
113
|
+
:param function_kwargs: a FunctionProto may have parameters,
|
|
114
|
+
this contains the values of them
|
|
57
115
|
"""
|
|
58
116
|
|
|
59
117
|
def __init__(
|
|
@@ -77,6 +135,8 @@ class OnnxruntimeEvaluator:
|
|
|
77
135
|
opsets: Optional[Union[int, Dict[str, int]]] = None,
|
|
78
136
|
whole: bool = False,
|
|
79
137
|
torch_or_numpy: Optional[bool] = None,
|
|
138
|
+
function_kwargs: Optional[Dict[str, Any]] = None,
|
|
139
|
+
dump_onnx_model: Optional[str] = None,
|
|
80
140
|
):
|
|
81
141
|
if isinstance(proto, str):
|
|
82
142
|
self.proto: Proto = load(proto)
|
|
@@ -90,6 +150,9 @@ class OnnxruntimeEvaluator:
|
|
|
90
150
|
assert isinstance(
|
|
91
151
|
self.proto, PROTO
|
|
92
152
|
), f"Unexpected type for self.proto {type(self.proto)}"
|
|
153
|
+
assert (
|
|
154
|
+
whole or not dump_onnx_model
|
|
155
|
+
), f"whole must be True for dump_onnx_model={dump_onnx_model!r}"
|
|
93
156
|
|
|
94
157
|
self._cache: Dict[
|
|
95
158
|
Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
|
|
@@ -109,6 +172,8 @@ class OnnxruntimeEvaluator:
|
|
|
109
172
|
use_training_api=use_training_api,
|
|
110
173
|
)
|
|
111
174
|
self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
|
|
175
|
+
self.function_kwargs = function_kwargs
|
|
176
|
+
self.dump_onnx_model = dump_onnx_model
|
|
112
177
|
|
|
113
178
|
self.verbose = verbose
|
|
114
179
|
self.torch_or_numpy = torch_or_numpy
|
|
@@ -199,6 +264,8 @@ class OnnxruntimeEvaluator:
|
|
|
199
264
|
def _log_arg(self, a: Any) -> Any:
|
|
200
265
|
if isinstance(a, (str, int, float)):
|
|
201
266
|
return a
|
|
267
|
+
if isinstance(a, OnnxList):
|
|
268
|
+
return string_type(a)
|
|
202
269
|
device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
|
|
203
270
|
if hasattr(a, "shape"):
|
|
204
271
|
prefix = "A:" if hasattr(a, "astype") else "T:"
|
|
@@ -221,6 +288,12 @@ class OnnxruntimeEvaluator:
|
|
|
221
288
|
def _is_local_function(self, node: NodeProto) -> bool:
|
|
222
289
|
return (node.domain, node.op_type) in self.local_functions
|
|
223
290
|
|
|
291
|
+
def _run_init(self, feed_inputs):
|
|
292
|
+
if self.sess_ is None:
|
|
293
|
+
assert self.proto, "self.proto is empty"
|
|
294
|
+
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
|
|
295
|
+
return self.sess_
|
|
296
|
+
|
|
224
297
|
def run(
|
|
225
298
|
self,
|
|
226
299
|
outputs: Optional[List[str]],
|
|
@@ -244,9 +317,7 @@ class OnnxruntimeEvaluator:
|
|
|
244
317
|
"""
|
|
245
318
|
if self.rt_nodes_ is None:
|
|
246
319
|
# runs a whole
|
|
247
|
-
|
|
248
|
-
assert self.proto, "self.proto is empty"
|
|
249
|
-
_, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
|
|
320
|
+
self._run_init(feed_inputs)
|
|
250
321
|
assert self.sess_, "mypy not happy"
|
|
251
322
|
return self.sess_.run(outputs, feed_inputs)
|
|
252
323
|
if outputs is None:
|
|
@@ -273,14 +344,16 @@ class OnnxruntimeEvaluator:
|
|
|
273
344
|
if node.op_type == "If" and node.domain == "":
|
|
274
345
|
outputs = self._run_if(node, inputs, results)
|
|
275
346
|
elif node.op_type in {"Scan", "Loop"} and node.domain == "":
|
|
276
|
-
outputs = self.
|
|
347
|
+
outputs = self._run_scan_or_loop(node, inputs, results)
|
|
277
348
|
elif self._is_local_function(node):
|
|
278
349
|
outputs = self._run_local(node, inputs, results)
|
|
279
350
|
else:
|
|
280
351
|
outputs = self._run(node, inputs, results)
|
|
281
|
-
for
|
|
282
|
-
|
|
283
|
-
|
|
352
|
+
node_output = [o for o in node.output if o]
|
|
353
|
+
assert len(node_output) == len(
|
|
354
|
+
outputs
|
|
355
|
+
), f"Length mismatch between node output={node.output} and outputs={outputs}"
|
|
356
|
+
for name, value in zip(node_output, outputs):
|
|
284
357
|
self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
|
|
285
358
|
assert isinstance(name, str), f"unexpected type for name {type(name)}"
|
|
286
359
|
results[name] = value
|
|
@@ -355,11 +428,12 @@ class OnnxruntimeEvaluator:
|
|
|
355
428
|
nodes: Sequence[NodeProto],
|
|
356
429
|
vinputs: Sequence[ValueInfoProto],
|
|
357
430
|
voutputs: Sequence[ValueInfoProto],
|
|
431
|
+
functions: Optional[Sequence[FunctionProto]] = None,
|
|
358
432
|
) -> ModelProto:
|
|
359
433
|
onx = oh.make_model(
|
|
360
434
|
oh.make_graph(nodes, "-", vinputs, voutputs),
|
|
361
435
|
ir_version=getattr(self.proto, "ir_version", self.ir_version),
|
|
362
|
-
functions=getattr(self.proto, "functions",
|
|
436
|
+
functions=[*getattr(self.proto, "functions", []), *(functions or [])],
|
|
363
437
|
)
|
|
364
438
|
del onx.opset_import[:]
|
|
365
439
|
if hasattr(self.proto, "opset_import"):
|
|
@@ -373,51 +447,85 @@ class OnnxruntimeEvaluator:
|
|
|
373
447
|
)
|
|
374
448
|
else:
|
|
375
449
|
onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
|
|
450
|
+
opsets = {d.domain: d.version for d in onx.opset_import}
|
|
451
|
+
add = {}
|
|
452
|
+
for node in self.enumerate_nodes(onx.graph.node):
|
|
453
|
+
if node.domain and node.domain not in opsets and node.domain not in add:
|
|
454
|
+
add[node.domain] = 1
|
|
455
|
+
onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()])
|
|
376
456
|
|
|
377
457
|
# That helps fixing bugs.
|
|
378
458
|
onx = shi.infer_shapes(onx)
|
|
379
459
|
return onx
|
|
380
460
|
|
|
461
|
+
def _make_model_outputs(
|
|
462
|
+
self, node: NodeProto, inputs: List[ValueInfoProto]
|
|
463
|
+
) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
|
|
464
|
+
return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
|
|
465
|
+
|
|
466
|
+
def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
|
|
467
|
+
"Enumerates nodes recursively."
|
|
468
|
+
for node in nodes:
|
|
469
|
+
if node.op_type in {"Scan", "If", "Loop"}:
|
|
470
|
+
for att in node.attribute:
|
|
471
|
+
if att.type == AttributeProto.GRAPH:
|
|
472
|
+
yield from self.enumerate_nodes(att.g.node)
|
|
473
|
+
yield node
|
|
474
|
+
|
|
381
475
|
@classmethod
|
|
382
|
-
def _get_hidden_inputs(
|
|
476
|
+
def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
|
|
383
477
|
"""
|
|
384
478
|
Returns the hidden inputs (inputs coming from an upper context)
|
|
385
479
|
used by a subgraph.
|
|
386
480
|
"""
|
|
387
481
|
hidden = set()
|
|
388
|
-
memo =
|
|
389
|
-
|
|
482
|
+
memo = (
|
|
483
|
+
{i.name for i in graph.initializer}
|
|
484
|
+
| {i.name for i in graph.sparse_initializer}
|
|
485
|
+
| {i.name for i in graph.input}
|
|
486
|
+
)
|
|
390
487
|
for node in graph.node:
|
|
391
488
|
for i in node.input:
|
|
392
489
|
if i not in memo:
|
|
393
490
|
hidden.add(i)
|
|
394
491
|
for att in node.attribute:
|
|
395
492
|
if att.type == AttributeProto.GRAPH and att.g:
|
|
396
|
-
hid =
|
|
493
|
+
hid = cls._get_hidden_inputs(att.g)
|
|
397
494
|
less = set(h for h in hid if h not in memo)
|
|
398
495
|
hidden |= less
|
|
399
496
|
memo |= set(node.output)
|
|
400
497
|
return hidden
|
|
401
498
|
|
|
402
499
|
@classmethod
|
|
403
|
-
def _get_hidden_node_inputs(
|
|
500
|
+
def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
|
|
404
501
|
"""Calls multiple _get_hidden_inputs on every attribute."""
|
|
405
502
|
if node.op_type not in {"Loop", "Scan", "If"}:
|
|
406
503
|
return set()
|
|
407
504
|
hidden = set()
|
|
408
505
|
for att in node.attribute:
|
|
409
506
|
if att.type == AttributeProto.GRAPH:
|
|
410
|
-
hidden |=
|
|
507
|
+
hidden |= cls._get_hidden_inputs(att.g)
|
|
411
508
|
return hidden - (hidden & set(node.input))
|
|
412
509
|
|
|
413
510
|
def _get_sess(
|
|
414
511
|
self, node: Union[ModelProto, NodeProto], inputs: List[Any]
|
|
415
512
|
) -> Tuple[ModelProto, _InferenceSession]:
|
|
513
|
+
on_cpu = None
|
|
416
514
|
if isinstance(node, ModelProto):
|
|
417
515
|
onx = node
|
|
418
516
|
else:
|
|
517
|
+
functions = []
|
|
518
|
+
if isinstance(node, FunctionProto):
|
|
519
|
+
functions.append(node)
|
|
520
|
+
node = oh.make_node(
|
|
521
|
+
node.name,
|
|
522
|
+
list(node.input),
|
|
523
|
+
list(node.output),
|
|
524
|
+
domain=node.domain,
|
|
525
|
+
**(self.function_kwargs or {}),
|
|
526
|
+
)
|
|
419
527
|
assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
|
|
420
|
-
if node.op_type == "Constant":
|
|
528
|
+
if node.op_type == "Constant" and node.domain == "":
|
|
421
529
|
# We force the type to be a boolean.
|
|
422
530
|
ref = ExtendedReferenceEvaluator(node)
|
|
423
531
|
cst = ref.run(None, {})[0]
|
|
@@ -427,6 +535,19 @@ class OnnxruntimeEvaluator:
|
|
|
427
535
|
node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
|
|
428
536
|
)
|
|
429
537
|
]
|
|
538
|
+
prenodes = [] # type: ignore[var-annotated]
|
|
539
|
+
elif node.op_type == "ConcatFromSequence" and node.domain == "":
|
|
540
|
+
# We force the type to be a boolean.
|
|
541
|
+
vinputs = [
|
|
542
|
+
oh.make_value_info(
|
|
543
|
+
node.input[0],
|
|
544
|
+
type_proto=oh.make_sequence_type_proto(
|
|
545
|
+
oh.make_tensor_type_proto(elem_type=inputs[0].itype, shape=None)
|
|
546
|
+
),
|
|
547
|
+
)
|
|
548
|
+
]
|
|
549
|
+
voutputs = [oh.make_tensor_value_info(node.output[0], inputs[0].itype, None)]
|
|
550
|
+
prenodes = [] # type: ignore[var-annotated]
|
|
430
551
|
else:
|
|
431
552
|
unique_names = set()
|
|
432
553
|
vinputs = []
|
|
@@ -440,18 +561,35 @@ class OnnxruntimeEvaluator:
|
|
|
440
561
|
vinputs.append(value)
|
|
441
562
|
|
|
442
563
|
# no need to run shape inference
|
|
443
|
-
voutputs =
|
|
564
|
+
prenodes, voutputs = self._make_model_outputs(node, vinputs)
|
|
444
565
|
|
|
445
|
-
onx = self._make_model_proto(
|
|
566
|
+
onx = self._make_model_proto(
|
|
567
|
+
[*prenodes, node], vinputs, voutputs, functions=functions
|
|
568
|
+
)
|
|
569
|
+
if node.op_type in {"Shape", "Size"}:
|
|
570
|
+
on_cpu = True
|
|
446
571
|
|
|
572
|
+
if self.dump_onnx_model:
|
|
573
|
+
onnx_save(
|
|
574
|
+
onx, self.dump_onnx_model, save_as_external_data=len(onx.graph.node) > 100
|
|
575
|
+
)
|
|
447
576
|
cls = (
|
|
448
577
|
InferenceSessionForNumpy
|
|
449
578
|
if any(isinstance(i, np.ndarray) for i in inputs)
|
|
450
579
|
and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
|
|
451
580
|
else InferenceSessionForTorch
|
|
452
581
|
)
|
|
582
|
+
if (
|
|
583
|
+
"providers" not in self.session_kwargs or not self.session_kwargs["providers"]
|
|
584
|
+
) and any(hasattr(t, "is_cuda") and t.is_cuda for t in inputs):
|
|
585
|
+
sess_kwargs = self.session_kwargs.copy()
|
|
586
|
+
sess_kwargs["providers"] = ["CUDAExecutionProvider"]
|
|
587
|
+
else:
|
|
588
|
+
sess_kwargs = self.session_kwargs or {}
|
|
589
|
+
if on_cpu and "CUDAExecutionProvider" in (sess_kwargs.get("providers", []) or []):
|
|
590
|
+
sess_kwargs["cpu_outputs"] = True
|
|
453
591
|
try:
|
|
454
|
-
sess = cls(onx, **
|
|
592
|
+
sess = cls(onx, **sess_kwargs)
|
|
455
593
|
except (
|
|
456
594
|
onnxruntime.capi.onnxruntime_pybind11_state.Fail,
|
|
457
595
|
onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
|
|
@@ -473,7 +611,17 @@ class OnnxruntimeEvaluator:
|
|
|
473
611
|
if i == "" or i in unique_names:
|
|
474
612
|
continue
|
|
475
613
|
unique_names.add(i)
|
|
476
|
-
|
|
614
|
+
if isinstance(it, OnnxList):
|
|
615
|
+
value = oh.make_value_info(
|
|
616
|
+
i,
|
|
617
|
+
type_proto=oh.make_sequence_type_proto(
|
|
618
|
+
oh.make_tensor_type_proto(
|
|
619
|
+
elem_type=dtype_to_tensor_dtype(it.dtype), shape=None
|
|
620
|
+
)
|
|
621
|
+
),
|
|
622
|
+
)
|
|
623
|
+
else:
|
|
624
|
+
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
|
|
477
625
|
vinputs.append(value)
|
|
478
626
|
|
|
479
627
|
reduced_set = self._get_hidden_inputs(g)
|
|
@@ -482,6 +630,10 @@ class OnnxruntimeEvaluator:
|
|
|
482
630
|
unique_names.add(i)
|
|
483
631
|
value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
|
|
484
632
|
vinputs.append(value)
|
|
633
|
+
assert len(reduced_set & set(context)) == len(reduced_set), (
|
|
634
|
+
f"Missing hidden inputs {sorted(reduced_set)} from context={sorted(context)} "
|
|
635
|
+
f"(len(inputs)={len([i for i in inputs if i])}) for node {pretty_onnx(node)}"
|
|
636
|
+
)
|
|
485
637
|
return vinputs
|
|
486
638
|
|
|
487
639
|
def _get_sess_if(
|
|
@@ -530,6 +682,14 @@ class OnnxruntimeEvaluator:
|
|
|
530
682
|
|
|
531
683
|
def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
|
|
532
684
|
"""Runs a node."""
|
|
685
|
+
if node.op_type[0] == "S":
|
|
686
|
+
if node.op_type == "SequenceEmpty":
|
|
687
|
+
dtype = TensorProto.FLOAT
|
|
688
|
+
for att in node.attribute:
|
|
689
|
+
if att.name == "dtype":
|
|
690
|
+
dtype = att.i
|
|
691
|
+
return [OnnxList(itype=dtype)]
|
|
692
|
+
|
|
533
693
|
types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
|
|
534
694
|
key = (id(node), *types)
|
|
535
695
|
if key in self._cache:
|
|
@@ -538,13 +698,31 @@ class OnnxruntimeEvaluator:
|
|
|
538
698
|
onx, sess = self._get_sess(node, inputs)
|
|
539
699
|
self._cache[key] = onx, sess
|
|
540
700
|
|
|
541
|
-
feeds =
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
701
|
+
feeds = {}
|
|
702
|
+
for i, val in zip(node.input, inputs):
|
|
703
|
+
if i == "":
|
|
704
|
+
assert (
|
|
705
|
+
val is None
|
|
706
|
+
), f"input name={i!r} but val={string_type(val, with_shape=True)}"
|
|
707
|
+
continue
|
|
708
|
+
feeds[i] = val
|
|
545
709
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
710
|
+
|
|
711
|
+
if node.op_type[0] == "C":
|
|
712
|
+
if node.op_type == "ConcatFromSequence":
|
|
713
|
+
res = sess.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
|
|
714
|
+
if isinstance(inputs[0][0], np.ndarray):
|
|
715
|
+
return list(res)
|
|
716
|
+
import torch
|
|
717
|
+
|
|
718
|
+
return [torch.from_numpy(r).to(inputs[0][0].device) for r in res]
|
|
719
|
+
|
|
546
720
|
outputs = list(sess.run(None, feeds))
|
|
547
721
|
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
722
|
+
assert not any(type(v) is list for v in outputs), (
|
|
723
|
+
f"One output type is a list, this should not be allowed, "
|
|
724
|
+
f"node.op_type={node.op_type}, feeds={string_type(feeds, with_shape=True)}"
|
|
725
|
+
)
|
|
548
726
|
return outputs
|
|
549
727
|
|
|
550
728
|
def _run_if(
|
|
@@ -570,7 +748,7 @@ class OnnxruntimeEvaluator:
|
|
|
570
748
|
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
571
749
|
return outputs
|
|
572
750
|
|
|
573
|
-
def
|
|
751
|
+
def _get_sess_scan_or_loop(
|
|
574
752
|
self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
|
|
575
753
|
) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
|
|
576
754
|
g = None
|
|
@@ -605,10 +783,26 @@ class OnnxruntimeEvaluator:
|
|
|
605
783
|
)
|
|
606
784
|
return onx, sess
|
|
607
785
|
|
|
608
|
-
def
|
|
786
|
+
def feeds_to_numpy(self, feeds):
|
|
787
|
+
new_feeds = {}
|
|
788
|
+
for k, v in feeds.items():
|
|
789
|
+
if hasattr(v, "detach"):
|
|
790
|
+
new_feeds[k] = v.detach().cpu().numpy()
|
|
791
|
+
elif isinstance(v, OnnxList):
|
|
792
|
+
new_feeds[k] = v.numpy()
|
|
793
|
+
else:
|
|
794
|
+
new_feeds[k] = v
|
|
795
|
+
return new_feeds
|
|
796
|
+
|
|
797
|
+
def _run_scan_or_loop(
|
|
609
798
|
self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
|
|
610
799
|
) -> List[Any]:
|
|
611
800
|
"""Runs a node Scan."""
|
|
801
|
+
assert not any(type(i) is list for i in inputs), (
|
|
802
|
+
f"One input is a list but it should an OnnxList, "
|
|
803
|
+
f"node.op_type={node.op_type!r}, node.input={node.input}, "
|
|
804
|
+
f"inputs={string_type(inputs, with_shape=True)}"
|
|
805
|
+
)
|
|
612
806
|
feeds = dict(zip(node.input, inputs))
|
|
613
807
|
feeds.update(results)
|
|
614
808
|
name = "body"
|
|
@@ -616,10 +810,21 @@ class OnnxruntimeEvaluator:
|
|
|
616
810
|
if key in self._cache:
|
|
617
811
|
sess = self._cache[key][1]
|
|
618
812
|
else:
|
|
619
|
-
self._cache[key] = _onx, sess = self.
|
|
813
|
+
self._cache[key] = _onx, sess = self._get_sess_scan_or_loop(
|
|
814
|
+
node, name, inputs, results
|
|
815
|
+
)
|
|
620
816
|
|
|
621
817
|
assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
|
|
622
818
|
feeds = {name: results[name] for name in sess.input_names}
|
|
819
|
+
if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()):
|
|
820
|
+
# This operator uses sequence. onnxruntime does not play well with sequence.
|
|
821
|
+
sess._run_init(feeds) # type: ignore[union-attr]
|
|
822
|
+
outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
|
|
823
|
+
return [
|
|
824
|
+
(OnnxList(v).to(feeds[node.input[0]]) if isinstance(v, list) else v)
|
|
825
|
+
for v in outputs
|
|
826
|
+
]
|
|
827
|
+
|
|
623
828
|
outputs = sess.run(None, feeds)
|
|
624
829
|
assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
|
|
625
830
|
return outputs
|
|
@@ -1,10 +1,6 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.config_helper import
|
|
4
|
-
update_config,
|
|
5
|
-
check_hasattr,
|
|
6
|
-
default_num_hidden_layers as nhl,
|
|
7
|
-
)
|
|
3
|
+
from ..helpers.config_helper import update_config, check_hasattr
|
|
8
4
|
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
9
5
|
|
|
10
6
|
|
|
@@ -13,8 +9,9 @@ __TASK__ = "feature-extraction"
|
|
|
13
9
|
|
|
14
10
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
15
11
|
"""Reduces a model size."""
|
|
16
|
-
check_hasattr(config, "
|
|
17
|
-
|
|
12
|
+
check_hasattr(config, "vocab_size")
|
|
13
|
+
# Bart architecture does not like too much that the number of layers is changed.
|
|
14
|
+
kwargs = dict(vocab_size=2056)
|
|
18
15
|
update_config(config, kwargs)
|
|
19
16
|
return kwargs
|
|
20
17
|
|
|
@@ -25,7 +22,8 @@ def get_inputs(
|
|
|
25
22
|
batch_size: int,
|
|
26
23
|
sequence_length: int,
|
|
27
24
|
dummy_max_token_id: int,
|
|
28
|
-
|
|
25
|
+
past_length: int = 30,
|
|
26
|
+
past_length2: int = 4,
|
|
29
27
|
decoder_attention_heads: Optional[int] = None,
|
|
30
28
|
encoder_attention_heads: Optional[int] = None,
|
|
31
29
|
encoder_ffn_dim: Optional[int] = None,
|
|
@@ -73,13 +71,13 @@ def get_inputs(
|
|
|
73
71
|
torch.randn(
|
|
74
72
|
batch_size,
|
|
75
73
|
encoder_attention_heads,
|
|
76
|
-
|
|
74
|
+
past_length,
|
|
77
75
|
encoder_ffn_dim,
|
|
78
76
|
),
|
|
79
77
|
torch.randn(
|
|
80
78
|
batch_size,
|
|
81
79
|
encoder_attention_heads,
|
|
82
|
-
|
|
80
|
+
past_length,
|
|
83
81
|
encoder_ffn_dim,
|
|
84
82
|
),
|
|
85
83
|
)
|
|
@@ -92,13 +90,13 @@ def get_inputs(
|
|
|
92
90
|
torch.randn(
|
|
93
91
|
batch_size,
|
|
94
92
|
decoder_attention_heads,
|
|
95
|
-
|
|
93
|
+
past_length2,
|
|
96
94
|
decoder_ffn_dim,
|
|
97
95
|
),
|
|
98
96
|
torch.randn(
|
|
99
97
|
batch_size,
|
|
100
98
|
decoder_attention_heads,
|
|
101
|
-
|
|
99
|
+
past_length2,
|
|
102
100
|
decoder_ffn_dim,
|
|
103
101
|
),
|
|
104
102
|
)
|
|
@@ -124,7 +122,8 @@ def get_inputs(
|
|
|
124
122
|
batch_size=batch_size + 1,
|
|
125
123
|
sequence_length=sequence_length + add_second_input,
|
|
126
124
|
dummy_max_token_id=dummy_max_token_id,
|
|
127
|
-
|
|
125
|
+
past_length=past_length,
|
|
126
|
+
past_length2=past_length2,
|
|
128
127
|
decoder_attention_heads=decoder_attention_heads,
|
|
129
128
|
encoder_attention_heads=encoder_attention_heads,
|
|
130
129
|
encoder_ffn_dim=encoder_ffn_dim,
|
|
@@ -146,7 +145,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
146
145
|
check_hasattr(config, "vocab_size")
|
|
147
146
|
kwargs = dict(
|
|
148
147
|
batch_size=2,
|
|
149
|
-
sequence_length=
|
|
148
|
+
sequence_length=12,
|
|
149
|
+
past_length=30,
|
|
150
|
+
past_length2=4,
|
|
150
151
|
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
151
152
|
)
|
|
152
153
|
for att in [
|