torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/interop.py CHANGED
@@ -1,209 +1,356 @@
1
+ import collections
1
2
  import copy
2
3
  import functools
3
4
  import torch
5
+ from inspect import signature
6
+ from functools import wraps
4
7
  from torch.nn.utils import stateless as torch_stateless
5
8
  import jax
6
9
  import jax.numpy as jnp
7
10
  from jax import tree_util as pytree
8
11
  from jax.experimental.shard_map import shard_map
9
12
  from torchax import tensor
13
+ from torchax import util
14
+ from torchax.ops import mappings
10
15
  import torchax
11
16
 
12
17
  from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
13
18
 
14
19
 
15
20
  def extract_all_buffers(m: torch.nn.Module):
16
- buffers = {}
17
- params = {}
18
- def extract_one(module, prefix):
19
- for k in dir(module):
20
- try:
21
- v = getattr(module, k)
22
- except:
23
- continue
24
- qual_name = prefix + k
25
- if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
26
- params[qual_name] = v
27
- elif isinstance(v, torch.Tensor):
28
- buffers[qual_name] = v
29
- for name, child in module.named_children():
30
- extract_one(child, prefix + name + '.')
31
- extract_one(m, '')
32
- return params, buffers
21
+ buffers = {}
22
+ params = {}
23
+
24
+ def extract_one(module, prefix):
25
+ for k in dir(module):
26
+ try:
27
+ v = getattr(module, k)
28
+ except:
29
+ continue
30
+ qual_name = prefix + k
31
+ if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
32
+ params[qual_name] = v
33
+ elif isinstance(v, torch.Tensor):
34
+ buffers[qual_name] = v
35
+ for name, child in module.named_children():
36
+ extract_one(child, prefix + name + '.')
37
+
38
+ extract_one(m, '')
39
+ return params, buffers
33
40
 
34
41
 
35
42
  def set_all_buffers(m, params, buffers):
36
- def set_one(module, prefix):
37
- for k in dir(module):
38
- qual_name = prefix + k
39
- if (potential_v := buffers.get(qual_name)) is not None:
40
- setattr(module, k, potential_v)
41
- elif (potential_v := params.get(qual_name)) is not None:
42
- print(k, potential_v)
43
- setattr(module, k, torch.nn.Parameter(potential_v))
44
- for name, child in module.named_children():
45
- set_one(child, prefix + name + '.')
46
43
 
47
- set_one(m, '')
44
+ def set_one(module, prefix):
45
+ for k in dir(module):
46
+ qual_name = prefix + k
47
+ if (potential_v := buffers.get(qual_name)) is not None:
48
+ setattr(module, k, potential_v)
49
+ elif (potential_v := params.get(qual_name)) is not None:
50
+ print(k, potential_v)
51
+ setattr(module, k, torch.nn.Parameter(potential_v))
52
+ for name, child in module.named_children():
53
+ set_one(child, prefix + name + '.')
54
+
55
+ set_one(m, '')
48
56
 
49
57
 
50
58
  class JittableModule(torch.nn.Module):
51
59
 
