onnx-diagnostic 0.7.13__py3-none-any.whl → 0.7.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.13"
6
+ __version__ = "0.7.15"
7
7
  __author__ = "Xavier Dupré"
@@ -400,12 +400,17 @@ def get_parser_validate() -> ArgumentParser:
400
400
 
401
401
  position_ids is usually not needed, they can be removed by adding:
402
402
 
403
- --drop position_ids
403
+ --drop position_ids
404
404
 
405
405
  The behaviour may be modified compare the original configuration,
406
406
  the following argument can be rope_scaling to dynamic:
407
407
 
408
- --mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
408
+ --mop \"rope_scaling={'rope_type': 'dynamic', 'factor': 10.0}\""
409
+
410
+ You can profile the command line by running:
411
+
412
+ pyinstrument -m onnx_diagnostic validate ...
413
+ pyinstrument -r html -o profile.html -m onnx_diagnostic validate ...
409
414
  """
410
415
  ),
411
416
  formatter_class=RawTextHelpFormatter,
@@ -548,6 +553,12 @@ def get_parser_validate() -> ArgumentParser:
548
553
  action=BooleanOptionalAction,
549
554
  help="Enables onnxruntime logging when the session is created",
550
555
  )
556
+ parser.add_argument(
557
+ "--quiet-input-sets",
558
+ default="",
559
+ help="Avoids raising an exception when an input sets does not work with "
560
+ "the exported model, example: --quiet-input-sets=inputs,inputs22",
561
+ )
551
562
  return parser
552
563
 
553
564
 
@@ -609,6 +620,7 @@ def _cmd_validate(argv: List[Any]):
609
620
  warmup=args.warmup,
610
621
  inputs2=args.inputs2,
611
622
  ort_logs=args.ort_logs,
623
+ quiet_input_sets=set(args.quiet_input_sets.split(",")),
612
624
  output_names=(
613
625
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
614
626
  ),
@@ -829,7 +841,7 @@ def get_parser_agg() -> ArgumentParser:
829
841
  "n_model_pass,n_model_faster,"
830
842
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
831
843
  "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
832
- "n_node_layer_normalization,n_node_layer_normalization23,"
844
+ "n_node_gqa,n_node_layer_normalization,n_node_layer_normalization23,"
833
845
  "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
834
846
  "n_node_constant,n_node_shape,n_node_expand,"
835
847
  "n_node_function,n_node_initializer,n_node_scatter,"
@@ -108,7 +108,7 @@ def flatten_unflatten_for_dynamic_shapes(
108
108
 
109
109
  def is_cache_dynamic_registered(fast: bool = False) -> bool:
110
110
  """
111
- Tells class :class:`transformers.cache_utils.DynamicCache` can be
111
+ Tells if class :class:`transformers.cache_utils.DynamicCache` can be
112
112
  serialized and deserialized. Only then, :func:`torch.export.export`
113
113
  can export a model.
114
114
 
@@ -95,7 +95,8 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
95
95
  mod_name = cls.__module__
96
96
  mod = importlib.import_module(mod_name)
97
97
  source = inspect.getsource(mod)
98
- reg = re.compile("config: ([A-Za-z0-9]+)")
98
+ # [^O] avoids capturing Optional[Something]
99
+ reg = re.compile("config: ([^O][A-Za-z0-9]+)")
99
100
  fall = reg.findall(source)
100
101
  if len(fall) == 0:
101
102
  assert not exc, (
@@ -1167,7 +1167,7 @@ class CubeLogs:
1167
1167
  df.to_excel(
1168
1168
  writer,
1169
1169
  sheet_name=name,
1170
- freeze_panes=(df.columns.nlevels + df.index.nlevels, df.index.nlevels),
1170
+ freeze_panes=(df.columns.nlevels + 1, df.index.nlevels),
1171
1171
  )
1172
1172
  f_highlights[name] = tview.f_highlight
1173
1173
  if tview.plots:
@@ -1210,7 +1210,7 @@ class CubeLogs:
1210
1210
  for k, v in sbs.items():
1211
1211
  print(f"[CubeLogs.to_excel] sbs {k}: {v}")
1212
1212
  name = "∧".join(sbs)
1213
- sbs_raw, sbs_agg = self.sbs(sbs)
1213
+ sbs_raw, sbs_agg, sbs_col = self.sbs(sbs)
1214
1214
  if verbose:
1215
1215
  print(f"[CubeLogs.to_excel] add sheet {name!r} with shape {sbs_raw.shape}")
1216
1216
  print(
@@ -1222,7 +1222,7 @@ class CubeLogs:
1222
1222
  writer,
1223
1223
  sheet_name=name,
1224
1224
  freeze_panes=(
1225
- sbs_raw.columns.nlevels + sbs_raw.index.nlevels,
1225
+ sbs_raw.columns.nlevels + 1,
1226
1226
  sbs_raw.index.nlevels,
1227
1227
  ),
1228
1228
  )
@@ -1230,10 +1230,18 @@ class CubeLogs:
1230
1230
  writer,
1231
1231
  sheet_name=f"{name}-AGG",
1232
1232
  freeze_panes=(
1233
- sbs_agg.columns.nlevels + sbs_agg.index.nlevels,
1233
+ sbs_agg.columns.nlevels + 1,
1234
1234
  sbs_agg.index.nlevels,
1235
1235
  ),
1236
1236
  )
