statedict2pytree 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree/__init__.py +4 -0
- statedict2pytree/statedict2pytree.py +194 -0
- statedict2pytree/static/input.css +3 -0
- statedict2pytree/static/output.css +1734 -0
- statedict2pytree/templates/index.html +246 -0
- statedict2pytree-0.1.2.dist-info/METADATA +43 -0
- statedict2pytree-0.1.2.dist-info/RECORD +8 -0
- statedict2pytree-0.1.2.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import functools as ft
|
|
2
|
+
import re
|
|
3
|
+
|
|
4
|
+
import equinox as eqx
|
|
5
|
+
import flask
|
|
6
|
+
import jax
|
|
7
|
+
from beartype.typing import Optional
|
|
8
|
+
from jaxtyping import PyTree
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from penzai import pz
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
app = flask.Flask(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Field(BaseModel):
|
|
18
|
+
path: str
|
|
19
|
+
shape: tuple[int, ...]
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TorchField(Field):
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class JaxField(Field):
|
|
27
|
+
type: str
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
PYTREE: Optional[PyTree] = None
|
|
31
|
+
STATE_DICT: Optional[dict] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def get_node(
|
|
35
|
+
tree: PyTree, targets: list[str], log_when_not_found: bool = False
|
|
36
|
+
) -> PyTree | None:
|
|
37
|
+
if len(targets) == 0 or tree is None:
|
|
38
|
+
return tree
|
|
39
|
+
else:
|
|
40
|
+
next_target: str = targets[0]
|
|
41
|
+
if bool(re.search(r"\[\d\]", next_target)):
|
|
42
|
+
split_index = next_target.rfind("[")
|
|
43
|
+
name, index = next_target[:split_index], next_target[split_index:]
|
|
44
|
+
index = index[1:-1]
|
|
45
|
+
if hasattr(tree, name):
|
|
46
|
+
subtree = getattr(tree, name)[int(index)]
|
|
47
|
+
else:
|
|
48
|
+
subtree = None
|
|
49
|
+
if log_when_not_found:
|
|
50
|
+
logger.info(f"Couldn't find {name} in {tree.__class__}")
|
|
51
|
+
else:
|
|
52
|
+
if hasattr(tree, next_target):
|
|
53
|
+
subtree = getattr(tree, next_target)
|
|
54
|
+
else:
|
|
55
|
+
subtree = None
|
|
56
|
+
if log_when_not_found:
|
|
57
|
+
logger.info(f"Couldn't find {next_target} in {tree.__class__}")
|
|
58
|
+
return get_node(subtree, targets[1:])
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def pytree_to_fields(pytree: PyTree) -> list[JaxField]:
|
|
62
|
+
flattened, _ = jax.tree_util.tree_flatten_with_path(pytree)
|
|
63
|
+
fields: list[JaxField] = []
|
|
64
|
+
for key_path, value in flattened:
|
|
65
|
+
path = jax.tree_util.keystr(key_path)
|
|
66
|
+
type_path = path.split(".")[1:-1]
|
|
67
|
+
target_path = path.split(".")[1:]
|
|
68
|
+
node_type = type(get_node(pytree, type_path, log_when_not_found=True))
|
|
69
|
+
node = get_node(pytree, target_path, log_when_not_found=True)
|
|
70
|
+
if node is not None and hasattr(node, "shape") and len(node.shape) > 0:
|
|
71
|
+
fields.append(
|
|
72
|
+
JaxField(path=path, type=str(node_type), shape=tuple(node.shape))
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
return fields
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def state_dict_to_fields(state_dict: Optional[dict]) -> list[TorchField]:
|
|
79
|
+
if state_dict is None:
|
|
80
|
+
return []
|
|
81
|
+
fields: list[TorchField] = []
|
|
82
|
+
for key, value in state_dict.items():
|
|
83
|
+
if hasattr(value, "shape") and len(value.shape) > 0:
|
|
84
|
+
fields.append(TorchField(path=key, shape=tuple(value.shape)))
|
|
85
|
+
return fields
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@app.route("/visualize", methods=["POST"])
|
|
89
|
+
def visualize_with_penzai():
|
|
90
|
+
global PYTREE, STATE_DICT
|
|
91
|
+
if PYTREE is None or STATE_DICT is None:
|
|
92
|
+
return flask.jsonify({"error": "No Pytree or StateDict found"})
|
|
93
|
+
request_data = flask.request.json
|
|
94
|
+
if request_data is None:
|
|
95
|
+
return flask.jsonify({"error": "No data received"})
|
|
96
|
+
jax_fields = request_data["jaxFields"]
|
|
97
|
+
torch_fields = request_data["torchFields"]
|
|
98
|
+
model, state = convert(jax_fields, torch_fields)
|
|
99
|
+
with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()):
|
|
100
|
+
html_jax = pz.ts.render_to_html((model, state))
|
|
101
|
+
html_torch = pz.ts.render_to_html(STATE_DICT)
|
|
102
|
+
|
|
103
|
+
combined_html = f"<html><body>{html_jax}<hr>{html_torch}</body></html>"
|
|
104
|
+
return combined_html
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@app.route("/convert", methods=["POST"])
|
|
108
|
+
def convert_torch_to_jax():
|
|
109
|
+
global PYTREE, STATE_DICT
|
|
110
|
+
if PYTREE is None or STATE_DICT is None:
|
|
111
|
+
return flask.jsonify({"error": "No Pytree or StateDict found"})
|
|
112
|
+
request_data = flask.request.json
|
|
113
|
+
if request_data is None:
|
|
114
|
+
return flask.jsonify({"error": "No data received"})
|
|
115
|
+
|
|
116
|
+
jax_fields_json = request_data["jaxFields"]
|
|
117
|
+
jax_fields: list[JaxField] = []
|
|
118
|
+
for f in jax_fields_json:
|
|
119
|
+
shape_tuple = tuple(
|
|
120
|
+
[int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
|
|
121
|
+
)
|
|
122
|
+
jax_fields.append(JaxField(path=f["path"], type=f["type"], shape=shape_tuple))
|
|
123
|
+
|
|
124
|
+
torch_fields_json = request_data["torchFields"]
|
|
125
|
+
torch_fields: list[TorchField] = []
|
|
126
|
+
for f in torch_fields_json:
|
|
127
|
+
shape_tuple = tuple(
|
|
128
|
+
[int(i) for i in f["shape"].strip("()").split(",") if len(i) > 0]
|
|
129
|
+
)
|
|
130
|
+
torch_fields.append(TorchField(path=f["path"], shape=shape_tuple))
|
|
131
|
+
|
|
132
|
+
name = request_data["name"]
|
|
133
|
+
model, state = convert(jax_fields, torch_fields)
|
|
134
|
+
eqx.tree_serialise_leaves(name, (model, state))
|
|
135
|
+
|
|
136
|
+
return flask.jsonify({"status": "success"})
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
@app.route("/", methods=["GET"])
|
|
140
|
+
def index():
|
|
141
|
+
pytree_fields = pytree_to_fields(PYTREE)
|
|
142
|
+
return flask.render_template(
|
|
143
|
+
"index.html",
|
|
144
|
+
pytree_fields=pytree_fields,
|
|
145
|
+
torch_fields=state_dict_to_fields(STATE_DICT),
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def convert(jax_fields: list[JaxField], torch_fields: list[TorchField]):
|
|
150
|
+
global PYTREE, STATE_DICT
|
|
151
|
+
if STATE_DICT is None:
|
|
152
|
+
raise ValueError("STATE_DICT must not be None!")
|
|
153
|
+
identity = lambda *args, **kwargs: PYTREE
|
|
154
|
+
model, state = eqx.nn.make_with_state(identity)()
|
|
155
|
+
state_paths: list[tuple[JaxField, TorchField]] = []
|
|
156
|
+
for jax_field, torch_field in zip(jax_fields, torch_fields):
|
|
157
|
+
path = jax_field.path.split(".")[1:]
|
|
158
|
+
if "StateIndex" in jax_field.type:
|
|
159
|
+
state_paths.append((jax_field, torch_field))
|
|
160
|
+
|
|
161
|
+
else:
|
|
162
|
+
where = ft.partial(get_node, targets=path)
|
|
163
|
+
if where(model) is not None:
|
|
164
|
+
model = eqx.tree_at(
|
|
165
|
+
where,
|
|
166
|
+
model,
|
|
167
|
+
STATE_DICT[torch_field.path],
|
|
168
|
+
)
|
|
169
|
+
result: dict[str, list[TorchField]] = {}
|
|
170
|
+
for tuple_item in state_paths:
|
|
171
|
+
path_prefix = tuple_item[0].path.split(".")[1:-1]
|
|
172
|
+
prefix_key = ".".join(path_prefix)
|
|
173
|
+
if prefix_key not in result:
|
|
174
|
+
result[prefix_key] = []
|
|
175
|
+
result[prefix_key].append(tuple_item[1])
|
|
176
|
+
|
|
177
|
+
for key in result:
|
|
178
|
+
state_index = get_node(model, key.split("."))
|
|
179
|
+
if state_index is not None:
|
|
180
|
+
to_replace_tuple = tuple([STATE_DICT[i.path] for i in result[key]])
|
|
181
|
+
state = state.set(state_index, to_replace_tuple)
|
|
182
|
+
return model, state
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def start_conversion(pytree: PyTree, state_dict: dict):
|
|
186
|
+
global PYTREE, STATE_DICT
|
|
187
|
+
if state_dict is None:
|
|
188
|
+
raise ValueError("STATE_DICT must not be None!")
|
|
189
|
+
PYTREE = pytree
|
|
190
|
+
STATE_DICT = state_dict
|
|
191
|
+
|
|
192
|
+
for k, v in STATE_DICT.items():
|
|
193
|
+
STATE_DICT[k] = v.numpy()
|
|
194
|
+
app.run(debug=True, port=5500)
|