onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +281 -80
- onnx_diagnostic/doc.py +22 -0
- onnx_diagnostic/export/dynamic_shapes.py +48 -20
- onnx_diagnostic/export/shape_helper.py +126 -0
- onnx_diagnostic/ext_test_case.py +1 -1
- onnx_diagnostic/helpers/cache_helper.py +78 -8
- onnx_diagnostic/helpers/config_helper.py +8 -4
- onnx_diagnostic/helpers/helper.py +30 -3
- onnx_diagnostic/helpers/log_helper.py +1744 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
- onnx_diagnostic/helpers/model_builder_helper.py +54 -73
- onnx_diagnostic/helpers/torch_helper.py +18 -2
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/ort_evaluator.py +29 -4
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +21 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
- onnx_diagnostic/tasks/feature_extraction.py +3 -0
- onnx_diagnostic/tasks/fill_mask.py +3 -0
- onnx_diagnostic/tasks/image_classification.py +7 -1
- onnx_diagnostic/tasks/image_text_to_text.py +72 -18
- onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
- onnx_diagnostic/tasks/object_detection.py +3 -0
- onnx_diagnostic/tasks/sentence_similarity.py +3 -0
- onnx_diagnostic/tasks/summarization.py +3 -0
- onnx_diagnostic/tasks/text2text_generation.py +3 -0
- onnx_diagnostic/tasks/text_classification.py +3 -0
- onnx_diagnostic/tasks/text_generation.py +90 -43
- onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
- onnx_diagnostic/torch_models/hghub/hub_api.py +81 -8
- onnx_diagnostic/torch_models/hghub/hub_data.py +6 -2
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +58 -14
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
- onnx_diagnostic/torch_models/{test_helper.py → validate.py} +166 -106
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/METADATA +2 -2
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/RECORD +44 -41
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.1.dist-info}/top_level.txt +0 -0
|
@@ -3953,6 +3953,46 @@ def _ccached_facebook_bart_large_cnn():
|
|
|
3953
3953
|
)
|
|
3954
3954
|
|
|
3955
3955
|
|
|
3956
|
+
def _ccached_microsoft_phi3_mini_4k_instruct():
|
|
3957
|
+
"microsoft/Phi-3-mini-4k-instruct"
|
|
3958
|
+
return transformers.Phi3Config(
|
|
3959
|
+
**{
|
|
3960
|
+
"_name_or_path": "Phi-3-mini-4k-instruct",
|
|
3961
|
+
"architectures": ["Phi3ForCausalLM"],
|
|
3962
|
+
"attention_dropout": 0.0,
|
|
3963
|
+
"auto_map": {
|
|
3964
|
+
"AutoConfig": "configuration_phi3.Phi3Config",
|
|
3965
|
+
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
|
|
3966
|
+
},
|
|
3967
|
+
"bos_token_id": 1,
|
|
3968
|
+
"embd_pdrop": 0.0,
|
|
3969
|
+
"eos_token_id": 32000,
|
|
3970
|
+
"hidden_act": "silu",
|
|
3971
|
+
"hidden_size": 3072,
|
|
3972
|
+
"initializer_range": 0.02,
|
|
3973
|
+
"intermediate_size": 8192,
|
|
3974
|
+
"max_position_embeddings": 4096,
|
|
3975
|
+
"model_type": "phi3",
|
|
3976
|
+
"num_attention_heads": 32,
|
|
3977
|
+
"num_hidden_layers": 32,
|
|
3978
|
+
"num_key_value_heads": 32,
|
|
3979
|
+
"original_max_position_embeddings": 4096,
|
|
3980
|
+
"pad_token_id": 32000,
|
|
3981
|
+
"resid_pdrop": 0.0,
|
|
3982
|
+
"rms_norm_eps": 1e-05,
|
|
3983
|
+
"rope_scaling": null,
|
|
3984
|
+
"rope_theta": 10000.0,
|
|
3985
|
+
"sliding_window": 2047,
|
|
3986
|
+
"tie_word_embeddings": false,
|
|
3987
|
+
"torch_dtype": "bfloat16",
|
|
3988
|
+
"transformers_version": "4.40.2",
|
|
3989
|
+
"use_cache": true,
|
|
3990
|
+
"attention_bias": false,
|
|
3991
|
+
"vocab_size": 32064,
|
|
3992
|
+
}
|
|
3993
|
+
)
|
|
3994
|
+
|
|
3995
|
+
|
|
3956
3996
|
def _ccached_microsoft_phi4_reasoning():
|
|
3957
3997
|
"microsoft/Phi-4-mini-reasoning"
|
|
3958
3998
|
return transformers.Phi3Config(
|
|
@@ -4093,3 +4133,172 @@ def _ccached_microsoft_phi4_reasoning():
|
|
|
4093
4133
|
"vocab_size": 200064,
|
|
4094
4134
|
}
|
|
4095
4135
|
)
|
|
4136
|
+
|
|
4137
|
+
|
|
4138
|
+
def _ccached_ydshieh_tiny_random_vit_for_image_classification():
|
|
4139
|
+
"ydshieh/tiny-random-ViTForImageClassification"
|
|
4140
|
+
return transformers.Phi3Config(
|
|
4141
|
+
**{
|
|
4142
|
+
"_name_or_path": ".temp/dummy/vit/ViTForImageClassification",
|
|
4143
|
+
"architectures": ["ViTForImageClassification"],
|
|
4144
|
+
"attention_probs_dropout_prob": 0.1,
|
|
4145
|
+
"encoder_stride": 2,
|
|
4146
|
+
"hidden_act": "gelu",
|
|
4147
|
+
"hidden_dropout_prob": 0.1,
|
|
4148
|
+
"hidden_size": 32,
|
|
4149
|
+
"image_size": 30,
|
|
4150
|
+
"initializer_range": 0.02,
|
|
4151
|
+
"intermediate_size": 37,
|
|
4152
|
+
"layer_norm_eps": 1e-12,
|
|
4153
|
+
"model_type": "vit",
|
|
4154
|
+
"num_attention_heads": 4,
|
|
4155
|
+
"num_channels": 3,
|
|
4156
|
+
"num_hidden_layers": 5,
|
|
4157
|
+
"patch_size": 2,
|
|
4158
|
+
"qkv_bias": true,
|
|
4159
|
+
"torch_dtype": "float32",
|
|
4160
|
+
"transformers_version": "4.24.0.dev0",
|
|
4161
|
+
}
|
|
4162
|
+
)
|
|
4163
|
+
|
|
4164
|
+
|
|
4165
|
+
def _ccached_microsoft_phi_35_mini_instruct():
|
|
4166
|
+
"microsoft/Phi-3.5-mini-instruct"
|
|
4167
|
+
return transformers.Phi3Config(
|
|
4168
|
+
**{
|
|
4169
|
+
"_name_or_path": "Phi-3.5-mini-instruct",
|
|
4170
|
+
"architectures": ["Phi3ForCausalLM"],
|
|
4171
|
+
"attention_dropout": 0.0,
|
|
4172
|
+
"auto_map": {
|
|
4173
|
+
"AutoConfig": "configuration_phi3.Phi3Config",
|
|
4174
|
+
"AutoModelForCausalLM": "modeling_phi3.Phi3ForCausalLM",
|
|
4175
|
+
},
|
|
4176
|
+
"bos_token_id": 1,
|
|
4177
|
+
"embd_pdrop": 0.0,
|
|
4178
|
+
"eos_token_id": 32000,
|
|
4179
|
+
"hidden_act": "silu",
|
|
4180
|
+
"hidden_size": 3072,
|
|
4181
|
+
"initializer_range": 0.02,
|
|
4182
|
+
"intermediate_size": 8192,
|
|
4183
|
+
"max_position_embeddings": 131072,
|
|
4184
|
+
"model_type": "phi3",
|
|
4185
|
+
"num_attention_heads": 32,
|
|
4186
|
+
"num_hidden_layers": 32,
|
|
4187
|
+
"num_key_value_heads": 32,
|
|
4188
|
+
"original_max_position_embeddings": 4096,
|
|
4189
|
+
"pad_token_id": 32000,
|
|
4190
|
+
"resid_pdrop": 0.0,
|
|
4191
|
+
"rms_norm_eps": 1e-05,
|
|
4192
|
+
"rope_scaling": {
|
|
4193
|
+
"long_factor": [
|
|
4194
|
+
1.0800000429153442,
|
|
4195
|
+
1.1100000143051147,
|
|
4196
|
+
1.1399999856948853,
|
|
4197
|
+
1.340000033378601,
|
|
4198
|
+
1.5899999141693115,
|
|
4199
|
+
1.600000023841858,
|
|
4200
|
+
1.6200000047683716,
|
|
4201
|
+
2.620000123977661,
|
|
4202
|
+
3.2300000190734863,
|
|
4203
|
+
3.2300000190734863,
|
|
4204
|
+
4.789999961853027,
|
|
4205
|
+
7.400000095367432,
|
|
4206
|
+
7.700000286102295,
|
|
4207
|
+
9.09000015258789,
|
|
4208
|
+
12.199999809265137,
|
|
4209
|
+
17.670000076293945,
|
|
4210
|
+
24.46000099182129,
|
|
4211
|
+
28.57000160217285,
|
|
4212
|
+
30.420001983642578,
|
|
4213
|
+
30.840002059936523,
|
|
4214
|
+
32.590003967285156,
|
|
4215
|
+
32.93000411987305,
|
|
4216
|
+
42.320003509521484,
|
|
4217
|
+
44.96000289916992,
|
|
4218
|
+
50.340003967285156,
|
|
4219
|
+
50.45000457763672,
|
|
4220
|
+
57.55000305175781,
|
|
4221
|
+
57.93000411987305,
|
|
4222
|
+
58.21000289916992,
|
|
4223
|
+
60.1400032043457,
|
|
4224
|
+
62.61000442504883,
|
|
4225
|
+
62.62000274658203,
|
|
4226
|
+
62.71000289916992,
|
|
4227
|
+
63.1400032043457,
|
|
4228
|
+
63.1400032043457,
|
|
4229
|
+
63.77000427246094,
|
|
4230
|
+
63.93000411987305,
|
|
4231
|
+
63.96000289916992,
|
|
4232
|
+
63.970001220703125,
|
|
4233
|
+
64.02999877929688,
|
|
4234
|
+
64.06999969482422,
|
|
4235
|
+
64.08000183105469,
|
|
4236
|
+
64.12000274658203,
|
|
4237
|
+
64.41000366210938,
|
|
4238
|
+
64.4800033569336,
|
|
4239
|
+
64.51000213623047,
|
|
4240
|
+
64.52999877929688,
|
|
4241
|
+
64.83999633789062,
|
|
4242
|
+
],
|
|
4243
|
+
"short_factor": [
|
|
4244
|
+
1.0,
|
|
4245
|
+
1.0199999809265137,
|
|
4246
|
+
1.0299999713897705,
|
|
4247
|
+
1.0299999713897705,
|
|
4248
|
+
1.0499999523162842,
|
|
4249
|
+
1.0499999523162842,
|
|
4250
|
+
1.0499999523162842,
|
|
4251
|
+
1.0499999523162842,
|
|
4252
|
+
1.0499999523162842,
|
|
4253
|
+
1.0699999332427979,
|
|
4254
|
+
1.0999999046325684,
|
|
4255
|
+
1.1099998950958252,
|
|
4256
|
+
1.1599998474121094,
|
|
4257
|
+
1.1599998474121094,
|
|
4258
|
+
1.1699998378753662,
|
|
4259
|
+
1.2899998426437378,
|
|
4260
|
+
1.339999794960022,
|
|
4261
|
+
1.679999828338623,
|
|
4262
|
+
1.7899998426437378,
|
|
4263
|
+
1.8199998140335083,
|
|
4264
|
+
1.8499997854232788,
|
|
4265
|
+
1.8799997568130493,
|
|
4266
|
+
1.9099997282028198,
|
|
4267
|
+
1.9399996995925903,
|
|
4268
|
+
1.9899996519088745,
|
|
4269
|
+
2.0199997425079346,
|
|
4270
|
+
2.0199997425079346,
|
|
4271
|
+
2.0199997425079346,
|
|
4272
|
+
2.0199997425079346,
|
|
4273
|
+
2.0199997425079346,
|
|
4274
|
+
2.0199997425079346,
|
|
4275
|
+
2.0299997329711914,
|
|
4276
|
+
2.0299997329711914,
|
|
4277
|
+
2.0299997329711914,
|
|
4278
|
+
2.0299997329711914,
|
|
4279
|
+
2.0299997329711914,
|
|
4280
|
+
2.0299997329711914,
|
|
4281
|
+
2.0299997329711914,
|
|
4282
|
+
2.0299997329711914,
|
|
4283
|
+
2.0299997329711914,
|
|
4284
|
+
2.0799996852874756,
|
|
4285
|
+
2.0899996757507324,
|
|
4286
|
+
2.189999580383301,
|
|
4287
|
+
2.2199995517730713,
|
|
4288
|
+
2.5899994373321533,
|
|
4289
|
+
2.729999542236328,
|
|
4290
|
+
2.749999523162842,
|
|
4291
|
+
2.8399994373321533,
|
|
4292
|
+
],
|
|
4293
|
+
"type": "longrope",
|
|
4294
|
+
},
|
|
4295
|
+
"rope_theta": 10000.0,
|
|
4296
|
+
"sliding_window": 262144,
|
|
4297
|
+
"tie_word_embeddings": false,
|
|
4298
|
+
"torch_dtype": "bfloat16",
|
|
4299
|
+
"transformers_version": "4.43.3",
|
|
4300
|
+
"use_cache": true,
|
|
4301
|
+
"attention_bias": false,
|
|
4302
|
+
"vocab_size": 32064,
|
|
4303
|
+
}
|
|
4304
|
+
)
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import os
|
|
3
|
+
import pprint
|
|
3
4
|
from typing import Any, Dict, Optional, Tuple
|
|
4
5
|
import torch
|
|
5
6
|
import transformers
|
|
6
7
|
from ...helpers.config_helper import update_config
|
|
7
8
|
from ...tasks import reduce_model_config, random_input_kwargs
|
|
8
|
-
from .hub_api import task_from_arch, task_from_id, get_pretrained_config
|
|
9
|
+
from .hub_api import task_from_arch, task_from_id, get_pretrained_config, download_code_modelid
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def _code_needing_rewriting(model: Any) -> Any:
|
|
@@ -22,10 +23,12 @@ def get_untrained_model_with_inputs(
|
|
|
22
23
|
model_kwargs: Optional[Dict[str, Any]] = None,
|
|
23
24
|
verbose: int = 0,
|
|
24
25
|
dynamic_rope: Optional[bool] = None,
|
|
26
|
+
use_pretrained: bool = False,
|
|
25
27
|
same_as_pretrained: bool = False,
|
|
26
28
|
use_preinstalled: bool = True,
|
|
27
29
|
add_second_input: bool = False,
|
|
28
30
|
subfolder: Optional[str] = None,
|
|
31
|
+
use_only_preinstalled: bool = False,
|
|
29
32
|
) -> Dict[str, Any]:
|
|
30
33
|
"""
|
|
31
34
|
Gets a non initialized model similar to the original model
|
|
@@ -42,10 +45,12 @@ def get_untrained_model_with_inputs(
|
|
|
42
45
|
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
43
46
|
:param same_as_pretrained: if True, do not change the default values
|
|
44
47
|
to get a smaller model
|
|
48
|
+
:param use_pretrained: download the pretrained weights as well
|
|
45
49
|
:param use_preinstalled: use preinstalled configurations
|
|
46
50
|
:param add_second_input: provides a second inputs to check a model
|
|
47
51
|
supports different shapes
|
|
48
52
|
:param subfolder: subfolder to use for this model id
|
|
53
|
+
:param use_only_preinstalled: use only preinstalled version
|
|
49
54
|
:return: dictionary with a model, inputs, dynamic shapes, and the configuration,
|
|
50
55
|
some necessary rewriting as well
|
|
51
56
|
|
|
@@ -66,6 +71,10 @@ def get_untrained_model_with_inputs(
|
|
|
66
71
|
print("-- dynamic shapes:", pprint.pformat(data['dynamic_shapes']))
|
|
67
72
|
print("-- configuration:", pprint.pformat(data['configuration']))
|
|
68
73
|
"""
|
|
74
|
+
assert not use_preinstalled or not use_only_preinstalled, (
|
|
75
|
+
f"model_id={model_id!r}, pretinstalled model is only available "
|
|
76
|
+
f"if use_only_preinstalled is False."
|
|
77
|
+
)
|
|
69
78
|
if verbose:
|
|
70
79
|
print(f"[get_untrained_model_with_inputs] model_id={model_id!r}")
|
|
71
80
|
if use_preinstalled:
|
|
@@ -74,6 +83,7 @@ def get_untrained_model_with_inputs(
|
|
|
74
83
|
config = get_pretrained_config(
|
|
75
84
|
model_id,
|
|
76
85
|
use_preinstalled=use_preinstalled,
|
|
86
|
+
use_only_preinstalled=use_only_preinstalled,
|
|
77
87
|
subfolder=subfolder,
|
|
78
88
|
**(model_kwargs or {}),
|
|
79
89
|
)
|
|
@@ -96,7 +106,7 @@ def get_untrained_model_with_inputs(
|
|
|
96
106
|
print(f"[get_untrained_model_with_inputs] architectures={archs!r}")
|
|
97
107
|
print(f"[get_untrained_model_with_inputs] cls={config.__class__.__name__!r}")
|
|
98
108
|
if task is None:
|
|
99
|
-
task = task_from_arch(archs[0])
|
|
109
|
+
task = task_from_arch(archs[0], model_id=model_id)
|
|
100
110
|
if verbose:
|
|
101
111
|
print(f"[get_untrained_model_with_inputs] task={task!r}")
|
|
102
112
|
|
|
@@ -111,7 +121,6 @@ def get_untrained_model_with_inputs(
|
|
|
111
121
|
)
|
|
112
122
|
|
|
113
123
|
# updating the configuration
|
|
114
|
-
|
|
115
124
|
mkwargs = reduce_model_config(config, task) if not same_as_pretrained else {}
|
|
116
125
|
if model_kwargs:
|
|
117
126
|
for k, v in model_kwargs.items():
|
|
@@ -136,27 +145,62 @@ def get_untrained_model_with_inputs(
|
|
|
136
145
|
f"{config._attn_implementation!r}" # type: ignore[union-attr]
|
|
137
146
|
)
|
|
138
147
|
|
|
148
|
+
if use_pretrained:
|
|
149
|
+
model = transformers.AutoModel.from_pretrained(model_id, **mkwargs)
|
|
150
|
+
else:
|
|
151
|
+
if archs is not None:
|
|
152
|
+
try:
|
|
153
|
+
model = getattr(transformers, archs[0])(config)
|
|
154
|
+
except AttributeError as e:
|
|
155
|
+
# The code of the models is not in transformers but in the
|
|
156
|
+
# repository of the model. We need to download it.
|
|
157
|
+
pyfiles = download_code_modelid(model_id, verbose=verbose)
|
|
158
|
+
if pyfiles:
|
|
159
|
+
if "." in archs[0]:
|
|
160
|
+
cls_name = archs[0]
|
|
161
|
+
else:
|
|
162
|
+
modeling = [_ for _ in pyfiles if "/modeling_" in _]
|
|
163
|
+
assert len(modeling) == 1, (
|
|
164
|
+
f"Unable to guess the main file implemented class {archs[0]!r} "
|
|
165
|
+
f"from {pyfiles}, found={modeling}."
|
|
166
|
+
)
|
|
167
|
+
last_name = os.path.splitext(os.path.split(modeling[0])[-1])[0]
|
|
168
|
+
cls_name = f"{last_name}.{archs[0]}"
|
|
169
|
+
if verbose:
|
|
170
|
+
print(
|
|
171
|
+
f"[get_untrained_model_with_inputs] custom code for {cls_name!r}"
|
|
172
|
+
)
|
|
173
|
+
print(
|
|
174
|
+
f"[get_untrained_model_with_inputs] from folder "
|
|
175
|
+
f"{os.path.split(pyfiles[0])[0]!r}"
|
|
176
|
+
)
|
|
177
|
+
cls = transformers.dynamic_module_utils.get_class_from_dynamic_module(
|
|
178
|
+
cls_name, pretrained_model_name_or_path=os.path.split(pyfiles[0])[0]
|
|
179
|
+
)
|
|
180
|
+
model = cls(config)
|
|
181
|
+
else:
|
|
182
|
+
raise AttributeError(
|
|
183
|
+
f"Unable to find class 'tranformers.{archs[0]}'. "
|
|
184
|
+
f"The code needs to be downloaded, config="
|
|
185
|
+
f"\n{pprint.pformat(config)}."
|
|
186
|
+
) from e
|
|
187
|
+
else:
|
|
188
|
+
assert same_as_pretrained and use_pretrained, (
|
|
189
|
+
f"Model {model_id!r} cannot be built, the model cannot be built. "
|
|
190
|
+
f"It must be downloaded. Use same_as_pretrained=True "
|
|
191
|
+
f"and use_pretrained=True."
|
|
192
|
+
)
|
|
193
|
+
|
|
139
194
|
# input kwargs
|
|
140
195
|
kwargs, fct = random_input_kwargs(config, task)
|
|
141
196
|
if verbose:
|
|
142
197
|
print(f"[get_untrained_model_with_inputs] use fct={fct}")
|
|
143
198
|
if os.environ.get("PRINT_CONFIG") in (1, "1"):
|
|
144
|
-
import pprint
|
|
145
|
-
|
|
146
199
|
print(f"-- input kwargs for task {task!r}")
|
|
147
200
|
pprint.pprint(kwargs)
|
|
148
201
|
if inputs_kwargs:
|
|
149
202
|
kwargs.update(inputs_kwargs)
|
|
150
203
|
|
|
151
|
-
if archs is not None:
|
|
152
|
-
model = getattr(transformers, archs[0])(config)
|
|
153
|
-
else:
|
|
154
|
-
assert same_as_pretrained, (
|
|
155
|
-
f"Model {model_id!r} cannot be built, the model cannot be built. "
|
|
156
|
-
f"It must be downloaded. Use same_as_pretrained=True."
|
|
157
|
-
)
|
|
158
|
-
model = None
|
|
159
|
-
|
|
160
204
|
# This line is important. Some models may produce different
|
|
161
205
|
# outputs even with the same inputs in training mode.
|
|
162
206
|
model.eval()
|
|
@@ -1,7 +1,5 @@
|
|
|
1
1
|
from typing import Any, Dict
|
|
2
|
-
import torch
|
|
3
2
|
import transformers
|
|
4
|
-
from ...helpers.cache_helper import make_dynamic_cache
|
|
5
3
|
|
|
6
4
|
|
|
7
5
|
def get_tiny_llm(
|
|
@@ -9,6 +7,7 @@ def get_tiny_llm(
|
|
|
9
7
|
sequence_length: int = 30,
|
|
10
8
|
sequence_length2: int = 3,
|
|
11
9
|
dynamic_rope: bool = False,
|
|
10
|
+
use_static_cache: bool = False,
|
|
12
11
|
**kwargs,
|
|
13
12
|
) -> Dict[str, Any]:
|
|
14
13
|
"""
|
|
@@ -18,11 +17,14 @@ def get_tiny_llm(
|
|
|
18
17
|
:param sequence_length: sequence length
|
|
19
18
|
:param sequence_length2: new sequence length
|
|
20
19
|
:param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
|
|
20
|
+
:param use_static_cache: use StaticCache instead of DynamicCache
|
|
21
21
|
:param kwargs: to overwrite the configuration, example ``num_hidden_layers=1``
|
|
22
22
|
:return: dictionary
|
|
23
23
|
|
|
24
24
|
See :ref:`l-plot-tiny-llm-export` or :ref:`l-plot-tiny-llm-export-patched` for examples.
|
|
25
25
|
"""
|
|
26
|
+
from ...tasks.text_generation import get_inputs
|
|
27
|
+
|
|
26
28
|
config = {
|
|
27
29
|
"architectures": ["LlamaForCausalLM"],
|
|
28
30
|
"bos_token_id": 1,
|
|
@@ -48,56 +50,27 @@ def get_tiny_llm(
|
|
|
48
50
|
|
|
49
51
|
config.update(**kwargs)
|
|
50
52
|
conf = transformers.LlamaConfig(**config)
|
|
53
|
+
if use_static_cache:
|
|
54
|
+
conf.cache_implementation = "static"
|
|
51
55
|
model = transformers.LlamaForCausalLM(conf)
|
|
52
56
|
model.eval()
|
|
53
57
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
58
|
+
res = get_inputs(
|
|
59
|
+
model,
|
|
60
|
+
conf,
|
|
61
|
+
dummy_max_token_id=config["vocab_size"], # type: ignore[arg-type]
|
|
62
|
+
num_hidden_layers=config["num_hidden_layers"], # type: ignore[arg-type]
|
|
63
|
+
batch_size=batch_size,
|
|
64
|
+
sequence_length=sequence_length,
|
|
65
|
+
sequence_length2=sequence_length2,
|
|
66
|
+
dynamic_rope=dynamic_rope,
|
|
67
|
+
num_key_value_heads=config["num_key_value_heads"], # type: ignore[arg-type]
|
|
68
|
+
cls_cache="StaticCache" if use_static_cache else "DynamicCache",
|
|
69
|
+
)
|
|
63
70
|
|
|
64
|
-
|
|
65
|
-
"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
},
|
|
70
|
-
"position_ids": {
|
|
71
|
-
0: batch,
|
|
72
|
-
1: torch.export.Dim.DYNAMIC, # cache_length + seq_length
|
|
73
|
-
},
|
|
74
|
-
"past_key_values": [
|
|
75
|
-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
|
|
76
|
-
[{0: batch, 2: cache_length} for _ in range(n_layers)],
|
|
77
|
-
],
|
|
78
|
-
}
|
|
79
|
-
inputs = dict(
|
|
80
|
-
input_ids=torch.randint(0, max_token_id, (batch_size, sequence_length2)).to(
|
|
81
|
-
torch.int64
|
|
82
|
-
),
|
|
83
|
-
attention_mask=torch.ones((batch_size, sequence_length + sequence_length2)).to(
|
|
84
|
-
torch.int64
|
|
85
|
-
),
|
|
86
|
-
position_ids=torch.arange(sequence_length, sequence_length + sequence_length2)
|
|
87
|
-
.to(torch.int64)
|
|
88
|
-
.expand((batch_size, -1)),
|
|
89
|
-
past_key_values=make_dynamic_cache(
|
|
90
|
-
[
|
|
91
|
-
(
|
|
92
|
-
torch.randn(
|
|
93
|
-
batch_size, num_key_value_heads, sequence_length, cache_last_dim
|
|
94
|
-
),
|
|
95
|
-
torch.randn(
|
|
96
|
-
batch_size, num_key_value_heads, sequence_length, cache_last_dim
|
|
97
|
-
),
|
|
98
|
-
)
|
|
99
|
-
for i in range(n_layers)
|
|
100
|
-
]
|
|
101
|
-
),
|
|
71
|
+
return dict(
|
|
72
|
+
inputs=res["inputs"],
|
|
73
|
+
model=model,
|
|
74
|
+
dynamic_shapes=res["dynamic_shapes"],
|
|
75
|
+
configuration=conf,
|
|
102
76
|
)
|
|
103
|
-
return dict(inputs=inputs, model=model, dynamic_shapes=shapes, configuration=conf)
|