onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -3
  3. onnx_diagnostic/ci_models/ci_helpers.py +12 -7
  4. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  5. onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
  6. onnx_diagnostic/export/api.py +295 -5
  7. onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
  8. onnx_diagnostic/export/dynamic_shapes.py +45 -3
  9. onnx_diagnostic/export/shape_helper.py +1 -0
  10. onnx_diagnostic/ext_test_case.py +9 -2
  11. onnx_diagnostic/helpers/bench_run.py +1 -1
  12. onnx_diagnostic/helpers/cache_helper.py +0 -8
  13. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  14. onnx_diagnostic/helpers/helper.py +30 -1
  15. onnx_diagnostic/helpers/log_helper.py +1 -3
  16. onnx_diagnostic/helpers/optim_helper.py +116 -0
  17. onnx_diagnostic/helpers/ort_session.py +5 -0
  18. onnx_diagnostic/tasks/image_text_to_text.py +19 -9
  19. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  20. onnx_diagnostic/tasks/text_generation.py +3 -0
  21. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
  22. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  23. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  24. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
  31. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  32. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  33. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  34. onnx_diagnostic/torch_models/validate.py +48 -0
  35. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/METADATA +3 -1
  36. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/RECORD +39 -36
  37. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/WHEEL +0 -0
  38. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/licenses/LICENSE.txt +0 -0
  39. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.8.dist-info}/top_level.txt +0 -0
@@ -352,6 +352,19 @@ class CoupleInputsDynamicShapes:
352
352
  else None
353
353
  )
354
354
  assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
355
+ if set(inputs) != set(ds):
356
+ not_in_ds = {k for k in inputs if k not in ds}
357
+ not_in_inputs = {k for k in ds if k not in inputs}
358
+ assert not_in_inputs == {"kwargs"} and set(ds["kwargs"]) == not_in_ds, (
359
+ f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
360
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}, "
361
+ f"not_in_ds={not_in_ds}, not_in_inputs={not_in_inputs}"
362
+ )
363
+ # Tweak...
364
+ kws = ds["kwargs"]
365
+ del ds["kwargs"]
366
+ ds.update(kws)
367
+
355
368
  assert set(inputs) == set(ds), (
356
369
  f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
357
370
  f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
@@ -366,13 +379,15 @@ class CoupleInputsDynamicShapes:
366
379
  return dvalue if dvalue else None
367
380
 
368
381
  # A custom class.
369
- assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
382
+ assert inputs is None or inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
370
383
  f"Class {inputs.__class__.__name__!r} was not registered using "
371
384
  f"torch.utils._pytree.register_pytree_node, it is not possible to "
372
385
  f"map this class with the given dynamic shapes."
373
386
  )
374
387
  if flatten_unflatten:
375
388
  flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
389
+ if isinstance(flatunflat, (list, tuple, dict)) and len(flatunflat) == 0:
390
+ return flatunflat
376
391
  res = cls._generic_walker_step(
377
392
  processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
378
393
  )
@@ -667,6 +682,11 @@ class ModelInputs:
667
682
  if self.signature
668
683
  else None
669
684
  )
685
+ self.forward_parameters_kinds = (
686
+ {p.name: p.kind for p in self.signature.parameters.values()}
687
+ if self.signature
688
+ else None
689
+ )
670
690
  self.forward_ordered_parameter_names = (
671
691
  list(self.signature.parameters) if self.signature else None
672
692
  )
@@ -973,7 +993,13 @@ class ModelInputs:
973
993
  len(s1) == 1
974
994
  ), f"Different numbers of positional arguments {s1} for {self.full_name}"
975
995
  s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
976
- assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
996
+ assert len(s2) > 0, f"empty {s2} for {self.full_name}"
997
+ if len(s2) > 1:
998
+ # We need to keep the largest set of inputs, the one including all the others.
999
+ sum_s2 = set()
1000
+ for s in s2:
1001
+ sum_s2 |= set(s)
1002
+ s2 = {tuple(sum_s2)}
977
1003
  args = []
978
1004
  kwargs = {}
979
1005
  for i in range(s1.pop()):
@@ -993,12 +1019,18 @@ class ModelInputs:
993
1019
  f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
994
1020
  )
995
1021
 
996
- objs = [_[1][name] for _ in self.inputs]
1022
+ objs = [_[1][name] for _ in self.inputs if name in _[1]]
997
1023
  kwargs[name] = self.guess_dynamic_shape_object(
998
1024
  *objs,
999
1025
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
1000
1026
  msg=lambda name=name: f" failing input {name!r}",
1001
1027
  )
