onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import Callable, Set
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.torch_helper import is_torchdynamo_exporting
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def make_undefined_dimension(i: int) -> torch.SymInt:
|
|
7
|
+
"""
|
|
8
|
+
Uses for a custom op when a new dimension must be introduced to bypass
|
|
9
|
+
some verification. The following function creates a dummy output
|
|
10
|
+
with a dimension based on the content.
|
|
11
|
+
|
|
12
|
+
.. code-block:: python
|
|
13
|
+
|
|
14
|
+
def symbolic_shape(x, y):
|
|
15
|
+
return torch.empty(
|
|
16
|
+
x.shape[0],
|
|
17
|
+
make_undefined_dimension(min(x.shape[1], y[0])),
|
|
18
|
+
)
|
|
19
|
+
"""
|
|
20
|
+
try:
|
|
21
|
+
ti = int(i)
|
|
22
|
+
except: # noqa: E722
|
|
23
|
+
ti = 10
|
|
24
|
+
t = torch.ones((ti * 2,))
|
|
25
|
+
t[:ti] = 0
|
|
26
|
+
res = torch.nonzero(t).shape[0]
|
|
27
|
+
return res
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _patched_float_arange(
|
|
31
|
+
start: torch.Tensor, end: torch.Tensor, step: torch.Tensor
|
|
32
|
+
) -> torch.Tensor:
|
|
33
|
+
"""Float arange."""
|
|
34
|
+
return torch.arange(
|
|
35
|
+
float(start.item()),
|
|
36
|
+
float(end.item()),
|
|
37
|
+
float(step.item()),
|
|
38
|
+
dtype=start.dtype,
|
|
39
|
+
device=start.device,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _patched_float_arange_shape(start, end, step):
|
|
44
|
+
# Fails because:
|
|
45
|
+
# Did you accidentally call new_dynamic_size() or item()
|
|
46
|
+
# more times than you needed to in your fake implementation?
|
|
47
|
+
# try:
|
|
48
|
+
# n = math.ceil(((end - start) / step).item())
|
|
49
|
+
# except: # noqa: E722
|
|
50
|
+
# n = 10
|
|
51
|
+
n = 10
|
|
52
|
+
return torch.empty((make_undefined_dimension(n),), dtype=start.dtype, device=start.device)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _iterate_patched_expressions():
|
|
56
|
+
glo = globals().copy()
|
|
57
|
+
for k, _v in glo.items():
|
|
58
|
+
if k.startswith("_patched_") and not k.endswith("_shape"):
|
|
59
|
+
name = k
|
|
60
|
+
yield k[len("_patched_") :], glo[name], glo[f"{name}_shape"]
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
_registered: Set[str] = set()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _register_patched_expression(
|
|
67
|
+
fct: Callable, fct_shape: Callable, namespace: str, fname: str
|
|
68
|
+
):
|
|
69
|
+
schema_str = torch.library.infer_schema(fct, mutates_args=())
|
|
70
|
+
custom_def = torch.library.CustomOpDef(namespace, fname, schema_str, fct)
|
|
71
|
+
custom_def.register_kernel("cpu")(fct)
|
|
72
|
+
custom_def._abstract_fn = fct_shape
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def register_patched_expressions(namespace: str = "patched"):
|
|
76
|
+
"""
|
|
77
|
+
Registers as custom ops known expressions failing due to dynamic shapes.
|
|
78
|
+
|
|
79
|
+
.. runpython::
|
|
80
|
+
:showcode:
|
|
81
|
+
|
|
82
|
+
import pprint
|
|
83
|
+
from onnx_diagnostic.torch_export_patches.patch_expressions import (
|
|
84
|
+
_iterate_patched_expressions,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
pprint.pprint([name for name, _f, _fsh in _iterate_patched_expressions()])
|
|
88
|
+
"""
|
|
89
|
+
for name, f, fsh in _iterate_patched_expressions():
|
|
90
|
+
if name not in _registered:
|
|
91
|
+
_register_patched_expression(f, fsh, namespace, name)
|
|
92
|
+
_registered.add(name)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def patched_selector(fct: Callable, patched_fct: Callable) -> Callable:
|
|
96
|
+
"""
|
|
97
|
+
Returns **fct** if the model is being executed or
|
|
98
|
+
**patched_fct** if it is being exported.
|
|
99
|
+
"""
|
|
100
|
+
return patched_fct if is_torchdynamo_exporting() else fct
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def patched_float_arange(start, end, step):
|
|
104
|
+
"""Patched arange when start, end, step are floats."""
|
|
105
|
+
if is_torchdynamo_exporting():
|
|
106
|
+
return torch.ops.patched.float_arange(start, end, step)
|
|
107
|
+
else:
|
|
108
|
+
return torch.arange(start, end, step)
|
|
@@ -0,0 +1,211 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from typing import Any, Dict, Optional, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
import transformers
|
|
5
|
+
from ..helpers import string_type
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def _process_cache(k: str, v):
|
|
9
|
+
assert k != "position_ids" or isinstance(
|
|
10
|
+
k, torch.Tensor
|
|
11
|
+
), f"Unexpected type for parameter {k!r} {string_type(v, with_shape=True)}"
|
|
12
|
+
if (
|
|
13
|
+
isinstance(v, list)
|
|
14
|
+
and all(isinstance(i, tuple) for i in v)
|
|
15
|
+
and set(len(t) for t in v) == {2}
|
|
16
|
+
):
|
|
17
|
+
# A dynamicCache
|
|
18
|
+
from ..helpers.cache_helper import make_dynamic_cache
|
|
19
|
+
|
|
20
|
+
cache = make_dynamic_cache(v)
|
|
21
|
+
return cache
|
|
22
|
+
if isinstance(v, torch.Tensor):
|
|
23
|
+
return v
|
|
24
|
+
raise NotImplementedError(
|
|
25
|
+
f"Unable to process parameter {k!r} with v={string_type(v,with_shape=True)}"
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _make_shape(subset: Dict, cls: type, value: Any) -> Any:
|
|
30
|
+
if cls is transformers.cache_utils.DynamicCache:
|
|
31
|
+
assert subset, "DynamicCache cannot be empty"
|
|
32
|
+
values = set(map(str, subset.values()))
|
|
33
|
+
assert len(values) == 1, (
|
|
34
|
+
f"Inconsistencies in subset={subset}, found={values}, "
|
|
35
|
+
f"it cannot be a {cls}, value={string_type(value)}"
|
|
36
|
+
)
|
|
37
|
+
cache_length = len(value.layers if hasattr(value, "layers") else value.key_cache)
|
|
38
|
+
for v in subset.values():
|
|
39
|
+
axes = v
|
|
40
|
+
break
|
|
41
|
+
new_shape = [axes for i in range(cache_length * 2)]
|
|
42
|
+
return new_shape
|
|
43
|
+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
44
|
+
raise NotImplementedError(
|
|
45
|
+
f"_make_shape not implemented for registered class={cls}, "
|
|
46
|
+
f"subset={subset}, value={string_type(value)}"
|
|
47
|
+
)
|
|
48
|
+
raise NotImplementedError(
|
|
49
|
+
f"_make_shape not implemented for cls={cls}, "
|
|
50
|
+
f"subset={subset}, value={string_type(value)}"
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def convert_dynamic_axes_into_dynamic_shapes(
|
|
55
|
+
model: torch.nn.Module,
|
|
56
|
+
args: Optional[Tuple[Any, ...]] = None,
|
|
57
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
58
|
+
dynamic_axes: Optional[Dict[str, Dict[int, str]]] = None,
|
|
59
|
+
prefix_mapping: Optional[Dict[str, str]] = None,
|
|
60
|
+
verbose: int = 0,
|
|
61
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any], Dict[str, Any]]:
|
|
62
|
+
"""
|
|
63
|
+
Converts the input from an export to something :func:`torch.export.export` can handle.
|
|
64
|
+
|
|
65
|
+
:param model: model to convert (used to extract the signature)
|
|
66
|
+
:param args: positional arguments
|
|
67
|
+
:param kwargs: named arguments
|
|
68
|
+
:param dynamic_axes: dynamic axes
|
|
69
|
+
:param prefix_mapping: prefix mapping
|
|
70
|
+
:param verbose: verbosity
|
|
71
|
+
:return: (args, kwargs, dynamic shapes)
|
|
72
|
+
"""
|
|
73
|
+
from ..helpers.cache_helper import CacheKeyValue
|
|
74
|
+
|
|
75
|
+
new_kwargs = {}
|
|
76
|
+
if args:
|
|
77
|
+
assert hasattr(model, "forward"), f"Missing method 'forward' for {model!r}"
|
|
78
|
+
plus = 0 if isinstance(model, torch.nn.Module) else 1
|
|
79
|
+
print(
|
|
80
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] "
|
|
81
|
+
f"mapping args to kwargs for model="
|
|
82
|
+
f"{model if plus else model.__class__.__name__}"
|
|
83
|
+
)
|
|
84
|
+
pars = inspect.signature(model.forward).parameters
|
|
85
|
+
assert len(pars) >= len(
|
|
86
|
+
args
|
|
87
|
+
), f"Length mismatch, len(args)={len(args)}, pars={list(pars)}"
|
|
88
|
+
|
|
89
|
+
for i, p in enumerate(pars):
|
|
90
|
+
if i < plus:
|
|
91
|
+
continue
|
|
92
|
+
if i - plus >= len(args):
|
|
93
|
+
break
|
|
94
|
+
if verbose:
|
|
95
|
+
print(
|
|
96
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] mapping args[{i-plus}] "
|
|
97
|
+
f"to {p!r} ({string_type(args[i-plus])})"
|
|
98
|
+
)
|
|
99
|
+
new_kwargs[p] = args[i - plus]
|
|
100
|
+
|
|
101
|
+
if kwargs:
|
|
102
|
+
for k, v in kwargs.items():
|
|
103
|
+
assert k not in new_kwargs, f"Argument {k!r} from kwargs already present in args."
|
|
104
|
+
new_kwargs[k] = v
|
|
105
|
+
|
|
106
|
+
# process
|
|
107
|
+
updated_kwargs = {}
|
|
108
|
+
changes = {}
|
|
109
|
+
for k, v in new_kwargs.items():
|
|
110
|
+
if isinstance(v, torch.Tensor):
|
|
111
|
+
updated_kwargs[k] = v
|
|
112
|
+
continue
|
|
113
|
+
if isinstance(v, list):
|
|
114
|
+
# cache?
|
|
115
|
+
updated_kwargs[k] = _process_cache(k, v)
|
|
116
|
+
if type(updated_kwargs[k]) is not type(v):
|
|
117
|
+
# A cache was introduced.
|
|
118
|
+
if verbose:
|
|
119
|
+
print(
|
|
120
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] parameter "
|
|
121
|
+
f"{k!r} was changed into {type(updated_kwargs[k])}"
|
|
122
|
+
)
|
|
123
|
+
changes[k] = type(updated_kwargs[k])
|
|
124
|
+
continue
|
|
125
|
+
if isinstance(v, transformers.cache_utils.DynamicCache):
|
|
126
|
+
ca = CacheKeyValue(v)
|
|
127
|
+
updated_kwargs[k] = [ca.key_cache, ca.value_cache]
|
|
128
|
+
changes[k] = type(v)
|
|
129
|
+
continue
|
|
130
|
+
raise NotImplementedError(
|
|
131
|
+
f"Unexpected type {type(v)} for parameter {k!r} "
|
|
132
|
+
f"({string_type(v, with_shape=True)})"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
# process dynamic axes
|
|
136
|
+
if changes:
|
|
137
|
+
dynamic_shapes = {}
|
|
138
|
+
done = set()
|
|
139
|
+
for k, v in dynamic_axes.items():
|
|
140
|
+
if k not in changes and k in updated_kwargs and isinstance(v, dict):
|
|
141
|
+
dynamic_shapes[k] = v
|
|
142
|
+
continue
|
|
143
|
+
if (
|
|
144
|
+
k in updated_kwargs
|
|
145
|
+
and k in changes
|
|
146
|
+
and changes[k] == transformers.cache_utils.DynamicCache
|
|
147
|
+
):
|
|
148
|
+
dynamic_shapes[k] = v
|
|
149
|
+
continue
|
|
150
|
+
if "." in k:
|
|
151
|
+
# something like present.0.key
|
|
152
|
+
prefix = k.split(".")[0]
|
|
153
|
+
if prefix in done:
|
|
154
|
+
continue
|
|
155
|
+
args_prefix = (
|
|
156
|
+
prefix_mapping[prefix]
|
|
157
|
+
if prefix_mapping and prefix in prefix_mapping
|
|
158
|
+
else prefix
|
|
159
|
+
)
|
|
160
|
+
if args_prefix in updated_kwargs and args_prefix in changes:
|
|
161
|
+
# A cache.
|
|
162
|
+
cls = changes[args_prefix]
|
|
163
|
+
dynamic_shapes[args_prefix] = _make_shape(
|
|
164
|
+
{
|
|
165
|
+
_: __
|
|
166
|
+
for _, __ in dynamic_axes.items()
|
|
167
|
+
if _.startswith(f"{prefix}.")
|
|
168
|
+
},
|
|
169
|
+
cls,
|
|
170
|
+
updated_kwargs[args_prefix],
|
|
171
|
+
)
|
|
172
|
+
done.add(prefix)
|
|
173
|
+
continue
|
|
174
|
+
if k not in updated_kwargs:
|
|
175
|
+
# dynamic axes not in the given inputs, should be raise an exception?
|
|
176
|
+
if verbose:
|
|
177
|
+
print(
|
|
178
|
+
f"[convert_dynamic_axes_into_dynamic_shapes] dropping axes "
|
|
179
|
+
f"{k!r}-{v!r}, not found in {set(updated_kwargs)}"
|
|
180
|
+
)
|
|
181
|
+
continue
|
|
182
|
+
raise NotImplementedError(
|
|
183
|
+
f"Unable to process dynamic axes {k!r}, axes={v}, "
|
|
184
|
+
f"value={string_type(updated_kwargs[k], with_shape=True)}, "
|
|
185
|
+
f"dynamic axes={dynamic_axes}, "
|
|
186
|
+
f"updated_kwargs={string_type(updated_kwargs, with_shape=True)}"
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
return (), updated_kwargs, dynamic_shapes
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def use_dyn_not_str(dynamic_shapes: Any, default_value=None) -> Any:
|
|
193
|
+
"""
|
|
194
|
+
Some functions returns dynamic shapes as string.
|
|
195
|
+
This functions replaces them with ``torch.export.Dim.DYNAMIC``.
|
|
196
|
+
``default_value=torch.export.Dim.AUTO`` changes the default value.
|
|
197
|
+
"""
|
|
198
|
+
if isinstance(dynamic_shapes, list):
|
|
199
|
+
return [use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes]
|
|
200
|
+
if isinstance(dynamic_shapes, tuple):
|
|
201
|
+
return tuple(use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes)
|
|
202
|
+
if isinstance(dynamic_shapes, dict):
|
|
203
|
+
return {
|
|
204
|
+
k: use_dyn_not_str(v, default_value=default_value)
|
|
205
|
+
for k, v in dynamic_shapes.items()
|
|
206
|
+
}
|
|
207
|
+
if isinstance(dynamic_shapes, set):
|
|
208
|
+
return {use_dyn_not_str(a, default_value=default_value) for a in dynamic_shapes}
|
|
209
|
+
if isinstance(dynamic_shapes, str):
|
|
210
|
+
return torch.export.Dim.DYNAMIC if default_value is None else default_value
|
|
211
|
+
return dynamic_shapes
|