onnx-diagnostic 0.8.4__py3-none-any.whl → 0.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,6 +2,7 @@ import ast
2
2
  import enum
3
3
  import inspect
4
4
  import itertools
5
+ import json
5
6
  from dataclasses import is_dataclass, fields
6
7
  from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7
8
  import numpy as np
@@ -1373,11 +1374,7 @@ def max_diff(
1373
1374
  if hist:
1374
1375
  if isinstance(hist, bool):
1375
1376
  hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
1376
- ind = np.digitize(diff.reshape((-1,)), hist, right=True)
1377
- cou = np.bincount(ind, minlength=ind.shape[0] + 1)
1378
- res["rep"] = dict(
1379
- zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])
1380
- )
1377
+ res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
1381
1378
  return res # type: ignore
1382
1379
 
1383
1380
  if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
@@ -1493,27 +1490,11 @@ def max_diff(
1493
1490
  dev=dev,
1494
1491
  )
1495
1492
  if hist:
1496
- if isinstance(hist, list) and len(hist) == 1:
1497
- res["rep"] = {f">{hist[0]}": (diff > hist[0]).sum().item()}
1498
- elif isinstance(hist, list) and len(hist) == 2:
1499
- res["rep"] = {
1500
- f">{hist[0]}": (diff > hist[0]).sum().item(),
1501
- f">{hist[1]}": (diff > hist[1]).sum().item(),
1502
- }
1503
- else:
1504
- if isinstance(hist, bool):
1505
- hist = torch.tensor(
1506
- [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1507
- )
1508
- hist = torch.tensor(hist).to(diff.device)
1509
- ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1510
- cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1511
- res["rep"] = dict(
1512
- zip(
1513
- [f">{x}" for x in hist],
1514
- [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1515
- )
1493
+ if isinstance(hist, bool):
1494
+ hist = torch.tensor(
1495
+ [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1516
1496
  )
1497
+ res["rep"] = {f">{h}": (diff > h).sum().item() for h in hist}
1517
1498
  return res # type: ignore
1518
1499
 
1519
1500
  if isinstance(expected, int) and isinstance(got, torch.Tensor):
@@ -1750,8 +1731,26 @@ def max_diff(
1750
1731
  )
1751
1732
 
1752
1733
 
1753
- def string_diff(diff: Dict[str, Any]) -> str:
1754
- """Renders discrepancies return by :func:`max_diff` into one string."""
1734
+ def string_diff(diff: Dict[str, Any], js: bool = False, ratio: bool = False, **kwargs) -> str:
1735
+ """
1736
+ Renders discrepancies return by :func:`max_diff` into one string.
1737
+
1738
+ :param diff: differences
1739
+ :param js: json format
1740
+ :param ratio: display mismatch ratio
1741
+ :param kwargs: addition values to add in the json format
1742
+ """
1743
+ if js:
1744
+ if "rep" in diff:
1745
+ rep = diff["rep"]
1746
+ diff = {**{k: v for k, v in diff.items() if k != "rep"}, **rep}
1747
+ if ratio:
1748
+ for k, v in rep.items():
1749
+ diff[f"%{k}"] = v / diff["n"]
1750
+ diff["mean"] = diff["sum"] / diff["n"]
1751
+ diff.update(kwargs)
1752
+ return json.dumps(diff)
1753
+
1755
1754
  # dict(abs=, rel=, sum=, n=n_diff, dnan=)
1756
1755
  if "dev" in diff:
1757
1756
  ddiff = {k: v for k, v in diff.items() if k != "dev"}
@@ -818,6 +818,7 @@ def torch_export_patches(
818
818
  rewrite: Optional[List[Callable]] = None,
819
819
  dump_rewriting: Optional[str] = None,
820
820
  patch_details: Optional[PatchDetails] = None,
821
+ profile: Optional[str] = None,
821
822
  ) -> Callable:
822
823
  """
823
824
  Tries to bypass some situations :func:`torch.export.export` does not support.
@@ -850,6 +851,8 @@ def torch_export_patches(
850
851
  :param dump_rewriting: dumps rewriting information in file beginning with that prefix
851
852
  :param patch_details: if specified, this class is used to stored every rewritten done.
852
853
  :param verbose: to show which patches is applied
854
+ :param profile: starts profiling whatever is called inside the context manager,
855
+ output the profiling into a text file
853
856
 
854
857
  The list of available patches.
855
858
 
@@ -1017,10 +1020,23 @@ def torch_export_patches(
1017
1020
  if verbose:
1018
1021
  print("[torch_export_patches] done patching")
1019
1022
 
1023
+ if profile:
1024
+ from pyinstrument import Profiler
1025
+
1026
+ profiler = Profiler()
1027
+ profiler.start()
1028
+ else:
1029
+ profiler = None
1030
+
1020
1031
  try:
1021
1032
  yield fct_callable
1022
1033
  finally:
1023
1034
 
1035
+ if profiler:
1036
+ profiler.stop()
1037
+ with open(profile, "w") as f:
1038
+ f.write(profiler.output_html())
1039
+
1024
1040
  # unpatch
1025
1041
 
1026
1042
  if verbose:
@@ -1,3 +1,4 @@
1
+ import inspect
1
2
  from typing import Callable, List, Optional, Tuple
2
3
  import torch
3
4
 
@@ -19,6 +20,12 @@ if patch_masking_utils:
19
20
  prepare_padding_mask,
20
21
  )
21
22
 
23
+ _prepare_padding_mask_kwargs = (
24
+ dict(_slice=False)
25
+ if "_slice" in inspect.signature(prepare_padding_mask).parameters
26
+ else {}
27
+ )
28
+
22
29
  try:
23
30
  # transformers>=5.0
24
31
  from transformers.masking_utils import (
@@ -132,7 +139,9 @@ if patch_masking_utils:
132
139
  ) -> Optional[torch.Tensor]:
133
140
  """manual patch for function ``transformers.masking_utils.sdpa_mask_recent_torch``."""
134
141
  q_length = cache_position.shape[0]
135
- padding_mask = prepare_padding_mask(attention_mask, kv_length, kv_offset, _slice=False)
142
+ padding_mask = prepare_padding_mask(
143
+ attention_mask, kv_length, kv_offset, **_prepare_padding_mask_kwargs
144
+ )
136
145
  if allow_is_causal_skip and _ignore_causal_mask_sdpa(
137
146
  padding_mask, q_length, kv_length, kv_offset, local_size
138
147
  ):
@@ -24,7 +24,7 @@ if patch_qwen2_5:
24
24
 
25
25
  onnx_plugs_op = onnxscript.values.Opset("onnx_plug", 1)
26
26
  op = onnxscript.opset22
27
- op24 = onnxscript.onnx_opset.opset24
27
+ op23 = onnxscript.onnx_opset.opset23
28
28
  msft_op = onnxscript.values.Opset("com.microsoft", 1)
29
29
  STOPAT = (
30
30
  int(os.environ.get("STOPAT", None))
@@ -101,7 +101,7 @@ if patch_qwen2_5:
101
101
  return attn_output_4d
102
102
 
103
103
  @onnxscript.script(opset=onnx_plugs_op)
104
- def LoopAttention24(
104
+ def LoopAttention23(
105
105
  query_states,
106
106
  key_states,
107
107
  value_states,
@@ -109,26 +109,26 @@ if patch_qwen2_5:
109
109
  scaling: float = 0.11180339887498948,
110
110
  num_heads: int = 16,
111
111
  ):
112
- to_3d_shape = op24.Constant(value_ints=[0, 0, -1])
113
- query_transposed = op24.Transpose(query_states, perm=[0, 2, 1, 3])
114
- output_shape = op24.Shape(query_transposed)
115
- query_3d = op24.Reshape(query_transposed, to_3d_shape)
116
- value_3d = op24.Reshape(op24.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
117
- key_3d = op24.Reshape(op24.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
118
- cu_seqlens = op24.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
119
- num_patches = op24.Size(cu_seqlens) - 1
120
- seq_axis = op24.Constant(value_ints=[1])
121
- seq_axis_int32 = op24.Cast(seq_axis, to=onnx.TensorProto.INT32)
122
- seq_attn = op24.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
112
+ to_3d_shape = op23.Constant(value_ints=[0, 0, -1])
113
+ query_transposed = op23.Transpose(query_states, perm=[0, 2, 1, 3])
114
+ output_shape = op23.Shape(query_transposed)
115
+ query_3d = op23.Reshape(query_transposed, to_3d_shape)
116
+ value_3d = op23.Reshape(op23.Transpose(value_states, perm=[0, 2, 1, 3]), to_3d_shape)
117
+ key_3d = op23.Reshape(op23.Transpose(key_states, perm=[0, 2, 1, 3]), to_3d_shape)
118
+ cu_seqlens = op23.Cast(cu_seqlens, to=onnx.TensorProto.INT32)
119
+ num_patches = op23.Size(cu_seqlens) - 1
120
+ seq_axis = op23.Constant(value_ints=[1])
121
+ seq_axis_int32 = op23.Cast(seq_axis, to=onnx.TensorProto.INT32)
122
+ seq_attn = op23.SequenceEmpty(dtype=onnx.TensorProto.FLOAT)
123
123
  for i_patch in range(num_patches):
124
- i_1d = op24.Reshape(i_patch, [1])
124
+ i_1d = op23.Reshape(i_patch, [1])
125
125
  i_plus_1_1d = i_1d + 1
126
- start = op24.Gather(cu_seqlens, i_1d, axis=0)
127
- end = op24.Gather(cu_seqlens, i_plus_1_1d, axis=0)
128
- query_i = op24.Slice(query_3d, start, end, seq_axis_int32)
129
- key_i = op24.Slice(key_3d, start, end, seq_axis_int32)
130
- value_i = op24.Slice(value_3d, start, end, seq_axis_int32)
131
- mha_output = op24.Attention(
126
+ start = op23.Gather(cu_seqlens, i_1d, axis=0)
127
+ end = op23.Gather(cu_seqlens, i_plus_1_1d, axis=0)
128
+ query_i = op23.Slice(query_3d, start, end, seq_axis_int32)
129
+ key_i = op23.Slice(key_3d, start, end, seq_axis_int32)
130
+ value_i = op23.Slice(value_3d, start, end, seq_axis_int32)
131
+ mha_output = op23.Attention(
132
132
  query_i,
133
133
  key_i,
134
134
  value_i,
@@ -137,9 +137,9 @@ if patch_qwen2_5:
137
137
  kv_num_heads=num_heads,
138
138
  softmax_precision=onnx.TensorProto.FLOAT,
139
139
  )
140
- seq_attn = op24.SequenceInsert(seq_attn, mha_output)
141
- attn_output = op24.ConcatFromSequence(seq_attn, axis=1)
142
- attn_output_4d = op24.Reshape(attn_output, output_shape)
140
+ seq_attn = op23.SequenceInsert(seq_attn, mha_output)
141
+ attn_output = op23.ConcatFromSequence(seq_attn, axis=1)
142
+ attn_output_4d = op23.Reshape(attn_output, output_shape)
143
143
  return attn_output_4d
144
144
 
145
145
  @onnxscript.script(opset=onnx_plugs_op)
@@ -256,20 +256,24 @@ if patch_qwen2_5:
256
256
  return attn_output
257
257
 
258
258
  def qwen_version_selector(opset: int, *args: torch.Tensor) -> Tuple[str, torch.dtype]:
259
- first_tensor = next(a for a in args if a is not None)
260
- dtype = first_tensor.dtype
259
+ first_float_tensor = next(
260
+ a
261
+ for a in args
262
+ if a is not None and a.dtype in {torch.float16, torch.float32, torch.bfloat16}
263
+ )
264
+ dtype = first_float_tensor.dtype
261
265
  strategy = patched_Qwen2_5_VLVisionAttention.STRATEGY_FOR_ATTENTION()
262
266
  itype = torch_dtype_to_onnx_dtype(dtype)
263
267
  if strategy is not None:
264
268
  return strategy, itype
265
269
  if dtype == torch.float32 or itype == onnx.TensorProto.FLOAT:
266
- if opset >= 24:
267
- return "LOOPA24", itype
270
+ if opset >= 23:
271
+ return "LOOPA23", itype
268
272
  return "LOOPMHA", itype
269
273
  if dtype == torch.float16 or itype == onnx.TensorProto.FLOAT16:
270
274
  # first_tensor may be a SymbolicTensor (onnx).
271
275
  # is_cuda is not available.
272
- if hasattr(first_tensor, "is_cuda") and first_tensor.is_cuda:
276
+ if hasattr(first_float_tensor, "is_cuda") and first_float_tensor.is_cuda:
273
277
  return "PACKED", itype
274
278
  return "LOOPMHA", itype
275
279
  raise AssertionError(
@@ -288,9 +292,9 @@ if patch_qwen2_5:
288
292
  ("PACKED", onnx.TensorProto.FLOAT16): _add_com_microsoft_opset(
289
293
  PackedAttention.to_function_proto()
290
294
  ),
291
- ("LOOPA24", onnx.TensorProto.FLOAT): LoopAttention24.to_function_proto(),
292
- ("LOOPA24", onnx.TensorProto.FLOAT16): _update_sequence_type(
293
- onnx.TensorProto.FLOAT16, LoopAttention24.to_function_proto()
295
+ ("LOOPA23", onnx.TensorProto.FLOAT): LoopAttention23.to_function_proto(),
296
+ ("LOOPA23", onnx.TensorProto.FLOAT16): _update_sequence_type(
297
+ onnx.TensorProto.FLOAT16, LoopAttention23.to_function_proto()
294
298
  ),
295
299
  ("LOOPMHA", onnx.TensorProto.FLOAT): _add_com_microsoft_opset(
296
300
  LoopMHAAttention.to_function_proto()
@@ -733,3 +737,71 @@ if patch_qwen2_5:
733
737
  attn_output = attn_output.reshape(seq_length, -1).contiguous()
734
738
  attn_output = self.proj(attn_output)
735
739
  return attn_output
740
+
741
+ class patched_Qwen2_5_VLModel:
742
+ _PATCHES_ = ["get_placeholder_mask"]
743
+ _PATCHED_CLASS_ = transformers.models.qwen2_5_vl.modeling_qwen2_5_vl.Qwen2_5_VLModel
744
+
745
+ def get_placeholder_mask(
746
+ self,
747
+ input_ids: torch.LongTensor,
748
+ inputs_embeds: torch.FloatTensor,
749
+ image_features: Optional[torch.FloatTensor] = None,
750
+ video_features: Optional[torch.FloatTensor] = None,
751
+ ):
752
+ if input_ids is None:
753
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
754
+ torch.tensor(
755
+ self.config.image_token_id,
756
+ dtype=torch.long,
757
+ device=inputs_embeds.device,
758
+ )
759
+ )
760
+ special_image_mask = special_image_mask.all(-1)
761
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
762
+ torch.tensor(
763
+ self.config.video_token_id,
764
+ dtype=torch.long,
765
+ device=inputs_embeds.device,
766
+ )
767
+ )
768
+ special_video_mask = special_video_mask.all(-1)
769
+ else:
770
+ special_image_mask = input_ids == self.config.image_token_id
771
+ special_video_mask = input_ids == self.config.video_token_id
772
+
773
+ special_image_mask = (
774
+ special_image_mask.unsqueeze(-1)
775
+ .expand_as(inputs_embeds)
776
+ .to(inputs_embeds.device)
777
+ )
778
+
779
+ # PATCHED: we should use torch._check
780
+ # but this fails for compilation. It cannot be verified with FakeTensors
781
+ # torch._check(
782
+ # image_features is None
783
+ # or inputs_embeds[special_image_mask].numel() == image_features.numel(),
784
+ # lambda: (
785
+ # f"Image features and image tokens do not match: tokens: "
786
+ # f"{special_image_mask.sum()}, features {image_features.shape[0]}"
787
+ # ),
788
+ # )
789
+
790
+ special_video_mask = (
791
+ special_video_mask.unsqueeze(-1)
792
+ .expand_as(inputs_embeds)
793
+ .to(inputs_embeds.device)
794
+ )
795
+
796
+ # PATCHED: we should use torch._check
797
+ # but this fails for compilation. It cannot be verified with FakeTensors
798
+ # torch._check(
799
+ # video_features is None
800
+ # or inputs_embeds[special_video_mask].numel() == video_features.numel(),
801
+ # lambda: (
802
+ # f"Videos features and video tokens do not match: tokens: "
803
+ # f"{special_video_mask.sum()}, features {video_features.shape[0]}"
804
+ # ),
805
+ # )
806
+
807
+ return special_image_mask, special_video_mask
@@ -77,6 +77,7 @@ if patch_qwen2_5:
77
77
  patched_Qwen2_5_VisionTransformerPretrainedModel,
78
78
  patched_Qwen2_5_VLVisionAttentionOneIteration,
79
79
  patched_Qwen2_5_VLVisionAttention,
80
+ patched_Qwen2_5_VLModel,
80
81
  PLUGS as PLUGS_Qwen25,
81
82
  )
82
83
 
@@ -55,6 +55,7 @@ Automatically generated:
55
55
  import base64
56
56
  import json
57
57
  import textwrap
58
+ from typing import Any
58
59
  import transformers
59
60
 
60
61
  null = None
@@ -62,6 +63,22 @@ true = True
62
63
  false = False
63
64
 
64
65
 
66
+ def _enforce_default(config_type: type, **kwargs) -> Any:
67
+ config = config_type(**kwargs)
68
+ for name in [
69
+ *[k for k in kwargs if k.endswith("_token_id")],
70
+ "attention_dropout",
71
+ "hidden_size",
72
+ "hidden_act",
73
+ "intermediate_size",
74
+ "max_position_embeddings",
75
+ "vocab_size",
76
+ ]:
77
+ if name in kwargs and (not hasattr(config, name) or getattr(config, name) is None):
78
+ setattr(config, name, kwargs[name])
79
+ return config
80
+
81
+
65
82
  def _ccached_arnir0_tiny_LLM():
66
83
  "arnir0/Tiny-LLM"
67
84
  return transformers.LlamaConfig(
@@ -4691,7 +4708,8 @@ def _ccached_zai_glm_45():
4691
4708
 
4692
4709
  def _ccached_microsoft_phi3_mini_128k_instruct():
4693
4710
  "microsoft/Phi-3-mini-128k-instruct"
4694
- return transformers.Phi3Config(
4711
+ return _enforce_default(
4712
+ transformers.Phi3Config,
4695
4713
  **{
4696
4714
  "_name_or_path": "Phi-3-mini-128k-instruct",
4697
4715
  "architectures": ["Phi3ForCausalLM"],
@@ -4827,13 +4845,14 @@ def _ccached_microsoft_phi3_mini_128k_instruct():
4827
4845
  "use_cache": true,
4828
4846
  "attention_bias": false,
4829
4847
  "vocab_size": 32064,
4830
- }
4848
+ },
4831
4849
  )
4832
4850
 
4833
4851
 
4834
4852
  def _ccached_google_gemma_3_4b_it_like():
4835
4853
  "google/gemma-3-4b-it"
4836
- return transformers.Gemma3Config(
4854
+ return _enforce_default(
4855
+ transformers.Gemma3Config,
4837
4856
  **{
4838
4857
  "architectures": ["Gemma3ForConditionalGeneration"],
4839
4858
  "boi_token_index": 255999,
@@ -4863,13 +4882,14 @@ def _ccached_google_gemma_3_4b_it_like():
4863
4882
  "patch_size": 14,
4864
4883
  "vision_use_head": false,
4865
4884
  },
4866
- }
4885
+ },
4867
4886
  )
4868
4887
 
4869
4888
 
4870
4889
  def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4871
4890
  "hf-internal-testing/tiny-random-Gemma3ForCausalLM"
4872
- return transformers.Gemma3TextConfig(
4891
+ return _enforce_default(
4892
+ transformers.Gemma3TextConfig,
4873
4893
  **{
4874
4894
  "architectures": ["Gemma3ForCausalLM"],
4875
4895
  "attention_bias": false,
@@ -4901,13 +4921,14 @@ def _ccached_hf_internal_testing_tiny_random_gemma3_for_causal_lm():
4901
4921
  "transformers_version": "4.52.0.dev0",
4902
4922
  "use_cache": true,
4903
4923
  "vocab_size": 262144,
4904
- }
4924
+ },
4905
4925
  )
4906
4926
 
4907
4927
 
4908
4928
  def _ccached_qwen_qwen2_5_vl_7b_instruct():
4909
4929
  "Qwen/Qwen2.5-VL-7B-Instruct"
4910
- return transformers.Qwen2_5_VLConfig(
4930
+ return _enforce_default(
4931
+ transformers.Qwen2_5_VLConfig,
4911
4932
  **{
4912
4933
  "architectures": ["Qwen2_5_VLForConditionalGeneration"],
4913
4934
  "attention_dropout": 0.0,
@@ -4954,5 +4975,5 @@ def _ccached_qwen_qwen2_5_vl_7b_instruct():
4954
4975
  },
4955
4976
  "rope_scaling": {"type": "mrope", "mrope_section": [16, 24, 24]},
4956
4977
  "vocab_size": 152064,
4957
- }
4978
+ },
4958
4979
  )