statedict2pytree 0.3.0__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/PKG-INFO +1 -1
- statedict2pytree-0.4.0/examples/convert_resnet.py +20 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/examples/resnet.py +43 -12
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/examples/test_resnet_inference.py +6 -6
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/pyproject.toml +1 -1
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/statedict2pytree.py +4 -1
- statedict2pytree-0.3.0/examples/convert_resnet.py +0 -16
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/.gitignore +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/.pre-commit-config.yaml +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/README.md +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/examples/doggo.jpeg +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/package-lock.json +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/package.json +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/pyrightconfig.json +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/__init__.py +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/static/input.css +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/static/output.css +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/statedict2pytree/templates/index.html +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/tailwind.config.js +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/tests/test_conv.py +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/tests/test_linear.py +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.4.0}/torch2jax.png +0 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
import statedict2pytree as s2p
|
|
4
|
+
from resnet import resnet152
|
|
5
|
+
from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def convert_resnet():
|
|
9
|
+
resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
10
|
+
resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
|
|
11
|
+
state_dict = resnet_torch.state_dict()
|
|
12
|
+
|
|
13
|
+
# s2p.start_conversion(resnet_jax, state_dict)
|
|
14
|
+
model, state = s2p.autoconvert(resnet_jax, state_dict)
|
|
15
|
+
name = "resnet152.eqx"
|
|
16
|
+
eqx.tree_serialise_leaves(name, (model, state))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == "__main__":
|
|
20
|
+
convert_resnet()
|
|
@@ -333,12 +333,22 @@ def resnet18(
|
|
|
333
333
|
|
|
334
334
|
|
|
335
335
|
def resnet34(
|
|
336
|
-
image_channels: int = 3,
|
|
336
|
+
image_channels: int = 3,
|
|
337
|
+
num_classes: int = 1000,
|
|
338
|
+
*,
|
|
339
|
+
key: PRNGKeyArray,
|
|
340
|
+
make_with_state: bool = True,
|
|
341
|
+
**kwargs,
|
|
337
342
|
):
|
|
338
343
|
layers = [3, 4, 6, 3]
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
344
|
+
if make_with_state:
|
|
345
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
346
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
return ResNet(
|
|
350
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
351
|
+
)
|
|
342
352
|
|
|
343
353
|
|
|
344
354
|
def resnet50(
|
|
@@ -361,18 +371,39 @@ def resnet50(
|
|
|
361
371
|
|
|
362
372
|
|
|
363
373
|
def resnet101(
|
|
364
|
-
image_channels: int = 3,
|
|
374
|
+
image_channels: int = 3,
|
|
375
|
+
num_classes: int = 1000,
|
|
376
|
+
*,
|
|
377
|
+
key: PRNGKeyArray,
|
|
378
|
+
make_with_state: bool = True,
|
|
379
|
+
**kwargs,
|
|
365
380
|
):
|
|
366
381
|
layers = [3, 4, 23, 3]
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
382
|
+
|
|
383
|
+
if make_with_state:
|
|
384
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
385
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
return ResNet(
|
|
389
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
390
|
+
)
|
|
370
391
|
|
|
371
392
|
|
|
372
393
|
def resnet152(
|
|
373
|
-
image_channels: int = 3,
|
|
394
|
+
image_channels: int = 3,
|
|
395
|
+
num_classes: int = 1000,
|
|
396
|
+
*,
|
|
397
|
+
key: PRNGKeyArray,
|
|
398
|
+
make_with_state: bool = True,
|
|
399
|
+
**kwargs,
|
|
374
400
|
):
|
|
375
401
|
layers = [3, 8, 36, 3]
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
402
|
+
if make_with_state:
|
|
403
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
404
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
return ResNet(
|
|
408
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
409
|
+
)
|
|
@@ -6,15 +6,15 @@ import equinox as eqx
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import torch
|
|
9
|
+
from examples.resnet import resnet152
|
|
9
10
|
from PIL import Image
|
|
10
|
-
from tests.resnet import resnet50
|
|
11
11
|
from torchvision import transforms
|
|
12
|
-
from torchvision.models import
|
|
12
|
+
from torchvision.models import resnet152 as t_resnet152, ResNet152_Weights
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def test_resnet():
|
|
16
|
-
resnet_jax =
|
|
17
|
-
resnet_torch =
|
|
16
|
+
resnet_jax = resnet152(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
17
|
+
resnet_torch = t_resnet152(weights=ResNet152_Weights.DEFAULT)
|
|
18
18
|
|
|
19
19
|
img_name = "doggo.jpeg"
|
|
20
20
|
|
|
@@ -42,7 +42,7 @@ def test_resnet():
|
|
|
42
42
|
) # Outputs the ImageNet class index of the prediction
|
|
43
43
|
|
|
44
44
|
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
|
|
45
|
-
with urllib.request.urlopen(url) as url:
|
|
45
|
+
with urllib.request.urlopen(url) as url: # pyright: ignore
|
|
46
46
|
imagenet_labels = json.loads(url.read().decode())
|
|
47
47
|
|
|
48
48
|
label = imagenet_labels[str(predicted.item())][1]
|
|
@@ -52,7 +52,7 @@ def test_resnet():
|
|
|
52
52
|
model_callable = ft.partial(identity, resnet_jax)
|
|
53
53
|
model, state = eqx.nn.make_with_state(model_callable)()
|
|
54
54
|
|
|
55
|
-
model, state = eqx.tree_deserialise_leaves("
|
|
55
|
+
model, state = eqx.tree_deserialise_leaves("resnet152.eqx", (model, state))
|
|
56
56
|
|
|
57
57
|
jax_batch = jnp.array(batch_t.numpy())
|
|
58
58
|
out, state = eqx.filter_vmap(
|
|
@@ -46,7 +46,7 @@ def get_node(
|
|
|
46
46
|
return tree
|
|
47
47
|
else:
|
|
48
48
|
next_target: str = targets[0]
|
|
49
|
-
if bool(re.search(r"\[\d
|
|
49
|
+
if bool(re.search(r"\[\d+\]", next_target)):
|
|
50
50
|
split_index = next_target.rfind("[")
|
|
51
51
|
name, index = next_target[:split_index], next_target[split_index:]
|
|
52
52
|
index = index[1:-1]
|
|
@@ -157,6 +157,9 @@ def index():
|
|
|
157
157
|
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
158
|
jax_fields = pytree_to_fields(pytree)
|
|
159
159
|
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
+
|
|
161
|
+
for k, v in state_dict.items():
|
|
162
|
+
state_dict[k] = v.numpy()
|
|
160
163
|
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
161
164
|
|
|
162
165
|
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import statedict2pytree as s2p
|
|
3
|
-
from resnet import resnet50
|
|
4
|
-
from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def convert_resnet():
|
|
8
|
-
resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
9
|
-
resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
10
|
-
state_dict = resnet_torch.state_dict()
|
|
11
|
-
|
|
12
|
-
s2p.start_conversion(resnet_jax, state_dict)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
if __name__ == "__main__":
|
|
16
|
-
convert_resnet()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|