statedict2pytree 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/__init__.py +3 -0
- statedict2pytree/statedict2pytree.py +17 -3
- {statedict2pytree-0.2.0.dist-info → statedict2pytree-0.3.0.dist-info}/METADATA +22 -1
- statedict2pytree-0.3.0.dist-info/RECORD +8 -0
- statedict2pytree-0.2.0.dist-info/RECORD +0 -8
- {statedict2pytree-0.2.0.dist-info → statedict2pytree-0.3.0.dist-info}/WHEEL +0 -0
statedict2pytree/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ import re
|
|
|
4
4
|
import equinox as eqx
|
|
5
5
|
import flask
|
|
6
6
|
import jax
|
|
7
|
+
import numpy as np
|
|
7
8
|
from beartype.typing import Optional
|
|
8
9
|
from jaxtyping import PyTree
|
|
9
10
|
from loguru import logger
|
|
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
|
|
|
31
32
|
STATE_DICT: Optional[dict] = None
|
|
32
33
|
|
|
33
34
|
|
|
35
|
+
def can_reshape(shape1, shape2):
|
|
36
|
+
product1 = np.prod(shape1)
|
|
37
|
+
product2 = np.prod(shape2)
|
|
38
|
+
|
|
39
|
+
return product1 == product2
|
|
40
|
+
|
|
41
|
+
|
|
34
42
|
def get_node(
|
|
35
43
|
tree: PyTree, targets: list[str], log_when_not_found: bool = False
|
|
36
44
|
) -> PyTree | None:
|
|
@@ -146,17 +154,23 @@ def index():
|
|
|
146
154
|
)
|
|
147
155
|
|
|
148
156
|
|
|
157
|
+
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
|
+
jax_fields = pytree_to_fields(pytree)
|
|
159
|
+
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
+
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
161
|
+
|
|
162
|
+
|
|
149
163
|
def convert(
|
|
150
164
|
jax_fields: list[JaxField],
|
|
151
165
|
torch_fields: list[TorchField],
|
|
152
166
|
pytree: PyTree,
|
|
153
167
|
state_dict: dict,
|
|
154
|
-
):
|
|
168
|
+
) -> tuple[PyTree, eqx.nn.State]:
|
|
155
169
|
identity = lambda *args, **kwargs: pytree
|
|
156
170
|
model, state = eqx.nn.make_with_state(identity)()
|
|
157
171
|
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
158
172
|
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
159
|
-
if jax_field.shape
|
|
173
|
+
if not can_reshape(jax_field.shape, torch_field.shape):
|
|
160
174
|
raise ValueError(
|
|
161
175
|
"Fields have incompatible shapes!"
|
|
162
176
|
f"{jax_field.shape=} != {torch_field.shape=}"
|
|
@@ -171,7 +185,7 @@ def convert(
|
|
|
171
185
|
model = eqx.tree_at(
|
|
172
186
|
where,
|
|
173
187
|
model,
|
|
174
|
-
state_dict[torch_field.path],
|
|
188
|
+
state_dict[torch_field.path].reshape(jax_field.shape),
|
|
175
189
|
)
|
|
176
190
|
result: dict[str, list[TorchField]] = {}
|
|
177
191
|
for tuple_item in state_paths:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.3.0
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: ~=3.10
|
|
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
|
|
|
32
32
|
|
|
33
33
|
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
34
34
|
|
|
35
|
+
## Shape Matching? What's that?
|
|
36
|
+
|
|
37
|
+
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
|
|
38
|
+
|
|
39
|
+
1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
40
|
+
|
|
35
41
|
## Get Started
|
|
36
42
|
|
|
37
43
|
### Installation
|
|
@@ -124,3 +130,18 @@ def convert(
|
|
|
124
130
|
```
|
|
125
131
|
|
|
126
132
|
If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
|
|
133
|
+
|
|
134
|
+
For the full, automatic experience, use `autoconvert`:
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
import statedict2pytree as s2p
|
|
138
|
+
|
|
139
|
+
my_model = Model(...)
|
|
140
|
+
state_dict = ...
|
|
141
|
+
|
|
142
|
+
model, state = s2p.autoconvert(my_model, state_dict)
|
|
143
|
+
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
This will however only work if your PyTree fields have been declared
|
|
147
|
+
in the same order as they appear in the state dict!
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
|
|
2
|
+
statedict2pytree/statedict2pytree.py,sha256=X5Ljf4lYhhH7_V4KgdciChncbTt7YZpIWHcOxcZ3l48,7103
|
|
3
|
+
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
|
+
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
|
+
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
+
statedict2pytree-0.3.0.dist-info/METADATA,sha256=YSK4tWzNQemyZ1xKq5BhWiLWWc-RDr4E9q_eV_iOsdw,4232
|
|
7
|
+
statedict2pytree-0.3.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
+
statedict2pytree-0.3.0.dist-info/RECORD,,
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
|
|
2
|
-
statedict2pytree/statedict2pytree.py,sha256=SQ1Xs4VPG_gnjYdAOBzdoNih1RGVsW9x08UQGkIFUdg,6640
|
|
3
|
-
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
|
-
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
|
-
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
-
statedict2pytree-0.2.0.dist-info/METADATA,sha256=x9GhcbG0io5HIpfwocS-nsPYyKRTmG90uLKyckSm394,3673
|
|
7
|
-
statedict2pytree-0.2.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
-
statedict2pytree-0.2.0.dist-info/RECORD,,
|
|
File without changes
|