torchax 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of torchax might be problematic. Click here for more details.
- torchax/CONTRIBUTING.md +2 -2
- torchax/__init__.py +26 -24
- torchax/amp.py +332 -0
- torchax/config.py +25 -14
- torchax/configuration.py +30 -0
- torchax/decompositions.py +663 -195
- torchax/device_module.py +14 -1
- torchax/environment.py +0 -1
- torchax/export.py +26 -17
- torchax/flax.py +39 -0
- torchax/interop.py +288 -141
- torchax/mesh_util.py +220 -0
- torchax/ops/jaten.py +1723 -1297
- torchax/ops/jax_reimplement.py +23 -21
- torchax/ops/jc10d.py +5 -4
- torchax/ops/jimage.py +113 -0
- torchax/ops/jlibrary.py +9 -2
- torchax/ops/jtorch.py +237 -88
- torchax/ops/jtorchvision_nms.py +32 -43
- torchax/ops/mappings.py +77 -35
- torchax/ops/op_base.py +59 -32
- torchax/ops/ops_registry.py +40 -35
- torchax/tensor.py +442 -288
- torchax/train.py +38 -41
- torchax/util.py +88 -0
- torchax/view.py +377 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/METADATA +111 -145
- torchax-0.0.6.dist-info/RECORD +33 -0
- torchax/distributed.py +0 -246
- torchax-0.0.4.dist-info/RECORD +0 -27
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/WHEEL +0 -0
- {torchax-0.0.4.dist-info → torchax-0.0.6.dist-info}/licenses/LICENSE +0 -0
torchax/device_module.py
CHANGED
|
@@ -1,20 +1,33 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
1
4
|
def _is_in_bad_fork():
|
|
2
5
|
return False
|
|
3
6
|
|
|
7
|
+
|
|
4
8
|
def manual_seed_all(seed):
|
|
5
9
|
pass
|
|
6
10
|
|
|
11
|
+
|
|
7
12
|
def device_count():
|
|
8
13
|
return 1
|
|
9
14
|
|
|
15
|
+
|
|
10
16
|
def get_rng_state():
|
|
11
17
|
return []
|
|
12
18
|
|
|
19
|
+
|
|
13
20
|
def set_rng_state(new_state, device):
|
|
14
21
|
pass
|
|
15
22
|
|
|
23
|
+
|
|
16
24
|
def is_available():
|
|
17
25
|
return True
|
|
18
26
|
|
|
27
|
+
|
|
19
28
|
def current_device():
|
|
20
|
-
return 0
|
|
29
|
+
return 0
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_amp_supported_dtype():
|
|
33
|
+
return [torch.float16, torch.bfloat16]
|
torchax/environment.py
CHANGED
torchax/export.py
CHANGED
|
@@ -4,14 +4,14 @@ import copy
|
|
|
4
4
|
from typing import Any, Dict, Tuple
|
|
5
5
|
import torch
|
|
6
6
|
from torch.utils import _pytree as pytree
|
|
7
|
+
import torchax
|
|
7
8
|
from torchax import tensor
|
|
8
|
-
from torchax.ops import ops_registry
|
|
9
|
+
from torchax.ops import ops_registry, mappings
|
|
9
10
|
from torchax import decompositions
|
|
10
11
|
import jax
|
|
11
12
|
import jax.export
|
|
12
13
|
import sympy
|
|
13
14
|
|
|
14
|
-
|
|
15
15
|
DEBUG = False
|
|
16
16
|
|
|
17
17
|
|
|
@@ -83,7 +83,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
|
83
83
|
if torch.__version__ >= '2.2':
|
|
84
84
|
# torch version 2.1 didn't expose this yet
|
|
85
85
|
exported_program = exported_program.run_decompositions()
|
|
86
|
-
exported_program = exported_program.run_decompositions(
|
|
86
|
+
exported_program = exported_program.run_decompositions(
|
|
87
|
+
decompositions.DECOMPOSITIONS)
|
|
87
88
|
if DEBUG:
|
|
88
89
|
print(exported_program.graph_module.code)
|
|
89
90
|
|
|
@@ -108,8 +109,8 @@ def exported_program_to_jax(exported_program, export_raw: bool = False):
|
|
|
108
109
|
|
|
109
110
|
if export_raw:
|
|
110
111
|
return names, states, func
|
|
111
|
-
|
|
112
|
-
states =
|
|
112
|
+
env = torchax.default_env()
|
|
113
|
+
states = env.t2j_copy(states)
|
|
113
114
|
return states, func
|
|
114
115
|
|
|
115
116
|
|
|
@@ -121,34 +122,35 @@ def extract_avals(exported):
|
|
|
121
122
|
def _to_aval(arg_meta, symbolic_shapes):
|
|
122
123
|
"""Convet from torch type to jax abstract value for export tracing
|
|
123
124
|
"""
|
|
125
|
+
|
|
124
126
|
def _get_dim(d):
|
|
125
127
|
if isinstance(d, torch.SymInt):
|
|
126
128
|
return symbolic_shapes[str(d)]
|
|
127
129
|
return d
|
|
128
130
|
|
|
129
131
|
val = arg_meta['val']
|
|
130
|
-
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
|
|
132
|
+
is_scalar = isinstance(val, float) or isinstance(val, int) or isinstance(
|
|
133
|
+
val, bool)
|
|
131
134
|
if is_scalar:
|
|
132
135
|
return jax.ShapeDtypeStruct([], type(arg_meta['val']))
|
|
133
136
|
|
|
134
137
|
tensor_meta = arg_meta['tensor_meta']
|
|
135
138
|
shape = [_get_dim(d) for d in tensor_meta.shape]
|
|
136
|
-
return jax.ShapeDtypeStruct(shape,
|
|
139
|
+
return jax.ShapeDtypeStruct(shape, mappings.t2j_dtype(tensor_meta.dtype))
|
|
137
140
|
|
|
138
141
|
def _get_inputs(exported):
|
|
139
142
|
"""Return placeholders with input metadata"""
|
|
140
143
|
placeholders = [p for p in exported.graph.nodes if p.op == "placeholder"]
|
|
141
144
|
input_placeholders = [
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
145
|
+
p for p, s in zip(placeholders, exported.graph_signature.input_specs)
|
|
146
|
+
if s.kind == torch.export.graph_signature.InputKind.USER_INPUT
|
|
145
147
|
]
|
|
146
148
|
return input_placeholders
|
|
147
149
|
|
|
148
150
|
def _build_symbolic_shapes(range_constraints):
|
|
149
151
|
"""Convert torch SymInt to JAX symbolic_shape and stores in a map using the
|
|
150
152
|
string name of the torch symbolic int.
|
|
151
|
-
|
|
153
|
+
|
|
152
154
|
TODO: There is probably a better way of storing a key for a symbolic int.
|
|
153
155
|
This value needs to be looked up again in `_to_aval` to figure out which
|
|
154
156
|
JAX symbolic to map to for a given torch tensor.
|
|
@@ -163,8 +165,10 @@ def extract_avals(exported):
|
|
|
163
165
|
torch.export.Dim("a", min=5, max=10)
|
|
164
166
|
==> ("a >= 5", "a <= 10",)
|
|
165
167
|
"""
|
|
166
|
-
if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
|
|
167
|
-
|
|
168
|
+
if not isinstance(torch_constraint, torch.utils._sympy.value_ranges.
|
|
169
|
+
ValueRanges) or torch_constraint.is_bool:
|
|
170
|
+
raise TypeError(
|
|
171
|
+
f"No symbolic constraint handler for: {torch_constraint}")
|
|
168
172
|
|
|
169
173
|
constraints = []
|
|
170
174
|
symbol = sympy.Symbol(symbol_name)
|
|
@@ -182,7 +186,7 @@ def extract_avals(exported):
|
|
|
182
186
|
There are two possible sympy `sym` inputs:
|
|
183
187
|
1. Symbol - (s0) These can have custom constraints.
|
|
184
188
|
2. Expr - (s0*2) These apply the expr to s0's constraints, cannot override.
|
|
185
|
-
|
|
189
|
+
|
|
186
190
|
Currently support is limited to operations with a symbol and and int,
|
|
187
191
|
in `torch/export/dynamic_shapes.py`:
|
|
188
192
|
"Only increasing linear operations with integer coefficients are supported."
|
|
@@ -190,7 +194,8 @@ def extract_avals(exported):
|
|
|
190
194
|
symbol_name = str(sym)
|
|
191
195
|
constraints = _build_symbolic_constraints(symbol_name, constraint)
|
|
192
196
|
if sym.is_symbol:
|
|
193
|
-
symbolic_shape = jax.export.symbolic_shape(
|
|
197
|
+
symbolic_shape = jax.export.symbolic_shape(
|
|
198
|
+
symbol_name, constraints=constraints)
|
|
194
199
|
else:
|
|
195
200
|
assert len(sym.free_symbols) > 0
|
|
196
201
|
scope = free_symbols[str(list(sym.free_symbols)[0])].scope
|
|
@@ -203,8 +208,12 @@ def extract_avals(exported):
|
|
|
203
208
|
# integer compuations on symbol variables, so each symbol variable is OK to
|
|
204
209
|
# have its own scope.
|
|
205
210
|
symbolic_shapes = {}
|
|
206
|
-
symbol_variables = [
|
|
207
|
-
|
|
211
|
+
symbol_variables = [
|
|
212
|
+
(s, v) for s, v in range_constraints.items() if s.is_symbol
|
|
213
|
+
]
|
|
214
|
+
symbol_exprs = [
|
|
215
|
+
(s, v) for s, v in range_constraints.items() if not s.is_symbol
|
|
216
|
+
]
|
|
208
217
|
for sym, constraint in symbol_variables + symbol_exprs:
|
|
209
218
|
symbolic_shape = _build_symbolic_shape(sym, constraint, symbolic_shapes)
|
|
210
219
|
symbolic_shapes[str(sym)] = symbolic_shape
|
torchax/flax.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""Flax interop."""
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torchax as tx
|
|
5
|
+
import torchax.interop
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class FlaxNNModule(torch.nn.Module):
|
|
9
|
+
|
|
10
|
+
def __init__(self, env, flax_module, sample_args, sample_kwargs=None):
|
|
11
|
+
super().__init__()
|
|
12
|
+
prng = env.prng_key
|
|
13
|
+
sample_kwargs = sample_kwargs or {}
|
|
14
|
+
parameter_dict = tx.interop.call_jax(flax_module.init, prng, *sample_args,
|
|
15
|
+
**sample_kwargs)
|
|
16
|
+
|
|
17
|
+
self._params = self._encode_nested_dict(parameter_dict)
|
|
18
|
+
|
|
19
|
+
self._flax_module = flax_module
|
|
20
|
+
|
|
21
|
+
def _encode_nested_dict(self, nested_dict):
|
|
22
|
+
child_module = torch.nn.Module()
|
|
23
|
+
for k, v in nested_dict.items():
|
|
24
|
+
if isinstance(v, dict):
|
|
25
|
+
child_module.add_module(k, self._encode_nested_dict(v))
|
|
26
|
+
else:
|
|
27
|
+
child_module.register_parameter(k, torch.nn.Parameter(v))
|
|
28
|
+
return child_module
|
|
29
|
+
|
|
30
|
+
def _decode_nested_dict(self, child_module):
|
|
31
|
+
result = dict(child_module.named_parameters(recurse=False))
|
|
32
|
+
for k, v in child_module.named_children():
|
|
33
|
+
result[k] = self._decode_nested_dict(v)
|
|
34
|
+
return result
|
|
35
|
+
|
|
36
|
+
def forward(self, *args, **kwargs):
|
|
37
|
+
nested_dict_params = self._decode_nested_dict(self._params)
|
|
38
|
+
return tx.interop.call_jax(self._flax_module.apply, nested_dict_params,
|
|
39
|
+
*args, **kwargs)
|