1028
+ # reordering
1029
+ if kwargs is not None and self.forward_ordered_parameter_names:
1030
+ kwargs1 = {
1031
+ p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1032
+ }
1033
+ kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
1002
1034
  return tuple(args), kwargs
1003
1035
 
1004
1036
  def move_to_kwargs(
@@ -1061,6 +1093,16 @@ class ModelInputs:
1061
1093
  f"and kwargs={set(kwargs)}, "
1062
1094
  f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
1063
1095
  )
1096
+ if kwargs is not None and self.forward_ordered_parameter_names:
1097
+ kwargs1 = {
1098
+ p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1099
+ }
1100
+ kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
1101
+ if kw_dyn is not None and self.forward_ordered_parameter_names:
1102
+ kw_dyn1 = {
1103
+ p: kw_dyn[p] for p in self.forward_ordered_parameter_names if p in kw_dyn
1104
+ }
1105
+ kw_dyn = {**kw_dyn1, **{k: v for k, v in kw_dyn.items() if k not in kw_dyn1}}
1064
1106
  return args, kwargs, (tuple(), kw_dyn)
1065
1107
 
1066
1108
  def validate_inputs_for_export(
@@ -210,6 +210,7 @@ def make_fake_with_dynamic_dimensions(
210
210
  This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211
211
  Parameter ``existing`` is used to reused the same object when the dynamic
212
212
  dimension is given the same name as another one.
213
+ This function works with caches only if ``transformers>=4.57``.
213
214
 
214
215
  A simple tensor:
215
216
 
@@ -1267,6 +1267,7 @@ class ExtTestCase(unittest.TestCase):
1267
1267
  :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
1268
1268
  """
1269
1269
  from .helpers import string_type, string_diff, max_diff
1270
+ from .helpers.torch_helper import torch_deepcopy
1270
1271
  from .helpers.rt_helper import make_feeds
1271
1272
  from .helpers.ort_session import InferenceSessionForTorch
1272
1273
 
@@ -1283,6 +1284,12 @@ class ExtTestCase(unittest.TestCase):
1283
1284
  model_file = proto
1284
1285
  name = proto
1285
1286
  proto = onnx.load(name)
1287
+ elif hasattr(proto, "save"):
1288
+ name = f"{test_name}.onnx"
1289
+ proto.save(name)
1290
+ proto = onnx.load(name)
1291
+ elif hasattr(proto, "model_proto"):
1292
+ proto = proto.model_proto
1286
1293
  elif not self.unit_test_going():
1287
1294
  assert isinstance(
1288
1295
  proto, onnx.ModelProto
@@ -1341,9 +1348,9 @@ class ExtTestCase(unittest.TestCase):
1341
1348
  if copy_inputs:
1342
1349
  expected = [
1343
1350
  (
1344
- model(*copy.deepcopy(inp))
1351
+ model(*torch_deepcopy(inp))
1345
1352
  if isinstance(inp, tuple)
1346
- else model(**copy.deepcopy(inp))
1353
+ else model(**torch_deepcopy(inp))
1347
1354
  )
1348
1355
  for inp in inputs
1349
1356
  ]
@@ -20,7 +20,7 @@ class BenchmarkError(RuntimeError):
20
20
 
21
21
 
22
22
  def _clean_string(s: str) -> str:
23
- cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {","}]
23
+ cleaned = [c for c in s if 32 <= ord(c) < 127 and c not in {",", ":"}]
24
24
  return "".join(cleaned)
25
25
 
26
26
 
@@ -28,14 +28,6 @@ class CacheKeyValue:
28
28
  ]
29
29
  self.key_cache = [layer.keys for layer in layers]
30
30
  self.value_cache = [layer.values for layer in layers]
31
- if None in self.key_cache or None in self.value_cache:
32
- from .helper import string_type
33
-
34
- raise AssertionError(
35
- f"issue with key_cache={string_type(self.key_cache)}, "
36
- f"or value_cache={string_type(self.value_cache)}, "
37
- f"cache.layers={string_type(cache.layers)}"
38
- )
39
31
  elif cache is not None and hasattr(cache, "key_cache"):
40
32
  self.key_cache = cache.key_cache
41
33
  self.value_cache = cache.value_cache
@@ -105,6 +105,8 @@ class FakeTensorContext:
105
105
  reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
106
106
  axis=tuple(sorted(sh)), keepdim=True
107
107
  )
108
+ if len(reduced_tensor.shape) == 0 == len(new_shape):
109
+ return reduced_tensor
108
110
  return reduced_tensor.expand(*new_shape)
109
111
 
110
112
  def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
@@ -144,19 +146,22 @@ class FakeTensorContext:
144
146
  """
145
147
  See
146
148
  :func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
149
+ If caches are used, it requires ``transformers>=4.57``.
147
150
  """
148
151
  if x is None:
149
152
  return None, None
150
- if isinstance(x, (list, tuple)):
153
+ if type(x) in (list, tuple):
151
154
  return x.__class__(
152
155
  [
153
156
  self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
154
157
  for i, ds in zip(x, dynamic_shapes)
155
158
  ]
156
159
  )
157
- if isinstance(x, dict):
160
+ if type(x) is dict:
158
161
  return {
159
- k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
162
+ k: self.make_fake_with_dynamic_dimensions(
163
+ v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None
164
+ )
160
165
  for k, v in x.items()
161
166
  }
162
167
  if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
@@ -187,6 +192,17 @@ class FakeTensorContext:
187
192
  x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
188
193
  )
189
194
  return x
195
+ if x.__class__.__name__ == "BaseModelOutput":
196
+ assert (
197
+ list(x.keys()) == ["last_hidden_state"] and x.last_hidden_state is not None
198
+ ), (
199
+ f"Field 'last_hidden_state' is empty for {type(x)} or other fields "
200
+ f"{list(x.keys())} are used."
201
+ )
202
+ x.last_hidden_state = self.make_fake_with_dynamic_dimensions(
203
+ x.last_hidden_state, dynamic_shapes=dynamic_shapes[0]
204
+ )
205
+ return x
190
206
  if hasattr(x, "shape"):
191
207
  assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
192
208
  f"dynamic_shapes must be a dictionary at this stage but "
@@ -197,9 +213,11 @@ class FakeTensorContext:
197
213
  for idim, dim in enumerate(x.shape):
198
214
  if dynamic_shapes is not None and idim in dynamic_shapes:
199
215
  s = dynamic_shapes[idim]
216
+ if s.__class__.__name__ == "Dim":
217
+ s = s.__name__
200
218
  assert isinstance(s, str), (
201
219
  f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
202
- f"at index {idim}"
220
+ f"at index {idim}, self._mapping_str={self._mapping_str}"
203
221
  )
204
222
  if s in self._mapping_str:
205
223
  dim = self._mapping_str[s]
@@ -217,10 +235,13 @@ class FakeTensorContext:
217
235
 
218
236
  x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
219
237
 
220
- t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
238
+ t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type]
221
239
  assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
222
240
  assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
223
241
  return t
242
+ if isinstance(x, (int, bool, float)):
243
+ # It is a constant, we don't change that.
244
+ return x
224
245
  from ..helpers import string_type
225
246
 
226
247
  raise TypeError(
@@ -704,9 +704,35 @@ def string_type(
704
704
  if obj.__class__.__name__ == "VirtualTensor":
705
705
  if verbose:
706
706
  print(f"[string_type] TT4:{type(obj)}")
707
+
708
+ def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821
709
+ if isinstance(value, str):
710
+ return value
711
+ if hasattr(value, "node") and isinstance(value.node, str):
712
+ return f"{value.node}"
713
+
714
+ from torch.fx.experimental.sym_node import SymNode
715
+
716
+ if hasattr(value, "node") and isinstance(value.node, SymNode):
717
+ # '_expr' is safer than expr
718
+ return str(value.node._expr).replace(" ", "")
719
+
720
+ try:
721
+ val_int = int(value)
722
+ return val_int
723
+ except (
724
+ TypeError,
725
+ ValueError,
726
+ AttributeError,
727
+ torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
728
+ ):
729
+ pass
730
+
731
+ raise AssertionError(f"Unable to convert {value!r} into string")
732
+
707
733
  return (
708
734
  f"{obj.__class__.__name__}(name={obj.name!r}, "
709
- f"dtype={obj.dtype}, shape={obj.shape})"
735
+ f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})"
710
736
  )
711
737
 
712
738
  if obj.__class__.__name__ == "KeyValuesWrapper":
@@ -775,6 +801,9 @@ def string_type(
775
801
  print(f"[string_type] TT8:{type(obj)}")
776
802
  return repr(obj).replace(" ", "").replace("\n", " ")
777
803
 
804
+ if isinstance(obj, torch.fx.proxy.Proxy):
805
+ return repr(obj)
806
+
778
807
  if ignore:
779
808
  if verbose:
780
809
  print(f"[string_type] CACHE4:{type(obj)}")
@@ -1921,9 +1921,7 @@ class CubeLogsPerformance(CubeLogs):
1921
1921
  return lambdas[formula]
1922
1922
 
1923
1923
  if formula == "onnx_n_nodes_no_cst":
1924
- return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
1925
- df, "op_onnx__Constant", 0
1926
- ).fillna(0)
1924
+ return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(df, "op_onnx__Constant", 0)
1927
1925
  if formula == "peak_gpu_torch":
1928
1926
  return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
1929
1927
  if formula == "peak_gpu_nvidia":
@@ -0,0 +1,116 @@
1
+ from typing import Optional, Union
2
+ import pprint
3
+ import onnx
4
+
5
+
6
+ def optimize_model(
7
+ algorithm: str,
8
+ model: Union[onnx.ModelProto, str],
9
+ output: Optional[str] = None,
10
+ processor: Optional[str] = None,
11
+ infer_shapes: bool = True,
12
+ remove_shape_info: bool = False,
13
+ verbose: int = 1,
14
+ ):
15
+ """
16
+ Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
17
+ and replaces them by the corresponding nodes. It also does basic optimization
18
+ such as removing identity nodes or unused nodes.
19
+
20
+ :param algorithm: algorithm to choose
21
+ :param model: model to optimize as a proto or a filename
22
+ :param output: if not empty, the optimized model is saved
23
+ :param processor: optimization are done for the processor
24
+ :param infer_shapes: infer shapes before optimizing, this might not be
25
+ available for all algorithm
26
+ :param remove_shape_info: remove shape information before saving the model
27
+ :param verbose: verbosity level
28
+ :return: optimized model
29
+
30
+ The goal is to make the model faster.
31
+ Argument patterns defines the patterns to apply or the set of patterns.
32
+ It is possible to show statistics or to remove a particular pattern.
33
+ Here are some environment variables which can be used to trigger
34
+ these displays.
35
+
36
+ Available options algorithms, default and default+runtime:
37
+
38
+ - ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
39
+ those patterns when optimizing a model
40
+ - ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
41
+ - ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
42
+ for specific patterns to understand why one pattern was not applied,
43
+ this shows which line is rejecting a pattern if it seems one pattern was missed
44
+ """
45
+ if isinstance(model, str):
46
+ if verbose:
47
+ print(f"[optimize_model] load {model!r}")
48
+ proto = onnx.load(model)
49
+ if verbose:
50
+ print("[optimize_model] done loading.")
51
+ else:
52
+ proto = model
53
+
54
+ if verbose:
55
+ print(f"[optimize_model] optimize with {algorithm!r}")
56
+ if algorithm in {"default", "default+onnxruntime"}:
57
+ from experimental_experiment.xoptim import get_pattern_list
58
+ from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
59
+
60
+ pats = get_pattern_list(algorithm)
61
+
62
+ gr = GraphBuilder(
63
+ proto,
64
+ infer_shapes_options=infer_shapes,
65
+ optimization_options=OptimizationOptions(
66
+ patterns=pats,
67
+ verbose=verbose,
68
+ remove_unused=True,
69
+ constant_folding=True,
70
+ remove_identity=True,
71
+ max_iter=max(100, len(proto.graph.node) // 2),
72
+ processor=processor or "CPU",
73
+ ),
74
+ )
75
+ if verbose:
76
+ print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
77
+ print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
78
+ opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
79
+ if verbose:
80
+ print("[optimize_model] optimization report")
81
+ pprint.pprint(report)
82
+ print("[optimize_model] done")
83
+
84
+ elif algorithm == "slim":
85
+ import onnxslim
86
+
87
+ opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
88
+ elif algorithm in {"ir", "os_ort"}:
89
+ import onnx_ir
90
+ import onnxscript.optimizer
91
+ from onnxscript.rewriter.ort_fusions import optimize_for_ort
92
+
93
+ model_ir = onnx_ir.from_proto(proto)
94
+ if algorithm == "ir":
95
+ onnxscript.optimizer.optimize(model_ir)
96
+ else:
97
+ optimize_for_ort(model_ir)
98
+ opt_onx = onnx_ir.serde.serialize_model(model_ir)
99
+
100
+ del proto
101
+ if verbose:
102
+ print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
103
+ if remove_shape_info:
104
+ if verbose:
105
+ print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
106
+ del opt_onx.graph.value_info[:]
107
+ if verbose:
108
+ print("[optimize_model] done removing shape info")
109
+
110
+ if output:
111
+ if verbose:
112
+ print(f"[optimize_model] save file into {output!r}")
113
+ onnx.save(opt_onx, output, save_as_external_data=True)
114
+ if verbose:
115
+ print("[optimize_model] done saving")
116
+ return opt_onx
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
3
  import onnx
3
4
  import numpy as np
@@ -76,6 +77,10 @@ class _InferenceSession:
76
77
  session_options.enable_profiling = enable_profiling
77
78
  if optimized_model_filepath:
78
79
  session_options.optimized_model_filepath = optimized_model_filepath
80
+ session_options.add_session_config_entry(
81
+ "session.optimized_model_external_initializers_file_name",
82
+ f"{os.path.splitext(os.path.split(optimized_model_filepath)[-1])[0]}.data",
83
+ )
79
84
  if log_severity_level is not None:
80
85
  session_options.log_severity_level = log_severity_level
81
86
  if log_verbosity_level is not None:
@@ -13,6 +13,10 @@ from .data import get_data
13
13
  __TASK__ = "image-text-to-text"
14
14
 
15
15
 
16
+ def should_have_vision_config(config):
17
+ return config.architectures != ["FuyuForCausalLM"]
18
+
19
+
16
20
  def reduce_model_config(config: Any) -> Dict[str, Any]:
17
21
  """Reduces a model size."""
18
22
  kwargs: Dict[str, Any] = {}
@@ -168,10 +172,10 @@ def _get_inputs_gemma3(
168
172
  assert expected & set(
169
173
  dummies
170
174
  ), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
171
- assert sequence_length == dummies["input_ids"].shape[-1], (
172
- f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
173
- f"model class {model.__class__.__name__}"
174
- )
175
+ # assert sequence_length == dummies["input_ids"].shape[-1], (
176
+ # f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
177
+ # f"model class {model.__class__.__name__}"
178
+ # )
175
179
  assert batch_size == dummies["input_ids"].shape[0], (
176
180
  f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
177
181
  f"model class {model.__class__.__name__}"
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
477
481
  "hidden_size",
478
482
  "pad_token_id",
479
483
  )
480
- check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
484
+ if should_have_vision_config(config):
485
+ check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
481
486
  text_config = True
482
487
  else:
483
488
  check_hasattr(
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
491
496
  "vision_config",
492
497
  )
493
498
  text_config = False
494
- check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
499
+ if should_have_vision_config(config):
500
+ check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
495
501
  kwargs = dict(
496
502
  head_dim=(
497
503
  16
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
552
558
  ),
553
559
  width=(
554
560
  224
555
- if config is None or not hasattr(config.vision_config, "image_size")
561
+ if config is None
562
+ or not should_have_vision_config(config)
563
+ or not hasattr(config.vision_config, "image_size")
556
564
  else config.vision_config.image_size
557
565
  ),
558
566
  height=(
559
567
  224
560
- if config is None or not hasattr(config.vision_config, "image_size")
568
+ if config is None
569
+ or not should_have_vision_config(config)
570
+ or not hasattr(config.vision_config, "image_size")
561
571
  else config.vision_config.image_size
562
572
  ),
563
573
  num_channels=(
564
574
  3
565
- if config is None
575
+ if config is None or not should_have_vision_config(config)
566
576
  else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
567
577
  ),
568
578
  pad_token_id=(
@@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
18
18
  config.num_decoder_layers = min(config.num_decoder_layers, 2)
19
19
  if hasattr(config, "num_hidden_layers"):
20
20
  config.num_hidden_layers = min(config.num_hidden_layers, nhl())
21
+ if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"):
22
+ default_layer_types = [
23
+ "sliding_attention",
24
+ "full_attention",
25
+ "sliding_attention",
26
+ "full_attention",
27
+ ]
28
+ config.encoder.num_hidden_layers = 4
29
+ config.encoder.layer_types = (
30
+ default_layer_types if config is None else config.encoder.layer_types[:4]
31
+ )
32
+ config.decoder.num_hidden_layers = 4
33
+ config.decoder.layer_types = (
34
+ default_layer_types if config is None else config.decoder.layer_types[:4]
35
+ )
36
+
21
37
  update_config(config, kwargs)
22
38
  return kwargs
23
39
 
@@ -177,55 +193,75 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
177
193
 
178
194
  If the configuration is None, the function selects typical dimensions.
179
195
  """
196
+ path = 1
180
197
  if config is not None:
181
- check_hasattr(
182
- config,
183
- "vocab_size",
184
- "hidden_size",
185
- "num_attention_heads",
186
- ("num_hidden_layers", "num_layers"),
187
- ("n_positions", "d_model"),
188
- (
189
- "num_key_value_heads",
190
- "num_heads",
191
- ("decoder_attention_heads", "encoder_attention_heads"),
192
- ),
193
- )
194
- # exceptions = {
195
- # "PLBartForConditionalGeneration": (
196
- # lambda c: c.encoder_attention_heads + c.decoder_attention_heads
197
- # )
198
- # }
199
- kwargs = dict(
200
- batch_size=2,
201
- sequence_length=30,
202
- sequence_length2=3,
203
- head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
204
- head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
205
- dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
206
- num_hidden_layers=(
207
- 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
208
- ),
209
- num_key_value_heads_encoder=(
210
- 16
211
- if config is None
212
- else _pick(
198
+ if hasattr(config, "num_attention_heads"):
199
+ check_hasattr(
213
200
  config,
214
- "encoder_attention_heads",
215
- "num_key_value_heads",
216
- "num_heads",
201
+ "vocab_size",
202
+ "hidden_size",
203
+ "num_attention_heads",
204
+ ("num_hidden_layers", "num_layers"),
205
+ ("n_positions", "d_model"),
206
+ (
207
+ "num_key_value_heads",
208
+ "num_heads",
209
+ ("decoder_attention_heads", "encoder_attention_heads"),
210
+ ),
217
211
  )
218
- ),
219
- num_key_value_heads_decoder=(
220
- 16
221
- if config is None
222
- else _pick(
223
- config,
224
- "decoder_attention_heads",
225
- "num_key_value_heads",
226
- "num_heads",
227
- )
228
- ),
229
- encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
230
- )
212
+ else:
213
+ check_hasattr(config, "encoder", "decoder")
214
+ path = 2
215
+
216
+ if path == 1:
217
+ kwargs = dict(
218
+ batch_size=2,
219
+ sequence_length=30,
220
+ sequence_length2=3,
221
+ head_dim_encoder=(
222
+ 16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim")
223
+ ),
224
+ head_dim_decoder=(
225
+ 16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim")
226
+ ),
227
+ dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
228
+ num_hidden_layers=(
229
+ 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
230
+ ),
231
+ num_key_value_heads_encoder=(
232
+ 16
233
+ if config is None
234
+ else _pick(
235
+ config,
236
+ "encoder_attention_heads",
237
+ "num_key_value_heads",
238
+ "num_heads",
239
+ )
240
+ ),
241
+ num_key_value_heads_decoder=(
242
+ 16
243
+ if config is None
244
+ else _pick(
245
+ config,
246
+ "decoder_attention_heads",
247
+ "num_key_value_heads",
248
+ "num_heads",
249
+ )
250
+ ),
251
+ encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
252
+ )
253
+ else:
254
+ kwargs = dict(
255
+ batch_size=2,
256
+ sequence_length=30,
257
+ sequence_length2=3,
258
+ dummy_max_token_id=config.encoder.vocab_size - 1,
259
+ num_key_value_heads_encoder=config.encoder.num_key_value_heads,
260
+ num_key_value_heads_decoder=config.decoder.num_key_value_heads,
261
+ num_hidden_layers=len(config.encoder.layer_types),
262
+ head_dim_encoder=config.encoder.head_dim,
263
+ head_dim_decoder=config.decoder.head_dim,
264
+ encoder_dim=256,
265
+ )
266
+
231
267
  return kwargs, get_inputs
@@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
40
40
  state_size=8 if config is None else getattr(config, "state_size", None),
41
41
  conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
42
42
  )
43
+ elif config.__class__.__name__ == "FunnelConfig":
44
+ # does not support num_hidden_layers
45
+ kwargs = dict()
43
46
  else:
44
47
  kwargs = dict(
45
48
  head_dim=getattr(