@genai-fi/nanogpt 0.11.0 → 0.12.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (236) hide show
  1. package/dist/Generator.js +29 -29
  2. package/dist/{RealDiv-Ds-jvL09.js → RealDiv-C8neBwFi.js} +17 -17
  3. package/dist/{Reshape-Cd6e-Otn.js → Reshape-Bd4V_4X7.js} +1 -1
  4. package/dist/{Reshape-Ct266DEk.js → Reshape-Ck29jQSY.js} +7 -7
  5. package/dist/TeachableLLM.d.ts +2 -1
  6. package/dist/TeachableLLM.js +9 -9
  7. package/dist/Trainer.d.ts +4 -2
  8. package/dist/Trainer.js +11 -8
  9. package/dist/{axis_util-DofAuy0p.js → axis_util-DGqbT-FX.js} +1 -1
  10. package/dist/backend.js +2 -2
  11. package/dist/{backend_util-C7NWHpv7.js → backend_util-DC3rBo_H.js} +18 -18
  12. package/dist/{backend_webgpu-B0Vls736.js → backend_webgpu-mbhNnlx9.js} +10 -10
  13. package/dist/{broadcast_to-DDaNMbX7.js → broadcast_to-D1Dmg2Oz.js} +2 -2
  14. package/dist/checks/appendCache.js +2 -2
  15. package/dist/checks/attentionMask.js +3 -3
  16. package/dist/checks/gelu.js +2 -2
  17. package/dist/checks/matMulGelu.js +2 -2
  18. package/dist/checks/normRMS.js +4 -4
  19. package/dist/checks/normRMSGrad.js +3 -3
  20. package/dist/checks/packUnpack.js +2 -2
  21. package/dist/checks/qkv.js +2 -2
  22. package/dist/checks/rope.js +2 -2
  23. package/dist/clip_by_value-fg2aKzUy.js +12 -0
  24. package/dist/{complex-DClmWqJt.js → complex-Cyg-eQeZ.js} +1 -1
  25. package/dist/concat-CSm2rMwe.js +17 -0
  26. package/dist/{concat_util-CHsJFZJJ.js → concat_util-D0je5Ppu.js} +1 -1
  27. package/dist/{dataset-DcjWqUVQ.js → dataset-CVIJu7Xa.js} +3 -3
  28. package/dist/{dropout-OxuaJz6z.js → dropout-DLhSMNTZ.js} +14 -14
  29. package/dist/expand_dims-ChkuOp6I.js +11 -0
  30. package/dist/{exports_initializers-eS9QJ6ut.js → exports_initializers-1KWPiStI.js} +1 -1
  31. package/dist/{floor-DIb-lN_u.js → floor-BRMPgeIs.js} +1 -1
  32. package/dist/gather-BSULDalH.js +9 -0
  33. package/dist/{gelu-DqTbCx5x.js → gelu-BK1k-n1i.js} +1 -1
  34. package/dist/{gpgpu_math-CJcbnKPC.js → gpgpu_math-BJSTk_mW.js} +25 -25
  35. package/dist/{index-Dj5TkmPY.js → index-BBVLAXZD.js} +14 -14
  36. package/dist/{index-D0RBWjq8.js → index-Duu1Lvvv.js} +45 -45
  37. package/dist/{kernel_funcs_utils-CSaumNDs.js → kernel_funcs_utils-BtYrPoJu.js} +8 -8
  38. package/dist/layers/BaseLayer.js +2 -2
  39. package/dist/layers/CausalSelfAttention.js +6 -6
  40. package/dist/layers/MLP.js +4 -4
  41. package/dist/layers/PositionEmbedding.js +5 -5
  42. package/dist/layers/RMSNorm.js +3 -3
  43. package/dist/layers/RoPECache.js +4 -4
  44. package/dist/layers/TiedEmbedding.js +6 -6
  45. package/dist/layers/TransformerBlock.js +1 -1
  46. package/dist/loader/loadTransformers.js +1 -1
  47. package/dist/loader/oldZipLoad.js +17 -17
  48. package/dist/{log_sum_exp-VLZgbFAH.js → log_sum_exp-CVqLsVLl.js} +4 -4
  49. package/dist/main.d.ts +9 -0
  50. package/dist/main.js +68 -58
  51. package/dist/{matMul16-cDxwemKj.js → matMul16-xswmhSuF.js} +7 -7
  52. package/dist/{matMulGelu-B2s_80-H.js → matMulGelu-BpvgnYG8.js} +26 -26
  53. package/dist/mat_mul-Bn2BDpT4.js +11 -0
  54. package/dist/{mod-PrOKlFxH.js → mod-B4AUd1Np.js} +1 -1
  55. package/dist/models/NanoGPTV1.js +2 -2
  56. package/dist/models/model.js +9 -9
  57. package/dist/{ones-BX_wEgzB.js → ones-CBI1AQjb.js} +3 -3
  58. package/dist/ops/adamAdjust.js +1 -1
  59. package/dist/ops/adamMoments.js +1 -1
  60. package/dist/ops/add16.js +1 -1
  61. package/dist/ops/appendCache.js +3 -3
  62. package/dist/ops/attentionMask.js +1 -1
  63. package/dist/ops/concat16.js +2 -2
  64. package/dist/ops/cpu/adamAdjust.js +7 -7
  65. package/dist/ops/cpu/adamMoments.js +5 -5
  66. package/dist/ops/cpu/appendCache.js +6 -6
  67. package/dist/ops/cpu/attentionMask.js +6 -6
  68. package/dist/ops/cpu/fusedSoftmax.js +5 -5
  69. package/dist/ops/cpu/gatherSub.js +7 -7
  70. package/dist/ops/cpu/gelu.js +5 -5
  71. package/dist/ops/cpu/matMul16.js +2 -2
  72. package/dist/ops/cpu/matMulGelu.js +3 -3
  73. package/dist/ops/cpu/matMulMul.js +5 -5
  74. package/dist/ops/cpu/mulDropout.js +1 -1
  75. package/dist/ops/cpu/normRMS.js +5 -5
  76. package/dist/ops/cpu/qkv.js +3 -3
  77. package/dist/ops/cpu/rope.js +9 -9
  78. package/dist/ops/cpu/scatterSub.js +5 -5
  79. package/dist/ops/dot16.js +2 -2
  80. package/dist/ops/gatherSub.js +1 -1
  81. package/dist/ops/gelu.js +2 -2
  82. package/dist/ops/grads/add16.js +1 -1
  83. package/dist/ops/grads/attentionMask.js +2 -2
  84. package/dist/ops/grads/gelu.js +2 -2
  85. package/dist/ops/grads/matMul16.js +3 -3
  86. package/dist/ops/grads/matMulGelu.js +5 -5
  87. package/dist/ops/grads/normRMS.js +6 -6
  88. package/dist/ops/grads/pack16.js +3 -3
  89. package/dist/ops/grads/qkv.js +9 -9
  90. package/dist/ops/grads/rope.js +2 -2
  91. package/dist/ops/grads/softmax16.js +1 -1
  92. package/dist/ops/grads/unpack16.js +2 -2
  93. package/dist/ops/matMul16.js +3 -3
  94. package/dist/ops/matMulGelu.js +2 -2
  95. package/dist/ops/matMulMul.js +1 -1
  96. package/dist/ops/mul16.js +1 -1
  97. package/dist/ops/mulDrop.js +1 -1
  98. package/dist/ops/normRMS.js +1 -1
  99. package/dist/ops/pack16.js +2 -2
  100. package/dist/ops/qkv.js +1 -1
  101. package/dist/ops/reshape16.js +6 -6
  102. package/dist/ops/rope.js +2 -2
  103. package/dist/ops/scatterSub.js +1 -1
  104. package/dist/ops/slice16.js +2 -2
  105. package/dist/ops/softmax16.js +1 -1
  106. package/dist/ops/sub16.js +1 -1
  107. package/dist/ops/sum16.js +2 -2
  108. package/dist/ops/transpose16.js +6 -6
  109. package/dist/ops/unpack16.js +2 -2
  110. package/dist/ops/webgl/adamAdjust.js +2 -2
  111. package/dist/ops/webgl/adamMoments.js +1 -1
  112. package/dist/ops/webgl/appendCache.js +1 -1
  113. package/dist/ops/webgl/attentionMask.js +4 -4
  114. package/dist/ops/webgl/fusedSoftmax.js +6 -6
  115. package/dist/ops/webgl/gatherSub.js +1 -1
  116. package/dist/ops/webgl/gelu.js +2 -2
  117. package/dist/ops/webgl/log.js +3 -3
  118. package/dist/ops/webgl/matMul16.js +10 -10
  119. package/dist/ops/webgl/matMulGelu.js +4 -4
  120. package/dist/ops/webgl/matMulMul.js +2 -2
  121. package/dist/ops/webgl/mulDropout.js +1 -1
  122. package/dist/ops/webgl/normRMS.js +2 -2
  123. package/dist/ops/webgl/qkv.js +1 -1
  124. package/dist/ops/webgl/rope.js +4 -4
  125. package/dist/ops/webgl/scatterSub.js +1 -1
  126. package/dist/ops/webgpu/adamAdjust.js +3 -3
  127. package/dist/ops/webgpu/adamMoments.js +5 -5
  128. package/dist/ops/webgpu/add16.js +1 -1
  129. package/dist/ops/webgpu/appendCache.js +3 -3
  130. package/dist/ops/webgpu/attentionMask.js +5 -5
  131. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  132. package/dist/ops/webgpu/concat16.js +5 -5
  133. package/dist/ops/webgpu/gatherSub.js +3 -3
  134. package/dist/ops/webgpu/gelu.js +3 -3
  135. package/dist/ops/webgpu/matMul16.js +19 -19
  136. package/dist/ops/webgpu/matMul16_program.js +2 -2
  137. package/dist/ops/webgpu/mul16.js +1 -1
  138. package/dist/ops/webgpu/normRMS.js +2 -2
  139. package/dist/ops/webgpu/normRMSGrad.js +4 -4
  140. package/dist/ops/webgpu/pack16.js +3 -3
  141. package/dist/ops/webgpu/pack16_program.js +2 -2
  142. package/dist/ops/webgpu/qkv.js +4 -4
  143. package/dist/ops/webgpu/rope.js +3 -3
  144. package/dist/ops/webgpu/scatterSub.js +3 -3
  145. package/dist/ops/webgpu/slice16.js +4 -4
  146. package/dist/ops/webgpu/softmax16.js +4 -4
  147. package/dist/ops/webgpu/softmax16_program.js +2 -2
  148. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  149. package/dist/ops/webgpu/softmax16grad.js +1 -1
  150. package/dist/ops/webgpu/sub16.js +1 -1
  151. package/dist/ops/webgpu/sum16.js +5 -5
  152. package/dist/ops/webgpu/transpose16.js +2 -2
  153. package/dist/ops/webgpu/transpose16_program.js +2 -2
  154. package/dist/ops/webgpu/transpose16_shared_program.js +3 -3
  155. package/dist/ops/webgpu/unpack16.js +5 -5
  156. package/dist/ops/webgpu/utils/binary_op.js +3 -3
  157. package/dist/ops/webgpu/utils/reductions.js +4 -4
  158. package/dist/{ops-FJapAPfm.js → ops-C2_OXuZ4.js} +35 -35
  159. package/dist/{pack16-k4jq6aMX.js → pack16-atD0eYRm.js} +6 -6
  160. package/dist/patches/webgpu_backend.js +8 -8
  161. package/dist/patches/webgpu_base.js +1 -1
  162. package/dist/patches/webgpu_program.js +2 -2
  163. package/dist/{random_width-UGQn4OWb.js → random_width-BN4wGJaW.js} +33 -33
  164. package/dist/{range-CuGvVN2c.js → range-DKmP1-OQ.js} +1 -1
  165. package/dist/relu-BsXmGzzu.js +9 -0
  166. package/dist/{reshape-CkjKPPqB.js → reshape-BI0yzp1T.js} +1 -1
  167. package/dist/{resize_nearest_neighbor-DB8k9KN_.js → resize_nearest_neighbor-BA_BX-ub.js} +25 -25
  168. package/dist/{rope-BmZmp9uP.js → rope-DJ7Y7c-u.js} +1 -1
  169. package/dist/{scatter_nd_util-BY22Cc-C.js → scatter_nd_util-k9MUVUkn.js} +1 -1
  170. package/dist/{selu_util-BuLbmbrl.js → selu_util-DyW0X1WG.js} +5 -5
  171. package/dist/{shared-B7USJZgw.js → shared-Q3BS6T03.js} +1 -1
  172. package/dist/{shared-BQboIImQ.js → shared-nnSWpC3u.js} +6 -6
  173. package/dist/{slice-Aqy7KbJh.js → slice-wBNvzVyz.js} +3 -3
  174. package/dist/{slice_util-D8CQRenR.js → slice_util-zN8KFC5I.js} +7 -7
  175. package/dist/{softmax-faLoUZVT.js → softmax-DfuYyjMh.js} +1 -1
  176. package/dist/split-BYrLboMq.js +9 -0
  177. package/dist/squeeze-Bk8Brcct.js +10 -0
  178. package/dist/{stack-WJK22CFn.js → stack-CDWShFHF.js} +1 -1
  179. package/dist/{step-dXR33iOg.js → step-BS5JXRR6.js} +14 -14
  180. package/dist/sum-BPUfDB2X.js +11 -0
  181. package/dist/{tensor-BQqrDvpx.js → tensor-CEt9Nm2s.js} +1 -1
  182. package/dist/{tensor1d-LxP9asMm.js → tensor1d-Cc_KCIDg.js} +1 -1
  183. package/dist/{tensor2d-BN1sSfQO.js → tensor2d-BN97fF71.js} +1 -1
  184. package/dist/{tensor4d-DVwr7pLF.js → tensor4d-vuDDgdUI.js} +1 -1
  185. package/dist/{tfjs_backend-Vi4JfLzT.js → tfjs_backend-806hyYve.js} +36 -36
  186. package/dist/tile-OWUvpIVt.js +11 -0
  187. package/dist/tokeniser/BaseTokeniser.d.ts +6 -8
  188. package/dist/tokeniser/BaseTokeniser.js +6 -6
  189. package/dist/tokeniser/CharTokeniser.d.ts +6 -6
  190. package/dist/tokeniser/CharTokeniser.js +26 -26
  191. package/dist/tokeniser/bpe.d.ts +6 -6
  192. package/dist/tokeniser/bpe.js +9 -9
  193. package/dist/tokeniser/type.d.ts +6 -8
  194. package/dist/training/Adam.js +2 -2
  195. package/dist/training/AdamExt.js +1 -1
  196. package/dist/training/DatasetBuilder.d.ts +1 -1
  197. package/dist/training/DatasetBuilder.js +29 -29
  198. package/dist/training/FullTrainer.js +1 -1
  199. package/dist/training/Trainer.d.ts +5 -4
  200. package/dist/training/Trainer.js +22 -25
  201. package/dist/training/sparseCrossEntropy.js +3 -3
  202. package/dist/training/tasks/ConversationTask.d.ts +11 -0
  203. package/dist/training/tasks/ConversationTask.js +26 -0
  204. package/dist/training/tasks/PretrainingTask.d.ts +11 -0
  205. package/dist/training/tasks/PretrainingTask.js +34 -0
  206. package/dist/training/tasks/StartSentenceTask.d.ts +12 -0
  207. package/dist/training/tasks/StartSentenceTask.js +42 -0
  208. package/dist/training/tasks/Task.d.ts +8 -0
  209. package/dist/training/tasks/Task.js +41 -0
  210. package/dist/{transpose-JawVKyZy.js → transpose-BUkQCJp9.js} +7 -7
  211. package/dist/{unsorted_segment_sum-LAbmE9G4.js → unsorted_segment_sum-BljxHhCY.js} +78 -78
  212. package/dist/utilities/dummy.js +3 -3
  213. package/dist/utilities/multinomialCPU.js +2 -2
  214. package/dist/utilities/packed.js +1 -1
  215. package/dist/utilities/performance.js +1 -1
  216. package/dist/utilities/profile.js +1 -1
  217. package/dist/utilities/safetensors.js +2 -2
  218. package/dist/utilities/sentences.d.ts +1 -1
  219. package/dist/utilities/sentences.js +11 -11
  220. package/dist/utilities/weights.js +2 -2
  221. package/dist/{variable-DQ9yYgEU.js → variable-DPt_Iuog.js} +1 -1
  222. package/dist/{webgpu_program-CAE4RICo.js → webgpu_program-BpWRlghH.js} +1 -1
  223. package/dist/{webgpu_util-BdovYhXr.js → webgpu_util-DMiKzzQM.js} +7 -7
  224. package/dist/{zeros-DeiE2zTa.js → zeros-5YROwwUH.js} +2 -2
  225. package/dist/{zeros_like-BAz3iKru.js → zeros_like-De4n1C3m.js} +57 -57
  226. package/package.json +1 -1
  227. package/dist/clip_by_value-Dn5tzexi.js +0 -12
  228. package/dist/concat-C6X3AAlQ.js +0 -17
  229. package/dist/expand_dims-BzfJK2uc.js +0 -11
  230. package/dist/gather-BcO5UQNJ.js +0 -9
  231. package/dist/mat_mul-DxpNTCRz.js +0 -11
  232. package/dist/relu-Cf80uA2p.js +0 -9
  233. package/dist/split-BNz5jcGc.js +0 -9
  234. package/dist/squeeze--YMgaAAf.js +0 -10
  235. package/dist/sum-BdplSvq_.js +0 -11
  236. package/dist/tile-CvN_LyVr.js +0 -11