52
- def __init__(self, m: torch.nn.Module, extra_jit_args={}):
53
- super().__init__()
54
- self.params, self.buffers = extract_all_buffers(m)
55
- self._model = m
56
- self._jitted = {}
60
+ def __init__(self,
61
+ m: torch.nn.Module,
62
+ extra_jit_args={},
63
+ dedup_parameters=True):
64
+ super().__init__()
65
+ self.params, self.buffers = extract_all_buffers(m)
66
+ self._model = m
67
+ self._jitted = {}
68
+
69
+ self._extra_jit_args = extra_jit_args
70
+
71
+ self._extra_dumped_weights = {}
72
+
73
+ if dedup_parameters:
74
+ temp = collections.defaultdict(list)
75
+ for k, v in self.params.items():
76
+ temp[id(v)].append(k)
77
+
78
+ for v in temp.values():
79
+ if len(v) > 1:
80
+ # duplicated weights with different name
81
+ self._extra_dumped_weights[v[0]] = v[1:]
82
+ for extra_keys in v[1:]:
83
+ del self.params[extra_keys]
84
+
85
+ @property
86
+ def __class__(self):
87
+ # Lie about the class type so that
88
+ # isinstance(jittable_module, self._model.__class__) works
89
+ return self._model.__class__
90
+
91
+ def __call__(self, *args, **kwargs):
92
+ return self.forward(*args, **kwargs)
93
+
94
+ def functional_call(self, method_or_name, params, buffers, *args, **kwargs):
95
+ kwargs = kwargs or {}
96
+ params_copy = copy.copy(params)
97
+ params_copy.update(buffers)
98
+ # reinflate the state dict so there are not any missing keys
99
+ for k, v in self._extra_dumped_weights.items():
100
+ for new_key in v:
101
+ params_copy[new_key] = params_copy[k]
102
+
103
+ if isinstance(method_or_name, str):
104
+ method = getattr(self._model, method_or_name)
105
+ else:
106
+ if not callable(method_or_name):
107
+ raise TypeError(
108
+ f"method_or_name should be a callable or a string, got {type(method_or_name)}"
109
+ )
110
+ method = method_or_name
111
+ args = (self._model,) + args
112
+ with torch_stateless._reparametrize_module(self._model, params_copy):
113
+ res = method(*args, **kwargs)
114
+ return res
115
+
116
+ def jittable_call(self, method_name: str, *args, **kwargs):
117
+ if method_name not in self._jitted:
118
+ jitted = jax_jit(
119
+ functools.partial(self.functional_call, method_name),
120
+ kwargs_for_jax_jit=self._extra_jit_args,
121
+ )
122
+
123
+ def jitted_forward(*args, **kwargs):
124
+ return jitted(self.params, self.buffers, *args, **kwargs)
125
+
126
+ self._jitted[method_name] = jitted_forward
127
+ return self._jitted[method_name](*args, **kwargs)
128
+
129
+ def forward(self, *args, **kwargs):
130
+ return self.jittable_call('forward', *args, **kwargs)
131
+
132
+ def __getattr__(self, key):
133
+ if key == '_model':
134
+ return super().__getattr__(key)
135
+ if key in self._jitted:
136
+ return self._jitted[key]
137
+ return getattr(self._model, key)
138
+
139
+ def make_jitted(self, key):
140
+ jitted = jax_jit(
141
+ functools.partial(self.functional_call, key),
142
+ kwargs_for_jax_jit=self._extra_jit_args)
143
+
144
+ def call(*args, **kwargs):
145
+ return jitted(self.params, self.buffers, *args, **kwargs)
146
+
147
+ self._jitted[key] = call
57
148
 
58
- self._extra_jit_args = extra_jit_args
59
149
 
150
+ class CompileMixin:
60
151
 
61
- def __call__(self, *args, **kwargs):
62
- return self.forward(*args, **kwargs)
152
+ def functional_call(self, method, params, buffers, *args, **kwargs):
153
+ kwargs = kwargs or {}
154
+ params_copy = copy.copy(params)
155
+ params_copy.update(buffers)
156
+ with torch_stateless._reparametrize_module(self, params_copy):
157
+ res = method(*args, **kwargs)
158
+ return res
63
159
 
160
+ def jit(self, method):
161
+ jitted = jax_jit(functools.partial(self.functional_call, method_name))
64
162
 
65
- def functional_call(
66
- self, method_name, params, buffers, *args, **kwargs):
67
- kwargs = kwargs or {}
68
- params_copy = copy.copy(params)
69
- params_copy.update(buffers)
70
- with torch_stateless._reparametrize_module(self._model, params_copy):
71
- res = getattr(self._model, method_name)(*args, **kwargs)
72
- return res
163
+ def call(*args, **kwargs):
164
+ return jitted(self.named_paramters(), self.named_buffers(), *args,
165
+ **kwargs)
73
166
 
