@ruvector/edge-net 0.5.0 → 0.5.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1418 @@
1
+ /**
2
+ * Training Utilities for MicroLoRA
3
+ *
4
+ * Provides comprehensive training infrastructure including data preprocessing,
5
+ * batch generation, loss computation, EWC for continual learning, and
6
+ * gradient checkpointing for memory efficiency.
7
+ *
8
+ * @module @ruvector/edge-net/models/training-utils
9
+ *
10
+ * @example
11
+ * ```javascript
12
+ * import {
13
+ * DataPreprocessor,
14
+ * BatchGenerator,
15
+ * LossComputer,
16
+ * EWCManager,
17
+ * GradientCheckpointer
18
+ * } from '@ruvector/edge-net/models/training-utils';
19
+ *
20
+ * // Preprocess training data
21
+ * const preprocessor = new DataPreprocessor();
22
+ * const processed = await preprocessor.process(examples);
23
+ *
24
+ * // Generate batches
25
+ * const batcher = new BatchGenerator(processed, { batchSize: 8 });
26
+ * for (const batch of batcher) {
27
+ * const loss = LossComputer.contrastive(batch.anchors, batch.positives, batch.negatives);
28
+ * }
29
+ *
30
+ * // Enable EWC for continual learning
31
+ * const ewc = new EWCManager({ lambda: 2000 });
32
+ * ewc.computeFisher(model, dataloader);
33
+ * ```
34
+ */
35
+
36
+ import { EventEmitter } from 'events';
37
+ import { createHash } from 'crypto';
38
+
39
+ // ============================================
40
+ // TYPE DEFINITIONS (JSDoc)
41
+ // ============================================
42
+
43
+ /**
44
+ * @typedef {Object} TrainingExample
45
+ * @property {string} input - Input text
46
+ * @property {string} output - Expected output
47
+ * @property {number} [quality=1.0] - Example quality weight
48
+ * @property {Object} [metadata] - Optional metadata
49
+ */
50
+
51
+ /**
52
+ * @typedef {Object} ProcessedExample
53
+ * @property {string} input - Original input text
54
+ * @property {string} output - Original output text
55
+ * @property {Float32Array} inputEmb - Input embedding
56
+ * @property {Float32Array} outputEmb - Output embedding
57
+ * @property {Float32Array} [negativeEmb] - Optional negative embedding
58
+ * @property {number} quality - Example quality weight
59
+ * @property {number[]} inputTokens - Tokenized input
60
+ * @property {number[]} outputTokens - Tokenized output
61
+ */
62
+
63
+ /**
64
+ * @typedef {Object} Batch
65
+ * @property {ProcessedExample[]} examples - Batch examples
66
+ * @property {Float32Array[]} anchors - Anchor embeddings
67
+ * @property {Float32Array[]} positives - Positive embeddings
68
+ * @property {Float32Array[]} [negatives] - Negative embeddings
69
+ * @property {number} size - Batch size
70
+ */
71
+
72
+ /**
73
+ * @typedef {Object} LossResult
74
+ * @property {number} loss - Computed loss value
75
+ * @property {Float32Array[]} [gradients] - Optional gradients
76
+ * @property {Object} [components] - Loss components breakdown
77
+ */
78
+
79
+ // ============================================
80
+ // DATA PREPROCESSOR
81
+ // ============================================
82
+
83
+ /**
84
+ * DataPreprocessor - Prepares training data for MicroLoRA
85
+ *
86
+ * Handles tokenization, embedding generation, data augmentation,
87
+ * and quality filtering of training examples.
88
+ *
89
+ * @example
90
+ * ```javascript
91
+ * const preprocessor = new DataPreprocessor({
92
+ * maxLength: 512,
93
+ * augmentation: true,
94
+ * qualityThreshold: 0.5
95
+ * });
96
+ *
97
+ * const processed = await preprocessor.process([
98
+ * { input: 'Hello', output: 'World' }
99
+ * ]);
100
+ * ```
101
+ */
102
+ export class DataPreprocessor extends EventEmitter {
103
+ /**
104
+ * Create a DataPreprocessor
105
+ *
106
+ * @param {Object} [config={}] - Preprocessor configuration
107
+ */
108
+ constructor(config = {}) {
109
+ super();
110
+
111
+ this.config = {
112
+ maxLength: 512,
113
+ embeddingDim: 384,
114
+ augmentation: false,
115
+ augmentationFactor: 2,
116
+ qualityThreshold: 0.0,
117
+ normalizeEmbeddings: true,
118
+ ...config,
119
+ };
120
+
121
+ // Simple vocabulary for tokenization
122
+ this.vocab = new Map();
123
+ this.vocabSize = 0;
124
+
125
+ // Cache for embeddings
126
+ this.embeddingCache = new Map();
127
+
128
+ this.stats = {
129
+ processed: 0,
130
+ filtered: 0,
131
+ augmented: 0,
132
+ cacheHits: 0,
133
+ };
134
+ }
135
+
136
+ /**
137
+ * Process training examples
138
+ *
139
+ * @param {TrainingExample[]} examples - Raw training examples
140
+ * @returns {Promise<ProcessedExample[]>} Processed examples
141
+ */
142
+ async process(examples) {
143
+ const processed = [];
144
+
145
+ for (const example of examples) {
146
+ // Quality filter
147
+ if ((example.quality || 1.0) < this.config.qualityThreshold) {
148
+ this.stats.filtered++;
149
+ continue;
150
+ }
151
+
152
+ // Process single example
153
+ const result = await this._processExample(example);
154
+ if (result) {
155
+ processed.push(result);
156
+ this.stats.processed++;
157
+ }
158
+
159
+ // Augmentation
160
+ if (this.config.augmentation) {
161
+ const augmented = await this._augmentExample(example);
162
+ for (const aug of augmented) {
163
+ const augResult = await this._processExample(aug);
164
+ if (augResult) {
165
+ processed.push(augResult);
166
+ this.stats.augmented++;
167
+ }
168
+ }
169
+ }
170
+ }
171
+
172
+ this.emit('process:complete', {
173
+ total: processed.length,
174
+ stats: this.stats,
175
+ });
176
+
177
+ return processed;
178
+ }
179
+
180
+ /**
181
+ * Process a single example
182
+ * @private
183
+ */
184
+ async _processExample(example) {
185
+ const inputTokens = this._tokenize(example.input);
186
+ const outputTokens = this._tokenize(example.output);
187
+
188
+ // Truncate if needed
189
+ const truncatedInput = inputTokens.slice(0, this.config.maxLength);
190
+ const truncatedOutput = outputTokens.slice(0, this.config.maxLength);
191
+
192
+ // Generate embeddings
193
+ const inputEmb = this._embed(example.input);
194
+ const outputEmb = this._embed(example.output);
195
+
196
+ return {
197
+ input: example.input,
198
+ output: example.output,
199
+ inputEmb,
200
+ outputEmb,
201
+ quality: example.quality || 1.0,
202
+ inputTokens: truncatedInput,
203
+ outputTokens: truncatedOutput,
204
+ metadata: example.metadata,
205
+ };
206
+ }
207
+
208
+ /**
209
+ * Simple tokenization (character-level with common subwords)
210
+ * @private
211
+ */
212
+ _tokenize(text) {
213
+ const tokens = [];
214
+
215
+ // Split into words and characters
216
+ const words = text.split(/\s+/);
217
+ for (const word of words) {
218
+ // Check vocabulary
219
+ if (this.vocab.has(word)) {
220
+ tokens.push(this.vocab.get(word));
221
+ } else {
222
+ // Add to vocab or use character tokens
223
+ if (this.vocabSize < 50000) {
224
+ this.vocab.set(word, this.vocabSize);
225
+ tokens.push(this.vocabSize);
226
+ this.vocabSize++;
227
+ } else {
228
+ // Fall back to character-level
229
+ for (const char of word) {
230
+ const charToken = char.charCodeAt(0) % 256;
231
+ tokens.push(charToken);
232
+ }
233
+ }
234
+ }
235
+ // Add space token
236
+ tokens.push(32);
237
+ }
238
+
239
+ return tokens;
240
+ }
241
+
242
+ /**
243
+ * Generate embedding for text
244
+ * @private
245
+ */
246
+ _embed(text) {
247
+ // Check cache
248
+ const cacheKey = createHash('md5').update(text).digest('hex');
249
+ if (this.embeddingCache.has(cacheKey)) {
250
+ this.stats.cacheHits++;
251
+ return this.embeddingCache.get(cacheKey);
252
+ }
253
+
254
+ const dim = this.config.embeddingDim;
255
+ const embedding = new Float32Array(dim);
256
+
257
+ // Hash-based embedding
258
+ const hash = createHash('sha256').update(text).digest();
259
+ for (let i = 0; i < dim; i++) {
260
+ embedding[i] = (hash[i % 32] - 128) / 128;
261
+ }
262
+
263
+ // Add positional character features
264
+ for (let i = 0; i < text.length && i < dim; i++) {
265
+ embedding[i] += (text.charCodeAt(i) - 64) / 256;
266
+ }
267
+
268
+ // Normalize
269
+ if (this.config.normalizeEmbeddings) {
270
+ let norm = 0;
271
+ for (let i = 0; i < dim; i++) {
272
+ norm += embedding[i] * embedding[i];
273
+ }
274
+ norm = Math.sqrt(norm) || 1;
275
+ for (let i = 0; i < dim; i++) {
276
+ embedding[i] /= norm;
277
+ }
278
+ }
279
+
280
+ // Cache
281
+ this.embeddingCache.set(cacheKey, embedding);
282
+
283
+ return embedding;
284
+ }
285
+
286
+ /**
287
+ * Augment a training example
288
+ * @private
289
+ */
290
+ async _augmentExample(example) {
291
+ const augmented = [];
292
+
293
+ // Synonym replacement (simplified)
294
+ if (Math.random() < 0.5) {
295
+ augmented.push({
296
+ input: this._synonymReplace(example.input),
297
+ output: example.output,
298
+ quality: (example.quality || 1.0) * 0.9,
299
+ });
300
+ }
301
+
302
+ // Random insertion
303
+ if (Math.random() < 0.3) {
304
+ augmented.push({
305
+ input: this._randomInsert(example.input),
306
+ output: example.output,
307
+ quality: (example.quality || 1.0) * 0.85,
308
+ });
309
+ }
310
+
311
+ // Case variation
312
+ if (Math.random() < 0.3) {
313
+ augmented.push({
314
+ input: this._caseVariation(example.input),
315
+ output: example.output,
316
+ quality: (example.quality || 1.0) * 0.95,
317
+ });
318
+ }
319
+
320
+ return augmented.slice(0, this.config.augmentationFactor - 1);
321
+ }
322
+
323
+ /**
324
+ * Simple synonym replacement
325
+ * @private
326
+ */
327
+ _synonymReplace(text) {
328
+ const synonyms = {
329
+ 'write': ['create', 'make', 'generate'],
330
+ 'function': ['method', 'procedure', 'routine'],
331
+ 'code': ['program', 'script', 'implementation'],
332
+ 'help': ['assist', 'aid', 'support'],
333
+ 'explain': ['describe', 'clarify', 'elaborate'],
334
+ };
335
+
336
+ let result = text;
337
+ for (const [word, syns] of Object.entries(synonyms)) {
338
+ if (result.toLowerCase().includes(word)) {
339
+ const syn = syns[Math.floor(Math.random() * syns.length)];
340
+ result = result.replace(new RegExp(word, 'gi'), syn);
341
+ break;
342
+ }
343
+ }
344
+ return result;
345
+ }
346
+
347
+ /**
348
+ * Random word insertion
349
+ * @private
350
+ */
351
+ _randomInsert(text) {
352
+ const words = text.split(' ');
353
+ const insertWords = ['please', 'now', 'just', 'simply'];
354
+ const insertWord = insertWords[Math.floor(Math.random() * insertWords.length)];
355
+ const position = Math.floor(Math.random() * words.length);
356
+ words.splice(position, 0, insertWord);
357
+ return words.join(' ');
358
+ }
359
+
360
+ /**
361
+ * Case variation
362
+ * @private
363
+ */
364
+ _caseVariation(text) {
365
+ const variations = [
366
+ text.toLowerCase(),
367
+ text.charAt(0).toUpperCase() + text.slice(1).toLowerCase(),
368
+ text.toUpperCase(),
369
+ ];
370
+ return variations[Math.floor(Math.random() * variations.length)];
371
+ }
372
+
373
+ /**
374
+ * Clear embedding cache
375
+ */
376
+ clearCache() {
377
+ this.embeddingCache.clear();
378
+ this.stats.cacheHits = 0;
379
+ }
380
+
381
+ /**
382
+ * Get preprocessor statistics
383
+ */
384
+ getStats() {
385
+ return {
386
+ ...this.stats,
387
+ vocabSize: this.vocabSize,
388
+ cacheSize: this.embeddingCache.size,
389
+ };
390
+ }
391
+ }
392
+
393
+ // ============================================
394
+ // BATCH GENERATOR
395
+ // ============================================
396
+
397
+ /**
398
+ * BatchGenerator - Generates training batches with shuffling and sampling
399
+ *
400
+ * Supports various batching strategies including random sampling,
401
+ * hard negative mining, and curriculum learning.
402
+ *
403
+ * @example
404
+ * ```javascript
405
+ * const batcher = new BatchGenerator(processedData, {
406
+ * batchSize: 8,
407
+ * shuffle: true,
408
+ * dropLast: false
409
+ * });
410
+ *
411
+ * for (const batch of batcher) {
412
+ * console.log(`Batch size: ${batch.size}`);
413
+ * }
414
+ * ```
415
+ */
416
+ export class BatchGenerator {
417
+ /**
418
+ * Create a BatchGenerator
419
+ *
420
+ * @param {ProcessedExample[]} data - Processed training data
421
+ * @param {Object} [config={}] - Generator configuration
422
+ */
423
+ constructor(data, config = {}) {
424
+ this.data = [...data];
425
+ this.config = {
426
+ batchSize: 8,
427
+ shuffle: true,
428
+ dropLast: false,
429
+ hardNegatives: false,
430
+ curriculum: false,
431
+ curriculumEpochs: 5,
432
+ ...config,
433
+ };
434
+
435
+ this.currentIndex = 0;
436
+ this.epoch = 0;
437
+ this.indices = this._createIndices();
438
+ }
439
+
440
+ /**
441
+ * Create index array (with shuffling if enabled)
442
+ * @private
443
+ */
444
+ _createIndices() {
445
+ const indices = Array.from({ length: this.data.length }, (_, i) => i);
446
+
447
+ if (this.config.shuffle) {
448
+ // Fisher-Yates shuffle
449
+ for (let i = indices.length - 1; i > 0; i--) {
450
+ const j = Math.floor(Math.random() * (i + 1));
451
+ [indices[i], indices[j]] = [indices[j], indices[i]];
452
+ }
453
+ }
454
+
455
+ // Curriculum learning: sort by quality/difficulty for early epochs
456
+ if (this.config.curriculum && this.epoch < this.config.curriculumEpochs) {
457
+ indices.sort((a, b) => {
458
+ const qualityA = this.data[a].quality || 1;
459
+ const qualityB = this.data[b].quality || 1;
460
+ // Higher quality first in early epochs
461
+ return qualityB - qualityA;
462
+ });
463
+ }
464
+
465
+ return indices;
466
+ }
467
+
468
+ /**
469
+ * Iterator implementation
470
+ */
471
+ [Symbol.iterator]() {
472
+ return this;
473
+ }
474
+
475
+ /**
476
+ * Get next batch
477
+ * @returns {{value: Batch, done: boolean}}
478
+ */
479
+ next() {
480
+ if (this.currentIndex >= this.indices.length) {
481
+ // Check if we should drop last incomplete batch
482
+ if (this.config.dropLast && this.currentIndex > 0) {
483
+ return { done: true };
484
+ }
485
+ return { done: true };
486
+ }
487
+
488
+ const endIndex = Math.min(
489
+ this.currentIndex + this.config.batchSize,
490
+ this.indices.length
491
+ );
492
+
493
+ // Skip if this would be an incomplete batch and dropLast is true
494
+ if (this.config.dropLast &&
495
+ endIndex - this.currentIndex < this.config.batchSize) {
496
+ return { done: true };
497
+ }
498
+
499
+ const batchIndices = this.indices.slice(this.currentIndex, endIndex);
500
+ const examples = batchIndices.map(i => this.data[i]);
501
+
502
+ // Extract embeddings
503
+ const anchors = examples.map(e => e.inputEmb);
504
+ const positives = examples.map(e => e.outputEmb);
505
+
506
+ // Generate negatives if needed
507
+ let negatives = null;
508
+ if (this.config.hardNegatives) {
509
+ negatives = this._mineHardNegatives(anchors, positives);
510
+ }
511
+
512
+ this.currentIndex = endIndex;
513
+
514
+ return {
515
+ value: {
516
+ examples,
517
+ anchors,
518
+ positives,
519
+ negatives,
520
+ size: examples.length,
521
+ },
522
+ done: false,
523
+ };
524
+ }
525
+
526
+ /**
527
+ * Mine hard negatives for contrastive learning
528
+ * @private
529
+ */
530
+ _mineHardNegatives(anchors, positives) {
531
+ const negatives = [];
532
+
533
+ for (let i = 0; i < anchors.length; i++) {
534
+ // Find hardest negative (most similar non-positive)
535
+ let hardestIdx = -1;
536
+ let hardestSim = -Infinity;
537
+
538
+ for (let j = 0; j < positives.length; j++) {
539
+ if (i === j) continue;
540
+
541
+ const sim = this._cosineSimilarity(anchors[i], positives[j]);
542
+ if (sim > hardestSim) {
543
+ hardestSim = sim;
544
+ hardestIdx = j;
545
+ }
546
+ }
547
+
548
+ negatives.push(hardestIdx >= 0 ? positives[hardestIdx] : positives[(i + 1) % positives.length]);
549
+ }
550
+
551
+ return negatives;
552
+ }
553
+
554
+ /**
555
+ * Cosine similarity between two embeddings
556
+ * @private
557
+ */
558
+ _cosineSimilarity(a, b) {
559
+ let dot = 0, normA = 0, normB = 0;
560
+ const len = Math.min(a.length, b.length);
561
+ for (let i = 0; i < len; i++) {
562
+ dot += a[i] * b[i];
563
+ normA += a[i] * a[i];
564
+ normB += b[i] * b[i];
565
+ }
566
+ return dot / (Math.sqrt(normA) * Math.sqrt(normB) + 1e-8);
567
+ }
568
+
569
+ /**
570
+ * Reset generator for new epoch
571
+ */
572
+ reset() {
573
+ this.currentIndex = 0;
574
+ this.epoch++;
575
+ this.indices = this._createIndices();
576
+ }
577
+
578
+ /**
579
+ * Get total number of batches
580
+ */
581
+ get length() {
582
+ if (this.config.dropLast) {
583
+ return Math.floor(this.data.length / this.config.batchSize);
584
+ }
585
+ return Math.ceil(this.data.length / this.config.batchSize);
586
+ }
587
+
588
+ /**
589
+ * Get number of remaining batches
590
+ */
591
+ get remaining() {
592
+ const remaining = this.indices.length - this.currentIndex;
593
+ if (this.config.dropLast) {
594
+ return Math.floor(remaining / this.config.batchSize);
595
+ }
596
+ return Math.ceil(remaining / this.config.batchSize);
597
+ }
598
+ }
599
+
600
+ // ============================================
601
+ // LOSS COMPUTER
602
+ // ============================================
603
+
604
+ /**
605
+ * LossComputer - Various loss functions for training
606
+ *
607
+ * Implements multiple loss functions optimized for different training
608
+ * scenarios: contrastive learning, cross-entropy, triplet loss, etc.
609
+ *
610
+ * @example
611
+ * ```javascript
612
+ * // Contrastive loss
613
+ * const loss = LossComputer.contrastive(anchors, positives, negatives, {
614
+ * temperature: 0.07,
615
+ * margin: 0.5
616
+ * });
617
+ *
618
+ * // Cross-entropy loss
619
+ * const ceLoss = LossComputer.crossEntropy(predictions, targets);
620
+ *
621
+ * // Combined loss
622
+ * const combined = LossComputer.combine([
623
+ * { loss: contrastiveLoss, weight: 0.7 },
624
+ * { loss: ceLoss, weight: 0.3 }
625
+ * ]);
626
+ * ```
627
+ */
628
+ export class LossComputer {
629
+ /**
630
+ * Contrastive loss (InfoNCE)
631
+ *
632
+ * @param {Float32Array[]} anchors - Anchor embeddings
633
+ * @param {Float32Array[]} positives - Positive embeddings
634
+ * @param {Float32Array[]} [negatives] - Negative embeddings
635
+ * @param {Object} [options={}] - Loss options
636
+ * @returns {LossResult}
637
+ */
638
+ static contrastive(anchors, positives, negatives = null, options = {}) {
639
+ const { temperature = 0.07, margin = 0.0 } = options;
640
+
641
+ let totalLoss = 0;
642
+ const n = anchors.length;
643
+
644
+ for (let i = 0; i < n; i++) {
645
+ const anchor = anchors[i];
646
+ const positive = positives[i];
647
+
648
+ // Positive similarity
649
+ const posSim = LossComputer._cosineSimilarity(anchor, positive);
650
+
651
+ // Negative similarities (in-batch or explicit)
652
+ let negSum = 0;
653
+ const negCount = negatives ? negatives.length : n - 1;
654
+
655
+ if (negatives) {
656
+ for (let j = 0; j < negatives.length; j++) {
657
+ const negSim = LossComputer._cosineSimilarity(anchor, negatives[j]);
658
+ negSum += Math.exp((negSim - margin) / temperature);
659
+ }
660
+ } else {
661
+ // In-batch negatives
662
+ for (let j = 0; j < n; j++) {
663
+ if (i === j) continue;
664
+ const negSim = LossComputer._cosineSimilarity(anchor, positives[j]);
665
+ negSum += Math.exp((negSim - margin) / temperature);
666
+ }
667
+ }
668
+
669
+ // InfoNCE loss: -log(exp(pos/t) / (exp(pos/t) + sum(exp(neg/t))))
670
+ const posExp = Math.exp((posSim - margin) / temperature);
671
+ const loss = -Math.log(posExp / (posExp + negSum + 1e-8));
672
+ totalLoss += loss;
673
+ }
674
+
675
+ return {
676
+ loss: totalLoss / n,
677
+ components: { contrastive: totalLoss / n },
678
+ };
679
+ }
680
+
681
+ /**
682
+ * Triplet loss with margin
683
+ *
684
+ * @param {Float32Array[]} anchors - Anchor embeddings
685
+ * @param {Float32Array[]} positives - Positive embeddings
686
+ * @param {Float32Array[]} negatives - Negative embeddings
687
+ * @param {Object} [options={}] - Loss options
688
+ * @returns {LossResult}
689
+ */
690
+ static triplet(anchors, positives, negatives, options = {}) {
691
+ const { margin = 0.5 } = options;
692
+
693
+ let totalLoss = 0;
694
+ const n = anchors.length;
695
+
696
+ for (let i = 0; i < n; i++) {
697
+ const posDistance = LossComputer._euclideanDistance(anchors[i], positives[i]);
698
+ const negDistance = LossComputer._euclideanDistance(anchors[i], negatives[i]);
699
+
700
+ // Triplet loss: max(0, pos_dist - neg_dist + margin)
701
+ const loss = Math.max(0, posDistance - negDistance + margin);
702
+ totalLoss += loss;
703
+ }
704
+
705
+ return {
706
+ loss: totalLoss / n,
707
+ components: { triplet: totalLoss / n },
708
+ };
709
+ }
710
+
711
+ /**
712
+ * Cross-entropy loss
713
+ *
714
+ * @param {number[][]} predictions - Predicted logits/probabilities
715
+ * @param {number[]} targets - Target class indices
716
+ * @param {Object} [options={}] - Loss options
717
+ * @returns {LossResult}
718
+ */
719
+ static crossEntropy(predictions, targets, options = {}) {
720
+ const { labelSmoothing = 0.0 } = options;
721
+
722
+ let totalLoss = 0;
723
+ const n = predictions.length;
724
+
725
+ for (let i = 0; i < n; i++) {
726
+ const pred = predictions[i];
727
+ const target = targets[i];
728
+
729
+ // Softmax
730
+ const maxLogit = Math.max(...pred);
731
+ const expSum = pred.reduce((sum, p) => sum + Math.exp(p - maxLogit), 0);
732
+ const logProbs = pred.map(p => p - maxLogit - Math.log(expSum));
733
+
734
+ // Cross-entropy with optional label smoothing
735
+ if (labelSmoothing > 0) {
736
+ const numClasses = pred.length;
737
+ const smoothTarget = new Array(numClasses).fill(labelSmoothing / numClasses);
738
+ smoothTarget[target] = 1 - labelSmoothing + labelSmoothing / numClasses;
739
+
740
+ let loss = 0;
741
+ for (let c = 0; c < numClasses; c++) {
742
+ loss -= smoothTarget[c] * logProbs[c];
743
+ }
744
+ totalLoss += loss;
745
+ } else {
746
+ totalLoss -= logProbs[target];
747
+ }
748
+ }
749
+
750
+ return {
751
+ loss: totalLoss / n,
752
+ components: { crossEntropy: totalLoss / n },
753
+ };
754
+ }
755
+
756
+ /**
757
+ * Mean Squared Error loss
758
+ *
759
+ * @param {Float32Array[]} predictions - Predicted embeddings
760
+ * @param {Float32Array[]} targets - Target embeddings
761
+ * @returns {LossResult}
762
+ */
763
+ static mse(predictions, targets) {
764
+ let totalLoss = 0;
765
+ const n = predictions.length;
766
+
767
+ for (let i = 0; i < n; i++) {
768
+ const pred = predictions[i];
769
+ const target = targets[i];
770
+ const dim = Math.min(pred.length, target.length);
771
+
772
+ let loss = 0;
773
+ for (let d = 0; d < dim; d++) {
774
+ const diff = pred[d] - target[d];
775
+ loss += diff * diff;
776
+ }
777
+ totalLoss += loss / dim;
778
+ }
779
+
780
+ return {
781
+ loss: totalLoss / n,
782
+ components: { mse: totalLoss / n },
783
+ };
784
+ }
785
+
786
+ /**
787
+ * Cosine embedding loss
788
+ *
789
+ * @param {Float32Array[]} predictions - Predicted embeddings
790
+ * @param {Float32Array[]} targets - Target embeddings
791
+ * @param {number[]} [labels] - Labels: 1 for similar, -1 for dissimilar
792
+ * @param {Object} [options={}] - Loss options
793
+ * @returns {LossResult}
794
+ */
795
+ static cosineEmbedding(predictions, targets, labels = null, options = {}) {
796
+ const { margin = 0.0 } = options;
797
+
798
+ let totalLoss = 0;
799
+ const n = predictions.length;
800
+
801
+ for (let i = 0; i < n; i++) {
802
+ const sim = LossComputer._cosineSimilarity(predictions[i], targets[i]);
803
+ const label = labels ? labels[i] : 1;
804
+
805
+ if (label === 1) {
806
+ totalLoss += 1 - sim;
807
+ } else {
808
+ totalLoss += Math.max(0, sim - margin);
809
+ }
810
+ }
811
+
812
+ return {
813
+ loss: totalLoss / n,
814
+ components: { cosineEmbedding: totalLoss / n },
815
+ };
816
+ }
817
+
818
+ /**
819
+ * Combine multiple losses with weights
820
+ *
821
+ * @param {Array<{loss: LossResult, weight: number}>} losses - Weighted losses
822
+ * @returns {LossResult}
823
+ */
824
+ static combine(losses) {
825
+ let totalLoss = 0;
826
+ const components = {};
827
+
828
+ for (const { loss, weight } of losses) {
829
+ totalLoss += loss.loss * weight;
830
+ for (const [name, value] of Object.entries(loss.components || {})) {
831
+ components[name] = (components[name] || 0) + value * weight;
832
+ }
833
+ }
834
+
835
+ return {
836
+ loss: totalLoss,
837
+ components,
838
+ };
839
+ }
840
+
841
+ /**
842
+ * Cosine similarity helper
843
+ * @private
844
+ */
845
+ static _cosineSimilarity(a, b) {
846
+ let dot = 0, normA = 0, normB = 0;
847
+ const len = Math.min(a.length, b.length);
848
+ for (let i = 0; i < len; i++) {
849
+ dot += a[i] * b[i];
850
+ normA += a[i] * a[i];
851
+ normB += b[i] * b[i];
852
+ }
853
+ return dot / (Math.sqrt(normA) * Math.sqrt(normB) + 1e-8);
854
+ }
855
+
856
+ /**
857
+ * Euclidean distance helper
858
+ * @private
859
+ */
860
+ static _euclideanDistance(a, b) {
861
+ let sum = 0;
862
+ const len = Math.min(a.length, b.length);
863
+ for (let i = 0; i < len; i++) {
864
+ const diff = a[i] - b[i];
865
+ sum += diff * diff;
866
+ }
867
+ return Math.sqrt(sum);
868
+ }
869
+ }
870
+
871
+ // ============================================
872
+ // EWC MANAGER
873
+ // ============================================
874
+
875
+ /**
876
+ * EWCManager - Elastic Weight Consolidation for Continual Learning
877
+ *
878
+ * Prevents catastrophic forgetting when training on new tasks by
879
+ * regularizing important weights based on Fisher information.
880
+ *
881
+ * @example
882
+ * ```javascript
883
+ * const ewc = new EWCManager({ lambda: 2000 });
884
+ *
885
+ * // After training on task 1
886
+ * ewc.computeFisher(adapters, task1Data);
887
+ *
888
+ * // When training on task 2, add EWC penalty
889
+ * const ewcLoss = ewc.computePenalty(currentAdapters);
890
+ * totalLoss = taskLoss + ewcLoss;
891
+ * ```
892
+ */
893
+ export class EWCManager extends EventEmitter {
894
+ /**
895
+ * Create an EWCManager
896
+ *
897
+ * @param {Object} [config={}] - EWC configuration
898
+ */
899
+ constructor(config = {}) {
900
+ super();
901
+
902
+ this.config = {
903
+ lambda: 2000, // Regularization strength
904
+ sampleSize: 200, // Samples for Fisher estimation
905
+ normalize: true, // Normalize Fisher values
906
+ online: false, // Online EWC (cumulative Fisher)
907
+ ...config,
908
+ };
909
+
910
+ // Stored Fisher information and optimal parameters
911
+ this.fisherInfo = new Map();
912
+ this.optimalParams = new Map();
913
+
914
+ this.stats = {
915
+ tasksLearned: 0,
916
+ totalPenalties: 0,
917
+ };
918
+ }
919
+
920
+ /**
921
+ * Compute Fisher information matrix diagonal for adapters
922
+ *
923
+ * @param {Map<string, Object>} adapters - Adapter weights
924
+ * @param {ProcessedExample[]} data - Training data for estimation
925
+ */
926
+ computeFisher(adapters, data) {
927
+ this.emit('fisher:start', { samples: data.length });
928
+
929
+ const sampleData = data.length > this.config.sampleSize
930
+ ? this._sampleData(data, this.config.sampleSize)
931
+ : data;
932
+
933
+ for (const [name, adapter] of adapters) {
934
+ // Store optimal parameters
935
+ this.optimalParams.set(name, {
936
+ loraA: adapter.loraA.map(row => [...row]),
937
+ loraB: adapter.loraB.map(row => [...row]),
938
+ });
939
+
940
+ // Compute Fisher diagonal (squared gradients)
941
+ const fisherA = this._zeros(adapter.loraA.length, adapter.loraA[0].length);
942
+ const fisherB = this._zeros(adapter.loraB.length, adapter.loraB[0].length);
943
+
944
+ for (const example of sampleData) {
945
+ const grads = this._computeGradients(example, adapter);
946
+
947
+ // Accumulate squared gradients
948
+ for (let i = 0; i < fisherA.length; i++) {
949
+ for (let j = 0; j < fisherA[0].length; j++) {
950
+ fisherA[i][j] += grads.gradA[i][j] * grads.gradA[i][j];
951
+ }
952
+ }
953
+ for (let i = 0; i < fisherB.length; i++) {
954
+ for (let j = 0; j < fisherB[0].length; j++) {
955
+ fisherB[i][j] += grads.gradB[i][j] * grads.gradB[i][j];
956
+ }
957
+ }
958
+ }
959
+
960
+ // Normalize
961
+ const n = sampleData.length;
962
+ for (let i = 0; i < fisherA.length; i++) {
963
+ for (let j = 0; j < fisherA[0].length; j++) {
964
+ fisherA[i][j] /= n;
965
+ }
966
+ }
967
+ for (let i = 0; i < fisherB.length; i++) {
968
+ for (let j = 0; j < fisherB[0].length; j++) {
969
+ fisherB[i][j] /= n;
970
+ }
971
+ }
972
+
973
+ // Online EWC: accumulate with previous Fisher
974
+ if (this.config.online && this.fisherInfo.has(name)) {
975
+ const prevFisher = this.fisherInfo.get(name);
976
+ for (let i = 0; i < fisherA.length; i++) {
977
+ for (let j = 0; j < fisherA[0].length; j++) {
978
+ fisherA[i][j] = 0.5 * (fisherA[i][j] + prevFisher.fisherA[i][j]);
979
+ }
980
+ }
981
+ for (let i = 0; i < fisherB.length; i++) {
982
+ for (let j = 0; j < fisherB[0].length; j++) {
983
+ fisherB[i][j] = 0.5 * (fisherB[i][j] + prevFisher.fisherB[i][j]);
984
+ }
985
+ }
986
+ }
987
+
988
+ this.fisherInfo.set(name, { fisherA, fisherB });
989
+ }
990
+
991
+ this.stats.tasksLearned++;
992
+ this.emit('fisher:complete', { adapters: adapters.size });
993
+ }
994
+
995
+ /**
996
+ * Compute EWC penalty for current adapter values
997
+ *
998
+ * @param {Map<string, Object>} adapters - Current adapter weights
999
+ * @returns {number} EWC penalty value
1000
+ */
1001
+ computePenalty(adapters) {
1002
+ if (this.fisherInfo.size === 0) {
1003
+ return 0;
1004
+ }
1005
+
1006
+ let penalty = 0;
1007
+
1008
+ for (const [name, adapter] of adapters) {
1009
+ const fisher = this.fisherInfo.get(name);
1010
+ const optimal = this.optimalParams.get(name);
1011
+
1012
+ if (!fisher || !optimal) continue;
1013
+
1014
+ // Sum of F_i * (theta_i - theta*_i)^2
1015
+ for (let i = 0; i < adapter.loraA.length; i++) {
1016
+ for (let j = 0; j < adapter.loraA[0].length; j++) {
1017
+ const diff = adapter.loraA[i][j] - optimal.loraA[i][j];
1018
+ penalty += fisher.fisherA[i][j] * diff * diff;
1019
+ }
1020
+ }
1021
+ for (let i = 0; i < adapter.loraB.length; i++) {
1022
+ for (let j = 0; j < adapter.loraB[0].length; j++) {
1023
+ const diff = adapter.loraB[i][j] - optimal.loraB[i][j];
1024
+ penalty += fisher.fisherB[i][j] * diff * diff;
1025
+ }
1026
+ }
1027
+ }
1028
+
1029
+ this.stats.totalPenalties++;
1030
+ return this.config.lambda * penalty * 0.5;
1031
+ }
1032
+
1033
+ /**
1034
+ * Apply EWC gradient to adapters
1035
+ *
1036
+ * @param {Map<string, Object>} adapters - Adapter weights to update
1037
+ * @param {number} learningRate - Learning rate
1038
+ */
1039
+ applyGradient(adapters, learningRate) {
1040
+ for (const [name, adapter] of adapters) {
1041
+ const fisher = this.fisherInfo.get(name);
1042
+ const optimal = this.optimalParams.get(name);
1043
+
1044
+ if (!fisher || !optimal) continue;
1045
+
1046
+ for (let i = 0; i < adapter.loraA.length; i++) {
1047
+ for (let j = 0; j < adapter.loraA[0].length; j++) {
1048
+ const diff = adapter.loraA[i][j] - optimal.loraA[i][j];
1049
+ adapter.loraA[i][j] -= learningRate * this.config.lambda * fisher.fisherA[i][j] * diff;
1050
+ }
1051
+ }
1052
+ for (let i = 0; i < adapter.loraB.length; i++) {
1053
+ for (let j = 0; j < adapter.loraB[0].length; j++) {
1054
+ const diff = adapter.loraB[i][j] - optimal.loraB[i][j];
1055
+ adapter.loraB[i][j] -= learningRate * this.config.lambda * fisher.fisherB[i][j] * diff;
1056
+ }
1057
+ }
1058
+ }
1059
+ }
1060
+
1061
+ /**
1062
+ * Compute gradients for a single example
1063
+ * @private
1064
+ */
1065
+ _computeGradients(example, adapter) {
1066
+ const input = example.inputEmb;
1067
+ const target = example.outputEmb;
1068
+ const rank = adapter.loraA[0].length;
1069
+ const dim = Math.min(input.length, adapter.loraA.length);
1070
+
1071
+ // Forward
1072
+ const hidden = new Float64Array(rank);
1073
+ for (let r = 0; r < rank; r++) {
1074
+ for (let d = 0; d < dim; d++) {
1075
+ hidden[r] += input[d] * adapter.loraA[d][r];
1076
+ }
1077
+ }
1078
+
1079
+ const adapted = [...input];
1080
+ for (let d = 0; d < dim; d++) {
1081
+ for (let r = 0; r < rank; r++) {
1082
+ adapted[d] += adapter.scaling * hidden[r] * adapter.loraB[r][d];
1083
+ }
1084
+ }
1085
+
1086
+ // Output gradient
1087
+ const gradOutput = adapted.map((val, i) =>
1088
+ 2 * (val - (target[i] || 0)) / dim
1089
+ );
1090
+
1091
+ // Gradient for B
1092
+ const gradB = this._zeros(rank, dim);
1093
+ for (let r = 0; r < rank; r++) {
1094
+ for (let d = 0; d < dim; d++) {
1095
+ gradB[r][d] = hidden[r] * gradOutput[d] * adapter.scaling;
1096
+ }
1097
+ }
1098
+
1099
+ // Gradient for hidden
1100
+ const gradHidden = new Float64Array(rank);
1101
+ for (let r = 0; r < rank; r++) {
1102
+ for (let d = 0; d < dim; d++) {
1103
+ gradHidden[r] += gradOutput[d] * adapter.loraB[r][d] * adapter.scaling;
1104
+ }
1105
+ }
1106
+
1107
+ // Gradient for A
1108
+ const gradA = this._zeros(dim, rank);
1109
+ for (let d = 0; d < dim; d++) {
1110
+ for (let r = 0; r < rank; r++) {
1111
+ gradA[d][r] = input[d] * gradHidden[r];
1112
+ }
1113
+ }
1114
+
1115
+ return { gradA, gradB };
1116
+ }
1117
+
1118
+ /**
1119
+ * Random sample from data
1120
+ * @private
1121
+ */
1122
+ _sampleData(data, size) {
1123
+ const indices = [];
1124
+ for (let i = 0; i < size; i++) {
1125
+ indices.push(Math.floor(Math.random() * data.length));
1126
+ }
1127
+ return indices.map(i => data[i]);
1128
+ }
1129
+
1130
+ /**
1131
+ * Zero matrix helper
1132
+ * @private
1133
+ */
1134
+ _zeros(rows, cols) {
1135
+ return Array(rows).fill(null).map(() => Array(cols).fill(0));
1136
+ }
1137
+
1138
+ /**
1139
+ * Clear stored Fisher information
1140
+ */
1141
+ reset() {
1142
+ this.fisherInfo.clear();
1143
+ this.optimalParams.clear();
1144
+ this.stats.tasksLearned = 0;
1145
+ }
1146
+
1147
+ /**
1148
+ * Export EWC state for persistence
1149
+ */
1150
+ export() {
1151
+ return {
1152
+ fisherInfo: Object.fromEntries(this.fisherInfo),
1153
+ optimalParams: Object.fromEntries(this.optimalParams),
1154
+ config: this.config,
1155
+ stats: this.stats,
1156
+ };
1157
+ }
1158
+
1159
+ /**
1160
+ * Import EWC state
1161
+ */
1162
+ import(data) {
1163
+ this.fisherInfo = new Map(Object.entries(data.fisherInfo || {}));
1164
+ this.optimalParams = new Map(Object.entries(data.optimalParams || {}));
1165
+ if (data.config) this.config = { ...this.config, ...data.config };
1166
+ if (data.stats) this.stats = data.stats;
1167
+ }
1168
+ }
1169
+
1170
+ // ============================================
1171
+ // GRADIENT CHECKPOINTER
1172
+ // ============================================
1173
+
1174
+ /**
1175
+ * GradientCheckpointer - Memory-efficient training with gradient checkpointing
1176
+ *
1177
+ * Reduces memory usage during training by recomputing intermediate
1178
+ * activations during backward pass instead of storing them.
1179
+ *
1180
+ * @example
1181
+ * ```javascript
1182
+ * const checkpointer = new GradientCheckpointer({
1183
+ * checkpointSegments: 4,
1184
+ * enabled: true
1185
+ * });
1186
+ *
1187
+ * // Wrap forward pass
1188
+ * const output = checkpointer.checkpoint(forwardFn, inputs);
1189
+ *
1190
+ * // Backward pass with recomputation
1191
+ * const gradients = checkpointer.backward(output, gradOutput);
1192
+ * ```
1193
+ */
1194
+ export class GradientCheckpointer extends EventEmitter {
1195
+ /**
1196
+ * Create a GradientCheckpointer
1197
+ *
1198
+ * @param {Object} [config={}] - Checkpointer configuration
1199
+ */
1200
+ constructor(config = {}) {
1201
+ super();
1202
+
1203
+ this.config = {
1204
+ enabled: true,
1205
+ checkpointSegments: 4, // Number of segments to divide computation
1206
+ maxStoredActivations: 10, // Maximum activations to store
1207
+ ...config,
1208
+ };
1209
+
1210
+ // Stored checkpoints
1211
+ this.checkpoints = [];
1212
+
1213
+ // Stored forward functions for recomputation
1214
+ this.forwardFns = [];
1215
+
1216
+ this.stats = {
1217
+ checkpoints: 0,
1218
+ recomputations: 0,
1219
+ memoryReduction: 0,
1220
+ };
1221
+ }
1222
+
1223
+ /**
1224
+ * Run forward pass with checkpointing
1225
+ *
1226
+ * @param {Function} forwardFn - Forward function to checkpoint
1227
+ * @param {any} input - Input to forward function
1228
+ * @returns {any} Forward output
1229
+ */
1230
+ checkpoint(forwardFn, input) {
1231
+ if (!this.config.enabled) {
1232
+ return forwardFn(input);
1233
+ }
1234
+
1235
+ // Store function for potential recomputation
1236
+ this.forwardFns.push({ fn: forwardFn, input: this._shallowCopy(input) });
1237
+ this.stats.checkpoints++;
1238
+
1239
+ // Run forward
1240
+ const output = forwardFn(input);
1241
+
1242
+ // Store checkpoint if within limit
1243
+ if (this.checkpoints.length < this.config.maxStoredActivations) {
1244
+ this.checkpoints.push({
1245
+ input: this._shallowCopy(input),
1246
+ output: this._shallowCopy(output),
1247
+ });
1248
+ }
1249
+
1250
+ return output;
1251
+ }
1252
+
1253
+ /**
1254
+ * Recompute activations for backward pass
1255
+ *
1256
+ * @param {number} segmentIdx - Segment index to recompute
1257
+ * @returns {Object} Recomputed activations
1258
+ */
1259
+ recompute(segmentIdx) {
1260
+ if (segmentIdx < this.checkpoints.length) {
1261
+ return this.checkpoints[segmentIdx];
1262
+ }
1263
+
1264
+ // Recompute from stored forward functions
1265
+ if (segmentIdx < this.forwardFns.length) {
1266
+ const { fn, input } = this.forwardFns[segmentIdx];
1267
+ const output = fn(input);
1268
+ this.stats.recomputations++;
1269
+
1270
+ return { input, output };
1271
+ }
1272
+
1273
+ throw new Error(`Cannot recompute segment ${segmentIdx}`);
1274
+ }
1275
+
1276
+ /**
1277
+ * Clear stored checkpoints
1278
+ */
1279
+ clear() {
1280
+ this.checkpoints = [];
1281
+ this.forwardFns = [];
1282
+ }
1283
+
1284
+ /**
1285
+ * Shallow copy helper
1286
+ * @private
1287
+ */
1288
+ _shallowCopy(obj) {
1289
+ if (Array.isArray(obj)) {
1290
+ return [...obj];
1291
+ }
1292
+ if (obj instanceof Float32Array || obj instanceof Float64Array) {
1293
+ return new obj.constructor(obj);
1294
+ }
1295
+ if (typeof obj === 'object' && obj !== null) {
1296
+ return { ...obj };
1297
+ }
1298
+ return obj;
1299
+ }
1300
+
1301
+ /**
1302
+ * Estimate memory savings
1303
+ */
1304
+ estimateMemorySavings(totalLayers, activationSize) {
1305
+ const withoutCheckpointing = totalLayers * activationSize;
1306
+ const withCheckpointing = this.config.checkpointSegments * activationSize +
1307
+ (totalLayers / this.config.checkpointSegments) * activationSize;
1308
+
1309
+ const savings = 1 - (withCheckpointing / withoutCheckpointing);
1310
+ this.stats.memoryReduction = savings;
1311
+
1312
+ return {
1313
+ without: withoutCheckpointing,
1314
+ with: withCheckpointing,
1315
+ savings: savings * 100,
1316
+ savingsPercentage: `${(savings * 100).toFixed(1)}%`,
1317
+ };
1318
+ }
1319
+
1320
+ /**
1321
+ * Get checkpointer statistics
1322
+ */
1323
+ getStats() {
1324
+ return {
1325
+ ...this.stats,
1326
+ storedCheckpoints: this.checkpoints.length,
1327
+ storedFunctions: this.forwardFns.length,
1328
+ };
1329
+ }
1330
+ }
1331
+
1332
+ // ============================================
1333
+ // LEARNING RATE SCHEDULERS
1334
+ // ============================================
1335
+
1336
+ /**
1337
+ * Learning rate scheduler implementations
1338
+ */
1339
+ export const LRSchedulers = {
1340
+ /**
1341
+ * Constant learning rate
1342
+ */
1343
+ constant: (baseLR, step, totalSteps) => baseLR,
1344
+
1345
+ /**
1346
+ * Linear decay
1347
+ */
1348
+ linear: (baseLR, step, totalSteps) =>
1349
+ baseLR * (1 - step / totalSteps),
1350
+
1351
+ /**
1352
+ * Cosine annealing
1353
+ */
1354
+ cosine: (baseLR, step, totalSteps) =>
1355
+ baseLR * 0.5 * (1 + Math.cos(Math.PI * step / totalSteps)),
1356
+
1357
+ /**
1358
+ * Cosine with warm restarts
1359
+ */
1360
+ cosineWarmRestarts: (baseLR, step, totalSteps, opts = {}) => {
1361
+ const { restartPeriod = 100, restartMultiplier = 2 } = opts;
1362
+ const cycleStep = step % restartPeriod;
1363
+ return baseLR * 0.5 * (1 + Math.cos(Math.PI * cycleStep / restartPeriod));
1364
+ },
1365
+
1366
+ /**
1367
+ * Exponential decay
1368
+ */
1369
+ exponential: (baseLR, step, totalSteps, opts = {}) => {
1370
+ const { gamma = 0.95, stepSize = 100 } = opts;
1371
+ return baseLR * Math.pow(gamma, Math.floor(step / stepSize));
1372
+ },
1373
+
1374
+ /**
1375
+ * Warmup + cosine decay
1376
+ */
1377
+ warmupCosine: (baseLR, step, totalSteps, opts = {}) => {
1378
+ const { warmupSteps = 100 } = opts;
1379
+ if (step < warmupSteps) {
1380
+ return baseLR * (step / warmupSteps);
1381
+ }
1382
+ const decayStep = step - warmupSteps;
1383
+ const decayTotal = totalSteps - warmupSteps;
1384
+ return baseLR * 0.5 * (1 + Math.cos(Math.PI * decayStep / decayTotal));
1385
+ },
1386
+
1387
+ /**
1388
+ * One-cycle policy
1389
+ */
1390
+ oneCycle: (baseLR, step, totalSteps, opts = {}) => {
1391
+ const { maxLR = baseLR * 10, divFactor = 25, finalDiv = 10000 } = opts;
1392
+ const midPoint = totalSteps / 2;
1393
+
1394
+ if (step < midPoint) {
1395
+ // Warmup phase
1396
+ const progress = step / midPoint;
1397
+ return baseLR + progress * (maxLR - baseLR);
1398
+ } else {
1399
+ // Annealing phase
1400
+ const progress = (step - midPoint) / midPoint;
1401
+ const finalLR = baseLR / finalDiv;
1402
+ return maxLR - progress * (maxLR - finalLR);
1403
+ }
1404
+ },
1405
+ };
1406
+
1407
+ // ============================================
1408
+ // EXPORTS
1409
+ // ============================================
1410
+
1411
+ export default {
1412
+ DataPreprocessor,
1413
+ BatchGenerator,
1414
+ LossComputer,
1415
+ EWCManager,
1416
+ GradientCheckpointer,
1417
+ LRSchedulers,
1418
+ };