onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,581 @@
1
+ import itertools
2
+ from typing import Any, Callable, Dict, Optional, Tuple
3
+ import torch
4
+ from ..helpers.cache_helper import make_dynamic_cache, make_hybrid_cache
5
+ from ..helpers.config_helper import (
6
+ update_config,
7
+ check_hasattr,
8
+ _pick,
9
+ default_num_hidden_layers as nhl,
10
+ )
11
+ from .data import get_data
12
+
13
+ __TASK__ = "image-text-to-text"
14
+
15
+
16
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
17
+ """Reduces a model size."""
18
+ kwargs: Dict[str, Any] = {}
19
+ if (
20
+ hasattr(config, "architectures")
21
+ and config.architectures
22
+ and config.architectures[0] == "Gemma3ForConditionalGeneration"
23
+ ):
24
+ if hasattr(config, "vision_config"):
25
+ if hasattr(config.vision_config, "num_hidden_layers"):
26
+ config.vision_config.num_hidden_layers = min(
27
+ config.vision_config.num_hidden_layers, nhl()
28
+ )
29
+ if hasattr(config, "text_config"):
30
+ if hasattr(config.text_config, "intermediate_size"):
31
+ config.text_config.intermediate_size = min(
32
+ config.text_config.intermediate_size, 10240 // 10 * 5 // 2
33
+ )
34
+ config.text_config.hidden_size = min(
35
+ config.text_config.hidden_size, 2560 // 10 * 5 // 2
36
+ )
37
+ update_config(config, kwargs)
38
+ return kwargs
39
+
40
+ if hasattr(config, "num_hidden_layers"):
41
+ config.num_hidden_layers = min(config.num_hidden_layers, nhl())
42
+ if hasattr(config, "mm_tokens_per_image"):
43
+ config.mm_tokens_per_image = min(config.mm_tokens_per_image, 2)
44
+ if hasattr(config, "vision_config"):
45
+ if hasattr(config.vision_config, "num_hidden_layers"):
46
+ config.vision_config.num_hidden_layers = min(
47
+ config.vision_config.num_hidden_layers, 2
48
+ )
49
+ if hasattr(config.vision_config, "num_heads"):
50
+ config.vision_config.num_heads = min(config.vision_config.num_heads, 4)
51
+ if hasattr(config.vision_config, "image_size"):
52
+ config.vision_config.image_size = min(config.vision_config.image_size, 168 // 2)
53
+ if hasattr(config.vision_config, "intermediate_size"):
54
+ config.vision_config.intermediate_size = min(
55
+ config.vision_config.intermediate_size, 1076
56
+ )
57
+ if hasattr(config.vision_config, "patch_size"):
58
+ config.vision_config.patch_size = min(config.vision_config.patch_size, 1)
59
+ if hasattr(config.vision_config, "temporal_patch_size"):
60
+ config.vision_config.temporal_patch_size = min(
61
+ config.vision_config.temporal_patch_size, 8
62
+ )
63
+ if hasattr(config.vision_config, "hidden_size"):
64
+ config.vision_config.hidden_size = min(config.vision_config.hidden_size, 16)
65
+ if hasattr(config, "text_config"):
66
+ if hasattr(config.text_config, "intermediate_size"):
67
+ config.text_config.intermediate_size = min(
68
+ config.text_config.intermediate_size, 320
69
+ )
70
+ if hasattr(config.text_config, "hidden_size"):
71
+ config.text_config.hidden_size = min(config.text_config.hidden_size, 16)
72
+ if hasattr(config.text_config, "num_hidden_layers"):
73
+ config.text_config.num_hidden_layers = min(config.text_config.num_hidden_layers, 2)
74
+ if hasattr(config.text_config, "layer_types"):
75
+ config.text_config.layer_types = config.text_config.layer_types[
76
+ : config.text_config.num_hidden_layers
77
+ ]
78
+ if hasattr(config.text_config, "num_attention_heads"):
79
+ config.text_config.num_attention_heads = min(
80
+ config.text_config.num_attention_heads, 2
81
+ )
82
+ update_config(config, kwargs)
83
+ return kwargs
84
+
85
+
86
+ def _get_inputs_gemma3(
87
+ model: torch.nn.Module,
88
+ config: Optional[Any],
89
+ dummy_max_token_id: int,
90
+ num_key_value_heads: int,
91
+ num_hidden_layers: int,
92
+ pad_token_id: int,
93
+ image_token_index: int,
94
+ head_dim: int,
95
+ width: int,
96
+ height: int,
97
+ num_channels: int,
98
+ batch_size: Optional[int] = 1,
99
+ sequence_length: Optional[int] = 281,
100
+ n_images: Optional[int] = 1,
101
+ max_sequence_length: Optional[int] = 580,
102
+ total_sequence_length: Optional[int] = 860,
103
+ **kwargs, # unused
104
+ ):
105
+ """
106
+ The functions uses predefined values for input_ids and token_type_ids.
107
+
108
+ **google/gemma-3-4b-it**
109
+
110
+ iteration 1
111
+
112
+ ::
113
+ cache_position:T7s281,
114
+ input_ids:T7s1x281,
115
+ token_type_ids:T7s1x281,
116
+ attention_mask:dict(sliding_attention:T9s1x1x281x580,
117
+ full_attention:T9s1x1x281x580),
118
+ pixel_values:T16s1x3x896x896,
119
+
120
+ iteration 2
121
+
122
+ ::
123
+
124
+ cache_position:T7s1,
125
+ past_key_values:StaticCache(key_cache=#34[T1s1x4x580x256,...],
126
+ value_cache=#34[T1s1x4x580x256,...]),
127
+ input_ids:T7s1x1,
128
+ inputs_embeds:None,
129
+ token_type_ids:T7s1x1,
130
+ attention_mask:dict(sliding_attention:T9s1x1x1x580,full_attention:T9s1x1x1x580),
131
+ position_ids:None,
132
+ """
133
+ batch_size = 1 if batch_size is None else batch_size
134
+ sequence_length = 281 if sequence_length is None else sequence_length
135
+ n_images = 1 if n_images is None else n_images
136
+ max_sequence_length = 580 if max_sequence_length is None else max_sequence_length
137
+ total_sequence_length = 860 if total_sequence_length is None else total_sequence_length
138
+
139
+ assert (
140
+ "cls_cache" not in kwargs
141
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
142
+ batch = "batch"
143
+ seq_length = "seq_length"
144
+ tot_length = "total_length"
145
+
146
+ shapes = {
147
+ "input_ids": {0: batch, 1: seq_length},
148
+ "token_type_ids": {0: batch, 1: seq_length},
149
+ "attention_mask": {
150
+ "full_attention": {0: batch, 2: seq_length, 3: tot_length},
151
+ "sliding_attention": {0: batch, 2: seq_length, 3: tot_length},
152
+ },
153
+ "position_ids": {0: batch, 1: seq_length},
154
+ "cache_position": {0: seq_length},
155
+ "past_key_values": [{0: batch} for _ in range(num_hidden_layers * 2)],
156
+ "pixel_values": {0: batch},
157
+ "use_cache": None,
158
+ }
159
+
160
+ # retrieve specific inputs to keep the consistency between
161
+ # ids and images
162
+ dummies = get_data("dummies_imagetext2text_generation_gemma3.onnx")
163
+ dummies = dummies[("", 0, "I")][1]
164
+ dummies = {k: v for k, v in dummies.items() if k in shapes}
165
+ expected = {"input_ids", "token_type_ids", "position_ids", "cache_position"}
166
+
167
+ def _check_():
168
+ assert expected & set(
169
+ dummies
170
+ ), f"Unable to find expected inputs {expected} in loaded inputs {set(dummies)}"
171
+ assert sequence_length == dummies["input_ids"].shape[-1], (
172
+ f"sequence_length={sequence_length} != {dummies['input_ids'].shape[-1]} for "
173
+ f"model class {model.__class__.__name__}"
174
+ )
175
+ assert batch_size == dummies["input_ids"].shape[0], (
176
+ f"batch_size={batch_size} != {dummies['input_ids'].shape[0]} for "
177
+ f"model class {model.__class__.__name__}"
178
+ )
179
+ assert max_sequence_length == 580, (
180
+ f"max_sequence_length={max_sequence_length} != 580 "
181
+ f"for model {model.__class__.__name__}"
182
+ )
183
+ assert total_sequence_length == 860, (
184
+ f"total_sequence_length={total_sequence_length} != 860 "
185
+ f"for model {model.__class__.__name__}"
186
+ )
187
+ assert head_dim in (
188
+ 256,
189
+ 32,
190
+ ), f"head_dim={head_dim} not in (32, 256) for model {model.__class__.__name__}"
191
+ assert n_images == 1, f"n_images={n_images} != 1 for model {model.__class__.__name__}"
192
+ assert num_key_value_heads in (1, 4), (
193
+ f"num_key_value_heads={num_key_value_heads} not in (1, 4) "
194
+ f"for this model {model.__class__.__name__}"
195
+ )
196
+
197
+ _check_()
198
+
199
+ inputs = dict(
200
+ input_ids=dummies["input_ids"],
201
+ token_type_ids=dummies["token_type_ids"],
202
+ attention_mask=dict(
203
+ full_attention=torch.randn(batch_size, 1, sequence_length, total_sequence_length),
204
+ sliding_attention=torch.randn(
205
+ batch_size, 1, sequence_length, total_sequence_length
206
+ ),
207
+ ),
208
+ position_ids=torch.arange(0, sequence_length).to(torch.int64).expand((batch_size, -1)),
209
+ cache_position=torch.arange(0, sequence_length).to(torch.int64),
210
+ past_key_values=make_hybrid_cache(
211
+ [
212
+ (
213
+ torch.randn(
214
+ batch_size, num_key_value_heads, max_sequence_length, head_dim
215
+ ),
216
+ torch.randn(
217
+ batch_size, num_key_value_heads, max_sequence_length, head_dim
218
+ ),
219
+ )
220
+ for i in range(num_hidden_layers)
221
+ ]
222
+ ),
223
+ pixel_values=torch.randn(n_images, num_channels, width, height).clamp(-1, 1),
224
+ # image_attention_mask=torch.ones((batch_size, sequence_length2, n_images)).to(
225
+ # torch.int64
226
+ # ),
227
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
228
+ )
229
+ return dict(inputs=inputs, dynamic_shapes=shapes)
230
+
231
+
232
+ def get_inputs_default(
233
+ model: torch.nn.Module,
234
+ config: Optional[Any],
235
+ dummy_max_token_id: int,
236
+ num_key_value_heads: int,
237
+ num_hidden_layers: int,
238
+ pad_token_id: int,
239
+ image_token_index: int,
240
+ head_dim: int,
241
+ width: int,
242
+ height: int,
243
+ num_channels: int,
244
+ batch_size: Optional[int] = 2,
245
+ sequence_length: Optional[int] = 43,
246
+ n_images: Optional[int] = 2,
247
+ max_sequence_length: Optional[int] = 43,
248
+ total_sequence_length: Optional[int] = 43,
249
+ add_second_input: int = 0,
250
+ **kwargs, # unused
251
+ ):
252
+ batch_size = 2 if batch_size is None else batch_size
253
+ sequence_length = 43 if sequence_length is None else sequence_length
254
+ n_images = 2 if n_images is None else n_images
255
+ max_sequence_length = 43 if max_sequence_length is None else max_sequence_length
256
+ total_sequence_length = 43 if total_sequence_length is None else total_sequence_length
257
+
258
+ assert batch_size > 0, "batch_size cannot be null"
259
+ assert (
260
+ "cls_cache" not in kwargs
261
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
262
+ batch = "batch"
263
+ batch_img = torch.export.Dim("batch_img", min=1, max=1024)
264
+ seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
265
+ cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
266
+ images = "images" # torch.export.Dim("images", min=1, max=4096)
267
+
268
+ shapes = {
269
+ "input_ids": {0: batch, 1: seq_length},
270
+ "token_type_ids": {0: batch, 1: seq_length},
271
+ "attention_mask": {0: batch, 1: "cache+seq"},
272
+ "position_ids": {0: batch, 1: seq_length},
273
+ "past_key_values": list(
274
+ itertools.chain.from_iterable(
275
+ zip(
276
+ [{0: batch} for _ in range(num_hidden_layers)],
277
+ [{0: batch, 2: cache_length} for _ in range(num_hidden_layers)],
278
+ )
279
+ )
280
+ ),
281
+ "pixel_values": (
282
+ {0: batch, 1: images}
283
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
284
+ else {0: batch_img}
285
+ ),
286
+ "image_attention_mask": {0: batch, 1: seq_length, 2: images},
287
+ "image_grid_thw": {0: batch},
288
+ "use_cache": None,
289
+ }
290
+
291
+ input_ids = torch.randint(0, dummy_max_token_id, (batch_size, total_sequence_length)).to(
292
+ torch.int64
293
+ )
294
+ if total_sequence_length > 0:
295
+ input_ids[0, 0] = image_token_index
296
+ if min(input_ids.shape) > 1:
297
+ input_ids[1, 1] = image_token_index
298
+ # input_ids[input_ids == image_token_index] = pad_token_id
299
+ token_type_ids = torch.zeros_like(input_ids)
300
+ token_type_ids[input_ids == image_token_index] = 1
301
+ image_grid_thw = torch.zeros((n_images, 3), dtype=torch.int64)
302
+ if n_images > 0:
303
+ image_grid_thw[:, 1] = height
304
+ image_grid_thw[:, 2] = width
305
+ image_grid_thw[0, :] //= 2
306
+ image_grid_thw[:, 0] = torch.arange(n_images, dtype=image_grid_thw.dtype)
307
+
308
+ inputs = dict(
309
+ input_ids=input_ids,
310
+ token_type_ids=token_type_ids,
311
+ attention_mask=torch.cat(
312
+ [
313
+ torch.ones((batch_size, sequence_length), dtype=torch.int64),
314
+ input_ids.ne(pad_token_id).to(torch.int64),
315
+ ],
316
+ axis=-1,
317
+ ),
318
+ position_ids=torch.arange(0, total_sequence_length)
319
+ .to(torch.int64)
320
+ .expand((batch_size, -1)),
321
+ past_key_values=make_dynamic_cache(
322
+ [
323
+ (
324
+ torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
325
+ torch.randn(batch_size, num_key_value_heads, sequence_length, head_dim),
326
+ )
327
+ for i in range(num_hidden_layers)
328
+ ]
329
+ ),
330
+ pixel_values=(
331
+ torch.randn((batch_size, n_images, num_channels, width, height)).clamp(-1, 1)
332
+ if model.__class__.__name__ == "IdeficsForVisionText2Text"
333
+ else torch.randn(n_images, num_channels, width, height).clamp(-1, 1)
334
+ ),
335
+ image_attention_mask=torch.ones((batch_size, total_sequence_length, n_images)).to(
336
+ torch.int64
337
+ ),
338
+ image_grid_thw=image_grid_thw,
339
+ use_cache=True, # Gemma3 does not set this value to true when a cache is provided
340
+ )
341
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
342
+ return res
343
+
344
+
345
+ def get_inputs(
346
+ model: torch.nn.Module,
347
+ config: Optional[Any],
348
+ dummy_max_token_id: int,
349
+ num_key_value_heads: int,
350
+ num_hidden_layers: int,
351
+ pad_token_id: int,
352
+ image_token_index: int,
353
+ head_dim: int,
354
+ width: int,
355
+ height: int,
356
+ num_channels: int,
357
+ batch_size: Optional[int] = None,
358
+ sequence_length: Optional[int] = None,
359
+ n_images: Optional[int] = None,
360
+ max_sequence_length: Optional[int] = None,
361
+ total_sequence_length: Optional[int] = None,
362
+ add_second_input: int = 0,
363
+ **kwargs, # unused
364
+ ):
365
+ """
366
+ Generates input for task ``image-text-to-text``.
367
+
368
+ :param model: model to get the missing information
369
+ :param config: configuration used to generate the model
370
+ :param head_dim: last dimension of the cache
371
+ :param dummy_max_token_id: dummy max token id
372
+ :param pad_token_id: pad_token_id
373
+ :param image_token_index: image_token_index
374
+ :param batch_size: batch size
375
+ :param sequence_length: sequence length
376
+ :param max_sequence_length: for the cache
377
+ :param total_sequence_length: for the mask
378
+ :param n_images: number of images
379
+ :param width: width of the image
380
+ :param height: height of the image
381
+ :param num_channels: number of channels
382
+ :return: dictionary
383
+
384
+ .. note::
385
+
386
+ The content of the input_ids and its shape is correlated to the images.
387
+ The function uses a predefined values. The function raises an exception
388
+ if dimension are not the expected ones.
389
+ """
390
+ if model.__class__.__name__.startswith("Gemma3"):
391
+ res = _get_inputs_gemma3(
392
+ model,
393
+ config,
394
+ dummy_max_token_id=dummy_max_token_id,
395
+ num_key_value_heads=num_key_value_heads,
396
+ num_hidden_layers=num_hidden_layers,
397
+ pad_token_id=pad_token_id,
398
+ image_token_index=image_token_index,
399
+ head_dim=head_dim,
400
+ width=width,
401
+ height=height,
402
+ num_channels=num_channels,
403
+ batch_size=batch_size,
404
+ sequence_length=sequence_length,
405
+ max_sequence_length=max_sequence_length,
406
+ total_sequence_length=total_sequence_length,
407
+ n_images=n_images,
408
+ **kwargs,
409
+ )
410
+ else:
411
+ res = get_inputs_default(
412
+ model,
413
+ config,
414
+ dummy_max_token_id=dummy_max_token_id,
415
+ num_key_value_heads=num_key_value_heads,
416
+ num_hidden_layers=num_hidden_layers,
417
+ pad_token_id=pad_token_id,
418
+ image_token_index=image_token_index,
419
+ head_dim=head_dim,
420
+ width=width,
421
+ height=height,
422
+ num_channels=num_channels,
423
+ batch_size=batch_size,
424
+ sequence_length=sequence_length,
425
+ max_sequence_length=max_sequence_length,
426
+ total_sequence_length=total_sequence_length,
427
+ n_images=n_images,
428
+ **kwargs,
429
+ )
430
+
431
+ if add_second_input:
432
+ assert (
433
+ add_second_input > 0
434
+ ), f"Not implemented for add_second_input={add_second_input}."
435
+ res["inputs2"] = get_inputs(
436
+ model=model,
437
+ config=config,
438
+ dummy_max_token_id=dummy_max_token_id,
439
+ num_key_value_heads=num_key_value_heads,
440
+ num_hidden_layers=num_hidden_layers,
441
+ head_dim=head_dim,
442
+ width=width,
443
+ height=height,
444
+ num_channels=num_channels,
445
+ batch_size=3,
446
+ sequence_length=1,
447
+ max_sequence_length=1,
448
+ total_sequence_length=1,
449
+ n_images=0,
450
+ pad_token_id=pad_token_id,
451
+ image_token_index=image_token_index,
452
+ add_second_input=0,
453
+ **kwargs,
454
+ )["inputs"]
455
+ return res
456
+
457
+
458
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
459
+ """
460
+ Inputs kwargs.
461
+
462
+ If the configuration is None, the function selects typical dimensions.
463
+ """
464
+ if config is not None:
465
+ if hasattr(config, "text_config"):
466
+ check_hasattr(
467
+ config.text_config,
468
+ "vocab_size",
469
+ "hidden_size",
470
+ "num_attention_heads",
471
+ ("num_key_value_heads", "num_attention_heads"),
472
+ "intermediate_size",
473
+ "hidden_size",
474
+ "pad_token_id",
475
+ )
476
+ check_hasattr(config, "vision_config", ("image_token_index", "image_token_id"))
477
+ text_config = True
478
+ else:
479
+ check_hasattr(
480
+ config,
481
+ "vocab_size",
482
+ "hidden_size",
483
+ "num_attention_heads",
484
+ ("num_key_value_heads", "num_attention_heads"),
485
+ "intermediate_size",
486
+ "hidden_size",
487
+ "vision_config",
488
+ )
489
+ text_config = False
490
+ check_hasattr(config.vision_config, ("num_channels", "in_chans", "in_channels"))
491
+ kwargs = dict(
492
+ head_dim=(
493
+ 16
494
+ if config is None
495
+ else getattr(
496
+ config,
497
+ "head_dim",
498
+ (
499
+ config.text_config.head_dim
500
+ if text_config and hasattr(config.text_config, "head_dim")
501
+ else (
502
+ (config.text_config.hidden_size if text_config else config.hidden_size)
503
+ // (
504
+ config.text_config.num_attention_heads
505
+ if text_config
506
+ else config.num_attention_heads
507
+ )
508
+ )
509
+ ),
510
+ )
511
+ ),
512
+ dummy_max_token_id=(
513
+ 31999
514
+ if config is None
515
+ else (config.text_config.vocab_size if text_config else config.vocab_size) - 1
516
+ ),
517
+ num_hidden_layers=(
518
+ 4
519
+ if config is None
520
+ else (
521
+ config.text_config.num_hidden_layers
522
+ if text_config
523
+ else config.num_hidden_layers
524
+ )
525
+ ),
526
+ num_key_value_heads=(
527
+ 8
528
+ if config is None
529
+ else (
530
+ _pick(config.text_config, "num_key_value_heads", "num_attention_heads")
531
+ if text_config
532
+ else _pick(config, "num_key_value_heads", "num_attention_heads")
533
+ )
534
+ ),
535
+ intermediate_size=(
536
+ 1024
537
+ if config is None
538
+ else (
539
+ config.text_config.intermediate_size
540
+ if text_config
541
+ else config.intermediate_size
542
+ )
543
+ ),
544
+ hidden_size=(
545
+ 512
546
+ if config is None
547
+ else (config.text_config.hidden_size if text_config else config.hidden_size)
548
+ ),
549
+ width=(
550
+ 224
551
+ if config is None or not hasattr(config.vision_config, "image_size")
552
+ else config.vision_config.image_size
553
+ ),
554
+ height=(
555
+ 224
556
+ if config is None or not hasattr(config.vision_config, "image_size")
557
+ else config.vision_config.image_size
558
+ ),
559
+ num_channels=(
560
+ 3
561
+ if config is None
562
+ else _pick(config.vision_config, "num_channels", "in_chans", "in_channels")
563
+ ),
564
+ pad_token_id=(
565
+ 0
566
+ if config is None
567
+ or not hasattr(config, "text_config")
568
+ or not hasattr(config.text_config, "pad_token_id")
569
+ else config.text_config.pad_token_id
570
+ ),
571
+ image_token_index=(
572
+ 4
573
+ if config is None
574
+ or (
575
+ not hasattr(config, "image_token_index")
576
+ and not hasattr(config, "image_token_id")
577
+ )
578
+ else _pick(config, "image_token_index", "image_token_id")
579
+ ),
580
+ )
581
+ return kwargs, get_inputs
@@ -0,0 +1,127 @@
1
+ from typing import Any, Callable, Dict, Optional, Tuple
2
+ import torch
3
+ from ..helpers.config_helper import (
4
+ update_config,
5
+ check_hasattr,
6
+ default_num_hidden_layers as nhl,
7
+ )
8
+
9
+ __TASK__ = "image-to-video"
10
+
11
+
12
+ def reduce_model_config(config: Any) -> Dict[str, Any]:
13
+ """Reduces a model size."""
14
+ if not hasattr(config, "num_hidden_layers") and not hasattr(config, "num_layers"):
15
+ # We cannot reduce.
16
+ return {}
17
+ check_hasattr(config, ("num_hidden_layers", "num_layers"))
18
+ kwargs = {}
19
+ if hasattr(config, "num_layers"):
20
+ kwargs["num_layers"] = min(config.num_layers, nhl())
21
+ if hasattr(config, "num_hidden_layers"):
22
+ kwargs["num_hidden_layers"] = min(config.num_hidden_layers, nhl())
23
+
24
+ update_config(config, kwargs)
25
+ return kwargs
26
+
27
+
28
+ def get_inputs(
29
+ model: torch.nn.Module,
30
+ config: Optional[Any],
31
+ text_embed_dim: int,
32
+ latent_channels: int,
33
+ batch_size: int = 2,
34
+ image_height: int = 704,
35
+ image_width: int = 1280,
36
+ latent_frames: int = 1,
37
+ text_maxlen: int = 512,
38
+ add_second_input: int = 1,
39
+ **kwargs, # unused
40
+ ):
41
+ """
42
+ Generates inputs for task ``image-to-video``.
43
+ """
44
+ assert (
45
+ "cls_cache" not in kwargs
46
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
47
+ latent_height = image_height // 8
48
+ latent_width = image_width // 8
49
+ dtype = torch.float32
50
+
51
+ inputs = dict(
52
+ hidden_states=torch.randn(
53
+ batch_size,
54
+ latent_channels,
55
+ latent_frames,
56
+ latent_height,
57
+ latent_width,
58
+ dtype=dtype,
59
+ ),
60
+ timestep=torch.tensor([1.0] * batch_size, dtype=dtype),
61
+ encoder_hidden_states=torch.randn(
62
+ batch_size, text_maxlen, text_embed_dim, dtype=dtype
63
+ ),
64
+ padding_mask=torch.ones(1, 1, image_height, image_width, dtype=dtype),
65
+ fps=torch.tensor([16] * batch_size, dtype=dtype),
66
+ condition_mask=torch.randn(
67
+ batch_size, 1, latent_frames, latent_height, latent_width, dtype=dtype
68
+ ),
69
+ )
70
+ shapes = dict(
71
+ hidden_states={
72
+ 0: "batch_size",
73
+ 2: "latent_frames",
74
+ 3: "latent_height",
75
+ 4: "latent_width",
76
+ },
77
+ timestep={0: "batch_size"},
78
+ encoder_hidden_states={0: "batch_size"},
79
+ padding_mask={0: "batch_size", 2: "height", 3: "width"},
80
+ fps={0: "batch_size"},
81
+ condition_mask={
82
+ 0: "batch_size",
83
+ 2: "latent_frames",
84
+ 3: "latent_height",
85
+ 4: "latent_width",
86
+ },
87
+ )
88
+ res = dict(inputs=inputs, dynamic_shapes=shapes)
89
+
90
+ if add_second_input:
91
+ assert (
92
+ add_second_input > 0
93
+ ), f"Not implemented for add_second_input={add_second_input}."
94
+ res["inputs2"] = get_inputs(
95
+ model=model,
96
+ config=config,
97
+ text_embed_dim=text_embed_dim,
98
+ latent_channels=latent_channels,
99
+ batch_size=batch_size,
100
+ image_height=image_height,
101
+ image_width=image_width,
102
+ latent_frames=latent_frames,
103
+ text_maxlen=text_maxlen,
104
+ add_second_input=0,
105
+ **kwargs,
106
+ )["inputs"]
107
+ return res
108
+
109
+
110
+ def random_input_kwargs(config: Any) -> Tuple[Dict[str, Any], Callable]:
111
+ """
112
+ Inputs kwargs.
113
+
114
+ If the configuration is None, the function selects typical dimensions.
115
+ """
116
+ if config is not None:
117
+ check_hasattr(config, "in_channels", "text_embed_dim"),
118
+ kwargs = dict(
119
+ text_embed_dim=1024 if config is None else config.text_embed_dim,
120
+ latent_channels=16 if config is None else config.in_channels - 1,
121
+ batch_size=1,
122
+ image_height=8 * 50,
123
+ image_width=8 * 80,
124
+ latent_frames=1,
125
+ text_maxlen=512,
126
+ )
127
+ return kwargs, get_inputs