statedict2pytree 0.6.0__py3-none-any.whl → 1.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/__init__.py +8 -0
- statedict2pytree/converter.py +293 -0
- {statedict2pytree-0.6.0.dist-info → statedict2pytree-1.0.1.dist-info}/METADATA +10 -28
- statedict2pytree-1.0.1.dist-info/RECORD +5 -0
- client/.gitignore +0 -3
- client/package-lock.json +0 -4540
- client/package.json +0 -36
- client/public/bundle.js +0 -10072
- client/public/bundle.js.map +0 -1
- client/public/index.html +0 -14
- client/public/input.css +0 -3
- client/public/output.css +0 -1617
- client/rollup.config.mjs +0 -44
- client/src/App.svelte +0 -584
- client/src/empty.ts +0 -0
- client/src/main.js +0 -8
- client/tailwind.config.js +0 -8
- client/tsconfig.json +0 -5
- statedict2pytree-0.6.0.dist-info/RECORD +0 -17
- {statedict2pytree-0.6.0.dist-info → statedict2pytree-1.0.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,293 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pathlib
|
|
3
|
+
import tempfile
|
|
4
|
+
|
|
5
|
+
import equinox as eqx
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import numpy as np
|
|
9
|
+
from beartype.typing import Any, Callable
|
|
10
|
+
from jax.tree_util import FlattenedIndexKey, GetAttrKey, KeyPath, SequenceKey
|
|
11
|
+
from jaxtyping import Array, PyTree
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
from tqdm import tqdm
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ChunkifiedPytreePath(BaseModel):
|
|
17
|
+
path: str
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ChunkifiedStatedictPath(BaseModel):
|
|
21
|
+
path: str
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TorchField(BaseModel):
|
|
25
|
+
path: str
|
|
26
|
+
shape: tuple[int, ...]
|
|
27
|
+
skip: bool = False
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class JaxField(BaseModel):
|
|
31
|
+
path: KeyPath
|
|
32
|
+
shape: tuple[int, ...]
|
|
33
|
+
skip: bool = False
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def is_numerical(element: Any):
|
|
37
|
+
if hasattr(element, "dtype"):
|
|
38
|
+
# Check if it's a JAX or NumPy array
|
|
39
|
+
return (
|
|
40
|
+
np.issubdtype(element.dtype, np.integer)
|
|
41
|
+
or np.issubdtype(element.dtype, np.floating)
|
|
42
|
+
or np.issubdtype(element.dtype, np.complexfloating)
|
|
43
|
+
) and not np.issubdtype(element.dtype, np.bool_)
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def _default_floating_dtype():
|
|
48
|
+
if jax.config.jax_enable_x64: # pyright: ignore
|
|
49
|
+
return jnp.float64
|
|
50
|
+
else:
|
|
51
|
+
return jnp.float32
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _can_reshape(shape1: tuple, shape2: tuple):
|
|
55
|
+
"""
|
|
56
|
+
Check if two shapes can be reshaped to each other.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
shape1 (tuple): First shape.
|
|
60
|
+
shape2 (tuple): Second shape.
|
|
61
|
+
|
|
62
|
+
Returns:
|
|
63
|
+
bool: True if shapes can be reshaped to each other, False otherwise.
|
|
64
|
+
"""
|
|
65
|
+
product1 = np.prod(shape1)
|
|
66
|
+
product2 = np.prod(shape2)
|
|
67
|
+
|
|
68
|
+
return product1 == product2
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_stateindex_fields(obj) -> dict:
|
|
72
|
+
state_indices = {}
|
|
73
|
+
|
|
74
|
+
for attr_name in dir(obj):
|
|
75
|
+
if attr_name.startswith("_"):
|
|
76
|
+
continue
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
attr_value = getattr(obj, attr_name)
|
|
80
|
+
if isinstance(attr_value, eqx.nn.StateIndex):
|
|
81
|
+
state_indices[attr_name] = attr_value
|
|
82
|
+
except: # noqa
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
return state_indices
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _get_node(
|
|
89
|
+
tree: PyTree, path: KeyPath, state_indices: dict | None = None
|
|
90
|
+
) -> tuple[PyTree | None, dict | None]:
|
|
91
|
+
if tree is None:
|
|
92
|
+
return None, {}
|
|
93
|
+
else:
|
|
94
|
+
if len(path) == 0:
|
|
95
|
+
return tree, state_indices
|
|
96
|
+
f, *_ = path
|
|
97
|
+
if hasattr(tree, "is_stateful"):
|
|
98
|
+
if state_indices is None:
|
|
99
|
+
state_indices = {}
|
|
100
|
+
indices = _get_stateindex_fields(tree)
|
|
101
|
+
for attr_name in indices:
|
|
102
|
+
index: eqx.nn.StateIndex = indices[attr_name]
|
|
103
|
+
assert isinstance(index, eqx.nn.StateIndex)
|
|
104
|
+
state_indices[index.marker] = index
|
|
105
|
+
if isinstance(f, SequenceKey):
|
|
106
|
+
subtree = tree[f.idx]
|
|
107
|
+
elif isinstance(f, GetAttrKey):
|
|
108
|
+
subtree = getattr(tree, f.name)
|
|
109
|
+
elif isinstance(f, FlattenedIndexKey):
|
|
110
|
+
if isinstance(tree, eqx.nn.State):
|
|
111
|
+
assert state_indices is not None
|
|
112
|
+
index = state_indices[f.key]
|
|
113
|
+
subtree = tree.get(index)
|
|
114
|
+
else:
|
|
115
|
+
subtree = None
|
|
116
|
+
else:
|
|
117
|
+
subtree = None
|
|
118
|
+
return _get_node(subtree, path[1:], state_indices)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _replace_node(
|
|
122
|
+
tree: PyTree, path: KeyPath, new_value: Array, state_indices: dict | None = None
|
|
123
|
+
) -> PyTree:
|
|
124
|
+
def where_wrapper(t):
|
|
125
|
+
node, _ = _get_node(t, path=path, state_indices=state_indices)
|
|
126
|
+
return node
|
|
127
|
+
|
|
128
|
+
node, _ = _get_node(tree, path=path, state_indices=state_indices)
|
|
129
|
+
|
|
130
|
+
if node is not None and eqx.is_array(node):
|
|
131
|
+
tree = eqx.tree_at(
|
|
132
|
+
where_wrapper,
|
|
133
|
+
tree,
|
|
134
|
+
new_value.reshape(node.shape),
|
|
135
|
+
)
|
|
136
|
+
else:
|
|
137
|
+
print("WARNING: Couldn't find: ", jax.tree_util.keystr(path))
|
|
138
|
+
return tree
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def move_running_fields_to_the_end(
|
|
142
|
+
torchfields: list[TorchField], identifier: str = "running_"
|
|
143
|
+
):
|
|
144
|
+
"""
|
|
145
|
+
Helper function to move fields that contain the given string to the end of the
|
|
146
|
+
torchfields. Helpful for stateful layers such as BatchNorm, which appear at the
|
|
147
|
+
end of Equinox pytrees.
|
|
148
|
+
"""
|
|
149
|
+
i = 0
|
|
150
|
+
total = 0
|
|
151
|
+
while i + total < len(torchfields):
|
|
152
|
+
if identifier in torchfields[i].path:
|
|
153
|
+
field = torchfields.pop(i)
|
|
154
|
+
torchfields.append(field)
|
|
155
|
+
total += 1
|
|
156
|
+
else:
|
|
157
|
+
i += 1
|
|
158
|
+
return torchfields
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def state_dict_to_fields(
|
|
162
|
+
state_dict: dict[str, Any],
|
|
163
|
+
) -> list[TorchField]:
|
|
164
|
+
if state_dict is None:
|
|
165
|
+
return []
|
|
166
|
+
fields: list[TorchField] = []
|
|
167
|
+
for key, value in state_dict.items():
|
|
168
|
+
if hasattr(value, "shape") and len(value.shape) > 0:
|
|
169
|
+
fields.append(TorchField(path=key, shape=tuple(value.shape)))
|
|
170
|
+
return fields
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
def pytree_to_fields(
|
|
174
|
+
pytree: PyTree,
|
|
175
|
+
model_order: list[str] | None = None,
|
|
176
|
+
filter: Callable[[Array], bool] = eqx.is_array,
|
|
177
|
+
) -> tuple[list[JaxField], dict | None]:
|
|
178
|
+
jaxfields = []
|
|
179
|
+
paths = jax.tree.leaves_with_path(pytree)
|
|
180
|
+
i = {}
|
|
181
|
+
for p in paths:
|
|
182
|
+
keys, _ = p
|
|
183
|
+
n, i = _get_node(pytree, keys, i)
|
|
184
|
+
if n is not None and filter(n):
|
|
185
|
+
jaxfields.append(JaxField(path=keys, shape=n.shape))
|
|
186
|
+
|
|
187
|
+
if model_order is not None:
|
|
188
|
+
ordered_jaxfields = []
|
|
189
|
+
path_dict = {jax.tree_util.keystr(field.path): field for field in jaxfields}
|
|
190
|
+
|
|
191
|
+
for path_str in model_order:
|
|
192
|
+
if path_str in path_dict:
|
|
193
|
+
ordered_jaxfields.append(path_dict[path_str])
|
|
194
|
+
del path_dict[path_str]
|
|
195
|
+
ordered_jaxfields.extend(path_dict.values())
|
|
196
|
+
jaxfields = ordered_jaxfields
|
|
197
|
+
return jaxfields, i
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def _chunkify_state_dict(
|
|
201
|
+
state_dict: dict[str, np.ndarray], target_path: str
|
|
202
|
+
) -> ChunkifiedStatedictPath:
|
|
203
|
+
"""
|
|
204
|
+
Convert a PyTorch state dict into chunked files and save them to the specified path.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
state_dict (dict[str, np.ndarray]): The PyTorch state dict to be chunked.
|
|
208
|
+
target_path (str): The directory where chunked files will be saved.
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
ChunkifiedStatedictPath: A path to the chunked files
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
for key in state_dict.keys():
|
|
215
|
+
if not hasattr(state_dict[key], "shape"):
|
|
216
|
+
continue
|
|
217
|
+
path = pathlib.Path(target_path) / "state_dict"
|
|
218
|
+
|
|
219
|
+
if not os.path.exists(path):
|
|
220
|
+
os.mkdir(path)
|
|
221
|
+
np.save(path / key, state_dict[key])
|
|
222
|
+
|
|
223
|
+
return ChunkifiedStatedictPath(path=str(pathlib.Path(target_path)))
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def convert(
|
|
227
|
+
state_dict: dict[str, Any],
|
|
228
|
+
pytree: PyTree,
|
|
229
|
+
jaxfields: list[JaxField],
|
|
230
|
+
state_indices: dict | None,
|
|
231
|
+
torchfields: list[TorchField],
|
|
232
|
+
dtype: Any | None = None,
|
|
233
|
+
) -> PyTree:
|
|
234
|
+
if dtype is None:
|
|
235
|
+
dtype = _default_floating_dtype()
|
|
236
|
+
assert dtype is not None
|
|
237
|
+
state_dict_np: dict[str, np.ndarray] = {
|
|
238
|
+
k: state_dict[k].detach().numpy() for k in state_dict
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
for k in state_dict_np:
|
|
242
|
+
if np.issubdtype(state_dict_np[k].dtype, np.floating):
|
|
243
|
+
state_dict_np[k] = state_dict_np[k].astype(dtype)
|
|
244
|
+
|
|
245
|
+
if len(torchfields) != len(jaxfields):
|
|
246
|
+
raise ValueError(
|
|
247
|
+
f"Length of state_dict ({len(torchfields)}) "
|
|
248
|
+
f"!= length of pytree ({len(jaxfields)})"
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
252
|
+
chunkified_statedict_path = _chunkify_state_dict(state_dict_np, tmpdir)
|
|
253
|
+
del state_dict_np, state_dict
|
|
254
|
+
for t, j in tqdm(zip(torchfields, jaxfields), total=len(torchfields)):
|
|
255
|
+
if not _can_reshape(t.shape, j.shape):
|
|
256
|
+
raise ValueError(
|
|
257
|
+
f"Cannot reshape {t.shape} "
|
|
258
|
+
f"into shape {j.shape}. "
|
|
259
|
+
"Note that the order of the fields matters "
|
|
260
|
+
"and that you can mark arrays as skippable. "
|
|
261
|
+
f"{t.path=} "
|
|
262
|
+
f"{jax.tree_util.keystr(j.path)=}"
|
|
263
|
+
)
|
|
264
|
+
state_dict_dir = pathlib.Path(chunkified_statedict_path.path) / "state_dict"
|
|
265
|
+
filename = state_dict_dir / t.path
|
|
266
|
+
new_value = jnp.array(np.load(str(filename) + ".npy"))
|
|
267
|
+
|
|
268
|
+
n, _ = _get_node(pytree, j.path, state_indices)
|
|
269
|
+
assert n is not None, f"Node {j.path} not found"
|
|
270
|
+
assert _can_reshape(n.shape, new_value.shape), (
|
|
271
|
+
f"Cannot reshape {n.shape} into {new_value.shape}"
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
pytree = _replace_node(pytree, j.path, new_value, state_indices)
|
|
275
|
+
|
|
276
|
+
return pytree
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def autoconvert(
|
|
280
|
+
pytree: PyTree, state_dict: dict, pytree_model_order: list[str] | None = None
|
|
281
|
+
) -> PyTree:
|
|
282
|
+
torchfields = state_dict_to_fields(state_dict)
|
|
283
|
+
jaxfields, state_indices = pytree_to_fields(pytree, pytree_model_order)
|
|
284
|
+
|
|
285
|
+
pytree = convert(
|
|
286
|
+
state_dict,
|
|
287
|
+
pytree,
|
|
288
|
+
jaxfields,
|
|
289
|
+
state_indices,
|
|
290
|
+
torchfields,
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
return pytree
|
|
@@ -1,58 +1,41 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.1
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
-
Requires-Python:
|
|
7
|
-
Requires-Dist: anthropic
|
|
6
|
+
Requires-Python: >=3.11
|
|
8
7
|
Requires-Dist: beartype
|
|
9
8
|
Requires-Dist: equinox
|
|
10
|
-
Requires-Dist: flask
|
|
11
9
|
Requires-Dist: jax
|
|
12
10
|
Requires-Dist: jaxlib
|
|
13
|
-
Requires-Dist: jaxonmodels
|
|
14
11
|
Requires-Dist: jaxtyping
|
|
15
|
-
Requires-Dist: loguru
|
|
16
|
-
Requires-Dist: penzai
|
|
17
12
|
Requires-Dist: pydantic
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist: python-dotenv
|
|
20
|
-
Requires-Dist: torch
|
|
21
|
-
Requires-Dist: torchvision
|
|
22
|
-
Requires-Dist: typing-extensions
|
|
13
|
+
Requires-Dist: tqdm
|
|
23
14
|
Provides-Extra: dev
|
|
24
15
|
Requires-Dist: mkdocs; extra == 'dev'
|
|
25
|
-
Requires-Dist: nox; extra == 'dev'
|
|
26
16
|
Requires-Dist: pre-commit; extra == 'dev'
|
|
27
17
|
Requires-Dist: pytest; extra == 'dev'
|
|
18
|
+
Requires-Dist: torch; extra == 'dev'
|
|
28
19
|
Provides-Extra: examples
|
|
29
20
|
Requires-Dist: jaxonmodels; extra == 'examples'
|
|
30
21
|
Description-Content-Type: text/markdown
|
|
31
22
|
|
|
32
23
|
# statedict2pytree
|
|
33
24
|
|
|
34
|
-

|
|
35
25
|
|
|
36
|
-
##
|
|
26
|
+
## Update:
|
|
37
27
|
|
|
38
|
-
|
|
28
|
+
For examples for `statedict2pytree`, check out my other repository [jaxonmodels](https://github.com/Artur-Galstyan/jaxonmodels).
|
|
39
29
|
|
|
30
|
+
## Docs
|
|
40
31
|
|
|
41
|
-
|
|
32
|
+
Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
|
|
42
33
|
|
|
43
|
-
This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
|
|
44
|
-
PRs and other contributions are *highly* welcome! :)
|
|
45
34
|
|
|
46
35
|
## Info
|
|
47
36
|
|
|
48
|
-
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees
|
|
37
|
+
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
|
|
49
38
|
|
|
50
|
-
## Features
|
|
51
|
-
|
|
52
|
-
- Convert PyTorch statedicts to JAX pytrees
|
|
53
|
-
- Handle large models with chunked file conversion
|
|
54
|
-
- Provide an "intuitive-ish" UI for parameter mapping
|
|
55
|
-
- Support both in-memory and file-based conversions
|
|
56
39
|
|
|
57
40
|
## Installation
|
|
58
41
|
|
|
@@ -64,7 +47,6 @@ The goal of this package is to simplify the conversion from PyTorch models into
|
|
|
64
47
|
|
|
65
48
|
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
66
49
|
|
|
67
|
-
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
68
50
|
|
|
69
51
|
## Shape Matching? What's that?
|
|
70
52
|
|
|
@@ -73,8 +55,8 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
|
|
|
73
55
|
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
74
56
|
|
|
75
57
|
|
|
76
|
-
|
|
77
58
|
### Disclaimer
|
|
78
59
|
|
|
79
60
|
Some of the docstrings and the docs have been written with the help of
|
|
80
61
|
Claude.
|
|
62
|
+
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
|
|
2
|
+
statedict2pytree/converter.py,sha256=Toehra2_guidbR5Gyl1uUi1YYv37jkoSKDQjxo1tzwg,8626
|
|
3
|
+
statedict2pytree-1.0.1.dist-info/METADATA,sha256=Z_tRmXwxVMGUMrOYm3SiCxnl4L9OUINZMrBqxbEg5dE,1918
|
|
4
|
+
statedict2pytree-1.0.1.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
5
|
+
statedict2pytree-1.0.1.dist-info/RECORD,,
|
client/.gitignore
DELETED