onnx-diagnostic 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/ci_models/export_phi4_mm.py +1 -1
  3. onnx_diagnostic/doc.py +258 -8
  4. onnx_diagnostic/export/api.py +755 -5
  5. onnx_diagnostic/export/dynamic_shapes.py +61 -4
  6. onnx_diagnostic/export/shape_helper.py +1 -8
  7. onnx_diagnostic/helpers/cache_helper.py +98 -21
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  9. onnx_diagnostic/helpers/helper.py +36 -6
  10. onnx_diagnostic/helpers/onnx_helper.py +7 -0
  11. onnx_diagnostic/helpers/ort_session.py +5 -0
  12. onnx_diagnostic/helpers/rt_helper.py +14 -1
  13. onnx_diagnostic/helpers/torch_helper.py +22 -9
  14. onnx_diagnostic/tasks/image_text_to_text.py +8 -5
  15. onnx_diagnostic/tasks/text_generation.py +17 -17
  16. onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
  18. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  19. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
  22. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
  23. onnx_diagnostic/torch_models/validate.py +48 -0
  24. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +3 -1
  25. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +28 -28
  26. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
@@ -329,7 +329,7 @@ class CoupleInputsDynamicShapes:
329
329
  if type(inputs) in (tuple, list, dict):
330
330
  # Type must be strict, some custom classes can inherit from those.
331
331
  assert type(inputs) is type(ds), (
332
- f"Input type and dynamic shape type mush match but "
332
+ f"Input type and dynamic shapes type mush match but "
333
333
  f"type(inputs)={type(inputs)}, type(ds)={type(ds)}, "
334
334
  f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
335
335
  )
@@ -352,6 +352,19 @@ class CoupleInputsDynamicShapes:
352
352
  else None
353
353
  )
354
354
  assert type(inputs) is dict, f"Unexpected type for inputs {type(inputs)}"
355
+ if set(inputs) != set(ds):
356
+ not_in_ds = {k for k in inputs if k not in ds}
357
+ not_in_inputs = {k for k in ds if k not in inputs}
358
+ assert not_in_inputs == {"kwargs"} and set(ds["kwargs"]) == not_in_ds, (
359
+ f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
360
+ f"inputs={string_type(inputs, with_shape=True)}, ds={ds}, "
361
+ f"not_in_ds={not_in_ds}, not_in_inputs={not_in_inputs}"
362
+ )
363
+ # Tweak...
364
+ kws = ds["kwargs"]
365
+ del ds["kwargs"]
366
+ ds.update(kws)
367
+
355
368
  assert set(inputs) == set(ds), (
356
369
  f"Keys mismatch between inputs {set(inputs)} and ds={set(ds)}, "
357
370
  f"inputs={string_type(inputs, with_shape=True)}, ds={ds}"
@@ -366,13 +379,15 @@ class CoupleInputsDynamicShapes:
366
379
  return dvalue if dvalue else None
367
380
 
368
381
  # A custom class.
369
- assert inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
382
+ assert inputs is None or inputs.__class__ in torch.utils._pytree.SUPPORTED_NODES, (
370
383
  f"Class {inputs.__class__.__name__!r} was not registered using "
371
384
  f"torch.utils._pytree.register_pytree_node, it is not possible to "
372
385
  f"map this class with the given dynamic shapes."
373
386
  )
374
387
  if flatten_unflatten:
375
388
  flatunflat = flatten_unflatten_for_dynamic_shapes(inputs)
389
+ if isinstance(flatunflat, (list, tuple, dict)) and len(flatunflat) == 0:
390
+ return flatunflat
376
391
  res = cls._generic_walker_step(
377
392
  processor, flatunflat, ds, flatten_unflatten=flatten_unflatten
378
393
  )
@@ -667,6 +682,11 @@ class ModelInputs:
667
682
  if self.signature
668
683
  else None
669
684
  )
685
+ self.forward_parameters_kinds = (
686
+ {p.name: p.kind for p in self.signature.parameters.values()}
687
+ if self.signature
688
+ else None
689
+ )
670
690
  self.forward_ordered_parameter_names = (
671
691
  list(self.signature.parameters) if self.signature else None
672
692
  )
