onnx-diagnostic 0.4.1__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -3,5 +3,5 @@ Investigates onnx models.
3
3
  Functions, classes to dig into a model when this one is right, slow, wrong...
4
4
  """
5
5
 
6
- __version__ = "0.4.1"
6
+ __version__ = "0.4.2"
7
7
  __author__ = "Xavier Dupré"
@@ -155,6 +155,7 @@ def make_mamba_cache(
155
155
  key_value_pairs: List[Tuple[torch.Tensor, torch.Tensor]],
156
156
  ) -> transformers.cache_utils.MambaCache:
157
157
  "Creates a :class:`transformers.cache_utils.MambaCache`."
158
+ dtype = key_value_pairs[0][0].dtype
158
159
 
159
160
  class _config:
160
161
  def __init__(self):
@@ -162,14 +163,23 @@ def make_mamba_cache(
162
163
  self.conv_kernel = key_value_pairs[0][0].shape[-1]
163
164
  self.state_size = key_value_pairs[0][1].shape[-1]
164
165
  self.num_hidden_layers = len(key_value_pairs)
165
- self.dtype = key_value_pairs[0][0].dtype
166
+ self.dtype = dtype
166
167
 
167
168
  cache = transformers.cache_utils.MambaCache(
168
169
  _config(),
169
170
  max_batch_size=key_value_pairs[0][0].shape[0],
170
171
  device=key_value_pairs[0][0].device,
172
+ dtype=dtype,
171
173
  )
172
174
  for i in range(len(key_value_pairs)):
175
+ assert cache.conv_states[i].dtype == dtype, (
176
+ f"Type mismatch for cache.conv_states[{i}].dtype="
177
+ f"{cache.conv_states[i].dtype} != {dtype}"
178
+ )
179
+ assert cache.ssm_states[i].dtype == dtype, (
180
+ f"Type mismatch for cache.ssm_states[{i}].dtype="
181
+ f"{cache.ssm_states[i].dtype} != {dtype}"
182
+ )
173
183
  assert cache.conv_states[i].shape == key_value_pairs[i][0].shape, (
174
184
  f"Shape mismatch, expected {cache.conv_states[i].shape}, "
175
185
  f"got {key_value_pairs[i][0].shape}"
@@ -1404,6 +1404,28 @@ def max_diff(
1404
1404
  f"level={level}"
1405
1405
  )
1406
1406
 
1407
+ if expected.__class__.__name__ == "SlidingWindowCache":
1408
+ if got.__class__.__name__ == "SlidingWindowCache":
1409
+ if verbose >= 6:
1410
+ print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1411
+ return max_diff(
1412
+ [expected.key_cache, expected.value_cache],
1413
+ [got.key_cache, got.value_cache],
1414
+ verbose=verbose,
1415
+ )
1416
+ if isinstance(got, tuple) and len(got) == 2:
1417
+ return max_diff(
1418
+ [expected.key_cache, expected.value_cache],
1419
+ [got[0], got[1]],
1420
+ verbose=verbose,
1421
+ )
1422
+ raise AssertionError(
1423
+ f"SlidingWindowCache not fully implemented with classes "
1424
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1425
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1426
+ f"level={level}"
1427
+ )
1428
+
1407
1429
  if expected.__class__.__name__ == "EncoderDecoderCache":
1408
1430
  if got.__class__.__name__ == "EncoderDecoderCache":
1409
1431
  if verbose >= 6:
@@ -8,6 +8,7 @@ from .cache_helper import (
8
8
  make_dynamic_cache,
9
9
  make_encoder_decoder_cache,
10
10
  make_sliding_window_cache,
11
+ make_mamba_cache,
11
12
  )
12
13
 
13
14
 
@@ -346,6 +347,8 @@ def torch_deepcopy(value: Any) -> Any:
346
347
  """
347
348
  Makes a deepcopy.
