statedict2pytree 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/__init__.py +3 -0
- statedict2pytree/statedict2pytree.py +21 -4
- {statedict2pytree-0.2.0.dist-info → statedict2pytree-0.4.0.dist-info}/METADATA +22 -1
- statedict2pytree-0.4.0.dist-info/RECORD +8 -0
- statedict2pytree-0.2.0.dist-info/RECORD +0 -8
- {statedict2pytree-0.2.0.dist-info → statedict2pytree-0.4.0.dist-info}/WHEEL +0 -0
statedict2pytree/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ import re
|
|
|
4
4
|
import equinox as eqx
|
|
5
5
|
import flask
|
|
6
6
|
import jax
|
|
7
|
+
import numpy as np
|
|
7
8
|
from beartype.typing import Optional
|
|
8
9
|
from jaxtyping import PyTree
|
|
9
10
|
from loguru import logger
|
|
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
|
|
|
31
32
|
STATE_DICT: Optional[dict] = None
|
|
32
33
|
|
|
33
34
|
|
|
35
|
+
def can_reshape(shape1, shape2):
|
|
36
|
+
product1 = np.prod(shape1)
|
|
37
|
+
product2 = np.prod(shape2)
|
|
38
|
+
|
|
39
|
+
return product1 == product2
|
|
40
|
+
|
|
41
|
+
|
|
34
42
|
def get_node(
|
|
35
43
|
tree: PyTree, targets: list[str], log_when_not_found: bool = False
|
|
36
44
|
) -> PyTree | None:
|
|
@@ -38,7 +46,7 @@ def get_node(
|
|
|
38
46
|
return tree
|
|
39
47
|
else:
|
|
40
48
|
next_target: str = targets[0]
|
|
41
|
-
if bool(re.search(r"\[\d
|
|
49
|
+
if bool(re.search(r"\[\d+\]", next_target)):
|
|
42
50
|
split_index = next_target.rfind("[")
|
|
43
51
|
name, index = next_target[:split_index], next_target[split_index:]
|
|
44
52
|
index = index[1:-1]
|
|
@@ -146,17 +154,26 @@ def index():
|
|
|
146
154
|
)
|
|
147
155
|
|
|
148
156
|
|
|
157
|
+
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
|
+
jax_fields = pytree_to_fields(pytree)
|
|
159
|
+
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
+
|
|
161
|
+
for k, v in state_dict.items():
|
|
162
|
+
state_dict[k] = v.numpy()
|
|
163
|
+
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
164
|
+
|
|
165
|
+
|
|
149
166
|
def convert(
|
|
150
167
|
jax_fields: list[JaxField],
|
|
151
168
|
torch_fields: list[TorchField],
|
|
152
169
|
pytree: PyTree,
|
|
153
170
|
state_dict: dict,
|
|
154
|
-
):
|
|
171
|
+
) -> tuple[PyTree, eqx.nn.State]:
|
|
155
172
|
identity = lambda *args, **kwargs: pytree
|
|
156
173
|
model, state = eqx.nn.make_with_state(identity)()
|
|
157
174
|
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
158
175
|
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
159
|
-
if jax_field.shape
|
|
176
|
+
if not can_reshape(jax_field.shape, torch_field.shape):
|
|
160
177
|
raise ValueError(
|
|
161
178
|
"Fields have incompatible shapes!"
|
|
162
179
|
f"{jax_field.shape=} != {torch_field.shape=}"
|
|
@@ -171,7 +188,7 @@ def convert(
|
|
|
171
188
|
model = eqx.tree_at(
|
|
172
189
|
where,
|
|
173
190
|
model,
|
|
174
|
-
state_dict[torch_field.path],
|
|
191
|
+
state_dict[torch_field.path].reshape(jax_field.shape),
|
|
175
192
|
)
|
|
176
193
|
result: dict[str, list[TorchField]] = {}
|
|
177
194
|
for tuple_item in state_paths:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: ~=3.10
|
|
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
|
|
|
32
32
|
|
|
33
33
|
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
34
34
|
|
|
35
|
+
## Shape Matching? What's that?
|
|
36
|
+
|
|
37
|
+
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
|
|
38
|
+
|
|
39
|
+
1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
40
|
+
|
|
35
41
|
## Get Started
|
|
36
42
|
|
|
37
43
|
### Installation
|
|
@@ -124,3 +130,18 @@ def convert(
|
|
|
124
130
|
```
|
|
125
131
|
|
|
126
132
|
If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
|
|
133
|
+
|
|
134
|
+
For the full, automatic experience, use `autoconvert`:
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
import statedict2pytree as s2p
|
|
138
|
+
|
|
139
|
+
my_model = Model(...)
|
|
140
|
+
state_dict = ...
|
|
141
|
+
|
|
142
|
+
model, state = s2p.autoconvert(my_model, state_dict)
|
|
143
|
+
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
This will however only work if your PyTree fields have been declared
|
|
147
|
+
in the same order as they appear in the state dict!
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
|
|
2
|
+
statedict2pytree/statedict2pytree.py,sha256=u1PddUBY_MHErl3tstSpJt2a6a_H-RizvxP9anPoFpQ,7175
|
|
3
|
+
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
|
+
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
|
+
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
+
statedict2pytree-0.4.0.dist-info/METADATA,sha256=BbJaUhVv-Qb4-e2DlX0WBNgPOKsTjlKv8ISJZgFNx1o,4232
|
|
7
|
+
statedict2pytree-0.4.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
+
statedict2pytree-0.4.0.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
|
|
2
|
-
statedict2pytree/statedict2pytree.py,sha256=SQ1Xs4VPG_gnjYdAOBzdoNih1RGVsW9x08UQGkIFUdg,6640
|
|
3
|
-
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
|
-
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
|
-
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
-
statedict2pytree-0.2.0.dist-info/METADATA,sha256=x9GhcbG0io5HIpfwocS-nsPYyKRTmG90uLKyckSm394,3673
|
|
7
|
-
statedict2pytree-0.2.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
-
statedict2pytree-0.2.0.dist-info/RECORD,,
|
|
File without changes
|