torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/__init__.py CHANGED
@@ -13,16 +13,21 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import contextlib
16
- from typing import List, Dict, Any, Optional
17
16
  import dataclasses
18
- import jax
19
17
  import os
18
+ from contextlib import contextmanager
19
+ from typing import Any
20
+
21
+ import jax
20
22
  import torch
21
23
  from torch.utils import _pytree as pytree
24
+
25
+ import torchax.device_module
22
26
  from torchax import tensor
23
- from contextlib import contextmanager
24
27
 
25
- __version__ = "0.0.10.dev20251114"
28
+ from .checkpoint import load_checkpoint, save_checkpoint
29
+
30
+ __version__ = "0.0.11.dev202612"
26
31
  VERSION = __version__
27
32
 
28
33
  # the "fast path" uses some sparse tensor thingies that currently we
@@ -31,123 +36,114 @@ torch.backends.mha.set_fastpath_enabled(False)
31
36
 
32
37
 
33
38
  __all__ = [
34
- "default_env",
35
- "extract_jax",
36
- "enable_globally",
37
- "save_checkpoint",
38
- "load_checkpoint",
39
+ "default_env",
40
+ "extract_jax",
41
+ "enable_globally",
42
+ "save_checkpoint",
43
+ "load_checkpoint",
39
44
  ]
40
45
 
41
- from .checkpoint import save_checkpoint, load_checkpoint
42
46
 
43
47
  os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
44
48
 
45
49
  # torchax:oss-begin
46
50
  if getattr(jax.config, "jax_pjrt_client_create_options", None):
47
- jax.config.update(
48
- "jax_pjrt_client_create_options",
49
- f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
50
- )
51
+ jax.config.update(
52
+ "jax_pjrt_client_create_options",
53
+ f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
54
+ )
51
55
  # torchax:oss-end
52
56
 
53
57
  env = None
54
58
 
55
59
 
56
60
  def default_env():
57
- global env
61
+ global env
58
62
 
59
- if env is None:
60
- env = tensor.Environment()
61
- return env
63
+ if env is None:
64
+ env = tensor.Environment()
65
+ return env
62
66
 
63
67
 
64
68
  def extract_jax(mod: torch.nn.Module, env=None):
65
- """Returns a pytree of jax.ndarray and a jax callable."""
66
- if env is None:
67
- env = default_env()
68
- states = dict(mod.named_buffers())
69
- states.update(mod.named_parameters())
69
+ """Returns a pytree of jax.ndarray and a jax callable."""
70
+ if env is None:
71
+ env = default_env()
72
+ states = dict(mod.named_buffers())
73
+ states.update(mod.named_parameters())
70
74
 
71
- states = env.t2j_copy(states)
75
+ states = env.t2j_copy(states)
72
76
 
