onnx-diagnostic 0.8.11__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/ci_models/data/Blanca_Lake_Hudak.jpg +0 -0
  3. onnx_diagnostic/ci_models/data/Ice_worm_glacier.jpg +0 -0
  4. onnx_diagnostic/ci_models/data/__init__.py +0 -0
  5. onnx_diagnostic/ci_models/export_phi4_mm.py +8 -3
  6. onnx_diagnostic/export/api.py +11 -0
  7. onnx_diagnostic/export/dynamic_shapes.py +1 -1
  8. onnx_diagnostic/helpers/cache_helper.py +96 -30
  9. onnx_diagnostic/helpers/helper.py +39 -0
  10. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  11. onnx_diagnostic/helpers/ort_session.py +5 -1
  12. onnx_diagnostic/helpers/rt_helper.py +53 -9
  13. onnx_diagnostic/helpers/torch_helper.py +7 -2
  14. onnx_diagnostic/investigate/input_observer.py +793 -152
  15. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +32 -14
  16. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +107 -6
  17. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +13 -3
  18. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +1 -0
  19. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +28 -2
  20. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/METADATA +2 -2
  21. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/RECORD +24 -21
  22. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/WHEEL +1 -1
  23. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  24. {onnx_diagnostic-0.8.11.dist-info → onnx_diagnostic-0.9.0.dist-info}/top_level.txt +0 -0
@@ -3,5 +3,5 @@ Patches, Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.8.11"
6
+ __version__ = "0.9.0"
7
7
  __author__ = "Xavier Dupré"
File without changes
@@ -668,12 +668,17 @@ def get_inputs_for_part(
668
668
  f"{user_prompt}<|image_1|>\n<|image_2|>\n<|image_3|>\n<|image_4|>\n"
669
669
  f"What is shown in these four images?{prompt_suffix}{assistant_prompt}"
670
670
  )
671
- url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000"
672
- image_2 = Image.open(requests.get(url, stream=True).raw)
671
+ image_2_path = os.path.join(
672
+ os.path.dirname(__file__), "data", "Blanca_Lake_Hudak.jpg"
673
+ )
674
+ image_2 = Image.open(image_2_path)
673
675
  url = (
674
676
  "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
675
677
  )
676
- image_3 = Image.open(requests.get(url, stream=True).raw)
678
+ image_3_path = os.path.join(
679
+ os.path.dirname(__file__), "data", "Ice_worm_glacier.jpg"
680
+ )
681
+ image_3 = Image.open(image_3_path)
677
682
 
678
683
  images = [image_1, image_2, image_3, image_4]
679
684
  inputs = processor(prompt, images=images, return_tensors="pt").to(device)
@@ -428,6 +428,16 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
428
428
  new_kwargs[k] = v
429
429
  return new_kwargs
430
430
 
431
+ def is_empty_cache(self, cache):
432
+ if cache.__class__.__name__ == "DynamicCache" and hasattr(cache, "layers"):
433
+ if len(cache.layers) == 1 and cache.layers[0].keys is None:
434
+ return True
435
+ if len(cache.layers) == 0:
436
+ return True
437
+ if cache is None:
438
+ return True
439
+ return False
440
+
431
441
  def forward(self, *args, **kwargs):
432
442
  if not self._export_done:
433
443
  inp_args = args
@@ -443,6 +453,7 @@ class WrapperToExportMethodToOnnx(torch.nn.Module):
443
453
  if v is not None
444
454
  and (not self.skip_kwargs_names or k not in self.skip_kwargs_names)
445
455
  and not isinstance(v, (bool, int, float))
456
+ and not self.is_empty_cache(v)
446
457
  }
447
458
  )
448
459
  inp_args, inp_kwargs = torch_deepcopy((inp_args, inp_kwargs))
@@ -834,7 +834,7 @@ class ModelInputs:
834
834
  """Guesses the dynamic shapes for one argument."""
835
835
  if len(objs) == 0:
836
836
  return None
