statedict2pytree 0.5.0__tar.gz → 0.5.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. {statedict2pytree-0.5.0 → statedict2pytree-0.5.3}/.gitignore +2 -0
  2. statedict2pytree-0.5.3/PKG-INFO +77 -0
  3. statedict2pytree-0.5.3/README.md +47 -0
  4. statedict2pytree-0.5.3/client/.gitignore +3 -0
  5. statedict2pytree-0.5.3/client/package-lock.json +4540 -0
  6. statedict2pytree-0.5.3/client/package.json +36 -0
  7. statedict2pytree-0.5.3/client/public/bundle.js +10072 -0
  8. statedict2pytree-0.5.3/client/public/bundle.js.map +1 -0
  9. statedict2pytree-0.5.3/client/public/index.html +14 -0
  10. {statedict2pytree-0.5.0/statedict2pytree/static → statedict2pytree-0.5.3/client/public}/output.css +184 -109
  11. statedict2pytree-0.5.3/client/rollup.config.mjs +44 -0
  12. statedict2pytree-0.5.3/client/src/App.svelte +583 -0
  13. statedict2pytree-0.5.3/client/src/empty.ts +0 -0
  14. statedict2pytree-0.5.3/client/src/main.js +8 -0
  15. {statedict2pytree-0.5.0 → statedict2pytree-0.5.3/client}/tailwind.config.js +1 -1
  16. statedict2pytree-0.5.3/client/tsconfig.json +5 -0
  17. {statedict2pytree-0.5.0 → statedict2pytree-0.5.3}/pyproject.toml +17 -5
  18. statedict2pytree-0.5.0/.pre-commit-config.yaml +0 -16
  19. statedict2pytree-0.5.0/PKG-INFO +0 -147
  20. statedict2pytree-0.5.0/README.md +0 -123
  21. statedict2pytree-0.5.0/examples/convert_resnet.py +0 -19
  22. statedict2pytree-0.5.0/examples/doggo.jpeg +0 -0
  23. statedict2pytree-0.5.0/examples/resnet.py +0 -409
  24. statedict2pytree-0.5.0/examples/resnet_inference.py +0 -68
  25. statedict2pytree-0.5.0/package-lock.json +0 -1424
  26. statedict2pytree-0.5.0/package.json +0 -9
  27. statedict2pytree-0.5.0/pyrightconfig.json +0 -4
  28. statedict2pytree-0.5.0/statedict2pytree/__init__.py +0 -7
  29. statedict2pytree-0.5.0/statedict2pytree/statedict2pytree.py +0 -219
  30. statedict2pytree-0.5.0/statedict2pytree/templates/index.html +0 -308
  31. statedict2pytree-0.5.0/tests/test_batchnorm.py +0 -46
  32. statedict2pytree-0.5.0/tests/test_conv.py +0 -57
  33. statedict2pytree-0.5.0/tests/test_linear.py +0 -41
  34. statedict2pytree-0.5.0/torch2jax.png +0 -0
  35. {statedict2pytree-0.5.0/statedict2pytree/static → statedict2pytree-0.5.3/client/public}/input.css +0 -0
@@ -159,3 +159,5 @@ cython_debug/
159
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
160
  #.idea/
161
161
  node_modules/
162
+ examples/llama/Meta-Llama-3-8B/
163
+ .zed/
@@ -0,0 +1,77 @@
1
+ Metadata-Version: 2.3
2
+ Name: statedict2pytree
3
+ Version: 0.5.3
4
+ Summary: Converts torch models into PyTrees for Equinox
5
+ Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
+ Requires-Python: ~=3.10
7
+ Requires-Dist: anthropic
8
+ Requires-Dist: beartype
9
+ Requires-Dist: equinox
10
+ Requires-Dist: flask
11
+ Requires-Dist: jax
12
+ Requires-Dist: jaxlib
13
+ Requires-Dist: jaxonmodels
14
+ Requires-Dist: jaxtyping
15
+ Requires-Dist: loguru
16
+ Requires-Dist: penzai
17
+ Requires-Dist: pydantic
18
+ Requires-Dist: python-dotenv
19
+ Requires-Dist: torch
20
+ Requires-Dist: torchvision
21
+ Requires-Dist: typing-extensions
22
+ Provides-Extra: dev
23
+ Requires-Dist: mkdocs; extra == 'dev'
24
+ Requires-Dist: nox; extra == 'dev'
25
+ Requires-Dist: pre-commit; extra == 'dev'
26
+ Requires-Dist: pytest; extra == 'dev'
27
+ Provides-Extra: examples
28
+ Requires-Dist: jaxonmodels; extra == 'examples'
29
+ Description-Content-Type: text/markdown
30
+
31
+ # statedict2pytree
32
+
33
+ ![statedict2pytree](statedict2pytree.png "A ResNet demo")
34
+
35
+ ## Important
36
+
37
+ This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
38
+ PRs and other contributions are *highly* welcome! :)
39
+
40
+ ## Info
41
+
42
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
43
+
44
+ ## Features
45
+
46
+ - Convert PyTorch statedicts to JAX pytrees
47
+ - Handle large models with chunked file conversion
48
+ - Provide an "intuitive-ish" UI for parameter mapping
49
+ - Support both in-memory and file-based conversions
50
+
51
+ ## Installation
52
+
53
+ ```bash
54
+ pip install statedict2pytree
55
+ ```
56
+
57
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
58
+
59
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
60
+
61
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
62
+
63
+ ## Shape Matching? What's that?
64
+
65
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
66
+
67
+ (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
68
+
69
+
70
+ ### Docs
71
+
72
+ Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
73
+
74
+ ### Disclaimer
75
+
76
+ Some of the docstrings and the docs have been written with the help of
77
+ Claude.
@@ -0,0 +1,47 @@
1
+ # statedict2pytree
2
+
3
+ ![statedict2pytree](statedict2pytree.png "A ResNet demo")
4
+
5
+ ## Important
6
+
7
+ This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
8
+ PRs and other contributions are *highly* welcome! :)
9
+
10
+ ## Info
11
+
12
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
13
+
14
+ ## Features
15
+
16
+ - Convert PyTorch statedicts to JAX pytrees
17
+ - Handle large models with chunked file conversion
18
+ - Provide an "intuitive-ish" UI for parameter mapping
19
+ - Support both in-memory and file-based conversions
20
+
21
+ ## Installation
22
+
23
+ ```bash
24
+ pip install statedict2pytree
25
+ ```
26
+
27
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
28
+
29
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
30
+
31
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
32
+
33
+ ## Shape Matching? What's that?
34
+
35
+ Currently, there is no sophisticated shape matching in place. Two matrices are considered "matching" if the product of their shape match. For example:
36
+
37
+ (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
38
+
39
+
40
+ ### Docs
41
+
42
+ Documentation will appear as soon as I have all the necessary features implemented. Until then, check out the "main.py" file for a better example.
43
+
44
+ ### Disclaimer
45
+
46
+ Some of the docstrings and the docs have been written with the help of
47
+ Claude.
@@ -0,0 +1,3 @@
1
+ .DS_Store
2
+ node_modules
3
+ public/bundle.*