statedict2pytree 0.1.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- statedict2pytree-0.1.2/.gitignore +161 -0
- statedict2pytree-0.1.2/.pre-commit-config.yaml +16 -0
- statedict2pytree-0.1.2/PKG-INFO +43 -0
- statedict2pytree-0.1.2/README.md +20 -0
- statedict2pytree-0.1.2/examples/convert_resnet.py +16 -0
- statedict2pytree-0.1.2/examples/doggo.jpeg +0 -0
- statedict2pytree-0.1.2/examples/resnet.py +378 -0
- statedict2pytree-0.1.2/examples/test_resnet_inference.py +68 -0
- statedict2pytree-0.1.2/package-lock.json +1424 -0
- statedict2pytree-0.1.2/package.json +10 -0
- statedict2pytree-0.1.2/pyproject.toml +48 -0
- statedict2pytree-0.1.2/statedict2pytree/__init__.py +4 -0
- statedict2pytree-0.1.2/statedict2pytree/statedict2pytree.py +194 -0
- statedict2pytree-0.1.2/statedict2pytree/static/input.css +3 -0
- statedict2pytree-0.1.2/statedict2pytree/static/output.css +1734 -0
- statedict2pytree-0.1.2/statedict2pytree/templates/index.html +246 -0
- statedict2pytree-0.1.2/tailwind.config.js +8 -0
- statedict2pytree-0.1.2/torch2jax.png +0 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[cod]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py,cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
#Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# poetry
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
102
|
+
#poetry.lock
|
|
103
|
+
|
|
104
|
+
# pdm
|
|
105
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
106
|
+
#pdm.lock
|
|
107
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
|
108
|
+
# in version control.
|
|
109
|
+
# https://pdm.fming.dev/#use-with-ide
|
|
110
|
+
.pdm.toml
|
|
111
|
+
|
|
112
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
113
|
+
__pypackages__/
|
|
114
|
+
|
|
115
|
+
# Celery stuff
|
|
116
|
+
celerybeat-schedule
|
|
117
|
+
celerybeat.pid
|
|
118
|
+
|
|
119
|
+
# SageMath parsed files
|
|
120
|
+
*.sage.py
|
|
121
|
+
|
|
122
|
+
# Environments
|
|
123
|
+
.env
|
|
124
|
+
.venv
|
|
125
|
+
env/
|
|
126
|
+
venv/
|
|
127
|
+
ENV/
|
|
128
|
+
env.bak/
|
|
129
|
+
venv.bak/
|
|
130
|
+
|
|
131
|
+
# Spyder project settings
|
|
132
|
+
.spyderproject
|
|
133
|
+
.spyproject
|
|
134
|
+
|
|
135
|
+
# Rope project settings
|
|
136
|
+
.ropeproject
|
|
137
|
+
|
|
138
|
+
# mkdocs documentation
|
|
139
|
+
/site
|
|
140
|
+
|
|
141
|
+
# mypy
|
|
142
|
+
.mypy_cache/
|
|
143
|
+
.dmypy.json
|
|
144
|
+
dmypy.json
|
|
145
|
+
|
|
146
|
+
# Pyre type checker
|
|
147
|
+
.pyre/
|
|
148
|
+
|
|
149
|
+
# pytype static type analyzer
|
|
150
|
+
.pytype/
|
|
151
|
+
|
|
152
|
+
# Cython debug symbols
|
|
153
|
+
cython_debug/
|
|
154
|
+
|
|
155
|
+
# PyCharm
|
|
156
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
157
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
158
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
159
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
160
|
+
#.idea/
|
|
161
|
+
node_modules/
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
repos:
|
|
2
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
3
|
+
# Ruff version.
|
|
4
|
+
rev: v0.4.4
|
|
5
|
+
hooks:
|
|
6
|
+
# Run the linter.
|
|
7
|
+
- id: ruff
|
|
8
|
+
args: [--fix]
|
|
9
|
+
# Run the formatter.
|
|
10
|
+
- id: ruff-format
|
|
11
|
+
- repo: https://github.com/RobertCraigie/pyright-python
|
|
12
|
+
rev: v1.1.351
|
|
13
|
+
hooks:
|
|
14
|
+
- id: pyright
|
|
15
|
+
additional_dependencies:
|
|
16
|
+
[beartype, jax, jaxtyping, pytest, typing_extensions]
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: statedict2pytree
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Converts torch models into PyTrees for Equinox
|
|
5
|
+
Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
|
|
6
|
+
Requires-Python: ~=3.10
|
|
7
|
+
Requires-Dist: beartype
|
|
8
|
+
Requires-Dist: equinox>=0.11.4
|
|
9
|
+
Requires-Dist: flask
|
|
10
|
+
Requires-Dist: jax
|
|
11
|
+
Requires-Dist: jaxlib
|
|
12
|
+
Requires-Dist: jaxtyping
|
|
13
|
+
Requires-Dist: loguru
|
|
14
|
+
Requires-Dist: pydantic
|
|
15
|
+
Requires-Dist: torch
|
|
16
|
+
Requires-Dist: typing-extensions
|
|
17
|
+
Provides-Extra: dev
|
|
18
|
+
Requires-Dist: mkdocs; extra == 'dev'
|
|
19
|
+
Requires-Dist: nox; extra == 'dev'
|
|
20
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
21
|
+
Requires-Dist: pytest; extra == 'dev'
|
|
22
|
+
Description-Content-Type: text/markdown
|
|
23
|
+
|
|
24
|
+
# statedict2pytree
|
|
25
|
+
|
|
26
|
+

