statedict2pytree 0.4.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- client/.gitignore +3 -0
- client/package-lock.json +4540 -0
- client/package.json +36 -0
- client/public/bundle.js +9691 -0
- client/public/bundle.js.map +1 -0
- client/public/index.html +14 -0
- {statedict2pytree/static → client/public}/output.css +132 -196
- client/rollup.config.mjs +44 -0
- client/src/App.svelte +343 -0
- client/src/empty.ts +0 -0
- client/src/main.js +8 -0
- client/tailwind.config.js +8 -0
- client/tsconfig.json +5 -0
- statedict2pytree-0.5.2.dist-info/METADATA +64 -0
- statedict2pytree-0.5.2.dist-info/RECORD +17 -0
- statedict2pytree/__init__.py +0 -7
- statedict2pytree/statedict2pytree.py +0 -218
- statedict2pytree/templates/index.html +0 -246
- statedict2pytree-0.4.0.dist-info/METADATA +0 -147
- statedict2pytree-0.4.0.dist-info/RECORD +0 -8
- {statedict2pytree/static → client/public}/input.css +0 -0
- {statedict2pytree-0.4.0.dist-info → statedict2pytree-0.5.2.dist-info}/WHEEL +0 -0
client/src/App.svelte
ADDED
|
@@ -0,0 +1,343 @@
|
|
|
1
|
+
<script lang="ts">
|
|
2
|
+
//@ts-ignore
|
|
3
|
+
import Sortable from "sortablejs/modular/sortable.complete.esm.js";
|
|
4
|
+
import { onMount } from "svelte";
|
|
5
|
+
import Swal from "sweetalert2";
|
|
6
|
+
|
|
7
|
+
let model: string = "model.eqx";
|
|
8
|
+
|
|
9
|
+
const Toast = Swal.mixin({
|
|
10
|
+
toast: true,
|
|
11
|
+
position: "top-end",
|
|
12
|
+
showConfirmButton: false,
|
|
13
|
+
timer: 5000,
|
|
14
|
+
timerProgressBar: true,
|
|
15
|
+
didOpen: (toast) => {
|
|
16
|
+
toast.onmouseenter = Swal.stopTimer;
|
|
17
|
+
toast.onmouseleave = Swal.resumeTimer;
|
|
18
|
+
},
|
|
19
|
+
});
|
|
20
|
+
type Field = {
|
|
21
|
+
path: string;
|
|
22
|
+
shape: Number[];
|
|
23
|
+
skip: boolean;
|
|
24
|
+
type: string | null;
|
|
25
|
+
};
|
|
26
|
+
|
|
27
|
+
let jaxFields: Field[] = [];
|
|
28
|
+
let torchFields: Field[] = [];
|
|
29
|
+
let torchSortable: Sortable;
|
|
30
|
+
onMount(async () => {
|
|
31
|
+
let req = await fetch("/startup/getJaxFields");
|
|
32
|
+
jaxFields = (await req.json()) as Field[];
|
|
33
|
+
req = await fetch("/startup/getTorchFields");
|
|
34
|
+
torchFields = await req.json();
|
|
35
|
+
setTimeout(() => {
|
|
36
|
+
initSortable();
|
|
37
|
+
}, 100);
|
|
38
|
+
});
|
|
39
|
+
|
|
40
|
+
function initSortable() {
|
|
41
|
+
torchSortable = new Sortable(document.getElementById("torch-fields"), {
|
|
42
|
+
animation: 150,
|
|
43
|
+
multiDrag: true,
|
|
44
|
+
ghostClass: "bg-blue-400",
|
|
45
|
+
selectedClass: "bg-accent",
|
|
46
|
+
multiDragKey: "shift",
|
|
47
|
+
onEnd: onEnd,
|
|
48
|
+
});
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
function swap(a: any, b: any, array: any[]) {
|
|
52
|
+
const temp = array[a];
|
|
53
|
+
array[a] = array[b];
|
|
54
|
+
array[b] = temp;
|
|
55
|
+
return array;
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
function fetchJaxAndTorchFields() {
|
|
59
|
+
let allTorchElements =
|
|
60
|
+
document.getElementById("torch-fields")?.children;
|
|
61
|
+
if (!allTorchElements) {
|
|
62
|
+
Toast.fire({
|
|
63
|
+
icon: "error",
|
|
64
|
+
title: "Couldn't find PyTorch elements",
|
|
65
|
+
});
|
|
66
|
+
return {
|
|
67
|
+
error: "Failed to fetch PyTorch elements",
|
|
68
|
+
jaxFields: [],
|
|
69
|
+
torchFields: [],
|
|
70
|
+
};
|
|
71
|
+
}
|
|
72
|
+
let allTorchFields: HTMLElement[] = [];
|
|
73
|
+
for (let i = 0; i < allTorchElements.length; i++) {
|
|
74
|
+
allTorchFields.push(allTorchElements[i].firstChild as HTMLElement);
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
let newTorchFields: Field[] = [];
|
|
78
|
+
allTorchFields.forEach((el) => {
|
|
79
|
+
newTorchFields.push({
|
|
80
|
+
path: el.getAttribute("data-path"),
|
|
81
|
+
shape: el
|
|
82
|
+
.getAttribute("data-shape")
|
|
83
|
+
?.split(",")
|
|
84
|
+
.map((x) => parseInt(x)),
|
|
85
|
+
type: el.getAttribute("data-type"),
|
|
86
|
+
skip: el.getAttribute("data-skip") === "true",
|
|
87
|
+
} as Field);
|
|
88
|
+
});
|
|
89
|
+
|
|
90
|
+
const allJaxFields = document.querySelectorAll('[data-jax="jax"]');
|
|
91
|
+
let newJaxFields: Field[] = [];
|
|
92
|
+
allJaxFields.forEach((el) => {
|
|
93
|
+
newJaxFields.push({
|
|
94
|
+
path: el.getAttribute("data-path"),
|
|
95
|
+
shape: el
|
|
96
|
+
.getAttribute("data-shape")
|
|
97
|
+
?.split(",")
|
|
98
|
+
.map((x) => parseInt(x)),
|
|
99
|
+
type: el.getAttribute("data-type"),
|
|
100
|
+
skip: el.getAttribute("data-skip") === "true",
|
|
101
|
+
} as Field);
|
|
102
|
+
});
|
|
103
|
+
|
|
104
|
+
return { jaxFields: newJaxFields, torchFields: newTorchFields };
|
|
105
|
+
}
|
|
106
|
+
|
|
107
|
+
function onEnd() {
|
|
108
|
+
const updatedFields = fetchJaxAndTorchFields();
|
|
109
|
+
if (updatedFields.error) {
|
|
110
|
+
Toast.fire({
|
|
111
|
+
icon: "error",
|
|
112
|
+
title: updatedFields.error,
|
|
113
|
+
});
|
|
114
|
+
return;
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
for (let i = 0; i < updatedFields.jaxFields.length; i++) {
|
|
118
|
+
let jaxField = updatedFields.jaxFields[i];
|
|
119
|
+
let torchField = updatedFields.torchFields[i];
|
|
120
|
+
if (torchField === undefined) continue;
|
|
121
|
+
if (torchField.skip === true) {
|
|
122
|
+
document
|
|
123
|
+
.getElementById("jax-" + i)
|
|
124
|
+
?.classList.remove("bg-error");
|
|
125
|
+
document
|
|
126
|
+
.getElementById("torch-" + i)
|
|
127
|
+
?.classList.remove("bg-error");
|
|
128
|
+
continue;
|
|
129
|
+
}
|
|
130
|
+
let jaxShape = jaxField.shape;
|
|
131
|
+
let torchShape = torchField.shape;
|
|
132
|
+
//@ts-ignore
|
|
133
|
+
let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
|
|
134
|
+
//@ts-ignore
|
|
135
|
+
let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
|
|
136
|
+
if (jaxShapeProduct !== torchShapeProduct) {
|
|
137
|
+
document.getElementById("jax-" + i)?.classList.add("bg-error");
|
|
138
|
+
document
|
|
139
|
+
.getElementById("torch-" + i)
|
|
140
|
+
?.classList.add("bg-error");
|
|
141
|
+
} else {
|
|
142
|
+
document
|
|
143
|
+
.getElementById("jax-" + i)
|
|
144
|
+
?.classList.remove("bg-error");
|
|
145
|
+
document
|
|
146
|
+
.getElementById("torch-" + i)
|
|
147
|
+
?.classList.remove("bg-error");
|
|
148
|
+
}
|
|
149
|
+
}
|
|
150
|
+
|
|
151
|
+
if (updatedFields.torchFields.length > updatedFields.jaxFields.length) {
|
|
152
|
+
for (
|
|
153
|
+
let i = updatedFields.jaxFields.length;
|
|
154
|
+
i < updatedFields.torchFields.length;
|
|
155
|
+
i++
|
|
156
|
+
) {
|
|
157
|
+
document
|
|
158
|
+
.getElementById("torch-" + i)
|
|
159
|
+
?.classList.remove("bg-error");
|
|
160
|
+
}
|
|
161
|
+
}
|
|
162
|
+
}
|
|
163
|
+
function checkFields(jaxFields: Field[], torchFields: Field[]) {
|
|
164
|
+
if (jaxFields.length > torchFields.length) {
|
|
165
|
+
return {
|
|
166
|
+
error: "JAX and PyTorch have lengths! Make sure to pad the PyTorch side.",
|
|
167
|
+
};
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
for (let i = 0; i < jaxFields.length; i++) {
|
|
171
|
+
let jaxField = jaxFields[i];
|
|
172
|
+
let torchField = torchFields[i];
|
|
173
|
+
if (torchField.skip === true) {
|
|
174
|
+
continue;
|
|
175
|
+
}
|
|
176
|
+
|
|
177
|
+
//@ts-ignore
|
|
178
|
+
let jaxShapeProduct = jaxField.shape.reduce((a, b) => a * b, 1);
|
|
179
|
+
//@ts-ignore
|
|
180
|
+
let torchShapeProduct = torchField.shape.reduce((a, b) => a * b, 1);
|
|
181
|
+
|
|
182
|
+
if (jaxShapeProduct !== torchShapeProduct) {
|
|
183
|
+
return {
|
|
184
|
+
error: `JAX ${jaxField.path} with shape ${jaxField.shape} doesn't match PyTorch ${torchField.path} with shape ${torchField.shape}`,
|
|
185
|
+
};
|
|
186
|
+
}
|
|
187
|
+
}
|
|
188
|
+
return { success: true };
|
|
189
|
+
}
|
|
190
|
+
function removeSkipLayer(index: number) {
|
|
191
|
+
torchFields = torchFields.toSpliced(index, 1);
|
|
192
|
+
setTimeout(() => {
|
|
193
|
+
onEnd();
|
|
194
|
+
}, 100);
|
|
195
|
+
}
|
|
196
|
+
function addSkipLayer(index: number) {
|
|
197
|
+
const newField = {
|
|
198
|
+
skip: true,
|
|
199
|
+
shape: [],
|
|
200
|
+
path: "",
|
|
201
|
+
type: "",
|
|
202
|
+
} as Field;
|
|
203
|
+
torchFields = torchFields.toSpliced(index, 0, newField);
|
|
204
|
+
setTimeout(() => {
|
|
205
|
+
onEnd();
|
|
206
|
+
}, 100);
|
|
207
|
+
}
|
|
208
|
+
|
|
209
|
+
async function convert() {
|
|
210
|
+
let fields = fetchJaxAndTorchFields();
|
|
211
|
+
if (fields.error) {
|
|
212
|
+
Toast.fire({
|
|
213
|
+
icon: "error",
|
|
214
|
+
title: fields.error,
|
|
215
|
+
});
|
|
216
|
+
return;
|
|
217
|
+
}
|
|
218
|
+
|
|
219
|
+
let check = checkFields(fields.jaxFields, fields.torchFields);
|
|
220
|
+
if (check.error) {
|
|
221
|
+
Toast.fire({
|
|
222
|
+
icon: "error",
|
|
223
|
+
title: "Failed to convert",
|
|
224
|
+
text: check.error,
|
|
225
|
+
});
|
|
226
|
+
return;
|
|
227
|
+
}
|
|
228
|
+
|
|
229
|
+
const response = await fetch("/convert", {
|
|
230
|
+
method: "POST",
|
|
231
|
+
headers: {
|
|
232
|
+
"Content-Type": "application/json",
|
|
233
|
+
},
|
|
234
|
+
body: JSON.stringify({
|
|
235
|
+
model: model,
|
|
236
|
+
jaxFields: fields.jaxFields,
|
|
237
|
+
torchFields: fields.torchFields,
|
|
238
|
+
}),
|
|
239
|
+
});
|
|
240
|
+
|
|
241
|
+
const res = await response.json();
|
|
242
|
+
console.log(res);
|
|
243
|
+
if (res.error) {
|
|
244
|
+
Toast.fire({
|
|
245
|
+
icon: "error",
|
|
246
|
+
title: res.error,
|
|
247
|
+
});
|
|
248
|
+
} else {
|
|
249
|
+
Toast.fire({
|
|
250
|
+
icon: "success",
|
|
251
|
+
title: "Conversion successful",
|
|
252
|
+
});
|
|
253
|
+
}
|
|
254
|
+
}
|
|
255
|
+
</script>
|
|
256
|
+
|
|
257
|
+
<svelte:head><title>Statedict2PyTree</title></svelte:head>
|
|
258
|
+
|
|
259
|
+
<h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
|
|
260
|
+
|
|
261
|
+
<div class="grid grid-cols-2 gap-x-2">
|
|
262
|
+
<div class="">
|
|
263
|
+
<h2 class="text-2xl">JAX</h2>
|
|
264
|
+
<div id="jax-fields" class="">
|
|
265
|
+
{#each jaxFields as field, i}
|
|
266
|
+
<div
|
|
267
|
+
class="border h-12 rounded-xl flex flex-col justify-center"
|
|
268
|
+
>
|
|
269
|
+
<div
|
|
270
|
+
id={"jax-" + String(i)}
|
|
271
|
+
class="whitespace-nowrap overflow-x-scroll cursor-pointer mx-2"
|
|
272
|
+
data-jax="jax"
|
|
273
|
+
data-path={field.path}
|
|
274
|
+
data-shape={field.shape}
|
|
275
|
+
data-skip={field.skip}
|
|
276
|
+
data-type={field.type}
|
|
277
|
+
>
|
|
278
|
+
{field.path}
|
|
279
|
+
{field.shape}
|
|
280
|
+
</div>
|
|
281
|
+
</div>
|
|
282
|
+
{/each}
|
|
283
|
+
</div>
|
|
284
|
+
</div>
|
|
285
|
+
|
|
286
|
+
<div class="">
|
|
287
|
+
<h2 class="text-2xl">PyTorch</h2>
|
|
288
|
+
<div id="torch-fields" class="">
|
|
289
|
+
{#each torchFields as field, i}
|
|
290
|
+
<div class="flex space-x-2 border h-12 rounded-xl">
|
|
291
|
+
<div
|
|
292
|
+
id={"torch-" + String(i)}
|
|
293
|
+
data-torch="torch"
|
|
294
|
+
data-path={field.path}
|
|
295
|
+
data-shape={field.shape}
|
|
296
|
+
data-skip={field.skip}
|
|
297
|
+
data-type={field.type}
|
|
298
|
+
class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
|
|
299
|
+
>
|
|
300
|
+
{#if field.skip}
|
|
301
|
+
SKIP
|
|
302
|
+
{:else}
|
|
303
|
+
{field.path}
|
|
304
|
+
{field.shape}
|
|
305
|
+
{/if}
|
|
306
|
+
</div>
|
|
307
|
+
{#if field.skip}
|
|
308
|
+
<button
|
|
309
|
+
class="btn btn-ghost"
|
|
310
|
+
on:click={() => {
|
|
311
|
+
removeSkipLayer(i);
|
|
312
|
+
}}>-</button
|
|
313
|
+
>
|
|
314
|
+
{/if}
|
|
315
|
+
<button
|
|
316
|
+
class="btn btn-ghost"
|
|
317
|
+
on:click={() => {
|
|
318
|
+
addSkipLayer(i);
|
|
319
|
+
}}>+</button
|
|
320
|
+
>
|
|
321
|
+
</div>
|
|
322
|
+
{/each}
|
|
323
|
+
</div>
|
|
324
|
+
</div>
|
|
325
|
+
</div>
|
|
326
|
+
<div class="flex justify-center my-12 w-full">
|
|
327
|
+
<div class="flex flex-col justify-center w-full">
|
|
328
|
+
<input
|
|
329
|
+
id="name"
|
|
330
|
+
type="text"
|
|
331
|
+
name="name"
|
|
332
|
+
class="input input-primary w-full"
|
|
333
|
+
placeholder="Name of the new file (model.eqx per default)"
|
|
334
|
+
bind:value={model}
|
|
335
|
+
/>
|
|
336
|
+
<button
|
|
337
|
+
on:click={convert}
|
|
338
|
+
class="btn btn-accent btn-wide btn-lg mx-auto my-2"
|
|
339
|
+
>
|
|
340
|
+
Convert!
|
|
341
|
+
</button>
|
|
342
|
+
</div>
|
|
343
|
+
</div>
|
client/src/empty.ts
ADDED
|
File without changes
|
client/src/main.js
ADDED
client/tsconfig.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: statedict2pytree
|
|
3
|
+
Version: 0.5.2
|
|
4
|
+
Summary: Converts torch models into PyTrees for Equinox
|
|
5
|
+
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
+
Requires-Python: ~=3.10
|
|
7
|
+
Requires-Dist: beartype
|
|
8
|
+
Requires-Dist: equinox
|
|
9
|
+
Requires-Dist: flask
|
|
10
|
+
Requires-Dist: jax
|
|
11
|
+
Requires-Dist: jaxlib
|
|
12
|
+
Requires-Dist: jaxonmodels
|
|
13
|
+
Requires-Dist: jaxtyping
|
|
14
|
+
Requires-Dist: loguru
|
|
15
|
+
Requires-Dist: penzai
|
|
16
|
+
Requires-Dist: pydantic
|
|
17
|
+
Requires-Dist: torch
|
|
18
|
+
Requires-Dist: torchvision
|
|
19
|
+
Requires-Dist: typing-extensions
|
|
20
|
+
Provides-Extra: dev
|
|
21
|
+
Requires-Dist: mkdocs; extra == 'dev'
|
|
22
|
+
Requires-Dist: nox; extra == 'dev'
|
|
23
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
24
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
25
|
+
Provides-Extra: examples
|
|
26
|
+
Requires-Dist: jaxonmodels; extra == 'examples'
|
|
27
|
+
Description-Content-Type: text/markdown
|
|
28
|
+
|
|
29
|
+
# statedict2pytree
|
|
30
|
+
|
|
31
|
+

|
|
32
|
+
|
|
33
|
+
## Important
|
|
34
|
+
|
|
35
|
+
This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
|
|
36
|
+
PRs and other contributions are *highly* welcome! :)
|
|
37
|
+
|
|
38
|
+
## Info
|
|
39
|
+
|
|
40
|
+
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
41
|
+
|
|
42
|
+
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
43
|
+
|
|
44
|
+
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
45
|
+
|
|
46
|
+
## Shape Matching? What's that?
|
|
47
|
+
|
|
48
|
+
Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
|
|
49
|
+
|
|
50
|
+
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
51
|
+
|
|
52
|
+
## Get Started
|
|
53
|
+
|
|
54
|
+
### Installation
|
|
55
|
+
|
|
56
|
+
Run
|
|
57
|
+
|
|
58
|
+
```bash
|
|
59
|
+
pip install statedict2pytree
|
|
60
|
+
```
|
|
61
|
+
|
|
62
|
+
### Docs
|
|
63
|
+
|
|
64
|
+
Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
client/.gitignore,sha256=dUAd7J1wJVSaU9WD0Gp4c7_5yEz5lOdWLfsjubHf1z8,39
|
|
2
|
+
client/package-lock.json,sha256=JCW5mq9bGNs29yGO1EIlcysDaEuTm2i_AJoyt-45yO8,158531
|
|
3
|
+
client/package.json,sha256=Ad-MDEQeh7BPHWPYLd3u9sXk8YVuO_dXmpkxxU1Pglo,1044
|
|
4
|
+
client/rollup.config.mjs,sha256=RAepJhL2V5Rf-BlJBZJxllVl0mxtr67GSVr9aU0JUnA,1073
|
|
5
|
+
client/tailwind.config.js,sha256=TfN5eOoFOUPGBou6OoK54M14PtokgxWDJUsV4qkurS8,175
|
|
6
|
+
client/tsconfig.json,sha256=cLHEFXx-Q55XqbF9QjQ4XScSEQ15n-vS5tsTcqY4UAY,158
|
|
7
|
+
client/public/bundle.js,sha256=l4Vu_8_7v6k7XFisnI8jCFzyMeztr9jc3Z3lrrPDpk0,347197
|
|
8
|
+
client/public/bundle.js.map,sha256=s9zOkP-34BWrQFunVbkouwTKVAfL8EyIUcfKgSicH8M,682783
|
|
9
|
+
client/public/index.html,sha256=jUx-NPKkFN2EF2lj-8Ml49CEHxKJFWK9seszauI4GE0,335
|
|
10
|
+
client/public/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
11
|
+
client/public/output.css,sha256=3iiBiTGfqAeVKuRZRqgcixX3ztSnlp0zqHXkjSKtmVs,38664
|
|
12
|
+
client/src/App.svelte,sha256=hHVoQ_C2xGMmd4d86ftZFeTlGmOcOW_wJ6abq0_qvWo,11170
|
|
13
|
+
client/src/empty.ts,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
+
client/src/main.js,sha256=O_8UgVd1vJM8BcHO7U_6jkL76ZSA6oC7GLLcL9F3JLA,118
|
|
15
|
+
statedict2pytree-0.5.2.dist-info/METADATA,sha256=YgE7bgWDMI6urA71bH-zWleR_mw6SJ5QPneUZVvHL2E,2242
|
|
16
|
+
statedict2pytree-0.5.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
17
|
+
statedict2pytree-0.5.2.dist-info/RECORD,,
|
statedict2pytree/__init__.py
DELETED
|
@@ -1,218 +0,0 @@
|
|
|
1
|
-
import functools as ft
|
|
2
|
-
import re
|
|
3
|
-
|
|
4
|
-
import equinox as eqx
|
|
5
|
-
import flask
|
|
6
|
-
import jax
|
|
7
|
-
import numpy as np
|
|
8
|
-
from beartype.typing import Optional
|
|
9
|
-
from jaxtyping import PyTree
|
|
10
|
-
from loguru import logger
|
|
11
|
-
from penzai import pz
|
|
12
|
-
from pydantic import BaseModel
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
app = flask.Flask(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class Field(BaseModel):
|
|
19
|
-
path: str
|
|
20
|
-
shape: tuple[int, ...]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
class TorchField(Field):
|
|
24
|
-
pass
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class JaxField(Field):
|
|
28
|
-
type: str
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
PYTREE: Optional[PyTree] = None
|
|
32
|
-
STATE_DICT: Optional[dict] = None
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
def can_reshape(shape1, shape2):
|
|
36
|
-
product1 = np.prod(shape1)
|
|
37
|
-
product2 = np.prod(shape2)
|
|
38
|
-
|
|
39
|
-
return product1 == product2
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def get_node(
|
|
43
|
-
tree: PyTree, targets: list[str], log_when_not_found: bool = False
|
|
44
|
-
) -> PyTree | None:
|
|
45
|
-
if len(targets) == 0 or tree is None:
|
|
46
|
-
return tree
|
|
47
|
-
else:
|
|
48
|
-
next_target: str = targets[0]
|
|
49
|
-
if bool(re.search(r"\[\d+\]", next_target)):
|
|
50
|
-
split_index = next_target.rfind("[")
|
|
51
|
-
name, index = next_target[:split_index], next_target[split_index:]
|
|
52
|
-
index = index[1:-1]
|
|
53
|
-
if hasattr(tree, name):
|
|
54
|
-
subtree = getattr(tree, name)[int(index)]
|
|
55
|
-
else:
|
|
56
|
-
subtree = None
|
|
57
|
-
if log_when_not_found:
|
|
58
|
-
logger.info(f"Couldn't find {name} in {tree.__class__}")
|
|
59
|
-
else:
|
|
60
|
-
if hasattr(tree, next_target):
|
|
61
|
-
subtree = getattr(tree, next_target)
|
|
62
|
-
else:
|
|
63
|
-
subtree = None
|
|
64
|
-
if log_when_not_found:
|
|
65
|
-
logger.info(f"Couldn't find {next_target} in {tree.__class__}")
|
|
66
|
-
return get_node(subtree, targets[1:])
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def pytree_to_fields(pytree: PyTree) -> list[JaxField]:
|
|
70
|
-
flattened, _ = jax.tree_util.tree_flatten_with_path(pytree)
|
|
71
|
-
fields: list[JaxField] = []
|
|
72
|
-
for key_path, value in flattened:
|
|
73
|
-
path = jax.tree_util.keystr(key_path)
|
|
74
|
-
type_path = path.split(".")[1:-1]
|
|
75
|
-
target_path = path.split(".")[1:]
|
|
76
|
-
node_type = type(get_node(pytree, type_path, log_when_not_found=True))
|
|
77
|
-
node = get_node(pytree, target_path, log_when_not_found=True)
|
|
78
|
-
if node is not None and hasattr(node, "shape") and len(node.shape) > 0:
|
|
79
|
-
fields.append(
|
|
80
|
-
JaxField(path=path, type=str(node_type), shape=tuple(node.shape))
|
|
81
|
-
)
|
|
82
|
-
|
|
83
|
-
return fields
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
def state_dict_to_fields(state_dict: Optional[dict]) -> list[TorchField]:
|
|
87
|
-
if state_dict is None:
|
|
88
|
-
return []
|
|
89
|
-
fields: list[TorchField] = []
|
|
90
|
-
for key, value in state_dict.items():
|
|
91
|
-
if hasattr(value, "shape") and len(value.shape) > 0:
|
|
92
|
-
fields.append(TorchField(path=key, shape=tuple(value.shape)))
|
|
93
|
-
return fields
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
@app.route("/visualize", methods=["POST"])
|
|
97
|
-
def visualize_with_penzai():
|
|
98
|
-
global PYTREE, STATE_DICT
|
|
99
|
-
if PYTREE is None or STATE_DICT is None:
|
|
100
|
-
return flask.jsonify({"error": "No Pytree or StateDict found"})
|
|
101
|
-
request_data = flask.request.json
|
|
102
|
-
if request_data is None:
|
|
103
|
-
return flask.jsonify({"error": "No data received"})
|
|
104
|
-
jax_fields = request_data["jaxFields"]
|
|
105
|
-
torch_fields = request_data["torchFields"]
|
|
106
|
-
model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
|
|
107
|
-
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
|
|
108
|
-
html_jax = pz.ts.render_to_html((model, state))
|
|
109
|
-
html_torch = pz.ts.render_to_html(STATE_DICT)
|
|
110
|
-
|
|
111
|
-
combined_html = f"<html><body>{html_jax}<hr>{html_torch}</body></html>"
|
|
112
|
-
return combined_html
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
@app.route("/convert", methods=["POST"])
|
|
116
|
-
def convert_torch_to_jax():
|
|
117
|
-
global PYTREE, STATE_DICT
|
|
118
|
-
if PYTREE is None or STATE_DICT is None:
|
|
119
|
-
return flask.jsonify({"error": "No Pytree or StateDict found"})
|
|
120
|
-
request_data = flask.request.json
|
|
121
|
-
if request_data is None:
|
|
122
|
-
return flask.jsonify({"error": "No data received"})
|
|
123
|
-
|
|
124
|
-
jax_fields_json = request_data["jaxFields"]
|
|
125
|
-
jax_fields: list[JaxField] = []
|
|
126
|
-
for f in jax_fields_json:
|
|
127
|
-
shape_tuple = tuple(
|
|
128
|
-
[int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
|
|
129
|
-
)
|
|
130
|
-
jax_fields.append(JaxField(path=f["path"], type=f["type"], shape=shape_tuple))
|
|
131
|
-
|
|
132
|
-
torch_fields_json = request_data["torchFields"]
|
|
133
|
-
torch_fields: list[TorchField] = []
|
|
134
|
-
for f in torch_fields_json:
|
|
135
|
-
shape_tuple = tuple(
|
|
136
|
-
[int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
|
|
137
|
-
)
|
|
138
|
-
torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
|
|
139
|
-
|
|
140
|
-
name = request_data["name"]
|
|
141
|
-
model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
|
|
142
|
-
eqx.tree_serialise_leaves(name, (model, state))
|
|
143
|
-
|
|
144
|
-
return flask.jsonify({"status": "success"})
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
@app.route("/", methods=["GET"])
|
|
148
|
-
def index():
|
|
149
|
-
pytree_fields = pytree_to_fields(PYTREE)
|
|
150
|
-
return flask.render_template(
|
|
151
|
-
"index.html",
|
|
152
|
-
pytree_fields=pytree_fields,
|
|
153
|
-
torch_fields=state_dict_to_fields(STATE_DICT),
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
|
-
jax_fields = pytree_to_fields(pytree)
|
|
159
|
-
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
-
|
|
161
|
-
for k, v in state_dict.items():
|
|
162
|
-
state_dict[k] = v.numpy()
|
|
163
|
-
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
def convert(
|
|
167
|
-
jax_fields: list[JaxField],
|
|
168
|
-
torch_fields: list[TorchField],
|
|
169
|
-
pytree: PyTree,
|
|
170
|
-
state_dict: dict,
|
|
171
|
-
) -> tuple[PyTree, eqx.nn.State]:
|
|
172
|
-
identity = lambda *args, **kwargs: pytree
|
|
173
|
-
model, state = eqx.nn.make_with_state(identity)()
|
|
174
|
-
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
175
|
-
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
176
|
-
if not can_reshape(jax_field.shape, torch_field.shape):
|
|
177
|
-
raise ValueError(
|
|
178
|
-
"Fields have incompatible shapes!"
|
|
179
|
-
f"{jax_field.shape=} != {torch_field.shape=}"
|
|
180
|
-
)
|
|
181
|
-
path = jax_field.path.split(".")[1:]
|
|
182
|
-
if "StateIndex" in jax_field.type:
|
|
183
|
-
state_paths.append((jax_field, torch_field))
|
|
184
|
-
|
|
185
|
-
else:
|
|
186
|
-
where = ft.partial(get_node, targets=path)
|
|
187
|
-
if where(model) is not None:
|
|
188
|
-
model = eqx.tree_at(
|
|
189
|
-
where,
|
|
190
|
-
model,
|
|
191
|
-
state_dict[torch_field.path].reshape(jax_field.shape),
|
|
192
|
-
)
|
|
193
|
-
result: dict[str, list[TorchField]] = {}
|
|
194
|
-
for tuple_item in state_paths:
|
|
195
|
-
path_prefix = tuple_item[0].path.split(".")[1:-1]
|
|
196
|
-
prefix_key = ".".join(path_prefix)
|
|
197
|
-
if prefix_key not in result:
|
|
198
|
-
result[prefix_key] = []
|
|
199
|
-
result[prefix_key].append(tuple_item[1])
|
|
200
|
-
|
|
201
|
-
for key in result:
|
|
202
|
-
state_index = get_node(model, key.split("."))
|
|
203
|
-
if state_index is not None:
|
|
204
|
-
to_replace_tuple = tuple([state_dict[i.path] for i in result[key]])
|
|
205
|
-
state = state.set(state_index, to_replace_tuple)
|
|
206
|
-
return model, state
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
def start_conversion(pytree: PyTree, state_dict: dict):
|
|
210
|
-
global PYTREE, STATE_DICT
|
|
211
|
-
if state_dict is None:
|
|
212
|
-
raise ValueError("STATE_DICT must not be None!")
|
|
213
|
-
PYTREE = pytree
|
|
214
|
-
STATE_DICT = state_dict
|
|
215
|
-
|
|
216
|
-
for k, v in STATE_DICT.items():
|
|
217
|
-
STATE_DICT[k] = v.numpy()
|
|
218
|
-
app.run(debug=True, port=5500)
|