onnx-diagnostic 0.4.2__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.4.2"
6
+ __version__ = "0.4.3"
7
7
  __author__ = "Xavier Dupré"
@@ -1,4 +1,8 @@
1
1
  from .onnx_export_errors import (
2
- bypass_export_some_errors,
2
+ torch_export_patches,
3
3
  register_additional_serialization_functions,
4
4
  )
5
+
6
+
7
+ # bypass_export_some_errors is the first name given to the patches.
8
+ bypass_export_some_errors = torch_export_patches # type: ignore
@@ -93,7 +93,7 @@ def register_additional_serialization_functions(
93
93
 
94
94
 
95
95
  @contextlib.contextmanager
96
- def bypass_export_some_errors(
96
+ def torch_export_patches(
97
97
  patch_sympy: bool = True,
98
98
  patch_torch: bool = True,
99
99
  patch_transformers: bool = False,
@@ -145,13 +145,13 @@ def bypass_export_some_errors(
145
145
 
146
146
  ::
147
147
 
148
- with bypass_export_some_errors(patch_transformers=True) as modificator:
148
+ with torch_export_patches(patch_transformers=True) as modificator:
149
149
  inputs = modificator(inputs)
150
150
  onx = to_onnx(..., inputs, ...)
151
151
 
152
152
  ::
153
153
 
154
- with bypass_export_some_errors(patch_transformers=True) as modificator:
154
+ with torch_export_patches(patch_transformers=True) as modificator:
155
155
  inputs = modificator(inputs)
156
156
  onx = torch.onnx.export(..., inputs, ...)
157
157
 
@@ -159,7 +159,7 @@ def bypass_export_some_errors(
159
159
 
160
160
  ::
161
161
 
162
- with bypass_export_some_errors(patch_transformers=True) as modificator:
162
+ with torch_export_patches(patch_transformers=True) as modificator:
163
163
  inputs = modificator(inputs)
164
164
  ep = torch.export.export(..., inputs, ...)
165
165
 
@@ -190,7 +190,7 @@ def bypass_export_some_errors(
190
190
 
191
191
  if verbose:
192
192
  print(
193
- "[bypass_export_some_errors] replace torch.jit.isinstance, "
193
+ "[torch_export_patches] replace torch.jit.isinstance, "
194
194
  "torch._dynamo.mark_static_address"
195
195
  )
196
196
 
@@ -210,8 +210,8 @@ def bypass_export_some_errors(
210
210
  f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
211
211
 
212
212
  if verbose:
213
- print(f"[bypass_export_some_errors] sympy.__version__={sympy.__version__!r}")
214
- print("[bypass_export_some_errors] patch sympy")
213
+ print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
214
+ print("[torch_export_patches] patch sympy")
215
215
 
216
216
  sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
217
217
 
@@ -228,9 +228,9 @@ def bypass_export_some_errors(
228
228
  )
229
229
 
230
230
  if verbose:
231
- print(f"[bypass_export_some_errors] torch.__version__={torch.__version__!r}")
232
- print(f"[bypass_export_some_errors] stop_if_static={stop_if_static!r}")
233
- print("[bypass_export_some_errors] patch pytorch")
231
+ print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
232
+ print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
233
+ print("[torch_export_patches] patch pytorch")
234
234
 
235
235
  # torch.jit.isinstance
236
236
  f_jit_isinstance = torch.jit.isinstance
@@ -252,7 +252,7 @@ def bypass_export_some_errors(
252
252
  # torch._export.non_strict_utils.produce_guards_and_solve_constraints
253
253
  if catch_constraints:
254
254
  if verbose:
255
- print("[bypass_export_some_errors] modifies shape constraints")
255
+ print("[torch_export_patches] modifies shape constraints")
256
256
  f_produce_guards_and_solve_constraints = (
257
257
  torch._export.non_strict_utils.produce_guards_and_solve_constraints
258
258
  )
@@ -277,22 +277,20 @@ def bypass_export_some_errors(
277
277
  ShapeEnv._log_guard_remember = ShapeEnv._log_guard
278
278
 
279
279
  if verbose:
280
- print(
281
- "[bypass_export_some_errors] assert when a dynamic dimension turns static"
282
- )
283
- print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
280
+ print("[torch_export_patches] assert when a dynamic dimension turns static")
281
+ print("[torch_export_patches] replaces ShapeEnv._set_replacement")
284
282
 
285
283
  f_shape_env__set_replacement = ShapeEnv._set_replacement
286
284
  ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
287
285
 
288
286
  if verbose:
289
- print("[bypass_export_some_errors] replaces ShapeEnv._log_guard")
287
+ print("[torch_export_patches] replaces ShapeEnv._log_guard")
290
288
  f_shape_env__log_guard = ShapeEnv._log_guard
291
289
  ShapeEnv._log_guard = patched_ShapeEnv._log_guard
292
290
 
293
291
  if stop_if_static > 1:
294
292
  if verbose:
295
- print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
293
+ print("[torch_export_patches] replaces ShapeEnv._check_frozen")
296
294
  f_shape_env__check_frozen = ShapeEnv._check_frozen
297
295
  ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
298
296
 
@@ -305,7 +303,7 @@ def bypass_export_some_errors(
305
303
  import transformers
306
304
 
307
305
  print(
308
- f"[bypass_export_some_errors] transformers.__version__="
306
+ f"[torch_export_patches] transformers.__version__="
309
307
  f"{transformers.__version__!r}"
310
308
  )
311
309
  revert_patches_info = patch_module_or_classes(
@@ -314,7 +312,7 @@ def bypass_export_some_errors(
314
312
 
315
313
  if custom_patches:
316
314
  if verbose:
317
- print("[bypass_export_some_errors] applies custom patches")
315
+ print("[torch_export_patches] applies custom patches")
318
316
  revert_custom_patches_info = patch_module_or_classes(
319
317
  custom_patches, verbose=verbose
320
318
  )
@@ -326,7 +324,7 @@ def bypass_export_some_errors(
326
324
  fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
327
325
 
328
326
  if verbose:
329
- print("[bypass_export_some_errors] done patching")
327
+ print("[torch_export_patches] done patching")
330
328
 
331
329
  try:
332
330
  yield fct_callable
@@ -336,7 +334,7 @@ def bypass_export_some_errors(
336
334
  #######
337
335
 
338
336
  if verbose:
339
- print("[bypass_export_some_errors] remove patches")
337
+ print("[torch_export_patches] remove patches")
340
338
 
341
339
  if patch_sympy:
342
340
  # tracked by https://github.com/pytorch/pytorch/issues/143494
@@ -346,7 +344,7 @@ def bypass_export_some_errors(
346
344
  delattr(sympy.core.numbers.IntegerConstant, "name")
347
345
 
348
346
  if verbose:
349
- print("[bypass_export_some_errors] restored sympy functions")
347
+ print("[torch_export_patches] restored sympy functions")
350
348
 
351
349
  #######
352
350
  # torch
@@ -362,22 +360,22 @@ def bypass_export_some_errors(
362
360
  torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
363
361
 
364
362
  if verbose:
365
- print("[bypass_export_some_errors] restored pytorch functions")
363
+ print("[torch_export_patches] restored pytorch functions")
366
364
 
367
365
  if stop_if_static:
368
366
  if verbose:
369
- print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
367
+ print("[torch_export_patches] restored ShapeEnv._set_replacement")
370
368
 
371
369
  ShapeEnv._set_replacement = f_shape_env__set_replacement
372
370
 
373
371
  if verbose:
374
- print("[bypass_export_some_errors] restored ShapeEnv._log_guard")
372
+ print("[torch_export_patches] restored ShapeEnv._log_guard")
375
373
 
376
374
  ShapeEnv._log_guard = f_shape_env__log_guard
377
375
 
378
376
  if stop_if_static > 1:
379
377
  if verbose:
380
- print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
378
+ print("[torch_export_patches] restored ShapeEnv._check_frozen")
381
379
  ShapeEnv._check_frozen = f_shape_env__check_frozen
382
380
 
383
381
  if catch_constraints:
@@ -389,11 +387,11 @@ def bypass_export_some_errors(
389
387
  f__check_input_constraints_for_graph
390
388
  )
391
389
  if verbose:
392
- print("[bypass_export_some_errors] restored shape constraints")
390
+ print("[torch_export_patches] restored shape constraints")
393
391
 
394
392
  if custom_patches:
395
393
  if verbose:
396
- print("[bypass_export_some_errors] unpatch custom patches")
394
+ print("[torch_export_patches] unpatch custom patches")
397
395
  unpatch_module_or_classes(
398
396
  custom_patches, revert_custom_patches_info, verbose=verbose
399
397
  )
@@ -404,7 +402,7 @@ def bypass_export_some_errors(
404
402
 
405
403
  if patch_transformers:
406
404
  if verbose:
407
- print("[bypass_export_some_errors] unpatch transformers")
405
+ print("[torch_export_patches] unpatch transformers")
408
406
  unpatch_module_or_classes(
409
407
  patch_transformers_list, revert_patches_info, verbose=verbose
410
408
  )
@@ -38,12 +38,14 @@ __data_arch__ = textwrap.dedent(
38
38
  DeiTModel,image-feature-extraction
39
39
  DetrModel,image-feature-extraction
40
40
  Dinov2Model,image-feature-extraction
41
+ DistilBertForSequenceClassification,text-classification
41
42
  DistilBertModel,feature-extraction
42
43
  DonutSwinModel,feature-extraction
43
44
  ElectraModel,feature-extraction
44
45
  EsmModel,feature-extraction
45
46
  FalconMambaForCausalLM,text-generation
46
47
  GLPNModel,image-feature-extraction
48
+ GPT2LMHeadModel,text-generation
47
49
  GPTBigCodeModel,feature-extraction
48
50
  GPTJModel,feature-extraction
49
51
  GPTNeoModel,feature-extraction
@@ -64,6 +66,7 @@ __data_arch__ = textwrap.dedent(
64
66
  LongT5Model,feature-extraction
65
67
  LongformerModel,feature-extraction
66
68
  MCTCTModel,feature-extraction
69
+ MPNetForMaskedLM,sentence-similarity
67
70
  MPNetModel,feature-extraction
68
71
  MT5Model,feature-extraction
69
72
  MarianMTModel,text2text-generation
@@ -96,11 +99,13 @@ __data_arch__ = textwrap.dedent(
96
99
  PoolFormerModel,image-feature-extraction
97
100
  PvtForImageClassification,image-classification
98
101
  Qwen2ForCausalLM,text-generation
102
+ Qwen2_5_VLForConditionalGeneration,image-text-to-text
99
103
  RTDetrForObjectDetection,object-detection
100
104
  RegNetModel,image-feature-extraction
101
105
  RemBertModel,feature-extraction
102
106
  ResNetForImageClassification,image-classification
103
107
  RoFormerModel,feature-extraction
108
+ RobertaForMaskedLM,sentence-similarity
104
109
  RobertaModel,feature-extraction
105
110
  RtDetrV2ForObjectDetection,object-detection
106
111
  SEWDModel,feature-extraction
@@ -118,6 +123,7 @@ __data_arch__ = textwrap.dedent(
118
123
  Swinv2Model,image-feature-extraction
119
124
  T5ForConditionalGeneration,text2text-generation
120
125
  TableTransformerModel,image-feature-extraction
126
+ TableTransformerForObjectDetection,object-detection
121
127
  UniSpeechForSequenceClassification,audio-classification
122
128
  ViTForImageClassification,image-classification
123
129
  ViTMAEModel,image-feature-extraction
@@ -130,6 +136,9 @@ __data_arch__ = textwrap.dedent(
130
136
  WhisperForConditionalGeneration,automatic-speech-recognition
131
137
  XLMModel,feature-extraction
132
138
  XLMRobertaForCausalLM,text-generation
139
+ XLMRobertaForMaskedLM,fill-mask
140
+ XLMRobertaModel,sentence-similarity
141
+ Wav2Vec2ForCTC,automatic-speech-recognition
133
142
  YolosForObjectDetection,object-detection
134
143
  YolosModel,image-feature-extraction"""
135
144
  )
@@ -12,7 +12,7 @@ from ..helpers.rt_helper import make_feeds
12
12
  from ..helpers.torch_test_helper import to_any, torch_deepcopy
13
13
  from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
14
14
  from ..tasks import random_input_kwargs
15
- from ..torch_export_patches import bypass_export_some_errors
15
+ from ..torch_export_patches import torch_export_patches
16
16
  from ..torch_export_patches.patch_inputs import use_dyn_not_str
17
17
  from .hghub import get_untrained_model_with_inputs
18
18
 
@@ -242,9 +242,9 @@ def validate_model(
242
242
  depend on the the exporter
243
243
  :param quiet: if quiet, catches exception if any issue
244
244
  :param patch: applies patches (``patch_transformers=True``) before exporting,
245
- see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
245
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
246
246
  :param stop_if_static: stops if a dynamic dimension becomes static,
247
- see :func:`onnx_diagnostic.torch_export_patches.bypass_export_some_errors`
247
+ see :func:`onnx_diagnostic.torch_export_patches.torch_export_patches`
248
248
  :param dump_folder: dumps everything in a subfolder of this one
249
249
  :param drop_inputs: drops this list of inputs (given their names)
250
250
  :param ortfusiontype: runs ort fusion, the parameters defines the fusion type,
@@ -417,7 +417,7 @@ def validate_model(
417
417
  f"[validate_model] applies patches before exporting "
418
418
  f"stop_if_static={stop_if_static}"
419
419
  )
420
- with bypass_export_some_errors( # type: ignore
420
+ with torch_export_patches( # type: ignore
421
421
  patch_transformers=True,
422
422
  stop_if_static=stop_if_static,
423
423
  verbose=max(0, verbose - 1),
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.4.2
3
+ Version: 0.4.3
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -67,13 +67,13 @@ it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using
67
67
 
68
68
  .. code-block:: python
69
69
 
70
- with bypass_export_some_errors(patch_transformers=True) as f:
70
+ with torch_export_patches(patch_transformers=True) as f:
71
71
  ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
72
72
  # ...
73
73
 
74
74
  It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
75
75
  See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
76
- `bypass_export_some_errors <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.bypass_export_some_errors>`_.
76
+ `torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.
77
77
 
78
78
  Getting started
79
79
  +++++++++++++++
@@ -1,4 +1,4 @@
1
- onnx_diagnostic/__init__.py,sha256=wVSctxhjG5jNBmX9oZ_oUVWt2QU4P1s8bsgeKXDj0YI,164
1
+ onnx_diagnostic/__init__.py,sha256=opPQ2jwxhWOe2Y2oDiKVTmNL4w0H1Gl0G921alsi0NM,164
2
2
  onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
3
3
  onnx_diagnostic/_command_lines_parser.py,sha256=kOECT1BccZc38vmVc3jF3xvXGDpcocvLuUGoPkzte08,14753
4
4
  onnx_diagnostic/doc.py,sha256=MTuT7Kxyvn7KEy84liQeFeqhugJrUQhjjpx21F72Uxw,926
@@ -61,8 +61,8 @@ onnx_diagnostic/tasks/text2text_generation.py,sha256=jaJLQqKk38mAop7O3zCFQjUvmYm
61
61
  onnx_diagnostic/tasks/text_classification.py,sha256=OgC_G9iumzTjTNUEvMoFFNTHCD8_BkdvdYC4jUsfpHM,2412
62
62
  onnx_diagnostic/tasks/text_generation.py,sha256=fTasu-igW-f9dyhYN4qXYkTWZU1ppgK37cmpvXV3i08,10215
63
63
  onnx_diagnostic/tasks/zero_shot_image_classification.py,sha256=N3cEG1Lq95wS1N_CWUUUCU5j-4Tp5eR8Ce68U8THYAk,4380
64
- onnx_diagnostic/torch_export_patches/__init__.py,sha256=RZzVGgouNNXaPirQJYQThiq5wrliwH4unVszeU18oJw,116
65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=xeMHu9VJcJ3suCg-OJiLtvNcR2Q8hYm8Y5aZdmcperk,16256
64
+ onnx_diagnostic/torch_export_patches/__init__.py,sha256=uRqg0-KSs_DhDnzrVp-TG2vfcDBO4HlsNkEg7RomQL0,246
65
+ onnx_diagnostic/torch_export_patches/onnx_export_errors.py,sha256=9WQUBAp5okQL9aJJKyp23ZumBnKt-qahcr94a9MYWxA,16083
66
66
  onnx_diagnostic/torch_export_patches/onnx_export_serialization.py,sha256=1s1LqgqOL_hV6yqT7sgxzTKSDAL267CcZgNq8K4oTZM,14898
67
67
  onnx_diagnostic/torch_export_patches/patch_inputs.py,sha256=FQrMjwvEgPqvYY7ptfULzfexW5yJHo6Pzq_p1HDkNrY,7680
68
68
  onnx_diagnostic/torch_export_patches/patches/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -70,10 +70,10 @@ onnx_diagnostic/torch_export_patches/patches/patch_torch.py,sha256=TKLxrIJUrQsy0
70
70
  onnx_diagnostic/torch_export_patches/patches/patch_transformers.py,sha256=exiIq8zNZsY6QTzZVDMgU2ywGzs6-54Ic4vzTQ-26YQ,21863
71
71
  onnx_diagnostic/torch_models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
72
72
  onnx_diagnostic/torch_models/llms.py,sha256=soyg4yC87ptGoeulJhKqw5opGmuLvH1pn_ZDXZ4Jr8E,90
73
- onnx_diagnostic/torch_models/test_helper.py,sha256=e-JuKDGJnbLA6FmwBLfeAPV122ZulEdEChbvKJWc9R0,46624
73
+ onnx_diagnostic/torch_models/test_helper.py,sha256=ad4C1vpnYkLMTrPMBdyk0OuUIUIAa7wP-zgCzPH88zc,46604
74
74
  onnx_diagnostic/torch_models/hghub/__init__.py,sha256=vi1Q7YHdddj1soiBN42MSvJdFqe2_KUoWafHISjwOu8,58
75
75
  onnx_diagnostic/torch_models/hghub/hub_api.py,sha256=EjwsmdHhf9ub1K5UCQPxsKiTMZy1dsdcRvNmxoZrc98,8621
76
- onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=6H3ui0R6kkvRcqL7KBCXu9XuvebodViaXg74bPcDcko,7445
76
+ onnx_diagnostic/torch_models/hghub/hub_data.py,sha256=K9fu3NA530QEqYJcFTwACfsoRAIUrJoJLNtxDrUzM3c,7863
77
77
  onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py,sha256=IYinZGm6-Ob7fQHg7rE1OE0t5muCsEq5TpQiMgEsrgs,245009
78
78
  onnx_diagnostic/torch_models/hghub/model_inputs.py,sha256=B5c_-T_Ub9Mxs_DxpP4_yb4im-85ftVvAcUBgsISp1o,5743
79
79
  onnx_diagnostic/torch_models/untrained/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -81,8 +81,8 @@ onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiF
81
81
  onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=7N3fGvT_4Mn4NbIo0Qk57c6DMc3OXGWyvj_P41rjwSY,3513
82
82
  onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
83
  onnx_diagnostic/torch_onnx/sbs.py,sha256=HEGDHhV9pfXxpBQrpOWPNWGMsNfOebWewyAazi9poV8,16872
84
- onnx_diagnostic-0.4.2.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
85
- onnx_diagnostic-0.4.2.dist-info/METADATA,sha256=xkEoLhmlKpx91wddmUKXOpmDwdsEuW2y_c7bbk7cAVw,5511
86
- onnx_diagnostic-0.4.2.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
87
- onnx_diagnostic-0.4.2.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
88
- onnx_diagnostic-0.4.2.dist-info/RECORD,,
84
+ onnx_diagnostic-0.4.3.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
85
+ onnx_diagnostic-0.4.3.dist-info/METADATA,sha256=AI3BrV-xYBj8qVBlnpPOMsKnIAtERv1fsjx7sA6hD-A,5496
86
+ onnx_diagnostic-0.4.3.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
87
+ onnx_diagnostic-0.4.3.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
88
+ onnx_diagnostic-0.4.3.dist-info/RECORD,,