torchax 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +57 -19
- torchax/amp.py +333 -0
- torchax/config.py +19 -12
- torchax/decompositions.py +663 -195
- torchax/device_module.py +7 -1
- torchax/distributed.py +55 -60
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +275 -141
- torchax/mesh_util.py +211 -0
- torchax/ops/jaten.py +1718 -1294
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +219 -78
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +417 -275
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/METADATA +111 -145
- torchax-0.0.5.dist-info/RECORD +32 -0
- torchax/environment.py +0 -2
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.5.dist-info}/licenses/LICENSE +0 -0
torchax/device_module.py
CHANGED
|
@@ -1,20 +1,26 @@
|
|
|
1
1
|
def _is_in_bad_fork():
|
|
2
2
|
return False
|
|
3
3
|
|
|
4
|
+
|
|
4
5
|
def manual_seed_all(seed):
|
|
5
6
|
pass
|
|
6
7
|
|
|
8
|
+
|
|
7
9
|
def device_count():
|
|
8
10
|
return 1
|
|
9
11
|
|
|
12
|
+
|
|
10
13
|
def get_rng_state():
|
|
11
14
|
return []
|
|
12
15
|
|
|
16
|
+
|
|
13
17
|
def set_rng_state(new_state, device):
|
|
14
18
|
pass
|
|
15
19
|
|
|
20
|
+
|
|
16
21
|
def is_available():
|
|
17
22
|
return True
|
|
18
23
|
|
|
24
|
+
|
|
19
25
|
def current_device():
|
|
20
|
-
return 0
|
|
26
|
+
return 0
|
torchax/distributed.py
CHANGED
|
@@ -51,64 +51,61 @@ class ProcessGroupJax(ProcessGroup):
|
|
|
51
51
|
|
|
52
52
|
@staticmethod
|
|
53
53
|
def _work(
|
|
54
|
-
|
|
54
|
+
tensors: Union[torch.Tensor, List[torch.Tensor],
|
|
55
|
+
List[List[torch.Tensor]]],
|
|
55
56
|
) -> dist.Work:
|
|
56
57
|
fut = torch.futures.Future()
|
|
57
58
|
fut.set_result(tensors)
|
|
58
59
|
return torch._C._distributed_c10d._create_work_from_future(fut)
|
|
59
60
|
|
|
60
61
|
def _allgather_base(
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
62
|
+
self,
|
|
63
|
+
output: torch.Tensor,
|
|
64
|
+
input: torch.Tensor,
|
|
65
|
+
opts=...,
|
|
65
66
|
) -> dist.Work:
|
|
66
67
|
assert isinstance(input, torchax.tensor.Tensor)
|
|
67
68
|
assert isinstance(output, torchax.tensor.Tensor)
|
|
68
69
|
torch.distributed._functional_collectives.all_gather_tensor_inplace(
|
|
69
|
-
|
|
70
|
-
)
|
|
70
|
+
output, input, group=self)
|
|
71
71
|
return self._work(output)
|
|
72
72
|
|
|
73
73
|
def allreduce(
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
74
|
+
self,
|
|
75
|
+
tensors: List[torch.Tensor],
|
|
76
|
+
opts: dist.AllreduceOptions = ...,
|
|
77
77
|
) -> dist.Work:
|
|
78
78
|
assert len(tensors) == 1
|
|
79
79
|
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
80
80
|
torch.distributed._functional_collectives.all_reduce_inplace(
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
self,
|
|
81
|
+
tensors[0],
|
|
82
|
+
torch.distributed._functional_collectives.REDUCE_OP_TO_STR[
|
|
83
|
+
opts.reduceOp.op],
|
|
84
|
+
self,
|
|
86
85
|
)
|
|
87
86
|
|
|
88
87
|
return self._work(tensors)
|
|
89
88
|
|
|
90
89
|
def broadcast(
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
90
|
+
self,
|
|
91
|
+
tensors: List[torch.Tensor],
|
|
92
|
+
opts: dist.BroadcastOptions = ...,
|
|
94
93
|
) -> dist.Work:
|
|
95
94
|
assert len(tensors) == 1
|
|
96
95
|
assert isinstance(tensors[0], torchax.tensor.Tensor)
|
|
97
96
|
tensors[0].copy_(
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
)
|
|
101
|
-
)
|
|
97
|
+
torch.distributed._functional_collectives.broadcast(
|
|
98
|
+
tensors[0], opts.rootRank, group=self))
|
|
102
99
|
|
|
103
100
|
return self._work(tensors)
|
|
104
101
|
|
|
105
102
|
|
|
106
|
-
dist.Backend.register_backend("jax", ProcessGroupJax)
|
|
103
|
+
dist.Backend.register_backend("jax", ProcessGroupJax, devices=["jax"])
|
|
107
104
|
|
|
108
105
|
|
|
109
|
-
def jax_rendezvous_handler(
|
|
110
|
-
|
|
111
|
-
):
|
|
106
|
+
def jax_rendezvous_handler(url: str,
|
|
107
|
+
timeout: datetime.timedelta = ...,
|
|
108
|
+
**kwargs):
|
|
112
109
|
"""Initialize distributed store with JAX process IDs.
|
|
113
110
|
|
|
114
111
|
Requires `$MASTER_ADDR` and `$MASTER_PORT`.
|
|
@@ -120,10 +117,10 @@ def jax_rendezvous_handler(
|
|
|
120
117
|
master_port = int(os.environ["MASTER_PORT"])
|
|
121
118
|
# TODO(wcromar): Use `torchrun`'s store if available
|
|
122
119
|
store = dist.TCPStore(
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
120
|
+
master_ip,
|
|
121
|
+
master_port,
|
|
122
|
+
jax.process_count(),
|
|
123
|
+
is_master=jax.process_index() == 0,
|
|
127
124
|
)
|
|
128
125
|
|
|
129
126
|
yield (store, jax.process_index(), jax.process_count())
|
|
@@ -145,9 +142,9 @@ def spawn(f, args=(), env: Optional[torchax.tensor.Environment] = None):
|
|
|
145
142
|
torch_outputs = f(index, *args)
|
|
146
143
|
return env.t2j_iso(torch_outputs)
|
|
147
144
|
|
|
148
|
-
jax_outputs = jax.pmap(
|
|
149
|
-
|
|
150
|
-
|
|
145
|
+
jax_outputs = jax.pmap(
|
|
146
|
+
jax_wrapper, axis_name="torch_dist")(np.arange(jax.device_count()),
|
|
147
|
+
env.t2j_iso(args))
|
|
151
148
|
return env.j2t_iso(jax_outputs)
|
|
152
149
|
|
|
153
150
|
|
|
@@ -172,11 +169,12 @@ class DistributedDataParallel(torch.nn.Module):
|
|
|
172
169
|
jax_output = jax_model(jax_data)
|
|
173
170
|
```
|
|
174
171
|
"""
|
|
172
|
+
|
|
175
173
|
def __init__(
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
174
|
+
self,
|
|
175
|
+
module: torch.nn.Module,
|
|
176
|
+
env: Optional[torchax.tensor.Environment] = None,
|
|
177
|
+
**kwargs,
|
|
180
178
|
):
|
|
181
179
|
if kwargs:
|
|
182
180
|
logging.warning(f"Unsupported kwargs {kwargs}")
|
|
@@ -184,17 +182,15 @@ class DistributedDataParallel(torch.nn.Module):
|
|
|
184
182
|
super().__init__()
|
|
185
183
|
self._env = env or torchax.default_env()
|
|
186
184
|
self._mesh = Mesh(
|
|
187
|
-
|
|
188
|
-
|
|
185
|
+
mesh_utils.create_device_mesh((jax.device_count(),)),
|
|
186
|
+
axis_names=("batch",),
|
|
189
187
|
)
|
|
190
188
|
replicated_state = torch_pytree.tree_map_only(
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
)
|
|
196
|
-
),
|
|
197
|
-
module.state_dict(),
|
|
189
|
+
torch.Tensor,
|
|
190
|
+
lambda t: self._env.j2t_iso(
|
|
191
|
+
jax.device_put(
|
|
192
|
+
self._env.to_xla(t)._elem, NamedSharding(self._mesh, P()))),
|
|
193
|
+
module.state_dict(),
|
|
198
194
|
)
|
|
199
195
|
# TODO: broadcast
|
|
200
196
|
module.load_state_dict(replicated_state, assign=True)
|
|
@@ -208,25 +204,24 @@ class DistributedDataParallel(torch.nn.Module):
|
|
|
208
204
|
global_batch_shape = (global_batch_size,) + inp.shape[1:]
|
|
209
205
|
|
|
210
206
|
sharding = NamedSharding(self._mesh, P("batch"))
|
|
211
|
-
return self._env.j2t_iso(
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
))
|
|
207
|
+
return self._env.j2t_iso(
|
|
208
|
+
jax.make_array_from_single_device_arrays(
|
|
209
|
+
global_batch_shape,
|
|
210
|
+
NamedSharding(self._mesh, P("batch")),
|
|
211
|
+
arrays=[
|
|
212
|
+
jax.device_put(self._env.to_xla(batch)._elem, device) for batch,
|
|
213
|
+
device in zip(per_replica_batches, sharding.addressable_devices)
|
|
214
|
+
],
|
|
215
|
+
))
|
|
221
216
|
|
|
222
217
|
def replicate_input(self, inp):
|
|
223
218
|
return self._env.j2t_iso(
|
|
224
|
-
|
|
225
|
-
)
|
|
219
|
+
jax.device_put(inp._elem, NamedSharding(self._mesh, P())))
|
|
226
220
|
|
|
227
221
|
def jit_step(self, func):
|
|
228
|
-
|
|
229
|
-
|
|
222
|
+
|
|
223
|
+
@functools.partial(
|
|
224
|
+
interop.jax_jit, kwargs_for_jax_jit={'donate_argnums': 0})
|
|
230
225
|
def _jit_fn(states, args):
|
|
231
226
|
self.load_state_dict(states)
|
|
232
227
|
outputs = func(*args)
|
torchax/export.py
CHANGED
|
@@ -4,14 +4,14 @@ import copy
|
|
|
4
4
|
from typing import Any, Dict, Tuple
|
|
5
5
|
import torch
|
|
6
6
|
from torch.utils import _pytree as pytree
|
|
7
|
+
import torchax
|
|
7
8
|
from torchax import tensor
|
|
8
|
-
from torchax.ops import ops_registry
|
|
9
|
+
from torchax.ops import ops_registry, mappings
|
|
9
10
|
from torchax import decompositions
|
|
10
11
|
import jax
|
|
11
12
|
import jax.export
|
|
12
13
|
import sympy
|
|
13
14
|
|
|
14
|
-
|
|
15
15
|
DEBUG = False
|
|
16
16
|
|
|
17
17
|
|
|
@@ -83,7 +83,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
|
83
83
|
if torch.__version__ >= '2.2':
|
|
84
84
|
# torch version 2.1 didn't expose this yet
|
|
85
85
|
exported_program = exported_program.run_decompositions()
|
|
86
|
-
exported_program = exported_program.run_decompositions(
|
|
86
|
+
exported_program = exported_program.run_decompositions(
|
|
87
|
+
decompositions.DECOMPOSITIONS)
|
|
87
88
|
if DEBUG:
|
|
88
89
|
print(exported_program.graph_module.code)
|
|
89
90
|
|
|
@@ -108,8 +109,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
|
108
109
|
|
|
109
110
|
if export_raw:
|
|
110
111
|
return names, states, func
|
|
111
|
-
|
|
112
|
-
states =
|
|
112
|
+
env = torchax.default_env()
|
|
113
|
+
states = env.t2j_copy(states)
|
|
113
114
|
return states, func
|
|
114
115
|
|
|
115
116
|
|
|
@@ -121,34 +122,35 @@ def extract_avals(exported):
|
|
|
121
122
|
def _to_aval(arg_meta, symbolic_shapes):
|
|
122
123
|
"""Convet from torch type to jax abstract value for export tracing
|
|
123
124
|
"""
|
|
125
|
+
|
|
124
126
|
def _get_dim(d):
|
|
125
127
|
if isinstance(d, torch.SymInt):
|
|
126
128
|
return symbolic_shapes[str(d)]
|
|
127
129
|
return d
|
|
128
130
|
|
|
129
131
|
val = arg_meta['val']
|
|
130
|
-
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
|
|
132
|
+
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
|
|
133
|
+
val, bool)
|
|
131
134
|
if is_scalar:
|
|
132
135
|
return jax.ShapeDtypeStruct([], type(arg_meta['val']))
|
|
133
136
|
|
|
134
137
|
tensor_meta = arg_meta['tensor_meta']
|
|
135
138
|
shape = [_get_dim(d) for d in tensor_meta.shape]
|
|
136
|
-
return jax.ShapeDtypeStruct(shape,
|
|
139
|
+
return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
|
|
137
140
|
|
|
138
141
|
def _get_inputs(exported):
|
|
139
142
|
"""Return placeholders with input metadata"""
|
|
140
143
|
placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
|
|
141
144
|
input_placeholders = [
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
145
|
+
p for p, s in zip(placeholders, exported.graph_signature.input_specs)
|
|
146
|
+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
145
147
|
]
|
|
146
148
|
return input_placeholders
|
|
147
149
|
|
|
148
150
|
def _build_symbolic_shapes(range_constraints):
|
|
149
151
|
"""Convert torch SymInt to JAX symbolic_shape and stores in a map using the
|
|
150
152
|
string name of the torch symbolic int.
|
|
151
|
-
|
|
153
|
+
|
|
152
154
|
TODO: There is probably a better way of storing a key for a symbolic int.
|
|
153
155
|
This value needs to be looked up again in `_to_aval` to figure out which
|
|
154
156
|
JAX symbolic to map to for a given torch tensor.
|
|
@@ -163,8 +165,10 @@ def extract_avals(exported):
|
|
|
163
165
|
torch.export.Dim("a", min=5, max=10)
|
|
164
166
|
==> ("a >= 5", "a <= 10",)
|
|
165
167
|
"""
|
|
166
|
-
if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
|
|
167
|
-
|
|
168
|
+
if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
|
|
169
|
+
ValueRanges) or torch_constraint.is_bool:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"No symbolic constraint handler for: {torch_constraint}")
|
|
168
172
|
|
|
169
173
|
constraints = []
|
|
170
174
|
symbol = sympy.Symbol(symbol_name)
|
|
@@ -182,7 +186,7 @@ def extract_avals(exported):
|
|
|
182
186
|
There are two possible sympy `sym` inputs:
|
|
183
187
|
1. Symbol - (s0) These can have custom constraints.
|
|
184
188
|
2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
|
|
185
|
-
|
|
189
|
+
|
|
186
190
|
Currently support is limited to operations with a symbol and and int,
|
|
187
191
|
in `torch/export/dynamic_shapes.py`:
|
|
188
192
|
"Only increasing linear operations with integer coefficients are supported."
|
|
@@ -190,7 +194,8 @@ def extract_avals(exported):
|
|
|
190
194
|
symbol_name = str(sym)
|
|
191
195
|
constraints = _build_symbolic_constraints(symbol_name, constraint)
|
|
192
196
|
if sym.is_symbol:
|
|
193
|
-
symbolic_shape = jax.export.symbolic_shape(
|
|
197
|
+
symbolic_shape = jax.export.symbolic_shape(
|
|
198
|
+
symbol_name, constraints=constraints)
|
|
194
199
|
else:
|
|
195
200
|
assert len(sym.free_symbols) > 0
|
|
196
201
|
scope = free_symbols[str(list(sym.free_symbols)[0])].scope
|
|
@@ -203,8 +208,12 @@ def extract_avals(exported):
|
|
|
203
208
|
# integer compuations on symbol variables, so each symbol variable is OK to
|
|
204
209
|
# have its own scope.
|
|
205
210
|
symbolic_shapes = {}
|
|
206
|
-
symbol_variables = [
|
|
207
|
-
|
|
211
|
+
symbol_variables = [
|
|
212
|
+
(s, v) for s, v in range_constraints.items() if s.is_symbol
|
|
213
|
+
]
|
|
214
|
+
symbol_exprs = [
|
|
215
|
+
(s, v) for s, v in range_constraints.items() if not s.is_symbol
|
|
216
|
+
]
|
|
208
217
|
for sym, constraint in symbol_variables + symbol_exprs:
|
|
209
218
|
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
|
|
210
219
|
symbolic_shapes[str(sym)] = symbolic_shape
|
torchax/flax.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Flax interop."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torchax as tx
|
|
5
|
+
import torchax.interop
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FlaxNNModule(torch.nn.Module):
|
|
9
|
+
|
|
10
|
+
def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
|
|
11
|
+
super().__init__()
|
|
12
|
+
prng = env.prng_key
|
|
13
|
+
sample_kwargs = sample_kwargs or {}
|
|
14
|
+
parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
|
|
15
|
+
**sample_kwargs)
|
|
16
|
+
|
|
17
|
+
self._params = self._encode_nested_dict(parameter_dict)
|
|
18
|
+
|
|
19
|
+
self._flax_module = flax_module
|
|
20
|
+
|
|
21
|
+
def _encode_nested_dict(self, nested_dict):
|
|
22
|
+
child_module = torch.nn.Module()
|
|
23
|
+
for k, v in nested_dict.items():
|
|
24
|
+
if isinstance(v, dict):
|
|
25
|
+
child_module.add_module(k, self._encode_nested_dict(v))
|
|
26
|
+
else:
|
|
27
|
+
child_module.register_parameter(k, torch.nn.Parameter(v))
|
|
28
|
+
return child_module
|
|
29
|
+
|
|
30
|
+
def _decode_nested_dict(self, child_module):
|
|
31
|
+
result = dict(child_module.named_parameters(recurse=False))
|
|
32
|
+
for k, v in child_module.named_children():
|
|
33
|
+
result[k] = self._decode_nested_dict(v)
|
|
34
|
+
return result
|
|
35
|
+
|
|
36
|
+
def forward(self, *args, **kwargs):
|
|
37
|
+
nested_dict_params = self._decode_nested_dict(self._params)
|
|
38
|
+
return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
|
|
39
|
+
*args, **kwargs)
|