statedict2pytree 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
client/src/App.svelte ADDED
@@ -0,0 +1,343 @@
1
+ <script lang="ts">
2
+ //@ts-ignore
3
+ import Sortable from "sortablejs/modular/sortable.complete.esm.js";
4
+ import { onMount } from "svelte";
5
+ import Swal from "sweetalert2";
6
+
7
+ let model: string = "model.eqx";
8
+
9
+ const Toast = Swal.mixin({
10
+ toast: true,
11
+ position: "top-end",
12
+ showConfirmButton: false,
13
+ timer: 5000,
14
+ timerProgressBar: true,
15
+ didOpen: (toast) => {
16
+ toast.onmouseenter = Swal.stopTimer;
17
+ toast.onmouseleave = Swal.resumeTimer;
18
+ },
19
+ });
20
+ type Field = {
21
+ path: string;
22
+ shape: Number[];
23
+ skip: boolean;
24
+ type: string | null;
25
+ };
26
+
27
+ let jaxFields: Field[] = [];
28
+ let torchFields: Field[] = [];
29
+ let torchSortable: Sortable;
30
+ onMount(async () => {
31
+ let req = await fetch("/startup/getJaxFields");
32
+ jaxFields = (await req.json()) as Field[];
33
+ req = await fetch("/startup/getTorchFields");
34
+ torchFields = await req.json();
35
+ setTimeout(() => {
36
+ initSortable();
37
+ }, 100);
38
+ });
39
+
40
+ function initSortable() {
41
+ torchSortable = new Sortable(document.getElementById("torch-fields"), {
42
+ animation: 150,
43
+ multiDrag: true,
44
+ ghostClass: "bg-blue-400",
45
+ selectedClass: "bg-accent",
46
+ multiDragKey: "shift",
47
+ onEnd: onEnd,
48
+ });
49
+ }
50
+
51
+ function swap(a: any, b: any, array: any[]) {
52
+ const temp = array[a];
53
+ array[a] = array[b];
54
+ array[b] = temp;
55
+ return array;
56
+ }
57
+
58
+ function fetchJaxAndTorchFields() {
59
+ let allTorchElements =
60
+ document.getElementById("torch-fields")?.children;
61
+ if (!allTorchElements) {
62
+ Toast.fire({
63
+ icon: "error",
64
+ title: "Couldn't find PyTorch elements",
65
+ });
66
+ return {
67
+ error: "Failed to fetch PyTorch elements",
68
+ jaxFields: [],
69
+ torchFields: [],
70
+ };
71
+ }
72
+ let allTorchFields: HTMLElement[] = [];
73
+ for (let i = 0; i < allTorchElements.length; i++) {
74
+ allTorchFields.push(allTorchElements[i].firstChild as HTMLElement);
75
+ }
76
+
77
+ let newTorchFields: Field[] = [];
78
+ allTorchFields.forEach((el) => {
79
+ newTorchFields.push({
80
+ path: el.getAttribute("data-path"),
81
+ shape: el
82
+ .getAttribute("data-shape")
83
+ ?.split(",")
84
+ .map((x) => parseInt(x)),
85
+ type: el.getAttribute("data-type"),
86
+ skip: el.getAttribute("data-skip") === "true",
87
+ } as Field);
88
+ });
89
+
90
+ const allJaxFields = document.querySelectorAll('[data-jax="jax"]');
91
+ let newJaxFields: Field[] = [];
92
+ allJaxFields.forEach((el) => {
93
+ newJaxFields.push({
94
+ path: el.getAttribute("data-path"),
95
+ shape: el
96
+ .getAttribute("data-shape")
97
+ ?.split(",")
98
+ .map((x) => parseInt(x)),
99
+ type: el.getAttribute("data-type"),
100
+ skip: el.getAttribute("data-skip") === "true",
101
+ } as Field);
102
+ });
103
+
104
+ return { jaxFields: newJaxFields, torchFields: newTorchFields };
105
+ }
106
+
107
+ function onEnd() {
108
+ const updatedFields = fetchJaxAndTorchFields();
109
+ if (updatedFields.error) {
110
+ Toast.fire({
111
+ icon: "error",
112
+ title: updatedFields.error,
113
+ });
114
+ return;
115
+ }
116
+
117
+ for (let i = 0; i < updatedFields.jaxFields.length; i++) {
118
+ let jaxField = updatedFields.jaxFields[i];
119
+ let torchField = updatedFields.torchFields[i];
120
+ if (torchField === undefined) continue;
121
+ if (torchField.skip === true) {
122
+ document
123
+ .getElementById("jax-" + i)
124
+ ?.classList.remove("bg-error");
125
+ document
126
+ .getElementById("torch-" + i)
127
+ ?.classList.remove("bg-error");
128
+ continue;
129
+ }
130
+ let jaxShape = jaxField.shape;
131
+ let torchShape = torchField.shape;
132
+ //@ts-ignore
133
+ let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
134
+ //@ts-ignore
135
+ let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
136
+ if (jaxShapeProduct !== torchShapeProduct) {
137
+ document.getElementById("jax-" + i)?.classList.add("bg-error");
138
+ document
139
+ .getElementById("torch-" + i)
140
+ ?.classList.add("bg-error");
141
+ } else {
142
+ document
143
+ .getElementById("jax-" + i)
144
+ ?.classList.remove("bg-error");
145
+ document
146
+ .getElementById("torch-" + i)
147
+ ?.classList.remove("bg-error");
148
+ }
149
+ }
150
+
151
+ if (updatedFields.torchFields.length > updatedFields.jaxFields.length) {
152
+ for (
153
+ let i = updatedFields.jaxFields.length;
154
+ i < updatedFields.torchFields.length;
155
+ i++
156
+ ) {
157
+ document
158
+ .getElementById("torch-" + i)
159
+ ?.classList.remove("bg-error");
160
+ }
161
+ }
162
+ }
163
+ function checkFields(jaxFields: Field[], torchFields: Field[]) {
164
+ if (jaxFields.length > torchFields.length) {
165
+ return {
166
+ error: "JAX and PyTorch have lengths! Make sure to pad the PyTorch side.",
167
+ };
168
+ }
169
+
170
+ for (let i = 0; i < jaxFields.length; i++) {
171
+ let jaxField = jaxFields[i];
172
+ let torchField = torchFields[i];
173
+ if (torchField.skip === true) {
174
+ continue;
175
+ }
176
+
177
+ //@ts-ignore
178
+ let jaxShapeProduct = jaxField.shape.reduce((a, b) => a * b, 1);
179
+ //@ts-ignore
180
+ let torchShapeProduct = torchField.shape.reduce((a, b) => a * b, 1);
181
+
182
+ if (jaxShapeProduct !== torchShapeProduct) {
183
+ return {
184
+ error: `JAX ${jaxField.path} with shape ${jaxField.shape} doesn't match PyTorch ${torchField.path} with shape ${torchField.shape}`,
185
+ };
186
+ }
187
+ }
188
+ return { success: true };
189
+ }
190
+ function removeSkipLayer(index: number) {
191
+ torchFields = torchFields.toSpliced(index, 1);
192
+ setTimeout(() => {
193
+ onEnd();
194
+ }, 100);
195
+ }
196
+ function addSkipLayer(index: number) {
197
+ const newField = {
198
+ skip: true,
199
+ shape: [],
200
+ path: "",
201
+ type: "",
202
+ } as Field;
203
+ torchFields = torchFields.toSpliced(index, 0, newField);
204
+ setTimeout(() => {
205
+ onEnd();
206
+ }, 100);
207
+ }
208
+
209
+ async function convert() {
210
+ let fields = fetchJaxAndTorchFields();
211
+ if (fields.error) {
212
+ Toast.fire({
213
+ icon: "error",
214
+ title: fields.error,
215
+ });
216
+ return;
217
+ }
218
+
219
+ let check = checkFields(fields.jaxFields, fields.torchFields);
220
+ if (check.error) {
221
+ Toast.fire({
222
+ icon: "error",
223
+ title: "Failed to convert",
224
+ text: check.error,
225
+ });
226
+ return;
227
+ }
228
+
229
+ const response = await fetch("/convert", {
230
+ method: "POST",
231
+ headers: {
232
+ "Content-Type": "application/json",
233
+ },
234
+ body: JSON.stringify({
235
+ model: model,
236
+ jaxFields: fields.jaxFields,
237
+ torchFields: fields.torchFields,
238
+ }),
239
+ });
240
+
241
+ const res = await response.json();
242
+ console.log(res);
243
+ if (res.error) {
244
+ Toast.fire({
245
+ icon: "error",
246
+ title: res.error,
247
+ });
248
+ } else {
249
+ Toast.fire({
250
+ icon: "success",
251
+ title: "Conversion successful",
252
+ });
253
+ }
254
+ }
255
+ </script>
256
+
257
+ <svelte:head><title>Statedict2PyTree</title></svelte:head>
258
+
259
+ <h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
260
+
261
+ <div class="grid grid-cols-2 gap-x-2">
262
+ <div class="">
263
+ <h2 class="text-2xl">JAX</h2>
264
+ <div id="jax-fields" class="">
265
+ {#each jaxFields as field, i}
266
+ <div
267
+ class="border h-12 rounded-xl flex flex-col justify-center"
268
+ >
269
+ <div
270
+ id={"jax-" + String(i)}
271
+ class="whitespace-nowrap overflow-x-scroll cursor-pointer mx-2"
272
+ data-jax="jax"
273
+ data-path={field.path}
274
+ data-shape={field.shape}
275
+ data-skip={field.skip}
276
+ data-type={field.type}
277
+ >
278
+ {field.path}
279
+ {field.shape}
280
+ </div>
281
+ </div>
282
+ {/each}
283
+ </div>
284
+ </div>
285
+
286
+ <div class="">
287
+ <h2 class="text-2xl">PyTorch</h2>
288
+ <div id="torch-fields" class="">
289
+ {#each torchFields as field, i}
290
+ <div class="flex space-x-2 border h-12 rounded-xl">
291
+ <div
292
+ id={"torch-" + String(i)}
293
+ data-torch="torch"
294
+ data-path={field.path}
295
+ data-shape={field.shape}
296
+ data-skip={field.skip}
297
+ data-type={field.type}
298
+ class="flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer"
299
+ >
300
+ {#if field.skip}
301
+ SKIP
302
+ {:else}
303
+ {field.path}
304
+ {field.shape}
305
+ {/if}
306
+ </div>
307
+ {#if field.skip}
308
+ <button
309
+ class="btn btn-ghost"
310
+ on:click={() => {
311
+ removeSkipLayer(i);
312
+ }}>-</button
313
+ >
314
+ {/if}
315
+ <button
316
+ class="btn btn-ghost"
317
+ on:click={() => {
318
+ addSkipLayer(i);
319
+ }}>+</button
320
+ >
321
+ </div>
322
+ {/each}
323
+ </div>
324
+ </div>
325
+ </div>
326
+ <div class="flex justify-center my-12 w-full">
327
+ <div class="flex flex-col justify-center w-full">
328
+ <input
329
+ id="name"
330
+ type="text"
331
+ name="name"
332
+ class="input input-primary w-full"
333
+ placeholder="Name of the new file (model.eqx per default)"
334
+ bind:value={model}
335
+ />
336
+ <button
337
+ on:click={convert}
338
+ class="btn btn-accent btn-wide btn-lg mx-auto my-2"
339
+ >
340
+ Convert!
341
+ </button>
342
+ </div>
343
+ </div>
client/src/empty.ts ADDED
File without changes
client/src/main.js ADDED
@@ -0,0 +1,8 @@
1
+ import App from "./App.svelte";
2
+
3
+ const app = new App({
4
+ target: document.body,
5
+ props: {},
6
+ });
7
+
8
+ export default app;
@@ -0,0 +1,8 @@
1
+ /** @type {import('tailwindcss').Config} */
2
+ module.exports = {
3
+ content: ["./src/**/*.{html,js,svelte}"],
4
+ theme: {
5
+ extend: {},
6
+ },
7
+ plugins: [require("daisyui")],
8
+ };
client/tsconfig.json ADDED
@@ -0,0 +1,5 @@
1
+ {
2
+ "extends": "@tsconfig/svelte/tsconfig.json",
3
+ "include": ["src/**/*", "src/node_modules"],
4
+ "exclude": ["node_modules/*", "__sapper__/*", "public/*"]
5
+ }
@@ -0,0 +1,64 @@
1
+ Metadata-Version: 2.3
2
+ Name: statedict2pytree
3
+ Version: 0.5.2
4
+ Summary: Converts torch models into PyTrees for Equinox
5
+ Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
+ Requires-Python: ~=3.10
7
+ Requires-Dist: beartype
8
+ Requires-Dist: equinox
9
+ Requires-Dist: flask
10
+ Requires-Dist: jax
11
+ Requires-Dist: jaxlib
12
+ Requires-Dist: jaxonmodels
13
+ Requires-Dist: jaxtyping
14
+ Requires-Dist: loguru
15
+ Requires-Dist: penzai
16
+ Requires-Dist: pydantic
17
+ Requires-Dist: torch
18
+ Requires-Dist: torchvision
19
+ Requires-Dist: typing-extensions
20
+ Provides-Extra: dev
21
+ Requires-Dist: mkdocs; extra == 'dev'
22
+ Requires-Dist: nox; extra == 'dev'
23
+ Requires-Dist: pre-commit; extra == 'dev'
24
+ Requires-Dist: pytest; extra == 'dev'
25
+ Provides-Extra: examples
26
+ Requires-Dist: jaxonmodels; extra == 'examples'
27
+ Description-Content-Type: text/markdown
28
+
29
+ # statedict2pytree
30
+
31
+ ![statedict2pytree](statedict2pytree.png "A ResNet demo")
32
+
33
+ ## Important
34
+
35
+ This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
36
+ PRs and other contributions are *highly* welcome! :)
37
+
38
+ ## Info
39
+
40
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
41
+
42
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
43
+
44
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
45
+
46
+ ## Shape Matching? What's that?
47
+
48
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
49
+
50
+ (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
51
+
52
+ ## Get Started
53
+
54
+ ### Installation
55
+
56
+ Run
57
+
58
+ ```bash
59
+ pip install statedict2pytree
60
+ ```
61
+
62
+ ### Docs
63
+
64
+ Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
@@ -0,0 +1,17 @@
1
+ client/.gitignore,sha256=dUAd7J1wJVSaU9WD0Gp4c7_5yEz5lOdWLfsjubHf1z8,39
2
+ client/package-lock.json,sha256=JCW5mq9bGNs29yGO1EIlcysDaEuTm2i_AJoyt-45yO8,158531
3
+ client/package.json,sha256=Ad-MDEQeh7BPHWPYLd3u9sXk8YVuO_dXmpkxxU1Pglo,1044
4
+ client/rollup.config.mjs,sha256=RAepJhL2V5Rf-BlJBZJxllVl0mxtr67GSVr9aU0JUnA,1073
5
+ client/tailwind.config.js,sha256=TfN5eOoFOUPGBou6OoK54M14PtokgxWDJUsV4qkurS8,175
6
+ client/tsconfig.json,sha256=cLHEFXx-Q55XqbF9QjQ4XScSEQ15n-vS5tsTcqY4UAY,158
7
+ client/public/bundle.js,sha256=l4Vu_8_7v6k7XFisnI8jCFzyMeztr9jc3Z3lrrPDpk0,347197
8
+ client/public/bundle.js.map,sha256=s9zOkP-34BWrQFunVbkouwTKVAfL8EyIUcfKgSicH8M,682783
9
+ client/public/index.html,sha256=jUx-NPKkFN2EF2lj-8Ml49CEHxKJFWK9seszauI4GE0,335
10
+ client/public/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
11
+ client/public/output.css,sha256=3iiBiTGfqAeVKuRZRqgcixX3ztSnlp0zqHXkjSKtmVs,38664
12
+ client/src/App.svelte,sha256=hHVoQ_C2xGMmd4d86ftZFeTlGmOcOW_wJ6abq0_qvWo,11170
13
+ client/src/empty.ts,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
+ client/src/main.js,sha256=O_8UgVd1vJM8BcHO7U_6jkL76ZSA6oC7GLLcL9F3JLA,118
15
+ statedict2pytree-0.5.2.dist-info/METADATA,sha256=YgE7bgWDMI6urA71bH-zWleR_mw6SJ5QPneUZVvHL2E,2242
16
+ statedict2pytree-0.5.2.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
17
+ statedict2pytree-0.5.2.dist-info/RECORD,,
@@ -1,7 +0,0 @@
1
- from statedict2pytree.statedict2pytree import (
2
- autoconvert as autoconvert,
3
- convert as convert,
4
- pytree_to_fields as pytree_to_fields,
5
- start_conversion as start_conversion,
6
- state_dict_to_fields as state_dict_to_fields,
7
- )
@@ -1,219 +0,0 @@
1
- import functools as ft
2
- import re
3
-
4
- import equinox as eqx
5
- import flask
6
- import jax
7
- import numpy as np
8
- from beartype.typing import Optional
9
- from jaxtyping import PyTree
10
- from loguru import logger
11
- from penzai import pz
12
- from pydantic import BaseModel
13
-
14
-
15
- app = flask.Flask(__name__)
16
-
17
-
18
- class Field(BaseModel):
19
- path: str
20
- shape: tuple[int, ...]
21
-
22
-
23
- class TorchField(Field):
24
- pass
25
-
26
-
27
- class JaxField(Field):
28
- type: str
29
-
30
-
31
- PYTREE: Optional[PyTree] = None
32
- STATE_DICT: Optional[dict] = None
33
-
34
-
35
- def can_reshape(shape1, shape2):
36
- product1 = np.prod(shape1)
37
- product2 = np.prod(shape2)
38
-
39
- return product1 == product2
40
-
41
-
42
- def get_node(
43
- tree: PyTree, targets: list[str], log_when_not_found: bool = False
44
- ) -> PyTree | None:
45
- if len(targets) == 0 or tree is None:
46
- return tree
47
- else:
48
- next_target: str = targets[0]
49
- if bool(re.search(r"\[\d+\]", next_target)):
50
- split_index = next_target.rfind("[")
51
- name, index = next_target[:split_index], next_target[split_index:]
52
- index = index[1:-1]
53
- if hasattr(tree, name):
54
- subtree = getattr(tree, name)[int(index)]
55
- else:
56
- subtree = None
57
- if log_when_not_found:
58
- logger.info(f"Couldn't find {name} in {tree.__class__}")
59
- else:
60
- if hasattr(tree, next_target):
61
- subtree = getattr(tree, next_target)
62
- else:
63
- subtree = None
64
- if log_when_not_found:
65
- logger.info(f"Couldn't find {next_target} in {tree.__class__}")
66
- return get_node(subtree, targets[1:])
67
-
68
-
69
- def pytree_to_fields(pytree: PyTree) -> list[JaxField]:
70
- flattened, _ = jax.tree_util.tree_flatten_with_path(pytree)
71
- fields: list[JaxField] = []
72
- for key_path, value in flattened:
73
- path = jax.tree_util.keystr(key_path)
74
- type_path = path.split(".")[1:-1]
75
- target_path = path.split(".")[1:]
76
- node_type = type(get_node(pytree, type_path, log_when_not_found=True))
77
- node = get_node(pytree, target_path, log_when_not_found=True)
78
- if node is not None and hasattr(node, "shape") and len(node.shape) > 0:
79
- fields.append(
80
- JaxField(path=path, type=str(node_type), shape=tuple(node.shape))
81
- )
82
-
83
- return fields
84
-
85
-
86
- def state_dict_to_fields(state_dict: Optional[dict]) -> list[TorchField]:
87
- if state_dict is None:
88
- return []
89
- fields: list[TorchField] = []
90
- for key, value in state_dict.items():
91
- if hasattr(value, "shape") and len(value.shape) > 0:
92
- fields.append(TorchField(path=key, shape=tuple(value.shape)))
93
- return fields
94
-
95
-
96
- @app.route("/visualize", methods=["POST"])
97
- def visualize_with_penzai():
98
- global PYTREE, STATE_DICT
99
- if PYTREE is None or STATE_DICT is None:
100
- return flask.jsonify({"error": "No Pytree or StateDict found"})
101
- request_data = flask.request.json
102
- if request_data is None:
103
- return flask.jsonify({"error": "No data received"})
104
- jax_fields = request_data["jaxFields"]
105
- torch_fields = request_data["torchFields"]
106
- model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
107
- with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
108
- html_jax = pz.ts.render_to_html((model, state))
109
- html_torch = pz.ts.render_to_html(STATE_DICT)
110
-
111
- combined_html = f"<html><body>{html_jax}<hr>{html_torch}</body></html>"
112
- return combined_html
113
-
114
-
115
- @app.route("/convert", methods=["POST"])
116
- def convert_torch_to_jax():
117
- global PYTREE, STATE_DICT
118
- if PYTREE is None or STATE_DICT is None:
119
- return flask.jsonify({"error": "No Pytree or StateDict found"})
120
- request_data = flask.request.json
121
- if request_data is None:
122
- return flask.jsonify({"error": "No data received"})
123
-
124
- jax_fields_json = request_data["jaxFields"]
125
- jax_fields: list[JaxField] = []
126
- for f in jax_fields_json:
127
- shape_tuple = tuple(
128
- [int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
129
- )
130
- jax_fields.append(JaxField(path=f["path"], type=f["type"], shape=shape_tuple))
131
-
132
- torch_fields_json = request_data["torchFields"]
133
- torch_fields: list[TorchField] = []
134
- for f in torch_fields_json:
135
- shape_tuple = tuple(
136
- [int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
137
- )
138
- torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
139
-
140
- name = request_data["name"]
141
- model, state = convert(jax_fields, torch_fields, PYTREE, STATE_DICT)
142
- eqx.tree_serialise_leaves(name, (model, state))
143
-
144
- return flask.jsonify({"status": "success"})
145
-
146
-
147
- @app.route("/", methods=["GET"])
148
- def index():
149
- pytree_fields = pytree_to_fields(PYTREE)
150
- return flask.render_template(
151
- "index.html",
152
- pytree_fields=pytree_fields,
153
- torch_fields=state_dict_to_fields(STATE_DICT),
154
- )
155
-
156
-
157
- def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
- jax_fields = pytree_to_fields(pytree)
159
- torch_fields = state_dict_to_fields(state_dict)
160
-
161
- for k, v in state_dict.items():
162
- state_dict[k] = v.numpy()
163
- return convert(jax_fields, torch_fields, pytree, state_dict)
164
-
165
-
166
- def convert(
167
- jax_fields: list[JaxField],
168
- torch_fields: list[TorchField],
169
- pytree: PyTree,
170
- state_dict: dict,
171
- ) -> tuple[PyTree, eqx.nn.State]:
172
- identity = lambda *args, **kwargs: pytree
173
- model, state = eqx.nn.make_with_state(identity)()
174
- state_paths: list[tuple[JaxField, TorchField]] = []
175
- for jax_field, torch_field in zip(jax_fields, torch_fields):
176
- if not can_reshape(jax_field.shape, torch_field.shape):
177
- raise ValueError(
178
- "Fields have incompatible shapes!"
179
- f"{jax_field.shape=} != {torch_field.shape=}"
180
- )
181
- path = jax_field.path.split(".")[1:]
182
- if "StateIndex" in jax_field.type:
183
- state_paths.append((jax_field, torch_field))
184
-
185
- else:
186
- where = ft.partial(get_node, targets=path)
187
- if where(model) is not None:
188
- model = eqx.tree_at(
189
- where,
190
- model,
191
- state_dict[torch_field.path].reshape(jax_field.shape),
192
- )
193
- result: dict[str, list[TorchField]] = {}
194
- for tuple_item in state_paths:
195
- path_prefix = tuple_item[0].path.split(".")[1:-1]
196
- prefix_key = ".".join(path_prefix)
197
- if prefix_key not in result:
198
- result[prefix_key] = []
199
- result[prefix_key].append(tuple_item[1])
200
-
201
- for key in result:
202
- state_index = get_node(model, key.split("."))
203
- if state_index is not None:
204
- to_replace_tuple = tuple([state_dict[i.path] for i in result[key]])
205
- state = state.set(state_index, to_replace_tuple)
206
- return model, state
207
-
208
-
209
- def start_conversion(pytree: PyTree, state_dict: dict):
210
- global PYTREE, STATE_DICT
211
- if state_dict is None:
212
- raise ValueError("STATE_DICT must not be None!")
213
- PYTREE = pytree
214
- STATE_DICT = state_dict
215
-
216
- for k, v in STATE_DICT.items():
217
- STATE_DICT[k] = v.numpy()
218
- app.jinja_env.globals.update(enumerate=enumerate)
219
- app.run(debug=True, port=5500)