torchax 0.0.10.dev20251116__py3-none-any.whl → 0.0.11.dev202612__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/checkpoint.py CHANGED
@@ -12,32 +12,39 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import torch
16
15
  import os
17
- from typing import Any, Dict
18
- from flax.training import checkpoints
16
+ from typing import Any
17
+
19
18
  import jax
20
19
  import jax.numpy as jnp
21
20
  import numpy as np
21
+ import torch
22
+ from flax.training import checkpoints
23
+
22
24
  from . import tensor
23
25
 
26
+
24
27
  def _to_jax(pytree):
25
28
  def to_jax_array(x):
26
29
  if isinstance(x, tensor.Tensor):
27
- return x.jax()
30
+ return x.jax()
28
31
  elif isinstance(x, torch.Tensor):
29
- return jnp.asarray(x.cpu().numpy())
32
+ return jnp.asarray(x.cpu().numpy())
30
33
  return x
34
+
31
35
  return jax.tree_util.tree_map(to_jax_array, pytree)
32
36
 
33
37
 
34
38
  def _to_torch(pytree):
35
39
  return jax.tree_util.tree_map(
36
40
  lambda x: torch.from_numpy(np.asarray(x))
37
- if isinstance(x, (jnp.ndarray, jax.Array)) else x, pytree)
41
+ if isinstance(x, (jnp.ndarray, jax.Array))
42
+ else x,
43
+ pytree,
44
+ )
38
45
 
39
46
 
40
- def save_checkpoint(state: Dict[str, Any], path: str, step: int):
47
+ def save_checkpoint(state: dict[str, Any], path: str, step: int):
41
48
  """Saves a checkpoint to a file in JAX style.
42
49
 
43
50
  Args:
@@ -50,7 +57,7 @@ def save_checkpoint(state: Dict[str, Any], path: str, step: int):
50
57
  checkpoints.save_checkpoint(path, state, step=step, overwrite=True)
51
58
 
52
59
 
53
- def load_checkpoint(path: str) -> Dict[str, Any]:
60
+ def load_checkpoint(path: str) -> dict[str, Any]:
54
61
  """Loads a checkpoint and returns it in JAX format.
55
62
 
56
63
  This function can load both PyTorch-style (single file) and JAX-style
@@ -76,4 +83,3 @@ def load_checkpoint(path: str) -> Dict[str, Any]:
76
83
  return _to_jax(state)
77
84
  else:
78
85
  raise FileNotFoundError(f"No such file or directory: {path}")
79
-
torchax/config.py CHANGED
@@ -35,10 +35,6 @@ class Configuration:
35
35
  # Use DLPack for converting jax.Arrays <-> and torch.Tensor
36
36
  use_dlpack_for_data_conversion: bool = False
37
37
 
38
- # Flash attention
39
- use_tpu_flash_attention: bool = False
40
- shmap_flash_attention: bool = False
41
-
42
38
  # device
43
39
  treat_cuda_as_jax_device: bool = True
44
40
  internal_respect_torch_return_dtypes: bool = False
