onnx-diagnostic 0.7.14__py3-none-any.whl → 0.7.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +156 -47
  3. onnx_diagnostic/export/dynamic_shapes.py +6 -6
  4. onnx_diagnostic/export/shape_helper.py +124 -6
  5. onnx_diagnostic/ext_test_case.py +5 -1
  6. onnx_diagnostic/helpers/cache_helper.py +68 -42
  7. onnx_diagnostic/helpers/config_helper.py +2 -1
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +153 -0
  9. onnx_diagnostic/helpers/helper.py +3 -0
  10. onnx_diagnostic/helpers/rt_helper.py +3 -3
  11. onnx_diagnostic/tasks/image_text_to_text.py +7 -6
  12. onnx_diagnostic/tasks/text_generation.py +7 -4
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +69 -11
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +31 -13
  15. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +109 -18
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +133 -28
  17. onnx_diagnostic/torch_models/code_sample.py +343 -0
  18. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +38 -0
  19. onnx_diagnostic/torch_models/hghub/model_inputs.py +7 -3
  20. onnx_diagnostic/torch_models/validate.py +73 -29
  21. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/METADATA +6 -6
  22. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/RECORD +25 -23
  23. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/WHEEL +0 -0
  24. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/licenses/LICENSE.txt +0 -0
  25. {onnx_diagnostic-0.7.14.dist-info → onnx_diagnostic-0.7.16.dist-info}/top_level.txt +0 -0
@@ -25,8 +25,8 @@ def retrieve_stacktrace():
25
25
 
26
26
  def _catch_produce_guards_and_solve_constraints(
27
27
  previous_function: Callable,
28
- fake_mode: "FakeTensorMode", # noqa: F821
29
- gm: "torch.fx.GraphModule", # noqa: F821
28
+ fake_mode: FakeTensorMode,
29
+ gm: torch.fx.GraphModule,
30
30
  dynamic_shapes: Union[Dict[str, Any], Tuple[Any], List[Any], None],
31
31
  equalities_inputs: "EqualityConstraint", # noqa: F821
32
32
  original_signature: inspect.Signature,
@@ -88,7 +88,7 @@ def patch__check_input_constraints_for_graph(
88
88
 
89
89
  def patched_infer_size(a, b):
90
90
  """Patches ``torch._subclasses.fake_impls.infer_size``."""
91
- from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
91
+ from torch.fx.experimental.symbolic_shapes import guard_or_false
92
92
 
93
93
  dimsA = len(a)
94
94
  dimsB = len(b)
@@ -113,19 +113,19 @@ def patched_infer_size(a, b):
113
113
  # were not the case, we'd need to write this using torch.sym_or() or
114
114
  # something like that).
115
115
  try:
116
- b1 = guard_size_oblivious(sizeA == 1)
116
+ b1 = guard_or_false(sizeA == 1)
117
117
  except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
118
118
  b1 = False
119
119
  try:
120
- b2 = guard_size_oblivious(sizeB == 1)
120
+ b2 = guard_or_false(sizeB == 1)
121
121
  except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
122
122
  b2 = False
123
123
  try:
124
- b3 = guard_size_oblivious(sizeA == sizeB)
124
+ b3 = guard_or_false(sizeA == sizeB)
125
125
  except torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode:
126
126
  b3 = False
127
127
  if b1 or b2 or b3:
128
- expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
128
+ expandedSizes[i] = sizeB if guard_or_false(sizeA == 1) else sizeA
129
129
  else:
130
130
  # PATCHED: generic case, the dimension is known, no need to assert
131
131
  expandedSizes[i] = torch.sym_max(sizeA, sizeB)
@@ -137,7 +137,6 @@ def patched__broadcast_shapes(*_shapes):
137
137
  from functools import reduce
138
138
  from torch._prims_common import IntLike
139
139
  from torch.fx.experimental.symbolic_shapes import (
140
- guard_size_oblivious,
141
140
  guard_or_false,
142
141
  is_nested_int,
143
142
  )
@@ -174,13 +173,15 @@ def patched__broadcast_shapes(*_shapes):
174
173
  continue
175
174
  # PATCHED: two cases, if == for sure, no broadcast,
176
175
  # otherwise maybe broadcast with max(dimensions)
177
- if guard_size_oblivious(common_shape[idx] == 1):
176
+ if guard_or_false(common_shape[idx] != 1):
177
+ pass
178
+ elif guard_or_false(common_shape[idx] == 1) or guard_or_false(shape[idx] != 1):
178
179
  if shape[idx] < 0:
179
180
  raise ValueError(
180
181
  "Attempting to broadcast a dimension with negative length!"
181
182
  )
182
183
  common_shape[idx] = shape[idx]
183
- elif guard_size_oblivious(shape[idx] != 1):
184
+ else:
184
185
  common_shape[idx] = torch.sym_max(common_shape[idx], shape[idx])
185
186
 
186
187
  return common_shape
@@ -360,6 +361,10 @@ class patched_ShapeEnv:
360
361
  },
361
362
  )
