torchax 0.0.10.dev20251114__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/interop.py CHANGED
@@ -15,20 +15,24 @@
15
15
  import collections
16
16
  import copy
17
17
  import functools
18
- import torch
19
- from inspect import signature
20
18
  from functools import wraps
21
- from torch.nn.utils import stateless as torch_stateless
19
+ from inspect import signature
20
+
22
21
  import jax
23
22
  import jax.numpy as jnp
23
+ import torch
24
24
  from jax import tree_util as pytree
25
- from jax.experimental.shard_map import shard_map
26
- from torchax import tensor
27
- from torchax import util
28
- from torchax.ops import mappings
25
+ from torch.nn.utils import stateless as torch_stateless
26
+
29
27
  import torchax
28
+ from torchax import tensor, util
29
+ from torchax.ops import mappings
30
+ from torchax.types import JaxCallable, JaxValue, TorchCallable, TorchValue
30
31
 
31
- from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
32
+ try:
33
+ from jax import shard_map as shard_map # for jax since v0.8.0
34
+ except ImportError:
35
+ from jax.experimental.shard_map import shard_map
32
36
 
33
37
 
34
38
  def extract_all_buffers(m: torch.nn.Module):
@@ -39,7 +43,7 @@ def extract_all_buffers(m: torch.nn.Module):
39
43
  for k in dir(module):
40
44
  try:
41
45
  v = getattr(module, k)
42
- except:
46
+ except Exception:
43
47
  continue
44
48
  qual_name = prefix + k
45
49
  if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
@@ -47,14 +51,13 @@ def extract_all_buffers(m: torch.nn.Module):
47
51
  elif isinstance(v, torch.Tensor):
48
52
  buffers[qual_name] = v
49
53
  for name, child in module.named_children():
50
- extract_one(child, prefix + name + '.')
54
+ extract_one(child, prefix + name + ".")
51
55
 
52
- extract_one(m, '')
56
+ extract_one(m, "")
53
57
  return params, buffers
54
58
 
55
59
 
56
60
  def set_all_buffers(m, params, buffers):
57
-
58
61
  def set_one(module, prefix):
59
62
  for k in dir(module):
60
63
  qual_name = prefix + k
@@ -64,17 +67,15 @@ def set_all_buffers(m, params, buffers):
64
67
  print(k, potential_v)
65
68
  setattr(module, k, torch.nn.Parameter(potential_v))
66
69
  for name, child in module.named_children():
67
- set_one(child, prefix + name + '.')
70
+ set_one(child, prefix + name + ".")
68
71
 
69
- set_one(m, '')
72
+ set_one(m, "")
70
73
 
71
74
 
72
75
  class JittableModule(torch.nn.Module):
73
-
74
- def __init__(self,
75
- m: torch.nn.Module,
76
- extra_jit_args={},
77
- dedup_parameters=True):
76
+ def __init__(self, m: torch.nn.Module, extra_jit_args=None, dedup_parameters=True):
77
+ if extra_jit_args is None:
78
+ extra_jit_args = {}
78
79
  super().__init__()
79
80
  self.params, self.buffers = extract_all_buffers(m)
80
81
  self._model = m
@@ -119,7 +120,7 @@ class JittableModule(torch.nn.Module):
119
120
  else:
120
121
  if not callable(method_or_name):
121
122
  raise TypeError(
122
- f"method_or_name should be a callable or a string, got {type(method_or_name)}"
123
+ f"method_or_name should be a callable or a string, got {type(method_or_name)}"
123
124
  )
124
125
  method = method_or_name
125
126
  args = (self._model,) + args
@@ -130,8 +131,8 @@ class JittableModule(torch.nn.Module):
130
131
  def jittable_call(self, method_name: str, *args, **kwargs):
131
132
  if method_name not in self._jitted:
132
133
  jitted = jax_jit(
133
- functools.partial(self.functional_call, method_name),
134
- kwargs_for_jax_jit=self._extra_jit_args,
134
+ functools.partial(self.functional_call, method_name),
135
+ kwargs_for_jax_jit=self._extra_jit_args,
135
136
  )
136
137
 
137
138
  def jitted_forward(*args, **kwargs):
@@ -141,10 +142,10 @@ class JittableModule(torch.nn.Module):
141
142
  return self._jitted[method_name](*args, **kwargs)
142
143
 
143
144
  def forward(self, *args, **kwargs):