@@ -947,6 +967,8 @@ class ModelInputs:
947
967
  """
948
968
  Guesses the dynamic shapes for that module from two execution.
949
969
  If there is only one execution, then that would be static dimensions.
970
+ If the model signature is available, the kwargs are reordered following
971
+ the signature order, otherwise it follows the order given in the inputs.
950
972
 
951
973
  :param auto: if auto is True, use ``torch.export.Dim.AUTO`` for any
952
974
  dimension if the number of inputs is one,
@@ -973,7 +995,13 @@ class ModelInputs:
973
995
  len(s1) == 1
974
996
  ), f"Different numbers of positional arguments {s1} for {self.full_name}"
975
997
  s2 = set(tuple(sorted(set(i[1]))) for i in self.inputs)
976
- assert len(s2) == 1, f"Different named arguments {s2} for {self.full_name}"
998
+ assert len(s2) > 0, f"empty {s2} for {self.full_name}"
999
+ if len(s2) > 1:
1000
+ # We need to keep the largest set of inputs, the one including all the others.
1001
+ sum_s2 = set()
1002
+ for s in s2:
1003
+ sum_s2 |= set(s)
1004
+ s2 = {tuple(sum_s2)}
977
1005
  args = []
978
1006
  kwargs = {}
979
1007
  for i in range(s1.pop()):
@@ -993,12 +1021,31 @@ class ModelInputs:
993
1021
  f"\ninputs[1]={string_type(self.inputs[1], with_shape=True)}"
994
1022
  )
995
1023
 
996
- objs = [_[1][name] for _ in self.inputs]
1024
+ objs = [_[1][name] for _ in self.inputs if name in _[1]]
997
1025
  kwargs[name] = self.guess_dynamic_shape_object(
998
1026
  *objs,
999
1027
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}I",
1000
1028
  msg=lambda name=name: f" failing input {name!r}",
1001
1029
  )
1030
+ # reordering
1031
+ if kwargs:
1032
+ if self.forward_ordered_parameter_names:
1033
+ kwargs1 = {
1034
+ p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1035
+ }
1036
+ kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
1037
+ else:
1038
+ # We reorder the same the way the input were given.
1039
+ use = None
1040
+ params = set(kwargs)
1041
+ for _args, kws in self.inputs:
1042
+ if set(kws) == params:
1043
+ use = kws
1044
+ break
1045
+ if use:
1046
+ ordered = list(use)
1047
+ kwargs = {k: kwargs[k] for k in ordered}
1048
+
1002
1049
  return tuple(args), kwargs
1003
1050
 
1004
1051
  def move_to_kwargs(
@@ -1061,6 +1108,16 @@ class ModelInputs:
1061
1108
  f"and kwargs={set(kwargs)}, "
1062
1109
  f"forward_ordered_parameter_names={self.forward_ordered_parameter_names}"
1063
1110
  )
1111
+ if kwargs is not None and self.forward_ordered_parameter_names:
1112
+ kwargs1 = {
1113
+ p: kwargs[p] for p in self.forward_ordered_parameter_names if p in kwargs
1114
+ }
1115
+ kwargs = {**kwargs1, **{k: v for k, v in kwargs.items() if k not in kwargs1}}
1116
+ if kw_dyn is not None and self.forward_ordered_parameter_names:
1117
+ kw_dyn1 = {
1118
+ p: kw_dyn[p] for p in self.forward_ordered_parameter_names if p in kw_dyn
1119
+ }
1120
+ kw_dyn = {**kw_dyn1, **{k: v for k, v in kw_dyn.items() if k not in kw_dyn1}}
1064
1121
  return args, kwargs, (tuple(), kw_dyn)
1065
1122
 
1066
1123
  def validate_inputs_for_export(
@@ -47,7 +47,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
47
47
  make_dynamic_cache,
48
48
  make_encoder_decoder_cache,
49
49
  make_mamba_cache,
50
- make_sliding_window_cache,
51
50
  make_static_cache,
52
51
  )
53
52
  from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
@@ -77,13 +76,6 @@ def all_dynamic_shapes_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
77
76
  ]
78
77
  ),
79
78
  ),
80
- make_sliding_window_cache(
81
- [
82
- (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
83
- (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
84
- (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
85
- ]
86
- ),
87
79
  make_static_cache(
88
80
  [
89
81
  (torch.rand((4, 5, 6, 7)), torch.rand((4, 5, 6, 7))),
@@ -210,6 +202,7 @@ def make_fake_with_dynamic_dimensions(
210
202
  This uses function :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`.
211
203
  Parameter ``existing`` is used to reused the same object when the dynamic
212
204
  dimension is given the same name as another one.
205
+ This function works with caches only if ``transformers>=4.57``.
213
206
 
214
207
  A simple tensor:
215
208
 
@@ -19,7 +19,7 @@ class CacheKeyValue:
19
19
  capi.value_cache
20
20
  """
