statedict2pytree 0.6.0__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,8 @@
1
+ from .converter import (
2
+ autoconvert, # noqa
3
+ convert, # noqa
4
+ is_numerical, # noqa
5
+ move_running_fields_to_the_end, # noqa
6
+ pytree_to_fields, # noqa
7
+ state_dict_to_fields, # noqa
8
+ ) # noqa
@@ -0,0 +1,293 @@
1
+ import os
2
+ import pathlib
3
+ import tempfile
4
+
5
+ import equinox as eqx
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from beartype.typing import Any, Callable
10
+ from jax.tree_util import FlattenedIndexKey, GetAttrKey, KeyPath, SequenceKey
11
+ from jaxtyping import Array, PyTree
12
+ from pydantic import BaseModel
13
+ from tqdm import tqdm
14
+
15
+
16
+ class ChunkifiedPytreePath(BaseModel):
17
+ path: str
18
+
19
+
20
+ class ChunkifiedStatedictPath(BaseModel):
21
+ path: str
22
+
23
+
24
+ class TorchField(BaseModel):
25
+ path: str
26
+ shape: tuple[int, ...]
27
+ skip: bool = False
28
+
29
+
30
+ class JaxField(BaseModel):
31
+ path: KeyPath
32
+ shape: tuple[int, ...]
33
+ skip: bool = False
34
+
35
+
36
+ def is_numerical(element: Any):
37
+ if hasattr(element, "dtype"):
38
+ # Check if it's a JAX or NumPy array
39
+ return (
40
+ np.issubdtype(element.dtype, np.integer)
41
+ or np.issubdtype(element.dtype, np.floating)
42
+ or np.issubdtype(element.dtype, np.complexfloating)
43
+ ) and not np.issubdtype(element.dtype, np.bool_)
44
+ return False
45
+
46
+
47
+ def _default_floating_dtype():
48
+ if jax.config.jax_enable_x64: # pyright: ignore
49
+ return jnp.float64
50
+ else:
51
+ return jnp.float32
52
+
53
+
54
+ def _can_reshape(shape1: tuple, shape2: tuple):
55
+ """
56
+ Check if two shapes can be reshaped to each other.
57
+
58
+ Args:
59
+ shape1 (tuple): First shape.
60
+ shape2 (tuple): Second shape.
61
+
62
+ Returns:
63
+ bool: True if shapes can be reshaped to each other, False otherwise.
64
+ """
65
+ product1 = np.prod(shape1)
66
+ product2 = np.prod(shape2)
67
+
68
+ return product1 == product2
69
+
70
+
71
+ def _get_stateindex_fields(obj) -> dict:
72
+ state_indices = {}
73
+
74
+ for attr_name in dir(obj):
75
+ if attr_name.startswith("_"):
76
+ continue
77
+
78
+ try:
79
+ attr_value = getattr(obj, attr_name)
80
+ if isinstance(attr_value, eqx.nn.StateIndex):
81
+ state_indices[attr_name] = attr_value
82
+ except: # noqa
83
+ pass
84
+
85
+ return state_indices
86
+
87
+
88
+ def _get_node(
89
+ tree: PyTree, path: KeyPath, state_indices: dict | None = None
90
+ ) -> tuple[PyTree | None, dict | None]:
91
+ if tree is None:
92
+ return None, {}
93
+ else:
94
+ if len(path) == 0:
95
+ return tree, state_indices
96
+ f, *_ = path
97
+ if hasattr(tree, "is_stateful"):
98
+ if state_indices is None:
99
+ state_indices = {}
100
+ indices = _get_stateindex_fields(tree)
101
+ for attr_name in indices:
102
+ index: eqx.nn.StateIndex = indices[attr_name]
103
+ assert isinstance(index, eqx.nn.StateIndex)
104
+ state_indices[index.marker] = index
105
+ if isinstance(f, SequenceKey):
106
+ subtree = tree[f.idx]
107
+ elif isinstance(f, GetAttrKey):
108
+ subtree = getattr(tree, f.name)
109
+ elif isinstance(f, FlattenedIndexKey):
110
+ if isinstance(tree, eqx.nn.State):
111
+ assert state_indices is not None
112
+ index = state_indices[f.key]
113
+ subtree = tree.get(index)
114
+ else:
115
+ subtree = None
116
+ else:
117
+ subtree = None
118
+ return _get_node(subtree, path[1:], state_indices)
119
+
120
+
121
+ def _replace_node(
122
+ tree: PyTree, path: KeyPath, new_value: Array, state_indices: dict | None = None
123
+ ) -> PyTree:
124
+ def where_wrapper(t):
125
+ node, _ = _get_node(t, path=path, state_indices=state_indices)
126
+ return node
127
+
128
+ node, _ = _get_node(tree, path=path, state_indices=state_indices)
129
+
130
+ if node is not None and eqx.is_array(node):
131
+ tree = eqx.tree_at(
132
+ where_wrapper,
133
+ tree,
134
+ new_value.reshape(node.shape),
135
+ )
136
+ else:
137
+ print("WARNING: Couldn't find: ", jax.tree_util.keystr(path))
138
+ return tree
139
+
140
+
141
+ def move_running_fields_to_the_end(
142
+ torchfields: list[TorchField], identifier: str = "running_"
143
+ ):
144
+ """
145
+ Helper function to move fields that contain the given string to the end of the
146
+ torchfields. Helpful for stateful layers such as BatchNorm, which appear at the
147
+ end of Equinox pytrees.
148
+ """
149
+ i = 0
150
+ total = 0
151
+ while i + total < len(torchfields):
152
+ if identifier in torchfields[i].path:
153
+ field = torchfields.pop(i)
154
+ torchfields.append(field)
155
+ total += 1
156
+ else:
157
+ i += 1
158
+ return torchfields
159
+
160
+
161
+ def state_dict_to_fields(
162
+ state_dict: dict[str, Any],
163
+ ) -> list[TorchField]:
164
+ if state_dict is None:
165
+ return []
166
+ fields: list[TorchField] = []
167
+ for key, value in state_dict.items():
168
+ if hasattr(value, "shape") and len(value.shape) > 0:
169
+ fields.append(TorchField(path=key, shape=tuple(value.shape)))
170
+ return fields
171
+
172
+
173
+ def pytree_to_fields(
174
+ pytree: PyTree,
175
+ model_order: list[str] | None = None,
176
+ filter: Callable[[Array], bool] = eqx.is_array,
177
+ ) -> tuple[list[JaxField], dict | None]:
178
+ jaxfields = []
179
+ paths = jax.tree.leaves_with_path(pytree)
180
+ i = {}
181
+ for p in paths:
182
+ keys, _ = p
183
+ n, i = _get_node(pytree, keys, i)
184
+ if n is not None and filter(n):
185
+ jaxfields.append(JaxField(path=keys, shape=n.shape))
186
+
187
+ if model_order is not None:
188
+ ordered_jaxfields = []
189
+ path_dict = {jax.tree_util.keystr(field.path): field for field in jaxfields}
190
+
191
+ for path_str in model_order:
192
+ if path_str in path_dict:
193
+ ordered_jaxfields.append(path_dict[path_str])
194
+ del path_dict[path_str]
195
+ ordered_jaxfields.extend(path_dict.values())
196
+ jaxfields = ordered_jaxfields
197
+ return jaxfields, i
198
+
199
+
200
+ def _chunkify_state_dict(
201
+ state_dict: dict[str, np.ndarray], target_path: str
202
+ ) -> ChunkifiedStatedictPath:
203
+ """
204
+ Convert a PyTorch state dict into chunked files and save them to the specified path.
205
+
206
+ Args:
207
+ state_dict (dict[str, np.ndarray]): The PyTorch state dict to be chunked.
208
+ target_path (str): The directory where chunked files will be saved.
209
+
210
+ Returns:
211
+ ChunkifiedStatedictPath: A path to the chunked files
212
+ """
213
+
214
+ for key in state_dict.keys():
215
+ if not hasattr(state_dict[key], "shape"):
216
+ continue
217
+ path = pathlib.Path(target_path) / "state_dict"
218
+
219
+ if not os.path.exists(path):
220
+ os.mkdir(path)
221
+ np.save(path / key, state_dict[key])
222
+
223
+ return ChunkifiedStatedictPath(path=str(pathlib.Path(target_path)))
224
+
225
+
226
+ def convert(
227
+ state_dict: dict[str, Any],
228
+ pytree: PyTree,
229
+ jaxfields: list[JaxField],
230
+ state_indices: dict | None,
231
+ torchfields: list[TorchField],
232
+ dtype: Any | None = None,
233
+ ) -> PyTree:
234
+ if dtype is None:
235
+ dtype = _default_floating_dtype()
236
+ assert dtype is not None
237
+ state_dict_np: dict[str, np.ndarray] = {
238
+ k: state_dict[k].detach().numpy() for k in state_dict
239
+ }
240
+
241
+ for k in state_dict_np:
242
+ if np.issubdtype(state_dict_np[k].dtype, np.floating):
243
+ state_dict_np[k] = state_dict_np[k].astype(dtype)
244
+
245
+ if len(torchfields) != len(jaxfields):
246
+ raise ValueError(
247
+ f"Length of state_dict ({len(torchfields)}) "
248
+ f"!= length of pytree ({len(jaxfields)})"
249
+ )
250
+
251
+ with tempfile.TemporaryDirectory() as tmpdir:
252
+ chunkified_statedict_path = _chunkify_state_dict(state_dict_np, tmpdir)
253
+ del state_dict_np, state_dict
254
+ for t, j in tqdm(zip(torchfields, jaxfields), total=len(torchfields)):
255
+ if not _can_reshape(t.shape, j.shape):
256
+ raise ValueError(
257
+ f"Cannot reshape {t.shape} "
258
+ f"into shape {j.shape}. "
259
+ "Note that the order of the fields matters "
260
+ "and that you can mark arrays as skippable. "
261
+ f"{t.path=} "
262
+ f"{jax.tree_util.keystr(j.path)=}"
263
+ )
264
+ state_dict_dir = pathlib.Path(chunkified_statedict_path.path) / "state_dict"
265
+ filename = state_dict_dir / t.path
266
+ new_value = jnp.array(np.load(str(filename) + ".npy"))
267
+
268
+ n, _ = _get_node(pytree, j.path, state_indices)
269
+ assert n is not None, f"Node {j.path} not found"
270
+ assert _can_reshape(n.shape, new_value.shape), (
271
+ f"Cannot reshape {n.shape} into {new_value.shape}"
272
+ )
273
+
274
+ pytree = _replace_node(pytree, j.path, new_value, state_indices)
275
+
276
+ return pytree
277
+
278
+
279
+ def autoconvert(
280
+ pytree: PyTree, state_dict: dict, pytree_model_order: list[str] | None = None
281
+ ) -> PyTree:
282
+ torchfields = state_dict_to_fields(state_dict)
283
+ jaxfields, state_indices = pytree_to_fields(pytree, pytree_model_order)
284
+
285
+ pytree = convert(
286
+ state_dict,
287
+ pytree,
288
+ jaxfields,
289
+ state_indices,
290
+ torchfields,
291
+ )
292
+
293
+ return pytree
@@ -1,58 +1,41 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.6.0
3
+ Version: 1.0.1
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
- Requires-Python: ~=3.10
7
- Requires-Dist: anthropic
6
+ Requires-Python: >=3.11
8
7
  Requires-Dist: beartype
