torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/interop.py CHANGED
@@ -1,209 +1,343 @@
1
+ import collections
1
2
  import copy
2
3
  import functools
3
4
  import torch
5
+ from inspect import signature
6
+ from functools import wraps
4
7
  from torch.nn.utils import stateless as torch_stateless
5
8
  import jax
6
9
  import jax.numpy as jnp
7
10
  from jax import tree_util as pytree
8
11
  from jax.experimental.shard_map import shard_map
9
12
  from torchax import tensor
13
+ from torchax import util
10
14
  import torchax
11
15
 
12
16
  from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
13
17
 
14
18
 
15
19
  def extract_all_buffers(m: torch.nn.Module):
16
- buffers = {}
17
- params = {}
18
- def extract_one(module, prefix):
19
- for k in dir(module):
20
- try:
21
- v = getattr(module, k)
22
- except:
23
- continue
24
- qual_name = prefix + k
25
- if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
26
- params[qual_name] = v
27
- elif isinstance(v, torch.Tensor):
28
- buffers[qual_name] = v
29
- for name, child in module.named_children():
30
- extract_one(child, prefix + name + '.')
31
- extract_one(m, '')
32
- return params, buffers
20
+ buffers = {}
21
+ params = {}
22
+
23
+ def extract_one(module, prefix):
24
+ for k in dir(module):
25
+ try:
26
+ v = getattr(module, k)
27
+ except:
28
+ continue
29
+ qual_name = prefix + k
30
+ if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
31
+ params[qual_name] = v
32
+ elif isinstance(v, torch.Tensor):
33
+ buffers[qual_name] = v
34
+ for name, child in module.named_children():
35
+ extract_one(child, prefix + name + '.')
36
+
37
+ extract_one(m, '')
38
+ return params, buffers
33
39
 
34
40
 
35
41
  def set_all_buffers(m, params, buffers):
36
- def set_one(module, prefix):
37
- for k in dir(module):
38
- qual_name = prefix + k
39
- if (potential_v := buffers.get(qual_name)) is not None:
40
- setattr(module, k, potential_v)
41
- elif (potential_v := params.get(qual_name)) is not None:
42
- print(k, potential_v)
43
- setattr(module, k, torch.nn.Parameter(potential_v))
44
- for name, child in module.named_children():
45
- set_one(child, prefix + name + '.')
46
42
 
47
- set_one(m, '')
43
+ def set_one(module, prefix):
44
+ for k in dir(module):
45
+ qual_name = prefix + k
46
+ if (potential_v := buffers.get(qual_name)) is not None:
47
+ setattr(module, k, potential_v)
48
+ elif (potential_v := params.get(qual_name)) is not None:
49
+ print(k, potential_v)
50
+ setattr(module, k, torch.nn.Parameter(potential_v))
51
+ for name, child in module.named_children():
52
+ set_one(child, prefix + name + '.')
53
+
54
+ set_one(m, '')
48
55
 
49
56
 
50
57
  class JittableModule(torch.nn.Module):
51
58
 
52
- def __init__(self, m: torch.nn.Module, extra_jit_args={}):
53
- super().__init__()
54
- self.params, self.buffers = extract_all_buffers(m)
55
- self._model = m
56
- self._jitted = {}
59
+ def __init__(self,
60
+ m: torch.nn.Module,
61
+ extra_jit_args={},
62
+ dedup_parameters=True):
63
+ super().__init__()
64
+ self.params, self.buffers = extract_all_buffers(m)
65
+ self._model = m
66
+ self._jitted = {}
67
+
68
+ self._extra_jit_args = extra_jit_args
69
+
70
+ self._extra_dumped_weights = {}
71
+
72
+ if dedup_parameters:
73
+ temp = collections.defaultdict(list)
74
+ for k, v in self.params.items():
75
+ temp[id(v)].append(k)
76
+
77
+ for v in temp.values():
78
+ if len(v) > 1:
79
+ # duplicated weights with different name
80
+ self._extra_dumped_weights[v[0]] = v[1:]
81
+ for extra_keys in v[1:]:
82
+ del self.params[extra_keys]
83
+
84
+ @property
85
+ def __class__(self):
86
+ # Lie about the class type so that
87
+ # isinstance(jittable_module, self._model.__class__) works
88
+ return self._model.__class__
89
+
90
+ def __call__(self, *args, **kwargs):
91
+ return self.forward(*args, **kwargs)
92
+
93
+ def functional_call(self, method_name, params, buffers, *args, **kwargs):
94
+ kwargs = kwargs or {}
95
+ params_copy = copy.copy(params)
96
+ params_copy.update(buffers)
97
+ # reinflate the state dict so there are not any missing keys
98
+ for k, v in self._extra_dumped_weights.items():
99
+ for new_key in v:
100
+ params_copy[new_key] = params_copy[k]
101
+ with torch_stateless._reparametrize_module(self._model, params_copy):
102
+ res = getattr(self._model, method_name)(*args, **kwargs)
103
+ return res
104
+
105
+ def forward(self, *args, **kwargs):
106
+ if 'forward' not in self._jitted:
107
+ jitted = jax_jit(
108
+ functools.partial(self.functional_call, 'forward'),
109
+ kwargs_for_jax_jit=self._extra_jit_args,
110
+ )
111
+
112
+ def jitted_forward(*args, **kwargs):
113
+ return jitted(self.params, self.buffers, *args, **kwargs)
114
+
115
+ self._jitted['forward'] = jitted_forward
116
+ return self._jitted['forward'](*args, **kwargs)
117
+
118
+ def __getattr__(self, key):
119
+ if key == '_model':
120
+ return super().__getattr__(key)
121
+ if key in self._jitted:
122
+ return self._jitted[key]
123
+ return getattr(self._model, key)
124
+
125
+ def make_jitted(self, key):
126
+ jitted = jax_jit(
127
+ functools.partial(self.functional_call, key),
128
+ kwargs_for_jax_jit=self._extra_jit_args)
129
+
130
+ def call(*args, **kwargs):
131
+ return jitted(self.params, self.buffers, *args, **kwargs)
132
+
133
+ self._jitted[key] = call
57
134
 
