onnx-diagnostic 0.7.16__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +78 -22
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +2 -1
- onnx_diagnostic/export/shape_helper.py +47 -70
- onnx_diagnostic/ext_test_case.py +11 -0
- onnx_diagnostic/helpers/cache_helper.py +38 -7
- onnx_diagnostic/helpers/fake_tensor_helper.py +224 -104
- onnx_diagnostic/helpers/helper.py +27 -33
- onnx_diagnostic/helpers/log_helper.py +109 -5
- onnx_diagnostic/helpers/memory_peak.py +2 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +1 -1
- onnx_diagnostic/helpers/model_builder_helper.py +132 -2
- onnx_diagnostic/helpers/onnx_helper.py +1 -1
- onnx_diagnostic/helpers/ort_session.py +4 -0
- onnx_diagnostic/helpers/rt_helper.py +393 -43
- onnx_diagnostic/helpers/torch_helper.py +20 -1
- onnx_diagnostic/tasks/__init__.py +7 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +2 -8
- onnx_diagnostic/tasks/feature_extraction.py +2 -8
- onnx_diagnostic/tasks/image_text_to_text.py +10 -8
- onnx_diagnostic/tasks/summarization.py +2 -8
- onnx_diagnostic/tasks/text2text_generation.py +3 -8
- onnx_diagnostic/tasks/text_generation.py +86 -65
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +718 -438
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +1 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +9 -36
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +12 -6
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +162 -24
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +140 -104
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +1 -4
- onnx_diagnostic/torch_models/validate.py +626 -228
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/RECORD +38 -36
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.7.16.dist-info → onnx_diagnostic-0.8.1.dist-info}/top_level.txt +0 -0
|
@@ -1,82 +1,236 @@
|
|
|
1
|
-
from typing import Any, Dict, Optional, Tuple
|
|
1
|
+
from typing import Any, Dict, Optional, Set, Tuple
|
|
2
2
|
|
|
3
3
|
|
|
4
|
-
|
|
4
|
+
class FakeTensorContext:
|
|
5
|
+
"""Stores information used to reused same dimension for the same dimension names."""
|
|
5
6
|
|
|
7
|
+
def __init__(self, fake_mode: Optional["FakeTensorMode"] = None): # noqa: F821
|
|
8
|
+
if fake_mode is None:
|
|
9
|
+
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
10
|
+
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
6
11
|
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
12
|
+
shape_env = ShapeEnv()
|
|
13
|
+
self.fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
14
|
+
else:
|
|
15
|
+
self.fake_mode = fake_mode
|
|
16
|
+
self._candidates = self._first_primes()
|
|
17
|
+
self._unique_: Set[str] = set()
|
|
18
|
+
self._mapping_int: Dict[int, str] = {}
|
|
19
|
+
self._mapping_str: Dict[str, int] = {}
|
|
13
20
|
|
|
21
|
+
@classmethod
|
|
22
|
+
def _first_primes(cls, n=1000):
|
|
23
|
+
sieve = [True] * (n + 1)
|
|
24
|
+
sieve[0:2] = [False, False]
|
|
14
25
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
fake_mode: Optional["FakeTensorMode"] = None, # noqa: F821
|
|
20
|
-
) -> "FakeTensor": # noqa: F821
|
|
21
|
-
"""
|
|
22
|
-
Changes the shape of a true tensor to make it dynamic.
|
|
26
|
+
for i in range(2, int(n**0.5) + 1):
|
|
27
|
+
if sieve[i]:
|
|
28
|
+
# Élimine les multiples de i
|
|
29
|
+
sieve[i * i : n + 1 : i] = [False] * len(range(i * i, n + 1, i))
|
|
23
30
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
31
|
+
return [i for i, prime in enumerate(sieve) if prime and i >= 13]
|
|
32
|
+
|
|
33
|
+
def _unique(self) -> int:
|
|
34
|
+
i = 0
|
|
35
|
+
c = self._candidates[i]
|
|
36
|
+
while c in self._unique_ or c in self._mapping_int:
|
|
37
|
+
i += 1
|
|
38
|
+
assert i < len(
|
|
39
|
+
self._candidates
|
|
40
|
+
), f"Two many unique dimensions to generate, requested: {len(self._unique_)}"
|
|
41
|
+
c = self._candidates[i]
|
|
42
|
+
self._unique_.add(c)
|
|
43
|
+
return c
|
|
44
|
+
|
|
45
|
+
def from_tensor(self, x, static_shapes=False) -> "FakeTensor": # noqa: F821
|
|
46
|
+
"""
|
|
47
|
+
Returns a fake tensor.
|
|
48
|
+
``pytorch`` returns the same name for the same dimension.
|
|
49
|
+
"""
|
|
50
|
+
fake = self.fake_mode.from_tensor(x, static_shapes=static_shapes)
|
|
51
|
+
for i, s in zip(x.shape, fake.shape):
|
|
52
|
+
assert i not in self._mapping_int or self._mapping_int[i] == s, (
|
|
53
|
+
f"Inconsistency between {x.shape} and {fake.shape}, "
|
|
54
|
+
f"mapping has {self._mapping_int[i]} and s={s}"
|
|
55
|
+
)
|
|
56
|
+
self._mapping_int[i] = s
|
|
57
|
+
return fake
|
|
58
|
+
|
|
59
|
+
def fake_reshape(
|
|
60
|
+
self,
|
|
61
|
+
true_tensor: "torch.Tensor", # noqa: F821
|
|
62
|
+
sh: Dict[int, Any], # noqa: F821
|
|
63
|
+
fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
|
|
64
|
+
) -> "FakeTensor": # noqa: F821
|
|
65
|
+
"""
|
|
66
|
+
Changes the shape of a true tensor to make it dynamic.
|
|
67
|
+
|
|
68
|
+
:param true_tensor: true tensor
|
|
69
|
+
:param sh: dynamic shape
|
|
70
|
+
:param fake_tensor: fake tensor, if None, make a fake one
|
|
71
|
+
:return: fake tensor
|
|
72
|
+
"""
|
|
73
|
+
import torch
|
|
74
|
+
|
|
75
|
+
# deal with 0/1
|
|
76
|
+
for i in sh:
|
|
77
|
+
if true_tensor.shape[i] <= 1:
|
|
78
|
+
expanded_shape = list(true_tensor.shape)
|
|
79
|
+
expanded_shape[i] = self._unique()
|
|
80
|
+
true_tensor = torch.empty(
|
|
81
|
+
tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
# deal with equivalent dimension
|
|
85
|
+
new_shape = list(true_tensor.shape)
|
|
86
|
+
mapping = {}
|
|
87
|
+
for i, s in sh.items():
|
|
88
|
+
d = true_tensor.shape[i]
|
|
89
|
+
if d not in mapping:
|
|
90
|
+
mapping[d] = s
|
|
91
|
+
elif mapping[d] != s:
|
|
92
|
+
d = self._unique()
|
|
93
|
+
mapping[d] = s
|
|
94
|
+
new_shape[i] = d
|
|
95
|
+
true_tensor = torch.empty(
|
|
96
|
+
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# now switch to FakeTensor
|
|
100
|
+
fake_tensor = self.from_tensor(true_tensor, static_shapes=False)
|
|
101
|
+
new_shape = list(true_tensor.shape)
|
|
102
|
+
for i in sh:
|
|
103
|
+
new_shape[i] = fake_tensor.shape[i]
|
|
104
|
+
|
|
105
|
+
reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
|
|
106
|
+
axis=tuple(sorted(sh)), keepdim=True
|
|
107
|
+
)
|
|
108
|
+
return reduced_tensor.expand(*new_shape)
|
|
109
|
+
|
|
110
|
+
def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
|
|
111
|
+
"""See :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`."""
|
|
112
|
+
if x is None:
|
|
113
|
+
return None
|
|
114
|
+
if isinstance(x, (list, tuple)):
|
|
115
|
+
return x.__class__([self.make_fake(i) for i in x])
|
|
116
|
+
if isinstance(x, dict):
|
|
117
|
+
return {k: self.make_fake(v) for k, v in x.items()}
|
|
118
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
119
|
+
assert hasattr(x, "layers"), (
|
|
120
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
121
|
+
f"'layers' not found in class {type(x)}"
|
|
122
|
+
)
|
|
123
|
+
for layer in x.layers:
|
|
124
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
125
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
126
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
127
|
+
)
|
|
128
|
+
layer.keys = self.make_fake(layer.keys)
|
|
129
|
+
layer.values = self.make_fake(layer.values)
|
|
130
|
+
return x
|
|
131
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
132
|
+
self.make_fake(x.self_attention_cache)
|
|
133
|
+
self.make_fake(x.cross_attention_cache)
|
|
134
|
+
return x
|
|
135
|
+
if hasattr(x, "shape"):
|
|
136
|
+
return self.from_tensor(x, static_shapes=False)
|
|
137
|
+
from . import string_type
|
|
138
|
+
|
|
139
|
+
raise TypeError(
|
|
140
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
|
|
144
|
+
"""
|
|
145
|
+
See
|
|
146
|
+
:func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
|
|
147
|
+
"""
|
|
148
|
+
if x is None:
|
|
149
|
+
return None, None
|
|
150
|
+
if isinstance(x, (list, tuple)):
|
|
151
|
+
return x.__class__(
|
|
152
|
+
[
|
|
153
|
+
self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
|
|
154
|
+
for i, ds in zip(x, dynamic_shapes)
|
|
155
|
+
]
|
|
156
|
+
)
|
|
157
|
+
if isinstance(x, dict):
|
|
158
|
+
return {
|
|
159
|
+
k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
|
|
160
|
+
for k, v in x.items()
|
|
161
|
+
}
|
|
162
|
+
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
163
|
+
assert hasattr(x, "layers"), (
|
|
164
|
+
f"Une more recent version of transformers (>=4.55), "
|
|
165
|
+
f"'layers' not found in class {type(x)}"
|
|
166
|
+
)
|
|
167
|
+
assert isinstance(dynamic_shapes, list) and (
|
|
168
|
+
not dynamic_shapes or not isinstance(dynamic_shapes[0], list)
|
|
169
|
+
), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
|
|
170
|
+
for il, layer in enumerate(x.layers):
|
|
171
|
+
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
172
|
+
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
173
|
+
f"not found in class {type(layer)} ({dir(layer)})"
|
|
174
|
+
)
|
|
175
|
+
layer.keys = self.make_fake_with_dynamic_dimensions(
|
|
176
|
+
layer.keys, dynamic_shapes=dynamic_shapes[il * 2]
|
|
177
|
+
)
|
|
178
|
+
layer.values = self.make_fake_with_dynamic_dimensions(
|
|
179
|
+
layer.values, dynamic_shapes=dynamic_shapes[il * 2 + 1]
|
|
180
|
+
)
|
|
181
|
+
return x
|
|
182
|
+
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
183
|
+
self.make_fake_with_dynamic_dimensions(
|
|
184
|
+
x.self_attention_cache, dynamic_shapes=dynamic_shapes[0]
|
|
39
185
|
)
|
|
186
|
+
self.make_fake_with_dynamic_dimensions(
|
|
187
|
+
x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
|
|
188
|
+
)
|
|
189
|
+
return x
|
|
190
|
+
if hasattr(x, "shape"):
|
|
191
|
+
assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
|
|
192
|
+
f"dynamic_shapes must be a dictionary at this stage but "
|
|
193
|
+
f"dynamic_shapes={dynamic_shapes}"
|
|
194
|
+
)
|
|
195
|
+
# We need to overwrite the values.
|
|
196
|
+
new_shape = []
|
|
197
|
+
for idim, dim in enumerate(x.shape):
|
|
198
|
+
if dynamic_shapes is not None and idim in dynamic_shapes:
|
|
199
|
+
s = dynamic_shapes[idim]
|
|
200
|
+
assert isinstance(s, str), (
|
|
201
|
+
f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
|
|
202
|
+
f"at index {idim}"
|
|
203
|
+
)
|
|
204
|
+
if s in self._mapping_str:
|
|
205
|
+
dim = self._mapping_str[s]
|
|
206
|
+
else:
|
|
207
|
+
i = self._unique()
|
|
208
|
+
self._mapping_str[s] = i
|
|
209
|
+
dim = i
|
|
210
|
+
assert isinstance(dim, int), (
|
|
211
|
+
f"Unexpected type {type(dim)}, dynamic_shapes={dynamic_shapes} "
|
|
212
|
+
f"at index {idim}, dim={dim}"
|
|
213
|
+
)
|
|
214
|
+
new_shape.append(dim)
|
|
215
|
+
if tuple(new_shape) != x.shape:
|
|
216
|
+
import torch
|
|
40
217
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
true_tensor = torch.empty(
|
|
53
|
-
tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
# now switch to FakeTensor
|
|
57
|
-
if fake_mode is None:
|
|
58
|
-
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
59
|
-
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
60
|
-
|
|
61
|
-
shape_env = ShapeEnv()
|
|
62
|
-
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
63
|
-
if fake_tensor is None:
|
|
64
|
-
fake_tensor = fake_mode.from_tensor(true_tensor, static_shapes=False)
|
|
65
|
-
assert fake_mode is not None, "fake_mode must be provided"
|
|
66
|
-
|
|
67
|
-
new_shape = list(true_tensor.shape)
|
|
68
|
-
for i in sh:
|
|
69
|
-
new_shape[i] = fake_tensor.shape[i]
|
|
70
|
-
|
|
71
|
-
reduced_tensor = fake_mode.from_tensor(true_tensor, static_shapes=True).sum(
|
|
72
|
-
axis=tuple(sorted(sh)), keepdim=True
|
|
73
|
-
)
|
|
74
|
-
return reduced_tensor.expand(*new_shape)
|
|
218
|
+
x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
|
|
219
|
+
|
|
220
|
+
t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
|
|
221
|
+
assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
|
|
222
|
+
assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
|
|
223
|
+
return t
|
|
224
|
+
from ..helpers import string_type
|
|
225
|
+
|
|
226
|
+
raise TypeError(
|
|
227
|
+
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
228
|
+
)
|
|
75
229
|
|
|
76
230
|
|
|
77
231
|
def make_fake(
|
|
78
|
-
x: Any,
|
|
79
|
-
) -> Tuple[Optional["FakeTensor"], Optional[
|
|
232
|
+
x: Any, context: Optional[FakeTensorContext] = None
|
|
233
|
+
) -> Tuple[Optional["FakeTensor"], Optional[FakeTensorContext]]: # noqa: F821
|
|
80
234
|
"""
|
|
81
235
|
Replaces all tensors by fake tensors.
|
|
82
236
|
This modification happens inplace for caches.
|
|
@@ -114,40 +268,6 @@ def make_fake(
|
|
|
114
268
|
"""
|
|
115
269
|
if x is None:
|
|
116
270
|
return None, None
|
|
117
|
-
if
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
shape_env = ShapeEnv()
|
|
122
|
-
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
123
|
-
|
|
124
|
-
if isinstance(x, (list, tuple)):
|
|
125
|
-
return x.__class__([make_fake(i, fake_mode=fake_mode)[0] for i in x]), fake_mode
|
|
126
|
-
if isinstance(x, dict):
|
|
127
|
-
return {k: make_fake(v, fake_mode=fake_mode)[0] for k, v in x.items()}, fake_mode
|
|
128
|
-
|
|
129
|
-
if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
|
|
130
|
-
assert hasattr(x, "layers"), (
|
|
131
|
-
f"Une more recent version of transformers (>=4.55), "
|
|
132
|
-
f"'layers' not found in class {type(x)}"
|
|
133
|
-
)
|
|
134
|
-
for layer in x.layers:
|
|
135
|
-
assert hasattr(layer, "keys") and hasattr(layer, "values"), (
|
|
136
|
-
f"Une more recent version of transformers (>=4.55), 'layers' "
|
|
137
|
-
f"not found in class {type(layer)} ({dir(layer)})"
|
|
138
|
-
)
|
|
139
|
-
layer.keys = make_fake(layer.keys, fake_mode=fake_mode)[0]
|
|
140
|
-
layer.values = make_fake(layer.values, fake_mode=fake_mode)[0]
|
|
141
|
-
return x, fake_mode
|
|
142
|
-
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
143
|
-
make_fake(x.self_attention_cache, fake_mode=fake_mode)
|
|
144
|
-
make_fake(x.cross_attention_cache, fake_mode=fake_mode)
|
|
145
|
-
return x, fake_mode
|
|
146
|
-
if hasattr(x, "shape"):
|
|
147
|
-
t = fake_mode.from_tensor(x, static_shapes=False)
|
|
148
|
-
return t, fake_mode
|
|
149
|
-
from . import string_type
|
|
150
|
-
|
|
151
|
-
raise TypeError(
|
|
152
|
-
f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
|
|
153
|
-
)
|
|
271
|
+
if context is None:
|
|
272
|
+
context = FakeTensorContext()
|
|
273
|
+
return context.make_fake(x), context
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import ast
|
|
2
2
|
import enum
|
|
3
3
|
import inspect
|
|
4
|
+
import itertools
|
|
4
5
|
from dataclasses import is_dataclass, fields
|
|
5
6
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
|
|
6
7
|
import numpy as np
|
|
@@ -948,8 +949,8 @@ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
|
|
|
948
949
|
from .cache_helper import CacheKeyValue
|
|
949
950
|
|
|
950
951
|
kc = CacheKeyValue(x)
|
|
951
|
-
|
|
952
|
-
|
|
952
|
+
return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
|
|
953
|
+
|
|
953
954
|
if x.__class__.__name__ == "EncoderDecoderCache":
|
|
954
955
|
res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
|
|
955
956
|
return tuple(res)
|
|
@@ -1015,6 +1016,8 @@ def max_diff(
|
|
|
1015
1016
|
|
|
1016
1017
|
You may use :func:`string_diff` to display the discrepancies in one string.
|
|
1017
1018
|
"""
|
|
1019
|
+
if verbose >= 10:
|
|
1020
|
+
print(f"[max_diff] {type(expected)} ? {type(got)}")
|
|
1018
1021
|
if expected is None and got is None:
|
|
1019
1022
|
return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
|
|
1020
1023
|
|
|
@@ -1056,6 +1059,27 @@ def max_diff(
|
|
|
1056
1059
|
allow_unique_tensor_with_list_of_one_element=False,
|
|
1057
1060
|
hist=hist,
|
|
1058
1061
|
)
|
|
1062
|
+
|
|
1063
|
+
if expected.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1064
|
+
if verbose >= 6:
|
|
1065
|
+
print(
|
|
1066
|
+
f"[max_diff] CausalLMOutputWithPast: {string_type(expected, with_shape=True)} "
|
|
1067
|
+
f"? {string_type(got, with_shape=True)}"
|
|
1068
|
+
)
|
|
1069
|
+
if got.__class__.__name__ == "CausalLMOutputWithPast":
|
|
1070
|
+
return max_diff(
|
|
1071
|
+
[expected.logits, *flatten_object(expected.past_key_values)],
|
|
1072
|
+
[got.logits, *flatten_object(got.past_key_values)],
|
|
1073
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1074
|
+
**_dkws,
|
|
1075
|
+
)
|
|
1076
|
+
return max_diff(
|
|
1077
|
+
[expected.logits, *flatten_object(expected.past_key_values)],
|
|
1078
|
+
got,
|
|
1079
|
+
debug_info=_debug(expected.__class__.__name__),
|
|
1080
|
+
**_dkws,
|
|
1081
|
+
)
|
|
1082
|
+
|
|
1059
1083
|
if hasattr(expected, "to_tuple"):
|
|
1060
1084
|
if verbose >= 6:
|
|
1061
1085
|
print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}")
|
|
@@ -1066,36 +1090,6 @@ def max_diff(
|
|
|
1066
1090
|
print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
|
|
1067
1091
|
return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)
|
|
1068
1092
|
|
|
1069
|
-
if isinstance(got, (list, tuple)):
|
|
1070
|
-
if len(got) != 1:
|
|
1071
|
-
if verbose >= 6:
|
|
1072
|
-
print(
|
|
1073
|
-
f"[max_diff] list,tuple,2: {string_type(expected)} "
|
|
1074
|
-
f"? {string_type(got)}"
|
|
1075
|
-
)
|
|
1076
|
-
if verbose > 2:
|
|
1077
|
-
import torch
|
|
1078
|
-
|
|
1079
|
-
print(
|
|
1080
|
-
f"[max_diff] (a) inf because len(expected)={len(expected)}!=1, "
|
|
1081
|
-
f"len(got)={len(got)}, level={level}, _index={_index}"
|
|
1082
|
-
)
|
|
1083
|
-
for i, (a, b) in enumerate(zip(expected, got)):
|
|
1084
|
-
if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
|
|
1085
|
-
print(
|
|
1086
|
-
f" i={i} expected {a.dtype}:{a.shape}, "
|
|
1087
|
-
f"has {b.dtype}:{b.shape}, _index={_index}"
|
|
1088
|
-
)
|
|
1089
|
-
else:
|
|
1090
|
-
print(
|
|
1091
|
-
f" i={i} a is {type(a)}, "
|
|
1092
|
-
f"b is {type(b)}, _index={_index}"
|
|
1093
|
-
)
|
|
1094
|
-
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1095
|
-
if verbose >= 6:
|
|
1096
|
-
print(f"[max_diff] list,tuple,1: {string_type(expected)} ? {string_type(got)}")
|
|
1097
|
-
return max_diff(expected, got[0], debug_info=_debug("lt1"), **_dkws)
|
|
1098
|
-
|
|
1099
1093
|
if isinstance(expected, (tuple, list)):
|
|
1100
1094
|
if verbose >= 6:
|
|
1101
1095
|
print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
|
|
@@ -1484,7 +1478,7 @@ def max_diff(
|
|
|
1484
1478
|
return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
|
|
1485
1479
|
if verbose >= 6:
|
|
1486
1480
|
print(
|
|
1487
|
-
f"[max_diff] {expected.__class__.__name__}: "
|
|
1481
|
+
f"[max_diff*] {expected.__class__.__name__}: "
|
|
1488
1482
|
f"{string_type(expected)} ? {string_type(got)}"
|
|
1489
1483
|
)
|
|
1490
1484
|
expected_args, _spec = torch.utils._pytree.tree_flatten(expected)
|