onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -3
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +435 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/ext_test_case.py +23 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
onnx_diagnostic/ext_test_case.py
CHANGED
|
@@ -700,6 +700,19 @@ def requires_onnx(version: str, msg: str = "") -> Callable:
|
|
|
700
700
|
return lambda x: x
|
|
701
701
|
|
|
702
702
|
|
|
703
|
+
def requires_experimental_experiment(version: str, msg: str = "") -> Callable:
|
|
704
|
+
"""Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
|
|
705
|
+
import packaging.version as pv
|
|
706
|
+
import experimental_experiment
|
|
707
|
+
|
|
708
|
+
if pv.Version(experimental_experiment.__version__) < pv.Version(version):
|
|
709
|
+
msg = (
|
|
710
|
+
f"onnx-array-api version {experimental_experiment.__version__} < {version}: {msg}"
|
|
711
|
+
)
|
|
712
|
+
return unittest.skip(msg)
|
|
713
|
+
return lambda x: x
|
|
714
|
+
|
|
715
|
+
|
|
703
716
|
def requires_onnx_array_api(version: str, msg: str = "") -> Callable:
|
|
704
717
|
"""Skips a unit test if :epkg:`onnx-array-api` is not recent enough."""
|
|
705
718
|
import packaging.version as pv
|
|
@@ -774,6 +787,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
774
787
|
def setUpClass(cls):
|
|
775
788
|
logger = logging.getLogger("onnxscript.optimizer.constant_folding")
|
|
776
789
|
logger.setLevel(logging.ERROR)
|
|
790
|
+
warnings.filterwarnings("ignore", category=DeprecationWarning)
|
|
777
791
|
unittest.TestCase.setUpClass()
|
|
778
792
|
|
|
779
793
|
@classmethod
|
|
@@ -1253,6 +1267,7 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1253
1267
|
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`
|
|
1254
1268
|
"""
|
|
1255
1269
|
from .helpers import string_type, string_diff, max_diff
|
|
1270
|
+
from .helpers.torch_helper import torch_deepcopy
|
|
1256
1271
|
from .helpers.rt_helper import make_feeds
|
|
1257
1272
|
from .helpers.ort_session import InferenceSessionForTorch
|
|
1258
1273
|
|
|
@@ -1269,6 +1284,12 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1269
1284
|
model_file = proto
|
|
1270
1285
|
name = proto
|
|
1271
1286
|
proto = onnx.load(name)
|
|
1287
|
+
elif hasattr(proto, "save"):
|
|
1288
|
+
name = f"{test_name}.onnx"
|
|
1289
|
+
proto.save(name)
|
|
1290
|
+
proto = onnx.load(name)
|
|
1291
|
+
elif hasattr(proto, "model_proto"):
|
|
1292
|
+
proto = proto.model_proto
|
|
1272
1293
|
elif not self.unit_test_going():
|
|
1273
1294
|
assert isinstance(
|
|
1274
1295
|
proto, onnx.ModelProto
|
|
@@ -1327,9 +1348,9 @@ class ExtTestCase(unittest.TestCase):
|
|
|
1327
1348
|
if copy_inputs:
|
|
1328
1349
|
expected = [
|
|
1329
1350
|
(
|
|
1330
|
-
model(*
|
|
1351
|
+
model(*torch_deepcopy(inp))
|
|
1331
1352
|
if isinstance(inp, tuple)
|
|
1332
|
-
else model(**
|
|
1353
|
+
else model(**torch_deepcopy(inp))
|
|
1333
1354
|
)
|
|
1334
1355
|
for inp in inputs
|
|
1335
1356
|
]
|
|
@@ -1921,9 +1921,7 @@ class CubeLogsPerformance(CubeLogs):
|
|
|
1921
1921
|
return lambdas[formula]
|
|
1922
1922
|
|
|
1923
1923
|
if formula == "onnx_n_nodes_no_cst":
|
|
1924
|
-
return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(
|
|
1925
|
-
df, "op_onnx__Constant", 0
|
|
1926
|
-
).fillna(0)
|
|
1924
|
+
return lambda df: gdf(df, "onnx_n_nodes", 0) - gdf(df, "op_onnx__Constant", 0)
|
|
1927
1925
|
if formula == "peak_gpu_torch":
|
|
1928
1926
|
return lambda df: gdf(df, "mema_gpu_5_after_export") - gdf(df, "mema_gpu_4_reset")
|
|
1929
1927
|
if formula == "peak_gpu_nvidia":
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from typing import Optional, Union
|
|
2
|
+
import pprint
|
|
3
|
+
import onnx
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def optimize_model(
|
|
7
|
+
algorithm: str,
|
|
8
|
+
model: Union[onnx.ModelProto, str],
|
|
9
|
+
output: Optional[str] = None,
|
|
10
|
+
processor: Optional[str] = None,
|
|
11
|
+
infer_shapes: bool = True,
|
|
12
|
+
remove_shape_info: bool = False,
|
|
13
|
+
verbose: int = 1,
|
|
14
|
+
):
|
|
15
|
+
"""
|
|
16
|
+
Optimizes an onnx model by fusing nodes. It looks for patterns in the graphs
|
|
17
|
+
and replaces them by the corresponding nodes. It also does basic optimization
|
|
18
|
+
such as removing identity nodes or unused nodes.
|
|
19
|
+
|
|
20
|
+
:param algorithm: algorithm to choose
|
|
21
|
+
:param model: model to optimize as a proto or a filename
|
|
22
|
+
:param output: if not empty, the optimized model is saved
|
|
23
|
+
:param processor: optimization are done for the processor
|
|
24
|
+
:param infer_shapes: infer shapes before optimizing, this might not be
|
|
25
|
+
available for all algorithm
|
|
26
|
+
:param remove_shape_info: remove shape information before saving the model
|
|
27
|
+
:param verbose: verbosity level
|
|
28
|
+
:return: optimized model
|
|
29
|
+
|
|
30
|
+
The goal is to make the model faster.
|
|
31
|
+
Argument patterns defines the patterns to apply or the set of patterns.
|
|
32
|
+
It is possible to show statistics or to remove a particular pattern.
|
|
33
|
+
Here are some environment variables which can be used to trigger
|
|
34
|
+
these displays.
|
|
35
|
+
|
|
36
|
+
Available options algorithms, default and default+runtime:
|
|
37
|
+
|
|
38
|
+
- ``DROPPATTERN=<pattern1,patterns2,...>``: do not apply
|
|
39
|
+
those patterns when optimizing a model
|
|
40
|
+
- ``DUMPPATTERNS=<folder>``: dumps all matched and applied nodes when a pattern is applied
|
|
41
|
+
- ``PATTERN=<pattern1,pattern2,...>``: increase verbosity
|
|
42
|
+
for specific patterns to understand why one pattern was not applied,
|
|
43
|
+
this shows which line is rejecting a pattern if it seems one pattern was missed
|
|
44
|
+
"""
|
|
45
|
+
if isinstance(model, str):
|
|
46
|
+
if verbose:
|
|
47
|
+
print(f"[optimize_model] load {model!r}")
|
|
48
|
+
proto = onnx.load(model)
|
|
49
|
+
if verbose:
|
|
50
|
+
print("[optimize_model] done loading.")
|
|
51
|
+
else:
|
|
52
|
+
proto = model
|
|
53
|
+
|
|
54
|
+
if verbose:
|
|
55
|
+
print(f"[optimize_model] optimize with {algorithm!r}")
|
|
56
|
+
if algorithm in {"default", "default+onnxruntime"}:
|
|
57
|
+
from experimental_experiment.xoptim import get_pattern_list
|
|
58
|
+
from experimental_experiment.xbuilder import GraphBuilder, OptimizationOptions
|
|
59
|
+
|
|
60
|
+
pats = get_pattern_list(algorithm)
|
|
61
|
+
|
|
62
|
+
gr = GraphBuilder(
|
|
63
|
+
proto,
|
|
64
|
+
infer_shapes_options=infer_shapes,
|
|
65
|
+
optimization_options=OptimizationOptions(
|
|
66
|
+
patterns=pats,
|
|
67
|
+
verbose=verbose,
|
|
68
|
+
remove_unused=True,
|
|
69
|
+
constant_folding=True,
|
|
70
|
+
remove_identity=True,
|
|
71
|
+
max_iter=max(100, len(proto.graph.node) // 2),
|
|
72
|
+
processor=processor or "CPU",
|
|
73
|
+
),
|
|
74
|
+
)
|
|
75
|
+
if verbose:
|
|
76
|
+
print(f"[optimize_model] starts optimizing with {len(pats)} patterns")
|
|
77
|
+
print(f"[optimize_model] model has {len(proto.graph.node)} nodes")
|
|
78
|
+
opt_onx, report = gr.to_onnx(optimize=True, return_optimize_report=True)
|
|
79
|
+
if verbose:
|
|
80
|
+
print("[optimize_model] optimization report")
|
|
81
|
+
pprint.pprint(report)
|
|
82
|
+
print("[optimize_model] done")
|
|
83
|
+
|
|
84
|
+
elif algorithm == "slim":
|
|
85
|
+
import onnxslim
|
|
86
|
+
|
|
87
|
+
opt_onx = onnxslim.slim(proto, no_shape_infer=not infer_shapes)
|
|
88
|
+
elif algorithm in {"ir", "os_ort"}:
|
|
89
|
+
import onnx_ir
|
|
90
|
+
import onnxscript.optimizer
|
|
91
|
+
from onnxscript.rewriter.ort_fusions import optimize_for_ort
|
|
92
|
+
|
|
93
|
+
model_ir = onnx_ir.from_proto(proto)
|
|
94
|
+
if algorithm == "ir":
|
|
95
|
+
onnxscript.optimizer.optimize(model_ir)
|
|
96
|
+
else:
|
|
97
|
+
optimize_for_ort(model_ir)
|
|
98
|
+
opt_onx = onnx_ir.serde.serialize_model(model_ir)
|
|
99
|
+
|
|
100
|
+
del proto
|
|
101
|
+
if verbose:
|
|
102
|
+
print(f"[optimize_model] done optimizing, model has {len(opt_onx.graph.node)} nodes")
|
|
103
|
+
if remove_shape_info:
|
|
104
|
+
if verbose:
|
|
105
|
+
print(f"[optimize_model] remove shape information {len(opt_onx.graph.value_info)}")
|
|
106
|
+
del opt_onx.graph.value_info[:]
|
|
107
|
+
if verbose:
|
|
108
|
+
print("[optimize_model] done removing shape info")
|
|
109
|
+
|
|
110
|
+
if output:
|
|
111
|
+
if verbose:
|
|
112
|
+
print(f"[optimize_model] save file into {output!r}")
|
|
113
|
+
onnx.save(opt_onx, output, save_as_external_data=True)
|
|
114
|
+
if verbose:
|
|
115
|
+
print("[optimize_model] done saving")
|
|
116
|
+
return opt_onx
|
|
@@ -13,6 +13,10 @@ from .data import get_data
|
|
|
13
13
|
__TASK__ = "image-text-to-text"
|
|
14
14
|
|
|
15
15
|
|
|
16
|
+
def should_have_vision_config(config):
|
|
17
|
+
return config.architectures != ["FuyuForCausalLM"]
|
|
18
|
+
|
|
19
|
+
|
|
16
20
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
17
21
|
"""Reduces a model size."""
|
|
18
22
|
kwargs: Dict[str, Any] = {}
|
|
@@ -477,7 +481,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
477
481
|
"hidden_size",
|
|
478
482
|
"pad_token_id",
|
|
479
483
|
)
|
|
480
|
-
|
|
484
|
+
if should_have_vision_config(config):
|
|
485
|
+
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
|
|
481
486
|
text_config = True
|
|
482
487
|
else:
|
|
483
488
|
check_hasattr(
|
|
@@ -491,7 +496,8 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
491
496
|
"vision_config",
|
|
492
497
|
)
|
|
493
498
|
text_config = False
|
|
494
|
-
|
|
499
|
+
if should_have_vision_config(config):
|
|
500
|
+
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
|
|
495
501
|
kwargs = dict(
|
|
496
502
|
head_dim=(
|
|
497
503
|
16
|
|
@@ -552,17 +558,21 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
552
558
|
),
|
|
553
559
|
width=(
|
|
554
560
|
224
|
|
555
|
-
if config is None
|
|
561
|
+
if config is None
|
|
562
|
+
or not should_have_vision_config(config)
|
|
563
|
+
or not hasattr(config.vision_config, "image_size")
|
|
556
564
|
else config.vision_config.image_size
|
|
557
565
|
),
|
|
558
566
|
height=(
|
|
559
567
|
224
|
|
560
|
-
if config is None
|
|
568
|
+
if config is None
|
|
569
|
+
or not should_have_vision_config(config)
|
|
570
|
+
or not hasattr(config.vision_config, "image_size")
|
|
561
571
|
else config.vision_config.image_size
|
|
562
572
|
),
|
|
563
573
|
num_channels=(
|
|
564
574
|
3
|
|
565
|
-
if config is None
|
|
575
|
+
if config is None or not should_have_vision_config(config)
|
|
566
576
|
else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
|
|
567
577
|
),
|
|
568
578
|
pad_token_id=(
|
|
@@ -18,6 +18,22 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
18
18
|
config.num_decoder_layers = min(config.num_decoder_layers, 2)
|
|
19
19
|
if hasattr(config, "num_hidden_layers"):
|
|
20
20
|
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
21
|
+
if hasattr(config, "encoder") and hasattr(config.encoder, "layer_types"):
|
|
22
|
+
default_layer_types = [
|
|
23
|
+
"sliding_attention",
|
|
24
|
+
"full_attention",
|
|
25
|
+
"sliding_attention",
|
|
26
|
+
"full_attention",
|
|
27
|
+
]
|
|
28
|
+
config.encoder.num_hidden_layers = 4
|
|
29
|
+
config.encoder.layer_types = (
|
|
30
|
+
default_layer_types if config is None else config.encoder.layer_types[:4]
|
|
31
|
+
)
|
|
32
|
+
config.decoder.num_hidden_layers = 4
|
|
33
|
+
config.decoder.layer_types = (
|
|
34
|
+
default_layer_types if config is None else config.decoder.layer_types[:4]
|
|
35
|
+
)
|
|
36
|
+
|
|
21
37
|
update_config(config, kwargs)
|
|
22
38
|
return kwargs
|
|
23
39
|
|
|
@@ -177,55 +193,75 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
177
193
|
|
|
178
194
|
If the configuration is None, the function selects typical dimensions.
|
|
179
195
|
"""
|
|
196
|
+
path = 1
|
|
180
197
|
if config is not None:
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
"vocab_size",
|
|
184
|
-
"hidden_size",
|
|
185
|
-
"num_attention_heads",
|
|
186
|
-
("num_hidden_layers", "num_layers"),
|
|
187
|
-
("n_positions", "d_model"),
|
|
188
|
-
(
|
|
189
|
-
"num_key_value_heads",
|
|
190
|
-
"num_heads",
|
|
191
|
-
("decoder_attention_heads", "encoder_attention_heads"),
|
|
192
|
-
),
|
|
193
|
-
)
|
|
194
|
-
# exceptions = {
|
|
195
|
-
# "PLBartForConditionalGeneration": (
|
|
196
|
-
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
|
|
197
|
-
# )
|
|
198
|
-
# }
|
|
199
|
-
kwargs = dict(
|
|
200
|
-
batch_size=2,
|
|
201
|
-
sequence_length=30,
|
|
202
|
-
sequence_length2=3,
|
|
203
|
-
head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
|
|
204
|
-
head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
|
|
205
|
-
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
206
|
-
num_hidden_layers=(
|
|
207
|
-
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
|
|
208
|
-
),
|
|
209
|
-
num_key_value_heads_encoder=(
|
|
210
|
-
16
|
|
211
|
-
if config is None
|
|
212
|
-
else _pick(
|
|
198
|
+
if hasattr(config, "num_attention_heads"):
|
|
199
|
+
check_hasattr(
|
|
213
200
|
config,
|
|
214
|
-
"
|
|
215
|
-
"
|
|
216
|
-
"
|
|
201
|
+
"vocab_size",
|
|
202
|
+
"hidden_size",
|
|
203
|
+
"num_attention_heads",
|
|
204
|
+
("num_hidden_layers", "num_layers"),
|
|
205
|
+
("n_positions", "d_model"),
|
|
206
|
+
(
|
|
207
|
+
"num_key_value_heads",
|
|
208
|
+
"num_heads",
|
|
209
|
+
("decoder_attention_heads", "encoder_attention_heads"),
|
|
210
|
+
),
|
|
217
211
|
)
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
212
|
+
else:
|
|
213
|
+
check_hasattr(config, "encoder", "decoder")
|
|
214
|
+
path = 2
|
|
215
|
+
|
|
216
|
+
if path == 1:
|
|
217
|
+
kwargs = dict(
|
|
218
|
+
batch_size=2,
|
|
219
|
+
sequence_length=30,
|
|
220
|
+
sequence_length2=3,
|
|
221
|
+
head_dim_encoder=(
|
|
222
|
+
16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim")
|
|
223
|
+
),
|
|
224
|
+
head_dim_decoder=(
|
|
225
|
+
16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim")
|
|
226
|
+
),
|
|
227
|
+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
228
|
+
num_hidden_layers=(
|
|
229
|
+
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
|
|
230
|
+
),
|
|
231
|
+
num_key_value_heads_encoder=(
|
|
232
|
+
16
|
|
233
|
+
if config is None
|
|
234
|
+
else _pick(
|
|
235
|
+
config,
|
|
236
|
+
"encoder_attention_heads",
|
|
237
|
+
"num_key_value_heads",
|
|
238
|
+
"num_heads",
|
|
239
|
+
)
|
|
240
|
+
),
|
|
241
|
+
num_key_value_heads_decoder=(
|
|
242
|
+
16
|
|
243
|
+
if config is None
|
|
244
|
+
else _pick(
|
|
245
|
+
config,
|
|
246
|
+
"decoder_attention_heads",
|
|
247
|
+
"num_key_value_heads",
|
|
248
|
+
"num_heads",
|
|
249
|
+
)
|
|
250
|
+
),
|
|
251
|
+
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
|
|
252
|
+
)
|
|
253
|
+
else:
|
|
254
|
+
kwargs = dict(
|
|
255
|
+
batch_size=2,
|
|
256
|
+
sequence_length=30,
|
|
257
|
+
sequence_length2=3,
|
|
258
|
+
dummy_max_token_id=config.encoder.vocab_size - 1,
|
|
259
|
+
num_key_value_heads_encoder=config.encoder.num_key_value_heads,
|
|
260
|
+
num_key_value_heads_decoder=config.decoder.num_key_value_heads,
|
|
261
|
+
num_hidden_layers=len(config.encoder.layer_types),
|
|
262
|
+
head_dim_encoder=config.encoder.head_dim,
|
|
263
|
+
head_dim_decoder=config.decoder.head_dim,
|
|
264
|
+
encoder_dim=256,
|
|
265
|
+
)
|
|
266
|
+
|
|
231
267
|
return kwargs, get_inputs
|
|
@@ -40,6 +40,9 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
40
40
|
state_size=8 if config is None else getattr(config, "state_size", None),
|
|
41
41
|
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
|
|
42
42
|
)
|
|
43
|
+
elif config.__class__.__name__ == "FunnelConfig":
|
|
44
|
+
# does not support num_hidden_layers
|
|
45
|
+
kwargs = dict()
|
|
43
46
|
else:
|
|
44
47
|
kwargs = dict(
|
|
45
48
|
head_dim=getattr(
|
|
@@ -221,6 +221,7 @@ def _patch_torch(
|
|
|
221
221
|
catch_constraints: bool,
|
|
222
222
|
stop_if_static: int,
|
|
223
223
|
) -> Tuple[Optional[Callable], ...]:
|
|
224
|
+
import packaging.version as pv
|
|
224
225
|
import torch
|
|
225
226
|
import torch.jit
|
|
226
227
|
import torch._export.non_strict_utils # produce_guards_and_solve_constraints
|
|
@@ -238,6 +239,11 @@ def _patch_torch(
|
|
|
238
239
|
patched_ShapeEnv,
|
|
239
240
|
)
|
|
240
241
|
|
|
242
|
+
if pv.Version(torch.__version__) >= pv.Version("2.9.99"):
|
|
243
|
+
from .patches.patch_torch import patched_DynamicDimConstraintPrinter
|
|
244
|
+
else:
|
|
245
|
+
patched_DynamicDimConstraintPrinter = None
|
|
246
|
+
|
|
241
247
|
f___constrain_user_specified_dimhint_range = None
|
|
242
248
|
f__broadcast_in_dim_meta = None
|
|
243
249
|
f__broadcast_shapes = None
|
|
@@ -259,6 +265,17 @@ def _patch_torch(
|
|
|
259
265
|
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
|
|
260
266
|
print("[torch_export_patches] patch pytorch")
|
|
261
267
|
|
|
268
|
+
# torch.tx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
|
|
269
|
+
if patched_DynamicDimConstraintPrinter is not None:
|
|
270
|
+
f__print_symbol = (
|
|
271
|
+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol
|
|
272
|
+
)
|
|
273
|
+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
|
|
274
|
+
patched_DynamicDimConstraintPrinter._print_Symbol
|
|
275
|
+
)
|
|
276
|
+
else:
|
|
277
|
+
f__print_symbol = None
|
|
278
|
+
|
|
262
279
|
# torch.vmap
|
|
263
280
|
f_vmap = torch.vmap
|
|
264
281
|
torch.vmap = patched_vmap
|
|
@@ -392,6 +409,7 @@ def _patch_torch(
|
|
|
392
409
|
f_shape_env__log_guard,
|
|
393
410
|
f_shape_env__set_replacement,
|
|
394
411
|
f_vmap,
|
|
412
|
+
f__print_symbol,
|
|
395
413
|
)
|
|
396
414
|
|
|
397
415
|
|
|
@@ -416,6 +434,7 @@ def _unpatch_torch(
|
|
|
416
434
|
f_shape_env__log_guard: Optional[Callable],
|
|
417
435
|
f_shape_env__set_replacement: Optional[Callable],
|
|
418
436
|
f_vmap: Optional[Callable],
|
|
437
|
+
f__print_symbol: Optional[Callable],
|
|
419
438
|
):
|
|
420
439
|
import torch
|
|
421
440
|
import torch.jit
|
|
@@ -423,6 +442,10 @@ def _unpatch_torch(
|
|
|
423
442
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
424
443
|
|
|
425
444
|
# this should disappear when torch.jit is removed
|
|
445
|
+
if f__print_symbol is not None:
|
|
446
|
+
torch.fx.experimental.symbolic_shapes.DynamicDimConstraintPrinter._print_Symbol = (
|
|
447
|
+
f__print_symbol
|
|
448
|
+
)
|
|
426
449
|
torch.vmap = f_vmap
|
|
427
450
|
torch.jit.isinstance = f_jit_isinstance
|
|
428
451
|
torch._dynamo.mark_static_address = f_mark_static_address
|
|
@@ -818,6 +841,7 @@ def torch_export_patches(
|
|
|
818
841
|
rewrite: Optional[List[Callable]] = None,
|
|
819
842
|
dump_rewriting: Optional[str] = None,
|
|
820
843
|
patch_details: Optional[PatchDetails] = None,
|
|
844
|
+
profile: Optional[str] = None,
|
|
821
845
|
) -> Callable:
|
|
822
846
|
"""
|
|
823
847
|
Tries to bypass some situations :func:`torch.export.export` does not support.
|
|
@@ -847,9 +871,12 @@ def torch_export_patches(
|
|
|
847
871
|
this is done by function :func:`transform_method
|
|
848
872
|
<onnx_diagnostic.torch_export_patches.patch_module.transform_method>`,
|
|
849
873
|
its documentation provides possible values
|
|
850
|
-
:param dump_rewriting: dumps rewriting information in file beginning with that prefix
|
|
851
|
-
|
|
874
|
+
:param dump_rewriting: dumps rewriting information in file beginning with that prefix,
|
|
875
|
+
this only applied on the automated rewritings
|
|
876
|
+
:param patch_details: if specified, this class is used to stored every applied rewriting.
|
|
852
877
|
:param verbose: to show which patches is applied
|
|
878
|
+
:param profile: starts profiling whatever is called inside the context manager,
|
|
879
|
+
output the profiling into a text file
|
|
853
880
|
|
|
854
881
|
The list of available patches.
|
|
855
882
|
|
|
@@ -989,6 +1016,7 @@ def torch_export_patches(
|
|
|
989
1016
|
f_shape_env__log_guard,
|
|
990
1017
|
f_shape_env__set_replacement,
|
|
991
1018
|
f_vmap,
|
|
1019
|
+
f__print_Symbol,
|
|
992
1020
|
) = _patch_torch(
|
|
993
1021
|
verbose, patch_details, patch_torch, catch_constraints, stop_if_static
|
|
994
1022
|
)
|
|
@@ -1017,10 +1045,23 @@ def torch_export_patches(
|
|
|
1017
1045
|
if verbose:
|
|
1018
1046
|
print("[torch_export_patches] done patching")
|
|
1019
1047
|
|
|
1048
|
+
if profile:
|
|
1049
|
+
from pyinstrument import Profiler
|
|
1050
|
+
|
|
1051
|
+
profiler = Profiler()
|
|
1052
|
+
profiler.start()
|
|
1053
|
+
else:
|
|
1054
|
+
profiler = None
|
|
1055
|
+
|
|
1020
1056
|
try:
|
|
1021
1057
|
yield fct_callable
|
|
1022
1058
|
finally:
|
|
1023
1059
|
|
|
1060
|
+
if profiler:
|
|
1061
|
+
profiler.stop()
|
|
1062
|
+
with open(profile, "w") as f:
|
|
1063
|
+
f.write(profiler.output_html())
|
|
1064
|
+
|
|
1024
1065
|
# unpatch
|
|
1025
1066
|
|
|
1026
1067
|
if verbose:
|
|
@@ -1051,6 +1092,7 @@ def torch_export_patches(
|
|
|
1051
1092
|
f_shape_env__log_guard,
|
|
1052
1093
|
f_shape_env__set_replacement,
|
|
1053
1094
|
f_vmap,
|
|
1095
|
+
f__print_Symbol,
|
|
1054
1096
|
)
|
|
1055
1097
|
|
|
1056
1098
|
if patch_transformers:
|
|
@@ -101,7 +101,10 @@ def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
|
|
|
101
101
|
|
|
102
102
|
|
|
103
103
|
def patched_float_arange(start, end, step):
|
|
104
|
-
"""
|
|
104
|
+
"""
|
|
105
|
+
Patched arange when start, end, step are floats.
|
|
106
|
+
This patch should not be needed after 2.10.
|
|
107
|
+
"""
|
|
105
108
|
if is_torchdynamo_exporting():
|
|
106
109
|
return torch.ops.patched.float_arange(start, end, step)
|
|
107
110
|
else:
|
|
@@ -596,33 +596,41 @@ class RewriteControlFlow(ast.NodeTransformer):
|
|
|
596
596
|
elts=[
|
|
597
597
|
*[
|
|
598
598
|
ast.Call(
|
|
599
|
-
ast.Attribute(
|
|
600
|
-
value=ast.
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
args=[
|
|
605
|
-
ast.Subscript(
|
|
606
|
-
value=ast.Attribute(
|
|
607
|
-
value=ast.Name(id=v, ctx=ast.Load()),
|
|
608
|
-
attr="shape",
|
|
599
|
+
func=ast.Attribute(
|
|
600
|
+
value=ast.Call(
|
|
601
|
+
ast.Attribute(
|
|
602
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
603
|
+
attr="arange",
|
|
609
604
|
ctx=ast.Load(),
|
|
610
605
|
),
|
|
611
|
-
|
|
606
|
+
args=[
|
|
607
|
+
ast.Subscript(
|
|
608
|
+
value=ast.Attribute(
|
|
609
|
+
value=ast.Name(id=v, ctx=ast.Load()),
|
|
610
|
+
attr="shape",
|
|
611
|
+
ctx=ast.Load(),
|
|
612
|
+
),
|
|
613
|
+
slice=ast.Constant(value=0, ctx=ast.Load()),
|
|
614
|
+
ctx=ast.Load(),
|
|
615
|
+
),
|
|
616
|
+
],
|
|
617
|
+
keywords=[
|
|
618
|
+
ast.keyword(
|
|
619
|
+
arg="dtype",
|
|
620
|
+
value=ast.Attribute(
|
|
621
|
+
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
622
|
+
attr="int64",
|
|
623
|
+
ctx=ast.Load(),
|
|
624
|
+
),
|
|
625
|
+
)
|
|
626
|
+
],
|
|
612
627
|
ctx=ast.Load(),
|
|
613
628
|
),
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
618
|
-
|
|
619
|
-
value=ast.Name(id="torch", ctx=ast.Load()),
|
|
620
|
-
attr="int64",
|
|
621
|
-
ctx=ast.Load(),
|
|
622
|
-
),
|
|
623
|
-
)
|
|
624
|
-
],
|
|
625
|
-
ctx=ast.Load(),
|
|
629
|
+
attr="unsqueeze",
|
|
630
|
+
ctx=ast.Load(),
|
|
631
|
+
),
|
|
632
|
+
args=[ast.Constant(value=1)],
|
|
633
|
+
keywords=[],
|
|
626
634
|
)
|
|
627
635
|
for v in scan_shape_vars
|
|
628
636
|
],
|