167
+ return call
74
168
 
75
- def forward(self, *args, **kwargs):
76
- if 'forward' not in self._jitted:
77
- jitted = jax_jit(
78
- functools.partial(self.functional_call, 'forward'),
79
- kwargs_for_jax_jit=self._extra_jit_args,
80
- )
81
- def jitted_forward(*args, **kwargs):
82
- return jitted(self.params, self.buffers, *args, **kwargs)
83
- self._jitted['forward'] = jitted_forward
84
- return self._jitted['forward'](*args, **kwargs)
85
169
 
86
- def __getattr__(self, key):
87
- if key == '_model':
88
- return super().__getattr__(key)
89
- if key in self._jitted:
90
- return self._jitted[key]
91
- return getattr(self._model, key)
170
+ def compile_nn_module(m: torch.nn.Module, methods=None):
171
+ if methods is None:
172
+ methods = ['forward']
92
173
 
93
- def make_jitted(self, key):
94
- jitted = jax_jit(
95
- functools.partial(self.functional_call, key),
96
- kwargs_for_jax_jit=self._extra_jit_args)
97
- def call(*args, **kwargs):
98
- return jitted(self.params, self.buffers, *args, **kwargs)
99
- self._jitted[key] = call
174
+ new_parent = type(
175
+ m.__class__.__name__ + '_with_CompileMixin',
176
+ (CompileMixin, m.__class__),
177
+ )
178
+ m.__class__ = NewParent
100
179
 
101
180
 
181
+ def _torch_view(t: JaxValue) -> TorchValue:
182
+ # t is an object from jax land
183
+ # view it as-if it's a torch land object
184
+ if isinstance(t, jax.Array):
185
+ # TODO
186
+ return tensor.Tensor(t, torchax.default_env())
187
+ if isinstance(t, jnp.dtype):
188
+ return mappings.j2t_dtype(t)
189
+ if callable(t): # t is a JaxCallable
190
+ return functools.partial(call_jax, t)
191
+ # regular types are not changed
192
+ return t
102
193
 
103
194
 
195
+ torch_view = functools.partial(pytree.tree_map, _torch_view)
104
196
 
105
- class CompileMixin:
106
197
 
107
- def functional_call(
108
- self, method, params, buffers, *args, **kwargs):
109
- kwargs = kwargs or {}
110
- params_copy = copy.copy(params)
111
- params_copy.update(buffers)
112
- with torch_stateless._reparametrize_module(self, params_copy):
113
- res = method(*args, **kwargs)
114
- return res
198
+ def _jax_view(t: TorchValue) -> JaxValue:
199
+ # t is an object from torch land
200
+ # view it as-if it's a jax land object
201
+ if isinstance(t, torch.Tensor):
202
+ assert isinstance(t, tensor.Tensor) or isinstance(t, tensor.View), type(t)
203
+ return t.jax()
204
+ if isinstance(t, type(torch.int32)):
205
+ return mappings.t2j_dtype(t)
115
206
 
116
- def jit(self, method):
117
- jitted = jax_jit(functools.partial(self.functional_call, method_name))
118
- def call(*args, **kwargs):
119
- return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
120
- return call
207
+ # torch.nn.Module needs special handling
208
+ if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
209
+ return functools.partial(call_torch, t)
210
+ # regular types are not changed
211
+ return t
121
212
 
122
213
 
123
- def compile_nn_module(m: torch.nn.Module, methods=None):
124
- if methods is None:
125
- methods = ['forward']
214
+ jax_view = functools.partial(pytree.tree_map, _jax_view)
126
215
 
127
- new_parent = type(
128
- m.__class__.__name__ + '_with_CompileMixin',
129
- (CompileMixin, m.__class__),
130
- )
131
- m.__class__ = NewParent
132
216
 
217
+ def call_jax(jax_func: JaxCallable, *args: TorchValue,
218
+ **kwargs: TorchValue) -> TorchValue:
219
+ args, kwargs = jax_view((args, kwargs))
220
+ res: JaxValue = jax_func(*args, **kwargs)
221
+ return torch_view(res)
133
222
 
