statedict2pytree 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- client/public/bundle.js +615 -234
- client/public/bundle.js.map +1 -1
- client/public/output.css +164 -217
- client/src/App.svelte +331 -91
- {statedict2pytree-0.5.2.dist-info → statedict2pytree-0.5.3.dist-info}/METADATA +23 -10
- {statedict2pytree-0.5.2.dist-info → statedict2pytree-0.5.3.dist-info}/RECORD +7 -7
- {statedict2pytree-0.5.2.dist-info → statedict2pytree-0.5.3.dist-info}/WHEEL +0 -0
client/src/App.svelte
CHANGED
|
@@ -5,6 +5,7 @@
|
|
|
5
5
|
import Swal from "sweetalert2";
|
|
6
6
|
|
|
7
7
|
let model: string = "model.eqx";
|
|
8
|
+
let anthropicModel: "opus" | "sonnet" | "haiku" = "haiku";
|
|
8
9
|
|
|
9
10
|
const Toast = Swal.mixin({
|
|
10
11
|
toast: true,
|
|
@@ -26,7 +27,6 @@
|
|
|
26
27
|
|
|
27
28
|
let jaxFields: Field[] = [];
|
|
28
29
|
let torchFields: Field[] = [];
|
|
29
|
-
let torchSortable: Sortable;
|
|
30
30
|
onMount(async () => {
|
|
31
31
|
let req = await fetch("/startup/getJaxFields");
|
|
32
32
|
jaxFields = (await req.json()) as Field[];
|
|
@@ -35,10 +35,14 @@
|
|
|
35
35
|
setTimeout(() => {
|
|
36
36
|
initSortable();
|
|
37
37
|
}, 100);
|
|
38
|
+
|
|
39
|
+
setTimeout(() => {
|
|
40
|
+
onEnd();
|
|
41
|
+
}, 500);
|
|
38
42
|
});
|
|
39
43
|
|
|
40
44
|
function initSortable() {
|
|
41
|
-
|
|
45
|
+
new Sortable(document.getElementById("torch-fields"), {
|
|
42
46
|
animation: 150,
|
|
43
47
|
multiDrag: true,
|
|
44
48
|
ghostClass: "bg-blue-400",
|
|
@@ -48,13 +52,6 @@
|
|
|
48
52
|
});
|
|
49
53
|
}
|
|
50
54
|
|
|
51
|
-
function swap(a: any, b: any, array: any[]) {
|
|
52
|
-
const temp = array[a];
|
|
53
|
-
array[a] = array[b];
|
|
54
|
-
array[b] = temp;
|
|
55
|
-
return array;
|
|
56
|
-
}
|
|
57
|
-
|
|
58
55
|
function fetchJaxAndTorchFields() {
|
|
59
56
|
let allTorchElements =
|
|
60
57
|
document.getElementById("torch-fields")?.children;
|
|
@@ -105,60 +102,58 @@
|
|
|
105
102
|
}
|
|
106
103
|
|
|
107
104
|
function onEnd() {
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
for (let i = 0; i < updatedFields.jaxFields.length; i++) {
|
|
118
|
-
let jaxField = updatedFields.jaxFields[i];
|
|
119
|
-
let torchField = updatedFields.torchFields[i];
|
|
120
|
-
if (torchField === undefined) continue;
|
|
121
|
-
if (torchField.skip === true) {
|
|
122
|
-
document
|
|
123
|
-
.getElementById("jax-" + i)
|
|
124
|
-
?.classList.remove("bg-error");
|
|
125
|
-
document
|
|
126
|
-
.getElementById("torch-" + i)
|
|
127
|
-
?.classList.remove("bg-error");
|
|
128
|
-
continue;
|
|
105
|
+
setTimeout(() => {
|
|
106
|
+
const updatedFields = fetchJaxAndTorchFields();
|
|
107
|
+
if (updatedFields.error) {
|
|
108
|
+
Toast.fire({
|
|
109
|
+
icon: "error",
|
|
110
|
+
title: updatedFields.error,
|
|
111
|
+
});
|
|
112
|
+
return;
|
|
129
113
|
}
|
|
130
|
-
|
|
131
|
-
let
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
114
|
+
|
|
115
|
+
for (let i = 0; i < updatedFields.jaxFields.length; i++) {
|
|
116
|
+
let jaxField = updatedFields.jaxFields[i];
|
|
117
|
+
let torchField = updatedFields.torchFields[i];
|
|
118
|
+
if (torchField === undefined) continue;
|
|
119
|
+
if (torchField.skip === true) {
|
|
120
|
+
document
|
|
121
|
+
.getElementById("jax-" + i)
|
|
122
|
+
?.classList.remove("bg-error");
|
|
123
|
+
continue;
|
|
124
|
+
}
|
|
125
|
+
let jaxShape = jaxField.shape;
|
|
126
|
+
let torchShape = torchField.shape;
|
|
127
|
+
//@ts-ignore
|
|
128
|
+
let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
|
|
129
|
+
//@ts-ignore
|
|
130
|
+
let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
|
|
131
|
+
if (jaxShapeProduct !== torchShapeProduct) {
|
|
132
|
+
document
|
|
133
|
+
.getElementById("jax-" + i)
|
|
134
|
+
?.classList.add("bg-error");
|
|
135
|
+
} else {
|
|
136
|
+
document
|
|
137
|
+
.getElementById("jax-" + i)
|
|
138
|
+
?.classList.remove("bg-error");
|
|
139
|
+
}
|
|
148
140
|
}
|
|
149
|
-
}
|
|
150
141
|
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
i < updatedFields.torchFields.length;
|
|
155
|
-
i++
|
|
142
|
+
if (
|
|
143
|
+
updatedFields.torchFields.length >
|
|
144
|
+
updatedFields.jaxFields.length
|
|
156
145
|
) {
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
146
|
+
for (
|
|
147
|
+
let i = updatedFields.jaxFields.length;
|
|
148
|
+
i < updatedFields.torchFields.length;
|
|
149
|
+
i++
|
|
150
|
+
) {
|
|
151
|
+
document
|
|
152
|
+
.getElementById("torch-" + i)
|
|
153
|
+
?.classList.remove("bg-error");
|
|
154
|
+
}
|
|
160
155
|
}
|
|
161
|
-
}
|
|
156
|
+
}, 100);
|
|
162
157
|
}
|
|
163
158
|
function checkFields(jaxFields: Field[], torchFields: Field[]) {
|
|
164
159
|
if (jaxFields.length > torchFields.length) {
|
|
@@ -194,13 +189,21 @@
|
|
|
194
189
|
}, 100);
|
|
195
190
|
}
|
|
196
191
|
function addSkipLayer(index: number) {
|
|
192
|
+
let fields = fetchJaxAndTorchFields();
|
|
193
|
+
if (fields.error) {
|
|
194
|
+
Toast.fire({
|
|
195
|
+
icon: "error",
|
|
196
|
+
text: fields.error,
|
|
197
|
+
});
|
|
198
|
+
return;
|
|
199
|
+
}
|
|
197
200
|
const newField = {
|
|
198
201
|
skip: true,
|
|
199
|
-
shape: [],
|
|
200
|
-
path: "",
|
|
201
|
-
type: "",
|
|
202
|
+
shape: [0],
|
|
203
|
+
path: "SKIP",
|
|
204
|
+
type: "SKIP",
|
|
202
205
|
} as Field;
|
|
203
|
-
torchFields = torchFields.toSpliced(index, 0, newField);
|
|
206
|
+
torchFields = fields.torchFields.toSpliced(index, 0, newField);
|
|
204
207
|
setTimeout(() => {
|
|
205
208
|
onEnd();
|
|
206
209
|
}, 100);
|
|
@@ -252,12 +255,247 @@
|
|
|
252
255
|
});
|
|
253
256
|
}
|
|
254
257
|
}
|
|
258
|
+
|
|
259
|
+
function padToMatch() {
|
|
260
|
+
let fields = fetchJaxAndTorchFields();
|
|
261
|
+
if (fields.error) {
|
|
262
|
+
Toast.fire({
|
|
263
|
+
icon: "error",
|
|
264
|
+
text: fields.error,
|
|
265
|
+
});
|
|
266
|
+
return;
|
|
267
|
+
}
|
|
268
|
+
|
|
269
|
+
if (fields.torchFields.length < fields.jaxFields.length) {
|
|
270
|
+
let toAdd = fields.jaxFields.length - fields.torchFields.length;
|
|
271
|
+
for (let i = 0; i < toAdd; i++) {
|
|
272
|
+
setTimeout(() => {
|
|
273
|
+
console.log("adding skip at ", i);
|
|
274
|
+
addSkipLayer(fields.jaxFields.length + i);
|
|
275
|
+
}, 100);
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
function removeAllSkipLayers() {
|
|
281
|
+
let fields = fetchJaxAndTorchFields();
|
|
282
|
+
if (fields.error) {
|
|
283
|
+
Toast.fire({
|
|
284
|
+
icon: "error",
|
|
285
|
+
text: fields.error,
|
|
286
|
+
});
|
|
287
|
+
return;
|
|
288
|
+
}
|
|
289
|
+
let filteredFields = [];
|
|
290
|
+
for (let i = 0; i < fields.torchFields.length; i++) {
|
|
291
|
+
if (fields.torchFields[i].skip === false) {
|
|
292
|
+
filteredFields.push(fields.torchFields[i]);
|
|
293
|
+
}
|
|
294
|
+
}
|
|
295
|
+
torchFields = filteredFields;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
async function matchByName() {
|
|
299
|
+
let fields = fetchJaxAndTorchFields();
|
|
300
|
+
if (fields.error) {
|
|
301
|
+
Toast.fire({
|
|
302
|
+
icon: "error",
|
|
303
|
+
text: fields.error,
|
|
304
|
+
});
|
|
305
|
+
return;
|
|
306
|
+
}
|
|
307
|
+
if (fields.jaxFields.length !== fields.torchFields.length) {
|
|
308
|
+
Toast.fire({
|
|
309
|
+
icon: "error",
|
|
310
|
+
text: "PyTree and State Dict have diffent lengths. Make sure to pad first!",
|
|
311
|
+
});
|
|
312
|
+
return;
|
|
313
|
+
}
|
|
314
|
+
Toast.fire({
|
|
315
|
+
icon: "info",
|
|
316
|
+
title: "Matching by name...",
|
|
317
|
+
text: "This can take a while! Hold tight.",
|
|
318
|
+
});
|
|
319
|
+
|
|
320
|
+
let content = `
|
|
321
|
+
You will get two lists of strings. These strings are fields of a JAX and PyTorch model.
|
|
322
|
+
For example:
|
|
323
|
+
--JAX START--
|
|
324
|
+
.layers[0].weight
|
|
325
|
+
.layers[1].weight
|
|
326
|
+
.layers[2].weight
|
|
327
|
+
.layers[3].weight
|
|
328
|
+
.layers[4].weight
|
|
329
|
+
--JAX END--
|
|
330
|
+
|
|
331
|
+
--PYTORCH START--
|
|
332
|
+
layers.0.weight
|
|
333
|
+
layers.1.weight
|
|
334
|
+
layers.4.weight
|
|
335
|
+
layers.2.weight
|
|
336
|
+
layers.3.weight
|
|
337
|
+
--PYTORCH END--
|
|
338
|
+
|
|
339
|
+
As you can see, the order doesn't match. This means, you should look at the PyTorch fields and
|
|
340
|
+
rearrange them, such that they match the JAX model. In the above example, the expected return value
|
|
341
|
+
is:
|
|
342
|
+
--PYTORCH START--
|
|
343
|
+
layers.0.weight
|
|
344
|
+
layers.1.weight
|
|
345
|
+
layers.2.weight
|
|
346
|
+
layers.3.weight
|
|
347
|
+
layers.4.weight
|
|
348
|
+
--PYTORCH END--
|
|
349
|
+
|
|
350
|
+
Here's another example:
|
|
351
|
+
--JAX START--
|
|
352
|
+
.conv1.weight
|
|
353
|
+
.bn1.weight
|
|
354
|
+
.bn1.bias
|
|
355
|
+
.bn1.state_index.init[0]
|
|
356
|
+
.bn1.state_index.init[1]
|
|
357
|
+
--JAX END--
|
|
358
|
+
|
|
359
|
+
--PYTORCH START--
|
|
360
|
+
bn1.running_mean
|
|
361
|
+
bn1.running_var
|
|
362
|
+
conv1.weight
|
|
363
|
+
bn1.weight
|
|
364
|
+
bn1.bias
|
|
365
|
+
--PYTORCH END--
|
|
366
|
+
|
|
367
|
+
The expected return value in this case is:
|
|
368
|
+
|
|
369
|
+
--PYTORCH START--
|
|
370
|
+
conv1.weight
|
|
371
|
+
bn1.weight
|
|
372
|
+
bn1.bias
|
|
373
|
+
bn1.running_mean
|
|
374
|
+
bn1.running_var
|
|
375
|
+
--PYTORCH END--
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
Sometimes, there are so-called "skip-layers" in the PyTorch model. Those can be put anywhere, preferably to
|
|
379
|
+
the end, because your priority is to match those fields that can be matched first! Here's an example:
|
|
380
|
+
|
|
381
|
+
--JAX START--
|
|
382
|
+
.layers[0].weight
|
|
383
|
+
.layers[1].weight
|
|
384
|
+
.layers[2].weight
|
|
385
|
+
.layers[3].weight
|
|
386
|
+
.layers[4].weight
|
|
387
|
+
--JAX END--
|
|
388
|
+
|
|
389
|
+
--PYTORCH START--
|
|
390
|
+
layers.0.weight
|
|
391
|
+
SKIP
|
|
392
|
+
layers.3.weight
|
|
393
|
+
layers.2.weight
|
|
394
|
+
layers.1.weight
|
|
395
|
+
--PYTORCH START--
|
|
396
|
+
|
|
397
|
+
This should return
|
|
398
|
+
|
|
399
|
+
--PYTORCH START--
|
|
400
|
+
layers.0.weight
|
|
401
|
+
layers.1.weight
|
|
402
|
+
layers.2.weight
|
|
403
|
+
layers.3.weight
|
|
404
|
+
SKIP
|
|
405
|
+
--PYTORCH START--
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
It's not always 100% which belongs to which. Use your best judgement. Start your response with
|
|
409
|
+
--PYTORCH START-- and end it with --PYTORCH END--.
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
Here's your input:
|
|
413
|
+
--JAX START--
|
|
414
|
+
`;
|
|
415
|
+
|
|
416
|
+
for (let i = 0; i < fields.jaxFields.length; i++) {
|
|
417
|
+
content += fields.jaxFields[i].path + "\n";
|
|
418
|
+
}
|
|
419
|
+
content += "--JAX END--\n";
|
|
420
|
+
content += "\n";
|
|
421
|
+
content += "--PYTORCH START--\n";
|
|
422
|
+
|
|
423
|
+
for (let i = 0; i < fields.torchFields.length; i++) {
|
|
424
|
+
content += fields.torchFields[i].path + "\n";
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
content += "--PYTORCH END--";
|
|
428
|
+
console.log(content);
|
|
429
|
+
|
|
430
|
+
let req = await fetch("/anthropic", {
|
|
431
|
+
method: "POST",
|
|
432
|
+
headers: {
|
|
433
|
+
"Content-Type": "application/json",
|
|
434
|
+
},
|
|
435
|
+
body: JSON.stringify({
|
|
436
|
+
content: content,
|
|
437
|
+
model: anthropicModel,
|
|
438
|
+
}),
|
|
439
|
+
});
|
|
440
|
+
let res = await req.json();
|
|
441
|
+
if (res.error) {
|
|
442
|
+
Toast.fire({
|
|
443
|
+
icon: "error",
|
|
444
|
+
text: res.error,
|
|
445
|
+
});
|
|
446
|
+
return;
|
|
447
|
+
}
|
|
448
|
+
console.log(res);
|
|
449
|
+
let responseContent = res.content;
|
|
450
|
+
let lines = responseContent.split("\n");
|
|
451
|
+
console.log(lines);
|
|
452
|
+
let rearrangedTorchFields = [];
|
|
453
|
+
for (let i = 0; i < lines.length; i++) {
|
|
454
|
+
let matchingTorchField = fields.torchFields.find(
|
|
455
|
+
(field) => field.path === lines[i],
|
|
456
|
+
);
|
|
457
|
+
if (matchingTorchField) {
|
|
458
|
+
rearrangedTorchFields.push(matchingTorchField);
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
if (fields.torchFields.length !== rearrangedTorchFields.length) {
|
|
462
|
+
Toast.fire({
|
|
463
|
+
icon: "error",
|
|
464
|
+
text: "Some fields are missing in the response. Try a different model instead.",
|
|
465
|
+
});
|
|
466
|
+
return;
|
|
467
|
+
}
|
|
468
|
+
console.log("rearrangedTorchFields", rearrangedTorchFields);
|
|
469
|
+
setTimeout(() => {
|
|
470
|
+
torchFields = rearrangedTorchFields;
|
|
471
|
+
onEnd();
|
|
472
|
+
Toast.fire({
|
|
473
|
+
icon: "success",
|
|
474
|
+
title: "Success",
|
|
475
|
+
});
|
|
476
|
+
}, 500);
|
|
477
|
+
}
|
|
255
478
|
</script>
|
|
256
479
|
|
|
257
480
|
<svelte:head><title>Statedict2PyTree</title></svelte:head>
|
|
258
481
|
|
|
259
482
|
<h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
|
|
260
|
-
|
|
483
|
+
<div class="my-4 flex justify-evenly">
|
|
484
|
+
<button on:click={padToMatch} class="btn btn-accent">Pad To Match</button>
|
|
485
|
+
<button on:click={removeAllSkipLayers} class="btn btn-secondary"
|
|
486
|
+
>Remove All Skip Layers</button
|
|
487
|
+
>
|
|
488
|
+
<div>
|
|
489
|
+
<button on:click={matchByName} class="btn btn-warning"
|
|
490
|
+
>Match By Name</button
|
|
491
|
+
>
|
|
492
|
+
<select bind:value={anthropicModel}>
|
|
493
|
+
<option value="opus">opus</option>
|
|
494
|
+
<option value="sonnet">sonnet</option>
|
|
495
|
+
<option value="haiku">haiku</option>
|
|
496
|
+
</select>
|
|
497
|
+
</div>
|
|
498
|
+
</div>
|
|
261
499
|
<div class="grid grid-cols-2 gap-x-2">
|
|
262
500
|
<div class="">
|
|
263
501
|
<h2 class="text-2xl">JAX</h2>
|
|
@@ -286,40 +524,42 @@
|
|
|
286
524
|
<div class="">
|
|
287
525
|
<h2 class="text-2xl">PyTorch</h2>
|
|
288
526
|
<div id="torch-fields" class="">
|
|
289
|
-
{#
|
|
290
|
-
|
|
291
|
-
<div
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
527
|
+
{#key torchFields}
|
|
528
|
+
{#each torchFields as field, i}
|
|
529
|
+
<div class="flex space-x-2 border h-12 rounded-xl">
|
|
530
|
+
<div
|
|
531
|
+
id={"torch-" + String(i)}
|
|
532
|
+
data-torch="torch"
|
|
533
|
+
data-path={field.path}
|
|
534
|
+
data-shape={field.shape}
|
|
535
|
+
data-skip={field.skip}
|
|
536
|
+
data-type={field.type}
|
|
537
|
+
class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
|
|
538
|
+
>
|
|
539
|
+
{#if field.skip}
|
|
540
|
+
SKIP
|
|
541
|
+
{:else}
|
|
542
|
+
{field.path}
|
|
543
|
+
{field.shape}
|
|
544
|
+
{/if}
|
|
545
|
+
</div>
|
|
300
546
|
{#if field.skip}
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
547
|
+
<button
|
|
548
|
+
class="btn btn-ghost"
|
|
549
|
+
on:click={() => {
|
|
550
|
+
removeSkipLayer(i);
|
|
551
|
+
}}>-</button
|
|
552
|
+
>
|
|
305
553
|
{/if}
|
|
306
|
-
</div>
|
|
307
|
-
{#if field.skip}
|
|
308
554
|
<button
|
|
309
555
|
class="btn btn-ghost"
|
|
310
556
|
on:click={() => {
|
|
311
|
-
|
|
312
|
-
}}
|
|
557
|
+
addSkipLayer(i);
|
|
558
|
+
}}>+</button
|
|
313
559
|
>
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
on:click={() => {
|
|
318
|
-
addSkipLayer(i);
|
|
319
|
-
}}>+</button
|
|
320
|
-
>
|
|
321
|
-
</div>
|
|
322
|
-
{/each}
|
|
560
|
+
</div>
|
|
561
|
+
{/each}
|
|
562
|
+
{/key}
|
|
323
563
|
</div>
|
|
324
564
|
</div>
|
|
325
565
|
</div>
|
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.3
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: ~=3.10
|
|
7
|
+
Requires-Dist: anthropic
|
|
7
8
|
Requires-Dist: beartype
|
|
8
9
|
Requires-Dist: equinox
|
|
9
10
|
Requires-Dist: flask
|
|
@@ -14,6 +15,7 @@ Requires-Dist: jaxtyping
|
|
|
14
15
|
Requires-Dist: loguru
|
|
15
16
|
Requires-Dist: penzai
|
|
16
17
|
Requires-Dist: pydantic
|
|
18
|
+
Requires-Dist: python-dotenv
|
|
17
19
|
Requires-Dist: torch
|
|
18
20
|
Requires-Dist: torchvision
|
|
19
21
|
Requires-Dist: typing-extensions
|
|
@@ -37,6 +39,21 @@ PRs and other contributions are *highly* welcome! :)
|
|
|
37
39
|
|
|
38
40
|
## Info
|
|
39
41
|
|
|
42
|
+
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
|
|
43
|
+
|
|
44
|
+
## Features
|
|
45
|
+
|
|
46
|
+
- Convert PyTorch statedicts to JAX pytrees
|
|
47
|
+
- Handle large models with chunked file conversion
|
|
48
|
+
- Provide an "intuitive-ish" UI for parameter mapping
|
|
49
|
+
- Support both in-memory and file-based conversions
|
|
50
|
+
|
|
51
|
+
## Installation
|
|
52
|
+
|
|
53
|
+
```bash
|
|
54
|
+
pip install statedict2pytree
|
|
55
|
+
```
|
|
56
|
+
|
|
40
57
|
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
41
58
|
|
|
42
59
|
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
@@ -49,16 +66,12 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
|
|
|
49
66
|
|
|
50
67
|
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
51
68
|
|
|
52
|
-
## Get Started
|
|
53
|
-
|
|
54
|
-
### Installation
|
|
55
|
-
|
|
56
|
-
Run
|
|
57
|
-
|
|
58
|
-
```bash
|
|
59
|
-
pip install statedict2pytree
|
|
60
|
-
```
|
|
61
69
|
|
|
62
70
|
### Docs
|
|
63
71
|
|
|
64
72
|
Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
|
|
73
|
+
|
|
74
|
+
### Disclaimer
|
|
75
|
+
|
|
76
|
+
Some of the docstrings and the docs have been written with the help of
|
|
77
|
+
Claude.
|
|
@@ -4,14 +4,14 @@ client/package.json,sha256=Ad-MDEQeh7BPHWPYLd3u9sXk8YVuO_dXmpkxxU1Pglo,1044
|
|
|
4
4
|
client/rollup.config.mjs,sha256=RAepJhL2V5Rf-BlJBZJxllVl0mxtr67GSVr9aU0JUnA,1073
|
|
5
5
|
client/tailwind.config.js,sha256=TfN5eOoFOUPGBou6OoK54M14PtokgxWDJUsV4qkurS8,175
|
|
6
6
|
client/tsconfig.json,sha256=cLHEFXx-Q55XqbF9QjQ4XScSEQ15n-vS5tsTcqY4UAY,158
|
|
7
|
-
client/public/bundle.js,sha256=
|
|
8
|
-
client/public/bundle.js.map,sha256=
|
|
7
|
+
client/public/bundle.js,sha256=EEQ1-tpIHl5Nr2POytTpdKp9vriYvUGvbUb7mRv4-7s,356425
|
|
8
|
+
client/public/bundle.js.map,sha256=l_aFN6UUVGz9YoLVsk2IwsOWGojpbXmZdvVE-pC9UNQ,694164
|
|
9
9
|
client/public/index.html,sha256=jUx-NPKkFN2EF2lj-8Ml49CEHxKJFWK9seszauI4GE0,335
|
|
10
10
|
client/public/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
11
|
-
client/public/output.css,sha256=
|
|
12
|
-
client/src/App.svelte,sha256=
|
|
11
|
+
client/public/output.css,sha256=80svlSgNV_Fw82IYgqeRjnY86GXMBW0gWm2VyA7n1A8,36658
|
|
12
|
+
client/src/App.svelte,sha256=9pAfpz96Bk-A-5uQEBASvxEht-dPctKAZLBs2cOb2kE,17650
|
|
13
13
|
client/src/empty.ts,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
client/src/main.js,sha256=O_8UgVd1vJM8BcHO7U_6jkL76ZSA6oC7GLLcL9F3JLA,118
|
|
15
|
-
statedict2pytree-0.5.
|
|
16
|
-
statedict2pytree-0.5.
|
|
17
|
-
statedict2pytree-0.5.
|
|
15
|
+
statedict2pytree-0.5.3.dist-info/METADATA,sha256=J_rlj3ymHkqbf4NxSHnCgh-Nc6LGgK0mhBoG4exnoqo,2788
|
|
16
|
+
statedict2pytree-0.5.3.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
17
|
+
statedict2pytree-0.5.3.dist-info/RECORD,,
|
|
File without changes
|