statedict2pytree 1.0.1__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,8 @@ class TorchField(BaseModel):
28
28
 
29
29
 
30
30
  class JaxField(BaseModel):
31
- path: KeyPath
31
+ key_path: KeyPath
32
+ path: str
32
33
  shape: tuple[int, ...]
33
34
  skip: bool = False
34
35
 
@@ -45,7 +46,7 @@ def is_numerical(element: Any):
45
46
 
46
47
 
47
48
  def _default_floating_dtype():
48
- if jax.config.jax_enable_x64: # pyright: ignore
49
+ if jax.config.jax_enable_x64: # ty: ignore
49
50
  return jnp.float64
50
51
  else:
51
52
  return jnp.float32
@@ -86,15 +87,15 @@ def _get_stateindex_fields(obj) -> dict:
86
87
 
87
88
 
88
89
  def _get_node(
89
- tree: PyTree, path: KeyPath, state_indices: dict | None = None
90
+ tree: PyTree, key_path: KeyPath, state_indices: dict | None = None
90
91
  ) -> tuple[PyTree | None, dict | None]:
91
92
  if tree is None:
92
93
  return None, {}
93
94
  else:
94
- if len(path) == 0:
95
+ if len(key_path) == 0:
95
96
  return tree, state_indices
96
- f, *_ = path
97
- if hasattr(tree, "is_stateful"):
97
+ f, *_ = key_path
98
+ if hasattr(tree, "is_stateful") and tree.is_stateful():
98
99
  if state_indices is None:
99
100
  state_indices = {}
100
101
  indices = _get_stateindex_fields(tree)
