onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
1
+ from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
2
2
  import numpy as np
3
3
  from onnx import (
4
4
  AttributeProto,
@@ -6,6 +6,7 @@ from onnx import (
6
6
  FunctionProto,
7
7
  ModelProto,
8
8
  NodeProto,
9
+ TensorProto,
9
10
  TypeProto,
10
11
  ValueInfoProto,
11
12
  helper as oh,
@@ -16,7 +17,14 @@ from onnx import (
16
17
  from onnx.defs import onnx_opset_version
17
18
  import onnxruntime
18
19
  from ..helpers import string_type
19
- from ..helpers.onnx_helper import pretty_onnx, dtype_to_tensor_dtype, to_array_extended
20
+ from ..helpers.onnx_helper import (
21
+ get_hidden_inputs,
22
+ dtype_to_tensor_dtype,
23
+ np_dtype_to_tensor_dtype,
24
+ to_array_extended,
25
+ pretty_onnx,
26
+ )
27
+ from ..helpers.torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
20
28
  from ..helpers.ort_session import (
21
29
  InferenceSessionForTorch,
22
30
  InferenceSessionForNumpy,
@@ -31,6 +39,54 @@ PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
31
39
  Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
32
40
 
33
41
 
42
+ class OnnxList(list):
43
+ """Defines a list for the runtime."""
44
+
45
+ def __init__(self, itype: Union[list, int]):
46
+ super().__init__()
47
+ if isinstance(itype, int):
48
+ self.itype = itype
49
+ self.dtype = onnx_dtype_to_torch_dtype(itype)
50
+ else:
51
+ assert itype, "The list cannot be created with an empty list."
52
+ self.itype = (
53
+ np_dtype_to_tensor_dtype(itype[0].dtype)
54
+ if isinstance(itype[0], np.ndarray)
55
+ else torch_dtype_to_onnx_dtype(itype[0].dtype)
56
+ )
57
+ self.extend(itype)
58
+ self.dtype = itype[0].dtype
59
+ self.shape = "OnnxList"
60
+
61
+ def get_device(self):
62
+ "Returns the device of the first tensor."
63
+ assert len(self) > 0, "Cannot access the device for an empty list."
64
+ return self[0].get_device() if hasattr(self[0], "get_device") else -1
65
+
66
+ def numpy(self):
67
+ "Creates a new list with all tensors on numpy or self it is already the case."
68
+ if all(isinstance(v, np.ndarray) for v in self):
69
+ return self
70
+ return OnnxList([v.detach().cpu().numpy() for v in self])
71
+
72
+ def to(self, tensor_like) -> "OnnxList":
73
+ "Creates a new list with all tensors on numpy or pytorch depending on `tensor_like`."
74
+ if isinstance(tensor_like, np.ndarray):
75
+ return self
76
+ import torch
77
+
78
+ return OnnxList(
79
+ [
80
+ torch.from_numpy(t).to(tensor_like.device) if isinstance(t, np.ndarray) else t
81
+ for t in self
82
+ ]
83
+ )
84
+
85
+ def clone(self) -> "OnnxList":
86
+ "Clone (torch)."
87
+ return OnnxList([t.clone() for t in self]) if len(self) > 0 else OnnxList(self.itype)
88
+
89
+
34
90
  class OnnxruntimeEvaluator:
35
91
  """
36
92
  This class loads an onnx model and the executes one by one the nodes
@@ -54,6 +110,9 @@ class OnnxruntimeEvaluator:
54
110
  :param whole: if True, do not split node by node
55
111
  :param torch_or_numpy: force the use of one of them, True for torch,
56
112
  False for numpy, None to let the class choose
113
+ :param dump_onnx_model: dumps the temporary onnx model created if whole is True
114
+ :param function_kwargs: a FunctionProto may have parameters,
115
+ this contains the values of them
57
116
  """
58
117
 
59
118
  def __init__(
@@ -77,6 +136,8 @@ class OnnxruntimeEvaluator:
77
136
  opsets: Optional[Union[int, Dict[str, int]]] = None,
78
137
  whole: bool = False,
79
138
  torch_or_numpy: Optional[bool] = None,
139
+ function_kwargs: Optional[Dict[str, Any]] = None,
140
+ dump_onnx_model: Optional[str] = None,
80
141
  ):
81
142
  if isinstance(proto, str):
82
143
  self.proto: Proto = load(proto)
@@ -90,6 +151,9 @@ class OnnxruntimeEvaluator:
90
151
  assert isinstance(
91
152
  self.proto, PROTO
92
153
  ), f"Unexpected type for self.proto {type(self.proto)}"
154
+ assert (
155
+ whole or not dump_onnx_model
156
+ ), f"whole must be True for dump_onnx_model={dump_onnx_model!r}"
93
157
 
94
158
  self._cache: Dict[
95
159
  Any, Tuple[Proto, Union["OnnxruntimeEvaluator", _InferenceSession]] # noqa: UP037
@@ -109,6 +173,8 @@ class OnnxruntimeEvaluator:
109
173
  use_training_api=use_training_api,
110
174
  )
111
175
  self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
176
+ self.function_kwargs = function_kwargs
177
+ self.dump_onnx_model = dump_onnx_model
112
178
 
113
179
  self.verbose = verbose
114
180
  self.torch_or_numpy = torch_or_numpy
@@ -199,6 +265,8 @@ class OnnxruntimeEvaluator:
199
265
  def _log_arg(self, a: Any) -> Any:
200
266
  if isinstance(a, (str, int, float)):
201
267
  return a
268
+ if isinstance(a, OnnxList):
269
+ return string_type(a)
202
270
  device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
203
271
  if hasattr(a, "shape"):
204
272
  prefix = "A:" if hasattr(a, "astype") else "T:"
@@ -221,6 +289,12 @@ class OnnxruntimeEvaluator:
221
289
  def _is_local_function(self, node: NodeProto) -> bool:
222
290
  return (node.domain, node.op_type) in self.local_functions
223
291
 
292
+ def _run_init(self, feed_inputs):
293
+ if self.sess_ is None:
294
+ assert self.proto, "self.proto is empty"
295
+ _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
296
+ return self.sess_
297
+
224
298
  def run(
225
299
  self,
226
300
  outputs: Optional[List[str]],
@@ -244,9 +318,7 @@ class OnnxruntimeEvaluator:
244
318
  """
245
319
  if self.rt_nodes_ is None:
246
320
  # runs a whole
247
- if self.sess_ is None:
248
- assert self.proto, "self.proto is empty"
249
- _, self.sess_ = self._get_sess(self.proto, list(feed_inputs.values()))
321
+ self._run_init(feed_inputs)
250
322
  assert self.sess_, "mypy not happy"
251
323
  return self.sess_.run(outputs, feed_inputs)
252
324
  if outputs is None:
@@ -273,14 +345,16 @@ class OnnxruntimeEvaluator:
273
345
  if node.op_type == "If" and node.domain == "":
274
346
  outputs = self._run_if(node, inputs, results)
275
347
  elif node.op_type in {"Scan", "Loop"} and node.domain == "":
276
- outputs = self._run_scan(node, inputs, results)
348
+ outputs = self._run_scan_or_loop(node, inputs, results)
277
349
  elif self._is_local_function(node):
278
350
  outputs = self._run_local(node, inputs, results)
279
351
  else:
280
352
  outputs = self._run(node, inputs, results)
281
- for name, value in zip(node.output, outputs):
282
- if name == "":
283
- continue
353
+ node_output = [o for o in node.output if o]
354
+ assert len(node_output) == len(
355
+ outputs
356
+ ), f"Length mismatch between node output={node.output} and outputs={outputs}"
357
+ for name, value in zip(node_output, outputs):
284
358
  self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
285
359
  assert isinstance(name, str), f"unexpected type for name {type(name)}"
286
360
  results[name] = value
@@ -355,11 +429,12 @@ class OnnxruntimeEvaluator:
355
429
  nodes: Sequence[NodeProto],
356
430
  vinputs: Sequence[ValueInfoProto],
357
431
  voutputs: Sequence[ValueInfoProto],
432
+ functions: Optional[Sequence[FunctionProto]] = None,
358
433
  ) -> ModelProto:
359
434
  onx = oh.make_model(
360
435
  oh.make_graph(nodes, "-", vinputs, voutputs),
361
436
  ir_version=getattr(self.proto, "ir_version", self.ir_version),
362
- functions=getattr(self.proto, "functions", None),
437
+ functions=[*getattr(self.proto, "functions", []), *(functions or [])],
363
438
  )
364
439
  del onx.opset_import[:]
365
440
  if hasattr(self.proto, "opset_import"):
@@ -373,51 +448,61 @@ class OnnxruntimeEvaluator:
373
448
  )
374
449
  else:
375
450
  onx.opset_import.append(oh.make_opsetid("", onnx_opset_version()))
451
+ opsets = {d.domain: d.version for d in onx.opset_import}
452
+ add = {}
453
+ for node in self.enumerate_nodes(onx.graph.node):
454
+ if node.domain and node.domain not in opsets and node.domain not in add:
455
+ add[node.domain] = 1
456
+ onx.opset_import.extend([oh.make_opsetid(k, v) for k, v in add.items()])
376
457
 
377
458
  # That helps fixing bugs.
378
459
  onx = shi.infer_shapes(onx)
379
460
  return onx
380
461
 
381
- @classmethod
382
- def _get_hidden_inputs(self, graph: GraphProto) -> Set[str]:
383
- """
384
- Returns the hidden inputs (inputs coming from an upper context)
385
- used by a subgraph.
386
- """
387
- hidden = set()
388
- memo = set(i.name for i in graph.initializer)
389
- memo |= set(i.name for i in graph.sparse_initializer)
390
- for node in graph.node:
391
- for i in node.input:
392
- if i not in memo:
393
- hidden.add(i)
394
- for att in node.attribute:
395
- if att.type == AttributeProto.GRAPH and att.g:
396
- hid = self._get_hidden_inputs(att.g)
397
- less = set(h for h in hid if h not in memo)
398
- hidden |= less
399
- memo |= set(node.output)
400
- return hidden
462
+ def _make_model_outputs(
463
+ self, node: NodeProto, inputs: List[ValueInfoProto]
464
+ ) -> Tuple[List[NodeProto], List[ValueInfoProto]]:
465
+ return [], [oh.make_value_info(o, TypeProto()) for o in node.output if o]
466
+
467
+ def enumerate_nodes(self, nodes: List[NodeProto]) -> Iterator[NodeProto]:
468
+ "Enumerates nodes recursively."
469
+ for node in nodes:
470
+ if node.op_type in {"Scan", "If", "Loop"}:
471
+ for att in node.attribute:
472
+ if att.type == AttributeProto.GRAPH:
473
+ yield from self.enumerate_nodes(att.g.node)
474
+ yield node
401
475
 
402
476
  @classmethod
403
- def _get_hidden_node_inputs(self, node: NodeProto) -> Set[str]:
404
- """Calls multiple _get_hidden_inputs on every attribute."""
477
+ def _get_hidden_node_inputs(cls, node: NodeProto) -> Set[str]:
478
+ """Calls multiple get_hidden_inputs on every attribute."""
405
479
  if node.op_type not in {"Loop", "Scan", "If"}:
406
480
  return set()
407
481
  hidden = set()
408
482
  for att in node.attribute:
409
483
  if att.type == AttributeProto.GRAPH:
410
- hidden |= self._get_hidden_inputs(att.g)
484
+ hidden |= get_hidden_inputs(att.g)
411
485
  return hidden - (hidden & set(node.input))
412
486
 
413
487
  def _get_sess(
414
488
  self, node: Union[ModelProto, NodeProto], inputs: List[Any]
415
489
  ) -> Tuple[ModelProto, _InferenceSession]:
490
+ on_cpu = None
416
491
  if isinstance(node, ModelProto):
417
492
  onx = node
418
493
  else:
494
+ functions = []
495
+ if isinstance(node, FunctionProto):
496
+ functions.append(node)
497
+ node = oh.make_node(
498
+ node.name,
499
+ list(node.input),
500
+ list(node.output),
501
+ domain=node.domain,
502
+ **(self.function_kwargs or {}),
503
+ )
419
504
  assert isinstance(node, NodeProto), f"Unexpected type {type(node)} for node"
420
- if node.op_type == "Constant":
505
+ if node.op_type == "Constant" and node.domain == "":
421
506
  # We force the type to be a boolean.
422
507
  ref = ExtendedReferenceEvaluator(node)
423
508
  cst = ref.run(None, {})[0]
@@ -427,6 +512,19 @@ class OnnxruntimeEvaluator:
427
512
  node.output[0], dtype_to_tensor_dtype(cst.dtype), cst.shape
428
513
  )
429
514
  ]
515
+ prenodes = [] # type: ignore[var-annotated]
516
+ elif node.op_type == "ConcatFromSequence" and node.domain == "":
517
+ # We force the type to be a boolean.
518
+ vinputs = [
519
+ oh.make_value_info(
520
+ node.input[0],
521
+ type_proto=oh.make_sequence_type_proto(
522
+ oh.make_tensor_type_proto(elem_type=inputs[0].itype, shape=None)
523
+ ),
524
+ )
525
+ ]
526
+ voutputs = [oh.make_tensor_value_info(node.output[0], inputs[0].itype, None)]
527
+ prenodes = [] # type: ignore[var-annotated]
430
528
  else:
431
529
  unique_names = set()
432
530
  vinputs = []
@@ -440,18 +538,35 @@ class OnnxruntimeEvaluator:
440
538
  vinputs.append(value)
441
539
 
442
540
  # no need to run shape inference
443
- voutputs = [oh.make_value_info(o, TypeProto()) for o in node.output]
541
+ prenodes, voutputs = self._make_model_outputs(node, vinputs)
444
542
 
445
- onx = self._make_model_proto([node], vinputs, voutputs)
543
+ onx = self._make_model_proto(
544
+ [*prenodes, node], vinputs, voutputs, functions=functions
545
+ )
546
+ if node.op_type in {"Shape", "Size"}:
547
+ on_cpu = True
446
548
 
549
+ if self.dump_onnx_model:
550
+ onnx_save(
551
+ onx, self.dump_onnx_model, save_as_external_data=len(onx.graph.node) > 100
552
+ )
447
553
  cls = (
448
554
  InferenceSessionForNumpy
449
555
  if any(isinstance(i, np.ndarray) for i in inputs)
450
556
  and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
451
557
  else InferenceSessionForTorch
452
558
  )
559
+ if (
560
+ "providers" not in self.session_kwargs or not self.session_kwargs["providers"]
561
+ ) and any(hasattr(t, "is_cuda") and t.is_cuda for t in inputs):
562
+ sess_kwargs = self.session_kwargs.copy()
563
+ sess_kwargs["providers"] = ["CUDAExecutionProvider"]
564
+ else:
565
+ sess_kwargs = self.session_kwargs or {}
566
+ if on_cpu and "CUDAExecutionProvider" in (sess_kwargs.get("providers", []) or []):
567
+ sess_kwargs["cpu_outputs"] = True
453
568
  try:
454
- sess = cls(onx, **self.session_kwargs)
569
+ sess = cls(onx, **sess_kwargs)
455
570
  except (
456
571
  onnxruntime.capi.onnxruntime_pybind11_state.Fail,
457
572
  onnxruntime.capi.onnxruntime_pybind11_state.InvalidGraph,
@@ -473,15 +588,29 @@ class OnnxruntimeEvaluator:
473
588
  if i == "" or i in unique_names:
474
589
  continue
475
590
  unique_names.add(i)
476
- value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
591
+ if isinstance(it, OnnxList):
592
+ value = oh.make_value_info(
593
+ i,
594
+ type_proto=oh.make_sequence_type_proto(
595
+ oh.make_tensor_type_proto(
596
+ elem_type=dtype_to_tensor_dtype(it.dtype), shape=None
597
+ )
598
+ ),
599
+ )
600
+ else:
601
+ value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(it.dtype), it.shape)
477
602
  vinputs.append(value)
478
603
 
479
- reduced_set = self._get_hidden_inputs(g)
604
+ reduced_set = get_hidden_inputs(g)
480
605
  for i, v in context.items():
481
606
  if i in reduced_set and i not in unique_names:
482
607
  unique_names.add(i)
483
608
  value = oh.make_tensor_value_info(i, dtype_to_tensor_dtype(v.dtype), v.shape)
484
609
  vinputs.append(value)
610
+ assert len(reduced_set & set(context)) == len(reduced_set), (
611
+ f"Missing hidden inputs {sorted(reduced_set)} from context={sorted(context)} "
612
+ f"(len(inputs)={len([i for i in inputs if i])}) for node {pretty_onnx(node)}"
613
+ )
485
614
  return vinputs
486
615
 
487
616
  def _get_sess_if(
@@ -530,6 +659,14 @@ class OnnxruntimeEvaluator:
530
659
 
531
660
  def _run(self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]) -> List[Any]:
532
661
  """Runs a node."""
662
+ if node.op_type[0] == "S":
663
+ if node.op_type == "SequenceEmpty":
664
+ dtype = TensorProto.FLOAT
665
+ for att in node.attribute:
666
+ if att.name == "dtype":
667
+ dtype = att.i
668
+ return [OnnxList(itype=dtype)]
669
+
533
670
  types = [(None if a is None else (a.dtype, a.shape)) for a in inputs]
534
671
  key = (id(node), *types)
535
672
  if key in self._cache:
@@ -538,13 +675,31 @@ class OnnxruntimeEvaluator:
538
675
  onx, sess = self._get_sess(node, inputs)
539
676
  self._cache[key] = onx, sess
540
677
 
541
- feeds = dict(zip(node.input, inputs))
542
- if "" in feeds:
543
- feeds[""] = np.array([0], dtype=np.float32)
544
-
678
+ feeds = {}
679
+ for i, val in zip(node.input, inputs):
680
+ if i == "":
681
+ assert (
682
+ val is None
683
+ ), f"input name={i!r} but val={string_type(val, with_shape=True)}"
684
+ continue
685
+ feeds[i] = val
545
686
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
687
+
688
+ if node.op_type[0] == "C":
689
+ if node.op_type == "ConcatFromSequence":
690
+ res = sess.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
691
+ if isinstance(inputs[0][0], np.ndarray):
692
+ return list(res)
693
+ import torch
694
+
695
+ return [torch.from_numpy(r).to(inputs[0][0].device) for r in res]
696
+
546
697
  outputs = list(sess.run(None, feeds))
547
698
  assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
699
+ assert not any(type(v) is list for v in outputs), (
700
+ f"One output type is a list, this should not be allowed, "
701
+ f"node.op_type={node.op_type}, feeds={string_type(feeds, with_shape=True)}"
702
+ )
548
703
  return outputs
549
704
 
550
705
  def _run_if(
@@ -570,7 +725,7 @@ class OnnxruntimeEvaluator:
570
725
  assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
571
726
  return outputs
572
727
 
573
- def _get_sess_scan(
728
+ def _get_sess_scan_or_loop(
574
729
  self, node: NodeProto, branch: str, inputs: List[Any], context: Dict[str, Any]
575
730
  ) -> Tuple[ModelProto, "OnnxruntimeEvaluator"]:
576
731
  g = None
@@ -605,10 +760,26 @@ class OnnxruntimeEvaluator:
605
760
  )
606
761
  return onx, sess
607
762
 
608
- def _run_scan(
763
+ def feeds_to_numpy(self, feeds):
764
+ new_feeds = {}
765
+ for k, v in feeds.items():
766
+ if hasattr(v, "detach"):
767
+ new_feeds[k] = v.detach().cpu().numpy()
768
+ elif isinstance(v, OnnxList):
769
+ new_feeds[k] = v.numpy()
770
+ else:
771
+ new_feeds[k] = v
772
+ return new_feeds
773
+
774
+ def _run_scan_or_loop(
609
775
  self, node: NodeProto, inputs: List[Any], results: Dict[str, Any]
610
776
  ) -> List[Any]:
611
777
  """Runs a node Scan."""
778
+ assert not any(type(i) is list for i in inputs), (
779
+ f"One input is a list but it should an OnnxList, "
780
+ f"node.op_type={node.op_type!r}, node.input={node.input}, "
781
+ f"inputs={string_type(inputs, with_shape=True)}"
782
+ )
612
783
  feeds = dict(zip(node.input, inputs))
613
784
  feeds.update(results)
614
785
  name = "body"
@@ -616,10 +787,21 @@ class OnnxruntimeEvaluator:
616
787
  if key in self._cache:
617
788
  sess = self._cache[key][1]
618
789
  else:
619
- self._cache[key] = _onx, sess = self._get_sess_scan(node, name, inputs, results)
790
+ self._cache[key] = _onx, sess = self._get_sess_scan_or_loop(
791
+ node, name, inputs, results
792
+ )
620
793
 
621
794
  assert hasattr(sess, "run"), f"Missing method run for type {type(sess)}"
622
795
  feeds = {name: results[name] for name in sess.input_names}
796
+ if node.op_type == "Loop" and any(isinstance(v, OnnxList) for v in feeds.values()):
797
+ # This operator uses sequence. onnxruntime does not play well with sequence.
798
+ sess._run_init(feeds) # type: ignore[union-attr]
799
+ outputs = sess.sess_.sess.run(None, self.feeds_to_numpy(feeds)) # type: ignore[union-attr]
800
+ return [
801
+ (OnnxList(v).to(feeds[node.input[0]]) if isinstance(v, list) else v)
802
+ for v in outputs
803
+ ]
804
+
623
805
  outputs = sess.run(None, feeds)
624
806
  assert isinstance(outputs, list), f"Unexpected type for outputs {type(outputs)}"
625
807
  return outputs
@@ -1,10 +1,6 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import (
4
- update_config,
5
- check_hasattr,
6
- default_num_hidden_layers as nhl,
7
- )
3
+ from ..helpers.config_helper import update_config, check_hasattr
8
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
9
5
 
10
6
 
@@ -13,8 +9,9 @@ __TASK__ = "feature-extraction"
13
9
 
14
10
  def reduce_model_config(config: Any) -> Dict[str, Any]:
15
11
  """Reduces a model size."""
16
- check_hasattr(config, "num_hidden_layers")
17
- kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
12
+ check_hasattr(config, "vocab_size")
13
+ # Bart architecture does not like too much that the number of layers is changed.
14
+ kwargs = dict(vocab_size=2056)
18
15
  update_config(config, kwargs)
19
16
  return kwargs
20
17
 
@@ -25,7 +22,8 @@ def get_inputs(
25
22
  batch_size: int,
26
23
  sequence_length: int,
27
24
  dummy_max_token_id: int,
28
- sequence_length2: int = 3,
25
+ past_length: int = 30,
26
+ past_length2: int = 4,
29
27
  decoder_attention_heads: Optional[int] = None,
30
28
  encoder_attention_heads: Optional[int] = None,
31
29
  encoder_ffn_dim: Optional[int] = None,
@@ -73,13 +71,13 @@ def get_inputs(
73
71
  torch.randn(
74
72
  batch_size,
75
73
  encoder_attention_heads,
76
- sequence_length,
74
+ past_length,
77
75
  encoder_ffn_dim,
78
76
  ),
79
77
  torch.randn(
80
78
  batch_size,
81
79
  encoder_attention_heads,
82
- sequence_length,
80
+ past_length,
83
81
  encoder_ffn_dim,
84
82
  ),
85
83
  )
@@ -92,13 +90,13 @@ def get_inputs(
92
90
  torch.randn(
93
91
  batch_size,
94
92
  decoder_attention_heads,
95
- sequence_length2,
93
+ past_length2,
96
94
  decoder_ffn_dim,
97
95
  ),
98
96
  torch.randn(
99
97
  batch_size,
100
98
  decoder_attention_heads,
101
- sequence_length2,
99
+ past_length2,
102
100
  decoder_ffn_dim,
103
101
  ),
104
102
  )
@@ -124,7 +122,8 @@ def get_inputs(
124
122
  batch_size=batch_size + 1,
125
123
  sequence_length=sequence_length + add_second_input,
126
124
  dummy_max_token_id=dummy_max_token_id,
127
- sequence_length2=sequence_length2,
125
+ past_length=past_length,
126
+ past_length2=past_length2,
128
127
  decoder_attention_heads=decoder_attention_heads,
129
128
  encoder_attention_heads=encoder_attention_heads,
130
129
  encoder_ffn_dim=encoder_ffn_dim,
@@ -146,7 +145,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
146
145
  check_hasattr(config, "vocab_size")
147
146
  kwargs = dict(
148
147
  batch_size=2,
149
- sequence_length=30,
148
+ sequence_length=12,
149
+ past_length=30,
150
+ past_length2=4,
150
151
  dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
151
152
  )
152
153
  for att in [