statedict2pytree 0.5.2__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,9 +1,10 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.5.2
3
+ Version: 0.5.3
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
7
+ Requires-Dist: anthropic
7
8
  Requires-Dist: beartype
8
9
  Requires-Dist: equinox
9
10
  Requires-Dist: flask
@@ -14,6 +15,7 @@ Requires-Dist: jaxtyping
14
15
  Requires-Dist: loguru
15
16
  Requires-Dist: penzai
16
17
  Requires-Dist: pydantic
18
+ Requires-Dist: python-dotenv
17
19
  Requires-Dist: torch
18
20
  Requires-Dist: torchvision
19
21
  Requires-Dist: typing-extensions
@@ -37,6 +39,21 @@ PRs and other contributions are *highly* welcome! :)
37
39
 
38
40
  ## Info
39
41
 
42
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
43
+
44
+ ## Features
45
+
46
+ - Convert PyTorch statedicts to JAX pytrees
47
+ - Handle large models with chunked file conversion
48
+ - Provide an "intuitive-ish" UI for parameter mapping
49
+ - Support both in-memory and file-based conversions
50
+
51
+ ## Installation
52
+
53
+ ```bash
54
+ pip install statedict2pytree
55
+ ```
56
+
40
57
  The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
41
58
 
42
59
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
@@ -49,16 +66,12 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
49
66
 
50
67
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
51
68
 
52
- ## Get Started
53
-
54
- ### Installation
55
-
56
- Run
57
-
58
- ```bash
59
- pip install statedict2pytree
60
- ```
61
69
 
62
70
  ### Docs
63
71
 
64
72
  Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
73
+
74
+ ### Disclaimer
75
+
76
+ Some of the docstrings and the docs have been written with the help of
77
+ Claude.
@@ -9,6 +9,21 @@ PRs and other contributions are *highly* welcome! :)
9
9
 
10
10
  ## Info
11
11
 
12
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
13
+
14
+ ## Features
15
+
16
+ - Convert PyTorch statedicts to JAX pytrees
17
+ - Handle large models with chunked file conversion
18
+ - Provide an "intuitive-ish" UI for parameter mapping
19
+ - Support both in-memory and file-based conversions
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install statedict2pytree
25
+ ```
26
+
12
27
  The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
13
28
 
14
29
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
@@ -21,16 +36,12 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
21
36
 
22
37
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
23
38
 
24
- ## Get Started
25
-
26
- ### Installation
27
-
28
- Run
29
-
30
- ```bash
31
- pip install statedict2pytree
32
- ```
33
39
 
34
40
  ### Docs
35
41
 
36
42
  Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
43
+
44
+ ### Disclaimer
45
+
46
+ Some of the docstrings and the docs have been written with the help of
47
+ Claude.