onnx-diagnostic 0.7.10__py3-none-any.whl → 0.7.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.10"
6
+ __version__ = "0.7.12"
7
7
  __author__ = "Xavier Dupré"
@@ -474,7 +474,7 @@ def get_parser_validate() -> ArgumentParser:
474
474
  )
475
475
  parser.add_argument(
476
476
  "--runtime",
477
- choices=["onnxruntime", "torch", "ref"],
477
+ choices=["onnxruntime", "torch", "ref", "orteval", "orteval10"],
478
478
  default="onnxruntime",
479
479
  help="onnx runtime to use, `onnxruntime` by default",
480
480
  )
@@ -542,6 +542,12 @@ def get_parser_validate() -> ArgumentParser:
542
542
  "the onnx exporter should use.",
543
543
  default="",
544
544
  )
545
+ parser.add_argument(
546
+ "--ort-logs",
547
+ default=False,
548
+ action=BooleanOptionalAction,
549
+ help="Enables onnxruntime logging when the session is created",
550
+ )
545
551
  return parser
546
552
 
547
553
 
@@ -575,6 +581,7 @@ def _cmd_validate(argv: List[Any]):
575
581
  ):
576
582
  print(f"validate - unsupported args: export={args.export!r}, opt={args.opt!r}")
577
583
  return