58
- self._extra_jit_args = extra_jit_args
59
135
 
136
+ class CompileMixin:
60
137
 
61
- def __call__(self, *args, **kwargs):
62
- return self.forward(*args, **kwargs)
138
+ def functional_call(self, method, params, buffers, *args, **kwargs):
139
+ kwargs = kwargs or {}
140
+ params_copy = copy.copy(params)
141
+ params_copy.update(buffers)
142
+ with torch_stateless._reparametrize_module(self, params_copy):
143
+ res = method(*args, **kwargs)
144
+ return res
63
145
 
146
+ def jit(self, method):
147
+ jitted = jax_jit(functools.partial(self.functional_call, method_name))
64
148
 
65
- def functional_call(
66
- self, method_name, params, buffers, *args, **kwargs):
67
- kwargs = kwargs or {}
68
- params_copy = copy.copy(params)
69
- params_copy.update(buffers)
70
- with torch_stateless._reparametrize_module(self._model, params_copy):
71
- res = getattr(self._model, method_name)(*args, **kwargs)
72
- return res
149
+ def call(*args, **kwargs):
150
+ return jitted(self.named_paramters(), self.named_buffers(), *args,
151
+ **kwargs)
73
152
 
153
+ return call
74
154
 
75
- def forward(self, *args, **kwargs):
76
- if 'forward' not in self._jitted:
77
- jitted = jax_jit(
78
- functools.partial(self.functional_call, 'forward'),
79
- kwargs_for_jax_jit=self._extra_jit_args,
80
- )
81
- def jitted_forward(*args, **kwargs):
82
- return jitted(self.params, self.buffers, *args, **kwargs)
83
- self._jitted['forward'] = jitted_forward
84
- return self._jitted['forward'](*args, **kwargs)
85
155
 
86
- def __getattr__(self, key):
87
- if key == '_model':
88
- return super().__getattr__(key)
89
- if key in self._jitted:
90
- return self._jitted[key]
91
- return getattr(self._model, key)
156
+ def compile_nn_module(m: torch.nn.Module, methods=None):
157
+ if methods is None:
158
+ methods = ['forward']
92
159
 
93
- def make_jitted(self, key):
94
- jitted = jax_jit(
95
- functools.partial(self.functional_call, key),
96
- kwargs_for_jax_jit=self._extra_jit_args)
97
- def call(*args, **kwargs):
98
- return jitted(self.params, self.buffers, *args, **kwargs)
99
- self._jitted[key] = call
160
+ new_parent = type(
161
+ m.__class__.__name__ + '_with_CompileMixin',
162
+ (CompileMixin, m.__class__),
163
+ )
164
+ m.__class__ = NewParent
100
165
 
101
166
 
167
+ def _torch_view(t: JaxValue) -> TorchValue:
168
+ # t is an object from jax land
169
+ # view it as-if it's a torch land object
170
+ if isinstance(t, jax.Array):
171
+ # TODO
172
+ return tensor.Tensor(t, torchax.default_env())
173
+ if isinstance(t, type(jnp.int32)):
174
+ return tensor.t2j_type(t)
175
+ if callable(t): # t is a JaxCallable
176
+ return functools.partial(call_jax, t)
177
+ # regular types are not changed
178
+ return t
102
179
 
103
180
 
