onnx-diagnostic 0.8.11__py3-none-any.whl → 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
- onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
- onnx_diagnostic/ci_models/data/__init__.py +0 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +8 -3
- onnx_diagnostic/export/api.py +11 -0
- onnx_diagnostic/export/dynamic_shapes.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +96 -30
- onnx_diagnostic/helpers/helper.py +39 -0
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +5 -1
- onnx_diagnostic/helpers/rt_helper.py +53 -9
- onnx_diagnostic/helpers/torch_helper.py +7 -2
- onnx_diagnostic/investigate/input_observer.py +793 -152
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +28 -2
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +24 -21
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
Binary file
|
|
Binary file
|
|
File without changes
|
|
@@ -668,12 +668,17 @@ def get_inputs_for_part(
|
|
|
668
668
|
f"{user_prompt}<|image_1|>\n<|image_2|>\n<|image_3|>\n<|image_4|>\n"
|
|
669
669
|
f"What is shown in these four images?{prompt_suffix}{assistant_prompt}"
|
|
670
670
|
)
|
|
671
|
-
|
|
672
|
-
|
|
671
|
+
image_2_path = os.path.join(
|
|
672
|
+
os.path.dirname(__file__), "data", "Blanca_Lake_Hudak.jpg"
|
|
673
|
+
)
|
|
674
|
+
image_2 = Image.open(image_2_path)
|
|
673
675
|
url = (
|
|
674
676
|
"https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
|
|
675
677
|
)
|
|
676
|
-
|
|
678
|
+
image_3_path = os.path.join(
|
|
679
|
+
os.path.dirname(__file__), "data", "Ice_worm_glacier.jpg"
|
|
680
|
+
)
|
|
681
|
+
image_3 = Image.open(image_3_path)
|
|
677
682
|
|
|
678
683
|
images = [image_1, image_2, image_3, image_4]
|
|
679
684
|
inputs = processor(prompt, images=images, return_tensors="pt").to(device)
|
onnx_diagnostic/export/api.py
CHANGED
|
@@ -428,6 +428,16 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
428
428
|
new_kwargs[k] = v
|
|
429
429
|
return new_kwargs
|
|
430
430
|
|
|
431
|
+
def is_empty_cache(self, cache):
|
|
432
|
+
if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"):
|
|
433
|
+
if len(cache.layers) == 1 and cache.layers[0].keys is None:
|
|
434
|
+
return True
|
|
435
|
+
if len(cache.layers) == 0:
|
|
436
|
+
return True
|
|
437
|
+
if cache is None:
|
|
438
|
+
return True
|
|
439
|
+
return False
|
|
440
|
+
|
|
431
441
|
def forward(self, *args, **kwargs):
|
|
432
442
|
if not self._export_done:
|
|
433
443
|
inp_args = args
|
|
@@ -443,6 +453,7 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
|
|
|
443
453
|
if v is not None
|
|
444
454
|
and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
|
|
445
455
|
and not isinstance(v, (bool, int, float))
|
|
456
|
+
and not self.is_empty_cache(v)
|
|
446
457
|
}
|
|
447
458
|
)
|
|
448
459
|
inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
|
|
@@ -834,7 +834,7 @@ class ModelInputs:
|
|
|
834
834
|
"""Guesses the dynamic shapes for one argument."""
|
|
835
835
|
if len(objs) == 0:
|
|
836
836
|
return None
|
|
837
|
-
set_types = set(type(o) for o in objs)
|
|
837
|
+
set_types = set(type(o) for o in objs if o is not None)
|
|
838
838
|
assert (
|
|
839
839
|
len(set_types) == 1
|
|
840
840
|
), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
|
|
@@ -4,6 +4,19 @@ import torch
|
|
|
4
4
|
import transformers
|
|
5
5
|
import transformers.cache_utils
|
|
6
6
|
|
|
7
|
+
KWARGS_LAYER = {}
|
|
8
|
+
if hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"):
|
|
9
|
+
KWARGS_LAYER.update(
|
|
10
|
+
{
|
|
11
|
+
transformers.cache_utils.DynamicSlidingWindowLayer: lambda tensor: {
|
|
12
|
+
"sliding_window": tensor.shape[2]
|
|
13
|
+
},
|
|
14
|
+
transformers.cache_utils.StaticSlidingWindowLayer: lambda tensor: {
|
|
15
|
+
"sliding_window": tensor.shape[2]
|
|
16
|
+
},
|
|
17
|
+
}
|
|
18
|
+
)
|
|
19
|
+
|
|
7
20
|
|
|
8
21
|
class CacheKeyValue:
|
|
9
22
|
"""
|
|
@@ -185,6 +198,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
185
198
|
def make_dynamic_cache(
|
|
186
199
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
187
200
|
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
201
|
+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
|
|
188
202
|
) -> transformers.cache_utils.DynamicCache:
|
|
189
203
|
"""
|
|
190
204
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -194,6 +208,8 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
194
208
|
:param cls_layers: to select the appropriate class to use on each layer,
|
|
195
209
|
if specified, sliding_window is ignored, it can be a string
|
|
196
210
|
if all layers are expected to follow the same class
|
|
211
|
+
:param cls_kwargs: arguments used to build a specific layer,
|
|
212
|
+
such as ``sliding_window`` for ``DynamicSlidingWindowLayer``
|
|
197
213
|
:return: :class:`transformers.cache_utils.DynamicCache`
|
|
198
214
|
|
|
199
215
|
Example:
|
|
@@ -224,49 +240,70 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
224
240
|
are supported.
|
|
225
241
|
"""
|
|
226
242
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
227
|
-
cls_kwargs = {}
|
|
228
243
|
if isinstance(cls_layers, str):
|
|
229
244
|
assert hasattr(
|
|
230
245
|
transformers.cache_utils, cls_layers
|
|
231
|
-
), f"
|
|
232
|
-
|
|
233
|
-
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
246
|
+
), f"Missing layer class {cls_layers!r}"
|
|
247
|
+
cls_layers = getattr(transformers.cache_utils, cls_layers)
|
|
248
|
+
if cls_layers and not isinstance(cls_layers, list):
|
|
249
|
+
cls_layers = [cls_layers for _ in key_value_pairs] # type: ignore[misc]
|
|
250
|
+
if cls_layers is not None and isinstance(cls_layers, list):
|
|
251
|
+
assert len(cls_layers) == len(key_value_pairs), (
|
|
252
|
+
f"Length mismatch {len(key_value_pairs)} expected but "
|
|
253
|
+
f"{len(cls_layers)} layer types are given."
|
|
254
|
+
)
|
|
255
|
+
if cls_kwargs is None:
|
|
256
|
+
cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment]
|
|
257
|
+
assert len(cls_layers) == len(cls_kwargs), (
|
|
258
|
+
f"Length mismatch {len(cls_kwargs)} expected but "
|
|
259
|
+
f"{len(cls_layers)} layer types are given, "
|
|
260
|
+
f"cls_layers={cls_layers}, cls_kwargs={cls_kwargs}"
|
|
261
|
+
)
|
|
262
|
+
cls_layer = None
|
|
263
|
+
assert (
|
|
264
|
+
key_value_pairs and key_value_pairs[0]
|
|
265
|
+
), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}"
|
|
266
|
+
for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs):
|
|
267
|
+
default_values = KWARGS_LAYER.get(clsy, lambda tensor: {})(kv[0])
|
|
268
|
+
for k, v in default_values.items():
|
|
269
|
+
if k not in kws:
|
|
270
|
+
kws[k] = v # type: ignore[index]
|
|
256
271
|
else:
|
|
272
|
+
assert cls_kwargs is None, "cls_layers must be a list if cls_kwargs is specified"
|
|
273
|
+
assert (
|
|
274
|
+
cls_layers is None
|
|
275
|
+
), f"cls_layers must be list or a string but it is {cls_layers}"
|
|
276
|
+
cls_kwargs = {}
|
|
257
277
|
cls_layer = (
|
|
258
278
|
transformers.cache_utils.DynamicLayer
|
|
259
279
|
if hasattr(transformers.cache_utils, "DynamicLayer")
|
|
260
280
|
else None
|
|
261
281
|
)
|
|
262
282
|
|
|
283
|
+
if cls_layer is not None:
|
|
284
|
+
assert isinstance(cls_kwargs, dict), (
|
|
285
|
+
f"one layer = one set of arguments, cls_layer={cls_layer}, "
|
|
286
|
+
f"cls_kwargs={cls_kwargs}"
|
|
287
|
+
)
|
|
288
|
+
cls_layers = [cls_layer for _ in key_value_pairs]
|
|
289
|
+
cls_kwargs = (
|
|
290
|
+
cls_kwargs # type: ignore[assignment]
|
|
291
|
+
if isinstance(cls_kwargs, list)
|
|
292
|
+
else [cls_kwargs for _ in key_value_pairs]
|
|
293
|
+
)
|
|
294
|
+
elif cls_layers is not None:
|
|
295
|
+
assert isinstance(cls_layers, list), f"Unexpected type cls_layers={cls_layers}"
|
|
296
|
+
assert isinstance(cls_kwargs, list), f"Unexpected type cls_kwargs={cls_kwargs}"
|
|
297
|
+
|
|
263
298
|
if (
|
|
264
299
|
key_value_pairs
|
|
265
300
|
and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
|
|
266
301
|
and pv.Version(transformers.__version__) >= pv.Version("4.56")
|
|
267
302
|
):
|
|
268
303
|
cache = transformers.cache_utils.DynamicCache()
|
|
269
|
-
cache.layers.extend(
|
|
304
|
+
cache.layers.extend(
|
|
305
|
+
[cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
|
|
306
|
+
)
|
|
270
307
|
for i, layer in enumerate(cache.layers):
|
|
271
308
|
k, v = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
272
309
|
layer.dtype = k.dtype
|
|
@@ -281,8 +318,25 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
|
|
|
281
318
|
return finalize_cache(cache)
|
|
282
319
|
|
|
283
320
|
cache = transformers.cache_utils.DynamicCache()
|
|
284
|
-
if hasattr(cache, "layers") and
|
|
285
|
-
|
|
321
|
+
if hasattr(cache, "layers") and (
|
|
322
|
+
cls_layer is None or cls_layer != transformers.cache_utils.DynamicLayer
|
|
323
|
+
):
|
|
324
|
+
assert isinstance(cls_layers, list) and isinstance(cls_kwargs, list), (
|
|
325
|
+
f"Wrong type {type(cls_layers)} for cls_layers or "
|
|
326
|
+
f"{type(cls_kwargs)} for cls_kwargs"
|
|
327
|
+
)
|
|
328
|
+
assert len(cls_kwargs) == len(cls_layers) and len(cls_kwargs) == len(
|
|
329
|
+
key_value_pairs
|
|
330
|
+
), (
|
|
331
|
+
f"Length mismatch between len(cls_kwargs)={len(cls_kwargs)}, "
|
|
332
|
+
f"len(cls_layers)={len(cls_layers)}, "
|
|
333
|
+
f"len(key_value_pairs)={len(key_value_pairs)}, "
|
|
334
|
+
f"cls_kwargs={cls_kwargs}, cls_layers={cls_layers}"
|
|
335
|
+
)
|
|
336
|
+
del cache.layers[:]
|
|
337
|
+
cache.layers.extend(
|
|
338
|
+
[cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
|
|
339
|
+
)
|
|
286
340
|
for i, layer in enumerate(cache.layers):
|
|
287
341
|
layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
|
|
288
342
|
layer.is_initialized = True
|
|
@@ -306,6 +360,7 @@ else:
|
|
|
306
360
|
def make_dynamic_cache(
|
|
307
361
|
key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
|
|
308
362
|
cls_layers: Optional[Union[str, List[type]]] = None,
|
|
363
|
+
cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
|
|
309
364
|
) -> transformers.cache_utils.DynamicCache:
|
|
310
365
|
"""
|
|
311
366
|
Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
|
|
@@ -337,7 +392,9 @@ else:
|
|
|
337
392
|
)
|
|
338
393
|
print(string_type(past_key_values, with_shape=True))
|
|
339
394
|
"""
|
|
340
|
-
assert
|
|
395
|
+
assert (
|
|
396
|
+
not cls_layers and not cls_kwargs
|
|
397
|
+
), "cls_layers, cls_kwargs cannot be used for transformers<5."
|
|
341
398
|
key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
|
|
342
399
|
cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
|
|
343
400
|
for i, (key, value) in enumerate(key_value_pairs):
|
|
@@ -775,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_
|
|
|
775
832
|
# This is used to expand the cache when it does not contains enough layers.
|
|
776
833
|
# This is needed since transformers>4.55.3
|
|
777
834
|
cache.layer_class_to_replicate = cache.layers[0].__class__
|
|
835
|
+
assert (
|
|
836
|
+
not hasattr(cache, "layers")
|
|
837
|
+
or len(cache.layers) != 1
|
|
838
|
+
or cache.layers[0].keys is not None
|
|
839
|
+
), (
|
|
840
|
+
f"Size mismatch between {len(cache.layers)=}, "
|
|
841
|
+
f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined]
|
|
842
|
+
f"first value={cache.layers[0].values}" # type: ignore[attr-defined]
|
|
843
|
+
)
|
|
778
844
|
return cache
|
|
@@ -574,6 +574,32 @@ def string_type(
|
|
|
574
574
|
print(f"[string_type] CACHE1:{type(obj)}")
|
|
575
575
|
return f"MambaCache(conv_states={c}, ssm_states={d})"
|
|
576
576
|
|
|
577
|
+
if (
|
|
578
|
+
obj.__class__.__name__ in {"DynamicCache"}
|
|
579
|
+
and hasattr(obj, "layers")
|
|
580
|
+
and any(lay.__class__.__name__ != "DynamicLayer" for lay in obj.layers)
|
|
581
|
+
):
|
|
582
|
+
slay = []
|
|
583
|
+
for lay in obj.layers:
|
|
584
|
+
skeys = string_type(
|
|
585
|
+
lay.keys,
|
|
586
|
+
with_shape=with_shape,
|
|
587
|
+
with_min_max=with_min_max,
|
|
588
|
+
with_device=with_device,
|
|
589
|
+
limit=limit,
|
|
590
|
+
verbose=verbose,
|
|
591
|
+
)
|
|
592
|
+
svalues = string_type(
|
|
593
|
+
lay.keys,
|
|
594
|
+
with_shape=with_shape,
|
|
595
|
+
with_min_max=with_min_max,
|
|
596
|
+
with_device=with_device,
|
|
597
|
+
limit=limit,
|
|
598
|
+
verbose=verbose,
|
|
599
|
+
)
|
|
600
|
+
slay.append(f"{lay.__class__.__name__}({skeys}, {svalues})")
|
|
601
|
+
return f"{obj.__class__.__name__}({', '.join(slay)})"
|
|
602
|
+
|
|
577
603
|
if obj.__class__.__name__ in {
|
|
578
604
|
"DynamicCache",
|
|
579
605
|
"SlidingWindowCache",
|
|
@@ -829,6 +855,19 @@ def string_type(
|
|
|
829
855
|
return f"{obj}"
|
|
830
856
|
if obj.__class__.__name__ == "FakeTensorContext":
|
|
831
857
|
return "FakeTensorContext(...)"
|
|
858
|
+
if obj.__class__.__name__ == "Chat":
|
|
859
|
+
import transformers.utils.chat_template_utils as ctu
|
|
860
|
+
|
|
861
|
+
assert isinstance(obj, ctu.Chat), f"unexpected type {type(obj)}"
|
|
862
|
+
msg = string_type(
|
|
863
|
+
obj.messages,
|
|
864
|
+
with_shape=with_shape,
|
|
865
|
+
with_min_max=with_min_max,
|
|
866
|
+
with_device=with_device,
|
|
867
|
+
limit=limit,
|
|
868
|
+
verbose=verbose,
|
|
869
|
+
)
|
|
870
|
+
return f"Chat({msg})"
|
|
832
871
|
|
|
833
872
|
if verbose:
|
|
834
873
|
print(f"[string_type] END:{type(obj)}")
|
|
@@ -1742,7 +1742,7 @@ def _find_used_names(node_list, node_indices):
|
|
|
1742
1742
|
possible_outputs |= {o for o in node_list[i_node].output if o}
|
|
1743
1743
|
# find all requires input from the other nodes
|
|
1744
1744
|
set_indices = set(node_indices)
|
|
1745
|
-
not_known
|
|
1745
|
+
not_known = set()
|
|
1746
1746
|
ranges = list(range(len(node_list)))
|
|
1747
1747
|
for i_node in ranges[::-1]:
|
|
1748
1748
|
if i_node in set_indices:
|
|
@@ -6,7 +6,7 @@ import torch
|
|
|
6
6
|
from torch._C import _from_dlpack
|
|
7
7
|
import onnxruntime
|
|
8
8
|
from onnxruntime.capi import _pybind_state as ORTC
|
|
9
|
-
from .helper import size_type
|
|
9
|
+
from .helper import size_type, string_type
|
|
10
10
|
from .onnx_helper import (
|
|
11
11
|
onnx_dtype_to_np_dtype,
|
|
12
12
|
np_dtype_to_tensor_dtype,
|
|
@@ -511,6 +511,10 @@ class InferenceSessionForTorch(_InferenceSession):
|
|
|
511
511
|
device = -1
|
|
512
512
|
for k, v in feeds.items():
|
|
513
513
|
assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
|
|
514
|
+
assert hasattr(v, "device"), (
|
|
515
|
+
f"Unepxected class {type(v)} for input {k!r}, "
|
|
516
|
+
f"feeds={string_type(feeds, with_shape=True)}"
|
|
517
|
+
)
|
|
514
518
|
device = max(device, v.get_device())
|
|
515
519
|
assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
|
|
516
520
|
if not v.is_contiguous():
|
|
@@ -115,7 +115,7 @@ def make_feeds(
|
|
|
115
115
|
def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
|
|
116
116
|
if isinstance(s, int):
|
|
117
117
|
return s
|
|
118
|
-
if s == "batch":
|
|
118
|
+
if s == "batch" or i == 0:
|
|
119
119
|
return batch
|
|
120
120
|
# Everything else is cache length or sequence length.
|
|
121
121
|
return 0
|
|
@@ -153,9 +153,13 @@ def make_empty_cache(
|
|
|
153
153
|
[i.type for i in sess.get_inputs()[2:]],
|
|
154
154
|
)
|
|
155
155
|
"""
|
|
156
|
+
assert batch > 0, f"batch size = {batch} must be positive"
|
|
156
157
|
feeds = {}
|
|
157
158
|
for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
|
|
158
159
|
new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
|
|
160
|
+
assert (
|
|
161
|
+
new_shape and new_shape[0] > 0
|
|
162
|
+
), f"new_shape={new_shape} cannot have a null batch size, name={name!r}, shape={shape}"
|
|
159
163
|
feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
|
|
160
164
|
return feeds
|
|
161
165
|
|
|
@@ -272,6 +276,7 @@ def generate_and_validate(
|
|
|
272
276
|
def onnx_generate(
|
|
273
277
|
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
|
|
274
278
|
input_ids: torch.Tensor,
|
|
279
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
275
280
|
eos_token_id: int = 2,
|
|
276
281
|
max_new_tokens=100,
|
|
277
282
|
return_session: bool = False,
|
|
@@ -330,7 +335,9 @@ def onnx_generate(
|
|
|
330
335
|
)
|
|
331
336
|
|
|
332
337
|
print("-- generate with onnx")
|
|
333
|
-
onnx_outputs = onnx_generate(
|
|
338
|
+
onnx_outputs = onnx_generate(
|
|
339
|
+
model_name, input_ids[:1], eos_token_id=2, max_new_tokens=10
|
|
340
|
+
)
|
|
334
341
|
print("-- onnx output", onnx_outputs)
|
|
335
342
|
|
|
336
343
|
# The example continues with other functions doing the same.
|
|
@@ -364,6 +371,7 @@ def onnx_generate(
|
|
|
364
371
|
input_names = session.input_names
|
|
365
372
|
input_types = session.input_types
|
|
366
373
|
has_position_ids = "position_ids" in session.input_names
|
|
374
|
+
has_cache_position = "cache_position" in session.input_names
|
|
367
375
|
|
|
368
376
|
assert (
|
|
369
377
|
len(input_names) > 2
|
|
@@ -377,21 +385,46 @@ def onnx_generate(
|
|
|
377
385
|
not has_position_ids or input_names[2] == "position_ids"
|
|
378
386
|
), f"position_ids must the third input but input_names={input_names}"
|
|
379
387
|
|
|
388
|
+
cache_names, cache_shapes, cache_types = [], [], []
|
|
389
|
+
for name, shape, dt in zip(input_names, input_shapes, input_types):
|
|
390
|
+
if name.startswith("past_key_values"):
|
|
391
|
+
cache_names.append(name)
|
|
392
|
+
cache_shapes.append(shape)
|
|
393
|
+
cache_types.append(dt)
|
|
394
|
+
|
|
380
395
|
# First call: prefill
|
|
396
|
+
empty_cache = make_empty_cache(input_ids.shape[0], cache_names, cache_shapes, cache_types)
|
|
381
397
|
feeds = dict(
|
|
382
398
|
input_ids=input_ids,
|
|
383
|
-
attention_mask=
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
|
|
399
|
+
attention_mask=(
|
|
400
|
+
attention_mask
|
|
401
|
+
if attention_mask is not None
|
|
402
|
+
else torch.ones(input_ids.shape, dtype=input_ids.dtype, device=input_ids.device)
|
|
388
403
|
),
|
|
404
|
+
**empty_cache,
|
|
389
405
|
)
|
|
406
|
+
|
|
390
407
|
if has_position_ids:
|
|
391
|
-
|
|
408
|
+
assert (
|
|
409
|
+
input_ids.shape[1] > 0
|
|
410
|
+
), f"unexpected value for input_ids shape={input_ids.shape}"
|
|
411
|
+
position_ids = torch.unsqueeze(
|
|
392
412
|
torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
|
|
393
413
|
)
|
|
414
|
+
feeds["position_ids"] = position_ids
|
|
415
|
+
|
|
416
|
+
if has_cache_position:
|
|
417
|
+
assert empty_cache, "no cache means no cache_position"
|
|
418
|
+
first_tensor = next(iter(empty_cache.values()))
|
|
419
|
+
cache_position = torch.arange(
|
|
420
|
+
first_tensor.shape[2],
|
|
421
|
+
input_ids.shape[1] + first_tensor.shape[2],
|
|
422
|
+
dtype=torch.int64,
|
|
423
|
+
device=input_ids.device,
|
|
424
|
+
)
|
|
425
|
+
feeds["cache_position"] = cache_position
|
|
394
426
|
|
|
427
|
+
# prefill step
|
|
395
428
|
outputs = session.run(None, feeds)
|
|
396
429
|
|
|
397
430
|
# Next calls: decode
|
|
@@ -424,7 +457,18 @@ def onnx_generate(
|
|
|
424
457
|
),
|
|
425
458
|
0,
|
|
426
459
|
)
|
|
427
|
-
|
|
460
|
+
if has_cache_position:
|
|
461
|
+
feeds["cache_position"] = torch.arange(
|
|
462
|
+
input_ids.shape[1],
|
|
463
|
+
input_ids.shape[1] + 1,
|
|
464
|
+
dtype=torch.int64,
|
|
465
|
+
device=input_ids.device,
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
feeds.update(
|
|
469
|
+
dict(zip([n for n in input_names if n.startswith("past_key_values")], outputs[1:]))
|
|
470
|
+
)
|
|
471
|
+
# generate/decoding step
|
|
428
472
|
outputs = session.run(None, feeds)
|
|
429
473
|
|
|
430
474
|
if return_session:
|
|
@@ -851,9 +851,14 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
851
851
|
from .cache_helper import CacheKeyValue
|
|
852
852
|
|
|
853
853
|
ca = CacheKeyValue(value)
|
|
854
|
-
|
|
855
|
-
|
|
854
|
+
pairs = list(zip(ca.key_cache, ca.value_cache))
|
|
855
|
+
assert not hasattr(value, "layers") or len(value.layers) == len(pairs), (
|
|
856
|
+
f"Size mismatch between {len(value.layers)=} and {len(pairs)=}. "
|
|
857
|
+
f"value={string_type(value, with_shape=True)}, "
|
|
858
|
+
f"first key={value.layers[0].keys}, "
|
|
859
|
+
f"first value={value.layers[0].values}"
|
|
856
860
|
)
|
|
861
|
+
return make_dynamic_cache(torch_deepcopy(pairs), cls_layers=ca.cls_layers)
|
|
857
862
|
if value.__class__.__name__ == "StaticCache":
|
|
858
863
|
from .cache_helper import CacheKeyValue
|
|
859
864
|
|