|
|
27
|
+
|
|
28
|
+
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
29
|
+
|
|
30
|
+
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
31
|
+
|
|
32
|
+
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
33
|
+
|
|
34
|
+
## Get Started
|
|
35
|
+
|
|
36
|
+
### Installation
|
|
37
|
+
|
|
38
|
+
Run
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
pip install statedict2pytree
|
|
42
|
+
|
|
43
|
+
```
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# statedict2pytree
|
|
2
|
+
|
|
3
|
+

|
|
4
|
+
|
|
5
|
+
The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
|
|
6
|
+
|
|
7
|
+
Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
|
|
8
|
+
|
|
9
|
+
(Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
|
|
10
|
+
|
|
11
|
+
## Get Started
|
|
12
|
+
|
|
13
|
+
### Installation
|
|
14
|
+
|
|
15
|
+
Run
|
|
16
|
+
|
|
17
|
+
```bash
|
|
18
|
+
pip install statedict2pytree
|
|
19
|
+
|
|
20
|
+
```
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import statedict2pytree as s2p
|
|
3
|
+
from tests.resnet import resnet50
|
|
4
|
+
from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convert_resnet():
|
|
8
|
+
resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
9
|
+
resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
10
|
+
state_dict = resnet_torch.state_dict()
|
|
11
|
+
|
|
12
|
+
s2p.start_conversion(resnet_jax, state_dict)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
if __name__ == "__main__":
|
|
16
|
+
convert_resnet()
|
|
Binary file
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
from beartype.typing import Optional, Type
|
|
4
|
+
from equinox.nn import State
|
|
5
|
+
from jaxtyping import Array, PRNGKeyArray
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def conv3x3(
|
|
9
|
+
in_channels: int,
|
|
10
|
+
out_channels: int,
|
|
11
|
+
stride: int = 1,
|
|
12
|
+
groups: int = 1,
|
|
13
|
+
*,
|
|
14
|
+
key: PRNGKeyArray,
|
|
15
|
+
) -> eqx.nn.Conv2d:
|
|
16
|
+
return eqx.nn.Conv2d(
|
|
17
|
+
in_channels,
|
|
18
|
+
out_channels,
|
|
19
|
+
kernel_size=3,
|
|
20
|
+
stride=stride,
|
|
21
|
+
padding=1,
|
|
22
|
+
groups=groups,
|
|
23
|
+
use_bias=False,
|
|
24
|
+
key=key,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def conv1x1(
|
|
29
|
+
in_channels: int, out_channels: int, stride: int = 1, *, key: PRNGKeyArray
|
|
30
|
+
) -> eqx.nn.Conv2d:
|
|
31
|
+
return eqx.nn.Conv2d(
|
|
32
|
+
in_channels, out_channels, kernel_size=1, stride=stride, use_bias=False, key=key
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class Downsample(eqx.Module):
|
|
37
|
+
conv: eqx.nn.Conv2d
|
|
38
|
+
norm: eqx.nn.BatchNorm
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self, in_channels: int, out_channels: int, stride: int, *, key: PRNGKeyArray
|
|
42
|
+
) -> None:
|
|
43
|
+
self.conv = conv1x1(in_channels, out_channels, stride, key=key)
|
|
44
|
+
self.norm = eqx.nn.BatchNorm(out_channels, axis_name="batch")
|
|
45
|
+
|
|
46
|
+
def __call__(self, x: Array, state: State) -> tuple[Array, State]:
|
|
47
|
+
x = self.conv(x)
|
|
48
|
+
x, state = self.norm(x, state)
|
|
49
|
+
|
|
50
|
+
return x, state
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class BasicBlock(eqx.Module):
|
|
54
|
+
conv1: eqx.nn.Conv2d
|
|
55
|
+
bn1: eqx.nn.BatchNorm
|
|
56
|
+
conv2: eqx.nn.Conv2d
|
|
57
|
+
bn2: eqx.nn.BatchNorm
|
|
58
|
+
downsample: Optional[Downsample]
|
|
59
|
+
|
|
60
|
+
def __init__(
|
|
61
|
+
self,
|
|
62
|
+
in_channels: int,
|
|
63
|
+
out_channels: int,
|
|
64
|
+
stride: int = 1,
|
|
65
|
+
downsample: Optional[Downsample] = None,
|
|
66
|
+
groups: int = 1,
|
|
67
|
+
base_width: int = 64,
|
|
68
|
+
*,
|
|
69
|
+
key: PRNGKeyArray,
|
|
70
|
+
):
|
|
71
|
+
if groups != 1 or base_width != 64:
|
|
72
|
+
raise ValueError("BasicBlock only supports groups=1 and base_width=64")
|
|
73
|
+
key, conv1_key, conv2_key = jax.random.split(key, 3)
|
|
74
|
+
self.conv1 = conv3x3(in_channels, out_channels, stride, key=conv1_key)
|
|
75
|
+
self.bn1 = eqx.nn.BatchNorm(out_channels, axis_name="batch")
|
|
76
|
+
self.conv2 = conv3x3(out_channels, out_channels, key=conv2_key)
|
|
77
|
+
self.bn2 = eqx.nn.BatchNorm(out_channels, axis_name="batch")
|
|
78
|
+
self.downsample = downsample
|
|
79
|
+
|
|
80
|
+
def __call__(self, x: Array, state: State) -> tuple[Array, State]:
|
|
81
|
+
identity = x
|
|
82
|
+
|
|
83
|
+
out = self.conv1(x)
|
|
84
|
+
out, state = self.bn1(out, state)
|
|
85
|
+
out = jax.nn.relu(out)
|
|
86
|
+
|
|
87
|
+
out = self.conv2(out)
|
|
88
|
+
out, state = self.bn2(out, state)
|
|
89
|
+
|
|
90
|
+
if self.downsample is not None:
|
|
91
|
+
identity, state = self.downsample(x, state)
|
|
92
|
+
|
|
93
|
+
out += identity
|
|
94
|
+
out = jax.nn.relu(out)
|
|
95
|
+
|
|
96
|
+
return out, state
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class Bottleneck(eqx.Module):
|
|
100
|
+
conv1: eqx.nn.Conv2d
|
|
101
|
+
bn1: eqx.nn.BatchNorm
|
|
102
|
+
conv2: eqx.nn.Conv2d
|
|
103
|
+
bn2: eqx.nn.BatchNorm
|
|
104
|
+
conv3: eqx.nn.Conv2d
|
|
105
|
+
bn3: eqx.nn.BatchNorm
|
|
106
|
+
|
|
107
|
+
downsample: Optional[Downsample]
|
|
108
|
+
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
in_channels: int,
|
|
112
|
+
out_channels: int,
|
|
113
|
+
stride: int = 1,
|
|
114
|
+
downsample: Optional[Downsample] = None,
|
|
115
|
+
groups: int = 1,
|
|
116
|
+
base_width: int = 64,
|
|
117
|
+
*,
|
|
118
|
+
key: PRNGKeyArray,
|
|
119
|
+
) -> None:
|
|
120
|
+
width = int(out_channels * (base_width / 64.0)) * groups
|
|
121
|
+
conv1_key, conv2_key, conv3_key = jax.random.split(key, 3)
|
|
122
|
+
expansion = 4
|
|
123
|
+
self.conv1 = conv1x1(in_channels, width, key=conv1_key)
|
|
124
|
+
self.bn1 = eqx.nn.BatchNorm(width, axis_name="batch")
|
|
125
|
+
self.conv2 = conv3x3(width, width, stride, groups, key=conv2_key)
|
|
126
|
+
self.bn2 = eqx.nn.BatchNorm(width, axis_name="batch")
|
|
127
|
+
self.conv3 = conv1x1(width, out_channels * expansion, key=conv3_key)
|
|
128
|
+
self.bn3 = eqx.nn.BatchNorm(out_channels * expansion, axis_name="batch")
|
|
129
|
+
self.downsample = downsample
|
|
130
|
+
|
|
131
|
+
def __call__(self, x: Array, state: State) -> tuple[Array, State]:
|
|
132
|
+
identity = x
|
|
133
|
+
x = self.conv1(x)
|
|
134
|
+
x, state = self.bn1(x, state)
|
|
135
|
+
x = jax.nn.relu(x)
|
|
136
|
+
|
|
137
|
+
x = self.conv2(x)
|
|
138
|
+
x, state = self.bn2(x, state)
|
|
139
|
+
x = jax.nn.relu(x)
|
|
140
|
+
|
|
141
|
+
x = self.conv3(x)
|
|
142
|
+
x, state = self.bn3(x, state)
|
|
143
|
+
|
|
144
|
+
if self.downsample is not None:
|
|
145
|
+
identity, state = self.downsample(identity, state)
|
|
146
|
+
|
|
147
|
+
x += identity
|
|
148
|
+
x = jax.nn.relu(x)
|
|
149
|
+
|
|
150
|
+
return x, state
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
class ResnetLayer(eqx.Module):
|
|
154
|
+
layers: list[BasicBlock | Bottleneck]
|
|
155
|
+
|
|
156
|
+
def __init__(self, layers: list[BasicBlock | Bottleneck]) -> None:
|
|
157
|
+
self.layers = layers
|
|
158
|
+
|
|
159
|
+
def __call__(self, x: Array, state: State) -> tuple[Array, State]:
|
|
160
|
+
for l in self.layers:
|
|
161
|
+
x, state = l(x, state)
|
|
162
|
+
return x, state
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
class ResNet(eqx.Module):
|
|
166
|
+
in_channels: int = eqx.field(static=True)
|
|
167
|
+
|
|
168
|
+
conv1: eqx.nn.Conv2d
|
|
169
|
+
bn1: eqx.nn.BatchNorm
|
|
170
|
+
maxpool: eqx.nn.MaxPool2d
|
|
171
|
+
|
|
172
|
+
layer1: ResnetLayer
|
|
173
|
+
layer2: ResnetLayer
|
|
174
|
+
layer3: ResnetLayer
|
|
175
|
+
layer4: ResnetLayer
|
|
176
|
+
|
|
177
|
+
avgpool: eqx.nn.AdaptiveAvgPool2d
|
|
178
|
+
fc: eqx.nn.Linear
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
block: Type[BasicBlock | Bottleneck],
|
|
183
|
+
layers: list[int],
|
|
184
|
+
image_channels: int = 3,
|
|
185
|
+
num_classes: int = 1000,
|
|
186
|
+
groups: int = 1,
|
|
187
|
+
width_per_group: int = 64,
|
|
188
|
+
*,
|
|
189
|
+
key: PRNGKeyArray,
|
|
190
|
+
) -> None:
|
|
191
|
+
self.in_channels = 64
|
|
192
|
+
key, conv_key = jax.random.split(key)
|
|
193
|
+
self.conv1 = eqx.nn.Conv2d(
|
|
194
|
+
image_channels,
|
|
195
|
+
self.in_channels,
|
|
196
|
+
kernel_size=7,
|
|
197
|
+
stride=2,
|
|
198
|
+
padding=3,
|
|
199
|
+
use_bias=False,
|
|
200
|
+
key=conv_key,
|
|
201
|
+
)
|
|
202
|
+
self.bn1 = eqx.nn.BatchNorm(self.in_channels, axis_name="batch")
|
|
203
|
+
self.maxpool = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
204
|
+
key, *layer_keys = jax.random.split(key, len(layers) + 1)
|
|
205
|
+
self.layer1 = self._make_layer(
|
|
206
|
+
block,
|
|
207
|
+
out_channels=64,
|
|
208
|
+
num_residual_blocks=layers[0],
|
|
209
|
+
stride=1,
|
|
210
|
+
groups=groups,
|
|
211
|
+
base_width=width_per_group,
|
|
212
|
+
key=layer_keys[0],
|
|
213
|
+
)
|
|
214
|
+
self.layer2 = self._make_layer(
|
|
215
|
+
block,
|
|
216
|
+
out_channels=128,
|
|
217
|
+
num_residual_blocks=layers[1],
|
|
218
|
+
stride=2,
|
|
219
|
+
groups=groups,
|
|
220
|
+
base_width=width_per_group,
|
|
221
|
+
key=layer_keys[1],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
self.layer3 = self._make_layer(
|
|
225
|
+
block,
|
|
226
|
+
out_channels=256,
|
|
227
|
+
num_residual_blocks=layers[2],
|
|
228
|
+
stride=2,
|
|
229
|
+
groups=groups,
|
|
230
|
+
base_width=width_per_group,
|
|
231
|
+
key=layer_keys[2],
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
self.layer4 = self._make_layer(
|
|
235
|
+
block,
|
|
236
|
+
out_channels=512,
|
|
237
|
+
num_residual_blocks=layers[3],
|
|
238
|
+
stride=2,
|
|
239
|
+
groups=groups,
|
|
240
|
+
base_width=width_per_group,
|
|
241
|
+
key=layer_keys[3],
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
self.avgpool = eqx.nn.AdaptiveAvgPool2d((1, 1))
|
|
245
|
+
key, fc_key = jax.random.split(key)
|
|
246
|
+
self.fc = eqx.nn.Linear(512 * _get_expansion(block), num_classes, key=fc_key)
|
|
247
|
+
|
|
248
|
+
def __call__(self, x: Array, state: State) -> tuple[Array, State]:
|
|
249
|
+
x = self.conv1(x)
|
|
250
|
+
x, state = self.bn1(x, state)
|
|
251
|
+
x = jax.nn.relu(x)
|
|
252
|
+
x = self.maxpool(x)
|
|
253
|
+
|
|
254
|
+
x, state = self.layer1(x, state)
|
|
255
|
+
x, state = self.layer2(x, state)
|
|
256
|
+
x, state = self.layer3(x, state)
|
|
257
|
+
x, state = self.layer4(x, state)
|
|
258
|
+
x = self.avgpool(x)
|
|
259
|
+
x = x.reshape(-1)
|
|
260
|
+
x = self.fc(x)
|
|
261
|
+
return x, state
|
|
262
|
+
|
|
263
|
+
def _make_layer(
|
|
264
|
+
self,
|
|
265
|
+
block: Type[BasicBlock | Bottleneck],
|
|
266
|
+
out_channels: int,
|
|
267
|
+
num_residual_blocks: int,
|
|
268
|
+
stride: int,
|
|
269
|
+
groups: int,
|
|
270
|
+
base_width: int,
|
|
271
|
+
*,
|
|
272
|
+
key: PRNGKeyArray,
|
|
273
|
+
):
|
|
274
|
+
downsample = None
|
|
275
|
+
expansion = _get_expansion(block)
|
|
276
|
+
key, downsample_key = jax.random.split(key)
|
|
277
|
+
if stride != 1 or self.in_channels != out_channels * expansion:
|
|
278
|
+
downsample = Downsample(
|
|
279
|
+
self.in_channels, out_channels * expansion, stride, key=downsample_key
|
|
280
|
+
)
|
|
281
|
+
layers = []
|
|
282
|
+
key, *layer_keys = jax.random.split(key, num_residual_blocks + 1)
|
|
283
|
+
|
|
284
|
+
layers.append(
|
|
285
|
+
block(
|
|
286
|
+
self.in_channels,
|
|
287
|
+
out_channels,
|
|
288
|
+
stride,
|
|
289
|
+
downsample,
|
|
290
|
+
groups=groups,
|
|
291
|
+
base_width=base_width,
|
|
292
|
+
key=layer_keys[0],
|
|
293
|
+
)
|
|
294
|
+
)
|
|
295
|
+
self.in_channels = out_channels * expansion
|
|
296
|
+
for i in range(num_residual_blocks - 1):
|
|
297
|
+
layers.append(
|
|
298
|
+
block(
|
|
299
|
+
self.in_channels,
|
|
300
|
+
out_channels,
|
|
301
|
+
groups=groups,
|
|
302
|
+
base_width=base_width,
|
|
303
|
+
key=layer_keys[i + 1],
|
|
304
|
+
)
|
|
305
|
+
)
|
|
306
|
+
return ResnetLayer(layers)
|
|
307
|
+
|
|
308
|
+
|
|
309
|
+
def _get_expansion(block_type: Type[Bottleneck | BasicBlock]) -> int:
|
|
310
|
+
if block_type == Bottleneck:
|
|
311
|
+
return 4
|
|
312
|
+
else:
|
|
313
|
+
return 1
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def resnet18(
|
|
317
|
+
image_channels: int = 3,
|
|
318
|
+
num_classes: int = 1000,
|
|
319
|
+
*,
|
|
320
|
+
key: PRNGKeyArray,
|
|
321
|
+
make_with_state: bool = True,
|
|
322
|
+
**kwargs,
|
|
323
|
+
):
|
|
324
|
+
layers = [2, 2, 2, 2]
|
|
325
|
+
if make_with_state:
|
|
326
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
327
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
328
|
+
)
|
|
329
|
+
else:
|
|
330
|
+
return ResNet(
|
|
331
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def resnet34(
|
|
336
|
+
image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
|
|
337
|
+
):
|
|
338
|
+
layers = [3, 4, 6, 3]
|
|
339
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
340
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
def resnet50(
|
|
345
|
+
image_channels: int = 3,
|
|
346
|
+
num_classes: int = 1000,
|
|
347
|
+
*,
|
|
348
|
+
key: PRNGKeyArray,
|
|
349
|
+
make_with_state: bool = True,
|
|
350
|
+
**kwargs,
|
|
351
|
+
):
|
|
352
|
+
layers = [3, 4, 6, 3]
|
|
353
|
+
if make_with_state:
|
|
354
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
355
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
356
|
+
)
|
|
357
|
+
else:
|
|
358
|
+
return ResNet(
|
|
359
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
def resnet101(
|
|
364
|
+
image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
|
|
365
|
+
):
|
|
366
|
+
layers = [3, 4, 23, 3]
|
|
367
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
368
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def resnet152(
|
|
373
|
+
image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
|
|
374
|
+
):
|
|
375
|
+
layers = [3, 8, 36, 3]
|
|
376
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
377
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
378
|
+
)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import functools as ft
|
|
2
|
+
import json
|
|
3
|
+
import urllib
|
|
4
|
+
|
|
5
|
+
import equinox as eqx
|
|
6
|
+
import jax
|
|
7
|
+
import jax.numpy as jnp
|
|
8
|
+
import torch
|
|
9
|
+
from PIL import Image
|
|
10
|
+
from tests.resnet import resnet50
|
|
11
|
+
from torchvision import transforms
|
|
12
|
+
from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_resnet():
|
|
16
|
+
resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
17
|
+
resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
18
|
+
|
|
19
|
+
img_name = "doggo.jpeg"
|
|
20
|
+
|
|
21
|
+
transform = transforms.Compose(
|
|
22
|
+
[
|
|
23
|
+
transforms.Resize(256),
|
|
24
|
+
transforms.CenterCrop(224),
|
|
25
|
+
transforms.ToTensor(),
|
|
26
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
27
|
+
]
|
|
28
|
+
)
|
|
29
|
+
img = Image.open(img_name)
|
|
30
|
+
img_t = transform(img)
|
|
31
|
+
print(img_t.shape) # pyright: ignore
|
|
32
|
+
batch_t = torch.unsqueeze(img_t, 0) # pyright:ignore
|
|
33
|
+
|
|
34
|
+
# Predict
|
|
35
|
+
with torch.no_grad():
|
|
36
|
+
output = resnet_torch(batch_t)
|
|
37
|
+
print(output.shape) # pyright: ignore
|
|
38
|
+
_, predicted = torch.max(output, 1)
|
|
39
|
+
|
|
40
|
+
print(
|
|
41
|
+
f"Predicted: {predicted.item()}"
|
|
42
|
+
) # Outputs the ImageNet class index of the prediction
|
|
43
|
+
|
|
44
|
+
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
|
|
45
|
+
with urllib.request.urlopen(url) as url:
|
|
46
|
+
imagenet_labels = json.loads(url.read().decode())
|
|
47
|
+
|
|
48
|
+
label = imagenet_labels[str(predicted.item())][1]
|
|
49
|
+
print(f"Label for index {predicted.item()}: {label}")
|
|
50
|
+
|
|
51
|
+
identity = lambda x: x
|
|
52
|
+
model_callable = ft.partial(identity, resnet_jax)
|
|
53
|
+
model, state = eqx.nn.make_with_state(model_callable)()
|
|
54
|
+
|
|
55
|
+
model, state = eqx.tree_deserialise_leaves("model.eqx", (model, state))
|
|
56
|
+
|
|
57
|
+
jax_batch = jnp.array(batch_t.numpy())
|
|
58
|
+
out, state = eqx.filter_vmap(
|
|
59
|
+
model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
|
|
60
|
+
)(jax_batch, state)
|
|
61
|
+
print(f"{out.shape}")
|
|
62
|
+
|
|
63
|
+
label = imagenet_labels[str(jnp.argmax(out))][1]
|
|
64
|
+
print(f"Label for index {jnp.argmax(out)}: {label}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == "__main__":
|
|
68
|
+
test_resnet()
|