torchax 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torchax might be problematic. Click here for more details.

torchax/distributed.py ADDED
@@ -0,0 +1,246 @@
1
+ """`torch.distributed` backend implemented with JAX collective ops.
2
+
3
+ EXPERIMENTAL: This module is still highly experimental, and it may be removed
4
+ before any stable release.
5
+
6
+ Note: JAX collective ops require that axis names be defined in `pmap` or
7
+ `shmap`. The distributed backend only supports one axis, named `torch_dist`.
8
+ This name is defined by our mirror implementation of `spawn`.
9
+ """
10
+
11
+ import datetime
12
+ import functools
13
+ import logging
14
+ import os
15
+ from typing import List, Optional, Union
16
+
17
+ import jax
18
+ import numpy as np
19
+ import torch
20
+ import torch.distributed as dist
21
+ import torch.distributed._functional_collectives
22
+ from torch._C._distributed_c10d import ProcessGroup # type: ignore
23
+ import torch.distributed
24
+ import torchax
25
+ from jax.sharding import NamedSharding
26
+ from jax.sharding import Mesh, PartitionSpec as P
27
+ from jax.experimental import mesh_utils
28
+ import torch.utils._pytree as torch_pytree
29
+ from torchax import interop
30
+
31
+
32
+ class ProcessGroupJax(ProcessGroup):
33
+ """Distributed backend implemented with JAX."""
34
+
35
+ def __init__(self, prefix_store, rank, size, timeout):
36
+ super().__init__(rank, size)
37
+ self._group_name = None
38
+
39
+ def getBackendName(self):
40
+ return "jax"
41
+
42
+ # TODO(wcromar): why doesn't default group name setter work?
43
+ # https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
44
+ def _set_group_name(self, name: str) -> None:
45
+ self._group_name = name
46
+
47
+ @property
48
+ def group_name(self):
49
+ assert self._group_name
50
+ return self._group_name
51
+
52
+ @staticmethod
53
+ def _work(
54
+ tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]],
55
+ ) -> dist.Work:
56
+ fut = torch.futures.Future()
57
+ fut.set_result(tensors)
58
+ return torch._C._distributed_c10d._create_work_from_future(fut)
59
+
60
+ def _allgather_base(
61
+ self,
62
+ output: torch.Tensor,
63
+ input: torch.Tensor,
64
+ opts=...,
65
+ ) -> dist.Work:
66
+ assert isinstance(input, torchax.tensor.Tensor)
67
+ assert isinstance(output, torchax.tensor.Tensor)
68
+ torch.distributed._functional_collectives.all_gather_tensor_inplace(
69
+ output, input, group=self
70
+ )
71
+ return self._work(output)
72
+
73
+ def allreduce(
74
+ self,
75
+ tensors: List[torch.Tensor],
76
+ opts: dist.AllreduceOptions = ...,
77
+ ) -> dist.Work:
78
+ assert len(tensors) == 1
79
+ assert isinstance(tensors[0], torchax.tensor.Tensor)
80
+ torch.distributed._functional_collectives.all_reduce_inplace(
81
+ tensors[0],
82
+ torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
83
+ opts.reduceOp.op
84
+ ],
85
+ self,
86
+ )
87
+
88
+ return self._work(tensors)
89
+
90
+ def broadcast(
91
+ self,
92
+ tensors: List[torch.Tensor],
93
+ opts: dist.BroadcastOptions = ...,
94
+ ) -> dist.Work:
95
+ assert len(tensors) == 1
96
+ assert isinstance(tensors[0], torchax.tensor.Tensor)
97
+ tensors[0].copy_(
98
+ torch.distributed._functional_collectives.broadcast(
99
+ tensors[0], opts.rootRank, group=self
100
+ )
101
+ )
102
+
103
+ return self._work(tensors)
104
+
105
+
106
+ dist.Backend.register_backend("jax", ProcessGroupJax)
107
+
108
+
109
+ def jax_rendezvous_handler(
110
+ url: str, timeout: datetime.timedelta = ..., **kwargs
111
+ ):
112
+ """Initialize distributed store with JAX process IDs.
113
+
114
+ Requires `$MASTER_ADDR` and `$MASTER_PORT`.
115
+ """
116
+ # TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU
117
+ # TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part
118
+ # of their public Python API
119
+ master_ip = os.environ["MASTER_ADDR"]
120
+ master_port = int(os.environ["MASTER_PORT"])
121
+ # TODO(wcromar): Use `torchrun`'s store if available
122
+ store = dist.TCPStore(
123
+ master_ip,
124
+ master_port,
125
+ jax.process_count(),
126
+ is_master=jax.process_index() == 0,
127
+ )
128
+
129
+ yield (store, jax.process_index(), jax.process_count())
130
+
131
+
132
+ dist.register_rendezvous_handler("jax", jax_rendezvous_handler)
133
+
134
+
135
+ def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
136
+ """Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined.
137
+ `f` is expected to take the replica index as a positional argument, similar
138
+ to `torch.multiprocessing.spawn`.
139
+ Note: `spawn` does not actually create parallel processes.
140
+ """
141
+ env = env or torchax.default_env()
142
+
143
+ def jax_wrapper(index, jax_args):
144
+ index, args = env.j2t_iso([index, jax_args])
145
+ torch_outputs = f(index, *args)
146
+ return env.t2j_iso(torch_outputs)
147
+
148
+ jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")(
149
+ np.arange(jax.device_count()), env.t2j_iso(args)
150
+ )
151
+ return env.j2t_iso(jax_outputs)
152
+
153
+
154
+ class DistributedDataParallel(torch.nn.Module):
155
+ """Re-implementation of DistributedDataParallel using JAX SPMD.
156
+
157
+ Splits inputs along batch dimension (assumed to be 0) across all devices in
158
+ JAX runtime, including remote devices. Each process should load a distinct
159
+ shard of the input data using e.g. DistributedSampler. Each process' shard
160
+ is then further split among the addressable devices (e.g. local TPU chips)
161
+ by `shard_input`.
162
+
163
+ Note: since parameters are replicated across addressable devices, inputs
164
+ must also be SPMD sharded using `shard_input` or `replicate_input`.
165
+
166
+ Example usage:
167
+
168
+ ```
169
+ jax_model = torchax.distributed.DistributedDataParallel(create_model())
170
+ for data, dataloader:
171
+ jax_data = jax_model.shard_input(data)
172
+ jax_output = jax_model(jax_data)
173
+ ```
174
+ """
175
+ def __init__(
176
+ self,
177
+ module: torch.nn.Module,
178
+ env: Optional[torchax.tensor.Environment] = None,
179
+ **kwargs,
180
+ ):
181
+ if kwargs:
182
+ logging.warning(f"Unsupported kwargs {kwargs}")
183
+
184
+ super().__init__()
185
+ self._env = env or torchax.default_env()
186
+ self._mesh = Mesh(
187
+ mesh_utils.create_device_mesh((jax.device_count(),)),
188
+ axis_names=("batch",),
189
+ )
190
+ replicated_state = torch_pytree.tree_map_only(
191
+ torch.Tensor,
192
+ lambda t: self._env.j2t_iso(
193
+ jax.device_put(
194
+ self._env.to_xla(t)._elem, NamedSharding(self._mesh, P())
195
+ )
196
+ ),
197
+ module.state_dict(),
198
+ )
199
+ # TODO: broadcast
200
+ module.load_state_dict(replicated_state, assign=True)
201
+ self._module = module
202
+
203
+ def shard_input(self, inp):
204
+ per_process_batch_size = inp.shape[0] # assumes batch dim is 0
205
+ per_replica_batch_size = per_process_batch_size // jax.local_device_count()
206
+ per_replica_batches = torch.chunk(inp, jax.local_device_count())
207
+ global_batch_size = per_replica_batch_size * jax.device_count()
208
+ global_batch_shape = (global_batch_size,) + inp.shape[1:]
209
+
210
+ sharding = NamedSharding(self._mesh, P("batch"))
211
+ return self._env.j2t_iso(jax.make_array_from_single_device_arrays(
212
+ global_batch_shape,
213
+ NamedSharding(self._mesh, P("batch")),
214
+ arrays=[
215
+ jax.device_put(self._env.to_xla(batch)._elem, device)
216
+ for batch, device in zip(
217
+ per_replica_batches, sharding.addressable_devices
218
+ )
219
+ ],
220
+ ))
221
+
222
+ def replicate_input(self, inp):
223
+ return self._env.j2t_iso(
224
+ jax.device_put(inp._elem, NamedSharding(self._mesh, P()))
225
+ )
226
+
227
+ def jit_step(self, func):
228
+ @functools.partial(interop.jax_jit,
229
+ kwargs_for_jax_jit={'donate_argnums': 0})
230
+ def _jit_fn(states, args):
231
+ self.load_state_dict(states)
232
+ outputs = func(*args)
233
+ return self.state_dict(), outputs
234
+
235
+ @functools.wraps(func)
236
+ def inner(*args):
237
+ jax_states = self.state_dict()
238
+ new_states, outputs = _jit_fn(jax_states, args)
239
+ self.load_state_dict(new_states)
240
+ return outputs
241
+
242
+ return inner
243
+
244
+ def forward(self, *args):
245
+ with self._env:
246
+ return self._module(*args)
torchax/environment.py ADDED
@@ -0,0 +1,2 @@
1
+
2
+
torchax/export.py ADDED
@@ -0,0 +1,236 @@
1
+ # pylint: disable
2
+ """Utilities for exporting a torch program to jax/stablehlo."""
3
+ import copy
4
+ from typing import Any, Dict, Tuple
5
+ import torch
6
+ from torch.utils import _pytree as pytree
7
+ from torchax import tensor
8
+ from torchax.ops import ops_registry
9
+ from torchax import decompositions
10
+ import jax
11
+ import jax.export
12
+ import sympy
13
+
14
+
15
+ DEBUG = False
16
+
17
+
18
+ class JaxInterpreter(torch.fx.Interpreter):
19
+ """Experimental."""
20
+
21
+ def __init__(self, graph_module):
22
+ super().__init__(graph_module)
23
+ import torchax.ops.jaten
24
+ import torchax.ops.jtorch
25
+
26
+ def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
27
+ if not isinstance(target,
28
+ (torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
29
+ return super().call_function(target, args, kwargs)
30
+
31
+ if DEBUG:
32
+ print('Running ', target.name(), '--------')
33
+
34
+ op = ops_registry.all_aten_ops.get(target)
35
+ if op is None:
36
+ op = ops_registry.all_aten_ops.get(target.overloadpacket)
37
+ assert op is not None, target
38
+ assert op.is_jax_function, op
39
+ if op is None:
40
+ op = ops_registry.all_aten_ops.get(target.overloadpacket)
41
+ if op is None:
42
+ print(target.name(), target.tags)
43
+ raise RuntimeError('No lowering found for', target.name())
44
+ return op.func(*args, **kwargs)
45
+
46
+ def run_node(self, n) -> Any:
47
+ res = super().run_node(n)
48
+ if DEBUG:
49
+ if n.op == 'call_function':
50
+ if hasattr(res, 'shape'):
51
+ print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
52
+ return res
53
+
54
+
55
+ from torch._decomp import get_decompositions
56
+ import torch._refs
57
+
58
+ _extra_decomp = get_decompositions([torch.ops.aten.unfold])
59
+
60
+
61
+ def _extract_states_from_exported_program(exported_model):
62
+ # NOTE call convention: (parameters, buffers, user_inputs)
63
+ param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
64
+ state_dict = copy.copy(exported_model.state_dict)
65
+ if (constants := getattr(exported_model, 'constants', None)) is not None:
66
+ state_dict.update(constants)
67
+ param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys)
68
+
69
+ if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
70
+ for name in exported_model.graph_signature.lifted_tensor_constants:
71
+ param_buffer_values.append(exported_model.tensor_constants[name])
72
+
73
+ return param_and_buffer_keys, param_buffer_values
74
+
75
+
76
+ def exported_program_to_jax(exported_program, export_raw: bool = False):
77
+ """returns a pytree of jax arrays(state), and
78
+
79
+ a callable(func) that is jax function.
80
+
81
+ func(state, input) would be how you call it.
82
+ """
83
+ if torch.__version__ >= '2.2':
84
+ # torch version 2.1 didn't expose this yet
85
+ exported_program = exported_program.run_decompositions()
86
+ exported_program = exported_program.run_decompositions(decompositions.EXTRA_DECOMP)
87
+ if DEBUG:
88
+ print(exported_program.graph_module.code)
89
+
90
+ names, states = _extract_states_from_exported_program(exported_program)
91
+
92
+ def _extract_args(args, kwargs):
93
+ flat_args, received_spec = pytree.tree_flatten(
94
+ (args, kwargs)) # type: ignore[possibly-undefined]
95
+ return flat_args
96
+
97
+ num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
98
+
99
+ def func(states, inputs):
100
+ args = _extract_args(inputs, {})
101
+ res = JaxInterpreter(exported_program.graph_module).run(
102
+ *states,
103
+ *args,
104
+ enable_io_processing=False,
105
+ )
106
+ res = res[num_mutations:]
107
+ return res
108
+
109
+ if export_raw:
110
+ return names, states, func
111
+
112
+ states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
113
+ return states, func
114
+
115
+
116
+ def extract_avals(exported):
117
+ """Return JAX Abstract Value shapes for all input parameters of the exported
118
+ program. This supports dynamic batch dimensions, including with constraints.
119
+ """
120
+
121
+ def _to_aval(arg_meta, symbolic_shapes):
122
+ """Convet from torch type to jax abstract value for export tracing
123
+ """
124
+ def _get_dim(d):
125
+ if isinstance(d, torch.SymInt):
126
+ return symbolic_shapes[str(d)]
127
+ return d
128
+
129
+ val = arg_meta['val']
130
+ is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
131
+ if is_scalar:
132
+ return jax.ShapeDtypeStruct([], type(arg_meta['val']))
133
+
134
+ tensor_meta = arg_meta['tensor_meta']
135
+ shape = [_get_dim(d) for d in tensor_meta.shape]
136
+ return jax.ShapeDtypeStruct(shape, tensor.t2j_dtype(tensor_meta.dtype))
137
+
138
+ def _get_inputs(exported):
139
+ """Return placeholders with input metadata"""
140
+ placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
141
+ input_placeholders = [
142
+ p
143
+ for p, s in zip(placeholders, exported.graph_signature.input_specs)
144
+ if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
145
+ ]
146
+ return input_placeholders
147
+
148
+ def _build_symbolic_shapes(range_constraints):
149
+ """Convert torch SymInt to JAX symbolic_shape and stores in a map using the
150
+ string name of the torch symbolic int.
151
+
152
+ TODO: There is probably a better way of storing a key for a symbolic int.
153
+ This value needs to be looked up again in `_to_aval` to figure out which
154
+ JAX symbolic to map to for a given torch tensor.
155
+ """
156
+ if len(range_constraints) == 0:
157
+ return None
158
+
159
+ def _build_symbolic_constraints(symbol_name, torch_constraint):
160
+ """Convert torch SymInt constraints to string for JAX symbolic_shape
161
+ Using sympy may be overkill here, currently PyTorch only uses ValueRanges
162
+ which allow specifying the min and the max of a value, for example:
163
+ torch.export.Dim("a", min=5, max=10)
164
+ ==> ("a >= 5", "a <= 10",)
165
+ """
166
+ if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges) or torch_constraint.is_bool:
167
+ raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
168
+
169
+ constraints = []
170
+ symbol = sympy.Symbol(symbol_name)
171
+ if torch_constraint.lower != 2:
172
+ constraints.append(symbol >= torch_constraint.lower)
173
+ from sympy.core.singleton import S
174
+ if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity:
175
+ constraints.append(symbol <= torch_constraint.upper)
176
+
177
+ return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
178
+
179
+ def _build_symbolic_shape(sym, constraint, free_symbols):
180
+ """Returns a JAX symbolic shape for a given symbol and constraint
181
+
182
+ There are two possible sympy `sym` inputs:
183
+ 1. Symbol - (s0) These can have custom constraints.
184
+ 2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
185
+
186
+ Currently support is limited to operations with a symbol and and int,
187
+ in `torch/export/dynamic_shapes.py`:
188
+ "Only increasing linear operations with integer coefficients are supported."
189
+ """
190
+ symbol_name = str(sym)
191
+ constraints = _build_symbolic_constraints(symbol_name, constraint)
192
+ if sym.is_symbol:
193
+ symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
194
+ else:
195
+ assert len(sym.free_symbols) > 0
196
+ scope = free_symbols[str(list(sym.free_symbols)[0])].scope
197
+ symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope)
198
+ assert len(symbolic_shape) == 1
199
+ return symbolic_shape[0]
200
+
201
+ # Populate symbol variables before expressions, exprs need to use the same
202
+ # Symbolic scope as the variable they operate on. Expressions can only be
203
+ # integer compuations on symbol variables, so each symbol variable is OK to
204
+ # have its own scope.
205
+ symbolic_shapes = {}
206
+ symbol_variables = [(s,v) for s,v in range_constraints.items() if s.is_symbol]
207
+ symbol_exprs = [(s,v) for s,v in range_constraints.items() if not s.is_symbol]
208
+ for sym, constraint in symbol_variables + symbol_exprs:
209
+ symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
210
+ symbolic_shapes[str(sym)] = symbolic_shape
211
+ return symbolic_shapes
212
+
213
+ symbolic_shapes = _build_symbolic_shapes(exported.range_constraints)
214
+ args = _get_inputs(exported)
215
+
216
+ if DEBUG:
217
+ print('Inputs to aval:', args, '--------')
218
+ print('Symbolic shapes:', symbolic_shapes)
219
+ for arg in args:
220
+ print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes))
221
+
222
+ return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
223
+
224
+
225
+ def exported_program_to_stablehlo(exported_program):
226
+ """Replacement for torch_xla.stablehlo.exported_program_to_stablehlo
227
+
228
+ Convert a program exported via torch.export to StableHLO.
229
+
230
+ This supports dynamic dimension sizes and generates explicit checks for
231
+ dynamo guards in the IR using shape_assertion custom_call ops.
232
+ """
233
+ weights, func = exported_program_to_jax(exported_program)
234
+ jax_avals = extract_avals(exported_program)
235
+ jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,))
236
+ return weights, jax_export
torchax/interop.py ADDED
@@ -0,0 +1,209 @@
1
+ import copy
2
+ import functools
3
+ import torch
4
+ from torch.nn.utils import stateless as torch_stateless
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from jax import tree_util as pytree
8
+ from jax.experimental.shard_map import shard_map
9
+ from torchax import tensor
10
+ import torchax
11
+
12
+ from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
13
+
14
+
15
+ def extract_all_buffers(m: torch.nn.Module):
16
+ buffers = {}
17
+ params = {}
18
+ def extract_one(module, prefix):
19
+ for k in dir(module):
20
+ try:
21
+ v = getattr(module, k)
22
+ except:
23
+ continue
24
+ qual_name = prefix + k
25
+ if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
26
+ params[qual_name] = v
27
+ elif isinstance(v, torch.Tensor):
28
+ buffers[qual_name] = v
29
+ for name, child in module.named_children():
30
+ extract_one(child, prefix + name + '.')
31
+ extract_one(m, '')
32
+ return params, buffers
33
+
34
+
35
+ def set_all_buffers(m, params, buffers):
36
+ def set_one(module, prefix):
37
+ for k in dir(module):
38
+ qual_name = prefix + k
39
+ if (potential_v := buffers.get(qual_name)) is not None:
40
+ setattr(module, k, potential_v)
41
+ elif (potential_v := params.get(qual_name)) is not None:
42
+ print(k, potential_v)
43
+ setattr(module, k, torch.nn.Parameter(potential_v))
44
+ for name, child in module.named_children():
45
+ set_one(child, prefix + name + '.')
46
+
47
+ set_one(m, '')
48
+
49
+
50
+ class JittableModule(torch.nn.Module):
51
+
52
+ def __init__(self, m: torch.nn.Module, extra_jit_args={}):
53
+ super().__init__()
54
+ self.params, self.buffers = extract_all_buffers(m)
55
+ self._model = m
56
+ self._jitted = {}
57
+
58
+ self._extra_jit_args = extra_jit_args
59
+
60
+
61
+ def __call__(self, *args, **kwargs):
62
+ return self.forward(*args, **kwargs)
63
+
64
+
65
+ def functional_call(
66
+ self, method_name, params, buffers, *args, **kwargs):
67
+ kwargs = kwargs or {}
68
+ params_copy = copy.copy(params)
69
+ params_copy.update(buffers)
70
+ with torch_stateless._reparametrize_module(self._model, params_copy):
71
+ res = getattr(self._model, method_name)(*args, **kwargs)
72
+ return res
73
+
74
+
75
+ def forward(self, *args, **kwargs):
76
+ if 'forward' not in self._jitted:
77
+ jitted = jax_jit(
78
+ functools.partial(self.functional_call, 'forward'),
79
+ kwargs_for_jax_jit=self._extra_jit_args,
80
+ )
81
+ def jitted_forward(*args, **kwargs):
82
+ return jitted(self.params, self.buffers, *args, **kwargs)
83
+ self._jitted['forward'] = jitted_forward
84
+ return self._jitted['forward'](*args, **kwargs)
85
+
86
+ def __getattr__(self, key):
87
+ if key == '_model':
88
+ return super().__getattr__(key)
89
+ if key in self._jitted:
90
+ return self._jitted[key]
91
+ return getattr(self._model, key)
92
+
93
+ def make_jitted(self, key):
94
+ jitted = jax_jit(
95
+ functools.partial(self.functional_call, key),
96
+ kwargs_for_jax_jit=self._extra_jit_args)
97
+ def call(*args, **kwargs):
98
+ return jitted(self.params, self.buffers, *args, **kwargs)
99
+ self._jitted[key] = call
100
+
101
+
102
+
103
+
104
+
105
+ class CompileMixin:
106
+
107
+ def functional_call(
108
+ self, method, params, buffers, *args, **kwargs):
109
+ kwargs = kwargs or {}
110
+ params_copy = copy.copy(params)
111
+ params_copy.update(buffers)
112
+ with torch_stateless._reparametrize_module(self, params_copy):
113
+ res = method(*args, **kwargs)
114
+ return res
115
+
116
+ def jit(self, method):
117
+ jitted = jax_jit(functools.partial(self.functional_call, method_name))
118
+ def call(*args, **kwargs):
119
+ return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
120
+ return call
121
+
122
+
123
+ def compile_nn_module(m: torch.nn.Module, methods=None):
124
+ if methods is None:
125
+ methods = ['forward']
126
+
127
+ new_parent = type(
128
+ m.__class__.__name__ + '_with_CompileMixin',
129
+ (CompileMixin, m.__class__),
130
+ )
131
+ m.__class__ = NewParent
132
+
133
+
134
+ def _torch_view(t: JaxValue) -> TorchValue:
135
+ # t is an object from jax land
136
+ # view it as-if it's a torch land object
137
+ if isinstance(t, jax.Array):
138
+ # TODO
139
+ return tensor.Tensor(t, torchax.default_env())
140
+ if isinstance(t, type(jnp.int32)):
141
+ return tensor.t2j_type(t)
142
+ if callable(t): # t is a JaxCallable
143
+ return functools.partial(call_jax, t)
144
+ # regular types are not changed
145
+ return t
146
+
147
+ torch_view = functools.partial(pytree.tree_map, _torch_view)
148
+
149
+
150
+ def _jax_view(t: TorchValue) -> JaxValue:
151
+ # t is an object from torch land
152
+ # view it as-if it's a jax land object
153
+ if isinstance(t, torch.Tensor):
154
+ assert isinstance(t, tensor.Tensor), type(t)
155
+ return t.jax()
156
+ if isinstance(t, type(torch.int32)):
157
+ return tensor.t2j_dtype(t)
158
+
159
+ # torch.nn.Module needs special handling
160
+ if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
161
+ return functools.partial(call_torch, t)
162
+ # regular types are not changed
163
+ return t
164
+
165
+ jax_view = functools.partial(pytree.tree_map, _jax_view)
166
+
167
+
168
+ def call_jax(jax_func: JaxCallable,
169
+ *args: TorchValue,
170
+ **kwargs: TorchValue) -> TorchValue:
171
+ args, kwargs = jax_view((args, kwargs))
172
+ res: JaxValue = jax_func(*args, **kwargs)
173
+ return torch_view(res)
174
+
175
+
176
+ def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue:
177
+ args, kwargs = torch_view((args, kwargs))
178
+ with torchax.default_env():
179
+ res: TorchValue = torch_func(*args, **kwargs)
180
+ return jax_view(res)
181
+
182
+
183
+ fori_loop = torch_view(jax.lax.fori_loop)
184
+
185
+
186
+ def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
187
+ kwargs_for_jax = kwargs_for_jax or {}
188
+ jax_func = jax_view(torch_function)
189
+ jitted = jax_jit_func(jax_func, **kwargs_for_jax)
190
+ return torch_view(jitted)
191
+
192
+
193
+ def jax_jit(torch_function, kwargs_for_jax_jit=None):
194
+ return wrap_jax_jit(torch_function, jax_jit_func=jax.jit,
195
+ kwargs_for_jax=kwargs_for_jax_jit)
196
+
197
+
198
+ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
199
+ return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
200
+ kwargs_for_jax=kwargs_for_jax_shard_map)
201
+
202
+
203
+ def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
204
+ return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad,
205
+ kwargs_for_jax=kwargs_for_value_and_grad)
206
+
207
+ def gradient_checkpoint(torch_function, kwargs=None):
208
+ return wrap_jax_jit(torch_function, jax_jit_func=jax.checkpoint,
209
+ kwargs_for_jax=kwargs)
@@ -0,0 +1,10 @@
1
+ def all_aten_jax_ops():
2
+ # to load the ops
3
+ import torchax.ops.jaten # type: ignore
4
+ import torchax.ops.ops_registry # type: ignore
5
+
6
+ return {
7
+ key: val.func
8
+ for key, val in torchax.ops.ops_registry.all_aten_ops.items()
9
+ if val.is_jax_function
10
+ }