statedict2pytree 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
client/public/bundle.js CHANGED
@@ -116,6 +116,12 @@ var app = (function () {
116
116
  return text(' ');
117
117
  }
118
118
 
119
+ /**
120
+ * @returns {Text} */
121
+ function empty() {
122
+ return text('');
123
+ }
124
+
119
125
  /**
120
126
  * @param {EventTarget} node
121
127
  * @param {string} event
@@ -153,6 +159,26 @@ var app = (function () {
153
159
  input.value = value == null ? '' : value;
154
160
  }
155
161
 
162
+ /**
163
+ * @returns {void} */
164
+ function select_option(select, value, mounting) {
165
+ for (let i = 0; i < select.options.length; i += 1) {
166
+ const option = select.options[i];
167
+ if (option.__value === value) {
168
+ option.selected = true;
169
+ return;
170
+ }
171
+ }
172
+ if (!mounting || value !== undefined) {
173
+ select.selectedIndex = -1; // no option should be selected
174
+ }
175
+ }
176
+
177
+ function select_value(select) {
178
+ const selected_option = select.querySelector(':checked');
179
+ return selected_option && selected_option.__value;
180
+ }
181
+
156
182
  /**
157
183
  * @template T
158
184
  * @param {string} type
@@ -8737,26 +8763,26 @@ var app = (function () {
8737
8763
 
8738
8764
  function get_each_context(ctx, list, i) {
8739
8765
  const child_ctx = ctx.slice();
8740
- child_ctx[14] = list[i];
8741
- child_ctx[16] = i;
8766
+ child_ctx[18] = list[i];
8767
+ child_ctx[20] = i;
8742
8768
  return child_ctx;
8743
8769
  }
8744
8770
 
8745
8771
  function get_each_context_1(ctx, list, i) {
8746
8772
  const child_ctx = ctx.slice();
8747
- child_ctx[14] = list[i];
8748
- child_ctx[16] = i;
8773
+ child_ctx[18] = list[i];
8774
+ child_ctx[20] = i;
8749
8775
  return child_ctx;
8750
8776
  }
8751
8777
 
8752
- // (231:12) {#each jaxFields as field, i}
8778
+ // (458:12) {#each jaxFields as field, i}
8753
8779
  function create_each_block_1(ctx) {
8754
8780
  let div1;
8755
8781
  let div0;
8756
- let t0_value = /*field*/ ctx[14].path + "";
8782
+ let t0_value = /*field*/ ctx[18].path + "";
8757
8783
  let t0;
8758
8784
  let t1;
8759
- let t2_value = /*field*/ ctx[14].shape + "";
8785
+ let t2_value = /*field*/ ctx[18].shape + "";
8760
8786
  let t2;
8761
8787
  let div0_data_path_value;
8762
8788
  let div0_data_shape_value;
@@ -8772,16 +8798,16 @@ var app = (function () {
8772
8798
  t1 = space();
8773
8799
  t2 = text(t2_value);
8774
8800
  t3 = space();
8775
- attr_dev(div0, "id", "jax-" + String(/*i*/ ctx[16]));
8801
+ attr_dev(div0, "id", "jax-" + String(/*i*/ ctx[20]));
8776
8802
  attr_dev(div0, "class", "whitespace-nowrap overflow-x-scroll cursor-pointer mx-2");
8777
8803
  attr_dev(div0, "data-jax", "jax");
8778
- attr_dev(div0, "data-path", div0_data_path_value = /*field*/ ctx[14].path);
8779
- attr_dev(div0, "data-shape", div0_data_shape_value = /*field*/ ctx[14].shape);
8780
- attr_dev(div0, "data-skip", div0_data_skip_value = /*field*/ ctx[14].skip);
8781
- attr_dev(div0, "data-type", div0_data_type_value = /*field*/ ctx[14].type);
8782
- add_location(div0, file, 268, 20, 7693);
8804
+ attr_dev(div0, "data-path", div0_data_path_value = /*field*/ ctx[18].path);
8805
+ attr_dev(div0, "data-shape", div0_data_shape_value = /*field*/ ctx[18].shape);
8806
+ attr_dev(div0, "data-skip", div0_data_skip_value = /*field*/ ctx[18].skip);
8807
+ attr_dev(div0, "data-type", div0_data_type_value = /*field*/ ctx[18].type);
8808
+ add_location(div0, file, 506, 20, 13315);
8783
8809
  attr_dev(div1, "class", "border h-12 rounded-xl flex flex-col justify-center");
8784
- add_location(div1, file, 265, 16, 7570);
8810
+ add_location(div1, file, 503, 16, 13192);
8785
8811
  },
8786
8812
  m: function mount(target, anchor) {
8787
8813
  insert_dev(target, div1, anchor);
@@ -8792,22 +8818,22 @@ var app = (function () {
8792
8818
  append_dev(div1, t3);
8793
8819
  },
8794
8820
  p: function update(ctx, dirty) {
8795
- if (dirty & /*jaxFields*/ 2 && t0_value !== (t0_value = /*field*/ ctx[14].path + "")) set_data_dev(t0, t0_value);
8796
- if (dirty & /*jaxFields*/ 2 && t2_value !== (t2_value = /*field*/ ctx[14].shape + "")) set_data_dev(t2, t2_value);
8821
+ if (dirty & /*jaxFields*/ 4 && t0_value !== (t0_value = /*field*/ ctx[18].path + "")) set_data_dev(t0, t0_value);
8822
+ if (dirty & /*jaxFields*/ 4 && t2_value !== (t2_value = /*field*/ ctx[18].shape + "")) set_data_dev(t2, t2_value);
8797
8823
 
8798
- if (dirty & /*jaxFields*/ 2 && div0_data_path_value !== (div0_data_path_value = /*field*/ ctx[14].path)) {
8824
+ if (dirty & /*jaxFields*/ 4 && div0_data_path_value !== (div0_data_path_value = /*field*/ ctx[18].path)) {
8799
8825
  attr_dev(div0, "data-path", div0_data_path_value);
8800
8826
  }
8801
8827
 
8802
- if (dirty & /*jaxFields*/ 2 && div0_data_shape_value !== (div0_data_shape_value = /*field*/ ctx[14].shape)) {
8828
+ if (dirty & /*jaxFields*/ 4 && div0_data_shape_value !== (div0_data_shape_value = /*field*/ ctx[18].shape)) {
8803
8829
  attr_dev(div0, "data-shape", div0_data_shape_value);
8804
8830
  }
8805
8831
 
8806
- if (dirty & /*jaxFields*/ 2 && div0_data_skip_value !== (div0_data_skip_value = /*field*/ ctx[14].skip)) {
8832
+ if (dirty & /*jaxFields*/ 4 && div0_data_skip_value !== (div0_data_skip_value = /*field*/ ctx[18].skip)) {
8807
8833
  attr_dev(div0, "data-skip", div0_data_skip_value);
8808
8834
  }
8809
8835
 
8810
- if (dirty & /*jaxFields*/ 2 && div0_data_type_value !== (div0_data_type_value = /*field*/ ctx[14].type)) {
8836
+ if (dirty & /*jaxFields*/ 4 && div0_data_type_value !== (div0_data_type_value = /*field*/ ctx[18].type)) {
8811
8837
  attr_dev(div0, "data-type", div0_data_type_value);
8812
8838
  }
8813
8839
  },
@@ -8822,19 +8848,19 @@ var app = (function () {
8822
8848
  block,
8823
8849
  id: create_each_block_1.name,
8824
8850
  type: "each",
8825
- source: "(231:12) {#each jaxFields as field, i}",
8851
+ source: "(458:12) {#each jaxFields as field, i}",
8826
8852
  ctx
8827
8853
  });
8828
8854
 
8829
8855
  return block;
8830
8856
  }
8831
8857
 
8832
- // (268:24) {:else}
8858
+ // (496:28) {:else}
8833
8859
  function create_else_block(ctx) {
8834
- let t0_value = /*field*/ ctx[14].path + "";
8860
+ let t0_value = /*field*/ ctx[18].path + "";
8835
8861
  let t0;
8836
8862
  let t1;
8837
- let t2_value = /*field*/ ctx[14].shape + "";
8863
+ let t2_value = /*field*/ ctx[18].shape + "";
8838
8864
  let t2;
8839
8865
 
8840
8866
  const block = {
@@ -8849,8 +8875,8 @@ var app = (function () {
8849
8875
  insert_dev(target, t2, anchor);
8850
8876
  },
8851
8877
  p: function update(ctx, dirty) {
8852
- if (dirty & /*torchFields*/ 4 && t0_value !== (t0_value = /*field*/ ctx[14].path + "")) set_data_dev(t0, t0_value);
8853
- if (dirty & /*torchFields*/ 4 && t2_value !== (t2_value = /*field*/ ctx[14].shape + "")) set_data_dev(t2, t2_value);
8878
+ if (dirty & /*torchFields*/ 8 && t0_value !== (t0_value = /*field*/ ctx[18].path + "")) set_data_dev(t0, t0_value);
8879
+ if (dirty & /*torchFields*/ 8 && t2_value !== (t2_value = /*field*/ ctx[18].shape + "")) set_data_dev(t2, t2_value);
8854
8880
  },
8855
8881
  d: function destroy(detaching) {
8856
8882
  if (detaching) {
@@ -8865,14 +8891,14 @@ var app = (function () {
8865
8891
  block,
8866
8892
  id: create_else_block.name,
8867
8893
  type: "else",
8868
- source: "(268:24) {:else}",
8894
+ source: "(496:28) {:else}",
8869
8895
  ctx
8870
8896
  });
8871
8897
 
8872
8898
  return block;
8873
8899
  }
8874
8900
 
8875
- // (266:24) {#if field.skip}
8901
+ // (494:28) {#if field.skip}
8876
8902
  function create_if_block_1(ctx) {
8877
8903
  let t;
8878
8904
 
@@ -8895,21 +8921,21 @@ var app = (function () {
8895
8921
  block,
8896
8922
  id: create_if_block_1.name,
8897
8923
  type: "if",
8898
- source: "(266:24) {#if field.skip}",
8924
+ source: "(494:28) {#if field.skip}",
8899
8925
  ctx
8900
8926
  });
8901
8927
 
8902
8928
  return block;
8903
8929
  }
8904
8930
 
8905
- // (273:20) {#if field.skip}
8931
+ // (501:24) {#if field.skip}
8906
8932
  function create_if_block(ctx) {
8907
8933
  let button;
8908
8934
  let mounted;
8909
8935
  let dispose;
8910
8936
 
8911
8937
  function click_handler() {
8912
- return /*click_handler*/ ctx[6](/*i*/ ctx[16]);
8938
+ return /*click_handler*/ ctx[11](/*i*/ ctx[20]);
8913
8939
  }
8914
8940
 
8915
8941
  const block = {
@@ -8917,7 +8943,7 @@ var app = (function () {
8917
8943
  button = element("button");
8918
8944
  button.textContent = "-";
8919
8945
  attr_dev(button, "class", "btn btn-ghost");
8920
- add_location(button, file, 307, 24, 9211);
8946
+ add_location(button, file, 546, 28, 14944);
8921
8947
  },
8922
8948
  m: function mount(target, anchor) {
8923
8949
  insert_dev(target, button, anchor);
@@ -8944,14 +8970,14 @@ var app = (function () {
8944
8970
  block,
8945
8971
  id: create_if_block.name,
8946
8972
  type: "if",
8947
- source: "(273:20) {#if field.skip}",
8973
+ source: "(501:24) {#if field.skip}",
8948
8974
  ctx
8949
8975
  });
8950
8976
 
8951
8977
  return block;
8952
8978
  }
8953
8979
 
8954
- // (255:12) {#each torchFields as field, i}
8980
+ // (483:16) {#each torchFields as field, i}
8955
8981
  function create_each_block(ctx) {
8956
8982
  let div1;
8957
8983
  let div0;
@@ -8967,16 +8993,16 @@ var app = (function () {
8967
8993
  let dispose;
8968
8994
 
8969
8995
  function select_block_type(ctx, dirty) {
8970
- if (/*field*/ ctx[14].skip) return create_if_block_1;
8996
+ if (/*field*/ ctx[18].skip) return create_if_block_1;
8971
8997
  return create_else_block;
8972
8998
  }
8973
8999
 
8974
9000
  let current_block_type = select_block_type(ctx);
8975
9001
  let if_block0 = current_block_type(ctx);
8976
- let if_block1 = /*field*/ ctx[14].skip && create_if_block(ctx);
9002
+ let if_block1 = /*field*/ ctx[18].skip && create_if_block(ctx);
8977
9003
 
8978
9004
  function click_handler_1() {
8979
- return /*click_handler_1*/ ctx[7](/*i*/ ctx[16]);
9005
+ return /*click_handler_1*/ ctx[12](/*i*/ ctx[20]);
8980
9006
  }
8981
9007
 
8982
9008
  const block = {
@@ -8990,18 +9016,18 @@ var app = (function () {
8990
9016
  button = element("button");
8991
9017
  button.textContent = "+";
8992
9018
  t3 = space();
8993
- attr_dev(div0, "id", "torch-" + String(/*i*/ ctx[16]));
9019
+ attr_dev(div0, "id", "torch-" + String(/*i*/ ctx[20]));
8994
9020
  attr_dev(div0, "data-torch", "torch");
8995
- attr_dev(div0, "data-path", div0_data_path_value = /*field*/ ctx[14].path);
8996
- attr_dev(div0, "data-shape", div0_data_shape_value = /*field*/ ctx[14].shape);
8997
- attr_dev(div0, "data-skip", div0_data_skip_value = /*field*/ ctx[14].skip);
8998
- attr_dev(div0, "data-type", div0_data_type_value = /*field*/ ctx[14].type);
9021
+ attr_dev(div0, "data-path", div0_data_path_value = /*field*/ ctx[18].path);
9022
+ attr_dev(div0, "data-shape", div0_data_shape_value = /*field*/ ctx[18].shape);
9023
+ attr_dev(div0, "data-skip", div0_data_skip_value = /*field*/ ctx[18].skip);
9024
+ attr_dev(div0, "data-type", div0_data_type_value = /*field*/ ctx[18].type);
8999
9025
  attr_dev(div0, "class", "flex-1 mx-2 my-auto whitespace-nowrap overflow-x-scroll cursor-pointer");
9000
- add_location(div0, file, 290, 20, 8491);
9026
+ add_location(div0, file, 529, 24, 14156);
9001
9027
  attr_dev(button, "class", "btn btn-ghost");
9002
- add_location(button, file, 314, 20, 9480);
9028
+ add_location(button, file, 553, 24, 15241);
9003
9029
  attr_dev(div1, "class", "flex space-x-2 border h-12 rounded-xl");
9004
- add_location(div1, file, 289, 16, 8419);
9030
+ add_location(div1, file, 528, 20, 14080);
9005
9031
  },
9006
9032
  m: function mount(target, anchor) {
9007
9033
  insert_dev(target, div1, anchor);
@@ -9033,23 +9059,23 @@ var app = (function () {
9033
9059
  }
9034
9060
  }
9035
9061
 
9036
- if (dirty & /*torchFields*/ 4 && div0_data_path_value !== (div0_data_path_value = /*field*/ ctx[14].path)) {
9062
+ if (dirty & /*torchFields*/ 8 && div0_data_path_value !== (div0_data_path_value = /*field*/ ctx[18].path)) {
9037
9063
  attr_dev(div0, "data-path", div0_data_path_value);
9038
9064
  }
9039
9065
 
9040
- if (dirty & /*torchFields*/ 4 && div0_data_shape_value !== (div0_data_shape_value = /*field*/ ctx[14].shape)) {
9066
+ if (dirty & /*torchFields*/ 8 && div0_data_shape_value !== (div0_data_shape_value = /*field*/ ctx[18].shape)) {
9041
9067
  attr_dev(div0, "data-shape", div0_data_shape_value);
9042
9068
  }
9043
9069
 
9044
- if (dirty & /*torchFields*/ 4 && div0_data_skip_value !== (div0_data_skip_value = /*field*/ ctx[14].skip)) {
9070
+ if (dirty & /*torchFields*/ 8 && div0_data_skip_value !== (div0_data_skip_value = /*field*/ ctx[18].skip)) {
9045
9071
  attr_dev(div0, "data-skip", div0_data_skip_value);
9046
9072
  }
9047
9073
 
9048
- if (dirty & /*torchFields*/ 4 && div0_data_type_value !== (div0_data_type_value = /*field*/ ctx[14].type)) {
9074
+ if (dirty & /*torchFields*/ 8 && div0_data_type_value !== (div0_data_type_value = /*field*/ ctx[18].type)) {
9049
9075
  attr_dev(div0, "data-type", div0_data_type_value);
9050
9076
  }
9051
9077
 
9052
- if (/*field*/ ctx[14].skip) {
9078
+ if (/*field*/ ctx[18].skip) {
9053
9079
  if (if_block1) {
9054
9080
  if_block1.p(ctx, dirty);
9055
9081
  } else {
@@ -9078,7 +9104,78 @@ var app = (function () {
9078
9104
  block,
9079
9105
  id: create_each_block.name,
9080
9106
  type: "each",
9081
- source: "(255:12) {#each torchFields as field, i}",
9107
+ source: "(483:16) {#each torchFields as field, i}",
9108
+ ctx
9109
+ });
9110
+
9111
+ return block;
9112
+ }
9113
+
9114
+ // (482:12) {#key torchFields}
9115
+ function create_key_block(ctx) {
9116
+ let each_1_anchor;
9117
+ let each_value = ensure_array_like_dev(/*torchFields*/ ctx[3]);
9118
+ let each_blocks = [];
9119
+
9120
+ for (let i = 0; i < each_value.length; i += 1) {
9121
+ each_blocks[i] = create_each_block(get_each_context(ctx, each_value, i));
9122
+ }
9123
+
9124
+ const block = {
9125
+ c: function create() {
9126
+ for (let i = 0; i < each_blocks.length; i += 1) {
9127
+ each_blocks[i].c();
9128
+ }
9129
+
9130
+ each_1_anchor = empty();
9131
+ },
9132
+ m: function mount(target, anchor) {
9133
+ for (let i = 0; i < each_blocks.length; i += 1) {
9134
+ if (each_blocks[i]) {
9135
+ each_blocks[i].m(target, anchor);
9136
+ }
9137
+ }
9138
+
9139
+ insert_dev(target, each_1_anchor, anchor);
9140
+ },
9141
+ p: function update(ctx, dirty) {
9142
+ if (dirty & /*addSkipLayer, removeSkipLayer, torchFields, String*/ 56) {
9143
+ each_value = ensure_array_like_dev(/*torchFields*/ ctx[3]);
9144
+ let i;
9145
+
9146
+ for (i = 0; i < each_value.length; i += 1) {
9147
+ const child_ctx = get_each_context(ctx, each_value, i);
9148
+
9149
+ if (each_blocks[i]) {
9150
+ each_blocks[i].p(child_ctx, dirty);
9151
+ } else {
9152
+ each_blocks[i] = create_each_block(child_ctx);
9153
+ each_blocks[i].c();
9154
+ each_blocks[i].m(each_1_anchor.parentNode, each_1_anchor);
9155
+ }
9156
+ }
9157
+
9158
+ for (; i < each_blocks.length; i += 1) {
9159
+ each_blocks[i].d(1);
9160
+ }
9161
+
9162
+ each_blocks.length = each_value.length;
9163
+ }
9164
+ },
9165
+ d: function destroy(detaching) {
9166
+ if (detaching) {
9167
+ detach_dev(each_1_anchor);
9168
+ }
9169
+
9170
+ destroy_each(each_blocks, detaching);
9171
+ }
9172
+ };
9173
+
9174
+ dispatch_dev("SvelteRegisterBlock", {
9175
+ block,
9176
+ id: create_key_block.name,
9177
+ type: "key",
9178
+ source: "(482:12) {#key torchFields}",
9082
9179
  ctx
9083
9180
  });
9084
9181
 
@@ -9089,37 +9186,46 @@ var app = (function () {
9089
9186
  let t0;
9090
9187
  let h1;
9091
9188
  let t2;
9092
- let div4;
9093
9189
  let div1;
9094
- let h20;
9190
+ let button0;
9095
9191
  let t4;
9192
+ let button1;
9193
+ let t6;
9096
9194
  let div0;
9097
- let t5;
9098
- let div3;
9099
- let h21;
9100
- let t7;
9101
- let div2;
9195
+ let button2;
9102
9196
  let t8;
9197
+ let select;
9198
+ let option0;
9199
+ let option1;
9200
+ let option2;
9201
+ let t12;
9103
9202
  let div6;
9203
+ let div3;
9204
+ let h20;
9205
+ let t14;
9206
+ let div2;
9207
+ let t15;
9104
9208
  let div5;
9209
+ let h21;
9210
+ let t17;
9211
+ let div4;
9212
+ let previous_key = /*torchFields*/ ctx[3];
9213
+ let t18;
9214
+ let div8;
9215
+ let div7;
9105
9216
  let input;
9106
- let t9;
9107
- let button;
9217
+ let t19;
9218
+ let button3;
9108
9219
  let mounted;
9109
9220
  let dispose;
9110
- let each_value_1 = ensure_array_like_dev(/*jaxFields*/ ctx[1]);
9111
- let each_blocks_1 = [];
9221
+ let each_value_1 = ensure_array_like_dev(/*jaxFields*/ ctx[2]);
9222
+ let each_blocks = [];
9112
9223
 
9113
9224
  for (let i = 0; i < each_value_1.length; i += 1) {
9114
- each_blocks_1[i] = create_each_block_1(get_each_context_1(ctx, each_value_1, i));
9225
+ each_blocks[i] = create_each_block_1(get_each_context_1(ctx, each_value_1, i));
9115
9226
  }
9116
9227
 
9117
- let each_value = ensure_array_like_dev(/*torchFields*/ ctx[2]);
9118
- let each_blocks = [];
9119
-
9120
- for (let i = 0; i < each_value.length; i += 1) {
9121
- each_blocks[i] = create_each_block(get_each_context(ctx, each_value, i));
9122
- }
9228
+ let key_block = create_key_block(ctx);
9123
9229
 
9124
9230
  const block = {
9125
9231
  c: function create() {
@@ -9127,66 +9233,101 @@ var app = (function () {
9127
9233
  h1 = element("h1");
9128
9234
  h1.textContent = "Welcome to Torch2Jax";
9129
9235
  t2 = space();
9130
- div4 = element("div");
9131
9236
  div1 = element("div");
9132
- h20 = element("h2");
9133
- h20.textContent = "JAX";
9237
+ button0 = element("button");
9238
+ button0.textContent = "Pad To Match";
9134
9239
  t4 = space();
9240
+ button1 = element("button");
9241
+ button1.textContent = "Remove All Skip Layers";
9242
+ t6 = space();
9135
9243
  div0 = element("div");
9136
-
9137
- for (let i = 0; i < each_blocks_1.length; i += 1) {
9138
- each_blocks_1[i].c();
9139
- }
9140
-
9141
- t5 = space();
9244
+ button2 = element("button");
9245
+ button2.textContent = "Match By Name";
9246
+ t8 = space();
9247
+ select = element("select");
9248
+ option0 = element("option");
9249
+ option0.textContent = "opus";
9250
+ option1 = element("option");
9251
+ option1.textContent = "sonnet";
9252
+ option2 = element("option");
9253
+ option2.textContent = "haiku";
9254
+ t12 = space();
9255
+ div6 = element("div");
9142
9256
  div3 = element("div");
9143
- h21 = element("h2");
9144
- h21.textContent = "PyTorch";
9145
- t7 = space();
9257
+ h20 = element("h2");
9258
+ h20.textContent = "JAX";
9259
+ t14 = space();
9146
9260
  div2 = element("div");
9147
9261
 
9148
9262
  for (let i = 0; i < each_blocks.length; i += 1) {
9149
9263
  each_blocks[i].c();
9150
9264
  }
9151
9265
 
9152
- t8 = space();
9153
- div6 = element("div");
9266
+ t15 = space();
9154
9267
  div5 = element("div");
9268
+ h21 = element("h2");
9269
+ h21.textContent = "PyTorch";
9270
+ t17 = space();
9271
+ div4 = element("div");
9272
+ key_block.c();
9273
+ t18 = space();
9274
+ div8 = element("div");
9275
+ div7 = element("div");
9155
9276
  input = element("input");
9156
- t9 = space();
9157
- button = element("button");
9158
- button.textContent = "Convert!";
9277
+ t19 = space();
9278
+ button3 = element("button");
9279
+ button3.textContent = "Convert!";
9159
9280
  document_1.title = "Statedict2PyTree";
9160
9281
  attr_dev(h1, "class", "text-3xl my-12");
9161
- add_location(h1, file, 258, 0, 7323);
9282
+ add_location(h1, file, 481, 0, 12367);
9283
+ attr_dev(button0, "class", "btn btn-accent");
9284
+ add_location(button0, file, 483, 4, 12463);
9285
+ attr_dev(button1, "class", "btn btn-secondary");
9286
+ add_location(button1, file, 484, 4, 12542);
9287
+ attr_dev(button2, "class", "btn btn-warning");
9288
+ add_location(button2, file, 488, 8, 12671);
9289
+ option0.__value = "opus";
9290
+ set_input_value(option0, option0.__value);
9291
+ add_location(option0, file, 492, 12, 12828);
9292
+ option1.__value = "sonnet";
9293
+ set_input_value(option1, option1.__value);
9294
+ add_location(option1, file, 493, 12, 12875);
9295
+ option2.__value = "haiku";
9296
+ set_input_value(option2, option2.__value);
9297
+ add_location(option2, file, 494, 12, 12926);
9298
+ if (/*anthropicModel*/ ctx[1] === void 0) add_render_callback(() => /*select_change_handler*/ ctx[10].call(select));
9299
+ add_location(select, file, 491, 8, 12779);
9300
+ add_location(div0, file, 487, 4, 12657);
9301
+ attr_dev(div1, "class", "my-4 flex justify-evenly");
9302
+ add_location(div1, file, 482, 0, 12420);
9162
9303
  attr_dev(h20, "class", "text-2xl");
9163
- add_location(h20, file, 262, 8, 7443);
9164
- attr_dev(div0, "id", "jax-fields");
9165
- attr_dev(div0, "class", "");
9166
- add_location(div0, file, 263, 8, 7481);
9167
- attr_dev(div1, "class", "");
9168
- add_location(div1, file, 261, 4, 7420);
9169
- attr_dev(h21, "class", "text-2xl");
9170
- add_location(h21, file, 286, 8, 8284);
9171
- attr_dev(div2, "id", "torch-fields");
9304
+ add_location(h20, file, 500, 8, 13065);
9305
+ attr_dev(div2, "id", "jax-fields");
9172
9306
  attr_dev(div2, "class", "");
9173
- add_location(div2, file, 287, 8, 8326);
9307
+ add_location(div2, file, 501, 8, 13103);
9174
9308
  attr_dev(div3, "class", "");
9175
- add_location(div3, file, 285, 4, 8261);
9176
- attr_dev(div4, "class", "grid grid-cols-2 gap-x-2");
9177
- add_location(div4, file, 260, 0, 7377);
9309
+ add_location(div3, file, 499, 4, 13042);
9310
+ attr_dev(h21, "class", "text-2xl");
9311
+ add_location(h21, file, 524, 8, 13906);
9312
+ attr_dev(div4, "id", "torch-fields");
9313
+ attr_dev(div4, "class", "");
9314
+ add_location(div4, file, 525, 8, 13948);
9315
+ attr_dev(div5, "class", "");
9316
+ add_location(div5, file, 523, 4, 13883);
9317
+ attr_dev(div6, "class", "grid grid-cols-2 gap-x-2");
9318
+ add_location(div6, file, 498, 0, 12999);
9178
9319
  attr_dev(input, "id", "name");
9179
9320
  attr_dev(input, "type", "text");
9180
9321
  attr_dev(input, "name", "name");
9181
9322
  attr_dev(input, "class", "input input-primary w-full");
9182
9323
  attr_dev(input, "placeholder", "Name of the new file (model.eqx per default)");
9183
- add_location(input, file, 327, 8, 9865);
9184
- attr_dev(button, "class", "btn btn-accent btn-wide btn-lg mx-auto my-2");
9185
- add_location(button, file, 335, 8, 10110);
9186
- attr_dev(div5, "class", "flex flex-col justify-center w-full");
9187
- add_location(div5, file, 326, 4, 9807);
9188
- attr_dev(div6, "class", "flex justify-center my-12 w-full");
9189
- add_location(div6, file, 325, 0, 9756);
9324
+ add_location(input, file, 567, 8, 15673);
9325
+ attr_dev(button3, "class", "btn btn-accent btn-wide btn-lg mx-auto my-2");
9326
+ add_location(button3, file, 575, 8, 15918);
9327
+ attr_dev(div7, "class", "flex flex-col justify-center w-full");
9328
+ add_location(div7, file, 566, 4, 15615);
9329
+ attr_dev(div8, "class", "flex justify-center my-12 w-full");
9330
+ add_location(div8, file, 565, 0, 15564);
9190
9331
  },
9191
9332
  l: function claim(nodes) {
9192
9333
  throw new Error("options.hydrate only works if the component was compiled with the `hydratable: true` option");
@@ -9195,22 +9336,24 @@ var app = (function () {
9195
9336
  insert_dev(target, t0, anchor);
9196
9337
  insert_dev(target, h1, anchor);
9197
9338
  insert_dev(target, t2, anchor);
9198
- insert_dev(target, div4, anchor);
9199
- append_dev(div4, div1);
9200
- append_dev(div1, h20);
9339
+ insert_dev(target, div1, anchor);
9340
+ append_dev(div1, button0);
9201
9341
  append_dev(div1, t4);
9342
+ append_dev(div1, button1);
9343
+ append_dev(div1, t6);
9202
9344
  append_dev(div1, div0);
9203
-
9204
- for (let i = 0; i < each_blocks_1.length; i += 1) {
9205
- if (each_blocks_1[i]) {
9206
- each_blocks_1[i].m(div0, null);
9207
- }
9208
- }
9209
-
9210
- append_dev(div4, t5);
9211
- append_dev(div4, div3);
9212
- append_dev(div3, h21);
9213
- append_dev(div3, t7);
9345
+ append_dev(div0, button2);
9346
+ append_dev(div0, t8);
9347
+ append_dev(div0, select);
9348
+ append_dev(select, option0);
9349
+ append_dev(select, option1);
9350
+ append_dev(select, option2);
9351
+ select_option(select, /*anthropicModel*/ ctx[1], true);
9352
+ insert_dev(target, t12, anchor);
9353
+ insert_dev(target, div6, anchor);
9354
+ append_dev(div6, div3);
9355
+ append_dev(div3, h20);
9356
+ append_dev(div3, t14);
9214
9357
  append_dev(div3, div2);
9215
9358
 
9216
9359
  for (let i = 0; i < each_blocks.length; i += 1) {
@@ -9219,58 +9362,49 @@ var app = (function () {
9219
9362
  }
9220
9363
  }
9221
9364
 
9222
- insert_dev(target, t8, anchor);
9223
- insert_dev(target, div6, anchor);
9365
+ append_dev(div6, t15);
9224
9366
  append_dev(div6, div5);
9225
- append_dev(div5, input);
9367
+ append_dev(div5, h21);
9368
+ append_dev(div5, t17);
9369
+ append_dev(div5, div4);
9370
+ key_block.m(div4, null);
9371
+ insert_dev(target, t18, anchor);
9372
+ insert_dev(target, div8, anchor);
9373
+ append_dev(div8, div7);
9374
+ append_dev(div7, input);
9226
9375
  set_input_value(input, /*model*/ ctx[0]);
9227
- append_dev(div5, t9);
9228
- append_dev(div5, button);
9376
+ append_dev(div7, t19);
9377
+ append_dev(div7, button3);
9229
9378
 
9230
9379
  if (!mounted) {
9231
9380
  dispose = [
9232
- listen_dev(input, "input", /*input_input_handler*/ ctx[8]),
9233
- listen_dev(button, "click", /*convert*/ ctx[5], false, false, false, false)
9381
+ listen_dev(button0, "click", /*padToMatch*/ ctx[7], false, false, false, false),
9382
+ listen_dev(button1, "click", /*removeAllSkipLayers*/ ctx[8], false, false, false, false),
9383
+ listen_dev(button2, "click", /*matchByName*/ ctx[9], false, false, false, false),
9384
+ listen_dev(select, "change", /*select_change_handler*/ ctx[10]),
9385
+ listen_dev(input, "input", /*input_input_handler*/ ctx[13]),
9386
+ listen_dev(button3, "click", /*convert*/ ctx[6], false, false, false, false)
9234
9387
  ];
9235
9388
 
9236
9389
  mounted = true;
9237
9390
  }
9238
9391
  },
9239
9392
  p: function update(ctx, [dirty]) {
9240
- if (dirty & /*String, jaxFields*/ 2) {
9241
- each_value_1 = ensure_array_like_dev(/*jaxFields*/ ctx[1]);
9242
- let i;
9243
-
9244
- for (i = 0; i < each_value_1.length; i += 1) {
9245
- const child_ctx = get_each_context_1(ctx, each_value_1, i);
9246
-
9247
- if (each_blocks_1[i]) {
9248
- each_blocks_1[i].p(child_ctx, dirty);
9249
- } else {
9250
- each_blocks_1[i] = create_each_block_1(child_ctx);
9251
- each_blocks_1[i].c();
9252
- each_blocks_1[i].m(div0, null);
9253
- }
9254
- }
9255
-
9256
- for (; i < each_blocks_1.length; i += 1) {
9257
- each_blocks_1[i].d(1);
9258
- }
9259
-
9260
- each_blocks_1.length = each_value_1.length;
9393
+ if (dirty & /*anthropicModel*/ 2) {
9394
+ select_option(select, /*anthropicModel*/ ctx[1]);
9261
9395
  }
9262
9396
 
9263
- if (dirty & /*addSkipLayer, removeSkipLayer, torchFields, String*/ 28) {
9264
- each_value = ensure_array_like_dev(/*torchFields*/ ctx[2]);
9397
+ if (dirty & /*String, jaxFields*/ 4) {
9398
+ each_value_1 = ensure_array_like_dev(/*jaxFields*/ ctx[2]);
9265
9399
  let i;
9266
9400
 
9267
- for (i = 0; i < each_value.length; i += 1) {
9268
- const child_ctx = get_each_context(ctx, each_value, i);
9401
+ for (i = 0; i < each_value_1.length; i += 1) {
9402
+ const child_ctx = get_each_context_1(ctx, each_value_1, i);
9269
9403
 
9270
9404
  if (each_blocks[i]) {
9271
9405
  each_blocks[i].p(child_ctx, dirty);
9272
9406
  } else {
9273
- each_blocks[i] = create_each_block(child_ctx);
9407
+ each_blocks[i] = create_each_block_1(child_ctx);
9274
9408
  each_blocks[i].c();
9275
9409
  each_blocks[i].m(div2, null);
9276
9410
  }
@@ -9280,7 +9414,16 @@ var app = (function () {
9280
9414
  each_blocks[i].d(1);
9281
9415
  }
9282
9416
 
9283
- each_blocks.length = each_value.length;
9417
+ each_blocks.length = each_value_1.length;
9418
+ }
9419
+
9420
+ if (dirty & /*torchFields*/ 8 && safe_not_equal(previous_key, previous_key = /*torchFields*/ ctx[3])) {
9421
+ key_block.d(1);
9422
+ key_block = create_key_block(ctx);
9423
+ key_block.c();
9424
+ key_block.m(div4, null);
9425
+ } else {
9426
+ key_block.p(ctx, dirty);
9284
9427
  }
9285
9428
 
9286
9429
  if (dirty & /*model*/ 1 && input.value !== /*model*/ ctx[0]) {
@@ -9294,13 +9437,15 @@ var app = (function () {
9294
9437
  detach_dev(t0);
9295
9438
  detach_dev(h1);
9296
9439
  detach_dev(t2);
9297
- detach_dev(div4);
9298
- detach_dev(t8);
9440
+ detach_dev(div1);
9441
+ detach_dev(t12);
9299
9442
  detach_dev(div6);
9443
+ detach_dev(t18);
9444
+ detach_dev(div8);
9300
9445
  }
9301
9446
 
9302
- destroy_each(each_blocks_1, detaching);
9303
9447
  destroy_each(each_blocks, detaching);
9448
+ key_block.d(detaching);
9304
9449
  mounted = false;
9305
9450
  run_all(dispose);
9306
9451
  }
@@ -9317,13 +9462,6 @@ var app = (function () {
9317
9462
  return block;
9318
9463
  }
9319
9464
 
9320
- function swap(a, b, array) {
9321
- const temp = array[a];
9322
- array[a] = array[b];
9323
- array[b] = temp;
9324
- return array;
9325
- }
9326
-
9327
9465
  function checkFields(jaxFields, torchFields) {
9328
9466
  if (jaxFields.length > torchFields.length) {
9329
9467
  return {
@@ -9359,6 +9497,7 @@ var app = (function () {
9359
9497
  let { $$slots: slots = {}, $$scope } = $$props;
9360
9498
  validate_slots('App', slots, []);
9361
9499
  let model = "model.eqx";
9500
+ let anthropicModel = "haiku";
9362
9501
 
9363
9502
  const Toast = Swal.mixin({
9364
9503
  toast: true,
@@ -9374,13 +9513,12 @@ var app = (function () {
9374
9513
 
9375
9514
  let jaxFields = [];
9376
9515
  let torchFields = [];
9377
- let torchSortable;
9378
9516
 
9379
9517
  onMount(async () => {
9380
9518
  let req = await fetch("/startup/getJaxFields");
9381
- $$invalidate(1, jaxFields = await req.json());
9519
+ $$invalidate(2, jaxFields = await req.json());
9382
9520
  req = await fetch("/startup/getTorchFields");
9383
- $$invalidate(2, torchFields = await req.json());
9521
+ $$invalidate(3, torchFields = await req.json());
9384
9522
 
9385
9523
  setTimeout(
9386
9524
  () => {
@@ -9388,10 +9526,17 @@ var app = (function () {
9388
9526
  },
9389
9527
  100
9390
9528
  );
9529
+
9530
+ setTimeout(
9531
+ () => {
9532
+ onEnd();
9533
+ },
9534
+ 500
9535
+ );
9391
9536
  });
9392
9537
 
9393
9538
  function initSortable() {
9394
- torchSortable = new Sortable(document.getElementById("torch-fields"),
9539
+ new Sortable(document.getElementById("torch-fields"),
9395
9540
  {
9396
9541
  animation: 150,
9397
9542
  multiDrag: true,
@@ -9466,74 +9611,67 @@ var app = (function () {
9466
9611
  }
9467
9612
 
9468
9613
  function onEnd() {
9469
- var _a, _b, _c, _d, _e, _f, _g;
9470
- const updatedFields = fetchJaxAndTorchFields();
9471
-
9472
- if (updatedFields.error) {
9473
- Toast.fire({
9474
- icon: "error",
9475
- title: updatedFields.error
9476
- });
9477
-
9478
- return;
9479
- }
9480
-
9481
- for (let i = 0; i < updatedFields.jaxFields.length; i++) {
9482
- let jaxField = updatedFields.jaxFields[i];
9483
- let torchField = updatedFields.torchFields[i];
9484
- if (torchField === undefined) continue;
9614
+ setTimeout(
9615
+ () => {
9616
+ var _a, _b, _c, _d;
9617
+ const updatedFields = fetchJaxAndTorchFields();
9485
9618
 
9486
- if (torchField.skip === true) {
9487
- (_a = document.getElementById("jax-" + i)) === null || _a === void 0
9488
- ? void 0
9489
- : _a.classList.remove("bg-error");
9619
+ if (updatedFields.error) {
9620
+ Toast.fire({
9621
+ icon: "error",
9622
+ title: updatedFields.error
9623
+ });
9490
9624
 
9491
- (_b = document.getElementById("torch-" + i)) === null || _b === void 0
9492
- ? void 0
9493
- : _b.classList.remove("bg-error");
9625
+ return;
9626
+ }
9494
9627
 
9495
- continue;
9496
- }
9628
+ for (let i = 0; i < updatedFields.jaxFields.length; i++) {
9629
+ let jaxField = updatedFields.jaxFields[i];
9630
+ let torchField = updatedFields.torchFields[i];
9631
+ if (torchField === undefined) continue;
9497
9632
 
9498
- let jaxShape = jaxField.shape;
9499
- let torchShape = torchField.shape;
9633
+ if (torchField.skip === true) {
9634
+ (_a = document.getElementById("jax-" + i)) === null || _a === void 0
9635
+ ? void 0
9636
+ : _a.classList.remove("bg-error");
9500
9637
 
9501
- //@ts-ignore
9502
- let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
9638
+ continue;
9639
+ }
9503
9640
 
9504
- //@ts-ignore
9505
- let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
9641
+ let jaxShape = jaxField.shape;
9642
+ let torchShape = torchField.shape;
9506
9643
 
9507
- if (jaxShapeProduct !== torchShapeProduct) {
9508
- (_c = document.getElementById("jax-" + i)) === null || _c === void 0
9509
- ? void 0
9510
- : _c.classList.add("bg-error");
9644
+ //@ts-ignore
9645
+ let jaxShapeProduct = jaxShape.reduce((a, b) => a * b, 1);
9511
9646
 
9512
- (_d = document.getElementById("torch-" + i)) === null || _d === void 0
9513
- ? void 0
9514
- : _d.classList.add("bg-error");
9515
- } else {
9516
- (_e = document.getElementById("jax-" + i)) === null || _e === void 0
9517
- ? void 0
9518
- : _e.classList.remove("bg-error");
9647
+ //@ts-ignore
9648
+ let torchShapeProduct = torchShape.reduce((a, b) => a * b, 1);
9519
9649
 
9520
- (_f = document.getElementById("torch-" + i)) === null || _f === void 0
9521
- ? void 0
9522
- : _f.classList.remove("bg-error");
9523
- }
9524
- }
9650
+ if (jaxShapeProduct !== torchShapeProduct) {
9651
+ (_b = document.getElementById("jax-" + i)) === null || _b === void 0
9652
+ ? void 0
9653
+ : _b.classList.add("bg-error");
9654
+ } else {
9655
+ (_c = document.getElementById("jax-" + i)) === null || _c === void 0
9656
+ ? void 0
9657
+ : _c.classList.remove("bg-error");
9658
+ }
9659
+ }
9525
9660
 
9526
- if (updatedFields.torchFields.length > updatedFields.jaxFields.length) {
9527
- for (let i = updatedFields.jaxFields.length; i < updatedFields.torchFields.length; i++) {
9528
- (_g = document.getElementById("torch-" + i)) === null || _g === void 0
9529
- ? void 0
9530
- : _g.classList.remove("bg-error");
9531
- }
9532
- }
9661
+ if (updatedFields.torchFields.length > updatedFields.jaxFields.length) {
9662
+ for (let i = updatedFields.jaxFields.length; i < updatedFields.torchFields.length; i++) {
9663
+ (_d = document.getElementById("torch-" + i)) === null || _d === void 0
9664
+ ? void 0
9665
+ : _d.classList.remove("bg-error");
9666
+ }
9667
+ }
9668
+ },
9669
+ 100
9670
+ );
9533
9671
  }
9534
9672
 
9535
9673
  function removeSkipLayer(index) {
9536
- $$invalidate(2, torchFields = torchFields.toSpliced(index, 1));
9674
+ $$invalidate(3, torchFields = torchFields.toSpliced(index, 1));
9537
9675
 
9538
9676
  setTimeout(
9539
9677
  () => {
@@ -9544,14 +9682,21 @@ var app = (function () {
9544
9682
  }
9545
9683
 
9546
9684
  function addSkipLayer(index) {
9685
+ let fields = fetchJaxAndTorchFields();
9686
+
9687
+ if (fields.error) {
9688
+ Toast.fire({ icon: "error", text: fields.error });
9689
+ return;
9690
+ }
9691
+
9547
9692
  const newField = {
9548
9693
  skip: true,
9549
- shape: [],
9550
- path: "",
9551
- type: ""
9694
+ shape: [0],
9695
+ path: "SKIP",
9696
+ type: "SKIP"
9552
9697
  };
9553
9698
 
9554
- $$invalidate(2, torchFields = torchFields.toSpliced(index, 0, newField));
9699
+ $$invalidate(3, torchFields = fields.torchFields.toSpliced(index, 0, newField));
9555
9700
 
9556
9701
  setTimeout(
9557
9702
  () => {
@@ -9604,12 +9749,241 @@ var app = (function () {
9604
9749
  }
9605
9750
  }
9606
9751
 
9752
+ function padToMatch() {
9753
+ let fields = fetchJaxAndTorchFields();
9754
+
9755
+ if (fields.error) {
9756
+ Toast.fire({ icon: "error", text: fields.error });
9757
+ return;
9758
+ }
9759
+
9760
+ if (fields.torchFields.length < fields.jaxFields.length) {
9761
+ let toAdd = fields.jaxFields.length - fields.torchFields.length;
9762
+
9763
+ for (let i = 0; i < toAdd; i++) {
9764
+ setTimeout(
9765
+ () => {
9766
+ console.log("adding skip at ", i);
9767
+ addSkipLayer(fields.jaxFields.length + i);
9768
+ },
9769
+ 100
9770
+ );
9771
+ }
9772
+ }
9773
+ }
9774
+
9775
+ function removeAllSkipLayers() {
9776
+ let fields = fetchJaxAndTorchFields();
9777
+
9778
+ if (fields.error) {
9779
+ Toast.fire({ icon: "error", text: fields.error });
9780
+ return;
9781
+ }
9782
+
9783
+ let filteredFields = [];
9784
+
9785
+ for (let i = 0; i < fields.torchFields.length; i++) {
9786
+ if (fields.torchFields[i].skip === false) {
9787
+ filteredFields.push(fields.torchFields[i]);
9788
+ }
9789
+ }
9790
+
9791
+ $$invalidate(3, torchFields = filteredFields);
9792
+ }
9793
+
9794
+ async function matchByName() {
9795
+ let fields = fetchJaxAndTorchFields();
9796
+
9797
+ if (fields.error) {
9798
+ Toast.fire({ icon: "error", text: fields.error });
9799
+ return;
9800
+ }
9801
+
9802
+ if (fields.jaxFields.length !== fields.torchFields.length) {
9803
+ Toast.fire({
9804
+ icon: "error",
9805
+ text: "PyTree and State Dict have diffent lengths. Make sure to pad first!"
9806
+ });
9807
+
9808
+ return;
9809
+ }
9810
+
9811
+ Toast.fire({
9812
+ icon: "info",
9813
+ title: "Matching by name...",
9814
+ text: "This can take a while! Hold tight."
9815
+ });
9816
+
9817
+ let content = `
9818
+ You will get two lists of strings. These strings are fields of a JAX and PyTorch model.
9819
+ For example:
9820
+ --JAX START--
9821
+ .layers[0].weight
9822
+ .layers[1].weight
9823
+ .layers[2].weight
9824
+ .layers[3].weight
9825
+ .layers[4].weight
9826
+ --JAX END--
9827
+
9828
+ --PYTORCH START--
9829
+ layers.0.weight
9830
+ layers.1.weight
9831
+ layers.4.weight
9832
+ layers.2.weight
9833
+ layers.3.weight
9834
+ --PYTORCH END--
9835
+
9836
+ As you can see, the order doesn't match. This means, you should look at the PyTorch fields and
9837
+ rearrange them, such that they match the JAX model. In the above example, the expected return value
9838
+ is:
9839
+ --PYTORCH START--
9840
+ layers.0.weight
9841
+ layers.1.weight
9842
+ layers.2.weight
9843
+ layers.3.weight
9844
+ layers.4.weight
9845
+ --PYTORCH END--
9846
+
9847
+ Here's another example:
9848
+ --JAX START--
9849
+ .conv1.weight
9850
+ .bn1.weight
9851
+ .bn1.bias
9852
+ .bn1.state_index.init[0]
9853
+ .bn1.state_index.init[1]
9854
+ --JAX END--
9855
+
9856
+ --PYTORCH START--
9857
+ bn1.running_mean
9858
+ bn1.running_var
9859
+ conv1.weight
9860
+ bn1.weight
9861
+ bn1.bias
9862
+ --PYTORCH END--
9863
+
9864
+ The expected return value in this case is:
9865
+
9866
+ --PYTORCH START--
9867
+ conv1.weight
9868
+ bn1.weight
9869
+ bn1.bias
9870
+ bn1.running_mean
9871
+ bn1.running_var
9872
+ --PYTORCH END--
9873
+
9874
+
9875
+ Sometimes, there are so-called "skip-layers" in the PyTorch model. Those can be put anywhere, preferably to
9876
+ the end, because your priority is to match those fields that can be matched first! Here's an example:
9877
+
9878
+ --JAX START--
9879
+ .layers[0].weight
9880
+ .layers[1].weight
9881
+ .layers[2].weight
9882
+ .layers[3].weight
9883
+ .layers[4].weight
9884
+ --JAX END--
9885
+
9886
+ --PYTORCH START--
9887
+ layers.0.weight
9888
+ SKIP
9889
+ layers.3.weight
9890
+ layers.2.weight
9891
+ layers.1.weight
9892
+ --PYTORCH START--
9893
+
9894
+ This should return
9895
+
9896
+ --PYTORCH START--
9897
+ layers.0.weight
9898
+ layers.1.weight
9899
+ layers.2.weight
9900
+ layers.3.weight
9901
+ SKIP
9902
+ --PYTORCH START--
9903
+
9904
+
9905
+ It's not always 100% which belongs to which. Use your best judgement. Start your response with
9906
+ --PYTORCH START-- and end it with --PYTORCH END--.
9907
+
9908
+
9909
+ Here's your input:
9910
+ --JAX START--
9911
+ `;
9912
+
9913
+ for (let i = 0; i < fields.jaxFields.length; i++) {
9914
+ content += fields.jaxFields[i].path + "\n";
9915
+ }
9916
+
9917
+ content += "--JAX END--\n";
9918
+ content += "\n";
9919
+ content += "--PYTORCH START--\n";
9920
+
9921
+ for (let i = 0; i < fields.torchFields.length; i++) {
9922
+ content += fields.torchFields[i].path + "\n";
9923
+ }
9924
+
9925
+ content += "--PYTORCH END--";
9926
+ console.log(content);
9927
+
9928
+ let req = await fetch("/anthropic", {
9929
+ method: "POST",
9930
+ headers: { "Content-Type": "application/json" },
9931
+ body: JSON.stringify({ content, model: anthropicModel })
9932
+ });
9933
+
9934
+ let res = await req.json();
9935
+
9936
+ if (res.error) {
9937
+ Toast.fire({ icon: "error", text: res.error });
9938
+ return;
9939
+ }
9940
+
9941
+ console.log(res);
9942
+ let responseContent = res.content;
9943
+ let lines = responseContent.split("\n");
9944
+ console.log(lines);
9945
+ let rearrangedTorchFields = [];
9946
+
9947
+ for (let i = 0; i < lines.length; i++) {
9948
+ let matchingTorchField = fields.torchFields.find(field => field.path === lines[i]);
9949
+
9950
+ if (matchingTorchField) {
9951
+ rearrangedTorchFields.push(matchingTorchField);
9952
+ }
9953
+ }
9954
+
9955
+ if (fields.torchFields.length !== rearrangedTorchFields.length) {
9956
+ Toast.fire({
9957
+ icon: "error",
9958
+ text: "Some fields are missing in the response. Try a different model instead."
9959
+ });
9960
+
9961
+ return;
9962
+ }
9963
+
9964
+ console.log("rearrangedTorchFields", rearrangedTorchFields);
9965
+
9966
+ setTimeout(
9967
+ () => {
9968
+ $$invalidate(3, torchFields = rearrangedTorchFields);
9969
+ onEnd();
9970
+ Toast.fire({ icon: "success", title: "Success" });
9971
+ },
9972
+ 500
9973
+ );
9974
+ }
9975
+
9607
9976
  const writable_props = [];
9608
9977
 
9609
9978
  Object.keys($$props).forEach(key => {
9610
9979
  if (!~writable_props.indexOf(key) && key.slice(0, 2) !== '$$' && key !== 'slot') console_1.warn(`<App> was created with unknown prop '${key}'`);
9611
9980
  });
9612
9981
 
9982
+ function select_change_handler() {
9983
+ anthropicModel = select_value(this);
9984
+ $$invalidate(1, anthropicModel);
9985
+ }
9986
+
9613
9987
  const click_handler = i => {
9614
9988
  removeSkipLayer(i);
9615
9989
  };
@@ -9628,25 +10002,27 @@ var app = (function () {
9628
10002
  onMount,
9629
10003
  Swal,
9630
10004
  model,
10005
+ anthropicModel,
9631
10006
  Toast,
9632
10007
  jaxFields,
9633
10008
  torchFields,
9634
- torchSortable,
9635
10009
  initSortable,
9636
- swap,
9637
10010
  fetchJaxAndTorchFields,
9638
10011
  onEnd,
9639
10012
  checkFields,
9640
10013
  removeSkipLayer,
9641
10014
  addSkipLayer,
9642
- convert
10015
+ convert,
10016
+ padToMatch,
10017
+ removeAllSkipLayers,
10018
+ matchByName
9643
10019
  });
9644
10020
 
9645
10021
  $$self.$inject_state = $$props => {
9646
10022
  if ('model' in $$props) $$invalidate(0, model = $$props.model);
9647
- if ('jaxFields' in $$props) $$invalidate(1, jaxFields = $$props.jaxFields);
9648
- if ('torchFields' in $$props) $$invalidate(2, torchFields = $$props.torchFields);
9649
- if ('torchSortable' in $$props) torchSortable = $$props.torchSortable;
10023
+ if ('anthropicModel' in $$props) $$invalidate(1, anthropicModel = $$props.anthropicModel);
10024
+ if ('jaxFields' in $$props) $$invalidate(2, jaxFields = $$props.jaxFields);
10025
+ if ('torchFields' in $$props) $$invalidate(3, torchFields = $$props.torchFields);
9650
10026
  };
9651
10027
 
9652
10028
  if ($$props && "$$inject" in $$props) {
@@ -9655,11 +10031,16 @@ var app = (function () {
9655
10031
 
9656
10032
  return [
9657
10033
  model,
10034
+ anthropicModel,
9658
10035
  jaxFields,
9659
10036
  torchFields,
9660
10037
  removeSkipLayer,
9661
10038
  addSkipLayer,
9662
10039
  convert,
10040
+ padToMatch,
10041
+ removeAllSkipLayers,
10042
+ matchByName,
10043
+ select_change_handler,
9663
10044
  click_handler,
9664
10045
  click_handler_1,
9665
10046
  input_input_handler