21
21
 
22
- def __init__(self, cache=None):
22
+ def __init__(self, cache=None, cls_layers=None):
23
23
  if hasattr(cache, "layers"):
24
24
  layers = [
25
25
  layer
@@ -28,32 +28,52 @@ class CacheKeyValue:
28
28
  ]
29
29
  self.key_cache = [layer.keys for layer in layers]
30
30
  self.value_cache = [layer.values for layer in layers]
31
- if None in self.key_cache or None in self.value_cache:
32
- from .helper import string_type
33
-
34
- raise AssertionError(
35
- f"issue with key_cache={string_type(self.key_cache)}, "
36
- f"or value_cache={string_type(self.value_cache)}, "
37
- f"cache.layers={string_type(cache.layers)}"
38
- )
31
+ assert (
32
+ cls_layers is None
33
+ ), f"cache is {type(cache)}, cannot specify cls_layers={cls_layers}"
34
+ self.cls_layers = [type(lay) for lay in cache.layers]
39
35
  elif cache is not None and hasattr(cache, "key_cache"):
40
36
  self.key_cache = cache.key_cache
41
37
  self.value_cache = cache.value_cache
38
+ self.cls_layers = cls_layers
39
+ elif (
40
+ cache is not None
41
+ and isinstance(cache, list)
42
+ and all(isinstance(t, torch.Tensor) for t in cache)
43
+ ):
44
+ self.key_cache = cache[::2]
45
+ self.value_cache = cache[1::2]
46
+ self.cls_layers = cls_layers
42
47
  elif cache is None:
43
48
  self.key_cache = None
44
49
  self.value_cache = None
50
+ self.cls_layers = cls_layers
45
51
  else:
46
52
  raise NotImplementedError(f"type(cache)={type(cache)}")
47
53
 
48
54
  def make_dynamic_cache(self):
49
55
  """Does the reverse operation."""
50
- return make_dynamic_cache(list(zip(self.key_cache, self.value_cache)))
56
+ return make_dynamic_cache(
57
+ list(zip(self.key_cache, self.value_cache)), cls_layers=self.cls_layers
58
+ )
51
59
 
52
60
  @property
53
61
  def n_layers(self) -> int:
54
62
  """Returns the number of layers."""
55
63
  return len(self.key_cache) if self.key_cache else 0
56
64
 
65
+ def __len__(self) -> int:
66
+ "Returns the number of tensors."
67
+ return len(self.key_cache) + len(self.value_cache)
68
+
69
+ def aslist(self) -> List[torch.Tensor]:
70
+ "Returns tensors in a list."
71
+ res = []
72
+ for i in range(self.n_layers):
73
+ res.append(self.key_cache[i])
74
+ res.append(self.value_cache[i])
75
+ return res
76
+
57
77
 
