@mni-ml/framework 0.0.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. package/dist/autodiff.d.ts +13 -0
  2. package/dist/autodiff.d.ts.map +1 -0
  3. package/dist/autodiff.js +91 -0
  4. package/dist/autodiff.js.map +1 -0
  5. package/dist/datasets.d.ts +16 -0
  6. package/dist/datasets.d.ts.map +1 -0
  7. package/dist/datasets.js +64 -0
  8. package/dist/datasets.js.map +1 -0
  9. package/dist/fast_ops.d.ts +23 -0
  10. package/dist/fast_ops.d.ts.map +1 -0
  11. package/dist/fast_ops.js +263 -0
  12. package/dist/fast_ops.js.map +1 -0
  13. package/dist/fast_ops_worker.d.ts +2 -0
  14. package/dist/fast_ops_worker.d.ts.map +1 -0
  15. package/dist/fast_ops_worker.js +119 -0
  16. package/dist/fast_ops_worker.js.map +1 -0
  17. package/dist/gpu_backend.d.ts +37 -0
  18. package/dist/gpu_backend.d.ts.map +1 -0
  19. package/dist/gpu_backend.js +163 -0
  20. package/dist/gpu_backend.js.map +1 -0
  21. package/dist/gpu_kernels.d.ts +74 -0
  22. package/dist/gpu_kernels.d.ts.map +1 -0
  23. package/dist/gpu_kernels.js +571 -0
  24. package/dist/gpu_kernels.js.map +1 -0
  25. package/dist/gpu_ops.d.ts +43 -0
  26. package/dist/gpu_ops.d.ts.map +1 -0
  27. package/dist/gpu_ops.js +365 -0
  28. package/dist/gpu_ops.js.map +1 -0
  29. package/dist/index.d.ts +15 -0
  30. package/dist/index.d.ts.map +1 -0
  31. package/dist/index.js +20 -0
  32. package/dist/index.js.map +1 -0
  33. package/dist/module.d.ts +23 -0
  34. package/dist/module.d.ts.map +1 -0
  35. package/dist/module.js +97 -0
  36. package/dist/module.js.map +1 -0
  37. package/dist/nn.d.ts +63 -0
  38. package/dist/nn.d.ts.map +1 -0
  39. package/dist/nn.js +234 -0
  40. package/dist/nn.js.map +1 -0
  41. package/dist/operators.d.ts +29 -0
  42. package/dist/operators.d.ts.map +1 -0
  43. package/dist/operators.js +91 -0
  44. package/dist/operators.js.map +1 -0
  45. package/dist/optimizer.d.ts +15 -0
  46. package/dist/optimizer.d.ts.map +1 -0
  47. package/dist/optimizer.js +62 -0
  48. package/dist/optimizer.js.map +1 -0
  49. package/dist/scalar.d.ts +42 -0
  50. package/dist/scalar.d.ts.map +1 -0
  51. package/dist/scalar.js +126 -0
  52. package/dist/scalar.js.map +1 -0
  53. package/dist/scalar_functions.d.ts +62 -0
  54. package/dist/scalar_functions.d.ts.map +1 -0
  55. package/dist/scalar_functions.js +127 -0
  56. package/dist/scalar_functions.js.map +1 -0
  57. package/dist/tensor.d.ts +58 -0
  58. package/dist/tensor.d.ts.map +1 -0
  59. package/dist/tensor.js +288 -0
  60. package/dist/tensor.js.map +1 -0
  61. package/dist/tensor_data.d.ts +29 -0
  62. package/dist/tensor_data.d.ts.map +1 -0
  63. package/dist/tensor_data.js +131 -0
  64. package/dist/tensor_data.js.map +1 -0
  65. package/dist/tensor_functions.d.ts +97 -0
  66. package/dist/tensor_functions.d.ts.map +1 -0
  67. package/dist/tensor_functions.js +465 -0
  68. package/dist/tensor_functions.js.map +1 -0
  69. package/dist/tensor_ops.d.ts +47 -0
  70. package/dist/tensor_ops.d.ts.map +1 -0
  71. package/dist/tensor_ops.js +249 -0
  72. package/dist/tensor_ops.js.map +1 -0
  73. package/package.json +45 -0
