@genai-fi/nanogpt 0.10.2 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. package/dist/Generator.d.ts +10 -5
  2. package/dist/Generator.js +11760 -146
  3. package/dist/{RealDiv-zz7FpkKX.js → RealDiv-Ds-jvL09.js} +28 -30
  4. package/dist/Reshape-Cd6e-Otn.js +14 -0
  5. package/dist/{Reshape-CHdUjC72.js → Reshape-Ct266DEk.js} +21 -23
  6. package/dist/TeachableLLM.d.ts +4 -3
  7. package/dist/TeachableLLM.js +15 -16
  8. package/dist/Trainer.d.ts +2 -2
  9. package/dist/Trainer.js +6 -6
  10. package/dist/{axis_util-BsIr9ZNu.js → axis_util-DofAuy0p.js} +1 -1
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-B1XRLuq9.js → backend_util-C7NWHpv7.js} +72 -73
  13. package/dist/{backend_webgpu-CqpfEImu.js → backend_webgpu-B0Vls736.js} +52 -54
  14. package/dist/broadcast_to-DDaNMbX7.js +28 -0
  15. package/dist/checks/appendCache.js +2 -2
  16. package/dist/checks/attentionMask.js +3 -3
  17. package/dist/checks/gelu.js +2 -2
  18. package/dist/checks/matMulGelu.js +7 -11
  19. package/dist/checks/normRMS.js +9 -9
  20. package/dist/checks/normRMSGrad.js +3 -3
  21. package/dist/checks/packUnpack.js +2 -2
  22. package/dist/checks/qkv.js +11 -12
  23. package/dist/checks/rope.js +2 -2
  24. package/dist/clip_by_value-Dn5tzexi.js +12 -0
  25. package/dist/complex-DClmWqJt.js +11 -0
  26. package/dist/concat-C6X3AAlQ.js +17 -0
  27. package/dist/{concat_util-iBYIyuQe.js → concat_util-CHsJFZJJ.js} +1 -1
  28. package/dist/{dataset-D2P7rHAw.js → dataset-DcjWqUVQ.js} +135 -137
  29. package/dist/dropout-OxuaJz6z.js +92 -0
  30. package/dist/expand_dims-BzfJK2uc.js +11 -0
  31. package/dist/{exports_initializers-CZSUJoVE.js → exports_initializers-eS9QJ6ut.js} +1 -1
  32. package/dist/floor-DIb-lN_u.js +9 -0
  33. package/dist/gather-BcO5UQNJ.js +9 -0
  34. package/dist/{gelu-Bmhopi0J.js → gelu-DqTbCx5x.js} +10 -11
  35. package/dist/{gpgpu_math-DsCcikas.js → gpgpu_math-CJcbnKPC.js} +841 -1015
  36. package/dist/index-D0RBWjq8.js +3520 -0
  37. package/dist/{index-DRyE072i.js → index-Dj5TkmPY.js} +330 -331
  38. package/dist/{kernel_funcs_utils-CWfOAPGO.js → kernel_funcs_utils-CSaumNDs.js} +132 -134
  39. package/dist/layers/BaseLayer.js +15 -16
  40. package/dist/layers/CausalSelfAttention.js +6 -6
  41. package/dist/layers/MLP.js +4 -4
  42. package/dist/layers/PositionEmbedding.js +7 -7
  43. package/dist/layers/RMSNorm.js +3 -3
  44. package/dist/layers/RoPECache.js +9 -9
  45. package/dist/layers/TiedEmbedding.js +6 -6
  46. package/dist/layers/TransformerBlock.js +1 -1
  47. package/dist/loader/loadTransformers.js +1 -1
  48. package/dist/loader/oldZipLoad.js +21 -22
  49. package/dist/log_sum_exp-VLZgbFAH.js +39 -0
  50. package/dist/main.d.ts +1 -1
  51. package/dist/main.js +49 -50
  52. package/dist/{matMul16-fEAJ4smh.js → matMul16-cDxwemKj.js} +14 -15
  53. package/dist/matMulGelu-B2s_80-H.js +163 -0
  54. package/dist/mat_mul-DxpNTCRz.js +11 -0
  55. package/dist/mod-PrOKlFxH.js +11 -0
  56. package/dist/models/NanoGPTV1.js +2 -2
  57. package/dist/models/model.js +13 -14
  58. package/dist/ones-BX_wEgzB.js +14 -0
  59. package/dist/ops/adamAdjust.js +1 -1
  60. package/dist/ops/adamMoments.js +1 -1
  61. package/dist/ops/add16.js +1 -1
  62. package/dist/ops/appendCache.js +3 -3
  63. package/dist/ops/attentionMask.js +1 -1
  64. package/dist/ops/concat16.js +2 -2
  65. package/dist/ops/cpu/adamAdjust.js +12 -13
  66. package/dist/ops/cpu/adamMoments.js +6 -7
  67. package/dist/ops/cpu/appendCache.js +7 -8
  68. package/dist/ops/cpu/attentionMask.js +11 -11
  69. package/dist/ops/cpu/fusedSoftmax.js +10 -11
  70. package/dist/ops/cpu/gatherSub.js +10 -11
  71. package/dist/ops/cpu/gelu.js +14 -15
  72. package/dist/ops/cpu/matMul16.js +6 -7
  73. package/dist/ops/cpu/matMulGelu.js +5 -6
  74. package/dist/ops/cpu/matMulMul.js +3 -4
  75. package/dist/ops/cpu/mulDropout.js +3 -4
  76. package/dist/ops/cpu/normRMS.js +11 -12
  77. package/dist/ops/cpu/qkv.js +8 -9
  78. package/dist/ops/cpu/rope.js +9 -10
  79. package/dist/ops/cpu/scatterSub.js +14 -16
  80. package/dist/ops/dot16.js +2 -2
  81. package/dist/ops/gatherSub.js +1 -1
  82. package/dist/ops/gelu.js +2 -2
  83. package/dist/ops/grads/add16.js +10 -11
  84. package/dist/ops/grads/attentionMask.js +5 -6
  85. package/dist/ops/grads/gelu.js +3 -4
  86. package/dist/ops/grads/matMul16.js +4 -5
  87. package/dist/ops/grads/matMulGelu.js +8 -9
  88. package/dist/ops/grads/normRMS.js +9 -10
  89. package/dist/ops/grads/pack16.js +4 -5
  90. package/dist/ops/grads/qkv.js +17 -19
  91. package/dist/ops/grads/rope.js +3 -5
  92. package/dist/ops/grads/softmax16.js +3 -4
  93. package/dist/ops/grads/unpack16.js +3 -4
  94. package/dist/ops/grads/utils.d.ts +1 -0
  95. package/dist/ops/grads/utils.js +8 -4
  96. package/dist/ops/matMul16.js +3 -3
  97. package/dist/ops/matMulGelu.js +2 -2
  98. package/dist/ops/matMulMul.js +1 -1
  99. package/dist/ops/mul16.js +1 -1
  100. package/dist/ops/mulDrop.js +1 -1
  101. package/dist/ops/normRMS.js +1 -1
  102. package/dist/ops/pack16.js +3 -4
  103. package/dist/ops/qkv.js +4 -8
  104. package/dist/ops/reshape16.js +16 -18
  105. package/dist/ops/rope.d.ts +1 -1
  106. package/dist/ops/rope.js +3 -8
  107. package/dist/ops/scatterSub.js +1 -1
  108. package/dist/ops/slice16.js +2 -2
  109. package/dist/ops/softmax16.js +5 -8
  110. package/dist/ops/sub16.js +1 -1
  111. package/dist/ops/sum16.js +2 -2
  112. package/dist/ops/transpose16.js +23 -24
  113. package/dist/ops/unpack16.js +2 -2
  114. package/dist/ops/webgl/adamAdjust.js +2 -3
  115. package/dist/ops/webgl/adamMoments.js +1 -2
  116. package/dist/ops/webgl/appendCache.js +1 -2
  117. package/dist/ops/webgl/attentionMask.js +5 -6
  118. package/dist/ops/webgl/fusedSoftmax.js +6 -8
  119. package/dist/ops/webgl/gatherSub.js +6 -7
  120. package/dist/ops/webgl/gelu.js +2 -3
  121. package/dist/ops/webgl/log.js +11 -12
  122. package/dist/ops/webgl/matMul16.js +15 -16
  123. package/dist/ops/webgl/matMulGelu.js +7 -111
  124. package/dist/ops/webgl/matMulMul.js +14 -15
  125. package/dist/ops/webgl/mulDropout.js +8 -9
  126. package/dist/ops/webgl/normRMS.js +7 -8
  127. package/dist/ops/webgl/qkv.js +5 -6
  128. package/dist/ops/webgl/rope.js +7 -8
  129. package/dist/ops/webgl/scatterSub.js +5 -6
  130. package/dist/ops/webgpu/adamAdjust.js +10 -12
  131. package/dist/ops/webgpu/adamMoments.js +8 -10
  132. package/dist/ops/webgpu/add16.js +8 -9
  133. package/dist/ops/webgpu/appendCache.js +23 -25
  134. package/dist/ops/webgpu/attentionMask.js +10 -12
  135. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  136. package/dist/ops/webgpu/concat16.js +12 -14
  137. package/dist/ops/webgpu/gatherSub.js +9 -11
  138. package/dist/ops/webgpu/gelu.js +28 -29
  139. package/dist/ops/webgpu/matMul16.js +26 -28
  140. package/dist/ops/webgpu/matMul16_program.js +4 -5
  141. package/dist/ops/webgpu/mul16.js +7 -8
  142. package/dist/ops/webgpu/normRMS.js +17 -19
  143. package/dist/ops/webgpu/normRMSGrad.js +21 -28
  144. package/dist/ops/webgpu/pack16.js +12 -13
  145. package/dist/ops/webgpu/pack16_program.js +2 -2
  146. package/dist/ops/webgpu/qkv.js +13 -15
  147. package/dist/ops/webgpu/rope.js +25 -27
  148. package/dist/ops/webgpu/scatterSub.js +7 -9
  149. package/dist/ops/webgpu/slice16.js +21 -23
  150. package/dist/ops/webgpu/softmax16.js +17 -19
  151. package/dist/ops/webgpu/softmax16_program.js +2 -2
  152. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  153. package/dist/ops/webgpu/softmax16grad.js +7 -8
  154. package/dist/ops/webgpu/sub16.js +8 -9
  155. package/dist/ops/webgpu/sum16.js +19 -21
  156. package/dist/ops/webgpu/transpose16.js +19 -20
  157. package/dist/ops/webgpu/transpose16_program.js +2 -2
  158. package/dist/ops/webgpu/transpose16_shared_program.js +11 -12
  159. package/dist/ops/webgpu/unpack16.js +3 -4
  160. package/dist/ops/webgpu/utils/binary_op.js +7 -8
  161. package/dist/ops/webgpu/utils/reductions.js +14 -22
  162. package/dist/ops-FJapAPfm.js +476 -0
  163. package/dist/pack16-k4jq6aMX.js +39 -0
  164. package/dist/patches/webgpu_backend.js +19 -20
  165. package/dist/patches/webgpu_base.js +1 -1
  166. package/dist/patches/webgpu_program.js +15 -16
  167. package/dist/{random_width-BVV9HveY.js → random_width-UGQn4OWb.js} +2506 -2761
  168. package/dist/range-CuGvVN2c.js +10 -0
  169. package/dist/relu-Cf80uA2p.js +9 -0
  170. package/dist/reshape-CkjKPPqB.js +9 -0
  171. package/dist/resize_nearest_neighbor-DB8k9KN_.js +175 -0
  172. package/dist/rope-BmZmp9uP.js +24 -0
  173. package/dist/{scatter_nd_util-C7zXRT_h.js → scatter_nd_util-BY22Cc-C.js} +1 -1
  174. package/dist/selu_util-BuLbmbrl.js +44 -0
  175. package/dist/{shared-CHhxz-O5.js → shared-B7USJZgw.js} +1 -1
  176. package/dist/{shared-D2NP_CpY.js → shared-BQboIImQ.js} +379 -381
  177. package/dist/slice-Aqy7KbJh.js +12 -0
  178. package/dist/{slice_util-DyjSAD0u.js → slice_util-D8CQRenR.js} +7 -7
  179. package/dist/{softmax-C9JQEtnO.js → softmax-faLoUZVT.js} +4 -5
  180. package/dist/split-BNz5jcGc.js +9 -0
  181. package/dist/squeeze--YMgaAAf.js +10 -0
  182. package/dist/stack-WJK22CFn.js +11 -0
  183. package/dist/step-dXR33iOg.js +261 -0
  184. package/dist/sum-BdplSvq_.js +11 -0
  185. package/dist/{tensor-0r5yOo2R.js → tensor-BQqrDvpx.js} +1 -1
  186. package/dist/tensor1d-LxP9asMm.js +11 -0
  187. package/dist/{tensor2d-CSB4KOb0.js → tensor2d-BN1sSfQO.js} +6 -7
  188. package/dist/{tensor4d-D7bLqGqz.js → tensor4d-DVwr7pLF.js} +6 -7
  189. package/dist/{tfjs_backend-CNkSTL0c.js → tfjs_backend-Vi4JfLzT.js} +256 -265
  190. package/dist/tile-CvN_LyVr.js +11 -0
  191. package/dist/tokeniser/BaseTokeniser.d.ts +27 -0
  192. package/dist/tokeniser/BaseTokeniser.js +94 -0
  193. package/dist/tokeniser/CharTokeniser.d.ts +4 -3
  194. package/dist/tokeniser/CharTokeniser.js +46 -32
  195. package/dist/tokeniser/bpe.d.ts +4 -3
  196. package/dist/tokeniser/bpe.js +60 -45
  197. package/dist/tokeniser/type.d.ts +11 -0
  198. package/dist/training/Adam.js +2 -2
  199. package/dist/training/AdamExt.js +1 -1
  200. package/dist/training/DatasetBuilder.d.ts +2 -2
  201. package/dist/training/DatasetBuilder.js +32 -36
  202. package/dist/training/FullTrainer.js +1 -1
  203. package/dist/training/Trainer.d.ts +3 -3
  204. package/dist/training/Trainer.js +2 -2
  205. package/dist/training/sparseCrossEntropy.js +5 -5
  206. package/dist/transpose-JawVKyZy.js +36 -0
  207. package/dist/unsorted_segment_sum-LAbmE9G4.js +277 -0
  208. package/dist/utilities/dummy.js +3 -3
  209. package/dist/utilities/multinomialCPU.js +2 -2
  210. package/dist/utilities/packed.d.ts +1 -4
  211. package/dist/utilities/packed.js +10 -745
  212. package/dist/utilities/performance.js +1 -1
  213. package/dist/utilities/profile.js +1 -1
  214. package/dist/utilities/safetensors.js +2 -2
  215. package/dist/utilities/sentences.js +5 -5
  216. package/dist/utilities/weights.js +2 -2
  217. package/dist/{variable-DzfrwYuP.js → variable-DQ9yYgEU.js} +1 -1
  218. package/dist/{webgpu_program-DzaQiqel.js → webgpu_program-CAE4RICo.js} +177 -171
  219. package/dist/{webgpu_util-0_ubCEHJ.js → webgpu_util-BdovYhXr.js} +34 -35
  220. package/dist/zeros-DeiE2zTa.js +13 -0
  221. package/dist/zeros_like-BAz3iKru.js +721 -0
  222. package/package.json +4 -2
  223. package/dist/Reshape-CDVLyVfz.js +0 -16
  224. package/dist/broadcast_to-B0ChcDaz.js +0 -30
  225. package/dist/complex-BBiRlsVq.js +0 -13
  226. package/dist/concat-DmBLPVGC.js +0 -19
  227. package/dist/dropout-B1x1kYMa.js +0 -99
  228. package/dist/expand_dims-ouvfxQ1n.js +0 -13
  229. package/dist/gather-CH9sdacz.js +0 -10
  230. package/dist/index-D6Q1lPZO.js +0 -2157
  231. package/dist/log_sum_exp-D3ftBNY5.js +0 -41
  232. package/dist/mat_mul-C59XWcJd.js +0 -12
  233. package/dist/mod-DESSvHIU.js +0 -12
  234. package/dist/mulmat_packed_gpu-Coh6qbJk.js +0 -55
  235. package/dist/ones-jU9jlQvM.js +0 -15
  236. package/dist/ops-BFDtP6th.js +0 -645
  237. package/dist/pack16-CmVZs6af.js +0 -41
  238. package/dist/patches/PackedTensor.d.ts +0 -12
  239. package/dist/patches/PackedTensor.js +0 -11
  240. package/dist/patches/engine.d.ts +0 -261
  241. package/dist/patches/engine.js +0 -12
  242. package/dist/patches/tape.d.ts +0 -12
  243. package/dist/patches/tape.js +0 -5
  244. package/dist/range-ZZZD60Fx.js +0 -11
  245. package/dist/reciprocal-CrYlsAGD.js +0 -10
  246. package/dist/register_all_kernels-nvj2k7OC.js +0 -12307
  247. package/dist/relu-BYDneVPn.js +0 -10
  248. package/dist/reshape-CaPQzFvz.js +0 -10
  249. package/dist/rope-s4W2XO9B.js +0 -32
  250. package/dist/selu_util-BGPXmd4B.js +0 -303
  251. package/dist/sin-Djs4aQiu.js +0 -16
  252. package/dist/slice-DvovR5wq.js +0 -13
  253. package/dist/split-DBck65sX.js +0 -10
  254. package/dist/squeeze-C00Ipm_7.js +0 -11
  255. package/dist/stack-ChnHwRpX.js +0 -13
  256. package/dist/sum-ywRJj3Zr.js +0 -12
  257. package/dist/tensor-CzmOBsdf.js +0 -909
  258. package/dist/tensor1d-BlUT89BP.js +0 -12
  259. package/dist/tensor_util-DfwaWayG.js +0 -523
  260. package/dist/tile-CR074jmp.js +0 -13
  261. package/dist/transpose-DH4gmHvu.js +0 -38
  262. package/dist/zeros-DBFVbpv5.js +0 -14