58
78
  def flatten_unflatten_for_dynamic_shapes(
59
79
  obj: Any,
@@ -164,12 +184,16 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
164
184
 
165
185
  def make_dynamic_cache(
166
186
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
187
+ cls_layers: Optional[Union[str, List[type]]] = None,
167
188
  ) -> transformers.cache_utils.DynamicCache:
168
189
  """
169
190
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
170
191
  This version is valid for ``transformers >= 4.50``.
171
192
 
172
193
  :param key_value_pairs: list of pairs of (key, values)
194
+ :param cls_layers: to select the appropriate class to use on each layer,
195
+ if specified, sliding_window is ignored, it can be a string
196
+ if all layers are expected to follow the same class
173
197
  :return: :class:`transformers.cache_utils.DynamicCache`
174
198
 
175
199
  Example:
@@ -200,15 +224,49 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
200
224
  are supported.
201
225
  """
202
226
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
227
+ cls_kwargs = {}
228
+ if isinstance(cls_layers, str):
229
+ assert hasattr(
230
+ transformers.cache_utils, cls_layers
231
+ ), f"Unable to find class {cls_layers!r} in transformers.cache_utils"
232
+ cls_layer = getattr(transformers.cache_utils, cls_layers)
233
+ if cls_layers == "DynamicSlidingWindowLayer":
234
+ cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
235
+ assert isinstance(
236
+ cls_kwargs["sliding_window"], int
237
+ ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
238
+ elif cls_layers is not None:
239
+ unique = set(cls_layers)
240
+ assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}"
241
+ cls_layer = unique.pop()
242
+ if (
243
+ hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer")
244
+ and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer
245
+ ):
246
+ from .helper import string_type
247
+
248
+ assert key_value_pairs and key_value_pairs[0], (
249
+ f"not implemented for key_value_pairs="
250
+ f"{string_type(key_value_pairs, with_shape=True)}"
251
+ )
252
+ cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
253
+ assert isinstance(
254
+ cls_kwargs["sliding_window"], int
255
+ ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
256
+ else:
257
+ cls_layer = (
258
+ transformers.cache_utils.DynamicLayer
259
+ if hasattr(transformers.cache_utils, "DynamicLayer")
260
+ else None
261
+ )
262
+
203
263
  if (
204
264
  key_value_pairs
205
265
  and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
206
266
  and pv.Version(transformers.__version__) >= pv.Version("4.56")
207
267
  ):
208
268
  cache = transformers.cache_utils.DynamicCache()
209
- cache.layers.extend(
210
- [transformers.cache_utils.DynamicLayer() for _ in key_value_pairs]
211
- )
269
+ cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
212
270
  for i, layer in enumerate(cache.layers):
213
271
  k, v = key_value_pairs[i][0], key_value_pairs[i][1]
214
272
  layer.dtype = k.dtype
@@ -222,14 +280,21 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
222
280
  )
223
281
  return finalize_cache(cache)
224
282
 
225
- cache = transformers.cache_utils.DynamicCache(key_value_pairs)
226
- if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
227
- # The cache constructor contains the two following lines
228
- # (in cache_utils.py) which append empty layers when the cache is
229
- # initialized. We need to remove them.
230
- # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
231
- # self.append_new_layers(self.num_hidden_layers - 1)
232
- cache.layers[:] = cache.layers[-len(key_value_pairs) :]
283
+ cache = transformers.cache_utils.DynamicCache()
284
+ if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer:
285
+ cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
286
+ for i, layer in enumerate(cache.layers):
287
+ layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
288
+ layer.is_initialized = True
289
+ else:
290
+ cache = transformers.cache_utils.DynamicCache(key_value_pairs)
291
+ if hasattr(cache, "layers") and len(key_value_pairs) < len(cache.layers):
292
+ # The cache constructor contains the two following lines
293
+ # (in cache_utils.py) which append empty layers when the cache is
294
+ # initialized. We need to remove them.
295
+ # self.num_hidden_layers = getattr(config, "num_hidden_layers", 1)
296
+ # self.append_new_layers(self.num_hidden_layers - 1)
297
+ cache.layers[:] = cache.layers[-len(key_value_pairs) :]
233
298
  assert not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers), (
234
299
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
235
300
  f"{len(key_value_pairs)} expected."
@@ -240,6 +305,7 @@ else:
240
305
 
241
306
  def make_dynamic_cache(
242
307
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
308
+ cls_layers: Optional[Union[str, List[type]]] = None,
243
309
  ) -> transformers.cache_utils.DynamicCache:
244
310
  """
245
311
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -271,6 +337,7 @@ else:
271
337
  )
272
338
  print(string_type(past_key_values, with_shape=True))
273
339
  """
340
+ assert not cls_layers, "cls_layers cannot be used for transformers<5."
274
341
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
275
342
  cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
276
343
  for i, (key, value) in enumerate(key_value_pairs):
@@ -516,9 +583,13 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
516
583
  )
517
584
  return finalize_cache(cache)
518
585
 
586
+ def get_make_hybrid_cache():
587
+ return make_sliding_window_cache
588
+
519
589
  else:
520
590
  make_sliding_window_cache = None # type: ignore[assignment]
521
591
 
