statedict2pytree 0.5.2__tar.gz → 0.5.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/PKG-INFO +27 -11
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/README.md +23 -10
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/public/bundle.js +615 -234
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/public/bundle.js.map +1 -1
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/public/output.css +164 -217
- statedict2pytree-0.5.4/client/src/App.svelte +584 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/pyproject.toml +5 -2
- statedict2pytree-0.5.2/client/src/App.svelte +0 -343
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/.gitignore +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/.gitignore +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/package-lock.json +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/package.json +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/public/index.html +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/public/input.css +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/rollup.config.mjs +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/src/empty.ts +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/src/main.js +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/tailwind.config.js +0 -0
- {statedict2pytree-0.5.2 → statedict2pytree-0.5.4}/client/tsconfig.json +0 -0
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.5.
|
|
3
|
+
Version: 0.5.4
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: ~=3.10
|
|
7
|
+
Requires-Dist: anthropic
|
|
7
8
|
Requires-Dist: beartype
|
|
8
9
|
Requires-Dist: equinox
|
|
9
10
|
Requires-Dist: flask
|
|
@@ -14,6 +15,8 @@ Requires-Dist: jaxtyping
|
|
|
14
15
|
Requires-Dist: loguru
|
|
15
16
|
Requires-Dist: penzai
|
|
16
17
|
Requires-Dist: pydantic
|
|
18
|
+
Requires-Dist: pytest
|
|
19
|
+
Requires-Dist: python-dotenv
|
|
17
20
|
Requires-Dist: torch
|
|
18
21
|
Requires-Dist: torchvision
|
|
19
22
|
Requires-Dist: typing-extensions
|
|
@@ -30,6 +33,11 @@ Description-Content-Type: text/markdown
|
|
|
30
33
|
|
|
31
34
|

|
|
32
35
|
|
|
36
|
+
## Docs
|
|
37
|
+
|
|
38
|
+
Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
|
|
39
|
+
|
|
40
|
+
|
|
33
41
|
## Important
|
|
34
42
|
|
|
35
43
|
This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
|
|
@@ -37,6 +45,21 @@ PRs and other contributions are *highly* welcome! :)
|
|
|
37
45
|
|
|
38
46
|
## Info
|
|
39
47
|
|
|
48
|
+
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
|
|
49
|
+
|
|
50
|
+
## Features
|
|
51
|
+
|
|
52
|
+
- Convert PyTorch statedicts to JAX pytrees
|
|
53
|
+
- Handle large models with chunked file conversion
|
|
54
|
+
- Provide an "intuitive-ish" UI for parameter mapping
|
|
55
|
+
- Support both in-memory and file-based conversions
|
|
56
|
+
|
|
57
|
+
## Installation
|
|
58
|
+
|
|
59
|
+
```bash
|
|
60
|
+
pip install statedict2pytree
|
|
61
|
+
```
|
|
62
|
+
|
|
40
63
|
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
41
64
|
|
|
42
65
|
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
@@ -49,16 +72,9 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
|
|
|
49
72
|
|
|
50
73
|
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
51
74
|
|
|
52
|
-
## Get Started
|
|
53
|
-
|
|
54
|
-
### Installation
|
|
55
75
|
|
|
56
|
-
Run
|
|
57
|
-
|
|
58
|
-
```bash
|
|
59
|
-
pip install statedict2pytree
|
|
60
|
-
```
|
|
61
76
|
|
|
62
|
-
###
|
|
77
|
+
### Disclaimer
|
|
63
78
|
|
|
64
|
-
|
|
79
|
+
Some of the docstrings and the docs have been written with the help of
|
|
80
|
+
Claude.
|
|
@@ -2,6 +2,11 @@
|
|
|
2
2
|
|
|
3
3
|

|
|
4
4
|
|
|
5
|
+
## Docs
|
|
6
|
+
|
|
7
|
+
Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
|
|
8
|
+
|
|
9
|
+
|
|
5
10
|
## Important
|
|
6
11
|
|
|
7
12
|
This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
|
|
@@ -9,6 +14,21 @@ PRs and other contributions are *highly* welcome! :)
|
|
|
9
14
|
|
|
10
15
|
## Info
|
|
11
16
|
|
|
17
|
+
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
|
|
18
|
+
|
|
19
|
+
## Features
|
|
20
|
+
|
|
21
|
+
- Convert PyTorch statedicts to JAX pytrees
|
|
22
|
+
- Handle large models with chunked file conversion
|
|
23
|
+
- Provide an "intuitive-ish" UI for parameter mapping
|
|
24
|
+
- Support both in-memory and file-based conversions
|
|
25
|
+
|
|
26
|
+
## Installation
|
|
27
|
+
|
|
28
|
+
```bash
|
|
29
|
+
pip install statedict2pytree
|
|
30
|
+
```
|
|
31
|
+
|
|
12
32
|
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
13
33
|
|
|
14
34
|
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
@@ -21,16 +41,9 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
|
|
|
21
41
|
|
|
22
42
|
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
23
43
|
|
|
24
|
-
## Get Started
|
|
25
|
-
|
|
26
|
-
### Installation
|
|
27
44
|
|
|
28
|
-
Run
|
|
29
|
-
|
|
30
|
-
```bash
|
|
31
|
-
pip install statedict2pytree
|
|
32
|
-
```
|
|
33
45
|
|
|
34
|
-
###
|
|
46
|
+
### Disclaimer
|
|
35
47
|
|
|
36
|
-
|
|
48
|
+
Some of the docstrings and the docs have been written with the help of
|
|
49
|
+
Claude.
|