onnx-diagnostic 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (132) hide show
  1. onnx_diagnostic/__init__.py +7 -0
  2. onnx_diagnostic/__main__.py +4 -0
  3. onnx_diagnostic/_command_lines_parser.py +1141 -0
  4. onnx_diagnostic/api.py +15 -0
  5. onnx_diagnostic/doc.py +100 -0
  6. onnx_diagnostic/export/__init__.py +2 -0
  7. onnx_diagnostic/export/api.py +124 -0
  8. onnx_diagnostic/export/dynamic_shapes.py +1083 -0
  9. onnx_diagnostic/export/shape_helper.py +296 -0
  10. onnx_diagnostic/export/validate.py +173 -0
  11. onnx_diagnostic/ext_test_case.py +1290 -0
  12. onnx_diagnostic/helpers/__init__.py +1 -0
  13. onnx_diagnostic/helpers/_log_helper.py +463 -0
  14. onnx_diagnostic/helpers/args_helper.py +132 -0
  15. onnx_diagnostic/helpers/bench_run.py +450 -0
  16. onnx_diagnostic/helpers/cache_helper.py +687 -0
  17. onnx_diagnostic/helpers/config_helper.py +170 -0
  18. onnx_diagnostic/helpers/doc_helper.py +163 -0
  19. onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
  20. onnx_diagnostic/helpers/graph_helper.py +386 -0
  21. onnx_diagnostic/helpers/helper.py +1707 -0
  22. onnx_diagnostic/helpers/log_helper.py +2245 -0
  23. onnx_diagnostic/helpers/memory_peak.py +249 -0
  24. onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
  25. onnx_diagnostic/helpers/model_builder_helper.py +469 -0
  26. onnx_diagnostic/helpers/onnx_helper.py +1200 -0
  27. onnx_diagnostic/helpers/ort_session.py +736 -0
  28. onnx_diagnostic/helpers/rt_helper.py +476 -0
  29. onnx_diagnostic/helpers/torch_helper.py +987 -0
  30. onnx_diagnostic/reference/__init__.py +4 -0
  31. onnx_diagnostic/reference/evaluator.py +254 -0
  32. onnx_diagnostic/reference/ops/__init__.py +1 -0
  33. onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
  34. onnx_diagnostic/reference/ops/op_attention.py +60 -0
  35. onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
  36. onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
  37. onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
  38. onnx_diagnostic/reference/ops/op_complex.py +26 -0
  39. onnx_diagnostic/reference/ops/op_concat.py +15 -0
  40. onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
  41. onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
  42. onnx_diagnostic/reference/ops/op_gather.py +29 -0
  43. onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
  44. onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
  45. onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
  46. onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
  47. onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
  48. onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
  49. onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
  50. onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
  51. onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
  52. onnx_diagnostic/reference/ops/op_rotary.py +19 -0
  53. onnx_diagnostic/reference/ops/op_scan.py +65 -0
  54. onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
  55. onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
  56. onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
  57. onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
  58. onnx_diagnostic/reference/ops/op_slice.py +20 -0
  59. onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
  60. onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
  61. onnx_diagnostic/reference/ort_evaluator.py +652 -0
  62. onnx_diagnostic/reference/quantized_tensor.py +46 -0
  63. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  64. onnx_diagnostic/reference/torch_evaluator.py +669 -0
  65. onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
  66. onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
  67. onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
  68. onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
  69. onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
  70. onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
  71. onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
  72. onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
  73. onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
  74. onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
  75. onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
  76. onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
  77. onnx_diagnostic/tasks/__init__.py +90 -0
  78. onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
  79. onnx_diagnostic/tasks/data/__init__.py +13 -0
  80. onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
  81. onnx_diagnostic/tasks/feature_extraction.py +162 -0
  82. onnx_diagnostic/tasks/fill_mask.py +89 -0
  83. onnx_diagnostic/tasks/image_classification.py +144 -0
  84. onnx_diagnostic/tasks/image_text_to_text.py +581 -0
  85. onnx_diagnostic/tasks/image_to_video.py +127 -0
  86. onnx_diagnostic/tasks/mask_generation.py +143 -0
  87. onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
  88. onnx_diagnostic/tasks/object_detection.py +134 -0
  89. onnx_diagnostic/tasks/sentence_similarity.py +89 -0
  90. onnx_diagnostic/tasks/summarization.py +227 -0
  91. onnx_diagnostic/tasks/text2text_generation.py +230 -0
  92. onnx_diagnostic/tasks/text_classification.py +89 -0
  93. onnx_diagnostic/tasks/text_generation.py +352 -0
  94. onnx_diagnostic/tasks/text_to_image.py +95 -0
  95. onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
  96. onnx_diagnostic/torch_export_patches/__init__.py +21 -0
  97. onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
  98. onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
  99. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
  100. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
  101. onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
  102. onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
  103. onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
  104. onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
  105. onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
  106. onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
  107. onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
  108. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
  109. onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
  110. onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
  111. onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
  112. onnx_diagnostic/torch_models/__init__.py +0 -0
  113. onnx_diagnostic/torch_models/code_sample.py +343 -0
  114. onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
  115. onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
  116. onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
  117. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
  118. onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
  119. onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
  120. onnx_diagnostic/torch_models/llms.py +2 -0
  121. onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
  122. onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
  123. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
  124. onnx_diagnostic/torch_models/validate.py +2124 -0
  125. onnx_diagnostic/torch_onnx/__init__.py +0 -0
  126. onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
  127. onnx_diagnostic/torch_onnx/sbs.py +440 -0
  128. onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
  129. onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
  130. onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
  131. onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
  132. onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,898 @@
