onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,469 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import importlib.util
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import requests
|
|
6
|
+
import sys
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, List, Optional, Union
|
|
9
|
+
from urllib.parse import urlparse
|
|
10
|
+
from onnx import ModelProto, TensorProto, load as load_model
|
|
11
|
+
|
|
12
|
+
CACHE_SUBDIR = "onnx-diagnostic"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def download_model_builder_to_cache(
|
|
16
|
+
url: str = "https://raw.githubusercontent.com/microsoft/onnxruntime-genai/refs/heads/main/src/python/py/models/builder.py",
|
|
17
|
+
):
|
|
18
|
+
"""
|
|
19
|
+
Downloads ``builder.py`` from the
|
|
20
|
+
``https://github.com/microsoft/onnxruntime-genai/blob/main/src/python/py/models/builder.py``.
|
|
21
|
+
"""
|
|
22
|
+
filename = os.path.basename(urlparse(url).path)
|
|
23
|
+
cache_dir = Path(os.getenv("HOME", Path.home())) / ".cache" / CACHE_SUBDIR
|
|
24
|
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
|
25
|
+
|
|
26
|
+
file_path = cache_dir / filename
|
|
27
|
+
|
|
28
|
+
if file_path.exists():
|
|
29
|
+
return file_path
|
|
30
|
+
|
|
31
|
+
response = requests.get(url)
|
|
32
|
+
response.raise_for_status()
|
|
33
|
+
with open(file_path, "wb") as f:
|
|
34
|
+
f.write(response.content)
|
|
35
|
+
return file_path
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def import_model_builder(module_name: str = "builder") -> object:
|
|
39
|
+
"""Imports the downloaded ``model.by``."""
|
|
40
|
+
if module_name in sys.modules:
|
|
41
|
+
return sys.modules[module_name]
|
|
42
|
+
path = Path(os.getenv("HOME", Path.home())) / ".cache" / CACHE_SUBDIR
|
|
43
|
+
module_file = path / f"{module_name}.py"
|
|
44
|
+
assert os.path.exists(module_file), f"Unable to find {module_file!r}"
|
|
45
|
+
spec = importlib.util.spec_from_file_location(module_name, str(path))
|
|
46
|
+
if spec is None:
|
|
47
|
+
spath = str(path)
|
|
48
|
+
if spath not in sys.path:
|
|
49
|
+
sys.path.append(spath)
|
|
50
|
+
module = importlib.__import__(module_name)
|
|
51
|
+
return module
|
|
52
|
+
assert spec is not None, f"Unable to import module {module_name!r} from {str(path)!r}"
|
|
53
|
+
module = importlib.util.module_from_spec(spec)
|
|
54
|
+
sys.modules[module_name] = module
|
|
55
|
+
spec.loader.exec_module(module)
|
|
56
|
+
return module
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _make_model(self, model, verbose: int = 0):
|
|
60
|
+
# Make inputs and outputs to ONNX model
|
|
61
|
+
import torch
|
|
62
|
+
|
|
63
|
+
self.make_inputs_and_outputs()
|
|
64
|
+
|
|
65
|
+
# Make pre-processing nodes
|
|
66
|
+
self.make_preprocessing_nodes()
|
|
67
|
+
|
|
68
|
+
# Loop through model and map each module to ONNX/ORT ops
|
|
69
|
+
self.layer_id = 0
|
|
70
|
+
for module in model.modules():
|
|
71
|
+
if (
|
|
72
|
+
isinstance(module, torch.nn.Embedding)
|
|
73
|
+
and module.weight.shape[0] == self.vocab_size
|
|
74
|
+
) or (hasattr(model, "embedding") and module == model.embedding):
|
|
75
|
+
# Checks (Hugging Face logic) or (GGUF logic)
|
|
76
|
+
if not self.exclude_embeds:
|
|
77
|
+
# Embedding layer
|
|
78
|
+
if verbose:
|
|
79
|
+
print("[_make_model] Reading embedding layer")
|
|
80
|
+
self.make_embedding(module.weight.detach().cpu())
|
|
81
|
+
else:
|
|
82
|
+
# Exclude embedding layer from model
|
|
83
|
+
self.layernorm_attrs["root_input"] = "inputs_embeds"
|
|
84
|
+
self.layernorm_attrs["skip_input"] = "inputs_embeds"
|
|
85
|
+
|
|
86
|
+
elif (
|
|
87
|
+
module.__class__.__name__.endswith("DecoderLayer")
|
|
88
|
+
or module.__class__.__name__.endswith("GLMBlock")
|
|
89
|
+
) and self.layer_id < self.num_layers:
|
|
90
|
+
# Each decoder layer of model
|
|
91
|
+
if verbose:
|
|
92
|
+
print(f"[_make_model] Reading decoder layer {self.layer_id}")
|
|
93
|
+
self.make_layer(self.layer_id, module)
|
|
94
|
+
self.layer_id += 1
|
|
95
|
+
|
|
96
|
+
elif self.layer_id == self.num_layers and self.has_final_norm(module, model):
|
|
97
|
+
# SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm)
|
|
98
|
+
if verbose:
|
|
99
|
+
print("[_make_model] Reading final norm")
|
|
100
|
+
self.make_layernorm(
|
|
101
|
+
self.layer_id,
|
|
102
|
+
module,
|
|
103
|
+
skip=True,
|
|
104
|
+
simple=self.layernorm_attrs["simple"],
|
|
105
|
+
location="final_norm",
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
elif (
|
|
109
|
+
isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size
|
|
110
|
+
) or (hasattr(model, "lm_head") and module == model.lm_head):
|
|
111
|
+
# Checks (Hugging Face logic) or (GGUF logic)
|
|
112
|
+
if not self.exclude_lm_head:
|
|
113
|
+
# Language modeling head (SkipLayerNorm --> logits)
|
|
114
|
+
if verbose:
|
|
115
|
+
print("[_make_model] Reading LM head")
|
|
116
|
+
self.make_lm_head(module)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def save_model_builder(
|
|
120
|
+
self, out_dir: Optional[str] = "", verbose: int = 0
|
|
121
|
+
) -> Union[str, ModelProto]:
|
|
122
|
+
"""
|
|
123
|
+
Saves a model created by function :func:`create_model_builder`.
|
|
124
|
+
If out_dir is empty or not specified, the function still returns the
|
|
125
|
+
generated model.
|
|
126
|
+
"""
|
|
127
|
+
import onnx_ir
|
|
128
|
+
|
|
129
|
+
if verbose:
|
|
130
|
+
print(f"[save_model_builder] Saving ONNX model in {out_dir!r}")
|
|
131
|
+
|
|
132
|
+
# Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
|
|
133
|
+
already_quantized_in_qdq_format = (
|
|
134
|
+
self.quant_type is not None and self.quant_attrs["use_qdq"]
|
|
135
|
+
)
|
|
136
|
+
model = (
|
|
137
|
+
self.to_int4()
|
|
138
|
+
if self.onnx_dtype in {onnx_ir.DataType.INT4, onnx_ir.DataType.UINT4}
|
|
139
|
+
and not already_quantized_in_qdq_format
|
|
140
|
+
else self.model
|
|
141
|
+
)
|
|
142
|
+
model.graph.sort()
|
|
143
|
+
if not out_dir:
|
|
144
|
+
return onnx_ir.to_proto(model)
|
|
145
|
+
|
|
146
|
+
out_path = os.path.join(out_dir, self.filename)
|
|
147
|
+
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
|
|
148
|
+
|
|
149
|
+
# Save ONNX model with only one external data file and delete any existing duplicate copies
|
|
150
|
+
out_path = os.path.join(out_dir, self.filename)
|
|
151
|
+
data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
|
|
152
|
+
if os.path.exists(out_path):
|
|
153
|
+
if verbose:
|
|
154
|
+
print(f"[save_model_builder] Overwriting {out_path!r}")
|
|
155
|
+
os.remove(out_path)
|
|
156
|
+
if os.path.exists(data_path):
|
|
157
|
+
if verbose:
|
|
158
|
+
print(f"[save_model_builder] Overwriting {data_path!r}")
|
|
159
|
+
os.remove(data_path)
|
|
160
|
+
|
|
161
|
+
onnx_ir.save(
|
|
162
|
+
model,
|
|
163
|
+
out_path,
|
|
164
|
+
external_data=os.path.basename(data_path),
|
|
165
|
+
size_threshold_bytes=2**10,
|
|
166
|
+
)
|
|
167
|
+
if verbose:
|
|
168
|
+
print(f"[save_model_builder] saved in {out_dir!r}")
|
|
169
|
+
|
|
170
|
+
return out_path
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def create_model_builder(
|
|
174
|
+
config: Any,
|
|
175
|
+
model: "torch.nn.Module", # noqa: F821
|
|
176
|
+
cache_dir: str,
|
|
177
|
+
precision: str = "fp32",
|
|
178
|
+
execution_provider: str = "cpu",
|
|
179
|
+
verbose: int = 0,
|
|
180
|
+
**extra_options,
|
|
181
|
+
) -> "Model": # noqa: F821
|
|
182
|
+
"""
|
|
183
|
+
Creates a model based on a configuration.
|
|
184
|
+
The onnx model is returned by function :func:`save_model_builder`.
|
|
185
|
+
|
|
186
|
+
:param config: configuration
|
|
187
|
+
:param cache_dir: cache directory
|
|
188
|
+
:param precision: precision
|
|
189
|
+
:param execution_provider: execution provider
|
|
190
|
+
:param verbose: verbosity
|
|
191
|
+
:param extra_options: extra options
|
|
192
|
+
:return: model
|
|
193
|
+
"""
|
|
194
|
+
assert cache_dir, "create_model_builder does not work without cache_dir."
|
|
195
|
+
assert os.path.exists(cache_dir), f"cache_dir={cache_dir!r} does not exists"
|
|
196
|
+
precision = {"float32": "fp32", "float16": "fp16", "bfloat16": "bfp16"}.get(
|
|
197
|
+
precision, precision
|
|
198
|
+
)
|
|
199
|
+
download_model_builder_to_cache()
|
|
200
|
+
builder = import_model_builder()
|
|
201
|
+
io_dtype = builder.set_io_dtype(precision, execution_provider, extra_options)
|
|
202
|
+
|
|
203
|
+
arch_map = {
|
|
204
|
+
"ChatGLMForConditionalGeneration": builder.ChatGLMModel,
|
|
205
|
+
"ChatGLMModel": builder.ChatGLMModel,
|
|
206
|
+
"Ernie4_5_ForCausalLM": builder.ErnieModel,
|
|
207
|
+
"GemmaForCausalLM": builder.Gemma2Model,
|
|
208
|
+
"Gemma2ForCausalLM": builder.Gemma2Model,
|
|
209
|
+
"Gemma3ForCausalLM": builder.Gemma3Model,
|
|
210
|
+
"Gemma3ForConditionalGeneration": builder.Gemma3Model,
|
|
211
|
+
"GraniteForCausalLM": builder.GraniteModel,
|
|
212
|
+
"GptOssForCausalLM": builder.GPTOSSModel,
|
|
213
|
+
"LlamaForCausalLM": builder.LlamaModel,
|
|
214
|
+
"MistralForCausalLM": builder.MistralModel,
|
|
215
|
+
"NemotronForCausalLM": builder.NemotronModel,
|
|
216
|
+
"OlmoForCausalLM": builder.OLMoModel,
|
|
217
|
+
"PhiForCausalLM": builder.PhiModel,
|
|
218
|
+
"Phi3ForCausalLM": (
|
|
219
|
+
lambda config, *args: (
|
|
220
|
+
(
|
|
221
|
+
builder.Phi3MiniModel
|
|
222
|
+
if config.max_position_embeddings
|
|
223
|
+
== config.original_max_position_embeddings
|
|
224
|
+
else builder.Phi3MiniLongRoPEModel
|
|
225
|
+
)(config, *args)
|
|
226
|
+
)
|
|
227
|
+
),
|
|
228
|
+
"PhiMoEForCausalLM": builder.Phi3MoELongRoPEModel,
|
|
229
|
+
"Phi3SmallForCausalLM": (
|
|
230
|
+
lambda config, *args: (
|
|
231
|
+
(
|
|
232
|
+
builder.Phi3SmallModel
|
|
233
|
+
if config.max_position_embeddings
|
|
234
|
+
== config.original_max_position_embeddings
|
|
235
|
+
else builder.Phi3SmallLongRoPEModel
|
|
236
|
+
)(config, *args)
|
|
237
|
+
)
|
|
238
|
+
),
|
|
239
|
+
"Phi3VForCausalLM": builder.Phi3VModel,
|
|
240
|
+
"Phi4MMForCausalLM": builder.Phi4MMModel,
|
|
241
|
+
"Qwen2ForCausalLM": builder.QwenModel,
|
|
242
|
+
"Qwen3ForCausalLM": builder.Qwen3Model,
|
|
243
|
+
"SmolLM3ForCausalLM": builder.SmolLM3Model,
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
assert config.architectures[0] in arch_map, (
|
|
247
|
+
f"Unable find {config.architectures[0]!r} in the supported list "
|
|
248
|
+
f"of architectures: {sorted(arch_map)}"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# Additional validations.
|
|
252
|
+
post = None
|
|
253
|
+
if config.architectures[0] in ("ChatGLMForConditionalGeneration", "ChatGLMModel"):
|
|
254
|
+
# Quantized ChatGLM model has ChatGLMForConditionalGeneration
|
|
255
|
+
# as architecture whereas HF model as the latter
|
|
256
|
+
config.hidden_act = "swiglu"
|
|
257
|
+
elif config.architectures[0] == "Gemma2ForCausalLM":
|
|
258
|
+
assert precision == "bfp16", (
|
|
259
|
+
f"architecture {config.architectures[0]!r} loses accuracy "
|
|
260
|
+
f"with float16 precision, use bfp16."
|
|
261
|
+
)
|
|
262
|
+
elif config.architectures[0] == "Gemma3ForCausalLM":
|
|
263
|
+
assert precision == "bfp16", (
|
|
264
|
+
f"architecture {config.architectures[0]!r} loses accuracy "
|
|
265
|
+
f"with float16 precision, use bfp16."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
def _post(onnx_model):
|
|
269
|
+
onnx_model.model_type = "gemma3_text"
|
|
270
|
+
|
|
271
|
+
post = _post
|
|
272
|
+
elif config.architectures[0] == "Gemma3ForConditionalGeneration":
|
|
273
|
+
assert extra_options.get("exclude_embeds", False), (
|
|
274
|
+
f"This is only generating the text component of architecture "
|
|
275
|
+
f"{config.architectures[0]!r}. Set extra_options exclude_embeds=true."
|
|
276
|
+
)
|
|
277
|
+
assert precision == "bfp16", (
|
|
278
|
+
f"architecture {config.architectures[0]!r} loses accuracy "
|
|
279
|
+
f"with float16 precision, use bfp16."
|
|
280
|
+
)
|
|
281
|
+
text_config = config.text_config
|
|
282
|
+
for key in text_config:
|
|
283
|
+
if not hasattr(config, key):
|
|
284
|
+
setattr(config, key, getattr(text_config, key))
|
|
285
|
+
elif config.architectures[0] == "GptOssForCausalLM":
|
|
286
|
+
delattr(config, "quantization_config")
|
|
287
|
+
elif (
|
|
288
|
+
config.architectures[0] == "PhiMoEForCausalLM"
|
|
289
|
+
and config.max_position_embeddings != config.original_max_position_embeddings
|
|
290
|
+
):
|
|
291
|
+
assert execution_provider == "cuda", (
|
|
292
|
+
f"architecture {config.architectures[0]!r} works on 'cuda' "
|
|
293
|
+
f"because `MoE` is only supported for CUDA in ONNX Runtime."
|
|
294
|
+
)
|
|
295
|
+
assert precision == "int4", f"architecture {config.architectures[0]!r} supports int4."
|
|
296
|
+
elif config.architectures[0] == "Phi3VForCausalLM":
|
|
297
|
+
assert extra_options.get("exclude_embeds", False), (
|
|
298
|
+
f"This is only generating the text component of architecture "
|
|
299
|
+
f"{config.architectures[0]!r}. Set extra_options exclude_embeds=true."
|
|
300
|
+
)
|
|
301
|
+
elif config.architectures[0] == "Phi4MMForCausalLM":
|
|
302
|
+
assert extra_options.get("exclude_embeds", False), (
|
|
303
|
+
f"This is only generating the text component of architecture "
|
|
304
|
+
f"{config.architectures[0]!r}. Set extra_options exclude_embeds=true."
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
cls = arch_map[config.architectures[0]]
|
|
308
|
+
|
|
309
|
+
# ModelBuilder does not like None values for some parameters.
|
|
310
|
+
remove = set()
|
|
311
|
+
for c in ["head_dim"]:
|
|
312
|
+
if hasattr(config, c) and getattr(config, c) is None:
|
|
313
|
+
remove.add(c)
|
|
314
|
+
for c in remove:
|
|
315
|
+
delattr(config, c)
|
|
316
|
+
|
|
317
|
+
convert = {
|
|
318
|
+
"fp32": TensorProto.FLOAT,
|
|
319
|
+
"fp16": TensorProto.FLOAT16,
|
|
320
|
+
"bfp16": TensorProto.BFLOAT16,
|
|
321
|
+
}
|
|
322
|
+
assert (
|
|
323
|
+
precision in convert
|
|
324
|
+
), f"Unexpected value for precision={precision!r}, should be in {convert}"
|
|
325
|
+
onnx_model = cls(
|
|
326
|
+
config, io_dtype, convert[precision], execution_provider, cache_dir, extra_options
|
|
327
|
+
)
|
|
328
|
+
|
|
329
|
+
if post:
|
|
330
|
+
post(onnx_model)
|
|
331
|
+
_make_model(onnx_model, model, verbose=verbose)
|
|
332
|
+
|
|
333
|
+
assert onnx_model.model, (
|
|
334
|
+
f"No node in the model, io_dtype={io_dtype!r}, "
|
|
335
|
+
f"precision={precision!r}, execution_provider={execution_provider!r}, "
|
|
336
|
+
f"extra_options={extra_options!r}, cache_dir={cache_dir!r}, "
|
|
337
|
+
f"\n-- config --\n{config}"
|
|
338
|
+
)
|
|
339
|
+
# onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
|
|
340
|
+
# onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
|
|
341
|
+
return onnx_model
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def find_names_pattern(names: List[str]) -> str:
|
|
345
|
+
"""
|
|
346
|
+
Finds a repeatable patterns in a list of names.
|
|
347
|
+
It tries to locate the figures.
|
|
348
|
+
|
|
349
|
+
.. runpython::
|
|
350
|
+
:showcode:
|
|
351
|
+
|
|
352
|
+
from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
|
|
353
|
+
pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
|
|
354
|
+
print(pattern)
|
|
355
|
+
"""
|
|
356
|
+
patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
|
|
357
|
+
unique = set(patterns)
|
|
358
|
+
assert (
|
|
359
|
+
len(unique) == 1
|
|
360
|
+
), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
|
|
361
|
+
return patterns[0]
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def make_genai_config(
|
|
365
|
+
config,
|
|
366
|
+
onnx_filename: str,
|
|
367
|
+
) -> Dict:
|
|
368
|
+
"""
|
|
369
|
+
Creates genai config file for a model.
|
|
370
|
+
|
|
371
|
+
:param config: configuration from transformers
|
|
372
|
+
:param onnx_filename: onnx configuration
|
|
373
|
+
:return: configuration
|
|
374
|
+
"""
|
|
375
|
+
onx = load_model(onnx_filename, load_external_data=False)
|
|
376
|
+
config = copy.deepcopy(config)
|
|
377
|
+
defaults = {
|
|
378
|
+
"bos_token_id": None,
|
|
379
|
+
"do_sample": False,
|
|
380
|
+
"eos_token_id": None,
|
|
381
|
+
"pad_token_id": None,
|
|
382
|
+
"temperature": 1.0,
|
|
383
|
+
"top_k": 50,
|
|
384
|
+
"top_p": 1.0,
|
|
385
|
+
}
|
|
386
|
+
for key, default_val in defaults.items():
|
|
387
|
+
if not hasattr(config, key):
|
|
388
|
+
setattr(config, key, default_val)
|
|
389
|
+
|
|
390
|
+
bos_token_id = (
|
|
391
|
+
config.bos_token_id
|
|
392
|
+
if hasattr(config, "bos_token_id") and config.bos_token_id is not None
|
|
393
|
+
else 1
|
|
394
|
+
)
|
|
395
|
+
eos_token_id = config.eos_token_id
|
|
396
|
+
pad_token_id = (
|
|
397
|
+
config.pad_token_id
|
|
398
|
+
if hasattr(config, "pad_token_id") and config.pad_token_id is not None
|
|
399
|
+
else (
|
|
400
|
+
config.eos_token_id[0]
|
|
401
|
+
if isinstance(config.eos_token_id, list)
|
|
402
|
+
else config.eos_token_id
|
|
403
|
+
)
|
|
404
|
+
)
|
|
405
|
+
input_names = [i.name for i in onx.graph.input]
|
|
406
|
+
output_names = [i.name for i in onx.graph.output]
|
|
407
|
+
past_key_values = [s for s in input_names if s.startswith("past_key_value")]
|
|
408
|
+
first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
|
|
409
|
+
shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
|
|
410
|
+
return {
|
|
411
|
+
"model": {
|
|
412
|
+
"bos_token_id": bos_token_id,
|
|
413
|
+
"context_length": config.max_position_embeddings,
|
|
414
|
+
"decoder": {
|
|
415
|
+
"session_options": {
|
|
416
|
+
"log_id": "onnxruntime-genai",
|
|
417
|
+
"provider_options": [],
|
|
418
|
+
},
|
|
419
|
+
"filename": os.path.split(onnx_filename)[-1],
|
|
420
|
+
"head_size": shape[-1],
|
|
421
|
+
"hidden_size": config.hidden_size,
|
|
422
|
+
"inputs": {
|
|
423
|
+
"input_ids": input_names[0],
|
|
424
|
+
"attention_mask": input_names[1],
|
|
425
|
+
"past_key_names": find_names_pattern(input_names[2::2]),
|
|
426
|
+
"past_value_names": find_names_pattern(input_names[3::2]),
|
|
427
|
+
},
|
|
428
|
+
"outputs": {
|
|
429
|
+
"logits": output_names[0],
|
|
430
|
+
"present_key_names": find_names_pattern(output_names[1::2]),
|
|
431
|
+
"present_value_names": find_names_pattern(output_names[2::2]),
|
|
432
|
+
},
|
|
433
|
+
"num_attention_heads": config.num_attention_heads,
|
|
434
|
+
"num_hidden_layers": len(past_key_values) // 2,
|
|
435
|
+
"num_key_value_heads": shape[1],
|
|
436
|
+
},
|
|
437
|
+
"eos_token_id": eos_token_id,
|
|
438
|
+
"pad_token_id": pad_token_id,
|
|
439
|
+
"type": config.model_type,
|
|
440
|
+
# if "For" in self.model_type else len(self.model_type)].lower(),
|
|
441
|
+
"vocab_size": config.vocab_size,
|
|
442
|
+
},
|
|
443
|
+
"search": {
|
|
444
|
+
"diversity_penalty": (
|
|
445
|
+
config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
|
|
446
|
+
),
|
|
447
|
+
"do_sample": config.do_sample if hasattr(config, "do_sample") else False,
|
|
448
|
+
"early_stopping": True,
|
|
449
|
+
"length_penalty": (
|
|
450
|
+
config.length_penalty if hasattr(config, "length_penalty") else 1.0
|
|
451
|
+
),
|
|
452
|
+
"max_length": config.max_position_embeddings,
|
|
453
|
+
"min_length": 0,
|
|
454
|
+
"no_repeat_ngram_size": (
|
|
455
|
+
config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
|
|
456
|
+
),
|
|
457
|
+
"num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
|
|
458
|
+
"num_return_sequences": (
|
|
459
|
+
config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
|
|
460
|
+
),
|
|
461
|
+
"past_present_share_buffer": False,
|
|
462
|
+
"repetition_penalty": (
|
|
463
|
+
config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
|
|
464
|
+
),
|
|
465
|
+
"temperature": config.temperature if hasattr(config, "temperature") else 1.0,
|
|
466
|
+
"top_k": config.top_k if hasattr(config, "top_k") else 50,
|
|
467
|
+
"top_p": config.top_p if hasattr(config, "top_p") else 1.0,
|
|
468
|
+
},
|
|
469
|
+
}
|