torchax 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +38 -0
- torchax/__init__.py +124 -0
- torchax/config.py +19 -0
- torchax/decompositions.py +308 -0
- torchax/device_module.py +20 -0
- torchax/distributed.py +246 -0
- torchax/environment.py +2 -0
- torchax/export.py +236 -0
- torchax/interop.py +209 -0
- torchax/ops/__init__.py +10 -0
- torchax/ops/jaten.py +5212 -0
- torchax/ops/jax_reimplement.py +169 -0
- torchax/ops/jc10d.py +51 -0
- torchax/ops/jlibrary.py +73 -0
- torchax/ops/jtorch.py +427 -0
- torchax/ops/jtorchvision_nms.py +245 -0
- torchax/ops/mappings.py +97 -0
- torchax/ops/op_base.py +104 -0
- torchax/ops/ops_registry.py +50 -0
- torchax/tensor.py +557 -0
- torchax/tf_integration.py +119 -0
- torchax/train.py +120 -0
- torchax/types.py +12 -0
- torchax-0.0.4.dist-info/METADATA +341 -0
- torchax-0.0.4.dist-info/RECORD +27 -0
- torchax-0.0.4.dist-info/WHEEL +4 -0
- torchax-0.0.4.dist-info/licenses/LICENSE +28 -0
torchax/distributed.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""`torch.distributed` backend implemented with JAX collective ops.
|
|
2
|
+
|
|
3
|
+
EXPERIMENTAL: This module is still highly experimental, and it may be removed
|
|
4
|
+
before any stable release.
|
|
5
|
+
|
|
6
|
+
Note: JAX collective ops require that axis names be defined in `pmap` or
|
|
7
|
+
`shmap`. The distributed backend only supports one axis, named `torch_dist`.
|
|
8
|
+
This name is defined by our mirror implementation of `spawn`.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import datetime
|
|
12
|
+
import functools
|
|
13
|
+
import logging
|
|
14
|
+
import os
|
|
15
|
+
from typing import List, Optional, Union
|
|
16
|
+
|
|
17
|
+
import jax
|
|
18
|
+
import numpy as np
|
|
19
|
+
import torch
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
import torch.distributed._functional_collectives
|
|
22
|
+
from torch._C._distributed_c10d import ProcessGroup # type: ignore
|
|
23
|
+
import torch.distributed
|
|
24
|
+
import torchax
|
|
25
|
+
from jax.sharding import NamedSharding
|
|
26
|
+
from jax.sharding import Mesh, PartitionSpec as P
|
|
27
|
+
from jax.experimental import mesh_utils
|
|
28
|
+
import torch.utils._pytree as torch_pytree
|
|
29
|
+
from torchax import interop
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class ProcessGroupJax(ProcessGroup):
|
|
33
|
+
"""Distributed backend implemented with JAX."""
|
|
34
|
+
|
|
35
|
+
def __init__(self, prefix_store, rank, size, timeout):
|
|
36
|
+
super().__init__(rank, size)
|
|
37
|
+
self._group_name = None
|
|
38
|
+
|
|
39
|
+
def getBackendName(self):
|
|
40
|
+
return "jax"
|
|
41
|
+
|
|
42
|
+
# TODO(wcromar): why doesn't default group name setter work?
|
|
43
|
+
# https://github.com/pytorch/pytorch/blob/7b1988f9222f3dec5cc2012afce84218199748ae/torch/csrc/distributed/c10d/ProcessGroup.cpp#L148-L152
|
|
44
|
+
def _set_group_name(self, name: str) -> None:
|
|
45
|
+
self._group_name = name
|
|
46
|
+
|
|
47
|
+
@property
|
|
48
|
+
def group_name(self):
|
|
49
|
+
assert self._group_name
|
|
50
|
+
return self._group_name
|
|
51
|
+
|
|
52
|
+
@staticmethod
|
|
53
|
+
def _work(
|
|
54
|
+
tensors: Union[torch.Tensor, List[torch.Tensor], List[List[torch.Tensor]]],
|
|
55
|
+
) -> dist.Work:
|
|
56
|
+
fut = torch.futures.Future()
|
|
57
|
+
fut.set_result(tensors)
|
|
58
|
+
return torch._C._distributed_c10d._create_work_from_future(fut)
|
|
59
|
+
|
|
60
|
+
def _allgather_base(
|
|
61
|
+
self,
|
|
62
|
+
output: torch.Tensor,
|
|
63
|
+
input: torch.Tensor,
|
|
64
|
+
opts=...,
|
|
65
|
+
) -> dist.Work:
|
|
66
|
+
assert isinstance(input, torchax.tensor.Tensor)
|
|
67
|
+
assert isinstance(output, torchax.tensor.Tensor)
|
|
68
|
+
torch.distributed._functional_collectives.all_gather_tensor_inplace(
|
|
69
|
+
output, input, group=self
|
|
70
|
+
)
|
|
71
|
+
return self._work(output)
|
|
72
|
+
|
|
73
|
+
def allreduce(
|
|
74
|
+
self,
|
|
75
|
+
tensors: List[torch.Tensor],
|
|
76
|
+
opts: dist.AllreduceOptions = ...,
|
|
77
|
+
) -> dist.Work:
|
|
78
|
+
assert len(tensors) == 1
|
|
79
|
+
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
80
|
+
torch.distributed._functional_collectives.all_reduce_inplace(
|
|
81
|
+
tensors[0],
|
|
82
|
+
torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
|
|
83
|
+
opts.reduceOp.op
|
|
84
|
+
],
|
|
85
|
+
self,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return self._work(tensors)
|
|
89
|
+
|
|
90
|
+
def broadcast(
|
|
91
|
+
self,
|
|
92
|
+
tensors: List[torch.Tensor],
|
|
93
|
+
opts: dist.BroadcastOptions = ...,
|
|
94
|
+
) -> dist.Work:
|
|
95
|
+
assert len(tensors) == 1
|
|
96
|
+
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
97
|
+
tensors[0].copy_(
|
|
98
|
+
torch.distributed._functional_collectives.broadcast(
|
|
99
|
+
tensors[0], opts.rootRank, group=self
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
return self._work(tensors)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
dist.Backend.register_backend("jax", ProcessGroupJax)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def jax_rendezvous_handler(
|
|
110
|
+
url: str, timeout: datetime.timedelta = ..., **kwargs
|
|
111
|
+
):
|
|
112
|
+
"""Initialize distributed store with JAX process IDs.
|
|
113
|
+
|
|
114
|
+
Requires `$MASTER_ADDR` and `$MASTER_PORT`.
|
|
115
|
+
"""
|
|
116
|
+
# TODO(wcromar): jax.distributed.initialize(...) for multiprocess on GPU
|
|
117
|
+
# TODO(wcromar): Can we use the XLA coordinator as a Store? This isn't part
|
|
118
|
+
# of their public Python API
|
|
119
|
+
master_ip = os.environ["MASTER_ADDR"]
|
|
120
|
+
master_port = int(os.environ["MASTER_PORT"])
|
|
121
|
+
# TODO(wcromar): Use `torchrun`'s store if available
|
|
122
|
+
store = dist.TCPStore(
|
|
123
|
+
master_ip,
|
|
124
|
+
master_port,
|
|
125
|
+
jax.process_count(),
|
|
126
|
+
is_master=jax.process_index() == 0,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
yield (store, jax.process_index(), jax.process_count())
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
dist.register_rendezvous_handler("jax", jax_rendezvous_handler)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
|
|
136
|
+
"""Wrap `f` in a JAX `pmap` with the axis name `torch_dist` defined.
|
|
137
|
+
`f` is expected to take the replica index as a positional argument, similar
|
|
138
|
+
to `torch.multiprocessing.spawn`.
|
|
139
|
+
Note: `spawn` does not actually create parallel processes.
|
|
140
|
+
"""
|
|
141
|
+
env = env or torchax.default_env()
|
|
142
|
+
|
|
143
|
+
def jax_wrapper(index, jax_args):
|
|
144
|
+
index, args = env.j2t_iso([index, jax_args])
|
|
145
|
+
torch_outputs = f(index, *args)
|
|
146
|
+
return env.t2j_iso(torch_outputs)
|
|
147
|
+
|
|
148
|
+
jax_outputs = jax.pmap(jax_wrapper, axis_name="torch_dist")(
|
|
149
|
+
np.arange(jax.device_count()), env.t2j_iso(args)
|
|
150
|
+
)
|
|
151
|
+
return env.j2t_iso(jax_outputs)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class DistributedDataParallel(torch.nn.Module):
|
|
155
|
+
"""Re-implementation of DistributedDataParallel using JAX SPMD.
|
|
156
|
+
|
|
157
|
+
Splits inputs along batch dimension (assumed to be 0) across all devices in
|
|
158
|
+
JAX runtime, including remote devices. Each process should load a distinct
|
|
159
|
+
shard of the input data using e.g. DistributedSampler. Each process' shard
|
|
160
|
+
is then further split among the addressable devices (e.g. local TPU chips)
|
|
161
|
+
by `shard_input`.
|
|
162
|
+
|
|
163
|
+
Note: since parameters are replicated across addressable devices, inputs
|
|
164
|
+
must also be SPMD sharded using `shard_input` or `replicate_input`.
|
|
165
|
+
|
|
166
|
+
Example usage:
|
|
167
|
+
|
|
168
|
+
```
|
|
169
|
+
jax_model = torchax.distributed.DistributedDataParallel(create_model())
|
|
170
|
+
for data, dataloader:
|
|
171
|
+
jax_data = jax_model.shard_input(data)
|
|
172
|
+
jax_output = jax_model(jax_data)
|
|
173
|
+
```
|
|
174
|
+
"""
|
|
175
|
+
def __init__(
|
|
176
|
+
self,
|
|
177
|
+
module: torch.nn.Module,
|
|
178
|
+
env: Optional[torchax.tensor.Environment] = None,
|
|
179
|
+
**kwargs,
|
|
180
|
+
):
|
|
181
|
+
if kwargs:
|
|
182
|
+
logging.warning(f"Unsupported kwargs {kwargs}")
|
|
183
|
+
|
|
184
|
+
super().__init__()
|
|
185
|
+
self._env = env or torchax.default_env()
|
|
186
|
+
self._mesh = Mesh(
|
|
187
|
+
mesh_utils.create_device_mesh((jax.device_count(),)),
|
|
188
|
+
axis_names=("batch",),
|
|
189
|
+
)
|
|
190
|
+
replicated_state = torch_pytree.tree_map_only(
|
|
191
|
+
torch.Tensor,
|
|
192
|
+
lambda t: self._env.j2t_iso(
|
|
193
|
+
jax.device_put(
|
|
194
|
+
self._env.to_xla(t)._elem, NamedSharding(self._mesh, P())
|
|
195
|
+
)
|
|
196
|
+
),
|
|
197
|
+
module.state_dict(),
|
|
198
|
+
)
|
|
199
|
+
# TODO: broadcast
|
|
200
|
+
module.load_state_dict(replicated_state, assign=True)
|
|
201
|
+
self._module = module
|
|
202
|
+
|
|
203
|
+
def shard_input(self, inp):
|
|
204
|
+
per_process_batch_size = inp.shape[0] # assumes batch dim is 0
|
|
205
|
+
per_replica_batch_size = per_process_batch_size // jax.local_device_count()
|
|
206
|
+
per_replica_batches = torch.chunk(inp, jax.local_device_count())
|
|
207
|
+
global_batch_size = per_replica_batch_size * jax.device_count()
|
|
208
|
+
global_batch_shape = (global_batch_size,) + inp.shape[1:]
|
|
209
|
+
|
|
210
|
+
sharding = NamedSharding(self._mesh, P("batch"))
|
|
211
|
+
return self._env.j2t_iso(jax.make_array_from_single_device_arrays(
|
|
212
|
+
global_batch_shape,
|
|
213
|
+
NamedSharding(self._mesh, P("batch")),
|
|
214
|
+
arrays=[
|
|
215
|
+
jax.device_put(self._env.to_xla(batch)._elem, device)
|
|
216
|
+
for batch, device in zip(
|
|
217
|
+
per_replica_batches, sharding.addressable_devices
|
|
218
|
+
)
|
|
219
|
+
],
|
|
220
|
+
))
|
|
221
|
+
|
|
222
|
+
def replicate_input(self, inp):
|
|
223
|
+
return self._env.j2t_iso(
|
|
224
|
+
jax.device_put(inp._elem, NamedSharding(self._mesh, P()))
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
def jit_step(self, func):
|
|
228
|
+
@functools.partial(interop.jax_jit,
|
|
229
|
+
kwargs_for_jax_jit={'donate_argnums': 0})
|
|
230
|
+
def _jit_fn(states, args):
|
|
231
|
+
self.load_state_dict(states)
|
|
232
|
+
outputs = func(*args)
|
|
233
|
+
return self.state_dict(), outputs
|
|
234
|
+
|
|
235
|
+
@functools.wraps(func)
|
|
236
|
+
def inner(*args):
|
|
237
|
+
jax_states = self.state_dict()
|
|
238
|
+
new_states, outputs = _jit_fn(jax_states, args)
|
|
239
|
+
self.load_state_dict(new_states)
|
|
240
|
+
return outputs
|
|
241
|
+
|
|
242
|
+
return inner
|
|
243
|
+
|
|
244
|
+
def forward(self, *args):
|
|
245
|
+
with self._env:
|
|
246
|
+
return self._module(*args)
|
torchax/environment.py
ADDED
torchax/export.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
# pylint: disable
|
|
2
|
+
"""Utilities for exporting a torch program to jax/stablehlo."""
|
|
3
|
+
import copy
|
|
4
|
+
from typing import Any, Dict, Tuple
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils import _pytree as pytree
|
|
7
|
+
from torchax import tensor
|
|
8
|
+
from torchax.ops import ops_registry
|
|
9
|
+
from torchax import decompositions
|
|
10
|
+
import jax
|
|
11
|
+
import jax.export
|
|
12
|
+
import sympy
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
DEBUG = False
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class JaxInterpreter(torch.fx.Interpreter):
|
|
19
|
+
"""Experimental."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, graph_module):
|
|
22
|
+
super().__init__(graph_module)
|
|
23
|
+
import torchax.ops.jaten
|
|
24
|
+
import torchax.ops.jtorch
|
|
25
|
+
|
|
26
|
+
def call_function(self, target, args: Tuple, kwargs: Dict) -> Any:
|
|
27
|
+
if not isinstance(target,
|
|
28
|
+
(torch._ops.OpOverloadPacket, torch._ops.OpOverload)):
|
|
29
|
+
return super().call_function(target, args, kwargs)
|
|
30
|
+
|
|
31
|
+
if DEBUG:
|
|
32
|
+
print('Running ', target.name(), '--------')
|
|
33
|
+
|
|
34
|
+
op = ops_registry.all_aten_ops.get(target)
|
|
35
|
+
if op is None:
|
|
36
|
+
op = ops_registry.all_aten_ops.get(target.overloadpacket)
|
|
37
|
+
assert op is not None, target
|
|
38
|
+
assert op.is_jax_function, op
|
|
39
|
+
if op is None:
|
|
40
|
+
op = ops_registry.all_aten_ops.get(target.overloadpacket)
|
|
41
|
+
if op is None:
|
|
42
|
+
print(target.name(), target.tags)
|
|
43
|
+
raise RuntimeError('No lowering found for', target.name())
|
|
44
|
+
return op.func(*args, **kwargs)
|
|
45
|
+
|
|
46
|
+
def run_node(self, n) -> Any:
|
|
47
|
+
res = super().run_node(n)
|
|
48
|
+
if DEBUG:
|
|
49
|
+
if n.op == 'call_function':
|
|
50
|
+
if hasattr(res, 'shape'):
|
|
51
|
+
print('Meta:', n.meta.get('val').shape, 'REAL: ', res.shape)
|
|
52
|
+
return res
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
from torch._decomp import get_decompositions
|
|
56
|
+
import torch._refs
|
|
57
|
+
|
|
58
|
+
_extra_decomp = get_decompositions([torch.ops.aten.unfold])
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _extract_states_from_exported_program(exported_model):
|
|
62
|
+
# NOTE call convention: (parameters, buffers, user_inputs)
|
|
63
|
+
param_and_buffer_keys = exported_model.graph_signature.parameters + exported_model.graph_signature.buffers
|
|
64
|
+
state_dict = copy.copy(exported_model.state_dict)
|
|
65
|
+
if (constants := getattr(exported_model, 'constants', None)) is not None:
|
|
66
|
+
state_dict.update(constants)
|
|
67
|
+
param_buffer_values = list(state_dict[key] for key in param_and_buffer_keys)
|
|
68
|
+
|
|
69
|
+
if hasattr(exported_model.graph_signature, "lifted_tensor_constants"):
|
|
70
|
+
for name in exported_model.graph_signature.lifted_tensor_constants:
|
|
71
|
+
param_buffer_values.append(exported_model.tensor_constants[name])
|
|
72
|
+
|
|
73
|
+
return param_and_buffer_keys, param_buffer_values
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
77
|
+
"""returns a pytree of jax arrays(state), and
|
|
78
|
+
|
|
79
|
+
a callable(func) that is jax function.
|
|
80
|
+
|
|
81
|
+
func(state, input) would be how you call it.
|
|
82
|
+
"""
|
|
83
|
+
if torch.__version__ >= '2.2':
|
|
84
|
+
# torch version 2.1 didn't expose this yet
|
|
85
|
+
exported_program = exported_program.run_decompositions()
|
|
86
|
+
exported_program = exported_program.run_decompositions(decompositions.EXTRA_DECOMP)
|
|
87
|
+
if DEBUG:
|
|
88
|
+
print(exported_program.graph_module.code)
|
|
89
|
+
|
|
90
|
+
names, states = _extract_states_from_exported_program(exported_program)
|
|
91
|
+
|
|
92
|
+
def _extract_args(args, kwargs):
|
|
93
|
+
flat_args, received_spec = pytree.tree_flatten(
|
|
94
|
+
(args, kwargs)) # type: ignore[possibly-undefined]
|
|
95
|
+
return flat_args
|
|
96
|
+
|
|
97
|
+
num_mutations = len(exported_program.graph_signature.buffers_to_mutate)
|
|
98
|
+
|
|
99
|
+
def func(states, inputs):
|
|
100
|
+
args = _extract_args(inputs, {})
|
|
101
|
+
res = JaxInterpreter(exported_program.graph_module).run(
|
|
102
|
+
*states,
|
|
103
|
+
*args,
|
|
104
|
+
enable_io_processing=False,
|
|
105
|
+
)
|
|
106
|
+
res = res[num_mutations:]
|
|
107
|
+
return res
|
|
108
|
+
|
|
109
|
+
if export_raw:
|
|
110
|
+
return names, states, func
|
|
111
|
+
|
|
112
|
+
states = pytree.tree_map_only(torch.Tensor, tensor.t2j, states)
|
|
113
|
+
return states, func
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def extract_avals(exported):
|
|
117
|
+
"""Return JAX Abstract Value shapes for all input parameters of the exported
|
|
118
|
+
program. This supports dynamic batch dimensions, including with constraints.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
def _to_aval(arg_meta, symbolic_shapes):
|
|
122
|
+
"""Convet from torch type to jax abstract value for export tracing
|
|
123
|
+
"""
|
|
124
|
+
def _get_dim(d):
|
|
125
|
+
if isinstance(d, torch.SymInt):
|
|
126
|
+
return symbolic_shapes[str(d)]
|
|
127
|
+
return d
|
|
128
|
+
|
|
129
|
+
val = arg_meta['val']
|
|
130
|
+
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(val, bool)
|
|
131
|
+
if is_scalar:
|
|
132
|
+
return jax.ShapeDtypeStruct([], type(arg_meta['val']))
|
|
133
|
+
|
|
134
|
+
tensor_meta = arg_meta['tensor_meta']
|
|
135
|
+
shape = [_get_dim(d) for d in tensor_meta.shape]
|
|
136
|
+
return jax.ShapeDtypeStruct(shape, tensor.t2j_dtype(tensor_meta.dtype))
|
|
137
|
+
|
|
138
|
+
def _get_inputs(exported):
|
|
139
|
+
"""Return placeholders with input metadata"""
|
|
140
|
+
placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
|
|
141
|
+
input_placeholders = [
|
|
142
|
+
p
|
|
143
|
+
for p, s in zip(placeholders, exported.graph_signature.input_specs)
|
|
144
|
+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
145
|
+
]
|
|
146
|
+
return input_placeholders
|
|
147
|
+
|
|
148
|
+
def _build_symbolic_shapes(range_constraints):
|
|
149
|
+
"""Convert torch SymInt to JAX symbolic_shape and stores in a map using the
|
|
150
|
+
string name of the torch symbolic int.
|
|
151
|
+
|
|
152
|
+
TODO: There is probably a better way of storing a key for a symbolic int.
|
|
153
|
+
This value needs to be looked up again in `_to_aval` to figure out which
|
|
154
|
+
JAX symbolic to map to for a given torch tensor.
|
|
155
|
+
"""
|
|
156
|
+
if len(range_constraints) == 0:
|
|
157
|
+
return None
|
|
158
|
+
|
|
159
|
+
def _build_symbolic_constraints(symbol_name, torch_constraint):
|
|
160
|
+
"""Convert torch SymInt constraints to string for JAX symbolic_shape
|
|
161
|
+
Using sympy may be overkill here, currently PyTorch only uses ValueRanges
|
|
162
|
+
which allow specifying the min and the max of a value, for example:
|
|
163
|
+
torch.export.Dim("a", min=5, max=10)
|
|
164
|
+
==> ("a >= 5", "a <= 10",)
|
|
165
|
+
"""
|
|
166
|
+
if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.ValueRanges) or torch_constraint.is_bool:
|
|
167
|
+
raise TypeError(f"No symbolic constraint handler for: {torch_constraint}")
|
|
168
|
+
|
|
169
|
+
constraints = []
|
|
170
|
+
symbol = sympy.Symbol(symbol_name)
|
|
171
|
+
if torch_constraint.lower != 2:
|
|
172
|
+
constraints.append(symbol >= torch_constraint.lower)
|
|
173
|
+
from sympy.core.singleton import S
|
|
174
|
+
if not torch_constraint.upper.is_infinite and torch_constraint.upper is not S.IntInfinity:
|
|
175
|
+
constraints.append(symbol <= torch_constraint.upper)
|
|
176
|
+
|
|
177
|
+
return tuple(sympy.pretty(c, use_unicode=False) for c in constraints)
|
|
178
|
+
|
|
179
|
+
def _build_symbolic_shape(sym, constraint, free_symbols):
|
|
180
|
+
"""Returns a JAX symbolic shape for a given symbol and constraint
|
|
181
|
+
|
|
182
|
+
There are two possible sympy `sym` inputs:
|
|
183
|
+
1. Symbol - (s0) These can have custom constraints.
|
|
184
|
+
2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
|
|
185
|
+
|
|
186
|
+
Currently support is limited to operations with a symbol and and int,
|
|
187
|
+
in `torch/export/dynamic_shapes.py`:
|
|
188
|
+
"Only increasing linear operations with integer coefficients are supported."
|
|
189
|
+
"""
|
|
190
|
+
symbol_name = str(sym)
|
|
191
|
+
constraints = _build_symbolic_constraints(symbol_name, constraint)
|
|
192
|
+
if sym.is_symbol:
|
|
193
|
+
symbolic_shape = jax.export.symbolic_shape(symbol_name, constraints=constraints)
|
|
194
|
+
else:
|
|
195
|
+
assert len(sym.free_symbols) > 0
|
|
196
|
+
scope = free_symbols[str(list(sym.free_symbols)[0])].scope
|
|
197
|
+
symbolic_shape = jax.export.symbolic_shape(symbol_name, scope=scope)
|
|
198
|
+
assert len(symbolic_shape) == 1
|
|
199
|
+
return symbolic_shape[0]
|
|
200
|
+
|
|
201
|
+
# Populate symbol variables before expressions, exprs need to use the same
|
|
202
|
+
# Symbolic scope as the variable they operate on. Expressions can only be
|
|
203
|
+
# integer compuations on symbol variables, so each symbol variable is OK to
|
|
204
|
+
# have its own scope.
|
|
205
|
+
symbolic_shapes = {}
|
|
206
|
+
symbol_variables = [(s,v) for s,v in range_constraints.items() if s.is_symbol]
|
|
207
|
+
symbol_exprs = [(s,v) for s,v in range_constraints.items() if not s.is_symbol]
|
|
208
|
+
for sym, constraint in symbol_variables + symbol_exprs:
|
|
209
|
+
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
|
|
210
|
+
symbolic_shapes[str(sym)] = symbolic_shape
|
|
211
|
+
return symbolic_shapes
|
|
212
|
+
|
|
213
|
+
symbolic_shapes = _build_symbolic_shapes(exported.range_constraints)
|
|
214
|
+
args = _get_inputs(exported)
|
|
215
|
+
|
|
216
|
+
if DEBUG:
|
|
217
|
+
print('Inputs to aval:', args, '--------')
|
|
218
|
+
print('Symbolic shapes:', symbolic_shapes)
|
|
219
|
+
for arg in args:
|
|
220
|
+
print('Meta2Aval', arg.meta, '--> ', _to_aval(arg.meta, symbolic_shapes))
|
|
221
|
+
|
|
222
|
+
return [_to_aval(arg.meta, symbolic_shapes) for arg in args]
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def exported_program_to_stablehlo(exported_program):
|
|
226
|
+
"""Replacement for torch_xla.stablehlo.exported_program_to_stablehlo
|
|
227
|
+
|
|
228
|
+
Convert a program exported via torch.export to StableHLO.
|
|
229
|
+
|
|
230
|
+
This supports dynamic dimension sizes and generates explicit checks for
|
|
231
|
+
dynamo guards in the IR using shape_assertion custom_call ops.
|
|
232
|
+
"""
|
|
233
|
+
weights, func = exported_program_to_jax(exported_program)
|
|
234
|
+
jax_avals = extract_avals(exported_program)
|
|
235
|
+
jax_export = jax.export.export(jax.jit(func))(weights, (jax_avals,))
|
|
236
|
+
return weights, jax_export
|
torchax/interop.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import functools
|
|
3
|
+
import torch
|
|
4
|
+
from torch.nn.utils import stateless as torch_stateless
|
|
5
|
+
import jax
|
|
6
|
+
import jax.numpy as jnp
|
|
7
|
+
from jax import tree_util as pytree
|
|
8
|
+
from jax.experimental.shard_map import shard_map
|
|
9
|
+
from torchax import tensor
|
|
10
|
+
import torchax
|
|
11
|
+
|
|
12
|
+
from torchax.types import JaxValue, TorchValue, JaxCallable, TorchCallable
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def extract_all_buffers(m: torch.nn.Module):
|
|
16
|
+
buffers = {}
|
|
17
|
+
params = {}
|
|
18
|
+
def extract_one(module, prefix):
|
|
19
|
+
for k in dir(module):
|
|
20
|
+
try:
|
|
21
|
+
v = getattr(module, k)
|
|
22
|
+
except:
|
|
23
|
+
continue
|
|
24
|
+
qual_name = prefix + k
|
|
25
|
+
if isinstance(v, torch.nn.parameter.Parameter) and v.requires_grad:
|
|
26
|
+
params[qual_name] = v
|
|
27
|
+
elif isinstance(v, torch.Tensor):
|
|
28
|
+
buffers[qual_name] = v
|
|
29
|
+
for name, child in module.named_children():
|
|
30
|
+
extract_one(child, prefix + name + '.')
|
|
31
|
+
extract_one(m, '')
|
|
32
|
+
return params, buffers
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def set_all_buffers(m, params, buffers):
|
|
36
|
+
def set_one(module, prefix):
|
|
37
|
+
for k in dir(module):
|
|
38
|
+
qual_name = prefix + k
|
|
39
|
+
if (potential_v := buffers.get(qual_name)) is not None:
|
|
40
|
+
setattr(module, k, potential_v)
|
|
41
|
+
elif (potential_v := params.get(qual_name)) is not None:
|
|
42
|
+
print(k, potential_v)
|
|
43
|
+
setattr(module, k, torch.nn.Parameter(potential_v))
|
|
44
|
+
for name, child in module.named_children():
|
|
45
|
+
set_one(child, prefix + name + '.')
|
|
46
|
+
|
|
47
|
+
set_one(m, '')
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class JittableModule(torch.nn.Module):
|
|
51
|
+
|
|
52
|
+
def __init__(self, m: torch.nn.Module, extra_jit_args={}):
|
|
53
|
+
super().__init__()
|
|
54
|
+
self.params, self.buffers = extract_all_buffers(m)
|
|
55
|
+
self._model = m
|
|
56
|
+
self._jitted = {}
|
|
57
|
+
|
|
58
|
+
self._extra_jit_args = extra_jit_args
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def __call__(self, *args, **kwargs):
|
|
62
|
+
return self.forward(*args, **kwargs)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def functional_call(
|
|
66
|
+
self, method_name, params, buffers, *args, **kwargs):
|
|
67
|
+
kwargs = kwargs or {}
|
|
68
|
+
params_copy = copy.copy(params)
|
|
69
|
+
params_copy.update(buffers)
|
|
70
|
+
with torch_stateless._reparametrize_module(self._model, params_copy):
|
|
71
|
+
res = getattr(self._model, method_name)(*args, **kwargs)
|
|
72
|
+
return res
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def forward(self, *args, **kwargs):
|
|
76
|
+
if 'forward' not in self._jitted:
|
|
77
|
+
jitted = jax_jit(
|
|
78
|
+
functools.partial(self.functional_call, 'forward'),
|
|
79
|
+
kwargs_for_jax_jit=self._extra_jit_args,
|
|
80
|
+
)
|
|
81
|
+
def jitted_forward(*args, **kwargs):
|
|
82
|
+
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
83
|
+
self._jitted['forward'] = jitted_forward
|
|
84
|
+
return self._jitted['forward'](*args, **kwargs)
|
|
85
|
+
|
|
86
|
+
def __getattr__(self, key):
|
|
87
|
+
if key == '_model':
|
|
88
|
+
return super().__getattr__(key)
|
|
89
|
+
if key in self._jitted:
|
|
90
|
+
return self._jitted[key]
|
|
91
|
+
return getattr(self._model, key)
|
|
92
|
+
|
|
93
|
+
def make_jitted(self, key):
|
|
94
|
+
jitted = jax_jit(
|
|
95
|
+
functools.partial(self.functional_call, key),
|
|
96
|
+
kwargs_for_jax_jit=self._extra_jit_args)
|
|
97
|
+
def call(*args, **kwargs):
|
|
98
|
+
return jitted(self.params, self.buffers, *args, **kwargs)
|
|
99
|
+
self._jitted[key] = call
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class CompileMixin:
|
|
106
|
+
|
|
107
|
+
def functional_call(
|
|
108
|
+
self, method, params, buffers, *args, **kwargs):
|
|
109
|
+
kwargs = kwargs or {}
|
|
110
|
+
params_copy = copy.copy(params)
|
|
111
|
+
params_copy.update(buffers)
|
|
112
|
+
with torch_stateless._reparametrize_module(self, params_copy):
|
|
113
|
+
res = method(*args, **kwargs)
|
|
114
|
+
return res
|
|
115
|
+
|
|
116
|
+
def jit(self, method):
|
|
117
|
+
jitted = jax_jit(functools.partial(self.functional_call, method_name))
|
|
118
|
+
def call(*args, **kwargs):
|
|
119
|
+
return jitted(self.named_paramters(), self.named_buffers(), *args, **kwargs)
|
|
120
|
+
return call
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def compile_nn_module(m: torch.nn.Module, methods=None):
|
|
124
|
+
if methods is None:
|
|
125
|
+
methods = ['forward']
|
|
126
|
+
|
|
127
|
+
new_parent = type(
|
|
128
|
+
m.__class__.__name__ + '_with_CompileMixin',
|
|
129
|
+
(CompileMixin, m.__class__),
|
|
130
|
+
)
|
|
131
|
+
m.__class__ = NewParent
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _torch_view(t: JaxValue) -> TorchValue:
|
|
135
|
+
# t is an object from jax land
|
|
136
|
+
# view it as-if it's a torch land object
|
|
137
|
+
if isinstance(t, jax.Array):
|
|
138
|
+
# TODO
|
|
139
|
+
return tensor.Tensor(t, torchax.default_env())
|
|
140
|
+
if isinstance(t, type(jnp.int32)):
|
|
141
|
+
return tensor.t2j_type(t)
|
|
142
|
+
if callable(t): # t is a JaxCallable
|
|
143
|
+
return functools.partial(call_jax, t)
|
|
144
|
+
# regular types are not changed
|
|
145
|
+
return t
|
|
146
|
+
|
|
147
|
+
torch_view = functools.partial(pytree.tree_map, _torch_view)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _jax_view(t: TorchValue) -> JaxValue:
|
|
151
|
+
# t is an object from torch land
|
|
152
|
+
# view it as-if it's a jax land object
|
|
153
|
+
if isinstance(t, torch.Tensor):
|
|
154
|
+
assert isinstance(t, tensor.Tensor), type(t)
|
|
155
|
+
return t.jax()
|
|
156
|
+
if isinstance(t, type(torch.int32)):
|
|
157
|
+
return tensor.t2j_dtype(t)
|
|
158
|
+
|
|
159
|
+
# torch.nn.Module needs special handling
|
|
160
|
+
if not isinstance(t, torch.nn.Module) and callable(t): # t is a TorchCallable
|
|
161
|
+
return functools.partial(call_torch, t)
|
|
162
|
+
# regular types are not changed
|
|
163
|
+
return t
|
|
164
|
+
|
|
165
|
+
jax_view = functools.partial(pytree.tree_map, _jax_view)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def call_jax(jax_func: JaxCallable,
|
|
169
|
+
*args: TorchValue,
|
|
170
|
+
**kwargs: TorchValue) -> TorchValue:
|
|
171
|
+
args, kwargs = jax_view((args, kwargs))
|
|
172
|
+
res: JaxValue = jax_func(*args, **kwargs)
|
|
173
|
+
return torch_view(res)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def call_torch(torch_func: TorchCallable, *args: JaxValue, **kwargs: JaxValue) -> JaxValue:
|
|
177
|
+
args, kwargs = torch_view((args, kwargs))
|
|
178
|
+
with torchax.default_env():
|
|
179
|
+
res: TorchValue = torch_func(*args, **kwargs)
|
|
180
|
+
return jax_view(res)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
fori_loop = torch_view(jax.lax.fori_loop)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def wrap_jax_jit(torch_function, jax_jit_func=jax.jit, kwargs_for_jax=None):
|
|
187
|
+
kwargs_for_jax = kwargs_for_jax or {}
|
|
188
|
+
jax_func = jax_view(torch_function)
|
|
189
|
+
jitted = jax_jit_func(jax_func, **kwargs_for_jax)
|
|
190
|
+
return torch_view(jitted)
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def jax_jit(torch_function, kwargs_for_jax_jit=None):
|
|
194
|
+
return wrap_jax_jit(torch_function, jax_jit_func=jax.jit,
|
|
195
|
+
kwargs_for_jax=kwargs_for_jax_jit)
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
|
|
199
|
+
return wrap_jax_jit(torch_function, jax_jit_func=shard_map,
|
|
200
|
+
kwargs_for_jax=kwargs_for_jax_shard_map)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
|
|
204
|
+
return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad,
|
|
205
|
+
kwargs_for_jax=kwargs_for_value_and_grad)
|
|
206
|
+
|
|
207
|
+
def gradient_checkpoint(torch_function, kwargs=None):
|
|
208
|
+
return wrap_jax_jit(torch_function, jax_jit_func=jax.checkpoint,
|
|
209
|
+
kwargs_for_jax=kwargs)
|
torchax/ops/__init__.py
ADDED