statedict2pytree 0.3.0__tar.gz → 0.5.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/PKG-INFO +1 -1
- statedict2pytree-0.5.0/examples/convert_resnet.py +19 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/examples/resnet.py +43 -12
- statedict2pytree-0.3.0/examples/test_resnet_inference.py → statedict2pytree-0.5.0/examples/resnet_inference.py +5 -5
- statedict2pytree-0.5.0/package.json +9 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/pyproject.toml +1 -1
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/statedict2pytree/statedict2pytree.py +5 -1
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/statedict2pytree/static/output.css +87 -279
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/statedict2pytree/templates/index.html +106 -44
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/tailwind.config.js +1 -1
- statedict2pytree-0.5.0/tests/test_batchnorm.py +46 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/tests/test_conv.py +0 -4
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/tests/test_linear.py +0 -4
- statedict2pytree-0.3.0/examples/convert_resnet.py +0 -16
- statedict2pytree-0.3.0/package.json +0 -10
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/.gitignore +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/.pre-commit-config.yaml +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/README.md +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/examples/doggo.jpeg +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/package-lock.json +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/pyrightconfig.json +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/statedict2pytree/__init__.py +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/statedict2pytree/static/input.css +0 -0
- {statedict2pytree-0.3.0 → statedict2pytree-0.5.0}/torch2jax.png +0 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import jax
|
|
2
|
+
import statedict2pytree as s2p
|
|
3
|
+
from resnet import resnet18
|
|
4
|
+
from torchvision.models import resnet18 as t_resnet18, ResNet18_Weights
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def convert_resnet():
|
|
8
|
+
resnet_jax = resnet18(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
9
|
+
resnet_torch = t_resnet18(weights=ResNet18_Weights.DEFAULT)
|
|
10
|
+
state_dict = resnet_torch.state_dict()
|
|
11
|
+
|
|
12
|
+
s2p.start_conversion(resnet_jax, state_dict)
|
|
13
|
+
# model, state = s2p.autoconvert(resnet_jax, state_dict)
|
|
14
|
+
# name = "resnet18.eqx"
|
|
15
|
+
# eqx.tree_serialise_leaves(name, (model, state))
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
if __name__ == "__main__":
|
|
19
|
+
convert_resnet()
|
|
@@ -333,12 +333,22 @@ def resnet18(
|
|
|
333
333
|
|
|
334
334
|
|
|
335
335
|
def resnet34(
|
|
336
|
-
image_channels: int = 3,
|
|
336
|
+
image_channels: int = 3,
|
|
337
|
+
num_classes: int = 1000,
|
|
338
|
+
*,
|
|
339
|
+
key: PRNGKeyArray,
|
|
340
|
+
make_with_state: bool = True,
|
|
341
|
+
**kwargs,
|
|
337
342
|
):
|
|
338
343
|
layers = [3, 4, 6, 3]
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
344
|
+
if make_with_state:
|
|
345
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
346
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
347
|
+
)
|
|
348
|
+
else:
|
|
349
|
+
return ResNet(
|
|
350
|
+
BasicBlock, layers, image_channels, num_classes, **kwargs, key=key
|
|
351
|
+
)
|
|
342
352
|
|
|
343
353
|
|
|
344
354
|
def resnet50(
|
|
@@ -361,18 +371,39 @@ def resnet50(
|
|
|
361
371
|
|
|
362
372
|
|
|
363
373
|
def resnet101(
|
|
364
|
-
image_channels: int = 3,
|
|
374
|
+
image_channels: int = 3,
|
|
375
|
+
num_classes: int = 1000,
|
|
376
|
+
*,
|
|
377
|
+
key: PRNGKeyArray,
|
|
378
|
+
make_with_state: bool = True,
|
|
379
|
+
**kwargs,
|
|
365
380
|
):
|
|
366
381
|
layers = [3, 4, 23, 3]
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
382
|
+
|
|
383
|
+
if make_with_state:
|
|
384
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
385
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
return ResNet(
|
|
389
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
390
|
+
)
|
|
370
391
|
|
|
371
392
|
|
|
372
393
|
def resnet152(
|
|
373
|
-
image_channels: int = 3,
|
|
394
|
+
image_channels: int = 3,
|
|
395
|
+
num_classes: int = 1000,
|
|
396
|
+
*,
|
|
397
|
+
key: PRNGKeyArray,
|
|
398
|
+
make_with_state: bool = True,
|
|
399
|
+
**kwargs,
|
|
374
400
|
):
|
|
375
401
|
layers = [3, 8, 36, 3]
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
402
|
+
if make_with_state:
|
|
403
|
+
return eqx.nn.make_with_state(ResNet)(
|
|
404
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
405
|
+
)
|
|
406
|
+
else:
|
|
407
|
+
return ResNet(
|
|
408
|
+
Bottleneck, layers, image_channels, num_classes, **kwargs, key=key
|
|
409
|
+
)
|
|
@@ -6,15 +6,15 @@ import equinox as eqx
|
|
|
6
6
|
import jax
|
|
7
7
|
import jax.numpy as jnp
|
|
8
8
|
import torch
|
|
9
|
+
from examples.resnet import resnet18
|
|
9
10
|
from PIL import Image
|
|
10
|
-
from tests.resnet import resnet50
|
|
11
11
|
from torchvision import transforms
|
|
12
|
-
from torchvision.models import
|
|
12
|
+
from torchvision.models import resnet18 as t_resnet18, ResNet18_Weights
|
|
13
13
|
|
|
14
14
|
|
|
15
15
|
def test_resnet():
|
|
16
|
-
resnet_jax =
|
|
17
|
-
resnet_torch =
|
|
16
|
+
resnet_jax = resnet18(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
17
|
+
resnet_torch = t_resnet18(weights=ResNet18_Weights.DEFAULT)
|
|
18
18
|
|
|
19
19
|
img_name = "doggo.jpeg"
|
|
20
20
|
|
|
@@ -42,7 +42,7 @@ def test_resnet():
|
|
|
42
42
|
) # Outputs the ImageNet class index of the prediction
|
|
43
43
|
|
|
44
44
|
url = "https://storage.googleapis.com/download.tensorflow.org/data/imagenet_class_index.json"
|
|
45
|
-
with urllib.request.urlopen(url) as url:
|
|
45
|
+
with urllib.request.urlopen(url) as url: # pyright: ignore
|
|
46
46
|
imagenet_labels = json.loads(url.read().decode())
|
|
47
47
|
|
|
48
48
|
label = imagenet_labels[str(predicted.item())][1]
|
|
@@ -46,7 +46,7 @@ def get_node(
|
|
|
46
46
|
return tree
|
|
47
47
|
else:
|
|
48
48
|
next_target: str = targets[0]
|
|
49
|
-
if bool(re.search(r"\[\d
|
|
49
|
+
if bool(re.search(r"\[\d+\]", next_target)):
|
|
50
50
|
split_index = next_target.rfind("[")
|
|
51
51
|
name, index = next_target[:split_index], next_target[split_index:]
|
|
52
52
|
index = index[1:-1]
|
|
@@ -157,6 +157,9 @@ def index():
|
|
|
157
157
|
def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
|
|
158
158
|
jax_fields = pytree_to_fields(pytree)
|
|
159
159
|
torch_fields = state_dict_to_fields(state_dict)
|
|
160
|
+
|
|
161
|
+
for k, v in state_dict.items():
|
|
162
|
+
state_dict[k] = v.numpy()
|
|
160
163
|
return convert(jax_fields, torch_fields, pytree, state_dict)
|
|
161
164
|
|
|
162
165
|
|
|
@@ -212,4 +215,5 @@ def start_conversion(pytree: PyTree, state_dict: dict):
|
|
|
212
215
|
|
|
213
216
|
for k, v in STATE_DICT.items():
|
|
214
217
|
STATE_DICT[k] = v.numpy()
|
|
218
|
+
app.jinja_env.globals.update(enumerate=enumerate)
|
|
215
219
|
app.run(debug=True, port=5500)
|
|
@@ -784,23 +784,42 @@ html {
|
|
|
784
784
|
}
|
|
785
785
|
}
|
|
786
786
|
|
|
787
|
-
.
|
|
788
|
-
display:
|
|
787
|
+
.alert {
|
|
788
|
+
display: grid;
|
|
789
|
+
width: 100%;
|
|
790
|
+
grid-auto-flow: row;
|
|
791
|
+
align-content: flex-start;
|
|
789
792
|
align-items: center;
|
|
790
|
-
justify-
|
|
793
|
+
justify-items: center;
|
|
794
|
+
gap: 1rem;
|
|
795
|
+
text-align: center;
|
|
796
|
+
border-radius: var(--rounded-box, 1rem);
|
|
797
|
+
border-width: 1px;
|
|
798
|
+
--tw-border-opacity: 1;
|
|
799
|
+
border-color: var(--fallback-b2,oklch(var(--b2)/var(--tw-border-opacity)));
|
|
800
|
+
padding: 1rem;
|
|
801
|
+
--tw-text-opacity: 1;
|
|
802
|
+
color: var(--fallback-bc,oklch(var(--bc)/var(--tw-text-opacity)));
|
|
803
|
+
--alert-bg: var(--fallback-b2,oklch(var(--b2)/1));
|
|
804
|
+
--alert-bg-mix: var(--fallback-b1,oklch(var(--b1)/1));
|
|
805
|
+
background-color: var(--alert-bg);
|
|
791
806
|
}
|
|
792
807
|
|
|
793
|
-
@media (
|
|
794
|
-
.
|
|
795
|
-
|
|
796
|
-
|
|
797
|
-
|
|
798
|
-
|
|
799
|
-
--tw-text-opacity: 1;
|
|
800
|
-
color: var(--fallback-nc,oklch(var(--nc)/var(--tw-text-opacity)));
|
|
808
|
+
@media (min-width: 640px) {
|
|
809
|
+
.alert {
|
|
810
|
+
grid-auto-flow: column;
|
|
811
|
+
grid-template-columns: auto minmax(auto,1fr);
|
|
812
|
+
justify-items: start;
|
|
813
|
+
text-align: start;
|
|
801
814
|
}
|
|
802
815
|
}
|
|
803
816
|
|
|
817
|
+
.avatar.placeholder > div {
|
|
818
|
+
display: flex;
|
|
819
|
+
align-items: center;
|
|
820
|
+
justify-content: center;
|
|
821
|
+
}
|
|
822
|
+
|
|
804
823
|
.btn {
|
|
805
824
|
display: inline-flex;
|
|
806
825
|
height: 3rem;
|
|
@@ -934,18 +953,6 @@ html {
|
|
|
934
953
|
border-color: color-mix(in oklab, var(--fallback-p,oklch(var(--p)/1)) 90%, black);
|
|
935
954
|
}
|
|
936
955
|
}
|
|
937
|
-
|
|
938
|
-
:where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(.active, .btn):hover, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(.active, .btn):hover {
|
|
939
|
-
cursor: pointer;
|
|
940
|
-
outline: 2px solid transparent;
|
|
941
|
-
outline-offset: 2px;
|
|
942
|
-
}
|
|
943
|
-
|
|
944
|
-
@supports (color: oklch(0% 0 0)) {
|
|
945
|
-
:where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(.active, .btn):hover, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(.active, .btn):hover {
|
|
946
|
-
background-color: var(--fallback-bc,oklch(var(--bc)/0.1));
|
|
947
|
-
}
|
|
948
|
-
}
|
|
949
956
|
}
|
|
950
957
|
|
|
951
958
|
.input {
|
|
@@ -978,59 +985,6 @@ html {
|
|
|
978
985
|
text-decoration-line: underline;
|
|
979
986
|
}
|
|
980
987
|
|
|
981
|
-
.menu {
|
|
982
|
-
display: flex;
|
|
983
|
-
flex-direction: column;
|
|
984
|
-
flex-wrap: wrap;
|
|
985
|
-
font-size: 0.875rem;
|
|
986
|
-
line-height: 1.25rem;
|
|
987
|
-
padding: 0.5rem;
|
|
988
|
-
}
|
|
989
|
-
|
|
990
|
-
.menu :where(li ul) {
|
|
991
|
-
position: relative;
|
|
992
|
-
white-space: nowrap;
|
|
993
|
-
margin-inline-start: 1rem;
|
|
994
|
-
padding-inline-start: 0.5rem;
|
|
995
|
-
}
|
|
996
|
-
|
|
997
|
-
.menu :where(li:not(.menu-title) > *:not(ul, details, .menu-title, .btn)), .menu :where(li:not(.menu-title) > details > summary:not(.menu-title)) {
|
|
998
|
-
display: grid;
|
|
999
|
-
grid-auto-flow: column;
|
|
1000
|
-
align-content: flex-start;
|
|
1001
|
-
align-items: center;
|
|
1002
|
-
gap: 0.5rem;
|
|
1003
|
-
grid-auto-columns: minmax(auto, max-content) auto max-content;
|
|
1004
|
-
-webkit-user-select: none;
|
|
1005
|
-
-moz-user-select: none;
|
|
1006
|
-
user-select: none;
|
|
1007
|
-
}
|
|
1008
|
-
|
|
1009
|
-
.menu li.disabled {
|
|
1010
|
-
cursor: not-allowed;
|
|
1011
|
-
-webkit-user-select: none;
|
|
1012
|
-
-moz-user-select: none;
|
|
1013
|
-
user-select: none;
|
|
1014
|
-
color: var(--fallback-bc,oklch(var(--bc)/0.3));
|
|
1015
|
-
}
|
|
1016
|
-
|
|
1017
|
-
.menu :where(li > .menu-dropdown:not(.menu-dropdown-show)) {
|
|
1018
|
-
display: none;
|
|
1019
|
-
}
|
|
1020
|
-
|
|
1021
|
-
:where(.menu li) {
|
|
1022
|
-
position: relative;
|
|
1023
|
-
display: flex;
|
|
1024
|
-
flex-shrink: 0;
|
|
1025
|
-
flex-direction: column;
|
|
1026
|
-
flex-wrap: wrap;
|
|
1027
|
-
align-items: stretch;
|
|
1028
|
-
}
|
|
1029
|
-
|
|
1030
|
-
:where(.menu li) .badge {
|
|
1031
|
-
justify-self: end;
|
|
1032
|
-
}
|
|
1033
|
-
|
|
1034
988
|
.toast {
|
|
1035
989
|
position: fixed;
|
|
1036
990
|
display: flex;
|
|
@@ -1042,6 +996,14 @@ html {
|
|
|
1042
996
|
padding: 1rem;
|
|
1043
997
|
}
|
|
1044
998
|
|
|
999
|
+
.alert-error {
|
|
1000
|
+
border-color: var(--fallback-er,oklch(var(--er)/0.2));
|
|
1001
|
+
--tw-text-opacity: 1;
|
|
1002
|
+
color: var(--fallback-erc,oklch(var(--erc)/var(--tw-text-opacity)));
|
|
1003
|
+
--alert-bg: var(--fallback-er,oklch(var(--er)/1));
|
|
1004
|
+
--alert-bg-mix: var(--fallback-b1,oklch(var(--b1)/1));
|
|
1005
|
+
}
|
|
1006
|
+
|
|
1045
1007
|
@media (prefers-reduced-motion: no-preference) {
|
|
1046
1008
|
.btn {
|
|
1047
1009
|
animation: button-pop var(--animation-btn, 0.25s) ease-out;
|
|
@@ -1273,88 +1235,6 @@ html {
|
|
|
1273
1235
|
outline-offset: 2px;
|
|
1274
1236
|
}
|
|
1275
1237
|
|
|
1276
|
-
:where(.menu li:empty) {
|
|
1277
|
-
--tw-bg-opacity: 1;
|
|
1278
|
-
background-color: var(--fallback-bc,oklch(var(--bc)/var(--tw-bg-opacity)));
|
|
1279
|
-
opacity: 0.1;
|
|
1280
|
-
margin: 0.5rem 1rem;
|
|
1281
|
-
height: 1px;
|
|
1282
|
-
}
|
|
1283
|
-
|
|
1284
|
-
.menu :where(li ul):before {
|
|
1285
|
-
position: absolute;
|
|
1286
|
-
bottom: 0.75rem;
|
|
1287
|
-
inset-inline-start: 0px;
|
|
1288
|
-
top: 0.75rem;
|
|
1289
|
-
width: 1px;
|
|
1290
|
-
--tw-bg-opacity: 1;
|
|
1291
|
-
background-color: var(--fallback-bc,oklch(var(--bc)/var(--tw-bg-opacity)));
|
|
1292
|
-
opacity: 0.1;
|
|
1293
|
-
content: "";
|
|
1294
|
-
}
|
|
1295
|
-
|
|
1296
|
-
.menu :where(li:not(.menu-title) > *:not(ul, details, .menu-title, .btn)),
|
|
1297
|
-
.menu :where(li:not(.menu-title) > details > summary:not(.menu-title)) {
|
|
1298
|
-
border-radius: var(--rounded-btn, 0.5rem);
|
|
1299
|
-
padding-left: 1rem;
|
|
1300
|
-
padding-right: 1rem;
|
|
1301
|
-
padding-top: 0.5rem;
|
|
1302
|
-
padding-bottom: 0.5rem;
|
|
1303
|
-
text-align: start;
|
|
1304
|
-
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, -webkit-backdrop-filter;
|
|
1305
|
-
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter;
|
|
1306
|
-
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter, -webkit-backdrop-filter;
|
|
1307
|
-
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
|
1308
|
-
transition-timing-function: cubic-bezier(0, 0, 0.2, 1);
|
|
1309
|
-
transition-duration: 200ms;
|
|
1310
|
-
text-wrap: balance;
|
|
1311
|
-
}
|
|
1312
|
-
|
|
1313
|
-
:where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(summary, .active, .btn).focus, :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(summary, .active, .btn):focus, :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):is(summary):not(.active, .btn):focus-visible, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(summary, .active, .btn).focus, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(summary, .active, .btn):focus, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):is(summary):not(.active, .btn):focus-visible {
|
|
1314
|
-
cursor: pointer;
|
|
1315
|
-
background-color: var(--fallback-bc,oklch(var(--bc)/0.1));
|
|
1316
|
-
--tw-text-opacity: 1;
|
|
1317
|
-
color: var(--fallback-bc,oklch(var(--bc)/var(--tw-text-opacity)));
|
|
1318
|
-
outline: 2px solid transparent;
|
|
1319
|
-
outline-offset: 2px;
|
|
1320
|
-
}
|
|
1321
|
-
|
|
1322
|
-
.menu li > *:not(ul, .menu-title, details, .btn):active,
|
|
1323
|
-
.menu li > *:not(ul, .menu-title, details, .btn).active,
|
|
1324
|
-
.menu li > details > summary:active {
|
|
1325
|
-
--tw-bg-opacity: 1;
|
|
1326
|
-
background-color: var(--fallback-n,oklch(var(--n)/var(--tw-bg-opacity)));
|
|
1327
|
-
--tw-text-opacity: 1;
|
|
1328
|
-
color: var(--fallback-nc,oklch(var(--nc)/var(--tw-text-opacity)));
|
|
1329
|
-
}
|
|
1330
|
-
|
|
1331
|
-
.menu :where(li > details > summary)::-webkit-details-marker {
|
|
1332
|
-
display: none;
|
|
1333
|
-
}
|
|
1334
|
-
|
|
1335
|
-
.menu :where(li > details > summary):after,
|
|
1336
|
-
.menu :where(li > .menu-dropdown-toggle):after {
|
|
1337
|
-
justify-self: end;
|
|
1338
|
-
display: block;
|
|
1339
|
-
margin-top: -0.5rem;
|
|
1340
|
-
height: 0.5rem;
|
|
1341
|
-
width: 0.5rem;
|
|
1342
|
-
transform: rotate(45deg);
|
|
1343
|
-
transition-property: transform, margin-top;
|
|
1344
|
-
transition-duration: 0.3s;
|
|
1345
|
-
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
|
1346
|
-
content: "";
|
|
1347
|
-
transform-origin: 75% 75%;
|
|
1348
|
-
box-shadow: 2px 2px;
|
|
1349
|
-
pointer-events: none;
|
|
1350
|
-
}
|
|
1351
|
-
|
|
1352
|
-
.menu :where(li > details[open] > summary):after,
|
|
1353
|
-
.menu :where(li > .menu-dropdown-toggle.menu-dropdown-show):after {
|
|
1354
|
-
transform: rotate(225deg);
|
|
1355
|
-
margin-top: 0;
|
|
1356
|
-
}
|
|
1357
|
-
|
|
1358
1238
|
.mockup-browser .mockup-browser-toolbar .input {
|
|
1359
1239
|
position: relative;
|
|
1360
1240
|
margin-left: auto;
|
|
@@ -1552,115 +1432,6 @@ html {
|
|
|
1552
1432
|
transform: translate(var(--tw-translate-x), var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y));
|
|
1553
1433
|
}
|
|
1554
1434
|
|
|
1555
|
-
.tooltip {
|
|
1556
|
-
position: relative;
|
|
1557
|
-
display: inline-block;
|
|
1558
|
-
--tooltip-offset: calc(100% + 1px + var(--tooltip-tail, 0px));
|
|
1559
|
-
}
|
|
1560
|
-
|
|
1561
|
-
.tooltip:before {
|
|
1562
|
-
position: absolute;
|
|
1563
|
-
pointer-events: none;
|
|
1564
|
-
z-index: 1;
|
|
1565
|
-
content: var(--tw-content);
|
|
1566
|
-
--tw-content: attr(data-tip);
|
|
1567
|
-
}
|
|
1568
|
-
|
|
1569
|
-
.tooltip:before, .tooltip-top:before {
|
|
1570
|
-
transform: translateX(-50%);
|
|
1571
|
-
top: auto;
|
|
1572
|
-
left: 50%;
|
|
1573
|
-
right: auto;
|
|
1574
|
-
bottom: var(--tooltip-offset);
|
|
1575
|
-
}
|
|
1576
|
-
|
|
1577
|
-
.tooltip {
|
|
1578
|
-
position: relative;
|
|
1579
|
-
display: inline-block;
|
|
1580
|
-
text-align: center;
|
|
1581
|
-
--tooltip-tail: 0.1875rem;
|
|
1582
|
-
--tooltip-color: var(--fallback-n,oklch(var(--n)/1));
|
|
1583
|
-
--tooltip-text-color: var(--fallback-nc,oklch(var(--nc)/1));
|
|
1584
|
-
--tooltip-tail-offset: calc(100% + 0.0625rem - var(--tooltip-tail));
|
|
1585
|
-
}
|
|
1586
|
-
|
|
1587
|
-
.tooltip:before,
|
|
1588
|
-
.tooltip:after {
|
|
1589
|
-
opacity: 0;
|
|
1590
|
-
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, -webkit-backdrop-filter;
|
|
1591
|
-
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter;
|
|
1592
|
-
transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter, -webkit-backdrop-filter;
|
|
1593
|
-
transition-delay: 100ms;
|
|
1594
|
-
transition-duration: 200ms;
|
|
1595
|
-
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
|
1596
|
-
}
|
|
1597
|
-
|
|
1598
|
-
.tooltip:after {
|
|
1599
|
-
position: absolute;
|
|
1600
|
-
content: "";
|
|
1601
|
-
border-style: solid;
|
|
1602
|
-
border-width: var(--tooltip-tail, 0);
|
|
1603
|
-
width: 0;
|
|
1604
|
-
height: 0;
|
|
1605
|
-
display: block;
|
|
1606
|
-
}
|
|
1607
|
-
|
|
1608
|
-
.tooltip:before {
|
|
1609
|
-
max-width: 20rem;
|
|
1610
|
-
border-radius: 0.25rem;
|
|
1611
|
-
padding-left: 0.5rem;
|
|
1612
|
-
padding-right: 0.5rem;
|
|
1613
|
-
padding-top: 0.25rem;
|
|
1614
|
-
padding-bottom: 0.25rem;
|
|
1615
|
-
font-size: 0.875rem;
|
|
1616
|
-
line-height: 1.25rem;
|
|
1617
|
-
background-color: var(--tooltip-color);
|
|
1618
|
-
color: var(--tooltip-text-color);
|
|
1619
|
-
width: -moz-max-content;
|
|
1620
|
-
width: max-content;
|
|
1621
|
-
}
|
|
1622
|
-
|
|
1623
|
-
.tooltip.tooltip-open:before {
|
|
1624
|
-
opacity: 1;
|
|
1625
|
-
transition-delay: 75ms;
|
|
1626
|
-
}
|
|
1627
|
-
|
|
1628
|
-
.tooltip.tooltip-open:after {
|
|
1629
|
-
opacity: 1;
|
|
1630
|
-
transition-delay: 75ms;
|
|
1631
|
-
}
|
|
1632
|
-
|
|
1633
|
-
.tooltip:hover:before {
|
|
1634
|
-
opacity: 1;
|
|
1635
|
-
transition-delay: 75ms;
|
|
1636
|
-
}
|
|
1637
|
-
|
|
1638
|
-
.tooltip:hover:after {
|
|
1639
|
-
opacity: 1;
|
|
1640
|
-
transition-delay: 75ms;
|
|
1641
|
-
}
|
|
1642
|
-
|
|
1643
|
-
.tooltip:has(:focus-visible):after,
|
|
1644
|
-
.tooltip:has(:focus-visible):before {
|
|
1645
|
-
opacity: 1;
|
|
1646
|
-
transition-delay: 75ms;
|
|
1647
|
-
}
|
|
1648
|
-
|
|
1649
|
-
.tooltip:not([data-tip]):hover:before,
|
|
1650
|
-
.tooltip:not([data-tip]):hover:after {
|
|
1651
|
-
visibility: hidden;
|
|
1652
|
-
opacity: 0;
|
|
1653
|
-
}
|
|
1654
|
-
|
|
1655
|
-
.tooltip:after, .tooltip-top:after {
|
|
1656
|
-
transform: translateX(-50%);
|
|
1657
|
-
border-color: var(--tooltip-color) transparent transparent transparent;
|
|
1658
|
-
top: auto;
|
|
1659
|
-
left: 50%;
|
|
1660
|
-
right: auto;
|
|
1661
|
-
bottom: var(--tooltip-tail-offset);
|
|
1662
|
-
}
|
|
1663
|
-
|
|
1664
1435
|
.static {
|
|
1665
1436
|
position: static;
|
|
1666
1437
|
}
|
|
@@ -1684,14 +1455,42 @@ html {
|
|
|
1684
1455
|
display: flex;
|
|
1685
1456
|
}
|
|
1686
1457
|
|
|
1458
|
+
.grid {
|
|
1459
|
+
display: grid;
|
|
1460
|
+
}
|
|
1461
|
+
|
|
1462
|
+
.hidden {
|
|
1463
|
+
display: none;
|
|
1464
|
+
}
|
|
1465
|
+
|
|
1466
|
+
.h-6 {
|
|
1467
|
+
height: 1.5rem;
|
|
1468
|
+
}
|
|
1469
|
+
|
|
1687
1470
|
.w-10\/12 {
|
|
1688
1471
|
width: 83.333333%;
|
|
1689
1472
|
}
|
|
1690
1473
|
|
|
1474
|
+
.w-6 {
|
|
1475
|
+
width: 1.5rem;
|
|
1476
|
+
}
|
|
1477
|
+
|
|
1691
1478
|
.w-full {
|
|
1692
1479
|
width: 100%;
|
|
1693
1480
|
}
|
|
1694
1481
|
|
|
1482
|
+
.shrink-0 {
|
|
1483
|
+
flex-shrink: 0;
|
|
1484
|
+
}
|
|
1485
|
+
|
|
1486
|
+
.cursor-pointer {
|
|
1487
|
+
cursor: pointer;
|
|
1488
|
+
}
|
|
1489
|
+
|
|
1490
|
+
.grid-cols-2 {
|
|
1491
|
+
grid-template-columns: repeat(2, minmax(0, 1fr));
|
|
1492
|
+
}
|
|
1493
|
+
|
|
1695
1494
|
.flex-col {
|
|
1696
1495
|
flex-direction: column;
|
|
1697
1496
|
}
|
|
@@ -1700,14 +1499,17 @@ html {
|
|
|
1700
1499
|
justify-content: center;
|
|
1701
1500
|
}
|
|
1702
1501
|
|
|
1703
|
-
.
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
margin-left: calc(1rem * calc(1 - var(--tw-space-x-reverse)));
|
|
1502
|
+
.gap-x-2 {
|
|
1503
|
+
-moz-column-gap: 0.5rem;
|
|
1504
|
+
column-gap: 0.5rem;
|
|
1707
1505
|
}
|
|
1708
1506
|
|
|
1709
|
-
.
|
|
1710
|
-
|
|
1507
|
+
.overflow-x-scroll {
|
|
1508
|
+
overflow-x: scroll;
|
|
1509
|
+
}
|
|
1510
|
+
|
|
1511
|
+
.whitespace-nowrap {
|
|
1512
|
+
white-space: nowrap;
|
|
1711
1513
|
}
|
|
1712
1514
|
|
|
1713
1515
|
.bg-base-200 {
|
|
@@ -1715,8 +1517,18 @@ html {
|
|
|
1715
1517
|
background-color: var(--fallback-b2,oklch(var(--b2)/var(--tw-bg-opacity)));
|
|
1716
1518
|
}
|
|
1717
1519
|
|
|
1718
|
-
.
|
|
1719
|
-
|
|
1520
|
+
.bg-blue-400 {
|
|
1521
|
+
--tw-bg-opacity: 1;
|
|
1522
|
+
background-color: rgb(96 165 250 / var(--tw-bg-opacity));
|
|
1523
|
+
}
|
|
1524
|
+
|
|
1525
|
+
.bg-error {
|
|
1526
|
+
--tw-bg-opacity: 1;
|
|
1527
|
+
background-color: var(--fallback-er,oklch(var(--er)/var(--tw-bg-opacity)));
|
|
1528
|
+
}
|
|
1529
|
+
|
|
1530
|
+
.stroke-current {
|
|
1531
|
+
stroke: currentColor;
|
|
1720
1532
|
}
|
|
1721
1533
|
|
|
1722
1534
|
.text-2xl {
|
|
@@ -1728,7 +1540,3 @@ html {
|
|
|
1728
1540
|
font-size: 1.875rem;
|
|
1729
1541
|
line-height: 2.25rem;
|
|
1730
1542
|
}
|
|
1731
|
-
|
|
1732
|
-
.ease-in-out {
|
|
1733
|
-
transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
|
|
1734
|
-
}
|
|
@@ -11,26 +11,73 @@
|
|
|
11
11
|
href="{{ url_for('static', filename='output.css') }}"
|
|
12
12
|
/>
|
|
13
13
|
<script src="https://cdn.jsdelivr.net/npm/sweetalert2@11"></script>
|
|
14
|
+
<script src="https://cdn.jsdelivr.net/npm/sortablejs@latest/Sortable.min.js"></script>
|
|
14
15
|
<script type="module">
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
16
|
+
let jaxSortable = new Sortable(document.getElementById("jax-fields"), {
|
|
17
|
+
animation: 150,
|
|
18
|
+
ghostClass: "blue-background-class",
|
|
19
|
+
});
|
|
20
|
+
|
|
21
|
+
let torchSortable = new Sortable(
|
|
22
|
+
document.getElementById("torch-fields"),
|
|
23
|
+
{
|
|
24
|
+
animation: 150,
|
|
25
|
+
ghostClass: "bg-blue-400",
|
|
26
|
+
onEnd: function (evt) {
|
|
27
|
+
document.getElementById("error-field").classList.add("hidden");
|
|
28
|
+
let allJaxFields = document.querySelectorAll("#jax-fields > div");
|
|
29
|
+
let allTorchFields = document.querySelectorAll(
|
|
30
|
+
"#torch-fields > div",
|
|
31
|
+
);
|
|
32
|
+
if (allJaxFields.length !== allTorchFields.length) {
|
|
33
|
+
Swal.fire({
|
|
34
|
+
icon: "error",
|
|
35
|
+
title:
|
|
36
|
+
"The number of fields in JAX and PyTorch should be the same",
|
|
37
|
+
});
|
|
38
|
+
} else {
|
|
39
|
+
for (let i = 0; i < allJaxFields.length; i++) {
|
|
40
|
+
let jaxField = allJaxFields[i];
|
|
41
|
+
let torchField = allTorchFields[i];
|
|
42
|
+
let jaxShape = jaxField.getAttribute("data-shape");
|
|
43
|
+
let torchShape = torchField.getAttribute("data-shape");
|
|
44
|
+
|
|
45
|
+
jaxShape = jaxShape
|
|
46
|
+
.replace("(", "")
|
|
47
|
+
.replace(")", "")
|
|
48
|
+
.replace(/\s+/g, "")
|
|
49
|
+
.replace(/,\s*$/, "");
|
|
50
|
+
|
|
51
|
+
torchShape = torchShape
|
|
52
|
+
.replace("(", "")
|
|
53
|
+
.replace(")", "")
|
|
54
|
+
.replace(/\s+/g, "")
|
|
55
|
+
.replace(/,\s*$/, "");
|
|
56
|
+
|
|
57
|
+
let jaxShapeParts = jaxShape.split(",").map((x) => parseInt(x));
|
|
58
|
+
let torchShapeParts = torchShape
|
|
59
|
+
.split(",")
|
|
60
|
+
.map((x) => parseInt(x));
|
|
61
|
+
let jaxShapeProduct = jaxShapeParts.reduce((a, b) => a * b, 1);
|
|
62
|
+
let torchShapeProduct = torchShapeParts.reduce(
|
|
63
|
+
(a, b) => a * b,
|
|
64
|
+
1,
|
|
65
|
+
);
|
|
66
|
+
|
|
67
|
+
let jaxEl = jaxField;
|
|
68
|
+
let torchEl = torchField;
|
|
69
|
+
if (jaxShapeProduct !== torchShapeProduct) {
|
|
70
|
+
jaxEl.classList.add("bg-error");
|
|
71
|
+
torchEl.classList.add("bg-error");
|
|
72
|
+
} else {
|
|
73
|
+
jaxEl.classList.remove("bg-error");
|
|
74
|
+
torchEl.classList.remove("bg-error");
|
|
75
|
+
}
|
|
76
|
+
}
|
|
77
|
+
}
|
|
31
78
|
},
|
|
32
|
-
}
|
|
33
|
-
|
|
79
|
+
},
|
|
80
|
+
);
|
|
34
81
|
</script>
|
|
35
82
|
</head>
|
|
36
83
|
<body class="w-10/12 mx-auto">
|
|
@@ -66,10 +113,8 @@
|
|
|
66
113
|
xhr.setRequestHeader("Content-Type", "application/json");
|
|
67
114
|
xhr.onload = function () {
|
|
68
115
|
if (xhr.status >= 200 && xhr.status < 300) {
|
|
69
|
-
// Successfully received response
|
|
70
116
|
var container = document.getElementById("visualizationResult");
|
|
71
117
|
container.innerHTML = xhr.responseText;
|
|
72
|
-
// Execute any scripts that were in the response
|
|
73
118
|
var scripts = container.getElementsByTagName("script");
|
|
74
119
|
for (var i = 0; i < scripts.length; i++) {
|
|
75
120
|
var script = document.createElement("script");
|
|
@@ -86,7 +131,7 @@
|
|
|
86
131
|
|
|
87
132
|
function getJaxAndTorchFields() {
|
|
88
133
|
const jaxFields = Array.from(
|
|
89
|
-
document.querySelectorAll("
|
|
134
|
+
document.querySelectorAll("#jax-fields")[0].children,
|
|
90
135
|
).map((li) => {
|
|
91
136
|
const path = li.getAttribute("data-path");
|
|
92
137
|
const shape = li.getAttribute("data-shape");
|
|
@@ -95,7 +140,7 @@
|
|
|
95
140
|
});
|
|
96
141
|
|
|
97
142
|
const torchFields = Array.from(
|
|
98
|
-
document.querySelectorAll("
|
|
143
|
+
document.querySelectorAll("#torch-fields")[0].children,
|
|
99
144
|
).map((li) => {
|
|
100
145
|
const path = li.getAttribute("data-path");
|
|
101
146
|
const shape = li.getAttribute("data-shape");
|
|
@@ -121,6 +166,11 @@
|
|
|
121
166
|
icon: "error",
|
|
122
167
|
title: `${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`,
|
|
123
168
|
});
|
|
169
|
+
document.getElementById("error-field").classList.remove("hidden");
|
|
170
|
+
document
|
|
171
|
+
.getElementById("error-field")
|
|
172
|
+
.querySelector("span").innerText =
|
|
173
|
+
`${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`;
|
|
124
174
|
return { error: "Invalid shapes" };
|
|
125
175
|
}
|
|
126
176
|
}
|
|
@@ -133,7 +183,8 @@
|
|
|
133
183
|
if (fields.error) {
|
|
134
184
|
Toast.fire({
|
|
135
185
|
icon: "error",
|
|
136
|
-
title: "
|
|
186
|
+
title: "Failed to convert!",
|
|
187
|
+
text: fields.error,
|
|
137
188
|
});
|
|
138
189
|
}
|
|
139
190
|
const jaxFields = fields.jaxFields;
|
|
@@ -178,42 +229,36 @@
|
|
|
178
229
|
</script>
|
|
179
230
|
<h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
|
|
180
231
|
|
|
181
|
-
<div class="
|
|
182
|
-
<div class="
|
|
232
|
+
<div class="grid grid-cols-2 gap-x-2">
|
|
233
|
+
<div class="">
|
|
183
234
|
<h2 class="text-2xl">JAX</h2>
|
|
184
|
-
<
|
|
235
|
+
<div id="jax-fields" class="bg-base-200">
|
|
185
236
|
{% for field in pytree_fields %}
|
|
186
|
-
<
|
|
187
|
-
draggable="true"
|
|
237
|
+
<div
|
|
188
238
|
data-path="{{field.path}}"
|
|
189
239
|
data-shape="{{field.shape}}"
|
|
190
240
|
data-type="{{field.type}}"
|
|
241
|
+
class="whitespace-nowrap overflow-x-scroll cursor-pointer"
|
|
191
242
|
>
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
data-path="{{field.path}}"
|
|
195
|
-
data-tip="{{field.type}}"
|
|
196
|
-
>
|
|
197
|
-
{{ field.path}} {{field.shape }}
|
|
198
|
-
</p>
|
|
199
|
-
</li>
|
|
243
|
+
{{ field.path }} {{ field.shape }}
|
|
244
|
+
</div>
|
|
200
245
|
{% endfor %}
|
|
201
|
-
</
|
|
246
|
+
</div>
|
|
202
247
|
</div>
|
|
203
248
|
|
|
204
|
-
<div class="
|
|
249
|
+
<div class="">
|
|
205
250
|
<h2 class="text-2xl">PyTorch</h2>
|
|
206
|
-
<
|
|
251
|
+
<div id="torch-fields" class="bg-base-200">
|
|
207
252
|
{% for field in torch_fields %}
|
|
208
|
-
<
|
|
209
|
-
draggable="true"
|
|
253
|
+
<div
|
|
210
254
|
data-path="{{field.path}}"
|
|
211
255
|
data-shape="{{field.shape}}"
|
|
256
|
+
class="whitespace-nowrap overflow-x-scroll cursor-pointer"
|
|
212
257
|
>
|
|
213
|
-
|
|
214
|
-
</
|
|
258
|
+
{{ field.path }} {{ field.shape }}
|
|
259
|
+
</div>
|
|
215
260
|
{% endfor %}
|
|
216
|
-
</
|
|
261
|
+
</div>
|
|
217
262
|
</div>
|
|
218
263
|
</div>
|
|
219
264
|
<div class="flex justify-center my-12 w-full">
|
|
@@ -234,6 +279,23 @@
|
|
|
234
279
|
</button>
|
|
235
280
|
</div>
|
|
236
281
|
</div>
|
|
282
|
+
<div role="alert" class="alert alert-error hidden" id="error-field">
|
|
283
|
+
<svg
|
|
284
|
+
xmlns="http://www.w3.org/2000/svg"
|
|
285
|
+
class="stroke-current shrink-0 h-6 w-6"
|
|
286
|
+
fill="none"
|
|
287
|
+
viewBox="0 0 24 24"
|
|
288
|
+
>
|
|
289
|
+
<path
|
|
290
|
+
stroke-linecap="round"
|
|
291
|
+
stroke-linejoin="round"
|
|
292
|
+
stroke-width="2"
|
|
293
|
+
d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z"
|
|
294
|
+
/>
|
|
295
|
+
</svg>
|
|
296
|
+
<span></span>
|
|
297
|
+
</div>
|
|
298
|
+
|
|
237
299
|
<div class="flex justify-center">
|
|
238
300
|
<button onclick="visualize()" class="btn btn-secondary">
|
|
239
301
|
Visualize with Penzai!
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import equinox as eqx
|
|
2
|
+
import jax
|
|
3
|
+
import numpy as np
|
|
4
|
+
import statedict2pytree as s2p
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def test_linear():
|
|
9
|
+
in_features = 10
|
|
10
|
+
out_features = 10
|
|
11
|
+
|
|
12
|
+
class J(eqx.Module):
|
|
13
|
+
linear: eqx.nn.Linear
|
|
14
|
+
norm: eqx.nn.BatchNorm
|
|
15
|
+
|
|
16
|
+
def __init__(self):
|
|
17
|
+
self.linear = eqx.nn.Linear(
|
|
18
|
+
in_features, out_features, key=jax.random.PRNGKey(30)
|
|
19
|
+
)
|
|
20
|
+
self.norm = eqx.nn.BatchNorm(input_size=out_features, axis_name="batch")
|
|
21
|
+
|
|
22
|
+
class T(torch.nn.Module):
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
super(T, self).__init__()
|
|
25
|
+
self.linear = torch.nn.Linear(in_features, out_features)
|
|
26
|
+
self.norm = torch.nn.BatchNorm1d(out_features)
|
|
27
|
+
|
|
28
|
+
jax_model = J()
|
|
29
|
+
torch_model = T()
|
|
30
|
+
state_dict = torch_model.state_dict()
|
|
31
|
+
|
|
32
|
+
model, state = s2p.autoconvert(pytree=jax_model, state_dict=state_dict)
|
|
33
|
+
|
|
34
|
+
assert np.allclose(
|
|
35
|
+
np.array(model.linear.weight), torch_model.linear.weight.detach().numpy()
|
|
36
|
+
)
|
|
37
|
+
assert np.allclose(
|
|
38
|
+
np.array(model.linear.bias), torch_model.linear.bias.detach().numpy()
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
assert np.allclose(
|
|
42
|
+
np.array(model.norm.weight), torch_model.norm.weight.detach().numpy()
|
|
43
|
+
)
|
|
44
|
+
assert np.allclose(
|
|
45
|
+
np.array(model.norm.bias), torch_model.norm.bias.detach().numpy()
|
|
46
|
+
)
|
|
@@ -1,16 +0,0 @@
|
|
|
1
|
-
import jax
|
|
2
|
-
import statedict2pytree as s2p
|
|
3
|
-
from resnet import resnet50
|
|
4
|
-
from torchvision.models import resnet50 as t_resnet50, ResNet50_Weights
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
def convert_resnet():
|
|
8
|
-
resnet_jax = resnet50(key=jax.random.PRNGKey(33), make_with_state=False)
|
|
9
|
-
resnet_torch = t_resnet50(weights=ResNet50_Weights.DEFAULT)
|
|
10
|
-
state_dict = resnet_torch.state_dict()
|
|
11
|
-
|
|
12
|
-
s2p.start_conversion(resnet_jax, state_dict)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
if __name__ == "__main__":
|
|
16
|
-
convert_resnet()
|
|
@@ -1,10 +0,0 @@
|
|
|
1
|
-
{
|
|
2
|
-
"scripts": {
|
|
3
|
-
"dev": "python3 torch2jax/torch2jax.py --debug=True",
|
|
4
|
-
"watch": "npx tailwindcss -i ./torch2jax/static/input.css -o ./torch2jax/static/output.css --watch"
|
|
5
|
-
},
|
|
6
|
-
"devDependencies": {
|
|
7
|
-
"daisyui": "^4.11.1",
|
|
8
|
-
"tailwindcss": "^3.4.3"
|
|
9
|
-
}
|
|
10
|
-
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|