onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,469 @@
1
+ import copy
2
+ import importlib.util
3
+ import os
4
+ import re
5
+ import requests
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Union
9
+ from urllib.parse import urlparse
10
+ from onnx import ModelProto, TensorProto, load as load_model
11
+
12
+ CACHE_SUBDIR = "onnx-diagnostic"
13
+
14
+
15
+ def download_model_builder_to_cache(
16
+ url: str = "https://raw.githubusercontent.com/microsoft/onnxruntime-genai/refs/heads/main/src/python/py/models/builder.py",
17
+ ):
18
+ """
19
+ Downloads ``builder.py`` from the
20
+ ``https://github.com/microsoft/onnxruntime-genai/blob/main/src/python/py/models/builder.py``.
21
+ """
22
+ filename = os.path.basename(urlparse(url).path)
23
+ cache_dir = Path(os.getenv("HOME", Path.home())) / ".cache" / CACHE_SUBDIR
24
+ cache_dir.mkdir(parents=True, exist_ok=True)
25
+
26
+ file_path = cache_dir / filename
27
+
28
+ if file_path.exists():
29
+ return file_path
30
+
31
+ response = requests.get(url)
32
+ response.raise_for_status()
33
+ with open(file_path, "wb") as f:
34
+ f.write(response.content)
35
+ return file_path
36
+
37
+
38
+ def import_model_builder(module_name: str = "builder") -> object:
39
+ """Imports the downloaded ``model.by``."""
40
+ if module_name in sys.modules:
41
+ return sys.modules[module_name]
42
+ path = Path(os.getenv("HOME", Path.home())) / ".cache" / CACHE_SUBDIR
43
+ module_file = path / f"{module_name}.py"
44
+ assert os.path.exists(module_file), f"Unable to find {module_file!r}"
45
+ spec = importlib.util.spec_from_file_location(module_name, str(path))
46
+ if spec is None:
47
+ spath = str(path)
48
+ if spath not in sys.path:
49
+ sys.path.append(spath)
50
+ module = importlib.__import__(module_name)
51
+ return module
52
+ assert spec is not None, f"Unable to import module {module_name!r} from {str(path)!r}"
53
+ module = importlib.util.module_from_spec(spec)
54
+ sys.modules[module_name] = module
55
+ spec.loader.exec_module(module)
56
+ return module
57
+
58
+
59
+ def _make_model(self, model, verbose: int = 0):
60
+ # Make inputs and outputs to ONNX model
61
+ import torch
62
+
63
+ self.make_inputs_and_outputs()
64
+
65
+ # Make pre-processing nodes
66
+ self.make_preprocessing_nodes()
67
+
68
+ # Loop through model and map each module to ONNX/ORT ops
69
+ self.layer_id = 0
70
+ for module in model.modules():
71
+ if (
72
+ isinstance(module, torch.nn.Embedding)
73
+ and module.weight.shape[0] == self.vocab_size
74
+ ) or (hasattr(model, "embedding") and module == model.embedding):
75
+ # Checks (Hugging Face logic) or (GGUF logic)
76
+ if not self.exclude_embeds:
77
+ # Embedding layer
78
+ if verbose:
79
+ print("[_make_model] Reading embedding layer")
80
+ self.make_embedding(module.weight.detach().cpu())
81
+ else:
82
+ # Exclude embedding layer from model
83
+ self.layernorm_attrs["root_input"] = "inputs_embeds"
84
+ self.layernorm_attrs["skip_input"] = "inputs_embeds"
85
+
86
+ elif (
87
+ module.__class__.__name__.endswith("DecoderLayer")
88
+ or module.__class__.__name__.endswith("GLMBlock")
89
+ ) and self.layer_id < self.num_layers:
90
+ # Each decoder layer of model
91
+ if verbose:
92
+ print(f"[_make_model] Reading decoder layer {self.layer_id}")
93
+ self.make_layer(self.layer_id, module)
94
+ self.layer_id += 1
95
+
96
+ elif self.layer_id == self.num_layers and self.has_final_norm(module, model):
97
+ # SkipLayerNorm after last decoder layer (MatMul --> SkipLayerNorm)
98
+ if verbose:
99
+ print("[_make_model] Reading final norm")
100
+ self.make_layernorm(
101
+ self.layer_id,
102
+ module,
103
+ skip=True,
104
+ simple=self.layernorm_attrs["simple"],
105
+ location="final_norm",
106
+ )
107
+
108
+ elif (
109
+ isinstance(module, torch.nn.Linear) and module.out_features == self.vocab_size
110
+ ) or (hasattr(model, "lm_head") and module == model.lm_head):
111
+ # Checks (Hugging Face logic) or (GGUF logic)
112
+ if not self.exclude_lm_head:
113
+ # Language modeling head (SkipLayerNorm --> logits)
114
+ if verbose:
115
+ print("[_make_model] Reading LM head")
116
+ self.make_lm_head(module)
117
+
118
+
119
+ def save_model_builder(
120
+ self, out_dir: Optional[str] = "", verbose: int = 0
121
+ ) -> Union[str, ModelProto]:
122
+ """
123
+ Saves a model created by function :func:`create_model_builder`.
124
+ If out_dir is empty or not specified, the function still returns the
125
+ generated model.
126
+ """
127
+ import onnx_ir
128
+
129
+ if verbose:
130
+ print(f"[save_model_builder] Saving ONNX model in {out_dir!r}")
131
+
132
+ # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
133
+ already_quantized_in_qdq_format = (
134
+ self.quant_type is not None and self.quant_attrs["use_qdq"]
135
+ )
136
+ model = (
137
+ self.to_int4()
138
+ if self.onnx_dtype in {onnx_ir.DataType.INT4, onnx_ir.DataType.UINT4}
139
+ and not already_quantized_in_qdq_format
140
+ else self.model
141
+ )
142
+ model.graph.sort()
143
+ if not out_dir:
144
+ return onnx_ir.to_proto(model)
145
+
146
+ out_path = os.path.join(out_dir, self.filename)
147
+ data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
148
+
149
+ # Save ONNX model with only one external data file and delete any existing duplicate copies
150
+ out_path = os.path.join(out_dir, self.filename)
151
+ data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
152
+ if os.path.exists(out_path):
153
+ if verbose:
154
+ print(f"[save_model_builder] Overwriting {out_path!r}")
155
+ os.remove(out_path)
156
+ if os.path.exists(data_path):
157
+ if verbose:
158
+ print(f"[save_model_builder] Overwriting {data_path!r}")
159
+ os.remove(data_path)
160
+
161
+ onnx_ir.save(
162
+ model,
163
+ out_path,
164
+ external_data=os.path.basename(data_path),
165
+ size_threshold_bytes=2**10,
166
+ )
167
+ if verbose:
168
+ print(f"[save_model_builder] saved in {out_dir!r}")
169
+
170
+ return out_path
171
+
172
+
173
+ def create_model_builder(
174
+ config: Any,
175
+ model: "torch.nn.Module", # noqa: F821
176
+ cache_dir: str,
177
+ precision: str = "fp32",
178
+ execution_provider: str = "cpu",
179
+ verbose: int = 0,
180
+ **extra_options,
181
+ ) -> "Model": # noqa: F821
182
+ """
183
+ Creates a model based on a configuration.
184
+ The onnx model is returned by function :func:`save_model_builder`.
185
+
186
+ :param config: configuration
187
+ :param cache_dir: cache directory
188
+ :param precision: precision
189
+ :param execution_provider: execution provider
190
+ :param verbose: verbosity
191
+ :param extra_options: extra options
192
+ :return: model
193
+ """
194
+ assert cache_dir, "create_model_builder does not work without cache_dir."
195
+ assert os.path.exists(cache_dir), f"cache_dir={cache_dir!r} does not exists"
196
+ precision = {"float32": "fp32", "float16": "fp16", "bfloat16": "bfp16"}.get(
197
+ precision, precision
198
+ )
199
+ download_model_builder_to_cache()
200
+ builder = import_model_builder()
201
+ io_dtype = builder.set_io_dtype(precision, execution_provider, extra_options)
202
+
203
+ arch_map = {
204
+ "ChatGLMForConditionalGeneration": builder.ChatGLMModel,
205
+ "ChatGLMModel": builder.ChatGLMModel,
206
+ "Ernie4_5_ForCausalLM": builder.ErnieModel,
207
+ "GemmaForCausalLM": builder.Gemma2Model,
208
+ "Gemma2ForCausalLM": builder.Gemma2Model,
209
+ "Gemma3ForCausalLM": builder.Gemma3Model,
210
+ "Gemma3ForConditionalGeneration": builder.Gemma3Model,
211
+ "GraniteForCausalLM": builder.GraniteModel,
212
+ "GptOssForCausalLM": builder.GPTOSSModel,
213
+ "LlamaForCausalLM": builder.LlamaModel,
214
+ "MistralForCausalLM": builder.MistralModel,
215
+ "NemotronForCausalLM": builder.NemotronModel,
216
+ "OlmoForCausalLM": builder.OLMoModel,
217
+ "PhiForCausalLM": builder.PhiModel,
218
+ "Phi3ForCausalLM": (
219
+ lambda config, *args: (
220
+ (
221
+ builder.Phi3MiniModel
222
+ if config.max_position_embeddings
223
+ == config.original_max_position_embeddings
224
+ else builder.Phi3MiniLongRoPEModel
225
+ )(config, *args)
226
+ )
227
+ ),
228
+ "PhiMoEForCausalLM": builder.Phi3MoELongRoPEModel,
229
+ "Phi3SmallForCausalLM": (
230
+ lambda config, *args: (
231
+ (
232
+ builder.Phi3SmallModel
233
+ if config.max_position_embeddings
234
+ == config.original_max_position_embeddings
235
+ else builder.Phi3SmallLongRoPEModel
236
+ )(config, *args)
237
+ )
238
+ ),
239
+ "Phi3VForCausalLM": builder.Phi3VModel,
240
+ "Phi4MMForCausalLM": builder.Phi4MMModel,
241
+ "Qwen2ForCausalLM": builder.QwenModel,
242
+ "Qwen3ForCausalLM": builder.Qwen3Model,
243
+ "SmolLM3ForCausalLM": builder.SmolLM3Model,
244
+ }
245
+
246
+ assert config.architectures[0] in arch_map, (
247
+ f"Unable find {config.architectures[0]!r} in the supported list "
248
+ f"of architectures: {sorted(arch_map)}"
249
+ )
250
+
251
+ # Additional validations.
252
+ post = None
253
+ if config.architectures[0] in ("ChatGLMForConditionalGeneration", "ChatGLMModel"):
254
+ # Quantized ChatGLM model has ChatGLMForConditionalGeneration
255
+ # as architecture whereas HF model as the latter
256
+ config.hidden_act = "swiglu"
257
+ elif config.architectures[0] == "Gemma2ForCausalLM":
258
+ assert precision == "bfp16", (
259
+ f"architecture {config.architectures[0]!r} loses accuracy "
260
+ f"with float16 precision, use bfp16."
261
+ )
262
+ elif config.architectures[0] == "Gemma3ForCausalLM":
263
+ assert precision == "bfp16", (
264
+ f"architecture {config.architectures[0]!r} loses accuracy "
265
+ f"with float16 precision, use bfp16."
266
+ )
267
+
268
+ def _post(onnx_model):
269
+ onnx_model.model_type = "gemma3_text"
270
+
271
+ post = _post
272
+ elif config.architectures[0] == "Gemma3ForConditionalGeneration":
273
+ assert extra_options.get("exclude_embeds", False), (
274
+ f"This is only generating the text component of architecture "
275
+ f"{config.architectures[0]!r}. Set extra_options exclude_embeds=true."
276
+ )
277
+ assert precision == "bfp16", (
278
+ f"architecture {config.architectures[0]!r} loses accuracy "
279
+ f"with float16 precision, use bfp16."
280
+ )
281
+ text_config = config.text_config
282
+ for key in text_config:
283
+ if not hasattr(config, key):
284
+ setattr(config, key, getattr(text_config, key))
285
+ elif config.architectures[0] == "GptOssForCausalLM":
286
+ delattr(config, "quantization_config")
287
+ elif (
288
+ config.architectures[0] == "PhiMoEForCausalLM"
289
+ and config.max_position_embeddings != config.original_max_position_embeddings
290
+ ):
291
+ assert execution_provider == "cuda", (
292
+ f"architecture {config.architectures[0]!r} works on 'cuda' "
293
+ f"because `MoE` is only supported for CUDA in ONNX Runtime."
294
+ )
295
+ assert precision == "int4", f"architecture {config.architectures[0]!r} supports int4."
296
+ elif config.architectures[0] == "Phi3VForCausalLM":
297
+ assert extra_options.get("exclude_embeds", False), (
298
+ f"This is only generating the text component of architecture "
299
+ f"{config.architectures[0]!r}. Set extra_options exclude_embeds=true."
300
+ )
301
+ elif config.architectures[0] == "Phi4MMForCausalLM":
302
+ assert extra_options.get("exclude_embeds", False), (
303
+ f"This is only generating the text component of architecture "
304
+ f"{config.architectures[0]!r}. Set extra_options exclude_embeds=true."
305
+ )
306
+
307
+ cls = arch_map[config.architectures[0]]
308
+
309
+ # ModelBuilder does not like None values for some parameters.
310
+ remove = set()
311
+ for c in ["head_dim"]:
312
+ if hasattr(config, c) and getattr(config, c) is None:
313
+ remove.add(c)
314
+ for c in remove:
315
+ delattr(config, c)
316
+
317
+ convert = {
318
+ "fp32": TensorProto.FLOAT,
319
+ "fp16": TensorProto.FLOAT16,
320
+ "bfp16": TensorProto.BFLOAT16,
321
+ }
322
+ assert (
323
+ precision in convert
324
+ ), f"Unexpected value for precision={precision!r}, should be in {convert}"
325
+ onnx_model = cls(
326
+ config, io_dtype, convert[precision], execution_provider, cache_dir, extra_options
327
+ )
328
+
329
+ if post:
330
+ post(onnx_model)
331
+ _make_model(onnx_model, model, verbose=verbose)
332
+
333
+ assert onnx_model.model, (
334
+ f"No node in the model, io_dtype={io_dtype!r}, "
335
+ f"precision={precision!r}, execution_provider={execution_provider!r}, "
336
+ f"extra_options={extra_options!r}, cache_dir={cache_dir!r}, "
337
+ f"\n-- config --\n{config}"
338
+ )
339
+ # onnx_model.make_genai_config(hf_name, extra_kwargs, output_dir)
340
+ # onnx_model.save_processing(hf_name, extra_kwargs, output_dir)
341
+ return onnx_model
342
+
343
+
344
+ def find_names_pattern(names: List[str]) -> str:
345
+ """
346
+ Finds a repeatable patterns in a list of names.
347
+ It tries to locate the figures.
348
+
349
+ .. runpython::
350
+ :showcode:
351
+
352
+ from onnx_diagnostic.helpers.model_builder_helper import find_names_pattern
353
+ pattern = find_names_pattern(["past_key_values_key_0", "past_key_values_key_1"])
354
+ print(pattern)
355
+ """
356
+ patterns = [re.sub(r"(\d+)", r"%d", t) for t in names]
357
+ unique = set(patterns)
358
+ assert (
359
+ len(unique) == 1
360
+ ), f"Unable to guess a pattern from {names} which led to the unique patterns {unique}"
361
+ return patterns[0]
362
+
363
+
364
+ def make_genai_config(
365
+ config,
366
+ onnx_filename: str,
367
+ ) -> Dict:
368
+ """
369
+ Creates genai config file for a model.
370
+
371
+ :param config: configuration from transformers
372
+ :param onnx_filename: onnx configuration
373
+ :return: configuration
374
+ """
375
+ onx = load_model(onnx_filename, load_external_data=False)
376
+ config = copy.deepcopy(config)
377
+ defaults = {
378
+ "bos_token_id": None,
379
+ "do_sample": False,
380
+ "eos_token_id": None,
381
+ "pad_token_id": None,
382
+ "temperature": 1.0,
383
+ "top_k": 50,
384
+ "top_p": 1.0,
385
+ }
386
+ for key, default_val in defaults.items():
387
+ if not hasattr(config, key):
388
+ setattr(config, key, default_val)
389
+
390
+ bos_token_id = (
391
+ config.bos_token_id
392
+ if hasattr(config, "bos_token_id") and config.bos_token_id is not None
393
+ else 1
394
+ )
395
+ eos_token_id = config.eos_token_id
396
+ pad_token_id = (
397
+ config.pad_token_id
398
+ if hasattr(config, "pad_token_id") and config.pad_token_id is not None
399
+ else (
400
+ config.eos_token_id[0]
401
+ if isinstance(config.eos_token_id, list)
402
+ else config.eos_token_id
403
+ )
404
+ )
405
+ input_names = [i.name for i in onx.graph.input]
406
+ output_names = [i.name for i in onx.graph.output]
407
+ past_key_values = [s for s in input_names if s.startswith("past_key_value")]
408
+ first = [i for i in onx.graph.input if i.name == past_key_values[0]][0] # noqa: RUF015
409
+ shape = tuple(d.dim_value or d.dim_param for d in first.type.tensor_type.shape.dim)
410
+ return {
411
+ "model": {
412
+ "bos_token_id": bos_token_id,
413
+ "context_length": config.max_position_embeddings,
414
+ "decoder": {
415
+ "session_options": {
416
+ "log_id": "onnxruntime-genai",
417
+ "provider_options": [],
418
+ },
419
+ "filename": os.path.split(onnx_filename)[-1],
420
+ "head_size": shape[-1],
421
+ "hidden_size": config.hidden_size,
422
+ "inputs": {
423
+ "input_ids": input_names[0],
424
+ "attention_mask": input_names[1],
425
+ "past_key_names": find_names_pattern(input_names[2::2]),
426
+ "past_value_names": find_names_pattern(input_names[3::2]),
427
+ },
428
+ "outputs": {
429
+ "logits": output_names[0],
430
+ "present_key_names": find_names_pattern(output_names[1::2]),
431
+ "present_value_names": find_names_pattern(output_names[2::2]),
432
+ },
433
+ "num_attention_heads": config.num_attention_heads,
434
+ "num_hidden_layers": len(past_key_values) // 2,
435
+ "num_key_value_heads": shape[1],
436
+ },
437
+ "eos_token_id": eos_token_id,
438
+ "pad_token_id": pad_token_id,
439
+ "type": config.model_type,
440
+ # if "For" in self.model_type else len(self.model_type)].lower(),
441
+ "vocab_size": config.vocab_size,
442
+ },
443
+ "search": {
444
+ "diversity_penalty": (
445
+ config.diversity_penalty if hasattr(config, "diversity_penalty") else 0.0
446
+ ),
447
+ "do_sample": config.do_sample if hasattr(config, "do_sample") else False,
448
+ "early_stopping": True,
449
+ "length_penalty": (
450
+ config.length_penalty if hasattr(config, "length_penalty") else 1.0
451
+ ),
452
+ "max_length": config.max_position_embeddings,
453
+ "min_length": 0,
454
+ "no_repeat_ngram_size": (
455
+ config.no_repeat_ngram_size if hasattr(config, "no_repeat_ngram_size") else 0
456
+ ),
457
+ "num_beams": config.num_beams if hasattr(config, "num_beams") else 1,
458
+ "num_return_sequences": (
459
+ config.num_return_sequences if hasattr(config, "num_return_sequences") else 1
460
+ ),
461
+ "past_present_share_buffer": False,
462
+ "repetition_penalty": (
463
+ config.repetition_penalty if hasattr(config, "repetition_penalty") else 1.0
464
+ ),
465
+ "temperature": config.temperature if hasattr(config, "temperature") else 1.0,
466
+ "top_k": config.top_k if hasattr(config, "top_k") else 50,
467
+ "top_p": config.top_p if hasattr(config, "top_p") else 1.0,
468
+ },
469
+ }