torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/CONTRIBUTING.md CHANGED
@@ -8,7 +8,7 @@ If you plan to contribute new features, utility functions or extensions to the c
8
8
  # Developer setup
9
9
 
10
10
  ## Mac setup:
11
- @qihqi
11
+ @qihqi
12
12
 
13
13
  I am able to develop directly on mac (m1) laptop for most of parts. Using steps
14
14
  in README.md works. The condensed version for easy copy & paste:
@@ -24,7 +24,7 @@ pytest test
24
24
 
25
25
  ### VSCode
26
26
 
27
- I use vscode on my Mac. I loosely followed instruction in
27
+ I use vscode on my Mac. I loosely followed instruction in
28
28
  https://code.visualstudio.com/docs/python/python-tutorial
29
29
  to setup a proper python environment.
30
30
 
torchax/__init__.py CHANGED
@@ -7,28 +7,31 @@ import torch
7
7
  from torch.utils import _pytree as pytree
8
8
  from torchax import tensor
9
9
  from torchax import distributed # noqa: F401
10
+ from contextlib import contextmanager
10
11
 
11
- __version__ = "0.0.4"
12
+ __version__ = "0.0.5"
12
13
  VERSION = __version__
13
14
 
14
15
  __all__ = [
15
- 'default_env',
16
- 'extract_jax',
17
- 'enable_globally',
16
+ 'default_env',
17
+ 'extract_jax',
18
+ 'enable_globally',
18
19
  ]
19
20
 
20
21
  from jax._src import xla_bridge
22
+
21
23
  os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
22
24
 
23
25
  # torchax:oss-begin
24
26
  if getattr(jax.config, 'jax_pjrt_client_create_options', None):
25
27
  jax.config.update(
26
- 'jax_pjrt_client_create_options',
27
- f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
28
- )
28
+ 'jax_pjrt_client_create_options',
29
+ f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}')
29
30
  # torchax:oss-end
30
31
 
31
32
  env = None
33
+
34
+
32
35
  def default_env():
33
36
  global env
34
37
 
@@ -37,14 +40,14 @@ def default_env():
37
40
  return env
38
41
 
39
42
 
40
-
41
43
  def extract_jax(mod: torch.nn.Module, env=None):
42
44
  """Returns a pytree of jax.ndarray and a jax callable."""
43
45
  if env is None:
44
46
  env = default_env()
45
- states = mod.state_dict()
47
+ states = dict(mod.named_buffers())
48
+ states.update(mod.named_parameters())
46
49
 
47
- states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
50
+ states = env.t2j_copy(states)
48
51
 
49
52
  #@jax.jit
50
53
  def jax_func(states, inputs):
@@ -55,20 +58,23 @@ def extract_jax(mod: torch.nn.Module, env=None):
55
58
 
56
59
  return states, jax_func
57
60
 
61
+
58
62
  def enable_globally():
59
63
  env = default_env().enable_torch_modes()
60
64
  return env
61
65
 
66
+
62
67
  def disable_globally():
63
- global env
68
+ global env
64
69
  default_env().disable_torch_modes()
65
70
 
71
+
66
72
  @contextlib.contextmanager
67
73
  def disable_temporarily():
68
74
  prev = default_env().enabled
69
75
  if prev:
70
76
  disable_globally()
71
- yield()
77
+ yield ()
72
78
  if prev:
73
79
  enable_globally()
74
80
 
@@ -76,14 +82,15 @@ def disable_temporarily():
76
82
  torch.utils.rename_privateuse1_backend('jax')
77
83
  unsupported_dtype = [torch.quint8]
78
84
  torch.utils.generate_methods_for_privateuse1_backend(
79
- for_tensor=True, for_module=True, for_storage=True,
80
- unsupported_dtype=unsupported_dtype)
85
+ for_tensor=True,
86
+ for_module=True,
87
+ for_storage=True,
88
+ unsupported_dtype=unsupported_dtype)
81
89
 
82
90
  import jax
83
91
  import torchax.device_module
84
- torch._register_device_module('jax', torchax.device_module)
85
-
86
92
 
