onnx-diagnostic 0.7.6__py3-none-any.whl → 0.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +56 -3
  3. onnx_diagnostic/export/dynamic_shapes.py +24 -10
  4. onnx_diagnostic/export/shape_helper.py +6 -2
  5. onnx_diagnostic/helpers/cache_helper.py +79 -6
  6. onnx_diagnostic/helpers/config_helper.py +10 -0
  7. onnx_diagnostic/helpers/helper.py +6 -1
  8. onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
  9. onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
  10. onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
  11. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  12. onnx_diagnostic/tasks/feature_extraction.py +7 -3
  13. onnx_diagnostic/tasks/fill_mask.py +6 -2
  14. onnx_diagnostic/tasks/image_classification.py +6 -2
  15. onnx_diagnostic/tasks/image_text_to_text.py +33 -10
  16. onnx_diagnostic/tasks/mask_generation.py +6 -2
  17. onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
  18. onnx_diagnostic/tasks/object_detection.py +6 -2
  19. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  20. onnx_diagnostic/tasks/summarization.py +7 -2
  21. onnx_diagnostic/tasks/text2text_generation.py +7 -2
  22. onnx_diagnostic/tasks/text_classification.py +6 -2
  23. onnx_diagnostic/tasks/text_generation.py +8 -14
  24. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
  25. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  26. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
  27. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +119 -0
  28. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +3 -1
  29. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
  30. onnx_diagnostic/torch_models/validate.py +1 -0
  31. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.7.dist-info}/METADATA +2 -2
  32. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.7.dist-info}/RECORD +35 -35
  33. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.7.dist-info}/WHEEL +0 -0
  34. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.7.dist-info}/licenses/LICENSE.txt +0 -0
  35. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.7.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.6"
6
+ __version__ = "0.7.7"
7
7
  __author__ = "Xavier Dupré"
@@ -306,7 +306,7 @@ class _ParseDict(argparse.Action):
306
306
  value = split_items[1]
307
307
 
308
308
  if value in ("True", "true", "False", "false"):
309
- d[key] = bool(value)
309
+ d[key] = value in ("True", "true")
310
310
  continue
311
311
  try:
312
312
  d[key] = int(value)
@@ -323,6 +323,54 @@ class _ParseDict(argparse.Action):
323
323
  setattr(namespace, self.dest, d)
324
324
 
325
325
 
326
+ class _BoolOrParseDictPatch(argparse.Action):
327
+ def __call__(self, parser, namespace, values, option_string=None):
328
+
329
+ if not values:
330
+ return
331
+ if len(values) == 1 and values[0] in (
332
+ "True",
333
+ "False",
334
+ "true",
335
+ "false",
336
+ "0",
337
+ "1",
338
+ 0,
339
+ 1,
340
+ ):
341
+ setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
342
+ return
343
+ d = getattr(namespace, self.dest) or {}
344
+ if not isinstance(d, dict):
345
+ d = {
346
+ "patch_sympy": d,
347
+ "patch_torch": d,
348
+ "patch_transformers": d,
349
+ "patch_diffusers": d,
350
+ }
351
+ for item in values:
352
+ split_items = item.split("=", 1)
353
+ key = split_items[0].strip() # we remove blanks around keys, as is logical
354
+ value = split_items[1]
355
+
356
+ if value in ("True", "true", "False", "false"):
357
+ d[key] = value in ("True", "true")
358
+ continue
359
+ try:
360
+ d[key] = int(value)
361
+ continue
362
+ except (TypeError, ValueError):
363
+ pass
364
+ try:
365
+ d[key] = float(value)
366
+ continue
367
+ except (TypeError, ValueError):
368
+ pass
369
+ d[key] = _parse_json(value)
370
+
371
+ setattr(namespace, self.dest, d)
372
+
373
+
326
374
  def get_parser_validate() -> ArgumentParser:
