statedict2pytree 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,4 @@
1
+ from statedict2pytree.statedict2pytree import (
2
+ convert as convert,
3
+ start_conversion as start_conversion,
4
+ )
@@ -0,0 +1,194 @@
1
+ import functools as ft
2
+ import re
3
+
4
+ import equinox as eqx
5
+ import flask
6
+ import jax
7
+ from beartype.typing import Optional
8
+ from jaxtyping import PyTree
9
+ from loguru import logger
10
+ from penzai import pz
11
+ from pydantic import BaseModel
12
+
13
+
14
+ app = flask.Flask(__name__)
15
+
16
+
17
+ class Field(BaseModel):
18
+ path: str
19
+ shape: tuple[int, ...]
20
+
21
+
22
+ class TorchField(Field):
23
+ pass
24
+
25
+
26
+ class JaxField(Field):
27
+ type: str
28
+
29
+
30
+ PYTREE: Optional[PyTree] = None
31
+ STATE_DICT: Optional[dict] = None
32
+
33
+
34
+ def get_node(
35
+ tree: PyTree, targets: list[str], log_when_not_found: bool = False
36
+ ) -> PyTree | None:
37
+ if len(targets) == 0 or tree is None:
38
+ return tree
39
+ else:
40
+ next_target: str = targets[0]
41
+ if bool(re.search(r"\[\d\]", next_target)):
42
+ split_index = next_target.rfind("[")
43
+ name, index = next_target[:split_index], next_target[split_index:]
44
+ index = index[1:-1]
45
+ if hasattr(tree, name):
46
+ subtree = getattr(tree, name)[int(index)]
47
+ else:
48
+ subtree = None
49
+ if log_when_not_found:
50
+ logger.info(f"Couldn't find {name} in {tree.__class__}")
51
+ else:
52
+ if hasattr(tree, next_target):
53
+ subtree = getattr(tree, next_target)
54
+ else:
55
+ subtree = None
56
+ if log_when_not_found:
57
+ logger.info(f"Couldn't find {next_target} in {tree.__class__}")
58
+ return get_node(subtree, targets[1:])
59
+
60
+
61
+ def pytree_to_fields(pytree: PyTree) -> list[JaxField]:
62
+ flattened, _ = jax.tree_util.tree_flatten_with_path(pytree)
63
+ fields: list[JaxField] = []
64
+ for key_path, value in flattened:
65
+ path = jax.tree_util.keystr(key_path)
66
+ type_path = path.split(".")[1:-1]
67
+ target_path = path.split(".")[1:]
68
+ node_type = type(get_node(pytree, type_path, log_when_not_found=True))
69
+ node = get_node(pytree, target_path, log_when_not_found=True)
70
+ if node is not None and hasattr(node, "shape") and len(node.shape) > 0:
71
+ fields.append(
72
+ JaxField(path=path, type=str(node_type), shape=tuple(node.shape))
73
+ )
74
+
75
+ return fields
76
+
77
+
78
+ def state_dict_to_fields(state_dict: Optional[dict]) -> list[TorchField]:
79
+ if state_dict is None:
80
+ return []
81
+ fields: list[TorchField] = []
82
+ for key, value in state_dict.items():
83
+ if hasattr(value, "shape") and len(value.shape) > 0:
84
+ fields.append(TorchField(path=key, shape=tuple(value.shape)))
85
+ return fields
86
+
87
+
88
+ @app.route("/visualize", methods=["POST"])
89
+ def visualize_with_penzai():
90
+ global PYTREE, STATE_DICT
91
+ if PYTREE is None or STATE_DICT is None:
92
+ return flask.jsonify({"error": "No Pytree or StateDict found"})
93
+ request_data = flask.request.json
94
+ if request_data is None:
95
+ return flask.jsonify({"error": "No data received"})
96
+ jax_fields = request_data["jaxFields"]
97
+ torch_fields = request_data["torchFields"]
98
+ model, state = convert(jax_fields, torch_fields)
99
+ with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
100
+ html_jax = pz.ts.render_to_html((model, state))
101
+ html_torch = pz.ts.render_to_html(STATE_DICT)
102
+
103
+ combined_html = f"<html><body>{html_jax}<hr>{html_torch}</body></html>"
104
+ return combined_html
105
+
106
+
107
+ @app.route("/convert", methods=["POST"])
108
+ def convert_torch_to_jax():
109
+ global PYTREE, STATE_DICT
110
+ if PYTREE is None or STATE_DICT is None:
111
+ return flask.jsonify({"error": "No Pytree or StateDict found"})
112
+ request_data = flask.request.json
113
+ if request_data is None:
114
+ return flask.jsonify({"error": "No data received"})
115
+
116
+ jax_fields_json = request_data["jaxFields"]
117
+ jax_fields: list[JaxField] = []
118
+ for f in jax_fields_json:
119
+ shape_tuple = tuple(
120
+ [int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
121
+ )
122
+ jax_fields.append(JaxField(path=f["path"], type=f["type"], shape=shape_tuple))
123
+
124
+ torch_fields_json = request_data["torchFields"]
125
+ torch_fields: list[TorchField] = []
126
+ for f in torch_fields_json:
127
+ shape_tuple = tuple(
128
+ [int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
129
+ )
130
+ torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
131
+
132
+ name = request_data["name"]
133
+ model, state = convert(jax_fields, torch_fields)
134
+ eqx.tree_serialise_leaves(name, (model, state))
135
+
136
+ return flask.jsonify({"status": "success"})
137
+
138
+
139
+ @app.route("/", methods=["GET"])
140
+ def index():
141
+ pytree_fields = pytree_to_fields(PYTREE)
142
+ return flask.render_template(
143
+ "index.html",
144
+ pytree_fields=pytree_fields,
145
+ torch_fields=state_dict_to_fields(STATE_DICT),
146
+ )
147
+
148
+
149
+ def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
150
+ global PYTREE, STATE_DICT
151
+ if STATE_DICT is None:
152
+ raise ValueError("STATE_DICT must not be None!")
153
+ identity = lambda *args, **kwargs: PYTREE
154
+ model, state = eqx.nn.make_with_state(identity)()
155
+ state_paths: list[tuple[JaxField, TorchField]] = []
156
+ for jax_field, torch_field in zip(jax_fields, torch_fields):
157
+ path = jax_field.path.split(".")[1:]
158
+ if "StateIndex" in jax_field.type:
159
+ state_paths.append((jax_field, torch_field))
160
+
161
+ else:
162
+ where = ft.partial(get_node, targets=path)
163
+ if where(model) is not None:
164
+ model = eqx.tree_at(
165
+ where,
166
+ model,
167
+ STATE_DICT[torch_field.path],
168
+ )
169
+ result: dict[str, list[TorchField]] = {}
170
+ for tuple_item in state_paths:
171
+ path_prefix = tuple_item[0].path.split(".")[1:-1]
172
+ prefix_key = ".".join(path_prefix)
173
+ if prefix_key not in result:
174
+ result[prefix_key] = []
175
+ result[prefix_key].append(tuple_item[1])
176
+
177
+ for key in result:
178
+ state_index = get_node(model, key.split("."))
179
+ if state_index is not None:
180
+ to_replace_tuple = tuple([STATE_DICT[i.path] for i in result[key]])
181
+ state = state.set(state_index, to_replace_tuple)
182
+ return model, state
183
+
184
+
185
+ def start_conversion(pytree: PyTree, state_dict: dict):
186
+ global PYTREE, STATE_DICT
187
+ if state_dict is None:
188
+ raise ValueError("STATE_DICT must not be None!")
189
+ PYTREE = pytree
190
+ STATE_DICT = state_dict
191
+
192
+ for k, v in STATE_DICT.items():
193
+ STATE_DICT[k] = v.numpy()
194
+ app.run(debug=True, port=5500)
@@ -0,0 +1,3 @@
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;