onnx-diagnostic 0.7.3__py3-none-any.whl → 0.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +82 -12
  3. onnx_diagnostic/export/shape_helper.py +71 -0
  4. onnx_diagnostic/helpers/_log_helper.py +461 -0
  5. onnx_diagnostic/helpers/cache_helper.py +11 -1
  6. onnx_diagnostic/helpers/log_helper.py +404 -315
  7. onnx_diagnostic/reference/ops/op_cast_like.py +12 -8
  8. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  9. onnx_diagnostic/tasks/feature_extraction.py +92 -7
  10. onnx_diagnostic/tasks/fill_mask.py +6 -2
  11. onnx_diagnostic/tasks/image_classification.py +7 -3
  12. onnx_diagnostic/tasks/image_text_to_text.py +6 -2
  13. onnx_diagnostic/tasks/mixture_of_expert.py +1 -1
  14. onnx_diagnostic/tasks/object_detection.py +7 -3
  15. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  16. onnx_diagnostic/tasks/summarization.py +6 -2
  17. onnx_diagnostic/tasks/text2text_generation.py +8 -4
  18. onnx_diagnostic/tasks/text_classification.py +6 -2
  19. onnx_diagnostic/tasks/text_generation.py +5 -3
  20. onnx_diagnostic/tasks/text_to_image.py +6 -2
  21. onnx_diagnostic/tasks/zero_shot_image_classification.py +6 -2
  22. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +63 -7
  23. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +188 -51
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +6 -1
  25. onnx_diagnostic/torch_models/validate.py +49 -10
  26. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/METADATA +1 -1
  27. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/RECORD +30 -29
  28. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/WHEEL +0 -0
  29. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/licenses/LICENSE.txt +0 -0
  30. {onnx_diagnostic-0.7.3.dist-info → onnx_diagnostic-0.7.5.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,17 @@
1
1
  from onnx.onnx_pb import TensorProto
2
2
  from onnx.reference.op_run import OpRun
3
- from onnx.reference.ops.op_cast import (
4
- bfloat16,
5
- cast_to,
6
- float8e4m3fn,
7
- float8e4m3fnuz,
8
- float8e5m2,
9
- float8e5m2fnuz,
10
- )
3
+
4
+ try:
5
+ from onnx.reference.ops.op_cast import (
6
+ bfloat16,
7
+ cast_to,
8
+ float8e4m3fn,
9
+ float8e4m3fnuz,
10
+ float8e5m2,
11
+ float8e5m2fnuz,
12
+ )
13
+ except ImportError:
14
+ from onnx.reference.ops.op_cast import cast_to
11
15
  from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
12
16
 
13
17
 
@@ -33,7 +33,7 @@ def get_inputs(
33
33
  head_dim: int,
34
34
  batch_size: int = 2,
35
35
  sequence_length: int = 30,
36
- add_second_input: bool = False,
36
+ add_second_input: int = 1,
37
37
  **kwargs, # unused
38
38
  ):
39
39
  """
@@ -132,6 +132,9 @@ def get_inputs(
132
132
  )
133
133
  res = dict(inputs=inputs, dynamic_shapes=shapes)
134
134
  if add_second_input:
135
+ assert (
136
+ add_second_input > 0
137
+ ), f"Not implemented for add_second_input={add_second_input}."
135
138
  res["inputs2"] = get_inputs(
136
139
  model=model,
137
140
  config=config,
@@ -144,7 +147,8 @@ def get_inputs(
144
147
  decoder_layers=decoder_layers,
145
148
  head_dim=head_dim,
146
149
  batch_size=batch_size + 1,
147
- sequence_length=sequence_length + 1,
150
+ sequence_length=sequence_length + add_second_input,
151
+ add_second_input=0,
148
152
  **kwargs,
149
153
  )["inputs"]
150
154
  return res
@@ -1,17 +1,15 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  from ..helpers.config_helper import update_config, check_hasattr
4
+ from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4
5
 
5
6
  __TASK__ = "feature-extraction"
6
7
 
7
8
 
8
9
  def reduce_model_config(config: Any) -> Dict[str, Any]:
9
10
  """Reduces a model size."""
10
- check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
- kwargs = dict(
12
- num_hidden_layers=min(config.num_hidden_layers, 2),
13
- num_attention_heads=min(config.num_attention_heads, 4),
14
- )
11
+ check_hasattr(config, "num_hidden_layers")
12
+ kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, 2))
15
13
  update_config(config, kwargs)
16
14
  return kwargs
17
15
 
@@ -22,7 +20,13 @@ def get_inputs(
22
20
  batch_size: int,
23
21
  sequence_length: int,
24
22
  dummy_max_token_id: int,
25
- add_second_input: bool = False,
23
+ sequence_length2: int = 3,
24
+ decoder_attention_heads: Optional[int] = None,
25
+ encoder_attention_heads: Optional[int] = None,
26
+ encoder_ffn_dim: Optional[int] = None,
27
+ decoder_ffn_dim: Optional[int] = None,
28
+ num_hidden_layers: Optional[int] = None,
29
+ add_second_input: int = 1,
26
30
  **kwargs, # unused
27
31
  ):
28
32
  """
@@ -50,14 +54,84 @@ def get_inputs(
50
54
  ),
51
55
  attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
52
56
  )
