statedict2pytree 0.1.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,161 @@
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+ node_modules/
@@ -0,0 +1,16 @@
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ # Ruff version.
4
+ rev: v0.4.4
5
+ hooks:
6
+ # Run the linter.
7
+ - id: ruff
8
+ args: [--fix]
9
+ # Run the formatter.
10
+ - id: ruff-format
11
+ - repo: https://github.com/RobertCraigie/pyright-python
12
+ rev: v1.1.351
13
+ hooks:
14
+ - id: pyright
15
+ additional_dependencies:
16
+ [beartype, jax, jaxtyping, pytest, typing_extensions]
@@ -0,0 +1,43 @@
1
+ Metadata-Version: 2.3
2
+ Name: statedict2pytree
3
+ Version: 0.1.2
4
+ Summary: Converts torch models into PyTrees for Equinox
5
+ Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
+ Requires-Python: ~=3.10
7
+ Requires-Dist: beartype
8
+ Requires-Dist: equinox>=0.11.4
9
+ Requires-Dist: flask
10
+ Requires-Dist: jax
11
+ Requires-Dist: jaxlib
12
+ Requires-Dist: jaxtyping
13
+ Requires-Dist: loguru
14
+ Requires-Dist: pydantic
15
+ Requires-Dist: torch
16
+ Requires-Dist: typing-extensions
17
+ Provides-Extra: dev
18
+ Requires-Dist: mkdocs; extra == 'dev'
19
+ Requires-Dist: nox; extra == 'dev'
20
+ Requires-Dist: pre-commit; extra == 'dev'
21
+ Requires-Dist: pytest; extra == 'dev'
22
+ Description-Content-Type: text/markdown
23
+
24
+ # statedict2pytree
25
+
26
+ ![statedict2pytree](torch2jax.png "A ResNet demo")
27
+
28
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
29
+
30
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
31
+
32
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
33
+
34
+ ## Get Started
35
+
36
+ ### Installation
37
+
38
+ Run
39
+
40
+ ```bash
41
+ pip install statedict2pytree
42
+
43
+ ```
@@ -0,0 +1,20 @@
1
+ # statedict2pytree
2
+
3
+ ![statedict2pytree](torch2jax.png "A ResNet demo")
4
+
5
+ The goal of this package is to simplify the conversion from PyTorch models into JAX PyTrees (which can be used e.g. in Equinox). The way this works is by putting both models side my side and aligning the weights in the right order. Then, all statedict2pytree is doing, is iterating over both lists and matching the weight matrices.
6
+
7
+ Usually, if you _declared the fields in the same order as in the PyTorch model_, you don't have to rearrange anything -- but the option is there if you need it.
8
+
9
+ (Theoretically, you can rearrange the model in any way you like - e.g. last layer as the first layer - as long as the shapes match!)
10
+
11
+ ## Get Started
12
+
13
+ ### Installation
14
+
15
+ Run
16
+
17
+ ```bash
18
+ pip install statedict2pytree
19
+
20
+ ```
@@ -0,0 +1,16 @@
1
+ import jax
2
+ import statedict2pytree as s2p
3
+ from tests.resnet import resnet50
4
+ from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
5
+
6
+
7
+ def convert_resnet():
8
+ resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
9
+ resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
10
+ state_dict = resnet_torch.state_dict()
11
+
12
+ s2p.start_conversion(resnet_jax, state_dict)
13
+
14
+
15
+ if __name__ == "__main__":
16
+ convert_resnet()
@@ -0,0 +1,378 @@
1
+ import equinox as eqx
2
+ import jax
3
+ from beartype.typing import Optional, Type
4
+ from equinox.nn import State
5
+ from jaxtyping import Array, PRNGKeyArray
6
+
7
+
8
+ def conv3x3(
9
+ in_channels: int,
10
+ out_channels: int,
11
+ stride: int = 1,
12
+ groups: int = 1,
13
+ *,
14
+ key: PRNGKeyArray,
15
+ ) -> eqx.nn.Conv2d:
16
+ return eqx.nn.Conv2d(
17
+ in_channels,
18
+ out_channels,
19
+ kernel_size=3,
20
+ stride=stride,
21
+ padding=1,
22
+ groups=groups,
23
+ use_bias=False,
24
+ key=key,
25
+ )
26
+
27
+
28
+ def conv1x1(
29
+ in_channels: int, out_channels: int, stride: int = 1, *, key: PRNGKeyArray
30
+ ) -> eqx.nn.Conv2d:
31
+ return eqx.nn.Conv2d(
32
+ in_channels, out_channels, kernel_size=1, stride=stride, use_bias=False, key=key
33
+ )
34
+
35
+
36
+ class Downsample(eqx.Module):
37
+ conv: eqx.nn.Conv2d
38
+ norm: eqx.nn.BatchNorm
39
+
40
+ def __init__(
41
+ self, in_channels: int, out_channels: int, stride: int, *, key: PRNGKeyArray
42
+ ) -> None:
43
+ self.conv = conv1x1(in_channels, out_channels, stride, key=key)
44
+ self.norm = eqx.nn.BatchNorm(out_channels, axis_name="batch")
45
+
46
+ def __call__(self, x: Array, state: State) -> tuple[Array, State]:
47
+ x = self.conv(x)
48
+ x, state = self.norm(x, state)
49
+
50
+ return x, state
51
+
52
+
53
+ class BasicBlock(eqx.Module):
54
+ conv1: eqx.nn.Conv2d
55
+ bn1: eqx.nn.BatchNorm
56
+ conv2: eqx.nn.Conv2d
57
+ bn2: eqx.nn.BatchNorm
58
+ downsample: Optional[Downsample]
59
+
60
+ def __init__(
61
+ self,
62
+ in_channels: int,
63
+ out_channels: int,
64
+ stride: int = 1,
65
+ downsample: Optional[Downsample] = None,
66
+ groups: int = 1,
67
+ base_width: int = 64,
68
+ *,
69
+ key: PRNGKeyArray,
70
+ ):
71
+ if groups != 1 or base_width != 64:
72
+ raise ValueError("BasicBlock only supports groups=1 and base_width=64")
73
+ key, conv1_key, conv2_key = jax.random.split(key, 3)
74
+ self.conv1 = conv3x3(in_channels, out_channels, stride, key=conv1_key)
75
+ self.bn1 = eqx.nn.BatchNorm(out_channels, axis_name="batch")
76
+ self.conv2 = conv3x3(out_channels, out_channels, key=conv2_key)
77
+ self.bn2 = eqx.nn.BatchNorm(out_channels, axis_name="batch")
78
+ self.downsample = downsample
79
+
80
+ def __call__(self, x: Array, state: State) -> tuple[Array, State]:
81
+ identity = x
82
+
83
+ out = self.conv1(x)
84
+ out, state = self.bn1(out, state)
85
+ out = jax.nn.relu(out)
86
+
87
+ out = self.conv2(out)
88
+ out, state = self.bn2(out, state)
89
+
90
+ if self.downsample is not None:
91
+ identity, state = self.downsample(x, state)
92
+
93
+ out += identity
94
+ out = jax.nn.relu(out)
95
+
96
+ return out, state
97
+
98
+
99
+ class Bottleneck(eqx.Module):
100
+ conv1: eqx.nn.Conv2d
101
+ bn1: eqx.nn.BatchNorm
102
+ conv2: eqx.nn.Conv2d
103
+ bn2: eqx.nn.BatchNorm
104
+ conv3: eqx.nn.Conv2d
105
+ bn3: eqx.nn.BatchNorm
106
+
107
+ downsample: Optional[Downsample]
108
+
109
+ def __init__(
110
+ self,
111
+ in_channels: int,
112
+ out_channels: int,
113
+ stride: int = 1,
114
+ downsample: Optional[Downsample] = None,
115
+ groups: int = 1,
116
+ base_width: int = 64,
117
+ *,
118
+ key: PRNGKeyArray,
119
+ ) -> None:
120
+ width = int(out_channels * (base_width / 64.0)) * groups
121
+ conv1_key, conv2_key, conv3_key = jax.random.split(key, 3)
122
+ expansion = 4
123
+ self.conv1 = conv1x1(in_channels, width, key=conv1_key)
124
+ self.bn1 = eqx.nn.BatchNorm(width, axis_name="batch")
125
+ self.conv2 = conv3x3(width, width, stride, groups, key=conv2_key)
126
+ self.bn2 = eqx.nn.BatchNorm(width, axis_name="batch")
127
+ self.conv3 = conv1x1(width, out_channels * expansion, key=conv3_key)
128
+ self.bn3 = eqx.nn.BatchNorm(out_channels * expansion, axis_name="batch")
129
+ self.downsample = downsample
130
+
131
+ def __call__(self, x: Array, state: State) -> tuple[Array, State]:
132
+ identity = x
133
+ x = self.conv1(x)
134
+ x, state = self.bn1(x, state)
135
+ x = jax.nn.relu(x)
136
+
137
+ x = self.conv2(x)
138
+ x, state = self.bn2(x, state)
139
+ x = jax.nn.relu(x)
140
+
141
+ x = self.conv3(x)
142
+ x, state = self.bn3(x, state)
143
+
144
+ if self.downsample is not None:
145
+ identity, state = self.downsample(identity, state)
146
+
147
+ x += identity
148
+ x = jax.nn.relu(x)
149
+
150
+ return x, state
151
+
152
+
153
+ class ResnetLayer(eqx.Module):
154
+ layers: list[BasicBlock | Bottleneck]
155
+
156
+ def __init__(self, layers: list[BasicBlock | Bottleneck]) -> None:
157
+ self.layers = layers
158
+
159
+ def __call__(self, x: Array, state: State) -> tuple[Array, State]:
160
+ for l in self.layers:
161
+ x, state = l(x, state)
162
+ return x, state
163
+
164
+
165
+ class ResNet(eqx.Module):
166
+ in_channels: int = eqx.field(static=True)
167
+
168
+ conv1: eqx.nn.Conv2d
169
+ bn1: eqx.nn.BatchNorm
170
+ maxpool: eqx.nn.MaxPool2d
171
+
172
+ layer1: ResnetLayer
173
+ layer2: ResnetLayer
174
+ layer3: ResnetLayer
175
+ layer4: ResnetLayer
176
+
177
+ avgpool: eqx.nn.AdaptiveAvgPool2d
178
+ fc: eqx.nn.Linear
179
+
180
+ def __init__(
181
+ self,
182
+ block: Type[BasicBlock | Bottleneck],
183
+ layers: list[int],
184
+ image_channels: int = 3,
185
+ num_classes: int = 1000,
186
+ groups: int = 1,
187
+ width_per_group: int = 64,
188
+ *,
189
+ key: PRNGKeyArray,
190
+ ) -> None:
191
+ self.in_channels = 64
192
+ key, conv_key = jax.random.split(key)
193
+ self.conv1 = eqx.nn.Conv2d(
194
+ image_channels,
195
+ self.in_channels,
196
+ kernel_size=7,
197
+ stride=2,
198
+ padding=3,
199
+ use_bias=False,
200
+ key=conv_key,
201
+ )
202
+ self.bn1 = eqx.nn.BatchNorm(self.in_channels, axis_name="batch")
203
+ self.maxpool = eqx.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
204
+ key, *layer_keys = jax.random.split(key, len(layers) + 1)
205
+ self.layer1 = self._make_layer(
206
+ block,
207
+ out_channels=64,
208
+ num_residual_blocks=layers[0],
209
+ stride=1,
210
+ groups=groups,
211
+ base_width=width_per_group,
212
+ key=layer_keys[0],
213
+ )
214
+ self.layer2 = self._make_layer(
215
+ block,
216
+ out_channels=128,
217
+ num_residual_blocks=layers[1],
218
+ stride=2,
219
+ groups=groups,
220
+ base_width=width_per_group,
221
+ key=layer_keys[1],
222
+ )
223
+
224
+ self.layer3 = self._make_layer(
225
+ block,
226
+ out_channels=256,
227
+ num_residual_blocks=layers[2],
228
+ stride=2,
229
+ groups=groups,
230
+ base_width=width_per_group,
231
+ key=layer_keys[2],
232
+ )
233
+
234
+ self.layer4 = self._make_layer(
235
+ block,
236
+ out_channels=512,
237
+ num_residual_blocks=layers[3],
238
+ stride=2,
239
+ groups=groups,
240
+ base_width=width_per_group,
241
+ key=layer_keys[3],
242
+ )
243
+
244
+ self.avgpool = eqx.nn.AdaptiveAvgPool2d((1, 1))
245
+ key, fc_key = jax.random.split(key)
246
+ self.fc = eqx.nn.Linear(512 * _get_expansion(block), num_classes, key=fc_key)
247
+
248
+ def __call__(self, x: Array, state: State) -> tuple[Array, State]:
249
+ x = self.conv1(x)
250
+ x, state = self.bn1(x, state)
251
+ x = jax.nn.relu(x)
252
+ x = self.maxpool(x)
253
+
254
+ x, state = self.layer1(x, state)
255
+ x, state = self.layer2(x, state)
256
+ x, state = self.layer3(x, state)
257
+ x, state = self.layer4(x, state)
258
+ x = self.avgpool(x)
259
+ x = x.reshape(-1)
260
+ x = self.fc(x)
261
+ return x, state
262
+
263
+ def _make_layer(
264
+ self,
265
+ block: Type[BasicBlock | Bottleneck],
266
+ out_channels: int,
267
+ num_residual_blocks: int,
268
+ stride: int,
269
+ groups: int,
270
+ base_width: int,
271
+ *,
272
+ key: PRNGKeyArray,
273
+ ):
274
+ downsample = None
275
+ expansion = _get_expansion(block)
276
+ key, downsample_key = jax.random.split(key)
277
+ if stride != 1 or self.in_channels != out_channels * expansion:
278
+ downsample = Downsample(
279
+ self.in_channels, out_channels * expansion, stride, key=downsample_key
280
+ )
281
+ layers = []
282
+ key, *layer_keys = jax.random.split(key, num_residual_blocks + 1)
283
+
284
+ layers.append(
285
+ block(
286
+ self.in_channels,
287
+ out_channels,
288
+ stride,
289
+ downsample,
290
+ groups=groups,
291
+ base_width=base_width,
292
+ key=layer_keys[0],
293
+ )
294
+ )
295
+ self.in_channels = out_channels * expansion
296
+ for i in range(num_residual_blocks - 1):
297
+ layers.append(
298
+ block(
299
+ self.in_channels,
300
+ out_channels,
301
+ groups=groups,
302
+ base_width=base_width,
303
+ key=layer_keys[i + 1],
304
+ )
305
+ )
306
+ return ResnetLayer(layers)
307
+
308
+
309
+ def _get_expansion(block_type: Type[Bottleneck | BasicBlock]) -> int:
310
+ if block_type == Bottleneck:
311
+ return 4
312
+ else:
313
+ return 1
314
+
315
+
316
+ def resnet18(
317
+ image_channels: int = 3,
318
+ num_classes: int = 1000,
319
+ *,
320
+ key: PRNGKeyArray,
321
+ make_with_state: bool = True,
322
+ **kwargs,
323
+ ):
324
+ layers = [2, 2, 2, 2]
325
+ if make_with_state:
326
+ return eqx.nn.make_with_state(ResNet)(
327
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
328
+ )
329
+ else:
330
+ return ResNet(
331
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
332
+ )
333
+
334
+
335
+ def resnet34(
336
+ image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
337
+ ):
338
+ layers = [3, 4, 6, 3]
339
+ return eqx.nn.make_with_state(ResNet)(
340
+ BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
341
+ )
342
+
343
+
344
+ def resnet50(
345
+ image_channels: int = 3,
346
+ num_classes: int = 1000,
347
+ *,
348
+ key: PRNGKeyArray,
349
+ make_with_state: bool = True,
350
+ **kwargs,
351
+ ):
352
+ layers = [3, 4, 6, 3]
353
+ if make_with_state:
354
+ return eqx.nn.make_with_state(ResNet)(
355
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
356
+ )
357
+ else:
358
+ return ResNet(
359
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
360
+ )
361
+
362
+
363
+ def resnet101(
364
+ image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
365
+ ):
366
+ layers = [3, 4, 23, 3]
367
+ return eqx.nn.make_with_state(ResNet)(
368
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
369
+ )
370
+
371
+
372
+ def resnet152(
373
+ image_channels: int = 3, num_classes: int = 1000, *, key: PRNGKeyArray, **kwargs
374
+ ):
375
+ layers = [3, 8, 36, 3]
376
+ return eqx.nn.make_with_state(ResNet)(
377
+ Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
378
+ )
@@ -0,0 +1,68 @@
1
+ import functools as ft
2
+ import json
3
+ import urllib
4
+
5
+ import equinox as eqx
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import torch
9
+ from PIL import Image
10
+ from tests.resnet import resnet50
11
+ from torchvision import transforms
12
+ from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
13
+
14
+
15
+ def test_resnet():
16
+ resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
17
+ resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
18
+
19
+ img_name = "doggo.jpeg"
20
+
21
+ transform = transforms.Compose(
22
+ [
23
+ transforms.Resize(256),
24
+ transforms.CenterCrop(224),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
27
+ ]
28
+ )
29
+ img = Image.open(img_name)
30
+ img_t = transform(img)
31
+ print(img_t.shape) # pyright: ignore
32
+ batch_t = torch.unsqueeze(img_t, 0) # pyright:ignore
33
+
34
+ # Predict
35
+ with torch.no_grad():
36
+ output = resnet_torch(batch_t)
37
+ print(output.shape) # pyright: ignore
38
+ _, predicted = torch.max(output, 1)
39
+
40
+ print(
41
+ f"Predicted: {predicted.item()}"
42
+ ) # Outputs the ImageNet class index of the prediction
43
+
44
+ url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
45
+ with urllib.request.urlopen(url) as url:
46
+ imagenet_labels = json.loads(url.read().decode())
47
+
48
+ label = imagenet_labels[str(predicted.item())][1]
49
+ print(f"Label for index {predicted.item()}: {label}")
50
+
51
+ identity = lambda x: x
52
+ model_callable = ft.partial(identity, resnet_jax)
53
+ model, state = eqx.nn.make_with_state(model_callable)()
54
+
55
+ model, state = eqx.tree_deserialise_leaves("model.eqx", (model, state))
56
+
57
+ jax_batch = jnp.array(batch_t.numpy())
58
+ out, state = eqx.filter_vmap(
59
+ model, in_axes=(0, None), out_axes=(0, None), axis_name="batch"
60
+ )(jax_batch, state)
61
+ print(f"{out.shape}")
62
+
63
+ label = imagenet_labels[str(jnp.argmax(out))][1]
64
+ print(f"Label for index {jnp.argmax(out)}: {label}")
65
+
66
+
67
+ if __name__ == "__main__":
68
+ test_resnet()