onnx-diagnostic 0.7.1__py3-none-any.whl → 0.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +22 -5
  3. onnx_diagnostic/ext_test_case.py +31 -0
  4. onnx_diagnostic/helpers/cache_helper.py +23 -12
  5. onnx_diagnostic/helpers/config_helper.py +16 -1
  6. onnx_diagnostic/helpers/log_helper.py +308 -83
  7. onnx_diagnostic/helpers/rt_helper.py +11 -1
  8. onnx_diagnostic/helpers/torch_helper.py +7 -3
  9. onnx_diagnostic/tasks/__init__.py +2 -0
  10. onnx_diagnostic/tasks/text_generation.py +17 -8
  11. onnx_diagnostic/tasks/text_to_image.py +91 -0
  12. onnx_diagnostic/torch_export_patches/eval/__init__.py +3 -1
  13. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +24 -7
  14. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +148 -351
  15. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +89 -10
  16. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  17. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  18. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +259 -0
  19. onnx_diagnostic/torch_models/hghub/hub_api.py +15 -4
  20. onnx_diagnostic/torch_models/hghub/hub_data.py +1 -0
  21. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +28 -0
  22. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -5
  23. onnx_diagnostic/torch_models/validate.py +36 -12
  24. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/METADATA +26 -1
  25. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/RECORD +28 -24
  26. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/WHEEL +0 -0
  27. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/licenses/LICENSE.txt +0 -0
  28. {onnx_diagnostic-0.7.1.dist-info → onnx_diagnostic-0.7.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,91 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr, pick
4
+
5
+ __TASK__ = "text-to-image"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, "sample_size", "cross_attention_dim")
11
+ kwargs = dict(
12
+ sample_size=min(config["sample_size"], 32),
13
+ cross_attention_dim=min(config["cross_attention_dim"], 64),
14
+ )
15
+ update_config(config, kwargs)
16
+ return kwargs
17
+
18
+
19
+ def get_inputs(
20
+ model: torch.nn.Module,
21
+ config: Optional[Any],
22
+ batch_size: int,
23
+ sequence_length: int,
24
+ cache_length: int,
25
+ in_channels: int,
26
+ sample_size: int,
27
+ cross_attention_dim: int,
28
+ add_second_input: bool = False,
29
+ **kwargs, # unused
30
+ ):
31
+ """
32
+ Generates inputs for task ``text-to-image``.
33
+ Example:
34
+
35
+ ::
36
+
37
+ sample:T10s2x4x96x96[-3.7734375,4.359375:A-0.043463995395642184]
38
+ timestep:T7s=101
39
+ encoder_hidden_states:T10s2x77x1024[-6.58203125,13.0234375:A-0.16780663634440257]
40
+ """
41
+ assert (
42
+ "cls_cache" not in kwargs
43
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
44
+ batch = "batch"
45
+ shapes = {
46
+ "sample": {0: batch},
47
+ "timestep": {},
48
+ "encoder_hidden_states": {0: batch, 1: "encoder_length"},
49
+ }
50
+ inputs = dict(
51
+ sample=torch.randn((batch_size, sequence_length, sample_size, sample_size)).to(
52
+ torch.float32
53
+ ),
54
+ timestep=torch.tensor([101], dtype=torch.int64),
55
+ encoder_hidden_states=torch.randn(
56
+ (batch_size, sequence_length, cross_attention_dim)
57
+ ).to(torch.float32),
58
+ )
59
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
60
+ if add_second_input:
61
+ res["inputs2"] = get_inputs(
62
+ model=model,
63
+ config=config,
64
+ batch_size=batch_size + 1,
65
+ sequence_length=sequence_length,
66
+ cache_length=cache_length + 1,
67
+ in_channels=in_channels,
68
+ sample_size=sample_size,
69
+ cross_attention_dim=cross_attention_dim,
70
+ **kwargs,
71
+ )["inputs"]
72
+ return res
73
+
74
+
75
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
76
+ """
77
+ Inputs kwargs.
78
+
79
+ If the configuration is None, the function selects typical dimensions.
80
+ """
81
+ if config is not None:
82
+ check_hasattr(config, "sample_size", "cross_attention_dim", "in_channels")
83
+ kwargs = dict(
84
+ batch_size=2,
85
+ sequence_length=pick(config, "in_channels", 4),
86
+ cache_length=77,
87
+ in_channels=pick(config, "in_channels", 4),
88
+ sample_size=pick(config, "sample_size", 32),
89
+ cross_attention_dim=pick(config, "cross_attention_dim", 64),
90
+ )
91
+ return kwargs, get_inputs
@@ -337,7 +337,7 @@ def _make_exporter_onnx(
337
337
  from experimental_experiment.torch_interpreter import to_onnx, ExportOptions
338
338
 
339
339
  opts = {}
340
- opts["strict"] = "-nostrict" not in exporter
340
+ opts["strict"] = "-strict" in exporter
341
341
  opts["fallback"] = "-fallback" in exporter
342
342
  opts["tracing"] = "-tracing" in exporter
343
343
  opts["jit"] = "-jit" in exporter
@@ -520,6 +520,8 @@ def run_exporter(
520
520
  return res
521
521
 
522
522
  onx, builder = res
523
+ base["onx"] = onx
524
+ base["builder"] = builder
523
525
  if verbose >= 9:
524
526
  print("[run_exporter] onnx model")
525
527
  print(
@@ -134,11 +134,17 @@ def unpatch_module_or_classes(mod, info: Dict[type, Dict[type, Callable]], verbo
134
134
 
135
135
  @contextlib.contextmanager
136
136
  def register_additional_serialization_functions(
137
- patch_transformers: bool = False, verbose: int = 0
137
+ patch_transformers: bool = False, patch_diffusers: bool = False, verbose: int = 0
138
138
  ) -> Callable:
139
139
  """The necessary modifications to run the fx Graph."""
140
- fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
141
- done = register_cache_serialization(verbose=verbose)
140
+ fct_callable = (
141
+ replacement_before_exporting
142
+ if patch_transformers or patch_diffusers
143
+ else (lambda x: x)
144
+ )
145
+ done = register_cache_serialization(
146
+ patch_transformers=patch_transformers, patch_diffusers=patch_diffusers, verbose=verbose
147
+ )
142
148
  try:
143
149
  yield fct_callable
144
150
  finally:
@@ -150,6 +156,7 @@ def torch_export_patches(
150
156
  patch_sympy: bool = True,
151
157
  patch_torch: bool = True,
152
158
  patch_transformers: bool = False,
159
+ patch_diffusers: bool = False,
153
160
  catch_constraints: bool = True,
154
161
  stop_if_static: int = 0,
155
162
  verbose: int = 0,
@@ -165,6 +172,7 @@ def torch_export_patches(
165
172
  :param patch_sympy: fix missing method ``name`` for IntegerConstant
166
173
  :param patch_torch: patches :epkg:`torch` with supported implementation
167
174
  :param patch_transformers: patches :epkg:`transformers` with supported implementation
175
+ :param patch_diffusers: patches :epkg:`diffusers` with supported implementation
168
176
  :param catch_constraints: catch constraints related to dynamic shapes,
169
177
  as a result, some dynamic dimension may turn into static ones,
170
178
  the environment variable ``SKIP_SOLVE_CONSTRAINTS=0``
@@ -174,8 +182,8 @@ def torch_export_patches(
174
182
  and show a stack trace indicating the exact location of the issue,
175
183
  ``if stop_if_static > 1``, more methods are replace to catch more
176
184
  issues
177
- :param patch: if False, disable all patches except the registration of
178
- serialization function
185
+ :param patch: if False, disable all patches but keeps the registration of
186
+ serialization functions if other patch functions are enabled
179
187
  :param custom_patches: to apply custom patches,
180
188
  every patched class must define static attributes
181
189
  ``_PATCHES_``, ``_PATCHED_CLASS_``
@@ -249,6 +257,7 @@ def torch_export_patches(
249
257
  patch_sympy=patch_sympy,
250
258
  patch_torch=patch_torch,
251
259
  patch_transformers=patch_transformers,
260
+ patch_diffusers=patch_diffusers,
252
261
  catch_constraints=catch_constraints,
253
262
  stop_if_static=stop_if_static,
254
263
  verbose=verbose,
@@ -261,7 +270,11 @@ def torch_export_patches(
261
270
  pass
262
271
  elif not patch:
263
272
  fct_callable = lambda x: x # noqa: E731
264
- done = register_cache_serialization(verbose=verbose)
273
+ done = register_cache_serialization(
274
+ patch_transformers=patch_transformers,
275
+ patch_diffusers=patch_diffusers,
276
+ verbose=verbose,
277
+ )
265
278
  try:
266
279
  yield fct_callable
267
280
  finally:
@@ -281,7 +294,11 @@ def torch_export_patches(
281
294
  # caches
282
295
  ########
283
296
 
284
- cache_done = register_cache_serialization(verbose=verbose)
297
+ cache_done = register_cache_serialization(
298
+ patch_transformers=patch_transformers,
299
+ patch_diffusers=patch_diffusers,
300
+ verbose=verbose,
301
+ )
285
302
 
286
303
  #############
287
304
  # patch sympy