statedict2pytree 0.1.2__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,7 @@
1
1
  from statedict2pytree.statedict2pytree import (
2
+ autoconvert as autoconvert,
2
3
  convert as convert,
4
+ pytree_to_fields as pytree_to_fields,
3
5
  start_conversion as start_conversion,
6
+ state_dict_to_fields as state_dict_to_fields,
4
7
  )
@@ -4,6 +4,7 @@ import re
4
4
  import equinox as eqx
5
5
  import flask
6
6
  import jax
7
+ import numpy as np
7
8
  from beartype.typing import Optional
8
9
  from jaxtyping import PyTree
9
10
  from loguru import logger
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
31
32
  STATE_DICT: Optional[dict] = None
32
33
 
33
34
 
35
+ def can_reshape(shape1, shape2):
36
+ product1 = np.prod(shape1)
37
+ product2 = np.prod(shape2)
38
+
39
+ return product1 == product2
40
+
41
+
34
42
  def get_node(
35
43
  tree: PyTree, targets: list[str], log_when_not_found: bool = False
36
44
  ) -> PyTree | None:
@@ -95,7 +103,7 @@ def visualize_with_penzai():
95
103
  return flask.jsonify({"error": "No data received"})
96
104
  jax_fields = request_data["jaxFields"]
97
105
  torch_fields = request_data["torchFields"]
98
- model, state = convert(jax_fields, torch_fields)
106
+ model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
99
107
  with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
100
108
  html_jax = pz.ts.render_to_html((model, state))
101
109
  html_torch = pz.ts.render_to_html(STATE_DICT)
@@ -130,7 +138,7 @@ def convert_torch_to_jax():
130
138
  torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
131
139
 
132
140
  name = request_data["name"]
133
- model, state = convert(jax_fields, torch_fields)
141
+ model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
134
142
  eqx.tree_serialise_leaves(name, (model, state))
135
143
 
136
144
  return flask.jsonify({"status": "success"})
@@ -146,14 +154,27 @@ def index():
146
154
  )
147
155
 
148
156
 
149
- def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
150
- global PYTREE, STATE_DICT
151
- if STATE_DICT is None:
152
- raise ValueError("STATE_DICT must not be None!")
153
- identity = lambda *args, **kwargs: PYTREE
157
+ def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
+ jax_fields = pytree_to_fields(pytree)
159
+ torch_fields = state_dict_to_fields(state_dict)
160
+ return convert(jax_fields, torch_fields, pytree, state_dict)
161
+
162
+
163
+ def convert(
164
+ jax_fields: list[JaxField],
165
+ torch_fields: list[TorchField],
166
+ pytree: PyTree,
167
+ state_dict: dict,
168
+ ) -> tuple[PyTree, eqx.nn.State]:
169
+ identity = lambda *args, **kwargs: pytree
154
170
  model, state = eqx.nn.make_with_state(identity)()
155
171
  state_paths: list[tuple[JaxField, TorchField]] = []
156
172
  for jax_field, torch_field in zip(jax_fields, torch_fields):
173
+ if not can_reshape(jax_field.shape, torch_field.shape):
174
+ raise ValueError(
175
+ "Fields have incompatible shapes!"
176
+ f"{jax_field.shape=} != {torch_field.shape=}"
177
+ )
157
178
  path = jax_field.path.split(".")[1:]
158
179
  if "StateIndex" in jax_field.type:
159
180
  state_paths.append((jax_field, torch_field))
@@ -164,7 +185,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
164
185
  model = eqx.tree_at(
165
186
  where,
166
187
  model,
167
- STATE_DICT[torch_field.path],
188
+ state_dict[torch_field.path].reshape(jax_field.shape),
168
189
  )
169
190
  result: dict[str, list[TorchField]] = {}
170
191
  for tuple_item in state_paths:
@@ -177,7 +198,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
177
198
  for key in result:
178
199
  state_index = get_node(model, key.split("."))
179
200
  if state_index is not None:
180
- to_replace_tuple = tuple([STATE_DICT[i.path] for i in result[key]])
201
+ to_replace_tuple = tuple([state_dict[i.path] for i in result[key]])
181
202
  state = state.set(state_index, to_replace_tuple)
182
203
  return model, state
183
204
 
