onnx-diagnostic 0.8.5__py3-none-any.whl → 0.8.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +154 -3
- onnx_diagnostic/ci_models/__init__.py +0 -0
- onnx_diagnostic/ci_models/ci_helpers.py +435 -0
- onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
- onnx_diagnostic/ci_models/export_qwen25_vl.py +568 -0
- onnx_diagnostic/export/api.py +1 -0
- onnx_diagnostic/export/cf_simple_loop_for.py +537 -0
- onnx_diagnostic/export/control_flow_onnx.py +23 -17
- onnx_diagnostic/ext_test_case.py +23 -2
- onnx_diagnostic/helpers/bench_run.py +1 -1
- onnx_diagnostic/helpers/log_helper.py +1 -3
- onnx_diagnostic/helpers/optim_helper.py +116 -0
- onnx_diagnostic/tasks/image_text_to_text.py +15 -5
- onnx_diagnostic/tasks/text2text_generation.py +84 -48
- onnx_diagnostic/tasks/text_generation.py +3 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +44 -2
- onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
- onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +86 -3
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +23 -24
- onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +29 -8
- onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
- onnx_diagnostic/torch_onnx/compare.py +357 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +33 -27
- onnx_diagnostic/export/control_flow.py +0 -214
- onnx_diagnostic/export/control_flow_research.py +0 -140
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.5.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,1062 @@
|
|
|
1
|
+
r"""
|
|
2
|
+
Export visual and embedding parts of microsoft/Phi-4-multimodal-instruct
|
|
3
|
+
========================================================================
|
|
4
|
+
|
|
5
|
+
Requirements
|
|
6
|
+
++++++++++++
|
|
7
|
+
|
|
8
|
+
::
|
|
9
|
+
|
|
10
|
+
git+https://github.com/sdpython/experimental-experiment.git # optional
|
|
11
|
+
backoff
|
|
12
|
+
huggingface_hub
|
|
13
|
+
onnx-diagnostic>=0.8.6
|
|
14
|
+
onnxruntime>=1.23
|
|
15
|
+
peft==0.17.1
|
|
16
|
+
Pillow
|
|
17
|
+
requests
|
|
18
|
+
torch>=2.10 # weekly is better
|
|
19
|
+
tqdm
|
|
20
|
+
transformers==4.48.3
|
|
21
|
+
|
|
22
|
+
.. note::
|
|
23
|
+
|
|
24
|
+
``flash_attn`` must be removed to export if it was installed.
|
|
25
|
+
|
|
26
|
+
Examples
|
|
27
|
+
++++++++
|
|
28
|
+
|
|
29
|
+
.. code-block:: bash
|
|
30
|
+
|
|
31
|
+
python -m onnx_diagnostic.ci_models.export_phi4_mm \
|
|
32
|
+
-m microsoft/Phi-4-multimodal-instruct --device cuda --dtype float16 \
|
|
33
|
+
--exporter custom --pretrained --second-input --part vision
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
import os
|
|
37
|
+
import pprint
|
|
38
|
+
import sys
|
|
39
|
+
import textwrap
|
|
40
|
+
import time
|
|
41
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
42
|
+
|
|
43
|
+
from .ci_helpers import (
|
|
44
|
+
check_for_discrepancies_and_log_everything_into_a_json_file,
|
|
45
|
+
compute_expected_outputs,
|
|
46
|
+
get_parser,
|
|
47
|
+
get_torch_dtype_from_command_line_args,
|
|
48
|
+
simplify_model_id_for_a_filename,
|
|
49
|
+
zip_model_and_data_into_a_single_file,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def get_patches_transformers():
|
|
54
|
+
import re
|
|
55
|
+
from itertools import cycle
|
|
56
|
+
import torch
|
|
57
|
+
import transformers
|
|
58
|
+
|
|
59
|
+
class patched_PreTrainedModel(torch.nn.Module):
|
|
60
|
+
_PATCHES_ = ["get_expanded_tied_weights_keys"]
|
|
61
|
+
_PATCHED_CLASS_ = transformers.modeling_utils.PreTrainedModel
|
|
62
|
+
|
|
63
|
+
def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
|
|
64
|
+
if all_submodels:
|
|
65
|
+
expanded_tied_weights = {}
|
|
66
|
+
for prefix, submodule in self.named_modules(remove_duplicate=False):
|
|
67
|
+
if isinstance(submodule, transformers.modeling_utils.PreTrainedModel):
|
|
68
|
+
submodel_tied_weights = submodule.get_expanded_tied_weights_keys(
|
|
69
|
+
all_submodels=False
|
|
70
|
+
)
|
|
71
|
+
if prefix != "":
|
|
72
|
+
submodel_tied_weights = {
|
|
73
|
+
f"{prefix}.{k}": f"{prefix}.{v}"
|
|
74
|
+
for k, v in submodel_tied_weights.items()
|
|
75
|
+
}
|
|
76
|
+
expanded_tied_weights.update(submodel_tied_weights)
|
|
77
|
+
return expanded_tied_weights
|
|
78
|
+
|
|
79
|
+
tied_mapping = self._tied_weights_keys
|
|
80
|
+
if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
|
|
81
|
+
return {}
|
|
82
|
+
elif tied_mapping is None:
|
|
83
|
+
return {}
|
|
84
|
+
common_case_regex = re.compile(r"^[A-Za-z0-9_\.]+(weight)|(bias)$")
|
|
85
|
+
# PATCHED
|
|
86
|
+
if tied_mapping == ["lm_head.weight"]:
|
|
87
|
+
tied_mapping = {"lm_head.weight": "model.embed_tokens.weight"}
|
|
88
|
+
if all(
|
|
89
|
+
common_case_regex.match(k) for k in tied_mapping.keys() | tied_mapping.values()
|
|
90
|
+
):
|
|
91
|
+
return tied_mapping.copy()
|
|
92
|
+
|
|
93
|
+
expanded_tied_weights = {}
|
|
94
|
+
all_param_names = {k for k, _ in self.named_parameters(remove_duplicate=False)} | {
|
|
95
|
+
k for k, _ in self.named_buffers(remove_duplicate=False)
|
|
96
|
+
}
|
|
97
|
+
for target_name, source_name in tied_mapping.items():
|
|
98
|
+
target_name = "^" + target_name
|
|
99
|
+
source_name = "^" + source_name
|
|
100
|
+
|
|
101
|
+
source_params = sorted(
|
|
102
|
+
filter(lambda x: re.search(source_name, x), all_param_names)
|
|
103
|
+
)
|
|
104
|
+
target_params = sorted(
|
|
105
|
+
filter(lambda x: re.search(target_name, x), all_param_names)
|
|
106
|
+
)
|
|
107
|
+
if (
|
|
108
|
+
not len(source_params) > 0
|
|
109
|
+
or not len(target_params) > 0
|
|
110
|
+
or len(target_params) % len(source_params) != 0
|
|
111
|
+
):
|
|
112
|
+
raise ValueError(
|
|
113
|
+
f"There is an issue with your definition of "
|
|
114
|
+
f"`tie_weights_keys` for {source_name}:{target_name}. "
|
|
115
|
+
f"We found {source_params} to tie into {target_params}"
|
|
116
|
+
)
|
|
117
|
+
for target_n, source_n in zip(target_params, cycle(source_params)):
|
|
118
|
+
if source_n in expanded_tied_weights.keys():
|
|
119
|
+
expanded_tied_weights[target_n] = expanded_tied_weights[source_n]
|
|
120
|
+
else:
|
|
121
|
+
expanded_tied_weights[target_n] = source_n
|
|
122
|
+
|
|
123
|
+
return expanded_tied_weights
|
|
124
|
+
|
|
125
|
+
return [patched_PreTrainedModel]
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def get_patches(mod, mod_siglip):
|
|
129
|
+
import torch
|
|
130
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
131
|
+
from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
|
|
132
|
+
from ..export.cf_simple_loop_for import simple_loop_for
|
|
133
|
+
|
|
134
|
+
_IMAGE_SPECIAL_TOKEN_ID = mod._IMAGE_SPECIAL_TOKEN_ID
|
|
135
|
+
|
|
136
|
+
class patched_SiglipVisionEmbeddings(torch.nn.Module):
|
|
137
|
+
_PATCHES_ = ["forward"]
|
|
138
|
+
_PATCHED_CLASS_ = mod_siglip.SiglipVisionEmbeddings
|
|
139
|
+
|
|
140
|
+
def forward(
|
|
141
|
+
self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
|
|
142
|
+
) -> torch.Tensor:
|
|
143
|
+
batch_size = pixel_values.size(0)
|
|
144
|
+
|
|
145
|
+
patch_embeds = self.patch_embedding(pixel_values)
|
|
146
|
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
|
147
|
+
|
|
148
|
+
max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
|
|
149
|
+
max_nb_patches_h, max_nb_patches_w = (
|
|
150
|
+
max_im_h // self.patch_size,
|
|
151
|
+
max_im_w // self.patch_size,
|
|
152
|
+
)
|
|
153
|
+
boundaries = torch.arange(
|
|
154
|
+
torch.tensor(1 / self.num_patches_per_side, dtype=pixel_values.dtype),
|
|
155
|
+
torch.tensor(1.0, dtype=pixel_values.dtype),
|
|
156
|
+
torch.tensor(1 / self.num_patches_per_side, dtype=pixel_values.dtype),
|
|
157
|
+
)
|
|
158
|
+
position_ids = torch.full(
|
|
159
|
+
size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
# PATHED: a loop replace with scan.
|
|
163
|
+
|
|
164
|
+
def body(p_attn_mask, position_ids_row, boundaries):
|
|
165
|
+
h_len = torch.tensor(1, dtype=boundaries.dtype) / p_attn_mask[:, 0].sum()
|
|
166
|
+
w_len = torch.tensor(1, dtype=boundaries.dtype) / p_attn_mask[0].sum()
|
|
167
|
+
torch._check(h_len.item() > 0)
|
|
168
|
+
fractional_coords_h = torch.arange(
|
|
169
|
+
torch.tensor(0.0, dtype=boundaries.dtype),
|
|
170
|
+
torch.tensor(1 - 1e-6, dtype=boundaries.dtype),
|
|
171
|
+
h_len,
|
|
172
|
+
)
|
|
173
|
+
torch._check(w_len.item() > 0)
|
|
174
|
+
fractional_coords_w = torch.arange(
|
|
175
|
+
torch.tensor(0.0, dtype=boundaries.dtype),
|
|
176
|
+
torch.tensor(1 - 1e-6, dtype=boundaries.dtype),
|
|
177
|
+
w_len,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
|
|
181
|
+
bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
|
|
182
|
+
|
|
183
|
+
pos_ids = (
|
|
184
|
+
bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
|
|
185
|
+
).flatten()
|
|
186
|
+
|
|
187
|
+
row = position_ids_row.clone()
|
|
188
|
+
row[p_attn_mask.view(-1)] = pos_ids
|
|
189
|
+
return [row]
|
|
190
|
+
|
|
191
|
+
position_ids = torch.ops.higher_order.scan(
|
|
192
|
+
body, [], [patch_attention_mask, position_ids], additional_inputs=[boundaries]
|
|
193
|
+
)[0]
|
|
194
|
+
|
|
195
|
+
position_ids = position_ids.to(self.position_embedding.weight.device)
|
|
196
|
+
embeddings = embeddings + self.position_embedding(position_ids)
|
|
197
|
+
return embeddings
|
|
198
|
+
|
|
199
|
+
class patched_SiglipVisionTransformer(torch.nn.Module):
|
|
200
|
+
_PATCHES_ = ["forward"]
|
|
201
|
+
_PATCHED_CLASS_ = mod_siglip.SiglipVisionTransformer
|
|
202
|
+
|
|
203
|
+
def forward(
|
|
204
|
+
self,
|
|
205
|
+
pixel_values,
|
|
206
|
+
patch_attention_mask: Optional[torch.BoolTensor] = None,
|
|
207
|
+
output_attentions: Optional[bool] = None,
|
|
208
|
+
output_hidden_states: Optional[bool] = None,
|
|
209
|
+
return_dict: Optional[bool] = None,
|
|
210
|
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
|
211
|
+
output_attentions = (
|
|
212
|
+
output_attentions
|
|
213
|
+
if output_attentions is not None
|
|
214
|
+
else self.config.output_attentions
|
|
215
|
+
)
|
|
216
|
+
output_hidden_states = (
|
|
217
|
+
output_hidden_states
|
|
218
|
+
if output_hidden_states is not None
|
|
219
|
+
else self.config.output_hidden_states
|
|
220
|
+
)
|
|
221
|
+
return_dict = (
|
|
222
|
+
return_dict if return_dict is not None else self.config.use_return_dict
|
|
223
|
+
)
|
|
224
|
+
|
|
225
|
+
batch_size = pixel_values.size(0)
|
|
226
|
+
if patch_attention_mask is None:
|
|
227
|
+
patch_attention_mask = torch.ones(
|
|
228
|
+
size=(
|
|
229
|
+
batch_size,
|
|
230
|
+
pixel_values.size(2) // self.config.patch_size,
|
|
231
|
+
pixel_values.size(3) // self.config.patch_size,
|
|
232
|
+
),
|
|
233
|
+
dtype=torch.bool,
|
|
234
|
+
device=pixel_values.device,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
hidden_states = self.embeddings(
|
|
238
|
+
pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
patch_attention_mask = patch_attention_mask.view(batch_size, -1)
|
|
242
|
+
# PATCHED: skip the test
|
|
243
|
+
# if not torch.any(~patch_attention_mask):
|
|
244
|
+
# attention_mask = None
|
|
245
|
+
# else:
|
|
246
|
+
# attention_mask = (
|
|
247
|
+
# _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
|
|
248
|
+
## if not self.config._flash_attn_2_enabled
|
|
249
|
+
# else patch_attention_mask
|
|
250
|
+
# )
|
|
251
|
+
attention_mask = _prepare_4d_attention_mask(
|
|
252
|
+
patch_attention_mask, hidden_states.dtype
|
|
253
|
+
)
|
|
254
|
+
|
|
255
|
+
encoder_outputs = self.encoder(
|
|
256
|
+
inputs_embeds=hidden_states,
|
|
257
|
+
attention_mask=attention_mask,
|
|
258
|
+
output_attentions=output_attentions,
|
|
259
|
+
output_hidden_states=output_hidden_states,
|
|
260
|
+
return_dict=return_dict,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
last_hidden_state = encoder_outputs[0]
|
|
264
|
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
|
265
|
+
|
|
266
|
+
pooled_output = self.head(
|
|
267
|
+
hidden_state=last_hidden_state,
|
|
268
|
+
attention_mask=patch_attention_mask,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
if not return_dict:
|
|
272
|
+
return (last_hidden_state, pooled_output, *encoder_outputs[1:])
|
|
273
|
+
|
|
274
|
+
return BaseModelOutputWithPooling(
|
|
275
|
+
last_hidden_state=last_hidden_state,
|
|
276
|
+
pooler_output=pooled_output,
|
|
277
|
+
hidden_states=encoder_outputs.hidden_states,
|
|
278
|
+
attentions=encoder_outputs.attentions,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
class patched_Phi4MMImageEmbedding(torch.nn.Module):
|
|
282
|
+
_PATCHES_ = ["forward"]
|
|
283
|
+
_PATCHED_CLASS_ = mod.Phi4MMImageEmbedding
|
|
284
|
+
|
|
285
|
+
def forward(
|
|
286
|
+
self,
|
|
287
|
+
input_ids: torch.LongTensor,
|
|
288
|
+
input_embeds: torch.FloatTensor,
|
|
289
|
+
image_sizes=None,
|
|
290
|
+
**kwargs,
|
|
291
|
+
) -> torch.FloatTensor:
|
|
292
|
+
|
|
293
|
+
if isinstance(input_ids, tuple):
|
|
294
|
+
input_ids, input_embeds = input_ids
|
|
295
|
+
|
|
296
|
+
img_embeds = input_embeds
|
|
297
|
+
if image_sizes is None and "image_sizes" in kwargs:
|
|
298
|
+
image_sizes = kwargs["image_sizes"]
|
|
299
|
+
img_sizes = image_sizes
|
|
300
|
+
|
|
301
|
+
if self.img_features is not None:
|
|
302
|
+
img_embeds = self.img_features.clone()
|
|
303
|
+
self.img_features = None
|
|
304
|
+
|
|
305
|
+
if self.img_sizes is not None:
|
|
306
|
+
img_sizes = self.img_sizes
|
|
307
|
+
|
|
308
|
+
dtype = self.img_processor.embeddings.patch_embedding.weight.dtype
|
|
309
|
+
if img_embeds is not None:
|
|
310
|
+
img_embeds = img_embeds.to(dtype)
|
|
311
|
+
|
|
312
|
+
if self.image_attention_mask is not None:
|
|
313
|
+
image_attention_mask = self.image_attention_mask.clone()
|
|
314
|
+
self.image_attention_mask = None
|
|
315
|
+
elif "image_attention_mask" in kwargs:
|
|
316
|
+
image_attention_mask = kwargs["image_attention_mask"]
|
|
317
|
+
else:
|
|
318
|
+
image_attention_mask = None
|
|
319
|
+
input_shape = input_ids.size()
|
|
320
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
|
321
|
+
|
|
322
|
+
with torch.no_grad():
|
|
323
|
+
positions_tuple = torch.nonzero(
|
|
324
|
+
input_ids == _IMAGE_SPECIAL_TOKEN_ID, as_tuple=True
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
select = False
|
|
328
|
+
hd_transform = False
|
|
329
|
+
|
|
330
|
+
if isinstance(self.img_projection, torch.nn.Sequential):
|
|
331
|
+
target_device = self.img_projection[0].bias.device
|
|
332
|
+
else:
|
|
333
|
+
target_device = self.img_projection.bias.device
|
|
334
|
+
|
|
335
|
+
# PATCHED: Let's assume it is always true.
|
|
336
|
+
if True: # len(positions.tolist()) > 0:
|
|
337
|
+
if self.use_hd_transform and img_sizes is not None:
|
|
338
|
+
hd_transform = True
|
|
339
|
+
bs = img_embeds.shape[0]
|
|
340
|
+
if image_attention_mask is not None:
|
|
341
|
+
img_features = self.get_img_features(
|
|
342
|
+
img_embeds.flatten(0, 1),
|
|
343
|
+
attention_mask=image_attention_mask.type(torch.BoolTensor)
|
|
344
|
+
.flatten(0, 1)
|
|
345
|
+
.to(target_device),
|
|
346
|
+
)
|
|
347
|
+
else:
|
|
348
|
+
img_features = self.get_img_features(img_embeds.flatten(0, 1))
|
|
349
|
+
|
|
350
|
+
base_resolution = self.crop_size
|
|
351
|
+
base_feat_height_reduction = self.base_feat_height_reduction
|
|
352
|
+
|
|
353
|
+
base_feat_height = base_feat_width = torch.sym_int(
|
|
354
|
+
img_features.shape[1] ** 0.5
|
|
355
|
+
)
|
|
356
|
+
img_features = img_features.view(
|
|
357
|
+
bs, -1, base_feat_height * base_feat_width, self.image_dim_out
|
|
358
|
+
)
|
|
359
|
+
C = self.image_dim_out
|
|
360
|
+
H = base_feat_height
|
|
361
|
+
|
|
362
|
+
if isinstance(img_sizes, torch.Tensor):
|
|
363
|
+
img_sizes = img_sizes.view(-1, 2)
|
|
364
|
+
else:
|
|
365
|
+
raise NotImplementedError
|
|
366
|
+
select = True
|
|
367
|
+
|
|
368
|
+
hidden_states = kwargs["wte"](input_ids)
|
|
369
|
+
|
|
370
|
+
assert select
|
|
371
|
+
if hd_transform:
|
|
372
|
+
|
|
373
|
+
def body_fn(
|
|
374
|
+
_bs,
|
|
375
|
+
img_features,
|
|
376
|
+
img_sizes,
|
|
377
|
+
image_attention_mask,
|
|
378
|
+
cst_shape_CH,
|
|
379
|
+
glb_GN,
|
|
380
|
+
sub_GN,
|
|
381
|
+
proj_0_weight,
|
|
382
|
+
proj_0_bias,
|
|
383
|
+
proj_1_weight,
|
|
384
|
+
proj_1_bias,
|
|
385
|
+
base_resolution=None,
|
|
386
|
+
base_feat_height_reduction=None,
|
|
387
|
+
base_feat_height=None,
|
|
388
|
+
base_feat_width=None,
|
|
389
|
+
):
|
|
390
|
+
# oddly, it seems impossible to write img_sizes[_bs.item()]
|
|
391
|
+
# it needs img_sizes[_bs.item() : (_bs + 1).item()][0]
|
|
392
|
+
row = img_sizes[_bs.item() : (_bs + 1).item()]
|
|
393
|
+
row = row[0]
|
|
394
|
+
h, w = row[0], row[1]
|
|
395
|
+
h = h // base_resolution
|
|
396
|
+
w = w // base_resolution
|
|
397
|
+
B_ = h * w
|
|
398
|
+
C, H = cst_shape_CH.shape
|
|
399
|
+
|
|
400
|
+
# 1 x (24x24) x 1024
|
|
401
|
+
global_img_feature = img_features[_bs.item() : (_bs + 1).item(), :1][0]
|
|
402
|
+
|
|
403
|
+
# 1 x 12 x 12 x 4096
|
|
404
|
+
glb_img = (
|
|
405
|
+
global_img_feature.reshape(1, H, H, C)
|
|
406
|
+
.reshape(
|
|
407
|
+
1,
|
|
408
|
+
H // base_feat_height_reduction,
|
|
409
|
+
base_feat_height_reduction,
|
|
410
|
+
H // base_feat_height_reduction,
|
|
411
|
+
base_feat_height_reduction,
|
|
412
|
+
C,
|
|
413
|
+
)
|
|
414
|
+
.contiguous()
|
|
415
|
+
.permute(0, 1, 3, 2, 4, 5)
|
|
416
|
+
.reshape(
|
|
417
|
+
1,
|
|
418
|
+
H // base_feat_height_reduction,
|
|
419
|
+
H // base_feat_height_reduction,
|
|
420
|
+
base_feat_height_reduction * base_feat_height_reduction * C,
|
|
421
|
+
)
|
|
422
|
+
.contiguous()
|
|
423
|
+
)
|
|
424
|
+
temp_glb_GN = sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1)
|
|
425
|
+
|
|
426
|
+
# 1 x 156 x 4096
|
|
427
|
+
glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(
|
|
428
|
+
1, -1, base_feat_height_reduction * base_feat_height_reduction * C
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
# (max_num_crops-1) x (12x12) x C
|
|
432
|
+
sub_img = img_features[_bs.item() : (_bs + 1).item(), 1:][0]
|
|
433
|
+
# 16x574x1024
|
|
434
|
+
# get rid of padding sub_img
|
|
435
|
+
sub_img = sub_img[: B_.item()]
|
|
436
|
+
|
|
437
|
+
# (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024)
|
|
438
|
+
# -> (num_crops, 12*12, 4*1024)
|
|
439
|
+
sub_img = (
|
|
440
|
+
sub_img.reshape(B_.item(), H, H, C)
|
|
441
|
+
.reshape(
|
|
442
|
+
B_.item(),
|
|
443
|
+
H // base_feat_height_reduction,
|
|
444
|
+
base_feat_height_reduction,
|
|
445
|
+
H // base_feat_height_reduction,
|
|
446
|
+
base_feat_height_reduction,
|
|
447
|
+
C,
|
|
448
|
+
)
|
|
449
|
+
.contiguous()
|
|
450
|
+
.permute(0, 1, 3, 2, 4, 5)
|
|
451
|
+
.reshape(
|
|
452
|
+
B_.item(),
|
|
453
|
+
-1,
|
|
454
|
+
base_feat_height_reduction * base_feat_height_reduction * C,
|
|
455
|
+
)
|
|
456
|
+
.contiguous()
|
|
457
|
+
)
|
|
458
|
+
sub_img = (
|
|
459
|
+
sub_img.reshape(
|
|
460
|
+
1,
|
|
461
|
+
h.item(),
|
|
462
|
+
w.item(),
|
|
463
|
+
base_feat_height // base_feat_height_reduction,
|
|
464
|
+
base_feat_width // base_feat_height_reduction,
|
|
465
|
+
-1,
|
|
466
|
+
)
|
|
467
|
+
.permute(0, 1, 3, 2, 4, 5)
|
|
468
|
+
.reshape(
|
|
469
|
+
1,
|
|
470
|
+
(h * base_feat_height // base_feat_height_reduction).item(),
|
|
471
|
+
(w * base_feat_width // base_feat_height_reduction).item(),
|
|
472
|
+
base_feat_height_reduction * base_feat_height_reduction * C,
|
|
473
|
+
)
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
reshaped_image_attention_mask = (
|
|
477
|
+
image_attention_mask[
|
|
478
|
+
_bs.item() : (_bs + 1).item(), 1 : (B_ + 1).item(), 0::2, 0::2
|
|
479
|
+
][0]
|
|
480
|
+
.reshape(
|
|
481
|
+
1,
|
|
482
|
+
h.item(),
|
|
483
|
+
w.item(),
|
|
484
|
+
base_feat_height // base_feat_height_reduction,
|
|
485
|
+
base_feat_width // base_feat_height_reduction,
|
|
486
|
+
)
|
|
487
|
+
.permute(0, 1, 3, 2, 4)
|
|
488
|
+
.reshape(
|
|
489
|
+
1,
|
|
490
|
+
(h * base_feat_height // base_feat_height_reduction).item(),
|
|
491
|
+
(w * base_feat_width // base_feat_height_reduction).item(),
|
|
492
|
+
)
|
|
493
|
+
)
|
|
494
|
+
useful_height = (
|
|
495
|
+
reshaped_image_attention_mask[0, :, 0].sum().to(torch.int64).item()
|
|
496
|
+
)
|
|
497
|
+
useful_width = (
|
|
498
|
+
reshaped_image_attention_mask[0, 0, :].sum().to(torch.int64).item()
|
|
499
|
+
)
|
|
500
|
+
# the module cannot be extracted from here
|
|
501
|
+
sub_img = sub_img[:, :useful_height, :useful_width]
|
|
502
|
+
temp_sub_GN = sub_GN.repeat(1, useful_height, 1, 1)
|
|
503
|
+
# temp_len = (
|
|
504
|
+
# image_attention_mask[_bs, : B_ + 1, 0::2, 0::2]
|
|
505
|
+
# .sum()
|
|
506
|
+
# .to(torch.int64)
|
|
507
|
+
# .item()
|
|
508
|
+
# + (useful_height + 1)
|
|
509
|
+
# + base_feat_height // base_feat_height_reduction
|
|
510
|
+
# )
|
|
511
|
+
|
|
512
|
+
sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(
|
|
513
|
+
1, -1, base_feat_height_reduction * base_feat_height_reduction * C
|
|
514
|
+
)
|
|
515
|
+
# (1, num_img_tokens, 1024*4)
|
|
516
|
+
|
|
517
|
+
# glb + sub
|
|
518
|
+
# glb_sub
|
|
519
|
+
# output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
|
|
520
|
+
# sub_glb
|
|
521
|
+
_output_img = torch.cat([sub_img, glb_GN, glb_img], dim=1)
|
|
522
|
+
# output_len.append(temp_len)
|
|
523
|
+
proj = torch.nn.functional.linear(_output_img, proj_0_weight, proj_0_bias)
|
|
524
|
+
proj = torch.nn.functional.gelu(proj)
|
|
525
|
+
proj = torch.nn.functional.linear(proj, proj_1_weight, proj_1_bias)
|
|
526
|
+
return (proj,)
|
|
527
|
+
|
|
528
|
+
def local_body_fn(
|
|
529
|
+
n_iter,
|
|
530
|
+
img_features,
|
|
531
|
+
img_sizes,
|
|
532
|
+
image_attention_mask,
|
|
533
|
+
cst_shape_CH,
|
|
534
|
+
glb_GN,
|
|
535
|
+
sub_GN,
|
|
536
|
+
proj_0_weight,
|
|
537
|
+
proj_0_bias,
|
|
538
|
+
proj_1_weight,
|
|
539
|
+
proj_1_bias,
|
|
540
|
+
):
|
|
541
|
+
return body_fn(
|
|
542
|
+
n_iter,
|
|
543
|
+
img_features,
|
|
544
|
+
img_sizes,
|
|
545
|
+
image_attention_mask,
|
|
546
|
+
cst_shape_CH,
|
|
547
|
+
glb_GN,
|
|
548
|
+
sub_GN,
|
|
549
|
+
proj_0_weight,
|
|
550
|
+
proj_0_bias,
|
|
551
|
+
proj_1_weight,
|
|
552
|
+
proj_1_bias,
|
|
553
|
+
base_resolution=base_resolution,
|
|
554
|
+
base_feat_height_reduction=base_feat_height_reduction,
|
|
555
|
+
base_feat_height=base_feat_height,
|
|
556
|
+
base_feat_width=base_feat_width,
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
tmp = torch.arange(bs + 1).max()
|
|
560
|
+
glb_GN = self.glb_GN
|
|
561
|
+
sub_GN = self.sub_GN
|
|
562
|
+
cst_shape_CH = torch.zeros((C, H), dtype=torch.int32)
|
|
563
|
+
|
|
564
|
+
merged_img_set_tensor = simple_loop_for(
|
|
565
|
+
tmp,
|
|
566
|
+
local_body_fn,
|
|
567
|
+
(
|
|
568
|
+
img_features,
|
|
569
|
+
img_sizes,
|
|
570
|
+
image_attention_mask,
|
|
571
|
+
cst_shape_CH,
|
|
572
|
+
glb_GN,
|
|
573
|
+
sub_GN,
|
|
574
|
+
self.img_projection[0].weight,
|
|
575
|
+
self.img_projection[0].bias,
|
|
576
|
+
# self.img_projection[1] is GELU
|
|
577
|
+
self.img_projection[2].weight,
|
|
578
|
+
self.img_projection[2].bias,
|
|
579
|
+
),
|
|
580
|
+
[1],
|
|
581
|
+
)
|
|
582
|
+
torch._check(isinstance(merged_img_set_tensor, torch.Tensor))
|
|
583
|
+
merged_img_set_tensor = merged_img_set_tensor.squeeze(0)
|
|
584
|
+
|
|
585
|
+
# merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
|
|
586
|
+
merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(
|
|
587
|
+
hidden_states.device
|
|
588
|
+
)
|
|
589
|
+
with torch.autocast(device_type=hidden_states.device.type, enabled=False):
|
|
590
|
+
new_hidden_states = hidden_states.index_put(
|
|
591
|
+
indices=positions_tuple,
|
|
592
|
+
values=merged_img_set_tensor,
|
|
593
|
+
accumulate=False,
|
|
594
|
+
)
|
|
595
|
+
hidden_states = new_hidden_states
|
|
596
|
+
else:
|
|
597
|
+
raise NotImplementedError
|
|
598
|
+
|
|
599
|
+
if self.drop is not None:
|
|
600
|
+
hidden_states = self.drop(hidden_states)
|
|
601
|
+
|
|
602
|
+
return hidden_states
|
|
603
|
+
|
|
604
|
+
return [
|
|
605
|
+
*get_patches_transformers(),
|
|
606
|
+
patched_Phi4MMImageEmbedding,
|
|
607
|
+
patched_SiglipVisionEmbeddings,
|
|
608
|
+
patched_SiglipVisionTransformer,
|
|
609
|
+
]
|
|
610
|
+
|
|
611
|
+
|
|
612
|
+
def get_inputs_for_part(
|
|
613
|
+
model_id: str,
|
|
614
|
+
part: str,
|
|
615
|
+
torch_dtype: "torch.dtype", # noqa: F821
|
|
616
|
+
device: str,
|
|
617
|
+
second_input: bool,
|
|
618
|
+
) -> Tuple[Dict[str, "torch.Tensor"], List[Dict[str, "torch.Tensor"]]]: # noqa: F821
|
|
619
|
+
if part == "vision":
|
|
620
|
+
import requests
|
|
621
|
+
from PIL import Image
|
|
622
|
+
from transformers import AutoProcessor
|
|
623
|
+
|
|
624
|
+
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
|
|
625
|
+
user_prompt = "<|user|>\n"
|
|
626
|
+
assistant_prompt = "<|assistant|>\n"
|
|
627
|
+
prompt_suffix = "<|end|>\n"
|
|
628
|
+
prompt = (
|
|
629
|
+
f"{user_prompt}<|image_1|>\n<|image_2|>\n"
|
|
630
|
+
f"What is shown in these four images?{prompt_suffix}{assistant_prompt}"
|
|
631
|
+
)
|
|
632
|
+
|
|
633
|
+
root = os.path.join(os.path.dirname(__file__), "..", "..", "_small_data")
|
|
634
|
+
# "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
635
|
+
url = os.path.join(root, "American_Flamingo_JG.jpg")
|
|
636
|
+
image_1 = (
|
|
637
|
+
Image.open(requests.get(url, stream=True).raw)
|
|
638
|
+
if url.startswith("https")
|
|
639
|
+
else Image.open(url)
|
|
640
|
+
)
|
|
641
|
+
# "https://wallpaper.dog/large/10809054.jpg"
|
|
642
|
+
url = os.path.join(root, "RedcrestedTuraco.jpg")
|
|
643
|
+
image_4 = (
|
|
644
|
+
Image.open(requests.get(url, stream=True).raw)
|
|
645
|
+
if url.startswith("https")
|
|
646
|
+
else Image.open(url)
|
|
647
|
+
)
|
|
648
|
+
|
|
649
|
+
images = [image_1, image_4]
|
|
650
|
+
inputs = processor(prompt, images=images, return_tensors="pt").to(device)
|
|
651
|
+
export_inputs = dict(
|
|
652
|
+
input_ids=inputs["input_ids"].to(device),
|
|
653
|
+
input_image_embeds=inputs["input_image_embeds"].to(torch_dtype).to(device),
|
|
654
|
+
image_attention_mask=inputs["image_attention_mask"].to(torch_dtype).to(device),
|
|
655
|
+
image_sizes=inputs["image_sizes"].to(device),
|
|
656
|
+
)
|
|
657
|
+
assert (
|
|
658
|
+
export_inputs["input_image_embeds"].shape[-2] >= 28
|
|
659
|
+
and export_inputs["input_image_embeds"].shape[-1] >= 28
|
|
660
|
+
), (
|
|
661
|
+
f"required by the exported program but shape is "
|
|
662
|
+
f"{export_inputs['input_image_embeds'].shape}"
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
other_inputs = []
|
|
666
|
+
if second_input:
|
|
667
|
+
prompt = (
|
|
668
|
+
f"{user_prompt}<|image_1|>\n<|image_2|>\n<|image_3|>\n<|image_4|>\n"
|
|
669
|
+
f"What is shown in these four images?{prompt_suffix}{assistant_prompt}"
|
|
670
|
+
)
|
|
671
|
+
url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000"
|
|
672
|
+
image_2 = Image.open(requests.get(url, stream=True).raw)
|
|
673
|
+
url = (
|
|
674
|
+
"https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
|
|
675
|
+
)
|
|
676
|
+
image_3 = Image.open(requests.get(url, stream=True).raw)
|
|
677
|
+
|
|
678
|
+
images = [image_1, image_2, image_3, image_4]
|
|
679
|
+
inputs = processor(prompt, images=images, return_tensors="pt").to(device)
|
|
680
|
+
other_inputs = [
|
|
681
|
+
dict(
|
|
682
|
+
input_ids=inputs["input_ids"].to(device),
|
|
683
|
+
input_image_embeds=inputs["input_image_embeds"].to(torch_dtype).to(device),
|
|
684
|
+
image_attention_mask=inputs["image_attention_mask"]
|
|
685
|
+
.to(torch_dtype)
|
|
686
|
+
.to(device),
|
|
687
|
+
image_sizes=inputs["image_sizes"].to(device),
|
|
688
|
+
)
|
|
689
|
+
]
|
|
690
|
+
return export_inputs, other_inputs
|
|
691
|
+
|
|
692
|
+
raise NotImplementedError(f"No inputs yet implement for part={part!r}")
|
|
693
|
+
|
|
694
|
+
|
|
695
|
+
def main(
|
|
696
|
+
model_id: str = "microsoft/Phi-4-multimodal-instruct",
|
|
697
|
+
device: str = "cpu",
|
|
698
|
+
dtype: str = "float32",
|
|
699
|
+
exporter: str = "onnx-dynamo",
|
|
700
|
+
pretrained: bool = True,
|
|
701
|
+
second_input: bool = True,
|
|
702
|
+
make_zip: bool = False,
|
|
703
|
+
output_folder: str = "dump_models",
|
|
704
|
+
existing_onnx: str | None = None,
|
|
705
|
+
part: str = "vision",
|
|
706
|
+
atol: float = 2,
|
|
707
|
+
mismatch01: float = 0.01,
|
|
708
|
+
profile_exporter: bool = False,
|
|
709
|
+
):
|
|
710
|
+
"""
|
|
711
|
+
Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
|
|
712
|
+
The script applies as well to other models based on the same architecture.
|
|
713
|
+
|
|
714
|
+
The function saves everything on disk. It does not generate new inputs
|
|
715
|
+
on the second run but reuses the saved ones. Same goes for the expected
|
|
716
|
+
outputs with are also saved on disk.
|
|
717
|
+
|
|
718
|
+
:param model_id: model id
|
|
719
|
+
:param device: device
|
|
720
|
+
:param dtype: dtype
|
|
721
|
+
:param exporter: exportor to use
|
|
722
|
+
:param pretrained: pretrained=False is usually used to test
|
|
723
|
+
:param second_input: checks discrepancies on more examples
|
|
724
|
+
:param make_zip: creates a zip at the end
|
|
725
|
+
:param output_folder: output folder
|
|
726
|
+
:param part: "" to export the whole model, ``"vision"`` for vision part,
|
|
727
|
+
...
|
|
728
|
+
:param atol: raises an exception if tolerance is above that threshold
|
|
729
|
+
:param mismatch01: raises an exception if the ratio of mismatches
|
|
730
|
+
is above that threshold
|
|
731
|
+
:param profile_exporter: profiles the exporter
|
|
732
|
+
"""
|
|
733
|
+
prefix = simplify_model_id_for_a_filename(model_id)
|
|
734
|
+
basename = os.path.join(
|
|
735
|
+
output_folder, f"model.{prefix}.{part}.{device}.{dtype}.{exporter}"
|
|
736
|
+
)
|
|
737
|
+
filename = f"{basename}.onnx"
|
|
738
|
+
stat_file = f"{basename}.stats"
|
|
739
|
+
|
|
740
|
+
print("------------------------------------------------------------------")
|
|
741
|
+
print(f"-- model_id={model_id}")
|
|
742
|
+
print(f"-- part={part}")
|
|
743
|
+
print(f"-- device={device}")
|
|
744
|
+
print(f"-- dtype={dtype}")
|
|
745
|
+
print(f"-- exporter={exporter}")
|
|
746
|
+
print(f"-- pretrained={pretrained}")
|
|
747
|
+
print(f"-- second_input={second_input}")
|
|
748
|
+
print(f"-- make_zip={make_zip}")
|
|
749
|
+
print(f"-- output_folder={output_folder}")
|
|
750
|
+
print(f"-- atol={atol}")
|
|
751
|
+
print(f"-- mismatch01={mismatch01}")
|
|
752
|
+
print(f"-- profile_exporter={profile_exporter}")
|
|
753
|
+
print("------------------------------------------------------------------")
|
|
754
|
+
print(f"-- prefix={prefix}")
|
|
755
|
+
print(f"-- export in {filename!r}")
|
|
756
|
+
print("------------------------------------------------------------------")
|
|
757
|
+
|
|
758
|
+
if os.path.exists(stat_file) and not existing_onnx:
|
|
759
|
+
print(f"-- skipping because {stat_file!r} already exists")
|
|
760
|
+
return
|
|
761
|
+
|
|
762
|
+
print("-- import torch and others")
|
|
763
|
+
import torch
|
|
764
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
765
|
+
from ..helpers import string_type, string_diff, max_diff
|
|
766
|
+
from ..torch_export_patches import torch_export_patches
|
|
767
|
+
from ..torch_export_patches.patch_details import PatchDetails
|
|
768
|
+
from ..torch_export_patches.patch_inputs import use_dyn_not_str
|
|
769
|
+
from ..export.api import to_onnx
|
|
770
|
+
|
|
771
|
+
if output_folder and output_folder != ".":
|
|
772
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
773
|
+
|
|
774
|
+
print(f"-- create model {model_id!r}")
|
|
775
|
+
print(
|
|
776
|
+
f"-- device={device!r}, dtype={dtype!r}, exporter={exporter!r}, "
|
|
777
|
+
f"pretrained={pretrained!r}"
|
|
778
|
+
)
|
|
779
|
+
torch_dtype = get_torch_dtype_from_command_line_args(dtype)
|
|
780
|
+
|
|
781
|
+
if pretrained:
|
|
782
|
+
print("-- pretrained model")
|
|
783
|
+
config = AutoConfig.from_pretrained(
|
|
784
|
+
model_id, trust_remote_code=True, attn_implementation="sdpa"
|
|
785
|
+
)
|
|
786
|
+
model = AutoModelForCausalLM.from_pretrained(
|
|
787
|
+
model_id,
|
|
788
|
+
config=config,
|
|
789
|
+
trust_remote_code=True,
|
|
790
|
+
torch_dtype=torch_dtype,
|
|
791
|
+
device_map=device,
|
|
792
|
+
attn_implementation="sdpa",
|
|
793
|
+
).eval()
|
|
794
|
+
data = dict(model=model)
|
|
795
|
+
else:
|
|
796
|
+
print("-- random model")
|
|
797
|
+
config = AutoConfig.from_pretrained(
|
|
798
|
+
model_id, trust_remote_code=True, attn_implementation="sdpa"
|
|
799
|
+
)
|
|
800
|
+
config.attn_implementation = "sdpa"
|
|
801
|
+
config._attn_implementation = "sdpa"
|
|
802
|
+
config.num_hidden_layers = 2
|
|
803
|
+
model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
|
|
804
|
+
data = dict(model=model)
|
|
805
|
+
|
|
806
|
+
main_mod_name = model.__module__
|
|
807
|
+
assert (
|
|
808
|
+
main_mod_name in sys.modules
|
|
809
|
+
), f"Unable to find {main_mod_name!r} in {pprint.pformat(list(sys.modules))}"
|
|
810
|
+
main_mod = sys.modules[main_mod_name]
|
|
811
|
+
model = model.to(device).to(getattr(torch, dtype))
|
|
812
|
+
mod_siglip_name = model.model.embed_tokens_extend.image_embed.img_processor.__module__
|
|
813
|
+
assert (
|
|
814
|
+
mod_siglip_name in sys.modules
|
|
815
|
+
), f"Unable to find {mod_siglip_name!r} in {pprint.pformat(list(sys.modules))}"
|
|
816
|
+
mod_siglip = sys.modules[mod_siglip_name]
|
|
817
|
+
|
|
818
|
+
print(f"-- config._attn_implementation={model.config._attn_implementation}")
|
|
819
|
+
print(f"-- model.dtype={model.dtype}")
|
|
820
|
+
print(f"-- model.device={model.device}")
|
|
821
|
+
|
|
822
|
+
export_inputs, other_inputs = None, None
|
|
823
|
+
if not part:
|
|
824
|
+
# used to unit test
|
|
825
|
+
from ..helpers.torch_helper import to_any
|
|
826
|
+
|
|
827
|
+
assert "inputs" in data, f"key 'inputs' is missing from data (available {set(data)})"
|
|
828
|
+
model_to_export = data["model"]
|
|
829
|
+
model_to_export.eval()
|
|
830
|
+
export_inputs = to_any(to_any(data["inputs"], device), torch_dtype)
|
|
831
|
+
other_inputs = [
|
|
832
|
+
v for k, v in data.items() if k.startswith("inputs_") if k != "inputs_prompt"
|
|
833
|
+
]
|
|
834
|
+
dynamic_shapes = data["dynamic_shapes"]
|
|
835
|
+
assert other_inputs, f"No other inputs was found from data (available {set(data)})"
|
|
836
|
+
|
|
837
|
+
elif part == "vision":
|
|
838
|
+
|
|
839
|
+
class VisionPart(torch.nn.Module):
|
|
840
|
+
def __init__(self, model):
|
|
841
|
+
super().__init__()
|
|
842
|
+
self.model = model
|
|
843
|
+
|
|
844
|
+
def forward(
|
|
845
|
+
self, input_ids, input_image_embeds, image_attention_mask, image_sizes
|
|
846
|
+
):
|
|
847
|
+
torch._check(input_image_embeds.shape[-2] >= 28)
|
|
848
|
+
torch._check(input_image_embeds.shape[-1] >= 28)
|
|
849
|
+
return model.model.embed_tokens_extend.image_embed(
|
|
850
|
+
input_ids=input_ids,
|
|
851
|
+
input_embeds=input_image_embeds,
|
|
852
|
+
image_attention_mask=image_attention_mask,
|
|
853
|
+
image_sizes=image_sizes,
|
|
854
|
+
wte=model.model.embed_tokens,
|
|
855
|
+
)
|
|
856
|
+
|
|
857
|
+
model_to_export = VisionPart(model)
|
|
858
|
+
model_to_export.eval()
|
|
859
|
+
|
|
860
|
+
dynamic_shapes = {
|
|
861
|
+
"input_ids": {1: "seq_length"},
|
|
862
|
+
"input_image_embeds": {
|
|
863
|
+
0: "num_images",
|
|
864
|
+
1: "max_num_crops",
|
|
865
|
+
3: "height",
|
|
866
|
+
4: "width",
|
|
867
|
+
},
|
|
868
|
+
"image_attention_mask": {0: "num_images", 1: "max_num_crops"},
|
|
869
|
+
"image_sizes": {0: "num_images"},
|
|
870
|
+
}
|
|
871
|
+
|
|
872
|
+
else:
|
|
873
|
+
raise NotImplementedError(f"no export yet for part={part!r}")
|
|
874
|
+
|
|
875
|
+
print(f"-- part={part!r}")
|
|
876
|
+
print(f"-- model_to_export={type(model_to_export)}")
|
|
877
|
+
print(f"-- dynamic_shapes={dynamic_shapes}")
|
|
878
|
+
print("-- ############")
|
|
879
|
+
print("-- INPUT/OUTPUT")
|
|
880
|
+
print("-- ############")
|
|
881
|
+
|
|
882
|
+
input_filename = os.path.join(output_folder, f"inputs.{prefix}.{part}.{device}.{dtype}.pt")
|
|
883
|
+
if os.path.exists(input_filename):
|
|
884
|
+
print(f"-- restore inputs from {input_filename!r}")
|
|
885
|
+
data = torch.load(input_filename, weights_only=False)
|
|
886
|
+
export_inputs = data["export_inputs"]
|
|
887
|
+
other_inputs = data["other_inputs"]
|
|
888
|
+
dynamic_shapes = data["dynamic_shapes"]
|
|
889
|
+
elif export_inputs is not None:
|
|
890
|
+
data = dict(
|
|
891
|
+
export_inputs=export_inputs,
|
|
892
|
+
other_inputs=other_inputs,
|
|
893
|
+
dynamic_shapes=dynamic_shapes,
|
|
894
|
+
)
|
|
895
|
+
print(f"-- dump inputs into {input_filename!r}")
|
|
896
|
+
torch.save(data, input_filename)
|
|
897
|
+
else:
|
|
898
|
+
export_inputs, other_inputs = get_inputs_for_part(
|
|
899
|
+
model_id,
|
|
900
|
+
part,
|
|
901
|
+
torch_dtype,
|
|
902
|
+
device,
|
|
903
|
+
second_input,
|
|
904
|
+
)
|
|
905
|
+
data = dict(
|
|
906
|
+
export_inputs=export_inputs,
|
|
907
|
+
other_inputs=other_inputs,
|
|
908
|
+
dynamic_shapes=dynamic_shapes,
|
|
909
|
+
)
|
|
910
|
+
print(f"-- dump inputs into {input_filename!r}")
|
|
911
|
+
torch.save(data, input_filename)
|
|
912
|
+
|
|
913
|
+
print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}")
|
|
914
|
+
print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}")
|
|
915
|
+
print(f"-- dynamic_shapes={dynamic_shapes}")
|
|
916
|
+
output_filename = os.path.join(
|
|
917
|
+
output_folder, f"expected.{prefix}.visual.{device}.{dtype}.pt"
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
print("-- ##################")
|
|
921
|
+
print("-- # EXPECTED_OUTPUTS")
|
|
922
|
+
print("-- ##################")
|
|
923
|
+
|
|
924
|
+
export_expected, *_ = compute_expected_outputs(
|
|
925
|
+
output_filename, model_to_export, input_filename
|
|
926
|
+
)
|
|
927
|
+
|
|
928
|
+
if existing_onnx and os.path.exists(existing_onnx):
|
|
929
|
+
print("-- ######")
|
|
930
|
+
print(f"-- USING EXISTING ONNX {existing_onnx!r}")
|
|
931
|
+
print("-- ######")
|
|
932
|
+
|
|
933
|
+
exporter = existing_onnx
|
|
934
|
+
filename = existing_onnx
|
|
935
|
+
target_opset = None
|
|
936
|
+
else:
|
|
937
|
+
print("-- ######")
|
|
938
|
+
print("-- EXPORT")
|
|
939
|
+
print("-- ######")
|
|
940
|
+
|
|
941
|
+
additional_patches = get_patches(main_mod, mod_siglip)
|
|
942
|
+
|
|
943
|
+
begin = time.perf_counter()
|
|
944
|
+
|
|
945
|
+
target_opset = 22
|
|
946
|
+
|
|
947
|
+
details = PatchDetails()
|
|
948
|
+
with torch_export_patches(
|
|
949
|
+
patch_torch=True, # needed for DynamicDimConstraintPrinter
|
|
950
|
+
patch_sympy=False,
|
|
951
|
+
patch_transformers=True,
|
|
952
|
+
verbose=1,
|
|
953
|
+
stop_if_static=0,
|
|
954
|
+
profile=(f"{basename}.profile.html" if profile_exporter else None),
|
|
955
|
+
custom_patches=additional_patches,
|
|
956
|
+
patch_details=details,
|
|
957
|
+
):
|
|
958
|
+
# let's again the patched code runs
|
|
959
|
+
patched_expected = model_to_export(**export_inputs)
|
|
960
|
+
diff = max_diff(export_expected, patched_expected, hist=[0.1, 0.01])
|
|
961
|
+
print(f"-- discrepancies PATCHED/ORIGINAL {string_diff(diff)}")
|
|
962
|
+
assert diff["abs"] < atol, (
|
|
963
|
+
f"Patches do not output the same values\n"
|
|
964
|
+
f"\nexpected={string_type(export_expected, with_shape=True)}"
|
|
965
|
+
f"\n patched={string_type(patched_expected, with_shape=True)}"
|
|
966
|
+
f"\ndiff={string_diff(diff)}"
|
|
967
|
+
)
|
|
968
|
+
if details and not os.path.exists(f"{basename}.patches_details.rst"):
|
|
969
|
+
print("-- builds patch details")
|
|
970
|
+
ep = torch.export.export(
|
|
971
|
+
model_to_export,
|
|
972
|
+
(),
|
|
973
|
+
kwargs=export_inputs,
|
|
974
|
+
dynamic_shapes=use_dyn_not_str(dynamic_shapes),
|
|
975
|
+
)
|
|
976
|
+
patches = details.patches_involded_in_graph(ep.graph)
|
|
977
|
+
report = details.make_report(patches, format="rst")
|
|
978
|
+
with open(f"{basename}.patches_details.rst", "w") as f:
|
|
979
|
+
f.write(report)
|
|
980
|
+
with open(f"{basename}.ep", "w") as f:
|
|
981
|
+
f.write(str(ep))
|
|
982
|
+
with open(f"{basename}.graph", "w") as f:
|
|
983
|
+
f.write(str(ep.graph))
|
|
984
|
+
print("-- done writing patch details")
|
|
985
|
+
|
|
986
|
+
to_onnx(
|
|
987
|
+
model_to_export,
|
|
988
|
+
kwargs=export_inputs,
|
|
989
|
+
dynamic_shapes=dynamic_shapes,
|
|
990
|
+
filename=filename,
|
|
991
|
+
exporter=exporter,
|
|
992
|
+
verbose=1,
|
|
993
|
+
save_ep=None,
|
|
994
|
+
target_opset=target_opset,
|
|
995
|
+
optimize=True,
|
|
996
|
+
)
|
|
997
|
+
export_duration = time.perf_counter() - begin
|
|
998
|
+
|
|
999
|
+
print("-- ###############")
|
|
1000
|
+
print("-- # DISCREPANCIES")
|
|
1001
|
+
print("-- ###############")
|
|
1002
|
+
|
|
1003
|
+
info = {
|
|
1004
|
+
"model_id": model_id,
|
|
1005
|
+
"part": part,
|
|
1006
|
+
"device": device,
|
|
1007
|
+
"dtype": dtype,
|
|
1008
|
+
"exporter": exporter,
|
|
1009
|
+
"pretrained": pretrained,
|
|
1010
|
+
"attention": os.environ.get("QWEN25ATTENTION", "default"),
|
|
1011
|
+
}
|
|
1012
|
+
|
|
1013
|
+
check_for_discrepancies_and_log_everything_into_a_json_file(
|
|
1014
|
+
agg_stat_file=os.path.join(output_folder, "collection_statistics.js"),
|
|
1015
|
+
stat_file=stat_file,
|
|
1016
|
+
export_duration=export_duration,
|
|
1017
|
+
device=device,
|
|
1018
|
+
model_file=filename,
|
|
1019
|
+
cached_inputs=input_filename,
|
|
1020
|
+
cached_expected_outputs=output_filename,
|
|
1021
|
+
main_info=info,
|
|
1022
|
+
atol=atol,
|
|
1023
|
+
mismatch01=mismatch01,
|
|
1024
|
+
)
|
|
1025
|
+
|
|
1026
|
+
if make_zip:
|
|
1027
|
+
print("-- #####")
|
|
1028
|
+
print("-- # ZIP")
|
|
1029
|
+
print("-- #####")
|
|
1030
|
+
zip_model_and_data_into_a_single_file(f"{basename}.zip", filename)
|
|
1031
|
+
|
|
1032
|
+
|
|
1033
|
+
if __name__ == "__main__":
|
|
1034
|
+
parser = get_parser(
|
|
1035
|
+
"qwen25",
|
|
1036
|
+
epilog=textwrap.dedent(
|
|
1037
|
+
r"""
|
|
1038
|
+
Tested command lines::
|
|
1039
|
+
|
|
1040
|
+
python -m onnx_diagnostic.ci_models.export_phi4_mm \
|
|
1041
|
+
-m microsoft/Phi-4-multimodal-instruct \
|
|
1042
|
+
--device cuda --dtype float16 --exporter custom \
|
|
1043
|
+
--pretrained --second-input --part vision
|
|
1044
|
+
"""
|
|
1045
|
+
),
|
|
1046
|
+
)
|
|
1047
|
+
args = parser.parse_args(sys.argv[1:])
|
|
1048
|
+
main(
|
|
1049
|
+
model_id=args.mid,
|
|
1050
|
+
device=args.device,
|
|
1051
|
+
dtype=args.dtype,
|
|
1052
|
+
exporter=args.exporter,
|
|
1053
|
+
pretrained=args.pretrained,
|
|
1054
|
+
second_input=args.second_input,
|
|
1055
|
+
make_zip=args.zip,
|
|
1056
|
+
output_folder=args.output_folder,
|
|
1057
|
+
existing_onnx=args.existing_onnx,
|
|
1058
|
+
part=args.part,
|
|
1059
|
+
atol=args.atol,
|
|
1060
|
+
mismatch01=args.mismatch01,
|
|
1061
|
+
profile_exporter=args.profile_exporter,
|
|
1062
|
+
)
|