torchax/decompositions.py CHANGED
@@ -22,21 +22,21 @@ Can also contain decompositions of a torch op in terms of other torch ops.
22
22
  """
23
23
 
24
24
  import functools
25
- from typing import Any, Callable, List, Tuple
25
+ from collections.abc import Callable
26
+ from typing import Any
26
27
 
27
28
  import torch
28
- from torch import Tensor
29
29
  import torch._decomp as decomp
30
- from torch._decomp import decompositions_for_rng
31
- from torch._decomp import register_decomposition
32
30
  import torch._prims_common as utils
31
+ from torch import Tensor
32
+ from torch._decomp import decompositions_for_rng, register_decomposition
33
33
  from torch._prims_common.wrappers import out_wrapper
34
34
 
35
35
  DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
36
36
 
37
37
  # None of these functions are publicly accessible; get at them
38
38
  # from torch._decomps
39
- __all__: List[str] = []
39
+ __all__: list[str] = []
40
40
 
41
41
  aten = torch._ops.ops.aten
42
42
 
@@ -44,21 +44,21 @@ aten = torch._ops.ops.aten
44
44
  def _try_register(op, impl):
45
45
  try:
46
46
  register_decomposition(op)(impl)
47
- except:
47
+
48
+ except Exception:
48
49
  pass
49
50
 
50
51
 
51
52
  @out_wrapper()
52
- def _reflection_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
53
-
53
+ def _reflection_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
54
54
  def idx(left, middle, right):
55
55
  dim_idx = torch.arange(-left, middle + right, device=a.device)
56
56
  return middle - 1 - (middle - 1 - dim_idx.abs()).abs()
57
57
 
58
58
  return _reflection_or_replication_pad(
59
- a,
60
- padding,
61
- idx,
59
+ a,
60
+ padding,
61
+ idx,
62
62
  )
63
63
 
64
64
 
@@ -68,32 +68,32 @@ _try_register(aten.reflection_pad3d, _reflection_pad)
68
68
 
69
69
 
70
70
  @out_wrapper()
71
- def _replication_pad(a: Tensor, padding: Tuple[int, ...]) -> Tensor:
72
-
71
+ def _replication_pad(a: Tensor, padding: tuple[int, ...]) -> Tensor:
73
72
  def idx(left, middle, right):
74
73
  dim_idx = torch.arange(-left, middle + right, device=a.device)
75
74
  return torch.clamp(dim_idx, 0, middle - 1)
76
75
 
77
76
  return _reflection_or_replication_pad(
78
- a,
79
- padding,
80
- idx,
77
+ a,
78
+ padding,
79
+ idx,
81
80
  )
82
81
 
83
82
 
84
- decomp.global_decomposition_table["post_autograd"][
85
- aten.replication_pad2d.default] = _replication_pad
83
+ decomp.global_decomposition_table["post_autograd"][aten.replication_pad2d.default] = (
84
+ _replication_pad
85
+ )
86
86
 
87
87
 
88
88
  def _reflection_or_replication_pad(
89
- a: Tensor,
90
- padding: Tuple[int, ...],
91
- idx_fn: Callable[[int, int, int], Tensor],
89
+ a: Tensor,
90
+ padding: tuple[int, ...],
91
+ idx_fn: Callable[[int, int, int], Tensor],
92
92
  ) -> Tensor:
93
93
  dim = len(padding) // 2
94
94
  torch._check(
95
- a.dim() in (dim + 1, dim + 2),
96
- lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
95
+ a.dim() in (dim + 1, dim + 2),
96
+ lambda: f"reflection_pad{dim}d requires {dim + 1}D or {dim + 2}D input",
97
97
  )
98
98
  inp_shape = a.shape[-dim:]
99
99
  nc_dim = a.dim() - dim
@@ -103,7 +103,7 @@ def _reflection_or_replication_pad(
103
103
 
104
104
  result = a
105
105
  for i in range(dim):
106
- idx: List[Any] = [None] * result.dim()
106
+ idx: list[Any] = [None] * result.dim()
107
107
  idx[i + nc_dim] = idx_fn(padding_left[i], inp_shape[i], padding_right[i])
108
108
  result = aten._unsafe_index(result, idx)
109
109
 
@@ -158,11 +158,11 @@ def _sum_tensors(ts) -> Tensor:
158
158
 
159
159
  @register_decomposition(aten.grid_sampler_3d)
160
160
  def _grid_sampler_3d(
161
- a: torch.Tensor,
162
- grid: torch.Tensor,
163
- interpolation_mode: int = 0,
164
- padding_mode: int = 0,
165
- align_corners: bool = False,
161
+ a: torch.Tensor,
162
+ grid: torch.Tensor,
163
+ interpolation_mode: int = 0,
164
+ padding_mode: int = 0,
165
+ align_corners: bool = False,
166
166
  ) -> Tensor:
167
167
  """References: https://github.com/pytorch/pytorch/blob/06a7dc21c1005750598c37f3adbc031183c74de6/torch/_decomp/decompositions.py#L4075
168
168
 
