statedict2pytree 0.6.0__tar.gz → 1.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. statedict2pytree-1.0.1/.github/workflows/run_tests.yml +38 -0
  2. statedict2pytree-1.0.1/.pre-commit-config.yaml +12 -0
  3. {statedict2pytree-0.6.0 → statedict2pytree-1.0.1}/PKG-INFO +10 -28
  4. {statedict2pytree-0.6.0 → statedict2pytree-1.0.1}/README.md +6 -15
  5. statedict2pytree-1.0.1/docs/index.md +186 -0
  6. statedict2pytree-1.0.1/mkdocs.yml +19 -0
  7. statedict2pytree-1.0.1/pyproject.toml +24 -0
  8. statedict2pytree-1.0.1/pyrightconfig.json +4 -0
  9. statedict2pytree-1.0.1/statedict2pytree/__init__.py +8 -0
  10. statedict2pytree-1.0.1/statedict2pytree/converter.py +293 -0
  11. statedict2pytree-1.0.1/tests/test_batchnorm.py +56 -0
  12. statedict2pytree-1.0.1/tests/test_conv.py +52 -0
  13. statedict2pytree-1.0.1/tests/test_linear.py +36 -0
  14. statedict2pytree-1.0.1/uv.lock +1909 -0
  15. statedict2pytree-0.6.0/client/.gitignore +0 -3
  16. statedict2pytree-0.6.0/client/package-lock.json +0 -4540
  17. statedict2pytree-0.6.0/client/package.json +0 -36
  18. statedict2pytree-0.6.0/client/public/bundle.js +0 -10072
  19. statedict2pytree-0.6.0/client/public/bundle.js.map +0 -1
  20. statedict2pytree-0.6.0/client/public/index.html +0 -14
  21. statedict2pytree-0.6.0/client/public/input.css +0 -3
  22. statedict2pytree-0.6.0/client/public/output.css +0 -1617
  23. statedict2pytree-0.6.0/client/rollup.config.mjs +0 -44
  24. statedict2pytree-0.6.0/client/src/App.svelte +0 -584
  25. statedict2pytree-0.6.0/client/src/empty.ts +0 -0
  26. statedict2pytree-0.6.0/client/src/main.js +0 -8
  27. statedict2pytree-0.6.0/client/tailwind.config.js +0 -8
  28. statedict2pytree-0.6.0/client/tsconfig.json +0 -5
  29. statedict2pytree-0.6.0/pyproject.toml +0 -62
  30. {statedict2pytree-0.6.0 → statedict2pytree-1.0.1}/.gitignore +0 -0
@@ -0,0 +1,38 @@
1
+ name: Run tests
2
+
3
+ on: [pull_request]
4
+
5
+ jobs:
6
+ run-test:
7
+ strategy:
8
+ matrix:
9
+ python-version: ["3.10"]
10
+ os: [ubuntu-latest]
11
+ fail-fast: false
12
+ runs-on: ${{ matrix.os }}
13
+ steps:
14
+ - name: Checkout code
15
+ uses: actions/checkout@v4
16
+
17
+ - name: Set up Python ${{ matrix.python-version }}
18
+ uses: actions/setup-python@v5
19
+ with:
20
+ python-version: ${{ matrix.python-version }}
21
+
22
+ - name: Install uv
23
+ uses: astral-sh/setup-uv@v1
24
+
25
+ - name: Create virtual environment and install dependencies
26
+ run: |
27
+ uv venv --python ${{ matrix.python-version }}
28
+ source .venv/bin/activate
29
+ uv pip install -e ".[dev]"
30
+ uv pip install -e .
31
+
32
+ - name: Run pre-commit hooks
33
+ uses: pre-commit/action@v3.0.1
34
+
35
+ - name: Run tests
36
+ run: |
37
+ source .venv/bin/activate
38
+ python3 -m pytest
@@ -0,0 +1,12 @@
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ rev: v0.11.12
4
+ hooks:
5
+ - id: ruff-check
6
+ - id: ruff-format
7
+ - repo: https://github.com/RobertCraigie/pyright-python
8
+ rev: v1.1.351
9
+ hooks:
10
+ - id: pyright
11
+ additional_dependencies:
12
+ [beartype, jax, jaxtyping, pytest, typing_extensions]
@@ -1,58 +1,41 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.6.0
3
+ Version: 1.0.1
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
- Requires-Python: ~=3.10
7
- Requires-Dist: anthropic
6
+ Requires-Python: >=3.11
8
7
  Requires-Dist: beartype