134
- def _torch_view(t: JaxValue) -> TorchValue:
135
- # t is an object from jax land
136
- # view it as-if it's a torch land object
137
- if isinstance(t, jax.Array):
138
- # TODO
139
- return tensor.Tensor(t, torchax.default_env())
140
- if isinstance(t, type(jnp.int32)):
141
- return tensor.t2j_type(t)
142
- if callable(t): # t is a JaxCallable
143
- return functools.partial(call_jax, t)
144
- # regular types are not changed
145
- return t
146
223
 
147
- torch_view = functools.partial(pytree.tree_map, _torch_view)
224
+ def call_torch(torch_func: TorchCallable, *args: JaxValue,
225
+ **kwargs: JaxValue) -> JaxValue:
226
+ args, kwargs = torch_view((args, kwargs))
227
+ with torchax.default_env():
228
+ res: TorchValue = torch_func(*args, **kwargs)
229
+ return jax_view(res)
148
230
 
149
231
 
150
- def _jax_view(t: TorchValue) -> JaxValue:
151
- # t is an object from torch land
152
- # view it as-if it's a jax land object
153
- if isinstance(t, torch.Tensor):
154
- assert isinstance(t, tensor.Tensor), type(t)
155
- return t.jax()
156
- if isinstance(t, type(torch.int32)):
157
- return tensor.t2j_dtype(t)
158
-
159
- # torch.nn.Module needs special handling
160
- if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
161
- return functools.partial(call_torch, t)
162
- # regular types are not changed
163
- return t
232
+ def j2t_autograd(fn, call_jax=call_jax):
233
+ """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
164
234
 
165
- jax_view = functools.partial(pytree.tree_map, _jax_view)
235
+ It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
236
+ activations). The wrapped function is then run via `call_jax` and integrated into
237
+ the PyTorch autograd framework by saving the residuals into the context object.
238
+ """
166
239
 
240
+ @wraps(fn)
241
+ def inner(*args, **kwargs):
242
+ from jax.tree_util import tree_flatten
167
243
 
168
- def call_jax(jax_func: JaxCallable,
169
- *args: TorchValue,
170
- **kwargs: TorchValue) -> TorchValue:
171
- args, kwargs = jax_view((args, kwargs))
172
- res: JaxValue = jax_func(*args, **kwargs)
173
- return torch_view(res)
244
+ class JaxFun(torch.autograd.Function):
245
+
246
+ @staticmethod
247
+ def forward(ctx, tree_def, *flat_args_kwargs):
174
248
 
249
+ tensors, other = util.partition(flat_args_kwargs,
250
+ lambda x: isinstance(x, torch.Tensor))
251
+ # We want the arguments that don't require grads to be closured?
175
252
 
