@genai-fi/nanogpt 0.6.3 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (140) hide show
  1. package/dist/Generator.js +11 -11
  2. package/dist/NanoGPTModel.d.ts +2 -2
  3. package/dist/NanoGPTModel.js +104 -136
  4. package/dist/{RealDiv-BYViZwhN.js → RealDiv-C4hOvYOZ.js} +26 -25
  5. package/dist/{Reshape-t7Kcikjk.js → Reshape-BLijOA8h.js} +5 -5
  6. package/dist/TeachableLLM.js +5 -5
  7. package/dist/{TiedEmbedding-9WeDwvjO.js → TiedEmbedding-BLltddza.js} +4 -4
  8. package/dist/{axis_util-Bu4h7XWV.js → axis_util-DaAl5MER.js} +3 -3
  9. package/dist/backend.d.ts +1 -0
  10. package/dist/backend.js +7 -0
  11. package/dist/backend_util-DWiwsi2N.js +749 -0
  12. package/dist/{broadcast_to-DARN-DBD.js → broadcast_to-C4v-j9yA.js} +2 -2
  13. package/dist/{concat-5aPGqw3Z.js → concat-CsHeR4zV.js} +8 -8
  14. package/dist/{dataset-pgqp-YfL.js → dataset-JDyjG3QR.js} +3 -3
  15. package/dist/{dropout-Bciw46HT.js → dropout-hpDwECTe.js} +7 -7
  16. package/dist/{gather-DjyCjmOD.js → gather-D0_gPiBz.js} +4 -4
  17. package/dist/gelu-uyHP1x1f.js +26 -0
  18. package/dist/gpgpu_math-DJm3ZTAf.js +2371 -0
  19. package/dist/index-BPPzKVdR.js +12099 -0
  20. package/dist/{index-BAzbokzv.js → index-C0dhsYom.js} +405 -389
  21. package/dist/{kernel_funcs_utils-CUxJCg0g.js → kernel_funcs_utils-CwRTFqrc.js} +31 -30
  22. package/dist/layers/BaseLayer.js +2 -2
  23. package/dist/layers/CausalSelfAttention.js +6 -6
  24. package/dist/layers/MLP.js +5 -5
  25. package/dist/layers/RMSNorm.js +3 -3
  26. package/dist/layers/RoPECache.js +4 -4
  27. package/dist/layers/TiedEmbedding.js +5 -5
  28. package/dist/layers/TransformerBlock.js +1 -1
  29. package/dist/loader/loadTransformers.js +1 -1
  30. package/dist/loader/oldZipLoad.js +5 -5
  31. package/dist/{log_sum_exp-YEo2h3gb.js → log_sum_exp-D086OgZJ.js} +15 -15
  32. package/dist/main.d.ts +2 -0
  33. package/dist/main.js +9 -5
  34. package/dist/{mat_mul-7121rsJk.js → mat_mul-1nwdPkQ_.js} +4 -4
  35. package/dist/{max-DtlIuVeW.js → max-BQc2Aj-I.js} +4 -4
  36. package/dist/{mulmat_packed_gpu-D4nKF7Je.js → mulmat_packed_gpu-Gzf3I9UV.js} +1 -1
  37. package/dist/non_max_suppression_impl-CsEgBuMA.js +134 -0
  38. package/dist/{ones-BBlSRqn1.js → ones-D63HpSF_.js} +2 -2
  39. package/dist/ops/appendCache.js +3 -3
  40. package/dist/ops/attentionMask.js +1 -1
  41. package/dist/ops/cpu/appendCache.js +8 -8
  42. package/dist/ops/cpu/attentionMask.js +9 -9
  43. package/dist/ops/cpu/fusedSoftmax.js +17 -11
  44. package/dist/ops/cpu/gatherSub.js +7 -7
  45. package/dist/ops/cpu/gelu.js +13 -13
  46. package/dist/ops/cpu/matMulGelu.js +36 -24
  47. package/dist/ops/cpu/matMulMul.js +14 -8
  48. package/dist/ops/cpu/mulDropout.js +9 -3
  49. package/dist/ops/cpu/normRMS.js +5 -5
  50. package/dist/ops/cpu/qkv.js +3 -3
  51. package/dist/ops/cpu/rope.js +5 -5
  52. package/dist/ops/cpu/scatterSub.js +11 -11
  53. package/dist/ops/fusedSoftmax.js +1 -1
  54. package/dist/ops/gatherSub.js +1 -1
  55. package/dist/ops/gelu.js +2 -2
  56. package/dist/ops/grads/attentionMask.js +1 -1
  57. package/dist/ops/grads/fusedSoftmax.js +2 -2
  58. package/dist/ops/grads/gelu.js +3 -24
  59. package/dist/ops/grads/matMulGelu.js +5 -5
  60. package/dist/ops/grads/normRMS.js +6 -6
  61. package/dist/ops/grads/qkv.js +1 -1
  62. package/dist/ops/grads/rope.js +3 -3
  63. package/dist/ops/matMulGelu.js +1 -1
  64. package/dist/ops/matMulMul.js +1 -1
  65. package/dist/ops/mulDrop.js +1 -1
  66. package/dist/ops/normRMS.js +1 -1
  67. package/dist/ops/qkv.js +1 -1
  68. package/dist/ops/rope.js +4 -4
  69. package/dist/ops/scatterSub.js +1 -1
  70. package/dist/ops/webgl/appendCache.js +1 -1
  71. package/dist/ops/webgl/attentionMask.js +1 -1
  72. package/dist/ops/webgl/fusedSoftmax.js +4 -4
  73. package/dist/ops/webgl/gatherSub.js +1 -1
  74. package/dist/ops/webgl/gelu.js +2 -2
  75. package/dist/ops/webgl/log.js +5 -5
  76. package/dist/ops/webgl/matMulGelu.js +17 -17
  77. package/dist/ops/webgl/matMulMul.js +1 -1
  78. package/dist/ops/webgl/mulDropout.js +4 -4
  79. package/dist/ops/webgl/normRMS.js +2 -2
  80. package/dist/ops/webgl/qkv.js +1 -1
  81. package/dist/ops/webgl/rope.js +1 -1
  82. package/dist/ops/webgl/scatterSub.js +1 -1
  83. package/dist/ops/webgpu/appendCache.d.ts +1 -0
  84. package/dist/ops/webgpu/appendCache.js +56 -0
  85. package/dist/ops/webgpu/attentionMask.d.ts +1 -0
  86. package/dist/ops/webgpu/attentionMask.js +64 -0
  87. package/dist/ops/webgpu/gatherSub.d.ts +1 -0
  88. package/dist/ops/webgpu/gatherSub.js +37 -0
  89. package/dist/ops/webgpu/gelu.d.ts +14 -0
  90. package/dist/ops/webgpu/gelu.js +86 -0
  91. package/dist/ops/webgpu/index.d.ts +0 -0
  92. package/dist/ops/webgpu/index.js +8 -0
  93. package/dist/ops/webgpu/normRMS.d.ts +1 -0
  94. package/dist/ops/webgpu/normRMS.js +115 -0
  95. package/dist/ops/webgpu/qkv.d.ts +1 -0
  96. package/dist/ops/webgpu/qkv.js +56 -0
  97. package/dist/ops/webgpu/rope.d.ts +1 -0
  98. package/dist/ops/webgpu/rope.js +68 -0
  99. package/dist/ops/webgpu/scatterSub.d.ts +1 -0
  100. package/dist/ops/webgpu/scatterSub.js +37 -0
  101. package/dist/{ops-C0sQEcPw.js → ops-CIQLNshk.js} +452 -503
  102. package/dist/{random_width-DWzaOgrn.js → random_width-DkYP8W8N.js} +143 -144
  103. package/dist/{range-DYsrnfiy.js → range-CYzpQY53.js} +1 -1
  104. package/dist/{reciprocal-CJQeasVa.js → reciprocal-_A9yv27J.js} +1 -1
  105. package/dist/{register_all_kernels-BfFCQAqs.js → register_all_kernels-guvSxp7M.js} +202 -200
  106. package/dist/{reshape-krWGKraP.js → reshape-BMUzc1UY.js} +3 -3
  107. package/dist/{scatter_nd_util-93ln7Hut.js → scatter_nd_util-IRBqKz_b.js} +3 -3
  108. package/dist/{selu_util-sntGesxr.js → selu_util-Dt_iuXaq.js} +6 -6
  109. package/dist/shared-BNa2q6jD.js +69 -0
  110. package/dist/{shared-Ca6iDobD.js → shared-CDu9S76h.js} +541 -606
  111. package/dist/{sin-D_h-qCSx.js → sin-Cocju-BY.js} +6 -6
  112. package/dist/{softmax-fsdtf6JC.js → softmax-GPNK3o-U.js} +3 -3
  113. package/dist/{split-eiktj-6L.js → split-CHzJjxDv.js} +4 -4
  114. package/dist/{stack-dfEEz2OY.js → stack-Dpgg_1W1.js} +2 -2
  115. package/dist/{sum-BE_Irnim.js → sum-B8wEpKsg.js} +5 -5
  116. package/dist/{tensor-Xyi595sG.js → tensor-RvZVNmg0.js} +1 -1
  117. package/dist/{tensor2d-CPEkynbH.js → tensor2d-B_kyod7_.js} +1 -1
  118. package/dist/training/AdamExt.js +1 -1
  119. package/dist/training/DatasetBuilder.js +2 -2
  120. package/dist/training/Evaluator.js +1 -1
  121. package/dist/training/FullTrainer.js +20 -20
  122. package/dist/training/Trainer.d.ts +5 -6
  123. package/dist/training/Trainer.js +59 -60
  124. package/dist/training/sparseCrossEntropy.js +4 -4
  125. package/dist/utilities/dummy.js +19 -19
  126. package/dist/utilities/generate.js +15 -16
  127. package/dist/utilities/multinomialCPU.d.ts +2 -0
  128. package/dist/utilities/multinomialCPU.js +13 -0
  129. package/dist/utilities/performance.d.ts +2 -0
  130. package/dist/utilities/performance.js +16 -0
  131. package/dist/utilities/profile.d.ts +1 -0
  132. package/dist/utilities/profile.js +9 -6
  133. package/dist/utilities/safetensors.js +2 -2
  134. package/dist/utilities/weights.js +2 -2
  135. package/dist/{variable-wSS22xj5.js → variable-DXEUOwew.js} +1 -1
  136. package/dist/webgpu_util-g13LvDIv.js +625 -0
  137. package/dist/{zeros-YJDE7oRb.js → zeros-DCPCdFGq.js} +8 -8
  138. package/package.json +2 -1
  139. package/dist/gpgpu_math-CNslybmD.js +0 -3115
  140. package/dist/norm-CzltS9Fz.js +0 -86