@@ -1,4 +1,4 @@
1
- import { ITokeniser } from '../tokeniser/type';
1
+ import { Conversation, ITokeniser } from '../tokeniser/type';
2
2
  import { DatasetBuilder } from './DatasetBuilder';
3
3
  import { default as AdamExt } from './AdamExt';
4
4
  import { NamedTensorMap, TensorContainer } from '@tensorflow/tfjs-core/dist/tensor_types';
@@ -93,7 +93,7 @@ export default abstract class GPTTrainer {
93
93
  log: TrainingLogEntry;
94
94
  progress: TrainingProgress;
95
95
  }>;
96
- createTrainValidationSplit(textData: string[], batchSize?: number, validationSplit?: number): Promise<{
96
+ createTrainValidationSplit(textData: Conversation[][], batchSize?: number, validationSplit?: number): Promise<{
97
97
  trainDataset: Dataset<{
98
98
  xs: Tensor;
99
99
  ys: Tensor;
@@ -103,6 +103,6 @@ export default abstract class GPTTrainer {
103
103
  ys: Tensor;
104
104
  }>;
105
105
  }>;
106
- createDataset(textData: string[], batchSize?: number): Promise<Dataset<TensorContainer>>;
106
+ createDataset(textData: Conversation[][], batchSize?: number): Promise<Dataset<TensorContainer>>;
107
107
  dispose(): void;
108
108
  }
@@ -1,7 +1,7 @@
1
1
  import { DatasetBuilder as f, flattenTokens as h, PAGE_FACTOR as y } from "./DatasetBuilder.js";
2
2
  import z from "./AdamExt.js";
3
- import { t as S, v as k, k as x, d as p, b as m } from "../index-D6Q1lPZO.js";
4
- import { z as g } from "../zeros-DBFVbpv5.js";
3
+ import { t as S, v as k, k as x, d as p, b as m } from "../index-D0RBWjq8.js";
4
+ import { z as g } from "../zeros-DeiE2zTa.js";
5
5
  class M {
6
6
  constructor(t, e, s = 1e-3) {
7
7
  this.tokenizer = e, this.model = t, this.lossScaling = t.lossScaling, this.learningRate = s, this.resetOptimizer(), this.datasetBuilder = new f(e, t.config.blockSize);
@@ -1,15 +1,15 @@
1
1
  import { gatherSub as x } from "../ops/gatherSub.js";
2
2
  import { scatterSub as L } from "../ops/scatterSub.js";
3
- import { w as C, t as u, z as E, c as G } from "../index-D6Q1lPZO.js";
4
- import { s as y } from "../softmax-C9JQEtnO.js";
5
- import { m as z, l as v } from "../log_sum_exp-D3ftBNY5.js";
3
+ import { a2 as C, t as u, a3 as E, c as G } from "../index-D0RBWjq8.js";
4
+ import { s as y } from "../softmax-faLoUZVT.js";
5
+ import { m as z, l as v } from "../log_sum_exp-VLZgbFAH.js";
6
6
  function k(t, s) {
7
7
  return u(() => {
8
8
  const n = t.shape[t.shape.length - 1], c = t.shape.slice(0, -1).reduce((o, e) => o * e, 1), h = t.shape.length > 2 ? t.reshape([c, n]) : t, p = s.shape.length > 1 ? s.reshape([c]).cast("int32") : s.cast("int32"), r = z(h, -1, !0), a = G(h, r), d = v(a, -1);
9
9
  return x(d, p, a);
10
10
  });
11
11
  }
12
- function q() {
12
+ function w() {
13
13
  return C(
14
14
  // @ts-expect-error Invalid params
15
15
  (s, n, m) => {
@@ -22,6 +22,6 @@ function q() {
22
22
  );
23
23
  }
24
24
  export {
25
- q as createSoftmaxCrossEntropyWithGrad,
25
+ w as createSoftmaxCrossEntropyWithGrad,
26
26
  k as sparseSoftmaxCrossEntropy
27
27
  };
@@ -0,0 +1,36 @@
1
+ import { q as u, u as i, E as o, ap as $, aq as g, ar as m, y as l, t as x, as as p } from "./index-D0RBWjq8.js";
2
+ import { c as k } from "./complex-DClmWqJt.js";
3
+ function K(r) {
4
+ const e = { input: i(r, "input", "imag") };
5
+ return o.runKernel($, e);
6
+ }
7
+ const h = /* @__PURE__ */ u({ imag_: K });
8
+ function E(r) {
9
+ const e = { x: i(r, "x", "neg") };
10
+ return o.runKernel(g, e);
11
+ }
12
+ const _ = /* @__PURE__ */ u({ neg_: E });
13
+ function b(r) {
14
+ const e = { input: i(r, "input", "real") };
15
+ return o.runKernel(m, e);
16
+ }
17
+ const d = /* @__PURE__ */ u({ real_: b });
18
+ function y(r, t, e) {
19
+ const n = i(r, "x", "transpose");
20
+ if (t == null && (t = n.shape.map((s, a) => a).reverse()), l(n.rank === t.length, () => `Error in transpose: rank of input ${n.rank} must match length of perm ${t}.`), t.forEach((s) => {
21
+ l(s >= 0 && s < n.rank, () => `All entries in 'perm' must be between 0 and ${n.rank - 1} but got ${t}`);
22
+ }), n.rank <= 1)
23
+ return n.clone();
24
+ const f = { x: n }, c = { perm: t };
25
+ return n.dtype === "complex64" ? x(() => {
26
+ let s = d(n), a = h(n);
27
+ return s = o.runKernel(p, { x: s }, c), a = o.runKernel(p, { x: a }, c), e && (a = _(a)), k(s, a);
28
+ }) : o.runKernel(p, f, c);
29
+ }
30
+ const q = /* @__PURE__ */ u({ transpose_: y });
31
+ export {
32
+ h as i,
33
+ _ as n,
34
+ d as r,
35
+ q as t
36
+ };
@@ -0,0 +1,277 @@
1
+ import { q as h, u as c, E as d, bo as T, bp as q, bq as H, y as l, br as P, N as _, bs as y, bt as I, bu as W, bv as B, bw as A, bx as G, by as L, bz as O, bA as z, bB as F, D as M, $ as j, bC as J, bD as Q, bE as U, a2 as V, c as N, m as X, bF as Y, bG as Z, bH as R, bI as nn, bJ as tn, bK as sn, bL as en, bM as rn, bN as on, bO as an, bP as un, aG as cn, bQ as ln } from "./index-D0RBWjq8.js";
2
+ import { k as C, c as g, m as D } from "./step-dXR33iOg.js";
3
+ import { r as b } from "./reshape-CkjKPPqB.js";
4
+ import { m as pn, a as hn, e as w } from "./log_sum_exp-VLZgbFAH.js";
5
+ import { s as K } from "./sum-BdplSvq_.js";
6
+ function fn(s, n = null, t = !1) {
7
+ const u = { x: c(s, "x", "all", "bool") }, o = { axis: n, keepDims: t };
8
+ return d.runKernel(T, u, o);
9
+ }
10
+ const nt = /* @__PURE__ */ h({ all_: fn });
11
+ function dn(s, n = null, t = !1) {
12
+ const u = { x: c(s, "x", "any", "bool") }, o = { axis: n, keepDims: t };
13
+ return d.runKernel(q, u, o);
14
+ }
15
+ const tt = /* @__PURE__ */ h({ any_: dn });
16
+ function mn(s, n = 0) {
17
+ const e = { x: c(s, "x", "argMax") }, u = { axis: n };
18
+ return d.runKernel(H, e, u);
19
+ }
20
+ const st = /* @__PURE__ */ h({ argMax_: mn });
21
+ function $n(s, n, t, e, u) {
22
+ const o = c(s, "x", "avgPool", "float32"), p = 1;
23
+ l(C(t, p), () => `Error in avgPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`);
24
+ let r = o, a = !1;
25
+ o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in avgPool: x must be rank 4 but got rank ${r.rank}.`), g("avgPool", e, u);
26
+ const i = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: u };
27
+ let f = d.runKernel(P, i, m);
28
+ return f = _(f, o.dtype), a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
29
+ }
30
+ const et = /* @__PURE__ */ h({ avgPool_: $n });
31
+ function bn(s) {
32
+ const t = { x: c(s, "x", "tanh", "float32") };
33
+ return d.runKernel(y, t);
34
+ }
35
+ const rt = /* @__PURE__ */ h({ tanh_: bn });
36
+ function xn(s, n, t) {
37
+ const e = c(s, "x", "batchToSpaceND"), u = n.reduce((r, a) => r * a);
38
+ l(e.rank >= 1 + n.length, () => `input rank is ${e.rank} but should be > than blockShape.length ${n.length}`), l(t.length === n.length, () => `crops.length is ${t.length} but should be equal to blockShape.length ${n.length}`), l(e.shape[0] % u === 0, () => `input tensor batch is ${e.shape[0]} but is not divisible by the product of the elements of blockShape ${n.join(" * ")} === ${u}`);
39
+ const o = { x: e }, p = { blockShape: n, crops: t };
40
+ return d.runKernel(I, o, p);
41
+ }
42
+ const ot = /* @__PURE__ */ h({ batchToSpaceND_: xn });
43
+ function kn(s) {
44
+ let n;
45
+ return s.rank === 0 || s.rank === 1 ? n = b(s, [1, 1, 1, s.size]) : s.rank === 2 ? n = b(s, [1, 1, s.shape[0], s.shape[1]]) : s.rank === 3 ? n = b(s, [1, s.shape[0], s.shape[1], s.shape[2]]) : n = s, n;
46
+ }
47
+ function vn(s, n, t, e, u, o) {
48
+ o == null && (o = 1e-3);
49
+ const p = c(s, "x", "batchNorm"), r = c(n, "mean", "batchNorm"), a = c(t, "variance", "batchNorm");
50
+ let i;
51
+ u != null && (i = c(u, "scale", "batchNorm"));
52
+ let m;
53
+ e != null && (m = c(e, "offset", "batchNorm")), l(r.rank === a.rank, () => "Batch normalization gradient requires mean and variance to have equal ranks."), l(m == null || r.rank === m.rank, () => "Batch normalization gradient requires mean and offset to have equal ranks."), l(i == null || r.rank === i.rank, () => "Batch normalization gradient requires mean and scale to have equal ranks.");
54
+ const x = {
55
+ x: kn(p),
56
+ scale: i,
57
+ offset: m,
58
+ mean: r,
59
+ variance: a
60
+ }, k = { varianceEpsilon: o }, $ = d.runKernel(W, x, k);
61
+ return b($, p.shape);
62
+ }
63
+ const at = /* @__PURE__ */ h({ batchNorm_: vn });
64
+ function gn(s, n, t, e, u = "NHWC", o = [1, 1], p) {
65
+ const r = c(s, "x", "conv2d", "float32"), a = c(n, "filter", "conv2d", "float32");
66
+ let i = r, m = !1;
67
+ r.rank === 3 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(i.rank === 4, () => `Error in conv2d: input must be rank 4, but got rank ${i.rank}.`), l(a.rank === 4, () => `Error in conv2d: filter must be rank 4, but got rank ${a.rank}.`), g("conv2d", e, p);
68
+ const f = u === "NHWC" ? i.shape[3] : i.shape[1];
69
+ l(f === a.shape[2], () => `Error in conv2d: depth of input (${f}) must match input depth for filter ${a.shape[2]}.`), l(C(t, o), () => `Error in conv2D: Either strides or dilations must be 1. Got strides ${t} and dilations '${o}'`), l(D(o), () => "Error in conv2D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv2D: Strides should be larger than 0.");
70
+ const x = { x: i, filter: a }, k = { strides: t, pad: e, dataFormat: u, dilations: o, dimRoundingMode: p }, $ = d.runKernel(B, x, k);
71
+ return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
72
+ }
73
+ const S = /* @__PURE__ */ h({ conv2d_: gn });
74
+ function Dn(s, n, t, e, u = "NWC", o = 1, p) {
75
+ const r = c(s, "x", "conv1d"), a = c(n, "filter", "conv1d");
76
+ let i = r, m = !1;
77
+ r.rank === 2 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1]])), l(i.rank === 3, () => `Error in conv1d: input must be rank 3, but got rank ${i.rank}.`), l(a.rank === 3, () => `Error in conv1d: filter must be rank 3, but got rank ${a.rank}.`), g("conv1d", e, p), l(i.shape[2] === a.shape[1], () => `Error in conv1d: depth of input (${i.shape[2]}) must match input depth for filter ${a.shape[1]}.`), l(C(t, o), () => `Error in conv1D: Either stride or dilation must be 1. Got stride ${t} and dilation '${o}'`), l(D(o), () => "Error in conv1D: Dilated rates should be larger than 0."), l(D(t), () => "Error in conv1D: Stride should be larger than 0."), l(u === "NWC", () => `Error in conv1d: got dataFormat of ${u} but only NWC is currently supported.`);
78
+ const f = b(a, [1, a.shape[0], a.shape[1], a.shape[2]]), x = b(i, [i.shape[0], 1, i.shape[1], i.shape[2]]), v = S(x, f, [1, t], e, "NHWC", [1, o], p);
79
+ return m ? b(v, [v.shape[2], v.shape[3]]) : b(v, [v.shape[0], v.shape[2], v.shape[3]]);
80
+ }
81
+ const ut = /* @__PURE__ */ h({ conv1d_: Dn });
82
+ function Cn(s, n, t, e, u, o = "NHWC", p) {
83
+ l(s.length === n.rank, () => `Length of inShape (${s.length}) and rank of dy (${n.rank}) must match`);
84
+ let r = s, a = n, i = !1;
85
+ n.rank === 3 && (i = !0, a = b(n, [1, n.shape[0], n.shape[1], n.shape[2]]), r = [1, s[0], s[1], s[2]]), l(r.length === 4, () => `Error in conv2dDerInput: inShape must be length 4, but got length ${r.length}.`), l(a.rank === 4, () => `Error in conv2dDerInput: dy must be rank 4, but got rank ${a.rank}`), l(t.rank === 4, () => `Error in conv2dDerInput: filter must be rank 4, but got rank ${t.rank}`);
86
+ const m = o === "NHWC" ? r[3] : r[1], f = o === "NHWC" ? a.shape[3] : a.shape[1];
87
+ l(m === t.shape[2], () => `Error in conv2dDerInput: depth of input (${m}) must match input depth for filter ${t.shape[2]}.`), l(f === t.shape[3], () => `Error in conv2dDerInput: depth of output (${f}) must match output depth for filter ${t.shape[3]}.`), g("conv2dDerInput", u, p);
88
+ const x = { dy: a, filter: t }, k = { strides: e, pad: u, dataFormat: o, dimRoundingMode: p, inputShape: r }, $ = d.runKernel(A, x, k);
89
+ return i ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
90
+ }
91
+ const En = /* @__PURE__ */ h({ conv2DBackpropInput_: Cn });
92
+ function Nn(s, n, t, e, u, o) {
93
+ const p = c(s, "x", "conv2dTranspose"), r = c(n, "filter", "conv2dTranspose");
94
+ return En(t, p, r, e, u, "NHWC", o);
95
+ }
96
+ const it = /* @__PURE__ */ h({ conv2dTranspose_: Nn });
97
+ function _n(s) {
98
+ const t = { x: c(s, "x", "cos", "float32") };
99
+ return d.runKernel(G, t);
100
+ }
101
+ const ct = /* @__PURE__ */ h({ cos_: _n });
102
+ function wn(s) {
103
+ const t = { x: c(s, "x", "cosh", "float32") };
104
+ return d.runKernel(L, t);
105
+ }
106
+ const lt = /* @__PURE__ */ h({ cosh_: wn });
107
+ function Kn(s, n = 0, t = !1, e = !1) {
108
+ const o = { x: c(s, "x", "cumprod") }, p = { axis: n, exclusive: t, reverse: e };
109
+ return d.runKernel(O, o, p);
110
+ }
111
+ const pt = /* @__PURE__ */ h({ cumprod_: Kn });
112
+ function Sn(s, n = 0, t = !1, e = !1) {
113
+ const o = { x: c(s, "x", "cumsum") }, p = { axis: n, exclusive: t, reverse: e };
114
+ return d.runKernel(z, o, p);
115
+ }
116
+ const ht = /* @__PURE__ */ h({ cumsum_: Sn });
117
+ function Tn(s, n, t, e, u = "NHWC", o = [1, 1], p) {
118
+ const r = c(s, "x", "depthwiseConv2d", "float32"), a = c(n, "filter", "depthwiseConv2d", "float32");
119
+ let i = r, m = !1;
120
+ r.rank === 3 && (m = !0, i = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), l(i.rank === 4, () => `Error in depthwiseConv2d: input must be rank 4, but got rank ${i.rank}.`), l(a.rank === 4, () => `Error in depthwiseConv2d: filter must be rank 4, but got rank ${a.rank}.`);
121
+ const f = u === "NHWC" ? i.shape[3] : i.shape[1];
122
+ l(f === a.shape[2], () => `Error in depthwiseConv2d: number of input channels (${f}) must match the inChannels dimension in filter ${a.shape[2]}.`), g("depthwiseConv2d", e, p);
123
+ const x = { x: i, filter: a }, k = { strides: t, pad: e, dataFormat: u, dilations: o, dimRoundingMode: p }, $ = d.runKernel(F, x, k);
124
+ return m ? b($, [$.shape[1], $.shape[2], $.shape[3]]) : $;
125
+ }
126
+ const qn = /* @__PURE__ */ h({ depthwiseConv2d_: Tn });
127
+ function Hn(s, n) {
128
+ let t = c(s, "a", "equal", "string_or_numeric"), e = c(n, "b", "equal", "string_or_numeric");
129
+ [t, e] = M(t, e), j(t.shape, e.shape);
130
+ const u = { a: t, b: e };
131
+ return d.runKernel(J, u);
132
+ }
133
+ const ft = /* @__PURE__ */ h({ equal_: Hn });
134
+ function Pn(s) {
135
+ let n = c(s, "x", "erf");
136
+ l(n.dtype === "int32" || n.dtype === "float32", () => "Input dtype must be `int32` or `float32`."), n.dtype === "int32" && (n = _(n, "float32"));
137
+ const t = { x: n };
138
+ return d.runKernel(Q, t);
139
+ }
140
+ const dt = /* @__PURE__ */ h({ erf_: Pn });
141
+ function yn(s) {
142
+ const t = { x: c(s, "x", "softplus") };
143
+ return d.runKernel(U, t);
144
+ }
145
+ const mt = /* @__PURE__ */ h({ softplus_: yn });
146
+ function In(s, n = -1) {
147
+ const t = c(s, "logits", "logSoftmax");
148
+ if (n === -1 && (n = t.rank - 1), n !== t.rank - 1)
149
+ throw Error(`Log Softmax along a non-last dimension is not yet supported. Logits was rank ${t.rank} and axis was ${n}`);
150
+ return V((u, o) => {
151
+ const r = pn(u, n, !0), a = N(u, r), i = N(_(a, "float32"), hn(K(w(a), n, !0)));
152
+ return o([i]), { value: i, gradFunc: (f, x) => {
153
+ const [k] = x, $ = !0, E = w(k);
154
+ return N(f, X(K(f, n, $), E));
155
+ } };
156
+ })(t);
157
+ }
158
+ const $t = /* @__PURE__ */ h({ logSoftmax_: In });
159
+ function Wn(s) {
160
+ const t = { x: c(s, "x", "logicalNot", "bool") };
161
+ return d.runKernel(Y, t);
162
+ }
163
+ const bt = /* @__PURE__ */ h({ logicalNot_: Wn });
164
+ function Bn(s, n, t, e, u) {
165
+ const o = c(s, "x", "maxPool"), p = 1;
166
+ let r = o, a = !1;
167
+ o.rank === 3 && (a = !0, r = b(o, [1, o.shape[0], o.shape[1], o.shape[2]])), l(r.rank === 4, () => `Error in maxPool: input must be rank 4 but got rank ${r.rank}.`), l(C(t, p), () => `Error in maxPool: Either strides or dilations must be 1. Got strides ${t} and dilations '${p}'`), g("maxPool", e, u);
168
+ const i = { x: r }, m = { filterSize: n, strides: t, pad: e, dimRoundingMode: u }, f = d.runKernel(Z, i, m);
169
+ return a ? b(f, [f.shape[1], f.shape[2], f.shape[3]]) : f;
170
+ }
171
+ const xt = /* @__PURE__ */ h({ maxPool_: Bn });
172
+ function An(s, n, t = 1, e = 0, u = "int32") {
173
+ if (n < 2)
174
+ throw new Error(`Error in oneHot: depth must be >=2, but it is ${n}`);
175
+ const p = { indices: c(s, "indices", "oneHot", "int32") }, r = { dtype: u, depth: n, onValue: t, offValue: e };
176
+ return d.runKernel(R, p, r);
177
+ }
178
+ const kt = /* @__PURE__ */ h({ oneHot_: An });
179
+ function Gn(s) {
180
+ const t = { x: c(s, "x", "onesLike") };
181
+ return d.runKernel(nn, t);
182
+ }
183
+ const vt = /* @__PURE__ */ h({ onesLike_: Gn });
184
+ function Ln(s, n, t = 0) {
185
+ const e = c(s, "x", "pad");
186
+ if (e.rank === 0)
187
+ throw new Error("pad(scalar) is not defined. Pass non-scalar to pad");
188
+ const u = { paddings: n, constantValue: t }, o = { x: e };
189
+ return d.runKernel(tn, o, u);
190
+ }
191
+ const gt = /* @__PURE__ */ h({ pad_: Ln });
192
+ function On(s, n, t) {
193
+ const e = c(s, "x", "spaceToBatchND");
194
+ l(e.rank >= 1 + n.length, () => `input rank ${e.rank} should be > than [blockShape] ${n.length}`), l(t.length === n.length, () => `paddings.shape[0] ${t.length} must be equal to [blockShape] ${n.length}`), l(e.shape.reduce((p, r, a) => a > 0 && a <= n.length ? p && (r + t[a - 1][0] + t[a - 1][1]) % n[a - 1] === 0 : p, !0), () => `input spatial dimensions ${e.shape.slice(1)} with paddings ${t.toString()} must be divisible by blockShapes ${n.toString()}`);
195
+ const u = { x: e }, o = { blockShape: n, paddings: t };
196
+ return d.runKernel(sn, u, o);
197
+ }
198
+ const Dt = /* @__PURE__ */ h({ spaceToBatchND_: On });
199
+ function zn(s, n) {
200
+ const e = { x: c(s, "x", "reverse") }, u = { dims: n };
201
+ return d.runKernel(en, e, u);
202
+ }
203
+ const Ct = /* @__PURE__ */ h({ reverse_: zn });
204
+ function Fn(s) {
205
+ const t = { x: c(s, "x", "rsqrt", "float32") };
206
+ return d.runKernel(rn, t);
207
+ }
208
+ const Et = /* @__PURE__ */ h({ rsqrt_: Fn });
209
+ function Mn(s) {
210
+ const t = { x: c(s, "x", "selu") };
211
+ return d.runKernel(on, t);
212
+ }
213
+ const Nt = /* @__PURE__ */ h({ selu_: Mn });
214
+ function jn(s, n, t, e, u, o = [1, 1], p = "NHWC") {
215
+ const r = c(s, "x", "separableConv2d"), a = c(n, "depthwiseFilter", "separableConv2d"), i = c(t, "pointwiseFilter", "separableConv2d");
216
+ let m = r, f = !1;
217
+ if (r.rank === 3 && (f = !0, m = b(r, [1, r.shape[0], r.shape[1], r.shape[2]])), p === "NCHW")
218
+ throw new Error("separableConv2d currently does not support dataFormat NCHW; only NHWC is supported");
219
+ l(m.rank === 4, () => `Error in separableConv2d: input must be rank 4, but got rank ${m.rank}.`), l(a.rank === 4, () => `Error in separableConv2d: depthwise filter must be rank 4, but got rank ${a.rank}.`), l(i.rank === 4, () => `Error in separableConv2d: pointwise filter must be rank 4, but got rank ${a.rank}.`), l(i.shape[0] === 1, () => `Error in separableConv2d: the first dimension of pointwise filter must be 1, but got ${i.shape[0]}.`), l(i.shape[1] === 1, () => `Error in separableConv2d: the second dimension of pointwise filter must be 1, but got ${i.shape[1]}.`);
220
+ const x = a.shape[2], k = a.shape[3];
221
+ l(i.shape[2] === x * k, () => `Error in separableConv2d: the third dimension of pointwise filter must be ${x * k}, but got ${i.shape[2]}.`);
222
+ const $ = qn(m, a, e, u, p, o), v = S($, i, 1, "valid", p);
223
+ return f ? b(v, [v.shape[1], v.shape[2], v.shape[3]]) : v;
224
+ }
225
+ const _t = /* @__PURE__ */ h({ separableConv2d_: jn });
226
+ function Jn(s) {
227
+ const t = { x: c(s, "x", "sin", "float32") };
228
+ return d.runKernel(an, t);
229
+ }
230
+ const wt = /* @__PURE__ */ h({ sin_: Jn });
231
+ function Qn(s) {
232
+ const t = { x: c(s, "x", "sinh") };
233
+ return d.runKernel(un, t);
234
+ }
235
+ const Kt = /* @__PURE__ */ h({ sinh_: Qn });
236
+ function Un(s, n, t) {
237
+ const e = c(s, "x", "unsortedSegmentSum"), u = c(n, "segmentIds", "unsortedSegmentSum", "int32");
238
+ l(cn(t), () => "numSegments must be of dtype int");
239
+ const o = { x: e, segmentIds: u }, p = { numSegments: t };
240
+ return d.runKernel(ln, o, p);
241
+ }
242
+ const St = /* @__PURE__ */ h({ unsortedSegmentSum_: Un });
243
+ export {
244
+ Et as A,
245
+ Nt as B,
246
+ _t as C,
247
+ Kt as D,
248
+ rt as E,
249
+ St as F,
250
+ En as G,
251
+ mt as a,
252
+ Dt as b,
253
+ ct as c,
254
+ et as d,
255
+ ft as e,
256
+ ot as f,
257
+ nt as g,
258
+ tt as h,
259
+ st as i,
260
+ at as j,
261
+ ut as k,
262
+ bt as l,
263
+ xt as m,
264
+ it as n,
265
+ S as o,
266
+ lt as p,
267
+ pt as q,
268
+ Ct as r,
269
+ wt as s,
270
+ ht as t,
271
+ qn as u,
272
+ dt as v,
273
+ $t as w,
274
+ kt as x,
275
+ vt as y,
276
+ gt as z
277
+ };
@@ -1,6 +1,6 @@
1
- import { a as y, e as S, v as w } from "../index-D6Q1lPZO.js";
2
- import { z as m } from "../zeros-DBFVbpv5.js";
3
- import { o as P } from "../ones-jU9jlQvM.js";
1
+ import { a as y, e as S, v as w } from "../index-D0RBWjq8.js";
2
+ import { z as m } from "../zeros-DeiE2zTa.js";
3
+ import { o as P } from "../ones-BX_wEgzB.js";
4
4
  async function b(s) {
5
5
  const t = m([1, s.config.blockSize], "int32"), [n, o] = s.forward({ training: !1 }, t);
6
6
  await n.data(), n.dispose(), o && o.dispose(), t.dispose();
@@ -1,5 +1,5 @@
1
- import "../index-D6Q1lPZO.js";
2
- import { t as e } from "../tensor2d-CSB4KOb0.js";
1
+ import "../index-D0RBWjq8.js";
2
+ import { t as e } from "../tensor2d-BN1sSfQO.js";
3
3
  function l(n) {
4
4
  let r = 0;
5
5
  const i = Math.random();
@@ -1,7 +1,4 @@
1
- import { PackableTensor } from '../patches/PackedTensor';
2
1
  import { Tensor } from '@tensorflow/tfjs-core';
3
2
  export declare function packingSupported(): boolean;
4
- export declare function isPackableTensor(tensor: Tensor): tensor is PackableTensor;
3
+ export declare function isPackableTensor(tensor: Tensor): boolean;
5
4
  export declare function isPackedTensor(tensor: Tensor): boolean;
6
- export declare function packTensor(tensor: Tensor): Tensor;
7
- export declare function unpackTensor(tensor: Tensor): Tensor;