1237
+ sbs_col.to_excel(
1238
+ writer,
1239
+ sheet_name=f"{name}-COL",
1240
+ freeze_panes=(
1241
+ sbs_col.columns.nlevels + 1,
1242
+ sbs_col.index.nlevels,
1243
+ ),
1244
+ )
1237
1245
 
1238
1246
  if plots:
1239
1247
  from openpyxl.drawing.image import Image
@@ -1314,7 +1322,7 @@ class CubeLogs:
1314
1322
 
1315
1323
  def sbs(
1316
1324
  self, configs: Dict[str, Dict[str, Any]], column_name: str = "CONF"
1317
- ) -> Tuple[pandas.DataFrame, pandas.DataFrame]:
1325
+ ) -> Tuple[pandas.DataFrame, pandas.DataFrame, pandas.DataFrame]:
1318
1326
  """
1319
1327
  Creates a side-by-side for two configurations.
1320
1328
  Every configuration a dictionary column:value which filters in
@@ -1325,7 +1333,7 @@ class CubeLogs:
1325
1333
  :param configs: example
1326
1334
  ``dict(CFA=dict(exporter="E1", opt="O"), CFB=dict(exporter="E2", opt="O"))``
1327
1335
  :param column_name: column to add with the name of the configuration
1328
- :return: data and aggregated date
1336
+ :return: data, aggregated date, data with a row per model
1329
1337
  """
1330
1338
  assert (
1331
1339
  len(configs) >= 2
@@ -1433,6 +1441,8 @@ class CubeLogs:
1433
1441
  _mkc(m, f"{n1}<{n2}"): (si < sj).astype(int),
1434
1442
  _mkc(m, f"{n1}=={n2}"): (si == sj).astype(int),
1435
1443
  _mkc(m, f"{n1}>{n2}"): (si > sj).astype(int),
1444
+ _mkc(m, f"{n1}*({n1}∧{n2})"): si * (~sinan & ~sjnan).astype(float),
1445
+ _mkc(m, f"{n2}*({n1}∧{n2})"): sj * (~sinan & ~sjnan).astype(float),
1436
1446
  }
1437
1447
  )
1438
1448
  nas.columns.names = view_res.columns.names
@@ -1452,13 +1462,11 @@ class CubeLogs:
1452
1462
  }
1453
1463
  flat = view_res.groupby(self.time).agg(aggs)
1454
1464
  flat = flat.stack("METRICS", future_stack=True)
1455
- return res, flat
1465
+ return res, flat, view_res.T.sort_index().T
1456
1466
 
1457
1467
 
1458
1468
  class CubeLogsPerformance(CubeLogs):
1459
- """
1460
- Processes logs coming from experiments.
1461
- """
1469
+ """Processes logs coming from experiments."""
1462
1470
 
1463
1471
  def __init__(
1464
1472
  self,
@@ -1511,20 +1519,25 @@ class CubeLogsPerformance(CubeLogs):
1511
1519
  "n_model_faster2x",
1512
1520
  "n_model_faster3x",
1513
1521
  "n_model_faster4x",
1522
+ "n_model_faster5x",
1514
1523
  "n_node_attention",
1515
1524
  "n_node_attention23",
1516
- "n_node_rotary_embedding",
1517
- "n_node_rotary_embedding23",
1518
- "n_node_layer_normalization",
1519
- "n_node_layer_normalization23",
1525
+ "n_node_causal_mask",
1526
+ "n_node_constant",
1520
1527
  "n_node_control_flow",
1521
- "n_node_scatter",
1528
+ "n_node_expand",
1522
1529
  "n_node_function",
1530
+ "n_node_gqa",
1523
1531
  "n_node_initializer",
1524
1532
  "n_node_initializer_small",
1525
- "n_node_constant",
1533
+ "n_node_layer_normalization",
1534
+ "n_node_layer_normalization23",
1535
+ "n_node_reshape",
1536
+ "n_node_rotary_embedding",
1537
+ "n_node_rotary_embedding23",
1538
+ "n_node_scatter",
1539
+ "n_node_sequence",
1526
1540
  "n_node_shape",
1527
- "n_node_expand",
1528
1541
  "onnx_n_nodes_no_cst",
1529
1542
  "peak_gpu_torch",
1530
1543
  "peak_gpu_nvidia",
@@ -1690,6 +1703,11 @@ class CubeLogsPerformance(CubeLogs):
1690
1703
  "time_latency",
1691
1704
  gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
1692
1705
  ),