@@ -170,11 +170,12 @@ def _grid_sampler_3d(
170
170
  """
171
171
  _expand_grid = False
172
172
  torch._check(
173
- interpolation_mode in (0, 1),
174
- lambda: f"Invalid interpolation mode {interpolation_mode}",
173
+ interpolation_mode in (0, 1),
174
+ lambda: f"Invalid interpolation mode {interpolation_mode}",
175
175
  )
176
176
  torch._check(
177
- padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}")
177
+ padding_mode in (0, 1, 2), lambda: f"Invalid padding mode {padding_mode}"
178
+ )
178
179
 
179
180
  # a is 5D: [B, C, D, H, W]
180
181
 
@@ -189,8 +190,7 @@ def _grid_sampler_3d(
189
190
  # Reflects coordinates until they fall between low and high (inclusive).
190
191
  # The bounds are passed as twice their value so that half-integer values
191
192
  # can be represented as ints.
192
- def reflect_coordinates(coords: Tensor, twice_low: int,
193
- twice_high: int) -> Tensor:
193
+ def reflect_coordinates(coords: Tensor, twice_low: int, twice_high: int) -> Tensor:
194
194
  if twice_low == twice_high:
195
195
  return torch.zeros_like(coords)
196
196
  coords_min = twice_low / 2
@@ -198,8 +198,9 @@ def _grid_sampler_3d(
198
198
  coords2 = (coords - coords_min).abs()
199
199
  extra = torch.fmod(coords2, coords_span)
200
200
  flips = (coords2 / coords_span).floor().to(dtype=torch.int8)
201
- return torch.where(flips & 1 == 0, extra + coords_min,
202
- coords_span + coords_min - extra)
201
+ return torch.where(
202
+ flips & 1 == 0, extra + coords_min, coords_span + coords_min - extra
203
+ )
203
204
 
204
205
  def compute_coordinates(coords: Tensor, size: int) -> Tensor:
205
206
  if padding_mode == 0: # Zero
@@ -219,7 +220,7 @@ def _grid_sampler_3d(
219
220
 
220
221
  N, C, iD, iH, iW = a.shape
221
222
  _, oD, oH, oW, three = grid.shape
222
- assert three == 3, "Last dim of grid must be 3. got {}".format(three)
223
+ assert three == 3, f"Last dim of grid must be 3. got {three}"
223
224
 
224
225
  def in_bounds_cond(xs: Tensor, ys: Tensor, zs) -> Tensor:
225
226
  xcheck = torch.logical_and(0 <= xs, xs < iW)
@@ -238,15 +239,16 @@ def _grid_sampler_3d(
238
239
  # broadcasting with N_idx, C_idx for the purposes of advanced indexing
239
240
  c = C if _expand_grid else 1
240
241
  return tuple(
241
- torch.where(cond, t, 0).view(N, c, oD, oH, oW) for t in (
242
- xs.to(dtype=torch.int64),
243
- ys.to(dtype=torch.int64),
244
- zs.to(dtype=torch.int64),
245
- ws,
246
- ))
247
-
248
- def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor,
249
- w) -> Tensor:
242
+ torch.where(cond, t, 0).view(N, c, oD, oH, oW)
243
+ for t in (
244
+ xs.to(dtype=torch.int64),
245
+ ys.to(dtype=torch.int64),
246
+ zs.to(dtype=torch.int64),
247
+ ws,
248
+ )
249
+ )
250
+
251
+ def get_summand(ix: torch.Tensor, iy: torch.Tensor, iz: torch.Tensor, w) -> Tensor:
250
252
  # Perform clipping, index into input tensor and multiply by weight
251
253
  idx_x, idx_y, idx_z, w_ = clip(ix, iy, iz, w)
252
254
  return a[N_idx, C_idx, idx_z, idx_y, idx_x] * w_
@@ -279,16 +281,18 @@ def _grid_sampler_3d(
279
281
  w_seb = (ix - ix_nwf) * (iy - iy_nwf) * (id_ - id_nwf)
280
282
 
281
283
  return _sum_tensors(
282
- get_summand(ix, iy, id_, w) for (ix, iy, id_, w) in (
283
- (ix_nwf, iy_nwf, id_nwf, w_nwf),
284
- (ix_nef, iy_nef, id_nef, w_nef),
285
- (ix_swf, iy_swf, id_swf, w_swf),
286
- (ix_sef, iy_sef, id_sef, w_sef),
287
- (ix_nwb, iy_nwb, id_nwb, w_nwb),
288
- (ix_neb, iy_neb, id_neb, w_neb),
289
- (ix_swb, iy_swb, id_swb, w_swb),
290
- (ix_seb, iy_seb, id_seb, w_seb),
291
- ))
284
+ get_summand(ix, iy, id_, w)
285
+ for (ix, iy, id_, w) in (
286
+ (ix_nwf, iy_nwf, id_nwf, w_nwf),
287
+ (ix_nef, iy_nef, id_nef, w_nef),
288
+ (ix_swf, iy_swf, id_swf, w_swf),
289
+ (ix_sef, iy_sef, id_sef, w_sef),
290
+ (ix_nwb, iy_nwb, id_nwb, w_nwb),
291
+ (ix_neb, iy_neb, id_neb, w_neb),
292
+ (ix_swb, iy_swb, id_swb, w_swb),
293
+ (ix_seb, iy_seb, id_seb, w_seb),
294
+ )
295
+ )
292
296
  else: # interpolation_mode == 1: # Nearest
293
297
  ix = compute_source_index(x, iW)
294
298
  iy = compute_source_index(y, iH)
@@ -301,7 +305,8 @@ def _grid_sampler_3d(
301
305
  return get_summand(ix_nearest, iy_nearest, iz_nearest, 1)
302
306
 
303
307
 
304
- DECOMPOSITIONS = decomp.get_decompositions([
308
+ DECOMPOSITIONS = decomp.get_decompositions(
309
+ [
305
310
  torch.ops.aten.upsample_bicubic2d,
306
311
  torch.ops.aten.upsample_nearest1d,
307
312
  torch.ops.aten.upsample_nearest2d,
@@ -782,9 +787,10 @@ DECOMPOSITIONS = decomp.get_decompositions([
782
787
  torch.ops.aten.__irshift__.Tensor,
783
788
  torch.ops.aten.__irshift__.Scalar,
784
789
  torch.ops.aten.__ior__.Tensor,
785
- ])
790
+ ]
791
+ )
786
792
 
787
793
  MUTABLE_DECOMPOSITION = [
788
- torch.ops.aten.bernoulli_.Tensor,
789
- torch.ops.aten.bernoulli_.float,
794
+ torch.ops.aten.bernoulli_.Tensor,
795
+ torch.ops.aten.bernoulli_.float,
790
796
  ]
torchax/export.py CHANGED
@@ -14,17 +14,21 @@
14
14
 
15
15
  # pylint: disable
16
16
  """Utilities for exporting a torch program to jax/stablehlo."""
17
+
17
18
  import copy
18
- from typing import Any, Dict, Tuple
19
+ from typing import Any
20
+
21
+ import jax
22
+ import jax.export
23
+ import sympy
19
24
  import torch
25
+ import torch._refs
26
+ from torch._decomp import get_decompositions
20
27
  from torch.utils import _pytree as pytree
28
+
21
29
  import torchax
22
- from torchax import tensor
23
- from torchax.ops import ops_registry, mappings
24
30
  from torchax import decompositions
25
- import jax
26
- import jax.export
27
- import sympy
31
+ from torchax.ops import mappings, ops_registry
28
32
 
29
33
  DEBUG = False
30
34
 
@@ -34,16 +38,13 @@ class JaxInterpreter(torch.fx.Interpreter):
34
38
 
35
39
  def __init__(self, graph_module):
36
40
  super().__init__(graph_module)
37
- import torchax.ops.jaten
38
- import torchax.ops.jtorch
39
41
 
40
- def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
41
- if not isinstance(target,
42
- (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
42
+ def call_function(self, target, args: tuple, kwargs: dict) -> Any:
43
+ if not isinstance(target, (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
43
44
  return super().call_function(target, args, kwargs)
44
45
 
45
46
  if DEBUG:
46
- print('Running ', target.name(), '--------')
47
+ print("Running ", target.name(), "--------")
47
48
 
48
49
  op = ops_registry.all_aten_ops.get(target)
49
50
  if op is None:
@@ -54,31 +55,30 @@ class JaxInterpreter(torch.fx.Interpreter):
54
55
  op = ops_registry.all_aten_ops.get(target.overloadpacket)
55
56
  if op is None:
56
57
  print(target.name(), target.tags)
57
- raise RuntimeError('No lowering found for', target.name())
58
+ raise RuntimeError("No lowering found for", target.name())
58
59
  return op.func(*args, **kwargs)
59
60
 
60
61
  def run_node(self, n) -> Any:
61
62
  res = super().run_node(n)
62
63
  if DEBUG:
63
- if n.op == 'call_function':
64
- if hasattr(res, 'shape'):
65
- print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
64
+ if n.op == "call_function":
65
+ if hasattr(res, "shape"):
66
+ print("Meta:", n.meta.get("val").shape, "REAL: ", res.shape)
66
67
  return res
67
68
 
68
69
 
69
- from torch._decomp import get_decompositions
70
- import torch._refs
71
-
72
70
  _extra_decomp = get_decompositions([torch.ops.aten.unfold])
73
71
 
74
72
 
75
73
  def _extract_states_from_exported_program(exported_model):
76
74
  # NOTE call convention: (parameters, buffers, user_inputs)
77
- param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
75
+ param_and_buffer_keys = (
76
+ exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
77
+ )
78
78
  state_dict = copy.copy(exported_model.state_dict)
79
- if (constants := getattr(exported_model, 'constants', None)) is not None:
79
+ if (constants := getattr(exported_model, "constants", None)) is not None:
80
80
  state_dict.update(constants)
81
- param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys)
81
+ param_buffer_values = [state_dict[key] for key in param_and_buffer_keys]
82
82
 
83
83
  if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
84
84
  for name in exported_model.graph_signature.lifted_tensor_constants:
@@ -94,19 +94,19 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
94
94
 
95
95
  func(state, input) would be how you call it.
96
96
  """
97
- if torch.__version__ >= '2.2':
97
+ if torch.__version__ >= "2.2":
98
98
  # torch version 2.1 didn't expose this yet
99
99
  exported_program = exported_program.run_decompositions()
100
100
  exported_program = exported_program.run_decompositions(
101
- decompositions.DECOMPOSITIONS)
101
+ decompositions.DECOMPOSITIONS
102
+ )
102
103
  if DEBUG:
