torchax 0.0.10.dev20251118__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,796 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """This file contains some decompositons that are not available in torch stable.
16
+
17
+ Most likely from Content of
18
+ https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
19
+ at main branch HEAD that we find useful here.
20
+
21
+ Can also contain decompositions of a torch op in terms of other torch ops.
22
+ """
23
+
24
+ import functools
25
+ from collections.abc import Callable
26
+ from typing import Any
27
+
28
+ import torch
29
+ import torch._decomp as decomp
30
+ import torch._prims_common as utils
31
+ from torch import Tensor
32
+ from torch._decomp import decompositions_for_rng, register_decomposition
33
+ from torch._prims_common.wrappers import out_wrapper
34
+
35
+ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
36
+
37
+ # None of these functions are publicly accessible; get at them
38
+ # from torch._decomps
39
+ __all__: list[str] = []
40
+
41
+ aten = torch._ops.ops.aten
42
+
43
+
44
+ def _try_register(op, impl):
45
+ try:
46
+ register_decomposition(op)(impl)
47
+
48
+ except Exception:
49
+ pass
50
+
51
+
52
+ @out_wrapper()
53
+ def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
54
+ def idx(left, middle, right):
55
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
56
+ return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
57
+
58
+ return _reflection_or_replication_pad(
59
+ a,
60
+ padding,
61
+ idx,
62
+ )
63
+
64
+
65
+ _try_register(aten.reflection_pad1d, _reflection_pad)
66
+ _try_register(aten.reflection_pad2d, _reflection_pad)
67
+ _try_register(aten.reflection_pad3d, _reflection_pad)
68
+
69
+
70
+ @out_wrapper()
71
+ def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
72
+ def idx(left, middle, right):
73
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
74
+ return torch.clamp(dim_idx, 0, middle - 1)
75
+
76
+ return _reflection_or_replication_pad(
77
+ a,
78
+ padding,
79
+ idx,
80
+ )
81
+
82
+
83
+ decomp.global_decomposition_table["post_autograd"][aten.replication_pad2d.default] = (
84
+ _replication_pad
85
+ )
86
+
87
+
88
+ def _reflection_or_replication_pad(
89
+ a: Tensor,
90
+ padding: tuple[int, ...],
91
+ idx_fn: Callable[[int, int, int], Tensor],
92
+ ) -> Tensor:
93
+ dim = len(padding) // 2
94
+ torch._check(
95
+ a.dim() in (dim + 1, dim + 2),
96
+ lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
97
+ )
98
+ inp_shape = a.shape[-dim:]
99
+ nc_dim = a.dim() - dim
100
+
101
+ padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
102
+ padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
103
+
104
+ result = a
105
+ for i in range(dim):
106
+ idx: list[Any] = [None] * result.dim()
107
+ idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
108
+ result = aten._unsafe_index(result, idx)
109
+
110
+ # convert output to correct memory format, if necessary
111
+ memory_format = utils.suggest_memory_format(result)
112
+ result = result.contiguous(memory_format=memory_format)
113
+ return result
114
+
115
+
116
+ _try_register(aten.replication_pad1d, _replication_pad)
117
+ _try_register(aten.replication_pad3d, _replication_pad)
118
+
119
+
120
+ def bernoulli(self, *, generator=None):
121
+ return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
122
+
123
+
124
+ _try_register(aten.bernoulli.default, bernoulli)
125
+
126
+
127
+ def rand_like(self, **kwargs):
128
+ dtype = kwargs.get("dtype", self.dtype)
129
+ return torch.rand(self.shape, dtype=dtype)
130
+
131
+
132
+ def channel_shuffle(self, groups):
133
+ batchsize, channels, height, width = self.shape
134
+ channels_per_group = channels // groups
135
+ self = self.reshape(batchsize, groups, channels_per_group, height, width)
136
+ self = self.transpose(1, 2)
137
+ self = self.reshape(batchsize, channels, height, width)
138
+ return self
139
+
140
+
141
+ _try_register(aten.channel_shuffle, channel_shuffle)
142
+
143
+ _try_register(aten.bernoulli, bernoulli)
144
+ _try_register(aten.rand_like, rand_like)
145
+
146
+
147
+ def bernoulli_float(self, p=0.5):
148
+ return self.bernoulli_(p)
149
+
150
+
151
+ _try_register(aten.bernoulli_.float, bernoulli_float)
152
+ _try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_)
153
+
154
+
155
+ def _sum_tensors(ts) -> Tensor:
156
+ return functools.reduce(torch.add, ts)
157
+
158
+
159
+ @register_decomposition(aten.grid_sampler_3d)
160
+ def _grid_sampler_3d(
161
+ a: torch.Tensor,
162
+ grid: torch.Tensor,
163
+ interpolation_mode: int = 0,
164
+ padding_mode: int = 0,
165
+ align_corners: bool = False,
166
+ ) -> Tensor:
167
+ """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
168
+
169
+ The above implement the 2d case.
170
+ """
171
+ _expand_grid = False
172
+ torch._check(
173
+ interpolation_mode in (0, 1),
174
+ lambda: f"Invalid interpolation mode {interpolation_mode}",
175
+ )
176
+ torch._check(
177
+ padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
178
+ )
179
+
180
+ # a is 5D: [B, C, D, H, W]
181
+
182
+ def unnormalize(coords: Tensor, size: int) -> Tensor:
183
+ # Rescale coordinates from [-1, 1] to:
184
+ # [0, size - 1] if align_corners is True
185
+ # [-.5, size -.5] if align_corners is False
186
+ mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
187
+ ofs = size * 0.5 - 0.5
188
+ return coords * mul + ofs
189
+
190
+ # Reflects coordinates until they fall between low and high (inclusive).
191
+ # The bounds are passed as twice their value so that half-integer values
192
+ # can be represented as ints.
193
+ def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
194
+ if twice_low == twice_high:
195
+ return torch.zeros_like(coords)
196
+ coords_min = twice_low / 2
197
+ coords_span = (twice_high - twice_low) / 2
198
+ coords2 = (coords - coords_min).abs()
199
+ extra = torch.fmod(coords2, coords_span)
200
+ flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
201
+ return torch.where(
202
+ flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
203
+ )
204
+
205
+ def compute_coordinates(coords: Tensor, size: int) -> Tensor:
206
+ if padding_mode == 0: # Zero
207
+ return coords
208
+ elif padding_mode == 1: # Borders
209
+ return torch.clamp(coords, 0, size - 1)
210
+ else: # padding_mode == 2, Reflection
211
+ if align_corners:
212
+ coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
213
+ else:
214
+ coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
215
+ return torch.clamp(coords_reflected, 0, size - 1)
216
+
217
+ def compute_source_index(coords: Tensor, size: int) -> Tensor:
218
+ coords_un = unnormalize(coords, size)
219
+ return compute_coordinates(coords_un, size)
220
+
221
+ N, C, iD, iH, iW = a.shape
222
+ _, oD, oH, oW, three = grid.shape
223
+ assert three == 3, f"Last dim of grid must be 3. got {three}"
224
+
225
+ def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
226
+ xcheck = torch.logical_and(0 <= xs, xs < iW)
227
+ ycheck = torch.logical_and(0 <= ys, ys < iH)
228
+ zcheck = torch.logical_and(0 <= zs, zs < iD)
229
+ return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck))
230
+
231
+ N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
232
+ C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)
233
+
234
+ def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
235
+ cond = in_bounds_cond(xs, ys, zs)
236
+ # To clip to inside valid coordinates, we map the coordinates
237
+ # to (x, y) = (0, 0) and also set the weight to 0
238
+ # We also change the shape of the tensor to the appropriate one for
239
+ # broadcasting with N_idx, C_idx for the purposes of advanced indexing
240
+ c = C if _expand_grid else 1
241
+ return tuple(
242
+ torch.where(cond, t, 0).view(N, c, oD, oH, oW)
243
+ for t in (
244
+ xs.to(dtype=torch.int64),
245
+ ys.to(dtype=torch.int64),
246
+ zs.to(dtype=torch.int64),
247
+ ws,
248
+ )
249
+ )
250
+
251
+ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
252
+ # Perform clipping, index into input tensor and multiply by weight
253
+ idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
254
+ return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
255
+
256
+ x = grid[..., 0]
257
+ y = grid[..., 1]
258
+ d = grid[..., 2]
259
+
260
+ if interpolation_mode == 0: # Bilinear
261
+ ix = compute_source_index(x, iW)
262
+ iy = compute_source_index(y, iH)
263
+ id_ = compute_source_index(d, iD)
264
+
265
+ ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
266
+ ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
267
+ ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
268
+ ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
269
+ ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
270
+ ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
271
+ ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
272
+ ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1
273
+
274
+ w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
275
+ w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_)
276
+ w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
277
+ w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
278
+ w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
279
+ w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
280
+ w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
281
+ w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
282
+
283
+ return _sum_tensors(
284
+ get_summand(ix, iy, id_, w)
285
+ for (ix, iy, id_, w) in (
286
+ (ix_nwf, iy_nwf, id_nwf, w_nwf),
287
+ (ix_nef, iy_nef, id_nef, w_nef),
288
+ (ix_swf, iy_swf, id_swf, w_swf),
289
+ (ix_sef, iy_sef, id_sef, w_sef),
290
+ (ix_nwb, iy_nwb, id_nwb, w_nwb),
291
+ (ix_neb, iy_neb, id_neb, w_neb),
292
+ (ix_swb, iy_swb, id_swb, w_swb),
293
+ (ix_seb, iy_seb, id_seb, w_seb),
294
+ )
295
+ )
296
+ else: # interpolation_mode == 1: # Nearest
297
+ ix = compute_source_index(x, iW)
298
+ iy = compute_source_index(y, iH)
299
+ iz = compute_source_index(d, iD)
300
+
301
+ ix_nearest = ix.round()
302
+ iy_nearest = iy.round()
303
+ iz_nearest = iz.round()
304
+
305
+ return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
306
+
307
+
308
+ DECOMPOSITIONS = decomp.get_decompositions(
309
+ [
310
+ torch.ops.aten.upsample_bicubic2d,
311
+ torch.ops.aten.upsample_nearest1d,
312
+ torch.ops.aten.upsample_nearest2d,
313
+ torch.ops.aten.upsample_nearest3d,
314
+ torch.ops.aten._upsample_nearest_exact1d,
315
+ torch.ops.aten._upsample_nearest_exact2d,
316
+ torch.ops.aten._upsample_nearest_exact3d,
317
+ torch.ops.aten._native_batch_norm_legit.no_stats,
318
+ torch.ops.aten._native_batch_norm_legit_functional.default,
319
+ torch.ops.aten._adaptive_avg_pool2d,
320
+ torch.ops.aten._adaptive_avg_pool3d,
321
+ torch.ops.aten.grid_sampler_2d,
322
+ torch.ops.aten.grid_sampler_3d,
323
+ torch.ops.aten.native_dropout,
324
+ torch.ops.aten.reflection_pad1d,
325
+ torch.ops.aten.reflection_pad2d,
326
+ torch.ops.aten.reflection_pad3d,
327
+ torch.ops.aten.replication_pad1d,
328
+ torch.ops.aten.replication_pad2d,
329
+ torch.ops.aten.replication_pad3d,
330
+ torch.ops.aten.bernoulli,
331
+ torch.ops.aten.rand_like,
332
+ torch.ops.aten._batch_norm_with_update,
333
+ torch.ops.aten.channel_shuffle,
334
+ torch.ops.aten.nll_loss2d_forward,
335
+ torch.ops.aten.nll_loss2d_backward,
336
+ torch.ops.aten.bernoulli_.Tensor,
337
+ torch.ops.aten.bernoulli_.float,
338
+ torch.ops.aten.log_normal,
339
+ torch.ops.aten.addcdiv.default,
340
+ torch.ops.aten.addcdiv.out,
341
+ torch.ops.aten.addcdiv_.default,
342
+ torch.ops.aten.addcmul.default,
343
+ torch.ops.aten.addcmul.out,
344
+ torch.ops.aten.addcmul_.default,
345
+ torch.ops.aten.addr.default,
346
+ torch.ops.aten.addr.out,
347
+ torch.ops.aten.affine_grid_generator.default,
348
+ torch.ops.aten.affine_grid_generator.out,
349
+ torch.ops.aten.alias_copy.default,
350
+ torch.ops.aten.alias_copy.out,
351
+ torch.ops.aten.all.default,
352
+ torch.ops.aten.all.dim,
353
+ torch.ops.aten.all.dims,
354
+ torch.ops.aten.all.out,
355
+ torch.ops.aten.all.dims_out,
356
+ torch.ops.aten.all.all_out,
357
+ torch.ops.aten.all.dimname,
358
+ torch.ops.aten.all.dimname_out,
359
+ torch.ops.aten.aminmax.default,
360
+ torch.ops.aten.aminmax.out,
361
+ torch.ops.aten.arange.default,
362
+ torch.ops.aten.arange.start,
363
+ torch.ops.aten.baddbmm.default,
364
+ torch.ops.aten.baddbmm.out,
365
+ torch.ops.aten.binary_cross_entropy.default,
366
+ torch.ops.aten.binary_cross_entropy.out,
367
+ torch.ops.aten.binary_cross_entropy_backward.default,
368
+ torch.ops.aten.binary_cross_entropy_backward.grad_input,
369
+ torch.ops.aten.binary_cross_entropy_with_logits.default,
370
+ torch.ops.aten.binary_cross_entropy_with_logits.out,
371
+ torch.ops.aten.block_diag.default,
372
+ torch.ops.aten.block_diag.out,
373
+ torch.ops.aten.celu.default,
374
+ torch.ops.aten.celu.out,
375
+ torch.ops.aten.celu_.default,
376
+ torch.ops.aten.channel_shuffle.default,
377
+ torch.ops.aten.channel_shuffle.out,
378
+ torch.ops.aten.clamp_max.default,
379
+ torch.ops.aten.clamp_max.Tensor,
380
+ torch.ops.aten.clamp_max.out,
381
+ torch.ops.aten.clamp_max.Tensor_out,
382
+ torch.ops.aten.clamp_min.default,
383
+ torch.ops.aten.clamp_min.Tensor,
384
+ torch.ops.aten.clamp_min.out,
385
+ torch.ops.aten.clamp_min.Tensor_out,
386
+ torch.ops.aten.col2im.default,
387
+ torch.ops.aten.col2im.out,
388
+ torch.ops.aten.count_nonzero.dim_IntList,
389
+ torch.ops.aten.count_nonzero.dim_IntList_out,
390
+ torch.ops.aten.count_nonzero.default,
391
+ torch.ops.aten.count_nonzero.out,
392
+ torch.ops.aten.linalg_cross.default,
393
+ torch.ops.aten.linalg_cross.out,
394
+ torch.ops.aten.cudnn_batch_norm.default,
395
+ torch.ops.aten.cudnn_batch_norm.out,
396
+ torch.ops.aten.cudnn_batch_norm_backward.default,
397
+ torch.ops.aten.cudnn_batch_norm_backward.out,
398
+ torch.ops.aten.miopen_batch_norm_backward.default,
399
+ torch.ops.aten.miopen_batch_norm_backward.out,
400
+ torch.ops.aten.deg2rad.default,
401
+ torch.ops.aten.deg2rad.out,
402
+ torch.ops.aten.deg2rad_.default,
403
+ torch.ops.aten.detach.default,
404
+ torch.ops.aten.diag_embed.default,
405
+ torch.ops.aten.diag_embed.out,
406
+ torch.ops.aten.diagonal_backward.default,
407
+ torch.ops.aten.diagonal_backward.out,
408
+ torch.ops.aten.dot.default,
409
+ torch.ops.aten.dot.out,
410
+ torch.ops.aten.vdot.default,
411
+ torch.ops.aten.vdot.out,
412
+ torch.ops.aten.elu.default,
413
+ torch.ops.aten.elu.out,
414
+ torch.ops.aten.elu_.default,
415
+ torch.ops.aten.elu_backward.default,
416
+ torch.ops.aten.elu_backward.grad_input,
417
+ torch.ops.aten.embedding_dense_backward.default,
418
+ torch.ops.aten.embedding_dense_backward.out,
419
+ torch.ops.aten.empty_like.default,
420
+ torch.ops.aten.empty_like.out,
421
+ torch.ops.aten._euclidean_dist.default,
422
+ torch.ops.aten.expand_copy.default,
423
+ torch.ops.aten.expand_copy.out,
424
+ torch.ops.aten.eye.default,
425
+ torch.ops.aten.eye.m,
426
+ torch.ops.aten.eye.out,
427
+ torch.ops.aten.eye.m_out,
428
+ torch.ops.aten.fill.Scalar,
429
+ torch.ops.aten.fill.Tensor,
430
+ torch.ops.aten.fill_.Scalar,
431
+ torch.ops.aten.fill_.Tensor,
432
+ torch.ops.aten.floor_divide.default,
433
+ torch.ops.aten.floor_divide.Scalar,
434
+ torch.ops.aten.floor_divide.out,
435
+ torch.ops.aten.floor_divide.Scalar_out,
436
+ torch.ops.aten.frac.default,
437
+ torch.ops.aten.frac.out,
438
+ torch.ops.aten.frac_.default,
439
+ torch.ops.aten.gelu_.default,
440
+ torch.ops.aten.gelu_backward.default,
441
+ torch.ops.aten.gelu_backward.grad_input,
442
+ torch.ops.aten.glu.default,
443
+ torch.ops.aten.glu.out,
444
+ torch.ops.aten.glu_backward.default,
445
+ torch.ops.aten.glu_backward.grad_input,
446
+ torch.ops.aten.hardshrink.default,
447
+ torch.ops.aten.hardshrink.out,
448
+ torch.ops.aten.hardsigmoid.default,
449
+ torch.ops.aten.hardsigmoid.out,
450
+ torch.ops.aten.hardsigmoid_.default,
451
+ torch.ops.aten.hardsigmoid_backward.default,
452
+ torch.ops.aten.hardsigmoid_backward.grad_input,
453
+ torch.ops.aten.hardswish.default,
454
+ torch.ops.aten.hardswish.out,
455
+ torch.ops.aten.hardswish_.default,
456
+ torch.ops.aten.hardswish_backward.default,
457
+ torch.ops.aten.hardswish_backward.out,
458
+ torch.ops.aten.hardtanh_.default,
459
+ torch.ops.aten.hardtanh_backward.default,
460
+ torch.ops.aten.hardtanh_backward.grad_input,
461
+ torch.ops.aten.heaviside.default,
462
+ torch.ops.aten.heaviside.out,
463
+ torch.ops.aten.heaviside_.default,
464
+ torch.ops.aten.huber_loss.default,
465
+ torch.ops.aten.huber_loss.out,
466
+ torch.ops.aten.huber_loss_backward.default,
467
+ torch.ops.aten.huber_loss_backward.out,
468
+ torch.ops.aten.im2col.default,
469
+ torch.ops.aten.im2col.out,
470
+ torch.ops.aten.index_add.default,
471
+ torch.ops.aten.index_add.out,
472
+ torch.ops.aten.index_add.dimname,
473
+ torch.ops.aten.index_add_.default,
474
+ torch.ops.aten.index_copy.default,
475
+ torch.ops.aten.index_copy.dimname,
476
+ torch.ops.aten.index_copy.out,
477
+ torch.ops.aten.index_copy_.default,
478
+ torch.ops.aten.index_copy_.dimname,
479
+ torch.ops.aten.index_fill.int_Tensor,
480
+ torch.ops.aten.index_fill.int_Scalar,
481
+ torch.ops.aten.index_fill.Dimname_Scalar,
482
+ torch.ops.aten.index_fill.Dimname_Tensor,
483
+ torch.ops.aten.index_fill.int_Scalar_out,
484
+ torch.ops.aten.index_fill.int_Tensor_out,
485
+ torch.ops.aten.index_fill_.int_Tensor,
486
+ torch.ops.aten.index_fill_.int_Scalar,
487
+ torch.ops.aten.index_fill_.Dimname_Scalar,
488
+ torch.ops.aten.index_fill_.Dimname_Tensor,
489
+ torch.ops.aten.isin.Tensor_Tensor,
490
+ torch.ops.aten.isin.Tensor_Tensor_out,
491
+ torch.ops.aten.isin.Tensor_Scalar,
492
+ torch.ops.aten.isin.Tensor_Scalar_out,
493
+ torch.ops.aten.isin.Scalar_Tensor,
494
+ torch.ops.aten.isin.Scalar_Tensor_out,
495
+ torch.ops.aten.isneginf.default,
496
+ torch.ops.aten.isneginf.out,
497
+ torch.ops.aten.isposinf.default,
498
+ torch.ops.aten.isposinf.out,
499
+ torch.ops.aten.leaky_relu_.default,
500
+ torch.ops.aten.leaky_relu_backward.default,
501
+ torch.ops.aten.leaky_relu_backward.grad_input,
502
+ torch.ops.aten.lerp.Scalar,
503
+ torch.ops.aten.lerp.Tensor,
504
+ torch.ops.aten.lerp.Scalar_out,
505
+ torch.ops.aten.lerp.Tensor_out,
506
+ torch.ops.aten.lerp_.Scalar,
507
+ torch.ops.aten.lerp_.Tensor,
508
+ torch.ops.aten.linspace.Tensor_Tensor,
509
+ torch.ops.aten.linspace.Tensor_Scalar,
510
+ torch.ops.aten.linspace.Scalar_Tensor,
511
+ torch.ops.aten.linspace.default,
512
+ torch.ops.aten.linspace.out,
513
+ torch.ops.aten.linspace.Tensor_Tensor_out,
514
+ torch.ops.aten.linspace.Tensor_Scalar_out,
515
+ torch.ops.aten.linspace.Scalar_Tensor_out,
516
+ torch.ops.aten.logaddexp.default,
517
+ torch.ops.aten.logaddexp.out,
518
+ torch.ops.aten.logaddexp2.default,
519
+ torch.ops.aten.logaddexp2.out,
520
+ torch.ops.aten.logit.default,
521
+ torch.ops.aten.logit.out,
522
+ torch.ops.aten.logit_.default,
523
+ torch.ops.aten.logit_backward.default,
524
+ torch.ops.aten.log_sigmoid_backward.default,
525
+ torch.ops.aten.log_sigmoid_backward.grad_input,
526
+ torch.ops.aten.log_sigmoid_forward.default,
527
+ torch.ops.aten.log_sigmoid_forward.output,
528
+ torch.ops.aten._log_softmax_backward_data.default,
529
+ torch.ops.aten._log_softmax_backward_data.out,
530
+ torch.ops.aten.logspace.Tensor_Tensor,
531
+ torch.ops.aten.logspace.Tensor_Scalar,
532
+ torch.ops.aten.logspace.Scalar_Tensor,
533
+ torch.ops.aten.logspace.default,
534
+ torch.ops.aten.logspace.out,
535
+ torch.ops.aten.logspace.Tensor_Tensor_out,
536
+ torch.ops.aten.logspace.Tensor_Scalar_out,
537
+ torch.ops.aten.logspace.Scalar_Tensor_out,
538
+ torch.ops.aten.logsumexp.default,
539
+ torch.ops.aten.masked_fill.Scalar,
540
+ torch.ops.aten.masked_fill.Tensor,
541
+ torch.ops.aten.masked_fill.Scalar_out,
542
+ torch.ops.aten.masked_fill.Tensor_out,
543
+ torch.ops.aten.masked_fill_.Scalar,
544
+ torch.ops.aten.masked_fill_.Tensor,
545
+ torch.ops.aten.mish.default,
546
+ torch.ops.aten.mish.out,
547
+ torch.ops.aten.mish_.default,
548
+ torch.ops.aten.mse_loss.default,
549
+ torch.ops.aten.mse_loss.out,
550
+ torch.ops.aten.mse_loss_backward.default,
551
+ torch.ops.aten.mse_loss_backward.grad_input,
552
+ torch.ops.aten.multi_margin_loss.default,
553
+ torch.ops.aten.multi_margin_loss.out,
554
+ torch.ops.aten.multilabel_margin_loss_forward.default,
555
+ torch.ops.aten.multilabel_margin_loss_forward.output,
556
+ torch.ops.aten.mv.default,
557
+ torch.ops.aten.mv.out,
558
+ torch.ops.aten.mvlgamma.default,
559
+ torch.ops.aten.mvlgamma.out,
560
+ torch.ops.aten.mvlgamma_.default,
561
+ torch.ops.aten.nansum.default,
562
+ torch.ops.aten.nansum.out,
563
+ torch.ops.aten.nan_to_num.default,
564
+ torch.ops.aten.nan_to_num.out,
565
+ torch.ops.aten.nan_to_num_.default,
566
+ torch.ops.aten.native_batch_norm_backward.default,
567
+ torch.ops.aten.native_batch_norm_backward.out,
568
+ torch.ops.aten.native_dropout_backward.default,
569
+ torch.ops.aten.native_dropout_backward.out,
570
+ torch.ops.aten.native_group_norm_backward.default,
571
+ torch.ops.aten.native_group_norm_backward.out,
572
+ torch.ops.aten.native_layer_norm_backward.default,
573
+ torch.ops.aten.native_layer_norm_backward.out,
574
+ torch.ops.aten.new_empty.default,
575
+ torch.ops.aten.new_empty.out,
576
+ torch.ops.aten.new_full.default,
577
+ torch.ops.aten.new_full.out,
578
+ torch.ops.aten.new_ones.default,
579
+ torch.ops.aten.new_ones.out,
580
+ torch.ops.aten.new_zeros.default,
581
+ torch.ops.aten.new_zeros.out,
582
+ torch.ops.aten.nll_loss2d_forward.default,
583
+ torch.ops.aten.nll_loss2d_forward.output,
584
+ torch.ops.aten.nll_loss2d_backward.default,
585
+ torch.ops.aten.nll_loss2d_backward.grad_input,
586
+ torch.ops.aten.nll_loss_backward.default,
587
+ torch.ops.aten.nll_loss_backward.grad_input,
588
+ torch.ops.aten.nll_loss_forward.default,
589
+ torch.ops.aten.nll_loss_forward.output,
590
+ torch.ops.aten.norm.Scalar,
591
+ torch.ops.aten.norm.ScalarOpt_dim,
592
+ torch.ops.aten.norm.names_ScalarOpt_dim,
593
+ torch.ops.aten.norm.ScalarOpt_dim_dtype,
594
+ torch.ops.aten.norm.dtype_out,
595
+ torch.ops.aten.norm.out,
596
+ torch.ops.aten.norm.ScalarOpt_dtype,
597
+ torch.ops.aten.norm.ScalarOpt_dtype_out,
598
+ torch.ops.aten.norm.Scalar_out,
599
+ torch.ops.aten.norm.names_ScalarOpt_dim_dtype,
600
+ torch.ops.aten.norm.names_dtype_out,
601
+ torch.ops.aten.norm.names_out,
602
+ torch.ops.aten.ones.default,
603
+ torch.ops.aten.ones_like.default,
604
+ torch.ops.aten.ones_like.out,
605
+ torch.ops.aten.pixel_shuffle.default,
606
+ torch.ops.aten.pixel_shuffle.out,
607
+ torch.ops.aten.pixel_unshuffle.default,
608
+ torch.ops.aten.pixel_unshuffle.out,
609
+ torch.ops.aten._prelu_kernel.default,
610
+ torch.ops.aten._prelu_kernel_backward.default,
611
+ torch.ops.aten._reshape_alias.default,
612
+ torch.ops.aten.rad2deg.default,
613
+ torch.ops.aten.rad2deg.out,
614
+ torch.ops.aten.rad2deg_.default,
615
+ torch.ops.aten.reflection_pad1d.default,
616
+ torch.ops.aten.reflection_pad1d.out,
617
+ torch.ops.aten.reflection_pad1d_backward.default,
618
+ torch.ops.aten.reflection_pad1d_backward.grad_input,
619
+ torch.ops.aten.reflection_pad2d.default,
620
+ torch.ops.aten.reflection_pad2d.out,
621
+ torch.ops.aten.reflection_pad2d_backward.default,
622
+ torch.ops.aten.reflection_pad2d_backward.grad_input,
623
+ torch.ops.aten.reflection_pad3d.default,
624
+ torch.ops.aten.reflection_pad3d.out,
625
+ torch.ops.aten.reflection_pad3d_backward.default,
626
+ torch.ops.aten.reflection_pad3d_backward.grad_input,
627
+ torch.ops.aten.replication_pad1d.default,
628
+ torch.ops.aten.replication_pad1d.out,
629
+ torch.ops.aten.replication_pad2d.default,
630
+ torch.ops.aten.replication_pad2d.out,
631
+ torch.ops.aten.replication_pad3d.default,
632
+ torch.ops.aten.replication_pad3d.out,
633
+ torch.ops.aten.renorm.default,
634
+ torch.ops.aten.renorm.out,
635
+ torch.ops.aten.renorm_.default,
636
+ torch.ops.aten.resize_as.default,
637
+ torch.ops.aten.resize_as.out,
638
+ torch.ops.aten.roll.default,
639
+ torch.ops.aten.roll.out,
640
+ torch.ops.aten.rot90.default,
641
+ torch.ops.aten.rot90.out,
642
+ torch.ops.aten.rrelu_with_noise.default,
643
+ torch.ops.aten.rrelu_with_noise.out,
644
+ torch.ops.aten.rrelu_with_noise_.default,
645
+ torch.ops.aten.rsub.Tensor,
646
+ torch.ops.aten.rsub.Scalar,
647
+ torch.ops.aten.rsub.Tensor_out,
648
+ torch.ops.aten.rsub.Scalar_out,
649
+ torch.ops.aten._safe_softmax.default,
650
+ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
651
+ torch.ops.aten.select_backward.default,
652
+ torch.ops.aten.select_backward.out,
653
+ torch.ops.aten.select_scatter.default,
654
+ torch.ops.aten.select_scatter.out,
655
+ torch.ops.aten.sgn.default,
656
+ torch.ops.aten.sgn.out,
657
+ torch.ops.aten.sgn_.default,
658
+ torch.ops.aten.sigmoid_backward.default,
659
+ torch.ops.aten.sigmoid_backward.grad_input,
660
+ torch.ops.aten.silu.default,
661
+ torch.ops.aten.silu.out,
662
+ torch.ops.aten.silu_.default,
663
+ torch.ops.aten.silu_backward.default,
664
+ torch.ops.aten.silu_backward.grad_input,
665
+ torch.ops.aten.sinc.default,
666
+ torch.ops.aten.sinc.out,
667
+ torch.ops.aten.sinc_.default,
668
+ torch.ops.aten.slice_backward.default,
669
+ torch.ops.aten.slice_backward.out,
670
+ torch.ops.aten.smooth_l1_loss.default,
671
+ torch.ops.aten.smooth_l1_loss.out,
672
+ torch.ops.aten.smooth_l1_loss_backward.default,
673
+ torch.ops.aten.smooth_l1_loss_backward.grad_input,
674
+ torch.ops.aten.soft_margin_loss.default,
675
+ torch.ops.aten.soft_margin_loss.out,
676
+ torch.ops.aten.soft_margin_loss_backward.default,
677
+ torch.ops.aten.soft_margin_loss_backward.grad_input,
678
+ torch.ops.aten._softmax_backward_data.default,
679
+ torch.ops.aten._softmax_backward_data.out,
680
+ torch.ops.aten.softplus.default,
681
+ torch.ops.aten.softplus.out,
682
+ torch.ops.aten.softplus_backward.default,
683
+ torch.ops.aten.softplus_backward.grad_input,
684
+ torch.ops.aten.softshrink.default,
685
+ torch.ops.aten.softshrink.out,
686
+ torch.ops.aten.special_entr.default,
687
+ torch.ops.aten.special_entr.out,
688
+ torch.ops.aten.special_log_ndtr.default,
689
+ torch.ops.aten.special_log_ndtr.out,
690
+ torch.ops.aten.special_xlog1py.default,
691
+ torch.ops.aten.special_xlog1py.other_scalar,
692
+ torch.ops.aten.special_xlog1py.self_scalar,
693
+ torch.ops.aten.special_xlog1py.out,
694
+ torch.ops.aten.special_xlog1py.self_scalar_out,
695
+ torch.ops.aten.special_xlog1py.other_scalar_out,
696
+ torch.ops.aten.split.Tensor,
697
+ torch.ops.aten.split_with_sizes_copy.default,
698
+ torch.ops.aten.split_with_sizes_copy.out,
699
+ torch.ops.aten.squeeze.default,
700
+ torch.ops.aten.squeeze.dim,
701
+ torch.ops.aten.std.default,
702
+ torch.ops.aten.std.dim,
703
+ torch.ops.aten.std.correction,
704
+ torch.ops.aten.std.names_dim,
705
+ torch.ops.aten.std.names_out,
706
+ torch.ops.aten.std.out,
707
+ torch.ops.aten.std.correction_out,
708
+ torch.ops.aten.std.correction_names,
709
+ torch.ops.aten.std.correction_names_out,
710
+ torch.ops.aten.std_mean.default,
711
+ torch.ops.aten.std_mean.dim,
712
+ torch.ops.aten.std_mean.correction,
713
+ torch.ops.aten.std_mean.names_dim,
714
+ torch.ops.aten.std_mean.correction_names,
715
+ torch.ops.aten.std_mean.correction_out,
716
+ torch.ops.aten.stack.default,
717
+ torch.ops.aten.stack.out,
718
+ torch.ops.aten.sum.default,
719
+ torch.ops.aten.sum.out,
720
+ torch.ops.aten.t.default,
721
+ torch.ops.aten.t_copy.out,
722
+ torch.ops.aten.t_copy.default,
723
+ torch.ops.aten.take.default,
724
+ torch.ops.aten.take.out,
725
+ torch.ops.aten.tanh_backward.default,
726
+ torch.ops.aten.tanh_backward.grad_input,
727
+ torch.ops.aten.threshold.default,
728
+ torch.ops.aten.threshold.out,
729
+ torch.ops.aten.threshold_.default,
730
+ torch.ops.aten.threshold_backward.default,
731
+ torch.ops.aten.threshold_backward.grad_input,
732
+ torch.ops.aten.trace.default,
733
+ torch.ops.aten.trace.out,
734
+ torch.ops.aten.transpose.int,
735
+ torch.ops.aten.tril.default,
736
+ torch.ops.aten.tril.out,
737
+ torch.ops.aten.tril_.default,
738
+ torch.ops.aten.triu.default,
739
+ torch.ops.aten.triu.out,
740
+ torch.ops.aten.triu_.default,
741
+ torch.ops.aten.unbind.int,
742
+ torch.ops.aten.unbind.Dimname,
743
+ torch.ops.aten.unfold_backward.default,
744
+ torch.ops.aten.unfold_backward.out,
745
+ torch.ops.aten.unfold_copy.default,
746
+ torch.ops.aten.unfold_copy.out,
747
+ torch.ops.aten._unsafe_index.Tensor,
748
+ torch.ops.aten._unsafe_index_put.default,
749
+ torch.ops.aten._unsafe_masked_index.default,
750
+ torch.ops.aten._unsafe_masked_index_put_accumulate.default,
751
+ torch.ops.aten.unsafe_split.Tensor,
752
+ torch.ops.aten.unsafe_split_with_sizes.default,
753
+ torch.ops.aten.unsqueeze_copy.out,
754
+ torch.ops.aten.unsqueeze_copy.default,
755
+ torch.ops.aten._unsafe_view.default,
756
+ torch.ops.aten._unsafe_view.out,
757
+ torch.ops.aten.upsample_linear1d.default,
758
+ torch.ops.aten.upsample_linear1d.out,
759
+ torch.ops.aten.upsample_bilinear2d.vec,
760
+ torch.ops.aten.upsample_bilinear2d.default,
761
+ torch.ops.aten.upsample_bilinear2d.out,
762
+ torch.ops.aten.upsample_trilinear3d.vec,
763
+ torch.ops.aten.upsample_trilinear3d.default,
764
+ torch.ops.aten.upsample_trilinear3d.out,
765
+ torch.ops.aten.xlogy.Tensor,
766
+ torch.ops.aten.xlogy.Scalar_Other,
767
+ torch.ops.aten.xlogy.Scalar_Self,
768
+ torch.ops.aten.xlogy.OutTensor,
769
+ torch.ops.aten.xlogy.OutScalar_Self,
770
+ torch.ops.aten.xlogy.OutScalar_Other,
771
+ torch.ops.aten.xlogy_.Tensor,
772
+ torch.ops.aten.xlogy_.Scalar_Other,
773
+ torch.ops.aten.zero.default,
774
+ torch.ops.aten.zero.out,
775
+ torch.ops.aten.zero_.default,
776
+ torch.ops.aten.zeros.default,
777
+ torch.ops.aten.zeros_like.default,
778
+ torch.ops.aten.zeros_like.out,
779
+ torch.ops.aten._chunk_cat.default,
780
+ torch.ops.aten._chunk_cat.out,
781
+ torch.ops.aten._weight_norm_interface.default,
782
+ torch.ops.aten._weight_norm_interface.out,
783
+ torch.ops.aten.__iand__.Tensor,
784
+ torch.ops.aten.__ixor__.Tensor,
785
+ torch.ops.aten.__ilshift__.Tensor,
786
+ torch.ops.aten.__ilshift__.Scalar,
787
+ torch.ops.aten.__irshift__.Tensor,
788
+ torch.ops.aten.__irshift__.Scalar,
789
+ torch.ops.aten.__ior__.Tensor,
790
+ ]
791
+ )
792
+
793
+ MUTABLE_DECOMPOSITION = [
794
+ torch.ops.aten.bernoulli_.Tensor,
795
+ torch.ops.aten.bernoulli_.float,
796
+ ]