statedict2pytree 0.1.2__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -95,7 +95,7 @@ def visualize_with_penzai():
95
95
  return flask.jsonify({"error": "No data received"})
96
96
  jax_fields = request_data["jaxFields"]
97
97
  torch_fields = request_data["torchFields"]
98
- model, state = convert(jax_fields, torch_fields)
98
+ model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
99
99
  with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
100
100
  html_jax = pz.ts.render_to_html((model, state))
101
101
  html_torch = pz.ts.render_to_html(STATE_DICT)
@@ -130,7 +130,7 @@ def convert_torch_to_jax():
130
130
  torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
131
131
 
132
132
  name = request_data["name"]
133
- model, state = convert(jax_fields, torch_fields)
133
+ model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
134
134
  eqx.tree_serialise_leaves(name, (model, state))
135
135
 
136
136
  return flask.jsonify({"status": "success"})
@@ -146,14 +146,21 @@ def index():
146
146
  )
147
147
 
148
148
 
149
- def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
150
- global PYTREE, STATE_DICT
151
- if STATE_DICT is None:
152
- raise ValueError("STATE_DICT must not be None!")
153
- identity = lambda *args, **kwargs: PYTREE
149
+ def convert(
150
+ jax_fields: list[JaxField],
151
+ torch_fields: list[TorchField],
152
+ pytree: PyTree,
153
+ state_dict: dict,
154
+ ):
155
+ identity = lambda *args, **kwargs: pytree
154
156
  model, state = eqx.nn.make_with_state(identity)()
155
157
  state_paths: list[tuple[JaxField, TorchField]] = []
156
158
  for jax_field, torch_field in zip(jax_fields, torch_fields):
159
+ if jax_field.shape != torch_field.shape:
160
+ raise ValueError(
161
+ "Fields have incompatible shapes!"
162
+ f"{jax_field.shape=} != {torch_field.shape=}"
163
+ )
157
164
  path = jax_field.path.split(".")[1:]
158
165
  if "StateIndex" in jax_field.type:
159
166
  state_paths.append((jax_field, torch_field))
@@ -164,7 +171,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
164
171
  model = eqx.tree_at(
165
172
  where,
166
173
  model,
167
- STATE_DICT[torch_field.path],
174
+ state_dict[torch_field.path],
168
175
  )
169
176
  result: dict[str, list[TorchField]] = {}
170
177
  for tuple_item in state_paths:
@@ -177,7 +184,7 @@ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
177
184
  for key in result:
178
185
  state_index = get_node(model, key.split("."))
179
186
  if state_index is not None:
180
- to_replace_tuple = tuple([STATE_DICT[i.path] for i in result[key]])
187
+ to_replace_tuple = tuple([state_dict[i.path] for i in result[key]])
181
188
  state = state.set(state_index, to_replace_tuple)
182
189
  return model, state
183
190
 
