statedict2pytree 1.0.1__tar.gz → 2.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/.gitignore +1 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/PKG-INFO +3 -3
- statedict2pytree-2.0.0/flax_example.py +33 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/pyproject.toml +2 -2
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/statedict2pytree/converter.py +92 -26
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/tests/test_batchnorm.py +5 -3
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/.github/workflows/run_tests.yml +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/.pre-commit-config.yaml +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/README.md +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/docs/index.md +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/mkdocs.yml +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/pyrightconfig.json +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/statedict2pytree/__init__.py +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/tests/test_conv.py +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/tests/test_linear.py +0 -0
- {statedict2pytree-1.0.1 → statedict2pytree-2.0.0}/uv.lock +0 -0
|
@@ -1,11 +1,11 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version:
|
|
3
|
+
Version: 2.0.0
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: >=3.11
|
|
7
7
|
Requires-Dist: beartype
|
|
8
|
-
Requires-Dist: equinox
|
|
8
|
+
Requires-Dist: equinox>=0.13.0
|
|
9
9
|
Requires-Dist: jax
|
|
10
10
|
Requires-Dist: jaxlib
|
|
11
11
|
Requires-Dist: jaxtyping
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from flax import nnx
|
|
3
|
+
from statedict2pytree import autoconvert, pytree_to_fields
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TorchModel(torch.nn.Module):
|
|
7
|
+
def __init__(self, din, dout):
|
|
8
|
+
super(TorchModel, self).__init__()
|
|
9
|
+
self.linear = torch.nn.Linear(din, dout, bias=False)
|
|
10
|
+
self.linear.weight.data = torch.ones_like(self.linear.weight)
|
|
11
|
+
|
|
12
|
+
def forward(self, x):
|
|
13
|
+
return self.linear(x)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Model(nnx.Module):
|
|
17
|
+
def __init__(self, din, dout, rngs: nnx.Rngs):
|
|
18
|
+
self.linear = nnx.Linear(din, dout, rngs=rngs, use_bias=False)
|
|
19
|
+
|
|
20
|
+
def __call__(self, x):
|
|
21
|
+
return self.linear(x)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
flax_model = Model(2, 64, rngs=nnx.Rngs(0))
|
|
25
|
+
torch_model = TorchModel(2, 64)
|
|
26
|
+
|
|
27
|
+
pt_fields = pytree_to_fields(flax_model)
|
|
28
|
+
print(pt_fields)
|
|
29
|
+
|
|
30
|
+
flax_model = autoconvert(flax_model, torch_model.state_dict())
|
|
31
|
+
|
|
32
|
+
print(flax_model)
|
|
33
|
+
print(flax_model.linear.kernel.value)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "statedict2pytree"
|
|
3
|
-
version = "
|
|
3
|
+
version = "2.0.0"
|
|
4
4
|
description = "Converts torch models into PyTrees for Equinox"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = ">=3.11"
|
|
7
7
|
authors = [{ name = "Artur A. Galstyan", email = "mail@arturgalstyan.dev" }]
|
|
8
8
|
dependencies = [
|
|
9
9
|
"jax",
|
|
10
|
-
"equinox",
|
|
10
|
+
"equinox>=0.13.0",
|
|
11
11
|
"jaxlib",
|
|
12
12
|
"beartype",
|
|
13
13
|
"jaxtyping",
|
|
@@ -28,7 +28,8 @@ class TorchField(BaseModel):
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class JaxField(BaseModel):
|
|
31
|
-
|
|
31
|
+
key_path: KeyPath
|
|
32
|
+
path: str
|
|
32
33
|
shape: tuple[int, ...]
|
|
33
34
|
skip: bool = False
|
|
34
35
|
|
|
@@ -45,7 +46,7 @@ def is_numerical(element: Any):
|
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
def _default_floating_dtype():
|
|
48
|
-
if jax.config.jax_enable_x64: #
|
|
49
|
+
if jax.config.jax_enable_x64: # ty: ignore
|
|
49
50
|
return jnp.float64
|
|
50
51
|
else:
|
|
51
52
|
return jnp.float32
|
|
@@ -86,15 +87,15 @@ def _get_stateindex_fields(obj) -> dict:
|
|
|
86
87
|
|
|
87
88
|
|
|
88
89
|
def _get_node(
|
|
89
|
-
tree: PyTree,
|
|
90
|
+
tree: PyTree, key_path: KeyPath, state_indices: dict | None = None
|
|
90
91
|
) -> tuple[PyTree | None, dict | None]:
|
|
91
92
|
if tree is None:
|
|
92
93
|
return None, {}
|
|
93
94
|
else:
|
|
94
|
-
if len(
|
|
95
|
+
if len(key_path) == 0:
|
|
95
96
|
return tree, state_indices
|
|
96
|
-
f, *_ =
|
|
97
|
-
if hasattr(tree, "is_stateful"):
|
|
97
|
+
f, *_ = key_path
|
|
98
|
+
if hasattr(tree, "is_stateful") and tree.is_stateful():
|
|
98
99
|
if state_indices is None:
|
|
99
100
|
state_indices = {}
|
|
100
101
|
indices = _get_stateindex_fields(tree)
|
|
@@ -109,23 +110,26 @@ def _get_node(
|
|
|
109
110
|
elif isinstance(f, FlattenedIndexKey):
|
|
110
111
|
if isinstance(tree, eqx.nn.State):
|
|
111
112
|
assert state_indices is not None
|
|
112
|
-
|
|
113
|
+
|
|
114
|
+
markers = list(tree._state.keys())
|
|
115
|
+
marker = markers[f.key]
|
|
116
|
+
index = state_indices[marker]
|
|
113
117
|
subtree = tree.get(index)
|
|
114
118
|
else:
|
|
115
119
|
subtree = None
|
|
116
120
|
else:
|
|
117
121
|
subtree = None
|
|
118
|
-
return _get_node(subtree,
|
|
122
|
+
return _get_node(subtree, key_path[1:], state_indices)
|
|
119
123
|
|
|
120
124
|
|
|
121
125
|
def _replace_node(
|
|
122
|
-
tree: PyTree,
|
|
126
|
+
tree: PyTree, key_path: KeyPath, new_value: Array, state_indices: dict | None = None
|
|
123
127
|
) -> PyTree:
|
|
124
128
|
def where_wrapper(t):
|
|
125
|
-
node, _ = _get_node(t,
|
|
129
|
+
node, _ = _get_node(t, key_path=key_path, state_indices=state_indices)
|
|
126
130
|
return node
|
|
127
131
|
|
|
128
|
-
node, _ = _get_node(tree,
|
|
132
|
+
node, _ = _get_node(tree, key_path=key_path, state_indices=state_indices)
|
|
129
133
|
|
|
130
134
|
if node is not None and eqx.is_array(node):
|
|
131
135
|
tree = eqx.tree_at(
|
|
@@ -134,10 +138,48 @@ def _replace_node(
|
|
|
134
138
|
new_value.reshape(node.shape),
|
|
135
139
|
)
|
|
136
140
|
else:
|
|
137
|
-
print("WARNING: Couldn't find: ", jax.tree_util.keystr(
|
|
141
|
+
print("WARNING: Couldn't find: ", jax.tree_util.keystr(key_path))
|
|
138
142
|
return tree
|
|
139
143
|
|
|
140
144
|
|
|
145
|
+
def _resolve_state_names(pytree: PyTree, key_path: KeyPath) -> str:
|
|
146
|
+
parts = []
|
|
147
|
+
current = pytree
|
|
148
|
+
|
|
149
|
+
for i, key in enumerate(key_path):
|
|
150
|
+
if isinstance(key, GetAttrKey):
|
|
151
|
+
parts.append(key.name)
|
|
152
|
+
current = getattr(current, key.name)
|
|
153
|
+
elif isinstance(key, SequenceKey):
|
|
154
|
+
if i == 0 and isinstance(pytree, tuple) and len(pytree) == 2:
|
|
155
|
+
current = current[key.idx]
|
|
156
|
+
continue
|
|
157
|
+
current = current[key.idx]
|
|
158
|
+
parts.append(str(key.idx))
|
|
159
|
+
elif isinstance(key, FlattenedIndexKey):
|
|
160
|
+
if isinstance(current, eqx.nn.State):
|
|
161
|
+
markers = list(current._state.keys())
|
|
162
|
+
marker = markers[key.key]
|
|
163
|
+
clean_marker = marker.split(".", 1)[-1] if "." in marker else marker
|
|
164
|
+
parts.append(clean_marker)
|
|
165
|
+
current = current._state[marker]
|
|
166
|
+
else:
|
|
167
|
+
parts.append(str(key))
|
|
168
|
+
|
|
169
|
+
return ".".join(parts)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _normalize_eqx_name(path: str) -> str:
|
|
173
|
+
if "batch_state_index.0" in path:
|
|
174
|
+
return path.replace("batch_state_index.0", "running_mean")
|
|
175
|
+
if "batch_state_index.1" in path:
|
|
176
|
+
return path.replace("batch_state_index.1", "running_var")
|
|
177
|
+
if "batch_counter" in path:
|
|
178
|
+
return path.replace("batch_counter", "num_batches_tracked")
|
|
179
|
+
|
|
180
|
+
return path
|
|
181
|
+
|
|
182
|
+
|
|
141
183
|
def move_running_fields_to_the_end(
|
|
142
184
|
torchfields: list[TorchField], identifier: str = "running_"
|
|
143
185
|
):
|
|
@@ -159,14 +201,16 @@ def move_running_fields_to_the_end(
|
|
|
159
201
|
|
|
160
202
|
|
|
161
203
|
def state_dict_to_fields(
|
|
162
|
-
state_dict: dict[str, Any],
|
|
204
|
+
state_dict: dict[str, Any], sort_by_path: bool = True
|
|
163
205
|
) -> list[TorchField]:
|
|
164
206
|
if state_dict is None:
|
|
165
207
|
return []
|
|
166
208
|
fields: list[TorchField] = []
|
|
167
209
|
for key, value in state_dict.items():
|
|
168
|
-
if hasattr(value, "shape")
|
|
210
|
+
if hasattr(value, "shape"):
|
|
169
211
|
fields.append(TorchField(path=key, shape=tuple(value.shape)))
|
|
212
|
+
|
|
213
|
+
fields.sort(key=lambda x: x.path)
|
|
170
214
|
return fields
|
|
171
215
|
|
|
172
216
|
|
|
@@ -174,17 +218,29 @@ def pytree_to_fields(
|
|
|
174
218
|
pytree: PyTree,
|
|
175
219
|
model_order: list[str] | None = None,
|
|
176
220
|
filter: Callable[[Array], bool] = eqx.is_array,
|
|
221
|
+
sort_by_path: bool = True,
|
|
177
222
|
) -> tuple[list[JaxField], dict | None]:
|
|
178
223
|
jaxfields = []
|
|
179
224
|
paths = jax.tree.leaves_with_path(pytree)
|
|
180
|
-
|
|
225
|
+
state_indices = {}
|
|
181
226
|
for p in paths:
|
|
182
227
|
keys, _ = p
|
|
183
|
-
n,
|
|
228
|
+
n, state_indices = _get_node(pytree, keys, state_indices)
|
|
184
229
|
if n is not None and filter(n):
|
|
185
|
-
|
|
230
|
+
readable_path = _resolve_state_names(pytree, keys)
|
|
231
|
+
jaxfields.append(
|
|
232
|
+
JaxField(
|
|
233
|
+
key_path=keys,
|
|
234
|
+
shape=n.shape,
|
|
235
|
+
path=_normalize_eqx_name(readable_path),
|
|
236
|
+
)
|
|
237
|
+
)
|
|
186
238
|
|
|
187
239
|
if model_order is not None:
|
|
240
|
+
if sort_by_path:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"model_order is given and sort_by_path is true. Only one of those is allowed"
|
|
243
|
+
)
|
|
188
244
|
ordered_jaxfields = []
|
|
189
245
|
path_dict = {jax.tree_util.keystr(field.path): field for field in jaxfields}
|
|
190
246
|
|
|
@@ -194,7 +250,9 @@ def pytree_to_fields(
|
|
|
194
250
|
del path_dict[path_str]
|
|
195
251
|
ordered_jaxfields.extend(path_dict.values())
|
|
196
252
|
jaxfields = ordered_jaxfields
|
|
197
|
-
|
|
253
|
+
elif sort_by_path:
|
|
254
|
+
jaxfields.sort(key=lambda j: j.path)
|
|
255
|
+
return jaxfields, state_indices
|
|
198
256
|
|
|
199
257
|
|
|
200
258
|
def _chunkify_state_dict(
|
|
@@ -245,7 +303,9 @@ def convert(
|
|
|
245
303
|
if len(torchfields) != len(jaxfields):
|
|
246
304
|
raise ValueError(
|
|
247
305
|
f"Length of state_dict ({len(torchfields)}) "
|
|
248
|
-
f"!= length of pytree ({len(jaxfields)})"
|
|
306
|
+
f"!= length of pytree ({len(jaxfields)}) "
|
|
307
|
+
"Note: if you are using BatchNorm in Equinox, make sure to "
|
|
308
|
+
'add BatchNorm(..., mode="batch"), which might fix your issue!'
|
|
249
309
|
)
|
|
250
310
|
|
|
251
311
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -259,28 +319,34 @@ def convert(
|
|
|
259
319
|
"Note that the order of the fields matters "
|
|
260
320
|
"and that you can mark arrays as skippable. "
|
|
261
321
|
f"{t.path=} "
|
|
262
|
-
f"{jax.tree_util.keystr(j.
|
|
322
|
+
f"{jax.tree_util.keystr(j.key_path)=} ({j.path})"
|
|
263
323
|
)
|
|
264
324
|
state_dict_dir = pathlib.Path(chunkified_statedict_path.path) / "state_dict"
|
|
265
325
|
filename = state_dict_dir / t.path
|
|
266
326
|
new_value = jnp.array(np.load(str(filename) + ".npy"))
|
|
267
327
|
|
|
268
|
-
n, _ = _get_node(pytree, j.
|
|
269
|
-
assert n is not None, f"Node {j.
|
|
328
|
+
n, _ = _get_node(pytree, j.key_path, state_indices)
|
|
329
|
+
assert n is not None, f"Node {j.key_path} not found"
|
|
270
330
|
assert _can_reshape(n.shape, new_value.shape), (
|
|
271
331
|
f"Cannot reshape {n.shape} into {new_value.shape}"
|
|
272
332
|
)
|
|
273
333
|
|
|
274
|
-
pytree = _replace_node(pytree, j.
|
|
334
|
+
pytree = _replace_node(pytree, j.key_path, new_value, state_indices)
|
|
275
335
|
|
|
276
336
|
return pytree
|
|
277
337
|
|
|
278
338
|
|
|
279
339
|
def autoconvert(
|
|
280
|
-
pytree: PyTree,
|
|
340
|
+
pytree: PyTree,
|
|
341
|
+
state_dict: dict,
|
|
342
|
+
filter: Callable[[Array], bool] = eqx.is_array,
|
|
343
|
+
pytree_model_order: list[str] | None = None,
|
|
344
|
+
sort_by_path: bool = True,
|
|
281
345
|
) -> PyTree:
|
|
282
|
-
torchfields = state_dict_to_fields(state_dict)
|
|
283
|
-
jaxfields, state_indices = pytree_to_fields(
|
|
346
|
+
torchfields = state_dict_to_fields(state_dict, sort_by_path)
|
|
347
|
+
jaxfields, state_indices = pytree_to_fields(
|
|
348
|
+
pytree, pytree_model_order, filter=filter, sort_by_path=sort_by_path
|
|
349
|
+
)
|
|
284
350
|
|
|
285
351
|
pytree = convert(
|
|
286
352
|
state_dict,
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import equinox as eqx
|
|
2
2
|
import jax
|
|
3
3
|
import numpy as np
|
|
4
|
-
import statedict2pytree as s2p
|
|
5
4
|
import torch
|
|
6
5
|
|
|
6
|
+
import statedict2pytree as s2p
|
|
7
|
+
|
|
7
8
|
|
|
8
9
|
def test_linear():
|
|
9
10
|
in_features = 10
|
|
@@ -17,7 +18,9 @@ def test_linear():
|
|
|
17
18
|
self.linear = eqx.nn.Linear(
|
|
18
19
|
in_features, out_features, key=jax.random.PRNGKey(30)
|
|
19
20
|
)
|
|
20
|
-
self.norm = eqx.nn.BatchNorm(
|
|
21
|
+
self.norm = eqx.nn.BatchNorm(
|
|
22
|
+
input_size=out_features, axis_name="batch", mode="batch"
|
|
23
|
+
)
|
|
21
24
|
|
|
22
25
|
class T(torch.nn.Module):
|
|
23
26
|
def __init__(self) -> None:
|
|
@@ -31,7 +34,6 @@ def test_linear():
|
|
|
31
34
|
state_dict = torch_model.state_dict()
|
|
32
35
|
|
|
33
36
|
torchfields = s2p.state_dict_to_fields(state_dict)
|
|
34
|
-
torchfields = s2p.move_running_fields_to_the_end(torchfields)
|
|
35
37
|
|
|
36
38
|
jaxfields, state_indices = s2p.pytree_to_fields(
|
|
37
39
|
(model, state), filter=s2p.is_numerical
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|