torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/decompositions.py CHANGED
@@ -8,7 +8,7 @@ Can also contain decompositions of a torch op in terms of other torch ops.
8
8
  """
9
9
 
10
10
  import functools
11
- from typing import Any, Callable, List, Tuple
11
+ from typing import Any, Callable, List, Tuple
12
12
 
13
13
  import torch
14
14
  from torch import Tensor
@@ -18,7 +18,6 @@ from torch._decomp import register_decomposition
18
18
  import torch._prims_common as utils
19
19
  from torch._prims_common.wrappers import out_wrapper
20
20
 
21
-
22
21
  DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
23
22
 
24
23
  # None of these functions are publicly accessible; get at them
@@ -27,41 +26,49 @@ __all__: List[str] = []
27
26
 
28
27
  aten = torch._ops.ops.aten
29
28
 
29
+
30
30
  def _try_register(op, impl):
31
- try:
32
- register_decomposition(op)(impl)
33
- except:
34
- pass
31
+ try:
32
+ register_decomposition(op)(impl)
33
+ except:
34
+ pass
35
+
35
36
 
36
37
  @out_wrapper()
37
38
  def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
38
- def idx(left, middle, right):
39
- dim_idx = torch.arange(-left, middle + right, device=a.device)
40
- return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
41
39
 
42
- return _reflection_or_replication_pad(
43
- a,
44
- padding,
45
- idx,
46
- )
40
+ def idx(left, middle, right):
41
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
42
+ return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
43
+
44
+ return _reflection_or_replication_pad(
45
+ a,
46
+ padding,
47
+ idx,
48
+ )
49
+
47
50
 
48
51
  _try_register(aten.reflection_pad1d, _reflection_pad)
49
52
  _try_register(aten.reflection_pad2d, _reflection_pad)
50
53
  _try_register(aten.reflection_pad3d, _reflection_pad)
51
54
 
55
+
52
56
  @out_wrapper()
53
57
  def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
54
- def idx(left, middle, right):
55
- dim_idx = torch.arange(-left, middle + right, device=a.device)
56
- return torch.clamp(dim_idx, 0, middle - 1)
57
58
 
58
- return _reflection_or_replication_pad(
59
- a,
60
- padding,
61
- idx,
62
- )
59
+ def idx(left, middle, right):
60
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
61
+ return torch.clamp(dim_idx, 0, middle - 1)
63
62
 
64
- decomp.global_decomposition_table['post_autograd'][aten.replication_pad2d.default] = _replication_pad
63
+ return _reflection_or_replication_pad(
64
+ a,
65
+ padding,
66
+ idx,
67
+ )
68
+
69
+
70
+ decomp.global_decomposition_table["post_autograd"][
71
+ aten.replication_pad2d.default] = _replication_pad
65
72
 
66
73
 
67
74
  def _reflection_or_replication_pad(
@@ -69,64 +76,70 @@ def _reflection_or_replication_pad(
69
76
  padding: Tuple[int, ...],
70
77
  idx_fn: Callable[[int, int, int], Tensor],
71
78
  ) -> Tensor:
72
- dim = len(padding) // 2
73
- torch._check(
74
- a.dim() in (dim + 1, dim + 2),
75
- lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
76
- )
77
- inp_shape = a.shape[-dim:]
78
- nc_dim = a.dim() - dim
79
-
80
- padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
81
- padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
82
-
83
- result = a
84
- for i in range(dim):
85
- idx: List[Any] = [None] * result.dim()
86
- idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
87
- result = aten._unsafe_index(result, idx)
88
-
89
- # convert output to correct memory format, if necessary
90
- memory_format = utils.suggest_memory_format(result)
91
- result = result.contiguous(memory_format=memory_format)
92
- return result
79
+ dim = len(padding) // 2
80
+ torch._check(
81
+ a.dim() in (dim + 1, dim + 2),
82
+ lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
83
+ )
84
+ inp_shape = a.shape[-dim:]
85
+ nc_dim = a.dim() - dim
86
+
87
+ padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
88
+ padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
89
+
90
+ result = a
91
+ for i in range(dim):
92
+ idx: List[Any] = [None] * result.dim()
93
+ idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
94
+ result = aten._unsafe_index(result, idx)
95
+
96
+ # convert output to correct memory format, if necessary
97
+ memory_format = utils.suggest_memory_format(result)
98
+ result = result.contiguous(memory_format=memory_format)
99
+ return result
100
+
93
101
 
94
102
  _try_register(aten.replication_pad1d, _replication_pad)
95
103
  _try_register(aten.replication_pad3d, _replication_pad)
96
104
 
105
+
97
106
  def bernoulli(self, *, generator=None):
98
- return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
107
+ return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
108
+
99
109
 
100
110
  _try_register(aten.bernoulli.default, bernoulli)
101
111
 
102
112
 
103
113
  def rand_like(self, **kwargs):
104
- dtype = kwargs.get('dtype', self.dtype)
105
- return torch.rand(self.shape, dtype=dtype)
114
+ dtype = kwargs.get("dtype", self.dtype)
115
+ return torch.rand(self.shape, dtype=dtype)
116
+
106
117
 
107
118
  def channel_shuffle(self, groups):
108
- batchsize, channels, height, width = self.shape
109
- channels_per_group = channels // groups
110
- self = self.reshape(batchsize, groups, channels_per_group, height, width)
111
- self = self.transpose(1, 2)
112
- self = self.reshape(batchsize, channels, height, width)
113
- return self
119
+ batchsize, channels, height, width = self.shape
120
+ channels_per_group = channels // groups
121
+ self = self.reshape(batchsize, groups, channels_per_group, height, width)
122
+ self = self.transpose(1, 2)
123
+ self = self.reshape(batchsize, channels, height, width)
124
+ return self
125
+
114
126
 
115
127
  _try_register(aten.channel_shuffle, channel_shuffle)
116
128
 
117
129
  _try_register(aten.bernoulli, bernoulli)
118
130
  _try_register(aten.rand_like, rand_like)
119
131
 
132
+
120
133
  def bernoulli_float(self, p=0.5):
121
- return self.bernoulli_(torch.tensor(p))
134
+ return self.bernoulli_(p)
135
+
122
136
 
123
137
  _try_register(aten.bernoulli_.float, bernoulli_float)
124
138
  _try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_)
125
139
 
126
140
 
127
-
128
141
  def _sum_tensors(ts) -> Tensor:
129
- return functools.reduce(torch.add, ts)
142
+ return functools.reduce(torch.add, ts)
130
143
 
131
144
 
132
145
  @register_decomposition(aten.grid_sampler_3d)
@@ -137,145 +150,144 @@ def _grid_sampler_3d(
137
150
  padding_mode: int = 0,
138
151
  align_corners: bool = False,
139
152
  ) -> Tensor:
140
- """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
141
-
142
- The above implement the 2d case.
143
- """
144
- _expand_grid = False
145
- torch._check(
146
- interpolation_mode in (0, 1),
147
- lambda: f"Invalid interpolation mode {interpolation_mode}",
148
- )
149
- torch._check(
150
- padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
151
- )
152
-
153
- # a is 5D: [B, C, D, H, W]
154
-
155
- def unnormalize(coords: Tensor, size: int) -> Tensor:
156
- # Rescale coordinates from [-1, 1] to:
157
- # [0, size - 1] if align_corners is True
158
- # [-.5, size -.5] if align_corners is False
159
- mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
160
- ofs = size * 0.5 - 0.5
161
- return coords * mul + ofs
162
-
163
- # Reflects coordinates until they fall between low and high (inclusive).
164
- # The bounds are passed as twice their value so that half-integer values
165
- # can be represented as ints.
166
- def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
167
- if twice_low == twice_high:
168
- return torch.zeros_like(coords)
169
- coords_min = twice_low / 2
170
- coords_span = (twice_high - twice_low) / 2
171
- coords2 = (coords - coords_min).abs()
172
- extra = torch.fmod(coords2, coords_span)
173
- flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
174
- return torch.where(
175
- flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
176
- )
177
-
178
- def compute_coordinates(coords: Tensor, size: int) -> Tensor:
179
- if padding_mode == 0: # Zero
180
- return coords
181
- elif padding_mode == 1: # Borders
182
- return torch.clamp(coords, 0, size - 1)
183
- else: # padding_mode == 2, Reflection
184
- if align_corners:
185
- coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
186
- else:
187
- coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
188
- return torch.clamp(coords_reflected, 0, size - 1)
189
-
190
- def compute_source_index(coords: Tensor, size: int) -> Tensor:
191
- coords_un = unnormalize(coords, size)
192
- return compute_coordinates(coords_un, size)
193
-
194
- N, C, iD, iH, iW = a.shape
195
- _, oD, oH, oW, three = grid.shape
196
- assert three == 3, 'Last dim of grid must be 3. got {}'.format(three)
197
-
198
-
199
- def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
200
- xcheck = torch.logical_and(0 <= xs, xs < iW)
201
- ycheck = torch.logical_and(0 <= ys, ys < iH)
202
- zcheck = torch.logical_and(0 <= zs, zs < iD)
203
- return torch.logical_and(
204
- xcheck, torch.logical_and(ycheck, zcheck)
205
- )
206
-
207
- N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
208
- C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)
209
-
210
- def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
211
- cond = in_bounds_cond(xs, ys, zs)
212
- # To clip to inside valid coordinates, we map the coordinates
213
- # to (x, y) = (0, 0) and also set the weight to 0
214
- # We also change the shape of the tensor to the appropriate one for
215
- # broadcasting with N_idx, C_idx for the purposes of advanced indexing
216
- c = C if _expand_grid else 1
217
- return tuple(
218
- torch.where(cond, t, 0).view(N, c, oD, oH, oW)
219
- for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), zs.to(dtype=torch.int64), ws)
220
- )
221
-
222
- def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
223
- # Perform clipping, index into input tensor and multiply by weight
224
- idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
225
- return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
226
-
227
- x = grid[..., 0]
228
- y = grid[..., 1]
229
- d = grid[..., 2]
230
-
231
- if interpolation_mode == 0: # Bilinear
232
- ix = compute_source_index(x, iW)
233
- iy = compute_source_index(y, iH)
234
- id_ = compute_source_index(d, iD)
235
-
236
- ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
237
- ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
238
- ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
239
- ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
240
- ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
241
- ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
242
- ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
243
- ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1
244
-
245
- w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
246
- w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb- id_)
247
- w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
248
- w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
249
- w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
250
- w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
251
- w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
252
- w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
253
-
254
- return _sum_tensors(
255
- get_summand(ix, iy, id_, w)
256
- for (ix, iy, id_, w) in (
257
- (ix_nwf, iy_nwf, id_nwf, w_nwf),
258
- (ix_nef, iy_nef, id_nef, w_nef),
259
- (ix_swf, iy_swf, id_swf, w_swf),
260
- (ix_sef, iy_sef, id_sef, w_sef),
261
- (ix_nwb, iy_nwb, id_nwb, w_nwb),
262
- (ix_neb, iy_neb, id_neb, w_neb),
263
- (ix_swb, iy_swb, id_swb, w_swb),
264
- (ix_seb, iy_seb, id_seb, w_seb),
265
- )
266
- )
267
- else: #interpolation_mode == 1: # Nearest
268
- ix = compute_source_index(x, iW)
269
- iy = compute_source_index(y, iH)
270
- iz = compute_source_index(d, iD)
271
-
272
- ix_nearest = ix.round()
273
- iy_nearest = iy.round()
274
- iz_nearest = iz.round()
275
-
276
- return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
277
-
278
- EXTRA_DECOMP = decomp.get_decompositions([
153
+ """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
154
+
155
+ The above implement the 2d case.
156
+ """
157
+ _expand_grid = False
158
+ torch._check(
159
+ interpolation_mode in (0, 1),
160
+ lambda: f"Invalid interpolation mode {interpolation_mode}",
161
+ )
162
+ torch._check(
163
+ padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}")
164
+
165
+ # a is 5D: [B, C, D, H, W]
166
+
167
+ def unnormalize(coords: Tensor, size: int) -> Tensor:
168
+ # Rescale coordinates from [-1, 1] to:
169
+ # [0, size - 1] if align_corners is True
170
+ # [-.5, size -.5] if align_corners is False
171
+ mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
172
+ ofs = size * 0.5 - 0.5
173
+ return coords * mul + ofs
174
+
175
+ # Reflects coordinates until they fall between low and high (inclusive).
176
+ # The bounds are passed as twice their value so that half-integer values
177
+ # can be represented as ints.
178
+ def reflect_coordinates(coords: Tensor, twice_low: int,
179
+ twice_high: int) -> Tensor:
180
+ if twice_low == twice_high:
181
+ return torch.zeros_like(coords)
182
+ coords_min = twice_low / 2
183
+ coords_span = (twice_high - twice_low) / 2
184
+ coords2 = (coords - coords_min).abs()
185
+ extra = torch.fmod(coords2, coords_span)
186
+ flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
187
+ return torch.where(flips & 1 == 0, extra + coords_min,
188
+ coords_span + coords_min - extra)
189
+
190
+ def compute_coordinates(coords: Tensor, size: int) -> Tensor:
191
+ if padding_mode == 0: # Zero
192
+ return coords
193
+ elif padding_mode == 1: # Borders
194
+ return torch.clamp(coords, 0, size - 1)
195
+ else: # padding_mode == 2, Reflection
196
+ if align_corners:
197
+ coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
198
+ else:
199
+ coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
200
+ return torch.clamp(coords_reflected, 0, size - 1)
201
+
202
+ def compute_source_index(coords: Tensor, size: int) -> Tensor:
203
+ coords_un = unnormalize(coords, size)
204
+ return compute_coordinates(coords_un, size)
205
+
206
+ N, C, iD, iH, iW = a.shape
207
+ _, oD, oH, oW, three = grid.shape
208
+ assert three == 3, "Last dim of grid must be 3. got {}".format(three)
209
+
210
+ def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
211
+ xcheck = torch.logical_and(0 <= xs, xs < iW)
212
+ ycheck = torch.logical_and(0 <= ys, ys < iH)
213
+ zcheck = torch.logical_and(0 <= zs, zs < iD)
214
+ return torch.logical_and(xcheck, torch.logical_and(ycheck, zcheck))
215
+
216
+ N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
217
+ C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)
218
+
219
+ def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
220
+ cond = in_bounds_cond(xs, ys, zs)
221
+ # To clip to inside valid coordinates, we map the coordinates
222
+ # to (x, y) = (0, 0) and also set the weight to 0
223
+ # We also change the shape of the tensor to the appropriate one for
224
+ # broadcasting with N_idx, C_idx for the purposes of advanced indexing
225
+ c = C if _expand_grid else 1
226
+ return tuple(
227
+ torch.where(cond, t, 0).view(N, c, oD, oH, oW) for t in (
228
+ xs.to(dtype=torch.int64),
229
+ ys.to(dtype=torch.int64),
230
+ zs.to(dtype=torch.int64),
231
+ ws,
232
+ ))
233
+
234
+ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
235
+ w) -> Tensor:
236
+ # Perform clipping, index into input tensor and multiply by weight
237
+ idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
238
+ return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
239
+
240
+ x = grid[..., 0]
241
+ y = grid[..., 1]
242
+ d = grid[..., 2]
243
+
244
+ if interpolation_mode == 0: # Bilinear
245
+ ix = compute_source_index(x, iW)
246
+ iy = compute_source_index(y, iH)
247
+ id_ = compute_source_index(d, iD)
248
+
249
+ ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
250
+ ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
251
+ ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
252
+ ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
253
+ ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
254
+ ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
255
+ ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
256
+ ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1
257
+
258
+ w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
259
+ w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb - id_)
260
+ w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
261
+ w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
262
+ w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
263
+ w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
264
+ w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
265
+ w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
266
+
267
+ return _sum_tensors(
268
+ get_summand(ix, iy, id_, w) for (ix, iy, id_, w) in (
269
+ (ix_nwf, iy_nwf, id_nwf, w_nwf),
270
+ (ix_nef, iy_nef, id_nef, w_nef),
271
+ (ix_swf, iy_swf, id_swf, w_swf),
272
+ (ix_sef, iy_sef, id_sef, w_sef),
273
+ (ix_nwb, iy_nwb, id_nwb, w_nwb),
274
+ (ix_neb, iy_neb, id_neb, w_neb),
275
+ (ix_swb, iy_swb, id_swb, w_swb),
276
+ (ix_seb, iy_seb, id_seb, w_seb),
277
+ ))
278
+ else: # interpolation_mode == 1: # Nearest
279
+ ix = compute_source_index(x, iW)
280
+ iy = compute_source_index(y, iH)
281
+ iz = compute_source_index(d, iD)
282
+
283
+ ix_nearest = ix.round()
284
+ iy_nearest = iy.round()
285
+ iz_nearest = iz.round()
286
+
287
+ return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
288
+
289
+
290
+ DECOMPOSITIONS = decomp.get_decompositions([
279
291
  torch.ops.aten.upsample_bicubic2d,
280
292
  torch.ops.aten.upsample_nearest1d,
281
293
  torch.ops.aten.upsample_nearest2d,
@@ -305,4 +317,460 @@ EXTRA_DECOMP = decomp.get_decompositions([
305
317
  torch.ops.aten.bernoulli_.Tensor,
306
318
  torch.ops.aten.bernoulli_.float,
307
319
  torch.ops.aten.log_normal,
320
+ torch.ops.aten.addcdiv.default,
321
+ torch.ops.aten.addcdiv.out,
322
+ torch.ops.aten.addcdiv_.default,
323
+ torch.ops.aten.addcmul.default,
324
+ torch.ops.aten.addcmul.out,
325
+ torch.ops.aten.addcmul_.default,
326
+ torch.ops.aten.addr.default,
327
+ torch.ops.aten.addr.out,
328
+ torch.ops.aten.affine_grid_generator.default,
329
+ torch.ops.aten.affine_grid_generator.out,
330
+ torch.ops.aten.alias_copy.default,
331
+ torch.ops.aten.alias_copy.out,
332
+ torch.ops.aten.all.default,
333
+ torch.ops.aten.all.dim,
334
+ torch.ops.aten.all.dims,
335
+ torch.ops.aten.all.out,
336
+ torch.ops.aten.all.dims_out,
337
+ torch.ops.aten.all.all_out,
338
+ torch.ops.aten.all.dimname,
339
+ torch.ops.aten.all.dimname_out,
340
+ torch.ops.aten.aminmax.default,
341
+ torch.ops.aten.aminmax.out,
342
+ torch.ops.aten.arange.default,
343
+ torch.ops.aten.arange.start,
344
+ torch.ops.aten.baddbmm.default,
345
+ torch.ops.aten.baddbmm.out,
346
+ torch.ops.aten.binary_cross_entropy.default,
347
+ torch.ops.aten.binary_cross_entropy.out,
348
+ torch.ops.aten.binary_cross_entropy_backward.default,
349
+ torch.ops.aten.binary_cross_entropy_backward.grad_input,
350
+ torch.ops.aten.binary_cross_entropy_with_logits.default,
351
+ torch.ops.aten.binary_cross_entropy_with_logits.out,
352
+ torch.ops.aten.block_diag.default,
353
+ torch.ops.aten.block_diag.out,
354
+ torch.ops.aten.celu.default,
355
+ torch.ops.aten.celu.out,
356
+ torch.ops.aten.celu_.default,
357
+ torch.ops.aten.channel_shuffle.default,
358
+ torch.ops.aten.channel_shuffle.out,
359
+ torch.ops.aten.clamp_max.default,
360
+ torch.ops.aten.clamp_max.Tensor,
361
+ torch.ops.aten.clamp_max.out,
362
+ torch.ops.aten.clamp_max.Tensor_out,
363
+ torch.ops.aten.clamp_min.default,
364
+ torch.ops.aten.clamp_min.Tensor,
365
+ torch.ops.aten.clamp_min.out,
366
+ torch.ops.aten.clamp_min.Tensor_out,
367
+ torch.ops.aten.col2im.default,
368
+ torch.ops.aten.col2im.out,
369
+ torch.ops.aten.count_nonzero.dim_IntList,
370
+ torch.ops.aten.count_nonzero.dim_IntList_out,
371
+ torch.ops.aten.count_nonzero.default,
372
+ torch.ops.aten.count_nonzero.out,
373
+ torch.ops.aten.linalg_cross.default,
374
+ torch.ops.aten.linalg_cross.out,
375
+ torch.ops.aten.cudnn_batch_norm.default,
376
+ torch.ops.aten.cudnn_batch_norm.out,
377
+ torch.ops.aten.cudnn_batch_norm_backward.default,
378
+ torch.ops.aten.cudnn_batch_norm_backward.out,
379
+ torch.ops.aten.miopen_batch_norm_backward.default,
380
+ torch.ops.aten.miopen_batch_norm_backward.out,
381
+ torch.ops.aten.deg2rad.default,
382
+ torch.ops.aten.deg2rad.out,
383
+ torch.ops.aten.deg2rad_.default,
384
+ torch.ops.aten.detach.default,
385
+ torch.ops.aten.diag_embed.default,
386
+ torch.ops.aten.diag_embed.out,
387
+ torch.ops.aten.diagonal_backward.default,
388
+ torch.ops.aten.diagonal_backward.out,
389
+ torch.ops.aten.dot.default,
390
+ torch.ops.aten.dot.out,
391
+ torch.ops.aten.vdot.default,
392
+ torch.ops.aten.vdot.out,
393
+ torch.ops.aten.elu.default,
394
+ torch.ops.aten.elu.out,
395
+ torch.ops.aten.elu_.default,
396
+ torch.ops.aten.elu_backward.default,
397
+ torch.ops.aten.elu_backward.grad_input,
398
+ torch.ops.aten.embedding_dense_backward.default,
399
+ torch.ops.aten.embedding_dense_backward.out,
400
+ torch.ops.aten.empty_like.default,
401
+ torch.ops.aten.empty_like.out,
402
+ torch.ops.aten._euclidean_dist.default,
403
+ torch.ops.aten.expand_copy.default,
404
+ torch.ops.aten.expand_copy.out,
405
+ torch.ops.aten.eye.default,
406
+ torch.ops.aten.eye.m,
407
+ torch.ops.aten.eye.out,
408
+ torch.ops.aten.eye.m_out,
409
+ torch.ops.aten.fill.Scalar,
410
+ torch.ops.aten.fill.Tensor,
411
+ torch.ops.aten.fill_.Scalar,
412
+ torch.ops.aten.fill_.Tensor,
413
+ torch.ops.aten.floor_divide.default,
414
+ torch.ops.aten.floor_divide.Scalar,
415
+ torch.ops.aten.floor_divide.out,
416
+ torch.ops.aten.floor_divide.Scalar_out,
417
+ torch.ops.aten.frac.default,
418
+ torch.ops.aten.frac.out,
419
+ torch.ops.aten.frac_.default,
420
+ torch.ops.aten.gelu_.default,
421
+ torch.ops.aten.gelu_backward.default,
422
+ torch.ops.aten.gelu_backward.grad_input,
423
+ torch.ops.aten.glu.default,
424
+ torch.ops.aten.glu.out,
425
+ torch.ops.aten.glu_backward.default,
426
+ torch.ops.aten.glu_backward.grad_input,
427
+ torch.ops.aten.hardshrink.default,
428
+ torch.ops.aten.hardshrink.out,
429
+ torch.ops.aten.hardsigmoid.default,
430
+ torch.ops.aten.hardsigmoid.out,
431
+ torch.ops.aten.hardsigmoid_.default,
432
+ torch.ops.aten.hardsigmoid_backward.default,
433
+ torch.ops.aten.hardsigmoid_backward.grad_input,
434
+ torch.ops.aten.hardswish.default,
435
+ torch.ops.aten.hardswish.out,
436
+ torch.ops.aten.hardswish_.default,
437
+ torch.ops.aten.hardswish_backward.default,
438
+ torch.ops.aten.hardswish_backward.out,
439
+ torch.ops.aten.hardtanh_.default,
440
+ torch.ops.aten.hardtanh_backward.default,
441
+ torch.ops.aten.hardtanh_backward.grad_input,
442
+ torch.ops.aten.heaviside.default,
443
+ torch.ops.aten.heaviside.out,
444
+ torch.ops.aten.heaviside_.default,
445
+ torch.ops.aten.huber_loss.default,
446
+ torch.ops.aten.huber_loss.out,
447
+ torch.ops.aten.huber_loss_backward.default,
448
+ torch.ops.aten.huber_loss_backward.out,
449
+ torch.ops.aten.im2col.default,
450
+ torch.ops.aten.im2col.out,
451
+ torch.ops.aten.index_add.default,
452
+ torch.ops.aten.index_add.out,
453
+ torch.ops.aten.index_add.dimname,
454
+ torch.ops.aten.index_add_.default,
455
+ torch.ops.aten.index_copy.default,
456
+ torch.ops.aten.index_copy.dimname,
457
+ torch.ops.aten.index_copy.out,
458
+ torch.ops.aten.index_copy_.default,
459
+ torch.ops.aten.index_copy_.dimname,
460
+ torch.ops.aten.index_fill.int_Tensor,
461
+ torch.ops.aten.index_fill.int_Scalar,
462
+ torch.ops.aten.index_fill.Dimname_Scalar,
463
+ torch.ops.aten.index_fill.Dimname_Tensor,
464
+ torch.ops.aten.index_fill.int_Scalar_out,
465
+ torch.ops.aten.index_fill.int_Tensor_out,
466
+ torch.ops.aten.index_fill_.int_Tensor,
467
+ torch.ops.aten.index_fill_.int_Scalar,
468
+ torch.ops.aten.index_fill_.Dimname_Scalar,
469
+ torch.ops.aten.index_fill_.Dimname_Tensor,
470
+ torch.ops.aten.isin.Tensor_Tensor,
471
+ torch.ops.aten.isin.Tensor_Tensor_out,
472
+ torch.ops.aten.isin.Tensor_Scalar,
473
+ torch.ops.aten.isin.Tensor_Scalar_out,
474
+ torch.ops.aten.isin.Scalar_Tensor,
475
+ torch.ops.aten.isin.Scalar_Tensor_out,
476
+ torch.ops.aten.isneginf.default,
477
+ torch.ops.aten.isneginf.out,
478
+ torch.ops.aten.isposinf.default,
479
+ torch.ops.aten.isposinf.out,
480
+ torch.ops.aten.leaky_relu_.default,
481
+ torch.ops.aten.leaky_relu_backward.default,
482
+ torch.ops.aten.leaky_relu_backward.grad_input,
483
+ torch.ops.aten.lerp.Scalar,
484
+ torch.ops.aten.lerp.Tensor,
485
+ torch.ops.aten.lerp.Scalar_out,
486
+ torch.ops.aten.lerp.Tensor_out,
487
+ torch.ops.aten.lerp_.Scalar,
488
+ torch.ops.aten.lerp_.Tensor,
489
+ torch.ops.aten.linspace.Tensor_Tensor,
490
+ torch.ops.aten.linspace.Tensor_Scalar,
491
+ torch.ops.aten.linspace.Scalar_Tensor,
492
+ torch.ops.aten.linspace.default,
493
+ torch.ops.aten.linspace.out,
494
+ torch.ops.aten.linspace.Tensor_Tensor_out,
495
+ torch.ops.aten.linspace.Tensor_Scalar_out,
496
+ torch.ops.aten.linspace.Scalar_Tensor_out,
497
+ torch.ops.aten.logaddexp.default,
498
+ torch.ops.aten.logaddexp.out,
499
+ torch.ops.aten.logaddexp2.default,
500
+ torch.ops.aten.logaddexp2.out,
501
+ torch.ops.aten.logit.default,
502
+ torch.ops.aten.logit.out,
503
+ torch.ops.aten.logit_.default,
504
+ torch.ops.aten.logit_backward.default,
505
+ torch.ops.aten.log_sigmoid_backward.default,
506
+ torch.ops.aten.log_sigmoid_backward.grad_input,
507
+ torch.ops.aten.log_sigmoid_forward.default,
508
+ torch.ops.aten.log_sigmoid_forward.output,
509
+ torch.ops.aten._log_softmax_backward_data.default,
510
+ torch.ops.aten._log_softmax_backward_data.out,
511
+ torch.ops.aten.logspace.Tensor_Tensor,
512
+ torch.ops.aten.logspace.Tensor_Scalar,
513
+ torch.ops.aten.logspace.Scalar_Tensor,
514
+ torch.ops.aten.logspace.default,
515
+ torch.ops.aten.logspace.out,
516
+ torch.ops.aten.logspace.Tensor_Tensor_out,
517
+ torch.ops.aten.logspace.Tensor_Scalar_out,
518
+ torch.ops.aten.logspace.Scalar_Tensor_out,
519
+ torch.ops.aten.logsumexp.default,
520
+ torch.ops.aten.masked_fill.Scalar,
521
+ torch.ops.aten.masked_fill.Tensor,
522
+ torch.ops.aten.masked_fill.Scalar_out,
523
+ torch.ops.aten.masked_fill.Tensor_out,
524
+ torch.ops.aten.masked_fill_.Scalar,
525
+ torch.ops.aten.masked_fill_.Tensor,
526
+ torch.ops.aten.mish.default,
527
+ torch.ops.aten.mish.out,
528
+ torch.ops.aten.mish_.default,
529
+ torch.ops.aten.mse_loss.default,
530
+ torch.ops.aten.mse_loss.out,
531
+ torch.ops.aten.mse_loss_backward.default,
532
+ torch.ops.aten.mse_loss_backward.grad_input,
533
+ torch.ops.aten.multi_margin_loss.default,
534
+ torch.ops.aten.multi_margin_loss.out,
535
+ torch.ops.aten.multilabel_margin_loss_forward.default,
536
+ torch.ops.aten.multilabel_margin_loss_forward.output,
537
+ torch.ops.aten.mv.default,
538
+ torch.ops.aten.mv.out,
539
+ torch.ops.aten.mvlgamma.default,
540
+ torch.ops.aten.mvlgamma.out,
541
+ torch.ops.aten.mvlgamma_.default,
542
+ torch.ops.aten.nansum.default,
543
+ torch.ops.aten.nansum.out,
544
+ torch.ops.aten.nan_to_num.default,
545
+ torch.ops.aten.nan_to_num.out,
546
+ torch.ops.aten.nan_to_num_.default,
547
+ torch.ops.aten.native_batch_norm_backward.default,
548
+ torch.ops.aten.native_batch_norm_backward.out,
549
+ torch.ops.aten.native_dropout_backward.default,
550
+ torch.ops.aten.native_dropout_backward.out,
551
+ torch.ops.aten.native_group_norm_backward.default,
552
+ torch.ops.aten.native_group_norm_backward.out,
553
+ torch.ops.aten.native_layer_norm_backward.default,
554
+ torch.ops.aten.native_layer_norm_backward.out,
555
+ torch.ops.aten.new_empty.default,
556
+ torch.ops.aten.new_empty.out,
557
+ torch.ops.aten.new_full.default,
558
+ torch.ops.aten.new_full.out,
559
+ torch.ops.aten.new_ones.default,
560
+ torch.ops.aten.new_ones.out,
561
+ torch.ops.aten.new_zeros.default,
562
+ torch.ops.aten.new_zeros.out,
563
+ torch.ops.aten.nll_loss2d_forward.default,
564
+ torch.ops.aten.nll_loss2d_forward.output,
565
+ torch.ops.aten.nll_loss2d_backward.default,
566
+ torch.ops.aten.nll_loss2d_backward.grad_input,
567
+ torch.ops.aten.nll_loss_backward.default,
568
+ torch.ops.aten.nll_loss_backward.grad_input,
569
+ torch.ops.aten.nll_loss_forward.default,
570
+ torch.ops.aten.nll_loss_forward.output,
571
+ torch.ops.aten.norm.Scalar,
572
+ torch.ops.aten.norm.ScalarOpt_dim,
573
+ torch.ops.aten.norm.names_ScalarOpt_dim,
574
+ torch.ops.aten.norm.ScalarOpt_dim_dtype,
575
+ torch.ops.aten.norm.dtype_out,
576
+ torch.ops.aten.norm.out,
577
+ torch.ops.aten.norm.ScalarOpt_dtype,
578
+ torch.ops.aten.norm.ScalarOpt_dtype_out,
579
+ torch.ops.aten.norm.Scalar_out,
580
+ torch.ops.aten.norm.names_ScalarOpt_dim_dtype,
581
+ torch.ops.aten.norm.names_dtype_out,
582
+ torch.ops.aten.norm.names_out,
583
+ torch.ops.aten.ones.default,
584
+ torch.ops.aten.ones_like.default,
585
+ torch.ops.aten.ones_like.out,
586
+ torch.ops.aten.pixel_shuffle.default,
587
+ torch.ops.aten.pixel_shuffle.out,
588
+ torch.ops.aten.pixel_unshuffle.default,
589
+ torch.ops.aten.pixel_unshuffle.out,
590
+ torch.ops.aten._prelu_kernel.default,
591
+ torch.ops.aten._prelu_kernel_backward.default,
592
+ torch.ops.aten._reshape_alias.default,
593
+ torch.ops.aten.rad2deg.default,
594
+ torch.ops.aten.rad2deg.out,
595
+ torch.ops.aten.rad2deg_.default,
596
+ torch.ops.aten.reflection_pad1d.default,
597
+ torch.ops.aten.reflection_pad1d.out,
598
+ torch.ops.aten.reflection_pad1d_backward.default,
599
+ torch.ops.aten.reflection_pad1d_backward.grad_input,
600
+ torch.ops.aten.reflection_pad2d.default,
601
+ torch.ops.aten.reflection_pad2d.out,
602
+ torch.ops.aten.reflection_pad2d_backward.default,
603
+ torch.ops.aten.reflection_pad2d_backward.grad_input,
604
+ torch.ops.aten.reflection_pad3d.default,
605
+ torch.ops.aten.reflection_pad3d.out,
606
+ torch.ops.aten.reflection_pad3d_backward.default,
607
+ torch.ops.aten.reflection_pad3d_backward.grad_input,
608
+ torch.ops.aten.replication_pad1d.default,
609
+ torch.ops.aten.replication_pad1d.out,
610
+ torch.ops.aten.replication_pad2d.default,
611
+ torch.ops.aten.replication_pad2d.out,
612
+ torch.ops.aten.replication_pad3d.default,
613
+ torch.ops.aten.replication_pad3d.out,
614
+ torch.ops.aten.renorm.default,
615
+ torch.ops.aten.renorm.out,
616
+ torch.ops.aten.renorm_.default,
617
+ torch.ops.aten.resize_as.default,
618
+ torch.ops.aten.resize_as.out,
619
+ torch.ops.aten.roll.default,
620
+ torch.ops.aten.roll.out,
621
+ torch.ops.aten.rot90.default,
622
+ torch.ops.aten.rot90.out,
623
+ torch.ops.aten.rrelu_with_noise.default,
624
+ torch.ops.aten.rrelu_with_noise.out,
625
+ torch.ops.aten.rrelu_with_noise_.default,
626
+ torch.ops.aten.rsub.Tensor,
627
+ torch.ops.aten.rsub.Scalar,
628
+ torch.ops.aten.rsub.Tensor_out,
629
+ torch.ops.aten.rsub.Scalar_out,
630
+ torch.ops.aten._safe_softmax.default,
631
+ torch.ops.aten._scaled_dot_product_flash_attention_for_cpu.default,
632
+ torch.ops.aten.select_backward.default,
633
+ torch.ops.aten.select_backward.out,
634
+ torch.ops.aten.select_scatter.default,
635
+ torch.ops.aten.select_scatter.out,
636
+ torch.ops.aten.sgn.default,
637
+ torch.ops.aten.sgn.out,
638
+ torch.ops.aten.sgn_.default,
639
+ torch.ops.aten.sigmoid_backward.default,
640
+ torch.ops.aten.sigmoid_backward.grad_input,
641
+ torch.ops.aten.silu.default,
642
+ torch.ops.aten.silu.out,
643
+ torch.ops.aten.silu_.default,
644
+ torch.ops.aten.silu_backward.default,
645
+ torch.ops.aten.silu_backward.grad_input,
646
+ torch.ops.aten.sinc.default,
647
+ torch.ops.aten.sinc.out,
648
+ torch.ops.aten.sinc_.default,
649
+ torch.ops.aten.slice_backward.default,
650
+ torch.ops.aten.slice_backward.out,
651
+ torch.ops.aten.smooth_l1_loss.default,
652
+ torch.ops.aten.smooth_l1_loss.out,
653
+ torch.ops.aten.smooth_l1_loss_backward.default,
654
+ torch.ops.aten.smooth_l1_loss_backward.grad_input,
655
+ torch.ops.aten.soft_margin_loss.default,
656
+ torch.ops.aten.soft_margin_loss.out,
657
+ torch.ops.aten.soft_margin_loss_backward.default,
658
+ torch.ops.aten.soft_margin_loss_backward.grad_input,
659
+ torch.ops.aten._softmax_backward_data.default,
660
+ torch.ops.aten._softmax_backward_data.out,
661
+ torch.ops.aten.softplus.default,
662
+ torch.ops.aten.softplus.out,
663
+ torch.ops.aten.softplus_backward.default,
664
+ torch.ops.aten.softplus_backward.grad_input,
665
+ torch.ops.aten.softshrink.default,
666
+ torch.ops.aten.softshrink.out,
667
+ torch.ops.aten.special_entr.default,
668
+ torch.ops.aten.special_entr.out,
669
+ torch.ops.aten.special_log_ndtr.default,
670
+ torch.ops.aten.special_log_ndtr.out,
671
+ torch.ops.aten.special_xlog1py.default,
672
+ torch.ops.aten.special_xlog1py.other_scalar,
673
+ torch.ops.aten.special_xlog1py.self_scalar,
674
+ torch.ops.aten.special_xlog1py.out,
675
+ torch.ops.aten.special_xlog1py.self_scalar_out,
676
+ torch.ops.aten.special_xlog1py.other_scalar_out,
677
+ torch.ops.aten.split.Tensor,
678
+ torch.ops.aten.split_with_sizes_copy.default,
679
+ torch.ops.aten.split_with_sizes_copy.out,
680
+ torch.ops.aten.squeeze.default,
681
+ torch.ops.aten.squeeze.dim,
682
+ torch.ops.aten.std.default,
683
+ torch.ops.aten.std.dim,
684
+ torch.ops.aten.std.correction,
685
+ torch.ops.aten.std.names_dim,
686
+ torch.ops.aten.std.names_out,
687
+ torch.ops.aten.std.out,
688
+ torch.ops.aten.std.correction_out,
689
+ torch.ops.aten.std.correction_names,
690
+ torch.ops.aten.std.correction_names_out,
691
+ torch.ops.aten.std_mean.default,
692
+ torch.ops.aten.std_mean.dim,
693
+ torch.ops.aten.std_mean.correction,
694
+ torch.ops.aten.std_mean.names_dim,
695
+ torch.ops.aten.std_mean.correction_names,
696
+ torch.ops.aten.std_mean.correction_out,
697
+ torch.ops.aten.stack.default,
698
+ torch.ops.aten.stack.out,
699
+ torch.ops.aten.sum.default,
700
+ torch.ops.aten.sum.out,
701
+ torch.ops.aten.t.default,
702
+ torch.ops.aten.t_copy.out,
703
+ torch.ops.aten.t_copy.default,
704
+ torch.ops.aten.take.default,
705
+ torch.ops.aten.take.out,
706
+ torch.ops.aten.tanh_backward.default,
707
+ torch.ops.aten.tanh_backward.grad_input,
708
+ torch.ops.aten.threshold.default,
709
+ torch.ops.aten.threshold.out,
710
+ torch.ops.aten.threshold_.default,
711
+ torch.ops.aten.threshold_backward.default,
712
+ torch.ops.aten.threshold_backward.grad_input,
713
+ torch.ops.aten.trace.default,
714
+ torch.ops.aten.trace.out,
715
+ torch.ops.aten.transpose.int,
716
+ torch.ops.aten.tril.default,
717
+ torch.ops.aten.tril.out,
718
+ torch.ops.aten.tril_.default,
719
+ torch.ops.aten.triu.default,
720
+ torch.ops.aten.triu.out,
721
+ torch.ops.aten.triu_.default,
722
+ torch.ops.aten.unbind.int,
723
+ torch.ops.aten.unbind.Dimname,
724
+ torch.ops.aten.unfold_backward.default,
725
+ torch.ops.aten.unfold_backward.out,
726
+ torch.ops.aten.unfold_copy.default,
727
+ torch.ops.aten.unfold_copy.out,
728
+ torch.ops.aten._unsafe_index.Tensor,
729
+ torch.ops.aten._unsafe_index_put.default,
730
+ torch.ops.aten._unsafe_masked_index.default,
731
+ torch.ops.aten._unsafe_masked_index_put_accumulate.default,
732
+ torch.ops.aten.unsafe_split.Tensor,
733
+ torch.ops.aten.unsafe_split_with_sizes.default,
734
+ torch.ops.aten.unsqueeze_copy.out,
735
+ torch.ops.aten.unsqueeze_copy.default,
736
+ torch.ops.aten._unsafe_view.default,
737
+ torch.ops.aten._unsafe_view.out,
738
+ torch.ops.aten.upsample_linear1d.default,
739
+ torch.ops.aten.upsample_linear1d.out,
740
+ torch.ops.aten.upsample_bilinear2d.vec,
741
+ torch.ops.aten.upsample_bilinear2d.default,
742
+ torch.ops.aten.upsample_bilinear2d.out,
743
+ torch.ops.aten.upsample_trilinear3d.vec,
744
+ torch.ops.aten.upsample_trilinear3d.default,
745
+ torch.ops.aten.upsample_trilinear3d.out,
746
+ torch.ops.aten.xlogy.Tensor,
747
+ torch.ops.aten.xlogy.Scalar_Other,
748
+ torch.ops.aten.xlogy.Scalar_Self,
749
+ torch.ops.aten.xlogy.OutTensor,
750
+ torch.ops.aten.xlogy.OutScalar_Self,
751
+ torch.ops.aten.xlogy.OutScalar_Other,
752
+ torch.ops.aten.xlogy_.Tensor,
753
+ torch.ops.aten.xlogy_.Scalar_Other,
754
+ torch.ops.aten.zero.default,
755
+ torch.ops.aten.zero.out,
756
+ torch.ops.aten.zero_.default,
757
+ torch.ops.aten.zeros.default,
758
+ torch.ops.aten.zeros_like.default,
759
+ torch.ops.aten.zeros_like.out,
760
+ torch.ops.aten._chunk_cat.default,
761
+ torch.ops.aten._chunk_cat.out,
762
+ torch.ops.aten._weight_norm_interface.default,
763
+ torch.ops.aten._weight_norm_interface.out,
764
+ torch.ops.aten.__iand__.Tensor,
765
+ torch.ops.aten.__ixor__.Tensor,
766
+ torch.ops.aten.__ilshift__.Tensor,
767
+ torch.ops.aten.__ilshift__.Scalar,
768
+ torch.ops.aten.__irshift__.Tensor,
769
+ torch.ops.aten.__irshift__.Scalar,
770
+ torch.ops.aten.__ior__.Tensor,
308
771
  ])
772
+
773
+ MUTABLE_DECOMPOSITION = [
774
+ torch.ops.aten.bernoulli_.Tensor,
775
+ torch.ops.aten.bernoulli_.float,
776
+ ]