73
- # @jax.jit
74
- def jax_func(states, args, kwargs=None):
75
- (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
76
- with env:
77
- res = torch.func.functional_call(
78
- mod, states, args, kwargs, tie_weights=False
79
- )
80
- return env.t2j_iso(res)
77
+ # @jax.jit
78
+ def jax_func(states, args, kwargs=None):
79
+ (states, args, kwargs) = env.j2t_iso((states, args, kwargs))
80
+ with env:
81
+ res = torch.func.functional_call(mod, states, args, kwargs, tie_weights=False)
82
+ return env.t2j_iso(res)
81
83
 
82
- return states, jax_func
84
+ return states, jax_func
83
85
 
84
86
 
85
87
  def enable_globally():
86
- env = default_env().enable_torch_modes()
87
- return env
88
+ env = default_env().enable_torch_modes()
89
+ return env
88
90
 
89
91
 
90
92
  def disable_globally():
91
- global env
92
- default_env().disable_torch_modes()
93
+ global env
94
+ default_env().disable_torch_modes()
93
95
 
94
96
 
95
97
  @contextlib.contextmanager
96
98
  def disable_temporarily():
97
- prev = default_env().enabled
98
- if prev:
99
- disable_globally()
100
- yield ()
101
- if prev:
102
- enable_globally()
99
+ prev = default_env().enabled
100
+ if prev:
101
+ disable_globally()
102
+ yield ()
103
+ if prev:
104
+ enable_globally()
103
105
 
104
106
 
105
107
  torch.utils.rename_privateuse1_backend("jax")
106
108
  unsupported_dtype = [torch.quint8]
107
109
 
108
- import jax
109
- import torchax.device_module
110
110
 
111
111
  torch._register_device_module("jax", torchax.device_module)
112
112
 
113
113
 
114
114
  def enable_accuracy_mode():
115
- jax.config.update("jax_enable_x64", True)
116
- jax.config.update("jax_default_matmul_precision", "highest")
117
- default_env().config.internal_respect_torch_return_dtypes = True
115
+ jax.config.update("jax_enable_x64", True)
116
+ jax.config.update("jax_default_matmul_precision", "highest")
117
+ default_env().config.internal_respect_torch_return_dtypes = True
118
118
 
119
119
 
120
120
  def enable_performance_mode():
121
- jax.config.update("jax_enable_x64", False)
122
- jax.config.update("jax_default_matmul_precision", "default")
123
- default_env().config.internal_respect_torch_return_dtypes = False
121
+ jax.config.update("jax_enable_x64", False)
122
+ jax.config.update("jax_default_matmul_precision", "default")
123
+ default_env().config.internal_respect_torch_return_dtypes = False
124
124
 
125
125
 
126
126
  @dataclasses.dataclass
127
127
  class CompileOptions:
128
- # only valid if compiling nn.Module
129
- methods_to_compile: List[str] = dataclasses.field(
130
- default_factory=lambda: ["forward"]
131
- )
132
- jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
133
- mode: str = "jax" # or dynamo or export
134
-
135
-
136
- def compile(fn, options: Optional[CompileOptions] = None):
137
- options = options or CompileOptions()
138
- if options.mode == "jax":
139
- from torchax import interop
140
-
141
- if isinstance(fn, torch.nn.Module):
142
- module = interop.JittableModule(
143
- fn, extra_jit_args=options.jax_jit_kwargs
144
- )
145
- for n in options.methods_to_compile:
146
- module.make_jitted(n)
147
- return module
148
- else:
149
- return interop.jax_jit(fn)
150
- elif options.mode == "dynamo":
151
- raise RuntimeError("dynamo mode is not supported yet")
152
- elif options.mode == "export":
153
- raise RuntimeError("export mode is not supported yet")
128
+ # only valid if compiling nn.Module
129
+ methods_to_compile: list[str] = dataclasses.field(default_factory=lambda: ["forward"])
130
+ jax_jit_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
131
+ mode: str = "jax" # or dynamo or export
132
+
133
+
134
+ def compile(fn, options: CompileOptions | None = None):
135
+ options = options or CompileOptions()
136
+ if options.mode == "jax":
137
+ from torchax import interop
138
+
139
+ if isinstance(fn, torch.nn.Module):
140
+ module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
141
+ for n in options.methods_to_compile:
142
+ module.make_jitted(n)
143
+ return module
144
+ else:
145
+ return interop.jax_jit(fn)
146
+ elif options.mode == "dynamo":
147
+ raise RuntimeError("dynamo mode is not supported yet")
148
+ elif options.mode == "export":
149
+ raise RuntimeError("export mode is not supported yet")
torchax/amp.py CHANGED
@@ -14,6 +14,7 @@
14
14
 
15
15
  import contextlib
16
16
  import enum
17
+
17
18
  import torch
18
19
  from torch.utils import _pytree as pytree
19
20
 
@@ -50,23 +51,24 @@ class CastPolicy(enum.Enum):
50
51
 
51
52
 
52
53
  def execute_policy(policy, args, kwargs, target_lower_fp):
53
-
54
54
  def is_float(a):
55
55
  return isinstance(a, torch.Tensor) and a.is_floating_point()
56
+
56
57
  match policy:
57
58
  case CastPolicy.LOWER_PRECISION_FP:
58
- return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
59
- (args, kwargs))
59
+ return pytree.tree_map_only(
60
+ is_float, lambda a: a.to(target_lower_fp), (args, kwargs)
61
+ )
60
62
  case CastPolicy.FP32:
