torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/device_module.py CHANGED
@@ -1,20 +1,33 @@
1
+ import torch
2
+
3
+
1
4
  def _is_in_bad_fork():
2
5
  return False
3
6
 
7
+
4
8
  def manual_seed_all(seed):
5
9
  pass
6
10
 
11
+
7
12
  def device_count():
8
13
  return 1
9
14
 
15
+
10
16
  def get_rng_state():
11
17
  return []
12
18
 
19
+
13
20
  def set_rng_state(new_state, device):
14
21
  pass
15
22
 
23
+
16
24
  def is_available():
17
25
  return True
18
26
 
27
+
19
28
  def current_device():
20
- return 0
29
+ return 0
30
+
31
+
32
+ def get_amp_supported_dtype():
33
+ return [torch.float16, torch.bfloat16]
torchax/environment.py CHANGED
@@ -1,2 +1 @@
1
1
 
2
-
torchax/export.py CHANGED
@@ -4,14 +4,14 @@ import copy
4
4
  from typing import Any, Dict, Tuple
5
5
  import torch
6
6
  from torch.utils import _pytree as pytree
7
+ import torchax
7
8
  from torchax import tensor
8
- from torchax.ops import ops_registry
9
+ from torchax.ops import ops_registry, mappings
9
10
  from torchax import decompositions
10
11
  import jax
11
12
  import jax.export
12
13
  import sympy
13
14
 
14
-
15
15
  DEBUG = False
16
16
 
17
17
 
@@ -83,7 +83,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
83
83
  if torch.__version__ >= '2.2':
84
84
  # torch version 2.1 didn't expose this yet
85
85
  exported_program = exported_program.run_decompositions()
86
- exported_program = exported_program.run_decompositions(decompositions.EXTRA_DECOMP)
86
+ exported_program = exported_program.run_decompositions(
87
+ decompositions.DECOMPOSITIONS)
87
88
  if DEBUG:
88
89
  print(exported_program.graph_module.code)
89
90
 
@@ -108,8 +109,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
108
109
 
109
110
  if export_raw:
110
111
  return names, states, func
111
-
112
- states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
112
+ env = torchax.default_env()
113
+ states = env.t2j_copy(states)
113
114
  return states, func
114
115
 
115
116
 
@@ -121,34 +122,35 @@ def extract_avals(exported):
121
122
  def _to_aval(arg_meta, symbolic_shapes):
122
123
  """Convet from torch type to jax abstract value for export tracing
123
124
  """
125
+
124
126
  def _get_dim(d):
125
127
  if isinstance(d, torch.SymInt):
126
128
  return symbolic_shapes[str(d)]
127
129
  return d
128
130
 
129
131
  val = arg_meta['val']
130
- is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
132
+ is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
133
+ val, bool)
131
134
  if is_scalar:
132
135
  return jax.ShapeDtypeStruct([], type(arg_meta['val']))
133
136
 
134
137
  tensor_meta = arg_meta['tensor_meta']
135
138
  shape = [_get_dim(d) for d in tensor_meta.shape]
136
- return jax.ShapeDtypeStruct(shape, tensor.t2j_dtype(tensor_meta.dtype))
139
+ return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
137
140
 
138
141
  def _get_inputs(exported):
139
142
  """Return placeholders with input metadata"""
140
143
  placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
141
144
  input_placeholders = [
142
- p
143
- for p, s in zip(placeholders, exported.graph_signature.input_specs)
144
- if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
145
+ p for p, s in zip(placeholders, exported.graph_signature.input_specs)
146
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
145
147
  ]
146
148
  return input_placeholders
147
149
 
148
150
  def _build_symbolic_shapes(range_constraints):
