onnx-diagnostic 0.8.7__py3-none-any.whl → 0.8.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/ci_models/export_phi4_mm.py +1 -1
  3. onnx_diagnostic/doc.py +258 -8
  4. onnx_diagnostic/export/api.py +755 -5
  5. onnx_diagnostic/export/dynamic_shapes.py +61 -4
  6. onnx_diagnostic/export/shape_helper.py +1 -8
  7. onnx_diagnostic/helpers/cache_helper.py +98 -21
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +26 -5
  9. onnx_diagnostic/helpers/helper.py +36 -6
  10. onnx_diagnostic/helpers/onnx_helper.py +7 -0
  11. onnx_diagnostic/helpers/ort_session.py +5 -0
  12. onnx_diagnostic/helpers/rt_helper.py +14 -1
  13. onnx_diagnostic/helpers/torch_helper.py +22 -9
  14. onnx_diagnostic/tasks/image_text_to_text.py +8 -5
  15. onnx_diagnostic/tasks/text_generation.py +17 -17
  16. onnx_diagnostic/torch_export_patches/eval/__init__.py +1 -1
  17. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +62 -38
  18. onnx_diagnostic/torch_export_patches/patch_details.py +3 -3
  19. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +14 -5
  20. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
  21. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -9
  22. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +42 -30
  23. onnx_diagnostic/torch_models/validate.py +48 -0
  24. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/METADATA +3 -1
  25. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/RECORD +28 -28
  26. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.8.7.dist-info → onnx_diagnostic-0.8.9.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
1
1
  import itertools
2
2
  from typing import Any, Callable, Dict, Optional, Tuple
3
3
  import torch