181
+ torch_view = functools.partial(pytree.tree_map, _torch_view)
104
182
 
105
- class CompileMixin:
106
183
 
107
- def functional_call(
108
- self, method, params, buffers, *args, **kwargs):
109
- kwargs = kwargs or {}
110
- params_copy = copy.copy(params)
111
- params_copy.update(buffers)
112
- with torch_stateless._reparametrize_module(self, params_copy):
113
- res = method(*args, **kwargs)
114
- return res
184
+ def _jax_view(t: TorchValue) -> JaxValue:
185
+ # t is an object from torch land
186
+ # view it as-if it's a jax land object
187
+ if isinstance(t, torch.Tensor):
188
+ assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
189
+ return t.jax()
190
+ if isinstance(t, type(torch.int32)):
191
+ return tensor.t2j_dtype(t)
115
192
 
116
- def jit(self, method):
117
- jitted = jax_jit(functools.partial(self.functional_call, method_name))
118
- def call(*args, **kwargs):
119
- return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
120
- return call
193
+ # torch.nn.Module needs special handling
194
+ if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
195
+ return functools.partial(call_torch, t)
196
+ # regular types are not changed
197
+ return t
121
198
 
122
199
 
123
- def compile_nn_module(m: torch.nn.Module, methods=None):
124
- if methods is None:
125
- methods = ['forward']
200
+ jax_view = functools.partial(pytree.tree_map, _jax_view)
126
201
 
127
- new_parent = type(
128
- m.__class__.__name__ + '_with_CompileMixin',
129
- (CompileMixin, m.__class__),
130
- )
131
- m.__class__ = NewParent
132
202
 
203
+ def call_jax(jax_func: JaxCallable, *args: TorchValue,
204
+ **kwargs: TorchValue) -> TorchValue:
205
+ args, kwargs = jax_view((args, kwargs))
206
+ res: JaxValue = jax_func(*args, **kwargs)
207
+ return torch_view(res)
133
208
 
134
- def _torch_view(t: JaxValue) -> TorchValue:
135
- # t is an object from jax land
136
- # view it as-if it's a torch land object
137
- if isinstance(t, jax.Array):
138
- # TODO
139
- return tensor.Tensor(t, torchax.default_env())
140
- if isinstance(t, type(jnp.int32)):
141
- return tensor.t2j_type(t)
142
- if callable(t): # t is a JaxCallable
143
- return functools.partial(call_jax, t)
144
- # regular types are not changed
145
- return t
146
209
 
147
- torch_view = functools.partial(pytree.tree_map, _torch_view)
210
+ def call_torch(torch_func: TorchCallable, *args: JaxValue,
211
+ **kwargs: JaxValue) -> JaxValue:
212
+ args, kwargs = torch_view((args, kwargs))
213
+ with torchax.default_env():
214
+ res: TorchValue = torch_func(*args, **kwargs)
215
+ return jax_view(res)
148
216
 
149
217
 
150
- def _jax_view(t: TorchValue) -> JaxValue:
151
- # t is an object from torch land
152
- # view it as-if it's a jax land object
153
- if isinstance(t, torch.Tensor):
154
- assert isinstance(t, tensor.Tensor), type(t)
155
- return t.jax()
156
- if isinstance(t, type(torch.int32)):
157
- return tensor.t2j_dtype(t)
158
-
159
- # torch.nn.Module needs special handling
160
- if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
161
- return functools.partial(call_torch, t)
162
- # regular types are not changed
163
- return t
218
+ def j2t_autograd(fn, call_jax=call_jax):
219
+ """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
164
220
 
165
- jax_view = functools.partial(pytree.tree_map, _jax_view)
221
+ It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
222
+ activations). The wrapped function is then run via `call_jax` and integrated into
223
+ the PyTorch autograd framework by saving the residuals into the context object.
224
+ """
166
225
 
226
+ @wraps(fn)
227
+ def inner(*args, **kwargs):
228
+ from jax.tree_util import tree_flatten, tree_unflatten
229
+ from jax.util import safe_zip
167
230
 
168
- def call_jax(jax_func: JaxCallable,
169
- *args: TorchValue,
170
- **kwargs: TorchValue) -> TorchValue:
171
- args, kwargs = jax_view((args, kwargs))
172
- res: JaxValue = jax_func(*args, **kwargs)
173
- return torch_view(res)
231
+ class JaxFun(torch.autograd.Function):
232
+
233
+ @staticmethod
234
+ def forward(ctx, tree_def, *flat_args_kwargs):
174
235
 
236
+ tensors, other = util.partition(flat_args_kwargs,
237
+ lambda x: isinstance(x, torch.Tensor))
238
+ # We want the arguments that don't require grads to be closured?
175
239
 