@@ -0,0 +1,147 @@
1
+ Metadata-Version: 2.3
2
+ Name: statedict2pytree
3
+ Version: 0.3.0
4
+ Summary: Converts torch models into PyTrees for Equinox
5
+ Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
+ Requires-Python: ~=3.10
7
+ Requires-Dist: beartype
8
+ Requires-Dist: equinox>=0.11.4
9
+ Requires-Dist: flask
10
+ Requires-Dist: jax
11
+ Requires-Dist: jaxlib
12
+ Requires-Dist: jaxtyping
13
+ Requires-Dist: loguru
14
+ Requires-Dist: penzai
15
+ Requires-Dist: pydantic
16
+ Requires-Dist: torch
17
+ Requires-Dist: typing-extensions
18
+ Provides-Extra: dev
19
+ Requires-Dist: mkdocs; extra == 'dev'
20
+ Requires-Dist: nox; extra == 'dev'
21
+ Requires-Dist: pre-commit; extra == 'dev'
22
+ Requires-Dist: pytest; extra == 'dev'
23
+ Description-Content-Type: text/markdown
24
+
25
+ # statedict2pytree
26
+
27
+ ![statedict2pytree](torch2jax.png "A ResNet demo")
28
+
29
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
30
+
31
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
32
+
33
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
+
35
+ ## Shape Matching? What's that?
36
+
37
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
38
+
39
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
40
+
41
+ ## Get Started
42
+
43
+ ### Installation
44
+
45
+ Run
46
+
47
+ ```bash
48
+ pip install statedict2pytree
49
+
50
+ ```
51
+
52
+ ### Basic Example
53
+
54
+ ```python
55
+ import equinox as eqx
56
+ import jax
57
+ import torch
58
+ import statedict2pytree as s2p
59
+
60
+
61
+ def test_mlp():
62
+ in_size = 784
63
+ out_size = 10
64
+ width_size = 64
65
+ depth = 2
66
+ key = jax.random.PRNGKey(22)
67
+
68
+ class EqxMLP(eqx.Module):
69
+ mlp: eqx.nn.MLP
70
+ batch_norm: eqx.nn.BatchNorm
71
+
72
+ def __init__(self, in_size, out_size, width_size, depth, key):
73
+ self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
74
+ self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")
75
+
76
+ def __call__(self, x, state):
77
+ return self.batch_norm(self.mlp(x), state)
78
+
79
+ jax_model = EqxMLP(in_size, out_size, width_size, depth, key)
80
+
81
+ class TorchMLP(torch.nn.Module):
82
+ def __init__(self, in_size, out_size, width_size, depth):
83
+ super(TorchMLP, self).__init__()
84
+ self.layers = torch.nn.ModuleList()
85
+ self.layers.append(torch.nn.Linear(in_size, width_size))
86
+ for _ in range(depth - 1):
87
+ self.layers.append(torch.nn.Linear(width_size, width_size))
88
+ self.layers.append(torch.nn.Linear(width_size, out_size))
89
+ self.batch_norm = torch.nn.BatchNorm1d(out_size)
90
+
91
+ def forward(self, x):
92
+ for layer in self.layers[:-1]:
93
+ x = torch.relu(layer(x))
94
+ x = self.batch_norm(self.layers[-1](x))
95
+ return x
96
+
97
+ torch_model = TorchMLP(in_size, out_size, width_size, depth)
98
+ state_dict = torch_model.state_dict()
99
+ s2p.start_conversion(jax_model, state_dict)
100
+
101
+
102
+ if __name__ == "__main__":
103
+ test_mlp()
104
+
105
+ ```
106
+
107
+ There exists also a function called `s2p.convert` which does the actual conversion:
108
+
109
+ ```python
110
+
111
+ class Field(BaseModel):
112
+ path: str
113
+ shape: tuple[int, ...]
114
+
115
+
116
+ class TorchField(Field):
117
+ pass
118
+
119
+
120
+ class JaxField(Field):
121
+ type: str
122
+
123
+ def convert(
124
+ jax_fields: list[JaxField],
125
+ torch_fields: list[TorchField],
126
+ pytree: PyTree,
127
+ state_dict: dict,
128
+ ):
129
+ ...
130
+ ```
131
+
132
+ If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
133
+
134
+ For the full, automatic experience, use `autoconvert`:
135
+
136
+ ```python
137
+ import statedict2pytree as s2p
138
+
139
+ my_model = Model(...)
140
+ state_dict = ...
141
+
142
+ model, state = s2p.autoconvert(my_model, state_dict)
143
+
144
+ ```
145
+
146
+ This will however only work if your PyTree fields have been declared
147
+ in the same order as they appear in the state dict!
@@ -0,0 +1,8 @@
1
+ statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
2
+ statedict2pytree/statedict2pytree.py,sha256=X5Ljf4lYhhH7_V4KgdciChncbTt7YZpIWHcOxcZ3l48,7103
3
+ statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
+ statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
+ statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
+ statedict2pytree-0.3.0.dist-info/METADATA,sha256=YSK4tWzNQemyZ1xKq5BhWiLWWc-RDr4E9q_eV_iOsdw,4232
7
+ statedict2pytree-0.3.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
+ statedict2pytree-0.3.0.dist-info/RECORD,,
@@ -1,43 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: statedict2pytree
3
- Version: 0.1.2
4
- Summary: Converts torch models into PyTrees for Equinox
5
- Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
- Requires-Python: ~=3.10
7
- Requires-Dist: beartype
8
- Requires-Dist: equinox>=0.11.4
9
- Requires-Dist: flask
10
- Requires-Dist: jax
11
- Requires-Dist: jaxlib
12
- Requires-Dist: jaxtyping
13
- Requires-Dist: loguru
14
- Requires-Dist: pydantic
15
- Requires-Dist: torch
16
- Requires-Dist: typing-extensions
17
- Provides-Extra: dev
18
- Requires-Dist: mkdocs; extra == 'dev'
19
- Requires-Dist: nox; extra == 'dev'
20
- Requires-Dist: pre-commit; extra == 'dev'
21
- Requires-Dist: pytest; extra == 'dev'
22
- Description-Content-Type: text/markdown
23
-
24
- # statedict2pytree
25
-
26
- ![statedict2pytree](torch2jax.png "A ResNet demo")
27
-
28
- The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
29
-
30
- Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
31
-
32
- (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
33
-
34
- ## Get Started
35
-
36
- ### Installation
37
-
38
- Run
39
-
40
- ```bash
41
- pip install statedict2pytree
42
-
43
- ```
@@ -1,8 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
2
- statedict2pytree/statedict2pytree.py,sha256=LUM19UvNn8R9jau3iYmDLbgfOlznytNEL3J-d-RoSZ0,6455
3
- statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
- statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
- statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
- statedict2pytree-0.1.2.dist-info/METADATA,sha256=X-79GNzLPC9VXRPSVTJao9ysWygmiseBKLd4GmgAY-g,1437
7
- statedict2pytree-0.1.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
- statedict2pytree-0.1.2.dist-info/RECORD,,