61
- return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
62
- (args, kwargs))
63
+ return pytree.tree_map_only(
64
+ is_float, lambda a: a.to(torch.float32), (args, kwargs)
65
+ )
63
66
  case CastPolicy.PROMOTE:
64
- dtypes = set(a.dtype for a in args)
67
+ dtypes = {a.dtype for a in args}
65
68
  widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
66
- return pytree.tree_map_only(is_float, lambda a: a.to(widest),
67
- (args, kwargs))
69
+ return pytree.tree_map_only(is_float, lambda a: a.to(widest), (args, kwargs))
68
70
  case _:
69
- raise AssertionError(f'Policy {policy} not implemented yet.')
71
+ raise AssertionError(f"Policy {policy} not implemented yet.")
70
72
 
71
73
 
72
74
  @contextlib.contextmanager
@@ -74,6 +76,7 @@ def autocast(device, dtype=torch.bfloat16, env=None):
74
76
  del device
75
77
  if env is None:
76
78
  import torchax
79
+
77
80
  env = torchax.default_env()
78
81
  with env.override_property(autocast_dtype=dtype):
79
82
  yield
@@ -81,266 +84,135 @@ def autocast(device, dtype=torch.bfloat16, env=None):
81
84
 
82
85
  # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
83
86
  autocast_policy = {
84
- torch.ops.aten.conv1d.default:
85
- CastPolicy.LOWER_PRECISION_FP,
86
- torch.ops.aten.conv1d.padding:
87
- CastPolicy.LOWER_PRECISION_FP,
88
- torch.ops.aten.conv2d.default:
89
- CastPolicy.LOWER_PRECISION_FP,
90
- torch.ops.aten.conv2d.padding:
91
- CastPolicy.LOWER_PRECISION_FP,
92
- torch.ops.aten.conv3d.default:
93
- CastPolicy.LOWER_PRECISION_FP,
94
- torch.ops.aten.conv3d.padding:
95
- CastPolicy.LOWER_PRECISION_FP,
96
- torch.ops.aten.bmm.default:
97
- CastPolicy.LOWER_PRECISION_FP,
98
- torch.ops.aten.mm.default:
99
- CastPolicy.LOWER_PRECISION_FP,
100
- torch.ops.aten.linalg_vecdot.default:
101
- CastPolicy.LOWER_PRECISION_FP,
102
- torch.ops.aten.baddbmm.default:
103
- CastPolicy.LOWER_PRECISION_FP,
104
- torch.ops.aten.addmm.default:
105
- CastPolicy.LOWER_PRECISION_FP,
106
- torch.ops.aten._addmm_activation.default:
107
- CastPolicy.LOWER_PRECISION_FP,
108
- torch.ops.aten.addbmm.default:
109
- CastPolicy.LOWER_PRECISION_FP,
110
- torch.ops.aten.linear.default:
111
- CastPolicy.LOWER_PRECISION_FP,
112
- torch.ops.aten._convolution.deprecated:
113
- CastPolicy.LOWER_PRECISION_FP,
114
- torch.ops.aten.matmul.default:
115
- CastPolicy.LOWER_PRECISION_FP,
116
- torch.ops.aten.conv_tbc.default:
117
- CastPolicy.LOWER_PRECISION_FP,
118
- torch.ops.aten.mkldnn_rnn_layer.default:
119
- CastPolicy.LOWER_PRECISION_FP,
120
- torch.ops.aten.conv_transpose1d.default:
121
- CastPolicy.LOWER_PRECISION_FP,
122
- torch.ops.aten.conv_transpose2d.input:
123
- CastPolicy.LOWER_PRECISION_FP,
124
- torch.ops.aten.conv_transpose3d.input:
125
- CastPolicy.LOWER_PRECISION_FP,
126
- torch.ops.aten.prelu.default:
127
- CastPolicy.LOWER_PRECISION_FP,
128
- torch.ops.aten.scaled_dot_product_attention.default:
129
- CastPolicy.LOWER_PRECISION_FP,
130
- torch.ops.aten._native_multi_head_attention.default:
131
- CastPolicy.LOWER_PRECISION_FP,
132
-
133
- # fp32 cast policy
134
- torch.ops.aten.avg_pool3d.default:
135
- CastPolicy.FP32,
136
- torch.ops.aten.binary_cross_entropy.default:
137
- CastPolicy.FP32,
138
- torch.ops.aten.grid_sampler.default:
139
- CastPolicy.FP32,
140
- torch.ops.aten.polar.default:
141
- CastPolicy.FP32,
142
- torch.ops.aten.prod.default:
143
- CastPolicy.FP32,
144
- torch.ops.aten.prod.dim_int:
145
- CastPolicy.FP32,
146
- torch.ops.aten.prod.dim_Dimname:
147
- CastPolicy.FP32,
148
- torch.ops.aten.quantile.default:
149
- CastPolicy.FP32,
150
- torch.ops.aten.quantile.scalar:
151
- CastPolicy.FP32,
152
- torch.ops.aten.nanquantile.default:
153
- CastPolicy.FP32,
154
- torch.ops.aten.nanquantile.scalar:
155
- CastPolicy.FP32,
156
- torch.ops.aten.stft.default:
157
- CastPolicy.FP32,
158
- torch.ops.aten.stft.center:
159
- CastPolicy.FP32,
160
- torch.ops.aten.cdist.default:
161
- CastPolicy.FP32,
162
- torch.ops.aten.grid_sampler_2d.default:
163
- CastPolicy.FP32,
164
- torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
165
- CastPolicy.FP32,
166
- torch.ops.aten.grid_sampler_3d.default:
167
- CastPolicy.FP32,
168
- torch.ops.aten.trace.default:
169
- CastPolicy.FP32,
170
- torch.ops.aten.view_as_complex.default:
171
- CastPolicy.FP32,
172
- torch.ops.aten.cholesky.default:
173
- CastPolicy.FP32,
174
- torch.ops.aten.cholesky_inverse.default:
175
- CastPolicy.FP32,
176
- torch.ops.aten.cholesky_solve.default:
177
- CastPolicy.FP32,
178
- torch.ops.aten.inverse.default:
179
- CastPolicy.FP32,
180
- torch.ops.aten.lu_solve.default:
181
- CastPolicy.FP32,
182
- torch.ops.aten.orgqr.default:
183
- CastPolicy.FP32,
184
- torch.ops.aten.ormqr.default:
185
- CastPolicy.FP32,
186
- torch.ops.aten.pinverse.default:
187
- CastPolicy.FP32,
188
- torch.ops.aten.max_pool3d.default:
189
- CastPolicy.FP32,
190
- torch.ops.aten.max_unpool2d.default:
191
- CastPolicy.FP32,
192
- torch.ops.aten.max_unpool3d.default:
193
- CastPolicy.FP32,
194
- torch.ops.aten.adaptive_avg_pool3d.default:
195
- CastPolicy.FP32,
196
- torch.ops.aten.reflection_pad1d.default:
197
- CastPolicy.FP32,
198
- torch.ops.aten.reflection_pad2d.default:
199
- CastPolicy.FP32,
200
- torch.ops.aten.replication_pad1d.default:
201
- CastPolicy.FP32,
202
- torch.ops.aten.replication_pad2d.default:
203
- CastPolicy.FP32,
204
- torch.ops.aten.replication_pad3d.default:
205
- CastPolicy.FP32,
206
- torch.ops.aten.mse_loss.default:
207
- CastPolicy.FP32,
208
- torch.ops.aten.cosine_embedding_loss.default:
209
- CastPolicy.FP32,
210
- torch.ops.aten.nll_loss.default:
211
- CastPolicy.FP32,
212
- torch.ops.aten.nll_loss2d.default:
213
- CastPolicy.FP32,
214
- torch.ops.aten.hinge_embedding_loss.default:
215
- CastPolicy.FP32,
216
- torch.ops.aten.poisson_nll_loss.default:
217
- CastPolicy.FP32,
218
- torch.ops.aten.smooth_l1_loss.default:
219
- CastPolicy.FP32,
220
- torch.ops.aten.cross_entropy_loss.default:
221
- CastPolicy.FP32,
222
- torch.ops.aten.l1_loss.default:
223
- CastPolicy.FP32,
224
- torch.ops.aten.huber_loss.default:
225
- CastPolicy.FP32,
226
- torch.ops.aten.margin_ranking_loss.default:
227
- CastPolicy.FP32,
228
- torch.ops.aten.soft_margin_loss.default:
229
- CastPolicy.FP32,
230
- torch.ops.aten.triplet_margin_loss.default:
231
- CastPolicy.FP32,
232
- torch.ops.aten.multi_margin_loss.default:
233
- CastPolicy.FP32,
234
- torch.ops.aten.ctc_loss.IntList:
235
- CastPolicy.FP32,
236
- torch.ops.aten.ctc_loss.Tensor:
237
- CastPolicy.FP32,
238
- torch.ops.aten.kl_div.default:
239
- CastPolicy.FP32,
240
- torch.ops.aten.multilabel_margin_loss.default:
241
- CastPolicy.FP32,
242
- torch.ops.aten.binary_cross_entropy_with_logits.default:
243
- CastPolicy.FP32,
244
- torch.ops.aten.fft_fft.default:
245
- CastPolicy.FP32,
246
- torch.ops.aten.fft_ifft.default:
247
- CastPolicy.FP32,
248
- torch.ops.aten.fft_fft2.default:
249
- CastPolicy.FP32,
250
- torch.ops.aten.fft_ifft2.default:
251
- CastPolicy.FP32,
252
- torch.ops.aten.fft_fftn.default:
253
- CastPolicy.FP32,
254
- torch.ops.aten.fft_ifftn.default:
255
- CastPolicy.FP32,
256
- torch.ops.aten.fft_rfft.default:
257
- CastPolicy.FP32,
258
- torch.ops.aten.fft_irfft.default:
259
- CastPolicy.FP32,
260
- torch.ops.aten.fft_rfft2.default:
261
- CastPolicy.FP32,
262
- torch.ops.aten.fft_irfft2.default:
263
- CastPolicy.FP32,
264
- torch.ops.aten.fft_rfftn.default:
265
- CastPolicy.FP32,
266
- torch.ops.aten.fft_irfftn.default:
267
- CastPolicy.FP32,
268
- torch.ops.aten.fft_hfft.default:
269
- CastPolicy.FP32,
270
- torch.ops.aten.fft_ihfft.default:
271
- CastPolicy.FP32,
272
- torch.ops.aten.linalg_cond.default:
273
- CastPolicy.FP32,
274
- torch.ops.aten.linalg_cond.p_str:
275
- CastPolicy.FP32,
276
- torch.ops.aten.linalg_matrix_rank.default:
277
- CastPolicy.FP32,
278
- torch.ops.aten.linalg_matrix_rank.tol_tensor:
279
- CastPolicy.FP32,
280
- torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
281
- CastPolicy.FP32,
282
- torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
283
- CastPolicy.FP32,
284
- torch.ops.aten.linalg_solve.default:
285
- CastPolicy.FP32,
286
- torch.ops.aten.linalg_cholesky.default:
287
- CastPolicy.FP32,
288
- torch.ops.aten.linalg_svdvals.default:
289
- CastPolicy.FP32,
290
- torch.ops.aten.linalg_eigvals.default:
291
- CastPolicy.FP32,
292
- torch.ops.aten.linalg_eigvalsh.default:
293
- CastPolicy.FP32,
294
- torch.ops.aten.linalg_inv.default:
295
- CastPolicy.FP32,
296
- torch.ops.aten.linalg_householder_product.default:
297
- CastPolicy.FP32,
298
- torch.ops.aten.linalg_tensorinv.default:
299
- CastPolicy.FP32,
300
- torch.ops.aten.linalg_tensorsolve.default:
301
- CastPolicy.FP32,
302
- torch.ops.aten.fake_quantize_per_tensor_affine.default:
303
- CastPolicy.FP32,
304
- torch.ops.aten.geqrf.default:
305
- CastPolicy.FP32,
306
- torch.ops.aten._lu_with_info.default:
307
- CastPolicy.FP32,
308
- torch.ops.aten.qr.default:
309
- CastPolicy.FP32,
310
- torch.ops.aten.svd.default:
311
- CastPolicy.FP32,
312
- torch.ops.aten.triangular_solve.default:
313
- CastPolicy.FP32,
314
- torch.ops.aten.fractional_max_pool2d.default:
315
- CastPolicy.FP32,
316
- torch.ops.aten.fractional_max_pool3d.default:
317
- CastPolicy.FP32,
318
- torch.ops.aten.adaptive_max_pool3d.default:
319
- CastPolicy.FP32,
320
- torch.ops.aten.multilabel_margin_loss_forward.default:
321
- CastPolicy.FP32,
322
- torch.ops.aten.linalg_qr.default:
323
- CastPolicy.FP32,
324
- torch.ops.aten.linalg_cholesky_ex.default:
325
- CastPolicy.FP32,
326
- torch.ops.aten.linalg_svd.default:
327
- CastPolicy.FP32,
328
- torch.ops.aten.linalg_eig.default:
329
- CastPolicy.FP32,
330
- torch.ops.aten.linalg_eigh.default:
331
- CastPolicy.FP32,
332
- torch.ops.aten.linalg_lstsq.default:
333
- CastPolicy.FP32,
334
- torch.ops.aten.linalg_inv_ex.default:
335
- CastPolicy.FP32,
336
-
337
- # promote
338
- torch.ops.aten.stack.default:
339
- CastPolicy.PROMOTE,
340
- torch.ops.aten.cat.default:
341
- CastPolicy.PROMOTE,
342
- torch.ops.aten.index_copy.default:
343
- CastPolicy.PROMOTE,
344
- torch.ops.aten.index_copy.dimname:
345
- CastPolicy.PROMOTE,
87
+ torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP,
88
+ torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP,
89
+ torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP,
90
+ torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP,
91
+ torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP,
92
+ torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP,
93
+ torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP,
94
+ torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP,
95
+ torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP,
96
+ torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP,
97
+ torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP,
98
+ torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP,
99
+ torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP,
100
+ torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP,
101
+ torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP,
102
+ torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP,
103
+ torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP,
104
+ torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP,
105
+ torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP,
106
+ torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP,
107
+ torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP,
108
+ torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP,
109
+ torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP,
110
+ torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP,
111
+ # fp32 cast policy
112
+ torch.ops.aten.avg_pool3d.default: CastPolicy.FP32,
113
+ torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32,
114
+ torch.ops.aten.grid_sampler.default: CastPolicy.FP32,
115
+ torch.ops.aten.polar.default: CastPolicy.FP32,
116
+ torch.ops.aten.prod.default: CastPolicy.FP32,
117
+ torch.ops.aten.prod.dim_int: CastPolicy.FP32,
118
+ torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32,
119
+ torch.ops.aten.quantile.default: CastPolicy.FP32,
120
+ torch.ops.aten.quantile.scalar: CastPolicy.FP32,
121
+ torch.ops.aten.nanquantile.default: CastPolicy.FP32,
122
+ torch.ops.aten.nanquantile.scalar: CastPolicy.FP32,
123
+ torch.ops.aten.stft.default: CastPolicy.FP32,
124
+ torch.ops.aten.stft.center: CastPolicy.FP32,
125
+ torch.ops.aten.cdist.default: CastPolicy.FP32,
126
+ torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32,
127
+ torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32,
128
+ torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32,
129
+ torch.ops.aten.trace.default: CastPolicy.FP32,
130
+ torch.ops.aten.view_as_complex.default: CastPolicy.FP32,
131
+ torch.ops.aten.cholesky.default: CastPolicy.FP32,
132
+ torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32,
133
+ torch.ops.aten.cholesky_solve.default: CastPolicy.FP32,
134
+ torch.ops.aten.inverse.default: CastPolicy.FP32,
135
+ torch.ops.aten.lu_solve.default: CastPolicy.FP32,
136
+ torch.ops.aten.orgqr.default: CastPolicy.FP32,
137
+ torch.ops.aten.ormqr.default: CastPolicy.FP32,
138
+ torch.ops.aten.pinverse.default: CastPolicy.FP32,
139
+ torch.ops.aten.max_pool3d.default: CastPolicy.FP32,
140
+ torch.ops.aten.max_unpool2d.default: CastPolicy.FP32,
141
+ torch.ops.aten.max_unpool3d.default: CastPolicy.FP32,
142
+ torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32,
143
+ torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32,
144
+ torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32,
145
+ torch.ops.aten.replication_pad1d.default: CastPolicy.FP32,
146
+ torch.ops.aten.replication_pad2d.default: CastPolicy.FP32,
147
+ torch.ops.aten.replication_pad3d.default: CastPolicy.FP32,
148
+ torch.ops.aten.mse_loss.default: CastPolicy.FP32,
149
+ torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32,
150
+ torch.ops.aten.nll_loss.default: CastPolicy.FP32,
151
+ torch.ops.aten.nll_loss2d.default: CastPolicy.FP32,
152
+ torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32,
153
+ torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32,
154
+ torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32,
155
+ torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32,
156
+ torch.ops.aten.l1_loss.default: CastPolicy.FP32,
157
+ torch.ops.aten.huber_loss.default: CastPolicy.FP32,
158
+ torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32,
159
+ torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32,
160
+ torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32,
161
+ torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32,
162
+ torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32,
163
+ torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32,
164
+ torch.ops.aten.kl_div.default: CastPolicy.FP32,
165
+ torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32,
166
+ torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32,
167
+ torch.ops.aten.fft_fft.default: CastPolicy.FP32,
168
+ torch.ops.aten.fft_ifft.default: CastPolicy.FP32,
169
+ torch.ops.aten.fft_fft2.default: CastPolicy.FP32,
170
+ torch.ops.aten.fft_ifft2.default: CastPolicy.FP32,
171
+ torch.ops.aten.fft_fftn.default: CastPolicy.FP32,
172
+ torch.ops.aten.fft_ifftn.default: CastPolicy.FP32,
173
+ torch.ops.aten.fft_rfft.default: CastPolicy.FP32,
174
+ torch.ops.aten.fft_irfft.default: CastPolicy.FP32,
175
+ torch.ops.aten.fft_rfft2.default: CastPolicy.FP32,
176
+ torch.ops.aten.fft_irfft2.default: CastPolicy.FP32,
177
+ torch.ops.aten.fft_rfftn.default: CastPolicy.FP32,
178
+ torch.ops.aten.fft_irfftn.default: CastPolicy.FP32,
179
+ torch.ops.aten.fft_hfft.default: CastPolicy.FP32,
180
+ torch.ops.aten.fft_ihfft.default: CastPolicy.FP32,
181
+ torch.ops.aten.linalg_cond.default: CastPolicy.FP32,
182
+ torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32,
183
+ torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32,
184
+ torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32,
185
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32,
186
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32,
187
+ torch.ops.aten.linalg_solve.default: CastPolicy.FP32,
188
+ torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32,
189
+ torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32,
190
+ torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32,
191
+ torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32,
192
+ torch.ops.aten.linalg_inv.default: CastPolicy.FP32,
193
+ torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32,
194
+ torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32,
195
+ torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32,
196
+ torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32,
197
+ torch.ops.aten.geqrf.default: CastPolicy.FP32,
198
+ torch.ops.aten._lu_with_info.default: CastPolicy.FP32,
199
+ torch.ops.aten.qr.default: CastPolicy.FP32,
200
+ torch.ops.aten.svd.default: CastPolicy.FP32,
201
+ torch.ops.aten.triangular_solve.default: CastPolicy.FP32,
202
+ torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32,
203
+ torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32,
204
+ torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32,
205
+ torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32,
206
+ torch.ops.aten.linalg_qr.default: CastPolicy.FP32,
207
+ torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32,
208
+ torch.ops.aten.linalg_svd.default: CastPolicy.FP32,
209
+ torch.ops.aten.linalg_eig.default: CastPolicy.FP32,
210
+ torch.ops.aten.linalg_eigh.default: CastPolicy.FP32,
211
+ torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32,
212
+ torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32,
213
+ # promote
214
+ torch.ops.aten.stack.default: CastPolicy.PROMOTE,
215
+ torch.ops.aten.cat.default: CastPolicy.PROMOTE,
216
+ torch.ops.aten.index_copy.default: CastPolicy.PROMOTE,
217
+ torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE,
346
218
  }