onnx-diagnostic 0.8.10__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +136 -140
- onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
- onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
- onnx_diagnostic/ci_models/data/__init__.py +0 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +10 -7
- onnx_diagnostic/export/api.py +13 -4
- onnx_diagnostic/export/dynamic_shapes.py +1 -1
- onnx_diagnostic/export/validate.py +2 -0
- onnx_diagnostic/ext_test_case.py +32 -15
- onnx_diagnostic/helpers/args_helper.py +1 -0
- onnx_diagnostic/helpers/bench_run.py +0 -1
- onnx_diagnostic/helpers/cache_helper.py +102 -36
- onnx_diagnostic/helpers/doc_helper.py +7 -4
- onnx_diagnostic/helpers/graph_helper.py +6 -6
- onnx_diagnostic/helpers/helper.py +39 -0
- onnx_diagnostic/helpers/log_helper.py +37 -14
- onnx_diagnostic/helpers/memory_peak.py +5 -1
- onnx_diagnostic/helpers/mini_onnx_builder.py +9 -14
- onnx_diagnostic/helpers/model_builder_helper.py +1 -1
- onnx_diagnostic/helpers/onnx_helper.py +283 -110
- onnx_diagnostic/helpers/ort_session.py +5 -2
- onnx_diagnostic/helpers/rt_helper.py +53 -9
- onnx_diagnostic/helpers/torch_helper.py +15 -11
- onnx_diagnostic/investigate/__init__.py +0 -0
- onnx_diagnostic/investigate/input_observer.py +970 -0
- onnx_diagnostic/reference/evaluator.py +0 -1
- onnx_diagnostic/reference/ort_evaluator.py +0 -1
- onnx_diagnostic/reference/report_results_comparison.py +9 -3
- onnx_diagnostic/reference/torch_evaluator.py +5 -1
- onnx_diagnostic/reference/torch_ops/_op_run.py +3 -5
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +1 -1
- onnx_diagnostic/tasks/feature_extraction.py +0 -1
- onnx_diagnostic/torch_export_patches/__init__.py +0 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
- onnx_diagnostic/torch_export_patches/patch_module.py +1 -1
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +2 -2
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +70 -23
- onnx_diagnostic/torch_models/code_sample.py +5 -10
- onnx_diagnostic/torch_models/hghub/hub_data.py +2 -4
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +6 -12
- onnx_diagnostic/torch_models/validate.py +1 -1
- onnx_diagnostic/torch_onnx/compare.py +0 -1
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -1
- onnx_diagnostic/torch_onnx/sbs.py +1 -1
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +2 -4
- onnx_diagnostic/typing.py +15 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +55 -50
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
- onnx_diagnostic/api.py +0 -15
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.10.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
onnx_diagnostic/export/api.py
CHANGED
|
@@ -428,6 +428,16 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
428
428
|
new_kwargs[k] = v
|
|
429
429
|
return new_kwargs
|
|
430
430
|
|
|
431
|
+
def is_empty_cache(self, cache):
|
|
432
|
+
if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"):
|
|
433
|
+
if len(cache.layers) == 1 and cache.layers[0].keys is None:
|
|
434
|
+
return True
|
|
435
|
+
if len(cache.layers) == 0:
|
|
436
|
+
return True
|
|
437
|
+
if cache is None:
|
|
438
|
+
return True
|
|
439
|
+
return False
|
|
440
|
+
|
|
431
441
|
def forward(self, *args, **kwargs):
|
|
432
442
|
if not self._export_done:
|
|
433
443
|
inp_args = args
|
|
@@ -443,6 +453,7 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
443
453
|
if v is not None
|
|
444
454
|
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
|
|
445
455
|
and not isinstance(v, (bool, int, float))
|
|
456
|
+
and not self.is_empty_cache(v)
|
|
446
457
|
}
|
|
447
458
|
)
|
|
448
459
|
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
|
|
@@ -509,12 +520,10 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
509
520
|
simple_sig = inspect.Signature(params, return_annotation=inspect._empty)
|
|
510
521
|
args = str(simple_sig)[1:-1]
|
|
511
522
|
calls_args = ", ".join(f"{p}={p}" for p in simple_sig.parameters)
|
|
512
|
-
src = textwrap.dedent(
|
|
513
|
-
f"""
|
|
523
|
+
src = textwrap.dedent(f"""
|
|
514
524
|
def f(self, {args}):
|
|
515
525
|
return self._method_call({calls_args})
|
|
516
|
-
"""
|
|
517
|
-
)
|
|
526
|
+
""")
|
|
518
527
|
self._method_src = src
|
|
519
528
|
ns = {}
|
|
520
529
|
try:
|
|
@@ -834,7 +834,7 @@ class ModelInputs:
|
|
|
834
834
|
"""Guesses the dynamic shapes for one argument."""
|
|
835
835
|
if len(objs) == 0:
|
|
836
836
|
return None
|
|
837
|
-
set_types = set(type(o) for o in objs)
|
|
837
|
+
set_types = set(type(o) for o in objs if o is not None)
|
|
838
838
|
assert (
|
|
839
839
|
len(set_types) == 1
|
|
840
840
|
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
|
|
@@ -80,6 +80,7 @@ def compare_modules(
|
|
|
80
80
|
)
|
|
81
81
|
got = modep(*_get(args), **_get(kwargs))
|
|
82
82
|
if verbose:
|
|
83
|
+
# pyrefly: ignore[unbound-name]
|
|
83
84
|
d = time.perf_counter() - begin
|
|
84
85
|
print(f"[compare_modules] done in {d} with output={string_type(got, with_shape=True)}")
|
|
85
86
|
if mod:
|
|
@@ -89,6 +90,7 @@ def compare_modules(
|
|
|
89
90
|
expected = mod(*_get(args), **_get(kwargs))
|
|
90
91
|
diff = max_diff(expected, got)
|
|
91
92
|
if verbose:
|
|
93
|
+
# pyrefly: ignore[unbound-name]
|
|
92
94
|
d = time.perf_counter() - begin
|
|
93
95
|
print(
|
|
94
96
|
f"[compare_modules] done in {d} with "
|
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -780,7 +780,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
780
780
|
|
|
781
781
|
@property
|
|
782
782
|
def verbose(self) -> int:
|
|
783
|
-
"Returns the
|
|
783
|
+
"Returns the value of environment variable ``VERBOSE``."
|
|
784
784
|
return int(os.environ.get("VERBOSE", "0"))
|
|
785
785
|
|
|
786
786
|
@classmethod
|
|
@@ -1028,6 +1028,19 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1028
1028
|
rtol=rtol,
|
|
1029
1029
|
msg=msg,
|
|
1030
1030
|
)
|
|
1031
|
+
elif expected.__class__.__name__ == "BaseModelOutputWithPooling":
|
|
1032
|
+
if expected.__class__.__name__ == value.__class__.__name__:
|
|
1033
|
+
self.assertEqual(len(expected), len(value), msg=msg)
|
|
1034
|
+
self.assertEqual(list(expected), list(value), msg=msg) # checks the order
|
|
1035
|
+
self.assertEqualAny(
|
|
1036
|
+
{k: v for k, v in expected.items()}, # noqa: C416
|
|
1037
|
+
{k: v for k, v in value.items()}, # noqa: C416
|
|
1038
|
+
atol=atol,
|
|
1039
|
+
rtol=rtol,
|
|
1040
|
+
msg=msg,
|
|
1041
|
+
)
|
|
1042
|
+
else:
|
|
1043
|
+
self.assertEqualArray(expected.last_hidden_state, value)
|
|
1031
1044
|
elif isinstance(expected, (tuple, list, dict)):
|
|
1032
1045
|
self.assertIsInstance(value, type(expected), msg=msg)
|
|
1033
1046
|
self.assertEqual(len(expected), len(value), msg=msg)
|
|
@@ -1043,24 +1056,28 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1043
1056
|
"SlidingWindowCache",
|
|
1044
1057
|
"HybridCache",
|
|
1045
1058
|
):
|
|
1059
|
+
from .helpers.cache_helper import CacheKeyValue
|
|
1060
|
+
|
|
1046
1061
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1047
|
-
|
|
1048
|
-
self.assertEqualAny(
|
|
1049
|
-
{k: expected.__dict__.get(k, None) for k in atts},
|
|
1050
|
-
{k: value.__dict__.get(k, None) for k in atts},
|
|
1051
|
-
atol=atol,
|
|
1052
|
-
rtol=rtol,
|
|
1053
|
-
)
|
|
1062
|
+
self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
|
|
1054
1063
|
elif expected.__class__.__name__ == "StaticCache":
|
|
1064
|
+
from .helpers.cache_helper import CacheKeyValue
|
|
1065
|
+
|
|
1055
1066
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1056
1067
|
self.assertEqual(expected.max_cache_len, value.max_cache_len)
|
|
1057
|
-
|
|
1058
|
-
|
|
1059
|
-
|
|
1060
|
-
|
|
1061
|
-
|
|
1062
|
-
|
|
1063
|
-
|
|
1068
|
+
self.assertEqualAny(CacheKeyValue(expected), CacheKeyValue(value))
|
|
1069
|
+
elif expected.__class__.__name__ == "CacheKeyValue":
|
|
1070
|
+
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1071
|
+
if expected.cls_layers is None:
|
|
1072
|
+
self.assertEqual(expected.cls_layers, value.cls_layers)
|
|
1073
|
+
else:
|
|
1074
|
+
self.assertEqualAny(
|
|
1075
|
+
[cls.__name__ for cls in expected.cls_layers],
|
|
1076
|
+
[cls.__name__ for cls in value.cls_layers],
|
|
1077
|
+
msg=msg,
|
|
1078
|
+
)
|
|
1079
|
+
self.assertEqualAny(expected.key_cache, value.key_cache, msg=msg)
|
|
1080
|
+
self.assertEqualAny(expected.value_cache, value.value_cache, msg=msg)
|
|
1064
1081
|
elif expected.__class__.__name__ == "EncoderDecoderCache":
|
|
1065
1082
|
self.assertEqual(type(expected), type(value), msg=msg)
|
|
1066
1083
|
atts = ["self_attention_cache", "cross_attention_cache"]
|
|
@@ -4,6 +4,19 @@ import torch
|
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
|
+
KWARGS_LAYER = {}
|
|
8
|
+
if hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"):
|
|
9
|
+
KWARGS_LAYER.update(
|
|
10
|
+
{
|
|
11
|
+
transformers.cache_utils.DynamicSlidingWindowLayer: lambda tensor: {
|
|
12
|
+
"sliding_window": tensor.shape[2]
|
|
13
|
+
},
|
|
14
|
+
transformers.cache_utils.StaticSlidingWindowLayer: lambda tensor: {
|
|
15
|
+
"sliding_window": tensor.shape[2]
|
|
16
|
+
},
|
|
17
|
+
}
|
|
18
|
+
)
|
|
19
|
+
|
|
7
20
|
|
|
8
21
|
class CacheKeyValue:
|
|
9
22
|
"""
|
|
@@ -90,7 +103,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
90
103
|
the context gives the dictionary keys but it is not expressed
|
|
91
104
|
in the dynamic shapes, these specifications seems to be different
|
|
92
105
|
for the strict and non strict mode. It also preserves tuple.
|
|
93
|
-
:param change_function: to
|
|
106
|
+
:param change_function: to modify the tensor in the structure itself,
|
|
94
107
|
like replace them by a shape
|
|
95
108
|
:return: the serialized object
|
|
96
109
|
"""
|
|
@@ -110,7 +123,7 @@ def flatten_unflatten_for_dynamic_shapes(
|
|
|
110
123
|
start = end
|
|
111
124
|
if use_dict:
|
|
112
125
|
if spec.type is dict:
|
|
113
|
-
# This a dictionary.
|
|
126
|
+
# This is a dictionary.
|
|
114
127
|
return dict(zip(spec.context, subtrees))
|
|
115
128
|
if spec.type is tuple:
|
|
116
129
|
return tuple(subtrees)
|
|
@@ -185,6 +198,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
185
198
|
def make_dynamic_cache(
|
|
186
199
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
187
200
|
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
201
|
+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
|
|
188
202
|
) -> transformers.cache_utils.DynamicCache:
|
|
189
203
|
"""
|
|
190
204
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -194,6 +208,8 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
194
208
|
:param cls_layers: to select the appropriate class to use on each layer,
|
|
195
209
|
if specified, sliding_window is ignored, it can be a string
|
|
196
210
|
if all layers are expected to follow the same class
|
|
211
|
+
:param cls_kwargs: arguments used to build a specific layer,
|
|
212
|
+
such as ``sliding_window`` for ``DynamicSlidingWindowLayer``
|
|
197
213
|
:return: :class:`transformers.cache_utils.DynamicCache`
|
|
198
214
|
|
|
199
215
|
Example:
|
|
@@ -224,49 +240,70 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
224
240
|
are supported.
|
|
225
241
|
"""
|
|
226
242
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
227
|
-
cls_kwargs = {}
|
|
228
243
|
if isinstance(cls_layers, str):
|
|
229
244
|
assert hasattr(
|
|
230
245
|
transformers.cache_utils, cls_layers
|
|
231
|
-
), f"
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
246
|
+
), f"Missing layer class {cls_layers!r}"
|
|
247
|
+
cls_layers = getattr(transformers.cache_utils, cls_layers)
|
|
248
|
+
if cls_layers and not isinstance(cls_layers, list):
|
|
249
|
+
cls_layers = [cls_layers for _ in key_value_pairs] # type: ignore[misc]
|
|
250
|
+
if cls_layers is not None and isinstance(cls_layers, list):
|
|
251
|
+
assert len(cls_layers) == len(key_value_pairs), (
|
|
252
|
+
f"Length mismatch {len(key_value_pairs)} expected but "
|
|
253
|
+
f"{len(cls_layers)} layer types are given."
|
|
254
|
+
)
|
|
255
|
+
if cls_kwargs is None:
|
|
256
|
+
cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment]
|
|
257
|
+
assert len(cls_layers) == len(cls_kwargs), (
|
|
258
|
+
f"Length mismatch {len(cls_kwargs)} expected but "
|
|
259
|
+
f"{len(cls_layers)} layer types are given, "
|
|
260
|
+
f"cls_layers={cls_layers}, cls_kwargs={cls_kwargs}"
|
|
261
|
+
)
|
|
262
|
+
cls_layer = None
|
|
263
|
+
assert (
|
|
264
|
+
key_value_pairs and key_value_pairs[0]
|
|
265
|
+
), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}"
|
|
266
|
+
for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs):
|
|
267
|
+
default_values = KWARGS_LAYER.get(clsy, lambda tensor: {})(kv[0])
|
|
268
|
+
for k, v in default_values.items():
|
|
269
|
+
if k not in kws:
|
|
270
|
+
kws[k] = v # type: ignore[index]
|
|
256
271
|
else:
|
|
272
|
+
assert cls_kwargs is None, "cls_layers must be a list if cls_kwargs is specified"
|
|
273
|
+
assert (
|
|
274
|
+
cls_layers is None
|
|
275
|
+
), f"cls_layers must be list or a string but it is {cls_layers}"
|
|
276
|
+
cls_kwargs = {}
|
|
257
277
|
cls_layer = (
|
|
258
278
|
transformers.cache_utils.DynamicLayer
|
|
259
279
|
if hasattr(transformers.cache_utils, "DynamicLayer")
|
|
260
280
|
else None
|
|
261
281
|
)
|
|
262
282
|
|
|
283
|
+
if cls_layer is not None:
|
|
284
|
+
assert isinstance(cls_kwargs, dict), (
|
|
285
|
+
f"one layer = one set of arguments, cls_layer={cls_layer}, "
|
|
286
|
+
f"cls_kwargs={cls_kwargs}"
|
|
287
|
+
)
|
|
288
|
+
cls_layers = [cls_layer for _ in key_value_pairs]
|
|
289
|
+
cls_kwargs = (
|
|
290
|
+
cls_kwargs # type: ignore[assignment]
|
|
291
|
+
if isinstance(cls_kwargs, list)
|
|
292
|
+
else [cls_kwargs for _ in key_value_pairs]
|
|
293
|
+
)
|
|
294
|
+
elif cls_layers is not None:
|
|
295
|
+
assert isinstance(cls_layers, list), f"Unexpected type cls_layers={cls_layers}"
|
|
296
|
+
assert isinstance(cls_kwargs, list), f"Unexpected type cls_kwargs={cls_kwargs}"
|
|
297
|
+
|
|
263
298
|
if (
|
|
264
299
|
key_value_pairs
|
|
265
300
|
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
266
301
|
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
267
302
|
):
|
|
268
303
|
cache = transformers.cache_utils.DynamicCache()
|
|
269
|
-
cache.layers.extend(
|
|
304
|
+
cache.layers.extend(
|
|
305
|
+
[cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
|
|
306
|
+
)
|
|
270
307
|
for i, layer in enumerate(cache.layers):
|
|
271
308
|
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
272
309
|
layer.dtype = k.dtype
|
|
@@ -281,8 +318,25 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
281
318
|
return finalize_cache(cache)
|
|
282
319
|
|
|
283
320
|
cache = transformers.cache_utils.DynamicCache()
|
|
284
|
-
if hasattr(cache, "layers") and
|
|
285
|
-
|
|
321
|
+
if hasattr(cache, "layers") and (
|
|
322
|
+
cls_layer is None or cls_layer != transformers.cache_utils.DynamicLayer
|
|
323
|
+
):
|
|
324
|
+
assert isinstance(cls_layers, list) and isinstance(cls_kwargs, list), (
|
|
325
|
+
f"Wrong type {type(cls_layers)} for cls_layers or "
|
|
326
|
+
f"{type(cls_kwargs)} for cls_kwargs"
|
|
327
|
+
)
|
|
328
|
+
assert len(cls_kwargs) == len(cls_layers) and len(cls_kwargs) == len(
|
|
329
|
+
key_value_pairs
|
|
330
|
+
), (
|
|
331
|
+
f"Length mismatch between len(cls_kwargs)={len(cls_kwargs)}, "
|
|
332
|
+
f"len(cls_layers)={len(cls_layers)}, "
|
|
333
|
+
f"len(key_value_pairs)={len(key_value_pairs)}, "
|
|
334
|
+
f"cls_kwargs={cls_kwargs}, cls_layers={cls_layers}"
|
|
335
|
+
)
|
|
336
|
+
del cache.layers[:]
|
|
337
|
+
cache.layers.extend(
|
|
338
|
+
[cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
|
|
339
|
+
)
|
|
286
340
|
for i, layer in enumerate(cache.layers):
|
|
287
341
|
layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
288
342
|
layer.is_initialized = True
|
|
@@ -306,6 +360,7 @@ else:
|
|
|
306
360
|
def make_dynamic_cache(
|
|
307
361
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
308
362
|
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
363
|
+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
|
|
309
364
|
) -> transformers.cache_utils.DynamicCache:
|
|
310
365
|
"""
|
|
311
366
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -337,7 +392,9 @@ else:
|
|
|
337
392
|
)
|
|
338
393
|
print(string_type(past_key_values, with_shape=True))
|
|
339
394
|
"""
|
|
340
|
-
assert
|
|
395
|
+
assert (
|
|
396
|
+
not cls_layers and not cls_kwargs
|
|
397
|
+
), "cls_layers, cls_kwargs cannot be used for transformers<5."
|
|
341
398
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
342
399
|
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
|
|
343
400
|
for i, (key, value) in enumerate(key_value_pairs):
|
|
@@ -348,6 +405,7 @@ else:
|
|
|
348
405
|
def make_static_cache(
|
|
349
406
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
350
407
|
max_cache_len: Optional[int] = None,
|
|
408
|
+
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
351
409
|
) -> transformers.cache_utils.DynamicCache:
|
|
352
410
|
"""
|
|
353
411
|
Creates an instance of :class:`transformers.cache_utils.StaticCache`.
|
|
@@ -379,6 +437,9 @@ def make_static_cache(
|
|
|
379
437
|
)
|
|
380
438
|
print(string_type(past_key_values, with_shape=True))
|
|
381
439
|
"""
|
|
440
|
+
assert not cls_layers or set(cls_layers) == {
|
|
441
|
+
transformers.cache_utils.StaticLayer
|
|
442
|
+
}, f"Not implemented when cls_layers={cls_layers!r}"
|
|
382
443
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
383
444
|
|
|
384
445
|
class _config:
|
|
@@ -583,13 +644,9 @@ if hasattr(transformers.cache_utils, "SlidingWindowCache"):
|
|
|
583
644
|
)
|
|
584
645
|
return finalize_cache(cache)
|
|
585
646
|
|
|
586
|
-
def get_make_hybrid_cache():
|
|
587
|
-
return make_sliding_window_cache
|
|
588
|
-
|
|
589
647
|
else:
|
|
590
648
|
make_sliding_window_cache = None # type: ignore[assignment]
|
|
591
649
|
|
|
592
|
-
|
|
593
650
|
if hasattr(transformers.cache_utils, "HybridCache"):
|
|
594
651
|
|
|
595
652
|
def make_hybrid_cache(
|
|
@@ -775,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_
|
|
|
775
832
|
# This is used to expand the cache when it does not contains enough layers.
|
|
776
833
|
# This is needed since transformers>4.55.3
|
|
777
834
|
cache.layer_class_to_replicate = cache.layers[0].__class__
|
|
835
|
+
assert (
|
|
836
|
+
not hasattr(cache, "layers")
|
|
837
|
+
or len(cache.layers) != 1
|
|
838
|
+
or cache.layers[0].keys is not None
|
|
839
|
+
), (
|
|
840
|
+
f"Size mismatch between {len(cache.layers)=}, "
|
|
841
|
+
f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined]
|
|
842
|
+
f"first value={cache.layers[0].values}" # type: ignore[attr-defined]
|
|
843
|
+
)
|
|
778
844
|
return cache
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Dict, List, Optional, Tuple
|
|
2
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
3
3
|
import onnx
|
|
4
4
|
import onnx.helper as oh
|
|
5
5
|
import torch
|
|
@@ -46,10 +46,10 @@ class LayerNormalizationOrt(OpRunKernel):
|
|
|
46
46
|
f"This kernel implementation only work when only one output "
|
|
47
47
|
f"is required but {node.output} were."
|
|
48
48
|
)
|
|
49
|
-
self._cache: Dict[Tuple[int, int],
|
|
49
|
+
self._cache: Dict[Tuple[int, int], Any] = {}
|
|
50
50
|
self.is_cpu = torch.device("cpu") == self.device
|
|
51
51
|
|
|
52
|
-
def _make_model(self, itype: int, rank: int, has_bias: bool) ->
|
|
52
|
+
def _make_model(self, itype: int, rank: int, has_bias: bool) -> Any:
|
|
53
53
|
shape = [*["d{i}" for i in range(rank - 1)], "last"]
|
|
54
54
|
layer_model = oh.make_model(
|
|
55
55
|
oh.make_graph(
|
|
@@ -88,6 +88,7 @@ class LayerNormalizationOrt(OpRunKernel):
|
|
|
88
88
|
providers=[provider],
|
|
89
89
|
)
|
|
90
90
|
|
|
91
|
+
# pyrefly: ignore[bad-override]
|
|
91
92
|
def run(self, x, scale, bias=None):
|
|
92
93
|
itype = torch_dtype_to_onnx_dtype(x.dtype)
|
|
93
94
|
rank = len(x.shape)
|
|
@@ -124,7 +125,7 @@ class MatMulOrt(OpRunKernel):
|
|
|
124
125
|
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
|
|
125
126
|
self.is_cpu = torch.device("cpu") == self.device
|
|
126
127
|
|
|
127
|
-
def _make_model(self, itype: int, ranka: int, rankb: int) ->
|
|
128
|
+
def _make_model(self, itype: int, ranka: int, rankb: int) -> Any:
|
|
128
129
|
shapea = ["a{i}" for i in range(ranka)]
|
|
129
130
|
shapeb = ["b{i}" for i in range(rankb)]
|
|
130
131
|
shapec = ["c{i}" for i in range(max(ranka, rankb))]
|
|
@@ -149,6 +150,7 @@ class MatMulOrt(OpRunKernel):
|
|
|
149
150
|
providers=[provider],
|
|
150
151
|
)
|
|
151
152
|
|
|
153
|
+
# pyrefly: ignore[bad-override]
|
|
152
154
|
def run(self, a, b):
|
|
153
155
|
itype = torch_dtype_to_onnx_dtype(a.dtype)
|
|
154
156
|
ranka, rankb = len(a.shape), len(b.shape)
|
|
@@ -159,5 +161,6 @@ class MatMulOrt(OpRunKernel):
|
|
|
159
161
|
if self.verbose:
|
|
160
162
|
print(f"[MatMulOrt] running on {self._provider!r}")
|
|
161
163
|
feeds = dict(A=a.tensor, B=b.tensor)
|
|
164
|
+
# pyrefly: ignore[missing-attribute]
|
|
162
165
|
got = sess.run(None, feeds)[0]
|
|
163
166
|
return OpRunTensor(got)
|
|
@@ -36,7 +36,7 @@ class GraphRendering:
|
|
|
36
36
|
:return: computation order
|
|
37
37
|
"""
|
|
38
38
|
assert not ({"If", "Scan", "Loop", "SequenceMap"} & set(n.op_type for n in nodes)), (
|
|
39
|
-
f"This
|
|
39
|
+
f"This algorithm is not yet implemented if the sequence contains "
|
|
40
40
|
f"a control flow, types={sorted(set(n.op_type for n in nodes))}"
|
|
41
41
|
)
|
|
42
42
|
number = {e: start - 1 for e in (existing or [])} # noqa: C420
|
|
@@ -131,14 +131,14 @@ class GraphRendering:
|
|
|
131
131
|
@property
|
|
132
132
|
def nodes(self) -> List[onnx.NodeProto]:
|
|
133
133
|
"Returns the list of nodes"
|
|
134
|
-
return (
|
|
134
|
+
return list(
|
|
135
135
|
self.proto.graph.node
|
|
136
136
|
if isinstance(self.proto, onnx.ModelProto)
|
|
137
137
|
else self.proto.node
|
|
138
138
|
)
|
|
139
139
|
|
|
140
140
|
@property
|
|
141
|
-
def start_names(self) -> List[
|
|
141
|
+
def start_names(self) -> List[str]:
|
|
142
142
|
"Returns the list of known names, inputs and initializer"
|
|
143
143
|
graph = self.proto.graph if isinstance(self.proto, onnx.ModelProto) else self.proto
|
|
144
144
|
input_names = (
|
|
@@ -151,7 +151,7 @@ class GraphRendering:
|
|
|
151
151
|
if isinstance(graph, onnx.FunctionProto)
|
|
152
152
|
else [
|
|
153
153
|
*[i.name for i in graph.initializer],
|
|
154
|
-
*[i.name for i in graph.sparse_initializer],
|
|
154
|
+
*[i.values.name for i in graph.sparse_initializer],
|
|
155
155
|
]
|
|
156
156
|
)
|
|
157
157
|
return [*input_names, *init_names]
|
|
@@ -159,7 +159,7 @@ class GraphRendering:
|
|
|
159
159
|
@property
|
|
160
160
|
def input_names(self) -> List[str]:
|
|
161
161
|
"Returns the list of input names."
|
|
162
|
-
return (
|
|
162
|
+
return list(
|
|
163
163
|
self.proto.input
|
|
164
164
|
if isinstance(self.proto, onnx.FunctionProto)
|
|
165
165
|
else [
|
|
@@ -173,7 +173,7 @@ class GraphRendering:
|
|
|
173
173
|
@property
|
|
174
174
|
def output_names(self) -> List[str]:
|
|
175
175
|
"Returns the list of output names."
|
|
176
|
-
return (
|
|
176
|
+
return list(
|
|
177
177
|
self.proto.output
|
|
178
178
|
if isinstance(self.proto, onnx.FunctionProto)
|
|
179
179
|
else [
|
|
@@ -574,6 +574,32 @@ def string_type(
|
|
|
574
574
|
print(f"[string_type] CACHE1:{type(obj)}")
|
|
575
575
|
return f"MambaCache(conv_states={c}, ssm_states={d})"
|
|
576
576
|
|
|
577
|
+
if (
|
|
578
|
+
obj.__class__.__name__ in {"DynamicCache"}
|
|
579
|
+
and hasattr(obj, "layers")
|
|
580
|
+
and any(lay.__class__.__name__ != "DynamicLayer" for lay in obj.layers)
|
|
581
|
+
):
|
|
582
|
+
slay = []
|
|
583
|
+
for lay in obj.layers:
|
|
584
|
+
skeys = string_type(
|
|
585
|
+
lay.keys,
|
|
586
|
+
with_shape=with_shape,
|
|
587
|
+
with_min_max=with_min_max,
|
|
588
|
+
with_device=with_device,
|
|
589
|
+
limit=limit,
|
|
590
|
+
verbose=verbose,
|
|
591
|
+
)
|
|
592
|
+
svalues = string_type(
|
|
593
|
+
lay.keys,
|
|
594
|
+
with_shape=with_shape,
|
|
595
|
+
with_min_max=with_min_max,
|
|
596
|
+
with_device=with_device,
|
|
597
|
+
limit=limit,
|
|
598
|
+
verbose=verbose,
|
|
599
|
+
)
|
|
600
|
+
slay.append(f"{lay.__class__.__name__}({skeys}, {svalues})")
|
|
601
|
+
return f"{obj.__class__.__name__}({', '.join(slay)})"
|
|
602
|
+
|
|
577
603
|
if obj.__class__.__name__ in {
|
|
578
604
|
"DynamicCache",
|
|
579
605
|
"SlidingWindowCache",
|
|
@@ -829,6 +855,19 @@ def string_type(
|
|
|
829
855
|
return f"{obj}"
|
|
830
856
|
if obj.__class__.__name__ == "FakeTensorContext":
|
|
831
857
|
return "FakeTensorContext(...)"
|
|
858
|
+
if obj.__class__.__name__ == "Chat":
|
|
859
|
+
import transformers.utils.chat_template_utils as ctu
|
|
860
|
+
|
|
861
|
+
assert isinstance(obj, ctu.Chat), f"unexpected type {type(obj)}"
|
|
862
|
+
msg = string_type(
|
|
863
|
+
obj.messages,
|
|
864
|
+
with_shape=with_shape,
|
|
865
|
+
with_min_max=with_min_max,
|
|
866
|
+
with_device=with_device,
|
|
867
|
+
limit=limit,
|
|
868
|
+
verbose=verbose,
|
|
869
|
+
)
|
|
870
|
+
return f"Chat({msg})"
|
|
832
871
|
|
|
833
872
|
if verbose:
|
|
834
873
|
print(f"[string_type] END:{type(obj)}")
|