statedict2pytree 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,7 +46,7 @@ def get_node(
46
46
  return tree
47
47
  else:
48
48
  next_target: str = targets[0]
49
- if bool(re.search(r"\[\d\]", next_target)):
49
+ if bool(re.search(r"\[\d+\]", next_target)):
50
50
  split_index = next_target.rfind("[")
51
51
  name, index = next_target[:split_index], next_target[split_index:]
52
52
  index = index[1:-1]
@@ -157,6 +157,9 @@ def index():
157
157
  def autoconvert(pytree: PyTree, state_dict: dict) -> tuple[PyTree, eqx.nn.State]:
158
158
  jax_fields = pytree_to_fields(pytree)
159
159
  torch_fields = state_dict_to_fields(state_dict)
160
+
161
+ for k, v in state_dict.items():
162
+ state_dict[k] = v.numpy()
160
163
  return convert(jax_fields, torch_fields, pytree, state_dict)
161
164
 
162
165
 
@@ -212,4 +215,5 @@ def start_conversion(pytree: PyTree, state_dict: dict):
212
215
 
213
216
  for k, v in STATE_DICT.items():
214
217
  STATE_DICT[k] = v.numpy()
218
+ app.jinja_env.globals.update(enumerate=enumerate)
215
219
  app.run(debug=True, port=5500)
@@ -784,23 +784,42 @@ html {
784
784
  }
785
785
  }
786
786
 
