@sparkleideas/performance 3.0.0-alpha.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +256 -0
- package/__tests__/README.md +242 -0
- package/__tests__/attention.test.ts +516 -0
- package/__tests__/benchmarks.test.ts +515 -0
- package/benchmarks/attention/memory-efficiency.bench.ts +569 -0
- package/benchmarks/attention/multi-head-attention.bench.ts +566 -0
- package/benchmarks/startup/agent-spawn.bench.ts +422 -0
- package/benchmarks/startup/cli-cold-start.bench.ts +327 -0
- package/benchmarks/startup/cli-warm-start.bench.ts +277 -0
- package/benchmarks/startup/mcp-server-init.bench.ts +380 -0
- package/docs/ATTENTION.md +277 -0
- package/package.json +29 -0
- package/src/attention-benchmarks.ts +459 -0
- package/src/attention-integration.ts +507 -0
- package/src/examples/flash-attention-demo.ts +160 -0
- package/src/examples/quick-test.ts +62 -0
- package/src/framework/benchmark.ts +583 -0
- package/src/index.ts +63 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +31 -0
|
@@ -0,0 +1,566 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Multi-Head Attention Benchmark
|
|
3
|
+
*
|
|
4
|
+
* Target: Baseline comparison for Flash Attention improvements
|
|
5
|
+
*
|
|
6
|
+
* Measures multi-head attention performance with different
|
|
7
|
+
* configurations and parallelization strategies.
|
|
8
|
+
*/
|
|
9
|
+
|
|
10
|
+
import { benchmark, BenchmarkRunner, formatTime, formatBytes } from '../../src/framework/benchmark.js';
|
|
11
|
+
|
|
12
|
+
// ============================================================================
|
|
13
|
+
// Multi-Head Attention Types
|
|
14
|
+
// ============================================================================
|
|
15
|
+
|
|
16
|
+
interface MHAConfig {
|
|
17
|
+
seqLength: number;
|
|
18
|
+
headDim: number;
|
|
19
|
+
numHeads: number;
|
|
20
|
+
batchSize: number;
|
|
21
|
+
dropout?: number;
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
interface MHAResult {
|
|
25
|
+
output: Float32Array;
|
|
26
|
+
headOutputs: Float32Array[];
|
|
27
|
+
memoryUsed: number;
|
|
28
|
+
computeTime: number;
|
|
29
|
+
}
|
|
30
|
+
|
|
31
|
+
// ============================================================================
|
|
32
|
+
// Multi-Head Attention Implementation
|
|
33
|
+
// ============================================================================
|
|
34
|
+
|
|
35
|
+
/**
|
|
36
|
+
* Standard multi-head attention
|
|
37
|
+
*/
|
|
38
|
+
class MultiHeadAttention {
|
|
39
|
+
private config: MHAConfig;
|
|
40
|
+
|
|
41
|
+
constructor(config: MHAConfig) {
|
|
42
|
+
this.config = config;
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
/**
|
|
46
|
+
* Single head attention
|
|
47
|
+
*/
|
|
48
|
+
private singleHeadAttention(
|
|
49
|
+
query: Float32Array,
|
|
50
|
+
key: Float32Array,
|
|
51
|
+
value: Float32Array
|
|
52
|
+
): Float32Array {
|
|
53
|
+
const { seqLength, headDim } = this.config;
|
|
54
|
+
const scale = 1 / Math.sqrt(headDim);
|
|
55
|
+
|
|
56
|
+
// Compute attention scores
|
|
57
|
+
const scores = new Float32Array(seqLength * seqLength);
|
|
58
|
+
|
|
59
|
+
for (let i = 0; i < seqLength; i++) {
|
|
60
|
+
for (let j = 0; j < seqLength; j++) {
|
|
61
|
+
let dot = 0;
|
|
62
|
+
for (let k = 0; k < headDim; k++) {
|
|
63
|
+
dot += query[i * headDim + k]! * key[j * headDim + k]!;
|
|
64
|
+
}
|
|
65
|
+
scores[i * seqLength + j] = dot * scale;
|
|
66
|
+
}
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
// Softmax
|
|
70
|
+
for (let i = 0; i < seqLength; i++) {
|
|
71
|
+
let max = -Infinity;
|
|
72
|
+
for (let j = 0; j < seqLength; j++) {
|
|
73
|
+
max = Math.max(max, scores[i * seqLength + j]!);
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
let sum = 0;
|
|
77
|
+
for (let j = 0; j < seqLength; j++) {
|
|
78
|
+
const exp = Math.exp(scores[i * seqLength + j]! - max);
|
|
79
|
+
scores[i * seqLength + j] = exp;
|
|
80
|
+
sum += exp;
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
for (let j = 0; j < seqLength; j++) {
|
|
84
|
+
scores[i * seqLength + j]! /= sum;
|
|
85
|
+
}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
// Weighted sum
|
|
89
|
+
const output = new Float32Array(seqLength * headDim);
|
|
90
|
+
|
|
91
|
+
for (let i = 0; i < seqLength; i++) {
|
|
92
|
+
for (let k = 0; k < headDim; k++) {
|
|
93
|
+
let sum = 0;
|
|
94
|
+
for (let j = 0; j < seqLength; j++) {
|
|
95
|
+
sum += scores[i * seqLength + j]! * value[j * headDim + k]!;
|
|
96
|
+
}
|
|
97
|
+
output[i * headDim + k] = sum;
|
|
98
|
+
}
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
return output;
|
|
102
|
+
}
|
|
103
|
+
|
|
104
|
+
/**
|
|
105
|
+
* Forward pass through all heads sequentially
|
|
106
|
+
*/
|
|
107
|
+
forwardSequential(
|
|
108
|
+
queries: Float32Array[],
|
|
109
|
+
keys: Float32Array[],
|
|
110
|
+
values: Float32Array[]
|
|
111
|
+
): MHAResult {
|
|
112
|
+
const memBefore = process.memoryUsage().heapUsed;
|
|
113
|
+
const startTime = performance.now();
|
|
114
|
+
|
|
115
|
+
const { seqLength, headDim, numHeads } = this.config;
|
|
116
|
+
const headOutputs: Float32Array[] = [];
|
|
117
|
+
|
|
118
|
+
// Process each head sequentially
|
|
119
|
+
for (let h = 0; h < numHeads; h++) {
|
|
120
|
+
const headOutput = this.singleHeadAttention(
|
|
121
|
+
queries[h]!,
|
|
122
|
+
keys[h]!,
|
|
123
|
+
values[h]!
|
|
124
|
+
);
|
|
125
|
+
headOutputs.push(headOutput);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// Concatenate heads
|
|
129
|
+
const output = new Float32Array(seqLength * headDim * numHeads);
|
|
130
|
+
for (let h = 0; h < numHeads; h++) {
|
|
131
|
+
for (let i = 0; i < seqLength; i++) {
|
|
132
|
+
for (let k = 0; k < headDim; k++) {
|
|
133
|
+
output[i * headDim * numHeads + h * headDim + k] =
|
|
134
|
+
headOutputs[h]![i * headDim + k]!;
|
|
135
|
+
}
|
|
136
|
+
}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
return {
|
|
140
|
+
output,
|
|
141
|
+
headOutputs,
|
|
142
|
+
memoryUsed: process.memoryUsage().heapUsed - memBefore,
|
|
143
|
+
computeTime: performance.now() - startTime,
|
|
144
|
+
};
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
/**
|
|
148
|
+
* Forward pass with parallel head computation (simulated)
|
|
149
|
+
*/
|
|
150
|
+
async forwardParallel(
|
|
151
|
+
queries: Float32Array[],
|
|
152
|
+
keys: Float32Array[],
|
|
153
|
+
values: Float32Array[]
|
|
154
|
+
): Promise<MHAResult> {
|
|
155
|
+
const memBefore = process.memoryUsage().heapUsed;
|
|
156
|
+
const startTime = performance.now();
|
|
157
|
+
|
|
158
|
+
const { seqLength, headDim, numHeads } = this.config;
|
|
159
|
+
|
|
160
|
+
// Process all heads in parallel
|
|
161
|
+
const headPromises = queries.map((q, h) =>
|
|
162
|
+
Promise.resolve(this.singleHeadAttention(q, keys[h]!, values[h]!))
|
|
163
|
+
);
|
|
164
|
+
|
|
165
|
+
const headOutputs = await Promise.all(headPromises);
|
|
166
|
+
|
|
167
|
+
// Concatenate heads
|
|
168
|
+
const output = new Float32Array(seqLength * headDim * numHeads);
|
|
169
|
+
for (let h = 0; h < numHeads; h++) {
|
|
170
|
+
for (let i = 0; i < seqLength; i++) {
|
|
171
|
+
for (let k = 0; k < headDim; k++) {
|
|
172
|
+
output[i * headDim * numHeads + h * headDim + k] =
|
|
173
|
+
headOutputs[h]![i * headDim + k]!;
|
|
174
|
+
}
|
|
175
|
+
}
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
return {
|
|
179
|
+
output,
|
|
180
|
+
headOutputs,
|
|
181
|
+
memoryUsed: process.memoryUsage().heapUsed - memBefore,
|
|
182
|
+
computeTime: performance.now() - startTime,
|
|
183
|
+
};
|
|
184
|
+
}
|
|
185
|
+
}
|
|
186
|
+
|
|
187
|
+
/**
|
|
188
|
+
* Grouped-Query Attention (GQA)
|
|
189
|
+
* Multiple query heads share fewer key/value heads
|
|
190
|
+
*/
|
|
191
|
+
class GroupedQueryAttention {
|
|
192
|
+
private config: MHAConfig;
|
|
193
|
+
private kvHeads: number;
|
|
194
|
+
private groupSize: number;
|
|
195
|
+
|
|
196
|
+
constructor(config: MHAConfig, kvHeads: number) {
|
|
197
|
+
this.config = config;
|
|
198
|
+
this.kvHeads = kvHeads;
|
|
199
|
+
this.groupSize = config.numHeads / kvHeads;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
forward(
|
|
203
|
+
queries: Float32Array[],
|
|
204
|
+
keys: Float32Array[],
|
|
205
|
+
values: Float32Array[]
|
|
206
|
+
): MHAResult {
|
|
207
|
+
const memBefore = process.memoryUsage().heapUsed;
|
|
208
|
+
const startTime = performance.now();
|
|
209
|
+
|
|
210
|
+
const { seqLength, headDim, numHeads } = this.config;
|
|
211
|
+
const headOutputs: Float32Array[] = [];
|
|
212
|
+
|
|
213
|
+
// Process each query head, sharing K/V within groups
|
|
214
|
+
for (let h = 0; h < numHeads; h++) {
|
|
215
|
+
const kvIndex = Math.floor(h / this.groupSize);
|
|
216
|
+
const headOutput = this.singleHeadAttention(
|
|
217
|
+
queries[h]!,
|
|
218
|
+
keys[kvIndex]!,
|
|
219
|
+
values[kvIndex]!
|
|
220
|
+
);
|
|
221
|
+
headOutputs.push(headOutput);
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
// Concatenate heads
|
|
225
|
+
const output = new Float32Array(seqLength * headDim * numHeads);
|
|
226
|
+
for (let h = 0; h < numHeads; h++) {
|
|
227
|
+
for (let i = 0; i < seqLength; i++) {
|
|
228
|
+
for (let k = 0; k < headDim; k++) {
|
|
229
|
+
output[i * headDim * numHeads + h * headDim + k] =
|
|
230
|
+
headOutputs[h]![i * headDim + k]!;
|
|
231
|
+
}
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
return {
|
|
236
|
+
output,
|
|
237
|
+
headOutputs,
|
|
238
|
+
memoryUsed: process.memoryUsage().heapUsed - memBefore,
|
|
239
|
+
computeTime: performance.now() - startTime,
|
|
240
|
+
};
|
|
241
|
+
}
|
|
242
|
+
|
|
243
|
+
private singleHeadAttention(
|
|
244
|
+
query: Float32Array,
|
|
245
|
+
key: Float32Array,
|
|
246
|
+
value: Float32Array
|
|
247
|
+
): Float32Array {
|
|
248
|
+
const { seqLength, headDim } = this.config;
|
|
249
|
+
const scale = 1 / Math.sqrt(headDim);
|
|
250
|
+
|
|
251
|
+
const scores = new Float32Array(seqLength * seqLength);
|
|
252
|
+
|
|
253
|
+
for (let i = 0; i < seqLength; i++) {
|
|
254
|
+
for (let j = 0; j < seqLength; j++) {
|
|
255
|
+
let dot = 0;
|
|
256
|
+
for (let k = 0; k < headDim; k++) {
|
|
257
|
+
dot += query[i * headDim + k]! * key[j * headDim + k]!;
|
|
258
|
+
}
|
|
259
|
+
scores[i * seqLength + j] = dot * scale;
|
|
260
|
+
}
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
// Softmax
|
|
264
|
+
for (let i = 0; i < seqLength; i++) {
|
|
265
|
+
let max = -Infinity;
|
|
266
|
+
for (let j = 0; j < seqLength; j++) {
|
|
267
|
+
max = Math.max(max, scores[i * seqLength + j]!);
|
|
268
|
+
}
|
|
269
|
+
|
|
270
|
+
let sum = 0;
|
|
271
|
+
for (let j = 0; j < seqLength; j++) {
|
|
272
|
+
const exp = Math.exp(scores[i * seqLength + j]! - max);
|
|
273
|
+
scores[i * seqLength + j] = exp;
|
|
274
|
+
sum += exp;
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
for (let j = 0; j < seqLength; j++) {
|
|
278
|
+
scores[i * seqLength + j]! /= sum;
|
|
279
|
+
}
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
const output = new Float32Array(seqLength * headDim);
|
|
283
|
+
|
|
284
|
+
for (let i = 0; i < seqLength; i++) {
|
|
285
|
+
for (let k = 0; k < headDim; k++) {
|
|
286
|
+
let sum = 0;
|
|
287
|
+
for (let j = 0; j < seqLength; j++) {
|
|
288
|
+
sum += scores[i * seqLength + j]! * value[j * headDim + k]!;
|
|
289
|
+
}
|
|
290
|
+
output[i * headDim + k] = sum;
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
return output;
|
|
295
|
+
}
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
// ============================================================================
|
|
299
|
+
// Helper Functions
|
|
300
|
+
// ============================================================================
|
|
301
|
+
|
|
302
|
+
function generateRandomTensor(size: number): Float32Array {
|
|
303
|
+
const tensor = new Float32Array(size);
|
|
304
|
+
for (let i = 0; i < size; i++) {
|
|
305
|
+
tensor[i] = Math.random() * 2 - 1;
|
|
306
|
+
}
|
|
307
|
+
return tensor;
|
|
308
|
+
}
|
|
309
|
+
|
|
310
|
+
function createMultiHeadQKV(
|
|
311
|
+
config: MHAConfig
|
|
312
|
+
): { queries: Float32Array[]; keys: Float32Array[]; values: Float32Array[] } {
|
|
313
|
+
const { seqLength, headDim, numHeads } = config;
|
|
314
|
+
const size = seqLength * headDim;
|
|
315
|
+
|
|
316
|
+
return {
|
|
317
|
+
queries: Array.from({ length: numHeads }, () => generateRandomTensor(size)),
|
|
318
|
+
keys: Array.from({ length: numHeads }, () => generateRandomTensor(size)),
|
|
319
|
+
values: Array.from({ length: numHeads }, () => generateRandomTensor(size)),
|
|
320
|
+
};
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
// ============================================================================
|
|
324
|
+
// Benchmark Suite
|
|
325
|
+
// ============================================================================
|
|
326
|
+
|
|
327
|
+
export async function runMultiHeadAttentionBenchmarks(): Promise<void> {
|
|
328
|
+
const runner = new BenchmarkRunner('Multi-Head Attention');
|
|
329
|
+
|
|
330
|
+
console.log('\n--- Multi-Head Attention Benchmarks ---\n');
|
|
331
|
+
|
|
332
|
+
// Test configurations
|
|
333
|
+
const configs: MHAConfig[] = [
|
|
334
|
+
{ seqLength: 128, headDim: 64, numHeads: 8, batchSize: 1 },
|
|
335
|
+
{ seqLength: 256, headDim: 64, numHeads: 8, batchSize: 1 },
|
|
336
|
+
{ seqLength: 512, headDim: 64, numHeads: 8, batchSize: 1 },
|
|
337
|
+
{ seqLength: 256, headDim: 64, numHeads: 16, batchSize: 1 },
|
|
338
|
+
];
|
|
339
|
+
|
|
340
|
+
for (const config of configs) {
|
|
341
|
+
const { seqLength, numHeads } = config;
|
|
342
|
+
console.log(`\n--- Seq: ${seqLength}, Heads: ${numHeads} ---`);
|
|
343
|
+
|
|
344
|
+
const mha = new MultiHeadAttention(config);
|
|
345
|
+
const { queries, keys, values } = createMultiHeadQKV(config);
|
|
346
|
+
|
|
347
|
+
// Sequential forward
|
|
348
|
+
const seqResult = await runner.run(
|
|
349
|
+
`mha-sequential-seq${seqLength}-h${numHeads}`,
|
|
350
|
+
async () => {
|
|
351
|
+
mha.forwardSequential(queries, keys, values);
|
|
352
|
+
},
|
|
353
|
+
{ iterations: 50 }
|
|
354
|
+
);
|
|
355
|
+
|
|
356
|
+
console.log(`Sequential: ${formatTime(seqResult.mean)}`);
|
|
357
|
+
|
|
358
|
+
// Parallel forward
|
|
359
|
+
const parallelResult = await runner.run(
|
|
360
|
+
`mha-parallel-seq${seqLength}-h${numHeads}`,
|
|
361
|
+
async () => {
|
|
362
|
+
await mha.forwardParallel(queries, keys, values);
|
|
363
|
+
},
|
|
364
|
+
{ iterations: 50 }
|
|
365
|
+
);
|
|
366
|
+
|
|
367
|
+
console.log(`Parallel: ${formatTime(parallelResult.mean)}`);
|
|
368
|
+
|
|
369
|
+
// Speedup
|
|
370
|
+
const speedup = seqResult.mean / parallelResult.mean;
|
|
371
|
+
console.log(`Parallel Speedup: ${speedup.toFixed(2)}x`);
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
// Grouped-Query Attention comparison
|
|
375
|
+
console.log('\n--- Grouped-Query Attention Comparison ---');
|
|
376
|
+
|
|
377
|
+
const gqaConfig: MHAConfig = {
|
|
378
|
+
seqLength: 256,
|
|
379
|
+
headDim: 64,
|
|
380
|
+
numHeads: 8,
|
|
381
|
+
batchSize: 1,
|
|
382
|
+
};
|
|
383
|
+
|
|
384
|
+
const standardMHA = new MultiHeadAttention(gqaConfig);
|
|
385
|
+
const { queries, keys, values } = createMultiHeadQKV(gqaConfig);
|
|
386
|
+
|
|
387
|
+
// Standard MHA (8 query heads, 8 KV heads)
|
|
388
|
+
const standardResult = await runner.run(
|
|
389
|
+
'mha-standard-8heads',
|
|
390
|
+
async () => {
|
|
391
|
+
standardMHA.forwardSequential(queries, keys, values);
|
|
392
|
+
},
|
|
393
|
+
{ iterations: 50 }
|
|
394
|
+
);
|
|
395
|
+
|
|
396
|
+
console.log(`Standard MHA (8 QKV heads): ${formatTime(standardResult.mean)}`);
|
|
397
|
+
|
|
398
|
+
// GQA with 4 KV heads
|
|
399
|
+
const gqa4 = new GroupedQueryAttention(gqaConfig, 4);
|
|
400
|
+
const kvFor4 = {
|
|
401
|
+
keys: keys.slice(0, 4),
|
|
402
|
+
values: values.slice(0, 4),
|
|
403
|
+
};
|
|
404
|
+
|
|
405
|
+
const gqa4Result = await runner.run(
|
|
406
|
+
'gqa-4-kv-heads',
|
|
407
|
+
async () => {
|
|
408
|
+
gqa4.forward(queries, kvFor4.keys, kvFor4.values);
|
|
409
|
+
},
|
|
410
|
+
{ iterations: 50 }
|
|
411
|
+
);
|
|
412
|
+
|
|
413
|
+
console.log(`GQA (8 Q, 4 KV heads): ${formatTime(gqa4Result.mean)}`);
|
|
414
|
+
|
|
415
|
+
// GQA with 2 KV heads
|
|
416
|
+
const gqa2 = new GroupedQueryAttention(gqaConfig, 2);
|
|
417
|
+
const kvFor2 = {
|
|
418
|
+
keys: keys.slice(0, 2),
|
|
419
|
+
values: values.slice(0, 2),
|
|
420
|
+
};
|
|
421
|
+
|
|
422
|
+
const gqa2Result = await runner.run(
|
|
423
|
+
'gqa-2-kv-heads',
|
|
424
|
+
async () => {
|
|
425
|
+
gqa2.forward(queries, kvFor2.keys, kvFor2.values);
|
|
426
|
+
},
|
|
427
|
+
{ iterations: 50 }
|
|
428
|
+
);
|
|
429
|
+
|
|
430
|
+
console.log(`GQA (8 Q, 2 KV heads): ${formatTime(gqa2Result.mean)}`);
|
|
431
|
+
|
|
432
|
+
// Memory comparison
|
|
433
|
+
console.log('\n--- Memory Usage Comparison ---');
|
|
434
|
+
|
|
435
|
+
const memConfig: MHAConfig = {
|
|
436
|
+
seqLength: 512,
|
|
437
|
+
headDim: 64,
|
|
438
|
+
numHeads: 8,
|
|
439
|
+
batchSize: 1,
|
|
440
|
+
};
|
|
441
|
+
|
|
442
|
+
const { queries: q512, keys: k512, values: v512 } = createMultiHeadQKV(memConfig);
|
|
443
|
+
const mha512 = new MultiHeadAttention(memConfig);
|
|
444
|
+
|
|
445
|
+
const memResult = mha512.forwardSequential(q512, k512, v512);
|
|
446
|
+
console.log(`MHA Memory (seq=512, h=8): ${formatBytes(memResult.memoryUsed)}`);
|
|
447
|
+
|
|
448
|
+
// Per-head memory
|
|
449
|
+
const perHeadMem = memResult.memoryUsed / memConfig.numHeads;
|
|
450
|
+
console.log(`Per-head memory: ${formatBytes(perHeadMem)}`);
|
|
451
|
+
|
|
452
|
+
// Theoretical attention matrix size
|
|
453
|
+
const attentionMatrixSize = memConfig.seqLength * memConfig.seqLength * 4 * memConfig.numHeads;
|
|
454
|
+
console.log(`Theoretical attention matrices: ${formatBytes(attentionMatrixSize)}`);
|
|
455
|
+
|
|
456
|
+
// Summary
|
|
457
|
+
console.log('\n--- Summary ---');
|
|
458
|
+
console.log('Standard MHA vs GQA:');
|
|
459
|
+
console.log(` 8 KV heads: ${formatTime(standardResult.mean)}`);
|
|
460
|
+
console.log(` 4 KV heads: ${formatTime(gqa4Result.mean)} (${(standardResult.mean / gqa4Result.mean).toFixed(2)}x)`);
|
|
461
|
+
console.log(` 2 KV heads: ${formatTime(gqa2Result.mean)} (${(standardResult.mean / gqa2Result.mean).toFixed(2)}x)`);
|
|
462
|
+
|
|
463
|
+
// Print full results
|
|
464
|
+
runner.printResults();
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
// ============================================================================
|
|
468
|
+
// Multi-Head Attention Optimization Strategies
|
|
469
|
+
// ============================================================================
|
|
470
|
+
|
|
471
|
+
export const mhaOptimizations = {
|
|
472
|
+
/**
|
|
473
|
+
* Parallel head computation
|
|
474
|
+
*/
|
|
475
|
+
parallelHeads: {
|
|
476
|
+
description: 'Compute attention heads in parallel',
|
|
477
|
+
expectedImprovement: 'Up to num_heads x speedup',
|
|
478
|
+
implementation: `
|
|
479
|
+
async function parallelMHA(queries, keys, values) {
|
|
480
|
+
const headResults = await Promise.all(
|
|
481
|
+
queries.map((q, i) => computeHead(q, keys[i], values[i]))
|
|
482
|
+
);
|
|
483
|
+
return concatenateHeads(headResults);
|
|
484
|
+
}
|
|
485
|
+
`,
|
|
486
|
+
},
|
|
487
|
+
|
|
488
|
+
/**
|
|
489
|
+
* Grouped-Query Attention
|
|
490
|
+
*/
|
|
491
|
+
groupedQueryAttention: {
|
|
492
|
+
description: 'Share K/V across multiple query heads',
|
|
493
|
+
expectedImprovement: '2-4x memory, 1.5-2x speed',
|
|
494
|
+
implementation: `
|
|
495
|
+
// Instead of numHeads K/V pairs, use numHeads / groupSize
|
|
496
|
+
class GQA {
|
|
497
|
+
forward(queries, keys, values) {
|
|
498
|
+
return queries.map((q, i) => {
|
|
499
|
+
const kvIdx = Math.floor(i / groupSize);
|
|
500
|
+
return attention(q, keys[kvIdx], values[kvIdx]);
|
|
501
|
+
});
|
|
502
|
+
}
|
|
503
|
+
}
|
|
504
|
+
`,
|
|
505
|
+
},
|
|
506
|
+
|
|
507
|
+
/**
|
|
508
|
+
* Multi-Query Attention
|
|
509
|
+
*/
|
|
510
|
+
multiQueryAttention: {
|
|
511
|
+
description: 'Single K/V pair shared across all heads',
|
|
512
|
+
expectedImprovement: '8x memory, 2-3x speed',
|
|
513
|
+
implementation: `
|
|
514
|
+
class MQA {
|
|
515
|
+
forward(queries, key, value) {
|
|
516
|
+
// All heads share single K and V
|
|
517
|
+
return queries.map(q => attention(q, key, value));
|
|
518
|
+
}
|
|
519
|
+
}
|
|
520
|
+
`,
|
|
521
|
+
},
|
|
522
|
+
|
|
523
|
+
/**
|
|
524
|
+
* Fused QKV projection
|
|
525
|
+
*/
|
|
526
|
+
fusedQKVProjection: {
|
|
527
|
+
description: 'Fuse Q, K, V projections into single operation',
|
|
528
|
+
expectedImprovement: '20-30% projection overhead',
|
|
529
|
+
implementation: `
|
|
530
|
+
function fusedQKV(input, weights) {
|
|
531
|
+
// Single matmul for all QKV
|
|
532
|
+
const qkv = matmul(input, weights.qkv);
|
|
533
|
+
return splitQKV(qkv, numHeads, headDim);
|
|
534
|
+
}
|
|
535
|
+
`,
|
|
536
|
+
},
|
|
537
|
+
|
|
538
|
+
/**
|
|
539
|
+
* KV caching for inference
|
|
540
|
+
*/
|
|
541
|
+
kvCaching: {
|
|
542
|
+
description: 'Cache K/V for autoregressive generation',
|
|
543
|
+
expectedImprovement: 'O(1) per token instead of O(n)',
|
|
544
|
+
implementation: `
|
|
545
|
+
class CachedMHA {
|
|
546
|
+
private kvCache: { k: Float32Array[], v: Float32Array[] } = { k: [], v: [] };
|
|
547
|
+
|
|
548
|
+
forward(query, key, value, useCache: boolean) {
|
|
549
|
+
if (useCache) {
|
|
550
|
+
this.kvCache.k.push(key);
|
|
551
|
+
this.kvCache.v.push(value);
|
|
552
|
+
return attention(query, this.kvCache.k, this.kvCache.v);
|
|
553
|
+
}
|
|
554
|
+
return attention(query, [key], [value]);
|
|
555
|
+
}
|
|
556
|
+
}
|
|
557
|
+
`,
|
|
558
|
+
},
|
|
559
|
+
};
|
|
560
|
+
|
|
561
|
+
// Run if executed directly
|
|
562
|
+
if (import.meta.url === `file://${process.argv[1]}`) {
|
|
563
|
+
runMultiHeadAttentionBenchmarks().catch(console.error);
|
|
564
|
+
}
|
|
565
|
+
|
|
566
|
+
export default runMultiHeadAttentionBenchmarks;
|