@genai-fi/nanogpt 0.7.0 → 0.7.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. package/dist/Generator.js +13 -9
  2. package/dist/NanoGPTModel.js +10 -10
  3. package/dist/{RealDiv-C4hOvYOZ.js → RealDiv-CVYNbZxu.js} +11 -11
  4. package/dist/{Reshape-BLijOA8h.js → Reshape-CEsEp0AI.js} +2 -2
  5. package/dist/Reshape-Do18N3gO.js +30 -0
  6. package/dist/TeachableLLM.js +9 -5
  7. package/dist/{TiedEmbedding-BLltddza.js → TiedEmbedding-ccLBFiZi.js} +4 -4
  8. package/dist/{axis_util-DaAl5MER.js → axis_util-5DTW2tFV.js} +1 -1
  9. package/dist/backend.js +2 -2
  10. package/dist/{backend_util-DWiwsi2N.js → backend_util-C9Ut8n0Q.js} +40 -40
  11. package/dist/{broadcast_to-C4v-j9yA.js → broadcast_to-Ba9h_8DO.js} +2 -2
  12. package/dist/{concat-CsHeR4zV.js → concat-CbXTetof.js} +1 -1
  13. package/dist/{dataset-JDyjG3QR.js → dataset-U3PrjwgU.js} +7 -7
  14. package/dist/{dropout-hpDwECTe.js → dropout-DPfPgWWe.js} +11 -11
  15. package/dist/{gather-D0_gPiBz.js → gather-Bbh8DHhM.js} +4 -4
  16. package/dist/{gelu-uyHP1x1f.js → gelu-BFwVnd1r.js} +1 -1
  17. package/dist/{gpgpu_math-DJm3ZTAf.js → gpgpu_math-DffelNS-.js} +2 -2
  18. package/dist/{index-BPPzKVdR.js → index-DYD_yPa-.js} +1083 -1106
  19. package/dist/{index-C0dhsYom.js → index-UdZhlibC.js} +126 -126
  20. package/dist/{kernel_funcs_utils-CwRTFqrc.js → kernel_funcs_utils-CXDy3EN7.js} +3 -3
  21. package/dist/layers/BaseLayer.js +2 -2
  22. package/dist/layers/CausalSelfAttention.js +8 -8
  23. package/dist/layers/MLP.js +5 -5
  24. package/dist/layers/RMSNorm.js +3 -3
  25. package/dist/layers/RoPECache.js +4 -4
  26. package/dist/layers/TiedEmbedding.js +5 -5
  27. package/dist/layers/TransformerBlock.js +1 -1
  28. package/dist/loader/loadTransformers.js +1 -1
  29. package/dist/loader/oldZipLoad.js +11 -7
  30. package/dist/{log_sum_exp-D086OgZJ.js → log_sum_exp-BnmCkHWl.js} +8 -8
  31. package/dist/main.d.ts +11 -0
  32. package/dist/main.js +44 -27
  33. package/dist/{mat_mul-1nwdPkQ_.js → mat_mul-dwmZz69e.js} +1 -1
  34. package/dist/{max-BQc2Aj-I.js → max-ByjEGoFx.js} +3 -3
  35. package/dist/{mulmat_packed_gpu-Gzf3I9UV.js → mulmat_packed_gpu-IGPBp6h9.js} +1 -1
  36. package/dist/{ones-D63HpSF_.js → ones-C8Mfln6-.js} +2 -2
  37. package/dist/ops/adamAdjust.d.ts +2 -0
  38. package/dist/ops/adamAdjust.js +9 -0
  39. package/dist/ops/adamMoments.d.ts +2 -0
  40. package/dist/ops/adamMoments.js +9 -0
  41. package/dist/ops/appendCache.js +3 -3
  42. package/dist/ops/attentionMask.js +1 -1
  43. package/dist/ops/cpu/adamAdjust.d.ts +1 -0
  44. package/dist/ops/cpu/adamAdjust.js +18 -0
  45. package/dist/ops/cpu/adamMoments.d.ts +1 -0
  46. package/dist/ops/cpu/adamMoments.js +16 -0
  47. package/dist/ops/cpu/appendCache.js +2 -2
  48. package/dist/ops/cpu/attentionMask.js +5 -5
  49. package/dist/ops/cpu/fusedSoftmax.js +2 -2
  50. package/dist/ops/cpu/gatherSub.js +3 -3
  51. package/dist/ops/cpu/gelu.js +1 -1
  52. package/dist/ops/cpu/matMulGelu.js +2 -2
  53. package/dist/ops/cpu/matMulMul.js +1 -1
  54. package/dist/ops/cpu/mulDropout.js +1 -1
  55. package/dist/ops/cpu/normRMS.js +1 -1
  56. package/dist/ops/cpu/qkv.js +3 -3
  57. package/dist/ops/cpu/rope.js +5 -5
  58. package/dist/ops/cpu/scatterSub.js +11 -11
  59. package/dist/ops/fusedSoftmax.js +1 -1
  60. package/dist/ops/gatherSub.js +1 -1
  61. package/dist/ops/gelu.js +2 -2
  62. package/dist/ops/grads/attentionMask.js +1 -1
  63. package/dist/ops/grads/fusedSoftmax.js +2 -2
  64. package/dist/ops/grads/gelu.js +2 -2
  65. package/dist/ops/grads/matMulGelu.js +1 -1
  66. package/dist/ops/grads/normRMS.js +1 -1
  67. package/dist/ops/grads/qkv.js +1 -1
  68. package/dist/ops/grads/rope.js +1 -1
  69. package/dist/ops/matMulGelu.js +1 -1
  70. package/dist/ops/matMulMul.js +1 -1
  71. package/dist/ops/mulDrop.js +1 -1
  72. package/dist/ops/normRMS.js +1 -1
  73. package/dist/ops/qkv.js +1 -1
  74. package/dist/ops/rope.js +4 -4
  75. package/dist/ops/scatterSub.js +1 -1
  76. package/dist/ops/webgl/adamAdjust.d.ts +1 -0
  77. package/dist/ops/webgl/adamAdjust.js +50 -0
  78. package/dist/ops/webgl/adamMoments.d.ts +1 -0
  79. package/dist/ops/webgl/adamMoments.js +38 -0
  80. package/dist/ops/webgl/appendCache.js +1 -1
  81. package/dist/ops/webgl/attentionMask.js +1 -1
  82. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  83. package/dist/ops/webgl/gatherSub.js +8 -8
  84. package/dist/ops/webgl/gelu.js +2 -2
  85. package/dist/ops/webgl/log.js +3 -3
  86. package/dist/ops/webgl/matMulGelu.js +4 -4
  87. package/dist/ops/webgl/matMulMul.js +1 -1
  88. package/dist/ops/webgl/mulDropout.js +1 -1
  89. package/dist/ops/webgl/normRMS.js +2 -2
  90. package/dist/ops/webgl/qkv.js +1 -1
  91. package/dist/ops/webgl/rope.js +1 -1
  92. package/dist/ops/webgl/scatterSub.js +1 -1
  93. package/dist/ops/webgpu/adamAdjust.d.ts +1 -0
  94. package/dist/ops/webgpu/adamAdjust.js +52 -0
  95. package/dist/ops/webgpu/adamMoments.d.ts +1 -0
  96. package/dist/ops/webgpu/adamMoments.js +51 -0
  97. package/dist/ops/webgpu/appendCache.js +13 -12
  98. package/dist/ops/webgpu/attentionMask.js +11 -10
  99. package/dist/ops/webgpu/gatherSub.js +26 -11
  100. package/dist/ops/webgpu/gelu.js +7 -6
  101. package/dist/ops/webgpu/index.js +3 -0
  102. package/dist/ops/webgpu/normRMS.js +27 -101
  103. package/dist/ops/webgpu/normRMSGrad.d.ts +1 -0
  104. package/dist/ops/webgpu/normRMSGrad.js +128 -0
  105. package/dist/ops/webgpu/qkv.js +9 -8
  106. package/dist/ops/webgpu/rope.js +8 -7
  107. package/dist/ops/webgpu/scatterSub.js +8 -7
  108. package/dist/ops/webgpu/utils/reductions.d.ts +9 -0
  109. package/dist/ops/webgpu/utils/reductions.js +68 -0
  110. package/dist/{ops-CIQLNshk.js → ops-aRTXR2Sr.js} +195 -219
  111. package/dist/{random_width-DkYP8W8N.js → random_width-DbSpgl4o.js} +22 -21
  112. package/dist/{range-CYzpQY53.js → range-D9CZhVlR.js} +1 -1
  113. package/dist/{reciprocal-_A9yv27J.js → reciprocal-CGB48wZB.js} +1 -1
  114. package/dist/{register_all_kernels-guvSxp7M.js → register_all_kernels-DnbAyBXt.js} +30 -29
  115. package/dist/{reshape-BMUzc1UY.js → reshape-BR0eoLYN.js} +3 -3
  116. package/dist/{scatter_nd_util-IRBqKz_b.js → scatter_nd_util-OjyAxku2.js} +1 -1
  117. package/dist/{selu_util-Dt_iuXaq.js → selu_util-Ce6pu9IM.js} +41 -41
  118. package/dist/{shared-CDu9S76h.js → shared-Czipaeb6.js} +6 -6
  119. package/dist/{shared-BNa2q6jD.js → shared-DS5waSIY.js} +1 -1
  120. package/dist/{sin-Cocju-BY.js → sin-CiBxrDqX.js} +6 -6
  121. package/dist/slice-BHbDHObE.js +28 -0
  122. package/dist/{softmax-GPNK3o-U.js → softmax-JMEIUo2J.js} +3 -3
  123. package/dist/{split-CHzJjxDv.js → split-CRU0PjVV.js} +1 -1
  124. package/dist/{stack-Dpgg_1W1.js → stack-ikk2Y8_P.js} +1 -1
  125. package/dist/{sum-B8wEpKsg.js → sum-NLYbiDag.js} +3 -3
  126. package/dist/{tensor-RvZVNmg0.js → tensor-Do9PKbIE.js} +1 -1
  127. package/dist/{tensor2d-B_kyod7_.js → tensor2d-CWHxHpLh.js} +1 -1
  128. package/dist/training/Adam.d.ts +22 -0
  129. package/dist/training/Adam.js +93 -0
  130. package/dist/training/AdamExt.d.ts +1 -1
  131. package/dist/training/AdamExt.js +13 -12
  132. package/dist/training/DatasetBuilder.js +2 -2
  133. package/dist/training/FullTrainer.js +22 -22
  134. package/dist/training/Trainer.d.ts +1 -1
  135. package/dist/training/Trainer.js +32 -32
  136. package/dist/training/sparseCrossEntropy.d.ts +0 -4
  137. package/dist/training/sparseCrossEntropy.js +7 -7
  138. package/dist/utilities/arrayClose.d.ts +1 -0
  139. package/dist/utilities/arrayClose.js +11 -0
  140. package/dist/utilities/dummy.js +2 -2
  141. package/dist/utilities/generate.js +3 -3
  142. package/dist/utilities/multinomialCPU.js +2 -2
  143. package/dist/utilities/performance.d.ts +1 -1
  144. package/dist/utilities/performance.js +11 -11
  145. package/dist/utilities/profile.js +1 -1
  146. package/dist/utilities/safetensors.js +2 -2
  147. package/dist/utilities/weights.js +2 -2
  148. package/dist/{variable-DXEUOwew.js → variable-BTBkayv_.js} +1 -1
  149. package/dist/{webgpu_util-g13LvDIv.js → webgpu_program-WaoMq-WD.js} +138 -215
  150. package/dist/webgpu_util-DhSeP4b6.js +80 -0
  151. package/dist/{zeros-DCPCdFGq.js → zeros-DnPT2nD4.js} +4 -4
  152. package/package.json +1 -1
