onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,476 @@
1
+ import json
2
+ import os
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ import numpy as np
5
+ import onnx
6
+ import torch
7
+ from .helper import string_type, flatten_object, max_diff
8
+ from .torch_helper import torch_deepcopy
9
+ from .ort_session import InferenceSessionForTorch
10
+
11
+
12
+ def name_type_to_onnx_dtype(name: str) -> int:
13
+ if name == "tensor(int64)":
14
+ return onnx.TensorProto.INT64
15
+ if name == "tensor(float)":
16
+ return onnx.TensorProto.FLOAT
17
+ if name == "tensor(float16)":
18
+ return onnx.TensorProto.FLOAT16
19
+ raise AssertionError(f"Unexpected value {name!r}")
20
+
21
+
22
+ def make_feeds(
23
+ proto: Union[onnx.ModelProto, List[str]],
24
+ inputs: Any,
25
+ use_numpy: bool = False,
26
+ copy: bool = False,
27
+ check_flatten: bool = True,
28
+ is_modelbuilder: bool = False,
29
+ ) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
30
+ """
31
+ Serializes the inputs to produce feeds expected
32
+ by :class:`onnxruntime.InferenceSession`.
33
+
34
+ :param proto: onnx model or list of names
35
+ :param inputs: any kind of inputs
36
+ :param use_numpy: if True, converts torch tensors into numpy arrays
37
+ :param copy: a copy is made, this should be the case if the inputs is ingested
38
+ by ``OrtValue``
39
+ :param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
40
+ returns the same number of outputs
41
+ :param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
42
+ the past_key_values inputs to match the expected order, and get rid of position_ids.
43
+ :return: feeds dictionary
44
+ """
45
+ # NOTE: position_ids is a special case because ModelBuilder does not usually use it,
46
+ # because it's fued into rotary embedding in GQA.
47
+ if is_modelbuilder and isinstance(inputs, dict):
48
+ inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
49
+
50
+ flat = flatten_object(inputs, drop_keys=True)
51
+ assert (
52
+ not check_flatten
53
+ or not all(isinstance(obj, torch.Tensor) for obj in flat)
54
+ # or not is_cache_dynamic_registered(fast=True)
55
+ or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
56
+ ), (
57
+ f"Unexpected number of flattened objects, "
58
+ f"{string_type(flat, with_shape=True)} != "
59
+ f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
60
+ )
61
+ if use_numpy:
62
+ from .torch_helper import to_numpy
63
+
64
+ flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
65
+ names = (
66
+ [i.name for i in proto.graph.input]
67
+ if isinstance(proto, onnx.ModelProto)
68
+ else (
69
+ [i.name for i in proto.get_inputs()]
70
+ if hasattr(proto, "get_inputs")
71
+ else (proto.input_names if hasattr(proto, "input_names") else proto)
72
+ )
73
+ )
74
+ assert (
75
+ isinstance(names, list)
76
+ and len(names) <= len(flat)
77
+ and (
78
+ len(names) == len(flat)
79
+ or isinstance(proto, onnx.ModelProto)
80
+ or hasattr(proto, "get_inputs")
81
+ )
82
+ ), (
83
+ f"Not the same number of given inputs {len(flat)} "
84
+ f"and the number of model inputs {len(names)}, "
85
+ f"type(names)={type(names)}, type(proto)={type(proto)}"
86
+ f"\n-- inputs={string_type(inputs, with_shape=True)}"
87
+ f"\n-- names={names}"
88
+ )
89
+
90
+ if copy:
91
+ flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
92
+ # bool, int, float, onnxruntime does not support float, bool, int
93
+ new_flat = []
94
+ for i in flat:
95
+ if isinstance(i, bool):
96
+ i = np.array(i, dtype=np.bool_)
97
+ elif isinstance(i, int):
98
+ i = np.array(i, dtype=np.int64)
99
+ elif isinstance(i, float):
100
+ i = np.array(i, dtype=np.float32)
101
+ new_flat.append(i)
102
+ return dict(zip(names, new_flat))
103
+
104
+
105
+ def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
106
+ if isinstance(s, int):
107
+ return s
108
+ if s == "batch":
109
+ return batch
110
+ # Everything else is cache length or sequence length.
111
+ return 0
112
+
113
+
114
+ _DTYPES = {
115
+ "tensor(float)": torch.float32,
116
+ "tensor(float16)": torch.float16,
117
+ "tensor(bfloat16)": torch.bfloat16,
118
+ "tensor(int64)": torch.int64,
119
+ "tensor(int32)": torch.int32,
120
+ }
121
+
122
+
123
+ def rt_type_to_torch_dtype(typename: str) -> torch.dtype:
124
+ """Converts a string such as ``tensor(float)`` into a dtype (torch.float32)."""
125
+ return _DTYPES[typename]
126
+
127
+
128
+ def make_empty_cache(
129
+ batch: int,
130
+ onnx_input_names: List[str],
131
+ onnx_input_shapes: List[Tuple[Union[int, str], ...]],
132
+ onnx_input_types: List[str],
133
+ ) -> Dict[str, torch.Tensor]:
134
+ """
135
+ Creates an empty cache. Example:
136
+
137
+ .. code-block:: python
138
+
139
+ make_empty_cache(
140
+ 1,
141
+ sess.input_names[2:],
142
+ [i.shape for i in sess.get_inputs()[2:]],
143
+ [i.type for i in sess.get_inputs()[2:]],
144
+ )
145
+ """
146
+ feeds = {}
147
+ for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
148
+ new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
149
+ feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
150
+ return feeds
151
+
152
+
153
+ def generate_and_validate(
154
+ model,
155
+ input_ids: torch.Tensor,
156
+ eos_token_id: int,
157
+ max_new_tokens: int = 100,
158
+ session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
159
+ atol: float = 0.1,
160
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
161
+ """
162
+ Implements a simple method ``generate`` for a torch model.
163
+ The function does not expect any ``position_ids`` as input.
164
+ The function also checks the outputs coming from an onnx model
165
+ are close to the output the torch model produces.
166
+
167
+ :param model_or_path: model or loaded model
168
+ :param input_ids: input tokens
169
+ :param eos_token_ids: token representing the end of an answer
170
+ :param max_new_tokens: stops after this number of generated tokens
171
+ :param session: the onnx model
172
+ :return: input tokens concatenated with new tokens,
173
+ if session is not null, it also returns the maximum differences
174
+ at every iterations
175
+
176
+ See example given with function :func:`onnx_generate
177
+ <onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
178
+ """
179
+ if session is not None:
180
+ if not isinstance(session, InferenceSessionForTorch):
181
+ providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
182
+ providers.append("CPUExecutionProvider")
183
+ session = InferenceSessionForTorch(session, providers=providers)
184
+
185
+ # First call: prefill
186
+ attention_mask = torch.ones(
187
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
188
+ )
189
+ if session:
190
+ feeds = {
191
+ **dict(zip(session.input_names[:2], [input_ids, attention_mask])),
192
+ **make_empty_cache(
193
+ input_ids.shape[0],
194
+ session.input_names[2:],
195
+ session.input_shapes[2:],
196
+ session.input_types[2:],
197
+ ),
198
+ }
199
+ onnx_results = session.run(None, feeds)
200
+
201
+ outputs = model(input_ids, use_cache=True, attention_mask=attention_mask)
202
+
203
+ if session:
204
+ diff = max_diff(outputs, onnx_results)
205
+ assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
206
+ f"Unexpected issue with {type(model)}\ndiff={diff}"
207
+ f"\ninput_ids.shape={input_ids.shape}"
208
+ f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
209
+ f"\n got=\n"
210
+ f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
211
+ f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
212
+ )
213
+ diffs = [diff]
214
+
215
+ # Next calls: decode
216
+ for iteration in range(max_new_tokens):
217
+ next_token_logits = outputs.logits[:, -1, :]
218
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
219
+ if next_token_id.item() == eos_token_id:
220
+ break
221
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
222
+ attention_mask = torch.ones(
223
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
224
+ )
225
+ if session:
226
+ feeds = dict(
227
+ zip(
228
+ session.input_names,
229
+ [
230
+ t.detach()
231
+ for t in torch_deepcopy(
232
+ flatten_object(
233
+ [next_token_id, attention_mask, outputs.past_key_values]
234
+ )
235
+ )
236
+ ],
237
+ )
238
+ )
239
+ onnx_results = session.run(None, feeds)
240
+ outputs = model(
241
+ next_token_id,
242
+ use_cache=True,
243
+ past_key_values=outputs.past_key_values,
244
+ attention_mask=attention_mask,
245
+ )
246
+ if session:
247
+ diff = max_diff(outputs, onnx_results)
248
+ assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
249
+ f"Unexpected issue with {type(model)}, iteration={iteration}"
250
+ f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
251
+ f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
252
+ f"\n got=\n"
253
+ f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
254
+ f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
255
+ )
256
+ diffs.append(diff)
257
+ if session:
258
+ return input_ids, diffs
259
+ return input_ids
260
+
261
+
262
+ def onnx_generate(
263
+ model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
264
+ input_ids: torch.Tensor,
265
+ eos_token_id: int,
266
+ max_new_tokens=100,
267
+ return_session: bool = False,
268
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
269
+ """
270
+ Implements a simple method ``generate`` for an ONNX model.
271
+ The function does not expect any ``position_ids`` as input.
272
+
273
+ :param model_or_path: model or loaded model
274
+ :param input_ids: input tokens
275
+ :param eos_token_ids: token representing the end of an answer
276
+ :param max_new_tokens: stops after this number of generated tokens
277
+ :param return_session: returns the instance of class
278
+ :class:`InferenceSessionForTorch
279
+ <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
280
+ created if necessary
281
+ :return: input tokens concatenated with new tokens
282
+
283
+ .. runpython::
284
+ :showcode:
285
+
286
+ import os
287
+ from onnx_diagnostic.helpers import string_type, string_diff
288
+ from onnx_diagnostic.helpers.rt_helper import (
289
+ onnx_generate,
290
+ generate_and_validate,
291
+ onnx_generate_with_genai,
292
+ )
293
+ from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
294
+ from onnx_diagnostic.torch_export_patches import torch_export_patches
295
+ from onnx_diagnostic.export.api import to_onnx
296
+
297
+ mid = "arnir0/Tiny-LLM"
298
+ print(f"-- get model for {mid!r}")
299
+ data = get_untrained_model_with_inputs(mid)
300
+ model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
301
+ del inputs["position_ids"]
302
+ del ds["position_ids"]
303
+ input_ids = inputs["input_ids"]
304
+
305
+ print(f"-- input_ids={input_ids.shape}")
306
+ print(f"-- inputs: {string_type(inputs, with_shape=True)}")
307
+ print(f"-- dynamic_shapes: {string_type(ds)}")
308
+ folder = "dump_test"
309
+ os.makedirs(folder, exist_ok=True)
310
+ model_name = os.path.join(folder, "model.onnx")
311
+ print("-- test_onnx_generate: export model")
312
+ with torch_export_patches(patch_transformers=True, patch_torch=False):
313
+ to_onnx(
314
+ model,
315
+ (),
316
+ kwargs=inputs,
317
+ dynamic_shapes=ds,
318
+ filename=model_name,
319
+ exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
320
+ )
321
+
322
+ print("-- generate with onnx")
323
+ onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
324
+ print("-- onnx output", onnx_outputs)
325
+
326
+ # The example continues with other functions doing the same.
327
+ print("-- generate with pytorch")
328
+ torch_outputs, diffs = generate_and_validate(
329
+ model, input_ids[:1], 2, max_new_tokens=10, session=model_name
330
+ )
331
+ print("-- torch output", torch_outputs)
332
+ print("-- differences at each step:")
333
+ for i, d in enumerate(diffs):
334
+ print(f"iteration {i}: {string_diff(d)}")
335
+
336
+ print("-- generate with genai")
337
+ genai_outputs, session = onnx_generate_with_genai(
338
+ model_name,
339
+ input_ids[:1],
340
+ max_new_tokens=10,
341
+ return_session=True,
342
+ transformers_config=data["configuration"],
343
+ )
344
+ print("-- genai output", genai_outputs)
345
+ """
346
+ if not isinstance(model_or_path, InferenceSessionForTorch):
347
+ providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
348
+ providers.append("CPUExecutionProvider")
349
+ session = InferenceSessionForTorch(model_or_path, providers=providers)
350
+ else:
351
+ session = model_or_path
352
+
353
+ input_shapes = session.input_shapes
354
+ input_names = session.input_names
355
+ input_types = session.input_types
356
+
357
+ assert (
358
+ len(input_names) > 2
359
+ and input_names[:2] == ["input_ids", "attention_mask"]
360
+ and input_names[2].startswith("past_key_values")
361
+ ), f"Only text generation is supported but input_names == {input_names}"
362
+
363
+ # First call: prefill
364
+ feeds = dict(
365
+ input_ids=input_ids,
366
+ attention_mask=torch.ones(
367
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
368
+ ),
369
+ **make_empty_cache(
370
+ input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
371
+ ),
372
+ )
373
+
374
+ outputs = session.run(None, feeds)
375
+
376
+ # Next calls: decode
377
+ for _ in range(max_new_tokens):
378
+ next_token_logits = outputs[0][:, -1, :]
379
+
380
+ # The most probable next token is chosen.
381
+ next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
382
+ # But we could select it using a multinomial law
383
+ # <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
384
+ # <<< top_probs, top_indices = torch.topk(probs, top_k)
385
+ # <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
386
+
387
+ if next_token_id.item() == eos_token_id:
388
+ break
389
+ input_ids = torch.cat([input_ids, next_token_id.to(input_ids.device)], dim=-1)
390
+ feeds = dict(
391
+ input_ids=next_token_id,
392
+ attention_mask=torch.ones(
393
+ input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
394
+ ),
395
+ )
396
+ feeds.update(dict(zip(input_names[2:], outputs[1:])))
397
+ outputs = session.run(None, feeds)
398
+
399
+ if return_session:
400
+ return input_ids, session
401
+ return input_ids
402
+
403
+
404
+ def onnx_generate_with_genai(
405
+ model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
406
+ input_ids: torch.Tensor,
407
+ max_new_tokens=100,
408
+ return_session: bool = False,
409
+ transformers_config: Optional[Any] = None,
410
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
411
+ """
412
+ Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate``
413
+ for an ONNX model. The function does not expect any ``position_ids`` as input.
414
+
415
+ :param model_or_path: model or loaded model
416
+ :param input_ids: input tokens
417
+ :param eos_token_ids: token representing the end of an answer
418
+ :param max_new_tokens: stops after this number of generated tokens
419
+ :param return_session: returns the instance of class
420
+ :class:`InferenceSessionForTorch
421
+ <onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
422
+ created if necessary
423
+ :param transformers_config: write configuration
424
+ if missing and if this configuration is provided
425
+ :return: input tokens concatenated with new tokens
426
+
427
+ See example given with function :func:`onnx_generate
428
+ <onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
429
+ """
430
+ import onnxruntime_genai as og
431
+
432
+ if not isinstance(model_or_path, og.Model):
433
+ from .model_builder_helper import make_genai_config
434
+
435
+ assert isinstance(
436
+ model_or_path, str
437
+ ), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}"
438
+ folder = os.path.dirname(model_or_path)
439
+ assert os.path.exists(folder), f"Folder {folder!r} does not exists."
440
+ assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists."
441
+ config_file = os.path.join(folder, "genai_config.json")
442
+ if not os.path.exists(config_file):
443
+ if not transformers_config:
444
+ raise FileNotFoundError(
445
+ f"Folder {model_or_path!r} does not contain 'genai_config.json'."
446
+ )
447
+ config = make_genai_config(transformers_config, model_or_path)
448
+ with open(config_file, "w") as f:
449
+ json.dump(config, f, indent=4)
450
+
451
+ config = og.Config(os.path.dirname(config_file))
452
+ if input_ids.is_cuda:
453
+ config.clear_providers()
454
+ config.append_provider("cuda")
455
+ session = og.Model(config)
456
+ else:
457
+ session = model_or_path
458
+
459
+ params = og.GeneratorParams(session)
460
+ params.set_search_options(
461
+ max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0]
462
+ )
463
+ generator = og.Generator(session, params)
464
+
465
+ # First call: prefill
466
+ cats = []
467
+ generator.append_tokens(input_ids)
468
+ while not generator.is_done():
469
+ generator.generate_next_token()
470
+ new_token = generator.get_next_tokens()[0]
471
+ cats.append(int(new_token))
472
+
473
+ input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1)
474
+ if return_session:
475
+ return input_ids, session
476
+ return input_ids