@@ -1,4 +1,4 @@
1
- import { E as i } from "./index-BAzbokzv.js";
1
+ import { E as i } from "./index-C0dhsYom.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2018 Google LLC. All Rights Reserved.
@@ -0,0 +1,625 @@
1
+ import { aa as F, ab as O, ac as L, Z as N, l as _ } from "./index-C0dhsYom.js";
2
+ /**
3
+ * @license
4
+ * Copyright 2019 Google LLC. All Rights Reserved.
5
+ * Licensed under the Apache License, Version 2.0 (the "License");
6
+ * you may not use this file except in compliance with the License.
7
+ * You may obtain a copy of the License at
8
+ *
9
+ * http://www.apache.org/licenses/LICENSE-2.0
10
+ *
11
+ * Unless required by applicable law or agreed to in writing, software
12
+ * distributed under the License is distributed on an "AS IS" BASIS,
13
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ * See the License for the specific language governing permissions and
15
+ * limitations under the License.
16
+ * =============================================================================
17
+ */
18
+ function U(e, s) {
19
+ if (Math.max(...e) > 5)
20
+ throw new Error("Cannot symbolically compute strides for rank > 6 tensor.");
21
+ const t = e.length, o = "xyzwuv", n = e.map((a) => `${s}.${o[a]}`), r = new Array(t - 1);
22
+ r[t - 2] = n[t - 1];
23
+ for (let a = t - 3; a >= 0; --a)
24
+ r[a] = `(${r[a + 1]} * ${n[a + 1]})`;
25
+ return r;
26
+ }
27
+ const H = (e, s, t) => t === "int32" ? `atomicAdd(${e}, bitcast<i32>(${s}));` : `
28
+ {
29
+ var oldValue = 0;
30
+ loop {
31
+ let newValueF32 = bitcast<f32>(oldValue) + (${s});
32
+ let newValue = bitcast<i32>(newValueF32);
33
+ let res = atomicCompareExchangeWeak(${e}, oldValue, newValue);
34
+ if res.exchanged {
35
+ break;
36
+ }
37
+ oldValue = res.old_value;
38
+ }
39
+ }`;
40
+ /**
41
+ * @license
42
+ * Copyright 2022 Google LLC. All Rights Reserved.
43
+ * Licensed under the Apache License, Version 2.0 (the "License");
44
+ * you may not use this file except in compliance with the License.
45
+ * You may obtain a copy of the License at
46
+ *
47
+ * http://www.apache.org/licenses/LICENSE-2.0
48
+ *
49
+ * Unless required by applicable law or agreed to in writing, software
50
+ * distributed under the License is distributed on an "AS IS" BASIS,
51
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
52
+ * See the License for the specific language governing permissions and
53
+ * limitations under the License.
54
+ * =============================================================================
55
+ */
56
+ var y;
57
+ (function(e) {
58
+ e[e.FROM_PIXELS = 0] = "FROM_PIXELS", e[e.DRAW = 1] = "DRAW";
59
+ })(y || (y = {}));
60
+ const Z = (e, s, t, o, n) => {
61
+ const r = { dtype: o.dtype, shape: o.shape }, a = P(t, r, s), i = e.createShaderModule({ code: a, label: s.constructor.name });
62
+ let d = L().get("WEBGPU_PRINT_SHADER");
63
+ if (d !== "") {
64
+ d = d.toLowerCase();
65
+ const p = d.split(",");
66
+ (d === "all" || p.some((u) => s.shaderKey.toLowerCase().includes(u))) && (console.group(s.shaderKey), console.debug(a), console.groupEnd());
67
+ }
68
+ return n ? e.createComputePipelineAsync({
69
+ compute: { module: i, entryPoint: "_start" },
70
+ label: s.constructor.name,
71
+ layout: "auto"
72
+ }) : e.createComputePipeline({
73
+ compute: { module: i, entryPoint: "_start" },
74
+ label: s.constructor.name,
75
+ layout: "auto"
76
+ });
77
+ }, f = (e, s = "f32") => {
78
+ switch (e) {
79
+ case 1:
80
+ return `${s}`;
81
+ case 2:
82
+ return `vec2<${s}>`;
83
+ case 3:
84
+ return `vec3<${s}>`;
85
+ case 4:
86
+ return `vec4<${s}>`;
87
+ default:
88
+ throw new Error(`${e}-component ${s} is not supported.`);
89
+ }
90
+ };
91
+ function v(e) {
92
+ if (e <= 1)
93
+ return "i32";
94
+ if (e === 2)
95
+ return "vec2<i32>";
96
+ if (e === 3)
97
+ return "vec3<i32>";
98
+ if (e === 4)
99
+ return "vec4<i32>";
100
+ if (e === 5)
101
+ return "vec5";
102
+ if (e === 6)
103
+ return "vec6";
104
+ throw Error(`GPU for rank ${e} is not yet supported`);
105
+ }
106
+ function I(e) {
107
+ if (e === 0)
108
+ return "x";
109
+ if (e === 1)
110
+ return "y";
111
+ if (e === 2)
112
+ return "z";
113
+ if (e === 3)
114
+ return "w";
115
+ if (e === 4)
116
+ return "u";
117
+ if (e === 5)
118
+ return "v";
119
+ throw Error(`Index ${e} is not yet supported`);
120
+ }
121
+ function q(...e) {
122
+ let s;
123
+ switch (e.length) {
124
+ case 0:
125
+ s = `
126
+ fn main()
127
+ `;
128
+ break;
129
+ case 1:
130
+ s = `
131
+ fn main(${e[0]} : i32)
132
+ `;
133
+ break;
134
+ default:
135
+ throw Error("Unreachable");
136
+ }
137
+ return s;
138
+ }
139
+ function w(e, s) {
140
+ let t;
141
+ return t = `
142
+ ${D(s)}
143
+ fn _start(@builtin(local_invocation_id) LocalId : vec3<u32>,
144
+ @builtin(global_invocation_id) GlobalId : vec3<u32>,
145
+ @builtin(local_invocation_index) LocalIndex: u32,
146
+ @builtin(workgroup_id) WorkgroupId : vec3<u32>,
147
+ @builtin(num_workgroups) NumWorkgroups : vec3<u32>) {
148
+ localId = LocalId;
149
+ localIndex = LocalIndex;
150
+ globalId = GlobalId;
151
+ numWorkgroups = NumWorkgroups;
152
+ workgroupId = WorkgroupId;
153
+ ${e ? "main(getGlobalIndex());" : "main();"};
154
+ }
155
+ `, t;
156
+ }
157
+ function D(e) {
158
+ return `
159
+ @compute @workgroup_size(${e.workgroupSize[0]}, ${e.workgroupSize[1]}, ${e.workgroupSize[2]})
160
+ `;
161
+ }
162
+ function P(e, s, t) {
163
+ const o = [], n = t.workgroupSize[0] * t.workgroupSize[1] * t.workgroupSize[2];
164
+ if (t.outputComponent = t.outputComponent ? t.outputComponent : 1, o.push(`
165
+
166
+ var<private> localId: vec3<u32>;
167
+ var<private> localIndex: u32;
168
+ var<private> globalId: vec3<u32>;
169
+ var<private> numWorkgroups: vec3<u32>;
170
+ var<private> workgroupId: vec3<u32>;
171
+
172
+ // Only used when the y/z dimension of workgroup size is 1.
173
+ fn getGlobalIndex() -> i32 {
174
+ ${E(t) ? " return i32(globalId.x);" : ` return i32((workgroupId.z * numWorkgroups.x * numWorkgroups.y +
175
+ workgroupId.y * numWorkgroups.x + workgroupId.x) * ${n}u +
176
+ localIndex);
177
+ `}
178
+ }
179
+ `), t.pixelsOpType != null) {
180
+ const h = t.pixelsOpType === y.FROM_PIXELS ? `@group(0) @binding(0) var<storage, read_write> result: array<${b(s.dtype, t.outputComponent)}>;` : `@group(0) @binding(1) var<storage, read> inBuf : array<${b(e[0].dtype, t.outputComponent)}>;`, c = s.shape.length === 3 ? "vec2<i32>" : "i32";
181
+ o.push(`
182
+ struct Uniform {
183
+ outShapeStrides : ${c},
184
+ size : i32,
185
+ numChannels : i32,
186
+ alpha : f32,
187
+ };
188
+
189
+ ${h}
190
+ @group(0) @binding(2) var<uniform> uniforms: Uniform;
191
+ `);
192
+ const x = k(t);
193
+ return [
194
+ C,
195
+ o.join(`
196
+ `),
197
+ m(s.shape),
198
+ t.getUserCode(),
199
+ w(x, t)
200
+ ].join(`
201
+ `);
202
+ }
203
+ let r, a, i = "struct Uniforms { NAN : f32, INFINITY : f32, ";
204
+ t.variableNames.forEach((h, c) => {
205
+ const x = v(e[c].shape.length);
206
+ i += `${h.charAt(0).toLowerCase() + h.slice(1)}Shape : ${x}, `, r = e[c].shape.length - 1, a = v(r), i += `${h.charAt(0).toLowerCase() + h.slice(1)}ShapeStrides: ${a}, `;
207
+ });
208
+ const d = v(s.shape.length);
209
+ i += `outShape : ${d}, `, r = s.shape.length - 1, a = v(r), i += `
210
+ outShapeStrides: ${a}, `, t.size && (i += "size : i32, "), t.uniforms && (i += t.uniforms), i += "};", i = K(i), o.push(i), t.atomic ? o.push(`
211
+ @group(0) @binding(0) var<storage, read_write> result: array<atomic<i32>>;
212
+ `) : o.push(`
213
+ @group(0) @binding(0) var<storage, read_write> result: array<${b(s.dtype, t.outputComponent)}>;
214
+ `), t.variableNames.forEach((h, c) => {
215
+ o.push(`
216
+ @group(0) @binding(${1 + c}) var<storage, read> ${h}: array<${t.variableComponents ? b(e[c].dtype, t.variableComponents[c]) : b(e[c].dtype, t.outputComponent)}>;
217
+ `);
218
+ }), i !== "" && o.push(`
219
+ @group(0) @binding(${1 + t.variableNames.length}) var<uniform> uniforms: Uniforms;
220
+ `);
221
+ const p = T(s.shape, t.dispatchLayout), u = [
222
+ C,
223
+ o.join(`
224
+ `) + W,
225
+ m(s.shape),
226
+ p,
227
+ V(s.shape.length)
228
+ ];
229
+ t.atomic || u.push(B(s.shape, s.dtype, t.outputComponent)), t.variableNames.forEach((h, c) => {
230
+ u.push(`${m(e[c].shape, h)}`);
231
+ });
232
+ const g = e.map((h, c) => G(h, s.shape, t.variableComponents ? t.variableComponents[c] : t.outputComponent, t.dispatchLayout.x.length === s.shape.length)).join(`
233
+ `);
234
+ u.push(g), u.push(t.getUserCode());
235
+ const l = k(t);
236
+ return u.push(w(l, t)), u.join(`
237
+ `);
238
+ }
239
+ function J(e, s, t) {
240
+ let o = e.shaderKey;
241
+ if (e.pixelsOpType != null)
242
+ return o;
243
+ const n = [], r = [];
244
+ s.forEach((u) => {
245
+ n.push(u.shape), r.push(u.dtype);
246
+ }), n.push(t.shape), r.push(t.dtype);
247
+ const a = s.map((u) => F(u.shape, t.shape)), i = s.map((u) => O(u.shape, t.shape)).join("_"), d = a.map((u) => u.join("_")).join(";"), p = E(e) ? "flatDispatch" : "";
248
+ return o += "_" + (e.workgroupSize ? e.workgroupSize.join(",") : "") + n.map((u) => u.length).join(",") + r.join(",") + e.variableNames.join(",") + d + i + p, o;
249
+ }
250
+ const C = `
251
+ struct vec5 {x: i32, y: i32, z: i32, w: i32, u: i32};
252
+ struct vec6 {x: i32, y: i32, z: i32, w: i32, u: i32, v: i32};
253
+
254
+ // Checks whether coordinates lie within the bounds of the shape.
255
+ fn coordsInBounds2D(coord : vec2<i32>, shape : vec2<i32>) -> bool {
256
+ return all(coord >= vec2<i32>(0)) && all(coord < shape);
257
+ }
258
+ fn coordsInBounds3D(coord : vec3<i32>, shape : vec3<i32>) -> bool {
259
+ return all(coord >= vec3<i32>(0)) && all(coord < shape);
260
+ }
261
+ fn coordsInBounds4D(coord : vec4<i32>, shape : vec4<i32>) -> bool {
262
+ return all(coord >= vec4<i32>(0)) && all(coord < shape);
263
+ }
264
+
265
+ fn getIndexFromCoords1D(coord : i32, shape : i32) -> i32 {
266
+ return coord;
267
+ }
268
+ fn getIndexFromCoords2D(coords : vec2<i32>, shape : vec2<i32>) -> i32 {
269
+ return dot(coords, vec2<i32>(shape.y, 1));
270
+ }
271
+ fn getIndexFromCoords3D(coords : vec3<i32>, shape : vec3<i32>) -> i32 {
272
+ return dot(coords, vec3<i32>(shape.y * shape.z, shape.z, 1));
273
+ }
274
+ fn getIndexFromCoords4D(coords : vec4<i32>, shape : vec4<i32>) -> i32 {
275
+ return dot(coords, vec4<i32>(
276
+ shape.y * shape.z * shape.w, shape.z * shape.w, shape.w, 1));
277
+ }
278
+ fn getIndexFromCoords5D(coords : vec5, shape : vec5) -> i32 {
279
+ let shapeStrides: vec5 = vec5(shape.y * shape.z * shape.w * shape.u, shape.z * shape.w * shape.u, shape.w * shape.u, shape.u, 1);
280
+ return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u;
281
+ }
282
+ fn getIndexFromCoords6D(coords : vec6, shape : vec6) -> i32 {
283
+ let shapeStrides: vec6 = vec6(shape.y * shape.z * shape.w * shape.u * shape.v, shape.z * shape.w * shape.u * shape.v, shape.w * shape.u * shape.v, shape.u * shape.v, shape.v, 1);
284
+ return coords.x*shapeStrides.x + coords.y*shapeStrides.y + coords.z*shapeStrides.z + coords.w*shapeStrides.w + coords.u*shapeStrides.u + coords.v*shapeStrides.v;
285
+ }
286
+
287
+ // NaN defination in IEEE 754-1985 is :
288
+ // - sign = either 0 or 1.
289
+ // - biased exponent = all 1 bits.
290
+ // - fraction = anything except all 0 bits (since all 0 bits represents infinity).
291
+ // https://en.wikipedia.org/wiki/IEEE_754-1985#Representation_of_non-numbers
292
+ fn isnan(val: f32) -> bool {
293
+ let floatToUint: u32 = bitcast<u32>(val);
294
+ return (floatToUint & 0x7fffffffu) > 0x7f800000u;
295
+ }
296
+ fn isnanVec4(val : vec4<f32>) -> vec4<bool> {
297
+ let floatToUint: vec4<u32> = bitcast<vec4<u32>>(val);
298
+ return (floatToUint & vec4<u32>(0x7fffffffu)) > vec4<u32>(0x7f800000u);
299
+ }
300
+ `, W = `
301
+ fn isinf(val: f32) -> bool {
302
+ return abs(val) == uniforms.INFINITY;
303
+ }
304
+ `;
305
+ function m(e, s = "") {
306
+ const t = e.length, o = s !== "" ? `get${s.charAt(0).toUpperCase() + s.slice(1)}CoordsFromIndex` : "getCoordsFromIndex", n = s !== "" ? `${s.charAt(0).toLowerCase() + s.slice(1)}ShapeStrides` : "outShapeStrides";
307
+ if (t <= 1)
308
+ return `fn ${o}(index : i32) -> i32 { return index; }`;
309
+ const r = N(e), a = v(t), i = [];
310
+ for (let p = 0; p < t; p++)
311
+ i.push(`d${p}`);
312
+ if (r.length === 1)
313
+ return ` fn ${o}(index : i32) -> vec2<i32> {
314
+ let d0 = index / uniforms.${n}; let d1 = index - d0 * uniforms.${n};
315
+ return vec2<i32>(d0, d1);
316
+ }`;
317
+ let d;
318
+ return d = "var index2 = index;" + r.map((p, u) => {
319
+ const g = `let ${i[u]} = index2 / uniforms.${n}.${I(u)}`, l = u === r.length - 1 ? `let ${i[u + 1]} = index2 - ${i[u]} * uniforms.${n}.${I(u)}` : `index2 = index2 - ${i[u]} * uniforms.${n}.${I(u)}`;
320
+ return `${g}; ${l};`;
321
+ }).join(""), `
322
+ fn ${o}(index : i32) -> ${a} {
323
+ ${d}
324
+ return ${a}(${i.join(",")});
325
+ }
326
+ `;
327
+ }
328
+ function M(e, s) {
329
+ const t = e.name, o = e.shape.length, n = v(o), r = "get" + t.charAt(0).toUpperCase() + t.slice(1), a = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), i = a.map((u) => `${u} : i32`).join(", ");
330
+ if (o < 1)
331
+ return `
332
+ fn ${r}() -> ${f(s)} {
333
+ return ${f(s)}(${t}[0]);
334
+ }
335
+ `;
336
+ const d = `uniforms.${t.charAt(0).toLowerCase() + t.slice(1)}Shape`;
337
+ let p = `${o}D`;
338
+ return o === 0 && (p = "1D"), `
339
+ fn ${r}(${i}) -> ${f(s)} {
340
+ return ${f(s)}(${t}[getIndexFromCoords${p}(${n}(${a.join(",")}),
341
+ ${d})${s === 1 ? "" : ` / ${s}`}]);
342
+ }
343
+ `;
344
+ }
345
+ function R(e, s, t, o) {
346
+ const n = e.name, r = n.charAt(0).toUpperCase() + n.slice(1), a = "get" + r + "ByOutput", i = e.shape.length, d = s.length, p = v(d);
347
+ if (O(e.shape, s) && o)
348
+ return `
349
+ fn ${a}Index(globalIndex : i32) -> ${f(t)} {
350
+ return ${f(t)}(${n}[globalIndex]);
351
+ }
352
+
353
+ fn ${a}Coords(coords : ${p}) -> ${f(t)} {
354
+ return ${f(t)}(${n}[${d > 1 ? "getOutputIndexFromCoords(coords)" : "coords"}${t === 1 ? "" : ` / ${t}`}]);
355
+ }
356
+ `;
357
+ const u = F(e.shape, s), g = d - i;
358
+ let l = "";
359
+ if (i === 0)
360
+ return `
361
+ fn ${a}Index(globalIndex : i32) -> ${f(t)}{
362
+ return get${r}();
363
+ }
364
+
365
+ fn ${a}Coords(coords : ${p}) -> ${f(t)}{
366
+ return get${r}();
367
+ }
368
+ `;
369
+ d < 2 && u.length >= 1 ? l = "coords = 0;" : l = u.map((x) => `coords.${I(x + g)} = 0;`).join(`
370
+ `);
371
+ let $ = "";
372
+ if (d < 2 && i > 0)
373
+ $ = "coords";
374
+ else if (d > 1) {
375
+ const x = v(i), j = e.shape.map((X, A) => `coords.${I(A + g)}`).join(", ");
376
+ $ = `${x}(${j})`;
377
+ } else
378
+ $ = "coords";
379
+ const h = `uniforms.${n.charAt(0).toLowerCase() + n.slice(1)}Shape`, c = `${i}D`;
380
+ return `
381
+ fn ${a}Index(globalIndex : i32) -> ${f(t)} {
382
+ var coords = getCoordsFromIndex(globalIndex);
383
+ ${l}
384
+ return ${f(t)}(${n}[getIndexFromCoords${c}(${$}, ${h})${t === 1 ? "" : ` / ${t}`}]);
385
+ }
386
+
387
+ fn ${a}Coords(coordsIn : ${p}) -> ${f(t)} {
388
+ var coords = coordsIn;
389
+ ${l}
390
+ return ${f(t)}(${n}[getIndexFromCoords${c}(${$}, ${h})${t === 1 ? "" : ` / ${t}`}]);
391
+ }
392
+ `;
393
+ }
394
+ function G(e, s, t, o) {
395
+ let n = M(e, t);
396
+ return e.shape.length <= s.length && (n += R(e, s, t, o)), n;
397
+ }
398
+ function T(e, s) {
399
+ const { x: t, y: o = [], z: n = [] } = s, r = e.length, a = t.length + o.length + n.length;
400
+ if (a !== r)
401
+ return "";
402
+ if (t.length === r)
403
+ return `fn getOutputCoords() -> ${v(r)}{
404
+ let globalIndex = getGlobalIndex();
405
+ return getCoordsFromIndex(globalIndex);
406
+ }
407
+ `;
408
+ let i = "";
409
+ const d = [t, o, n];
410
+ for (let l = 0; l < d.length; l++) {
411
+ const $ = d[l];
412
+ if ($.length !== 0)
413
+ if ($.length === 1)
414
+ i += `let d${$[0]} = i32(globalId[${l}]);`;
415
+ else {
416
+ const h = U($, "uniforms.outShape");
417
+ i += `var index${l} = i32(globalId[${l}]);`;
418
+ for (let c = 0; c < h.length; c++)
419
+ i += `let d${$[c]} = index${l} / ${h[c]};`, c === h.length - 1 ? i += `let d${$[c + 1]} = index${l} - d${$[c]} * ${h[c]};` : i += `index${l} = index${l} - d${$[c]} * ${h[c]};`;
420
+ }
421
+ }
422
+ const p = [];
423
+ for (let l = 0; l < a; l++)
424
+ p.push(`d${l}`);
425
+ const u = v(a);
426
+ let g = `fn getOutputCoords() -> ${u} {
427
+ ${i}
428
+ `;
429
+ return p.length === 0 ? g += `return ${u}(0); }` : g += `return ${u}(${p.join(",")}); }`, g;
430
+ }
431
+ function V(e) {
432
+ let s = "";
433
+ switch (e) {
434
+ case 0:
435
+ case 1:
436
+ s += `
437
+ fn getOutputIndexFromCoords(coords : i32) -> i32 {
438
+ return coords;
439
+ }
440
+ `;
441
+ break;
442
+ case 2:
443
+ s += `
444
+ fn getOutputIndexFromCoords(coords : vec2<i32>) -> i32 {
445
+ return dot(coords, vec2<i32>(uniforms.outShapeStrides, 1));
446
+ }
447
+ `;
448
+ break;
449
+ case 3:
450
+ s += `
451
+ fn getOutputIndexFromCoords(coords : vec3<i32>) -> i32 {
452
+ return dot(coords, vec3<i32>(uniforms.outShapeStrides.x, uniforms.outShapeStrides.y, 1));
453
+ }
454
+ `;
455
+ break;
456
+ case 4:
457
+ s += `
458
+ fn getOutputIndexFromCoords(coords : vec4<i32>) -> i32 {
459
+ return dot(coords, vec4<i32>(
460
+ uniforms.outShapeStrides.x, uniforms.outShapeStrides.y, uniforms.outShapeStrides.z, 1));
461
+ }
462
+ `;
463
+ break;
464
+ case 5:
465
+ s += `
466
+ fn getOutputIndexFromCoords(coords : vec5) -> i32 {
467
+ return coords.x * uniforms.outShapeStrides.x +
468
+ coords.y * uniforms.outShapeStrides.y +
469
+ coords.z * uniforms.outShapeStrides.z +
470
+ coords.w * uniforms.outShapeStrides.w +
471
+ coords.u;
472
+ }
473
+ `;
474
+ break;
475
+ case 6:
476
+ s += `
477
+ fn getOutputIndexFromCoords(coords : vec6) -> i32 {
478
+ return coords.x * uniforms.outShapeStrides.x +
479
+ coords.y * uniforms.outShapeStrides.y +
480
+ coords.z * uniforms.outShapeStrides.z +
481
+ coords.w * uniforms.outShapeStrides.w +
482
+ coords.u * uniforms.outShapeStrides.u +
483
+ coords.v;
484
+ }
485
+ `;
486
+ break;
487
+ default:
488
+ _(!1, () => `Unsupported ${e}D shape`);
489
+ break;
490
+ }
491
+ return s;
492
+ }
493
+ function E(e) {
494
+ return e.dispatch[1] === 1 && e.dispatch[2] === 1;
495
+ }
496
+ function b(e, s = 1) {
497
+ if (e === "float32")
498
+ return f(s, "f32");
499
+ if (e === "int32" || e === "bool")
500
+ return f(s, "i32");
501
+ throw new Error(`type ${e} is not supported.`);
502
+ }
503
+ function B(e, s, t) {
504
+ const o = e.length, n = b(s, t);
505
+ let r = `fn setOutputAtIndex(flatIndex : i32, value : ${f(t)}) {
506
+ result[flatIndex] = ${n}(value);
507
+ }
508
+
509
+ fn setOutputAtIndexI32(flatIndex : i32, value : ${f(t, "i32")}) {
510
+ result[flatIndex] = ${n}(value);
511
+ }
512
+ `;
513
+ if (o >= 2) {
514
+ const a = ["d0", "d1", "d2", "d3", "d4", "d5"].slice(0, o), i = v(o);
515
+ r += `
516
+ fn setOutputAtCoords(${a.map((d) => `${d} : i32`).join(", ")}, value : ${f(t)}) {
517
+ let flatIndex = getOutputIndexFromCoords(${i}(${a.join(", ")}));
518
+ setOutputAtIndex(flatIndex${t === 1 ? "" : ` / ${t}`}, value);
519
+ }
520
+ fn setOutputAtCoordsI32(${a.map((d) => `${d} : i32`).join(", ")}, value : ${f(t, "i32")}) {
521
+ let flatIndex = getOutputIndexFromCoords(${i}(${a.join(", ")}));
522
+ setOutputAtIndexI32(flatIndex${t === 1 ? "" : ` / ${t}`}, value);
523
+ }
524
+ `;
525
+ }
526
+ return r;
527
+ }
528
+ function K(e) {
529
+ const s = /(\w+)\s*:\s*vec(5|6)/g;
530
+ e = e.replace(s, (o) => "@align(16) " + o);
531
+ const t = /vec(5|6)\s*,\s*(\w+)/g;
532
+ return e = e.replace(t, (o, n, r) => `vec${n}, @align(16) ${r}`), e;
533
+ }
534
+ function k(e) {
535
+ return !(e.dispatchLayout.hasOwnProperty("y") && e.dispatchLayout.y.length !== 0 || e.dispatchLayout.hasOwnProperty("z") && e.dispatchLayout.z.length !== 0);
536
+ }
537
+ /**
538
+ * @license
539
+ * Copyright 2019 Google LLC. All Rights Reserved.
540
+ * Licensed under the Apache License, Version 2.0 (the "License");
541
+ * you may not use this file except in compliance with the License.
542
+ * You may obtain a copy of the License at
543
+ *
544
+ * http://www.apache.org/licenses/LICENSE-2.0
545
+ *
546
+ * Unless required by applicable law or agreed to in writing, software
547
+ * distributed under the License is distributed on an "AS IS" BASIS,
548
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
549
+ * See the License for the specific language governing permissions and
550
+ * limitations under the License.
551
+ * =============================================================================
552
+ */
553
+ const S = (e) => {
554
+ let s = 1;
555
+ for (let t = 0; t < e.length; t++)
556
+ s *= e[t];
557
+ return s;
558
+ };
559
+ function Q(e, s, t = [1, 1, 1], o = [1, 1, 1]) {
560
+ const [n, r, a] = [
561
+ Math.ceil(S(e.x.map((i) => s[i])) / (t[0] * o[0])),
562
+ e.y ? Math.ceil(S(e.y.map((i) => s[i])) / (t[1] * o[1])) : 1,
563
+ e.z ? Math.ceil(S(e.z.map((i) => s[i])) / (t[2] * o[2])) : 1
564
+ ];
565
+ return [n, r, a];
566
+ }
567
+ function ee(e, s, t, o = !1) {
568
+ const n = [8, 8, 1], r = [4, 4, 1];
569
+ return o || (e <= 8 && (r[1] = 1), s <= 16 && t <= 16 && (n[0] = 4)), { workgroupSize: n, elementsPerThread: r };
570
+ }
571
+ function te(e, s, t = !1) {
572
+ if (t)
573
+ return [8, 8, 1];
574
+ const o = S(e.x.map((r) => s[r])), n = S(e.y.map((r) => s[r]));
575
+ return o <= 4 ? [4, 16, 1] : n <= 4 ? [16, 4, 1] : [16, 16, 1];
576
+ }
577
+ function se(e, s, t = !1) {
578
+ if (t)
579
+ return [4, 4, 1];
580
+ const o = S(e.x.map((r) => s[r])), n = S(e.y.map((r) => s[r]));
581
+ return o <= 4 ? [1, 2, 1] : n <= 4 ? [2, 1, 1] : [2, 2, 1];
582
+ }
583
+ function oe(e) {
584
+ return { x: e.map((s, t) => t) };
585
+ }
586
+ function re(e) {
587
+ if (e === "float32" || e === "int32" || e === "bool" || e === "string")
588
+ return 4;
589
+ if (e === "complex64")
590
+ return 8;
591
+ throw new Error(`Unknown dtype ${e}`);
592
+ }
593
+ function ne() {
594
+ return !!(typeof globalThis < "u" && globalThis.navigator && globalThis.navigator.gpu);
595
+ }
596
+ function ie(e, s) {
597
+ Array.isArray(e) || (e = [e]), e.forEach((t) => {
598
+ t != null && _(t.dtype !== "complex64", () => `${s} does not support complex64 tensors in the WebGPU backend.`);
599
+ });
600
+ }
601
+ var z;
602
+ (function(e) {
603
+ e[e.MatMulReduceProgram = 0] = "MatMulReduceProgram", e[e.MatMulSplitKProgram = 1] = "MatMulSplitKProgram", e[e.MatMulSmallOutputSizeProgram = 2] = "MatMulSmallOutputSizeProgram", e[e.MatMulPackedProgram = 3] = "MatMulPackedProgram", e[e.MatMulMax = 4] = "MatMulMax";
604
+ })(z || (z = {}));
605
+ export {
606
+ re as G,
607
+ z as M,
608
+ y as P,
609
+ Z as a,
610
+ ee as b,
611
+ Q as c,
612
+ H as d,
613
+ v as e,
614
+ oe as f,
615
+ q as g,
616
+ I as h,
617
+ ne as i,
618
+ ie as j,
619
+ te as k,
620
+ se as l,
621
+ J as m,
622
+ b as n,
623
+ m as o,
624
+ f as t
625
+ };
@@ -1,4 +1,4 @@
1
- import { o as m, q as r, Z as l, E as c, _ as i, x as p, $ as u, g as x } from "./index-BAzbokzv.js";
1
+ import { x as m, y as r, a0 as l, E as c, a1 as i, C as p, a2 as u, j as x } from "./index-C0dhsYom.js";
2
2
  /**
3
3
  * @license
4
4
  * Copyright 2020 Google LLC. All Rights Reserved.
@@ -16,9 +16,9 @@ import { o as m, q as r, Z as l, E as c, _ as i, x as p, $ as u, g as x } from "
16
16
  * =============================================================================
17
17
  */