57
+ if (
58
+ encoder_attention_heads
59
+ and decoder_attention_heads
60
+ and encoder_ffn_dim
61
+ and decoder_ffn_dim
62
+ and num_hidden_layers
63
+ ):
64
+ inputs["past_key_values"] = make_encoder_decoder_cache(
65
+ make_dynamic_cache(
66
+ [
67
+ (
68
+ torch.randn(
69
+ batch_size,
70
+ encoder_attention_heads,
71
+ sequence_length,
72
+ encoder_ffn_dim,
73
+ ),
74
+ torch.randn(
75
+ batch_size,
76
+ encoder_attention_heads,
77
+ sequence_length,
78
+ encoder_ffn_dim,
79
+ ),
80
+ )
81
+ for i in range(num_hidden_layers)
82
+ ]
83
+ ),
84
+ make_dynamic_cache(
85
+ [
86
+ (
87
+ torch.randn(
88
+ batch_size,
89
+ decoder_attention_heads,
90
+ sequence_length2,
91
+ decoder_ffn_dim,
92
+ ),
93
+ torch.randn(
94
+ batch_size,
95
+ decoder_attention_heads,
96
+ sequence_length2,
97
+ decoder_ffn_dim,
98
+ ),
99
+ )
100
+ for i in range(num_hidden_layers)
101
+ ]
102
+ ),
103
+ )
104
+ cache_length = "cache_length_key"
105
+ cache_length2 = "cache_length_val"
106
+ shapes["past_key_values"] = [ # type: ignore[assignment]
107
+ [
108
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
109
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
110
+ ],
111
+ [
112
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
113
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
114
+ ],
115
+ ]
116
+
53
117
  res = dict(inputs=inputs, dynamic_shapes=shapes)
54
118
  if add_second_input:
119
+ assert (
120
+ add_second_input > 0
121
+ ), f"Not implemented for add_second_input={add_second_input}."
55
122
  res["inputs2"] = get_inputs(
56
123
  model=model,
57
124
  config=config,
58
125
  batch_size=batch_size + 1,
59
- sequence_length=sequence_length + 1,
126
+ sequence_length=sequence_length + add_second_input,
60
127
  dummy_max_token_id=dummy_max_token_id,
128
+ sequence_length2=sequence_length2,
129
+ decoder_attention_heads=decoder_attention_heads,
130
+ encoder_attention_heads=encoder_attention_heads,
131
+ encoder_ffn_dim=encoder_ffn_dim,
132
+ decoder_ffn_dim=decoder_ffn_dim,
133
+ num_hidden_layers=num_hidden_layers,
134
+ add_second_input=0,
61
135
  **kwargs,
62
136
  )["inputs"]