103
104
  print(exported_program.graph_module.code)
104
105
 
105
106
  names, states = _extract_states_from_exported_program(exported_program)
106
107
 
107
108
  def _extract_args(args, kwargs):
108
- flat_args, received_spec = pytree.tree_flatten(
109
- (args, kwargs)) # type: ignore[possibly-undefined]
109
+ flat_args, received_spec = pytree.tree_flatten((args, kwargs)) # type: ignore[possibly-undefined]
110
110
  return flat_args
111
111
 
112
112
  num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
@@ -114,9 +114,9 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
114
114
  def func(states, inputs):
115
115
  args = _extract_args(inputs, {})
116
116
  res = JaxInterpreter(exported_program.graph_module).run(
117
- *states,
118
- *args,
119
- enable_io_processing=False,
117
+ *states,
118
+ *args,
119
+ enable_io_processing=False,
120
120
  )
121
121
  res = res[num_mutations:]
122
122
  return res
@@ -134,21 +134,19 @@ def extract_avals(exported):
134
134
  """
135
135
 
136
136
  def _to_aval(arg_meta, symbolic_shapes):
137
- """Convet from torch type to jax abstract value for export tracing
138
- """
137
+ """Convet from torch type to jax abstract value for export tracing"""
139
138
 
140
139
  def _get_dim(d):
141
140
  if isinstance(d, torch.SymInt):
142
141
  return symbolic_shapes[str(d)]
143
142
  return d
144
143
 
145
- val = arg_meta['val']
146
- is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
147
- val, bool)
144
+ val = arg_meta["val"]
145
+ is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
148
146
  if is_scalar:
149
- return jax.ShapeDtypeStruct([], type(arg_meta['val']))
147
+ return jax.ShapeDtypeStruct([], type(arg_meta["val"]))
150
148
 
151
- tensor_meta = arg_meta['tensor_meta']
149
+ tensor_meta = arg_meta["tensor_meta"]
152
150
  shape = [_get_dim(d) for d in tensor_meta.shape]
153
151
  return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
154
152
 
@@ -156,8 +154,9 @@ def extract_avals(exported):
156
154
  """Return placeholders with input metadata"""
157
155
  placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
158
156
  input_placeholders = [
159
- p for p, s in zip(placeholders, exported.graph_signature.input_specs)
160
- if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
157
+ p
158
+ for p, s in zip(placeholders, exported.graph_signature.input_specs, strict=False)
159
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
161
160
  ]
162
161
  return input_placeholders
163
162
 
@@ -179,17 +178,22 @@ def extract_avals(exported):
179
178
  torch.export.Dim("a", min=5, max=10)
180
179
  ==> ("a >= 5", "a <= 10",)
181
180
  """