93
+ torch._register_device_module('jax', torchax.device_module)
87
94
 
88
95
 
89
96
  def enable_accuracy_mode():
@@ -98,13 +105,13 @@ def enable_performance_mode():
98
105
  default_env().config.internal_respect_torch_return_dtypes = False
99
106
 
100
107
 
101
-
102
108
  @dataclasses.dataclass
103
109
  class CompileOptions:
104
110
  # only valid if compiling nn.Module
105
- methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward'])
111
+ methods_to_compile: List[str] = dataclasses.field(
112
+ default_factory=lambda: ['forward'])
106
113
  jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
107
- mode: str = 'jax' # or dynamo or export
114
+ mode: str = 'jax' # or dynamo or export
108
115
 
109
116
 
110
117
  def compile(fn, options: Optional[CompileOptions] = None):
@@ -122,3 +129,34 @@ def compile(fn, options: Optional[CompileOptions] = None):
122
129
  raise RuntimeError('dynamo mode is not supported yet')
123
130
  elif options.mode == 'export':
124
131
  raise RuntimeError('export mode is not supported yet')
132
+
133
+
134
+ @contextmanager
135
+ def jax_device(target_device: str, env: tensor.Environment | None = None):
136
+ """
137
+ to("jax") cannot differentiate the device/platform (cpu vs tpu).
138
+ Use this context manager to control jax array's storage device
139
+
140
+ Examples:
141
+
142
+ a = torch.ones(3, 3)
143
+
144
+ with jax_device("cpu"):
145
+ b = a.to("jax")
146
+
147
+ with jax_device("tpu"):
148
+ c = a.to("jax")
149
+
150
+ with jax_device("tpu"):
151
+ c = b.to("jax")
152
+
153
+ """
154
+ if env is None:
155
+ env = default_env()
156
+
157
+ prev_target_device = env.target_device
158
+ try:
159
+ env.target_device = target_device
160
+ yield env
161
+ finally:
162
+ env.target_device = prev_target_device
torchax/amp.py ADDED
@@ -0,0 +1,333 @@
1
+ import contextlib
2
+ import enum
3
+ import torch
4
+ from torch.utils import _pytree as pytree
5
+
6
+
7
+ # enum class CastPolicy : uint8_t {
8
+ # lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before
9
+ # // running the op. Currently, lower_precision_fp is
10
+ # // fp16 for AutocastCUDA, and is defined by user
11
+ # // (default bf16) for AutocastCPU or other device.
12
+ # fp32, // Cast all inputs to at::kFloat before running the op.
13
+ # fp32_set_opt_dtype, // Treats functions (like softmax) that
14
+ # // 1. we'd like to run in fp32 and
15
+ # // 2. have a std::optional<ScalarType> arg that controls
16
+ # // the output type.
17
+ # // fp32_set_opt_dtype wrappers' policy is: if the output
18
+ # // type is already set, don't touch it, otherwise, set
19
+ # // it to at::kFloat.
20
+ # fp32_append_dtype, // Treats functions (like norm) that
21
+ # // 1. we'd like to run in fp32 and
22
+ # // 2. have some overloads that accept an output type and
23
+ # // other overloads that don't.
24
+ # // fp32_append_dtype wrappers wrap the overloads that don't
25
+ # // have an output dtype.
26
+ # // The wrapper policy is: append at::kFloat to the args,
27
+ # // and redispatch to the type-aware overload.
28
+ # promote, // Run in the widest dtype among several args.
29
+ # };
30
+ class CastPolicy(enum.Enum):
31
+ LOWER_PRECISION_FP = 0
32
+ FP32 = 1
33
+ FP32_SET_OPT_DTYPE = 2
34
+ FP32_APPEND_DTYPE = 3
35
+ PROMOTE = 4
36
+
37
+
38
+ def execute_policy(policy, args, kwargs, target_lower_fp):
39
+
40
+ def is_float(a):
41
+ return isinstance(a, torch.Tensor) and a.is_floating_point()
42
+ match policy:
43
+ case CastPolicy.LOWER_PRECISION_FP:
44
+ return pytree.tree_map_only(is_float, lambda a: a.to(target_lower_fp),
45
+ (args, kwargs))
46
+ case CastPolicy.FP32:
47
+ return pytree.tree_map_only(is_float, lambda a: a.to(torch.float32),
48
+ (args, kwargs))
49
+ case CastPolicy.PROMOTE:
50
+ dtypes = set(a.dtype for a in args)
51
+ widest = max((dtype.itemsize, dtype) for dtype in dtypes)[1]
52
+ return pytree.tree_map_only(is_float, lambda a: a.to(widest),
53
+ (args, kwargs))
54
+ case _:
55
+ raise AssertionError(f'Policy {policy} not implemented yet.')
56
+
57
+
58
+ @contextlib.contextmanager
59
+ def autocast(device, dtype=torch.bfloat16, env=None):
60
+ del device
61
+ if env is None:
62
+ import torchax
63
+ env = torchax.default_env()
64
+ env.autocast_dtype, old = dtype, env.autocast_dtype
65
+ yield
66
+ env.autocast_dtype = old
67
+
68
+
69
+ # https://github.com/pytorch/pytorch/blob/05faba40287cf7d8734da96cb2e904f39710bf29/aten/src/ATen/autocast_mode.cpp#L327
70
+ autocast_policy = {
71
+ torch.ops.aten.conv1d.default:
72
+ CastPolicy.LOWER_PRECISION_FP,
73
+ torch.ops.aten.conv1d.padding:
74
+ CastPolicy.LOWER_PRECISION_FP,
75
+ torch.ops.aten.conv2d.default:
76
+ CastPolicy.LOWER_PRECISION_FP,
77
+ torch.ops.aten.conv2d.padding:
78
+ CastPolicy.LOWER_PRECISION_FP,
79
+ torch.ops.aten.conv3d.default:
80
+ CastPolicy.LOWER_PRECISION_FP,
81
+ torch.ops.aten.conv3d.padding:
82
+ CastPolicy.LOWER_PRECISION_FP,
83
+ torch.ops.aten.bmm.default:
84
+ CastPolicy.LOWER_PRECISION_FP,
85
+ torch.ops.aten.mm.default:
86
+ CastPolicy.LOWER_PRECISION_FP,
87
+ torch.ops.aten.linalg_vecdot.default:
88
+ CastPolicy.LOWER_PRECISION_FP,
89
+ torch.ops.aten.baddbmm.default:
90
+ CastPolicy.LOWER_PRECISION_FP,
91
+ torch.ops.aten.addmm.default:
92
+ CastPolicy.LOWER_PRECISION_FP,
93
+ torch.ops.aten._addmm_activation.default:
94
+ CastPolicy.LOWER_PRECISION_FP,
95
+ torch.ops.aten.addbmm.default:
96
+ CastPolicy.LOWER_PRECISION_FP,
97
+ torch.ops.aten.linear.default:
98
+ CastPolicy.LOWER_PRECISION_FP,
99
+ torch.ops.aten._convolution.deprecated:
100
+ CastPolicy.LOWER_PRECISION_FP,
101
+ torch.ops.aten.matmul.default:
102
+ CastPolicy.LOWER_PRECISION_FP,
103
+ torch.ops.aten.conv_tbc.default:
104
+ CastPolicy.LOWER_PRECISION_FP,
105
+ torch.ops.aten.mkldnn_rnn_layer.default:
106
+ CastPolicy.LOWER_PRECISION_FP,
107
+ torch.ops.aten.conv_transpose1d.default:
108
+ CastPolicy.LOWER_PRECISION_FP,
109
+ torch.ops.aten.conv_transpose2d.input:
110
+ CastPolicy.LOWER_PRECISION_FP,
111
+ torch.ops.aten.conv_transpose3d.input:
112
+ CastPolicy.LOWER_PRECISION_FP,
113
+ torch.ops.aten.prelu.default:
114
+ CastPolicy.LOWER_PRECISION_FP,
115
+ torch.ops.aten.scaled_dot_product_attention.default:
116
+ CastPolicy.LOWER_PRECISION_FP,
117
+ torch.ops.aten._native_multi_head_attention.default:
118
+ CastPolicy.LOWER_PRECISION_FP,
119
+
120
+ # fp32 cast policy
121
+ torch.ops.aten.avg_pool3d.default:
122
+ CastPolicy.FP32,
123
+ torch.ops.aten.binary_cross_entropy.default:
124
+ CastPolicy.FP32,
125
+ torch.ops.aten.grid_sampler.default:
126
+ CastPolicy.FP32,
127
+ torch.ops.aten.polar.default:
128
+ CastPolicy.FP32,
129
+ torch.ops.aten.prod.default:
130
+ CastPolicy.FP32,
131
+ torch.ops.aten.prod.dim_int:
132
+ CastPolicy.FP32,
133
+ torch.ops.aten.prod.dim_Dimname:
134
+ CastPolicy.FP32,
135
+ torch.ops.aten.quantile.default:
136
+ CastPolicy.FP32,
137
+ torch.ops.aten.quantile.scalar:
138
+ CastPolicy.FP32,
139
+ torch.ops.aten.nanquantile.default:
140
+ CastPolicy.FP32,
141
+ torch.ops.aten.nanquantile.scalar:
142
+ CastPolicy.FP32,
143
+ torch.ops.aten.stft.default:
144
+ CastPolicy.FP32,
145
+ torch.ops.aten.stft.center:
146
+ CastPolicy.FP32,
147
+ torch.ops.aten.cdist.default:
148
+ CastPolicy.FP32,
149
+ torch.ops.aten.grid_sampler_2d.default:
150
+ CastPolicy.FP32,
151
+ torch.ops.aten._grid_sampler_2d_cpu_fallback.default:
152
+ CastPolicy.FP32,
153
+ torch.ops.aten.grid_sampler_3d.default:
154
+ CastPolicy.FP32,
155
+ torch.ops.aten.trace.default:
156
+ CastPolicy.FP32,
157
+ torch.ops.aten.view_as_complex.default:
158
+ CastPolicy.FP32,
159
+ torch.ops.aten.cholesky.default:
160
+ CastPolicy.FP32,
161
+ torch.ops.aten.cholesky_inverse.default:
162
+ CastPolicy.FP32,
163
+ torch.ops.aten.cholesky_solve.default:
164
+ CastPolicy.FP32,
165
+ torch.ops.aten.inverse.default:
166
+ CastPolicy.FP32,
167
+ torch.ops.aten.lu_solve.default:
168
+ CastPolicy.FP32,
169
+ torch.ops.aten.orgqr.default:
170
+ CastPolicy.FP32,
171
+ torch.ops.aten.ormqr.default:
172
+ CastPolicy.FP32,
173
+ torch.ops.aten.pinverse.default:
174
+ CastPolicy.FP32,
175
+ torch.ops.aten.max_pool3d.default:
176
+ CastPolicy.FP32,
177
+ torch.ops.aten.max_unpool2d.default:
178
+ CastPolicy.FP32,
179
+ torch.ops.aten.max_unpool3d.default:
180
+ CastPolicy.FP32,
181
+ torch.ops.aten.adaptive_avg_pool3d.default:
182
+ CastPolicy.FP32,
183
+ torch.ops.aten.reflection_pad1d.default:
184
+ CastPolicy.FP32,
185
+ torch.ops.aten.reflection_pad2d.default:
186
+ CastPolicy.FP32,
187
+ torch.ops.aten.replication_pad1d.default:
188
+ CastPolicy.FP32,
189
+ torch.ops.aten.replication_pad2d.default:
190
+ CastPolicy.FP32,
191
+ torch.ops.aten.replication_pad3d.default:
192
+ CastPolicy.FP32,
193
+ torch.ops.aten.mse_loss.default:
194
+ CastPolicy.FP32,
195
+ torch.ops.aten.cosine_embedding_loss.default:
196
+ CastPolicy.FP32,
197
+ torch.ops.aten.nll_loss.default:
198
+ CastPolicy.FP32,
199
+ torch.ops.aten.nll_loss2d.default:
200
+ CastPolicy.FP32,
201
+ torch.ops.aten.hinge_embedding_loss.default:
202
+ CastPolicy.FP32,
203
+ torch.ops.aten.poisson_nll_loss.default:
204
+ CastPolicy.FP32,
205
+ torch.ops.aten.smooth_l1_loss.default:
206
+ CastPolicy.FP32,
207
+ torch.ops.aten.cross_entropy_loss.default:
208
+ CastPolicy.FP32,
209
+ torch.ops.aten.l1_loss.default:
210
+ CastPolicy.FP32,
211
+ torch.ops.aten.huber_loss.default:
212
+ CastPolicy.FP32,
213
+ torch.ops.aten.margin_ranking_loss.default:
214
+ CastPolicy.FP32,
215
+ torch.ops.aten.soft_margin_loss.default:
216
+ CastPolicy.FP32,
217
+ torch.ops.aten.triplet_margin_loss.default:
218
+ CastPolicy.FP32,
219
+ torch.ops.aten.multi_margin_loss.default:
220
+ CastPolicy.FP32,
221
+ torch.ops.aten.ctc_loss.IntList:
222
+ CastPolicy.FP32,
223
+ torch.ops.aten.ctc_loss.Tensor:
224
+ CastPolicy.FP32,
225
+ torch.ops.aten.kl_div.default:
226
+ CastPolicy.FP32,
227
+ torch.ops.aten.multilabel_margin_loss.default:
228
+ CastPolicy.FP32,
229
+ torch.ops.aten.binary_cross_entropy_with_logits.default:
230
+ CastPolicy.FP32,
231
+ torch.ops.aten.fft_fft.default:
232
+ CastPolicy.FP32,
233
+ torch.ops.aten.fft_ifft.default:
234
+ CastPolicy.FP32,
235
+ torch.ops.aten.fft_fft2.default:
236
+ CastPolicy.FP32,
237
+ torch.ops.aten.fft_ifft2.default:
238
+ CastPolicy.FP32,
239
+ torch.ops.aten.fft_fftn.default:
240
+ CastPolicy.FP32,
241
+ torch.ops.aten.fft_ifftn.default:
242
+ CastPolicy.FP32,
243
+ torch.ops.aten.fft_rfft.default:
244
+ CastPolicy.FP32,
245
+ torch.ops.aten.fft_irfft.default:
246
+ CastPolicy.FP32,
247
+ torch.ops.aten.fft_rfft2.default:
248
+ CastPolicy.FP32,
249
+ torch.ops.aten.fft_irfft2.default:
250
+ CastPolicy.FP32,
251
+ torch.ops.aten.fft_rfftn.default:
252
+ CastPolicy.FP32,
253
+ torch.ops.aten.fft_irfftn.default:
254
+ CastPolicy.FP32,
255
+ torch.ops.aten.fft_hfft.default:
256
+ CastPolicy.FP32,
257
+ torch.ops.aten.fft_ihfft.default:
258
+ CastPolicy.FP32,
259
+ torch.ops.aten.linalg_cond.default:
260
+ CastPolicy.FP32,
261
+ torch.ops.aten.linalg_cond.p_str:
262
+ CastPolicy.FP32,
263
+ torch.ops.aten.linalg_matrix_rank.default:
264
+ CastPolicy.FP32,
265
+ torch.ops.aten.linalg_matrix_rank.tol_tensor:
266
+ CastPolicy.FP32,
267
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_tensor:
268
+ CastPolicy.FP32,
269
+ torch.ops.aten.linalg_matrix_rank.atol_rtol_float:
270
+ CastPolicy.FP32,
271
+ torch.ops.aten.linalg_solve.default:
272
+ CastPolicy.FP32,
273
+ torch.ops.aten.linalg_cholesky.default:
274
+ CastPolicy.FP32,
275
+ torch.ops.aten.linalg_svdvals.default:
276
+ CastPolicy.FP32,
277
+ torch.ops.aten.linalg_eigvals.default:
278
+ CastPolicy.FP32,
279
+ torch.ops.aten.linalg_eigvalsh.default:
280
+ CastPolicy.FP32,
281
+ torch.ops.aten.linalg_inv.default:
282
+ CastPolicy.FP32,
283
+ torch.ops.aten.linalg_householder_product.default:
284
+ CastPolicy.FP32,
285
+ torch.ops.aten.linalg_tensorinv.default:
286
+ CastPolicy.FP32,
287
+ torch.ops.aten.linalg_tensorsolve.default:
288
+ CastPolicy.FP32,
289
+ torch.ops.aten.fake_quantize_per_tensor_affine.default:
290
+ CastPolicy.FP32,
291
+ torch.ops.aten.geqrf.default:
292
+ CastPolicy.FP32,
293
+ torch.ops.aten._lu_with_info.default:
294
+ CastPolicy.FP32,
295
+ torch.ops.aten.qr.default:
296
+ CastPolicy.FP32,
297
+ torch.ops.aten.svd.default:
298
+ CastPolicy.FP32,
299
+ torch.ops.aten.triangular_solve.default:
300
+ CastPolicy.FP32,
301
+ torch.ops.aten.fractional_max_pool2d.default:
302
+ CastPolicy.FP32,
303
+ torch.ops.aten.fractional_max_pool3d.default:
304
+ CastPolicy.FP32,
305
+ torch.ops.aten.adaptive_max_pool3d.default:
306
+ CastPolicy.FP32,
307
+ torch.ops.aten.multilabel_margin_loss_forward.default:
308
+ CastPolicy.FP32,
309
+ torch.ops.aten.linalg_qr.default:
310
+ CastPolicy.FP32,
311
+ torch.ops.aten.linalg_cholesky_ex.default:
312
+ CastPolicy.FP32,
313
+ torch.ops.aten.linalg_svd.default:
314
+ CastPolicy.FP32,
315
+ torch.ops.aten.linalg_eig.default:
316
+ CastPolicy.FP32,
317
+ torch.ops.aten.linalg_eigh.default:
318
+ CastPolicy.FP32,
319
+ torch.ops.aten.linalg_lstsq.default:
320
+ CastPolicy.FP32,
321
+ torch.ops.aten.linalg_inv_ex.default:
322
+ CastPolicy.FP32,
323
+
324
+ # promote
325
+ torch.ops.aten.stack.default:
326
+ CastPolicy.PROMOTE,
327
+ torch.ops.aten.cat.default:
328
+ CastPolicy.PROMOTE,
329
+ torch.ops.aten.index_copy.default:
330
+ CastPolicy.PROMOTE,
331
+ torch.ops.aten.index_copy.dimname:
332
+ CastPolicy.PROMOTE,
333
+ }
torchax/config.py CHANGED
@@ -3,17 +3,24 @@ import dataclasses
3
3
 
