torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/device_module.py CHANGED
@@ -1,20 +1,26 @@
1
1
  def _is_in_bad_fork():
2
2
  return False
3
3
 
4
+
4
5
  def manual_seed_all(seed):
5
6
  pass
6
7
 
8
+
7
9
  def device_count():
8
10
  return 1
9
11
 
12
+
10
13
  def get_rng_state():
11
14
  return []
12
15
 
16
+
13
17
  def set_rng_state(new_state, device):
14
18
  pass
15
19
 
20
+
16
21
  def is_available():
17
22
  return True
18
23
 
24
+
19
25
  def current_device():
20
- return 0
26
+ return 0
torchax/distributed.py CHANGED
@@ -51,64 +51,61 @@ class ProcessGroupJax(ProcessGroup):
51
51
 
52
52
  @staticmethod
53
53
  def _work(
54
- tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]],
54
+ tensors: Union[torch.Tensor, List[torch.Tensor],
55
+ List[List[torch.Tensor]]],
55
56
  ) -> dist.Work:
56
57
  fut = torch.futures.Future()
57
58
  fut.set_result(tensors)
58
59
  return torch._C._distributed_c10d._create_work_from_future(fut)
59
60
 
60
61
  def _allgather_base(
61
- self,
62
- output: torch.Tensor,
63
- input: torch.Tensor,
64
- opts=...,
62
+ self,
63
+ output: torch.Tensor,
64
+ input: torch.Tensor,
65
+ opts=...,
65
66
  ) -> dist.Work:
66
67
  assert isinstance(input, torchax.tensor.Tensor)
67
68
  assert isinstance(output, torchax.tensor.Tensor)
68
69
  torch.distributed._functional_collectives.all_gather_tensor_inplace(
69
- output, input, group=self
70
- )
70
+ output, input, group=self)
71
71
  return self._work(output)
72
72
 
73
73
  def allreduce(
74
- self,
75
- tensors: List[torch.Tensor],
76
- opts: dist.AllreduceOptions = ...,
74
+ self,
75
+ tensors: List[torch.Tensor],
76
+ opts: dist.AllreduceOptions = ...,
77
77
  ) -> dist.Work:
78
78
  assert len(tensors) == 1
79
79
  assert isinstance(tensors[0], torchax.tensor.Tensor)
80
80
  torch.distributed._functional_collectives.all_reduce_inplace(
81
- tensors[0],
82
- torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
83
- opts.reduceOp.op
84
- ],
85
- self,
81
+ tensors[0],
82
+ torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
83
+ opts.reduceOp.op],
84
+ self,
86
85
  )
87
86
 
88
87
  return self._work(tensors)
89
88
 
90
89
  def broadcast(
91
- self,
92
- tensors: List[torch.Tensor],
93
- opts: dist.BroadcastOptions = ...,
90
+ self,
91
+ tensors: List[torch.Tensor],
92
+ opts: dist.BroadcastOptions = ...,
94
93
  ) -> dist.Work:
95
94
  assert len(tensors) == 1
96
95
  assert isinstance(tensors[0], torchax.tensor.Tensor)
97
96
  tensors[0].copy_(
98
- torch.distributed._functional_collectives.broadcast(
99
- tensors[0], opts.rootRank, group=self
100
- )
101
- )
97
+ torch.distributed._functional_collectives.broadcast(
98
+ tensors[0], opts.rootRank, group=self))
102
99
 
103
100
  return self._work(tensors)
104
101
 
105
102
 
106
- dist.Backend.register_backend("jax", ProcessGroupJax)
103
+ dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"])
107
104
 
108
105
 
109
- def jax_rendezvous_handler(
110
- url: str, timeout: datetime.timedelta = ..., **kwargs
111
- ):
106
+ def jax_rendezvous_handler(url: str,
107
+ timeout: datetime.timedelta = ...,
108
+ **kwargs):
112
109
  """Initialize distributed store with JAX process IDs.
113
110
 
114
111
  Requires `$MASTER_ADDR` and `$MASTER_PORT`.
@@ -120,10 +117,10 @@ def jax_rendezvous_handler(
120
117
  master_port = int(os.environ["MASTER_PORT"])
121
118
  # TODO(wcromar): Use `torchrun`'s store if available
122
119
  store = dist.TCPStore(
123
- master_ip,
124
- master_port,
125
- jax.process_count(),
126
- is_master=jax.process_index() == 0,
120
+ master_ip,
121
+ master_port,
122
+ jax.process_count(),
123
+ is_master=jax.process_index() == 0,
127
124
  )
128
125
 
129
126
  yield (store, jax.process_index(), jax.process_count())
@@ -145,9 +142,9 @@ def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
145
142
  torch_outputs = f(index, *args)
146
143
  return env.t2j_iso(torch_outputs)
147
144
 
148
- jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")(
149
- np.arange(jax.device_count()), env.t2j_iso(args)
150
- )
145
+ jax_outputs = jax.pmap(
146
+ jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()),
147
+ env.t2j_iso(args))
151
148
  return env.j2t_iso(jax_outputs)
152
149
 
153
150
 
@@ -172,11 +169,12 @@ class DistributedDataParallel(torch.nn.Module):
172
169
  jax_output = jax_model(jax_data)
173
170
  ```
174
171
  """
