mini-jstorch 1.8.2 → 2.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/Docs/About.md +84 -83
- package/Docs/Structure.md +115 -128
- package/README.md +70 -31
- package/demo/fu_fun.js +71 -71
- package/demo/linear_regression.js +42 -0
- package/demo/xor_classification.js +47 -0
- package/package.json +23 -23
- package/src/jstorch.js +838 -166
package/src/jstorch.js
CHANGED
|
@@ -29,7 +29,7 @@
|
|
|
29
29
|
// See the Documentation for more details.
|
|
30
30
|
// --------------------------------------------------------------
|
|
31
31
|
|
|
32
|
-
// ----------------------
|
|
32
|
+
// ---------------------- Engine-only Utils ----------------------
|
|
33
33
|
export function zeros(rows, cols) {
|
|
34
34
|
return Array.from({length:rows},()=>Array(cols).fill(0));
|
|
35
35
|
}
|
|
@@ -235,17 +235,111 @@ export function fu_stack(tensors) {
|
|
|
235
235
|
|
|
236
236
|
// ---------------------- Tensor ----------------------
|
|
237
237
|
export class Tensor {
|
|
238
|
-
constructor(data
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
238
|
+
constructor(data, requiresGrad = false){
|
|
239
|
+
this.data = data;
|
|
240
|
+
this.rows = data.length;
|
|
241
|
+
this.cols = data[0].length;
|
|
242
|
+
this.grad = zeros(this.rows, this.cols);
|
|
243
|
+
this.requiresGrad = requiresGrad;
|
|
244
|
+
|
|
245
|
+
this._dataFlat = null;
|
|
246
|
+
this._gradFlat = null;
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
shape(){
|
|
250
|
+
return [this.rows, this.cols];
|
|
251
|
+
}
|
|
252
|
+
|
|
253
|
+
_getDataFlat() {
|
|
254
|
+
if (!this._dataFlat) {
|
|
255
|
+
this._dataFlat = new Float32Array(this.rows * this.cols);
|
|
256
|
+
for (let i = 0; i < this.rows; i++) {
|
|
257
|
+
const offset = i * this.cols;
|
|
258
|
+
const row = this.data[i];
|
|
259
|
+
for (let j = 0; j < this.cols; j++) {
|
|
260
|
+
this._dataFlat[offset + j] = row[j];
|
|
261
|
+
}
|
|
262
|
+
}
|
|
263
|
+
}
|
|
264
|
+
return this._dataFlat;
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
_getGradFlat() {
|
|
268
|
+
if (!this._gradFlat) {
|
|
269
|
+
this._gradFlat = new Float32Array(this.rows * this.cols);
|
|
270
|
+
for (let i = 0; i < this.rows; i++) {
|
|
271
|
+
const offset = i * this.cols;
|
|
272
|
+
const row = this.grad[i];
|
|
273
|
+
for (let j = 0; j < this.cols; j++) {
|
|
274
|
+
this._gradFlat[offset + j] = row[j];
|
|
275
|
+
}
|
|
276
|
+
}
|
|
277
|
+
}
|
|
278
|
+
return this._gradFlat;
|
|
279
|
+
}
|
|
280
|
+
|
|
281
|
+
add(t){
|
|
282
|
+
if (t instanceof Tensor) {
|
|
283
|
+
const result = Array(this.rows);
|
|
284
|
+
const aFlat = this._getDataFlat();
|
|
285
|
+
const bFlat = t._getDataFlat();
|
|
286
|
+
for (let i = 0; i < this.rows; i++) {
|
|
287
|
+
result[i] = Array(this.cols);
|
|
288
|
+
const offset = i * this.cols;
|
|
289
|
+
for (let j = 0; j < this.cols; j++) {
|
|
290
|
+
result[i][j] = aFlat[offset + j] + bFlat[offset + j];
|
|
291
|
+
}
|
|
292
|
+
}
|
|
293
|
+
return new Tensor(result);
|
|
294
|
+
} else {
|
|
295
|
+
const res = this.data.map(r => r.map(v => v + t));
|
|
296
|
+
return new Tensor(res);
|
|
297
|
+
}
|
|
298
|
+
}
|
|
299
|
+
|
|
300
|
+
mul(t){
|
|
301
|
+
if (t instanceof Tensor) {
|
|
302
|
+
const result = Array(this.rows);
|
|
303
|
+
const aFlat = this._getDataFlat();
|
|
304
|
+
const bFlat = t._getDataFlat();
|
|
305
|
+
for (let i = 0; i < this.rows; i++) {
|
|
306
|
+
result[i] = Array(this.cols);
|
|
307
|
+
const offset = i * this.cols;
|
|
308
|
+
for (let j = 0; j < this.cols; j++) {
|
|
309
|
+
result[i][j] = aFlat[offset + j] * bFlat[offset + j];
|
|
310
|
+
}
|
|
311
|
+
}
|
|
312
|
+
return new Tensor(result);
|
|
313
|
+
} else {
|
|
314
|
+
const res = this.data.map(r => r.map(v => v * t));
|
|
315
|
+
return new Tensor(res);
|
|
316
|
+
}
|
|
317
|
+
}
|
|
318
|
+
|
|
319
|
+
matmul(t){
|
|
320
|
+
if (!(t instanceof Tensor)) throw new Error("matmul requires Tensor");
|
|
321
|
+
return new Tensor(dot(this.data, t.data));
|
|
322
|
+
}
|
|
323
|
+
|
|
324
|
+
transpose(){
|
|
325
|
+
return new Tensor(transpose(this.data));
|
|
326
|
+
}
|
|
327
|
+
|
|
328
|
+
flatten(){
|
|
329
|
+
return this.data.flat();
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
static zeros(r,c){
|
|
333
|
+
return new Tensor(zeros(r,c));
|
|
334
|
+
}
|
|
335
|
+
|
|
336
|
+
static ones(r,c){
|
|
337
|
+
return new Tensor(ones(r,c));
|
|
338
|
+
}
|
|
339
|
+
|
|
340
|
+
static random(r,c,scale=0.1){
|
|
341
|
+
return new Tensor(randomMatrix(r,c,scale));
|
|
342
|
+
}
|
|
249
343
|
}
|
|
250
344
|
|
|
251
345
|
// ---------------------- Layers ----------------------
|
|
@@ -256,58 +350,175 @@ export class Linear {
|
|
|
256
350
|
this.gradW = zeros(inputDim, outputDim);
|
|
257
351
|
this.gradb = Array(outputDim).fill(0);
|
|
258
352
|
this.x = null;
|
|
259
|
-
this.originalShape = null;
|
|
260
|
-
|
|
353
|
+
this.originalShape = null;
|
|
354
|
+
|
|
355
|
+
this._WFlat = null;
|
|
356
|
+
this._bFlat = null;
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
_updateCache() {
|
|
360
|
+
const rows = this.W.length;
|
|
361
|
+
const cols = this.W[0].length;
|
|
362
|
+
this._WFlat = new Float32Array(rows * cols);
|
|
363
|
+
for (let i = 0; i < rows; i++){
|
|
364
|
+
const offset = i * cols;
|
|
365
|
+
const row = this.W[i];
|
|
366
|
+
for (let j = 0; j < cols; j++){
|
|
367
|
+
this._WFlat[offset + j] = row[j];
|
|
368
|
+
}
|
|
369
|
+
}
|
|
370
|
+
this._bFlat = new Float32Array(this.b);
|
|
371
|
+
}
|
|
261
372
|
|
|
262
373
|
forward(x){
|
|
263
|
-
// Handle both [batch, features] and [batch, 1, features]
|
|
264
374
|
this.originalShape = this._getShapeType(x);
|
|
265
375
|
|
|
266
376
|
if (this.originalShape === '3d') {
|
|
267
|
-
// Convert from [batch, 1, features] to [batch, features]
|
|
268
377
|
this.x = x.map(sample => sample[0]);
|
|
269
378
|
} else {
|
|
270
|
-
// Already in [batch, features] format
|
|
271
379
|
this.x = x;
|
|
272
380
|
}
|
|
381
|
+
|
|
382
|
+
this._updateCache();
|
|
383
|
+
|
|
384
|
+
const m = this.x.length;
|
|
385
|
+
const k = this.x[0].length;
|
|
386
|
+
const n = this.W[0].length;
|
|
273
387
|
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
388
|
+
if (!this._WFlat) {
|
|
389
|
+
const rows = this.W.length;
|
|
390
|
+
const cols = this.W[0].length;
|
|
391
|
+
this._WFlat = new Float32Array(rows * cols);
|
|
392
|
+
for (let i = 0; i < rows; i++) {
|
|
393
|
+
const offset = i * cols;
|
|
394
|
+
const row = this.W[i];
|
|
395
|
+
for (let j = 0; j < cols; j++) {
|
|
396
|
+
this._WFlat[offset + j] = row[j];
|
|
397
|
+
}
|
|
283
398
|
}
|
|
399
|
+
this._bFlat = new Float32Array(this.b);
|
|
284
400
|
}
|
|
285
401
|
|
|
286
|
-
|
|
287
|
-
|
|
402
|
+
// Flatten input x to Float32Array
|
|
403
|
+
const xFlat = new Float32Array(m * k);
|
|
404
|
+
for (let i = 0; i < m; i++) {
|
|
405
|
+
const row = this.x[i];
|
|
406
|
+
const offset = i * k;
|
|
407
|
+
for (let j = 0; j < k; j++) {
|
|
408
|
+
xFlat[offset + j] = row[j];
|
|
409
|
+
}
|
|
288
410
|
}
|
|
289
|
-
|
|
290
|
-
const
|
|
291
|
-
for(let i = 0; i <
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
411
|
+
|
|
412
|
+
const outFlat = new Float32Array(m * n);
|
|
413
|
+
for (let i = 0; i < m; i++) {
|
|
414
|
+
const xOffset = i * k;
|
|
415
|
+
for (let j = 0; j < n; j++) {
|
|
416
|
+
let sum = 0;
|
|
417
|
+
for (let l = 0; l < k; l++) {
|
|
418
|
+
sum += xFlat[xOffset + l] * this._WFlat[l * n + j];
|
|
295
419
|
}
|
|
420
|
+
outFlat[i * n + j] = sum + this._bFlat[j];
|
|
296
421
|
}
|
|
297
422
|
}
|
|
298
423
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
424
|
+
const out = Array(m);
|
|
425
|
+
for (let i = 0; i < m; i++) {
|
|
426
|
+
const row = Array(n);
|
|
427
|
+
const offset = i * n;
|
|
428
|
+
for (let j = 0; j < n; j++) {
|
|
429
|
+
row[j] = outFlat[offset + j];
|
|
430
|
+
}
|
|
431
|
+
out[i] = row;
|
|
302
432
|
}
|
|
303
|
-
|
|
433
|
+
|
|
434
|
+
return out;
|
|
304
435
|
}
|
|
305
436
|
|
|
437
|
+
backward(grad){
|
|
438
|
+
const m = this.x.length;
|
|
439
|
+
const k = this.W.length; // input dim
|
|
440
|
+
const n = this.W[0].length; // output dim
|
|
441
|
+
|
|
442
|
+
// Convert grad to Float32Array
|
|
443
|
+
const gradFlat = new Float32Array(m * n);
|
|
444
|
+
for (let i = 0; i < m; i++) {
|
|
445
|
+
const row = grad[i];
|
|
446
|
+
const offset = i * n;
|
|
447
|
+
for (let j = 0; j < n; j++) {
|
|
448
|
+
gradFlat[offset + j] = row[j];
|
|
449
|
+
}
|
|
450
|
+
}
|
|
451
|
+
|
|
452
|
+
// Convert x to Float32Array
|
|
453
|
+
const xFlat = new Float32Array(m * k);
|
|
454
|
+
for (let i = 0; i < m; i++) {
|
|
455
|
+
const row = this.x[i];
|
|
456
|
+
const offset = i * k;
|
|
457
|
+
for (let j = 0; j < k; j++) {
|
|
458
|
+
xFlat[offset + j] = row[j];
|
|
459
|
+
}
|
|
460
|
+
}
|
|
461
|
+
|
|
462
|
+
// Reset gradW
|
|
463
|
+
for (let i = 0; i < this.gradW.length; i++) {
|
|
464
|
+
for (let j = 0; j < this.gradW[0].length; j++) {
|
|
465
|
+
this.gradW[i][j] = 0;
|
|
466
|
+
}
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
// Compute gradW = x^T * grad
|
|
470
|
+
for (let i = 0; i < k; i++) {
|
|
471
|
+
for (let j = 0; j < n; j++) {
|
|
472
|
+
let sum = 0;
|
|
473
|
+
for (let batch = 0; batch < m; batch++) {
|
|
474
|
+
sum += xFlat[batch * k + i] * gradFlat[batch * n + j];
|
|
475
|
+
}
|
|
476
|
+
this.gradW[i][j] = sum;
|
|
477
|
+
}
|
|
478
|
+
}
|
|
479
|
+
|
|
480
|
+
// Compute gradb
|
|
481
|
+
for (let j = 0; j < n; j++) {
|
|
482
|
+
let sum = 0;
|
|
483
|
+
for (let batch = 0; batch < m; batch++) {
|
|
484
|
+
sum += gradFlat[batch * n + j];
|
|
485
|
+
}
|
|
486
|
+
this.gradb[j] = sum;
|
|
487
|
+
}
|
|
488
|
+
|
|
489
|
+
const gradInputFlat = new Float32Array(m * k);
|
|
490
|
+
for (let i = 0; i < m; i++) {
|
|
491
|
+
for (let j = 0; j < k; j++) {
|
|
492
|
+
let sum = 0;
|
|
493
|
+
for (let l = 0; l < n; l++) {
|
|
494
|
+
sum += gradFlat[i * n + l] * this.W[j][l];
|
|
495
|
+
}
|
|
496
|
+
gradInputFlat[i * k + j] = sum;
|
|
497
|
+
}
|
|
498
|
+
}
|
|
499
|
+
|
|
500
|
+
// Convert back to 2D array
|
|
501
|
+
const gradInput = Array(m);
|
|
502
|
+
for (let i = 0; i < m; i++) {
|
|
503
|
+
const row = Array(k);
|
|
504
|
+
const offset = i * k;
|
|
505
|
+
for (let j = 0; j < k; j++) {
|
|
506
|
+
row[j] = gradInputFlat[offset + j];
|
|
507
|
+
}
|
|
508
|
+
gradInput[i] = row;
|
|
509
|
+
}
|
|
510
|
+
|
|
511
|
+
if (this.originalShape === '3d') {
|
|
512
|
+
return gradInput.map(row => [row]);
|
|
513
|
+
}
|
|
514
|
+
return gradInput;
|
|
515
|
+
}
|
|
516
|
+
|
|
306
517
|
_getShapeType(x) {
|
|
307
518
|
if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
|
|
308
|
-
return '3d';
|
|
519
|
+
return '3d';
|
|
309
520
|
} else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
|
|
310
|
-
return '2d';
|
|
521
|
+
return '2d';
|
|
311
522
|
} else {
|
|
312
523
|
throw new Error(`Unsupported input shape for Linear layer`);
|
|
313
524
|
}
|
|
@@ -405,7 +616,7 @@ export class Flatten {
|
|
|
405
616
|
parameters() { return []; }
|
|
406
617
|
}
|
|
407
618
|
|
|
408
|
-
// ---------------------- Conv2D ----------------------
|
|
619
|
+
// ---------------------- Conv2D (BETA) ----------------------
|
|
409
620
|
export class Conv2D {
|
|
410
621
|
constructor(inC, outC, kernel, stride=1, padding=0){
|
|
411
622
|
this.inC = inC;
|
|
@@ -420,10 +631,31 @@ export class Conv2D {
|
|
|
420
631
|
Array(inC).fill().map(() => zeros(kernel, kernel))
|
|
421
632
|
);
|
|
422
633
|
this.x = null;
|
|
634
|
+
|
|
635
|
+
// Cache Float32Array untuk kernels
|
|
636
|
+
this._WFlat = null;
|
|
637
|
+
this._cacheKernels();
|
|
638
|
+
}
|
|
639
|
+
|
|
640
|
+
_cacheKernels() {
|
|
641
|
+
this._WFlat = this.W.map(oc =>
|
|
642
|
+
oc.map(ic => {
|
|
643
|
+
const rows = ic.length;
|
|
644
|
+
const cols = ic[0].length;
|
|
645
|
+
const flat = new Float32Array(rows * cols);
|
|
646
|
+
for (let i = 0; i < rows; i++) {
|
|
647
|
+
const offset = i * cols;
|
|
648
|
+
const row = ic[i];
|
|
649
|
+
for (let j = 0; j < cols; j++) {
|
|
650
|
+
flat[offset + j] = row[j];
|
|
651
|
+
}
|
|
652
|
+
}
|
|
653
|
+
return flat;
|
|
654
|
+
})
|
|
655
|
+
);
|
|
423
656
|
}
|
|
424
657
|
|
|
425
658
|
pad2D(input, pad){
|
|
426
|
-
// Input is single channel [height, width]
|
|
427
659
|
if (!input || !input.length) return input;
|
|
428
660
|
|
|
429
661
|
const rows = input.length + 2 * pad;
|
|
@@ -431,26 +663,31 @@ export class Conv2D {
|
|
|
431
663
|
const out = Array.from({length: rows}, () => Array(cols).fill(0));
|
|
432
664
|
|
|
433
665
|
for(let i = 0; i < input.length; i++) {
|
|
666
|
+
const row = input[i];
|
|
667
|
+
const outRow = out[i + pad];
|
|
434
668
|
for(let j = 0; j < input[0].length; j++) {
|
|
435
|
-
|
|
669
|
+
outRow[j + pad] = row[j];
|
|
436
670
|
}
|
|
437
671
|
}
|
|
438
672
|
return out;
|
|
439
673
|
}
|
|
440
674
|
|
|
441
|
-
conv2DSingle(input,
|
|
442
|
-
const rows = Math.floor((input.length -
|
|
443
|
-
const cols = Math.floor((input[0].length -
|
|
444
|
-
const out =
|
|
675
|
+
conv2DSingle(input, kernelFlat, kH, kW) {
|
|
676
|
+
const rows = Math.floor((input.length - kH) / this.stride) + 1;
|
|
677
|
+
const cols = Math.floor((input[0].length - kW) / this.stride) + 1;
|
|
678
|
+
const out = Array(rows);
|
|
445
679
|
|
|
446
680
|
for(let i = 0; i < rows; i++) {
|
|
681
|
+
out[i] = Array(cols);
|
|
447
682
|
for(let j = 0; j < cols; j++) {
|
|
448
683
|
let sum = 0;
|
|
449
|
-
for(let ki = 0; ki <
|
|
450
|
-
|
|
451
|
-
|
|
684
|
+
for(let ki = 0; ki < kH; ki++) {
|
|
685
|
+
const inputRow = i * this.stride + ki;
|
|
686
|
+
const rowOffset = ki * kW;
|
|
687
|
+
const inputRowData = input[inputRow];
|
|
688
|
+
for(let kj = 0; kj < kW; kj++) {
|
|
452
689
|
const inputCol = j * this.stride + kj;
|
|
453
|
-
sum +=
|
|
690
|
+
sum += inputRowData[inputCol] * kernelFlat[rowOffset + kj];
|
|
454
691
|
}
|
|
455
692
|
}
|
|
456
693
|
out[i][j] = sum;
|
|
@@ -461,6 +698,9 @@ export class Conv2D {
|
|
|
461
698
|
|
|
462
699
|
forward(batch) {
|
|
463
700
|
this.x = batch;
|
|
701
|
+
const kH = this.kernel;
|
|
702
|
+
const kW = this.kernel;
|
|
703
|
+
|
|
464
704
|
return batch.map(sample => {
|
|
465
705
|
const channelsOut = [];
|
|
466
706
|
for(let oc = 0; oc < this.outC; oc++) {
|
|
@@ -471,7 +711,7 @@ export class Conv2D {
|
|
|
471
711
|
inputChan = this.pad2D(inputChan, this.padding);
|
|
472
712
|
}
|
|
473
713
|
|
|
474
|
-
const conv = this.conv2DSingle(inputChan, this.
|
|
714
|
+
const conv = this.conv2DSingle(inputChan, this._WFlat[oc][ic], kH, kW);
|
|
475
715
|
|
|
476
716
|
if(outChan === null) {
|
|
477
717
|
outChan = conv;
|
|
@@ -487,7 +727,7 @@ export class Conv2D {
|
|
|
487
727
|
|
|
488
728
|
backward(grad) {
|
|
489
729
|
const batchSize = this.x.length;
|
|
490
|
-
const gradW = this.
|
|
730
|
+
const gradW = this.gradW.map(oc => oc.map(ic => zeros(this.kernel, this.kernel)));
|
|
491
731
|
const gradInput = this.x.map(sample =>
|
|
492
732
|
sample.map(chan => zeros(chan.length, chan[0].length))
|
|
493
733
|
);
|
|
@@ -549,10 +789,109 @@ export class Conv2D {
|
|
|
549
789
|
|
|
550
790
|
// ---------------------- Sequential ----------------------
|
|
551
791
|
export class Sequential {
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
792
|
+
constructor(layers=[]) {
|
|
793
|
+
this.layers = layers;
|
|
794
|
+
}
|
|
795
|
+
|
|
796
|
+
forward(x){
|
|
797
|
+
return this.layers.reduce((acc, l) => l.forward(acc), x);
|
|
798
|
+
}
|
|
799
|
+
|
|
800
|
+
backward(grad){
|
|
801
|
+
return this.layers.reduceRight((g, l) => l.backward(g), grad);
|
|
802
|
+
}
|
|
803
|
+
|
|
804
|
+
parameters(){
|
|
805
|
+
return this.layers.flatMap(l => l.parameters ? l.parameters() : []);
|
|
806
|
+
}
|
|
807
|
+
|
|
808
|
+
/**
|
|
809
|
+
* Zero out all gradients of all parameters
|
|
810
|
+
*/
|
|
811
|
+
zeroGrad(){
|
|
812
|
+
const params = this.parameters();
|
|
813
|
+
|
|
814
|
+
for (const p of params){
|
|
815
|
+
if (!p.grad) continue;
|
|
816
|
+
|
|
817
|
+
// Handle different gradient shapes
|
|
818
|
+
if (Array.isArray(p.grad)){
|
|
819
|
+
// Check if it's 1D or 2D array
|
|
820
|
+
if (p.grad.length > 0 && Array.isArray(p.grad[0])){
|
|
821
|
+
// 2D Gradient (weights)
|
|
822
|
+
for (let i = 0; i < p.grad.length; i++){
|
|
823
|
+
const row = p.grad[i];
|
|
824
|
+
for (let j = 0; j < row.length; j++){
|
|
825
|
+
row[j] = 0;
|
|
826
|
+
}
|
|
827
|
+
}
|
|
828
|
+
} else {
|
|
829
|
+
// 1D Gradient (bias)
|
|
830
|
+
for (let i = 0; i < p.grad.length; i++){
|
|
831
|
+
p.grad[i] = 0;
|
|
832
|
+
}
|
|
833
|
+
}
|
|
834
|
+
} else if (typeof p.grad === 'number'){
|
|
835
|
+
// Scalar gradient (rare case)
|
|
836
|
+
p.grad = 0;
|
|
837
|
+
}
|
|
838
|
+
}
|
|
839
|
+
|
|
840
|
+
return this; // Allow chaining
|
|
841
|
+
}
|
|
842
|
+
|
|
843
|
+
/**
|
|
844
|
+
* Train mode (enable dropout, batch norm, etc.)
|
|
845
|
+
*/
|
|
846
|
+
train(){
|
|
847
|
+
this.layers.forEach(layer => {
|
|
848
|
+
if (layer.train) layer.train();
|
|
849
|
+
});
|
|
850
|
+
return this;
|
|
851
|
+
}
|
|
852
|
+
|
|
853
|
+
/**
|
|
854
|
+
* Eval mode (disable dropout, batch norm, etc.)
|
|
855
|
+
*/
|
|
856
|
+
eval(){
|
|
857
|
+
this.layers.forEach(layer => {
|
|
858
|
+
if (layer.eval) layer.eval();
|
|
859
|
+
});
|
|
860
|
+
return this;
|
|
861
|
+
}
|
|
862
|
+
|
|
863
|
+
/**
|
|
864
|
+
* Get model state dict (weights and biases)
|
|
865
|
+
*/
|
|
866
|
+
stateDict(){
|
|
867
|
+
const state = {};
|
|
868
|
+
this.layers.forEach((layer, idx) => {
|
|
869
|
+
if (layer.W) state[`layer_${idx}.weight`] = layer.W;
|
|
870
|
+
if (layer.b) state[`layer_${idx}.bias`] = layer.b;
|
|
871
|
+
});
|
|
872
|
+
return state;
|
|
873
|
+
}
|
|
874
|
+
|
|
875
|
+
/**
|
|
876
|
+
* Load state dict
|
|
877
|
+
*/
|
|
878
|
+
loadStateDict(stateDict){
|
|
879
|
+
this.layers.forEach((layer, idx) => {
|
|
880
|
+
const weightKey = `layer_${idx}.weight`;
|
|
881
|
+
const biasKey = `layer_${idx}.bias`;
|
|
882
|
+
|
|
883
|
+
if (layer.W && stateDict[weightKey]){
|
|
884
|
+
layer.W = stateDict[weightKey];
|
|
885
|
+
// Invalidate cache
|
|
886
|
+
if (layer._InvalidateCache) layer._InvalidateCache();
|
|
887
|
+
}
|
|
888
|
+
if (layer.b && stateDict[biasKey]){
|
|
889
|
+
layer.b = stateDict[biasKey];
|
|
890
|
+
if (layer._InvalidateCache) layer._InvalidateCache();
|
|
891
|
+
}
|
|
892
|
+
});
|
|
893
|
+
return this;
|
|
894
|
+
}
|
|
556
895
|
}
|
|
557
896
|
|
|
558
897
|
// ---------------------- Activations ----------------------
|
|
@@ -606,42 +945,35 @@ export class Softmax {
|
|
|
606
945
|
|
|
607
946
|
forward(x) {
|
|
608
947
|
this.input = x;
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
617
|
-
|
|
948
|
+
|
|
949
|
+
// x: [batch_size, num_Classes]
|
|
950
|
+
this.output = x.map(row => {
|
|
951
|
+
const maxVal = Math.max(...row);
|
|
952
|
+
const exps = row.map(v => Math.exp(v - maxVal));
|
|
953
|
+
const sumExps = exps.reduce((a, b) => a + b, 0);
|
|
954
|
+
return exps.map(v => v / sumExps);
|
|
955
|
+
});
|
|
956
|
+
return this.output;
|
|
618
957
|
}
|
|
619
958
|
|
|
620
959
|
backward(grad) {
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
}
|
|
639
|
-
gradInput[i][j] = sum;
|
|
640
|
-
}
|
|
641
|
-
}
|
|
642
|
-
|
|
643
|
-
return gradInput;
|
|
644
|
-
}
|
|
960
|
+
const batchSize = grad.length;
|
|
961
|
+
const numClasses = grad[0].length;
|
|
962
|
+
const gradInput = zeros(batchSize, numClasses);
|
|
963
|
+
|
|
964
|
+
for (let i = 0; i < batchSize; i++){
|
|
965
|
+
const s = this.output[i]; // Softmax output
|
|
966
|
+
const gradOut = grad[i]; // Gradient from next layer
|
|
967
|
+
|
|
968
|
+
const dot = s.reduce((sum, val, k) => sum + val * gradOut[k], 0);
|
|
969
|
+
|
|
970
|
+
for (let j = 0; j < numClasses; j++){
|
|
971
|
+
gradInput[i][j] = s[j] * (gradOut[j] - dot);
|
|
972
|
+
}
|
|
973
|
+
}
|
|
974
|
+
|
|
975
|
+
return gradInput;
|
|
976
|
+
}
|
|
645
977
|
|
|
646
978
|
parameters() {
|
|
647
979
|
return []; // Softmax has no trainable parameters
|
|
@@ -650,90 +982,209 @@ export class Softmax {
|
|
|
650
982
|
|
|
651
983
|
// ---------------------- Tokenizer ----------------------
|
|
652
984
|
export class Tokenizer {
|
|
653
|
-
constructor(vocabSize = 2000){
|
|
985
|
+
constructor(vocabSize = 2000) {
|
|
654
986
|
this.vocabSize = vocabSize;
|
|
655
987
|
this.wordToIndex = new Map();
|
|
656
988
|
this.indexToWord = new Map();
|
|
657
989
|
this.fitted = false;
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
990
|
+
this.wordCounts = null;
|
|
991
|
+
|
|
992
|
+
// Special tokens
|
|
993
|
+
this.PAD_TOKEN = '<PAD>';
|
|
994
|
+
this.UNK_TOKEN = '<UNK>';
|
|
995
|
+
this.PAD_INDEX = 0;
|
|
996
|
+
this.UNK_INDEX = 1;
|
|
997
|
+
}
|
|
998
|
+
|
|
999
|
+
/**
|
|
1000
|
+
* Fit tokenizer on a list of texts
|
|
1001
|
+
* @param {string[]} texts - Array of text strings
|
|
1002
|
+
* @returns {Tokenizer} this
|
|
1003
|
+
*/
|
|
1004
|
+
fit(texts) {
|
|
1005
|
+
this.wordCounts = new Map();
|
|
1006
|
+
this.wordToIndex.clear();
|
|
1007
|
+
this.indexToWord.clear();
|
|
1008
|
+
|
|
1009
|
+
// Count word frequencies
|
|
664
1010
|
texts.forEach(text => {
|
|
665
|
-
const words = this.
|
|
1011
|
+
const words = this._tokenize(text);
|
|
666
1012
|
words.forEach(word => {
|
|
667
|
-
wordCounts.set(word, (wordCounts.get(word) || 0) + 1);
|
|
1013
|
+
this.wordCounts.set(word, (this.wordCounts.get(word) || 0) + 1);
|
|
668
1014
|
});
|
|
669
1015
|
});
|
|
670
|
-
|
|
671
|
-
|
|
672
|
-
|
|
673
|
-
.
|
|
674
|
-
.
|
|
675
|
-
|
|
676
|
-
//
|
|
677
|
-
this.wordToIndex.
|
|
678
|
-
this.indexToWord.
|
|
679
|
-
|
|
680
|
-
//
|
|
681
|
-
this.wordToIndex.set(
|
|
682
|
-
this.indexToWord.set(
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
this.wordToIndex.set(word, index
|
|
687
|
-
this.indexToWord.set(index
|
|
688
|
-
})
|
|
689
|
-
|
|
1016
|
+
|
|
1017
|
+
const sortedWords = [...this.wordCounts.entries()]
|
|
1018
|
+
.sort((a, b) => b[1] - a[1]) // Descending: most frequent first
|
|
1019
|
+
.slice(0, this.vocabSize - 2) // -2 for PAD and UNK
|
|
1020
|
+
.map(([word]) => word);
|
|
1021
|
+
|
|
1022
|
+
// Index 0 = <PAD> (padding)
|
|
1023
|
+
this.wordToIndex.set(this.PAD_TOKEN, this.PAD_INDEX);
|
|
1024
|
+
this.indexToWord.set(this.PAD_INDEX, this.PAD_TOKEN);
|
|
1025
|
+
|
|
1026
|
+
// Index 1 = <UNK> (unknown)
|
|
1027
|
+
this.wordToIndex.set(this.UNK_TOKEN, this.UNK_INDEX);
|
|
1028
|
+
this.indexToWord.set(this.UNK_INDEX, this.UNK_TOKEN);
|
|
1029
|
+
|
|
1030
|
+
sortedWords.forEach((word, idx) => {
|
|
1031
|
+
const index = idx + 2; // Start from index 2
|
|
1032
|
+
this.wordToIndex.set(word, index);
|
|
1033
|
+
this.indexToWord.set(index, word);
|
|
1034
|
+
});
|
|
1035
|
+
|
|
690
1036
|
this.fitted = true;
|
|
691
1037
|
return this;
|
|
692
1038
|
}
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
698
|
-
|
|
1039
|
+
|
|
1040
|
+
/**
|
|
1041
|
+
* Fit and transform in one step
|
|
1042
|
+
* @param {string[]} texts - Array of text strings
|
|
1043
|
+
* @param {number|null} maxLength - Pad/truncate to this length
|
|
1044
|
+
* @param {boolean} padToMax - Whether to pad to maxLength (default: true)
|
|
1045
|
+
* @returns {number[][]} Tokenized sequences
|
|
1046
|
+
*/
|
|
1047
|
+
fitTransform(texts, maxLength = null, padToMax = true) {
|
|
1048
|
+
this.fit(texts);
|
|
1049
|
+
return this.transform(texts, maxLength, padToMax);
|
|
699
1050
|
}
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
1051
|
+
|
|
1052
|
+
/**
|
|
1053
|
+
* Transform texts to token indices
|
|
1054
|
+
* @param {string[]} texts - Array of text strings
|
|
1055
|
+
* @param {number|null} maxLength - Pad/truncate to this length
|
|
1056
|
+
* @param {boolean} padToMax - Whether to pad to maxLength (default: true)
|
|
1057
|
+
* @returns {number[][]} Tokenized sequences
|
|
1058
|
+
*/
|
|
1059
|
+
transform(texts, maxLength = null, padToMax = true) {
|
|
1060
|
+
if (!this.fitted) {
|
|
1061
|
+
throw new Error("Tokenizer not fitted. Call fit() or fitTransform() first.");
|
|
1062
|
+
}
|
|
1063
|
+
|
|
704
1064
|
return texts.map(text => {
|
|
705
|
-
const
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
} else {
|
|
712
|
-
return [...tokens, ...Array(maxLength - tokens.length).fill(0)];
|
|
713
|
-
}
|
|
1065
|
+
const words = this._tokenize(text);
|
|
1066
|
+
let tokens = words.map(word => this.wordToIndex.get(word) || this.UNK_INDEX);
|
|
1067
|
+
|
|
1068
|
+
// Truncate if maxLength specified
|
|
1069
|
+
if (maxLength !== null && tokens.length > maxLength) {
|
|
1070
|
+
tokens = tokens.slice(0, maxLength);
|
|
714
1071
|
}
|
|
715
|
-
|
|
1072
|
+
|
|
1073
|
+
// Pad if maxLength specified and padToMax is true
|
|
1074
|
+
if (maxLength !== null && padToMax && tokens.length < maxLength) {
|
|
1075
|
+
tokens = [...tokens, ...Array(maxLength - tokens.length).fill(this.PAD_INDEX)];
|
|
1076
|
+
}
|
|
1077
|
+
|
|
716
1078
|
return tokens;
|
|
717
|
-
})
|
|
1079
|
+
});
|
|
718
1080
|
}
|
|
719
|
-
|
|
720
|
-
|
|
721
|
-
|
|
1081
|
+
|
|
1082
|
+
/**
|
|
1083
|
+
* Convert token indices back to text
|
|
1084
|
+
* @param {number[]|number[][]} tokens - Single sequence or batch of sequences
|
|
1085
|
+
* @param {boolean} skipPad - Whether to skip PAD tokens (default: true)
|
|
1086
|
+
* @returns {string|string[]} Detokenized text(s)
|
|
1087
|
+
*/
|
|
1088
|
+
inverseTransform(tokens, skipPad = true) {
|
|
1089
|
+
if (!this.fitted) {
|
|
1090
|
+
throw new Error("Tokenizer not fitted. Call fit() first.");
|
|
1091
|
+
}
|
|
1092
|
+
|
|
1093
|
+
// Check if batch or single sequence
|
|
1094
|
+
const isBatch = Array.isArray(tokens[0]);
|
|
1095
|
+
|
|
1096
|
+
if (isBatch) {
|
|
1097
|
+
return tokens.map(seq => this._detokenize(seq, skipPad));
|
|
1098
|
+
} else {
|
|
1099
|
+
return this._detokenize(tokens, skipPad);
|
|
1100
|
+
}
|
|
722
1101
|
}
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
1102
|
+
|
|
1103
|
+
/**
|
|
1104
|
+
* Get vocabulary as array
|
|
1105
|
+
* @returns {string[]} Vocabulary list
|
|
1106
|
+
*/
|
|
1107
|
+
getVocabulary() {
|
|
1108
|
+
if (!this.fitted) return [];
|
|
1109
|
+
|
|
1110
|
+
const vocab = new Array(this.wordToIndex.size);
|
|
1111
|
+
for (let i = 0; i < this.wordToIndex.size; i++) {
|
|
1112
|
+
vocab[i] = this.indexToWord.get(i);
|
|
1113
|
+
}
|
|
1114
|
+
return vocab;
|
|
726
1115
|
}
|
|
727
|
-
|
|
728
|
-
|
|
1116
|
+
|
|
1117
|
+
/**
|
|
1118
|
+
* Get vocabulary size
|
|
1119
|
+
* @returns {number} Number of words in vocabulary
|
|
1120
|
+
*/
|
|
1121
|
+
getVocabSize() {
|
|
729
1122
|
return this.wordToIndex.size;
|
|
730
1123
|
}
|
|
731
|
-
|
|
732
|
-
|
|
1124
|
+
|
|
1125
|
+
/**
|
|
1126
|
+
* Get word frequency counts
|
|
1127
|
+
* @returns {Map<string, number>} Word frequency map
|
|
1128
|
+
*/
|
|
1129
|
+
getWordCounts() {
|
|
1130
|
+
return this.wordCounts ? new Map(this.wordCounts) : null;
|
|
1131
|
+
}
|
|
1132
|
+
|
|
1133
|
+
/**
|
|
1134
|
+
* Get most common words
|
|
1135
|
+
* @param {number} n - Number of words to return
|
|
1136
|
+
* @returns {Array<{word: string, count: number}>} Most common words
|
|
1137
|
+
*/
|
|
1138
|
+
getMostCommon(n = 10) {
|
|
1139
|
+
if (!this.wordCounts) return [];
|
|
1140
|
+
|
|
1141
|
+
return [...this.wordCounts.entries()]
|
|
1142
|
+
.sort((a, b) => b[1] - a[1])
|
|
1143
|
+
.slice(0, n)
|
|
1144
|
+
.map(([word, count]) => ({ word, count }));
|
|
1145
|
+
}
|
|
1146
|
+
|
|
1147
|
+
/**
|
|
1148
|
+
* Internal: tokenize text into words
|
|
1149
|
+
* @param {string} text
|
|
1150
|
+
* @returns {string[]}
|
|
1151
|
+
*/
|
|
1152
|
+
_tokenize(text) {
|
|
1153
|
+
// handle contractions and preserve word boundaries
|
|
733
1154
|
return text.toLowerCase()
|
|
734
|
-
|
|
735
|
-
|
|
736
|
-
|
|
1155
|
+
.replace(/([.!?;:,])/g, ' $1 ')
|
|
1156
|
+
.replace(/\s+/g, ' ')
|
|
1157
|
+
.trim()
|
|
1158
|
+
.split(' ')
|
|
1159
|
+
.filter(word => word.length > 0 && !/^[.!?;:,]+$/.test(word) || word.length > 1);
|
|
1160
|
+
}
|
|
1161
|
+
|
|
1162
|
+
/**
|
|
1163
|
+
* Internal: detokenize sequence to text
|
|
1164
|
+
* @param {number[]} tokens
|
|
1165
|
+
* @param {boolean} skipPad
|
|
1166
|
+
* @returns {string}
|
|
1167
|
+
*/
|
|
1168
|
+
_detokenize(tokens, skipPad = true) {
|
|
1169
|
+
const words = [];
|
|
1170
|
+
|
|
1171
|
+
for (const token of tokens) {
|
|
1172
|
+
if (skipPad && token === this.PAD_INDEX) {
|
|
1173
|
+
continue; // Skip padding tokens
|
|
1174
|
+
}
|
|
1175
|
+
|
|
1176
|
+
const word = this.indexToWord.get(token);
|
|
1177
|
+
if (word && word !== this.PAD_TOKEN) {
|
|
1178
|
+
words.push(word === this.UNK_TOKEN ? '?' : word);
|
|
1179
|
+
} else if (word === undefined) {
|
|
1180
|
+
words.push('?');
|
|
1181
|
+
}
|
|
1182
|
+
}
|
|
1183
|
+
|
|
1184
|
+
let text = words.join(' ');
|
|
1185
|
+
text = text.replace(/\s+([.!?;:,])/g, '$1');
|
|
1186
|
+
text = text.replace(/([.!?;:,])\s+/g, '$1 ');
|
|
1187
|
+
return text.trim();
|
|
737
1188
|
}
|
|
738
1189
|
}
|
|
739
1190
|
|
|
@@ -741,14 +1192,200 @@ export class Tokenizer {
|
|
|
741
1192
|
export class Sigmoid{ constructor(){ this.out=null; } forward(x){ const fn=v=>1/(1+Math.exp(-v)); this.out=x.map(r=>r.map(fn)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*this.out[i][j]*(1-this.out[i][j]))); } }
|
|
742
1193
|
export class Tanh{ constructor(){ this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>Math.tanh(v))); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(1-this.out[i][j]**2))); } }
|
|
743
1194
|
export class LeakyReLU{ constructor(alpha=0.01){ this.alpha=alpha; this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>v>0?v:v*this.alpha)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(this.out[i][j]>0?1:this.alpha))); } }
|
|
744
|
-
|
|
1195
|
+
|
|
1196
|
+
// ---------------------- GELU ----------------------
|
|
1197
|
+
export class GELU {
|
|
1198
|
+
constructor() {
|
|
1199
|
+
this.x = null;
|
|
1200
|
+
this.output = null;
|
|
1201
|
+
}
|
|
1202
|
+
|
|
1203
|
+
forward(x) {
|
|
1204
|
+
this.x = x;
|
|
1205
|
+
const sqrt2pi = Math.sqrt(2 / Math.PI);
|
|
1206
|
+
const k = 0.044715;
|
|
1207
|
+
|
|
1208
|
+
this.output = x.map(row =>
|
|
1209
|
+
row.map(v => {
|
|
1210
|
+
const cube = v * v * v;
|
|
1211
|
+
const inner = sqrt2pi * (v + k * cube);
|
|
1212
|
+
const tanhVal = Math.tanh(inner);
|
|
1213
|
+
return 0.5 * v * (1 + tanhVal);
|
|
1214
|
+
})
|
|
1215
|
+
);
|
|
1216
|
+
return this.output;
|
|
1217
|
+
}
|
|
1218
|
+
|
|
1219
|
+
backward(grad) {
|
|
1220
|
+
const sqrt2pi = Math.sqrt(2 / Math.PI);
|
|
1221
|
+
const k = 0.044715;
|
|
1222
|
+
|
|
1223
|
+
return grad.map((row, i) =>
|
|
1224
|
+
row.map((gradVal, j) => {
|
|
1225
|
+
const x = this.x[i][j];
|
|
1226
|
+
|
|
1227
|
+
// Calculate derivative numerically stable
|
|
1228
|
+
const x2 = x * x;
|
|
1229
|
+
const x3 = x2 * x;
|
|
1230
|
+
|
|
1231
|
+
const inner = sqrt2pi * (x + k * x3);
|
|
1232
|
+
const tanhInner = Math.tanh(inner);
|
|
1233
|
+
const sech2 = 1 - tanhInner * tanhInner;
|
|
1234
|
+
|
|
1235
|
+
// d(inner)/dx
|
|
1236
|
+
const d_inner = sqrt2pi * (1 + 3 * k * x2);
|
|
1237
|
+
|
|
1238
|
+
// d(GELU)/dx = 0.5 * (1 + tanh(inner)) + 0.5 * x * sech2(inner) * d_inner
|
|
1239
|
+
const d_gelu = 0.5 * (1 + tanhInner) + 0.5 * x * sech2 * d_inner;
|
|
1240
|
+
|
|
1241
|
+
return gradVal * d_gelu;
|
|
1242
|
+
})
|
|
1243
|
+
);
|
|
1244
|
+
}
|
|
1245
|
+
}
|
|
745
1246
|
|
|
746
1247
|
// ---------------------- Dropout ----------------------
|
|
747
|
-
export class Dropout
|
|
1248
|
+
export class Dropout {
|
|
1249
|
+
constructor(p = 0.5) {
|
|
1250
|
+
this.p = p;
|
|
1251
|
+
this.mask = null;
|
|
1252
|
+
this.training = true;
|
|
1253
|
+
this.scale = null; // Will be computed when needed
|
|
1254
|
+
}
|
|
1255
|
+
|
|
1256
|
+
_getScale() {
|
|
1257
|
+
if (this.scale === null) {
|
|
1258
|
+
this.scale = this.p === 1 ? 0 : 1.0 / (1.0 - this.p);
|
|
1259
|
+
}
|
|
1260
|
+
return this.scale;
|
|
1261
|
+
}
|
|
1262
|
+
|
|
1263
|
+
forward(x) {
|
|
1264
|
+
// Handle both 2D and 3D inputs
|
|
1265
|
+
const is3D = x[0] && Array.isArray(x[0][0]);
|
|
1266
|
+
|
|
1267
|
+
if (!this.training || this.p === 0) {
|
|
1268
|
+
// Deep copy based on dimension
|
|
1269
|
+
if (is3D) {
|
|
1270
|
+
return x.map(sample => sample.map(row => [...row]));
|
|
1271
|
+
}
|
|
1272
|
+
return x.map(row => [...row]);
|
|
1273
|
+
}
|
|
1274
|
+
|
|
1275
|
+
const scale = this._getScale();
|
|
1276
|
+
|
|
1277
|
+
if (is3D) {
|
|
1278
|
+
// Handle 3D input [batch, channels, features]
|
|
1279
|
+
this.mask = x.map(sample =>
|
|
1280
|
+
sample.map(row =>
|
|
1281
|
+
row.map(() => Math.random() >= this.p ? 1 : 0)
|
|
1282
|
+
)
|
|
1283
|
+
);
|
|
1284
|
+
|
|
1285
|
+
return x.map((sample, i) =>
|
|
1286
|
+
sample.map((row, j) =>
|
|
1287
|
+
row.map((v, k) => v * this.mask[i][j][k] * scale)
|
|
1288
|
+
)
|
|
1289
|
+
);
|
|
1290
|
+
} else {
|
|
1291
|
+
// Handle 2D input [batch, features]
|
|
1292
|
+
this.mask = x.map(row =>
|
|
1293
|
+
row.map(() => Math.random() >= this.p ? 1 : 0)
|
|
1294
|
+
);
|
|
1295
|
+
|
|
1296
|
+
return x.map((row, i) =>
|
|
1297
|
+
row.map((v, j) => v * this.mask[i][j] * scale)
|
|
1298
|
+
);
|
|
1299
|
+
}
|
|
1300
|
+
}
|
|
1301
|
+
|
|
1302
|
+
backward(grad) {
|
|
1303
|
+
if (!this.training || this.p === 0 || !this.mask) {
|
|
1304
|
+
// Deep copy gradient based on dimension
|
|
1305
|
+
const is3D = grad[0] && Array.isArray(grad[0][0]);
|
|
1306
|
+
if (is3D) {
|
|
1307
|
+
return grad.map(sample => sample.map(row => [...row]));
|
|
1308
|
+
}
|
|
1309
|
+
return grad.map(row => [...row]);
|
|
1310
|
+
}
|
|
1311
|
+
|
|
1312
|
+
const scale = this._getScale();
|
|
1313
|
+
const is3D = grad[0] && Array.isArray(grad[0][0]);
|
|
1314
|
+
|
|
1315
|
+
if (is3D) {
|
|
1316
|
+
return grad.map((sample, i) =>
|
|
1317
|
+
sample.map((row, j) =>
|
|
1318
|
+
row.map((v, k) => v * this.mask[i][j][k] * scale)
|
|
1319
|
+
)
|
|
1320
|
+
);
|
|
1321
|
+
} else {
|
|
1322
|
+
return grad.map((row, i) =>
|
|
1323
|
+
row.map((v, j) => v * this.mask[i][j] * scale)
|
|
1324
|
+
);
|
|
1325
|
+
}
|
|
1326
|
+
}
|
|
1327
|
+
|
|
1328
|
+
train() {
|
|
1329
|
+
this.training = true;
|
|
1330
|
+
}
|
|
1331
|
+
|
|
1332
|
+
eval() {
|
|
1333
|
+
this.training = false;
|
|
1334
|
+
}
|
|
1335
|
+
}
|
|
748
1336
|
|
|
749
1337
|
// ---------------------- Losses ----------------------
|
|
750
|
-
export class MSELoss
|
|
751
|
-
|
|
1338
|
+
export class MSELoss {
|
|
1339
|
+
forward(pred, target) {
|
|
1340
|
+
this.pred = pred;
|
|
1341
|
+
this.target = target;
|
|
1342
|
+
this.batchSize = pred.length;
|
|
1343
|
+
this.featureSize = pred[0].length;
|
|
1344
|
+
|
|
1345
|
+
let totalLoss = 0;
|
|
1346
|
+
for (let i = 0; i < this.batchSize; i++) {
|
|
1347
|
+
let sampleLoss = 0;
|
|
1348
|
+
for (let j = 0; j < this.featureSize; j++) { // ← pake this.featureSize
|
|
1349
|
+
const diff = pred[i][j] - target[i][j];
|
|
1350
|
+
sampleLoss += diff * diff;
|
|
1351
|
+
}
|
|
1352
|
+
totalLoss += sampleLoss / this.featureSize;
|
|
1353
|
+
}
|
|
1354
|
+
|
|
1355
|
+
this.loss = totalLoss / this.batchSize;
|
|
1356
|
+
return this.loss;
|
|
1357
|
+
}
|
|
1358
|
+
|
|
1359
|
+
backward() {
|
|
1360
|
+
const grad = Array(this.batchSize);
|
|
1361
|
+
|
|
1362
|
+
for (let i = 0; i < this.batchSize; i++) {
|
|
1363
|
+
grad[i] = Array(this.featureSize);
|
|
1364
|
+
for (let j = 0; j < this.featureSize; j++) {
|
|
1365
|
+
grad[i][j] = 2 * (this.pred[i][j] - this.target[i][j]) / this.batchSize;
|
|
1366
|
+
}
|
|
1367
|
+
}
|
|
1368
|
+
|
|
1369
|
+
return grad;
|
|
1370
|
+
}
|
|
1371
|
+
}
|
|
1372
|
+
|
|
1373
|
+
export class CrossEntropyLoss{
|
|
1374
|
+
constructor() {
|
|
1375
|
+
console.warn(
|
|
1376
|
+
'[JST WARN]: CrossEntrpyLoss is deprecated. ' +
|
|
1377
|
+
'Use SoftmaxCrossEntropyLoss instead for better numerical stability.'
|
|
1378
|
+
);
|
|
1379
|
+
this._impl = new SoftmaxCrossEntropyLoss();
|
|
1380
|
+
}
|
|
1381
|
+
|
|
1382
|
+
forward(logits, targets){
|
|
1383
|
+
return this._impl.forward(logits, targets);
|
|
1384
|
+
}
|
|
1385
|
+
backward(){
|
|
1386
|
+
return this._impl.backward();
|
|
1387
|
+
}
|
|
1388
|
+
}
|
|
752
1389
|
|
|
753
1390
|
export class SoftmaxCrossEntropyLoss {
|
|
754
1391
|
forward(logits, targets) {
|
|
@@ -816,7 +1453,6 @@ export class BCEWithLogitsLoss {
|
|
|
816
1453
|
// ---------------------- Optimizers ----------------------
|
|
817
1454
|
export class Adam{
|
|
818
1455
|
constructor(params, lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-8, max_grad_norm = 1.0){
|
|
819
|
-
// Handle both parameter styles: (params, lr) OR (params, {lr, ...})
|
|
820
1456
|
if (typeof lr === 'object') {
|
|
821
1457
|
// Options object provided
|
|
822
1458
|
const options = lr;
|
|
@@ -873,7 +1509,7 @@ export class Adam{
|
|
|
873
1509
|
}
|
|
874
1510
|
}
|
|
875
1511
|
|
|
876
|
-
// ---------------------- AdamW
|
|
1512
|
+
// ---------------------- AdamW ----------------------
|
|
877
1513
|
export class AdamW {
|
|
878
1514
|
constructor(params, options = {}) {
|
|
879
1515
|
const {
|
|
@@ -983,9 +1619,9 @@ export class LION {
|
|
|
983
1619
|
this.params = params;
|
|
984
1620
|
|
|
985
1621
|
const {
|
|
986
|
-
lr = 0.0001,
|
|
987
|
-
beta1 = 0.9,
|
|
988
|
-
beta2 = 0.99,
|
|
1622
|
+
lr = 0.0001,
|
|
1623
|
+
beta1 = 0.9,
|
|
1624
|
+
beta2 = 0.99,
|
|
989
1625
|
weight_decay = 0, // L2 regularization
|
|
990
1626
|
eps = 1e-8 // Numerical stability
|
|
991
1627
|
} = options;
|
|
@@ -1264,7 +1900,7 @@ export class SiLU {
|
|
|
1264
1900
|
}
|
|
1265
1901
|
|
|
1266
1902
|
|
|
1267
|
-
// ---------------------- BatchNorm2D ----------------------
|
|
1903
|
+
// ---------------------- BatchNorm2D (BETA) ----------------------
|
|
1268
1904
|
export class BatchNorm2d {
|
|
1269
1905
|
constructor(numFeatures, eps=1e-5, momentum=0.1, affine=true) {
|
|
1270
1906
|
this.numFeatures = numFeatures;
|
|
@@ -1477,12 +2113,12 @@ export class BatchNorm2d {
|
|
|
1477
2113
|
eval() { this.training = false; }
|
|
1478
2114
|
}
|
|
1479
2115
|
|
|
1480
|
-
// ---------------------- Model Save/Load ----------------------
|
|
2116
|
+
// ---------------------- Model Save/Load (BETA) ----------------------
|
|
1481
2117
|
export function saveModel(model){
|
|
1482
2118
|
if(!(model instanceof Sequential)) throw new Error("saveModel supports only Sequential");
|
|
1483
2119
|
const weights=model.layers.map(layer=>({weights:layer.W||null,biases:layer.b||null}));
|
|
1484
2120
|
return JSON.stringify(weights);
|
|
1485
|
-
/* Didn't expect this to work
|
|
2121
|
+
/* Didn't expect this to work */
|
|
1486
2122
|
}
|
|
1487
2123
|
|
|
1488
2124
|
export function loadModel(model,json){
|
|
@@ -1492,7 +2128,7 @@ export function loadModel(model,json){
|
|
|
1492
2128
|
if(layer.W && weights[i].weights) layer.W=weights[i].weights;
|
|
1493
2129
|
if(layer.b && weights[i].biases) layer.b=weights[i].biases;
|
|
1494
2130
|
});
|
|
1495
|
-
/* Didn't expect this to work
|
|
2131
|
+
/* Didn't expect this to work */
|
|
1496
2132
|
}
|
|
1497
2133
|
|
|
1498
2134
|
// ---------------------- Advanced Utils ----------------------
|
|
@@ -1507,4 +2143,40 @@ export function reshape(tensor, rows, cols) {
|
|
|
1507
2143
|
flat.slice(i*cols, i*cols + cols)
|
|
1508
2144
|
);
|
|
1509
2145
|
return out;
|
|
2146
|
+
}
|
|
2147
|
+
|
|
2148
|
+
export function toFloat32(matrix) {
|
|
2149
|
+
const rows = matrix.length, cols = matrix[0].length;
|
|
2150
|
+
const flat = new Float32Array(rows * cols);
|
|
2151
|
+
for (let i = 0; i < rows; i++)
|
|
2152
|
+
for (let j = 0; j < cols; j++)
|
|
2153
|
+
flat[i * cols + j] = matrix[i][j];
|
|
2154
|
+
return flat;
|
|
2155
|
+
}
|
|
2156
|
+
|
|
2157
|
+
export function fromFloat32(flat, rows, cols) {
|
|
2158
|
+
const matrix = Array(rows);
|
|
2159
|
+
for (let i = 0; i < rows; i++) {
|
|
2160
|
+
matrix[i] = Array(cols);
|
|
2161
|
+
for (let j = 0; j < cols; j++)
|
|
2162
|
+
matrix[i][j] = flat[i * cols + j];
|
|
2163
|
+
}
|
|
2164
|
+
return matrix;
|
|
2165
|
+
}
|
|
2166
|
+
|
|
2167
|
+
export function fastDot(a, b) {
|
|
2168
|
+
const m = a.length, k = a[0].length, n = b[0].length;
|
|
2169
|
+
const aFlat = toFloat32(a);
|
|
2170
|
+
const bFlat = toFloat32(b);
|
|
2171
|
+
const res = new Float32Array(m * n);
|
|
2172
|
+
|
|
2173
|
+
for (let i = 0; i < m; i++)
|
|
2174
|
+
for (let j = 0; j < n; j++) {
|
|
2175
|
+
let sum = 0;
|
|
2176
|
+
for (let l = 0; l < k; l++)
|
|
2177
|
+
sum += aFlat[i * k + l] * bFlat[l * n + j];
|
|
2178
|
+
res[i * n + j] = sum;
|
|
2179
|
+
}
|
|
2180
|
+
|
|
2181
|
+
return fromFloat32(res, m, n);
|
|
1510
2182
|
}
|