592
+
522
593
  if hasattr(transformers.cache_utils, "HybridCache"):
523
594
 
524
595
  def make_hybrid_cache(
@@ -680,9 +751,15 @@ if hasattr(transformers.cache_utils, "HybridCache"):
680
751
  )
681
752
  return finalize_cache(cache)
682
753
 
754
+ def get_make_hybrid_cache():
755
+ return make_hybrid_cache
756
+
683
757
  else:
684
758
  make_hybrid_cache = None # type: ignore[assignment]
685
759
 
760
+ def get_make_hybrid_cache():
761
+ return None
762
+
686
763
 
687
764
  def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
688
765
  """
@@ -105,6 +105,8 @@ class FakeTensorContext:
105
105
  reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
106
106
  axis=tuple(sorted(sh)), keepdim=True
107
107
  )
108
+ if len(reduced_tensor.shape) == 0 == len(new_shape):
109
+ return reduced_tensor
108
110
  return reduced_tensor.expand(*new_shape)
109
111
 
110
112
  def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
@@ -144,19 +146,22 @@ class FakeTensorContext:
144
146
  """
145
147
  See
146
148
  :func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
149
+ If caches are used, it requires ``transformers>=4.57``.
147
150
  """
148
151
  if x is None:
149
152
  return None, None
150
- if isinstance(x, (list, tuple)):
153
+ if type(x) in (list, tuple):
151
154
  return x.__class__(
152
155
  [
153
156
  self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
154
157
  for i, ds in zip(x, dynamic_shapes)
155
158
  ]
156
159
  )
157
- if isinstance(x, dict):
160
+ if type(x) is dict:
158
161
  return {
159
- k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
162
+ k: self.make_fake_with_dynamic_dimensions(
163
+ v, dynamic_shapes=dynamic_shapes[k] if dynamic_shapes else None
164
+ )
160
165
  for k, v in x.items()
161
166
  }
162
167
  if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
@@ -187,6 +192,17 @@ class FakeTensorContext:
187
192
  x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
188
193
  )
189
194
  return x
195
+ if x.__class__.__name__ == "BaseModelOutput":
196
+ assert (
197
+ list(x.keys()) == ["last_hidden_state"] and x.last_hidden_state is not None
198
+ ), (
199
+ f"Field 'last_hidden_state' is empty for {type(x)} or other fields "
200
+ f"{list(x.keys())} are used."
201
+ )
202
+ x.last_hidden_state = self.make_fake_with_dynamic_dimensions(
203
+ x.last_hidden_state, dynamic_shapes=dynamic_shapes[0]
204
+ )
205
+ return x
190
206
  if hasattr(x, "shape"):
191
207
  assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
192
208
  f"dynamic_shapes must be a dictionary at this stage but "
@@ -197,9 +213,11 @@ class FakeTensorContext:
197
213
  for idim, dim in enumerate(x.shape):
198
214
  if dynamic_shapes is not None and idim in dynamic_shapes:
199
215
  s = dynamic_shapes[idim]
216
+ if s.__class__.__name__ == "Dim":
217
+ s = s.__name__
200
218
  assert isinstance(s, str), (
201
219
  f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
202
- f"at index {idim}"
220
+ f"at index {idim}, self._mapping_str={self._mapping_str}"
203
221
  )
204
222
  if s in self._mapping_str:
205
223
  dim = self._mapping_str[s]
@@ -217,10 +235,13 @@ class FakeTensorContext:
217
235
 
218
236
  x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
219
237
 
220
- t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
238
+ t = self.fake_reshape(x, dynamic_shapes) if dynamic_shapes else x # type: ignore[arg-type]
221
239
  assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
222
240
  assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
223
241
  return t
242
+ if isinstance(x, (int, bool, float)):
243
+ # It is a constant, we don't change that.
244
+ return x
224
245
  from ..helpers import string_type
225
246
 
226
247
  raise TypeError(
@@ -1,7 +1,6 @@
1
1
  import ast
2
2
  import enum
3
3
  import inspect
4
- import itertools
5
4
  import json
6
5
  from dataclasses import is_dataclass, fields
7
6
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
@@ -704,9 +703,35 @@ def string_type(
704
703
  if obj.__class__.__name__ == "VirtualTensor":
705
704
  if verbose:
706
705
  print(f"[string_type] TT4:{type(obj)}")
706
+
707
+ def _torch_sym_int_to_str(value: "torch.SymInt") -> Union[int, str]: # noqa: F821
708
+ if isinstance(value, str):
709
+ return value
710
+ if hasattr(value, "node") and isinstance(value.node, str):
711
+ return f"{value.node}"
712
+
713
+ from torch.fx.experimental.sym_node import SymNode
714
+
715
+ if hasattr(value, "node") and isinstance(value.node, SymNode):
716
+ # '_expr' is safer than expr
717
+ return str(value.node._expr).replace(" ", "")
718
+
719
+ try:
720
+ val_int = int(value)
721
+ return val_int
722
+ except (
723
+ TypeError,
724
+ ValueError,
725
+ AttributeError,
726
+ torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode,
727
+ ):
728
+ pass
729
+
730
+ raise AssertionError(f"Unable to convert {value!r} into string")
731
+
707
732
  return (
708
733
  f"{obj.__class__.__name__}(name={obj.name!r}, "
709
- f"dtype={obj.dtype}, shape={obj.shape})"
734
+ f"dtype={obj.dtype}, shape={tuple(_torch_sym_int_to_str(_) for _ in obj.shape)})"
710
735
  )
711
736
 
712
737
  if obj.__class__.__name__ == "KeyValuesWrapper":
@@ -775,6 +800,9 @@ def string_type(
775
800
  print(f"[string_type] TT8:{type(obj)}")
776
801
  return repr(obj).replace(" ", "").replace("\n", " ")
777
802
 
803
+ if isinstance(obj, torch.fx.proxy.Proxy):
804
+ return repr(obj)
805
+
778
806
  if ignore:
779
807
  if verbose:
780
808
  print(f"[string_type] CACHE4:{type(obj)}")
@@ -962,15 +990,17 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
962
990
  if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
963
991
  from .cache_helper import CacheKeyValue
964
992
 
965
- kc = CacheKeyValue(x)
966
- return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
993
+ return CacheKeyValue(x).aslist()
967
994
 
968
995
  if x.__class__.__name__ == "EncoderDecoderCache":
969
- res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
996
+ res = [
997
+ *flatten_object(x.self_attention_cache),
998
+ *flatten_object(x.cross_attention_cache),
999
+ ]
970
1000
  return tuple(res)
971
1001
  if x.__class__.__name__ == "MambaCache":
972
1002
  if isinstance(x.conv_states, list):
973
- res = flatten_object(x.conv_states) + flatten_object(x.ssm_states)
1003
+ res = [*flatten_object(x.conv_states), *flatten_object(x.ssm_states)]
974
1004
  return tuple(res)
975
1005
  return (x.conv_states, x.ssm_states)
976
1006
  if hasattr(x, "to_tuple"):
@@ -28,6 +28,7 @@ from onnx import (
28
28
  NodeProto,
29
29
  OperatorSetIdProto,
30
30
  TensorProto,
31
+ TypeProto,
31
32
  ValueInfoProto,
32
33
  load as onnx_load,
33
34
  )
@@ -385,6 +386,12 @@ def pretty_onnx(
385
386
  shape_str = ",".join(map(str, shape))
386
387
  return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}] {name}"
387
388
 
389
+ if isinstance(onx, TypeProto):
390
+ itype = onx.tensor_type.elem_type
391
+ shape = tuple((d.dim_param or d.dim_value) for d in onx.tensor_type.shape.dim)
392
+ shape_str = ",".join(map(str, shape))
393
+ return f"{onnx_dtype_name(itype, exc=False)}[{shape_str}]"
394
+
388
395
  if isinstance(onx, AttributeProto):
389
396
  att = onx
390
397
  if att.type == AttributeProto.INT:
@@ -1,3 +1,4 @@
1
+ import os
1
2
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
2
3
  import onnx
3
4
  import numpy as np
@@ -76,6 +77,10 @@ class _InferenceSession:
76
77
  session_options.enable_profiling = enable_profiling
77
78
  if optimized_model_filepath:
78
79
  session_options.optimized_model_filepath = optimized_model_filepath
80
+ session_options.add_session_config_entry(
81
+ "session.optimized_model_external_initializers_file_name",
82
+ f"{os.path.splitext(os.path.split(optimized_model_filepath)[-1])[0]}.data",
83
+ )
79
84
  if log_severity_level is not None:
80
85
  session_options.log_severity_level = log_severity_level
81
86
  if log_verbosity_level is not None:
@@ -41,7 +41,20 @@ def make_feeds(
41
41
  """
