statedict2pytree 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,7 @@
1
1
  from statedict2pytree.statedict2pytree import (
2
+ autoconvert as autoconvert,
2
3
  convert as convert,
4
+ pytree_to_fields as pytree_to_fields,
3
5
  start_conversion as start_conversion,
6
+ state_dict_to_fields as state_dict_to_fields,
4
7
  )
@@ -4,6 +4,7 @@ import re
4
4
  import equinox as eqx
5
5
  import flask
6
6
  import jax
7
+ import numpy as np
7
8
  from beartype.typing import Optional
8
9
  from jaxtyping import PyTree
9
10
  from loguru import logger
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
31
32
  STATE_DICT: Optional[dict] = None
32
33
 
33
34
 
35
+ def can_reshape(shape1, shape2):
36
+ product1 = np.prod(shape1)
37
+ product2 = np.prod(shape2)
38
+
39
+ return product1 == product2
40
+
41
+
34
42
  def get_node(
35
43
  tree: PyTree, targets: list[str], log_when_not_found: bool = False
36
44
  ) -> PyTree | None:
@@ -146,17 +154,23 @@ def index():
146
154
  )
147
155
 
148
156
 
157
+ def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
+ jax_fields = pytree_to_fields(pytree)
159
+ torch_fields = state_dict_to_fields(state_dict)
160
+ return convert(jax_fields, torch_fields, pytree, state_dict)
161
+
162
+
149
163
  def convert(
150
164
  jax_fields: list[JaxField],
151
165
  torch_fields: list[TorchField],
152
166
  pytree: PyTree,
153
167
  state_dict: dict,
154
- ):
168
+ ) -> tuple[PyTree, eqx.nn.State]:
155
169
  identity = lambda *args, **kwargs: pytree
156
170
  model, state = eqx.nn.make_with_state(identity)()
157
171
  state_paths: list[tuple[JaxField, TorchField]] = []
158
172
  for jax_field, torch_field in zip(jax_fields, torch_fields):
159
- if jax_field.shape != torch_field.shape:
173
+ if not can_reshape(jax_field.shape, torch_field.shape):
160
174
  raise ValueError(
161
175
  "Fields have incompatible shapes!"
162
176
  f"{jax_field.shape=} != {torch_field.shape=}"
@@ -171,7 +185,7 @@ def convert(
171
185
  model = eqx.tree_at(
172
186
  where,
173
187
  model,
174
- state_dict[torch_field.path],
188
+ state_dict[torch_field.path].reshape(jax_field.shape),
175
189
  )
176
190
  result: dict[str, list[TorchField]] = {}
177
191
  for tuple_item in state_paths:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.2.0
3
+ Version: 0.3.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
32
32
 
33
33
  (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
34
 
35
+ ## Shape Matching? What's that?
36
+
37
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
38
+
39
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
40
+
35
41
  ## Get Started
36
42
 
37
43
  ### Installation
@@ -124,3 +130,18 @@ def convert(
124
130
  ```
125
131
 
126
132
  If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
133
+
134
+ For the full, automatic experience, use `autoconvert`:
135
+
136
+ ```python
137
+ import statedict2pytree as s2p
138
+
139
+ my_model = Model(...)
140
+ state_dict = ...
141
+
142
+ model, state = s2p.autoconvert(my_model, state_dict)
143
+
144
+ ```
145
+
146
+ This will however only work if your PyTree fields have been declared
147
+ in the same order as they appear in the state dict!
@@ -0,0 +1,8 @@
1
+ statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
2
+ statedict2pytree/statedict2pytree.py,sha256=X5Ljf4lYhhH7_V4KgdciChncbTt7YZpIWHcOxcZ3l48,7103
3
+ statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
+ statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
+ statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
+ statedict2pytree-0.3.0.dist-info/METADATA,sha256=YSK4tWzNQemyZ1xKq5BhWiLWWc-RDr4E9q_eV_iOsdw,4232
7
+ statedict2pytree-0.3.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
+ statedict2pytree-0.3.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=kMuooLMZQ68rfJSJNVEpJORGnSJFY1sv6jgK9Guh4LY,116
2
- statedict2pytree/statedict2pytree.py,sha256=SQ1Xs4VPG_gnjYdAOBzdoNih1RGVsW9x08UQGkIFUdg,6640
3
- statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
- statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
- statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
- statedict2pytree-0.2.0.dist-info/METADATA,sha256=x9GhcbG0io5HIpfwocS-nsPYyKRTmG90uLKyckSm394,3673
7
- statedict2pytree-0.2.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
- statedict2pytree-0.2.0.dist-info/RECORD,,