1
+ import numpy as np
2
+ import torch
3
+ from ..patches.patch_torch import patched_vmap
4
+
5
+ DIM = torch.export.Dim
6
+ DYN = torch.export.Dim.DYNAMIC
7
+
8
+
9
+ class AtenRollRelu(torch.nn.Module):
10
+ def forward(self, x):
11
+ return torch.relu(torch.roll(x, -1, -1))
12
+
13
+ _inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),)
14
+ _dynamic = {"x": {0: DIM("batch")}}
15
+
16
+
17
+ class AtenRollPos(torch.nn.Module):
18
+ def forward(self, x):
19
+ return torch.roll(x, 1, -1)
20
+
21
+ _inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),)
22
+ _dynamic = {"x": {0: DIM("batch")}}
23
+
24
+
25
+ class InplaceAdd(torch.nn.Module):
26
+ def __init__(self):
27
+ super().__init__()
28
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
29
+
30
+ def forward(self, x):
31
+ x += self.bias
32
+ return x
33
+
34
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
35
+ _dynamic = {"x": {0: DIM("batch")}}
36
+
37
+
38
+ class InplaceAdd2(torch.nn.Module):
39
+ def __init__(self):
40
+ super().__init__()
41
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
42
+
43
+ def forward(self, x):
44
+ x.add_(self.bias)
45
+ return x
46
+
47
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
48
+ _dynamic = {"x": {0: DIM("batch")}}
49
+
50
+
51
+ class InplaceAdd_Mul(torch.nn.Module):
52
+ def __init__(self):
53
+ super().__init__()
54
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
55
+
56
+ def forward(self, x):
57
+ x.add_(self.bias)
58
+ return x * 2
59
+
60
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
61
+ _dynamic = {"x": {0: DIM("batch")}}
62
+
63
+
64
+ class InplaceCloneAdd_(torch.nn.Module):
65
+ def __init__(self):
66
+ super().__init__()
67
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
68
+
69
+ def forward(self, x):
70
+ x = x.clone()
71
+ x.add_(self.bias)
72
+ return x
73
+
74
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
75
+ _dynamic = {"x": {0: DIM("batch")}}
76
+
77
+
78
+ class InplaceSetItemSquare(torch.nn.Module):
79
+ def forward(self, x):
80
+ x[:2, :3] = 1
81
+ return x
82
+
83
+ _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
84
+ _dynamic = {"x": {0: DIM("batch")}}
85
+
86
+
87
+ class InplaceSetItemSquareAdd(torch.nn.Module):
88
+ def forward(self, x):
89
+ x[:2, :3] = 1
90
+ return x + 2
91
+
92
+ _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
93
+ _dynamic = {"x": {0: DIM("batch")}}
94
+
95
+
96
+ class InplaceSetItemSquareAdd2(torch.nn.Module):
97
+ def forward(self, x):
98
+ x[:2, :3] = 1
99
+ return x + 2, x + 3
100
+
101
+ _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
102
+ _dynamic = {"x": {0: DIM("batch")}}
103
+
104
+
105
+ class InplaceSetItemEllipsis_1(torch.nn.Module):
106
+ def __init__(self):
107
+ super().__init__()
108
+ self.params = torch.zeros((1, 8192, 4), dtype=torch.float32)
109
+
110
+ def forward(self, index, update):
111
+ copy = self.params.clone()
112
+ copy[..., index] = update
113
+ return copy
114
+
115
+ _inputs = (
116
+ (torch.from_numpy(np.array([0, 3, 2, 1])).to(torch.int64)),
117
+ (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
118
+ )
119
+ _dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
120
+
121
+
122
+ class InplaceSetItemEllipsis_2(torch.nn.Module):
123
+ def __init__(self):
124
+ super().__init__()
125
+ self.params = torch.zeros((1, 8192, 6), dtype=torch.float32)
126
+
127
+ def forward(self, index, update):
128
+ copy = self.params.clone()
129
+ copy[..., index] = update
130
+ return copy
131
+
132
+ _inputs = (
133
+ torch.from_numpy(np.array([0, 3, 2, 5])).to(torch.int64),
134
+ (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
135
+ )
136
+ _dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
137
+
138
+
139
+ class InplaceSetItemMask(torch.nn.Module):
140
+ def forward(self, x):
141
+ mask = x.to(bool)
142
+ x[mask] = 2
143
+ return x
144
+
145
+ _inputs = [(torch.randn((2, 3, 3)),), (torch.randn((3, 3, 3)),)]
146
+ _dynamic = {"x": {0: DIM("batch")}}
147
+
148
+
149
+ class AtenInterpolate(torch.nn.Module):
150
+ def forward(self, x):
151
+ y = torch.nn.functional.interpolate(
152
+ x,
153
+ scale_factor=2.0,
154
+ mode="bilinear",
155
+ recompute_scale_factor=False,
156
+ )
157
+ return y
158
+
159
+ _inputs = (torch.randn(2, 2, 3, 4, requires_grad=False),)
160
+ _dynamic = {"x": {0: DIM("batch")}}
161
+
162
+
163
+ class AtenNonZero(torch.nn.Module):
164
+ def forward(self, x):
165
+ y = torch.nonzero(x)
166
+ return y
167
+
168
+ _inputs = (torch.randn(3, 4, requires_grad=False),)
169
+ _dynamic = {"x": {0: DIM("batch")}}
170
+
171
+
172
+ class AtenNonZeroTuple(torch.nn.Module):
173
+ def forward(self, x):
174
+ y = torch.nonzero(x, as_tuple=True)
175
+ return y[0], y[1]
176
+
177
+ _inputs = (torch.randn(3, 4, requires_grad=False),)
178
+ _dynamic = {"x": {0: DIM("batch")}}
179
+
180
+
181
+ class AtenAsStrided(torch.nn.Module):
182
+ def __init__(self):
183
+ super().__init__()
184
+
185
+ def forward(self, x):
186
+ y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1))
187
+ return y
188
+
189
+ _inputs = (torch.randn((2, 2, 8, 8), requires_grad=False),)
190
+ _dynamic = {"x": {0: DIM("batch")}}
191
+
192
+
193
+ class ComplexPolar(torch.nn.Module):
194
+ def forward(self, x, angle):
195
+ return torch.polar(x, angle)
196
+
197
+ _inputs = (torch.rand(4, 4), torch.rand(4, 4))
198
+ _dynamic = {"x": {0: DIM("batch")}, "angle": {0: DIM("batch")}}
199
+
200
+
201
+ class ControlFlowCond(torch.nn.Module):
202
+ def forward(self, x):
203
+ def true_fn(x):
204
+ return torch.sin(x)
205
+
206
+ def false_fn(x):
207
+ return torch.cos(x)
208
+
209
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
210
+
211
+ _inputs = (torch.rand(5, 3),)
212
+ _dynamic = {"x": {0: DIM("batch")}}
213
+
214
+
215
+ class ControlFlowCond2Outputs(torch.nn.Module):
216
+ def forward(self, x):
217
+ def true_fn(x):
218
+ return torch.sin(x), torch.cos(x)
219
+
220
+ def false_fn(x):
221
+ return torch.cos(x), torch.sin(x)
222
+
223
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
224
+
225
+ _inputs = (torch.rand(5, 3),)
226
+ _dynamic = {"x": {0: DIM("batch")}}
227
+
228
+
229
+ class ControlFlowCond2Inputs(torch.nn.Module):
230
+ def forward(self, x, y):
231
+ def true_fn(x, y):
232
+ return torch.sin(x), torch.cos(x) + y
233
+
234
+ def false_fn(x, y):
235
+ return torch.cos(x), torch.sin(x) + y
236
+
237
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])
238
+
239
+ _inputs = torch.rand(5, 3), torch.rand(5, 3)
240
+ _dynamic = {"x": {0: DIM("batch")}, "y": {0: DIM("batch")}}
241
+
242
+
243
+ class ControlFlowNestCond(torch.nn.Module):
244
+ def forward(self, x):
245
+ def true_fn2(x):
246
+ def true_fn1(x):
247
+ return torch.sin(x)
248
+
249
+ def false_fn1(x):
250
+ return torch.cos(x)
251
+
252
+ return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x])
253
+
254
+ def false_fn2(x):
255
+ return -x
256
+
257
+ return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x])
258
+
259
+ _inputs = (torch.rand(5, 3),)
260
+ _dynamic = {"x": {0: DIM("batch")}}
261
+
262
+
263
+ class ControlFlowCondConstant(torch.nn.Module):
264
+ def forward(self, x):
265
+ def true_fn(x):
266
+ return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype)
267
+
268
+ def false_fn(x):
269
+ return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype)
270
+
271
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
272
+
273
+ _inputs = (torch.rand(1024, 1024),)
274
+ _dynamic = {"x": {0: DIM("batch")}}
275
+
276
+
277
+ class ControlFlowCondNestedModule(torch.nn.Module):
278
+ class Submodule(torch.nn.Module):
279
+ def __init__(self):
280
+ super().__init__()
281
+ # Nested weight
282
+ self.weight = torch.nn.Parameter(torch.tensor([100.0]))
283
+
284
+ def forward(self, x):
285
+ def true_fn(x):
286
+ return x * self.weight
287
+
288
+ def false_fn(x):
289
+ return x / self.weight
290
+
291
+ y = torch.cond(torch.abs(x).sum() > 100, true_fn, false_fn, [x])
292
+ return y
293
+
294
+ def __init__(self):
295
+ super().__init__()
296
+ self.submodule = ControlFlowCondNestedModule.Submodule()
297
+ self.weight = torch.nn.Parameter(torch.tensor([42.0]))
298
+
299
+ def forward(self, x):
300
+ def true_fn(x):
301
+ return self.submodule(x)
302
+
303
+ def false_fn(x):
304
+ return x - self.weight
305
+
306
+ y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
307
+ return y
308
+
309
+ _inputs = (torch.tensor([-1, 2]),)
310
+ _dynamic = {"x": {0: DIM("batch")}}
311
+
312
+
313
+ class ControlFlowCondNonZero(torch.nn.Module):
314
+ def forward(self, input_ids, image_features, vocab_size):
315
+ def then_branch(input_ids, image_features, vocab_size):
316
+ input_shape = input_ids.size()
317
+ input_ids = input_ids.view(-1, input_shape[-1])
318
+
319
+ condition = (input_ids < 0) & (input_ids > -int(1e9))
320
+ positions = torch.nonzero(condition, as_tuple=True)
321
+ input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)
322
+ return (input_ids, positions[0], positions[1])
323
+
324
+ def else_branch(input_ids, image_features, vocab_size):
325
+ r = torch.where(torch.zeros((1, 1), dtype=torch.bool))
326
+ return (input_ids, r[0], r[1])
327
+
328
+ a, b, c = torch.cond(
329
+ image_features.numel() > 0,
330
+ then_branch,
331
+ else_branch,
332
+ [input_ids, image_features, vocab_size],
333
+ )
334
+ return a, b, c
335
+
336
+ _inputs = [
337
+ (
338
+ (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
339
+ torch.arange(32).reshape((2, -1)).to(torch.float32),
340
+ 1025,
341
+ ),
342
+ (
343
+ (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
344
+ torch.tensor([[], []], dtype=torch.float32),
345
+ 1025,
346
+ ),
347
+ ]
348
+ _dynamic = (
349
+ {0: DIM("batch")},
350
+ {0: DIM("batch"), 1: DIM("seq_length")},
351
+ None,
352
+ )
353
+
354
+
355
+ class ControlFlowCondIdentity_153832(torch.nn.Module):
356
+ """`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
357
+
358
+ def forward(self, x, y):
359
+ def branch_cond_then_1(x):
360
+ x = torch.abs(x) + 1
361
+ return x
362
+
363
+ def branch_cond_else_1(x):
364
+ return x # fails but succeeds with x.clone()
365
+
366
+ x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x])
367
+ return x + y
368
+
369
+ _inputs = [
370
+ (torch.rand((3, 4)), torch.rand((3, 4))),
371
+ (torch.rand((4, 5)), torch.rand((4, 5))),
372
+ ]
373
+ _dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}}
374
+
375
+
376
+ class ControlFlowScan(torch.nn.Module):
377
+ @staticmethod
378
+ def add(carry: torch.Tensor, y: torch.Tensor):
379
+ next_carry = carry + y
380
+ return [next_carry, next_carry]
381
+
382
+ def forward(self, x):
383
+ init = torch.zeros_like(x[0])
384
+ carry, _out = torch.ops.higher_order.scan(
385
+ ControlFlowScan.add, [init], [x], additional_inputs=[]
386
+ )
387
+ return carry
388
+
389
+ _inputs = (torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32),)
390
+ _dynamic = {"x": {0: DIM("batch")}}
391
+
392
+
393
+ class ControlFlowScan2Carried(torch.nn.Module):
394
+ @staticmethod
395
+ def add(carry1: torch.Tensor, carry2: torch.Tensor, y1: torch.Tensor, y2: torch.Tensor):
396
+ next_carry1 = carry1 + y1
397
+ next_carry2 = carry2 * y2
398
+ return [next_carry1, next_carry2, next_carry1, next_carry2]
399
+
400
+ def forward(self, x):
401
+ init1 = torch.zeros_like(x[0])
402
+ init2 = torch.ones_like(x[0])
403
+ carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
404
+ ControlFlowScan2Carried.add,
405
+ [init1, init2],
406
+ [x, x * 2],
407
+ # dim=0, # 01/31/2025, not supported anymore
408
+ additional_inputs=[],
409
+ )
410
+ return carry1, carry2, out1, out2
411
+
412
+ _inputs = (
413
+ torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
414
+ )
415
+ _dynamic = {"x": {0: DIM("batch")}}
416
+
417
+
418
+ class ControlFlowScanCDist(torch.nn.Module):
419
+ @staticmethod
420
+ def dist(carry: torch.Tensor, x: torch.Tensor):
421
+ sub = carry - x.reshape((1, -1))
422
+ sq = sub * sub
423
+ rd = sq.sum(axis=1) ** 0.5
424
+ # clone --> UnsupportedAliasMutationException:
425
+ # Combine_fn might be aliasing the input!
426
+ return [carry.clone(), rd]
427
+
428
+ def forward(self, x):
429
+ _carry, out = torch.ops.higher_order.scan(
430
+ ControlFlowScanCDist.dist,
431
+ [x],
432
+ [x],
433
+ # dim=0, # 01/31/2025, not supported anymore
434
+ additional_inputs=[],
435
+ )
436
+ return out
437
+
438
+ _inputs = (
439
+ torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
440
+ )
441
+ _dynamic = {"x": {0: DIM("batch")}}
442
+
443
+
444
+ class ControlFlowScanCDist2(torch.nn.Module):
445
+ @staticmethod
446
+ def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
447
+ sub = samex - x.reshape((1, -1))
448
+ sq = sub * sub
449
+ rd = torch.sqrt(sq.sum(axis=1))
450
+ # clone --> UnsupportedAliasMutationException:
451
+ # Combine_fn might be aliasing the input!
452
+ return [unused.clone(), rd]
453
+
454
+ def forward(self, x):
455
+ z = torch.tensor([0], dtype=torch.float32)
456
+ y = x.clone()
457
+ out = torch.ops.higher_order.scan(
458
+ ControlFlowScanCDist2.dist,
459
+ [z],
460
+ [x],
461
+ # dim=0, # 01/31/2025, not supported anymore
462
+ additional_inputs=[y],
463
+ )
464
+ return out[1]
465
+
466
+ _inputs = (
467
+ torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
468
+ )
469
+ _dynamic = {"x": {0: DIM("batch")}}
470
+
471
+
472
+ class ControlFlowScanCDistXY(torch.nn.Module):
473
+ @staticmethod
474
+ def dist(y: torch.Tensor, scanned_x: torch.Tensor):
475
+ sub = y - scanned_x.reshape((1, -1))
476
+ sq = sub * sub
477
+ rd = torch.sqrt(sq.sum(axis=1))
478
+ # clone --> UnsupportedAliasMutationException:
479
+ # Combine_fn might be aliasing the input!
480
+ return [y.clone(), rd]
481
+
482
+ def forward(self, x, y):
483
+ _carry, out = torch.ops.higher_order.scan(
484
+ ControlFlowScanCDistXY.dist,
485
+ [y],
486
+ [x],
487
+ # dim=0, # 01/31/2025, not supported anymore
488
+ additional_inputs=[],
489
+ )
490
+ return out
491
+
492
+ _inputs = [
493
+ (torch.randn(3, 4), torch.randn(5, 4)),
494
+ (torch.randn(13, 14), torch.randn(15, 14)),
495
+ ]
496
+ _dynamic = {
497
+ "x": {0: DIM("x_rows"), 1: DIM("dim")},
498
+ "y": {0: DIM("y_rows"), 1: DIM("dim")},
499
+ }
500
+
501
+
502
+ class ControlFlowScanInplace_153705(torch.nn.Module):
503
+ """
504
+ `#153705 <https://github.com/pytorch/pytorch/issues/153705>`_
505
+ """
506
+
507
+ def forward(self, x, y):
508
+ def loop_body_1(z, iv, x, y):
509
+ z = z.clone()
510
+ i = iv.item()
511
+ z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
512
+ return [z, iv]
513
+
514
+ z = torch.empty((x.shape[0], y.shape[0]))
515
+ r = torch.ops.higher_order.scan(
516
+ loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
517
+ )
518
+ return r[0]
519
+
520
+ _inputs = [
521
+ (torch.rand((3, 4)), torch.rand((5, 4))),
522
+ (torch.rand((4, 5)), torch.rand((6, 5))),
523
+ ]
524
+ _dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}}
525
+
526
+
527
+ class ControlFlowScanDecomposition_151564(torch.nn.Module):
528
+ """
529
+ `#151564 <https://github.com/pytorch/pytorch/issues/151564>`_
530
+ """
531
+
532
+ @classmethod
533
+ def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor):
534
+ copy = torch.zeros(padded.shape)
535
+ for i in range(pos.shape[0]):
536
+ p = pos[i]
537
+ copy[i, :p] = padded[i, :p]
538
+ return copy
539
+
540
+ @classmethod
541
+ def dummy_loop_with_scan(cls, padded: torch.Tensor, pos: torch.Tensor):
542
+ def pad_row(padded, p):
543
+ row = torch.zeros((padded.shape[0],))
544
+ torch._check(p.item() > 0)
545
+ torch._check(p.item() < padded.shape[0])
546
+ # this check is not always true, we add it anyway to make this dimension >= 2
547
+ # and avoid raising an exception about dynamic dimension in {0, 1}
548
+ if torch.compiler.is_exporting():
549
+ torch._check(p.item() > 1)
550
+ row[: p.item()] = padded[: p.item()]
551
+ return (row,)
552
+
553
+ return torch.ops.higher_order.scan(
554
+ pad_row,
555
+ [],
556
+ [padded, pos],
557
+ [],
558
+ )
559
+
560
+ @classmethod
561
+ def select_when_exporting(cls, f, f_scan):
562
+ return f_scan if torch.compiler.is_exporting() else f
563
+
564
+ def forward(self, images, position):
565
+ return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
566
+ images, position
567
+ )
568
+
569
+ _inputs = [(torch.randn((5, 6)), torch.arange(5, dtype=torch.int64) + 1)]
570
+ _dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
571
+
572
+
573
+ class SignatureInt1(torch.nn.Module):
574
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
575
+ super().__init__()
576
+ self.linear = torch.nn.Linear(n_dims, n_targets)
577
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
578
+
579
+ def forward(self, x, i: int = 2):
580
+ return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1]
581
+
582
+ _inputs = [
583
+ ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1),
584
+ ((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2),
585
+ ]
586
+ _dynamic = ({0: DIM("batch", min=1, max=1024)}, None)
587
+
588
+
589
+ class SignatureFloat1(torch.nn.Module):
590
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
591
+ super().__init__()
592
+ self.linear = torch.nn.Linear(n_dims, n_targets)
593
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
594
+
595
+ def forward(self, x, alpha: float = 2.0):
596
+ return torch.sigmoid(self.linear(x)) - self.buff * alpha
597
+
598
+ _inputs = [
599
+ ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1.5),
600
+ ((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2.5),
601
+ ]
602
+ _dynamic = ({0: DIM("batch", min=1, max=1024)}, None)
603
+
604
+
605
+ class SignatureInt2(torch.nn.Module):
606
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
607
+ super().__init__()
608
+ self.linear = torch.nn.Linear(n_dims, n_targets)
609
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
610
+
611
+ def forward(self, x, i: int = 2):
612
+ return torch.sigmoid(self.linear(x)) - self.buff + x[:, i]
613
+
614
+ _inputs = ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1)
615
+ _dynamic = {
616
+ "x": {0: DIM("batch")},
617
+ "i": None, # DIM("ii", min=0, max=3)}
618
+ }
619
+
620
+
621
+ class SignatureListFixedLength(torch.nn.Module):
622
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
623
+ super().__init__()
624
+ self.linear = torch.nn.Linear(n_dims, n_targets)
625
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
626
+
627
+ def forward(self, x, lx: list):
628
+ return (
629
+ torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True)
630
+ )
631
+
632
+ _inputs = [
633
+ (
634
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
635
+ [
636
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
637
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
638
+ ],
639
+ ),
640
+ (
641
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
642
+ [
643
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
644
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
645
+ ],
646
+ ),
647
+ ]
648
+ _dynamic = {
649
+ "x": {0: DIM("batch")},
650
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
651
+ }
652
+
653
+
654
+ class SignatureListVariableLength(torch.nn.Module):
655
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
656
+ super().__init__()
657
+ self.linear = torch.nn.Linear(n_dims, n_targets)
658
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
659
+
660
+ def forward(self, x, lx: list):
661
+ t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True)
662
+ return torch.sigmoid(self.linear(x)) - self.buff + t
663
+
664
+ _inputs = [
665
+ (
666
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
667
+ [
668
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
669
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
670
+ ],
671
+ ),
672
+ (
673
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
674
+ [
675
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
676
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
677
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
678
+ ],
679
+ ),
680
+ ]
681
+ _dynamic = {
682
+ "x": {0: DIM("batch")},
683
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
684
+ }
685
+
686
+
687
+ class BuildInLen(torch.nn.Module):
688
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
689
+ super().__init__()
690
+ self.linear = torch.nn.Linear(n_dims, n_targets)
691
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
692
+
693
+ def forward(self, x, lx: list):
694
+ t = lx[0] * lx[1].sum(axis=1, keepdim=True)
695
+ if len(lx) > 2:
696
+ t = t + lx[2].sum(axis=1, keepdim=True)
697
+ return torch.sigmoid(self.linear(x)) - self.buff + t
698
+
699
+ _inputs = [
700
+ (
701
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
702
+ [
703
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
704
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
705
+ ],
706
+ ),
707
+ (
708
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
709
+ [
710
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
711
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
712
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
713
+ ],
714
+ ),
715
+ ]
716
+ _dynamic = {
717
+ "x": {0: DIM("batch")},
718
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
719
+ }
720
+
721
+
722
+ class BuildInIsInstance(torch.nn.Module):
723
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
724
+ super().__init__()
725
+ self.linear = torch.nn.Linear(n_dims, n_targets)
726
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
727
+
728
+ def forward(self, x, lx: list | torch.Tensor):
729
+ if isinstance(lx, list):
730
+ t = lx[0] * lx[1].sum(axis=1, keepdim=True)
731
+ return torch.sigmoid(self.linear(x)) - self.buff + t
732
+ return torch.sigmoid(self.linear(x)) - self.buff + lx
733
+
734
+ _inputs = [
735
+ (
736
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
737
+ [
738
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
739
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
740
+ ],
741
+ ),
742
+ (
743
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
744
+ [
745
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
746
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
747
+ ],
748
+ ),
749
+ ]
750
+ _dynamic = {
751
+ "x": {0: DIM("batch")},
752
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
753
+ }
754
+
755
+
756
+ class SignatureShapeAsIndex(torch.nn.Module):
757
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
758
+ super().__init__()
759
+ self.linear = torch.nn.Linear(n_dims, n_targets)
760
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
761
+
762
+ def forward(self, x, y):
763
+ t = torch.sigmoid(self.linear(x)) + x
764
+ return t[:, : y.shape[1]]
765
+
766
+ _inputs = (
767
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
768
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
769
+ )
770
+ _dynamic = {
771
+ "x": {0: DIM("batch", min=0, max=1024)},
772
+ "y": {
773
+ 0: DIM("batch", min=0, max=1024),
774
+ 1: DIM("length", min=0, max=2),
775
+ },
776
+ }
777
+
778
+
779
+ class TypeBFloat16(torch.nn.Module):
780
+ def forward(self, x):
781
+ xb = x.to(torch.bfloat16)
782
+ return (xb + xb).to(torch.float32)
783
+
784
+ _inputs = (torch.rand(4, 4).to(torch.float32),)
785
+ _dynamic = {"x": {0: DIM("batch")}}
786
+
787
+
788
+ class CropLastDimensionWithTensorShape(torch.nn.Module):
789
+
790
+ def forward(self, x, y):
791
+ return x[..., : y.shape[0]]
792
+
793
+ _inputs = [
794
+ (
795
+ torch.rand(3, 4, 4).to(torch.float32),
796
+ torch.rand(
797
+ 2,
798
+ ).to(torch.float32),
799
+ ),
800
+ (
801
+ torch.rand(6, 4, 4).to(torch.float32),
802
+ torch.rand(
803
+ 3,
804
+ ).to(torch.float32),
805
+ ),
806
+ ]
807
+ _dynamic = {
808
+ "x": {0: DIM("batch")},
809
+ "y": {0: DIM("crop", min=1, max=3)},
810
+ }
811
+
812
+
813
+ class CropLastDimensionWithTensorContent(torch.nn.Module):
814
+ def forward(self, x, shape):
815
+ return x[..., : shape[0]]
816
+
817
+ _inputs = [
818
+ (torch.rand(3, 4, 4).to(torch.float32), torch.tensor([2], dtype=torch.int64)),
819
+ (torch.rand(6, 4, 4).to(torch.float32), torch.tensor([3], dtype=torch.int64)),
820
+ ]
821
+ _dynamic = {"x": {0: DIM("batch")}, "shape": {}}
822
+
823
+
824
+ class SignatureListFixedWithNone(torch.nn.Module):
825
+ def forward(self, lx):
826
+ x = lx[0]
827
+ if lx[1] is not None:
828
+ x += lx[1]
829
+ if lx[2] is not None:
830
+ x += lx[2]
831
+ return x
832
+
833
+ _inputs = [
834
+ ([torch.rand((4, 4)), torch.rand((4, 4)), None],),
835
+ ([torch.rand((4, 4)), torch.rand((4, 4)), torch.rand((4, 4))],),
836
+ ]
837
+ _dynamic = {
838
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
839
+ }
840
+
841
+
842
+ class CreateFromShape(torch.nn.Module):
843
+ def forward(self, x):
844
+ y = torch.ones((x.shape[0], x.shape[1] + 1))
845
+ return y
846
+
847
+ _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
848
+ _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
849
+
850
+
851
+ class CreateFromShapeThroughFunction(torch.nn.Module):
852
+ @staticmethod
853
+ def add_one(dim):
854
+ return dim + 1
855
+
856
+ def forward(self, x):
857
+ dy1 = CreateFromShapeThroughFunction.add_one(x.shape[1])
858
+ y = torch.ones((x.shape[0], dy1))
859
+ return y
860
+
861
+ _inputs = [(torch.rand((4, 4)),)]
862
+ _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
863
+
864
+
865
+ class Vmap(torch.nn.Module):
866
+ def forward(self, x, y):
867
+ f = lambda x, y: x * y + 1 # noqa: E731
868
+ return torch.vmap(f)(x, y)
869
+
870
+ _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
871
+ _dynamic = {"x": {0: DYN}, "y": {0: DYN}}
872
+
873
+
874
+ class VmapPython(torch.nn.Module):
875
+ def forward(self, x, y):
876
+ f = lambda x, y: x * y + 1 # noqa: E731
877
+ return patched_vmap(f)(x, y)
878
+
879
+ _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
880
+ _dynamic = {"x": {0: DYN}, "y": {0: DYN}}
881
+
882
+
883
+ class ExportWithDimension0(torch.nn.Module):
884
+ def forward(self, x):
885
+ return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
886
+
887
+ _inputs = [(torch.empty((0, 3), dtype=torch.float32),)]
888
+ _dynamic = {"x": {0: DYN, 1: DYN}}
889
+ _valid = [(torch.rand((2, 3), dtype=torch.float32),)]
890
+
891
+
892
+ class ExportWithDimension1(torch.nn.Module):
893
+ def forward(self, x):
894
+ return x @ torch.arange(x.shape[1], dtype=torch.float32).reshape((-1, 1))
895
+
896
+ _inputs = [(torch.zeros((1, 3), dtype=torch.float32),)]
897
+ _dynamic = {"x": {0: DYN, 1: DYN}}
898
+ _valid = [(torch.rand((2, 3), dtype=torch.float32),)]