statedict2pytree 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,308 +0,0 @@
1
- <!doctype html>
2
- <html lang="en" data-theme="light">
3
- <head>
4
- <meta charset="UTF-8" />
5
- <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
- <meta http-equiv="X-UA-Compatible" content="ie=edge" />
7
- <title>Torch2Jax</title>
8
-
9
- <link
10
- rel="stylesheet"
11
- href="{{ url_for('static', filename='output.css') }}"
12
- />
13
- <script src="https://cdn.jsdelivr.net/npm/sweetalert2@11"></script>
14
- <script src="https://cdn.jsdelivr.net/npm/sortablejs@latest/Sortable.min.js"></script>
15
- <script type="module">
16
- let jaxSortable = new Sortable(document.getElementById("jax-fields"), {
17
- animation: 150,
18
- ghostClass: "blue-background-class",
19
- });
20
-
21
- let torchSortable = new Sortable(
22
- document.getElementById("torch-fields"),
23
- {
24
- animation: 150,
25
- ghostClass: "bg-blue-400",
26
- onEnd: function (evt) {
27
- document.getElementById("error-field").classList.add("hidden");
28
- let allJaxFields = document.querySelectorAll("#jax-fields > div");
29
- let allTorchFields = document.querySelectorAll(
30
- "#torch-fields > div",
31
- );
32
- if (allJaxFields.length !== allTorchFields.length) {
33
- Swal.fire({
34
- icon: "error",
35
- title:
36
- "The number of fields in JAX and PyTorch should be the same",
37
- });
38
- } else {
39
- for (let i = 0; i < allJaxFields.length; i++) {
40
- let jaxField = allJaxFields[i];
41
- let torchField = allTorchFields[i];
42
- let jaxShape = jaxField.getAttribute("data-shape");
43
- let torchShape = torchField.getAttribute("data-shape");
44
-
45
- jaxShape = jaxShape
46
- .replace("(", "")
47
- .replace(")", "")
48
- .replace(/\s+/g, "")
49
- .replace(/,\s*$/, "");
50
-
51
- torchShape = torchShape
52
- .replace("(", "")
53
- .replace(")", "")
54
- .replace(/\s+/g, "")
55
- .replace(/,\s*$/, "");
56
-
57
- let jaxShapeParts = jaxShape.split(",").map((x) => parseInt(x));
58
- let torchShapeParts = torchShape
59
- .split(",")
60
- .map((x) => parseInt(x));
61
- let jaxShapeProduct = jaxShapeParts.reduce((a, b) => a * b, 1);
62
- let torchShapeProduct = torchShapeParts.reduce(
63
- (a, b) => a * b,
64
- 1,
65
- );
66
-
67
- let jaxEl = jaxField;
68
- let torchEl = torchField;
69
- if (jaxShapeProduct !== torchShapeProduct) {
70
- jaxEl.classList.add("bg-error");
71
- torchEl.classList.add("bg-error");
72
- } else {
73
- jaxEl.classList.remove("bg-error");
74
- torchEl.classList.remove("bg-error");
75
- }
76
- }
77
- }
78
- },
79
- },
80
- );
81
- </script>
82
- </head>
83
- <body class="w-10/12 mx-auto">
84
- <script>
85
- const Toast = Swal.mixin({
86
- toast: true,
87
- position: "top-end",
88
- showConfirmButton: false,
89
- timer: 5000,
90
- timerProgressBar: true,
91
- didOpen: (toast) => {
92
- toast.onmouseenter = Swal.stopTimer;
93
- toast.onmouseleave = Swal.resumeTimer;
94
- },
95
- });
96
-
97
- async function visualize() {
98
- const fields = getJaxAndTorchFields();
99
- if (fields.error) {
100
- Toast.fire({
101
- icon: "error",
102
- title: "The number of fields in JAX and PyTorch should be the same",
103
- });
104
- }
105
- const jaxFields = fields.jaxFields;
106
- const torchFields = fields.torchFields;
107
- var data = JSON.stringify({
108
- jaxFields: jaxFields,
109
- torchFields: torchFields,
110
- });
111
- var xhr = new XMLHttpRequest();
112
- xhr.open("POST", "/visualize", true);
113
- xhr.setRequestHeader("Content-Type", "application/json");
114
- xhr.onload = function () {
115
- if (xhr.status >= 200 && xhr.status < 300) {
116
- var container = document.getElementById("visualizationResult");
117
- container.innerHTML = xhr.responseText;
118
- var scripts = container.getElementsByTagName("script");
119
- for (var i = 0; i < scripts.length; i++) {
120
- var script = document.createElement("script");
121
- script.text = scripts[i].text;
122
- document.head.appendChild(script).parentNode.removeChild(script);
123
- }
124
- } else {
125
- document.getElementById("visualizationResult").innerHTML =
126
- "Error: " + xhr.statusText;
127
- }
128
- };
129
- xhr.send(data);
130
- }
131
-
132
- function getJaxAndTorchFields() {
133
- const jaxFields = Array.from(
134
- document.querySelectorAll("#jax-fields")[0].children,
135
- ).map((li) => {
136
- const path = li.getAttribute("data-path");
137
- const shape = li.getAttribute("data-shape");
138
- const type = li.getAttribute("data-type");
139
- return { path, shape, type };
140
- });
141
-
142
- const torchFields = Array.from(
143
- document.querySelectorAll("#torch-fields")[0].children,
144
- ).map((li) => {
145
- const path = li.getAttribute("data-path");
146
- const shape = li.getAttribute("data-shape");
147
- return { path, shape };
148
- });
149
-
150
- const jaxLength = jaxFields.length;
151
- const torchLength = torchFields.length;
152
- if (jaxLength !== torchLength) {
153
- return {
154
- error: "The number of fields in JAX and PyTorch should be the same",
155
- };
156
- }
157
-
158
- console.log({
159
- jaxFields,
160
- torchFields,
161
- });
162
-
163
- for (let i = 0; i < jaxLength; i++) {
164
- if (jaxFields[i].shape !== torchFields[i].shape) {
165
- Toast.fire({
166
- icon: "error",
167
- title: `${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`,
168
- });
169
- document.getElementById("error-field").classList.remove("hidden");
170
- document
171
- .getElementById("error-field")
172
- .querySelector("span").innerText =
173
- `${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`;
174
- return { error: "Invalid shapes" };
175
- }
176
- }
177
-
178
- return { jaxFields: jaxFields, torchFields: torchFields };
179
- }
180
-
181
- async function convert() {
182
- const fields = getJaxAndTorchFields();
183
- if (fields.error) {
184
- Toast.fire({
185
- icon: "error",
186
- title: "Failed to convert!",
187
- text: fields.error,
188
- });
189
- }
190
- const jaxFields = fields.jaxFields;
191
- const torchFields = fields.torchFields;
192
-
193
- let idField = document.getElementById("name");
194
- if (!idField) {
195
- Toast.fire({
196
- icon: "error",
197
- title: "Error finding the name!",
198
- });
199
- }
200
-
201
- let name = idField.value;
202
-
203
- const response = await fetch("/convert", {
204
- method: "POST",
205
- headers: {
206
- "Content-Type": "application/json",
207
- },
208
- body: JSON.stringify({
209
- jaxFields,
210
- torchFields,
211
- name,
212
- }),
213
- });
214
-
215
- const res = await response.json();
216
- console.log(res);
217
- if (res.error) {
218
- Toast.fire({
219
- icon: "error",
220
- title: res.error,
221
- });
222
- } else {
223
- Toast.fire({
224
- icon: "success",
225
- title: "Conversion successful",
226
- });
227
- }
228
- }
229
- </script>
230
- <h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
231
-
232
- <div class="grid grid-cols-2 gap-x-2">
233
- <div class="">
234
- <h2 class="text-2xl">JAX</h2>
235
- <div id="jax-fields" class="bg-base-200">
236
- {% for field in pytree_fields %}
237
- <div
238
- data-path="{{field.path}}"
239
- data-shape="{{field.shape}}"
240
- data-type="{{field.type}}"
241
- class="whitespace-nowrap overflow-x-scroll cursor-pointer"
242
- >
243
- {{ field.path }} {{ field.shape }}
244
- </div>
245
- {% endfor %}
246
- </div>
247
- </div>
248
-
249
- <div class="">
250
- <h2 class="text-2xl">PyTorch</h2>
251
- <div id="torch-fields" class="bg-base-200">
252
- {% for field in torch_fields %}
253
- <div
254
- data-path="{{field.path}}"
255
- data-shape="{{field.shape}}"
256
- class="whitespace-nowrap overflow-x-scroll cursor-pointer"
257
- >
258
- {{ field.path }} {{ field.shape }}
259
- </div>
260
- {% endfor %}
261
- </div>
262
- </div>
263
- </div>
264
- <div class="flex justify-center my-12 w-full">
265
- <div class="flex flex-col justify-center w-full">
266
- <input
267
- id="name"
268
- type="text"
269
- name="name"
270
- class="input input-primary w-full"
271
- placeholder="Name of the new file (model.eqx per default)"
272
- value="model.eqx"
273
- />
274
- <button
275
- onclick="convert()"
276
- class="btn btn-accent btn-wide btn-lg mx-auto my-2"
277
- >
278
- Convert!
279
- </button>
280
- </div>
281
- </div>
282
- <div role="alert" class="alert alert-error hidden" id="error-field">
283
- <svg
284
- xmlns="http://www.w3.org/2000/svg"
285
- class="stroke-current shrink-0 h-6 w-6"
286
- fill="none"
287
- viewBox="0 0 24 24"
288
- >
289
- <path
290
- stroke-linecap="round"
291
- stroke-linejoin="round"
292
- stroke-width="2"
293
- d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z"
294
- />
295
- </svg>
296
- <span></span>
297
- </div>
298
-
299
- <div class="flex justify-center">
300
- <button onclick="visualize()" class="btn btn-secondary">
301
- Visualize with Penzai!
302
- </button>
303
- </div>
304
-
305
- <hr />
306
- <div id="visualizationResult"></div>
307
- </body>
308
- </html>
@@ -1,147 +0,0 @@
1
- Metadata-Version: 2.3
2
- Name: statedict2pytree
3
- Version: 0.5.0
4
- Summary: Converts torch models into PyTrees for Equinox
5
- Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
- Requires-Python: ~=3.10
7
- Requires-Dist: beartype
8
- Requires-Dist: equinox>=0.11.4
9
- Requires-Dist: flask
10
- Requires-Dist: jax
11
- Requires-Dist: jaxlib
12
- Requires-Dist: jaxtyping
13
- Requires-Dist: loguru
14
- Requires-Dist: penzai
15
- Requires-Dist: pydantic
16
- Requires-Dist: torch
17
- Requires-Dist: typing-extensions
18
- Provides-Extra: dev
19
- Requires-Dist: mkdocs; extra == 'dev'
20
- Requires-Dist: nox; extra == 'dev'
21
- Requires-Dist: pre-commit; extra == 'dev'
22
- Requires-Dist: pytest; extra == 'dev'
23
- Description-Content-Type: text/markdown
24
-
25
- # statedict2pytree
26
-
27
- ![statedict2pytree](torch2jax.png "A ResNet demo")
28
-
29
- The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
30
-
31
- Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
32
-
33
- (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
-
35
- ## Shape Matching? What's that?
36
-
37
- Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
38
-
39
- 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
40
-
41
- ## Get Started
42
-
43
- ### Installation
44
-
45
- Run
46
-
47
- ```bash
48
- pip install statedict2pytree
49
-
50
- ```
51
-
52
- ### Basic Example
53
-
54
- ```python
55
- import equinox as eqx
56
- import jax
57
- import torch
58
- import statedict2pytree as s2p
59
-
60
-
61
- def test_mlp():
62
- in_size = 784
63
- out_size = 10
64
- width_size = 64
65
- depth = 2
66
- key = jax.random.PRNGKey(22)
67
-
68
- class EqxMLP(eqx.Module):
69
- mlp: eqx.nn.MLP
70
- batch_norm: eqx.nn.BatchNorm
71
-
72
- def __init__(self, in_size, out_size, width_size, depth, key):
73
- self.mlp = eqx.nn.MLP(in_size, out_size, width_size, depth, key=key)
74
- self.batch_norm = eqx.nn.BatchNorm(out_size, axis_name="batch")
75
-
76
- def __call__(self, x, state):
77
- return self.batch_norm(self.mlp(x), state)
78
-
79
- jax_model = EqxMLP(in_size, out_size, width_size, depth, key)
80
-
81
- class TorchMLP(torch.nn.Module):
82
- def __init__(self, in_size, out_size, width_size, depth):
83
- super(TorchMLP, self).__init__()
84
- self.layers = torch.nn.ModuleList()
85
- self.layers.append(torch.nn.Linear(in_size, width_size))
86
- for _ in range(depth - 1):
87
- self.layers.append(torch.nn.Linear(width_size, width_size))
88
- self.layers.append(torch.nn.Linear(width_size, out_size))
89
- self.batch_norm = torch.nn.BatchNorm1d(out_size)
90
-
91
- def forward(self, x):
92
- for layer in self.layers[:-1]:
93
- x = torch.relu(layer(x))
94
- x = self.batch_norm(self.layers[-1](x))
95
- return x
96
-
97
- torch_model = TorchMLP(in_size, out_size, width_size, depth)
98
- state_dict = torch_model.state_dict()
99
- s2p.start_conversion(jax_model, state_dict)
100
-
101
-
102
- if __name__ == "__main__":
103
- test_mlp()
104
-
105
- ```
106
-
107
- There exists also a function called `s2p.convert` which does the actual conversion:
108
-
109
- ```python
110
-
111
- class Field(BaseModel):
112
- path: str
113
- shape: tuple[int, ...]
114
-
115
-
116
- class TorchField(Field):
117
- pass
118
-
119
-
120
- class JaxField(Field):
121
- type: str
122
-
123
- def convert(
124
- jax_fields: list[JaxField],
125
- torch_fields: list[TorchField],
126
- pytree: PyTree,
127
- state_dict: dict,
128
- ):
129
- ...
130
- ```
131
-
132
- If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
133
-
134
- For the full, automatic experience, use `autoconvert`:
135
-
136
- ```python
137
- import statedict2pytree as s2p
138
-
139
- my_model = Model(...)
140
- state_dict = ...
141
-
142
- model, state = s2p.autoconvert(my_model, state_dict)
143
-
144
- ```
145
-
146
- This will however only work if your PyTree fields have been declared
147
- in the same order as they appear in the state dict!
@@ -1,8 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
2
- statedict2pytree/statedict2pytree.py,sha256=yLOWx1D-6tX1VjiEg_-JcYPTrg6KWAgw6waZQi1GNvA,7229
3
- statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
- statedict2pytree/static/output.css,sha256=B0itthSyy_tduTWMyTK5sAry-W6WbeODnpQ-oOcQQng,33966
5
- statedict2pytree/templates/index.html,sha256=Mbo8fFHV6kYRiBiiwayku-p-y3hUaLw_Yj3zn_cfmb0,10027
6
- statedict2pytree-0.5.0.dist-info/METADATA,sha256=TOf10T0EZoPGAc0qSltZRcr8Ni7y4bHW7w3wzRVJH7A,4232
7
- statedict2pytree-0.5.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
- statedict2pytree-0.5.0.dist-info/RECORD,,
File without changes