torchax 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

@@ -0,0 +1,38 @@
1
+ # Contributing to TorchXLA2
2
+
3
+ We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
4
+
5
+ If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
6
+
7
+
8
+ # Developer setup
9
+
10
+ ## Mac setup:
11
+ @qihqi
12
+
13
+ I am able to develop directly on mac (m1) laptop for most of parts. Using steps
14
+ in README.md works. The condensed version for easy copy & paste:
15
+
16
+ ```bash
17
+ conda create --name <your_name> python=3.10
18
+ conda activate <your_name>
19
+ pip install --upgrade "jax[cpu]" torch
20
+ pip install -r test_requirements.txt
21
+ pip install -e .
22
+ pytest test
23
+ ```
24
+
25
+ ### VSCode
26
+
27
+ I use vscode on my Mac. I loosely followed instruction in
28
+ https://code.visualstudio.com/docs/python/python-tutorial
29
+ to setup a proper python environment.
30
+
31
+ The plugins I installed (a subset of the ones listed above) are:
32
+ * VSCode's official Python plugin
33
+ * Ruff formatter
34
+ * Python Debugger
35
+
36
+ I also changed Python interpreter to point at the one in my conda env.
37
+ That is all the changes I have.
38
+
torchax/__init__.py ADDED
@@ -0,0 +1,124 @@
1
+ import contextlib
2
+ from typing import List, Dict, Any, Optional
3
+ import dataclasses
4
+ import jax
5
+ import os
6
+ import torch
7
+ from torch.utils import _pytree as pytree
8
+ from torchax import tensor
9
+ from torchax import distributed # noqa: F401
10
+
11
+ __version__ = "0.0.4"
12
+ VERSION = __version__
13
+
14
+ __all__ = [
15
+ 'default_env',
16
+ 'extract_jax',
17
+ 'enable_globally',
18
+ ]
19
+
20
+ from jax._src import xla_bridge
21
+ os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
22
+
23
+ # torchax:oss-begin
24
+ if getattr(jax.config, 'jax_pjrt_client_create_options', None):
25
+ jax.config.update(
26
+ 'jax_pjrt_client_create_options',
27
+ f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
28
+ )
29
+ # torchax:oss-end
30
+
31
+ env = None
32
+ def default_env():
33
+ global env
34
+
35
+ if env is None:
36
+ env = tensor.Environment()
37
+ return env
38
+
39
+
40
+
41
+ def extract_jax(mod: torch.nn.Module, env=None):
42
+ """Returns a pytree of jax.ndarray and a jax callable."""
43
+ if env is None:
44
+ env = default_env()
45
+ states = mod.state_dict()
46
+
47
+ states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
48
+
49
+ #@jax.jit
50
+ def jax_func(states, inputs):
51
+ (states, inputs) = env.j2t_iso((states, inputs))
52
+ with env:
53
+ res = torch.func.functional_call(mod, states, inputs, tie_weights=False)
54
+ return env.t2j_iso(res)
55
+
56
+ return states, jax_func
57
+
58
+ def enable_globally():
59
+ env = default_env().enable_torch_modes()
60
+ return env
61
+
62
+ def disable_globally():
63
+ global env
64
+ default_env().disable_torch_modes()
65
+
66
+ @contextlib.contextmanager
67
+ def disable_temporarily():
68
+ prev = default_env().enabled
69
+ if prev:
70
+ disable_globally()
71
+ yield()
72
+ if prev:
73
+ enable_globally()
74
+
75
+
76
+ torch.utils.rename_privateuse1_backend('jax')
77
+ unsupported_dtype = [torch.quint8]
78
+ torch.utils.generate_methods_for_privateuse1_backend(
79
+ for_tensor=True, for_module=True, for_storage=True,
80
+ unsupported_dtype=unsupported_dtype)
81
+
82
+ import jax
83
+ import torchax.device_module
84
+ torch._register_device_module('jax', torchax.device_module)
85
+
86
+
87
+
88
+
89
+ def enable_accuracy_mode():
90
+ jax.config.update('jax_enable_x64', True)
91
+ jax.config.update('jax_default_matmul_precision', 'highest')
92
+ default_env().config.internal_respect_torch_return_dtypes = True
93
+
94
+
95
+ def enable_performance_mode():
96
+ jax.config.update('jax_enable_x64', False)
97
+ jax.config.update('jax_default_matmul_precision', 'default')
98
+ default_env().config.internal_respect_torch_return_dtypes = False
99
+
100
+
101
+
102
+ @dataclasses.dataclass
103
+ class CompileOptions:
104
+ # only valid if compiling nn.Module
105
+ methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward'])
106
+ jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
107
+ mode: str = 'jax' # or dynamo or export
108
+
109
+
110
+ def compile(fn, options: Optional[CompileOptions] = None):
111
+ options = options or CompileOptions()
112
+ if options.mode == 'jax':
113
+ from torchax import interop
114
+ if isinstance(fn, torch.nn.Module):
115
+ module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
116
+ for n in options.methods_to_compile:
117
+ module.make_jitted(n)
118
+ return module
119
+ else:
120
+ return interop.jax_jit(fn)
121
+ elif options.mode == 'dynamo':
122
+ raise RuntimeError('dynamo mode is not supported yet')
123
+ elif options.mode == 'export':
124
+ raise RuntimeError('export mode is not supported yet')
torchax/config.py ADDED
@@ -0,0 +1,19 @@
1
+ import dataclasses
2
+
3
+
4
+ @dataclasses.dataclass
5
+ class Configuration:
6
+ debug_print_each_op: bool = False
7
+ debug_accuracy_for_each_op: bool = False
8
+ debug_mixed_tensor: bool = False
9
+ debug_print_each_op_operands: bool = False
10
+ use_int32_for_index: bool = False
11
+
12
+ # Flash attention
13
+ use_tpu_flash_attention: bool = False
14
+ shmap_flash_attention: bool = False
15
+
16
+ # device
17
+ treat_cuda_as_jax_device: bool = True
18
+ use_torch_native_for_cpu_tensor: bool = True
19
+ internal_respect_torch_return_dtypes: bool = False
@@ -0,0 +1,308 @@
1
+ """This file contains some decompositons that are not available in torch stable.
2
+
3
+ Most likely from Content of
4
+ https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
5
+ at main branch HEAD that we find useful here.
6
+
7
+ Can also contain decompositions of a torch op in terms of other torch ops.
8
+ """
9
+
10
+ import functools
11
+ from typing import Any, Callable, List, Tuple
12
+
13
+ import torch
14
+ from torch import Tensor
15
+ import torch._decomp as decomp
16
+ from torch._decomp import decompositions_for_rng
17
+ from torch._decomp import register_decomposition
18
+ import torch._prims_common as utils
19
+ from torch._prims_common.wrappers import out_wrapper
20
+
21
+
22
+ DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
23
+
24
+ # None of these functions are publicly accessible; get at them
25
+ # from torch._decomps
26
+ __all__: List[str] = []
27
+
28
+ aten = torch._ops.ops.aten
29
+
30
+ def _try_register(op, impl):
31
+ try:
32
+ register_decomposition(op)(impl)
33
+ except:
34
+ pass
35
+
36
+ @out_wrapper()
37
+ def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
38
+ def idx(left, middle, right):
39
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
40
+ return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
41
+
42
+ return _reflection_or_replication_pad(
43
+ a,
44
+ padding,
45
+ idx,
46
+ )
47
+
48
+ _try_register(aten.reflection_pad1d, _reflection_pad)
49
+ _try_register(aten.reflection_pad2d, _reflection_pad)
50
+ _try_register(aten.reflection_pad3d, _reflection_pad)
51
+
52
+ @out_wrapper()
53
+ def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
54
+ def idx(left, middle, right):
55
+ dim_idx = torch.arange(-left, middle + right, device=a.device)
56
+ return torch.clamp(dim_idx, 0, middle - 1)
57
+
58
+ return _reflection_or_replication_pad(
59
+ a,
60
+ padding,
61
+ idx,
62
+ )
63
+
64
+ decomp.global_decomposition_table['post_autograd'][aten.replication_pad2d.default] = _replication_pad
65
+
66
+
67
+ def _reflection_or_replication_pad(
68
+ a: Tensor,
69
+ padding: Tuple[int, ...],
70
+ idx_fn: Callable[[int, int, int], Tensor],
71
+ ) -> Tensor:
72
+ dim = len(padding) // 2
73
+ torch._check(
74
+ a.dim() in (dim + 1, dim + 2),
75
+ lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
76
+ )
77
+ inp_shape = a.shape[-dim:]
78
+ nc_dim = a.dim() - dim
79
+
80
+ padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
81
+ padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
82
+
83
+ result = a
84
+ for i in range(dim):
85
+ idx: List[Any] = [None] * result.dim()
86
+ idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
87
+ result = aten._unsafe_index(result, idx)
88
+
89
+ # convert output to correct memory format, if necessary
90
+ memory_format = utils.suggest_memory_format(result)
91
+ result = result.contiguous(memory_format=memory_format)
92
+ return result
93
+
94
+ _try_register(aten.replication_pad1d, _replication_pad)
95
+ _try_register(aten.replication_pad3d, _replication_pad)
96
+
97
+ def bernoulli(self, *, generator=None):
98
+ return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
99
+
100
+ _try_register(aten.bernoulli.default, bernoulli)
101
+
102
+
103
+ def rand_like(self, **kwargs):
104
+ dtype = kwargs.get('dtype', self.dtype)
105
+ return torch.rand(self.shape, dtype=dtype)
106
+
107
+ def channel_shuffle(self, groups):
108
+ batchsize, channels, height, width = self.shape
109
+ channels_per_group = channels // groups
110
+ self = self.reshape(batchsize, groups, channels_per_group, height, width)
111
+ self = self.transpose(1, 2)
112
+ self = self.reshape(batchsize, channels, height, width)
113
+ return self
114
+
115
+ _try_register(aten.channel_shuffle, channel_shuffle)
116
+
117
+ _try_register(aten.bernoulli, bernoulli)
118
+ _try_register(aten.rand_like, rand_like)
119
+
120
+ def bernoulli_float(self, p=0.5):
121
+ return self.bernoulli_(torch.tensor(p))
122
+
123
+ _try_register(aten.bernoulli_.float, bernoulli_float)
124
+ _try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_)
125
+
126
+
127
+
128
+ def _sum_tensors(ts) -> Tensor:
129
+ return functools.reduce(torch.add, ts)
130
+
131
+
132
+ @register_decomposition(aten.grid_sampler_3d)
133
+ def _grid_sampler_3d(
134
+ a: torch.Tensor,
135
+ grid: torch.Tensor,
136
+ interpolation_mode: int = 0,
137
+ padding_mode: int = 0,
138
+ align_corners: bool = False,
139
+ ) -> Tensor:
140
+ """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
141
+
142
+ The above implement the 2d case.
143
+ """
144
+ _expand_grid = False
145
+ torch._check(
146
+ interpolation_mode in (0, 1),
147
+ lambda: f"Invalid interpolation mode {interpolation_mode}",
148
+ )
149
+ torch._check(
150
+ padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
151
+ )
152
+
153
+ # a is 5D: [B, C, D, H, W]
154
+
155
+ def unnormalize(coords: Tensor, size: int) -> Tensor:
156
+ # Rescale coordinates from [-1, 1] to:
157
+ # [0, size - 1] if align_corners is True
158
+ # [-.5, size -.5] if align_corners is False
159
+ mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
160
+ ofs = size * 0.5 - 0.5
161
+ return coords * mul + ofs
162
+
163
+ # Reflects coordinates until they fall between low and high (inclusive).
164
+ # The bounds are passed as twice their value so that half-integer values
165
+ # can be represented as ints.
166
+ def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
167
+ if twice_low == twice_high:
168
+ return torch.zeros_like(coords)
169
+ coords_min = twice_low / 2
170
+ coords_span = (twice_high - twice_low) / 2
171
+ coords2 = (coords - coords_min).abs()
172
+ extra = torch.fmod(coords2, coords_span)
173
+ flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
174
+ return torch.where(
175
+ flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
176
+ )
177
+
178
+ def compute_coordinates(coords: Tensor, size: int) -> Tensor:
179
+ if padding_mode == 0: # Zero
180
+ return coords
181
+ elif padding_mode == 1: # Borders
182
+ return torch.clamp(coords, 0, size - 1)
183
+ else: # padding_mode == 2, Reflection
184
+ if align_corners:
185
+ coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
186
+ else:
187
+ coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
188
+ return torch.clamp(coords_reflected, 0, size - 1)
189
+
190
+ def compute_source_index(coords: Tensor, size: int) -> Tensor:
191
+ coords_un = unnormalize(coords, size)
192
+ return compute_coordinates(coords_un, size)
193
+
194
+ N, C, iD, iH, iW = a.shape
195
+ _, oD, oH, oW, three = grid.shape
196
+ assert three == 3, 'Last dim of grid must be 3. got {}'.format(three)
197
+
198
+
199
+ def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
200
+ xcheck = torch.logical_and(0 <= xs, xs < iW)
201
+ ycheck = torch.logical_and(0 <= ys, ys < iH)
202
+ zcheck = torch.logical_and(0 <= zs, zs < iD)
203
+ return torch.logical_and(
204
+ xcheck, torch.logical_and(ycheck, zcheck)
205
+ )
206
+
207
+ N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
208
+ C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)
209
+
210
+ def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
211
+ cond = in_bounds_cond(xs, ys, zs)
212
+ # To clip to inside valid coordinates, we map the coordinates
213
+ # to (x, y) = (0, 0) and also set the weight to 0
214
+ # We also change the shape of the tensor to the appropriate one for
215
+ # broadcasting with N_idx, C_idx for the purposes of advanced indexing
216
+ c = C if _expand_grid else 1
217
+ return tuple(
218
+ torch.where(cond, t, 0).view(N, c, oD, oH, oW)
219
+ for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), zs.to(dtype=torch.int64), ws)
220
+ )
221
+
222
+ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
223
+ # Perform clipping, index into input tensor and multiply by weight
224
+ idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
225
+ return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
226
+
227
+ x = grid[..., 0]
228
+ y = grid[..., 1]
229
+ d = grid[..., 2]
230
+
231
+ if interpolation_mode == 0: # Bilinear
232
+ ix = compute_source_index(x, iW)
233
+ iy = compute_source_index(y, iH)
234
+ id_ = compute_source_index(d, iD)
235
+
236
+ ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
237
+ ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
238
+ ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
239
+ ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
240
+ ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
241
+ ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
242
+ ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
243
+ ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1
244
+
245
+ w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
246
+ w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb- id_)
247
+ w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
248
+ w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
249
+ w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
250
+ w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
251
+ w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
252
+ w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
253
+
254
+ return _sum_tensors(
255
+ get_summand(ix, iy, id_, w)
256
+ for (ix, iy, id_, w) in (
257
+ (ix_nwf, iy_nwf, id_nwf, w_nwf),
258
+ (ix_nef, iy_nef, id_nef, w_nef),
259
+ (ix_swf, iy_swf, id_swf, w_swf),
260
+ (ix_sef, iy_sef, id_sef, w_sef),
261
+ (ix_nwb, iy_nwb, id_nwb, w_nwb),
262
+ (ix_neb, iy_neb, id_neb, w_neb),
263
+ (ix_swb, iy_swb, id_swb, w_swb),
264
+ (ix_seb, iy_seb, id_seb, w_seb),
265
+ )
266
+ )
267
+ else: #interpolation_mode == 1: # Nearest
268
+ ix = compute_source_index(x, iW)
269
+ iy = compute_source_index(y, iH)
270
+ iz = compute_source_index(d, iD)
271
+
272
+ ix_nearest = ix.round()
273
+ iy_nearest = iy.round()
274
+ iz_nearest = iz.round()
275
+
276
+ return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
277
+
278
+ EXTRA_DECOMP = decomp.get_decompositions([
279
+ torch.ops.aten.upsample_bicubic2d,
280
+ torch.ops.aten.upsample_nearest1d,
281
+ torch.ops.aten.upsample_nearest2d,
282
+ torch.ops.aten.upsample_nearest3d,
283
+ torch.ops.aten._upsample_nearest_exact1d,
284
+ torch.ops.aten._upsample_nearest_exact2d,
285
+ torch.ops.aten._upsample_nearest_exact3d,
286
+ torch.ops.aten._native_batch_norm_legit.no_stats,
287
+ torch.ops.aten._native_batch_norm_legit_functional.default,
288
+ torch.ops.aten._adaptive_avg_pool2d,
289
+ torch.ops.aten._adaptive_avg_pool3d,
290
+ torch.ops.aten.grid_sampler_2d,
291
+ torch.ops.aten.grid_sampler_3d,
292
+ torch.ops.aten.native_dropout,
293
+ torch.ops.aten.reflection_pad1d,
294
+ torch.ops.aten.reflection_pad2d,
295
+ torch.ops.aten.reflection_pad3d,
296
+ torch.ops.aten.replication_pad1d,
297
+ torch.ops.aten.replication_pad2d,
298
+ torch.ops.aten.replication_pad3d,
299
+ torch.ops.aten.bernoulli,
300
+ torch.ops.aten.rand_like,
301
+ torch.ops.aten._batch_norm_with_update,
302
+ torch.ops.aten.channel_shuffle,
303
+ torch.ops.aten.nll_loss2d_forward,
304
+ torch.ops.aten.nll_loss2d_backward,
305
+ torch.ops.aten.bernoulli_.Tensor,
306
+ torch.ops.aten.bernoulli_.float,
307
+ torch.ops.aten.log_normal,
308
+ ])
@@ -0,0 +1,20 @@
1
+ def _is_in_bad_fork():
2
+ return False
3
+
4
+ def manual_seed_all(seed):
5
+ pass
6
+
7
+ def device_count():
8
+ return 1
9
+
10
+ def get_rng_state():
11
+ return []
12
+
13
+ def set_rng_state(new_state, device):
14
+ pass
15
+
16
+ def is_available():
17
+ return True
18
+
19
+ def current_device():
20
+ return 0