statedict2pytree 1.0.2__py3-none-any.whl → 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -28,7 +28,8 @@ class TorchField(BaseModel):
28
28
 
29
29
 
30
30
  class JaxField(BaseModel):
31
- path: KeyPath
31
+ key_path: KeyPath
32
+ path: str
32
33
  shape: tuple[int, ...]
33
34
  skip: bool = False
34
35
 
@@ -45,7 +46,7 @@ def is_numerical(element: Any):
45
46
 
46
47
 
47
48
  def _default_floating_dtype():
48
- if jax.config.jax_enable_x64: # pyright: ignore
49
+ if jax.config.jax_enable_x64: # ty: ignore
49
50
  return jnp.float64
50
51
  else:
51
52
  return jnp.float32
@@ -86,14 +87,14 @@ def _get_stateindex_fields(obj) -> dict:
86
87
 
87
88
 
88
89
  def _get_node(
89
- tree: PyTree, path: KeyPath, state_indices: dict | None = None
90
+ tree: PyTree, key_path: KeyPath, state_indices: dict | None = None
90
91
  ) -> tuple[PyTree | None, dict | None]:
91
92
  if tree is None:
92
93
  return None, {}
93
94
  else:
94
- if len(path) == 0:
95
+ if len(key_path) == 0:
95
96
  return tree, state_indices
96
- f, *_ = path
97
+ f, *_ = key_path
97
98
  if hasattr(tree, "is_stateful") and tree.is_stateful():
98
99
  if state_indices is None:
99
100
  state_indices = {}
