onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
1
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
2
2
  import numpy as np
3
3
  from onnx import (
4
4
  AttributeProto,
@@ -6,6 +6,7 @@ from onnx import (
6
6
  FunctionProto,
7
7
  ModelProto,
8
8
  NodeProto,
9
+ TensorProto,
9
10
  TypeProto,
10
11
  ValueInfoProto,
11
12
  helper as oh,
@@ -16,7 +17,13 @@ from onnx import (
16
17
  from onnx.defs import onnx_opset_version
17
18
  import onnxruntime
18
19
  from ..helpers import string_type
19
- from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
20
+ from ..helpers.onnx_helper import (
21
+ pretty_onnx,
22
+ dtype_to_tensor_dtype,
23
+ to_array_extended,
24
+ np_dtype_to_tensor_dtype,
25
+ )
26
+ from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
20
27
  from ..helpers.ort_session import (
21
28
  InferenceSessionForTorch,
22
29
  InferenceSessionForNumpy,
@@ -31,6 +38,54 @@ PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
31
38
  Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
32
39
 
33
40
 
41
+ class OnnxList(list):
42
+ """Defines a list for the runtime."""
43
+
44
+ def __init__(self, itype: Union[list, int]):
45
+ super().__init__()
46
+ if isinstance(itype, int):
47
+ self.itype = itype
48
+ self.dtype = onnx_dtype_to_torch_dtype(itype)
49
+ else:
50
+ assert itype, "The list cannot be created with an empty list."
51
+ self.itype = (
52
+ np_dtype_to_tensor_dtype(itype[0].dtype)
53
+ if isinstance(itype[0], np.ndarray)
54
+ else torch_dtype_to_onnx_dtype(itype[0].dtype)
55
+ )
56
+ self.extend(itype)
57
+ self.dtype = itype[0].dtype
58
+ self.shape = "OnnxList"
59
+
60
+ def get_device(self):
61
+ "Returns the device of the first tensor."
62
+ assert len(self) > 0, "Cannot access the device for an empty list."
63
+ return self[0].get_device() if hasattr(self[0], "get_device") else -1
64
+
65
+ def numpy(self):
66
+ "Creates a new list with all tensors on numpy or self it is already the case."
67
+ if all(isinstance(v, np.ndarray) for v in self):
68
+ return self
69
+ return OnnxList([v.detach().cpu().numpy() for v in self])
70
+
71
+ def to(self, tensor_like) -> "OnnxList":
72
+ "Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`."
73
+ if isinstance(tensor_like, np.ndarray):
74
+ return self
75
+ import torch
76
+
77
+ return OnnxList(
78
+ [
79
+ torch.from_numpy(t).to(tensor_like.device) if isinstance(t, np.ndarray) else t
80
+ for t in self
81
+ ]
82
+ )
83
+
84
+ def clone(self) -> "OnnxList":
85
+ "Clone (torch)."
86
+ return OnnxList([t.clone() for t in self]) if len(self) > 0 else OnnxList(self.itype)
87
+
88
+
34
89
  class OnnxruntimeEvaluator:
35
90
  """
36
91
  This class loads an onnx model and the executes one by one the nodes
@@ -54,6 +109,9 @@ class OnnxruntimeEvaluator:
54
109
  :param whole: if True, do not split node by node
55
110
  :param torch_or_numpy: force the use of one of them, True for torch,
56
111
  False for numpy, None to let the class choose
112
+ :param dump_onnx_model: dumps the temporary onnx model created if whole is True
113
+ :param function_kwargs: a FunctionProto may have parameters,
114
+ this contains the values of them
57
115
  """
58
116
 
59
117
  def __init__(
@@ -77,6 +135,8 @@ class OnnxruntimeEvaluator:
77
135
  opsets: Optional[Union[int, Dict[str, int]]] = None,
78
136
  whole: bool = False,
79
137
  torch_or_numpy: Optional[bool] = None,
138
+ function_kwargs: Optional[Dict[str, Any]] = None,
139
+ dump_onnx_model: Optional[str] = None,
80
140
  ):
81
141
  if isinstance(proto, str):
82
142
  self.proto: Proto = load(proto)
@@ -90,6 +150,9 @@ class OnnxruntimeEvaluator:
90
150
  assert isinstance(
91
151
  self.proto, PROTO
92
152
  ), f"Unexpected type for self.proto {type(self.proto)}"
153
+ assert (
154
+ whole or not dump_onnx_model
155
+ ), f"whole must be True for dump_onnx_model={dump_onnx_model!r}"
93
156
 
94
157
  self._cache: Dict[
95
158
  Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
@@ -109,6 +172,8 @@ class OnnxruntimeEvaluator:
109
172
  use_training_api=use_training_api,
110
173
  )
111
174
  self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
175
+ self.function_kwargs = function_kwargs
176
+ self.dump_onnx_model = dump_onnx_model
112
177
 
113
178
  self.verbose = verbose
114
179
  self.torch_or_numpy = torch_or_numpy
@@ -199,6 +264,8 @@ class OnnxruntimeEvaluator:
199
264
  def _log_arg(self, a: Any) -> Any:
200
265
  if isinstance(a, (str, int, float)):
201
266
  return a
267
+ if isinstance(a, OnnxList):
268
+ return string_type(a)
202
269
  device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
203
270
  if hasattr(a, "shape"):
204
271
  prefix = "A:" if hasattr(a, "astype") else "T:"
@@ -221,6 +288,12 @@ class OnnxruntimeEvaluator:
221
288
  def _is_local_function(self, node: NodeProto) -> bool:
222
289
  return (node.domain, node.op_type) in self.local_functions
223
290
 
291
+ def _run_init(self, feed_inputs):
292
+ if self.sess_ is None:
293
+ assert self.proto, "self.proto is empty"
294
+ _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
295
+ return self.sess_
296
+
224
297
  def run(
225
298
  self,
226
299
  outputs: Optional[List[str]],
@@ -244,9 +317,7 @@ class OnnxruntimeEvaluator:
244
317
  """
245
318
  if self.rt_nodes_ is None:
246
319
  # runs a whole
247
- if self.sess_ is None:
248
- assert self.proto, "self.proto is empty"
249
- _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
320
+ self._run_init(feed_inputs)
250
321
  assert self.sess_, "mypy not happy"
251
322
  return self.sess_.run(outputs, feed_inputs)
252
323
  if outputs is None:
@@ -273,14 +344,16 @@ class OnnxruntimeEvaluator:
273
344
  if node.op_type == "If" and node.domain == "":
274
345
  outputs = self._run_if(node, inputs, results)
275
346
  elif node.op_type in {"Scan", "Loop"} and node.domain == "":
276
- outputs = self._run_scan(node, inputs, results)
347
+ outputs = self._run_scan_or_loop(node, inputs, results)
277
348
  elif self._is_local_function(node):
278
349
  outputs = self._run_local(node, inputs, results)
279
350
  else:
280
351
  outputs = self._run(node, inputs, results)
281
- for name, value in zip(node.output, outputs):
282
- if name == "":
283
- continue
352
+ node_output = [o for o in node.output if o]
353
+ assert len(node_output) == len(
354
+ outputs
355
+ ), f"Length mismatch between node output={node.output} and outputs={outputs}"
356
+ for name, value in zip(node_output, outputs):
284
357
  self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
285
358
  assert isinstance(name, str), f"unexpected type for name {type(name)}"
286
359
  results[name] = value
@@ -355,11 +428,12 @@ class OnnxruntimeEvaluator:
355
428
  nodes: Sequence[NodeProto],
356
429
  vinputs: Sequence[ValueInfoProto],
357
430
  voutputs: Sequence[ValueInfoProto],
431
+ functions: Optional[Sequence[FunctionProto]] = None,
358
432
  ) -> ModelProto:
359
433
  onx = oh.make_model(
360
434
  oh.make_graph(nodes, "-", vinputs, voutputs),
361
435
  ir_version=getattr(self.proto, "ir_version", self.ir_version),
362
- functions=getattr(self.proto, "functions", None),
436
+ functions=[*getattr(self.proto, "functions", []), *(functions or [])],
363
437
  )
364
438
  del onx.opset_import[:]
365
439
  if hasattr(self.proto, "opset_import"):
@@ -373,51 +447,85 @@ class OnnxruntimeEvaluator:
373
447
  )
374
448
  else:
375
449
  onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
450
+ opsets = {d.domain: d.version for d in onx.opset_import}
451
+ add = {}
452
+ for node in self.enumerate_nodes(onx.graph.node):
453
+ if node.domain and node.domain not in opsets and node.domain not in add:
454
+ add[node.domain] = 1
455
+ onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()])
376
456
 
377
457
  # That helps fixing bugs.
378
458
  onx = shi.infer_shapes(onx)
379
459
  return onx
380
460
 
461
+ def _make_model_outputs(
462
+ self, node: NodeProto, inputs: List[ValueInfoProto]
463
+ ) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
464
+ return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
465
+
466
+ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
467
+ "Enumerates nodes recursively."
468
+ for node in nodes:
469
+ if node.op_type in {"Scan", "If", "Loop"}:
470
+ for att in node.attribute:
471
+ if att.type == AttributeProto.GRAPH:
472
+ yield from self.enumerate_nodes(att.g.node)
473
+ yield node
474
+
381
475
  @classmethod
382
- def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
476
+ def _get_hidden_inputs(cls, graph: GraphProto) -> Set[str]:
383
477
  """
384
478
  Returns the hidden inputs (inputs coming from an upper context)
385
479
  used by a subgraph.
386
480
  """
387
481
  hidden = set()
388
- memo = set(i.name for i in graph.initializer)
389
- memo |= set(i.name for i in graph.sparse_initializer)
482
+ memo = (
483
+ {i.name for i in graph.initializer}
484
+ | {i.name for i in graph.sparse_initializer}
485
+ | {i.name for i in graph.input}
486
+ )
390
487
  for node in graph.node:
391
488
  for i in node.input:
392
489
  if i not in memo:
393
490
  hidden.add(i)
394
491
  for att in node.attribute:
395
492
  if att.type == AttributeProto.GRAPH and att.g:
396
- hid = self._get_hidden_inputs(att.g)
493
+ hid = cls._get_hidden_inputs(att.g)
397
494
  less = set(h for h in hid if h not in memo)
398
495
  hidden |= less
399
496
  memo |= set(node.output)
400
497
  return hidden
401
498
 
402
499
  @classmethod
403
- def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
500
+ def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
404
501
  """Calls multiple _get_hidden_inputs on every attribute."""
405
502
  if node.op_type not in {"Loop", "Scan", "If"}:
406
503
  return set()
407
504
  hidden = set()
408
505
  for att in node.attribute:
409
506
  if att.type == AttributeProto.GRAPH:
410
- hidden |= self._get_hidden_inputs(att.g)
507
+ hidden |= cls._get_hidden_inputs(att.g)
411
508
  return hidden - (hidden & set(node.input))
412
509
 
413
510
  def _get_sess(
414
511
  self, node: Union[ModelProto, NodeProto], inputs: List[Any]
415
512
  ) -> Tuple[ModelProto, _InferenceSession]:
513
+ on_cpu = None
416
514
  if isinstance(node, ModelProto):
417
515
  onx = node
418
516
  else:
517
+ functions = []
518
+ if isinstance(node, FunctionProto):
519
+ functions.append(node)
520
+ node = oh.make_node(
521
+ node.name,
522
+ list(node.input),
523
+ list(node.output),
524
+ domain=node.domain,
525
+ **(self.function_kwargs or {}),
526
+ )
419
527
  assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
420
- if node.op_type == "Constant":
528
+ if node.op_type == "Constant" and node.domain == "":
421
529
  # We force the type to be a boolean.
422
530
  ref = ExtendedReferenceEvaluator(node)
423
531
  cst = ref.run(None, {})[0]
@@ -427,6 +535,19 @@ class OnnxruntimeEvaluator:
427
535
  node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
428
536
  )
429
537
  ]
538
+ prenodes = [] # type: ignore[var-annotated]
539
+ elif node.op_type == "ConcatFromSequence" and node.domain == "":
540
+ # We force the type to be a boolean.
541
+ vinputs = [
542
+ oh.make_value_info(
543
+ node.input[0],
544
+ type_proto=oh.make_sequence_type_proto(
545
+ oh.make_tensor_type_proto(elem_type=inputs[0].itype, shape=None)
546
+ ),
547
+ )
548
+ ]
549
+ voutputs = [oh.make_tensor_value_info(node.output[0], inputs[0].itype, None)]
550
+ prenodes = [] # type: ignore[var-annotated]
430
551
  else:
431
552
  unique_names = set()
432
553
  vinputs = []
@@ -440,18 +561,35 @@ class OnnxruntimeEvaluator:
440
561
  vinputs.append(value)
441
562
 
442
563
  # no need to run shape inference
443
- voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
564
+ prenodes, voutputs = self._make_model_outputs(node, vinputs)
444
565
 
445
- onx = self._make_model_proto([node], vinputs, voutputs)
566
+ onx = self._make_model_proto(
567
+ [*prenodes, node], vinputs, voutputs, functions=functions
568
+ )
569
+ if node.op_type in {"Shape", "Size"}:
570
+ on_cpu = True
446
571
 
572
+ if self.dump_onnx_model:
573
+ onnx_save(
574
+ onx, self.dump_onnx_model, save_as_external_data=len(onx.graph.node) > 100
575
+ )
447
576
  cls = (
448
577
  InferenceSessionForNumpy
449
578
  if any(isinstance(i, np.ndarray) for i in inputs)
450
579
  and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
451
580
  else InferenceSessionForTorch
452
581
  )
582
+ if (
583
+ "providers" not in self.session_kwargs or not self.session_kwargs["providers"]
584
+ ) and any(hasattr(t, "is_cuda") and t.is_cuda for t in inputs):
585
+ sess_kwargs = self.session_kwargs.copy()
586
+ sess_kwargs["providers"] = ["CUDAExecutionProvider"]
587
+ else:
588
+ sess_kwargs = self.session_kwargs or {}
589
+ if on_cpu and "CUDAExecutionProvider" in (sess_kwargs.get("providers", []) or []):
590
+ sess_kwargs["cpu_outputs"] = True
453
591
  try:
454
- sess = cls(onx, **self.session_kwargs)
592
+ sess = cls(onx, **sess_kwargs)
455
593
  except (
456
594
  onnxruntime.capi.onnxruntime_pybind11_state.Fail,
457
595
  onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
@@ -473,7 +611,17 @@ class OnnxruntimeEvaluator:
473
611
  if i == "" or i in unique_names:
474
612
  continue
475
613
  unique_names.add(i)
476
- value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
614
+ if isinstance(it, OnnxList):
615
+ value = oh.make_value_info(
616
+ i,
617
+ type_proto=oh.make_sequence_type_proto(
618
+ oh.make_tensor_type_proto(
619
+ elem_type=dtype_to_tensor_dtype(it.dtype), shape=None
620
+ )
621
+ ),
622
+ )
623
+ else:
624
+ value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
477
625
  vinputs.append(value)
478
626
 
479
627
  reduced_set = self._get_hidden_inputs(g)
@@ -482,6 +630,10 @@ class OnnxruntimeEvaluator:
482
630
  unique_names.add(i)
483
631
  value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
484
632
  vinputs.append(value)
633
+ assert len(reduced_set & set(context)) == len(reduced_set), (
634
+ f"Missing hidden inputs {sorted(reduced_set)} from context={sorted(context)} "
635
+ f"(len(inputs)={len([i for i in inputs if i])}) for node {pretty_onnx(node)}"
636
+ )
485
637
  return vinputs
486
638
 
487
639
  def _get_sess_if(
@@ -530,6 +682,14 @@ class OnnxruntimeEvaluator:
530
682
 
531
683
  def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
532
684
  """Runs a node."""
685
+ if node.op_type[0] == "S":
686
+ if node.op_type == "SequenceEmpty":
687
+ dtype = TensorProto.FLOAT
688
+ for att in node.attribute:
689
+ if att.name == "dtype":
690
+ dtype = att.i
691
+ return [OnnxList(itype=dtype)]
692
+
533
693
  types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
534
694
  key = (id(node), *types)
535
695
  if key in self._cache:
@@ -538,13 +698,31 @@ class OnnxruntimeEvaluator:
538
698
  onx, sess = self._get_sess(node, inputs)
539
699
  self._cache[key] = onx, sess
540
700
 
541
- feeds = dict(zip(node.input, inputs))
542
- if "" in feeds:
543
- feeds[""] = np.array([0], dtype=np.float32)
544
-
701
+ feeds = {}
702
+ for i, val in zip(node.input, inputs):
703
+ if i == "":
704
+ assert (
705
+ val is None
706
+ ), f"input name={i!r} but val={string_type(val, with_shape=True)}"
707
+ continue
708
+ feeds[i] = val
545
709
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
710
+
711
+ if node.op_type[0] == "C":
712
+ if node.op_type == "ConcatFromSequence":
713
+ res = sess.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
714
+ if isinstance(inputs[0][0], np.ndarray):
715
+ return list(res)
716
+ import torch
717
+
718
+ return [torch.from_numpy(r).to(inputs[0][0].device) for r in res]
719
+
546
720
  outputs = list(sess.run(None, feeds))
547
721
  assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
722
+ assert not any(type(v) is list for v in outputs), (
723
+ f"One output type is a list, this should not be allowed, "
724
+ f"node.op_type={node.op_type}, feeds={string_type(feeds, with_shape=True)}"
725
+ )
548
726
  return outputs
549
727
 
550
728
  def _run_if(
@@ -570,7 +748,7 @@ class OnnxruntimeEvaluator:
570
748
  assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
571
749
  return outputs
572
750
 
573
- def _get_sess_scan(
751
+ def _get_sess_scan_or_loop(
574
752
  self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
575
753
  ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
576
754
  g = None
@@ -605,10 +783,26 @@ class OnnxruntimeEvaluator:
605
783
  )
606
784
  return onx, sess
607
785
 
608
- def _run_scan(
786
+ def feeds_to_numpy(self, feeds):
787
+ new_feeds = {}
788
+ for k, v in feeds.items():
789
+ if hasattr(v, "detach"):
790
+ new_feeds[k] = v.detach().cpu().numpy()
791
+ elif isinstance(v, OnnxList):
792
+ new_feeds[k] = v.numpy()
793
+ else:
794
+ new_feeds[k] = v
795
+ return new_feeds
796
+
797
+ def _run_scan_or_loop(
609
798
  self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
610
799
  ) -> List[Any]:
611
800
  """Runs a node Scan."""
801
+ assert not any(type(i) is list for i in inputs), (
802
+ f"One input is a list but it should an OnnxList, "
803
+ f"node.op_type={node.op_type!r}, node.input={node.input}, "
804
+ f"inputs={string_type(inputs, with_shape=True)}"
805
+ )
612
806
  feeds = dict(zip(node.input, inputs))
613
807
  feeds.update(results)
614
808
  name = "body"
@@ -616,10 +810,21 @@ class OnnxruntimeEvaluator:
616
810
  if key in self._cache:
617
811
  sess = self._cache[key][1]
618
812
  else:
619
- self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
813
+ self._cache[key] = _onx, sess = self._get_sess_scan_or_loop(
814
+ node, name, inputs, results
815
+ )
620
816
 
621
817
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
622
818
  feeds = {name: results[name] for name in sess.input_names}
819
+ if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()):
820
+ # This operator uses sequence. onnxruntime does not play well with sequence.
821
+ sess._run_init(feeds) # type: ignore[union-attr]
822
+ outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
823
+ return [
824
+ (OnnxList(v).to(feeds[node.input[0]]) if isinstance(v, list) else v)
825
+ for v in outputs
826
+ ]
827
+
623
828
  outputs = sess.run(None, feeds)
624
829
  assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
625
830
  return outputs
@@ -1,10 +1,6 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import (
4
- update_config,
5
- check_hasattr,
6
- default_num_hidden_layers as nhl,
7
- )
3
+ from ..helpers.config_helper import update_config, check_hasattr
8
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
9
5
 
10
6
 
@@ -13,8 +9,9 @@ __TASK__ = "feature-extraction"
13
9
 
14
10
  def reduce_model_config(config: Any) -> Dict[str, Any]:
15
11
  """Reduces a model size."""
16
- check_hasattr(config, "num_hidden_layers")
17
- kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
12
+ check_hasattr(config, "vocab_size")
13
+ # Bart architecture does not like too much that the number of layers is changed.
14
+ kwargs = dict(vocab_size=2056)
18
15
  update_config(config, kwargs)
19
16
  return kwargs
20
17
 
@@ -25,7 +22,8 @@ def get_inputs(
25
22
  batch_size: int,
26
23
  sequence_length: int,
27
24
  dummy_max_token_id: int,
28
- sequence_length2: int = 3,
25
+ past_length: int = 30,
26
+ past_length2: int = 4,
29
27
  decoder_attention_heads: Optional[int] = None,
30
28
  encoder_attention_heads: Optional[int] = None,
31
29
  encoder_ffn_dim: Optional[int] = None,
@@ -73,13 +71,13 @@ def get_inputs(
73
71
  torch.randn(
74
72
  batch_size,
75
73
  encoder_attention_heads,
76
- sequence_length,
74
+ past_length,
77
75
  encoder_ffn_dim,
78
76
  ),
79
77
  torch.randn(
80
78
  batch_size,
81
79
  encoder_attention_heads,
82
- sequence_length,
80
+ past_length,
83
81
  encoder_ffn_dim,
84
82
  ),
85
83
  )
@@ -92,13 +90,13 @@ def get_inputs(
92
90
  torch.randn(
93
91
  batch_size,
94
92
  decoder_attention_heads,
95
- sequence_length2,
93
+ past_length2,
96
94
  decoder_ffn_dim,
97
95
  ),
98
96
  torch.randn(
99
97
  batch_size,
100
98
  decoder_attention_heads,
101
- sequence_length2,
99
+ past_length2,
102
100
  decoder_ffn_dim,
103
101
  ),
104
102
  )
@@ -124,7 +122,8 @@ def get_inputs(
124
122
  batch_size=batch_size + 1,
125
123
  sequence_length=sequence_length + add_second_input,
126
124
  dummy_max_token_id=dummy_max_token_id,
127
- sequence_length2=sequence_length2,
125
+ past_length=past_length,
126
+ past_length2=past_length2,
128
127
  decoder_attention_heads=decoder_attention_heads,
129
128
  encoder_attention_heads=encoder_attention_heads,
130
129
  encoder_ffn_dim=encoder_ffn_dim,
@@ -146,7 +145,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
146
145
  check_hasattr(config, "vocab_size")
147
146
  kwargs = dict(
148
147
  batch_size=2,
149
- sequence_length=30,
148
+ sequence_length=12,
149
+ past_length=30,
150
+ past_length2=4,
150
151
  dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
151
152
  )
152
153
  for att in [
@@ -311,7 +311,11 @@ def get_inputs_default(
311
311
  attention_mask=torch.cat(
312
312
  [
313
313
  torch.ones((batch_size, sequence_length), dtype=torch.int64),
314
- input_ids.ne(pad_token_id).to(torch.int64),
314
+ (
315
+ torch.ones(input_ids.shape)
316
+ if pad_token_id is None
317
+ else input_ids.ne(pad_token_id)
318
+ ).to(torch.int64),
315
319
  ],
316
320
  axis=-1,
317
321
  ),