@ruvector/attention-wasm 0.1.0 → 2.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
package/js/index.ts DELETED
@@ -1,412 +0,0 @@
1
- /**
2
- * TypeScript wrapper for ruvector-attention-wasm
3
- * Provides a clean, type-safe API for attention mechanisms
4
- */
5
-
6
- import init, * as wasm from '../pkg/ruvector_attention_wasm';
7
- import type {
8
- AttentionConfig,
9
- MultiHeadConfig,
10
- HyperbolicConfig,
11
- LinearAttentionConfig,
12
- FlashAttentionConfig,
13
- LocalGlobalConfig,
14
- MoEConfig,
15
- TrainingConfig,
16
- SchedulerConfig,
17
- ExpertStats,
18
- AttentionType,
19
- } from './types';
20
-
21
- export * from './types';
22
-
23
- let initialized = false;
24
-
25
- /**
26
- * Initialize the WASM module
27
- * Must be called before using any attention mechanisms
28
- */
29
- export async function initialize(): Promise<void> {
30
- if (!initialized) {
31
- await init();
32
- initialized = true;
33
- }
34
- }
35
-
36
- /**
37
- * Get the version of the ruvector-attention-wasm package
38
- */
39
- export function version(): string {
40
- return wasm.version();
41
- }
42
-
43
- /**
44
- * Get list of available attention mechanisms
45
- */
46
- export function availableMechanisms(): AttentionType[] {
47
- return wasm.available_mechanisms() as AttentionType[];
48
- }
49
-
50
- /**
51
- * Multi-head attention mechanism
52
- */
53
- export class MultiHeadAttention {
54
- private inner: wasm.WasmMultiHeadAttention;
55
-
56
- constructor(config: MultiHeadConfig) {
57
- this.inner = new wasm.WasmMultiHeadAttention(config.dim, config.numHeads);
58
- }
59
-
60
- /**
61
- * Compute multi-head attention
62
- */
63
- compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
64
- const result = this.inner.compute(query, keys, values);
65
- return new Float32Array(result);
66
- }
67
-
68
- get numHeads(): number {
69
- return this.inner.num_heads;
70
- }
71
-
72
- get dim(): number {
73
- return this.inner.dim;
74
- }
75
-
76
- free(): void {
77
- this.inner.free();
78
- }
79
- }
80
-
81
- /**
82
- * Hyperbolic attention mechanism
83
- */
84
- export class HyperbolicAttention {
85
- private inner: wasm.WasmHyperbolicAttention;
86
-
87
- constructor(config: HyperbolicConfig) {
88
- this.inner = new wasm.WasmHyperbolicAttention(config.dim, config.curvature);
89
- }
90
-
91
- compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
92
- const result = this.inner.compute(query, keys, values);
93
- return new Float32Array(result);
94
- }
95
-
96
- get curvature(): number {
97
- return this.inner.curvature;
98
- }
99
-
100
- free(): void {
101
- this.inner.free();
102
- }
103
- }
104
-
105
- /**
106
- * Linear attention (Performer-style)
107
- */
108
- export class LinearAttention {
109
- private inner: wasm.WasmLinearAttention;
110
-
111
- constructor(config: LinearAttentionConfig) {
112
- this.inner = new wasm.WasmLinearAttention(config.dim, config.numFeatures);
113
- }
114
-
115
- compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
116
- const result = this.inner.compute(query, keys, values);
117
- return new Float32Array(result);
118
- }
119
-
120
- free(): void {
121
- this.inner.free();
122
- }
123
- }
124
-
125
- /**
126
- * Flash attention mechanism
127
- */
128
- export class FlashAttention {
129
- private inner: wasm.WasmFlashAttention;
130
-
131
- constructor(config: FlashAttentionConfig) {
132
- this.inner = new wasm.WasmFlashAttention(config.dim, config.blockSize);
133
- }
134
-
135
- compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
136
- const result = this.inner.compute(query, keys, values);
137
- return new Float32Array(result);
138
- }
139
-
140
- free(): void {
141
- this.inner.free();
142
- }
143
- }
144
-
145
- /**
146
- * Local-global attention mechanism
147
- */
148
- export class LocalGlobalAttention {
149
- private inner: wasm.WasmLocalGlobalAttention;
150
-
151
- constructor(config: LocalGlobalConfig) {
152
- this.inner = new wasm.WasmLocalGlobalAttention(
153
- config.dim,
154
- config.localWindow,
155
- config.globalTokens
156
- );
157
- }
158
-
159
- compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
160
- const result = this.inner.compute(query, keys, values);
161
- return new Float32Array(result);
162
- }
163
-
164
- free(): void {
165
- this.inner.free();
166
- }
167
- }
168
-
169
- /**
170
- * Mixture of Experts attention
171
- */
172
- export class MoEAttention {
173
- private inner: wasm.WasmMoEAttention;
174
-
175
- constructor(config: MoEConfig) {
176
- this.inner = new wasm.WasmMoEAttention(config.dim, config.numExperts, config.topK);
177
- }
178
-
179
- compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
180
- const result = this.inner.compute(query, keys, values);
181
- return new Float32Array(result);
182
- }
183
-
184
- getExpertStats(): ExpertStats {
185
- return this.inner.expert_stats() as ExpertStats;
186
- }
187
-
188
- free(): void {
189
- this.inner.free();
190
- }
191
- }
192
-
193
- /**
194
- * InfoNCE contrastive loss
195
- */
196
- export class InfoNCELoss {
197
- private inner: wasm.WasmInfoNCELoss;
198
-
199
- constructor(temperature: number = 0.07) {
200
- this.inner = new wasm.WasmInfoNCELoss(temperature);
201
- }
202
-
203
- compute(anchor: Float32Array, positive: Float32Array, negatives: Float32Array[]): number {
204
- return this.inner.compute(anchor, positive, negatives);
205
- }
206
-
207
- computeMultiPositive(
208
- anchor: Float32Array,
209
- positives: Float32Array[],
210
- negatives: Float32Array[]
211
- ): number {
212
- return this.inner.compute_multi_positive(anchor, positives, negatives);
213
- }
214
-
215
- free(): void {
216
- this.inner.free();
217
- }
218
- }
219
-
220
- /**
221
- * Adam optimizer
222
- */
223
- export class Adam {
224
- private inner: wasm.WasmAdam;
225
-
226
- constructor(paramCount: number, config: TrainingConfig) {
227
- this.inner = new wasm.WasmAdam(
228
- paramCount,
229
- config.learningRate,
230
- config.beta1,
231
- config.beta2,
232
- config.epsilon
233
- );
234
- }
235
-
236
- step(params: Float32Array, gradients: Float32Array): void {
237
- this.inner.step(params, gradients);
238
- }
239
-
240
- reset(): void {
241
- this.inner.reset();
242
- }
243
-
244
- get learningRate(): number {
245
- return this.inner.learning_rate;
246
- }
247
-
248
- set learningRate(lr: number) {
249
- this.inner.learning_rate = lr;
250
- }
251
-
252
- free(): void {
253
- this.inner.free();
254
- }
255
- }
256
-
257
- /**
258
- * AdamW optimizer (Adam with decoupled weight decay)
259
- */
260
- export class AdamW {
261
- private inner: wasm.WasmAdamW;
262
-
263
- constructor(paramCount: number, config: TrainingConfig) {
264
- if (!config.weightDecay) {
265
- throw new Error('AdamW requires weightDecay parameter');
266
- }
267
-
268
- this.inner = new wasm.WasmAdamW(
269
- paramCount,
270
- config.learningRate,
271
- config.weightDecay,
272
- config.beta1,
273
- config.beta2,
274
- config.epsilon
275
- );
276
- }
277
-
278
- step(params: Float32Array, gradients: Float32Array): void {
279
- this.inner.step(params, gradients);
280
- }
281
-
282
- reset(): void {
283
- this.inner.reset();
284
- }
285
-
286
- get learningRate(): number {
287
- return this.inner.learning_rate;
288
- }
289
-
290
- set learningRate(lr: number) {
291
- this.inner.learning_rate = lr;
292
- }
293
-
294
- get weightDecay(): number {
295
- return this.inner.weight_decay;
296
- }
297
-
298
- free(): void {
299
- this.inner.free();
300
- }
301
- }
302
-
303
- /**
304
- * Learning rate scheduler with warmup and cosine decay
305
- */
306
- export class LRScheduler {
307
- private inner: wasm.WasmLRScheduler;
308
-
309
- constructor(config: SchedulerConfig) {
310
- this.inner = new wasm.WasmLRScheduler(
311
- config.initialLR,
312
- config.warmupSteps,
313
- config.totalSteps
314
- );
315
- }
316
-
317
- getLR(): number {
318
- return this.inner.get_lr();
319
- }
320
-
321
- step(): void {
322
- this.inner.step();
323
- }
324
-
325
- reset(): void {
326
- this.inner.reset();
327
- }
328
-
329
- free(): void {
330
- this.inner.free();
331
- }
332
- }
333
-
334
- /**
335
- * Utility functions
336
- */
337
- export const utils = {
338
- /**
339
- * Compute cosine similarity between two vectors
340
- */
341
- cosineSimilarity(a: Float32Array, b: Float32Array): number {
342
- return wasm.cosine_similarity(a, b);
343
- },
344
-
345
- /**
346
- * Compute L2 norm of a vector
347
- */
348
- l2Norm(vec: Float32Array): number {
349
- return wasm.l2_norm(vec);
350
- },
351
-
352
- /**
353
- * Normalize a vector to unit length (in-place)
354
- */
355
- normalize(vec: Float32Array): void {
356
- wasm.normalize(vec);
357
- },
358
-
359
- /**
360
- * Apply softmax to a vector (in-place)
361
- */
362
- softmax(vec: Float32Array): void {
363
- wasm.softmax(vec);
364
- },
365
-
366
- /**
367
- * Compute attention weights from scores (in-place)
368
- */
369
- attentionWeights(scores: Float32Array, temperature?: number): void {
370
- wasm.attention_weights(scores, temperature);
371
- },
372
-
373
- /**
374
- * Batch normalize vectors
375
- */
376
- batchNormalize(vectors: Float32Array[], epsilon?: number): Float32Array {
377
- const result = wasm.batch_normalize(vectors, epsilon);
378
- return new Float32Array(result);
379
- },
380
-
381
- /**
382
- * Generate random orthogonal matrix
383
- */
384
- randomOrthogonalMatrix(dim: number): Float32Array {
385
- const result = wasm.random_orthogonal_matrix(dim);
386
- return new Float32Array(result);
387
- },
388
-
389
- /**
390
- * Compute pairwise distances between vectors
391
- */
392
- pairwiseDistances(vectors: Float32Array[]): Float32Array {
393
- const result = wasm.pairwise_distances(vectors);
394
- return new Float32Array(result);
395
- },
396
- };
397
-
398
- /**
399
- * Simple scaled dot-product attention (functional API)
400
- */
401
- export function scaledDotAttention(
402
- query: Float32Array,
403
- keys: Float32Array[],
404
- values: Float32Array[],
405
- scale?: number
406
- ): Float32Array {
407
- const result = wasm.scaled_dot_attention(query, keys, values, scale);
408
- return new Float32Array(result);
409
- }
410
-
411
- // Re-export WASM module for advanced usage
412
- export { wasm };
package/js/types.ts DELETED
@@ -1,108 +0,0 @@
1
- /**
2
- * TypeScript type definitions for ruvector-attention-wasm
3
- */
4
-
5
- export interface AttentionConfig {
6
- /** Embedding dimension */
7
- dim: number;
8
- /** Number of attention heads (for multi-head attention) */
9
- numHeads?: number;
10
- /** Dropout probability */
11
- dropout?: number;
12
- /** Scaling factor for attention scores */
13
- scale?: number;
14
- /** Whether to use causal masking */
15
- causal?: boolean;
16
- }
17
-
18
- export interface MultiHeadConfig extends AttentionConfig {
19
- numHeads: number;
20
- }
21
-
22
- export interface HyperbolicConfig extends AttentionConfig {
23
- /** Hyperbolic space curvature */
24
- curvature: number;
25
- }
26
-
27
- export interface LinearAttentionConfig extends AttentionConfig {
28
- /** Number of random features for kernel approximation */
29
- numFeatures: number;
30
- }
31
-
32
- export interface FlashAttentionConfig extends AttentionConfig {
33
- /** Block size for tiling */
34
- blockSize: number;
35
- }
36
-
37
- export interface LocalGlobalConfig extends AttentionConfig {
38
- /** Size of local attention window */
39
- localWindow: number;
40
- /** Number of global attention tokens */
41
- globalTokens: number;
42
- }
43
-
44
- export interface MoEConfig extends AttentionConfig {
45
- /** Number of expert attention mechanisms */
46
- numExperts: number;
47
- /** Number of experts to use per query */
48
- topK: number;
49
- /** Maximum capacity per expert */
50
- expertCapacity?: number;
51
- /** Load balancing coefficient */
52
- balanceCoeff?: number;
53
- }
54
-
55
- export interface TrainingConfig {
56
- /** Learning rate for optimizer */
57
- learningRate: number;
58
- /** Temperature parameter for contrastive loss */
59
- temperature?: number;
60
- /** First moment decay rate (Adam/AdamW) */
61
- beta1?: number;
62
- /** Second moment decay rate (Adam/AdamW) */
63
- beta2?: number;
64
- /** Weight decay coefficient (AdamW) */
65
- weightDecay?: number;
66
- /** Numerical stability constant */
67
- epsilon?: number;
68
- }
69
-
70
- export interface SchedulerConfig {
71
- /** Initial learning rate */
72
- initialLR: number;
73
- /** Number of warmup steps */
74
- warmupSteps: number;
75
- /** Total training steps */
76
- totalSteps: number;
77
- }
78
-
79
- export interface ExpertStats {
80
- /** Number of times each expert was selected */
81
- selectionCounts: number[];
82
- /** Average load per expert */
83
- averageLoad: number[];
84
- /** Load balance factor (lower is better) */
85
- loadBalance: number;
86
- }
87
-
88
- /**
89
- * Attention mechanism types
90
- */
91
- export type AttentionType =
92
- | 'scaled_dot_product'
93
- | 'multi_head'
94
- | 'hyperbolic'
95
- | 'linear'
96
- | 'flash'
97
- | 'local_global'
98
- | 'moe';
99
-
100
- /**
101
- * Optimizer types
102
- */
103
- export type OptimizerType = 'adam' | 'adamw';
104
-
105
- /**
106
- * Loss function types
107
- */
108
- export type LossType = 'info_nce';
package/pkg/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2025 rUv
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.