statedict2pytree 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/statedict2pytree.py +16 -9
- statedict2pytree-0.2.0.dist-info/METADATA +126 -0
- {statedict2pytree-0.1.2.dist-info → statedict2pytree-0.2.0.dist-info}/RECORD +4 -4
- statedict2pytree-0.1.2.dist-info/METADATA +0 -43
- {statedict2pytree-0.1.2.dist-info → statedict2pytree-0.2.0.dist-info}/WHEEL +0 -0
|
@@ -95,7 +95,7 @@ def visualize_with_penzai():
|
|
|
95
95
|
return flask.jsonify({"error": "No data received"})
|
|
96
96
|
jax_fields = request_data["jaxFields"]
|
|
97
97
|
torch_fields = request_data["torchFields"]
|
|
98
|
-
model, state = convert(jax_fields, torch_fields)
|
|
98
|
+
model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
|
|
99
99
|
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
|
|
100
100
|
html_jax = pz.ts.render_to_html((model, state))
|
|
101
101
|
html_torch = pz.ts.render_to_html(STATE_DICT)
|
|
@@ -130,7 +130,7 @@ def convert_torch_to_jax():
|
|
|
130
130
|
torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
|
|
131
131
|
|
|
132
132
|
name = request_data["name"]
|
|
133
|
-
model, state = convert(jax_fields, torch_fields)
|
|
133
|
+
model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
|
|
134
134
|
eqx.tree_serialise_leaves(name, (model, state))
|
|
135
135
|
|
|
136
136
|
return flask.jsonify({"status": "success"})
|
|
@@ -146,14 +146,21 @@ def index():
|
|
|
146
146
|
)
|
|
147
147
|
|
|
148
148
|
|
|
149
|
-
def convert(
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
149
|
+
def convert(
|
|
150
|
+
jax_fields: list[JaxField],
|
|
151
|
+
torch_fields: list[TorchField],
|
|
152
|
+
pytree: PyTree,
|
|
153
|
+
state_dict: dict,
|
|
154
|
+
):
|
|
155
|
+
identity = lambda *args, **kwargs: pytree
|
|
154
156
|
model, state = eqx.nn.make_with_state(identity)()
|
|
155
157
|
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
156
158
|
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
159
|
+
if jax_field.shape != torch_field.shape:
|
|
160
|
+
raise ValueError(
|
|
161
|
+
"Fields have incompatible shapes!"
|
|
162
|
+
f"{jax_field.shape=} != {torch_field.shape=}"
|
|
163
|
+
)
|
|
157
164
|
path = jax_field.path.split(".")[1:]
|
|
158
165
|
if "StateIndex" in jax_field.type:
|
|
159
166
|
state_paths.append((jax_field, torch_field))
|
|
@@ -164,7 +171,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
|
|
|
164
171
|
model = eqx.tree_at(
|
|
165
172
|
where,
|
|
166
173
|
model,
|
|
167
|
-
|
|
174
|
+
state_dict[torch_field.path],
|
|
168
175
|
)
|
|
169
176
|
result: dict[str, list[TorchField]] = {}
|
|
170
177
|
for tuple_item in state_paths:
|
|
@@ -177,7 +184,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
|
|
|
177
184
|
for key in result:
|
|
178
185
|
state_index = get_node(model, key.split("."))
|
|
179
186
|
if state_index is not None:
|
|
180
|
-
to_replace_tuple = tuple([
|
|
187
|
+
to_replace_tuple = tuple([state_dict[i.path] for i in result[key]])
|
|
181
188
|
state = state.set(state_index, to_replace_tuple)
|
|
182
189
|
return model, state
|
|
183
190
|
|
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: statedict2pytree
|
|
3
|
+
Version: 0.2.0
|
|
4
|
+
Summary: Converts torch models into PyTrees for Equinox
|
|
5
|
+
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
+
Requires-Python: ~=3.10
|
|
7
|
+
Requires-Dist: beartype
|
|
8
|
+
Requires-Dist: equinox>=0.11.4
|
|
9
|
+
Requires-Dist: flask
|
|
10
|
+
Requires-Dist: jax
|
|
11
|
+
Requires-Dist: jaxlib
|
|
12
|
+
Requires-Dist: jaxtyping
|
|
13
|
+
Requires-Dist: loguru
|
|
14
|
+
Requires-Dist: penzai
|
|
15
|
+
Requires-Dist: pydantic
|
|
16
|
+
Requires-Dist: torch
|
|
17
|
+
Requires-Dist: typing-extensions
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: mkdocs; extra == 'dev'
|
|
20
|
+
Requires-Dist: nox; extra == 'dev'
|
|
21
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
22
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
23
|
+
Description-Content-Type: text/markdown
|
|
24
|
+
|
|
25
|
+
# statedict2pytree
|
|
26
|
+
|
|
27
|
+

|
|
28
|
+
|
|
29
|
+
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
30
|
+
|
|
31
|
+
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
32
|
+
|
|
33
|
+
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
34
|
+
|
|
35
|
+
## Get Started
|
|
36
|
+
|
|
37
|
+
### Installation
|
|
38
|
+
|
|
39
|
+
Run
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pip install statedict2pytree
|
|
43
|
+
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
### Basic Example
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
import equinox as eqx
|
|
50
|
+
import jax
|
|
51
|
+
import torch
|
|
52
|
+
import statedict2pytree as s2p
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def test_mlp():
|
|
56
|
+
in_size = 784
|
|
57
|
+
out_size = 10
|
|
58
|
+
width_size = 64
|
|
59
|
+
depth = 2
|
|
60
|
+
key = jax.random.PRNGKey(22)
|
|
61
|
+
|
|
62
|
+
class EqxMLP(eqx.Module):
|
|
63
|
+
mlp: eqx.nn.MLP
|
|
64
|
+
batch_norm: eqx.nn.BatchNorm
|
|
65
|
+
|
|
66
|
+
def __init__(self, in_size, out_size, width_size, depth, key):
|
|
67
|
+
self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
|
|
68
|
+
self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")
|
|
69
|
+
|
|
70
|
+
def __call__(self, x, state):
|
|
71
|
+
return self.batch_norm(self.mlp(x), state)
|
|
72
|
+
|
|
73
|
+
jax_model = EqxMLP(in_size, out_size, width_size, depth, key)
|
|
74
|
+
|
|
75
|
+
class TorchMLP(torch.nn.Module):
|
|
76
|
+
def __init__(self, in_size, out_size, width_size, depth):
|
|
77
|
+
super(TorchMLP, self).__init__()
|
|
78
|
+
self.layers = torch.nn.ModuleList()
|
|
79
|
+
self.layers.append(torch.nn.Linear(in_size, width_size))
|
|
80
|
+
for _ in range(depth - 1):
|
|
81
|
+
self.layers.append(torch.nn.Linear(width_size, width_size))
|
|
82
|
+
self.layers.append(torch.nn.Linear(width_size, out_size))
|
|
83
|
+
self.batch_norm = torch.nn.BatchNorm1d(out_size)
|
|
84
|
+
|
|
85
|
+
def forward(self, x):
|
|
86
|
+
for layer in self.layers[:-1]:
|
|
87
|
+
x = torch.relu(layer(x))
|
|
88
|
+
x = self.batch_norm(self.layers[-1](x))
|
|
89
|
+
return x
|
|
90
|
+
|
|
91
|
+
torch_model = TorchMLP(in_size, out_size, width_size, depth)
|
|
92
|
+
state_dict = torch_model.state_dict()
|
|
93
|
+
s2p.start_conversion(jax_model, state_dict)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == "__main__":
|
|
97
|
+
test_mlp()
|
|
98
|
+
|
|
99
|
+
```
|
|
100
|
+
|
|
101
|
+
There exists also a function called `s2p.convert` which does the actual conversion:
|
|
102
|
+
|
|
103
|
+
```python
|
|
104
|
+
|
|
105
|
+
class Field(BaseModel):
|
|
106
|
+
path: str
|
|
107
|
+
shape: tuple[int, ...]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class TorchField(Field):
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class JaxField(Field):
|
|
115
|
+
type: str
|
|
116
|
+
|
|
117
|
+
def convert(
|
|
118
|
+
jax_fields: list[JaxField],
|
|
119
|
+
torch_fields: list[TorchField],
|
|
120
|
+
pytree: PyTree,
|
|
121
|
+
state_dict: dict,
|
|
122
|
+
):
|
|
123
|
+
...
|
|
124
|
+
```
|
|
125
|
+
|
|
126
|
+
If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
|
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
|
|
2
|
-
statedict2pytree/statedict2pytree.py,sha256=
|
|
2
|
+
statedict2pytree/statedict2pytree.py,sha256=SQ1Xs4VPG_gnjYdAOBzdoNih1RGVsW9x08UQGkIFUdg,6640
|
|
3
3
|
statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
4
4
|
statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
|
|
5
5
|
statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
|
|
6
|
-
statedict2pytree-0.
|
|
7
|
-
statedict2pytree-0.
|
|
8
|
-
statedict2pytree-0.
|
|
6
|
+
statedict2pytree-0.2.0.dist-info/METADATA,sha256=x9GhcbG0io5HIpfwocS-nsPYyKRTmG90uLKyckSm394,3673
|
|
7
|
+
statedict2pytree-0.2.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
8
|
+
statedict2pytree-0.2.0.dist-info/RECORD,,
|
|
@@ -1,43 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.3
|
|
2
|
-
Name: statedict2pytree
|
|
3
|
-
Version: 0.1.2
|
|
4
|
-
Summary: Converts torch models into PyTrees for Equinox
|
|
5
|
-
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
-
Requires-Python: ~=3.10
|
|
7
|
-
Requires-Dist: beartype
|
|
8
|
-
Requires-Dist: equinox>=0.11.4
|
|
9
|
-
Requires-Dist: flask
|
|
10
|
-
Requires-Dist: jax
|
|
11
|
-
Requires-Dist: jaxlib
|
|
12
|
-
Requires-Dist: jaxtyping
|
|
13
|
-
Requires-Dist: loguru
|
|
14
|
-
Requires-Dist: pydantic
|
|
15
|
-
Requires-Dist: torch
|
|
16
|
-
Requires-Dist: typing-extensions
|
|
17
|
-
Provides-Extra: dev
|
|
18
|
-
Requires-Dist: mkdocs; extra == 'dev'
|
|
19
|
-
Requires-Dist: nox; extra == 'dev'
|
|
20
|
-
Requires-Dist: pre-commit; extra == 'dev'
|
|
21
|
-
Requires-Dist: pytest; extra == 'dev'
|
|
22
|
-
Description-Content-Type: text/markdown
|
|
23
|
-
|
|
24
|
-
# statedict2pytree
|
|
25
|
-
|
|
26
|
-

|
|
27
|
-
|
|
28
|
-
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
29
|
-
|
|
30
|
-
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
31
|
-
|
|
32
|
-
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
33
|
-
|
|
34
|
-
## Get Started
|
|
35
|
-
|
|
36
|
-
### Installation
|
|
37
|
-
|
|
38
|
-
Run
|
|
39
|
-
|
|
40
|
-
```bash
|
|
41
|
-
pip install statedict2pytree
|
|
42
|
-
|
|
43
|
-
```
|
|
File without changes
|