onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,170 @@
1
+ import functools
2
+ import importlib
3
+ import inspect
4
+ import os
5
+ import re
6
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
7
+ import transformers
8
+
9
+
10
+ def check_hasattr(config: Any, *args: Union[str, Tuple[Any, ...]]):
11
+ """
12
+ Checks the confiugation has all the attributes in ``args``.
13
+ Raises an exception otherwise.
14
+ """
15
+ for a in args:
16
+ assert isinstance(a, (str, tuple)), f"unexpected type {type(a)} in {args!r}"
17
+ if isinstance(a, str):
18
+ assert (isinstance(config, dict) and a in config) or hasattr(
19
+ config, a
20
+ ), f"Missing attribute {a!r} in\n{config}"
21
+ elif isinstance(a, tuple):
22
+ assert any(
23
+ (isinstance(name, str) and hasattr(config, name))
24
+ or all(hasattr(config, _) for _ in name)
25
+ for name in a
26
+ ), f"All attributes in {a!r} are missing from\n{config}"
27
+
28
+
29
+ def update_config(config: Any, mkwargs: Dict[str, Any]):
30
+ """Updates a configuration with different values."""
31
+ for k, v in mkwargs.items():
32
+ if k == "attn_implementation":
33
+ config._attn_implementation = v
34
+ if getattr(config, "_attn_implementation_autoset", False):
35
+ config._attn_implementation_autoset = False
36
+ continue
37
+ if isinstance(v, dict):
38
+ if not hasattr(config, k) or getattr(config, k) is None:
39
+ setattr(config, k, v)
40
+ continue
41
+ existing = getattr(config, k)
42
+ if type(existing) is dict:
43
+ existing.update(v)
44
+ else:
45
+ update_config(getattr(config, k), v)
46
+ continue
47
+ if type(config) is dict:
48
+ config[k] = v
49
+ else:
50
+ setattr(config, k, v)
51
+
52
+
53
+ def _pick(config, *atts, exceptions: Optional[Dict[str, Callable]] = None):
54
+ """Returns the first value found in the configuration."""
55
+ if (
56
+ exceptions
57
+ and hasattr(config, "architectures")
58
+ and len(config.architectures) == 1
59
+ and config.architectures[0] in exceptions
60
+ ):
61
+ excs = exceptions[config.architectures[0]]
62
+ return excs(config)
63
+ for a in atts:
64
+ if isinstance(a, str):
65
+ if hasattr(config, a):
66
+ return getattr(config, a)
67
+ elif isinstance(a, tuple):
68
+ if all(hasattr(config, _) for _ in a[1:]):
69
+ return a[0]([getattr(config, _) for _ in a[1:]])
70
+ raise AssertionError(f"Unable to find any of these {atts!r} in {config}")
71
+
72
+
73
+ def pick(config, name: str, default_value: Any) -> Any:
74
+ """
75
+ Returns the value of a attribute if config has it
76
+ otherwise the default value.
77
+ """
78
+ if not config:
79
+ return default_value
80
+ if type(config) is dict:
81
+ return config.get(name, default_value)
82
+ return getattr(config, name, default_value)
83
+
84
+
85
+ @functools.cache
86
+ def config_class_from_architecture(arch: str, exc: bool = False) -> Optional[type]:
87
+ """
88
+ Retrieves the configuration class for a given architecture.
89
+
90
+ :param arch: architecture (clas name)
91
+ :param exc: raise an exception if not found
92
+ :return: type
93
+ """
94
+ cls = getattr(transformers, arch)
95
+ mod_name = cls.__module__
96
+ mod = importlib.import_module(mod_name)
97
+ source = inspect.getsource(mod)
98
+ # [^O] avoids capturing Optional[Something]
99
+ reg = re.compile("config: ([^O][A-Za-z0-9]+)")
100
+ fall = reg.findall(source)
101
+ if len(fall) == 0:
102
+ assert not exc, (
103
+ f"Unable to guess Configuration class name for arch={arch!r}, "
104
+ f"module={mod_name!r}, no candidate, source is\n{source}"
105
+ )
106
+ return None
107
+ unique = set(fall)
108
+ assert len(unique) == 1, (
109
+ f"Unable to guess Configuration class name for arch={arch!r}, "
110
+ f"module={mod_name!r}, found={unique} (#{len(unique)}), "
111
+ f"source is\n{source}"
112
+ )
113
+ cls_name = unique.pop()
114
+ return getattr(transformers, cls_name)
115
+
116
+
117
+ def default_num_hidden_layers():
118
+ """
119
+ Returns the default number of layers.
120
+ It is lower when the unit tests are running
121
+ when ``UNITTEST_GOING=1``.
122
+ """
123
+ import torch
124
+
125
+ if torch.cuda.is_available():
126
+ capa = torch.cuda.get_device_capability(0)
127
+ if capa[0] < 9:
128
+ return 2
129
+ return 2 if os.environ.get("UNITTEST_GOING", "0") == "1" else 4
130
+
131
+
132
+ def build_diff_config(config0, config1):
133
+ """
134
+ Returns all the modified values between two configuration
135
+ """
136
+ import torch
137
+
138
+ diff = {}
139
+ for k in config0:
140
+ assert isinstance(k, str), f"k={k!r}, wrong type in {config0}"
141
+ if k not in config1:
142
+ v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
143
+ diff[k] = f"-{v0}"
144
+ for k in config1:
145
+ assert isinstance(k, str), f"k={k!r}, wrong type in {config1}"
146
+ if k not in config0:
147
+ v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
148
+ diff[k] = f"+{v1}"
149
+ for k in config0:
150
+ if k not in config1:
151
+ continue
152
+ v0 = getattr(config0, k) if hasattr(config0, k) else config0[k]
153
+ v1 = getattr(config1, k) if hasattr(config1, k) else config1[k]
154
+ if (
155
+ v0 is None
156
+ or v1 is None
157
+ or isinstance(v1, (float, int, bool, str, list, tuple, torch.dtype))
158
+ or (
159
+ isinstance(v0, dict)
160
+ and isinstance(v1, dict)
161
+ and all(isinstance(k, int) for k in v1)
162
+ )
163
+ ):
164
+ if v1 != v0:
165
+ diff[k] = f"{v0} -> {v1}"
166
+ else:
167
+ d = build_diff_config(v0, v1)
168
+ if d:
169
+ diff[k] = d
170
+ return diff
@@ -0,0 +1,163 @@
1
+ import os
2
+ from typing import Dict, List, Optional, Tuple
3
+ import onnx
4
+ import onnx.helper as oh
5
+ import torch
6
+ from ..reference.torch_ops import OpRunKernel, OpRunTensor
7
+ from .torch_helper import onnx_dtype_to_torch_dtype, torch_dtype_to_onnx_dtype
8
+ from .ort_session import InferenceSessionForTorch
9
+
10
+ _SAVED: List[str] = []
11
+ _SAVE_OPTIMIZED_MODEL_ = int(os.environ.get("DUMP_ONNX", "0"))
12
+
13
+
14
+ def _get_model_name(op_name: str, provider: str) -> Optional[str]:
15
+ if _SAVE_OPTIMIZED_MODEL_:
16
+ name = f"dump_doc_layer_norm_{provider}_{len(_SAVED)}.onnx"
17
+ _SAVED.append(name)
18
+ return name
19
+ return None
20
+
21
+
22
+ class LayerNormalizationOrt(OpRunKernel):
23
+ "LayerNormalization with onnxruntime"
24
+
25
+ @classmethod
26
+ def device_dependent(cls) -> bool:
27
+ "Needs device."
28
+ return True
29
+
30
+ def __init__(
31
+ self,
32
+ node: onnx.NodeProto,
33
+ version=None,
34
+ device: Optional[torch.device] = None,
35
+ verbose: int = 0,
36
+ ):
37
+ super().__init__(node, version, verbose=verbose)
38
+ self.axis = self.get_attribute_int(node, "axis", -1)
39
+ self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
40
+ self.device = device
41
+ self.stash_type = onnx_dtype_to_torch_dtype(
42
+ self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT) # type: ignore[arg-type]
43
+ )
44
+ self.compute_std = len(node.output) > 1
45
+ assert not self.compute_std, (
46
+ f"This kernel implementation only work when only one output "
47
+ f"is required but {node.output} were."
48
+ )
49
+ self._cache: Dict[Tuple[int, int], onnx.ModelProto] = {}
50
+ self.is_cpu = torch.device("cpu") == self.device
51
+
52
+ def _make_model(self, itype: int, rank: int, has_bias: bool) -> onnx.ModelProto:
53
+ shape = [*["d{i}" for i in range(rank - 1)], "last"]
54
+ layer_model = oh.make_model(
55
+ oh.make_graph(
56
+ [
57
+ oh.make_node(
58
+ "LayerNormalization",
59
+ ["X", "W", "B"] if has_bias else ["X", "W"],
60
+ ["Z"],
61
+ axis=self.axis,
62
+ epsilon=self.epsilon,
63
+ )
64
+ ],
65
+ "dummy",
66
+ (
67
+ [
68
+ oh.make_tensor_value_info("X", itype, shape),
69
+ oh.make_tensor_value_info("W", itype, ["last"]),
70
+ oh.make_tensor_value_info("B", itype, ["last"]),
71
+ ]
72
+ if has_bias
73
+ else [
74
+ oh.make_tensor_value_info("X", itype, shape),
75
+ oh.make_tensor_value_info("W", itype, ["last"]),
76
+ ]
77
+ ),
78
+ [oh.make_tensor_value_info("Z", itype, shape)],
79
+ ),
80
+ ir_version=9,
81
+ opset_imports=[oh.make_opsetid("", 18)],
82
+ )
83
+ provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
84
+ self._provider = provider
85
+ return InferenceSessionForTorch(
86
+ layer_model,
87
+ optimized_model_filepath=_get_model_name("layer_norm", provider),
88
+ providers=[provider],
89
+ )
90
+
91
+ def run(self, x, scale, bias=None):
92
+ itype = torch_dtype_to_onnx_dtype(x.dtype)
93
+ rank = len(x.shape)
94
+ key = itype, rank
95
+ if key not in self._cache:
96
+ self._cache[key] = self._make_model(itype, rank, bias is not None)
97
+ sess = self._cache[key]
98
+ if self.verbose:
99
+ print(f"[LayerNormalizationOrt] running on {self._provider!r}")
100
+ feeds = dict(X=x.tensor, W=scale.tensor)
101
+ if bias is not None:
102
+ feeds["B"] = bias.tensor
103
+ got = sess.run(None, feeds)[0]
104
+ return OpRunTensor(got)
105
+
106
+
107
+ class MatMulOrt(OpRunKernel):
108
+ "MatMul with onnxruntime"
109
+
110
+ @classmethod
111
+ def device_dependent(cls) -> bool:
112
+ "Needs device."
113
+ return True
114
+
115
+ def __init__(
116
+ self,
117
+ node: onnx.NodeProto,
118
+ version=None,
119
+ device: Optional[torch.device] = None,
120
+ verbose: int = 0,
121
+ ):
122
+ super().__init__(node, version, verbose=verbose)
123
+ self.device = device
124
+ self._cache: Dict[Tuple[int, int, int], onnx.ModelProto] = {}
125
+ self.is_cpu = torch.device("cpu") == self.device
126
+
127
+ def _make_model(self, itype: int, ranka: int, rankb: int) -> onnx.ModelProto:
128
+ shapea = ["a{i}" for i in range(ranka)]
129
+ shapeb = ["b{i}" for i in range(rankb)]
130
+ shapec = ["c{i}" for i in range(max(ranka, rankb))]
131
+ model = oh.make_model(
132
+ oh.make_graph(
133
+ [oh.make_node("MatMul", ["A", "B"], ["C"])],
134
+ "dummy",
135
+ [
136
+ oh.make_tensor_value_info("A", itype, shapea),
137
+ oh.make_tensor_value_info("B", itype, shapeb),
138
+ ],
139
+ [oh.make_tensor_value_info("C", itype, shapec)],
140
+ ),
141
+ ir_version=9,
142
+ opset_imports=[oh.make_opsetid("", 18)],
143
+ )
144
+ provider = "CPUExecutionProvider" if self.is_cpu else "CUDAExecutionProvider"
145
+ self._provider = provider
146
+ return InferenceSessionForTorch(
147
+ model,
148
+ optimized_model_filepath=_get_model_name("matmul", provider),
149
+ providers=[provider],
150
+ )
151
+
152
+ def run(self, a, b):
153
+ itype = torch_dtype_to_onnx_dtype(a.dtype)
154
+ ranka, rankb = len(a.shape), len(b.shape)
155
+ key = itype, ranka, rankb
156
+ if key not in self._cache:
157
+ self._cache[key] = self._make_model(itype, ranka, rankb)
158
+ sess = self._cache[key]
159
+ if self.verbose:
160
+ print(f"[MatMulOrt] running on {self._provider!r}")
161
+ feeds = dict(A=a.tensor, B=b.tensor)
162
+ got = sess.run(None, feeds)[0]
163
+ return OpRunTensor(got)
@@ -0,0 +1,273 @@
1
+ from typing import Any, Dict, Optional, Set, Tuple
2
+
3
+
4
+ class FakeTensorContext:
5
+ """Stores information used to reused same dimension for the same dimension names."""
6
+
7
+ def __init__(self, fake_mode: Optional["FakeTensorMode"] = None): # noqa: F821
8
+ if fake_mode is None:
9
+ from torch.fx.experimental.symbolic_shapes import ShapeEnv
10
+ from torch._subclasses.fake_tensor import FakeTensorMode
11
+
12
+ shape_env = ShapeEnv()
13
+ self.fake_mode = FakeTensorMode(shape_env=shape_env)
14
+ else:
15
+ self.fake_mode = fake_mode
16
+ self._candidates = self._first_primes()
17
+ self._unique_: Set[str] = set()
18
+ self._mapping_int: Dict[int, str] = {}
19
+ self._mapping_str: Dict[str, int] = {}
20
+
21
+ @classmethod
22
+ def _first_primes(cls, n=1000):
23
+ sieve = [True] * (n + 1)
24
+ sieve[0:2] = [False, False]
25
+
26
+ for i in range(2, int(n**0.5) + 1):
27
+ if sieve[i]:
28
+ # Élimine les multiples de i
29
+ sieve[i * i : n + 1 : i] = [False] * len(range(i * i, n + 1, i))
30
+
31
+ return [i for i, prime in enumerate(sieve) if prime and i >= 13]
32
+
33
+ def _unique(self) -> int:
34
+ i = 0
35
+ c = self._candidates[i]
36
+ while c in self._unique_ or c in self._mapping_int:
37
+ i += 1
38
+ assert i < len(
39
+ self._candidates
40
+ ), f"Two many unique dimensions to generate, requested: {len(self._unique_)}"
41
+ c = self._candidates[i]
42
+ self._unique_.add(c)
43
+ return c
44
+
45
+ def from_tensor(self, x, static_shapes=False) -> "FakeTensor": # noqa: F821
46
+ """
47
+ Returns a fake tensor.
48
+ ``pytorch`` returns the same name for the same dimension.
49
+ """
50
+ fake = self.fake_mode.from_tensor(x, static_shapes=static_shapes)
51
+ for i, s in zip(x.shape, fake.shape):
52
+ assert i not in self._mapping_int or self._mapping_int[i] == s, (
53
+ f"Inconsistency between {x.shape} and {fake.shape}, "
54
+ f"mapping has {self._mapping_int[i]} and s={s}"
55
+ )
56
+ self._mapping_int[i] = s
57
+ return fake
58
+
59
+ def fake_reshape(
60
+ self,
61
+ true_tensor: "torch.Tensor", # noqa: F821
62
+ sh: Dict[int, Any], # noqa: F821
63
+ fake_tensor: Optional["FakeTensor"] = None, # noqa: F821
64
+ ) -> "FakeTensor": # noqa: F821
65
+ """
66
+ Changes the shape of a true tensor to make it dynamic.
67
+
68
+ :param true_tensor: true tensor
69
+ :param sh: dynamic shape
70
+ :param fake_tensor: fake tensor, if None, make a fake one
71
+ :return: fake tensor
72
+ """
73
+ import torch
74
+
75
+ # deal with 0/1
76
+ for i in sh:
77
+ if true_tensor.shape[i] <= 1:
78
+ expanded_shape = list(true_tensor.shape)
79
+ expanded_shape[i] = self._unique()
80
+ true_tensor = torch.empty(
81
+ tuple(expanded_shape), dtype=true_tensor.dtype, device=true_tensor.device
82
+ )
83
+
84
+ # deal with equivalent dimension
85
+ new_shape = list(true_tensor.shape)
86
+ mapping = {}
87
+ for i, s in sh.items():
88
+ d = true_tensor.shape[i]
89
+ if d not in mapping:
90
+ mapping[d] = s
91
+ elif mapping[d] != s:
92
+ d = self._unique()
93
+ mapping[d] = s
94
+ new_shape[i] = d
95
+ true_tensor = torch.empty(
96
+ tuple(new_shape), dtype=true_tensor.dtype, device=true_tensor.device
97
+ )
98
+
99
+ # now switch to FakeTensor
100
+ fake_tensor = self.from_tensor(true_tensor, static_shapes=False)
101
+ new_shape = list(true_tensor.shape)
102
+ for i in sh:
103
+ new_shape[i] = fake_tensor.shape[i]
104
+
105
+ reduced_tensor = self.from_tensor(true_tensor, static_shapes=True).sum(
106
+ axis=tuple(sorted(sh)), keepdim=True
107
+ )
108
+ return reduced_tensor.expand(*new_shape)
109
+
110
+ def make_fake(self, x: Any) -> Optional["FakeTensor"]: # noqa: F821
111
+ """See :func:`onnx_diagnostic.helpers.fake_tensor_helper.make_fake`."""
112
+ if x is None:
113
+ return None
114
+ if isinstance(x, (list, tuple)):
115
+ return x.__class__([self.make_fake(i) for i in x])
116
+ if isinstance(x, dict):
117
+ return {k: self.make_fake(v) for k, v in x.items()}
118
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
119
+ assert hasattr(x, "layers"), (
120
+ f"Une more recent version of transformers (>=4.55), "
121
+ f"'layers' not found in class {type(x)}"
122
+ )
123
+ for layer in x.layers:
124
+ assert hasattr(layer, "keys") and hasattr(layer, "values"), (
125
+ f"Une more recent version of transformers (>=4.55), 'layers' "
126
+ f"not found in class {type(layer)} ({dir(layer)})"
127
+ )
128
+ layer.keys = self.make_fake(layer.keys)
129
+ layer.values = self.make_fake(layer.values)
130
+ return x
131
+ if x.__class__.__name__ == "EncoderDecoderCache":
132
+ self.make_fake(x.self_attention_cache)
133
+ self.make_fake(x.cross_attention_cache)
134
+ return x
135
+ if hasattr(x, "shape"):
136
+ return self.from_tensor(x, static_shapes=False)
137
+ from . import string_type
138
+
139
+ raise TypeError(
140
+ f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
141
+ )
142
+
143
+ def make_fake_with_dynamic_dimensions(self, x: Any, dynamic_shapes: Any) -> Any:
144
+ """
145
+ See
146
+ :func:`onnx_diagnostic.export.shape_helper.make_fake_with_dynamic_dimensions`.
147
+ """
148
+ if x is None:
149
+ return None, None
150
+ if isinstance(x, (list, tuple)):
151
+ return x.__class__(
152
+ [
153
+ self.make_fake_with_dynamic_dimensions(i, dynamic_shapes=ds)
154
+ for i, ds in zip(x, dynamic_shapes)
155
+ ]
156
+ )
157
+ if isinstance(x, dict):
158
+ return {
159
+ k: self.make_fake_with_dynamic_dimensions(v, dynamic_shapes=dynamic_shapes[k])
160
+ for k, v in x.items()
161
+ }
162
+ if x.__class__.__name__ in {"DynamicCache", "StaticCache", "HybridCache"}:
163
+ assert hasattr(x, "layers"), (
164
+ f"Une more recent version of transformers (>=4.55), "
165
+ f"'layers' not found in class {type(x)}"
166
+ )
167
+ assert isinstance(dynamic_shapes, list) and (
168
+ not dynamic_shapes or not isinstance(dynamic_shapes[0], list)
169
+ ), f"Unexpected dynamic_shapes={dynamic_shapes} for a DynamicCache"
170
+ for il, layer in enumerate(x.layers):
171
+ assert hasattr(layer, "keys") and hasattr(layer, "values"), (
172
+ f"Une more recent version of transformers (>=4.55), 'layers' "
173
+ f"not found in class {type(layer)} ({dir(layer)})"
174
+ )
175
+ layer.keys = self.make_fake_with_dynamic_dimensions(
176
+ layer.keys, dynamic_shapes=dynamic_shapes[il * 2]
177
+ )
178
+ layer.values = self.make_fake_with_dynamic_dimensions(
179
+ layer.values, dynamic_shapes=dynamic_shapes[il * 2 + 1]
180
+ )
181
+ return x
182
+ if x.__class__.__name__ == "EncoderDecoderCache":
183
+ self.make_fake_with_dynamic_dimensions(
184
+ x.self_attention_cache, dynamic_shapes=dynamic_shapes[0]
185
+ )
186
+ self.make_fake_with_dynamic_dimensions(
187
+ x.cross_attention_cache, dynamic_shapes=dynamic_shapes[1]
188
+ )
189
+ return x
190
+ if hasattr(x, "shape"):
191
+ assert dynamic_shapes is None or isinstance(dynamic_shapes, dict), (
192
+ f"dynamic_shapes must be a dictionary at this stage but "
193
+ f"dynamic_shapes={dynamic_shapes}"
194
+ )
195
+ # We need to overwrite the values.
196
+ new_shape = []
197
+ for idim, dim in enumerate(x.shape):
198
+ if dynamic_shapes is not None and idim in dynamic_shapes:
199
+ s = dynamic_shapes[idim]
200
+ assert isinstance(s, str), (
201
+ f"Unexpected type {type(s)} in dynamic_shapes={dynamic_shapes} "
202
+ f"at index {idim}"
203
+ )
204
+ if s in self._mapping_str:
205
+ dim = self._mapping_str[s]
206
+ else:
207
+ i = self._unique()
208
+ self._mapping_str[s] = i
209
+ dim = i
210
+ assert isinstance(dim, int), (
211
+ f"Unexpected type {type(dim)}, dynamic_shapes={dynamic_shapes} "
212
+ f"at index {idim}, dim={dim}"
213
+ )
214
+ new_shape.append(dim)
215
+ if tuple(new_shape) != x.shape:
216
+ import torch
217
+
218
+ x = torch.empty(tuple(new_shape), dtype=x.dtype, device=x.device)
219
+
220
+ t = self.fake_reshape(x, dynamic_shapes) # type: ignore[arg-type]
221
+ assert t.device == x.device, f"device mismatch {x.device} -> {t.device}"
222
+ assert t.dtype == x.dtype, f"dtype mismatch {x.dtype} -> {t.dtype}"
223
+ return t
224
+ from ..helpers import string_type
225
+
226
+ raise TypeError(
227
+ f"Unexpected type {type(x)} for x, content is {string_type(x, with_shape=True)}"
228
+ )
229
+
230
+
231
+ def make_fake(
232
+ x: Any, context: Optional[FakeTensorContext] = None
233
+ ) -> Tuple[Optional["FakeTensor"], Optional[FakeTensorContext]]: # noqa: F821
234
+ """
235
+ Replaces all tensors by fake tensors.
236
+ This modification happens inplace for caches.
237
+ This function is only implemented for cache with
238
+ ``transformers>=4.55``.
239
+
240
+ .. runpython::
241
+ :showcode:
242
+
243
+ import pprint
244
+ import torch
245
+ from onnx_diagnostic.helpers.cache_helper import make_dynamic_cache
246
+ from onnx_diagnostic.helpers.fake_tensor_helper import make_fake
247
+
248
+ inputs, _ = make_fake(
249
+ dict(
250
+ input_ids=torch.randint(30360, size=(2, 3), dtype=torch.int64),
251
+ attention_mask=torch.randint(1, size=(2, 33), dtype=torch.int64),
252
+ position_ids=torch.randint(32, size=(2, 3), dtype=torch.int64),
253
+ past_key_values=make_dynamic_cache(
254
+ [
255
+ (
256
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
257
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
258
+ ),
259
+ (
260
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
261
+ torch.rand((2, 32, 30, 96), dtype=torch.float16),
262
+ ),
263
+ ]
264
+ ),
265
+ )
266
+ )
267
+ pprint.pprint(inputs)
268
+ """
269
+ if x is None:
270
+ return None, None
271
+ if context is None:
272
+ context = FakeTensorContext()
273
+ return context.make_fake(x), context