onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,230 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.cache_helper import make_dynamic_cache, make_encoder_decoder_cache
|
|
4
|
+
from ..helpers.config_helper import (
|
|
5
|
+
update_config,
|
|
6
|
+
check_hasattr,
|
|
7
|
+
_pick,
|
|
8
|
+
default_num_hidden_layers as nhl,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
__TASK__ = "text2text-generation"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
15
|
+
"""Reduces a model size."""
|
|
16
|
+
kwargs: Dict[str, Any] = {}
|
|
17
|
+
if hasattr(config, "num_decoder_layers"):
|
|
18
|
+
config.num_decoder_layers = min(config.num_decoder_layers, 2)
|
|
19
|
+
if hasattr(config, "num_hidden_layers"):
|
|
20
|
+
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
21
|
+
update_config(config, kwargs)
|
|
22
|
+
return kwargs
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_inputs(
|
|
26
|
+
model: torch.nn.Module,
|
|
27
|
+
config: Optional[Any],
|
|
28
|
+
dummy_max_token_id: int,
|
|
29
|
+
num_key_value_heads_encoder: int,
|
|
30
|
+
num_key_value_heads_decoder: int,
|
|
31
|
+
num_hidden_layers: int,
|
|
32
|
+
head_dim_encoder: int,
|
|
33
|
+
head_dim_decoder: int,
|
|
34
|
+
encoder_dim: int,
|
|
35
|
+
batch_size: int = 2,
|
|
36
|
+
sequence_length: int = 30,
|
|
37
|
+
sequence_length2: int = 3,
|
|
38
|
+
add_second_input: int = 1,
|
|
39
|
+
**kwargs, # unused
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Generates input for task ``text2text-generation``.
|
|
43
|
+
|
|
44
|
+
:param model: model to get the missing information
|
|
45
|
+
:param config: configuration used to generate the model
|
|
46
|
+
:param head_dim_encoder: last dimension of the cache for the encoder
|
|
47
|
+
:param head_dim_decoder: last dimension of the cache for the decoder
|
|
48
|
+
:param num_key_value_heads_encoder: number of heads for the encoder
|
|
49
|
+
:param num_key_value_heads_decoder: number of heads for the decoder
|
|
50
|
+
:param dummy_max_token_id: dummy max token id
|
|
51
|
+
:param batch_size: batch size
|
|
52
|
+
:param encoder_dim: last dimension of encoder_last_hidden_state
|
|
53
|
+
:param sequence_length: sequence length
|
|
54
|
+
:param sequence_length2: new sequence length
|
|
55
|
+
:return: dictionary
|
|
56
|
+
|
|
57
|
+
Stolen inputs for one model.
|
|
58
|
+
|
|
59
|
+
::
|
|
60
|
+
|
|
61
|
+
cache_position:T7s1
|
|
62
|
+
past_key_values:EncoderDecoderCache(
|
|
63
|
+
self_attention_cache=DynamicCache(
|
|
64
|
+
key_cache=#6[T1s1x8x1x64,...],
|
|
65
|
+
value_cache=#6[T1s1x8x1x64,...]),
|
|
66
|
+
cross_attention_cache=DynamicCache(
|
|
67
|
+
key_cache=#6[T1s1x8x16x64,...],
|
|
68
|
+
value_cache=#6[T1s1x8x16x64,...])),
|
|
69
|
+
decoder_input_ids:T7s1x1,
|
|
70
|
+
encoder_outputs:dict(last_hidden_state:T1s1x16x512)
|
|
71
|
+
"""
|
|
72
|
+
assert (
|
|
73
|
+
"cls_cache" not in kwargs
|
|
74
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
75
|
+
batch = "batch"
|
|
76
|
+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
77
|
+
cache_length = "cache_length_key"
|
|
78
|
+
cache_length2 = "cache_length_val"
|
|
79
|
+
|
|
80
|
+
shapes = {
|
|
81
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
82
|
+
"decoder_input_ids": {0: batch, 1: "seq_ids"},
|
|
83
|
+
"attention_mask": {0: batch, 1: "seq_mask"},
|
|
84
|
+
# "cache_position": {0: batch, 1: torch.export.Dim.DYNAMIC},
|
|
85
|
+
"past_key_values": [
|
|
86
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)],
|
|
87
|
+
[{0: batch, 2: cache_length2} for _ in range(num_hidden_layers * 2)],
|
|
88
|
+
],
|
|
89
|
+
# one these is selected based on the forward method signature
|
|
90
|
+
# "encoder_last_hidden_state": {0: batch, 1: torch.export.Dim.DYNAMIC},
|
|
91
|
+
# "encoder_outputs": {0: batch, 1: torch.export.Dim.DYNAMIC},
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
inputs = dict(
|
|
95
|
+
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
|
|
96
|
+
torch.int64
|
|
97
|
+
),
|
|
98
|
+
decoder_input_ids=torch.randint(
|
|
99
|
+
0, dummy_max_token_id, (batch_size, sequence_length2)
|
|
100
|
+
).to(torch.int64),
|
|
101
|
+
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
|
|
102
|
+
# cache_position=torch.arange(sequence_length, sequence_length + sequence_length2)
|
|
103
|
+
# .to(torch.int64)
|
|
104
|
+
# .expand((batch_size, -1)),
|
|
105
|
+
past_key_values=make_encoder_decoder_cache(
|
|
106
|
+
make_dynamic_cache(
|
|
107
|
+
[
|
|
108
|
+
(
|
|
109
|
+
torch.randn(
|
|
110
|
+
batch_size,
|
|
111
|
+
num_key_value_heads_encoder,
|
|
112
|
+
sequence_length,
|
|
113
|
+
head_dim_encoder,
|
|
114
|
+
),
|
|
115
|
+
torch.randn(
|
|
116
|
+
batch_size,
|
|
117
|
+
num_key_value_heads_encoder,
|
|
118
|
+
sequence_length,
|
|
119
|
+
head_dim_encoder,
|
|
120
|
+
),
|
|
121
|
+
)
|
|
122
|
+
for i in range(num_hidden_layers)
|
|
123
|
+
]
|
|
124
|
+
),
|
|
125
|
+
make_dynamic_cache(
|
|
126
|
+
[
|
|
127
|
+
(
|
|
128
|
+
torch.randn(
|
|
129
|
+
batch_size,
|
|
130
|
+
num_key_value_heads_decoder,
|
|
131
|
+
sequence_length2,
|
|
132
|
+
head_dim_decoder,
|
|
133
|
+
),
|
|
134
|
+
torch.randn(
|
|
135
|
+
batch_size,
|
|
136
|
+
num_key_value_heads_decoder,
|
|
137
|
+
sequence_length2,
|
|
138
|
+
head_dim_decoder,
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
for i in range(num_hidden_layers)
|
|
142
|
+
]
|
|
143
|
+
),
|
|
144
|
+
),
|
|
145
|
+
# one these is selected based on the forward method signature
|
|
146
|
+
# encoder_last_hidden_state=torch.randn(batch_size, sequence_length2, encoder_dim),
|
|
147
|
+
# encoder_outputs=torch.randn(batch_size, sequence_length2, encoder_dim),
|
|
148
|
+
)
|
|
149
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
150
|
+
if add_second_input:
|
|
151
|
+
assert (
|
|
152
|
+
add_second_input > 0
|
|
153
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
154
|
+
res["inputs2"] = get_inputs(
|
|
155
|
+
model=model,
|
|
156
|
+
config=config,
|
|
157
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
158
|
+
num_key_value_heads_encoder=num_key_value_heads_encoder,
|
|
159
|
+
num_key_value_heads_decoder=num_key_value_heads_decoder,
|
|
160
|
+
num_hidden_layers=num_hidden_layers,
|
|
161
|
+
head_dim_encoder=head_dim_encoder,
|
|
162
|
+
head_dim_decoder=head_dim_decoder,
|
|
163
|
+
encoder_dim=encoder_dim,
|
|
164
|
+
batch_size=batch_size + 1,
|
|
165
|
+
sequence_length=sequence_length + add_second_input,
|
|
166
|
+
sequence_length2=sequence_length2 + 1,
|
|
167
|
+
add_second_input=0,
|
|
168
|
+
**kwargs,
|
|
169
|
+
)["inputs"]
|
|
170
|
+
return res
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
174
|
+
"""
|
|
175
|
+
Inputs kwargs.
|
|
176
|
+
|
|
177
|
+
If the configuration is None, the function selects typical dimensions.
|
|
178
|
+
"""
|
|
179
|
+
if config is not None:
|
|
180
|
+
check_hasattr(
|
|
181
|
+
config,
|
|
182
|
+
"vocab_size",
|
|
183
|
+
"hidden_size",
|
|
184
|
+
"num_attention_heads",
|
|
185
|
+
("num_hidden_layers", "num_layers"),
|
|
186
|
+
("n_positions", "d_model"),
|
|
187
|
+
(
|
|
188
|
+
"num_key_value_heads",
|
|
189
|
+
"num_heads",
|
|
190
|
+
("decoder_attention_heads", "encoder_attention_heads"),
|
|
191
|
+
),
|
|
192
|
+
)
|
|
193
|
+
# exceptions = {
|
|
194
|
+
# "PLBartForConditionalGeneration": (
|
|
195
|
+
# lambda c: c.encoder_attention_heads + c.decoder_attention_heads
|
|
196
|
+
# )
|
|
197
|
+
# }
|
|
198
|
+
kwargs = dict(
|
|
199
|
+
batch_size=2,
|
|
200
|
+
sequence_length=30,
|
|
201
|
+
sequence_length2=3,
|
|
202
|
+
head_dim_encoder=16 if config is None else _pick(config, "d_kv", "encoder_ffn_dim"),
|
|
203
|
+
head_dim_decoder=16 if config is None else _pick(config, "d_kv", "decoder_ffn_dim"),
|
|
204
|
+
dummy_max_token_id=31999 if config is None else config.vocab_size - 1,
|
|
205
|
+
num_hidden_layers=(
|
|
206
|
+
8 if config is None else _pick(config, "num_hidden_layers", "num_layers")
|
|
207
|
+
),
|
|
208
|
+
num_key_value_heads_encoder=(
|
|
209
|
+
16
|
|
210
|
+
if config is None
|
|
211
|
+
else _pick(
|
|
212
|
+
config,
|
|
213
|
+
"encoder_attention_heads",
|
|
214
|
+
"num_key_value_heads",
|
|
215
|
+
"num_heads",
|
|
216
|
+
)
|
|
217
|
+
),
|
|
218
|
+
num_key_value_heads_decoder=(
|
|
219
|
+
16
|
|
220
|
+
if config is None
|
|
221
|
+
else _pick(
|
|
222
|
+
config,
|
|
223
|
+
"decoder_attention_heads",
|
|
224
|
+
"num_key_value_heads",
|
|
225
|
+
"num_heads",
|
|
226
|
+
)
|
|
227
|
+
),
|
|
228
|
+
encoder_dim=512 if config is None else _pick(config, "n_positions", "d_model"),
|
|
229
|
+
)
|
|
230
|
+
return kwargs, get_inputs
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__TASK__ = "text-classification"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
13
|
+
"""Reduces a model size."""
|
|
14
|
+
check_hasattr(config, "num_attention_heads", "num_hidden_layers")
|
|
15
|
+
kwargs = dict(
|
|
16
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
17
|
+
num_attention_heads=min(config.num_attention_heads, 4),
|
|
18
|
+
)
|
|
19
|
+
update_config(config, kwargs)
|
|
20
|
+
return kwargs
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def get_inputs(
|
|
24
|
+
model: torch.nn.Module,
|
|
25
|
+
config: Optional[Any],
|
|
26
|
+
batch_size: int,
|
|
27
|
+
sequence_length: int,
|
|
28
|
+
dummy_max_token_id: int,
|
|
29
|
+
add_second_input: int = 1,
|
|
30
|
+
**kwargs, # unused
|
|
31
|
+
):
|
|
32
|
+
"""
|
|
33
|
+
Generates inputs for task ``text-classification``.
|
|
34
|
+
Example:
|
|
35
|
+
|
|
36
|
+
::
|
|
37
|
+
|
|
38
|
+
input_ids:T7s1x13[101,72654:A16789.23076923077],
|
|
39
|
+
token_type_ids:T7s1x13[0,0:A0.0],
|
|
40
|
+
attention_mask:T7s1x13[1,1:A1.0])
|
|
41
|
+
"""
|
|
42
|
+
assert (
|
|
43
|
+
"cls_cache" not in kwargs
|
|
44
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
45
|
+
batch = "batch"
|
|
46
|
+
seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
|
|
47
|
+
shapes = {
|
|
48
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
49
|
+
"token_type_ids": {0: batch, 1: seq_length},
|
|
50
|
+
"attention_mask": {0: batch, 1: seq_length},
|
|
51
|
+
}
|
|
52
|
+
inputs = dict(
|
|
53
|
+
input_ids=torch.randint(0, dummy_max_token_id, (batch_size, sequence_length)).to(
|
|
54
|
+
torch.int64
|
|
55
|
+
),
|
|
56
|
+
token_type_ids=torch.zeros((batch_size, sequence_length)).to(torch.int64),
|
|
57
|
+
attention_mask=torch.ones((batch_size, sequence_length)).to(torch.int64),
|
|
58
|
+
)
|
|
59
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
60
|
+
if add_second_input:
|
|
61
|
+
assert (
|
|
62
|
+
add_second_input > 0
|
|
63
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
64
|
+
res["inputs2"] = get_inputs(
|
|
65
|
+
model=model,
|
|
66
|
+
config=config,
|
|
67
|
+
batch_size=batch_size + 1,
|
|
68
|
+
sequence_length=sequence_length + add_second_input,
|
|
69
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
70
|
+
add_second_input=0,
|
|
71
|
+
**kwargs,
|
|
72
|
+
)["inputs"]
|
|
73
|
+
return res
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
77
|
+
"""
|
|
78
|
+
Inputs kwargs.
|
|
79
|
+
|
|
80
|
+
If the configuration is None, the function selects typical dimensions.
|
|
81
|
+
"""
|
|
82
|
+
if config is not None:
|
|
83
|
+
check_hasattr(config, "vocab_size")
|
|
84
|
+
kwargs = dict(
|
|
85
|
+
batch_size=2,
|
|
86
|
+
sequence_length=30,
|
|
87
|
+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
88
|
+
)
|
|
89
|
+
return kwargs, get_inputs
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.cache_helper import (
|
|
4
|
+
make_dynamic_cache,
|
|
5
|
+
make_mamba_cache,
|
|
6
|
+
make_sliding_window_cache,
|
|
7
|
+
make_static_cache,
|
|
8
|
+
)
|
|
9
|
+
from ..helpers.config_helper import (
|
|
10
|
+
update_config,
|
|
11
|
+
check_hasattr,
|
|
12
|
+
_pick,
|
|
13
|
+
default_num_hidden_layers as nhl,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
__TASK__ = "text-generation"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
20
|
+
"""Reduces a model size."""
|
|
21
|
+
# FalconMambaConfig: use_mambapy
|
|
22
|
+
if hasattr(config, "text_config"):
|
|
23
|
+
# The model is probably of mixture of models used only for text.
|
|
24
|
+
config = config.text_config
|
|
25
|
+
check_hasattr(
|
|
26
|
+
config,
|
|
27
|
+
("head_dim", ("hidden_size", "num_attention_heads"), "use_mambapy"),
|
|
28
|
+
"num_hidden_layers",
|
|
29
|
+
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
|
|
30
|
+
"hidden_size",
|
|
31
|
+
"vocab_size",
|
|
32
|
+
)
|
|
33
|
+
if config.__class__.__name__ == "FalconMambaConfig":
|
|
34
|
+
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
|
|
35
|
+
kwargs = dict(
|
|
36
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
37
|
+
intermediate_size=256 if config is None else min(512, config.intermediate_size),
|
|
38
|
+
hidden_size=512 if config is None else min(512, config.hidden_size),
|
|
39
|
+
cls_cache="MambaCache",
|
|
40
|
+
state_size=8 if config is None else getattr(config, "state_size", None),
|
|
41
|
+
conv_kernel=4 if config is None else getattr(config, "conv_kernel", None),
|
|
42
|
+
)
|
|
43
|
+
else:
|
|
44
|
+
kwargs = dict(
|
|
45
|
+
head_dim=getattr(
|
|
46
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
47
|
+
),
|
|
48
|
+
num_hidden_layers=min(config.num_hidden_layers, nhl()),
|
|
49
|
+
num_key_value_heads=(
|
|
50
|
+
config.num_key_value_heads
|
|
51
|
+
if hasattr(config, "num_key_value_heads")
|
|
52
|
+
else config.num_attention_heads
|
|
53
|
+
),
|
|
54
|
+
)
|
|
55
|
+
update_config(config, kwargs)
|
|
56
|
+
return kwargs
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def get_inputs(
|
|
60
|
+
model: torch.nn.Module,
|
|
61
|
+
config: Optional[Any],
|
|
62
|
+
dummy_max_token_id: int,
|
|
63
|
+
num_hidden_layers: int,
|
|
64
|
+
batch_size: int = 2,
|
|
65
|
+
sequence_length: int = 30,
|
|
66
|
+
sequence_length2: int = 3,
|
|
67
|
+
dynamic_rope: bool = False,
|
|
68
|
+
num_key_value_heads: Optional[int] = None,
|
|
69
|
+
head_dim: Optional[int] = None,
|
|
70
|
+
cls_cache: Optional[Union[type, str]] = None,
|
|
71
|
+
add_second_input: int = 1,
|
|
72
|
+
**kwargs, # unused
|
|
73
|
+
):
|
|
74
|
+
"""
|
|
75
|
+
Generates input for task ``text-generation``.
|
|
76
|
+
|
|
77
|
+
:param model: model to get the missing information
|
|
78
|
+
:param config: configuration used to generate the model
|
|
79
|
+
:param head_dim: last dimension of the cache
|
|
80
|
+
:param dummy_max_token_id: dummy max token id
|
|
81
|
+
:param batch_size: batch size
|
|
82
|
+
:param sequence_length: sequence length
|
|
83
|
+
:param sequence_length2: new sequence length
|
|
84
|
+
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
85
|
+
:param cls_cache: cache class, by default it is
|
|
86
|
+
:class:`transformers.cache_utils.DynamicCache`
|
|
87
|
+
:return: dictionary
|
|
88
|
+
"""
|
|
89
|
+
batch = "batch"
|
|
90
|
+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
91
|
+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
92
|
+
|
|
93
|
+
if config is not None and config.__class__.__name__ == "FalconMambaConfig":
|
|
94
|
+
try:
|
|
95
|
+
from transformers.models.mamba.modeling_mamba import MambaCache
|
|
96
|
+
except ImportError:
|
|
97
|
+
from transformers.cache_utils import MambaCache
|
|
98
|
+
|
|
99
|
+
assert cls_cache in (
|
|
100
|
+
"MambaCache",
|
|
101
|
+
MambaCache,
|
|
102
|
+
), f"Unexpected value for cls_cache={cls_cache} and config={config}"
|
|
103
|
+
seq_length_multiple = 8
|
|
104
|
+
sequence_length = (
|
|
105
|
+
(sequence_length + seq_length_multiple)
|
|
106
|
+
// seq_length_multiple
|
|
107
|
+
* seq_length_multiple
|
|
108
|
+
)
|
|
109
|
+
# sequence_inc = seq_length_multiple
|
|
110
|
+
sequence_length2 = seq_length_multiple
|
|
111
|
+
|
|
112
|
+
shapes = {
|
|
113
|
+
"input_ids": {0: batch, 1: "sequence_length"},
|
|
114
|
+
"attention_mask": {
|
|
115
|
+
0: batch,
|
|
116
|
+
1: "cache+seq", # cache_length + seq_length
|
|
117
|
+
},
|
|
118
|
+
"cache_position": {
|
|
119
|
+
0: batch,
|
|
120
|
+
1: "cache+seq", # cache_length + seq_length
|
|
121
|
+
},
|
|
122
|
+
"cache_params": [{0: batch} for _ in range(num_hidden_layers * 2)],
|
|
123
|
+
}
|
|
124
|
+
inputs = dict(
|
|
125
|
+
input_ids=torch.randint(
|
|
126
|
+
0, dummy_max_token_id, (batch_size, sequence_length + sequence_length2)
|
|
127
|
+
).to(torch.int64),
|
|
128
|
+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
|
|
129
|
+
torch.int64
|
|
130
|
+
),
|
|
131
|
+
cache_position=torch.arange(0, kwargs["conv_kernel"]).to(torch.int64),
|
|
132
|
+
# .expand((batch_size, -1))
|
|
133
|
+
cache_params=make_mamba_cache(
|
|
134
|
+
[
|
|
135
|
+
(
|
|
136
|
+
torch.randn(
|
|
137
|
+
batch_size, kwargs["intermediate_size"], kwargs["conv_kernel"]
|
|
138
|
+
),
|
|
139
|
+
torch.randn(
|
|
140
|
+
batch_size, kwargs["intermediate_size"], kwargs["state_size"]
|
|
141
|
+
),
|
|
142
|
+
)
|
|
143
|
+
for i in range(num_hidden_layers)
|
|
144
|
+
]
|
|
145
|
+
),
|
|
146
|
+
)
|
|
147
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
148
|
+
else:
|
|
149
|
+
if head_dim is None:
|
|
150
|
+
assert config, "head_dim is None, the value cannot be set without a configuration"
|
|
151
|
+
head_dim = config.hidden_size // config.num_attention_heads
|
|
152
|
+
|
|
153
|
+
cache_name = (
|
|
154
|
+
cls_cache
|
|
155
|
+
if cls_cache is None or isinstance(cls_cache, str)
|
|
156
|
+
else cls_cache.__name__
|
|
157
|
+
)
|
|
158
|
+
make_caches = {
|
|
159
|
+
"DynamicCache": make_dynamic_cache,
|
|
160
|
+
"SlidingWindowCache": make_sliding_window_cache,
|
|
161
|
+
"StaticCache": make_static_cache,
|
|
162
|
+
}
|
|
163
|
+
assert cache_name is None or cache_name in make_caches, (
|
|
164
|
+
f"Unable to handle cls_cache={cache_name!r}, it should be in "
|
|
165
|
+
f"{sorted(make_caches)}"
|
|
166
|
+
)
|
|
167
|
+
make_cache = make_dynamic_cache if cache_name is None else make_caches[cache_name]
|
|
168
|
+
is_static = cache_name == "StaticCache"
|
|
169
|
+
|
|
170
|
+
if is_static:
|
|
171
|
+
# static
|
|
172
|
+
shapes = {
|
|
173
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
174
|
+
"attention_mask": {0: batch, 2: "seq"},
|
|
175
|
+
"cache_position": {0: "seq"},
|
|
176
|
+
"past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
|
|
177
|
+
}
|
|
178
|
+
inputs = dict(
|
|
179
|
+
input_ids=torch.randint(
|
|
180
|
+
0, dummy_max_token_id, (batch_size, sequence_length2)
|
|
181
|
+
).to(torch.int64),
|
|
182
|
+
attention_mask=torch.ones(
|
|
183
|
+
(batch_size, num_key_value_heads, sequence_length2, head_dim)
|
|
184
|
+
).to(torch.bool),
|
|
185
|
+
cache_position=torch.arange(sequence_length2).to(torch.int64),
|
|
186
|
+
past_key_values=make_static_cache(
|
|
187
|
+
[
|
|
188
|
+
(
|
|
189
|
+
torch.randn(
|
|
190
|
+
batch_size,
|
|
191
|
+
num_key_value_heads,
|
|
192
|
+
sequence_length + sequence_length2,
|
|
193
|
+
head_dim,
|
|
194
|
+
),
|
|
195
|
+
torch.randn(
|
|
196
|
+
batch_size,
|
|
197
|
+
num_key_value_heads,
|
|
198
|
+
sequence_length + sequence_length2,
|
|
199
|
+
head_dim,
|
|
200
|
+
),
|
|
201
|
+
)
|
|
202
|
+
for i in range(num_hidden_layers)
|
|
203
|
+
],
|
|
204
|
+
max_cache_len=max(sequence_length + sequence_length2, head_dim),
|
|
205
|
+
),
|
|
206
|
+
)
|
|
207
|
+
else:
|
|
208
|
+
# dynamic
|
|
209
|
+
shapes = {
|
|
210
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
211
|
+
"attention_mask": {
|
|
212
|
+
0: batch,
|
|
213
|
+
1: "cache+seq", # cache_length + seq_length
|
|
214
|
+
},
|
|
215
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
216
|
+
"past_key_values": [
|
|
217
|
+
{0: batch, 2: cache_length} for _ in range(num_hidden_layers * 2)
|
|
218
|
+
],
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
inputs = dict(
|
|
222
|
+
input_ids=torch.randint(
|
|
223
|
+
0, dummy_max_token_id, (batch_size, sequence_length2)
|
|
224
|
+
).to(torch.int64),
|
|
225
|
+
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
|
|
226
|
+
torch.int64
|
|
227
|
+
),
|
|
228
|
+
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
|
|
229
|
+
.to(torch.int64)
|
|
230
|
+
.expand((batch_size, -1)),
|
|
231
|
+
past_key_values=make_cache( # type: ignore[operator]
|
|
232
|
+
[
|
|
233
|
+
(
|
|
234
|
+
torch.randn(
|
|
235
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
236
|
+
),
|
|
237
|
+
torch.randn(
|
|
238
|
+
batch_size, num_key_value_heads, sequence_length, head_dim
|
|
239
|
+
),
|
|
240
|
+
)
|
|
241
|
+
for i in range(num_hidden_layers)
|
|
242
|
+
]
|
|
243
|
+
),
|
|
244
|
+
)
|
|
245
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
246
|
+
if add_second_input:
|
|
247
|
+
res["inputs2"] = get_inputs(
|
|
248
|
+
model=model,
|
|
249
|
+
config=config,
|
|
250
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
251
|
+
num_hidden_layers=num_hidden_layers,
|
|
252
|
+
batch_size=(batch_size + 1) if add_second_input > 0 else 1,
|
|
253
|
+
sequence_length=sequence_length + 1,
|
|
254
|
+
sequence_length2=sequence_length2
|
|
255
|
+
+ (add_second_input if add_second_input > 0 else -add_second_input),
|
|
256
|
+
dynamic_rope=dynamic_rope,
|
|
257
|
+
num_key_value_heads=num_key_value_heads,
|
|
258
|
+
head_dim=head_dim,
|
|
259
|
+
cls_cache=cls_cache,
|
|
260
|
+
add_second_input=0,
|
|
261
|
+
**kwargs,
|
|
262
|
+
)["inputs"]
|
|
263
|
+
res["inputs_empty_cache"] = get_inputs(
|
|
264
|
+
model=model,
|
|
265
|
+
config=config,
|
|
266
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
267
|
+
num_hidden_layers=num_hidden_layers,
|
|
268
|
+
batch_size=batch_size,
|
|
269
|
+
sequence_length=0,
|
|
270
|
+
sequence_length2=sequence_length2,
|
|
271
|
+
dynamic_rope=dynamic_rope,
|
|
272
|
+
num_key_value_heads=num_key_value_heads,
|
|
273
|
+
head_dim=head_dim,
|
|
274
|
+
cls_cache=cls_cache,
|
|
275
|
+
add_second_input=0,
|
|
276
|
+
**kwargs,
|
|
277
|
+
)["inputs"]
|
|
278
|
+
res["inputs_batch1"] = get_inputs(
|
|
279
|
+
model=model,
|
|
280
|
+
config=config,
|
|
281
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
282
|
+
num_hidden_layers=num_hidden_layers,
|
|
283
|
+
batch_size=1,
|
|
284
|
+
sequence_length=sequence_length,
|
|
285
|
+
sequence_length2=sequence_length2,
|
|
286
|
+
dynamic_rope=dynamic_rope,
|
|
287
|
+
num_key_value_heads=num_key_value_heads,
|
|
288
|
+
head_dim=head_dim,
|
|
289
|
+
cls_cache=cls_cache,
|
|
290
|
+
add_second_input=0,
|
|
291
|
+
**kwargs,
|
|
292
|
+
)["inputs"]
|
|
293
|
+
return res
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
297
|
+
"""
|
|
298
|
+
Inputs kwargs.
|
|
299
|
+
|
|
300
|
+
If the configuration is None, the function selects typical dimensions.
|
|
301
|
+
"""
|
|
302
|
+
if hasattr(config, "text_config"):
|
|
303
|
+
# The model is probably of mixture of models used only for text.
|
|
304
|
+
config = config.text_config
|
|
305
|
+
if config is not None:
|
|
306
|
+
check_hasattr(
|
|
307
|
+
config,
|
|
308
|
+
"vocab_size",
|
|
309
|
+
("num_attention_heads", "use_mambapy"),
|
|
310
|
+
("num_key_value_heads", "num_attention_heads", "use_mambapy"),
|
|
311
|
+
"hidden_size",
|
|
312
|
+
)
|
|
313
|
+
if config.__class__.__name__ == "FalconMambaConfig":
|
|
314
|
+
check_hasattr(config, "conv_kernel", "state_size", "intermediate_size") # 4 and 8
|
|
315
|
+
kwargs = dict(
|
|
316
|
+
batch_size=2,
|
|
317
|
+
sequence_length=30,
|
|
318
|
+
sequence_length2=3,
|
|
319
|
+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
320
|
+
num_hidden_layers=4 if config is None else config.num_hidden_layers,
|
|
321
|
+
intermediate_size=256 if config is None else config.intermediate_size,
|
|
322
|
+
cls_cache="MambaCache",
|
|
323
|
+
state_size=8 if config is None else getattr(config, "state_size", None),
|
|
324
|
+
conv_kernel=8 if config is None else getattr(config, "conv_kernel", None),
|
|
325
|
+
)
|
|
326
|
+
else:
|
|
327
|
+
kwargs = dict(
|
|
328
|
+
batch_size=2,
|
|
329
|
+
sequence_length=30,
|
|
330
|
+
sequence_length2=3,
|
|
331
|
+
head_dim=(
|
|
332
|
+
16
|
|
333
|
+
if config is None
|
|
334
|
+
else getattr(
|
|
335
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
|
336
|
+
)
|
|
337
|
+
),
|
|
338
|
+
dummy_max_token_id=31999 if config is None else (config.vocab_size - 1),
|
|
339
|
+
num_hidden_layers=4 if config is None else config.num_hidden_layers,
|
|
340
|
+
num_key_value_heads=(
|
|
341
|
+
24
|
|
342
|
+
if config is None
|
|
343
|
+
else _pick(config, "num_key_value_heads", "num_attention_heads")
|
|
344
|
+
),
|
|
345
|
+
hidden_size=512 if config is None else config.hidden_size,
|
|
346
|
+
)
|
|
347
|
+
if config is None or hasattr(config, "intermediate_size"):
|
|
348
|
+
kwargs["intermediate_size"] = (
|
|
349
|
+
1024 if config is None else config.intermediate_size,
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
return kwargs, get_inputs
|