onnx-diagnostic 0.7.0__py3-none-any.whl → 0.7.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +213 -5
  3. onnx_diagnostic/export/dynamic_shapes.py +48 -20
  4. onnx_diagnostic/export/shape_helper.py +126 -0
  5. onnx_diagnostic/ext_test_case.py +31 -0
  6. onnx_diagnostic/helpers/cache_helper.py +42 -20
  7. onnx_diagnostic/helpers/config_helper.py +16 -1
  8. onnx_diagnostic/helpers/log_helper.py +1561 -177
  9. onnx_diagnostic/helpers/torch_helper.py +6 -2
  10. onnx_diagnostic/tasks/__init__.py +2 -0
  11. onnx_diagnostic/tasks/image_text_to_text.py +69 -18
  12. onnx_diagnostic/tasks/text_generation.py +17 -8
  13. onnx_diagnostic/tasks/text_to_image.py +91 -0
  14. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  15. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +144 -349
  16. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +87 -7
  17. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  18. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  19. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  20. onnx_diagnostic/torch_models/hghub/hub_api.py +73 -5
  21. onnx_diagnostic/torch_models/hghub/hub_data.py +7 -2
  22. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  23. onnx_diagnostic/torch_models/hghub/model_inputs.py +74 -14
  24. onnx_diagnostic/torch_models/validate.py +45 -16
  25. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/RECORD +29 -24
  27. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.7.0.dist-info → onnx_diagnostic-0.7.2.dist-info}/top_level.txt +0 -0
@@ -735,7 +735,8 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
735
735
  [t.to(to_value) for t in value.key_cache],
736
736
  [t.to(to_value) for t in value.value_cache],
737
737
  )
738
- )
738
+ ),
739
+ max_cache_len=value.max_cache_len,
739
740
  )
740
741
  if value.__class__.__name__ == "EncoderDecoderCache":
741
742
  return make_encoder_decoder_cache(
@@ -784,7 +785,10 @@ def torch_deepcopy(value: Any) -> Any:
784
785
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
785
786
  )
786
787
  if value.__class__.__name__ == "StaticCache":
787
- return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
788
+ return make_static_cache(
789
+ torch_deepcopy(list(zip(value.key_cache, value.value_cache))),
790
+ max_cache_len=value.max_cache_len,
791
+ )
788
792
  if value.__class__.__name__ == "SlidingWindowCache":
789
793
  return make_sliding_window_cache(
790
794
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
@@ -11,6 +11,7 @@ from . import (
11
11
  summarization,
12
12
  text_classification,
13
13
  text_generation,
14
+ text_to_image,
14
15
  text2text_generation,
15
16
  zero_shot_image_classification,
16
17
  )
@@ -27,6 +28,7 @@ __TASKS__ = [
27
28
  summarization,
28
29
  text_classification,
29
30
  text_generation,
31
+ text_to_image,
30
32
  text2text_generation,
31
33
  zero_shot_image_classification,
32
34
  ]
@@ -96,10 +96,10 @@ def get_inputs(
96
96
  for i in range(num_hidden_layers)
97
97
  ]
98
98
  ),
99
- image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
99
+ pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
100
100
  torch.int64
101
101
  ),
102
- pixel_values=torch.ones((batch_size, n_images, num_channels, width, height)).to(
102
+ image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
103
103
  torch.int64
104
104
  ),
105
105
  )
@@ -132,16 +132,30 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
132
132
  If the configuration is None, the function selects typical dimensions.