@@ -1,4 +1,4 @@
1
- import { aa as F, ab as O, ac as L, Z as N, l as _ } from "./index-C0dhsYom.js";
1
+ import { aa as k, ab as z, ac as E, a1 as j, l as A } from "./index-UdZhlibC.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2019 Google LLC. All Rights Reserved.
@@ -15,16 +15,16 @@ import { aa as F, ab as O, ac as L, Z as N, l as _ } from "./index-C0dhsYom.js";
15
15
  * limitations under the License.
16
16
  * =============================================================================
17
17
  */
18
- function U(e, s) {
18
+ function L(e, s) {
19
19
  if (Math.max(...e) > 5)
20
20
  throw new Error("Cannot symbolically compute strides for rank > 6 tensor.");
21
- const t = e.length, o = "xyzwuv", n = e.map((a) => `${s}.${o[a]}`), r = new Array(t - 1);
22
- r[t - 2] = n[t - 1];
23
- for (let a = t - 3; a >= 0; --a)
24
- r[a] = `(${r[a + 1]} * ${n[a + 1]})`;
25
- return r;
21
+ const t = e.length, o = "xyzwuv", i = e.map((n) => `${s}.${o[n]}`), u = new Array(t - 1);
22
+ u[t - 2] = i[t - 1];
23
+ for (let n = t - 3; n >= 0; --n)
24
+ u[n] = `(${u[n + 1]} * ${i[n + 1]})`;
25
+ return u;
26
26
  }
27
- const H = (e, s, t) => t === "int32" ? `atomicAdd(${e}, bitcast<i32>(${s}));` : `
27
+ const X = (e, s, t) => t === "int32" ? `atomicAdd(${e}, bitcast<i32>(${s}));` : `
28
28
  {
29
29
  var oldValue = 0;
30
30
  loop {
@@ -53,24 +53,24 @@ const H = (e, s, t) => t === "int32" ? `atomicAdd(${e}, bitcast<i32>(${s}));` :
53
53
  * limitations under the License.
54
54
  * =============================================================================
55
55
  */
56
- var y;
56
+ var b;
57
57
  (function(e) {
58
58
  e[e.FROM_PIXELS = 0] = "FROM_PIXELS", e[e.DRAW = 1] = "DRAW";
59
- })(y || (y = {}));
60
- const Z = (e, s, t, o, n) => {
61
- const r = { dtype: o.dtype, shape: o.shape }, a = P(t, r, s), i = e.createShaderModule({ code: a, label: s.constructor.name });
62
- let d = L().get("WEBGPU_PRINT_SHADER");
59
+ })(b || (b = {}));
60
+ const H = (e, s, t, o, i) => {
61
+ const u = { dtype: o.dtype, shape: o.shape }, n = D(t, u, s), r = e.createShaderModule({ code: n, label: s.constructor.name });
62
+ let d = E().get("WEBGPU_PRINT_SHADER");
63
63
  if (d !== "") {
64
64
  d = d.toLowerCase();
65
65
  const p = d.split(",");
66
- (d === "all" || p.some((u) => s.shaderKey.toLowerCase().includes(u))) && (console.group(s.shaderKey), console.debug(a), console.groupEnd());
66
+ (d === "all" || p.some((a) => s.shaderKey.toLowerCase().includes(a))) && (console.group(s.shaderKey), console.debug(n), console.groupEnd());
67
67
  }
68
- return n ? e.createComputePipelineAsync({
69
- compute: { module: i, entryPoint: "_start" },
68
+ return i ? e.createComputePipelineAsync({
69
+ compute: { module: r, entryPoint: "_start" },
70
70
  label: s.constructor.name,
71
71
  layout: "auto"
72
72
  }) : e.createComputePipeline({
73
- compute: { module: i, entryPoint: "_start" },
73
+ compute: { module: r, entryPoint: "_start" },
74
74
  label: s.constructor.name,
75
75
  layout: "auto"
76
76
  });
@@ -118,7 +118,7 @@ function I(e) {
118
118
  return "v";
119
119
  throw Error(`Index ${e} is not yet supported`);
120
120
  }
121
- function q(...e) {
121
+ function Y(...e) {
122
122
  let s;
123
123
  switch (e.length) {
124
124
  case 0:
@@ -139,7 +139,7 @@ function q(...e) {
139
139
  function w(e, s) {
140
140
  let t;
141
141
  return t = `
142
- ${D(s)}
142
+ ${N(s)}
143
143
  fn _start(@builtin(local_invocation_id) LocalId : vec3<u32>,
144
144
  @builtin(global_invocation_id) GlobalId : vec3<u32>,
145
145
  @builtin(local_invocation_index) LocalIndex: u32,
@@ -154,13 +154,13 @@ function w(e, s) {
154
154
  }
155
155
  `, t;
156
156
  }
157
- function D(e) {
157
+ function N(e) {
158
158
  return `
159
159
  @compute @workgroup_size(${e.workgroupSize[0]}, ${e.workgroupSize[1]}, ${e.workgroupSize[2]})
160
160
  `;
161
161
  }
162
- function P(e, s, t) {
163
- const o = [], n = t.workgroupSize[0] * t.workgroupSize[1] * t.workgroupSize[2];
162
+ function D(e, s, t) {
163
+ const o = [], i = t.workgroupSize[0] * t.workgroupSize[1] * t.workgroupSize[2];
164
164
  if (t.outputComponent = t.outputComponent ? t.outputComponent : 1, o.push(`
165
165
 
166
166
  var<private> localId: vec3<u32>;
@@ -171,13 +171,13 @@ function P(e, s, t) {
171
171
 
172
172
  // Only used when the y/z dimension of workgroup size is 1.
173
173
  fn getGlobalIndex() -> i32 {
174
- ${E(t) ? " return i32(globalId.x);" : ` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y +
175
- workgroupId.y * numWorkgroups.x + workgroupId.x) * ${n}u +
174
+ ${F(t) ? " return i32(globalId.x);" : ` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y +
175
+ workgroupId.y * numWorkgroups.x + workgroupId.x) * ${i}u +
176
176
  localIndex);
177
177
  `}
178
178
  }
179
179
  `), t.pixelsOpType != null) {
180
- const h = t.pixelsOpType === y.FROM_PIXELS ? `@group(0) @binding(0) var<storage, read_write> result: array<${b(s.dtype, t.outputComponent)}>;` : `@group(0) @binding(1) var<storage, read> inBuf : array<${b(e[0].dtype, t.outputComponent)}>;`, c = s.shape.length === 3 ? "vec2<i32>" : "i32";
180
+ const h = t.pixelsOpType === b.FROM_PIXELS ? `@group(0) @binding(0) var<storage, read_write> result: array<${x(s.dtype, t.outputComponent)}>;` : `@group(0) @binding(1) var<storage, read> inBuf : array<${x(e[0].dtype, t.outputComponent)}>;`, c = s.shape.length === 3 ? "vec2<i32>" : "i32";
181
181
  o.push(`
182
182
  struct Uniform {
183
183
  outShapeStrides : ${c},
@@ -189,63 +189,63 @@ function P(e, s, t) {
189
189
  ${h}
190
190
  @group(0) @binding(2) var<uniform> uniforms: Uniform;
191
191
  `);
192
- const x = k(t);
192
+ const S = m(t);
193
193
  return [
194
194
  C,
195
195
  o.join(`
196
196
  `),
197
- m(s.shape),
197
+ y(s.shape),
198
198
  t.getUserCode(),
199
- w(x, t)
199
+ w(S, t)
200
200
  ].join(`
201
201
  `);
202
202
  }
203
- let r, a, i = "struct Uniforms { NAN : f32, INFINITY : f32, ";
203
+ let u, n, r = "struct Uniforms { NAN : f32, INFINITY : f32, ";
204
204
  t.variableNames.forEach((h, c) => {
205
- const x = v(e[c].shape.length);
206
- i += `${h.charAt(0).toLowerCase() + h.slice(1)}Shape : ${x}, `, r = e[c].shape.length - 1, a = v(r), i += `${h.charAt(0).toLowerCase() + h.slice(1)}ShapeStrides: ${a}, `;
205
+ const S = v(e[c].shape.length);
206
+ r += `${h.charAt(0).toLowerCase() + h.slice(1)}Shape : ${S}, `, u = e[c].shape.length - 1, n = v(u), r += `${h.charAt(0).toLowerCase() + h.slice(1)}ShapeStrides: ${n}, `;
207
207
  });
208
208
  const d = v(s.shape.length);
209
- i += `outShape : ${d}, `, r = s.shape.length - 1, a = v(r), i += `
210
- outShapeStrides: ${a}, `, t.size && (i += "size : i32, "), t.uniforms && (i += t.uniforms), i += "};", i = K(i), o.push(i), t.atomic ? o.push(`
209
+ r += `outShape : ${d}, `, u = s.shape.length - 1, n = v(u), r += `
210
+ outShapeStrides: ${n}, `, t.size && (r += "size : i32, "), t.uniforms && (r += t.uniforms), r += "};", r = B(r), o.push(r), t.atomic ? o.push(`
211
211
  @group(0) @binding(0) var<storage, read_write> result: array<atomic<i32>>;
212
212
  `) : o.push(`
213
- @group(0) @binding(0) var<storage, read_write> result: array<${b(s.dtype, t.outputComponent)}>;
213
+ @group(0) @binding(0) var<storage, read_write> result: array<${x(s.dtype, t.outputComponent)}>;
214
214
  `), t.variableNames.forEach((h, c) => {
215
215
  o.push(`
216
- @group(0) @binding(${1 + c}) var<storage, read> ${h}: array<${t.variableComponents ? b(e[c].dtype, t.variableComponents[c]) : b(e[c].dtype, t.outputComponent)}>;
216
+ @group(0) @binding(${1 + c}) var<storage, read> ${h}: array<${t.variableComponents ? x(e[c].dtype, t.variableComponents[c]) : x(e[c].dtype, t.outputComponent)}>;
217
217
  `);
218
- }), i !== "" && o.push(`
218
+ }), r !== "" && o.push(`
219
219
  @group(0) @binding(${1 + t.variableNames.length}) var<uniform> uniforms: Uniforms;
220
220
  `);
221
- const p = T(s.shape, t.dispatchLayout), u = [
221
+ const p = R(s.shape, t.dispatchLayout), a = [
222
222
  C,
223
223
  o.join(`
224
- `) + W,
225
- m(s.shape),
224
+ `) + T,
225
+ y(s.shape),
226
226
  p,
227
- V(s.shape.length)
227
+ G(s.shape.length)
228
228
  ];
229
- t.atomic || u.push(B(s.shape, s.dtype, t.outputComponent)), t.variableNames.forEach((h, c) => {
230
- u.push(`${m(e[c].shape, h)}`);
229
+ t.atomic || a.push(V(s.shape, s.dtype, t.outputComponent)), t.variableNames.forEach((h, c) => {
230
+ a.push(`${y(e[c].shape, h)}`);
231
231
  });
232
- const g = e.map((h, c) => G(h, s.shape, t.variableComponents ? t.variableComponents[c] : t.outputComponent, t.dispatchLayout.x.length === s.shape.length)).join(`
232
+ const g = e.map((h, c) => P(h, s.shape, t.variableComponents ? t.variableComponents[c] : t.outputComponent, t.dispatchLayout.x.length === s.shape.length)).join(`
233
233
  `);
234
- u.push(g), u.push(t.getUserCode());
235
- const l = k(t);
236
- return u.push(w(l, t)), u.join(`
234
+ a.push(g), a.push(t.getUserCode());
235
+ const l = m(t);
236
+ return a.push(w(l, t)), a.join(`
237
237
  `);
238
238
  }
239
- function J(e, s, t) {
239
+ function q(e, s, t) {
240
240
  let o = e.shaderKey;
241
241
  if (e.pixelsOpType != null)
242
242
  return o;
243
- const n = [], r = [];
244
- s.forEach((u) => {
245
- n.push(u.shape), r.push(u.dtype);
246
- }), n.push(t.shape), r.push(t.dtype);
247
- const a = s.map((u) => F(u.shape, t.shape)), i = s.map((u) => O(u.shape, t.shape)).join("_"), d = a.map((u) => u.join("_")).join(";"), p = E(e) ? "flatDispatch" : "";
248
- return o += "_" + (e.workgroupSize ? e.workgroupSize.join(",") : "") + n.map((u) => u.length).join(",") + r.join(",") + e.variableNames.join(",") + d + i + p, o;
243
+ const i = [], u = [];
244
+ s.forEach((a) => {
245
+ i.push(a.shape), u.push(a.dtype);
246
+ }), i.push(t.shape), u.push(t.dtype);
247
+ const n = s.map((a) => k(a.shape, t.shape)), r = s.map((a) => z(a.shape, t.shape)).join("_"), d = n.map((a) => a.join("_")).join(";"), p = F(e) ? "flatDispatch" : "";
248
+ return o += "_" + (e.workgroupSize ? e.workgroupSize.join(",") : "") + i.map((a) => a.length).join(",") + u.join(",") + e.variableNames.join(",") + d + r + p, o;
249
249
  }
250
250
  const C = `
251
251
  struct vec5 {x: i32, y: i32, z: i32, w: i32, u: i32};
@@ -297,138 +297,138 @@ const C = `
297
297
  let floatToUint: vec4<u32> = bitcast<vec4<u32>>(val);
298
298
  return (floatToUint & vec4<u32>(0x7fffffffu)) > vec4<u32>(0x7f800000u);
299
299
  }
300
- `, W = `
300
+ `, T = `
301
301
  fn isinf(val: f32) -> bool {
302
302
  return abs(val) == uniforms.INFINITY;
303
303
  }
304
304
  `;
305
- function m(e, s = "") {
306
- const t = e.length, o = s !== "" ? `get${s.charAt(0).toUpperCase() + s.slice(1)}CoordsFromIndex` : "getCoordsFromIndex", n = s !== "" ? `${s.charAt(0).toLowerCase() + s.slice(1)}ShapeStrides` : "outShapeStrides";
305
+ function y(e, s = "") {
306
+ const t = e.length, o = s !== "" ? `get${s.charAt(0).toUpperCase() + s.slice(1)}CoordsFromIndex` : "getCoordsFromIndex", i = s !== "" ? `${s.charAt(0).toLowerCase() + s.slice(1)}ShapeStrides` : "outShapeStrides";
307
307
  if (t <= 1)
308
308
  return `fn ${o}(index : i32) -> i32 { return index; }`;
309
- const r = N(e), a = v(t), i = [];
309
+ const u = j(e), n = v(t), r = [];
310
310
  for (let p = 0; p < t; p++)
311
- i.push(`d${p}`);
312
- if (r.length === 1)
311
+ r.push(`d${p}`);
312
+ if (u.length === 1)
313
313
  return ` fn ${o}(index : i32) -> vec2<i32> {
314
- let d0 = index / uniforms.${n}; let d1 = index - d0 * uniforms.${n};
314
+ let d0 = index / uniforms.${i}; let d1 = index - d0 * uniforms.${i};
315
315
  return vec2<i32>(d0, d1);
316
316
  }`;
317
317
  let d;
318
- return d = "var index2 = index;" + r.map((p, u) => {
319
- const g = `let ${i[u]} = index2 / uniforms.${n}.${I(u)}`, l = u === r.length - 1 ? `let ${i[u + 1]} = index2 - ${i[u]} * uniforms.${n}.${I(u)}` : `index2 = index2 - ${i[u]} * uniforms.${n}.${I(u)}`;
318
+ return d = "var index2 = index;" + u.map((p, a) => {
319
+ const g = `let ${r[a]} = index2 / uniforms.${i}.${I(a)}`, l = a === u.length - 1 ? `let ${r[a + 1]} = index2 - ${r[a]} * uniforms.${i}.${I(a)}` : `index2 = index2 - ${r[a]} * uniforms.${i}.${I(a)}`;
320
320
  return `${g}; ${l};`;
321
321
  }).join(""), `
322
- fn ${o}(index : i32) -> ${a} {
322
+ fn ${o}(index : i32) -> ${n} {
323
323
  ${d}
324
- return ${a}(${i.join(",")});
324
+ return ${n}(${r.join(",")});
325
325
  }
326
326
  `;
327
327
  }
328
- function M(e, s) {
329
- const t = e.name, o = e.shape.length, n = v(o), r = "get" + t.charAt(0).toUpperCase() + t.slice(1), a = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), i = a.map((u) => `${u} : i32`).join(", ");
328
+ function U(e, s) {
329
+ const t = e.name, o = e.shape.length, i = v(o), u = "get" + t.charAt(0).toUpperCase() + t.slice(1), n = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), r = n.map((a) => `${a} : i32`).join(", ");
330
330
  if (o < 1)
331
331
  return `
332
- fn ${r}() -> ${f(s)} {
332
+ fn ${u}() -> ${f(s)} {
333
333
  return ${f(s)}(${t}[0]);
334
334
  }
335
335
  `;
336
336
  const d = `uniforms.${t.charAt(0).toLowerCase() + t.slice(1)}Shape`;
337
337
  let p = `${o}D`;
338
338
  return o === 0 && (p = "1D"), `
339
- fn ${r}(${i}) -> ${f(s)} {
340
- return ${f(s)}(${t}[getIndexFromCoords${p}(${n}(${a.join(",")}),
339
+ fn ${u}(${r}) -> ${f(s)} {
340
+ return ${f(s)}(${t}[getIndexFromCoords${p}(${i}(${n.join(",")}),
341
341
  ${d})${s === 1 ? "" : ` / ${s}`}]);
342
342
  }
343
343
  `;
344
344
  }
345
- function R(e, s, t, o) {
346
- const n = e.name, r = n.charAt(0).toUpperCase() + n.slice(1), a = "get" + r + "ByOutput", i = e.shape.length, d = s.length, p = v(d);
347
- if (O(e.shape, s) && o)
345
+ function W(e, s, t, o) {
346
+ const i = e.name, u = i.charAt(0).toUpperCase() + i.slice(1), n = "get" + u + "ByOutput", r = e.shape.length, d = s.length, p = v(d);
347
+ if (z(e.shape, s) && o)
348
348
  return `
349
- fn ${a}Index(globalIndex : i32) -> ${f(t)} {
350
- return ${f(t)}(${n}[globalIndex]);
349
+ fn ${n}Index(globalIndex : i32) -> ${f(t)} {
350
+ return ${f(t)}(${i}[globalIndex]);
351
351
  }
352
352
 
353
- fn ${a}Coords(coords : ${p}) -> ${f(t)} {
354
- return ${f(t)}(${n}[${d > 1 ? "getOutputIndexFromCoords(coords)" : "coords"}${t === 1 ? "" : ` / ${t}`}]);
353
+ fn ${n}Coords(coords : ${p}) -> ${f(t)} {
354
+ return ${f(t)}(${i}[${d > 1 ? "getOutputIndexFromCoords(coords)" : "coords"}${t === 1 ? "" : ` / ${t}`}]);
355
355
  }
356
356
  `;
357
- const u = F(e.shape, s), g = d - i;
357
+ const a = k(e.shape, s), g = d - r;
358
358
  let l = "";
359
- if (i === 0)
359
+ if (r === 0)
360
360
  return `
361
- fn ${a}Index(globalIndex : i32) -> ${f(t)}{
362
- return get${r}();
361
+ fn ${n}Index(globalIndex : i32) -> ${f(t)}{
362
+ return get${u}();
363
363
  }
364
364
 
365
- fn ${a}Coords(coords : ${p}) -> ${f(t)}{
366
- return get${r}();
365
+ fn ${n}Coords(coords : ${p}) -> ${f(t)}{
366
+ return get${u}();
367
367
  }
368
368
  `;
369
- d < 2 && u.length >= 1 ? l = "coords = 0;" : l = u.map((x) => `coords.${I(x + g)} = 0;`).join(`
369
+ d < 2 && a.length >= 1 ? l = "coords = 0;" : l = a.map((S) => `coords.${I(S + g)} = 0;`).join(`
370
370
  `);
371
371
  let $ = "";
372
- if (d < 2 && i > 0)
372
+ if (d < 2 && r > 0)
373
373
  $ = "coords";
374
374
  else if (d > 1) {
375
- const x = v(i), j = e.shape.map((X, A) => `coords.${I(A + g)}`).join(", ");
376
- $ = `${x}(${j})`;
375
+ const S = v(r), O = e.shape.map((M, _) => `coords.${I(_ + g)}`).join(", ");
376
+ $ = `${S}(${O})`;
377
377
  } else
378
378
  $ = "coords";
379
- const h = `uniforms.${n.charAt(0).toLowerCase() + n.slice(1)}Shape`, c = `${i}D`;
379
+ const h = `uniforms.${i.charAt(0).toLowerCase() + i.slice(1)}Shape`, c = `${r}D`;
380
380
  return `
381
- fn ${a}Index(globalIndex : i32) -> ${f(t)} {
381
+ fn ${n}Index(globalIndex : i32) -> ${f(t)} {
382
382
  var coords = getCoordsFromIndex(globalIndex);
383
383
  ${l}
384
- return ${f(t)}(${n}[getIndexFromCoords${c}(${$}, ${h})${t === 1 ? "" : ` / ${t}`}]);
384
+ return ${f(t)}(${i}[getIndexFromCoords${c}(${$}, ${h})${t === 1 ? "" : ` / ${t}`}]);
385
385
  }
386
386
 
387
- fn ${a}Coords(coordsIn : ${p}) -> ${f(t)} {
387
+ fn ${n}Coords(coordsIn : ${p}) -> ${f(t)} {
388
388
  var coords = coordsIn;
389
389
  ${l}
390
- return ${f(t)}(${n}[getIndexFromCoords${c}(${$}, ${h})${t === 1 ? "" : ` / ${t}`}]);
390
+ return ${f(t)}(${i}[getIndexFromCoords${c}(${$}, ${h})${t === 1 ? "" : ` / ${t}`}]);
391
391
  }
392
392
  `;
393
393
  }
394
- function G(e, s, t, o) {
395
- let n = M(e, t);
396
- return e.shape.length <= s.length && (n += R(e, s, t, o)), n;
394
+ function P(e, s, t, o) {
395
+ let i = U(e, t);
396
+ return e.shape.length <= s.length && (i += W(e, s, t, o)), i;
397
397
  }
398
- function T(e, s) {
399
- const { x: t, y: o = [], z: n = [] } = s, r = e.length, a = t.length + o.length + n.length;
400
- if (a !== r)
398
+ function R(e, s) {
399
+ const { x: t, y: o = [], z: i = [] } = s, u = e.length, n = t.length + o.length + i.length;
400
+ if (n !== u)
401
401
  return "";
402
- if (t.length === r)
403
- return `fn getOutputCoords() -> ${v(r)}{
402
+ if (t.length === u)
403
+ return `fn getOutputCoords() -> ${v(u)}{
404
404
  let globalIndex = getGlobalIndex();
405
405
  return getCoordsFromIndex(globalIndex);
406
406
  }
407
407
  `;
408
- let i = "";
409
- const d = [t, o, n];
408
+ let r = "";
409
+ const d = [t, o, i];
410
410
  for (let l = 0; l < d.length; l++) {
411
411
  const $ = d[l];
412
412
  if ($.length !== 0)
413
413
  if ($.length === 1)
414
- i += `let d${$[0]} = i32(globalId[${l}]);`;
414
+ r += `let d${$[0]} = i32(globalId[${l}]);`;
415
415
  else {
416
- const h = U($, "uniforms.outShape");
417
- i += `var index${l} = i32(globalId[${l}]);`;
416
+ const h = L($, "uniforms.outShape");
417
+ r += `var index${l} = i32(globalId[${l}]);`;
418
418
  for (let c = 0; c < h.length; c++)
419
- i += `let d${$[c]} = index${l} / ${h[c]};`, c === h.length - 1 ? i += `let d${$[c + 1]} = index${l} - d${$[c]} * ${h[c]};` : i += `index${l} = index${l} - d${$[c]} * ${h[c]};`;
419
+ r += `let d${$[c]} = index${l} / ${h[c]};`, c === h.length - 1 ? r += `let d${$[c + 1]} = index${l} - d${$[c]} * ${h[c]};` : r += `index${l} = index${l} - d${$[c]} * ${h[c]};`;
420
420
  }
421
421
  }
422
422
  const p = [];
423
- for (let l = 0; l < a; l++)
423
+ for (let l = 0; l < n; l++)
424
424
  p.push(`d${l}`);
425
- const u = v(a);
426
- let g = `fn getOutputCoords() -> ${u} {
427
- ${i}
425
+ const a = v(n);
426
+ let g = `fn getOutputCoords() -> ${a} {
427
+ ${r}
428
428
  `;
429
- return p.length === 0 ? g += `return ${u}(0); }` : g += `return ${u}(${p.join(",")}); }`, g;
429
+ return p.length === 0 ? g += `return ${a}(0); }` : g += `return ${a}(${p.join(",")}); }`, g;
430
430
  }
431
- function V(e) {
431
+ function G(e) {
432
432
  let s = "";
433
433
  switch (e) {
434
434
  case 0:
@@ -485,141 +485,64 @@ function V(e) {
485
485
  `;
486
486
  break;
487
487
  default:
488
- _(!1, () => `Unsupported ${e}D shape`);
488
+ A(!1, () => `Unsupported ${e}D shape`);
489
489
  break;
490
490
  }
491
491
  return s;
492
492
  }
493
- function E(e) {
493
+ function F(e) {
494
494
  return e.dispatch[1] === 1 && e.dispatch[2] === 1;
495
495
  }
496
- function b(e, s = 1) {
496
+ function x(e, s = 1) {
497
497
  if (e === "float32")
498
498
  return f(s, "f32");
499
499
  if (e === "int32" || e === "bool")
500
500
  return f(s, "i32");
501
501
  throw new Error(`type ${e} is not supported.`);
502
502
  }
503
- function B(e, s, t) {
504
- const o = e.length, n = b(s, t);
505
- let r = `fn setOutputAtIndex(flatIndex : i32, value : ${f(t)}) {
506
- result[flatIndex] = ${n}(value);
503
+ function V(e, s, t) {
504
+ const o = e.length, i = x(s, t);
505
+ let u = `fn setOutputAtIndex(flatIndex : i32, value : ${f(t)}) {
506
+ result[flatIndex] = ${i}(value);
507
507
  }
508
508
 
509
509
  fn setOutputAtIndexI32(flatIndex : i32, value : ${f(t, "i32")}) {
510
- result[flatIndex] = ${n}(value);
510
+ result[flatIndex] = ${i}(value);
511
511
  }
512
512
  `;
513
513
  if (o >= 2) {
514
- const a = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), i = v(o);
515
- r += `
516
- fn setOutputAtCoords(${a.map((d) => `${d} : i32`).join(", ")}, value : ${f(t)}) {
517
- let flatIndex = getOutputIndexFromCoords(${i}(${a.join(", ")}));
514
+ const n = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), r = v(o);
515
+ u += `
516
+ fn setOutputAtCoords(${n.map((d) => `${d} : i32`).join(", ")}, value : ${f(t)}) {
517
+ let flatIndex = getOutputIndexFromCoords(${r}(${n.join(", ")}));
518
518
  setOutputAtIndex(flatIndex${t === 1 ? "" : ` / ${t}`}, value);
519
519
  }
520
- fn setOutputAtCoordsI32(${a.map((d) => `${d} : i32`).join(", ")}, value : ${f(t, "i32")}) {
521
- let flatIndex = getOutputIndexFromCoords(${i}(${a.join(", ")}));
520
+ fn setOutputAtCoordsI32(${n.map((d) => `${d} : i32`).join(", ")}, value : ${f(t, "i32")}) {
521
+ let flatIndex = getOutputIndexFromCoords(${r}(${n.join(", ")}));
522
522
  setOutputAtIndexI32(flatIndex${t === 1 ? "" : ` / ${t}`}, value);
523
523
  }
524
524
  `;
525
525
  }
526
- return r;
526
+ return u;
527
527
  }
528
- function K(e) {
528
+ function B(e) {
529
529
  const s = /(\w+)\s*:\s*vec(5|6)/g;
530
530
  e = e.replace(s, (o) => "@align(16) " + o);
531
531
  const t = /vec(5|6)\s*,\s*(\w+)/g;
532
- return e = e.replace(t, (o, n, r) => `vec${n}, @align(16) ${r}`), e;
532
+ return e = e.replace(t, (o, i, u) => `vec${i}, @align(16) ${u}`), e;
533
533
  }
534
- function k(e) {
534
+ function m(e) {
535
535
  return !(e.dispatchLayout.hasOwnProperty("y") && e.dispatchLayout.y.length !== 0 || e.dispatchLayout.hasOwnProperty("z") && e.dispatchLayout.z.length !== 0);
536
536
  }
537
- /**
538
- * @license
539
- * Copyright 2019 Google LLC. All Rights Reserved.
540
- * Licensed under the Apache License, Version 2.0 (the "License");
541
- * you may not use this file except in compliance with the License.
542
- * You may obtain a copy of the License at
543
- *
544
- * http://www.apache.org/licenses/LICENSE-2.0
545
- *
546
- * Unless required by applicable law or agreed to in writing, software
547
- * distributed under the License is distributed on an "AS IS" BASIS,
548
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
549
- * See the License for the specific language governing permissions and
550
- * limitations under the License.
551
- * =============================================================================
552
- */
553
- const S = (e) => {
554
- let s = 1;
555
- for (let t = 0; t < e.length; t++)
556
- s *= e[t];
557
- return s;
558
- };
559
- function Q(e, s, t = [1, 1, 1], o = [1, 1, 1]) {
560
- const [n, r, a] = [
561
- Math.ceil(S(e.x.map((i) => s[i])) / (t[0] * o[0])),
562
- e.y ? Math.ceil(S(e.y.map((i) => s[i])) / (t[1] * o[1])) : 1,
563
- e.z ? Math.ceil(S(e.z.map((i) => s[i])) / (t[2] * o[2])) : 1
564
- ];
565
- return [n, r, a];
566
- }
567
- function ee(e, s, t, o = !1) {
568
- const n = [8, 8, 1], r = [4, 4, 1];
569
- return o || (e <= 8 && (r[1] = 1), s <= 16 && t <= 16 && (n[0] = 4)), { workgroupSize: n, elementsPerThread: r };
570
- }
571
- function te(e, s, t = !1) {
572
- if (t)
573
- return [8, 8, 1];
574
- const o = S(e.x.map((r) => s[r])), n = S(e.y.map((r) => s[r]));
575
- return o <= 4 ? [4, 16, 1] : n <= 4 ? [16, 4, 1] : [16, 16, 1];
576
- }
577
- function se(e, s, t = !1) {
578
- if (t)
579
- return [4, 4, 1];
580
- const o = S(e.x.map((r) => s[r])), n = S(e.y.map((r) => s[r]));
581
- return o <= 4 ? [1, 2, 1] : n <= 4 ? [2, 1, 1] : [2, 2, 1];
582
- }
583
- function oe(e) {
584
- return { x: e.map((s, t) => t) };
585
- }
586
- function re(e) {
587
- if (e === "float32" || e === "int32" || e === "bool" || e === "string")
588
- return 4;
589
- if (e === "complex64")
590
- return 8;
591
- throw new Error(`Unknown dtype ${e}`);
592
- }
593
- function ne() {
594
- return !!(typeof globalThis < "u" && globalThis.navigator && globalThis.navigator.gpu);
595
- }
596
- function ie(e, s) {
597
- Array.isArray(e) || (e = [e]), e.forEach((t) => {
598
- t != null && _(t.dtype !== "complex64", () => `${s} does not support complex64 tensors in the WebGPU backend.`);
599
- });
600
- }
601
- var z;
602
- (function(e) {
603
- e[e.MatMulReduceProgram = 0] = "MatMulReduceProgram", e[e.MatMulSplitKProgram = 1] = "MatMulSplitKProgram", e[e.MatMulSmallOutputSizeProgram = 2] = "MatMulSmallOutputSizeProgram", e[e.MatMulPackedProgram = 3] = "MatMulPackedProgram", e[e.MatMulMax = 4] = "MatMulMax";
604
- })(z || (z = {}));
605
537
  export {
606
- re as G,
607
- z as M,
608
- y as P,
609
- Z as a,
610
- ee as b,
611
- Q as c,
612
- H as d,
613
- v as e,
614
- oe as f,
615
- q as g,
616
- I as h,
617
- ne as i,
618
- ie as j,
619
- te as k,
620
- se as l,
621
- J as m,
622
- b as n,
623
- m as o,
538
+ b as P,
539
+ X as a,
540
+ v as b,
541
+ H as c,
542
+ I as d,
543
+ x as e,
544
+ y as f,
545
+ Y as g,
546
+ q as m,
624
547
  f as t
625
548
  };
@@ -0,0 +1,80 @@
1
+ import { l as u } from "./index-UdZhlibC.js";
2
+ /**
3
+ * @license
4
+ * Copyright 2019 Google LLC. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ * =============================================================================
17
+ */
18
+ const e = (r) => {
19
+ let t = 1;
20
+ for (let n = 0; n < r.length; n++)
21
+ t *= r[n];
22
+ return t;
23
+ };
24
+ function m(r, t, n = [1, 1, 1], a = [1, 1, 1]) {
25
+ const [o, i, f] = [
26
+ Math.ceil(e(r.x.map((c) => t[c])) / (n[0] * a[0])),
27
+ r.y ? Math.ceil(e(r.y.map((c) => t[c])) / (n[1] * a[1])) : 1,
28
+ r.z ? Math.ceil(e(r.z.map((c) => t[c])) / (n[2] * a[2])) : 1
29
+ ];
30
+ return [o, i, f];
31
+ }
32
+ function d(r, t, n, a = !1) {
33
+ const o = [8, 8, 1], i = [4, 4, 1];
34
+ return a || (r <= 8 && (i[1] = 1), t <= 16 && n <= 16 && (o[0] = 4)), { workgroupSize: o, elementsPerThread: i };
35
+ }
36
+ function p(r, t, n = !1) {
37
+ if (n)
38
+ return [8, 8, 1];
39
+ const a = e(r.x.map((i) => t[i])), o = e(r.y.map((i) => t[i]));
40
+ return a <= 4 ? [4, 16, 1] : o <= 4 ? [16, 4, 1] : [16, 16, 1];
41
+ }
42
+ function M(r, t, n = !1) {
43
+ if (n)
44
+ return [4, 4, 1];
45
+ const a = e(r.x.map((i) => t[i])), o = e(r.y.map((i) => t[i]));
46
+ return a <= 4 ? [1, 2, 1] : o <= 4 ? [2, 1, 1] : [2, 2, 1];
47
+ }
48
+ function h(r) {
49
+ return { x: r.map((t, n) => n) };
50
+ }
51
+ function x(r) {
52
+ if (r === "float32" || r === "int32" || r === "bool" || r === "string")
53
+ return 4;
54
+ if (r === "complex64")
55
+ return 8;
56
+ throw new Error(`Unknown dtype ${r}`);
57
+ }
58
+ function g() {
59
+ return !!(typeof globalThis < "u" && globalThis.navigator && globalThis.navigator.gpu);
60
+ }
61
+ function b(r, t) {
62
+ Array.isArray(r) || (r = [r]), r.forEach((n) => {
63
+ n != null && u(n.dtype !== "complex64", () => `${t} does not support complex64 tensors in the WebGPU backend.`);
64
+ });
65
+ }
66
+ var s;
67
+ (function(r) {
68
+ r[r.MatMulReduceProgram = 0] = "MatMulReduceProgram", r[r.MatMulSplitKProgram = 1] = "MatMulSplitKProgram", r[r.MatMulSmallOutputSizeProgram = 2] = "MatMulSmallOutputSizeProgram", r[r.MatMulPackedProgram = 3] = "MatMulPackedProgram", r[r.MatMulMax = 4] = "MatMulMax";
69
+ })(s || (s = {}));
70
+ export {
71
+ x as G,
72
+ s as M,
73
+ d as a,
74
+ b,
75
+ m as c,
76
+ p as d,
77
+ M as e,
78
+ h as f,
79
+ g as i
80
+ };