348
349
  """
350
+ if value is None:
351
+ return None
349
352
  if isinstance(value, (int, float, str)):
350
353
  return value
351
354
  if isinstance(value, tuple):
@@ -376,6 +379,9 @@ def torch_deepcopy(value: Any) -> Any:
376
379
  torch_deepcopy(value.self_attention_cache),
377
380
  torch_deepcopy(value.cross_attention_cache),
378
381
  )
382
+ if value.__class__.__name__ == "MambaCache":
383
+ return make_mamba_cache(list(zip(value.conv_states, value.ssm_states)))
384
+
379
385
  if value.__class__ in torch.utils._pytree.SUPPORTED_NODES:
380
386
  args, spec = torch.utils._pytree.tree_flatten(value)
381
387
  new_args = torch_deepcopy(args)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: onnx-diagnostic
3
- Version: 0.4.1
3
+ Version: 0.4.2
4
4
  Summary: Investigate ONNX models
5
5
  Home-page: https://github.com/sdpython/onnx-diagnostic
6
6
  Author: Xavier Dupré
@@ -1,4 +1,4 @@
1
- onnx_diagnostic/__init__.py,sha256=bh5VOJtQu31ncsPxImpF46_gmHuLSnb8C4SsLWLh1_k,164
1
+ onnx_diagnostic/__init__.py,sha256=wVSctxhjG5jNBmX9oZ_oUVWt2QU4P1s8bsgeKXDj0YI,164
2
2
  onnx_diagnostic/__main__.py,sha256=YmyV_Aq_ianDlHyKLHMa6h8YK3ZmFPpLVHLKjM91aCk,79
3
3
  onnx_diagnostic/_command_lines_parser.py,sha256=kOECT1BccZc38vmVc3jF3xvXGDpcocvLuUGoPkzte08,14753
4
4
  onnx_diagnostic/doc.py,sha256=MTuT7Kxyvn7KEy84liQeFeqhugJrUQhjjpx21F72Uxw,926
@@ -9,14 +9,14 @@ onnx_diagnostic/export/validate.py,sha256=FBI7Sercyio_Ozw542z5Jh7u6mzKFyDxSRL2hS
9
9
  onnx_diagnostic/helpers/__init__.py,sha256=21OeajRAtfTaQ-IRa7UqQemJq5lh_7iYMTuZUcNcEqU,67
10
10
  onnx_diagnostic/helpers/args_helper.py,sha256=7pTrw1A1wuNvLdXJdpda5spPI140FylwSmxxZTGu_4E,4389
11
11
  onnx_diagnostic/helpers/bench_run.py,sha256=CGA6VMJZMH2gDhVueT9ypNm4PMcjGrrGFYp08nhWj9k,16539
12
- onnx_diagnostic/helpers/cache_helper.py,sha256=5AaZA3F0Tta99QGr8x0mpFJnqGMZauknjtlb1voetEo,7949
12
+ onnx_diagnostic/helpers/cache_helper.py,sha256=soKjyIXa7EQgALd9PAUGIKYzXlJGoLevYiQDsxoqkQ4,8349
13
13
  onnx_diagnostic/helpers/config_helper.py,sha256=zmxKA54xYTNOQyf2MP0FnHsplDmYYtBRLIJlLBrdjUU,3039
14
- onnx_diagnostic/helpers/helper.py,sha256=8VdJqJfMXU2hB0IKEkhWtERJTOBur8GTxFMpKv1OxPM,53397
14
+ onnx_diagnostic/helpers/helper.py,sha256=HpSws6esUq86jDQKj7jUM4g-Xdd3FqsvcXBX6QDHDlg,54349
15
15
  onnx_diagnostic/helpers/memory_peak.py,sha256=lgQm5DvjxfSw9nEPBYyZKMlaGVe-dcP5jqo7SxBOIS0,6380
16
16
  onnx_diagnostic/helpers/onnx_helper.py,sha256=4jX_BxbJ29jW-LaCZib9lJXNci4u8iBM7qitx4KrDQU,29336
17
17
  onnx_diagnostic/helpers/ort_session.py,sha256=hlM6h0Bn0b5adNVK-QBwyt0SPFzo2Z7AZ5c0Zmdy928,26774
18
18
  onnx_diagnostic/helpers/rt_helper.py,sha256=zFIA3HTvogok6EUjWfTAXUd-BfP_Sh4pwEmLcdqPbl0,1774
19
- onnx_diagnostic/helpers/torch_test_helper.py,sha256=EvhliihBUIff2CEGIXB2X15IfOn5svBg4H1HgUidGB4,13376
19
+ onnx_diagnostic/helpers/torch_test_helper.py,sha256=qaXk1sqzx2kBvgAHYJ5Xzn3pdF1mn-N41tJBq59OejY,13570
20
20
  onnx_diagnostic/reference/__init__.py,sha256=0Al5kins8LlBICAsszEZ59thMwmaARBO6fMwtYpKOOQ,98
21
21
  onnx_diagnostic/reference/evaluator.py,sha256=gzfcgzc2oC99ynFJ4FF_JPlm-52_OKtpSrLBq7S-QR0,8804
22
22
  onnx_diagnostic/reference/ort_evaluator.py,sha256=aYMydWPh9oEAHf1exCu8XFfdpfXvDRY_3d4sKJe5ruo,16369
@@ -81,8 +81,8 @@ onnx_diagnostic/torch_models/untrained/llm_phi2.py,sha256=ynBTDHJHCk44NjLT_t6OiF
81
81
  onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py,sha256=7N3fGvT_4Mn4NbIo0Qk57c6DMc3OXGWyvj_P41rjwSY,3513
82
82
  onnx_diagnostic/torch_onnx/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
83
83
  onnx_diagnostic/torch_onnx/sbs.py,sha256=HEGDHhV9pfXxpBQrpOWPNWGMsNfOebWewyAazi9poV8,16872
84
- onnx_diagnostic-0.4.1.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
85
- onnx_diagnostic-0.4.1.dist-info/METADATA,sha256=XSoBJwePsa6TtlJj_ok66Xf6RIqByU5e1oihbqMQt_s,5511
86
- onnx_diagnostic-0.4.1.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
87
- onnx_diagnostic-0.4.1.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
88
- onnx_diagnostic-0.4.1.dist-info/RECORD,,
84
+ onnx_diagnostic-0.4.2.dist-info/licenses/LICENSE.txt,sha256=Vv6TXglX6Rc0d-f8aREhayhT-6PMQXEyOmI2NKlUCMc,1045
85
+ onnx_diagnostic-0.4.2.dist-info/METADATA,sha256=xkEoLhmlKpx91wddmUKXOpmDwdsEuW2y_c7bbk7cAVw,5511
86
+ onnx_diagnostic-0.4.2.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
87
+ onnx_diagnostic-0.4.2.dist-info/top_level.txt,sha256=KwNkXewmcobM3ZT1DJLVWH6ebJzA5qKg7cWqKfpGNT4,16
88
+ onnx_diagnostic-0.4.2.dist-info/RECORD,,