onnx-diagnostic 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +11 -1
- onnx_diagnostic/helpers/helper.py +22 -0
- onnx_diagnostic/helpers/torch_test_helper.py +6 -0
- {onnx_diagnostic-0.4.1.dist-info → onnx_diagnostic-0.4.2.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.4.1.dist-info → onnx_diagnostic-0.4.2.dist-info}/RECORD +9 -9
- {onnx_diagnostic-0.4.1.dist-info → onnx_diagnostic-0.4.2.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.4.1.dist-info → onnx_diagnostic-0.4.2.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.4.1.dist-info → onnx_diagnostic-0.4.2.dist-info}/top_level.txt +0 -0
onnx_diagnostic/__init__.py
CHANGED
|
@@ -155,6 +155,7 @@ def make_mamba_cache(
|
|
|
155
155
|
key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
|
|
156
156
|
) -> transformers.cache_utils.MambaCache:
|
|
157
157
|
"Creates a :class:`transformers.cache_utils.MambaCache`."
|
|
158
|
+
dtype = key_value_pairs[0][0].dtype
|
|
158
159
|
|
|
159
160
|
class _config:
|
|
160
161
|
def __init__(self):
|
|
@@ -162,14 +163,23 @@ def make_mamba_cache(
|
|
|
162
163
|
self.conv_kernel = key_value_pairs[0][0].shape[-1]
|
|
163
164
|
self.state_size = key_value_pairs[0][1].shape[-1]
|
|
164
165
|
self.num_hidden_layers = len(key_value_pairs)
|
|
165
|
-
self.dtype =
|
|
166
|
+
self.dtype = dtype
|
|
166
167
|
|
|
167
168
|
cache = transformers.cache_utils.MambaCache(
|
|
168
169
|
_config(),
|
|
169
170
|
max_batch_size=key_value_pairs[0][0].shape[0],
|
|
170
171
|
device=key_value_pairs[0][0].device,
|
|
172
|
+
dtype=dtype,
|
|
171
173
|
)
|
|
172
174
|
for i in range(len(key_value_pairs)):
|
|
175
|
+
assert cache.conv_states[i].dtype == dtype, (
|
|
176
|
+
f"Type mismatch for cache.conv_states[{i}].dtype="
|
|
177
|
+
f"{cache.conv_states[i].dtype} != {dtype}"
|
|
178
|
+
)
|
|
179
|
+
assert cache.ssm_states[i].dtype == dtype, (
|
|
180
|
+
f"Type mismatch for cache.ssm_states[{i}].dtype="
|
|
181
|
+
f"{cache.ssm_states[i].dtype} != {dtype}"
|
|
182
|
+
)
|
|
173
183
|
assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
|
|
174
184
|
f"Shape mismatch, expected {cache.conv_states[i].shape}, "
|
|
175
185
|
f"got {key_value_pairs[i][0].shape}"
|
|
@@ -1404,6 +1404,28 @@ def max_diff(
|
|
|
1404
1404
|
f"level={level}"
|
|
1405
1405
|
)
|
|
1406
1406
|
|
|
1407
|
+
if expected.__class__.__name__ == "SlidingWindowCache":
|
|
1408
|
+
if got.__class__.__name__ == "SlidingWindowCache":
|
|
1409
|
+
if verbose >= 6:
|
|
1410
|
+
print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
|
|
1411
|
+
return max_diff(
|
|
1412
|
+
[expected.key_cache, expected.value_cache],
|
|
1413
|
+
[got.key_cache, got.value_cache],
|
|
1414
|
+
verbose=verbose,
|
|
1415
|
+
)
|
|
1416
|
+
if isinstance(got, tuple) and len(got) == 2:
|
|
1417
|
+
return max_diff(
|
|
1418
|
+
[expected.key_cache, expected.value_cache],
|
|
1419
|
+
[got[0], got[1]],
|
|
1420
|
+
verbose=verbose,
|
|
1421
|
+
)
|
|
1422
|
+
raise AssertionError(
|
|
1423
|
+
f"SlidingWindowCache not fully implemented with classes "
|
|
1424
|
+
f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
|
|
1425
|
+
f"and expected={string_type(expected)}, got={string_type(got)},\n"
|
|
1426
|
+
f"level={level}"
|
|
1427
|
+
)
|
|
1428
|
+
|
|
1407
1429
|
if expected.__class__.__name__ == "EncoderDecoderCache":
|
|
1408
1430
|
if got.__class__.__name__ == "EncoderDecoderCache":
|
|
1409
1431
|
if verbose >= 6:
|
|
@@ -8,6 +8,7 @@ from .cache_helper import (
|
|
|
8
8
|
make_dynamic_cache,
|
|
9
9
|
make_encoder_decoder_cache,
|
|
10
10
|
make_sliding_window_cache,
|
|
11
|
+
make_mamba_cache,
|
|
11
12
|
)
|
|
12
13
|
|
|
13
14
|
|
|
@@ -346,6 +347,8 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
346
347
|
"""
|
|
347
348
|
Makes a deepcopy.
|
|
348
349
|
"""
|
|
350
|
+
if value is None:
|
|
351
|
+
return None
|
|
349
352
|
if isinstance(value, (int, float, str)):
|
|
350
353
|
return value
|
|
351
354
|
if isinstance(value, tuple):
|
|
@@ -376,6 +379,9 @@ def torch_deepcopy(value: Any) -> Any:
|
|
|
376
379
|
torch_deepcopy(value.self_attention_cache),
|
|
377
380
|
torch_deepcopy(value.cross_attention_cache),
|
|
378
381
|
)
|
|
382
|
+
if value.__class__.__name__ == "MambaCache":
|
|
383
|
+
return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
|
|
384
|
+
|
|
379
385
|
if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
|
|
380
386
|
args, spec = torch.utils._pytree.tree_flatten(value)
|
|
381
387
|
new_args = torch_deepcopy(args)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
onnx_diagnostic/__init__.py,sha256=
|
|
1
|
+
onnx_diagnostic/__init__.py,sha256=wVSctxhjG5jNBmX9oZ_oUVWt2QU4P1s8bsgeKXDj0YI,164
|
|
2
2
|
onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
|
|
3
3
|
onnx_diagnostic/_command_lines_parser.py,sha256=kOECT1BccZc38vmVc3jF3xvXGDpcocvLuUGoPkzte08,14753
|
|
4
4
|
onnx_diagnostic/doc.py,sha256=MTuT7Kxyvn7KEy84liQeFeqhugJrUQhjjpx21F72Uxw,926
|
|
@@ -9,14 +9,14 @@ onnx_diagnostic/export/validate.py,sha256=FBI7Sercyio_Ozw542z5Jh7u6mzKFyDxSRL2hS
|
|
|
9
9
|
onnx_diagnostic/helpers/__init__.py,sha256=21OeajRAtfTaQ-IRa7UqQemJq5lh_7iYMTuZUcNcEqU,67
|
|
10
10
|
onnx_diagnostic/helpers/args_helper.py,sha256=7pTrw1A1wuNvLdXJdpda5spPI140FylwSmxxZTGu_4E,4389
|
|
11
11
|
onnx_diagnostic/helpers/bench_run.py,sha256=CGA6VMJZMH2gDhVueT9ypNm4PMcjGrrGFYp08nhWj9k,16539
|
|
12
|
-
onnx_diagnostic/helpers/cache_helper.py,sha256=
|
|
12
|
+
onnx_diagnostic/helpers/cache_helper.py,sha256=soKjyIXa7EQgALd9PAUGIKYzXlJGoLevYiQDsxoqkQ4,8349
|
|
13
13
|
onnx_diagnostic/helpers/config_helper.py,sha256=zmxKA54xYTNOQyf2MP0FnHsplDmYYtBRLIJlLBrdjUU,3039
|
|
14
|
-
onnx_diagnostic/helpers/helper.py,sha256=
|
|
14
|
+
onnx_diagnostic/helpers/helper.py,sha256=HpSws6esUq86jDQKj7jUM4g-Xdd3FqsvcXBX6QDHDlg,54349
|
|
15
15
|
onnx_diagnostic/helpers/memory_peak.py,sha256=lgQm5DvjxfSw9nEPBYyZKMlaGVe-dcP5jqo7SxBOIS0,6380
|
|
16
16
|
onnx_diagnostic/helpers/onnx_helper.py,sha256=4jX_BxbJ29jW-LaCZib9lJXNci4u8iBM7qitx4KrDQU,29336
|
|
17
17
|
onnx_diagnostic/helpers/ort_session.py,sha256=hlM6h0Bn0b5adNVK-QBwyt0SPFzo2Z7AZ5c0Zmdy928,26774
|
|
18
18
|
onnx_diagnostic/helpers/rt_helper.py,sha256=zFIA3HTvogok6EUjWfTAXUd-BfP_Sh4pwEmLcdqPbl0,1774
|
|
19
|
-
onnx_diagnostic/helpers/torch_test_helper.py,sha256=
|
|
19
|
+
onnx_diagnostic/helpers/torch_test_helper.py,sha256=qaXk1sqzx2kBvgAHYJ5Xzn3pdF1mn-N41tJBq59OejY,13570
|
|
20
20
|
onnx_diagnostic/reference/__init__.py,sha256=0Al5kins8LlBICAsszEZ59thMwmaARBO6fMwtYpKOOQ,98
|
|
21
21
|
onnx_diagnostic/reference/evaluator.py,sha256=gzfcgzc2oC99ynFJ4FF_JPlm-52_OKtpSrLBq7S-QR0,8804
|
|
22
22
|
onnx_diagnostic/reference/ort_evaluator.py,sha256=aYMydWPh9oEAHf1exCu8XFfdpfXvDRY_3d4sKJe5ruo,16369
|
|
@@ -81,8 +81,8 @@ onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiF
|
|
|
81
81
|
onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=7N3fGvT_4Mn4NbIo0Qk57c6DMc3OXGWyvj_P41rjwSY,3513
|
|
82
82
|
onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
83
83
|
onnx_diagnostic/torch_onnx/sbs.py,sha256=HEGDHhV9pfXxpBQrpOWPNWGMsNfOebWewyAazi9poV8,16872
|
|
84
|
-
onnx_diagnostic-0.4.
|
|
85
|
-
onnx_diagnostic-0.4.
|
|
86
|
-
onnx_diagnostic-0.4.
|
|
87
|
-
onnx_diagnostic-0.4.
|
|
88
|
-
onnx_diagnostic-0.4.
|
|
84
|
+
onnx_diagnostic-0.4.2.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
|
|
85
|
+
onnx_diagnostic-0.4.2.dist-info/METADATA,sha256=xkEoLhmlKpx91wddmUKXOpmDwdsEuW2y_c7bbk7cAVw,5511
|
|
86
|
+
onnx_diagnostic-0.4.2.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
87
|
+
onnx_diagnostic-0.4.2.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
|
|
88
|
+
onnx_diagnostic-0.4.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|