torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +57 -19
- torchax/amp.py +333 -0
- torchax/config.py +19 -12
- torchax/decompositions.py +663 -195
- torchax/device_module.py +7 -1
- torchax/distributed.py +55 -60
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +275 -141
- torchax/mesh_util.py +211 -0
- torchax/ops/jaten.py +1718 -1294
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +219 -78
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +417 -275
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/METADATA +111 -145
- torchax-0.0.5.dist-info/RECORD +32 -0
- torchax/environment.py +0 -2
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/licenses/LICENSE +0 -0
torchax/CONTRIBUTING.md
CHANGED
|
@@ -8,7 +8,7 @@ If you plan to contribute new features, utility functions or extensions to the c
|
|
|
8
8
|
# Developer setup
|
|
9
9
|
|
|
10
10
|
## Mac setup:
|
|
11
|
-
@qihqi
|
|
11
|
+
@qihqi
|
|
12
12
|
|
|
13
13
|
I am able to develop directly on mac (m1) laptop for most of parts. Using steps
|
|
14
14
|
in README.md works. The condensed version for easy copy & paste:
|
|
@@ -24,7 +24,7 @@ pytest test
|
|
|
24
24
|
|
|
25
25
|
### VSCode
|
|
26
26
|
|
|
27
|
-
I use vscode on my Mac. I loosely followed instruction in
|
|
27
|
+
I use vscode on my Mac. I loosely followed instruction in
|
|
28
28
|
https://code.visualstudio.com/docs/python/python-tutorial
|
|
29
29
|
to setup a proper python environment.
|
|
30
30
|
|
torchax/__init__.py
CHANGED
|
@@ -7,28 +7,31 @@ import torch
|
|
|
7
7
|
from torch.utils import _pytree as pytree
|
|
8
8
|
from torchax import tensor
|
|
9
9
|
from torchax import distributed # noqa: F401
|
|
10
|
+
from contextlib import contextmanager
|
|
10
11
|
|
|
11
|
-
__version__ = "0.0.
|
|
12
|
+
__version__ = "0.0.5"
|
|
12
13
|
VERSION = __version__
|
|
13
14
|
|
|
14
15
|
__all__ = [
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
16
|
+
'default_env',
|
|
17
|
+
'extract_jax',
|
|
18
|
+
'enable_globally',
|
|
18
19
|
]
|
|
19
20
|
|
|
20
21
|
from jax._src import xla_bridge
|
|
22
|
+
|
|
21
23
|
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
|
|
22
24
|
|
|
23
25
|
# torchax:oss-begin
|
|
24
26
|
if getattr(jax.config, 'jax_pjrt_client_create_options', None):
|
|
25
27
|
jax.config.update(
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
)
|
|
28
|
+
'jax_pjrt_client_create_options',
|
|
29
|
+
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}')
|
|
29
30
|
# torchax:oss-end
|
|
30
31
|
|
|
31
32
|
env = None
|
|
33
|
+
|
|
34
|
+
|
|
32
35
|
def default_env():
|
|
33
36
|
global env
|
|
34
37
|
|
|
@@ -37,14 +40,14 @@ def default_env():
|
|
|
37
40
|
return env
|
|
38
41
|
|
|
39
42
|
|
|
40
|
-
|
|
41
43
|
def extract_jax(mod: torch.nn.Module, env=None):
|
|
42
44
|
"""Returns a pytree of jax.ndarray and a jax callable."""
|
|
43
45
|
if env is None:
|
|
44
46
|
env = default_env()
|
|
45
|
-
states = mod.
|
|
47
|
+
states = dict(mod.named_buffers())
|
|
48
|
+
states.update(mod.named_parameters())
|
|
46
49
|
|
|
47
|
-
states =
|
|
50
|
+
states = env.t2j_copy(states)
|
|
48
51
|
|
|
49
52
|
#@jax.jit
|
|
50
53
|
def jax_func(states, inputs):
|
|
@@ -55,20 +58,23 @@ def extract_jax(mod: torch.nn.Module, env=None):
|
|
|
55
58
|
|
|
56
59
|
return states, jax_func
|
|
57
60
|
|
|
61
|
+
|
|
58
62
|
def enable_globally():
|
|
59
63
|
env = default_env().enable_torch_modes()
|
|
60
64
|
return env
|
|
61
65
|
|
|
66
|
+
|
|
62
67
|
def disable_globally():
|
|
63
|
-
global env
|
|
68
|
+
global env
|
|
64
69
|
default_env().disable_torch_modes()
|
|
65
70
|
|
|
71
|
+
|
|
66
72
|
@contextlib.contextmanager
|
|
67
73
|
def disable_temporarily():
|
|
68
74
|
prev = default_env().enabled
|
|
69
75
|
if prev:
|
|
70
76
|
disable_globally()
|
|
71
|
-
yield()
|
|
77
|
+
yield ()
|
|
72
78
|
if prev:
|
|
73
79
|
enable_globally()
|
|
74
80
|
|
|
@@ -76,14 +82,15 @@ def disable_temporarily():
|
|
|
76
82
|
torch.utils.rename_privateuse1_backend('jax')
|
|
77
83
|
unsupported_dtype = [torch.quint8]
|
|
78
84
|
torch.utils.generate_methods_for_privateuse1_backend(
|
|
79
|
-
|
|
80
|
-
|
|
85
|
+
for_tensor=True,
|
|
86
|
+
for_module=True,
|
|
87
|
+
for_storage=True,
|
|
88
|
+
unsupported_dtype=unsupported_dtype)
|
|
81
89
|
|
|
82
90
|
import jax
|
|
83
91
|
import torchax.device_module
|
|
84
|
-
torch._register_device_module('jax', torchax.device_module)
|
|
85
|
-
|
|
86
92
|
|
|
93
|
+
torch._register_device_module('jax', torchax.device_module)
|
|
87
94
|
|
|
88
95
|
|
|
89
96
|
def enable_accuracy_mode():
|
|
@@ -98,13 +105,13 @@ def enable_performance_mode():
|
|
|
98
105
|
default_env().config.internal_respect_torch_return_dtypes = False
|
|
99
106
|
|
|
100
107
|
|
|
101
|
-
|
|
102
108
|
@dataclasses.dataclass
|
|
103
109
|
class CompileOptions:
|
|
104
110
|
# only valid if compiling nn.Module
|
|
105
|
-
methods_to_compile: List[str] = dataclasses.field(
|
|
111
|
+
methods_to_compile: List[str] = dataclasses.field(
|
|
112
|
+
default_factory=lambda: ['forward'])
|
|
106
113
|
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
107
|
-
mode: str = 'jax'
|
|
114
|
+
mode: str = 'jax' # or dynamo or export
|
|
108
115
|
|
|
109
116
|
|
|
110
117
|
def compile(fn, options: Optional[CompileOptions] = None):
|
|
@@ -122,3 +129,34 @@ def compile(fn, options: Optional[CompileOptions] = None):
|
|
|
122
129
|
raise RuntimeError('dynamo mode is not supported yet')
|
|
123
130
|
elif options.mode == 'export':
|
|
124
131
|
raise RuntimeError('export mode is not supported yet')
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
@contextmanager
|
|
135
|
+
def jax_device(target_device: str, env: tensor.Environment | None = None):
|
|
136
|
+
"""
|
|
137
|
+
to("jax") cannot differentiate the device/platform (cpu vs tpu).
|
|
138
|
+
Use this context manager to control jax array's storage device
|
|
139
|
+
|
|
140
|
+
Examples:
|
|
141
|
+
|
|
142
|
+
a = torch.ones(3, 3)
|
|
143
|
+
|
|
144
|
+
with jax_device("cpu"):
|
|
145
|
+
b = a.to("jax")
|
|
146
|
+
|
|
147
|
+
with jax_device("tpu"):
|
|
148
|
+
c = a.to("jax")
|
|
149
|
+
|
|
150
|
+
with jax_device("tpu"):
|
|
151
|
+
c = b.to("jax")
|
|
152
|
+
|
|
153
|
+
"""
|
|
154
|
+
if env is None:
|
|
155
|
+
env = default_env()
|
|
156
|
+
|
|
157
|
+
prev_target_device = env.target_device
|
|
158
|
+
try:
|
|
159
|
+
env.target_device = target_device
|
|
160
|
+
yield env
|
|
161
|
+
finally:
|
|
162
|
+
env.target_device = prev_target_device
|
torchax/amp.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import enum
|
|
3
|
+
import torch
|
|
4
|
+
from torch.utils import _pytree as pytree
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
# enum class CastPolicy : uint8_t {
|
|
8
|
+
# lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
|
|
9
|
+
# // running the op. Currently, lower_precision_fp is
|
|
10
|
+
# // fp16 for AutocastCUDA, and is defined by user
|
|
11
|
+
# // (default bf16) for AutocastCPU or other device.
|
|
12
|
+
# fp32, // Cast all inputs to at::kFloat before running the op.
|
|
13
|
+
# fp32_set_opt_dtype, // Treats functions (like softmax) that
|
|
14
|
+
# // 1. we'd like to run in fp32 and
|
|
15
|
+
# // 2. have a std::optional<ScalarType> arg that controls
|
|
16
|
+
# // the output type.
|
|
17
|
+
# // fp32_set_opt_dtype wrappers' policy is: if the output
|
|
18
|
+
# // type is already set, don't touch it, otherwise, set
|
|
19
|
+
# // it to at::kFloat.
|
|
20
|
+
# fp32_append_dtype, // Treats functions (like norm) that
|
|
21
|
+
# // 1. we'd like to run in fp32 and
|
|
22
|
+
# // 2. have some overloads that accept an output type and
|
|
23
|
+
# // other overloads that don't.
|
|
24
|
+
# // fp32_append_dtype wrappers wrap the overloads that don't
|
|
25
|
+
# // have an output dtype.
|
|
26
|
+
# // The wrapper policy is: append at::kFloat to the args,
|
|
27
|
+
# // and redispatch to the type-aware overload.
|
|
28
|
+
# promote, // Run in the widest dtype among several args.
|
|
29
|
+
# };
|
|
30
|
+
class CastPolicy(enum.Enum):
|
|
31
|
+
LOWER_PRECISION_FP = 0
|
|
32
|
+
FP32 = 1
|
|
33
|
+
FP32_SET_OPT_DTYPE = 2
|
|
34
|
+
FP32_APPEND_DTYPE = 3
|
|
35
|
+
PROMOTE = 4
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def execute_policy(policy, args, kwargs, target_lower_fp):
|
|
39
|
+
|
|
40
|
+
def is_float(a):
|
|
41
|
+
return isinstance(a, torch.Tensor) and a.is_floating_point()
|
|
42
|
+
match policy:
|
|
43
|
+
case CastPolicy.LOWER_PRECISION_FP:
|
|
44
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
|
|
45
|
+
(args, kwargs))
|
|
46
|
+
case CastPolicy.FP32:
|
|
47
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
|
|
48
|
+
(args, kwargs))
|
|
49
|
+
case CastPolicy.PROMOTE:
|
|
50
|
+
dtypes = set(a.dtype for a in args)
|
|
51
|
+
widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
|
|
52
|
+
return pytree.tree_map_only(is_float, lambda a: a.to(widest),
|
|
53
|
+
(args, kwargs))
|
|
54
|
+
case _:
|
|
55
|
+
raise AssertionError(f'Policy {policy} not implemented yet.')
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@contextlib.contextmanager
|
|
59
|
+
def autocast(device, dtype=torch.bfloat16, env=None):
|
|
60
|
+
del device
|
|
61
|
+
if env is None:
|
|
62
|
+
import torchax
|
|
63
|
+
env = torchax.default_env()
|
|
64
|
+
env.autocast_dtype, old = dtype, env.autocast_dtype
|
|
65
|
+
yield
|
|
66
|
+
env.autocast_dtype = old
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
# https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
|
|
70
|
+
autocast_policy = {
|
|
71
|
+
torch.ops.aten.conv1d.default:
|
|
72
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
73
|
+
torch.ops.aten.conv1d.padding:
|
|
74
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
75
|
+
torch.ops.aten.conv2d.default:
|
|
76
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
77
|
+
torch.ops.aten.conv2d.padding:
|
|
78
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
79
|
+
torch.ops.aten.conv3d.default:
|
|
80
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
81
|
+
torch.ops.aten.conv3d.padding:
|
|
82
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
83
|
+
torch.ops.aten.bmm.default:
|
|
84
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
85
|
+
torch.ops.aten.mm.default:
|
|
86
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
87
|
+
torch.ops.aten.linalg_vecdot.default:
|
|
88
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
89
|
+
torch.ops.aten.baddbmm.default:
|
|
90
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
91
|
+
torch.ops.aten.addmm.default:
|
|
92
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
93
|
+
torch.ops.aten._addmm_activation.default:
|
|
94
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
95
|
+
torch.ops.aten.addbmm.default:
|
|
96
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
97
|
+
torch.ops.aten.linear.default:
|
|
98
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
99
|
+
torch.ops.aten._convolution.deprecated:
|
|
100
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
101
|
+
torch.ops.aten.matmul.default:
|
|
102
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
103
|
+
torch.ops.aten.conv_tbc.default:
|
|
104
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
105
|
+
torch.ops.aten.mkldnn_rnn_layer.default:
|
|
106
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
107
|
+
torch.ops.aten.conv_transpose1d.default:
|
|
108
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
109
|
+
torch.ops.aten.conv_transpose2d.input:
|
|
110
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
111
|
+
torch.ops.aten.conv_transpose3d.input:
|
|
112
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
113
|
+
torch.ops.aten.prelu.default:
|
|
114
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
115
|
+
torch.ops.aten.scaled_dot_product_attention.default:
|
|
116
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
117
|
+
torch.ops.aten._native_multi_head_attention.default:
|
|
118
|
+
CastPolicy.LOWER_PRECISION_FP,
|
|
119
|
+
|
|
120
|
+
# fp32 cast policy
|
|
121
|
+
torch.ops.aten.avg_pool3d.default:
|
|
122
|
+
CastPolicy.FP32,
|
|
123
|
+
torch.ops.aten.binary_cross_entropy.default:
|
|
124
|
+
CastPolicy.FP32,
|
|
125
|
+
torch.ops.aten.grid_sampler.default:
|
|
126
|
+
CastPolicy.FP32,
|
|
127
|
+
torch.ops.aten.polar.default:
|
|
128
|
+
CastPolicy.FP32,
|
|
129
|
+
torch.ops.aten.prod.default:
|
|
130
|
+
CastPolicy.FP32,
|
|
131
|
+
torch.ops.aten.prod.dim_int:
|
|
132
|
+
CastPolicy.FP32,
|
|
133
|
+
torch.ops.aten.prod.dim_Dimname:
|
|
134
|
+
CastPolicy.FP32,
|
|
135
|
+
torch.ops.aten.quantile.default:
|
|
136
|
+
CastPolicy.FP32,
|
|
137
|
+
torch.ops.aten.quantile.scalar:
|
|
138
|
+
CastPolicy.FP32,
|
|
139
|
+
torch.ops.aten.nanquantile.default:
|
|
140
|
+
CastPolicy.FP32,
|
|
141
|
+
torch.ops.aten.nanquantile.scalar:
|
|
142
|
+
CastPolicy.FP32,
|
|
143
|
+
torch.ops.aten.stft.default:
|
|
144
|
+
CastPolicy.FP32,
|
|
145
|
+
torch.ops.aten.stft.center:
|
|
146
|
+
CastPolicy.FP32,
|
|
147
|
+
torch.ops.aten.cdist.default:
|
|
148
|
+
CastPolicy.FP32,
|
|
149
|
+
torch.ops.aten.grid_sampler_2d.default:
|
|
150
|
+
CastPolicy.FP32,
|
|
151
|
+
torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
|
|
152
|
+
CastPolicy.FP32,
|
|
153
|
+
torch.ops.aten.grid_sampler_3d.default:
|
|
154
|
+
CastPolicy.FP32,
|
|
155
|
+
torch.ops.aten.trace.default:
|
|
156
|
+
CastPolicy.FP32,
|
|
157
|
+
torch.ops.aten.view_as_complex.default:
|
|
158
|
+
CastPolicy.FP32,
|
|
159
|
+
torch.ops.aten.cholesky.default:
|
|
160
|
+
CastPolicy.FP32,
|
|
161
|
+
torch.ops.aten.cholesky_inverse.default:
|
|
162
|
+
CastPolicy.FP32,
|
|
163
|
+
torch.ops.aten.cholesky_solve.default:
|
|
164
|
+
CastPolicy.FP32,
|
|
165
|
+
torch.ops.aten.inverse.default:
|
|
166
|
+
CastPolicy.FP32,
|
|
167
|
+
torch.ops.aten.lu_solve.default:
|
|
168
|
+
CastPolicy.FP32,
|
|
169
|
+
torch.ops.aten.orgqr.default:
|
|
170
|
+
CastPolicy.FP32,
|
|
171
|
+
torch.ops.aten.ormqr.default:
|
|
172
|
+
CastPolicy.FP32,
|
|
173
|
+
torch.ops.aten.pinverse.default:
|
|
174
|
+
CastPolicy.FP32,
|
|
175
|
+
torch.ops.aten.max_pool3d.default:
|
|
176
|
+
CastPolicy.FP32,
|
|
177
|
+
torch.ops.aten.max_unpool2d.default:
|
|
178
|
+
CastPolicy.FP32,
|
|
179
|
+
torch.ops.aten.max_unpool3d.default:
|
|
180
|
+
CastPolicy.FP32,
|
|
181
|
+
torch.ops.aten.adaptive_avg_pool3d.default:
|
|
182
|
+
CastPolicy.FP32,
|
|
183
|
+
torch.ops.aten.reflection_pad1d.default:
|
|
184
|
+
CastPolicy.FP32,
|
|
185
|
+
torch.ops.aten.reflection_pad2d.default:
|
|
186
|
+
CastPolicy.FP32,
|
|
187
|
+
torch.ops.aten.replication_pad1d.default:
|
|
188
|
+
CastPolicy.FP32,
|
|
189
|
+
torch.ops.aten.replication_pad2d.default:
|
|
190
|
+
CastPolicy.FP32,
|
|
191
|
+
torch.ops.aten.replication_pad3d.default:
|
|
192
|
+
CastPolicy.FP32,
|
|
193
|
+
torch.ops.aten.mse_loss.default:
|
|
194
|
+
CastPolicy.FP32,
|
|
195
|
+
torch.ops.aten.cosine_embedding_loss.default:
|
|
196
|
+
CastPolicy.FP32,
|
|
197
|
+
torch.ops.aten.nll_loss.default:
|
|
198
|
+
CastPolicy.FP32,
|
|
199
|
+
torch.ops.aten.nll_loss2d.default:
|
|
200
|
+
CastPolicy.FP32,
|
|
201
|
+
torch.ops.aten.hinge_embedding_loss.default:
|
|
202
|
+
CastPolicy.FP32,
|
|
203
|
+
torch.ops.aten.poisson_nll_loss.default:
|
|
204
|
+
CastPolicy.FP32,
|
|
205
|
+
torch.ops.aten.smooth_l1_loss.default:
|
|
206
|
+
CastPolicy.FP32,
|
|
207
|
+
torch.ops.aten.cross_entropy_loss.default:
|
|
208
|
+
CastPolicy.FP32,
|
|
209
|
+
torch.ops.aten.l1_loss.default:
|
|
210
|
+
CastPolicy.FP32,
|
|
211
|
+
torch.ops.aten.huber_loss.default:
|
|
212
|
+
CastPolicy.FP32,
|
|
213
|
+
torch.ops.aten.margin_ranking_loss.default:
|
|
214
|
+
CastPolicy.FP32,
|
|
215
|
+
torch.ops.aten.soft_margin_loss.default:
|
|
216
|
+
CastPolicy.FP32,
|
|
217
|
+
torch.ops.aten.triplet_margin_loss.default:
|
|
218
|
+
CastPolicy.FP32,
|
|
219
|
+
torch.ops.aten.multi_margin_loss.default:
|
|
220
|
+
CastPolicy.FP32,
|
|
221
|
+
torch.ops.aten.ctc_loss.IntList:
|
|
222
|
+
CastPolicy.FP32,
|
|
223
|
+
torch.ops.aten.ctc_loss.Tensor:
|
|
224
|
+
CastPolicy.FP32,
|
|
225
|
+
torch.ops.aten.kl_div.default:
|
|
226
|
+
CastPolicy.FP32,
|
|
227
|
+
torch.ops.aten.multilabel_margin_loss.default:
|
|
228
|
+
CastPolicy.FP32,
|
|
229
|
+
torch.ops.aten.binary_cross_entropy_with_logits.default:
|
|
230
|
+
CastPolicy.FP32,
|
|
231
|
+
torch.ops.aten.fft_fft.default:
|
|
232
|
+
CastPolicy.FP32,
|
|
233
|
+
torch.ops.aten.fft_ifft.default:
|
|
234
|
+
CastPolicy.FP32,
|
|
235
|
+
torch.ops.aten.fft_fft2.default:
|
|
236
|
+
CastPolicy.FP32,
|
|
237
|
+
torch.ops.aten.fft_ifft2.default:
|
|
238
|
+
CastPolicy.FP32,
|
|
239
|
+
torch.ops.aten.fft_fftn.default:
|
|
240
|
+
CastPolicy.FP32,
|
|
241
|
+
torch.ops.aten.fft_ifftn.default:
|
|
242
|
+
CastPolicy.FP32,
|
|
243
|
+
torch.ops.aten.fft_rfft.default:
|
|
244
|
+
CastPolicy.FP32,
|
|
245
|
+
torch.ops.aten.fft_irfft.default:
|
|
246
|
+
CastPolicy.FP32,
|
|
247
|
+
torch.ops.aten.fft_rfft2.default:
|
|
248
|
+
CastPolicy.FP32,
|
|
249
|
+
torch.ops.aten.fft_irfft2.default:
|
|
250
|
+
CastPolicy.FP32,
|
|
251
|
+
torch.ops.aten.fft_rfftn.default:
|
|
252
|
+
CastPolicy.FP32,
|
|
253
|
+
torch.ops.aten.fft_irfftn.default:
|
|
254
|
+
CastPolicy.FP32,
|
|
255
|
+
torch.ops.aten.fft_hfft.default:
|
|
256
|
+
CastPolicy.FP32,
|
|
257
|
+
torch.ops.aten.fft_ihfft.default:
|
|
258
|
+
CastPolicy.FP32,
|
|
259
|
+
torch.ops.aten.linalg_cond.default:
|
|
260
|
+
CastPolicy.FP32,
|
|
261
|
+
torch.ops.aten.linalg_cond.p_str:
|
|
262
|
+
CastPolicy.FP32,
|
|
263
|
+
torch.ops.aten.linalg_matrix_rank.default:
|
|
264
|
+
CastPolicy.FP32,
|
|
265
|
+
torch.ops.aten.linalg_matrix_rank.tol_tensor:
|
|
266
|
+
CastPolicy.FP32,
|
|
267
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
|
|
268
|
+
CastPolicy.FP32,
|
|
269
|
+
torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
|
|
270
|
+
CastPolicy.FP32,
|
|
271
|
+
torch.ops.aten.linalg_solve.default:
|
|
272
|
+
CastPolicy.FP32,
|
|
273
|
+
torch.ops.aten.linalg_cholesky.default:
|
|
274
|
+
CastPolicy.FP32,
|
|
275
|
+
torch.ops.aten.linalg_svdvals.default:
|
|
276
|
+
CastPolicy.FP32,
|
|
277
|
+
torch.ops.aten.linalg_eigvals.default:
|
|
278
|
+
CastPolicy.FP32,
|
|
279
|
+
torch.ops.aten.linalg_eigvalsh.default:
|
|
280
|
+
CastPolicy.FP32,
|
|
281
|
+
torch.ops.aten.linalg_inv.default:
|
|
282
|
+
CastPolicy.FP32,
|
|
283
|
+
torch.ops.aten.linalg_householder_product.default:
|
|
284
|
+
CastPolicy.FP32,
|
|
285
|
+
torch.ops.aten.linalg_tensorinv.default:
|
|
286
|
+
CastPolicy.FP32,
|
|
287
|
+
torch.ops.aten.linalg_tensorsolve.default:
|
|
288
|
+
CastPolicy.FP32,
|
|
289
|
+
torch.ops.aten.fake_quantize_per_tensor_affine.default:
|
|
290
|
+
CastPolicy.FP32,
|
|
291
|
+
torch.ops.aten.geqrf.default:
|
|
292
|
+
CastPolicy.FP32,
|
|
293
|
+
torch.ops.aten._lu_with_info.default:
|
|
294
|
+
CastPolicy.FP32,
|
|
295
|
+
torch.ops.aten.qr.default:
|
|
296
|
+
CastPolicy.FP32,
|
|
297
|
+
torch.ops.aten.svd.default:
|
|
298
|
+
CastPolicy.FP32,
|
|
299
|
+
torch.ops.aten.triangular_solve.default:
|
|
300
|
+
CastPolicy.FP32,
|
|
301
|
+
torch.ops.aten.fractional_max_pool2d.default:
|
|
302
|
+
CastPolicy.FP32,
|
|
303
|
+
torch.ops.aten.fractional_max_pool3d.default:
|
|
304
|
+
CastPolicy.FP32,
|
|
305
|
+
torch.ops.aten.adaptive_max_pool3d.default:
|
|
306
|
+
CastPolicy.FP32,
|
|
307
|
+
torch.ops.aten.multilabel_margin_loss_forward.default:
|
|
308
|
+
CastPolicy.FP32,
|
|
309
|
+
torch.ops.aten.linalg_qr.default:
|
|
310
|
+
CastPolicy.FP32,
|
|
311
|
+
torch.ops.aten.linalg_cholesky_ex.default:
|
|
312
|
+
CastPolicy.FP32,
|
|
313
|
+
torch.ops.aten.linalg_svd.default:
|
|
314
|
+
CastPolicy.FP32,
|
|
315
|
+
torch.ops.aten.linalg_eig.default:
|
|
316
|
+
CastPolicy.FP32,
|
|
317
|
+
torch.ops.aten.linalg_eigh.default:
|
|
318
|
+
CastPolicy.FP32,
|
|
319
|
+
torch.ops.aten.linalg_lstsq.default:
|
|
320
|
+
CastPolicy.FP32,
|
|
321
|
+
torch.ops.aten.linalg_inv_ex.default:
|
|
322
|
+
CastPolicy.FP32,
|
|
323
|
+
|
|
324
|
+
# promote
|
|
325
|
+
torch.ops.aten.stack.default:
|
|
326
|
+
CastPolicy.PROMOTE,
|
|
327
|
+
torch.ops.aten.cat.default:
|
|
328
|
+
CastPolicy.PROMOTE,
|
|
329
|
+
torch.ops.aten.index_copy.default:
|
|
330
|
+
CastPolicy.PROMOTE,
|
|
331
|
+
torch.ops.aten.index_copy.dimname:
|
|
332
|
+
CastPolicy.PROMOTE,
|
|
333
|
+
}
|
torchax/config.py
CHANGED
|
@@ -3,17 +3,24 @@ import dataclasses
|
|
|
3
3
|
|
|
4
4
|
@dataclasses.dataclass
|
|
5
5
|
class Configuration:
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
use_int32_for_index: bool = False
|
|
6
|
+
debug_print_each_op: bool = False
|
|
7
|
+
debug_accuracy_for_each_op: bool = False
|
|
8
|
+
debug_mixed_tensor: bool = False
|
|
9
|
+
debug_print_each_op_operands: bool = False
|
|
11
10
|
|
|
12
|
-
|
|
13
|
-
use_tpu_flash_attention: bool = False
|
|
14
|
-
shmap_flash_attention: bool = False
|
|
11
|
+
use_int32_for_index: bool = False
|
|
15
12
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
13
|
+
# If true, we will convert Views into torchax.Tensors eagerly
|
|
14
|
+
force_materialize_views: bool = False
|
|
15
|
+
|
|
16
|
+
# Use DLPack for converting jax.Arrays <-> and torch.Tensor
|
|
17
|
+
use_dlpack_for_data_conversion: bool = False
|
|
18
|
+
|
|
19
|
+
# Flash attention
|
|
20
|
+
use_tpu_flash_attention: bool = False
|
|
21
|
+
shmap_flash_attention: bool = False
|
|
22
|
+
|
|
23
|
+
# device
|
|
24
|
+
treat_cuda_as_jax_device: bool = True
|
|
25
|
+
use_torch_native_for_cpu_tensor: bool = True
|
|
26
|
+
internal_respect_torch_return_dtypes: bool = False
|