@@ -109,23 +110,26 @@ def _get_node(
109
110
  elif isinstance(f, FlattenedIndexKey):
110
111
  if isinstance(tree, eqx.nn.State):
111
112
  assert state_indices is not None
112
- index = state_indices[f.key]
113
+
114
+ markers = list(tree._state.keys())
115
+ marker = markers[f.key]
116
+ index = state_indices[marker]
113
117
  subtree = tree.get(index)
114
118
  else:
115
119
  subtree = None
116
120
  else:
117
121
  subtree = None
118
- return _get_node(subtree, path[1:], state_indices)
122
+ return _get_node(subtree, key_path[1:], state_indices)
119
123
 
120
124
 
121
125
  def _replace_node(
122
- tree: PyTree, path: KeyPath, new_value: Array, state_indices: dict | None = None
126
+ tree: PyTree, key_path: KeyPath, new_value: Array, state_indices: dict | None = None
123
127
  ) -> PyTree:
124
128
  def where_wrapper(t):
125
- node, _ = _get_node(t, path=path, state_indices=state_indices)
129
+ node, _ = _get_node(t, key_path=key_path, state_indices=state_indices)
126
130
  return node
127
131
 
128
- node, _ = _get_node(tree, path=path, state_indices=state_indices)
132
+ node, _ = _get_node(tree, key_path=key_path, state_indices=state_indices)
129
133
 
130
134
  if node is not None and eqx.is_array(node):
131
135
  tree = eqx.tree_at(
@@ -134,10 +138,48 @@ def _replace_node(
134
138
  new_value.reshape(node.shape),
135
139
  )
136
140
  else:
137
- print("WARNING: Couldn't find: ", jax.tree_util.keystr(path))
141
+ print("WARNING: Couldn't find: ", jax.tree_util.keystr(key_path))
138
142
  return tree
139
143
 
140
144
 
145
+ def _resolve_state_names(pytree: PyTree, key_path: KeyPath) -> str:
146
+ parts = []
147
+ current = pytree
148
+
149
+ for i, key in enumerate(key_path):
150
+ if isinstance(key, GetAttrKey):
151
+ parts.append(key.name)
152
+ current = getattr(current, key.name)
153
+ elif isinstance(key, SequenceKey):
154
+ if i == 0 and isinstance(pytree, tuple) and len(pytree) == 2:
155
+ current = current[key.idx]
156
+ continue
157
+ current = current[key.idx]
158
+ parts.append(str(key.idx))
159
+ elif isinstance(key, FlattenedIndexKey):
160
+ if isinstance(current, eqx.nn.State):
161
+ markers = list(current._state.keys())
162
+ marker = markers[key.key]
163
+ clean_marker = marker.split(".", 1)[-1] if "." in marker else marker
164
+ parts.append(clean_marker)
165
+ current = current._state[marker]
166
+ else:
167
+ parts.append(str(key))
168
+
169
+ return ".".join(parts)
170
+
171
+
172
+ def _normalize_eqx_name(path: str) -> str:
173
+ if "batch_state_index.0" in path:
174
+ return path.replace("batch_state_index.0", "running_mean")
175
+ if "batch_state_index.1" in path:
176
+ return path.replace("batch_state_index.1", "running_var")
177
+ if "batch_counter" in path:
178
+ return path.replace("batch_counter", "num_batches_tracked")
179
+
180
+ return path
181
+
182
+
141
183
  def move_running_fields_to_the_end(
142
184
  torchfields: list[TorchField], identifier: str = "running_"
143
185
  ):
@@ -159,14 +201,16 @@ def move_running_fields_to_the_end(
159
201
 
160
202
 
161
203
  def state_dict_to_fields(
162
- state_dict: dict[str, Any],
204
+ state_dict: dict[str, Any], sort_by_path: bool = True
163
205
  ) -> list[TorchField]:
164
206
  if state_dict is None:
165
207
  return []
166
208
  fields: list[TorchField] = []
167
209
  for key, value in state_dict.items():
168
- if hasattr(value, "shape") and len(value.shape) > 0:
210
+ if hasattr(value, "shape"):
169
211
  fields.append(TorchField(path=key, shape=tuple(value.shape)))
212
+
213
+ fields.sort(key=lambda x: x.path)
170
214
  return fields
171
215
 
172
216
 
@@ -174,17 +218,29 @@ def pytree_to_fields(
174
218
  pytree: PyTree,
175
219
  model_order: list[str] | None = None,
176
220
  filter: Callable[[Array], bool] = eqx.is_array,
221
+ sort_by_path: bool = True,
177
222
  ) -> tuple[list[JaxField], dict | None]:
178
223
  jaxfields = []
179
224
  paths = jax.tree.leaves_with_path(pytree)
180
- i = {}
225
+ state_indices = {}
181
226
  for p in paths:
182
227
  keys, _ = p
183
- n, i = _get_node(pytree, keys, i)
228
+ n, state_indices = _get_node(pytree, keys, state_indices)
184
229
  if n is not None and filter(n):
185
- jaxfields.append(JaxField(path=keys, shape=n.shape))
230
+ readable_path = _resolve_state_names(pytree, keys)
231
+ jaxfields.append(
232
+ JaxField(
233
+ key_path=keys,
234
+ shape=n.shape,
235
+ path=_normalize_eqx_name(readable_path),
236
+ )
237
+ )
186
238
 
187
239
  if model_order is not None:
240
+ if sort_by_path:
241
+ raise ValueError(
242
+ "model_order is given and sort_by_path is true. Only one of those is allowed"
243
+ )
188
244
  ordered_jaxfields = []
189
245
  path_dict = {jax.tree_util.keystr(field.path): field for field in jaxfields}
190
246
 
@@ -194,7 +250,9 @@ def pytree_to_fields(
194
250
  del path_dict[path_str]
195
251
  ordered_jaxfields.extend(path_dict.values())
196
252
  jaxfields = ordered_jaxfields
197
- return jaxfields, i
253
+ elif sort_by_path:
254
+ jaxfields.sort(key=lambda j: j.path)
255
+ return jaxfields, state_indices
198
256
 
199
257
 
200
258
  def _chunkify_state_dict(
@@ -245,7 +303,9 @@ def convert(
245
303
  if len(torchfields) != len(jaxfields):
246
304
  raise ValueError(
247
305
  f"Length of state_dict ({len(torchfields)}) "
248
- f"!= length of pytree ({len(jaxfields)})"
306
+ f"!= length of pytree ({len(jaxfields)}) "
307
+ "Note: if you are using BatchNorm in Equinox, make sure to "
308
+ 'add BatchNorm(..., mode="batch"), which might fix your issue!'
249
309
  )
250
310
 
251
311
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -259,28 +319,34 @@ def convert(
259
319
  "Note that the order of the fields matters "
260
320
  "and that you can mark arrays as skippable. "
261
321
  f"{t.path=} "
262
- f"{jax.tree_util.keystr(j.path)=}"
322
+ f"{jax.tree_util.keystr(j.key_path)=} ({j.path})"
263
323
  )
264
324
  state_dict_dir = pathlib.Path(chunkified_statedict_path.path) / "state_dict"
265
325
  filename = state_dict_dir / t.path
266
326
  new_value = jnp.array(np.load(str(filename) + ".npy"))
267
327
 
268
- n, _ = _get_node(pytree, j.path, state_indices)
269
- assert n is not None, f"Node {j.path} not found"
328
+ n, _ = _get_node(pytree, j.key_path, state_indices)
329
+ assert n is not None, f"Node {j.key_path} not found"
270
330
  assert _can_reshape(n.shape, new_value.shape), (
271
331
  f"Cannot reshape {n.shape} into {new_value.shape}"
272
332
  )
273
333
 
274
- pytree = _replace_node(pytree, j.path, new_value, state_indices)
334
+ pytree = _replace_node(pytree, j.key_path, new_value, state_indices)
275
335
 
276
336
  return pytree
277
337
 
278
338
 
279
339
  def autoconvert(
280
- pytree: PyTree, state_dict: dict, pytree_model_order: list[str] | None = None
340
+ pytree: PyTree,
341
+ state_dict: dict,
342
+ filter: Callable[[Array], bool] = eqx.is_array,
343
+ pytree_model_order: list[str] | None = None,
344
+ sort_by_path: bool = True,
281
345
  ) -> PyTree:
282
- torchfields = state_dict_to_fields(state_dict)
283
- jaxfields, state_indices = pytree_to_fields(pytree, pytree_model_order)
346
+ torchfields = state_dict_to_fields(state_dict, sort_by_path)
347
+ jaxfields, state_indices = pytree_to_fields(
348
+ pytree, pytree_model_order, filter=filter, sort_by_path=sort_by_path
349
+ )
284
350
 
285
351
  pytree = convert(
286
352
  state_dict,
@@ -1,11 +1,11 @@
1
- Metadata-Version: 2.3
1
+ Metadata-Version: 2.4
2
2
  Name: statedict2pytree
3
- Version: 1.0.1
3
+ Version: 2.0.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: >=3.11
7
7
  Requires-Dist: beartype
8
- Requires-Dist: equinox
8
+ Requires-Dist: equinox>=0.13.0
9
9
  Requires-Dist: jax
10
10
  Requires-Dist: jaxlib
11
11
  Requires-Dist: jaxtyping
@@ -0,0 +1,5 @@
1
+ statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
2
+ statedict2pytree/converter.py,sha256=hxUBVJswVQ1FG-6DRKRQWj7aEZuHVQFgk_sSNzPvUK0,11046
3
+ statedict2pytree-2.0.0.dist-info/METADATA,sha256=bg1HBXYFFLFwSlBEZd32p2ViJwfipXzTpQz8YWKnRq0,1926
4
+ statedict2pytree-2.0.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ statedict2pytree-2.0.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.24.2
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,5 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
2
- statedict2pytree/converter.py,sha256=Toehra2_guidbR5Gyl1uUi1YYv37jkoSKDQjxo1tzwg,8626
3
- statedict2pytree-1.0.1.dist-info/METADATA,sha256=Z_tRmXwxVMGUMrOYm3SiCxnl4L9OUINZMrBqxbEg5dE,1918
4
- statedict2pytree-1.0.1.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
5
- statedict2pytree-1.0.1.dist-info/RECORD,,