onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +78 -22
  3. onnx_diagnostic/export/api.py +124 -0
  4. onnx_diagnostic/export/dynamic_shapes.py +2 -1
  5. onnx_diagnostic/export/shape_helper.py +47 -70
  6. onnx_diagnostic/ext_test_case.py +11 -0
  7. onnx_diagnostic/helpers/cache_helper.py +38 -7
  8. onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
  9. onnx_diagnostic/helpers/helper.py +27 -33
  10. onnx_diagnostic/helpers/log_helper.py +109 -5
  11. onnx_diagnostic/helpers/memory_peak.py +2 -0
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +132 -2
  14. onnx_diagnostic/helpers/onnx_helper.py +1 -1
  15. onnx_diagnostic/helpers/ort_session.py +4 -0
  16. onnx_diagnostic/helpers/rt_helper.py +393 -43
  17. onnx_diagnostic/helpers/torch_helper.py +20 -1
  18. onnx_diagnostic/tasks/__init__.py +7 -0
  19. onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
  20. onnx_diagnostic/tasks/feature_extraction.py +2 -8
  21. onnx_diagnostic/tasks/image_text_to_text.py +10 -8
  22. onnx_diagnostic/tasks/summarization.py +2 -8
  23. onnx_diagnostic/tasks/text2text_generation.py +3 -8
  24. onnx_diagnostic/tasks/text_generation.py +86 -65
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
  26. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  27. onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
  28. onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
  29. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
  30. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
  31. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
  32. onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
  33. onnx_diagnostic/torch_models/validate.py +626 -228
  34. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
  35. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
  36. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
  37. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
  38. {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
@@ -1,18 +1,18 @@
1
- from typing import Any, Dict, List, Union
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
4
  import numpy as np
3
5
  import onnx
4
6
  import torch
5
- from .helper import string_type, flatten_object
7
+ from .helper import string_type, flatten_object, max_diff
8
+ from .torch_helper import torch_deepcopy
9
+ from .ort_session import InferenceSessionForTorch
6
10
 
7
11
 
8
12
  def name_type_to_onnx_dtype(name: str) -> int:
9
- if name == "tensor(int64)":
10
- return onnx.TensorProto.INT64
11
- if name == "tensor(float)":
12
- return onnx.TensorProto.FLOAT
13
- if name == "tensor(float16)":
14
- return onnx.TensorProto.FLOAT16
15
- raise AssertionError(f"Unexpected value {name!r}")
13
+ assert name.startswith("tensor(") and name.endswith(")"), f"Invalid value name={name!r}"
14
+ look = name[7:-1]
15
+ return getattr(onnx.TensorProto, look.upper())
16
16
 
17
17
 
18
18
  def make_feeds(
@@ -95,49 +95,399 @@ def make_feeds(
95
95
  elif isinstance(i, float):
96
96
  i = np.array(i, dtype=np.float32)
97
97
  new_flat.append(i)
98
-
99
- # NOTE: model builder has a different order for past_key_values
100
- # we need to reorder them to match the expected order
101
- if is_modelbuilder:
102
- # We assume that if "past_key_values" is in the names when it's
103
- # modelbuilder
104
- non_past_kv_input_names = [n for n in names if "past_key_values" not in n]
105
- past_kv_names = [n for n in names if "past_key_values" in n]
106
- reorder_past_kv_names = reorder_modelbuilder_cache_to_torch(past_kv_names)
107
- names = non_past_kv_input_names + reorder_past_kv_names
108
98
  return dict(zip(names, new_flat))
109
99
 
110
100
 
111
- def reorder_modelbuilder_cache_to_torch(past_kv: List[Any]) -> List[Any]:
101
+ def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
102
+ if isinstance(s, int):
103
+ return s
104
+ if s == "batch":
105
+ return batch
106
+ # Everything else is cache length or sequence length.
107
+ return 0
108
+
109
+
110
+ _DTYPES = {
111
+ "tensor(float)": torch.float32,
112
+ "tensor(float16)": torch.float16,
113
+ "tensor(bfloat16)": torch.bfloat16,
114
+ "tensor(int64)": torch.int64,
115
+ "tensor(int32)": torch.int32,
116
+ }
117
+
118
+
119
+ def rt_type_to_torch_dtype(typename: str) -> torch.dtype:
120
+ """Converts a string such as ``tensor(float)`` into a dtype (torch.float32)."""
121
+ return _DTYPES[typename]
122
+
123
+
124
+ def make_empty_cache(
125
+ batch: int,
126
+ onnx_input_names: List[str],
127
+ onnx_input_shapes: List[Tuple[Union[int, str], ...]],
128
+ onnx_input_types: List[str],
129
+ ) -> Dict[str, torch.Tensor]:
130
+ """
131
+ Creates an empty cache. Example:
132
+
133
+ .. code-block:: python
134
+
135
+ make_empty_cache(
136
+ 1,
137
+ sess.input_names[2:],
138
+ [i.shape for i in sess.get_inputs()[2:]],
139
+ [i.type for i in sess.get_inputs()[2:]],
140
+ )
141
+ """
142
+ feeds = {}
143
+ for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
144
+ new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
145
+ feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
146
+ return feeds
147
+
148
+
149
+ def generate_and_validate(
150
+ model,
151
+ input_ids: torch.Tensor,
152
+ eos_token_id: int = 2,
153
+ max_new_tokens: int = 100,
154
+ session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
155
+ atol: float = 0.1,
156
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
157
+ """
158
+ Implements a simple method ``generate`` for a torch model.
159
+ The function does not expect any ``position_ids`` as input.
160
+ The function also checks the outputs coming from an onnx model
161
+ are close to the output the torch model produces.
162
+
163
+ :param model_or_path: model or loaded model
164
+ :param input_ids: input tokens
165
+ :param eos_token_ids: token representing the end of an answer
166
+ :param max_new_tokens: stops after this number of generated tokens
167
+ :param session: the onnx model
168
+ :return: input tokens concatenated with new tokens,
169
+ if session is not null, it also returns the maximum differences
170
+ at every iterations
171
+
172
+ See example given with function :func:`onnx_generate
173
+ <onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
174
+ """
175
+ if session is not None:
176
+ if not isinstance(session, InferenceSessionForTorch):
177
+ providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
178
+ providers.append("CPUExecutionProvider")
179
+ session = InferenceSessionForTorch(session, providers=providers)
180
+
181
+ # First call: prefill
182
+ attention_mask = torch.ones(
183
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
184
+ )
185
+ if session:
186
+ feeds = {
187
+ **dict(zip(session.input_names[:2], [input_ids, attention_mask])),
188
+ **make_empty_cache(
189
+ input_ids.shape[0],
190
+ session.input_names[2:],
191
+ session.input_shapes[2:],
192
+ session.input_types[2:],
193
+ ),
194
+ }
195
+ onnx_results = session.run(None, feeds)
196
+
197
+ outputs = model(input_ids, use_cache=True, attention_mask=attention_mask)
198
+
199
+ if session:
200
+ diff = max_diff(outputs, onnx_results)
201
+ assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
202
+ f"Unexpected issue with {type(model)}\ndiff={diff}"
203
+ f"\ninput_ids.shape={input_ids.shape}"
204
+ f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
205
+ f"\n got=\n"
206
+ f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
207
+ f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
208
+ )
209
+ diffs = [diff]
210
+
211
+ # Next calls: decode
212
+ for iteration in range(max_new_tokens):
213
+ next_token_logits = outputs.logits[:, -1, :]
214
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
215
+ if next_token_id.item() == eos_token_id:
216
+ break
217
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
218
+ attention_mask = torch.ones(
219
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
220
+ )
221
+ if session:
222
+ feeds = dict(
223
+ zip(
224
+ session.input_names,
225
+ [
226
+ t.detach()
227
+ for t in torch_deepcopy(
228
+ flatten_object(
229
+ [next_token_id, attention_mask, outputs.past_key_values]
230
+ )
231
+ )
232
+ ],
233
+ )
234
+ )
235
+ onnx_results = session.run(None, feeds)
236
+ outputs = model(
237
+ next_token_id,
238
+ use_cache=True,
239
+ past_key_values=outputs.past_key_values,
240
+ attention_mask=attention_mask,
241
+ )
242
+ if session:
243
+ diff = max_diff(outputs, onnx_results)
244
+ assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
245
+ f"Unexpected issue with {type(model)}, iteration={iteration}"
246
+ f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
247
+ f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
248
+ f"\n got=\n"
249
+ f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
250
+ f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
251
+ )
252
+ diffs.append(diff)
253
+ if session:
254
+ return input_ids, diffs
255
+ return input_ids
256
+
257
+
258
+ def onnx_generate(
259
+ model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
260
+ input_ids: torch.Tensor,
261
+ eos_token_id: int = 2,
262
+ max_new_tokens=100,
263
+ return_session: bool = False,
264
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch, Dict[str, Any]]]:
265
+ """
266
+ Implements a simple method ``generate`` for an ONNX model.
267
+ The function does not expect any ``position_ids`` as input.
268
+
269
+ :param model_or_path: model or loaded model
270
+ :param input_ids: input tokens
271
+ :param eos_token_ids: token representing the end of an answer
272
+ :param max_new_tokens: stops after this number of generated tokens
273
+ :param return_session: returns the instance of class
274
+ :class:`InferenceSessionForTorch
275
+ <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
276
+ created if necessary, the function returns the feeds for the next iteration
277
+ :return: input tokens concatenated with new tokens
278
+
279
+ .. runpython::
280
+ :showcode:
281
+
282
+ import os
283
+ from onnx_diagnostic.helpers import string_type, string_diff
284
+ from onnx_diagnostic.helpers.rt_helper import (
285
+ onnx_generate,
286
+ generate_and_validate,
287
+ onnx_generate_with_genai,
288
+ )
289
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
290
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
291
+ from onnx_diagnostic.export.api import to_onnx
292
+
293
+ mid = "arnir0/Tiny-LLM"
294
+ print(f"-- get model for {mid!r}")
295
+ data = get_untrained_model_with_inputs(mid)
296
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
297
+ del inputs["position_ids"]
298
+ del ds["position_ids"]
299
+ input_ids = inputs["input_ids"]
300
+
301
+ print(f"-- input_ids={input_ids.shape}")
302
+ print(f"-- inputs: {string_type(inputs, with_shape=True)}")
303
+ print(f"-- dynamic_shapes: {string_type(ds)}")
304
+ folder = "dump_test"
305
+ os.makedirs(folder, exist_ok=True)
306
+ model_name = os.path.join(folder, "model.onnx")
307
+ print("-- test_onnx_generate: export model")
308
+ with torch_export_patches(patch_transformers=True, patch_torch=False):
309
+ to_onnx(
310
+ model,
311
+ (),
312
+ kwargs=inputs,
313
+ dynamic_shapes=ds,
314
+ filename=model_name,
315
+ exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
316
+ )
317
+
318
+ print("-- generate with onnx")
319
+ onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
320
+ print("-- onnx output", onnx_outputs)
321
+
322
+ # The example continues with other functions doing the same.
323
+ print("-- generate with pytorch")
324
+ torch_outputs, diffs = generate_and_validate(
325
+ model, input_ids[:1], 2, max_new_tokens=10, session=model_name
326
+ )
327
+ print("-- torch output", torch_outputs)
328
+ print("-- differences at each step:")
329
+ for i, d in enumerate(diffs):
330
+ print(f"iteration {i}: {string_diff(d)}")
331
+
332
+ print("-- generate with genai")
333
+ genai_outputs, session = onnx_generate_with_genai(
334
+ model_name,
335
+ input_ids[:1],
336
+ max_new_tokens=10,
337
+ return_session=True,
338
+ transformers_config=data["configuration"],
339
+ )
340
+ print("-- genai output", genai_outputs)
112
341
  """
113
- Reorders the past_kvs for ModelBuilder to match the expected order
114
- by PyTorch exported models.
342
+ if not isinstance(model_or_path, InferenceSessionForTorch):
343
+ providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
344
+ providers.append("CPUExecutionProvider")
345
+ session = InferenceSessionForTorch(model_or_path, providers=providers)
346
+ else:
347
+ session = model_or_path
348
+
349
+ input_shapes = session.input_shapes
350
+ input_names = session.input_names
351
+ input_types = session.input_types
352
+ has_position_ids = "position_ids" in session.input_names
353
+
354
+ assert (
355
+ len(input_names) > 2
356
+ and input_names[:2] == ["input_ids", "attention_mask"]
357
+ and input_names[3 if has_position_ids else 2].startswith("past_key_values")
358
+ ), (
359
+ f"Only text generation is supported but input_names == {input_names}, "
360
+ f"has_position_ids={has_position_ids}"
361
+ )
362
+ assert (
363
+ not has_position_ids or input_names[2] == "position_ids"
364
+ ), f"position_ids must the third input but input_names={input_names}"
365
+
366
+ # First call: prefill
367
+ feeds = dict(
368
+ input_ids=input_ids,
369
+ attention_mask=torch.ones(
370
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
371
+ ),
372
+ **make_empty_cache(
373
+ input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
374
+ ),
375
+ )
376
+ if has_position_ids:
377
+ feeds["position_ids"] = torch.unsqueeze(
378
+ torch.arange(input_ids.shape[1], dtype=torch.int64, device=input_ids.device), 0
379
+ )
115
380
 
116
- .. note::
117
- This function can take either the names or the actual tensors
118
- as long as they are in a list.
381
+ outputs = session.run(None, feeds)
119
382
 
120
- Conceptually,
383
+ # Next calls: decode
384
+ for _ in range(max_new_tokens):
385
+ next_token_logits = outputs[0][:, -1, :]
121
386
 
122
- From::
387
+ # The most probable next token is chosen.
388
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
389
+ # But we could select it using a multinomial law
390
+ # <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
391
+ # <<< top_probs, top_indices = torch.topk(probs, top_k)
392
+ # <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
123
393
 
124
- [past_key_values.0.key, past_key_values.0.value,
125
- past_key_values.1.key, past_key_values.1.value, ...]
394
+ if next_token_id.item() == eos_token_id:
395
+ break
396
+ input_ids = torch.cat([input_ids, next_token_id.to(input_ids.device)], dim=-1)
397
+ feeds = dict(
398
+ input_ids=next_token_id,
399
+ attention_mask=torch.ones(
400
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
401
+ ),
402
+ )
403
+ if has_position_ids:
404
+ feeds["position_ids"] = torch.unsqueeze(
405
+ torch.arange(
406
+ input_ids.shape[1],
407
+ input_ids.shape[1] + 1,
408
+ dtype=torch.int64,
409
+ device=input_ids.device,
410
+ ),
411
+ 0,
412
+ )
413
+ feeds.update(dict(zip(input_names[3 if has_position_ids else 2 :], outputs[1:])))
414
+ outputs = session.run(None, feeds)
415
+
416
+ if return_session:
417
+ return input_ids, session, feeds
418
+ return input_ids
126
419
 
127
- To::
128
420
 
129
- [past_key_values.0.key, past_key_values.1.key,
130
- ..., past_key_values.0.value, past_key_values.1.value, ...]
421
+ def onnx_generate_with_genai(
422
+ model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
423
+ input_ids: torch.Tensor,
424
+ max_new_tokens=100,
425
+ return_session: bool = False,
426
+ transformers_config: Optional[Any] = None,
427
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
428
+ """
429
+ Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate``
430
+ for an ONNX model. The function does not expect any ``position_ids`` as input.
431
+
432
+ :param model_or_path: model or loaded model
433
+ :param input_ids: input tokens
434
+ :param eos_token_ids: token representing the end of an answer
435
+ :param max_new_tokens: stops after this number of generated tokens
436
+ :param return_session: returns the instance of class
437
+ :class:`InferenceSessionForTorch
438
+ <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
439
+ created if necessary
440
+ :param transformers_config: write configuration
441
+ if missing and if this configuration is provided
442
+ :return: input tokens concatenated with new tokens
131
443
 
132
- :param past_kv: list of flattened inputs
133
- :return: reordered list of flattened inputs
444
+ See example given with function :func:`onnx_generate
445
+ <onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
134
446
  """
135
- total_len = len(past_kv)
136
- if total_len % 2 != 0:
137
- raise ValueError("The length of past_key_values should be even.")
138
- keys = []
139
- values = []
140
- for i in range(0, total_len, 2):
141
- keys.append(past_kv[i])
142
- values.append(past_kv[i + 1])
143
- return keys + values
447
+ import onnxruntime_genai as og
448
+
449
+ if not isinstance(model_or_path, og.Model):
450
+ from .model_builder_helper import make_genai_config
451
+
452
+ assert isinstance(
453
+ model_or_path, str
454
+ ), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}"
455
+ folder = os.path.dirname(model_or_path)
456
+ assert os.path.exists(folder), f"Folder {folder!r} does not exists."
457
+ assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists."
458
+ config_file = os.path.join(folder, "genai_config.json")
459
+ if not os.path.exists(config_file):
460
+ if not transformers_config:
461
+ raise FileNotFoundError(
462
+ f"Folder {model_or_path!r} does not contain 'genai_config.json'."
463
+ )
464
+ config = make_genai_config(transformers_config, model_or_path)
465
+ with open(config_file, "w") as f:
466
+ json.dump(config, f, indent=4)
467
+
468
+ config = og.Config(os.path.dirname(config_file))
469
+ if input_ids.is_cuda:
470
+ config.clear_providers()
471
+ config.append_provider("cuda")
472
+ session = og.Model(config)
473
+ else:
474
+ session = model_or_path
475
+
476
+ params = og.GeneratorParams(session)
477
+ params.set_search_options(
478
+ max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0]
479
+ )
480
+ generator = og.Generator(session, params)
481
+
482
+ # First call: prefill
483
+ cats = []
484
+ generator.append_tokens(input_ids)
485
+ while not generator.is_done():
486
+ generator.generate_next_token()
487
+ new_token = generator.get_next_tokens()[0]
488
+ cats.append(int(new_token))
489
+
490
+ input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1)
491
+ if return_session:
492
+ return input_ids, session
493
+ return input_ids
@@ -856,9 +856,15 @@ def torch_deepcopy(value: Any) -> Any:
856
856
  ), f"Unexpected type={type(value)}"