9
8
  Requires-Dist: equinox
10
- Requires-Dist: flask
11
9
  Requires-Dist: jax
12
10
  Requires-Dist: jaxlib
13
- Requires-Dist: jaxonmodels
14
11
  Requires-Dist: jaxtyping
15
- Requires-Dist: loguru
16
- Requires-Dist: penzai
17
12
  Requires-Dist: pydantic
18
- Requires-Dist: pytest
19
- Requires-Dist: python-dotenv
20
- Requires-Dist: torch
21
- Requires-Dist: torchvision
22
- Requires-Dist: typing-extensions
13
+ Requires-Dist: tqdm
23
14
  Provides-Extra: dev
24
15
  Requires-Dist: mkdocs; extra == 'dev'
25
- Requires-Dist: nox; extra == 'dev'
26
16
  Requires-Dist: pre-commit; extra == 'dev'
27
17
  Requires-Dist: pytest; extra == 'dev'
18
+ Requires-Dist: torch; extra == 'dev'
28
19
  Provides-Extra: examples
29
20
  Requires-Dist: jaxonmodels; extra == 'examples'
30
21
  Description-Content-Type: text/markdown
31
22
 
32
23
  # statedict2pytree
33
24
 
34
- ![statedict2pytree](statedict2pytree.png "A ResNet demo")
35
25
 
