onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,581 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
3
|
+
import torch
|
|
4
|
+
from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
|
|
5
|
+
from ..helpers.config_helper import (
|
|
6
|
+
update_config,
|
|
7
|
+
check_hasattr,
|
|
8
|
+
_pick,
|
|
9
|
+
default_num_hidden_layers as nhl,
|
|
10
|
+
)
|
|
11
|
+
from .data import get_data
|
|
12
|
+
|
|
13
|
+
__TASK__ = "image-text-to-text"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
17
|
+
"""Reduces a model size."""
|
|
18
|
+
kwargs: Dict[str, Any] = {}
|
|
19
|
+
if (
|
|
20
|
+
hasattr(config, "architectures")
|
|
21
|
+
and config.architectures
|
|
22
|
+
and config.architectures[0] == "Gemma3ForConditionalGeneration"
|
|
23
|
+
):
|
|
24
|
+
if hasattr(config, "vision_config"):
|
|
25
|
+
if hasattr(config.vision_config, "num_hidden_layers"):
|
|
26
|
+
config.vision_config.num_hidden_layers = min(
|
|
27
|
+
config.vision_config.num_hidden_layers, nhl()
|
|
28
|
+
)
|
|
29
|
+
if hasattr(config, "text_config"):
|
|
30
|
+
if hasattr(config.text_config, "intermediate_size"):
|
|
31
|
+
config.text_config.intermediate_size = min(
|
|
32
|
+
config.text_config.intermediate_size, 10240 // 10 * 5 // 2
|
|
33
|
+
)
|
|
34
|
+
config.text_config.hidden_size = min(
|
|
35
|
+
config.text_config.hidden_size, 2560 // 10 * 5 // 2
|
|
36
|
+
)
|
|
37
|
+
update_config(config, kwargs)
|
|
38
|
+
return kwargs
|
|
39
|
+
|
|
40
|
+
if hasattr(config, "num_hidden_layers"):
|
|
41
|
+
config.num_hidden_layers = min(config.num_hidden_layers, nhl())
|
|
42
|
+
if hasattr(config, "mm_tokens_per_image"):
|
|
43
|
+
config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
|
|
44
|
+
if hasattr(config, "vision_config"):
|
|
45
|
+
if hasattr(config.vision_config, "num_hidden_layers"):
|
|
46
|
+
config.vision_config.num_hidden_layers = min(
|
|
47
|
+
config.vision_config.num_hidden_layers, 2
|
|
48
|
+
)
|
|
49
|
+
if hasattr(config.vision_config, "num_heads"):
|
|
50
|
+
config.vision_config.num_heads = min(config.vision_config.num_heads, 4)
|
|
51
|
+
if hasattr(config.vision_config, "image_size"):
|
|
52
|
+
config.vision_config.image_size = min(config.vision_config.image_size, 168 // 2)
|
|
53
|
+
if hasattr(config.vision_config, "intermediate_size"):
|
|
54
|
+
config.vision_config.intermediate_size = min(
|
|
55
|
+
config.vision_config.intermediate_size, 1076
|
|
56
|
+
)
|
|
57
|
+
if hasattr(config.vision_config, "patch_size"):
|
|
58
|
+
config.vision_config.patch_size = min(config.vision_config.patch_size, 1)
|
|
59
|
+
if hasattr(config.vision_config, "temporal_patch_size"):
|
|
60
|
+
config.vision_config.temporal_patch_size = min(
|
|
61
|
+
config.vision_config.temporal_patch_size, 8
|
|
62
|
+
)
|
|
63
|
+
if hasattr(config.vision_config, "hidden_size"):
|
|
64
|
+
config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16)
|
|
65
|
+
if hasattr(config, "text_config"):
|
|
66
|
+
if hasattr(config.text_config, "intermediate_size"):
|
|
67
|
+
config.text_config.intermediate_size = min(
|
|
68
|
+
config.text_config.intermediate_size, 320
|
|
69
|
+
)
|
|
70
|
+
if hasattr(config.text_config, "hidden_size"):
|
|
71
|
+
config.text_config.hidden_size = min(config.text_config.hidden_size, 16)
|
|
72
|
+
if hasattr(config.text_config, "num_hidden_layers"):
|
|
73
|
+
config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2)
|
|
74
|
+
if hasattr(config.text_config, "layer_types"):
|
|
75
|
+
config.text_config.layer_types = config.text_config.layer_types[
|
|
76
|
+
: config.text_config.num_hidden_layers
|
|
77
|
+
]
|
|
78
|
+
if hasattr(config.text_config, "num_attention_heads"):
|
|
79
|
+
config.text_config.num_attention_heads = min(
|
|
80
|
+
config.text_config.num_attention_heads, 2
|
|
81
|
+
)
|
|
82
|
+
update_config(config, kwargs)
|
|
83
|
+
return kwargs
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _get_inputs_gemma3(
|
|
87
|
+
model: torch.nn.Module,
|
|
88
|
+
config: Optional[Any],
|
|
89
|
+
dummy_max_token_id: int,
|
|
90
|
+
num_key_value_heads: int,
|
|
91
|
+
num_hidden_layers: int,
|
|
92
|
+
pad_token_id: int,
|
|
93
|
+
image_token_index: int,
|
|
94
|
+
head_dim: int,
|
|
95
|
+
width: int,
|
|
96
|
+
height: int,
|
|
97
|
+
num_channels: int,
|
|
98
|
+
batch_size: Optional[int] = 1,
|
|
99
|
+
sequence_length: Optional[int] = 281,
|
|
100
|
+
n_images: Optional[int] = 1,
|
|
101
|
+
max_sequence_length: Optional[int] = 580,
|
|
102
|
+
total_sequence_length: Optional[int] = 860,
|
|
103
|
+
**kwargs, # unused
|
|
104
|
+
):
|
|
105
|
+
"""
|
|
106
|
+
The functions uses predefined values for input_ids and token_type_ids.
|
|
107
|
+
|
|
108
|
+
**google/gemma-3-4b-it**
|
|
109
|
+
|
|
110
|
+
iteration 1
|
|
111
|
+
|
|
112
|
+
::
|
|
113
|
+
cache_position:T7s281,
|
|
114
|
+
input_ids:T7s1x281,
|
|
115
|
+
token_type_ids:T7s1x281,
|
|
116
|
+
attention_mask:dict(sliding_attention:T9s1x1x281x580,
|
|
117
|
+
full_attention:T9s1x1x281x580),
|
|
118
|
+
pixel_values:T16s1x3x896x896,
|
|
119
|
+
|
|
120
|
+
iteration 2
|
|
121
|
+
|
|
122
|
+
::
|
|
123
|
+
|
|
124
|
+
cache_position:T7s1,
|
|
125
|
+
past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
|
|
126
|
+
value_cache=#34[T1s1x4x580x256,...]),
|
|
127
|
+
input_ids:T7s1x1,
|
|
128
|
+
inputs_embeds:None,
|
|
129
|
+
token_type_ids:T7s1x1,
|
|
130
|
+
attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
|
|
131
|
+
position_ids:None,
|
|
132
|
+
"""
|
|
133
|
+
batch_size = 1 if batch_size is None else batch_size
|
|
134
|
+
sequence_length = 281 if sequence_length is None else sequence_length
|
|
135
|
+
n_images = 1 if n_images is None else n_images
|
|
136
|
+
max_sequence_length = 580 if max_sequence_length is None else max_sequence_length
|
|
137
|
+
total_sequence_length = 860 if total_sequence_length is None else total_sequence_length
|
|
138
|
+
|
|
139
|
+
assert (
|
|
140
|
+
"cls_cache" not in kwargs
|
|
141
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
142
|
+
batch = "batch"
|
|
143
|
+
seq_length = "seq_length"
|
|
144
|
+
tot_length = "total_length"
|
|
145
|
+
|
|
146
|
+
shapes = {
|
|
147
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
148
|
+
"token_type_ids": {0: batch, 1: seq_length},
|
|
149
|
+
"attention_mask": {
|
|
150
|
+
"full_attention": {0: batch, 2: seq_length, 3: tot_length},
|
|
151
|
+
"sliding_attention": {0: batch, 2: seq_length, 3: tot_length},
|
|
152
|
+
},
|
|
153
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
154
|
+
"cache_position": {0: seq_length},
|
|
155
|
+
"past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
|
|
156
|
+
"pixel_values": {0: batch},
|
|
157
|
+
"use_cache": None,
|
|
158
|
+
}
|
|
159
|
+
|
|
160
|
+
# retrieve specific inputs to keep the consistency between
|
|
161
|
+
# ids and images
|
|
162
|
+
dummies = get_data("dummies_imagetext2text_generation_gemma3.onnx")
|
|
163
|
+
dummies = dummies[("", 0, "I")][1]
|
|
164
|
+
dummies = {k: v for k, v in dummies.items() if k in shapes}
|
|
165
|
+
expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
|
|
166
|
+
|
|
167
|
+
def _check_():
|
|
168
|
+
assert expected & set(
|
|
169
|
+
dummies
|
|
170
|
+
), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
|
|
171
|
+
assert sequence_length == dummies["input_ids"].shape[-1], (
|
|
172
|
+
f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
|
|
173
|
+
f"model class {model.__class__.__name__}"
|
|
174
|
+
)
|
|
175
|
+
assert batch_size == dummies["input_ids"].shape[0], (
|
|
176
|
+
f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
|
|
177
|
+
f"model class {model.__class__.__name__}"
|
|
178
|
+
)
|
|
179
|
+
assert max_sequence_length == 580, (
|
|
180
|
+
f"max_sequence_length={max_sequence_length} != 580 "
|
|
181
|
+
f"for model {model.__class__.__name__}"
|
|
182
|
+
)
|
|
183
|
+
assert total_sequence_length == 860, (
|
|
184
|
+
f"total_sequence_length={total_sequence_length} != 860 "
|
|
185
|
+
f"for model {model.__class__.__name__}"
|
|
186
|
+
)
|
|
187
|
+
assert head_dim in (
|
|
188
|
+
256,
|
|
189
|
+
32,
|
|
190
|
+
), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}"
|
|
191
|
+
assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
|
|
192
|
+
assert num_key_value_heads in (1, 4), (
|
|
193
|
+
f"num_key_value_heads={num_key_value_heads} not in (1, 4) "
|
|
194
|
+
f"for this model {model.__class__.__name__}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
_check_()
|
|
198
|
+
|
|
199
|
+
inputs = dict(
|
|
200
|
+
input_ids=dummies["input_ids"],
|
|
201
|
+
token_type_ids=dummies["token_type_ids"],
|
|
202
|
+
attention_mask=dict(
|
|
203
|
+
full_attention=torch.randn(batch_size, 1, sequence_length, total_sequence_length),
|
|
204
|
+
sliding_attention=torch.randn(
|
|
205
|
+
batch_size, 1, sequence_length, total_sequence_length
|
|
206
|
+
),
|
|
207
|
+
),
|
|
208
|
+
position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
|
|
209
|
+
cache_position=torch.arange(0, sequence_length).to(torch.int64),
|
|
210
|
+
past_key_values=make_hybrid_cache(
|
|
211
|
+
[
|
|
212
|
+
(
|
|
213
|
+
torch.randn(
|
|
214
|
+
batch_size, num_key_value_heads, max_sequence_length, head_dim
|
|
215
|
+
),
|
|
216
|
+
torch.randn(
|
|
217
|
+
batch_size, num_key_value_heads, max_sequence_length, head_dim
|
|
218
|
+
),
|
|
219
|
+
)
|
|
220
|
+
for i in range(num_hidden_layers)
|
|
221
|
+
]
|
|
222
|
+
),
|
|
223
|
+
pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
|
|
224
|
+
# image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
|
|
225
|
+
# torch.int64
|
|
226
|
+
# ),
|
|
227
|
+
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
228
|
+
)
|
|
229
|
+
return dict(inputs=inputs, dynamic_shapes=shapes)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def get_inputs_default(
|
|
233
|
+
model: torch.nn.Module,
|
|
234
|
+
config: Optional[Any],
|
|
235
|
+
dummy_max_token_id: int,
|
|
236
|
+
num_key_value_heads: int,
|
|
237
|
+
num_hidden_layers: int,
|
|
238
|
+
pad_token_id: int,
|
|
239
|
+
image_token_index: int,
|
|
240
|
+
head_dim: int,
|
|
241
|
+
width: int,
|
|
242
|
+
height: int,
|
|
243
|
+
num_channels: int,
|
|
244
|
+
batch_size: Optional[int] = 2,
|
|
245
|
+
sequence_length: Optional[int] = 43,
|
|
246
|
+
n_images: Optional[int] = 2,
|
|
247
|
+
max_sequence_length: Optional[int] = 43,
|
|
248
|
+
total_sequence_length: Optional[int] = 43,
|
|
249
|
+
add_second_input: int = 0,
|
|
250
|
+
**kwargs, # unused
|
|
251
|
+
):
|
|
252
|
+
batch_size = 2 if batch_size is None else batch_size
|
|
253
|
+
sequence_length = 43 if sequence_length is None else sequence_length
|
|
254
|
+
n_images = 2 if n_images is None else n_images
|
|
255
|
+
max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
|
|
256
|
+
total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
|
|
257
|
+
|
|
258
|
+
assert batch_size > 0, "batch_size cannot be null"
|
|
259
|
+
assert (
|
|
260
|
+
"cls_cache" not in kwargs
|
|
261
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
262
|
+
batch = "batch"
|
|
263
|
+
batch_img = torch.export.Dim("batch_img", min=1, max=1024)
|
|
264
|
+
seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
|
|
265
|
+
cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
|
|
266
|
+
images = "images" # torch.export.Dim("images", min=1, max=4096)
|
|
267
|
+
|
|
268
|
+
shapes = {
|
|
269
|
+
"input_ids": {0: batch, 1: seq_length},
|
|
270
|
+
"token_type_ids": {0: batch, 1: seq_length},
|
|
271
|
+
"attention_mask": {0: batch, 1: "cache+seq"},
|
|
272
|
+
"position_ids": {0: batch, 1: seq_length},
|
|
273
|
+
"past_key_values": list(
|
|
274
|
+
itertools.chain.from_iterable(
|
|
275
|
+
zip(
|
|
276
|
+
[{0: batch} for _ in range(num_hidden_layers)],
|
|
277
|
+
[{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
|
|
278
|
+
)
|
|
279
|
+
)
|
|
280
|
+
),
|
|
281
|
+
"pixel_values": (
|
|
282
|
+
{0: batch, 1: images}
|
|
283
|
+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
284
|
+
else {0: batch_img}
|
|
285
|
+
),
|
|
286
|
+
"image_attention_mask": {0: batch, 1: seq_length, 2: images},
|
|
287
|
+
"image_grid_thw": {0: batch},
|
|
288
|
+
"use_cache": None,
|
|
289
|
+
}
|
|
290
|
+
|
|
291
|
+
input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to(
|
|
292
|
+
torch.int64
|
|
293
|
+
)
|
|
294
|
+
if total_sequence_length > 0:
|
|
295
|
+
input_ids[0, 0] = image_token_index
|
|
296
|
+
if min(input_ids.shape) > 1:
|
|
297
|
+
input_ids[1, 1] = image_token_index
|
|
298
|
+
# input_ids[input_ids == image_token_index] = pad_token_id
|
|
299
|
+
token_type_ids = torch.zeros_like(input_ids)
|
|
300
|
+
token_type_ids[input_ids == image_token_index] = 1
|
|
301
|
+
image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
|
|
302
|
+
if n_images > 0:
|
|
303
|
+
image_grid_thw[:, 1] = height
|
|
304
|
+
image_grid_thw[:, 2] = width
|
|
305
|
+
image_grid_thw[0, :] //= 2
|
|
306
|
+
image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
|
|
307
|
+
|
|
308
|
+
inputs = dict(
|
|
309
|
+
input_ids=input_ids,
|
|
310
|
+
token_type_ids=token_type_ids,
|
|
311
|
+
attention_mask=torch.cat(
|
|
312
|
+
[
|
|
313
|
+
torch.ones((batch_size, sequence_length), dtype=torch.int64),
|
|
314
|
+
input_ids.ne(pad_token_id).to(torch.int64),
|
|
315
|
+
],
|
|
316
|
+
axis=-1,
|
|
317
|
+
),
|
|
318
|
+
position_ids=torch.arange(0, total_sequence_length)
|
|
319
|
+
.to(torch.int64)
|
|
320
|
+
.expand((batch_size, -1)),
|
|
321
|
+
past_key_values=make_dynamic_cache(
|
|
322
|
+
[
|
|
323
|
+
(
|
|
324
|
+
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
|
|
325
|
+
torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
|
|
326
|
+
)
|
|
327
|
+
for i in range(num_hidden_layers)
|
|
328
|
+
]
|
|
329
|
+
),
|
|
330
|
+
pixel_values=(
|
|
331
|
+
torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
|
|
332
|
+
if model.__class__.__name__ == "IdeficsForVisionText2Text"
|
|
333
|
+
else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
|
|
334
|
+
),
|
|
335
|
+
image_attention_mask=torch.ones((batch_size, total_sequence_length, n_images)).to(
|
|
336
|
+
torch.int64
|
|
337
|
+
),
|
|
338
|
+
image_grid_thw=image_grid_thw,
|
|
339
|
+
use_cache=True, # Gemma3 does not set this value to true when a cache is provided
|
|
340
|
+
)
|
|
341
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
342
|
+
return res
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def get_inputs(
|
|
346
|
+
model: torch.nn.Module,
|
|
347
|
+
config: Optional[Any],
|
|
348
|
+
dummy_max_token_id: int,
|
|
349
|
+
num_key_value_heads: int,
|
|
350
|
+
num_hidden_layers: int,
|
|
351
|
+
pad_token_id: int,
|
|
352
|
+
image_token_index: int,
|
|
353
|
+
head_dim: int,
|
|
354
|
+
width: int,
|
|
355
|
+
height: int,
|
|
356
|
+
num_channels: int,
|
|
357
|
+
batch_size: Optional[int] = None,
|
|
358
|
+
sequence_length: Optional[int] = None,
|
|
359
|
+
n_images: Optional[int] = None,
|
|
360
|
+
max_sequence_length: Optional[int] = None,
|
|
361
|
+
total_sequence_length: Optional[int] = None,
|
|
362
|
+
add_second_input: int = 0,
|
|
363
|
+
**kwargs, # unused
|
|
364
|
+
):
|
|
365
|
+
"""
|
|
366
|
+
Generates input for task ``image-text-to-text``.
|
|
367
|
+
|
|
368
|
+
:param model: model to get the missing information
|
|
369
|
+
:param config: configuration used to generate the model
|
|
370
|
+
:param head_dim: last dimension of the cache
|
|
371
|
+
:param dummy_max_token_id: dummy max token id
|
|
372
|
+
:param pad_token_id: pad_token_id
|
|
373
|
+
:param image_token_index: image_token_index
|
|
374
|
+
:param batch_size: batch size
|
|
375
|
+
:param sequence_length: sequence length
|
|
376
|
+
:param max_sequence_length: for the cache
|
|
377
|
+
:param total_sequence_length: for the mask
|
|
378
|
+
:param n_images: number of images
|
|
379
|
+
:param width: width of the image
|
|
380
|
+
:param height: height of the image
|
|
381
|
+
:param num_channels: number of channels
|
|
382
|
+
:return: dictionary
|
|
383
|
+
|
|
384
|
+
.. note::
|
|
385
|
+
|
|
386
|
+
The content of the input_ids and its shape is correlated to the images.
|
|
387
|
+
The function uses a predefined values. The function raises an exception
|
|
388
|
+
if dimension are not the expected ones.
|
|
389
|
+
"""
|
|
390
|
+
if model.__class__.__name__.startswith("Gemma3"):
|
|
391
|
+
res = _get_inputs_gemma3(
|
|
392
|
+
model,
|
|
393
|
+
config,
|
|
394
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
395
|
+
num_key_value_heads=num_key_value_heads,
|
|
396
|
+
num_hidden_layers=num_hidden_layers,
|
|
397
|
+
pad_token_id=pad_token_id,
|
|
398
|
+
image_token_index=image_token_index,
|
|
399
|
+
head_dim=head_dim,
|
|
400
|
+
width=width,
|
|
401
|
+
height=height,
|
|
402
|
+
num_channels=num_channels,
|
|
403
|
+
batch_size=batch_size,
|
|
404
|
+
sequence_length=sequence_length,
|
|
405
|
+
max_sequence_length=max_sequence_length,
|
|
406
|
+
total_sequence_length=total_sequence_length,
|
|
407
|
+
n_images=n_images,
|
|
408
|
+
**kwargs,
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
res = get_inputs_default(
|
|
412
|
+
model,
|
|
413
|
+
config,
|
|
414
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
415
|
+
num_key_value_heads=num_key_value_heads,
|
|
416
|
+
num_hidden_layers=num_hidden_layers,
|
|
417
|
+
pad_token_id=pad_token_id,
|
|
418
|
+
image_token_index=image_token_index,
|
|
419
|
+
head_dim=head_dim,
|
|
420
|
+
width=width,
|
|
421
|
+
height=height,
|
|
422
|
+
num_channels=num_channels,
|
|
423
|
+
batch_size=batch_size,
|
|
424
|
+
sequence_length=sequence_length,
|
|
425
|
+
max_sequence_length=max_sequence_length,
|
|
426
|
+
total_sequence_length=total_sequence_length,
|
|
427
|
+
n_images=n_images,
|
|
428
|
+
**kwargs,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
if add_second_input:
|
|
432
|
+
assert (
|
|
433
|
+
add_second_input > 0
|
|
434
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
435
|
+
res["inputs2"] = get_inputs(
|
|
436
|
+
model=model,
|
|
437
|
+
config=config,
|
|
438
|
+
dummy_max_token_id=dummy_max_token_id,
|
|
439
|
+
num_key_value_heads=num_key_value_heads,
|
|
440
|
+
num_hidden_layers=num_hidden_layers,
|
|
441
|
+
head_dim=head_dim,
|
|
442
|
+
width=width,
|
|
443
|
+
height=height,
|
|
444
|
+
num_channels=num_channels,
|
|
445
|
+
batch_size=3,
|
|
446
|
+
sequence_length=1,
|
|
447
|
+
max_sequence_length=1,
|
|
448
|
+
total_sequence_length=1,
|
|
449
|
+
n_images=0,
|
|
450
|
+
pad_token_id=pad_token_id,
|
|
451
|
+
image_token_index=image_token_index,
|
|
452
|
+
add_second_input=0,
|
|
453
|
+
**kwargs,
|
|
454
|
+
)["inputs"]
|
|
455
|
+
return res
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
459
|
+
"""
|
|
460
|
+
Inputs kwargs.
|
|
461
|
+
|
|
462
|
+
If the configuration is None, the function selects typical dimensions.
|
|
463
|
+
"""
|
|
464
|
+
if config is not None:
|
|
465
|
+
if hasattr(config, "text_config"):
|
|
466
|
+
check_hasattr(
|
|
467
|
+
config.text_config,
|
|
468
|
+
"vocab_size",
|
|
469
|
+
"hidden_size",
|
|
470
|
+
"num_attention_heads",
|
|
471
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
472
|
+
"intermediate_size",
|
|
473
|
+
"hidden_size",
|
|
474
|
+
"pad_token_id",
|
|
475
|
+
)
|
|
476
|
+
check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
|
|
477
|
+
text_config = True
|
|
478
|
+
else:
|
|
479
|
+
check_hasattr(
|
|
480
|
+
config,
|
|
481
|
+
"vocab_size",
|
|
482
|
+
"hidden_size",
|
|
483
|
+
"num_attention_heads",
|
|
484
|
+
("num_key_value_heads", "num_attention_heads"),
|
|
485
|
+
"intermediate_size",
|
|
486
|
+
"hidden_size",
|
|
487
|
+
"vision_config",
|
|
488
|
+
)
|
|
489
|
+
text_config = False
|
|
490
|
+
check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
|
|
491
|
+
kwargs = dict(
|
|
492
|
+
head_dim=(
|
|
493
|
+
16
|
|
494
|
+
if config is None
|
|
495
|
+
else getattr(
|
|
496
|
+
config,
|
|
497
|
+
"head_dim",
|
|
498
|
+
(
|
|
499
|
+
config.text_config.head_dim
|
|
500
|
+
if text_config and hasattr(config.text_config, "head_dim")
|
|
501
|
+
else (
|
|
502
|
+
(config.text_config.hidden_size if text_config else config.hidden_size)
|
|
503
|
+
// (
|
|
504
|
+
config.text_config.num_attention_heads
|
|
505
|
+
if text_config
|
|
506
|
+
else config.num_attention_heads
|
|
507
|
+
)
|
|
508
|
+
)
|
|
509
|
+
),
|
|
510
|
+
)
|
|
511
|
+
),
|
|
512
|
+
dummy_max_token_id=(
|
|
513
|
+
31999
|
|
514
|
+
if config is None
|
|
515
|
+
else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
|
|
516
|
+
),
|
|
517
|
+
num_hidden_layers=(
|
|
518
|
+
4
|
|
519
|
+
if config is None
|
|
520
|
+
else (
|
|
521
|
+
config.text_config.num_hidden_layers
|
|
522
|
+
if text_config
|
|
523
|
+
else config.num_hidden_layers
|
|
524
|
+
)
|
|
525
|
+
),
|
|
526
|
+
num_key_value_heads=(
|
|
527
|
+
8
|
|
528
|
+
if config is None
|
|
529
|
+
else (
|
|
530
|
+
_pick(config.text_config, "num_key_value_heads", "num_attention_heads")
|
|
531
|
+
if text_config
|
|
532
|
+
else _pick(config, "num_key_value_heads", "num_attention_heads")
|
|
533
|
+
)
|
|
534
|
+
),
|
|
535
|
+
intermediate_size=(
|
|
536
|
+
1024
|
|
537
|
+
if config is None
|
|
538
|
+
else (
|
|
539
|
+
config.text_config.intermediate_size
|
|
540
|
+
if text_config
|
|
541
|
+
else config.intermediate_size
|
|
542
|
+
)
|
|
543
|
+
),
|
|
544
|
+
hidden_size=(
|
|
545
|
+
512
|
|
546
|
+
if config is None
|
|
547
|
+
else (config.text_config.hidden_size if text_config else config.hidden_size)
|
|
548
|
+
),
|
|
549
|
+
width=(
|
|
550
|
+
224
|
|
551
|
+
if config is None or not hasattr(config.vision_config, "image_size")
|
|
552
|
+
else config.vision_config.image_size
|
|
553
|
+
),
|
|
554
|
+
height=(
|
|
555
|
+
224
|
|
556
|
+
if config is None or not hasattr(config.vision_config, "image_size")
|
|
557
|
+
else config.vision_config.image_size
|
|
558
|
+
),
|
|
559
|
+
num_channels=(
|
|
560
|
+
3
|
|
561
|
+
if config is None
|
|
562
|
+
else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
|
|
563
|
+
),
|
|
564
|
+
pad_token_id=(
|
|
565
|
+
0
|
|
566
|
+
if config is None
|
|
567
|
+
or not hasattr(config, "text_config")
|
|
568
|
+
or not hasattr(config.text_config, "pad_token_id")
|
|
569
|
+
else config.text_config.pad_token_id
|
|
570
|
+
),
|
|
571
|
+
image_token_index=(
|
|
572
|
+
4
|
|
573
|
+
if config is None
|
|
574
|
+
or (
|
|
575
|
+
not hasattr(config, "image_token_index")
|
|
576
|
+
and not hasattr(config, "image_token_id")
|
|
577
|
+
)
|
|
578
|
+
else _pick(config, "image_token_index", "image_token_id")
|
|
579
|
+
),
|
|
580
|
+
)
|
|
581
|
+
return kwargs, get_inputs
|
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from typing import Any, Callable, Dict, Optional, Tuple
|
|
2
|
+
import torch
|
|
3
|
+
from ..helpers.config_helper import (
|
|
4
|
+
update_config,
|
|
5
|
+
check_hasattr,
|
|
6
|
+
default_num_hidden_layers as nhl,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
__TASK__ = "image-to-video"
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def reduce_model_config(config: Any) -> Dict[str, Any]:
|
|
13
|
+
"""Reduces a model size."""
|
|
14
|
+
if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
|
|
15
|
+
# We cannot reduce.
|
|
16
|
+
return {}
|
|
17
|
+
check_hasattr(config, ("num_hidden_layers", "num_layers"))
|
|
18
|
+
kwargs = {}
|
|
19
|
+
if hasattr(config, "num_layers"):
|
|
20
|
+
kwargs["num_layers"] = min(config.num_layers, nhl())
|
|
21
|
+
if hasattr(config, "num_hidden_layers"):
|
|
22
|
+
kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
|
|
23
|
+
|
|
24
|
+
update_config(config, kwargs)
|
|
25
|
+
return kwargs
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_inputs(
|
|
29
|
+
model: torch.nn.Module,
|
|
30
|
+
config: Optional[Any],
|
|
31
|
+
text_embed_dim: int,
|
|
32
|
+
latent_channels: int,
|
|
33
|
+
batch_size: int = 2,
|
|
34
|
+
image_height: int = 704,
|
|
35
|
+
image_width: int = 1280,
|
|
36
|
+
latent_frames: int = 1,
|
|
37
|
+
text_maxlen: int = 512,
|
|
38
|
+
add_second_input: int = 1,
|
|
39
|
+
**kwargs, # unused
|
|
40
|
+
):
|
|
41
|
+
"""
|
|
42
|
+
Generates inputs for task ``image-to-video``.
|
|
43
|
+
"""
|
|
44
|
+
assert (
|
|
45
|
+
"cls_cache" not in kwargs
|
|
46
|
+
), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
|
|
47
|
+
latent_height = image_height // 8
|
|
48
|
+
latent_width = image_width // 8
|
|
49
|
+
dtype = torch.float32
|
|
50
|
+
|
|
51
|
+
inputs = dict(
|
|
52
|
+
hidden_states=torch.randn(
|
|
53
|
+
batch_size,
|
|
54
|
+
latent_channels,
|
|
55
|
+
latent_frames,
|
|
56
|
+
latent_height,
|
|
57
|
+
latent_width,
|
|
58
|
+
dtype=dtype,
|
|
59
|
+
),
|
|
60
|
+
timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
|
|
61
|
+
encoder_hidden_states=torch.randn(
|
|
62
|
+
batch_size, text_maxlen, text_embed_dim, dtype=dtype
|
|
63
|
+
),
|
|
64
|
+
padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
|
|
65
|
+
fps=torch.tensor([16] * batch_size, dtype=dtype),
|
|
66
|
+
condition_mask=torch.randn(
|
|
67
|
+
batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
|
|
68
|
+
),
|
|
69
|
+
)
|
|
70
|
+
shapes = dict(
|
|
71
|
+
hidden_states={
|
|
72
|
+
0: "batch_size",
|
|
73
|
+
2: "latent_frames",
|
|
74
|
+
3: "latent_height",
|
|
75
|
+
4: "latent_width",
|
|
76
|
+
},
|
|
77
|
+
timestep={0: "batch_size"},
|
|
78
|
+
encoder_hidden_states={0: "batch_size"},
|
|
79
|
+
padding_mask={0: "batch_size", 2: "height", 3: "width"},
|
|
80
|
+
fps={0: "batch_size"},
|
|
81
|
+
condition_mask={
|
|
82
|
+
0: "batch_size",
|
|
83
|
+
2: "latent_frames",
|
|
84
|
+
3: "latent_height",
|
|
85
|
+
4: "latent_width",
|
|
86
|
+
},
|
|
87
|
+
)
|
|
88
|
+
res = dict(inputs=inputs, dynamic_shapes=shapes)
|
|
89
|
+
|
|
90
|
+
if add_second_input:
|
|
91
|
+
assert (
|
|
92
|
+
add_second_input > 0
|
|
93
|
+
), f"Not implemented for add_second_input={add_second_input}."
|
|
94
|
+
res["inputs2"] = get_inputs(
|
|
95
|
+
model=model,
|
|
96
|
+
config=config,
|
|
97
|
+
text_embed_dim=text_embed_dim,
|
|
98
|
+
latent_channels=latent_channels,
|
|
99
|
+
batch_size=batch_size,
|
|
100
|
+
image_height=image_height,
|
|
101
|
+
image_width=image_width,
|
|
102
|
+
latent_frames=latent_frames,
|
|
103
|
+
text_maxlen=text_maxlen,
|
|
104
|
+
add_second_input=0,
|
|
105
|
+
**kwargs,
|
|
106
|
+
)["inputs"]
|
|
107
|
+
return res
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
|
|
111
|
+
"""
|
|
112
|
+
Inputs kwargs.
|
|
113
|
+
|
|
114
|
+
If the configuration is None, the function selects typical dimensions.
|
|
115
|
+
"""
|
|
116
|
+
if config is not None:
|
|
117
|
+
check_hasattr(config, "in_channels", "text_embed_dim"),
|
|
118
|
+
kwargs = dict(
|
|
119
|
+
text_embed_dim=1024 if config is None else config.text_embed_dim,
|
|
120
|
+
latent_channels=16 if config is None else config.in_channels - 1,
|
|
121
|
+
batch_size=1,
|
|
122
|
+
image_height=8 * 50,
|
|
123
|
+
image_width=8 * 80,
|
|
124
|
+
latent_frames=1,
|
|
125
|
+
text_maxlen=512,
|
|
126
|
+
)
|
|
127
|
+
return kwargs, get_inputs
|