@@ -0,0 +1,126 @@
1
+ Metadata-Version: 2.3
2
+ Name: statedict2pytree
3
+ Version: 0.2.0
4
+ Summary: Converts torch models into PyTrees for Equinox
5
+ Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
+ Requires-Python: ~=3.10
7
+ Requires-Dist: beartype
8
+ Requires-Dist: equinox>=0.11.4
9
+ Requires-Dist: flask
10
+ Requires-Dist: jax
11
+ Requires-Dist: jaxlib
12
+ Requires-Dist: jaxtyping
13
+ Requires-Dist: loguru
14
+ Requires-Dist: penzai
15
+ Requires-Dist: pydantic
16
+ Requires-Dist: torch
17
+ Requires-Dist: typing-extensions
18
+ Provides-Extra: dev
19
+ Requires-Dist: mkdocs; extra == 'dev'
20
+ Requires-Dist: nox; extra == 'dev'
21
+ Requires-Dist: pre-commit; extra == 'dev'
22
+ Requires-Dist: pytest; extra == 'dev'
23
+ Description-Content-Type: text/markdown
24
+
25
+ # statedict2pytree
26
+
27
+ ![statedict2pytree](torch2jax.png "A ResNet demo")
28
+
29
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
30
+
31
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
32
+
33
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
+
35
+ ## Get Started
36
+
37
+ ### Installation
38
+
39
+ Run
40
+
41
+ ```bash
42
+ pip install statedict2pytree
43
+
44
+ ```
45
+
46
+ ### Basic Example
47
+
48
+ ```python
49
+ import equinox as eqx
50
+ import jax
51
+ import torch
52
+ import statedict2pytree as s2p
53
+
54
+
55
+ def test_mlp():
56
+ in_size = 784
57
+ out_size = 10
58
+ width_size = 64
59
+ depth = 2
60
+ key = jax.random.PRNGKey(22)
61
+
62
+ class EqxMLP(eqx.Module):
63
+ mlp: eqx.nn.MLP
64
+ batch_norm: eqx.nn.BatchNorm
65
+
66
+ def __init__(self, in_size, out_size, width_size, depth, key):
67
+ self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
68
+ self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")
69
+
70
+ def __call__(self, x, state):
71
+ return self.batch_norm(self.mlp(x), state)
72
+
73
+ jax_model = EqxMLP(in_size, out_size, width_size, depth, key)
74
+
75
+ class TorchMLP(torch.nn.Module):
76
+ def __init__(self, in_size, out_size, width_size, depth):
77
+ super(TorchMLP, self).__init__()
78
+ self.layers = torch.nn.ModuleList()
79
+ self.layers.append(torch.nn.Linear(in_size, width_size))
80
+ for _ in range(depth - 1):
81
+ self.layers.append(torch.nn.Linear(width_size, width_size))
82
+ self.layers.append(torch.nn.Linear(width_size, out_size))
83
+ self.batch_norm = torch.nn.BatchNorm1d(out_size)
84
+
85
+ def forward(self, x):
86
+ for layer in self.layers[:-1]:
87
+ x = torch.relu(layer(x))
88
+ x = self.batch_norm(self.layers[-1](x))
89
+ return x
90
+
91
+ torch_model = TorchMLP(in_size, out_size, width_size, depth)
92
+ state_dict = torch_model.state_dict()
93
+ s2p.start_conversion(jax_model, state_dict)
94
+
95
+
96
+ if __name__ == "__main__":
97
+ test_mlp()
98
+
99
+ ```
100
+
101
+ There exists also a function called `s2p.convert` which does the actual conversion:
102
+
103
+ ```python
104
+
105
+ class Field(BaseModel):
106
+ path: str
107
+ shape: tuple[int, ...]
108
+
109
+
110
+ class TorchField(Field):
111
+ pass
112
+
113
+
114
+ class JaxField(Field):
115
+ type: str
116
+
117
+ def convert(
118
+ jax_fields: list[JaxField],
119
+ torch_fields: list[TorchField],
120
+ pytree: PyTree,
121
+ state_dict: dict,
122
+ ):
123
+ ...
124
+ ```
125
+
126
+ If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
@@ -1,8 +1,8 @@
1
1
  statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
2
- statedict2pytree/statedict2pytree.py,sha256=LUM19UvNn8R9jau3iYmDLbgfOlznytNEL3J-d-RoSZ0,6455
2
+ statedict2pytree/statedict2pytree.py,sha256=SQ1Xs4VPG_gnjYdAOBzdoNih1RGVsW9x08UQGkIFUdg,6640
3
3
  statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
4
  statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
5
  statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
- statedict2pytree-0.1.2.dist-info/METADATA,sha256=X-79GNzLPC9VXRPSVTJao9ysWygmiseBKLd4GmgAY-g,1437
7
- statedict2pytree-0.1.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
- statedict2pytree-0.1.2.dist-info/RECORD,,
6
+ statedict2pytree-0.2.0.dist-info/METADATA,sha256=x9GhcbG0io5HIpfwocS-nsPYyKRTmG90uLKyckSm394,3673
7
+ statedict2pytree-0.2.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
+ statedict2pytree-0.2.0.dist-info/RECORD,,
@@ -1,43 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: statedict2pytree
3
- Version: 0.1.2
4
- Summary: Converts torch models into PyTrees for Equinox
5
- Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
- Requires-Python: ~=3.10
7
- Requires-Dist: beartype
8
- Requires-Dist: equinox>=0.11.4
9
- Requires-Dist: flask
10
- Requires-Dist: jax
11
- Requires-Dist: jaxlib
12
- Requires-Dist: jaxtyping
13
- Requires-Dist: loguru
14
- Requires-Dist: pydantic
15
- Requires-Dist: torch
16
- Requires-Dist: typing-extensions
17
- Provides-Extra: dev
18
- Requires-Dist: mkdocs; extra == 'dev'
19
- Requires-Dist: nox; extra == 'dev'
20
- Requires-Dist: pre-commit; extra == 'dev'
21
- Requires-Dist: pytest; extra == 'dev'
22
- Description-Content-Type: text/markdown
23
-
24
- # statedict2pytree
25
-
26
- ![statedict2pytree](torch2jax.png "A ResNet demo")
27
-
28
- The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
29
-
30
- Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
31
-
32
- (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
33
-
34
- ## Get Started
35
-
36
- ### Installation
37
-
38
- Run
39
-
40
- ```bash
41
- pip install statedict2pytree
42
-
43
- ```