42
42
  # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
43
43
  # because it's fued into rotary embedding in GQA.
44
- if is_modelbuilder and isinstance(inputs, dict):
44
+ if is_modelbuilder and isinstance(inputs, dict) and "position_ids" in inputs:
45
+ position_ids = inputs["position_ids"] # type: ignore[valid-type]
46
+ # We just check position_ids are contiguous.
47
+ assert isinstance(position_ids, torch.Tensor) and (
48
+ (
49
+ (position_ids - position_ids.min())
50
+ == torch.tensor(list(range(position_ids.shape[-1]))).unsqueeze(0)
51
+ )
52
+ .max()
53
+ .item()
54
+ ), (
55
+ f"ModelBuilder does not support position_ids={position_ids}, "
56
+ f"inputs={string_type(inputs, with_shape=True)}"
57
+ )
45
58
  inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
46
59
 
47
60
  flat = flatten_object(inputs, drop_keys=True)
@@ -15,9 +15,6 @@ from .helper import string_type, size_type
15
15
  from .cache_helper import (
16
16
  make_dynamic_cache,
17
17
  make_encoder_decoder_cache,
18
- make_hybrid_cache,
19
- make_sliding_window_cache,
20
- make_mamba_cache,
21
18
  make_static_cache,
22
19
  CacheKeyValue,
23
20
  )
