torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202617__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/__init__.py +73 -77
- torchax/amp.py +143 -271
- torchax/checkpoint.py +15 -9
- torchax/config.py +0 -4
- torchax/decompositions.py +66 -60
- torchax/export.py +53 -54
- torchax/flax.py +7 -5
- torchax/interop.py +66 -62
- torchax/mesh_util.py +20 -18
- torchax/ops/__init__.py +4 -3
- torchax/ops/jaten.py +3841 -3968
- torchax/ops/jax_reimplement.py +68 -42
- torchax/ops/jc10d.py +4 -6
- torchax/ops/jimage.py +20 -25
- torchax/ops/jlibrary.py +6 -6
- torchax/ops/jtorch.py +355 -419
- torchax/ops/jtorchvision_nms.py +69 -49
- torchax/ops/mappings.py +42 -63
- torchax/ops/op_base.py +17 -25
- torchax/ops/ops_registry.py +35 -30
- torchax/tensor.py +124 -128
- torchax/train.py +100 -102
- torchax/types.py +8 -7
- torchax/util.py +6 -4
- torchax/view.py +144 -136
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/METADATA +7 -1
- torchax-0.0.11.dev202617.dist-info/RECORD +31 -0
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/WHEEL +1 -1
- torchax-0.0.10.dev20251116.dist-info/RECORD +0 -31
- {torchax-0.0.10.dev20251116.dist-info → torchax-0.0.11.dev202617.dist-info}/licenses/LICENSE +0 -0
torchax/__init__.py
CHANGED
|
@@ -13,16 +13,21 @@
|
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
|
|
15
15
|
import contextlib
|
|
16
|
-
from typing import List, Dict, Any, Optional
|
|
17
16
|
import dataclasses
|
|
18
|
-
import jax
|
|
19
17
|
import os
|
|
18
|
+
from contextlib import contextmanager
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
import jax
|
|
20
22
|
import torch
|
|
21
23
|
from torch.utils import _pytree as pytree
|
|
24
|
+
|
|
25
|
+
import torchax.device_module
|
|
22
26
|
from torchax import tensor
|
|
23
|
-
from contextlib import contextmanager
|
|
24
27
|
|
|
25
|
-
|
|
28
|
+
from .checkpoint import load_checkpoint, save_checkpoint
|
|
29
|
+
|
|
30
|
+
__version__ = "0.0.11.dev202617"
|
|
26
31
|
VERSION = __version__
|
|
27
32
|
|
|
28
33
|
# the "fast path" uses some sparse tensor thingies that currently we
|
|
@@ -31,123 +36,114 @@ torch.backends.mha.set_fastpath_enabled(False)
|
|
|
31
36
|
|
|
32
37
|
|
|
33
38
|
__all__ = [
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
+
"default_env",
|
|
40
|
+
"extract_jax",
|
|
41
|
+
"enable_globally",
|
|
42
|
+
"save_checkpoint",
|
|
43
|
+
"load_checkpoint",
|
|
39
44
|
]
|
|
40
45
|
|
|
41
|
-
from .checkpoint import save_checkpoint, load_checkpoint
|
|
42
46
|
|
|
43
47
|
os.environ.setdefault("ENABLE_RUNTIME_UPTIME_TELEMETRY", "1")
|
|
44
48
|
|
|
45
49
|
# torchax:oss-begin
|
|
46
50
|
if getattr(jax.config, "jax_pjrt_client_create_options", None):
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
+
jax.config.update(
|
|
52
|
+
"jax_pjrt_client_create_options",
|
|
53
|
+
f"ml_framework_name:PyTorch/XLA2;ml_framework_version:{'v0.0.1'}",
|
|
54
|
+
)
|
|
51
55
|
# torchax:oss-end
|
|
52
56
|
|
|
53
57
|
env = None
|
|
54
58
|
|
|
55
59
|
|
|
56
60
|
def default_env():
|
|
57
|
-
|
|
61
|
+
global env
|
|
58
62
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
63
|
+
if env is None:
|
|
64
|
+
env = tensor.Environment()
|
|
65
|
+
return env
|
|
62
66
|
|
|
63
67
|
|
|
64
68
|
def extract_jax(mod: torch.nn.Module, env=None):
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
69
|
+
"""Returns a pytree of jax.ndarray and a jax callable."""
|
|
70
|
+
if env is None:
|
|
71
|
+
env = default_env()
|
|
72
|
+
states = dict(mod.named_buffers())
|
|
73
|
+
states.update(mod.named_parameters())
|
|
70
74
|
|
|
71
|
-
|
|
75
|
+
states = env.t2j_copy(states)
|
|
72
76
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
)
|
|
80
|
-
return env.t2j_iso(res)
|
|
77
|
+
# @jax.jit
|
|
78
|
+
def jax_func(states, args, kwargs=None):
|
|
79
|
+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
|
|
80
|
+
with env:
|
|
81
|
+
res = torch.func.functional_call(mod, states, args, kwargs, tie_weights=False)
|
|
82
|
+
return env.t2j_iso(res)
|
|
81
83
|
|
|
82
|
-
|
|
84
|
+
return states, jax_func
|
|
83
85
|
|
|
84
86
|
|
|
85
87
|
def enable_globally():
|
|
86
|
-
|
|
87
|
-
|
|
88
|
+
env = default_env().enable_torch_modes()
|
|
89
|
+
return env
|
|
88
90
|
|
|
89
91
|
|
|
90
92
|
def disable_globally():
|
|
91
|
-
|
|
92
|
-
|
|
93
|
+
global env
|
|
94
|
+
default_env().disable_torch_modes()
|
|
93
95
|
|
|
94
96
|
|
|
95
97
|
@contextlib.contextmanager
|
|
96
98
|
def disable_temporarily():
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
99
|
+
prev = default_env().enabled
|
|
100
|
+
if prev:
|
|
101
|
+
disable_globally()
|
|
102
|
+
yield ()
|
|
103
|
+
if prev:
|
|
104
|
+
enable_globally()
|
|
103
105
|
|
|
104
106
|
|
|
105
107
|
torch.utils.rename_privateuse1_backend("jax")
|
|
106
108
|
unsupported_dtype = [torch.quint8]
|
|
107
109
|
|
|
108
|
-
import jax
|
|
109
|
-
import torchax.device_module
|
|
110
110
|
|
|
111
111
|
torch._register_device_module("jax", torchax.device_module)
|
|
112
112
|
|
|
113
113
|
|
|
114
114
|
def enable_accuracy_mode():
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
115
|
+
jax.config.update("jax_enable_x64", True)
|
|
116
|
+
jax.config.update("jax_default_matmul_precision", "highest")
|
|
117
|
+
default_env().config.internal_respect_torch_return_dtypes = True
|
|
118
118
|
|
|
119
119
|
|
|
120
120
|
def enable_performance_mode():
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
121
|
+
jax.config.update("jax_enable_x64", False)
|
|
122
|
+
jax.config.update("jax_default_matmul_precision", "default")
|
|
123
|
+
default_env().config.internal_respect_torch_return_dtypes = False
|
|
124
124
|
|
|
125
125
|
|
|
126
126
|
@dataclasses.dataclass
|
|
127
127
|
class CompileOptions:
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
elif options.mode == "dynamo":
|
|
151
|
-
raise RuntimeError("dynamo mode is not supported yet")
|
|
152
|
-
elif options.mode == "export":
|
|
153
|
-
raise RuntimeError("export mode is not supported yet")
|
|
128
|
+
# only valid if compiling nn.Module
|
|
129
|
+
methods_to_compile: list[str] = dataclasses.field(default_factory=lambda: ["forward"])
|
|
130
|
+
jax_jit_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
131
|
+
mode: str = "jax" # or dynamo or export
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def compile(fn, options: CompileOptions | None = None):
|
|
135
|
+
options = options or CompileOptions()
|
|
136
|
+
if options.mode == "jax":
|
|
137
|
+
from torchax import interop
|
|
138
|
+
|
|
139
|
+
if isinstance(fn, torch.nn.Module):
|
|
140
|
+
module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
|
|
141
|
+
for n in options.methods_to_compile:
|
|
142
|
+
module.make_jitted(n)
|
|
143
|
+
return module
|
|
144
|
+
else:
|
|
145
|
+
return interop.jax_jit(fn)
|
|
146
|
+
elif options.mode == "dynamo":
|
|
147
|
+
raise RuntimeError("dynamo mode is not supported yet")
|
|
148
|
+
elif options.mode == "export":
|
|
149
|
+
raise RuntimeError("export mode is not supported yet")
|
torchax/amp.py
CHANGED
|
@@ -14,6 +14,7 @@
|
|
|
14
14
|
|
|
15
15
|
import contextlib
|
|
16
16
|
import enum
|
|
17
|
+
|
|
17
18
|
import torch
|
|
18
19
|
from torch.utils import _pytree as pytree
|
|
19
20
|
|
|
@@ -50,23 +51,24 @@ class CastPolicy(enum.Enum):
|
|
|
50
51
|
|
|
51
52
|
|
|
52
53
|
def execute_policy(policy, args, kwargs, target_lower_fp):
|
|
53
|
-
|
|
54
54
|
def is_float(a):
|
|
55
55
|
return isinstance(a, torch.Tensor) and a.is_floating_point()
|
|
56
|
+
|
|
56
57
|
match policy:
|
|
57
58
|
case CastPolicy.LOWER_PRECISION_FP:
|
|
58
|
-
return pytree.tree_map_only(
|
|
59
|
-
|
|
59
|
+
return pytree.tree_map_only(
|
|
60
|
+
is_float, lambda a: a.to(target_lower_fp), (args, kwargs)
|
|
61
|
+
)
|
|
60
62
|
case CastPolicy.FP32:
|
|
61
|
-
return pytree.tree_map_only(
|
|
62
|
-
|
|
63
|
+
return pytree.tree_map_only(
|
|
64
|
+
is_float, lambda a: a.to(torch.float32), (args, kwargs)
|
|
65
|
+
)
|
|
63
66
|
case CastPolicy.PROMOTE:
|
|
64
|
-
dtypes =
|
|
67
|
+
dtypes = {a.dtype for a in args}
|
|
65
68
|
widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
|
|
66
|
-
return pytree.tree_map_only(is_float, lambda a: a.to(widest),
|
|
67
|
-
(args, kwargs))
|
|
69
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(widest), (args, kwargs))
|
|
68
70
|
case _:
|
|
69
|
-
raise AssertionError(f
|
|
71
|
+
raise AssertionError(f"Policy {policy} not implemented yet.")
|
|
70
72
|
|
|
71
73
|
|
|
72
74
|
@contextlib.contextmanager
|
|
@@ -74,6 +76,7 @@ def autocast(device, dtype=torch.bfloat16, env=None):
|
|
|
74
76
|
del device
|
|
75
77
|
if env is None:
|
|
76
78
|
import torchax
|
|
79
|
+
|
|
77
80
|
env = torchax.default_env()
|
|
78
81
|
with env.override_property(autocast_dtype=dtype):
|
|
79
82
|
yield
|
|
@@ -81,266 +84,135 @@ def autocast(device, dtype=torch.bfloat16, env=None):
|
|
|
81
84
|
|
|
82
85
|
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
|
|
83
86
|
autocast_policy = {
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
CastPolicy.FP32,
|
|
216
|
-
torch.ops.aten.poisson_nll_loss.default:
|
|
217
|
-
CastPolicy.FP32,
|
|
218
|
-
torch.ops.aten.smooth_l1_loss.default:
|
|
219
|
-
CastPolicy.FP32,
|
|
220
|
-
torch.ops.aten.cross_entropy_loss.default:
|
|
221
|
-
CastPolicy.FP32,
|
|
222
|
-
torch.ops.aten.l1_loss.default:
|
|
223
|
-
CastPolicy.FP32,
|
|
224
|
-
torch.ops.aten.huber_loss.default:
|
|
225
|
-
CastPolicy.FP32,
|
|
226
|
-
torch.ops.aten.margin_ranking_loss.default:
|
|
227
|
-
CastPolicy.FP32,
|
|
228
|
-
torch.ops.aten.soft_margin_loss.default:
|
|
229
|
-
CastPolicy.FP32,
|
|
230
|
-
torch.ops.aten.triplet_margin_loss.default:
|
|
231
|
-
CastPolicy.FP32,
|
|
232
|
-
torch.ops.aten.multi_margin_loss.default:
|
|
233
|
-
CastPolicy.FP32,
|
|
234
|
-
torch.ops.aten.ctc_loss.IntList:
|
|
235
|
-
CastPolicy.FP32,
|
|
236
|
-
torch.ops.aten.ctc_loss.Tensor:
|
|
237
|
-
CastPolicy.FP32,
|
|
238
|
-
torch.ops.aten.kl_div.default:
|
|
239
|
-
CastPolicy.FP32,
|
|
240
|
-
torch.ops.aten.multilabel_margin_loss.default:
|
|
241
|
-
CastPolicy.FP32,
|
|
242
|
-
torch.ops.aten.binary_cross_entropy_with_logits.default:
|
|
243
|
-
CastPolicy.FP32,
|
|
244
|
-
torch.ops.aten.fft_fft.default:
|
|
245
|
-
CastPolicy.FP32,
|
|
246
|
-
torch.ops.aten.fft_ifft.default:
|
|
247
|
-
CastPolicy.FP32,
|
|
248
|
-
torch.ops.aten.fft_fft2.default:
|
|
249
|
-
CastPolicy.FP32,
|
|
250
|
-
torch.ops.aten.fft_ifft2.default:
|
|
251
|
-
CastPolicy.FP32,
|
|
252
|
-
torch.ops.aten.fft_fftn.default:
|
|
253
|
-
CastPolicy.FP32,
|
|
254
|
-
torch.ops.aten.fft_ifftn.default:
|
|
255
|
-
CastPolicy.FP32,
|
|
256
|
-
torch.ops.aten.fft_rfft.default:
|
|
257
|
-
CastPolicy.FP32,
|
|
258
|
-
torch.ops.aten.fft_irfft.default:
|
|
259
|
-
CastPolicy.FP32,
|
|
260
|
-
torch.ops.aten.fft_rfft2.default:
|
|
261
|
-
CastPolicy.FP32,
|
|
262
|
-
torch.ops.aten.fft_irfft2.default:
|
|
263
|
-
CastPolicy.FP32,
|
|
264
|
-
torch.ops.aten.fft_rfftn.default:
|
|
265
|
-
CastPolicy.FP32,
|
|
266
|
-
torch.ops.aten.fft_irfftn.default:
|
|
267
|
-
CastPolicy.FP32,
|
|
268
|
-
torch.ops.aten.fft_hfft.default:
|
|
269
|
-
CastPolicy.FP32,
|
|
270
|
-
torch.ops.aten.fft_ihfft.default:
|
|
271
|
-
CastPolicy.FP32,
|
|
272
|
-
torch.ops.aten.linalg_cond.default:
|
|
273
|
-
CastPolicy.FP32,
|
|
274
|
-
torch.ops.aten.linalg_cond.p_str:
|
|
275
|
-
CastPolicy.FP32,
|
|
276
|
-
torch.ops.aten.linalg_matrix_rank.default:
|
|
277
|
-
CastPolicy.FP32,
|
|
278
|
-
torch.ops.aten.linalg_matrix_rank.tol_tensor:
|
|
279
|
-
CastPolicy.FP32,
|
|
280
|
-
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
|
|
281
|
-
CastPolicy.FP32,
|
|
282
|
-
torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
|
|
283
|
-
CastPolicy.FP32,
|
|
284
|
-
torch.ops.aten.linalg_solve.default:
|
|
285
|
-
CastPolicy.FP32,
|
|
286
|
-
torch.ops.aten.linalg_cholesky.default:
|
|
287
|
-
CastPolicy.FP32,
|
|
288
|
-
torch.ops.aten.linalg_svdvals.default:
|
|
289
|
-
CastPolicy.FP32,
|
|
290
|
-
torch.ops.aten.linalg_eigvals.default:
|
|
291
|
-
CastPolicy.FP32,
|
|
292
|
-
torch.ops.aten.linalg_eigvalsh.default:
|
|
293
|
-
CastPolicy.FP32,
|
|
294
|
-
torch.ops.aten.linalg_inv.default:
|
|
295
|
-
CastPolicy.FP32,
|
|
296
|
-
torch.ops.aten.linalg_householder_product.default:
|
|
297
|
-
CastPolicy.FP32,
|
|
298
|
-
torch.ops.aten.linalg_tensorinv.default:
|
|
299
|
-
CastPolicy.FP32,
|
|
300
|
-
torch.ops.aten.linalg_tensorsolve.default:
|
|
301
|
-
CastPolicy.FP32,
|
|
302
|
-
torch.ops.aten.fake_quantize_per_tensor_affine.default:
|
|
303
|
-
CastPolicy.FP32,
|
|
304
|
-
torch.ops.aten.geqrf.default:
|
|
305
|
-
CastPolicy.FP32,
|
|
306
|
-
torch.ops.aten._lu_with_info.default:
|
|
307
|
-
CastPolicy.FP32,
|
|
308
|
-
torch.ops.aten.qr.default:
|
|
309
|
-
CastPolicy.FP32,
|
|
310
|
-
torch.ops.aten.svd.default:
|
|
311
|
-
CastPolicy.FP32,
|
|
312
|
-
torch.ops.aten.triangular_solve.default:
|
|
313
|
-
CastPolicy.FP32,
|
|
314
|
-
torch.ops.aten.fractional_max_pool2d.default:
|
|
315
|
-
CastPolicy.FP32,
|
|
316
|
-
torch.ops.aten.fractional_max_pool3d.default:
|
|
317
|
-
CastPolicy.FP32,
|
|
318
|
-
torch.ops.aten.adaptive_max_pool3d.default:
|
|
319
|
-
CastPolicy.FP32,
|
|
320
|
-
torch.ops.aten.multilabel_margin_loss_forward.default:
|
|
321
|
-
CastPolicy.FP32,
|
|
322
|
-
torch.ops.aten.linalg_qr.default:
|
|
323
|
-
CastPolicy.FP32,
|
|
324
|
-
torch.ops.aten.linalg_cholesky_ex.default:
|
|
325
|
-
CastPolicy.FP32,
|
|
326
|
-
torch.ops.aten.linalg_svd.default:
|
|
327
|
-
CastPolicy.FP32,
|
|
328
|
-
torch.ops.aten.linalg_eig.default:
|
|
329
|
-
CastPolicy.FP32,
|
|
330
|
-
torch.ops.aten.linalg_eigh.default:
|
|
331
|
-
CastPolicy.FP32,
|
|
332
|
-
torch.ops.aten.linalg_lstsq.default:
|
|
333
|
-
CastPolicy.FP32,
|
|
334
|
-
torch.ops.aten.linalg_inv_ex.default:
|
|
335
|
-
CastPolicy.FP32,
|
|
336
|
-
|
|
337
|
-
# promote
|
|
338
|
-
torch.ops.aten.stack.default:
|
|
339
|
-
CastPolicy.PROMOTE,
|
|
340
|
-
torch.ops.aten.cat.default:
|
|
341
|
-
CastPolicy.PROMOTE,
|
|
342
|
-
torch.ops.aten.index_copy.default:
|
|
343
|
-
CastPolicy.PROMOTE,
|
|
344
|
-
torch.ops.aten.index_copy.dimname:
|
|
345
|
-
CastPolicy.PROMOTE,
|
|
87
|
+
torch.ops.aten.conv1d.default: CastPolicy.LOWER_PRECISION_FP,
|
|
88
|
+
torch.ops.aten.conv1d.padding: CastPolicy.LOWER_PRECISION_FP,
|
|
89
|
+
torch.ops.aten.conv2d.default: CastPolicy.LOWER_PRECISION_FP,
|
|
90
|
+
torch.ops.aten.conv2d.padding: CastPolicy.LOWER_PRECISION_FP,
|
|
91
|
+
torch.ops.aten.conv3d.default: CastPolicy.LOWER_PRECISION_FP,
|
|
92
|
+
torch.ops.aten.conv3d.padding: CastPolicy.LOWER_PRECISION_FP,
|
|
93
|
+
torch.ops.aten.bmm.default: CastPolicy.LOWER_PRECISION_FP,
|
|
94
|
+
torch.ops.aten.mm.default: CastPolicy.LOWER_PRECISION_FP,
|
|
95
|
+
torch.ops.aten.linalg_vecdot.default: CastPolicy.LOWER_PRECISION_FP,
|
|
96
|
+
torch.ops.aten.baddbmm.default: CastPolicy.LOWER_PRECISION_FP,
|
|
97
|
+
torch.ops.aten.addmm.default: CastPolicy.LOWER_PRECISION_FP,
|
|
98
|
+
torch.ops.aten._addmm_activation.default: CastPolicy.LOWER_PRECISION_FP,
|
|
99
|
+
torch.ops.aten.addbmm.default: CastPolicy.LOWER_PRECISION_FP,
|
|
100
|
+
torch.ops.aten.linear.default: CastPolicy.LOWER_PRECISION_FP,
|
|
101
|
+
torch.ops.aten._convolution.deprecated: CastPolicy.LOWER_PRECISION_FP,
|
|
102
|
+
torch.ops.aten.matmul.default: CastPolicy.LOWER_PRECISION_FP,
|
|
103
|
+
torch.ops.aten.conv_tbc.default: CastPolicy.LOWER_PRECISION_FP,
|
|
104
|
+
torch.ops.aten.mkldnn_rnn_layer.default: CastPolicy.LOWER_PRECISION_FP,
|
|
105
|
+
torch.ops.aten.conv_transpose1d.default: CastPolicy.LOWER_PRECISION_FP,
|
|
106
|
+
torch.ops.aten.conv_transpose2d.input: CastPolicy.LOWER_PRECISION_FP,
|
|
107
|
+
torch.ops.aten.conv_transpose3d.input: CastPolicy.LOWER_PRECISION_FP,
|
|
108
|
+
torch.ops.aten.prelu.default: CastPolicy.LOWER_PRECISION_FP,
|
|
109
|
+
torch.ops.aten.scaled_dot_product_attention.default: CastPolicy.LOWER_PRECISION_FP,
|
|
110
|
+
torch.ops.aten._native_multi_head_attention.default: CastPolicy.LOWER_PRECISION_FP,
|
|
111
|
+
# fp32 cast policy
|
|
112
|
+
torch.ops.aten.avg_pool3d.default: CastPolicy.FP32,
|
|
113
|
+
torch.ops.aten.binary_cross_entropy.default: CastPolicy.FP32,
|
|
114
|
+
torch.ops.aten.grid_sampler.default: CastPolicy.FP32,
|
|
115
|
+
torch.ops.aten.polar.default: CastPolicy.FP32,
|
|
116
|
+
torch.ops.aten.prod.default: CastPolicy.FP32,
|
|
117
|
+
torch.ops.aten.prod.dim_int: CastPolicy.FP32,
|
|
118
|
+
torch.ops.aten.prod.dim_Dimname: CastPolicy.FP32,
|
|
119
|
+
torch.ops.aten.quantile.default: CastPolicy.FP32,
|
|
120
|
+
torch.ops.aten.quantile.scalar: CastPolicy.FP32,
|
|
121
|
+
torch.ops.aten.nanquantile.default: CastPolicy.FP32,
|
|
122
|
+
torch.ops.aten.nanquantile.scalar: CastPolicy.FP32,
|
|
123
|
+
torch.ops.aten.stft.default: CastPolicy.FP32,
|
|
124
|
+
torch.ops.aten.stft.center: CastPolicy.FP32,
|
|
125
|
+
torch.ops.aten.cdist.default: CastPolicy.FP32,
|
|
126
|
+
torch.ops.aten.grid_sampler_2d.default: CastPolicy.FP32,
|
|
127
|
+
torch.ops.aten._grid_sampler_2d_cpu_fallback.default: CastPolicy.FP32,
|
|
128
|
+
torch.ops.aten.grid_sampler_3d.default: CastPolicy.FP32,
|
|
129
|
+
torch.ops.aten.trace.default: CastPolicy.FP32,
|
|
130
|
+
torch.ops.aten.view_as_complex.default: CastPolicy.FP32,
|
|
131
|
+
torch.ops.aten.cholesky.default: CastPolicy.FP32,
|
|
132
|
+
torch.ops.aten.cholesky_inverse.default: CastPolicy.FP32,
|
|
133
|
+
torch.ops.aten.cholesky_solve.default: CastPolicy.FP32,
|
|
134
|
+
torch.ops.aten.inverse.default: CastPolicy.FP32,
|
|
135
|
+
torch.ops.aten.lu_solve.default: CastPolicy.FP32,
|
|
136
|
+
torch.ops.aten.orgqr.default: CastPolicy.FP32,
|
|
137
|
+
torch.ops.aten.ormqr.default: CastPolicy.FP32,
|
|
138
|
+
torch.ops.aten.pinverse.default: CastPolicy.FP32,
|
|
139
|
+
torch.ops.aten.max_pool3d.default: CastPolicy.FP32,
|
|
140
|
+
torch.ops.aten.max_unpool2d.default: CastPolicy.FP32,
|
|
141
|
+
torch.ops.aten.max_unpool3d.default: CastPolicy.FP32,
|
|
142
|
+
torch.ops.aten.adaptive_avg_pool3d.default: CastPolicy.FP32,
|
|
143
|
+
torch.ops.aten.reflection_pad1d.default: CastPolicy.FP32,
|
|
144
|
+
torch.ops.aten.reflection_pad2d.default: CastPolicy.FP32,
|
|
145
|
+
torch.ops.aten.replication_pad1d.default: CastPolicy.FP32,
|
|
146
|
+
torch.ops.aten.replication_pad2d.default: CastPolicy.FP32,
|
|
147
|
+
torch.ops.aten.replication_pad3d.default: CastPolicy.FP32,
|
|
148
|
+
torch.ops.aten.mse_loss.default: CastPolicy.FP32,
|
|
149
|
+
torch.ops.aten.cosine_embedding_loss.default: CastPolicy.FP32,
|
|
150
|
+
torch.ops.aten.nll_loss.default: CastPolicy.FP32,
|
|
151
|
+
torch.ops.aten.nll_loss2d.default: CastPolicy.FP32,
|
|
152
|
+
torch.ops.aten.hinge_embedding_loss.default: CastPolicy.FP32,
|
|
153
|
+
torch.ops.aten.poisson_nll_loss.default: CastPolicy.FP32,
|
|
154
|
+
torch.ops.aten.smooth_l1_loss.default: CastPolicy.FP32,
|
|
155
|
+
torch.ops.aten.cross_entropy_loss.default: CastPolicy.FP32,
|
|
156
|
+
torch.ops.aten.l1_loss.default: CastPolicy.FP32,
|
|
157
|
+
torch.ops.aten.huber_loss.default: CastPolicy.FP32,
|
|
158
|
+
torch.ops.aten.margin_ranking_loss.default: CastPolicy.FP32,
|
|
159
|
+
torch.ops.aten.soft_margin_loss.default: CastPolicy.FP32,
|
|
160
|
+
torch.ops.aten.triplet_margin_loss.default: CastPolicy.FP32,
|
|
161
|
+
torch.ops.aten.multi_margin_loss.default: CastPolicy.FP32,
|
|
162
|
+
torch.ops.aten.ctc_loss.IntList: CastPolicy.FP32,
|
|
163
|
+
torch.ops.aten.ctc_loss.Tensor: CastPolicy.FP32,
|
|
164
|
+
torch.ops.aten.kl_div.default: CastPolicy.FP32,
|
|
165
|
+
torch.ops.aten.multilabel_margin_loss.default: CastPolicy.FP32,
|
|
166
|
+
torch.ops.aten.binary_cross_entropy_with_logits.default: CastPolicy.FP32,
|
|
167
|
+
torch.ops.aten.fft_fft.default: CastPolicy.FP32,
|
|
168
|
+
torch.ops.aten.fft_ifft.default: CastPolicy.FP32,
|
|
169
|
+
torch.ops.aten.fft_fft2.default: CastPolicy.FP32,
|
|
170
|
+
torch.ops.aten.fft_ifft2.default: CastPolicy.FP32,
|
|
171
|
+
torch.ops.aten.fft_fftn.default: CastPolicy.FP32,
|
|
172
|
+
torch.ops.aten.fft_ifftn.default: CastPolicy.FP32,
|
|
173
|
+
torch.ops.aten.fft_rfft.default: CastPolicy.FP32,
|
|
174
|
+
torch.ops.aten.fft_irfft.default: CastPolicy.FP32,
|
|
175
|
+
torch.ops.aten.fft_rfft2.default: CastPolicy.FP32,
|
|
176
|
+
torch.ops.aten.fft_irfft2.default: CastPolicy.FP32,
|
|
177
|
+
torch.ops.aten.fft_rfftn.default: CastPolicy.FP32,
|
|
178
|
+
torch.ops.aten.fft_irfftn.default: CastPolicy.FP32,
|
|
179
|
+
torch.ops.aten.fft_hfft.default: CastPolicy.FP32,
|
|
180
|
+
torch.ops.aten.fft_ihfft.default: CastPolicy.FP32,
|
|
181
|
+
torch.ops.aten.linalg_cond.default: CastPolicy.FP32,
|
|
182
|
+
torch.ops.aten.linalg_cond.p_str: CastPolicy.FP32,
|
|
183
|
+
torch.ops.aten.linalg_matrix_rank.default: CastPolicy.FP32,
|
|
184
|
+
torch.ops.aten.linalg_matrix_rank.tol_tensor: CastPolicy.FP32,
|
|
185
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor: CastPolicy.FP32,
|
|
186
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_float: CastPolicy.FP32,
|
|
187
|
+
torch.ops.aten.linalg_solve.default: CastPolicy.FP32,
|
|
188
|
+
torch.ops.aten.linalg_cholesky.default: CastPolicy.FP32,
|
|
189
|
+
torch.ops.aten.linalg_svdvals.default: CastPolicy.FP32,
|
|
190
|
+
torch.ops.aten.linalg_eigvals.default: CastPolicy.FP32,
|
|
191
|
+
torch.ops.aten.linalg_eigvalsh.default: CastPolicy.FP32,
|
|
192
|
+
torch.ops.aten.linalg_inv.default: CastPolicy.FP32,
|
|
193
|
+
torch.ops.aten.linalg_householder_product.default: CastPolicy.FP32,
|
|
194
|
+
torch.ops.aten.linalg_tensorinv.default: CastPolicy.FP32,
|
|
195
|
+
torch.ops.aten.linalg_tensorsolve.default: CastPolicy.FP32,
|
|
196
|
+
torch.ops.aten.fake_quantize_per_tensor_affine.default: CastPolicy.FP32,
|
|
197
|
+
torch.ops.aten.geqrf.default: CastPolicy.FP32,
|
|
198
|
+
torch.ops.aten._lu_with_info.default: CastPolicy.FP32,
|
|
199
|
+
torch.ops.aten.qr.default: CastPolicy.FP32,
|
|
200
|
+
torch.ops.aten.svd.default: CastPolicy.FP32,
|
|
201
|
+
torch.ops.aten.triangular_solve.default: CastPolicy.FP32,
|
|
202
|
+
torch.ops.aten.fractional_max_pool2d.default: CastPolicy.FP32,
|
|
203
|
+
torch.ops.aten.fractional_max_pool3d.default: CastPolicy.FP32,
|
|
204
|
+
torch.ops.aten.adaptive_max_pool3d.default: CastPolicy.FP32,
|
|
205
|
+
torch.ops.aten.multilabel_margin_loss_forward.default: CastPolicy.FP32,
|
|
206
|
+
torch.ops.aten.linalg_qr.default: CastPolicy.FP32,
|
|
207
|
+
torch.ops.aten.linalg_cholesky_ex.default: CastPolicy.FP32,
|
|
208
|
+
torch.ops.aten.linalg_svd.default: CastPolicy.FP32,
|
|
209
|
+
torch.ops.aten.linalg_eig.default: CastPolicy.FP32,
|
|
210
|
+
torch.ops.aten.linalg_eigh.default: CastPolicy.FP32,
|
|
211
|
+
torch.ops.aten.linalg_lstsq.default: CastPolicy.FP32,
|
|
212
|
+
torch.ops.aten.linalg_inv_ex.default: CastPolicy.FP32,
|
|
213
|
+
# promote
|
|
214
|
+
torch.ops.aten.stack.default: CastPolicy.PROMOTE,
|
|
215
|
+
torch.ops.aten.cat.default: CastPolicy.PROMOTE,
|
|
216
|
+
torch.ops.aten.index_copy.default: CastPolicy.PROMOTE,
|
|
217
|
+
torch.ops.aten.index_copy.dimname: CastPolicy.PROMOTE,
|
|
346
218
|
}
|