@@ -0,0 +1,571 @@
1
+ import * as operators from './operators.js';
2
+ export const WORKGROUP_SIZE = 256;
3
+ export const BLOCK_SIZE = 16; // 16x16 = 256 = WORKGROUP_SIZE, used for 2D tiled matmul
4
+ const MAX_DIMS = 6;
5
+ // ---- Operation registries ----
6
+ export const UNARY_OPS = {
7
+ neg: 'return -x;',
8
+ id: 'return x;',
9
+ sigmoid: 'let s = select(1.0 / (1.0 + exp(-x)), exp(x) / (1.0 + exp(x)), x < 0.0); return s;',
10
+ relu: 'return max(x, 0.0);',
11
+ exp: 'return exp(x);',
12
+ log: 'return log(x);',
13
+ inv: 'return 1.0 / x;',
14
+ };
15
+ export const BINARY_OPS = {
16
+ add: 'return a + b;',
17
+ mul: 'return a * b;',
18
+ max: 'return max(a, b);',
19
+ lt: 'return select(0.0, 1.0, a < b);',
20
+ eq: 'return select(0.0, 1.0, abs(a - b) < 0.00001);',
21
+ isClose: 'return select(0.0, 1.0, abs(a - b) < 0.01);',
22
+ };
23
+ // Map TypeScript operator functions to their registry key.
24
+ const unaryRegistry = new Map();
25
+ unaryRegistry.set(operators.neg, 'neg');
26
+ unaryRegistry.set(operators.id, 'id');
27
+ unaryRegistry.set(operators.sigmoid, 'sigmoid');
28
+ unaryRegistry.set(operators.relu, 'relu');
29
+ unaryRegistry.set(operators.exp, 'exp');
30
+ unaryRegistry.set(operators.log, 'log');
31
+ unaryRegistry.set(operators.inv, 'inv');
32
+ const binaryRegistry = new Map();
33
+ binaryRegistry.set(operators.add, 'add');
34
+ binaryRegistry.set(operators.mul, 'mul');
35
+ binaryRegistry.set(operators.max, 'max');
36
+ binaryRegistry.set(operators.lt, 'lt');
37
+ binaryRegistry.set(operators.eq, 'eq');
38
+ binaryRegistry.set(operators.isClose, 'isClose');
39
+ export function resolveUnaryOp(fn) {
40
+ const name = unaryRegistry.get(fn);
41
+ if (name)
42
+ return name;
43
+ throw new Error(`Unknown GPU unary op: ${fn.name || fn.toString().slice(0, 60)}`);
44
+ }
45
+ export function resolveBinaryOp(fn) {
46
+ const name = binaryRegistry.get(fn);
47
+ if (name)
48
+ return name;
49
+ throw new Error(`Unknown GPU binary op: ${fn.name || fn.toString().slice(0, 60)}`);
50
+ }
51
+ export const REDUCE_IDENTITY = {
52
+ add: '0.0',
53
+ mul: '1.0',
54
+ max: '-1.0e+38',
55
+ };
56
+ // ---- WGSL utility: index helpers ported to WGSL ----
57
+ // Uses storage buffers for shape/stride arrays to avoid uniform alignment restrictions.
58
+ const WGSL_INDEX_HELPERS = `
59
+ const MAX_DIMS: u32 = ${MAX_DIMS}u;
60
+
61
+ fn toIndex(ordinal: u32, shape: array<u32, ${MAX_DIMS}>, dims: u32, out_idx: ptr<function, array<u32, ${MAX_DIMS}>>) {
62
+ var remaining = ordinal;
63
+ for (var i = i32(dims) - 1; i >= 0; i--) {
64
+ let d = shape[i];
65
+ (*out_idx)[i] = remaining % d;
66
+ remaining = remaining / d;
67
+ }
68
+ }
69
+
70
+ fn indexToPosition(idx: array<u32, ${MAX_DIMS}>, strd: array<u32, ${MAX_DIMS}>, dims: u32) -> u32 {
71
+ var pos: u32 = 0u;
72
+ for (var i: u32 = 0u; i < dims; i++) {
73
+ pos += idx[i] * strd[i];
74
+ }
75
+ return pos;
76
+ }
77
+
78
+ fn broadcastIndex(
79
+ bigIdx: array<u32, ${MAX_DIMS}>, bigDims: u32,
80
+ smallShape: array<u32, ${MAX_DIMS}>, smallDims: u32,
81
+ out_idx: ptr<function, array<u32, ${MAX_DIMS}>>
82
+ ) {
83
+ let off = bigDims - smallDims;
84
+ for (var i: u32 = 0u; i < smallDims; i++) {
85
+ let bigI = i + off;
86
+ if (smallShape[i] == 1u) {
87
+ (*out_idx)[i] = 0u;
88
+ } else {
89
+ (*out_idx)[i] = bigIdx[bigI];
90
+ }
91
+ }
92
+ }
93
+ `;
94
+ // ---- Shader template builders ----
95
+ /**
96
+ * Aligned map: shapes & strides match, simple 1:1 element mapping.
97
+ */
98
+ export function buildAlignedMapShader(opBody) {
99
+ return `
100
+ @group(0) @binding(0) var<storage, read> in_data: array<f32>;
101
+ @group(0) @binding(1) var<storage, read_write> out_data: array<f32>;
102
+
103
+ struct Params { size: u32 }
104
+ @group(0) @binding(2) var<uniform> params: Params;
105
+
106
+ fn apply(x: f32) -> f32 { ${opBody} }
107
+
108
+ @compute @workgroup_size(${WORKGROUP_SIZE})
109
+ fn main(@builtin(global_invocation_id) gid: vec3u) {
110
+ let i = gid.x;
111
+ if (i >= params.size) { return; }
112
+ out_data[i] = apply(in_data[i]);
113
+ }
114
+ `;
115
+ }
116
+ /**
117
+ * Broadcast map: output and input may differ in shape.
118
+ * Uses storage buffer for params to avoid WGSL uniform array alignment rules.
119
+ */
120
+ export function buildBroadcastMapShader(opBody) {
121
+ return `
122
+ ${WGSL_INDEX_HELPERS}
123
+
124
+ @group(0) @binding(0) var<storage, read> in_data: array<f32>;
125
+ @group(0) @binding(1) var<storage, read_write> out_data: array<f32>;
126
+
127
+ struct Params {
128
+ out_size: u32,
129
+ out_dims: u32,
130
+ in_dims: u32,
131
+ _pad: u32,
132
+ out_shape: array<u32, ${MAX_DIMS}>,
133
+ out_strides: array<u32, ${MAX_DIMS}>,
134
+ in_shape: array<u32, ${MAX_DIMS}>,
135
+ in_strides: array<u32, ${MAX_DIMS}>,
136
+ }
137
+ @group(0) @binding(2) var<storage, read> params: Params;
138
+
139
+ fn apply(x: f32) -> f32 { ${opBody} }
140
+
141
+ @compute @workgroup_size(${WORKGROUP_SIZE})
142
+ fn main(@builtin(global_invocation_id) gid: vec3u) {
143
+ let i = gid.x;
144
+ if (i >= params.out_size) { return; }
145
+
146
+ var outIdx: array<u32, ${MAX_DIMS}>;
147
+ toIndex(i, params.out_shape, params.out_dims, &outIdx);
148
+
149
+ var inIdx: array<u32, ${MAX_DIMS}>;
150
+ broadcastIndex(outIdx, params.out_dims, params.in_shape, params.in_dims, &inIdx);
151
+
152
+ let outPos = indexToPosition(outIdx, params.out_strides, params.out_dims);
153
+ let inPos = indexToPosition(inIdx, params.in_strides, params.in_dims);
154
+ out_data[outPos] = apply(in_data[inPos]);
155
+ }
156
+ `;
157
+ }
158
+ /**
159
+ * Aligned zip: all three tensors share shape & strides.
160
+ */
161
+ export function buildAlignedZipShader(opBody) {
162
+ return `
163
+ @group(0) @binding(0) var<storage, read> a_data: array<f32>;
164
+ @group(0) @binding(1) var<storage, read> b_data: array<f32>;
165
+ @group(0) @binding(2) var<storage, read_write> out_data: array<f32>;
166
+
167
+ struct Params { size: u32 }
168
+ @group(0) @binding(3) var<uniform> params: Params;
169
+
170
+ fn apply(a: f32, b: f32) -> f32 { ${opBody} }
171
+
172
+ @compute @workgroup_size(${WORKGROUP_SIZE})
173
+ fn main(@builtin(global_invocation_id) gid: vec3u) {
174
+ let i = gid.x;
175
+ if (i >= params.size) { return; }
176
+ out_data[i] = apply(a_data[i], b_data[i]);
177
+ }
178
+ `;
179
+ }
180
+ /**
181
+ * Broadcast zip: output, a, b may differ in shape.
182
+ * Uses storage buffer for params to avoid WGSL uniform array alignment rules.
183
+ */
184
+ export function buildBroadcastZipShader(opBody) {
185
+ return `
186
+ ${WGSL_INDEX_HELPERS}
187
+
188
+ @group(0) @binding(0) var<storage, read> a_data: array<f32>;
189
+ @group(0) @binding(1) var<storage, read> b_data: array<f32>;
190
+ @group(0) @binding(2) var<storage, read_write> out_data: array<f32>;
191
+
192
+ struct Params {
193
+ out_size: u32,
194
+ out_dims: u32,
195
+ a_dims: u32,
196
+ b_dims: u32,
197
+ out_shape: array<u32, ${MAX_DIMS}>,
198
+ out_strides: array<u32, ${MAX_DIMS}>,
199
+ a_shape: array<u32, ${MAX_DIMS}>,
200
+ a_strides: array<u32, ${MAX_DIMS}>,
201
+ b_shape: array<u32, ${MAX_DIMS}>,
202
+ b_strides: array<u32, ${MAX_DIMS}>,
203
+ }
204
+ @group(0) @binding(3) var<storage, read> params: Params;
205
+
206
+ fn apply(a: f32, b: f32) -> f32 { ${opBody} }
207
+
208
+ @compute @workgroup_size(${WORKGROUP_SIZE})
209
+ fn main(@builtin(global_invocation_id) gid: vec3u) {
210
+ let i = gid.x;
211
+ if (i >= params.out_size) { return; }
212
+
213
+ var outIdx: array<u32, ${MAX_DIMS}>;
214
+ toIndex(i, params.out_shape, params.out_dims, &outIdx);
215
+
216
+ var aIdx: array<u32, ${MAX_DIMS}>;
217
+ broadcastIndex(outIdx, params.out_dims, params.a_shape, params.a_dims, &aIdx);
218
+
219
+ var bIdx: array<u32, ${MAX_DIMS}>;
220
+ broadcastIndex(outIdx, params.out_dims, params.b_shape, params.b_dims, &bIdx);
221
+
222
+ let outPos = indexToPosition(outIdx, params.out_strides, params.out_dims);
223
+ let aPos = indexToPosition(aIdx, params.a_strides, params.a_dims);
224
+ let bPos = indexToPosition(bIdx, params.b_strides, params.b_dims);
225
+ out_data[outPos] = apply(a_data[aPos], b_data[bPos]);
226
+ }
227
+ `;
228
+ }
229
+ /**
230
+ * Sum practice: block-level partial sums using shared memory.
231
+ * Input: array of length size. Output: array of length ceil(size / WORKGROUP_SIZE).
232
+ * Each workgroup sums WORKGROUP_SIZE contiguous elements into one output cell.
233
+ */
234
+ export function buildSumPracticeShader() {
235
+ return `
236
+ const BLOCK_DIM: u32 = ${WORKGROUP_SIZE}u;
237
+ var<workgroup> sdata: array<f32, ${WORKGROUP_SIZE}>;
238
+
239
+ @group(0) @binding(0) var<storage, read> a: array<f32>;
240
+ @group(0) @binding(1) var<storage, read_write> result: array<f32>;
241
+
242
+ struct Params { size: u32 }
243
+ @group(0) @binding(2) var<uniform> params: Params;
244
+
245
+ @compute @workgroup_size(${WORKGROUP_SIZE})
246
+ fn main(
247
+ @builtin(local_invocation_index) tid: u32,
248
+ @builtin(workgroup_id) wid: vec3u,
249
+ ) {
250
+ let global_idx = wid.x * BLOCK_DIM + tid;
251
+
252
+ if (global_idx < params.size) {
253
+ sdata[tid] = a[global_idx];
254
+ } else {
255
+ sdata[tid] = 0.0;
256
+ }
257
+ workgroupBarrier();
258
+
259
+ for (var stride = BLOCK_DIM / 2u; stride > 0u; stride = stride >> 1u) {
260
+ if (tid < stride) {
261
+ sdata[tid] = sdata[tid] + sdata[tid + stride];
262
+ }
263
+ workgroupBarrier();
264
+ }
265
+
266
+ if (tid == 0u) {
267
+ result[wid.x] = sdata[0];
268
+ }
269
+ }
270
+ `;
271
+ }
272
+ /**
273
+ * General reduce along one dimension.
274
+ * One workgroup per output element. Threads cooperatively reduce
275
+ * the reduction dimension using shared memory.
276
+ * Uses storage buffer for params to avoid WGSL uniform array alignment rules.
277
+ */
278
+ export function buildReduceShader(opBody, identity) {
279
+ return `
280
+ ${WGSL_INDEX_HELPERS}
281
+
282
+ const BLOCK_DIM: u32 = ${WORKGROUP_SIZE}u;
283
+ var<workgroup> sdata: array<f32, ${WORKGROUP_SIZE}>;
284
+
285
+ @group(0) @binding(0) var<storage, read> a_data: array<f32>;
286
+ @group(0) @binding(1) var<storage, read_write> out_data: array<f32>;
287
+
288
+ struct Params {
289
+ out_size: u32,
290
+ out_dims: u32,
291
+ a_dims: u32,
292
+ reduce_dim: u32,
293
+ reduce_dim_size: u32,
294
+ reduce_stride: u32,
295
+ _pad0: u32,
296
+ _pad1: u32,
297
+ out_shape: array<u32, ${MAX_DIMS}>,
298
+ out_strides: array<u32, ${MAX_DIMS}>,
299
+ a_shape: array<u32, ${MAX_DIMS}>,
300
+ a_strides: array<u32, ${MAX_DIMS}>,
301
+ }
302
+ @group(0) @binding(2) var<storage, read> params: Params;
303
+
304
+ fn apply(a: f32, b: f32) -> f32 { ${opBody} }
305
+
306
+ @compute @workgroup_size(${WORKGROUP_SIZE})
307
+ fn main(
308
+ @builtin(local_invocation_index) tid: u32,
309
+ @builtin(workgroup_id) wid: vec3u,
310
+ ) {
311
+ let out_idx = wid.x;
312
+ if (out_idx >= params.out_size) { return; }
313
+
314
+ var outMI: array<u32, ${MAX_DIMS}>;
315
+ toIndex(out_idx, params.out_shape, params.out_dims, &outMI);
316
+
317
+ var aIdx: array<u32, ${MAX_DIMS}>;
318
+ for (var d: u32 = 0u; d < params.a_dims; d++) {
319
+ aIdx[d] = outMI[d];
320
+ }
321
+ aIdx[params.reduce_dim] = 0u;
322
+ let base_pos = indexToPosition(aIdx, params.a_strides, params.a_dims);
323
+
324
+ var local_acc: f32 = ${identity};
325
+ for (var j = tid; j < params.reduce_dim_size; j += BLOCK_DIM) {
326
+ local_acc = apply(local_acc, a_data[base_pos + j * params.reduce_stride]);
327
+ }
328
+ sdata[tid] = local_acc;
329
+ workgroupBarrier();
330
+
331
+ for (var s = BLOCK_DIM / 2u; s > 0u; s = s >> 1u) {
332
+ if (tid < s) {
333
+ sdata[tid] = apply(sdata[tid], sdata[tid + s]);
334
+ }
335
+ workgroupBarrier();
336
+ }
337
+
338
+ if (tid == 0u) {
339
+ let outPos = indexToPosition(outMI, params.out_strides, params.out_dims);
340
+ out_data[outPos] = sdata[0];
341
+ }
342
+ }
343
+ `;
344
+ }
345
+ /**
346
+ * Tiled matrix multiplication using workgroup shared memory.
347
+ * Dispatched as 3D: (ceil(N/BLOCK), ceil(M/BLOCK), batchSize).
348
+ * Each 16x16 workgroup computes one output tile, loading tiles of A and B
349
+ * into shared memory to satisfy:
350
+ * - all data read from shared memory (not global) during accumulation
351
+ * - each global cell of A and B read exactly once
352
+ * - each thread writes to global memory exactly once
353
+ * Supports arbitrary broadcast batch dimensions via stride-based indexing.
354
+ */
355
+ export function buildMatMulShader() {
356
+ const BLOCK = BLOCK_SIZE;
357
+ const SHARED = BLOCK * BLOCK;
358
+ return `
359
+ ${WGSL_INDEX_HELPERS}
360
+
361
+ const BLOCK: u32 = ${BLOCK}u;
362
+ var<workgroup> a_shared: array<f32, ${SHARED}>;
363
+ var<workgroup> b_shared: array<f32, ${SHARED}>;
364
+
365
+ @group(0) @binding(0) var<storage, read> a_data: array<f32>;
366
+ @group(0) @binding(1) var<storage, read> b_data: array<f32>;
367
+ @group(0) @binding(2) var<storage, read_write> out_data: array<f32>;
368
+
369
+ struct Params {
370
+ batch_size: u32,
371
+ M: u32,
372
+ N: u32,
373
+ K: u32,
374
+ out_dims: u32,
375
+ a_dims: u32,
376
+ b_dims: u32,
377
+ _pad: u32,
378
+ out_shape: array<u32, ${MAX_DIMS}>,
379
+ out_strides: array<u32, ${MAX_DIMS}>,
380
+ a_shape: array<u32, ${MAX_DIMS}>,
381
+ a_strides: array<u32, ${MAX_DIMS}>,
382
+ b_shape: array<u32, ${MAX_DIMS}>,
383
+ b_strides: array<u32, ${MAX_DIMS}>,
384
+ }
385
+ @group(0) @binding(3) var<storage, read> params: Params;
386
+
387
+ @compute @workgroup_size(${BLOCK}, ${BLOCK}, 1)
388
+ fn main(
389
+ @builtin(local_invocation_id) lid: vec3u,
390
+ @builtin(workgroup_id) wid: vec3u,
391
+ ) {
392
+ let tx = lid.x;
393
+ let ty = lid.y;
394
+ let row = wid.y * BLOCK + ty;
395
+ let col = wid.x * BLOCK + tx;
396
+ let batch = wid.z;
397
+
398
+ let batch_dims = params.out_dims - 2u;
399
+
400
+ // Decompose linear batch index into multi-dim output batch indices
401
+ var batch_idx: array<u32, ${MAX_DIMS}>;
402
+ if (batch_dims > 0u) {
403
+ var batch_shape: array<u32, ${MAX_DIMS}>;
404
+ for (var d: u32 = 0u; d < batch_dims; d++) {
405
+ batch_shape[d] = params.out_shape[d];
406
+ }
407
+ toIndex(batch, batch_shape, batch_dims, &batch_idx);
408
+ }
409
+
410
+ // Output base offset from batch indices
411
+ var out_base: u32 = 0u;
412
+ for (var d: u32 = 0u; d < batch_dims; d++) {
413
+ out_base += batch_idx[d] * params.out_strides[d];
414
+ }
415
+ let out_stride_row = params.out_strides[params.out_dims - 2u];
416
+ let out_stride_col = params.out_strides[params.out_dims - 1u];
417
+
418
+ // Broadcast batch indices into a's batch space and compute base offset
419
+ let a_batch_dims = params.a_dims - 2u;
420
+ var a_batch_idx: array<u32, ${MAX_DIMS}>;
421
+ if (a_batch_dims > 0u) {
422
+ var a_batch_shape: array<u32, ${MAX_DIMS}>;
423
+ for (var d: u32 = 0u; d < a_batch_dims; d++) {
424
+ a_batch_shape[d] = params.a_shape[d];
425
+ }
426
+ broadcastIndex(batch_idx, batch_dims, a_batch_shape, a_batch_dims, &a_batch_idx);
427
+ }
428
+ var a_base: u32 = 0u;
429
+ for (var d: u32 = 0u; d < a_batch_dims; d++) {
430
+ a_base += a_batch_idx[d] * params.a_strides[d];
431
+ }
432
+ let a_stride_row = params.a_strides[params.a_dims - 2u];
433
+ let a_stride_col = params.a_strides[params.a_dims - 1u];
434
+
435
+ // Broadcast batch indices into b's batch space and compute base offset
436
+ let b_batch_dims = params.b_dims - 2u;
437
+ var b_batch_idx: array<u32, ${MAX_DIMS}>;
438
+ if (b_batch_dims > 0u) {
439
+ var b_batch_shape: array<u32, ${MAX_DIMS}>;
440
+ for (var d: u32 = 0u; d < b_batch_dims; d++) {
441
+ b_batch_shape[d] = params.b_shape[d];
442
+ }
443
+ broadcastIndex(batch_idx, batch_dims, b_batch_shape, b_batch_dims, &b_batch_idx);
444
+ }
445
+ var b_base: u32 = 0u;
446
+ for (var d: u32 = 0u; d < b_batch_dims; d++) {
447
+ b_base += b_batch_idx[d] * params.b_strides[d];
448
+ }
449
+ let b_stride_row = params.b_strides[params.b_dims - 2u];
450
+ let b_stride_col = params.b_strides[params.b_dims - 1u];
451
+
452
+ // Tiled matmul: each tile loads BLOCK x BLOCK elements into shared memory
453
+ var acc: f32 = 0.0;
454
+ let num_tiles = (params.K + BLOCK - 1u) / BLOCK;
455
+
456
+ for (var t: u32 = 0u; t < num_tiles; t++) {
457
+ // Load A[row, t*BLOCK + tx] into a_shared[ty][tx]
458
+ let a_col = t * BLOCK + tx;
459
+ if (row < params.M && a_col < params.K) {
460
+ a_shared[ty * BLOCK + tx] = a_data[a_base + row * a_stride_row + a_col * a_stride_col];
461
+ } else {
462
+ a_shared[ty * BLOCK + tx] = 0.0;
463
+ }
464
+
465
+ // Load B[t*BLOCK + ty, col] into b_shared[ty][tx]
466
+ let b_row = t * BLOCK + ty;
467
+ if (b_row < params.K && col < params.N) {
468
+ b_shared[ty * BLOCK + tx] = b_data[b_base + b_row * b_stride_row + col * b_stride_col];
469
+ } else {
470
+ b_shared[ty * BLOCK + tx] = 0.0;
471
+ }
472
+
473
+ workgroupBarrier();
474
+
475
+ // Accumulate partial dot products from shared memory
476
+ for (var k: u32 = 0u; k < BLOCK; k++) {
477
+ acc += a_shared[ty * BLOCK + k] * b_shared[k * BLOCK + tx];
478
+ }
479
+
480
+ workgroupBarrier();
481
+ }
482
+
483
+ // Single write to global memory per thread
484
+ if (row < params.M && col < params.N) {
485
+ out_data[out_base + row * out_stride_row + col * out_stride_col] = acc;
486
+ }
487
+ }
488
+ `;
489
+ }
490
+ // ---- Tensor core matmul via experimental subgroup matrix ----
491
+ export const TC_TILE = 8; // Apple Metal simdgroup_matrix tile dimension
492
+ /**
493
+ * Experimental tensor-core-accelerated matmul shader using the Dawn/Chrome
494
+ * `chromium_experimental_subgroup_matrix` extension.
495
+ *
496
+ * On Apple Silicon (Metal), this maps to `simdgroup_matrix` 8x8 f32 hardware
497
+ * instructions. On Vulkan with VK_KHR_cooperative_matrix it maps to the
498
+ * equivalent SPIR-V ops.
499
+ *
500
+ * @param workgroupX - must equal the device's maxSubgroupSize (e.g. 64 on
501
+ * Apple M-series via Dawn, 32 or 64 on Vulkan). The WebGPU spec requires
502
+ * the x-dimension of workgroup_size to be a multiple of maxSubgroupSize
503
+ * when the shader uses subgroup matrices.
504
+ *
505
+ * Constraints (checked by the caller on the TS side):
506
+ * - All tensors must be contiguous (natural row-major strides)
507
+ * - M, N, K must all be multiples of TC_TILE (8)
508
+ * - Batch dimensions of A and B must match exactly (no broadcasting)
509
+ *
510
+ * Dispatch: (N/TC_TILE, M/TC_TILE, batchSize).
511
+ * Each workgroup computes one 8x8 output tile by tiling over the K dimension
512
+ * with subgroupMatrixLoad / MultiplyAccumulate / Store.
513
+ */
514
+ export function buildTensorCoreMatMulShader(workgroupX) {
515
+ return `
516
+ enable chromium_experimental_subgroup_matrix;
517
+
518
+ const TILE: u32 = ${TC_TILE}u;
519
+
520
+ @group(0) @binding(0) var<storage, read> a_data: array<f32>;
521
+ @group(0) @binding(1) var<storage, read> b_data: array<f32>;
522
+ @group(0) @binding(2) var<storage, read_write> out_data: array<f32>;
523
+
524
+ struct Params {
525
+ batch_size: u32,
526
+ M: u32,
527
+ N: u32,
528
+ K: u32,
529
+ }
530
+ @group(0) @binding(3) var<uniform> params: Params;
531
+
532
+ @compute @workgroup_size(${workgroupX}, 1, 1)
533
+ fn main(@builtin(workgroup_id) wid: vec3u) {
534
+ let tile_row = wid.y * TILE;
535
+ let tile_col = wid.x * TILE;
536
+ let batch = wid.z;
537
+
538
+ let a_base = batch * params.M * params.K;
539
+ let b_base = batch * params.K * params.N;
540
+ let out_base = batch * params.M * params.N;
541
+
542
+ // Zero-initialise 8x8 accumulator distributed across the subgroup
543
+ var acc = subgroup_matrix_result<f32, 8, 8>(0.0);
544
+
545
+ // Tile over the shared K dimension
546
+ let num_tiles = params.K / TILE;
547
+ for (var t: u32 = 0u; t < num_tiles; t++) {
548
+ let k = t * TILE;
549
+
550
+ // Load 8x8 tile of A (row-major, stride = K columns)
551
+ let a_tile = subgroupMatrixLoad<subgroup_matrix_left<f32, 8, 8>>(
552
+ &a_data, a_base + tile_row * params.K + k, false, params.K
553
+ );
554
+
555
+ // Load 8x8 tile of B (row-major, stride = N columns)
556
+ let b_tile = subgroupMatrixLoad<subgroup_matrix_right<f32, 8, 8>>(
557
+ &b_data, b_base + k * params.N + tile_col, false, params.N
558
+ );
559
+
560
+ // Hardware multiply-accumulate: acc += a_tile * b_tile
561
+ acc = subgroupMatrixMultiplyAccumulate(a_tile, b_tile, acc);
562
+ }
563
+
564
+ // Single store of the 8x8 result tile to global memory
565
+ subgroupMatrixStore(
566
+ &out_data, out_base + tile_row * params.N + tile_col, acc, false, params.N
567
+ );
568
+ }
569
+ `;
570
+ }
571
+ //# sourceMappingURL=gpu_kernels.js.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"gpu_kernels.js","sourceRoot":"","sources":["../src/gpu_kernels.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,SAAS,MAAM,gBAAgB,CAAC;AAE5C,MAAM,CAAC,MAAM,cAAc,GAAG,GAAG,CAAC;AAClC,MAAM,CAAC,MAAM,UAAU,GAAG,EAAE,CAAC,CAAC,yDAAyD;AACvF,MAAM,QAAQ,GAAG,CAAC,CAAC;AAEnB,iCAAiC;AAEjC,MAAM,CAAC,MAAM,SAAS,GAA2B;IAC7C,GAAG,EAAM,YAAY;IACrB,EAAE,EAAO,WAAW;IACpB,OAAO,EAAE,oFAAoF;IAC7F,IAAI,EAAK,qBAAqB;IAC9B,GAAG,EAAM,gBAAgB;IACzB,GAAG,EAAM,gBAAgB;IACzB,GAAG,EAAM,iBAAiB;CAC7B,CAAC;AAEF,MAAM,CAAC,MAAM,UAAU,GAA2B;IAC9C,GAAG,EAAM,eAAe;IACxB,GAAG,EAAM,eAAe;IACxB,GAAG,EAAM,mBAAmB;IAC5B,EAAE,EAAO,iCAAiC;IAC1C,EAAE,EAAO,gDAAgD;IACzD,OAAO,EAAE,6CAA6C;CACzD,CAAC;AAEF,2DAA2D;AAC3D,MAAM,aAAa,GAAG,IAAI,GAAG,EAAoB,CAAC;AAClD,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAC5C,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,EAAE,EAAO,IAAI,CAAC,CAAC;AAC3C,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;AAChD,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,IAAI,EAAK,MAAM,CAAC,CAAC;AAC7C,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAC5C,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAC5C,aAAa,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAE5C,MAAM,cAAc,GAAG,IAAI,GAAG,EAAoB,CAAC;AACnD,cAAc,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAC7C,cAAc,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAC7C,cAAc,CAAC,GAAG,CAAC,SAAS,CAAC,GAAG,EAAM,KAAK,CAAC,CAAC;AAC7C,cAAc,CAAC,GAAG,CAAC,SAAS,CAAC,EAAE,EAAO,IAAI,CAAC,CAAC;AAC5C,cAAc,CAAC,GAAG,CAAC,SAAS,CAAC,EAAE,EAAO,IAAI,CAAC,CAAC;AAC5C,cAAc,CAAC,GAAG,CAAC,SAAS,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;AAEjD,MAAM,UAAU,cAAc,CAAC,EAAY;IACvC,MAAM,IAAI,GAAG,aAAa,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACnC,IAAI,IAAI;QAAE,OAAO,IAAI,CAAC;IACtB,MAAM,IAAI,KAAK,CAAC,yBAAyB,EAAE,CAAC,IAAI,IAAI,EAAE,CAAC,QAAQ,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC;AACtF,CAAC;AAED,MAAM,UAAU,eAAe,CAAC,EAAY;IACxC,MAAM,IAAI,GAAG,cAAc,CAAC,GAAG,CAAC,EAAE,CAAC,CAAC;IACpC,IAAI,IAAI;QAAE,OAAO,IAAI,CAAC;IACtB,MAAM,IAAI,KAAK,CAAC,0BAA0B,EAAE,CAAC,IAAI,IAAI,EAAE,CAAC,QAAQ,EAAE,CAAC,KAAK,CAAC,CAAC,EAAE,EAAE,CAAC,EAAE,CAAC,CAAC;AACvF,CAAC;AAED,MAAM,CAAC,MAAM,eAAe,GAA2B;IACnD,GAAG,EAAE,KAAK;IACV,GAAG,EAAE,KAAK;IACV,GAAG,EAAE,UAAU;CAClB,CAAC;AAEF,uDAAuD;AACvD,wFAAwF;AAExF,MAAM,kBAAkB,GAAG;wBACH,QAAQ;;6CAEa,QAAQ,mDAAmD,QAAQ;;;;;;;;;qCAS3E,QAAQ,uBAAuB,QAAQ;;;;;;;;;yBASnD,QAAQ;6BACJ,QAAQ;wCACG,QAAQ;;;;;;;;;;;;CAY/C,CAAC;AAEF,qCAAqC;AAErC;;GAEG;AACH,MAAM,UAAU,qBAAqB,CAAC,MAAc;IAChD,OAAO;;;;;;;4BAOiB,MAAM;;2BAEP,cAAc;;;;;;CAMxC,CAAC;AACF,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,uBAAuB,CAAC,MAAc;IAClD,OAAO;EACT,kBAAkB;;;;;;;;;;8BAUU,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;;;;4BAIV,MAAM;;2BAEP,cAAc;;;;;6BAKZ,QAAQ;;;4BAGT,QAAQ;;;;;;;CAOnC,CAAC;AACF,CAAC;AAED;;GAEG;AACH,MAAM,UAAU,qBAAqB,CAAC,MAAc;IAChD,OAAO;;;;;;;;oCAQyB,MAAM;;2BAEf,cAAc;;;;;;CAMxC,CAAC;AACF,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,uBAAuB,CAAC,MAAc;IAClD,OAAO;EACT,kBAAkB;;;;;;;;;;;8BAWU,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;;;;oCAIF,MAAM;;2BAEf,cAAc;;;;;6BAKZ,QAAQ;;;2BAGV,QAAQ;;;2BAGR,QAAQ;;;;;;;;CAQlC,CAAC;AACF,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,sBAAsB;IAClC,OAAO;yBACc,cAAc;mCACJ,cAAc;;;;;;;;2BAQtB,cAAc;;;;;;;;;;;;;;;;;;;;;;;;;CAyBxC,CAAC;AACF,CAAC;AAED;;;;;GAKG;AACH,MAAM,UAAU,iBAAiB,CAAC,MAAc,EAAE,QAAgB;IAC9D,OAAO;EACT,kBAAkB;;yBAEK,cAAc;mCACJ,cAAc;;;;;;;;;;;;;;8BAcnB,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;;;;oCAIF,MAAM;;2BAEf,cAAc;;;;;;;;4BAQb,QAAQ;;;2BAGT,QAAQ;;;;;;;2BAOR,QAAQ;;;;;;;;;;;;;;;;;;;CAmBlC,CAAC;AACF,CAAC;AAED;;;;;;;;;GASG;AACH,MAAM,UAAU,iBAAiB;IAC7B,MAAM,KAAK,GAAG,UAAU,CAAC;IACzB,MAAM,MAAM,GAAG,KAAK,GAAG,KAAK,CAAC;IAC7B,OAAO;EACT,kBAAkB;;qBAEC,KAAK;sCACY,MAAM;sCACN,MAAM;;;;;;;;;;;;;;;8BAed,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;8BACR,QAAQ;;;;2BAIX,KAAK,KAAK,KAAK;;;;;;;;;;;;;;gCAcV,QAAQ;;sCAEF,QAAQ;;;;;;;;;;;;;;;;;kCAiBZ,QAAQ;;wCAEF,QAAQ;;;;;;;;;;;;;;;kCAed,QAAQ;;wCAEF,QAAQ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAiD/C,CAAC;AACF,CAAC;AAED,gEAAgE;AAEhE,MAAM,CAAC,MAAM,OAAO,GAAG,CAAC,CAAC,CAAC,8CAA8C;AAExE;;;;;;;;;;;;;;;;;;;;;GAqBG;AACH,MAAM,UAAU,2BAA2B,CAAC,UAAkB;IAC1D,OAAO;;;oBAGS,OAAO;;;;;;;;;;;;;;2BAcA,UAAU;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;CAqCpC,CAAC;AACF,CAAC"}
@@ -0,0 +1,43 @@
1
+ import type { Storage, Shape, Strides } from './tensor_data.js';
2
+ /**
3
+ * Practice sum kernel: given array of length `size`, produce ceil(size / WORKGROUP_SIZE)
4
+ * partial sums. Each workgroup sums WORKGROUP_SIZE contiguous elements using shared memory.
5
+ */
6
+ export declare function _sumPractice(out: Storage, a: Storage, size: number): Promise<void>;
7
+ /**
8
+ * GPU higher-order tensor map. fn must be a known operator from operators.ts.
9
+ */
10
+ export declare function gpuTensorMap(fn: (x: number) => number): (outStorage: Storage, outShape: Shape, outStrides: Strides, inStorage: Storage, inShape: Shape, inStrides: Strides) => Promise<void>;
11
+ /**
12
+ * GPU higher-order tensor zip (binary map). fn must be a known binary operator.
13
+ */
14
+ export declare function gpuTensorZip(fn: (a: number, b: number) => number): (outStorage: Storage, outShape: Shape, outStrides: Strides, aStorage: Storage, aShape: Shape, aStrides: Strides, bStorage: Storage, bShape: Shape, bStrides: Strides) => Promise<void>;
15
+ /**
16
+ * GPU higher-order tensor reduce. fn must be a known binary operator
17
+ * with an entry in REDUCE_IDENTITY.
18
+ * One workgroup per output element; threads cooperatively reduce
19
+ * the target dimension using shared memory tree reduction.
20
+ */
21
+ export declare function gpuTensorReduce(fn: (acc: number, x: number) => number): (outStorage: Storage, outShape: Shape, outStrides: Strides, aStorage: Storage, aShape: Shape, aStrides: Strides, reduceDim: number) => Promise<void>;
22
+ /**
23
+ * GPU tiled matrix multiplication with shared memory.
24
+ * Supports arbitrary broadcast batch dimensions as long as
25
+ * aShape[-1] === bShape[-2].
26
+ */
27
+ export declare function gpuTensorMatrixMultiply(outStorage: Storage, outShape: Shape, outStrides: Strides, outSize: number, aStorage: Storage, aShape: Shape, aStrides: Strides, bStorage: Storage, bShape: Shape, bStrides: Strides): Promise<void>;
28
+ /**
29
+ * Experimental tensor-core-accelerated matrix multiply.
30
+ *
31
+ * Uses the Dawn `chromium_experimental_subgroup_matrix` extension which maps
32
+ * to Apple simdgroup_matrix (8x8 f32) on Metal and VK_KHR_cooperative_matrix
33
+ * on Vulkan. Returns true if the fast path executed, false if a precondition
34
+ * was not met and the caller should fall back to `gpuTensorMatrixMultiply`.
35
+ *
36
+ * Preconditions (returns false when any is violated):
37
+ * - Tensor core device available (hardware + experimental feature)
38
+ * - All tensors contiguous (natural row-major strides)
39
+ * - M, N, K all divisible by 8
40
+ * - Batch dimensions of A and B identical (no broadcasting)
41
+ */
42
+ export declare function gpuTensorMatrixMultiplyTensorCore(outStorage: Storage, outShape: Shape, outStrides: Strides, outSize: number, aStorage: Storage, aShape: Shape, aStrides: Strides, bStorage: Storage, bShape: Shape, bStrides: Strides): Promise<boolean>;
43
+ //# sourceMappingURL=gpu_ops.d.ts.map
@@ -0,0 +1 @@
1
+ {"version":3,"file":"gpu_ops.d.ts","sourceRoot":"","sources":["../src/gpu_ops.ts"],"names":[],"mappings":"AAAA,OAAO,KAAK,EAAE,OAAO,EAAE,KAAK,EAAE,OAAO,EAAE,MAAM,kBAAkB,CAAC;AAgFhE;;;GAGG;AACH,wBAAsB,YAAY,CAC9B,GAAG,EAAE,OAAO,EACZ,CAAC,EAAE,OAAO,EACV,IAAI,EAAE,MAAM,GACb,OAAO,CAAC,IAAI,CAAC,CAkCf;AAID;;GAEG;AACH,wBAAgB,YAAY,CACxB,EAAE,EAAE,CAAC,CAAC,EAAE,MAAM,KAAK,MAAM,GAC1B,CACC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,SAAS,EAAE,OAAO,EAClB,OAAO,EAAE,KAAK,EACd,SAAS,EAAE,OAAO,KACjB,OAAO,CAAC,IAAI,CAAC,CA4DjB;AAID;;GAEG;AACH,wBAAgB,YAAY,CACxB,EAAE,EAAE,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,KAAK,MAAM,GACrC,CACC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,EACjB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,KAChB,OAAO,CAAC,IAAI,CAAC,CAsEjB;AAID;;;;;GAKG;AACH,wBAAgB,eAAe,CAC3B,EAAE,EAAE,CAAC,GAAG,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,KAAK,MAAM,GACvC,CACC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,EACjB,SAAS,EAAE,MAAM,KAChB,OAAO,CAAC,IAAI,CAAC,CA+DjB;AAID;;;;GAIG;AACH,wBAAsB,uBAAuB,CACzC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,OAAO,EAAE,MAAM,EACf,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,EACjB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,GAClB,OAAO,CAAC,IAAI,CAAC,CA4Df;AAqBD;;;;;;;;;;;;;GAaG;AACH,wBAAsB,iCAAiC,CACnD,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,OAAO,EAAE,MAAM,EACf,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,EACjB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,GAClB,OAAO,CAAC,OAAO,CAAC,CAkElB"}