837
- set_types = set(type(o) for o in objs)
837
+ set_types = set(type(o) for o in objs if o is not None)
838
838
  assert (
839
839
  len(set_types) == 1
840
840
  ), f"Unexpected variety of input type {set_types}{msg() if msg else ''})"
@@ -4,6 +4,19 @@ import torch
4
4
  import transformers
5
5
  import transformers.cache_utils
6
6
 
7
+ KWARGS_LAYER = {}
8
+ if hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer"):
9
+ KWARGS_LAYER.update(
10
+ {
11
+ transformers.cache_utils.DynamicSlidingWindowLayer: lambda tensor: {
12
+ "sliding_window": tensor.shape[2]
13
+ },
14
+ transformers.cache_utils.StaticSlidingWindowLayer: lambda tensor: {
15
+ "sliding_window": tensor.shape[2]
16
+ },
17
+ }
18
+ )
19
+
7
20
 
8
21
  class CacheKeyValue:
9
22
  """
@@ -185,6 +198,7 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
185
198
  def make_dynamic_cache(
186
199
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
187
200
  cls_layers: Optional[Union[str, List[type]]] = None,
201
+ cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
188
202
  ) -> transformers.cache_utils.DynamicCache:
189
203
  """
190
204
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -194,6 +208,8 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
194
208
  :param cls_layers: to select the appropriate class to use on each layer,
195
209
  if specified, sliding_window is ignored, it can be a string
196
210
  if all layers are expected to follow the same class
211
+ :param cls_kwargs: arguments used to build a specific layer,
212
+ such as ``sliding_window`` for ``DynamicSlidingWindowLayer``
197
213
  :return: :class:`transformers.cache_utils.DynamicCache`
198
214
 
199
215
  Example:
@@ -224,49 +240,70 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
224
240
  are supported.