362
363
 
364
+ for source in self.var_to_sources.get(a, []):
365
+ if user_tb:
366
+ self.specialization_stacks[source] = user_tb
367
+
363
368
  # PATCHED: removed lines
364
369
  # if config.print_specializations:
365
370
  # self.log.warning(
@@ -973,15 +978,101 @@ def patched__broadcast_in_dim_meta(
973
978
  new_strides.append(a.stride()[original_idx])
974
979
  else:
975
980
  new_strides.append(0)
981
+ # PATCHED: disabled this check
982
+ elif guard_or_false(a.shape[original_idx] != 1):
983
+ new_strides.append(a.stride()[original_idx])
976
984
  else:
977
- # PATCHED: disabled this check
978
- # torch._check(
979
- # a.shape[original_idx] == shape[idx],
980
- # lambda idx=idx, original_idx=original_idx: (
981
- # f"non-broadcasting semantics require "
982
- # f"{a.shape[original_idx]} == {shape[idx]}"
983
- # ),
984
- # )
985
+ # This checks generates the following issue:
986
+ # non-broadcasting semantics require s3 == Max(s10, s3), False,
987
+ # guard_or_false(a.shape[idx]==1)=False, a.stride()=(1, 2),
988
+ # idx=1, a.shape=torch.Size([2, s3]), shape=[2, Max(s10, s3)],
989
+ # original_idx=1
990
+ torch._check(
991
+ a.shape[original_idx] == shape[idx],
992
+ lambda idx=idx, original_idx=original_idx: (
993
+ f"non-broadcasting semantics require "
994
+ f"{a.shape[original_idx]} == {shape[idx]}, "
995
+ f"{guard_or_false(a.shape[idx] != 1)}, "
996
+ f"guard_or_false(a.shape[idx]==1)="
997
+ f"{guard_or_false(a.shape[idx] == 1)}, "
998
+ f"a.stride()={a.stride()}, idx={idx}, a.shape={a.shape}, "
999
+ f"shape={shape}, original_idx={original_idx}"
1000
+ ),
1001
+ )
1002
+ new_strides.append(a.stride()[original_idx])
1003
+ original_idx = original_idx + 1
1004
+ else:
1005
+ if guard_or_true(shape[idx] != 1):
1006
+ # consistent with previous use of guard_size_oblivious
1007
+ new_strides.append(0)
1008
+ elif original_idx == a.ndim:
1009
+ new_strides.append(1)
1010
+ else:
1011
+ new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
1012
+
1013
+ return a.as_strided(shape, new_strides, a.storage_offset())
1014
+
1015
+
1016
+ def patched__broadcast_in_dim_meta_level_2(
1017
+ a: torch._prims_common.TensorLikeType,
1018
+ shape: torch._prims_common.ShapeType,
1019
+ broadcast_dimensions: Sequence[int],
1020
+ ):
1021
+ """Patches ``torch._prims._broadcast_in_dim_meta``."""
1022
+ from torch.fx.experimental.symbolic_shapes import (
1023
+ guard_or_false,
1024
+ guard_or_true,
1025
+ sym_or,
1026
+ )
1027
+
1028
+ # Type checks
1029
+ assert isinstance(a, torch._prims_common.TensorLike)
1030
+ assert isinstance(shape, Sequence)
1031
+ assert isinstance(broadcast_dimensions, Sequence)
1032
+
1033
+ # every dimension must be accounted for
1034
+ assert a.ndim == len(broadcast_dimensions)
1035
+
1036
+ # broadcast shape must have weakly more dimensions
1037
+ assert len(shape) >= a.ndim
1038
+
1039
+ # broadcast_dimensions must be an ascending sequence
1040
+ # (no relative reordering of dims) of integers and
1041
+ # each dimension must be within the new shape
1042
+ def _greater_than_reduce(acc, x):
1043
+ assert isinstance(x, (int, torch.export.Dim)), f"unexpected type {type(x)} for x"
1044
+ assert x > acc
1045
+ assert x < len(shape)
1046
+
1047
+ return x
1048
+
1049
+ reduce(_greater_than_reduce, broadcast_dimensions, -1)
1050
+
1051
+ # shape must be broadcastable to
1052
+ for idx, new_idx in enumerate(broadcast_dimensions):
1053
+ torch._check(
1054
+ sym_or(a.shape[idx] == 1, shape[new_idx] == a.shape[idx]),
1055
+ lambda idx=idx, new_idx=new_idx: (
1056
+ f"{a.shape[idx]} must be broadcastable to {shape[new_idx]}"
1057
+ ),
1058
+ )
1059
+
1060
+ new_strides = []
1061
+ original_idx = 0
1062
+ for idx in range(len(shape)):
1063
+ if idx in broadcast_dimensions:
1064
+ # Assigns a stride of zero to dimensions
1065
+ # which were actually broadcast
1066
+ if guard_or_false(a.shape[original_idx] == 1):
1067
+ if guard_or_false(a.shape[original_idx] == shape[idx]):
1068
+ new_strides.append(a.stride()[original_idx])
1069
+ else:
1070
+ new_strides.append(0)
1071
+ # PATCHED: disabled this check
1072
+ elif guard_or_false(a.shape[original_idx] != 1):
1073
+ new_strides.append(a.stride()[original_idx])
1074
+ else:
1075
+ # PATCHED: torch._check was removed
985
1076
  new_strides.append(a.stride()[original_idx])
986
1077
  original_idx = original_idx + 1
987
1078
  else:
@@ -1019,6 +1019,26 @@ def patched__compute_dynamic_ntk_parameters(
1019
1019
  return inv_freq, attention_factor
1020
1020
 
1021
1021
 
1022
+ def _get_rope_init_fn(self, layer_type=None) -> Callable:
1023
+ if hasattr(self, "rope_init_fn"):
1024
+ # transformers<=5.0
1025
+ rope_init_fn = (
1026
+ patched__compute_dynamic_ntk_parameters
1027
+ if self.rope_init_fn
1028
+ is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1029
+ else self.rope_init_fn
1030
+ )
1031
+ return rope_init_fn
1032
+
1033
+ rope_type = self.rope_type if layer_type is None else self.rope_type[layer_type]
1034
+ rope_init_fn = self.compute_default_rope_parameters
1035
+ if rope_type != "default":
1036
+ rope_init_fn = transformers.modeling_rope_utils.ROPE_INIT_FUNCTIONS[self.rope_type]
1037
+ if rope_init_fn is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters:
1038
+ return patched__compute_dynamic_ntk_parameters
1039
+ return rope_init_fn
1040
+
1041
+
1022
1042
  def patched_dynamic_rope_update(rope_forward):
1023
1043
  """manual patch: ``[patch:transformers.modeling_rope_utils.dynamic_rope_update]``
1024
1044
 
@@ -1082,22 +1102,27 @@ def patched_dynamic_rope_update(rope_forward):
1082
1102
 
1083
1103
  """
1084
1104
 
1085
- def longrope_frequency_update(self, position_ids, device):
1105
+ def longrope_frequency_update(self, position_ids, device, layer_type=None):
1086
1106
  # It is no use to patch the function after the model is created
1087
1107
  # as rope_init_fn is an attribute set to one function when the model
1088
1108
  # is created and when no patch is applied yet.
1089
1109
  # So we select the patched version here.
1090
- rope_init_fn = (
1091
- patched__compute_dynamic_ntk_parameters
1092
- if self.rope_init_fn
1093
- is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1094
- else self.rope_init_fn
1095
- )
1110
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
1096
1111
  seq_len = torch.max(position_ids) + 1
1097
1112
  if hasattr(self.config, "original_max_position_embeddings"):
1098
1113
  original_max_position_embeddings = self.config.original_max_position_embeddings
1099
1114
  else:
1100
1115
  original_max_position_embeddings = self.config.max_position_embeddings
1116
+
1117
+ if layer_type is None:
1118
+ # rope_type = self.rope_type
1119
+ original_inv_freq = self.original_inv_freq
1120
+ prefix = ""
1121
+ else:
1122
+ # rope_type = self.rope_type[layer_type]
1123
+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1124
+ prefix = f"{layer_type}_"
1125
+
1101
1126
  # At export time, seq_len is unknown.
1102
1127
  long_inv_freq, _ = rope_init_fn(
1103
1128
  self.config, device, seq_len=original_max_position_embeddings + 1
@@ -1112,13 +1137,13 @@ def patched_dynamic_rope_update(rope_forward):
1112
1137
  (lambda x, y: y.clone()),
1113
1138
  [long_inv_freq, original_inv_freq],
1114
1139
  )
1115
- self.inv_freq = inv_freq
1140
+ setattr(self, f"{prefix}inv_freq", inv_freq)
1116
1141
  # if seq_len > original_max_position_embeddings:
1117
1142
  # self.inv_freq = self.long_inv_freq
1118
1143
  # else:
1119
1144
  # self.inv_freq = self.original_inv_freq
1120
1145
 
1121
- def dynamic_frequency_update(self, position_ids, device):
1146
+ def dynamic_frequency_update(self, position_ids, device, layer_type=None):
1122
1147
  # constructor:
1123
1148
  # - self.max_seq_len_cached = config.max_position_embeddings
1124
1149
  # - self.original_max_seq_len = config.max_position_embeddings
@@ -1128,12 +1153,7 @@ def patched_dynamic_rope_update(rope_forward):
1128
1153
  # as rope_init_fn is an attribute set to one function when the model
1129
1154
  # is created and when no patch is applied yet.
1130
1155
  # So we select the patched version here.
1131
- rope_init_fn = (
1132
- patched__compute_dynamic_ntk_parameters
1133
- if self.rope_init_fn
1134
- is transformers.modeling_rope_utils._compute_dynamic_ntk_parameters
1135
- else self.rope_init_fn
1136
- )
1156
+ rope_init_fn = _get_rope_init_fn(self, layer_type=layer_type)
1137
1157
 
1138
1158
  # This behaviour is difficult to translate.
1139
1159
  # The sequence always grows.
@@ -1162,6 +1182,19 @@ def patched_dynamic_rope_update(rope_forward):
1162
1182
  self.config, device, seq_len=seq_len
1163
1183
  )
1164
1184
 
1185
+ if layer_type is None:
1186
+ # rope_type = self.rope_type
1187
+ # max_seq_len_cached = self.max_seq_len_cached
1188
+ original_inv_freq = self.original_inv_freq
1189
+ prefix = ""
1190
+ else:
1191
+ # rope_type = self.rope_type[layer_type]
1192
+ # max_seq_len_cached = getattr(
1193
+ # self, f"{layer_type}_max_seq_len_cached", self.max_seq_len_cached
1194
+ # )
1195
+ original_inv_freq = getattr(self, f"{layer_type}_original_inv_freq")
1196
+ prefix = f"{layer_type}_"
1197
+
1165
1198
  # Second test to translate.
1166
1199
  # Let's keep in mind, self.max_seq_len_cached = seq_len is likely to be True.
1167
1200
  # But in that case the following condition is a way to restore the original cache.
@@ -1183,15 +1216,26 @@ def patched_dynamic_rope_update(rope_forward):
1183
1216
  (lambda x, y: y.clone()),
1184
1217
  [long_inv_freq, original_inv_freq],
1185
1218
  )