1706
+ n_model_faster5x=lambda df: gpreserve(
1707
+ df,
1708
+ "time_latency",
1709
+ gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 4.98,
1710
+ ),
1693
1711
  n_node_attention23=lambda df: gpreserve(
1694
1712
  df, "time_latency_eager", gdf(df, "op_onnx__Attention")
1695
1713
  ),
@@ -1720,6 +1738,11 @@ class CubeLogsPerformance(CubeLogs):
1720
1738
  + gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
1721
1739
  + gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
1722
1740
  ),
1741
+ n_node_gqa=lambda df: gpreserve(
1742
+ df,
1743
+ "time_latency_eager",
1744
+ gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0),
1745
+ ),
1723
1746
  n_node_layer_normalization=lambda df: gpreserve(
1724
1747
  df,
1725
1748
  "time_latency_eager",
@@ -1764,9 +1787,22 @@ class CubeLogsPerformance(CubeLogs):
1764
1787
  n_node_shape=lambda df: gpreserve(
1765
1788
  df, "time_latency_eager", gdf(df, "op_onnx__Shape")
1766
1789
  ),
1790
+ n_node_reshape=lambda df: gpreserve(
1791
+ df, "time_latency_eager", gdf(df, "op_onnx__Reshape")
1792
+ ),
1767
1793
  n_node_expand=lambda df: gpreserve(
1768
1794
  df, "time_latency_eager", gdf(df, "op_onnx__Expand")
1769
1795
  ),
1796
+ n_node_causal_mask=lambda df: gpreserve(
1797
+ df,
1798
+ "time_latency_eager",
1799
+ gdf(df, "op_onnx__CausalMask", 0),
1800
+ ),
1801
+ n_node_sequence=lambda df: gpreserve(
1802
+ df,
1803
+ "time_latency_eager",
1804
+ gdf(df, "op_onnx__SequenceAt", 0) + gdf(df, "op_onnx__SplitToSequence", 0),
1805
+ ),
1770
1806
  )