@@ -0,0 +1,42 @@
1
+ import { Task as e } from "./Task.js";
2
+ class a extends e {
3
+ rawText;
4
+ index = 0;
5
+ get length() {
6
+ return this.rawText.length;
7
+ }
8
+ constructor(t) {
9
+ super(), this.rawText = t;
10
+ }
11
+ hasMoreConversations() {
12
+ return this.index < this.rawText.length;
13
+ }
14
+ nextConversation() {
15
+ if (this.index >= this.rawText.length)
16
+ return null;
17
+ const t = this.rawText[this.index];
18
+ return this.index++, this.conversationFromString(t);
19
+ }
20
+ conversationFromString(t) {
21
+ const n = t.indexOf(".");
22
+ return n === -1 ? [{
23
+ role: "assistant",
24
+ content: this.rawText[this.index]
25
+ }] : [
26
+ {
27
+ role: "user",
28
+ content: t.slice(0, n + 1).trim()
29
+ },
30
+ {
31
+ role: "assistant",
32
+ content: t.slice(n + 1).trim()
33
+ }
34
+ ];
35
+ }
36
+ async estimateTokens(t) {
37
+ return (await t.encodeConversation(this.conversationFromString(this.rawText[0]))).length * this.length;
38
+ }
39
+ }
40
+ export {
41
+ a as default
42
+ };
@@ -0,0 +1,8 @@
1
+ import { Conversation, ITokeniser } from '../../main';
2
+ export declare abstract class Task {
3
+ abstract get length(): number;
4
+ abstract hasMoreConversations(): boolean;
5
+ abstract nextConversation(): Conversation[] | null;
6
+ abstract estimateTokens(tokeniser: ITokeniser): Promise<number>;
7
+ }
8
+ export declare function tokensFromTasks(tasks: Task[], tokenizer: ITokeniser): Promise<Uint16Array>;
@@ -0,0 +1,41 @@
1
+ class g {
2
+ }
3
+ function h(f, a, l, e, r) {
4
+ for (let i = 0; i < f.length; i++) {
5
+ const c = f[i].nextConversation();
6
+ if (c) {
7
+ const o = l.encodeConversation(c), s = a[a.length - 1];
8
+ if (e.offset + o.length > s.length) {
9
+ const n = s.length - e.offset;
10
+ s.set(o.slice(0, n), e.offset);
11
+ const t = new Uint16Array(Math.floor(r * 0.1) + 100);
12
+ t.set(o.slice(n), 0), a.push(t), e.offset = o.length - n;
13
+ } else
14
+ s.set(o, e.offset), e.offset += o.length;
15
+ }
16
+ }
17
+ }
18
+ async function w(f, a) {
19
+ const l = (await Promise.all(f.map((n) => n.estimateTokens(a)))).reduce(
20
+ (n, t) => n + t,
21
+ 0
22
+ ), e = [new Uint16Array(l)], r = {
23
+ offset: 0
24
+ };
25
+ let i = performance.now();
26
+ for (; r.offset < l && (h(f, e, a, r, l), !f.every((t) => !t.hasMoreConversations())); )
27
+ performance.now() - i > 40 && (await new Promise(requestAnimationFrame), i = performance.now());
28
+ if (e.length === 1)
29
+ return e[0].subarray(0, r.offset);
30
+ const c = e.reduce((n, t) => n + t.length, 0) - (e[e.length - 1].length - r.offset), o = new Uint16Array(c);
31
+ let s = 0;
32
+ for (let n = 0; n < e.length; n++) {
33
+ const t = e[n];
34
+ n === e.length - 1 ? (o.set(t.subarray(0, r.offset), s), s += r.offset) : (o.set(t, s), s += t.length);
35
+ }
36
+ return o;
37
+ }
38
+ export {
39
+ g as Task,
40
+ w as tokensFromTasks
41
+ };
@@ -1,5 +1,5 @@
1
- import { q as u, u as i, E as o, ap as $, aq as g, ar as m, y as l, t as x, as as p } from "./index-D0RBWjq8.js";
2
- import { c as k } from "./complex-DClmWqJt.js";
1
+ import { o as u, q as i, E as o, ap as $, aq as g, ar as x, x as l, t as m, as as p } from "./index-Duu1Lvvv.js";
2
+ import { c as k } from "./complex-Cyg-eQeZ.js";
3
3
  function K(r) {
4
4
  const e = { input: i(r, "input", "imag") };
5
5
  return o.runKernel($, e);
@@ -12,25 +12,25 @@ function E(r) {
12
12
  const _ = /* @__PURE__ */ u({ neg_: E });
13
13
  function b(r) {
14
14
  const e = { input: i(r, "input", "real") };
15
- return o.runKernel(m, e);
15
+ return o.runKernel(x, e);
16
16
  }
17
17
  const d = /* @__PURE__ */ u({ real_: b });
18
- function y(r, t, e) {
18
+ function N(r, t, e) {
19
19
  const n = i(r, "x", "transpose");
20
20
  if (t == null && (t = n.shape.map((s, a) => a).reverse()), l(n.rank === t.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${t}.`), t.forEach((s) => {
21
21
  l(s >= 0 && s < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${t}`);
22
22
  }), n.rank <= 1)
23
23
  return n.clone();
24
24
  const f = { x: n }, c = { perm: t };
25
- return n.dtype === "complex64" ? x(() => {
25
+ return n.dtype === "complex64" ? m(() => {
26
26
  let s = d(n), a = h(n);
27
27
  return s = o.runKernel(p, { x: s }, c), a = o.runKernel(p, { x: a }, c), e && (a = _(a)), k(s, a);
28
28
  }) : o.runKernel(p, f, c);
29
29
  }
30
- const q = /* @__PURE__ */ u({ transpose_: y });
30
+ const v = /* @__PURE__ */ u({ transpose_: N });
31
31
  export {
32
32
  h as i,
33
33
  _ as n,
34
34
  d as r,
35
- q as t
35
+ v as t
36
36
  };
@@ -1,30 +1,30 @@
1
- import { q as h, u as c, E as d, bo as T, bp as q, bq as H, y as l, br as P, N as _, bs as y, bt as I, bu as W, bv as B, bw as A, bx as G, by as L, bz as O, bA as z, bB as F, D as M, $ as j, bC as J, bD as Q, bE as U, a2 as V, c as N, m as X, bF as Y, bG as Z, bH as R, bI as nn, bJ as tn, bK as sn, bL as en, bM as rn, bN as on, bO as an, bP as un, aG as cn, bQ as ln } from "./index-D0RBWjq8.js";
2
- import { k as C, c as g, m as D } from "./step-dXR33iOg.js";
3
- import { r as b } from "./reshape-CkjKPPqB.js";
4
- import { m as pn, a as hn, e as w } from "./log_sum_exp-VLZgbFAH.js";
5
- import { s as K } from "./sum-BdplSvq_.js";
1
+ import { o as h, q as c, E as d, bo as T, bp as q, bq as H, x as l, br as P, L as _, bs as y, bt as B, bu as I, bv as W, bw as A, bx as G, by as L, bz as O, bA as z, bB as F, B as M, _ as j, bC as J, bD as Q, bE as U, a1 as V, c as N, m as X, bF as Y, bG as Z, bH as R, bI as nn, bJ as tn, bK as sn, bL as en, bM as rn, bN as on, bO as an, bP as un, aG as cn, bQ as ln } from "./index-Duu1Lvvv.js";
2
+ import { k as C, c as g, m as D } from "./step-BS5JXRR6.js";
3
+ import { r as b } from "./reshape-BI0yzp1T.js";
4
+ import { m as pn, a as hn, e as w } from "./log_sum_exp-CVqLsVLl.js";
5
+ import { s as K } from "./sum-BPUfDB2X.js";
6
6
  function fn(s, n = null, t = !1) {
7
- const u = { x: c(s, "x", "all", "bool") }, o = { axis: n, keepDims: t };
8
- return d.runKernel(T, u, o);
7
+ const i = { x: c(s, "x", "all", "bool") }, o = { axis: n, keepDims: t };
8
+ return d.runKernel(T, i, o);
9
9
  }
10
10
  const nt = /* @__PURE__ */ h({ all_: fn });
11
11
  function dn(s, n = null, t = !1) {
12
- const u = { x: c(s, "x", "any", "bool") }, o = { axis: n, keepDims: t };
13
- return d.runKernel(q, u, o);
12
+ const i = { x: c(s, "x", "any", "bool") }, o = { axis: n, keepDims: t };
13
+ return d.runKernel(q, i, o);
14
14
  }
15
15
  const tt = /* @__PURE__ */ h({ any_: dn });
16
16
  function mn(s, n = 0) {
17
- const e = { x: c(s, "x", "argMax") }, u = { axis: n };
18
- return d.runKernel(H, e, u);
17
+ const e = { x: c(s, "x", "argMax") }, i = { axis: n };
18
+ return d.runKernel(H, e, i);
19
19
  }
20
20
  const st = /* @__PURE__ */ h({ argMax_: mn });
21
- function $n(s, n, t, e, u) {
21
+ function $n(s, n, t, e, i) {
22
22
  const o = c(s, "x", "avgPool", "float32"), p = 1;
23
23
  l(C(t, p), () => `Error in avgPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`);
24
24
  let r = o, a = !1;
25
- o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${r.rank}.`), g("avgPool", e, u);
26
- const i = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: u };
27
- let f = d.runKernel(P, i, m);
25
+ o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${r.rank}.`), g("avgPool", e, i);
26
+ const u = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: i };
27
+ let f = d.runKernel(P, u, m);
28
28
  return f = _(f, o.dtype), a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
29
29
  }
30
30
  const et = /* @__PURE__ */ h({ avgPool_: $n });
@@ -34,66 +34,66 @@ function bn(s) {
34
34
  }
35
35
  const rt = /* @__PURE__ */ h({ tanh_: bn });
36
36
  function xn(s, n, t) {
37
- const e = c(s, "x", "batchToSpaceND"), u = n.reduce((r, a) => r * a);
38
- l(e.rank >= 1 + n.length, () => `input rank is ${e.rank} but should be > than blockShape.length ${n.length}`), l(t.length === n.length, () => `crops.length is ${t.length} but should be equal to blockShape.length ${n.length}`), l(e.shape[0] % u === 0, () => `input tensor batch is ${e.shape[0]} but is not divisible by the product of the elements of blockShape ${n.join(" * ")} === ${u}`);
37
+ const e = c(s, "x", "batchToSpaceND"), i = n.reduce((r, a) => r * a);
38
+ l(e.rank >= 1 + n.length, () => `input rank is ${e.rank} but should be > than blockShape.length ${n.length}`), l(t.length === n.length, () => `crops.length is ${t.length} but should be equal to blockShape.length ${n.length}`), l(e.shape[0] % i === 0, () => `input tensor batch is ${e.shape[0]} but is not divisible by the product of the elements of blockShape ${n.join(" * ")} === ${i}`);
39
39
  const o = { x: e }, p = { blockShape: n, crops: t };
40
- return d.runKernel(I, o, p);
40
+ return d.runKernel(B, o, p);
41
41
  }
42
42
  const ot = /* @__PURE__ */ h({ batchToSpaceND_: xn });
43
43
  function kn(s) {
44
44
  let n;
45
45
  return s.rank === 0 || s.rank === 1 ? n = b(s, [1, 1, 1, s.size]) : s.rank === 2 ? n = b(s, [1, 1, s.shape[0], s.shape[1]]) : s.rank === 3 ? n = b(s, [1, s.shape[0], s.shape[1], s.shape[2]]) : n = s, n;
46
46
  }
47
- function vn(s, n, t, e, u, o) {
47
+ function vn(s, n, t, e, i, o) {
48
48
  o == null && (o = 1e-3);
49
49
  const p = c(s, "x", "batchNorm"), r = c(n, "mean", "batchNorm"), a = c(t, "variance", "batchNorm");
50
- let i;
51
- u != null && (i = c(u, "scale", "batchNorm"));
50
+ let u;
51
+ i != null && (u = c(i, "scale", "batchNorm"));
52
52
  let m;
53
- e != null && (m = c(e, "offset", "batchNorm")), l(r.rank === a.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), l(m == null || r.rank === m.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), l(i == null || r.rank === i.rank, () => "Batch normalization gradient requires mean and scale to have equal ranks.");
53
+ e != null && (m = c(e, "offset", "batchNorm")), l(r.rank === a.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), l(m == null || r.rank === m.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), l(u == null || r.rank === u.rank, () => "Batch normalization gradient requires mean and scale to have equal ranks.");
54
54
  const x = {
55
55
  x: kn(p),
56
- scale: i,
56
+ scale: u,
57
57
  offset: m,
58
58
  mean: r,
59
59
  variance: a
60
- }, k = { varianceEpsilon: o }, $ = d.runKernel(W, x, k);
60
+ }, k = { varianceEpsilon: o }, $ = d.runKernel(I, x, k);
61
61
  return b($, p.shape);
62
62
  }
63
63
  const at = /* @__PURE__ */ h({ batchNorm_: vn });
64
- function gn(s, n, t, e, u = "NHWC", o = [1, 1], p) {
64
+ function gn(s, n, t, e, i = "NHWC", o = [1, 1], p) {
65
65
  const r = c(s, "x", "conv2d", "float32"), a = c(n, "filter", "conv2d", "float32");
66
- let i = r, m = !1;
67
- r.rank === 3 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(i.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${i.rank}.`), l(a.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ${a.rank}.`), g("conv2d", e, p);
68
- const f = u === "NHWC" ? i.shape[3] : i.shape[1];
66
+ let u = r, m = !1;
67
+ r.rank === 3 && (m = !0, u = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(u.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${u.rank}.`), l(a.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ${a.rank}.`), g("conv2d", e, p);
68
+ const f = i === "NHWC" ? u.shape[3] : u.shape[1];
69
69
  l(f === a.shape[2], () => `Error in conv2d: depth of input (${f}) must match input depth for filter ${a.shape[2]}.`), l(C(t, o), () => `Error in conv2D: Either strides or dilations must be 1. Got strides ${t} and dilations '${o}'`), l(D(o), () => "Error in conv2D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv2D: Strides should be larger than 0.");
70
- const x = { x: i, filter: a }, k = { strides: t, pad: e, dataFormat: u, dilations: o, dimRoundingMode: p }, $ = d.runKernel(B, x, k);
70
+ const x = { x: u, filter: a }, k = { strides: t, pad: e, dataFormat: i, dilations: o, dimRoundingMode: p }, $ = d.runKernel(W, x, k);
71
71
  return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
72
72
  }
73
73
  const S = /* @__PURE__ */ h({ conv2d_: gn });
74
- function Dn(s, n, t, e, u = "NWC", o = 1, p) {
74
+ function Dn(s, n, t, e, i = "NWC", o = 1, p) {
75
75
  const r = c(s, "x", "conv1d"), a = c(n, "filter", "conv1d");
76
- let i = r, m = !1;
77
- r.rank === 2 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1]])), l(i.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${i.rank}.`), l(a.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ${a.rank}.`), g("conv1d", e, p), l(i.shape[2] === a.shape[1], () => `Error in conv1d: depth of input (${i.shape[2]}) must match input depth for filter ${a.shape[1]}.`), l(C(t, o), () => `Error in conv1D: Either stride or dilation must be 1. Got stride ${t} and dilation '${o}'`), l(D(o), () => "Error in conv1D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv1D: Stride should be larger than 0."), l(u === "NWC", () => `Error in conv1d: got dataFormat of ${u} but only NWC is currently supported.`);
78
- const f = b(a, [1, a.shape[0], a.shape[1], a.shape[2]]), x = b(i, [i.shape[0], 1, i.shape[1], i.shape[2]]), v = S(x, f, [1, t], e, "NHWC", [1, o], p);
76
+ let u = r, m = !1;
77
+ r.rank === 2 && (m = !0, u = b(r, [1, r.shape[0], r.shape[1]])), l(u.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${u.rank}.`), l(a.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ${a.rank}.`), g("conv1d", e, p), l(u.shape[2] === a.shape[1], () => `Error in conv1d: depth of input (${u.shape[2]}) must match input depth for filter ${a.shape[1]}.`), l(C(t, o), () => `Error in conv1D: Either stride or dilation must be 1. Got stride ${t} and dilation '${o}'`), l(D(o), () => "Error in conv1D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv1D: Stride should be larger than 0."), l(i === "NWC", () => `Error in conv1d: got dataFormat of ${i} but only NWC is currently supported.`);
78
+ const f = b(a, [1, a.shape[0], a.shape[1], a.shape[2]]), x = b(u, [u.shape[0], 1, u.shape[1], u.shape[2]]), v = S(x, f, [1, t], e, "NHWC", [1, o], p);
79
79
  return m ? b(v, [v.shape[2], v.shape[3]]) : b(v, [v.shape[0], v.shape[2], v.shape[3]]);
80
80
  }
81
- const ut = /* @__PURE__ */ h({ conv1d_: Dn });
82
- function Cn(s, n, t, e, u, o = "NHWC", p) {
81
+ const it = /* @__PURE__ */ h({ conv1d_: Dn });
82
+ function Cn(s, n, t, e, i, o = "NHWC", p) {
83
83
  l(s.length === n.rank, () => `Length of inShape (${s.length}) and rank of dy (${n.rank}) must match`);
84
- let r = s, a = n, i = !1;
85
- n.rank === 3 && (i = !0, a = b(n, [1, n.shape[0], n.shape[1], n.shape[2]]), r = [1, s[0], s[1], s[2]]), l(r.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${r.length}.`), l(a.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${a.rank}`), l(t.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${t.rank}`);
84
+ let r = s, a = n, u = !1;
85
+ n.rank === 3 && (u = !0, a = b(n, [1, n.shape[0], n.shape[1], n.shape[2]]), r = [1, s[0], s[1], s[2]]), l(r.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${r.length}.`), l(a.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${a.rank}`), l(t.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${t.rank}`);
86
86
  const m = o === "NHWC" ? r[3] : r[1], f = o === "NHWC" ? a.shape[3] : a.shape[1];
87
- l(m === t.shape[2], () => `Error in conv2dDerInput: depth of input (${m}) must match input depth for filter ${t.shape[2]}.`), l(f === t.shape[3], () => `Error in conv2dDerInput: depth of output (${f}) must match output depth for filter ${t.shape[3]}.`), g("conv2dDerInput", u, p);
88
- const x = { dy: a, filter: t }, k = { strides: e, pad: u, dataFormat: o, dimRoundingMode: p, inputShape: r }, $ = d.runKernel(A, x, k);
89
- return i ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
87
+ l(m === t.shape[2], () => `Error in conv2dDerInput: depth of input (${m}) must match input depth for filter ${t.shape[2]}.`), l(f === t.shape[3], () => `Error in conv2dDerInput: depth of output (${f}) must match output depth for filter ${t.shape[3]}.`), g("conv2dDerInput", i, p);
88
+ const x = { dy: a, filter: t }, k = { strides: e, pad: i, dataFormat: o, dimRoundingMode: p, inputShape: r }, $ = d.runKernel(A, x, k);
89
+ return u ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
90
90
  }
91
91
  const En = /* @__PURE__ */ h({ conv2DBackpropInput_: Cn });
92
- function Nn(s, n, t, e, u, o) {
92
+ function Nn(s, n, t, e, i, o) {
93
93
  const p = c(s, "x", "conv2dTranspose"), r = c(n, "filter", "conv2dTranspose");
94
- return En(t, p, r, e, u, "NHWC", o);
94
+ return En(t, p, r, e, i, "NHWC", o);
95
95
  }
96
- const it = /* @__PURE__ */ h({ conv2dTranspose_: Nn });
96
+ const ut = /* @__PURE__ */ h({ conv2dTranspose_: Nn });
97
97
  function _n(s) {
98
98
  const t = { x: c(s, "x", "cos", "float32") };
99
99
  return d.runKernel(G, t);
@@ -114,21 +114,21 @@ function Sn(s, n = 0, t = !1, e = !1) {
114
114
  return d.runKernel(z, o, p);
115
115
  }
116
116
  const ht = /* @__PURE__ */ h({ cumsum_: Sn });
117
- function Tn(s, n, t, e, u = "NHWC", o = [1, 1], p) {
117
+ function Tn(s, n, t, e, i = "NHWC", o = [1, 1], p) {
118
118
  const r = c(s, "x", "depthwiseConv2d", "float32"), a = c(n, "filter", "depthwiseConv2d", "float32");
119
- let i = r, m = !1;
120
- r.rank === 3 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(i.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${i.rank}.`), l(a.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${a.rank}.`);
121
- const f = u === "NHWC" ? i.shape[3] : i.shape[1];
119
+ let u = r, m = !1;
120
+ r.rank === 3 && (m = !0, u = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(u.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${u.rank}.`), l(a.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${a.rank}.`);
121
+ const f = i === "NHWC" ? u.shape[3] : u.shape[1];
122
122
  l(f === a.shape[2], () => `Error in depthwiseConv2d: number of input channels (${f}) must match the inChannels dimension in filter ${a.shape[2]}.`), g("depthwiseConv2d", e, p);
123
- const x = { x: i, filter: a }, k = { strides: t, pad: e, dataFormat: u, dilations: o, dimRoundingMode: p }, $ = d.runKernel(F, x, k);
123
+ const x = { x: u, filter: a }, k = { strides: t, pad: e, dataFormat: i, dilations: o, dimRoundingMode: p }, $ = d.runKernel(F, x, k);
124
124
  return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
125
125
  }
126
126
  const qn = /* @__PURE__ */ h({ depthwiseConv2d_: Tn });
127
127
  function Hn(s, n) {
128
128
  let t = c(s, "a", "equal", "string_or_numeric"), e = c(n, "b", "equal", "string_or_numeric");
129
129
  [t, e] = M(t, e), j(t.shape, e.shape);
130
- const u = { a: t, b: e };
131
- return d.runKernel(J, u);
130
+ const i = { a: t, b: e };
131
+ return d.runKernel(J, i);
132
132
  }
133
133
  const ft = /* @__PURE__ */ h({ equal_: Hn });
134
134
  function Pn(s) {
@@ -143,36 +143,36 @@ function yn(s) {
143
143
  return d.runKernel(U, t);
144
144
  }
145
145
  const mt = /* @__PURE__ */ h({ softplus_: yn });
146
- function In(s, n = -1) {
146
+ function Bn(s, n = -1) {
147
147
  const t = c(s, "logits", "logSoftmax");
148
148
  if (n === -1 && (n = t.rank - 1), n !== t.rank - 1)
149
149
  throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${t.rank} and axis was ${n}`);
150
- return V((u, o) => {
151
- const r = pn(u, n, !0), a = N(u, r), i = N(_(a, "float32"), hn(K(w(a), n, !0)));
152
- return o([i]), { value: i, gradFunc: (f, x) => {
150
+ return V((i, o) => {
151
+ const r = pn(i, n, !0), a = N(i, r), u = N(_(a, "float32"), hn(K(w(a), n, !0)));
152
+ return o([u]), { value: u, gradFunc: (f, x) => {
153
153
  const [k] = x, $ = !0, E = w(k);
154
154
  return N(f, X(K(f, n, $), E));
155
155
  } };
156
156
  })(t);
157
157
  }
158
- const $t = /* @__PURE__ */ h({ logSoftmax_: In });
159
- function Wn(s) {
158
+ const $t = /* @__PURE__ */ h({ logSoftmax_: Bn });
159
+ function In(s) {
160
160
  const t = { x: c(s, "x", "logicalNot", "bool") };
161
161
  return d.runKernel(Y, t);
162
162
  }
163
- const bt = /* @__PURE__ */ h({ logicalNot_: Wn });
164
- function Bn(s, n, t, e, u) {
163
+ const bt = /* @__PURE__ */ h({ logicalNot_: In });
164
+ function Wn(s, n, t, e, i) {
165
165
  const o = c(s, "x", "maxPool"), p = 1;
166
166
  let r = o, a = !1;
167
- o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${r.rank}.`), l(C(t, p), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`), g("maxPool", e, u);
168
- const i = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: u }, f = d.runKernel(Z, i, m);
167
+ o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${r.rank}.`), l(C(t, p), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`), g("maxPool", e, i);
168
+ const u = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: i }, f = d.runKernel(Z, u, m);
169
169
  return a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
170
170
  }
171
- const xt = /* @__PURE__ */ h({ maxPool_: Bn });
172
- function An(s, n, t = 1, e = 0, u = "int32") {
171
+ const xt = /* @__PURE__ */ h({ maxPool_: Wn });
172
+ function An(s, n, t = 1, e = 0, i = "int32") {
173
173
  if (n < 2)
174
174
  throw new Error(`Error in oneHot: depth must be >=2, but it is ${n}`);
175
- const p = { indices: c(s, "indices", "oneHot", "int32") }, r = { dtype: u, depth: n, onValue: t, offValue: e };
175
+ const p = { indices: c(s, "indices", "oneHot", "int32") }, r = { dtype: i, depth: n, onValue: t, offValue: e };
176
176
  return d.runKernel(R, p, r);
177
177
  }
178
178
  const kt = /* @__PURE__ */ h({ oneHot_: An });
@@ -185,20 +185,20 @@ function Ln(s, n, t = 0) {
185
185
  const e = c(s, "x", "pad");
186
186
  if (e.rank === 0)
187
187
  throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");
188
- const u = { paddings: n, constantValue: t }, o = { x: e };
189
- return d.runKernel(tn, o, u);
188
+ const i = { paddings: n, constantValue: t }, o = { x: e };
189
+ return d.runKernel(tn, o, i);
190
190
  }
191
191
  const gt = /* @__PURE__ */ h({ pad_: Ln });
192
192
  function On(s, n, t) {
193
193
  const e = c(s, "x", "spaceToBatchND");
194
194
  l(e.rank >= 1 + n.length, () => `input rank ${e.rank} should be > than [blockShape] ${n.length}`), l(t.length === n.length, () => `paddings.shape[0] ${t.length} must be equal to [blockShape] ${n.length}`), l(e.shape.reduce((p, r, a) => a > 0 && a <= n.length ? p && (r + t[a - 1][0] + t[a - 1][1]) % n[a - 1] === 0 : p, !0), () => `input spatial dimensions ${e.shape.slice(1)} with paddings ${t.toString()} must be divisible by blockShapes ${n.toString()}`);
195
- const u = { x: e }, o = { blockShape: n, paddings: t };
196
- return d.runKernel(sn, u, o);
195
+ const i = { x: e }, o = { blockShape: n, paddings: t };
196
+ return d.runKernel(sn, i, o);
197
197
  }
198
198
  const Dt = /* @__PURE__ */ h({ spaceToBatchND_: On });
199
199
  function zn(s, n) {
200
- const e = { x: c(s, "x", "reverse") }, u = { dims: n };
201
- return d.runKernel(en, e, u);
200
+ const e = { x: c(s, "x", "reverse") }, i = { dims: n };
201
+ return d.runKernel(en, e, i);
202
202
  }
203
203
  const Ct = /* @__PURE__ */ h({ reverse_: zn });
204
204
  function Fn(s) {
@@ -211,15 +211,15 @@ function Mn(s) {
211
211
  return d.runKernel(on, t);
212
212
  }
213
213
  const Nt = /* @__PURE__ */ h({ selu_: Mn });
214
- function jn(s, n, t, e, u, o = [1, 1], p = "NHWC") {
215
- const r = c(s, "x", "separableConv2d"), a = c(n, "depthwiseFilter", "separableConv2d"), i = c(t, "pointwiseFilter", "separableConv2d");
214
+ function jn(s, n, t, e, i, o = [1, 1], p = "NHWC") {
215
+ const r = c(s, "x", "separableConv2d"), a = c(n, "depthwiseFilter", "separableConv2d"), u = c(t, "pointwiseFilter", "separableConv2d");
216
216
  let m = r, f = !1;
217
217
  if (r.rank === 3 && (f = !0, m = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), p === "NCHW")
218
218
  throw new Error("separableConv2d currently does not support dataFormat NCHW; only NHWC is supported");
219
- l(m.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${m.rank}.`), l(a.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${a.rank}.`), l(i.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but got rank ${a.rank}.`), l(i.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter must be 1, but got ${i.shape[0]}.`), l(i.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise filter must be 1, but got ${i.shape[1]}.`);
219
+ l(m.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${m.rank}.`), l(a.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${a.rank}.`), l(u.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but got rank ${a.rank}.`), l(u.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter must be 1, but got ${u.shape[0]}.`), l(u.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise filter must be 1, but got ${u.shape[1]}.`);
220
220
  const x = a.shape[2], k = a.shape[3];
221
- l(i.shape[2] === x * k, () => `Error in separableConv2d: the third dimension of pointwise filter must be ${x * k}, but got ${i.shape[2]}.`);
222
- const $ = qn(m, a, e, u, p, o), v = S($, i, 1, "valid", p);
221
+ l(u.shape[2] === x * k, () => `Error in separableConv2d: the third dimension of pointwise filter must be ${x * k}, but got ${u.shape[2]}.`);
222
+ const $ = qn(m, a, e, i, p, o), v = S($, u, 1, "valid", p);
223
223
  return f ? b(v, [v.shape[1], v.shape[2], v.shape[3]]) : v;
224
224
  }
225
225
  const _t = /* @__PURE__ */ h({ separableConv2d_: jn });
@@ -234,9 +234,9 @@ function Qn(s) {
234
234
  }
235
235
  const Kt = /* @__PURE__ */ h({ sinh_: Qn });
236
236
  function Un(s, n, t) {
237
- const e = c(s, "x", "unsortedSegmentSum"), u = c(n, "segmentIds", "unsortedSegmentSum", "int32");
237
+ const e = c(s, "x", "unsortedSegmentSum"), i = c(n, "segmentIds", "unsortedSegmentSum", "int32");
238
238
  l(cn(t), () => "numSegments must be of dtype int");
239
- const o = { x: e, segmentIds: u }, p = { numSegments: t };
239
+ const o = { x: e, segmentIds: i }, p = { numSegments: t };
240
240
  return d.runKernel(ln, o, p);
241
241
  }
242
242
  const St = /* @__PURE__ */ h({ unsortedSegmentSum_: Un });
@@ -258,10 +258,10 @@ export {
258
258
  tt as h,
259
259
  st as i,
260
260
  at as j,
261
- ut as k,
261
+ it as k,
262
262
  bt as l,
263
263
  xt as m,
264
- it as n,
264
+ ut as n,
265
265
  S as o,
266
266
  lt as p,
267
267
  pt as q,
@@ -1,6 +1,6 @@
1
- import { a as y, e as S, v as w } from "../index-D0RBWjq8.js";
2
- import { z as m } from "../zeros-DeiE2zTa.js";
3
- import { o as P } from "../ones-BX_wEgzB.js";
1
+ import { a as y, e as S, v as w } from "../index-Duu1Lvvv.js";
2
+ import { z as m } from "../zeros-5YROwwUH.js";
3
+ import { o as P } from "../ones-CBI1AQjb.js";
4
4
  async function b(s) {
5
5
  const t = m([1, s.config.blockSize], "int32"), [n, o] = s.forward({ training: !1 }, t);
6
6
  await n.data(), n.dispose(), o && o.dispose(), t.dispose();
@@ -1,5 +1,5 @@
1
- import "../index-D0RBWjq8.js";
2
- import { t as e } from "../tensor2d-BN1sSfQO.js";
1
+ import "../index-Duu1Lvvv.js";
2
+ import { t as e } from "../tensor2d-BN97fF71.js";
3
3
  function l(n) {
4
4
  let r = 0;
5
5
  const i = Math.random();
@@ -1,4 +1,4 @@
1
- import { e as n } from "../index-D0RBWjq8.js";
1
+ import { e as n } from "../index-Duu1Lvvv.js";
2
2
  function o() {
3
3
  return n().backendName === "webgpu";
4
4
  }
@@ -1,4 +1,4 @@
1
- import { t as s } from "../index-D0RBWjq8.js";
1
+ import { t as s } from "../index-Duu1Lvvv.js";
2
2
  async function f(e, o = 10, r = !1) {
3
3
  for (let t = 0; t < 100; t++) {
4
4
  const a = r ? await e() : s(e);
@@ -1,4 +1,4 @@
1
- import { a } from "../index-D0RBWjq8.js";
1
+ import { a } from "../index-Duu1Lvvv.js";
2
2
  const s = 1024 * 1024;
3
3
  class l {
4
4
  log = /* @__PURE__ */ new Map();
@@ -1,5 +1,5 @@
1
- import "../index-D0RBWjq8.js";
2
- import { t as y } from "../tensor-BQqrDvpx.js";
1
+ import "../index-Duu1Lvvv.js";
2
+ import { t as y } from "../tensor-CEt9Nm2s.js";
3
3
  function l(t) {
4
4
  if (t === "float32") return "F32";
5
5
  if (t === "int32") return "I32";
@@ -1,5 +1,5 @@
1
1
  import { default as TeachableLLM } from '../TeachableLLM';
2
2
  import { Tensor2D, Tensor3D } from '@tensorflow/tfjs-core';
3
3
  export declare function meanPooling(embeddings: Tensor3D, attentionMask?: Tensor2D): Tensor2D;
4
- export declare function sentenceEmbeddingsTensor(model: TeachableLLM, sentences: string[], batchSize?: number): Promise<Tensor2D>;
4
+ export declare function sentenceEmbeddingsTensor(model: TeachableLLM, sentences: string[], batchSize?: number): Tensor2D;
5
5
  export declare function sentenceEmbeddings(model: TeachableLLM, sentences: string[], batchSize?: number): Promise<number[][]>;
@@ -1,26 +1,26 @@
1
- import { m as w } from "../index-D0RBWjq8.js";
2
- import { t as g } from "../tensor2d-BN1sSfQO.js";
3
- import { e as y } from "../expand_dims-BzfJK2uc.js";
4
- import { s as h } from "../sum-BdplSvq_.js";
5
- import { c as T } from "../concat-C6X3AAlQ.js";
1
+ import { m as w } from "../index-Duu1Lvvv.js";
2
+ import { t as f } from "../tensor2d-BN97fF71.js";
3
+ import { e as y } from "../expand_dims-ChkuOp6I.js";
4
+ import { s as g } from "../sum-BPUfDB2X.js";
5
+ import { c as T } from "../concat-CSm2rMwe.js";
6
6
  const p = 16;
7
7
  function A(o, t) {
8
8
  if (!t)
9
9
  return o.mean(1);
10
- const r = y(t, 2), i = w(o, r), e = h(i, 1), s = h(t, 1, !0), c = e.div(s.maximum(1e-9));
10
+ const r = y(t, 2), i = w(o, r), e = g(i, 1), s = g(t, 1, !0), c = e.div(s.maximum(1e-9));
11
11
  return r.dispose(), i.dispose(), e.dispose(), s.dispose(), c;
12
12
  }
13
- async function E(o, t, r = p) {
13
+ function E(o, t, r = p) {
14
14
  const i = o.tokeniser, e = o.config.blockSize;
15
15
  let s = null, c = 0;
16
16
  for (; c < t.length; ) {
17
- const m = t.slice(c, c + p), k = await i.tokenise(m, !0), l = [], d = [];
18
- for (const n of k)
17
+ const b = t.slice(c, c + p).map((n) => i.encode(n)), l = [], d = [];
18
+ for (const n of b)
19
19
  n.length > e ? (l.push(n.slice(n.length - e, n.length)), d.push(new Array(e).fill(1))) : n.length < e ? (l.push(n.concat(new Array(e - n.length).fill(0))), d.push(
20
20
  new Array(n.length).fill(1).concat(new Array(e - n.length).fill(0))
21
21
  )) : (l.push(n), d.push(new Array(e).fill(1)));
22
- const b = g(l, [l.length, e], "int32"), u = g(d, [d.length, e], "float32"), f = o.model.forward({ skipLogits: !0, training: !1 }, b)[0], a = A(f, u);
23
- if (u.dispose(), f.dispose(), s === null)
22
+ const k = f(l, [l.length, e], "int32"), m = f(d, [d.length, e], "float32"), u = o.model.forward({ skipLogits: !0, training: !1 }, k)[0], a = A(u, m);
23
+ if (m.dispose(), u.dispose(), s === null)
24
24
  s = a;
25
25
  else {
26
26
  const n = s;
@@ -1,5 +1,5 @@
1
- import "../index-D0RBWjq8.js";
2
- import { t as p } from "../tensor-BQqrDvpx.js";
1
+ import "../index-Duu1Lvvv.js";
2
+ import { t as p } from "../tensor-CEt9Nm2s.js";
3
3
  function h(n) {
4
4
  const e = n.reduce((s, o) => s + o.length, 0), a = new Float32Array(e);
5
5
  let t = 0;
@@ -1,4 +1,4 @@
1
- import { E as i } from "./index-D0RBWjq8.js";
1
+ import { E as i } from "./index-Duu1Lvvv.js";
2
2
  function m(r, a = !0, e, t) {
3
3
  return i.makeVariable(r, a, e, t);
4
4
  }
@@ -1,4 +1,4 @@
1
- import { ad as z, ac as F, ab as E, a9 as j, y as A } from "./index-D0RBWjq8.js";
1
+ import { ac as z, ab as F, aa as E, a8 as j, x as A } from "./index-Duu1Lvvv.js";
2
2
  function L(t, s) {
3
3
  if (Math.max(...t) > 5)
4
4
  throw new Error("Cannot symbolically compute strides for rank > 6 tensor.");
@@ -1,4 +1,4 @@
1
- import { y as u } from "./index-D0RBWjq8.js";
1
+ import { x as u } from "./index-Duu1Lvvv.js";
2
2
  const c = (r) => {
3
3
  let t = 1;
4
4
  for (let n = 0; n < r.length; n++)
@@ -23,16 +23,16 @@ function p(r, t, n = !1) {
23
23
  const a = c(r.x.map((i) => t[i])), o = c(r.y.map((i) => t[i]));
24
24
  return a <= 4 ? [4, 16, 1] : o <= 4 ? [16, 4, 1] : [16, 16, 1];
25
25
  }
26
- function M(r, t, n = !1) {
26
+ function x(r, t, n = !1) {
27
27
  if (n)
28
28
  return [4, 4, 1];
29
29
  const a = c(r.x.map((i) => t[i])), o = c(r.y.map((i) => t[i]));
30
30
  return a <= 4 ? [1, 2, 1] : o <= 4 ? [2, 1, 1] : [2, 2, 1];
31
31
  }
32
- function h(r) {
32
+ function M(r) {
33
33
  return { x: r.map((t, n) => n) };
34
34
  }
35
- function x(r) {
35
+ function h(r) {
36
36
  if (r === "float32" || r === "int32" || r === "bool" || r === "string" || r === "packedF16")
37
37
  return 4;
38
38
  if (r === "complex64")
@@ -52,13 +52,13 @@ var s;
52
52
  r[r.MatMulReduceProgram = 0] = "MatMulReduceProgram", r[r.MatMulSplitKProgram = 1] = "MatMulSplitKProgram", r[r.MatMulSmallOutputSizeProgram = 2] = "MatMulSmallOutputSizeProgram", r[r.MatMulPackedProgram = 3] = "MatMulPackedProgram", r[r.MatMulMax = 4] = "MatMulMax";
53
53
  })(s || (s = {}));
54
54
  export {
55
- x as G,
55
+ h as G,
56
56
  s as M,
57
57
  d as a,
58
58
  b,
59
59
  m as c,
60
60
  p as d,
61
- M as e,
62
- h as f,
61
+ x as e,
62
+ M as f,
63
63
  g as i
64
64
  };
@@ -1,5 +1,5 @@
1
- import { w as n, U as m, V as i, E as c } from "./index-D0RBWjq8.js";
2
- import { c as f } from "./complex-DClmWqJt.js";
1
+ import { u as n, Q as m, U as i, E as c } from "./index-Duu1Lvvv.js";
2
+ import { c as f } from "./complex-Cyg-eQeZ.js";
3
3
  function e(o, r = "float32") {
4
4
  if (n(o), r === "complex64") {
5
5
  const s = e(o, "float32"), t = e(o, "float32");