onnx-diagnostic 0.4.2__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +5 -0
- onnx_diagnostic/tasks/__init__.py +2 -0
- onnx_diagnostic/tasks/image_classification.py +21 -0
- onnx_diagnostic/tasks/object_detection.py +123 -0
- onnx_diagnostic/tasks/text_generation.py +13 -10
- onnx_diagnostic/torch_export_patches/__init__.py +17 -1
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +27 -29
- onnx_diagnostic/torch_export_patches/patch_module.py +304 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +40 -4
- onnx_diagnostic/torch_models/hghub/hub_data.py +13 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +210 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -9
- onnx_diagnostic/torch_models/test_helper.py +82 -28
- {onnx_diagnostic-0.4.2.dist-info → onnx_diagnostic-0.4.4.dist-info}/METADATA +8 -3
- {onnx_diagnostic-0.4.2.dist-info → onnx_diagnostic-0.4.4.dist-info}/RECORD +19 -17
- {onnx_diagnostic-0.4.2.dist-info → onnx_diagnostic-0.4.4.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.4.2.dist-info → onnx_diagnostic-0.4.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.4.2.dist-info → onnx_diagnostic-0.4.4.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -336,6 +336,10 @@ def get_parser_validate() -> ArgumentParser:
|
|
|
336
336
|
help="drops the following inputs names, it should be a list "
|
|
337
337
|
"with comma separated values",
|
|
338
338
|
)
|
|
339
|
+
parser.add_argument(
|
|
340
|
+
"--subfolder",
|
|
341
|
+
help="subfolder where to find the model and the configuration",
|
|
342
|
+
)
|
|
339
343
|
parser.add_argument(
|
|
340
344
|
"--ortfusiontype",
|
|
341
345
|
required=False,
|
|
@@ -413,6 +417,7 @@ def _cmd_validate(argv: List[Any]):
|
|
|
413
417
|
ortfusiontype=args.ortfusiontype,
|
|
414
418
|
input_options=args.iop,
|
|
415
419
|
model_options=args.mop,
|
|
420
|
+
subfolder=args.subfolder,
|
|
416
421
|
)
|
|
417
422
|
print("")
|
|
418
423
|
print("-- summary --")
|
|
@@ -6,6 +6,7 @@ from . import (
|
|
|
6
6
|
image_classification,
|
|
7
7
|
image_text_to_text,
|
|
8
8
|
mixture_of_expert,
|
|
9
|
+
object_detection,
|
|
9
10
|
sentence_similarity,
|
|
10
11
|
text_classification,
|
|
11
12
|
text_generation,
|
|
@@ -20,6 +21,7 @@ __TASKS__ = [
|
|
|
20
21
|
image_classification,
|
|
21
22
|
image_text_to_text,
|
|
22
23
|
mixture_of_expert,
|
|
24
|
+
object_detection,
|
|
23
25
|
sentence_similarity,
|
|
24
26
|
text_classification,
|
|
25
27
|
text_generation,
|
|
@@ -7,6 +7,13 @@ __TASK__ = "image-classification"
|
|
|
7
7
|
|
|
8
8
|
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
9
9
|
"""Reduces a model size."""
|
|
10
|
+
if (
|
|
11
|
+
hasattr(config, "model_type")
|
|
12
|
+
and config.model_type == "timm_wrapper"
|
|
13
|
+
and not hasattr(config, "num_hidden_layers")
|
|
14
|
+
):
|
|
15
|
+
# We cannot reduce.
|
|
16
|
+
return {}
|
|
10
17
|
check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
|
|
11
18
|
kwargs = dict(
|
|
12
19
|
num_hidden_layers=(
|
|
@@ -82,6 +89,20 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
82
89
|
If the configuration is None, the function selects typical dimensions.
|
|
83
90
|
"""
|
|
84
91
|
if config is not None:
|
|
92
|
+
if (
|
|
93
|
+
hasattr(config, "model_type")
|
|
94
|
+
and config.model_type == "timm_wrapper"
|
|
95
|
+
and not hasattr(config, "num_hidden_layers")
|
|
96
|
+
):
|
|
97
|
+
input_size = config.pretrained_cfg["input_size"]
|
|
98
|
+
kwargs = dict(
|
|
99
|
+
batch_size=2,
|
|
100
|
+
input_width=input_size[-2],
|
|
101
|
+
input_height=input_size[-1],
|
|
102
|
+
input_channels=input_size[-3],
|
|
103
|
+
)
|
|
104
|
+
return kwargs, get_inputs
|
|
105
|
+
|
|
85
106
|
check_hasattr(config, ("image_size", "architectures"), "num_channels")
|
|
86
107
|
if config is not None:
|
|
87
108
|
if hasattr(config, "image_size"):
|
|
@@ -0,0 +1,123 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import update_config, check_hasattr
|
|
4
|
+
|
|
5
|
+
__TASK__ = "object-detection"
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
9
|
+
"""Reduces a model size."""
|
|
10
|
+
check_hasattr(config, ("num_hidden_layers", "hidden_sizes"))
|
|
11
|
+
kwargs = dict(
|
|
12
|
+
num_hidden_layers=(
|
|
13
|
+
min(config.num_hidden_layers, 2)
|
|
14
|
+
if hasattr(config, "num_hidden_layers")
|
|
15
|
+
else len(config.hidden_sizes)
|
|
16
|
+
)
|
|
17
|
+
)
|
|
18
|
+
update_config(config, kwargs)
|
|
19
|
+
return kwargs
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def get_inputs(
|
|
23
|
+
model: torch.nn.Module,
|
|
24
|
+
config: Optional[Any],
|
|
25
|
+
input_width: int,
|
|
26
|
+
input_height: int,
|
|
27
|
+
input_channels: int,
|
|
28
|
+
batch_size: int = 2,
|
|
29
|
+
dynamic_rope: bool = False,
|
|
30
|
+
add_second_input: bool = False,
|
|
31
|
+
**kwargs, # unused
|
|
32
|
+
):
|
|
33
|
+
"""
|
|
34
|
+
Generates inputs for task ``object-detection``.
|
|
35
|
+
|
|
36
|
+
:param model: model to get the missing information
|
|
37
|
+
:param config: configuration used to generate the model
|
|
38
|
+
:param batch_size: batch size
|
|
39
|
+
:param input_channels: input channel
|
|
40
|
+
:param input_width: input width
|
|
41
|
+
:param input_height: input height
|
|
42
|
+
:return: dictionary
|
|
43
|
+
"""
|
|
44
|
+
assert isinstance(
|
|
45
|
+
input_width, int
|
|
46
|
+
), f"Unexpected type for input_width {type(input_width)}{config}"
|
|
47
|
+
assert isinstance(
|
|
48
|
+
input_width, int
|
|
49
|
+
), f"Unexpected type for input_height {type(input_height)}{config}"
|
|
50
|
+
|
|
51
|
+
shapes = {
|
|
52
|
+
"pixel_values": {
|
|
53
|
+
0: torch.export.Dim("batch", min=1, max=1024),
|
|
54
|
+
2: "width",
|
|
55
|
+
3: "height",
|
|
56
|
+
}
|
|
57
|
+
}
|
|
58
|
+
inputs = dict(
|
|
59
|
+
pixel_values=torch.randn(batch_size, input_channels, input_width, input_height).clamp(
|
|
60
|
+
-1, 1
|
|
61
|
+
),
|
|
62
|
+
)
|
|
63
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
64
|
+
if add_second_input:
|
|
65
|
+
res["inputs2"] = get_inputs(
|
|
66
|
+
model=model,
|
|
67
|
+
config=config,
|
|
68
|
+
input_width=input_width + 1,
|
|
69
|
+
input_height=input_height + 1,
|
|
70
|
+
input_channels=input_channels,
|
|
71
|
+
batch_size=batch_size + 1,
|
|
72
|
+
dynamic_rope=dynamic_rope,
|
|
73
|
+
**kwargs,
|
|
74
|
+
)["inputs"]
|
|
75
|
+
return res
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
79
|
+
"""
|
|
80
|
+
Inputs kwargs.
|
|
81
|
+
|
|
82
|
+
If the configuration is None, the function selects typical dimensions.
|
|
83
|
+
"""
|
|
84
|
+
if config is not None:
|
|
85
|
+
if (
|
|
86
|
+
hasattr(config, "model_type")
|
|
87
|
+
and config.model_type == "timm_wrapper"
|
|
88
|
+
and not hasattr(config, "num_hidden_layers")
|
|
89
|
+
):
|
|
90
|
+
input_size = config.pretrained_cfg["input_size"]
|
|
91
|
+
kwargs = dict(
|
|
92
|
+
batch_size=2,
|
|
93
|
+
input_width=input_size[-2],
|
|
94
|
+
input_height=input_size[-1],
|
|
95
|
+
input_channels=input_size[-3],
|
|
96
|
+
)
|
|
97
|
+
return kwargs, get_inputs
|
|
98
|
+
|
|
99
|
+
check_hasattr(config, ("image_size", "architectures"), "num_channels")
|
|
100
|
+
if config is not None:
|
|
101
|
+
if hasattr(config, "image_size"):
|
|
102
|
+
image_size = config.image_size
|
|
103
|
+
else:
|
|
104
|
+
assert config.architectures, f"empty architecture in {config}"
|
|
105
|
+
from ..torch_models.hghub.hub_api import get_architecture_default_values
|
|
106
|
+
|
|
107
|
+
default_values = get_architecture_default_values(config.architectures[0])
|
|
108
|
+
image_size = default_values["image_size"]
|
|
109
|
+
if config is None or isinstance(image_size, int):
|
|
110
|
+
kwargs = dict(
|
|
111
|
+
batch_size=2,
|
|
112
|
+
input_width=224 if config is None else image_size,
|
|
113
|
+
input_height=224 if config is None else image_size,
|
|
114
|
+
input_channels=3 if config is None else config.num_channels,
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
kwargs = dict(
|
|
118
|
+
batch_size=2,
|
|
119
|
+
input_width=config.image_size[0],
|
|
120
|
+
input_height=config.image_size[1],
|
|
121
|
+
input_channels=config.num_channels,
|
|
122
|
+
)
|
|
123
|
+
return kwargs, get_inputs
|
|
@@ -19,12 +19,11 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
19
19
|
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
|
|
20
20
|
"num_hidden_layers",
|
|
21
21
|
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
|
|
22
|
-
"intermediate_size",
|
|
23
22
|
"hidden_size",
|
|
24
23
|
"vocab_size",
|
|
25
24
|
)
|
|
26
25
|
if config.__class__.__name__ == "FalconMambaConfig":
|
|
27
|
-
check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
|
|
26
|
+
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
|
|
28
27
|
kwargs = dict(
|
|
29
28
|
num_hidden_layers=min(config.num_hidden_layers, 2),
|
|
30
29
|
intermediate_size=256 if config is None else min(512, config.intermediate_size),
|
|
@@ -44,17 +43,18 @@ def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
|
44
43
|
if hasattr(config, "num_key_value_heads")
|
|
45
44
|
else config.num_attention_heads
|
|
46
45
|
),
|
|
47
|
-
intermediate_size=(
|
|
48
|
-
min(config.intermediate_size, 24576 // 4)
|
|
49
|
-
if config.intermediate_size % 4 == 0
|
|
50
|
-
else config.intermediate_size
|
|
51
|
-
),
|
|
52
46
|
hidden_size=(
|
|
53
47
|
min(config.hidden_size, 3072 // 4)
|
|
54
48
|
if config.hidden_size % 4 == 0
|
|
55
49
|
else config.hidden_size
|
|
56
50
|
),
|
|
57
51
|
)
|
|
52
|
+
if config is None or hasattr(config, "intermediate_size"):
|
|
53
|
+
kwargs["intermediate_size"] = (
|
|
54
|
+
min(config.intermediate_size, 24576 // 4)
|
|
55
|
+
if config.intermediate_size % 4 == 0
|
|
56
|
+
else config.intermediate_size
|
|
57
|
+
)
|
|
58
58
|
update_config(config, kwargs)
|
|
59
59
|
return kwargs
|
|
60
60
|
|
|
@@ -228,11 +228,10 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
228
228
|
"vocab_size",
|
|
229
229
|
("num_attention_heads", "use_mambapy"),
|
|
230
230
|
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
|
|
231
|
-
"intermediate_size",
|
|
232
231
|
"hidden_size",
|
|
233
232
|
)
|
|
234
233
|
if config.__class__.__name__ == "FalconMambaConfig":
|
|
235
|
-
check_hasattr(config, "conv_kernel", "state_size") # 4 and 8
|
|
234
|
+
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
|
|
236
235
|
kwargs = dict(
|
|
237
236
|
batch_size=2,
|
|
238
237
|
sequence_length=30,
|
|
@@ -263,7 +262,11 @@ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
|
263
262
|
if config is None
|
|
264
263
|
else _pick(config, "num_key_value_heads", "num_attention_heads")
|
|
265
264
|
),
|
|
266
|
-
intermediate_size=1024 if config is None else config.intermediate_size,
|
|
267
265
|
hidden_size=512 if config is None else config.hidden_size,
|
|
268
266
|
)
|
|
267
|
+
if config is None or hasattr(config, "intermediate_size"):
|
|
268
|
+
kwargs["intermediate_size"] = (
|
|
269
|
+
1024 if config is None else config.intermediate_size,
|
|
270
|
+
)
|
|
271
|
+
|
|
269
272
|
return kwargs, get_inputs
|
|
@@ -1,4 +1,20 @@
|
|
|
1
1
|
from .onnx_export_errors import (
|
|
2
|
-
|
|
2
|
+
torch_export_patches,
|
|
3
3
|
register_additional_serialization_functions,
|
|
4
4
|
)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# bypass_export_some_errors is the first name given to the patches.
|
|
8
|
+
bypass_export_some_errors = torch_export_patches # type: ignore
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def register_flattening_functions(verbose: int = 0):
|
|
12
|
+
"""
|
|
13
|
+
Registers functions to serialize deserialize cache or other classes
|
|
14
|
+
implemented in :epkg:`transformers` and used as inputs.
|
|
15
|
+
This is needed whenever a model must be exported through
|
|
16
|
+
:func:`torch.export.export`.
|
|
17
|
+
"""
|
|
18
|
+
from .onnx_export_serialization import _register_cache_serialization
|
|
19
|
+
|
|
20
|
+
return _register_cache_serialization(verbose=verbose)
|
|
@@ -93,7 +93,7 @@ def register_additional_serialization_functions(
|
|
|
93
93
|
|
|
94
94
|
|
|
95
95
|
@contextlib.contextmanager
|
|
96
|
-
def
|
|
96
|
+
def torch_export_patches(
|
|
97
97
|
patch_sympy: bool = True,
|
|
98
98
|
patch_torch: bool = True,
|
|
99
99
|
patch_transformers: bool = False,
|
|
@@ -145,13 +145,13 @@ def bypass_export_some_errors(
|
|
|
145
145
|
|
|
146
146
|
::
|
|
147
147
|
|
|
148
|
-
with
|
|
148
|
+
with torch_export_patches(patch_transformers=True) as modificator:
|
|
149
149
|
inputs = modificator(inputs)
|
|
150
150
|
onx = to_onnx(..., inputs, ...)
|
|
151
151
|
|
|
152
152
|
::
|
|
153
153
|
|
|
154
|
-
with
|
|
154
|
+
with torch_export_patches(patch_transformers=True) as modificator:
|
|
155
155
|
inputs = modificator(inputs)
|
|
156
156
|
onx = torch.onnx.export(..., inputs, ...)
|
|
157
157
|
|
|
@@ -159,7 +159,7 @@ def bypass_export_some_errors(
|
|
|
159
159
|
|
|
160
160
|
::
|
|
161
161
|
|
|
162
|
-
with
|
|
162
|
+
with torch_export_patches(patch_transformers=True) as modificator:
|
|
163
163
|
inputs = modificator(inputs)
|
|
164
164
|
ep = torch.export.export(..., inputs, ...)
|
|
165
165
|
|
|
@@ -190,7 +190,7 @@ def bypass_export_some_errors(
|
|
|
190
190
|
|
|
191
191
|
if verbose:
|
|
192
192
|
print(
|
|
193
|
-
"[
|
|
193
|
+
"[torch_export_patches] replace torch.jit.isinstance, "
|
|
194
194
|
"torch._dynamo.mark_static_address"
|
|
195
195
|
)
|
|
196
196
|
|
|
@@ -210,8 +210,8 @@ def bypass_export_some_errors(
|
|
|
210
210
|
f_sympy_name = getattr(sympy.core.numbers.IntegerConstant, "name", None)
|
|
211
211
|
|
|
212
212
|
if verbose:
|
|
213
|
-
print(f"[
|
|
214
|
-
print("[
|
|
213
|
+
print(f"[torch_export_patches] sympy.__version__={sympy.__version__!r}")
|
|
214
|
+
print("[torch_export_patches] patch sympy")
|
|
215
215
|
|
|
216
216
|
sympy.core.numbers.IntegerConstant.name = lambda self: f"IntCst{str(self)}"
|
|
217
217
|
|
|
@@ -228,9 +228,9 @@ def bypass_export_some_errors(
|
|
|
228
228
|
)
|
|
229
229
|
|
|
230
230
|
if verbose:
|
|
231
|
-
print(f"[
|
|
232
|
-
print(f"[
|
|
233
|
-
print("[
|
|
231
|
+
print(f"[torch_export_patches] torch.__version__={torch.__version__!r}")
|
|
232
|
+
print(f"[torch_export_patches] stop_if_static={stop_if_static!r}")
|
|
233
|
+
print("[torch_export_patches] patch pytorch")
|
|
234
234
|
|
|
235
235
|
# torch.jit.isinstance
|
|
236
236
|
f_jit_isinstance = torch.jit.isinstance
|
|
@@ -252,7 +252,7 @@ def bypass_export_some_errors(
|
|
|
252
252
|
# torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
253
253
|
if catch_constraints:
|
|
254
254
|
if verbose:
|
|
255
|
-
print("[
|
|
255
|
+
print("[torch_export_patches] modifies shape constraints")
|
|
256
256
|
f_produce_guards_and_solve_constraints = (
|
|
257
257
|
torch._export.non_strict_utils.produce_guards_and_solve_constraints
|
|
258
258
|
)
|
|
@@ -277,22 +277,20 @@ def bypass_export_some_errors(
|
|
|
277
277
|
ShapeEnv._log_guard_remember = ShapeEnv._log_guard
|
|
278
278
|
|
|
279
279
|
if verbose:
|
|
280
|
-
print(
|
|
281
|
-
|
|
282
|
-
)
|
|
283
|
-
print("[bypass_export_some_errors] replaces ShapeEnv._set_replacement")
|
|
280
|
+
print("[torch_export_patches] assert when a dynamic dimension turns static")
|
|
281
|
+
print("[torch_export_patches] replaces ShapeEnv._set_replacement")
|
|
284
282
|
|
|
285
283
|
f_shape_env__set_replacement = ShapeEnv._set_replacement
|
|
286
284
|
ShapeEnv._set_replacement = patched_ShapeEnv._set_replacement
|
|
287
285
|
|
|
288
286
|
if verbose:
|
|
289
|
-
print("[
|
|
287
|
+
print("[torch_export_patches] replaces ShapeEnv._log_guard")
|
|
290
288
|
f_shape_env__log_guard = ShapeEnv._log_guard
|
|
291
289
|
ShapeEnv._log_guard = patched_ShapeEnv._log_guard
|
|
292
290
|
|
|
293
291
|
if stop_if_static > 1:
|
|
294
292
|
if verbose:
|
|
295
|
-
print("[
|
|
293
|
+
print("[torch_export_patches] replaces ShapeEnv._check_frozen")
|
|
296
294
|
f_shape_env__check_frozen = ShapeEnv._check_frozen
|
|
297
295
|
ShapeEnv._check_frozen = patched_ShapeEnv._check_frozen
|
|
298
296
|
|
|
@@ -305,7 +303,7 @@ def bypass_export_some_errors(
|
|
|
305
303
|
import transformers
|
|
306
304
|
|
|
307
305
|
print(
|
|
308
|
-
f"[
|
|
306
|
+
f"[torch_export_patches] transformers.__version__="
|
|
309
307
|
f"{transformers.__version__!r}"
|
|
310
308
|
)
|
|
311
309
|
revert_patches_info = patch_module_or_classes(
|
|
@@ -314,7 +312,7 @@ def bypass_export_some_errors(
|
|
|
314
312
|
|
|
315
313
|
if custom_patches:
|
|
316
314
|
if verbose:
|
|
317
|
-
print("[
|
|
315
|
+
print("[torch_export_patches] applies custom patches")
|
|
318
316
|
revert_custom_patches_info = patch_module_or_classes(
|
|
319
317
|
custom_patches, verbose=verbose
|
|
320
318
|
)
|
|
@@ -326,7 +324,7 @@ def bypass_export_some_errors(
|
|
|
326
324
|
fct_callable = replacement_before_exporting if patch_transformers else (lambda x: x)
|
|
327
325
|
|
|
328
326
|
if verbose:
|
|
329
|
-
print("[
|
|
327
|
+
print("[torch_export_patches] done patching")
|
|
330
328
|
|
|
331
329
|
try:
|
|
332
330
|
yield fct_callable
|
|
@@ -336,7 +334,7 @@ def bypass_export_some_errors(
|
|
|
336
334
|
#######
|
|
337
335
|
|
|
338
336
|
if verbose:
|
|
339
|
-
print("[
|
|
337
|
+
print("[torch_export_patches] remove patches")
|
|
340
338
|
|
|
341
339
|
if patch_sympy:
|
|
342
340
|
# tracked by https://github.com/pytorch/pytorch/issues/143494
|
|
@@ -346,7 +344,7 @@ def bypass_export_some_errors(
|
|
|
346
344
|
delattr(sympy.core.numbers.IntegerConstant, "name")
|
|
347
345
|
|
|
348
346
|
if verbose:
|
|
349
|
-
print("[
|
|
347
|
+
print("[torch_export_patches] restored sympy functions")
|
|
350
348
|
|
|
351
349
|
#######
|
|
352
350
|
# torch
|
|
@@ -362,22 +360,22 @@ def bypass_export_some_errors(
|
|
|
362
360
|
torch._meta_registrations._broadcast_shapes = f__broadcast_shapes
|
|
363
361
|
|
|
364
362
|
if verbose:
|
|
365
|
-
print("[
|
|
363
|
+
print("[torch_export_patches] restored pytorch functions")
|
|
366
364
|
|
|
367
365
|
if stop_if_static:
|
|
368
366
|
if verbose:
|
|
369
|
-
print("[
|
|
367
|
+
print("[torch_export_patches] restored ShapeEnv._set_replacement")
|
|
370
368
|
|
|
371
369
|
ShapeEnv._set_replacement = f_shape_env__set_replacement
|
|
372
370
|
|
|
373
371
|
if verbose:
|
|
374
|
-
print("[
|
|
372
|
+
print("[torch_export_patches] restored ShapeEnv._log_guard")
|
|
375
373
|
|
|
376
374
|
ShapeEnv._log_guard = f_shape_env__log_guard
|
|
377
375
|
|
|
378
376
|
if stop_if_static > 1:
|
|
379
377
|
if verbose:
|
|
380
|
-
print("[
|
|
378
|
+
print("[torch_export_patches] restored ShapeEnv._check_frozen")
|
|
381
379
|
ShapeEnv._check_frozen = f_shape_env__check_frozen
|
|
382
380
|
|
|
383
381
|
if catch_constraints:
|
|
@@ -389,11 +387,11 @@ def bypass_export_some_errors(
|
|
|
389
387
|
f__check_input_constraints_for_graph
|
|
390
388
|
)
|
|
391
389
|
if verbose:
|
|
392
|
-
print("[
|
|
390
|
+
print("[torch_export_patches] restored shape constraints")
|
|
393
391
|
|
|
394
392
|
if custom_patches:
|
|
395
393
|
if verbose:
|
|
396
|
-
print("[
|
|
394
|
+
print("[torch_export_patches] unpatch custom patches")
|
|
397
395
|
unpatch_module_or_classes(
|
|
398
396
|
custom_patches, revert_custom_patches_info, verbose=verbose
|
|
399
397
|
)
|
|
@@ -404,7 +402,7 @@ def bypass_export_some_errors(
|
|
|
404
402
|
|
|
405
403
|
if patch_transformers:
|
|
406
404
|
if verbose:
|
|
407
|
-
print("[
|
|
405
|
+
print("[torch_export_patches] unpatch transformers")
|
|
408
406
|
unpatch_module_or_classes(
|
|
409
407
|
patch_transformers_list, revert_patches_info, verbose=verbose
|
|
410
408
|
)
|