@ruvector/edge-net 0.5.0 → 0.5.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +281 -10
- package/core-invariants.js +942 -0
- package/models/adapter-hub.js +1008 -0
- package/models/adapter-security.js +792 -0
- package/models/benchmark.js +688 -0
- package/models/distribution.js +791 -0
- package/models/index.js +109 -0
- package/models/integrity.js +753 -0
- package/models/loader.js +725 -0
- package/models/microlora.js +1298 -0
- package/models/model-loader.js +922 -0
- package/models/model-optimizer.js +1245 -0
- package/models/model-registry.js +696 -0
- package/models/model-utils.js +548 -0
- package/models/models-cli.js +914 -0
- package/models/registry.json +214 -0
- package/models/training-utils.js +1418 -0
- package/models/wasm-core.js +1025 -0
- package/network-genesis.js +2847 -0
- package/onnx-worker.js +462 -8
- package/package.json +33 -3
- package/plugins/SECURITY-AUDIT.md +654 -0
- package/plugins/cli.js +43 -3
- package/plugins/implementations/e2e-encryption.js +57 -12
- package/plugins/plugin-loader.js +610 -21
- package/tests/model-optimizer.test.js +644 -0
- package/tests/network-genesis.test.js +562 -0
- package/tests/plugin-benchmark.js +1239 -0
- package/tests/plugin-system-test.js +163 -0
- package/tests/wasm-core.test.js +368 -0
|
@@ -0,0 +1,1418 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* Training Utilities for MicroLoRA
|
|
3
|
+
*
|
|
4
|
+
* Provides comprehensive training infrastructure including data preprocessing,
|
|
5
|
+
* batch generation, loss computation, EWC for continual learning, and
|
|
6
|
+
* gradient checkpointing for memory efficiency.
|
|
7
|
+
*
|
|
8
|
+
* @module @ruvector/edge-net/models/training-utils
|
|
9
|
+
*
|
|
10
|
+
* @example
|
|
11
|
+
* ```javascript
|
|
12
|
+
* import {
|
|
13
|
+
* DataPreprocessor,
|
|
14
|
+
* BatchGenerator,
|
|
15
|
+
* LossComputer,
|
|
16
|
+
* EWCManager,
|
|
17
|
+
* GradientCheckpointer
|
|
18
|
+
* } from '@ruvector/edge-net/models/training-utils';
|
|
19
|
+
*
|
|
20
|
+
* // Preprocess training data
|
|
21
|
+
* const preprocessor = new DataPreprocessor();
|
|
22
|
+
* const processed = await preprocessor.process(examples);
|
|
23
|
+
*
|
|
24
|
+
* // Generate batches
|
|
25
|
+
* const batcher = new BatchGenerator(processed, { batchSize: 8 });
|
|
26
|
+
* for (const batch of batcher) {
|
|
27
|
+
* const loss = LossComputer.contrastive(batch.anchors, batch.positives, batch.negatives);
|
|
28
|
+
* }
|
|
29
|
+
*
|
|
30
|
+
* // Enable EWC for continual learning
|
|
31
|
+
* const ewc = new EWCManager({ lambda: 2000 });
|
|
32
|
+
* ewc.computeFisher(model, dataloader);
|
|
33
|
+
* ```
|
|
34
|
+
*/
|
|
35
|
+
|
|
36
|
+
import { EventEmitter } from 'events';
|
|
37
|
+
import { createHash } from 'crypto';
|
|
38
|
+
|
|
39
|
+
// ============================================
|
|
40
|
+
// TYPE DEFINITIONS (JSDoc)
|
|
41
|
+
// ============================================
|
|
42
|
+
|
|
43
|
+
/**
|
|
44
|
+
* @typedef {Object} TrainingExample
|
|
45
|
+
* @property {string} input - Input text
|
|
46
|
+
* @property {string} output - Expected output
|
|
47
|
+
* @property {number} [quality=1.0] - Example quality weight
|
|
48
|
+
* @property {Object} [metadata] - Optional metadata
|
|
49
|
+
*/
|
|
50
|
+
|
|
51
|
+
/**
|
|
52
|
+
* @typedef {Object} ProcessedExample
|
|
53
|
+
* @property {string} input - Original input text
|
|
54
|
+
* @property {string} output - Original output text
|
|
55
|
+
* @property {Float32Array} inputEmb - Input embedding
|
|
56
|
+
* @property {Float32Array} outputEmb - Output embedding
|
|
57
|
+
* @property {Float32Array} [negativeEmb] - Optional negative embedding
|
|
58
|
+
* @property {number} quality - Example quality weight
|
|
59
|
+
* @property {number[]} inputTokens - Tokenized input
|
|
60
|
+
* @property {number[]} outputTokens - Tokenized output
|
|
61
|
+
*/
|
|
62
|
+
|
|
63
|
+
/**
|
|
64
|
+
* @typedef {Object} Batch
|
|
65
|
+
* @property {ProcessedExample[]} examples - Batch examples
|
|
66
|
+
* @property {Float32Array[]} anchors - Anchor embeddings
|
|
67
|
+
* @property {Float32Array[]} positives - Positive embeddings
|
|
68
|
+
* @property {Float32Array[]} [negatives] - Negative embeddings
|
|
69
|
+
* @property {number} size - Batch size
|
|
70
|
+
*/
|
|
71
|
+
|
|
72
|
+
/**
|
|
73
|
+
* @typedef {Object} LossResult
|
|
74
|
+
* @property {number} loss - Computed loss value
|
|
75
|
+
* @property {Float32Array[]} [gradients] - Optional gradients
|
|
76
|
+
* @property {Object} [components] - Loss components breakdown
|
|
77
|
+
*/
|
|
78
|
+
|
|
79
|
+
// ============================================
|
|
80
|
+
// DATA PREPROCESSOR
|
|
81
|
+
// ============================================
|
|
82
|
+
|
|
83
|
+
/**
|
|
84
|
+
* DataPreprocessor - Prepares training data for MicroLoRA
|
|
85
|
+
*
|
|
86
|
+
* Handles tokenization, embedding generation, data augmentation,
|
|
87
|
+
* and quality filtering of training examples.
|
|
88
|
+
*
|
|
89
|
+
* @example
|
|
90
|
+
* ```javascript
|
|
91
|
+
* const preprocessor = new DataPreprocessor({
|
|
92
|
+
* maxLength: 512,
|
|
93
|
+
* augmentation: true,
|
|
94
|
+
* qualityThreshold: 0.5
|
|
95
|
+
* });
|
|
96
|
+
*
|
|
97
|
+
* const processed = await preprocessor.process([
|
|
98
|
+
* { input: 'Hello', output: 'World' }
|
|
99
|
+
* ]);
|
|
100
|
+
* ```
|
|
101
|
+
*/
|
|
102
|
+
export class DataPreprocessor extends EventEmitter {
|
|
103
|
+
/**
|
|
104
|
+
* Create a DataPreprocessor
|
|
105
|
+
*
|
|
106
|
+
* @param {Object} [config={}] - Preprocessor configuration
|
|
107
|
+
*/
|
|
108
|
+
constructor(config = {}) {
|
|
109
|
+
super();
|
|
110
|
+
|
|
111
|
+
this.config = {
|
|
112
|
+
maxLength: 512,
|
|
113
|
+
embeddingDim: 384,
|
|
114
|
+
augmentation: false,
|
|
115
|
+
augmentationFactor: 2,
|
|
116
|
+
qualityThreshold: 0.0,
|
|
117
|
+
normalizeEmbeddings: true,
|
|
118
|
+
...config,
|
|
119
|
+
};
|
|
120
|
+
|
|
121
|
+
// Simple vocabulary for tokenization
|
|
122
|
+
this.vocab = new Map();
|
|
123
|
+
this.vocabSize = 0;
|
|
124
|
+
|
|
125
|
+
// Cache for embeddings
|
|
126
|
+
this.embeddingCache = new Map();
|
|
127
|
+
|
|
128
|
+
this.stats = {
|
|
129
|
+
processed: 0,
|
|
130
|
+
filtered: 0,
|
|
131
|
+
augmented: 0,
|
|
132
|
+
cacheHits: 0,
|
|
133
|
+
};
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
/**
|
|
137
|
+
* Process training examples
|
|
138
|
+
*
|
|
139
|
+
* @param {TrainingExample[]} examples - Raw training examples
|
|
140
|
+
* @returns {Promise<ProcessedExample[]>} Processed examples
|
|
141
|
+
*/
|
|
142
|
+
async process(examples) {
|
|
143
|
+
const processed = [];
|
|
144
|
+
|
|
145
|
+
for (const example of examples) {
|
|
146
|
+
// Quality filter
|
|
147
|
+
if ((example.quality || 1.0) < this.config.qualityThreshold) {
|
|
148
|
+
this.stats.filtered++;
|
|
149
|
+
continue;
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
// Process single example
|
|
153
|
+
const result = await this._processExample(example);
|
|
154
|
+
if (result) {
|
|
155
|
+
processed.push(result);
|
|
156
|
+
this.stats.processed++;
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
// Augmentation
|
|
160
|
+
if (this.config.augmentation) {
|
|
161
|
+
const augmented = await this._augmentExample(example);
|
|
162
|
+
for (const aug of augmented) {
|
|
163
|
+
const augResult = await this._processExample(aug);
|
|
164
|
+
if (augResult) {
|
|
165
|
+
processed.push(augResult);
|
|
166
|
+
this.stats.augmented++;
|
|
167
|
+
}
|
|
168
|
+
}
|
|
169
|
+
}
|
|
170
|
+
}
|
|
171
|
+
|
|
172
|
+
this.emit('process:complete', {
|
|
173
|
+
total: processed.length,
|
|
174
|
+
stats: this.stats,
|
|
175
|
+
});
|
|
176
|
+
|
|
177
|
+
return processed;
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
/**
|
|
181
|
+
* Process a single example
|
|
182
|
+
* @private
|
|
183
|
+
*/
|
|
184
|
+
async _processExample(example) {
|
|
185
|
+
const inputTokens = this._tokenize(example.input);
|
|
186
|
+
const outputTokens = this._tokenize(example.output);
|
|
187
|
+
|
|
188
|
+
// Truncate if needed
|
|
189
|
+
const truncatedInput = inputTokens.slice(0, this.config.maxLength);
|
|
190
|
+
const truncatedOutput = outputTokens.slice(0, this.config.maxLength);
|
|
191
|
+
|
|
192
|
+
// Generate embeddings
|
|
193
|
+
const inputEmb = this._embed(example.input);
|
|
194
|
+
const outputEmb = this._embed(example.output);
|
|
195
|
+
|
|
196
|
+
return {
|
|
197
|
+
input: example.input,
|
|
198
|
+
output: example.output,
|
|
199
|
+
inputEmb,
|
|
200
|
+
outputEmb,
|
|
201
|
+
quality: example.quality || 1.0,
|
|
202
|
+
inputTokens: truncatedInput,
|
|
203
|
+
outputTokens: truncatedOutput,
|
|
204
|
+
metadata: example.metadata,
|
|
205
|
+
};
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
/**
|
|
209
|
+
* Simple tokenization (character-level with common subwords)
|
|
210
|
+
* @private
|
|
211
|
+
*/
|
|
212
|
+
_tokenize(text) {
|
|
213
|
+
const tokens = [];
|
|
214
|
+
|
|
215
|
+
// Split into words and characters
|
|
216
|
+
const words = text.split(/\s+/);
|
|
217
|
+
for (const word of words) {
|
|
218
|
+
// Check vocabulary
|
|
219
|
+
if (this.vocab.has(word)) {
|
|
220
|
+
tokens.push(this.vocab.get(word));
|
|
221
|
+
} else {
|
|
222
|
+
// Add to vocab or use character tokens
|
|
223
|
+
if (this.vocabSize < 50000) {
|
|
224
|
+
this.vocab.set(word, this.vocabSize);
|
|
225
|
+
tokens.push(this.vocabSize);
|
|
226
|
+
this.vocabSize++;
|
|
227
|
+
} else {
|
|
228
|
+
// Fall back to character-level
|
|
229
|
+
for (const char of word) {
|
|
230
|
+
const charToken = char.charCodeAt(0) % 256;
|
|
231
|
+
tokens.push(charToken);
|
|
232
|
+
}
|
|
233
|
+
}
|
|
234
|
+
}
|
|
235
|
+
// Add space token
|
|
236
|
+
tokens.push(32);
|
|
237
|
+
}
|
|
238
|
+
|
|
239
|
+
return tokens;
|
|
240
|
+
}
|
|
241
|
+
|
|
242
|
+
/**
|
|
243
|
+
* Generate embedding for text
|
|
244
|
+
* @private
|
|
245
|
+
*/
|
|
246
|
+
_embed(text) {
|
|
247
|
+
// Check cache
|
|
248
|
+
const cacheKey = createHash('md5').update(text).digest('hex');
|
|
249
|
+
if (this.embeddingCache.has(cacheKey)) {
|
|
250
|
+
this.stats.cacheHits++;
|
|
251
|
+
return this.embeddingCache.get(cacheKey);
|
|
252
|
+
}
|
|
253
|
+
|
|
254
|
+
const dim = this.config.embeddingDim;
|
|
255
|
+
const embedding = new Float32Array(dim);
|
|
256
|
+
|
|
257
|
+
// Hash-based embedding
|
|
258
|
+
const hash = createHash('sha256').update(text).digest();
|
|
259
|
+
for (let i = 0; i < dim; i++) {
|
|
260
|
+
embedding[i] = (hash[i % 32] - 128) / 128;
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
// Add positional character features
|
|
264
|
+
for (let i = 0; i < text.length && i < dim; i++) {
|
|
265
|
+
embedding[i] += (text.charCodeAt(i) - 64) / 256;
|
|
266
|
+
}
|
|
267
|
+
|
|
268
|
+
// Normalize
|
|
269
|
+
if (this.config.normalizeEmbeddings) {
|
|
270
|
+
let norm = 0;
|
|
271
|
+
for (let i = 0; i < dim; i++) {
|
|
272
|
+
norm += embedding[i] * embedding[i];
|
|
273
|
+
}
|
|
274
|
+
norm = Math.sqrt(norm) || 1;
|
|
275
|
+
for (let i = 0; i < dim; i++) {
|
|
276
|
+
embedding[i] /= norm;
|
|
277
|
+
}
|
|
278
|
+
}
|
|
279
|
+
|
|
280
|
+
// Cache
|
|
281
|
+
this.embeddingCache.set(cacheKey, embedding);
|
|
282
|
+
|
|
283
|
+
return embedding;
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
/**
|
|
287
|
+
* Augment a training example
|
|
288
|
+
* @private
|
|
289
|
+
*/
|
|
290
|
+
async _augmentExample(example) {
|
|
291
|
+
const augmented = [];
|
|
292
|
+
|
|
293
|
+
// Synonym replacement (simplified)
|
|
294
|
+
if (Math.random() < 0.5) {
|
|
295
|
+
augmented.push({
|
|
296
|
+
input: this._synonymReplace(example.input),
|
|
297
|
+
output: example.output,
|
|
298
|
+
quality: (example.quality || 1.0) * 0.9,
|
|
299
|
+
});
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
// Random insertion
|
|
303
|
+
if (Math.random() < 0.3) {
|
|
304
|
+
augmented.push({
|
|
305
|
+
input: this._randomInsert(example.input),
|
|
306
|
+
output: example.output,
|
|
307
|
+
quality: (example.quality || 1.0) * 0.85,
|
|
308
|
+
});
|
|
309
|
+
}
|
|
310
|
+
|
|
311
|
+
// Case variation
|
|
312
|
+
if (Math.random() < 0.3) {
|
|
313
|
+
augmented.push({
|
|
314
|
+
input: this._caseVariation(example.input),
|
|
315
|
+
output: example.output,
|
|
316
|
+
quality: (example.quality || 1.0) * 0.95,
|
|
317
|
+
});
|
|
318
|
+
}
|
|
319
|
+
|
|
320
|
+
return augmented.slice(0, this.config.augmentationFactor - 1);
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
/**
|
|
324
|
+
* Simple synonym replacement
|
|
325
|
+
* @private
|
|
326
|
+
*/
|
|
327
|
+
_synonymReplace(text) {
|
|
328
|
+
const synonyms = {
|
|
329
|
+
'write': ['create', 'make', 'generate'],
|
|
330
|
+
'function': ['method', 'procedure', 'routine'],
|
|
331
|
+
'code': ['program', 'script', 'implementation'],
|
|
332
|
+
'help': ['assist', 'aid', 'support'],
|
|
333
|
+
'explain': ['describe', 'clarify', 'elaborate'],
|
|
334
|
+
};
|
|
335
|
+
|
|
336
|
+
let result = text;
|
|
337
|
+
for (const [word, syns] of Object.entries(synonyms)) {
|
|
338
|
+
if (result.toLowerCase().includes(word)) {
|
|
339
|
+
const syn = syns[Math.floor(Math.random() * syns.length)];
|
|
340
|
+
result = result.replace(new RegExp(word, 'gi'), syn);
|
|
341
|
+
break;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
return result;
|
|
345
|
+
}
|
|
346
|
+
|
|
347
|
+
/**
|
|
348
|
+
* Random word insertion
|
|
349
|
+
* @private
|
|
350
|
+
*/
|
|
351
|
+
_randomInsert(text) {
|
|
352
|
+
const words = text.split(' ');
|
|
353
|
+
const insertWords = ['please', 'now', 'just', 'simply'];
|
|
354
|
+
const insertWord = insertWords[Math.floor(Math.random() * insertWords.length)];
|
|
355
|
+
const position = Math.floor(Math.random() * words.length);
|
|
356
|
+
words.splice(position, 0, insertWord);
|
|
357
|
+
return words.join(' ');
|
|
358
|
+
}
|
|
359
|
+
|
|
360
|
+
/**
|
|
361
|
+
* Case variation
|
|
362
|
+
* @private
|
|
363
|
+
*/
|
|
364
|
+
_caseVariation(text) {
|
|
365
|
+
const variations = [
|
|
366
|
+
text.toLowerCase(),
|
|
367
|
+
text.charAt(0).toUpperCase() + text.slice(1).toLowerCase(),
|
|
368
|
+
text.toUpperCase(),
|
|
369
|
+
];
|
|
370
|
+
return variations[Math.floor(Math.random() * variations.length)];
|
|
371
|
+
}
|
|
372
|
+
|
|
373
|
+
/**
|
|
374
|
+
* Clear embedding cache
|
|
375
|
+
*/
|
|
376
|
+
clearCache() {
|
|
377
|
+
this.embeddingCache.clear();
|
|
378
|
+
this.stats.cacheHits = 0;
|
|
379
|
+
}
|
|
380
|
+
|
|
381
|
+
/**
|
|
382
|
+
* Get preprocessor statistics
|
|
383
|
+
*/
|
|
384
|
+
getStats() {
|
|
385
|
+
return {
|
|
386
|
+
...this.stats,
|
|
387
|
+
vocabSize: this.vocabSize,
|
|
388
|
+
cacheSize: this.embeddingCache.size,
|
|
389
|
+
};
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
|
|
393
|
+
// ============================================
|
|
394
|
+
// BATCH GENERATOR
|
|
395
|
+
// ============================================
|
|
396
|
+
|
|
397
|
+
/**
|
|
398
|
+
* BatchGenerator - Generates training batches with shuffling and sampling
|
|
399
|
+
*
|
|
400
|
+
* Supports various batching strategies including random sampling,
|
|
401
|
+
* hard negative mining, and curriculum learning.
|
|
402
|
+
*
|
|
403
|
+
* @example
|
|
404
|
+
* ```javascript
|
|
405
|
+
* const batcher = new BatchGenerator(processedData, {
|
|
406
|
+
* batchSize: 8,
|
|
407
|
+
* shuffle: true,
|
|
408
|
+
* dropLast: false
|
|
409
|
+
* });
|
|
410
|
+
*
|
|
411
|
+
* for (const batch of batcher) {
|
|
412
|
+
* console.log(`Batch size: ${batch.size}`);
|
|
413
|
+
* }
|
|
414
|
+
* ```
|
|
415
|
+
*/
|
|
416
|
+
export class BatchGenerator {
|
|
417
|
+
/**
|
|
418
|
+
* Create a BatchGenerator
|
|
419
|
+
*
|
|
420
|
+
* @param {ProcessedExample[]} data - Processed training data
|
|
421
|
+
* @param {Object} [config={}] - Generator configuration
|
|
422
|
+
*/
|
|
423
|
+
constructor(data, config = {}) {
|
|
424
|
+
this.data = [...data];
|
|
425
|
+
this.config = {
|
|
426
|
+
batchSize: 8,
|
|
427
|
+
shuffle: true,
|
|
428
|
+
dropLast: false,
|
|
429
|
+
hardNegatives: false,
|
|
430
|
+
curriculum: false,
|
|
431
|
+
curriculumEpochs: 5,
|
|
432
|
+
...config,
|
|
433
|
+
};
|
|
434
|
+
|
|
435
|
+
this.currentIndex = 0;
|
|
436
|
+
this.epoch = 0;
|
|
437
|
+
this.indices = this._createIndices();
|
|
438
|
+
}
|
|
439
|
+
|
|
440
|
+
/**
|
|
441
|
+
* Create index array (with shuffling if enabled)
|
|
442
|
+
* @private
|
|
443
|
+
*/
|
|
444
|
+
_createIndices() {
|
|
445
|
+
const indices = Array.from({ length: this.data.length }, (_, i) => i);
|
|
446
|
+
|
|
447
|
+
if (this.config.shuffle) {
|
|
448
|
+
// Fisher-Yates shuffle
|
|
449
|
+
for (let i = indices.length - 1; i > 0; i--) {
|
|
450
|
+
const j = Math.floor(Math.random() * (i + 1));
|
|
451
|
+
[indices[i], indices[j]] = [indices[j], indices[i]];
|
|
452
|
+
}
|
|
453
|
+
}
|
|
454
|
+
|
|
455
|
+
// Curriculum learning: sort by quality/difficulty for early epochs
|
|
456
|
+
if (this.config.curriculum && this.epoch < this.config.curriculumEpochs) {
|
|
457
|
+
indices.sort((a, b) => {
|
|
458
|
+
const qualityA = this.data[a].quality || 1;
|
|
459
|
+
const qualityB = this.data[b].quality || 1;
|
|
460
|
+
// Higher quality first in early epochs
|
|
461
|
+
return qualityB - qualityA;
|
|
462
|
+
});
|
|
463
|
+
}
|
|
464
|
+
|
|
465
|
+
return indices;
|
|
466
|
+
}
|
|
467
|
+
|
|
468
|
+
/**
|
|
469
|
+
* Iterator implementation
|
|
470
|
+
*/
|
|
471
|
+
[Symbol.iterator]() {
|
|
472
|
+
return this;
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
/**
|
|
476
|
+
* Get next batch
|
|
477
|
+
* @returns {{value: Batch, done: boolean}}
|
|
478
|
+
*/
|
|
479
|
+
next() {
|
|
480
|
+
if (this.currentIndex >= this.indices.length) {
|
|
481
|
+
// Check if we should drop last incomplete batch
|
|
482
|
+
if (this.config.dropLast && this.currentIndex > 0) {
|
|
483
|
+
return { done: true };
|
|
484
|
+
}
|
|
485
|
+
return { done: true };
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
const endIndex = Math.min(
|
|
489
|
+
this.currentIndex + this.config.batchSize,
|
|
490
|
+
this.indices.length
|
|
491
|
+
);
|
|
492
|
+
|
|
493
|
+
// Skip if this would be an incomplete batch and dropLast is true
|
|
494
|
+
if (this.config.dropLast &&
|
|
495
|
+
endIndex - this.currentIndex < this.config.batchSize) {
|
|
496
|
+
return { done: true };
|
|
497
|
+
}
|
|
498
|
+
|
|
499
|
+
const batchIndices = this.indices.slice(this.currentIndex, endIndex);
|
|
500
|
+
const examples = batchIndices.map(i => this.data[i]);
|
|
501
|
+
|
|
502
|
+
// Extract embeddings
|
|
503
|
+
const anchors = examples.map(e => e.inputEmb);
|
|
504
|
+
const positives = examples.map(e => e.outputEmb);
|
|
505
|
+
|
|
506
|
+
// Generate negatives if needed
|
|
507
|
+
let negatives = null;
|
|
508
|
+
if (this.config.hardNegatives) {
|
|
509
|
+
negatives = this._mineHardNegatives(anchors, positives);
|
|
510
|
+
}
|
|
511
|
+
|
|
512
|
+
this.currentIndex = endIndex;
|
|
513
|
+
|
|
514
|
+
return {
|
|
515
|
+
value: {
|
|
516
|
+
examples,
|
|
517
|
+
anchors,
|
|
518
|
+
positives,
|
|
519
|
+
negatives,
|
|
520
|
+
size: examples.length,
|
|
521
|
+
},
|
|
522
|
+
done: false,
|
|
523
|
+
};
|
|
524
|
+
}
|
|
525
|
+
|
|
526
|
+
/**
|
|
527
|
+
* Mine hard negatives for contrastive learning
|
|
528
|
+
* @private
|
|
529
|
+
*/
|
|
530
|
+
_mineHardNegatives(anchors, positives) {
|
|
531
|
+
const negatives = [];
|
|
532
|
+
|
|
533
|
+
for (let i = 0; i < anchors.length; i++) {
|
|
534
|
+
// Find hardest negative (most similar non-positive)
|
|
535
|
+
let hardestIdx = -1;
|
|
536
|
+
let hardestSim = -Infinity;
|
|
537
|
+
|
|
538
|
+
for (let j = 0; j < positives.length; j++) {
|
|
539
|
+
if (i === j) continue;
|
|
540
|
+
|
|
541
|
+
const sim = this._cosineSimilarity(anchors[i], positives[j]);
|
|
542
|
+
if (sim > hardestSim) {
|
|
543
|
+
hardestSim = sim;
|
|
544
|
+
hardestIdx = j;
|
|
545
|
+
}
|
|
546
|
+
}
|
|
547
|
+
|
|
548
|
+
negatives.push(hardestIdx >= 0 ? positives[hardestIdx] : positives[(i + 1) % positives.length]);
|
|
549
|
+
}
|
|
550
|
+
|
|
551
|
+
return negatives;
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
/**
|
|
555
|
+
* Cosine similarity between two embeddings
|
|
556
|
+
* @private
|
|
557
|
+
*/
|
|
558
|
+
_cosineSimilarity(a, b) {
|
|
559
|
+
let dot = 0, normA = 0, normB = 0;
|
|
560
|
+
const len = Math.min(a.length, b.length);
|
|
561
|
+
for (let i = 0; i < len; i++) {
|
|
562
|
+
dot += a[i] * b[i];
|
|
563
|
+
normA += a[i] * a[i];
|
|
564
|
+
normB += b[i] * b[i];
|
|
565
|
+
}
|
|
566
|
+
return dot / (Math.sqrt(normA) * Math.sqrt(normB) + 1e-8);
|
|
567
|
+
}
|
|
568
|
+
|
|
569
|
+
/**
|
|
570
|
+
* Reset generator for new epoch
|
|
571
|
+
*/
|
|
572
|
+
reset() {
|
|
573
|
+
this.currentIndex = 0;
|
|
574
|
+
this.epoch++;
|
|
575
|
+
this.indices = this._createIndices();
|
|
576
|
+
}
|
|
577
|
+
|
|
578
|
+
/**
|
|
579
|
+
* Get total number of batches
|
|
580
|
+
*/
|
|
581
|
+
get length() {
|
|
582
|
+
if (this.config.dropLast) {
|
|
583
|
+
return Math.floor(this.data.length / this.config.batchSize);
|
|
584
|
+
}
|
|
585
|
+
return Math.ceil(this.data.length / this.config.batchSize);
|
|
586
|
+
}
|
|
587
|
+
|
|
588
|
+
/**
|
|
589
|
+
* Get number of remaining batches
|
|
590
|
+
*/
|
|
591
|
+
get remaining() {
|
|
592
|
+
const remaining = this.indices.length - this.currentIndex;
|
|
593
|
+
if (this.config.dropLast) {
|
|
594
|
+
return Math.floor(remaining / this.config.batchSize);
|
|
595
|
+
}
|
|
596
|
+
return Math.ceil(remaining / this.config.batchSize);
|
|
597
|
+
}
|
|
598
|
+
}
|
|
599
|
+
|
|
600
|
+
// ============================================
|
|
601
|
+
// LOSS COMPUTER
|
|
602
|
+
// ============================================
|
|
603
|
+
|
|
604
|
+
/**
|
|
605
|
+
* LossComputer - Various loss functions for training
|
|
606
|
+
*
|
|
607
|
+
* Implements multiple loss functions optimized for different training
|
|
608
|
+
* scenarios: contrastive learning, cross-entropy, triplet loss, etc.
|
|
609
|
+
*
|
|
610
|
+
* @example
|
|
611
|
+
* ```javascript
|
|
612
|
+
* // Contrastive loss
|
|
613
|
+
* const loss = LossComputer.contrastive(anchors, positives, negatives, {
|
|
614
|
+
* temperature: 0.07,
|
|
615
|
+
* margin: 0.5
|
|
616
|
+
* });
|
|
617
|
+
*
|
|
618
|
+
* // Cross-entropy loss
|
|
619
|
+
* const ceLoss = LossComputer.crossEntropy(predictions, targets);
|
|
620
|
+
*
|
|
621
|
+
* // Combined loss
|
|
622
|
+
* const combined = LossComputer.combine([
|
|
623
|
+
* { loss: contrastiveLoss, weight: 0.7 },
|
|
624
|
+
* { loss: ceLoss, weight: 0.3 }
|
|
625
|
+
* ]);
|
|
626
|
+
* ```
|
|
627
|
+
*/
|
|
628
|
+
export class LossComputer {
|
|
629
|
+
/**
|
|
630
|
+
* Contrastive loss (InfoNCE)
|
|
631
|
+
*
|
|
632
|
+
* @param {Float32Array[]} anchors - Anchor embeddings
|
|
633
|
+
* @param {Float32Array[]} positives - Positive embeddings
|
|
634
|
+
* @param {Float32Array[]} [negatives] - Negative embeddings
|
|
635
|
+
* @param {Object} [options={}] - Loss options
|
|
636
|
+
* @returns {LossResult}
|
|
637
|
+
*/
|
|
638
|
+
static contrastive(anchors, positives, negatives = null, options = {}) {
|
|
639
|
+
const { temperature = 0.07, margin = 0.0 } = options;
|
|
640
|
+
|
|
641
|
+
let totalLoss = 0;
|
|
642
|
+
const n = anchors.length;
|
|
643
|
+
|
|
644
|
+
for (let i = 0; i < n; i++) {
|
|
645
|
+
const anchor = anchors[i];
|
|
646
|
+
const positive = positives[i];
|
|
647
|
+
|
|
648
|
+
// Positive similarity
|
|
649
|
+
const posSim = LossComputer._cosineSimilarity(anchor, positive);
|
|
650
|
+
|
|
651
|
+
// Negative similarities (in-batch or explicit)
|
|
652
|
+
let negSum = 0;
|
|
653
|
+
const negCount = negatives ? negatives.length : n - 1;
|
|
654
|
+
|
|
655
|
+
if (negatives) {
|
|
656
|
+
for (let j = 0; j < negatives.length; j++) {
|
|
657
|
+
const negSim = LossComputer._cosineSimilarity(anchor, negatives[j]);
|
|
658
|
+
negSum += Math.exp((negSim - margin) / temperature);
|
|
659
|
+
}
|
|
660
|
+
} else {
|
|
661
|
+
// In-batch negatives
|
|
662
|
+
for (let j = 0; j < n; j++) {
|
|
663
|
+
if (i === j) continue;
|
|
664
|
+
const negSim = LossComputer._cosineSimilarity(anchor, positives[j]);
|
|
665
|
+
negSum += Math.exp((negSim - margin) / temperature);
|
|
666
|
+
}
|
|
667
|
+
}
|
|
668
|
+
|
|
669
|
+
// InfoNCE loss: -log(exp(pos/t) / (exp(pos/t) + sum(exp(neg/t))))
|
|
670
|
+
const posExp = Math.exp((posSim - margin) / temperature);
|
|
671
|
+
const loss = -Math.log(posExp / (posExp + negSum + 1e-8));
|
|
672
|
+
totalLoss += loss;
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
return {
|
|
676
|
+
loss: totalLoss / n,
|
|
677
|
+
components: { contrastive: totalLoss / n },
|
|
678
|
+
};
|
|
679
|
+
}
|
|
680
|
+
|
|
681
|
+
/**
|
|
682
|
+
* Triplet loss with margin
|
|
683
|
+
*
|
|
684
|
+
* @param {Float32Array[]} anchors - Anchor embeddings
|
|
685
|
+
* @param {Float32Array[]} positives - Positive embeddings
|
|
686
|
+
* @param {Float32Array[]} negatives - Negative embeddings
|
|
687
|
+
* @param {Object} [options={}] - Loss options
|
|
688
|
+
* @returns {LossResult}
|
|
689
|
+
*/
|
|
690
|
+
static triplet(anchors, positives, negatives, options = {}) {
|
|
691
|
+
const { margin = 0.5 } = options;
|
|
692
|
+
|
|
693
|
+
let totalLoss = 0;
|
|
694
|
+
const n = anchors.length;
|
|
695
|
+
|
|
696
|
+
for (let i = 0; i < n; i++) {
|
|
697
|
+
const posDistance = LossComputer._euclideanDistance(anchors[i], positives[i]);
|
|
698
|
+
const negDistance = LossComputer._euclideanDistance(anchors[i], negatives[i]);
|
|
699
|
+
|
|
700
|
+
// Triplet loss: max(0, pos_dist - neg_dist + margin)
|
|
701
|
+
const loss = Math.max(0, posDistance - negDistance + margin);
|
|
702
|
+
totalLoss += loss;
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
return {
|
|
706
|
+
loss: totalLoss / n,
|
|
707
|
+
components: { triplet: totalLoss / n },
|
|
708
|
+
};
|
|
709
|
+
}
|
|
710
|
+
|
|
711
|
+
/**
|
|
712
|
+
* Cross-entropy loss
|
|
713
|
+
*
|
|
714
|
+
* @param {number[][]} predictions - Predicted logits/probabilities
|
|
715
|
+
* @param {number[]} targets - Target class indices
|
|
716
|
+
* @param {Object} [options={}] - Loss options
|
|
717
|
+
* @returns {LossResult}
|
|
718
|
+
*/
|
|
719
|
+
static crossEntropy(predictions, targets, options = {}) {
|
|
720
|
+
const { labelSmoothing = 0.0 } = options;
|
|
721
|
+
|
|
722
|
+
let totalLoss = 0;
|
|
723
|
+
const n = predictions.length;
|
|
724
|
+
|
|
725
|
+
for (let i = 0; i < n; i++) {
|
|
726
|
+
const pred = predictions[i];
|
|
727
|
+
const target = targets[i];
|
|
728
|
+
|
|
729
|
+
// Softmax
|
|
730
|
+
const maxLogit = Math.max(...pred);
|
|
731
|
+
const expSum = pred.reduce((sum, p) => sum + Math.exp(p - maxLogit), 0);
|
|
732
|
+
const logProbs = pred.map(p => p - maxLogit - Math.log(expSum));
|
|
733
|
+
|
|
734
|
+
// Cross-entropy with optional label smoothing
|
|
735
|
+
if (labelSmoothing > 0) {
|
|
736
|
+
const numClasses = pred.length;
|
|
737
|
+
const smoothTarget = new Array(numClasses).fill(labelSmoothing / numClasses);
|
|
738
|
+
smoothTarget[target] = 1 - labelSmoothing + labelSmoothing / numClasses;
|
|
739
|
+
|
|
740
|
+
let loss = 0;
|
|
741
|
+
for (let c = 0; c < numClasses; c++) {
|
|
742
|
+
loss -= smoothTarget[c] * logProbs[c];
|
|
743
|
+
}
|
|
744
|
+
totalLoss += loss;
|
|
745
|
+
} else {
|
|
746
|
+
totalLoss -= logProbs[target];
|
|
747
|
+
}
|
|
748
|
+
}
|
|
749
|
+
|
|
750
|
+
return {
|
|
751
|
+
loss: totalLoss / n,
|
|
752
|
+
components: { crossEntropy: totalLoss / n },
|
|
753
|
+
};
|
|
754
|
+
}
|
|
755
|
+
|
|
756
|
+
/**
|
|
757
|
+
* Mean Squared Error loss
|
|
758
|
+
*
|
|
759
|
+
* @param {Float32Array[]} predictions - Predicted embeddings
|
|
760
|
+
* @param {Float32Array[]} targets - Target embeddings
|
|
761
|
+
* @returns {LossResult}
|
|
762
|
+
*/
|
|
763
|
+
static mse(predictions, targets) {
|
|
764
|
+
let totalLoss = 0;
|
|
765
|
+
const n = predictions.length;
|
|
766
|
+
|
|
767
|
+
for (let i = 0; i < n; i++) {
|
|
768
|
+
const pred = predictions[i];
|
|
769
|
+
const target = targets[i];
|
|
770
|
+
const dim = Math.min(pred.length, target.length);
|
|
771
|
+
|
|
772
|
+
let loss = 0;
|
|
773
|
+
for (let d = 0; d < dim; d++) {
|
|
774
|
+
const diff = pred[d] - target[d];
|
|
775
|
+
loss += diff * diff;
|
|
776
|
+
}
|
|
777
|
+
totalLoss += loss / dim;
|
|
778
|
+
}
|
|
779
|
+
|
|
780
|
+
return {
|
|
781
|
+
loss: totalLoss / n,
|
|
782
|
+
components: { mse: totalLoss / n },
|
|
783
|
+
};
|
|
784
|
+
}
|
|
785
|
+
|
|
786
|
+
/**
|
|
787
|
+
* Cosine embedding loss
|
|
788
|
+
*
|
|
789
|
+
* @param {Float32Array[]} predictions - Predicted embeddings
|
|
790
|
+
* @param {Float32Array[]} targets - Target embeddings
|
|
791
|
+
* @param {number[]} [labels] - Labels: 1 for similar, -1 for dissimilar
|
|
792
|
+
* @param {Object} [options={}] - Loss options
|
|
793
|
+
* @returns {LossResult}
|
|
794
|
+
*/
|
|
795
|
+
static cosineEmbedding(predictions, targets, labels = null, options = {}) {
|
|
796
|
+
const { margin = 0.0 } = options;
|
|
797
|
+
|
|
798
|
+
let totalLoss = 0;
|
|
799
|
+
const n = predictions.length;
|
|
800
|
+
|
|
801
|
+
for (let i = 0; i < n; i++) {
|
|
802
|
+
const sim = LossComputer._cosineSimilarity(predictions[i], targets[i]);
|
|
803
|
+
const label = labels ? labels[i] : 1;
|
|
804
|
+
|
|
805
|
+
if (label === 1) {
|
|
806
|
+
totalLoss += 1 - sim;
|
|
807
|
+
} else {
|
|
808
|
+
totalLoss += Math.max(0, sim - margin);
|
|
809
|
+
}
|
|
810
|
+
}
|
|
811
|
+
|
|
812
|
+
return {
|
|
813
|
+
loss: totalLoss / n,
|
|
814
|
+
components: { cosineEmbedding: totalLoss / n },
|
|
815
|
+
};
|
|
816
|
+
}
|
|
817
|
+
|
|
818
|
+
/**
|
|
819
|
+
* Combine multiple losses with weights
|
|
820
|
+
*
|
|
821
|
+
* @param {Array<{loss: LossResult, weight: number}>} losses - Weighted losses
|
|
822
|
+
* @returns {LossResult}
|
|
823
|
+
*/
|
|
824
|
+
static combine(losses) {
|
|
825
|
+
let totalLoss = 0;
|
|
826
|
+
const components = {};
|
|
827
|
+
|
|
828
|
+
for (const { loss, weight } of losses) {
|
|
829
|
+
totalLoss += loss.loss * weight;
|
|
830
|
+
for (const [name, value] of Object.entries(loss.components || {})) {
|
|
831
|
+
components[name] = (components[name] || 0) + value * weight;
|
|
832
|
+
}
|
|
833
|
+
}
|
|
834
|
+
|
|
835
|
+
return {
|
|
836
|
+
loss: totalLoss,
|
|
837
|
+
components,
|
|
838
|
+
};
|
|
839
|
+
}
|
|
840
|
+
|
|
841
|
+
/**
|
|
842
|
+
* Cosine similarity helper
|
|
843
|
+
* @private
|
|
844
|
+
*/
|
|
845
|
+
static _cosineSimilarity(a, b) {
|
|
846
|
+
let dot = 0, normA = 0, normB = 0;
|
|
847
|
+
const len = Math.min(a.length, b.length);
|
|
848
|
+
for (let i = 0; i < len; i++) {
|
|
849
|
+
dot += a[i] * b[i];
|
|
850
|
+
normA += a[i] * a[i];
|
|
851
|
+
normB += b[i] * b[i];
|
|
852
|
+
}
|
|
853
|
+
return dot / (Math.sqrt(normA) * Math.sqrt(normB) + 1e-8);
|
|
854
|
+
}
|
|
855
|
+
|
|
856
|
+
/**
|
|
857
|
+
* Euclidean distance helper
|
|
858
|
+
* @private
|
|
859
|
+
*/
|
|
860
|
+
static _euclideanDistance(a, b) {
|
|
861
|
+
let sum = 0;
|
|
862
|
+
const len = Math.min(a.length, b.length);
|
|
863
|
+
for (let i = 0; i < len; i++) {
|
|
864
|
+
const diff = a[i] - b[i];
|
|
865
|
+
sum += diff * diff;
|
|
866
|
+
}
|
|
867
|
+
return Math.sqrt(sum);
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
|
|
871
|
+
// ============================================
|
|
872
|
+
// EWC MANAGER
|
|
873
|
+
// ============================================
|
|
874
|
+
|
|
875
|
+
/**
|
|
876
|
+
* EWCManager - Elastic Weight Consolidation for Continual Learning
|
|
877
|
+
*
|
|
878
|
+
* Prevents catastrophic forgetting when training on new tasks by
|
|
879
|
+
* regularizing important weights based on Fisher information.
|
|
880
|
+
*
|
|
881
|
+
* @example
|
|
882
|
+
* ```javascript
|
|
883
|
+
* const ewc = new EWCManager({ lambda: 2000 });
|
|
884
|
+
*
|
|
885
|
+
* // After training on task 1
|
|
886
|
+
* ewc.computeFisher(adapters, task1Data);
|
|
887
|
+
*
|
|
888
|
+
* // When training on task 2, add EWC penalty
|
|
889
|
+
* const ewcLoss = ewc.computePenalty(currentAdapters);
|
|
890
|
+
* totalLoss = taskLoss + ewcLoss;
|
|
891
|
+
* ```
|
|
892
|
+
*/
|
|
893
|
+
export class EWCManager extends EventEmitter {
|
|
894
|
+
/**
|
|
895
|
+
* Create an EWCManager
|
|
896
|
+
*
|
|
897
|
+
* @param {Object} [config={}] - EWC configuration
|
|
898
|
+
*/
|
|
899
|
+
constructor(config = {}) {
|
|
900
|
+
super();
|
|
901
|
+
|
|
902
|
+
this.config = {
|
|
903
|
+
lambda: 2000, // Regularization strength
|
|
904
|
+
sampleSize: 200, // Samples for Fisher estimation
|
|
905
|
+
normalize: true, // Normalize Fisher values
|
|
906
|
+
online: false, // Online EWC (cumulative Fisher)
|
|
907
|
+
...config,
|
|
908
|
+
};
|
|
909
|
+
|
|
910
|
+
// Stored Fisher information and optimal parameters
|
|
911
|
+
this.fisherInfo = new Map();
|
|
912
|
+
this.optimalParams = new Map();
|
|
913
|
+
|
|
914
|
+
this.stats = {
|
|
915
|
+
tasksLearned: 0,
|
|
916
|
+
totalPenalties: 0,
|
|
917
|
+
};
|
|
918
|
+
}
|
|
919
|
+
|
|
920
|
+
/**
|
|
921
|
+
* Compute Fisher information matrix diagonal for adapters
|
|
922
|
+
*
|
|
923
|
+
* @param {Map<string, Object>} adapters - Adapter weights
|
|
924
|
+
* @param {ProcessedExample[]} data - Training data for estimation
|
|
925
|
+
*/
|
|
926
|
+
computeFisher(adapters, data) {
|
|
927
|
+
this.emit('fisher:start', { samples: data.length });
|
|
928
|
+
|
|
929
|
+
const sampleData = data.length > this.config.sampleSize
|
|
930
|
+
? this._sampleData(data, this.config.sampleSize)
|
|
931
|
+
: data;
|
|
932
|
+
|
|
933
|
+
for (const [name, adapter] of adapters) {
|
|
934
|
+
// Store optimal parameters
|
|
935
|
+
this.optimalParams.set(name, {
|
|
936
|
+
loraA: adapter.loraA.map(row => [...row]),
|
|
937
|
+
loraB: adapter.loraB.map(row => [...row]),
|
|
938
|
+
});
|
|
939
|
+
|
|
940
|
+
// Compute Fisher diagonal (squared gradients)
|
|
941
|
+
const fisherA = this._zeros(adapter.loraA.length, adapter.loraA[0].length);
|
|
942
|
+
const fisherB = this._zeros(adapter.loraB.length, adapter.loraB[0].length);
|
|
943
|
+
|
|
944
|
+
for (const example of sampleData) {
|
|
945
|
+
const grads = this._computeGradients(example, adapter);
|
|
946
|
+
|
|
947
|
+
// Accumulate squared gradients
|
|
948
|
+
for (let i = 0; i < fisherA.length; i++) {
|
|
949
|
+
for (let j = 0; j < fisherA[0].length; j++) {
|
|
950
|
+
fisherA[i][j] += grads.gradA[i][j] * grads.gradA[i][j];
|
|
951
|
+
}
|
|
952
|
+
}
|
|
953
|
+
for (let i = 0; i < fisherB.length; i++) {
|
|
954
|
+
for (let j = 0; j < fisherB[0].length; j++) {
|
|
955
|
+
fisherB[i][j] += grads.gradB[i][j] * grads.gradB[i][j];
|
|
956
|
+
}
|
|
957
|
+
}
|
|
958
|
+
}
|
|
959
|
+
|
|
960
|
+
// Normalize
|
|
961
|
+
const n = sampleData.length;
|
|
962
|
+
for (let i = 0; i < fisherA.length; i++) {
|
|
963
|
+
for (let j = 0; j < fisherA[0].length; j++) {
|
|
964
|
+
fisherA[i][j] /= n;
|
|
965
|
+
}
|
|
966
|
+
}
|
|
967
|
+
for (let i = 0; i < fisherB.length; i++) {
|
|
968
|
+
for (let j = 0; j < fisherB[0].length; j++) {
|
|
969
|
+
fisherB[i][j] /= n;
|
|
970
|
+
}
|
|
971
|
+
}
|
|
972
|
+
|
|
973
|
+
// Online EWC: accumulate with previous Fisher
|
|
974
|
+
if (this.config.online && this.fisherInfo.has(name)) {
|
|
975
|
+
const prevFisher = this.fisherInfo.get(name);
|
|
976
|
+
for (let i = 0; i < fisherA.length; i++) {
|
|
977
|
+
for (let j = 0; j < fisherA[0].length; j++) {
|
|
978
|
+
fisherA[i][j] = 0.5 * (fisherA[i][j] + prevFisher.fisherA[i][j]);
|
|
979
|
+
}
|
|
980
|
+
}
|
|
981
|
+
for (let i = 0; i < fisherB.length; i++) {
|
|
982
|
+
for (let j = 0; j < fisherB[0].length; j++) {
|
|
983
|
+
fisherB[i][j] = 0.5 * (fisherB[i][j] + prevFisher.fisherB[i][j]);
|
|
984
|
+
}
|
|
985
|
+
}
|
|
986
|
+
}
|
|
987
|
+
|
|
988
|
+
this.fisherInfo.set(name, { fisherA, fisherB });
|
|
989
|
+
}
|
|
990
|
+
|
|
991
|
+
this.stats.tasksLearned++;
|
|
992
|
+
this.emit('fisher:complete', { adapters: adapters.size });
|
|
993
|
+
}
|
|
994
|
+
|
|
995
|
+
/**
|
|
996
|
+
* Compute EWC penalty for current adapter values
|
|
997
|
+
*
|
|
998
|
+
* @param {Map<string, Object>} adapters - Current adapter weights
|
|
999
|
+
* @returns {number} EWC penalty value
|
|
1000
|
+
*/
|
|
1001
|
+
computePenalty(adapters) {
|
|
1002
|
+
if (this.fisherInfo.size === 0) {
|
|
1003
|
+
return 0;
|
|
1004
|
+
}
|
|
1005
|
+
|
|
1006
|
+
let penalty = 0;
|
|
1007
|
+
|
|
1008
|
+
for (const [name, adapter] of adapters) {
|
|
1009
|
+
const fisher = this.fisherInfo.get(name);
|
|
1010
|
+
const optimal = this.optimalParams.get(name);
|
|
1011
|
+
|
|
1012
|
+
if (!fisher || !optimal) continue;
|
|
1013
|
+
|
|
1014
|
+
// Sum of F_i * (theta_i - theta*_i)^2
|
|
1015
|
+
for (let i = 0; i < adapter.loraA.length; i++) {
|
|
1016
|
+
for (let j = 0; j < adapter.loraA[0].length; j++) {
|
|
1017
|
+
const diff = adapter.loraA[i][j] - optimal.loraA[i][j];
|
|
1018
|
+
penalty += fisher.fisherA[i][j] * diff * diff;
|
|
1019
|
+
}
|
|
1020
|
+
}
|
|
1021
|
+
for (let i = 0; i < adapter.loraB.length; i++) {
|
|
1022
|
+
for (let j = 0; j < adapter.loraB[0].length; j++) {
|
|
1023
|
+
const diff = adapter.loraB[i][j] - optimal.loraB[i][j];
|
|
1024
|
+
penalty += fisher.fisherB[i][j] * diff * diff;
|
|
1025
|
+
}
|
|
1026
|
+
}
|
|
1027
|
+
}
|
|
1028
|
+
|
|
1029
|
+
this.stats.totalPenalties++;
|
|
1030
|
+
return this.config.lambda * penalty * 0.5;
|
|
1031
|
+
}
|
|
1032
|
+
|
|
1033
|
+
/**
|
|
1034
|
+
* Apply EWC gradient to adapters
|
|
1035
|
+
*
|
|
1036
|
+
* @param {Map<string, Object>} adapters - Adapter weights to update
|
|
1037
|
+
* @param {number} learningRate - Learning rate
|
|
1038
|
+
*/
|
|
1039
|
+
applyGradient(adapters, learningRate) {
|
|
1040
|
+
for (const [name, adapter] of adapters) {
|
|
1041
|
+
const fisher = this.fisherInfo.get(name);
|
|
1042
|
+
const optimal = this.optimalParams.get(name);
|
|
1043
|
+
|
|
1044
|
+
if (!fisher || !optimal) continue;
|
|
1045
|
+
|
|
1046
|
+
for (let i = 0; i < adapter.loraA.length; i++) {
|
|
1047
|
+
for (let j = 0; j < adapter.loraA[0].length; j++) {
|
|
1048
|
+
const diff = adapter.loraA[i][j] - optimal.loraA[i][j];
|
|
1049
|
+
adapter.loraA[i][j] -= learningRate * this.config.lambda * fisher.fisherA[i][j] * diff;
|
|
1050
|
+
}
|
|
1051
|
+
}
|
|
1052
|
+
for (let i = 0; i < adapter.loraB.length; i++) {
|
|
1053
|
+
for (let j = 0; j < adapter.loraB[0].length; j++) {
|
|
1054
|
+
const diff = adapter.loraB[i][j] - optimal.loraB[i][j];
|
|
1055
|
+
adapter.loraB[i][j] -= learningRate * this.config.lambda * fisher.fisherB[i][j] * diff;
|
|
1056
|
+
}
|
|
1057
|
+
}
|
|
1058
|
+
}
|
|
1059
|
+
}
|
|
1060
|
+
|
|
1061
|
+
/**
|
|
1062
|
+
* Compute gradients for a single example
|
|
1063
|
+
* @private
|
|
1064
|
+
*/
|
|
1065
|
+
_computeGradients(example, adapter) {
|
|
1066
|
+
const input = example.inputEmb;
|
|
1067
|
+
const target = example.outputEmb;
|
|
1068
|
+
const rank = adapter.loraA[0].length;
|
|
1069
|
+
const dim = Math.min(input.length, adapter.loraA.length);
|
|
1070
|
+
|
|
1071
|
+
// Forward
|
|
1072
|
+
const hidden = new Float64Array(rank);
|
|
1073
|
+
for (let r = 0; r < rank; r++) {
|
|
1074
|
+
for (let d = 0; d < dim; d++) {
|
|
1075
|
+
hidden[r] += input[d] * adapter.loraA[d][r];
|
|
1076
|
+
}
|
|
1077
|
+
}
|
|
1078
|
+
|
|
1079
|
+
const adapted = [...input];
|
|
1080
|
+
for (let d = 0; d < dim; d++) {
|
|
1081
|
+
for (let r = 0; r < rank; r++) {
|
|
1082
|
+
adapted[d] += adapter.scaling * hidden[r] * adapter.loraB[r][d];
|
|
1083
|
+
}
|
|
1084
|
+
}
|
|
1085
|
+
|
|
1086
|
+
// Output gradient
|
|
1087
|
+
const gradOutput = adapted.map((val, i) =>
|
|
1088
|
+
2 * (val - (target[i] || 0)) / dim
|
|
1089
|
+
);
|
|
1090
|
+
|
|
1091
|
+
// Gradient for B
|
|
1092
|
+
const gradB = this._zeros(rank, dim);
|
|
1093
|
+
for (let r = 0; r < rank; r++) {
|
|
1094
|
+
for (let d = 0; d < dim; d++) {
|
|
1095
|
+
gradB[r][d] = hidden[r] * gradOutput[d] * adapter.scaling;
|
|
1096
|
+
}
|
|
1097
|
+
}
|
|
1098
|
+
|
|
1099
|
+
// Gradient for hidden
|
|
1100
|
+
const gradHidden = new Float64Array(rank);
|
|
1101
|
+
for (let r = 0; r < rank; r++) {
|
|
1102
|
+
for (let d = 0; d < dim; d++) {
|
|
1103
|
+
gradHidden[r] += gradOutput[d] * adapter.loraB[r][d] * adapter.scaling;
|
|
1104
|
+
}
|
|
1105
|
+
}
|
|
1106
|
+
|
|
1107
|
+
// Gradient for A
|
|
1108
|
+
const gradA = this._zeros(dim, rank);
|
|
1109
|
+
for (let d = 0; d < dim; d++) {
|
|
1110
|
+
for (let r = 0; r < rank; r++) {
|
|
1111
|
+
gradA[d][r] = input[d] * gradHidden[r];
|
|
1112
|
+
}
|
|
1113
|
+
}
|
|
1114
|
+
|
|
1115
|
+
return { gradA, gradB };
|
|
1116
|
+
}
|
|
1117
|
+
|
|
1118
|
+
/**
|
|
1119
|
+
* Random sample from data
|
|
1120
|
+
* @private
|
|
1121
|
+
*/
|
|
1122
|
+
_sampleData(data, size) {
|
|
1123
|
+
const indices = [];
|
|
1124
|
+
for (let i = 0; i < size; i++) {
|
|
1125
|
+
indices.push(Math.floor(Math.random() * data.length));
|
|
1126
|
+
}
|
|
1127
|
+
return indices.map(i => data[i]);
|
|
1128
|
+
}
|
|
1129
|
+
|
|
1130
|
+
/**
|
|
1131
|
+
* Zero matrix helper
|
|
1132
|
+
* @private
|
|
1133
|
+
*/
|
|
1134
|
+
_zeros(rows, cols) {
|
|
1135
|
+
return Array(rows).fill(null).map(() => Array(cols).fill(0));
|
|
1136
|
+
}
|
|
1137
|
+
|
|
1138
|
+
/**
|
|
1139
|
+
* Clear stored Fisher information
|
|
1140
|
+
*/
|
|
1141
|
+
reset() {
|
|
1142
|
+
this.fisherInfo.clear();
|
|
1143
|
+
this.optimalParams.clear();
|
|
1144
|
+
this.stats.tasksLearned = 0;
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
/**
|
|
1148
|
+
* Export EWC state for persistence
|
|
1149
|
+
*/
|
|
1150
|
+
export() {
|
|
1151
|
+
return {
|
|
1152
|
+
fisherInfo: Object.fromEntries(this.fisherInfo),
|
|
1153
|
+
optimalParams: Object.fromEntries(this.optimalParams),
|
|
1154
|
+
config: this.config,
|
|
1155
|
+
stats: this.stats,
|
|
1156
|
+
};
|
|
1157
|
+
}
|
|
1158
|
+
|
|
1159
|
+
/**
|
|
1160
|
+
* Import EWC state
|
|
1161
|
+
*/
|
|
1162
|
+
import(data) {
|
|
1163
|
+
this.fisherInfo = new Map(Object.entries(data.fisherInfo || {}));
|
|
1164
|
+
this.optimalParams = new Map(Object.entries(data.optimalParams || {}));
|
|
1165
|
+
if (data.config) this.config = { ...this.config, ...data.config };
|
|
1166
|
+
if (data.stats) this.stats = data.stats;
|
|
1167
|
+
}
|
|
1168
|
+
}
|
|
1169
|
+
|
|
1170
|
+
// ============================================
|
|
1171
|
+
// GRADIENT CHECKPOINTER
|
|
1172
|
+
// ============================================
|
|
1173
|
+
|
|
1174
|
+
/**
|
|
1175
|
+
* GradientCheckpointer - Memory-efficient training with gradient checkpointing
|
|
1176
|
+
*
|
|
1177
|
+
* Reduces memory usage during training by recomputing intermediate
|
|
1178
|
+
* activations during backward pass instead of storing them.
|
|
1179
|
+
*
|
|
1180
|
+
* @example
|
|
1181
|
+
* ```javascript
|
|
1182
|
+
* const checkpointer = new GradientCheckpointer({
|
|
1183
|
+
* checkpointSegments: 4,
|
|
1184
|
+
* enabled: true
|
|
1185
|
+
* });
|
|
1186
|
+
*
|
|
1187
|
+
* // Wrap forward pass
|
|
1188
|
+
* const output = checkpointer.checkpoint(forwardFn, inputs);
|
|
1189
|
+
*
|
|
1190
|
+
* // Backward pass with recomputation
|
|
1191
|
+
* const gradients = checkpointer.backward(output, gradOutput);
|
|
1192
|
+
* ```
|
|
1193
|
+
*/
|
|
1194
|
+
export class GradientCheckpointer extends EventEmitter {
|
|
1195
|
+
/**
|
|
1196
|
+
* Create a GradientCheckpointer
|
|
1197
|
+
*
|
|
1198
|
+
* @param {Object} [config={}] - Checkpointer configuration
|
|
1199
|
+
*/
|
|
1200
|
+
constructor(config = {}) {
|
|
1201
|
+
super();
|
|
1202
|
+
|
|
1203
|
+
this.config = {
|
|
1204
|
+
enabled: true,
|
|
1205
|
+
checkpointSegments: 4, // Number of segments to divide computation
|
|
1206
|
+
maxStoredActivations: 10, // Maximum activations to store
|
|
1207
|
+
...config,
|
|
1208
|
+
};
|
|
1209
|
+
|
|
1210
|
+
// Stored checkpoints
|
|
1211
|
+
this.checkpoints = [];
|
|
1212
|
+
|
|
1213
|
+
// Stored forward functions for recomputation
|
|
1214
|
+
this.forwardFns = [];
|
|
1215
|
+
|
|
1216
|
+
this.stats = {
|
|
1217
|
+
checkpoints: 0,
|
|
1218
|
+
recomputations: 0,
|
|
1219
|
+
memoryReduction: 0,
|
|
1220
|
+
};
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
/**
|
|
1224
|
+
* Run forward pass with checkpointing
|
|
1225
|
+
*
|
|
1226
|
+
* @param {Function} forwardFn - Forward function to checkpoint
|
|
1227
|
+
* @param {any} input - Input to forward function
|
|
1228
|
+
* @returns {any} Forward output
|
|
1229
|
+
*/
|
|
1230
|
+
checkpoint(forwardFn, input) {
|
|
1231
|
+
if (!this.config.enabled) {
|
|
1232
|
+
return forwardFn(input);
|
|
1233
|
+
}
|
|
1234
|
+
|
|
1235
|
+
// Store function for potential recomputation
|
|
1236
|
+
this.forwardFns.push({ fn: forwardFn, input: this._shallowCopy(input) });
|
|
1237
|
+
this.stats.checkpoints++;
|
|
1238
|
+
|
|
1239
|
+
// Run forward
|
|
1240
|
+
const output = forwardFn(input);
|
|
1241
|
+
|
|
1242
|
+
// Store checkpoint if within limit
|
|
1243
|
+
if (this.checkpoints.length < this.config.maxStoredActivations) {
|
|
1244
|
+
this.checkpoints.push({
|
|
1245
|
+
input: this._shallowCopy(input),
|
|
1246
|
+
output: this._shallowCopy(output),
|
|
1247
|
+
});
|
|
1248
|
+
}
|
|
1249
|
+
|
|
1250
|
+
return output;
|
|
1251
|
+
}
|
|
1252
|
+
|
|
1253
|
+
/**
|
|
1254
|
+
* Recompute activations for backward pass
|
|
1255
|
+
*
|
|
1256
|
+
* @param {number} segmentIdx - Segment index to recompute
|
|
1257
|
+
* @returns {Object} Recomputed activations
|
|
1258
|
+
*/
|
|
1259
|
+
recompute(segmentIdx) {
|
|
1260
|
+
if (segmentIdx < this.checkpoints.length) {
|
|
1261
|
+
return this.checkpoints[segmentIdx];
|
|
1262
|
+
}
|
|
1263
|
+
|
|
1264
|
+
// Recompute from stored forward functions
|
|
1265
|
+
if (segmentIdx < this.forwardFns.length) {
|
|
1266
|
+
const { fn, input } = this.forwardFns[segmentIdx];
|
|
1267
|
+
const output = fn(input);
|
|
1268
|
+
this.stats.recomputations++;
|
|
1269
|
+
|
|
1270
|
+
return { input, output };
|
|
1271
|
+
}
|
|
1272
|
+
|
|
1273
|
+
throw new Error(`Cannot recompute segment ${segmentIdx}`);
|
|
1274
|
+
}
|
|
1275
|
+
|
|
1276
|
+
/**
|
|
1277
|
+
* Clear stored checkpoints
|
|
1278
|
+
*/
|
|
1279
|
+
clear() {
|
|
1280
|
+
this.checkpoints = [];
|
|
1281
|
+
this.forwardFns = [];
|
|
1282
|
+
}
|
|
1283
|
+
|
|
1284
|
+
/**
|
|
1285
|
+
* Shallow copy helper
|
|
1286
|
+
* @private
|
|
1287
|
+
*/
|
|
1288
|
+
_shallowCopy(obj) {
|
|
1289
|
+
if (Array.isArray(obj)) {
|
|
1290
|
+
return [...obj];
|
|
1291
|
+
}
|
|
1292
|
+
if (obj instanceof Float32Array || obj instanceof Float64Array) {
|
|
1293
|
+
return new obj.constructor(obj);
|
|
1294
|
+
}
|
|
1295
|
+
if (typeof obj === 'object' && obj !== null) {
|
|
1296
|
+
return { ...obj };
|
|
1297
|
+
}
|
|
1298
|
+
return obj;
|
|
1299
|
+
}
|
|
1300
|
+
|
|
1301
|
+
/**
|
|
1302
|
+
* Estimate memory savings
|
|
1303
|
+
*/
|
|
1304
|
+
estimateMemorySavings(totalLayers, activationSize) {
|
|
1305
|
+
const withoutCheckpointing = totalLayers * activationSize;
|
|
1306
|
+
const withCheckpointing = this.config.checkpointSegments * activationSize +
|
|
1307
|
+
(totalLayers / this.config.checkpointSegments) * activationSize;
|
|
1308
|
+
|
|
1309
|
+
const savings = 1 - (withCheckpointing / withoutCheckpointing);
|
|
1310
|
+
this.stats.memoryReduction = savings;
|
|
1311
|
+
|
|
1312
|
+
return {
|
|
1313
|
+
without: withoutCheckpointing,
|
|
1314
|
+
with: withCheckpointing,
|
|
1315
|
+
savings: savings * 100,
|
|
1316
|
+
savingsPercentage: `${(savings * 100).toFixed(1)}%`,
|
|
1317
|
+
};
|
|
1318
|
+
}
|
|
1319
|
+
|
|
1320
|
+
/**
|
|
1321
|
+
* Get checkpointer statistics
|
|
1322
|
+
*/
|
|
1323
|
+
getStats() {
|
|
1324
|
+
return {
|
|
1325
|
+
...this.stats,
|
|
1326
|
+
storedCheckpoints: this.checkpoints.length,
|
|
1327
|
+
storedFunctions: this.forwardFns.length,
|
|
1328
|
+
};
|
|
1329
|
+
}
|
|
1330
|
+
}
|
|
1331
|
+
|
|
1332
|
+
// ============================================
|
|
1333
|
+
// LEARNING RATE SCHEDULERS
|
|
1334
|
+
// ============================================
|
|
1335
|
+
|
|
1336
|
+
/**
|
|
1337
|
+
* Learning rate scheduler implementations
|
|
1338
|
+
*/
|
|
1339
|
+
export const LRSchedulers = {
|
|
1340
|
+
/**
|
|
1341
|
+
* Constant learning rate
|
|
1342
|
+
*/
|
|
1343
|
+
constant: (baseLR, step, totalSteps) => baseLR,
|
|
1344
|
+
|
|
1345
|
+
/**
|
|
1346
|
+
* Linear decay
|
|
1347
|
+
*/
|
|
1348
|
+
linear: (baseLR, step, totalSteps) =>
|
|
1349
|
+
baseLR * (1 - step / totalSteps),
|
|
1350
|
+
|
|
1351
|
+
/**
|
|
1352
|
+
* Cosine annealing
|
|
1353
|
+
*/
|
|
1354
|
+
cosine: (baseLR, step, totalSteps) =>
|
|
1355
|
+
baseLR * 0.5 * (1 + Math.cos(Math.PI * step / totalSteps)),
|
|
1356
|
+
|
|
1357
|
+
/**
|
|
1358
|
+
* Cosine with warm restarts
|
|
1359
|
+
*/
|
|
1360
|
+
cosineWarmRestarts: (baseLR, step, totalSteps, opts = {}) => {
|
|
1361
|
+
const { restartPeriod = 100, restartMultiplier = 2 } = opts;
|
|
1362
|
+
const cycleStep = step % restartPeriod;
|
|
1363
|
+
return baseLR * 0.5 * (1 + Math.cos(Math.PI * cycleStep / restartPeriod));
|
|
1364
|
+
},
|
|
1365
|
+
|
|
1366
|
+
/**
|
|
1367
|
+
* Exponential decay
|
|
1368
|
+
*/
|
|
1369
|
+
exponential: (baseLR, step, totalSteps, opts = {}) => {
|
|
1370
|
+
const { gamma = 0.95, stepSize = 100 } = opts;
|
|
1371
|
+
return baseLR * Math.pow(gamma, Math.floor(step / stepSize));
|
|
1372
|
+
},
|
|
1373
|
+
|
|
1374
|
+
/**
|
|
1375
|
+
* Warmup + cosine decay
|
|
1376
|
+
*/
|
|
1377
|
+
warmupCosine: (baseLR, step, totalSteps, opts = {}) => {
|
|
1378
|
+
const { warmupSteps = 100 } = opts;
|
|
1379
|
+
if (step < warmupSteps) {
|
|
1380
|
+
return baseLR * (step / warmupSteps);
|
|
1381
|
+
}
|
|
1382
|
+
const decayStep = step - warmupSteps;
|
|
1383
|
+
const decayTotal = totalSteps - warmupSteps;
|
|
1384
|
+
return baseLR * 0.5 * (1 + Math.cos(Math.PI * decayStep / decayTotal));
|
|
1385
|
+
},
|
|
1386
|
+
|
|
1387
|
+
/**
|
|
1388
|
+
* One-cycle policy
|
|
1389
|
+
*/
|
|
1390
|
+
oneCycle: (baseLR, step, totalSteps, opts = {}) => {
|
|
1391
|
+
const { maxLR = baseLR * 10, divFactor = 25, finalDiv = 10000 } = opts;
|
|
1392
|
+
const midPoint = totalSteps / 2;
|
|
1393
|
+
|
|
1394
|
+
if (step < midPoint) {
|
|
1395
|
+
// Warmup phase
|
|
1396
|
+
const progress = step / midPoint;
|
|
1397
|
+
return baseLR + progress * (maxLR - baseLR);
|
|
1398
|
+
} else {
|
|
1399
|
+
// Annealing phase
|
|
1400
|
+
const progress = (step - midPoint) / midPoint;
|
|
1401
|
+
const finalLR = baseLR / finalDiv;
|
|
1402
|
+
return maxLR - progress * (maxLR - finalLR);
|
|
1403
|
+
}
|
|
1404
|
+
},
|
|
1405
|
+
};
|
|
1406
|
+
|
|
1407
|
+
// ============================================
|
|
1408
|
+
// EXPORTS
|
|
1409
|
+
// ============================================
|
|
1410
|
+
|
|
1411
|
+
export default {
|
|
1412
|
+
DataPreprocessor,
|
|
1413
|
+
BatchGenerator,
|
|
1414
|
+
LossComputer,
|
|
1415
|
+
EWCManager,
|
|
1416
|
+
GradientCheckpointer,
|
|
1417
|
+
LRSchedulers,
|
|
1418
|
+
};
|