statedict2pytree 0.6.0__tar.gz → 1.0.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree-1.0.0/.github/workflows/run_tests.yml +38 -0
- statedict2pytree-1.0.0/.pre-commit-config.yaml +12 -0
- {statedict2pytree-0.6.0 → statedict2pytree-1.0.0}/PKG-INFO +10 -28
- {statedict2pytree-0.6.0 → statedict2pytree-1.0.0}/README.md +6 -15
- statedict2pytree-1.0.0/docs/index.md +186 -0
- statedict2pytree-1.0.0/mkdocs.yml +19 -0
- statedict2pytree-1.0.0/pyproject.toml +24 -0
- statedict2pytree-1.0.0/pyrightconfig.json +4 -0
- statedict2pytree-1.0.0/statedict2pytree/__init__.py +8 -0
- statedict2pytree-1.0.0/statedict2pytree/converter.py +293 -0
- statedict2pytree-1.0.0/tests/test_batchnorm.py +56 -0
- statedict2pytree-1.0.0/tests/test_conv.py +52 -0
- statedict2pytree-1.0.0/tests/test_linear.py +36 -0
- statedict2pytree-1.0.0/uv.lock +1909 -0
- statedict2pytree-0.6.0/client/.gitignore +0 -3
- statedict2pytree-0.6.0/client/package-lock.json +0 -4540
- statedict2pytree-0.6.0/client/package.json +0 -36
- statedict2pytree-0.6.0/client/public/bundle.js +0 -10072
- statedict2pytree-0.6.0/client/public/bundle.js.map +0 -1
- statedict2pytree-0.6.0/client/public/index.html +0 -14
- statedict2pytree-0.6.0/client/public/input.css +0 -3
- statedict2pytree-0.6.0/client/public/output.css +0 -1617
- statedict2pytree-0.6.0/client/rollup.config.mjs +0 -44
- statedict2pytree-0.6.0/client/src/App.svelte +0 -584
- statedict2pytree-0.6.0/client/src/empty.ts +0 -0
- statedict2pytree-0.6.0/client/src/main.js +0 -8
- statedict2pytree-0.6.0/client/tailwind.config.js +0 -8
- statedict2pytree-0.6.0/client/tsconfig.json +0 -5
- statedict2pytree-0.6.0/pyproject.toml +0 -62
- {statedict2pytree-0.6.0 → statedict2pytree-1.0.0}/.gitignore +0 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
name: Run tests
|
|
2
|
+
|
|
3
|
+
on: [pull_request]
|
|
4
|
+
|
|
5
|
+
jobs:
|
|
6
|
+
run-test:
|
|
7
|
+
strategy:
|
|
8
|
+
matrix:
|
|
9
|
+
python-version: ["3.10"]
|
|
10
|
+
os: [ubuntu-latest]
|
|
11
|
+
fail-fast: false
|
|
12
|
+
runs-on: ${{ matrix.os }}
|
|
13
|
+
steps:
|
|
14
|
+
- name: Checkout code
|
|
15
|
+
uses: actions/checkout@v4
|
|
16
|
+
|
|
17
|
+
- name: Set up Python ${{ matrix.python-version }}
|
|
18
|
+
uses: actions/setup-python@v5
|
|
19
|
+
with:
|
|
20
|
+
python-version: ${{ matrix.python-version }}
|
|
21
|
+
|
|
22
|
+
- name: Install uv
|
|
23
|
+
uses: astral-sh/setup-uv@v1
|
|
24
|
+
|
|
25
|
+
- name: Create virtual environment and install dependencies
|
|
26
|
+
run: |
|
|
27
|
+
uv venv --python ${{ matrix.python-version }}
|
|
28
|
+
source .venv/bin/activate
|
|
29
|
+
uv pip install -e ".[dev]"
|
|
30
|
+
uv pip install -e .
|
|
31
|
+
|
|
32
|
+
- name: Run pre-commit hooks
|
|
33
|
+
uses: pre-commit/action@v3.0.1
|
|
34
|
+
|
|
35
|
+
- name: Run tests
|
|
36
|
+
run: |
|
|
37
|
+
source .venv/bin/activate
|
|
38
|
+
python3 -m pytest
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
repos:
|
|
2
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
3
|
+
rev: v0.11.12
|
|
4
|
+
hooks:
|
|
5
|
+
- id: ruff-check
|
|
6
|
+
- id: ruff-format
|
|
7
|
+
- repo: https://github.com/RobertCraigie/pyright-python
|
|
8
|
+
rev: v1.1.351
|
|
9
|
+
hooks:
|
|
10
|
+
- id: pyright
|
|
11
|
+
additional_dependencies:
|
|
12
|
+
[beartype, jax, jaxtyping, pytest, typing_extensions]
|
|
@@ -1,58 +1,41 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: statedict2pytree
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Converts torch models into PyTrees for Equinox
|
|
5
5
|
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
6
|
Requires-Python: ~=3.10
|
|
7
|
-
Requires-Dist: anthropic
|
|
8
7
|
Requires-Dist: beartype
|
|
9
8
|
Requires-Dist: equinox
|
|
10
|
-
Requires-Dist:
|
|
11
|
-
Requires-Dist: jax
|
|
9
|
+
Requires-Dist: jax>=0.6.1
|
|
12
10
|
Requires-Dist: jaxlib
|
|
13
|
-
Requires-Dist: jaxonmodels
|
|
14
11
|
Requires-Dist: jaxtyping
|
|
15
|
-
Requires-Dist: loguru
|
|
16
|
-
Requires-Dist: penzai
|
|
17
12
|
Requires-Dist: pydantic
|
|
18
|
-
Requires-Dist:
|
|
19
|
-
Requires-Dist: python-dotenv
|
|
20
|
-
Requires-Dist: torch
|
|
21
|
-
Requires-Dist: torchvision
|
|
22
|
-
Requires-Dist: typing-extensions
|
|
13
|
+
Requires-Dist: tqdm
|
|
23
14
|
Provides-Extra: dev
|
|
24
15
|
Requires-Dist: mkdocs; extra == 'dev'
|
|
25
|
-
Requires-Dist: nox; extra == 'dev'
|
|
26
16
|
Requires-Dist: pre-commit; extra == 'dev'
|
|
27
17
|
Requires-Dist: pytest; extra == 'dev'
|
|
18
|
+
Requires-Dist: torch; extra == 'dev'
|
|
28
19
|
Provides-Extra: examples
|
|
29
20
|
Requires-Dist: jaxonmodels; extra == 'examples'
|
|
30
21
|
Description-Content-Type: text/markdown
|
|
31
22
|
|
|
32
23
|
# statedict2pytree
|
|
33
24
|
|
|
34
|
-