4
- from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
4
+ from ..helpers.cache_helper import make_dynamic_cache, get_make_hybrid_cache
5
5
  from ..helpers.config_helper import (
6
6
  update_config,
7
7
  check_hasattr,
@@ -172,10 +172,10 @@ def _get_inputs_gemma3(
172
172
  assert expected & set(
173
173
  dummies
174
174
  ), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
175
- assert sequence_length == dummies["input_ids"].shape[-1], (
176
- f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
177
- f"model class {model.__class__.__name__}"
178
- )
175
+ # assert sequence_length == dummies["input_ids"].shape[-1], (
176
+ # f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
177
+ # f"model class {model.__class__.__name__}"
178
+ # )
179
179
  assert batch_size == dummies["input_ids"].shape[0], (
180
180
  f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
181
181
  f"model class {model.__class__.__name__}"
@@ -200,6 +200,9 @@ def _get_inputs_gemma3(
200
200
 
201
201
  _check_()
202
202
 
203
+ make_hybrid_cache = get_make_hybrid_cache()
204
+ assert make_hybrid_cache is not None, "not implemented when make_hybrid_cache is missing"
205
+
203
206
  inputs = dict(
204
207
  input_ids=dummies["input_ids"],
205
208
  token_type_ids=dummies["token_type_ids"],
@@ -1,11 +1,6 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple, Union
2
2
  import torch
3
- from ..helpers.cache_helper import (
4
- make_dynamic_cache,
5
- make_mamba_cache,
6
- make_sliding_window_cache,
7
- make_static_cache,
8
- )
3
+ from ..helpers.cache_helper import make_dynamic_cache, make_mamba_cache, make_static_cache
9
4
  from ..helpers.config_helper import (
10
5
  update_config,
11
6
  check_hasattr,
@@ -187,17 +182,22 @@ def get_inputs(
187
182
  if cls_cache is None or isinstance(cls_cache, str)
188
183
  else cls_cache.__name__
189
184
  )
190
- make_caches = {
191
- "DynamicCache": make_dynamic_cache,
192
- "SlidingWindowCache": make_sliding_window_cache,
193
- "StaticCache": make_static_cache,
194
- }
195
- assert cache_name is None or cache_name in make_caches, (
196
- f"Unable to handle cls_cache={cache_name!r}, it should be in "
197
- f"{sorted(make_caches)}"
198
- )
199
- make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
200
- is_static = cache_name == "StaticCache"
185
+ if cache_name == "DynamicSlidingWindowCache":
186
+ from ..helpers.cache_helper import make_sliding_window_cache
187
+
188
+ make_cache = make_sliding_window_cache
189
+ is_static = False
190
+ else:
191
+ make_caches = {
192
+ "DynamicCache": make_dynamic_cache,
193
+ "StaticCache": make_static_cache,
194
+ }
195
+ assert cache_name is None or cache_name in make_caches, (
196
+ f"Unable to handle cls_cache={cache_name!r}, it should be in "
197
+ f"{sorted(make_caches)}"
198
+ )
199
+ make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name] # type: ignore[assignment]
200
+ is_static = cache_name == "StaticCache"
201
201
 
202
202
  if is_static:
203
203
  # static
@@ -521,7 +521,7 @@ def run_exporter(
521
521
  :param exporter: exporter
522
522
  :param cls_model: model class to create
523
523
  :param inputs: list of inputs to try
524
- :param dynamic: use dynamic shape or not
524
+ :param dynamic: use dynamic shapes or not
525
525
  :param quiet: raise exception or not
526
526
  :param verbose: verbosity
527
527
  :return: results
@@ -7,15 +7,9 @@ import transformers
7
7
  from transformers.cache_utils import DynamicCache, StaticCache
8
8
 
9
9
  try:
10
- from transformers.cache_utils import (
11
- EncoderDecoderCache,
12
- HybridCache,
13
- SlidingWindowCache,
14
- )
10
+ from transformers.cache_utils import EncoderDecoderCache
15
11
  except ImportError:
16
12
  EncoderDecoderCache = None
17
- HybridCache = None
18
- SlidingWindowCache = None
19
13
  from ..helpers import string_type
20
14
  from .serialization import _lower_name_with_
21
15
 
@@ -36,6 +30,24 @@ def get_mamba_cache_cls() -> type:
36
30
  return None
37
31
 
38
32
 
33
+ def get_hybrid_cache_cls() -> type:
34
+ try:
35
+ from transformers.cache_utils import HybridCache
36
+
37
+ return HybridCache
38
+ except ImportError:
39
+ return None
40
+
41
+
42
+ def get_sliding_window_cache_cls() -> type:
43
+ try:
44
+ from transformers.cache_utils import SlidingWindowCache
45
+
46
+ return SlidingWindowCache
47
+ except ImportError:
48
+ return None
49
+
50
+
39
51
  def register_class_serialization(
40
52
  cls,
41
53
  f_flatten: Callable,
@@ -179,18 +191,9 @@ def serialization_functions(
179
191
  flatten_dynamic_cache,
180
192
  unflatten_dynamic_cache,
181
193
  flatten_with_keys_dynamic_cache,
182
- flatten_hybrid_cache,
183
- unflatten_hybrid_cache,
184
- flatten_with_keys_hybrid_cache,
185
- flatten_mamba_cache,
186
- unflatten_mamba_cache,
187
- flatten_with_keys_mamba_cache,
188
194
  flatten_encoder_decoder_cache,
189
195
  unflatten_encoder_decoder_cache,
190
196
  flatten_with_keys_encoder_decoder_cache,
191
- flatten_sliding_window_cache,
192
- unflatten_sliding_window_cache,
193
- flatten_with_keys_sliding_window_cache,
194
197
  flatten_static_cache,
195
198
  unflatten_static_cache,
196
199
  flatten_with_keys_static_cache,
@@ -208,14 +211,6 @@ def serialization_functions(
208
211
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
209
212
  verbose=verbose,
210
213
  ),
211
- HybridCache: lambda verbose=verbose: register_class_serialization(
212
- HybridCache,
213
- flatten_hybrid_cache,
214
- unflatten_hybrid_cache,
215
- flatten_with_keys_hybrid_cache,
216
- # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
217
- verbose=verbose,
218
- ),
219
214
  EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
220
215
  EncoderDecoderCache,
221
216
  flatten_encoder_decoder_cache,
@@ -223,13 +218,6 @@ def serialization_functions(
223
218
  flatten_with_keys_encoder_decoder_cache,
224
219
  verbose=verbose,
225
220
  ),
226
- SlidingWindowCache: lambda verbose=verbose: register_class_serialization(
227
- SlidingWindowCache,
228
- flatten_sliding_window_cache,
229
- unflatten_sliding_window_cache,
230
- flatten_with_keys_sliding_window_cache,
231
- verbose=verbose,
232
- ),
233
221
  StaticCache: lambda verbose=verbose: register_class_serialization(
234
222
  StaticCache,
235
223
  flatten_static_cache,
@@ -240,6 +228,12 @@ def serialization_functions(
240
228
  }
241
229
  MambaCache = get_mamba_cache_cls()
242
230
  if MambaCache:
231
+ from .serialization.transformers_impl import (
232
+ flatten_mamba_cache,
233
+ unflatten_mamba_cache,
234
+ flatten_with_keys_mamba_cache,
235
+ )
236
+
243
237
  transformers_classes[MambaCache] = (
244
238
  lambda verbose=verbose: register_class_serialization(
245
239
  MambaCache,
@@ -249,6 +243,42 @@ def serialization_functions(
249
243
  verbose=verbose,
250
244
  )
251
245
  )
246
+ HybridCache = get_hybrid_cache_cls()
247
+ if HybridCache:
248
+ from .serialization.transformers_impl import (
249
+ flatten_hybrid_cache,
250
+ unflatten_hybrid_cache,
251
+ flatten_with_keys_hybrid_cache,
252
+ )
253
+
254
+ transformers_classes[HybridCache] = (
255
+ lambda verbose=verbose: register_class_serialization(
256
+ HybridCache,
257
+ flatten_hybrid_cache,
258
+ unflatten_hybrid_cache,
259
+ flatten_with_keys_hybrid_cache,
260
+ verbose=verbose,
261
+ )
262
+ )
263
+
264
+ SlidingWindowCache = get_sliding_window_cache_cls()
265
+ if SlidingWindowCache:
266
+ from .serialization.transformers_impl import (
267
+ flatten_sliding_window_cache,
268
+ unflatten_sliding_window_cache,
269
+ flatten_with_keys_sliding_window_cache,
270
+ )
271
+
272
+ transformers_classes[SlidingWindowCache] = (
273
+ lambda verbose=verbose: register_class_serialization(
274
+ SlidingWindowCache,
275
+ flatten_sliding_window_cache,
276
+ unflatten_sliding_window_cache,
277
+ flatten_with_keys_sliding_window_cache,
278
+ verbose=verbose,
279
+ )
280
+ )
281
+
252
282
  classes.update(transformers_classes)
253
283
 
254
284
  if patch_diffusers:
@@ -303,13 +333,7 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
303
333
 
304
334
 
305
335
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
306
- """Undo all registrations."""
307
- MambaCache = get_mamba_cache_cls()
308
- cls_ensemble = (
309
- {DynamicCache, EncoderDecoderCache}
310
- | set(undo)
311
- | ({MambaCache} if MambaCache else set())
312
- )
336
+ cls_ensemble = {DynamicCache, EncoderDecoderCache} | set(undo)
313
337
  for cls in cls_ensemble:
314
338
  if undo.get(cls.__name__, False):
315
339
  unregister_class_serialization(cls, verbose)
@@ -191,7 +191,7 @@ class PatchDetails:
191
191
  ep = torch.export.export(
192
192
  model, (), kwargs=inputs, dynamic_shapes=use_dyn_not_str(ds)
193
193
  )
194
- patches = details.patches_involded_in_graph(ep.graph)
194
+ patches = details.patches_involved_in_graph(ep.graph)
195
195
  report = details.make_report(patches, format="rst")
196
196
  print(report)
197
197
  """
@@ -235,7 +235,7 @@ class PatchDetails:
235
235
  """Returns the data for a dataframe."""
236
236
  return [p.to_dict() for p in self.patched]
237
237
 
238
- def patches_involded_in_graph(
238
+ def patches_involved_in_graph(
239
239
  self, graph: "torch.fx.Graph" # noqa: F821
240
240
  ) -> List[Tuple[PatchInfo, List["torch.fx.Node"]]]: # noqa: F821
241
241
  """
@@ -322,7 +322,7 @@ class PatchDetails:
322
322
  """
323
323
  Creates a report based on the involved patches.
324
324
 
325
- :param patches: from method :meth:`patches_involded_in_graph`
325
+ :param patches: from method :meth:`patches_involved_in_graph`
326
326
  :param format: format of the report
327
327
  :return: report
328
328
  """
@@ -22,13 +22,22 @@ if patch_DynamicLayer:
22
22
  _PATCHES_ = ["lazy_initialization"]
23
23
  _PATCHED_CLASS_ = DynamicLayer
24
24
 
25
- def lazy_initialization(self, key_states: torch.Tensor):
25
+ def lazy_initialization(
26
+ self, key_states: torch.Tensor, value_states: torch.Tensor = None
27
+ ):
26
28
  self.dtype, self.device = key_states.dtype, key_states.device
27
- new_shape = list(key_states.shape)
28
- new_shape[-2] = 0
29
+ assert (
30
+ hasattr(key_states, "shape") and key_states is not None
31
+ ), f"Attribute 'shape' is wrong for type {type(key_states)}"
32
+ like = torch.narrow(key_states, dim=-2, start=0, length=0)
29
33
  # PATCHED: used a tensor with an empty shape and not en empty list to initialize
30
- self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
31
- self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
34
+ if isinstance(key_states, torch._subclasses.fake_tensor.FakeTensor):
35
+ with key_states.fake_mode:
36
+ self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
37
+ self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
38
+ else:
39
+ self.keys = torch.empty_like(like, dtype=self.dtype, device=self.device)
40
+ self.values = torch.empty_like(like, dtype=self.dtype, device=self.device)
32
41
  if patch_is_initialized:
33
42
  self.is_initialized = True
34
43
 
@@ -214,7 +214,7 @@ def patched_dynamic_rope_update(rope_forward):
214
214
  cond,
215
215
  (lambda x, y: x.clone()),
216
216
  (lambda x, y: y.clone()),
217
- [long_inv_freq, original_inv_freq],
217
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
218
218
  )
219
219
  setattr(self, f"{prefix}inv_freq", inv_freq)
220
220
  # if seq_len > original_max_position_embeddings:
@@ -293,7 +293,7 @@ def patched_dynamic_rope_update(rope_forward):
293
293
  cond,
294
294
  (lambda x, y: x.clone()),
295
295
  (lambda x, y: y.clone()),
296
- [long_inv_freq, original_inv_freq],
296
+ [long_inv_freq.to(original_inv_freq.dtype), original_inv_freq],
297
297
  )
298
298
  setattr(self, f"{prefix}inv_freq", inv_freq)
299
299
 
@@ -541,14 +541,17 @@ class patched_ShapeEnv:
541
541
  # oblivious_var_to_val will be defined iff we have sizes
542
542
  # with DimDynamic.OBLIVIOUS_SIZE type.
543
543
  # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113
544
+ var_to_val = getattr(
545
+ self,
546
+ "unbacked_var_to_val",
547
+ getattr(self, "oblivious_var_to_val", False),
548
+ )
544
549
  if (
545
- self.oblivious_var_to_val
546
- and not (
547
- correct_hint := orig_expr.xreplace(self.oblivious_var_to_val)
548
- ).free_symbols
550
+ var_to_val
551
+ and not (correct_hint := orig_expr.xreplace(var_to_val)).free_symbols
549
552
  and not (
550
553
  counterfactual_hint := orig_expr.xreplace(
551
- {k: max(2, v) for k, v in self.oblivious_var_to_val.items()}
554
+ {k: max(2, v) for k, v in var_to_val.items()}
552
555
  )
553
556
  ).free_symbols
554
557
  and correct_hint == counterfactual_hint
@@ -571,11 +574,11 @@ class patched_ShapeEnv:
571
574
  # and if they pass we add a runtime assertions and continue.
572
575
  if (
573
576
  not ok
574
- and self.unbacked_var_to_val
577
+ and var_to_val
575
578
  and not (
576
- unsound_result := orig_expr.xreplace(
577
- self.unbacked_var_to_val
578
- ).xreplace(self.var_to_val)
579
+ unsound_result := orig_expr.xreplace(var_to_val).xreplace(
580
+ var_to_val
581
+ )
579
582
  ).free_symbols
580
583
  ):
581
584
  # pyrefly: ignore # unbound-name
@@ -1,13 +1,7 @@
1
1
  import itertools
2
2
  from typing import Any, Callable, List, Set, Tuple
3
3
  import torch
4
- from transformers.cache_utils import (
5
- Cache,
6
- DynamicCache,
7
- EncoderDecoderCache,
8
- HybridCache,
9
- StaticCache,
10
- )
4
+ from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
11
5
 
12
6
  try:
13
7
  from transformers.cache_utils import SlidingWindowCache
@@ -15,18 +9,17 @@ except ImportError:
15
9
  SlidingWindowCache = None
16
10
 
17
11
 
12
+ try:
13
+ from transformers.cache_utils import HybridCache
14
+ except ImportError:
15
+ HybridCache = None
16
+
18
17
  try:
19
18
  from transformers.models.mamba.modeling_mamba import MambaCache
20
19
  except ImportError:
21
20
  from transformers.cache_utils import MambaCache
22
21
  from transformers.modeling_outputs import BaseModelOutput
23
- from ...helpers.cache_helper import (
24
- make_dynamic_cache,
25
- make_hybrid_cache,
26
- make_sliding_window_cache,
27
- make_static_cache,
28
- CacheKeyValue,
29
- )
22
+ from ...helpers.cache_helper import make_dynamic_cache, make_static_cache, CacheKeyValue
30
23
  from . import make_serialization_function_for_dataclass
31
24
 
32
25
 
@@ -78,6 +71,14 @@ def flatten_dynamic_cache(
78
71
  dynamic_cache: DynamicCache,
79
72
  ) -> Tuple[List[Any], torch.utils._pytree.Context]:
80
73
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
74
+ assert (
75
+ not hasattr(dynamic_cache, "layers")
76
+ or not dynamic_cache.layers
77
+ or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
78
+ ), (
79
+ f"The serialization does not work yet on other layers "
80
+ f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
81
+ )
81
82
  return _flatten_key_value_cache(dynamic_cache)
82
83
 
83
84
 
@@ -85,6 +86,14 @@ def flatten_with_keys_dynamic_cache(
85
86
  dynamic_cache: DynamicCache,
86
87
  ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
87
88
  """Serializes a :class:`transformers.cache_utils.DynamicCache` with python objects."""
89
+ assert (
90
+ not hasattr(dynamic_cache, "layers")
91
+ or not dynamic_cache.layers
92
+ or all(lay.__class__.__name__ == "DynamicLayer" for lay in dynamic_cache.layers)
93
+ ), (
94
+ f"The serialization does not work yet on other layers "
95
+ f"than DynamicLayer, but layers={[lay.__class__ for lay in dynamic_cache.layers]}"
96
+ )
88
97
  return _flatten_with_keys_cache(dynamic_cache)
89
98
 
90
99
 
@@ -99,26 +108,27 @@ def unflatten_dynamic_cache(
99
108
  # HybridCache
100
109
  #############
101
110
 
111
+ if HybridCache:
102
112
 
103
- def flatten_hybrid_cache(
104
- cache: HybridCache,
105
- ) -> Tuple[List[Any], torch.utils._pytree.Context]:
106
- """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
107
- return _flatten_key_value_cache(cache)
108
-
113
+ def flatten_hybrid_cache(
114
+ cache: HybridCache,
115
+ ) -> Tuple[List[Any], torch.utils._pytree.Context]:
116
+ """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
117
+ return _flatten_key_value_cache(cache)
109
118
 
110
- def flatten_with_keys_hybrid_cache(
111
- cache: HybridCache,
112
- ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
113
- """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
114
- return _flatten_with_keys_cache(cache)
119
+ def flatten_with_keys_hybrid_cache(
120
+ cache: HybridCache,
121
+ ) -> Tuple[List[Tuple[torch.utils._pytree.KeyEntry, Any]], torch.utils._pytree.Context]:
122
+ """Serializes a :class:`transformers.cache_utils.HybridCache` with python objects."""
123
+ return _flatten_with_keys_cache(cache)
115
124
 
125
+ def unflatten_hybrid_cache(
126
+ values: List[Any], context: torch.utils._pytree.Context, output_type=None
127
+ ) -> HybridCache:
128
+ """Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
129
+ from ...helpers.cache_helper import make_hybrid_cache
116
130
 
117
- def unflatten_hybrid_cache(
118
- values: List[Any], context: torch.utils._pytree.Context, output_type=None
119
- ) -> HybridCache:
120
- """Restores a :class:`transformers.cache_utils.HybridCache` from python objects."""
121
- return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
131
+ return _unflatten_cache(make_hybrid_cache, values, context, output_type=output_type)
122
132
 
123
133
 
124
134
  #############
@@ -190,6 +200,8 @@ if SlidingWindowCache:
190
200
  Restores a :class:`transformers.cache_utils.SlidingWindowCache`
191
201
  from python objects.
192
202
  """
203
+ from ...helpers.cache_helper import make_sliding_window_cache
204
+
193
205
  return _unflatten_cache(
194
206
  make_sliding_window_cache, values, context, output_type=output_type
195
207
  )
@@ -1771,6 +1771,10 @@ def validate_onnx_model(
1771
1771
  if os.environ.get("DUMPORTOPT", "") in ("1", "true", "True"):
1772
1772
  opts = onnxruntime.SessionOptions()
1773
1773
  opts.optimized_model_filepath = f"{data['onnx_filename']}.rtopt.onnx"
1774
+ opts.add_session_config_entry(
1775
+ "session.optimized_model_external_initializers_file_name",
1776
+ f"{os.path.split(data['onnx_filename'])[0]}.rtopt.data",
1777
+ )
1774
1778
  if verbose:
1775
1779
  print(
1776
1780
  f"[validate_onnx_model] saved optimized onnxruntime "
@@ -2326,6 +2330,7 @@ def call_torch_export_custom(
2326
2330
  "custom-dec",
2327
2331
  "custom-decall",
2328
2332
  "custom-fake",
2333
+ "custom-tracing",
2329
2334
  }
2330
2335
  assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
2331
2336
  assert "model" in data, f"model is missing from data: {sorted(data)}"
@@ -2338,11 +2343,16 @@ def call_torch_export_custom(
2338
2343
  f"Options strict cannot be specified in the exporter name {exporter!r} "
2339
2344
  f"and in the options {exporter_options}"
2340
2345
  )
2346
+ assert ("-tracing" not in exporter) or ("tracing" not in exporter_options), (
2347
+ f"Options tracing cannot be specified in the exporter name {exporter!r} "
2348
+ f"and in the options {exporter_options}"
2349
+ )
2341
2350
  summary: Dict[str, Union[str, int, float]] = {}
2342
2351
  strict = "-strict" in exporter or exporter_options.pop("strict", False)
2343
2352
  args, kwargs = split_args_kwargs(data["inputs_export"])
2344
2353
  ds = data.get("dynamic_shapes", None)
2345
2354
  fake = "-fake" in exporter or exporter_options.pop("fake", False)
2355
+ tracing = "-tracing" in exporter or exporter_options.pop("tracing", False)
2346
2356
  if fake:
2347
2357
  from onnx_diagnostic.export.shape_helper import make_fake_with_dynamic_dimensions
2348
2358
 
@@ -2366,6 +2376,7 @@ def call_torch_export_custom(
2366
2376
  summary["export_exporter"] = exporter
2367
2377
  summary["export_optimization"] = optimization or ""
2368
2378
  summary["export_strict"] = strict
2379
+ summary["export_tracing"] = tracing
2369
2380
  summary["export_fake"] = fake
2370
2381
  summary["export_args"] = string_type(args, with_shape=True)
2371
2382
  summary["export_kwargs"] = string_type(kwargs, with_shape=True)
@@ -2388,6 +2399,7 @@ def call_torch_export_custom(
2388
2399
  )
2389
2400
  )
2390
2401
  large_model = bool(exporter_options.pop("large_model", True))
2402
+ exporter_options.pop("tracing", False)
2391
2403
  return_optimize_report = bool(exporter_options.pop("return_optimize_report", True))
2392
2404
  export_modules_as_functions = bool(
2393
2405
  exporter_options.pop("export_modules_as_functions", False)
@@ -2401,6 +2413,7 @@ def call_torch_export_custom(
2401
2413
  summary["export_external_threshold"] = str(external_threshold)
2402
2414
 
2403
2415
  export_options = ExportOptions(
2416
+ tracing=tracing,
2404
2417
  strict=strict,
2405
2418
  decomposition_table=decomposition_table,
2406
2419
  save_ep=(
@@ -2445,6 +2458,41 @@ def call_torch_export_custom(
2445
2458
  )
2446
2459
  ),
2447
2460
  )
2461
+ if "optimization" in opt_stats and dump_folder:
2462
+ import pandas
2463
+
2464
+ pattern_stats = []
2465
+ for k, v in opt_stats.items():
2466
+ if "time" in k:
2467
+ pattern_stats.append(dict(level="main", pattern=k, time_in=v))
2468
+ pattern_stats.extend(
2469
+ [{**obs, "level": "detailed"} for obs in opt_stats["optimization"]]
2470
+ )
2471
+ stat_filename = os.path.join(dump_folder, "optimization_stats.xlsx")
2472
+ df = pandas.DataFrame(pattern_stats)
2473
+ df.to_excel(stat_filename, index=False)
2474
+ cols = [
2475
+ c
2476
+ for c in [
2477
+ "level",
2478
+ "pattern",
2479
+ "time_in",
2480
+ "iteration",
2481
+ "inlined",
2482
+ "removed",
2483
+ "added",
2484
+ "instances",
2485
+ "changed",
2486
+ "scale",
2487
+ ]
2488
+ if c in df.columns
2489
+ ]
2490
+ agg = {k: "sum" for k in cols if k not in ("level", "pattern")}
2491
+ agg.update(dict(iteration="max", instances="mean"))
2492
+ agg = {k: v for k, v in agg.items() if k in df.columns}
2493
+ stat_filename = os.path.join(dump_folder, "optimization_stats.agg.xlsx")
2494
+ df[cols].groupby(["level", "pattern"]).agg(agg).to_excel(stat_filename)
2495
+
2448
2496
  if "ERR_export_onnx_c" in summary:
2449
2497
  return summary, data
2450
2498
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.8.7
3
+ Version: 0.8.9
4
4
  Summary: Tools to help converting pytorch models into ONNX.
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -90,6 +90,8 @@ Enlightening Examples
90
90
 
91
91
  * `Export microsoft/phi-2
92
92
  <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
93
+ * `Export a LLM through method generate (with Tiny-LLM)
94
+ <https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_method_generate.html>`_
93
95
 
94
96
  **Torch Export**
95
97