onnx-diagnostic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.4.2"
6
+ __version__ = "0.4.4"
7
7
  __author__ = "Xavier Dupré"
@@ -336,6 +336,10 @@ def get_parser_validate() -> ArgumentParser:
336
336
  help="drops the following inputs names, it should be a list "
337
337
  "with comma separated values",
338
338
  )
339
+ parser.add_argument(
340
+ "--subfolder",
341
+ help="subfolder where to find the model and the configuration",
342
+ )
339
343
  parser.add_argument(
340
344
  "--ortfusiontype",
341
345
  required=False,
@@ -413,6 +417,7 @@ def _cmd_validate(argv: List[Any]):
413
417
  ortfusiontype=args.ortfusiontype,
414
418
  input_options=args.iop,
415
419
  model_options=args.mop,
420
+ subfolder=args.subfolder,
416
421
  )
417
422
  print("")
418
423
  print("-- summary --")
@@ -6,6 +6,7 @@ from . import (
6
6
  image_classification,
7
7
  image_text_to_text,
8
8
  mixture_of_expert,
9
+ object_detection,
9
10
  sentence_similarity,
10
11
  text_classification,
11
12
  text_generation,
@@ -20,6 +21,7 @@ __TASKS__ = [
20
21
  image_classification,
21
22
  image_text_to_text,
22
23
  mixture_of_expert,
24
+ object_detection,
23
25
  sentence_similarity,
24
26
  text_classification,
25
27
  text_generation,
@@ -7,6 +7,13 @@ __TASK__ = "image-classification"
7
7
 
8
8
  def reduce_model_config(config: Any) -> Dict[str, Any]:
9
9
  """Reduces a model size."""
10
+ if (
11
+ hasattr(config, "model_type")
12
+ and config.model_type == "timm_wrapper"
13
+ and not hasattr(config, "num_hidden_layers")
14
+ ):
15
+ # We cannot reduce.
16
+ return {}
10
17
  check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11
18
  kwargs = dict(
12
19
  num_hidden_layers=(
@@ -82,6 +89,20 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
82
89
  If the configuration is None, the function selects typical dimensions.
83
90
  """
84
91
  if config is not None:
92
+ if (
93
+ hasattr(config, "model_type")
94
+ and config.model_type == "timm_wrapper"
95
+ and not hasattr(config, "num_hidden_layers")
96
+ ):
97
+ input_size = config.pretrained_cfg["input_size"]
98
+ kwargs = dict(
99
+ batch_size=2,
100
+ input_width=input_size[-2],
101
+ input_height=input_size[-1],
102
+ input_channels=input_size[-3],
103
+ )
104
+ return kwargs, get_inputs
105
+
85
106
  check_hasattr(config, ("image_size", "architectures"), "num_channels")
86
107
  if config is not None:
87
108
  if hasattr(config, "image_size"):
@@ -0,0 +1,123 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import update_config, check_hasattr
4
+
5
+ __TASK__ = "object-detection"
6
+
7
+
8
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
+ """Reduces a model size."""
10
+ check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
11
+ kwargs = dict(
12
+ num_hidden_layers=(
13
+ min(config.num_hidden_layers, 2)
14
+ if hasattr(config, "num_hidden_layers")
15
+ else len(config.hidden_sizes)
16
+ )
17
+ )
18
+ update_config(config, kwargs)
19
+ return kwargs
20
+
21
+
22
+ def get_inputs(
23
+ model: torch.nn.Module,
24
+ config: Optional[Any],
25
+ input_width: int,
26
+ input_height: int,
27
+ input_channels: int,
28
+ batch_size: int = 2,
29
+ dynamic_rope: bool = False,
30
+ add_second_input: bool = False,
31
+ **kwargs, # unused
32
+ ):
33
+ """
34
+ Generates inputs for task ``object-detection``.
35
+
36
+ :param model: model to get the missing information
37
+ :param config: configuration used to generate the model
38
+ :param batch_size: batch size
39
+ :param input_channels: input channel
40
+ :param input_width: input width
41
+ :param input_height: input height
42
+ :return: dictionary
43
+ """
44
+ assert isinstance(
45
+ input_width, int
46
+ ), f"Unexpected type for input_width {type(input_width)}{config}"
47
+ assert isinstance(
48
+ input_width, int
49
+ ), f"Unexpected type for input_height {type(input_height)}{config}"
50
+
51
+ shapes = {
52
+ "pixel_values": {
53
+ 0: torch.export.Dim("batch", min=1, max=1024),
54
+ 2: "width",
55
+ 3: "height",
56
+ }
57
+ }
58
+ inputs = dict(
59
+ pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
60
+ -1, 1
61
+ ),
62
+ )
63
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
64
+ if add_second_input:
65
+ res["inputs2"] = get_inputs(
66
+ model=model,
67
+ config=config,
68
+ input_width=input_width + 1,
69
+ input_height=input_height + 1,
70
+ input_channels=input_channels,
71
+ batch_size=batch_size + 1,
72
+ dynamic_rope=dynamic_rope,
73
+ **kwargs,
74
+ )["inputs"]
75
+ return res
76
+
77
+
78
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
79
+ """
80
+ Inputs kwargs.
81
+
82
+ If the configuration is None, the function selects typical dimensions.
83
+ """
84
+ if config is not None:
85
+ if (
86
+ hasattr(config, "model_type")
87
+ and config.model_type == "timm_wrapper"
88
+ and not hasattr(config, "num_hidden_layers")
89
+ ):
90
+ input_size = config.pretrained_cfg["input_size"]
91
+ kwargs = dict(
92
+ batch_size=2,
93
+ input_width=input_size[-2],
94
+ input_height=input_size[-1],
95
+ input_channels=input_size[-3],
96
+ )
97
+ return kwargs, get_inputs
98
+
99
+ check_hasattr(config, ("image_size", "architectures"), "num_channels")
100
+ if config is not None:
101
+ if hasattr(config, "image_size"):
102
+ image_size = config.image_size
103
+ else:
104
+ assert config.architectures, f"empty architecture in {config}"
105
+ from ..torch_models.hghub.hub_api import get_architecture_default_values
106
+
107
+ default_values = get_architecture_default_values(config.architectures[0])
108
+ image_size = default_values["image_size"]
109
+ if config is None or isinstance(image_size, int):
110
+ kwargs = dict(
111
+ batch_size=2,
112
+ input_width=224 if config is None else image_size,
113
+ input_height=224 if config is None else image_size,
114
+ input_channels=3 if config is None else config.num_channels,
115
+ )
116
+ else:
117
+ kwargs = dict(
118
+ batch_size=2,
119
+ input_width=config.image_size[0],
120
+ input_height=config.image_size[1],
121
+ input_channels=config.num_channels,
122
+ )
123
+ return kwargs, get_inputs
@@ -19,12 +19,11 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
19
19
  ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
20
20
  "num_hidden_layers",
21
21
  ("num_key_value_heads", "num_attention_heads", "use_mambapy"),
22
- "intermediate_size",
23
22
  "hidden_size",
24
23
  "vocab_size",
25
24
  )
26
25
  if config.__class__.__name__ == "FalconMambaConfig":
27
- check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
26
+ check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
28
27
  kwargs = dict(
29
28
  num_hidden_layers=min(config.num_hidden_layers, 2),
30
29
  intermediate_size=256 if config is None else min(512, config.intermediate_size),
@@ -44,17 +43,18 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
44
43
  if hasattr(config, "num_key_value_heads")
45
44
  else config.num_attention_heads
46
45
  ),
47
- intermediate_size=(
48
- min(config.intermediate_size, 24576 // 4)
49
- if config.intermediate_size % 4 == 0
50
- else config.intermediate_size
51
- ),
52
46
  hidden_size=(
53
47
  min(config.hidden_size, 3072 // 4)
54
48
  if config.hidden_size % 4 == 0
55
49
  else config.hidden_size
56
50
  ),
57
51
  )
52
+ if config is None or hasattr(config, "intermediate_size"):
53
+ kwargs["intermediate_size"] = (
54
+ min(config.intermediate_size, 24576 // 4)
55
+ if config.intermediate_size % 4 == 0
56
+ else config.intermediate_size
57
+ )
58
58
  update_config(config, kwargs)
59
59
  return kwargs
60
60
 
@@ -228,11 +228,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
228
228
  "vocab_size",
229
229
  ("num_attention_heads", "use_mambapy"),
230
230
  ("num_key_value_heads", "num_attention_heads", "use_mambapy"),
231
- "intermediate_size",
232
231
  "hidden_size",
233
232
  )
234
233
  if config.__class__.__name__ == "FalconMambaConfig":
235
- check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
234
+ check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
236
235
  kwargs = dict(
237
236
  batch_size=2,
238
237
  sequence_length=30,
@@ -263,7 +262,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
263
262
  if config is None
264
263
  else _pick(config, "num_key_value_heads", "num_attention_heads")
265
264
  ),
266
- intermediate_size=1024 if config is None else config.intermediate_size,
267
265
  hidden_size=512 if config is None else config.hidden_size,
268
266
  )
267
+ if config is None or hasattr(config, "intermediate_size"):
268
+ kwargs["intermediate_size"] = (
269
+ 1024 if config is None else config.intermediate_size,
270
+ )
271
+
269
272
  return kwargs, get_inputs
@@ -1,4 +1,20 @@
1
1
  from .onnx_export_errors import (
2
- bypass_export_some_errors,
2
+ torch_export_patches,
3
3
  register_additional_serialization_functions,
4
4
  )
5
+
6
+
7
+ # bypass_export_some_errors is the first name given to the patches.
8
+ bypass_export_some_errors = torch_export_patches # type: ignore
9
+
10
+
11
+ def register_flattening_functions(verbose: int = 0):
12
+ """
13
+ Registers functions to serialize deserialize cache or other classes
14
+ implemented in :epkg:`transformers` and used as inputs.
15
+ This is needed whenever a model must be exported through
16
+ :func:`torch.export.export`.
17
+ """
18
+ from .onnx_export_serialization import _register_cache_serialization
19
+
20
+ return _register_cache_serialization(verbose=verbose)
@@ -93,7 +93,7 @@ def register_additional_serialization_functions(
93
93
 
94
94
 
95
95
  @contextlib.contextmanager
96
- def bypass_export_some_errors(
96
+ def torch_export_patches(
97
97
  patch_sympy: bool = True,
98
98
  patch_torch: bool = True,
99
99
  patch_transformers: bool = False,
@@ -145,13 +145,13 @@ def bypass_export_some_errors(
145
145
 
146
146
  ::
147
147
 
148
- with bypass_export_some_errors(patch_transformers=True) as modificator:
148
+ with torch_export_patches(patch_transformers=True) as modificator:
149
149
  inputs = modificator(inputs)
150
150
  onx = to_onnx(..., inputs, ...)
151
151
 
152
152
  ::
153
153
 
154
- with bypass_export_some_errors(patch_transformers=True) as modificator:
154
+ with torch_export_patches(patch_transformers=True) as modificator:
155
155
  inputs = modificator(inputs)
156
156
  onx = torch.onnx.export(..., inputs, ...)
157
157
 
@@ -159,7 +159,7 @@ def bypass_export_some_errors(
159
159
 
160
160
  ::
161
161
 
162
- with bypass_export_some_errors(patch_transformers=True) as modificator:
162
+ with torch_export_patches(patch_transformers=True) as modificator:
163
163
  inputs = modificator(inputs)
164
164
  ep = torch.export.export(..., inputs, ...)
165
165
 
@@ -190,7 +190,7 @@ def bypass_export_some_errors(
190
190
 
191
191
  if verbose:
192
192
  print(
193
- "[bypass_export_some_errors] replace torch.jit.isinstance, "
193
+ "[torch_export_patches] replace torch.jit.isinstance, "
194
194
  "torch._dynamo.mark_static_address"
195
195
  )
196
196
 
@@ -210,8 +210,8 @@ def bypass_export_some_errors(
210
210
  f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
211
211
 
212
212
  if verbose:
213
- print(f"[bypass_export_some_errors] sympy.__version__={sympy.__version__!r}")
214
- print("[bypass_export_some_errors] patch sympy")
213
+ print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
214
+ print("[torch_export_patches] patch sympy")
215
215
 
216
216
  sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
217
217
 
@@ -228,9 +228,9 @@ def bypass_export_some_errors(
228
228
  )
229
229
 
230
230
  if verbose:
231
- print(f"[bypass_export_some_errors] torch.__version__={torch.__version__!r}")
232
- print(f"[bypass_export_some_errors] stop_if_static={stop_if_static!r}")
233
- print("[bypass_export_some_errors] patch pytorch")
231
+ print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
232
+ print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
233
+ print("[torch_export_patches] patch pytorch")
234
234
 
235
235
  # torch.jit.isinstance
236
236
  f_jit_isinstance = torch.jit.isinstance
@@ -252,7 +252,7 @@ def bypass_export_some_errors(
252
252
  # torch._export.non_strict_utils.produce_guards_and_solve_constraints
253
253
  if catch_constraints:
254
254
  if verbose:
255
- print("[bypass_export_some_errors] modifies shape constraints")
255
+ print("[torch_export_patches] modifies shape constraints")
256
256
  f_produce_guards_and_solve_constraints = (
257
257
  torch._export.non_strict_utils.produce_guards_and_solve_constraints
258
258
  )
@@ -277,22 +277,20 @@ def bypass_export_some_errors(
277
277
  ShapeEnv._log_guard_remember = ShapeEnv._log_guard
278
278
 
279
279
  if verbose:
280
- print(
281
- "[bypass_export_some_errors] assert when a dynamic dimension turns static"
282
- )
283
- print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
280
+ print("[torch_export_patches] assert when a dynamic dimension turns static")
281
+ print("[torch_export_patches] replaces ShapeEnv._set_replacement")
284
282
 
285
283
  f_shape_env__set_replacement = ShapeEnv._set_replacement
286
284
  ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
287
285
 
288
286
  if verbose:
289
- print("[bypass_export_some_errors] replaces ShapeEnv._log_guard")
287
+ print("[torch_export_patches] replaces ShapeEnv._log_guard")
290
288
  f_shape_env__log_guard = ShapeEnv._log_guard
291
289
  ShapeEnv._log_guard = patched_ShapeEnv._log_guard
292
290
 
293
291
  if stop_if_static > 1:
294
292
  if verbose:
295
- print("[bypass_export_some_errors] replaces ShapeEnv._check_frozen")
293
+ print("[torch_export_patches] replaces ShapeEnv._check_frozen")
296
294
  f_shape_env__check_frozen = ShapeEnv._check_frozen
297
295
  ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
298
296
 
@@ -305,7 +303,7 @@ def bypass_export_some_errors(
305
303
  import transformers
306
304
 
307
305
  print(
308
- f"[bypass_export_some_errors] transformers.__version__="
306
+ f"[torch_export_patches] transformers.__version__="
309
307
  f"{transformers.__version__!r}"
310
308
  )
311
309
  revert_patches_info = patch_module_or_classes(
@@ -314,7 +312,7 @@ def bypass_export_some_errors(
314
312
 
315
313
  if custom_patches:
316
314
  if verbose:
317
- print("[bypass_export_some_errors] applies custom patches")
315
+ print("[torch_export_patches] applies custom patches")
318
316
  revert_custom_patches_info = patch_module_or_classes(
319
317
  custom_patches, verbose=verbose
320
318
  )
@@ -326,7 +324,7 @@ def bypass_export_some_errors(
326
324
  fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
327
325
 
328
326
  if verbose:
329
- print("[bypass_export_some_errors] done patching")
327
+ print("[torch_export_patches] done patching")
330
328
 
331
329
  try:
332
330
  yield fct_callable
@@ -336,7 +334,7 @@ def bypass_export_some_errors(
336
334
  #######
337
335
 
338
336
  if verbose:
339
- print("[bypass_export_some_errors] remove patches")
337
+ print("[torch_export_patches] remove patches")
340
338
 
341
339
  if patch_sympy:
342
340
  # tracked by https://github.com/pytorch/pytorch/issues/143494
@@ -346,7 +344,7 @@ def bypass_export_some_errors(
346
344
  delattr(sympy.core.numbers.IntegerConstant, "name")
347
345
 
348
346
  if verbose:
349
- print("[bypass_export_some_errors] restored sympy functions")
347
+ print("[torch_export_patches] restored sympy functions")
350
348
 
351
349
  #######
352
350
  # torch
@@ -362,22 +360,22 @@ def bypass_export_some_errors(
362
360
  torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
363
361
 
364
362
  if verbose:
365
- print("[bypass_export_some_errors] restored pytorch functions")
363
+ print("[torch_export_patches] restored pytorch functions")
366
364
 
367
365
  if stop_if_static:
368
366
  if verbose:
369
- print("[bypass_export_some_errors] restored ShapeEnv._set_replacement")
367
+ print("[torch_export_patches] restored ShapeEnv._set_replacement")
370
368
 
371
369
  ShapeEnv._set_replacement = f_shape_env__set_replacement
372
370
 
373
371
  if verbose:
374
- print("[bypass_export_some_errors] restored ShapeEnv._log_guard")
372
+ print("[torch_export_patches] restored ShapeEnv._log_guard")
375
373
 
376
374
  ShapeEnv._log_guard = f_shape_env__log_guard
377
375
 
378
376
  if stop_if_static > 1:
379
377
  if verbose:
380
- print("[bypass_export_some_errors] restored ShapeEnv._check_frozen")
378
+ print("[torch_export_patches] restored ShapeEnv._check_frozen")
381
379
  ShapeEnv._check_frozen = f_shape_env__check_frozen
382
380
 
383
381
  if catch_constraints:
@@ -389,11 +387,11 @@ def bypass_export_some_errors(
389
387
  f__check_input_constraints_for_graph
390
388
  )
391
389
  if verbose:
392
- print("[bypass_export_some_errors] restored shape constraints")
390
+ print("[torch_export_patches] restored shape constraints")
393
391
 
394
392
  if custom_patches:
395
393
  if verbose:
396
- print("[bypass_export_some_errors] unpatch custom patches")
394
+ print("[torch_export_patches] unpatch custom patches")
397
395
  unpatch_module_or_classes(
398
396
  custom_patches, revert_custom_patches_info, verbose=verbose
399
397
  )
@@ -404,7 +402,7 @@ def bypass_export_some_errors(
404
402
 
405
403
  if patch_transformers:
406
404
  if verbose:
407
- print("[bypass_export_some_errors] unpatch transformers")
405
+ print("[torch_export_patches] unpatch transformers")
408
406
  unpatch_module_or_classes(
409
407
  patch_transformers_list, revert_patches_info, verbose=verbose
410
408
  )