torchax 0.0.10.dev20251117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,790 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """This file contains some decompositons that are not available in torch stable.
16
+
17
+ Most likely from Content of
18
+ https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
19
+ at main branch HEAD that we find useful here.
20
+
21
+ Can also contain decompositions of a torch op in terms of other torch ops.
22
+ """
23
+
24
+ import functools
25
+ from typing import Any, Callable, List, Tuple
26
+
27
+ import torch
28
+ from torch import Tensor
29
+ import torch._decomp as decomp
30
+ from torch._decomp import decompositions_for_rng
31
+ from torch._decomp import register_decomposition
32
+ import torch._prims_common as utils
33
+ from torch._prims_common.wrappers import out_wrapper
34
+
35
+ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
36
+
37
+ # None of these functions are publicly accessible; get at them
38
+ # from torch._decomps
39
+ __all__: List[str] = []
40
+
41
+ aten = torch._ops.ops.aten
42
+
43
+
44
+ def _try_register(op, impl):
45
+ try:
46
+ register_decomposition(op)(impl)
47
+ except:
48
+ pass
49
+
50
+
51
+ @out_wrapper()
52
+ def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
53
+
54
+ def idx(left, middle, right):
55
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
56
+ return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
57
+
58
+ return _reflection_or_replication_pad(
59
+ a,
60
+ padding,
61
+ idx,
62
+ )
63
+
64
+
65
+ _try_register(aten.reflection_pad1d, _reflection_pad)
66
+ _try_register(aten.reflection_pad2d, _reflection_pad)
67
+ _try_register(aten.reflection_pad3d, _reflection_pad)
68
+
69
+
70
+ @out_wrapper()
71
+ def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
72
+
73
+ def idx(left, middle, right):
74
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
75
+ return torch.clamp(dim_idx, 0, middle - 1)
76
+
77
+ return _reflection_or_replication_pad(
78
+ a,
79
+ padding,
80
+ idx,
81
+ )
82
+
83
+
84
+ decomp.global_decomposition_table["post_autograd"][
85
+ aten.replication_pad2d.default] = _replication_pad
86
+
87
+
88
+ def _reflection_or_replication_pad(
89
+ a: Tensor,
90
+ padding: Tuple[int, ...],
91
+ idx_fn: Callable[[int, int, int], Tensor],
92
+ ) -> Tensor:
93
+ dim = len(padding) // 2
94
+ torch._check(
95
+ a.dim() in (dim + 1, dim + 2),
96
+ lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
97
+ )
98
+ inp_shape = a.shape[-dim:]
99
+ nc_dim = a.dim() - dim
100
+
101
+ padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
102
+ padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
103
+
104
+ result = a
105
+ for i in range(dim):
106
+ idx: List[Any] = [None] * result.dim()
107
+ idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
108
+ result = aten._unsafe_index(result, idx)
109
+
110
+ # convert output to correct memory format, if necessary
111
+ memory_format = utils.suggest_memory_format(result)
112
+ result = result.contiguous(memory_format=memory_format)
113
+ return result
114
+
115
+
116
+ _try_register(aten.replication_pad1d, _replication_pad)
117
+ _try_register(aten.replication_pad3d, _replication_pad)
118
+
119
+
120
+ def bernoulli(self, *, generator=None):
121
+ return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
122
+
123
+
124
+ _try_register(aten.bernoulli.default, bernoulli)
125
+
126
+
127
+ def rand_like(self, **kwargs):
128
+ dtype = kwargs.get("dtype", self.dtype)
129
+ return torch.rand(self.shape, dtype=dtype)
130
+
131
+
132
+ def channel_shuffle(self, groups):
133
+ batchsize, channels, height, width = self.shape
134
+ channels_per_group = channels // groups
135
+ self = self.reshape(batchsize, groups, channels_per_group, height, width)
136
+ self = self.transpose(1, 2)
137
+ self = self.reshape(batchsize, channels, height, width)
138
+ return self
139
+
140
+
141
+ _try_register(aten.channel_shuffle, channel_shuffle)
142
+
143
+ _try_register(aten.bernoulli, bernoulli)
144
+ _try_register(aten.rand_like, rand_like)
145
+
146
+
147
+ def bernoulli_float(self, p=0.5):
148
+ return self.bernoulli_(p)
149
+
150
+
151
+ _try_register(aten.bernoulli_.float, bernoulli_float)
152
+ _try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_)
153
+
154
+
155
+ def _sum_tensors(ts) -> Tensor:
156
+ return functools.reduce(torch.add, ts)
157
+
158
+
159
+ @register_decomposition(aten.grid_sampler_3d)
160
+ def _grid_sampler_3d(
161
+ a: torch.Tensor,
162
+ grid: torch.Tensor,
163
+ interpolation_mode: int = 0,
164
+ padding_mode: int = 0,
165
+ align_corners: bool = False,
166
+ ) -> Tensor:
167
+ """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
168
+
169
+ The above implement the 2d case.
170
+ """
171
+ _expand_grid = False
172
+ torch._check(
173
+ interpolation_mode in (0, 1),
174
+ lambda: f"Invalid interpolation mode {interpolation_mode}",
175
+ )
176
+ torch._check(
177
+ padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}")
178
+
179
+ # a is 5D: [B, C, D, H, W]
180
+
181
+ def unnormalize(coords: Tensor, size: int) -> Tensor:
182
+ # Rescale coordinates from [-1, 1] to:
183
+ # [0, size - 1] if align_corners is True
184
+ # [-.5, size -.5] if align_corners is False
185
+ mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
186
+ ofs = size * 0.5 - 0.5
187
+ return coords * mul + ofs
188
+
189
+ # Reflects coordinates until they fall between low and high (inclusive).
190
+ # The bounds are passed as twice their value so that half-integer values
191
+ # can be represented as ints.
192
+ def reflect_coordinates(coords: Tensor, twice_low: int,
193
+ twice_high: int) -> Tensor:
194
+ if twice_low == twice_high:
195
+ return torch.zeros_like(coords)
196
+ coords_min = twice_low / 2
197
+ coords_span = (twice_high - twice_low) / 2
198
+ coords2 = (coords - coords_min).abs()
199
+ extra = torch.fmod(coords2, coords_span)
200
+ flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
201
+ return torch.where(flips & 1 == 0, extra + coords_min,
202
+ coords_span + coords_min - extra)
203
+
204
+ def compute_coordinates(coords: Tensor, size: int) -> Tensor:
205
+ if padding_mode == 0: # Zero
206
+ return coords
207
+ elif padding_mode == 1: # Borders
208
+ return torch.clamp(coords, 0, size - 1)
209
+ else: # padding_mode == 2, Reflection
210
+ if align_corners:
211
+ coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
212
+ else:
213
+ coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
214
+ return torch.clamp(coords_reflected, 0, size - 1)
215
+
216
+ def compute_source_index(coords: Tensor, size: int) -> Tensor:
217
+ coords_un = unnormalize(coords, size)
218
+ return compute_coordinates(coords_un, size)
219
+
220
+ N, C, iD, iH, iW = a.shape
221
+ _, oD, oH, oW, three = grid.shape
222
+ assert three == 3, "Last dim of grid must be 3. got {}".format(three)
223
+
224
+ def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
225
+ xcheck = torch.logical_and(0 <= xs, xs < iW)
226
+ ycheck = torch.logical_and(0 <= ys, ys < iH)
227
+ zcheck = torch.logical_and(0 <= zs, zs < iD)
228
+ return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck))
229
+
230
+ N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
231
+ C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)
232
+
233
+ def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
234
+ cond = in_bounds_cond(xs, ys, zs)
235
+ # To clip to inside valid coordinates, we map the coordinates
236
+ # to (x, y) = (0, 0) and also set the weight to 0
237
+ # We also change the shape of the tensor to the appropriate one for
238
+ # broadcasting with N_idx, C_idx for the purposes of advanced indexing
239
+ c = C if _expand_grid else 1
240
+ return tuple(
241
+ torch.where(cond, t, 0).view(N, c, oD, oH, oW) for t in (
242
+ xs.to(dtype=torch.int64),
243
+ ys.to(dtype=torch.int64),
244
+ zs.to(dtype=torch.int64),
245
+ ws,
246
+ ))
247
+
248
+ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
249
+ w) -> Tensor:
250
+ # Perform clipping, index into input tensor and multiply by weight
251
+ idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
252
+ return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
253
+
254
+ x = grid[..., 0]
255
+ y = grid[..., 1]
256
+ d = grid[..., 2]
257
+
258
+ if interpolation_mode == 0: # Bilinear
259
+ ix = compute_source_index(x, iW)
260
+ iy = compute_source_index(y, iH)
261
+ id_ = compute_source_index(d, iD)
262
+
263
+ ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
264
+ ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
265
+ ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
266
+ ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
267
+ ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
268
+ ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
269
+ ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
270
+ ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1
271
+
272
+ w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
273
+ w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_)
274
+ w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
275
+ w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
276
+ w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
277
+ w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
278
+ w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
279
+ w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
280
+
281
+ return _sum_tensors(
282
+ get_summand(ix, iy, id_, w) for (ix, iy, id_, w) in (
283
+ (ix_nwf, iy_nwf, id_nwf, w_nwf),
284
+ (ix_nef, iy_nef, id_nef, w_nef),
285
+ (ix_swf, iy_swf, id_swf, w_swf),
286
+ (ix_sef, iy_sef, id_sef, w_sef),
287
+ (ix_nwb, iy_nwb, id_nwb, w_nwb),
288
+ (ix_neb, iy_neb, id_neb, w_neb),
289
+ (ix_swb, iy_swb, id_swb, w_swb),
290
+ (ix_seb, iy_seb, id_seb, w_seb),
291
+ ))
292
+ else: # interpolation_mode == 1: # Nearest
293
+ ix = compute_source_index(x, iW)
294
+ iy = compute_source_index(y, iH)
295
+ iz = compute_source_index(d, iD)
296
+
297
+ ix_nearest = ix.round()
298
+ iy_nearest = iy.round()
299
+ iz_nearest = iz.round()
300
+
301
+ return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
302
+
303
+
304
+ DECOMPOSITIONS = decomp.get_decompositions([
305
+ torch.ops.aten.upsample_bicubic2d,
306
+ torch.ops.aten.upsample_nearest1d,
307
+ torch.ops.aten.upsample_nearest2d,
308
+ torch.ops.aten.upsample_nearest3d,
309
+ torch.ops.aten._upsample_nearest_exact1d,
310
+ torch.ops.aten._upsample_nearest_exact2d,
311
+ torch.ops.aten._upsample_nearest_exact3d,
312
+ torch.ops.aten._native_batch_norm_legit.no_stats,
313
+ torch.ops.aten._native_batch_norm_legit_functional.default,
314
+ torch.ops.aten._adaptive_avg_pool2d,
315
+ torch.ops.aten._adaptive_avg_pool3d,
316
+ torch.ops.aten.grid_sampler_2d,
317
+ torch.ops.aten.grid_sampler_3d,
318
+ torch.ops.aten.native_dropout,
319
+ torch.ops.aten.reflection_pad1d,
320
+ torch.ops.aten.reflection_pad2d,
321
+ torch.ops.aten.reflection_pad3d,
322
+ torch.ops.aten.replication_pad1d,
323
+ torch.ops.aten.replication_pad2d,
324
+ torch.ops.aten.replication_pad3d,
325
+ torch.ops.aten.bernoulli,
326
+ torch.ops.aten.rand_like,
327
+ torch.ops.aten._batch_norm_with_update,
328
+ torch.ops.aten.channel_shuffle,
329
+ torch.ops.aten.nll_loss2d_forward,
330
+ torch.ops.aten.nll_loss2d_backward,
331
+ torch.ops.aten.bernoulli_.Tensor,
332
+ torch.ops.aten.bernoulli_.float,
333
+ torch.ops.aten.log_normal,
334
+ torch.ops.aten.addcdiv.default,
335
+ torch.ops.aten.addcdiv.out,
336
+ torch.ops.aten.addcdiv_.default,
337
+ torch.ops.aten.addcmul.default,
338
+ torch.ops.aten.addcmul.out,
339
+ torch.ops.aten.addcmul_.default,
340
+ torch.ops.aten.addr.default,
341
+ torch.ops.aten.addr.out,
342
+ torch.ops.aten.affine_grid_generator.default,
343
+ torch.ops.aten.affine_grid_generator.out,
344
+ torch.ops.aten.alias_copy.default,
345
+ torch.ops.aten.alias_copy.out,
346
+ torch.ops.aten.all.default,
347
+ torch.ops.aten.all.dim,
348
+ torch.ops.aten.all.dims,
349
+ torch.ops.aten.all.out,
350
+ torch.ops.aten.all.dims_out,
351
+ torch.ops.aten.all.all_out,
352
+ torch.ops.aten.all.dimname,
353
+ torch.ops.aten.all.dimname_out,
354
+ torch.ops.aten.aminmax.default,
355
+ torch.ops.aten.aminmax.out,
356
+ torch.ops.aten.arange.default,
357
+ torch.ops.aten.arange.start,
358
+ torch.ops.aten.baddbmm.default,
359
+ torch.ops.aten.baddbmm.out,
360
+ torch.ops.aten.binary_cross_entropy.default,
361
+ torch.ops.aten.binary_cross_entropy.out,
362
+ torch.ops.aten.binary_cross_entropy_backward.default,
363
+ torch.ops.aten.binary_cross_entropy_backward.grad_input,
364
+ torch.ops.aten.binary_cross_entropy_with_logits.default,
365
+ torch.ops.aten.binary_cross_entropy_with_logits.out,
366
+ torch.ops.aten.block_diag.default,
367
+ torch.ops.aten.block_diag.out,
368
+ torch.ops.aten.celu.default,
369
+ torch.ops.aten.celu.out,
370
+ torch.ops.aten.celu_.default,
371
+ torch.ops.aten.channel_shuffle.default,
372
+ torch.ops.aten.channel_shuffle.out,
373
+ torch.ops.aten.clamp_max.default,
374
+ torch.ops.aten.clamp_max.Tensor,
375
+ torch.ops.aten.clamp_max.out,
376
+ torch.ops.aten.clamp_max.Tensor_out,
377
+ torch.ops.aten.clamp_min.default,
378
+ torch.ops.aten.clamp_min.Tensor,
379
+ torch.ops.aten.clamp_min.out,
380
+ torch.ops.aten.clamp_min.Tensor_out,
381
+ torch.ops.aten.col2im.default,
382
+ torch.ops.aten.col2im.out,
383
+ torch.ops.aten.count_nonzero.dim_IntList,
384
+ torch.ops.aten.count_nonzero.dim_IntList_out,
385
+ torch.ops.aten.count_nonzero.default,
386
+ torch.ops.aten.count_nonzero.out,
387
+ torch.ops.aten.linalg_cross.default,
388
+ torch.ops.aten.linalg_cross.out,
389
+ torch.ops.aten.cudnn_batch_norm.default,
390
+ torch.ops.aten.cudnn_batch_norm.out,
391
+ torch.ops.aten.cudnn_batch_norm_backward.default,
392
+ torch.ops.aten.cudnn_batch_norm_backward.out,
393
+ torch.ops.aten.miopen_batch_norm_backward.default,
394
+ torch.ops.aten.miopen_batch_norm_backward.out,
395
+ torch.ops.aten.deg2rad.default,
396
+ torch.ops.aten.deg2rad.out,
397
+ torch.ops.aten.deg2rad_.default,
398
+ torch.ops.aten.detach.default,
399
+ torch.ops.aten.diag_embed.default,
400
+ torch.ops.aten.diag_embed.out,
401
+ torch.ops.aten.diagonal_backward.default,
402
+ torch.ops.aten.diagonal_backward.out,
403
+ torch.ops.aten.dot.default,
404
+ torch.ops.aten.dot.out,
405
+ torch.ops.aten.vdot.default,
406
+ torch.ops.aten.vdot.out,
407
+ torch.ops.aten.elu.default,
408
+ torch.ops.aten.elu.out,
409
+ torch.ops.aten.elu_.default,
410
+ torch.ops.aten.elu_backward.default,
411
+ torch.ops.aten.elu_backward.grad_input,
412
+ torch.ops.aten.embedding_dense_backward.default,
413
+ torch.ops.aten.embedding_dense_backward.out,
414
+ torch.ops.aten.empty_like.default,
415
+ torch.ops.aten.empty_like.out,
416
+ torch.ops.aten._euclidean_dist.default,
417
+ torch.ops.aten.expand_copy.default,
418
+ torch.ops.aten.expand_copy.out,
419
+ torch.ops.aten.eye.default,
420
+ torch.ops.aten.eye.m,
421
+ torch.ops.aten.eye.out,
422
+ torch.ops.aten.eye.m_out,
423
+ torch.ops.aten.fill.Scalar,
424
+ torch.ops.aten.fill.Tensor,
425
+ torch.ops.aten.fill_.Scalar,
426
+ torch.ops.aten.fill_.Tensor,
427
+ torch.ops.aten.floor_divide.default,
428
+ torch.ops.aten.floor_divide.Scalar,
429
+ torch.ops.aten.floor_divide.out,
430
+ torch.ops.aten.floor_divide.Scalar_out,
431
+ torch.ops.aten.frac.default,
432
+ torch.ops.aten.frac.out,
433
+ torch.ops.aten.frac_.default,
434
+ torch.ops.aten.gelu_.default,
435
+ torch.ops.aten.gelu_backward.default,
436
+ torch.ops.aten.gelu_backward.grad_input,
437
+ torch.ops.aten.glu.default,
438
+ torch.ops.aten.glu.out,
439
+ torch.ops.aten.glu_backward.default,
440
+ torch.ops.aten.glu_backward.grad_input,
441
+ torch.ops.aten.hardshrink.default,
442
+ torch.ops.aten.hardshrink.out,
443
+ torch.ops.aten.hardsigmoid.default,
444
+ torch.ops.aten.hardsigmoid.out,
445
+ torch.ops.aten.hardsigmoid_.default,
446
+ torch.ops.aten.hardsigmoid_backward.default,
447
+ torch.ops.aten.hardsigmoid_backward.grad_input,
448
+ torch.ops.aten.hardswish.default,
449
+ torch.ops.aten.hardswish.out,
450
+ torch.ops.aten.hardswish_.default,
451
+ torch.ops.aten.hardswish_backward.default,
452
+ torch.ops.aten.hardswish_backward.out,
453
+ torch.ops.aten.hardtanh_.default,
454
+ torch.ops.aten.hardtanh_backward.default,
455
+ torch.ops.aten.hardtanh_backward.grad_input,
456
+ torch.ops.aten.heaviside.default,
457
+ torch.ops.aten.heaviside.out,
458
+ torch.ops.aten.heaviside_.default,
459
+ torch.ops.aten.huber_loss.default,
460
+ torch.ops.aten.huber_loss.out,
461
+ torch.ops.aten.huber_loss_backward.default,
462
+ torch.ops.aten.huber_loss_backward.out,
463
+ torch.ops.aten.im2col.default,
464
+ torch.ops.aten.im2col.out,
465
+ torch.ops.aten.index_add.default,
466
+ torch.ops.aten.index_add.out,
467
+ torch.ops.aten.index_add.dimname,
468
+ torch.ops.aten.index_add_.default,
469
+ torch.ops.aten.index_copy.default,
470
+ torch.ops.aten.index_copy.dimname,
471
+ torch.ops.aten.index_copy.out,
472
+ torch.ops.aten.index_copy_.default,
473
+ torch.ops.aten.index_copy_.dimname,
474
+ torch.ops.aten.index_fill.int_Tensor,
475
+ torch.ops.aten.index_fill.int_Scalar,
476
+ torch.ops.aten.index_fill.Dimname_Scalar,
477
+ torch.ops.aten.index_fill.Dimname_Tensor,
478
+ torch.ops.aten.index_fill.int_Scalar_out,
479
+ torch.ops.aten.index_fill.int_Tensor_out,
480
+ torch.ops.aten.index_fill_.int_Tensor,
481
+ torch.ops.aten.index_fill_.int_Scalar,
482
+ torch.ops.aten.index_fill_.Dimname_Scalar,
483
+ torch.ops.aten.index_fill_.Dimname_Tensor,
484
+ torch.ops.aten.isin.Tensor_Tensor,
485
+ torch.ops.aten.isin.Tensor_Tensor_out,
486
+ torch.ops.aten.isin.Tensor_Scalar,
487
+ torch.ops.aten.isin.Tensor_Scalar_out,
488
+ torch.ops.aten.isin.Scalar_Tensor,
489
+ torch.ops.aten.isin.Scalar_Tensor_out,
490
+ torch.ops.aten.isneginf.default,
491
+ torch.ops.aten.isneginf.out,
492
+ torch.ops.aten.isposinf.default,
493
+ torch.ops.aten.isposinf.out,
494
+ torch.ops.aten.leaky_relu_.default,
495
+ torch.ops.aten.leaky_relu_backward.default,
496
+ torch.ops.aten.leaky_relu_backward.grad_input,
497
+ torch.ops.aten.lerp.Scalar,
498
+ torch.ops.aten.lerp.Tensor,
499
+ torch.ops.aten.lerp.Scalar_out,
500
+ torch.ops.aten.lerp.Tensor_out,
501
+ torch.ops.aten.lerp_.Scalar,
502
+ torch.ops.aten.lerp_.Tensor,
503
+ torch.ops.aten.linspace.Tensor_Tensor,
504
+ torch.ops.aten.linspace.Tensor_Scalar,
505
+ torch.ops.aten.linspace.Scalar_Tensor,
506
+ torch.ops.aten.linspace.default,
507
+ torch.ops.aten.linspace.out,
508
+ torch.ops.aten.linspace.Tensor_Tensor_out,
509
+ torch.ops.aten.linspace.Tensor_Scalar_out,
510
+ torch.ops.aten.linspace.Scalar_Tensor_out,
511
+ torch.ops.aten.logaddexp.default,
512
+ torch.ops.aten.logaddexp.out,
513
+ torch.ops.aten.logaddexp2.default,
514
+ torch.ops.aten.logaddexp2.out,
515
+ torch.ops.aten.logit.default,
516
+ torch.ops.aten.logit.out,
517
+ torch.ops.aten.logit_.default,
518
+ torch.ops.aten.logit_backward.default,
519
+ torch.ops.aten.log_sigmoid_backward.default,
520
+ torch.ops.aten.log_sigmoid_backward.grad_input,
521
+ torch.ops.aten.log_sigmoid_forward.default,
522
+ torch.ops.aten.log_sigmoid_forward.output,
523
+ torch.ops.aten._log_softmax_backward_data.default,
524
+ torch.ops.aten._log_softmax_backward_data.out,
525
+ torch.ops.aten.logspace.Tensor_Tensor,
526
+ torch.ops.aten.logspace.Tensor_Scalar,
527
+ torch.ops.aten.logspace.Scalar_Tensor,
528
+ torch.ops.aten.logspace.default,
529
+ torch.ops.aten.logspace.out,
530
+ torch.ops.aten.logspace.Tensor_Tensor_out,
531
+ torch.ops.aten.logspace.Tensor_Scalar_out,
532
+ torch.ops.aten.logspace.Scalar_Tensor_out,
533
+ torch.ops.aten.logsumexp.default,
534
+ torch.ops.aten.masked_fill.Scalar,
535
+ torch.ops.aten.masked_fill.Tensor,
536
+ torch.ops.aten.masked_fill.Scalar_out,
537
+ torch.ops.aten.masked_fill.Tensor_out,
538
+ torch.ops.aten.masked_fill_.Scalar,
539
+ torch.ops.aten.masked_fill_.Tensor,
540
+ torch.ops.aten.mish.default,
541
+ torch.ops.aten.mish.out,
542
+ torch.ops.aten.mish_.default,
543
+ torch.ops.aten.mse_loss.default,
544
+ torch.ops.aten.mse_loss.out,
545
+ torch.ops.aten.mse_loss_backward.default,
546
+ torch.ops.aten.mse_loss_backward.grad_input,
547
+ torch.ops.aten.multi_margin_loss.default,
548
+ torch.ops.aten.multi_margin_loss.out,
549
+ torch.ops.aten.multilabel_margin_loss_forward.default,
550
+ torch.ops.aten.multilabel_margin_loss_forward.output,
551
+ torch.ops.aten.mv.default,
552
+ torch.ops.aten.mv.out,
553
+ torch.ops.aten.mvlgamma.default,
554
+ torch.ops.aten.mvlgamma.out,
555
+ torch.ops.aten.mvlgamma_.default,
556
+ torch.ops.aten.nansum.default,
557
+ torch.ops.aten.nansum.out,
558
+ torch.ops.aten.nan_to_num.default,
559
+ torch.ops.aten.nan_to_num.out,
560
+ torch.ops.aten.nan_to_num_.default,
561
+ torch.ops.aten.native_batch_norm_backward.default,
562
+ torch.ops.aten.native_batch_norm_backward.out,
563
+ torch.ops.aten.native_dropout_backward.default,
564
+ torch.ops.aten.native_dropout_backward.out,
565
+ torch.ops.aten.native_group_norm_backward.default,
566
+ torch.ops.aten.native_group_norm_backward.out,
567
+ torch.ops.aten.native_layer_norm_backward.default,
568
+ torch.ops.aten.native_layer_norm_backward.out,
569
+ torch.ops.aten.new_empty.default,
570
+ torch.ops.aten.new_empty.out,
571
+ torch.ops.aten.new_full.default,
572
+ torch.ops.aten.new_full.out,
573
+ torch.ops.aten.new_ones.default,
574
+ torch.ops.aten.new_ones.out,
575
+ torch.ops.aten.new_zeros.default,
576
+ torch.ops.aten.new_zeros.out,
577
+ torch.ops.aten.nll_loss2d_forward.default,
578
+ torch.ops.aten.nll_loss2d_forward.output,
579
+ torch.ops.aten.nll_loss2d_backward.default,
580
+ torch.ops.aten.nll_loss2d_backward.grad_input,
581
+ torch.ops.aten.nll_loss_backward.default,
582
+ torch.ops.aten.nll_loss_backward.grad_input,
583
+ torch.ops.aten.nll_loss_forward.default,
584
+ torch.ops.aten.nll_loss_forward.output,
585
+ torch.ops.aten.norm.Scalar,
586
+ torch.ops.aten.norm.ScalarOpt_dim,
587
+ torch.ops.aten.norm.names_ScalarOpt_dim,
588
+ torch.ops.aten.norm.ScalarOpt_dim_dtype,
589
+ torch.ops.aten.norm.dtype_out,
590
+ torch.ops.aten.norm.out,
591
+ torch.ops.aten.norm.ScalarOpt_dtype,
592
+ torch.ops.aten.norm.ScalarOpt_dtype_out,
593
+ torch.ops.aten.norm.Scalar_out,
594
+ torch.ops.aten.norm.names_ScalarOpt_dim_dtype,
595
+ torch.ops.aten.norm.names_dtype_out,
596
+ torch.ops.aten.norm.names_out,
597
+ torch.ops.aten.ones.default,
598
+ torch.ops.aten.ones_like.default,
599
+ torch.ops.aten.ones_like.out,
600
+ torch.ops.aten.pixel_shuffle.default,
601
+ torch.ops.aten.pixel_shuffle.out,
602
+ torch.ops.aten.pixel_unshuffle.default,
603
+ torch.ops.aten.pixel_unshuffle.out,
604
+ torch.ops.aten._prelu_kernel.default,
605
+ torch.ops.aten._prelu_kernel_backward.default,
606
+ torch.ops.aten._reshape_alias.default,
607
+ torch.ops.aten.rad2deg.default,
608
+ torch.ops.aten.rad2deg.out,
609
+ torch.ops.aten.rad2deg_.default,
610
+ torch.ops.aten.reflection_pad1d.default,
611
+ torch.ops.aten.reflection_pad1d.out,
612
+ torch.ops.aten.reflection_pad1d_backward.default,
613
+ torch.ops.aten.reflection_pad1d_backward.grad_input,
614
+ torch.ops.aten.reflection_pad2d.default,
615
+ torch.ops.aten.reflection_pad2d.out,
616
+ torch.ops.aten.reflection_pad2d_backward.default,
617
+ torch.ops.aten.reflection_pad2d_backward.grad_input,
618
+ torch.ops.aten.reflection_pad3d.default,
619
+ torch.ops.aten.reflection_pad3d.out,
620
+ torch.ops.aten.reflection_pad3d_backward.default,
621
+ torch.ops.aten.reflection_pad3d_backward.grad_input,
622
+ torch.ops.aten.replication_pad1d.default,
623
+ torch.ops.aten.replication_pad1d.out,
624
+ torch.ops.aten.replication_pad2d.default,
625
+ torch.ops.aten.replication_pad2d.out,
626
+ torch.ops.aten.replication_pad3d.default,
627
+ torch.ops.aten.replication_pad3d.out,
628
+ torch.ops.aten.renorm.default,
629
+ torch.ops.aten.renorm.out,
630
+ torch.ops.aten.renorm_.default,
631
+ torch.ops.aten.resize_as.default,
632
+ torch.ops.aten.resize_as.out,
633
+ torch.ops.aten.roll.default,
634
+ torch.ops.aten.roll.out,
635
+ torch.ops.aten.rot90.default,
636
+ torch.ops.aten.rot90.out,
637
+ torch.ops.aten.rrelu_with_noise.default,
638
+ torch.ops.aten.rrelu_with_noise.out,
639
+ torch.ops.aten.rrelu_with_noise_.default,
640
+ torch.ops.aten.rsub.Tensor,
641
+ torch.ops.aten.rsub.Scalar,
642
+ torch.ops.aten.rsub.Tensor_out,
643
+ torch.ops.aten.rsub.Scalar_out,
644
+ torch.ops.aten._safe_softmax.default,
645
+ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
646
+ torch.ops.aten.select_backward.default,
647
+ torch.ops.aten.select_backward.out,
648
+ torch.ops.aten.select_scatter.default,
649
+ torch.ops.aten.select_scatter.out,
650
+ torch.ops.aten.sgn.default,
651
+ torch.ops.aten.sgn.out,
652
+ torch.ops.aten.sgn_.default,
653
+ torch.ops.aten.sigmoid_backward.default,
654
+ torch.ops.aten.sigmoid_backward.grad_input,
655
+ torch.ops.aten.silu.default,
656
+ torch.ops.aten.silu.out,
657
+ torch.ops.aten.silu_.default,
658
+ torch.ops.aten.silu_backward.default,
659
+ torch.ops.aten.silu_backward.grad_input,
660
+ torch.ops.aten.sinc.default,
661
+ torch.ops.aten.sinc.out,
662
+ torch.ops.aten.sinc_.default,
663
+ torch.ops.aten.slice_backward.default,
664
+ torch.ops.aten.slice_backward.out,
665
+ torch.ops.aten.smooth_l1_loss.default,
666
+ torch.ops.aten.smooth_l1_loss.out,
667
+ torch.ops.aten.smooth_l1_loss_backward.default,
668
+ torch.ops.aten.smooth_l1_loss_backward.grad_input,
669
+ torch.ops.aten.soft_margin_loss.default,
670
+ torch.ops.aten.soft_margin_loss.out,
671
+ torch.ops.aten.soft_margin_loss_backward.default,
672
+ torch.ops.aten.soft_margin_loss_backward.grad_input,
673
+ torch.ops.aten._softmax_backward_data.default,
674
+ torch.ops.aten._softmax_backward_data.out,
675
+ torch.ops.aten.softplus.default,
676
+ torch.ops.aten.softplus.out,
677
+ torch.ops.aten.softplus_backward.default,
678
+ torch.ops.aten.softplus_backward.grad_input,
679
+ torch.ops.aten.softshrink.default,
680
+ torch.ops.aten.softshrink.out,
681
+ torch.ops.aten.special_entr.default,
682
+ torch.ops.aten.special_entr.out,
683
+ torch.ops.aten.special_log_ndtr.default,
684
+ torch.ops.aten.special_log_ndtr.out,
685
+ torch.ops.aten.special_xlog1py.default,
686
+ torch.ops.aten.special_xlog1py.other_scalar,
687
+ torch.ops.aten.special_xlog1py.self_scalar,
688
+ torch.ops.aten.special_xlog1py.out,
689
+ torch.ops.aten.special_xlog1py.self_scalar_out,
690
+ torch.ops.aten.special_xlog1py.other_scalar_out,
691
+ torch.ops.aten.split.Tensor,
692
+ torch.ops.aten.split_with_sizes_copy.default,
693
+ torch.ops.aten.split_with_sizes_copy.out,
694
+ torch.ops.aten.squeeze.default,
695
+ torch.ops.aten.squeeze.dim,
696
+ torch.ops.aten.std.default,
697
+ torch.ops.aten.std.dim,
698
+ torch.ops.aten.std.correction,
699
+ torch.ops.aten.std.names_dim,
700
+ torch.ops.aten.std.names_out,
701
+ torch.ops.aten.std.out,
702
+ torch.ops.aten.std.correction_out,
703
+ torch.ops.aten.std.correction_names,
704
+ torch.ops.aten.std.correction_names_out,
705
+ torch.ops.aten.std_mean.default,
706
+ torch.ops.aten.std_mean.dim,
707
+ torch.ops.aten.std_mean.correction,
708
+ torch.ops.aten.std_mean.names_dim,
709
+ torch.ops.aten.std_mean.correction_names,
710
+ torch.ops.aten.std_mean.correction_out,
711
+ torch.ops.aten.stack.default,
712
+ torch.ops.aten.stack.out,
713
+ torch.ops.aten.sum.default,
714
+ torch.ops.aten.sum.out,
715
+ torch.ops.aten.t.default,
716
+ torch.ops.aten.t_copy.out,
717
+ torch.ops.aten.t_copy.default,
718
+ torch.ops.aten.take.default,
719
+ torch.ops.aten.take.out,
720
+ torch.ops.aten.tanh_backward.default,
721
+ torch.ops.aten.tanh_backward.grad_input,
722
+ torch.ops.aten.threshold.default,
723
+ torch.ops.aten.threshold.out,
724
+ torch.ops.aten.threshold_.default,
725
+ torch.ops.aten.threshold_backward.default,
726
+ torch.ops.aten.threshold_backward.grad_input,
727
+ torch.ops.aten.trace.default,
728
+ torch.ops.aten.trace.out,
729
+ torch.ops.aten.transpose.int,
730
+ torch.ops.aten.tril.default,
731
+ torch.ops.aten.tril.out,
732
+ torch.ops.aten.tril_.default,
733
+ torch.ops.aten.triu.default,
734
+ torch.ops.aten.triu.out,
735
+ torch.ops.aten.triu_.default,
736
+ torch.ops.aten.unbind.int,
737
+ torch.ops.aten.unbind.Dimname,
738
+ torch.ops.aten.unfold_backward.default,
739
+ torch.ops.aten.unfold_backward.out,
740
+ torch.ops.aten.unfold_copy.default,
741
+ torch.ops.aten.unfold_copy.out,
742
+ torch.ops.aten._unsafe_index.Tensor,
743
+ torch.ops.aten._unsafe_index_put.default,
744
+ torch.ops.aten._unsafe_masked_index.default,
745
+ torch.ops.aten._unsafe_masked_index_put_accumulate.default,
746
+ torch.ops.aten.unsafe_split.Tensor,
747
+ torch.ops.aten.unsafe_split_with_sizes.default,
748
+ torch.ops.aten.unsqueeze_copy.out,
749
+ torch.ops.aten.unsqueeze_copy.default,
750
+ torch.ops.aten._unsafe_view.default,
751
+ torch.ops.aten._unsafe_view.out,
752
+ torch.ops.aten.upsample_linear1d.default,
753
+ torch.ops.aten.upsample_linear1d.out,
754
+ torch.ops.aten.upsample_bilinear2d.vec,
755
+ torch.ops.aten.upsample_bilinear2d.default,
756
+ torch.ops.aten.upsample_bilinear2d.out,
757
+ torch.ops.aten.upsample_trilinear3d.vec,
758
+ torch.ops.aten.upsample_trilinear3d.default,
759
+ torch.ops.aten.upsample_trilinear3d.out,
760
+ torch.ops.aten.xlogy.Tensor,
761
+ torch.ops.aten.xlogy.Scalar_Other,
762
+ torch.ops.aten.xlogy.Scalar_Self,
763
+ torch.ops.aten.xlogy.OutTensor,
764
+ torch.ops.aten.xlogy.OutScalar_Self,
765
+ torch.ops.aten.xlogy.OutScalar_Other,
766
+ torch.ops.aten.xlogy_.Tensor,
767
+ torch.ops.aten.xlogy_.Scalar_Other,
768
+ torch.ops.aten.zero.default,
769
+ torch.ops.aten.zero.out,
770
+ torch.ops.aten.zero_.default,
771
+ torch.ops.aten.zeros.default,
772
+ torch.ops.aten.zeros_like.default,
773
+ torch.ops.aten.zeros_like.out,
774
+ torch.ops.aten._chunk_cat.default,
775
+ torch.ops.aten._chunk_cat.out,
776
+ torch.ops.aten._weight_norm_interface.default,
777
+ torch.ops.aten._weight_norm_interface.out,
778
+ torch.ops.aten.__iand__.Tensor,
779
+ torch.ops.aten.__ixor__.Tensor,
780
+ torch.ops.aten.__ilshift__.Tensor,
781
+ torch.ops.aten.__ilshift__.Scalar,
782
+ torch.ops.aten.__irshift__.Tensor,
783
+ torch.ops.aten.__irshift__.Scalar,
784
+ torch.ops.aten.__ior__.Tensor,
785
+ ])
786
+
787
+ MUTABLE_DECOMPOSITION = [
788
+ torch.ops.aten.bernoulli_.Tensor,
789
+ torch.ops.aten.bernoulli_.float,
790
+ ]