torchax 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +38 -0
- torchax/__init__.py +124 -0
- torchax/config.py +19 -0
- torchax/decompositions.py +308 -0
- torchax/device_module.py +20 -0
- torchax/distributed.py +246 -0
- torchax/environment.py +2 -0
- torchax/export.py +236 -0
- torchax/interop.py +209 -0
- torchax/ops/__init__.py +10 -0
- torchax/ops/jaten.py +5212 -0
- torchax/ops/jax_reimplement.py +169 -0
- torchax/ops/jc10d.py +51 -0
- torchax/ops/jlibrary.py +73 -0
- torchax/ops/jtorch.py +427 -0
- torchax/ops/jtorchvision_nms.py +245 -0
- torchax/ops/mappings.py +97 -0
- torchax/ops/op_base.py +104 -0
- torchax/ops/ops_registry.py +50 -0
- torchax/tensor.py +557 -0
- torchax/tf_integration.py +119 -0
- torchax/train.py +120 -0
- torchax/types.py +12 -0
- torchax-0.0.4.dist-info/METADATA +341 -0
- torchax-0.0.4.dist-info/RECORD +27 -0
- torchax-0.0.4.dist-info/WHEEL +4 -0
- torchax-0.0.4.dist-info/licenses/LICENSE +28 -0
torchax/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# Contributing to TorchXLA2
|
|
2
|
+
|
|
3
|
+
We appreciate all contributions. If you are planning to contribute a bug fix for an open issue, please comment on the thread and we're happy to provide any guidance. You are very welcome to pick issues from good first issue and help wanted labels.
|
|
4
|
+
|
|
5
|
+
If you plan to contribute new features, utility functions or extensions to the core, please first open an issue and discuss the feature with us. Sending a PR without discussion might end up resulting in a rejected PR, because we might be taking the core in a different direction than you might be aware of.
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# Developer setup
|
|
9
|
+
|
|
10
|
+
## Mac setup:
|
|
11
|
+
@qihqi
|
|
12
|
+
|
|
13
|
+
I am able to develop directly on mac (m1) laptop for most of parts. Using steps
|
|
14
|
+
in README.md works. The condensed version for easy copy & paste:
|
|
15
|
+
|
|
16
|
+
```bash
|
|
17
|
+
conda create --name <your_name> python=3.10
|
|
18
|
+
conda activate <your_name>
|
|
19
|
+
pip install --upgrade "jax[cpu]" torch
|
|
20
|
+
pip install -r test_requirements.txt
|
|
21
|
+
pip install -e .
|
|
22
|
+
pytest test
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
### VSCode
|
|
26
|
+
|
|
27
|
+
I use vscode on my Mac. I loosely followed instruction in
|
|
28
|
+
https://code.visualstudio.com/docs/python/python-tutorial
|
|
29
|
+
to setup a proper python environment.
|
|
30
|
+
|
|
31
|
+
The plugins I installed (a subset of the ones listed above) are:
|
|
32
|
+
* VSCode's official Python plugin
|
|
33
|
+
* Ruff formatter
|
|
34
|
+
* Python Debugger
|
|
35
|
+
|
|
36
|
+
I also changed Python interpreter to point at the one in my conda env.
|
|
37
|
+
That is all the changes I have.
|
|
38
|
+
|
torchax/__init__.py
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
from typing import List, Dict, Any, Optional
|
|
3
|
+
import dataclasses
|
|
4
|
+
import jax
|
|
5
|
+
import os
|
|
6
|
+
import torch
|
|
7
|
+
from torch.utils import _pytree as pytree
|
|
8
|
+
from torchax import tensor
|
|
9
|
+
from torchax import distributed # noqa: F401
|
|
10
|
+
|
|
11
|
+
__version__ = "0.0.4"
|
|
12
|
+
VERSION = __version__
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
'default_env',
|
|
16
|
+
'extract_jax',
|
|
17
|
+
'enable_globally',
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
from jax._src import xla_bridge
|
|
21
|
+
os.environ.setdefault('ENABLE_RUNTIME_UPTIME_TELEMETRY', '1')
|
|
22
|
+
|
|
23
|
+
# torchax:oss-begin
|
|
24
|
+
if getattr(jax.config, 'jax_pjrt_client_create_options', None):
|
|
25
|
+
jax.config.update(
|
|
26
|
+
'jax_pjrt_client_create_options',
|
|
27
|
+
f'ml_framework_name:PyTorch/XLA2;ml_framework_version:{"v0.0.1"}'
|
|
28
|
+
)
|
|
29
|
+
# torchax:oss-end
|
|
30
|
+
|
|
31
|
+
env = None
|
|
32
|
+
def default_env():
|
|
33
|
+
global env
|
|
34
|
+
|
|
35
|
+
if env is None:
|
|
36
|
+
env = tensor.Environment()
|
|
37
|
+
return env
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def extract_jax(mod: torch.nn.Module, env=None):
|
|
42
|
+
"""Returns a pytree of jax.ndarray and a jax callable."""
|
|
43
|
+
if env is None:
|
|
44
|
+
env = default_env()
|
|
45
|
+
states = mod.state_dict()
|
|
46
|
+
|
|
47
|
+
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
|
|
48
|
+
|
|
49
|
+
#@jax.jit
|
|
50
|
+
def jax_func(states, inputs):
|
|
51
|
+
(states, inputs) = env.j2t_iso((states, inputs))
|
|
52
|
+
with env:
|
|
53
|
+
res = torch.func.functional_call(mod, states, inputs, tie_weights=False)
|
|
54
|
+
return env.t2j_iso(res)
|
|
55
|
+
|
|
56
|
+
return states, jax_func
|
|
57
|
+
|
|
58
|
+
def enable_globally():
|
|
59
|
+
env = default_env().enable_torch_modes()
|
|
60
|
+
return env
|
|
61
|
+
|
|
62
|
+
def disable_globally():
|
|
63
|
+
global env
|
|
64
|
+
default_env().disable_torch_modes()
|
|
65
|
+
|
|
66
|
+
@contextlib.contextmanager
|
|
67
|
+
def disable_temporarily():
|
|
68
|
+
prev = default_env().enabled
|
|
69
|
+
if prev:
|
|
70
|
+
disable_globally()
|
|
71
|
+
yield()
|
|
72
|
+
if prev:
|
|
73
|
+
enable_globally()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
torch.utils.rename_privateuse1_backend('jax')
|
|
77
|
+
unsupported_dtype = [torch.quint8]
|
|
78
|
+
torch.utils.generate_methods_for_privateuse1_backend(
|
|
79
|
+
for_tensor=True, for_module=True, for_storage=True,
|
|
80
|
+
unsupported_dtype=unsupported_dtype)
|
|
81
|
+
|
|
82
|
+
import jax
|
|
83
|
+
import torchax.device_module
|
|
84
|
+
torch._register_device_module('jax', torchax.device_module)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def enable_accuracy_mode():
|
|
90
|
+
jax.config.update('jax_enable_x64', True)
|
|
91
|
+
jax.config.update('jax_default_matmul_precision', 'highest')
|
|
92
|
+
default_env().config.internal_respect_torch_return_dtypes = True
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def enable_performance_mode():
|
|
96
|
+
jax.config.update('jax_enable_x64', False)
|
|
97
|
+
jax.config.update('jax_default_matmul_precision', 'default')
|
|
98
|
+
default_env().config.internal_respect_torch_return_dtypes = False
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
@dataclasses.dataclass
|
|
103
|
+
class CompileOptions:
|
|
104
|
+
# only valid if compiling nn.Module
|
|
105
|
+
methods_to_compile: List[str] = dataclasses.field(default_factory=lambda: ['forward'])
|
|
106
|
+
jax_jit_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
107
|
+
mode: str = 'jax' # or dynamo or export
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def compile(fn, options: Optional[CompileOptions] = None):
|
|
111
|
+
options = options or CompileOptions()
|
|
112
|
+
if options.mode == 'jax':
|
|
113
|
+
from torchax import interop
|
|
114
|
+
if isinstance(fn, torch.nn.Module):
|
|
115
|
+
module = interop.JittableModule(fn, extra_jit_args=options.jax_jit_kwargs)
|
|
116
|
+
for n in options.methods_to_compile:
|
|
117
|
+
module.make_jitted(n)
|
|
118
|
+
return module
|
|
119
|
+
else:
|
|
120
|
+
return interop.jax_jit(fn)
|
|
121
|
+
elif options.mode == 'dynamo':
|
|
122
|
+
raise RuntimeError('dynamo mode is not supported yet')
|
|
123
|
+
elif options.mode == 'export':
|
|
124
|
+
raise RuntimeError('export mode is not supported yet')
|
torchax/config.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
@dataclasses.dataclass
|
|
5
|
+
class Configuration:
|
|
6
|
+
debug_print_each_op: bool = False
|
|
7
|
+
debug_accuracy_for_each_op: bool = False
|
|
8
|
+
debug_mixed_tensor: bool = False
|
|
9
|
+
debug_print_each_op_operands: bool = False
|
|
10
|
+
use_int32_for_index: bool = False
|
|
11
|
+
|
|
12
|
+
# Flash attention
|
|
13
|
+
use_tpu_flash_attention: bool = False
|
|
14
|
+
shmap_flash_attention: bool = False
|
|
15
|
+
|
|
16
|
+
# device
|
|
17
|
+
treat_cuda_as_jax_device: bool = True
|
|
18
|
+
use_torch_native_for_cpu_tensor: bool = True
|
|
19
|
+
internal_respect_torch_return_dtypes: bool = False
|
|
@@ -0,0 +1,308 @@
|
|
|
1
|
+
"""This file contains some decompositons that are not available in torch stable.
|
|
2
|
+
|
|
3
|
+
Most likely from Content of
|
|
4
|
+
https://github.com/pytorch/pytorch/blob/main/torch/_decomp/decompositions.py
|
|
5
|
+
at main branch HEAD that we find useful here.
|
|
6
|
+
|
|
7
|
+
Can also contain decompositions of a torch op in terms of other torch ops.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import functools
|
|
11
|
+
from typing import Any, Callable, List, Tuple
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
from torch import Tensor
|
|
15
|
+
import torch._decomp as decomp
|
|
16
|
+
from torch._decomp import decompositions_for_rng
|
|
17
|
+
from torch._decomp import register_decomposition
|
|
18
|
+
import torch._prims_common as utils
|
|
19
|
+
from torch._prims_common.wrappers import out_wrapper
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
|
23
|
+
|
|
24
|
+
# None of these functions are publicly accessible; get at them
|
|
25
|
+
# from torch._decomps
|
|
26
|
+
__all__: List[str] = []
|
|
27
|
+
|
|
28
|
+
aten = torch._ops.ops.aten
|
|
29
|
+
|
|
30
|
+
def _try_register(op, impl):
|
|
31
|
+
try:
|
|
32
|
+
register_decomposition(op)(impl)
|
|
33
|
+
except:
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
@out_wrapper()
|
|
37
|
+
def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
|
|
38
|
+
def idx(left, middle, right):
|
|
39
|
+
dim_idx = torch.arange(-left, middle + right, device=a.device)
|
|
40
|
+
return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
|
|
41
|
+
|
|
42
|
+
return _reflection_or_replication_pad(
|
|
43
|
+
a,
|
|
44
|
+
padding,
|
|
45
|
+
idx,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
_try_register(aten.reflection_pad1d, _reflection_pad)
|
|
49
|
+
_try_register(aten.reflection_pad2d, _reflection_pad)
|
|
50
|
+
_try_register(aten.reflection_pad3d, _reflection_pad)
|
|
51
|
+
|
|
52
|
+
@out_wrapper()
|
|
53
|
+
def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
|
|
54
|
+
def idx(left, middle, right):
|
|
55
|
+
dim_idx = torch.arange(-left, middle + right, device=a.device)
|
|
56
|
+
return torch.clamp(dim_idx, 0, middle - 1)
|
|
57
|
+
|
|
58
|
+
return _reflection_or_replication_pad(
|
|
59
|
+
a,
|
|
60
|
+
padding,
|
|
61
|
+
idx,
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
decomp.global_decomposition_table['post_autograd'][aten.replication_pad2d.default] = _replication_pad
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def _reflection_or_replication_pad(
|
|
68
|
+
a: Tensor,
|
|
69
|
+
padding: Tuple[int, ...],
|
|
70
|
+
idx_fn: Callable[[int, int, int], Tensor],
|
|
71
|
+
) -> Tensor:
|
|
72
|
+
dim = len(padding) // 2
|
|
73
|
+
torch._check(
|
|
74
|
+
a.dim() in (dim + 1, dim + 2),
|
|
75
|
+
lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
|
|
76
|
+
)
|
|
77
|
+
inp_shape = a.shape[-dim:]
|
|
78
|
+
nc_dim = a.dim() - dim
|
|
79
|
+
|
|
80
|
+
padding_left = [padding[2 * (dim - 1 - i)] for i in range(dim)]
|
|
81
|
+
padding_right = [padding[2 * (dim - 1 - i) + 1] for i in range(dim)]
|
|
82
|
+
|
|
83
|
+
result = a
|
|
84
|
+
for i in range(dim):
|
|
85
|
+
idx: List[Any] = [None] * result.dim()
|
|
86
|
+
idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
|
|
87
|
+
result = aten._unsafe_index(result, idx)
|
|
88
|
+
|
|
89
|
+
# convert output to correct memory format, if necessary
|
|
90
|
+
memory_format = utils.suggest_memory_format(result)
|
|
91
|
+
result = result.contiguous(memory_format=memory_format)
|
|
92
|
+
return result
|
|
93
|
+
|
|
94
|
+
_try_register(aten.replication_pad1d, _replication_pad)
|
|
95
|
+
_try_register(aten.replication_pad3d, _replication_pad)
|
|
96
|
+
|
|
97
|
+
def bernoulli(self, *, generator=None):
|
|
98
|
+
return (torch.rand_like(self, dtype=torch.float32) < self).to(self.dtype)
|
|
99
|
+
|
|
100
|
+
_try_register(aten.bernoulli.default, bernoulli)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def rand_like(self, **kwargs):
|
|
104
|
+
dtype = kwargs.get('dtype', self.dtype)
|
|
105
|
+
return torch.rand(self.shape, dtype=dtype)
|
|
106
|
+
|
|
107
|
+
def channel_shuffle(self, groups):
|
|
108
|
+
batchsize, channels, height, width = self.shape
|
|
109
|
+
channels_per_group = channels // groups
|
|
110
|
+
self = self.reshape(batchsize, groups, channels_per_group, height, width)
|
|
111
|
+
self = self.transpose(1, 2)
|
|
112
|
+
self = self.reshape(batchsize, channels, height, width)
|
|
113
|
+
return self
|
|
114
|
+
|
|
115
|
+
_try_register(aten.channel_shuffle, channel_shuffle)
|
|
116
|
+
|
|
117
|
+
_try_register(aten.bernoulli, bernoulli)
|
|
118
|
+
_try_register(aten.rand_like, rand_like)
|
|
119
|
+
|
|
120
|
+
def bernoulli_float(self, p=0.5):
|
|
121
|
+
return self.bernoulli_(torch.tensor(p))
|
|
122
|
+
|
|
123
|
+
_try_register(aten.bernoulli_.float, bernoulli_float)
|
|
124
|
+
_try_register(aten.bernoulli_.Tensor, decompositions_for_rng.bernoulli_)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def _sum_tensors(ts) -> Tensor:
|
|
129
|
+
return functools.reduce(torch.add, ts)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@register_decomposition(aten.grid_sampler_3d)
|
|
133
|
+
def _grid_sampler_3d(
|
|
134
|
+
a: torch.Tensor,
|
|
135
|
+
grid: torch.Tensor,
|
|
136
|
+
interpolation_mode: int = 0,
|
|
137
|
+
padding_mode: int = 0,
|
|
138
|
+
align_corners: bool = False,
|
|
139
|
+
) -> Tensor:
|
|
140
|
+
"""References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
|
|
141
|
+
|
|
142
|
+
The above implement the 2d case.
|
|
143
|
+
"""
|
|
144
|
+
_expand_grid = False
|
|
145
|
+
torch._check(
|
|
146
|
+
interpolation_mode in (0, 1),
|
|
147
|
+
lambda: f"Invalid interpolation mode {interpolation_mode}",
|
|
148
|
+
)
|
|
149
|
+
torch._check(
|
|
150
|
+
padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# a is 5D: [B, C, D, H, W]
|
|
154
|
+
|
|
155
|
+
def unnormalize(coords: Tensor, size: int) -> Tensor:
|
|
156
|
+
# Rescale coordinates from [-1, 1] to:
|
|
157
|
+
# [0, size - 1] if align_corners is True
|
|
158
|
+
# [-.5, size -.5] if align_corners is False
|
|
159
|
+
mul = (size * 0.5 - 0.5) if align_corners else (size * 0.5)
|
|
160
|
+
ofs = size * 0.5 - 0.5
|
|
161
|
+
return coords * mul + ofs
|
|
162
|
+
|
|
163
|
+
# Reflects coordinates until they fall between low and high (inclusive).
|
|
164
|
+
# The bounds are passed as twice their value so that half-integer values
|
|
165
|
+
# can be represented as ints.
|
|
166
|
+
def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
|
|
167
|
+
if twice_low == twice_high:
|
|
168
|
+
return torch.zeros_like(coords)
|
|
169
|
+
coords_min = twice_low / 2
|
|
170
|
+
coords_span = (twice_high - twice_low) / 2
|
|
171
|
+
coords2 = (coords - coords_min).abs()
|
|
172
|
+
extra = torch.fmod(coords2, coords_span)
|
|
173
|
+
flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
|
|
174
|
+
return torch.where(
|
|
175
|
+
flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
def compute_coordinates(coords: Tensor, size: int) -> Tensor:
|
|
179
|
+
if padding_mode == 0: # Zero
|
|
180
|
+
return coords
|
|
181
|
+
elif padding_mode == 1: # Borders
|
|
182
|
+
return torch.clamp(coords, 0, size - 1)
|
|
183
|
+
else: # padding_mode == 2, Reflection
|
|
184
|
+
if align_corners:
|
|
185
|
+
coords_reflected = reflect_coordinates(coords, 0, 2 * (size - 1))
|
|
186
|
+
else:
|
|
187
|
+
coords_reflected = reflect_coordinates(coords, -1, 2 * size - 1)
|
|
188
|
+
return torch.clamp(coords_reflected, 0, size - 1)
|
|
189
|
+
|
|
190
|
+
def compute_source_index(coords: Tensor, size: int) -> Tensor:
|
|
191
|
+
coords_un = unnormalize(coords, size)
|
|
192
|
+
return compute_coordinates(coords_un, size)
|
|
193
|
+
|
|
194
|
+
N, C, iD, iH, iW = a.shape
|
|
195
|
+
_, oD, oH, oW, three = grid.shape
|
|
196
|
+
assert three == 3, 'Last dim of grid must be 3. got {}'.format(three)
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
|
|
200
|
+
xcheck = torch.logical_and(0 <= xs, xs < iW)
|
|
201
|
+
ycheck = torch.logical_and(0 <= ys, ys < iH)
|
|
202
|
+
zcheck = torch.logical_and(0 <= zs, zs < iD)
|
|
203
|
+
return torch.logical_and(
|
|
204
|
+
xcheck, torch.logical_and(ycheck, zcheck)
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
N_idx = torch.arange(N, device=a.device).view(N, 1, 1, 1, 1)
|
|
208
|
+
C_idx = torch.arange(C, device=a.device).view(1, C, 1, 1, 1)
|
|
209
|
+
|
|
210
|
+
def clip(xs: torch.Tensor, ys: torch.Tensor, zs, ws: torch.Tensor):
|
|
211
|
+
cond = in_bounds_cond(xs, ys, zs)
|
|
212
|
+
# To clip to inside valid coordinates, we map the coordinates
|
|
213
|
+
# to (x, y) = (0, 0) and also set the weight to 0
|
|
214
|
+
# We also change the shape of the tensor to the appropriate one for
|
|
215
|
+
# broadcasting with N_idx, C_idx for the purposes of advanced indexing
|
|
216
|
+
c = C if _expand_grid else 1
|
|
217
|
+
return tuple(
|
|
218
|
+
torch.where(cond, t, 0).view(N, c, oD, oH, oW)
|
|
219
|
+
for t in (xs.to(dtype=torch.int64), ys.to(dtype=torch.int64), zs.to(dtype=torch.int64), ws)
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
|
|
223
|
+
# Perform clipping, index into input tensor and multiply by weight
|
|
224
|
+
idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
|
|
225
|
+
return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
|
|
226
|
+
|
|
227
|
+
x = grid[..., 0]
|
|
228
|
+
y = grid[..., 1]
|
|
229
|
+
d = grid[..., 2]
|
|
230
|
+
|
|
231
|
+
if interpolation_mode == 0: # Bilinear
|
|
232
|
+
ix = compute_source_index(x, iW)
|
|
233
|
+
iy = compute_source_index(y, iH)
|
|
234
|
+
id_ = compute_source_index(d, iD)
|
|
235
|
+
|
|
236
|
+
ix_nwf, iy_nwf, id_nwf = ix.floor(), iy.floor(), id_.floor()
|
|
237
|
+
ix_nef, iy_nef, id_nef = ix_nwf + 1, iy_nwf, id_nwf
|
|
238
|
+
ix_swf, iy_swf, id_swf = ix_nwf, iy_nwf + 1, id_nwf
|
|
239
|
+
ix_sef, iy_sef, id_sef = ix_nef, iy_swf, id_nwf
|
|
240
|
+
ix_nwb, iy_nwb, id_nwb = ix_nwf, iy_nwf, id_nwf + 1
|
|
241
|
+
ix_neb, iy_neb, id_neb = ix_nef, iy_nef, id_nwf + 1
|
|
242
|
+
ix_swb, iy_swb, id_swb = ix_swf, iy_swf, id_nwf + 1
|
|
243
|
+
ix_seb, iy_seb, id_seb = ix_sef, iy_sef, id_nwf + 1
|
|
244
|
+
|
|
245
|
+
w_nwf = (ix_seb - ix) * (iy_seb - iy) * (id_seb - id_)
|
|
246
|
+
w_nef = (ix - ix_swb) * (iy_swb - iy) * (id_swb- id_)
|
|
247
|
+
w_swf = (ix_neb - ix) * (iy - iy_neb) * (id_neb - id_)
|
|
248
|
+
w_sef = (ix - ix_nwb) * (iy - iy_nwb) * (id_nwb - id_)
|
|
249
|
+
w_nwb = (ix_sef - ix) * (iy_sef - iy) * (id_ - id_sef)
|
|
250
|
+
w_neb = (ix - ix_swf) * (iy_swf - iy) * (id_ - id_swf)
|
|
251
|
+
w_swb = (ix_nef - ix) * (iy - iy_nef) * (id_ - id_nef)
|
|
252
|
+
w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
|
|
253
|
+
|
|
254
|
+
return _sum_tensors(
|
|
255
|
+
get_summand(ix, iy, id_, w)
|
|
256
|
+
for (ix, iy, id_, w) in (
|
|
257
|
+
(ix_nwf, iy_nwf, id_nwf, w_nwf),
|
|
258
|
+
(ix_nef, iy_nef, id_nef, w_nef),
|
|
259
|
+
(ix_swf, iy_swf, id_swf, w_swf),
|
|
260
|
+
(ix_sef, iy_sef, id_sef, w_sef),
|
|
261
|
+
(ix_nwb, iy_nwb, id_nwb, w_nwb),
|
|
262
|
+
(ix_neb, iy_neb, id_neb, w_neb),
|
|
263
|
+
(ix_swb, iy_swb, id_swb, w_swb),
|
|
264
|
+
(ix_seb, iy_seb, id_seb, w_seb),
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
else: #interpolation_mode == 1: # Nearest
|
|
268
|
+
ix = compute_source_index(x, iW)
|
|
269
|
+
iy = compute_source_index(y, iH)
|
|
270
|
+
iz = compute_source_index(d, iD)
|
|
271
|
+
|
|
272
|
+
ix_nearest = ix.round()
|
|
273
|
+
iy_nearest = iy.round()
|
|
274
|
+
iz_nearest = iz.round()
|
|
275
|
+
|
|
276
|
+
return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
|
|
277
|
+
|
|
278
|
+
EXTRA_DECOMP = decomp.get_decompositions([
|
|
279
|
+
torch.ops.aten.upsample_bicubic2d,
|
|
280
|
+
torch.ops.aten.upsample_nearest1d,
|
|
281
|
+
torch.ops.aten.upsample_nearest2d,
|
|
282
|
+
torch.ops.aten.upsample_nearest3d,
|
|
283
|
+
torch.ops.aten._upsample_nearest_exact1d,
|
|
284
|
+
torch.ops.aten._upsample_nearest_exact2d,
|
|
285
|
+
torch.ops.aten._upsample_nearest_exact3d,
|
|
286
|
+
torch.ops.aten._native_batch_norm_legit.no_stats,
|
|
287
|
+
torch.ops.aten._native_batch_norm_legit_functional.default,
|
|
288
|
+
torch.ops.aten._adaptive_avg_pool2d,
|
|
289
|
+
torch.ops.aten._adaptive_avg_pool3d,
|
|
290
|
+
torch.ops.aten.grid_sampler_2d,
|
|
291
|
+
torch.ops.aten.grid_sampler_3d,
|
|
292
|
+
torch.ops.aten.native_dropout,
|
|
293
|
+
torch.ops.aten.reflection_pad1d,
|
|
294
|
+
torch.ops.aten.reflection_pad2d,
|
|
295
|
+
torch.ops.aten.reflection_pad3d,
|
|
296
|
+
torch.ops.aten.replication_pad1d,
|
|
297
|
+
torch.ops.aten.replication_pad2d,
|
|
298
|
+
torch.ops.aten.replication_pad3d,
|
|
299
|
+
torch.ops.aten.bernoulli,
|
|
300
|
+
torch.ops.aten.rand_like,
|
|
301
|
+
torch.ops.aten._batch_norm_with_update,
|
|
302
|
+
torch.ops.aten.channel_shuffle,
|
|
303
|
+
torch.ops.aten.nll_loss2d_forward,
|
|
304
|
+
torch.ops.aten.nll_loss2d_backward,
|
|
305
|
+
torch.ops.aten.bernoulli_.Tensor,
|
|
306
|
+
torch.ops.aten.bernoulli_.float,
|
|
307
|
+
torch.ops.aten.log_normal,
|
|
308
|
+
])
|
torchax/device_module.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
def _is_in_bad_fork():
|
|
2
|
+
return False
|
|
3
|
+
|
|
4
|
+
def manual_seed_all(seed):
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
def device_count():
|
|
8
|
+
return 1
|
|
9
|
+
|
|
10
|
+
def get_rng_state():
|
|
11
|
+
return []
|
|
12
|
+
|
|
13
|
+
def set_rng_state(new_state, device):
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
def is_available():
|
|
17
|
+
return True
|
|
18
|
+
|
|
19
|
+
def current_device():
|
|
20
|
+
return 0
|