36
- ## Docs
26
+ ## Update:
37
27
 
38
- Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
28
+ For examples for `statedict2pytree`, check out my other repository [jaxonmodels](https://github.com/Artur-Galstyan/jaxonmodels).
39
29
 
30
+ ## Docs
40
31
 
41
- ## Important
32
+ Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
42
33
 
43
- This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
44
- PRs and other contributions are *highly* welcome! :)
45
34
 
46
35
  ## Info
47
36
 
48
- `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
37
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
49
38
 
50
- ## Features
51
-
52
- - Convert PyTorch statedicts to JAX pytrees
53
- - Handle large models with chunked file conversion
54
- - Provide an "intuitive-ish" UI for parameter mapping
55
- - Support both in-memory and file-based conversions
56
39
 
57
40
  ## Installation
58
41
 
@@ -64,7 +47,6 @@ The goal of this package is to simplify the conversion from PyTorch models into
64
47
 
65
48
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
66
49
 
67
- (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
68
50
 
69
51
  ## Shape Matching? What's that?
70
52
 
@@ -73,8 +55,8 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
73
55
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
74
56
 
75
57
 
76
-
77
58
  ### Disclaimer
78
59
 
79
60
  Some of the docstrings and the docs have been written with the help of
80
61
  Claude.
62
+
@@ -0,0 +1,5 @@
1
+ statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
2
+ statedict2pytree/converter.py,sha256=Toehra2_guidbR5Gyl1uUi1YYv37jkoSKDQjxo1tzwg,8626
3
+ statedict2pytree-1.0.1.dist-info/METADATA,sha256=Z_tRmXwxVMGUMrOYm3SiCxnl4L9OUINZMrBqxbEg5dE,1918
4
+ statedict2pytree-1.0.1.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
5
+ statedict2pytree-1.0.1.dist-info/RECORD,,
client/.gitignore DELETED
@@ -1,3 +0,0 @@
1
- .DS_Store
2
- node_modules
3
- public/bundle.*