statedict2pytree 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,246 @@
1
+ <!doctype html>
2
+ <html lang="en" data-theme="light">
3
+ <head>
4
+ <meta charset="UTF-8" />
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
6
+ <meta http-equiv="X-UA-Compatible" content="ie=edge" />
7
+ <title>Torch2Jax</title>
8
+
9
+ <link
10
+ rel="stylesheet"
11
+ href="{{ url_for('static', filename='output.css') }}"
12
+ />
13
+ <script src="https://cdn.jsdelivr.net/npm/sweetalert2@11"></script>
14
+ <script type="module">
15
+ import {
16
+ Draggable,
17
+ Sortable,
18
+ Droppable,
19
+ Swappable,
20
+ Plugins,
21
+ } from "https://cdn.jsdelivr.net/npm/@shopify/draggable/build/esm/index.mjs";
22
+ const sortableLists = document.querySelectorAll(".draggable-list");
23
+
24
+ for (const sortableList of sortableLists) {
25
+ const sortable = new Sortable(sortableList, {
26
+ draggable: "li",
27
+ plugins: [Plugins.SortAnimation],
28
+ swapAnimation: {
29
+ duration: 200,
30
+ easingFunction: "ease-in-out",
31
+ },
32
+ });
33
+ }
34
+ </script>
35
+ </head>
36
+ <body class="w-10/12 mx-auto">
37
+ <script>
38
+ const Toast = Swal.mixin({
39
+ toast: true,
40
+ position: "top-end",
41
+ showConfirmButton: false,
42
+ timer: 5000,
43
+ timerProgressBar: true,
44
+ didOpen: (toast) => {
45
+ toast.onmouseenter = Swal.stopTimer;
46
+ toast.onmouseleave = Swal.resumeTimer;
47
+ },
48
+ });
49
+
50
+ async function visualize() {
51
+ const fields = getJaxAndTorchFields();
52
+ if (fields.error) {
53
+ Toast.fire({
54
+ icon: "error",
55
+ title: "The number of fields in JAX and PyTorch should be the same",
56
+ });
57
+ }
58
+ const jaxFields = fields.jaxFields;
59
+ const torchFields = fields.torchFields;
60
+ var data = JSON.stringify({
61
+ jaxFields: jaxFields,
62
+ torchFields: torchFields,
63
+ });
64
+ var xhr = new XMLHttpRequest();
65
+ xhr.open("POST", "/visualize", true);
66
+ xhr.setRequestHeader("Content-Type", "application/json");
67
+ xhr.onload = function () {
68
+ if (xhr.status >= 200 && xhr.status < 300) {
69
+ // Successfully received response
70
+ var container = document.getElementById("visualizationResult");
71
+ container.innerHTML = xhr.responseText;
72
+ // Execute any scripts that were in the response
73
+ var scripts = container.getElementsByTagName("script");
74
+ for (var i = 0; i < scripts.length; i++) {
75
+ var script = document.createElement("script");
76
+ script.text = scripts[i].text;
77
+ document.head.appendChild(script).parentNode.removeChild(script);
78
+ }
79
+ } else {
80
+ document.getElementById("visualizationResult").innerHTML =
81
+ "Error: " + xhr.statusText;
82
+ }
83
+ };
84
+ xhr.send(data);
85
+ }
86
+
87
+ function getJaxAndTorchFields() {
88
+ const jaxFields = Array.from(
89
+ document.querySelectorAll(".draggable-list")[0].children,
90
+ ).map((li) => {
91
+ const path = li.getAttribute("data-path");
92
+ const shape = li.getAttribute("data-shape");
93
+ const type = li.getAttribute("data-type");
94
+ return { path, shape, type };
95
+ });
96
+
97
+ const torchFields = Array.from(
98
+ document.querySelectorAll(".draggable-list")[1].children,
99
+ ).map((li) => {
100
+ const path = li.getAttribute("data-path");
101
+ const shape = li.getAttribute("data-shape");
102
+ return { path, shape };
103
+ });
104
+
105
+ const jaxLength = jaxFields.length;
106
+ const torchLength = torchFields.length;
107
+ if (jaxLength !== torchLength) {
108
+ return {
109
+ error: "The number of fields in JAX and PyTorch should be the same",
110
+ };
111
+ }
112
+
113
+ console.log({
114
+ jaxFields,
115
+ torchFields,
116
+ });
117
+
118
+ for (let i = 0; i < jaxLength; i++) {
119
+ if (jaxFields[i].shape !== torchFields[i].shape) {
120
+ Toast.fire({
121
+ icon: "error",
122
+ title: `${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`,
123
+ });
124
+ return { error: "Invalid shapes" };
125
+ }
126
+ }
127
+
128
+ return { jaxFields: jaxFields, torchFields: torchFields };
129
+ }
130
+
131
+ async function convert() {
132
+ const fields = getJaxAndTorchFields();
133
+ if (fields.error) {
134
+ Toast.fire({
135
+ icon: "error",
136
+ title: "The number of fields in JAX and PyTorch should be the same",
137
+ });
138
+ }
139
+ const jaxFields = fields.jaxFields;
140
+ const torchFields = fields.torchFields;
141
+
142
+ let idField = document.getElementById("name");
143
+ if (!idField) {
144
+ Toast.fire({
145
+ icon: "error",
146
+ title: "Error finding the name!",
147
+ });
148
+ }
149
+
150
+ let name = idField.value;
151
+
152
+ const response = await fetch("/convert", {
153
+ method: "POST",
154
+ headers: {
155
+ "Content-Type": "application/json",
156
+ },
157
+ body: JSON.stringify({
158
+ jaxFields,
159
+ torchFields,
160
+ name,
161
+ }),
162
+ });
163
+
164
+ const res = await response.json();
165
+ console.log(res);
166
+ if (res.error) {
167
+ Toast.fire({
168
+ icon: "error",
169
+ title: res.error,
170
+ });
171
+ } else {
172
+ Toast.fire({
173
+ icon: "success",
174
+ title: "Conversion successful",
175
+ });
176
+ }
177
+ }
178
+ </script>
179
+ <h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
180
+
181
+ <div class="flex space-x-4">
182
+ <div class="w-full">
183
+ <h2 class="text-2xl">JAX</h2>
184
+ <ul class="draggable-list menu bg-base-200 rounded-box">
185
+ {% for field in pytree_fields %}
186
+ <li
187
+ draggable="true"
188
+ data-path="{{field.path}}"
189
+ data-shape="{{field.shape}}"
190
+ data-type="{{field.type}}"
191
+ >
192
+ <p
193
+ class="tooltip text-left"
194
+ data-path="{{field.path}}"
195
+ data-tip="{{field.type}}"
196
+ >
197
+ {{ field.path}} {{field.shape }}
198
+ </p>
199
+ </li>
200
+ {% endfor %}
201
+ </ul>
202
+ </div>
203
+
204
+ <div class="w-full">
205
+ <h2 class="text-2xl">PyTorch</h2>
206
+ <ul class="draggable-list menu bg-base-200 rounded-box w-full">
207
+ {% for field in torch_fields %}
208
+ <li
209
+ draggable="true"
210
+ data-path="{{field.path}}"
211
+ data-shape="{{field.shape}}"
212
+ >
213
+ <p class="text-left">{{ field.path }} {{ field.shape }}</p>
214
+ </li>
215
+ {% endfor %}
216
+ </ul>
217
+ </div>
218
+ </div>
219
+ <div class="flex justify-center my-12 w-full">
220
+ <div class="flex flex-col justify-center w-full">
221
+ <input
222
+ id="name"
223
+ type="text"
224
+ name="name"
225
+ class="input input-primary w-full"
226
+ placeholder="Name of the new file (model.eqx per default)"
227
+ value="model.eqx"
228
+ />
229
+ <button
230
+ onclick="convert()"
231
+ class="btn btn-accent btn-wide btn-lg mx-auto my-2"
232
+ >
233
+ Convert!
234
+ </button>
235
+ </div>
236
+ </div>
237
+ <div class="flex justify-center">
238
+ <button onclick="visualize()" class="btn btn-secondary">
239
+ Visualize with Penzai!
240
+ </button>
241
+ </div>
242
+
243
+ <hr />
244
+ <div id="visualizationResult"></div>
245
+ </body>
246
+ </html>
@@ -0,0 +1,43 @@
1
+ Metadata-Version: 2.3
2
+ Name: statedict2pytree
3
+ Version: 0.1.2
4
+ Summary: Converts torch models into PyTrees for Equinox
5
+ Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
+ Requires-Python: ~=3.10
7
+ Requires-Dist: beartype
8
+ Requires-Dist: equinox>=0.11.4
9
+ Requires-Dist: flask
10
+ Requires-Dist: jax
11
+ Requires-Dist: jaxlib
12
+ Requires-Dist: jaxtyping
13
+ Requires-Dist: loguru
14
+ Requires-Dist: pydantic
15
+ Requires-Dist: torch
16
+ Requires-Dist: typing-extensions
17
+ Provides-Extra: dev
18
+ Requires-Dist: mkdocs; extra == 'dev'
19
+ Requires-Dist: nox; extra == 'dev'
20
+ Requires-Dist: pre-commit; extra == 'dev'
21
+ Requires-Dist: pytest; extra == 'dev'
22
+ Description-Content-Type: text/markdown
23
+
24
+ # statedict2pytree
25
+
26
+ ![statedict2pytree](torch2jax.png "A ResNet demo")
27
+
28
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
29
+
30
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
31
+
32
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
33
+
34
+ ## Get Started
35
+
36
+ ### Installation
37
+
38
+ Run
39
+
40
+ ```bash
41
+ pip install statedict2pytree
42
+
43
+ ```
@@ -0,0 +1,8 @@
1
+ statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
2
+ statedict2pytree/statedict2pytree.py,sha256=LUM19UvNn8R9jau3iYmDLbgfOlznytNEL3J-d-RoSZ0,6455
3
+ statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
+ statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
+ statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
+ statedict2pytree-0.1.2.dist-info/METADATA,sha256=X-79GNzLPC9VXRPSVTJao9ysWygmiseBKLd4GmgAY-g,1437
7
+ statedict2pytree-0.1.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
+ statedict2pytree-0.1.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.24.2
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any