584
+ patch_dict = args.patch if isinstance(args.patch, dict) else {"patch": args.patch}
578
585
  summary, _data = validate_model(
579
586
  model_id=args.mid,
580
587
  task=args.task,
@@ -585,8 +592,8 @@ def _cmd_validate(argv: List[Any]):
585
592
  use_pretrained=args.trained,
586
593
  dtype=args.dtype,
587
594
  device=args.device,
588
- patch=args.patch,
589
- rewrite=args.rewrite,
595
+ patch=patch_dict,
596
+ rewrite=args.rewrite and patch_dict.get("patch", True),
590
597
  stop_if_static=args.stop_if_static,
591
598
  optimization=args.opt,
592
599
  exporter=args.export,
@@ -601,6 +608,7 @@ def _cmd_validate(argv: List[Any]):
601
608
  repeat=args.repeat,
602
609
  warmup=args.warmup,
603
610
  inputs2=args.inputs2,
611
+ ort_logs=args.ort_logs,
604
612
  output_names=(
605
613
  None if len(args.outnames.strip()) < 2 else args.outnames.strip().split(",")
606
614
  ),
@@ -820,6 +828,8 @@ def get_parser_agg() -> ArgumentParser:
820
828
  "n_model_running,n_model_acc01,n_model_acc001,n_model_dynamic,"
821
829
  "n_model_pass,n_model_faster,"
822
830
  "n_model_faster2x,n_model_faster3x,n_model_faster4x,n_node_attention,"
831
+ "n_node_attention23,n_node_rotary_embedding,n_node_rotary_embedding23,"
832
+ "n_node_layer_normalization,n_node_layer_normalization23,"
823
833
  "peak_gpu_torch,peak_gpu_nvidia,n_node_control_flow,"
824
834
  "n_node_constant,n_node_shape,n_node_expand,"
825
835
  "n_node_function,n_node_initializer,n_node_scatter,"
@@ -4,11 +4,6 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
- try:
8
- from transformers.models.mamba.modeling_mamba import MambaCache
9
- except ImportError:
10
- from transformers.cache_utils import MambaCache
11
-
12
7
 
13
8
  class CacheKeyValue:
14
9
  """
@@ -354,8 +349,15 @@ def make_encoder_decoder_cache(
354
349
  )
355
350
 
356
351
 
357
- def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -> MambaCache:
352
+ def make_mamba_cache(
353
+ key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
354
+ ) -> "MambaCache": # noqa: F821
358
355
  "Creates a ``MambaCache``."
356
+ # import is moved here because this part is slow.
357
+ try:
358
+ from transformers.models.mamba.modeling_mamba import MambaCache
359
+ except ImportError:
360
+ from transformers.cache_utils import MambaCache
359
361
  dtype = key_value_pairs[0][0].dtype
360
362
 
361
363
  class _config:
@@ -285,7 +285,8 @@ class CubePlot:
285
285
  nn = df.shape[1] // n_cols
286
286
  nn += int(df.shape[1] % n_cols != 0)
287
287
  ratio = float(os.environ.get("FIGSIZEH", "1"))
288
- fig, axs = plt.subplots(nn, n_cols, figsize=(6 * n_cols, nn * df.shape[0] / 3 * ratio))
288
+ figsize = (6 * n_cols, nn * (2.5 + df.shape[0] / 15) * ratio)
289
+ fig, axs = plt.subplots(nn, n_cols, figsize=figsize)
289
290
  pos = 0
290
291
  imgs = []
291
292
  for c in self._make_loop(df.columns, verbose):
@@ -332,10 +333,12 @@ class CubePlot:
332
333
  n_cols = len(groups)
333
334
 
334
335
  title_suffix = f"\n{title_suffix}" if title_suffix else ""
336
+ ratio = float(os.environ.get("FIGSIZEH", "1"))
337
+ figsize = (5 * n_cols, max(len(g) for g in groups) * (2 + df.shape[1] / 2) * ratio)
335
338
  fig, axs = plt.subplots(
336
339
  df.shape[1],
337
340
  n_cols,
338
- figsize=(5 * n_cols, max(len(g) for g in groups) * df.shape[1] / 2),
341
+ figsize=figsize,
339
342
  sharex=True,
340
343
  sharey="row" if n_cols > 1 else False,
341
344
  )
@@ -877,7 +880,11 @@ class CubeLogs:
877
880
  print(f"[CubeLogs.view] key_columns={key_columns}")
878
881
  g = data[[*key_index, *key_columns]].copy()
879
882
  g["count"] = 1
880
- r = g.groupby([*key_index, *key_columns], dropna=False).sum()
883
+ r = (
884
+ g.copy()
885
+ if not key_index and not key_columns
886
+ else g.groupby([*key_index, *key_columns], dropna=False).sum()
887
+ )
881
888
  not_unique = r[r["count"] > 1]
882
889
  assert not_unique.shape[0] == 0, (
883
890
  f"view_def.name={view_def.name!r}, "
@@ -1505,6 +1512,11 @@ class CubeLogsPerformance(CubeLogs):
1505
1512
  "n_model_faster3x",
1506
1513
  "n_model_faster4x",
1507
1514
  "n_node_attention",
1515
+ "n_node_attention23",
1516
+ "n_node_rotary_embedding",
1517
+ "n_node_rotary_embedding23",
1518
+ "n_node_layer_normalization",
1519
+ "n_node_layer_normalization23",
1508
1520
  "n_node_control_flow",
1509
1521
  "n_node_scatter",
1510
1522
  "n_node_function",
@@ -1568,7 +1580,9 @@ class CubeLogsPerformance(CubeLogs):
1568
1580
 
1569
1581
  def gdf(df, cname, default_value=np.nan):
1570
1582
  if cname in df.columns:
1571
- return df[cname]
1583
+ if np.isnan(default_value):
1584
+ return df[cname]
1585
+ return df[cname].fillna(default_value)
1572
1586
  return pandas.Series(default_value, index=df.index)
1573
1587
 
1574
1588
  def ghas_value(df, cname):
@@ -1676,15 +1690,54 @@ class CubeLogsPerformance(CubeLogs):
1676
1690
  "time_latency",
1677
1691
  gdf(df, "time_latency_eager") > gdf(df, "time_latency", np.inf) * 3.98,
1678
1692
  ),
1693
+ n_node_attention23=lambda df: gpreserve(
1694
+ df, "time_latency_eager", gdf(df, "op_onnx__Attention")
1695
+ ),
1696
+ n_node_rotary_embedding23=lambda df: gpreserve(
1697
+ df, "time_latency_eager", gdf(df, "op_onnx__RotaryEmbedding")
1698
+ ),
1699
+ n_node_layer_normalization23=lambda df: gpreserve(
1700
+ df,
1701
+ "time_latency_eager",
1702
+ gdf(df, "op_onnx__LayerNormalization", 0)
1703
+ + gdf(df, "op_onnx__RMSNormalization", 0)
1704
+ + gdf(df, "op_onnx__BatchNormlization", 0)
1705
+ + gdf(df, "op_onnx__InstanceNormlization", 0)
1706
+ + gdf(df, "op_onnx__GroupNormalization", 0),
1707
+ ),
1679
1708
  n_node_attention=lambda df: gpreserve(
1680
1709
  df,
1681
- "op_onnx_com.microsoft_Attention",
1682
- gdf(df, "op_onnx_com.microsoft_Attention")
1683
- + gdf(df, "op_onnx_com.microsoft_MultiHeadAttention"),
1710
+ "time_latency_eager",
1711
+ gdf(df, "op_onnx_com.microsoft_Attention", 0)
1712
+ + gdf(df, "op_onnx_com.microsoft_MultiHeadAttention", 0)
1713
+ + gdf(df, "op_onnx_com.microsoft_PackedAttention", 0)
1714
+ + gdf(df, "op_onnx_com.microsoft_PackedMultiHeadAttention", 0)
1715
+ + gdf(df, "op_onnx_com.microsoft_GroupQueryAttention", 0)
1716
+ + gdf(df, "op_onnx_com.microsoft_PagedAttention", 0)
1717
+ + gdf(df, "op_onnx_com.microsoft_DecoderAttention", 0)
1718
+ + gdf(df, "op_onnx_com.microsoft_LongformerAttention", 0)
1719
+ + gdf(df, "op_onnx_com.microsoft_DecoderMaskedSelfAttention", 0)
1720
+ + gdf(df, "op_onnx_com.microsoft_DecoderMaskedMultiHeadAttention", 0)
1721
+ + gdf(df, "op_onnx_com.microsoft_SparseAttention", 0),
1722
+ ),
1723
+ n_node_layer_normalization=lambda df: gpreserve(
1724
+ df,
1725
+ "time_latency_eager",
1726
+ gdf(df, "op_onnx_com.microsoft_EmbedLayerNormalization", 0)
1727
+ + gdf(df, "op_onnx_com.microsoft_SkipLayerNormalization", 0)
1728
+ + gdf(df, "op_onnx_com.microsoft_LayerNormalization", 0)
1729
+ + gdf(df, "op_onnx_com.microsoft_SkipSimplifiedLayerNormalization", 0)
1730
+ + gdf(df, "op_onnx_com.microsoft_SimplifiedLayerNormalization", 0),
1731
+ ),
1732
+ n_node_rotary_embedding=lambda df: gpreserve(
1733
+ df,
1734
+ "time_latency_eager",
1735
+ gdf(df, "op_onnx_com.microsoft_GemmaRotaryEmbedding", 0)
1736
+ + gdf(df, "op_onnx_com.microsoft_RotaryEmbedding", 0),
1684
1737
  ),
1685
1738
  n_node_control_flow=lambda df: gpreserve(
1686
1739
  df,
1687
- "op_onnx__If",
1740
+ "time_latency_eager",
1688
1741
  (
1689
1742
  gdf(df, "op_onnx__If", 0)
1690
1743
  + gdf(df, "op_onnx__Scan", 0)
@@ -1693,7 +1746,7 @@ class CubeLogsPerformance(CubeLogs):
1693
1746
  ),
1694
1747
  n_node_scatter=lambda df: gpreserve(
1695
1748
  df,
1696
- "op_onnx__ScatterND",
1749
+ "time_latency_eager",
1697
1750
  gdf(df, "op_onnx__ScatterND", 0) + gdf(df, "op_onnx__ScatterElements", 0),
1698
1751
  ),
1699
1752
  n_node_function=lambda df: gpreserve(
@@ -1706,13 +1759,13 @@ class CubeLogsPerformance(CubeLogs):
1706
1759
  df, "onnx_n_initializer", gdf(df, "onnx_n_initializer")
1707
1760
  ),
1708
1761
  n_node_constant=lambda df: gpreserve(
1709
- df, "op_onnx__Constant", gdf(df, "op_onnx__Constant")
1762
+ df, "time_latency_eager", gdf(df, "op_onnx__Constant")
1710
1763
  ),
1711
1764
  n_node_shape=lambda df: gpreserve(
1712
- df, "op_onnx__Shape", gdf(df, "op_onnx__Shape")
1765
+ df, "time_latency_eager", gdf(df, "op_onnx__Shape")
1713
1766
  ),
1714
1767
  n_node_expand=lambda df: gpreserve(
1715
- df, "op_onnx__Expand", gdf(df, "op_onnx__Expand")
1768
+ df, "time_latency_eager", gdf(df, "op_onnx__Expand")
1716
1769
  ),
1717
1770
  )
1718
1771
  assert (
@@ -3,7 +3,6 @@ import numpy as np
3
3
  import onnx
4
4
  import torch
5
5
  from .helper import string_type, flatten_object
6
- from .onnx_helper import dtype_to_tensor_dtype
7
6
  from .cache_helper import is_cache_dynamic_registered
8
7
 
9
8
 
@@ -23,6 +22,7 @@ def make_feeds(
23
22
  use_numpy: bool = False,
24
23
  copy: bool = False,
25
24
  check_flatten: bool = True,
25
+ is_modelbuilder: bool = False,
26
26
  ) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
27
27
  """
28
28
  Serializes the inputs to produce feeds expected
@@ -35,10 +35,15 @@ def make_feeds(
35
35
  by ``OrtValue``
36
36
  :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
37
37
  returns the same number of outputs
38
+ :param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
39
+ the past_key_values inputs to match the expected order, and get rid of position_ids.
38
40
  :return: feeds dictionary
39
41
  """
40
- # position_ids is a special case because ModelBuilder does not usually use it.
41
- # We use types to detect the best inputs.
42
+ # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
43
+ # because it's fued into rotary embedding in GQA.
44
+ if is_modelbuilder and isinstance(inputs, dict):
45
+ inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
46
+
42
47
  flat = flatten_object(inputs, drop_keys=True)
43
48
  assert (
44
49
  not check_flatten
@@ -76,39 +81,6 @@ def make_feeds(
76
81
  f"\n-- inputs={string_type(inputs, with_shape=True)}"
77
82
  f"\n-- names={names}"
78
83
  )
79
- if len(names) < len(flat) and (
80
- isinstance(proto, onnx.ModelProto) or hasattr(proto, "get_inputs")
81
- ):
82
-
83
- typed_names = (
84
- [(i.name, i.type.tensor_type.elem_type) for i in proto.graph.input]
85
- if isinstance(proto, onnx.ModelProto)
86
- else [(i.name, name_type_to_onnx_dtype(i.type)) for i in proto.get_inputs()]
87
- )
88
-
89
- new_flat = []
90
- pos = 0
91
- for _name, dtype in typed_names:
92
- assert isinstance(
93
- dtype, int
94
- ), f"Unexpected value for dtype={dtype!r}, type(proto)={type(proto)}"
95
- itype = dtype_to_tensor_dtype(flat[pos].dtype)
96
- while dtype != itype:
97
- pos += 1
98
- if pos >= len(flat):
99
- break
100
- itype = dtype_to_tensor_dtype(flat[pos].dtype)
101
- if pos >= len(flat):
102
- break
103
- new_flat.append(flat[pos])
104
- pos += 1
105
- assert len(new_flat) == len(names), (
106
- f"Unable to align expected input {names} with the given input, "
107
- f"type(proto)={type(proto)}"
108
- f"\n-- inputs: {string_type(inputs, with_shape=True)}"
109
- f"\n-- typed_names: {typed_names}"
110
- )
111
- flat = new_flat
112
84
 
113
85
  if copy:
114
86
  flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
@@ -122,4 +94,49 @@ def make_feeds(
122
94
  elif isinstance(i, float):
123
95
  i = np.array(i, dtype=np.float32)
124
96
  new_flat.append(i)
97
+
98
+ # NOTE: model builder has a different order for past_key_values
99
+ # we need to reorder them to match the expected order
100
+ if is_modelbuilder:
101
+ # We assume that if "past_key_values" is in the names when it's
102
+ # modelbuilder
103
+ non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
104
+ past_kv_names = [n for n in names if "past_key_values" in n]
105
+ reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
106
+ names = non_past_kv_input_names + reorder_past_kv_names
125
107
  return dict(zip(names, new_flat))
108
+
109
+
110
+ def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
111
+ """
112
+ Reorders the past_kvs for ModelBuilder to match the expected order
113
+ by PyTorch exported models.
114
+
115
+ .. note::
116
+ This function can take either the names or the actual tensors
117
+ as long as they are in a list.
118
+
119
+ Conceptually,
120
+
121
+ From::
122
+
123
+ [past_key_values.0.key, past_key_values.0.value,
124
+ past_key_values.1.key, past_key_values.1.value, ...]
125
+
126
+ To::
127
+
128
+ [past_key_values.0.key, past_key_values.1.key,
129
+ ..., past_key_values.0.value, past_key_values.1.value, ...]
130
+
131
+ :param past_kv: list of flattened inputs
132
+ :return: reordered list of flattened inputs
133
+ """
134
+ total_len = len(past_kv)
135
+ if total_len % 2 != 0:
136
+ raise ValueError("The length of past_key_values should be even.")
137
+ keys = []
138
+ values = []
139
+ for i in range(0, total_len, 2):
140
+ keys.append(past_kv[i])
141
+ values.append(past_kv[i + 1])
142
+ return keys + values
@@ -5,6 +5,8 @@ from . import (
5
5
  fill_mask,
6
6
  image_classification,
7
7
  image_text_to_text,
8
+ image_to_video,
9
+ mask_generation,
8
10
  mixture_of_expert,
9
11
  object_detection,
10
12
  sentence_similarity,
@@ -14,7 +16,6 @@ from . import (
14
16
  text_to_image,
15
17
  text2text_generation,
16
18
  zero_shot_image_classification,
17
- mask_generation,
18
19
  )
19
20
 
20
21
  __TASKS__ = [
@@ -23,6 +24,8 @@ __TASKS__ = [
23
24
  fill_mask,
24
25
  image_classification,
25
26
  image_text_to_text,
27
+ image_to_video,
28
+ mask_generation,
26
29
  mixture_of_expert,
27
30
  object_detection,
28
31
  sentence_similarity,
@@ -32,7 +35,6 @@ __TASKS__ = [
32
35
  text_to_image,
33
36
  text2text_generation,
34
37
  zero_shot_image_classification,
35
- mask_generation,
36
38
  ]
37
39
 
38
40
 
@@ -0,0 +1,127 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "image-to-video"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
15
+ # We cannot reduce.
16
+ return {}
17
+ check_hasattr(config, ("num_hidden_layers", "num_layers"))
18
+ kwargs = {}
19
+ if hasattr(config, "num_layers"):
20
+ kwargs["num_layers"] = min(config.num_layers, nhl())
21
+ if hasattr(config, "num_hidden_layers"):
22
+ kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
23
+
24
+ update_config(config, kwargs)
25
+ return kwargs
26
+
27
+
28
+ def get_inputs(
29
+ model: torch.nn.Module,
30
+ config: Optional[Any],
31
+ text_embed_dim: int,
32
+ latent_channels: int,
33
+ batch_size: int = 2,
34
+ image_height: int = 704,
35
+ image_width: int = 1280,
36
+ latent_frames: int = 1,
37
+ text_maxlen: int = 512,
38
+ add_second_input: int = 1,
39
+ **kwargs, # unused
40
+ ):
41
+ """
42
+ Generates inputs for task ``image-to-video``.
43
+ """
44
+ assert (
45
+ "cls_cache" not in kwargs
46
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
47
+ latent_height = image_height // 8
48
+ latent_width = image_width // 8
49
+ dtype = torch.float32
50
+
51
+ inputs = dict(
52
+ hidden_states=torch.randn(
53
+ batch_size,
54
+ latent_channels,
55
+ latent_frames,
56
+ latent_height,
57
+ latent_width,
58
+ dtype=dtype,
59
+ ),
60
+ timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
61
+ encoder_hidden_states=torch.randn(
62
+ batch_size, text_maxlen, text_embed_dim, dtype=dtype
63
+ ),
64
+ padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
65
+ fps=torch.tensor([16] * batch_size, dtype=dtype),
66
+ condition_mask=torch.randn(
67
+ batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
68
+ ),
69
+ )
70
+ shapes = dict(
71
+ hidden_states={
72
+ 0: "batch_size",
73
+ 2: "latent_frames",
74
+ 3: "latent_height",
75
+ 4: "latent_width",
76
+ },
77
+ timestep={0: "batch_size"},
78
+ encoder_hidden_states={0: "batch_size"},
79
+ padding_mask={0: "batch_size", 2: "height", 3: "width"},
80
+ fps={0: "batch_size"},
81
+ condition_mask={
82
+ 0: "batch_size",
83
+ 2: "latent_frames",
84
+ 3: "latent_height",
85
+ 4: "latent_width",
86
+ },
87
+ )
88
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
89
+
90
+ if add_second_input:
91
+ assert (
92
+ add_second_input > 0
93
+ ), f"Not implemented for add_second_input={add_second_input}."
94
+ res["inputs2"] = get_inputs(
95
+ model=model,
96
+ config=config,
97
+ text_embed_dim=text_embed_dim,
98
+ latent_channels=latent_channels,
99
+ batch_size=batch_size,
100
+ image_height=image_height,
101
+ image_width=image_width,
102
+ latent_frames=latent_frames,
103
+ text_maxlen=text_maxlen,
104
+ add_second_input=0,
105
+ **kwargs,
106
+ )["inputs"]
107
+ return res
108
+
109
+
110
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
111
+ """
112
+ Inputs kwargs.
113
+
114
+ If the configuration is None, the function selects typical dimensions.
115
+ """
116
+ if config is not None:
117
+ check_hasattr(config, "in_channels", "text_embed_dim"),
118
+ kwargs = dict(
119
+ text_embed_dim=1024 if config is None else config.text_embed_dim,
120
+ latent_channels=16 if config is None else config.in_channels - 1,
121
+ batch_size=1,
122
+ image_height=8 * 50,
123
+ image_width=8 * 80,
124
+ latent_frames=1,
125
+ text_maxlen=512,
126
+ )
127
+ return kwargs, get_inputs
@@ -254,6 +254,17 @@ def torch_export_patches(
254
254
  may appear ``AssertionError: Mutating module attribute _seen_tokens during export.``.
255
255
  It can be avoided by setting ``strict=False`` when call :func:`torch.export.export`.
256
256
  """
257
+ if verbose:
258
+ print(f"[torch_export_patches] patch_sympy={patch_sympy!r}")
259
+ print(f" . patch_torch={patch_torch!r}")
260
+ print(f" . patch_transformers={patch_transformers!r}")
261
+ print(f" . patch_diffusers={patch_diffusers!r}")
262
+ print(f" . catch_constraints={catch_constraints!r}")
263
+ print(f" . stop_if_static={stop_if_static!r}")
264
+ print(f" . patch={patch!r}")
265
+ print(f" . custom_patches={custom_patches!r}")
266
+ print(f"[torch_export_patches] dump_rewriting={dump_rewriting!r}")
267
+
257
268
  if rewrite:
258
269
  from .patch_module import torch_export_rewrite
259
270
 
@@ -35,6 +35,9 @@ except ImportError:
35
35
  from ...ext_test_case import has_transformers
36
36
  from ...helpers.torch_helper import is_torchdynamo_exporting
37
37
 
38
+ patch_is_initialized = pv.Version(transformers.__version__) > pv.Version("4.56.99")
39
+
40
+
38
41
  if patch_masking_utils:
39
42
  # Introduced in 4.52
40
43
  from transformers.masking_utils import (
@@ -213,6 +216,8 @@ if patch_DynamicLayer:
213
216
  new_shape[-2] = 0
214
217
  self.keys = torch.empty(new_shape, dtype=self.dtype, device=self.device)
215
218
  self.values = torch.empty(new_shape, dtype=self.dtype, device=self.device)
219
+ if patch_is_initialized:
220
+ self.is_initialized = True
216
221
 
217
222
 
218
223
  def _patch_make_causal_mask(