172
+
175
173
  def __init__(
176
- self,
177
- module: torch.nn.Module,
178
- env: Optional[torchax.tensor.Environment] = None,
179
- **kwargs,
174
+ self,
175
+ module: torch.nn.Module,
176
+ env: Optional[torchax.tensor.Environment] = None,
177
+ **kwargs,
180
178
  ):
181
179
  if kwargs:
182
180
  logging.warning(f"Unsupported kwargs {kwargs}")
@@ -184,17 +182,15 @@ class DistributedDataParallel(torch.nn.Module):
184
182
  super().__init__()
185
183
  self._env = env or torchax.default_env()
186
184
  self._mesh = Mesh(
187
- mesh_utils.create_device_mesh((jax.device_count(),)),
188
- axis_names=("batch",),
185
+ mesh_utils.create_device_mesh((jax.device_count(),)),
186
+ axis_names=("batch",),
189
187
  )
190
188
  replicated_state = torch_pytree.tree_map_only(
191
- torch.Tensor,
192
- lambda t: self._env.j2t_iso(
193
- jax.device_put(
194
- self._env.to_xla(t)._elem, NamedSharding(self._mesh, P())
195
- )
196
- ),
197
- module.state_dict(),
189
+ torch.Tensor,
190
+ lambda t: self._env.j2t_iso(
191
+ jax.device_put(
192
+ self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))),
193
+ module.state_dict(),
198
194
  )
199
195
  # TODO: broadcast
200
196
  module.load_state_dict(replicated_state, assign=True)
@@ -208,25 +204,24 @@ class DistributedDataParallel(torch.nn.Module):
208
204
  global_batch_shape = (global_batch_size,) + inp.shape[1:]
209
205
 
210
206
  sharding = NamedSharding(self._mesh, P("batch"))
211
- return self._env.j2t_iso(jax.make_array_from_single_device_arrays(
212
- global_batch_shape,
213
- NamedSharding(self._mesh, P("batch")),
214
- arrays=[
215
- jax.device_put(self._env.to_xla(batch)._elem, device)
216
- for batch, device in zip(
217
- per_replica_batches, sharding.addressable_devices
218
- )
219
- ],
220
- ))
207
+ return self._env.j2t_iso(
208
+ jax.make_array_from_single_device_arrays(
209
+ global_batch_shape,
210
+ NamedSharding(self._mesh, P("batch")),
211
+ arrays=[
212
+ jax.device_put(self._env.to_xla(batch)._elem, device) for batch,
213
+ device in zip(per_replica_batches, sharding.addressable_devices)
214
+ ],
215
+ ))
221
216
 
222
217
  def replicate_input(self, inp):
223
218
  return self._env.j2t_iso(
224
- jax.device_put(inp._elem, NamedSharding(self._mesh, P()))
225
- )
219
+ jax.device_put(inp._elem, NamedSharding(self._mesh, P())))
226
220
 
227
221
  def jit_step(self, func):
228
- @functools.partial(interop.jax_jit,
229
- kwargs_for_jax_jit={'donate_argnums': 0})
222
+
223
+ @functools.partial(
224
+ interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0})
230
225
  def _jit_fn(states, args):
231
226
  self.load_state_dict(states)
232
227
  outputs = func(*args)
torchax/export.py CHANGED
@@ -4,14 +4,14 @@ import copy
4
4
  from typing import Any, Dict, Tuple
5
5
  import torch
6
6
  from torch.utils import _pytree as pytree
7
+ import torchax
7
8
  from torchax import tensor
8
- from torchax.ops import ops_registry
9
+ from torchax.ops import ops_registry, mappings
9
10
  from torchax import decompositions
10
11
  import jax
11
12
  import jax.export
12
13
  import sympy
13
14
 
14
-
15
15
  DEBUG = False
16
16
 
17
17
 
@@ -83,7 +83,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
83
83
  if torch.__version__ >= '2.2':
84
84
  # torch version 2.1 didn't expose this yet
85
85
  exported_program = exported_program.run_decompositions()
86
- exported_program = exported_program.run_decompositions(decompositions.EXTRA_DECOMP)
86
+ exported_program = exported_program.run_decompositions(
87
+ decompositions.DECOMPOSITIONS)
87
88
  if DEBUG:
88
89
  print(exported_program.graph_module.code)
89
90
 
@@ -108,8 +109,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
108
109
 
109
110
  if export_raw:
110
111
  return names, states, func
111
-
112
- states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
112
+ env = torchax.default_env()
113
+ states = env.t2j_copy(states)
113
114
  return states, func
114
115
 
115
116
 
@@ -121,34 +122,35 @@ def extract_avals(exported):
121
122
  def _to_aval(arg_meta, symbolic_shapes):
122
123
  """Convet from torch type to jax abstract value for export tracing
123
124
  """
125
+
124
126
  def _get_dim(d):
125
127
  if isinstance(d, torch.SymInt):
126
128
  return symbolic_shapes[str(d)]
127
129
  return d
128
130
 
