onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,230 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
4
+ from ..helpers.config_helper import (
5
+ update_config,
6
+ check_hasattr,
7
+ _pick,
8
+ default_num_hidden_layers as nhl,
9
+ )
10
+
11
+ __TASK__ = "text2text-generation"
12
+
13
+
14
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
15
+ """Reduces a model size."""
16
+ kwargs: Dict[str, Any] = {}
17
+ if hasattr(config, "num_decoder_layers"):
18
+ config.num_decoder_layers = min(config.num_decoder_layers, 2)
19
+ if hasattr(config, "num_hidden_layers"):
20
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
21
+ update_config(config, kwargs)
22
+ return kwargs
23
+
24
+
25
+ def get_inputs(
26
+ model: torch.nn.Module,
27
+ config: Optional[Any],
28
+ dummy_max_token_id: int,
29
+ num_key_value_heads_encoder: int,
30
+ num_key_value_heads_decoder: int,
31
+ num_hidden_layers: int,
32
+ head_dim_encoder: int,
33
+ head_dim_decoder: int,
34
+ encoder_dim: int,
35
+ batch_size: int = 2,
36
+ sequence_length: int = 30,
37
+ sequence_length2: int = 3,
38
+ add_second_input: int = 1,
39
+ **kwargs, # unused
40
+ ):
41
+ """
42
+ Generates input for task ``text2text-generation``.
43
+
44
+ :param model: model to get the missing information
45
+ :param config: configuration used to generate the model
46
+ :param head_dim_encoder: last dimension of the cache for the encoder
47
+ :param head_dim_decoder: last dimension of the cache for the decoder
48
+ :param num_key_value_heads_encoder: number of heads for the encoder
49
+ :param num_key_value_heads_decoder: number of heads for the decoder
50
+ :param dummy_max_token_id: dummy max token id
51
+ :param batch_size: batch size
52
+ :param encoder_dim: last dimension of encoder_last_hidden_state
53
+ :param sequence_length: sequence length
54
+ :param sequence_length2: new sequence length
55
+ :return: dictionary
56
+
57
+ Stolen inputs for one model.
58
+
59
+ ::
60
+
61
+ cache_position:T7s1
62
+ past_key_values:EncoderDecoderCache(
63
+ self_attention_cache=DynamicCache(
64
+ key_cache=#6[T1s1x8x1x64,...],
65
+ value_cache=#6[T1s1x8x1x64,...]),
66
+ cross_attention_cache=DynamicCache(
67
+ key_cache=#6[T1s1x8x16x64,...],
68
+ value_cache=#6[T1s1x8x16x64,...])),
69
+ decoder_input_ids:T7s1x1,
70
+ encoder_outputs:dict(last_hidden_state:T1s1x16x512)
71
+ """
72
+ assert (
73
+ "cls_cache" not in kwargs
74
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
75
+ batch = "batch"
76
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
77
+ cache_length = "cache_length_key"
78
+ cache_length2 = "cache_length_val"
79
+
80
+ shapes = {
81
+ "input_ids": {0: batch, 1: seq_length},
82
+ "decoder_input_ids": {0: batch, 1: "seq_ids"},
83
+ "attention_mask": {0: batch, 1: "seq_mask"},
84
+ # "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
85
+ "past_key_values": [
86
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
87
+ [{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
88
+ ],
89
+ # one these is selected based on the forward method signature
90
+ # "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
91
+ # "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
92
+ }
93
+
94
+ inputs = dict(
95
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
96
+ torch.int64
97
+ ),
98
+ decoder_input_ids=torch.randint(
99
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
100
+ ).to(torch.int64),
101
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
102
+ # cache_position=torch.arange(sequence_length, sequence_length + sequence_length2)
103
+ # .to(torch.int64)
104
+ # .expand((batch_size, -1)),
105
+ past_key_values=make_encoder_decoder_cache(
106
+ make_dynamic_cache(
107
+ [
108
+ (
109
+ torch.randn(
110
+ batch_size,
111
+ num_key_value_heads_encoder,
112
+ sequence_length,
113
+ head_dim_encoder,
114
+ ),
115
+ torch.randn(
116
+ batch_size,
117
+ num_key_value_heads_encoder,
118
+ sequence_length,
119
+ head_dim_encoder,
120
+ ),
121
+ )
122
+ for i in range(num_hidden_layers)
123
+ ]
124
+ ),
125
+ make_dynamic_cache(
126
+ [
127
+ (
128
+ torch.randn(
129
+ batch_size,
130
+ num_key_value_heads_decoder,
131
+ sequence_length2,
132
+ head_dim_decoder,
133
+ ),
134
+ torch.randn(
135
+ batch_size,
136
+ num_key_value_heads_decoder,
137
+ sequence_length2,
138
+ head_dim_decoder,
139
+ ),
140
+ )
141
+ for i in range(num_hidden_layers)
142
+ ]
143
+ ),
144
+ ),
145
+ # one these is selected based on the forward method signature
146
+ # encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
147
+ # encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
148
+ )
149
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
150
+ if add_second_input:
151
+ assert (
152
+ add_second_input > 0
153
+ ), f"Not implemented for add_second_input={add_second_input}."
154
+ res["inputs2"] = get_inputs(
155
+ model=model,
156
+ config=config,
157
+ dummy_max_token_id=dummy_max_token_id,
158
+ num_key_value_heads_encoder=num_key_value_heads_encoder,
159
+ num_key_value_heads_decoder=num_key_value_heads_decoder,
160
+ num_hidden_layers=num_hidden_layers,
161
+ head_dim_encoder=head_dim_encoder,
162
+ head_dim_decoder=head_dim_decoder,
163
+ encoder_dim=encoder_dim,
164
+ batch_size=batch_size + 1,
165
+ sequence_length=sequence_length + add_second_input,
166
+ sequence_length2=sequence_length2 + 1,
167
+ add_second_input=0,
168
+ **kwargs,
169
+ )["inputs"]
170
+ return res
171
+
172
+
173
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
174
+ """
175
+ Inputs kwargs.
176
+
177
+ If the configuration is None, the function selects typical dimensions.
178
+ """
179
+ if config is not None:
180
+ check_hasattr(
181
+ config,
182
+ "vocab_size",
183
+ "hidden_size",
184
+ "num_attention_heads",
185
+ ("num_hidden_layers", "num_layers"),
186
+ ("n_positions", "d_model"),
187
+ (
188
+ "num_key_value_heads",
189
+ "num_heads",
190
+ ("decoder_attention_heads", "encoder_attention_heads"),
191
+ ),
192
+ )
193
+ # exceptions = {
194
+ # "PLBartForConditionalGeneration": (
195
+ # lambda c: c.encoder_attention_heads + c.decoder_attention_heads
196
+ # )
197
+ # }
198
+ kwargs = dict(
199
+ batch_size=2,
200
+ sequence_length=30,
201
+ sequence_length2=3,
202
+ head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
203
+ head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
204
+ dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
205
+ num_hidden_layers=(
206
+ 8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
207
+ ),
208
+ num_key_value_heads_encoder=(
209
+ 16
210
+ if config is None
211
+ else _pick(
212
+ config,
213
+ "encoder_attention_heads",
214
+ "num_key_value_heads",
215
+ "num_heads",
216
+ )
217
+ ),
218
+ num_key_value_heads_decoder=(
219
+ 16
220
+ if config is None
221
+ else _pick(
222
+ config,
223
+ "decoder_attention_heads",
224
+ "num_key_value_heads",
225
+ "num_heads",
226
+ )
227
+ ),
228
+ encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
229
+ )
230
+ return kwargs, get_inputs
@@ -0,0 +1,89 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "text-classification"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ check_hasattr(config, "num_attention_heads", "num_hidden_layers")
15
+ kwargs = dict(
16
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
17
+ num_attention_heads=min(config.num_attention_heads, 4),
18
+ )
19
+ update_config(config, kwargs)
20
+ return kwargs
21
+
22
+
23
+ def get_inputs(
24
+ model: torch.nn.Module,
25
+ config: Optional[Any],
26
+ batch_size: int,
27
+ sequence_length: int,
28
+ dummy_max_token_id: int,
29
+ add_second_input: int = 1,
30
+ **kwargs, # unused
31
+ ):
32
+ """
33
+ Generates inputs for task ``text-classification``.
34
+ Example:
35
+
36
+ ::
37
+
38
+ input_ids:T7s1x13[101,72654:A16789.23076923077],
39
+ token_type_ids:T7s1x13[0,0:A0.0],
40
+ attention_mask:T7s1x13[1,1:A1.0])
41
+ """
42
+ assert (
43
+ "cls_cache" not in kwargs
44
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
45
+ batch = "batch"
46
+ seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
47
+ shapes = {
48
+ "input_ids": {0: batch, 1: seq_length},
49
+ "token_type_ids": {0: batch, 1: seq_length},
50
+ "attention_mask": {0: batch, 1: seq_length},
51
+ }
52
+ inputs = dict(
53
+ input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
54
+ torch.int64
55
+ ),
56
+ token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
57
+ attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
58
+ )
59
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
60
+ if add_second_input:
61
+ assert (
62
+ add_second_input > 0
63
+ ), f"Not implemented for add_second_input={add_second_input}."
64
+ res["inputs2"] = get_inputs(
65
+ model=model,
66
+ config=config,
67
+ batch_size=batch_size + 1,
68
+ sequence_length=sequence_length + add_second_input,
69
+ dummy_max_token_id=dummy_max_token_id,
70
+ add_second_input=0,
71
+ **kwargs,
72
+ )["inputs"]
73
+ return res
74
+
75
+
76
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
77
+ """
78
+ Inputs kwargs.
79
+
80
+ If the configuration is None, the function selects typical dimensions.
81
+ """
82
+ if config is not None:
83
+ check_hasattr(config, "vocab_size")
84
+ kwargs = dict(
85
+ batch_size=2,
86
+ sequence_length=30,
87
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
88
+ )
89
+ return kwargs, get_inputs
@@ -0,0 +1,352 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple, Union
2
+ import torch
3
+ from ..helpers.cache_helper import (
4
+ make_dynamic_cache,
5
+ make_mamba_cache,
6
+ make_sliding_window_cache,
7
+ make_static_cache,
8
+ )
9
+ from ..helpers.config_helper import (
10
+ update_config,
11
+ check_hasattr,
12
+ _pick,
13
+ default_num_hidden_layers as nhl,
14
+ )
15
+
16
+ __TASK__ = "text-generation"
17
+
18
+
19
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
20
+ """Reduces a model size."""
21
+ # FalconMambaConfig: use_mambapy
22
+ if hasattr(config, "text_config"):
23
+ # The model is probably of mixture of models used only for text.
24
+ config = config.text_config
25
+ check_hasattr(
26
+ config,
27
+ ("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
28
+ "num_hidden_layers",
29
+ ("num_key_value_heads", "num_attention_heads", "use_mambapy"),
30
+ "hidden_size",
31
+ "vocab_size",
32
+ )
33
+ if config.__class__.__name__ == "FalconMambaConfig":
34
+ check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
35
+ kwargs = dict(
36
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
37
+ intermediate_size=256 if config is None else min(512, config.intermediate_size),
38
+ hidden_size=512 if config is None else min(512, config.hidden_size),
39
+ cls_cache="MambaCache",
40
+ state_size=8 if config is None else getattr(config, "state_size", None),
41
+ conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
42
+ )
43
+ else:
44
+ kwargs = dict(
45
+ head_dim=getattr(
46
+ config, "head_dim", config.hidden_size // config.num_attention_heads
47
+ ),
48
+ num_hidden_layers=min(config.num_hidden_layers, nhl()),
49
+ num_key_value_heads=(
50
+ config.num_key_value_heads
51
+ if hasattr(config, "num_key_value_heads")
52
+ else config.num_attention_heads
53
+ ),
54
+ )
55
+ update_config(config, kwargs)
56
+ return kwargs
57
+
58
+
59
+ def get_inputs(
60
+ model: torch.nn.Module,
61
+ config: Optional[Any],
62
+ dummy_max_token_id: int,
63
+ num_hidden_layers: int,
64
+ batch_size: int = 2,
65
+ sequence_length: int = 30,
66
+ sequence_length2: int = 3,
67
+ dynamic_rope: bool = False,
68
+ num_key_value_heads: Optional[int] = None,
69
+ head_dim: Optional[int] = None,
70
+ cls_cache: Optional[Union[type, str]] = None,
71
+ add_second_input: int = 1,
72
+ **kwargs, # unused
73
+ ):
74
+ """
75
+ Generates input for task ``text-generation``.
76
+
77
+ :param model: model to get the missing information
78
+ :param config: configuration used to generate the model
79
+ :param head_dim: last dimension of the cache
80
+ :param dummy_max_token_id: dummy max token id
81
+ :param batch_size: batch size
82
+ :param sequence_length: sequence length
83
+ :param sequence_length2: new sequence length
84
+ :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
85
+ :param cls_cache: cache class, by default it is
86
+ :class:`transformers.cache_utils.DynamicCache`
87
+ :return: dictionary
88
+ """
89
+ batch = "batch"
90
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
91
+ cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
92
+
93
+ if config is not None and config.__class__.__name__ == "FalconMambaConfig":
94
+ try:
95
+ from transformers.models.mamba.modeling_mamba import MambaCache
96
+ except ImportError:
97
+ from transformers.cache_utils import MambaCache
98
+
99
+ assert cls_cache in (
100
+ "MambaCache",
101
+ MambaCache,
102
+ ), f"Unexpected value for cls_cache={cls_cache} and config={config}"
103
+ seq_length_multiple = 8
104
+ sequence_length = (
105
+ (sequence_length + seq_length_multiple)
106
+ // seq_length_multiple
107
+ * seq_length_multiple
108
+ )
109
+ # sequence_inc = seq_length_multiple
110
+ sequence_length2 = seq_length_multiple
111
+
112
+ shapes = {
113
+ "input_ids": {0: batch, 1: "sequence_length"},
114
+ "attention_mask": {
115
+ 0: batch,
116
+ 1: "cache+seq", # cache_length + seq_length
117
+ },
118
+ "cache_position": {
119
+ 0: batch,
120
+ 1: "cache+seq", # cache_length + seq_length
121
+ },
122
+ "cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
123
+ }
124
+ inputs = dict(
125
+ input_ids=torch.randint(
126
+ 0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
127
+ ).to(torch.int64),
128
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
129
+ torch.int64
130
+ ),
131
+ cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
132
+ # .expand((batch_size, -1))
133
+ cache_params=make_mamba_cache(
134
+ [
135
+ (
136
+ torch.randn(
137
+ batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
138
+ ),
139
+ torch.randn(
140
+ batch_size, kwargs["intermediate_size"], kwargs["state_size"]
141
+ ),
142
+ )
143
+ for i in range(num_hidden_layers)
144
+ ]
145
+ ),
146
+ )
147
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
148
+ else:
149
+ if head_dim is None:
150
+ assert config, "head_dim is None, the value cannot be set without a configuration"
151
+ head_dim = config.hidden_size // config.num_attention_heads
152
+
153
+ cache_name = (
154
+ cls_cache
155
+ if cls_cache is None or isinstance(cls_cache, str)
156
+ else cls_cache.__name__
157
+ )
158
+ make_caches = {
159
+ "DynamicCache": make_dynamic_cache,
160
+ "SlidingWindowCache": make_sliding_window_cache,
161
+ "StaticCache": make_static_cache,
162
+ }
163
+ assert cache_name is None or cache_name in make_caches, (
164
+ f"Unable to handle cls_cache={cache_name!r}, it should be in "
165
+ f"{sorted(make_caches)}"
166
+ )
167
+ make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
168
+ is_static = cache_name == "StaticCache"
169
+
170
+ if is_static:
171
+ # static
172
+ shapes = {
173
+ "input_ids": {0: batch, 1: seq_length},
174
+ "attention_mask": {0: batch, 2: "seq"},
175
+ "cache_position": {0: "seq"},
176
+ "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
177
+ }
178
+ inputs = dict(
179
+ input_ids=torch.randint(
180
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
181
+ ).to(torch.int64),
182
+ attention_mask=torch.ones(
183
+ (batch_size, num_key_value_heads, sequence_length2, head_dim)
184
+ ).to(torch.bool),
185
+ cache_position=torch.arange(sequence_length2).to(torch.int64),
186
+ past_key_values=make_static_cache(
187
+ [
188
+ (
189
+ torch.randn(
190
+ batch_size,
191
+ num_key_value_heads,
192
+ sequence_length + sequence_length2,
193
+ head_dim,
194
+ ),
195
+ torch.randn(
196
+ batch_size,
197
+ num_key_value_heads,
198
+ sequence_length + sequence_length2,
199
+ head_dim,
200
+ ),
201
+ )
202
+ for i in range(num_hidden_layers)
203
+ ],
204
+ max_cache_len=max(sequence_length + sequence_length2, head_dim),
205
+ ),
206
+ )
207
+ else:
208
+ # dynamic
209
+ shapes = {
210
+ "input_ids": {0: batch, 1: seq_length},
211
+ "attention_mask": {
212
+ 0: batch,
213
+ 1: "cache+seq", # cache_length + seq_length
214
+ },
215
+ "position_ids": {0: batch, 1: seq_length},
216
+ "past_key_values": [
217
+ {0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)
218
+ ],
219
+ }
220
+
221
+ inputs = dict(
222
+ input_ids=torch.randint(
223
+ 0, dummy_max_token_id, (batch_size, sequence_length2)
224
+ ).to(torch.int64),
225
+ attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
226
+ torch.int64
227
+ ),
228
+ position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
229
+ .to(torch.int64)
230
+ .expand((batch_size, -1)),
231
+ past_key_values=make_cache( # type: ignore[operator]
232
+ [
233
+ (
234
+ torch.randn(
235
+ batch_size, num_key_value_heads, sequence_length, head_dim
236
+ ),
237
+ torch.randn(
238
+ batch_size, num_key_value_heads, sequence_length, head_dim
239
+ ),
240
+ )
241
+ for i in range(num_hidden_layers)
242
+ ]
243
+ ),
244
+ )
245
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
246
+ if add_second_input:
247
+ res["inputs2"] = get_inputs(
248
+ model=model,
249
+ config=config,
250
+ dummy_max_token_id=dummy_max_token_id,
251
+ num_hidden_layers=num_hidden_layers,
252
+ batch_size=(batch_size + 1) if add_second_input > 0 else 1,
253
+ sequence_length=sequence_length + 1,
254
+ sequence_length2=sequence_length2
255
+ + (add_second_input if add_second_input > 0 else -add_second_input),
256
+ dynamic_rope=dynamic_rope,
257
+ num_key_value_heads=num_key_value_heads,
258
+ head_dim=head_dim,
259
+ cls_cache=cls_cache,
260
+ add_second_input=0,
261
+ **kwargs,
262
+ )["inputs"]
263
+ res["inputs_empty_cache"] = get_inputs(
264
+ model=model,
265
+ config=config,
266
+ dummy_max_token_id=dummy_max_token_id,
267
+ num_hidden_layers=num_hidden_layers,
268
+ batch_size=batch_size,
269
+ sequence_length=0,
270
+ sequence_length2=sequence_length2,
271
+ dynamic_rope=dynamic_rope,
272
+ num_key_value_heads=num_key_value_heads,
273
+ head_dim=head_dim,
274
+ cls_cache=cls_cache,
275
+ add_second_input=0,
276
+ **kwargs,
277
+ )["inputs"]
278
+ res["inputs_batch1"] = get_inputs(
279
+ model=model,
280
+ config=config,
281
+ dummy_max_token_id=dummy_max_token_id,
282
+ num_hidden_layers=num_hidden_layers,
283
+ batch_size=1,
284
+ sequence_length=sequence_length,
285
+ sequence_length2=sequence_length2,
286
+ dynamic_rope=dynamic_rope,
287
+ num_key_value_heads=num_key_value_heads,
288
+ head_dim=head_dim,
289
+ cls_cache=cls_cache,
290
+ add_second_input=0,
291
+ **kwargs,
292
+ )["inputs"]
293
+ return res
294
+
295
+
296
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
297
+ """
298
+ Inputs kwargs.
299
+
300
+ If the configuration is None, the function selects typical dimensions.
301
+ """
302
+ if hasattr(config, "text_config"):
303
+ # The model is probably of mixture of models used only for text.
304
+ config = config.text_config
305
+ if config is not None:
306
+ check_hasattr(
307
+ config,
308
+ "vocab_size",
309
+ ("num_attention_heads", "use_mambapy"),
310
+ ("num_key_value_heads", "num_attention_heads", "use_mambapy"),
311
+ "hidden_size",
312
+ )
313
+ if config.__class__.__name__ == "FalconMambaConfig":
314
+ check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
315
+ kwargs = dict(
316
+ batch_size=2,
317
+ sequence_length=30,
318
+ sequence_length2=3,
319
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
320
+ num_hidden_layers=4 if config is None else config.num_hidden_layers,
321
+ intermediate_size=256 if config is None else config.intermediate_size,
322
+ cls_cache="MambaCache",
323
+ state_size=8 if config is None else getattr(config, "state_size", None),
324
+ conv_kernel=8 if config is None else getattr(config, "conv_kernel", None),
325
+ )
326
+ else:
327
+ kwargs = dict(
328
+ batch_size=2,
329
+ sequence_length=30,
330
+ sequence_length2=3,
331
+ head_dim=(
332
+ 16
333
+ if config is None
334
+ else getattr(
335
+ config, "head_dim", config.hidden_size // config.num_attention_heads
336
+ )
337
+ ),
338
+ dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
339
+ num_hidden_layers=4 if config is None else config.num_hidden_layers,
340
+ num_key_value_heads=(
341
+ 24
342
+ if config is None
343
+ else _pick(config, "num_key_value_heads", "num_attention_heads")
344
+ ),
345
+ hidden_size=512 if config is None else config.hidden_size,
346
+ )
347
+ if config is None or hasattr(config, "intermediate_size"):
348
+ kwargs["intermediate_size"] = (
349
+ 1024 if config is None else config.intermediate_size,
350
+ )
351
+
352
+ return kwargs, get_inputs