mini-jstorch 1.8.2 → 2.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/src/jstorch.js CHANGED
@@ -29,7 +29,7 @@
29
29
  // See the Documentation for more details.
30
30
  // --------------------------------------------------------------
31
31
 
32
- // ---------------------- DONOT USE THESE (ENGINE INTERNALS) ERROR/BUG ARE EXPECTED ----------------------
32
+ // ---------------------- Engine-only Utils ----------------------
33
33
  export function zeros(rows, cols) {
34
34
  return Array.from({length:rows},()=>Array(cols).fill(0));
35
35
  }
@@ -235,17 +235,111 @@ export function fu_stack(tensors) {
235
235
 
236
236
  // ---------------------- Tensor ----------------------
237
237
  export class Tensor {
238
- constructor(data){ this.data=data; this.grad=zeros(data.length,data[0].length); }
239
- shape(){ return [this.data.length,this.data[0].length]; }
240
- add(t){ return t instanceof Tensor?this.data.map((r,i)=>r.map((v,j)=>v+t.data[i][j])):this.data.map(r=>r.map(v=>v+t)); }
241
- sub(t){ return t instanceof Tensor?this.data.map((r,i)=>r.map((v,j)=>v-t.data[i][j])):this.data.map(r=>r.map(v=>v-t)); }
242
- mul(t){ return t instanceof Tensor?this.data.map((r,i)=>r.map((v,j)=>v*t.data[i][j])):this.data.map(r=>r.map(v=>v*t)); }
243
- matmul(t){ if(t instanceof Tensor) return dot(this.data,t.data); else throw new Error("matmul requires Tensor"); }
244
- transpose(){ return transpose(this.data); }
245
- flatten(){ return this.data.flat(); }
246
- static zeros(r,c){ return new Tensor(zeros(r,c)); }
247
- static ones(r,c){ return new Tensor(ones(r,c)); }
248
- static random(r,c,scale=0.1){ return new Tensor(randomMatrix(r,c,scale)); }
238
+ constructor(data, requiresGrad = false){
239
+ this.data = data;
240
+ this.rows = data.length;
241
+ this.cols = data[0].length;
242
+ this.grad = zeros(this.rows, this.cols);
243
+ this.requiresGrad = requiresGrad;
244
+
245
+ this._dataFlat = null;
246
+ this._gradFlat = null;
247
+ }
248
+
249
+ shape(){
250
+ return [this.rows, this.cols];
251
+ }
252
+
253
+ _getDataFlat() {
254
+ if (!this._dataFlat) {
255
+ this._dataFlat = new Float32Array(this.rows * this.cols);
256
+ for (let i = 0; i < this.rows; i++) {
257
+ const offset = i * this.cols;
258
+ const row = this.data[i];
259
+ for (let j = 0; j < this.cols; j++) {
260
+ this._dataFlat[offset + j] = row[j];
261
+ }
262
+ }
263
+ }
264
+ return this._dataFlat;
265
+ }
266
+
267
+ _getGradFlat() {
268
+ if (!this._gradFlat) {
269
+ this._gradFlat = new Float32Array(this.rows * this.cols);
270
+ for (let i = 0; i < this.rows; i++) {
271
+ const offset = i * this.cols;
272
+ const row = this.grad[i];
273
+ for (let j = 0; j < this.cols; j++) {
274
+ this._gradFlat[offset + j] = row[j];
275
+ }
276
+ }
277
+ }
278
+ return this._gradFlat;
279
+ }
280
+
281
+ add(t){
282
+ if (t instanceof Tensor) {
283
+ const result = Array(this.rows);
284
+ const aFlat = this._getDataFlat();
285
+ const bFlat = t._getDataFlat();
286
+ for (let i = 0; i < this.rows; i++) {
287
+ result[i] = Array(this.cols);
288
+ const offset = i * this.cols;
289
+ for (let j = 0; j < this.cols; j++) {
290
+ result[i][j] = aFlat[offset + j] + bFlat[offset + j];
291
+ }
292
+ }
293
+ return new Tensor(result);
294
+ } else {
295
+ const res = this.data.map(r => r.map(v => v + t));
296
+ return new Tensor(res);
297
+ }
298
+ }
299
+
300
+ mul(t){
301
+ if (t instanceof Tensor) {
302
+ const result = Array(this.rows);
303
+ const aFlat = this._getDataFlat();
304
+ const bFlat = t._getDataFlat();
305
+ for (let i = 0; i < this.rows; i++) {
306
+ result[i] = Array(this.cols);
307
+ const offset = i * this.cols;
308
+ for (let j = 0; j < this.cols; j++) {
309
+ result[i][j] = aFlat[offset + j] * bFlat[offset + j];
310
+ }
311
+ }
312
+ return new Tensor(result);
313
+ } else {
314
+ const res = this.data.map(r => r.map(v => v * t));
315
+ return new Tensor(res);
316
+ }
317
+ }
318
+
319
+ matmul(t){
320
+ if (!(t instanceof Tensor)) throw new Error("matmul requires Tensor");
321
+ return new Tensor(dot(this.data, t.data));
322
+ }
323
+
324
+ transpose(){
325
+ return new Tensor(transpose(this.data));
326
+ }
327
+
328
+ flatten(){
329
+ return this.data.flat();
330
+ }
331
+
332
+ static zeros(r,c){
333
+ return new Tensor(zeros(r,c));
334
+ }
335
+
336
+ static ones(r,c){
337
+ return new Tensor(ones(r,c));
338
+ }
339
+
340
+ static random(r,c,scale=0.1){
341
+ return new Tensor(randomMatrix(r,c,scale));
342
+ }
249
343
  }
250
344
 
251
345
  // ---------------------- Layers ----------------------
@@ -256,58 +350,175 @@ export class Linear {
256
350
  this.gradW = zeros(inputDim, outputDim);
257
351
  this.gradb = Array(outputDim).fill(0);
258
352
  this.x = null;
259
- this.originalShape = null; // Track input shape
260
- }
353
+ this.originalShape = null;
354
+
355
+ this._WFlat = null;
356
+ this._bFlat = null;
357
+ }
358
+
359
+ _updateCache() {
360
+ const rows = this.W.length;
361
+ const cols = this.W[0].length;
362
+ this._WFlat = new Float32Array(rows * cols);
363
+ for (let i = 0; i < rows; i++){
364
+ const offset = i * cols;
365
+ const row = this.W[i];
366
+ for (let j = 0; j < cols; j++){
367
+ this._WFlat[offset + j] = row[j];
368
+ }
369
+ }
370
+ this._bFlat = new Float32Array(this.b);
371
+ }
261
372
 
