@sparkleideas/performance 3.0.0-alpha.10
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +256 -0
- package/__tests__/README.md +242 -0
- package/__tests__/attention.test.ts +516 -0
- package/__tests__/benchmarks.test.ts +515 -0
- package/benchmarks/attention/memory-efficiency.bench.ts +569 -0
- package/benchmarks/attention/multi-head-attention.bench.ts +566 -0
- package/benchmarks/startup/agent-spawn.bench.ts +422 -0
- package/benchmarks/startup/cli-cold-start.bench.ts +327 -0
- package/benchmarks/startup/cli-warm-start.bench.ts +277 -0
- package/benchmarks/startup/mcp-server-init.bench.ts +380 -0
- package/docs/ATTENTION.md +277 -0
- package/package.json +29 -0
- package/src/attention-benchmarks.ts +459 -0
- package/src/attention-integration.ts +507 -0
- package/src/examples/flash-attention-demo.ts +160 -0
- package/src/examples/quick-test.ts +62 -0
- package/src/framework/benchmark.ts +583 -0
- package/src/index.ts +63 -0
- package/tmp.json +0 -0
- package/tsconfig.json +9 -0
- package/vitest.config.ts +31 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @sparkleideas/performance - Flash Attention Integration
|
|
3
|
+
*
|
|
4
|
+
* Integrates @ruvector/attention Flash Attention capabilities into V3 performance module.
|
|
5
|
+
* Provides optimized attention mechanisms with 2.49x-7.47x speedup targets.
|
|
6
|
+
*
|
|
7
|
+
* Features:
|
|
8
|
+
* - Flash Attention for memory-efficient processing
|
|
9
|
+
* - Automatic runtime selection (NAPI/WASM/JS)
|
|
10
|
+
* - Performance benchmarking and metrics
|
|
11
|
+
* - Speedup tracking and validation
|
|
12
|
+
*/
|
|
13
|
+
|
|
14
|
+
import {
|
|
15
|
+
FlashAttention,
|
|
16
|
+
DotProductAttention,
|
|
17
|
+
type BenchmarkResult as AttentionBenchmarkResult,
|
|
18
|
+
type ArrayInput,
|
|
19
|
+
} from '@ruvector/attention';
|
|
20
|
+
|
|
21
|
+
// ============================================================================
|
|
22
|
+
// Types
|
|
23
|
+
// ============================================================================
|
|
24
|
+
|
|
25
|
+
export interface AttentionInput {
|
|
26
|
+
query: Float32Array | number[];
|
|
27
|
+
keys: Float32Array[] | number[][];
|
|
28
|
+
values: Float32Array[] | number[][];
|
|
29
|
+
dim?: number;
|
|
30
|
+
blockSize?: number;
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
export interface AttentionOutput {
|
|
34
|
+
result: Float32Array;
|
|
35
|
+
runtime: 'napi' | 'wasm' | 'js';
|
|
36
|
+
executionTimeMs: number;
|
|
37
|
+
memoryUsageBytes?: number;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
export interface BenchmarkResult {
|
|
41
|
+
flashAttention: {
|
|
42
|
+
averageTimeMs: number;
|
|
43
|
+
opsPerSecond: number;
|
|
44
|
+
memoryUsageBytes?: number;
|
|
45
|
+
};
|
|
46
|
+
baseline: {
|
|
47
|
+
averageTimeMs: number;
|
|
48
|
+
opsPerSecond: number;
|
|
49
|
+
memoryUsageBytes?: number;
|
|
50
|
+
};
|
|
51
|
+
speedup: number;
|
|
52
|
+
meetsTarget: boolean; // true if speedup >= 2.49x
|
|
53
|
+
timestamp: Date;
|
|
54
|
+
}
|
|
55
|
+
|
|
56
|
+
export interface PerformanceMetrics {
|
|
57
|
+
totalOperations: number;
|
|
58
|
+
averageSpeedup: number;
|
|
59
|
+
peakSpeedup: number;
|
|
60
|
+
averageExecutionTimeMs: number;
|
|
61
|
+
totalMemorySavedBytes: number;
|
|
62
|
+
successRate: number; // % of operations meeting target
|
|
63
|
+
// Memory tracking metrics
|
|
64
|
+
baselineMemoryBytes: number;
|
|
65
|
+
optimizedMemoryBytes: number;
|
|
66
|
+
memorySavedBytes: number;
|
|
67
|
+
memorySavedPercent: number;
|
|
68
|
+
peakMemoryBytes: number;
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
// ============================================================================
|
|
72
|
+
// Flash Attention Optimizer
|
|
73
|
+
// ============================================================================
|
|
74
|
+
|
|
75
|
+
export class FlashAttentionOptimizer {
|
|
76
|
+
private flashAttention: FlashAttention;
|
|
77
|
+
private baselineAttention: DotProductAttention;
|
|
78
|
+
private metrics: {
|
|
79
|
+
operations: number;
|
|
80
|
+
totalSpeedup: number;
|
|
81
|
+
peakSpeedup: number;
|
|
82
|
+
totalExecutionTime: number;
|
|
83
|
+
successfulOperations: number;
|
|
84
|
+
// Memory tracking
|
|
85
|
+
totalBaselineMemory: number;
|
|
86
|
+
totalOptimizedMemory: number;
|
|
87
|
+
peakMemory: number;
|
|
88
|
+
};
|
|
89
|
+
|
|
90
|
+
constructor(
|
|
91
|
+
private readonly dim: number = 512,
|
|
92
|
+
private readonly blockSize: number = 64
|
|
93
|
+
) {
|
|
94
|
+
this.flashAttention = new FlashAttention(dim, blockSize);
|
|
95
|
+
this.baselineAttention = new DotProductAttention(dim);
|
|
96
|
+
this.metrics = {
|
|
97
|
+
operations: 0,
|
|
98
|
+
totalSpeedup: 0,
|
|
99
|
+
peakSpeedup: 0,
|
|
100
|
+
totalExecutionTime: 0,
|
|
101
|
+
successfulOperations: 0,
|
|
102
|
+
totalBaselineMemory: 0,
|
|
103
|
+
totalOptimizedMemory: 0,
|
|
104
|
+
peakMemory: 0,
|
|
105
|
+
};
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
/**
|
|
109
|
+
* Optimize attention computation using Flash Attention
|
|
110
|
+
* @param input - Query, keys, and values for attention computation
|
|
111
|
+
* @returns Optimized attention output with performance metrics
|
|
112
|
+
*/
|
|
113
|
+
optimize(input: AttentionInput): AttentionOutput {
|
|
114
|
+
const startTime = performance.now();
|
|
115
|
+
const startMemory = this.getMemoryUsage();
|
|
116
|
+
|
|
117
|
+
// Convert inputs if needed
|
|
118
|
+
const query = this.ensureFloat32Array(input.query);
|
|
119
|
+
const keys = input.keys.map(k => this.ensureFloat32Array(k));
|
|
120
|
+
const values = input.values.map(v => this.ensureFloat32Array(v));
|
|
121
|
+
|
|
122
|
+
// Use synchronous Flash Attention with raw Float32Arrays
|
|
123
|
+
const result = this.flashAttention.computeRaw(query, keys, values);
|
|
124
|
+
|
|
125
|
+
const executionTimeMs = performance.now() - startTime;
|
|
126
|
+
const endMemory = this.getMemoryUsage();
|
|
127
|
+
const memoryUsageBytes = endMemory - startMemory;
|
|
128
|
+
|
|
129
|
+
// Update metrics
|
|
130
|
+
this.metrics.operations++;
|
|
131
|
+
this.metrics.totalExecutionTime += executionTimeMs;
|
|
132
|
+
|
|
133
|
+
return {
|
|
134
|
+
result,
|
|
135
|
+
runtime: this.detectRuntime(),
|
|
136
|
+
executionTimeMs,
|
|
137
|
+
memoryUsageBytes: memoryUsageBytes > 0 ? memoryUsageBytes : undefined,
|
|
138
|
+
};
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
/**
|
|
142
|
+
* Benchmark Flash Attention vs baseline attention
|
|
143
|
+
* @returns Comprehensive benchmark results with speedup metrics
|
|
144
|
+
*/
|
|
145
|
+
benchmark(): BenchmarkResult {
|
|
146
|
+
const dim = this.dim;
|
|
147
|
+
const numKeys = 100;
|
|
148
|
+
const iterations = 1000;
|
|
149
|
+
|
|
150
|
+
// Create test data
|
|
151
|
+
const query = new Float32Array(dim);
|
|
152
|
+
const keys = Array.from({ length: numKeys }, () => new Float32Array(dim));
|
|
153
|
+
const values = Array.from({ length: numKeys }, () => new Float32Array(dim));
|
|
154
|
+
|
|
155
|
+
// Fill with random data
|
|
156
|
+
for (let i = 0; i < dim; i++) {
|
|
157
|
+
query[i] = Math.random();
|
|
158
|
+
}
|
|
159
|
+
for (let i = 0; i < numKeys; i++) {
|
|
160
|
+
for (let j = 0; j < dim; j++) {
|
|
161
|
+
keys[i][j] = Math.random();
|
|
162
|
+
values[i][j] = Math.random();
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
|
|
166
|
+
// Force garbage collection if available for accurate memory measurement
|
|
167
|
+
this.forceGC();
|
|
168
|
+
|
|
169
|
+
// Measure baseline memory usage
|
|
170
|
+
const baselineMemoryBefore = this.getMemoryUsage();
|
|
171
|
+
let baselinePeakMemory = baselineMemoryBefore;
|
|
172
|
+
|
|
173
|
+
// Benchmark baseline (DotProduct) - run first to establish baseline memory
|
|
174
|
+
const baselineStart = performance.now();
|
|
175
|
+
for (let i = 0; i < iterations; i++) {
|
|
176
|
+
this.baselineAttention.computeRaw(query, keys, values);
|
|
177
|
+
// Sample memory periodically (every 100 iterations to reduce overhead)
|
|
178
|
+
if (i % 100 === 0) {
|
|
179
|
+
const currentMemory = this.getMemoryUsage();
|
|
180
|
+
if (currentMemory > baselinePeakMemory) {
|
|
181
|
+
baselinePeakMemory = currentMemory;
|
|
182
|
+
}
|
|
183
|
+
}
|
|
184
|
+
}
|
|
185
|
+
const baselineEnd = performance.now();
|
|
186
|
+
const baselineMemoryAfter = this.getMemoryUsage();
|
|
187
|
+
const baselineTimeMs = baselineEnd - baselineStart;
|
|
188
|
+
const baselineAvgMs = baselineTimeMs / iterations;
|
|
189
|
+
const baselineMemoryUsed = Math.max(0, baselinePeakMemory - baselineMemoryBefore);
|
|
190
|
+
|
|
191
|
+
// Force garbage collection before Flash Attention benchmark
|
|
192
|
+
this.forceGC();
|
|
193
|
+
|
|
194
|
+
// Measure Flash Attention memory usage
|
|
195
|
+
const flashMemoryBefore = this.getMemoryUsage();
|
|
196
|
+
let flashPeakMemory = flashMemoryBefore;
|
|
197
|
+
|
|
198
|
+
// Benchmark Flash Attention
|
|
199
|
+
const flashStart = performance.now();
|
|
200
|
+
for (let i = 0; i < iterations; i++) {
|
|
201
|
+
this.flashAttention.computeRaw(query, keys, values);
|
|
202
|
+
// Sample memory periodically
|
|
203
|
+
if (i % 100 === 0) {
|
|
204
|
+
const currentMemory = this.getMemoryUsage();
|
|
205
|
+
if (currentMemory > flashPeakMemory) {
|
|
206
|
+
flashPeakMemory = currentMemory;
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
}
|
|
210
|
+
const flashEnd = performance.now();
|
|
211
|
+
const flashTimeMs = flashEnd - flashStart;
|
|
212
|
+
const flashAvgMs = flashTimeMs / iterations;
|
|
213
|
+
const flashMemoryUsed = Math.max(0, flashPeakMemory - flashMemoryBefore);
|
|
214
|
+
|
|
215
|
+
const speedup = baselineAvgMs / flashAvgMs;
|
|
216
|
+
const meetsTarget = speedup >= 2.49; // Minimum target: 2.49x
|
|
217
|
+
|
|
218
|
+
// Update peak speedup
|
|
219
|
+
if (speedup > this.metrics.peakSpeedup) {
|
|
220
|
+
this.metrics.peakSpeedup = speedup;
|
|
221
|
+
}
|
|
222
|
+
|
|
223
|
+
// Update memory tracking metrics
|
|
224
|
+
this.metrics.totalBaselineMemory += baselineMemoryUsed;
|
|
225
|
+
this.metrics.totalOptimizedMemory += flashMemoryUsed;
|
|
226
|
+
if (flashPeakMemory > this.metrics.peakMemory) {
|
|
227
|
+
this.metrics.peakMemory = flashPeakMemory;
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
this.metrics.totalSpeedup += speedup;
|
|
231
|
+
if (meetsTarget) {
|
|
232
|
+
this.metrics.successfulOperations++;
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
return {
|
|
236
|
+
flashAttention: {
|
|
237
|
+
averageTimeMs: flashAvgMs,
|
|
238
|
+
opsPerSecond: 1000 / flashAvgMs,
|
|
239
|
+
memoryUsageBytes: flashMemoryUsed,
|
|
240
|
+
},
|
|
241
|
+
baseline: {
|
|
242
|
+
averageTimeMs: baselineAvgMs,
|
|
243
|
+
opsPerSecond: 1000 / baselineAvgMs,
|
|
244
|
+
memoryUsageBytes: baselineMemoryUsed,
|
|
245
|
+
},
|
|
246
|
+
speedup,
|
|
247
|
+
meetsTarget,
|
|
248
|
+
timestamp: new Date(),
|
|
249
|
+
};
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
/**
|
|
253
|
+
* Get current speedup factor from accumulated metrics
|
|
254
|
+
* @returns Average speedup factor across all operations
|
|
255
|
+
*/
|
|
256
|
+
getSpeedup(): number {
|
|
257
|
+
if (this.metrics.operations === 0) {
|
|
258
|
+
return 0;
|
|
259
|
+
}
|
|
260
|
+
return this.metrics.totalSpeedup / this.metrics.operations;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
/**
|
|
264
|
+
* Get comprehensive performance metrics
|
|
265
|
+
* @returns Detailed performance statistics
|
|
266
|
+
*/
|
|
267
|
+
getMetrics(): PerformanceMetrics {
|
|
268
|
+
const avgSpeedup = this.getSpeedup();
|
|
269
|
+
|
|
270
|
+
// Calculate memory savings
|
|
271
|
+
const baselineMemory = this.metrics.totalBaselineMemory;
|
|
272
|
+
const optimizedMemory = this.metrics.totalOptimizedMemory;
|
|
273
|
+
const memorySaved = Math.max(0, baselineMemory - optimizedMemory);
|
|
274
|
+
const memorySavedPercent =
|
|
275
|
+
baselineMemory > 0 ? (memorySaved / baselineMemory) * 100 : 0;
|
|
276
|
+
|
|
277
|
+
return {
|
|
278
|
+
totalOperations: this.metrics.operations,
|
|
279
|
+
averageSpeedup: avgSpeedup,
|
|
280
|
+
peakSpeedup: this.metrics.peakSpeedup,
|
|
281
|
+
averageExecutionTimeMs:
|
|
282
|
+
this.metrics.operations > 0
|
|
283
|
+
? this.metrics.totalExecutionTime / this.metrics.operations
|
|
284
|
+
: 0,
|
|
285
|
+
totalMemorySavedBytes: memorySaved,
|
|
286
|
+
successRate:
|
|
287
|
+
this.metrics.operations > 0
|
|
288
|
+
? (this.metrics.successfulOperations / this.metrics.operations) * 100
|
|
289
|
+
: 0,
|
|
290
|
+
// Memory tracking metrics
|
|
291
|
+
baselineMemoryBytes: baselineMemory,
|
|
292
|
+
optimizedMemoryBytes: optimizedMemory,
|
|
293
|
+
memorySavedBytes: memorySaved,
|
|
294
|
+
memorySavedPercent: memorySavedPercent,
|
|
295
|
+
peakMemoryBytes: this.metrics.peakMemory,
|
|
296
|
+
};
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
/**
|
|
300
|
+
* Reset all metrics
|
|
301
|
+
*/
|
|
302
|
+
resetMetrics(): void {
|
|
303
|
+
this.metrics = {
|
|
304
|
+
operations: 0,
|
|
305
|
+
totalSpeedup: 0,
|
|
306
|
+
peakSpeedup: 0,
|
|
307
|
+
totalExecutionTime: 0,
|
|
308
|
+
successfulOperations: 0,
|
|
309
|
+
totalBaselineMemory: 0,
|
|
310
|
+
totalOptimizedMemory: 0,
|
|
311
|
+
peakMemory: 0,
|
|
312
|
+
};
|
|
313
|
+
}
|
|
314
|
+
|
|
315
|
+
/**
|
|
316
|
+
* Ensure input is Float32Array for optimal performance
|
|
317
|
+
*/
|
|
318
|
+
private ensureFloat32Array(input: ArrayInput): Float32Array {
|
|
319
|
+
if (input instanceof Float32Array) {
|
|
320
|
+
return input;
|
|
321
|
+
}
|
|
322
|
+
return new Float32Array(input);
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
/**
|
|
326
|
+
* Detect which runtime is being used
|
|
327
|
+
*/
|
|
328
|
+
private detectRuntime(): 'napi' | 'wasm' | 'js' {
|
|
329
|
+
// Check if NAPI bindings are available
|
|
330
|
+
try {
|
|
331
|
+
if (typeof process !== 'undefined' && process.versions && 'napi' in process.versions) {
|
|
332
|
+
return 'napi';
|
|
333
|
+
}
|
|
334
|
+
} catch {
|
|
335
|
+
// Not in Node.js environment
|
|
336
|
+
}
|
|
337
|
+
|
|
338
|
+
// Check for WebAssembly support
|
|
339
|
+
if (typeof globalThis !== 'undefined' && 'WebAssembly' in globalThis) {
|
|
340
|
+
return 'wasm';
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
// Fallback to pure JS
|
|
344
|
+
return 'js';
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
/**
|
|
348
|
+
* Get current memory usage in bytes
|
|
349
|
+
*/
|
|
350
|
+
private getMemoryUsage(): number {
|
|
351
|
+
if (typeof process !== 'undefined' && process.memoryUsage) {
|
|
352
|
+
return process.memoryUsage().heapUsed;
|
|
353
|
+
}
|
|
354
|
+
return 0;
|
|
355
|
+
}
|
|
356
|
+
|
|
357
|
+
/**
|
|
358
|
+
* Force garbage collection if available (requires --expose-gc flag)
|
|
359
|
+
* This helps get more accurate memory measurements
|
|
360
|
+
*/
|
|
361
|
+
private forceGC(): void {
|
|
362
|
+
if (typeof global !== 'undefined' && typeof (global as any).gc === 'function') {
|
|
363
|
+
(global as any).gc();
|
|
364
|
+
}
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
/**
|
|
368
|
+
* Benchmark memory usage across multiple dimensions
|
|
369
|
+
* Validates the 50-75% memory reduction target
|
|
370
|
+
* @param dimensions - Array of dimensions to test (default: [128, 256, 512, 1024])
|
|
371
|
+
* @returns Memory profiling results for each dimension
|
|
372
|
+
*/
|
|
373
|
+
benchmarkMemory(
|
|
374
|
+
dimensions: number[] = [128, 256, 512, 1024]
|
|
375
|
+
): {
|
|
376
|
+
dimension: number;
|
|
377
|
+
baselineMemoryBytes: number;
|
|
378
|
+
optimizedMemoryBytes: number;
|
|
379
|
+
memorySavedBytes: number;
|
|
380
|
+
memorySavedPercent: number;
|
|
381
|
+
meetsTarget: boolean; // true if 50-75% reduction achieved
|
|
382
|
+
}[] {
|
|
383
|
+
const results: {
|
|
384
|
+
dimension: number;
|
|
385
|
+
baselineMemoryBytes: number;
|
|
386
|
+
optimizedMemoryBytes: number;
|
|
387
|
+
memorySavedBytes: number;
|
|
388
|
+
memorySavedPercent: number;
|
|
389
|
+
meetsTarget: boolean;
|
|
390
|
+
}[] = [];
|
|
391
|
+
|
|
392
|
+
for (const dim of dimensions) {
|
|
393
|
+
const numKeys = 100;
|
|
394
|
+
const iterations = 100; // Fewer iterations for memory profiling
|
|
395
|
+
|
|
396
|
+
// Create test data
|
|
397
|
+
const query = new Float32Array(dim);
|
|
398
|
+
const keys = Array.from({ length: numKeys }, () => new Float32Array(dim));
|
|
399
|
+
const values = Array.from({ length: numKeys }, () => new Float32Array(dim));
|
|
400
|
+
|
|
401
|
+
// Fill with random data
|
|
402
|
+
for (let i = 0; i < dim; i++) {
|
|
403
|
+
query[i] = Math.random();
|
|
404
|
+
}
|
|
405
|
+
for (let i = 0; i < numKeys; i++) {
|
|
406
|
+
for (let j = 0; j < dim; j++) {
|
|
407
|
+
keys[i][j] = Math.random();
|
|
408
|
+
values[i][j] = Math.random();
|
|
409
|
+
}
|
|
410
|
+
}
|
|
411
|
+
|
|
412
|
+
// Create temporary instances for this dimension
|
|
413
|
+
const flashAttention = new FlashAttention(dim, this.blockSize);
|
|
414
|
+
const baselineAttention = new DotProductAttention(dim);
|
|
415
|
+
|
|
416
|
+
// Measure baseline memory
|
|
417
|
+
this.forceGC();
|
|
418
|
+
const baselineMemBefore = this.getMemoryUsage();
|
|
419
|
+
let baselinePeak = baselineMemBefore;
|
|
420
|
+
|
|
421
|
+
for (let i = 0; i < iterations; i++) {
|
|
422
|
+
baselineAttention.computeRaw(query, keys, values);
|
|
423
|
+
if (i % 10 === 0) {
|
|
424
|
+
const curr = this.getMemoryUsage();
|
|
425
|
+
if (curr > baselinePeak) baselinePeak = curr;
|
|
426
|
+
}
|
|
427
|
+
}
|
|
428
|
+
const baselineMemoryUsed = Math.max(0, baselinePeak - baselineMemBefore);
|
|
429
|
+
|
|
430
|
+
// Measure Flash Attention memory
|
|
431
|
+
this.forceGC();
|
|
432
|
+
const flashMemBefore = this.getMemoryUsage();
|
|
433
|
+
let flashPeak = flashMemBefore;
|
|
434
|
+
|
|
435
|
+
for (let i = 0; i < iterations; i++) {
|
|
436
|
+
flashAttention.computeRaw(query, keys, values);
|
|
437
|
+
if (i % 10 === 0) {
|
|
438
|
+
const curr = this.getMemoryUsage();
|
|
439
|
+
if (curr > flashPeak) flashPeak = curr;
|
|
440
|
+
}
|
|
441
|
+
}
|
|
442
|
+
const flashMemoryUsed = Math.max(0, flashPeak - flashMemBefore);
|
|
443
|
+
|
|
444
|
+
const memorySaved = Math.max(0, baselineMemoryUsed - flashMemoryUsed);
|
|
445
|
+
const memorySavedPercent =
|
|
446
|
+
baselineMemoryUsed > 0 ? (memorySaved / baselineMemoryUsed) * 100 : 0;
|
|
447
|
+
|
|
448
|
+
// Target: 50-75% memory reduction
|
|
449
|
+
const meetsTarget = memorySavedPercent >= 50 && memorySavedPercent <= 100;
|
|
450
|
+
|
|
451
|
+
results.push({
|
|
452
|
+
dimension: dim,
|
|
453
|
+
baselineMemoryBytes: baselineMemoryUsed,
|
|
454
|
+
optimizedMemoryBytes: flashMemoryUsed,
|
|
455
|
+
memorySavedBytes: memorySaved,
|
|
456
|
+
memorySavedPercent: memorySavedPercent,
|
|
457
|
+
meetsTarget: meetsTarget,
|
|
458
|
+
});
|
|
459
|
+
|
|
460
|
+
// Update global metrics
|
|
461
|
+
this.metrics.totalBaselineMemory += baselineMemoryUsed;
|
|
462
|
+
this.metrics.totalOptimizedMemory += flashMemoryUsed;
|
|
463
|
+
if (flashPeak > this.metrics.peakMemory) {
|
|
464
|
+
this.metrics.peakMemory = flashPeak;
|
|
465
|
+
}
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
return results;
|
|
469
|
+
}
|
|
470
|
+
}
|
|
471
|
+
|
|
472
|
+
// ============================================================================
|
|
473
|
+
// Convenience Factory Functions
|
|
474
|
+
// ============================================================================
|
|
475
|
+
|
|
476
|
+
/**
|
|
477
|
+
* Create a Flash Attention optimizer with default settings
|
|
478
|
+
* @param dim - Dimension of attention vectors (default: 512)
|
|
479
|
+
* @param blockSize - Block size for Flash Attention (default: 64)
|
|
480
|
+
* @returns Configured FlashAttentionOptimizer instance
|
|
481
|
+
*/
|
|
482
|
+
export function createFlashAttentionOptimizer(
|
|
483
|
+
dim: number = 512,
|
|
484
|
+
blockSize: number = 64
|
|
485
|
+
): FlashAttentionOptimizer {
|
|
486
|
+
return new FlashAttentionOptimizer(dim, blockSize);
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
/**
|
|
490
|
+
* Quick benchmark of Flash Attention performance
|
|
491
|
+
* @param dim - Dimension to test (default: 512)
|
|
492
|
+
* @returns Benchmark results with speedup metrics
|
|
493
|
+
*/
|
|
494
|
+
export function quickBenchmark(dim: number = 512): BenchmarkResult {
|
|
495
|
+
const optimizer = createFlashAttentionOptimizer(dim);
|
|
496
|
+
return optimizer.benchmark();
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
// ============================================================================
|
|
500
|
+
// Exports
|
|
501
|
+
// ============================================================================
|
|
502
|
+
|
|
503
|
+
export {
|
|
504
|
+
FlashAttention,
|
|
505
|
+
DotProductAttention,
|
|
506
|
+
type AttentionBenchmarkResult,
|
|
507
|
+
};
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Flash Attention Integration Demo
|
|
3
|
+
*
|
|
4
|
+
* Demonstrates how to use the Flash Attention integration in V3 performance module.
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
import {
|
|
8
|
+
FlashAttentionOptimizer,
|
|
9
|
+
createFlashAttentionOptimizer,
|
|
10
|
+
quickBenchmark,
|
|
11
|
+
AttentionBenchmarkRunner,
|
|
12
|
+
quickValidation,
|
|
13
|
+
runAndDisplaySuite,
|
|
14
|
+
type AttentionInput,
|
|
15
|
+
} from '../index.js';
|
|
16
|
+
|
|
17
|
+
// ============================================================================
|
|
18
|
+
// Example 1: Basic Flash Attention Usage
|
|
19
|
+
// ============================================================================
|
|
20
|
+
|
|
21
|
+
async function basicUsageExample() {
|
|
22
|
+
console.log('\n=== Example 1: Basic Flash Attention Usage ===\n');
|
|
23
|
+
|
|
24
|
+
// Create optimizer with 512-dimensional vectors
|
|
25
|
+
const optimizer = createFlashAttentionOptimizer(512, 64);
|
|
26
|
+
|
|
27
|
+
// Prepare input data
|
|
28
|
+
const dim = 512;
|
|
29
|
+
const numKeys = 100;
|
|
30
|
+
|
|
31
|
+
const input: AttentionInput = {
|
|
32
|
+
query: new Float32Array(dim).fill(1.0),
|
|
33
|
+
keys: Array.from({ length: numKeys }, () => new Float32Array(dim).fill(1.0)),
|
|
34
|
+
values: Array.from({ length: numKeys }, () => new Float32Array(dim).fill(1.0)),
|
|
35
|
+
};
|
|
36
|
+
|
|
37
|
+
// Run optimized attention
|
|
38
|
+
const output = await optimizer.optimize(input);
|
|
39
|
+
|
|
40
|
+
console.log(`Runtime: ${output.runtime}`);
|
|
41
|
+
console.log(`Execution time: ${output.executionTimeMs.toFixed(3)}ms`);
|
|
42
|
+
console.log(`Result shape: Float32Array[${output.result.length}]`);
|
|
43
|
+
console.log(`Memory usage: ${output.memoryUsageBytes ? `${(output.memoryUsageBytes / 1024).toFixed(2)} KB` : 'N/A'}`);
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
// ============================================================================
|
|
47
|
+
// Example 2: Performance Benchmarking
|
|
48
|
+
// ============================================================================
|
|
49
|
+
|
|
50
|
+
async function benchmarkExample() {
|
|
51
|
+
console.log('\n=== Example 2: Performance Benchmarking ===\n');
|
|
52
|
+
|
|
53
|
+
// Quick benchmark with default settings
|
|
54
|
+
const result = await quickBenchmark(512);
|
|
55
|
+
|
|
56
|
+
console.log(`Flash Attention: ${result.flashAttention.averageTimeMs.toFixed(3)}ms`);
|
|
57
|
+
console.log(`Baseline: ${result.baseline.averageTimeMs.toFixed(3)}ms`);
|
|
58
|
+
console.log(`Speedup: ${result.speedup.toFixed(2)}x`);
|
|
59
|
+
console.log(`Meets target (≥2.49x): ${result.meetsTarget ? 'YES ✓' : 'NO ✗'}`);
|
|
60
|
+
}
|
|
61
|
+
|
|
62
|
+
// ============================================================================
|
|
63
|
+
// Example 3: Comprehensive Suite
|
|
64
|
+
// ============================================================================
|
|
65
|
+
|
|
66
|
+
async function comprehensiveSuiteExample() {
|
|
67
|
+
console.log('\n=== Example 3: Comprehensive Benchmark Suite ===\n');
|
|
68
|
+
|
|
69
|
+
const runner = new AttentionBenchmarkRunner();
|
|
70
|
+
const suite = await runner.runComprehensiveSuite();
|
|
71
|
+
|
|
72
|
+
console.log(`Suite: ${suite.suiteName}`);
|
|
73
|
+
console.log(`Benchmarks run: ${suite.summary.totalBenchmarks}`);
|
|
74
|
+
console.log(`Average speedup: ${suite.summary.averageSpeedup.toFixed(2)}x`);
|
|
75
|
+
console.log(`Min speedup: ${suite.summary.minSpeedup.toFixed(2)}x`);
|
|
76
|
+
console.log(`Max speedup: ${suite.summary.maxSpeedup.toFixed(2)}x`);
|
|
77
|
+
console.log(`Success rate: ${suite.summary.successRate.toFixed(1)}%`);
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
// ============================================================================
|
|
81
|
+
// Example 4: V3 Target Validation
|
|
82
|
+
// ============================================================================
|
|
83
|
+
|
|
84
|
+
async function targetValidationExample() {
|
|
85
|
+
console.log('\n=== Example 4: V3 Target Validation ===\n');
|
|
86
|
+
|
|
87
|
+
const isValid = await quickValidation();
|
|
88
|
+
|
|
89
|
+
console.log(`\nValidation result: ${isValid ? 'PASSED ✓' : 'FAILED ✗'}`);
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
// ============================================================================
|
|
93
|
+
// Example 5: Continuous Metrics Tracking
|
|
94
|
+
// ============================================================================
|
|
95
|
+
|
|
96
|
+
async function metricsTrackingExample() {
|
|
97
|
+
console.log('\n=== Example 5: Continuous Metrics Tracking ===\n');
|
|
98
|
+
|
|
99
|
+
const optimizer = createFlashAttentionOptimizer(512);
|
|
100
|
+
|
|
101
|
+
// Run multiple operations
|
|
102
|
+
const dim = 512;
|
|
103
|
+
const input: AttentionInput = {
|
|
104
|
+
query: new Float32Array(dim).fill(1.0),
|
|
105
|
+
keys: Array.from({ length: 100 }, () => new Float32Array(dim).fill(1.0)),
|
|
106
|
+
values: Array.from({ length: 100 }, () => new Float32Array(dim).fill(1.0)),
|
|
107
|
+
};
|
|
108
|
+
|
|
109
|
+
console.log('Running 10 operations...\n');
|
|
110
|
+
|
|
111
|
+
for (let i = 0; i < 10; i++) {
|
|
112
|
+
await optimizer.optimize(input);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
// Run benchmarks to update speedup metrics
|
|
116
|
+
await optimizer.benchmark();
|
|
117
|
+
|
|
118
|
+
// Get accumulated metrics
|
|
119
|
+
const metrics = optimizer.getMetrics();
|
|
120
|
+
|
|
121
|
+
console.log(`Total operations: ${metrics.totalOperations}`);
|
|
122
|
+
console.log(`Average speedup: ${metrics.averageSpeedup.toFixed(2)}x`);
|
|
123
|
+
console.log(`Peak speedup: ${metrics.peakSpeedup.toFixed(2)}x`);
|
|
124
|
+
console.log(`Average execution time: ${metrics.averageExecutionTimeMs.toFixed(3)}ms`);
|
|
125
|
+
console.log(`Success rate: ${metrics.successRate.toFixed(1)}%`);
|
|
126
|
+
}
|
|
127
|
+
|
|
128
|
+
// ============================================================================
|
|
129
|
+
// Main Demo Runner
|
|
130
|
+
// ============================================================================
|
|
131
|
+
|
|
132
|
+
async function runAllExamples() {
|
|
133
|
+
try {
|
|
134
|
+
await basicUsageExample();
|
|
135
|
+
await benchmarkExample();
|
|
136
|
+
await comprehensiveSuiteExample();
|
|
137
|
+
await targetValidationExample();
|
|
138
|
+
await metricsTrackingExample();
|
|
139
|
+
|
|
140
|
+
console.log('\n=== All Examples Completed ===\n');
|
|
141
|
+
} catch (error) {
|
|
142
|
+
console.error('Error running examples:', error);
|
|
143
|
+
process.exit(1);
|
|
144
|
+
}
|
|
145
|
+
}
|
|
146
|
+
|
|
147
|
+
// Run if executed directly
|
|
148
|
+
if (import.meta.url === `file://${process.argv[1]}`) {
|
|
149
|
+
runAllExamples();
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// Export for programmatic use
|
|
153
|
+
export {
|
|
154
|
+
basicUsageExample,
|
|
155
|
+
benchmarkExample,
|
|
156
|
+
comprehensiveSuiteExample,
|
|
157
|
+
targetValidationExample,
|
|
158
|
+
metricsTrackingExample,
|
|
159
|
+
runAllExamples,
|
|
160
|
+
};
|