onnx-diagnostic 0.7.6__py3-none-any.whl → 0.7.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +56 -3
- onnx_diagnostic/export/dynamic_shapes.py +24 -10
- onnx_diagnostic/export/shape_helper.py +6 -2
- onnx_diagnostic/helpers/cache_helper.py +83 -7
- onnx_diagnostic/helpers/config_helper.py +57 -0
- onnx_diagnostic/helpers/helper.py +6 -1
- onnx_diagnostic/reference/ops/op_cast_like.py +15 -11
- onnx_diagnostic/reference/torch_ops/__init__.py +1 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +6 -2
- onnx_diagnostic/tasks/feature_extraction.py +7 -3
- onnx_diagnostic/tasks/fill_mask.py +6 -2
- onnx_diagnostic/tasks/image_classification.py +6 -2
- onnx_diagnostic/tasks/image_text_to_text.py +48 -12
- onnx_diagnostic/tasks/mask_generation.py +6 -2
- onnx_diagnostic/tasks/mixture_of_expert.py +2 -2
- onnx_diagnostic/tasks/object_detection.py +6 -2
- onnx_diagnostic/tasks/sentence_similarity.py +6 -2
- onnx_diagnostic/tasks/summarization.py +7 -2
- onnx_diagnostic/tasks/text2text_generation.py +7 -2
- onnx_diagnostic/tasks/text_classification.py +6 -2
- onnx_diagnostic/tasks/text_generation.py +8 -14
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +3 -3
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +4 -4
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +227 -1
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +3 -1
- onnx_diagnostic/torch_models/hghub/hub_data.py +5 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +70 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +13 -1
- onnx_diagnostic/torch_models/validate.py +17 -0
- {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/RECORD +37 -37
- {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.6.dist-info → onnx_diagnostic-0.7.8.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -306,7 +306,7 @@ class _ParseDict(argparse.Action):
|
|
|
306
306
|
value = split_items[1]
|
|
307
307
|
|
|
308
308
|
if value in ("True", "true", "False", "false"):
|
|
309
|
-
d[key] =
|
|
309
|
+
d[key] = value in ("True", "true")
|
|
310
310
|
continue
|
|
311
311
|
try:
|
|
312
312
|
d[key] = int(value)
|
|
@@ -323,6 +323,54 @@ class _ParseDict(argparse.Action):
|
|
|
323
323
|
setattr(namespace, self.dest, d)
|
|
324
324
|
|
|
325
325
|
|
|
326
|
+
class _BoolOrParseDictPatch(argparse.Action):
|
|
327
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
|
328
|
+
|
|
329
|
+
if not values:
|
|
330
|
+
return
|
|
331
|
+
if len(values) == 1 and values[0] in (
|
|
332
|
+
"True",
|
|
333
|
+
"False",
|
|
334
|
+
"true",
|
|
335
|
+
"false",
|
|
336
|
+
"0",
|
|
337
|
+
"1",
|
|
338
|
+
0,
|
|
339
|
+
1,
|
|
340
|
+
):
|
|
341
|
+
setattr(namespace, self.dest, values[0] in ("True", "true", 1, "1"))
|
|
342
|
+
return
|
|
343
|
+
d = getattr(namespace, self.dest) or {}
|
|
344
|
+
if not isinstance(d, dict):
|
|
345
|
+
d = {
|
|
346
|
+
"patch_sympy": d,
|
|
347
|
+
"patch_torch": d,
|
|
348
|
+
"patch_transformers": d,
|
|
349
|
+
"patch_diffusers": d,
|
|
350
|
+
}
|
|
351
|
+
for item in values:
|
|
352
|
+
split_items = item.split("=", 1)
|
|
353
|
+
key = split_items[0].strip() # we remove blanks around keys, as is logical
|
|
354
|
+
value = split_items[1]
|
|
355
|
+
|
|
356
|
+
if value in ("True", "true", "False", "false"):
|
|
357
|
+
d[key] = value in ("True", "true")
|
|
358
|
+
continue
|
|
359
|
+
try:
|
|
360
|
+
d[key] = int(value)
|
|
361
|
+
continue
|
|
362
|
+
except (TypeError, ValueError):
|
|
363
|
+
pass
|
|
364
|
+
try:
|
|
365
|
+
d[key] = float(value)
|
|
366
|
+
continue
|
|
367
|
+
except (TypeError, ValueError):
|
|
368
|
+
pass
|
|
369
|
+
d[key] = _parse_json(value)
|
|
370
|
+
|
|
371
|
+
setattr(namespace, self.dest, d)
|
|
372
|
+
|
|
373
|
+
|
|
326
374
|
def get_parser_validate() -> ArgumentParser:
|
|
327
375
|
parser = ArgumentParser(
|
|
328
376
|
prog="validate",
|
|
@@ -383,8 +431,13 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
383
431
|
parser.add_argument(
|
|
384
432
|
"--patch",
|
|
385
433
|
default=True,
|
|
386
|
-
action=
|
|
387
|
-
|
|
434
|
+
action=_BoolOrParseDictPatch,
|
|
435
|
+
nargs="*",
|
|
436
|
+
help="Applies patches before exporting, it can be a boolean "
|
|
437
|
+
"to enable to disable the patches or be more finetuned. It is possible to "
|
|
438
|
+
"disable patch for torch by adding "
|
|
439
|
+
'--patch "patch_sympy=False" --patch "patch_torch=False", '
|
|
440
|
+
"default is True.",
|
|
388
441
|
)
|
|
389
442
|
parser.add_argument(
|
|
390
443
|
"--rewrite",
|
|
@@ -887,19 +887,30 @@ class ModelInputs:
|
|
|
887
887
|
|
|
888
888
|
# In case DynamicCache is not registered.
|
|
889
889
|
if obj.__class__.__name__ == "DynamicCache":
|
|
890
|
-
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
|
|
896
|
-
|
|
897
|
-
|
|
890
|
+
if hasattr(obj, "layers"):
|
|
891
|
+
kc = set(len(o.layers) for o in objs)
|
|
892
|
+
assert (
|
|
893
|
+
len(kc) == 1
|
|
894
|
+
), f"All attribute 'key_cache' should have the same length but found {kc}"
|
|
895
|
+
vc = kc.copy()
|
|
896
|
+
else:
|
|
897
|
+
kc = set(len(o.key_cache) for o in objs)
|
|
898
|
+
assert (
|
|
899
|
+
len(kc) == 1
|
|
900
|
+
), f"All attribute 'key_cache' should have the same length but found {kc}"
|
|
901
|
+
vc = set(len(o.value_cache) for o in objs)
|
|
902
|
+
assert (
|
|
903
|
+
len(vc) == 1
|
|
904
|
+
), f"All attribute 'value_cache' should have the same length but found {vc}"
|
|
905
|
+
|
|
898
906
|
key_cache = []
|
|
899
907
|
for i in range(kc.pop()):
|
|
900
908
|
key_cache.append(
|
|
901
909
|
self.guess_dynamic_dimensions(
|
|
902
|
-
*[
|
|
910
|
+
*[
|
|
911
|
+
o.layers[i].keys if hasattr(o, "layers") else o.key_cache[i]
|
|
912
|
+
for o in objs
|
|
913
|
+
],
|
|
903
914
|
auto=auto if isinstance(auto, bool) else f"{auto}_{i}kdc",
|
|
904
915
|
)
|
|
905
916
|
)
|
|
@@ -907,7 +918,10 @@ class ModelInputs:
|
|
|
907
918
|
for i in range(vc.pop()):
|
|
908
919
|
value_cache.append(
|
|
909
920
|
self.guess_dynamic_dimensions(
|
|
910
|
-
*[
|
|
921
|
+
*[
|
|
922
|
+
o.layers[i].values if hasattr(o, "layers") else o.value_cache[i]
|
|
923
|
+
for o in objs
|
|
924
|
+
],
|
|
911
925
|
auto=auto if isinstance(auto, bool) else f"{auto}_{i}vdc",
|
|
912
926
|
)
|
|
913
927
|
)
|
|
@@ -9,6 +9,8 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
9
9
|
All dimensions are considered as dynamic.
|
|
10
10
|
``dim_prefix`` can be a string (the function uses it as a prefix),
|
|
11
11
|
or ``torch.export.Dim.AUTO`` or ``torch.export.Dim.DYNAMIC``.
|
|
12
|
+
Depending on the version of transformers, serializations function
|
|
13
|
+
of DynamicCache class is automatically serialized or not (>= 4.51, < 4.55).
|
|
12
14
|
|
|
13
15
|
.. runpython::
|
|
14
16
|
:showcode:
|
|
@@ -17,6 +19,7 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
17
19
|
import torch
|
|
18
20
|
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
19
21
|
from onnx_diagnostic.export.shape_helper import all_dynamic_shape_from_inputs
|
|
22
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
20
23
|
|
|
21
24
|
bsize, nheads, slen, dim = 2, 1, 30, 96
|
|
22
25
|
inputs = dict(
|
|
@@ -25,10 +28,11 @@ def all_dynamic_shape_from_inputs(inputs: Any, dim_prefix: Any = "d") -> Any:
|
|
|
25
28
|
position_ids=torch.arange(3, dtype=torch.int64),
|
|
26
29
|
past_key_values=make_dynamic_cache(
|
|
27
30
|
[(torch.randn(bsize, nheads, slen, dim),
|
|
28
|
-
|
|
31
|
+
torch.randn(bsize, nheads, slen, dim))]
|
|
29
32
|
),
|
|
30
33
|
)
|
|
31
|
-
|
|
34
|
+
with torch_export_patches(patch_transformers=True):
|
|
35
|
+
ds = all_dynamic_shape_from_inputs(inputs)
|
|
32
36
|
pprint.pprint(ds)
|
|
33
37
|
|
|
34
38
|
For this function to work, patches must be enabled if :epkg:`transformers`
|
|
@@ -41,9 +41,14 @@ class CacheKeyValue:
|
|
|
41
41
|
f"or value_cache={string_type(self.value_cache)}, "
|
|
42
42
|
f"cache.layers={string_type(cache.layers)}"
|
|
43
43
|
)
|
|
44
|
-
elif cache is not None:
|
|
44
|
+
elif cache is not None and hasattr(cache, "key_cache"):
|
|
45
45
|
self.key_cache = cache.key_cache
|
|
46
46
|
self.value_cache = cache.value_cache
|
|
47
|
+
elif cache is None:
|
|
48
|
+
self.key_cache = None
|
|
49
|
+
self.value_cache = None
|
|
50
|
+
else:
|
|
51
|
+
raise NotImplementedError(f"type(cache)={type(cache)}")
|
|
47
52
|
|
|
48
53
|
def make_dynamic_cache(self):
|
|
49
54
|
"""Do the reverse operation."""
|
|
@@ -91,13 +96,16 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
91
96
|
return tuple(subtrees)
|
|
92
97
|
if spec.type is list:
|
|
93
98
|
return list(subtrees)
|
|
99
|
+
if spec.type is None and not subtrees:
|
|
100
|
+
return None
|
|
94
101
|
if spec.context:
|
|
95
102
|
# This is a custom class with attributes.
|
|
96
103
|
# It is returned as a list.
|
|
97
104
|
return list(subtrees)
|
|
98
105
|
raise ValueError(
|
|
99
106
|
f"Unable to interpret spec type {spec.type} "
|
|
100
|
-
f"(type is {type(spec.type)}, context is {spec.context})
|
|
107
|
+
f"(type is {type(spec.type)}, context is {spec.context}), "
|
|
108
|
+
f"spec={spec}, subtrees={subtrees}"
|
|
101
109
|
)
|
|
102
110
|
# This is a list.
|
|
103
111
|
return subtrees
|
|
@@ -126,6 +134,8 @@ def is_cache_dynamic_registered(fast: bool = False) -> bool:
|
|
|
126
134
|
)
|
|
127
135
|
values, spec = torch.utils._pytree.tree_flatten(cache)
|
|
128
136
|
cache2 = torch.utils._pytree.tree_unflatten(values, spec)
|
|
137
|
+
if hasattr(cache2, "layers") and hasattr(cache, "layers"):
|
|
138
|
+
return len(cache2.layers) == len(cache.layers)
|
|
129
139
|
return len(cache2.key_cache) == len(cache.value_cache)
|
|
130
140
|
|
|
131
141
|
|
|
@@ -176,7 +186,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
176
186
|
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
177
187
|
f"{len(key_value_pairs)} expected."
|
|
178
188
|
)
|
|
179
|
-
return cache
|
|
189
|
+
return finalize_cache(cache)
|
|
180
190
|
|
|
181
191
|
else:
|
|
182
192
|
|
|
@@ -260,6 +270,9 @@ def make_static_cache(
|
|
|
260
270
|
self.num_attention_heads = key_value_pairs[0][0].shape[1]
|
|
261
271
|
self.num_hidden_layers = len(key_value_pairs)
|
|
262
272
|
|
|
273
|
+
def get_text_config(self):
|
|
274
|
+
return self
|
|
275
|
+
|
|
263
276
|
assert max_cache_len is not None, (
|
|
264
277
|
f"max_cache_len={max_cache_len} cannot be setup "
|
|
265
278
|
f"automatically yet from shape {key_value_pairs[0][0].shape}"
|
|
@@ -280,6 +293,33 @@ def make_static_cache(
|
|
|
280
293
|
max_cache_len=max_cache_len,
|
|
281
294
|
)
|
|
282
295
|
ca = CacheKeyValue(cache)
|
|
296
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
297
|
+
# transformers>= 4.55.2, layers are empty
|
|
298
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
299
|
+
cache.update(key, value, i)
|
|
300
|
+
return cache
|
|
301
|
+
|
|
302
|
+
torch._check(
|
|
303
|
+
not hasattr(cache, "layers") or len(key_value_pairs) == len(cache.layers),
|
|
304
|
+
lambda: (
|
|
305
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
306
|
+
f"len(cache.layers)={len(cache.layers)}"
|
|
307
|
+
),
|
|
308
|
+
)
|
|
309
|
+
torch._check(
|
|
310
|
+
len(key_value_pairs) == len(ca.key_cache),
|
|
311
|
+
lambda: (
|
|
312
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
313
|
+
f"len(ca.key_cache)={len(ca.key_cache)}"
|
|
314
|
+
),
|
|
315
|
+
)
|
|
316
|
+
torch._check(
|
|
317
|
+
len(key_value_pairs) == len(ca.value_cache),
|
|
318
|
+
lambda: (
|
|
319
|
+
f"Length mismatch len(key_value_pairs)={len(key_value_pairs)}, "
|
|
320
|
+
f"len(ca.value_cache)={len(ca.value_cache)}"
|
|
321
|
+
),
|
|
322
|
+
)
|
|
283
323
|
for i in range(len(key_value_pairs)):
|
|
284
324
|
assert (
|
|
285
325
|
key_value_pairs[i][0].shape == key_value_pairs[i][1].shape
|
|
@@ -298,7 +338,7 @@ def make_static_cache(
|
|
|
298
338
|
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
299
339
|
f"{len(key_value_pairs)} expected."
|
|
300
340
|
)
|
|
301
|
-
return cache
|
|
341
|
+
return finalize_cache(cache)
|
|
302
342
|
|
|
303
343
|
|
|
304
344
|
def make_encoder_decoder_cache(
|
|
@@ -307,7 +347,10 @@ def make_encoder_decoder_cache(
|
|
|
307
347
|
) -> transformers.cache_utils.EncoderDecoderCache:
|
|
308
348
|
"""Creates an EncoderDecoderCache."""
|
|
309
349
|
return transformers.cache_utils.EncoderDecoderCache(
|
|
310
|
-
self_attention_cache=self_attention_cache,
|
|
350
|
+
# self_attention_cache=self_attention_cache,
|
|
351
|
+
# cross_attention_cache=cross_attention_cache
|
|
352
|
+
self_attention_cache,
|
|
353
|
+
cross_attention_cache,
|
|
311
354
|
)
|
|
312
355
|
|
|
313
356
|
|
|
@@ -323,6 +366,9 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
|
|
|
323
366
|
self.num_hidden_layers = len(key_value_pairs)
|
|
324
367
|
self.dtype = dtype
|
|
325
368
|
|
|
369
|
+
def get_text_config(self):
|
|
370
|
+
return self
|
|
371
|
+
|
|
326
372
|
cache = MambaCache(
|
|
327
373
|
_config(),
|
|
328
374
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
@@ -348,7 +394,7 @@ def make_mamba_cache(key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]]) -
|
|
|
348
394
|
f"got {key_value_pairs[i][1].shape}"
|
|
349
395
|
)
|
|
350
396
|
cache.ssm_states[i][:, :, :] = key_value_pairs[i][1]
|
|
351
|
-
return cache
|
|
397
|
+
return finalize_cache(cache)
|
|
352
398
|
|
|
353
399
|
|
|
354
400
|
def make_sliding_window_cache(
|
|
@@ -363,6 +409,9 @@ def make_sliding_window_cache(
|
|
|
363
409
|
self.num_hidden_layers = len(key_value_pairs)
|
|
364
410
|
self.sliding_window = key_value_pairs[0][0].shape[2]
|
|
365
411
|
|
|
412
|
+
def get_text_config(self):
|
|
413
|
+
return self
|
|
414
|
+
|
|
366
415
|
cache = transformers.cache_utils.SlidingWindowCache(
|
|
367
416
|
config=_config(),
|
|
368
417
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
@@ -371,6 +420,13 @@ def make_sliding_window_cache(
|
|
|
371
420
|
dtype=key_value_pairs[0][0].dtype,
|
|
372
421
|
)
|
|
373
422
|
ca = CacheKeyValue(cache)
|
|
423
|
+
if hasattr(cache, "layers") and len(ca.key_cache) == 0:
|
|
424
|
+
# transformers>= 4.55.2, layers are empty
|
|
425
|
+
cache_position = torch.arange(key_value_pairs[0][0].shape[2], dtype=torch.int64)
|
|
426
|
+
for i, (key, value) in enumerate(key_value_pairs):
|
|
427
|
+
cache.update(key, value, i, cache_kwargs={"cache_position": cache_position})
|
|
428
|
+
return cache
|
|
429
|
+
|
|
374
430
|
for i in range(len(key_value_pairs)):
|
|
375
431
|
assert ca.key_cache[i].shape == key_value_pairs[i][0].shape, (
|
|
376
432
|
f"Shape mismatch, expected {cache.key_cache[i].shape}, "
|
|
@@ -393,7 +449,7 @@ def make_sliding_window_cache(
|
|
|
393
449
|
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
394
450
|
f"{len(key_value_pairs)} expected."
|
|
395
451
|
)
|
|
396
|
-
return cache
|
|
452
|
+
return finalize_cache(cache)
|
|
397
453
|
|
|
398
454
|
|
|
399
455
|
def make_hybrid_cache(
|
|
@@ -521,6 +577,9 @@ def make_hybrid_cache(
|
|
|
521
577
|
sliding_window = _sliding_window
|
|
522
578
|
num_key_value_heads = key_value_pairs[0][1].shape[1] # transformers 4.48.3
|
|
523
579
|
|
|
580
|
+
def get_text_config(self):
|
|
581
|
+
return self
|
|
582
|
+
|
|
524
583
|
if layer_types:
|
|
525
584
|
_config.layer_types = layer_types # type: ignore[attr-defined]
|
|
526
585
|
|
|
@@ -549,4 +608,21 @@ def make_hybrid_cache(
|
|
|
549
608
|
f"Unexpected number of layers in the cache ({len(cache.layers)}), "
|
|
550
609
|
f"{len(key_value_pairs)} expected."
|
|
551
610
|
)
|
|
611
|
+
return finalize_cache(cache)
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_utils.Cache:
|
|
615
|
+
"""
|
|
616
|
+
Ensures the created cache is consistent.
|
|
617
|
+
Returns the cache modified inplace.
|
|
618
|
+
"""
|
|
619
|
+
if (
|
|
620
|
+
hasattr(cache, "layer_class_to_replicate")
|
|
621
|
+
and hasattr(cache, "layers")
|
|
622
|
+
and cache.layers
|
|
623
|
+
and not cache.layer_class_to_replicate
|
|
624
|
+
):
|
|
625
|
+
# This is used to expand the cache when it does not contains enough layers.
|
|
626
|
+
# This is needed since transformers>4.55.3
|
|
627
|
+
cache.layer_class_to_replicate = cache.layers[0].__class__
|
|
552
628
|
return cache
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import functools
|
|
2
2
|
import importlib
|
|
3
3
|
import inspect
|
|
4
|
+
import os
|
|
4
5
|
import re
|
|
5
6
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
6
7
|
import transformers
|
|
@@ -110,3 +111,59 @@ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[typ
|
|
|
110
111
|
)
|
|
111
112
|
cls_name = unique.pop()
|
|
112
113
|
return getattr(transformers, cls_name)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def default_num_hidden_layers():
|
|
117
|
+
"""
|
|
118
|
+
Returns the default number of layers.
|
|
119
|
+
It is lower when the unit tests are running
|
|
120
|
+
when ``UNITTEST_GOING=1``.
|
|
121
|
+
"""
|
|
122
|
+
import torch
|
|
123
|
+
|
|
124
|
+
if torch.cuda.is_available():
|
|
125
|
+
capa = torch.cuda.get_device_capability(0)
|
|
126
|
+
if capa[0] < 9:
|
|
127
|
+
return 2
|
|
128
|
+
return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def build_diff_config(config0, config1):
|
|
132
|
+
"""
|
|
133
|
+
Returns all the modified values between two configuration
|
|
134
|
+
"""
|
|
135
|
+
import torch
|
|
136
|
+
|
|
137
|
+
diff = {}
|
|
138
|
+
for k in config0:
|
|
139
|
+
assert isinstance(k, str), f"k={k!r}, wrong type in {config0}"
|
|
140
|
+
if k not in config1:
|
|
141
|
+
v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
|
|
142
|
+
diff[k] = f"-{v0}"
|
|
143
|
+
for k in config1:
|
|
144
|
+
assert isinstance(k, str), f"k={k!r}, wrong type in {config1}"
|
|
145
|
+
if k not in config0:
|
|
146
|
+
v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
|
|
147
|
+
diff[k] = f"+{v1}"
|
|
148
|
+
for k in config0:
|
|
149
|
+
if k not in config1:
|
|
150
|
+
continue
|
|
151
|
+
v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
|
|
152
|
+
v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
|
|
153
|
+
if (
|
|
154
|
+
v0 is None
|
|
155
|
+
or v1 is None
|
|
156
|
+
or isinstance(v1, (float, int, bool, str, list, tuple, torch.dtype))
|
|
157
|
+
or (
|
|
158
|
+
isinstance(v0, dict)
|
|
159
|
+
and isinstance(v1, dict)
|
|
160
|
+
and all(isinstance(k, int) for k in v1)
|
|
161
|
+
)
|
|
162
|
+
):
|
|
163
|
+
if v1 != v0:
|
|
164
|
+
diff[k] = f"{v0} -> {v1}"
|
|
165
|
+
else:
|
|
166
|
+
d = build_diff_config(v0, v1)
|
|
167
|
+
if d:
|
|
168
|
+
diff[k] = d
|
|
169
|
+
return diff
|
|
@@ -36,11 +36,12 @@ def size_type(dtype: Any) -> int:
|
|
|
36
36
|
TensorProto.FLOAT8E4M3FNUZ,
|
|
37
37
|
TensorProto.FLOAT8E5M2,
|
|
38
38
|
TensorProto.FLOAT8E5M2FNUZ,
|
|
39
|
+
getattr(TensorProto, "FLOAT8E8M0", None),
|
|
39
40
|
}:
|
|
40
41
|
return 1
|
|
41
42
|
if dtype in {TensorProto.COMPLEX128}:
|
|
42
43
|
return 16
|
|
43
|
-
from .
|
|
44
|
+
from .onnx_helper import onnx_dtype_name
|
|
44
45
|
|
|
45
46
|
raise AssertionError(
|
|
46
47
|
f"Unable to return the element size for type {onnx_dtype_name(dtype)}"
|
|
@@ -1478,8 +1479,12 @@ def max_diff(
|
|
|
1478
1479
|
# backup function in case pytorch does not know how to serialize.
|
|
1479
1480
|
if expected.__class__.__name__ == "DynamicCache":
|
|
1480
1481
|
if got.__class__.__name__ == "DynamicCache":
|
|
1482
|
+
from .cache_helper import CacheKeyValue
|
|
1483
|
+
|
|
1481
1484
|
if verbose >= 6:
|
|
1482
1485
|
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
|
|
1486
|
+
expected = CacheKeyValue(expected)
|
|
1487
|
+
got = CacheKeyValue(got)
|
|
1483
1488
|
return max_diff(
|
|
1484
1489
|
[expected.key_cache, expected.value_cache],
|
|
1485
1490
|
[got.key_cache, got.value_cache],
|
|
@@ -11,22 +11,26 @@ try:
|
|
|
11
11
|
float8e5m2fnuz,
|
|
12
12
|
)
|
|
13
13
|
except ImportError:
|
|
14
|
+
bfloat16 = None
|
|
14
15
|
from onnx.reference.ops.op_cast import cast_to
|
|
15
16
|
from ...helpers.onnx_helper import np_dtype_to_tensor_dtype
|
|
16
17
|
|
|
17
18
|
|
|
18
19
|
def _cast_like(x, y, saturate):
|
|
19
|
-
if
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
20
|
+
if bfloat16 is not None:
|
|
21
|
+
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
|
|
22
|
+
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
|
|
23
|
+
to = TensorProto.BFLOAT16
|
|
24
|
+
elif y.dtype == float8e4m3fn and y.dtype.descr[0][0] == "e4m3fn":
|
|
25
|
+
to = TensorProto.FLOAT8E4M3FN
|
|
26
|
+
elif y.dtype == float8e4m3fnuz and y.dtype.descr[0][0] == "e4m3fnuz":
|
|
27
|
+
to = TensorProto.FLOAT8E4M3FNUZ
|
|
28
|
+
elif y.dtype == float8e5m2 and y.dtype.descr[0][0] == "e5m2":
|
|
29
|
+
to = TensorProto.FLOAT8E5M2
|
|
30
|
+
elif y.dtype == float8e5m2fnuz and y.dtype.descr[0][0] == "e5m2fnuz":
|
|
31
|
+
to = TensorProto.FLOAT8E5M2FNUZ
|
|
32
|
+
else:
|
|
33
|
+
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
|
|
30
34
|
else:
|
|
31
35
|
to = np_dtype_to_tensor_dtype(y.dtype) # type: ignore
|
|
32
36
|
return (cast_to(x, to, saturate),)
|
|
@@ -37,6 +37,13 @@ class Identity_1(OpRunKernel):
|
|
|
37
37
|
return OpRunTensor(x.tensor)
|
|
38
38
|
|
|
39
39
|
|
|
40
|
+
class IsNaN_9(OpRunKernel):
|
|
41
|
+
"""IsNaN"""
|
|
42
|
+
|
|
43
|
+
def run(self, x: OpRunTensor) -> OpRunTensor:
|
|
44
|
+
return OpRunTensor(x.tensor.isnan())
|
|
45
|
+
|
|
46
|
+
|
|
40
47
|
class Log_1(OpRunKernel):
|
|
41
48
|
"""Log"""
|
|
42
49
|
|
|
@@ -2,7 +2,11 @@ from typing import Any, Callable, Dict, Optional, Tuple
|
|
|
2
2
|
import torch
|
|
3
3
|
import transformers
|
|
4
4
|
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
5
|
-
from ..helpers.config_helper import
|
|
5
|
+
from ..helpers.config_helper import (
|
|
6
|
+
update_config,
|
|
7
|
+
check_hasattr,
|
|
8
|
+
default_num_hidden_layers as nhl,
|
|
9
|
+
)
|
|
6
10
|
|
|
7
11
|
__TASK__ = "automatic-speech-recognition"
|
|
8
12
|
|
|
@@ -15,7 +19,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
15
19
|
if hasattr(config, "decoder_layers"):
|
|
16
20
|
config.decoder_layers = min(config.decoder_layers, 2)
|
|
17
21
|
if hasattr(config, "num_hidden_layers"):
|
|
18
|
-
config.num_hidden_layers = min(config.num_hidden_layers,
|
|
22
|
+
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
19
23
|
update_config(config, kwargs)
|
|
20
24
|
return kwargs
|
|
21
25
|
|
|
@@ -1,15 +1,20 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.config_helper import
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
4
8
|
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
5
9
|
|
|
10
|
+
|
|
6
11
|
__TASK__ = "feature-extraction"
|
|
7
12
|
|
|
8
13
|
|
|
9
14
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
10
15
|
"""Reduces a model size."""
|
|
11
16
|
check_hasattr(config, "num_hidden_layers")
|
|
12
|
-
kwargs = dict(num_hidden_layers=min(config.num_hidden_layers,
|
|
17
|
+
kwargs = dict(num_hidden_layers=min(config.num_hidden_layers, nhl()))
|
|
13
18
|
update_config(config, kwargs)
|
|
14
19
|
return kwargs
|
|
15
20
|
|
|
@@ -160,5 +165,4 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
160
165
|
if hasattr(config, att):
|
|
161
166
|
kwargs[att] = getattr(config, att)
|
|
162
167
|
kwargs["decoder_ffn_dim"] = kwargs["encoder_ffn_dim"] = 64
|
|
163
|
-
print(kwargs)
|
|
164
168
|
return kwargs, get_inputs
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.config_helper import
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
4
8
|
|
|
5
9
|
__TASK__ = "fill-mask"
|
|
6
10
|
|
|
@@ -9,7 +13,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
9
13
|
"""Reduces a model size."""
|
|
10
14
|
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
|
|
11
15
|
kwargs = dict(
|
|
12
|
-
num_hidden_layers=min(config.num_hidden_layers,
|
|
16
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
13
17
|
num_attention_heads=min(config.num_attention_heads, 4),
|
|
14
18
|
)
|
|
15
19
|
update_config(config, kwargs)
|
|
@@ -1,6 +1,10 @@
|
|
|
1
1
|
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
2
|
import torch
|
|
3
|
-
from ..helpers.config_helper import
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
4
8
|
|
|
5
9
|
__TASK__ = "image-classification"
|
|
6
10
|
|
|
@@ -17,7 +21,7 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
17
21
|
check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
|
|
18
22
|
kwargs = dict(
|
|
19
23
|
num_hidden_layers=(
|
|
20
|
-
min(config.num_hidden_layers,
|
|
24
|
+
min(config.num_hidden_layers, nhl())
|
|
21
25
|
if hasattr(config, "num_hidden_layers")
|
|
22
26
|
else len(config.hidden_sizes)
|
|
23
27
|
)
|