statedict2pytree 0.2.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/PKG-INFO +22 -1
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/README.md +21 -0
- statedict2pytree-0.4.0/examples/convert_resnet.py +20 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/examples/resnet.py +43 -12
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/examples/test_resnet_inference.py +6 -6
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/pyproject.toml +2 -2
- statedict2pytree-0.4.0/pyrightconfig.json +4 -0
- statedict2pytree-0.4.0/statedict2pytree/__init__.py +7 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/statedict2pytree.py +21 -4
- statedict2pytree-0.4.0/tests/test_conv.py +61 -0
- statedict2pytree-0.4.0/tests/test_linear.py +45 -0
- statedict2pytree-0.2.0/examples/convert_resnet.py +0 -16
- statedict2pytree-0.2.0/statedict2pytree/__init__.py +0 -4
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/.gitignore +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/.pre-commit-config.yaml +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/examples/doggo.jpeg +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/package-lock.json +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/package.json +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/static/input.css +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/static/output.css +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/templates/index.html +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/tailwind.config.js +0 -0
- {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/torch2jax.png +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: ~=3.10
|
|
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
|
|
|
32
32
|
|
|
33
33
|
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
34
34
|
|
|
35
|
+
## Shape Matching? What's that?
|
|
36
|
+
|
|
37
|
+
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
|
|
38
|
+
|
|
39
|
+
1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
40
|
+
|
|
35
41
|
## Get Started
|
|
36
42
|
|
|
37
43
|
### Installation
|
|
@@ -124,3 +130,18 @@ def convert(
|
|
|
124
130
|
```
|
|
125
131
|
|
|
126
132
|
If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
|
|
133
|
+
|
|
134
|
+
For the full, automatic experience, use `autoconvert`:
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
import statedict2pytree as s2p
|
|
138
|
+
|
|
139
|
+
my_model = Model(...)
|
|
140
|
+
state_dict = ...
|
|
141
|
+
|
|
142
|
+
model, state = s2p.autoconvert(my_model, state_dict)
|
|
143
|
+
|
|
144
|
+
```
|
|
145
|
+
|
|
146
|
+
This will however only work if your PyTree fields have been declared
|
|
147
|
+
in the same order as they appear in the state dict!
|
|
@@ -8,6 +8,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
|
|
|
8
8
|
|
|
9
9
|
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
10
10
|
|
|
11
|
+
## Shape Matching? What's that?
|
|
12
|
+
|
|
13
|
+
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
|
|
14
|
+
|
|
15
|
+
1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
16
|
+
|
|
11
17
|
## Get Started
|
|
12
18
|
|
|
13
19
|
### Installation
|
|
@@ -100,3 +106,18 @@ def convert(
|
|
|
100
106
|
```
|
|
101
107
|
|
|
102
108
|
If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
|
|
109
|
+
|
|
110
|
+
For the full, automatic experience, use `autoconvert`:
|
|
111
|
+
|
|
112
|
+
```python
|
|
113
|
+
import statedict2pytree as s2p
|
|
114
|
+
|
|
115
|
+
my_model = Model(...)
|
|
116
|
+
state_dict = ...
|
|
117
|
+
|
|
118
|
+
model, state = s2p.autoconvert(my_model, state_dict)
|
|
119
|
+
|
|
120
|
+
```
|
|
121
|
+
|
|
122
|
+
This will however only work if your PyTree fields have been declared
|
|
123
|
+
in the same order as they appear in the state dict!
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
import statedict2pytree as s2p
|
|
4
|
+
from resnet import resnet152
|
|
5
|
+
from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_resnet():
|
|
9
|
+
resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
10
|
+
resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
|
|
11
|
+
state_dict = resnet_torch.state_dict()
|
|
12
|
+
|
|
13
|
+
# s2p.start_conversion(resnet_jax, state_dict)
|
|
14
|
+
model, state = s2p.autoconvert(resnet_jax, state_dict)
|
|
15
|
+
name = "resnet152.eqx"
|
|
16
|
+
eqx.tree_serialise_leaves(name, (model, state))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == "__main__":
|
|
20
|
+
convert_resnet()
|
|
@@ -333,12 +333,22 @@ def resnet18(
|
|
|
333
333
|
|
|
334
334
|
|
|
335
335
|
def resnet34(
|
|
336
|
-
image_channels: int = 3,
|
|
336
|
+
image_channels: int = 3,
|
|
337
|
+
num_classes: int = 1000,
|
|
338
|
+
*,
|
|
339
|
+
key: PRNGKeyArray,
|
|
340
|
+
make_with_state: bool = True,
|
|
341
|
+
**kwargs,
|
|
337
342
|
):
|
|
338
343
|
layers = [3, 4, 6, 3]
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
344
|
+
if make_with_state:
|
|
345
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
346
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
return ResNet(
|
|
350
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
351
|
+
)
|
|
342
352
|
|
|
343
353
|
|
|
344
354
|
def resnet50(
|
|
@@ -361,18 +371,39 @@ def resnet50(
|
|
|
361
371
|
|
|
362
372
|
|
|
363
373
|
def resnet101(
|
|
364
|
-
image_channels: int = 3,
|
|
374
|
+
image_channels: int = 3,
|
|
375
|
+
num_classes: int = 1000,
|
|
376
|
+
*,
|
|
377
|
+
key: PRNGKeyArray,
|
|
378
|
+
make_with_state: bool = True,
|
|
379
|
+
**kwargs,
|
|
365
380
|
):
|
|
366
381
|
layers = [3, 4, 23, 3]
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
382
|
+
|
|
383
|
+
if make_with_state:
|
|
384
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
385
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
return ResNet(
|
|
389
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
390
|
+
)
|
|
370
391
|
|
|
371
392
|
|
|
372
393
|
def resnet152(
|
|
373
|
-
image_channels: int = 3,
|
|
394
|
+
image_channels: int = 3,
|
|
395
|
+
num_classes: int = 1000,
|
|
396
|
+
*,
|
|
397
|
+
key: PRNGKeyArray,
|
|
398
|
+
make_with_state: bool = True,
|
|
399
|
+
**kwargs,
|
|
374
400
|
):
|
|
375
401
|
layers = [3, 8, 36, 3]
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
402
|
+
if make_with_state:
|
|
403
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
404
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
return ResNet(
|
|
408
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
409
|
+
)
|
|
@@ -6,15 +6,15 @@ import equinox as eqx
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import torch
|
|
9
|
+
from examples.resnet import resnet152
|
|
9
10
|
from PIL import Image
|
|
10
|
-
from tests.resnet import resnet50
|
|
11
11
|
from torchvision import transforms
|
|
12
|
-
from torchvision.models import
|
|
12
|
+
from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def test_resnet():
|
|
16
|
-
resnet_jax =
|
|
17
|
-
resnet_torch =
|
|
16
|
+
resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
17
|
+
resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
|
|
18
18
|
|
|
19
19
|
img_name = "doggo.jpeg"
|
|
20
20
|
|
|
@@ -42,7 +42,7 @@ def test_resnet():
|
|
|
42
42
|
) # Outputs the ImageNet class index of the prediction
|
|
43
43
|
|
|
44
44
|
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
|
|
45
|
-
with urllib.request.urlopen(url) as url:
|
|
45
|
+
with urllib.request.urlopen(url) as url: # pyright: ignore
|
|
46
46
|
imagenet_labels = json.loads(url.read().decode())
|
|
47
47
|
|
|
48
48
|
label = imagenet_labels[str(predicted.item())][1]
|
|
@@ -52,7 +52,7 @@ def test_resnet():
|
|
|
52
52
|
model_callable = ft.partial(identity, resnet_jax)
|
|
53
53
|
model, state = eqx.nn.make_with_state(model_callable)()
|
|
54
54
|
|
|
55
|
-
model, state = eqx.tree_deserialise_leaves("
|
|
55
|
+
model, state = eqx.tree_deserialise_leaves("resnet152.eqx", (model, state))
|
|
56
56
|
|
|
57
57
|
jax_batch = jnp.array(batch_t.numpy())
|
|
58
58
|
out, state = eqx.filter_vmap(
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "statedict2pytree"
|
|
3
|
-
version = "0.
|
|
3
|
+
version = "0.4.0"
|
|
4
4
|
description = "Converts torch models into PyTrees for Equinox"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
requires-python = "~=3.10"
|
|
@@ -16,7 +16,7 @@ dependencies = [
|
|
|
16
16
|
"torch",
|
|
17
17
|
"flask",
|
|
18
18
|
"pydantic",
|
|
19
|
-
"penzai"
|
|
19
|
+
"penzai",
|
|
20
20
|
]
|
|
21
21
|
[project.optional-dependencies]
|
|
22
22
|
dev = ["nox", "pre-commit", "pytest", "mkdocs"]
|
|
@@ -4,6 +4,7 @@ import re
|
|
|
4
4
|
import equinox as eqx
|
|
5
5
|
import flask
|
|
6
6
|
import jax
|
|
7
|
+
import numpy as np
|
|
7
8
|
from beartype.typing import Optional
|
|
8
9
|
from jaxtyping import PyTree
|
|
9
10
|
from loguru import logger
|
|
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
|
|
|
31
32
|
STATE_DICT: Optional[dict] = None
|
|
32
33
|
|
|
33
34
|
|
|
35
|
+
def can_reshape(shape1, shape2):
|
|
36
|
+
product1 = np.prod(shape1)
|
|
37
|
+
product2 = np.prod(shape2)
|
|
38
|
+
|
|
39
|
+
return product1 == product2
|
|
40
|
+
|
|
41
|
+
|
|
34
42
|
def get_node(
|
|
35
43
|
tree: PyTree, targets: list[str], log_when_not_found: bool = False
|
|
36
44
|
) -> PyTree | None:
|
|
@@ -38,7 +46,7 @@ def get_node(
|
|
|
38
46
|
return tree
|
|
39
47
|
else:
|
|
40
48
|
next_target: str = targets[0]
|
|
41
|
-
if bool(re.search(r"\[\d
|
|
49
|
+
if bool(re.search(r"\[\d+\]", next_target)):
|
|
42
50
|
split_index = next_target.rfind("[")
|
|
43
51
|
name, index = next_target[:split_index], next_target[split_index:]
|
|
44
52
|
index = index[1:-1]
|
|
@@ -146,17 +154,26 @@ def index():
|
|
|
146
154
|
)
|
|
147
155
|
|
|
148
156
|
|
|
157
|
+
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
|
+
jax_fields = pytree_to_fields(pytree)
|
|
159
|
+
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
+
|
|
161
|
+
for k, v in state_dict.items():
|
|
162
|
+
state_dict[k] = v.numpy()
|
|
163
|
+
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
164
|
+
|
|
165
|
+
|
|
149
166
|
def convert(
|
|
150
167
|
jax_fields: list[JaxField],
|
|
151
168
|
torch_fields: list[TorchField],
|
|
152
169
|
pytree: PyTree,
|
|
153
170
|
state_dict: dict,
|
|
154
|
-
):
|
|
171
|
+
) -> tuple[PyTree, eqx.nn.State]:
|
|
155
172
|
identity = lambda *args, **kwargs: pytree
|
|
156
173
|
model, state = eqx.nn.make_with_state(identity)()
|
|
157
174
|
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
158
175
|
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
159
|
-
if jax_field.shape
|
|
176
|
+
if not can_reshape(jax_field.shape, torch_field.shape):
|
|
160
177
|
raise ValueError(
|
|
161
178
|
"Fields have incompatible shapes!"
|
|
162
179
|
f"{jax_field.shape=} != {torch_field.shape=}"
|
|
@@ -171,7 +188,7 @@ def convert(
|
|
|
171
188
|
model = eqx.tree_at(
|
|
172
189
|
where,
|
|
173
190
|
model,
|
|
174
|
-
state_dict[torch_field.path],
|
|
191
|
+
state_dict[torch_field.path].reshape(jax_field.shape),
|
|
175
192
|
)
|
|
176
193
|
result: dict[str, list[TorchField]] = {}
|
|
177
194
|
for tuple_item in state_paths:
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
import numpy as np
|
|
4
|
+
import statedict2pytree as s2p
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_conv():
|
|
9
|
+
in_channels = 8
|
|
10
|
+
out_channels = 8
|
|
11
|
+
kernel_size = 4
|
|
12
|
+
stride = 2
|
|
13
|
+
padding = 1
|
|
14
|
+
|
|
15
|
+
class J(eqx.Module):
|
|
16
|
+
conv: eqx.nn.Conv2d
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
self.conv = eqx.nn.Conv2d(
|
|
20
|
+
in_channels=in_channels,
|
|
21
|
+
out_channels=out_channels,
|
|
22
|
+
kernel_size=kernel_size,
|
|
23
|
+
stride=stride,
|
|
24
|
+
padding=padding,
|
|
25
|
+
key=jax.random.PRNGKey(22),
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
class T(torch.nn.Module):
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
super(T, self).__init__()
|
|
31
|
+
self.conv = torch.nn.Conv2d(
|
|
32
|
+
in_channels=in_channels,
|
|
33
|
+
out_channels=out_channels,
|
|
34
|
+
kernel_size=kernel_size,
|
|
35
|
+
stride=stride,
|
|
36
|
+
padding=padding,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
jax_model = J()
|
|
40
|
+
torch_model = T()
|
|
41
|
+
state_dict = torch_model.state_dict()
|
|
42
|
+
|
|
43
|
+
jax_fields = s2p.pytree_to_fields(jax_model)
|
|
44
|
+
torch_fields = s2p.state_dict_to_fields(state_dict)
|
|
45
|
+
|
|
46
|
+
model, state = s2p.convert(
|
|
47
|
+
jax_fields, torch_fields, pytree=jax_model, state_dict=state_dict
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
assert np.allclose(
|
|
51
|
+
np.array(model.conv.weight), torch_model.conv.weight.detach().numpy()
|
|
52
|
+
)
|
|
53
|
+
if torch_model.conv.bias is not None:
|
|
54
|
+
assert np.allclose(
|
|
55
|
+
np.array(model.conv.bias),
|
|
56
|
+
torch_model.conv.bias.detach().numpy().reshape(model.conv.bias.shape),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
if __name__ == "__main__":
|
|
61
|
+
test_conv()
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
import numpy as np
|
|
4
|
+
import statedict2pytree as s2p
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_linear():
|
|
9
|
+
in_features = 10
|
|
10
|
+
out_features = 10
|
|
11
|
+
|
|
12
|
+
class J(eqx.Module):
|
|
13
|
+
linear: eqx.nn.Linear
|
|
14
|
+
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self.linear = eqx.nn.Linear(
|
|
17
|
+
in_features, out_features, key=jax.random.PRNGKey(30)
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
class T(torch.nn.Module):
|
|
21
|
+
def __init__(self) -> None:
|
|
22
|
+
super(T, self).__init__()
|
|
23
|
+
self.linear = torch.nn.Linear(in_features, out_features)
|
|
24
|
+
|
|
25
|
+
jax_model = J()
|
|
26
|
+
torch_model = T()
|
|
27
|
+
state_dict = torch_model.state_dict()
|
|
28
|
+
|
|
29
|
+
jax_fields = s2p.pytree_to_fields(jax_model)
|
|
30
|
+
torch_fields = s2p.state_dict_to_fields(state_dict)
|
|
31
|
+
|
|
32
|
+
model, state = s2p.convert(
|
|
33
|
+
jax_fields, torch_fields, pytree=jax_model, state_dict=state_dict
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
assert np.allclose(
|
|
37
|
+
np.array(model.linear.weight), torch_model.linear.weight.detach().numpy()
|
|
38
|
+
)
|
|
39
|
+
assert np.allclose(
|
|
40
|
+
np.array(model.linear.bias), torch_model.linear.bias.detach().numpy()
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
if __name__ == "__main__":
|
|
45
|
+
test_linear()
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import statedict2pytree as s2p
|
|
3
|
-
from resnet import resnet50
|
|
4
|
-
from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def convert_resnet():
|
|
8
|
-
resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
9
|
-
resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
10
|
-
state_dict = resnet_torch.state_dict()
|
|
11
|
-
|
|
12
|
-
s2p.start_conversion(resnet_jax, state_dict)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
if __name__ == "__main__":
|
|
16
|
-
convert_resnet()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|