onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1707 @@
|
|
|
1
|
+
import ast
|
|
2
|
+
import enum
|
|
3
|
+
import inspect
|
|
4
|
+
import itertools
|
|
5
|
+
from dataclasses import is_dataclass, fields
|
|
6
|
+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def size_type(dtype: Any) -> int:
|
|
11
|
+
"""Returns the element size for an element type."""
|
|
12
|
+
if isinstance(dtype, int):
|
|
13
|
+
from onnx import TensorProto
|
|
14
|
+
|
|
15
|
+
# It is a TensorProto.DATATYPE
|
|
16
|
+
if dtype in {
|
|
17
|
+
TensorProto.DOUBLE,
|
|
18
|
+
TensorProto.INT64,
|
|
19
|
+
TensorProto.UINT64,
|
|
20
|
+
TensorProto.COMPLEX64,
|
|
21
|
+
}:
|
|
22
|
+
return 8
|
|
23
|
+
if dtype in {TensorProto.FLOAT, TensorProto.INT32, TensorProto.UINT32}:
|
|
24
|
+
return 4
|
|
25
|
+
if dtype in {
|
|
26
|
+
TensorProto.FLOAT16,
|
|
27
|
+
TensorProto.BFLOAT16,
|
|
28
|
+
TensorProto.INT16,
|
|
29
|
+
TensorProto.UINT16,
|
|
30
|
+
}:
|
|
31
|
+
return 2
|
|
32
|
+
if dtype in {
|
|
33
|
+
TensorProto.INT8,
|
|
34
|
+
TensorProto.UINT8,
|
|
35
|
+
TensorProto.BOOL,
|
|
36
|
+
TensorProto.FLOAT8E4M3FN,
|
|
37
|
+
TensorProto.FLOAT8E4M3FNUZ,
|
|
38
|
+
TensorProto.FLOAT8E5M2,
|
|
39
|
+
TensorProto.FLOAT8E5M2FNUZ,
|
|
40
|
+
getattr(TensorProto, "FLOAT8E8M0", None),
|
|
41
|
+
}:
|
|
42
|
+
return 1
|
|
43
|
+
if dtype in {TensorProto.COMPLEX128}:
|
|
44
|
+
return 16
|
|
45
|
+
from .onnx_helper import onnx_dtype_name
|
|
46
|
+
|
|
47
|
+
raise AssertionError(
|
|
48
|
+
f"Unable to return the element size for type {onnx_dtype_name(dtype)}"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
if dtype == np.float64 or dtype == np.int64:
|
|
52
|
+
return 8
|
|
53
|
+
if dtype == np.float32 or dtype == np.float32:
|
|
54
|
+
return 4
|
|
55
|
+
if dtype == np.float16 or dtype == np.int16:
|
|
56
|
+
return 2
|
|
57
|
+
if dtype == np.int32:
|
|
58
|
+
return 4
|
|
59
|
+
if dtype == np.int8:
|
|
60
|
+
return 1
|
|
61
|
+
if hasattr(np, "uint64"):
|
|
62
|
+
# it fails on mac
|
|
63
|
+
if dtype == np.uint64:
|
|
64
|
+
return 8
|
|
65
|
+
if dtype == np.uint32:
|
|
66
|
+
return 4
|
|
67
|
+
if dtype == np.uint16:
|
|
68
|
+
return 2
|
|
69
|
+
if dtype == np.uint8:
|
|
70
|
+
return 1
|
|
71
|
+
|
|
72
|
+
import torch
|
|
73
|
+
|
|
74
|
+
if dtype in {torch.float64, torch.int64}:
|
|
75
|
+
return 8
|
|
76
|
+
if dtype in {torch.float32, torch.int32}:
|
|
77
|
+
return 4
|
|
78
|
+
if dtype in {torch.float16, torch.int16, torch.bfloat16}:
|
|
79
|
+
return 2
|
|
80
|
+
if dtype in {torch.int8, torch.uint8, torch.bool}:
|
|
81
|
+
return 1
|
|
82
|
+
if hasattr(torch, "uint64"):
|
|
83
|
+
# it fails on mac
|
|
84
|
+
if dtype in {torch.uint64}:
|
|
85
|
+
return 8
|
|
86
|
+
if dtype in {torch.uint32}:
|
|
87
|
+
return 4
|
|
88
|
+
if dtype in {torch.uint16}:
|
|
89
|
+
return 2
|
|
90
|
+
import ml_dtypes
|
|
91
|
+
|
|
92
|
+
if dtype == ml_dtypes.bfloat16:
|
|
93
|
+
return 2
|
|
94
|
+
raise AssertionError(f"Unexpected dtype={dtype}")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def string_type(
|
|
98
|
+
obj: Any,
|
|
99
|
+
with_shape: bool = False,
|
|
100
|
+
with_min_max: bool = False,
|
|
101
|
+
with_device: bool = False,
|
|
102
|
+
ignore: bool = False,
|
|
103
|
+
limit: int = 20,
|
|
104
|
+
verbose: int = 0,
|
|
105
|
+
) -> str:
|
|
106
|
+
"""
|
|
107
|
+
Displays the types of an object as a string.
|
|
108
|
+
|
|
109
|
+
:param obj: any
|
|
110
|
+
:param with_shape: displays shapes as well
|
|
111
|
+
:param with_min_max: displays information about the values
|
|
112
|
+
:param with_device: display the device
|
|
113
|
+
:param ignore: if True, just prints the type for unknown types
|
|
114
|
+
:param verbose: verbosity (to show the path it followed to get that print)
|
|
115
|
+
:return: str
|
|
116
|
+
|
|
117
|
+
The function displays something like the following for a tensor.
|
|
118
|
+
|
|
119
|
+
.. code-block:: text
|
|
120
|
+
|
|
121
|
+
T7s2x7[0.5:6:A3.56]
|
|
122
|
+
^^^+-^^----+------^
|
|
123
|
+
|| | |
|
|
124
|
+
|| | +-- information about the content of a tensor or array
|
|
125
|
+
|| | [min,max:A<average>]
|
|
126
|
+
|| |
|
|
127
|
+
|| +-- a shape
|
|
128
|
+
||
|
|
129
|
+
|+-- integer following the code defined by onnx.TensorProto,
|
|
130
|
+
| 7 is onnx.TensorProto.INT64 (see onnx_dtype_name)
|
|
131
|
+
|
|
|
132
|
+
+-- A,T,F
|
|
133
|
+
A is an array from numpy
|
|
134
|
+
T is a Tensor from pytorch
|
|
135
|
+
F is a FakeTensor from pytorch
|
|
136
|
+
|
|
137
|
+
The element types for a tensor are displayed as integer to shorten the message.
|
|
138
|
+
The semantic is defined by :class:`onnx.TensorProto` and can be obtained
|
|
139
|
+
by :func:`onnx_diagnostic.helpers.onnx_helper.onnx_dtype_name`.
|
|
140
|
+
|
|
141
|
+
.. runpython::
|
|
142
|
+
:showcode:
|
|
143
|
+
|
|
144
|
+
from onnx_diagnostic.helpers import string_type
|
|
145
|
+
|
|
146
|
+
print(string_type((1, ["r", 6.6])))
|
|
147
|
+
|
|
148
|
+
With pytorch:
|
|
149
|
+
|
|
150
|
+
.. runpython::
|
|
151
|
+
:showcode:
|
|
152
|
+
|
|
153
|
+
import torch
|
|
154
|
+
from onnx_diagnostic.helpers import string_type
|
|
155
|
+
|
|
156
|
+
inputs = (
|
|
157
|
+
torch.rand((3, 4), dtype=torch.float16),
|
|
158
|
+
[
|
|
159
|
+
torch.rand((5, 6), dtype=torch.float16),
|
|
160
|
+
torch.rand((5, 6, 7), dtype=torch.float16),
|
|
161
|
+
]
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
# with shapes
|
|
165
|
+
print(string_type(inputs, with_shape=True))
|
|
166
|
+
|
|
167
|
+
# with min max
|
|
168
|
+
print(string_type(inputs, with_shape=True, with_min_max=True))
|
|
169
|
+
"""
|
|
170
|
+
if obj is None:
|
|
171
|
+
if verbose:
|
|
172
|
+
print(f"[string_type] A:{type(obj)}")
|
|
173
|
+
return "None"
|
|
174
|
+
|
|
175
|
+
# tuple
|
|
176
|
+
if isinstance(obj, tuple):
|
|
177
|
+
if len(obj) == 1:
|
|
178
|
+
s = string_type(
|
|
179
|
+
obj[0],
|
|
180
|
+
with_shape=with_shape,
|
|
181
|
+
with_min_max=with_min_max,
|
|
182
|
+
with_device=with_device,
|
|
183
|
+
ignore=ignore,
|
|
184
|
+
limit=limit,
|
|
185
|
+
verbose=verbose,
|
|
186
|
+
)
|
|
187
|
+
if verbose:
|
|
188
|
+
print(f"[string_type] C:{type(obj)}")
|
|
189
|
+
return f"({s},)"
|
|
190
|
+
if len(obj) < limit:
|
|
191
|
+
js = ",".join(
|
|
192
|
+
string_type(
|
|
193
|
+
o,
|
|
194
|
+
with_shape=with_shape,
|
|
195
|
+
with_min_max=with_min_max,
|
|
196
|
+
with_device=with_device,
|
|
197
|
+
ignore=ignore,
|
|
198
|
+
limit=limit,
|
|
199
|
+
verbose=verbose,
|
|
200
|
+
)
|
|
201
|
+
for o in obj
|
|
202
|
+
)
|
|
203
|
+
if verbose:
|
|
204
|
+
print(f"[string_type] D:{type(obj)}")
|
|
205
|
+
return f"({js})"
|
|
206
|
+
tt = string_type(
|
|
207
|
+
obj[0],
|
|
208
|
+
with_shape=with_shape,
|
|
209
|
+
with_min_max=with_min_max,
|
|
210
|
+
with_device=with_device,
|
|
211
|
+
ignore=ignore,
|
|
212
|
+
limit=limit,
|
|
213
|
+
verbose=verbose,
|
|
214
|
+
)
|
|
215
|
+
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
|
|
216
|
+
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
|
|
217
|
+
if verbose:
|
|
218
|
+
print(f"[string_type] E:{type(obj)}")
|
|
219
|
+
return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
|
|
220
|
+
if verbose:
|
|
221
|
+
print(f"[string_type] F:{type(obj)}")
|
|
222
|
+
return f"#{len(obj)}({tt},...)"
|
|
223
|
+
# list
|
|
224
|
+
if isinstance(obj, list):
|
|
225
|
+
if len(obj) < limit:
|
|
226
|
+
js = ",".join(
|
|
227
|
+
string_type(
|
|
228
|
+
o,
|
|
229
|
+
with_shape=with_shape,
|
|
230
|
+
with_min_max=with_min_max,
|
|
231
|
+
with_device=with_device,
|
|
232
|
+
ignore=ignore,
|
|
233
|
+
limit=limit,
|
|
234
|
+
verbose=verbose,
|
|
235
|
+
)
|
|
236
|
+
for o in obj
|
|
237
|
+
)
|
|
238
|
+
if verbose:
|
|
239
|
+
print(f"[string_type] G:{type(obj)}")
|
|
240
|
+
return f"#{len(obj)}[{js}]"
|
|
241
|
+
tt = string_type(
|
|
242
|
+
obj[0],
|
|
243
|
+
with_shape=with_shape,
|
|
244
|
+
with_min_max=with_min_max,
|
|
245
|
+
with_device=with_device,
|
|
246
|
+
ignore=ignore,
|
|
247
|
+
limit=limit,
|
|
248
|
+
verbose=verbose,
|
|
249
|
+
)
|
|
250
|
+
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
|
|
251
|
+
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
|
|
252
|
+
if verbose:
|
|
253
|
+
print(f"[string_type] H:{type(obj)}")
|
|
254
|
+
return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]"
|
|
255
|
+
if verbose:
|
|
256
|
+
print(f"[string_type] I:{type(obj)}")
|
|
257
|
+
return f"#{len(obj)}[{tt},...]"
|
|
258
|
+
# set
|
|
259
|
+
if isinstance(obj, set):
|
|
260
|
+
if len(obj) < 10:
|
|
261
|
+
js = ",".join(
|
|
262
|
+
string_type(
|
|
263
|
+
o,
|
|
264
|
+
with_shape=with_shape,
|
|
265
|
+
with_min_max=with_min_max,
|
|
266
|
+
with_device=with_device,
|
|
267
|
+
ignore=ignore,
|
|
268
|
+
limit=limit,
|
|
269
|
+
verbose=verbose,
|
|
270
|
+
)
|
|
271
|
+
for o in obj
|
|
272
|
+
)
|
|
273
|
+
if verbose:
|
|
274
|
+
print(f"[string_type] J:{type(obj)}")
|
|
275
|
+
return f"{{{js}}}"
|
|
276
|
+
if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
|
|
277
|
+
mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
|
|
278
|
+
if verbose:
|
|
279
|
+
print(f"[string_type] K:{type(obj)}")
|
|
280
|
+
return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]"
|
|
281
|
+
if verbose:
|
|
282
|
+
print(f"[string_type] L:{type(obj)}")
|
|
283
|
+
return f"{{...}}#{len(obj)}" if with_shape else "{...}"
|
|
284
|
+
# dict
|
|
285
|
+
if isinstance(obj, dict) and type(obj) is dict:
|
|
286
|
+
if len(obj) == 0:
|
|
287
|
+
if verbose:
|
|
288
|
+
print(f"[string_type] M:{type(obj)}")
|
|
289
|
+
return "{}"
|
|
290
|
+
|
|
291
|
+
import torch
|
|
292
|
+
|
|
293
|
+
if all(isinstance(k, int) for k in obj) and all(
|
|
294
|
+
isinstance(
|
|
295
|
+
v,
|
|
296
|
+
(
|
|
297
|
+
str,
|
|
298
|
+
torch.export.dynamic_shapes._Dim,
|
|
299
|
+
torch.export.dynamic_shapes._DerivedDim,
|
|
300
|
+
torch.export.dynamic_shapes._DimHint,
|
|
301
|
+
),
|
|
302
|
+
)
|
|
303
|
+
for v in obj.values()
|
|
304
|
+
):
|
|
305
|
+
# This is dynamic shapes
|
|
306
|
+
rows = []
|
|
307
|
+
for k, v in obj.items():
|
|
308
|
+
if isinstance(v, str):
|
|
309
|
+
rows.append(f"{k}:DYN({v})")
|
|
310
|
+
else:
|
|
311
|
+
rows.append(f"{k}:{string_type(v, verbose=verbose)}")
|
|
312
|
+
if verbose:
|
|
313
|
+
print(f"[string_type] DS0:{type(obj)}")
|
|
314
|
+
return f"{{{','.join(rows)}}}"
|
|
315
|
+
|
|
316
|
+
kws = dict(
|
|
317
|
+
with_shape=with_shape,
|
|
318
|
+
with_min_max=with_min_max,
|
|
319
|
+
with_device=with_device,
|
|
320
|
+
ignore=ignore,
|
|
321
|
+
limit=limit,
|
|
322
|
+
verbose=verbose,
|
|
323
|
+
)
|
|
324
|
+
s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items())
|
|
325
|
+
if all(isinstance(k, int) for k in obj):
|
|
326
|
+
if verbose:
|
|
327
|
+
print(f"[string_type] N:{type(obj)}")
|
|
328
|
+
return f"{{{s}}}"
|
|
329
|
+
if verbose:
|
|
330
|
+
print(f"[string_type] O:{type(obj)}")
|
|
331
|
+
return f"dict({s})"
|
|
332
|
+
# array
|
|
333
|
+
if isinstance(obj, np.ndarray):
|
|
334
|
+
from .onnx_helper import np_dtype_to_tensor_dtype
|
|
335
|
+
|
|
336
|
+
if with_min_max:
|
|
337
|
+
s = string_type(obj, with_shape=with_shape)
|
|
338
|
+
if len(obj.shape) == 0:
|
|
339
|
+
return f"{s}={obj}"
|
|
340
|
+
if obj.size == 0:
|
|
341
|
+
return f"{s}[empty]"
|
|
342
|
+
n_nan = np.isnan(obj.reshape((-1,))).astype(int).sum()
|
|
343
|
+
if n_nan > 0:
|
|
344
|
+
nob = obj.ravel()
|
|
345
|
+
nob = nob[~np.isnan(nob)]
|
|
346
|
+
if nob.size == 0:
|
|
347
|
+
if verbose:
|
|
348
|
+
print(f"[string_type] A1:{type(obj)}")
|
|
349
|
+
return f"{s}[N{n_nan}nans]"
|
|
350
|
+
if verbose:
|
|
351
|
+
print(f"[string_type] A2:{type(obj)}")
|
|
352
|
+
return f"{s}[{nob.min()},{nob.max()}:A{nob.astype(float).mean()}N{n_nan}nans]"
|
|
353
|
+
if verbose:
|
|
354
|
+
print(f"[string_type] A3:{type(obj)}")
|
|
355
|
+
return f"{s}[{obj.min()},{obj.max()}:A{obj.astype(float).mean()}]"
|
|
356
|
+
i = np_dtype_to_tensor_dtype(obj.dtype)
|
|
357
|
+
if not with_shape:
|
|
358
|
+
if verbose:
|
|
359
|
+
print(f"[string_type] A4:{type(obj)}")
|
|
360
|
+
return f"A{i}r{len(obj.shape)}"
|
|
361
|
+
if verbose:
|
|
362
|
+
print(f"[string_type] A5:{type(obj)}")
|
|
363
|
+
return f"A{i}s{'x'.join(map(str, obj.shape))}"
|
|
364
|
+
|
|
365
|
+
import torch
|
|
366
|
+
|
|
367
|
+
# Dim, SymInt
|
|
368
|
+
if isinstance(obj, torch.export.dynamic_shapes._DerivedDim):
|
|
369
|
+
if verbose:
|
|
370
|
+
print(f"[string_type] Y1:{type(obj)}")
|
|
371
|
+
return "DerivedDim"
|
|
372
|
+
if isinstance(obj, torch.export.dynamic_shapes._Dim):
|
|
373
|
+
if verbose:
|
|
374
|
+
print(f"[string_type] Y2:{type(obj)}")
|
|
375
|
+
return f"Dim({obj.__name__})"
|
|
376
|
+
if isinstance(obj, torch.SymInt):
|
|
377
|
+
if verbose:
|
|
378
|
+
print(f"[string_type] Y3:{type(obj)}")
|
|
379
|
+
return "SymInt"
|
|
380
|
+
if isinstance(obj, torch.SymFloat):
|
|
381
|
+
if verbose:
|
|
382
|
+
print(f"[string_type] Y4:{type(obj)}")
|
|
383
|
+
return "SymFloat"
|
|
384
|
+
|
|
385
|
+
if isinstance(obj, torch.export.dynamic_shapes._DimHint):
|
|
386
|
+
cl = (
|
|
387
|
+
torch.export.dynamic_shapes._DimHintType
|
|
388
|
+
if hasattr(torch.export.dynamic_shapes, "_DimHintType")
|
|
389
|
+
else torch.export.Dim
|
|
390
|
+
)
|
|
391
|
+
if obj in (torch.export.Dim.DYNAMIC, cl.DYNAMIC):
|
|
392
|
+
if verbose:
|
|
393
|
+
print(f"[string_type] Y8:{type(obj)}")
|
|
394
|
+
return "DYNAMIC"
|
|
395
|
+
if obj in (torch.export.Dim.AUTO, cl.AUTO):
|
|
396
|
+
if verbose:
|
|
397
|
+
print(f"[string_type] Y9:{type(obj)}")
|
|
398
|
+
return "AUTO"
|
|
399
|
+
if verbose:
|
|
400
|
+
print(f"[string_type] Y7:{type(obj)}")
|
|
401
|
+
return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO")
|
|
402
|
+
|
|
403
|
+
if isinstance(obj, bool):
|
|
404
|
+
if with_min_max:
|
|
405
|
+
if verbose:
|
|
406
|
+
print(f"[string_type] W1:{type(obj)}")
|
|
407
|
+
return f"bool={obj}"
|
|
408
|
+
if verbose:
|
|
409
|
+
print(f"[string_type] W2:{type(obj)}")
|
|
410
|
+
return "bool"
|
|
411
|
+
if isinstance(obj, int):
|
|
412
|
+
if with_min_max:
|
|
413
|
+
if verbose:
|
|
414
|
+
print(f"[string_type] W3:{type(obj)}")
|
|
415
|
+
return f"int={obj}"
|
|
416
|
+
if verbose:
|
|
417
|
+
print(f"[string_type] W4:{type(obj)}")
|
|
418
|
+
return "int"
|
|
419
|
+
if isinstance(obj, float):
|
|
420
|
+
if with_min_max:
|
|
421
|
+
if verbose:
|
|
422
|
+
print(f"[string_type] W6:{type(obj)}")
|
|
423
|
+
return f"float={obj}"
|
|
424
|
+
if verbose:
|
|
425
|
+
print(f"[string_type] W8:{type(obj)}")
|
|
426
|
+
return "float"
|
|
427
|
+
if isinstance(obj, str):
|
|
428
|
+
if verbose:
|
|
429
|
+
print(f"[string_type] W9:{type(obj)}")
|
|
430
|
+
return "str"
|
|
431
|
+
if isinstance(obj, slice):
|
|
432
|
+
if verbose:
|
|
433
|
+
print(f"[string_type] W10:{type(obj)}")
|
|
434
|
+
return "slice"
|
|
435
|
+
|
|
436
|
+
if is_dataclass(obj):
|
|
437
|
+
# That includes torch.export.Dim.AUTO, torch.export.Dim.DYNAMIC so they need to be
|
|
438
|
+
# handled before that.
|
|
439
|
+
values = {f.name: getattr(obj, f.name, None) for f in fields(obj)}
|
|
440
|
+
values = {k: v for k, v in values.items() if v is not None}
|
|
441
|
+
s = string_type(
|
|
442
|
+
values,
|
|
443
|
+
with_shape=with_shape,
|
|
444
|
+
with_min_max=with_min_max,
|
|
445
|
+
with_device=with_device,
|
|
446
|
+
ignore=ignore,
|
|
447
|
+
limit=limit,
|
|
448
|
+
verbose=verbose,
|
|
449
|
+
)
|
|
450
|
+
if verbose:
|
|
451
|
+
print(f"[string_type] B:{type(obj)}")
|
|
452
|
+
return f"{obj.__class__.__name__}{s[4:]}"
|
|
453
|
+
|
|
454
|
+
# Tensors
|
|
455
|
+
if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
|
|
456
|
+
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
457
|
+
|
|
458
|
+
i = torch_dtype_to_onnx_dtype(obj.dtype)
|
|
459
|
+
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
|
|
460
|
+
if not with_shape:
|
|
461
|
+
if verbose:
|
|
462
|
+
print(f"[string_type] F1:{type(obj)}")
|
|
463
|
+
return f"{prefix}F{i}r{len(obj.shape)}"
|
|
464
|
+
if verbose:
|
|
465
|
+
print(f"[string_type] F2:{type(obj)}")
|
|
466
|
+
return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
|
|
467
|
+
|
|
468
|
+
if isinstance(obj, torch.Tensor):
|
|
469
|
+
from .torch_helper import torch_dtype_to_onnx_dtype
|
|
470
|
+
|
|
471
|
+
if with_min_max:
|
|
472
|
+
s = string_type(obj, with_shape=with_shape, with_device=with_device)
|
|
473
|
+
if len(obj.shape) == 0:
|
|
474
|
+
if verbose:
|
|
475
|
+
print(f"[string_type] T1:{type(obj)}")
|
|
476
|
+
return f"{s}={obj}"
|
|
477
|
+
if obj.numel() == 0:
|
|
478
|
+
if verbose:
|
|
479
|
+
print(f"[string_type] T2:{type(obj)}")
|
|
480
|
+
return f"{s}[empty]"
|
|
481
|
+
n_nan = obj.reshape((-1,)).isnan().to(int).sum()
|
|
482
|
+
if n_nan > 0:
|
|
483
|
+
nob = obj.reshape((-1,))
|
|
484
|
+
nob = nob[~nob.isnan()]
|
|
485
|
+
if obj.dtype in {torch.complex64, torch.complex128}:
|
|
486
|
+
if verbose:
|
|
487
|
+
print(f"[string_type] T3:{type(obj)}")
|
|
488
|
+
return (
|
|
489
|
+
f"{s}[{nob.abs().min()},{nob.abs().max():A{nob.mean()}N{n_nan}nans}]"
|
|
490
|
+
)
|
|
491
|
+
if verbose:
|
|
492
|
+
print(f"[string_type] T5:{type(obj)}")
|
|
493
|
+
return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}N{n_nan}nans]"
|
|
494
|
+
if obj.dtype in {torch.complex64, torch.complex128}:
|
|
495
|
+
if verbose:
|
|
496
|
+
print(f"[string_type] T6:{type(obj)}")
|
|
497
|
+
return f"{s}[{obj.abs().min()},{obj.abs().max()}:A{obj.abs().mean()}]"
|
|
498
|
+
if verbose:
|
|
499
|
+
print(f"[string_type] T7:{type(obj)}")
|
|
500
|
+
return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}]"
|
|
501
|
+
i = torch_dtype_to_onnx_dtype(obj.dtype)
|
|
502
|
+
prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
|
|
503
|
+
if not with_shape:
|
|
504
|
+
if verbose:
|
|
505
|
+
print(f"[string_type] T8:{type(obj)}")
|
|
506
|
+
return f"{prefix}T{i}r{len(obj.shape)}"
|
|
507
|
+
if verbose:
|
|
508
|
+
print(f"[string_type] T9:{type(obj)}")
|
|
509
|
+
return f"{prefix}T{i}s{'x'.join(map(str, obj.shape))}"
|
|
510
|
+
|
|
511
|
+
if obj.__class__.__name__ == "OrtValue":
|
|
512
|
+
if not obj.has_value():
|
|
513
|
+
if verbose:
|
|
514
|
+
print(f"[string_type] V1:{type(obj)}")
|
|
515
|
+
return "OV(<novalue>)"
|
|
516
|
+
if not obj.is_tensor():
|
|
517
|
+
if verbose:
|
|
518
|
+
print(f"[string_type] V2:{type(obj)}")
|
|
519
|
+
return "OV(NOTENSOR)"
|
|
520
|
+
if with_min_max:
|
|
521
|
+
from .torch_helper import to_numpy
|
|
522
|
+
|
|
523
|
+
try:
|
|
524
|
+
t = to_numpy(obj)
|
|
525
|
+
except Exception:
|
|
526
|
+
# pass unable to convert into numpy (bfloat16, ...)
|
|
527
|
+
if verbose:
|
|
528
|
+
print(f"[string_type] V3:{type(obj)}")
|
|
529
|
+
return "OV(NO-NUMPY:FIXIT)"
|
|
530
|
+
if verbose:
|
|
531
|
+
print(f"[string_type] V4:{type(obj)}")
|
|
532
|
+
return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
|
|
533
|
+
dt = obj.element_type()
|
|
534
|
+
shape = obj.shape()
|
|
535
|
+
if with_shape:
|
|
536
|
+
if verbose:
|
|
537
|
+
print(f"[string_type] V5:{type(obj)}")
|
|
538
|
+
return f"OV{dt}s{'x'.join(map(str, shape))}"
|
|
539
|
+
if verbose:
|
|
540
|
+
print(f"[string_type] V6:{type(obj)}")
|
|
541
|
+
return f"OV{dt}r{len(shape)}"
|
|
542
|
+
|
|
543
|
+
# others classes
|
|
544
|
+
|
|
545
|
+
if obj.__class__.__name__ == "MambaCache":
|
|
546
|
+
c = string_type(
|
|
547
|
+
obj.conv_states,
|
|
548
|
+
with_shape=with_shape,
|
|
549
|
+
with_min_max=with_min_max,
|
|
550
|
+
with_device=with_device,
|
|
551
|
+
limit=limit,
|
|
552
|
+
verbose=verbose,
|
|
553
|
+
)
|
|
554
|
+
d = string_type(
|
|
555
|
+
obj.ssm_states,
|
|
556
|
+
with_shape=with_shape,
|
|
557
|
+
with_min_max=with_min_max,
|
|
558
|
+
with_device=with_device,
|
|
559
|
+
limit=limit,
|
|
560
|
+
verbose=verbose,
|
|
561
|
+
)
|
|
562
|
+
if verbose:
|
|
563
|
+
print(f"[string_type] CACHE1:{type(obj)}")
|
|
564
|
+
return f"MambaCache(conv_states={c}, ssm_states={d})"
|
|
565
|
+
|
|
566
|
+
if obj.__class__.__name__ in {
|
|
567
|
+
"DynamicCache",
|
|
568
|
+
"SlidingWindowCache",
|
|
569
|
+
"StaticCache",
|
|
570
|
+
"HybridCache",
|
|
571
|
+
}:
|
|
572
|
+
from .cache_helper import CacheKeyValue
|
|
573
|
+
|
|
574
|
+
ca = CacheKeyValue(obj)
|
|
575
|
+
kc = string_type(
|
|
576
|
+
ca.key_cache,
|
|
577
|
+
with_shape=with_shape,
|
|
578
|
+
with_min_max=with_min_max,
|
|
579
|
+
with_device=with_device,
|
|
580
|
+
limit=limit,
|
|
581
|
+
verbose=verbose,
|
|
582
|
+
)
|
|
583
|
+
vc = string_type(
|
|
584
|
+
ca.value_cache,
|
|
585
|
+
with_shape=with_shape,
|
|
586
|
+
with_min_max=with_min_max,
|
|
587
|
+
with_device=with_device,
|
|
588
|
+
limit=limit,
|
|
589
|
+
verbose=verbose,
|
|
590
|
+
)
|
|
591
|
+
if verbose:
|
|
592
|
+
print(f"[string_type] CACHE2:{type(obj)}")
|
|
593
|
+
return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
|
|
594
|
+
|
|
595
|
+
if obj.__class__.__name__ == "StaticLayer":
|
|
596
|
+
kc = string_type(
|
|
597
|
+
list(obj.keys),
|
|
598
|
+
with_shape=with_shape,
|
|
599
|
+
with_min_max=with_min_max,
|
|
600
|
+
with_device=with_device,
|
|
601
|
+
limit=limit,
|
|
602
|
+
verbose=verbose,
|
|
603
|
+
)
|
|
604
|
+
vc = string_type(
|
|
605
|
+
list(obj.values),
|
|
606
|
+
with_shape=with_shape,
|
|
607
|
+
with_min_max=with_min_max,
|
|
608
|
+
with_device=with_device,
|
|
609
|
+
limit=limit,
|
|
610
|
+
verbose=verbose,
|
|
611
|
+
)
|
|
612
|
+
if verbose:
|
|
613
|
+
print(f"[string_type] SL:{type(obj)}")
|
|
614
|
+
return f"{obj.__class__.__name__}(keys={kc}, values={vc})"
|
|
615
|
+
|
|
616
|
+
if obj.__class__.__name__ == "EncoderDecoderCache":
|
|
617
|
+
att = string_type(
|
|
618
|
+
obj.self_attention_cache,
|
|
619
|
+
with_shape=with_shape,
|
|
620
|
+
with_min_max=with_min_max,
|
|
621
|
+
with_device=with_device,
|
|
622
|
+
limit=limit,
|
|
623
|
+
verbose=verbose,
|
|
624
|
+
)
|
|
625
|
+
cross = string_type(
|
|
626
|
+
obj.cross_attention_cache,
|
|
627
|
+
with_shape=with_shape,
|
|
628
|
+
with_min_max=with_min_max,
|
|
629
|
+
with_device=with_device,
|
|
630
|
+
limit=limit,
|
|
631
|
+
verbose=verbose,
|
|
632
|
+
)
|
|
633
|
+
if verbose:
|
|
634
|
+
print(f"[string_type] CACHE3:{type(obj)}")
|
|
635
|
+
return (
|
|
636
|
+
f"{obj.__class__.__name__}(self_attention_cache={att}, "
|
|
637
|
+
f"cross_attention_cache={cross})"
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
641
|
+
from .cache_helper import flatten_unflatten_for_dynamic_shapes
|
|
642
|
+
|
|
643
|
+
args = flatten_unflatten_for_dynamic_shapes(obj)
|
|
644
|
+
att = string_type(
|
|
645
|
+
args,
|
|
646
|
+
with_shape=with_shape,
|
|
647
|
+
with_min_max=with_min_max,
|
|
648
|
+
with_device=with_device,
|
|
649
|
+
limit=limit,
|
|
650
|
+
verbose=verbose,
|
|
651
|
+
)
|
|
652
|
+
if verbose:
|
|
653
|
+
print(f"[string_type] DS:{type(obj)}")
|
|
654
|
+
return f"{obj.__class__.__name__}[serialized]({att})"
|
|
655
|
+
|
|
656
|
+
if type(obj).__name__ == "Node" and hasattr(obj, "meta"):
|
|
657
|
+
# torch.fx.node.Node
|
|
658
|
+
if verbose:
|
|
659
|
+
print(f"[string_type] TT1:{type(obj)}")
|
|
660
|
+
return f"%{obj.target}"
|
|
661
|
+
if type(obj).__name__ == "ValueInfoProto":
|
|
662
|
+
if verbose:
|
|
663
|
+
print(f"[string_type] OO1:{type(obj)}")
|
|
664
|
+
return f"OT{obj.type.tensor_type.elem_type}"
|
|
665
|
+
|
|
666
|
+
if obj.__class__.__name__ == "BatchFeature":
|
|
667
|
+
s = string_type(
|
|
668
|
+
obj.data,
|
|
669
|
+
with_shape=with_shape,
|
|
670
|
+
with_min_max=with_min_max,
|
|
671
|
+
with_device=with_device,
|
|
672
|
+
limit=limit,
|
|
673
|
+
verbose=verbose,
|
|
674
|
+
)
|
|
675
|
+
if verbose:
|
|
676
|
+
print(f"[string_type] TT2:{type(obj)}")
|
|
677
|
+
return f"BatchFeature(data={s})"
|
|
678
|
+
|
|
679
|
+
if obj.__class__.__name__ == "BatchEncoding":
|
|
680
|
+
s = string_type(
|
|
681
|
+
obj.data,
|
|
682
|
+
with_shape=with_shape,
|
|
683
|
+
with_min_max=with_min_max,
|
|
684
|
+
with_device=with_device,
|
|
685
|
+
limit=limit,
|
|
686
|
+
verbose=verbose,
|
|
687
|
+
)
|
|
688
|
+
if verbose:
|
|
689
|
+
print(f"[string_type] TT3:{type(obj)}")
|
|
690
|
+
return f"BatchEncoding(data={s})"
|
|
691
|
+
|
|
692
|
+
if obj.__class__.__name__ == "VirtualTensor":
|
|
693
|
+
if verbose:
|
|
694
|
+
print(f"[string_type] TT4:{type(obj)}")
|
|
695
|
+
return (
|
|
696
|
+
f"{obj.__class__.__name__}(name={obj.name!r}, "
|
|
697
|
+
f"dtype={obj.dtype}, shape={obj.shape})"
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
if obj.__class__.__name__ == "KeyValuesWrapper":
|
|
701
|
+
import transformers
|
|
702
|
+
|
|
703
|
+
assert isinstance(
|
|
704
|
+
obj, transformers.cache_utils.KeyValuesWrapper
|
|
705
|
+
), f"Unexpected type {type(obj)}"
|
|
706
|
+
if verbose:
|
|
707
|
+
print(f"[string_type] KW0:{type(obj)}")
|
|
708
|
+
s = string_type(
|
|
709
|
+
list(obj),
|
|
710
|
+
with_shape=with_shape,
|
|
711
|
+
with_min_max=with_min_max,
|
|
712
|
+
with_device=with_device,
|
|
713
|
+
limit=limit,
|
|
714
|
+
verbose=verbose,
|
|
715
|
+
)
|
|
716
|
+
return f"{obj.__class__.__name__}[{obj.cache_type}]{s}"
|
|
717
|
+
|
|
718
|
+
if obj.__class__.__name__ == "DynamicLayer":
|
|
719
|
+
import transformers
|
|
720
|
+
|
|
721
|
+
assert isinstance(
|
|
722
|
+
obj, transformers.cache_utils.DynamicLayer
|
|
723
|
+
), f"Unexpected type {type(obj)}"
|
|
724
|
+
if verbose:
|
|
725
|
+
print(f"[string_type] LY0:{type(obj)}")
|
|
726
|
+
s1 = string_type(
|
|
727
|
+
obj.keys,
|
|
728
|
+
with_shape=with_shape,
|
|
729
|
+
with_min_max=with_min_max,
|
|
730
|
+
with_device=with_device,
|
|
731
|
+
limit=limit,
|
|
732
|
+
verbose=verbose,
|
|
733
|
+
)
|
|
734
|
+
s2 = string_type(
|
|
735
|
+
obj.values,
|
|
736
|
+
with_shape=with_shape,
|
|
737
|
+
with_min_max=with_min_max,
|
|
738
|
+
with_device=with_device,
|
|
739
|
+
limit=limit,
|
|
740
|
+
verbose=verbose,
|
|
741
|
+
)
|
|
742
|
+
return f"{obj.__class__.__name__}(keys={s1}, values={s2})"
|
|
743
|
+
|
|
744
|
+
if isinstance(obj, torch.nn.Module):
|
|
745
|
+
if verbose:
|
|
746
|
+
print(f"[string_type] MM:{type(obj)}")
|
|
747
|
+
return f"{obj.__class__.__name__}(...)"
|
|
748
|
+
|
|
749
|
+
if isinstance(obj, (torch.device, torch.dtype, torch.memory_format, torch.layout)):
|
|
750
|
+
if verbose:
|
|
751
|
+
print(f"[string_type] TT7:{type(obj)}")
|
|
752
|
+
return f"{obj.__class__.__name__}({obj})"
|
|
753
|
+
|
|
754
|
+
if isinstance( # TreeSpec, MappingKey, SequenceKey
|
|
755
|
+
obj,
|
|
756
|
+
(
|
|
757
|
+
torch.utils._pytree.TreeSpec,
|
|
758
|
+
torch.utils._pytree.MappingKey,
|
|
759
|
+
torch.utils._pytree.SequenceKey,
|
|
760
|
+
),
|
|
761
|
+
):
|
|
762
|
+
if verbose:
|
|
763
|
+
print(f"[string_type] TT8:{type(obj)}")
|
|
764
|
+
return repr(obj).replace(" ", "").replace("\n", " ")
|
|
765
|
+
|
|
766
|
+
if ignore:
|
|
767
|
+
if verbose:
|
|
768
|
+
print(f"[string_type] CACHE4:{type(obj)}")
|
|
769
|
+
return f"{obj.__class__.__name__}(...)"
|
|
770
|
+
|
|
771
|
+
if obj.__class__.__name__.endswith("Config"):
|
|
772
|
+
import transformers.configuration_utils as tcu
|
|
773
|
+
|
|
774
|
+
if isinstance(obj, tcu.PretrainedConfig):
|
|
775
|
+
if verbose:
|
|
776
|
+
print(f"[string_type] CONFIG:{type(obj)}")
|
|
777
|
+
s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "")
|
|
778
|
+
return f"{obj.__class__.__name__}(**{s})"
|
|
779
|
+
if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
|
|
780
|
+
return f"{obj.__class__.__name__}(...)"
|
|
781
|
+
if obj.__class__.__name__ == "Results":
|
|
782
|
+
import ultralytics
|
|
783
|
+
|
|
784
|
+
assert isinstance(
|
|
785
|
+
obj, ultralytics.engine.results.Results
|
|
786
|
+
), f"Unexpected type={type(obj)}"
|
|
787
|
+
return f"ultralytics.{obj.__class__.__name__}(...)"
|
|
788
|
+
if obj.__class__.__name__ == "FakeTensorMode":
|
|
789
|
+
return f"{obj}"
|
|
790
|
+
|
|
791
|
+
if verbose:
|
|
792
|
+
print(f"[string_type] END:{type(obj)}")
|
|
793
|
+
raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
|
|
794
|
+
|
|
795
|
+
|
|
796
|
+
def string_signature(sig: Any) -> str:
|
|
797
|
+
"""Displays the signature of a functions."""
|
|
798
|
+
|
|
799
|
+
def _k(p, kind):
|
|
800
|
+
for name in dir(p):
|
|
801
|
+
if getattr(p, name) == kind:
|
|
802
|
+
return name
|
|
803
|
+
return repr(kind)
|
|
804
|
+
|
|
805
|
+
text = [" __call__ ("]
|
|
806
|
+
for p in sig.parameters:
|
|
807
|
+
pp = sig.parameters[p]
|
|
808
|
+
kind = repr(pp.kind)
|
|
809
|
+
t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p
|
|
810
|
+
if pp.default is not inspect._empty:
|
|
811
|
+
t = f"{t} = {pp.default!r}"
|
|
812
|
+
if kind == pp.VAR_POSITIONAL:
|
|
813
|
+
t = f"*{t}"
|
|
814
|
+
le = (30 - len(t)) * " "
|
|
815
|
+
text.append(f" {t}{le}|{_k(pp,kind)}")
|
|
816
|
+
text.append(
|
|
817
|
+
f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")"
|
|
818
|
+
)
|
|
819
|
+
return "\n".join(text)
|
|
820
|
+
|
|
821
|
+
|
|
822
|
+
def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str:
|
|
823
|
+
"""
|
|
824
|
+
Displays the signature of a function if the default
|
|
825
|
+
if the given value is different from
|
|
826
|
+
"""
|
|
827
|
+
if hasattr(f, "__init__") and kwargs is None:
|
|
828
|
+
fct = f.__init__
|
|
829
|
+
kwargs = f.__dict__
|
|
830
|
+
name = f.__class__.__name__
|
|
831
|
+
else:
|
|
832
|
+
fct = f
|
|
833
|
+
name = f.__name__
|
|
834
|
+
|
|
835
|
+
if kwargs is None:
|
|
836
|
+
kwargs = {}
|
|
837
|
+
rows = []
|
|
838
|
+
sig = inspect.signature(fct)
|
|
839
|
+
for p in sig.parameters:
|
|
840
|
+
pp = sig.parameters[p]
|
|
841
|
+
d = pp.default
|
|
842
|
+
if d is inspect._empty:
|
|
843
|
+
if p in kwargs:
|
|
844
|
+
v = kwargs[p]
|
|
845
|
+
rows.append(
|
|
846
|
+
f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}"
|
|
847
|
+
)
|
|
848
|
+
continue
|
|
849
|
+
v = kwargs.get(p, d)
|
|
850
|
+
if d != v:
|
|
851
|
+
rows.append(f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}")
|
|
852
|
+
continue
|
|
853
|
+
atts = ", ".join(rows)
|
|
854
|
+
return f"{name}({atts})"
|
|
855
|
+
|
|
856
|
+
|
|
857
|
+
def make_hash(obj: Any) -> str:
|
|
858
|
+
"""
|
|
859
|
+
Returns a simple hash of ``id(obj)`` in four letter.
|
|
860
|
+
"""
|
|
861
|
+
aa = id(obj) % (26**3)
|
|
862
|
+
return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"
|
|
863
|
+
|
|
864
|
+
|
|
865
|
+
def rename_dynamic_dimensions(
|
|
866
|
+
constraints: Dict[str, Set[str]], original: Set[str], ban_prefix: str = "DYN"
|
|
867
|
+
) -> Dict[str, str]:
|
|
868
|
+
"""
|
|
869
|
+
Renames dynamic shapes as requested by the user. :func:`torch.export.export` uses
|
|
870
|
+
many names for dynamic dimensions. When building the onnx model,
|
|
871
|
+
some of them are redundant and can be replaced by the name provided by the user.
|
|
872
|
+
|
|
873
|
+
:param constraints: exhaustive list of used names and all the values equal to it
|
|
874
|
+
:param original: the names to use if possible
|
|
875
|
+
:param ban_prefix: avoid any rewriting by a constant starting with this prefix
|
|
876
|
+
:return: replacement dictionary
|
|
877
|
+
"""
|
|
878
|
+
replacements = {s: s for s in original}
|
|
879
|
+
all_values = set(constraints) | original
|
|
880
|
+
|
|
881
|
+
not_done = set(constraints)
|
|
882
|
+
max_iter = len(replacements)
|
|
883
|
+
while not_done and max_iter > 0:
|
|
884
|
+
max_iter -= 1
|
|
885
|
+
for k, v in constraints.items():
|
|
886
|
+
common = v & original
|
|
887
|
+
if not common:
|
|
888
|
+
continue
|
|
889
|
+
sorted_common = sorted(common)
|
|
890
|
+
by = sorted_common[0]
|
|
891
|
+
if ban_prefix and by.startswith(ban_prefix):
|
|
892
|
+
continue
|
|
893
|
+
replacements[k] = by
|
|
894
|
+
for vv in v:
|
|
895
|
+
if vv not in replacements:
|
|
896
|
+
replacements[vv] = by
|
|
897
|
+
not_done = all_values - set(replacements)
|
|
898
|
+
return replacements
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def rename_dynamic_expression(expression: str, replacements: Dict[str, str]):
|
|
902
|
+
"""
|
|
903
|
+
Renames variables of an expression.
|
|
904
|
+
|
|
905
|
+
:param expression: something like ``s15 + seq_length``
|
|
906
|
+
:param replacements: replacements to make
|
|
907
|
+
:return: new string
|
|
908
|
+
"""
|
|
909
|
+
|
|
910
|
+
class RenameVariable(ast.NodeTransformer):
|
|
911
|
+
def visit_Name(self, node):
|
|
912
|
+
if node.id in replacements:
|
|
913
|
+
node.id = replacements[node.id]
|
|
914
|
+
return node
|
|
915
|
+
|
|
916
|
+
tree = ast.parse(expression)
|
|
917
|
+
transformer = RenameVariable()
|
|
918
|
+
new_tree = transformer.visit(tree)
|
|
919
|
+
return ast.unparse(new_tree)
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
923
|
+
"""
|
|
924
|
+
Flattens the object.
|
|
925
|
+
It accepts some common classes used in deep learning.
|
|
926
|
+
|
|
927
|
+
:param x: any object
|
|
928
|
+
:param drop_keys: drop the keys if a dictionary is flattened.
|
|
929
|
+
Keeps the order defined by the dictionary if False, sort them if True.
|
|
930
|
+
:return: flattened object
|
|
931
|
+
"""
|
|
932
|
+
if x is None:
|
|
933
|
+
return x
|
|
934
|
+
if isinstance(x, (list, tuple)):
|
|
935
|
+
res = []
|
|
936
|
+
for i in x:
|
|
937
|
+
if i is None or hasattr(i, "shape") or isinstance(i, (int, float, str)):
|
|
938
|
+
res.append(i)
|
|
939
|
+
else:
|
|
940
|
+
res.extend(flatten_object(i, drop_keys=drop_keys))
|
|
941
|
+
return tuple(res) if isinstance(x, tuple) else res
|
|
942
|
+
if isinstance(x, dict):
|
|
943
|
+
# We flatten the keys.
|
|
944
|
+
if drop_keys:
|
|
945
|
+
return flatten_object(list(x.values()), drop_keys=drop_keys)
|
|
946
|
+
return flatten_object(list(x.items()), drop_keys=drop_keys)
|
|
947
|
+
|
|
948
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
949
|
+
from .cache_helper import CacheKeyValue
|
|
950
|
+
|
|
951
|
+
kc = CacheKeyValue(x)
|
|
952
|
+
return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
|
|
953
|
+
|
|
954
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
955
|
+
res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
|
|
956
|
+
return tuple(res)
|
|
957
|
+
if x.__class__.__name__ == "MambaCache":
|
|
958
|
+
if isinstance(x.conv_states, list):
|
|
959
|
+
res = flatten_object(x.conv_states) + flatten_object(x.ssm_states)
|
|
960
|
+
return tuple(res)
|
|
961
|
+
return (x.conv_states, x.ssm_states)
|
|
962
|
+
if hasattr(x, "to_tuple"):
|
|
963
|
+
return flatten_object(x.to_tuple(), drop_keys=drop_keys)
|
|
964
|
+
if hasattr(x, "shape"):
|
|
965
|
+
# A tensor. Nothing to do.
|
|
966
|
+
return x
|
|
967
|
+
raise TypeError(
|
|
968
|
+
f"Unexpected type {type(x)} for x, drop_keys={drop_keys}, "
|
|
969
|
+
f"content is {string_type(x, with_shape=True)}"
|
|
970
|
+
)
|
|
971
|
+
|
|
972
|
+
|
|
973
|
+
def _make_debug_info(msg, level, debug_info, verbose) -> Optional[List[str]]:
|
|
974
|
+
return (
|
|
975
|
+
[*(debug_info if debug_info else []), f"{' ' * level}{msg}"] if verbose > 5 else None
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
|
|
979
|
+
def max_diff(
|
|
980
|
+
expected: Any,
|
|
981
|
+
got: Any,
|
|
982
|
+
verbose: int = 0,
|
|
983
|
+
level: int = 0,
|
|
984
|
+
flatten: bool = False,
|
|
985
|
+
debug_info: Optional[List[str]] = None,
|
|
986
|
+
begin: int = 0,
|
|
987
|
+
end: int = -1,
|
|
988
|
+
_index: int = 0,
|
|
989
|
+
allow_unique_tensor_with_list_of_one_element: bool = True,
|
|
990
|
+
hist: Optional[Union[bool, List[float]]] = None,
|
|
991
|
+
) -> Dict[str, Union[float, int, Tuple[int, ...]]]:
|
|
992
|
+
"""
|
|
993
|
+
Returns the maximum discrepancy.
|
|
994
|
+
|
|
995
|
+
:param expected: expected values
|
|
996
|
+
:param got: values
|
|
997
|
+
:param verbose: verbosity level
|
|
998
|
+
:param level: for embedded outputs, used for debug purpposes
|
|
999
|
+
:param flatten: flatten outputs
|
|
1000
|
+
:param debug_info: debug information
|
|
1001
|
+
:param begin: first output to considered
|
|
1002
|
+
:param end: last output to considered (-1 for the last one)
|
|
1003
|
+
:param _index: used with begin and end
|
|
1004
|
+
:param allow_unique_tensor_with_list_of_one_element:
|
|
1005
|
+
allow a comparison between a single tensor and a list of one tensor
|
|
1006
|
+
:param hist: compute an histogram of the discrepancies
|
|
1007
|
+
:return: dictionary with many values
|
|
1008
|
+
|
|
1009
|
+
* abs: max absolute error
|
|
1010
|
+
* rel: max relative error
|
|
1011
|
+
* sum: sum of the errors
|
|
1012
|
+
* n: number of outputs values, if there is one
|
|
1013
|
+
output, this number will be the number of elements
|
|
1014
|
+
of this output
|
|
1015
|
+
* dnan: difference in the number of nan
|
|
1016
|
+
|
|
1017
|
+
You may use :func:`string_diff` to display the discrepancies in one string.
|
|
1018
|
+
"""
|
|
1019
|
+
if expected is None and got is None:
|
|
1020
|
+
return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
|
|
1021
|
+
|
|
1022
|
+
_dkws_ = dict(
|
|
1023
|
+
verbose=verbose,
|
|
1024
|
+
level=level + 1,
|
|
1025
|
+
begin=begin,
|
|
1026
|
+
end=end,
|
|
1027
|
+
_index=_index,
|
|
1028
|
+
hist=hist,
|
|
1029
|
+
)
|
|
1030
|
+
_dkws = {**_dkws_, "flatten": flatten}
|
|
1031
|
+
_dkwsf = {**_dkws_, "flatten": False}
|
|
1032
|
+
|
|
1033
|
+
_debug = lambda msg: _make_debug_info(msg, level, debug_info, verbose) # noqa: E731
|
|
1034
|
+
|
|
1035
|
+
if allow_unique_tensor_with_list_of_one_element:
|
|
1036
|
+
if hasattr(expected, "shape") and isinstance(got, (list, tuple)) and len(got) == 1:
|
|
1037
|
+
return max_diff(
|
|
1038
|
+
expected,
|
|
1039
|
+
got[0],
|
|
1040
|
+
verbose=verbose,
|
|
1041
|
+
level=level,
|
|
1042
|
+
flatten=False,
|
|
1043
|
+
debug_info=debug_info,
|
|
1044
|
+
allow_unique_tensor_with_list_of_one_element=False,
|
|
1045
|
+
hist=hist,
|
|
1046
|
+
)
|
|
1047
|
+
return max_diff(
|
|
1048
|
+
expected,
|
|
1049
|
+
got,
|
|
1050
|
+
verbose=verbose,
|
|
1051
|
+
level=level,
|
|
1052
|
+
flatten=flatten,
|
|
1053
|
+
debug_info=debug_info,
|
|
1054
|
+
begin=begin,
|
|
1055
|
+
end=end,
|
|
1056
|
+
_index=_index,
|
|
1057
|
+
allow_unique_tensor_with_list_of_one_element=False,
|
|
1058
|
+
hist=hist,
|
|
1059
|
+
)
|
|
1060
|
+
|
|
1061
|
+
if expected.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1062
|
+
if verbose >= 6:
|
|
1063
|
+
print(
|
|
1064
|
+
f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
|
|
1065
|
+
f"? {string_type(got)}"
|
|
1066
|
+
)
|
|
1067
|
+
if got.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1068
|
+
return max_diff(
|
|
1069
|
+
[expected.logits, *flatten_object(expected.past_key_values)],
|
|
1070
|
+
[got.logits, *flatten_object(got.past_key_values)],
|
|
1071
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1072
|
+
**_dkws,
|
|
1073
|
+
)
|
|
1074
|
+
return max_diff(
|
|
1075
|
+
[expected.logits, *flatten_object(expected.past_key_values)],
|
|
1076
|
+
got,
|
|
1077
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1078
|
+
**_dkws,
|
|
1079
|
+
)
|
|
1080
|
+
|
|
1081
|
+
if hasattr(expected, "to_tuple"):
|
|
1082
|
+
if verbose >= 6:
|
|
1083
|
+
print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}")
|
|
1084
|
+
return max_diff(expected.to_tuple(), got, debug_info=_debug("to_tuple1"), **_dkws)
|
|
1085
|
+
|
|
1086
|
+
if hasattr(got, "to_tuple"):
|
|
1087
|
+
if verbose >= 6:
|
|
1088
|
+
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
|
|
1089
|
+
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)
|
|
1090
|
+
|
|
1091
|
+
if isinstance(expected, (tuple, list)):
|
|
1092
|
+
if verbose >= 6:
|
|
1093
|
+
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
|
|
1094
|
+
if len(expected) == 1 and not isinstance(got, type(expected)):
|
|
1095
|
+
if verbose >= 6:
|
|
1096
|
+
print(f"[max_diff] list,tuple,3: {string_type(expected)} ? {string_type(got)}")
|
|
1097
|
+
return max_diff(expected[0], got, debug_info=_debug("lt2"), **_dkws)
|
|
1098
|
+
if not isinstance(got, (tuple, list)):
|
|
1099
|
+
if verbose >= 6:
|
|
1100
|
+
print(f"[max_diff] list,tuple,4: {string_type(expected)} ? {string_type(got)}")
|
|
1101
|
+
if verbose > 2:
|
|
1102
|
+
print(
|
|
1103
|
+
f"[max_diff] inf because type(expected)={type(expected)}, "
|
|
1104
|
+
f"type(got)={type(got)}, level={level}, _index={_index}"
|
|
1105
|
+
)
|
|
1106
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1107
|
+
|
|
1108
|
+
if len(got) != len(expected):
|
|
1109
|
+
if flatten:
|
|
1110
|
+
if verbose >= 6:
|
|
1111
|
+
print(
|
|
1112
|
+
f"[max_diff] list,tuple,5: {string_type(expected)} "
|
|
1113
|
+
f"? {string_type(got)}"
|
|
1114
|
+
)
|
|
1115
|
+
# Let's flatten.
|
|
1116
|
+
if verbose > 2:
|
|
1117
|
+
print(
|
|
1118
|
+
f"[max_diff] flattening because of length mismatch, "
|
|
1119
|
+
f"expected is\n {string_type(expected)}\n -- and got is\n "
|
|
1120
|
+
f"{string_type(got)}"
|
|
1121
|
+
)
|
|
1122
|
+
flat_a = flatten_object(expected, drop_keys=True)
|
|
1123
|
+
flat_b = flatten_object(got, drop_keys=True)
|
|
1124
|
+
if verbose > 2:
|
|
1125
|
+
print(
|
|
1126
|
+
f"[max_diff] after flattening, "
|
|
1127
|
+
f"expected is\n {string_type(flat_a)}\n -- and got is\n "
|
|
1128
|
+
f"{string_type(flat_b)}"
|
|
1129
|
+
)
|
|
1130
|
+
return max_diff(
|
|
1131
|
+
flat_a,
|
|
1132
|
+
flat_b,
|
|
1133
|
+
debug_info=(
|
|
1134
|
+
[
|
|
1135
|
+
*(debug_info if debug_info else []),
|
|
1136
|
+
(
|
|
1137
|
+
f"{' ' * level}flatten["
|
|
1138
|
+
f"{string_type(expected)},{string_type(got)}]"
|
|
1139
|
+
),
|
|
1140
|
+
]
|
|
1141
|
+
if verbose > 5
|
|
1142
|
+
else None
|
|
1143
|
+
),
|
|
1144
|
+
**_dkwsf,
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
if verbose > 2:
|
|
1148
|
+
import torch
|
|
1149
|
+
|
|
1150
|
+
print(
|
|
1151
|
+
f"[max_diff] (b) inf because len(expected)={len(expected)}, "
|
|
1152
|
+
f"len(got)={len(got)}, level={level}, _index={_index}"
|
|
1153
|
+
)
|
|
1154
|
+
for i, (a, b) in enumerate(zip(expected, got)):
|
|
1155
|
+
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
|
|
1156
|
+
print(
|
|
1157
|
+
f" i={i} expected {a.dtype}:{a.shape}, "
|
|
1158
|
+
f"has {b.dtype}:{b.shape}, _index={_index}"
|
|
1159
|
+
)
|
|
1160
|
+
else:
|
|
1161
|
+
print(f" i={i} a is {type(a)}, b is {type(b)}")
|
|
1162
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1163
|
+
|
|
1164
|
+
if verbose >= 6:
|
|
1165
|
+
print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}")
|
|
1166
|
+
am, rm, sm, n, dn, drep = 0, 0, 0.0, 0.0, 0, None
|
|
1167
|
+
for ip, (e, g) in enumerate(zip(expected, got)):
|
|
1168
|
+
d = max_diff(
|
|
1169
|
+
e,
|
|
1170
|
+
g,
|
|
1171
|
+
verbose=verbose,
|
|
1172
|
+
level=level + 1,
|
|
1173
|
+
debug_info=(
|
|
1174
|
+
[
|
|
1175
|
+
*(debug_info if debug_info else []),
|
|
1176
|
+
f"{' ' * level}[{ip}] so far abs {am} - rel {rm}",
|
|
1177
|
+
]
|
|
1178
|
+
if verbose > 5
|
|
1179
|
+
else None
|
|
1180
|
+
),
|
|
1181
|
+
begin=begin,
|
|
1182
|
+
end=end,
|
|
1183
|
+
_index=_index + ip,
|
|
1184
|
+
flatten=flatten,
|
|
1185
|
+
hist=hist,
|
|
1186
|
+
)
|
|
1187
|
+
am = max(am, d["abs"])
|
|
1188
|
+
dn = max(dn, d["dnan"])
|
|
1189
|
+
rm = max(rm, d["rel"])
|
|
1190
|
+
sm += d["sum"] # type: ignore
|
|
1191
|
+
n += d["n"] # type: ignore
|
|
1192
|
+
if "rep" in d:
|
|
1193
|
+
if drep is None:
|
|
1194
|
+
drep = d["rep"].copy()
|
|
1195
|
+
else:
|
|
1196
|
+
for k, v in d["rep"].items():
|
|
1197
|
+
drep[k] += v
|
|
1198
|
+
res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
|
|
1199
|
+
if drep:
|
|
1200
|
+
res["rep"] = drep
|
|
1201
|
+
return res # type: ignore
|
|
1202
|
+
|
|
1203
|
+
if isinstance(expected, dict):
|
|
1204
|
+
if verbose >= 6:
|
|
1205
|
+
print(f"[max_diff] dict: {string_type(expected)} ? {string_type(got)}")
|
|
1206
|
+
assert begin == 0 and end == -1, (
|
|
1207
|
+
f"begin={begin}, end={end} not compatible with dictionaries, "
|
|
1208
|
+
f"keys={sorted(expected)}"
|
|
1209
|
+
)
|
|
1210
|
+
if isinstance(got, dict):
|
|
1211
|
+
if len(expected) != len(got):
|
|
1212
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1213
|
+
if set(expected) != set(got):
|
|
1214
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1215
|
+
keys = sorted(expected)
|
|
1216
|
+
return max_diff(
|
|
1217
|
+
[expected[k] for k in keys],
|
|
1218
|
+
[got[k] for k in keys],
|
|
1219
|
+
debug_info=_debug("dict1"),
|
|
1220
|
+
**_dkws,
|
|
1221
|
+
)
|
|
1222
|
+
|
|
1223
|
+
if not isinstance(got, (tuple, list)):
|
|
1224
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1225
|
+
if len(expected) != len(got):
|
|
1226
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1227
|
+
return max_diff(list(expected.values()), got, debug_info=_debug("dict2"), **_dkws)
|
|
1228
|
+
|
|
1229
|
+
import torch
|
|
1230
|
+
|
|
1231
|
+
if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
|
|
1232
|
+
if isinstance(expected, torch.Tensor):
|
|
1233
|
+
from .torch_helper import to_numpy
|
|
1234
|
+
|
|
1235
|
+
expected = to_numpy(expected)
|
|
1236
|
+
if isinstance(got, torch.Tensor):
|
|
1237
|
+
from .torch_helper import to_numpy
|
|
1238
|
+
|
|
1239
|
+
got = to_numpy(got)
|
|
1240
|
+
if verbose >= 6:
|
|
1241
|
+
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
|
|
1242
|
+
|
|
1243
|
+
if _index < begin or (end != -1 and _index >= end):
|
|
1244
|
+
# out of boundary
|
|
1245
|
+
return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
|
|
1246
|
+
if isinstance(expected, (int, float)):
|
|
1247
|
+
if isinstance(got, np.ndarray) and len(got.shape) == 0:
|
|
1248
|
+
got = float(got)
|
|
1249
|
+
if isinstance(got, (int, float)):
|
|
1250
|
+
if expected == got:
|
|
1251
|
+
return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
|
|
1252
|
+
return dict(
|
|
1253
|
+
abs=abs(expected - got),
|
|
1254
|
+
rel=abs(expected - got) / (abs(expected) + 1e-5),
|
|
1255
|
+
sum=abs(expected - got),
|
|
1256
|
+
n=1,
|
|
1257
|
+
dnan=0,
|
|
1258
|
+
)
|
|
1259
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1260
|
+
if expected.dtype in (np.complex64, np.complex128):
|
|
1261
|
+
if got.dtype == expected.dtype:
|
|
1262
|
+
got = np.real(got)
|
|
1263
|
+
elif got.dtype not in (np.float32, np.float64):
|
|
1264
|
+
if verbose >= 10:
|
|
1265
|
+
# To understand the value it comes from.
|
|
1266
|
+
if debug_info:
|
|
1267
|
+
print("\n".join(debug_info))
|
|
1268
|
+
print(
|
|
1269
|
+
f"[max_diff-c] expected.dtype={expected.dtype}, "
|
|
1270
|
+
f"got.dtype={got.dtype}"
|
|
1271
|
+
)
|
|
1272
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1273
|
+
expected = np.real(expected)
|
|
1274
|
+
|
|
1275
|
+
if expected.shape != got.shape:
|
|
1276
|
+
if verbose >= 10:
|
|
1277
|
+
# To understand the value it comes from.
|
|
1278
|
+
if debug_info:
|
|
1279
|
+
print("\n".join(debug_info))
|
|
1280
|
+
print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}")
|
|
1281
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1282
|
+
# nan are replace by 1e10, any discrepancies in that order of magnitude
|
|
1283
|
+
# is likely caused by nans
|
|
1284
|
+
exp_cpu = np.nan_to_num(expected.astype(np.float64), nan=1e10)
|
|
1285
|
+
got_cpu = np.nan_to_num(got.astype(np.float64), nan=1e10)
|
|
1286
|
+
diff = np.abs(got_cpu - exp_cpu)
|
|
1287
|
+
ndiff = np.abs(np.isnan(expected).astype(int) - np.isnan(got).astype(int))
|
|
1288
|
+
rdiff = diff / (np.abs(exp_cpu) + 1e-3)
|
|
1289
|
+
if diff.size == 0:
|
|
1290
|
+
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
|
|
1291
|
+
(0, 0, 0, 0, 0)
|
|
1292
|
+
if exp_cpu.size == got_cpu.size
|
|
1293
|
+
else (np.inf, np.inf, np.inf, 0, np.inf)
|
|
1294
|
+
)
|
|
1295
|
+
argm = None
|
|
1296
|
+
else:
|
|
1297
|
+
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
|
|
1298
|
+
float(diff.max()),
|
|
1299
|
+
float(rdiff.max()),
|
|
1300
|
+
float(diff.sum()),
|
|
1301
|
+
float(diff.size),
|
|
1302
|
+
float(ndiff.sum()),
|
|
1303
|
+
)
|
|
1304
|
+
argm = tuple(map(int, np.unravel_index(diff.argmax(), diff.shape)))
|
|
1305
|
+
if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
|
|
1306
|
+
# To understand the value it comes from.
|
|
1307
|
+
if debug_info:
|
|
1308
|
+
print("\n".join(debug_info))
|
|
1309
|
+
print(
|
|
1310
|
+
f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, "
|
|
1311
|
+
f"nan_diff={nan_diff}, dtype={expected.dtype}, "
|
|
1312
|
+
f"shape={expected.shape}, level={level}, _index={_index}"
|
|
1313
|
+
)
|
|
1314
|
+
if abs_diff >= 10:
|
|
1315
|
+
idiff = np.argmax(diff.reshape((-1,)))
|
|
1316
|
+
x = expected.reshape((-1,))[idiff]
|
|
1317
|
+
y = got.reshape((-1,))[idiff]
|
|
1318
|
+
print(
|
|
1319
|
+
f" [max_diff-2] abs diff={abs_diff}, "
|
|
1320
|
+
f"x={x}, y={y}, level={level}, "
|
|
1321
|
+
f"_index={_index}"
|
|
1322
|
+
)
|
|
1323
|
+
print(y)
|
|
1324
|
+
|
|
1325
|
+
if rel_diff >= 10:
|
|
1326
|
+
idiff = np.argmax(rdiff.reshape((-1,)))
|
|
1327
|
+
x = expected.reshape((-1,))[idiff]
|
|
1328
|
+
y = got.reshape((-1,))[idiff]
|
|
1329
|
+
print(
|
|
1330
|
+
f" [max_diff-3] rel diff={rel_diff}, "
|
|
1331
|
+
f"x={x}, y={y}, level={level}, "
|
|
1332
|
+
f"_index={_index}"
|
|
1333
|
+
)
|
|
1334
|
+
|
|
1335
|
+
res: Dict[str, float] = dict( # type: ignore
|
|
1336
|
+
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
|
|
1337
|
+
)
|
|
1338
|
+
if hist:
|
|
1339
|
+
if isinstance(hist, bool):
|
|
1340
|
+
hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
|
|
1341
|
+
ind = np.digitize(diff.reshape((-1,)), hist, right=True)
|
|
1342
|
+
cou = np.bincount(ind, minlength=ind.shape[0] + 1)
|
|
1343
|
+
res["rep"] = dict(
|
|
1344
|
+
zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])
|
|
1345
|
+
)
|
|
1346
|
+
return res # type: ignore
|
|
1347
|
+
|
|
1348
|
+
if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
|
|
1349
|
+
if verbose >= 6:
|
|
1350
|
+
print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
|
|
1351
|
+
if _index < begin or (end != -1 and _index >= end):
|
|
1352
|
+
# out of boundary
|
|
1353
|
+
return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
|
|
1354
|
+
if expected.dtype in (torch.complex64, torch.complex128):
|
|
1355
|
+
if got.dtype == expected.dtype:
|
|
1356
|
+
got = torch.view_as_real(got)
|
|
1357
|
+
elif got.dtype not in (torch.float32, torch.float64):
|
|
1358
|
+
if verbose >= 10:
|
|
1359
|
+
# To understand the value it comes from.
|
|
1360
|
+
if debug_info:
|
|
1361
|
+
print("\n".join(debug_info))
|
|
1362
|
+
print(
|
|
1363
|
+
f"[max_diff-c] expected.dtype={expected.dtype}, "
|
|
1364
|
+
f"got.dtype={got.dtype}"
|
|
1365
|
+
)
|
|
1366
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1367
|
+
expected = torch.view_as_real(expected)
|
|
1368
|
+
|
|
1369
|
+
if expected.shape != got.shape:
|
|
1370
|
+
if verbose >= 10:
|
|
1371
|
+
# To understand the value it comes from.
|
|
1372
|
+
if debug_info:
|
|
1373
|
+
print("\n".join(debug_info))
|
|
1374
|
+
print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}")
|
|
1375
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1376
|
+
# nan are replace by 1e10, any discrepancies in that order of magnitude
|
|
1377
|
+
# is likely caused by nans
|
|
1378
|
+
exp_cpu = expected.to(torch.float64).nan_to_num(1e10)
|
|
1379
|
+
got_cpu = got.to(torch.float64).nan_to_num(1e10)
|
|
1380
|
+
if got_cpu.device != exp_cpu.device:
|
|
1381
|
+
if torch.device("cuda:0") in {got_cpu.device, exp_cpu.device}:
|
|
1382
|
+
got_cpu = got_cpu.to("cuda:0")
|
|
1383
|
+
exp_cpu = exp_cpu.to("cuda:0")
|
|
1384
|
+
expected = expected.to("cuda:0")
|
|
1385
|
+
got = got.to("cuda:0")
|
|
1386
|
+
else:
|
|
1387
|
+
got_cpu = got_cpu.detach().to("cpu")
|
|
1388
|
+
exp_cpu = exp_cpu.detach().to("cpu")
|
|
1389
|
+
expected = expected.to("cpu")
|
|
1390
|
+
got = got.to("cpu")
|
|
1391
|
+
diff = (got_cpu - exp_cpu).abs()
|
|
1392
|
+
ndiff = (expected.isnan().to(int) - got.isnan().to(int)).abs()
|
|
1393
|
+
rdiff = diff / (exp_cpu.abs() + 1e-3)
|
|
1394
|
+
if diff.numel() > 0:
|
|
1395
|
+
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
|
|
1396
|
+
float(diff.max().detach()),
|
|
1397
|
+
float(rdiff.max().detach()),
|
|
1398
|
+
float(diff.sum().detach()),
|
|
1399
|
+
float(diff.numel()),
|
|
1400
|
+
float(ndiff.sum().detach()),
|
|
1401
|
+
)
|
|
1402
|
+
argm = tuple(map(int, torch.unravel_index(diff.argmax(), diff.shape)))
|
|
1403
|
+
elif got_cpu.numel() == exp_cpu.numel():
|
|
1404
|
+
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (0.0, 0.0, 0.0, 0.0, 0.0)
|
|
1405
|
+
argm = None
|
|
1406
|
+
else:
|
|
1407
|
+
abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
|
|
1408
|
+
np.inf,
|
|
1409
|
+
np.inf,
|
|
1410
|
+
np.inf,
|
|
1411
|
+
np.inf,
|
|
1412
|
+
np.inf,
|
|
1413
|
+
)
|
|
1414
|
+
argm = None
|
|
1415
|
+
|
|
1416
|
+
if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
|
|
1417
|
+
# To understand the value it comes from.
|
|
1418
|
+
if debug_info:
|
|
1419
|
+
print("\n".join(debug_info))
|
|
1420
|
+
print(
|
|
1421
|
+
f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, "
|
|
1422
|
+
f"nan_diff={nan_diff}, dtype={expected.dtype}, "
|
|
1423
|
+
f"shape={expected.shape}, level={level}, _index={_index}"
|
|
1424
|
+
)
|
|
1425
|
+
if abs_diff >= 10:
|
|
1426
|
+
idiff = torch.argmax(diff.reshape((-1,)))
|
|
1427
|
+
x = expected.reshape((-1,))[idiff]
|
|
1428
|
+
y = got.reshape((-1,))[idiff]
|
|
1429
|
+
print(
|
|
1430
|
+
f" [max_diff-2] abs diff={abs_diff}, "
|
|
1431
|
+
f"x={x}, y={y}, level={level}, "
|
|
1432
|
+
f"_index={_index}"
|
|
1433
|
+
)
|
|
1434
|
+
print(y)
|
|
1435
|
+
|
|
1436
|
+
if rel_diff >= 10:
|
|
1437
|
+
idiff = torch.argmax(rdiff.reshape((-1,)))
|
|
1438
|
+
x = expected.reshape((-1,))[idiff]
|
|
1439
|
+
y = got.reshape((-1,))[idiff]
|
|
1440
|
+
print(
|
|
1441
|
+
f" [max_diff-3] rel diff={rel_diff}, "
|
|
1442
|
+
f"x={x}, y={y}, level={level}, "
|
|
1443
|
+
f"_index={_index}"
|
|
1444
|
+
)
|
|
1445
|
+
|
|
1446
|
+
res: Dict[str, float] = dict( # type: ignore
|
|
1447
|
+
abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
|
|
1448
|
+
)
|
|
1449
|
+
if hist:
|
|
1450
|
+
if isinstance(hist, bool):
|
|
1451
|
+
hist = torch.tensor(
|
|
1452
|
+
[0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
|
|
1453
|
+
)
|
|
1454
|
+
hist = hist.to(diff.device)
|
|
1455
|
+
ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
|
|
1456
|
+
cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
|
|
1457
|
+
res["rep"] = dict(
|
|
1458
|
+
zip(
|
|
1459
|
+
[f">{x}" for x in hist],
|
|
1460
|
+
[int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
|
|
1461
|
+
)
|
|
1462
|
+
)
|
|
1463
|
+
return res # type: ignore
|
|
1464
|
+
|
|
1465
|
+
if "SquashedNormal" in expected.__class__.__name__:
|
|
1466
|
+
if verbose >= 6:
|
|
1467
|
+
print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}")
|
|
1468
|
+
values = (
|
|
1469
|
+
expected.mean.detach().to("cpu"),
|
|
1470
|
+
expected.scale.detach().to("cpu"),
|
|
1471
|
+
)
|
|
1472
|
+
return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws)
|
|
1473
|
+
|
|
1474
|
+
if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
1475
|
+
if got.__class__ not in torch.utils._pytree.SUPPORTED_NODES:
|
|
1476
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1477
|
+
if verbose >= 6:
|
|
1478
|
+
print(
|
|
1479
|
+
f"[max_diff*] {expected.__class__.__name__}: "
|
|
1480
|
+
f"{string_type(expected)} ? {string_type(got)}"
|
|
1481
|
+
)
|
|
1482
|
+
expected_args, _spec = torch.utils._pytree.tree_flatten(expected)
|
|
1483
|
+
got_args, _spec = torch.utils._pytree.tree_flatten(got)
|
|
1484
|
+
return max_diff(
|
|
1485
|
+
expected_args, got_args, debug_info=_debug(expected.__class__.__name__), **_dkws
|
|
1486
|
+
)
|
|
1487
|
+
|
|
1488
|
+
# backup function in case pytorch does not know how to serialize.
|
|
1489
|
+
if expected.__class__.__name__ == "DynamicCache":
|
|
1490
|
+
if got.__class__.__name__ == "DynamicCache":
|
|
1491
|
+
from .cache_helper import CacheKeyValue
|
|
1492
|
+
|
|
1493
|
+
if verbose >= 6:
|
|
1494
|
+
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
|
|
1495
|
+
expected = CacheKeyValue(expected)
|
|
1496
|
+
got = CacheKeyValue(got)
|
|
1497
|
+
return max_diff(
|
|
1498
|
+
[expected.key_cache, expected.value_cache],
|
|
1499
|
+
[got.key_cache, got.value_cache],
|
|
1500
|
+
verbose=verbose,
|
|
1501
|
+
hist=hist,
|
|
1502
|
+
)
|
|
1503
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1504
|
+
return max_diff(
|
|
1505
|
+
[expected.key_cache, expected.value_cache],
|
|
1506
|
+
[got[0], got[1]],
|
|
1507
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1508
|
+
**_dkws,
|
|
1509
|
+
)
|
|
1510
|
+
raise AssertionError(
|
|
1511
|
+
f"DynamicCache not fully implemented with classes "
|
|
1512
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1513
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1514
|
+
f"level={level}"
|
|
1515
|
+
)
|
|
1516
|
+
|
|
1517
|
+
# backup function in case pytorch does not know how to serialize.
|
|
1518
|
+
if expected.__class__.__name__ == "HybridCache":
|
|
1519
|
+
if got.__class__.__name__ == "HybridCache":
|
|
1520
|
+
from .cache_helper import CacheKeyValue
|
|
1521
|
+
|
|
1522
|
+
if verbose >= 6:
|
|
1523
|
+
print(f"[max_diff] HybridCache: {string_type(expected)} ? {string_type(got)}")
|
|
1524
|
+
cae = CacheKeyValue(expected)
|
|
1525
|
+
cag = CacheKeyValue(got)
|
|
1526
|
+
return max_diff(
|
|
1527
|
+
[cae.key_cache, cae.value_cache],
|
|
1528
|
+
[cag.key_cache, cag.value_cache],
|
|
1529
|
+
verbose=verbose,
|
|
1530
|
+
hist=hist,
|
|
1531
|
+
)
|
|
1532
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1533
|
+
from .cache_helper import CacheKeyValue
|
|
1534
|
+
|
|
1535
|
+
cae = CacheKeyValue(expected)
|
|
1536
|
+
return max_diff(
|
|
1537
|
+
[cae.key_cache, cae.value_cache],
|
|
1538
|
+
[got[0], got[1]],
|
|
1539
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1540
|
+
**_dkws,
|
|
1541
|
+
)
|
|
1542
|
+
raise AssertionError(
|
|
1543
|
+
f"HybridCache not fully implemented with classes "
|
|
1544
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1545
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1546
|
+
f"level={level}"
|
|
1547
|
+
)
|
|
1548
|
+
|
|
1549
|
+
if expected.__class__.__name__ == "StaticCache":
|
|
1550
|
+
if got.__class__.__name__ == "StaticCache":
|
|
1551
|
+
from .cache_helper import CacheKeyValue
|
|
1552
|
+
|
|
1553
|
+
cae = CacheKeyValue(expected)
|
|
1554
|
+
cag = CacheKeyValue(got)
|
|
1555
|
+
if verbose >= 6:
|
|
1556
|
+
print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
|
|
1557
|
+
return max_diff(
|
|
1558
|
+
[cae.key_cache, cae.value_cache],
|
|
1559
|
+
[cag.key_cache, cag.value_cache],
|
|
1560
|
+
verbose=verbose,
|
|
1561
|
+
hist=hist,
|
|
1562
|
+
)
|
|
1563
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1564
|
+
from .cache_helper import CacheKeyValue
|
|
1565
|
+
|
|
1566
|
+
cae = CacheKeyValue(expected)
|
|
1567
|
+
return max_diff(
|
|
1568
|
+
[cae.key_cache, cae.value_cache],
|
|
1569
|
+
[got[0], got[1]],
|
|
1570
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1571
|
+
**_dkws,
|
|
1572
|
+
)
|
|
1573
|
+
raise AssertionError(
|
|
1574
|
+
f"StaticCache not fully implemented with classes "
|
|
1575
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1576
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1577
|
+
f"level={level}"
|
|
1578
|
+
)
|
|
1579
|
+
|
|
1580
|
+
if expected.__class__.__name__ == "SlidingWindowCache":
|
|
1581
|
+
if got.__class__.__name__ == "SlidingWindowCache":
|
|
1582
|
+
if verbose >= 6:
|
|
1583
|
+
print(
|
|
1584
|
+
f"[max_diff] SlidingWindowCache: "
|
|
1585
|
+
f"{string_type(expected)} ? {string_type(got)}"
|
|
1586
|
+
)
|
|
1587
|
+
from .cache_helper import CacheKeyValue
|
|
1588
|
+
|
|
1589
|
+
cae = CacheKeyValue(expected)
|
|
1590
|
+
cag = CacheKeyValue(got)
|
|
1591
|
+
return max_diff(
|
|
1592
|
+
[cae.key_cache, cae.value_cache],
|
|
1593
|
+
[cag.key_cache, cag.value_cache],
|
|
1594
|
+
verbose=verbose,
|
|
1595
|
+
hist=hist,
|
|
1596
|
+
)
|
|
1597
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1598
|
+
from .cache_helper import CacheKeyValue
|
|
1599
|
+
|
|
1600
|
+
cae = CacheKeyValue(expected)
|
|
1601
|
+
return max_diff(
|
|
1602
|
+
[cae.key_cache, cae.value_cache],
|
|
1603
|
+
[got[0], got[1]],
|
|
1604
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1605
|
+
**_dkws,
|
|
1606
|
+
)
|
|
1607
|
+
raise AssertionError(
|
|
1608
|
+
f"SlidingWindowCache not fully implemented with classes "
|
|
1609
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1610
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1611
|
+
f"level={level}"
|
|
1612
|
+
)
|
|
1613
|
+
|
|
1614
|
+
if expected.__class__.__name__ == "EncoderDecoderCache":
|
|
1615
|
+
if got.__class__.__name__ == "EncoderDecoderCache":
|
|
1616
|
+
if verbose >= 6:
|
|
1617
|
+
print(
|
|
1618
|
+
f"[max_diff] EncoderDecoderCache: "
|
|
1619
|
+
f"{string_type(expected)} ? {string_type(got)}"
|
|
1620
|
+
)
|
|
1621
|
+
return max_diff(
|
|
1622
|
+
[expected.self_attention_cache, expected.cross_attention_cache],
|
|
1623
|
+
[got.self_attention_cache, got.cross_attention_cache],
|
|
1624
|
+
verbose=verbose,
|
|
1625
|
+
hist=hist,
|
|
1626
|
+
)
|
|
1627
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1628
|
+
return max_diff(
|
|
1629
|
+
[expected.self_attention_cache, expected.cross_attention_cache],
|
|
1630
|
+
[got[0], got[1]],
|
|
1631
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1632
|
+
**_dkws,
|
|
1633
|
+
)
|
|
1634
|
+
raise AssertionError(
|
|
1635
|
+
f"EncoderDecoderCache not fully implemented with classes "
|
|
1636
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1637
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1638
|
+
f"level={level}"
|
|
1639
|
+
)
|
|
1640
|
+
|
|
1641
|
+
if expected.__class__.__name__ in ("transformers.cache_utils.MambaCache", "MambaCache"):
|
|
1642
|
+
if verbose >= 6:
|
|
1643
|
+
print(f"[max_diff] MambaCache: {string_type(expected)} ? {string_type(got)}")
|
|
1644
|
+
if got.__class__.__name__ != expected.__class__.__name__:
|
|
1645
|
+
# This case happens with onnx where the outputs are flattened.
|
|
1646
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1647
|
+
atts = []
|
|
1648
|
+
for k in ["conv_states", "ssm_states"]:
|
|
1649
|
+
if hasattr(expected, k) and not hasattr(got, k):
|
|
1650
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1651
|
+
atts.append(k)
|
|
1652
|
+
|
|
1653
|
+
return max_diff(
|
|
1654
|
+
[getattr(expected, k) for k in atts],
|
|
1655
|
+
[getattr(got, k) for k in atts],
|
|
1656
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1657
|
+
**_dkws,
|
|
1658
|
+
)
|
|
1659
|
+
|
|
1660
|
+
if expected.__class__.__name__ == "KeyValuesWrapper":
|
|
1661
|
+
if verbose >= 6:
|
|
1662
|
+
print(f"[max_diff] KeyValuesWrapper: {string_type(expected)} ? {string_type(got)}")
|
|
1663
|
+
if got.__class__.__name__ != expected.__class__.__name__:
|
|
1664
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1665
|
+
if got.cache_type != expected.cache_type:
|
|
1666
|
+
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1667
|
+
return max_diff(
|
|
1668
|
+
list(expected),
|
|
1669
|
+
list(got),
|
|
1670
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1671
|
+
**_dkws,
|
|
1672
|
+
)
|
|
1673
|
+
|
|
1674
|
+
raise AssertionError(
|
|
1675
|
+
f"Not implemented with implemented with expected="
|
|
1676
|
+
f"{string_type(expected)}, got={string_type(got)},\n"
|
|
1677
|
+
f"level={level}"
|
|
1678
|
+
)
|
|
1679
|
+
|
|
1680
|
+
|
|
1681
|
+
def string_diff(diff: Dict[str, Any]) -> str:
|
|
1682
|
+
"""Renders discrepancies return by :func:`max_diff` into one string."""
|
|
1683
|
+
# dict(abs=, rel=, sum=, n=n_diff, dnan=)
|
|
1684
|
+
suffix = ""
|
|
1685
|
+
if "rep" in diff:
|
|
1686
|
+
rows = []
|
|
1687
|
+
for k, v in diff["rep"].items():
|
|
1688
|
+
if v > 0:
|
|
1689
|
+
rows.append(f"#{v}{k}")
|
|
1690
|
+
suffix = "-".join(rows)
|
|
1691
|
+
suffix = f"/{suffix}"
|
|
1692
|
+
if "argm" in diff:
|
|
1693
|
+
sa = (
|
|
1694
|
+
",".join(map(str, diff["argm"]))
|
|
1695
|
+
if isinstance(diff["argm"], tuple)
|
|
1696
|
+
else str(diff["argm"])
|
|
1697
|
+
)
|
|
1698
|
+
suffix += f",amax={sa}"
|
|
1699
|
+
if diff.get("dnan", None):
|
|
1700
|
+
if diff["abs"] == 0 or diff["rel"] == 0:
|
|
1701
|
+
return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}{suffix}"
|
|
1702
|
+
return (
|
|
1703
|
+
f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}, dnan={diff['dnan']}{suffix}"
|
|
1704
|
+
)
|
|
1705
|
+
if diff["abs"] == 0 or diff["rel"] == 0:
|
|
1706
|
+
return f"abs={diff['abs']}, rel={diff['rel']}{suffix}"
|
|
1707
|
+
return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}{suffix}"
|