133
133
  """
134
134
  if config is not None:
135
- check_hasattr(
136
- config,
137
- "vocab_size",
138
- "hidden_size",
139
- "num_attention_heads",
140
- ("num_key_value_heads", "num_attention_heads"),
141
- "intermediate_size",
142
- "hidden_size",
143
- "vision_config",
144
- )
135
+ if hasattr(config, "text_config"):
136
+ check_hasattr(
137
+ config.text_config,
138
+ "vocab_size",
139
+ "hidden_size",
140
+ "num_attention_heads",
141
+ ("num_key_value_heads", "num_attention_heads"),
142
+ "intermediate_size",
143
+ "hidden_size",
144
+ )
145
+ check_hasattr(config, "vision_config")
146
+ text_config = True
147
+ else:
148
+ check_hasattr(
149
+ config,
150
+ "vocab_size",
151
+ "hidden_size",
152
+ "num_attention_heads",
153
+ ("num_key_value_heads", "num_attention_heads"),
154
+ "intermediate_size",
155
+ "hidden_size",
156
+ "vision_config",
157
+ )
158
+ text_config = False
145
159
  check_hasattr(config.vision_config, "image_size", "num_channels")
146
160
  kwargs = dict(
147
161
  batch_size=2,
@@ -150,17 +164,54 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
150
164
  head_dim=(
151
165
  16
152
166
  if config is None
153
- else getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
167
+ else getattr(
168
+ config,
169
+ "head_dim",
170
+ (config.text_config.hidden_size if text_config else config.hidden_size)
171
+ // (
172
+ config.text_config.num_attention_heads
173
+ if text_config
174
+ else config.num_attention_heads
175
+ ),
176
+ )
177
+ ),
178
+ dummy_max_token_id=(
179
+ 31999
180
+ if config is None
181
+ else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
182
+ ),
183
+ num_hidden_layers=(
184
+ 4
185
+ if config is None
186
+ else (
187
+ config.text_config.num_hidden_layers
188
+ if text_config
189
+ else config.num_hidden_layers
190
+ )
154
191
  ),
155
- dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
156
- num_hidden_layers=4 if config is None else config.num_hidden_layers,
157
192
  num_key_value_heads=(
158
193
  8
159
194
  if config is None
160
- else _pick(config, "num_key_value_heads", "num_attention_heads")
195
+ else (
196
+ _pick(config.text_config, "num_key_value_heads", "num_attention_heads")
197
+ if text_config
198
+ else _pick(config, "num_key_value_heads", "num_attention_heads")
199
+ )
200
+ ),
201
+ intermediate_size=(
202
+ 1024
203
+ if config is None
204
+ else (
205
+ config.text_config.intermediate_size
206
+ if text_config
207
+ else config.intermediate_size
208
+ )
209
+ ),
210
+ hidden_size=(
211
+ 512
212
+ if config is None
213
+ else (config.text_config.hidden_size if text_config else config.hidden_size)
161
214
  ),
162
- intermediate_size=1024 if config is None else config.intermediate_size,
163
- hidden_size=512 if config is None else config.hidden_size,
164
215
  width=224 if config is None else config.vision_config.image_size,
165
216
  height=224 if config is None else config.vision_config.image_size,
166
217
  num_channels=3 if config is None else config.vision_config.num_channels,
@@ -109,7 +109,7 @@ def get_inputs(
109
109
  sequence_length2 = seq_length_multiple
110
110
 
111
111
  shapes = {
112
- "input_ids": {0: batch, 1: torch.export.Dim.DYNAMIC},
112
+ "input_ids": {0: batch, 1: "sequence_length"},
113
113
  "attention_mask": {
114
114
  0: batch,
115
115
  1: "cache+seq", # cache_length + seq_length
@@ -176,8 +176,10 @@ def get_inputs(
176
176
  "attention_mask": {0: batch, 2: "seq"},
177
177
  "cache_position": {0: "seq"},
178
178
  "past_key_values": [
179
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
179
+ # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
180
+ # [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
181
+ [{0: batch} for _ in range(num_hidden_layers)],
182
+ [{0: batch} for _ in range(num_hidden_layers)],
181
183
  ],
182
184
  }
183
185
  inputs = dict(
@@ -188,18 +190,25 @@ def get_inputs(
188
190
  (batch_size, num_key_value_heads, sequence_length2, head_dim)
189
191
  ).to(torch.bool),
190
192
  cache_position=torch.arange(sequence_length2).to(torch.int64),
191
- past_key_values=make_cache(
193
+ past_key_values=make_static_cache(
192
194
  [
193
195
  (
194
196
  torch.randn(
195
- batch_size, num_key_value_heads, sequence_length, head_dim
197
+ batch_size,
198
+ num_key_value_heads,
199
+ sequence_length + sequence_length2,
200
+ head_dim,
196
201
  ),
197
202
  torch.randn(
198
- batch_size, num_key_value_heads, sequence_length, head_dim
203
+ batch_size,
204
+ num_key_value_heads,
205
+ sequence_length + sequence_length2,
206
+ head_dim,
199
207
  ),
200
208
  )
201
209
  for i in range(num_hidden_layers)
202
- ]
210
+ ],
211
+ max_cache_len=max(sequence_length + sequence_length2, head_dim),
203
212
  ),
204
213
  )
205
214
  else:
@@ -230,7 +239,7 @@ def get_inputs(
230
239
  position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
231
240
  .to(torch.int64)
232
241
  .expand((batch_size, -1)),
233
- past_key_values=make_cache(
242
+ past_key_values=make_cache( # type: ignore[operator]
234
243
  [
235
244
  (
236
245
  torch.randn(
@@ -0,0 +1,91 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr, pick
4
+
5
+ __TASK__ = "text-to-image"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, "sample_size", "cross_attention_dim")
11
+ kwargs = dict(
12
+ sample_size=min(config["sample_size"], 32),
13
+ cross_attention_dim=min(config["cross_attention_dim"], 64),
14
+ )
15
+ update_config(config, kwargs)
16
+ return kwargs
17
+
18
+
19
+ def get_inputs(
20
+ model: torch.nn.Module,
21
+ config: Optional[Any],
22
+ batch_size: int,
23
+ sequence_length: int,
24
+ cache_length: int,
25
+ in_channels: int,
26
+ sample_size: int,
27
+ cross_attention_dim: int,
28
+ add_second_input: bool = False,
29
+ **kwargs, # unused
30
+ ):
31
+ """
32
+ Generates inputs for task ``text-to-image``.
33
+ Example:
34
+
35
+ ::
36
+
37
+ sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
38
+ timestep:T7s=101
39
+ encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
40
+ """
41
+ assert (
42
+ "cls_cache" not in kwargs
43
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
44
+ batch = "batch"
45
+ shapes = {
46
+ "sample": {0: batch},
47
+ "timestep": {},
48
+ "encoder_hidden_states": {0: batch, 1: "encoder_length"},
49
+ }
50
+ inputs = dict(
51
+ sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to(
52
+ torch.float32
53
+ ),
54
+ timestep=torch.tensor([101], dtype=torch.int64),
55
+ encoder_hidden_states=torch.randn(
56
+ (batch_size, sequence_length, cross_attention_dim)
57
+ ).to(torch.float32),
58
+ )
59
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
60
+ if add_second_input:
61
+ res["inputs2"] = get_inputs(
62
+ model=model,
63
+ config=config,
64
+ batch_size=batch_size + 1,
65
+ sequence_length=sequence_length,
66
+ cache_length=cache_length + 1,
67
+ in_channels=in_channels,
68
+ sample_size=sample_size,
69
+ cross_attention_dim=cross_attention_dim,
70
+ **kwargs,
71
+ )["inputs"]
72
+ return res
73
+
74
+
75
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
76
+ """
77
+ Inputs kwargs.
78
+
79
+ If the configuration is None, the function selects typical dimensions.
80
+ """
81
+ if config is not None:
82
+ check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
83
+ kwargs = dict(
84
+ batch_size=2,
85
+ sequence_length=pick(config, "in_channels", 4),
86
+ cache_length=77,
87
+ in_channels=pick(config, "in_channels", 4),
88
+ sample_size=pick(config, "sample_size", 32),
89
+ cross_attention_dim=pick(config, "cross_attention_dim", 64),
90
+ )
91
+ return kwargs, get_inputs
@@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
134
134
 
135
135
  @contextlib.contextmanager
136
136
  def register_additional_serialization_functions(
137
- patch_transformers: bool = False, verbose: int = 0
137
+ patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
138
138
  ) -> Callable:
139
139
  """The necessary modifications to run the fx Graph."""
140
- fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
141
- done = register_cache_serialization(verbose=verbose)
140
+ fct_callable = (
141
+ replacement_before_exporting
142
+ if patch_transformers or patch_diffusers
143
+ else (lambda x: x)
144
+ )
145
+ done = register_cache_serialization(
146
+ patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
147
+ )
142
148
  try:
143
149
  yield fct_callable
144
150
  finally:
@@ -150,6 +156,7 @@ def torch_export_patches(
150
156
  patch_sympy: bool = True,
151
157
  patch_torch: bool = True,
152
158
  patch_transformers: bool = False,
159
+ patch_diffusers: bool = False,
153
160
  catch_constraints: bool = True,
154
161
  stop_if_static: int = 0,
155
162
  verbose: int = 0,
@@ -165,6 +172,7 @@ def torch_export_patches(
165
172
  :param patch_sympy: fix missing method ``name`` for IntegerConstant
166
173
  :param patch_torch: patches :epkg:`torch` with supported implementation
167
174
  :param patch_transformers: patches :epkg:`transformers` with supported implementation
175
+ :param patch_diffusers: patches :epkg:`diffusers` with supported implementation
168
176
  :param catch_constraints: catch constraints related to dynamic shapes,
169
177
  as a result, some dynamic dimension may turn into static ones,
170
178
  the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
@@ -174,8 +182,8 @@ def torch_export_patches(
174
182
  and show a stack trace indicating the exact location of the issue,
175
183
  ``if stop_if_static > 1``, more methods are replace to catch more
176
184
  issues
177
- :param patch: if False, disable all patches except the registration of
178
- serialization function
185
+ :param patch: if False, disable all patches but keeps the registration of
186
+ serialization functions if other patch functions are enabled
179
187
  :param custom_patches: to apply custom patches,
180
188
  every patched class must define static attributes
181
189
  ``_PATCHES_``, ``_PATCHED_CLASS_``
@@ -249,6 +257,7 @@ def torch_export_patches(
249
257
  patch_sympy=patch_sympy,
250
258
  patch_torch=patch_torch,
251
259
  patch_transformers=patch_transformers,
260
+ patch_diffusers=patch_diffusers,
252
261
  catch_constraints=catch_constraints,
253
262
  stop_if_static=stop_if_static,
254
263
  verbose=verbose,
@@ -261,7 +270,11 @@ def torch_export_patches(
261
270
  pass
262
271
  elif not patch:
263
272
  fct_callable = lambda x: x # noqa: E731
264
- done = register_cache_serialization(verbose=verbose)
273
+ done = register_cache_serialization(
274
+ patch_transformers=patch_transformers,
275
+ patch_diffusers=patch_diffusers,
276
+ verbose=verbose,
277
+ )
265
278
  try:
266
279
  yield fct_callable
267
280
  finally:
@@ -281,7 +294,11 @@ def torch_export_patches(
281
294
  # caches
282
295
  ########
283
296
 
284
- cache_done = register_cache_serialization(verbose=verbose)
297
+ cache_done = register_cache_serialization(
298
+ patch_transformers=patch_transformers,
299
+ patch_diffusers=patch_diffusers,
300
+ verbose=verbose,
301
+ )
285
302
 
286
303
  #############
287
304
  # patch sympy