statedict2pytree 0.5.4__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/__init__.py +8 -0
- statedict2pytree/converter.py +293 -0
- {statedict2pytree-0.5.4.dist-info → statedict2pytree-1.0.0.dist-info}/METADATA +10 -28
- statedict2pytree-1.0.0.dist-info/RECORD +5 -0
- client/.gitignore +0 -3
- client/package-lock.json +0 -4540
- client/package.json +0 -36
- client/public/bundle.js +0 -10072
- client/public/bundle.js.map +0 -1
- client/public/index.html +0 -14
- client/public/input.css +0 -3
- client/public/output.css +0 -1617
- client/rollup.config.mjs +0 -44
- client/src/App.svelte +0 -584
- client/src/empty.ts +0 -0
- client/src/main.js +0 -8
- client/tailwind.config.js +0 -8
- client/tsconfig.json +0 -5
- statedict2pytree-0.5.4.dist-info/RECORD +0 -17
- {statedict2pytree-0.5.4.dist-info → statedict2pytree-1.0.0.dist-info}/WHEEL +0 -0
client/rollup.config.mjs
DELETED
|
@@ -1,44 +0,0 @@
|
|
|
1
|
-
import svelte from "rollup-plugin-svelte";
|
|
2
|
-
import { nodeResolve } from "@rollup/plugin-node-resolve";
|
|
3
|
-
import commonjs from "@rollup/plugin-commonjs";
|
|
4
|
-
import livereload from "rollup-plugin-livereload";
|
|
5
|
-
import terser from "@rollup/plugin-terser";
|
|
6
|
-
import { sveltePreprocess } from "svelte-preprocess";
|
|
7
|
-
import typescript from "@rollup/plugin-typescript";
|
|
8
|
-
|
|
9
|
-
const production = !process.env.ROLLUP_WATCH;
|
|
10
|
-
|
|
11
|
-
export default {
|
|
12
|
-
input: "src/main.js",
|
|
13
|
-
output: {
|
|
14
|
-
sourcemap: true,
|
|
15
|
-
format: "iife",
|
|
16
|
-
name: "app",
|
|
17
|
-
file: "public/bundle.js",
|
|
18
|
-
},
|
|
19
|
-
plugins: [
|
|
20
|
-
svelte({
|
|
21
|
-
compilerOptions: {
|
|
22
|
-
dev: !production,
|
|
23
|
-
css: (css) => {
|
|
24
|
-
css.write("public/bundle.css");
|
|
25
|
-
},
|
|
26
|
-
},
|
|
27
|
-
preprocess: sveltePreprocess(),
|
|
28
|
-
}),
|
|
29
|
-
|
|
30
|
-
typescript({ sourceMap: !production }),
|
|
31
|
-
nodeResolve({
|
|
32
|
-
browser: true,
|
|
33
|
-
dedupe: (importee) =>
|
|
34
|
-
importee === "svelte" || importee.startsWith("svelte/"),
|
|
35
|
-
}),
|
|
36
|
-
commonjs(),
|
|
37
|
-
|
|
38
|
-
!production && livereload("public"),
|
|
39
|
-
production && terser(),
|
|
40
|
-
],
|
|
41
|
-
watch: {
|
|
42
|
-
clearScreen: false,
|
|
43
|
-
},
|
|
44
|
-
};
|
client/src/App.svelte
DELETED
|
@@ -1,584 +0,0 @@
|
|
|
1
|
-
<script lang="ts">
|
|
2
|
-
//@ts-ignore
|
|
3
|
-
import Sortable from "sortablejs/modular/sortable.complete.esm.js";
|
|
4
|
-
import { onMount } from "svelte";
|
|
5
|
-
import Swal from "sweetalert2";
|
|
6
|
-
|
|
7
|
-
let model: string = "model.eqx";
|
|
8
|
-
let anthropicModel: "opus" | "sonnet" | "sonnet3.5" | "haiku" = "haiku";
|
|
9
|
-
|
|
10
|
-
const Toast = Swal.mixin({
|
|
11
|
-
toast: true,
|
|
12
|
-
position: "top-end",
|
|
13
|
-
showConfirmButton: false,
|
|
14
|
-
timer: 5000,
|
|
15
|
-
timerProgressBar: true,
|
|
16
|
-
didOpen: (toast) => {
|
|
17
|
-
toast.onmouseenter = Swal.stopTimer;
|
|
18
|
-
toast.onmouseleave = Swal.resumeTimer;
|
|
19
|
-
},
|
|
20
|
-
});
|
|
21
|
-
type Field = {
|
|
22
|
-
path: string;
|
|
23
|
-
shape: Number[];
|
|
24
|
-
skip: boolean;
|
|
25
|
-
type: string | null;
|
|
26
|
-
};
|
|
27
|
-
|
|
28
|
-
let jaxFields: Field[] = [];
|
|
29
|
-
let torchFields: Field[] = [];
|
|
30
|
-
onMount(async () => {
|
|
31
|
-
let req = await fetch("/startup/getJaxFields");
|
|
32
|
-
jaxFields = (await req.json()) as Field[];
|
|
33
|
-
req = await fetch("/startup/getTorchFields");
|
|
34
|
-
torchFields = await req.json();
|
|
35
|
-
setTimeout(() => {
|
|
36
|
-
initSortable();
|
|
37
|
-
}, 100);
|
|
38
|
-
|
|
39
|
-
setTimeout(() => {
|
|
40
|
-
onEnd();
|
|
41
|
-
}, 500);
|
|
42
|
-
});
|
|
43
|
-
|
|
44
|
-
function initSortable() {
|
|
45
|
-
new Sortable(document.getElementById("torch-fields"), {
|
|
46
|
-
animation: 150,
|
|
47
|
-
multiDrag: true,
|
|
48
|
-
ghostClass: "bg-blue-400",
|
|
49
|
-
selectedClass: "bg-accent",
|
|
50
|
-
multiDragKey: "shift",
|
|
51
|
-
onEnd: onEnd,
|
|
52
|
-
});
|
|
53
|
-
}
|
|
54
|
-
|
|
55
|
-
function fetchJaxAndTorchFields() {
|
|
56
|
-
let allTorchElements =
|
|
57
|
-
document.getElementById("torch-fields")?.children;
|
|
58
|
-
if (!allTorchElements) {
|
|
59
|
-
Toast.fire({
|
|
60
|
-
icon: "error",
|
|
61
|
-
title: "Couldn't find PyTorch elements",
|
|
62
|
-
});
|
|
63
|
-
return {
|
|
64
|
-
error: "Failed to fetch PyTorch elements",
|
|
65
|
-
jaxFields: [],
|
|
66
|
-
torchFields: [],
|
|
67
|
-
};
|
|
68
|
-
}
|
|
69
|
-
let allTorchFields: HTMLElement[] = [];
|
|
70
|
-
for (let i = 0; i < allTorchElements.length; i++) {
|
|
71
|
-
allTorchFields.push(allTorchElements[i].firstChild as HTMLElement);
|
|
72
|
-
}
|
|
73
|
-
|
|
74
|
-
let newTorchFields: Field[] = [];
|
|
75
|
-
allTorchFields.forEach((el) => {
|
|
76
|
-
newTorchFields.push({
|
|
77
|
-
path: el.getAttribute("data-path"),
|
|
78
|
-
shape: el
|
|
79
|
-
.getAttribute("data-shape")
|
|
80
|
-
?.split(",")
|
|
81
|
-
.map((x) => parseInt(x)),
|
|
82
|
-
type: el.getAttribute("data-type"),
|
|
83
|
-
skip: el.getAttribute("data-skip") === "true",
|
|
84
|
-
} as Field);
|
|
85
|
-
});
|
|
86
|
-
|
|
87
|
-
const allJaxFields = document.querySelectorAll('[data-jax="jax"]');
|
|
88
|
-
let newJaxFields: Field[] = [];
|
|
89
|
-
allJaxFields.forEach((el) => {
|
|
90
|
-
newJaxFields.push({
|
|
91
|
-
path: el.getAttribute("data-path"),
|
|
92
|
-
shape: el
|
|
93
|
-
.getAttribute("data-shape")
|
|
94
|
-
?.split(",")
|
|
95
|
-
.map((x) => parseInt(x)),
|
|
96
|
-
type: el.getAttribute("data-type"),
|
|
97
|
-
skip: el.getAttribute("data-skip") === "true",
|
|
98
|
-
} as Field);
|
|
99
|
-
});
|
|
100
|
-
|
|
101
|
-
return { jaxFields: newJaxFields, torchFields: newTorchFields };
|
|
102
|
-
}
|
|
103
|
-
|
|
104
|
-
function onEnd() {
|
|
105
|
-
setTimeout(() => {
|
|
106
|
-
const updatedFields = fetchJaxAndTorchFields();
|
|
107
|
-
if (updatedFields.error) {
|
|
108
|
-
Toast.fire({
|
|
109
|
-
icon: "error",
|
|
110
|
-
title: updatedFields.error,
|
|
111
|
-
});
|
|
112
|
-
return;
|
|
113
|
-
}
|
|
114
|
-
|
|
115
|
-
for (let i = 0; i < updatedFields.jaxFields.length; i++) {
|
|
116
|
-
let jaxField = updatedFields.jaxFields[i];
|
|
117
|
-
let torchField = updatedFields.torchFields[i];
|
|
118
|
-
if (torchField === undefined) continue;
|
|
119
|
-
if (torchField.skip === true) {
|
|
120
|
-
document
|
|
121
|
-
.getElementById("jax-" + i)
|
|
122
|
-
?.classList.remove("bg-error");
|
|
123
|
-
continue;
|
|
124
|
-
}
|
|
125
|
-
let jaxShape = jaxField.shape;
|
|
126
|
-
let torchShape = torchField.shape;
|
|
127
|
-
//@ts-ignore
|
|
128
|
-
let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
|
|
129
|
-
//@ts-ignore
|
|
130
|
-
let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
|
|
131
|
-
if (jaxShapeProduct !== torchShapeProduct) {
|
|
132
|
-
document
|
|
133
|
-
.getElementById("jax-" + i)
|
|
134
|
-
?.classList.add("bg-error");
|
|
135
|
-
} else {
|
|
136
|
-
document
|
|
137
|
-
.getElementById("jax-" + i)
|
|
138
|
-
?.classList.remove("bg-error");
|
|
139
|
-
}
|
|
140
|
-
}
|
|
141
|
-
|
|
142
|
-
if (
|
|
143
|
-
updatedFields.torchFields.length >
|
|
144
|
-
updatedFields.jaxFields.length
|
|
145
|
-
) {
|
|
146
|
-
for (
|
|
147
|
-
let i = updatedFields.jaxFields.length;
|
|
148
|
-
i < updatedFields.torchFields.length;
|
|
149
|
-
i++
|
|
150
|
-
) {
|
|
151
|
-
document
|
|
152
|
-
.getElementById("torch-" + i)
|
|
153
|
-
?.classList.remove("bg-error");
|
|
154
|
-
}
|
|
155
|
-
}
|
|
156
|
-
}, 100);
|
|
157
|
-
}
|
|
158
|
-
function checkFields(jaxFields: Field[], torchFields: Field[]) {
|
|
159
|
-
if (jaxFields.length > torchFields.length) {
|
|
160
|
-
return {
|
|
161
|
-
error: "JAX and PyTorch have lengths! Make sure to pad the PyTorch side.",
|
|
162
|
-
};
|
|
163
|
-
}
|
|
164
|
-
|
|
165
|
-
for (let i = 0; i < jaxFields.length; i++) {
|
|
166
|
-
let jaxField = jaxFields[i];
|
|
167
|
-
let torchField = torchFields[i];
|
|
168
|
-
if (torchField.skip === true) {
|
|
169
|
-
continue;
|
|
170
|
-
}
|
|
171
|
-
|
|
172
|
-
//@ts-ignore
|
|
173
|
-
let jaxShapeProduct = jaxField.shape.reduce((a, b) => a * b, 1);
|
|
174
|
-
//@ts-ignore
|
|
175
|
-
let torchShapeProduct = torchField.shape.reduce((a, b) => a * b, 1);
|
|
176
|
-
|
|
177
|
-
if (jaxShapeProduct !== torchShapeProduct) {
|
|
178
|
-
return {
|
|
179
|
-
error: `JAX ${jaxField.path} with shape ${jaxField.shape} doesn't match PyTorch ${torchField.path} with shape ${torchField.shape}`,
|
|
180
|
-
};
|
|
181
|
-
}
|
|
182
|
-
}
|
|
183
|
-
return { success: true };
|
|
184
|
-
}
|
|
185
|
-
function removeSkipLayer(index: number) {
|
|
186
|
-
torchFields = torchFields.toSpliced(index, 1);
|
|
187
|
-
setTimeout(() => {
|
|
188
|
-
onEnd();
|
|
189
|
-
}, 100);
|
|
190
|
-
}
|
|
191
|
-
function addSkipLayer(index: number) {
|
|
192
|
-
let fields = fetchJaxAndTorchFields();
|
|
193
|
-
if (fields.error) {
|
|
194
|
-
Toast.fire({
|
|
195
|
-
icon: "error",
|
|
196
|
-
text: fields.error,
|
|
197
|
-
});
|
|
198
|
-
return;
|
|
199
|
-
}
|
|
200
|
-
const newField = {
|
|
201
|
-
skip: true,
|
|
202
|
-
shape: [0],
|
|
203
|
-
path: "SKIP",
|
|
204
|
-
type: "SKIP",
|
|
205
|
-
} as Field;
|
|
206
|
-
torchFields = fields.torchFields.toSpliced(index, 0, newField);
|
|
207
|
-
setTimeout(() => {
|
|
208
|
-
onEnd();
|
|
209
|
-
}, 100);
|
|
210
|
-
}
|
|
211
|
-
|
|
212
|
-
async function convert() {
|
|
213
|
-
let fields = fetchJaxAndTorchFields();
|
|
214
|
-
if (fields.error) {
|
|
215
|
-
Toast.fire({
|
|
216
|
-
icon: "error",
|
|
217
|
-
title: fields.error,
|
|
218
|
-
});
|
|
219
|
-
return;
|
|
220
|
-
}
|
|
221
|
-
|
|
222
|
-
let check = checkFields(fields.jaxFields, fields.torchFields);
|
|
223
|
-
if (check.error) {
|
|
224
|
-
Toast.fire({
|
|
225
|
-
icon: "error",
|
|
226
|
-
title: "Failed to convert",
|
|
227
|
-
text: check.error,
|
|
228
|
-
});
|
|
229
|
-
return;
|
|
230
|
-
}
|
|
231
|
-
|
|
232
|
-
const response = await fetch("/convert", {
|
|
233
|
-
method: "POST",
|
|
234
|
-
headers: {
|
|
235
|
-
"Content-Type": "application/json",
|
|
236
|
-
},
|
|
237
|
-
body: JSON.stringify({
|
|
238
|
-
model: model,
|
|
239
|
-
jaxFields: fields.jaxFields,
|
|
240
|
-
torchFields: fields.torchFields,
|
|
241
|
-
}),
|
|
242
|
-
});
|
|
243
|
-
|
|
244
|
-
const res = await response.json();
|
|
245
|
-
console.log(res);
|
|
246
|
-
if (res.error) {
|
|
247
|
-
Toast.fire({
|
|
248
|
-
icon: "error",
|
|
249
|
-
title: res.error,
|
|
250
|
-
});
|
|
251
|
-
} else {
|
|
252
|
-
Toast.fire({
|
|
253
|
-
icon: "success",
|
|
254
|
-
title: "Conversion successful",
|
|
255
|
-
});
|
|
256
|
-
}
|
|
257
|
-
}
|
|
258
|
-
|
|
259
|
-
function padToMatch() {
|
|
260
|
-
let fields = fetchJaxAndTorchFields();
|
|
261
|
-
if (fields.error) {
|
|
262
|
-
Toast.fire({
|
|
263
|
-
icon: "error",
|
|
264
|
-
text: fields.error,
|
|
265
|
-
});
|
|
266
|
-
return;
|
|
267
|
-
}
|
|
268
|
-
|
|
269
|
-
if (fields.torchFields.length < fields.jaxFields.length) {
|
|
270
|
-
let toAdd = fields.jaxFields.length - fields.torchFields.length;
|
|
271
|
-
for (let i = 0; i < toAdd; i++) {
|
|
272
|
-
setTimeout(() => {
|
|
273
|
-
console.log("adding skip at ", i);
|
|
274
|
-
addSkipLayer(fields.jaxFields.length + i);
|
|
275
|
-
}, 100);
|
|
276
|
-
}
|
|
277
|
-
}
|
|
278
|
-
}
|
|
279
|
-
|
|
280
|
-
function removeAllSkipLayers() {
|
|
281
|
-
let fields = fetchJaxAndTorchFields();
|
|
282
|
-
if (fields.error) {
|
|
283
|
-
Toast.fire({
|
|
284
|
-
icon: "error",
|
|
285
|
-
text: fields.error,
|
|
286
|
-
});
|
|
287
|
-
return;
|
|
288
|
-
}
|
|
289
|
-
let filteredFields = [];
|
|
290
|
-
for (let i = 0; i < fields.torchFields.length; i++) {
|
|
291
|
-
if (fields.torchFields[i].skip === false) {
|
|
292
|
-
filteredFields.push(fields.torchFields[i]);
|
|
293
|
-
}
|
|
294
|
-
}
|
|
295
|
-
torchFields = filteredFields;
|
|
296
|
-
}
|
|
297
|
-
|
|
298
|
-
async function matchByName() {
|
|
299
|
-
let fields = fetchJaxAndTorchFields();
|
|
300
|
-
if (fields.error) {
|
|
301
|
-
Toast.fire({
|
|
302
|
-
icon: "error",
|
|
303
|
-
text: fields.error,
|
|
304
|
-
});
|
|
305
|
-
return;
|
|
306
|
-
}
|
|
307
|
-
if (fields.jaxFields.length !== fields.torchFields.length) {
|
|
308
|
-
Toast.fire({
|
|
309
|
-
icon: "error",
|
|
310
|
-
text: "PyTree and State Dict have diffent lengths. Make sure to pad first!",
|
|
311
|
-
});
|
|
312
|
-
return;
|
|
313
|
-
}
|
|
314
|
-
Toast.fire({
|
|
315
|
-
icon: "info",
|
|
316
|
-
title: "Matching by name...",
|
|
317
|
-
text: "This can take a while! Hold tight.",
|
|
318
|
-
});
|
|
319
|
-
|
|
320
|
-
let content = `
|
|
321
|
-
You will get two lists of strings. These strings are fields of a JAX and PyTorch model.
|
|
322
|
-
For example:
|
|
323
|
-
--JAX START--
|
|
324
|
-
.layers[0].weight
|
|
325
|
-
.layers[1].weight
|
|
326
|
-
.layers[2].weight
|
|
327
|
-
.layers[3].weight
|
|
328
|
-
.layers[4].weight
|
|
329
|
-
--JAX END--
|
|
330
|
-
|
|
331
|
-
--PYTORCH START--
|
|
332
|
-
layers.0.weight
|
|
333
|
-
layers.1.weight
|
|
334
|
-
layers.4.weight
|
|
335
|
-
layers.2.weight
|
|
336
|
-
layers.3.weight
|
|
337
|
-
--PYTORCH END--
|
|
338
|
-
|
|
339
|
-
As you can see, the order doesn't match. This means, you should look at the PyTorch fields and
|
|
340
|
-
rearrange them, such that they match the JAX model. In the above example, the expected return value
|
|
341
|
-
is:
|
|
342
|
-
--PYTORCH START--
|
|
343
|
-
layers.0.weight
|
|
344
|
-
layers.1.weight
|
|
345
|
-
layers.2.weight
|
|
346
|
-
layers.3.weight
|
|
347
|
-
layers.4.weight
|
|
348
|
-
--PYTORCH END--
|
|
349
|
-
|
|
350
|
-
Here's another example:
|
|
351
|
-
--JAX START--
|
|
352
|
-
.conv1.weight
|
|
353
|
-
.bn1.weight
|
|
354
|
-
.bn1.bias
|
|
355
|
-
.bn1.state_index.init[0]
|
|
356
|
-
.bn1.state_index.init[1]
|
|
357
|
-
--JAX END--
|
|
358
|
-
|
|
359
|
-
--PYTORCH START--
|
|
360
|
-
bn1.running_mean
|
|
361
|
-
bn1.running_var
|
|
362
|
-
conv1.weight
|
|
363
|
-
bn1.weight
|
|
364
|
-
bn1.bias
|
|
365
|
-
--PYTORCH END--
|
|
366
|
-
|
|
367
|
-
The expected return value in this case is:
|
|
368
|
-
|
|
369
|
-
--PYTORCH START--
|
|
370
|
-
conv1.weight
|
|
371
|
-
bn1.weight
|
|
372
|
-
bn1.bias
|
|
373
|
-
bn1.running_mean
|
|
374
|
-
bn1.running_var
|
|
375
|
-
--PYTORCH END--
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
Sometimes, there are so-called "skip-layers" in the PyTorch model. Those can be put anywhere, preferably to
|
|
379
|
-
the end, because your priority is to match those fields that can be matched first! Here's an example:
|
|
380
|
-
|
|
381
|
-
--JAX START--
|
|
382
|
-
.layers[0].weight
|
|
383
|
-
.layers[1].weight
|
|
384
|
-
.layers[2].weight
|
|
385
|
-
.layers[3].weight
|
|
386
|
-
.layers[4].weight
|
|
387
|
-
--JAX END--
|
|
388
|
-
|
|
389
|
-
--PYTORCH START--
|
|
390
|
-
layers.0.weight
|
|
391
|
-
SKIP
|
|
392
|
-
layers.3.weight
|
|
393
|
-
layers.2.weight
|
|
394
|
-
layers.1.weight
|
|
395
|
-
--PYTORCH START--
|
|
396
|
-
|
|
397
|
-
This should return
|
|
398
|
-
|
|
399
|
-
--PYTORCH START--
|
|
400
|
-
layers.0.weight
|
|
401
|
-
layers.1.weight
|
|
402
|
-
layers.2.weight
|
|
403
|
-
layers.3.weight
|
|
404
|
-
SKIP
|
|
405
|
-
--PYTORCH START--
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
It's not always 100% which belongs to which. Use your best judgement. Start your response with
|
|
409
|
-
--PYTORCH START-- and end it with --PYTORCH END--.
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
Here's your input:
|
|
413
|
-
--JAX START--
|
|
414
|
-
`;
|
|
415
|
-
|
|
416
|
-
for (let i = 0; i < fields.jaxFields.length; i++) {
|
|
417
|
-
content += fields.jaxFields[i].path + "\n";
|
|
418
|
-
}
|
|
419
|
-
content += "--JAX END--\n";
|
|
420
|
-
content += "\n";
|
|
421
|
-
content += "--PYTORCH START--\n";
|
|
422
|
-
|
|
423
|
-
for (let i = 0; i < fields.torchFields.length; i++) {
|
|
424
|
-
content += fields.torchFields[i].path + "\n";
|
|
425
|
-
}
|
|
426
|
-
|
|
427
|
-
content += "--PYTORCH END--";
|
|
428
|
-
console.log(content);
|
|
429
|
-
|
|
430
|
-
let req = await fetch("/anthropic", {
|
|
431
|
-
method: "POST",
|
|
432
|
-
headers: {
|
|
433
|
-
"Content-Type": "application/json",
|
|
434
|
-
},
|
|
435
|
-
body: JSON.stringify({
|
|
436
|
-
content: content,
|
|
437
|
-
model: anthropicModel,
|
|
438
|
-
}),
|
|
439
|
-
});
|
|
440
|
-
let res = await req.json();
|
|
441
|
-
if (res.error) {
|
|
442
|
-
Toast.fire({
|
|
443
|
-
icon: "error",
|
|
444
|
-
text: res.error,
|
|
445
|
-
});
|
|
446
|
-
return;
|
|
447
|
-
}
|
|
448
|
-
console.log(res);
|
|
449
|
-
let responseContent = res.content;
|
|
450
|
-
let lines = responseContent.split("\n");
|
|
451
|
-
console.log(lines);
|
|
452
|
-
let rearrangedTorchFields = [];
|
|
453
|
-
for (let i = 0; i < lines.length; i++) {
|
|
454
|
-
let matchingTorchField = fields.torchFields.find(
|
|
455
|
-
(field) => field.path === lines[i],
|
|
456
|
-
);
|
|
457
|
-
if (matchingTorchField) {
|
|
458
|
-
rearrangedTorchFields.push(matchingTorchField);
|
|
459
|
-
}
|
|
460
|
-
}
|
|
461
|
-
if (fields.torchFields.length !== rearrangedTorchFields.length) {
|
|
462
|
-
Toast.fire({
|
|
463
|
-
icon: "error",
|
|
464
|
-
text: "Some fields are missing in the response. Try a different model instead.",
|
|
465
|
-
});
|
|
466
|
-
return;
|
|
467
|
-
}
|
|
468
|
-
console.log("rearrangedTorchFields", rearrangedTorchFields);
|
|
469
|
-
setTimeout(() => {
|
|
470
|
-
torchFields = rearrangedTorchFields;
|
|
471
|
-
onEnd();
|
|
472
|
-
Toast.fire({
|
|
473
|
-
icon: "success",
|
|
474
|
-
title: "Success",
|
|
475
|
-
});
|
|
476
|
-
}, 500);
|
|
477
|
-
}
|
|
478
|
-
</script>
|
|
479
|
-
|
|
480
|
-
<svelte:head><title>Statedict2PyTree</title></svelte:head>
|
|
481
|
-
|
|
482
|
-
<h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
|
|
483
|
-
<div class="my-4 flex justify-evenly">
|
|
484
|
-
<button on:click={padToMatch} class="btn btn-accent">Pad To Match</button>
|
|
485
|
-
<button on:click={removeAllSkipLayers} class="btn btn-secondary"
|
|
486
|
-
>Remove All Skip Layers</button
|
|
487
|
-
>
|
|
488
|
-
<div>
|
|
489
|
-
<button on:click={matchByName} class="btn btn-warning"
|
|
490
|
-
>Match By Name</button
|
|
491
|
-
>
|
|
492
|
-
<select bind:value={anthropicModel}>
|
|
493
|
-
<option value="opus">opus</option>
|
|
494
|
-
<option value="sonnet">sonnet</option>
|
|
495
|
-
<option value="sonnet3.5">sonnet3.5</option>
|
|
496
|
-
<option value="haiku">haiku</option>
|
|
497
|
-
</select>
|
|
498
|
-
</div>
|
|
499
|
-
</div>
|
|
500
|
-
<div class="grid grid-cols-2 gap-x-2">
|
|
501
|
-
<div class="">
|
|
502
|
-
<h2 class="text-2xl">JAX</h2>
|
|
503
|
-
<div id="jax-fields" class="">
|
|
504
|
-
{#each jaxFields as field, i}
|
|
505
|
-
<div
|
|
506
|
-
class="border h-12 rounded-xl flex flex-col justify-center"
|
|
507
|
-
>
|
|
508
|
-
<div
|
|
509
|
-
id={"jax-" + String(i)}
|
|
510
|
-
class="whitespace-nowrap overflow-x-scroll cursor-pointer mx-2"
|
|
511
|
-
data-jax="jax"
|
|
512
|
-
data-path={field.path}
|
|
513
|
-
data-shape={field.shape}
|
|
514
|
-
data-skip={field.skip}
|
|
515
|
-
data-type={field.type}
|
|
516
|
-
>
|
|
517
|
-
{field.path}
|
|
518
|
-
{field.shape}
|
|
519
|
-
</div>
|
|
520
|
-
</div>
|
|
521
|
-
{/each}
|
|
522
|
-
</div>
|
|
523
|
-
</div>
|
|
524
|
-
|
|
525
|
-
<div class="">
|
|
526
|
-
<h2 class="text-2xl">PyTorch</h2>
|
|
527
|
-
<div id="torch-fields" class="">
|
|
528
|
-
{#key torchFields}
|
|
529
|
-
{#each torchFields as field, i}
|
|
530
|
-
<div class="flex space-x-2 border h-12 rounded-xl">
|
|
531
|
-
<div
|
|
532
|
-
id={"torch-" + String(i)}
|
|
533
|
-
data-torch="torch"
|
|
534
|
-
data-path={field.path}
|
|
535
|
-
data-shape={field.shape}
|
|
536
|
-
data-skip={field.skip}
|
|
537
|
-
data-type={field.type}
|
|
538
|
-
class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
|
|
539
|
-
>
|
|
540
|
-
{#if field.skip}
|
|
541
|
-
SKIP
|
|
542
|
-
{:else}
|
|
543
|
-
{field.path}
|
|
544
|
-
{field.shape}
|
|
545
|
-
{/if}
|
|
546
|
-
</div>
|
|
547
|
-
{#if field.skip}
|
|
548
|
-
<button
|
|
549
|
-
class="btn btn-ghost"
|
|
550
|
-
on:click={() => {
|
|
551
|
-
removeSkipLayer(i);
|
|
552
|
-
}}>-</button
|
|
553
|
-
>
|
|
554
|
-
{/if}
|
|
555
|
-
<button
|
|
556
|
-
class="btn btn-ghost"
|
|
557
|
-
on:click={() => {
|
|
558
|
-
addSkipLayer(i);
|
|
559
|
-
}}>+</button
|
|
560
|
-
>
|
|
561
|
-
</div>
|
|
562
|
-
{/each}
|
|
563
|
-
{/key}
|
|
564
|
-
</div>
|
|
565
|
-
</div>
|
|
566
|
-
</div>
|
|
567
|
-
<div class="flex justify-center my-12 w-full">
|
|
568
|
-
<div class="flex flex-col justify-center w-full">
|
|
569
|
-
<input
|
|
570
|
-
id="name"
|
|
571
|
-
type="text"
|
|
572
|
-
name="name"
|
|
573
|
-
class="input input-primary w-full"
|
|
574
|
-
placeholder="Name of the new file (model.eqx per default)"
|
|
575
|
-
bind:value={model}
|
|
576
|
-
/>
|
|
577
|
-
<button
|
|
578
|
-
on:click={convert}
|
|
579
|
-
class="btn btn-accent btn-wide btn-lg mx-auto my-2"
|
|
580
|
-
>
|
|
581
|
-
Convert!
|
|
582
|
-
</button>
|
|
583
|
-
</div>
|
|
584
|
-
</div>
|
client/src/empty.ts
DELETED
|
File without changes
|
client/src/main.js
DELETED
client/tailwind.config.js
DELETED
client/tsconfig.json
DELETED
|
@@ -1,17 +0,0 @@
|
|
|
1
|
-
client/.gitignore,sha256=dUAd7J1wJVSaU9WD0Gp4c7_5yEz5lOdWLfsjubHf1z8,39
|
|
2
|
-
client/package-lock.json,sha256=JCW5mq9bGNs29yGO1EIlcysDaEuTm2i_AJoyt-45yO8,158531
|
|
3
|
-
client/package.json,sha256=Ad-MDEQeh7BPHWPYLd3u9sXk8YVuO_dXmpkxxU1Pglo,1044
|
|
4
|
-
client/rollup.config.mjs,sha256=RAepJhL2V5Rf-BlJBZJxllVl0mxtr67GSVr9aU0JUnA,1073
|
|
5
|
-
client/tailwind.config.js,sha256=TfN5eOoFOUPGBou6OoK54M14PtokgxWDJUsV4qkurS8,175
|
|
6
|
-
client/tsconfig.json,sha256=cLHEFXx-Q55XqbF9QjQ4XScSEQ15n-vS5tsTcqY4UAY,158
|
|
7
|
-
client/public/bundle.js,sha256=EEQ1-tpIHl5Nr2POytTpdKp9vriYvUGvbUb7mRv4-7s,356425
|
|
8
|
-
client/public/bundle.js.map,sha256=l_aFN6UUVGz9YoLVsk2IwsOWGojpbXmZdvVE-pC9UNQ,694164
|
|
9
|
-
client/public/index.html,sha256=jUx-NPKkFN2EF2lj-8Ml49CEHxKJFWK9seszauI4GE0,335
|
|
10
|
-
client/public/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
|
|
11
|
-
client/public/output.css,sha256=80svlSgNV_Fw82IYgqeRjnY86GXMBW0gWm2VyA7n1A8,36658
|
|
12
|
-
client/src/App.svelte,sha256=9DpP0XOFs3gfEW9Sj7NdGfQe4SyPBw9DVfQglRH70Ns,17721
|
|
13
|
-
client/src/empty.ts,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
client/src/main.js,sha256=O_8UgVd1vJM8BcHO7U_6jkL76ZSA6oC7GLLcL9F3JLA,118
|
|
15
|
-
statedict2pytree-0.5.4.dist-info/METADATA,sha256=ruY4h7mzPnIMACyxs3bkCrA-c5up5k-Lyqc7nNYy3fg,2742
|
|
16
|
-
statedict2pytree-0.5.4.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
17
|
-
statedict2pytree-0.5.4.dist-info/RECORD,,
|
|
File without changes
|