@sparkleideas/performance 3.0.0-alpha.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,566 @@
1
+ /**
2
+ * Multi-Head Attention Benchmark
3
+ *
4
+ * Target: Baseline comparison for Flash Attention improvements
5
+ *
6
+ * Measures multi-head attention performance with different
7
+ * configurations and parallelization strategies.
8
+ */
9
+
10
+ import { benchmark, BenchmarkRunner, formatTime, formatBytes } from '../../src/framework/benchmark.js';
11
+
12
+ // ============================================================================
13
+ // Multi-Head Attention Types
14
+ // ============================================================================
15
+
16
+ interface MHAConfig {
17
+ seqLength: number;
18
+ headDim: number;
19
+ numHeads: number;
20
+ batchSize: number;
21
+ dropout?: number;
22
+ }
23
+
24
+ interface MHAResult {
25
+ output: Float32Array;
26
+ headOutputs: Float32Array[];
27
+ memoryUsed: number;
28
+ computeTime: number;
29
+ }
30
+
31
+ // ============================================================================
32
+ // Multi-Head Attention Implementation
33
+ // ============================================================================
34
+
35
+ /**
36
+ * Standard multi-head attention
37
+ */
38
+ class MultiHeadAttention {
39
+ private config: MHAConfig;
40
+
41
+ constructor(config: MHAConfig) {
42
+ this.config = config;
43
+ }
44
+
45
+ /**
46
+ * Single head attention
47
+ */
48
+ private singleHeadAttention(
49
+ query: Float32Array,
50
+ key: Float32Array,
51
+ value: Float32Array
52
+ ): Float32Array {
53
+ const { seqLength, headDim } = this.config;
54
+ const scale = 1 / Math.sqrt(headDim);
55
+
56
+ // Compute attention scores
57
+ const scores = new Float32Array(seqLength * seqLength);
58
+
59
+ for (let i = 0; i < seqLength; i++) {
60
+ for (let j = 0; j < seqLength; j++) {
61
+ let dot = 0;
62
+ for (let k = 0; k < headDim; k++) {
63
+ dot += query[i * headDim + k]! * key[j * headDim + k]!;
64
+ }
65
+ scores[i * seqLength + j] = dot * scale;
66
+ }
67
+ }
68
+
69
+ // Softmax
70
+ for (let i = 0; i < seqLength; i++) {
71
+ let max = -Infinity;
72
+ for (let j = 0; j < seqLength; j++) {
73
+ max = Math.max(max, scores[i * seqLength + j]!);
74
+ }
75
+
76
+ let sum = 0;
77
+ for (let j = 0; j < seqLength; j++) {
78
+ const exp = Math.exp(scores[i * seqLength + j]! - max);
79
+ scores[i * seqLength + j] = exp;
80
+ sum += exp;
81
+ }
82
+
83
+ for (let j = 0; j < seqLength; j++) {
84
+ scores[i * seqLength + j]! /= sum;
85
+ }
86
+ }
87
+
88
+ // Weighted sum
89
+ const output = new Float32Array(seqLength * headDim);
90
+
91
+ for (let i = 0; i < seqLength; i++) {
92
+ for (let k = 0; k < headDim; k++) {
93
+ let sum = 0;
94
+ for (let j = 0; j < seqLength; j++) {
95
+ sum += scores[i * seqLength + j]! * value[j * headDim + k]!;
96
+ }
97
+ output[i * headDim + k] = sum;
98
+ }
99
+ }
100
+
101
+ return output;
102
+ }
103
+
104
+ /**
105
+ * Forward pass through all heads sequentially
106
+ */
107
+ forwardSequential(
108
+ queries: Float32Array[],
109
+ keys: Float32Array[],
110
+ values: Float32Array[]
111
+ ): MHAResult {
112
+ const memBefore = process.memoryUsage().heapUsed;
113
+ const startTime = performance.now();
114
+
115
+ const { seqLength, headDim, numHeads } = this.config;
116
+ const headOutputs: Float32Array[] = [];
117
+
118
+ // Process each head sequentially
119
+ for (let h = 0; h < numHeads; h++) {
120
+ const headOutput = this.singleHeadAttention(
121
+ queries[h]!,
122
+ keys[h]!,
123
+ values[h]!
124
+ );
125
+ headOutputs.push(headOutput);
126
+ }
127
+
128
+ // Concatenate heads
129
+ const output = new Float32Array(seqLength * headDim * numHeads);
130
+ for (let h = 0; h < numHeads; h++) {
131
+ for (let i = 0; i < seqLength; i++) {
132
+ for (let k = 0; k < headDim; k++) {
133
+ output[i * headDim * numHeads + h * headDim + k] =
134
+ headOutputs[h]![i * headDim + k]!;
135
+ }
136
+ }
137
+ }
138
+
139
+ return {
140
+ output,
141
+ headOutputs,
142
+ memoryUsed: process.memoryUsage().heapUsed - memBefore,
143
+ computeTime: performance.now() - startTime,
144
+ };
145
+ }
146
+
147
+ /**
148
+ * Forward pass with parallel head computation (simulated)
149
+ */
150
+ async forwardParallel(
151
+ queries: Float32Array[],
152
+ keys: Float32Array[],
153
+ values: Float32Array[]
154
+ ): Promise<MHAResult> {
155
+ const memBefore = process.memoryUsage().heapUsed;
156
+ const startTime = performance.now();
157
+
158
+ const { seqLength, headDim, numHeads } = this.config;
159
+
160
+ // Process all heads in parallel
161
+ const headPromises = queries.map((q, h) =>
162
+ Promise.resolve(this.singleHeadAttention(q, keys[h]!, values[h]!))
163
+ );
164
+
165
+ const headOutputs = await Promise.all(headPromises);
166
+
167
+ // Concatenate heads
168
+ const output = new Float32Array(seqLength * headDim * numHeads);
169
+ for (let h = 0; h < numHeads; h++) {
170
+ for (let i = 0; i < seqLength; i++) {
171
+ for (let k = 0; k < headDim; k++) {
172
+ output[i * headDim * numHeads + h * headDim + k] =
173
+ headOutputs[h]![i * headDim + k]!;
174
+ }
175
+ }
176
+ }
177
+
178
+ return {
179
+ output,
180
+ headOutputs,
181
+ memoryUsed: process.memoryUsage().heapUsed - memBefore,
182
+ computeTime: performance.now() - startTime,
183
+ };
184
+ }
185
+ }
186
+
187
+ /**
188
+ * Grouped-Query Attention (GQA)
189
+ * Multiple query heads share fewer key/value heads
190
+ */
191
+ class GroupedQueryAttention {
192
+ private config: MHAConfig;
193
+ private kvHeads: number;
194
+ private groupSize: number;
195
+
196
+ constructor(config: MHAConfig, kvHeads: number) {
197
+ this.config = config;
198
+ this.kvHeads = kvHeads;
199
+ this.groupSize = config.numHeads / kvHeads;
200
+ }
201
+
202
+ forward(
203
+ queries: Float32Array[],
204
+ keys: Float32Array[],
205
+ values: Float32Array[]
206
+ ): MHAResult {
207
+ const memBefore = process.memoryUsage().heapUsed;
208
+ const startTime = performance.now();
209
+
210
+ const { seqLength, headDim, numHeads } = this.config;
211
+ const headOutputs: Float32Array[] = [];
212
+
213
+ // Process each query head, sharing K/V within groups
214
+ for (let h = 0; h < numHeads; h++) {
215
+ const kvIndex = Math.floor(h / this.groupSize);
216
+ const headOutput = this.singleHeadAttention(
217
+ queries[h]!,
218
+ keys[kvIndex]!,
219
+ values[kvIndex]!
220
+ );
221
+ headOutputs.push(headOutput);
222
+ }
223
+
224
+ // Concatenate heads
225
+ const output = new Float32Array(seqLength * headDim * numHeads);
226
+ for (let h = 0; h < numHeads; h++) {
227
+ for (let i = 0; i < seqLength; i++) {
228
+ for (let k = 0; k < headDim; k++) {
229
+ output[i * headDim * numHeads + h * headDim + k] =
230
+ headOutputs[h]![i * headDim + k]!;
231
+ }
232
+ }
233
+ }
234
+
235
+ return {
236
+ output,
237
+ headOutputs,
238
+ memoryUsed: process.memoryUsage().heapUsed - memBefore,
239
+ computeTime: performance.now() - startTime,
240
+ };
241
+ }
242
+
243
+ private singleHeadAttention(
244
+ query: Float32Array,
245
+ key: Float32Array,
246
+ value: Float32Array
247
+ ): Float32Array {
248
+ const { seqLength, headDim } = this.config;
249
+ const scale = 1 / Math.sqrt(headDim);
250
+
251
+ const scores = new Float32Array(seqLength * seqLength);
252
+
253
+ for (let i = 0; i < seqLength; i++) {
254
+ for (let j = 0; j < seqLength; j++) {
255
+ let dot = 0;
256
+ for (let k = 0; k < headDim; k++) {
257
+ dot += query[i * headDim + k]! * key[j * headDim + k]!;
258
+ }
259
+ scores[i * seqLength + j] = dot * scale;
260
+ }
261
+ }
262
+
263
+ // Softmax
264
+ for (let i = 0; i < seqLength; i++) {
265
+ let max = -Infinity;
266
+ for (let j = 0; j < seqLength; j++) {
267
+ max = Math.max(max, scores[i * seqLength + j]!);
268
+ }
269
+
270
+ let sum = 0;
271
+ for (let j = 0; j < seqLength; j++) {
272
+ const exp = Math.exp(scores[i * seqLength + j]! - max);
273
+ scores[i * seqLength + j] = exp;
274
+ sum += exp;
275
+ }
276
+
277
+ for (let j = 0; j < seqLength; j++) {
278
+ scores[i * seqLength + j]! /= sum;
279
+ }
280
+ }
281
+
282
+ const output = new Float32Array(seqLength * headDim);
283
+
284
+ for (let i = 0; i < seqLength; i++) {
285
+ for (let k = 0; k < headDim; k++) {
286
+ let sum = 0;
287
+ for (let j = 0; j < seqLength; j++) {
288
+ sum += scores[i * seqLength + j]! * value[j * headDim + k]!;
289
+ }
290
+ output[i * headDim + k] = sum;
291
+ }
292
+ }
293
+
294
+ return output;
295
+ }
296
+ }
297
+
298
+ // ============================================================================
299
+ // Helper Functions
300
+ // ============================================================================
301
+
302
+ function generateRandomTensor(size: number): Float32Array {
303
+ const tensor = new Float32Array(size);
304
+ for (let i = 0; i < size; i++) {
305
+ tensor[i] = Math.random() * 2 - 1;
306
+ }
307
+ return tensor;
308
+ }
309
+
310
+ function createMultiHeadQKV(
311
+ config: MHAConfig
312
+ ): { queries: Float32Array[]; keys: Float32Array[]; values: Float32Array[] } {
313
+ const { seqLength, headDim, numHeads } = config;
314
+ const size = seqLength * headDim;
315
+
316
+ return {
317
+ queries: Array.from({ length: numHeads }, () => generateRandomTensor(size)),
318
+ keys: Array.from({ length: numHeads }, () => generateRandomTensor(size)),
319
+ values: Array.from({ length: numHeads }, () => generateRandomTensor(size)),
320
+ };
321
+ }
322
+
323
+ // ============================================================================
324
+ // Benchmark Suite
325
+ // ============================================================================
326
+
327
+ export async function runMultiHeadAttentionBenchmarks(): Promise<void> {
328
+ const runner = new BenchmarkRunner('Multi-Head Attention');
329
+
330
+ console.log('\n--- Multi-Head Attention Benchmarks ---\n');
331
+
332
+ // Test configurations
333
+ const configs: MHAConfig[] = [
334
+ { seqLength: 128, headDim: 64, numHeads: 8, batchSize: 1 },
335
+ { seqLength: 256, headDim: 64, numHeads: 8, batchSize: 1 },
336
+ { seqLength: 512, headDim: 64, numHeads: 8, batchSize: 1 },
337
+ { seqLength: 256, headDim: 64, numHeads: 16, batchSize: 1 },
338
+ ];
339
+
340
+ for (const config of configs) {
341
+ const { seqLength, numHeads } = config;
342
+ console.log(`\n--- Seq: ${seqLength}, Heads: ${numHeads} ---`);
343
+
344
+ const mha = new MultiHeadAttention(config);
345
+ const { queries, keys, values } = createMultiHeadQKV(config);
346
+
347
+ // Sequential forward
348
+ const seqResult = await runner.run(
349
+ `mha-sequential-seq${seqLength}-h${numHeads}`,
350
+ async () => {
351
+ mha.forwardSequential(queries, keys, values);
352
+ },
353
+ { iterations: 50 }
354
+ );
355
+
356
+ console.log(`Sequential: ${formatTime(seqResult.mean)}`);
357
+
358
+ // Parallel forward
359
+ const parallelResult = await runner.run(
360
+ `mha-parallel-seq${seqLength}-h${numHeads}`,
361
+ async () => {
362
+ await mha.forwardParallel(queries, keys, values);
363
+ },
364
+ { iterations: 50 }
365
+ );
366
+
367
+ console.log(`Parallel: ${formatTime(parallelResult.mean)}`);
368
+
369
+ // Speedup
370
+ const speedup = seqResult.mean / parallelResult.mean;
371
+ console.log(`Parallel Speedup: ${speedup.toFixed(2)}x`);
372
+ }
373
+
374
+ // Grouped-Query Attention comparison
375
+ console.log('\n--- Grouped-Query Attention Comparison ---');
376
+
377
+ const gqaConfig: MHAConfig = {
378
+ seqLength: 256,
379
+ headDim: 64,
380
+ numHeads: 8,
381
+ batchSize: 1,
382
+ };
383
+
384
+ const standardMHA = new MultiHeadAttention(gqaConfig);
385
+ const { queries, keys, values } = createMultiHeadQKV(gqaConfig);
386
+
387
+ // Standard MHA (8 query heads, 8 KV heads)
388
+ const standardResult = await runner.run(
389
+ 'mha-standard-8heads',
390
+ async () => {
391
+ standardMHA.forwardSequential(queries, keys, values);
392
+ },
393
+ { iterations: 50 }
394
+ );
395
+
396
+ console.log(`Standard MHA (8 QKV heads): ${formatTime(standardResult.mean)}`);
397
+
398
+ // GQA with 4 KV heads
399
+ const gqa4 = new GroupedQueryAttention(gqaConfig, 4);
400
+ const kvFor4 = {
401
+ keys: keys.slice(0, 4),
402
+ values: values.slice(0, 4),
403
+ };
404
+
405
+ const gqa4Result = await runner.run(
406
+ 'gqa-4-kv-heads',
407
+ async () => {
408
+ gqa4.forward(queries, kvFor4.keys, kvFor4.values);
409
+ },
410
+ { iterations: 50 }
411
+ );
412
+
413
+ console.log(`GQA (8 Q, 4 KV heads): ${formatTime(gqa4Result.mean)}`);
414
+
415
+ // GQA with 2 KV heads
416
+ const gqa2 = new GroupedQueryAttention(gqaConfig, 2);
417
+ const kvFor2 = {
418
+ keys: keys.slice(0, 2),
419
+ values: values.slice(0, 2),
420
+ };
421
+
422
+ const gqa2Result = await runner.run(
423
+ 'gqa-2-kv-heads',
424
+ async () => {
425
+ gqa2.forward(queries, kvFor2.keys, kvFor2.values);
426
+ },
427
+ { iterations: 50 }
428
+ );
429
+
430
+ console.log(`GQA (8 Q, 2 KV heads): ${formatTime(gqa2Result.mean)}`);
431
+
432
+ // Memory comparison
433
+ console.log('\n--- Memory Usage Comparison ---');
434
+
435
+ const memConfig: MHAConfig = {
436
+ seqLength: 512,
437
+ headDim: 64,
438
+ numHeads: 8,
439
+ batchSize: 1,
440
+ };
441
+
442
+ const { queries: q512, keys: k512, values: v512 } = createMultiHeadQKV(memConfig);
443
+ const mha512 = new MultiHeadAttention(memConfig);
444
+
445
+ const memResult = mha512.forwardSequential(q512, k512, v512);
446
+ console.log(`MHA Memory (seq=512, h=8): ${formatBytes(memResult.memoryUsed)}`);
447
+
448
+ // Per-head memory
449
+ const perHeadMem = memResult.memoryUsed / memConfig.numHeads;
450
+ console.log(`Per-head memory: ${formatBytes(perHeadMem)}`);
451
+
452
+ // Theoretical attention matrix size
453
+ const attentionMatrixSize = memConfig.seqLength * memConfig.seqLength * 4 * memConfig.numHeads;
454
+ console.log(`Theoretical attention matrices: ${formatBytes(attentionMatrixSize)}`);
455
+
456
+ // Summary
457
+ console.log('\n--- Summary ---');
458
+ console.log('Standard MHA vs GQA:');
459
+ console.log(` 8 KV heads: ${formatTime(standardResult.mean)}`);
460
+ console.log(` 4 KV heads: ${formatTime(gqa4Result.mean)} (${(standardResult.mean / gqa4Result.mean).toFixed(2)}x)`);
461
+ console.log(` 2 KV heads: ${formatTime(gqa2Result.mean)} (${(standardResult.mean / gqa2Result.mean).toFixed(2)}x)`);
462
+
463
+ // Print full results
464
+ runner.printResults();
465
+ }
466
+
467
+ // ============================================================================
468
+ // Multi-Head Attention Optimization Strategies
469
+ // ============================================================================
470
+
471
+ export const mhaOptimizations = {
472
+ /**
473
+ * Parallel head computation
474
+ */
475
+ parallelHeads: {
476
+ description: 'Compute attention heads in parallel',
477
+ expectedImprovement: 'Up to num_heads x speedup',
478
+ implementation: `
479
+ async function parallelMHA(queries, keys, values) {
480
+ const headResults = await Promise.all(
481
+ queries.map((q, i) => computeHead(q, keys[i], values[i]))
482
+ );
483
+ return concatenateHeads(headResults);
484
+ }
485
+ `,
486
+ },
487
+
488
+ /**
489
+ * Grouped-Query Attention
490
+ */
491
+ groupedQueryAttention: {
492
+ description: 'Share K/V across multiple query heads',
493
+ expectedImprovement: '2-4x memory, 1.5-2x speed',
494
+ implementation: `
495
+ // Instead of numHeads K/V pairs, use numHeads / groupSize
496
+ class GQA {
497
+ forward(queries, keys, values) {
498
+ return queries.map((q, i) => {
499
+ const kvIdx = Math.floor(i / groupSize);
500
+ return attention(q, keys[kvIdx], values[kvIdx]);
501
+ });
502
+ }
503
+ }
504
+ `,
505
+ },
506
+
507
+ /**
508
+ * Multi-Query Attention
509
+ */
510
+ multiQueryAttention: {
511
+ description: 'Single K/V pair shared across all heads',
512
+ expectedImprovement: '8x memory, 2-3x speed',
513
+ implementation: `
514
+ class MQA {
515
+ forward(queries, key, value) {
516
+ // All heads share single K and V
517
+ return queries.map(q => attention(q, key, value));
518
+ }
519
+ }
520
+ `,
521
+ },
522
+
523
+ /**
524
+ * Fused QKV projection
525
+ */
526
+ fusedQKVProjection: {
527
+ description: 'Fuse Q, K, V projections into single operation',
528
+ expectedImprovement: '20-30% projection overhead',
529
+ implementation: `
530
+ function fusedQKV(input, weights) {
531
+ // Single matmul for all QKV
532
+ const qkv = matmul(input, weights.qkv);
533
+ return splitQKV(qkv, numHeads, headDim);
534
+ }
535
+ `,
536
+ },
537
+
538
+ /**
539
+ * KV caching for inference
540
+ */
541
+ kvCaching: {
542
+ description: 'Cache K/V for autoregressive generation',
543
+ expectedImprovement: 'O(1) per token instead of O(n)',
544
+ implementation: `
545
+ class CachedMHA {
546
+ private kvCache: { k: Float32Array[], v: Float32Array[] } = { k: [], v: [] };
547
+
548
+ forward(query, key, value, useCache: boolean) {
549
+ if (useCache) {
550
+ this.kvCache.k.push(key);
551
+ this.kvCache.v.push(value);
552
+ return attention(query, this.kvCache.k, this.kvCache.v);
553
+ }
554
+ return attention(query, [key], [value]);
555
+ }
556
+ }
557
+ `,
558
+ },
559
+ };
560
+
561
+ // Run if executed directly
562
+ if (import.meta.url === `file://${process.argv[1]}`) {
563
+ runMultiHeadAttentionBenchmarks().catch(console.error);
564
+ }
565
+
566
+ export default runMultiHeadAttentionBenchmarks;