statedict2pytree 0.6.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
client/rollup.config.mjs DELETED
@@ -1,44 +0,0 @@
1
- import svelte from "rollup-plugin-svelte";
2
- import { nodeResolve } from "@rollup/plugin-node-resolve";
3
- import commonjs from "@rollup/plugin-commonjs";
4
- import livereload from "rollup-plugin-livereload";
5
- import terser from "@rollup/plugin-terser";
6
- import { sveltePreprocess } from "svelte-preprocess";
7
- import typescript from "@rollup/plugin-typescript";
8
-
9
- const production = !process.env.ROLLUP_WATCH;
10
-
11
- export default {
12
- input: "src/main.js",
13
- output: {
14
- sourcemap: true,
15
- format: "iife",
16
- name: "app",
17
- file: "public/bundle.js",
18
- },
19
- plugins: [
20
- svelte({
21
- compilerOptions: {
22
- dev: !production,
23
- css: (css) => {
24
- css.write("public/bundle.css");
25
- },
26
- },
27
- preprocess: sveltePreprocess(),
28
- }),
29
-
30
- typescript({ sourceMap: !production }),
31
- nodeResolve({
32
- browser: true,
33
- dedupe: (importee) =>
34
- importee === "svelte" || importee.startsWith("svelte/"),
35
- }),
36
- commonjs(),
37
-
38
- !production && livereload("public"),
39
- production && terser(),
40
- ],
41
- watch: {
42
- clearScreen: false,
43
- },
44
- };
client/src/App.svelte DELETED
@@ -1,584 +0,0 @@
1
- <script lang="ts">
2
- //@ts-ignore
3
- import Sortable from "sortablejs/modular/sortable.complete.esm.js";
4
- import { onMount } from "svelte";
5
- import Swal from "sweetalert2";
6
-
7
- let model: string = "model.eqx";
8
- let anthropicModel: "opus" | "sonnet" | "sonnet3.5" | "haiku" = "haiku";
9
-
10
- const Toast = Swal.mixin({
11
- toast: true,
12
- position: "top-end",
13
- showConfirmButton: false,
14
- timer: 5000,
15
- timerProgressBar: true,
16
- didOpen: (toast) => {
17
- toast.onmouseenter = Swal.stopTimer;
18
- toast.onmouseleave = Swal.resumeTimer;
19
- },
20
- });
21
- type Field = {
22
- path: string;
23
- shape: Number[];
24
- skip: boolean;
25
- type: string | null;
26
- };
27
-
28
- let jaxFields: Field[] = [];
29
- let torchFields: Field[] = [];
30
- onMount(async () => {
31
- let req = await fetch("/startup/getJaxFields");
32
- jaxFields = (await req.json()) as Field[];
33
- req = await fetch("/startup/getTorchFields");
34
- torchFields = await req.json();
35
- setTimeout(() => {
36
- initSortable();
37
- }, 100);
38
-
39
- setTimeout(() => {
40
- onEnd();
41
- }, 500);
42
- });
43
-
44
- function initSortable() {
45
- new Sortable(document.getElementById("torch-fields"), {
46
- animation: 150,
47
- multiDrag: true,
48
- ghostClass: "bg-blue-400",
49
- selectedClass: "bg-accent",
50
- multiDragKey: "shift",
51
- onEnd: onEnd,
52
- });
53
- }
54
-
55
- function fetchJaxAndTorchFields() {
56
- let allTorchElements =
57
- document.getElementById("torch-fields")?.children;
58
- if (!allTorchElements) {
59
- Toast.fire({
60
- icon: "error",
61
- title: "Couldn't find PyTorch elements",
62
- });
63
- return {
64
- error: "Failed to fetch PyTorch elements",
65
- jaxFields: [],
66
- torchFields: [],
67
- };
68
- }
69
- let allTorchFields: HTMLElement[] = [];
70
- for (let i = 0; i < allTorchElements.length; i++) {
71
- allTorchFields.push(allTorchElements[i].firstChild as HTMLElement);
72
- }
73
-
74
- let newTorchFields: Field[] = [];
75
- allTorchFields.forEach((el) => {
76
- newTorchFields.push({
77
- path: el.getAttribute("data-path"),
78
- shape: el
79
- .getAttribute("data-shape")
80
- ?.split(",")
81
- .map((x) => parseInt(x)),
82
- type: el.getAttribute("data-type"),
83
- skip: el.getAttribute("data-skip") === "true",
84
- } as Field);
85
- });
86
-
87
- const allJaxFields = document.querySelectorAll('[data-jax="jax"]');
88
- let newJaxFields: Field[] = [];
89
- allJaxFields.forEach((el) => {
90
- newJaxFields.push({
91
- path: el.getAttribute("data-path"),
92
- shape: el
93
- .getAttribute("data-shape")
94
- ?.split(",")
95
- .map((x) => parseInt(x)),
96
- type: el.getAttribute("data-type"),
97
- skip: el.getAttribute("data-skip") === "true",
98
- } as Field);
99
- });
100
-
101
- return { jaxFields: newJaxFields, torchFields: newTorchFields };
102
- }
103
-
104
- function onEnd() {
105
- setTimeout(() => {
106
- const updatedFields = fetchJaxAndTorchFields();
107
- if (updatedFields.error) {
108
- Toast.fire({
109
- icon: "error",
110
- title: updatedFields.error,
111
- });
112
- return;
113
- }
114
-
115
- for (let i = 0; i < updatedFields.jaxFields.length; i++) {
116
- let jaxField = updatedFields.jaxFields[i];
117
- let torchField = updatedFields.torchFields[i];
118
- if (torchField === undefined) continue;
119
- if (torchField.skip === true) {
120
- document
121
- .getElementById("jax-" + i)
122
- ?.classList.remove("bg-error");
123
- continue;
124
- }
125
- let jaxShape = jaxField.shape;
126
- let torchShape = torchField.shape;
127
- //@ts-ignore
128
- let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
129
- //@ts-ignore
130
- let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
131
- if (jaxShapeProduct !== torchShapeProduct) {
132
- document
133
- .getElementById("jax-" + i)
134
- ?.classList.add("bg-error");
135
- } else {
136
- document
137
- .getElementById("jax-" + i)
138
- ?.classList.remove("bg-error");
139
- }
140
- }
141
-
142
- if (
143
- updatedFields.torchFields.length >
144
- updatedFields.jaxFields.length
145
- ) {
146
- for (
147
- let i = updatedFields.jaxFields.length;
148
- i < updatedFields.torchFields.length;
149
- i++
150
- ) {
151
- document
152
- .getElementById("torch-" + i)
153
- ?.classList.remove("bg-error");
154
- }
155
- }
156
- }, 100);
157
- }
158
- function checkFields(jaxFields: Field[], torchFields: Field[]) {
159
- if (jaxFields.length > torchFields.length) {
160
- return {
161
- error: "JAX and PyTorch have lengths! Make sure to pad the PyTorch side.",
162
- };
163
- }
164
-
165
- for (let i = 0; i < jaxFields.length; i++) {
166
- let jaxField = jaxFields[i];
167
- let torchField = torchFields[i];
168
- if (torchField.skip === true) {
169
- continue;
170
- }
171
-
172
- //@ts-ignore
173
- let jaxShapeProduct = jaxField.shape.reduce((a, b) => a * b, 1);
174
- //@ts-ignore
175
- let torchShapeProduct = torchField.shape.reduce((a, b) => a * b, 1);
176
-
177
- if (jaxShapeProduct !== torchShapeProduct) {
178
- return {
179
- error: `JAX ${jaxField.path} with shape ${jaxField.shape} doesn't match PyTorch ${torchField.path} with shape ${torchField.shape}`,
180
- };
181
- }
182
- }
183
- return { success: true };
184
- }
185
- function removeSkipLayer(index: number) {
186
- torchFields = torchFields.toSpliced(index, 1);
187
- setTimeout(() => {
188
- onEnd();
189
- }, 100);
190
- }
191
- function addSkipLayer(index: number) {
192
- let fields = fetchJaxAndTorchFields();
193
- if (fields.error) {
194
- Toast.fire({
195
- icon: "error",
196
- text: fields.error,
197
- });
198
- return;
199
- }
200
- const newField = {
201
- skip: true,
202
- shape: [0],
203
- path: "SKIP",
204
- type: "SKIP",
205
- } as Field;
206
- torchFields = fields.torchFields.toSpliced(index, 0, newField);
207
- setTimeout(() => {
208
- onEnd();
209
- }, 100);
210
- }
211
-
212
- async function convert() {
213
- let fields = fetchJaxAndTorchFields();
214
- if (fields.error) {
215
- Toast.fire({
216
- icon: "error",
217
- title: fields.error,
218
- });
219
- return;
220
- }
221
-
222
- let check = checkFields(fields.jaxFields, fields.torchFields);
223
- if (check.error) {
224
- Toast.fire({
225
- icon: "error",
226
- title: "Failed to convert",
227
- text: check.error,
228
- });
229
- return;
230
- }
231
-
232
- const response = await fetch("/convert", {
233
- method: "POST",
234
- headers: {
235
- "Content-Type": "application/json",
236
- },
237
- body: JSON.stringify({
238
- model: model,
239
- jaxFields: fields.jaxFields,
240
- torchFields: fields.torchFields,
241
- }),
242
- });
243
-
244
- const res = await response.json();
245
- console.log(res);
246
- if (res.error) {
247
- Toast.fire({
248
- icon: "error",
249
- title: res.error,
250
- });
251
- } else {
252
- Toast.fire({
253
- icon: "success",
254
- title: "Conversion successful",
255
- });
256
- }
257
- }
258
-
259
- function padToMatch() {
260
- let fields = fetchJaxAndTorchFields();
261
- if (fields.error) {
262
- Toast.fire({
263
- icon: "error",
264
- text: fields.error,
265
- });
266
- return;
267
- }
268
-
269
- if (fields.torchFields.length < fields.jaxFields.length) {
270
- let toAdd = fields.jaxFields.length - fields.torchFields.length;
271
- for (let i = 0; i < toAdd; i++) {
272
- setTimeout(() => {
273
- console.log("adding skip at ", i);
274
- addSkipLayer(fields.jaxFields.length + i);
275
- }, 100);
276
- }
277
- }
278
- }
279
-
280
- function removeAllSkipLayers() {
281
- let fields = fetchJaxAndTorchFields();
282
- if (fields.error) {
283
- Toast.fire({
284
- icon: "error",
285
- text: fields.error,
286
- });
287
- return;
288
- }
289
- let filteredFields = [];
290
- for (let i = 0; i < fields.torchFields.length; i++) {
291
- if (fields.torchFields[i].skip === false) {
292
- filteredFields.push(fields.torchFields[i]);
293
- }
294
- }
295
- torchFields = filteredFields;
296
- }
297
-
298
- async function matchByName() {
299
- let fields = fetchJaxAndTorchFields();
300
- if (fields.error) {
301
- Toast.fire({
302
- icon: "error",
303
- text: fields.error,
304
- });
305
- return;
306
- }
307
- if (fields.jaxFields.length !== fields.torchFields.length) {
308
- Toast.fire({
309
- icon: "error",
310
- text: "PyTree and State Dict have diffent lengths. Make sure to pad first!",
311
- });
312
- return;
313
- }
314
- Toast.fire({
315
- icon: "info",
316
- title: "Matching by name...",
317
- text: "This can take a while! Hold tight.",
318
- });
319
-
320
- let content = `
321
- You will get two lists of strings. These strings are fields of a JAX and PyTorch model.
322
- For example:
323
- --JAX START--
324
- .layers[0].weight
325
- .layers[1].weight
326
- .layers[2].weight
327
- .layers[3].weight
328
- .layers[4].weight
329
- --JAX END--
330
-
331
- --PYTORCH START--
332
- layers.0.weight
333
- layers.1.weight
334
- layers.4.weight
335
- layers.2.weight
336
- layers.3.weight
337
- --PYTORCH END--
338
-
339
- As you can see, the order doesn't match. This means, you should look at the PyTorch fields and
340
- rearrange them, such that they match the JAX model. In the above example, the expected return value
341
- is:
342
- --PYTORCH START--
343
- layers.0.weight
344
- layers.1.weight
345
- layers.2.weight
346
- layers.3.weight
347
- layers.4.weight
348
- --PYTORCH END--
349
-
350
- Here's another example:
351
- --JAX START--
352
- .conv1.weight
353
- .bn1.weight
354
- .bn1.bias
355
- .bn1.state_index.init[0]
356
- .bn1.state_index.init[1]
357
- --JAX END--
358
-
359
- --PYTORCH START--
360
- bn1.running_mean
361
- bn1.running_var
362
- conv1.weight
363
- bn1.weight
364
- bn1.bias
365
- --PYTORCH END--
366
-
367
- The expected return value in this case is:
368
-
369
- --PYTORCH START--
370
- conv1.weight
371
- bn1.weight
372
- bn1.bias
373
- bn1.running_mean
374
- bn1.running_var
375
- --PYTORCH END--
376
-
377
-
378
- Sometimes, there are so-called "skip-layers" in the PyTorch model. Those can be put anywhere, preferably to
379
- the end, because your priority is to match those fields that can be matched first! Here's an example:
380
-
381
- --JAX START--
382
- .layers[0].weight
383
- .layers[1].weight
384
- .layers[2].weight
385
- .layers[3].weight
386
- .layers[4].weight
387
- --JAX END--
388
-
389
- --PYTORCH START--
390
- layers.0.weight
391
- SKIP
392
- layers.3.weight
393
- layers.2.weight
394
- layers.1.weight
395
- --PYTORCH START--
396
-
397
- This should return
398
-
399
- --PYTORCH START--
400
- layers.0.weight
401
- layers.1.weight
402
- layers.2.weight
403
- layers.3.weight
404
- SKIP
405
- --PYTORCH START--
406
-
407
-
408
- It's not always 100% which belongs to which. Use your best judgement. Start your response with
409
- --PYTORCH START-- and end it with --PYTORCH END--.
410
-
411
-
412
- Here's your input:
413
- --JAX START--
414
- `;
415
-
416
- for (let i = 0; i < fields.jaxFields.length; i++) {
417
- content += fields.jaxFields[i].path + "\n";
418
- }
419
- content += "--JAX END--\n";
420
- content += "\n";
421
- content += "--PYTORCH START--\n";
422
-
423
- for (let i = 0; i < fields.torchFields.length; i++) {
424
- content += fields.torchFields[i].path + "\n";
425
- }
426
-
427
- content += "--PYTORCH END--";
428
- console.log(content);
429
-
430
- let req = await fetch("/anthropic", {
431
- method: "POST",
432
- headers: {
433
- "Content-Type": "application/json",
434
- },
435
- body: JSON.stringify({
436
- content: content,
437
- model: anthropicModel,
438
- }),
439
- });
440
- let res = await req.json();
441
- if (res.error) {
442
- Toast.fire({
443
- icon: "error",
444
- text: res.error,
445
- });
446
- return;
447
- }
448
- console.log(res);
449
- let responseContent = res.content;
450
- let lines = responseContent.split("\n");
451
- console.log(lines);
452
- let rearrangedTorchFields = [];
453
- for (let i = 0; i < lines.length; i++) {
454
- let matchingTorchField = fields.torchFields.find(
455
- (field) => field.path === lines[i],
456
- );
457
- if (matchingTorchField) {
458
- rearrangedTorchFields.push(matchingTorchField);
459
- }
460
- }
461
- if (fields.torchFields.length !== rearrangedTorchFields.length) {
462
- Toast.fire({
463
- icon: "error",
464
- text: "Some fields are missing in the response. Try a different model instead.",
465
- });
466
- return;
467
- }
468
- console.log("rearrangedTorchFields", rearrangedTorchFields);
469
- setTimeout(() => {
470
- torchFields = rearrangedTorchFields;
471
- onEnd();
472
- Toast.fire({
473
- icon: "success",
474
- title: "Success",
475
- });
476
- }, 500);
477
- }
478
- </script>
479
-
480
- <svelte:head><title>Statedict2PyTree</title></svelte:head>
481
-
482
- <h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
483
- <div class="my-4 flex justify-evenly">
484
- <button on:click={padToMatch} class="btn btn-accent">Pad To Match</button>
485
- <button on:click={removeAllSkipLayers} class="btn btn-secondary"
486
- >Remove All Skip Layers</button
487
- >
488
- <div>
489
- <button on:click={matchByName} class="btn btn-warning"
490
- >Match By Name</button
491
- >
492
- <select bind:value={anthropicModel}>
493
- <option value="opus">opus</option>
494
- <option value="sonnet">sonnet</option>
495
- <option value="sonnet3.5">sonnet3.5</option>
496
- <option value="haiku">haiku</option>
497
- </select>
498
- </div>
499
- </div>
500
- <div class="grid grid-cols-2 gap-x-2">
501
- <div class="">
502
- <h2 class="text-2xl">JAX</h2>
503
- <div id="jax-fields" class="">
504
- {#each jaxFields as field, i}
505
- <div
506
- class="border h-12 rounded-xl flex flex-col justify-center"
507
- >
508
- <div
509
- id={"jax-" + String(i)}
510
- class="whitespace-nowrap overflow-x-scroll cursor-pointer mx-2"
511
- data-jax="jax"
512
- data-path={field.path}
513
- data-shape={field.shape}
514
- data-skip={field.skip}
515
- data-type={field.type}
516
- >
517
- {field.path}
518
- {field.shape}
519
- </div>
520
- </div>
521
- {/each}
522
- </div>
523
- </div>
524
-
525
- <div class="">
526
- <h2 class="text-2xl">PyTorch</h2>
527
- <div id="torch-fields" class="">
528
- {#key torchFields}
529
- {#each torchFields as field, i}
530
- <div class="flex space-x-2 border h-12 rounded-xl">
531
- <div
532
- id={"torch-" + String(i)}
533
- data-torch="torch"
534
- data-path={field.path}
535
- data-shape={field.shape}
536
- data-skip={field.skip}
537
- data-type={field.type}
538
- class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
539
- >
540
- {#if field.skip}
541
- SKIP
542
- {:else}
543
- {field.path}
544
- {field.shape}
545
- {/if}
546
- </div>
547
- {#if field.skip}
548
- <button
549
- class="btn btn-ghost"
550
- on:click={() => {
551
- removeSkipLayer(i);
552
- }}>-</button
553
- >
554
- {/if}
555
- <button
556
- class="btn btn-ghost"
557
- on:click={() => {
558
- addSkipLayer(i);
559
- }}>+</button
560
- >
561
- </div>
562
- {/each}
563
- {/key}
564
- </div>
565
- </div>
566
- </div>
567
- <div class="flex justify-center my-12 w-full">
568
- <div class="flex flex-col justify-center w-full">
569
- <input
570
- id="name"
571
- type="text"
572
- name="name"
573
- class="input input-primary w-full"
574
- placeholder="Name of the new file (model.eqx per default)"
575
- bind:value={model}
576
- />
577
- <button
578
- on:click={convert}
579
- class="btn btn-accent btn-wide btn-lg mx-auto my-2"
580
- >
581
- Convert!
582
- </button>
583
- </div>
584
- </div>
client/src/empty.ts DELETED
File without changes
client/src/main.js DELETED
@@ -1,8 +0,0 @@
1
- import App from "./App.svelte";
2
-
3
- const app = new App({
4
- target: document.body,
5
- props: {},
6
- });
7
-
8
- export default app;
client/tailwind.config.js DELETED
@@ -1,8 +0,0 @@
1
- /** @type {import('tailwindcss').Config} */
2
- module.exports = {
3
- content: ["./src/**/*.{html,js,svelte}"],
4
- theme: {
5
- extend: {},
6
- },
7
- plugins: [require("daisyui")],
8
- };
client/tsconfig.json DELETED
@@ -1,5 +0,0 @@
1
- {
2
- "extends": "@tsconfig/svelte/tsconfig.json",
3
- "include": ["src/**/*", "src/node_modules"],
4
- "exclude": ["node_modules/*", "__sapper__/*", "public/*"]
5
- }
@@ -1,17 +0,0 @@
1
- client/.gitignore,sha256=dUAd7J1wJVSaU9WD0Gp4c7_5yEz5lOdWLfsjubHf1z8,39
2
- client/package-lock.json,sha256=JCW5mq9bGNs29yGO1EIlcysDaEuTm2i_AJoyt-45yO8,158531
3
- client/package.json,sha256=Ad-MDEQeh7BPHWPYLd3u9sXk8YVuO_dXmpkxxU1Pglo,1044
4
- client/rollup.config.mjs,sha256=RAepJhL2V5Rf-BlJBZJxllVl0mxtr67GSVr9aU0JUnA,1073
5
- client/tailwind.config.js,sha256=TfN5eOoFOUPGBou6OoK54M14PtokgxWDJUsV4qkurS8,175
6
- client/tsconfig.json,sha256=cLHEFXx-Q55XqbF9QjQ4XScSEQ15n-vS5tsTcqY4UAY,158
7
- client/public/bundle.js,sha256=EEQ1-tpIHl5Nr2POytTpdKp9vriYvUGvbUb7mRv4-7s,356425
8
- client/public/bundle.js.map,sha256=l_aFN6UUVGz9YoLVsk2IwsOWGojpbXmZdvVE-pC9UNQ,694164
9
- client/public/index.html,sha256=jUx-NPKkFN2EF2lj-8Ml49CEHxKJFWK9seszauI4GE0,335
10
- client/public/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
11
- client/public/output.css,sha256=80svlSgNV_Fw82IYgqeRjnY86GXMBW0gWm2VyA7n1A8,36658
12
- client/src/App.svelte,sha256=9DpP0XOFs3gfEW9Sj7NdGfQe4SyPBw9DVfQglRH70Ns,17721
13
- client/src/empty.ts,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
- client/src/main.js,sha256=O_8UgVd1vJM8BcHO7U_6jkL76ZSA6oC7GLLcL9F3JLA,118
15
- statedict2pytree-0.6.0.dist-info/METADATA,sha256=Ot2rYkKkMnp142k7PmryuxRv4ZzXuNt3i5IYi54lf64,2742
16
- statedict2pytree-0.6.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
17
- statedict2pytree-0.6.0.dist-info/RECORD,,