statedict2pytree 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/__init__.py +3 -0
- statedict2pytree/statedict2pytree.py +30 -9
- statedict2pytree-0.3.0.dist-info/METADATA +147 -0
- statedict2pytree-0.3.0.dist-info/RECORD +8 -0
- statedict2pytree-0.1.2.dist-info/METADATA +0 -43
- statedict2pytree-0.1.2.dist-info/RECORD +0 -8
- {statedict2pytree-0.1.2.dist-info → statedict2pytree-0.3.0.dist-info}/WHEEL +0 -0
statedict2pytree/__init__.py
CHANGED
|
@@ -4,6 +4,7 @@ import re
|
|
|
4
4
|
import equinox as eqx
|
|
5
5
|
import flask
|
|
6
6
|
import jax
|
|
7
|
+
import numpy as np
|
|
7
8
|
from beartype.typing import Optional
|
|
8
9
|
from jaxtyping import PyTree
|
|
9
10
|
from loguru import logger
|
|
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
|
|
|
31
32
|
STATE_DICT: Optional[dict] = None
|
|
32
33
|
|
|
33
34
|
|
|
35
|
+
def can_reshape(shape1, shape2):
|
|
36
|
+
product1 = np.prod(shape1)
|
|
37
|
+
product2 = np.prod(shape2)
|
|
38
|
+
|
|
39
|
+
return product1 == product2
|
|
40
|
+
|
|
41
|
+
|
|
34
42
|
def get_node(
|
|
35
43
|
tree: PyTree, targets: list[str], log_when_not_found: bool = False
|
|
36
44
|
) -> PyTree | None:
|
|
@@ -95,7 +103,7 @@ def visualize_with_penzai():
|
|
|
95
103
|
return flask.jsonify({"error": "No data received"})
|
|
96
104
|
jax_fields = request_data["jaxFields"]
|
|
97
105
|
torch_fields = request_data["torchFields"]
|
|
98
|
-
model, state = convert(jax_fields, torch_fields)
|
|
106
|
+
model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
|
|
99
107
|
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
|
|
100
108
|
html_jax = pz.ts.render_to_html((model, state))
|
|
101
109
|
html_torch = pz.ts.render_to_html(STATE_DICT)
|
|
@@ -130,7 +138,7 @@ def convert_torch_to_jax():
|
|
|
130
138
|
torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
|
|
131
139
|
|
|
132
140
|
name = request_data["name"]
|
|
133
|
-
model, state = convert(jax_fields, torch_fields)
|
|
141
|
+
model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
|
|
134
142
|
eqx.tree_serialise_leaves(name, (model, state))
|
|
135
143
|
|
|
136
144
|
return flask.jsonify({"status": "success"})
|
|
@@ -146,14 +154,27 @@ def index():
|
|
|
146
154
|
)
|
|
147
155
|
|
|
148
156
|
|
|
149
|
-
def
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
157
|
+
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
|
+
jax_fields = pytree_to_fields(pytree)
|
|
159
|
+
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
+
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def convert(
|
|
164
|
+
jax_fields: list[JaxField],
|
|
165
|
+
torch_fields: list[TorchField],
|
|
166
|
+
pytree: PyTree,
|
|
167
|
+
state_dict: dict,
|
|
168
|
+
) -> tuple[PyTree, eqx.nn.State]:
|
|
169
|
+
identity = lambda *args, **kwargs: pytree
|
|
154
170
|
model, state = eqx.nn.make_with_state(identity)()
|
|
155
171
|
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
156
172
|
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
173
|
+
if not can_reshape(jax_field.shape, torch_field.shape):
|
|
174
|
+
raise ValueError(
|
|
175
|
+
"Fields have incompatible shapes!"
|
|
176
|
+
f"{jax_field.shape=} != {torch_field.shape=}"
|
|
177
|
+
)
|
|
157
178
|
path = jax_field.path.split(".")[1:]
|
|
158
179
|
if "StateIndex" in jax_field.type:
|
|
159
180
|
state_paths.append((jax_field, torch_field))
|
|
@@ -164,7 +185,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
|
|
|
164
185
|
model = eqx.tree_at(
|
|
165
186
|
where,
|
|
166
187
|
model,
|
|
167
|
-
|
|
188
|
+
state_dict[torch_field.path].reshape(jax_field.shape),
|
|
168
189
|
)
|
|
169
190
|
result: dict[str, list[TorchField]] = {}
|
|
170
191
|
for tuple_item in state_paths:
|
|
@@ -177,7 +198,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
|
|
|
177
198
|
for key in result:
|
|
178
199
|
state_index = get_node(model, key.split("."))
|
|
179
200
|
if state_index is not None:
|
|
180
|
-
to_replace_tuple = tuple([
|
|
201
|
+
to_replace_tuple = tuple([state_dict[i.path] for i in result[key]])
|
|
181
202
|
state = state.set(state_index, to_replace_tuple)
|
|
182
203
|
return model, state
|
|
183
204
|
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: statedict2pytree
|
|
3
|
+
Version: 0.3.0
|
|
4
|
+
Summary: Converts torch models into PyTrees for Equinox
|
|
5
|
+
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
+
Requires-Python: ~=3.10
|
|
7
|
+
Requires-Dist: beartype
|
|
8
|
+
Requires-Dist: equinox>=0.11.4
|
|
9
|
+
Requires-Dist: flask
|
|
10
|
+
Requires-Dist: jax
|
|
11
|
+
Requires-Dist: jaxlib
|
|
12
|
+
Requires-Dist: jaxtyping
|
|
13
|
+
Requires-Dist: loguru
|
|
14
|
+
Requires-Dist: penzai
|
|
15
|
+
Requires-Dist: pydantic
|
|
16
|
+
Requires-Dist: torch
|
|
17
|
+
Requires-Dist: typing-extensions
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: mkdocs; extra == 'dev'
|
|
20
|
+
Requires-Dist: nox; extra == 'dev'
|
|
21
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
22
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# statedict2pytree
|
|
26
|
+
|
|
27
|
+

|
|
28
|
+
|
|
29
|
+
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
30
|
+
|
|
31
|
+
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
32
|
+
|
|
33
|
+
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
34
|
+
|
|
35
|
+
## Shape Matching? What's that?
|
|
36
|
+
|
|
37
|
+
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
|
|
38
|
+
|
|
39
|
+
1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
40
|
+
|
|
41
|
+
## Get Started
|
|
42
|
+
|
|
43
|
+
### Installation
|
|
44
|
+
|
|
45
|
+
Run
|
|
46
|
+
|
|
47
|
+
```bash
|
|
48
|
+
pip install statedict2pytree
|
|
49
|
+
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
### Basic Example
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
import equinox as eqx
|
|
56
|
+
import jax
|
|
57
|
+
import torch
|
|
58
|
+
import statedict2pytree as s2p
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def test_mlp():
|
|
62
|
+
in_size = 784
|
|
63
|
+
out_size = 10
|
|
64
|
+
width_size = 64
|
|
65
|
+
depth = 2
|
|
66
|
+
key = jax.random.PRNGKey(22)
|
|
67
|
+
|
|
68
|
+
class EqxMLP(eqx.Module):
|
|
69
|
+
mlp: eqx.nn.MLP
|
|
70
|
+
batch_norm: eqx.nn.BatchNorm
|
|
71
|
+
|
|
72
|
+
def __init__(self, in_size, out_size, width_size, depth, key):
|
|
73
|
+
self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
|
|
74
|
+
self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")
|
|
75
|
+
|
|
76
|
+
def __call__(self, x, state):
|
|
77
|
+
return self.batch_norm(self.mlp(x), state)
|
|
78
|
+
|
|
79
|
+
jax_model = EqxMLP(in_size, out_size, width_size, depth, key)
|
|
80
|
+
|
|
81
|
+
class TorchMLP(torch.nn.Module):
|
|
82
|
+
def __init__(self, in_size, out_size, width_size, depth):
|
|
83
|
+
super(TorchMLP, self).__init__()
|
|
84
|
+
self.layers = torch.nn.ModuleList()
|
|
85
|
+
self.layers.append(torch.nn.Linear(in_size, width_size))
|
|
86
|
+
for _ in range(depth - 1):
|
|
87
|
+
self.layers.append(torch.nn.Linear(width_size, width_size))
|
|
88
|
+
self.layers.append(torch.nn.Linear(width_size, out_size))
|
|
89
|
+
self.batch_norm = torch.nn.BatchNorm1d(out_size)
|
|
90
|
+
|
|
91
|
+
def forward(self, x):
|
|
92
|
+
for layer in self.layers[:-1]:
|
|
93
|
+
x = torch.relu(layer(x))
|
|
94
|
+
x = self.batch_norm(self.layers[-1](x))
|
|
95
|
+
return x
|
|
96
|
+
|
|
97
|
+
torch_model = TorchMLP(in_size, out_size, width_size, depth)
|
|
98
|
+
state_dict = torch_model.state_dict()
|
|
99
|
+
s2p.start_conversion(jax_model, state_dict)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
if __name__ == "__main__":
|
|
103
|
+
test_mlp()
|
|
104
|
+
|
|
105
|
+
```
|
|
106
|
+
|
|
107
|
+
There exists also a function called `s2p.convert` which does the actual conversion:
|
|
108
|
+
|
|
109
|
+
```python
|
|
110
|
+
|
|
111
|
+
class Field(BaseModel):
|
|
112
|
+
path: str
|
|
113
|
+
shape: tuple[int, ...]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TorchField(Field):
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
class JaxField(Field):
|
|
121
|
+
type: str
|
|
122
|
+
|
|
123
|
+
def convert(
|
|
124
|
+
jax_fields: list[JaxField],
|
|
125
|
+
torch_fields: list[TorchField],
|
|
126
|
+
pytree: PyTree,
|
|
127
|
+
state_dict: dict,
|
|
128
|
+
):
|
|
129
|
+
...
|
|
130
|
+
```
|
|
131
|
+
|
|
132
|
+
If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
|
|
133
|
+
|
|
134
|
+
For the full, automatic experience, use `autoconvert`:
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
import statedict2pytree as s2p
|
|
138
|
+
|
|
139
|
+
my_model = Model(...)
|
|
140
|
+
state_dict = ...
|
|
141
|
+
|
|
142
|
+
model, state = s2p.autoconvert(my_model, state_dict)
|
|
143
|
+
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
This will however only work if your PyTree fields have been declared
|
|
147
|
+
in the same order as they appear in the state dict!
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
|
|
2
|
+
statedict2pytree/statedict2pytree.py,sha256=X5Ljf4lYhhH7_V4KgdciChncbTt7YZpIWHcOxcZ3l48,7103
|
|
3
|
+
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
|
+
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
|
+
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
+
statedict2pytree-0.3.0.dist-info/METADATA,sha256=YSK4tWzNQemyZ1xKq5BhWiLWWc-RDr4E9q_eV_iOsdw,4232
|
|
7
|
+
statedict2pytree-0.3.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
+
statedict2pytree-0.3.0.dist-info/RECORD,,
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.3
|
|
2
|
-
Name: statedict2pytree
|
|
3
|
-
Version: 0.1.2
|
|
4
|
-
Summary: Converts torch models into PyTrees for Equinox
|
|
5
|
-
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
-
Requires-Python: ~=3.10
|
|
7
|
-
Requires-Dist: beartype
|
|
8
|
-
Requires-Dist: equinox>=0.11.4
|
|
9
|
-
Requires-Dist: flask
|
|
10
|
-
Requires-Dist: jax
|
|
11
|
-
Requires-Dist: jaxlib
|
|
12
|
-
Requires-Dist: jaxtyping
|
|
13
|
-
Requires-Dist: loguru
|
|
14
|
-
Requires-Dist: pydantic
|
|
15
|
-
Requires-Dist: torch
|
|
16
|
-
Requires-Dist: typing-extensions
|
|
17
|
-
Provides-Extra: dev
|
|
18
|
-
Requires-Dist: mkdocs; extra == 'dev'
|
|
19
|
-
Requires-Dist: nox; extra == 'dev'
|
|
20
|
-
Requires-Dist: pre-commit; extra == 'dev'
|
|
21
|
-
Requires-Dist: pytest; extra == 'dev'
|
|
22
|
-
Description-Content-Type: text/markdown
|
|
23
|
-
|
|
24
|
-
# statedict2pytree
|
|
25
|
-
|
|
26
|
-

|
|
27
|
-
|
|
28
|
-
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
29
|
-
|
|
30
|
-
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
31
|
-
|
|
32
|
-
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
33
|
-
|
|
34
|
-
## Get Started
|
|
35
|
-
|
|
36
|
-
### Installation
|
|
37
|
-
|
|
38
|
-
Run
|
|
39
|
-
|
|
40
|
-
```bash
|
|
41
|
-
pip install statedict2pytree
|
|
42
|
-
|
|
43
|
-
```
|
|
@@ -1,8 +0,0 @@
|
|
|
1
|
-
statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
|
|
2
|
-
statedict2pytree/statedict2pytree.py,sha256=LUM19UvNn8R9jau3iYmDLbgfOlznytNEL3J-d-RoSZ0,6455
|
|
3
|
-
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
|
-
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
|
-
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
-
statedict2pytree-0.1.2.dist-info/METADATA,sha256=X-79GNzLPC9VXRPSVTJao9ysWygmiseBKLd4GmgAY-g,1437
|
|
7
|
-
statedict2pytree-0.1.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
-
statedict2pytree-0.1.2.dist-info/RECORD,,
|
|
File without changes
|