327
375
  parser = ArgumentParser(
328
376
  prog="validate",
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
383
431
  parser.add_argument(
384
432
  "--patch",
385
433
  default=True,
386
- action=BooleanOptionalAction,
387
- help="Applies patches before exporting.",
434
+ action=_BoolOrParseDictPatch,
435
+ nargs="*",
436
+ help="Applies patches before exporting, it can be a boolean "
437
+ "to enable to disable the patches or be more finetuned. It is possible to "
438
+ "disable patch for torch by adding "
439
+ '--patch "patch_sympy=False" --patch "patch_torch=False", '
440
+ "default is True.",
388
441
  )
389
442
  parser.add_argument(
390
443
  "--rewrite",
@@ -887,19 +887,30 @@ class ModelInputs:
887
887
 
888
888
  # In case DynamicCache is not registered.
889
889
  if obj.__class__.__name__ == "DynamicCache":
890
- kc = set(len(o.key_cache) for o in objs)
891
- assert (
892
- len(kc) == 1
893
- ), f"All attribute 'key_cache' should have the same length but found {kc}"
894
- vc = set(len(o.value_cache) for o in objs)
895
- assert (
896
- len(vc) == 1
897
- ), f"All attribute 'value_cache' should have the same length but found {vc}"
890
+ if hasattr(obj, "layers"):
891
+ kc = set(len(o.layers) for o in objs)
892
+ assert (
893
+ len(kc) == 1
894
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
895
+ vc = kc.copy()
896
+ else:
897
+ kc = set(len(o.key_cache) for o in objs)
898
+ assert (
899
+ len(kc) == 1
900
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
901
+ vc = set(len(o.value_cache) for o in objs)
902
+ assert (
903
+ len(vc) == 1
904
+ ), f"All attribute 'value_cache' should have the same length but found {vc}"
905
+
898
906
  key_cache = []
899
907
  for i in range(kc.pop()):
900
908
  key_cache.append(
901
909
  self.guess_dynamic_dimensions(
902
- *[o.key_cache[i] for o in objs],
910
+ *[
911
+ o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
912
+ for o in objs
913
+ ],
903
914
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
904
915
  )
905
916
  )
@@ -907,7 +918,10 @@ class ModelInputs:
907
918
  for i in range(vc.pop()):
908
919
  value_cache.append(
909
920
  self.guess_dynamic_dimensions(
910
- *[o.value_cache[i] for o in objs],
921
+ *[
922
+ o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
923
+ for o in objs
924
+ ],
911
925
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
912
926
  )
913
927
  )
@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
9
9
  All dimensions are considered as dynamic.
10
10
  ``dim_prefix`` can be a string (the function uses it as a prefix),
11
11
  or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12
+ Depending on the version of transformers, serializations function
13
+ of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
12
14
 
13
15
  .. runpython::
14
16
  :showcode:
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
17
19
  import torch
18
20
  from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19
21
  from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
22
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
20
23
 
21
24
  bsize, nheads, slen, dim = 2, 1, 30, 96
22
25
  inputs = dict(
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
25
28
  position_ids=torch.arange(3, dtype=torch.int64),
26
29
  past_key_values=make_dynamic_cache(
27
30
  [(torch.randn(bsize, nheads, slen, dim),
28
- torch.randn(bsize, nheads, slen, dim))]
31
+ torch.randn(bsize, nheads, slen, dim))]
29
32
  ),
30
33
  )
31
- ds = all_dynamic_shape_from_inputs(inputs)
34
+ with torch_export_patches(patch_transformers=True):
35
+ ds = all_dynamic_shape_from_inputs(inputs)
32
36
  pprint.pprint(ds)
33
37
 
34
38
  For this function to work, patches must be enabled if :epkg:`transformers`
@@ -41,9 +41,14 @@ class CacheKeyValue:
41
41
  f"or value_cache={string_type(self.value_cache)}, "
42
42
  f"cache.layers={string_type(cache.layers)}"
43
43
  )
44
- elif cache is not None:
44
+ elif cache is not None and hasattr(cache, "key_cache"):
45
45
  self.key_cache = cache.key_cache
46
46
  self.value_cache = cache.value_cache
47
+ elif cache is None:
48
+ self.key_cache = None
49
+ self.value_cache = None
50
+ else:
51
+ raise NotImplementedError(f"type(cache)={type(cache)}")
47
52
 
48
53
  def make_dynamic_cache(self):
49
54
  """Do the reverse operation."""
@@ -126,6 +131,8 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
126
131
  )
127
132
  values, spec = torch.utils._pytree.tree_flatten(cache)
128
133
  cache2 = torch.utils._pytree.tree_unflatten(values, spec)
134
+ if hasattr(cache2, "layers") and hasattr(cache, "layers"):
135
+ return len(cache2.layers) == len(cache.layers)
129
136
  return len(cache2.key_cache) == len(cache.value_cache)
130
137
 
131
138
 
@@ -176,7 +183,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
176
183
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
177
184
  f"{len(key_value_pairs)} expected."
178
185
  )
179
- return cache
186
+ return finalize_cache(cache)
180
187
 
181
188
  else:
182
189
 
@@ -260,6 +267,9 @@ def make_static_cache(
260
267
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
261
268
  self.num_hidden_layers = len(key_value_pairs)
262
269
 
270
+ def get_text_config(self):
271
+ return self
272
+
263
273
  assert max_cache_len is not None, (
264
274
  f"max_cache_len={max_cache_len} cannot be setup "
265
275
  f"automatically yet from shape {key_value_pairs[0][0].shape}"
@@ -280,6 +290,33 @@ def make_static_cache(
280
290
  max_cache_len=max_cache_len,
281
291
  )
282
292
  ca = CacheKeyValue(cache)
293
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
294
+ # transformers>= 4.55.2, layers are empty
295
+ for i, (key, value) in enumerate(key_value_pairs):
296
+ cache.update(key, value, i)
297
+ return cache
298
+
299
+ torch._check(
300
+ not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
301
+ lambda: (
302
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
303
+ f"len(cache.layers)={len(cache.layers)}"
304
+ ),
305
+ )
306
+ torch._check(
307
+ len(key_value_pairs) == len(ca.key_cache),
308
+ lambda: (
309
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
310
+ f"len(ca.key_cache)={len(ca.key_cache)}"
311
+ ),
312
+ )
313
+ torch._check(
314
+ len(key_value_pairs) == len(ca.value_cache),
315
+ lambda: (
316
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
317
+ f"len(ca.value_cache)={len(ca.value_cache)}"
318
+ ),
319
+ )
283
320
  for i in range(len(key_value_pairs)):
284
321
  assert (
285
322
  key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
@@ -298,7 +335,7 @@ def make_static_cache(
298
335
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
299
336
  f"{len(key_value_pairs)} expected."
300
337
  )
301
- return cache
338
+ return finalize_cache(cache)
302
339
 
303
340
 
304
341
  def make_encoder_decoder_cache(
@@ -307,7 +344,10 @@ def make_encoder_decoder_cache(
307
344
  ) -> transformers.cache_utils.EncoderDecoderCache:
308
345
  """Creates an EncoderDecoderCache."""
309
346
  return transformers.cache_utils.EncoderDecoderCache(
310
- self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
347
+ # self_attention_cache=self_attention_cache,
348
+ # cross_attention_cache=cross_attention_cache
349
+ self_attention_cache,
350
+ cross_attention_cache,
311
351
  )
312
352
 
313
353
 
@@ -323,6 +363,9 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
323
363
  self.num_hidden_layers = len(key_value_pairs)
324
364
  self.dtype = dtype
325
365
 
366
+ def get_text_config(self):
367
+ return self
368
+
326
369
  cache = MambaCache(
327
370
  _config(),
328
371
  max_batch_size=key_value_pairs[0][0].shape[0],
@@ -348,7 +391,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
348
391
  f"got {key_value_pairs[i][1].shape}"
349
392
  )
350
393
  cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
351
- return cache
394
+ return finalize_cache(cache)
352
395
 
353
396
 
354
397
  def make_sliding_window_cache(
@@ -363,6 +406,9 @@ def make_sliding_window_cache(
363
406
  self.num_hidden_layers = len(key_value_pairs)
364
407
  self.sliding_window = key_value_pairs[0][0].shape[2]
365
408
 
409
+ def get_text_config(self):
410
+ return self
411
+
366
412
  cache = transformers.cache_utils.SlidingWindowCache(
367
413
  config=_config(),
368
414
  max_batch_size=key_value_pairs[0][0].shape[0],
@@ -371,6 +417,13 @@ def make_sliding_window_cache(
371
417
  dtype=key_value_pairs[0][0].dtype,
372
418
  )
373
419
  ca = CacheKeyValue(cache)
420
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
421
+ # transformers>= 4.55.2, layers are empty
422
+ cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
423
+ for i, (key, value) in enumerate(key_value_pairs):
424
+ cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
425
+ return cache
426
+
374
427
  for i in range(len(key_value_pairs)):
375
428
  assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
376
429
  f"Shape mismatch, expected {cache.key_cache[i].shape}, "
@@ -393,7 +446,7 @@ def make_sliding_window_cache(
393
446
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
394
447
  f"{len(key_value_pairs)} expected."
395
448
  )
396
- return cache
449
+ return finalize_cache(cache)
397
450
 
398
451
 
399
452
  def make_hybrid_cache(
@@ -521,6 +574,9 @@ def make_hybrid_cache(
521
574
  sliding_window = _sliding_window
522
575
  num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
523
576
 
577
+ def get_text_config(self):
578
+ return self
579
+
524
580
  if layer_types:
525
581
  _config.layer_types = layer_types # type: ignore[attr-defined]
526
582
 
@@ -549,4 +605,21 @@ def make_hybrid_cache(
549
605
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
550
606
  f"{len(key_value_pairs)} expected."
551
607
  )
608
+ return finalize_cache(cache)
609
+
610
+
611
+ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
612
+ """
613
+ Ensures the created cache is consistent.
614
+ Returns the cache modified inplace.
615
+ """
616
+ if (
617
+ hasattr(cache, "layer_class_to_replicate")
618
+ and hasattr(cache, "layers")
619
+ and cache.layers
620
+ and not cache.layer_class_to_replicate
621
+ ):
622
+ # This is used to expand the cache when it does not contains enough layers.
623
+ # This is needed since transformers>4.55.3
624
+ cache.layer_class_to_replicate = cache.layers[0].__class__
552
625
  return cache
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import importlib
3
3
  import inspect
4
+ import os
4
5
  import re
5
6
  from typing import Any, Callable, Dict, Optional, Tuple, Union
6
7
  import transformers
@@ -110,3 +111,12 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
110
111
  )
111
112
  cls_name = unique.pop()
112
113
  return getattr(transformers, cls_name)
114
+
115
+
116
+ def default_num_hidden_layers():
117
+ """
118
+ Returns the default number of layers.
119
+ It is lower when the unit tests are running
120
+ when ``UNITTEST_GOING=1``.
121
+ """
122
+ return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
@@ -36,11 +36,12 @@ def size_type(dtype: Any) -> int:
36
36
  TensorProto.FLOAT8E4M3FNUZ,
37
37
  TensorProto.FLOAT8E5M2,
38
38
  TensorProto.FLOAT8E5M2FNUZ,
39
+ getattr(TensorProto, "FLOAT8E8M0", None),
39
40
  }:
40
41
  return 1
41
42
  if dtype in {TensorProto.COMPLEX128}:
42
43
  return 16
43
- from .helpers.onnx_helper import onnx_dtype_name
44
+ from .onnx_helper import onnx_dtype_name
44
45
 
45
46
  raise AssertionError(
46
47
  f"Unable to return the element size for type {onnx_dtype_name(dtype)}"
@@ -1478,8 +1479,12 @@ def max_diff(
1478
1479
  # backup function in case pytorch does not know how to serialize.
1479
1480
  if expected.__class__.__name__ == "DynamicCache":
1480
1481
  if got.__class__.__name__ == "DynamicCache":
1482
+ from .cache_helper import CacheKeyValue
1483
+
1481
1484
  if verbose >= 6:
1482
1485
  print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1486
+ expected = CacheKeyValue(expected)
1487
+ got = CacheKeyValue(got)
1483
1488
  return max_diff(
1484
1489
  [expected.key_cache, expected.value_cache],
1485
1490
  [got.key_cache, got.value_cache],
@@ -11,22 +11,26 @@ try:
11
11
  float8e5m2fnuz,
12
12
  )
13
13
  except ImportError:
14
+ bfloat16 = None
14
15
  from onnx.reference.ops.op_cast import cast_to
15
16
  from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
16
17
 
17
18
 
18
19
  def _cast_like(x, y, saturate):
19
- if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
20
- # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
21
- to = TensorProto.BFLOAT16
22
- elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
23
- to = TensorProto.FLOAT8E4M3FN
24
- elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
25
- to = TensorProto.FLOAT8E4M3FNUZ
26
- elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
27
- to = TensorProto.FLOAT8E5M2
28
- elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
29
- to = TensorProto.FLOAT8E5M2FNUZ
20
+ if bfloat16 is not None:
21
+ if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
22
+ # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
23
+ to = TensorProto.BFLOAT16
24
+ elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
25
+ to = TensorProto.FLOAT8E4M3FN
26
+ elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
27
+ to = TensorProto.FLOAT8E4M3FNUZ
28
+ elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
29
+ to = TensorProto.FLOAT8E5M2
30
+ elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
31
+ to = TensorProto.FLOAT8E5M2FNUZ
32
+ else:
33
+ to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
30
34
  else:
31
35
  to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
32
36
  return (cast_to(x, to, saturate),)
@@ -45,6 +45,7 @@ from .unary_ops import (
45
45
  Erf_9,
46
46
  Exp_1,
47
47
  Identity_1,
48
+ IsNaN_9,
48
49
  Log_1,
49
50
  Neg_1,
50
51
  Not_1,
@@ -37,6 +37,13 @@ class Identity_1(OpRunKernel):
37
37
  return OpRunTensor(x.tensor)
38
38
 
39
39
 
40
+ class IsNaN_9(OpRunKernel):
41
+ """IsNaN"""
42
+
43
+ def run(self, x: OpRunTensor) -> OpRunTensor:
44
+ return OpRunTensor(x.tensor.isnan())
45
+
46
+
40
47
  class Log_1(OpRunKernel):
41
48
  """Log"""
42
49
 
@@ -2,7 +2,11 @@ from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  import transformers
4
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
- from ..helpers.config_helper import update_config, check_hasattr
5
+ from ..helpers.config_helper import (
6
+ update_config,
7
+ check_hasattr,
8
+ default_num_hidden_layers as nhl,
9
+ )
6
10
 
7
11
  __TASK__ = "automatic-speech-recognition"
8
12
 
@@ -15,7 +19,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
19
  if hasattr(config, "decoder_layers"):
16
20
  config.decoder_layers = min(config.decoder_layers, 2)
17
21
  if hasattr(config, "num_hidden_layers"):
18
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
22
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
19
23
  update_config(config, kwargs)
20
24
  return kwargs
21
25
 
@@ -1,15 +1,20 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
9
 
10
+
6
11
  __TASK__ = "feature-extraction"
7
12
 
8
13
 
9
14
  def reduce_model_config(config: Any) -> Dict[str, Any]:
10
15
  """Reduces a model size."""
11
16
  check_hasattr(config, "num_hidden_layers")
12
- kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, 2))
17
+ kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
13
18
  update_config(config, kwargs)
14
19
  return kwargs
15
20
 
@@ -160,5 +165,4 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
160
165
  if hasattr(config, att):
161
166
  kwargs[att] = getattr(config, att)
162
167
  kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
163
- print(kwargs)
164
168
  return kwargs, get_inputs
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "fill-mask"
6
10
 
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
13
  """Reduces a model size."""
10
14
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
15
  kwargs = dict(
12
- num_hidden_layers=min(config.num_hidden_layers, 2),
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
13
17
  num_attention_heads=min(config.num_attention_heads, 4),
14
18
  )
15
19
  update_config(config, kwargs)
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "image-classification"
6
10
 
@@ -17,7 +21,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
17
21
  check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
18
22
  kwargs = dict(
19
23
  num_hidden_layers=(
20
- min(config.num_hidden_layers, 2)
24
+ min(config.num_hidden_layers, nhl())
21
25
  if hasattr(config, "num_hidden_layers")
22
26
  else len(config.hidden_sizes)
23
27
  )
@@ -1,7 +1,12 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
4
- from ..helpers.config_helper import update_config, check_hasattr, _pick
4
+ from ..helpers.config_helper import (
5
+ update_config,
6
+ check_hasattr,
7
+ _pick,
8
+ default_num_hidden_layers as nhl,
9
+ )
5
10
 
6
11
  __TASK__ = "image-text-to-text"
7
12
 
@@ -10,7 +15,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
10
15
  """Reduces a model size."""
11
16
  kwargs: Dict[str, Any] = {}
12
17
  if hasattr(config, "num_hidden_layers"):
13
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
18
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
14
19
  if hasattr(config, "mm_tokens_per_image"):
15
20
  config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
16
21
  if hasattr(config, "vision_config"):
@@ -334,7 +339,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
334
339
  "hidden_size",
335
340
  "pad_token_id",
336
341
  )
337
- check_hasattr(config, "vision_config", "image_token_index")
342
+ check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
338
343
  text_config = True
339
344
  else:
340
345
  check_hasattr(
@@ -348,7 +353,7 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
348
353
  "vision_config",
349
354
  )
350
355
  text_config = False
351
- check_hasattr(config.vision_config, "image_size", "num_channels")
356
+ check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
352
357
  kwargs = dict(
353
358
  batch_size=2,
354
359
  sequence_length=43,
@@ -410,18 +415,36 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
410
415
  if config is None
411
416
  else (config.text_config.hidden_size if text_config else config.hidden_size)
412
417
  ),
413
- width=224 if config is None else config.vision_config.image_size,
414
- height=224 if config is None else config.vision_config.image_size,
415
- num_channels=3 if config is None else config.vision_config.num_channels,
418
+ width=(
419
+ 224
420
+ if config is None or not hasattr(config.vision_config, "image_size")
421
+ else config.vision_config.image_size
422
+ ),
423
+ height=(
424
+ 224
425
+ if config is None or not hasattr(config.vision_config, "image_size")
426
+ else config.vision_config.image_size
427
+ ),
428
+ num_channels=(
429
+ 3
430
+ if config is None
431
+ else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
432
+ ),
416
433
  pad_token_id=(
417
434
  0
418
- if config is None or not hasattr(config, "text_config")
435
+ if config is None
436
+ or not hasattr(config, "text_config")
437
+ or not hasattr(config.text_config, "pad_token_id")
419
438
  else config.text_config.pad_token_id
420
439
  ),
421
440
  image_token_index=(
422
441
  4
423
- if config is None or not hasattr(config, "image_token_index")
424
- else config.image_token_index
442
+ if config is None
443
+ or (
444
+ not hasattr(config, "image_token_index")
445
+ and not hasattr(config, "image_token_id")
446
+ )
447
+ else _pick(config, "image_token_index", "image_token_id")
425
448
  ),
426
449
  )
427
450
  return kwargs, get_inputs
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "mask-generation"
6
10
 
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
13
  """Reduces a model size."""
10
14
  kwargs: Dict[str, Any] = {}
11
15
  if hasattr(config, "num_hidden_layers"):
12
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
16
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
13
17
  if hasattr(config, "vision_config") and hasattr(config.vision_config, "num_hidden_layers"):
14
18
  config.vision_config.num_hidden_layers = min(config.vision_config.num_hidden_layers, 2)
15
19
  update_config(config, kwargs)