|
|
35
25
|
|
|
36
|
-
##
|
|
26
|
+
## Update:
|
|
37
27
|
|
|
38
|
-
|
|
28
|
+
For examples for `statedict2pytree`, check out my other repository [jaxonmodels](https://github.com/Artur-Galstyan/jaxonmodels).
|
|
39
29
|
|
|
30
|
+
## Docs
|
|
40
31
|
|
|
41
|
-
|
|
32
|
+
Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
|
|
42
33
|
|
|
43
|
-
This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
|
|
44
|
-
PRs and other contributions are *highly* welcome! :)
|
|
45
34
|
|
|
46
35
|
## Info
|
|
47
36
|
|
|
48
|
-
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees
|
|
37
|
+
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
|
|
49
38
|
|
|
50
|
-
## Features
|
|
51
|
-
|
|
52
|
-
- Convert PyTorch statedicts to JAX pytrees
|
|
53
|
-
- Handle large models with chunked file conversion
|
|
54
|
-
- Provide an "intuitive-ish" UI for parameter mapping
|
|
55
|
-
- Support both in-memory and file-based conversions
|
|
56
39
|
|
|
57
40
|
## Installation
|
|
58
41
|
|
|
@@ -64,7 +47,6 @@ The goal of this package is to simplify the conversion from PyTorch models into
|
|
|
64
47
|
|
|
65
48
|
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
66
49
|
|
|
67
|
-
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
68
50
|
|
|
69
51
|
## Shape Matching? What's that?
|
|
70
52
|
|
|
@@ -73,8 +55,8 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
|
|
|
73
55
|
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
74
56
|
|
|
75
57
|
|
|
76
|
-
|
|
77
58
|
### Disclaimer
|
|
78
59
|
|
|
79
60
|
Some of the docstrings and the docs have been written with the help of
|
|
80
61
|
Claude.
|
|
62
|
+
|
|
@@ -1,27 +1,19 @@
|
|
|
1
1
|
# statedict2pytree
|
|
2
2
|
|
|
3
|
-

