statedict2pytree 1.0.2__py3-none-any.whl → 2.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
statedict2pytree/converter.py
CHANGED
|
@@ -28,7 +28,8 @@ class TorchField(BaseModel):
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class JaxField(BaseModel):
|
|
31
|
-
|
|
31
|
+
key_path: KeyPath
|
|
32
|
+
path: str
|
|
32
33
|
shape: tuple[int, ...]
|
|
33
34
|
skip: bool = False
|
|
34
35
|
|
|
@@ -45,7 +46,7 @@ def is_numerical(element: Any):
|
|
|
45
46
|
|
|
46
47
|
|
|
47
48
|
def _default_floating_dtype():
|
|
48
|
-
if jax.config.jax_enable_x64: #
|
|
49
|
+
if jax.config.jax_enable_x64: # ty: ignore
|
|
49
50
|
return jnp.float64
|
|
50
51
|
else:
|
|
51
52
|
return jnp.float32
|
|
@@ -86,14 +87,14 @@ def _get_stateindex_fields(obj) -> dict:
|
|
|
86
87
|
|
|
87
88
|
|
|
88
89
|
def _get_node(
|
|
89
|
-
tree: PyTree,
|
|
90
|
+
tree: PyTree, key_path: KeyPath, state_indices: dict | None = None
|
|
90
91
|
) -> tuple[PyTree | None, dict | None]:
|
|
91
92
|
if tree is None:
|
|
92
93
|
return None, {}
|
|
93
94
|
else:
|
|
94
|
-
if len(
|
|
95
|
+
if len(key_path) == 0:
|
|
95
96
|
return tree, state_indices
|
|
96
|
-
f, *_ =
|
|
97
|
+
f, *_ = key_path
|
|
97
98
|
if hasattr(tree, "is_stateful") and tree.is_stateful():
|
|
98
99
|
if state_indices is None:
|
|
99
100
|
state_indices = {}
|
|
@@ -118,17 +119,17 @@ def _get_node(
|
|
|
118
119
|
subtree = None
|
|
119
120
|
else:
|
|
120
121
|
subtree = None
|
|
121
|
-
return _get_node(subtree,
|
|
122
|
+
return _get_node(subtree, key_path[1:], state_indices)
|
|
122
123
|
|
|
123
124
|
|
|
124
125
|
def _replace_node(
|
|
125
|
-
tree: PyTree,
|
|
126
|
+
tree: PyTree, key_path: KeyPath, new_value: Array, state_indices: dict | None = None
|
|
126
127
|
) -> PyTree:
|
|
127
128
|
def where_wrapper(t):
|
|
128
|
-
node, _ = _get_node(t,
|
|
129
|
+
node, _ = _get_node(t, key_path=key_path, state_indices=state_indices)
|
|
129
130
|
return node
|
|
130
131
|
|
|
131
|
-
node, _ = _get_node(tree,
|
|
132
|
+
node, _ = _get_node(tree, key_path=key_path, state_indices=state_indices)
|
|
132
133
|
|
|
133
134
|
if node is not None and eqx.is_array(node):
|
|
134
135
|
tree = eqx.tree_at(
|
|
@@ -137,10 +138,48 @@ def _replace_node(
|
|
|
137
138
|
new_value.reshape(node.shape),
|
|
138
139
|
)
|
|
139
140
|
else:
|
|
140
|
-
print("WARNING: Couldn't find: ", jax.tree_util.keystr(
|
|
141
|
+
print("WARNING: Couldn't find: ", jax.tree_util.keystr(key_path))
|
|
141
142
|
return tree
|
|
142
143
|
|
|
143
144
|
|
|
145
|
+
def _resolve_state_names(pytree: PyTree, key_path: KeyPath) -> str:
|
|
146
|
+
parts = []
|
|
147
|
+
current = pytree
|
|
148
|
+
|
|
149
|
+
for i, key in enumerate(key_path):
|
|
150
|
+
if isinstance(key, GetAttrKey):
|
|
151
|
+
parts.append(key.name)
|
|
152
|
+
current = getattr(current, key.name)
|
|
153
|
+
elif isinstance(key, SequenceKey):
|
|
154
|
+
if i == 0 and isinstance(pytree, tuple) and len(pytree) == 2:
|
|
155
|
+
current = current[key.idx]
|
|
156
|
+
continue
|
|
157
|
+
current = current[key.idx]
|
|
158
|
+
parts.append(str(key.idx))
|
|
159
|
+
elif isinstance(key, FlattenedIndexKey):
|
|
160
|
+
if isinstance(current, eqx.nn.State):
|
|
161
|
+
markers = list(current._state.keys())
|
|
162
|
+
marker = markers[key.key]
|
|
163
|
+
clean_marker = marker.split(".", 1)[-1] if "." in marker else marker
|
|
164
|
+
parts.append(clean_marker)
|
|
165
|
+
current = current._state[marker]
|
|
166
|
+
else:
|
|
167
|
+
parts.append(str(key))
|
|
168
|
+
|
|
169
|
+
return ".".join(parts)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def _normalize_eqx_name(path: str) -> str:
|
|
173
|
+
if "batch_state_index.0" in path:
|
|
174
|
+
return path.replace("batch_state_index.0", "running_mean")
|
|
175
|
+
if "batch_state_index.1" in path:
|
|
176
|
+
return path.replace("batch_state_index.1", "running_var")
|
|
177
|
+
if "batch_counter" in path:
|
|
178
|
+
return path.replace("batch_counter", "num_batches_tracked")
|
|
179
|
+
|
|
180
|
+
return path
|
|
181
|
+
|
|
182
|
+
|
|
144
183
|
def move_running_fields_to_the_end(
|
|
145
184
|
torchfields: list[TorchField], identifier: str = "running_"
|
|
146
185
|
):
|
|
@@ -162,14 +201,16 @@ def move_running_fields_to_the_end(
|
|
|
162
201
|
|
|
163
202
|
|
|
164
203
|
def state_dict_to_fields(
|
|
165
|
-
state_dict: dict[str, Any],
|
|
204
|
+
state_dict: dict[str, Any], sort_by_path: bool = True
|
|
166
205
|
) -> list[TorchField]:
|
|
167
206
|
if state_dict is None:
|
|
168
207
|
return []
|
|
169
208
|
fields: list[TorchField] = []
|
|
170
209
|
for key, value in state_dict.items():
|
|
171
|
-
if hasattr(value, "shape")
|
|
210
|
+
if hasattr(value, "shape"):
|
|
172
211
|
fields.append(TorchField(path=key, shape=tuple(value.shape)))
|
|
212
|
+
|
|
213
|
+
fields.sort(key=lambda x: x.path)
|
|
173
214
|
return fields
|
|
174
215
|
|
|
175
216
|
|
|
@@ -177,17 +218,29 @@ def pytree_to_fields(
|
|
|
177
218
|
pytree: PyTree,
|
|
178
219
|
model_order: list[str] | None = None,
|
|
179
220
|
filter: Callable[[Array], bool] = eqx.is_array,
|
|
221
|
+
sort_by_path: bool = True,
|
|
180
222
|
) -> tuple[list[JaxField], dict | None]:
|
|
181
223
|
jaxfields = []
|
|
182
224
|
paths = jax.tree.leaves_with_path(pytree)
|
|
183
|
-
|
|
225
|
+
state_indices = {}
|
|
184
226
|
for p in paths:
|
|
185
227
|
keys, _ = p
|
|
186
|
-
n,
|
|
228
|
+
n, state_indices = _get_node(pytree, keys, state_indices)
|
|
187
229
|
if n is not None and filter(n):
|
|
188
|
-
|
|
230
|
+
readable_path = _resolve_state_names(pytree, keys)
|
|
231
|
+
jaxfields.append(
|
|
232
|
+
JaxField(
|
|
233
|
+
key_path=keys,
|
|
234
|
+
shape=n.shape,
|
|
235
|
+
path=_normalize_eqx_name(readable_path),
|
|
236
|
+
)
|
|
237
|
+
)
|
|
189
238
|
|
|
190
239
|
if model_order is not None:
|
|
240
|
+
if sort_by_path:
|
|
241
|
+
raise ValueError(
|
|
242
|
+
"model_order is given and sort_by_path is true. Only one of those is allowed"
|
|
243
|
+
)
|
|
191
244
|
ordered_jaxfields = []
|
|
192
245
|
path_dict = {jax.tree_util.keystr(field.path): field for field in jaxfields}
|
|
193
246
|
|
|
@@ -197,7 +250,9 @@ def pytree_to_fields(
|
|
|
197
250
|
del path_dict[path_str]
|
|
198
251
|
ordered_jaxfields.extend(path_dict.values())
|
|
199
252
|
jaxfields = ordered_jaxfields
|
|
200
|
-
|
|
253
|
+
elif sort_by_path:
|
|
254
|
+
jaxfields.sort(key=lambda j: j.path)
|
|
255
|
+
return jaxfields, state_indices
|
|
201
256
|
|
|
202
257
|
|
|
203
258
|
def _chunkify_state_dict(
|
|
@@ -248,7 +303,9 @@ def convert(
|
|
|
248
303
|
if len(torchfields) != len(jaxfields):
|
|
249
304
|
raise ValueError(
|
|
250
305
|
f"Length of state_dict ({len(torchfields)}) "
|
|
251
|
-
f"!= length of pytree ({len(jaxfields)})"
|
|
306
|
+
f"!= length of pytree ({len(jaxfields)}) "
|
|
307
|
+
"Note: if you are using BatchNorm in Equinox, make sure to "
|
|
308
|
+
'add BatchNorm(..., mode="batch"), which might fix your issue!'
|
|
252
309
|
)
|
|
253
310
|
|
|
254
311
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
@@ -262,28 +319,34 @@ def convert(
|
|
|
262
319
|
"Note that the order of the fields matters "
|
|
263
320
|
"and that you can mark arrays as skippable. "
|
|
264
321
|
f"{t.path=} "
|
|
265
|
-
f"{jax.tree_util.keystr(j.
|
|
322
|
+
f"{jax.tree_util.keystr(j.key_path)=} ({j.path})"
|
|
266
323
|
)
|
|
267
324
|
state_dict_dir = pathlib.Path(chunkified_statedict_path.path) / "state_dict"
|
|
268
325
|
filename = state_dict_dir / t.path
|
|
269
326
|
new_value = jnp.array(np.load(str(filename) + ".npy"))
|
|
270
327
|
|
|
271
|
-
n, _ = _get_node(pytree, j.
|
|
272
|
-
assert n is not None, f"Node {j.
|
|
328
|
+
n, _ = _get_node(pytree, j.key_path, state_indices)
|
|
329
|
+
assert n is not None, f"Node {j.key_path} not found"
|
|
273
330
|
assert _can_reshape(n.shape, new_value.shape), (
|
|
274
331
|
f"Cannot reshape {n.shape} into {new_value.shape}"
|
|
275
332
|
)
|
|
276
333
|
|
|
277
|
-
pytree = _replace_node(pytree, j.
|
|
334
|
+
pytree = _replace_node(pytree, j.key_path, new_value, state_indices)
|
|
278
335
|
|
|
279
336
|
return pytree
|
|
280
337
|
|
|
281
338
|
|
|
282
339
|
def autoconvert(
|
|
283
|
-
pytree: PyTree,
|
|
340
|
+
pytree: PyTree,
|
|
341
|
+
state_dict: dict,
|
|
342
|
+
filter: Callable[[Array], bool] = eqx.is_array,
|
|
343
|
+
pytree_model_order: list[str] | None = None,
|
|
344
|
+
sort_by_path: bool = True,
|
|
284
345
|
) -> PyTree:
|
|
285
|
-
torchfields = state_dict_to_fields(state_dict)
|
|
286
|
-
jaxfields, state_indices = pytree_to_fields(
|
|
346
|
+
torchfields = state_dict_to_fields(state_dict, sort_by_path)
|
|
347
|
+
jaxfields, state_indices = pytree_to_fields(
|
|
348
|
+
pytree, pytree_model_order, filter=filter, sort_by_path=sort_by_path
|
|
349
|
+
)
|
|
287
350
|
|
|
288
351
|
pytree = convert(
|
|
289
352
|
state_dict,
|
|
@@ -0,0 +1,5 @@
|
|
|
1
|
+
statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
|
|
2
|
+
statedict2pytree/converter.py,sha256=hxUBVJswVQ1FG-6DRKRQWj7aEZuHVQFgk_sSNzPvUK0,11046
|
|
3
|
+
statedict2pytree-2.0.0.dist-info/METADATA,sha256=bg1HBXYFFLFwSlBEZd32p2ViJwfipXzTpQz8YWKnRq0,1926
|
|
4
|
+
statedict2pytree-2.0.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
5
|
+
statedict2pytree-2.0.0.dist-info/RECORD,,
|
|
@@ -1,5 +0,0 @@
|
|
|
1
|
-
statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
|
|
2
|
-
statedict2pytree/converter.py,sha256=Fp9UK8oibon8qua-Kbl1n__KouKOQo_9W3TNss-AhHI,8734
|
|
3
|
-
statedict2pytree-1.0.2.dist-info/METADATA,sha256=K0lvHcpx_0CEVJWCR1hZe8-uzDjAfFtiGseTrSj4BUo,1926
|
|
4
|
-
statedict2pytree-1.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
5
|
-
statedict2pytree-1.0.2.dist-info/RECORD,,
|