63
137
  return res
@@ -76,4 +150,15 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
76
150
  sequence_length=30,
77
151
  dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
78
152
  )
153
+ for att in [
154
+ "decoder_attention_heads",
155
+ "encoder_attention_heads",
156
+ "encoder_ffn_dim",
157
+ "decoder_ffn_dim",
158
+ "num_hidden_layers",
159
+ ]:
160
+ if hasattr(config, att):
161
+ kwargs[att] = getattr(config, att)
162
+ kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
163
+ print(kwargs)
79
164
  return kwargs, get_inputs
@@ -22,7 +22,7 @@ def get_inputs(
22
22
  batch_size: int,
23
23
  sequence_length: int,
24
24
  dummy_max_token_id: int,
25
- add_second_input: bool = False,
25
+ add_second_input: int = 1,
26
26
  **kwargs, # unused
27
27
  ):
28
28
  """
@@ -54,12 +54,16 @@ def get_inputs(
54
54
  )
55
55
  res = dict(inputs=inputs, dynamic_shapes=shapes)
56
56
  if add_second_input:
57
+ assert (
58
+ add_second_input > 0
59
+ ), f"Not implemented for add_second_input={add_second_input}."
57
60
  res["inputs2"] = get_inputs(
58
61
  model=model,
59
62
  config=config,
60
63
  batch_size=batch_size + 1,
61
- sequence_length=sequence_length + 1,
64
+ sequence_length=sequence_length + add_second_input,
62
65
  dummy_max_token_id=dummy_max_token_id,
66
+ add_second_input=0,
63
67
  **kwargs,
64
68
  )["inputs"]
65
69
  return res
@@ -34,7 +34,7 @@ def get_inputs(
34
34
  input_channels: int,
35
35
  batch_size: int = 2,
36
36
  dynamic_rope: bool = False,
37
- add_second_input: bool = False,
37
+ add_second_input: int = 1,
38
38
  **kwargs, # unused
39
39
  ):
40
40
  """
