@genai-fi/nanogpt 0.10.2 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. package/dist/Generator.d.ts +10 -5
  2. package/dist/Generator.js +11760 -146
  3. package/dist/{RealDiv-zz7FpkKX.js → RealDiv-Ds-jvL09.js} +28 -30
  4. package/dist/Reshape-Cd6e-Otn.js +14 -0
  5. package/dist/{Reshape-CHdUjC72.js → Reshape-Ct266DEk.js} +21 -23
  6. package/dist/TeachableLLM.d.ts +4 -3
  7. package/dist/TeachableLLM.js +15 -16
  8. package/dist/Trainer.d.ts +2 -2
  9. package/dist/Trainer.js +6 -6
  10. package/dist/{axis_util-BsIr9ZNu.js → axis_util-DofAuy0p.js} +1 -1
  11. package/dist/backend.js +2 -2
  12. package/dist/{backend_util-B1XRLuq9.js → backend_util-C7NWHpv7.js} +72 -73
  13. package/dist/{backend_webgpu-CqpfEImu.js → backend_webgpu-B0Vls736.js} +52 -54
  14. package/dist/broadcast_to-DDaNMbX7.js +28 -0
  15. package/dist/checks/appendCache.js +2 -2
  16. package/dist/checks/attentionMask.js +3 -3
  17. package/dist/checks/gelu.js +2 -2
  18. package/dist/checks/matMulGelu.js +7 -11
  19. package/dist/checks/normRMS.js +9 -9
  20. package/dist/checks/normRMSGrad.js +3 -3
  21. package/dist/checks/packUnpack.js +2 -2
  22. package/dist/checks/qkv.js +11 -12
  23. package/dist/checks/rope.js +2 -2
  24. package/dist/clip_by_value-Dn5tzexi.js +12 -0
  25. package/dist/complex-DClmWqJt.js +11 -0
  26. package/dist/concat-C6X3AAlQ.js +17 -0
  27. package/dist/{concat_util-iBYIyuQe.js → concat_util-CHsJFZJJ.js} +1 -1
  28. package/dist/{dataset-D2P7rHAw.js → dataset-DcjWqUVQ.js} +135 -137
  29. package/dist/dropout-OxuaJz6z.js +92 -0
  30. package/dist/expand_dims-BzfJK2uc.js +11 -0
  31. package/dist/{exports_initializers-CZSUJoVE.js → exports_initializers-eS9QJ6ut.js} +1 -1
  32. package/dist/floor-DIb-lN_u.js +9 -0
  33. package/dist/gather-BcO5UQNJ.js +9 -0
  34. package/dist/{gelu-Bmhopi0J.js → gelu-DqTbCx5x.js} +10 -11
  35. package/dist/{gpgpu_math-DsCcikas.js → gpgpu_math-CJcbnKPC.js} +841 -1015
  36. package/dist/index-D0RBWjq8.js +3520 -0
  37. package/dist/{index-DRyE072i.js → index-Dj5TkmPY.js} +330 -331
  38. package/dist/{kernel_funcs_utils-CWfOAPGO.js → kernel_funcs_utils-CSaumNDs.js} +132 -134
  39. package/dist/layers/BaseLayer.js +15 -16
  40. package/dist/layers/CausalSelfAttention.js +6 -6
  41. package/dist/layers/MLP.js +4 -4
  42. package/dist/layers/PositionEmbedding.js +7 -7
  43. package/dist/layers/RMSNorm.js +3 -3
  44. package/dist/layers/RoPECache.js +9 -9
  45. package/dist/layers/TiedEmbedding.js +6 -6
  46. package/dist/layers/TransformerBlock.js +1 -1
  47. package/dist/loader/loadTransformers.js +1 -1
  48. package/dist/loader/oldZipLoad.js +21 -22
  49. package/dist/log_sum_exp-VLZgbFAH.js +39 -0
  50. package/dist/main.d.ts +1 -1
  51. package/dist/main.js +49 -50
  52. package/dist/{matMul16-fEAJ4smh.js → matMul16-cDxwemKj.js} +14 -15
  53. package/dist/matMulGelu-B2s_80-H.js +163 -0
  54. package/dist/mat_mul-DxpNTCRz.js +11 -0
  55. package/dist/mod-PrOKlFxH.js +11 -0
  56. package/dist/models/NanoGPTV1.js +2 -2
  57. package/dist/models/model.js +13 -14
  58. package/dist/ones-BX_wEgzB.js +14 -0
  59. package/dist/ops/adamAdjust.js +1 -1
  60. package/dist/ops/adamMoments.js +1 -1
  61. package/dist/ops/add16.js +1 -1
  62. package/dist/ops/appendCache.js +3 -3
  63. package/dist/ops/attentionMask.js +1 -1
  64. package/dist/ops/concat16.js +2 -2
  65. package/dist/ops/cpu/adamAdjust.js +12 -13
  66. package/dist/ops/cpu/adamMoments.js +6 -7
  67. package/dist/ops/cpu/appendCache.js +7 -8
  68. package/dist/ops/cpu/attentionMask.js +11 -11
  69. package/dist/ops/cpu/fusedSoftmax.js +10 -11
  70. package/dist/ops/cpu/gatherSub.js +10 -11
  71. package/dist/ops/cpu/gelu.js +14 -15
  72. package/dist/ops/cpu/matMul16.js +6 -7
  73. package/dist/ops/cpu/matMulGelu.js +5 -6
  74. package/dist/ops/cpu/matMulMul.js +3 -4
  75. package/dist/ops/cpu/mulDropout.js +3 -4
  76. package/dist/ops/cpu/normRMS.js +11 -12
  77. package/dist/ops/cpu/qkv.js +8 -9
  78. package/dist/ops/cpu/rope.js +9 -10
  79. package/dist/ops/cpu/scatterSub.js +14 -16
  80. package/dist/ops/dot16.js +2 -2
  81. package/dist/ops/gatherSub.js +1 -1
  82. package/dist/ops/gelu.js +2 -2
  83. package/dist/ops/grads/add16.js +10 -11
  84. package/dist/ops/grads/attentionMask.js +5 -6
  85. package/dist/ops/grads/gelu.js +3 -4
  86. package/dist/ops/grads/matMul16.js +4 -5
  87. package/dist/ops/grads/matMulGelu.js +8 -9
  88. package/dist/ops/grads/normRMS.js +9 -10
  89. package/dist/ops/grads/pack16.js +4 -5
  90. package/dist/ops/grads/qkv.js +17 -19
  91. package/dist/ops/grads/rope.js +3 -5
  92. package/dist/ops/grads/softmax16.js +3 -4
  93. package/dist/ops/grads/unpack16.js +3 -4
  94. package/dist/ops/grads/utils.d.ts +1 -0
  95. package/dist/ops/grads/utils.js +8 -4
  96. package/dist/ops/matMul16.js +3 -3
  97. package/dist/ops/matMulGelu.js +2 -2
  98. package/dist/ops/matMulMul.js +1 -1
  99. package/dist/ops/mul16.js +1 -1
  100. package/dist/ops/mulDrop.js +1 -1
  101. package/dist/ops/normRMS.js +1 -1
  102. package/dist/ops/pack16.js +3 -4
  103. package/dist/ops/qkv.js +4 -8
  104. package/dist/ops/reshape16.js +16 -18
  105. package/dist/ops/rope.d.ts +1 -1
  106. package/dist/ops/rope.js +3 -8
  107. package/dist/ops/scatterSub.js +1 -1
  108. package/dist/ops/slice16.js +2 -2
  109. package/dist/ops/softmax16.js +5 -8
  110. package/dist/ops/sub16.js +1 -1
  111. package/dist/ops/sum16.js +2 -2
  112. package/dist/ops/transpose16.js +23 -24
  113. package/dist/ops/unpack16.js +2 -2
  114. package/dist/ops/webgl/adamAdjust.js +2 -3
  115. package/dist/ops/webgl/adamMoments.js +1 -2
  116. package/dist/ops/webgl/appendCache.js +1 -2
  117. package/dist/ops/webgl/attentionMask.js +5 -6
  118. package/dist/ops/webgl/fusedSoftmax.js +6 -8
  119. package/dist/ops/webgl/gatherSub.js +6 -7
  120. package/dist/ops/webgl/gelu.js +2 -3
  121. package/dist/ops/webgl/log.js +11 -12
  122. package/dist/ops/webgl/matMul16.js +15 -16
  123. package/dist/ops/webgl/matMulGelu.js +7 -111
  124. package/dist/ops/webgl/matMulMul.js +14 -15
  125. package/dist/ops/webgl/mulDropout.js +8 -9
  126. package/dist/ops/webgl/normRMS.js +7 -8
  127. package/dist/ops/webgl/qkv.js +5 -6
  128. package/dist/ops/webgl/rope.js +7 -8
  129. package/dist/ops/webgl/scatterSub.js +5 -6
  130. package/dist/ops/webgpu/adamAdjust.js +10 -12
  131. package/dist/ops/webgpu/adamMoments.js +8 -10
  132. package/dist/ops/webgpu/add16.js +8 -9
  133. package/dist/ops/webgpu/appendCache.js +23 -25
  134. package/dist/ops/webgpu/attentionMask.js +10 -12
  135. package/dist/ops/webgpu/attentionMask32_program.js +2 -2
  136. package/dist/ops/webgpu/concat16.js +12 -14
  137. package/dist/ops/webgpu/gatherSub.js +9 -11
  138. package/dist/ops/webgpu/gelu.js +28 -29
  139. package/dist/ops/webgpu/matMul16.js +26 -28
  140. package/dist/ops/webgpu/matMul16_program.js +4 -5
  141. package/dist/ops/webgpu/mul16.js +7 -8
  142. package/dist/ops/webgpu/normRMS.js +17 -19
  143. package/dist/ops/webgpu/normRMSGrad.js +21 -28
  144. package/dist/ops/webgpu/pack16.js +12 -13
  145. package/dist/ops/webgpu/pack16_program.js +2 -2
  146. package/dist/ops/webgpu/qkv.js +13 -15
  147. package/dist/ops/webgpu/rope.js +25 -27
  148. package/dist/ops/webgpu/scatterSub.js +7 -9
  149. package/dist/ops/webgpu/slice16.js +21 -23
  150. package/dist/ops/webgpu/softmax16.js +17 -19
  151. package/dist/ops/webgpu/softmax16_program.js +2 -2
  152. package/dist/ops/webgpu/softmax16_subgroup_program.js +2 -2
  153. package/dist/ops/webgpu/softmax16grad.js +7 -8
  154. package/dist/ops/webgpu/sub16.js +8 -9
  155. package/dist/ops/webgpu/sum16.js +19 -21
  156. package/dist/ops/webgpu/transpose16.js +19 -20
  157. package/dist/ops/webgpu/transpose16_program.js +2 -2
  158. package/dist/ops/webgpu/transpose16_shared_program.js +11 -12
  159. package/dist/ops/webgpu/unpack16.js +3 -4
  160. package/dist/ops/webgpu/utils/binary_op.js +7 -8
  161. package/dist/ops/webgpu/utils/reductions.js +14 -22
  162. package/dist/ops-FJapAPfm.js +476 -0
  163. package/dist/pack16-k4jq6aMX.js +39 -0
  164. package/dist/patches/webgpu_backend.js +19 -20
  165. package/dist/patches/webgpu_base.js +1 -1
  166. package/dist/patches/webgpu_program.js +15 -16
  167. package/dist/{random_width-BVV9HveY.js → random_width-UGQn4OWb.js} +2506 -2761
  168. package/dist/range-CuGvVN2c.js +10 -0
  169. package/dist/relu-Cf80uA2p.js +9 -0
  170. package/dist/reshape-CkjKPPqB.js +9 -0
  171. package/dist/resize_nearest_neighbor-DB8k9KN_.js +175 -0
  172. package/dist/rope-BmZmp9uP.js +24 -0
  173. package/dist/{scatter_nd_util-C7zXRT_h.js → scatter_nd_util-BY22Cc-C.js} +1 -1
  174. package/dist/selu_util-BuLbmbrl.js +44 -0
  175. package/dist/{shared-CHhxz-O5.js → shared-B7USJZgw.js} +1 -1
  176. package/dist/{shared-D2NP_CpY.js → shared-BQboIImQ.js} +379 -381
  177. package/dist/slice-Aqy7KbJh.js +12 -0
  178. package/dist/{slice_util-DyjSAD0u.js → slice_util-D8CQRenR.js} +7 -7
  179. package/dist/{softmax-C9JQEtnO.js → softmax-faLoUZVT.js} +4 -5
  180. package/dist/split-BNz5jcGc.js +9 -0
  181. package/dist/squeeze--YMgaAAf.js +10 -0
  182. package/dist/stack-WJK22CFn.js +11 -0
  183. package/dist/step-dXR33iOg.js +261 -0
  184. package/dist/sum-BdplSvq_.js +11 -0
  185. package/dist/{tensor-0r5yOo2R.js → tensor-BQqrDvpx.js} +1 -1
  186. package/dist/tensor1d-LxP9asMm.js +11 -0
  187. package/dist/{tensor2d-CSB4KOb0.js → tensor2d-BN1sSfQO.js} +6 -7
  188. package/dist/{tensor4d-D7bLqGqz.js → tensor4d-DVwr7pLF.js} +6 -7
  189. package/dist/{tfjs_backend-CNkSTL0c.js → tfjs_backend-Vi4JfLzT.js} +256 -265
  190. package/dist/tile-CvN_LyVr.js +11 -0
  191. package/dist/tokeniser/BaseTokeniser.d.ts +27 -0
  192. package/dist/tokeniser/BaseTokeniser.js +94 -0
  193. package/dist/tokeniser/CharTokeniser.d.ts +4 -3
  194. package/dist/tokeniser/CharTokeniser.js +46 -32
  195. package/dist/tokeniser/bpe.d.ts +4 -3
  196. package/dist/tokeniser/bpe.js +60 -45
  197. package/dist/tokeniser/type.d.ts +11 -0
  198. package/dist/training/Adam.js +2 -2
  199. package/dist/training/AdamExt.js +1 -1
  200. package/dist/training/DatasetBuilder.d.ts +2 -2
  201. package/dist/training/DatasetBuilder.js +32 -36
  202. package/dist/training/FullTrainer.js +1 -1
  203. package/dist/training/Trainer.d.ts +3 -3
  204. package/dist/training/Trainer.js +2 -2
  205. package/dist/training/sparseCrossEntropy.js +5 -5
  206. package/dist/transpose-JawVKyZy.js +36 -0
  207. package/dist/unsorted_segment_sum-LAbmE9G4.js +277 -0
  208. package/dist/utilities/dummy.js +3 -3
  209. package/dist/utilities/multinomialCPU.js +2 -2
  210. package/dist/utilities/packed.d.ts +1 -4
  211. package/dist/utilities/packed.js +10 -745
  212. package/dist/utilities/performance.js +1 -1
  213. package/dist/utilities/profile.js +1 -1
  214. package/dist/utilities/safetensors.js +2 -2
  215. package/dist/utilities/sentences.js +5 -5
  216. package/dist/utilities/weights.js +2 -2
  217. package/dist/{variable-DzfrwYuP.js → variable-DQ9yYgEU.js} +1 -1
  218. package/dist/{webgpu_program-DzaQiqel.js → webgpu_program-CAE4RICo.js} +177 -171
  219. package/dist/{webgpu_util-0_ubCEHJ.js → webgpu_util-BdovYhXr.js} +34 -35
  220. package/dist/zeros-DeiE2zTa.js +13 -0
  221. package/dist/zeros_like-BAz3iKru.js +721 -0
  222. package/package.json +4 -2
  223. package/dist/Reshape-CDVLyVfz.js +0 -16
  224. package/dist/broadcast_to-B0ChcDaz.js +0 -30
  225. package/dist/complex-BBiRlsVq.js +0 -13
  226. package/dist/concat-DmBLPVGC.js +0 -19
  227. package/dist/dropout-B1x1kYMa.js +0 -99
  228. package/dist/expand_dims-ouvfxQ1n.js +0 -13
  229. package/dist/gather-CH9sdacz.js +0 -10
  230. package/dist/index-D6Q1lPZO.js +0 -2157
  231. package/dist/log_sum_exp-D3ftBNY5.js +0 -41
  232. package/dist/mat_mul-C59XWcJd.js +0 -12
  233. package/dist/mod-DESSvHIU.js +0 -12
  234. package/dist/mulmat_packed_gpu-Coh6qbJk.js +0 -55
  235. package/dist/ones-jU9jlQvM.js +0 -15
  236. package/dist/ops-BFDtP6th.js +0 -645
  237. package/dist/pack16-CmVZs6af.js +0 -41
  238. package/dist/patches/PackedTensor.d.ts +0 -12
  239. package/dist/patches/PackedTensor.js +0 -11
  240. package/dist/patches/engine.d.ts +0 -261
  241. package/dist/patches/engine.js +0 -12
  242. package/dist/patches/tape.d.ts +0 -12
  243. package/dist/patches/tape.js +0 -5
  244. package/dist/range-ZZZD60Fx.js +0 -11
  245. package/dist/reciprocal-CrYlsAGD.js +0 -10
  246. package/dist/register_all_kernels-nvj2k7OC.js +0 -12307
  247. package/dist/relu-BYDneVPn.js +0 -10
  248. package/dist/reshape-CaPQzFvz.js +0 -10
  249. package/dist/rope-s4W2XO9B.js +0 -32
  250. package/dist/selu_util-BGPXmd4B.js +0 -303
  251. package/dist/sin-Djs4aQiu.js +0 -16
  252. package/dist/slice-DvovR5wq.js +0 -13
  253. package/dist/split-DBck65sX.js +0 -10
  254. package/dist/squeeze-C00Ipm_7.js +0 -11
  255. package/dist/stack-ChnHwRpX.js +0 -13
  256. package/dist/sum-ywRJj3Zr.js +0 -12
  257. package/dist/tensor-CzmOBsdf.js +0 -909
  258. package/dist/tensor1d-BlUT89BP.js +0 -12
  259. package/dist/tensor_util-DfwaWayG.js +0 -523
  260. package/dist/tile-CR074jmp.js +0 -13
  261. package/dist/transpose-DH4gmHvu.js +0 -38
  262. package/dist/zeros-DBFVbpv5.js +0 -14