129
131
  val = arg_meta['val']
130
- is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
132
+ is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
133
+ val, bool)
131
134
  if is_scalar:
132
135
  return jax.ShapeDtypeStruct([], type(arg_meta['val']))
133
136
 
134
137
  tensor_meta = arg_meta['tensor_meta']
135
138
  shape = [_get_dim(d) for d in tensor_meta.shape]
136
- return jax.ShapeDtypeStruct(shape, tensor.t2j_dtype(tensor_meta.dtype))
139
+ return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
137
140
 
138
141
  def _get_inputs(exported):
139
142
  """Return placeholders with input metadata"""
140
143
  placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
141
144
  input_placeholders = [
142
- p
143
- for p, s in zip(placeholders, exported.graph_signature.input_specs)
144
- if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
145
+ p for p, s in zip(placeholders, exported.graph_signature.input_specs)
146
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
145
147
  ]
146
148
  return input_placeholders
147
149
 
148
150
  def _build_symbolic_shapes(range_constraints):
149
151
  """Convert torch SymInt to JAX symbolic_shape and stores in a map using the
150
152
  string name of the torch symbolic int.
151
-
153
+
152
154
  TODO: There is probably a better way of storing a key for a symbolic int.
153
155
  This value needs to be looked up again in `_to_aval` to figure out which
154
156
  JAX symbolic to map to for a given torch tensor.
@@ -163,8 +165,10 @@ def extract_avals(exported):
163
165
  torch.export.Dim("a", min=5, max=10)
164
166
  ==> ("a >= 5", "a <= 10",)
165
167
  """
166
- if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges) or torch_constraint.is_bool:
167
- raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
168
+ if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
169
+ ValueRanges) or torch_constraint.is_bool:
170
+ raise TypeError(
171
+ f"No symbolic constraint handler for: {torch_constraint}")
168
172
 
169
173
  constraints = []
170
174
  symbol = sympy.Symbol(symbol_name)
@@ -182,7 +186,7 @@ def extract_avals(exported):
182
186
  There are two possible sympy `sym` inputs:
183
187
  1. Symbol - (s0) These can have custom constraints.
184
188
  2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
185
-
189
+
186
190
  Currently support is limited to operations with a symbol and and int,
187
191
  in `torch/export/dynamic_shapes.py`:
188
192
  "Only increasing linear operations with integer coefficients are supported."
@@ -190,7 +194,8 @@ def extract_avals(exported):
190
194
  symbol_name = str(sym)
191
195
  constraints = _build_symbolic_constraints(symbol_name, constraint)
192
196
  if sym.is_symbol:
193
- symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
197
+ symbolic_shape = jax.export.symbolic_shape(
198
+ symbol_name, constraints=constraints)
194
199
  else:
195
200
  assert len(sym.free_symbols) > 0
196
201
  scope = free_symbols[str(list(sym.free_symbols)[0])].scope
@@ -203,8 +208,12 @@ def extract_avals(exported):
203
208
  # integer compuations on symbol variables, so each symbol variable is OK to
204
209
  # have its own scope.
205
210
  symbolic_shapes = {}
206
- symbol_variables = [(s,v) for s,v in range_constraints.items() if s.is_symbol]
207
- symbol_exprs = [(s,v) for s,v in range_constraints.items() if not s.is_symbol]
211
+ symbol_variables = [
212
+ (s, v) for s, v in range_constraints.items() if s.is_symbol
213
+ ]
214
+ symbol_exprs = [
215
+ (s, v) for s, v in range_constraints.items() if not s.is_symbol
216
+ ]
208
217
  for sym, constraint in symbol_variables + symbol_exprs:
209
218
  symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
210
219
  symbolic_shapes[str(sym)] = symbolic_shape
torchax/flax.py ADDED
@@ -0,0 +1,39 @@
1
+ """Flax interop."""
2
+
3
+ import torch
4
+ import torchax as tx
5
+ import torchax.interop
6
+
7
+
8
+ class FlaxNNModule(torch.nn.Module):
9
+
10
+ def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
11
+ super().__init__()
12
+ prng = env.prng_key
13
+ sample_kwargs = sample_kwargs or {}
14
+ parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
15
+ **sample_kwargs)
16
+
17
+ self._params = self._encode_nested_dict(parameter_dict)
18
+
19
+ self._flax_module = flax_module
20
+
21
+ def _encode_nested_dict(self, nested_dict):
22
+ child_module = torch.nn.Module()
23
+ for k, v in nested_dict.items():
24
+ if isinstance(v, dict):
25
+ child_module.add_module(k, self._encode_nested_dict(v))
26
+ else:
27
+ child_module.register_parameter(k, torch.nn.Parameter(v))
28
+ return child_module
29
+
30
+ def _decode_nested_dict(self, child_module):
31
+ result = dict(child_module.named_parameters(recurse=False))
32
+ for k, v in child_module.named_children():
33
+ result[k] = self._decode_nested_dict(v)
34
+ return result
35
+
36
+ def forward(self, *args, **kwargs):
37
+ nested_dict_params = self._decode_nested_dict(self._params)
38
+ return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
39
+ *args, **kwargs)