1186
- self.inv_freq = inv_freq
1219
+ setattr(self, f"{prefix}inv_freq", inv_freq)
1187
1220
 
1188
1221
  @wraps(rope_forward)
1189
- def wrapper(self, x, position_ids):
1222
+ def wrapper(self, x, position_ids, layer_type=None):
1223
+ if layer_type is None:
1224
+ if "dynamic" in self.rope_type:
1225
+ dynamic_frequency_update(self, position_ids, device=x.device)
1226
+ elif self.rope_type == "longrope":
1227
+ longrope_frequency_update(self, position_ids, device=x.device)
1228
+ return rope_forward(self, x, position_ids)
1229
+
1190
1230
  if "dynamic" in self.rope_type:
1191
- dynamic_frequency_update(self, position_ids, device=x.device)
1231
+ dynamic_frequency_update(
1232
+ self, position_ids, device=x.device, layer_type=layer_type
1233
+ )
1192
1234
  elif self.rope_type == "longrope":
1193
- longrope_frequency_update(self, position_ids, device=x.device)
1194
- return rope_forward(self, x, position_ids)
1235
+ longrope_frequency_update(
1236
+ self, position_ids, device=x.device, layer_type=layer_type
1237
+ )
1238
+ return rope_forward(self, x, position_ids, layer_type=layer_type)
1195
1239
 