225
241
  """
226
242
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
227
- cls_kwargs = {}
228
243
  if isinstance(cls_layers, str):
229
244
  assert hasattr(
230
245
  transformers.cache_utils, cls_layers
231
- ), f"Unable to find class {cls_layers!r} in transformers.cache_utils"
232
- cls_layer = getattr(transformers.cache_utils, cls_layers)
233
- if cls_layers == "DynamicSlidingWindowLayer":
234
- cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
235
- assert isinstance(
236
- cls_kwargs["sliding_window"], int
237
- ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
238
- elif cls_layers is not None:
239
- unique = set(cls_layers)
240
- assert len(unique) == 1, f"Not implemented when cls_layers={cls_layers}"
241
- cls_layer = unique.pop()
242
- if (
243
- hasattr(transformers.cache_utils, "DynamicSlidingWindowLayer")
244
- and cls_layer == transformers.cache_utils.DynamicSlidingWindowLayer
245
- ):
246
- from .helper import string_type
247
-
248
- assert key_value_pairs and key_value_pairs[0], (
249
- f"not implemented for key_value_pairs="
250
- f"{string_type(key_value_pairs, with_shape=True)}"
251
- )
252
- cls_kwargs["sliding_window"] = key_value_pairs[0][0].shape[2]
253
- assert isinstance(
254
- cls_kwargs["sliding_window"], int
255
- ), f"sliding_window must be an integer but shape={key_value_pairs[0][0].shape}"
246
+ ), f"Missing layer class {cls_layers!r}"
247
+ cls_layers = getattr(transformers.cache_utils, cls_layers)
248
+ if cls_layers and not isinstance(cls_layers, list):
249
+ cls_layers = [cls_layers for _ in key_value_pairs] # type: ignore[misc]
250
+ if cls_layers is not None and isinstance(cls_layers, list):
251
+ assert len(cls_layers) == len(key_value_pairs), (
252
+ f"Length mismatch {len(key_value_pairs)} expected but "
253
+ f"{len(cls_layers)} layer types are given."
254
+ )
255
+ if cls_kwargs is None:
256
+ cls_kwargs = [{} for _kv in key_value_pairs] # type: ignore[assignment]
257
+ assert len(cls_layers) == len(cls_kwargs), (
258
+ f"Length mismatch {len(cls_kwargs)} expected but "
259
+ f"{len(cls_layers)} layer types are given, "
260
+ f"cls_layers={cls_layers}, cls_kwargs={cls_kwargs}"
261
+ )
262
+ cls_layer = None
263
+ assert (
264
+ key_value_pairs and key_value_pairs[0]
265
+ ), f"not implemented for type(key_value_pairs[0])={type(key_value_pairs[0])}"
266
+ for kv, clsy, kws in zip(key_value_pairs, cls_layers, cls_kwargs):
267
+ default_values = KWARGS_LAYER.get(clsy, lambda tensor: {})(kv[0])
268
+ for k, v in default_values.items():
269
+ if k not in kws:
270
+ kws[k] = v # type: ignore[index]
256
271
  else:
272
+ assert cls_kwargs is None, "cls_layers must be a list if cls_kwargs is specified"
273
+ assert (
274
+ cls_layers is None
275
+ ), f"cls_layers must be list or a string but it is {cls_layers}"
276
+ cls_kwargs = {}
257
277
  cls_layer = (
258
278
  transformers.cache_utils.DynamicLayer
259
279
  if hasattr(transformers.cache_utils, "DynamicLayer")
260
280
  else None
261
281
  )
262
282
 
283
+ if cls_layer is not None:
284
+ assert isinstance(cls_kwargs, dict), (
285
+ f"one layer = one set of arguments, cls_layer={cls_layer}, "
286
+ f"cls_kwargs={cls_kwargs}"
287
+ )
288
+ cls_layers = [cls_layer for _ in key_value_pairs]
289
+ cls_kwargs = (
290
+ cls_kwargs # type: ignore[assignment]
291
+ if isinstance(cls_kwargs, list)
292
+ else [cls_kwargs for _ in key_value_pairs]
293
+ )
294
+ elif cls_layers is not None:
295
+ assert isinstance(cls_layers, list), f"Unexpected type cls_layers={cls_layers}"
296
+ assert isinstance(cls_kwargs, list), f"Unexpected type cls_kwargs={cls_kwargs}"
297
+
263
298
  if (
264
299
  key_value_pairs
265
300
  and isinstance(key_value_pairs[0][0], torch._subclasses.fake_tensor.FakeTensor)
266
301
  and pv.Version(transformers.__version__) >= pv.Version("4.56")
267
302
  ):
268
303
  cache = transformers.cache_utils.DynamicCache()
269
- cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
304
+ cache.layers.extend(
305
+ [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
306
+ )
270
307
  for i, layer in enumerate(cache.layers):
271
308
  k, v = key_value_pairs[i][0], key_value_pairs[i][1]
272
309
  layer.dtype = k.dtype
@@ -281,8 +318,25 @@ if pv.Version(transformers.__version__) > pv.Version("4.49.99999"):
281
318
  return finalize_cache(cache)
282
319
 
283
320
  cache = transformers.cache_utils.DynamicCache()
284
- if hasattr(cache, "layers") and cls_layer != transformers.cache_utils.DynamicLayer:
285
- cache.layers.extend([cls_layer(**cls_kwargs) for _ in key_value_pairs])
321
+ if hasattr(cache, "layers") and (
322
+ cls_layer is None or cls_layer != transformers.cache_utils.DynamicLayer
323
+ ):
324
+ assert isinstance(cls_layers, list) and isinstance(cls_kwargs, list), (
325
+ f"Wrong type {type(cls_layers)} for cls_layers or "
326
+ f"{type(cls_kwargs)} for cls_kwargs"
327
+ )
328
+ assert len(cls_kwargs) == len(cls_layers) and len(cls_kwargs) == len(
329
+ key_value_pairs
330
+ ), (
331
+ f"Length mismatch between len(cls_kwargs)={len(cls_kwargs)}, "
332
+ f"len(cls_layers)={len(cls_layers)}, "
333
+ f"len(key_value_pairs)={len(key_value_pairs)}, "
334
+ f"cls_kwargs={cls_kwargs}, cls_layers={cls_layers}"
335
+ )
336
+ del cache.layers[:]
337
+ cache.layers.extend(
338
+ [cls_layer(**kws) for cls_layer, kws in zip(cls_layers, cls_kwargs)] # type: ignore[operator, arg-type]
339
+ )
286
340
  for i, layer in enumerate(cache.layers):
287
341
  layer.keys, layer.values = key_value_pairs[i][0], key_value_pairs[i][1]
288
342
  layer.is_initialized = True
@@ -306,6 +360,7 @@ else:
306
360
  def make_dynamic_cache(
307
361
  key_value_pairs: Union[List[torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]],
308
362
  cls_layers: Optional[Union[str, List[type]]] = None,
363
+ cls_kwargs: Optional[Union[Dict[str, int], List[Dict[str, int]]]] = None,
309
364
  ) -> transformers.cache_utils.DynamicCache:
310
365
  """
