statedict2pytree 0.5.2__tar.gz → 0.5.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,10 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.5.2
3
+ Version: 0.5.4
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
7
+ Requires-Dist: anthropic
7
8
  Requires-Dist: beartype
8
9
  Requires-Dist: equinox
9
10
  Requires-Dist: flask
@@ -14,6 +15,8 @@ Requires-Dist: jaxtyping
14
15
  Requires-Dist: loguru
15
16
  Requires-Dist: penzai
16
17
  Requires-Dist: pydantic
18
+ Requires-Dist: pytest
19
+ Requires-Dist: python-dotenv
17
20
  Requires-Dist: torch
18
21
  Requires-Dist: torchvision
19
22
  Requires-Dist: typing-extensions
@@ -30,6 +33,11 @@ Description-Content-Type: text/markdown
30
33
 
31
34
  ![statedict2pytree](statedict2pytree.png "A ResNet demo")
32
35
 
36
+ ## Docs
37
+
38
+ Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
39
+
40
+
33
41
  ## Important
34
42
 
35
43
  This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
@@ -37,6 +45,21 @@ PRs and other contributions are *highly* welcome! :)
37
45
 
38
46
  ## Info
39
47
 
48
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
49
+
50
+ ## Features
51
+
52
+ - Convert PyTorch statedicts to JAX pytrees
53
+ - Handle large models with chunked file conversion
54
+ - Provide an "intuitive-ish" UI for parameter mapping
55
+ - Support both in-memory and file-based conversions
56
+
57
+ ## Installation
58
+
59
+ ```bash
60
+ pip install statedict2pytree
61
+ ```
62
+
40
63
  The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
41
64
 
42
65
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
@@ -49,16 +72,9 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
49
72
 
50
73
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
51
74
 
52
- ## Get Started
53
-
54
- ### Installation
55
75
 
56
- Run
57
-
58
- ```bash
59
- pip install statedict2pytree
60
- ```
61
76
 
62
- ### Docs
77
+ ### Disclaimer
63
78
 
64
- Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
79
+ Some of the docstrings and the docs have been written with the help of
80
+ Claude.
@@ -2,6 +2,11 @@
2
2
 
3
3
  ![statedict2pytree](statedict2pytree.png "A ResNet demo")
4
4
 
5
+ ## Docs
6
+
7
+ Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
8
+
9
+
5
10
  ## Important
6
11
 
7
12
  This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
@@ -9,6 +14,21 @@ PRs and other contributions are *highly* welcome! :)
9
14
 
10
15
  ## Info
11
16
 
17
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
18
+
19
+ ## Features
20
+
21
+ - Convert PyTorch statedicts to JAX pytrees
22
+ - Handle large models with chunked file conversion
23
+ - Provide an "intuitive-ish" UI for parameter mapping
24
+ - Support both in-memory and file-based conversions
25
+
26
+ ## Installation
27
+
28
+ ```bash
29
+ pip install statedict2pytree
30
+ ```
31
+
12
32
  The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
13
33
 
14
34
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
@@ -21,16 +41,9 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
21
41
 
22
42
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
23
43
 
24
- ## Get Started
25
-
26
- ### Installation
27
44
 
28
- Run
29
-
30
- ```bash
31
- pip install statedict2pytree
32
- ```
33
45
 
34
- ### Docs
46
+ ### Disclaimer
35
47
 
36
- Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
48
+ Some of the docstrings and the docs have been written with the help of
49
+ Claude.