onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import importlib
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import re
|
|
6
|
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
7
|
+
import transformers
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]):
|
|
11
|
+
"""
|
|
12
|
+
Checks the confiugation has all the attributes in ``args``.
|
|
13
|
+
Raises an exception otherwise.
|
|
14
|
+
"""
|
|
15
|
+
for a in args:
|
|
16
|
+
assert isinstance(a, (str, tuple)), f"unexpected type {type(a)} in {args!r}"
|
|
17
|
+
if isinstance(a, str):
|
|
18
|
+
assert (isinstance(config, dict) and a in config) or hasattr(
|
|
19
|
+
config, a
|
|
20
|
+
), f"Missing attribute {a!r} in\n{config}"
|
|
21
|
+
elif isinstance(a, tuple):
|
|
22
|
+
assert any(
|
|
23
|
+
(isinstance(name, str) and hasattr(config, name))
|
|
24
|
+
or all(hasattr(config, _) for _ in name)
|
|
25
|
+
for name in a
|
|
26
|
+
), f"All attributes in {a!r} are missing from\n{config}"
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def update_config(config: Any, mkwargs: Dict[str, Any]):
|
|
30
|
+
"""Updates a configuration with different values."""
|
|
31
|
+
for k, v in mkwargs.items():
|
|
32
|
+
if k == "attn_implementation":
|
|
33
|
+
config._attn_implementation = v
|
|
34
|
+
if getattr(config, "_attn_implementation_autoset", False):
|
|
35
|
+
config._attn_implementation_autoset = False
|
|
36
|
+
continue
|
|
37
|
+
if isinstance(v, dict):
|
|
38
|
+
if not hasattr(config, k) or getattr(config, k) is None:
|
|
39
|
+
setattr(config, k, v)
|
|
40
|
+
continue
|
|
41
|
+
existing = getattr(config, k)
|
|
42
|
+
if type(existing) is dict:
|
|
43
|
+
existing.update(v)
|
|
44
|
+
else:
|
|
45
|
+
update_config(getattr(config, k), v)
|
|
46
|
+
continue
|
|
47
|
+
if type(config) is dict:
|
|
48
|
+
config[k] = v
|
|
49
|
+
else:
|
|
50
|
+
setattr(config, k, v)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
|
|
54
|
+
"""Returns the first value found in the configuration."""
|
|
55
|
+
if (
|
|
56
|
+
exceptions
|
|
57
|
+
and hasattr(config, "architectures")
|
|
58
|
+
and len(config.architectures) == 1
|
|
59
|
+
and config.architectures[0] in exceptions
|
|
60
|
+
):
|
|
61
|
+
excs = exceptions[config.architectures[0]]
|
|
62
|
+
return excs(config)
|
|
63
|
+
for a in atts:
|
|
64
|
+
if isinstance(a, str):
|
|
65
|
+
if hasattr(config, a):
|
|
66
|
+
return getattr(config, a)
|
|
67
|
+
elif isinstance(a, tuple):
|
|
68
|
+
if all(hasattr(config, _) for _ in a[1:]):
|
|
69
|
+
return a[0]([getattr(config, _) for _ in a[1:]])
|
|
70
|
+
raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def pick(config, name: str, default_value: Any) -> Any:
|
|
74
|
+
"""
|
|
75
|
+
Returns the value of a attribute if config has it
|
|
76
|
+
otherwise the default value.
|
|
77
|
+
"""
|
|
78
|
+
if not config:
|
|
79
|
+
return default_value
|
|
80
|
+
if type(config) is dict:
|
|
81
|
+
return config.get(name, default_value)
|
|
82
|
+
return getattr(config, name, default_value)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@functools.cache
|
|
86
|
+
def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
|
|
87
|
+
"""
|
|
88
|
+
Retrieves the configuration class for a given architecture.
|
|
89
|
+
|
|
90
|
+
:param arch: architecture (clas name)
|
|
91
|
+
:param exc: raise an exception if not found
|
|
92
|
+
:return: type
|
|
93
|
+
"""
|
|
94
|
+
cls = getattr(transformers, arch)
|
|
95
|
+
mod_name = cls.__module__
|
|
96
|
+
mod = importlib.import_module(mod_name)
|
|
97
|
+
source = inspect.getsource(mod)
|
|
98
|
+
# [^O] avoids capturing Optional[Something]
|
|
99
|
+
reg = re.compile("config: ([^O][A-Za-z0-9]+)")
|
|
100
|
+
fall = reg.findall(source)
|
|
101
|
+
if len(fall) == 0:
|
|
102
|
+
assert not exc, (
|
|
103
|
+
f"Unable to guess Configuration class name for arch={arch!r}, "
|
|
104
|
+
f"module={mod_name!r}, no candidate, source is\n{source}"
|
|
105
|
+
)
|
|
106
|
+
return None
|
|
107
|
+
unique = set(fall)
|
|
108
|
+
assert len(unique) == 1, (
|
|
109
|
+
f"Unable to guess Configuration class name for arch={arch!r}, "
|
|
110
|
+
f"module={mod_name!r}, found={unique} (#{len(unique)}), "
|
|
111
|
+
f"source is\n{source}"
|
|
112
|
+
)
|
|
113
|
+
cls_name = unique.pop()
|
|
114
|
+
return getattr(transformers, cls_name)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def default_num_hidden_layers():
|
|
118
|
+
"""
|
|
119
|
+
Returns the default number of layers.
|
|
120
|
+
It is lower when the unit tests are running
|
|
121
|
+
when ``UNITTEST_GOING=1``.
|
|
122
|
+
"""
|
|
123
|
+
import torch
|
|
124
|
+
|
|
125
|
+
if torch.cuda.is_available():
|
|
126
|
+
capa = torch.cuda.get_device_capability(0)
|
|
127
|
+
if capa[0] < 9:
|
|
128
|
+
return 2
|
|
129
|
+
return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def build_diff_config(config0, config1):
|
|
133
|
+
"""
|
|
134
|
+
Returns all the modified values between two configuration
|
|
135
|
+
"""
|
|
136
|
+
import torch
|
|
137
|
+
|
|
138
|
+
diff = {}
|
|
139
|
+
for k in config0:
|
|
140
|
+
assert isinstance(k, str), f"k={k!r}, wrong type in {config0}"
|
|
141
|
+
if k not in config1:
|
|
142
|
+
v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
|
|
143
|
+
diff[k] = f"-{v0}"
|
|
144
|
+
for k in config1:
|
|
145
|
+
assert isinstance(k, str), f"k={k!r}, wrong type in {config1}"
|
|
146
|
+
if k not in config0:
|
|
147
|
+
v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
|
|
148
|
+
diff[k] = f"+{v1}"
|
|
149
|
+
for k in config0:
|
|
150
|
+
if k not in config1:
|
|
151
|
+
continue
|
|
152
|
+
v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
|
|
153
|
+
v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
|
|
154
|
+
if (
|
|
155
|
+
v0 is None
|
|
156
|
+
or v1 is None
|
|
157
|
+
or isinstance(v1, (float, int, bool, str, list, tuple, torch.dtype))
|
|
158
|
+
or (
|
|
159
|
+
isinstance(v0, dict)
|
|
160
|
+
and isinstance(v1, dict)
|
|
161
|
+
and all(isinstance(k, int) for k in v1)
|
|
162
|
+
)
|
|
163
|
+
):
|
|
164
|
+
if v1 != v0:
|
|
165
|
+
diff[k] = f"{v0} -> {v1}"
|
|
166
|
+
else:
|
|
167
|
+
d = build_diff_config(v0, v1)
|
|
168
|
+
if d:
|
|
169
|
+
diff[k] = d
|
|
170
|
+
return diff
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Dict, List, Optional, Tuple
|
|
3
|
+
import onnx
|
|
4
|
+
import onnx.helper as oh
|
|
5
|
+
import torch
|
|
6
|
+
from ..reference.torch_ops import OpRunKernel, OpRunTensor
|
|
7
|
+
from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
|
|
8
|
+
from .ort_session import InferenceSessionForTorch
|
|
9
|
+
|
|
10
|
+
_SAVED: List[str] = []
|
|
11
|
+
_SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _get_model_name(op_name: str, provider: str) -> Optional[str]:
|
|
15
|
+
if _SAVE_OPTIMIZED_MODEL_:
|
|
16
|
+
name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
|
|
17
|
+
_SAVED.append(name)
|
|
18
|
+
return name
|
|
19
|
+
return None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LayerNormalizationOrt(OpRunKernel):
|
|
23
|
+
"LayerNormalization with onnxruntime"
|
|
24
|
+
|
|
25
|
+
@classmethod
|
|
26
|
+
def device_dependent(cls) -> bool:
|
|
27
|
+
"Needs device."
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
node: onnx.NodeProto,
|
|
33
|
+
version=None,
|
|
34
|
+
device: Optional[torch.device] = None,
|
|
35
|
+
verbose: int = 0,
|
|
36
|
+
):
|
|
37
|
+
super().__init__(node, version, verbose=verbose)
|
|
38
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
39
|
+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
|
|
40
|
+
self.device = device
|
|
41
|
+
self.stash_type = onnx_dtype_to_torch_dtype(
|
|
42
|
+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
|
|
43
|
+
)
|
|
44
|
+
self.compute_std = len(node.output) > 1
|
|
45
|
+
assert not self.compute_std, (
|
|
46
|
+
f"This kernel implementation only work when only one output "
|
|
47
|
+
f"is required but {node.output} were."
|
|
48
|
+
)
|
|
49
|
+
self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
|
|
50
|
+
self.is_cpu = torch.device("cpu") == self.device
|
|
51
|
+
|
|
52
|
+
def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
|
|
53
|
+
shape = [*["d{i}" for i in range(rank - 1)], "last"]
|
|
54
|
+
layer_model = oh.make_model(
|
|
55
|
+
oh.make_graph(
|
|
56
|
+
[
|
|
57
|
+
oh.make_node(
|
|
58
|
+
"LayerNormalization",
|
|
59
|
+
["X", "W", "B"] if has_bias else ["X", "W"],
|
|
60
|
+
["Z"],
|
|
61
|
+
axis=self.axis,
|
|
62
|
+
epsilon=self.epsilon,
|
|
63
|
+
)
|
|
64
|
+
],
|
|
65
|
+
"dummy",
|
|
66
|
+
(
|
|
67
|
+
[
|
|
68
|
+
oh.make_tensor_value_info("X", itype, shape),
|
|
69
|
+
oh.make_tensor_value_info("W", itype, ["last"]),
|
|
70
|
+
oh.make_tensor_value_info("B", itype, ["last"]),
|
|
71
|
+
]
|
|
72
|
+
if has_bias
|
|
73
|
+
else [
|
|
74
|
+
oh.make_tensor_value_info("X", itype, shape),
|
|
75
|
+
oh.make_tensor_value_info("W", itype, ["last"]),
|
|
76
|
+
]
|
|
77
|
+
),
|
|
78
|
+
[oh.make_tensor_value_info("Z", itype, shape)],
|
|
79
|
+
),
|
|
80
|
+
ir_version=9,
|
|
81
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
82
|
+
)
|
|
83
|
+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
|
|
84
|
+
self._provider = provider
|
|
85
|
+
return InferenceSessionForTorch(
|
|
86
|
+
layer_model,
|
|
87
|
+
optimized_model_filepath=_get_model_name("layer_norm", provider),
|
|
88
|
+
providers=[provider],
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
def run(self, x, scale, bias=None):
|
|
92
|
+
itype = torch_dtype_to_onnx_dtype(x.dtype)
|
|
93
|
+
rank = len(x.shape)
|
|
94
|
+
key = itype, rank
|
|
95
|
+
if key not in self._cache:
|
|
96
|
+
self._cache[key] = self._make_model(itype, rank, bias is not None)
|
|
97
|
+
sess = self._cache[key]
|
|
98
|
+
if self.verbose:
|
|
99
|
+
print(f"[LayerNormalizationOrt] running on {self._provider!r}")
|
|
100
|
+
feeds = dict(X=x.tensor, W=scale.tensor)
|
|
101
|
+
if bias is not None:
|
|
102
|
+
feeds["B"] = bias.tensor
|
|
103
|
+
got = sess.run(None, feeds)[0]
|
|
104
|
+
return OpRunTensor(got)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class MatMulOrt(OpRunKernel):
|
|
108
|
+
"MatMul with onnxruntime"
|
|
109
|
+
|
|
110
|
+
@classmethod
|
|
111
|
+
def device_dependent(cls) -> bool:
|
|
112
|
+
"Needs device."
|
|
113
|
+
return True
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
node: onnx.NodeProto,
|
|
118
|
+
version=None,
|
|
119
|
+
device: Optional[torch.device] = None,
|
|
120
|
+
verbose: int = 0,
|
|
121
|
+
):
|
|
122
|
+
super().__init__(node, version, verbose=verbose)
|
|
123
|
+
self.device = device
|
|
124
|
+
self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
|
|
125
|
+
self.is_cpu = torch.device("cpu") == self.device
|
|
126
|
+
|
|
127
|
+
def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
|
|
128
|
+
shapea = ["a{i}" for i in range(ranka)]
|
|
129
|
+
shapeb = ["b{i}" for i in range(rankb)]
|
|
130
|
+
shapec = ["c{i}" for i in range(max(ranka, rankb))]
|
|
131
|
+
model = oh.make_model(
|
|
132
|
+
oh.make_graph(
|
|
133
|
+
[oh.make_node("MatMul", ["A", "B"], ["C"])],
|
|
134
|
+
"dummy",
|
|
135
|
+
[
|
|
136
|
+
oh.make_tensor_value_info("A", itype, shapea),
|
|
137
|
+
oh.make_tensor_value_info("B", itype, shapeb),
|
|
138
|
+
],
|
|
139
|
+
[oh.make_tensor_value_info("C", itype, shapec)],
|
|
140
|
+
),
|
|
141
|
+
ir_version=9,
|
|
142
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
143
|
+
)
|
|
144
|
+
provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
|
|
145
|
+
self._provider = provider
|
|
146
|
+
return InferenceSessionForTorch(
|
|
147
|
+
model,
|
|
148
|
+
optimized_model_filepath=_get_model_name("matmul", provider),
|
|
149
|
+
providers=[provider],
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def run(self, a, b):
|
|
153
|
+
itype = torch_dtype_to_onnx_dtype(a.dtype)
|
|
154
|
+
ranka, rankb = len(a.shape), len(b.shape)
|
|
155
|
+
key = itype, ranka, rankb
|
|
156
|
+
if key not in self._cache:
|
|
157
|
+
self._cache[key] = self._make_model(itype, ranka, rankb)
|
|
158
|
+
sess = self._cache[key]
|
|
159
|
+
if self.verbose:
|
|
160
|
+
print(f"[MatMulOrt] running on {self._provider!r}")
|
|
161
|
+
feeds = dict(A=a.tensor, B=b.tensor)
|
|
162
|
+
got = sess.run(None, feeds)[0]
|
|
163
|
+
return OpRunTensor(got)
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
from typing import Any, Dict, Optional, Set, Tuple
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class FakeTensorContext:
|
|
5
|
+
"""Stores information used to reused same dimension for the same dimension names."""
|
|
6
|
+
|
|
7
|
+
def __init__(self, fake_mode: Optional["FakeTensorMode"] = None): # noqa: F821
|
|
8
|
+
if fake_mode is None:
|
|
9
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
10
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
11
|
+
|
|
12
|
+
shape_env = ShapeEnv()
|
|
13
|
+
self.fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
14
|
+
else:
|
|
15
|
+
self.fake_mode = fake_mode
|
|
16
|
+
self._candidates = self._first_primes()
|
|
17
|
+
self._unique_: Set[str] = set()
|
|
18
|
+
self._mapping_int: Dict[int, str] = {}
|
|
19
|
+
self._mapping_str: Dict[str, int] = {}
|
|
20
|
+
|
|
21
|
+
@classmethod
|
|
22
|
+
def _first_primes(cls, n=1000):
|
|
23
|
+
sieve = [True] * (n + 1)
|
|
24
|
+
sieve[0:2] = [False, False]
|
|
25
|
+
|
|
26
|
+
for i in range(2, int(n**0.5) + 1):
|
|
27
|
+
if sieve[i]:
|
|
28
|
+
# Élimine les multiples de i
|
|
29
|
+
sieve[i * i : n + 1 : i] = [False] * len(range(i * i, n + 1, i))
|
|
30
|
+
|
|
31
|
+
return [i for i, prime in enumerate(sieve) if prime and i >= 13]
|
|
32
|
+
|
|
33
|
+
def _unique(self) -> int:
|
|
34
|
+
i = 0
|
|
35
|
+
c = self._candidates[i]
|
|
36
|
+
while c in self._unique_ or c in self._mapping_int:
|
|
37
|
+
i += 1
|
|
38
|
+
assert i < len(
|
|
39
|
+
self._candidates
|
|
40
|
+
), f"Two many unique dimensions to generate, requested: {len(self._unique_)}"
|
|
41
|
+
c = self._candidates[i]
|
|
42
|
+
self._unique_.add(c)
|
|
43
|
+
return c
|
|
44
|
+
|
|
45
|
+
def from_tensor(self, x, static_shapes=False) -> "FakeTensor": # noqa: F821
|
|
46
|
+
"""
|
|
47
|
+
Returns a fake tensor.
|
|
48
|
+
``pytorch`` returns the same name for the same dimension.
|
|
49
|
+
"""
|
|
50
|
+
fake = self.fake_mode.from_tensor(x, static_shapes=static_shapes)
|
|
51
|
+
for i, s in zip(x.shape, fake.shape):
|
|
52
|
+
assert i not in self._mapping_int or self._mapping_int[i] == s, (
|
|
53
|
+
f"Inconsistency between {x.shape} and {fake.shape}, "
|
|
54
|
+
f"mapping has {self._mapping_int[i]} and s={s}"
|
|
55
|
+
)
|
|
56
|
+
self._mapping_int[i] = s
|
|
57
|
+
return fake
|
|
58
|
+
|
|
59
|
+
def fake_reshape(
|
|
60
|
+
self,
|
|
61
|
+
true_tensor: "torch.Tensor", # noqa: F821
|
|
62
|
+
sh: Dict[int, Any], # noqa: F821
|
|
63
|
+
fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
|
|
64
|
+
) -> "FakeTensor": # noqa: F821
|
|
65
|
+
"""
|
|
66
|
+
Changes the shape of a true tensor to make it dynamic.
|
|
67
|
+
|
|
68
|
+
:param true_tensor: true tensor
|
|
69
|
+
:param sh: dynamic shape
|
|
70
|
+
:param fake_tensor: fake tensor, if None, make a fake one
|
|
71
|
+
:return: fake tensor
|
|
72
|
+
"""
|
|
73
|
+
import torch
|
|
74
|
+
|
|
75
|
+
# deal with 0/1
|
|
76
|
+
for i in sh:
|
|
77
|
+
if true_tensor.shape[i] <= 1:
|
|
78
|
+
expanded_shape = list(true_tensor.shape)
|
|
79
|
+
expanded_shape[i] = self._unique()
|
|
80
|
+
true_tensor = torch.empty(
|
|
81
|
+
tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# deal with equivalent dimension
|
|
85
|
+
new_shape = list(true_tensor.shape)
|
|
86
|
+
mapping = {}
|
|
87
|
+
for i, s in sh.items():
|
|
88
|
+
d = true_tensor.shape[i]
|
|
89
|
+
if d not in mapping:
|
|
90
|
+
mapping[d] = s
|
|
91
|
+
elif mapping[d] != s:
|
|
92
|
+
d = self._unique()
|
|
93
|
+
mapping[d] = s
|
|
94
|
+
new_shape[i] = d
|
|
95
|
+
true_tensor = torch.empty(
|
|
96
|
+
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# now switch to FakeTensor
|
|
100
|
+
fake_tensor = self.from_tensor(true_tensor, static_shapes=False)
|
|
101
|
+
new_shape = list(true_tensor.shape)
|
|
102
|
+
for i in sh:
|
|
103
|
+
new_shape[i] = fake_tensor.shape[i]
|
|
104
|
+
|
|
105
|
+
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
|
|
106
|
+
axis=tuple(sorted(sh)), keepdim=True
|
|
107
|
+
)
|
|
108
|
+
return reduced_tensor.expand(*new_shape)
|
|
109
|
+
|
|
110
|
+
def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
|
|
111
|
+
"""See :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`."""
|
|
112
|
+
if x is None:
|
|
113
|
+
return None
|
|
114
|
+
if isinstance(x, (list, tuple)):
|
|
115
|
+
return x.__class__([self.make_fake(i) for i in x])
|
|
116
|
+
if isinstance(x, dict):
|
|
117
|
+
return {k: self.make_fake(v) for k, v in x.items()}
|
|
118
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
119
|
+
assert hasattr(x, "layers"), (
|
|
120
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
121
|
+
f"'layers' not found in class {type(x)}"
|
|
122
|
+
)
|
|
123
|
+
for layer in x.layers:
|
|
124
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
125
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
126
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
127
|
+
)
|
|
128
|
+
layer.keys = self.make_fake(layer.keys)
|
|
129
|
+
layer.values = self.make_fake(layer.values)
|
|
130
|
+
return x
|
|
131
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
132
|
+
self.make_fake(x.self_attention_cache)
|
|
133
|
+
self.make_fake(x.cross_attention_cache)
|
|
134
|
+
return x
|
|
135
|
+
if hasattr(x, "shape"):
|
|
136
|
+
return self.from_tensor(x, static_shapes=False)
|
|
137
|
+
from . import string_type
|
|
138
|
+
|
|
139
|
+
raise TypeError(
|
|
140
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
|
|
144
|
+
"""
|
|
145
|
+
See
|
|
146
|
+
:func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
|
|
147
|
+
"""
|
|
148
|
+
if x is None:
|
|
149
|
+
return None, None
|
|
150
|
+
if isinstance(x, (list, tuple)):
|
|
151
|
+
return x.__class__(
|
|
152
|
+
[
|
|
153
|
+
self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
|
|
154
|
+
for i, ds in zip(x, dynamic_shapes)
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
if isinstance(x, dict):
|
|
158
|
+
return {
|
|
159
|
+
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
|
|
160
|
+
for k, v in x.items()
|
|
161
|
+
}
|
|
162
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
163
|
+
assert hasattr(x, "layers"), (
|
|
164
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
165
|
+
f"'layers' not found in class {type(x)}"
|
|
166
|
+
)
|
|
167
|
+
assert isinstance(dynamic_shapes, list) and (
|
|
168
|
+
not dynamic_shapes or not isinstance(dynamic_shapes[0], list)
|
|
169
|
+
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
|
|
170
|
+
for il, layer in enumerate(x.layers):
|
|
171
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
172
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
173
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
174
|
+
)
|
|
175
|
+
layer.keys = self.make_fake_with_dynamic_dimensions(
|
|
176
|
+
layer.keys, dynamic_shapes=dynamic_shapes[il * 2]
|
|
177
|
+
)
|
|
178
|
+
layer.values = self.make_fake_with_dynamic_dimensions(
|
|
179
|
+
layer.values, dynamic_shapes=dynamic_shapes[il * 2 + 1]
|
|
180
|
+
)
|
|
181
|
+
return x
|
|
182
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
183
|
+
self.make_fake_with_dynamic_dimensions(
|
|
184
|
+
x.self_attention_cache, dynamic_shapes=dynamic_shapes[0]
|
|
185
|
+
)
|
|
186
|
+
self.make_fake_with_dynamic_dimensions(
|
|
187
|
+
x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
|
|
188
|
+
)
|
|
189
|
+
return x
|
|
190
|
+
if hasattr(x, "shape"):
|
|
191
|
+
assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
|
|
192
|
+
f"dynamic_shapes must be a dictionary at this stage but "
|
|
193
|
+
f"dynamic_shapes={dynamic_shapes}"
|
|
194
|
+
)
|
|
195
|
+
# We need to overwrite the values.
|
|
196
|
+
new_shape = []
|
|
197
|
+
for idim, dim in enumerate(x.shape):
|
|
198
|
+
if dynamic_shapes is not None and idim in dynamic_shapes:
|
|
199
|
+
s = dynamic_shapes[idim]
|
|
200
|
+
assert isinstance(s, str), (
|
|
201
|
+
f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
|
|
202
|
+
f"at index {idim}"
|
|
203
|
+
)
|
|
204
|
+
if s in self._mapping_str:
|
|
205
|
+
dim = self._mapping_str[s]
|
|
206
|
+
else:
|
|
207
|
+
i = self._unique()
|
|
208
|
+
self._mapping_str[s] = i
|
|
209
|
+
dim = i
|
|
210
|
+
assert isinstance(dim, int), (
|
|
211
|
+
f"Unexpected type {type(dim)}, dynamic_shapes={dynamic_shapes} "
|
|
212
|
+
f"at index {idim}, dim={dim}"
|
|
213
|
+
)
|
|
214
|
+
new_shape.append(dim)
|
|
215
|
+
if tuple(new_shape) != x.shape:
|
|
216
|
+
import torch
|
|
217
|
+
|
|
218
|
+
x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
|
|
219
|
+
|
|
220
|
+
t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
|
|
221
|
+
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
|
|
222
|
+
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
|
|
223
|
+
return t
|
|
224
|
+
from ..helpers import string_type
|
|
225
|
+
|
|
226
|
+
raise TypeError(
|
|
227
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def make_fake(
|
|
232
|
+
x: Any, context: Optional[FakeTensorContext] = None
|
|
233
|
+
) -> Tuple[Optional["FakeTensor"], Optional[FakeTensorContext]]: # noqa: F821
|
|
234
|
+
"""
|
|
235
|
+
Replaces all tensors by fake tensors.
|
|
236
|
+
This modification happens inplace for caches.
|
|
237
|
+
This function is only implemented for cache with
|
|
238
|
+
``transformers>=4.55``.
|
|
239
|
+
|
|
240
|
+
.. runpython::
|
|
241
|
+
:showcode:
|
|
242
|
+
|
|
243
|
+
import pprint
|
|
244
|
+
import torch
|
|
245
|
+
from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
|
|
246
|
+
from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
|
|
247
|
+
|
|
248
|
+
inputs, _ = make_fake(
|
|
249
|
+
dict(
|
|
250
|
+
input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
|
|
251
|
+
attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
|
|
252
|
+
position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
|
|
253
|
+
past_key_values=make_dynamic_cache(
|
|
254
|
+
[
|
|
255
|
+
(
|
|
256
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
257
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
258
|
+
),
|
|
259
|
+
(
|
|
260
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
261
|
+
torch.rand((2, 32, 30, 96), dtype=torch.float16),
|
|
262
|
+
),
|
|
263
|
+
]
|
|
264
|
+
),
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
pprint.pprint(inputs)
|
|
268
|
+
"""
|
|
269
|
+
if x is None:
|
|
270
|
+
return None, None
|
|
271
|
+
if context is None:
|
|
272
|
+
context = FakeTensorContext()
|
|
273
|
+
return context.make_fake(x), context
|