@@ -75,14 +75,18 @@ def get_inputs(
75
75
  shapes["interpolate_pos_encoding"] = None # type: ignore[assignment]
76
76
  res = dict(inputs=inputs, dynamic_shapes=shapes)
77
77
  if add_second_input:
78
+ assert (
79
+ add_second_input > 0
80
+ ), f"Not implemented for add_second_input={add_second_input}."
78
81
  res["inputs2"] = get_inputs(
79
82
  model=model,
80
83
  config=config,
81
- input_width=input_width + 1,
82
- input_height=input_height + 1,
84
+ input_width=input_width + add_second_input,
85
+ input_height=input_height + add_second_input,
83
86
  input_channels=input_channels,
84
87
  batch_size=batch_size + 1,
85
88
  dynamic_rope=dynamic_rope,
89
+ add_second_input=0,
86
90
  **kwargs,
87
91
  )["inputs"]
88
92
  return res
@@ -32,7 +32,7 @@ def get_inputs(
32
32
  sequence_length2: int = 3,
33
33
  n_images: int = 2,
34
34
  dynamic_rope: bool = False,
35
- add_second_input: bool = False,
35
+ add_second_input: int = 1,
36
36
  **kwargs, # unused
37
37
  ):
38
38
  """
@@ -105,6 +105,9 @@ def get_inputs(
105
105
  )
106
106
  res = dict(inputs=inputs, dynamic_shapes=shapes)
107
107
  if add_second_input:
108
+ assert (
109
+ add_second_input > 0
110
+ ), f"Not implemented for add_second_input={add_second_input}."
108
111
  res["inputs2"] = get_inputs(
109
112
  model=model,
110
113
  config=config,
@@ -116,10 +119,11 @@ def get_inputs(
116
119
  height=height,
117
120
  num_channels=num_channels,
118
121
  batch_size=batch_size + 1,
119
- sequence_length=sequence_length + 1,
122
+ sequence_length=sequence_length + add_second_input,
120
123
  sequence_length2=sequence_length2 + 1,
121
124
  n_images=n_images + 1,
122
125
  dynamic_rope=dynamic_rope,
126
+ add_second_input=0,
123
127
  **kwargs,
124
128
  )["inputs"]
125
129
  return res
@@ -41,7 +41,7 @@ def get_inputs(
41
41
  sequence_length2: int = 3,
42
42
  n_images: int = 2,
43
43
  dynamic_rope: bool = False,
44
- add_second_input: bool = False,
44
+ add_second_input: int = 1,
45
45
  **kwargs, # unused
46
46
  ):
47
47
  """
@@ -27,7 +27,7 @@ def get_inputs(
27
27
  input_channels: int,
28
28
  batch_size: int = 2,
29
29
  dynamic_rope: bool = False,
30
- add_second_input: bool = False,
30
+ add_second_input: int = 1,
31
31
  **kwargs, # unused
32
32
  ):
33
33
  """
@@ -65,14 +65,18 @@ def get_inputs(
65
65
  )
66
66
  res = dict(inputs=inputs, dynamic_shapes=shapes)
67
67
  if add_second_input:
68
+ assert (
69
+ add_second_input > 0
70
+ ), f"Not implemented for add_second_input={add_second_input}."
68
71
  res["inputs2"] = get_inputs(
69
72
  model=model,
70
73
  config=config,
71
- input_width=input_width + 1,
72
- input_height=input_height + 1,
74
+ input_width=input_width + add_second_input,
75
+ input_height=input_height + add_second_input,
73
76
  input_channels=input_channels,
74
77
  batch_size=batch_size + 1,
75
78
  dynamic_rope=dynamic_rope,
79
+ add_second_input=0,
76
80
  **kwargs,
77
81
  )["inputs"]
78
82
  return res
@@ -22,7 +22,7 @@ def get_inputs(
22
22
  batch_size: int,
23
23
  sequence_length: int,
24
24
  dummy_max_token_id: int,
25
- add_second_input: bool = False,
25
+ add_second_input: int = 1,
26
26
  **kwargs, # unused
27
27
  ):
28
28
  """
@@ -54,12 +54,16 @@ def get_inputs(
54
54
  )
55
55
  res = dict(inputs=inputs, dynamic_shapes=shapes)
56
56
  if add_second_input:
57
+ assert (
58
+ add_second_input > 0
59
+ ), f"Not implemented for add_second_input={add_second_input}."
57
60
  res["inputs2"] = get_inputs(
58
61
  model=model,
59
62
  config=config,
60
63
  batch_size=batch_size + 1,
61
- sequence_length=sequence_length + 1,
64
+ sequence_length=sequence_length + add_second_input,
62
65
  dummy_max_token_id=dummy_max_token_id,
66
+ add_second_input=0,
63
67
  **kwargs,
64
68
  )["inputs"]
65
69
  return res
@@ -29,7 +29,7 @@ def get_inputs(
29
29
  batch_size: int = 2,
30
30
  sequence_length: int = 30,
31
31
  sequence_length2: int = 3,
32
- add_second_input: bool = False,
32
+ add_second_input: int = 1,
33
33
  **kwargs, # unused
34
34
  ):
35
35
  """
@@ -144,6 +144,9 @@ def get_inputs(
144
144
  )
145
145
  res = dict(inputs=inputs, dynamic_shapes=shapes)
146
146
  if add_second_input:
147
+ assert (
148
+ add_second_input > 0
149
+ ), f"Not implemented for add_second_input={add_second_input}."
147
150
  res["inputs2"] = get_inputs(
148
151
  model=model,
149
152
  config=config,
@@ -154,8 +157,9 @@ def get_inputs(
154
157
  head_dim_encoder=head_dim_encoder,
155
158
  head_dim_decoder=head_dim_decoder,
156
159
  batch_size=batch_size + 1,
157
- sequence_length=sequence_length + 1,
160
+ sequence_length=sequence_length + add_second_input,
158
161
  sequence_length2=sequence_length2 + 1,
162
+ add_second_input=0,
159
163
  **kwargs,
160
164
  )["inputs"]
161
165
  return res
@@ -30,7 +30,7 @@ def get_inputs(
30
30
  batch_size: int = 2,
31
31
  sequence_length: int = 30,
32
32
  sequence_length2: int = 3,
33
- add_second_input: bool = False,
33
+ add_second_input: int = 1,
34
34
  **kwargs, # unused
35
35
  ):
36
36
  """
@@ -69,8 +69,8 @@ def get_inputs(
69
69
  ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
70
70
  batch = torch.export.Dim("batch", min=1, max=1024)
71
71
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
72
- cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
73
- cache_length2 = "cache_length_val" # torch.export.Dim("cache_length2", min=1, max=4096)
72
+ cache_length = "cache_length_key"
73
+ cache_length2 = "cache_length_val"
74
74
 
75
75
  shapes = {
76
76
  "input_ids": {0: batch, 1: seq_length},
@@ -149,6 +149,9 @@ def get_inputs(
149
149
  )
150
150
  res = dict(inputs=inputs, dynamic_shapes=shapes)
151
151
  if add_second_input:
152
+ assert (
153
+ add_second_input > 0
154
+ ), f"Not implemented for add_second_input={add_second_input}."
152
155
  res["inputs2"] = get_inputs(
153
156
  model=model,
154
157
  config=config,
@@ -160,8 +163,9 @@ def get_inputs(
160
163
  head_dim_decoder=head_dim_decoder,
161
164
  encoder_dim=encoder_dim,
162
165
  batch_size=batch_size + 1,
163
- sequence_length=sequence_length + 1,
166
+ sequence_length=sequence_length + add_second_input,
164
167
  sequence_length2=sequence_length2 + 1,
168
+ add_second_input=0,
165
169
  **kwargs,
166
170
  )["inputs"]
167
171
  return res
@@ -22,7 +22,7 @@ def get_inputs(
22
22
  batch_size: int,
23
23
  sequence_length: int,
24
24
  dummy_max_token_id: int,
25
- add_second_input: bool = False,
25
+ add_second_input: int = 1,
26
26
  **kwargs, # unused
27
27
  ):
28
28
  """
@@ -54,12 +54,16 @@ def get_inputs(
54
54
  )
55
55
  res = dict(inputs=inputs, dynamic_shapes=shapes)
56
56
  if add_second_input:
57
+ assert (
58
+ add_second_input > 0
59
+ ), f"Not implemented for add_second_input={add_second_input}."
57
60
  res["inputs2"] = get_inputs(
58
61
  model=model,
59
62
  config=config,
60
63
  batch_size=batch_size + 1,
61
- sequence_length=sequence_length + 1,
64
+ sequence_length=sequence_length + add_second_input,
62
65
  dummy_max_token_id=dummy_max_token_id,
66
+ add_second_input=0,
63
67
  **kwargs,
64
68
  )["inputs"]
65
69
  return res
@@ -72,7 +72,7 @@ def get_inputs(
72
72
  num_key_value_heads: Optional[int] = None,
73
73
  head_dim: Optional[int] = None,
74
74
  cls_cache: Optional[Union[type, str]] = None,
75
- add_second_input: bool = False,
75
+ add_second_input: int = 1,
76
76
  **kwargs, # unused
77
77
  ):
78
78
  """
@@ -260,13 +260,15 @@ def get_inputs(
260
260
  config=config,
261
261
  dummy_max_token_id=dummy_max_token_id,
262
262
  num_hidden_layers=num_hidden_layers,
263
- batch_size=batch_size + 1,
263
+ batch_size=(batch_size + 1) if add_second_input > 0 else 1,
264
264
  sequence_length=sequence_length + 1,
265
- sequence_length2=sequence_length2 + 1,
265
+ sequence_length2=sequence_length2
266
+ + (add_second_input if add_second_input > 0 else -add_second_input),
266
267
  dynamic_rope=dynamic_rope,
267
268
  num_key_value_heads=num_key_value_heads,
268
269
  head_dim=head_dim,
269
270
  cls_cache=cls_cache,
271
+ add_second_input=0,
270
272
  **kwargs,
271
273
  )["inputs"]
272
274
  return res
@@ -25,7 +25,7 @@ def get_inputs(
25
25
  in_channels: int,
26
26
  sample_size: int,
27
27
  cross_attention_dim: int,
28
- add_second_input: bool = False,
28
+ add_second_input: int = 1,
29
29
  **kwargs, # unused
30
30
  ):
31
31
  """
@@ -58,15 +58,19 @@ def get_inputs(
58
58
  )
59
59
  res = dict(inputs=inputs, dynamic_shapes=shapes)
60
60
  if add_second_input:
61
+ assert (
62
+ add_second_input > 0
63
+ ), f"Not implemented for add_second_input={add_second_input}."
61
64
  res["inputs2"] = get_inputs(
62
65
  model=model,
63
66
  config=config,
64
67
  batch_size=batch_size + 1,
65
68
  sequence_length=sequence_length,
66
- cache_length=cache_length + 1,
69
+ cache_length=cache_length + add_second_input,
67
70
  in_channels=in_channels,
68
71
  sample_size=sample_size,
69
72
  cross_attention_dim=cross_attention_dim,
73
+ add_second_input=0,
70
74
  **kwargs,
71
75
  )["inputs"]
72
76
  return res
@@ -34,7 +34,7 @@ def get_inputs(
34
34
  input_height: int = 224,
35
35
  input_channels: int = 3,
36
36
  batch_size_image=3,
37
- add_second_input: bool = False,
37
+ add_second_input: int = 1,
38
38
  **kwargs, # unused
39
39
  ):
40
40
  """
@@ -87,16 +87,20 @@ def get_inputs(
87
87
  )
88
88
  res = dict(inputs=inputs, dynamic_shapes=shapes)
89
89
  if add_second_input:
90
+ assert (
91
+ add_second_input > 0
92
+ ), f"Not implemented for add_second_input={add_second_input}."
90
93
  res["inputs2"] = get_inputs(
91
94
  model=model,
92
95
  config=config,
93
96
  dummy_max_token_id=dummy_max_token_id,
94
97
  batch_size=batch_size + 1,
95
- sequence_length=sequence_length + 1,
98
+ sequence_length=sequence_length + add_second_input,
96
99
  input_width=input_width,
97
100
  input_height=input_height,
98
101
  input_channels=input_channels,
99
102
  batch_size_image=batch_size_image + 1,
103
+ add_second_input=0,
100
104
  **kwargs,
101
105
  )["inputs"]
102
106
  return res
@@ -16,6 +16,8 @@ def get_function(name: str) -> Tuple[type, Callable]:
16
16
  module_name = ".".join(spl[:-1])
17
17
  fname = spl[-1]
18
18
  mod = importlib.import_module(module_name)
19
+ if not hasattr(mod, fname):
20
+ return None, None
19
21
  return mod, getattr(mod, fname)
20
22
 
21
23
 
@@ -33,12 +35,16 @@ def get_patches(mod, verbose: int = 0) -> Tuple[str, List[Any]]:
33
35
  doc = v.__doc__.lstrip()
34
36
  if doc.startswith("manual patch"):
35
37
  continue
36
- reg = re.compile("[[]patch:([a-z_A-Z.]+)[]]")
38
+ reg = re.compile("[\\[]patch:([a-z_A-Z.]+)[\\]]")
37
39
  fall = reg.findall(doc)
38
40
  assert (
39
41
  len(fall) == 1
40
42
  ), f"Unable to find patching information for {v} in \n{doc}"
41
43
  fmod, f = get_function(fall[0])
44
+ if fmod is None and f is None:
45
+ # The function does not exist in this version of transformers.
46
+ # No patch is needed.
47
+ continue
42
48
  to_patch.append({"module": fmod, "function": f, "patch": v})
43
49
 
44
50
  name = mod.__name__
@@ -420,7 +426,11 @@ def torch_export_patches(
420
426
  patch_transformers_list, verbose=verbose
421
427
  )
422
428
 
423
- if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
429
+ if (
430
+ masking_utils
431
+ and patch_transformers_list.patch_masking_utils
432
+ and hasattr(masking_utils, "_vmap_for_bhqkv")
433
+ ):
424
434
  if verbose:
425
435
  print(
426
436
  "[torch_export_patches] patches "
@@ -429,6 +439,27 @@ def torch_export_patches(
429
439
  f_transformers__vmap_for_bhqkv = masking_utils._vmap_for_bhqkv
430
440
  masking_utils._vmap_for_bhqkv = patch_transformers_list.patched__vmap_for_bhqkv
431
441
 
442
+ if (
443
+ masking_utils
444
+ and patch_transformers_list.patch_masking_utils
445
+ and hasattr(masking_utils, "eager_mask")
446
+ ):
447
+ if verbose:
448
+ print(
449
+ "[torch_export_patches] patches "
450
+ "transformers.masking_utils.eager_mask"
451
+ )
452
+ f_transformers_eager_mask = masking_utils.eager_mask
453
+ masking_utils.eager_mask = patch_transformers_list.patched_eager_mask
454
+ if (
455
+ "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
456
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
457
+ == f_transformers_eager_mask
458
+ ):
459
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
460
+ patch_transformers_list.patched_eager_mask
461
+ )
462
+
432
463
  if custom_patches:
433
464
  if verbose:
434
465
  print("[torch_export_patches] applies custom patches")
@@ -511,7 +542,7 @@ def torch_export_patches(
511
542
 
512
543
  if custom_patches:
513
544
  if verbose:
514
- print("[torch_export_patches] unpatch custom patches")
545
+ print("[torch_export_patches] unpatches custom patches")
515
546
  unpatch_module_or_classes(
516
547
  custom_patches, revert_custom_patches_info, verbose=verbose
517
548
  )
@@ -526,18 +557,43 @@ def torch_export_patches(
526
557
  except ImportError:
527
558
  masking_utils = None
528
559
  if verbose:
529
- print("[torch_export_patches] unpatch transformers")
560
+ print("[torch_export_patches] unpatches transformers")
530
561
  unpatch_module_or_classes(
531
562
  patch_transformers_list, revert_patches_info, verbose=verbose
532
563
  )
533
564
 
534
- if masking_utils and hasattr(masking_utils, "_vmap_for_bhqkv"):
565
+ if (
566
+ masking_utils
567
+ and patch_transformers_list.patch_masking_utils
568
+ and hasattr(masking_utils, "_vmap_for_bhqkv")
569
+ ):
570
+ masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
535
571
  if verbose:
536
572
  print(
537
- "[torch_export_patches] unpatch "
573
+ "[torch_export_patches] restored "
538
574
  "transformers.masking_utils._vmap_for_bhqkv"
539
575
  )
540
- masking_utils._vmap_for_bhqkv = f_transformers__vmap_for_bhqkv
576
+
577
+ if (
578
+ masking_utils
579
+ and patch_transformers_list.patch_masking_utils
580
+ and hasattr(masking_utils, "eager_mask")
581
+ ):
582
+ f_transformers_eager_mask = masking_utils.eager_mask
583
+ masking_utils.eager_mask = f_transformers_eager_mask
584
+ if (
585
+ "eager" in masking_utils.ALL_MASK_ATTENTION_FUNCTIONS
586
+ and masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"]
587
+ == patch_transformers_list.patched_eager_mask
588
+ ):
589
+ masking_utils.ALL_MASK_ATTENTION_FUNCTIONS["eager"] = (
590
+ f_transformers_eager_mask
591
+ )
592
+ if verbose:
593
+ print(
594
+ "[torch_export_patches] restored "
595
+ "transformers.masking_utils.eager_mask"
596
+ )
541
597
 
542
598
  ########
543
599
  # caches