onnx-diagnostic 0.6.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +18 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/ext_test_case.py +3 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/helper.py +6 -5
- onnx_diagnostic/helpers/model_builder_helper.py +24 -8
- onnx_diagnostic/helpers/rt_helper.py +5 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +518 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +106 -37
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +12 -25
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +130 -16
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +88 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +115 -15
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +38 -23
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
from dataclasses import dataclass
|
|
3
|
+
from functools import wraps
|
|
3
4
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
4
5
|
import torch
|
|
5
6
|
import transformers
|
|
@@ -531,3 +532,90 @@ class patched_GenerationMixin:
|
|
|
531
532
|
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
|
532
533
|
model_inputs.pop("labels", None)
|
|
533
534
|
return model_inputs
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
def patched_dynamic_rope_update(rope_forward):
|
|
538
|
+
"""
|
|
539
|
+
patch:transformers.modeling_rope_utils.dynamic_rope_update
|
|
540
|
+
"""
|
|
541
|
+
|
|
542
|
+
def longrope_frequency_update(self, position_ids, device):
|
|
543
|
+
seq_len = torch.max(position_ids) + 1
|
|
544
|
+
if hasattr(self.config, "original_max_position_embeddings"):
|
|
545
|
+
original_max_position_embeddings = self.config.original_max_position_embeddings
|
|
546
|
+
else:
|
|
547
|
+
original_max_position_embeddings = self.config.max_position_embeddings
|
|
548
|
+
# At export time, seq_len is unknown.
|
|
549
|
+
long_inv_freq, _ = self.rope_init_fn(
|
|
550
|
+
self.config, device, seq_len=original_max_position_embeddings + 1
|
|
551
|
+
)
|
|
552
|
+
original_inv_freq = self.original_inv_freq.to(device)
|
|
553
|
+
|
|
554
|
+
cond = (seq_len > original_max_position_embeddings).item()
|
|
555
|
+
inv_freq = torch.cond(
|
|
556
|
+
cond,
|
|
557
|
+
(lambda x, y: x.clone()),
|
|
558
|
+
(lambda x, y: y.clone()),
|
|
559
|
+
[long_inv_freq, original_inv_freq],
|
|
560
|
+
)
|
|
561
|
+
self.inv_freq = inv_freq
|
|
562
|
+
# if seq_len > original_max_position_embeddings:
|
|
563
|
+
# self.inv_freq = self.long_inv_freq
|
|
564
|
+
# else:
|
|
565
|
+
# self.inv_freq = self.original_inv_freq
|
|
566
|
+
|
|
567
|
+
def dynamic_frequency_update(self, position_ids, device):
|
|
568
|
+
seq_len = torch.max(position_ids) + 1
|
|
569
|
+
if seq_len > self.max_seq_len_cached: # growth
|
|
570
|
+
inv_freq, self.attention_scaling = self.rope_init_fn(
|
|
571
|
+
self.config, device, seq_len=seq_len
|
|
572
|
+
)
|
|
573
|
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
574
|
+
self.max_seq_len_cached = seq_len
|
|
575
|
+
|
|
576
|
+
if (
|
|
577
|
+
seq_len < self.original_max_seq_len
|
|
578
|
+
and self.max_seq_len_cached > self.original_max_seq_len
|
|
579
|
+
):
|
|
580
|
+
self.original_inv_freq = self.original_inv_freq.to(device)
|
|
581
|
+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
|
|
582
|
+
self.max_seq_len_cached = self.original_max_seq_len
|
|
583
|
+
|
|
584
|
+
@wraps(rope_forward)
|
|
585
|
+
def wrapper(self, x, position_ids):
|
|
586
|
+
if "dynamic" in self.rope_type:
|
|
587
|
+
dynamic_frequency_update(self, position_ids, device=x.device)
|
|
588
|
+
elif self.rope_type == "longrope":
|
|
589
|
+
longrope_frequency_update(self, position_ids, device=x.device)
|
|
590
|
+
return rope_forward(self, x, position_ids)
|
|
591
|
+
|
|
592
|
+
return wrapper
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
class patched_Phi3RotaryEmbedding(torch.nn.Module):
|
|
596
|
+
_PATCHES_ = ["forward"]
|
|
597
|
+
_PATCHED_CLASS_ = transformers.models.phi3.modeling_phi3.Phi3RotaryEmbedding
|
|
598
|
+
|
|
599
|
+
@torch.no_grad()
|
|
600
|
+
@patched_dynamic_rope_update
|
|
601
|
+
def forward(self, x, position_ids):
|
|
602
|
+
inv_freq_expanded = (
|
|
603
|
+
self.inv_freq[None, :, None]
|
|
604
|
+
.float()
|
|
605
|
+
.expand(position_ids.shape[0], -1, 1)
|
|
606
|
+
.to(x.device)
|
|
607
|
+
)
|
|
608
|
+
position_ids_expanded = position_ids[:, None, :].float()
|
|
609
|
+
|
|
610
|
+
device_type = (
|
|
611
|
+
x.device.type
|
|
612
|
+
if isinstance(x.device.type, str) and x.device.type != "mps"
|
|
613
|
+
else "cpu"
|
|
614
|
+
)
|
|
615
|
+
with torch.autocast(device_type=device_type, enabled=False): # Force float32
|
|
616
|
+
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
|
|
617
|
+
emb = torch.cat((freqs, freqs), dim=-1)
|
|
618
|
+
cos = emb.cos() * self.attention_scaling
|
|
619
|
+
sin = emb.sin() * self.attention_scaling
|
|
620
|
+
|
|
621
|
+
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
|
@@ -3951,3 +3951,145 @@ def _ccached_facebook_bart_large_cnn():
|
|
|
3951
3951
|
"vocab_size": 50264,
|
|
3952
3952
|
}
|
|
3953
3953
|
)
|
|
3954
|
+
|
|
3955
|
+
|
|
3956
|
+
def _ccached_microsoft_phi4_reasoning():
|
|
3957
|
+
"microsoft/Phi-4-mini-reasoning"
|
|
3958
|
+
return transformers.Phi3Config(
|
|
3959
|
+
**{
|
|
3960
|
+
"architectures": ["Phi3ForCausalLM"],
|
|
3961
|
+
"attention_bias": false,
|
|
3962
|
+
"attention_dropout": 0.0,
|
|
3963
|
+
"bos_token_id": 199999,
|
|
3964
|
+
"embd_pdrop": 0.0,
|
|
3965
|
+
"eos_token_id": 199999,
|
|
3966
|
+
"full_attn_mod": 1,
|
|
3967
|
+
"hidden_act": "silu",
|
|
3968
|
+
"hidden_size": 3072,
|
|
3969
|
+
"initializer_range": 0.02,
|
|
3970
|
+
"intermediate_size": 8192,
|
|
3971
|
+
"interpolate_factor": 1,
|
|
3972
|
+
"lm_head_bias": false,
|
|
3973
|
+
"max_position_embeddings": 131072,
|
|
3974
|
+
"mlp_bias": false,
|
|
3975
|
+
"model_type": "phi3",
|
|
3976
|
+
"num_attention_heads": 24,
|
|
3977
|
+
"num_hidden_layers": 32,
|
|
3978
|
+
"num_key_value_heads": 8,
|
|
3979
|
+
"original_max_position_embeddings": 4096,
|
|
3980
|
+
"pad_token_id": 199999,
|
|
3981
|
+
"partial_rotary_factor": 0.75,
|
|
3982
|
+
"resid_pdrop": 0.0,
|
|
3983
|
+
"rms_norm_eps": 1e-05,
|
|
3984
|
+
"rope_scaling": {
|
|
3985
|
+
"long_factor": [
|
|
3986
|
+
1,
|
|
3987
|
+
1.118320672,
|
|
3988
|
+
1.250641126,
|
|
3989
|
+
1.398617824,
|
|
3990
|
+
1.564103225,
|
|
3991
|
+
1.74916897,
|
|
3992
|
+
1.956131817,
|
|
3993
|
+
2.187582649,
|
|
3994
|
+
2.446418898,
|
|
3995
|
+
2.735880826,
|
|
3996
|
+
3.059592084,
|
|
3997
|
+
3.421605075,
|
|
3998
|
+
3.826451687,
|
|
3999
|
+
4.279200023,
|
|
4000
|
+
4.785517845,
|
|
4001
|
+
5.351743533,
|
|
4002
|
+
5.984965424,
|
|
4003
|
+
6.693110555,
|
|
4004
|
+
7.485043894,
|
|
4005
|
+
8.370679318,
|
|
4006
|
+
9.36110372,
|
|
4007
|
+
10.4687158,
|
|
4008
|
+
11.70738129,
|
|
4009
|
+
13.09260651,
|
|
4010
|
+
14.64173252,
|
|
4011
|
+
16.37415215,
|
|
4012
|
+
18.31155283,
|
|
4013
|
+
20.47818807,
|
|
4014
|
+
22.90118105,
|
|
4015
|
+
25.61086418,
|
|
4016
|
+
28.64115884,
|
|
4017
|
+
32.03,
|
|
4018
|
+
32.1,
|
|
4019
|
+
32.13,
|
|
4020
|
+
32.23,
|
|
4021
|
+
32.6,
|
|
4022
|
+
32.61,
|
|
4023
|
+
32.64,
|
|
4024
|
+
32.66,
|
|
4025
|
+
32.7,
|
|
4026
|
+
32.71,
|
|
4027
|
+
32.93,
|
|
4028
|
+
32.97,
|
|
4029
|
+
33.28,
|
|
4030
|
+
33.49,
|
|
4031
|
+
33.5,
|
|
4032
|
+
44.16,
|
|
4033
|
+
47.77,
|
|
4034
|
+
],
|
|
4035
|
+
"short_factor": [
|
|
4036
|
+
1,
|
|
4037
|
+
1.118320672,
|
|
4038
|
+
1.250641126,
|
|
4039
|
+
1.398617824,
|
|
4040
|
+
1.564103225,
|
|
4041
|
+
1.74916897,
|
|
4042
|
+
1.956131817,
|
|
4043
|
+
2.187582649,
|
|
4044
|
+
2.446418898,
|
|
4045
|
+
2.735880826,
|
|
4046
|
+
3.059592084,
|
|
4047
|
+
3.421605075,
|
|
4048
|
+
3.826451687,
|
|
4049
|
+
4.279200023,
|
|
4050
|
+
4.785517845,
|
|
4051
|
+
5.351743533,
|
|
4052
|
+
5.984965424,
|
|
4053
|
+
6.693110555,
|
|
4054
|
+
7.485043894,
|
|
4055
|
+
8.370679318,
|
|
4056
|
+
9.36110372,
|
|
4057
|
+
10.4687158,
|
|
4058
|
+
11.70738129,
|
|
4059
|
+
13.09260651,
|
|
4060
|
+
14.64173252,
|
|
4061
|
+
16.37415215,
|
|
4062
|
+
18.31155283,
|
|
4063
|
+
20.47818807,
|
|
4064
|
+
22.90118105,
|
|
4065
|
+
25.61086418,
|
|
4066
|
+
28.64115884,
|
|
4067
|
+
32.03,
|
|
4068
|
+
32.1,
|
|
4069
|
+
32.13,
|
|
4070
|
+
32.23,
|
|
4071
|
+
32.6,
|
|
4072
|
+
32.61,
|
|
4073
|
+
32.64,
|
|
4074
|
+
32.66,
|
|
4075
|
+
32.7,
|
|
4076
|
+
32.71,
|
|
4077
|
+
32.93,
|
|
4078
|
+
32.97,
|
|
4079
|
+
33.28,
|
|
4080
|
+
33.49,
|
|
4081
|
+
33.5,
|
|
4082
|
+
44.16,
|
|
4083
|
+
47.77,
|
|
4084
|
+
],
|
|
4085
|
+
"type": "longrope",
|
|
4086
|
+
},
|
|
4087
|
+
"rope_theta": 10000.0,
|
|
4088
|
+
"sliding_window": 262144,
|
|
4089
|
+
"tie_word_embeddings": true,
|
|
4090
|
+
"torch_dtype": "bfloat16",
|
|
4091
|
+
"transformers_version": "4.50.0",
|
|
4092
|
+
"use_cache": true,
|
|
4093
|
+
"vocab_size": 200064,
|
|
4094
|
+
}
|
|
4095
|
+
)
|
|
@@ -4,6 +4,7 @@ import os
|
|
|
4
4
|
import sys
|
|
5
5
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
6
6
|
import time
|
|
7
|
+
import numpy as np
|
|
7
8
|
import onnx
|
|
8
9
|
import onnxscript
|
|
9
10
|
import onnxscript.rewriter.ort_fusions as ort_fusions
|
|
@@ -17,6 +18,7 @@ from ..helpers.cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
|
17
18
|
from ..tasks import random_input_kwargs
|
|
18
19
|
from ..torch_export_patches import torch_export_patches
|
|
19
20
|
from ..torch_export_patches.patch_inputs import use_dyn_not_str
|
|
21
|
+
from ..reference import TorchOnnxEvaluator
|
|
20
22
|
from .hghub import get_untrained_model_with_inputs
|
|
21
23
|
|
|
22
24
|
|
|
@@ -192,11 +194,16 @@ def _quiet_or_not_quiet(
|
|
|
192
194
|
summary: Dict[str, Any],
|
|
193
195
|
data: Optional[Dict[str, Any]],
|
|
194
196
|
fct: Callable,
|
|
197
|
+
repeat: int = 1,
|
|
198
|
+
warmup: int = 0,
|
|
195
199
|
) -> Any:
|
|
196
200
|
begin = time.perf_counter()
|
|
197
201
|
if quiet:
|
|
198
202
|
try:
|
|
199
|
-
|
|
203
|
+
res = fct()
|
|
204
|
+
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
205
|
+
if warmup + repeat == 1:
|
|
206
|
+
return res
|
|
200
207
|
except Exception as e:
|
|
201
208
|
summary[f"ERR_{suffix}"] = str(e)
|
|
202
209
|
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
@@ -204,11 +211,45 @@ def _quiet_or_not_quiet(
|
|
|
204
211
|
return {f"ERR_{suffix}": e}
|
|
205
212
|
data[f"ERR_{suffix}"] = e
|
|
206
213
|
return None
|
|
207
|
-
|
|
214
|
+
else:
|
|
215
|
+
res = fct()
|
|
208
216
|
summary[f"time_{suffix}"] = time.perf_counter() - begin
|
|
217
|
+
if warmup + repeat > 1:
|
|
218
|
+
if suffix == "run":
|
|
219
|
+
res = torch_deepcopy(res)
|
|
220
|
+
summary[f"{suffix}_output"] = string_type(res, with_shape=True, with_min_max=True)
|
|
221
|
+
summary[f"{suffix}_warmup"] = warmup
|
|
222
|
+
summary[f"{suffix}_repeat"] = repeat
|
|
223
|
+
for _w in range(max(0, warmup - 1)):
|
|
224
|
+
t = fct()
|
|
225
|
+
summary[f"io_{suffix}_{_w+1}"] = string_type(t, with_shape=True, with_min_max=True)
|
|
226
|
+
summary[f"time_{suffix}_warmup"] = time.perf_counter() - begin
|
|
227
|
+
times = []
|
|
228
|
+
for _r in range(repeat):
|
|
229
|
+
begin = time.perf_counter()
|
|
230
|
+
t = fct()
|
|
231
|
+
times.append(time.perf_counter() - begin)
|
|
232
|
+
a = np.array(times)
|
|
233
|
+
summary[f"time_{suffix}_latency"] = a.mean()
|
|
234
|
+
summary[f"time_{suffix}_latency_std"] = a.std()
|
|
235
|
+
summary[f"time_{suffix}_latency_min"] = a.min()
|
|
236
|
+
summary[f"time_{suffix}_latency_min"] = a.max()
|
|
209
237
|
return res
|
|
210
238
|
|
|
211
239
|
|
|
240
|
+
def shrink_config(cfg: Dict[str, Any]) -> Dict[str, Any]:
|
|
241
|
+
"""Shrinks the configuration before it gets added to the information to log."""
|
|
242
|
+
new_cfg = {}
|
|
243
|
+
for k, v in cfg.items():
|
|
244
|
+
|
|
245
|
+
new_cfg[k] = (
|
|
246
|
+
v
|
|
247
|
+
if (not isinstance(v, (list, tuple, set, dict)) or len(v) < 50)
|
|
248
|
+
else (v.__class__("...") if isinstance(v, (list, tuple)) else "...")
|
|
249
|
+
)
|
|
250
|
+
return new_cfg
|
|
251
|
+
|
|
252
|
+
|
|
212
253
|
def validate_model(
|
|
213
254
|
model_id: str,
|
|
214
255
|
task: Optional[str] = None,
|
|
@@ -231,6 +272,9 @@ def validate_model(
|
|
|
231
272
|
model_options: Optional[Dict[str, Any]] = None,
|
|
232
273
|
subfolder: Optional[str] = None,
|
|
233
274
|
opset: Optional[int] = None,
|
|
275
|
+
runtime: str = "onnxruntime",
|
|
276
|
+
repeat: int = 1,
|
|
277
|
+
warmup: int = 0,
|
|
234
278
|
) -> Tuple[Dict[str, Union[int, float, str]], Dict[str, Any]]:
|
|
235
279
|
"""
|
|
236
280
|
Validates a model.
|
|
@@ -267,6 +311,10 @@ def validate_model(
|
|
|
267
311
|
``num_hidden_layers`` or ``attn_implementation``
|
|
268
312
|
:param subfolder: version or subfolders to uses when retrieving a model id
|
|
269
313
|
:param opset: onnx opset to use for the conversion
|
|
314
|
+
:param runtime: onnx runtime to use to check about discrepancies,
|
|
315
|
+
only if `do_run` is true
|
|
316
|
+
:param repeat: number of time to measure the model
|
|
317
|
+
:param warmup: warmup the model first
|
|
270
318
|
:return: two dictionaries, one with some metrics,
|
|
271
319
|
another one with whatever the function produces
|
|
272
320
|
|
|
@@ -295,6 +343,7 @@ def validate_model(
|
|
|
295
343
|
version_ortfusiontype=ortfusiontype or "",
|
|
296
344
|
version_stop_if_static=str(stop_if_static),
|
|
297
345
|
version_exporter=exporter or "",
|
|
346
|
+
version_runtime=runtime,
|
|
298
347
|
)
|
|
299
348
|
)
|
|
300
349
|
if opset:
|
|
@@ -436,7 +485,9 @@ def validate_model(
|
|
|
436
485
|
if summary["model_module"] in sys.modules:
|
|
437
486
|
summary["model_file"] = str(sys.modules[summary["model_module"]].__file__) # type: ignore[index]
|
|
438
487
|
summary["model_config_class"] = data["configuration"].__class__.__name__
|
|
439
|
-
summary["model_config"] = str(data["configuration"].to_dict()).replace(
|
|
488
|
+
summary["model_config"] = str(shrink_config(data["configuration"].to_dict())).replace(
|
|
489
|
+
" ", ""
|
|
490
|
+
)
|
|
440
491
|
summary["model_id"] = model_id
|
|
441
492
|
|
|
442
493
|
if verbose:
|
|
@@ -460,7 +511,13 @@ def validate_model(
|
|
|
460
511
|
model = data["model"]
|
|
461
512
|
|
|
462
513
|
expected = _quiet_or_not_quiet(
|
|
463
|
-
quiet,
|
|
514
|
+
quiet,
|
|
515
|
+
"run",
|
|
516
|
+
summary,
|
|
517
|
+
data,
|
|
518
|
+
(lambda m=model, inp=inputs: m(**torch_deepcopy(inp))),
|
|
519
|
+
repeat=repeat,
|
|
520
|
+
warmup=warmup,
|
|
464
521
|
)
|
|
465
522
|
if "ERR_run" in summary:
|
|
466
523
|
return summary, data
|
|
@@ -522,7 +579,7 @@ def validate_model(
|
|
|
522
579
|
|
|
523
580
|
disc = max_diff(data["expected"], expected)
|
|
524
581
|
for k, v in disc.items():
|
|
525
|
-
summary[f"disc_patched_{k}"] = v
|
|
582
|
+
summary[f"disc_patched_{k}"] = str(v)
|
|
526
583
|
if verbose:
|
|
527
584
|
print("[validate_model] done (patched run)")
|
|
528
585
|
print(f"[validate_model] patched discrepancies={string_diff(disc)}")
|
|
@@ -618,7 +675,14 @@ def validate_model(
|
|
|
618
675
|
return summary, data
|
|
619
676
|
|
|
620
677
|
if do_run:
|
|
621
|
-
summary_valid, data = validate_onnx_model(
|
|
678
|
+
summary_valid, data = validate_onnx_model(
|
|
679
|
+
data=data,
|
|
680
|
+
quiet=quiet,
|
|
681
|
+
verbose=verbose,
|
|
682
|
+
runtime=runtime,
|
|
683
|
+
repeat=repeat,
|
|
684
|
+
warmup=warmup,
|
|
685
|
+
)
|
|
622
686
|
summary.update(summary_valid)
|
|
623
687
|
|
|
624
688
|
if ortfusiontype and "onnx_filename" in data:
|
|
@@ -671,7 +735,13 @@ def validate_model(
|
|
|
671
735
|
|
|
672
736
|
if do_run:
|
|
673
737
|
summary_valid, data = validate_onnx_model(
|
|
674
|
-
data=data,
|
|
738
|
+
data=data,
|
|
739
|
+
quiet=quiet,
|
|
740
|
+
verbose=verbose,
|
|
741
|
+
flavour=flavour,
|
|
742
|
+
runtime=runtime,
|
|
743
|
+
repeat=repeat,
|
|
744
|
+
warmup=warmup,
|
|
675
745
|
)
|
|
676
746
|
summary.update(summary_valid)
|
|
677
747
|
|
|
@@ -883,6 +953,9 @@ def validate_onnx_model(
|
|
|
883
953
|
quiet: bool = False,
|
|
884
954
|
verbose: int = 0,
|
|
885
955
|
flavour: Optional[str] = None,
|
|
956
|
+
runtime: str = "onnxruntime",
|
|
957
|
+
repeat: int = 1,
|
|
958
|
+
warmup: int = 0,
|
|
886
959
|
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
|
887
960
|
"""
|
|
888
961
|
Verifies that an onnx model produces the same
|
|
@@ -895,6 +968,9 @@ def validate_onnx_model(
|
|
|
895
968
|
:param quiet: catch exception or not
|
|
896
969
|
:param verbose: verbosity
|
|
897
970
|
:param flavour: use a different version of the inputs
|
|
971
|
+
:param runtime: onnx runtime to use, onnxruntime or torch
|
|
972
|
+
:param repeat: run that number of times the model
|
|
973
|
+
:param warmup: warmup the model
|
|
898
974
|
:return: two dictionaries, one with some metrics,
|
|
899
975
|
another one with whatever the function produces
|
|
900
976
|
"""
|
|
@@ -936,18 +1012,28 @@ def validate_onnx_model(
|
|
|
936
1012
|
f"{providers}..., flavour={flavour!r}"
|
|
937
1013
|
)
|
|
938
1014
|
|
|
1015
|
+
cls_runtime = (
|
|
1016
|
+
(
|
|
1017
|
+
lambda model, providers: onnxruntime.InferenceSession(
|
|
1018
|
+
(model.SerializeToString() if isinstance(model, onnx.ModelProto) else model),
|
|
1019
|
+
providers=providers,
|
|
1020
|
+
)
|
|
1021
|
+
)
|
|
1022
|
+
if runtime == "onnxruntime"
|
|
1023
|
+
else (
|
|
1024
|
+
lambda model, providers: TorchOnnxEvaluator(
|
|
1025
|
+
model, providers=providers, verbose=max(verbose - 1, 0)
|
|
1026
|
+
)
|
|
1027
|
+
)
|
|
1028
|
+
)
|
|
939
1029
|
sess = _quiet_or_not_quiet(
|
|
940
1030
|
quiet,
|
|
941
|
-
_mk("
|
|
1031
|
+
_mk("onnx_ort_create"),
|
|
942
1032
|
summary,
|
|
943
1033
|
data,
|
|
944
|
-
(
|
|
945
|
-
lambda source=source, providers=providers: onnxruntime.InferenceSession(
|
|
946
|
-
source, providers=providers
|
|
947
|
-
)
|
|
948
|
-
),
|
|
1034
|
+
(lambda source=source, providers=providers: cls_runtime(source, providers)),
|
|
949
1035
|
)
|
|
950
|
-
if f"ERR_{_mk('
|
|
1036
|
+
if f"ERR_{_mk('onnx_ort_create')}" in summary:
|
|
951
1037
|
return summary, data
|
|
952
1038
|
|
|
953
1039
|
data[_mk("onnx_ort_sess")] = sess
|
|
@@ -975,6 +1061,8 @@ def validate_onnx_model(
|
|
|
975
1061
|
summary,
|
|
976
1062
|
data,
|
|
977
1063
|
(lambda sess=sess, feeds=feeds: sess.run(None, feeds)),
|
|
1064
|
+
repeat=repeat,
|
|
1065
|
+
warmup=warmup,
|
|
978
1066
|
)
|
|
979
1067
|
if f"ERR_{_mk('time_onnx_ort_run')}" in summary:
|
|
980
1068
|
return summary, data
|
|
@@ -1051,7 +1139,7 @@ def call_torch_export_onnx(
|
|
|
1051
1139
|
dynamo=False,
|
|
1052
1140
|
dynamic_axes={
|
|
1053
1141
|
k: v
|
|
1054
|
-
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds)
|
|
1142
|
+
for k, v in CoupleInputsDynamicShapes(args, kwargs, ds) # type: ignore[arg-type]
|
|
1055
1143
|
.replace_by_string()
|
|
1056
1144
|
.items()
|
|
1057
1145
|
if isinstance(v, dict)
|
|
@@ -1229,6 +1317,13 @@ def call_torch_export_custom(
|
|
|
1229
1317
|
"custom-nostrict",
|
|
1230
1318
|
"custom-nostrict-default",
|
|
1231
1319
|
"custom-nostrict-all",
|
|
1320
|
+
"custom-inline",
|
|
1321
|
+
"custom-strict-inline",
|
|
1322
|
+
"custom-strict-default-inline",
|
|
1323
|
+
"custom-strict-all-inline",
|
|
1324
|
+
"custom-nostrict-inline",
|
|
1325
|
+
"custom-nostrict-default-inline",
|
|
1326
|
+
"custom-nostrict-all-inline",
|
|
1232
1327
|
}
|
|
1233
1328
|
assert exporter in available, f"Unexpected value for exporter={exporter!r} in {available}"
|
|
1234
1329
|
assert "model" in data, f"model is missing from data: {sorted(data)}"
|
|
@@ -1269,6 +1364,10 @@ def call_torch_export_custom(
|
|
|
1269
1364
|
),
|
|
1270
1365
|
save_ep=(os.path.join(dump_folder, f"{exporter}.ep") if dump_folder else None),
|
|
1271
1366
|
)
|
|
1367
|
+
inline = "-inline" in exporter
|
|
1368
|
+
if inline:
|
|
1369
|
+
export_options.aten_as_function = set()
|
|
1370
|
+
|
|
1272
1371
|
options = OptimizationOptions(patterns=optimization) if optimization else None
|
|
1273
1372
|
model = data["model"]
|
|
1274
1373
|
kws = dict(
|
|
@@ -1279,6 +1378,7 @@ def call_torch_export_custom(
|
|
|
1279
1378
|
large_model=True,
|
|
1280
1379
|
return_optimize_report=True,
|
|
1281
1380
|
verbose=max(verbose - 2, 0),
|
|
1381
|
+
inline=inline,
|
|
1282
1382
|
)
|
|
1283
1383
|
if opset:
|
|
1284
1384
|
kws["target_opset"] = opset
|