onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +2 -2
- onnx_diagnostic/_command_lines_parser.py +39 -1
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/export/dynamic_shapes.py +14 -5
- onnx_diagnostic/ext_test_case.py +15 -1
- onnx_diagnostic/helpers/args_helper.py +1 -1
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +30 -5
- onnx_diagnostic/helpers/model_builder_helper.py +349 -0
- onnx_diagnostic/helpers/rt_helper.py +69 -1
- onnx_diagnostic/helpers/torch_helper.py +2 -0
- onnx_diagnostic/reference/__init__.py +1 -0
- onnx_diagnostic/reference/torch_evaluator.py +518 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +55 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +326 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +84 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +118 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +35 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +176 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +120 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +86 -0
- onnx_diagnostic/tasks/__init__.py +22 -1
- onnx_diagnostic/tasks/image_classification.py +2 -2
- onnx_diagnostic/tasks/text_generation.py +3 -3
- onnx_diagnostic/torch_export_patches/eval/__init__.py +690 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +883 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +34 -1
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +6 -1
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +148 -28
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +91 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +117 -1
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +142 -0
- onnx_diagnostic/torch_models/test_helper.py +225 -22
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/RECORD +43 -24
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/WHEEL +1 -1
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.5.0.dist-info → onnx_diagnostic-0.6.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,883 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from ..patches.patch_torch import patched_vmap
|
|
4
|
+
|
|
5
|
+
DIM = torch.export.Dim
|
|
6
|
+
DYN = torch.export.Dim.DYNAMIC
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AtenRollRelu(torch.nn.Module):
|
|
10
|
+
def forward(self, x):
|
|
11
|
+
return torch.relu(torch.roll(x, -1, -1))
|
|
12
|
+
|
|
13
|
+
_inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),)
|
|
14
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class AtenRollPos(torch.nn.Module):
|
|
18
|
+
def forward(self, x):
|
|
19
|
+
return torch.roll(x, 1, -1)
|
|
20
|
+
|
|
21
|
+
_inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),)
|
|
22
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class InplaceAdd(torch.nn.Module):
|
|
26
|
+
def __init__(self):
|
|
27
|
+
super().__init__()
|
|
28
|
+
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
29
|
+
|
|
30
|
+
def forward(self, x):
|
|
31
|
+
x += self.bias
|
|
32
|
+
return x
|
|
33
|
+
|
|
34
|
+
_inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
|
|
35
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class InplaceAdd2(torch.nn.Module):
|
|
39
|
+
def __init__(self):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
42
|
+
|
|
43
|
+
def forward(self, x):
|
|
44
|
+
x.add_(self.bias)
|
|
45
|
+
return x
|
|
46
|
+
|
|
47
|
+
_inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
|
|
48
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class InplaceAdd_Mul(torch.nn.Module):
|
|
52
|
+
def __init__(self):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
55
|
+
|
|
56
|
+
def forward(self, x):
|
|
57
|
+
x.add_(self.bias)
|
|
58
|
+
return x * 2
|
|
59
|
+
|
|
60
|
+
_inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
|
|
61
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class InplaceCloneAdd_(torch.nn.Module):
|
|
65
|
+
def __init__(self):
|
|
66
|
+
super().__init__()
|
|
67
|
+
self.bias = torch.ones((1, 4), dtype=torch.float32)
|
|
68
|
+
|
|
69
|
+
def forward(self, x):
|
|
70
|
+
x = x.clone()
|
|
71
|
+
x.add_(self.bias)
|
|
72
|
+
return x
|
|
73
|
+
|
|
74
|
+
_inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
|
|
75
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class InplaceSetItemSquare(torch.nn.Module):
|
|
79
|
+
def forward(self, x):
|
|
80
|
+
x[:2, :3] = 1
|
|
81
|
+
return x
|
|
82
|
+
|
|
83
|
+
_inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
|
|
84
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class InplaceSetItemSquareAdd(torch.nn.Module):
|
|
88
|
+
def forward(self, x):
|
|
89
|
+
x[:2, :3] = 1
|
|
90
|
+
return x + 2
|
|
91
|
+
|
|
92
|
+
_inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
|
|
93
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class InplaceSetItemSquareAdd2(torch.nn.Module):
|
|
97
|
+
def forward(self, x):
|
|
98
|
+
x[:2, :3] = 1
|
|
99
|
+
return x + 2, x + 3
|
|
100
|
+
|
|
101
|
+
_inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
|
|
102
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class InplaceSetItemEllipsis_1(torch.nn.Module):
|
|
106
|
+
def __init__(self):
|
|
107
|
+
super().__init__()
|
|
108
|
+
self.params = torch.zeros((1, 8192, 4), dtype=torch.float32)
|
|
109
|
+
|
|
110
|
+
def forward(self, index, update):
|
|
111
|
+
copy = self.params.clone()
|
|
112
|
+
copy[..., index] = update
|
|
113
|
+
return copy
|
|
114
|
+
|
|
115
|
+
_inputs = (
|
|
116
|
+
(torch.from_numpy(np.array([0, 3, 2, 1])).to(torch.int64)),
|
|
117
|
+
(torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
|
|
118
|
+
)
|
|
119
|
+
_dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class InplaceSetItemEllipsis_2(torch.nn.Module):
|
|
123
|
+
def __init__(self):
|
|
124
|
+
super().__init__()
|
|
125
|
+
self.params = torch.zeros((1, 8192, 6), dtype=torch.float32)
|
|
126
|
+
|
|
127
|
+
def forward(self, index, update):
|
|
128
|
+
copy = self.params.clone()
|
|
129
|
+
copy[..., index] = update
|
|
130
|
+
return copy
|
|
131
|
+
|
|
132
|
+
_inputs = (
|
|
133
|
+
torch.from_numpy(np.array([0, 3, 2, 5])).to(torch.int64),
|
|
134
|
+
(torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
|
|
135
|
+
)
|
|
136
|
+
_dynamic = {"index": {0: DIM("batch")}, "update": {0: DIM("batch"), 1: DYN}}
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class InplaceSetItemMask(torch.nn.Module):
|
|
140
|
+
def forward(self, x):
|
|
141
|
+
mask = x.to(bool)
|
|
142
|
+
x[mask] = 2
|
|
143
|
+
return x
|
|
144
|
+
|
|
145
|
+
_inputs = [(torch.randn((2, 3, 3)),), (torch.randn((3, 3, 3)),)]
|
|
146
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
class AtenInterpolate(torch.nn.Module):
|
|
150
|
+
def forward(self, x):
|
|
151
|
+
y = torch.nn.functional.interpolate(
|
|
152
|
+
x,
|
|
153
|
+
scale_factor=2.0,
|
|
154
|
+
mode="bilinear",
|
|
155
|
+
recompute_scale_factor=False,
|
|
156
|
+
)
|
|
157
|
+
return y
|
|
158
|
+
|
|
159
|
+
_inputs = (torch.randn(2, 2, 3, 4, requires_grad=False),)
|
|
160
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class AtenNonZero(torch.nn.Module):
|
|
164
|
+
def forward(self, x):
|
|
165
|
+
y = torch.nonzero(x)
|
|
166
|
+
return y
|
|
167
|
+
|
|
168
|
+
_inputs = (torch.randn(3, 4, requires_grad=False),)
|
|
169
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class AtenNonZeroTuple(torch.nn.Module):
|
|
173
|
+
def forward(self, x):
|
|
174
|
+
y = torch.nonzero(x, as_tuple=True)
|
|
175
|
+
return y[0], y[1]
|
|
176
|
+
|
|
177
|
+
_inputs = (torch.randn(3, 4, requires_grad=False),)
|
|
178
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class AtenAsStrided(torch.nn.Module):
|
|
182
|
+
def __init__(self):
|
|
183
|
+
super().__init__()
|
|
184
|
+
|
|
185
|
+
def forward(self, x):
|
|
186
|
+
y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1))
|
|
187
|
+
return y
|
|
188
|
+
|
|
189
|
+
_inputs = (torch.randn((2, 2, 8, 8), requires_grad=False),)
|
|
190
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
class ComplexPolar(torch.nn.Module):
|
|
194
|
+
def forward(self, x, angle):
|
|
195
|
+
return torch.polar(x, angle)
|
|
196
|
+
|
|
197
|
+
_inputs = (torch.rand(4, 4), torch.rand(4, 4))
|
|
198
|
+
_dynamic = {"x": {0: DIM("batch")}, "angle": {0: DIM("batch")}}
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
class ControlFlowCond(torch.nn.Module):
|
|
202
|
+
def forward(self, x):
|
|
203
|
+
def true_fn(x):
|
|
204
|
+
return torch.sin(x)
|
|
205
|
+
|
|
206
|
+
def false_fn(x):
|
|
207
|
+
return torch.cos(x)
|
|
208
|
+
|
|
209
|
+
return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
|
|
210
|
+
|
|
211
|
+
_inputs = (torch.rand(5, 3),)
|
|
212
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
class ControlFlowCond2Outputs(torch.nn.Module):
|
|
216
|
+
def forward(self, x):
|
|
217
|
+
def true_fn(x):
|
|
218
|
+
return torch.sin(x), torch.cos(x)
|
|
219
|
+
|
|
220
|
+
def false_fn(x):
|
|
221
|
+
return torch.cos(x), torch.sin(x)
|
|
222
|
+
|
|
223
|
+
return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
|
|
224
|
+
|
|
225
|
+
_inputs = (torch.rand(5, 3),)
|
|
226
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
227
|
+
|
|
228
|
+
|
|
229
|
+
class ControlFlowCond2Inputs(torch.nn.Module):
|
|
230
|
+
def forward(self, x, y):
|
|
231
|
+
def true_fn(x, y):
|
|
232
|
+
return torch.sin(x), torch.cos(x) + y
|
|
233
|
+
|
|
234
|
+
def false_fn(x, y):
|
|
235
|
+
return torch.cos(x), torch.sin(x) + y
|
|
236
|
+
|
|
237
|
+
return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])
|
|
238
|
+
|
|
239
|
+
_inputs = torch.rand(5, 3), torch.rand(5, 3)
|
|
240
|
+
_dynamic = {"x": {0: DIM("batch")}, "y": {0: DIM("batch")}}
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
class ControlFlowNestCond(torch.nn.Module):
|
|
244
|
+
def forward(self, x):
|
|
245
|
+
def true_fn2(x):
|
|
246
|
+
def true_fn1(x):
|
|
247
|
+
return torch.sin(x)
|
|
248
|
+
|
|
249
|
+
def false_fn1(x):
|
|
250
|
+
return torch.cos(x)
|
|
251
|
+
|
|
252
|
+
return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x])
|
|
253
|
+
|
|
254
|
+
def false_fn2(x):
|
|
255
|
+
return -x
|
|
256
|
+
|
|
257
|
+
return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x])
|
|
258
|
+
|
|
259
|
+
_inputs = (torch.rand(5, 3),)
|
|
260
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class ControlFlowCondConstant(torch.nn.Module):
|
|
264
|
+
def forward(self, x):
|
|
265
|
+
def true_fn(x):
|
|
266
|
+
return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype)
|
|
267
|
+
|
|
268
|
+
def false_fn(x):
|
|
269
|
+
return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype)
|
|
270
|
+
|
|
271
|
+
return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
|
|
272
|
+
|
|
273
|
+
_inputs = (torch.rand(1024, 1024),)
|
|
274
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
class ControlFlowCondNestedModule(torch.nn.Module):
|
|
278
|
+
class Submodule(torch.nn.Module):
|
|
279
|
+
def __init__(self):
|
|
280
|
+
super().__init__()
|
|
281
|
+
# Nested weight
|
|
282
|
+
self.weight = torch.nn.Parameter(torch.tensor([100.0]))
|
|
283
|
+
|
|
284
|
+
def forward(self, x):
|
|
285
|
+
def true_fn(x):
|
|
286
|
+
return x * self.weight
|
|
287
|
+
|
|
288
|
+
def false_fn(x):
|
|
289
|
+
return x / self.weight
|
|
290
|
+
|
|
291
|
+
y = torch.cond(torch.abs(x).sum() > 100, true_fn, false_fn, [x])
|
|
292
|
+
return y
|
|
293
|
+
|
|
294
|
+
def __init__(self):
|
|
295
|
+
super().__init__()
|
|
296
|
+
self.submodule = ControlFlowCondNestedModule.Submodule()
|
|
297
|
+
self.weight = torch.nn.Parameter(torch.tensor([42.0]))
|
|
298
|
+
|
|
299
|
+
def forward(self, x):
|
|
300
|
+
def true_fn(x):
|
|
301
|
+
return self.submodule(x)
|
|
302
|
+
|
|
303
|
+
def false_fn(x):
|
|
304
|
+
return x - self.weight
|
|
305
|
+
|
|
306
|
+
y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
|
|
307
|
+
return y
|
|
308
|
+
|
|
309
|
+
_inputs = (torch.tensor([-1, 2]),)
|
|
310
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
class ControlFlowCondNonZero(torch.nn.Module):
|
|
314
|
+
def forward(self, input_ids, image_features, vocab_size):
|
|
315
|
+
def then_branch(input_ids, image_features, vocab_size):
|
|
316
|
+
input_shape = input_ids.size()
|
|
317
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
|
318
|
+
|
|
319
|
+
condition = (input_ids < 0) & (input_ids > -int(1e9))
|
|
320
|
+
positions = torch.nonzero(condition, as_tuple=True)
|
|
321
|
+
input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)
|
|
322
|
+
return (input_ids, positions[0], positions[1])
|
|
323
|
+
|
|
324
|
+
def else_branch(input_ids, image_features, vocab_size):
|
|
325
|
+
r = torch.where(torch.zeros((1, 1), dtype=torch.bool))
|
|
326
|
+
return (input_ids, r[0], r[1])
|
|
327
|
+
|
|
328
|
+
a, b, c = torch.cond(
|
|
329
|
+
image_features.numel() > 0,
|
|
330
|
+
then_branch,
|
|
331
|
+
else_branch,
|
|
332
|
+
[input_ids, image_features, vocab_size],
|
|
333
|
+
)
|
|
334
|
+
return a, b, c
|
|
335
|
+
|
|
336
|
+
_inputs = [
|
|
337
|
+
(
|
|
338
|
+
(torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
|
|
339
|
+
torch.arange(32).reshape((2, -1)).to(torch.float32),
|
|
340
|
+
1025,
|
|
341
|
+
),
|
|
342
|
+
(
|
|
343
|
+
(torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
|
|
344
|
+
torch.tensor([[], []], dtype=torch.float32),
|
|
345
|
+
1025,
|
|
346
|
+
),
|
|
347
|
+
]
|
|
348
|
+
_dynamic = (
|
|
349
|
+
{0: DIM("batch")},
|
|
350
|
+
{0: DIM("batch"), 1: DIM("seq_length")},
|
|
351
|
+
None,
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
class ControlFlowCondIdentity_153832(torch.nn.Module):
|
|
356
|
+
"""
|
|
357
|
+
`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_
|
|
358
|
+
"""
|
|
359
|
+
|
|
360
|
+
def forward(self, x, y):
|
|
361
|
+
|
|
362
|
+
def branch_cond_then_1(x):
|
|
363
|
+
x = torch.abs(x) + 1
|
|
364
|
+
return x
|
|
365
|
+
|
|
366
|
+
def branch_cond_else_1(x):
|
|
367
|
+
return x # fails but succeeds with x.clone()
|
|
368
|
+
|
|
369
|
+
x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x])
|
|
370
|
+
return x + y
|
|
371
|
+
|
|
372
|
+
_inputs = [
|
|
373
|
+
(torch.rand((3, 4)), torch.rand((3, 4))),
|
|
374
|
+
(torch.rand((4, 5)), torch.rand((4, 5))),
|
|
375
|
+
]
|
|
376
|
+
_dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}}
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class ControlFlowScan(torch.nn.Module):
|
|
380
|
+
@staticmethod
|
|
381
|
+
def add(carry: torch.Tensor, y: torch.Tensor):
|
|
382
|
+
next_carry = carry + y
|
|
383
|
+
return [next_carry, next_carry]
|
|
384
|
+
|
|
385
|
+
def forward(self, x):
|
|
386
|
+
init = torch.zeros_like(x[0])
|
|
387
|
+
carry, out = torch.ops.higher_order.scan(
|
|
388
|
+
ControlFlowScan.add, [init], [x], additional_inputs=[]
|
|
389
|
+
)
|
|
390
|
+
return carry
|
|
391
|
+
|
|
392
|
+
_inputs = (torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32),)
|
|
393
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
class ControlFlowScan2Carried(torch.nn.Module):
|
|
397
|
+
@staticmethod
|
|
398
|
+
def add(carry1: torch.Tensor, carry2: torch.Tensor, y1: torch.Tensor, y2: torch.Tensor):
|
|
399
|
+
next_carry1 = carry1 + y1
|
|
400
|
+
next_carry2 = carry2 * y2
|
|
401
|
+
return [next_carry1, next_carry2, next_carry1, next_carry2]
|
|
402
|
+
|
|
403
|
+
def forward(self, x):
|
|
404
|
+
init1 = torch.zeros_like(x[0])
|
|
405
|
+
init2 = torch.ones_like(x[0])
|
|
406
|
+
carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
|
|
407
|
+
ControlFlowScan2Carried.add,
|
|
408
|
+
[init1, init2],
|
|
409
|
+
[x, x * 2],
|
|
410
|
+
# dim=0, # 01/31/2025, not supported anymore
|
|
411
|
+
additional_inputs=[],
|
|
412
|
+
)
|
|
413
|
+
return carry1, carry2, out1, out2
|
|
414
|
+
|
|
415
|
+
_inputs = (
|
|
416
|
+
torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
|
|
417
|
+
)
|
|
418
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class ControlFlowScanCDist(torch.nn.Module):
|
|
422
|
+
@staticmethod
|
|
423
|
+
def dist(carry: torch.Tensor, x: torch.Tensor):
|
|
424
|
+
sub = carry - x.reshape((1, -1))
|
|
425
|
+
sq = sub * sub
|
|
426
|
+
rd = sq.sum(axis=1) ** 0.5
|
|
427
|
+
# clone --> UnsupportedAliasMutationException:
|
|
428
|
+
# Combine_fn might be aliasing the input!
|
|
429
|
+
return [carry.clone(), rd]
|
|
430
|
+
|
|
431
|
+
def forward(self, x):
|
|
432
|
+
carry, out = torch.ops.higher_order.scan(
|
|
433
|
+
ControlFlowScanCDist.dist,
|
|
434
|
+
[x],
|
|
435
|
+
[x],
|
|
436
|
+
# dim=0, # 01/31/2025, not supported anymore
|
|
437
|
+
additional_inputs=[],
|
|
438
|
+
)
|
|
439
|
+
return out
|
|
440
|
+
|
|
441
|
+
_inputs = (
|
|
442
|
+
torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
|
|
443
|
+
)
|
|
444
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
class ControlFlowScanCDist2(torch.nn.Module):
|
|
448
|
+
@staticmethod
|
|
449
|
+
def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
|
|
450
|
+
sub = samex - x.reshape((1, -1))
|
|
451
|
+
sq = sub * sub
|
|
452
|
+
rd = torch.sqrt(sq.sum(axis=1))
|
|
453
|
+
# clone --> UnsupportedAliasMutationException:
|
|
454
|
+
# Combine_fn might be aliasing the input!
|
|
455
|
+
return [unused.clone(), rd]
|
|
456
|
+
|
|
457
|
+
def forward(self, x):
|
|
458
|
+
z = torch.tensor([0], dtype=torch.float32)
|
|
459
|
+
y = x.clone()
|
|
460
|
+
out = torch.ops.higher_order.scan(
|
|
461
|
+
ControlFlowScanCDist2.dist,
|
|
462
|
+
[z],
|
|
463
|
+
[x],
|
|
464
|
+
# dim=0, # 01/31/2025, not supported anymore
|
|
465
|
+
additional_inputs=[y],
|
|
466
|
+
)
|
|
467
|
+
return out[1]
|
|
468
|
+
|
|
469
|
+
_inputs = (
|
|
470
|
+
torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
|
|
471
|
+
)
|
|
472
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
class ControlFlowScanCDistXY(torch.nn.Module):
|
|
476
|
+
@staticmethod
|
|
477
|
+
def dist(y: torch.Tensor, scanned_x: torch.Tensor):
|
|
478
|
+
sub = y - scanned_x.reshape((1, -1))
|
|
479
|
+
sq = sub * sub
|
|
480
|
+
rd = torch.sqrt(sq.sum(axis=1))
|
|
481
|
+
# clone --> UnsupportedAliasMutationException:
|
|
482
|
+
# Combine_fn might be aliasing the input!
|
|
483
|
+
return [y.clone(), rd]
|
|
484
|
+
|
|
485
|
+
def forward(self, x, y):
|
|
486
|
+
carry, out = torch.ops.higher_order.scan(
|
|
487
|
+
ControlFlowScanCDistXY.dist,
|
|
488
|
+
[y],
|
|
489
|
+
[x],
|
|
490
|
+
# dim=0, # 01/31/2025, not supported anymore
|
|
491
|
+
additional_inputs=[],
|
|
492
|
+
)
|
|
493
|
+
return out
|
|
494
|
+
|
|
495
|
+
_inputs = [
|
|
496
|
+
(torch.randn(3, 4), torch.randn(5, 4)),
|
|
497
|
+
(torch.randn(13, 14), torch.randn(15, 14)),
|
|
498
|
+
]
|
|
499
|
+
_dynamic = {
|
|
500
|
+
"x": {0: DIM("x_rows"), 1: DIM("dim")},
|
|
501
|
+
"y": {0: DIM("y_rows"), 1: DIM("dim")},
|
|
502
|
+
}
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class ControlFlowScanInplace_153705(torch.nn.Module):
|
|
506
|
+
"""
|
|
507
|
+
`#153705 <https://github.com/pytorch/pytorch/issues/153705>`_
|
|
508
|
+
"""
|
|
509
|
+
|
|
510
|
+
def forward(self, x, y):
|
|
511
|
+
def loop_body_1(z, iv, x, y):
|
|
512
|
+
z = z.clone()
|
|
513
|
+
i = iv.item()
|
|
514
|
+
z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
|
|
515
|
+
return [z, iv]
|
|
516
|
+
|
|
517
|
+
z = torch.empty((x.shape[0], y.shape[0]))
|
|
518
|
+
r = torch.ops.higher_order.scan(
|
|
519
|
+
loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
|
|
520
|
+
)
|
|
521
|
+
return r[0]
|
|
522
|
+
|
|
523
|
+
_inputs = [
|
|
524
|
+
(torch.rand((3, 4)), torch.rand((5, 4))),
|
|
525
|
+
(torch.rand((4, 5)), torch.rand((6, 5))),
|
|
526
|
+
]
|
|
527
|
+
_dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}}
|
|
528
|
+
|
|
529
|
+
|
|
530
|
+
class ControlFlowScanDecomposition_151564(torch.nn.Module):
|
|
531
|
+
"""
|
|
532
|
+
`#151564 <https://github.com/pytorch/pytorch/issues/151564>`_
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
@classmethod
|
|
536
|
+
def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor):
|
|
537
|
+
copy = torch.zeros(padded.shape)
|
|
538
|
+
for i in range(pos.shape[0]):
|
|
539
|
+
p = pos[i]
|
|
540
|
+
copy[i, :p] = padded[i, :p]
|
|
541
|
+
return copy
|
|
542
|
+
|
|
543
|
+
@classmethod
|
|
544
|
+
def dummy_loop_with_scan(cls, padded: torch.Tensor, pos: torch.Tensor):
|
|
545
|
+
def pad_row(padded, p):
|
|
546
|
+
row = torch.zeros((padded.shape[0],))
|
|
547
|
+
torch._check(p.item() > 0)
|
|
548
|
+
torch._check(p.item() < padded.shape[0])
|
|
549
|
+
# this check is not always true, we add it anyway to make this dimension >= 2
|
|
550
|
+
# and avoid raising an exception about dynamic dimension in {0, 1}
|
|
551
|
+
if torch.compiler.is_exporting():
|
|
552
|
+
torch._check(p.item() > 1)
|
|
553
|
+
row[: p.item()] = padded[: p.item()]
|
|
554
|
+
return (row,)
|
|
555
|
+
|
|
556
|
+
return torch.ops.higher_order.scan(
|
|
557
|
+
pad_row,
|
|
558
|
+
[],
|
|
559
|
+
[padded, pos],
|
|
560
|
+
[],
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
@classmethod
|
|
564
|
+
def select_when_exporting(cls, f, f_scan):
|
|
565
|
+
return f_scan if torch.compiler.is_exporting() else f
|
|
566
|
+
|
|
567
|
+
def forward(self, images, position):
|
|
568
|
+
return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
|
|
569
|
+
images, position
|
|
570
|
+
)
|
|
571
|
+
|
|
572
|
+
_inputs = [(torch.randn((5, 6)), torch.arange(5, dtype=torch.int64) + 1)]
|
|
573
|
+
_dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
|
|
574
|
+
|
|
575
|
+
|
|
576
|
+
class SignatureInt1(torch.nn.Module):
|
|
577
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
578
|
+
super().__init__()
|
|
579
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
580
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
581
|
+
|
|
582
|
+
def forward(self, x, i: int = 2):
|
|
583
|
+
return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1]
|
|
584
|
+
|
|
585
|
+
_inputs = [
|
|
586
|
+
((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1),
|
|
587
|
+
((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2),
|
|
588
|
+
]
|
|
589
|
+
_dynamic = ({0: DIM("batch", min=1, max=1024)}, None)
|
|
590
|
+
|
|
591
|
+
|
|
592
|
+
class SignatureFloat1(torch.nn.Module):
|
|
593
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
594
|
+
super().__init__()
|
|
595
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
596
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
597
|
+
|
|
598
|
+
def forward(self, x, alpha: float = 2.0):
|
|
599
|
+
return torch.sigmoid(self.linear(x)) - self.buff * alpha
|
|
600
|
+
|
|
601
|
+
_inputs = [
|
|
602
|
+
((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1.5),
|
|
603
|
+
((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2.5),
|
|
604
|
+
]
|
|
605
|
+
_dynamic = ({0: DIM("batch", min=1, max=1024)}, None)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
class SignatureInt2(torch.nn.Module):
|
|
609
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
610
|
+
super().__init__()
|
|
611
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
612
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
613
|
+
|
|
614
|
+
def forward(self, x, i: int = 2):
|
|
615
|
+
return torch.sigmoid(self.linear(x)) - self.buff + x[:, i]
|
|
616
|
+
|
|
617
|
+
_inputs = ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1)
|
|
618
|
+
_dynamic = {
|
|
619
|
+
"x": {0: DIM("batch")},
|
|
620
|
+
"i": None, # DIM("ii", min=0, max=3)}
|
|
621
|
+
}
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
class SignatureListFixedLength(torch.nn.Module):
|
|
625
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
626
|
+
super().__init__()
|
|
627
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
628
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
629
|
+
|
|
630
|
+
def forward(self, x, lx: list):
|
|
631
|
+
return (
|
|
632
|
+
torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True)
|
|
633
|
+
)
|
|
634
|
+
|
|
635
|
+
_inputs = [
|
|
636
|
+
(
|
|
637
|
+
(torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
638
|
+
[
|
|
639
|
+
(torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
|
|
640
|
+
(torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
641
|
+
],
|
|
642
|
+
),
|
|
643
|
+
(
|
|
644
|
+
(torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
645
|
+
[
|
|
646
|
+
(torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
|
|
647
|
+
(torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
648
|
+
],
|
|
649
|
+
),
|
|
650
|
+
]
|
|
651
|
+
_dynamic = {
|
|
652
|
+
"x": {0: DIM("batch")},
|
|
653
|
+
"lx": [{0: DIM("batch")}, {0: DIM("batch")}],
|
|
654
|
+
}
|
|
655
|
+
|
|
656
|
+
|
|
657
|
+
class SignatureListVariableLength(torch.nn.Module):
|
|
658
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
659
|
+
super().__init__()
|
|
660
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
661
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
662
|
+
|
|
663
|
+
def forward(self, x, lx: list):
|
|
664
|
+
t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True)
|
|
665
|
+
return torch.sigmoid(self.linear(x)) - self.buff + t
|
|
666
|
+
|
|
667
|
+
_inputs = [
|
|
668
|
+
(
|
|
669
|
+
(torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
670
|
+
[
|
|
671
|
+
(torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
|
|
672
|
+
(torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
673
|
+
],
|
|
674
|
+
),
|
|
675
|
+
(
|
|
676
|
+
(torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
677
|
+
[
|
|
678
|
+
(torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
|
|
679
|
+
(torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
680
|
+
(torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
681
|
+
],
|
|
682
|
+
),
|
|
683
|
+
]
|
|
684
|
+
_dynamic = {
|
|
685
|
+
"x": {0: DIM("batch")},
|
|
686
|
+
"lx": [{0: DIM("batch")}, {0: DIM("batch")}],
|
|
687
|
+
}
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
class BuildInLen(torch.nn.Module):
|
|
691
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
692
|
+
super().__init__()
|
|
693
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
694
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
695
|
+
|
|
696
|
+
def forward(self, x, lx: list):
|
|
697
|
+
t = lx[0] * lx[1].sum(axis=1, keepdim=True)
|
|
698
|
+
if len(lx) > 2:
|
|
699
|
+
t = t + lx[2].sum(axis=1, keepdim=True)
|
|
700
|
+
return torch.sigmoid(self.linear(x)) - self.buff + t
|
|
701
|
+
|
|
702
|
+
_inputs = [
|
|
703
|
+
(
|
|
704
|
+
(torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
705
|
+
[
|
|
706
|
+
(torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
|
|
707
|
+
(torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
708
|
+
],
|
|
709
|
+
),
|
|
710
|
+
(
|
|
711
|
+
(torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
712
|
+
[
|
|
713
|
+
(torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
|
|
714
|
+
(torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
715
|
+
(torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
716
|
+
],
|
|
717
|
+
),
|
|
718
|
+
]
|
|
719
|
+
_dynamic = {
|
|
720
|
+
"x": {0: DIM("batch")},
|
|
721
|
+
"lx": [{0: DIM("batch")}, {0: DIM("batch")}],
|
|
722
|
+
}
|
|
723
|
+
|
|
724
|
+
|
|
725
|
+
class BuildInIsInstance(torch.nn.Module):
|
|
726
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
727
|
+
super().__init__()
|
|
728
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
729
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
730
|
+
|
|
731
|
+
def forward(self, x, lx: list | torch.Tensor):
|
|
732
|
+
if isinstance(lx, list):
|
|
733
|
+
t = lx[0] * lx[1].sum(axis=1, keepdim=True)
|
|
734
|
+
return torch.sigmoid(self.linear(x)) - self.buff + t
|
|
735
|
+
return torch.sigmoid(self.linear(x)) - self.buff + lx
|
|
736
|
+
|
|
737
|
+
_inputs = [
|
|
738
|
+
(
|
|
739
|
+
(torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
740
|
+
[
|
|
741
|
+
(torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
|
|
742
|
+
(torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
743
|
+
],
|
|
744
|
+
),
|
|
745
|
+
(
|
|
746
|
+
(torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
747
|
+
[
|
|
748
|
+
(torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
|
|
749
|
+
(torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
750
|
+
],
|
|
751
|
+
),
|
|
752
|
+
]
|
|
753
|
+
_dynamic = {
|
|
754
|
+
"x": {0: DIM("batch")},
|
|
755
|
+
"lx": [{0: DIM("batch")}, {0: DIM("batch")}],
|
|
756
|
+
}
|
|
757
|
+
|
|
758
|
+
|
|
759
|
+
class SignatureShapeAsIndex(torch.nn.Module):
|
|
760
|
+
def __init__(self, n_dims: int = 3, n_targets: int = 1):
|
|
761
|
+
super().__init__()
|
|
762
|
+
self.linear = torch.nn.Linear(n_dims, n_targets)
|
|
763
|
+
self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
|
|
764
|
+
|
|
765
|
+
def forward(self, x, y):
|
|
766
|
+
t = torch.sigmoid(self.linear(x)) + x
|
|
767
|
+
return t[:, : y.shape[1]]
|
|
768
|
+
|
|
769
|
+
_inputs = (
|
|
770
|
+
(torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
|
|
771
|
+
(torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
|
|
772
|
+
)
|
|
773
|
+
_dynamic = {
|
|
774
|
+
"x": {0: DIM("batch", min=0, max=1024)},
|
|
775
|
+
"y": {
|
|
776
|
+
0: DIM("batch", min=0, max=1024),
|
|
777
|
+
1: DIM("length", min=0, max=2),
|
|
778
|
+
},
|
|
779
|
+
}
|
|
780
|
+
|
|
781
|
+
|
|
782
|
+
class TypeBFloat16(torch.nn.Module):
|
|
783
|
+
def forward(self, x):
|
|
784
|
+
xb = x.to(torch.bfloat16)
|
|
785
|
+
return (xb + xb).to(torch.float32)
|
|
786
|
+
|
|
787
|
+
_inputs = (torch.rand(4, 4).to(torch.float32),)
|
|
788
|
+
_dynamic = {"x": {0: DIM("batch")}}
|
|
789
|
+
|
|
790
|
+
|
|
791
|
+
class CropLastDimensionWithTensorShape(torch.nn.Module):
|
|
792
|
+
|
|
793
|
+
def forward(self, x, y):
|
|
794
|
+
return x[..., : y.shape[0]]
|
|
795
|
+
|
|
796
|
+
_inputs = [
|
|
797
|
+
(
|
|
798
|
+
torch.rand(3, 4, 4).to(torch.float32),
|
|
799
|
+
torch.rand(
|
|
800
|
+
2,
|
|
801
|
+
).to(torch.float32),
|
|
802
|
+
),
|
|
803
|
+
(
|
|
804
|
+
torch.rand(6, 4, 4).to(torch.float32),
|
|
805
|
+
torch.rand(
|
|
806
|
+
3,
|
|
807
|
+
).to(torch.float32),
|
|
808
|
+
),
|
|
809
|
+
]
|
|
810
|
+
_dynamic = {
|
|
811
|
+
"x": {0: DIM("batch")},
|
|
812
|
+
"y": {0: DIM("crop", min=1, max=3)},
|
|
813
|
+
}
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
class CropLastDimensionWithTensorContent(torch.nn.Module):
|
|
817
|
+
def forward(self, x, shape):
|
|
818
|
+
return x[..., : shape[0]]
|
|
819
|
+
|
|
820
|
+
_inputs = [
|
|
821
|
+
(torch.rand(3, 4, 4).to(torch.float32), torch.tensor([2], dtype=torch.int64)),
|
|
822
|
+
(torch.rand(6, 4, 4).to(torch.float32), torch.tensor([3], dtype=torch.int64)),
|
|
823
|
+
]
|
|
824
|
+
_dynamic = {"x": {0: DIM("batch")}, "shape": {}}
|
|
825
|
+
|
|
826
|
+
|
|
827
|
+
class SignatureListFixedWithNone(torch.nn.Module):
|
|
828
|
+
def forward(self, lx):
|
|
829
|
+
x = lx[0]
|
|
830
|
+
if lx[1] is not None:
|
|
831
|
+
x += lx[1]
|
|
832
|
+
if lx[2] is not None:
|
|
833
|
+
x += lx[2]
|
|
834
|
+
return x
|
|
835
|
+
|
|
836
|
+
_inputs = [
|
|
837
|
+
([torch.rand((4, 4)), torch.rand((4, 4)), None],),
|
|
838
|
+
([torch.rand((4, 4)), torch.rand((4, 4)), torch.rand((4, 4))],),
|
|
839
|
+
]
|
|
840
|
+
_dynamic = {
|
|
841
|
+
"lx": [{0: DIM("batch")}, {0: DIM("batch")}],
|
|
842
|
+
}
|
|
843
|
+
|
|
844
|
+
|
|
845
|
+
class CreateFromShape(torch.nn.Module):
|
|
846
|
+
def forward(self, x):
|
|
847
|
+
y = torch.ones((x.shape[0], x.shape[1] + 1))
|
|
848
|
+
return y
|
|
849
|
+
|
|
850
|
+
_inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
|
|
851
|
+
_dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
class CreateFromShapeThroughFunction(torch.nn.Module):
|
|
855
|
+
@staticmethod
|
|
856
|
+
def add_one(dim):
|
|
857
|
+
return dim + 1
|
|
858
|
+
|
|
859
|
+
def forward(self, x):
|
|
860
|
+
dy1 = CreateFromShapeThroughFunction.add_one(x.shape[1])
|
|
861
|
+
y = torch.ones((x.shape[0], dy1))
|
|
862
|
+
return y
|
|
863
|
+
|
|
864
|
+
_inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
|
|
865
|
+
_dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
class Vmap(torch.nn.Module):
|
|
869
|
+
def forward(self, x, y):
|
|
870
|
+
f = lambda x, y: x * y + 1 # noqa: E731
|
|
871
|
+
return torch.vmap(f)(x, y)
|
|
872
|
+
|
|
873
|
+
_inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
|
|
874
|
+
_dynamic = {"x": {0: DYN}, "y": {0: DYN}}
|
|
875
|
+
|
|
876
|
+
|
|
877
|
+
class VmapPython(torch.nn.Module):
|
|
878
|
+
def forward(self, x, y):
|
|
879
|
+
f = lambda x, y: x * y + 1 # noqa: E731
|
|
880
|
+
return patched_vmap(f)(x, y)
|
|
881
|
+
|
|
882
|
+
_inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
|
|
883
|
+
_dynamic = {"x": {0: DYN}, "y": {0: DYN}}
|