torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/CONTRIBUTING.md
CHANGED
|
@@ -8,7 +8,7 @@ If you plan to contribute new features, utility functions or extensions to the c
|
|
|
8
8
|
# Developer setup
|
|
9
9
|
|
|
10
10
|
## Mac setup:
|
|
11
|
-
@qihqi
|
|
11
|
+
@qihqi
|
|
12
12
|
|
|
13
13
|
I am able to develop directly on mac (m1) laptop for most of parts. Using steps
|
|
14
14
|
in README.md works. The condensed version for easy copy & paste:
|
|
@@ -24,7 +24,7 @@ pytest test
|
|
|
24
24
|
|
|
25
25
|
### VSCode
|
|
26
26
|
|
|
27
|
-
I use vscode on my Mac. I loosely followed instruction in
|
|
27
|
+
I use vscode on my Mac. I loosely followed instruction in
|
|
28
28
|
https://code.visualstudio.com/docs/python/python-tutorial
|
|
29
29
|
to setup a proper python environment.
|
|
30
30
|
|
torchax/__init__.py
CHANGED
|
@@ -6,29 +6,31 @@ import os
|
|
|
6
6
|
import torch
|
|
7
7
|
from torch.utils import _pytree as pytree
|
|
8
8
|
from torchax import tensor
|
|
9
|
-
from
|
|
9
|
+
from contextlib import contextmanager
|
|
10
10
|
|
|
11
|
-
__version__ = "0.0.
|
|
11
|
+
__version__ = "0.0.6"
|
|
12
12
|
VERSION = __version__
|
|
13
13
|
|
|
14
14
|
__all__ = [
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
15
|
+
'default_env',
|
|
16
|
+
'extract_jax',
|
|
17
|
+
'enable_globally',
|
|
18
18
|
]
|
|
19
19
|
|
|
20
20
|
from jax._src import xla_bridge
|
|
21
|
+
|
|
21
22
|
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
|
|
22
23
|
|
|
23
24
|
# torchax:oss-begin
|
|
24
25
|
if getattr(jax.config, 'jax_pjrt_client_create_options', None):
|
|
25
26
|
jax.config.update(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
)
|
|
27
|
+
'jax_pjrt_client_create_options',
|
|
28
|
+
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}')
|
|
29
29
|
# torchax:oss-end
|
|
30
30
|
|
|
31
31
|
env = None
|
|
32
|
+
|
|
33
|
+
|
|
32
34
|
def default_env():
|
|
33
35
|
global env
|
|
34
36
|
|
|
@@ -37,53 +39,53 @@ def default_env():
|
|
|
37
39
|
return env
|
|
38
40
|
|
|
39
41
|
|
|
40
|
-
|
|
41
42
|
def extract_jax(mod: torch.nn.Module, env=None):
|
|
42
43
|
"""Returns a pytree of jax.ndarray and a jax callable."""
|
|
43
44
|
if env is None:
|
|
44
45
|
env = default_env()
|
|
45
|
-
states = mod.
|
|
46
|
+
states = dict(mod.named_buffers())
|
|
47
|
+
states.update(mod.named_parameters())
|
|
46
48
|
|
|
47
|
-
states =
|
|
49
|
+
states = env.t2j_copy(states)
|
|
48
50
|
|
|
49
51
|
#@jax.jit
|
|
50
|
-
def jax_func(states,
|
|
51
|
-
(states,
|
|
52
|
+
def jax_func(states, args, kwargs=None):
|
|
53
|
+
(states, args, kwargs) = env.j2t_iso((states, args, kwargs))
|
|
52
54
|
with env:
|
|
53
|
-
res = torch.func.functional_call(
|
|
55
|
+
res = torch.func.functional_call(
|
|
56
|
+
mod, states, args, kwargs, tie_weights=False)
|
|
54
57
|
return env.t2j_iso(res)
|
|
55
58
|
|
|
56
59
|
return states, jax_func
|
|
57
60
|
|
|
61
|
+
|
|
58
62
|
def enable_globally():
|
|
59
63
|
env = default_env().enable_torch_modes()
|
|
60
64
|
return env
|
|
61
65
|
|
|
66
|
+
|
|
62
67
|
def disable_globally():
|
|
63
|
-
global env
|
|
68
|
+
global env
|
|
64
69
|
default_env().disable_torch_modes()
|
|
65
70
|
|
|
71
|
+
|
|
66
72
|
@contextlib.contextmanager
|
|
67
73
|
def disable_temporarily():
|
|
68
74
|
prev = default_env().enabled
|
|
69
75
|
if prev:
|
|
70
76
|
disable_globally()
|
|
71
|
-
yield()
|
|
77
|
+
yield ()
|
|
72
78
|
if prev:
|
|
73
79
|
enable_globally()
|
|
74
80
|
|
|
75
81
|
|
|
76
82
|
torch.utils.rename_privateuse1_backend('jax')
|
|
77
83
|
unsupported_dtype = [torch.quint8]
|
|
78
|
-
torch.utils.generate_methods_for_privateuse1_backend(
|
|
79
|
-
for_tensor=True, for_module=True, for_storage=True,
|
|
80
|
-
unsupported_dtype=unsupported_dtype)
|
|
81
84
|
|
|
82
85
|
import jax
|
|
83
86
|
import torchax.device_module
|
|
84
|
-
torch._register_device_module('jax', torchax.device_module)
|
|
85
|
-
|
|
86
87
|
|
|
88
|
+
torch._register_device_module('jax', torchax.device_module)
|
|
87
89
|
|
|
88
90
|
|
|
89
91
|
def enable_accuracy_mode():
|
|
@@ -98,13 +100,13 @@ def enable_performance_mode():
|
|
|
98
100
|
default_env().config.internal_respect_torch_return_dtypes = False
|
|
99
101
|
|
|
100
102
|
|
|
101
|
-
|
|
102
103
|
@dataclasses.dataclass
|
|
103
104
|
class CompileOptions:
|
|
104
105
|
# only valid if compiling nn.Module
|
|
105
|
-
methods_to_compile: List[str] = dataclasses.field(
|
|
106
|
+
methods_to_compile: List[str] = dataclasses.field(
|
|
107
|
+
default_factory=lambda: ['forward'])
|
|
106
108
|
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
107
|
-
mode: str = 'jax'
|
|
109
|
+
mode: str = 'jax' # or dynamo or export
|
|
108
110
|
|
|
109
111
|
|
|
110
112
|
def compile(fn, options: Optional[CompileOptions] = None):
|
torchax/amp.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import enum
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils import _pytree as pytree
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# enum class CastPolicy : uint8_t {
|
|
8
|
+
# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
|
|
9
|
+
# // running the op. Currently, lower_precision_fp is
|
|
10
|
+
# // fp16 for AutocastCUDA, and is defined by user
|
|
11
|
+
# // (default bf16) for AutocastCPU or other device.
|
|
12
|
+
# fp32, // Cast all inputs to at::kFloat before running the op.
|
|
13
|
+
# fp32_set_opt_dtype, // Treats functions (like softmax) that
|
|
14
|
+
# // 1. we'd like to run in fp32 and
|
|
15
|
+
# // 2. have a std::optional<ScalarType> arg that controls
|
|
16
|
+
# // the output type.
|
|
17
|
+
# // fp32_set_opt_dtype wrappers' policy is: if the output
|
|
18
|
+
# // type is already set, don't touch it, otherwise, set
|
|
19
|
+
# // it to at::kFloat.
|
|
20
|
+
# fp32_append_dtype, // Treats functions (like norm) that
|
|
21
|
+
# // 1. we'd like to run in fp32 and
|
|
22
|
+
# // 2. have some overloads that accept an output type and
|
|
23
|
+
# // other overloads that don't.
|
|
24
|
+
# // fp32_append_dtype wrappers wrap the overloads that don't
|
|
25
|
+
# // have an output dtype.
|
|
26
|
+
# // The wrapper policy is: append at::kFloat to the args,
|
|
27
|
+
# // and redispatch to the type-aware overload.
|
|
28
|
+
# promote, // Run in the widest dtype among several args.
|
|
29
|
+
# };
|
|
30
|
+
class CastPolicy(enum.Enum):
|
|
31
|
+
LOWER_PRECISION_FP = 0
|
|
32
|
+
FP32 = 1
|
|
33
|
+
FP32_SET_OPT_DTYPE = 2
|
|
34
|
+
FP32_APPEND_DTYPE = 3
|
|
35
|
+
PROMOTE = 4
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def execute_policy(policy, args, kwargs, target_lower_fp):
|
|
39
|
+
|
|
40
|
+
def is_float(a):
|
|
41
|
+
return isinstance(a, torch.Tensor) and a.is_floating_point()
|
|
42
|
+
match policy:
|
|
43
|
+
case CastPolicy.LOWER_PRECISION_FP:
|
|
44
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
|
|
45
|
+
(args, kwargs))
|
|
46
|
+
case CastPolicy.FP32:
|
|
47
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
|
|
48
|
+
(args, kwargs))
|
|
49
|
+
case CastPolicy.PROMOTE:
|
|
50
|
+
dtypes = set(a.dtype for a in args)
|
|
51
|
+
widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
|
|
52
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(widest),
|
|
53
|
+
(args, kwargs))
|
|
54
|
+
case _:
|
|
55
|
+
raise AssertionError(f'Policy {policy} not implemented yet.')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@contextlib.contextmanager
|
|
59
|
+
def autocast(device, dtype=torch.bfloat16, env=None):
|
|
60
|
+
del device
|
|
61
|
+
if env is None:
|
|
62
|
+
import torchax
|
|
63
|
+
env = torchax.default_env()
|
|
64
|
+
with env.override_property(autocast_dtype=dtype):
|
|
65
|
+
yield
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
|
|
69
|
+
autocast_policy = {
|
|
70
|
+
torch.ops.aten.conv1d.default:
|
|
71
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
72
|
+
torch.ops.aten.conv1d.padding:
|
|
73
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
74
|
+
torch.ops.aten.conv2d.default:
|
|
75
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
76
|
+
torch.ops.aten.conv2d.padding:
|
|
77
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
78
|
+
torch.ops.aten.conv3d.default:
|
|
79
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
80
|
+
torch.ops.aten.conv3d.padding:
|
|
81
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
82
|
+
torch.ops.aten.bmm.default:
|
|
83
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
84
|
+
torch.ops.aten.mm.default:
|
|
85
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
86
|
+
torch.ops.aten.linalg_vecdot.default:
|
|
87
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
88
|
+
torch.ops.aten.baddbmm.default:
|
|
89
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
90
|
+
torch.ops.aten.addmm.default:
|
|
91
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
92
|
+
torch.ops.aten._addmm_activation.default:
|
|
93
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
94
|
+
torch.ops.aten.addbmm.default:
|
|
95
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
96
|
+
torch.ops.aten.linear.default:
|
|
97
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
98
|
+
torch.ops.aten._convolution.deprecated:
|
|
99
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
100
|
+
torch.ops.aten.matmul.default:
|
|
101
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
102
|
+
torch.ops.aten.conv_tbc.default:
|
|
103
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
104
|
+
torch.ops.aten.mkldnn_rnn_layer.default:
|
|
105
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
106
|
+
torch.ops.aten.conv_transpose1d.default:
|
|
107
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
108
|
+
torch.ops.aten.conv_transpose2d.input:
|
|
109
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
110
|
+
torch.ops.aten.conv_transpose3d.input:
|
|
111
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
112
|
+
torch.ops.aten.prelu.default:
|
|
113
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
114
|
+
torch.ops.aten.scaled_dot_product_attention.default:
|
|
115
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
116
|
+
torch.ops.aten._native_multi_head_attention.default:
|
|
117
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
118
|
+
|
|
119
|
+
# fp32 cast policy
|
|
120
|
+
torch.ops.aten.avg_pool3d.default:
|
|
121
|
+
CastPolicy.FP32,
|
|
122
|
+
torch.ops.aten.binary_cross_entropy.default:
|
|
123
|
+
CastPolicy.FP32,
|
|
124
|
+
torch.ops.aten.grid_sampler.default:
|
|
125
|
+
CastPolicy.FP32,
|
|
126
|
+
torch.ops.aten.polar.default:
|
|
127
|
+
CastPolicy.FP32,
|
|
128
|
+
torch.ops.aten.prod.default:
|
|
129
|
+
CastPolicy.FP32,
|
|
130
|
+
torch.ops.aten.prod.dim_int:
|
|
131
|
+
CastPolicy.FP32,
|
|
132
|
+
torch.ops.aten.prod.dim_Dimname:
|
|
133
|
+
CastPolicy.FP32,
|
|
134
|
+
torch.ops.aten.quantile.default:
|
|
135
|
+
CastPolicy.FP32,
|
|
136
|
+
torch.ops.aten.quantile.scalar:
|
|
137
|
+
CastPolicy.FP32,
|
|
138
|
+
torch.ops.aten.nanquantile.default:
|
|
139
|
+
CastPolicy.FP32,
|
|
140
|
+
torch.ops.aten.nanquantile.scalar:
|
|
141
|
+
CastPolicy.FP32,
|
|
142
|
+
torch.ops.aten.stft.default:
|
|
143
|
+
CastPolicy.FP32,
|
|
144
|
+
torch.ops.aten.stft.center:
|
|
145
|
+
CastPolicy.FP32,
|
|
146
|
+
torch.ops.aten.cdist.default:
|
|
147
|
+
CastPolicy.FP32,
|
|
148
|
+
torch.ops.aten.grid_sampler_2d.default:
|
|
149
|
+
CastPolicy.FP32,
|
|
150
|
+
torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
|
|
151
|
+
CastPolicy.FP32,
|
|
152
|
+
torch.ops.aten.grid_sampler_3d.default:
|
|
153
|
+
CastPolicy.FP32,
|
|
154
|
+
torch.ops.aten.trace.default:
|
|
155
|
+
CastPolicy.FP32,
|
|
156
|
+
torch.ops.aten.view_as_complex.default:
|
|
157
|
+
CastPolicy.FP32,
|
|
158
|
+
torch.ops.aten.cholesky.default:
|
|
159
|
+
CastPolicy.FP32,
|
|
160
|
+
torch.ops.aten.cholesky_inverse.default:
|
|
161
|
+
CastPolicy.FP32,
|
|
162
|
+
torch.ops.aten.cholesky_solve.default:
|
|
163
|
+
CastPolicy.FP32,
|
|
164
|
+
torch.ops.aten.inverse.default:
|
|
165
|
+
CastPolicy.FP32,
|
|
166
|
+
torch.ops.aten.lu_solve.default:
|
|
167
|
+
CastPolicy.FP32,
|
|
168
|
+
torch.ops.aten.orgqr.default:
|
|
169
|
+
CastPolicy.FP32,
|
|
170
|
+
torch.ops.aten.ormqr.default:
|
|
171
|
+
CastPolicy.FP32,
|
|
172
|
+
torch.ops.aten.pinverse.default:
|
|
173
|
+
CastPolicy.FP32,
|
|
174
|
+
torch.ops.aten.max_pool3d.default:
|
|
175
|
+
CastPolicy.FP32,
|
|
176
|
+
torch.ops.aten.max_unpool2d.default:
|
|
177
|
+
CastPolicy.FP32,
|
|
178
|
+
torch.ops.aten.max_unpool3d.default:
|
|
179
|
+
CastPolicy.FP32,
|
|
180
|
+
torch.ops.aten.adaptive_avg_pool3d.default:
|
|
181
|
+
CastPolicy.FP32,
|
|
182
|
+
torch.ops.aten.reflection_pad1d.default:
|
|
183
|
+
CastPolicy.FP32,
|
|
184
|
+
torch.ops.aten.reflection_pad2d.default:
|
|
185
|
+
CastPolicy.FP32,
|
|
186
|
+
torch.ops.aten.replication_pad1d.default:
|
|
187
|
+
CastPolicy.FP32,
|
|
188
|
+
torch.ops.aten.replication_pad2d.default:
|
|
189
|
+
CastPolicy.FP32,
|
|
190
|
+
torch.ops.aten.replication_pad3d.default:
|
|
191
|
+
CastPolicy.FP32,
|
|
192
|
+
torch.ops.aten.mse_loss.default:
|
|
193
|
+
CastPolicy.FP32,
|
|
194
|
+
torch.ops.aten.cosine_embedding_loss.default:
|
|
195
|
+
CastPolicy.FP32,
|
|
196
|
+
torch.ops.aten.nll_loss.default:
|
|
197
|
+
CastPolicy.FP32,
|
|
198
|
+
torch.ops.aten.nll_loss2d.default:
|
|
199
|
+
CastPolicy.FP32,
|
|
200
|
+
torch.ops.aten.hinge_embedding_loss.default:
|
|
201
|
+
CastPolicy.FP32,
|
|
202
|
+
torch.ops.aten.poisson_nll_loss.default:
|
|
203
|
+
CastPolicy.FP32,
|
|
204
|
+
torch.ops.aten.smooth_l1_loss.default:
|
|
205
|
+
CastPolicy.FP32,
|
|
206
|
+
torch.ops.aten.cross_entropy_loss.default:
|
|
207
|
+
CastPolicy.FP32,
|
|
208
|
+
torch.ops.aten.l1_loss.default:
|
|
209
|
+
CastPolicy.FP32,
|
|
210
|
+
torch.ops.aten.huber_loss.default:
|
|
211
|
+
CastPolicy.FP32,
|
|
212
|
+
torch.ops.aten.margin_ranking_loss.default:
|
|
213
|
+
CastPolicy.FP32,
|
|
214
|
+
torch.ops.aten.soft_margin_loss.default:
|
|
215
|
+
CastPolicy.FP32,
|
|
216
|
+
torch.ops.aten.triplet_margin_loss.default:
|
|
217
|
+
CastPolicy.FP32,
|
|
218
|
+
torch.ops.aten.multi_margin_loss.default:
|
|
219
|
+
CastPolicy.FP32,
|
|
220
|
+
torch.ops.aten.ctc_loss.IntList:
|
|
221
|
+
CastPolicy.FP32,
|
|
222
|
+
torch.ops.aten.ctc_loss.Tensor:
|
|
223
|
+
CastPolicy.FP32,
|
|
224
|
+
torch.ops.aten.kl_div.default:
|
|
225
|
+
CastPolicy.FP32,
|
|
226
|
+
torch.ops.aten.multilabel_margin_loss.default:
|
|
227
|
+
CastPolicy.FP32,
|
|
228
|
+
torch.ops.aten.binary_cross_entropy_with_logits.default:
|
|
229
|
+
CastPolicy.FP32,
|
|
230
|
+
torch.ops.aten.fft_fft.default:
|
|
231
|
+
CastPolicy.FP32,
|
|
232
|
+
torch.ops.aten.fft_ifft.default:
|
|
233
|
+
CastPolicy.FP32,
|
|
234
|
+
torch.ops.aten.fft_fft2.default:
|
|
235
|
+
CastPolicy.FP32,
|
|
236
|
+
torch.ops.aten.fft_ifft2.default:
|
|
237
|
+
CastPolicy.FP32,
|
|
238
|
+
torch.ops.aten.fft_fftn.default:
|
|
239
|
+
CastPolicy.FP32,
|
|
240
|
+
torch.ops.aten.fft_ifftn.default:
|
|
241
|
+
CastPolicy.FP32,
|
|
242
|
+
torch.ops.aten.fft_rfft.default:
|
|
243
|
+
CastPolicy.FP32,
|
|
244
|
+
torch.ops.aten.fft_irfft.default:
|
|
245
|
+
CastPolicy.FP32,
|
|
246
|
+
torch.ops.aten.fft_rfft2.default:
|
|
247
|
+
CastPolicy.FP32,
|
|
248
|
+
torch.ops.aten.fft_irfft2.default:
|
|
249
|
+
CastPolicy.FP32,
|
|
250
|
+
torch.ops.aten.fft_rfftn.default:
|
|
251
|
+
CastPolicy.FP32,
|
|
252
|
+
torch.ops.aten.fft_irfftn.default:
|
|
253
|
+
CastPolicy.FP32,
|
|
254
|
+
torch.ops.aten.fft_hfft.default:
|
|
255
|
+
CastPolicy.FP32,
|
|
256
|
+
torch.ops.aten.fft_ihfft.default:
|
|
257
|
+
CastPolicy.FP32,
|
|
258
|
+
torch.ops.aten.linalg_cond.default:
|
|
259
|
+
CastPolicy.FP32,
|
|
260
|
+
torch.ops.aten.linalg_cond.p_str:
|
|
261
|
+
CastPolicy.FP32,
|
|
262
|
+
torch.ops.aten.linalg_matrix_rank.default:
|
|
263
|
+
CastPolicy.FP32,
|
|
264
|
+
torch.ops.aten.linalg_matrix_rank.tol_tensor:
|
|
265
|
+
CastPolicy.FP32,
|
|
266
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
|
|
267
|
+
CastPolicy.FP32,
|
|
268
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
|
|
269
|
+
CastPolicy.FP32,
|
|
270
|
+
torch.ops.aten.linalg_solve.default:
|
|
271
|
+
CastPolicy.FP32,
|
|
272
|
+
torch.ops.aten.linalg_cholesky.default:
|
|
273
|
+
CastPolicy.FP32,
|
|
274
|
+
torch.ops.aten.linalg_svdvals.default:
|
|
275
|
+
CastPolicy.FP32,
|
|
276
|
+
torch.ops.aten.linalg_eigvals.default:
|
|
277
|
+
CastPolicy.FP32,
|
|
278
|
+
torch.ops.aten.linalg_eigvalsh.default:
|
|
279
|
+
CastPolicy.FP32,
|
|
280
|
+
torch.ops.aten.linalg_inv.default:
|
|
281
|
+
CastPolicy.FP32,
|
|
282
|
+
torch.ops.aten.linalg_householder_product.default:
|
|
283
|
+
CastPolicy.FP32,
|
|
284
|
+
torch.ops.aten.linalg_tensorinv.default:
|
|
285
|
+
CastPolicy.FP32,
|
|
286
|
+
torch.ops.aten.linalg_tensorsolve.default:
|
|
287
|
+
CastPolicy.FP32,
|
|
288
|
+
torch.ops.aten.fake_quantize_per_tensor_affine.default:
|
|
289
|
+
CastPolicy.FP32,
|
|
290
|
+
torch.ops.aten.geqrf.default:
|
|
291
|
+
CastPolicy.FP32,
|
|
292
|
+
torch.ops.aten._lu_with_info.default:
|
|
293
|
+
CastPolicy.FP32,
|
|
294
|
+
torch.ops.aten.qr.default:
|
|
295
|
+
CastPolicy.FP32,
|
|
296
|
+
torch.ops.aten.svd.default:
|
|
297
|
+
CastPolicy.FP32,
|
|
298
|
+
torch.ops.aten.triangular_solve.default:
|
|
299
|
+
CastPolicy.FP32,
|
|
300
|
+
torch.ops.aten.fractional_max_pool2d.default:
|
|
301
|
+
CastPolicy.FP32,
|
|
302
|
+
torch.ops.aten.fractional_max_pool3d.default:
|
|
303
|
+
CastPolicy.FP32,
|
|
304
|
+
torch.ops.aten.adaptive_max_pool3d.default:
|
|
305
|
+
CastPolicy.FP32,
|
|
306
|
+
torch.ops.aten.multilabel_margin_loss_forward.default:
|
|
307
|
+
CastPolicy.FP32,
|
|
308
|
+
torch.ops.aten.linalg_qr.default:
|
|
309
|
+
CastPolicy.FP32,
|
|
310
|
+
torch.ops.aten.linalg_cholesky_ex.default:
|
|
311
|
+
CastPolicy.FP32,
|
|
312
|
+
torch.ops.aten.linalg_svd.default:
|
|
313
|
+
CastPolicy.FP32,
|
|
314
|
+
torch.ops.aten.linalg_eig.default:
|
|
315
|
+
CastPolicy.FP32,
|
|
316
|
+
torch.ops.aten.linalg_eigh.default:
|
|
317
|
+
CastPolicy.FP32,
|
|
318
|
+
torch.ops.aten.linalg_lstsq.default:
|
|
319
|
+
CastPolicy.FP32,
|
|
320
|
+
torch.ops.aten.linalg_inv_ex.default:
|
|
321
|
+
CastPolicy.FP32,
|
|
322
|
+
|
|
323
|
+
# promote
|
|
324
|
+
torch.ops.aten.stack.default:
|
|
325
|
+
CastPolicy.PROMOTE,
|
|
326
|
+
torch.ops.aten.cat.default:
|
|
327
|
+
CastPolicy.PROMOTE,
|
|
328
|
+
torch.ops.aten.index_copy.default:
|
|
329
|
+
CastPolicy.PROMOTE,
|
|
330
|
+
torch.ops.aten.index_copy.dimname:
|
|
331
|
+
CastPolicy.PROMOTE,
|
|
332
|
+
}
|
torchax/config.py
CHANGED
|
@@ -3,17 +3,28 @@ import dataclasses
|
|
|
3
3
|
|
|
4
4
|
@dataclasses.dataclass
|
|
5
5
|
class Configuration:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
6
|
+
debug_print_each_op: bool = False
|
|
7
|
+
debug_accuracy_for_each_op: bool = False
|
|
8
|
+
debug_mixed_tensor: bool = False
|
|
9
|
+
debug_print_each_op_operands: bool = False
|
|
10
|
+
|
|
11
|
+
use_int32_for_index: bool = False
|
|
12
|
+
|
|
13
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
14
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
15
|
+
# can use scalar * tensor math to handle it
|
|
16
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
17
|
+
|
|
18
|
+
# If true, we will convert Views into torchax.Tensors eagerly
|
|
19
|
+
force_materialize_views: bool = False
|
|
20
|
+
|
|
21
|
+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
22
|
+
use_dlpack_for_data_conversion: bool = False
|
|
23
|
+
|
|
24
|
+
# Flash attention
|
|
25
|
+
use_tpu_flash_attention: bool = False
|
|
26
|
+
shmap_flash_attention: bool = False
|
|
27
|
+
|
|
28
|
+
# device
|
|
29
|
+
treat_cuda_as_jax_device: bool = True
|
|
30
|
+
internal_respect_torch_return_dtypes: bool = False
|
torchax/configuration.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclasses.dataclass
|
|
5
|
+
class Configuration:
|
|
6
|
+
debug_print_each_op: bool = False
|
|
7
|
+
debug_accuracy_for_each_op: bool = False
|
|
8
|
+
debug_mixed_tensor: bool = False
|
|
9
|
+
debug_print_each_op_operands: bool = False
|
|
10
|
+
|
|
11
|
+
use_int32_for_index: bool = False
|
|
12
|
+
|
|
13
|
+
# normally, math between CPU torch.Tensor with torchax.Tensor is not
|
|
14
|
+
# allowed. However, if that torch.Tensor happens to be scalar, then we
|
|
15
|
+
# can use scalar * tensor math to handle it
|
|
16
|
+
allow_mixed_math_with_scalar_tensor: bool = True
|
|
17
|
+
|
|
18
|
+
# If true, we will convert Views into torchax.Tensors eagerly
|
|
19
|
+
force_materialize_views: bool = False
|
|
20
|
+
|
|
21
|
+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
22
|
+
use_dlpack_for_data_conversion: bool = False
|
|
23
|
+
|
|
24
|
+
# Flash attention
|
|
25
|
+
use_tpu_flash_attention: bool = False
|
|
26
|
+
shmap_flash_attention: bool = False
|
|
27
|
+
|
|
28
|
+
# device
|
|
29
|
+
treat_cuda_as_jax_device: bool = True
|
|
30
|
+
internal_respect_torch_return_dtypes: bool = False
|