9
8
  Requires-Dist: equinox
10
- Requires-Dist: flask
11
9
  Requires-Dist: jax
12
10
  Requires-Dist: jaxlib
13
- Requires-Dist: jaxonmodels
14
11
  Requires-Dist: jaxtyping
15
- Requires-Dist: loguru
16
- Requires-Dist: penzai
17
12
  Requires-Dist: pydantic
18
- Requires-Dist: pytest
19
- Requires-Dist: python-dotenv
20
- Requires-Dist: torch
21
- Requires-Dist: torchvision
22
- Requires-Dist: typing-extensions
13
+ Requires-Dist: tqdm
23
14
  Provides-Extra: dev
24
15
  Requires-Dist: mkdocs; extra == 'dev'
25
- Requires-Dist: nox; extra == 'dev'
26
16
  Requires-Dist: pre-commit; extra == 'dev'
27
17
  Requires-Dist: pytest; extra == 'dev'
18
+ Requires-Dist: torch; extra == 'dev'
28
19
  Provides-Extra: examples
29
20
  Requires-Dist: jaxonmodels; extra == 'examples'
30
21
  Description-Content-Type: text/markdown
31
22
 
32
23
  # statedict2pytree
33
24
 
34
- ![statedict2pytree](statedict2pytree.png "A ResNet demo")
35
25
 
36
- ## Docs
26
+ ## Update:
37
27
 
38
- Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
28
+ For examples for `statedict2pytree`, check out my other repository [jaxonmodels](https://github.com/Artur-Galstyan/jaxonmodels).
39
29
 
30
+ ## Docs
40
31
 
41
- ## Important
32
+ Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
42
33
 
43
- This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
44
- PRs and other contributions are *highly* welcome! :)
45
34
 
46
35
  ## Info
47
36
 