311
366
  Creates an instance of :class:`transformers.cache_utils.DynamicCache`.
@@ -337,7 +392,9 @@ else:
337
392
  )
338
393
  print(string_type(past_key_values, with_shape=True))
339
394
  """
340
- assert not cls_layers, "cls_layers cannot be used for transformers<5."
395
+ assert (
396
+ not cls_layers and not cls_kwargs
397
+ ), "cls_layers, cls_kwargs cannot be used for transformers<5."
341
398
  key_value_pairs = _preprocess_key_value_pairs(key_value_pairs)
342
399
  cache = transformers.cache_utils.DynamicCache(len(key_value_pairs)) # type: ignore
343
400
  for i, (key, value) in enumerate(key_value_pairs):
@@ -775,4 +832,13 @@ def finalize_cache(cache: transformers.cache_utils.Cache) -> transformers.cache_
775
832
  # This is used to expand the cache when it does not contains enough layers.
776
833
  # This is needed since transformers>4.55.3
777
834
  cache.layer_class_to_replicate = cache.layers[0].__class__
835
+ assert (
836
+ not hasattr(cache, "layers")
837
+ or len(cache.layers) != 1
838
+ or cache.layers[0].keys is not None
839
+ ), (
840
+ f"Size mismatch between {len(cache.layers)=}, "
841
+ f"first key={cache.layers[0].keys}, " # type: ignore[attr-defined]
842
+ f"first value={cache.layers[0].values}" # type: ignore[attr-defined]
843
+ )
778
844
  return cache
@@ -574,6 +574,32 @@ def string_type(
574
574
  print(f"[string_type] CACHE1:{type(obj)}")
575
575
  return f"MambaCache(conv_states={c}, ssm_states={d})"
576
576
 
577
+ if (
578
+ obj.__class__.__name__ in {"DynamicCache"}
579
+ and hasattr(obj, "layers")
580
+ and any(lay.__class__.__name__ != "DynamicLayer" for lay in obj.layers)
581
+ ):
582
+ slay = []
583
+ for lay in obj.layers:
584
+ skeys = string_type(
585
+ lay.keys,
586
+ with_shape=with_shape,
587
+ with_min_max=with_min_max,
588
+ with_device=with_device,
589
+ limit=limit,
590
+ verbose=verbose,
591
+ )
592
+ svalues = string_type(
593
+ lay.keys,
594
+ with_shape=with_shape,
595
+ with_min_max=with_min_max,
596
+ with_device=with_device,
597
+ limit=limit,
598
+ verbose=verbose,
599
+ )
600
+ slay.append(f"{lay.__class__.__name__}({skeys}, {svalues})")
601
+ return f"{obj.__class__.__name__}({', '.join(slay)})"
602
+
577
603
  if obj.__class__.__name__ in {
578
604
  "DynamicCache",
579
605
  "SlidingWindowCache",
@@ -829,6 +855,19 @@ def string_type(
829
855
  return f"{obj}"
830
856
  if obj.__class__.__name__ == "FakeTensorContext":
831
857
  return "FakeTensorContext(...)"
858
+ if obj.__class__.__name__ == "Chat":
859
+ import transformers.utils.chat_template_utils as ctu
860
+
861
+ assert isinstance(obj, ctu.Chat), f"unexpected type {type(obj)}"
862
+ msg = string_type(
863
+ obj.messages,
864
+ with_shape=with_shape,
865
+ with_min_max=with_min_max,
866
+ with_device=with_device,
867
+ limit=limit,
868
+ verbose=verbose,
869
+ )
870
+ return f"Chat({msg})"
832
871
 
833
872
  if verbose:
834
873
  print(f"[string_type] END:{type(obj)}")
@@ -1742,7 +1742,7 @@ def _find_used_names(node_list, node_indices):
1742
1742
  possible_outputs |= {o for o in node_list[i_node].output if o}
1743
1743
  # find all requires input from the other nodes
1744
1744
  set_indices = set(node_indices)
1745
- not_known: Set[str] = set()
1745
+ not_known = set()
1746
1746
  ranges = list(range(len(node_list)))
1747
1747
  for i_node in ranges[::-1]:
1748
1748
  if i_node in set_indices:
@@ -6,7 +6,7 @@ import torch
6
6
  from torch._C import _from_dlpack
7
7
  import onnxruntime
8
8
  from onnxruntime.capi import _pybind_state as ORTC
9
- from .helper import size_type
9
+ from .helper import size_type, string_type
10
10
  from .onnx_helper import (
11
11
  onnx_dtype_to_np_dtype,
12
12
  np_dtype_to_tensor_dtype,
@@ -511,6 +511,10 @@ class InferenceSessionForTorch(_InferenceSession):
511
511
  device = -1
512
512
  for k, v in feeds.items():
513
513
  assert k != "", f"Input cannot be empty but feeds names={list(feeds)}"
514
+ assert hasattr(v, "device"), (
515
+ f"Unepxected class {type(v)} for input {k!r}, "
516
+ f"feeds={string_type(feeds, with_shape=True)}"
517
+ )
514
518
  device = max(device, v.get_device())
515
519
  assert hasattr(v, "__dlpack__"), f"class {type(v)} should be serialized"
516
520
  if not v.is_contiguous():
@@ -115,7 +115,7 @@ def make_feeds(
115
115
  def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
116
116
  if isinstance(s, int):
117
117
  return s
118
- if s == "batch":
118
+ if s == "batch" or i == 0:
119
119
  return batch
120
120
  # Everything else is cache length or sequence length.
121
121
  return 0
@@ -153,9 +153,13 @@ def make_empty_cache(
153
153
  [i.type for i in sess.get_inputs()[2:]],
154
154
  )
155
155
  """
