statedict2pytree 0.2.0__tar.gz → 0.3.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/PKG-INFO +22 -1
  2. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/README.md +21 -0
  3. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/pyproject.toml +2 -2
  4. statedict2pytree-0.3.0/pyrightconfig.json +4 -0
  5. statedict2pytree-0.3.0/statedict2pytree/__init__.py +7 -0
  6. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/statedict2pytree/statedict2pytree.py +17 -3
  7. statedict2pytree-0.3.0/tests/test_conv.py +61 -0
  8. statedict2pytree-0.3.0/tests/test_linear.py +45 -0
  9. statedict2pytree-0.2.0/statedict2pytree/__init__.py +0 -4
  10. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/.gitignore +0 -0
  11. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/.pre-commit-config.yaml +0 -0
  12. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/examples/convert_resnet.py +0 -0
  13. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/examples/doggo.jpeg +0 -0
  14. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/examples/resnet.py +0 -0
  15. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/examples/test_resnet_inference.py +0 -0
  16. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/package-lock.json +0 -0
  17. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/package.json +0 -0
  18. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/statedict2pytree/static/input.css +0 -0
  19. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/statedict2pytree/static/output.css +0 -0
  20. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/statedict2pytree/templates/index.html +0 -0
  21. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/tailwind.config.js +0 -0
  22. {statedict2pytree-0.2.0 → statedict2pytree-0.3.0}/torch2jax.png +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
32
32
 
33
33
  (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
34
 
35
+ ## Shape Matching? What's that?
36
+
37
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
38
+
39
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
40
+
35
41
  ## Get Started
36
42
 
37
43
  ### Installation
@@ -124,3 +130,18 @@ def convert(
124
130
  ```
125
131
 
126
132
  If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
133
+
134
+ For the full, automatic experience, use `autoconvert`:
135
+
136
+ ```python
137
+ import statedict2pytree as s2p
138
+
139
+ my_model = Model(...)
140
+ state_dict = ...
141
+
142
+ model, state = s2p.autoconvert(my_model, state_dict)
143
+
144
+ ```
145
+
146
+ This will however only work if your PyTree fields have been declared
147
+ in the same order as they appear in the state dict!
@@ -8,6 +8,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
8
8
 
9
9
  (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
10
10
 
11
+ ## Shape Matching? What's that?
12
+
13
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
14
+
15
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
16
+
11
17
  ## Get Started
12
18
 
13
19
  ### Installation
@@ -100,3 +106,18 @@ def convert(
100
106
  ```
101
107
 
102
108
  If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
109
+
110
+ For the full, automatic experience, use `autoconvert`:
111
+
112
+ ```python
113
+ import statedict2pytree as s2p
114
+
115
+ my_model = Model(...)
116
+ state_dict = ...
117
+
118
+ model, state = s2p.autoconvert(my_model, state_dict)
119
+
120
+ ```
121
+
122
+ This will however only work if your PyTree fields have been declared
123
+ in the same order as they appear in the state dict!
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "statedict2pytree"
3
- version = "0.2.0"
3
+ version = "0.3.0"
4
4
  description = "Converts torch models into PyTrees for Equinox"
5
5
  readme = "README.md"
6
6
  requires-python = "~=3.10"
@@ -16,7 +16,7 @@ dependencies = [
16
16
  "torch",
17
17
  "flask",
18
18
  "pydantic",
19
- "penzai"
19
+ "penzai",
20
20
  ]
21
21
  [project.optional-dependencies]
22
22
  dev = ["nox", "pre-commit", "pytest", "mkdocs"]
@@ -0,0 +1,4 @@
1
+ {
2
+ "venvPath": ".",
3
+ "venv": ".venv"
4
+ }
@@ -0,0 +1,7 @@
1
+ from statedict2pytree.statedict2pytree import (
2
+ autoconvert as autoconvert,
3
+ convert as convert,
4
+ pytree_to_fields as pytree_to_fields,
5
+ start_conversion as start_conversion,
6
+ state_dict_to_fields as state_dict_to_fields,
7
+ )
@@ -4,6 +4,7 @@ import re
4
4
  import equinox as eqx
5
5
  import flask
6
6
  import jax
7
+ import numpy as np
7
8
  from beartype.typing import Optional
8
9
  from jaxtyping import PyTree
9
10
  from loguru import logger
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
31
32
  STATE_DICT: Optional[dict] = None
32
33
 
33
34
 
35
+ def can_reshape(shape1, shape2):
36
+ product1 = np.prod(shape1)
37
+ product2 = np.prod(shape2)
38
+
39
+ return product1 == product2
40
+
41
+
34
42
  def get_node(
35
43
  tree: PyTree, targets: list[str], log_when_not_found: bool = False
36
44
  ) -> PyTree | None:
@@ -146,17 +154,23 @@ def index():
146
154
  )
147
155
 
148
156
 
157
+ def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
+ jax_fields = pytree_to_fields(pytree)
159
+ torch_fields = state_dict_to_fields(state_dict)
160
+ return convert(jax_fields, torch_fields, pytree, state_dict)
161
+
162
+
149
163
  def convert(
150
164
  jax_fields: list[JaxField],
151
165
  torch_fields: list[TorchField],
152
166
  pytree: PyTree,
153
167
  state_dict: dict,
154
- ):
168
+ ) -> tuple[PyTree, eqx.nn.State]:
155
169
  identity = lambda *args, **kwargs: pytree