176
- def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue:
177
- args, kwargs = torch_view((args, kwargs))
178
- with torchax.default_env():
179
- res: TorchValue = torch_func(*args, **kwargs)
180
- return jax_view(res)
240
+ y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
241
+
242
+ # Save necessary information for backward
243
+ # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass.
244
+ # `residuals` contains the tensors needed for the backward pass.`
245
+ residuals, vjp_spec = tree_flatten(fun_vjp)
246
+ ctx.vjp_spec = vjp_spec
247
+ ctx.save_for_backward(*residuals)
248
+ return y
249
+
250
+ @staticmethod
251
+ def backward(ctx, *grad_out):
252
+ assert len(grad_out) > 0
253
+ grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
254
+
255
+ input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec,
256
+ ctx.saved_tensors, grad_out)
257
+
258
+ # Construct the gradient tuple to be returned.
259
+ # It needs to match the inputs to forward: (tree_def, *flat_inputs)
260
+ # The first gradient (for tree_def) is None.
261
+ # The subsequent gradients correspond to flat_inputs.
262
+ # We need to put a None for inputs that did not require gradients.
263
+ final_grads = [None]
264
+ for needs_grad, grad in safe_zip(ctx.needs_input_grad[1:],
265
+ input_grads_structured):
266
+ final_grads.append(grad if needs_grad else None)
267
+
268
+ return tuple(final_grads)
269
+
270
+ sig = signature(fn)
271
+ bound = sig.bind(*args, **kwargs)
272
+ bound.apply_defaults()
273
+ flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs))
274
+ y = JaxFun.apply(tree_def, *flat_args_kwargs)
275
+ return y
276
+
277
+ return inner
278
+
279
+
280
+ # NOTE(qihqi): This function cannot be inlined from the callsite
281
+ # Becuase if it does, then it won't hit the compilation cache for
282
+ # call_jax. Call jax uses functions' id as key.
283
+ def _jax_forward(fn, other, tree_def, tensors):
284
+ """JAX function to compute output and vjp function.
285
+
286
+ primals should be a tuple (args, kwargs).
287
+ """
288
+ import jax
289
+ from jax.tree_util import tree_flatten, tree_unflatten
290
+
291
+ def fn_wrapper(*tensors):
292
+ # Reconstruct the original args and kwargs
293
+ flat_inputs = util.merge(tensors, other)
294
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
295
+ return fn(*args, **kwargs)
296
+
297
+ return jax.vjp(fn_wrapper, *tensors)
298
+
299
+
300
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
301
+ """JAX function to compute input gradients.
302
+
303
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
304
+ """
305
+ from jax.tree_util import tree_unflatten
306
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
307
+ return fun_vjp(grad_out)
181
308
 
182
309
 
183
310
  fori_loop = torch_view(jax.lax.fori_loop)
184
311
 
185
312
 
186
313
  def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
187
- kwargs_for_jax = kwargs_for_jax or {}
188
- jax_func = jax_view(torch_function)
189
- jitted = jax_jit_func(jax_func, **kwargs_for_jax)
190
- return torch_view(jitted)
314
+ kwargs_for_jax = kwargs_for_jax or {}
315
+ jax_func = jax_view(torch_function)
316
+ jitted = jax_jit_func(jax_func, **kwargs_for_jax)
317
+ return torch_view(jitted)
191
318
 
192
319
 
193
- def jax_jit(torch_function, kwargs_for_jax_jit=None):
194
- return wrap_jax_jit(torch_function, jax_jit_func=jax.jit,
195
- kwargs_for_jax=kwargs_for_jax_jit)
320
+ def jax_jit(torch_function,
321
+ kwargs_for_jax_jit=None,
322
+ fix_for_buffer_donation=False):
323
+ return wrap_jax_jit(
324
+ torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)
196
325
 
197
326
 
198
327
  def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
199
- return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
200
- kwargs_for_jax=kwargs_for_jax_shard_map)
328
+ return wrap_jax_jit(
329
+ torch_function,
330
+ jax_jit_func=shard_map,
331
+ kwargs_for_jax=kwargs_for_jax_shard_map)
201
332
 
202
333
 
203
334
  def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
204
- return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad,
205
- kwargs_for_jax=kwargs_for_value_and_grad)
335
+ return wrap_jax_jit(
336
+ torch_function,
337
+ jax_jit_func=jax.value_and_grad,
338
+ kwargs_for_jax=kwargs_for_value_and_grad)
339
+
206
340
 
207
341
  def gradient_checkpoint(torch_function, kwargs=None):
208
- return wrap_jax_jit(torch_function, jax_jit_func=jax.checkpoint,
209
- kwargs_for_jax=kwargs)
342
+ return wrap_jax_jit(
343
+ torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)