|
|
4
3
|
|
|
5
|
-
##
|
|
4
|
+
## Update:
|
|
6
5
|
|
|
7
|
-
|
|
6
|
+
For examples for `statedict2pytree`, check out my other repository [jaxonmodels](https://github.com/Artur-Galstyan/jaxonmodels).
|
|
8
7
|
|
|
8
|
+
## Docs
|
|
9
9
|
|
|
10
|
-
|
|
10
|
+
Docs can be found [here](https://artur-galstyan.github.io/statedict2pytree/).
|
|
11
11
|
|
|
12
|
-
This package is still in its infancy and hihgly experimental! The code works, but it's far from perfect. With more and more iterations, it will eventually become stable and well tested.
|
|
13
|
-
PRs and other contributions are *highly* welcome! :)
|
|
14
12
|
|
|
15
13
|
## Info
|
|
16
14
|
|
|
17
|
-
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees
|
|
15
|
+
`statedict2pytree` is a powerful tool for converting PyTorch state dictionaries to JAX pytrees, specifically for Equinox
|
|
18
16
|
|
|
19
|
-
## Features
|
|
20
|
-
|
|
21
|
-
- Convert PyTorch statedicts to JAX pytrees
|
|
22
|
-
- Handle large models with chunked file conversion
|
|
23
|
-
- Provide an "intuitive-ish" UI for parameter mapping
|
|
24
|
-
- Support both in-memory and file-based conversions
|
|
25
17
|
|
|
26
18
|
## Installation
|
|
27
19
|
|
|
@@ -33,7 +25,6 @@ The goal of this package is to simplify the conversion from PyTorch models into
|
|
|
33
25
|
|
|
34
26
|
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
35
27
|
|
|
36
|
-
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
37
28
|
|
|
38
29
|
## Shape Matching? What's that?
|
|
39
30
|
|
|
@@ -42,8 +33,8 @@ Currently, there is no sophisticated shape matching in place. Two matrices are c
|
|
|
42
33
|
(8, 1, 1) and (8, ) match, because (8 _ 1 _ 1 = 8)
|
|
43
34
|
|
|
44
35
|
|
|
45
|
-
|
|
46
36
|
### Disclaimer
|
|
47
37
|
|
|
48
38
|
Some of the docstrings and the docs have been written with the help of
|
|
49
39
|
Claude.
|
|
40
|
+
|
|
@@ -0,0 +1,186 @@
|
|
|
1
|
+
# Quickstart Guide
|
|
2
|
+
|
|
3
|
+
## Installation
|
|
4
|
+
|
|
5
|
+
To install `statedict2pytree`, run:
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install statedict2pytree
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
---
|
|
12
|
+
|
|
13
|
+
## Basic Usage
|
|
14
|
+
|
|
15
|
+
There are 4-5 main functions you might interact with:
|
|
16
|
+
|
|
17
|
+
* `autoconvert`
|
|
18
|
+
* `convert`
|
|
19
|
+
* `pytree_to_fields`
|
|
20
|
+
* `state_dict_to_fields`
|
|
21
|
+
* `move_running_fields_to_the_end` (optional helper)
|
|
22
|
+
|
|
23
|
+
---
|
|
24
|
+
|
|
25
|
+
## General Information
|
|
26
|
+
|
|
27
|
+
`statedict2pytree` primarily aligns your JAX PyTree and the PyTorch `state_dict` side-by-side. It then checks if the shapes of the aligned weights match. If they do, it converts the PyTorch tensors to JAX arrays and places them into a new PyTree with the same structure as your original JAX PyTree.
|
|
28
|
+
|
|
29
|
+
**This means that the order and the shape of the arrays in your PyTree and the `state_dict` must match after any optional reordering!** The `pytree_to_fields` function uses a filter (defaulting to `equinox.is_array`) to determine which elements are considered fields.
|
|
30
|
+
|
|
31
|
+
For example, this conversion will **work** ✅:
|
|
32
|
+
|
|
33
|
+
| Parameter | JAX Shape | PyTorch Shape |
|
|
34
|
+
| :---------------- | :----------- | :------------ |
|
|
35
|
+
| `linear.weight` | `(2, 2)` | `(2, 2)` |
|
|
36
|
+
| `linear.bias` | `(2,)` | `(2,)` |
|
|
37
|
+
| `conv.weight` | `(1, 1, 2, 2)` | `(1, 1, 2, 2)`|
|
|
38
|
+
| `conv.bias` | `(1,)` | `(1,)` |
|
|
39
|
+
|
|
40
|
+
Since the shapes match when aligned in the same order, the conversion is successful.
|
|
41
|
+
|
|
42
|
+
On the other hand, this will **not work** ❌:
|
|
43
|
+
|
|
44
|
+
| Parameter | JAX Shape | PyTorch Shape | Mismatch? |
|
|
45
|
+
| :---------------- | :----------- | :------------ | :-------- |
|
|
46
|
+
| `linear.weight` | `(2, 2)` | `(3, 2)` | Yes |
|
|
47
|
+
| `linear.bias` | `(2,)` | `(3,)` | Yes |
|
|
48
|
+
| `conv.weight` | `(1, 1, 2, 2)` | `(1, 1, 2, 2)`| No |
|
|
49
|
+
| `conv.bias` | `(1,)` | `(1,)` | No |
|
|
50
|
+
|
|
51
|
+
This conversion will fail because the shapes of `model.linear.weight` and `model.linear.bias` don't match between the PyTree and the state dict.
|
|
52
|
+
|
|
53
|
+
Another reason why the conversion might fail is if the **order** of parameters (and thus the shapes of misaligned parameters) doesn't match:
|
|
54
|
+
|
|
55
|
+
| JAX Parameter (Model Order) | JAX Shape | PyTorch Counterpart (`state_dict` Order) | PyTorch Shape | Issue if Matched Sequentially |
|
|
56
|
+
| :-------------------------- | :------------- | :--------------------------------------- | :------------- | :------------------------------------------------ |
|
|
57
|
+
| `model['conv']['weight']` | `(1, 1, 2, 2)` | `state_dict['model.linear.weight']` | `(2, 2)` | Order: JAX `conv.w` `(1122)` vs PT `linear.w` `(22)` |
|
|
58
|
+
| `model['conv']['bias']` | `(1,)` | `state_dict['model.linear.bias']` | `(2,)` | Order: JAX `conv.b` `(1,)` vs PT `linear.b` `(2,)` |
|
|
59
|
+
| `model['linear']['weight']` | `(2, 2)` | `state_dict['model.conv.weight']` | `(1, 1, 2, 2)` | Order: JAX `linear.w` `(22)` vs PT `conv.w` `(1122)`|
|
|
60
|
+
| `model['linear']['bias']` | `(2,)` | `state_dict['model.conv.bias']` | `(1,)` | Order: JAX `linear.b` `(2,)` vs PT `conv.b` `(1,)` |
|
|
61
|
+
|
|
62
|
+
To help with the order issue, you can provide a `list[str]` specifying the desired order of PyTree fields (matching the `state_dict`'s conceptual order, or vice-versa if you reorder `state_dict` fields). This is especially helpful when you can't easily force the correct order using `move_running_fields_to_the_end`. For the example above, if your PyTree expects `conv` then `linear`, the list of strings representing the *names from the state\_dict in the JAX PyTree's desired order* would be:
|
|
63
|
+
|
|
64
|
+
```python
|
|
65
|
+
['model.conv.weight', 'model.conv.bias', 'model.linear.weight', 'model.linear.bias']
|
|
66
|
+
```
|
|
67
|
+
This list would be passed to `pytree_to_fields` via `autoconvert`'s `pytree_model_order` argument to ensure `jaxfields` are in this sequence. Alternatively, you could reorder `torchfields` using `move_running_fields_to_the_end` or other custom logic.
|
|
68
|
+
|
|
69
|
+
---
|
|
70
|
+
|
|
71
|
+
## API Reference
|
|
72
|
+
|
|
73
|
+
### `autoconvert`
|
|
74
|
+
|
|
75
|
+
This is the simplest, highest-level function for most use cases.
|
|
76
|
+
|
|
77
|
+
```python
|
|
78
|
+
def autoconvert(
|
|
79
|
+
pytree: PyTree,
|
|
80
|
+
state_dict: dict,
|
|
81
|
+
pytree_model_order: list[str] | None = None
|
|
82
|
+
) -> PyTree:
|
|
83
|
+
...
|
|
84
|
+
```
|
|
85
|
+
|
|
86
|
+
You provide your JAX `pytree` and the PyTorch `state_dict`. Optionally, you can give `pytree_model_order` (a list of strings representing `jax.tree_util.keystr(path)`) to ensure the JAX fields are processed in a specific sequence. It handles the steps of field extraction (using `pytree_to_fields` with its default `filter=eqx.is_array`), alignment, and conversion, returning the populated JAX PyTree. If you need custom filtering for PyTree leaves, you should use `pytree_to_fields` and `convert` separately.
|
|
87
|
+
|
|
88
|
+
* **Parameters**:
|
|
89
|
+
* `pytree`: The JAX PyTree (e.g., an Equinox model) whose structure is the target.
|
|
90
|
+
* `state_dict`: The PyTorch state dictionary containing the weights.
|
|
91
|
+
* `pytree_model_order` (optional): A list of JAX KeyPath strings (like `'.layers.0.linear.weight'`). If provided, JAX fields will be ordered according to this list. This is useful if the automatic PyTree traversal order doesn't match the `state_dict` order.
|
|
92
|
+
* **Returns**: A new JAX PyTree with the same structure as the input `pytree`, but with weights populated from the `state_dict`.
|
|
93
|
+
|
|
94
|
+
---
|
|
95
|
+
|
|
96
|
+
### `convert`
|
|
97
|
+
|
|
98
|
+
This is the core function that performs the actual conversion once the JAX PyTree fields and PyTorch `state_dict` fields have been extracted and aligned.
|
|
99
|
+
|
|
100
|
+
```python
|
|
101
|
+
def convert(
|
|
102
|
+
state_dict: dict[str, Any],
|
|
103
|
+
pytree: PyTree,
|
|
104
|
+
jaxfields: list[JaxField],
|
|
105
|
+
state_indices: dict | None,
|
|
106
|
+
torchfields: list[TorchField],
|
|
107
|
+
dtype: Any | None = None,
|
|
108
|
+
) -> PyTree:
|
|
109
|
+
...
|
|
110
|
+
```
|
|
111
|
+
|
|
112
|
+
It iterates through the aligned `jaxfields` and `torchfields`, checks for shape compatibility (reshapability), converts PyTorch tensors (expected as values in `state_dict`) to JAX arrays (optionally casting `dtype`), and inserts them into the correct place in the JAX PyTree.
|
|
113
|
+
|
|
114
|
+
* **Parameters**:
|
|
115
|
+
* `state_dict`: The original PyTorch state dictionary. Values are expected to be tensor-like (e.g., `torch.Tensor`).
|
|
116
|
+
* `pytree`: The JAX PyTree that will be populated.
|
|
117
|
+
* `jaxfields`: An ordered list of `JaxField` objects (obtained from `pytree_to_fields`) representing the leaves of the JAX PyTree.
|
|
118
|
+
* `state_indices`: A dictionary mapping state markers to `eqx.nn.StateIndex` objects, used for handling Equinox stateful layers.
|
|
119
|
+
* `torchfields`: An ordered list of `TorchField` objects (obtained from `state_dict_to_fields`) representing the tensors in the PyTorch `state_dict`. **This list must be ordered to match `jaxfields`**.
|
|
120
|
+
* `dtype` (optional): The JAX data type to convert floating-point tensors to (e.g., `jnp.float32`). Defaults to JAX's current default floating-point type.
|
|
121
|
+
* **Returns**: A new JAX PyTree populated with weights from the `state_dict`.
|
|
122
|
+
|
|
123
|
+
---
|
|
124
|
+
|
|
125
|
+
### `pytree_to_fields`
|
|
126
|
+
|
|
127
|
+
This function traverses a JAX PyTree and extracts information about its array leaves based on a filter.
|
|
128
|
+
|
|
129
|
+
```python
|
|
130
|
+
def pytree_to_fields(
|
|
131
|
+
pytree: PyTree,
|
|
132
|
+
model_order: list[str] | None = None,
|
|
133
|
+
filter: Callable[[Array], bool] = eqx.is_array,
|
|
134
|
+
) -> tuple[list[JaxField], dict | None]:
|
|
135
|
+
...
|
|
136
|
+
```
|
|
137
|
+
|
|
138
|
+
It identifies all JAX arrays (or other elements satisfying the `filter`) within the `pytree`, recording their `KeyPath` (path within the PyTree) and shape. If `model_order` is provided, it attempts to reorder the extracted fields according to that list. This is crucial for ensuring the JAX fields align correctly with the PyTorch fields.
|
|
139
|
+
|
|
140
|
+
* **Parameters**:
|
|
141
|
+
* `pytree`: The JAX PyTree to analyze.
|
|
142
|
+
* `model_order` (optional): A list of strings, where each string is a `jax.tree_util.keystr` representation of a `KeyPath` to an array leaf in the `pytree`. If provided, the output `JaxField` list will be sorted according to this order, with any fields not in `model_order` appended at the end.
|
|
143
|
+
* `filter` (optional): A callable that takes a PyTree leaf (e.g., an array) and returns `True` if it should be considered a field to be converted, `False` otherwise. Defaults to `equinox.is_array`.
|
|
144
|
+
* **Returns**: A tuple containing:
|
|
145
|
+
* `list[JaxField]`: A list of `JaxField` objects, each describing a filtered leaf in the PyTree (path, shape).
|
|
146
|
+
* `dict | None`: A dictionary containing information about `eqx.nn.StateIndex` objects found in the PyTree, or `None` if none are found.
|
|
147
|
+
|
|
148
|
+
---
|
|
149
|
+
|
|
150
|
+
### `state_dict_to_fields`
|
|
151
|
+
|
|
152
|
+
This function processes a PyTorch `state_dict` to extract information about its tensors.
|
|
153
|
+
|
|
154
|
+
```python
|
|
155
|
+
def state_dict_to_fields(
|
|
156
|
+
state_dict: dict[str, Any],
|
|
157
|
+
) -> list[TorchField]:
|
|
158
|
+
...
|
|
159
|
+
```
|
|
160
|
+
|
|
161
|
+
It iterates through the `state_dict`, creating a `TorchField` object for each value that has a `shape` attribute and a non-empty shape (typically tensors). This object stores the tensor's name (key in the `state_dict`) and its shape.
|
|
162
|
+
|
|
163
|
+
* **Parameters**:
|
|
164
|
+
* `state_dict`: The PyTorch state dictionary. Values are typically `torch.Tensor` or other array-like objects.
|
|
165
|
+
* **Returns**: A list of `TorchField` objects, each describing a tensor in the `state_dict` (path/key, shape). The order matches the iteration order of the input `state_dict`.
|
|
166
|
+
|
|
167
|
+
---
|
|
168
|
+
|
|
169
|
+
### `move_running_fields_to_the_end`
|
|
170
|
+
|
|
171
|
+
This is an optional utility function to help reorder fields extracted from a PyTorch `state_dict`.
|
|
172
|
+
|
|
173
|
+
```python
|
|
174
|
+
def move_running_fields_to_the_end(
|
|
175
|
+
torchfields: list[TorchField],
|
|
176
|
+
identifier: str = "running_"
|
|
177
|
+
):
|
|
178
|
+
...
|
|
179
|
+
```
|
|
180
|
+
|
|
181
|
+
It's particularly useful for models with layers like `BatchNorm`, where PyTorch often stores `running_mean` and `running_var` interspersed with weights and biases, while Equinox (a common JAX library) typically expects stateful components like these at the end of a layer's parameter list. This function moves any `TorchField` whose path contains the `identifier` (defaulting to `"running_"`) to the end of the list.
|
|
182
|
+
|
|
183
|
+
* **Parameters**:
|
|
184
|
+
* `torchfields`: The list of `TorchField` objects to be reordered.
|
|
185
|
+
* `identifier` (optional): A string that, if found within a `TorchField`'s path, will cause that field to be moved to the end of the list. Default is `"running_"`.
|
|
186
|
+
* **Returns**: The modified list of `TorchField` objects with identified fields moved to the end.
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
site_name: Statedict2Pytree
|
|
2
|
+
|
|
3
|
+
theme:
|
|
4
|
+
name: material
|
|
5
|
+
|
|
6
|
+
plugins:
|
|
7
|
+
- search
|
|
8
|
+
- mkdocstrings:
|
|
9
|
+
default_handler: python
|
|
10
|
+
|
|
11
|
+
nav:
|
|
12
|
+
- Home: index.md
|
|
13
|
+
|
|
14
|
+
markdown_extensions:
|
|
15
|
+
- pymdownx.highlight:
|
|
16
|
+
anchor_linenums: true
|
|
17
|
+
- pymdownx.inlinehilite
|
|
18
|
+
- pymdownx.snippets
|
|
19
|
+
- pymdownx.superfences
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "statedict2pytree"
|
|
3
|
+
version = "1.0.0"
|
|
4
|
+
description = "Converts torch models into PyTrees for Equinox"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = "~=3.10"
|
|
7
|
+
authors = [{ name = "Artur A. Galstyan", email = "mail@arturgalstyan.dev" }]
|
|
8
|
+
dependencies = [
|
|
9
|
+
"jax>=0.6.1",
|
|
10
|
+
"equinox",
|
|
11
|
+
"jaxlib",
|
|
12
|
+
"beartype",
|
|
13
|
+
"jaxtyping",
|
|
14
|
+
"pydantic",
|
|
15
|
+
"tqdm",
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
[project.optional-dependencies]
|
|
19
|
+
dev = ["pre-commit", "pytest", "mkdocs", "torch"]
|
|
20
|
+
examples = ["jaxonmodels"]
|
|
21
|
+
|
|
22
|
+
[build-system]
|
|
23
|
+
requires = ["hatchling"]
|
|
24
|
+
build-backend = "hatchling.build"
|