onnx-diagnostic 0.8.6__py3-none-any.whl → 0.8.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +108 -3
  3. onnx_diagnostic/ci_models/ci_helpers.py +12 -7
  4. onnx_diagnostic/ci_models/export_phi4_mm.py +1062 -0
  5. onnx_diagnostic/ci_models/export_qwen25_vl.py +12 -4
  6. onnx_diagnostic/export/api.py +1 -0
  7. onnx_diagnostic/export/cf_simple_loop_for.py +195 -10
  8. onnx_diagnostic/ext_test_case.py +9 -2
  9. onnx_diagnostic/helpers/bench_run.py +1 -1
  10. onnx_diagnostic/helpers/log_helper.py +1 -3
  11. onnx_diagnostic/helpers/optim_helper.py +116 -0
  12. onnx_diagnostic/tasks/image_text_to_text.py +15 -5
  13. onnx_diagnostic/tasks/text2text_generation.py +84 -48
  14. onnx_diagnostic/tasks/text_generation.py +3 -0
  15. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +28 -2
  16. onnx_diagnostic/torch_export_patches/patch_expressions.py +4 -1
  17. onnx_diagnostic/torch_export_patches/patch_module.py +31 -23
  18. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_funnel.py +80 -0
  19. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +12 -1
  20. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +15 -0
  21. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +22 -24
  22. onnx_diagnostic/torch_models/hghub/hub_api.py +11 -0
  23. onnx_diagnostic/torch_models/hghub/hub_data.py +9 -1
  24. onnx_diagnostic/torch_models/hghub/model_inputs.py +24 -19
  25. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/METADATA +1 -1
  26. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/RECORD +29 -26
  27. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/WHEEL +0 -0
  28. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/licenses/LICENSE.txt +0 -0
  29. {onnx_diagnostic-0.8.6.dist-info → onnx_diagnostic-0.8.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1062 @@
1
+ r"""
2
+ Export visual and embedding parts of microsoft/Phi-4-multimodal-instruct
3
+ ========================================================================
4
+
5
+ Requirements
6
+ ++++++++++++
7
+
8
+ ::
9
+
10
+ git+https://github.com/sdpython/experimental-experiment.git # optional
11
+ backoff
12
+ huggingface_hub
13
+ onnx-diagnostic>=0.8.6
14
+ onnxruntime>=1.23
15
+ peft==0.17.1
16
+ Pillow
17
+ requests
18
+ torch>=2.10 # weekly is better
19
+ tqdm
20
+ transformers==4.48.3
21
+
22
+ .. note::
23
+
24
+ ``flash_attn`` must be removed to export if it was installed.
25
+
26
+ Examples
27
+ ++++++++
28
+
29
+ .. code-block:: bash
30
+
31
+ python -m onnx_diagnostic.ci_models.export_phi4_mm \
32
+ -m microsoft/Phi-4-multimodal-instruct --device cuda --dtype float16 \
33
+ --exporter custom --pretrained --second-input --part vision
34
+ """
35
+
36
+ import os
37
+ import pprint
38
+ import sys
39
+ import textwrap
40
+ import time
41
+ from typing import Dict, List, Optional, Tuple, Union
42
+
43
+ from .ci_helpers import (
44
+ check_for_discrepancies_and_log_everything_into_a_json_file,
45
+ compute_expected_outputs,
46
+ get_parser,
47
+ get_torch_dtype_from_command_line_args,
48
+ simplify_model_id_for_a_filename,
49
+ zip_model_and_data_into_a_single_file,
50
+ )
51
+
52
+
53
+ def get_patches_transformers():
54
+ import re
55
+ from itertools import cycle
56
+ import torch
57
+ import transformers
58
+
59
+ class patched_PreTrainedModel(torch.nn.Module):
60
+ _PATCHES_ = ["get_expanded_tied_weights_keys"]
61
+ _PATCHED_CLASS_ = transformers.modeling_utils.PreTrainedModel
62
+
63
+ def get_expanded_tied_weights_keys(self, all_submodels: bool = False) -> dict:
64
+ if all_submodels:
65
+ expanded_tied_weights = {}
66
+ for prefix, submodule in self.named_modules(remove_duplicate=False):
67
+ if isinstance(submodule, transformers.modeling_utils.PreTrainedModel):
68
+ submodel_tied_weights = submodule.get_expanded_tied_weights_keys(
69
+ all_submodels=False
70
+ )
71
+ if prefix != "":
72
+ submodel_tied_weights = {
73
+ f"{prefix}.{k}": f"{prefix}.{v}"
74
+ for k, v in submodel_tied_weights.items()
75
+ }
76
+ expanded_tied_weights.update(submodel_tied_weights)
77
+ return expanded_tied_weights
78
+
79
+ tied_mapping = self._tied_weights_keys
80
+ if not self.config.tie_word_embeddings and not self.config.tie_encoder_decoder:
81
+ return {}
82
+ elif tied_mapping is None:
83
+ return {}
84
+ common_case_regex = re.compile(r"^[A-Za-z0-9_\.]+(weight)|(bias)$")
85
+ # PATCHED
86
+ if tied_mapping == ["lm_head.weight"]:
87
+ tied_mapping = {"lm_head.weight": "model.embed_tokens.weight"}
88
+ if all(
89
+ common_case_regex.match(k) for k in tied_mapping.keys() | tied_mapping.values()
90
+ ):
91
+ return tied_mapping.copy()
92
+
93
+ expanded_tied_weights = {}
94
+ all_param_names = {k for k, _ in self.named_parameters(remove_duplicate=False)} | {
95
+ k for k, _ in self.named_buffers(remove_duplicate=False)
96
+ }
97
+ for target_name, source_name in tied_mapping.items():
98
+ target_name = "^" + target_name
99
+ source_name = "^" + source_name
100
+
101
+ source_params = sorted(
102
+ filter(lambda x: re.search(source_name, x), all_param_names)
103
+ )
104
+ target_params = sorted(
105
+ filter(lambda x: re.search(target_name, x), all_param_names)
106
+ )
107
+ if (
108
+ not len(source_params) > 0
109
+ or not len(target_params) > 0
110
+ or len(target_params) % len(source_params) != 0
111
+ ):
112
+ raise ValueError(
113
+ f"There is an issue with your definition of "
114
+ f"`tie_weights_keys` for {source_name}:{target_name}. "
115
+ f"We found {source_params} to tie into {target_params}"
116
+ )
117
+ for target_n, source_n in zip(target_params, cycle(source_params)):
118
+ if source_n in expanded_tied_weights.keys():
119
+ expanded_tied_weights[target_n] = expanded_tied_weights[source_n]
120
+ else:
121
+ expanded_tied_weights[target_n] = source_n
122
+
123
+ return expanded_tied_weights
124
+
125
+ return [patched_PreTrainedModel]
126
+
127
+
128
+ def get_patches(mod, mod_siglip):
129
+ import torch
130
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
131
+ from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask
132
+ from ..export.cf_simple_loop_for import simple_loop_for
133
+
134
+ _IMAGE_SPECIAL_TOKEN_ID = mod._IMAGE_SPECIAL_TOKEN_ID
135
+
136
+ class patched_SiglipVisionEmbeddings(torch.nn.Module):
137
+ _PATCHES_ = ["forward"]
138
+ _PATCHED_CLASS_ = mod_siglip.SiglipVisionEmbeddings
139
+
140
+ def forward(
141
+ self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor
142
+ ) -> torch.Tensor:
143
+ batch_size = pixel_values.size(0)
144
+
145
+ patch_embeds = self.patch_embedding(pixel_values)
146
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
147
+
148
+ max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3)
149
+ max_nb_patches_h, max_nb_patches_w = (
150
+ max_im_h // self.patch_size,
151
+ max_im_w // self.patch_size,
152
+ )
153
+ boundaries = torch.arange(
154
+ torch.tensor(1 / self.num_patches_per_side, dtype=pixel_values.dtype),
155
+ torch.tensor(1.0, dtype=pixel_values.dtype),
156
+ torch.tensor(1 / self.num_patches_per_side, dtype=pixel_values.dtype),
157
+ )
158
+ position_ids = torch.full(
159
+ size=(batch_size, max_nb_patches_h * max_nb_patches_w), fill_value=0
160
+ )
161
+
162
+ # PATHED: a loop replace with scan.
163
+
164
+ def body(p_attn_mask, position_ids_row, boundaries):
165
+ h_len = torch.tensor(1, dtype=boundaries.dtype) / p_attn_mask[:, 0].sum()
166
+ w_len = torch.tensor(1, dtype=boundaries.dtype) / p_attn_mask[0].sum()
167
+ torch._check(h_len.item() > 0)
168
+ fractional_coords_h = torch.arange(
169
+ torch.tensor(0.0, dtype=boundaries.dtype),
170
+ torch.tensor(1 - 1e-6, dtype=boundaries.dtype),
171
+ h_len,
172
+ )
173
+ torch._check(w_len.item() > 0)
174
+ fractional_coords_w = torch.arange(
175
+ torch.tensor(0.0, dtype=boundaries.dtype),
176
+ torch.tensor(1 - 1e-6, dtype=boundaries.dtype),
177
+ w_len,
178
+ )
179
+
180
+ bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True)
181
+ bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True)
182
+
183
+ pos_ids = (
184
+ bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w
185
+ ).flatten()
186
+
187
+ row = position_ids_row.clone()
188
+ row[p_attn_mask.view(-1)] = pos_ids
189
+ return [row]
190
+
191
+ position_ids = torch.ops.higher_order.scan(
192
+ body, [], [patch_attention_mask, position_ids], additional_inputs=[boundaries]
193
+ )[0]
194
+
195
+ position_ids = position_ids.to(self.position_embedding.weight.device)
196
+ embeddings = embeddings + self.position_embedding(position_ids)
197
+ return embeddings
198
+
199
+ class patched_SiglipVisionTransformer(torch.nn.Module):
200
+ _PATCHES_ = ["forward"]
201
+ _PATCHED_CLASS_ = mod_siglip.SiglipVisionTransformer
202
+
203
+ def forward(
204
+ self,
205
+ pixel_values,
206
+ patch_attention_mask: Optional[torch.BoolTensor] = None,
207
+ output_attentions: Optional[bool] = None,
208
+ output_hidden_states: Optional[bool] = None,
209
+ return_dict: Optional[bool] = None,
210
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
211
+ output_attentions = (
212
+ output_attentions
213
+ if output_attentions is not None
214
+ else self.config.output_attentions
215
+ )
216
+ output_hidden_states = (
217
+ output_hidden_states
218
+ if output_hidden_states is not None
219
+ else self.config.output_hidden_states
220
+ )
221
+ return_dict = (
222
+ return_dict if return_dict is not None else self.config.use_return_dict
223
+ )
224
+
225
+ batch_size = pixel_values.size(0)
226
+ if patch_attention_mask is None:
227
+ patch_attention_mask = torch.ones(
228
+ size=(
229
+ batch_size,
230
+ pixel_values.size(2) // self.config.patch_size,
231
+ pixel_values.size(3) // self.config.patch_size,
232
+ ),
233
+ dtype=torch.bool,
234
+ device=pixel_values.device,
235
+ )
236
+
237
+ hidden_states = self.embeddings(
238
+ pixel_values=pixel_values, patch_attention_mask=patch_attention_mask
239
+ )
240
+
241
+ patch_attention_mask = patch_attention_mask.view(batch_size, -1)
242
+ # PATCHED: skip the test
243
+ # if not torch.any(~patch_attention_mask):
244
+ # attention_mask = None
245
+ # else:
246
+ # attention_mask = (
247
+ # _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype)
248
+ ## if not self.config._flash_attn_2_enabled
249
+ # else patch_attention_mask
250
+ # )
251
+ attention_mask = _prepare_4d_attention_mask(
252
+ patch_attention_mask, hidden_states.dtype
253
+ )
254
+
255
+ encoder_outputs = self.encoder(
256
+ inputs_embeds=hidden_states,
257
+ attention_mask=attention_mask,
258
+ output_attentions=output_attentions,
259
+ output_hidden_states=output_hidden_states,
260
+ return_dict=return_dict,
261
+ )
262
+
263
+ last_hidden_state = encoder_outputs[0]
264
+ last_hidden_state = self.post_layernorm(last_hidden_state)
265
+
266
+ pooled_output = self.head(
267
+ hidden_state=last_hidden_state,
268
+ attention_mask=patch_attention_mask,
269
+ )
270
+
271
+ if not return_dict:
272
+ return (last_hidden_state, pooled_output, *encoder_outputs[1:])
273
+
274
+ return BaseModelOutputWithPooling(
275
+ last_hidden_state=last_hidden_state,
276
+ pooler_output=pooled_output,
277
+ hidden_states=encoder_outputs.hidden_states,
278
+ attentions=encoder_outputs.attentions,
279
+ )
280
+
281
+ class patched_Phi4MMImageEmbedding(torch.nn.Module):
282
+ _PATCHES_ = ["forward"]
283
+ _PATCHED_CLASS_ = mod.Phi4MMImageEmbedding
284
+
285
+ def forward(
286
+ self,
287
+ input_ids: torch.LongTensor,
288
+ input_embeds: torch.FloatTensor,
289
+ image_sizes=None,
290
+ **kwargs,
291
+ ) -> torch.FloatTensor:
292
+
293
+ if isinstance(input_ids, tuple):
294
+ input_ids, input_embeds = input_ids
295
+
296
+ img_embeds = input_embeds
297
+ if image_sizes is None and "image_sizes" in kwargs:
298
+ image_sizes = kwargs["image_sizes"]
299
+ img_sizes = image_sizes
300
+
301
+ if self.img_features is not None:
302
+ img_embeds = self.img_features.clone()
303
+ self.img_features = None
304
+
305
+ if self.img_sizes is not None:
306
+ img_sizes = self.img_sizes
307
+
308
+ dtype = self.img_processor.embeddings.patch_embedding.weight.dtype
309
+ if img_embeds is not None:
310
+ img_embeds = img_embeds.to(dtype)
311
+
312
+ if self.image_attention_mask is not None:
313
+ image_attention_mask = self.image_attention_mask.clone()
314
+ self.image_attention_mask = None
315
+ elif "image_attention_mask" in kwargs:
316
+ image_attention_mask = kwargs["image_attention_mask"]
317
+ else:
318
+ image_attention_mask = None
319
+ input_shape = input_ids.size()
320
+ input_ids = input_ids.view(-1, input_shape[-1])
321
+
322
+ with torch.no_grad():
323
+ positions_tuple = torch.nonzero(
324
+ input_ids == _IMAGE_SPECIAL_TOKEN_ID, as_tuple=True
325
+ )
326
+
327
+ select = False
328
+ hd_transform = False
329
+
330
+ if isinstance(self.img_projection, torch.nn.Sequential):
331
+ target_device = self.img_projection[0].bias.device
332
+ else:
333
+ target_device = self.img_projection.bias.device
334
+
335
+ # PATCHED: Let's assume it is always true.
336
+ if True: # len(positions.tolist()) > 0:
337
+ if self.use_hd_transform and img_sizes is not None:
338
+ hd_transform = True
339
+ bs = img_embeds.shape[0]
340
+ if image_attention_mask is not None:
341
+ img_features = self.get_img_features(
342
+ img_embeds.flatten(0, 1),
343
+ attention_mask=image_attention_mask.type(torch.BoolTensor)
344
+ .flatten(0, 1)
345
+ .to(target_device),
346
+ )
347
+ else:
348
+ img_features = self.get_img_features(img_embeds.flatten(0, 1))
349
+
350
+ base_resolution = self.crop_size
351
+ base_feat_height_reduction = self.base_feat_height_reduction
352
+
353
+ base_feat_height = base_feat_width = torch.sym_int(
354
+ img_features.shape[1] ** 0.5
355
+ )
356
+ img_features = img_features.view(
357
+ bs, -1, base_feat_height * base_feat_width, self.image_dim_out
358
+ )
359
+ C = self.image_dim_out
360
+ H = base_feat_height
361
+
362
+ if isinstance(img_sizes, torch.Tensor):
363
+ img_sizes = img_sizes.view(-1, 2)
364
+ else:
365
+ raise NotImplementedError
366
+ select = True
367
+
368
+ hidden_states = kwargs["wte"](input_ids)
369
+
370
+ assert select
371
+ if hd_transform:
372
+
373
+ def body_fn(
374
+ _bs,
375
+ img_features,
376
+ img_sizes,
377
+ image_attention_mask,
378
+ cst_shape_CH,
379
+ glb_GN,
380
+ sub_GN,
381
+ proj_0_weight,
382
+ proj_0_bias,
383
+ proj_1_weight,
384
+ proj_1_bias,
385
+ base_resolution=None,
386
+ base_feat_height_reduction=None,
387
+ base_feat_height=None,
388
+ base_feat_width=None,
389
+ ):
390
+ # oddly, it seems impossible to write img_sizes[_bs.item()]
391
+ # it needs img_sizes[_bs.item() : (_bs + 1).item()][0]
392
+ row = img_sizes[_bs.item() : (_bs + 1).item()]
393
+ row = row[0]
394
+ h, w = row[0], row[1]
395
+ h = h // base_resolution
396
+ w = w // base_resolution
397
+ B_ = h * w
398
+ C, H = cst_shape_CH.shape
399
+
400
+ # 1 x (24x24) x 1024
401
+ global_img_feature = img_features[_bs.item() : (_bs + 1).item(), :1][0]
402
+
403
+ # 1 x 12 x 12 x 4096
404
+ glb_img = (
405
+ global_img_feature.reshape(1, H, H, C)
406
+ .reshape(
407
+ 1,
408
+ H // base_feat_height_reduction,
409
+ base_feat_height_reduction,
410
+ H // base_feat_height_reduction,
411
+ base_feat_height_reduction,
412
+ C,
413
+ )
414
+ .contiguous()
415
+ .permute(0, 1, 3, 2, 4, 5)
416
+ .reshape(
417
+ 1,
418
+ H // base_feat_height_reduction,
419
+ H // base_feat_height_reduction,
420
+ base_feat_height_reduction * base_feat_height_reduction * C,
421
+ )
422
+ .contiguous()
423
+ )
424
+ temp_glb_GN = sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1)
425
+
426
+ # 1 x 156 x 4096
427
+ glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape(
428
+ 1, -1, base_feat_height_reduction * base_feat_height_reduction * C
429
+ )
430
+
431
+ # (max_num_crops-1) x (12x12) x C
432
+ sub_img = img_features[_bs.item() : (_bs + 1).item(), 1:][0]
433
+ # 16x574x1024
434
+ # get rid of padding sub_img
435
+ sub_img = sub_img[: B_.item()]
436
+
437
+ # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024)
438
+ # -> (num_crops, 12*12, 4*1024)
439
+ sub_img = (
440
+ sub_img.reshape(B_.item(), H, H, C)
441
+ .reshape(
442
+ B_.item(),
443
+ H // base_feat_height_reduction,
444
+ base_feat_height_reduction,
445
+ H // base_feat_height_reduction,
446
+ base_feat_height_reduction,
447
+ C,
448
+ )
449
+ .contiguous()
450
+ .permute(0, 1, 3, 2, 4, 5)
451
+ .reshape(
452
+ B_.item(),
453
+ -1,
454
+ base_feat_height_reduction * base_feat_height_reduction * C,
455
+ )
456
+ .contiguous()
457
+ )
458
+ sub_img = (
459
+ sub_img.reshape(
460
+ 1,
461
+ h.item(),
462
+ w.item(),
463
+ base_feat_height // base_feat_height_reduction,
464
+ base_feat_width // base_feat_height_reduction,
465
+ -1,
466
+ )
467
+ .permute(0, 1, 3, 2, 4, 5)
468
+ .reshape(
469
+ 1,
470
+ (h * base_feat_height // base_feat_height_reduction).item(),
471
+ (w * base_feat_width // base_feat_height_reduction).item(),
472
+ base_feat_height_reduction * base_feat_height_reduction * C,
473
+ )
474
+ )
475
+
476
+ reshaped_image_attention_mask = (
477
+ image_attention_mask[
478
+ _bs.item() : (_bs + 1).item(), 1 : (B_ + 1).item(), 0::2, 0::2
479
+ ][0]
480
+ .reshape(
481
+ 1,
482
+ h.item(),
483
+ w.item(),
484
+ base_feat_height // base_feat_height_reduction,
485
+ base_feat_width // base_feat_height_reduction,
486
+ )
487
+ .permute(0, 1, 3, 2, 4)
488
+ .reshape(
489
+ 1,
490
+ (h * base_feat_height // base_feat_height_reduction).item(),
491
+ (w * base_feat_width // base_feat_height_reduction).item(),
492
+ )
493
+ )
494
+ useful_height = (
495
+ reshaped_image_attention_mask[0, :, 0].sum().to(torch.int64).item()
496
+ )
497
+ useful_width = (
498
+ reshaped_image_attention_mask[0, 0, :].sum().to(torch.int64).item()
499
+ )
500
+ # the module cannot be extracted from here
501
+ sub_img = sub_img[:, :useful_height, :useful_width]
502
+ temp_sub_GN = sub_GN.repeat(1, useful_height, 1, 1)
503
+ # temp_len = (
504
+ # image_attention_mask[_bs, : B_ + 1, 0::2, 0::2]
505
+ # .sum()
506
+ # .to(torch.int64)
507
+ # .item()
508
+ # + (useful_height + 1)
509
+ # + base_feat_height // base_feat_height_reduction
510
+ # )
511
+
512
+ sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape(
513
+ 1, -1, base_feat_height_reduction * base_feat_height_reduction * C
514
+ )
515
+ # (1, num_img_tokens, 1024*4)
516
+
517
+ # glb + sub
518
+ # glb_sub
519
+ # output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1))
520
+ # sub_glb
521
+ _output_img = torch.cat([sub_img, glb_GN, glb_img], dim=1)
522
+ # output_len.append(temp_len)
523
+ proj = torch.nn.functional.linear(_output_img, proj_0_weight, proj_0_bias)
524
+ proj = torch.nn.functional.gelu(proj)
525
+ proj = torch.nn.functional.linear(proj, proj_1_weight, proj_1_bias)
526
+ return (proj,)
527
+
528
+ def local_body_fn(
529
+ n_iter,
530
+ img_features,
531
+ img_sizes,
532
+ image_attention_mask,
533
+ cst_shape_CH,
534
+ glb_GN,
535
+ sub_GN,
536
+ proj_0_weight,
537
+ proj_0_bias,
538
+ proj_1_weight,
539
+ proj_1_bias,
540
+ ):
541
+ return body_fn(
542
+ n_iter,
543
+ img_features,
544
+ img_sizes,
545
+ image_attention_mask,
546
+ cst_shape_CH,
547
+ glb_GN,
548
+ sub_GN,
549
+ proj_0_weight,
550
+ proj_0_bias,
551
+ proj_1_weight,
552
+ proj_1_bias,
553
+ base_resolution=base_resolution,
554
+ base_feat_height_reduction=base_feat_height_reduction,
555
+ base_feat_height=base_feat_height,
556
+ base_feat_width=base_feat_width,
557
+ )
558
+
559
+ tmp = torch.arange(bs + 1).max()
560
+ glb_GN = self.glb_GN
561
+ sub_GN = self.sub_GN
562
+ cst_shape_CH = torch.zeros((C, H), dtype=torch.int32)
563
+
564
+ merged_img_set_tensor = simple_loop_for(
565
+ tmp,
566
+ local_body_fn,
567
+ (
568
+ img_features,
569
+ img_sizes,
570
+ image_attention_mask,
571
+ cst_shape_CH,
572
+ glb_GN,
573
+ sub_GN,
574
+ self.img_projection[0].weight,
575
+ self.img_projection[0].bias,
576
+ # self.img_projection[1] is GELU
577
+ self.img_projection[2].weight,
578
+ self.img_projection[2].bias,
579
+ ),
580
+ [1],
581
+ )
582
+ torch._check(isinstance(merged_img_set_tensor, torch.Tensor))
583
+ merged_img_set_tensor = merged_img_set_tensor.squeeze(0)
584
+
585
+ # merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0)
586
+ merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(
587
+ hidden_states.device
588
+ )
589
+ with torch.autocast(device_type=hidden_states.device.type, enabled=False):
590
+ new_hidden_states = hidden_states.index_put(
591
+ indices=positions_tuple,
592
+ values=merged_img_set_tensor,
593
+ accumulate=False,
594
+ )
595
+ hidden_states = new_hidden_states
596
+ else:
597
+ raise NotImplementedError
598
+
599
+ if self.drop is not None:
600
+ hidden_states = self.drop(hidden_states)
601
+
602
+ return hidden_states
603
+
604
+ return [
605
+ *get_patches_transformers(),
606
+ patched_Phi4MMImageEmbedding,
607
+ patched_SiglipVisionEmbeddings,
608
+ patched_SiglipVisionTransformer,
609
+ ]
610
+
611
+
612
+ def get_inputs_for_part(
613
+ model_id: str,
614
+ part: str,
615
+ torch_dtype: "torch.dtype", # noqa: F821
616
+ device: str,
617
+ second_input: bool,
618
+ ) -> Tuple[Dict[str, "torch.Tensor"], List[Dict[str, "torch.Tensor"]]]: # noqa: F821
619
+ if part == "vision":
620
+ import requests
621
+ from PIL import Image
622
+ from transformers import AutoProcessor
623
+
624
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
625
+ user_prompt = "<|user|>\n"
626
+ assistant_prompt = "<|assistant|>\n"
627
+ prompt_suffix = "<|end|>\n"
628
+ prompt = (
629
+ f"{user_prompt}<|image_1|>\n<|image_2|>\n"
630
+ f"What is shown in these four images?{prompt_suffix}{assistant_prompt}"
631
+ )
632
+
633
+ root = os.path.join(os.path.dirname(__file__), "..", "..", "_small_data")
634
+ # "https://www.ilankelman.org/stopsigns/australia.jpg"
635
+ url = os.path.join(root, "American_Flamingo_JG.jpg")
636
+ image_1 = (
637
+ Image.open(requests.get(url, stream=True).raw)
638
+ if url.startswith("https")
639
+ else Image.open(url)
640
+ )
641
+ # "https://wallpaper.dog/large/10809054.jpg"
642
+ url = os.path.join(root, "RedcrestedTuraco.jpg")
643
+ image_4 = (
644
+ Image.open(requests.get(url, stream=True).raw)
645
+ if url.startswith("https")
646
+ else Image.open(url)
647
+ )
648
+
649
+ images = [image_1, image_4]
650
+ inputs = processor(prompt, images=images, return_tensors="pt").to(device)
651
+ export_inputs = dict(
652
+ input_ids=inputs["input_ids"].to(device),
653
+ input_image_embeds=inputs["input_image_embeds"].to(torch_dtype).to(device),
654
+ image_attention_mask=inputs["image_attention_mask"].to(torch_dtype).to(device),
655
+ image_sizes=inputs["image_sizes"].to(device),
656
+ )
657
+ assert (
658
+ export_inputs["input_image_embeds"].shape[-2] >= 28
659
+ and export_inputs["input_image_embeds"].shape[-1] >= 28
660
+ ), (
661
+ f"required by the exported program but shape is "
662
+ f"{export_inputs['input_image_embeds'].shape}"
663
+ )
664
+
665
+ other_inputs = []
666
+ if second_input:
667
+ prompt = (
668
+ f"{user_prompt}<|image_1|>\n<|image_2|>\n<|image_3|>\n<|image_4|>\n"
669
+ f"What is shown in these four images?{prompt_suffix}{assistant_prompt}"
670
+ )
671
+ url = "https://img.freepik.com/free-photo/painting-mountain-lake-with-mountain-background_188544-9126.jpg?w=2000"
672
+ image_2 = Image.open(requests.get(url, stream=True).raw)
673
+ url = (
674
+ "https://th.bing.com/th/id/OIP.gCvQ1vmPVJmrq1nnzM3ZHQHaEo?rs=1&pid=ImgDetMain"
675
+ )
676
+ image_3 = Image.open(requests.get(url, stream=True).raw)
677
+
678
+ images = [image_1, image_2, image_3, image_4]
679
+ inputs = processor(prompt, images=images, return_tensors="pt").to(device)
680
+ other_inputs = [
681
+ dict(
682
+ input_ids=inputs["input_ids"].to(device),
683
+ input_image_embeds=inputs["input_image_embeds"].to(torch_dtype).to(device),
684
+ image_attention_mask=inputs["image_attention_mask"]
685
+ .to(torch_dtype)
686
+ .to(device),
687
+ image_sizes=inputs["image_sizes"].to(device),
688
+ )
689
+ ]
690
+ return export_inputs, other_inputs
691
+
692
+ raise NotImplementedError(f"No inputs yet implement for part={part!r}")
693
+
694
+
695
+ def main(
696
+ model_id: str = "microsoft/Phi-4-multimodal-instruct",
697
+ device: str = "cpu",
698
+ dtype: str = "float32",
699
+ exporter: str = "onnx-dynamo",
700
+ pretrained: bool = True,
701
+ second_input: bool = True,
702
+ make_zip: bool = False,
703
+ output_folder: str = "dump_models",
704
+ existing_onnx: str | None = None,
705
+ part: str = "vision",
706
+ atol: float = 2,
707
+ mismatch01: float = 0.01,
708
+ profile_exporter: bool = False,
709
+ ):
710
+ """
711
+ Exports model Qwen/Qwen2.5-VL-7B-Instruct or pieces of it.
712
+ The script applies as well to other models based on the same architecture.
713
+
714
+ The function saves everything on disk. It does not generate new inputs
715
+ on the second run but reuses the saved ones. Same goes for the expected
716
+ outputs with are also saved on disk.
717
+
718
+ :param model_id: model id
719
+ :param device: device
720
+ :param dtype: dtype
721
+ :param exporter: exportor to use
722
+ :param pretrained: pretrained=False is usually used to test
723
+ :param second_input: checks discrepancies on more examples
724
+ :param make_zip: creates a zip at the end
725
+ :param output_folder: output folder
726
+ :param part: "" to export the whole model, ``"vision"`` for vision part,
727
+ ...
728
+ :param atol: raises an exception if tolerance is above that threshold
729
+ :param mismatch01: raises an exception if the ratio of mismatches
730
+ is above that threshold
731
+ :param profile_exporter: profiles the exporter
732
+ """
733
+ prefix = simplify_model_id_for_a_filename(model_id)
734
+ basename = os.path.join(
735
+ output_folder, f"model.{prefix}.{part}.{device}.{dtype}.{exporter}"
736
+ )
737
+ filename = f"{basename}.onnx"
738
+ stat_file = f"{basename}.stats"
739
+
740
+ print("------------------------------------------------------------------")
741
+ print(f"-- model_id={model_id}")
742
+ print(f"-- part={part}")
743
+ print(f"-- device={device}")
744
+ print(f"-- dtype={dtype}")
745
+ print(f"-- exporter={exporter}")
746
+ print(f"-- pretrained={pretrained}")
747
+ print(f"-- second_input={second_input}")
748
+ print(f"-- make_zip={make_zip}")
749
+ print(f"-- output_folder={output_folder}")
750
+ print(f"-- atol={atol}")
751
+ print(f"-- mismatch01={mismatch01}")
752
+ print(f"-- profile_exporter={profile_exporter}")
753
+ print("------------------------------------------------------------------")
754
+ print(f"-- prefix={prefix}")
755
+ print(f"-- export in {filename!r}")
756
+ print("------------------------------------------------------------------")
757
+
758
+ if os.path.exists(stat_file) and not existing_onnx:
759
+ print(f"-- skipping because {stat_file!r} already exists")
760
+ return
761
+
762
+ print("-- import torch and others")
763
+ import torch
764
+ from transformers import AutoConfig, AutoModelForCausalLM
765
+ from ..helpers import string_type, string_diff, max_diff
766
+ from ..torch_export_patches import torch_export_patches
767
+ from ..torch_export_patches.patch_details import PatchDetails
768
+ from ..torch_export_patches.patch_inputs import use_dyn_not_str
769
+ from ..export.api import to_onnx
770
+
771
+ if output_folder and output_folder != ".":
772
+ os.makedirs(output_folder, exist_ok=True)
773
+
774
+ print(f"-- create model {model_id!r}")
775
+ print(
776
+ f"-- device={device!r}, dtype={dtype!r}, exporter={exporter!r}, "
777
+ f"pretrained={pretrained!r}"
778
+ )
779
+ torch_dtype = get_torch_dtype_from_command_line_args(dtype)
780
+
781
+ if pretrained:
782
+ print("-- pretrained model")
783
+ config = AutoConfig.from_pretrained(
784
+ model_id, trust_remote_code=True, attn_implementation="sdpa"
785
+ )
786
+ model = AutoModelForCausalLM.from_pretrained(
787
+ model_id,
788
+ config=config,
789
+ trust_remote_code=True,
790
+ torch_dtype=torch_dtype,
791
+ device_map=device,
792
+ attn_implementation="sdpa",
793
+ ).eval()
794
+ data = dict(model=model)
795
+ else:
796
+ print("-- random model")
797
+ config = AutoConfig.from_pretrained(
798
+ model_id, trust_remote_code=True, attn_implementation="sdpa"
799
+ )
800
+ config.attn_implementation = "sdpa"
801
+ config._attn_implementation = "sdpa"
802
+ config.num_hidden_layers = 2
803
+ model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
804
+ data = dict(model=model)
805
+
806
+ main_mod_name = model.__module__
807
+ assert (
808
+ main_mod_name in sys.modules
809
+ ), f"Unable to find {main_mod_name!r} in {pprint.pformat(list(sys.modules))}"
810
+ main_mod = sys.modules[main_mod_name]
811
+ model = model.to(device).to(getattr(torch, dtype))
812
+ mod_siglip_name = model.model.embed_tokens_extend.image_embed.img_processor.__module__
813
+ assert (
814
+ mod_siglip_name in sys.modules
815
+ ), f"Unable to find {mod_siglip_name!r} in {pprint.pformat(list(sys.modules))}"
816
+ mod_siglip = sys.modules[mod_siglip_name]
817
+
818
+ print(f"-- config._attn_implementation={model.config._attn_implementation}")
819
+ print(f"-- model.dtype={model.dtype}")
820
+ print(f"-- model.device={model.device}")
821
+
822
+ export_inputs, other_inputs = None, None
823
+ if not part:
824
+ # used to unit test
825
+ from ..helpers.torch_helper import to_any
826
+
827
+ assert "inputs" in data, f"key 'inputs' is missing from data (available {set(data)})"
828
+ model_to_export = data["model"]
829
+ model_to_export.eval()
830
+ export_inputs = to_any(to_any(data["inputs"], device), torch_dtype)
831
+ other_inputs = [
832
+ v for k, v in data.items() if k.startswith("inputs_") if k != "inputs_prompt"
833
+ ]
834
+ dynamic_shapes = data["dynamic_shapes"]
835
+ assert other_inputs, f"No other inputs was found from data (available {set(data)})"
836
+
837
+ elif part == "vision":
838
+
839
+ class VisionPart(torch.nn.Module):
840
+ def __init__(self, model):
841
+ super().__init__()
842
+ self.model = model
843
+
844
+ def forward(
845
+ self, input_ids, input_image_embeds, image_attention_mask, image_sizes
846
+ ):
847
+ torch._check(input_image_embeds.shape[-2] >= 28)
848
+ torch._check(input_image_embeds.shape[-1] >= 28)
849
+ return model.model.embed_tokens_extend.image_embed(
850
+ input_ids=input_ids,
851
+ input_embeds=input_image_embeds,
852
+ image_attention_mask=image_attention_mask,
853
+ image_sizes=image_sizes,
854
+ wte=model.model.embed_tokens,
855
+ )
856
+
857
+ model_to_export = VisionPart(model)
858
+ model_to_export.eval()
859
+
860
+ dynamic_shapes = {
861
+ "input_ids": {1: "seq_length"},
862
+ "input_image_embeds": {
863
+ 0: "num_images",
864
+ 1: "max_num_crops",
865
+ 3: "height",
866
+ 4: "width",
867
+ },
868
+ "image_attention_mask": {0: "num_images", 1: "max_num_crops"},
869
+ "image_sizes": {0: "num_images"},
870
+ }
871
+
872
+ else:
873
+ raise NotImplementedError(f"no export yet for part={part!r}")
874
+
875
+ print(f"-- part={part!r}")
876
+ print(f"-- model_to_export={type(model_to_export)}")
877
+ print(f"-- dynamic_shapes={dynamic_shapes}")
878
+ print("-- ############")
879
+ print("-- INPUT/OUTPUT")
880
+ print("-- ############")
881
+
882
+ input_filename = os.path.join(output_folder, f"inputs.{prefix}.{part}.{device}.{dtype}.pt")
883
+ if os.path.exists(input_filename):
884
+ print(f"-- restore inputs from {input_filename!r}")
885
+ data = torch.load(input_filename, weights_only=False)
886
+ export_inputs = data["export_inputs"]
887
+ other_inputs = data["other_inputs"]
888
+ dynamic_shapes = data["dynamic_shapes"]
889
+ elif export_inputs is not None:
890
+ data = dict(
891
+ export_inputs=export_inputs,
892
+ other_inputs=other_inputs,
893
+ dynamic_shapes=dynamic_shapes,
894
+ )
895
+ print(f"-- dump inputs into {input_filename!r}")
896
+ torch.save(data, input_filename)
897
+ else:
898
+ export_inputs, other_inputs = get_inputs_for_part(
899
+ model_id,
900
+ part,
901
+ torch_dtype,
902
+ device,
903
+ second_input,
904
+ )
905
+ data = dict(
906
+ export_inputs=export_inputs,
907
+ other_inputs=other_inputs,
908
+ dynamic_shapes=dynamic_shapes,
909
+ )
910
+ print(f"-- dump inputs into {input_filename!r}")
911
+ torch.save(data, input_filename)
912
+
913
+ print(f"-- export_inputs={string_type(export_inputs, with_shape=True, with_device=True)}")
914
+ print(f"-- other_inputs={string_type(other_inputs, with_shape=True, with_device=True)}")
915
+ print(f"-- dynamic_shapes={dynamic_shapes}")
916
+ output_filename = os.path.join(
917
+ output_folder, f"expected.{prefix}.visual.{device}.{dtype}.pt"
918
+ )
919
+
920
+ print("-- ##################")
921
+ print("-- # EXPECTED_OUTPUTS")
922
+ print("-- ##################")
923
+
924
+ export_expected, *_ = compute_expected_outputs(
925
+ output_filename, model_to_export, input_filename
926
+ )
927
+
928
+ if existing_onnx and os.path.exists(existing_onnx):
929
+ print("-- ######")
930
+ print(f"-- USING EXISTING ONNX {existing_onnx!r}")
931
+ print("-- ######")
932
+
933
+ exporter = existing_onnx
934
+ filename = existing_onnx
935
+ target_opset = None
936
+ else:
937
+ print("-- ######")
938
+ print("-- EXPORT")
939
+ print("-- ######")
940
+
941
+ additional_patches = get_patches(main_mod, mod_siglip)
942
+
943
+ begin = time.perf_counter()
944
+
945
+ target_opset = 22
946
+
947
+ details = PatchDetails()
948
+ with torch_export_patches(
949
+ patch_torch=True, # needed for DynamicDimConstraintPrinter
950
+ patch_sympy=False,
951
+ patch_transformers=True,
952
+ verbose=1,
953
+ stop_if_static=0,
954
+ profile=(f"{basename}.profile.html" if profile_exporter else None),
955
+ custom_patches=additional_patches,
956
+ patch_details=details,
957
+ ):
958
+ # let's again the patched code runs
959
+ patched_expected = model_to_export(**export_inputs)
960
+ diff = max_diff(export_expected, patched_expected, hist=[0.1, 0.01])
961
+ print(f"-- discrepancies PATCHED/ORIGINAL {string_diff(diff)}")
962
+ assert diff["abs"] < atol, (
963
+ f"Patches do not output the same values\n"
964
+ f"\nexpected={string_type(export_expected, with_shape=True)}"
965
+ f"\n patched={string_type(patched_expected, with_shape=True)}"
966
+ f"\ndiff={string_diff(diff)}"
967
+ )
968
+ if details and not os.path.exists(f"{basename}.patches_details.rst"):
969
+ print("-- builds patch details")
970
+ ep = torch.export.export(
971
+ model_to_export,
972
+ (),
973
+ kwargs=export_inputs,
974
+ dynamic_shapes=use_dyn_not_str(dynamic_shapes),
975
+ )
976
+ patches = details.patches_involded_in_graph(ep.graph)
977
+ report = details.make_report(patches, format="rst")
978
+ with open(f"{basename}.patches_details.rst", "w") as f:
979
+ f.write(report)
980
+ with open(f"{basename}.ep", "w") as f:
981
+ f.write(str(ep))
982
+ with open(f"{basename}.graph", "w") as f:
983
+ f.write(str(ep.graph))
984
+ print("-- done writing patch details")
985
+
986
+ to_onnx(
987
+ model_to_export,
988
+ kwargs=export_inputs,
989
+ dynamic_shapes=dynamic_shapes,
990
+ filename=filename,
991
+ exporter=exporter,
992
+ verbose=1,
993
+ save_ep=None,
994
+ target_opset=target_opset,
995
+ optimize=True,
996
+ )
997
+ export_duration = time.perf_counter() - begin
998
+
999
+ print("-- ###############")
1000
+ print("-- # DISCREPANCIES")
1001
+ print("-- ###############")
1002
+
1003
+ info = {
1004
+ "model_id": model_id,
1005
+ "part": part,
1006
+ "device": device,
1007
+ "dtype": dtype,
1008
+ "exporter": exporter,
1009
+ "pretrained": pretrained,
1010
+ "attention": os.environ.get("QWEN25ATTENTION", "default"),
1011
+ }
1012
+
1013
+ check_for_discrepancies_and_log_everything_into_a_json_file(
1014
+ agg_stat_file=os.path.join(output_folder, "collection_statistics.js"),
1015
+ stat_file=stat_file,
1016
+ export_duration=export_duration,
1017
+ device=device,
1018
+ model_file=filename,
1019
+ cached_inputs=input_filename,
1020
+ cached_expected_outputs=output_filename,
1021
+ main_info=info,
1022
+ atol=atol,
1023
+ mismatch01=mismatch01,
1024
+ )
1025
+
1026
+ if make_zip:
1027
+ print("-- #####")
1028
+ print("-- # ZIP")
1029
+ print("-- #####")
1030
+ zip_model_and_data_into_a_single_file(f"{basename}.zip", filename)
1031
+
1032
+
1033
+ if __name__ == "__main__":
1034
+ parser = get_parser(
1035
+ "qwen25",
1036
+ epilog=textwrap.dedent(
1037
+ r"""
1038
+ Tested command lines::
1039
+
1040
+ python -m onnx_diagnostic.ci_models.export_phi4_mm \
1041
+ -m microsoft/Phi-4-multimodal-instruct \
1042
+ --device cuda --dtype float16 --exporter custom \
1043
+ --pretrained --second-input --part vision
1044
+ """
1045
+ ),
1046
+ )
1047
+ args = parser.parse_args(sys.argv[1:])
1048
+ main(
1049
+ model_id=args.mid,
1050
+ device=args.device,
1051
+ dtype=args.dtype,
1052
+ exporter=args.exporter,
1053
+ pretrained=args.pretrained,
1054
+ second_input=args.second_input,
1055
+ make_zip=args.zip,
1056
+ output_folder=args.output_folder,
1057
+ existing_onnx=args.existing_onnx,
1058
+ part=args.part,
1059
+ atol=args.atol,
1060
+ mismatch01=args.mismatch01,
1061
+ profile_exporter=args.profile_exporter,
1062
+ )