1771
1807
  assert (
1772
1808
  formula in lambdas
@@ -3,8 +3,6 @@ import numpy as np
3
3
  import onnx
4
4
  import torch
5
5
  from .helper import string_type, flatten_object
6
- from .torch_helper import to_numpy
7
- from .cache_helper import is_cache_dynamic_registered
8
6
 
9
7
 
10
8
  def name_type_to_onnx_dtype(name: str) -> int:
@@ -49,7 +47,7 @@ def make_feeds(
49
47
  assert (
50
48
  not check_flatten
51
49
  or not all(isinstance(obj, torch.Tensor) for obj in flat)
52
- or not is_cache_dynamic_registered(fast=True)
50
+ # or not is_cache_dynamic_registered(fast=True)
53
51
  or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
54
52
  ), (
55
53
  f"Unexpected number of flattened objects, "
@@ -57,6 +55,8 @@ def make_feeds(
57
55
  f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
58
56
  )
59
57
  if use_numpy:
58
+ from .torch_helper import to_numpy
59
+
60
60
  flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
61
61
  names = (
62
62
  [i.name for i in proto.graph.input]
@@ -186,12 +186,13 @@ def _get_inputs_gemma3(
186
186
  f"total_sequence_length={total_sequence_length} != 860 "
187
187
  f"for model {model.__class__.__name__}"
188
188
  )
189
- assert (
190
- head_dim == 256
191
- ), f"head_dim={head_dim} != 256 for model {model.__class__.__name__}"
189
+ assert head_dim in (
190
+ 256,
191
+ 32,
192
+ ), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}"
192
193
  assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
193
- assert num_key_value_heads == 4, (
194
- f"num_key_value_heads={num_key_value_heads} != 256 "
194
+ assert num_key_value_heads in (1, 4), (
195
+ f"num_key_value_heads={num_key_value_heads} not in (1, 4) "
195
196
  f"for this model {model.__class__.__name__}"
196
197
  )
197
198
 
@@ -19,6 +19,9 @@ __TASK__ = "text-generation"
19
19
  def reduce_model_config(config: Any) -> Dict[str, Any]:
20
20
  """Reduces a model size."""
21
21
  # FalconMambaConfig: use_mambapy
22
+ if hasattr(config, "text_config"):
23
+ # The model is probably of mixture of models used only for text.
24
+ config = config.text_config
22
25
  check_hasattr(
23
26
  config,
24
27
  ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
@@ -284,6 +287,21 @@ def get_inputs(
284
287
  add_second_input=0,
285
288
  **kwargs,
286
289
  )["inputs"]
290
+ res["inputs_batch1"] = get_inputs(
291
+ model=model,
292
+ config=config,
293
+ dummy_max_token_id=dummy_max_token_id,
294
+ num_hidden_layers=num_hidden_layers,
295
+ batch_size=1,
296
+ sequence_length=sequence_length,
297
+ sequence_length2=sequence_length2,
298
+ dynamic_rope=dynamic_rope,
299
+ num_key_value_heads=num_key_value_heads,
300
+ head_dim=head_dim,
301
+ cls_cache=cls_cache,
302
+ add_second_input=0,
303
+ **kwargs,
304
+ )["inputs"]
287
305
  return res
288
306
 
289
307
 
@@ -293,6 +311,9 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
293
311
 
294
312
  If the configuration is None, the function selects typical dimensions.
295
313
  """
314
+ if hasattr(config, "text_config"):
315
+ # The model is probably of mixture of models used only for text.
316
+ config = config.text_config
296
317
  if config is not None:
297
318
  check_hasattr(
298
319
  config,
@@ -676,7 +676,13 @@ def run_exporter(
676
676
 
677
677
  if dynamic and len(inputs) > 1:
678
678
  for index, i in enumerate(inputs):
679
- expected = model(*_clone(i))
679
+ if quiet:
680
+ try:
681
+ expected = model(*_clone(i))
682
+ except Exception as e:
683
+ return dict(error=str(e), success=0, error_step=f"run0.{index}")
684
+ else:
685
+ expected = model(*_clone(i))
680
686
  try:
681
687
  got = mod(*i)
682
688
  except Exception as e:
@@ -353,12 +353,9 @@ class ControlFlowCondNonZero(torch.nn.Module):
353
353
 
354
354
 
355
355
  class ControlFlowCondIdentity_153832(torch.nn.Module):
356
- """
357
- `#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
358
- """
356
+ """`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
359
357
 
360
358
  def forward(self, x, y):
361
-
362
359
  def branch_cond_then_1(x):
363
360
  x = torch.abs(x) + 1
364
361
  return x
@@ -340,6 +340,7 @@ def torch_export_patches(
340
340
  ###############
341
341
 
342
342
  if patch_torch:
343
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
343
344
  from .patches.patch_torch import (
344
345
  patched_infer_size,
345
346
  patched_vmap,
@@ -347,6 +348,9 @@ def torch_export_patches(
347
348
  patched__constrain_user_specified_dimhint_range,
348
349
  _catch_produce_guards_and_solve_constraints,
349
350
  patch__check_input_constraints_for_graph,
351
+ patched__broadcast_in_dim_meta,
352
+ patched__maybe_broadcast,
353
+ patched_ShapeEnv,
350
354
  )
351
355
 
352
356
  if verbose:
@@ -383,6 +387,20 @@ def torch_export_patches(
383
387
  patched__constrain_user_specified_dimhint_range
384
388
  )
385
389
 
390
+ # torch._prims._broadcast_in_dim_meta
391
+ f_broadcast_in_dim = torch._prims.broadcast_in_dim
392
+ f__broadcast_in_dim_meta = torch._prims._broadcast_in_dim_meta
393
+ torch._prims._broadcast_in_dim_meta = patched__broadcast_in_dim_meta
394
+ torch._prims.broadcast_in_dim = patched__broadcast_in_dim_meta
395
+
396
+ # torch._refs._maybe_broadcast
397
+ f__maybe_broadcast = torch._refs._maybe_broadcast
398
+ torch._refs._maybe_broadcast = patched__maybe_broadcast
399
+
400
+ # ShapeEnv
401
+ f_shape_env__evaluate_expr = ShapeEnv._evaluate_expr
402
+ ShapeEnv._evaluate_expr = patched_ShapeEnv._evaluate_expr
403
+
386
404
  # torch._export.non_strict_utils.produce_guards_and_solve_constraints
387
405
  if patch_torch and catch_constraints:
388
406
  if verbose:
@@ -404,10 +422,7 @@ def torch_export_patches(
404
422
  )
405
423
  )
406
424
 
407
- if stop_if_static:
408
- from torch.fx.experimental.symbolic_shapes import ShapeEnv
409
- from .patches.patch_torch import patched_ShapeEnv
410
-
425
+ if patch_torch and stop_if_static:
411
426
  ShapeEnv._log_guard_remember = ShapeEnv._log_guard
412
427
 
413
428
  if verbose:
@@ -584,6 +599,10 @@ def torch_export_patches(
584
599
  torch._export.non_strict_utils._constrain_user_specified_dimhint_range = (
585
600
  f___constrain_user_specified_dimhint_range
586
601
  )
602
+ torch._prims._broadcast_in_dim_meta = f__broadcast_in_dim_meta
603
+ torch._prims.broadcast_in_dim = f_broadcast_in_dim
604
+ torch._refs._maybe_broadcast = f__maybe_broadcast
605
+ ShapeEnv._evaluate_expr = f_shape_env__evaluate_expr
587
606
 
588
607
  if verbose:
589
608
  print("[torch_export_patches] restored pytorch functions")
@@ -723,9 +742,7 @@ def torch_export_patches(
723
742
 
724
743
 
725
744
  def replacement_before_exporting(args: Any) -> Any:
726
- """
727
- Does replacements on the given inputs if needed.
728
- """
745
+ """Does replacements on the given inputs if needed."""
729
746
  if args is None:
730
747
  return None
731
748
  if isinstance(args, (int, float)):
@@ -12,17 +12,26 @@ from transformers.cache_utils import (
12
12
  StaticCache,
13
13
  )
14
14
 
15
- try:
16
- from transformers.models.mamba.modeling_mamba import MambaCache
17
- except ImportError:
18
- from transformers.cache_utils import MambaCache
19
-
20
15
  from ..helpers import string_type
21
16
  from .serialization import _lower_name_with_
22
17
 
23
18
  PATCH_OF_PATCHES: Set[Any] = set()
24
19
 
25
20
 
21
+ def get_mamba_cache_cls() -> type:
22
+ try:
23
+ from transformers.models.mamba.modeling_mamba import MambaCache
24
+
25
+ return MambaCache
26
+ except ImportError:
27
+ try:
28
+ from transformers.cache_utils import MambaCache
29
+
30
+ return MambaCache
31
+ except ImportError:
32
+ return None
33
+
34
+
26
35
  def register_class_serialization(
27
36
  cls,
28
37
  f_flatten: Callable,
@@ -203,13 +212,6 @@ def serialization_functions(
203
212
  # f_check=make_dynamic_cache([(torch.rand((4, 4, 4)), torch.rand((4, 4, 4)))]),
204
213
  verbose=verbose,
205
214
  ),
206
- MambaCache: lambda verbose=verbose: register_class_serialization(
207
- MambaCache,
208
- flatten_mamba_cache,
209
- unflatten_mamba_cache,
210
- flatten_with_keys_mamba_cache,
211
- verbose=verbose,
212
- ),
213
215
  EncoderDecoderCache: lambda verbose=verbose: register_class_serialization(
214
216
  EncoderDecoderCache,
215
217
  flatten_encoder_decoder_cache,
@@ -232,6 +234,17 @@ def serialization_functions(
232
234
  verbose=verbose,
233
235
  ),
234
236
  }
237
+ MambaCache = get_mamba_cache_cls()
238
+ if MambaCache:
239
+ transformers_classes[MambaCache] = (
240
+ lambda verbose=verbose: register_class_serialization(
241
+ MambaCache,
242
+ flatten_mamba_cache,
243
+ unflatten_mamba_cache,
244
+ flatten_with_keys_mamba_cache,
245
+ verbose=verbose,
246
+ )
247
+ )
235
248
  classes.update(transformers_classes)
236
249
 
237
250
  if patch_diffusers:
@@ -287,7 +300,12 @@ def unregister_class_serialization(cls: type, verbose: int = 0):
287
300
 
288
301
  def unregister_cache_serialization(undo: Dict[str, bool], verbose: int = 0):
289
302
  """Undo all registrations."""
290
- cls_ensemble = {MambaCache, DynamicCache, EncoderDecoderCache} | set(undo)
303
+ MambaCache = get_mamba_cache_cls()
304
+ cls_ensemble = (
305
+ {DynamicCache, EncoderDecoderCache}
306
+ | set(undo)
307
+ | ({MambaCache} if MambaCache else set())
308
+ )
291
309
  for cls in cls_ensemble:
292
310
  if undo.get(cls.__name__, False):
293
311
  unregister_class_serialization(cls, verbose)