18
18
  function f(a, e) {
19
- const o = r(a, "real", "complex"), s = r(e, "imag", "complex");
20
- l(o.shape, s.shape, `real and imag shapes, ${o.shape} and ${s.shape}, must match in call to tf.complex().`);
21
- const n = { real: o, imag: s };
19
+ const s = r(a, "real", "complex"), o = r(e, "imag", "complex");
20
+ l(s.shape, o.shape, `real and imag shapes, ${s.shape} and ${o.shape}, must match in call to tf.complex().`);
21
+ const n = { real: s, imag: o };
22
22
  return c.runKernel(i, n);
23
23
  }
24
24
  const g = /* @__PURE__ */ m({ complex_: f });
@@ -40,11 +40,11 @@ const g = /* @__PURE__ */ m({ complex_: f });
40
40
  */
41
41
  function t(a, e = "float32") {
42
42
  if (p(a), e === "complex64") {
43
- const s = t(a, "float32"), n = t(a, "float32");
44
- return g(s, n);
43
+ const o = t(a, "float32"), n = t(a, "float32");
44
+ return g(o, n);
45
45
  }
46
- const o = u(x(a), e);
47
- return c.makeTensor(o, a, e);
46
+ const s = u(x(a), e);
47
+ return c.makeTensor(s, a, e);
48
48
  }
49
49
  export {
50
50
  g as c,
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@genai-fi/nanogpt",
3
- "version": "0.6.3",
3
+ "version": "0.7.0",
4
4
  "type": "module",
5
5
  "main": "dist/main.js",
6
6
  "types": "dist/main.d.ts",
@@ -50,6 +50,7 @@
50
50
  "dependencies": {
51
51
  "@dsnp/parquetjs": "^1.8.7",
52
52
  "@tensorflow/tfjs": "^4.22.0",
53
+ "@tensorflow/tfjs-backend-webgpu": "^4.22.0",
53
54
  "eventemitter3": "^5.0.1",
54
55
  "jszip": "^3.10.1",
55
56
  "papaparse": "^5.5.3",