144
- return self.jittable_call('forward', *args, **kwargs)
145
+ return self.jittable_call("forward", *args, **kwargs)
145
146
 
146
147
  def __getattr__(self, key):
147
- if key == '_model':
148
+ if key == "_model":
148
149
  return super().__getattr__(key)
149
150
  if key in self._jitted:
150
151
  return self._jitted[key]
@@ -152,8 +153,9 @@ class JittableModule(torch.nn.Module):
152
153
 
153
154
  def make_jitted(self, key):
154
155
  jitted = jax_jit(
155
- functools.partial(self.functional_call, key),
156
- kwargs_for_jax_jit=self._extra_jit_args)
156
+ functools.partial(self.functional_call, key),
157
+ kwargs_for_jax_jit=self._extra_jit_args,
158
+ )
157
159
 
158
160
  def call(*args, **kwargs):
159
161
  return jitted(self.params, self.buffers, *args, **kwargs)
@@ -162,7 +164,6 @@ class JittableModule(torch.nn.Module):
162
164
 
163
165
 
164
166
  class CompileMixin:
165
-
166
167
  def functional_call(self, method, params, buffers, *args, **kwargs):
167
168
  kwargs = kwargs or {}
168
169
  params_copy = copy.copy(params)
@@ -172,24 +173,23 @@ class CompileMixin:
172
173
  return res
173
174
 
174
175
  def jit(self, method):
175
- jitted = jax_jit(functools.partial(self.functional_call, method_name))
176
+ jitted = jax_jit(functools.partial(self.functional_call, method_name)) # noqa: F821
176
177
 
177
178
  def call(*args, **kwargs):
178
- return jitted(self.named_paramters(), self.named_buffers(), *args,
179
- **kwargs)
179
+ return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
180
180
 
181
181
  return call
182
182
 
183
183
 
184
184
  def compile_nn_module(m: torch.nn.Module, methods=None):
185
185
  if methods is None:
186
- methods = ['forward']
186
+ methods = ["forward"]
187
187
 