4
4
  @dataclasses.dataclass
5
5
  class Configuration:
6
- debug_print_each_op: bool = False
7
- debug_accuracy_for_each_op: bool = False
8
- debug_mixed_tensor: bool = False
9
- debug_print_each_op_operands: bool = False
10
- use_int32_for_index: bool = False
6
+ debug_print_each_op: bool = False
7
+ debug_accuracy_for_each_op: bool = False
8
+ debug_mixed_tensor: bool = False
9
+ debug_print_each_op_operands: bool = False
11
10
 
12
- # Flash attention
13
- use_tpu_flash_attention: bool = False
14
- shmap_flash_attention: bool = False
11
+ use_int32_for_index: bool = False
15
12
 
16
- # device
17
- treat_cuda_as_jax_device: bool = True
18
- use_torch_native_for_cpu_tensor: bool = True
19
- internal_respect_torch_return_dtypes: bool = False
13
+ # If true, we will convert Views into torchax.Tensors eagerly
14
+ force_materialize_views: bool = False
15
+
16
+ # Use DLPack for converting jax.Arrays <-> and torch.Tensor
17
+ use_dlpack_for_data_conversion: bool = False
18
+
19
+ # Flash attention
20
+ use_tpu_flash_attention: bool = False
21
+ shmap_flash_attention: bool = False
22
+
23
+ # device
24
+ treat_cuda_as_jax_device: bool = True
25
+ use_torch_native_for_cpu_tensor: bool = True
26
+ internal_respect_torch_return_dtypes: bool = False