182
- if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
183
- ValueRanges) or torch_constraint.is_bool:
184
- raise TypeError(
185
- f"No symbolic constraint handler for: {torch_constraint}")
181
+ if (
182
+ not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges)
183
+ or torch_constraint.is_bool
184
+ ):
185
+ raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
186
186
 
187
187
  constraints = []
188
188
  symbol = sympy.Symbol(symbol_name)
189
189
  if torch_constraint.lower != 2:
190
190
  constraints.append(symbol >= torch_constraint.lower)
191
191
  from sympy.core.singleton import S
192
- if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity:
192
+
193
+ if (
194
+ not torch_constraint.upper.is_infinite
195
+ and torch_constraint.upper is not S.IntInfinity
196
+ ):
193
197
  constraints.append(symbol <= torch_constraint.upper)
194
198
 
195
199
  return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
@@ -208,8 +212,7 @@ def extract_avals(exported):
208
212
  symbol_name = str(sym)
209
213
  constraints = _build_symbolic_constraints(symbol_name, constraint)
210
214
  if sym.is_symbol:
211
- symbolic_shape = jax.export.symbolic_shape(
212
- symbol_name, constraints=constraints)
215
+ symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
213
216
  else:
214
217
  assert len(sym.free_symbols) > 0
215
218
  scope = free_symbols[str(list(sym.free_symbols)[0])].scope
