statedict2pytree 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,7 @@
1
1
  from statedict2pytree.statedict2pytree import (
2
+ autoconvert as autoconvert,
2
3
  convert as convert,
4
+ pytree_to_fields as pytree_to_fields,
3
5
  start_conversion as start_conversion,
6
+ state_dict_to_fields as state_dict_to_fields,
4
7
  )
@@ -4,6 +4,7 @@ import re
4
4
  import equinox as eqx
5
5
  import flask
6
6
  import jax
7
+ import numpy as np
7
8
  from beartype.typing import Optional
8
9
  from jaxtyping import PyTree
9
10
  from loguru import logger
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
31
32
  STATE_DICT: Optional[dict] = None
32
33
 
33
34
 
35
+ def can_reshape(shape1, shape2):
36
+ product1 = np.prod(shape1)
37
+ product2 = np.prod(shape2)
38
+
39
+ return product1 == product2
40
+
41
+
34
42
  def get_node(
35
43
  tree: PyTree, targets: list[str], log_when_not_found: bool = False
36
44
  ) -> PyTree | None:
@@ -38,7 +46,7 @@ def get_node(
38
46
  return tree
39
47
  else:
40
48
  next_target: str = targets[0]
41
- if bool(re.search(r"\[\d\]", next_target)):
49
+ if bool(re.search(r"\[\d+\]", next_target)):
42
50
  split_index = next_target.rfind("[")
43
51
  name, index = next_target[:split_index], next_target[split_index:]
44
52
  index = index[1:-1]
@@ -146,17 +154,26 @@ def index():
146
154
  )
147
155
 
148
156
 
157
+ def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
+ jax_fields = pytree_to_fields(pytree)
159
+ torch_fields = state_dict_to_fields(state_dict)
160
+
161
+ for k, v in state_dict.items():
162
+ state_dict[k] = v.numpy()
163
+ return convert(jax_fields, torch_fields, pytree, state_dict)
164
+
165
+
149
166
  def convert(
150
167
  jax_fields: list[JaxField],
151
168
  torch_fields: list[TorchField],
152
169
  pytree: PyTree,
153
170
  state_dict: dict,
154
- ):
171
+ ) -> tuple[PyTree, eqx.nn.State]:
155
172
  identity = lambda *args, **kwargs: pytree
156
173
  model, state = eqx.nn.make_with_state(identity)()
157
174
  state_paths: list[tuple[JaxField, TorchField]] = []
158
175
  for jax_field, torch_field in zip(jax_fields, torch_fields):
159
- if jax_field.shape != torch_field.shape:
176
+ if not can_reshape(jax_field.shape, torch_field.shape):
160
177
  raise ValueError(
161
178
  "Fields have incompatible shapes!"
162
179
  f"{jax_field.shape=} != {torch_field.shape=}"
@@ -171,7 +188,7 @@ def convert(
171
188
  model = eqx.tree_at(
172
189
  where,
173
190
  model,
174
- state_dict[torch_field.path],
191
+ state_dict[torch_field.path].reshape(jax_field.shape),
175
192
  )
176
193
  result: dict[str, list[TorchField]] = {}
177
194
  for tuple_item in state_paths:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
32
32
 
33
33
  (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
34
 
35
+ ## Shape Matching? What's that?
36
+
37
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
38
+
39
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
40
+
35
41
  ## Get Started
36
42
 
37
43
  ### Installation
@@ -124,3 +130,18 @@ def convert(
124
130
  ```
125
131
 
126
132
  If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
133
+
134
+ For the full, automatic experience, use `autoconvert`:
135
+
136
+ ```python
137
+ import statedict2pytree as s2p
138
+
139
+ my_model = Model(...)
140
+ state_dict = ...
141
+
142
+ model, state = s2p.autoconvert(my_model, state_dict)
143
+
144
+ ```
145
+
146
+ This will however only work if your PyTree fields have been declared
147
+ in the same order as they appear in the state dict!
@@ -0,0 +1,8 @@
1
+ statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
2
+ statedict2pytree/statedict2pytree.py,sha256=u1PddUBY_MHErl3tstSpJt2a6a_H-RizvxP9anPoFpQ,7175
3
+ statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
+ statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
+ statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
+ statedict2pytree-0.4.0.dist-info/METADATA,sha256=BbJaUhVv-Qb4-e2DlX0WBNgPOKsTjlKv8ISJZgFNx1o,4232
7
+ statedict2pytree-0.4.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
+ statedict2pytree-0.4.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
2
- statedict2pytree/statedict2pytree.py,sha256=SQ1Xs4VPG_gnjYdAOBzdoNih1RGVsW9x08UQGkIFUdg,6640
3
- statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
- statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
- statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
- statedict2pytree-0.2.0.dist-info/METADATA,sha256=x9GhcbG0io5HIpfwocS-nsPYyKRTmG90uLKyckSm394,3673
7
- statedict2pytree-0.2.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
- statedict2pytree-0.2.0.dist-info/RECORD,,