@@ -769,10 +766,22 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
769
766
  return {to_any(t, to_value) for t in value}
770
767
  if type(value) is dict:
771
768
  return {k: to_any(t, to_value) for k, t in value.items()}
772
- if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
773
- make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
769
+ if value.__class__.__name__ == "DynamicCache":
774
770
  cc = CacheKeyValue(value)
775
- return make[value.__class__.__name__]( # type: ignore[operator]
771
+ return make_dynamic_cache(
772
+ list(
773
+ zip(
774
+ [t.to(to_value) if t is not None else t for t in cc.key_cache],
775
+ [t.to(to_value) if t is not None else t for t in cc.value_cache],
776
+ )
777
+ ),
778
+ cls_layers=cc.cls_layers,
779
+ )
780
+ if value.__class__.__name__ == "HybridCache":
781
+ from .cache_helper import make_hybrid_cache
782
+
783
+ cc = CacheKeyValue(value)
784
+ return make_hybrid_cache(
776
785
  list(
777
786
  zip(
778
787
  [t.to(to_value) if t is not None else t for t in cc.key_cache],
@@ -843,7 +852,9 @@ def torch_deepcopy(value: Any) -> Any:
843
852
  from .cache_helper import CacheKeyValue
844
853
 
845
854
  ca = CacheKeyValue(value)
846
- return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
855
+ return make_dynamic_cache(
856
+ torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
857
+ )
847
858
  if value.__class__.__name__ == "StaticCache":
848
859
  from .cache_helper import CacheKeyValue
849
860
 
@@ -858,12 +869,12 @@ def torch_deepcopy(value: Any) -> Any:
858
869
  max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
859
870
  )
860
871
  if value.__class__.__name__ == "HybridCache":
861
- from .cache_helper import CacheKeyValue
872
+ from .cache_helper import CacheKeyValue, make_hybrid_cache
862
873
 
863
874
  ca = CacheKeyValue(value)
864
875
  return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
865
876
  if value.__class__.__name__ == "SlidingWindowCache":
866
- from .cache_helper import CacheKeyValue
877
+ from .cache_helper import CacheKeyValue, make_sliding_window_cache
867
878
 
868
879
  ca = CacheKeyValue(value)
869
880
  return make_sliding_window_cache(
@@ -875,6 +886,8 @@ def torch_deepcopy(value: Any) -> Any:
875
886
  torch_deepcopy(value.cross_attention_cache),
876
887
  )
877
888
  if value.__class__.__name__ == "MambaCache":
889
+ from .cache_helper import make_mamba_cache
890
+
878
891
  return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
879
892
 
880
893
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES: