onnx-diagnostic 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +56 -3
  3. onnx_diagnostic/export/dynamic_shapes.py +24 -10
  4. onnx_diagnostic/export/shape_helper.py +6 -2
  5. onnx_diagnostic/helpers/cache_helper.py +83 -7
  6. onnx_diagnostic/helpers/config_helper.py +57 -0
  7. onnx_diagnostic/helpers/helper.py +6 -1
  8. onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
  9. onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
  10. onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
  11. onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
  12. onnx_diagnostic/tasks/feature_extraction.py +7 -3
  13. onnx_diagnostic/tasks/fill_mask.py +6 -2
  14. onnx_diagnostic/tasks/image_classification.py +6 -2
  15. onnx_diagnostic/tasks/image_text_to_text.py +48 -12
  16. onnx_diagnostic/tasks/mask_generation.py +6 -2
  17. onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
  18. onnx_diagnostic/tasks/object_detection.py +6 -2
  19. onnx_diagnostic/tasks/sentence_similarity.py +6 -2
  20. onnx_diagnostic/tasks/summarization.py +7 -2
  21. onnx_diagnostic/tasks/text2text_generation.py +7 -2
  22. onnx_diagnostic/tasks/text_classification.py +6 -2
  23. onnx_diagnostic/tasks/text_generation.py +8 -14
  24. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
  25. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  26. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
  27. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +227 -1
  28. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +3 -1
  29. onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
  30. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +70 -1
  31. onnx_diagnostic/torch_models/hghub/model_inputs.py +13 -1
  32. onnx_diagnostic/torch_models/validate.py +17 -0
  33. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/METADATA +2 -2
  34. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/RECORD +37 -37
  35. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/WHEEL +0 -0
  36. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/licenses/LICENSE.txt +0 -0
  37. {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.7.6"
6
+ __version__ = "0.7.8"
7
7
  __author__ = "Xavier Dupré"
@@ -306,7 +306,7 @@ class _ParseDict(argparse.Action):
306
306
  value = split_items[1]
307
307
 
308
308
  if value in ("True", "true", "False", "false"):
309
- d[key] = bool(value)
309
+ d[key] = value in ("True", "true")
310
310
  continue
311
311
  try:
312
312
  d[key] = int(value)
@@ -323,6 +323,54 @@ class _ParseDict(argparse.Action):
323
323
  setattr(namespace, self.dest, d)
324
324
 
325
325
 
326
+ class _BoolOrParseDictPatch(argparse.Action):
327
+ def __call__(self, parser, namespace, values, option_string=None):
328
+
329
+ if not values:
330
+ return
331
+ if len(values) == 1 and values[0] in (
332
+ "True",
333
+ "False",
334
+ "true",
335
+ "false",
336
+ "0",
337
+ "1",
338
+ 0,
339
+ 1,
340
+ ):
341
+ setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
342
+ return
343
+ d = getattr(namespace, self.dest) or {}
344
+ if not isinstance(d, dict):
345
+ d = {
346
+ "patch_sympy": d,
347
+ "patch_torch": d,
348
+ "patch_transformers": d,
349
+ "patch_diffusers": d,
350
+ }
351
+ for item in values:
352
+ split_items = item.split("=", 1)
353
+ key = split_items[0].strip() # we remove blanks around keys, as is logical
354
+ value = split_items[1]
355
+
356
+ if value in ("True", "true", "False", "false"):
357
+ d[key] = value in ("True", "true")
358
+ continue
359
+ try:
360
+ d[key] = int(value)
361
+ continue
362
+ except (TypeError, ValueError):
363
+ pass
364
+ try:
365
+ d[key] = float(value)
366
+ continue
367
+ except (TypeError, ValueError):
368
+ pass
369
+ d[key] = _parse_json(value)
370
+
371
+ setattr(namespace, self.dest, d)
372
+
373
+
326
374
  def get_parser_validate() -> ArgumentParser:
327
375
  parser = ArgumentParser(
328
376
  prog="validate",
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
383
431
  parser.add_argument(
384
432
  "--patch",
385
433
  default=True,
386
- action=BooleanOptionalAction,
387
- help="Applies patches before exporting.",
434
+ action=_BoolOrParseDictPatch,
435
+ nargs="*",
436
+ help="Applies patches before exporting, it can be a boolean "
437
+ "to enable to disable the patches or be more finetuned. It is possible to "
438
+ "disable patch for torch by adding "
439
+ '--patch "patch_sympy=False" --patch "patch_torch=False", '
440
+ "default is True.",
388
441
  )
389
442
  parser.add_argument(
390
443
  "--rewrite",
@@ -887,19 +887,30 @@ class ModelInputs:
887
887
 
888
888
  # In case DynamicCache is not registered.
889
889
  if obj.__class__.__name__ == "DynamicCache":
890
- kc = set(len(o.key_cache) for o in objs)
891
- assert (
892
- len(kc) == 1
893
- ), f"All attribute 'key_cache' should have the same length but found {kc}"
894
- vc = set(len(o.value_cache) for o in objs)
895
- assert (
896
- len(vc) == 1
897
- ), f"All attribute 'value_cache' should have the same length but found {vc}"
890
+ if hasattr(obj, "layers"):
891
+ kc = set(len(o.layers) for o in objs)
892
+ assert (
893
+ len(kc) == 1
894
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
895
+ vc = kc.copy()
896
+ else:
897
+ kc = set(len(o.key_cache) for o in objs)
898
+ assert (
899
+ len(kc) == 1
900
+ ), f"All attribute 'key_cache' should have the same length but found {kc}"
901
+ vc = set(len(o.value_cache) for o in objs)
902
+ assert (
903
+ len(vc) == 1
904
+ ), f"All attribute 'value_cache' should have the same length but found {vc}"
905
+
898
906
  key_cache = []
899
907
  for i in range(kc.pop()):
900
908
  key_cache.append(
901
909
  self.guess_dynamic_dimensions(
902
- *[o.key_cache[i] for o in objs],
910
+ *[
911
+ o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
912
+ for o in objs
913
+ ],
903
914
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
904
915
  )
905
916
  )
@@ -907,7 +918,10 @@ class ModelInputs:
907
918
  for i in range(vc.pop()):
908
919
  value_cache.append(
909
920
  self.guess_dynamic_dimensions(
910
- *[o.value_cache[i] for o in objs],
921
+ *[
922
+ o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
923
+ for o in objs
924
+ ],
911
925
  auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
912
926
  )
913
927
  )
@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
9
9
  All dimensions are considered as dynamic.
10
10
  ``dim_prefix`` can be a string (the function uses it as a prefix),
11
11
  or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
12
+ Depending on the version of transformers, serializations function
13
+ of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
12
14
 
13
15
  .. runpython::
14
16
  :showcode:
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
17
19
  import torch
18
20
  from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
19
21
  from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
22
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
20
23
 
21
24
  bsize, nheads, slen, dim = 2, 1, 30, 96
22
25
  inputs = dict(
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
25
28
  position_ids=torch.arange(3, dtype=torch.int64),
26
29
  past_key_values=make_dynamic_cache(
27
30
  [(torch.randn(bsize, nheads, slen, dim),
28
- torch.randn(bsize, nheads, slen, dim))]
31
+ torch.randn(bsize, nheads, slen, dim))]
29
32
  ),
30
33
  )
31
- ds = all_dynamic_shape_from_inputs(inputs)
34
+ with torch_export_patches(patch_transformers=True):
35
+ ds = all_dynamic_shape_from_inputs(inputs)
32
36
  pprint.pprint(ds)
33
37
 
34
38
  For this function to work, patches must be enabled if :epkg:`transformers`
@@ -41,9 +41,14 @@ class CacheKeyValue:
41
41
  f"or value_cache={string_type(self.value_cache)}, "
42
42
  f"cache.layers={string_type(cache.layers)}"
43
43
  )
44
- elif cache is not None:
44
+ elif cache is not None and hasattr(cache, "key_cache"):
45
45
  self.key_cache = cache.key_cache
46
46
  self.value_cache = cache.value_cache
47
+ elif cache is None:
48
+ self.key_cache = None
49
+ self.value_cache = None
50
+ else:
51
+ raise NotImplementedError(f"type(cache)={type(cache)}")
47
52
 
48
53
  def make_dynamic_cache(self):
49
54
  """Do the reverse operation."""
@@ -91,13 +96,16 @@ def flatten_unflatten_for_dynamic_shapes(
91
96
  return tuple(subtrees)
92
97
  if spec.type is list:
93
98
  return list(subtrees)
99
+ if spec.type is None and not subtrees:
100
+ return None
94
101
  if spec.context:
95
102
  # This is a custom class with attributes.
96
103
  # It is returned as a list.
97
104
  return list(subtrees)
98
105
  raise ValueError(
99
106
  f"Unable to interpret spec type {spec.type} "
100
- f"(type is {type(spec.type)}, context is {spec.context})."
107
+ f"(type is {type(spec.type)}, context is {spec.context}), "
108
+ f"spec={spec}, subtrees={subtrees}"
101
109
  )
102
110
  # This is a list.
103
111
  return subtrees
@@ -126,6 +134,8 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
126
134
  )
127
135
  values, spec = torch.utils._pytree.tree_flatten(cache)
128
136
  cache2 = torch.utils._pytree.tree_unflatten(values, spec)
137
+ if hasattr(cache2, "layers") and hasattr(cache, "layers"):
138
+ return len(cache2.layers) == len(cache.layers)
129
139
  return len(cache2.key_cache) == len(cache.value_cache)
130
140
 
131
141
 
@@ -176,7 +186,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
176
186
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
177
187
  f"{len(key_value_pairs)} expected."
178
188
  )
179
- return cache
189
+ return finalize_cache(cache)
180
190
 
181
191
  else:
182
192
 
@@ -260,6 +270,9 @@ def make_static_cache(
260
270
  self.num_attention_heads = key_value_pairs[0][0].shape[1]
261
271
  self.num_hidden_layers = len(key_value_pairs)
262
272
 
273
+ def get_text_config(self):
274
+ return self
275
+
263
276
  assert max_cache_len is not None, (
264
277
  f"max_cache_len={max_cache_len} cannot be setup "
265
278
  f"automatically yet from shape {key_value_pairs[0][0].shape}"
@@ -280,6 +293,33 @@ def make_static_cache(
280
293
  max_cache_len=max_cache_len,
281
294
  )
282
295
  ca = CacheKeyValue(cache)
296
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
297
+ # transformers>= 4.55.2, layers are empty
298
+ for i, (key, value) in enumerate(key_value_pairs):
299
+ cache.update(key, value, i)
300
+ return cache
301
+
302
+ torch._check(
303
+ not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
304
+ lambda: (
305
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
306
+ f"len(cache.layers)={len(cache.layers)}"
307
+ ),
308
+ )
309
+ torch._check(
310
+ len(key_value_pairs) == len(ca.key_cache),
311
+ lambda: (
312
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
313
+ f"len(ca.key_cache)={len(ca.key_cache)}"
314
+ ),
315
+ )
316
+ torch._check(
317
+ len(key_value_pairs) == len(ca.value_cache),
318
+ lambda: (
319
+ f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
320
+ f"len(ca.value_cache)={len(ca.value_cache)}"
321
+ ),
322
+ )
283
323
  for i in range(len(key_value_pairs)):
284
324
  assert (
285
325
  key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
@@ -298,7 +338,7 @@ def make_static_cache(
298
338
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
299
339
  f"{len(key_value_pairs)} expected."
300
340
  )
301
- return cache
341
+ return finalize_cache(cache)
302
342
 
303
343
 
304
344
  def make_encoder_decoder_cache(
@@ -307,7 +347,10 @@ def make_encoder_decoder_cache(
307
347
  ) -> transformers.cache_utils.EncoderDecoderCache:
308
348
  """Creates an EncoderDecoderCache."""
309
349
  return transformers.cache_utils.EncoderDecoderCache(
310
- self_attention_cache=self_attention_cache, cross_attention_cache=cross_attention_cache
350
+ # self_attention_cache=self_attention_cache,
351
+ # cross_attention_cache=cross_attention_cache
352
+ self_attention_cache,
353
+ cross_attention_cache,
311
354
  )
312
355
 
313
356
 
@@ -323,6 +366,9 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
323
366
  self.num_hidden_layers = len(key_value_pairs)
324
367
  self.dtype = dtype
325
368
 
369
+ def get_text_config(self):
370
+ return self
371
+
326
372
  cache = MambaCache(
327
373
  _config(),
328
374
  max_batch_size=key_value_pairs[0][0].shape[0],
@@ -348,7 +394,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
348
394
  f"got {key_value_pairs[i][1].shape}"
349
395
  )
350
396
  cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
351
- return cache
397
+ return finalize_cache(cache)
352
398
 
353
399
 
354
400
  def make_sliding_window_cache(
@@ -363,6 +409,9 @@ def make_sliding_window_cache(
363
409
  self.num_hidden_layers = len(key_value_pairs)
364
410
  self.sliding_window = key_value_pairs[0][0].shape[2]
365
411
 
412
+ def get_text_config(self):
413
+ return self
414
+
366
415
  cache = transformers.cache_utils.SlidingWindowCache(
367
416
  config=_config(),
368
417
  max_batch_size=key_value_pairs[0][0].shape[0],
@@ -371,6 +420,13 @@ def make_sliding_window_cache(
371
420
  dtype=key_value_pairs[0][0].dtype,
372
421
  )
373
422
  ca = CacheKeyValue(cache)
423
+ if hasattr(cache, "layers") and len(ca.key_cache) == 0:
424
+ # transformers>= 4.55.2, layers are empty
425
+ cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
426
+ for i, (key, value) in enumerate(key_value_pairs):
427
+ cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
428
+ return cache
429
+
374
430
  for i in range(len(key_value_pairs)):
375
431
  assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
376
432
  f"Shape mismatch, expected {cache.key_cache[i].shape}, "
@@ -393,7 +449,7 @@ def make_sliding_window_cache(
393
449
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
394
450
  f"{len(key_value_pairs)} expected."
395
451
  )
396
- return cache
452
+ return finalize_cache(cache)
397
453
 
398
454
 
399
455
  def make_hybrid_cache(
@@ -521,6 +577,9 @@ def make_hybrid_cache(
521
577
  sliding_window = _sliding_window
522
578
  num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
523
579
 
580
+ def get_text_config(self):
581
+ return self
582
+
524
583
  if layer_types:
525
584
  _config.layer_types = layer_types # type: ignore[attr-defined]
526
585
 
@@ -549,4 +608,21 @@ def make_hybrid_cache(
549
608
  f"Unexpected number of layers in the cache ({len(cache.layers)}), "
550
609
  f"{len(key_value_pairs)} expected."
551
610
  )
611
+ return finalize_cache(cache)
612
+
613
+
614
+ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
615
+ """
616
+ Ensures the created cache is consistent.
617
+ Returns the cache modified inplace.
618
+ """
619
+ if (
620
+ hasattr(cache, "layer_class_to_replicate")
621
+ and hasattr(cache, "layers")
622
+ and cache.layers
623
+ and not cache.layer_class_to_replicate
624
+ ):
625
+ # This is used to expand the cache when it does not contains enough layers.
626
+ # This is needed since transformers>4.55.3
627
+ cache.layer_class_to_replicate = cache.layers[0].__class__
552
628
  return cache
@@ -1,6 +1,7 @@
1
1
  import functools
2
2
  import importlib
3
3
  import inspect
4
+ import os
4
5
  import re
5
6
  from typing import Any, Callable, Dict, Optional, Tuple, Union
6
7
  import transformers
@@ -110,3 +111,59 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
110
111
  )
111
112
  cls_name = unique.pop()
112
113
  return getattr(transformers, cls_name)
114
+
115
+
116
+ def default_num_hidden_layers():
117
+ """
118
+ Returns the default number of layers.
119
+ It is lower when the unit tests are running
120
+ when ``UNITTEST_GOING=1``.
121
+ """
122
+ import torch
123
+
124
+ if torch.cuda.is_available():
125
+ capa = torch.cuda.get_device_capability(0)
126
+ if capa[0] < 9:
127
+ return 2
128
+ return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
129
+
130
+
131
+ def build_diff_config(config0, config1):
132
+ """
133
+ Returns all the modified values between two configuration
134
+ """
135
+ import torch
136
+
137
+ diff = {}
138
+ for k in config0:
139
+ assert isinstance(k, str), f"k={k!r}, wrong type in {config0}"
140
+ if k not in config1:
141
+ v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
142
+ diff[k] = f"-{v0}"
143
+ for k in config1:
144
+ assert isinstance(k, str), f"k={k!r}, wrong type in {config1}"
145
+ if k not in config0:
146
+ v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
147
+ diff[k] = f"+{v1}"
148
+ for k in config0:
149
+ if k not in config1:
150
+ continue
151
+ v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
152
+ v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
153
+ if (
154
+ v0 is None
155
+ or v1 is None
156
+ or isinstance(v1, (float, int, bool, str, list, tuple, torch.dtype))
157
+ or (
158
+ isinstance(v0, dict)
159
+ and isinstance(v1, dict)
160
+ and all(isinstance(k, int) for k in v1)
161
+ )
162
+ ):
163
+ if v1 != v0:
164
+ diff[k] = f"{v0} -> {v1}"
165
+ else:
166
+ d = build_diff_config(v0, v1)
167
+ if d:
168
+ diff[k] = d
169
+ return diff
@@ -36,11 +36,12 @@ def size_type(dtype: Any) -> int:
36
36
  TensorProto.FLOAT8E4M3FNUZ,
37
37
  TensorProto.FLOAT8E5M2,
38
38
  TensorProto.FLOAT8E5M2FNUZ,
39
+ getattr(TensorProto, "FLOAT8E8M0", None),
39
40
  }:
40
41
  return 1
41
42
  if dtype in {TensorProto.COMPLEX128}:
42
43
  return 16
43
- from .helpers.onnx_helper import onnx_dtype_name
44
+ from .onnx_helper import onnx_dtype_name
44
45
 
45
46
  raise AssertionError(
46
47
  f"Unable to return the element size for type {onnx_dtype_name(dtype)}"
@@ -1478,8 +1479,12 @@ def max_diff(
1478
1479
  # backup function in case pytorch does not know how to serialize.
1479
1480
  if expected.__class__.__name__ == "DynamicCache":
1480
1481
  if got.__class__.__name__ == "DynamicCache":
1482
+ from .cache_helper import CacheKeyValue
1483
+
1481
1484
  if verbose >= 6:
1482
1485
  print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1486
+ expected = CacheKeyValue(expected)
1487
+ got = CacheKeyValue(got)
1483
1488
  return max_diff(
1484
1489
  [expected.key_cache, expected.value_cache],
1485
1490
  [got.key_cache, got.value_cache],
@@ -11,22 +11,26 @@ try:
11
11
  float8e5m2fnuz,
12
12
  )
13
13
  except ImportError:
14
+ bfloat16 = None
14
15
  from onnx.reference.ops.op_cast import cast_to
15
16
  from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
16
17
 
17
18
 
18
19
  def _cast_like(x, y, saturate):
19
- if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
20
- # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
21
- to = TensorProto.BFLOAT16
22
- elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
23
- to = TensorProto.FLOAT8E4M3FN
24
- elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
25
- to = TensorProto.FLOAT8E4M3FNUZ
26
- elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
27
- to = TensorProto.FLOAT8E5M2
28
- elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
29
- to = TensorProto.FLOAT8E5M2FNUZ
20
+ if bfloat16 is not None:
21
+ if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
22
+ # np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
23
+ to = TensorProto.BFLOAT16
24
+ elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
25
+ to = TensorProto.FLOAT8E4M3FN
26
+ elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
27
+ to = TensorProto.FLOAT8E4M3FNUZ
28
+ elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
29
+ to = TensorProto.FLOAT8E5M2
30
+ elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
31
+ to = TensorProto.FLOAT8E5M2FNUZ
32
+ else:
33
+ to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
30
34
  else:
31
35
  to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
32
36
  return (cast_to(x, to, saturate),)
@@ -45,6 +45,7 @@ from .unary_ops import (
45
45
  Erf_9,
46
46
  Exp_1,
47
47
  Identity_1,
48
+ IsNaN_9,
48
49
  Log_1,
49
50
  Neg_1,
50
51
  Not_1,
@@ -37,6 +37,13 @@ class Identity_1(OpRunKernel):
37
37
  return OpRunTensor(x.tensor)
38
38
 
39
39
 
40
+ class IsNaN_9(OpRunKernel):
41
+ """IsNaN"""
42
+
43
+ def run(self, x: OpRunTensor) -> OpRunTensor:
44
+ return OpRunTensor(x.tensor.isnan())
45
+
46
+
40
47
  class Log_1(OpRunKernel):
41
48
  """Log"""
42
49
 
@@ -2,7 +2,11 @@ from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
3
  import transformers
4
4
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
- from ..helpers.config_helper import update_config, check_hasattr
5
+ from ..helpers.config_helper import (
6
+ update_config,
7
+ check_hasattr,
8
+ default_num_hidden_layers as nhl,
9
+ )
6
10
 
7
11
  __TASK__ = "automatic-speech-recognition"
8
12
 
@@ -15,7 +19,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
19
  if hasattr(config, "decoder_layers"):
16
20
  config.decoder_layers = min(config.decoder_layers, 2)
17
21
  if hasattr(config, "num_hidden_layers"):
18
- config.num_hidden_layers = min(config.num_hidden_layers, 2)
22
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
19
23
  update_config(config, kwargs)
20
24
  return kwargs
21
25
 
@@ -1,15 +1,20 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
  from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
5
9
 
10
+
6
11
  __TASK__ = "feature-extraction"
7
12
 
8
13
 
9
14
  def reduce_model_config(config: Any) -> Dict[str, Any]:
10
15
  """Reduces a model size."""
11
16
  check_hasattr(config, "num_hidden_layers")
12
- kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, 2))
17
+ kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
13
18
  update_config(config, kwargs)
14
19
  return kwargs
15
20
 
@@ -160,5 +165,4 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
160
165
  if hasattr(config, att):
161
166
  kwargs[att] = getattr(config, att)
162
167
  kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
163
- print(kwargs)
164
168
  return kwargs, get_inputs
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "fill-mask"
6
10
 
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
9
13
  """Reduces a model size."""
10
14
  check_hasattr(config, "num_attention_heads", "num_hidden_layers")
11
15
  kwargs = dict(
12
- num_hidden_layers=min(config.num_hidden_layers, 2),
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
13
17
  num_attention_heads=min(config.num_attention_heads, 4),
14
18
  )
15
19
  update_config(config, kwargs)
@@ -1,6 +1,10 @@
1
1
  from typing import Any, Callable, Dict, Optional, Tuple
2
2
  import torch
3
- from ..helpers.config_helper import update_config, check_hasattr
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
4
8
 
5
9
  __TASK__ = "image-classification"
6
10
 
@@ -17,7 +21,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
17
21
  check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
18
22
  kwargs = dict(
19
23
  num_hidden_layers=(
20
- min(config.num_hidden_layers, 2)
24
+ min(config.num_hidden_layers, nhl())
21
25
  if hasattr(config, "num_hidden_layers")
22
26
  else len(config.hidden_sizes)
23
27
  )