onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1707 @@
1
+ import ast
2
+ import enum
3
+ import inspect
4
+ import itertools
5
+ from dataclasses import is_dataclass, fields
6
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
7
+ import numpy as np
8
+
9
+
10
+ def size_type(dtype: Any) -> int:
11
+ """Returns the element size for an element type."""
12
+ if isinstance(dtype, int):
13
+ from onnx import TensorProto
14
+
15
+ # It is a TensorProto.DATATYPE
16
+ if dtype in {
17
+ TensorProto.DOUBLE,
18
+ TensorProto.INT64,
19
+ TensorProto.UINT64,
20
+ TensorProto.COMPLEX64,
21
+ }:
22
+ return 8
23
+ if dtype in {TensorProto.FLOAT, TensorProto.INT32, TensorProto.UINT32}:
24
+ return 4
25
+ if dtype in {
26
+ TensorProto.FLOAT16,
27
+ TensorProto.BFLOAT16,
28
+ TensorProto.INT16,
29
+ TensorProto.UINT16,
30
+ }:
31
+ return 2
32
+ if dtype in {
33
+ TensorProto.INT8,
34
+ TensorProto.UINT8,
35
+ TensorProto.BOOL,
36
+ TensorProto.FLOAT8E4M3FN,
37
+ TensorProto.FLOAT8E4M3FNUZ,
38
+ TensorProto.FLOAT8E5M2,
39
+ TensorProto.FLOAT8E5M2FNUZ,
40
+ getattr(TensorProto, "FLOAT8E8M0", None),
41
+ }:
42
+ return 1
43
+ if dtype in {TensorProto.COMPLEX128}:
44
+ return 16
45
+ from .onnx_helper import onnx_dtype_name
46
+
47
+ raise AssertionError(
48
+ f"Unable to return the element size for type {onnx_dtype_name(dtype)}"
49
+ )
50
+
51
+ if dtype == np.float64 or dtype == np.int64:
52
+ return 8
53
+ if dtype == np.float32 or dtype == np.float32:
54
+ return 4
55
+ if dtype == np.float16 or dtype == np.int16:
56
+ return 2
57
+ if dtype == np.int32:
58
+ return 4
59
+ if dtype == np.int8:
60
+ return 1
61
+ if hasattr(np, "uint64"):
62
+ # it fails on mac
63
+ if dtype == np.uint64:
64
+ return 8
65
+ if dtype == np.uint32:
66
+ return 4
67
+ if dtype == np.uint16:
68
+ return 2
69
+ if dtype == np.uint8:
70
+ return 1
71
+
72
+ import torch
73
+
74
+ if dtype in {torch.float64, torch.int64}:
75
+ return 8
76
+ if dtype in {torch.float32, torch.int32}:
77
+ return 4
78
+ if dtype in {torch.float16, torch.int16, torch.bfloat16}:
79
+ return 2
80
+ if dtype in {torch.int8, torch.uint8, torch.bool}:
81
+ return 1
82
+ if hasattr(torch, "uint64"):
83
+ # it fails on mac
84
+ if dtype in {torch.uint64}:
85
+ return 8
86
+ if dtype in {torch.uint32}:
87
+ return 4
88
+ if dtype in {torch.uint16}:
89
+ return 2
90
+ import ml_dtypes
91
+
92
+ if dtype == ml_dtypes.bfloat16:
93
+ return 2
94
+ raise AssertionError(f"Unexpected dtype={dtype}")
95
+
96
+
97
+ def string_type(
98
+ obj: Any,
99
+ with_shape: bool = False,
100
+ with_min_max: bool = False,
101
+ with_device: bool = False,
102
+ ignore: bool = False,
103
+ limit: int = 20,
104
+ verbose: int = 0,
105
+ ) -> str:
106
+ """
107
+ Displays the types of an object as a string.
108
+
109
+ :param obj: any
110
+ :param with_shape: displays shapes as well
111
+ :param with_min_max: displays information about the values
112
+ :param with_device: display the device
113
+ :param ignore: if True, just prints the type for unknown types
114
+ :param verbose: verbosity (to show the path it followed to get that print)
115
+ :return: str
116
+
117
+ The function displays something like the following for a tensor.
118
+
119
+ .. code-block:: text
120
+
121
+ T7s2x7[0.5:6:A3.56]
122
+ ^^^+-^^----+------^
123
+ || | |
124
+ || | +-- information about the content of a tensor or array
125
+ || | [min,max:A<average>]
126
+ || |
127
+ || +-- a shape
128
+ ||
129
+ |+-- integer following the code defined by onnx.TensorProto,
130
+ | 7 is onnx.TensorProto.INT64 (see onnx_dtype_name)
131
+ |
132
+ +-- A,T,F
133
+ A is an array from numpy
134
+ T is a Tensor from pytorch
135
+ F is a FakeTensor from pytorch
136
+
137
+ The element types for a tensor are displayed as integer to shorten the message.
138
+ The semantic is defined by :class:`onnx.TensorProto` and can be obtained
139
+ by :func:`onnx_diagnostic.helpers.onnx_helper.onnx_dtype_name`.
140
+
141
+ .. runpython::
142
+ :showcode:
143
+
144
+ from onnx_diagnostic.helpers import string_type
145
+
146
+ print(string_type((1, ["r", 6.6])))
147
+
148
+ With pytorch:
149
+
150
+ .. runpython::
151
+ :showcode:
152
+
153
+ import torch
154
+ from onnx_diagnostic.helpers import string_type
155
+
156
+ inputs = (
157
+ torch.rand((3, 4), dtype=torch.float16),
158
+ [
159
+ torch.rand((5, 6), dtype=torch.float16),
160
+ torch.rand((5, 6, 7), dtype=torch.float16),
161
+ ]
162
+ )
163
+
164
+ # with shapes
165
+ print(string_type(inputs, with_shape=True))
166
+
167
+ # with min max
168
+ print(string_type(inputs, with_shape=True, with_min_max=True))
169
+ """
170
+ if obj is None:
171
+ if verbose:
172
+ print(f"[string_type] A:{type(obj)}")
173
+ return "None"
174
+
175
+ # tuple
176
+ if isinstance(obj, tuple):
177
+ if len(obj) == 1:
178
+ s = string_type(
179
+ obj[0],
180
+ with_shape=with_shape,
181
+ with_min_max=with_min_max,
182
+ with_device=with_device,
183
+ ignore=ignore,
184
+ limit=limit,
185
+ verbose=verbose,
186
+ )
187
+ if verbose:
188
+ print(f"[string_type] C:{type(obj)}")
189
+ return f"({s},)"
190
+ if len(obj) < limit:
191
+ js = ",".join(
192
+ string_type(
193
+ o,
194
+ with_shape=with_shape,
195
+ with_min_max=with_min_max,
196
+ with_device=with_device,
197
+ ignore=ignore,
198
+ limit=limit,
199
+ verbose=verbose,
200
+ )
201
+ for o in obj
202
+ )
203
+ if verbose:
204
+ print(f"[string_type] D:{type(obj)}")
205
+ return f"({js})"
206
+ tt = string_type(
207
+ obj[0],
208
+ with_shape=with_shape,
209
+ with_min_max=with_min_max,
210
+ with_device=with_device,
211
+ ignore=ignore,
212
+ limit=limit,
213
+ verbose=verbose,
214
+ )
215
+ if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
216
+ mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
217
+ if verbose:
218
+ print(f"[string_type] E:{type(obj)}")
219
+ return f"#{len(obj)}({tt},...)[{mini},{maxi}:A[{avg}]]"
220
+ if verbose:
221
+ print(f"[string_type] F:{type(obj)}")
222
+ return f"#{len(obj)}({tt},...)"
223
+ # list
224
+ if isinstance(obj, list):
225
+ if len(obj) < limit:
226
+ js = ",".join(
227
+ string_type(
228
+ o,
229
+ with_shape=with_shape,
230
+ with_min_max=with_min_max,
231
+ with_device=with_device,
232
+ ignore=ignore,
233
+ limit=limit,
234
+ verbose=verbose,
235
+ )
236
+ for o in obj
237
+ )
238
+ if verbose:
239
+ print(f"[string_type] G:{type(obj)}")
240
+ return f"#{len(obj)}[{js}]"
241
+ tt = string_type(
242
+ obj[0],
243
+ with_shape=with_shape,
244
+ with_min_max=with_min_max,
245
+ with_device=with_device,
246
+ ignore=ignore,
247
+ limit=limit,
248
+ verbose=verbose,
249
+ )
250
+ if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
251
+ mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
252
+ if verbose:
253
+ print(f"[string_type] H:{type(obj)}")
254
+ return f"#{len(obj)}[{tt},...][{mini},{maxi}:{avg}]"
255
+ if verbose:
256
+ print(f"[string_type] I:{type(obj)}")
257
+ return f"#{len(obj)}[{tt},...]"
258
+ # set
259
+ if isinstance(obj, set):
260
+ if len(obj) < 10:
261
+ js = ",".join(
262
+ string_type(
263
+ o,
264
+ with_shape=with_shape,
265
+ with_min_max=with_min_max,
266
+ with_device=with_device,
267
+ ignore=ignore,
268
+ limit=limit,
269
+ verbose=verbose,
270
+ )
271
+ for o in obj
272
+ )
273
+ if verbose:
274
+ print(f"[string_type] J:{type(obj)}")
275
+ return f"{{{js}}}"
276
+ if with_min_max and all(isinstance(_, (int, float, bool)) for _ in obj):
277
+ mini, maxi, avg = min(obj), max(obj), sum(float(_) for _ in obj) / len(obj)
278
+ if verbose:
279
+ print(f"[string_type] K:{type(obj)}")
280
+ return f"{{...}}#{len(obj)}[{mini},{maxi}:A{avg}]"
281
+ if verbose:
282
+ print(f"[string_type] L:{type(obj)}")
283
+ return f"{{...}}#{len(obj)}" if with_shape else "{...}"
284
+ # dict
285
+ if isinstance(obj, dict) and type(obj) is dict:
286
+ if len(obj) == 0:
287
+ if verbose:
288
+ print(f"[string_type] M:{type(obj)}")
289
+ return "{}"
290
+
291
+ import torch
292
+
293
+ if all(isinstance(k, int) for k in obj) and all(
294
+ isinstance(
295
+ v,
296
+ (
297
+ str,
298
+ torch.export.dynamic_shapes._Dim,
299
+ torch.export.dynamic_shapes._DerivedDim,
300
+ torch.export.dynamic_shapes._DimHint,
301
+ ),
302
+ )
303
+ for v in obj.values()
304
+ ):
305
+ # This is dynamic shapes
306
+ rows = []
307
+ for k, v in obj.items():
308
+ if isinstance(v, str):
309
+ rows.append(f"{k}:DYN({v})")
310
+ else:
311
+ rows.append(f"{k}:{string_type(v, verbose=verbose)}")
312
+ if verbose:
313
+ print(f"[string_type] DS0:{type(obj)}")
314
+ return f"{{{','.join(rows)}}}"
315
+
316
+ kws = dict(
317
+ with_shape=with_shape,
318
+ with_min_max=with_min_max,
319
+ with_device=with_device,
320
+ ignore=ignore,
321
+ limit=limit,
322
+ verbose=verbose,
323
+ )
324
+ s = ",".join(f"{kv[0]}:{string_type(kv[1],**kws)}" for kv in obj.items())
325
+ if all(isinstance(k, int) for k in obj):
326
+ if verbose:
327
+ print(f"[string_type] N:{type(obj)}")
328
+ return f"{{{s}}}"
329
+ if verbose:
330
+ print(f"[string_type] O:{type(obj)}")
331
+ return f"dict({s})"
332
+ # array
333
+ if isinstance(obj, np.ndarray):
334
+ from .onnx_helper import np_dtype_to_tensor_dtype
335
+
336
+ if with_min_max:
337
+ s = string_type(obj, with_shape=with_shape)
338
+ if len(obj.shape) == 0:
339
+ return f"{s}={obj}"
340
+ if obj.size == 0:
341
+ return f"{s}[empty]"
342
+ n_nan = np.isnan(obj.reshape((-1,))).astype(int).sum()
343
+ if n_nan > 0:
344
+ nob = obj.ravel()
345
+ nob = nob[~np.isnan(nob)]
346
+ if nob.size == 0:
347
+ if verbose:
348
+ print(f"[string_type] A1:{type(obj)}")
349
+ return f"{s}[N{n_nan}nans]"
350
+ if verbose:
351
+ print(f"[string_type] A2:{type(obj)}")
352
+ return f"{s}[{nob.min()},{nob.max()}:A{nob.astype(float).mean()}N{n_nan}nans]"
353
+ if verbose:
354
+ print(f"[string_type] A3:{type(obj)}")
355
+ return f"{s}[{obj.min()},{obj.max()}:A{obj.astype(float).mean()}]"
356
+ i = np_dtype_to_tensor_dtype(obj.dtype)
357
+ if not with_shape:
358
+ if verbose:
359
+ print(f"[string_type] A4:{type(obj)}")
360
+ return f"A{i}r{len(obj.shape)}"
361
+ if verbose:
362
+ print(f"[string_type] A5:{type(obj)}")
363
+ return f"A{i}s{'x'.join(map(str, obj.shape))}"
364
+
365
+ import torch
366
+
367
+ # Dim, SymInt
368
+ if isinstance(obj, torch.export.dynamic_shapes._DerivedDim):
369
+ if verbose:
370
+ print(f"[string_type] Y1:{type(obj)}")
371
+ return "DerivedDim"
372
+ if isinstance(obj, torch.export.dynamic_shapes._Dim):
373
+ if verbose:
374
+ print(f"[string_type] Y2:{type(obj)}")
375
+ return f"Dim({obj.__name__})"
376
+ if isinstance(obj, torch.SymInt):
377
+ if verbose:
378
+ print(f"[string_type] Y3:{type(obj)}")
379
+ return "SymInt"
380
+ if isinstance(obj, torch.SymFloat):
381
+ if verbose:
382
+ print(f"[string_type] Y4:{type(obj)}")
383
+ return "SymFloat"
384
+
385
+ if isinstance(obj, torch.export.dynamic_shapes._DimHint):
386
+ cl = (
387
+ torch.export.dynamic_shapes._DimHintType
388
+ if hasattr(torch.export.dynamic_shapes, "_DimHintType")
389
+ else torch.export.Dim
390
+ )
391
+ if obj in (torch.export.Dim.DYNAMIC, cl.DYNAMIC):
392
+ if verbose:
393
+ print(f"[string_type] Y8:{type(obj)}")
394
+ return "DYNAMIC"
395
+ if obj in (torch.export.Dim.AUTO, cl.AUTO):
396
+ if verbose:
397
+ print(f"[string_type] Y9:{type(obj)}")
398
+ return "AUTO"
399
+ if verbose:
400
+ print(f"[string_type] Y7:{type(obj)}")
401
+ return str(obj).replace("DimHint(DYNAMIC)", "DYNAMIC").replace("DimHint(AUTO)", "AUTO")
402
+
403
+ if isinstance(obj, bool):
404
+ if with_min_max:
405
+ if verbose:
406
+ print(f"[string_type] W1:{type(obj)}")
407
+ return f"bool={obj}"
408
+ if verbose:
409
+ print(f"[string_type] W2:{type(obj)}")
410
+ return "bool"
411
+ if isinstance(obj, int):
412
+ if with_min_max:
413
+ if verbose:
414
+ print(f"[string_type] W3:{type(obj)}")
415
+ return f"int={obj}"
416
+ if verbose:
417
+ print(f"[string_type] W4:{type(obj)}")
418
+ return "int"
419
+ if isinstance(obj, float):
420
+ if with_min_max:
421
+ if verbose:
422
+ print(f"[string_type] W6:{type(obj)}")
423
+ return f"float={obj}"
424
+ if verbose:
425
+ print(f"[string_type] W8:{type(obj)}")
426
+ return "float"
427
+ if isinstance(obj, str):
428
+ if verbose:
429
+ print(f"[string_type] W9:{type(obj)}")
430
+ return "str"
431
+ if isinstance(obj, slice):
432
+ if verbose:
433
+ print(f"[string_type] W10:{type(obj)}")
434
+ return "slice"
435
+
436
+ if is_dataclass(obj):
437
+ # That includes torch.export.Dim.AUTO, torch.export.Dim.DYNAMIC so they need to be
438
+ # handled before that.
439
+ values = {f.name: getattr(obj, f.name, None) for f in fields(obj)}
440
+ values = {k: v for k, v in values.items() if v is not None}
441
+ s = string_type(
442
+ values,
443
+ with_shape=with_shape,
444
+ with_min_max=with_min_max,
445
+ with_device=with_device,
446
+ ignore=ignore,
447
+ limit=limit,
448
+ verbose=verbose,
449
+ )
450
+ if verbose:
451
+ print(f"[string_type] B:{type(obj)}")
452
+ return f"{obj.__class__.__name__}{s[4:]}"
453
+
454
+ # Tensors
455
+ if isinstance(obj, torch._subclasses.fake_tensor.FakeTensor):
456
+ from .torch_helper import torch_dtype_to_onnx_dtype
457
+
458
+ i = torch_dtype_to_onnx_dtype(obj.dtype)
459
+ prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
460
+ if not with_shape:
461
+ if verbose:
462
+ print(f"[string_type] F1:{type(obj)}")
463
+ return f"{prefix}F{i}r{len(obj.shape)}"
464
+ if verbose:
465
+ print(f"[string_type] F2:{type(obj)}")
466
+ return f"{prefix}F{i}s{'x'.join(map(str, obj.shape))}"
467
+
468
+ if isinstance(obj, torch.Tensor):
469
+ from .torch_helper import torch_dtype_to_onnx_dtype
470
+
471
+ if with_min_max:
472
+ s = string_type(obj, with_shape=with_shape, with_device=with_device)
473
+ if len(obj.shape) == 0:
474
+ if verbose:
475
+ print(f"[string_type] T1:{type(obj)}")
476
+ return f"{s}={obj}"
477
+ if obj.numel() == 0:
478
+ if verbose:
479
+ print(f"[string_type] T2:{type(obj)}")
480
+ return f"{s}[empty]"
481
+ n_nan = obj.reshape((-1,)).isnan().to(int).sum()
482
+ if n_nan > 0:
483
+ nob = obj.reshape((-1,))
484
+ nob = nob[~nob.isnan()]
485
+ if obj.dtype in {torch.complex64, torch.complex128}:
486
+ if verbose:
487
+ print(f"[string_type] T3:{type(obj)}")
488
+ return (
489
+ f"{s}[{nob.abs().min()},{nob.abs().max():A{nob.mean()}N{n_nan}nans}]"
490
+ )
491
+ if verbose:
492
+ print(f"[string_type] T5:{type(obj)}")
493
+ return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}N{n_nan}nans]"
494
+ if obj.dtype in {torch.complex64, torch.complex128}:
495
+ if verbose:
496
+ print(f"[string_type] T6:{type(obj)}")
497
+ return f"{s}[{obj.abs().min()},{obj.abs().max()}:A{obj.abs().mean()}]"
498
+ if verbose:
499
+ print(f"[string_type] T7:{type(obj)}")
500
+ return f"{s}[{obj.min()},{obj.max()}:A{obj.to(float).mean()}]"
501
+ i = torch_dtype_to_onnx_dtype(obj.dtype)
502
+ prefix = ("G" if obj.get_device() >= 0 else "C") if with_device else ""
503
+ if not with_shape:
504
+ if verbose:
505
+ print(f"[string_type] T8:{type(obj)}")
506
+ return f"{prefix}T{i}r{len(obj.shape)}"
507
+ if verbose:
508
+ print(f"[string_type] T9:{type(obj)}")
509
+ return f"{prefix}T{i}s{'x'.join(map(str, obj.shape))}"
510
+
511
+ if obj.__class__.__name__ == "OrtValue":
512
+ if not obj.has_value():
513
+ if verbose:
514
+ print(f"[string_type] V1:{type(obj)}")
515
+ return "OV(<novalue>)"
516
+ if not obj.is_tensor():
517
+ if verbose:
518
+ print(f"[string_type] V2:{type(obj)}")
519
+ return "OV(NOTENSOR)"
520
+ if with_min_max:
521
+ from .torch_helper import to_numpy
522
+
523
+ try:
524
+ t = to_numpy(obj)
525
+ except Exception:
526
+ # pass unable to convert into numpy (bfloat16, ...)
527
+ if verbose:
528
+ print(f"[string_type] V3:{type(obj)}")
529
+ return "OV(NO-NUMPY:FIXIT)"
530
+ if verbose:
531
+ print(f"[string_type] V4:{type(obj)}")
532
+ return f"OV({string_type(t, with_shape=with_shape, with_min_max=with_min_max)})"
533
+ dt = obj.element_type()
534
+ shape = obj.shape()
535
+ if with_shape:
536
+ if verbose:
537
+ print(f"[string_type] V5:{type(obj)}")
538
+ return f"OV{dt}s{'x'.join(map(str, shape))}"
539
+ if verbose:
540
+ print(f"[string_type] V6:{type(obj)}")
541
+ return f"OV{dt}r{len(shape)}"
542
+
543
+ # others classes
544
+
545
+ if obj.__class__.__name__ == "MambaCache":
546
+ c = string_type(
547
+ obj.conv_states,
548
+ with_shape=with_shape,
549
+ with_min_max=with_min_max,
550
+ with_device=with_device,
551
+ limit=limit,
552
+ verbose=verbose,
553
+ )
554
+ d = string_type(
555
+ obj.ssm_states,
556
+ with_shape=with_shape,
557
+ with_min_max=with_min_max,
558
+ with_device=with_device,
559
+ limit=limit,
560
+ verbose=verbose,
561
+ )
562
+ if verbose:
563
+ print(f"[string_type] CACHE1:{type(obj)}")
564
+ return f"MambaCache(conv_states={c}, ssm_states={d})"
565
+
566
+ if obj.__class__.__name__ in {
567
+ "DynamicCache",
568
+ "SlidingWindowCache",
569
+ "StaticCache",
570
+ "HybridCache",
571
+ }:
572
+ from .cache_helper import CacheKeyValue
573
+
574
+ ca = CacheKeyValue(obj)
575
+ kc = string_type(
576
+ ca.key_cache,
577
+ with_shape=with_shape,
578
+ with_min_max=with_min_max,
579
+ with_device=with_device,
580
+ limit=limit,
581
+ verbose=verbose,
582
+ )
583
+ vc = string_type(
584
+ ca.value_cache,
585
+ with_shape=with_shape,
586
+ with_min_max=with_min_max,
587
+ with_device=with_device,
588
+ limit=limit,
589
+ verbose=verbose,
590
+ )
591
+ if verbose:
592
+ print(f"[string_type] CACHE2:{type(obj)}")
593
+ return f"{obj.__class__.__name__}(key_cache={kc}, value_cache={vc})"
594
+
595
+ if obj.__class__.__name__ == "StaticLayer":
596
+ kc = string_type(
597
+ list(obj.keys),
598
+ with_shape=with_shape,
599
+ with_min_max=with_min_max,
600
+ with_device=with_device,
601
+ limit=limit,
602
+ verbose=verbose,
603
+ )
604
+ vc = string_type(
605
+ list(obj.values),
606
+ with_shape=with_shape,
607
+ with_min_max=with_min_max,
608
+ with_device=with_device,
609
+ limit=limit,
610
+ verbose=verbose,
611
+ )
612
+ if verbose:
613
+ print(f"[string_type] SL:{type(obj)}")
614
+ return f"{obj.__class__.__name__}(keys={kc}, values={vc})"
615
+
616
+ if obj.__class__.__name__ == "EncoderDecoderCache":
617
+ att = string_type(
618
+ obj.self_attention_cache,
619
+ with_shape=with_shape,
620
+ with_min_max=with_min_max,
621
+ with_device=with_device,
622
+ limit=limit,
623
+ verbose=verbose,
624
+ )
625
+ cross = string_type(
626
+ obj.cross_attention_cache,
627
+ with_shape=with_shape,
628
+ with_min_max=with_min_max,
629
+ with_device=with_device,
630
+ limit=limit,
631
+ verbose=verbose,
632
+ )
633
+ if verbose:
634
+ print(f"[string_type] CACHE3:{type(obj)}")
635
+ return (
636
+ f"{obj.__class__.__name__}(self_attention_cache={att}, "
637
+ f"cross_attention_cache={cross})"
638
+ )
639
+
640
+ if obj.__class__ in torch.utils._pytree.SUPPORTED_NODES:
641
+ from .cache_helper import flatten_unflatten_for_dynamic_shapes
642
+
643
+ args = flatten_unflatten_for_dynamic_shapes(obj)
644
+ att = string_type(
645
+ args,
646
+ with_shape=with_shape,
647
+ with_min_max=with_min_max,
648
+ with_device=with_device,
649
+ limit=limit,
650
+ verbose=verbose,
651
+ )
652
+ if verbose:
653
+ print(f"[string_type] DS:{type(obj)}")
654
+ return f"{obj.__class__.__name__}[serialized]({att})"
655
+
656
+ if type(obj).__name__ == "Node" and hasattr(obj, "meta"):
657
+ # torch.fx.node.Node
658
+ if verbose:
659
+ print(f"[string_type] TT1:{type(obj)}")
660
+ return f"%{obj.target}"
661
+ if type(obj).__name__ == "ValueInfoProto":
662
+ if verbose:
663
+ print(f"[string_type] OO1:{type(obj)}")
664
+ return f"OT{obj.type.tensor_type.elem_type}"
665
+
666
+ if obj.__class__.__name__ == "BatchFeature":
667
+ s = string_type(
668
+ obj.data,
669
+ with_shape=with_shape,
670
+ with_min_max=with_min_max,
671
+ with_device=with_device,
672
+ limit=limit,
673
+ verbose=verbose,
674
+ )
675
+ if verbose:
676
+ print(f"[string_type] TT2:{type(obj)}")
677
+ return f"BatchFeature(data={s})"
678
+
679
+ if obj.__class__.__name__ == "BatchEncoding":
680
+ s = string_type(
681
+ obj.data,
682
+ with_shape=with_shape,
683
+ with_min_max=with_min_max,
684
+ with_device=with_device,
685
+ limit=limit,
686
+ verbose=verbose,
687
+ )
688
+ if verbose:
689
+ print(f"[string_type] TT3:{type(obj)}")
690
+ return f"BatchEncoding(data={s})"
691
+
692
+ if obj.__class__.__name__ == "VirtualTensor":
693
+ if verbose:
694
+ print(f"[string_type] TT4:{type(obj)}")
695
+ return (
696
+ f"{obj.__class__.__name__}(name={obj.name!r}, "
697
+ f"dtype={obj.dtype}, shape={obj.shape})"
698
+ )
699
+
700
+ if obj.__class__.__name__ == "KeyValuesWrapper":
701
+ import transformers
702
+
703
+ assert isinstance(
704
+ obj, transformers.cache_utils.KeyValuesWrapper
705
+ ), f"Unexpected type {type(obj)}"
706
+ if verbose:
707
+ print(f"[string_type] KW0:{type(obj)}")
708
+ s = string_type(
709
+ list(obj),
710
+ with_shape=with_shape,
711
+ with_min_max=with_min_max,
712
+ with_device=with_device,
713
+ limit=limit,
714
+ verbose=verbose,
715
+ )
716
+ return f"{obj.__class__.__name__}[{obj.cache_type}]{s}"
717
+
718
+ if obj.__class__.__name__ == "DynamicLayer":
719
+ import transformers
720
+
721
+ assert isinstance(
722
+ obj, transformers.cache_utils.DynamicLayer
723
+ ), f"Unexpected type {type(obj)}"
724
+ if verbose:
725
+ print(f"[string_type] LY0:{type(obj)}")
726
+ s1 = string_type(
727
+ obj.keys,
728
+ with_shape=with_shape,
729
+ with_min_max=with_min_max,
730
+ with_device=with_device,
731
+ limit=limit,
732
+ verbose=verbose,
733
+ )
734
+ s2 = string_type(
735
+ obj.values,
736
+ with_shape=with_shape,
737
+ with_min_max=with_min_max,
738
+ with_device=with_device,
739
+ limit=limit,
740
+ verbose=verbose,
741
+ )
742
+ return f"{obj.__class__.__name__}(keys={s1}, values={s2})"
743
+
744
+ if isinstance(obj, torch.nn.Module):
745
+ if verbose:
746
+ print(f"[string_type] MM:{type(obj)}")
747
+ return f"{obj.__class__.__name__}(...)"
748
+
749
+ if isinstance(obj, (torch.device, torch.dtype, torch.memory_format, torch.layout)):
750
+ if verbose:
751
+ print(f"[string_type] TT7:{type(obj)}")
752
+ return f"{obj.__class__.__name__}({obj})"
753
+
754
+ if isinstance( # TreeSpec, MappingKey, SequenceKey
755
+ obj,
756
+ (
757
+ torch.utils._pytree.TreeSpec,
758
+ torch.utils._pytree.MappingKey,
759
+ torch.utils._pytree.SequenceKey,
760
+ ),
761
+ ):
762
+ if verbose:
763
+ print(f"[string_type] TT8:{type(obj)}")
764
+ return repr(obj).replace(" ", "").replace("\n", " ")
765
+
766
+ if ignore:
767
+ if verbose:
768
+ print(f"[string_type] CACHE4:{type(obj)}")
769
+ return f"{obj.__class__.__name__}(...)"
770
+
771
+ if obj.__class__.__name__.endswith("Config"):
772
+ import transformers.configuration_utils as tcu
773
+
774
+ if isinstance(obj, tcu.PretrainedConfig):
775
+ if verbose:
776
+ print(f"[string_type] CONFIG:{type(obj)}")
777
+ s = str(obj.to_diff_dict()).replace("\n", "").replace(" ", "")
778
+ return f"{obj.__class__.__name__}(**{s})"
779
+ if obj.__class__.__name__ in {"TorchModelContainer", "InferenceSession"}:
780
+ return f"{obj.__class__.__name__}(...)"
781
+ if obj.__class__.__name__ == "Results":
782
+ import ultralytics
783
+
784
+ assert isinstance(
785
+ obj, ultralytics.engine.results.Results
786
+ ), f"Unexpected type={type(obj)}"
787
+ return f"ultralytics.{obj.__class__.__name__}(...)"
788
+ if obj.__class__.__name__ == "FakeTensorMode":
789
+ return f"{obj}"
790
+
791
+ if verbose:
792
+ print(f"[string_type] END:{type(obj)}")
793
+ raise AssertionError(f"Unsupported type {type(obj).__name__!r} - {type(obj)}")
794
+
795
+
796
+ def string_signature(sig: Any) -> str:
797
+ """Displays the signature of a functions."""
798
+
799
+ def _k(p, kind):
800
+ for name in dir(p):
801
+ if getattr(p, name) == kind:
802
+ return name
803
+ return repr(kind)
804
+
805
+ text = [" __call__ ("]
806
+ for p in sig.parameters:
807
+ pp = sig.parameters[p]
808
+ kind = repr(pp.kind)
809
+ t = f"{p}: {pp.annotation}" if pp.annotation is not inspect._empty else p
810
+ if pp.default is not inspect._empty:
811
+ t = f"{t} = {pp.default!r}"
812
+ if kind == pp.VAR_POSITIONAL:
813
+ t = f"*{t}"
814
+ le = (30 - len(t)) * " "
815
+ text.append(f" {t}{le}|{_k(pp,kind)}")
816
+ text.append(
817
+ f") -> {sig.return_annotation}" if sig.return_annotation is not inspect._empty else ")"
818
+ )
819
+ return "\n".join(text)
820
+
821
+
822
+ def string_sig(f: Callable, kwargs: Optional[Dict[str, Any]] = None) -> str:
823
+ """
824
+ Displays the signature of a function if the default
825
+ if the given value is different from
826
+ """
827
+ if hasattr(f, "__init__") and kwargs is None:
828
+ fct = f.__init__
829
+ kwargs = f.__dict__
830
+ name = f.__class__.__name__
831
+ else:
832
+ fct = f
833
+ name = f.__name__
834
+
835
+ if kwargs is None:
836
+ kwargs = {}
837
+ rows = []
838
+ sig = inspect.signature(fct)
839
+ for p in sig.parameters:
840
+ pp = sig.parameters[p]
841
+ d = pp.default
842
+ if d is inspect._empty:
843
+ if p in kwargs:
844
+ v = kwargs[p]
845
+ rows.append(
846
+ f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}"
847
+ )
848
+ continue
849
+ v = kwargs.get(p, d)
850
+ if d != v:
851
+ rows.append(f"{p}={v!r}" if not isinstance(v, enum.IntEnum) else f"{p}={v.name}")
852
+ continue
853
+ atts = ", ".join(rows)
854
+ return f"{name}({atts})"
855
+
856
+
857
+ def make_hash(obj: Any) -> str:
858
+ """
859
+ Returns a simple hash of ``id(obj)`` in four letter.
860
+ """
861
+ aa = id(obj) % (26**3)
862
+ return f"{chr(65 + aa // 26 ** 2)}{chr(65 + (aa // 26) % 26)}{chr(65 + aa % 26)}"
863
+
864
+
865
+ def rename_dynamic_dimensions(
866
+ constraints: Dict[str, Set[str]], original: Set[str], ban_prefix: str = "DYN"
867
+ ) -> Dict[str, str]:
868
+ """
869
+ Renames dynamic shapes as requested by the user. :func:`torch.export.export` uses
870
+ many names for dynamic dimensions. When building the onnx model,
871
+ some of them are redundant and can be replaced by the name provided by the user.
872
+
873
+ :param constraints: exhaustive list of used names and all the values equal to it
874
+ :param original: the names to use if possible
875
+ :param ban_prefix: avoid any rewriting by a constant starting with this prefix
876
+ :return: replacement dictionary
877
+ """
878
+ replacements = {s: s for s in original}
879
+ all_values = set(constraints) | original
880
+
881
+ not_done = set(constraints)
882
+ max_iter = len(replacements)
883
+ while not_done and max_iter > 0:
884
+ max_iter -= 1
885
+ for k, v in constraints.items():
886
+ common = v & original
887
+ if not common:
888
+ continue
889
+ sorted_common = sorted(common)
890
+ by = sorted_common[0]
891
+ if ban_prefix and by.startswith(ban_prefix):
892
+ continue
893
+ replacements[k] = by
894
+ for vv in v:
895
+ if vv not in replacements:
896
+ replacements[vv] = by
897
+ not_done = all_values - set(replacements)
898
+ return replacements
899
+
900
+
901
+ def rename_dynamic_expression(expression: str, replacements: Dict[str, str]):
902
+ """
903
+ Renames variables of an expression.
904
+
905
+ :param expression: something like ``s15 + seq_length``
906
+ :param replacements: replacements to make
907
+ :return: new string
908
+ """
909
+
910
+ class RenameVariable(ast.NodeTransformer):
911
+ def visit_Name(self, node):
912
+ if node.id in replacements:
913
+ node.id = replacements[node.id]
914
+ return node
915
+
916
+ tree = ast.parse(expression)
917
+ transformer = RenameVariable()
918
+ new_tree = transformer.visit(tree)
919
+ return ast.unparse(new_tree)
920
+
921
+
922
+ def flatten_object(x: Any, drop_keys: bool = False) -> Any:
923
+ """
924
+ Flattens the object.
925
+ It accepts some common classes used in deep learning.
926
+
927
+ :param x: any object
928
+ :param drop_keys: drop the keys if a dictionary is flattened.
929
+ Keeps the order defined by the dictionary if False, sort them if True.
930
+ :return: flattened object
931
+ """
932
+ if x is None:
933
+ return x
934
+ if isinstance(x, (list, tuple)):
935
+ res = []
936
+ for i in x:
937
+ if i is None or hasattr(i, "shape") or isinstance(i, (int, float, str)):
938
+ res.append(i)
939
+ else:
940
+ res.extend(flatten_object(i, drop_keys=drop_keys))
941
+ return tuple(res) if isinstance(x, tuple) else res
942
+ if isinstance(x, dict):
943
+ # We flatten the keys.
944
+ if drop_keys:
945
+ return flatten_object(list(x.values()), drop_keys=drop_keys)
946
+ return flatten_object(list(x.items()), drop_keys=drop_keys)
947
+
948
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
949
+ from .cache_helper import CacheKeyValue
950
+
951
+ kc = CacheKeyValue(x)
952
+ return list(itertools.chain.from_iterable(zip(kc.key_cache, kc.value_cache)))
953
+
954
+ if x.__class__.__name__ == "EncoderDecoderCache":
955
+ res = flatten_object(x.self_attention_cache) + flatten_object(x.cross_attention_cache)
956
+ return tuple(res)
957
+ if x.__class__.__name__ == "MambaCache":
958
+ if isinstance(x.conv_states, list):
959
+ res = flatten_object(x.conv_states) + flatten_object(x.ssm_states)
960
+ return tuple(res)
961
+ return (x.conv_states, x.ssm_states)
962
+ if hasattr(x, "to_tuple"):
963
+ return flatten_object(x.to_tuple(), drop_keys=drop_keys)
964
+ if hasattr(x, "shape"):
965
+ # A tensor. Nothing to do.
966
+ return x
967
+ raise TypeError(
968
+ f"Unexpected type {type(x)} for x, drop_keys={drop_keys}, "
969
+ f"content is {string_type(x, with_shape=True)}"
970
+ )
971
+
972
+
973
+ def _make_debug_info(msg, level, debug_info, verbose) -> Optional[List[str]]:
974
+ return (
975
+ [*(debug_info if debug_info else []), f"{' ' * level}{msg}"] if verbose > 5 else None
976
+ )
977
+
978
+
979
+ def max_diff(
980
+ expected: Any,
981
+ got: Any,
982
+ verbose: int = 0,
983
+ level: int = 0,
984
+ flatten: bool = False,
985
+ debug_info: Optional[List[str]] = None,
986
+ begin: int = 0,
987
+ end: int = -1,
988
+ _index: int = 0,
989
+ allow_unique_tensor_with_list_of_one_element: bool = True,
990
+ hist: Optional[Union[bool, List[float]]] = None,
991
+ ) -> Dict[str, Union[float, int, Tuple[int, ...]]]:
992
+ """
993
+ Returns the maximum discrepancy.
994
+
995
+ :param expected: expected values
996
+ :param got: values
997
+ :param verbose: verbosity level
998
+ :param level: for embedded outputs, used for debug purpposes
999
+ :param flatten: flatten outputs
1000
+ :param debug_info: debug information
1001
+ :param begin: first output to considered
1002
+ :param end: last output to considered (-1 for the last one)
1003
+ :param _index: used with begin and end
1004
+ :param allow_unique_tensor_with_list_of_one_element:
1005
+ allow a comparison between a single tensor and a list of one tensor
1006
+ :param hist: compute an histogram of the discrepancies
1007
+ :return: dictionary with many values
1008
+
1009
+ * abs: max absolute error
1010
+ * rel: max relative error
1011
+ * sum: sum of the errors
1012
+ * n: number of outputs values, if there is one
1013
+ output, this number will be the number of elements
1014
+ of this output
1015
+ * dnan: difference in the number of nan
1016
+
1017
+ You may use :func:`string_diff` to display the discrepancies in one string.
1018
+ """
1019
+ if expected is None and got is None:
1020
+ return dict(abs=0, rel=0, sum=0, n=0, dnan=0)
1021
+
1022
+ _dkws_ = dict(
1023
+ verbose=verbose,
1024
+ level=level + 1,
1025
+ begin=begin,
1026
+ end=end,
1027
+ _index=_index,
1028
+ hist=hist,
1029
+ )
1030
+ _dkws = {**_dkws_, "flatten": flatten}
1031
+ _dkwsf = {**_dkws_, "flatten": False}
1032
+
1033
+ _debug = lambda msg: _make_debug_info(msg, level, debug_info, verbose) # noqa: E731
1034
+
1035
+ if allow_unique_tensor_with_list_of_one_element:
1036
+ if hasattr(expected, "shape") and isinstance(got, (list, tuple)) and len(got) == 1:
1037
+ return max_diff(
1038
+ expected,
1039
+ got[0],
1040
+ verbose=verbose,
1041
+ level=level,
1042
+ flatten=False,
1043
+ debug_info=debug_info,
1044
+ allow_unique_tensor_with_list_of_one_element=False,
1045
+ hist=hist,
1046
+ )
1047
+ return max_diff(
1048
+ expected,
1049
+ got,
1050
+ verbose=verbose,
1051
+ level=level,
1052
+ flatten=flatten,
1053
+ debug_info=debug_info,
1054
+ begin=begin,
1055
+ end=end,
1056
+ _index=_index,
1057
+ allow_unique_tensor_with_list_of_one_element=False,
1058
+ hist=hist,
1059
+ )
1060
+
1061
+ if expected.__class__.__name__ == "CausalLMOutputWithPast":
1062
+ if verbose >= 6:
1063
+ print(
1064
+ f"[max_diff] CausalLMOutputWithPast: {string_type(expected)} "
1065
+ f"? {string_type(got)}"
1066
+ )
1067
+ if got.__class__.__name__ == "CausalLMOutputWithPast":
1068
+ return max_diff(
1069
+ [expected.logits, *flatten_object(expected.past_key_values)],
1070
+ [got.logits, *flatten_object(got.past_key_values)],
1071
+ debug_info=_debug(expected.__class__.__name__),
1072
+ **_dkws,
1073
+ )
1074
+ return max_diff(
1075
+ [expected.logits, *flatten_object(expected.past_key_values)],
1076
+ got,
1077
+ debug_info=_debug(expected.__class__.__name__),
1078
+ **_dkws,
1079
+ )
1080
+
1081
+ if hasattr(expected, "to_tuple"):
1082
+ if verbose >= 6:
1083
+ print(f"[max_diff] to_tuple1: {string_type(expected)} ? {string_type(got)}")
1084
+ return max_diff(expected.to_tuple(), got, debug_info=_debug("to_tuple1"), **_dkws)
1085
+
1086
+ if hasattr(got, "to_tuple"):
1087
+ if verbose >= 6:
1088
+ print(f"[max_diff] to_tuple2: {string_type(expected)} ? {string_type(got)}")
1089
+ return max_diff(expected, got.to_tuple(), debug_info=_debug("to_tuple2"), **_dkws)
1090
+
1091
+ if isinstance(expected, (tuple, list)):
1092
+ if verbose >= 6:
1093
+ print(f"[max_diff] list,tuple,0: {string_type(expected)} ? {string_type(got)}")
1094
+ if len(expected) == 1 and not isinstance(got, type(expected)):
1095
+ if verbose >= 6:
1096
+ print(f"[max_diff] list,tuple,3: {string_type(expected)} ? {string_type(got)}")
1097
+ return max_diff(expected[0], got, debug_info=_debug("lt2"), **_dkws)
1098
+ if not isinstance(got, (tuple, list)):
1099
+ if verbose >= 6:
1100
+ print(f"[max_diff] list,tuple,4: {string_type(expected)} ? {string_type(got)}")
1101
+ if verbose > 2:
1102
+ print(
1103
+ f"[max_diff] inf because type(expected)={type(expected)}, "
1104
+ f"type(got)={type(got)}, level={level}, _index={_index}"
1105
+ )
1106
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1107
+
1108
+ if len(got) != len(expected):
1109
+ if flatten:
1110
+ if verbose >= 6:
1111
+ print(
1112
+ f"[max_diff] list,tuple,5: {string_type(expected)} "
1113
+ f"? {string_type(got)}"
1114
+ )
1115
+ # Let's flatten.
1116
+ if verbose > 2:
1117
+ print(
1118
+ f"[max_diff] flattening because of length mismatch, "
1119
+ f"expected is\n {string_type(expected)}\n -- and got is\n "
1120
+ f"{string_type(got)}"
1121
+ )
1122
+ flat_a = flatten_object(expected, drop_keys=True)
1123
+ flat_b = flatten_object(got, drop_keys=True)
1124
+ if verbose > 2:
1125
+ print(
1126
+ f"[max_diff] after flattening, "
1127
+ f"expected is\n {string_type(flat_a)}\n -- and got is\n "
1128
+ f"{string_type(flat_b)}"
1129
+ )
1130
+ return max_diff(
1131
+ flat_a,
1132
+ flat_b,
1133
+ debug_info=(
1134
+ [
1135
+ *(debug_info if debug_info else []),
1136
+ (
1137
+ f"{' ' * level}flatten["
1138
+ f"{string_type(expected)},{string_type(got)}]"
1139
+ ),
1140
+ ]
1141
+ if verbose > 5
1142
+ else None
1143
+ ),
1144
+ **_dkwsf,
1145
+ )
1146
+
1147
+ if verbose > 2:
1148
+ import torch
1149
+
1150
+ print(
1151
+ f"[max_diff] (b) inf because len(expected)={len(expected)}, "
1152
+ f"len(got)={len(got)}, level={level}, _index={_index}"
1153
+ )
1154
+ for i, (a, b) in enumerate(zip(expected, got)):
1155
+ if isinstance(a, torch.Tensor) and isinstance(b, torch.Tensor):
1156
+ print(
1157
+ f" i={i} expected {a.dtype}:{a.shape}, "
1158
+ f"has {b.dtype}:{b.shape}, _index={_index}"
1159
+ )
1160
+ else:
1161
+ print(f" i={i} a is {type(a)}, b is {type(b)}")
1162
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1163
+
1164
+ if verbose >= 6:
1165
+ print(f"[max_diff] list,tuple,6: {string_type(expected)} ? {string_type(got)}")
1166
+ am, rm, sm, n, dn, drep = 0, 0, 0.0, 0.0, 0, None
1167
+ for ip, (e, g) in enumerate(zip(expected, got)):
1168
+ d = max_diff(
1169
+ e,
1170
+ g,
1171
+ verbose=verbose,
1172
+ level=level + 1,
1173
+ debug_info=(
1174
+ [
1175
+ *(debug_info if debug_info else []),
1176
+ f"{' ' * level}[{ip}] so far abs {am} - rel {rm}",
1177
+ ]
1178
+ if verbose > 5
1179
+ else None
1180
+ ),
1181
+ begin=begin,
1182
+ end=end,
1183
+ _index=_index + ip,
1184
+ flatten=flatten,
1185
+ hist=hist,
1186
+ )
1187
+ am = max(am, d["abs"])
1188
+ dn = max(dn, d["dnan"])
1189
+ rm = max(rm, d["rel"])
1190
+ sm += d["sum"] # type: ignore
1191
+ n += d["n"] # type: ignore
1192
+ if "rep" in d:
1193
+ if drep is None:
1194
+ drep = d["rep"].copy()
1195
+ else:
1196
+ for k, v in d["rep"].items():
1197
+ drep[k] += v
1198
+ res = dict(abs=am, rel=rm, sum=sm, n=n, dnan=dn)
1199
+ if drep:
1200
+ res["rep"] = drep
1201
+ return res # type: ignore
1202
+
1203
+ if isinstance(expected, dict):
1204
+ if verbose >= 6:
1205
+ print(f"[max_diff] dict: {string_type(expected)} ? {string_type(got)}")
1206
+ assert begin == 0 and end == -1, (
1207
+ f"begin={begin}, end={end} not compatible with dictionaries, "
1208
+ f"keys={sorted(expected)}"
1209
+ )
1210
+ if isinstance(got, dict):
1211
+ if len(expected) != len(got):
1212
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1213
+ if set(expected) != set(got):
1214
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1215
+ keys = sorted(expected)
1216
+ return max_diff(
1217
+ [expected[k] for k in keys],
1218
+ [got[k] for k in keys],
1219
+ debug_info=_debug("dict1"),
1220
+ **_dkws,
1221
+ )
1222
+
1223
+ if not isinstance(got, (tuple, list)):
1224
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1225
+ if len(expected) != len(got):
1226
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1227
+ return max_diff(list(expected.values()), got, debug_info=_debug("dict2"), **_dkws)
1228
+
1229
+ import torch
1230
+
1231
+ if isinstance(expected, np.ndarray) or isinstance(got, np.ndarray):
1232
+ if isinstance(expected, torch.Tensor):
1233
+ from .torch_helper import to_numpy
1234
+
1235
+ expected = to_numpy(expected)
1236
+ if isinstance(got, torch.Tensor):
1237
+ from .torch_helper import to_numpy
1238
+
1239
+ got = to_numpy(got)
1240
+ if verbose >= 6:
1241
+ print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1242
+
1243
+ if _index < begin or (end != -1 and _index >= end):
1244
+ # out of boundary
1245
+ return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1246
+ if isinstance(expected, (int, float)):
1247
+ if isinstance(got, np.ndarray) and len(got.shape) == 0:
1248
+ got = float(got)
1249
+ if isinstance(got, (int, float)):
1250
+ if expected == got:
1251
+ return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1252
+ return dict(
1253
+ abs=abs(expected - got),
1254
+ rel=abs(expected - got) / (abs(expected) + 1e-5),
1255
+ sum=abs(expected - got),
1256
+ n=1,
1257
+ dnan=0,
1258
+ )
1259
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1260
+ if expected.dtype in (np.complex64, np.complex128):
1261
+ if got.dtype == expected.dtype:
1262
+ got = np.real(got)
1263
+ elif got.dtype not in (np.float32, np.float64):
1264
+ if verbose >= 10:
1265
+ # To understand the value it comes from.
1266
+ if debug_info:
1267
+ print("\n".join(debug_info))
1268
+ print(
1269
+ f"[max_diff-c] expected.dtype={expected.dtype}, "
1270
+ f"got.dtype={got.dtype}"
1271
+ )
1272
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1273
+ expected = np.real(expected)
1274
+
1275
+ if expected.shape != got.shape:
1276
+ if verbose >= 10:
1277
+ # To understand the value it comes from.
1278
+ if debug_info:
1279
+ print("\n".join(debug_info))
1280
+ print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}")
1281
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1282
+ # nan are replace by 1e10, any discrepancies in that order of magnitude
1283
+ # is likely caused by nans
1284
+ exp_cpu = np.nan_to_num(expected.astype(np.float64), nan=1e10)
1285
+ got_cpu = np.nan_to_num(got.astype(np.float64), nan=1e10)
1286
+ diff = np.abs(got_cpu - exp_cpu)
1287
+ ndiff = np.abs(np.isnan(expected).astype(int) - np.isnan(got).astype(int))
1288
+ rdiff = diff / (np.abs(exp_cpu) + 1e-3)
1289
+ if diff.size == 0:
1290
+ abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
1291
+ (0, 0, 0, 0, 0)
1292
+ if exp_cpu.size == got_cpu.size
1293
+ else (np.inf, np.inf, np.inf, 0, np.inf)
1294
+ )
1295
+ argm = None
1296
+ else:
1297
+ abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
1298
+ float(diff.max()),
1299
+ float(rdiff.max()),
1300
+ float(diff.sum()),
1301
+ float(diff.size),
1302
+ float(ndiff.sum()),
1303
+ )
1304
+ argm = tuple(map(int, np.unravel_index(diff.argmax(), diff.shape)))
1305
+ if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
1306
+ # To understand the value it comes from.
1307
+ if debug_info:
1308
+ print("\n".join(debug_info))
1309
+ print(
1310
+ f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, "
1311
+ f"nan_diff={nan_diff}, dtype={expected.dtype}, "
1312
+ f"shape={expected.shape}, level={level}, _index={_index}"
1313
+ )
1314
+ if abs_diff >= 10:
1315
+ idiff = np.argmax(diff.reshape((-1,)))
1316
+ x = expected.reshape((-1,))[idiff]
1317
+ y = got.reshape((-1,))[idiff]
1318
+ print(
1319
+ f" [max_diff-2] abs diff={abs_diff}, "
1320
+ f"x={x}, y={y}, level={level}, "
1321
+ f"_index={_index}"
1322
+ )
1323
+ print(y)
1324
+
1325
+ if rel_diff >= 10:
1326
+ idiff = np.argmax(rdiff.reshape((-1,)))
1327
+ x = expected.reshape((-1,))[idiff]
1328
+ y = got.reshape((-1,))[idiff]
1329
+ print(
1330
+ f" [max_diff-3] rel diff={rel_diff}, "
1331
+ f"x={x}, y={y}, level={level}, "
1332
+ f"_index={_index}"
1333
+ )
1334
+
1335
+ res: Dict[str, float] = dict( # type: ignore
1336
+ abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1337
+ )
1338
+ if hist:
1339
+ if isinstance(hist, bool):
1340
+ hist = np.array([0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype)
1341
+ ind = np.digitize(diff.reshape((-1,)), hist, right=True)
1342
+ cou = np.bincount(ind, minlength=ind.shape[0] + 1)
1343
+ res["rep"] = dict(
1344
+ zip([f">{x}" for x in hist], [int(i) for i in (cou.sum() - np.cumsum(cou))])
1345
+ )
1346
+ return res # type: ignore
1347
+
1348
+ if isinstance(expected, torch.Tensor) and isinstance(got, torch.Tensor):
1349
+ if verbose >= 6:
1350
+ print(f"[max_diff] tensor: {string_type(expected)} ? {string_type(got)}")
1351
+ if _index < begin or (end != -1 and _index >= end):
1352
+ # out of boundary
1353
+ return dict(abs=0.0, rel=0.0, sum=0.0, n=0.0, dnan=0)
1354
+ if expected.dtype in (torch.complex64, torch.complex128):
1355
+ if got.dtype == expected.dtype:
1356
+ got = torch.view_as_real(got)
1357
+ elif got.dtype not in (torch.float32, torch.float64):
1358
+ if verbose >= 10:
1359
+ # To understand the value it comes from.
1360
+ if debug_info:
1361
+ print("\n".join(debug_info))
1362
+ print(
1363
+ f"[max_diff-c] expected.dtype={expected.dtype}, "
1364
+ f"got.dtype={got.dtype}"
1365
+ )
1366
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1367
+ expected = torch.view_as_real(expected)
1368
+
1369
+ if expected.shape != got.shape:
1370
+ if verbose >= 10:
1371
+ # To understand the value it comes from.
1372
+ if debug_info:
1373
+ print("\n".join(debug_info))
1374
+ print(f"[max_diff-s] expected.shape={expected.shape}, got.shape={got.shape}")
1375
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1376
+ # nan are replace by 1e10, any discrepancies in that order of magnitude
1377
+ # is likely caused by nans
1378
+ exp_cpu = expected.to(torch.float64).nan_to_num(1e10)
1379
+ got_cpu = got.to(torch.float64).nan_to_num(1e10)
1380
+ if got_cpu.device != exp_cpu.device:
1381
+ if torch.device("cuda:0") in {got_cpu.device, exp_cpu.device}:
1382
+ got_cpu = got_cpu.to("cuda:0")
1383
+ exp_cpu = exp_cpu.to("cuda:0")
1384
+ expected = expected.to("cuda:0")
1385
+ got = got.to("cuda:0")
1386
+ else:
1387
+ got_cpu = got_cpu.detach().to("cpu")
1388
+ exp_cpu = exp_cpu.detach().to("cpu")
1389
+ expected = expected.to("cpu")
1390
+ got = got.to("cpu")
1391
+ diff = (got_cpu - exp_cpu).abs()
1392
+ ndiff = (expected.isnan().to(int) - got.isnan().to(int)).abs()
1393
+ rdiff = diff / (exp_cpu.abs() + 1e-3)
1394
+ if diff.numel() > 0:
1395
+ abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
1396
+ float(diff.max().detach()),
1397
+ float(rdiff.max().detach()),
1398
+ float(diff.sum().detach()),
1399
+ float(diff.numel()),
1400
+ float(ndiff.sum().detach()),
1401
+ )
1402
+ argm = tuple(map(int, torch.unravel_index(diff.argmax(), diff.shape)))
1403
+ elif got_cpu.numel() == exp_cpu.numel():
1404
+ abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (0.0, 0.0, 0.0, 0.0, 0.0)
1405
+ argm = None
1406
+ else:
1407
+ abs_diff, rel_diff, sum_diff, n_diff, nan_diff = (
1408
+ np.inf,
1409
+ np.inf,
1410
+ np.inf,
1411
+ np.inf,
1412
+ np.inf,
1413
+ )
1414
+ argm = None
1415
+
1416
+ if verbose >= 10 and (abs_diff >= 10 or rel_diff >= 10):
1417
+ # To understand the value it comes from.
1418
+ if debug_info:
1419
+ print("\n".join(debug_info))
1420
+ print(
1421
+ f"[max_diff-1] abs_diff={abs_diff}, rel_diff={rel_diff}, "
1422
+ f"nan_diff={nan_diff}, dtype={expected.dtype}, "
1423
+ f"shape={expected.shape}, level={level}, _index={_index}"
1424
+ )
1425
+ if abs_diff >= 10:
1426
+ idiff = torch.argmax(diff.reshape((-1,)))
1427
+ x = expected.reshape((-1,))[idiff]
1428
+ y = got.reshape((-1,))[idiff]
1429
+ print(
1430
+ f" [max_diff-2] abs diff={abs_diff}, "
1431
+ f"x={x}, y={y}, level={level}, "
1432
+ f"_index={_index}"
1433
+ )
1434
+ print(y)
1435
+
1436
+ if rel_diff >= 10:
1437
+ idiff = torch.argmax(rdiff.reshape((-1,)))
1438
+ x = expected.reshape((-1,))[idiff]
1439
+ y = got.reshape((-1,))[idiff]
1440
+ print(
1441
+ f" [max_diff-3] rel diff={rel_diff}, "
1442
+ f"x={x}, y={y}, level={level}, "
1443
+ f"_index={_index}"
1444
+ )
1445
+
1446
+ res: Dict[str, float] = dict( # type: ignore
1447
+ abs=abs_diff, rel=rel_diff, sum=sum_diff, n=n_diff, dnan=nan_diff, argm=argm
1448
+ )
1449
+ if hist:
1450
+ if isinstance(hist, bool):
1451
+ hist = torch.tensor(
1452
+ [0, 0.0001, 0.001, 0.01, 0.1, 1, 10, 100], dtype=diff.dtype
1453
+ )
1454
+ hist = hist.to(diff.device)
1455
+ ind = torch.bucketize(diff.reshape((-1,)), hist, right=False)
1456
+ cou = torch.bincount(ind, minlength=ind.shape[0] + 1)
1457
+ res["rep"] = dict(
1458
+ zip(
1459
+ [f">{x}" for x in hist],
1460
+ [int(i) for i in (cou.sum() - torch.cumsum(cou, 0))],
1461
+ )
1462
+ )
1463
+ return res # type: ignore
1464
+
1465
+ if "SquashedNormal" in expected.__class__.__name__:
1466
+ if verbose >= 6:
1467
+ print(f"[max_diff] SquashedNormal: {string_type(expected)} ? {string_type(got)}")
1468
+ values = (
1469
+ expected.mean.detach().to("cpu"),
1470
+ expected.scale.detach().to("cpu"),
1471
+ )
1472
+ return max_diff(values, got, debug_info=_debug("SquashedNormal"), **_dkws)
1473
+
1474
+ if expected.__class__ in torch.utils._pytree.SUPPORTED_NODES:
1475
+ if got.__class__ not in torch.utils._pytree.SUPPORTED_NODES:
1476
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1477
+ if verbose >= 6:
1478
+ print(
1479
+ f"[max_diff*] {expected.__class__.__name__}: "
1480
+ f"{string_type(expected)} ? {string_type(got)}"
1481
+ )
1482
+ expected_args, _spec = torch.utils._pytree.tree_flatten(expected)
1483
+ got_args, _spec = torch.utils._pytree.tree_flatten(got)
1484
+ return max_diff(
1485
+ expected_args, got_args, debug_info=_debug(expected.__class__.__name__), **_dkws
1486
+ )
1487
+
1488
+ # backup function in case pytorch does not know how to serialize.
1489
+ if expected.__class__.__name__ == "DynamicCache":
1490
+ if got.__class__.__name__ == "DynamicCache":
1491
+ from .cache_helper import CacheKeyValue
1492
+
1493
+ if verbose >= 6:
1494
+ print(f"[max_diff] DynamicCache: {string_type(expected)} ? {string_type(got)}")
1495
+ expected = CacheKeyValue(expected)
1496
+ got = CacheKeyValue(got)
1497
+ return max_diff(
1498
+ [expected.key_cache, expected.value_cache],
1499
+ [got.key_cache, got.value_cache],
1500
+ verbose=verbose,
1501
+ hist=hist,
1502
+ )
1503
+ if isinstance(got, tuple) and len(got) == 2:
1504
+ return max_diff(
1505
+ [expected.key_cache, expected.value_cache],
1506
+ [got[0], got[1]],
1507
+ debug_info=_debug(expected.__class__.__name__),
1508
+ **_dkws,
1509
+ )
1510
+ raise AssertionError(
1511
+ f"DynamicCache not fully implemented with classes "
1512
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1513
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1514
+ f"level={level}"
1515
+ )
1516
+
1517
+ # backup function in case pytorch does not know how to serialize.
1518
+ if expected.__class__.__name__ == "HybridCache":
1519
+ if got.__class__.__name__ == "HybridCache":
1520
+ from .cache_helper import CacheKeyValue
1521
+
1522
+ if verbose >= 6:
1523
+ print(f"[max_diff] HybridCache: {string_type(expected)} ? {string_type(got)}")
1524
+ cae = CacheKeyValue(expected)
1525
+ cag = CacheKeyValue(got)
1526
+ return max_diff(
1527
+ [cae.key_cache, cae.value_cache],
1528
+ [cag.key_cache, cag.value_cache],
1529
+ verbose=verbose,
1530
+ hist=hist,
1531
+ )
1532
+ if isinstance(got, tuple) and len(got) == 2:
1533
+ from .cache_helper import CacheKeyValue
1534
+
1535
+ cae = CacheKeyValue(expected)
1536
+ return max_diff(
1537
+ [cae.key_cache, cae.value_cache],
1538
+ [got[0], got[1]],
1539
+ debug_info=_debug(expected.__class__.__name__),
1540
+ **_dkws,
1541
+ )
1542
+ raise AssertionError(
1543
+ f"HybridCache not fully implemented with classes "
1544
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1545
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1546
+ f"level={level}"
1547
+ )
1548
+
1549
+ if expected.__class__.__name__ == "StaticCache":
1550
+ if got.__class__.__name__ == "StaticCache":
1551
+ from .cache_helper import CacheKeyValue
1552
+
1553
+ cae = CacheKeyValue(expected)
1554
+ cag = CacheKeyValue(got)
1555
+ if verbose >= 6:
1556
+ print(f"[max_diff] StaticCache: {string_type(expected)} ? {string_type(got)}")
1557
+ return max_diff(
1558
+ [cae.key_cache, cae.value_cache],
1559
+ [cag.key_cache, cag.value_cache],
1560
+ verbose=verbose,
1561
+ hist=hist,
1562
+ )
1563
+ if isinstance(got, tuple) and len(got) == 2:
1564
+ from .cache_helper import CacheKeyValue
1565
+
1566
+ cae = CacheKeyValue(expected)
1567
+ return max_diff(
1568
+ [cae.key_cache, cae.value_cache],
1569
+ [got[0], got[1]],
1570
+ debug_info=_debug(expected.__class__.__name__),
1571
+ **_dkws,
1572
+ )
1573
+ raise AssertionError(
1574
+ f"StaticCache not fully implemented with classes "
1575
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1576
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1577
+ f"level={level}"
1578
+ )
1579
+
1580
+ if expected.__class__.__name__ == "SlidingWindowCache":
1581
+ if got.__class__.__name__ == "SlidingWindowCache":
1582
+ if verbose >= 6:
1583
+ print(
1584
+ f"[max_diff] SlidingWindowCache: "
1585
+ f"{string_type(expected)} ? {string_type(got)}"
1586
+ )
1587
+ from .cache_helper import CacheKeyValue
1588
+
1589
+ cae = CacheKeyValue(expected)
1590
+ cag = CacheKeyValue(got)
1591
+ return max_diff(
1592
+ [cae.key_cache, cae.value_cache],
1593
+ [cag.key_cache, cag.value_cache],
1594
+ verbose=verbose,
1595
+ hist=hist,
1596
+ )
1597
+ if isinstance(got, tuple) and len(got) == 2:
1598
+ from .cache_helper import CacheKeyValue
1599
+
1600
+ cae = CacheKeyValue(expected)
1601
+ return max_diff(
1602
+ [cae.key_cache, cae.value_cache],
1603
+ [got[0], got[1]],
1604
+ debug_info=_debug(expected.__class__.__name__),
1605
+ **_dkws,
1606
+ )
1607
+ raise AssertionError(
1608
+ f"SlidingWindowCache not fully implemented with classes "
1609
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1610
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1611
+ f"level={level}"
1612
+ )
1613
+
1614
+ if expected.__class__.__name__ == "EncoderDecoderCache":
1615
+ if got.__class__.__name__ == "EncoderDecoderCache":
1616
+ if verbose >= 6:
1617
+ print(
1618
+ f"[max_diff] EncoderDecoderCache: "
1619
+ f"{string_type(expected)} ? {string_type(got)}"
1620
+ )
1621
+ return max_diff(
1622
+ [expected.self_attention_cache, expected.cross_attention_cache],
1623
+ [got.self_attention_cache, got.cross_attention_cache],
1624
+ verbose=verbose,
1625
+ hist=hist,
1626
+ )
1627
+ if isinstance(got, tuple) and len(got) == 2:
1628
+ return max_diff(
1629
+ [expected.self_attention_cache, expected.cross_attention_cache],
1630
+ [got[0], got[1]],
1631
+ debug_info=_debug(expected.__class__.__name__),
1632
+ **_dkws,
1633
+ )
1634
+ raise AssertionError(
1635
+ f"EncoderDecoderCache not fully implemented with classes "
1636
+ f"{expected.__class__.__name__!r} and {got.__class__.__name__!r}, "
1637
+ f"and expected={string_type(expected)}, got={string_type(got)},\n"
1638
+ f"level={level}"
1639
+ )
1640
+
1641
+ if expected.__class__.__name__ in ("transformers.cache_utils.MambaCache", "MambaCache"):
1642
+ if verbose >= 6:
1643
+ print(f"[max_diff] MambaCache: {string_type(expected)} ? {string_type(got)}")
1644
+ if got.__class__.__name__ != expected.__class__.__name__:
1645
+ # This case happens with onnx where the outputs are flattened.
1646
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1647
+ atts = []
1648
+ for k in ["conv_states", "ssm_states"]:
1649
+ if hasattr(expected, k) and not hasattr(got, k):
1650
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1651
+ atts.append(k)
1652
+
1653
+ return max_diff(
1654
+ [getattr(expected, k) for k in atts],
1655
+ [getattr(got, k) for k in atts],
1656
+ debug_info=_debug(expected.__class__.__name__),
1657
+ **_dkws,
1658
+ )
1659
+
1660
+ if expected.__class__.__name__ == "KeyValuesWrapper":
1661
+ if verbose >= 6:
1662
+ print(f"[max_diff] KeyValuesWrapper: {string_type(expected)} ? {string_type(got)}")
1663
+ if got.__class__.__name__ != expected.__class__.__name__:
1664
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1665
+ if got.cache_type != expected.cache_type:
1666
+ return dict(abs=np.inf, rel=np.inf, sum=np.inf, n=np.inf, dnan=np.inf)
1667
+ return max_diff(
1668
+ list(expected),
1669
+ list(got),
1670
+ debug_info=_debug(expected.__class__.__name__),
1671
+ **_dkws,
1672
+ )
1673
+
1674
+ raise AssertionError(
1675
+ f"Not implemented with implemented with expected="
1676
+ f"{string_type(expected)}, got={string_type(got)},\n"
1677
+ f"level={level}"
1678
+ )
1679
+
1680
+
1681
+ def string_diff(diff: Dict[str, Any]) -> str:
1682
+ """Renders discrepancies return by :func:`max_diff` into one string."""
1683
+ # dict(abs=, rel=, sum=, n=n_diff, dnan=)
1684
+ suffix = ""
1685
+ if "rep" in diff:
1686
+ rows = []
1687
+ for k, v in diff["rep"].items():
1688
+ if v > 0:
1689
+ rows.append(f"#{v}{k}")
1690
+ suffix = "-".join(rows)
1691
+ suffix = f"/{suffix}"
1692
+ if "argm" in diff:
1693
+ sa = (
1694
+ ",".join(map(str, diff["argm"]))
1695
+ if isinstance(diff["argm"], tuple)
1696
+ else str(diff["argm"])
1697
+ )
1698
+ suffix += f",amax={sa}"
1699
+ if diff.get("dnan", None):
1700
+ if diff["abs"] == 0 or diff["rel"] == 0:
1701
+ return f"abs={diff['abs']}, rel={diff['rel']}, dnan={diff['dnan']}{suffix}"
1702
+ return (
1703
+ f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}, dnan={diff['dnan']}{suffix}"
1704
+ )
1705
+ if diff["abs"] == 0 or diff["rel"] == 0:
1706
+ return f"abs={diff['abs']}, rel={diff['rel']}{suffix}"
1707
+ return f"abs={diff['abs']}, rel={diff['rel']}, n={diff['n']}{suffix}"