1196
1240
  return wrapper
1197
1241
 
@@ -1232,6 +1276,60 @@ def common_eager_attention_forward(
1232
1276
  return attn_output, attn_weights
1233
1277
 
1234
1278
 
1279
+ def patched_sdpa_attention_forward(
1280
+ module: torch.nn.Module,
1281
+ query: torch.Tensor,
1282
+ key: torch.Tensor,
1283
+ value: torch.Tensor,
1284
+ attention_mask: Optional[torch.Tensor],
1285
+ dropout: float = 0.0,
1286
+ scaling: Optional[float] = None,
1287
+ is_causal: Optional[bool] = None,
1288
+ **kwargs,
1289
+ ) -> tuple[torch.Tensor, None]:
1290
+ """[patch:transformers.integrations.sdpa_attention.sdpa_attention_forward]"""
1291
+ assert not kwargs.get("output_attentions", False), (
1292
+ "`sdpa` attention does not support `output_attentions=True`."
1293
+ " Please set your attention to `eager` if you want any of these features."
1294
+ )
1295
+ sdpa_kwargs = {}
1296
+ if hasattr(module, "num_key_value_groups"):
1297
+ if not transformers.integrations.sdpa_attention.use_gqa_in_sdpa(attention_mask, key):
1298
+ key = transformers.integrations.sdpa_attention.repeat_kv(
1299
+ key, module.num_key_value_groups
1300
+ )
1301
+ value = transformers.integrations.sdpa_attention.repeat_kv(
1302
+ value, module.num_key_value_groups
1303
+ )
1304
+ else:
1305
+ sdpa_kwargs = {"enable_gqa": True}
1306
+
1307
+ if attention_mask is not None and attention_mask.ndim == 4:
1308
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
1309
+
1310
+ is_causal = is_causal if is_causal is not None else getattr(module, "is_causal", True)
1311
+ # PATCHED: remove the test query.shape[2] > 1
1312
+ # is_causal = query.shape[2] > 1 and attention_mask is None and is_causal
1313
+ is_causal = attention_mask is None and is_causal
1314
+
1315
+ torch._check(
1316
+ attention_mask is None or attention_mask.shape[3] == key.shape[2],
1317
+ "Attention mask shape incompatible with key shape.",
1318
+ )
1319
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
1320
+ query,
1321
+ key,
1322
+ value,
1323
+ attn_mask=attention_mask,
1324
+ dropout_p=dropout,
1325
+ scale=scaling,
1326
+ is_causal=is_causal,
1327
+ **sdpa_kwargs,
1328
+ )
1329
+ attn_output = attn_output.transpose(1, 2).contiguous()
1330
+ return attn_output, None
1331
+
1332
+
1235
1333
  def patched_model_bart_eager_attention_forward(
1236
1334
  module: torch.nn.Module,
1237
1335
  query: torch.Tensor,
@@ -1287,12 +1385,18 @@ class common_RotaryEmbedding(torch.nn.Module):
1287
1385
  # @torch.no_grad()
1288
1386
  # PATCHED: the decorator
1289
1387
  @patched_dynamic_rope_update
1290
- def forward(self, x, position_ids):
1388
+ def forward(self, x, position_ids, layer_type=None):
1389
+ if layer_type is not None:
1390
+ # transformers>=5.0
1391
+ inv_freq = getattr(self, f"{layer_type}_inv_freq")
1392
+ attention_scaling = getattr(self, f"{layer_type}_attention_scaling")
1393
+ else:
1394
+ # transformers<5.0
1395
+ inv_freq = self.inv_freq
1396
+ attention_scaling = self.attention_scaling
1397
+
1291
1398
  inv_freq_expanded = (
1292
- self.inv_freq[None, :, None]
1293
- .float()
1294
- .expand(position_ids.shape[0], -1, 1)
1295
- .to(x.device)
1399
+ inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
1296
1400
  )
1297
1401
  position_ids_expanded = position_ids[:, None, :].float()
1298
1402
 
@@ -1304,8 +1408,8 @@ class common_RotaryEmbedding(torch.nn.Module):
1304
1408
  with torch.autocast(device_type=device_type, enabled=False): # Force float32
1305
1409
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1306
1410
  emb = torch.cat((freqs, freqs), dim=-1)
1307
- cos = emb.cos() * self.attention_scaling
1308
- sin = emb.sin() * self.attention_scaling
1411
+ cos = emb.cos() * attention_scaling
1412
+ sin = emb.sin() * attention_scaling
1309
1413
 
1310
1414
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
1311
1415
 
@@ -1380,7 +1484,8 @@ class patched_IdeficsEmbedding(torch.nn.Module):
1380
1484
 
1381
1485
  def _set_cos_sin_cache_then(x, inv_freq, seq_len, _cos_cached, _sin_cached):
1382
1486
  t = torch.arange(seq_len, device=x.device, dtype=torch.int64).type_as(inv_freq)
1383
- freqs = torch.einsum("i,j->ij", t, inv_freq)
1487
+ # freqs = torch.einsum("i,j->ij", t, inv_freq)
1488
+ freqs = t.reshape((-1, 1)) * inv_freq.reshape((1, -1))
1384
1489
  emb = torch.cat((freqs, freqs), dim=-1)
1385
1490
  return emb.cos().to(x.dtype), emb.sin().to(x.dtype)
1386
1491