156
170
  model, state = eqx.nn.make_with_state(identity)()
157
171
  state_paths: list[tuple[JaxField, TorchField]] = []
158
172
  for jax_field, torch_field in zip(jax_fields, torch_fields):
159
- if jax_field.shape != torch_field.shape:
173
+ if not can_reshape(jax_field.shape, torch_field.shape):
160
174
  raise ValueError(
161
175
  "Fields have incompatible shapes!"
162
176
  f"{jax_field.shape=} != {torch_field.shape=}"
@@ -171,7 +185,7 @@ def convert(
171
185
  model = eqx.tree_at(
172
186
  where,
173
187
  model,
174
- state_dict[torch_field.path],
188
+ state_dict[torch_field.path].reshape(jax_field.shape),
175
189
  )
176
190
  result: dict[str, list[TorchField]] = {}
177
191
  for tuple_item in state_paths:
@@ -0,0 +1,61 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import numpy as np
4
+ import statedict2pytree as s2p
5
+ import torch
6
+
7
+
8
+ def test_conv():
9
+ in_channels = 8
10
+ out_channels = 8
11
+ kernel_size = 4
12
+ stride = 2
13
+ padding = 1
14
+
15
+ class J(eqx.Module):
16
+ conv: eqx.nn.Conv2d
17
+
18
+ def __init__(self):
19
+ self.conv = eqx.nn.Conv2d(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=kernel_size,
23
+ stride=stride,
24
+ padding=padding,
25
+ key=jax.random.PRNGKey(22),
26
+ )
27
+
28
+ class T(torch.nn.Module):
29
+ def __init__(self) -> None:
30
+ super(T, self).__init__()
31
+ self.conv = torch.nn.Conv2d(
32
+ in_channels=in_channels,
33
+ out_channels=out_channels,
34
+ kernel_size=kernel_size,
35
+ stride=stride,
36
+ padding=padding,
37
+ )
38
+
39
+ jax_model = J()
40
+ torch_model = T()
41
+ state_dict = torch_model.state_dict()
42
+
43
+ jax_fields = s2p.pytree_to_fields(jax_model)
44
+ torch_fields = s2p.state_dict_to_fields(state_dict)
45
+
46
+ model, state = s2p.convert(
47
+ jax_fields, torch_fields, pytree=jax_model, state_dict=state_dict
48
+ )
49
+
50
+ assert np.allclose(
51
+ np.array(model.conv.weight), torch_model.conv.weight.detach().numpy()
52
+ )
53
+ if torch_model.conv.bias is not None:
54
+ assert np.allclose(
55
+ np.array(model.conv.bias),
56
+ torch_model.conv.bias.detach().numpy().reshape(model.conv.bias.shape),
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ test_conv()
@@ -0,0 +1,45 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import numpy as np
4
+ import statedict2pytree as s2p
5
+ import torch
6
+
7
+
8
+ def test_linear():
9
+ in_features = 10
10
+ out_features = 10
11
+
12
+ class J(eqx.Module):
13
+ linear: eqx.nn.Linear
14
+
15
+ def __init__(self):
16
+ self.linear = eqx.nn.Linear(
17
+ in_features, out_features, key=jax.random.PRNGKey(30)
18
+ )
19
+
20
+ class T(torch.nn.Module):
21
+ def __init__(self) -> None:
22
+ super(T, self).__init__()
23
+ self.linear = torch.nn.Linear(in_features, out_features)
24
+
25
+ jax_model = J()
26
+ torch_model = T()
27
+ state_dict = torch_model.state_dict()
28
+
29
+ jax_fields = s2p.pytree_to_fields(jax_model)
30
+ torch_fields = s2p.state_dict_to_fields(state_dict)
31
+
32
+ model, state = s2p.convert(
33
+ jax_fields, torch_fields, pytree=jax_model, state_dict=state_dict
34
+ )
35
+
36
+ assert np.allclose(
37
+ np.array(model.linear.weight), torch_model.linear.weight.detach().numpy()
38
+ )
39
+ assert np.allclose(
40
+ np.array(model.linear.bias), torch_model.linear.bias.detach().numpy()
41
+ )
42
+
43
+
44
+ if __name__ == "__main__":
45
+ test_linear()
@@ -1,4 +0,0 @@
1
- from statedict2pytree.statedict2pytree import (
2
- convert as convert,
3
- start_conversion as start_conversion,
4
- )