787
- .avatar.placeholder > div {
788
- display: flex;
787
+ .alert {
788
+ display: grid;
789
+ width: 100%;
790
+ grid-auto-flow: row;
791
+ align-content: flex-start;
789
792
  align-items: center;
790
- justify-content: center;
793
+ justify-items: center;
794
+ gap: 1rem;
795
+ text-align: center;
796
+ border-radius: var(--rounded-box, 1rem);
797
+ border-width: 1px;
798
+ --tw-border-opacity: 1;
799
+ border-color: var(--fallback-b2,oklch(var(--b2)/var(--tw-border-opacity)));
800
+ padding: 1rem;
801
+ --tw-text-opacity: 1;
802
+ color: var(--fallback-bc,oklch(var(--bc)/var(--tw-text-opacity)));
803
+ --alert-bg: var(--fallback-b2,oklch(var(--b2)/1));
804
+ --alert-bg-mix: var(--fallback-b1,oklch(var(--b1)/1));
805
+ background-color: var(--alert-bg);
791
806
  }
792
807
 
793
- @media (hover:hover) {
794
- .menu li > *:not(ul, .menu-title, details, .btn):active,
795
- .menu li > *:not(ul, .menu-title, details, .btn).active,
796
- .menu li > details > summary:active {
797
- --tw-bg-opacity: 1;
798
- background-color: var(--fallback-n,oklch(var(--n)/var(--tw-bg-opacity)));
799
- --tw-text-opacity: 1;
800
- color: var(--fallback-nc,oklch(var(--nc)/var(--tw-text-opacity)));
808
+ @media (min-width: 640px) {
809
+ .alert {
810
+ grid-auto-flow: column;
811
+ grid-template-columns: auto minmax(auto,1fr);
812
+ justify-items: start;
813
+ text-align: start;
801
814
  }
802
815
  }
803
816
 
817
+ .avatar.placeholder > div {
818
+ display: flex;
819
+ align-items: center;
820
+ justify-content: center;
821
+ }
822
+
804
823
  .btn {
805
824
  display: inline-flex;
806
825
  height: 3rem;
@@ -934,18 +953,6 @@ html {
934
953
  border-color: color-mix(in oklab, var(--fallback-p,oklch(var(--p)/1)) 90%, black);
935
954
  }
936
955
  }
937
-
938
- :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(.active, .btn):hover, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(.active, .btn):hover {
939
- cursor: pointer;
940
- outline: 2px solid transparent;
941
- outline-offset: 2px;
942
- }
943
-
944
- @supports (color: oklch(0% 0 0)) {
945
- :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(.active, .btn):hover, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(.active, .btn):hover {
946
- background-color: var(--fallback-bc,oklch(var(--bc)/0.1));
947
- }
948
- }
949
956
  }
950
957
 
951
958
  .input {
@@ -978,59 +985,6 @@ html {
978
985
  text-decoration-line: underline;
979
986
  }
980
987
 
981
- .menu {
982
- display: flex;
983
- flex-direction: column;
984
- flex-wrap: wrap;
985
- font-size: 0.875rem;
986
- line-height: 1.25rem;
987
- padding: 0.5rem;
988
- }
989
-
990
- .menu :where(li ul) {
991
- position: relative;
992
- white-space: nowrap;
993
- margin-inline-start: 1rem;
994
- padding-inline-start: 0.5rem;
995
- }
996
-
997
- .menu :where(li:not(.menu-title) > *:not(ul, details, .menu-title, .btn)), .menu :where(li:not(.menu-title) > details > summary:not(.menu-title)) {
998
- display: grid;
999
- grid-auto-flow: column;
1000
- align-content: flex-start;
1001
- align-items: center;
1002
- gap: 0.5rem;
1003
- grid-auto-columns: minmax(auto, max-content) auto max-content;
1004
- -webkit-user-select: none;
1005
- -moz-user-select: none;
1006
- user-select: none;
1007
- }
1008
-
1009
- .menu li.disabled {
1010
- cursor: not-allowed;
1011
- -webkit-user-select: none;
1012
- -moz-user-select: none;
1013
- user-select: none;
1014
- color: var(--fallback-bc,oklch(var(--bc)/0.3));
1015
- }
1016
-
1017
- .menu :where(li > .menu-dropdown:not(.menu-dropdown-show)) {
1018
- display: none;
1019
- }
1020
-
1021
- :where(.menu li) {
1022
- position: relative;
1023
- display: flex;
1024
- flex-shrink: 0;
1025
- flex-direction: column;
1026
- flex-wrap: wrap;
1027
- align-items: stretch;
1028
- }
1029
-
1030
- :where(.menu li) .badge {
1031
- justify-self: end;
1032
- }
1033
-
1034
988
  .toast {
1035
989
  position: fixed;
1036
990
  display: flex;
@@ -1042,6 +996,14 @@ html {
1042
996
  padding: 1rem;
1043
997
  }
1044
998
 
999
+ .alert-error {
1000
+ border-color: var(--fallback-er,oklch(var(--er)/0.2));
1001
+ --tw-text-opacity: 1;
1002
+ color: var(--fallback-erc,oklch(var(--erc)/var(--tw-text-opacity)));
1003
+ --alert-bg: var(--fallback-er,oklch(var(--er)/1));
1004
+ --alert-bg-mix: var(--fallback-b1,oklch(var(--b1)/1));
1005
+ }
1006
+
1045
1007
  @media (prefers-reduced-motion: no-preference) {
1046
1008
  .btn {
1047
1009
  animation: button-pop var(--animation-btn, 0.25s) ease-out;
@@ -1273,88 +1235,6 @@ html {
1273
1235
  outline-offset: 2px;
1274
1236
  }
1275
1237
 
1276
- :where(.menu li:empty) {
1277
- --tw-bg-opacity: 1;
1278
- background-color: var(--fallback-bc,oklch(var(--bc)/var(--tw-bg-opacity)));
1279
- opacity: 0.1;
1280
- margin: 0.5rem 1rem;
1281
- height: 1px;
1282
- }
1283
-
1284
- .menu :where(li ul):before {
1285
- position: absolute;
1286
- bottom: 0.75rem;
1287
- inset-inline-start: 0px;
1288
- top: 0.75rem;
1289
- width: 1px;
1290
- --tw-bg-opacity: 1;
1291
- background-color: var(--fallback-bc,oklch(var(--bc)/var(--tw-bg-opacity)));
1292
- opacity: 0.1;
1293
- content: "";
1294
- }
1295
-
1296
- .menu :where(li:not(.menu-title) > *:not(ul, details, .menu-title, .btn)),
1297
- .menu :where(li:not(.menu-title) > details > summary:not(.menu-title)) {
1298
- border-radius: var(--rounded-btn, 0.5rem);
1299
- padding-left: 1rem;
1300
- padding-right: 1rem;
1301
- padding-top: 0.5rem;
1302
- padding-bottom: 0.5rem;
1303
- text-align: start;
1304
- transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, -webkit-backdrop-filter;
1305
- transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter;
1306
- transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter, -webkit-backdrop-filter;
1307
- transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
1308
- transition-timing-function: cubic-bezier(0, 0, 0.2, 1);
1309
- transition-duration: 200ms;
1310
- text-wrap: balance;
1311
- }
1312
-
1313
- :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(summary, .active, .btn).focus, :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):not(summary, .active, .btn):focus, :where(.menu li:not(.menu-title, .disabled) > *:not(ul, details, .menu-title)):is(summary):not(.active, .btn):focus-visible, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(summary, .active, .btn).focus, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):not(summary, .active, .btn):focus, :where(.menu li:not(.menu-title, .disabled) > details > summary:not(.menu-title)):is(summary):not(.active, .btn):focus-visible {
1314
- cursor: pointer;
1315
- background-color: var(--fallback-bc,oklch(var(--bc)/0.1));
1316
- --tw-text-opacity: 1;
1317
- color: var(--fallback-bc,oklch(var(--bc)/var(--tw-text-opacity)));
1318
- outline: 2px solid transparent;
1319
- outline-offset: 2px;
1320
- }
1321
-
1322
- .menu li > *:not(ul, .menu-title, details, .btn):active,
1323
- .menu li > *:not(ul, .menu-title, details, .btn).active,
1324
- .menu li > details > summary:active {
1325
- --tw-bg-opacity: 1;
1326
- background-color: var(--fallback-n,oklch(var(--n)/var(--tw-bg-opacity)));
1327
- --tw-text-opacity: 1;
1328
- color: var(--fallback-nc,oklch(var(--nc)/var(--tw-text-opacity)));
1329
- }
1330
-
1331
- .menu :where(li > details > summary)::-webkit-details-marker {
1332
- display: none;
1333
- }
1334
-
1335
- .menu :where(li > details > summary):after,
1336
- .menu :where(li > .menu-dropdown-toggle):after {
1337
- justify-self: end;
1338
- display: block;
1339
- margin-top: -0.5rem;
1340
- height: 0.5rem;
1341
- width: 0.5rem;
1342
- transform: rotate(45deg);
1343
- transition-property: transform, margin-top;
1344
- transition-duration: 0.3s;
1345
- transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
1346
- content: "";
1347
- transform-origin: 75% 75%;
1348
- box-shadow: 2px 2px;
1349
- pointer-events: none;
1350
- }
1351
-
1352
- .menu :where(li > details[open] > summary):after,
1353
- .menu :where(li > .menu-dropdown-toggle.menu-dropdown-show):after {
1354
- transform: rotate(225deg);
1355
- margin-top: 0;
1356
- }
1357
-
1358
1238
  .mockup-browser .mockup-browser-toolbar .input {
1359
1239
  position: relative;
1360
1240
  margin-left: auto;
@@ -1552,115 +1432,6 @@ html {
1552
1432
  transform: translate(var(--tw-translate-x), var(--tw-translate-y)) rotate(var(--tw-rotate)) skewX(var(--tw-skew-x)) skewY(var(--tw-skew-y)) scaleX(var(--tw-scale-x)) scaleY(var(--tw-scale-y));
1553
1433
  }
1554
1434
 
1555
- .tooltip {
1556
- position: relative;
1557
- display: inline-block;
1558
- --tooltip-offset: calc(100% + 1px + var(--tooltip-tail, 0px));
1559
- }
1560
-
1561
- .tooltip:before {
1562
- position: absolute;
1563
- pointer-events: none;
1564
- z-index: 1;
1565
- content: var(--tw-content);
1566
- --tw-content: attr(data-tip);
1567
- }
1568
-
1569
- .tooltip:before, .tooltip-top:before {
1570
- transform: translateX(-50%);
1571
- top: auto;
1572
- left: 50%;
1573
- right: auto;
1574
- bottom: var(--tooltip-offset);
1575
- }
1576
-
1577
- .tooltip {
1578
- position: relative;
1579
- display: inline-block;
1580
- text-align: center;
1581
- --tooltip-tail: 0.1875rem;
1582
- --tooltip-color: var(--fallback-n,oklch(var(--n)/1));
1583
- --tooltip-text-color: var(--fallback-nc,oklch(var(--nc)/1));
1584
- --tooltip-tail-offset: calc(100% + 0.0625rem - var(--tooltip-tail));
1585
- }
1586
-
1587
- .tooltip:before,
1588
- .tooltip:after {
1589
- opacity: 0;
1590
- transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, -webkit-backdrop-filter;
1591
- transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter;
1592
- transition-property: color, background-color, border-color, text-decoration-color, fill, stroke, opacity, box-shadow, transform, filter, backdrop-filter, -webkit-backdrop-filter;
1593
- transition-delay: 100ms;
1594
- transition-duration: 200ms;
1595
- transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
1596
- }
1597
-
1598
- .tooltip:after {
1599
- position: absolute;
1600
- content: "";
1601
- border-style: solid;
1602
- border-width: var(--tooltip-tail, 0);
1603
- width: 0;
1604
- height: 0;
1605
- display: block;
1606
- }
1607
-
1608
- .tooltip:before {
1609
- max-width: 20rem;
1610
- border-radius: 0.25rem;
1611
- padding-left: 0.5rem;
1612
- padding-right: 0.5rem;
1613
- padding-top: 0.25rem;
1614
- padding-bottom: 0.25rem;
1615
- font-size: 0.875rem;
1616
- line-height: 1.25rem;
1617
- background-color: var(--tooltip-color);
1618
- color: var(--tooltip-text-color);
1619
- width: -moz-max-content;
1620
- width: max-content;
1621
- }
1622
-
1623
- .tooltip.tooltip-open:before {
1624
- opacity: 1;
1625
- transition-delay: 75ms;
1626
- }
1627
-
1628
- .tooltip.tooltip-open:after {
1629
- opacity: 1;
1630
- transition-delay: 75ms;
1631
- }
1632
-
1633
- .tooltip:hover:before {
1634
- opacity: 1;
1635
- transition-delay: 75ms;
1636
- }
1637
-
1638
- .tooltip:hover:after {
1639
- opacity: 1;
1640
- transition-delay: 75ms;
1641
- }
1642
-
1643
- .tooltip:has(:focus-visible):after,
1644
- .tooltip:has(:focus-visible):before {
1645
- opacity: 1;
1646
- transition-delay: 75ms;
1647
- }
1648
-
1649
- .tooltip:not([data-tip]):hover:before,
1650
- .tooltip:not([data-tip]):hover:after {
1651
- visibility: hidden;
1652
- opacity: 0;
1653
- }
1654
-
1655
- .tooltip:after, .tooltip-top:after {
1656
- transform: translateX(-50%);
1657
- border-color: var(--tooltip-color) transparent transparent transparent;
1658
- top: auto;
1659
- left: 50%;
1660
- right: auto;
1661
- bottom: var(--tooltip-tail-offset);
1662
- }
1663
-
1664
1435
  .static {
1665
1436
  position: static;
1666
1437
  }
@@ -1684,14 +1455,42 @@ html {
1684
1455
  display: flex;
1685
1456
  }
1686
1457
 
1458
+ .grid {
1459
+ display: grid;
1460
+ }
1461
+
1462
+ .hidden {
1463
+ display: none;
1464
+ }
1465
+
1466
+ .h-6 {
1467
+ height: 1.5rem;
1468
+ }
1469
+
1687
1470
  .w-10\/12 {
1688
1471
  width: 83.333333%;
1689
1472
  }
1690
1473
 
1474
+ .w-6 {
1475
+ width: 1.5rem;
1476
+ }
1477
+
1691
1478
  .w-full {
1692
1479
  width: 100%;
1693
1480
  }
1694
1481
 
1482
+ .shrink-0 {
1483
+ flex-shrink: 0;
1484
+ }
1485
+
1486
+ .cursor-pointer {
1487
+ cursor: pointer;
1488
+ }
1489
+
1490
+ .grid-cols-2 {
1491
+ grid-template-columns: repeat(2, minmax(0, 1fr));
1492
+ }
1493
+
1695
1494
  .flex-col {
1696
1495
  flex-direction: column;
1697
1496
  }
@@ -1700,14 +1499,17 @@ html {
1700
1499
  justify-content: center;
1701
1500
  }
1702
1501
 
1703
- .space-x-4 > :not([hidden]) ~ :not([hidden]) {
1704
- --tw-space-x-reverse: 0;
1705
- margin-right: calc(1rem * var(--tw-space-x-reverse));
1706
- margin-left: calc(1rem * calc(1 - var(--tw-space-x-reverse)));
1502
+ .gap-x-2 {
1503
+ -moz-column-gap: 0.5rem;
1504
+ column-gap: 0.5rem;
1707
1505
  }
1708
1506
 
1709
- .rounded-box {
1710
- border-radius: var(--rounded-box, 1rem);
1507
+ .overflow-x-scroll {
1508
+ overflow-x: scroll;
1509
+ }
1510
+
1511
+ .whitespace-nowrap {
1512
+ white-space: nowrap;
1711
1513
  }
1712
1514
 
1713
1515
  .bg-base-200 {
@@ -1715,8 +1517,18 @@ html {
1715
1517
  background-color: var(--fallback-b2,oklch(var(--b2)/var(--tw-bg-opacity)));
1716
1518
  }
1717
1519
 
1718
- .text-left {
1719
- text-align: left;
1520
+ .bg-blue-400 {
1521
+ --tw-bg-opacity: 1;
1522
+ background-color: rgb(96 165 250 / var(--tw-bg-opacity));
1523
+ }
1524
+
1525
+ .bg-error {
1526
+ --tw-bg-opacity: 1;
1527
+ background-color: var(--fallback-er,oklch(var(--er)/var(--tw-bg-opacity)));
1528
+ }
1529
+
1530
+ .stroke-current {
1531
+ stroke: currentColor;
1720
1532
  }
1721
1533
 
1722
1534
  .text-2xl {
@@ -1728,7 +1540,3 @@ html {
1728
1540
  font-size: 1.875rem;
1729
1541
  line-height: 2.25rem;
1730
1542
  }
1731
-
1732
- .ease-in-out {
1733
- transition-timing-function: cubic-bezier(0.4, 0, 0.2, 1);
1734
- }
@@ -11,26 +11,73 @@
11
11
  href="{{ url_for('static', filename='output.css') }}"
12
12
  />
13
13
  <script src="https://cdn.jsdelivr.net/npm/sweetalert2@11"></script>
14
+ <script src="https://cdn.jsdelivr.net/npm/sortablejs@latest/Sortable.min.js"></script>
14
15
  <script type="module">
15
- import {
16
- Draggable,
17
- Sortable,
18
- Droppable,
19
- Swappable,
20
- Plugins,
21
- } from "https://cdn.jsdelivr.net/npm/@shopify/draggable/build/esm/index.mjs";
22
- const sortableLists = document.querySelectorAll(".draggable-list");
23
-
24
- for (const sortableList of sortableLists) {
25
- const sortable = new Sortable(sortableList, {
26
- draggable: "li",
27
- plugins: [Plugins.SortAnimation],
28
- swapAnimation: {
29
- duration: 200,
30
- easingFunction: "ease-in-out",
16
+ let jaxSortable = new Sortable(document.getElementById("jax-fields"), {
17
+ animation: 150,
18
+ ghostClass: "blue-background-class",
19
+ });
20
+
21
+ let torchSortable = new Sortable(
22
+ document.getElementById("torch-fields"),
23
+ {
24
+ animation: 150,
25
+ ghostClass: "bg-blue-400",
26
+ onEnd: function (evt) {
27
+ document.getElementById("error-field").classList.add("hidden");
28
+ let allJaxFields = document.querySelectorAll("#jax-fields > div");
29
+ let allTorchFields = document.querySelectorAll(
30
+ "#torch-fields > div",
31
+ );
32
+ if (allJaxFields.length !== allTorchFields.length) {
33
+ Swal.fire({
34
+ icon: "error",
35
+ title:
36
+ "The number of fields in JAX and PyTorch should be the same",
37
+ });
38
+ } else {
39
+ for (let i = 0; i < allJaxFields.length; i++) {
40
+ let jaxField = allJaxFields[i];
41
+ let torchField = allTorchFields[i];
42
+ let jaxShape = jaxField.getAttribute("data-shape");
43
+ let torchShape = torchField.getAttribute("data-shape");
44
+
45
+ jaxShape = jaxShape
46
+ .replace("(", "")
47
+ .replace(")", "")
48
+ .replace(/\s+/g, "")
49
+ .replace(/,\s*$/, "");
50
+
51
+ torchShape = torchShape
52
+ .replace("(", "")
53
+ .replace(")", "")
54
+ .replace(/\s+/g, "")
55
+ .replace(/,\s*$/, "");
56
+
57
+ let jaxShapeParts = jaxShape.split(",").map((x) => parseInt(x));
58
+ let torchShapeParts = torchShape
59
+ .split(",")
60
+ .map((x) => parseInt(x));
61
+ let jaxShapeProduct = jaxShapeParts.reduce((a, b) => a * b, 1);
62
+ let torchShapeProduct = torchShapeParts.reduce(
63
+ (a, b) => a * b,
64
+ 1,
65
+ );
66
+
67
+ let jaxEl = jaxField;
68
+ let torchEl = torchField;
69
+ if (jaxShapeProduct !== torchShapeProduct) {
70
+ jaxEl.classList.add("bg-error");
71
+ torchEl.classList.add("bg-error");
72
+ } else {
73
+ jaxEl.classList.remove("bg-error");
74
+ torchEl.classList.remove("bg-error");
75
+ }
76
+ }
77
+ }
31
78
  },
32
- });
33
- }
79
+ },
80
+ );
34
81
  </script>
35
82
  </head>
36
83
  <body class="w-10/12 mx-auto">
@@ -66,10 +113,8 @@
66
113
  xhr.setRequestHeader("Content-Type", "application/json");
67
114
  xhr.onload = function () {
68
115
  if (xhr.status >= 200 && xhr.status < 300) {
69
- // Successfully received response
70
116
  var container = document.getElementById("visualizationResult");
71
117
  container.innerHTML = xhr.responseText;
72
- // Execute any scripts that were in the response
73
118
  var scripts = container.getElementsByTagName("script");
74
119
  for (var i = 0; i < scripts.length; i++) {
75
120
  var script = document.createElement("script");
@@ -86,7 +131,7 @@
86
131
 
87
132
  function getJaxAndTorchFields() {
88
133
  const jaxFields = Array.from(
89
- document.querySelectorAll(".draggable-list")[0].children,
134
+ document.querySelectorAll("#jax-fields")[0].children,
90
135
  ).map((li) => {
91
136
  const path = li.getAttribute("data-path");
92
137
  const shape = li.getAttribute("data-shape");
@@ -95,7 +140,7 @@
95
140
  });
96
141
 
97
142
  const torchFields = Array.from(
98
- document.querySelectorAll(".draggable-list")[1].children,
143
+ document.querySelectorAll("#torch-fields")[0].children,
99
144
  ).map((li) => {
100
145
  const path = li.getAttribute("data-path");
101
146
  const shape = li.getAttribute("data-shape");
@@ -121,6 +166,11 @@
121
166
  icon: "error",
122
167
  title: `${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`,
123
168
  });
169
+ document.getElementById("error-field").classList.remove("hidden");
170
+ document
171
+ .getElementById("error-field")
172
+ .querySelector("span").innerText =
173
+ `${jaxFields[i].path} has shape ${jaxFields[i].shape}, while ${torchFields[i].path} has shape ${torchFields[i].shape}`;
124
174
  return { error: "Invalid shapes" };
125
175
  }
126
176
  }
@@ -133,7 +183,8 @@
133
183
  if (fields.error) {
134
184
  Toast.fire({
135
185
  icon: "error",
136
- title: "The number of fields in JAX and PyTorch should be the same",
186
+ title: "Failed to convert!",
187
+ text: fields.error,
137
188
  });
138
189
  }
139
190
  const jaxFields = fields.jaxFields;
@@ -178,42 +229,36 @@
178
229
  </script>
179
230
  <h1 class="text-3xl my-12">Welcome to Torch2Jax</h1>
180
231
 
181
- <div class="flex space-x-4">
182
- <div class="w-full">
232
+ <div class="grid grid-cols-2 gap-x-2">
233
+ <div class="">
183
234
  <h2 class="text-2xl">JAX</h2>
184
- <ul class="draggable-list menu bg-base-200 rounded-box">
235
+ <div id="jax-fields" class="bg-base-200">
185
236
  {% for field in pytree_fields %}
186
- <li
187
- draggable="true"
237
+ <div
188
238
  data-path="{{field.path}}"
189
239
  data-shape="{{field.shape}}"
190
240
  data-type="{{field.type}}"
241
+ class="whitespace-nowrap overflow-x-scroll cursor-pointer"
191
242
  >
192
- <p
193
- class="tooltip text-left"
194
- data-path="{{field.path}}"
195
- data-tip="{{field.type}}"
196
- >
197
- {{ field.path}} {{field.shape }}
198
- </p>
199
- </li>
243
+ {{ field.path }} {{ field.shape }}
244
+ </div>
200
245
  {% endfor %}
201
- </ul>
246
+ </div>
202
247
  </div>
203
248
 
204
- <div class="w-full">
249
+ <div class="">
205
250
  <h2 class="text-2xl">PyTorch</h2>
206
- <ul class="draggable-list menu bg-base-200 rounded-box w-full">
251
+ <div id="torch-fields" class="bg-base-200">
207
252
  {% for field in torch_fields %}
208
- <li
209
- draggable="true"
253
+ <div
210
254
  data-path="{{field.path}}"
211
255
  data-shape="{{field.shape}}"
256
+ class="whitespace-nowrap overflow-x-scroll cursor-pointer"
212
257
  >
213
- <p class="text-left">{{ field.path }} {{ field.shape }}</p>
214
- </li>
258
+ {{ field.path }} {{ field.shape }}
259
+ </div>
215
260
  {% endfor %}
216
- </ul>
261
+ </div>
217
262
  </div>
218
263
  </div>
219
264
  <div class="flex justify-center my-12 w-full">
@@ -234,6 +279,23 @@
234
279
  </button>
235
280
  </div>
236
281
  </div>
282
+ <div role="alert" class="alert alert-error hidden" id="error-field">
283
+ <svg
284
+ xmlns="http://www.w3.org/2000/svg"
285
+ class="stroke-current shrink-0 h-6 w-6"
286
+ fill="none"
287
+ viewBox="0 0 24 24"
288
+ >
289
+ <path
290
+ stroke-linecap="round"
291
+ stroke-linejoin="round"
292
+ stroke-width="2"
293
+ d="M10 14l2-2m0 0l2-2m-2 2l-2-2m2 2l2 2m7-2a9 9 0 11-18 0 9 9 0 0118 0z"
294
+ />
295
+ </svg>
296
+ <span></span>
297
+ </div>
298
+
237
299
  <div class="flex justify-center">
238
300
  <button onclick="visualize()" class="btn btn-secondary">
239
301
  Visualize with Penzai!
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: statedict2pytree
3
- Version: 0.3.0
3
+ Version: 0.5.0
4
4
  Summary: Converts torch models into PyTrees for Equinox
5
5
  Author-email: "Artur A. Galstyan" <mail@arturgalstyan.dev>
6
6
  Requires-Python: ~=3.10
@@ -0,0 +1,8 @@
1
+ statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
2
+ statedict2pytree/statedict2pytree.py,sha256=yLOWx1D-6tX1VjiEg_-JcYPTrg6KWAgw6waZQi1GNvA,7229
3
+ statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
+ statedict2pytree/static/output.css,sha256=B0itthSyy_tduTWMyTK5sAry-W6WbeODnpQ-oOcQQng,33966
5
+ statedict2pytree/templates/index.html,sha256=Mbo8fFHV6kYRiBiiwayku-p-y3hUaLw_Yj3zn_cfmb0,10027
6
+ statedict2pytree-0.5.0.dist-info/METADATA,sha256=TOf10T0EZoPGAc0qSltZRcr8Ni7y4bHW7w3wzRVJH7A,4232
7
+ statedict2pytree-0.5.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
+ statedict2pytree-0.5.0.dist-info/RECORD,,
@@ -1,8 +0,0 @@
1
- statedict2pytree/__init__.py,sha256=lXxSaFFvkhXweXp5oHSkg_dPjdp49OsF8xoqwX4d_4E,240
2
- statedict2pytree/statedict2pytree.py,sha256=X5Ljf4lYhhH7_V4KgdciChncbTt7YZpIWHcOxcZ3l48,7103
3
- statedict2pytree/static/input.css,sha256=zBp60NAZ3bHTLQ7LWIugrCbOQdhiXdbDZjSLJfg6KOw,59
4
- statedict2pytree/static/output.css,sha256=KZ9GzeV3q0XKjbEiTdPkC6yV-R6jzXRflRm2S16VkJA,40813
5
- statedict2pytree/templates/index.html,sha256=0uG3dB2pAa1f2wcfTpYSO7TBNL77i2ALJP5rIhsbEnk,7506
6
- statedict2pytree-0.3.0.dist-info/METADATA,sha256=YSK4tWzNQemyZ1xKq5BhWiLWWc-RDr4E9q_eV_iOsdw,4232
7
- statedict2pytree-0.3.0.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
8
- statedict2pytree-0.3.0.dist-info/RECORD,,