statedict2pytree 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
client/src/App.svelte CHANGED
@@ -5,6 +5,7 @@
5
5
  import Swal from "sweetalert2";
6
6
 
7
7
  let model: string = "model.eqx";
8
+ let anthropicModel: "opus" | "sonnet" | "haiku" = "haiku";
8
9
 
9
10
  const Toast = Swal.mixin({
10
11
  toast: true,
@@ -26,7 +27,6 @@
26
27
 
27
28
  let jaxFields: Field[] = [];
28
29
  let torchFields: Field[] = [];
29
- let torchSortable: Sortable;
30
30
  onMount(async () => {
31
31
  let req = await fetch("/startup/getJaxFields");
32
32
  jaxFields = (await req.json()) as Field[];
@@ -35,10 +35,14 @@
35
35
  setTimeout(() => {
36
36
  initSortable();
37
37
  }, 100);
38
+
39
+ setTimeout(() => {
40
+ onEnd();
41
+ }, 500);
38
42
  });
39
43
 
40
44
  function initSortable() {
41
- torchSortable = new Sortable(document.getElementById("torch-fields"), {
45
+ new Sortable(document.getElementById("torch-fields"), {
42
46
  animation: 150,
43
47
  multiDrag: true,
44
48
  ghostClass: "bg-blue-400",
@@ -48,13 +52,6 @@
48
52
  });
49
53
  }
50
54
 
51
- function swap(a: any, b: any, array: any[]) {
52
- const temp = array[a];
53
- array[a] = array[b];
54
- array[b] = temp;
55
- return array;
56
- }
57
-
58
55
  function fetchJaxAndTorchFields() {
59
56
  let allTorchElements =
60
57
  document.getElementById("torch-fields")?.children;
@@ -105,60 +102,58 @@
105
102
  }
106
103
 
107
104
  function onEnd() {
108
- const updatedFields = fetchJaxAndTorchFields();
109
- if (updatedFields.error) {
110
- Toast.fire({
111
- icon: "error",
112
- title: updatedFields.error,
113
- });
114
- return;
115
- }
116
-
117
- for (let i = 0; i < updatedFields.jaxFields.length; i++) {
118
- let jaxField = updatedFields.jaxFields[i];
119
- let torchField = updatedFields.torchFields[i];
120
- if (torchField === undefined) continue;
121
- if (torchField.skip === true) {
122
- document
123
- .getElementById("jax-" + i)
124
- ?.classList.remove("bg-error");
125
- document
126
- .getElementById("torch-" + i)
127
- ?.classList.remove("bg-error");
128
- continue;
105
+ setTimeout(() => {
106
+ const updatedFields = fetchJaxAndTorchFields();
107
+ if (updatedFields.error) {
108
+ Toast.fire({
109
+ icon: "error",
110
+ title: updatedFields.error,
111
+ });
112
+ return;
129
113
  }
130
- let jaxShape = jaxField.shape;
131
- let torchShape = torchField.shape;
132
- //@ts-ignore
133
- let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
134
- //@ts-ignore
135
- let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
136
- if (jaxShapeProduct !== torchShapeProduct) {
137
- document.getElementById("jax-" + i)?.classList.add("bg-error");
138
- document
139
- .getElementById("torch-" + i)
140
- ?.classList.add("bg-error");
141
- } else {
142
- document
143
- .getElementById("jax-" + i)
144
- ?.classList.remove("bg-error");
145
- document
146
- .getElementById("torch-" + i)
147
- ?.classList.remove("bg-error");
114
+
115
+ for (let i = 0; i < updatedFields.jaxFields.length; i++) {
116
+ let jaxField = updatedFields.jaxFields[i];
117
+ let torchField = updatedFields.torchFields[i];
118
+ if (torchField === undefined) continue;
119
+ if (torchField.skip === true) {
120
+ document
121
+ .getElementById("jax-" + i)
122
+ ?.classList.remove("bg-error");
123
+ continue;
124
+ }
125
+ let jaxShape = jaxField.shape;
126
+ let torchShape = torchField.shape;
127
+ //@ts-ignore
128
+ let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
129
+ //@ts-ignore
130
+ let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
131
+ if (jaxShapeProduct !== torchShapeProduct) {
132
+ document
133
+ .getElementById("jax-" + i)
134
+ ?.classList.add("bg-error");
135
+ } else {
136
+ document
137
+ .getElementById("jax-" + i)
138
+ ?.classList.remove("bg-error");
139
+ }
148
140
  }
149
- }
150
141
 
151
- if (updatedFields.torchFields.length > updatedFields.jaxFields.length) {
152
- for (
153
- let i = updatedFields.jaxFields.length;
154
- i < updatedFields.torchFields.length;
155
- i++
142
+ if (
143
+ updatedFields.torchFields.length >
144
+ updatedFields.jaxFields.length
156
145
  ) {
157
- document
158
- .getElementById("torch-" + i)
159
- ?.classList.remove("bg-error");
146
+ for (
147
+ let i = updatedFields.jaxFields.length;
148
+ i < updatedFields.torchFields.length;
149
+ i++
150
+ ) {
151
+ document
152
+ .getElementById("torch-" + i)
153
+ ?.classList.remove("bg-error");
154
+ }
160
155
  }
161
- }
156
+ }, 100);
162
157
  }
163
158
  function checkFields(jaxFields: Field[], torchFields: Field[]) {
164
159
  if (jaxFields.length > torchFields.length) {
@@ -194,13 +189,21 @@
194
189
  }, 100);
195
190
  }
196
191
  function addSkipLayer(index: number) {
192
+ let fields = fetchJaxAndTorchFields();
193
+ if (fields.error) {
194
+ Toast.fire({
195
+ icon: "error",
196
+ text: fields.error,
197
+ });
198
+ return;
199
+ }
197
200
  const newField = {
198
201
  skip: true,
199
- shape: [],
200
- path: "",
201
- type: "",
202
+ shape: [0],
203
+ path: "SKIP",
204
+ type: "SKIP",
202
205
  } as Field;
203
- torchFields = torchFields.toSpliced(index, 0, newField);
206
+ torchFields = fields.torchFields.toSpliced(index, 0, newField);
204
207
  setTimeout(() => {
205
208
  onEnd();
206
209
  }, 100);
@@ -252,12 +255,247 @@
252
255
  });
253
256
  }
254
257
  }
258
+
259
+ function padToMatch() {
260
+ let fields = fetchJaxAndTorchFields();
261
+ if (fields.error) {
262
+ Toast.fire({
263
+ icon: "error",
264
+ text: fields.error,
265
+ });
266
+ return;
267
+ }
268
+
269
+ if (fields.torchFields.length < fields.jaxFields.length) {
270
+ let toAdd = fields.jaxFields.length - fields.torchFields.length;
271
+ for (let i = 0; i < toAdd; i++) {
272
+ setTimeout(() => {
273
+ console.log("adding skip at ", i);
274
+ addSkipLayer(fields.jaxFields.length + i);
275
+ }, 100);
276
+ }
277
+ }
278
+ }
279
+
280
+ function removeAllSkipLayers() {
281
+ let fields = fetchJaxAndTorchFields();
282
+ if (fields.error) {
283
+ Toast.fire({
284
+ icon: "error",
285
+ text: fields.error,
286
+ });
287
+ return;
288
+ }
289
+ let filteredFields = [];
290
+ for (let i = 0; i < fields.torchFields.length; i++) {
291
+ if (fields.torchFields[i].skip === false) {
292
+ filteredFields.push(fields.torchFields[i]);
293
+ }
294
+ }
295
+ torchFields = filteredFields;
296
+ }
297
+
298
+ async function matchByName() {
299
+ let fields = fetchJaxAndTorchFields();
300
+ if (fields.error) {
301
+ Toast.fire({
302
+ icon: "error",
303
+ text: fields.error,
304
+ });
305
+ return;
306
+ }
307
+ if (fields.jaxFields.length !== fields.torchFields.length) {
308
+ Toast.fire({
309
+ icon: "error",
310
+ text: "PyTree and State Dict have diffent lengths. Make sure to pad first!",
311
+ });
312
+ return;
313
+ }
314
+ Toast.fire({
315
+ icon: "info",
316
+ title: "Matching by name...",
317
+ text: "This can take a while! Hold tight.",
318
+ });
319
+
320
+ let content = `
321
+ You will get two lists of strings. These strings are fields of a JAX and PyTorch model.
322
+ For example:
323
+ --JAX START--
324
+ .layers[0].weight
325
+ .layers[1].weight
326
+ .layers[2].weight
327
+ .layers[3].weight
328
+ .layers[4].weight
329
+ --JAX END--
330
+
331
+ --PYTORCH START--
332
+ layers.0.weight
333
+ layers.1.weight
334
+ layers.4.weight
335
+ layers.2.weight
336
+ layers.3.weight
337
+ --PYTORCH END--
338
+
339
+ As you can see, the order doesn't match. This means, you should look at the PyTorch fields and
340
+ rearrange them, such that they match the JAX model. In the above example, the expected return value
341
+ is:
342
+ --PYTORCH START--
343
+ layers.0.weight
344
+ layers.1.weight
345
+ layers.2.weight
346
+ layers.3.weight
347
+ layers.4.weight
348
+ --PYTORCH END--
349
+
350
+ Here's another example:
351
+ --JAX START--
352
+ .conv1.weight
353
+ .bn1.weight
354
+ .bn1.bias
355
+ .bn1.state_index.init[0]
356
+ .bn1.state_index.init[1]
357
+ --JAX END--
358
+
359
+ --PYTORCH START--
360
+ bn1.running_mean
361
+ bn1.running_var
362
+ conv1.weight
363
+ bn1.weight
364
+ bn1.bias
365
+ --PYTORCH END--
366
+
367
+ The expected return value in this case is:
368
+
369
+ --PYTORCH START--
370
+ conv1.weight
371
+ bn1.weight
372
+ bn1.bias
373
+ bn1.running_mean
374
+ bn1.running_var
375
+ --PYTORCH END--
376
+
377
+
378
+ Sometimes, there are so-called "skip-layers" in the PyTorch model. Those can be put anywhere, preferably to
379
+ the end, because your priority is to match those fields that can be matched first! Here's an example:
380
+
381
+ --JAX START--
382
+ .layers[0].weight
383
+ .layers[1].weight
384
+ .layers[2].weight
385
+ .layers[3].weight
386
+ .layers[4].weight
387
+ --JAX END--
388
+
389
+ --PYTORCH START--
390
+ layers.0.weight
391
+ SKIP
392
+ layers.3.weight
393
+ layers.2.weight
394
+ layers.1.weight
395
+ --PYTORCH START--
396
+
397
+ This should return
398
+
399
+ --PYTORCH START--
400
+ layers.0.weight
401
+ layers.1.weight
402
+ layers.2.weight
403
+ layers.3.weight
404
+ SKIP
405
+ --PYTORCH START--
406
+
407
+
408
+ It's not always 100% which belongs to which. Use your best judgement. Start your response with
409
+ --PYTORCH START-- and end it with --PYTORCH END--.
410
+
411
+
412
+ Here's your input:
413
+ --JAX START--
414
+ `;
415
+
416
+ for (let i = 0; i < fields.jaxFields.length; i++) {
417
+ content += fields.jaxFields[i].path + "\n";
418
+ }
419
+ content += "--JAX END--\n";
420
+ content += "\n";
421
+ content += "--PYTORCH START--\n";
422
+
423
+ for (let i = 0; i < fields.torchFields.length; i++) {
424
+ content += fields.torchFields[i].path + "\n";
425
+ }
426
+
427
+ content += "--PYTORCH END--";
428
+ console.log(content);
429
+
430
+ let req = await fetch("/anthropic", {
431
+ method: "POST",
432
+ headers: {
433
+ "Content-Type": "application/json",
434
+ },
435
+ body: JSON.stringify({
436
+ content: content,
437
+ model: anthropicModel,
438
+ }),
439
+ });
440
+ let res = await req.json();
441
+ if (res.error) {
442
+ Toast.fire({
443
+ icon: "error",
444
+ text: res.error,
445
+ });
446
+ return;
447
+ }
448
+ console.log(res);
449
+ let responseContent = res.content;
450
+ let lines = responseContent.split("\n");
451
+ console.log(lines);
452
+ let rearrangedTorchFields = [];
453
+ for (let i = 0; i < lines.length; i++) {
454
+ let matchingTorchField = fields.torchFields.find(
455
+ (field) => field.path === lines[i],
456
+ );
457
+ if (matchingTorchField) {
458
+ rearrangedTorchFields.push(matchingTorchField);
459
+ }
460
+ }
461
+ if (fields.torchFields.length !== rearrangedTorchFields.length) {
462
+ Toast.fire({
463
+ icon: "error",
464
+ text: "Some fields are missing in the response. Try a different model instead.",
465
+ });
466
+ return;
467
+ }
468
+ console.log("rearrangedTorchFields", rearrangedTorchFields);
469
+ setTimeout(() => {
470
+ torchFields = rearrangedTorchFields;
471
+ onEnd();
472
+ Toast.fire({
473
+ icon: "success",
474
+ title: "Success",
475
+ });
476
+ }, 500);
477
+ }
255
478
  </script>
256
479
 
257
480
  <svelte:head><title>Statedict2PyTree</title></svelte:head>
258
481
 
259
482
  <h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
260
-
483
+ <div class="my-4 flex justify-evenly">
484
+ <button on:click={padToMatch} class="btn btn-accent">Pad To Match</button>
485
+ <button on:click={removeAllSkipLayers} class="btn btn-secondary"
486
+ >Remove All Skip Layers</button
487
+ >
488
+ <div>
489
+ <button on:click={matchByName} class="btn btn-warning"
490
+ >Match By Name</button
491
+ >
492
+ <select bind:value={anthropicModel}>
493
+ <option value="opus">opus</option>
494
+ <option value="sonnet">sonnet</option>
495
+ <option value="haiku">haiku</option>
496
+ </select>
497
+ </div>
498
+ </div>
261
499
  <div class="grid grid-cols-2 gap-x-2">
262
500
  <div class="">
263
501
  <h2 class="text-2xl">JAX</h2>
@@ -286,40 +524,42 @@
286
524
  <div class="">
287
525
  <h2 class="text-2xl">PyTorch</h2>
288
526
  <div id="torch-fields" class="">
289
- {#each torchFields as field, i}
290
- <div class="flex space-x-2 border h-12 rounded-xl">
291
- <div
292
- id={"torch-" + String(i)}
293
- data-torch="torch"
294
- data-path={field.path}
295
- data-shape={field.shape}
296
- data-skip={field.skip}
297
- data-type={field.type}
298
- class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
299
- >
527
+ {#key torchFields}
528
+ {#each torchFields as field, i}
529
+ <div class="flex space-x-2 border h-12 rounded-xl">
530
+ <div
531
+ id={"torch-" + String(i)}
532
+ data-torch="torch"
533
+ data-path={field.path}
534
+ data-shape={field.shape}
535
+ data-skip={field.skip}
536
+ data-type={field.type}
537
+ class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
538
+ >
539
+ {#if field.skip}
540
+ SKIP
541
+ {:else}
542
+ {field.path}
543
+ {field.shape}
544
+ {/if}
545
+ </div>
300
546
  {#if field.skip}
301
- SKIP
302
- {:else}
303
- {field.path}
304
- {field.shape}
547
+ <button
548
+ class="btn btn-ghost"
549
+ on:click={() => {
550
+ removeSkipLayer(i);
551
+ }}>-</button
552
+ >
305
553
  {/if}
306
- </div>
307
- {#if field.skip}
308
554
  <button
309
555
  class="btn btn-ghost"
310
556
  on:click={() => {
311
- removeSkipLayer(i);
312
- }}>-</button
557
+ addSkipLayer(i);
558
+ }}>+</button
313
559
  >
314
- {/if}
315
- <button
316
- class="btn btn-ghost"
317
- on:click={() => {
318
- addSkipLayer(i);
319
- }}>+</button
320
- >
321
- </div>
322
- {/each}
560
+ </div>
561
+ {/each}
562
+ {/key}
323
563
  </div>
324
564
  </div>
325
565
  </div>
@@ -1,9 +1,10 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.5.2
3
+ Version: 0.5.3
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
7
+ Requires-Dist: anthropic
7
8
  Requires-Dist: beartype
8
9
  Requires-Dist: equinox
9
10
  Requires-Dist: flask
@@ -14,6 +15,7 @@ Requires-Dist: jaxtyping
14
15
  Requires-Dist: loguru
15
16
  Requires-Dist: penzai
16
17
  Requires-Dist: pydantic
18
+ Requires-Dist: python-dotenv
17
19
  Requires-Dist: torch
18
20
  Requires-Dist: torchvision
19
21
  Requires-Dist: typing-extensions
@@ -37,6 +39,21 @@ PRs and other contributions are *highly* welcome! :)
37
39
 
38
40
  ## Info
39
41
 
42
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
43
+
44
+ ## Features
45
+
46
+ - Convert PyTorch statedicts to JAX pytrees
47
+ - Handle large models with chunked file conversion
48
+ - Provide an "intuitive-ish" UI for parameter mapping
49
+ - Support both in-memory and file-based conversions
50
+
51
+ ## Installation
52
+
53
+ ```bash
54
+ pip install statedict2pytree
55
+ ```
56
+
40
57
  The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
41
58
 
42
59
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
@@ -49,16 +66,12 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
49
66
 
50
67
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
51
68
 
52
- ## Get Started
53
-
54
- ### Installation
55
-
56
- Run
57
-
58
- ```bash
59
- pip install statedict2pytree
60
- ```
61
69
 
62
70
  ### Docs
63
71
 
64
72
  Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
73
+
74
+ ### Disclaimer
75
+
76
+ Some of the docstrings and the docs have been written with the help of
77
+ Claude.
@@ -4,14 +4,14 @@ client/package.json,sha256=Ad-MDEQeh7BPHWPYLd3u9sXk8YVuO_dXmpkxxU1Pglo,1044
4
4
  client/rollup.config.mjs,sha256=RAepJhL2V5Rf-BlJBZJxllVl0mxtr67GSVr9aU0JUnA,1073
5
5
  client/tailwind.config.js,sha256=TfN5eOoFOUPGBou6OoK54M14PtokgxWDJUsV4qkurS8,175
6
6
  client/tsconfig.json,sha256=cLHEFXx-Q55XqbF9QjQ4XScSEQ15n-vS5tsTcqY4UAY,158
7
- client/public/bundle.js,sha256=l4Vu_8_7v6k7XFisnI8jCFzyMeztr9jc3Z3lrrPDpk0,347197
8
- client/public/bundle.js.map,sha256=s9zOkP-34BWrQFunVbkouwTKVAfL8EyIUcfKgSicH8M,682783
7
+ client/public/bundle.js,sha256=EEQ1-tpIHl5Nr2POytTpdKp9vriYvUGvbUb7mRv4-7s,356425
8
+ client/public/bundle.js.map,sha256=l_aFN6UUVGz9YoLVsk2IwsOWGojpbXmZdvVE-pC9UNQ,694164
9
9
  client/public/index.html,sha256=jUx-NPKkFN2EF2lj-8Ml49CEHxKJFWK9seszauI4GE0,335
10
10
  client/public/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
11
- client/public/output.css,sha256=3iiBiTGfqAeVKuRZRqgcixX3ztSnlp0zqHXkjSKtmVs,38664
12
- client/src/App.svelte,sha256=hHVoQ_C2xGMmd4d86ftZFeTlGmOcOW_wJ6abq0_qvWo,11170
11
+ client/public/output.css,sha256=80svlSgNV_Fw82IYgqeRjnY86GXMBW0gWm2VyA7n1A8,36658
12
+ client/src/App.svelte,sha256=9pAfpz96Bk-A-5uQEBASvxEht-dPctKAZLBs2cOb2kE,17650
13
13
  client/src/empty.ts,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  client/src/main.js,sha256=O_8UgVd1vJM8BcHO7U_6jkL76ZSA6oC7GLLcL9F3JLA,118
15
- statedict2pytree-0.5.2.dist-info/METADATA,sha256=YgE7bgWDMI6urA71bH-zWleR_mw6SJ5QPneUZVvHL2E,2242
16
- statedict2pytree-0.5.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
17
- statedict2pytree-0.5.2.dist-info/RECORD,,
15
+ statedict2pytree-0.5.3.dist-info/METADATA,sha256=J_rlj3ymHkqbf4NxSHnCgh-Nc6LGgK0mhBoG4exnoqo,2788
16
+ statedict2pytree-0.5.3.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
17
+ statedict2pytree-0.5.3.dist-info/RECORD,,