onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,476 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
4
|
+
import numpy as np
|
|
5
|
+
import onnx
|
|
6
|
+
import torch
|
|
7
|
+
from .helper import string_type, flatten_object, max_diff
|
|
8
|
+
from .torch_helper import torch_deepcopy
|
|
9
|
+
from .ort_session import InferenceSessionForTorch
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def name_type_to_onnx_dtype(name: str) -> int:
|
|
13
|
+
if name == "tensor(int64)":
|
|
14
|
+
return onnx.TensorProto.INT64
|
|
15
|
+
if name == "tensor(float)":
|
|
16
|
+
return onnx.TensorProto.FLOAT
|
|
17
|
+
if name == "tensor(float16)":
|
|
18
|
+
return onnx.TensorProto.FLOAT16
|
|
19
|
+
raise AssertionError(f"Unexpected value {name!r}")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def make_feeds(
|
|
23
|
+
proto: Union[onnx.ModelProto, List[str]],
|
|
24
|
+
inputs: Any,
|
|
25
|
+
use_numpy: bool = False,
|
|
26
|
+
copy: bool = False,
|
|
27
|
+
check_flatten: bool = True,
|
|
28
|
+
is_modelbuilder: bool = False,
|
|
29
|
+
) -> Dict[str, Union[torch.Tensor, np.ndarray]]:
|
|
30
|
+
"""
|
|
31
|
+
Serializes the inputs to produce feeds expected
|
|
32
|
+
by :class:`onnxruntime.InferenceSession`.
|
|
33
|
+
|
|
34
|
+
:param proto: onnx model or list of names
|
|
35
|
+
:param inputs: any kind of inputs
|
|
36
|
+
:param use_numpy: if True, converts torch tensors into numpy arrays
|
|
37
|
+
:param copy: a copy is made, this should be the case if the inputs is ingested
|
|
38
|
+
by ``OrtValue``
|
|
39
|
+
:param check_flatten: if True, checks the ``torch.utils._pytree.tree_flatten``
|
|
40
|
+
returns the same number of outputs
|
|
41
|
+
:param is_modelbuilder: if True, the exporter is ModelBuilder, and we need to reorder
|
|
42
|
+
the past_key_values inputs to match the expected order, and get rid of position_ids.
|
|
43
|
+
:return: feeds dictionary
|
|
44
|
+
"""
|
|
45
|
+
# NOTE: position_ids is a special case because ModelBuilder does not usually use it,
|
|
46
|
+
# because it's fued into rotary embedding in GQA.
|
|
47
|
+
if is_modelbuilder and isinstance(inputs, dict):
|
|
48
|
+
inputs.pop("position_ids", None) # Ensure 'position_ids' absent before removing.
|
|
49
|
+
|
|
50
|
+
flat = flatten_object(inputs, drop_keys=True)
|
|
51
|
+
assert (
|
|
52
|
+
not check_flatten
|
|
53
|
+
or not all(isinstance(obj, torch.Tensor) for obj in flat)
|
|
54
|
+
# or not is_cache_dynamic_registered(fast=True)
|
|
55
|
+
or len(flat) == len(torch.utils._pytree.tree_flatten(inputs)[0])
|
|
56
|
+
), (
|
|
57
|
+
f"Unexpected number of flattened objects, "
|
|
58
|
+
f"{string_type(flat, with_shape=True)} != "
|
|
59
|
+
f"{string_type(torch.utils._pytree.tree_flatten(inputs)[0], with_shape=True)}"
|
|
60
|
+
)
|
|
61
|
+
if use_numpy:
|
|
62
|
+
from .torch_helper import to_numpy
|
|
63
|
+
|
|
64
|
+
flat = [to_numpy(t) if isinstance(t, torch.Tensor) else t for t in flat]
|
|
65
|
+
names = (
|
|
66
|
+
[i.name for i in proto.graph.input]
|
|
67
|
+
if isinstance(proto, onnx.ModelProto)
|
|
68
|
+
else (
|
|
69
|
+
[i.name for i in proto.get_inputs()]
|
|
70
|
+
if hasattr(proto, "get_inputs")
|
|
71
|
+
else (proto.input_names if hasattr(proto, "input_names") else proto)
|
|
72
|
+
)
|
|
73
|
+
)
|
|
74
|
+
assert (
|
|
75
|
+
isinstance(names, list)
|
|
76
|
+
and len(names) <= len(flat)
|
|
77
|
+
and (
|
|
78
|
+
len(names) == len(flat)
|
|
79
|
+
or isinstance(proto, onnx.ModelProto)
|
|
80
|
+
or hasattr(proto, "get_inputs")
|
|
81
|
+
)
|
|
82
|
+
), (
|
|
83
|
+
f"Not the same number of given inputs {len(flat)} "
|
|
84
|
+
f"and the number of model inputs {len(names)}, "
|
|
85
|
+
f"type(names)={type(names)}, type(proto)={type(proto)}"
|
|
86
|
+
f"\n-- inputs={string_type(inputs, with_shape=True)}"
|
|
87
|
+
f"\n-- names={names}"
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
if copy:
|
|
91
|
+
flat = [t.copy() if hasattr(t, "copy") else t.clone() for t in flat]
|
|
92
|
+
# bool, int, float, onnxruntime does not support float, bool, int
|
|
93
|
+
new_flat = []
|
|
94
|
+
for i in flat:
|
|
95
|
+
if isinstance(i, bool):
|
|
96
|
+
i = np.array(i, dtype=np.bool_)
|
|
97
|
+
elif isinstance(i, int):
|
|
98
|
+
i = np.array(i, dtype=np.int64)
|
|
99
|
+
elif isinstance(i, float):
|
|
100
|
+
i = np.array(i, dtype=np.float32)
|
|
101
|
+
new_flat.append(i)
|
|
102
|
+
return dict(zip(names, new_flat))
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _get_dim(i: int, s: Union[str, int], batch: int = 1) -> int:
|
|
106
|
+
if isinstance(s, int):
|
|
107
|
+
return s
|
|
108
|
+
if s == "batch":
|
|
109
|
+
return batch
|
|
110
|
+
# Everything else is cache length or sequence length.
|
|
111
|
+
return 0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
_DTYPES = {
|
|
115
|
+
"tensor(float)": torch.float32,
|
|
116
|
+
"tensor(float16)": torch.float16,
|
|
117
|
+
"tensor(bfloat16)": torch.bfloat16,
|
|
118
|
+
"tensor(int64)": torch.int64,
|
|
119
|
+
"tensor(int32)": torch.int32,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def rt_type_to_torch_dtype(typename: str) -> torch.dtype:
|
|
124
|
+
"""Converts a string such as ``tensor(float)`` into a dtype (torch.float32)."""
|
|
125
|
+
return _DTYPES[typename]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def make_empty_cache(
|
|
129
|
+
batch: int,
|
|
130
|
+
onnx_input_names: List[str],
|
|
131
|
+
onnx_input_shapes: List[Tuple[Union[int, str], ...]],
|
|
132
|
+
onnx_input_types: List[str],
|
|
133
|
+
) -> Dict[str, torch.Tensor]:
|
|
134
|
+
"""
|
|
135
|
+
Creates an empty cache. Example:
|
|
136
|
+
|
|
137
|
+
.. code-block:: python
|
|
138
|
+
|
|
139
|
+
make_empty_cache(
|
|
140
|
+
1,
|
|
141
|
+
sess.input_names[2:],
|
|
142
|
+
[i.shape for i in sess.get_inputs()[2:]],
|
|
143
|
+
[i.type for i in sess.get_inputs()[2:]],
|
|
144
|
+
)
|
|
145
|
+
"""
|
|
146
|
+
feeds = {}
|
|
147
|
+
for name, shape, dtype in zip(onnx_input_names, onnx_input_shapes, onnx_input_types):
|
|
148
|
+
new_shape = tuple(_get_dim(i, s, batch=batch) for i, s in enumerate(shape))
|
|
149
|
+
feeds[name] = torch.empty(new_shape, dtype=rt_type_to_torch_dtype(dtype))
|
|
150
|
+
return feeds
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def generate_and_validate(
|
|
154
|
+
model,
|
|
155
|
+
input_ids: torch.Tensor,
|
|
156
|
+
eos_token_id: int,
|
|
157
|
+
max_new_tokens: int = 100,
|
|
158
|
+
session: Optional[Union[InferenceSessionForTorch, onnx.ModelProto, str]] = None,
|
|
159
|
+
atol: float = 0.1,
|
|
160
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[Dict]]]:
|
|
161
|
+
"""
|
|
162
|
+
Implements a simple method ``generate`` for a torch model.
|
|
163
|
+
The function does not expect any ``position_ids`` as input.
|
|
164
|
+
The function also checks the outputs coming from an onnx model
|
|
165
|
+
are close to the output the torch model produces.
|
|
166
|
+
|
|
167
|
+
:param model_or_path: model or loaded model
|
|
168
|
+
:param input_ids: input tokens
|
|
169
|
+
:param eos_token_ids: token representing the end of an answer
|
|
170
|
+
:param max_new_tokens: stops after this number of generated tokens
|
|
171
|
+
:param session: the onnx model
|
|
172
|
+
:return: input tokens concatenated with new tokens,
|
|
173
|
+
if session is not null, it also returns the maximum differences
|
|
174
|
+
at every iterations
|
|
175
|
+
|
|
176
|
+
See example given with function :func:`onnx_generate
|
|
177
|
+
<onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
|
|
178
|
+
"""
|
|
179
|
+
if session is not None:
|
|
180
|
+
if not isinstance(session, InferenceSessionForTorch):
|
|
181
|
+
providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
|
|
182
|
+
providers.append("CPUExecutionProvider")
|
|
183
|
+
session = InferenceSessionForTorch(session, providers=providers)
|
|
184
|
+
|
|
185
|
+
# First call: prefill
|
|
186
|
+
attention_mask = torch.ones(
|
|
187
|
+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
|
|
188
|
+
)
|
|
189
|
+
if session:
|
|
190
|
+
feeds = {
|
|
191
|
+
**dict(zip(session.input_names[:2], [input_ids, attention_mask])),
|
|
192
|
+
**make_empty_cache(
|
|
193
|
+
input_ids.shape[0],
|
|
194
|
+
session.input_names[2:],
|
|
195
|
+
session.input_shapes[2:],
|
|
196
|
+
session.input_types[2:],
|
|
197
|
+
),
|
|
198
|
+
}
|
|
199
|
+
onnx_results = session.run(None, feeds)
|
|
200
|
+
|
|
201
|
+
outputs = model(input_ids, use_cache=True, attention_mask=attention_mask)
|
|
202
|
+
|
|
203
|
+
if session:
|
|
204
|
+
diff = max_diff(outputs, onnx_results)
|
|
205
|
+
assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
|
|
206
|
+
f"Unexpected issue with {type(model)}\ndiff={diff}"
|
|
207
|
+
f"\ninput_ids.shape={input_ids.shape}"
|
|
208
|
+
f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
|
|
209
|
+
f"\n got=\n"
|
|
210
|
+
f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
|
|
211
|
+
f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
|
|
212
|
+
)
|
|
213
|
+
diffs = [diff]
|
|
214
|
+
|
|
215
|
+
# Next calls: decode
|
|
216
|
+
for iteration in range(max_new_tokens):
|
|
217
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
218
|
+
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
219
|
+
if next_token_id.item() == eos_token_id:
|
|
220
|
+
break
|
|
221
|
+
input_ids = torch.cat([input_ids, next_token_id], dim=-1)
|
|
222
|
+
attention_mask = torch.ones(
|
|
223
|
+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
|
|
224
|
+
)
|
|
225
|
+
if session:
|
|
226
|
+
feeds = dict(
|
|
227
|
+
zip(
|
|
228
|
+
session.input_names,
|
|
229
|
+
[
|
|
230
|
+
t.detach()
|
|
231
|
+
for t in torch_deepcopy(
|
|
232
|
+
flatten_object(
|
|
233
|
+
[next_token_id, attention_mask, outputs.past_key_values]
|
|
234
|
+
)
|
|
235
|
+
)
|
|
236
|
+
],
|
|
237
|
+
)
|
|
238
|
+
)
|
|
239
|
+
onnx_results = session.run(None, feeds)
|
|
240
|
+
outputs = model(
|
|
241
|
+
next_token_id,
|
|
242
|
+
use_cache=True,
|
|
243
|
+
past_key_values=outputs.past_key_values,
|
|
244
|
+
attention_mask=attention_mask,
|
|
245
|
+
)
|
|
246
|
+
if session:
|
|
247
|
+
diff = max_diff(outputs, onnx_results)
|
|
248
|
+
assert isinstance(diff["abs"], float) and diff["abs"] <= atol, (
|
|
249
|
+
f"Unexpected issue with {type(model)}, iteration={iteration}"
|
|
250
|
+
f"\ndiff={diff}\ninput_ids.shape={input_ids.shape}"
|
|
251
|
+
f"\nexpected={string_type(outputs, with_shape=True, with_min_max=True)}"
|
|
252
|
+
f"\n got=\n"
|
|
253
|
+
f"{string_type(onnx_results, with_shape=True, with_min_max=True)}\n"
|
|
254
|
+
f"feeds={string_type(feeds, with_shape=True, with_min_max=True)}"
|
|
255
|
+
)
|
|
256
|
+
diffs.append(diff)
|
|
257
|
+
if session:
|
|
258
|
+
return input_ids, diffs
|
|
259
|
+
return input_ids
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def onnx_generate(
|
|
263
|
+
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
|
|
264
|
+
input_ids: torch.Tensor,
|
|
265
|
+
eos_token_id: int,
|
|
266
|
+
max_new_tokens=100,
|
|
267
|
+
return_session: bool = False,
|
|
268
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
|
|
269
|
+
"""
|
|
270
|
+
Implements a simple method ``generate`` for an ONNX model.
|
|
271
|
+
The function does not expect any ``position_ids`` as input.
|
|
272
|
+
|
|
273
|
+
:param model_or_path: model or loaded model
|
|
274
|
+
:param input_ids: input tokens
|
|
275
|
+
:param eos_token_ids: token representing the end of an answer
|
|
276
|
+
:param max_new_tokens: stops after this number of generated tokens
|
|
277
|
+
:param return_session: returns the instance of class
|
|
278
|
+
:class:`InferenceSessionForTorch
|
|
279
|
+
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
|
|
280
|
+
created if necessary
|
|
281
|
+
:return: input tokens concatenated with new tokens
|
|
282
|
+
|
|
283
|
+
.. runpython::
|
|
284
|
+
:showcode:
|
|
285
|
+
|
|
286
|
+
import os
|
|
287
|
+
from onnx_diagnostic.helpers import string_type, string_diff
|
|
288
|
+
from onnx_diagnostic.helpers.rt_helper import (
|
|
289
|
+
onnx_generate,
|
|
290
|
+
generate_and_validate,
|
|
291
|
+
onnx_generate_with_genai,
|
|
292
|
+
)
|
|
293
|
+
from onnx_diagnostic.torch_models.hghub import get_untrained_model_with_inputs
|
|
294
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
295
|
+
from onnx_diagnostic.export.api import to_onnx
|
|
296
|
+
|
|
297
|
+
mid = "arnir0/Tiny-LLM"
|
|
298
|
+
print(f"-- get model for {mid!r}")
|
|
299
|
+
data = get_untrained_model_with_inputs(mid)
|
|
300
|
+
model, inputs, ds = data["model"], data["inputs"], data["dynamic_shapes"]
|
|
301
|
+
del inputs["position_ids"]
|
|
302
|
+
del ds["position_ids"]
|
|
303
|
+
input_ids = inputs["input_ids"]
|
|
304
|
+
|
|
305
|
+
print(f"-- input_ids={input_ids.shape}")
|
|
306
|
+
print(f"-- inputs: {string_type(inputs, with_shape=True)}")
|
|
307
|
+
print(f"-- dynamic_shapes: {string_type(ds)}")
|
|
308
|
+
folder = "dump_test"
|
|
309
|
+
os.makedirs(folder, exist_ok=True)
|
|
310
|
+
model_name = os.path.join(folder, "model.onnx")
|
|
311
|
+
print("-- test_onnx_generate: export model")
|
|
312
|
+
with torch_export_patches(patch_transformers=True, patch_torch=False):
|
|
313
|
+
to_onnx(
|
|
314
|
+
model,
|
|
315
|
+
(),
|
|
316
|
+
kwargs=inputs,
|
|
317
|
+
dynamic_shapes=ds,
|
|
318
|
+
filename=model_name,
|
|
319
|
+
exporter="custom", # custom, dynamo or onnx-dynamo, modelbuilder
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
print("-- generate with onnx")
|
|
323
|
+
onnx_outputs = onnx_generate(model_name, input_ids[:1], 2, max_new_tokens=10)
|
|
324
|
+
print("-- onnx output", onnx_outputs)
|
|
325
|
+
|
|
326
|
+
# The example continues with other functions doing the same.
|
|
327
|
+
print("-- generate with pytorch")
|
|
328
|
+
torch_outputs, diffs = generate_and_validate(
|
|
329
|
+
model, input_ids[:1], 2, max_new_tokens=10, session=model_name
|
|
330
|
+
)
|
|
331
|
+
print("-- torch output", torch_outputs)
|
|
332
|
+
print("-- differences at each step:")
|
|
333
|
+
for i, d in enumerate(diffs):
|
|
334
|
+
print(f"iteration {i}: {string_diff(d)}")
|
|
335
|
+
|
|
336
|
+
print("-- generate with genai")
|
|
337
|
+
genai_outputs, session = onnx_generate_with_genai(
|
|
338
|
+
model_name,
|
|
339
|
+
input_ids[:1],
|
|
340
|
+
max_new_tokens=10,
|
|
341
|
+
return_session=True,
|
|
342
|
+
transformers_config=data["configuration"],
|
|
343
|
+
)
|
|
344
|
+
print("-- genai output", genai_outputs)
|
|
345
|
+
"""
|
|
346
|
+
if not isinstance(model_or_path, InferenceSessionForTorch):
|
|
347
|
+
providers = ["CUDAExecutionProvider"] if input_ids.is_cuda else []
|
|
348
|
+
providers.append("CPUExecutionProvider")
|
|
349
|
+
session = InferenceSessionForTorch(model_or_path, providers=providers)
|
|
350
|
+
else:
|
|
351
|
+
session = model_or_path
|
|
352
|
+
|
|
353
|
+
input_shapes = session.input_shapes
|
|
354
|
+
input_names = session.input_names
|
|
355
|
+
input_types = session.input_types
|
|
356
|
+
|
|
357
|
+
assert (
|
|
358
|
+
len(input_names) > 2
|
|
359
|
+
and input_names[:2] == ["input_ids", "attention_mask"]
|
|
360
|
+
and input_names[2].startswith("past_key_values")
|
|
361
|
+
), f"Only text generation is supported but input_names == {input_names}"
|
|
362
|
+
|
|
363
|
+
# First call: prefill
|
|
364
|
+
feeds = dict(
|
|
365
|
+
input_ids=input_ids,
|
|
366
|
+
attention_mask=torch.ones(
|
|
367
|
+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
|
|
368
|
+
),
|
|
369
|
+
**make_empty_cache(
|
|
370
|
+
input_ids.shape[0], input_names[2:], input_shapes[2:], input_types[2:]
|
|
371
|
+
),
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
outputs = session.run(None, feeds)
|
|
375
|
+
|
|
376
|
+
# Next calls: decode
|
|
377
|
+
for _ in range(max_new_tokens):
|
|
378
|
+
next_token_logits = outputs[0][:, -1, :]
|
|
379
|
+
|
|
380
|
+
# The most probable next token is chosen.
|
|
381
|
+
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
382
|
+
# But we could select it using a multinomial law
|
|
383
|
+
# <<< probs = torch.softmax(next_token_logits / temperature, dim=-1)
|
|
384
|
+
# <<< top_probs, top_indices = torch.topk(probs, top_k)
|
|
385
|
+
# <<< next_token_id = top_indices[torch.multinomial(top_probs, 1)]
|
|
386
|
+
|
|
387
|
+
if next_token_id.item() == eos_token_id:
|
|
388
|
+
break
|
|
389
|
+
input_ids = torch.cat([input_ids, next_token_id.to(input_ids.device)], dim=-1)
|
|
390
|
+
feeds = dict(
|
|
391
|
+
input_ids=next_token_id,
|
|
392
|
+
attention_mask=torch.ones(
|
|
393
|
+
input_ids.shape, dtype=input_ids.dtype, device=input_ids.device
|
|
394
|
+
),
|
|
395
|
+
)
|
|
396
|
+
feeds.update(dict(zip(input_names[2:], outputs[1:])))
|
|
397
|
+
outputs = session.run(None, feeds)
|
|
398
|
+
|
|
399
|
+
if return_session:
|
|
400
|
+
return input_ids, session
|
|
401
|
+
return input_ids
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def onnx_generate_with_genai(
|
|
405
|
+
model_or_path: Union[onnx.ModelProto, str, InferenceSessionForTorch],
|
|
406
|
+
input_ids: torch.Tensor,
|
|
407
|
+
max_new_tokens=100,
|
|
408
|
+
return_session: bool = False,
|
|
409
|
+
transformers_config: Optional[Any] = None,
|
|
410
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, InferenceSessionForTorch]]:
|
|
411
|
+
"""
|
|
412
|
+
Uses :epkg:`onnxruntime-genai` to implement a simple method ``generate``
|
|
413
|
+
for an ONNX model. The function does not expect any ``position_ids`` as input.
|
|
414
|
+
|
|
415
|
+
:param model_or_path: model or loaded model
|
|
416
|
+
:param input_ids: input tokens
|
|
417
|
+
:param eos_token_ids: token representing the end of an answer
|
|
418
|
+
:param max_new_tokens: stops after this number of generated tokens
|
|
419
|
+
:param return_session: returns the instance of class
|
|
420
|
+
:class:`InferenceSessionForTorch
|
|
421
|
+
<onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch>`
|
|
422
|
+
created if necessary
|
|
423
|
+
:param transformers_config: write configuration
|
|
424
|
+
if missing and if this configuration is provided
|
|
425
|
+
:return: input tokens concatenated with new tokens
|
|
426
|
+
|
|
427
|
+
See example given with function :func:`onnx_generate
|
|
428
|
+
<onnx_diagnostic.helpers.rt_helper.onnx_generate>`.
|
|
429
|
+
"""
|
|
430
|
+
import onnxruntime_genai as og
|
|
431
|
+
|
|
432
|
+
if not isinstance(model_or_path, og.Model):
|
|
433
|
+
from .model_builder_helper import make_genai_config
|
|
434
|
+
|
|
435
|
+
assert isinstance(
|
|
436
|
+
model_or_path, str
|
|
437
|
+
), f"Only a filename is allowed for model_or_path but type is {type(model_or_path)}"
|
|
438
|
+
folder = os.path.dirname(model_or_path)
|
|
439
|
+
assert os.path.exists(folder), f"Folder {folder!r} does not exists."
|
|
440
|
+
assert os.path.exists(model_or_path), f"Folder {model_or_path!r} does not exists."
|
|
441
|
+
config_file = os.path.join(folder, "genai_config.json")
|
|
442
|
+
if not os.path.exists(config_file):
|
|
443
|
+
if not transformers_config:
|
|
444
|
+
raise FileNotFoundError(
|
|
445
|
+
f"Folder {model_or_path!r} does not contain 'genai_config.json'."
|
|
446
|
+
)
|
|
447
|
+
config = make_genai_config(transformers_config, model_or_path)
|
|
448
|
+
with open(config_file, "w") as f:
|
|
449
|
+
json.dump(config, f, indent=4)
|
|
450
|
+
|
|
451
|
+
config = og.Config(os.path.dirname(config_file))
|
|
452
|
+
if input_ids.is_cuda:
|
|
453
|
+
config.clear_providers()
|
|
454
|
+
config.append_provider("cuda")
|
|
455
|
+
session = og.Model(config)
|
|
456
|
+
else:
|
|
457
|
+
session = model_or_path
|
|
458
|
+
|
|
459
|
+
params = og.GeneratorParams(session)
|
|
460
|
+
params.set_search_options(
|
|
461
|
+
max_length=max_new_tokens + input_ids.shape[1], batch_size=input_ids.shape[0]
|
|
462
|
+
)
|
|
463
|
+
generator = og.Generator(session, params)
|
|
464
|
+
|
|
465
|
+
# First call: prefill
|
|
466
|
+
cats = []
|
|
467
|
+
generator.append_tokens(input_ids)
|
|
468
|
+
while not generator.is_done():
|
|
469
|
+
generator.generate_next_token()
|
|
470
|
+
new_token = generator.get_next_tokens()[0]
|
|
471
|
+
cats.append(int(new_token))
|
|
472
|
+
|
|
473
|
+
input_ids = torch.cat([input_ids, torch.tensor([cats], dtype=torch.int64)], dim=-1)
|
|
474
|
+
if return_session:
|
|
475
|
+
return input_ids, session
|
|
476
|
+
return input_ids
|