176
- def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue:
177
- args, kwargs = torch_view((args, kwargs))
178
- with torchax.default_env():
179
- res: TorchValue = torch_func(*args, **kwargs)
180
- return jax_view(res)
253
+ y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
254
+
255
+ # Save necessary information for backward
256
+ # Flatten the vjp function. `vjp_spec` contains a jaxpr for the backward pass.
257
+ # `residuals` contains the tensors needed for the backward pass.`
258
+ residuals, vjp_spec = tree_flatten(fun_vjp)
259
+ ctx.vjp_spec = vjp_spec
260
+ ctx.save_for_backward(*residuals)
261
+ return y
262
+
263
+ @staticmethod
264
+ def backward(ctx, *grad_out):
265
+ assert len(grad_out) > 0
266
+ grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
267
+
268
+ input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec,
269
+ ctx.saved_tensors, grad_out)
270
+
271
+ # Construct the gradient tuple to be returned.
272
+ # It needs to match the inputs to forward: (tree_def, *flat_inputs)
273
+ # The first gradient (for tree_def) is None.
274
+ # The subsequent gradients correspond to flat_inputs.
275
+ # We need to put a None for inputs that did not require gradients.
276
+ final_grads = [None]
277
+ for needs_grad, grad in zip(
278
+ ctx.needs_input_grad[1:], input_grads_structured, strict=True):
279
+ final_grads.append(grad if needs_grad else None)
280
+
281
+ return tuple(final_grads)
282
+
283
+ sig = signature(fn)
284
+ bound = sig.bind(*args, **kwargs)
285
+ bound.apply_defaults()
286
+ flat_args_kwargs, tree_def = tree_flatten((bound.args, bound.kwargs))
287
+ y = JaxFun.apply(tree_def, *flat_args_kwargs)
288
+ return y
289
+
290
+ return inner
291
+
292
+
293
+ # NOTE(qihqi): This function cannot be inlined from the callsite
294
+ # Becuase if it does, then it won't hit the compilation cache for
295
+ # call_jax. Call jax uses functions' id as key.
296
+ def _jax_forward(fn, other, tree_def, tensors):
297
+ """JAX function to compute output and vjp function.
298
+
299
+ primals should be a tuple (args, kwargs).
300
+ """
301
+ import jax
302
+ from jax.tree_util import tree_flatten, tree_unflatten
303
+
304
+ def fn_wrapper(*tensors):
305
+ # Reconstruct the original args and kwargs
306
+ flat_inputs = util.merge(tensors, other)
307
+ args, kwargs = tree_unflatten(tree_def, flat_inputs)
308
+ return fn(*args, **kwargs)
309
+
310
+ return jax.vjp(fn_wrapper, *tensors)
311
+
312
+
313
+ def _jax_backward(vjp_spec, saved_tensors, grad_out):
314
+ """JAX function to compute input gradients.
315
+
316
+ Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
317
+ """
318
+ from jax.tree_util import tree_unflatten
319
+ fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
320
+ return fun_vjp(grad_out)
181
321
 
182
322
 
183
323
  fori_loop = torch_view(jax.lax.fori_loop)
184
324
 
185
325
 
186
326
  def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
187
- kwargs_for_jax = kwargs_for_jax or {}
188
- jax_func = jax_view(torch_function)
189
- jitted = jax_jit_func(jax_func, **kwargs_for_jax)
190
- return torch_view(jitted)
327
+ kwargs_for_jax = kwargs_for_jax or {}
328
+ jax_func = jax_view(torch_function)
329
+ jitted = jax_jit_func(jax_func, **kwargs_for_jax)
330
+ return torch_view(jitted)
191
331
 
192
332
 
193
- def jax_jit(torch_function, kwargs_for_jax_jit=None):
194
- return wrap_jax_jit(torch_function, jax_jit_func=jax.jit,
195
- kwargs_for_jax=kwargs_for_jax_jit)
333
+ def jax_jit(torch_function,
334
+ kwargs_for_jax_jit=None,
335
+ fix_for_buffer_donation=False):
336
+ return wrap_jax_jit(
337
+ torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)
196
338
 
197
339
 
198
340
  def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
199
- return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
200
- kwargs_for_jax=kwargs_for_jax_shard_map)
341
+ return wrap_jax_jit(
342
+ torch_function,
343
+ jax_jit_func=shard_map,
344
+ kwargs_for_jax=kwargs_for_jax_shard_map)
201
345
 
202
346
 
203
347
  def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
204
- return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad,
205
- kwargs_for_jax=kwargs_for_value_and_grad)
348
+ return wrap_jax_jit(
349
+ torch_function,
350
+ jax_jit_func=jax.value_and_grad,
351
+ kwargs_for_jax=kwargs_for_value_and_grad)
352
+
206
353
 
207
354
  def gradient_checkpoint(torch_function, kwargs=None):
208
- return wrap_jax_jit(torch_function, jax_jit_func=jax.checkpoint,
209
- kwargs_for_jax=kwargs)
355
+ return wrap_jax_jit(
356
+ torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)