@@ -1,5 +1,5 @@
1
- import { e as r } from "../../webgpu_program-DzaQiqel.js";
2
- import { f as a, c as u } from "../../webgpu_util-0_ubCEHJ.js";
1
+ import { e as r } from "../../webgpu_program-CAE4RICo.js";
2
+ import { f as a, c as u } from "../../webgpu_util-BdovYhXr.js";
3
3
  class p {
4
4
  variableNames = ["q", "k"];
5
5
  outputShape;
@@ -1,10 +1,8 @@
1
- import "../../index-D6Q1lPZO.js";
2
- import { e as x } from "../../webgpu_program-DzaQiqel.js";
3
- import { f as I, c as D } from "../../webgpu_util-0_ubCEHJ.js";
4
- import { r as y } from "../../Reshape-CDVLyVfz.js";
5
- import { r as $ } from "../../tensor_util-DfwaWayG.js";
6
- import { p as F, s as c } from "../../tensor-CzmOBsdf.js";
7
- import { a as L, c as d } from "../../concat_util-iBYIyuQe.js";
1
+ import { h as x, af as I, V as c } from "../../index-D0RBWjq8.js";
2
+ import { e as D } from "../../webgpu_program-CAE4RICo.js";
3
+ import { f as $, c as F } from "../../webgpu_util-BdovYhXr.js";
4
+ import { r as g } from "../../Reshape-Cd6e-Otn.js";
5
+ import { a as L, c as d } from "../../concat_util-CHsJFZJJ.js";
8
6
  class T {
9
7
  outputShape;
10
8
  shaderKey;
@@ -21,7 +19,7 @@ class T {
21
19
  t,
22
20
  1
23
21
  /* axis */
24
- ), this.variableNames = t.map((e, a) => `T${a}`), this.dispatchLayout = I(this.outputShape), this.dispatch = D(this.dispatchLayout, this.outputShape, this.workgroupSize, [
22
+ ), this.variableNames = t.map((e, a) => `T${a}`), this.dispatchLayout = $(this.outputShape), this.dispatch = F(this.dispatchLayout, this.outputShape, this.workgroupSize, [
25
23
  this.workPerThread,
26
24
  1,
27
25
  1
@@ -49,7 +47,7 @@ class T {
49
47
  "result[getIndexFromCoords2D(coords, uniforms.outShape)] = T0[getIndexFromCoords2D(vec2<i32>(yR, yC), uniforms.t0Shape)];"
50
48
  );
51
49
  return `
52
- ${x("index")} {
50
+ ${D("index")} {
53
51
  for(var i = 0; i < ${this.workPerThread}; i = i + 1) {
54
52
  let flatIndex = index * ${this.workPerThread} + i;
55
53
  if(flatIndex < uniforms.size) {
@@ -86,8 +84,8 @@ function m(n, t, e) {
86
84
  }
87
85
  const l = e.runWebGPUProgram(u, i, i[0].dtype, f);
88
86
  i.forEach((o) => e.disposeData(o.dataId));
89
- const g = y({ inputs: { x: l }, backend: e, attrs: { shape: s } });
90
- return e.disposeData(l.dataId), g.packed = !0, g;
87
+ const y = g({ inputs: { x: l }, backend: e, attrs: { shape: s } });
88
+ return e.disposeData(l.dataId), y;
91
89
  }
92
90
  function P(n, t, e) {
93
91
  const a = d(
@@ -95,7 +93,7 @@ function P(n, t, e) {
95
93
  t
96
94
  );
97
95
  return { tensors2D: n.map(
98
- (s) => y({
96
+ (s) => g({
99
97
  inputs: { x: s },
100
98
  backend: e,
101
99
  attrs: {
@@ -105,7 +103,7 @@ function P(n, t, e) {
105
103
  ), outShape: a };
106
104
  }
107
105
  function w(n) {
108
- const { inputs: t, backend: e, attrs: a } = n, { axis: i } = a, s = F(i, t[0].shape)[0], h = t.map((r) => r.shape);
106
+ const { inputs: t, backend: e, attrs: a } = n, { axis: i } = a, s = I(i, t[0].shape)[0], h = t.map((r) => r.shape);
109
107
  L(h, s);
110
108
  const u = d(
111
109
  t.map((r) => r.shape),
@@ -121,7 +119,7 @@ const v = {
121
119
  backendName: "webgpu",
122
120
  kernelFunc: w
123
121
  };
124
- $(v);
122
+ x(v);
125
123
  export {
126
124
  T as ConcatProgram,
127
125
  v as concatConfig
@@ -1,8 +1,6 @@
1
- import { e as u } from "../../webgpu_program-DzaQiqel.js";
2
- import { f as p, c as h } from "../../webgpu_util-0_ubCEHJ.js";
3
- import "../../index-D6Q1lPZO.js";
4
- import { j as s } from "../../tensor-CzmOBsdf.js";
5
- import { r as c } from "../../tensor_util-DfwaWayG.js";
1
+ import { e as u } from "../../webgpu_program-CAE4RICo.js";
2
+ import { f as h, c as p } from "../../webgpu_util-BdovYhXr.js";
3
+ import { h as c, a7 as r } from "../../index-D0RBWjq8.js";
6
4
  class l {
7
5
  variableNames = ["labels", "logits", "values"];
8
6
  outputShape;
@@ -11,8 +9,8 @@ class l {
11
9
  dispatch;
12
10
  workgroupSize = [64, 1, 1];
13
11
  size = !0;
14
- constructor(e) {
15
- this.outputShape = [e], this.dispatchLayout = p(this.outputShape), this.dispatch = h(this.dispatchLayout, this.outputShape, this.workgroupSize);
12
+ constructor(t) {
13
+ this.outputShape = [t], this.dispatchLayout = h(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize);
16
14
  }
17
15
  getUserCode() {
18
16
  return `
@@ -27,11 +25,11 @@ class l {
27
25
  `;
28
26
  }
29
27
  }
30
- function d(t) {
31
- const { logits: e, labels: a, values: r } = t.inputs, o = t.backend, i = a.shape[0];
32
- s(r.shape, [i], "Error in EfficientGatherSub: "), s(a.shape, [i], "Error in EfficientGatherSub: ");
28
+ function d(e) {
29
+ const { logits: t, labels: a, values: s } = e.inputs, o = e.backend, i = a.shape[0];
30
+ r(s.shape, [i], "Error in EfficientGatherSub: "), r(a.shape, [i], "Error in EfficientGatherSub: ");
33
31
  const n = new l(i);
34
- return o.runWebGPUProgram(n, [a, e, r], "float32");
32
+ return o.runWebGPUProgram(n, [a, t, s], "float32");
35
33
  }
36
34
  const f = {
37
35
  kernelName: "EfficientGatherSub",
@@ -1,10 +1,9 @@
1
- import "../../index-D6Q1lPZO.js";
2
- import { e as s } from "../../webgpu_program-DzaQiqel.js";
3
- import { f as o, c as p } from "../../webgpu_util-0_ubCEHJ.js";
1
+ import { h as d } from "../../index-D0RBWjq8.js";
2
+ import { e as s } from "../../webgpu_program-CAE4RICo.js";
3
+ import { f as n, c as o } from "../../webgpu_util-BdovYhXr.js";
4
4
  import { isPackedTensor as l } from "../../utilities/packed.js";
5
- import { r as h } from "../../tensor_util-DfwaWayG.js";
6
- const r = 0.7978845608028654, u = 0.044715;
7
- class x {
5
+ const u = 0.7978845608028654, r = 0.044715;
6
+ class c {
8
7
  outputShape;
9
8
  shaderKey;
10
9
  dispatchLayout;
@@ -13,7 +12,7 @@ class x {
13
12
  workgroupSize;
14
13
  size = !0;
15
14
  constructor(e) {
16
- this.workgroupSize = [128, 1, 1], this.outputShape = e, this.dispatchLayout = o(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize), this.shaderKey = "unary_gelu";
15
+ this.workgroupSize = [128, 1, 1], this.outputShape = e, this.dispatchLayout = n(this.outputShape), this.dispatch = o(this.dispatchLayout, this.outputShape, this.workgroupSize), this.shaderKey = "unary_gelu";
17
16
  }
18
17
  getUserCode() {
19
18
  return `
@@ -23,8 +22,8 @@ class x {
23
22
  }
24
23
  fn unaryOperation(x : f32) -> f32 {
25
24
  let x3 = x * x * x;
26
- var inner = fma(${u}, x3, x);
27
- inner = ${r} * inner;
25
+ var inner = fma(${r}, x3, x);
26
+ inner = ${u} * inner;
28
27
  inner = tanhComplete(inner);
29
28
  inner = 0.5 * (1.0 + inner);
30
29
  return x * inner;
@@ -38,17 +37,17 @@ class x {
38
37
  `;
39
38
  }
40
39
  }
41
- function g(t) {
42
- const { x: e } = t.inputs, a = t.backend, i = new x(e.shape);
40
+ function x(t) {
41
+ const { x: e } = t.inputs, a = t.backend, i = new c(e.shape);
43
42
  return a.runWebGPUProgram(i, [e], "float32");
44
43
  }
45
- const f = {
44
+ const g = {
46
45
  kernelName: "Gelu",
47
46
  backendName: "webgpu",
48
- kernelFunc: g
47
+ kernelFunc: x
49
48
  };
50
- h(f);
51
- class m {
49
+ d(g);
50
+ class f {
52
51
  // Inputs: dy, x
53
52
  variableNames = ["dy", "x"];
54
53
  outputShape;
@@ -58,7 +57,7 @@ class m {
58
57
  workgroupSize = [128, 1, 1];
59
58
  size = !0;
60
59
  constructor(e) {
61
- this.outputShape = e, this.dispatchLayout = o(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize);
60
+ this.outputShape = e, this.dispatchLayout = n(this.outputShape), this.dispatch = o(this.dispatchLayout, this.outputShape, this.workgroupSize);
62
61
  }
63
62
  getUserCode() {
64
63
  return `
@@ -69,10 +68,10 @@ class m {
69
68
  fn activationGrad(dy: f32, X: f32) -> f32 {
70
69
  let x2 = X * X;
71
70
  let x3 = x2 * X;
72
- let u = ${r} * (X + ${u} * x3);
71
+ let u = ${u} * (X + ${r} * x3);
73
72
  let t = tanhComplete(u);
74
73
  let sech2 = 1.0 - t * t;
75
- let du_dx = ${r} * (1.0 + 3.0 * ${u} * x2);
74
+ let du_dx = ${u} * (1.0 + 3.0 * ${r} * x2);
76
75
  let dgelu = 0.5 * (1.0 + t) + 0.5 * X * sech2 * du_dx;
77
76
  return dy *dgelu;
78
77
  }
@@ -89,7 +88,7 @@ class m {
89
88
  }`;
90
89
  }
91
90
  }
92
- class y {
91
+ class m {
93
92
  // Inputs: dy, x
94
93
  variableNames = ["dy", "x"];
95
94
  outputShape;
@@ -99,7 +98,7 @@ class y {
99
98
  workgroupSize = [128, 1, 1];
100
99
  size = !0;
101
100
  constructor(e) {
102
- this.outputShape = e, this.dispatchLayout = o(this.outputShape), this.dispatch = p(this.dispatchLayout, this.outputShape, this.workgroupSize);
101
+ this.outputShape = e, this.dispatchLayout = n(this.outputShape), this.dispatch = o(this.dispatchLayout, this.outputShape, this.workgroupSize);
103
102
  }
104
103
  getUserCode() {
105
104
  return `
@@ -110,10 +109,10 @@ class y {
110
109
  fn activationGrad(dy: f32, X: f32) -> f32 {
111
110
  let x2 = X * X;
112
111
  let x3 = x2 * X;
113
- let u = ${r} * (X + ${u} * x3);
112
+ let u = ${u} * (X + ${r} * x3);
114
113
  let t = tanhComplete(u);
115
114
  let sech2 = 1.0 - t * t;
116
- let du_dx = ${r} * (1.0 + 3.0 * ${u} * x2);
115
+ let du_dx = ${u} * (1.0 + 3.0 * ${r} * x2);
117
116
  let dgelu = 0.5 * (1.0 + t) + 0.5 * X * sech2 * du_dx;
118
117
  return dy *dgelu;
119
118
  }
@@ -127,16 +126,16 @@ class y {
127
126
  }`;
128
127
  }
129
128
  }
130
- function b(t) {
131
- const { dy: e, x: a } = t.inputs, i = t.backend, n = l(e), c = n ? new m(a.shape) : new y(a.shape), d = i.runWebGPUProgram(c, [e, a], n ? "int32" : "float32");
132
- return d.packed = n, d;
129
+ function y(t) {
130
+ const { dy: e, x: a } = t.inputs, i = t.backend, p = l(e), h = p ? new f(a.shape) : new m(a.shape);
131
+ return i.runWebGPUProgram(h, [e, a], p ? "packedF16" : "float32");
133
132
  }
134
- const k = {
133
+ const b = {
135
134
  kernelName: "GeluGrad",
136
135
  backendName: "webgpu",
137
- kernelFunc: b
136
+ kernelFunc: y
138
137
  };
139
- h(k);
138
+ d(b);
140
139
  export {
141
- x as GeluProgram
140
+ c as GeluProgram
142
141
  };
@@ -1,34 +1,32 @@
1
- import { m as y, b as B, j as Q } from "../../index-D6Q1lPZO.js";
1
+ import { h as H, m as P, b as B, V as y, $ as J } from "../../index-D0RBWjq8.js";
2
2
  import { isPackedTensor as R } from "../../utilities/packed.js";
3
3
  import { reshape16 as U } from "../reshape16.js";
4
- import { matMulMul as V } from "../matMulMul.js";
4
+ import { matMulMul as Q } from "../matMulMul.js";
5
5
  import { matMulGelu as X } from "../matMulGelu.js";
6
6
  import Y from "./matMul16_program.js";
7
- import { r as Z } from "../../tensor_util-DfwaWayG.js";
8
- import { m as _ } from "../../mat_mul-C59XWcJd.js";
9
- import { r as x } from "../../reshape-CaPQzFvz.js";
10
- import { t as C } from "../../transpose-DH4gmHvu.js";
11
- import { s as E } from "../../tensor-CzmOBsdf.js";
12
- function $(p) {
13
- const { A: e, B: s } = p.inputs, { transposeA: d, transposeB: f, scale: i, activation: k, scaleA: c, scaleB: u, forceOutputShape: o, perm: m, causalMask: g, pastLen: W } = p.attrs, z = p.backend, S = !R(e), M = !R(s);
7
+ import { m as Z } from "../../mat_mul-DxpNTCRz.js";
8
+ import { r as x } from "../../reshape-CkjKPPqB.js";
9
+ import { t as C } from "../../transpose-JawVKyZy.js";
10
+ function _(p) {
11
+ const { A: e, B: s } = p.inputs, { transposeA: d, transposeB: f, scale: i, activation: k, scaleA: c, scaleB: u, forceOutputShape: o, perm: h, causalMask: g, pastLen: E } = p.attrs, F = p.backend, S = !R(e), M = !R(s);
14
12
  if (S && M) {
15
- const A = c !== void 0 ? y(e, B(c)) : e, b = u !== void 0 ? y(s, B(u)) : s;
13
+ const A = c !== void 0 ? P(e, B(c)) : e, b = u !== void 0 ? P(s, B(u)) : s;
16
14
  if (g)
17
15
  throw new Error("Causal mask is not supported for unpacked MatMul16.");
18
16
  let a;
19
- if (i !== void 0 ? a = V(A, b, B(i), d, f) : k === "gelu" ? a = X(A, b) : a = _(A, b, d, f), m)
17
+ if (i !== void 0 ? a = Q(A, b, B(i), d, f) : k === "gelu" ? a = X(A, b) : a = Z(A, b, d, f), h)
20
18
  if (o) {
21
- const n = x(a, o);
19
+ const r = x(a, o);
22
20
  a.dispose();
23
- const J = C(n, m);
24
- return n.dispose(), J;
21
+ const q = C(r, h);
22
+ return r.dispose(), q;
25
23
  } else {
26
- const n = C(a, m);
27
- return a.dispose(), n;
24
+ const r = C(a, h);
25
+ return a.dispose(), r;
28
26
  }
29
27
  else if (o) {
30
- const n = x(a, o);
31
- return a.dispose(), n;
28
+ const r = x(a, o);
29
+ return a.dispose(), r;
32
30
  } else
33
31
  return a;
34
32
  }
@@ -36,23 +34,23 @@ function $(p) {
36
34
  throw new Error("When using mixed precision, A must be packed if B is packed.");
37
35
  if (!S && M)
38
36
  throw new Error("When using mixed precision, B must be packed if A is packed.");
39
- const h = e.shape.length, l = s.shape.length, F = e.shape.slice(0, -2), I = s.shape.slice(0, -2), v = E(F), w = E(I), N = Q(e.shape.slice(0, -2), s.shape.slice(0, -2)), j = Math.max(v, w), K = e.shape[h - 2], L = s.shape[l - 2], T = e.shape[h - 1] * 2, q = s.shape[l - 1] * 2, D = U(e, [v, e.shape[h - 2], e.shape[h - 1]]), G = U(s, [w, s.shape[l - 2], s.shape[l - 1]]), t = new Y(j, K, L, T, q, d, f), r = [];
40
- i !== void 0 && (t.useScale(), r.push({ type: "float32", data: [i] })), c !== void 0 && (t.useScaleA(), r.push({ type: "float32", data: [c] })), u !== void 0 && (t.useScaleB(), r.push({ type: "float32", data: [u] })), k !== void 0 && t.useActivation(k), g && (t.useCausalMask(), r.push({ type: "int32", data: [W || 0] }));
37
+ const l = e.shape.length, m = s.shape.length, W = e.shape.slice(0, -2), z = s.shape.slice(0, -2), v = y(W), w = y(z), I = J(e.shape.slice(0, -2), s.shape.slice(0, -2)), N = Math.max(v, w), K = e.shape[l - 2], L = s.shape[m - 2], T = e.shape[l - 1] * 2, V = s.shape[m - 1] * 2, D = U(e, [v, e.shape[l - 2], e.shape[l - 1]]), G = U(s, [w, s.shape[m - 2], s.shape[m - 1]]), t = new Y(N, K, L, T, V, d, f), n = [];
38
+ i !== void 0 && (t.useScale(), n.push({ type: "float32", data: [i] })), c !== void 0 && (t.useScaleA(), n.push({ type: "float32", data: [c] })), u !== void 0 && (t.useScaleB(), n.push({ type: "float32", data: [u] })), k !== void 0 && t.useActivation(k), g && (t.useCausalMask(), n.push({ type: "int32", data: [E || 0] }));
41
39
  const O = t.outputShape.length;
42
40
  o && (p.attrs.originalShape = t.outputShape);
43
- const H = o ?? N.concat([t.outputShape[O - 2], t.outputShape[O - 1]]);
44
- t.setOutputShape(H, m);
45
- const P = z.runWebGPUProgram(
41
+ const $ = o ?? I.concat([t.outputShape[O - 2], t.outputShape[O - 1]]);
42
+ t.setOutputShape($, h);
43
+ const j = F.runWebGPUProgram(
46
44
  t,
47
45
  [D, G],
48
- "int32",
49
- r.length > 0 ? r : void 0
46
+ "packedF16",
47
+ n.length > 0 ? n : void 0
50
48
  );
51
- return P.packed = !0, D.dispose(), G.dispose(), P;
49
+ return D.dispose(), G.dispose(), j;
52
50
  }
53
51
  const ee = {
54
52
  kernelName: "MatMul16",
55
53
  backendName: "webgpu",
56
- kernelFunc: $
54
+ kernelFunc: _
57
55
  };
58
- Z(ee);
56
+ H(ee);
@@ -1,7 +1,6 @@
1
- import "../../index-D6Q1lPZO.js";
2
- import { e as h } from "../../webgpu_program-DzaQiqel.js";
3
- import { s as f } from "../../tensor-CzmOBsdf.js";
4
- class A {
1
+ import { V as f } from "../../index-D0RBWjq8.js";
2
+ import { e as h } from "../../webgpu_program-CAE4RICo.js";
3
+ class B {
5
4
  variableNames = ["A", "B"];
6
5
  outputShape;
7
6
  shaderKey = "MatMul16TB";
@@ -332,5 +331,5 @@ class A {
332
331
  }
333
332
  }
334
333
  export {
335
- A as default
334
+ B as default
336
335
  };
@@ -1,14 +1,13 @@
1
- import "../../index-D6Q1lPZO.js";
1
+ import { h as t } from "../../index-D0RBWjq8.js";
2
2
  import { BinaryOpProgram as m } from "./utils/binary_op.js";
3
3
  import { B as p } from "../../binary_op_util-pKXltfxI.js";
4
- import { r as c } from "../../tensor_util-DfwaWayG.js";
5
- function i(r) {
6
- const { a: e, b: n } = r.inputs, t = r.backend, a = new m(p.MUL, e.shape, n.shape), o = t.runWebGPUProgram(a, [e, n], "int32");
7
- return o.packed = !0, o;
4
+ function s(e) {
5
+ const { a: r, b: n } = e.inputs, o = e.backend, a = new m(p.MUL, r.shape, n.shape);
6
+ return o.runWebGPUProgram(a, [r, n], "packedF16");
8
7
  }
9
- const s = {
8
+ const c = {
10
9
  kernelName: "Mul16",
11
10
  backendName: "webgpu",
12
- kernelFunc: i
11
+ kernelFunc: s
13
12
  };
14
- c(s);
13
+ t(c);
@@ -1,30 +1,28 @@
1
- import "../../index-D6Q1lPZO.js";
2
- import { createReduceInfo as g, reduce as l } from "./utils/reductions.js";
3
- import { j as w } from "../../tensor-CzmOBsdf.js";
4
- import { isPackedTensor as f } from "../../utilities/packed.js";
5
- import { p as k } from "../../pack16-CmVZs6af.js";
6
- import S from "./normRMS16_program.js";
7
- import z from "./normRMS32_program.js";
8
- import N from "./utils/deviceInfo.js";
9
- import { r as b } from "../../tensor_util-DfwaWayG.js";
10
- function P(m) {
11
- const { x: e, gamma: n } = m.inputs, c = m.backend, i = N(c), s = f(e), a = f(n), o = s || a, r = !o || s ? e : k(e), p = !o || a ? n : k(n), h = [r, p], t = g(h, -1), u = o ? new S(i, t) : new z(i, t);
12
- if (w(p.shape, [r.shape[r.shape.length - 1]], "Error in RMSNorm: "), e.shape.length !== 3)
1
+ import { h as g, a7 as l } from "../../index-D0RBWjq8.js";
2
+ import { createReduceInfo as w, reduce as S } from "./utils/reductions.js";
3
+ import { isPackedTensor as d } from "../../utilities/packed.js";
4
+ import { p as f } from "../../pack16-k4jq6aMX.js";
5
+ import z from "./normRMS16_program.js";
6
+ import N from "./normRMS32_program.js";
7
+ import b from "./utils/deviceInfo.js";
8
+ function P(c) {
9
+ const { x: e, gamma: s } = c.inputs, m = c.backend, i = b(m), t = d(e), a = d(s), n = t || a, r = !n || t ? e : f(e), p = !n || a ? s : f(s), h = [r, p], o = w(h, -1), u = n ? new z(i, o) : new N(i, o);
10
+ if (l(p.shape, [r.shape[r.shape.length - 1]], "Error in RMSNorm: "), e.shape.length !== 3)
13
11
  throw new Error(`rmsNormGPU: input rank ${e.shape.length} not supported, only rank 3 is supported`);
14
- if (t.inSize !== r.shape[r.shape.length - 1])
12
+ if (o.inSize !== r.shape[r.shape.length - 1])
15
13
  throw new Error(
16
- `rmsNormGPU: reduction size ${t.inSize} does not match expected size ${r.shape[r.shape.length - 1]}`
14
+ `rmsNormGPU: reduction size ${o.inSize} does not match expected size ${r.shape[r.shape.length - 1]}`
17
15
  );
18
- if (t.batchSize !== e.shape[0] * e.shape[1])
16
+ if (o.batchSize !== e.shape[0] * e.shape[1])
19
17
  throw new Error(
20
- `rmsNormGPU: batch size ${t.batchSize} does not match expected size ${e.shape[0] * e.shape[1]}`
18
+ `rmsNormGPU: batch size ${o.batchSize} does not match expected size ${e.shape[0] * e.shape[1]}`
21
19
  );
22
- const d = l(u, h, c);
23
- return d.packed = o, o && !s && r.dispose(), o && !a && p.dispose(), d;
20
+ const k = S(u, h, m);
21
+ return n && !t && r.dispose(), n && !a && p.dispose(), k;
24
22
  }
25
23
  const G = {
26
24
  kernelName: "RMSNorm",
27
25
  backendName: "webgpu",
28
26
  kernelFunc: P
29
27
  };
30
- b(G);
28
+ g(G);
@@ -1,14 +1,12 @@
1
- import { e as _ } from "../../index-D6Q1lPZO.js";
2
- import { createReduceInfo as D } from "./utils/reductions.js";
3
- import { f as X } from "../../webgpu_util-0_ubCEHJ.js";
4
- import { e as $ } from "../../webgpu_program-DzaQiqel.js";
5
- import { j as z } from "../../tensor-CzmOBsdf.js";
6
- import { p as k, u as M } from "../../pack16-CmVZs6af.js";
1
+ import { h as _, a7 as y, e as D } from "../../index-D0RBWjq8.js";
2
+ import { createReduceInfo as X } from "./utils/reductions.js";
3
+ import { f as $ } from "../../webgpu_util-BdovYhXr.js";
4
+ import { e as M } from "../../webgpu_program-CAE4RICo.js";
5
+ import { p as k, u as R } from "../../pack16-k4jq6aMX.js";
7
6
  import { isPackedTensor as h } from "../../utilities/packed.js";
8
- import { reshape16 as R } from "../reshape16.js";
9
- import { sum16 as L } from "../sum16.js";
10
- import { slice16 as w } from "../slice16.js";
11
- import { r as P } from "../../tensor_util-DfwaWayG.js";
7
+ import { reshape16 as L } from "../reshape16.js";
8
+ import { sum16 as P } from "../sum16.js";
9
+ import { slice16 as z } from "../slice16.js";
12
10
  class N {
13
11
  outputShape;
14
12
  shaderKey = "RMSNormGrad";
@@ -23,7 +21,7 @@ class N {
23
21
  packed = !1;
24
22
  outputComponent;
25
23
  constructor(a, e = 4, o = !1) {
26
- if (this.packed = o, this.shaderKey = `RMSNormGrad_${e}`, this.rowsPerWorkgroup = e, this.inputShape = [a.batchSize, a.inSize], this.outputShape = [a.batchSize + a.batchSize / this.rowsPerWorkgroup, a.inSize], this.dispatchLayout = X(this.outputShape), this.dispatch = [a.batchSize / this.rowsPerWorkgroup, 1, 1], a.batchSize % this.rowsPerWorkgroup !== 0)
24
+ if (this.packed = o, this.shaderKey = `RMSNormGrad_${e}`, this.rowsPerWorkgroup = e, this.inputShape = [a.batchSize, a.inSize], this.outputShape = [a.batchSize + a.batchSize / this.rowsPerWorkgroup, a.inSize], this.dispatchLayout = $(this.outputShape), this.dispatch = [a.batchSize / this.rowsPerWorkgroup, 1, 1], a.batchSize % this.rowsPerWorkgroup !== 0)
27
25
  throw new Error(
28
26
  `RMSNormGradProgram: batch size ${a.batchSize} must be divisible by rowsPerWorkgroup ${this.rowsPerWorkgroup}`
29
27
  );
@@ -87,7 +85,7 @@ class N {
87
85
 
88
86
  ${o}
89
87
 
90
- ${$("index")} {
88
+ ${M("index")} {
91
89
  // One workgroup per row (batch).
92
90
  let Length = uniforms.reduceSize;
93
91
  let BatchSize = uniforms.batchSize;
@@ -145,10 +143,10 @@ class N {
145
143
  }
146
144
  function W(p) {
147
145
  const { dy: a, x: e, gamma: o } = p.inputs, n = 4;
148
- z(e.shape, a.shape, "Error in RMSNormGrad dy: ");
146
+ y(e.shape, a.shape, "Error in RMSNormGrad dy: ");
149
147
  const s = h(e), i = h(o), u = h(a), r = s || i || u, m = !r || s ? e : k(e), c = !r || i ? o : k(o), d = !r || u ? a : k(a);
150
- z(c.shape, [m.shape[m.shape.length - 1]], "Error in RMSNormGrad gamma: ");
151
- const G = p.backend, t = D([m, c, d], -1), f = new N(t, n, r), v = [
148
+ y(c.shape, [m.shape[m.shape.length - 1]], "Error in RMSNormGrad gamma: ");
149
+ const w = p.backend, t = X([m, c, d], -1), f = new N(t, n, r), G = [
152
150
  { type: "int32", data: [f.inputShape[1]] },
153
151
  // Reduce size
154
152
  { type: "int32", data: [f.inputShape[0]] }
@@ -156,27 +154,22 @@ function W(p) {
156
154
  ];
157
155
  if (t.inSize > 1024)
158
156
  throw new Error(`rmsNormGradGPU: inSize ${t.inSize} exceeds max of 1024`);
159
- const x = G.runWebGPUProgram(
160
- f,
161
- [m, c, d],
162
- r ? "int32" : "float32",
163
- v
164
- );
165
- x.packed = r, r && !s && m.dispose(), r && !i && c.dispose(), r && !u && d.dispose();
166
- const l = _().makeTensorFromTensorInfo(x), S = w(l, [0, 0], [t.batchSize, t.inSize]), g = w(
157
+ const v = w.runWebGPUProgram(f, [m, c, d], r ? "packedF16" : "float32", G);
158
+ r && !s && m.dispose(), r && !i && c.dispose(), r && !u && d.dispose();
159
+ const l = D().makeTensorFromTensorInfo(v), x = z(l, [0, 0], [t.batchSize, t.inSize]), S = z(
167
160
  l,
168
161
  [t.batchSize, 0],
169
162
  [t.batchSize / n, t.inSize]
170
163
  );
171
164
  l.dispose();
172
- const b = R(S, e.shape);
173
- S.dispose();
174
- const y = L(g, [0]);
175
- return g.dispose(), [b, !r || i ? y : M(y)];
165
+ const b = L(x, e.shape);
166
+ x.dispose();
167
+ const g = P(S, [0]);
168
+ return S.dispose(), [b, !r || i ? g : R(g)];
176
169
  }
177
170
  const Y = {
178
171
  kernelName: "RMSNormGrad",
179
172
  backendName: "webgpu",
180
173
  kernelFunc: W
181
174
  };
182
- P(Y);
175
+ _(Y);
@@ -1,19 +1,18 @@
1
- import "../../index-D6Q1lPZO.js";
2
- import c from "./pack16_program.js";
3
- import { r as p } from "../../tensor_util-DfwaWayG.js";
4
- function m(n) {
5
- const { x: e } = n.inputs, { scaling: t, padding: r } = n.attrs, i = n.backend;
6
- if (e.shape[e.shape.length - 1] % 2 !== 0)
1
+ import { h as i } from "../../index-D0RBWjq8.js";
2
+ import p from "./pack16_program.js";
3
+ function m(e) {
4
+ const { x: n } = e.inputs, { scaling: a, padding: r } = e.attrs, s = e.backend;
5
+ if (n.shape[n.shape.length - 1] % 2 !== 0)
7
6
  throw new Error("Last dimension of input tensor must be even to use Pack16.");
8
- n.attrs && (n.attrs.originalShape = e.shape);
9
- const a = new c(e.shape, r), o = t !== 1;
10
- o && a.useScaling();
11
- const s = [{ type: "float32", data: [t] }];
12
- return i.runWebGPUProgram(a, [e], "int32", o ? s : void 0);
7
+ e.attrs && (e.attrs.originalShape = n.shape);
8
+ const t = new p(n.shape, r), o = a !== 1;
9
+ o && t.useScaling();
10
+ const c = [{ type: "float32", data: [a] }];
11
+ return s.runWebGPUProgram(t, [n], "packedF16", o ? c : void 0);
13
12
  }
14
- const u = {
13
+ const k = {
15
14
  kernelName: "Pack16",
16
15
  backendName: "webgpu",
17
16
  kernelFunc: m
18
17
  };
19
- p(u);
18
+ i(k);
@@ -1,5 +1,5 @@
1
- import { f as o, c as a } from "../../webgpu_util-0_ubCEHJ.js";
2
- import { e as s } from "../../webgpu_program-DzaQiqel.js";
1
+ import { f as o, c as a } from "../../webgpu_util-BdovYhXr.js";
2
+ import { e as s } from "../../webgpu_program-CAE4RICo.js";
3
3
  class h {
4
4
  outputShape;
5
5
  shaderKey = "Pack16";
@@ -1,26 +1,24 @@
1
- import "../../index-D6Q1lPZO.js";
2
- import { j as h } from "../../tensor-CzmOBsdf.js";
3
- import { b as f } from "../../matMul16-fEAJ4smh.js";
1
+ import { h, a7 as l } from "../../index-D0RBWjq8.js";
2
+ import { b as f } from "../../matMul16-cDxwemKj.js";
4
3
  import { slice16 as a } from "../slice16.js";
5
- import { isPackedTensor as l } from "../../utilities/packed.js";
6
- import { r as u } from "../../tensor_util-DfwaWayG.js";
4
+ import { isPackedTensor as u } from "../../utilities/packed.js";
7
5
  function k(i) {
8
- const { x: r, kernel: c } = i.inputs, { heads: e } = i.attrs, t = r.shape[0], n = r.shape[1], s = r.shape[2], m = l(r);
9
- if (h(c.shape, [m ? s * 2 : s, 3 * s], "Error in QKV: "), s % e !== 0)
6
+ const { x: n, kernel: c } = i.inputs, { heads: e } = i.attrs, r = n.shape[0], t = n.shape[1], s = n.shape[2], p = u(n);
7
+ if (l(c.shape, [p ? s * 2 : s, 3 * s], "Error in QKV: "), s % e !== 0)
10
8
  throw new Error(`Channel dimension ${s} must be divisible by number of heads ${e} in QKV.`);
11
- const o = f(r, c, !1, !1, {
12
- forceOutputShape: [t, n, 3 * e, s / e],
9
+ const o = f(n, c, !1, !1, {
10
+ forceOutputShape: [r, t, 3 * e, s / e],
13
11
  perm: [0, 2, 1, 3]
14
- }), p = [
15
- a(o, [0, 0, 0, 0], [t, e, n, s / e]),
16
- a(o, [0, e, 0, 0], [t, e, n, s / e]),
17
- a(o, [0, 2 * e, 0, 0], [t, e, n, s / e])
12
+ }), m = [
13
+ a(o, [0, 0, 0, 0], [r, e, t, s / e]),
14
+ a(o, [0, e, 0, 0], [r, e, t, s / e]),
15
+ a(o, [0, 2 * e, 0, 0], [r, e, t, s / e])
18
16
  ];
19
- return o.dispose(), p;
17
+ return o.dispose(), m;
20
18
  }
21
19
  const b = {
22
20
  kernelName: "QKV",
23
21
  backendName: "webgpu",
24
22
  kernelFunc: k
25
23
  };
26
- u(b);
24
+ h(b);