188
- new_parent = type(
189
- m.__class__.__name__ + '_with_CompileMixin',
190
- (CompileMixin, m.__class__),
188
+ type(
189
+ m.__class__.__name__ + "_with_CompileMixin",
190
+ (CompileMixin, m.__class__),
191
191
  )
192
- m.__class__ = NewParent
192
+ m.__class__ = NewParent # noqa: F821
193
193
 
194
194
 
195
195
  def _torch_view(t: JaxValue) -> TorchValue:
@@ -227,15 +227,17 @@ def _jax_view(t: TorchValue) -> JaxValue:
227
227
  jax_view = functools.partial(pytree.tree_map, _jax_view)
228
228
 
229
229
 
230
- def call_jax(jax_func: JaxCallable, *args: TorchValue,
231
- **kwargs: TorchValue) -> TorchValue:
230
+ def call_jax(
231
+ jax_func: JaxCallable, *args: TorchValue, **kwargs: TorchValue
232
+ ) -> TorchValue:
232
233
  args, kwargs = jax_view((args, kwargs))
233
234
  res: JaxValue = jax_func(*args, **kwargs)
234
235
  return torch_view(res)
235
236
 
236
237
 
237
- def call_torch(torch_func: TorchCallable, *args: JaxValue,
238
- **kwargs: JaxValue) -> JaxValue:
238
+ def call_torch(
239
+ torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue
240
+ ) -> JaxValue:
239
241
  args, kwargs = torch_view((args, kwargs))
240
242
  with torchax.default_env():
241
243
  res: TorchValue = torch_func(*args, **kwargs)
@@ -245,10 +247,10 @@ def call_torch(torch_func: TorchCallable, *args: JaxValue,
245
247
  def j2t_autograd(fn, call_jax=call_jax):
246
248
  """Given a JAX function, returns a PyTorch autograd function implemented with `jax.vjp(fn)`.
247
249
 
248
- It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
249
- activations). The wrapped function is then run via `call_jax` and integrated into
250
- the PyTorch autograd framework by saving the residuals into the context object.
251
- """
250
+ It wraps `fn` with `jax.vjp` to compute both the output and residuals (intermediate
251
+ activations). The wrapped function is then run via `call_jax` and integrated into
252
+ the PyTorch autograd framework by saving the residuals into the context object.
253
+ """
252
254
 
253
255
  # NOTE(qihqi): This function cannot be inlined from the callsite
254
256
  # Becuase if it does, then it won't hit the compilation cache for
@@ -261,7 +263,7 @@ def j2t_autograd(fn, call_jax=call_jax):
261
263
  primals should be a tuple (args, kwargs).
262
264
  """
263
265
  import jax
264
- from jax.tree_util import tree_flatten, tree_unflatten
266
+ from jax.tree_util import tree_unflatten
265
267
 
266
268
  def fn_wrapper(*tensors):
267
269
  # Reconstruct the original args and kwargs
@@ -277,6 +279,7 @@ def j2t_autograd(fn, call_jax=call_jax):
277
279
  Unflattening `saved_tensors` with `vjp_spec` should restore the original vjp function.
278
280
  """
279
281
  from jax.tree_util import tree_unflatten
282
+
280
283
  fun_vjp = tree_unflatten(vjp_spec, saved_tensors)
281
284
  return fun_vjp(grad_out)
282
285
 
@@ -285,12 +288,11 @@ def j2t_autograd(fn, call_jax=call_jax):
285
288
  from jax.tree_util import tree_flatten
286
289
 
287
290
  class JaxFun(torch.autograd.Function):
288
-
289
291
  @staticmethod
290
292
  def forward(ctx, tree_def, *flat_args_kwargs):
291
-
292
- tensors, other = util.partition(flat_args_kwargs,
293
- lambda x: isinstance(x, torch.Tensor))
293
+ tensors, other = util.partition(
294
+ flat_args_kwargs, lambda x: isinstance(x, torch.Tensor)
295
+ )
294
296
  # We want the arguments that don't require grads to be closured?
295
297
 
296
298
  y, fun_vjp = call_jax(_jax_forward, fn, other, tree_def, tensors)
@@ -308,8 +310,9 @@ def j2t_autograd(fn, call_jax=call_jax):
308
310
  assert len(grad_out) > 0
309
311
  grad_out = grad_out if len(grad_out) > 1 else grad_out[0]
310
312
 
311
- input_grads_structured = call_jax(_jax_backward, ctx.vjp_spec,
312
- ctx.saved_tensors, grad_out)
313
+ input_grads_structured = call_jax(
314
+ _jax_backward, ctx.vjp_spec, ctx.saved_tensors, grad_out
315
+ )
313
316
 
314
317
  # Construct the gradient tuple to be returned.
315
318
  # It needs to match the inputs to forward: (tree_def, *flat_inputs)
@@ -318,7 +321,8 @@ def j2t_autograd(fn, call_jax=call_jax):
318
321
  # We need to put a None for inputs that did not require gradients.
319
322
  final_grads = [None]
320
323
  for needs_grad, grad in zip(
321
- ctx.needs_input_grad[1:], input_grads_structured, strict=True):
324
+ ctx.needs_input_grad[1:], input_grads_structured, strict=True
325
+ ):
322
326
  final_grads.append(grad if needs_grad else None)
323
327
 
324
328
  return tuple(final_grads)
@@ -343,27 +347,27 @@ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
343
347
  return torch_view(jitted)
344
348
 
345
349
 
346
- def jax_jit(torch_function,
347
- kwargs_for_jax_jit=None,
348
- fix_for_buffer_donation=False):
350
+ def jax_jit(torch_function, kwargs_for_jax_jit=None, fix_for_buffer_donation=False):
349
351
  return wrap_jax_jit(
350
- torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit)
352
+ torch_function, jax_jit_func=jax.jit, kwargs_for_jax=kwargs_for_jax_jit
353
+ )
351
354
 
352
355
 
353
356
  def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
354
357
  return wrap_jax_jit(
355
- torch_function,
356
- jax_jit_func=shard_map,
357
- kwargs_for_jax=kwargs_for_jax_shard_map)
358
+ torch_function, jax_jit_func=shard_map, kwargs_for_jax=kwargs_for_jax_shard_map
359
+ )
358
360
 
359
361
 
360
362
  def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
361
363
  return wrap_jax_jit(
362
- torch_function,
363
- jax_jit_func=jax.value_and_grad,
364
- kwargs_for_jax=kwargs_for_value_and_grad)
364
+ torch_function,
365
+ jax_jit_func=jax.value_and_grad,
366
+ kwargs_for_jax=kwargs_for_value_and_grad,
367
+ )
365
368
 
366
369
 
367
370
  def gradient_checkpoint(torch_function, kwargs=None):
368
371
  return wrap_jax_jit(
369
- torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs)
372
+ torch_function, jax_jit_func=jax.checkpoint, kwargs_for_jax=kwargs
373
+ )
torchax/mesh_util.py CHANGED
@@ -13,8 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import jax
16
- from jax.sharding import PartitionSpec, NamedSharding
17
16
  import torch
17
+ from jax.sharding import NamedSharding, PartitionSpec
18
+
18
19
  import torchax
19
20
  from torchax import interop
20
21
 
@@ -94,12 +95,13 @@ class SingleAxisSharder:
94
95
  `_shard_first_multiple_of`.
95
96
  """
96
97
  del name
97
- sharding = _shard_first_multiple_of(self.axis_name, shapedtype.shape,
98
- self.axis_size)
98
+ sharding = _shard_first_multiple_of(
99
+ self.axis_name, shapedtype.shape, self.axis_size
100
+ )
99
101
  if not self.replicate_unshardable and all(s is None for s in sharding):
100
102
  raise AssertionError(
101
- f"Unable to find a dim to shard because "
102
- f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
103
+ f"Unable to find a dim to shard because "
104
+ f"None of the dims ({shapedtype.shape}) in shape is multiple of {self.axis_size}"
103
105
  )
104
106
  return sharding
105
107
 
@@ -159,15 +161,14 @@ class Mesh:
159
161
  self.jax_mesh = jax_mesh
160
162
  if sharder is None:
161
163
  assert len(self.jax_mesh.axis_names) == 1
162
- sharder = SingleAxisSharder(self.jax_mesh.axis_names[0],
163
- len(self.mesh.device_ids))
164
+ sharder = SingleAxisSharder(
165
+ self.jax_mesh.axis_names[0], len(self.mesh.device_ids)
166
+ )
164
167
  self._sharder = sharder
165
168
 
166
- def initialize_model_sharded(self,
167
- model_class,
168
- init_args,
169
- init_kwargs=None,
170
- override_sharder=None):
169
+ def initialize_model_sharded(
170
+ self, model_class, init_args, init_kwargs=None, override_sharder=None
171
+ ):
171
172
  """Initializes a PyTorch model with its parameters sharded across the mesh.
172
173
 
173
174
  This method orchestrates the initialization of a `torch.nn.Module` such
@@ -208,17 +209,18 @@ class Mesh:
208
209
 
209
210
  states = model.state_dict()
210
211
  output_shards = {
211
- name: NamedSharding(self.jax_mesh, sharder(name, tensor))
212
- for name, tensor in states.items()
212
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
213
+ for name, tensor in states.items()
213
214
  }
214
215
 
215
216
  def model_initializer():
216
- with torchax.default_env(), torch.device('meta'):
217
+ with torchax.default_env(), torch.device("meta"):
217
218
  model = model_class(*init_args, **init_kwargs)
218
219
  return dict(model.state_dict())
219
220
 
220
221
  jitted = interop.jax_jit(
221
- model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards})
222
+ model_initializer, kwargs_for_jax_jit={"out_shardings": output_shards}
223
+ )
222
224
  weights_dict = jitted()
223
225
 
224
226
  model.load_state_dict(weights_dict, assign=True)
@@ -228,7 +230,7 @@ class Mesh:
228
230
  sharder = override_sharder or self._sharder
229
231
  states = model.state_dict()
230
232
  output_shards = {
231
- name: NamedSharding(self.jax_mesh, sharder(name, tensor))
232
- for name, tensor in states.items()
233
+ name: NamedSharding(self.jax_mesh, sharder(name, tensor))
234
+ for name, tensor in states.items()
233
235
  }
234
236
  model.load_state_dict(output_shards, assign=True)
torchax/ops/__init__.py CHANGED
@@ -12,13 +12,14 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  def all_aten_jax_ops():
16
17
  # to load the ops
17
18
  import torchax.ops.jaten # type: ignore
18
19
  import torchax.ops.ops_registry # type: ignore
19
20
 
20
21
  return {
21
- key: val.func
22
- for key, val in torchax.ops.ops_registry.all_aten_ops.items()
23
- if val.is_jax_function
22
+ key: val.func
23
+ for key, val in torchax.ops.ops_registry.all_aten_ops.items()
24
+ if val.is_jax_function
24
25
  }