onnx-diagnostic 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,896 @@
1
+ import numpy as np
2
+ import torch
3
+ from ..patches.patch_torch import patched_vmap
4
+
5
+ DIM = torch.export.Dim
6
+ DYN = torch.export.Dim.DYNAMIC
7
+
8
+
9
+ class AtenRollRelu(torch.nn.Module):
10
+ def forward(self, x):
11
+ return torch.relu(torch.roll(x, -1, -1))
12
+
13
+ _inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),)
14
+ _dynamic = {"x": {0: DIM("batch")}}
15
+
16
+
17
+ class AtenRollPos(torch.nn.Module):
18
+ def forward(self, x):
19
+ return torch.roll(x, 1, -1)
20
+
21
+ _inputs = ((torch.arange(8 * 3) + 10).reshape((2, -1, 4)).to(torch.float32),)
22
+ _dynamic = {"x": {0: DIM("batch")}}
23
+
24
+
25
+ class InplaceAdd(torch.nn.Module):
26
+
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
30
+
31
+ def forward(self, x):
32
+ x += self.bias
33
+ return x
34
+
35
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
36
+ _dynamic = {"x": {0: DIM("batch")}}
37
+
38
+
39
+ class InplaceAdd2(torch.nn.Module):
40
+
41
+ def __init__(self):
42
+ super().__init__()
43
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
44
+
45
+ def forward(self, x):
46
+ x.add_(self.bias)
47
+ return x
48
+
49
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
50
+ _dynamic = {"x": {0: DIM("batch")}}
51
+
52
+
53
+ class InplaceAdd_Mul(torch.nn.Module):
54
+
55
+ def __init__(self):
56
+ super().__init__()
57
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
58
+
59
+ def forward(self, x):
60
+ x.add_(self.bias)
61
+ return x * 2
62
+
63
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
64
+ _dynamic = {"x": {0: DIM("batch")}}
65
+
66
+
67
+ class InplaceCloneAdd_(torch.nn.Module):
68
+
69
+ def __init__(self):
70
+ super().__init__()
71
+ self.bias = torch.ones((1, 4), dtype=torch.float32)
72
+
73
+ def forward(self, x):
74
+ x = x.clone()
75
+ x.add_(self.bias)
76
+ return x
77
+
78
+ _inputs = [(torch.rand(3, 4),), (torch.rand(5, 4),)]
79
+ _dynamic = {"x": {0: DIM("batch")}}
80
+
81
+
82
+ class InplaceSetItemSquare(torch.nn.Module):
83
+
84
+ def forward(self, x):
85
+ x[:2, :3] = 1
86
+ return x
87
+
88
+ _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
89
+ _dynamic = {"x": {0: DIM("batch")}}
90
+
91
+
92
+ class InplaceSetItemSquareAdd(torch.nn.Module):
93
+
94
+ def forward(self, x):
95
+ x[:2, :3] = 1
96
+ return x + 2
97
+
98
+ _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
99
+ _dynamic = {"x": {0: DIM("batch")}}
100
+
101
+
102
+ class InplaceSetItemSquareAdd2(torch.nn.Module):
103
+
104
+ def forward(self, x):
105
+ x[:2, :3] = 1
106
+ return x + 2, x + 3
107
+
108
+ _inputs = [(torch.rand(5, 5),), (torch.rand(7, 5),)]
109
+ _dynamic = {"x": {0: DIM("batch")}}
110
+
111
+
112
+ class InplaceSetItemEllipsis_1(torch.nn.Module):
113
+
114
+ def __init__(self):
115
+ super().__init__()
116
+ self.params = torch.zeros((1, 8192, 4), dtype=torch.float32)
117
+
118
+ def forward(self, index, update):
119
+ copy = self.params.clone()
120
+ copy[..., index] = update
121
+ return copy
122
+
123
+ _inputs = (
124
+ (torch.from_numpy(np.array([0, 3, 2, 1])).to(torch.int64)),
125
+ (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
126
+ )
127
+ _dynamic = {"x": {0: DIM("batch")}}
128
+
129
+
130
+ class InplaceSetItemEllipsis_2(torch.nn.Module):
131
+
132
+ def __init__(self):
133
+ super().__init__()
134
+ self.params = torch.zeros((1, 8192, 6), dtype=torch.float32)
135
+
136
+ def forward(self, index, update):
137
+ copy = self.params.clone()
138
+ copy[..., index] = update
139
+ return copy
140
+
141
+ _inputs = (
142
+ torch.from_numpy(np.array([0, 3, 2, 5])).to(torch.int64),
143
+ (torch.arange(4 * 8192) + 10).reshape((-1, 4)).to(torch.float32),
144
+ )
145
+ _dynamic = {"x": {0: DIM("batch")}}
146
+
147
+
148
+ class InplaceSetItemMask(torch.nn.Module):
149
+ def forward(self, x):
150
+ mask = x.to(bool)
151
+ x[mask] = 2
152
+ return x
153
+
154
+ _inputs = [(torch.randn((2, 3, 3)),), (torch.randn((3, 3, 3)),)]
155
+ _dynamic = {"x": {0: DIM("batch")}}
156
+
157
+
158
+ class AtenInterpolate(torch.nn.Module):
159
+
160
+ def forward(self, x):
161
+ y = torch.nn.functional.interpolate(
162
+ x,
163
+ scale_factor=2.0,
164
+ mode="bilinear",
165
+ recompute_scale_factor=False,
166
+ )
167
+ return y
168
+
169
+ _inputs = (torch.randn(2, 2, 3, 4, requires_grad=False),)
170
+ _dynamic = {"x": {0: DIM("batch")}}
171
+
172
+
173
+ class AtenNonZero(torch.nn.Module):
174
+
175
+ def forward(self, x):
176
+ y = torch.nonzero(x)
177
+ return y
178
+
179
+ _inputs = (torch.randn(3, 4, requires_grad=False),)
180
+ _dynamic = {"x": {0: DIM("batch")}}
181
+
182
+
183
+ class AtenNonZeroTuple(torch.nn.Module):
184
+
185
+ def forward(self, x):
186
+ y = torch.nonzero(x, as_tuple=True)
187
+ return y[0], y[1]
188
+
189
+ _inputs = (torch.randn(3, 4, requires_grad=False),)
190
+ _dynamic = {"x": {0: DIM("batch")}}
191
+
192
+
193
+ class AtenAsStrided(torch.nn.Module):
194
+
195
+ def __init__(self):
196
+ super().__init__()
197
+
198
+ def forward(self, x):
199
+ y = torch.as_strided(x, (2, 2, 8, 4), (128, 8, 16, 1))
200
+ return y
201
+
202
+ _inputs = (torch.randn((2, 2, 8, 8), requires_grad=False),)
203
+ _dynamic = {"x": {0: DIM("batch")}}
204
+
205
+
206
+ class ComplexPolar(torch.nn.Module):
207
+ def forward(self, x, angle):
208
+ return torch.polar(x, angle)
209
+
210
+ _inputs = (torch.rand(4, 4), torch.rand(4, 4))
211
+ _dynamic = {"x": {0: DIM("batch")}, "angle": {0: DIM("batch")}}
212
+
213
+
214
+ class ControlFlowCond(torch.nn.Module):
215
+ def forward(self, x):
216
+ def true_fn(x):
217
+ return torch.sin(x)
218
+
219
+ def false_fn(x):
220
+ return torch.cos(x)
221
+
222
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
223
+
224
+ _inputs = (torch.rand(5, 3),)
225
+ _dynamic = {"x": {0: DIM("batch")}}
226
+
227
+
228
+ class ControlFlowCond2Outputs(torch.nn.Module):
229
+ def forward(self, x):
230
+ def true_fn(x):
231
+ return torch.sin(x), torch.cos(x)
232
+
233
+ def false_fn(x):
234
+ return torch.cos(x), torch.sin(x)
235
+
236
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
237
+
238
+ _inputs = (torch.rand(5, 3),)
239
+ _dynamic = {"x": {0: DIM("batch")}}
240
+
241
+
242
+ class ControlFlowCond2Inputs(torch.nn.Module):
243
+ def forward(self, x, y):
244
+ def true_fn(x, y):
245
+ return torch.sin(x), torch.cos(x) + y
246
+
247
+ def false_fn(x, y):
248
+ return torch.cos(x), torch.sin(x) + y
249
+
250
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x, y])
251
+
252
+ _inputs = torch.rand(5, 3), torch.rand(5, 3)
253
+ _dynamic = {"x": {0: DIM("batch")}, "y": {0: DIM("batch")}}
254
+
255
+
256
+ class ControlFlowNestCond(torch.nn.Module):
257
+ def forward(self, x):
258
+ def true_fn2(x):
259
+ def true_fn1(x):
260
+ return torch.sin(x)
261
+
262
+ def false_fn1(x):
263
+ return torch.cos(x)
264
+
265
+ return torch.cond(x.sum() < 0, true_fn1, false_fn1, [x])
266
+
267
+ def false_fn2(x):
268
+ return -x
269
+
270
+ return torch.cond(x.sum() > 0, true_fn2, false_fn2, [x])
271
+
272
+ _inputs = (torch.rand(5, 3),)
273
+ _dynamic = {"x": {0: DIM("batch")}}
274
+
275
+
276
+ class ControlFlowCondConstant(torch.nn.Module):
277
+ def forward(self, x):
278
+ def true_fn(x):
279
+ return torch.sin(x) - torch.ones(x.shape, dtype=x.dtype)
280
+
281
+ def false_fn(x):
282
+ return torch.cos(x) + torch.ones((1, 1024), dtype=x.dtype)
283
+
284
+ return torch.cond(x.sum() > 0, true_fn, false_fn, [x])
285
+
286
+ _inputs = (torch.rand(1024, 1024),)
287
+ _dynamic = {"x": {0: DIM("batch")}}
288
+
289
+
290
+ class ControlFlowCondNestedModule(torch.nn.Module):
291
+
292
+ class Submodule(torch.nn.Module):
293
+ def __init__(self):
294
+ super().__init__()
295
+ # Nested weight
296
+ self.weight = torch.nn.Parameter(torch.tensor([100.0]))
297
+
298
+ def forward(self, x):
299
+ def true_fn(x):
300
+ return x * self.weight
301
+
302
+ def false_fn(x):
303
+ return x / self.weight
304
+
305
+ y = torch.cond(torch.abs(x).sum() > 100, true_fn, false_fn, [x])
306
+ return y
307
+
308
+ def __init__(self):
309
+ super().__init__()
310
+ self.submodule = ControlFlowCondNestedModule.Submodule()
311
+ self.weight = torch.nn.Parameter(torch.tensor([42.0]))
312
+
313
+ def forward(self, x):
314
+ def true_fn(x):
315
+ return self.submodule(x)
316
+
317
+ def false_fn(x):
318
+ return x - self.weight
319
+
320
+ y = torch.cond(x.sum() > 0, true_fn, false_fn, [x])
321
+ return y
322
+
323
+ _inputs = (torch.tensor([-1, 2]),)
324
+ _dynamic = {"x": {0: DIM("batch")}}
325
+
326
+
327
+ class ControlFlowCondNonZero(torch.nn.Module):
328
+ def forward(self, input_ids, image_features, vocab_size):
329
+ def then_branch(input_ids, image_features, vocab_size):
330
+ input_shape = input_ids.size()
331
+ input_ids = input_ids.view(-1, input_shape[-1])
332
+
333
+ condition = (input_ids < 0) & (input_ids > -int(1e9))
334
+ positions = torch.nonzero(condition, as_tuple=True)
335
+ input_ids = input_ids.clamp_min(0).clamp_max(vocab_size)
336
+ return (input_ids, positions[0], positions[1])
337
+
338
+ def else_branch(input_ids, image_features, vocab_size):
339
+ r = torch.where(torch.zeros((1, 1), dtype=torch.bool))
340
+ return (input_ids, r[0], r[1])
341
+
342
+ a, b, c = torch.cond(
343
+ image_features.numel() > 0,
344
+ then_branch,
345
+ else_branch,
346
+ [input_ids, image_features, vocab_size],
347
+ )
348
+ return a, b, c
349
+
350
+ _inputs = [
351
+ (
352
+ (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
353
+ torch.arange(32).reshape((2, -1)).to(torch.float32),
354
+ 1025,
355
+ ),
356
+ (
357
+ (torch.arange(24) - 8).reshape((2, -1)).to(torch.int64),
358
+ torch.tensor([[], []], dtype=torch.float32),
359
+ 1025,
360
+ ),
361
+ ]
362
+ _dynamic = (
363
+ {0: DIM("batch")},
364
+ {0: DIM("batch"), 1: DIM("seq_length")},
365
+ None,
366
+ )
367
+
368
+
369
+ class ControlFlowCondIdentity_153832(torch.nn.Module):
370
+ """`#153832 <https://github.com/pytorch/pytorch/issues/153832>`_"""
371
+
372
+ def forward(self, x, y):
373
+
374
+ def branch_cond_then_1(x):
375
+ x = torch.abs(x) + 1
376
+ return x
377
+
378
+ def branch_cond_else_1(x):
379
+ return x # fails but succeeds with x.clone()
380
+
381
+ x = torch.cond(x.sum() > 0, branch_cond_then_1, branch_cond_else_1, [x])
382
+ return x + y
383
+
384
+ _inputs = [
385
+ (torch.rand((3, 4)), torch.rand((3, 4))),
386
+ (torch.rand((4, 5)), torch.rand((4, 5))),
387
+ ]
388
+ _dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}}
389
+
390
+
391
+ class ControlFlowScan(torch.nn.Module):
392
+
393
+ @staticmethod
394
+ def add(carry: torch.Tensor, y: torch.Tensor):
395
+ next_carry = carry + y
396
+ return [next_carry, next_carry]
397
+
398
+ def forward(self, x):
399
+ init = torch.zeros_like(x[0])
400
+ carry, out = torch.ops.higher_order.scan(
401
+ ControlFlowScan.add, [init], [x], additional_inputs=[]
402
+ )
403
+ return carry
404
+
405
+ _inputs = (torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float32),)
406
+ _dynamic = {"x": {0: DIM("batch")}}
407
+
408
+
409
+ class ControlFlowScan2Carried(torch.nn.Module):
410
+ @staticmethod
411
+ def add(carry1: torch.Tensor, carry2: torch.Tensor, y1: torch.Tensor, y2: torch.Tensor):
412
+ next_carry1 = carry1 + y1
413
+ next_carry2 = carry2 * y2
414
+ return [next_carry1, next_carry2, next_carry1, next_carry2]
415
+
416
+ def forward(self, x):
417
+ init1 = torch.zeros_like(x[0])
418
+ init2 = torch.ones_like(x[0])
419
+ carry1, carry2, out1, out2 = torch.ops.higher_order.scan(
420
+ ControlFlowScan2Carried.add,
421
+ [init1, init2],
422
+ [x, x * 2],
423
+ # dim=0, # 01/31/2025, not supported anymore
424
+ additional_inputs=[],
425
+ )
426
+ return carry1, carry2, out1, out2
427
+
428
+ _inputs = (
429
+ torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
430
+ )
431
+ _dynamic = {"x": {0: DIM("batch")}}
432
+
433
+
434
+ class ControlFlowScanCDist(torch.nn.Module):
435
+ @staticmethod
436
+ def dist(carry: torch.Tensor, x: torch.Tensor):
437
+ sub = carry - x.reshape((1, -1))
438
+ sq = sub * sub
439
+ rd = sq.sum(axis=1) ** 0.5
440
+ # clone --> UnsupportedAliasMutationException:
441
+ # Combine_fn might be aliasing the input!
442
+ return [carry.clone(), rd]
443
+
444
+ def forward(self, x):
445
+ carry, out = torch.ops.higher_order.scan(
446
+ ControlFlowScanCDist.dist,
447
+ [x],
448
+ [x],
449
+ # dim=0, # 01/31/2025, not supported anymore
450
+ additional_inputs=[],
451
+ )
452
+ return out
453
+
454
+ _inputs = (
455
+ torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
456
+ )
457
+ _dynamic = {"x": {0: DIM("batch")}}
458
+
459
+
460
+ class ControlFlowScanCDist2(torch.nn.Module):
461
+ @staticmethod
462
+ def dist(unused: torch.Tensor, x: torch.Tensor, samex: torch.Tensor):
463
+ sub = samex - x.reshape((1, -1))
464
+ sq = sub * sub
465
+ rd = torch.sqrt(sq.sum(axis=1))
466
+ # clone --> UnsupportedAliasMutationException:
467
+ # Combine_fn might be aliasing the input!
468
+ return [unused.clone(), rd]
469
+
470
+ def forward(self, x):
471
+ z = torch.tensor([0], dtype=torch.float32)
472
+ y = x.clone()
473
+ out = torch.ops.higher_order.scan(
474
+ ControlFlowScanCDist2.dist,
475
+ [z],
476
+ [x],
477
+ # dim=0, # 01/31/2025, not supported anymore
478
+ additional_inputs=[y],
479
+ )
480
+ return out[1]
481
+
482
+ _inputs = (
483
+ torch.tensor([[1, 2, 3, -1], [4, 5, 6, -1], [7, 8, 9, -1]], dtype=torch.float32),
484
+ )
485
+ _dynamic = {"x": {0: DIM("batch")}}
486
+
487
+
488
+ class ControlFlowScanCDistXY(torch.nn.Module):
489
+
490
+ @staticmethod
491
+ def dist(y: torch.Tensor, scanned_x: torch.Tensor):
492
+ sub = y - scanned_x.reshape((1, -1))
493
+ sq = sub * sub
494
+ rd = torch.sqrt(sq.sum(axis=1))
495
+ # clone --> UnsupportedAliasMutationException:
496
+ # Combine_fn might be aliasing the input!
497
+ return [y.clone(), rd]
498
+
499
+ def forward(self, x, y):
500
+ carry, out = torch.ops.higher_order.scan(
501
+ ControlFlowScanCDistXY.dist,
502
+ [y],
503
+ [x],
504
+ # dim=0, # 01/31/2025, not supported anymore
505
+ additional_inputs=[],
506
+ )
507
+ return out
508
+
509
+ _inputs = [
510
+ (torch.randn(3, 4), torch.randn(5, 4)),
511
+ (torch.randn(13, 14), torch.randn(15, 14)),
512
+ ]
513
+ _dynamic = {
514
+ "x": {0: DIM("x_rows"), 1: DIM("dim")},
515
+ "y": {0: DIM("y_rows"), 1: DIM("dim")},
516
+ }
517
+
518
+
519
+ class ControlFlowScanInplace_153705(torch.nn.Module):
520
+ """`#153705 <https://github.com/pytorch/pytorch/issues/153705>`_"""
521
+
522
+ def forward(self, x, y):
523
+ def loop_body_1(z, iv, x, y):
524
+ z = z.clone()
525
+ i = iv.item()
526
+ z[i, :] = ((x[i, :] - y) ** 2).sum(dim=-1)
527
+ return [z, iv]
528
+
529
+ z = torch.empty((x.shape[0], y.shape[0]))
530
+ r = torch.ops.higher_order.scan(
531
+ loop_body_1, [z], [torch.arange(x.shape[0], dtype=torch.int64)], [x, y]
532
+ )
533
+ return r[0]
534
+
535
+ _inputs = [
536
+ (torch.rand((3, 4)), torch.rand((5, 4))),
537
+ (torch.rand((4, 5)), torch.rand((6, 5))),
538
+ ]
539
+ _dynamic = {"x": {0: DYN, 1: DYN}, "y": {0: DYN, 1: DYN}}
540
+
541
+
542
+ class ControlFlowScanDecomposition_151564(torch.nn.Module):
543
+ """`#151564 <https://github.com/pytorch/pytorch/issues/151564>`_"""
544
+
545
+ @classmethod
546
+ def dummy_loop(cls, padded: torch.Tensor, pos: torch.Tensor):
547
+ copy = torch.zeros(padded.shape)
548
+ for i in range(pos.shape[0]):
549
+ p = pos[i]
550
+ copy[i, :p] = padded[i, :p]
551
+ return copy
552
+
553
+ @classmethod
554
+ def dummy_loop_with_scan(cls, padded: torch.Tensor, pos: torch.Tensor):
555
+ def pad_row(padded, p):
556
+ row = torch.zeros((padded.shape[0],))
557
+ torch._check(p.item() > 0)
558
+ torch._check(p.item() < padded.shape[0])
559
+ # this check is not always true, we add it anyway to make this dimension >= 2
560
+ # and avoid raising an exception about dynamic dimension in {0, 1}
561
+ if torch.compiler.is_exporting():
562
+ torch._check(p.item() > 1)
563
+ row[: p.item()] = padded[: p.item()]
564
+ return (row,)
565
+
566
+ return torch.ops.higher_order.scan(
567
+ pad_row,
568
+ [],
569
+ [padded, pos],
570
+ [],
571
+ )
572
+
573
+ @classmethod
574
+ def select_when_exporting(cls, f, f_scan):
575
+ return f_scan if torch.compiler.is_exporting() else f
576
+
577
+ def forward(self, images, position):
578
+ return self.select_when_exporting(self.dummy_loop, self.dummy_loop_with_scan)(
579
+ images, position
580
+ )
581
+
582
+ _inputs = [(torch.randn((5, 6)), torch.arange(5, dtype=torch.int64) + 1)]
583
+ _dynamic = {"images": {0: DYN, 1: DYN}, "position": {0: DYN}}
584
+
585
+
586
+ class SignatureInt1(torch.nn.Module):
587
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
588
+ super().__init__()
589
+ self.linear = torch.nn.Linear(n_dims, n_targets)
590
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
591
+
592
+ def forward(self, x, i: int = 2):
593
+ return torch.sigmoid(self.linear(x)) - self.buff + x[:, i : i + 1]
594
+
595
+ _inputs = [
596
+ ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1),
597
+ ((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2),
598
+ ]
599
+ _dynamic = ({0: DIM("batch", min=1, max=1024)}, None)
600
+
601
+
602
+ class SignatureFloat1(torch.nn.Module):
603
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
604
+ super().__init__()
605
+ self.linear = torch.nn.Linear(n_dims, n_targets)
606
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
607
+
608
+ def forward(self, x, alpha: float = 2.0):
609
+ return torch.sigmoid(self.linear(x)) - self.buff * alpha
610
+
611
+ _inputs = [
612
+ ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1.5),
613
+ ((torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32), 2.5),
614
+ ]
615
+ _dynamic = ({0: DIM("batch", min=1, max=1024)}, None)
616
+
617
+
618
+ class SignatureInt2(torch.nn.Module):
619
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
620
+ super().__init__()
621
+ self.linear = torch.nn.Linear(n_dims, n_targets)
622
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
623
+
624
+ def forward(self, x, i: int = 2):
625
+ return torch.sigmoid(self.linear(x)) - self.buff + x[:, i]
626
+
627
+ _inputs = ((torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32), 1)
628
+ _dynamic = {
629
+ "x": {0: DIM("batch")},
630
+ "i": None, # DIM("ii", min=0, max=3)}
631
+ }
632
+
633
+
634
+ class SignatureListFixedLength(torch.nn.Module):
635
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
636
+ super().__init__()
637
+ self.linear = torch.nn.Linear(n_dims, n_targets)
638
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
639
+
640
+ def forward(self, x, lx: list):
641
+ return (
642
+ torch.sigmoid(self.linear(x)) - self.buff + lx[0] * lx[1].sum(axis=1, keepdim=True)
643
+ )
644
+
645
+ _inputs = [
646
+ (
647
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
648
+ [
649
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
650
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
651
+ ],
652
+ ),
653
+ (
654
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
655
+ [
656
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
657
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
658
+ ],
659
+ ),
660
+ ]
661
+ _dynamic = {
662
+ "x": {0: DIM("batch")},
663
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
664
+ }
665
+
666
+
667
+ class SignatureListVariableLength(torch.nn.Module):
668
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
669
+ super().__init__()
670
+ self.linear = torch.nn.Linear(n_dims, n_targets)
671
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
672
+
673
+ def forward(self, x, lx: list):
674
+ t = torch.cat(lx, dim=1).sum(axis=1, keepdim=True)
675
+ return torch.sigmoid(self.linear(x)) - self.buff + t
676
+
677
+ _inputs = [
678
+ (
679
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
680
+ [
681
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
682
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
683
+ ],
684
+ ),
685
+ (
686
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
687
+ [
688
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
689
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
690
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
691
+ ],
692
+ ),
693
+ ]
694
+ _dynamic = {
695
+ "x": {0: DIM("batch")},
696
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
697
+ }
698
+
699
+
700
+ class BuildInLen(torch.nn.Module):
701
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
702
+ super().__init__()
703
+ self.linear = torch.nn.Linear(n_dims, n_targets)
704
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
705
+
706
+ def forward(self, x, lx: list):
707
+ t = lx[0] * lx[1].sum(axis=1, keepdim=True)
708
+ if len(lx) > 2:
709
+ t = t + lx[2].sum(axis=1, keepdim=True)
710
+ return torch.sigmoid(self.linear(x)) - self.buff + t
711
+
712
+ _inputs = [
713
+ (
714
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
715
+ [
716
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
717
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
718
+ ],
719
+ ),
720
+ (
721
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
722
+ [
723
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
724
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
725
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
726
+ ],
727
+ ),
728
+ ]
729
+ _dynamic = {
730
+ "x": {0: DIM("batch")},
731
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
732
+ }
733
+
734
+
735
+ class BuildInIsInstance(torch.nn.Module):
736
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
737
+ super().__init__()
738
+ self.linear = torch.nn.Linear(n_dims, n_targets)
739
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
740
+
741
+ def forward(self, x, lx: list | torch.Tensor):
742
+ if isinstance(lx, list):
743
+ t = lx[0] * lx[1].sum(axis=1, keepdim=True)
744
+ return torch.sigmoid(self.linear(x)) - self.buff + t
745
+ return torch.sigmoid(self.linear(x)) - self.buff + lx
746
+
747
+ _inputs = [
748
+ (
749
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
750
+ [
751
+ (torch.arange(4) + 10).reshape((-1, 1)).to(torch.float32),
752
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
753
+ ],
754
+ ),
755
+ (
756
+ (torch.arange(8 * 3) + 10).reshape((-1, 3)).to(torch.float32),
757
+ [
758
+ (torch.arange(8) + 10).reshape((-1, 1)).to(torch.float32),
759
+ (torch.arange(8 * 2) + 10).reshape((-1, 2)).to(torch.float32),
760
+ ],
761
+ ),
762
+ ]
763
+ _dynamic = {
764
+ "x": {0: DIM("batch")},
765
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
766
+ }
767
+
768
+
769
+ class SignatureShapeAsIndex(torch.nn.Module):
770
+ def __init__(self, n_dims: int = 3, n_targets: int = 1):
771
+ super().__init__()
772
+ self.linear = torch.nn.Linear(n_dims, n_targets)
773
+ self.buff = torch.nn.parameter.Buffer(torch.tensor([0.5] * n_targets))
774
+
775
+ def forward(self, x, y):
776
+ t = torch.sigmoid(self.linear(x)) + x
777
+ return t[:, : y.shape[1]]
778
+
779
+ _inputs = (
780
+ (torch.arange(4 * 3) + 10).reshape((-1, 3)).to(torch.float32),
781
+ (torch.arange(4 * 2) + 10).reshape((-1, 2)).to(torch.float32),
782
+ )
783
+ _dynamic = {
784
+ "x": {0: DIM("batch", min=0, max=1024)},
785
+ "y": {
786
+ 0: DIM("batch", min=0, max=1024),
787
+ 1: DIM("length", min=0, max=2),
788
+ },
789
+ }
790
+
791
+
792
+ class TypeBFloat16(torch.nn.Module):
793
+
794
+ def forward(self, x):
795
+ xb = x.to(torch.bfloat16)
796
+ return (xb + xb).to(torch.float32)
797
+
798
+ _inputs = (torch.rand(4, 4).to(torch.float32),)
799
+ _dynamic = {"x": {0: DIM("batch")}}
800
+
801
+
802
+ class CropLastDimensionWithTensorShape(torch.nn.Module):
803
+
804
+ def forward(self, x, y):
805
+ return x[..., : y.shape[0]]
806
+
807
+ _inputs = [
808
+ (
809
+ torch.rand(3, 4, 4).to(torch.float32),
810
+ torch.rand(
811
+ 2,
812
+ ).to(torch.float32),
813
+ ),
814
+ (
815
+ torch.rand(6, 4, 4).to(torch.float32),
816
+ torch.rand(
817
+ 3,
818
+ ).to(torch.float32),
819
+ ),
820
+ ]
821
+ _dynamic = {
822
+ "x": {0: DIM("batch")},
823
+ "y": {0: DIM("crop", min=1, max=3)},
824
+ }
825
+
826
+
827
+ class CropLastDimensionWithTensorContent(torch.nn.Module):
828
+
829
+ def forward(self, x, shape):
830
+ return x[..., : shape[0]]
831
+
832
+ _inputs = [
833
+ (torch.rand(3, 4, 4).to(torch.float32), torch.tensor([2], dtype=torch.int64)),
834
+ (torch.rand(6, 4, 4).to(torch.float32), torch.tensor([3], dtype=torch.int64)),
835
+ ]
836
+ _dynamic = {"x": {0: DIM("batch")}}
837
+
838
+
839
+ class SignatureListFixedWithNone(torch.nn.Module):
840
+
841
+ def forward(self, lx):
842
+ x = lx[0]
843
+ if lx[1] is not None:
844
+ x += lx[1]
845
+ if lx[2] is not None:
846
+ x += lx[2]
847
+ return x
848
+
849
+ _inputs = [
850
+ ([torch.rand((4, 4)), torch.rand((4, 4)), None],),
851
+ ([torch.rand((4, 4)), torch.rand((4, 4)), torch.rand((4, 4))],),
852
+ ]
853
+ _dynamic = {
854
+ "lx": [{0: DIM("batch")}, {0: DIM("batch")}],
855
+ }
856
+
857
+
858
+ class CreateFromShape(torch.nn.Module):
859
+ def forward(self, x):
860
+ y = torch.ones((x.shape[0], x.shape[1] + 1))
861
+ return y
862
+
863
+ _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
864
+ _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
865
+
866
+
867
+ class CreateFromShapeThroughFunction(torch.nn.Module):
868
+ @staticmethod
869
+ def add_one(dim):
870
+ return dim + 1
871
+
872
+ def forward(self, x):
873
+ dy1 = CreateFromShapeThroughFunction.add_one(x.shape[1])
874
+ y = torch.ones((x.shape[0], dy1))
875
+ return y
876
+
877
+ _inputs = [(torch.rand((4, 4)),), (torch.rand((5, 5)),)]
878
+ _dynamic = {"x": {0: DIM("dx"), 1: DIM("dy")}}
879
+
880
+
881
+ class Vmap(torch.nn.Module):
882
+ def forward(self, x, y):
883
+ f = lambda x, y: x * y + 1 # noqa: E731
884
+ return torch.vmap(f)(x, y)
885
+
886
+ _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
887
+ _dynamic = {"x": {0: DYN}, "y": {0: DYN}}
888
+
889
+
890
+ class VmapPython(torch.nn.Module):
891
+ def forward(self, x, y):
892
+ f = lambda x, y: x * y + 1 # noqa: E731
893
+ return patched_vmap(f)(x, y)
894
+
895
+ _inputs = [(torch.tensor([1.0, 2.0, 3.0]), torch.tensor([0.1, 0.2, 0.3]))]
896
+ _dynamic = {"x": {0: DYN}, "y": {0: DYN}}