statedict2pytree 0.2.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (23) hide show
  1. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/PKG-INFO +22 -1
  2. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/README.md +21 -0
  3. statedict2pytree-0.4.0/examples/convert_resnet.py +20 -0
  4. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/examples/resnet.py +43 -12
  5. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/examples/test_resnet_inference.py +6 -6
  6. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/pyproject.toml +2 -2
  7. statedict2pytree-0.4.0/pyrightconfig.json +4 -0
  8. statedict2pytree-0.4.0/statedict2pytree/__init__.py +7 -0
  9. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/statedict2pytree.py +21 -4
  10. statedict2pytree-0.4.0/tests/test_conv.py +61 -0
  11. statedict2pytree-0.4.0/tests/test_linear.py +45 -0
  12. statedict2pytree-0.2.0/examples/convert_resnet.py +0 -16
  13. statedict2pytree-0.2.0/statedict2pytree/__init__.py +0 -4
  14. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/.gitignore +0 -0
  15. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/.pre-commit-config.yaml +0 -0
  16. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/examples/doggo.jpeg +0 -0
  17. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/package-lock.json +0 -0
  18. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/package.json +0 -0
  19. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/static/input.css +0 -0
  20. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/static/output.css +0 -0
  21. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/statedict2pytree/templates/index.html +0 -0
  22. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/tailwind.config.js +0 -0
  23. {statedict2pytree-0.2.0 → statedict2pytree-0.4.0}/torch2jax.png +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.2.0
3
+ Version: 0.4.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
@@ -32,6 +32,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
32
32
 
