onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,987 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import ctypes
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
import warnings
|
|
7
|
+
from collections.abc import Iterable
|
|
8
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
9
|
+
import numpy as np
|
|
10
|
+
import onnx
|
|
11
|
+
from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
|
|
12
|
+
import torch
|
|
13
|
+
from .helper import string_type, size_type
|
|
14
|
+
from .cache_helper import (
|
|
15
|
+
make_dynamic_cache,
|
|
16
|
+
make_encoder_decoder_cache,
|
|
17
|
+
make_hybrid_cache,
|
|
18
|
+
make_sliding_window_cache,
|
|
19
|
+
make_mamba_cache,
|
|
20
|
+
make_static_cache,
|
|
21
|
+
CacheKeyValue,
|
|
22
|
+
)
|
|
23
|
+
from .mini_onnx_builder import create_onnx_model_from_input_tensors
|
|
24
|
+
from .onnx_helper import (
|
|
25
|
+
to_array_extended,
|
|
26
|
+
tensor_dtype_to_np_dtype,
|
|
27
|
+
_STORAGE_TYPE,
|
|
28
|
+
onnx_dtype_name,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def proto_from_tensor(
|
|
33
|
+
arr: "torch.Tensor", # noqa: F821
|
|
34
|
+
name: Optional[str] = None,
|
|
35
|
+
verbose: int = 0,
|
|
36
|
+
) -> onnx.TensorProto:
|
|
37
|
+
"""
|
|
38
|
+
Converts a torch Tensor into a TensorProto.
|
|
39
|
+
|
|
40
|
+
:param arr: tensor
|
|
41
|
+
:param verbose: display the type and shape
|
|
42
|
+
:return: a TensorProto
|
|
43
|
+
"""
|
|
44
|
+
import torch
|
|
45
|
+
|
|
46
|
+
if not isinstance(arr, torch.Tensor):
|
|
47
|
+
raise TypeError(f"Unexpected type {type(arr)}.")
|
|
48
|
+
if arr.is_sparse:
|
|
49
|
+
raise NotImplementedError(
|
|
50
|
+
f"Sparse tensor is not supported yet but initializer {name!r} is."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# arr.contiguous() is slow after a transpose, maybe there is a way to optimize this.
|
|
54
|
+
if arr.is_contiguous():
|
|
55
|
+
arr_cpu = arr.cpu()
|
|
56
|
+
else:
|
|
57
|
+
arr_cpu = arr.contiguous().cpu()
|
|
58
|
+
|
|
59
|
+
numel = torch.numel(arr_cpu)
|
|
60
|
+
element_size = arr_cpu.element_size()
|
|
61
|
+
|
|
62
|
+
if arr_cpu.dtype in {torch.bfloat16}:
|
|
63
|
+
np_arr = arr_cpu
|
|
64
|
+
elif arr_cpu.data_ptr() == arr.data_ptr():
|
|
65
|
+
copy = arr_cpu.clone().detach().requires_grad_(False)
|
|
66
|
+
assert (
|
|
67
|
+
arr_cpu.data_ptr() == 0 or arr_cpu.data_ptr() != copy.data_ptr()
|
|
68
|
+
), f"Pointers are not null and different {arr_cpu.data_ptr()} != {copy.data_ptr()}"
|
|
69
|
+
np_arr = np.from_dlpack(copy)
|
|
70
|
+
else:
|
|
71
|
+
np_arr = np.from_dlpack(arr_cpu.detach())
|
|
72
|
+
|
|
73
|
+
tensor = onnx.TensorProto()
|
|
74
|
+
tensor.dims.extend(arr_cpu.shape)
|
|
75
|
+
if name:
|
|
76
|
+
tensor.name = name
|
|
77
|
+
itype = torch_dtype_to_onnx_dtype(arr_cpu.dtype)
|
|
78
|
+
assert not hasattr(onnx.TensorProto, "INT4") or itype not in {
|
|
79
|
+
onnx.TensorProto.INT4,
|
|
80
|
+
onnx.TensorProto.UINT4,
|
|
81
|
+
}, f"Type {arr.dtype} is not supported yet for name={name!r}"
|
|
82
|
+
tensor.data_type = itype
|
|
83
|
+
|
|
84
|
+
if verbose > 1 and numel > 100:
|
|
85
|
+
print(f"[proto_from_array] {tensor.data_type}[{arr_cpu.shape}]")
|
|
86
|
+
|
|
87
|
+
if isinstance(np_arr, torch.Tensor):
|
|
88
|
+
byte_data = (ctypes.c_ubyte * numel * element_size).from_address(np_arr.data_ptr())
|
|
89
|
+
tensor.raw_data = bytes(byte_data)
|
|
90
|
+
if sys.byteorder == "big":
|
|
91
|
+
np_dtype = _STORAGE_TYPE[tensor.data_type] # type: ignore
|
|
92
|
+
np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True) # type: ignore
|
|
93
|
+
else:
|
|
94
|
+
tensor.raw_data = np_arr.tobytes()
|
|
95
|
+
if sys.byteorder == "big":
|
|
96
|
+
np_dtype = tensor_dtype_to_np_dtype(tensor.data_type)
|
|
97
|
+
np.byteswap(np.frombuffer(tensor.raw_data, dtype=np_dtype), inplace=True)
|
|
98
|
+
return tensor
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def onnx_dtype_to_torch_dtype(itype: int) -> "torch.dtype": # noqa: F821
|
|
102
|
+
"""
|
|
103
|
+
Converts an onnx type into a torch dtype.
|
|
104
|
+
|
|
105
|
+
:param to: onnx dtype
|
|
106
|
+
:return: torch dtype
|
|
107
|
+
"""
|
|
108
|
+
if itype == onnx.TensorProto.FLOAT:
|
|
109
|
+
return torch.float32
|
|
110
|
+
if itype == onnx.TensorProto.FLOAT16:
|
|
111
|
+
return torch.float16
|
|
112
|
+
if itype == onnx.TensorProto.BFLOAT16:
|
|
113
|
+
return torch.bfloat16
|
|
114
|
+
if itype == onnx.TensorProto.DOUBLE:
|
|
115
|
+
return torch.float64
|
|
116
|
+
if itype == onnx.TensorProto.INT32:
|
|
117
|
+
return torch.int32
|
|
118
|
+
if itype == onnx.TensorProto.INT64:
|
|
119
|
+
return torch.int64
|
|
120
|
+
if itype == onnx.TensorProto.UINT32:
|
|
121
|
+
return torch.uint32
|
|
122
|
+
if itype == onnx.TensorProto.UINT64:
|
|
123
|
+
return torch.uint64
|
|
124
|
+
if itype == onnx.TensorProto.BOOL:
|
|
125
|
+
return torch.bool
|
|
126
|
+
if itype == onnx.TensorProto.INT16:
|
|
127
|
+
return torch.int16
|
|
128
|
+
if itype == onnx.TensorProto.UINT16:
|
|
129
|
+
return torch.uint16
|
|
130
|
+
if itype == onnx.TensorProto.INT8:
|
|
131
|
+
return torch.int8
|
|
132
|
+
if itype == onnx.TensorProto.UINT8:
|
|
133
|
+
return torch.uint8
|
|
134
|
+
if itype == onnx.TensorProto.COMPLEX64:
|
|
135
|
+
return torch.complex64
|
|
136
|
+
if itype == onnx.TensorProto.COMPLEX128:
|
|
137
|
+
return torch.complex128
|
|
138
|
+
raise NotImplementedError(
|
|
139
|
+
f"Unable to convert onnx type {onnx_dtype_name(itype)} to torch.type."
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def torch_dtype_to_onnx_dtype(to: "torch.dtype") -> int: # noqa: F821
|
|
144
|
+
"""
|
|
145
|
+
Converts a torch dtype into a onnx element type.
|
|
146
|
+
|
|
147
|
+
:param to: torch dtype
|
|
148
|
+
:return: onnx type
|
|
149
|
+
"""
|
|
150
|
+
import torch
|
|
151
|
+
|
|
152
|
+
if to == torch.float32:
|
|
153
|
+
return onnx.TensorProto.FLOAT
|
|
154
|
+
if to == torch.float16:
|
|
155
|
+
return onnx.TensorProto.FLOAT16
|
|
156
|
+
if to == torch.bfloat16:
|
|
157
|
+
return onnx.TensorProto.BFLOAT16
|
|
158
|
+
if to == torch.float64:
|
|
159
|
+
return onnx.TensorProto.DOUBLE
|
|
160
|
+
if to == torch.int64:
|
|
161
|
+
return onnx.TensorProto.INT64
|
|
162
|
+
if to == torch.int32:
|
|
163
|
+
return onnx.TensorProto.INT32
|
|
164
|
+
if to == torch.uint64:
|
|
165
|
+
return onnx.TensorProto.UINT64
|
|
166
|
+
if to == torch.uint32:
|
|
167
|
+
return onnx.TensorProto.UINT32
|
|
168
|
+
if to == torch.bool:
|
|
169
|
+
return onnx.TensorProto.BOOL
|
|
170
|
+
if to == torch.SymInt:
|
|
171
|
+
return onnx.TensorProto.INT64
|
|
172
|
+
if to == torch.int16:
|
|
173
|
+
return onnx.TensorProto.INT16
|
|
174
|
+
if to == torch.uint16:
|
|
175
|
+
return onnx.TensorProto.UINT16
|
|
176
|
+
if to == torch.int8:
|
|
177
|
+
return onnx.TensorProto.INT8
|
|
178
|
+
if to == torch.uint8:
|
|
179
|
+
return onnx.TensorProto.UINT8
|
|
180
|
+
if to == torch.SymFloat:
|
|
181
|
+
return onnx.TensorProto.FLOAT
|
|
182
|
+
if to == torch.complex64:
|
|
183
|
+
return onnx.TensorProto.COMPLEX64
|
|
184
|
+
if to == torch.complex128:
|
|
185
|
+
return onnx.TensorProto.COMPLEX128
|
|
186
|
+
raise NotImplementedError(f"Unable to convert torch dtype {to!r} to onnx dtype.")
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def _forward_(
|
|
190
|
+
*args,
|
|
191
|
+
_f=None,
|
|
192
|
+
_fprint=string_type,
|
|
193
|
+
_prefix="",
|
|
194
|
+
_context=None,
|
|
195
|
+
_storage=None,
|
|
196
|
+
_storage_limit=2**27,
|
|
197
|
+
_verbose=0,
|
|
198
|
+
**kwargs,
|
|
199
|
+
):
|
|
200
|
+
assert _f is not None, "_f cannot be None"
|
|
201
|
+
assert _context is not None, "_context cannot be None"
|
|
202
|
+
indent = " " * (len(_prefix) - len(_prefix.lstrip()))
|
|
203
|
+
_prefix = _prefix.lstrip()
|
|
204
|
+
print(
|
|
205
|
+
f"{indent}+{_prefix} -- stolen forward for class {_context['class_name']} "
|
|
206
|
+
f"-- iteration {_context['iteration']}"
|
|
207
|
+
)
|
|
208
|
+
kws = dict(
|
|
209
|
+
with_shape=_context.get("with_shape", False),
|
|
210
|
+
with_min_max=_context.get("with_min_max", False),
|
|
211
|
+
)
|
|
212
|
+
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
|
|
213
|
+
# torch.compiler.is_exporting requires torch>=2.7
|
|
214
|
+
print(f"{indent} <- args={_fprint(args, **kws)} --- kwargs={_fprint(kwargs, **kws)}")
|
|
215
|
+
if _storage is not None:
|
|
216
|
+
it = _context["iteration"]
|
|
217
|
+
key = (_prefix, it)
|
|
218
|
+
_storage[(*key, "I")] = (torch_deepcopy(args), torch_deepcopy(kwargs))
|
|
219
|
+
res = _f(*args, **kwargs)
|
|
220
|
+
if not hasattr(torch.compiler, "is_exporting") or not torch.compiler.is_exporting():
|
|
221
|
+
print(f"{indent} -> {_fprint(res, **kws)}")
|
|
222
|
+
print(f"{indent}-{_prefix}.")
|
|
223
|
+
if _storage is not None:
|
|
224
|
+
size = torch_tensor_size(res)
|
|
225
|
+
if size < _storage_limit:
|
|
226
|
+
if _verbose:
|
|
227
|
+
print(
|
|
228
|
+
f"-- stores key={key}, size {size // 2**10}Kb -- "
|
|
229
|
+
f"{string_type(res, with_shape=True)}"
|
|
230
|
+
)
|
|
231
|
+
_storage[(*key, "O")] = torch_deepcopy(res)
|
|
232
|
+
else:
|
|
233
|
+
if _verbose:
|
|
234
|
+
print(
|
|
235
|
+
f"-- skips key={key}, size {size // 2**10}Kb -- "
|
|
236
|
+
f"{string_type(res, with_shape=True)}"
|
|
237
|
+
)
|
|
238
|
+
_context["iteration"] += 1
|
|
239
|
+
return res
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
_steal_forward_status = [False]
|
|
243
|
+
_additional_stolen_objects = {}
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def is_stealing() -> bool:
|
|
247
|
+
"""Returns true if :func:`steal_forward` was yielded."""
|
|
248
|
+
return _steal_forward_status[0]
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def steal_append(name: str, obj: Any):
|
|
252
|
+
"""
|
|
253
|
+
When outside a forward method, it is still possible to add
|
|
254
|
+
a python object which contains tensors and dump after the execution
|
|
255
|
+
of the model.
|
|
256
|
+
|
|
257
|
+
.. code-block:: python
|
|
258
|
+
|
|
259
|
+
steal_append("quantize", [t1, t2])
|
|
260
|
+
|
|
261
|
+
The same code can executed multiple times, then
|
|
262
|
+
the name can extended with a number.
|
|
263
|
+
"""
|
|
264
|
+
if is_stealing():
|
|
265
|
+
if name in _additional_stolen_objects:
|
|
266
|
+
i = 1
|
|
267
|
+
n = f"{name}_{i}"
|
|
268
|
+
while n in _additional_stolen_objects:
|
|
269
|
+
i += 1
|
|
270
|
+
n = f"{name}_{i}"
|
|
271
|
+
print(f"-- stolen {name!r} renamed in {n!r}: {string_type(obj, with_shape=True)}")
|
|
272
|
+
_additional_stolen_objects[n] = obj
|
|
273
|
+
else:
|
|
274
|
+
print(f"-- stolen {name!r}: {string_type(obj, with_shape=True)}")
|
|
275
|
+
_additional_stolen_objects[name] = obj
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
@contextlib.contextmanager
|
|
279
|
+
def steal_forward(
|
|
280
|
+
model: Union[
|
|
281
|
+
Union[torch.nn.Module, Tuple[str, torch.nn.Module]],
|
|
282
|
+
List[Union[torch.nn.Module, Tuple[str, torch.nn.Module]]],
|
|
283
|
+
],
|
|
284
|
+
fprint: Callable = string_type,
|
|
285
|
+
dump_file: Optional[str] = None,
|
|
286
|
+
dump_drop: Optional[Set[str]] = None,
|
|
287
|
+
submodules: bool = False,
|
|
288
|
+
verbose: int = 0,
|
|
289
|
+
storage_limit: int = 2**27,
|
|
290
|
+
save_as_external_data: bool = True,
|
|
291
|
+
**kwargs,
|
|
292
|
+
):
|
|
293
|
+
"""
|
|
294
|
+
The necessary modification to steem forward method and prints out inputs
|
|
295
|
+
and outputs using :func:`onnx_diagnostic.helpers.string_type`.
|
|
296
|
+
See example :ref:`l-plot-tiny-llm-export` or
|
|
297
|
+
:ref:`l-plot-intermediate-results`.
|
|
298
|
+
|
|
299
|
+
:param model: a model or a list of models to monitor,
|
|
300
|
+
every model can also be a tuple(name, model), name is displayed well.
|
|
301
|
+
:param fprint: function used to print out (or dump), by default, it is
|
|
302
|
+
:func:`onnx_diagnostic.helpers.string_type`
|
|
303
|
+
:param kwargs: additional parameters sent to :func:`onnx_diagnostic.helpers.string_type`
|
|
304
|
+
or any other function defined by ``fprint``
|
|
305
|
+
:param dump_file: dumps stolen inputs and outputs in an onnx model,
|
|
306
|
+
they can be restored with :func:`create_input_tensors_from_onnx_model
|
|
307
|
+
<onnx_diagnostic.helpers.mini_onnx_builder.create_input_tensors_from_onnx_model>`
|
|
308
|
+
:param dump_drop: to drop some inputs too big (only if dump_file is specified)
|
|
309
|
+
:param save_as_external_data: True by default, but maybe better to have everything
|
|
310
|
+
in a single file if possible
|
|
311
|
+
:param submodules: if True and model is a module, the list extended with all the submodules
|
|
312
|
+
the module contains
|
|
313
|
+
:param verbose: verbosity
|
|
314
|
+
:param storage_limit: do not stored object bigger than this
|
|
315
|
+
|
|
316
|
+
The following examples shows how to steal and dump all the inputs / outputs
|
|
317
|
+
for a module and its submodules, then restores them.
|
|
318
|
+
|
|
319
|
+
.. runpython::
|
|
320
|
+
:showcode:
|
|
321
|
+
|
|
322
|
+
import torch
|
|
323
|
+
from onnx_diagnostic.helpers.torch_helper import steal_forward
|
|
324
|
+
from onnx_diagnostic.helpers.mini_onnx_builder import (
|
|
325
|
+
create_input_tensors_from_onnx_model,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
class SubModel(torch.nn.Module):
|
|
329
|
+
def forward(self, x):
|
|
330
|
+
return x * x
|
|
331
|
+
|
|
332
|
+
class Model(torch.nn.Module):
|
|
333
|
+
def __init__(self):
|
|
334
|
+
super().__init__()
|
|
335
|
+
self.s1 = SubModel()
|
|
336
|
+
self.s2 = SubModel()
|
|
337
|
+
|
|
338
|
+
def forward(self, x, y):
|
|
339
|
+
return self.s1(x) + self.s2(y)
|
|
340
|
+
|
|
341
|
+
inputs = torch.rand(2, 1), torch.rand(2, 1)
|
|
342
|
+
model = Model()
|
|
343
|
+
dump_file = "dump_steal_forward_submodules.onnx"
|
|
344
|
+
with steal_forward(model, submodules=True, dump_file=dump_file):
|
|
345
|
+
model(*inputs)
|
|
346
|
+
|
|
347
|
+
# Let's restore the stolen data.
|
|
348
|
+
restored = create_input_tensors_from_onnx_model(dump_file)
|
|
349
|
+
for k, v in sorted(restored.items()):
|
|
350
|
+
if isinstance(v, tuple):
|
|
351
|
+
args, kwargs = v
|
|
352
|
+
print("input", k, args, kwargs)
|
|
353
|
+
else:
|
|
354
|
+
print("output", k, v)
|
|
355
|
+
|
|
356
|
+
Function :func:`steal_append` can be used to dump more tensors.
|
|
357
|
+
When inside the context, func:`is_stealing` returns True, False otherwise.
|
|
358
|
+
"""
|
|
359
|
+
assert not is_stealing(), "steal_forward was already called."
|
|
360
|
+
# We clear the cache.
|
|
361
|
+
_steal_forward_status[0] = True
|
|
362
|
+
_additional_stolen_objects.clear()
|
|
363
|
+
assert not submodules or isinstance(
|
|
364
|
+
model, torch.nn.Module
|
|
365
|
+
), f"submodules can only be True if model is a module but is is {type(model)}."
|
|
366
|
+
context = dict(iteration=0, **kwargs)
|
|
367
|
+
if "with_shape" not in context and fprint == string_type:
|
|
368
|
+
context["with_shape"] = True
|
|
369
|
+
if not isinstance(model, list):
|
|
370
|
+
assert isinstance(model, torch.nn.Module), f"Unexpected type {type(model)} for model"
|
|
371
|
+
if submodules:
|
|
372
|
+
models = []
|
|
373
|
+
for idx, m in model.named_modules():
|
|
374
|
+
level = str(idx).split(".")
|
|
375
|
+
ll = len(level)
|
|
376
|
+
try:
|
|
377
|
+
_, start_line = inspect.getsourcelines(m.forward)
|
|
378
|
+
except OSError:
|
|
379
|
+
# The code is not available.
|
|
380
|
+
start_line = 0
|
|
381
|
+
name = f"{idx}-{m.__class__.__name__}-{start_line}"
|
|
382
|
+
models.append((f"{' ' * ll}{name}", m))
|
|
383
|
+
model = models
|
|
384
|
+
else:
|
|
385
|
+
model = [model]
|
|
386
|
+
keep_model_forward = {}
|
|
387
|
+
storage: Optional[Dict[Any, Any]] = {} if dump_file else None
|
|
388
|
+
for mt in model:
|
|
389
|
+
name, m = mt if isinstance(mt, tuple) else ("", mt)
|
|
390
|
+
keep_model_forward[id(m)] = (m, m.forward)
|
|
391
|
+
c = context.copy()
|
|
392
|
+
c["class_name"] = m.__class__.__name__
|
|
393
|
+
m.forward = lambda *args, _f=m.forward, _fp=fprint, _c=c, _p=name, _s=storage, _v=verbose, _sl=storage_limit, **kws: _forward_( # noqa: E501
|
|
394
|
+
*args,
|
|
395
|
+
_f=_f,
|
|
396
|
+
_fprint=_fp,
|
|
397
|
+
_context=_c,
|
|
398
|
+
_prefix=_p,
|
|
399
|
+
_storage=_s,
|
|
400
|
+
_verbose=_v,
|
|
401
|
+
_storage_limit=_sl,
|
|
402
|
+
**kws,
|
|
403
|
+
)
|
|
404
|
+
try:
|
|
405
|
+
yield
|
|
406
|
+
finally:
|
|
407
|
+
_steal_forward_status[0] = False
|
|
408
|
+
for f in keep_model_forward.values():
|
|
409
|
+
f[0].forward = f[1]
|
|
410
|
+
if dump_file:
|
|
411
|
+
# Let's add the cached tensor
|
|
412
|
+
assert storage is not None, "storage cannot be None but mypy is confused here."
|
|
413
|
+
storage.update(_additional_stolen_objects)
|
|
414
|
+
# We clear the cache.
|
|
415
|
+
_additional_stolen_objects.clear()
|
|
416
|
+
if verbose:
|
|
417
|
+
size = torch_tensor_size(storage)
|
|
418
|
+
print(f"-- gather stored {len(storage)} objects, size={size // 2 ** 20} Mb")
|
|
419
|
+
if dump_drop:
|
|
420
|
+
for k, v in storage.items():
|
|
421
|
+
if k[-1] == "I":
|
|
422
|
+
_args, kwargs = v
|
|
423
|
+
ii = set(kwargs) & dump_drop
|
|
424
|
+
if ii:
|
|
425
|
+
for i in ii:
|
|
426
|
+
print("---", i)
|
|
427
|
+
del kwargs[i]
|
|
428
|
+
proto = create_onnx_model_from_input_tensors(storage)
|
|
429
|
+
if verbose:
|
|
430
|
+
print("-- dumps stored objects")
|
|
431
|
+
location = f"{os.path.split(dump_file)[-1]}.data"
|
|
432
|
+
if os.path.exists(location):
|
|
433
|
+
os.remove(location)
|
|
434
|
+
onnx.save(
|
|
435
|
+
proto,
|
|
436
|
+
dump_file,
|
|
437
|
+
save_as_external_data=save_as_external_data,
|
|
438
|
+
all_tensors_to_one_file=True,
|
|
439
|
+
location=location,
|
|
440
|
+
)
|
|
441
|
+
if verbose:
|
|
442
|
+
print("-- done dump stored objects")
|
|
443
|
+
|
|
444
|
+
|
|
445
|
+
@contextlib.contextmanager
|
|
446
|
+
def fake_torchdynamo_exporting():
|
|
447
|
+
"""
|
|
448
|
+
Sets ``torch.compiler._is_exporting_flag`` to True to trigger
|
|
449
|
+
pieces of code only enabled during export.
|
|
450
|
+
"""
|
|
451
|
+
memorize = torch.compiler._is_exporting_flag
|
|
452
|
+
torch.compiler._is_exporting_flag = True
|
|
453
|
+
try:
|
|
454
|
+
yield
|
|
455
|
+
finally:
|
|
456
|
+
torch.compiler._is_exporting_flag = memorize
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def is_torchdynamo_exporting() -> bool:
|
|
460
|
+
"""
|
|
461
|
+
Tells if :epkg:`torch` is exporting a model.
|
|
462
|
+
Relies on ``torch.compiler.is_exporting()``.
|
|
463
|
+
"""
|
|
464
|
+
import torch
|
|
465
|
+
|
|
466
|
+
if not hasattr(torch.compiler, "is_exporting"):
|
|
467
|
+
# torch.compiler.is_exporting requires torch>=2.7
|
|
468
|
+
return False
|
|
469
|
+
|
|
470
|
+
try:
|
|
471
|
+
return torch.compiler.is_exporting()
|
|
472
|
+
except Exception:
|
|
473
|
+
try:
|
|
474
|
+
import torch._dynamo as dynamo
|
|
475
|
+
|
|
476
|
+
return dynamo.is_exporting() # type: ignore
|
|
477
|
+
except Exception:
|
|
478
|
+
return False
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
def to_numpy(tensor: "torch.Tensor") -> np.ndarray: # noqa: F821
|
|
482
|
+
"""Converts a :class:`torch.Tensor` to :class:`numpy.ndarray`."""
|
|
483
|
+
try:
|
|
484
|
+
return tensor.detach().cpu().numpy()
|
|
485
|
+
except TypeError:
|
|
486
|
+
# We try with ml_dtypes
|
|
487
|
+
pass
|
|
488
|
+
|
|
489
|
+
import ml_dtypes
|
|
490
|
+
|
|
491
|
+
conv = {torch.bfloat16: ml_dtypes.bfloat16}
|
|
492
|
+
assert tensor.dtype in conv, f"Unsupported type {tensor.dtype}, not in {conv}"
|
|
493
|
+
return tensor.detach().to(torch.float32).cpu().numpy().astype(conv[tensor.dtype])
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
def replace_string_by_dynamic(dynamic_shapes: Any) -> Any:
|
|
497
|
+
"""Replaces strings by ``torch.export.Dim.DYNAMIC``."""
|
|
498
|
+
import torch
|
|
499
|
+
|
|
500
|
+
if isinstance(dynamic_shapes, torch.export.dynamic_shapes._Dim):
|
|
501
|
+
return dynamic_shapes
|
|
502
|
+
if isinstance(dynamic_shapes, str):
|
|
503
|
+
return torch.export.Dim.DYNAMIC
|
|
504
|
+
if not dynamic_shapes:
|
|
505
|
+
return dynamic_shapes
|
|
506
|
+
if isinstance(dynamic_shapes, (tuple, list)):
|
|
507
|
+
return type(dynamic_shapes)(replace_string_by_dynamic(i) for i in dynamic_shapes)
|
|
508
|
+
if isinstance(dynamic_shapes, dict):
|
|
509
|
+
return {k: replace_string_by_dynamic(v) for k, v in dynamic_shapes.items()}
|
|
510
|
+
raise AssertionError(f"Unexpected type {type(dynamic_shapes)} for dynamic_shapes")
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def dummy_llm(
|
|
514
|
+
cls_name: Optional[str] = None,
|
|
515
|
+
dynamic_shapes: bool = False,
|
|
516
|
+
) -> Union[
|
|
517
|
+
Tuple[torch.nn.Module, Tuple[torch.Tensor, ...]],
|
|
518
|
+
Tuple[torch.nn.Module, Tuple[torch.Tensor, ...], Any],
|
|
519
|
+
]:
|
|
520
|
+
"""
|
|
521
|
+
Creates a dummy LLM for test purposes.
|
|
522
|
+
|
|
523
|
+
:param cls_name: None for whole model or a piece of it
|
|
524
|
+
:param dynamic_shapes: returns dynamic shapes as well
|
|
525
|
+
|
|
526
|
+
.. runpython::
|
|
527
|
+
:showcode:
|
|
528
|
+
|
|
529
|
+
from onnx_diagnostic.helpers.torch_helper import dummy_llm
|
|
530
|
+
print(dummy_llm())
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
class Embedding(torch.nn.Module):
|
|
534
|
+
def __init__(self, vocab_size: int = 1024, embedding_dim: int = 16):
|
|
535
|
+
super().__init__()
|
|
536
|
+
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim)
|
|
537
|
+
self.pe = torch.nn.Embedding(vocab_size, embedding_dim)
|
|
538
|
+
|
|
539
|
+
def forward(self, x):
|
|
540
|
+
word_emb = self.embedding(x)
|
|
541
|
+
word_pe = self.pe(x)
|
|
542
|
+
return word_emb + word_pe
|
|
543
|
+
|
|
544
|
+
class AttentionBlock(torch.nn.Module):
|
|
545
|
+
|
|
546
|
+
def __init__(self, embedding_dim: int = 16, context_size: int = 256):
|
|
547
|
+
super().__init__()
|
|
548
|
+
self.query = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
549
|
+
self.key = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
550
|
+
self.value = torch.nn.Linear(embedding_dim, embedding_dim, bias=False)
|
|
551
|
+
# torch.nn.Buffer are not fully handled by symbolic tracing
|
|
552
|
+
# Buffer(...)[:Prowy()] is not working
|
|
553
|
+
self.mask = torch.nn.Parameter(
|
|
554
|
+
torch.tril(
|
|
555
|
+
input=torch.ones(size=[context_size, context_size], dtype=torch.float)
|
|
556
|
+
)
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
def forward(self, x):
|
|
560
|
+
_B, T, C = x.shape
|
|
561
|
+
|
|
562
|
+
query = self.query(x)
|
|
563
|
+
key = self.key(x)
|
|
564
|
+
value = self.value(x)
|
|
565
|
+
|
|
566
|
+
qk = query @ key.transpose(-2, -1) * C**-0.5
|
|
567
|
+
attention = qk.masked_fill(self.mask[:T, :T] == 0, float("-inf"))
|
|
568
|
+
attention = torch.nn.functional.softmax(input=attention, dim=-1)
|
|
569
|
+
|
|
570
|
+
out = attention @ value
|
|
571
|
+
return out
|
|
572
|
+
|
|
573
|
+
class MultiAttentionBlock(torch.nn.Module):
|
|
574
|
+
|
|
575
|
+
def __init__(
|
|
576
|
+
self, embedding_dim: int = 16, num_heads: int = 2, context_size: int = 256
|
|
577
|
+
):
|
|
578
|
+
super().__init__()
|
|
579
|
+
self.attention = torch.nn.ModuleList(
|
|
580
|
+
modules=[AttentionBlock(embedding_dim, context_size) for _ in range(num_heads)]
|
|
581
|
+
)
|
|
582
|
+
self.linear = torch.nn.Linear(
|
|
583
|
+
in_features=embedding_dim * num_heads, out_features=embedding_dim
|
|
584
|
+
)
|
|
585
|
+
|
|
586
|
+
def forward(self, x):
|
|
587
|
+
out = torch.cat(tensors=[attention(x) for attention in self.attention], dim=-1)
|
|
588
|
+
x = self.linear(out)
|
|
589
|
+
return x
|
|
590
|
+
|
|
591
|
+
class FeedForward(torch.nn.Module):
|
|
592
|
+
|
|
593
|
+
def __init__(self, embedding_dim: int = 16, ff_dim: int = 128):
|
|
594
|
+
super().__init__()
|
|
595
|
+
self.linear_1 = torch.nn.Linear(embedding_dim, ff_dim)
|
|
596
|
+
self.relu = torch.nn.ReLU()
|
|
597
|
+
self.linear_2 = torch.nn.Linear(ff_dim, embedding_dim)
|
|
598
|
+
|
|
599
|
+
def forward(self, x):
|
|
600
|
+
x = self.linear_1(x)
|
|
601
|
+
x = self.relu(x)
|
|
602
|
+
x = self.linear_2(x)
|
|
603
|
+
return x
|
|
604
|
+
|
|
605
|
+
class DecoderLayer(torch.nn.Module):
|
|
606
|
+
|
|
607
|
+
def __init__(
|
|
608
|
+
self,
|
|
609
|
+
embedding_dim: int = 16,
|
|
610
|
+
num_heads: int = 2,
|
|
611
|
+
context_size: int = 256,
|
|
612
|
+
ff_dim: int = 128,
|
|
613
|
+
):
|
|
614
|
+
super().__init__()
|
|
615
|
+
self.attention = MultiAttentionBlock(embedding_dim, num_heads, context_size)
|
|
616
|
+
self.feed_forward = FeedForward(embedding_dim, ff_dim)
|
|
617
|
+
self.norm_1 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
|
|
618
|
+
self.norm_2 = torch.nn.LayerNorm(normalized_shape=embedding_dim)
|
|
619
|
+
|
|
620
|
+
def forward(self, x):
|
|
621
|
+
x_norm = self.norm_1(x)
|
|
622
|
+
attention = self.attention(x_norm)
|
|
623
|
+
attention = attention + x
|
|
624
|
+
|
|
625
|
+
attention_norm = self.norm_2(attention)
|
|
626
|
+
ff = self.feed_forward(attention_norm)
|
|
627
|
+
ff = ff + attention
|
|
628
|
+
|
|
629
|
+
return ff
|
|
630
|
+
|
|
631
|
+
class LLM(torch.nn.Module):
|
|
632
|
+
|
|
633
|
+
def __init__(
|
|
634
|
+
self,
|
|
635
|
+
vocab_size: int = 1024,
|
|
636
|
+
embedding_dim: int = 16,
|
|
637
|
+
num_heads: int = 2,
|
|
638
|
+
context_size: int = 256,
|
|
639
|
+
ff_dim: int = 128,
|
|
640
|
+
):
|
|
641
|
+
super().__init__()
|
|
642
|
+
self.embedding = Embedding(vocab_size, embedding_dim)
|
|
643
|
+
self.decoder = DecoderLayer(embedding_dim, num_heads, context_size, ff_dim)
|
|
644
|
+
|
|
645
|
+
def forward(self, input_ids):
|
|
646
|
+
x = self.embedding(input_ids)
|
|
647
|
+
y = self.decoder(x)
|
|
648
|
+
return y
|
|
649
|
+
|
|
650
|
+
if cls_name in (None, "LLM"):
|
|
651
|
+
dec: torch.nn.Module = LLM()
|
|
652
|
+
x = torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
|
|
653
|
+
dec(x)
|
|
654
|
+
if dynamic_shapes:
|
|
655
|
+
dyn = {
|
|
656
|
+
"input_ids": {
|
|
657
|
+
0: torch.export.Dim("batch", min=1, max=1024),
|
|
658
|
+
1: torch.export.Dim("length", min=1, max=255),
|
|
659
|
+
}
|
|
660
|
+
}
|
|
661
|
+
return dec, (x,), dyn
|
|
662
|
+
return dec, (x,)
|
|
663
|
+
|
|
664
|
+
if cls_name == "DecoderLayer":
|
|
665
|
+
LLM()(torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64))
|
|
666
|
+
|
|
667
|
+
dec = DecoderLayer()
|
|
668
|
+
x = Embedding()(
|
|
669
|
+
torch.randint(0, 1024, (2 if dynamic_shapes else 1, 30)).to(torch.int64)
|
|
670
|
+
)
|
|
671
|
+
dec(x)
|
|
672
|
+
if dynamic_shapes:
|
|
673
|
+
dyn = {
|
|
674
|
+
"x": {
|
|
675
|
+
0: torch.export.Dim("batch", min=1, max=1024),
|
|
676
|
+
1: torch.export.Dim("length", min=1, max=255),
|
|
677
|
+
}
|
|
678
|
+
}
|
|
679
|
+
return dec, (x,), dyn
|
|
680
|
+
return dec, (x,)
|
|
681
|
+
|
|
682
|
+
if cls_name == "MultiAttentionBlock":
|
|
683
|
+
dec = MultiAttentionBlock()
|
|
684
|
+
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
|
|
685
|
+
dec(x)
|
|
686
|
+
if dynamic_shapes:
|
|
687
|
+
dyn = {
|
|
688
|
+
"x": {
|
|
689
|
+
0: torch.export.Dim("batch", min=1, max=1024),
|
|
690
|
+
1: torch.export.Dim("length", min=1, max=255),
|
|
691
|
+
}
|
|
692
|
+
}
|
|
693
|
+
return dec, (x,), dyn
|
|
694
|
+
return dec, (x,)
|
|
695
|
+
|
|
696
|
+
if cls_name == "AttentionBlock":
|
|
697
|
+
dec = AttentionBlock()
|
|
698
|
+
x = torch.rand(2 if dynamic_shapes else 1, 30, 16).to(torch.float32)
|
|
699
|
+
dec(x)
|
|
700
|
+
if dynamic_shapes:
|
|
701
|
+
dyn = {
|
|
702
|
+
"x": {
|
|
703
|
+
0: torch.export.Dim("batch", min=1, max=1024),
|
|
704
|
+
1: torch.export.Dim("length", min=1, max=255),
|
|
705
|
+
}
|
|
706
|
+
}
|
|
707
|
+
return dec, (x,), dyn
|
|
708
|
+
return dec, (x,)
|
|
709
|
+
|
|
710
|
+
raise NotImplementedError(f"cls_name={cls_name}")
|
|
711
|
+
|
|
712
|
+
|
|
713
|
+
def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
|
|
714
|
+
"""Applies torch.to if applicable. Goes recursively."""
|
|
715
|
+
if isinstance(value, (torch.nn.Module, torch.Tensor)) and value.__class__.__name__ not in {
|
|
716
|
+
"DynamicCache",
|
|
717
|
+
"EncoderDecoderCache",
|
|
718
|
+
}:
|
|
719
|
+
if (
|
|
720
|
+
(
|
|
721
|
+
isinstance(to_value, torch.dtype)
|
|
722
|
+
or to_value in {"float16", "bfloat16", "float32", "float64"}
|
|
723
|
+
)
|
|
724
|
+
and hasattr(value, "dtype")
|
|
725
|
+
and value.dtype in {torch.int32, torch.int64, torch.int8, torch.int16}
|
|
726
|
+
):
|
|
727
|
+
# int vector should not be changed.
|
|
728
|
+
return value
|
|
729
|
+
return value.to(to_value)
|
|
730
|
+
if isinstance(value, list):
|
|
731
|
+
return [to_any(t, to_value) for t in value]
|
|
732
|
+
if isinstance(value, tuple):
|
|
733
|
+
return tuple(to_any(t, to_value) for t in value)
|
|
734
|
+
if isinstance(value, set):
|
|
735
|
+
return {to_any(t, to_value) for t in value}
|
|
736
|
+
if type(value) is dict:
|
|
737
|
+
return {k: to_any(t, to_value) for k, t in value.items()}
|
|
738
|
+
if value.__class__.__name__ in {"DynamicCache", "HybridCache"}:
|
|
739
|
+
make = dict(DynamicCache=make_dynamic_cache, HybridCache=make_hybrid_cache)
|
|
740
|
+
cc = CacheKeyValue(value)
|
|
741
|
+
return make[value.__class__.__name__]( # type: ignore[operator]
|
|
742
|
+
list(
|
|
743
|
+
zip(
|
|
744
|
+
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
745
|
+
[t.to(to_value) if t is not None else t for t in cc.value_cache],
|
|
746
|
+
)
|
|
747
|
+
)
|
|
748
|
+
)
|
|
749
|
+
if value.__class__.__name__ == "StaticCache":
|
|
750
|
+
cc = CacheKeyValue(value)
|
|
751
|
+
return make_static_cache(
|
|
752
|
+
list(
|
|
753
|
+
zip(
|
|
754
|
+
[t.to(to_value) if t is not None else t for t in cc.key_cache],
|
|
755
|
+
[t.to(to_value) if t is not None else t for t in cc.value_cache],
|
|
756
|
+
)
|
|
757
|
+
),
|
|
758
|
+
max_cache_len=value.max_cache_len,
|
|
759
|
+
)
|
|
760
|
+
if value.__class__.__name__ == "EncoderDecoderCache":
|
|
761
|
+
return make_encoder_decoder_cache(
|
|
762
|
+
to_any(value.self_attention_cache, to_value),
|
|
763
|
+
to_any(value.cross_attention_cache, to_value),
|
|
764
|
+
)
|
|
765
|
+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
766
|
+
args, spec = torch.utils._pytree.tree_flatten(value)
|
|
767
|
+
new_args = to_any(args, to_value)
|
|
768
|
+
return torch.utils._pytree.tree_unflatten(new_args, spec)
|
|
769
|
+
|
|
770
|
+
if hasattr(value, "to"):
|
|
771
|
+
return value.to(to_value)
|
|
772
|
+
|
|
773
|
+
assert "Cache" not in value.__class__.__name__, (
|
|
774
|
+
f"Class {value.__class__.__name__!r} should be registered "
|
|
775
|
+
f"to be able to change the type in every tensor it contains."
|
|
776
|
+
)
|
|
777
|
+
assert not isinstance(value, Iterable), f"Unsupported type {type(value)}"
|
|
778
|
+
return value
|
|
779
|
+
|
|
780
|
+
|
|
781
|
+
def torch_deepcopy(value: Any) -> Any:
|
|
782
|
+
"""
|
|
783
|
+
Makes a deep copy.
|
|
784
|
+
|
|
785
|
+
:param value: any value
|
|
786
|
+
:return: a deep copy
|
|
787
|
+
"""
|
|
788
|
+
if value is None:
|
|
789
|
+
return None
|
|
790
|
+
if isinstance(value, (int, float, str)):
|
|
791
|
+
return value
|
|
792
|
+
if isinstance(value, tuple):
|
|
793
|
+
return tuple(torch_deepcopy(v) for v in value)
|
|
794
|
+
if isinstance(value, list):
|
|
795
|
+
return [torch_deepcopy(v) for v in value]
|
|
796
|
+
if isinstance(value, set):
|
|
797
|
+
return {torch_deepcopy(v) for v in value}
|
|
798
|
+
if isinstance(value, dict):
|
|
799
|
+
if type(value) is dict:
|
|
800
|
+
return {k: torch_deepcopy(v) for k, v in value.items()}
|
|
801
|
+
# for BaseModelOutput
|
|
802
|
+
return value.__class__(**{k: torch_deepcopy(v) for k, v in value.items()})
|
|
803
|
+
if isinstance(value, np.ndarray):
|
|
804
|
+
return value.copy()
|
|
805
|
+
if hasattr(value, "clone"):
|
|
806
|
+
return value.clone()
|
|
807
|
+
if value.__class__.__name__ == "DynamicCache":
|
|
808
|
+
from .cache_helper import CacheKeyValue
|
|
809
|
+
|
|
810
|
+
ca = CacheKeyValue(value)
|
|
811
|
+
return make_dynamic_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
|
|
812
|
+
if value.__class__.__name__ == "StaticCache":
|
|
813
|
+
from .cache_helper import CacheKeyValue
|
|
814
|
+
|
|
815
|
+
ca = CacheKeyValue(value)
|
|
816
|
+
if len(ca.key_cache) == 0:
|
|
817
|
+
# Use of deepcopy.
|
|
818
|
+
import copy
|
|
819
|
+
|
|
820
|
+
return copy.deepcopy(value)
|
|
821
|
+
return make_static_cache(
|
|
822
|
+
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))),
|
|
823
|
+
max_cache_len=max([value.max_cache_len, *[t.shape[2] for t in ca.key_cache]]),
|
|
824
|
+
)
|
|
825
|
+
if value.__class__.__name__ == "HybridCache":
|
|
826
|
+
from .cache_helper import CacheKeyValue
|
|
827
|
+
|
|
828
|
+
ca = CacheKeyValue(value)
|
|
829
|
+
return make_hybrid_cache(torch_deepcopy(list(zip(ca.key_cache, ca.value_cache))))
|
|
830
|
+
if value.__class__.__name__ == "SlidingWindowCache":
|
|
831
|
+
from .cache_helper import CacheKeyValue
|
|
832
|
+
|
|
833
|
+
ca = CacheKeyValue(value)
|
|
834
|
+
return make_sliding_window_cache(
|
|
835
|
+
torch_deepcopy(list(zip(ca.key_cache, ca.value_cache)))
|
|
836
|
+
)
|
|
837
|
+
if value.__class__.__name__ == "EncoderDecoderCache":
|
|
838
|
+
return make_encoder_decoder_cache(
|
|
839
|
+
torch_deepcopy(value.self_attention_cache),
|
|
840
|
+
torch_deepcopy(value.cross_attention_cache),
|
|
841
|
+
)
|
|
842
|
+
if value.__class__.__name__ == "MambaCache":
|
|
843
|
+
return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
|
|
844
|
+
|
|
845
|
+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
846
|
+
args, spec = torch.utils._pytree.tree_flatten(value)
|
|
847
|
+
new_args = torch_deepcopy(args)
|
|
848
|
+
return torch.utils._pytree.tree_unflatten(new_args, spec)
|
|
849
|
+
|
|
850
|
+
if value.__class__.__name__ == "Results":
|
|
851
|
+
import copy
|
|
852
|
+
import ultralytics
|
|
853
|
+
|
|
854
|
+
assert isinstance(
|
|
855
|
+
value, ultralytics.engine.results.Results
|
|
856
|
+
), f"Unexpected type={type(value)}"
|
|
857
|
+
return copy.deepcopy(value)
|
|
858
|
+
|
|
859
|
+
if hasattr(value, "__nocopy__"):
|
|
860
|
+
return value
|
|
861
|
+
|
|
862
|
+
# We should have a code using serialization, deserialization assuming a model
|
|
863
|
+
# cannot be exported without them.
|
|
864
|
+
raise NotImplementedError(
|
|
865
|
+
f"torch_deepcopy not implemented for type {type(value)}, "
|
|
866
|
+
f"add attribute '__nocopy__' to return it as is."
|
|
867
|
+
)
|
|
868
|
+
|
|
869
|
+
|
|
870
|
+
def torch_tensor_size(value: Any) -> Any:
|
|
871
|
+
"""Returns the number of bytes stored in tensors."""
|
|
872
|
+
if value is None:
|
|
873
|
+
return 0
|
|
874
|
+
if isinstance(value, (int, float, str)):
|
|
875
|
+
return 0
|
|
876
|
+
if isinstance(value, (tuple, list, set)):
|
|
877
|
+
return sum(torch_tensor_size(v) for v in value)
|
|
878
|
+
if isinstance(value, dict):
|
|
879
|
+
return sum(torch_tensor_size(v) for v in value.values())
|
|
880
|
+
if isinstance(value, np.ndarray):
|
|
881
|
+
return value.copy()
|
|
882
|
+
if hasattr(value, "clone"):
|
|
883
|
+
return value.numel() * size_type(value.dtype)
|
|
884
|
+
if value.__class__.__name__ in {
|
|
885
|
+
"DynamicCache",
|
|
886
|
+
"SlidingWindowCache",
|
|
887
|
+
"HybridCache",
|
|
888
|
+
"StaticCache",
|
|
889
|
+
}:
|
|
890
|
+
cc = CacheKeyValue(value)
|
|
891
|
+
return torch_tensor_size(cc.key_cache) + torch_tensor_size(cc.value_cache)
|
|
892
|
+
if value.__class__.__name__ == "EncoderDecoderCache":
|
|
893
|
+
return torch_tensor_size(value.self_attention_cache) + torch_tensor_size(
|
|
894
|
+
value.cross_attention_cache
|
|
895
|
+
)
|
|
896
|
+
if value.__class__.__name__ == "MambaCache":
|
|
897
|
+
return torch_tensor_size(value.conv_states) + torch_tensor_size(value.ssm_states)
|
|
898
|
+
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
899
|
+
args, _spec = torch.utils._pytree.tree_flatten(value)
|
|
900
|
+
return sum(torch_tensor_size(a) for a in args)
|
|
901
|
+
|
|
902
|
+
# We should have a code using serialization, deserialization assuming a model
|
|
903
|
+
# cannot be exported without them.
|
|
904
|
+
raise NotImplementedError(f"torch_tensor_size not implemented for type {type(value)}")
|
|
905
|
+
|
|
906
|
+
|
|
907
|
+
def model_statistics(model: torch.nn.Module):
|
|
908
|
+
"""Returns statistics on a model in a dictionary."""
|
|
909
|
+
n_subs = len(list(model.modules()))
|
|
910
|
+
sizes = {}
|
|
911
|
+
param_size = 0
|
|
912
|
+
for param in model.parameters():
|
|
913
|
+
size = param.nelement() * param.element_size()
|
|
914
|
+
param_size += size
|
|
915
|
+
name = str(param.dtype).replace("torch.", "")
|
|
916
|
+
if name not in sizes:
|
|
917
|
+
sizes[name] = 0
|
|
918
|
+
sizes[name] += size
|
|
919
|
+
|
|
920
|
+
buffer_size = 0
|
|
921
|
+
for buffer in model.buffers():
|
|
922
|
+
size = buffer.nelement() * buffer.element_size()
|
|
923
|
+
buffer_size += size
|
|
924
|
+
name = str(buffer.dtype).replace("torch.", "")
|
|
925
|
+
if name not in sizes:
|
|
926
|
+
sizes[name] = 0
|
|
927
|
+
sizes[name] += size
|
|
928
|
+
|
|
929
|
+
res = dict(
|
|
930
|
+
type=model.__class__.__name__,
|
|
931
|
+
n_modules=n_subs,
|
|
932
|
+
param_size=param_size,
|
|
933
|
+
buffer_size=buffer_size,
|
|
934
|
+
size_mb=(param_size + buffer_size) // 2**20,
|
|
935
|
+
)
|
|
936
|
+
res.update(sizes)
|
|
937
|
+
return res
|
|
938
|
+
|
|
939
|
+
|
|
940
|
+
def to_tensor(tensor: onnx.TensorProto, base_dir: str = "") -> torch.Tensor:
|
|
941
|
+
"""
|
|
942
|
+
Converts a TensorProto to a numpy array.
|
|
943
|
+
|
|
944
|
+
:param tensor: a TensorProto object.
|
|
945
|
+
:param base_dir: if external tensor exists, base_dir can help to find the path to it
|
|
946
|
+
:return: the converted tensor
|
|
947
|
+
"""
|
|
948
|
+
assert not tensor.HasField("segment"), "Currently not supporting loading segments."
|
|
949
|
+
assert (
|
|
950
|
+
tensor.data_type != onnx.TensorProto.UNDEFINED
|
|
951
|
+
), "The element type in the input tensor is not defined."
|
|
952
|
+
assert tensor.data_type != onnx.TensorProto.STRING, "to_tensor not implemented for strings"
|
|
953
|
+
|
|
954
|
+
tensor_dtype = tensor.data_type
|
|
955
|
+
torch_dtype = onnx_dtype_to_torch_dtype(tensor_dtype)
|
|
956
|
+
dims = tuple(tensor.dims)
|
|
957
|
+
if uses_external_data(tensor):
|
|
958
|
+
# Load raw data from external tensor if it exists
|
|
959
|
+
load_external_data_for_tensor(tensor, base_dir)
|
|
960
|
+
|
|
961
|
+
if tensor.HasField("raw_data"):
|
|
962
|
+
raw_data = tensor.raw_data
|
|
963
|
+
if len(raw_data) == 0:
|
|
964
|
+
return torch.tensor([], dtype=torch_dtype).reshape(dims)
|
|
965
|
+
if sys.byteorder == "big":
|
|
966
|
+
# Convert endian from little to big
|
|
967
|
+
raw_data = torch.frombuffer(raw_data, dtype=torch_dtype).byteswap().tobytes()
|
|
968
|
+
with warnings.catch_warnings():
|
|
969
|
+
warnings.simplefilter("ignore")
|
|
970
|
+
return torch.frombuffer(raw_data, dtype=torch_dtype).reshape(dims)
|
|
971
|
+
|
|
972
|
+
# Other cases, it should be small tensor. We use numpy.
|
|
973
|
+
np_tensor = to_array_extended(tensor)
|
|
974
|
+
return torch.from_numpy(np_tensor)
|
|
975
|
+
|
|
976
|
+
|
|
977
|
+
def get_weight_type(model: torch.nn.Module) -> torch.dtype:
|
|
978
|
+
"""Returns the most probable dtype in a model."""
|
|
979
|
+
counts = {}
|
|
980
|
+
for _name, param in model.named_parameters():
|
|
981
|
+
dt = param.dtype
|
|
982
|
+
if dt not in counts:
|
|
983
|
+
counts[dt] = 1
|
|
984
|
+
else:
|
|
985
|
+
counts[dt] += 1
|
|
986
|
+
final = max(list(counts.items()))
|
|
987
|
+
return final[0]
|