48
- `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
37
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
49
38
 
50
- ## Features
51
-
52
- - Convert PyTorch statedicts to JAX pytrees
53
- - Handle large models with chunked file conversion
54
- - Provide an "intuitive-ish" UI for parameter mapping
55
- - Support both in-memory and file-based conversions
56
39
 
57
40
  ## Installation
58
41
 
@@ -64,7 +47,6 @@ The goal of this package is to simplify the conversion from PyTorch models into
64
47
 
65
48
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
66
49
 
67
- (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
68
50
 
69
51
  ## Shape Matching? What's that?
70
52
 
@@ -73,8 +55,8 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
73
55
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
74
56
 
75
57
 
76
-
77
58
  ### Disclaimer
78
59
 
79
60
  Some of the docstrings and the docs have been written with the help of
80
61
  Claude.
62
+
@@ -1,27 +1,19 @@
1
1
  # statedict2pytree
2
2
 
3
- ![statedict2pytree](statedict2pytree.png "A ResNet demo")
4
3
 
5
- ## Docs
4
+ ## Update:
6
5
 
7
- Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
6
+ For examples for `statedict2pytree`, check out my other repository [jaxonmodels](https://github.com/Artur-Galstyan/jaxonmodels).
8
7
 
8
+ ## Docs
9
9
 
10
- ## Important
10
+ Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
11
11
 
12
- This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
13
- PRs and other contributions are *highly* welcome! :)
14
12
 
15
13
  ## Info
16
14
 
17
- `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees. It provides both programmatic and UI-based methods for mapping between PyTorch and JAX model parameters.
15
+ `statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
18
16
 
19
- ## Features
20
-
21
- - Convert PyTorch statedicts to JAX pytrees
22
- - Handle large models with chunked file conversion
23
- - Provide an "intuitive-ish" UI for parameter mapping
24
- - Support both in-memory and file-based conversions
25
17
 
26
18
  ## Installation
27
19
 
@@ -33,7 +25,6 @@ The goal of this package is to simplify the conversion from PyTorch models into
33
25
 
34
26
  Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
35
27
 
36
- (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
37
28
 
38
29
  ## Shape Matching? What's that?
39
30
 
@@ -42,8 +33,8 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
42
33
  (8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
43
34
 
44
35
 
45
-
46
36
  ### Disclaimer
47
37
 
48
38
  Some of the docstrings and the docs have been written with the help of
49
39
  Claude.
40
+
@@ -0,0 +1,186 @@
1
+ # Quickstart Guide
2
+
3
+ ## Installation
4
+
5
+ To install `statedict2pytree`, run:
6
+
7
+ ```bash
8
+ pip install statedict2pytree
9
+ ```
10
+
11
+ ---
12
+
13
+ ## Basic Usage
14
+
15
+ There are 4-5 main functions you might interact with:
16
+
17
+ * `autoconvert`
18
+ * `convert`
19
+ * `pytree_to_fields`
20
+ * `state_dict_to_fields`
21
+ * `move_running_fields_to_the_end` (optional helper)
22
+
23
+ ---
24
+
25
+ ## General Information
26
+
27
+ `statedict2pytree` primarily aligns your JAX PyTree and the PyTorch `state_dict` side-by-side. It then checks if the shapes of the aligned weights match. If they do, it converts the PyTorch tensors to JAX arrays and places them into a new PyTree with the same structure as your original JAX PyTree.
28
+
29
+ **This means that the order and the shape of the arrays in your PyTree and the `state_dict` must match after any optional reordering!** The `pytree_to_fields` function uses a filter (defaulting to `equinox.is_array`) to determine which elements are considered fields.
30
+
31
+ For example, this conversion will **work** ✅:
32
+
33
+ | Parameter | JAX Shape | PyTorch Shape |
34
+ | :---------------- | :----------- | :------------ |
35
+ | `linear.weight` | `(2, 2)` | `(2, 2)` |
36
+ | `linear.bias` | `(2,)` | `(2,)` |
37
+ | `conv.weight` | `(1, 1, 2, 2)` | `(1, 1, 2, 2)`|
38
+ | `conv.bias` | `(1,)` | `(1,)` |
39
+
40
+ Since the shapes match when aligned in the same order, the conversion is successful.
41
+
42
+ On the other hand, this will **not work** ❌:
43
+
44
+ | Parameter | JAX Shape | PyTorch Shape | Mismatch? |
45
+ | :---------------- | :----------- | :------------ | :-------- |
46
+ | `linear.weight` | `(2, 2)` | `(3, 2)` | Yes |
47
+ | `linear.bias` | `(2,)` | `(3,)` | Yes |
48
+ | `conv.weight` | `(1, 1, 2, 2)` | `(1, 1, 2, 2)`| No |
49
+ | `conv.bias` | `(1,)` | `(1,)` | No |
50
+
51
+ This conversion will fail because the shapes of `model.linear.weight` and `model.linear.bias` don't match between the PyTree and the state dict.
52
+
53
+ Another reason why the conversion might fail is if the **order** of parameters (and thus the shapes of misaligned parameters) doesn't match:
54
+
55
+ | JAX Parameter (Model Order) | JAX Shape | PyTorch Counterpart (`state_dict` Order) | PyTorch Shape | Issue if Matched Sequentially |
56
+ | :-------------------------- | :------------- | :--------------------------------------- | :------------- | :------------------------------------------------ |
57
+ | `model['conv']['weight']` | `(1, 1, 2, 2)` | `state_dict['model.linear.weight']` | `(2, 2)` | Order: JAX `conv.w` `(1122)` vs PT `linear.w` `(22)` |
58
+ | `model['conv']['bias']` | `(1,)` | `state_dict['model.linear.bias']` | `(2,)` | Order: JAX `conv.b` `(1,)` vs PT `linear.b` `(2,)` |
59
+ | `model['linear']['weight']` | `(2, 2)` | `state_dict['model.conv.weight']` | `(1, 1, 2, 2)` | Order: JAX `linear.w` `(22)` vs PT `conv.w` `(1122)`|
60
+ | `model['linear']['bias']` | `(2,)` | `state_dict['model.conv.bias']` | `(1,)` | Order: JAX `linear.b` `(2,)` vs PT `conv.b` `(1,)` |
61
+
62
+ To help with the order issue, you can provide a `list[str]` specifying the desired order of PyTree fields (matching the `state_dict`'s conceptual order, or vice-versa if you reorder `state_dict` fields). This is especially helpful when you can't easily force the correct order using `move_running_fields_to_the_end`. For the example above, if your PyTree expects `conv` then `linear`, the list of strings representing the *names from the state\_dict in the JAX PyTree's desired order* would be:
63
+
64
+ ```python
65
+ ['model.conv.weight', 'model.conv.bias', 'model.linear.weight', 'model.linear.bias']
66
+ ```
67
+ This list would be passed to `pytree_to_fields` via `autoconvert`'s `pytree_model_order` argument to ensure `jaxfields` are in this sequence. Alternatively, you could reorder `torchfields` using `move_running_fields_to_the_end` or other custom logic.
68
+
69
+ ---
70
+
71
+ ## API Reference
72
+
73
+ ### `autoconvert`
74
+
75
+ This is the simplest, highest-level function for most use cases.
76
+
77
+ ```python
78
+ def autoconvert(
79
+ pytree: PyTree,
80
+ state_dict: dict,
81
+ pytree_model_order: list[str] | None = None
82
+ ) -> PyTree:
83
+ ...
84
+ ```
85
+
86
+ You provide your JAX `pytree` and the PyTorch `state_dict`. Optionally, you can give `pytree_model_order` (a list of strings representing `jax.tree_util.keystr(path)`) to ensure the JAX fields are processed in a specific sequence. It handles the steps of field extraction (using `pytree_to_fields` with its default `filter=eqx.is_array`), alignment, and conversion, returning the populated JAX PyTree. If you need custom filtering for PyTree leaves, you should use `pytree_to_fields` and `convert` separately.
87
+
88
+ * **Parameters**:
89
+ * `pytree`: The JAX PyTree (e.g., an Equinox model) whose structure is the target.
90
+ * `state_dict`: The PyTorch state dictionary containing the weights.
91
+ * `pytree_model_order` (optional): A list of JAX KeyPath strings (like `'.layers.0.linear.weight'`). If provided, JAX fields will be ordered according to this list. This is useful if the automatic PyTree traversal order doesn't match the `state_dict` order.
92
+ * **Returns**: A new JAX PyTree with the same structure as the input `pytree`, but with weights populated from the `state_dict`.
93
+
94
+ ---
95
+
96
+ ### `convert`
97
+
98
+ This is the core function that performs the actual conversion once the JAX PyTree fields and PyTorch `state_dict` fields have been extracted and aligned.
99
+
100
+ ```python
101
+ def convert(
102
+ state_dict: dict[str, Any],
103
+ pytree: PyTree,
104
+ jaxfields: list[JaxField],
105
+ state_indices: dict | None,
106
+ torchfields: list[TorchField],
107
+ dtype: Any | None = None,
108
+ ) -> PyTree:
109
+ ...
110
+ ```
111
+
112
+ It iterates through the aligned `jaxfields` and `torchfields`, checks for shape compatibility (reshapability), converts PyTorch tensors (expected as values in `state_dict`) to JAX arrays (optionally casting `dtype`), and inserts them into the correct place in the JAX PyTree.
113
+
114
+ * **Parameters**:
115
+ * `state_dict`: The original PyTorch state dictionary. Values are expected to be tensor-like (e.g., `torch.Tensor`).
116
+ * `pytree`: The JAX PyTree that will be populated.
117
+ * `jaxfields`: An ordered list of `JaxField` objects (obtained from `pytree_to_fields`) representing the leaves of the JAX PyTree.
118
+ * `state_indices`: A dictionary mapping state markers to `eqx.nn.StateIndex` objects, used for handling Equinox stateful layers.
119
+ * `torchfields`: An ordered list of `TorchField` objects (obtained from `state_dict_to_fields`) representing the tensors in the PyTorch `state_dict`. **This list must be ordered to match `jaxfields`**.
120
+ * `dtype` (optional): The JAX data type to convert floating-point tensors to (e.g., `jnp.float32`). Defaults to JAX's current default floating-point type.
121
+ * **Returns**: A new JAX PyTree populated with weights from the `state_dict`.
122
+
123
+ ---
124
+
125
+ ### `pytree_to_fields`
126
+
127
+ This function traverses a JAX PyTree and extracts information about its array leaves based on a filter.
128
+
129
+ ```python
130
+ def pytree_to_fields(
131
+ pytree: PyTree,
132
+ model_order: list[str] | None = None,
133
+ filter: Callable[[Array], bool] = eqx.is_array,
134
+ ) -> tuple[list[JaxField], dict | None]:
135
+ ...
136
+ ```
137
+
138
+ It identifies all JAX arrays (or other elements satisfying the `filter`) within the `pytree`, recording their `KeyPath` (path within the PyTree) and shape. If `model_order` is provided, it attempts to reorder the extracted fields according to that list. This is crucial for ensuring the JAX fields align correctly with the PyTorch fields.
139
+
140
+ * **Parameters**:
141
+ * `pytree`: The JAX PyTree to analyze.
142
+ * `model_order` (optional): A list of strings, where each string is a `jax.tree_util.keystr` representation of a `KeyPath` to an array leaf in the `pytree`. If provided, the output `JaxField` list will be sorted according to this order, with any fields not in `model_order` appended at the end.
143
+ * `filter` (optional): A callable that takes a PyTree leaf (e.g., an array) and returns `True` if it should be considered a field to be converted, `False` otherwise. Defaults to `equinox.is_array`.
144
+ * **Returns**: A tuple containing:
145
+ * `list[JaxField]`: A list of `JaxField` objects, each describing a filtered leaf in the PyTree (path, shape).
146
+ * `dict | None`: A dictionary containing information about `eqx.nn.StateIndex` objects found in the PyTree, or `None` if none are found.
147
+
148
+ ---
149
+
150
+ ### `state_dict_to_fields`
151
+
152
+ This function processes a PyTorch `state_dict` to extract information about its tensors.
153
+
154
+ ```python
155
+ def state_dict_to_fields(
156
+ state_dict: dict[str, Any],
157
+ ) -> list[TorchField]:
158
+ ...
159
+ ```
160
+
161
+ It iterates through the `state_dict`, creating a `TorchField` object for each value that has a `shape` attribute and a non-empty shape (typically tensors). This object stores the tensor's name (key in the `state_dict`) and its shape.
162
+
163
+ * **Parameters**:
164
+ * `state_dict`: The PyTorch state dictionary. Values are typically `torch.Tensor` or other array-like objects.
165
+ * **Returns**: A list of `TorchField` objects, each describing a tensor in the `state_dict` (path/key, shape). The order matches the iteration order of the input `state_dict`.
166
+
167
+ ---
168
+
169
+ ### `move_running_fields_to_the_end`
170
+
171
+ This is an optional utility function to help reorder fields extracted from a PyTorch `state_dict`.
172
+
173
+ ```python
174
+ def move_running_fields_to_the_end(
175
+ torchfields: list[TorchField],
176
+ identifier: str = "running_"
177
+ ):
178
+ ...
179
+ ```
180
+
181
+ It's particularly useful for models with layers like `BatchNorm`, where PyTorch often stores `running_mean` and `running_var` interspersed with weights and biases, while Equinox (a common JAX library) typically expects stateful components like these at the end of a layer's parameter list. This function moves any `TorchField` whose path contains the `identifier` (defaulting to `"running_"`) to the end of the list.
182
+
183
+ * **Parameters**:
184
+ * `torchfields`: The list of `TorchField` objects to be reordered.
185
+ * `identifier` (optional): A string that, if found within a `TorchField`'s path, will cause that field to be moved to the end of the list. Default is `"running_"`.
186
+ * **Returns**: The modified list of `TorchField` objects with identified fields moved to the end.
@@ -0,0 +1,19 @@
1
+ site_name: Statedict2Pytree
2
+
3
+ theme:
4
+ name: material
5
+
6
+ plugins:
7
+ - search
8
+ - mkdocstrings:
9
+ default_handler: python
10
+
11
+ nav:
12
+ - Home: index.md
13
+
14
+ markdown_extensions:
15
+ - pymdownx.highlight:
16
+ anchor_linenums: true
17
+ - pymdownx.inlinehilite
18
+ - pymdownx.snippets
19
+ - pymdownx.superfences
@@ -0,0 +1,24 @@
1
+ [project]
2
+ name = "statedict2pytree"
3
+ version = "1.0.1"
4
+ description = "Converts torch models into PyTrees for Equinox"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ authors = [{ name = "Artur A. Galstyan", email = "mail@arturgalstyan.dev" }]
8
+ dependencies = [
9
+ "jax",
10
+ "equinox",
11
+ "jaxlib",
12
+ "beartype",
13
+ "jaxtyping",
14
+ "pydantic",
15
+ "tqdm",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ dev = ["pre-commit", "pytest", "mkdocs", "torch"]
20
+ examples = ["jaxonmodels"]
21
+
22
+ [build-system]
23
+ requires = ["hatchling"]
24
+ build-backend = "hatchling.build"
@@ -0,0 +1,4 @@
1
+ {
2
+ "venvPath": ".",
3
+ "venv": ".venv"
4
+ }
@@ -0,0 +1,8 @@
1
+ from .converter import (
2
+ autoconvert, # noqa
3
+ convert, # noqa
4
+ is_numerical, # noqa
5
+ move_running_fields_to_the_end, # noqa
6
+ pytree_to_fields, # noqa
7
+ state_dict_to_fields, # noqa
8
+ ) # noqa