156
+ assert batch > 0, f"batch size = {batch} must be positive"
156
157
  feeds = {}
157
158
  for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
158
159
  new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
160
+ assert (
161
+ new_shape and new_shape[0] > 0
162
+ ), f"new_shape={new_shape} cannot have a null batch size, name={name!r}, shape={shape}"
159
163
  feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
160
164
  return feeds
161
165
 
@@ -272,6 +276,7 @@ def generate_and_validate(
272
276
  def onnx_generate(
273
277
  model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
274
278
  input_ids: torch.Tensor,
279
+ attention_mask: Optional[torch.Tensor] = None,
275
280
  eos_token_id: int = 2,
276
281
  max_new_tokens=100,
277
282
  return_session: bool = False,
@@ -330,7 +335,9 @@ def onnx_generate(
330
335
  )
331
336
 
332
337
  print("-- generate with onnx")
333
- onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
338
+ onnx_outputs = onnx_generate(
339
+ model_name, input_ids[:1], eos_token_id=2, max_new_tokens=10
340
+ )
334
341
  print("-- onnx output", onnx_outputs)
335
342
 
336
343
  # The example continues with other functions doing the same.
@@ -364,6 +371,7 @@ def onnx_generate(
364
371
  input_names = session.input_names
365
372
  input_types = session.input_types
366
373
  has_position_ids = "position_ids" in session.input_names
374
+ has_cache_position = "cache_position" in session.input_names
367
375
 
368
376
  assert (
369
377
  len(input_names) > 2
@@ -377,21 +385,46 @@ def onnx_generate(
377
385
  not has_position_ids or input_names[2] == "position_ids"
378
386
  ), f"position_ids must the third input but input_names={input_names}"
379
387
 
388
+ cache_names, cache_shapes, cache_types = [], [], []
389
+ for name, shape, dt in zip(input_names, input_shapes, input_types):
390
+ if name.startswith("past_key_values"):
391
+ cache_names.append(name)
392
+ cache_shapes.append(shape)
393
+ cache_types.append(dt)
394
+
380
395
  # First call: prefill
396
+ empty_cache = make_empty_cache(input_ids.shape[0], cache_names, cache_shapes, cache_types)
381
397
  feeds = dict(
382
398
  input_ids=input_ids,
383
- attention_mask=torch.ones(
384
- input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
385
- ),
386
- **make_empty_cache(
387
- input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
399
+ attention_mask=(
400
+ attention_mask
401
+ if attention_mask is not None
402
+ else torch.ones(input_ids.shape, dtype=input_ids.dtype, device=input_ids.device)
388
403
  ),
404
+ **empty_cache,
389
405
  )
406
+
390
407
  if has_position_ids:
391
- feeds["position_ids"] = torch.unsqueeze(
408
+ assert (
409
+ input_ids.shape[1] > 0
410
+ ), f"unexpected value for input_ids shape={input_ids.shape}"
411
+ position_ids = torch.unsqueeze(
392
412
  torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
393
413
  )
414
+ feeds["position_ids"] = position_ids
415
+
416
+ if has_cache_position:
417
+ assert empty_cache, "no cache means no cache_position"
418
+ first_tensor = next(iter(empty_cache.values()))
419
+ cache_position = torch.arange(
420
+ first_tensor.shape[2],
421
+ input_ids.shape[1] + first_tensor.shape[2],
422
+ dtype=torch.int64,
423
+ device=input_ids.device,
424
+ )
425
+ feeds["cache_position"] = cache_position
394
426
 
427
+ # prefill step
395
428
  outputs = session.run(None, feeds)
396
429
 
397
430
  # Next calls: decode
@@ -424,7 +457,18 @@ def onnx_generate(
424
457
  ),
425
458
  0,
426
459
  )
427
- feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
460
+ if has_cache_position:
461
+ feeds["cache_position"] = torch.arange(
462
+ input_ids.shape[1],
463
+ input_ids.shape[1] + 1,
464
+ dtype=torch.int64,
465
+ device=input_ids.device,
466
+ )
467
+
468
+ feeds.update(
469
+ dict(zip([n for n in input_names if n.startswith("past_key_values")], outputs[1:]))
470
+ )
471
+ # generate/decoding step
428
472
  outputs = session.run(None, feeds)
429
473
 
430
474
  if return_session:
@@ -851,9 +851,14 @@ def torch_deepcopy(value: Any) -> Any:
851
851
  from .cache_helper import CacheKeyValue
852
852
 
853
853
  ca = CacheKeyValue(value)
854
- return make_dynamic_cache(
855
- torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))), cls_layers=ca.cls_layers
854
+ pairs = list(zip(ca.key_cache, ca.value_cache))
855
+ assert not hasattr(value, "layers") or len(value.layers) == len(pairs), (
856
+ f"Size mismatch between {len(value.layers)=} and {len(pairs)=}. "
857
+ f"value={string_type(value, with_shape=True)}, "
858
+ f"first key={value.layers[0].keys}, "
859
+ f"first value={value.layers[0].values}"
856
860
  )
861
+ return make_dynamic_cache(torch_deepcopy(pairs), cls_layers=ca.cls_layers)
857
862
  if value.__class__.__name__ == "StaticCache":
858
863
  from .cache_helper import CacheKeyValue
859
864