33
33
  (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
34
34
 
35
+ ## Shape Matching? What's that?
36
+
37
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
38
+
39
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
40
+
35
41
  ## Get Started
36
42
 
37
43
  ### Installation
@@ -124,3 +130,18 @@ def convert(
124
130
  ```
125
131
 
126
132
  If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
133
+
134
+ For the full, automatic experience, use `autoconvert`:
135
+
136
+ ```python
137
+ import statedict2pytree as s2p
138
+
139
+ my_model = Model(...)
140
+ state_dict = ...
141
+
142
+ model, state = s2p.autoconvert(my_model, state_dict)
143
+
144
+ ```
145
+
146
+ This will however only work if your PyTree fields have been declared
147
+ in the same order as they appear in the state dict!
@@ -8,6 +8,12 @@ Usually, if you _declared the fields in the same order as in the PyTorch model_,
8
8
 
9
9
  (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
10
10
 
11
+ ## Shape Matching? What's that?
12
+
13
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
14
+
15
+ 1. (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
16
+
11
17
  ## Get Started
12
18
 
13
19
  ### Installation
@@ -100,3 +106,18 @@ def convert(
100
106
  ```
101
107
 
102
108
  If your models already have the right "order", then you might as well use this function directly. Note that the lists `jax_fields` and `torch_fields` must have the same length and each matching entry must have the same shape!
109
+
110
+ For the full, automatic experience, use `autoconvert`:
111
+
112
+ ```python
113
+ import statedict2pytree as s2p
114
+
115
+ my_model = Model(...)
116
+ state_dict = ...
117
+
118
+ model, state = s2p.autoconvert(my_model, state_dict)
119
+
120
+ ```
121
+
122
+ This will however only work if your PyTree fields have been declared
123
+ in the same order as they appear in the state dict!
@@ -0,0 +1,20 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import statedict2pytree as s2p
4
+ from resnet import resnet152
5
+ from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
6
+
7
+
8
+ def convert_resnet():
9
+ resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
10
+ resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
11
+ state_dict = resnet_torch.state_dict()
12
+
13
+ # s2p.start_conversion(resnet_jax, state_dict)
14
+ model, state = s2p.autoconvert(resnet_jax, state_dict)
15
+ name = "resnet152.eqx"
16
+ eqx.tree_serialise_leaves(name, (model, state))
17
+
18
+
19
+ if __name__ == "__main__":
20
+ convert_resnet()
@@ -333,12 +333,22 @@ def resnet18(
333
333
 
334
334
 
335
335
  def resnet34(
336
- image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
336
+ image_channels: int = 3,
337
+ num_classes: int = 1000,
338
+ *,
339
+ key: PRNGKeyArray,
340
+ make_with_state: bool = True,
341
+ **kwargs,
337
342
  ):
338
343
  layers = [3, 4, 6, 3]
339
- return eqx.nn.make_with_state(ResNet)(
340
- BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
341
- )
344
+ if make_with_state:
345
+ return eqx.nn.make_with_state(ResNet)(
346
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
347
+ )
348
+ else:
349
+ return ResNet(
350
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
351
+ )
342
352
 
343
353
 
344
354
  def resnet50(
@@ -361,18 +371,39 @@ def resnet50(
361
371
 
362
372
 
363
373
  def resnet101(
364
- image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
374
+ image_channels: int = 3,
375
+ num_classes: int = 1000,
376
+ *,
377
+ key: PRNGKeyArray,
378
+ make_with_state: bool = True,
379
+ **kwargs,
365
380
  ):
366
381
  layers = [3, 4, 23, 3]
367
- return eqx.nn.make_with_state(ResNet)(
368
- Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
369
- )
382
+
383
+ if make_with_state:
384
+ return eqx.nn.make_with_state(ResNet)(
385
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
386
+ )
387
+ else:
388
+ return ResNet(
389
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
390
+ )
370
391
 
371
392
 
372
393
  def resnet152(
373
- image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
394
+ image_channels: int = 3,
395
+ num_classes: int = 1000,
396
+ *,
397
+ key: PRNGKeyArray,
398
+ make_with_state: bool = True,
399
+ **kwargs,
374
400
  ):
375
401
  layers = [3, 8, 36, 3]
376
- return eqx.nn.make_with_state(ResNet)(
377
- Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
378
- )
402
+ if make_with_state:
403
+ return eqx.nn.make_with_state(ResNet)(
404
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
405
+ )
406
+ else:
407
+ return ResNet(
408
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
409
+ )
@@ -6,15 +6,15 @@ import equinox as eqx
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import torch
9
+ from examples.resnet import resnet152
9
10
  from PIL import Image
10
- from tests.resnet import resnet50
11
11
  from torchvision import transforms
12
- from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
12
+ from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
13
13
 
14
14
 
15
15
  def test_resnet():
16
- resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
17
- resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
16
+ resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
17
+ resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
18
18
 
19
19
  img_name = "doggo.jpeg"
20
20
 
@@ -42,7 +42,7 @@ def test_resnet():
42
42
  ) # Outputs the ImageNet class index of the prediction
43
43
 
44
44
  url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
45
- with urllib.request.urlopen(url) as url:
45
+ with urllib.request.urlopen(url) as url: # pyright: ignore
46
46
  imagenet_labels = json.loads(url.read().decode())
47
47
 
48
48
  label = imagenet_labels[str(predicted.item())][1]
@@ -52,7 +52,7 @@ def test_resnet():
52
52
  model_callable = ft.partial(identity, resnet_jax)
53
53
  model, state = eqx.nn.make_with_state(model_callable)()
54
54
 
55
- model, state = eqx.tree_deserialise_leaves("model.eqx", (model, state))
55
+ model, state = eqx.tree_deserialise_leaves("resnet152.eqx", (model, state))
56
56
 
57
57
  jax_batch = jnp.array(batch_t.numpy())
58
58
  out, state = eqx.filter_vmap(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "statedict2pytree"
3
- version = "0.2.0"
3
+ version = "0.4.0"
4
4
  description = "Converts torch models into PyTrees for Equinox"
5
5
  readme = "README.md"
6
6
  requires-python = "~=3.10"
@@ -16,7 +16,7 @@ dependencies = [
16
16
  "torch",
17
17
  "flask",
18
18
  "pydantic",
19
- "penzai"
19
+ "penzai",
20
20
  ]
21
21
  [project.optional-dependencies]
22
22
  dev = ["nox", "pre-commit", "pytest", "mkdocs"]
@@ -0,0 +1,4 @@
1
+ {
2
+ "venvPath": ".",
3
+ "venv": ".venv"
4
+ }
@@ -0,0 +1,7 @@
1
+ from statedict2pytree.statedict2pytree import (
2
+ autoconvert as autoconvert,
3
+ convert as convert,
4
+ pytree_to_fields as pytree_to_fields,
5
+ start_conversion as start_conversion,
6
+ state_dict_to_fields as state_dict_to_fields,
7
+ )
@@ -4,6 +4,7 @@ import re
4
4
  import equinox as eqx
5
5
  import flask
6
6
  import jax
7
+ import numpy as np
7
8
  from beartype.typing import Optional
8
9
  from jaxtyping import PyTree
9
10
  from loguru import logger
@@ -31,6 +32,13 @@ PYTREE: Optional[PyTree] = None
31
32
  STATE_DICT: Optional[dict] = None
32
33
 
33
34
 
35
+ def can_reshape(shape1, shape2):
36
+ product1 = np.prod(shape1)
37
+ product2 = np.prod(shape2)
38
+
39
+ return product1 == product2
40
+
41
+
34
42
  def get_node(
35
43
  tree: PyTree, targets: list[str], log_when_not_found: bool = False
36
44
  ) -> PyTree | None:
@@ -38,7 +46,7 @@ def get_node(
38
46
  return tree
39
47
  else:
40
48
  next_target: str = targets[0]
41
- if bool(re.search(r"\[\d\]", next_target)):
49
+ if bool(re.search(r"\[\d+\]", next_target)):
42
50
  split_index = next_target.rfind("[")
43
51
  name, index = next_target[:split_index], next_target[split_index:]
44
52
  index = index[1:-1]
@@ -146,17 +154,26 @@ def index():
146
154
  )
147
155
 
148
156
 
157
+ def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
+ jax_fields = pytree_to_fields(pytree)
159
+ torch_fields = state_dict_to_fields(state_dict)
160
+
161
+ for k, v in state_dict.items():
162
+ state_dict[k] = v.numpy()
163
+ return convert(jax_fields, torch_fields, pytree, state_dict)
164
+
165
+
149
166
  def convert(
150
167
  jax_fields: list[JaxField],
151
168
  torch_fields: list[TorchField],
152
169
  pytree: PyTree,
153
170
  state_dict: dict,
154
- ):
171
+ ) -> tuple[PyTree, eqx.nn.State]:
155
172
  identity = lambda *args, **kwargs: pytree
156
173
  model, state = eqx.nn.make_with_state(identity)()
157
174
  state_paths: list[tuple[JaxField, TorchField]] = []
158
175
  for jax_field, torch_field in zip(jax_fields, torch_fields):
159
- if jax_field.shape != torch_field.shape:
176
+ if not can_reshape(jax_field.shape, torch_field.shape):
160
177
  raise ValueError(
161
178
  "Fields have incompatible shapes!"
162
179
  f"{jax_field.shape=} != {torch_field.shape=}"
@@ -171,7 +188,7 @@ def convert(
171
188
  model = eqx.tree_at(
172
189
  where,
173
190
  model,
174
- state_dict[torch_field.path],
191
+ state_dict[torch_field.path].reshape(jax_field.shape),
175
192
  )
176
193
  result: dict[str, list[TorchField]] = {}
177
194
  for tuple_item in state_paths:
@@ -0,0 +1,61 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import numpy as np
4
+ import statedict2pytree as s2p
5
+ import torch
6
+
7
+
8
+ def test_conv():
9
+ in_channels = 8
10
+ out_channels = 8
11
+ kernel_size = 4
12
+ stride = 2
13
+ padding = 1
14
+
15
+ class J(eqx.Module):
16
+ conv: eqx.nn.Conv2d
17
+
18
+ def __init__(self):
19
+ self.conv = eqx.nn.Conv2d(
20
+ in_channels=in_channels,
21
+ out_channels=out_channels,
22
+ kernel_size=kernel_size,
23
+ stride=stride,
24
+ padding=padding,
25
+ key=jax.random.PRNGKey(22),
26
+ )
27
+
28
+ class T(torch.nn.Module):
29
+ def __init__(self) -> None:
30
+ super(T, self).__init__()
31
+ self.conv = torch.nn.Conv2d(
32
+ in_channels=in_channels,
33
+ out_channels=out_channels,
34
+ kernel_size=kernel_size,
35
+ stride=stride,
36
+ padding=padding,
37
+ )
38
+
39
+ jax_model = J()
40
+ torch_model = T()
41
+ state_dict = torch_model.state_dict()
42
+
43
+ jax_fields = s2p.pytree_to_fields(jax_model)
44
+ torch_fields = s2p.state_dict_to_fields(state_dict)
45
+
46
+ model, state = s2p.convert(
47
+ jax_fields, torch_fields, pytree=jax_model, state_dict=state_dict
48
+ )
49
+
50
+ assert np.allclose(
51
+ np.array(model.conv.weight), torch_model.conv.weight.detach().numpy()
52
+ )
53
+ if torch_model.conv.bias is not None:
54
+ assert np.allclose(
55
+ np.array(model.conv.bias),
56
+ torch_model.conv.bias.detach().numpy().reshape(model.conv.bias.shape),
57
+ )
58
+
59
+
60
+ if __name__ == "__main__":
61
+ test_conv()
@@ -0,0 +1,45 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import numpy as np
4
+ import statedict2pytree as s2p
5
+ import torch
6
+
7
+
8
+ def test_linear():
9
+ in_features = 10
10
+ out_features = 10
11
+
12
+ class J(eqx.Module):
13
+ linear: eqx.nn.Linear
14
+
15
+ def __init__(self):
16
+ self.linear = eqx.nn.Linear(
17
+ in_features, out_features, key=jax.random.PRNGKey(30)
18
+ )
19
+
20
+ class T(torch.nn.Module):
21
+ def __init__(self) -> None:
22
+ super(T, self).__init__()
23
+ self.linear = torch.nn.Linear(in_features, out_features)
24
+
25
+ jax_model = J()
26
+ torch_model = T()
27
+ state_dict = torch_model.state_dict()
28
+
29
+ jax_fields = s2p.pytree_to_fields(jax_model)
30
+ torch_fields = s2p.state_dict_to_fields(state_dict)
31
+
32
+ model, state = s2p.convert(
33
+ jax_fields, torch_fields, pytree=jax_model, state_dict=state_dict
34
+ )
35
+
36
+ assert np.allclose(
37
+ np.array(model.linear.weight), torch_model.linear.weight.detach().numpy()
38
+ )
39
+ assert np.allclose(
40
+ np.array(model.linear.bias), torch_model.linear.bias.detach().numpy()
41
+ )
42
+
43
+
44
+ if __name__ == "__main__":
45
+ test_linear()
@@ -1,16 +0,0 @@
1
- import jax
2
- import statedict2pytree as s2p
3
- from resnet import resnet50
4
- from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
5
-
6
-
7
- def convert_resnet():
8
- resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
9
- resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
10
- state_dict = resnet_torch.state_dict()
11
-
12
- s2p.start_conversion(resnet_jax, state_dict)
13
-
14
-
15
- if __name__ == "__main__":
16
- convert_resnet()
@@ -1,4 +0,0 @@
1
- from statedict2pytree.statedict2pytree import (
2
- convert as convert,
3
- start_conversion as start_conversion,
4
- )