torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/interop.py
CHANGED
|
@@ -1,209 +1,356 @@
|
|
|
1
|
+
import collections
|
|
1
2
|
import copy
|
|
2
3
|
import functools
|
|
3
4
|
import torch
|
|
5
|
+
from inspect import signature
|
|
6
|
+
from functools import wraps
|
|
4
7
|
from torch.nn.utils import stateless as torch_stateless
|
|
5
8
|
import jax
|
|
6
9
|
import jax.numpy as jnp
|
|
7
10
|
from jax import tree_util as pytree
|
|
8
11
|
from jax.experimental.shard_map import shard_map
|
|
9
12
|
from torchax import tensor
|
|
13
|
+
from torchax import util
|
|
14
|
+
from torchax.ops import mappings
|
|
10
15
|
import torchax
|
|
11
16
|
|
|
12
17
|
from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
|
|
13
18
|
|
|
14
19
|
|
|
15
20
|
def extract_all_buffers(m: torch.nn.Module):
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
21
|
+
buffers = {}
|
|
22
|
+
params = {}
|
|
23
|
+
|
|
24
|
+
def extract_one(module, prefix):
|
|
25
|
+
for k in dir(module):
|
|
26
|
+
try:
|
|
27
|
+
v = getattr(module, k)
|
|
28
|
+
except:
|
|
29
|
+
continue
|
|
30
|
+
qual_name = prefix + k
|
|
31
|
+
if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
|
|
32
|
+
params[qual_name] = v
|
|
33
|
+
elif isinstance(v, torch.Tensor):
|
|
34
|
+
buffers[qual_name] = v
|
|
35
|
+
for name, child in module.named_children():
|
|
36
|
+
extract_one(child, prefix + name + '.')
|
|
37
|
+
|
|
38
|
+
extract_one(m, '')
|
|
39
|
+
return params, buffers
|
|
33
40
|
|
|
34
41
|
|
|
35
42
|
def set_all_buffers(m, params, buffers):
|
|
36
|
-
def set_one(module, prefix):
|
|
37
|
-
for k in dir(module):
|
|
38
|
-
qual_name = prefix + k
|
|
39
|
-
if (potential_v := buffers.get(qual_name)) is not None:
|
|
40
|
-
setattr(module, k, potential_v)
|
|
41
|
-
elif (potential_v := params.get(qual_name)) is not None:
|
|
42
|
-
print(k, potential_v)
|
|
43
|
-
setattr(module, k, torch.nn.Parameter(potential_v))
|
|
44
|
-
for name, child in module.named_children():
|
|
45
|
-
set_one(child, prefix + name + '.')
|
|
46
43
|
|
|
47
|
-
|
|
44
|
+
def set_one(module, prefix):
|
|
45
|
+
for k in dir(module):
|
|
46
|
+
qual_name = prefix + k
|
|
47
|
+
if (potential_v := buffers.get(qual_name)) is not None:
|
|
48
|
+
setattr(module, k, potential_v)
|
|
49
|
+
elif (potential_v := params.get(qual_name)) is not None:
|
|
50
|
+
print(k, potential_v)
|
|
51
|
+
setattr(module, k, torch.nn.Parameter(potential_v))
|
|
52
|
+
for name, child in module.named_children():
|
|
53
|
+
set_one(child, prefix + name + '.')
|
|
54
|
+
|
|
55
|
+
set_one(m, '')
|
|
48
56
|
|
|
49
57
|
|
|
50
58
|
class JittableModule(torch.nn.Module):
|
|
51
59
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
60
|
+
def __init__(self,
|
|
61
|
+
m: torch.nn.Module,
|
|
62
|
+
extra_jit_args={},
|
|
63
|
+
dedup_parameters=True):
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.params, self.buffers = extract_all_buffers(m)
|
|
66
|
+
self._model = m
|
|
67
|
+
self._jitted = {}
|
|
68
|
+
|
|
69
|
+
self._extra_jit_args = extra_jit_args
|
|
70
|
+
|
|
71
|
+
self._extra_dumped_weights = {}
|
|
72
|
+
|
|
73
|
+
if dedup_parameters:
|
|
74
|
+
temp = collections.defaultdict(list)
|
|
75
|
+
for k, v in self.params.items():
|
|
76
|
+
temp[id(v)].append(k)
|
|
77
|
+
|
|
78
|
+
for v in temp.values():
|
|
79
|
+
if len(v) > 1:
|
|
80
|
+
# duplicated weights with different name
|
|
81
|
+
self._extra_dumped_weights[v[0]] = v[1:]
|
|
82
|
+
for extra_keys in v[1:]:
|
|
83
|
+
del self.params[extra_keys]
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def __class__(self):
|
|
87
|
+
# Lie about the class type so that
|
|
88
|
+
# isinstance(jittable_module, self._model.__class__) works
|
|
89
|
+
return self._model.__class__
|
|
90
|
+
|
|
91
|
+
def __call__(self, *args, **kwargs):
|
|
92
|
+
return self.forward(*args, **kwargs)
|
|
93
|
+
|
|
94
|
+
def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
|
|
95
|
+
kwargs = kwargs or {}
|
|
96
|
+
params_copy = copy.copy(params)
|
|
97
|
+
params_copy.update(buffers)
|
|
98
|
+
# reinflate the state dict so there are not any missing keys
|
|
99
|
+
for k, v in self._extra_dumped_weights.items():
|
|
100
|
+
for new_key in v:
|
|
101
|
+
params_copy[new_key] = params_copy[k]
|
|
102
|
+
|
|
103
|
+
if isinstance(method_or_name, str):
|
|
104
|
+
method = getattr(self._model, method_or_name)
|
|
105
|
+
else:
|
|
106
|
+
if not callable(method_or_name):
|
|
107
|
+
raise TypeError(
|
|
108
|
+
f"method_or_name should be a callable or a string, got {type(method_or_name)}"
|
|
109
|
+
)
|
|
110
|
+
method = method_or_name
|
|
111
|
+
args = (self._model,) + args
|
|
112
|
+
with torch_stateless._reparametrize_module(self._model, params_copy):
|
|
113
|
+
res = method(*args, **kwargs)
|
|
114
|
+
return res
|
|
115
|
+
|
|
116
|
+
def jittable_call(self, method_name: str, *args, **kwargs):
|
|
117
|
+
if method_name not in self._jitted:
|
|
118
|
+
jitted = jax_jit(
|
|
119
|
+
functools.partial(self.functional_call, method_name),
|
|
120
|
+
kwargs_for_jax_jit=self._extra_jit_args,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
def jitted_forward(*args, **kwargs):
|
|
124
|
+
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
125
|
+
|
|
126
|
+
self._jitted[method_name] = jitted_forward
|
|
127
|
+
return self._jitted[method_name](*args, **kwargs)
|
|
128
|
+
|
|
129
|
+
def forward(self, *args, **kwargs):
|
|
130
|
+
return self.jittable_call('forward', *args, **kwargs)
|
|
131
|
+
|
|
132
|
+
def __getattr__(self, key):
|
|
133
|
+
if key == '_model':
|
|
134
|
+
return super().__getattr__(key)
|
|
135
|
+
if key in self._jitted:
|
|
136
|
+
return self._jitted[key]
|
|
137
|
+
return getattr(self._model, key)
|
|
138
|
+
|
|
139
|
+
def make_jitted(self, key):
|
|
140
|
+
jitted = jax_jit(
|
|
141
|
+
functools.partial(self.functional_call, key),
|
|
142
|
+
kwargs_for_jax_jit=self._extra_jit_args)
|
|
143
|
+
|
|
144
|
+
def call(*args, **kwargs):
|
|
145
|
+
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
146
|
+
|
|
147
|
+
self._jitted[key] = call
|
|
57
148
|
|
|
58
|
-
self._extra_jit_args = extra_jit_args
|
|
59
149
|
|
|
150
|
+
class CompileMixin:
|
|
60
151
|
|
|
61
|
-
|
|
62
|
-
|
|
152
|
+
def functional_call(self, method, params, buffers, *args, **kwargs):
|
|
153
|
+
kwargs = kwargs or {}
|
|
154
|
+
params_copy = copy.copy(params)
|
|
155
|
+
params_copy.update(buffers)
|
|
156
|
+
with torch_stateless._reparametrize_module(self, params_copy):
|
|
157
|
+
res = method(*args, **kwargs)
|
|
158
|
+
return res
|
|
63
159
|
|
|
160
|
+
def jit(self, method):
|
|
161
|
+
jitted = jax_jit(functools.partial(self.functional_call, method_name))
|
|
64
162
|
|
|
65
|
-
def
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
params_copy = copy.copy(params)
|
|
69
|
-
params_copy.update(buffers)
|
|
70
|
-
with torch_stateless._reparametrize_module(self._model, params_copy):
|
|
71
|
-
res = getattr(self._model, method_name)(*args, **kwargs)
|
|
72
|
-
return res
|
|
163
|
+
def call(*args, **kwargs):
|
|
164
|
+
return jitted(self.named_paramters(), self.named_buffers(), *args,
|
|
165
|
+
**kwargs)
|
|
73
166
|
|
|
167
|
+
return call
|
|
74
168
|
|
|
75
|
-
def forward(self, *args, **kwargs):
|
|
76
|
-
if 'forward' not in self._jitted:
|
|
77
|
-
jitted = jax_jit(
|
|
78
|
-
functools.partial(self.functional_call, 'forward'),
|
|
79
|
-
kwargs_for_jax_jit=self._extra_jit_args,
|
|
80
|
-
)
|
|
81
|
-
def jitted_forward(*args, **kwargs):
|
|
82
|
-
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
83
|
-
self._jitted['forward'] = jitted_forward
|
|
84
|
-
return self._jitted['forward'](*args, **kwargs)
|
|
85
169
|
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
if key in self._jitted:
|
|
90
|
-
return self._jitted[key]
|
|
91
|
-
return getattr(self._model, key)
|
|
170
|
+
def compile_nn_module(m: torch.nn.Module, methods=None):
|
|
171
|
+
if methods is None:
|
|
172
|
+
methods = ['forward']
|
|
92
173
|
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
99
|
-
self._jitted[key] = call
|
|
174
|
+
new_parent = type(
|
|
175
|
+
m.__class__.__name__ + '_with_CompileMixin',
|
|
176
|
+
(CompileMixin, m.__class__),
|
|
177
|
+
)
|
|
178
|
+
m.__class__ = NewParent
|
|
100
179
|
|
|
101
180
|
|
|
181
|
+
def _torch_view(t: JaxValue) -> TorchValue:
|
|
182
|
+
# t is an object from jax land
|
|
183
|
+
# view it as-if it's a torch land object
|
|
184
|
+
if isinstance(t, jax.Array):
|
|
185
|
+
# TODO
|
|
186
|
+
return tensor.Tensor(t, torchax.default_env())
|
|
187
|
+
if isinstance(t, jnp.dtype):
|
|
188
|
+
return mappings.j2t_dtype(t)
|
|
189
|
+
if callable(t): # t is a JaxCallable
|
|
190
|
+
return functools.partial(call_jax, t)
|
|
191
|
+
# regular types are not changed
|
|
192
|
+
return t
|
|
102
193
|
|
|
103
194
|
|
|
195
|
+
torch_view = functools.partial(pytree.tree_map, _torch_view)
|
|
104
196
|
|
|
105
|
-
class CompileMixin:
|
|
106
197
|
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
198
|
+
def _jax_view(t: TorchValue) -> JaxValue:
|
|
199
|
+
# t is an object from torch land
|
|
200
|
+
# view it as-if it's a jax land object
|
|
201
|
+
if isinstance(t, torch.Tensor):
|
|
202
|
+
assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
|
|
203
|
+
return t.jax()
|
|
204
|
+
if isinstance(t, type(torch.int32)):
|
|
205
|
+
return mappings.t2j_dtype(t)
|
|
115
206
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
207
|
+
# torch.nn.Module needs special handling
|
|
208
|
+
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
|
|
209
|
+
return functools.partial(call_torch, t)
|
|
210
|
+
# regular types are not changed
|
|
211
|
+
return t
|
|
121
212
|
|
|
122
213
|
|
|
123
|
-
|
|
124
|
-
if methods is None:
|
|
125
|
-
methods = ['forward']
|
|
214
|
+
jax_view = functools.partial(pytree.tree_map, _jax_view)
|
|
126
215
|
|
|
127
|
-
new_parent = type(
|
|
128
|
-
m.__class__.__name__ + '_with_CompileMixin',
|
|
129
|
-
(CompileMixin, m.__class__),
|
|
130
|
-
)
|
|
131
|
-
m.__class__ = NewParent
|
|
132
216
|
|
|
217
|
+
def call_jax(jax_func: JaxCallable, *args: TorchValue,
|
|
218
|
+
**kwargs: TorchValue) -> TorchValue:
|
|
219
|
+
args, kwargs = jax_view((args, kwargs))
|
|
220
|
+
res: JaxValue = jax_func(*args, **kwargs)
|
|
221
|
+
return torch_view(res)
|
|
133
222
|
|
|
134
|
-
def _torch_view(t: JaxValue) -> TorchValue:
|
|
135
|
-
# t is an object from jax land
|
|
136
|
-
# view it as-if it's a torch land object
|
|
137
|
-
if isinstance(t, jax.Array):
|
|
138
|
-
# TODO
|
|
139
|
-
return tensor.Tensor(t, torchax.default_env())
|
|
140
|
-
if isinstance(t, type(jnp.int32)):
|
|
141
|
-
return tensor.t2j_type(t)
|
|
142
|
-
if callable(t): # t is a JaxCallable
|
|
143
|
-
return functools.partial(call_jax, t)
|
|
144
|
-
# regular types are not changed
|
|
145
|
-
return t
|
|
146
223
|
|
|
147
|
-
|
|
224
|
+
def call_torch(torch_func: TorchCallable, *args: JaxValue,
|
|
225
|
+
**kwargs: JaxValue) -> JaxValue:
|
|
226
|
+
args, kwargs = torch_view((args, kwargs))
|
|
227
|
+
with torchax.default_env():
|
|
228
|
+
res: TorchValue = torch_func(*args, **kwargs)
|
|
229
|
+
return jax_view(res)
|
|
148
230
|
|
|
149
231
|
|
|
150
|
-
def
|
|
151
|
-
|
|
152
|
-
# view it as-if it's a jax land object
|
|
153
|
-
if isinstance(t, torch.Tensor):
|
|
154
|
-
assert isinstance(t, tensor.Tensor), type(t)
|
|
155
|
-
return t.jax()
|
|
156
|
-
if isinstance(t, type(torch.int32)):
|
|
157
|
-
return tensor.t2j_dtype(t)
|
|
158
|
-
|
|
159
|
-
# torch.nn.Module needs special handling
|
|
160
|
-
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
|
|
161
|
-
return functools.partial(call_torch, t)
|
|
162
|
-
# regular types are not changed
|
|
163
|
-
return t
|
|
232
|
+
def j2t_autograd(fn, call_jax=call_jax):
|
|
233
|
+
"""Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
|
|
164
234
|
|
|
165
|
-
|
|
235
|
+
It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
|
|
236
|
+
activations). The wrapped function is then run via `call_jax` and integrated into
|
|
237
|
+
the PyTorch autograd framework by saving the residuals into the context object.
|
|
238
|
+
"""
|
|
166
239
|
|
|
240
|
+
@wraps(fn)
|
|
241
|
+
def inner(*args, **kwargs):
|
|
242
|
+
from jax.tree_util import tree_flatten
|
|
167
243
|
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
res: JaxValue = jax_func(*args, **kwargs)
|
|
173
|
-
return torch_view(res)
|
|
244
|
+
class JaxFun(torch.autograd.Function):
|
|
245
|
+
|
|
246
|
+
@staticmethod
|
|
247
|
+
def forward(ctx, tree_def, *flat_args_kwargs):
|
|
174
248
|
|
|
249
|
+
tensors, other = util.partition(flat_args_kwargs,
|
|
250
|
+
lambda x: isinstance(x, torch.Tensor))
|
|
251
|
+
# We want the arguments that don't require grads to be closured?
|
|
175
252
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
253
|
+
y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
|
|
254
|
+
|
|
255
|
+
# Save necessary information for backward
|
|
256
|
+
# Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass.
|
|
257
|
+
# `residuals` contains the tensors needed for the backward pass.`
|
|
258
|
+
residuals, vjp_spec = tree_flatten(fun_vjp)
|
|
259
|
+
ctx.vjp_spec = vjp_spec
|
|
260
|
+
ctx.save_for_backward(*residuals)
|
|
261
|
+
return y
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
def backward(ctx, *grad_out):
|
|
265
|
+
assert len(grad_out) > 0
|
|
266
|
+
grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
|
|
267
|
+
|
|
268
|
+
input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec,
|
|
269
|
+
ctx.saved_tensors, grad_out)
|
|
270
|
+
|
|
271
|
+
# Construct the gradient tuple to be returned.
|
|
272
|
+
# It needs to match the inputs to forward: (tree_def, *flat_inputs)
|
|
273
|
+
# The first gradient (for tree_def) is None.
|
|
274
|
+
# The subsequent gradients correspond to flat_inputs.
|
|
275
|
+
# We need to put a None for inputs that did not require gradients.
|
|
276
|
+
final_grads = [None]
|
|
277
|
+
for needs_grad, grad in zip(
|
|
278
|
+
ctx.needs_input_grad[1:], input_grads_structured, strict=True):
|
|
279
|
+
final_grads.append(grad if needs_grad else None)
|
|
280
|
+
|
|
281
|
+
return tuple(final_grads)
|
|
282
|
+
|
|
283
|
+
sig = signature(fn)
|
|
284
|
+
bound = sig.bind(*args, **kwargs)
|
|
285
|
+
bound.apply_defaults()
|
|
286
|
+
flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs))
|
|
287
|
+
y = JaxFun.apply(tree_def, *flat_args_kwargs)
|
|
288
|
+
return y
|
|
289
|
+
|
|
290
|
+
return inner
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
# NOTE(qihqi): This function cannot be inlined from the callsite
|
|
294
|
+
# Becuase if it does, then it won't hit the compilation cache for
|
|
295
|
+
# call_jax. Call jax uses functions' id as key.
|
|
296
|
+
def _jax_forward(fn, other, tree_def, tensors):
|
|
297
|
+
"""JAX function to compute output and vjp function.
|
|
298
|
+
|
|
299
|
+
primals should be a tuple (args, kwargs).
|
|
300
|
+
"""
|
|
301
|
+
import jax
|
|
302
|
+
from jax.tree_util import tree_flatten, tree_unflatten
|
|
303
|
+
|
|
304
|
+
def fn_wrapper(*tensors):
|
|
305
|
+
# Reconstruct the original args and kwargs
|
|
306
|
+
flat_inputs = util.merge(tensors, other)
|
|
307
|
+
args, kwargs = tree_unflatten(tree_def, flat_inputs)
|
|
308
|
+
return fn(*args, **kwargs)
|
|
309
|
+
|
|
310
|
+
return jax.vjp(fn_wrapper, *tensors)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
def _jax_backward(vjp_spec, saved_tensors, grad_out):
|
|
314
|
+
"""JAX function to compute input gradients.
|
|
315
|
+
|
|
316
|
+
Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
|
|
317
|
+
"""
|
|
318
|
+
from jax.tree_util import tree_unflatten
|
|
319
|
+
fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
|
|
320
|
+
return fun_vjp(grad_out)
|
|
181
321
|
|
|
182
322
|
|
|
183
323
|
fori_loop = torch_view(jax.lax.fori_loop)
|
|
184
324
|
|
|
185
325
|
|
|
186
326
|
def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
327
|
+
kwargs_for_jax = kwargs_for_jax or {}
|
|
328
|
+
jax_func = jax_view(torch_function)
|
|
329
|
+
jitted = jax_jit_func(jax_func, **kwargs_for_jax)
|
|
330
|
+
return torch_view(jitted)
|
|
191
331
|
|
|
192
332
|
|
|
193
|
-
def jax_jit(torch_function,
|
|
194
|
-
|
|
195
|
-
|
|
333
|
+
def jax_jit(torch_function,
|
|
334
|
+
kwargs_for_jax_jit=None,
|
|
335
|
+
fix_for_buffer_donation=False):
|
|
336
|
+
return wrap_jax_jit(
|
|
337
|
+
torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)
|
|
196
338
|
|
|
197
339
|
|
|
198
340
|
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
|
|
199
|
-
|
|
200
|
-
|
|
341
|
+
return wrap_jax_jit(
|
|
342
|
+
torch_function,
|
|
343
|
+
jax_jit_func=shard_map,
|
|
344
|
+
kwargs_for_jax=kwargs_for_jax_shard_map)
|
|
201
345
|
|
|
202
346
|
|
|
203
347
|
def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
|
|
204
|
-
|
|
205
|
-
|
|
348
|
+
return wrap_jax_jit(
|
|
349
|
+
torch_function,
|
|
350
|
+
jax_jit_func=jax.value_and_grad,
|
|
351
|
+
kwargs_for_jax=kwargs_for_value_and_grad)
|
|
352
|
+
|
|
206
353
|
|
|
207
354
|
def gradient_checkpoint(torch_function, kwargs=None):
|
|
208
|
-
|
|
209
|
-
|
|
355
|
+
return wrap_jax_jit(
|
|
356
|
+
torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)
|