@@ -118,17 +119,17 @@ def _get_node(
118
119
  subtree = None
119
120
  else:
120
121
  subtree = None
121
- return _get_node(subtree, path[1:], state_indices)
122
+ return _get_node(subtree, key_path[1:], state_indices)
122
123
 
123
124
 
124
125
  def _replace_node(
125
- tree: PyTree, path: KeyPath, new_value: Array, state_indices: dict | None = None
126
+ tree: PyTree, key_path: KeyPath, new_value: Array, state_indices: dict | None = None
126
127
  ) -> PyTree:
127
128
  def where_wrapper(t):
128
- node, _ = _get_node(t, path=path, state_indices=state_indices)
129
+ node, _ = _get_node(t, key_path=key_path, state_indices=state_indices)
129
130
  return node
130
131
 
131
- node, _ = _get_node(tree, path=path, state_indices=state_indices)
132
+ node, _ = _get_node(tree, key_path=key_path, state_indices=state_indices)
132
133
 
133
134
  if node is not None and eqx.is_array(node):
134
135
  tree = eqx.tree_at(
@@ -137,10 +138,48 @@ def _replace_node(
137
138
  new_value.reshape(node.shape),
138
139
  )
139
140
  else:
140
- print("WARNING: Couldn't find: ", jax.tree_util.keystr(path))
141
+ print("WARNING: Couldn't find: ", jax.tree_util.keystr(key_path))
141
142
  return tree
142
143
 
143
144
 
145
+ def _resolve_state_names(pytree: PyTree, key_path: KeyPath) -> str:
146
+ parts = []
147
+ current = pytree
148
+
149
+ for i, key in enumerate(key_path):
150
+ if isinstance(key, GetAttrKey):
151
+ parts.append(key.name)
152
+ current = getattr(current, key.name)
153
+ elif isinstance(key, SequenceKey):
154
+ if i == 0 and isinstance(pytree, tuple) and len(pytree) == 2:
155
+ current = current[key.idx]
156
+ continue
157
+ current = current[key.idx]
158
+ parts.append(str(key.idx))
159
+ elif isinstance(key, FlattenedIndexKey):
160
+ if isinstance(current, eqx.nn.State):
161
+ markers = list(current._state.keys())
162
+ marker = markers[key.key]
163
+ clean_marker = marker.split(".", 1)[-1] if "." in marker else marker
164
+ parts.append(clean_marker)
165
+ current = current._state[marker]
166
+ else:
167
+ parts.append(str(key))
168
+
169
+ return ".".join(parts)
170
+
171
+
172
+ def _normalize_eqx_name(path: str) -> str:
173
+ if "batch_state_index.0" in path:
174
+ return path.replace("batch_state_index.0", "running_mean")
175
+ if "batch_state_index.1" in path:
176
+ return path.replace("batch_state_index.1", "running_var")
177
+ if "batch_counter" in path:
178
+ return path.replace("batch_counter", "num_batches_tracked")
179
+
180
+ return path
181
+
182
+
144
183
  def move_running_fields_to_the_end(
145
184
  torchfields: list[TorchField], identifier: str = "running_"
146
185
  ):
@@ -162,14 +201,16 @@ def move_running_fields_to_the_end(
162
201
 
163
202
 
164
203
  def state_dict_to_fields(
165
- state_dict: dict[str, Any],
204
+ state_dict: dict[str, Any], sort_by_path: bool = True
166
205
  ) -> list[TorchField]:
167
206
  if state_dict is None:
168
207
  return []
169
208
  fields: list[TorchField] = []
170
209
  for key, value in state_dict.items():
171
- if hasattr(value, "shape") and len(value.shape) > 0:
210
+ if hasattr(value, "shape"):
172
211
  fields.append(TorchField(path=key, shape=tuple(value.shape)))
212
+
213
+ fields.sort(key=lambda x: x.path)
173
214
  return fields
174
215
 
175
216
 
@@ -177,17 +218,29 @@ def pytree_to_fields(
177
218
  pytree: PyTree,
178
219
  model_order: list[str] | None = None,
179
220
  filter: Callable[[Array], bool] = eqx.is_array,
221
+ sort_by_path: bool = True,
180
222
  ) -> tuple[list[JaxField], dict | None]:
181
223
  jaxfields = []
182
224
  paths = jax.tree.leaves_with_path(pytree)
183
- i = {}
225
+ state_indices = {}
184
226
  for p in paths:
185
227
  keys, _ = p
186
- n, i = _get_node(pytree, keys, i)
228
+ n, state_indices = _get_node(pytree, keys, state_indices)
187
229
  if n is not None and filter(n):
188
- jaxfields.append(JaxField(path=keys, shape=n.shape))
230
+ readable_path = _resolve_state_names(pytree, keys)
231
+ jaxfields.append(
232
+ JaxField(
233
+ key_path=keys,
234
+ shape=n.shape,
235
+ path=_normalize_eqx_name(readable_path),
236
+ )
237
+ )
189
238
 
190
239
  if model_order is not None:
240
+ if sort_by_path:
241
+ raise ValueError(
242
+ "model_order is given and sort_by_path is true. Only one of those is allowed"
243
+ )
191
244
  ordered_jaxfields = []
192
245
  path_dict = {jax.tree_util.keystr(field.path): field for field in jaxfields}
193
246
 
@@ -197,7 +250,9 @@ def pytree_to_fields(
197
250
  del path_dict[path_str]
198
251
  ordered_jaxfields.extend(path_dict.values())
199
252
  jaxfields = ordered_jaxfields
200
- return jaxfields, i
253
+ elif sort_by_path:
254
+ jaxfields.sort(key=lambda j: j.path)
255
+ return jaxfields, state_indices
201
256
 
202
257
 
203
258
  def _chunkify_state_dict(
@@ -248,7 +303,9 @@ def convert(
248
303
  if len(torchfields) != len(jaxfields):
249
304
  raise ValueError(
250
305
  f"Length of state_dict ({len(torchfields)}) "
251
- f"!= length of pytree ({len(jaxfields)})"
306
+ f"!= length of pytree ({len(jaxfields)}) "
307
+ "Note: if you are using BatchNorm in Equinox, make sure to "
308
+ 'add BatchNorm(..., mode="batch"), which might fix your issue!'
252
309
  )
253
310
 
254
311
  with tempfile.TemporaryDirectory() as tmpdir:
@@ -262,28 +319,34 @@ def convert(
262
319
  "Note that the order of the fields matters "
263
320
  "and that you can mark arrays as skippable. "
264
321
  f"{t.path=} "
265
- f"{jax.tree_util.keystr(j.path)=}"
322
+ f"{jax.tree_util.keystr(j.key_path)=} ({j.path})"
266
323
  )
267
324
  state_dict_dir = pathlib.Path(chunkified_statedict_path.path) / "state_dict"
268
325
  filename = state_dict_dir / t.path
269
326
  new_value = jnp.array(np.load(str(filename) + ".npy"))
270
327
 
271
- n, _ = _get_node(pytree, j.path, state_indices)
272
- assert n is not None, f"Node {j.path} not found"
328
+ n, _ = _get_node(pytree, j.key_path, state_indices)
329
+ assert n is not None, f"Node {j.key_path} not found"
273
330
  assert _can_reshape(n.shape, new_value.shape), (
274
331
  f"Cannot reshape {n.shape} into {new_value.shape}"
275
332
  )
276
333
 
277
- pytree = _replace_node(pytree, j.path, new_value, state_indices)
334
+ pytree = _replace_node(pytree, j.key_path, new_value, state_indices)
278
335
 
279
336
  return pytree
280
337
 
281
338
 
282
339
  def autoconvert(
283
- pytree: PyTree, state_dict: dict, pytree_model_order: list[str] | None = None
340
+ pytree: PyTree,
341
+ state_dict: dict,
342
+ filter: Callable[[Array], bool] = eqx.is_array,
343
+ pytree_model_order: list[str] | None = None,
344
+ sort_by_path: bool = True,
284
345
  ) -> PyTree:
285
- torchfields = state_dict_to_fields(state_dict)
286
- jaxfields, state_indices = pytree_to_fields(pytree, pytree_model_order)
346
+ torchfields = state_dict_to_fields(state_dict, sort_by_path)
347
+ jaxfields, state_indices = pytree_to_fields(
348
+ pytree, pytree_model_order, filter=filter, sort_by_path=sort_by_path
349
+ )
287
350
 
288
351
  pytree = convert(
289
352
  state_dict,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: statedict2pytree
3
- Version: 1.0.2
3
+ Version: 2.0.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: >=3.11
@@ -0,0 +1,5 @@
1
+ statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
2
+ statedict2pytree/converter.py,sha256=hxUBVJswVQ1FG-6DRKRQWj7aEZuHVQFgk_sSNzPvUK0,11046
3
+ statedict2pytree-2.0.0.dist-info/METADATA,sha256=bg1HBXYFFLFwSlBEZd32p2ViJwfipXzTpQz8YWKnRq0,1926
4
+ statedict2pytree-2.0.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
5
+ statedict2pytree-2.0.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,5 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=nOtp-ZeTwqHaIibPrqWWXJYzB5PnZHhaRw0HSr9fLBk,215
2
- statedict2pytree/converter.py,sha256=Fp9UK8oibon8qua-Kbl1n__KouKOQo_9W3TNss-AhHI,8734
3
- statedict2pytree-1.0.2.dist-info/METADATA,sha256=K0lvHcpx_0CEVJWCR1hZe8-uzDjAfFtiGseTrSj4BUo,1926
4
- statedict2pytree-1.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
- statedict2pytree-1.0.2.dist-info/RECORD,,