@@ -222,12 +225,8 @@ def extract_avals(exported):
222
225
  # integer compuations on symbol variables, so each symbol variable is OK to
223
226
  # have its own scope.
224
227
  symbolic_shapes = {}
225
- symbol_variables = [
226
- (s, v) for s, v in range_constraints.items() if s.is_symbol
227
- ]
228
- symbol_exprs = [
229
- (s, v) for s, v in range_constraints.items() if not s.is_symbol
230
- ]
228
+ symbol_variables = [(s, v) for s, v in range_constraints.items() if s.is_symbol]
229
+ symbol_exprs = [(s, v) for s, v in range_constraints.items() if not s.is_symbol]
231
230
  for sym, constraint in symbol_variables + symbol_exprs:
232
231
  symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
233
232
  symbolic_shapes[str(sym)] = symbolic_shape
@@ -237,10 +236,10 @@ def extract_avals(exported):
237
236
  args = _get_inputs(exported)
238
237
 
239
238
  if DEBUG:
240
- print('Inputs to aval:', args, '--------')
241
- print('Symbolic shapes:', symbolic_shapes)
239
+ print("Inputs to aval:", args, "--------")
240
+ print("Symbolic shapes:", symbolic_shapes)
242
241
  for arg in args:
243
- print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes))
242
+ print("Meta2Aval", arg.meta, "--> ", _to_aval(arg.meta, symbolic_shapes))
244
243
 
245
244
  return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
246
245
 
torchax/flax.py CHANGED
@@ -15,18 +15,19 @@
15
15
  """Flax interop."""
16
16
 
17
17
  import torch
18
+
18
19
  import torchax as tx
19
20
  import torchax.interop
20
21
 
21
22
 
22
23
  class FlaxNNModule(torch.nn.Module):
23
-
24
24
  def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
25
25
  super().__init__()
26
26
  prng = env.prng_key
27
27
  sample_kwargs = sample_kwargs or {}
28
- parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
29
- **sample_kwargs)
28
+ parameter_dict = tx.interop.call_jax(
29
+ flax_module.init, prng, *sample_args, **sample_kwargs
30
+ )
30
31
 
31
32
  self._params = self._encode_nested_dict(parameter_dict)
32
33
 
@@ -49,5 +50,6 @@ class FlaxNNModule(torch.nn.Module):
49
50
 
50
51
  def forward(self, *args, **kwargs):
51
52
  nested_dict_params = self._decode_nested_dict(self._params)
52
- return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
53
- *args, **kwargs)
53
+ return tx.interop.call_jax(
54
+ self._flax_module.apply, nested_dict_params, *args, **kwargs
55
+ )