262
373
  forward(x){
263
- // Handle both [batch, features] and [batch, 1, features]
264
374
  this.originalShape = this._getShapeType(x);
265
375
 
266
376
  if (this.originalShape === '3d') {
267
- // Convert from [batch, 1, features] to [batch, features]
268
377
  this.x = x.map(sample => sample[0]);
269
378
  } else {
270
- // Already in [batch, features] format
271
379
  this.x = x;
272
380
  }
381
+
382
+ this._updateCache();
383
+
384
+ const m = this.x.length;
385
+ const k = this.x[0].length;
386
+ const n = this.W[0].length;
273
387
 
274
- const out = dot(this.x, this.W);
275
- return out.map((row, i) => row.map((v, j) => v + this.b[j]));
276
- }
277
-
278
- backward(grad){
279
- // Compute gradients
280
- for(let i = 0; i < this.W.length; i++) {
281
- for(let j = 0; j < this.W[0].length; j++) {
282
- this.gradW[i][j] = this.x.reduce((sum, row, k) => sum + row[i] * grad[k][j], 0);
388
+ if (!this._WFlat) {
389
+ const rows = this.W.length;
390
+ const cols = this.W[0].length;
391
+ this._WFlat = new Float32Array(rows * cols);
392
+ for (let i = 0; i < rows; i++) {
393
+ const offset = i * cols;
394
+ const row = this.W[i];
395
+ for (let j = 0; j < cols; j++) {
396
+ this._WFlat[offset + j] = row[j];
397
+ }
283
398
  }
399
+ this._bFlat = new Float32Array(this.b);
284
400
  }
285
401
 
286
- for(let j = 0; j < this.b.length; j++) {
287
- this.gradb[j] = grad.reduce((sum, row) => sum + row[j], 0);
402
+ // Flatten input x to Float32Array
403
+ const xFlat = new Float32Array(m * k);
404
+ for (let i = 0; i < m; i++) {
405
+ const row = this.x[i];
406
+ const offset = i * k;
407
+ for (let j = 0; j < k; j++) {
408
+ xFlat[offset + j] = row[j];
409
+ }
288
410
  }
289
-
290
- const gradInput = zeros(this.x.length, this.W.length);
291
- for(let i = 0; i < this.x.length; i++) {
292
- for(let j = 0; j < this.W.length; j++) {
293
- for(let k = 0; k < this.W[0].length; k++) {
294
- gradInput[i][j] += grad[i][k] * this.W[j][k];
411
+
412
+ const outFlat = new Float32Array(m * n);
413
+ for (let i = 0; i < m; i++) {
414
+ const xOffset = i * k;
415
+ for (let j = 0; j < n; j++) {
416
+ let sum = 0;
417
+ for (let l = 0; l < k; l++) {
418
+ sum += xFlat[xOffset + l] * this._WFlat[l * n + j];
295
419
  }
420
+ outFlat[i * n + j] = sum + this._bFlat[j];
296
421
  }
297
422
  }
298
423
 
299
- //Convert back to original shape if needed
300
- if (this.originalShape === '3d') {
301
- return gradInput.map(row => [row]); // Back to [batch, 1, features]
424
+ const out = Array(m);
425
+ for (let i = 0; i < m; i++) {
426
+ const row = Array(n);
427
+ const offset = i * n;
428
+ for (let j = 0; j < n; j++) {
429
+ row[j] = outFlat[offset + j];
430
+ }
431
+ out[i] = row;
302
432
  }
303
- return gradInput;
433
+
434
+ return out;
304
435
  }
305
436
 
437
+ backward(grad){
438
+ const m = this.x.length;
439
+ const k = this.W.length; // input dim
440
+ const n = this.W[0].length; // output dim
441
+
442
+ // Convert grad to Float32Array
443
+ const gradFlat = new Float32Array(m * n);
444
+ for (let i = 0; i < m; i++) {
445
+ const row = grad[i];
446
+ const offset = i * n;
447
+ for (let j = 0; j < n; j++) {
448
+ gradFlat[offset + j] = row[j];
449
+ }
450
+ }
451
+
452
+ // Convert x to Float32Array
453
+ const xFlat = new Float32Array(m * k);
454
+ for (let i = 0; i < m; i++) {
455
+ const row = this.x[i];
456
+ const offset = i * k;
457
+ for (let j = 0; j < k; j++) {
458
+ xFlat[offset + j] = row[j];
459
+ }
460
+ }
461
+
462
+ // Reset gradW
463
+ for (let i = 0; i < this.gradW.length; i++) {
464
+ for (let j = 0; j < this.gradW[0].length; j++) {
465
+ this.gradW[i][j] = 0;
466
+ }
467
+ }
468
+
469
+ // Compute gradW = x^T * grad
470
+ for (let i = 0; i < k; i++) {
471
+ for (let j = 0; j < n; j++) {
472
+ let sum = 0;
473
+ for (let batch = 0; batch < m; batch++) {
474
+ sum += xFlat[batch * k + i] * gradFlat[batch * n + j];
475
+ }
476
+ this.gradW[i][j] = sum;
477
+ }
478
+ }
479
+
480
+ // Compute gradb
481
+ for (let j = 0; j < n; j++) {
482
+ let sum = 0;
483
+ for (let batch = 0; batch < m; batch++) {
484
+ sum += gradFlat[batch * n + j];
485
+ }
486
+ this.gradb[j] = sum;
487
+ }
488
+
489
+ const gradInputFlat = new Float32Array(m * k);
490
+ for (let i = 0; i < m; i++) {
491
+ for (let j = 0; j < k; j++) {
492
+ let sum = 0;
493
+ for (let l = 0; l < n; l++) {
494
+ sum += gradFlat[i * n + l] * this.W[j][l];
495
+ }
496
+ gradInputFlat[i * k + j] = sum;
497
+ }
498
+ }
499
+
500
+ // Convert back to 2D array
501
+ const gradInput = Array(m);
502
+ for (let i = 0; i < m; i++) {
503
+ const row = Array(k);
504
+ const offset = i * k;
505
+ for (let j = 0; j < k; j++) {
506
+ row[j] = gradInputFlat[offset + j];
507
+ }
508
+ gradInput[i] = row;
509
+ }
510
+
511
+ if (this.originalShape === '3d') {
512
+ return gradInput.map(row => [row]);
513
+ }
514
+ return gradInput;
515
+ }
516
+
306
517
  _getShapeType(x) {
307
518
  if (Array.isArray(x[0]) && Array.isArray(x[0][0]) && !Array.isArray(x[0][0][0])) {
308
- return '3d'; // [batch, 1, features]
519
+ return '3d';
309
520
  } else if (Array.isArray(x[0]) && !Array.isArray(x[0][0])) {
310
- return '2d'; // [batch, features]
521
+ return '2d';
311
522
  } else {
312
523
  throw new Error(`Unsupported input shape for Linear layer`);
313
524
  }
@@ -405,7 +616,7 @@ export class Flatten {
405
616
  parameters() { return []; }
406
617
  }
407
618
 
408
- // ---------------------- Conv2D ----------------------
619
+ // ---------------------- Conv2D (BETA) ----------------------
409
620
  export class Conv2D {
410
621
  constructor(inC, outC, kernel, stride=1, padding=0){
411
622
  this.inC = inC;
@@ -420,10 +631,31 @@ export class Conv2D {
420
631
  Array(inC).fill().map(() => zeros(kernel, kernel))
421
632
  );
422
633
  this.x = null;
634
+
635
+ // Cache Float32Array untuk kernels
636
+ this._WFlat = null;
637
+ this._cacheKernels();
638
+ }
639
+
640
+ _cacheKernels() {
641
+ this._WFlat = this.W.map(oc =>
642
+ oc.map(ic => {
643
+ const rows = ic.length;
644
+ const cols = ic[0].length;
645
+ const flat = new Float32Array(rows * cols);
646
+ for (let i = 0; i < rows; i++) {
647
+ const offset = i * cols;
648
+ const row = ic[i];
649
+ for (let j = 0; j < cols; j++) {
650
+ flat[offset + j] = row[j];
651
+ }
652
+ }
653
+ return flat;
654
+ })
655
+ );
423
656
  }
424
657
 
425
658
  pad2D(input, pad){
426
- // Input is single channel [height, width]
427
659
  if (!input || !input.length) return input;
428
660
 
429
661
  const rows = input.length + 2 * pad;
@@ -431,26 +663,31 @@ export class Conv2D {
431
663
  const out = Array.from({length: rows}, () => Array(cols).fill(0));
432
664
 
433
665
  for(let i = 0; i < input.length; i++) {
666
+ const row = input[i];
667
+ const outRow = out[i + pad];
434
668
  for(let j = 0; j < input[0].length; j++) {
435
- out[i + pad][j + pad] = input[i][j];
669
+ outRow[j + pad] = row[j];
436
670
  }
437
671
  }
438
672
  return out;
439
673
  }
440
674
 
441
- conv2DSingle(input, kernel) {
442
- const rows = Math.floor((input.length - kernel.length) / this.stride) + 1;
443
- const cols = Math.floor((input[0].length - kernel[0].length) / this.stride) + 1;
444
- const out = zeros(rows, cols);
675
+ conv2DSingle(input, kernelFlat, kH, kW) {
676
+ const rows = Math.floor((input.length - kH) / this.stride) + 1;
677
+ const cols = Math.floor((input[0].length - kW) / this.stride) + 1;
678
+ const out = Array(rows);
445
679
 
446
680
  for(let i = 0; i < rows; i++) {
681
+ out[i] = Array(cols);
447
682
  for(let j = 0; j < cols; j++) {
448
683
  let sum = 0;
449
- for(let ki = 0; ki < kernel.length; ki++) {
450
- for(let kj = 0; kj < kernel[0].length; kj++) {
451
- const inputRow = i * this.stride + ki;
684
+ for(let ki = 0; ki < kH; ki++) {
685
+ const inputRow = i * this.stride + ki;
686
+ const rowOffset = ki * kW;
687
+ const inputRowData = input[inputRow];
688
+ for(let kj = 0; kj < kW; kj++) {
452
689
  const inputCol = j * this.stride + kj;
453
- sum += input[inputRow][inputCol] * kernel[ki][kj];
690
+ sum += inputRowData[inputCol] * kernelFlat[rowOffset + kj];
454
691
  }
455
692
  }
456
693
  out[i][j] = sum;
@@ -461,6 +698,9 @@ export class Conv2D {
461
698
 
462
699
  forward(batch) {
463
700
  this.x = batch;
701
+ const kH = this.kernel;
702
+ const kW = this.kernel;
703
+
464
704
  return batch.map(sample => {
465
705
  const channelsOut = [];
466
706
  for(let oc = 0; oc < this.outC; oc++) {
@@ -471,7 +711,7 @@ export class Conv2D {
471
711
  inputChan = this.pad2D(inputChan, this.padding);
472
712
  }
473
713
 
474
- const conv = this.conv2DSingle(inputChan, this.W[oc][ic]);
714
+ const conv = this.conv2DSingle(inputChan, this._WFlat[oc][ic], kH, kW);
475
715
 
476
716
  if(outChan === null) {
477
717
  outChan = conv;
@@ -487,7 +727,7 @@ export class Conv2D {
487
727
 
488
728
  backward(grad) {
489
729
  const batchSize = this.x.length;
490
- const gradW = this.W.map(oc => oc.map(ic => zeros(this.kernel, this.kernel)));
730
+ const gradW = this.gradW.map(oc => oc.map(ic => zeros(this.kernel, this.kernel)));
491
731
  const gradInput = this.x.map(sample =>
492
732
  sample.map(chan => zeros(chan.length, chan[0].length))
493
733
  );
@@ -549,10 +789,109 @@ export class Conv2D {
549
789
 
550
790
  // ---------------------- Sequential ----------------------
551
791
  export class Sequential {
552
- constructor(layers=[]){ this.layers=layers; }
553
- forward(x){ return this.layers.reduce((acc,l)=>l.forward(acc), x); }
554
- backward(grad){ return this.layers.reduceRight((g,l)=>l.backward(g), grad); }
555
- parameters(){ return this.layers.flatMap(l=>l.parameters?l.parameters():[]); }
792
+ constructor(layers=[]) {
793
+ this.layers = layers;
794
+ }
795
+
796
+ forward(x){
797
+ return this.layers.reduce((acc, l) => l.forward(acc), x);
798
+ }
799
+
800
+ backward(grad){
801
+ return this.layers.reduceRight((g, l) => l.backward(g), grad);
802
+ }
803
+
804
+ parameters(){
805
+ return this.layers.flatMap(l => l.parameters ? l.parameters() : []);
806
+ }
807
+
808
+ /**
809
+ * Zero out all gradients of all parameters
810
+ */
811
+ zeroGrad(){
812
+ const params = this.parameters();
813
+
814
+ for (const p of params){
815
+ if (!p.grad) continue;
816
+
817
+ // Handle different gradient shapes
818
+ if (Array.isArray(p.grad)){
819
+ // Check if it's 1D or 2D array
820
+ if (p.grad.length > 0 && Array.isArray(p.grad[0])){
821
+ // 2D Gradient (weights)
822
+ for (let i = 0; i < p.grad.length; i++){
823
+ const row = p.grad[i];
824
+ for (let j = 0; j < row.length; j++){
825
+ row[j] = 0;
826
+ }
827
+ }
828
+ } else {
829
+ // 1D Gradient (bias)
830
+ for (let i = 0; i < p.grad.length; i++){
831
+ p.grad[i] = 0;
832
+ }
833
+ }
834
+ } else if (typeof p.grad === 'number'){
835
+ // Scalar gradient (rare case)
836
+ p.grad = 0;
837
+ }
838
+ }
839
+
840
+ return this; // Allow chaining
841
+ }
842
+
843
+ /**
844
+ * Train mode (enable dropout, batch norm, etc.)
845
+ */
846
+ train(){
847
+ this.layers.forEach(layer => {
848
+ if (layer.train) layer.train();
849
+ });
850
+ return this;
851
+ }
852
+
853
+ /**
854
+ * Eval mode (disable dropout, batch norm, etc.)
855
+ */
856
+ eval(){
857
+ this.layers.forEach(layer => {
858
+ if (layer.eval) layer.eval();
859
+ });
860
+ return this;
861
+ }
862
+
863
+ /**
864
+ * Get model state dict (weights and biases)
865
+ */
866
+ stateDict(){
867
+ const state = {};
868
+ this.layers.forEach((layer, idx) => {
869
+ if (layer.W) state[`layer_${idx}.weight`] = layer.W;
870
+ if (layer.b) state[`layer_${idx}.bias`] = layer.b;
871
+ });
872
+ return state;
873
+ }
874
+
875
+ /**
876
+ * Load state dict
877
+ */
878
+ loadStateDict(stateDict){
879
+ this.layers.forEach((layer, idx) => {
880
+ const weightKey = `layer_${idx}.weight`;
881
+ const biasKey = `layer_${idx}.bias`;
882
+
883
+ if (layer.W && stateDict[weightKey]){
884
+ layer.W = stateDict[weightKey];
885
+ // Invalidate cache
886
+ if (layer._InvalidateCache) layer._InvalidateCache();
887
+ }
888
+ if (layer.b && stateDict[biasKey]){
889
+ layer.b = stateDict[biasKey];
890
+ if (layer._InvalidateCache) layer._InvalidateCache();
891
+ }
892
+ });
893
+ return this;
894
+ }
556
895
  }
557
896
 
558
897
  // ---------------------- Activations ----------------------
@@ -606,42 +945,35 @@ export class Softmax {
606
945
 
607
946
  forward(x) {
608
947
  this.input = x;
609
-
610
- // x: [batch_size, num_classes]
611
- this.output = x.map(row => {
612
- const maxVal = Math.max(...row);
613
- const exps = row.map(v => Math.exp(v - maxVal));
614
- const sumExps = exps.reduce((a, b) => a + b, 0);
615
- return exps.map(v => v / sumExps);
616
- });
617
- return this.output;
948
+
949
+ // x: [batch_size, num_Classes]
950
+ this.output = x.map(row => {
951
+ const maxVal = Math.max(...row);
952
+ const exps = row.map(v => Math.exp(v - maxVal));
953
+ const sumExps = exps.reduce((a, b) => a + b, 0);
954
+ return exps.map(v => v / sumExps);
955
+ });
956
+ return this.output;
618
957
  }
619
958
 
620
959
  backward(grad) {
621
- // grad: [batch_size, num_classes] - gradient from next layer
622
- const batchSize = grad.length;
623
- const numClasses = grad[0].length;
624
-
625
- const gradInput = zeros(batchSize, numClasses);
626
-
627
- for (let i = 0; i < batchSize; i++) {
628
- const s = this.output[i]; // Softmax output for this sample
629
- const gradOut = grad[i]; // Gradient from loss
630
-
631
- // Compute Jacobian matrix: J_ij = s_i * (δ_ij - s_j)
632
- for (let j = 0; j < numClasses; j++) {
633
- let sum = 0;
634
- for (let k = 0; k < numClasses; k++) {
635
- // J[j][k] = s[j] * ((j === k ? 1 : 0) - s[k])
636
- const jacobian = s[j] * ((j === k ? 1 : 0) - s[k]);
637
- sum += jacobian * gradOut[k];
638
- }
639
- gradInput[i][j] = sum;
640
- }
641
- }
642
-
643
- return gradInput;
644
- }
960
+ const batchSize = grad.length;
961
+ const numClasses = grad[0].length;
962
+ const gradInput = zeros(batchSize, numClasses);
963
+
964
+ for (let i = 0; i < batchSize; i++){
965
+ const s = this.output[i]; // Softmax output
966
+ const gradOut = grad[i]; // Gradient from next layer
967
+
968
+ const dot = s.reduce((sum, val, k) => sum + val * gradOut[k], 0);
969
+
970
+ for (let j = 0; j < numClasses; j++){
971
+ gradInput[i][j] = s[j] * (gradOut[j] - dot);
972
+ }
973
+ }
974
+
975
+ return gradInput;
976
+ }
645
977
 
646
978
  parameters() {
647
979
  return []; // Softmax has no trainable parameters
@@ -650,90 +982,209 @@ export class Softmax {
650
982
 
651
983
  // ---------------------- Tokenizer ----------------------
652
984
  export class Tokenizer {
653
- constructor(vocabSize = 2000){
985
+ constructor(vocabSize = 2000) {
654
986
  this.vocabSize = vocabSize;
655
987
  this.wordToIndex = new Map();
656
988
  this.indexToWord = new Map();
657
989
  this.fitted = false;
658
- }
659
-
660
- fit(texts){
661
- const wordCounts = new Map();
662
-
663
- // Count word frequencies from all texts
990
+ this.wordCounts = null;
991
+
992
+ // Special tokens
993
+ this.PAD_TOKEN = '<PAD>';
994
+ this.UNK_TOKEN = '<UNK>';
995
+ this.PAD_INDEX = 0;
996
+ this.UNK_INDEX = 1;
997
+ }
998
+
999
+ /**
1000
+ * Fit tokenizer on a list of texts
1001
+ * @param {string[]} texts - Array of text strings
1002
+ * @returns {Tokenizer} this
1003
+ */
1004
+ fit(texts) {
1005
+ this.wordCounts = new Map();
1006
+ this.wordToIndex.clear();
1007
+ this.indexToWord.clear();
1008
+
1009
+ // Count word frequencies
664
1010
  texts.forEach(text => {
665
- const words = this._preprocess(text);
1011
+ const words = this._tokenize(text);
666
1012
  words.forEach(word => {
667
- wordCounts.set(word, (wordCounts.get(word) || 0) + 1);
1013
+ this.wordCounts.set(word, (this.wordCounts.get(word) || 0) + 1);
668
1014
  });
669
1015
  });
670
-
671
- // Sort by frequency and take top words
672
- const sortedWords = [...wordCounts.entries()]
673
- .sort((a, b) => a[1] - a[1])
674
- .slice(0, this.vocabSize - 1); // Reverse 1 for unknown
675
-
676
- // Build vocabulary
677
- this.wordToIndex.clear();
678
- this.indexToWord.clear();
679
-
680
- // Add unk token
681
- this.wordToIndex.set('<UNK>', 0);
682
- this.indexToWord.set(0, '<UNK>');
683
-
684
- // Add most frequent words
685
- sortedWords.forEach(([word], index) =>{
686
- this.wordToIndex.set(word, index + 1);
687
- this.indexToWord.set(index + 1, word);
688
- })
689
-
1016
+
1017
+ const sortedWords = [...this.wordCounts.entries()]
1018
+ .sort((a, b) => b[1] - a[1]) // Descending: most frequent first
1019
+ .slice(0, this.vocabSize - 2) // -2 for PAD and UNK
1020
+ .map(([word]) => word);
1021
+
1022
+ // Index 0 = <PAD> (padding)
1023
+ this.wordToIndex.set(this.PAD_TOKEN, this.PAD_INDEX);
1024
+ this.indexToWord.set(this.PAD_INDEX, this.PAD_TOKEN);
1025
+
1026
+ // Index 1 = <UNK> (unknown)
1027
+ this.wordToIndex.set(this.UNK_TOKEN, this.UNK_INDEX);
1028
+ this.indexToWord.set(this.UNK_INDEX, this.UNK_TOKEN);
1029
+
1030
+ sortedWords.forEach((word, idx) => {
1031
+ const index = idx + 2; // Start from index 2
1032
+ this.wordToIndex.set(word, index);
1033
+ this.indexToWord.set(index, word);
1034
+ });
1035
+
690
1036
  this.fitted = true;
691
1037
  return this;
692
1038
  }
693
-
694
- tokenize(text){
695
- if (!this.fitted) throw new Error("Tokenizer not fitted. Call fit() first.");
696
-
697
- const words = this._preprocess(text);
698
- return words.map(word => this.wordToIndex.get(word) || 0);
1039
+
1040
+ /**
1041
+ * Fit and transform in one step
1042
+ * @param {string[]} texts - Array of text strings
1043
+ * @param {number|null} maxLength - Pad/truncate to this length
1044
+ * @param {boolean} padToMax - Whether to pad to maxLength (default: true)
1045
+ * @returns {number[][]} Tokenized sequences
1046
+ */
1047
+ fitTransform(texts, maxLength = null, padToMax = true) {
1048
+ this.fit(texts);
1049
+ return this.transform(texts, maxLength, padToMax);
699
1050
  }
700
-
701
- tokenizeBatch(texts, maxLength=null){
702
- if (!this.fitted) throw new Error("Tokenizer not fitted. Call fit() first.");
703
-
1051
+
1052
+ /**
1053
+ * Transform texts to token indices
1054
+ * @param {string[]} texts - Array of text strings
1055
+ * @param {number|null} maxLength - Pad/truncate to this length
1056
+ * @param {boolean} padToMax - Whether to pad to maxLength (default: true)
1057
+ * @returns {number[][]} Tokenized sequences
1058
+ */
1059
+ transform(texts, maxLength = null, padToMax = true) {
1060
+ if (!this.fitted) {
1061
+ throw new Error("Tokenizer not fitted. Call fit() or fitTransform() first.");
1062
+ }
1063
+
704
1064
  return texts.map(text => {
705
- const tokens = this.tokenize(text);
706
-
707
- if (maxLength !== null){
708
- // Pad or truncate to maxLength
709
- if (tokens.length > maxLength){
710
- return tokens.slice(0, maxLength);
711
- } else {
712
- return [...tokens, ...Array(maxLength - tokens.length).fill(0)];
713
- }
1065
+ const words = this._tokenize(text);
1066
+ let tokens = words.map(word => this.wordToIndex.get(word) || this.UNK_INDEX);
1067
+
1068
+ // Truncate if maxLength specified
1069
+ if (maxLength !== null && tokens.length > maxLength) {
1070
+ tokens = tokens.slice(0, maxLength);
714
1071
  }
715
-
1072
+
1073
+ // Pad if maxLength specified and padToMax is true
1074
+ if (maxLength !== null && padToMax && tokens.length < maxLength) {
1075
+ tokens = [...tokens, ...Array(maxLength - tokens.length).fill(this.PAD_INDEX)];
1076
+ }
1077
+
716
1078
  return tokens;
717
- })
1079
+ });
718
1080
  }
719
-
720
- detokenize(tokens){
721
- return tokens.map(token => this.indexToWord.get(token) || '<UNK>').join(' ');
1081
+
1082
+ /**
1083
+ * Convert token indices back to text
1084
+ * @param {number[]|number[][]} tokens - Single sequence or batch of sequences
1085
+ * @param {boolean} skipPad - Whether to skip PAD tokens (default: true)
1086
+ * @returns {string|string[]} Detokenized text(s)
1087
+ */
1088
+ inverseTransform(tokens, skipPad = true) {
1089
+ if (!this.fitted) {
1090
+ throw new Error("Tokenizer not fitted. Call fit() first.");
1091
+ }
1092
+
1093
+ // Check if batch or single sequence
1094
+ const isBatch = Array.isArray(tokens[0]);
1095
+
1096
+ if (isBatch) {
1097
+ return tokens.map(seq => this._detokenize(seq, skipPad));
1098
+ } else {
1099
+ return this._detokenize(tokens, skipPad);
1100
+ }
722
1101
  }
723
-
724
- detokenizeBatch(tokenBatches){
725
- return tokenBatches.map(tokens => this.detokenize(tokens));
1102
+
1103
+ /**
1104
+ * Get vocabulary as array
1105
+ * @returns {string[]} Vocabulary list
1106
+ */
1107
+ getVocabulary() {
1108
+ if (!this.fitted) return [];
1109
+
1110
+ const vocab = new Array(this.wordToIndex.size);
1111
+ for (let i = 0; i < this.wordToIndex.size; i++) {
1112
+ vocab[i] = this.indexToWord.get(i);
1113
+ }
1114
+ return vocab;
726
1115
  }
727
-
728
- getVocabSize(){
1116
+
1117
+ /**
1118
+ * Get vocabulary size
1119
+ * @returns {number} Number of words in vocabulary
1120
+ */
1121
+ getVocabSize() {
729
1122
  return this.wordToIndex.size;
730
1123
  }
731
-
732
- _preprocess(text) {
1124
+
1125
+ /**
1126
+ * Get word frequency counts
1127
+ * @returns {Map<string, number>} Word frequency map
1128
+ */
1129
+ getWordCounts() {
1130
+ return this.wordCounts ? new Map(this.wordCounts) : null;
1131
+ }
1132
+
1133
+ /**
1134
+ * Get most common words
1135
+ * @param {number} n - Number of words to return
1136
+ * @returns {Array<{word: string, count: number}>} Most common words
1137
+ */
1138
+ getMostCommon(n = 10) {
1139
+ if (!this.wordCounts) return [];
1140
+
1141
+ return [...this.wordCounts.entries()]
1142
+ .sort((a, b) => b[1] - a[1])
1143
+ .slice(0, n)
1144
+ .map(([word, count]) => ({ word, count }));
1145
+ }
1146
+
1147
+ /**
1148
+ * Internal: tokenize text into words
1149
+ * @param {string} text
1150
+ * @returns {string[]}
1151
+ */
1152
+ _tokenize(text) {
1153
+ // handle contractions and preserve word boundaries
733
1154
  return text.toLowerCase()
734
- .replace(/[^\w\s]/g, ' ') // Remove punctuation
735
- .split(/\s+/) // Split by whitespace
736
- .filter(word => word.length > 0); // Remove empty strings
1155
+ .replace(/([.!?;:,])/g, ' $1 ')
1156
+ .replace(/\s+/g, ' ')
1157
+ .trim()
1158
+ .split(' ')
1159
+ .filter(word => word.length > 0 && !/^[.!?;:,]+$/.test(word) || word.length > 1);
1160
+ }
1161
+
1162
+ /**
1163
+ * Internal: detokenize sequence to text
1164
+ * @param {number[]} tokens
1165
+ * @param {boolean} skipPad
1166
+ * @returns {string}
1167
+ */
1168
+ _detokenize(tokens, skipPad = true) {
1169
+ const words = [];
1170
+
1171
+ for (const token of tokens) {
1172
+ if (skipPad && token === this.PAD_INDEX) {
1173
+ continue; // Skip padding tokens
1174
+ }
1175
+
1176
+ const word = this.indexToWord.get(token);
1177
+ if (word && word !== this.PAD_TOKEN) {
1178
+ words.push(word === this.UNK_TOKEN ? '?' : word);
1179
+ } else if (word === undefined) {
1180
+ words.push('?');
1181
+ }
1182
+ }
1183
+
1184
+ let text = words.join(' ');
1185
+ text = text.replace(/\s+([.!?;:,])/g, '$1');
1186
+ text = text.replace(/([.!?;:,])\s+/g, '$1 ');
1187
+ return text.trim();
737
1188
  }
738
1189
  }
739
1190
 
@@ -741,14 +1192,200 @@ export class Tokenizer {
741
1192
  export class Sigmoid{ constructor(){ this.out=null; } forward(x){ const fn=v=>1/(1+Math.exp(-v)); this.out=x.map(r=>r.map(fn)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*this.out[i][j]*(1-this.out[i][j]))); } }
742
1193
  export class Tanh{ constructor(){ this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>Math.tanh(v))); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(1-this.out[i][j]**2))); } }
743
1194
  export class LeakyReLU{ constructor(alpha=0.01){ this.alpha=alpha; this.out=null; } forward(x){ this.out=x.map(r=>r.map(v=>v>0?v:v*this.alpha)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map((v,j)=>v*(this.out[i][j]>0?1:this.alpha))); } }
744
- export class GELU{ constructor(){ this.out=null; } forward(x){ const fn=v=>0.5*v*(1+Math.tanh(Math.sqrt(2/Math.PI)*(v+0.044715*v**3))); this.out=x.map(r=>r.map(fn)); return this.out; } backward(grad){ return grad.map((r,i)=>r.map(v=>v*1)); } }
1195
+
1196
+ // ---------------------- GELU ----------------------
1197
+ export class GELU {
1198
+ constructor() {
1199
+ this.x = null;
1200
+ this.output = null;
1201
+ }
1202
+
1203
+ forward(x) {
1204
+ this.x = x;
1205
+ const sqrt2pi = Math.sqrt(2 / Math.PI);
1206
+ const k = 0.044715;
1207
+
1208
+ this.output = x.map(row =>
1209
+ row.map(v => {
1210
+ const cube = v * v * v;
1211
+ const inner = sqrt2pi * (v + k * cube);
1212
+ const tanhVal = Math.tanh(inner);
1213
+ return 0.5 * v * (1 + tanhVal);
1214
+ })
1215
+ );
1216
+ return this.output;
1217
+ }
1218
+
1219
+ backward(grad) {
1220
+ const sqrt2pi = Math.sqrt(2 / Math.PI);
1221
+ const k = 0.044715;
1222
+
1223
+ return grad.map((row, i) =>
1224
+ row.map((gradVal, j) => {
1225
+ const x = this.x[i][j];
1226
+
1227
+ // Calculate derivative numerically stable
1228
+ const x2 = x * x;
1229
+ const x3 = x2 * x;
1230
+
1231
+ const inner = sqrt2pi * (x + k * x3);
1232
+ const tanhInner = Math.tanh(inner);
1233
+ const sech2 = 1 - tanhInner * tanhInner;
1234
+
1235
+ // d(inner)/dx
1236
+ const d_inner = sqrt2pi * (1 + 3 * k * x2);
1237
+
1238
+ // d(GELU)/dx = 0.5 * (1 + tanh(inner)) + 0.5 * x * sech2(inner) * d_inner
1239
+ const d_gelu = 0.5 * (1 + tanhInner) + 0.5 * x * sech2 * d_inner;
1240
+
1241
+ return gradVal * d_gelu;
1242
+ })
1243
+ );
1244
+ }
1245
+ }
745
1246
 
746
1247
  // ---------------------- Dropout ----------------------
747
- export class Dropout{ constructor(p=0.5){ this.p=p; } forward(x){ return x.map(r=>r.map(v=>v*Math.random()>=this.p?v:0)); } backward(grad){ return grad.map(r=>r.map(v=>v*(1-this.p))); } }
1248
+ export class Dropout {
1249
+ constructor(p = 0.5) {
1250
+ this.p = p;
1251
+ this.mask = null;
1252
+ this.training = true;
1253
+ this.scale = null; // Will be computed when needed
1254
+ }
1255
+
1256
+ _getScale() {
1257
+ if (this.scale === null) {
1258
+ this.scale = this.p === 1 ? 0 : 1.0 / (1.0 - this.p);
1259
+ }
1260
+ return this.scale;
1261
+ }
1262
+
1263
+ forward(x) {
1264
+ // Handle both 2D and 3D inputs
1265
+ const is3D = x[0] && Array.isArray(x[0][0]);
1266
+
1267
+ if (!this.training || this.p === 0) {
1268
+ // Deep copy based on dimension
1269
+ if (is3D) {
1270
+ return x.map(sample => sample.map(row => [...row]));
1271
+ }
1272
+ return x.map(row => [...row]);
1273
+ }
1274
+
1275
+ const scale = this._getScale();
1276
+
1277
+ if (is3D) {
1278
+ // Handle 3D input [batch, channels, features]
1279
+ this.mask = x.map(sample =>
1280
+ sample.map(row =>
1281
+ row.map(() => Math.random() >= this.p ? 1 : 0)
1282
+ )
1283
+ );
1284
+
1285
+ return x.map((sample, i) =>
1286
+ sample.map((row, j) =>
1287
+ row.map((v, k) => v * this.mask[i][j][k] * scale)
1288
+ )
1289
+ );
1290
+ } else {
1291
+ // Handle 2D input [batch, features]
1292
+ this.mask = x.map(row =>
1293
+ row.map(() => Math.random() >= this.p ? 1 : 0)
1294
+ );
1295
+
1296
+ return x.map((row, i) =>
1297
+ row.map((v, j) => v * this.mask[i][j] * scale)
1298
+ );
1299
+ }
1300
+ }
1301
+
1302
+ backward(grad) {
1303
+ if (!this.training || this.p === 0 || !this.mask) {
1304
+ // Deep copy gradient based on dimension
1305
+ const is3D = grad[0] && Array.isArray(grad[0][0]);
1306
+ if (is3D) {
1307
+ return grad.map(sample => sample.map(row => [...row]));
1308
+ }
1309
+ return grad.map(row => [...row]);
1310
+ }
1311
+
1312
+ const scale = this._getScale();
1313
+ const is3D = grad[0] && Array.isArray(grad[0][0]);
1314
+
1315
+ if (is3D) {
1316
+ return grad.map((sample, i) =>
1317
+ sample.map((row, j) =>
1318
+ row.map((v, k) => v * this.mask[i][j][k] * scale)
1319
+ )
1320
+ );
1321
+ } else {
1322
+ return grad.map((row, i) =>
1323
+ row.map((v, j) => v * this.mask[i][j] * scale)
1324
+ );
1325
+ }
1326
+ }
1327
+
1328
+ train() {
1329
+ this.training = true;
1330
+ }
1331
+
1332
+ eval() {
1333
+ this.training = false;
1334
+ }
1335
+ }
748
1336
 
749
1337
  // ---------------------- Losses ----------------------
750
- export class MSELoss{ forward(pred,target){ this.pred=pred; this.target=target; const losses=pred.map((row,i)=>row.reduce((sum,v,j)=>sum+(v-target[i][j])**2,0)/row.length); return losses.reduce((a,b)=>a+b,0)/pred.length; } backward(){ return this.pred.map((row,i)=>row.map((v,j)=>2*(v-this.target[i][j])/row.length)); } }
751
- export class CrossEntropyLoss{ forward(pred,target){ this.pred=pred; this.target=target; const losses=pred.map((p,i)=>crossEntropy(softmax(p),target[i])); return losses.reduce((a,b)=>a+b,0)/pred.length; } backward(){ return this.pred.map((p,i)=>{ const s=softmax(p); return s.map((v,j)=>(v-this.target[i][j])/this.pred.length); }); } }
1338
+ export class MSELoss {
1339
+ forward(pred, target) {
1340
+ this.pred = pred;
1341
+ this.target = target;
1342
+ this.batchSize = pred.length;
1343
+ this.featureSize = pred[0].length;
1344
+
1345
+ let totalLoss = 0;
1346
+ for (let i = 0; i < this.batchSize; i++) {
1347
+ let sampleLoss = 0;
1348
+ for (let j = 0; j < this.featureSize; j++) { // ← pake this.featureSize
1349
+ const diff = pred[i][j] - target[i][j];
1350
+ sampleLoss += diff * diff;
1351
+ }
1352
+ totalLoss += sampleLoss / this.featureSize;
1353
+ }
1354
+
1355
+ this.loss = totalLoss / this.batchSize;
1356
+ return this.loss;
1357
+ }
1358
+
1359
+ backward() {
1360
+ const grad = Array(this.batchSize);
1361
+
1362
+ for (let i = 0; i < this.batchSize; i++) {
1363
+ grad[i] = Array(this.featureSize);
1364
+ for (let j = 0; j < this.featureSize; j++) {
1365
+ grad[i][j] = 2 * (this.pred[i][j] - this.target[i][j]) / this.batchSize;
1366
+ }
1367
+ }
1368
+
1369
+ return grad;
1370
+ }
1371
+ }
1372
+
1373
+ export class CrossEntropyLoss{
1374
+ constructor() {
1375
+ console.warn(
1376
+ '[JST WARN]: CrossEntrpyLoss is deprecated. ' +
1377
+ 'Use SoftmaxCrossEntropyLoss instead for better numerical stability.'
1378
+ );
1379
+ this._impl = new SoftmaxCrossEntropyLoss();
1380
+ }
1381
+
1382
+ forward(logits, targets){
1383
+ return this._impl.forward(logits, targets);
1384
+ }
1385
+ backward(){
1386
+ return this._impl.backward();
1387
+ }
1388
+ }
752
1389
 
753
1390
  export class SoftmaxCrossEntropyLoss {
754
1391
  forward(logits, targets) {
@@ -816,7 +1453,6 @@ export class BCEWithLogitsLoss {
816
1453
  // ---------------------- Optimizers ----------------------
817
1454
  export class Adam{
818
1455
  constructor(params, lr = 0.001, b1 = 0.9, b2 = 0.999, eps = 1e-8, max_grad_norm = 1.0){
819
- // Handle both parameter styles: (params, lr) OR (params, {lr, ...})
820
1456
  if (typeof lr === 'object') {
821
1457
  // Options object provided
822
1458
  const options = lr;
@@ -873,7 +1509,7 @@ export class Adam{
873
1509
  }
874
1510
  }
875
1511
 
876
- // ---------------------- AdamW Optimizer ----------------------
1512
+ // ---------------------- AdamW ----------------------
877
1513
  export class AdamW {
878
1514
  constructor(params, options = {}) {
879
1515
  const {
@@ -983,9 +1619,9 @@ export class LION {
983
1619
  this.params = params;
984
1620
 
985
1621
  const {
986
- lr = 0.0001, // Lions typically uses smaller LR
987
- beta1 = 0.9, // First moment decay
988
- beta2 = 0.99, // Second moment decay
1622
+ lr = 0.0001,
1623
+ beta1 = 0.9,
1624
+ beta2 = 0.99,
989
1625
  weight_decay = 0, // L2 regularization
990
1626
  eps = 1e-8 // Numerical stability
991
1627
  } = options;
@@ -1264,7 +1900,7 @@ export class SiLU {
1264
1900
  }
1265
1901
 
1266
1902
 
1267
- // ---------------------- BatchNorm2D ----------------------
1903
+ // ---------------------- BatchNorm2D (BETA) ----------------------
1268
1904
  export class BatchNorm2d {
1269
1905
  constructor(numFeatures, eps=1e-5, momentum=0.1, affine=true) {
1270
1906
  this.numFeatures = numFeatures;
@@ -1477,12 +2113,12 @@ export class BatchNorm2d {
1477
2113
  eval() { this.training = false; }
1478
2114
  }
1479
2115
 
1480
- // ---------------------- Model Save/Load ----------------------
2116
+ // ---------------------- Model Save/Load (BETA) ----------------------
1481
2117
  export function saveModel(model){
1482
2118
  if(!(model instanceof Sequential)) throw new Error("saveModel supports only Sequential");
1483
2119
  const weights=model.layers.map(layer=>({weights:layer.W||null,biases:layer.b||null}));
1484
2120
  return JSON.stringify(weights);
1485
- /* Didn't expect this to work /: */
2121
+ /* Didn't expect this to work */
1486
2122
  }
1487
2123
 
1488
2124
  export function loadModel(model,json){
@@ -1492,7 +2128,7 @@ export function loadModel(model,json){
1492
2128
  if(layer.W && weights[i].weights) layer.W=weights[i].weights;
1493
2129
  if(layer.b && weights[i].biases) layer.b=weights[i].biases;
1494
2130
  });
1495
- /* Didn't expect this to work /: */
2131
+ /* Didn't expect this to work */
1496
2132
  }
1497
2133
 
1498
2134
  // ---------------------- Advanced Utils ----------------------
@@ -1507,4 +2143,40 @@ export function reshape(tensor, rows, cols) {
1507
2143
  flat.slice(i*cols, i*cols + cols)
1508
2144
  );
1509
2145
  return out;
2146
+ }
2147
+
2148
+ export function toFloat32(matrix) {
2149
+ const rows = matrix.length, cols = matrix[0].length;
2150
+ const flat = new Float32Array(rows * cols);
2151
+ for (let i = 0; i < rows; i++)
2152
+ for (let j = 0; j < cols; j++)
2153
+ flat[i * cols + j] = matrix[i][j];
2154
+ return flat;
2155
+ }
2156
+
2157
+ export function fromFloat32(flat, rows, cols) {
2158
+ const matrix = Array(rows);
2159
+ for (let i = 0; i < rows; i++) {
2160
+ matrix[i] = Array(cols);
2161
+ for (let j = 0; j < cols; j++)
2162
+ matrix[i][j] = flat[i * cols + j];
2163
+ }
2164
+ return matrix;
2165
+ }
2166
+
2167
+ export function fastDot(a, b) {
2168
+ const m = a.length, k = a[0].length, n = b[0].length;
2169
+ const aFlat = toFloat32(a);
2170
+ const bFlat = toFloat32(b);
2171
+ const res = new Float32Array(m * n);
2172
+
2173
+ for (let i = 0; i < m; i++)
2174
+ for (let j = 0; j < n; j++) {
2175
+ let sum = 0;
2176
+ for (let l = 0; l < k; l++)
2177
+ sum += aFlat[i * k + l] * bFlat[l * n + j];
2178
+ res[i * n + j] = sum;
2179
+ }
2180
+
2181
+ return fromFloat32(res, m, n);
1510
2182
  }