statedict2pytree 0.3.0__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/PKG-INFO +1 -1
  2. statedict2pytree-0.4.0/examples/convert_resnet.py +20 -0
  3. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/examples/resnet.py +43 -12
  4. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/examples/test_resnet_inference.py +6 -6
  5. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/pyproject.toml +1 -1
  6. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/statedict2pytree.py +4 -1
  7. statedict2pytree-0.3.0/examples/convert_resnet.py +0 -16
  8. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/.gitignore +0 -0
  9. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/.pre-commit-config.yaml +0 -0
  10. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/README.md +0 -0
  11. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/examples/doggo.jpeg +0 -0
  12. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/package-lock.json +0 -0
  13. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/package.json +0 -0
  14. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/pyrightconfig.json +0 -0
  15. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/__init__.py +0 -0
  16. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/static/input.css +0 -0
  17. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/static/output.css +0 -0
  18. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/templates/index.html +0 -0
  19. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/tailwind.config.js +0 -0
  20. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/tests/test_conv.py +0 -0
  21. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/tests/test_linear.py +0 -0
  22. {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/torch2jax.png +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
@@ -0,0 +1,20 @@
1
+ import equinox as eqx
2
+ import jax
3
+ import statedict2pytree as s2p
4
+ from resnet import resnet152
5
+ from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
6
+
7
+
8
+ def convert_resnet():
9
+ resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
10
+ resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
11
+ state_dict = resnet_torch.state_dict()
12
+
13
+ # s2p.start_conversion(resnet_jax, state_dict)
14
+ model, state = s2p.autoconvert(resnet_jax, state_dict)
15
+ name = "resnet152.eqx"
16
+ eqx.tree_serialise_leaves(name, (model, state))
17
+
18
+
19
+ if __name__ == "__main__":
20
+ convert_resnet()
@@ -333,12 +333,22 @@ def resnet18(
333
333
 
334
334
 
335
335
  def resnet34(
336
- image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
336
+ image_channels: int = 3,
337
+ num_classes: int = 1000,
338
+ *,
339
+ key: PRNGKeyArray,
340
+ make_with_state: bool = True,
341
+ **kwargs,
337
342
  ):
338
343
  layers = [3, 4, 6, 3]
339
- return eqx.nn.make_with_state(ResNet)(
340
- BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
341
- )
344
+ if make_with_state:
345
+ return eqx.nn.make_with_state(ResNet)(
346
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
347
+ )
348
+ else:
349
+ return ResNet(
350
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
351
+ )
342
352
 
343
353
 
344
354
  def resnet50(
@@ -361,18 +371,39 @@ def resnet50(
361
371
 
362
372
 
363
373
  def resnet101(
364
- image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
374
+ image_channels: int = 3,
375
+ num_classes: int = 1000,
376
+ *,
377
+ key: PRNGKeyArray,
378
+ make_with_state: bool = True,
379
+ **kwargs,
365
380
  ):
366
381
  layers = [3, 4, 23, 3]
367
- return eqx.nn.make_with_state(ResNet)(
368
- Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
369
- )
382
+
383
+ if make_with_state:
384
+ return eqx.nn.make_with_state(ResNet)(
385
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
386
+ )
387
+ else:
388
+ return ResNet(
389
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
390
+ )
370
391
 
371
392
 
372
393
  def resnet152(
373
- image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
394
+ image_channels: int = 3,
395
+ num_classes: int = 1000,
396
+ *,
397
+ key: PRNGKeyArray,
398
+ make_with_state: bool = True,
399
+ **kwargs,
374
400
  ):
375
401
  layers = [3, 8, 36, 3]
376
- return eqx.nn.make_with_state(ResNet)(
377
- Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
378
- )
402
+ if make_with_state:
403
+ return eqx.nn.make_with_state(ResNet)(
404
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
405
+ )
406
+ else:
407
+ return ResNet(
408
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
409
+ )
@@ -6,15 +6,15 @@ import equinox as eqx
6
6
  import jax
7
7
  import jax.numpy as jnp
8
8
  import torch
9
+ from examples.resnet import resnet152
9
10
  from PIL import Image
10
- from tests.resnet import resnet50
11
11
  from torchvision import transforms
12
- from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
12
+ from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
13
13
 
14
14
 
15
15
  def test_resnet():
16
- resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
17
- resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
16
+ resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
17
+ resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
18
18
 
19
19
  img_name = "doggo.jpeg"
20
20
 
@@ -42,7 +42,7 @@ def test_resnet():
42
42
  ) # Outputs the ImageNet class index of the prediction
43
43
 
44
44
  url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
45
- with urllib.request.urlopen(url) as url:
45
+ with urllib.request.urlopen(url) as url: # pyright: ignore
46
46
  imagenet_labels = json.loads(url.read().decode())
47
47
 
48
48
  label = imagenet_labels[str(predicted.item())][1]
@@ -52,7 +52,7 @@ def test_resnet():
52
52
  model_callable = ft.partial(identity, resnet_jax)
53
53
  model, state = eqx.nn.make_with_state(model_callable)()
54
54
 
55
- model, state = eqx.tree_deserialise_leaves("model.eqx", (model, state))
55
+ model, state = eqx.tree_deserialise_leaves("resnet152.eqx", (model, state))
56
56
 
57
57
  jax_batch = jnp.array(batch_t.numpy())
58
58
  out, state = eqx.filter_vmap(
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "statedict2pytree"
3
- version = "0.3.0"
3
+ version = "0.4.0"
4
4
  description = "Converts torch models into PyTrees for Equinox"
5
5
  readme = "README.md"
6
6
  requires-python = "~=3.10"
@@ -46,7 +46,7 @@ def get_node(
46
46
  return tree
47
47
  else:
48
48
  next_target: str = targets[0]
49
- if bool(re.search(r"\[\d\]", next_target)):
49
+ if bool(re.search(r"\[\d+\]", next_target)):
50
50
  split_index = next_target.rfind("[")
51
51
  name, index = next_target[:split_index], next_target[split_index:]
52
52
  index = index[1:-1]
@@ -157,6 +157,9 @@ def index():
157
157
  def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
158
  jax_fields = pytree_to_fields(pytree)
159
159
  torch_fields = state_dict_to_fields(state_dict)
160
+
161
+ for k, v in state_dict.items():
162
+ state_dict[k] = v.numpy()
160
163
  return convert(jax_fields, torch_fields, pytree, state_dict)
161
164
 
162
165
 
@@ -1,16 +0,0 @@
1
- import jax
2
- import statedict2pytree as s2p
3
- from resnet import resnet50
4
- from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
5
-
6
-
7
- def convert_resnet():
8
- resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
9
- resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
10
- state_dict = resnet_torch.state_dict()
11
-
12
- s2p.start_conversion(resnet_jax, state_dict)
13
-
14
-
15
- if __name__ == "__main__":
16
- convert_resnet()