857
857
  return copy.deepcopy(value)
858
858
 
859
+ if hasattr(value, "__nocopy__"):
860
+ return value
861
+
859
862
  # We should have a code using serialization, deserialization assuming a model
860
863
  # cannot be exported without them.
861
- raise NotImplementedError(f"torch_deepcopy not implemented for type {type(value)}")
864
+ raise NotImplementedError(
865
+ f"torch_deepcopy not implemented for type {type(value)}, "
866
+ f"add attribute '__nocopy__' to return it as is."
867
+ )
862
868
 
863
869
 
864
870
  def torch_tensor_size(value: Any) -> Any:
@@ -966,3 +972,16 @@ def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor:
966
972
  # Other cases, it should be small tensor. We use numpy.
967
973
  np_tensor = to_array_extended(tensor)
968
974
  return torch.from_numpy(np_tensor)
975
+
976
+
977
+ def get_weight_type(model: torch.nn.Module) -> torch.dtype:
978
+ """Returns the most probable dtype in a model."""
979
+ counts = {}
980
+ for _name, param in model.named_parameters():
981
+ dt = param.dtype
982
+ if dt not in counts:
983
+ counts[dt] = 1
984
+ else:
985
+ counts[dt] += 1
986
+ final = max(list(counts.items()))
987
+ return final[0]
@@ -77,6 +77,13 @@ def random_input_kwargs(config: Any, task: str) -> Tuple[Dict[str, Any], Callabl
77
77
  If the configuration is None, the function selects typical dimensions.
78
78
  It returns parameters and a function. The function creates dummy inputs
79
79
  if it receives the parameters returned as a first result.
80
+
81
+ .. code-block:: python
82
+
83
+ config = get_pretrained_config(model_id)
84
+ task = task = task_from_id(name)
85
+ kwargs, fct = random_input_kwargs(config, task)
86
+ res = fct(model, config, add_second_input=False, **kwargs)
80
87
  """
81
88
  tasks = {mod.__TASK__: mod.random_input_kwargs for mod in __TASKS__}
82
89
  assert task in tasks, f"Task {task!r} not found in {sorted(tasks)}"
@@ -84,14 +84,8 @@ def get_inputs(
84
84
  "cache_position": {0: seq_length},
85
85
  "encoder_outputs": [{0: batch}], # last_hidden_state
86
86
  "past_key_values": [
87
- [
88
- [{0: batch} for _ in range(num_hidden_layers)],
89
- [{0: batch} for _ in range(num_hidden_layers)],
90
- ],
91
- [
92
- [{0: batch} for _ in range(num_hidden_layers)],
93
- [{0: batch} for _ in range(num_hidden_layers)],
94
- ],
87
+ [{0: batch} for _ in range(num_hidden_layers * 2)],
88
+ [{0: batch} for _ in range(num_hidden_layers * 2)],
95
89
  ],
96
90
  }
97
91
  inputs = dict(
@@ -109,14 +109,8 @@ def get_inputs(
109
109
  cache_length = "cache_length_key"
110
110
  cache_length2 = "cache_length_val"
111
111
  shapes["past_key_values"] = [ # type: ignore[assignment]
112
- [
113
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
114
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
115
- ],
116
- [
117
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
118
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
119
- ],
112
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
113
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
120
114
  ]
121
115
 
122
116
  res = dict(inputs=inputs, dynamic_shapes=shapes)
@@ -1,3 +1,4 @@
1
+ import itertools
1
2
  from typing import Any, Callable, Dict, Optional, Tuple
2
3
  import torch
3
4
  from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
@@ -151,10 +152,7 @@ def _get_inputs_gemma3(
151
152
  },
152
153
  "position_ids": {0: batch, 1: seq_length},
153
154
  "cache_position": {0: seq_length},
154
- "past_key_values": [
155
- [{0: batch} for _ in range(num_hidden_layers)],
156
- [{0: batch} for _ in range(num_hidden_layers)],
157
- ],
155
+ "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
158
156
  "pixel_values": {0: batch},
159
157
  "use_cache": None,
160
158
  }
@@ -272,10 +270,14 @@ def get_inputs_default(
272
270
  "token_type_ids": {0: batch, 1: seq_length},
273
271
  "attention_mask": {0: batch, 1: "cache+seq"},
274
272
  "position_ids": {0: batch, 1: seq_length},
275
- "past_key_values": [
276
- [{0: batch} for _ in range(num_hidden_layers)],
277
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
278
- ],
273
+ "past_key_values": list(
274
+ itertools.chain.from_iterable(
275
+ zip(
276
+ [{0: batch} for _ in range(num_hidden_layers)],
277
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
278
+ )
279
+ )
280
+ ),
279
281
  "pixel_values": (
280
282
  {0: batch, 1: images}
281
283
  if model.__class__.__name__ == "IdeficsForVisionText2Text"
@@ -81,14 +81,8 @@ def get_inputs(
81
81
  "attention_mask": {0: batch, 1: "seq_mask"},
82
82
  # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
83
83
  "past_key_values": [
84
- [
85
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
86
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
87
- ],
88
- [
89
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
90
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
91
- ],
84
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
85
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
92
86
  ],
93
87
  # one these is selected based on the forward method signature
94
88
  # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
@@ -83,14 +83,8 @@ def get_inputs(
83
83
  "attention_mask": {0: batch, 1: "seq_mask"},
84
84
  # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
85
85
  "past_key_values": [
86
- [
87
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
88
- [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
89
- ],
90
- [
91
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
92
- [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers)],
93
- ],
86
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
87
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
94
88
  ],
95
89
  # one these is selected based on the forward method signature
96
90
  # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
@@ -157,6 +151,7 @@ def get_inputs(
157
151
  assert (
158
152
  add_second_input > 0
159
153
  ), f"Not implemented for add_second_input={add_second_input}."
154
+ res["inputs_prompt"] = dict(input_ids=torch.randint(1000, 30000, (1, 11)))
160
155
  res["inputs2"] = get_inputs(
161
156
  model=model,
162
157
  config=config,