149
151
  """Convert torch SymInt to JAX symbolic_shape and stores in a map using the
150
152
  string name of the torch symbolic int.
151
-
153
+
152
154
  TODO: There is probably a better way of storing a key for a symbolic int.
153
155
  This value needs to be looked up again in `_to_aval` to figure out which
154
156
  JAX symbolic to map to for a given torch tensor.
@@ -163,8 +165,10 @@ def extract_avals(exported):
163
165
  torch.export.Dim("a", min=5, max=10)
164
166
  ==> ("a >= 5", "a <= 10",)
165
167
  """
166
- if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges) or torch_constraint.is_bool:
167
- raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
168
+ if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
169
+ ValueRanges) or torch_constraint.is_bool:
170
+ raise TypeError(
171
+ f"No symbolic constraint handler for: {torch_constraint}")
168
172
 
169
173
  constraints = []
170
174
  symbol = sympy.Symbol(symbol_name)
@@ -182,7 +186,7 @@ def extract_avals(exported):
182
186
  There are two possible sympy `sym` inputs:
183
187
  1. Symbol - (s0) These can have custom constraints.
184
188
  2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
185
-
189
+
186
190
  Currently support is limited to operations with a symbol and and int,
187
191
  in `torch/export/dynamic_shapes.py`:
188
192
  "Only increasing linear operations with integer coefficients are supported."
@@ -190,7 +194,8 @@ def extract_avals(exported):
190
194
  symbol_name = str(sym)
191
195
  constraints = _build_symbolic_constraints(symbol_name, constraint)
192
196
  if sym.is_symbol:
193
- symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
197
+ symbolic_shape = jax.export.symbolic_shape(
198
+ symbol_name, constraints=constraints)
194
199
  else:
195
200
  assert len(sym.free_symbols) > 0
196
201
  scope = free_symbols[str(list(sym.free_symbols)[0])].scope
@@ -203,8 +208,12 @@ def extract_avals(exported):
203
208
  # integer compuations on symbol variables, so each symbol variable is OK to
204
209
  # have its own scope.
205
210
  symbolic_shapes = {}
206
- symbol_variables = [(s,v) for s,v in range_constraints.items() if s.is_symbol]
207
- symbol_exprs = [(s,v) for s,v in range_constraints.items() if not s.is_symbol]
211
+ symbol_variables = [
212
+ (s, v) for s, v in range_constraints.items() if s.is_symbol
213
+ ]
214
+ symbol_exprs = [
215
+ (s, v) for s, v in range_constraints.items() if not s.is_symbol
216
+ ]
208
217
  for sym, constraint in symbol_variables + symbol_exprs:
209
218
  symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
210
219
  symbolic_shapes[str(sym)] = symbolic_shape
torchax/flax.py ADDED
@@ -0,0 +1,39 @@
1
+ """Flax interop."""
2
+
3
+ import torch
4
+ import torchax as tx
5
+ import torchax.interop
6
+
7
+
8
+ class FlaxNNModule(torch.nn.Module):
9
+
10
+ def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
11
+ super().__init__()
12
+ prng = env.prng_key
13
+ sample_kwargs = sample_kwargs or {}
14
+ parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
15
+ **sample_kwargs)
16
+
17
+ self._params = self._encode_nested_dict(parameter_dict)
18
+
19
+ self._flax_module = flax_module
20
+
21
+ def _encode_nested_dict(self, nested_dict):
22
+ child_module = torch.nn.Module()
23
+ for k, v in nested_dict.items():
24
+ if isinstance(v, dict):
25
+ child_module.add_module(k, self._encode_nested_dict(v))
26
+ else:
27
+ child_module.register_parameter(k, torch.nn.Parameter(v))
28
+ return child_module
29
+
30
+ def _decode_nested_dict(self, child_module):
31
+ result = dict(child_module.named_parameters(recurse=False))
32
+ for k, v in child_module.named_children():
33
+ result[k] = self._decode_nested_dict(v)
34
+ return result
35
+
36
+ def forward(self, *args, **kwargs):
37
+ nested_dict_params = self._decode_nested_dict(self._params)
38
+ return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
39
+ *args, **kwargs)