@ruvector/attention-wasm 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +21 -0
- package/README.md +193 -0
- package/js/index.ts +412 -0
- package/js/types.ts +108 -0
- package/package.json +45 -0
- package/pkg/LICENSE +21 -0
- package/pkg/README.md +193 -0
package/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 rUv
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
package/README.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# ruvector-attention-wasm
|
|
2
|
+
|
|
3
|
+
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- **Multiple Attention Mechanisms**:
|
|
8
|
+
- Scaled Dot-Product Attention
|
|
9
|
+
- Multi-Head Attention
|
|
10
|
+
- Hyperbolic Attention (for hierarchical data)
|
|
11
|
+
- Linear Attention (Performer-style)
|
|
12
|
+
- Flash Attention (memory-efficient)
|
|
13
|
+
- Local-Global Attention
|
|
14
|
+
- Mixture of Experts (MoE) Attention
|
|
15
|
+
|
|
16
|
+
- **Training Utilities**:
|
|
17
|
+
- InfoNCE contrastive loss
|
|
18
|
+
- Adam optimizer
|
|
19
|
+
- AdamW optimizer (with decoupled weight decay)
|
|
20
|
+
- Learning rate scheduler (warmup + cosine decay)
|
|
21
|
+
|
|
22
|
+
- **TypeScript Support**: Full type definitions and modern API
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
npm install ruvector-attention-wasm
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Usage
|
|
31
|
+
|
|
32
|
+
### TypeScript/JavaScript
|
|
33
|
+
|
|
34
|
+
```typescript
|
|
35
|
+
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
|
|
36
|
+
|
|
37
|
+
// Initialize WASM module
|
|
38
|
+
await initialize();
|
|
39
|
+
|
|
40
|
+
// Create multi-head attention
|
|
41
|
+
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
|
|
42
|
+
|
|
43
|
+
// Prepare inputs
|
|
44
|
+
const query = new Float32Array(64);
|
|
45
|
+
const keys = [new Float32Array(64), new Float32Array(64)];
|
|
46
|
+
const values = [new Float32Array(64), new Float32Array(64)];
|
|
47
|
+
|
|
48
|
+
// Compute attention
|
|
49
|
+
const output = attention.compute(query, keys, values);
|
|
50
|
+
|
|
51
|
+
// Use utilities
|
|
52
|
+
const similarity = utils.cosineSimilarity(query, keys[0]);
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
### Advanced Examples
|
|
56
|
+
|
|
57
|
+
#### Hyperbolic Attention
|
|
58
|
+
|
|
59
|
+
```typescript
|
|
60
|
+
import { HyperbolicAttention } from 'ruvector-attention-wasm';
|
|
61
|
+
|
|
62
|
+
const hyperbolic = new HyperbolicAttention({
|
|
63
|
+
dim: 128,
|
|
64
|
+
curvature: 1.0
|
|
65
|
+
});
|
|
66
|
+
|
|
67
|
+
const output = hyperbolic.compute(query, keys, values);
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
#### MoE Attention with Expert Stats
|
|
71
|
+
|
|
72
|
+
```typescript
|
|
73
|
+
import { MoEAttention } from 'ruvector-attention-wasm';
|
|
74
|
+
|
|
75
|
+
const moe = new MoEAttention({
|
|
76
|
+
dim: 64,
|
|
77
|
+
numExperts: 4,
|
|
78
|
+
topK: 2
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
const output = moe.compute(query, keys, values);
|
|
82
|
+
|
|
83
|
+
// Get expert utilization
|
|
84
|
+
const stats = moe.getExpertStats();
|
|
85
|
+
console.log('Load balance:', stats.loadBalance);
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
#### Training with InfoNCE Loss
|
|
89
|
+
|
|
90
|
+
```typescript
|
|
91
|
+
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
|
|
92
|
+
|
|
93
|
+
const loss = new InfoNCELoss(0.07);
|
|
94
|
+
const optimizer = new Adam(paramCount, {
|
|
95
|
+
learningRate: 0.001,
|
|
96
|
+
beta1: 0.9,
|
|
97
|
+
beta2: 0.999,
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
// Training loop
|
|
101
|
+
const lossValue = loss.compute(anchor, positive, negatives);
|
|
102
|
+
optimizer.step(params, gradients);
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
#### Learning Rate Scheduling
|
|
106
|
+
|
|
107
|
+
```typescript
|
|
108
|
+
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
|
|
109
|
+
|
|
110
|
+
const scheduler = new LRScheduler({
|
|
111
|
+
initialLR: 0.001,
|
|
112
|
+
warmupSteps: 1000,
|
|
113
|
+
totalSteps: 10000,
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
const optimizer = new AdamW(paramCount, {
|
|
117
|
+
learningRate: scheduler.getLR(),
|
|
118
|
+
weightDecay: 0.01,
|
|
119
|
+
});
|
|
120
|
+
|
|
121
|
+
// Training loop
|
|
122
|
+
for (let step = 0; step < 10000; step++) {
|
|
123
|
+
optimizer.learningRate = scheduler.getLR();
|
|
124
|
+
optimizer.step(params, gradients);
|
|
125
|
+
scheduler.step();
|
|
126
|
+
}
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
## Building from Source
|
|
130
|
+
|
|
131
|
+
### Prerequisites
|
|
132
|
+
|
|
133
|
+
- Rust 1.70+
|
|
134
|
+
- wasm-pack
|
|
135
|
+
|
|
136
|
+
### Build Commands
|
|
137
|
+
|
|
138
|
+
```bash
|
|
139
|
+
# Build for web (ES modules)
|
|
140
|
+
wasm-pack build --target web --out-dir pkg
|
|
141
|
+
|
|
142
|
+
# Build for Node.js
|
|
143
|
+
wasm-pack build --target nodejs --out-dir pkg-node
|
|
144
|
+
|
|
145
|
+
# Build for bundlers (webpack, vite, etc.)
|
|
146
|
+
wasm-pack build --target bundler --out-dir pkg-bundler
|
|
147
|
+
|
|
148
|
+
# Run tests
|
|
149
|
+
wasm-pack test --headless --firefox
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
## API Reference
|
|
153
|
+
|
|
154
|
+
### Attention Mechanisms
|
|
155
|
+
|
|
156
|
+
- `MultiHeadAttention` - Standard multi-head attention
|
|
157
|
+
- `HyperbolicAttention` - Attention in hyperbolic space
|
|
158
|
+
- `LinearAttention` - Linear complexity attention (Performer)
|
|
159
|
+
- `FlashAttention` - Memory-efficient attention
|
|
160
|
+
- `LocalGlobalAttention` - Combined local and global attention
|
|
161
|
+
- `MoEAttention` - Mixture of Experts attention
|
|
162
|
+
- `scaledDotAttention()` - Functional API for basic attention
|
|
163
|
+
|
|
164
|
+
### Training
|
|
165
|
+
|
|
166
|
+
- `InfoNCELoss` - Contrastive loss function
|
|
167
|
+
- `Adam` - Adam optimizer
|
|
168
|
+
- `AdamW` - AdamW optimizer with weight decay
|
|
169
|
+
- `LRScheduler` - Learning rate scheduler
|
|
170
|
+
|
|
171
|
+
### Utilities
|
|
172
|
+
|
|
173
|
+
- `utils.cosineSimilarity()` - Cosine similarity between vectors
|
|
174
|
+
- `utils.l2Norm()` - L2 norm of a vector
|
|
175
|
+
- `utils.normalize()` - Normalize vector to unit length
|
|
176
|
+
- `utils.softmax()` - Apply softmax transformation
|
|
177
|
+
- `utils.attentionWeights()` - Compute attention weights from scores
|
|
178
|
+
- `utils.batchNormalize()` - Batch normalization
|
|
179
|
+
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
|
|
180
|
+
- `utils.pairwiseDistances()` - Compute pairwise distances
|
|
181
|
+
|
|
182
|
+
## Performance
|
|
183
|
+
|
|
184
|
+
The WASM bindings provide near-native performance for attention computations:
|
|
185
|
+
|
|
186
|
+
- Optimized with `opt-level = "s"` and LTO
|
|
187
|
+
- SIMD acceleration where available
|
|
188
|
+
- Efficient memory management
|
|
189
|
+
- Zero-copy data transfer where possible
|
|
190
|
+
|
|
191
|
+
## License
|
|
192
|
+
|
|
193
|
+
MIT OR Apache-2.0
|
package/js/index.ts
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* TypeScript wrapper for ruvector-attention-wasm
|
|
3
|
+
* Provides a clean, type-safe API for attention mechanisms
|
|
4
|
+
*/
|
|
5
|
+
|
|
6
|
+
import init, * as wasm from '../pkg/ruvector_attention_wasm';
|
|
7
|
+
import type {
|
|
8
|
+
AttentionConfig,
|
|
9
|
+
MultiHeadConfig,
|
|
10
|
+
HyperbolicConfig,
|
|
11
|
+
LinearAttentionConfig,
|
|
12
|
+
FlashAttentionConfig,
|
|
13
|
+
LocalGlobalConfig,
|
|
14
|
+
MoEConfig,
|
|
15
|
+
TrainingConfig,
|
|
16
|
+
SchedulerConfig,
|
|
17
|
+
ExpertStats,
|
|
18
|
+
AttentionType,
|
|
19
|
+
} from './types';
|
|
20
|
+
|
|
21
|
+
export * from './types';
|
|
22
|
+
|
|
23
|
+
let initialized = false;
|
|
24
|
+
|
|
25
|
+
/**
|
|
26
|
+
* Initialize the WASM module
|
|
27
|
+
* Must be called before using any attention mechanisms
|
|
28
|
+
*/
|
|
29
|
+
export async function initialize(): Promise<void> {
|
|
30
|
+
if (!initialized) {
|
|
31
|
+
await init();
|
|
32
|
+
initialized = true;
|
|
33
|
+
}
|
|
34
|
+
}
|
|
35
|
+
|
|
36
|
+
/**
|
|
37
|
+
* Get the version of the ruvector-attention-wasm package
|
|
38
|
+
*/
|
|
39
|
+
export function version(): string {
|
|
40
|
+
return wasm.version();
|
|
41
|
+
}
|
|
42
|
+
|
|
43
|
+
/**
|
|
44
|
+
* Get list of available attention mechanisms
|
|
45
|
+
*/
|
|
46
|
+
export function availableMechanisms(): AttentionType[] {
|
|
47
|
+
return wasm.available_mechanisms() as AttentionType[];
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
/**
|
|
51
|
+
* Multi-head attention mechanism
|
|
52
|
+
*/
|
|
53
|
+
export class MultiHeadAttention {
|
|
54
|
+
private inner: wasm.WasmMultiHeadAttention;
|
|
55
|
+
|
|
56
|
+
constructor(config: MultiHeadConfig) {
|
|
57
|
+
this.inner = new wasm.WasmMultiHeadAttention(config.dim, config.numHeads);
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
/**
|
|
61
|
+
* Compute multi-head attention
|
|
62
|
+
*/
|
|
63
|
+
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
|
64
|
+
const result = this.inner.compute(query, keys, values);
|
|
65
|
+
return new Float32Array(result);
|
|
66
|
+
}
|
|
67
|
+
|
|
68
|
+
get numHeads(): number {
|
|
69
|
+
return this.inner.num_heads;
|
|
70
|
+
}
|
|
71
|
+
|
|
72
|
+
get dim(): number {
|
|
73
|
+
return this.inner.dim;
|
|
74
|
+
}
|
|
75
|
+
|
|
76
|
+
free(): void {
|
|
77
|
+
this.inner.free();
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
/**
|
|
82
|
+
* Hyperbolic attention mechanism
|
|
83
|
+
*/
|
|
84
|
+
export class HyperbolicAttention {
|
|
85
|
+
private inner: wasm.WasmHyperbolicAttention;
|
|
86
|
+
|
|
87
|
+
constructor(config: HyperbolicConfig) {
|
|
88
|
+
this.inner = new wasm.WasmHyperbolicAttention(config.dim, config.curvature);
|
|
89
|
+
}
|
|
90
|
+
|
|
91
|
+
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
|
92
|
+
const result = this.inner.compute(query, keys, values);
|
|
93
|
+
return new Float32Array(result);
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
get curvature(): number {
|
|
97
|
+
return this.inner.curvature;
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
free(): void {
|
|
101
|
+
this.inner.free();
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
/**
|
|
106
|
+
* Linear attention (Performer-style)
|
|
107
|
+
*/
|
|
108
|
+
export class LinearAttention {
|
|
109
|
+
private inner: wasm.WasmLinearAttention;
|
|
110
|
+
|
|
111
|
+
constructor(config: LinearAttentionConfig) {
|
|
112
|
+
this.inner = new wasm.WasmLinearAttention(config.dim, config.numFeatures);
|
|
113
|
+
}
|
|
114
|
+
|
|
115
|
+
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
|
116
|
+
const result = this.inner.compute(query, keys, values);
|
|
117
|
+
return new Float32Array(result);
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
free(): void {
|
|
121
|
+
this.inner.free();
|
|
122
|
+
}
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
/**
|
|
126
|
+
* Flash attention mechanism
|
|
127
|
+
*/
|
|
128
|
+
export class FlashAttention {
|
|
129
|
+
private inner: wasm.WasmFlashAttention;
|
|
130
|
+
|
|
131
|
+
constructor(config: FlashAttentionConfig) {
|
|
132
|
+
this.inner = new wasm.WasmFlashAttention(config.dim, config.blockSize);
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
|
136
|
+
const result = this.inner.compute(query, keys, values);
|
|
137
|
+
return new Float32Array(result);
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
free(): void {
|
|
141
|
+
this.inner.free();
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
/**
|
|
146
|
+
* Local-global attention mechanism
|
|
147
|
+
*/
|
|
148
|
+
export class LocalGlobalAttention {
|
|
149
|
+
private inner: wasm.WasmLocalGlobalAttention;
|
|
150
|
+
|
|
151
|
+
constructor(config: LocalGlobalConfig) {
|
|
152
|
+
this.inner = new wasm.WasmLocalGlobalAttention(
|
|
153
|
+
config.dim,
|
|
154
|
+
config.localWindow,
|
|
155
|
+
config.globalTokens
|
|
156
|
+
);
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
|
160
|
+
const result = this.inner.compute(query, keys, values);
|
|
161
|
+
return new Float32Array(result);
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
free(): void {
|
|
165
|
+
this.inner.free();
|
|
166
|
+
}
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
/**
|
|
170
|
+
* Mixture of Experts attention
|
|
171
|
+
*/
|
|
172
|
+
export class MoEAttention {
|
|
173
|
+
private inner: wasm.WasmMoEAttention;
|
|
174
|
+
|
|
175
|
+
constructor(config: MoEConfig) {
|
|
176
|
+
this.inner = new wasm.WasmMoEAttention(config.dim, config.numExperts, config.topK);
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
compute(query: Float32Array, keys: Float32Array[], values: Float32Array[]): Float32Array {
|
|
180
|
+
const result = this.inner.compute(query, keys, values);
|
|
181
|
+
return new Float32Array(result);
|
|
182
|
+
}
|
|
183
|
+
|
|
184
|
+
getExpertStats(): ExpertStats {
|
|
185
|
+
return this.inner.expert_stats() as ExpertStats;
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
free(): void {
|
|
189
|
+
this.inner.free();
|
|
190
|
+
}
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
/**
|
|
194
|
+
* InfoNCE contrastive loss
|
|
195
|
+
*/
|
|
196
|
+
export class InfoNCELoss {
|
|
197
|
+
private inner: wasm.WasmInfoNCELoss;
|
|
198
|
+
|
|
199
|
+
constructor(temperature: number = 0.07) {
|
|
200
|
+
this.inner = new wasm.WasmInfoNCELoss(temperature);
|
|
201
|
+
}
|
|
202
|
+
|
|
203
|
+
compute(anchor: Float32Array, positive: Float32Array, negatives: Float32Array[]): number {
|
|
204
|
+
return this.inner.compute(anchor, positive, negatives);
|
|
205
|
+
}
|
|
206
|
+
|
|
207
|
+
computeMultiPositive(
|
|
208
|
+
anchor: Float32Array,
|
|
209
|
+
positives: Float32Array[],
|
|
210
|
+
negatives: Float32Array[]
|
|
211
|
+
): number {
|
|
212
|
+
return this.inner.compute_multi_positive(anchor, positives, negatives);
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
free(): void {
|
|
216
|
+
this.inner.free();
|
|
217
|
+
}
|
|
218
|
+
}
|
|
219
|
+
|
|
220
|
+
/**
|
|
221
|
+
* Adam optimizer
|
|
222
|
+
*/
|
|
223
|
+
export class Adam {
|
|
224
|
+
private inner: wasm.WasmAdam;
|
|
225
|
+
|
|
226
|
+
constructor(paramCount: number, config: TrainingConfig) {
|
|
227
|
+
this.inner = new wasm.WasmAdam(
|
|
228
|
+
paramCount,
|
|
229
|
+
config.learningRate,
|
|
230
|
+
config.beta1,
|
|
231
|
+
config.beta2,
|
|
232
|
+
config.epsilon
|
|
233
|
+
);
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
step(params: Float32Array, gradients: Float32Array): void {
|
|
237
|
+
this.inner.step(params, gradients);
|
|
238
|
+
}
|
|
239
|
+
|
|
240
|
+
reset(): void {
|
|
241
|
+
this.inner.reset();
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
get learningRate(): number {
|
|
245
|
+
return this.inner.learning_rate;
|
|
246
|
+
}
|
|
247
|
+
|
|
248
|
+
set learningRate(lr: number) {
|
|
249
|
+
this.inner.learning_rate = lr;
|
|
250
|
+
}
|
|
251
|
+
|
|
252
|
+
free(): void {
|
|
253
|
+
this.inner.free();
|
|
254
|
+
}
|
|
255
|
+
}
|
|
256
|
+
|
|
257
|
+
/**
|
|
258
|
+
* AdamW optimizer (Adam with decoupled weight decay)
|
|
259
|
+
*/
|
|
260
|
+
export class AdamW {
|
|
261
|
+
private inner: wasm.WasmAdamW;
|
|
262
|
+
|
|
263
|
+
constructor(paramCount: number, config: TrainingConfig) {
|
|
264
|
+
if (!config.weightDecay) {
|
|
265
|
+
throw new Error('AdamW requires weightDecay parameter');
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
this.inner = new wasm.WasmAdamW(
|
|
269
|
+
paramCount,
|
|
270
|
+
config.learningRate,
|
|
271
|
+
config.weightDecay,
|
|
272
|
+
config.beta1,
|
|
273
|
+
config.beta2,
|
|
274
|
+
config.epsilon
|
|
275
|
+
);
|
|
276
|
+
}
|
|
277
|
+
|
|
278
|
+
step(params: Float32Array, gradients: Float32Array): void {
|
|
279
|
+
this.inner.step(params, gradients);
|
|
280
|
+
}
|
|
281
|
+
|
|
282
|
+
reset(): void {
|
|
283
|
+
this.inner.reset();
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
get learningRate(): number {
|
|
287
|
+
return this.inner.learning_rate;
|
|
288
|
+
}
|
|
289
|
+
|
|
290
|
+
set learningRate(lr: number) {
|
|
291
|
+
this.inner.learning_rate = lr;
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
get weightDecay(): number {
|
|
295
|
+
return this.inner.weight_decay;
|
|
296
|
+
}
|
|
297
|
+
|
|
298
|
+
free(): void {
|
|
299
|
+
this.inner.free();
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
/**
|
|
304
|
+
* Learning rate scheduler with warmup and cosine decay
|
|
305
|
+
*/
|
|
306
|
+
export class LRScheduler {
|
|
307
|
+
private inner: wasm.WasmLRScheduler;
|
|
308
|
+
|
|
309
|
+
constructor(config: SchedulerConfig) {
|
|
310
|
+
this.inner = new wasm.WasmLRScheduler(
|
|
311
|
+
config.initialLR,
|
|
312
|
+
config.warmupSteps,
|
|
313
|
+
config.totalSteps
|
|
314
|
+
);
|
|
315
|
+
}
|
|
316
|
+
|
|
317
|
+
getLR(): number {
|
|
318
|
+
return this.inner.get_lr();
|
|
319
|
+
}
|
|
320
|
+
|
|
321
|
+
step(): void {
|
|
322
|
+
this.inner.step();
|
|
323
|
+
}
|
|
324
|
+
|
|
325
|
+
reset(): void {
|
|
326
|
+
this.inner.reset();
|
|
327
|
+
}
|
|
328
|
+
|
|
329
|
+
free(): void {
|
|
330
|
+
this.inner.free();
|
|
331
|
+
}
|
|
332
|
+
}
|
|
333
|
+
|
|
334
|
+
/**
|
|
335
|
+
* Utility functions
|
|
336
|
+
*/
|
|
337
|
+
export const utils = {
|
|
338
|
+
/**
|
|
339
|
+
* Compute cosine similarity between two vectors
|
|
340
|
+
*/
|
|
341
|
+
cosineSimilarity(a: Float32Array, b: Float32Array): number {
|
|
342
|
+
return wasm.cosine_similarity(a, b);
|
|
343
|
+
},
|
|
344
|
+
|
|
345
|
+
/**
|
|
346
|
+
* Compute L2 norm of a vector
|
|
347
|
+
*/
|
|
348
|
+
l2Norm(vec: Float32Array): number {
|
|
349
|
+
return wasm.l2_norm(vec);
|
|
350
|
+
},
|
|
351
|
+
|
|
352
|
+
/**
|
|
353
|
+
* Normalize a vector to unit length (in-place)
|
|
354
|
+
*/
|
|
355
|
+
normalize(vec: Float32Array): void {
|
|
356
|
+
wasm.normalize(vec);
|
|
357
|
+
},
|
|
358
|
+
|
|
359
|
+
/**
|
|
360
|
+
* Apply softmax to a vector (in-place)
|
|
361
|
+
*/
|
|
362
|
+
softmax(vec: Float32Array): void {
|
|
363
|
+
wasm.softmax(vec);
|
|
364
|
+
},
|
|
365
|
+
|
|
366
|
+
/**
|
|
367
|
+
* Compute attention weights from scores (in-place)
|
|
368
|
+
*/
|
|
369
|
+
attentionWeights(scores: Float32Array, temperature?: number): void {
|
|
370
|
+
wasm.attention_weights(scores, temperature);
|
|
371
|
+
},
|
|
372
|
+
|
|
373
|
+
/**
|
|
374
|
+
* Batch normalize vectors
|
|
375
|
+
*/
|
|
376
|
+
batchNormalize(vectors: Float32Array[], epsilon?: number): Float32Array {
|
|
377
|
+
const result = wasm.batch_normalize(vectors, epsilon);
|
|
378
|
+
return new Float32Array(result);
|
|
379
|
+
},
|
|
380
|
+
|
|
381
|
+
/**
|
|
382
|
+
* Generate random orthogonal matrix
|
|
383
|
+
*/
|
|
384
|
+
randomOrthogonalMatrix(dim: number): Float32Array {
|
|
385
|
+
const result = wasm.random_orthogonal_matrix(dim);
|
|
386
|
+
return new Float32Array(result);
|
|
387
|
+
},
|
|
388
|
+
|
|
389
|
+
/**
|
|
390
|
+
* Compute pairwise distances between vectors
|
|
391
|
+
*/
|
|
392
|
+
pairwiseDistances(vectors: Float32Array[]): Float32Array {
|
|
393
|
+
const result = wasm.pairwise_distances(vectors);
|
|
394
|
+
return new Float32Array(result);
|
|
395
|
+
},
|
|
396
|
+
};
|
|
397
|
+
|
|
398
|
+
/**
|
|
399
|
+
* Simple scaled dot-product attention (functional API)
|
|
400
|
+
*/
|
|
401
|
+
export function scaledDotAttention(
|
|
402
|
+
query: Float32Array,
|
|
403
|
+
keys: Float32Array[],
|
|
404
|
+
values: Float32Array[],
|
|
405
|
+
scale?: number
|
|
406
|
+
): Float32Array {
|
|
407
|
+
const result = wasm.scaled_dot_attention(query, keys, values, scale);
|
|
408
|
+
return new Float32Array(result);
|
|
409
|
+
}
|
|
410
|
+
|
|
411
|
+
// Re-export WASM module for advanced usage
|
|
412
|
+
export { wasm };
|
package/js/types.ts
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* TypeScript type definitions for ruvector-attention-wasm
|
|
3
|
+
*/
|
|
4
|
+
|
|
5
|
+
export interface AttentionConfig {
|
|
6
|
+
/** Embedding dimension */
|
|
7
|
+
dim: number;
|
|
8
|
+
/** Number of attention heads (for multi-head attention) */
|
|
9
|
+
numHeads?: number;
|
|
10
|
+
/** Dropout probability */
|
|
11
|
+
dropout?: number;
|
|
12
|
+
/** Scaling factor for attention scores */
|
|
13
|
+
scale?: number;
|
|
14
|
+
/** Whether to use causal masking */
|
|
15
|
+
causal?: boolean;
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
export interface MultiHeadConfig extends AttentionConfig {
|
|
19
|
+
numHeads: number;
|
|
20
|
+
}
|
|
21
|
+
|
|
22
|
+
export interface HyperbolicConfig extends AttentionConfig {
|
|
23
|
+
/** Hyperbolic space curvature */
|
|
24
|
+
curvature: number;
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
export interface LinearAttentionConfig extends AttentionConfig {
|
|
28
|
+
/** Number of random features for kernel approximation */
|
|
29
|
+
numFeatures: number;
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
export interface FlashAttentionConfig extends AttentionConfig {
|
|
33
|
+
/** Block size for tiling */
|
|
34
|
+
blockSize: number;
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
export interface LocalGlobalConfig extends AttentionConfig {
|
|
38
|
+
/** Size of local attention window */
|
|
39
|
+
localWindow: number;
|
|
40
|
+
/** Number of global attention tokens */
|
|
41
|
+
globalTokens: number;
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
export interface MoEConfig extends AttentionConfig {
|
|
45
|
+
/** Number of expert attention mechanisms */
|
|
46
|
+
numExperts: number;
|
|
47
|
+
/** Number of experts to use per query */
|
|
48
|
+
topK: number;
|
|
49
|
+
/** Maximum capacity per expert */
|
|
50
|
+
expertCapacity?: number;
|
|
51
|
+
/** Load balancing coefficient */
|
|
52
|
+
balanceCoeff?: number;
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
export interface TrainingConfig {
|
|
56
|
+
/** Learning rate for optimizer */
|
|
57
|
+
learningRate: number;
|
|
58
|
+
/** Temperature parameter for contrastive loss */
|
|
59
|
+
temperature?: number;
|
|
60
|
+
/** First moment decay rate (Adam/AdamW) */
|
|
61
|
+
beta1?: number;
|
|
62
|
+
/** Second moment decay rate (Adam/AdamW) */
|
|
63
|
+
beta2?: number;
|
|
64
|
+
/** Weight decay coefficient (AdamW) */
|
|
65
|
+
weightDecay?: number;
|
|
66
|
+
/** Numerical stability constant */
|
|
67
|
+
epsilon?: number;
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
export interface SchedulerConfig {
|
|
71
|
+
/** Initial learning rate */
|
|
72
|
+
initialLR: number;
|
|
73
|
+
/** Number of warmup steps */
|
|
74
|
+
warmupSteps: number;
|
|
75
|
+
/** Total training steps */
|
|
76
|
+
totalSteps: number;
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
export interface ExpertStats {
|
|
80
|
+
/** Number of times each expert was selected */
|
|
81
|
+
selectionCounts: number[];
|
|
82
|
+
/** Average load per expert */
|
|
83
|
+
averageLoad: number[];
|
|
84
|
+
/** Load balance factor (lower is better) */
|
|
85
|
+
loadBalance: number;
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
/**
|
|
89
|
+
* Attention mechanism types
|
|
90
|
+
*/
|
|
91
|
+
export type AttentionType =
|
|
92
|
+
| 'scaled_dot_product'
|
|
93
|
+
| 'multi_head'
|
|
94
|
+
| 'hyperbolic'
|
|
95
|
+
| 'linear'
|
|
96
|
+
| 'flash'
|
|
97
|
+
| 'local_global'
|
|
98
|
+
| 'moe';
|
|
99
|
+
|
|
100
|
+
/**
|
|
101
|
+
* Optimizer types
|
|
102
|
+
*/
|
|
103
|
+
export type OptimizerType = 'adam' | 'adamw';
|
|
104
|
+
|
|
105
|
+
/**
|
|
106
|
+
* Loss function types
|
|
107
|
+
*/
|
|
108
|
+
export type LossType = 'info_nce';
|
package/package.json
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "@ruvector/attention-wasm",
|
|
3
|
+
"version": "0.1.0",
|
|
4
|
+
"description": "WebAssembly bindings for ruvector-attention - high-performance attention mechanisms",
|
|
5
|
+
"main": "pkg/ruvector_attention_wasm.js",
|
|
6
|
+
"types": "js/index.ts",
|
|
7
|
+
"files": [
|
|
8
|
+
"pkg/",
|
|
9
|
+
"js/"
|
|
10
|
+
],
|
|
11
|
+
"scripts": {
|
|
12
|
+
"build": "wasm-pack build --target web --out-dir pkg",
|
|
13
|
+
"build:node": "wasm-pack build --target nodejs --out-dir pkg-node",
|
|
14
|
+
"build:bundler": "wasm-pack build --target bundler --out-dir pkg-bundler",
|
|
15
|
+
"build:all": "npm run build && npm run build:node && npm run build:bundler",
|
|
16
|
+
"test": "wasm-pack test --headless --firefox",
|
|
17
|
+
"test:chrome": "wasm-pack test --headless --chrome",
|
|
18
|
+
"clean": "rm -rf pkg pkg-node pkg-bundler target"
|
|
19
|
+
},
|
|
20
|
+
"repository": {
|
|
21
|
+
"type": "git",
|
|
22
|
+
"url": "https://github.com/ruvnet/ruvector"
|
|
23
|
+
},
|
|
24
|
+
"keywords": [
|
|
25
|
+
"wasm",
|
|
26
|
+
"webassembly",
|
|
27
|
+
"attention",
|
|
28
|
+
"transformer",
|
|
29
|
+
"machine-learning",
|
|
30
|
+
"neural-networks",
|
|
31
|
+
"hyperbolic",
|
|
32
|
+
"moe",
|
|
33
|
+
"flash-attention"
|
|
34
|
+
],
|
|
35
|
+
"author": "rUv",
|
|
36
|
+
"license": "MIT OR Apache-2.0",
|
|
37
|
+
"bugs": {
|
|
38
|
+
"url": "https://github.com/ruvnet/ruvector/issues"
|
|
39
|
+
},
|
|
40
|
+
"homepage": "https://ruv.io/ruvector",
|
|
41
|
+
"devDependencies": {
|
|
42
|
+
"@types/node": "^20.0.0",
|
|
43
|
+
"typescript": "^5.0.0"
|
|
44
|
+
}
|
|
45
|
+
}
|
package/pkg/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 rUv
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
package/pkg/README.md
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
# ruvector-attention-wasm
|
|
2
|
+
|
|
3
|
+
WebAssembly bindings for the ruvector-attention package, providing high-performance attention mechanisms for browser and Node.js environments.
|
|
4
|
+
|
|
5
|
+
## Features
|
|
6
|
+
|
|
7
|
+
- **Multiple Attention Mechanisms**:
|
|
8
|
+
- Scaled Dot-Product Attention
|
|
9
|
+
- Multi-Head Attention
|
|
10
|
+
- Hyperbolic Attention (for hierarchical data)
|
|
11
|
+
- Linear Attention (Performer-style)
|
|
12
|
+
- Flash Attention (memory-efficient)
|
|
13
|
+
- Local-Global Attention
|
|
14
|
+
- Mixture of Experts (MoE) Attention
|
|
15
|
+
|
|
16
|
+
- **Training Utilities**:
|
|
17
|
+
- InfoNCE contrastive loss
|
|
18
|
+
- Adam optimizer
|
|
19
|
+
- AdamW optimizer (with decoupled weight decay)
|
|
20
|
+
- Learning rate scheduler (warmup + cosine decay)
|
|
21
|
+
|
|
22
|
+
- **TypeScript Support**: Full type definitions and modern API
|
|
23
|
+
|
|
24
|
+
## Installation
|
|
25
|
+
|
|
26
|
+
```bash
|
|
27
|
+
npm install ruvector-attention-wasm
|
|
28
|
+
```
|
|
29
|
+
|
|
30
|
+
## Usage
|
|
31
|
+
|
|
32
|
+
### TypeScript/JavaScript
|
|
33
|
+
|
|
34
|
+
```typescript
|
|
35
|
+
import { initialize, MultiHeadAttention, utils } from 'ruvector-attention-wasm';
|
|
36
|
+
|
|
37
|
+
// Initialize WASM module
|
|
38
|
+
await initialize();
|
|
39
|
+
|
|
40
|
+
// Create multi-head attention
|
|
41
|
+
const attention = new MultiHeadAttention({ dim: 64, numHeads: 8 });
|
|
42
|
+
|
|
43
|
+
// Prepare inputs
|
|
44
|
+
const query = new Float32Array(64);
|
|
45
|
+
const keys = [new Float32Array(64), new Float32Array(64)];
|
|
46
|
+
const values = [new Float32Array(64), new Float32Array(64)];
|
|
47
|
+
|
|
48
|
+
// Compute attention
|
|
49
|
+
const output = attention.compute(query, keys, values);
|
|
50
|
+
|
|
51
|
+
// Use utilities
|
|
52
|
+
const similarity = utils.cosineSimilarity(query, keys[0]);
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
### Advanced Examples
|
|
56
|
+
|
|
57
|
+
#### Hyperbolic Attention
|
|
58
|
+
|
|
59
|
+
```typescript
|
|
60
|
+
import { HyperbolicAttention } from 'ruvector-attention-wasm';
|
|
61
|
+
|
|
62
|
+
const hyperbolic = new HyperbolicAttention({
|
|
63
|
+
dim: 128,
|
|
64
|
+
curvature: 1.0
|
|
65
|
+
});
|
|
66
|
+
|
|
67
|
+
const output = hyperbolic.compute(query, keys, values);
|
|
68
|
+
```
|
|
69
|
+
|
|
70
|
+
#### MoE Attention with Expert Stats
|
|
71
|
+
|
|
72
|
+
```typescript
|
|
73
|
+
import { MoEAttention } from 'ruvector-attention-wasm';
|
|
74
|
+
|
|
75
|
+
const moe = new MoEAttention({
|
|
76
|
+
dim: 64,
|
|
77
|
+
numExperts: 4,
|
|
78
|
+
topK: 2
|
|
79
|
+
});
|
|
80
|
+
|
|
81
|
+
const output = moe.compute(query, keys, values);
|
|
82
|
+
|
|
83
|
+
// Get expert utilization
|
|
84
|
+
const stats = moe.getExpertStats();
|
|
85
|
+
console.log('Load balance:', stats.loadBalance);
|
|
86
|
+
```
|
|
87
|
+
|
|
88
|
+
#### Training with InfoNCE Loss
|
|
89
|
+
|
|
90
|
+
```typescript
|
|
91
|
+
import { InfoNCELoss, Adam } from 'ruvector-attention-wasm';
|
|
92
|
+
|
|
93
|
+
const loss = new InfoNCELoss(0.07);
|
|
94
|
+
const optimizer = new Adam(paramCount, {
|
|
95
|
+
learningRate: 0.001,
|
|
96
|
+
beta1: 0.9,
|
|
97
|
+
beta2: 0.999,
|
|
98
|
+
});
|
|
99
|
+
|
|
100
|
+
// Training loop
|
|
101
|
+
const lossValue = loss.compute(anchor, positive, negatives);
|
|
102
|
+
optimizer.step(params, gradients);
|
|
103
|
+
```
|
|
104
|
+
|
|
105
|
+
#### Learning Rate Scheduling
|
|
106
|
+
|
|
107
|
+
```typescript
|
|
108
|
+
import { LRScheduler, AdamW } from 'ruvector-attention-wasm';
|
|
109
|
+
|
|
110
|
+
const scheduler = new LRScheduler({
|
|
111
|
+
initialLR: 0.001,
|
|
112
|
+
warmupSteps: 1000,
|
|
113
|
+
totalSteps: 10000,
|
|
114
|
+
});
|
|
115
|
+
|
|
116
|
+
const optimizer = new AdamW(paramCount, {
|
|
117
|
+
learningRate: scheduler.getLR(),
|
|
118
|
+
weightDecay: 0.01,
|
|
119
|
+
});
|
|
120
|
+
|
|
121
|
+
// Training loop
|
|
122
|
+
for (let step = 0; step < 10000; step++) {
|
|
123
|
+
optimizer.learningRate = scheduler.getLR();
|
|
124
|
+
optimizer.step(params, gradients);
|
|
125
|
+
scheduler.step();
|
|
126
|
+
}
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
## Building from Source
|
|
130
|
+
|
|
131
|
+
### Prerequisites
|
|
132
|
+
|
|
133
|
+
- Rust 1.70+
|
|
134
|
+
- wasm-pack
|
|
135
|
+
|
|
136
|
+
### Build Commands
|
|
137
|
+
|
|
138
|
+
```bash
|
|
139
|
+
# Build for web (ES modules)
|
|
140
|
+
wasm-pack build --target web --out-dir pkg
|
|
141
|
+
|
|
142
|
+
# Build for Node.js
|
|
143
|
+
wasm-pack build --target nodejs --out-dir pkg-node
|
|
144
|
+
|
|
145
|
+
# Build for bundlers (webpack, vite, etc.)
|
|
146
|
+
wasm-pack build --target bundler --out-dir pkg-bundler
|
|
147
|
+
|
|
148
|
+
# Run tests
|
|
149
|
+
wasm-pack test --headless --firefox
|
|
150
|
+
```
|
|
151
|
+
|
|
152
|
+
## API Reference
|
|
153
|
+
|
|
154
|
+
### Attention Mechanisms
|
|
155
|
+
|
|
156
|
+
- `MultiHeadAttention` - Standard multi-head attention
|
|
157
|
+
- `HyperbolicAttention` - Attention in hyperbolic space
|
|
158
|
+
- `LinearAttention` - Linear complexity attention (Performer)
|
|
159
|
+
- `FlashAttention` - Memory-efficient attention
|
|
160
|
+
- `LocalGlobalAttention` - Combined local and global attention
|
|
161
|
+
- `MoEAttention` - Mixture of Experts attention
|
|
162
|
+
- `scaledDotAttention()` - Functional API for basic attention
|
|
163
|
+
|
|
164
|
+
### Training
|
|
165
|
+
|
|
166
|
+
- `InfoNCELoss` - Contrastive loss function
|
|
167
|
+
- `Adam` - Adam optimizer
|
|
168
|
+
- `AdamW` - AdamW optimizer with weight decay
|
|
169
|
+
- `LRScheduler` - Learning rate scheduler
|
|
170
|
+
|
|
171
|
+
### Utilities
|
|
172
|
+
|
|
173
|
+
- `utils.cosineSimilarity()` - Cosine similarity between vectors
|
|
174
|
+
- `utils.l2Norm()` - L2 norm of a vector
|
|
175
|
+
- `utils.normalize()` - Normalize vector to unit length
|
|
176
|
+
- `utils.softmax()` - Apply softmax transformation
|
|
177
|
+
- `utils.attentionWeights()` - Compute attention weights from scores
|
|
178
|
+
- `utils.batchNormalize()` - Batch normalization
|
|
179
|
+
- `utils.randomOrthogonalMatrix()` - Generate random orthogonal matrix
|
|
180
|
+
- `utils.pairwiseDistances()` - Compute pairwise distances
|
|
181
|
+
|
|
182
|
+
## Performance
|
|
183
|
+
|
|
184
|
+
The WASM bindings provide near-native performance for attention computations:
|
|
185
|
+
|
|
186
|
+
- Optimized with `opt-level = "s"` and LTO
|
|
187
|
+
- SIMD acceleration where available
|
|
188
|
+
- Efficient memory management
|
|
189
|
+
- Zero-copy data transfer where possible
|
|
190
|
+
|
|
191
|
+
## License
|
|
192
|
+
|
|
193
|
+
MIT OR Apache-2.0
|