catniff 0.2.14 → 0.2.16
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/core.d.ts +1 -0
- package/dist/core.js +194 -73
- package/package.json +1 -1
package/dist/core.d.ts
CHANGED
|
@@ -43,6 +43,7 @@ export declare class Tensor {
|
|
|
43
43
|
mean(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
44
44
|
max(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
45
45
|
min(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
46
|
+
softmax(dims?: number[] | number): Tensor;
|
|
46
47
|
add(other: TensorValue | Tensor): Tensor;
|
|
47
48
|
sub(other: TensorValue | Tensor): Tensor;
|
|
48
49
|
subtract: (other: TensorValue | Tensor) => Tensor;
|
package/dist/core.js
CHANGED
|
@@ -355,14 +355,12 @@ class Tensor {
|
|
|
355
355
|
gradValue = new Array(originalSize).fill(0);
|
|
356
356
|
}
|
|
357
357
|
// Calculate new value after sum
|
|
358
|
-
for (let
|
|
359
|
-
const coords = Tensor.indexToCoords(
|
|
358
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
359
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
360
360
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
361
361
|
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
362
362
|
// Convert output coordinates to flat index
|
|
363
363
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
364
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
365
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
366
364
|
// Add into sum
|
|
367
365
|
outputValue[outFlatIndex] += this.value[realFlatIndex];
|
|
368
366
|
// Mark for gradient if needed
|
|
@@ -402,14 +400,12 @@ class Tensor {
|
|
|
402
400
|
const outputValue = new Array(outputSize).fill(1);
|
|
403
401
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
404
402
|
// Calculate new value after multiplying
|
|
405
|
-
for (let
|
|
406
|
-
const coords = Tensor.indexToCoords(
|
|
403
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
404
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
407
405
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
408
406
|
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
409
407
|
// Convert output coordinates to flat index
|
|
410
408
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
411
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
412
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
413
409
|
// Multiply into product
|
|
414
410
|
outputValue[outFlatIndex] *= this.value[realFlatIndex];
|
|
415
411
|
}
|
|
@@ -419,21 +415,19 @@ class Tensor {
|
|
|
419
415
|
});
|
|
420
416
|
// Set up gradient if needed
|
|
421
417
|
if (this.requiresGrad) {
|
|
422
|
-
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
423
|
-
for (let index = 0; index < originalSize; index++) {
|
|
424
|
-
const coords = Tensor.indexToCoords(index, this.strides);
|
|
425
|
-
// Force 0 on reduced axes to collapse into size-1 dims
|
|
426
|
-
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
427
|
-
// Convert output coordinates to flat index
|
|
428
|
-
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
429
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
430
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
431
|
-
// Grad is the product of other elements of the same axis, which is product of all els divided by the current value
|
|
432
|
-
gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex];
|
|
433
|
-
}
|
|
434
418
|
out.requiresGrad = true;
|
|
435
419
|
out.children.push(this);
|
|
436
420
|
out.gradFn = () => {
|
|
421
|
+
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
422
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
423
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
424
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
425
|
+
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
426
|
+
// Convert output coordinates to flat index
|
|
427
|
+
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
428
|
+
// Grad is the product of other elements of the same axis, which is product of all els divided by the current value
|
|
429
|
+
gradValue[realFlatIndex] = outputValue[outFlatIndex] / this.value[realFlatIndex];
|
|
430
|
+
}
|
|
437
431
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
438
432
|
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
439
433
|
};
|
|
@@ -458,14 +452,12 @@ class Tensor {
|
|
|
458
452
|
const outputFeeders = new Array(outputSize).fill(0);
|
|
459
453
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
460
454
|
// Calculate sums and how many elements contribute to specific positions
|
|
461
|
-
for (let
|
|
462
|
-
const coords = Tensor.indexToCoords(
|
|
455
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
456
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
463
457
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
464
458
|
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
465
459
|
// Convert output coordinates to flat index
|
|
466
460
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
467
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
468
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
469
461
|
// Calculate sum and contributors to the sum
|
|
470
462
|
outputValue[outFlatIndex] += this.value[realFlatIndex];
|
|
471
463
|
outputFeeders[outFlatIndex]++;
|
|
@@ -480,22 +472,20 @@ class Tensor {
|
|
|
480
472
|
});
|
|
481
473
|
// Set up gradient if needed
|
|
482
474
|
if (this.requiresGrad) {
|
|
483
|
-
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
484
|
-
// Calculate grad by assigning 1 divided by the number of contributors to the position
|
|
485
|
-
for (let index = 0; index < originalSize; index++) {
|
|
486
|
-
const coords = Tensor.indexToCoords(index, this.strides);
|
|
487
|
-
// Force 0 on reduced axes to collapse into size-1 dims
|
|
488
|
-
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
489
|
-
// Convert output coordinates to flat index
|
|
490
|
-
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
491
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
492
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
493
|
-
// Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
|
|
494
|
-
gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex];
|
|
495
|
-
}
|
|
496
475
|
out.requiresGrad = true;
|
|
497
476
|
out.children.push(this);
|
|
498
477
|
out.gradFn = () => {
|
|
478
|
+
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
479
|
+
// Calculate grad by assigning 1 divided by the number of contributors to the position
|
|
480
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
481
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
482
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
483
|
+
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
484
|
+
// Convert output coordinates to flat index
|
|
485
|
+
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
486
|
+
// Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
|
|
487
|
+
gradValue[realFlatIndex] = 1 / outputFeeders[outFlatIndex];
|
|
488
|
+
}
|
|
499
489
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
500
490
|
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
501
491
|
};
|
|
@@ -519,14 +509,12 @@ class Tensor {
|
|
|
519
509
|
const outputValue = new Array(outputSize).fill(-Infinity);
|
|
520
510
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
521
511
|
// Calculate maximum values of axes
|
|
522
|
-
for (let
|
|
523
|
-
const coords = Tensor.indexToCoords(
|
|
512
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
513
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
524
514
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
525
515
|
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
526
516
|
// Convert output coordinates to flat index
|
|
527
517
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
528
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
529
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
530
518
|
// Get max over time
|
|
531
519
|
if (this.value[realFlatIndex] > outputValue[outFlatIndex]) {
|
|
532
520
|
outputValue[outFlatIndex] = this.value[realFlatIndex];
|
|
@@ -538,21 +526,19 @@ class Tensor {
|
|
|
538
526
|
});
|
|
539
527
|
// Set up gradient if needed
|
|
540
528
|
if (this.requiresGrad) {
|
|
541
|
-
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
542
|
-
for (let index = 0; index < originalSize; index++) {
|
|
543
|
-
const coords = Tensor.indexToCoords(index, this.strides);
|
|
544
|
-
// Force 0 on reduced axes to collapse into size-1 dims
|
|
545
|
-
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
546
|
-
// Convert output coordinates to flat index
|
|
547
|
-
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
548
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
549
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
550
|
-
// Calculate grad by checking if a positon holds a value equal to the max value
|
|
551
|
-
gradValue[realFlatIndex] = outputValue[outFlatIndex] === this.value[realFlatIndex] ? 1 : 0;
|
|
552
|
-
}
|
|
553
529
|
out.requiresGrad = true;
|
|
554
530
|
out.children.push(this);
|
|
555
531
|
out.gradFn = () => {
|
|
532
|
+
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
533
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
534
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
535
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
536
|
+
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
537
|
+
// Convert output coordinates to flat index
|
|
538
|
+
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
539
|
+
// Calculate grad by checking if a positon holds a value equal to the max value
|
|
540
|
+
gradValue[realFlatIndex] = outputValue[outFlatIndex] === this.value[realFlatIndex] ? 1 : 0;
|
|
541
|
+
}
|
|
556
542
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
557
543
|
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
558
544
|
};
|
|
@@ -576,14 +562,12 @@ class Tensor {
|
|
|
576
562
|
const outputValue = new Array(outputSize).fill(Infinity);
|
|
577
563
|
const originalSize = Tensor.shapeToSize(this.shape);
|
|
578
564
|
// Calculate minimum values of axes
|
|
579
|
-
for (let
|
|
580
|
-
const coords = Tensor.indexToCoords(
|
|
565
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
566
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
581
567
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
582
568
|
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
583
569
|
// Convert output coordinates to flat index
|
|
584
570
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
585
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
586
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
587
571
|
// Get min over time
|
|
588
572
|
if (this.value[realFlatIndex] < outputValue[outFlatIndex]) {
|
|
589
573
|
outputValue[outFlatIndex] = this.value[realFlatIndex];
|
|
@@ -595,27 +579,89 @@ class Tensor {
|
|
|
595
579
|
});
|
|
596
580
|
// Set up gradient if needed
|
|
597
581
|
if (this.requiresGrad) {
|
|
598
|
-
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
599
|
-
for (let index = 0; index < originalSize; index++) {
|
|
600
|
-
const coords = Tensor.indexToCoords(index, this.strides);
|
|
601
|
-
// Force 0 on reduced axes to collapse into size-1 dims
|
|
602
|
-
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
603
|
-
// Convert output coordinates to flat index
|
|
604
|
-
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
605
|
-
// Accumulate, outFlatIndex should match multiple realFlatIndexes
|
|
606
|
-
const realFlatIndex = Tensor.coordsToIndex(coords, this.strides);
|
|
607
|
-
// Calculate grad by checking if a positon holds a value equal to the min value
|
|
608
|
-
gradValue[realFlatIndex] = outputValue[outFlatIndex] === this.value[realFlatIndex] ? 1 : 0;
|
|
609
|
-
}
|
|
610
582
|
out.requiresGrad = true;
|
|
611
583
|
out.children.push(this);
|
|
612
584
|
out.gradFn = () => {
|
|
585
|
+
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
586
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
587
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
588
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
589
|
+
const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
590
|
+
// Convert output coordinates to flat index
|
|
591
|
+
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
592
|
+
// Calculate grad by checking if a positon holds a value equal to the min value
|
|
593
|
+
gradValue[realFlatIndex] = outputValue[outFlatIndex] === this.value[realFlatIndex] ? 1 : 0;
|
|
594
|
+
}
|
|
613
595
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
614
596
|
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
615
597
|
};
|
|
616
598
|
}
|
|
617
599
|
return keepDims ? out : out.squeeze(dims);
|
|
618
600
|
}
|
|
601
|
+
// Tensor product reduction
|
|
602
|
+
softmax(dims) {
|
|
603
|
+
if (typeof this.value === "number")
|
|
604
|
+
return this;
|
|
605
|
+
if (typeof dims === "number") {
|
|
606
|
+
dims = [dims];
|
|
607
|
+
}
|
|
608
|
+
if (typeof dims === "undefined") {
|
|
609
|
+
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
610
|
+
}
|
|
611
|
+
// Dims that are reduced now have size-1
|
|
612
|
+
const expSumShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
|
|
613
|
+
const expSumStrides = Tensor.getStrides(expSumShape);
|
|
614
|
+
const expSumSize = Tensor.shapeToSize(expSumShape);
|
|
615
|
+
const expSumValue = new Array(expSumSize).fill(0);
|
|
616
|
+
const outputShape = this.shape;
|
|
617
|
+
const outputStrides = this.strides;
|
|
618
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
619
|
+
const outputValue = new Array(outputSize);
|
|
620
|
+
// Calculate sums of e^xi over axes
|
|
621
|
+
for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
|
|
622
|
+
const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
|
|
623
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
624
|
+
const expSumCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
625
|
+
// Convert exp sum coordinates to flat index
|
|
626
|
+
const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
|
|
627
|
+
// Add e^x to the sum cache
|
|
628
|
+
expSumValue[expSumFlatIndex] += Math.exp(this.value[realFlatIndex]);
|
|
629
|
+
}
|
|
630
|
+
// Calculate e^xi / sum over axes
|
|
631
|
+
for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
|
|
632
|
+
const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
|
|
633
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
634
|
+
const expSumCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
|
|
635
|
+
// Convert exp sum coordinates to flat index
|
|
636
|
+
const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
|
|
637
|
+
// Calculate e^xi / sum
|
|
638
|
+
outputValue[realFlatIndex] = Math.exp(this.value[realFlatIndex]) / expSumValue[expSumFlatIndex];
|
|
639
|
+
}
|
|
640
|
+
const out = new Tensor(outputValue, {
|
|
641
|
+
shape: outputShape,
|
|
642
|
+
strides: outputStrides
|
|
643
|
+
});
|
|
644
|
+
// Set up gradient if needed
|
|
645
|
+
if (this.requiresGrad) {
|
|
646
|
+
out.requiresGrad = true;
|
|
647
|
+
out.children.push(this);
|
|
648
|
+
out.gradFn = () => {
|
|
649
|
+
const upstreamGrad = out.grad.withGrad(false);
|
|
650
|
+
const softmaxOutput = out.withGrad(false);
|
|
651
|
+
// Compute element-wise product: ∂L/∂σᵢ × σᵢ
|
|
652
|
+
const gradTimesOutput = upstreamGrad.mul(softmaxOutput);
|
|
653
|
+
// Sum over softmax dimensions: Σᵢ(∂L/∂σᵢ × σᵢ)
|
|
654
|
+
const sumGradOutput = gradTimesOutput.sum(dims, true); // keepDims=true for broadcasting
|
|
655
|
+
// Apply softmax gradient formula:
|
|
656
|
+
// ∂L/∂zⱼ = (∂L/∂σⱼ × σⱼ) - (σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ))
|
|
657
|
+
const term1 = upstreamGrad.mul(softmaxOutput); // ∂L/∂σⱼ × σⱼ
|
|
658
|
+
const term2 = softmaxOutput.mul(sumGradOutput); // σⱼ × Σᵢ(∂L/∂σᵢ × σᵢ)
|
|
659
|
+
const localGrad = term1.sub(term2);
|
|
660
|
+
Tensor.addGrad(this, localGrad);
|
|
661
|
+
};
|
|
662
|
+
}
|
|
663
|
+
return out;
|
|
664
|
+
}
|
|
619
665
|
// Tensor element-wise addition
|
|
620
666
|
add(other) {
|
|
621
667
|
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
|
|
@@ -1144,19 +1190,94 @@ class Tensor {
|
|
|
1144
1190
|
// General matrix multiplication with different shapes
|
|
1145
1191
|
matmul(other) {
|
|
1146
1192
|
other = Tensor.forceTensor(other);
|
|
1147
|
-
|
|
1193
|
+
const isThis1D = this.shape.length === 1;
|
|
1194
|
+
const isOther1D = other.shape.length === 1;
|
|
1195
|
+
if (isThis1D && isOther1D) {
|
|
1148
1196
|
return this.dot(other);
|
|
1149
1197
|
}
|
|
1150
|
-
else if (
|
|
1198
|
+
else if (isThis1D && other.shape.length === 2) {
|
|
1151
1199
|
return this.unsqueeze(0).mm(other).squeeze(0);
|
|
1152
1200
|
}
|
|
1153
|
-
else if (this.shape.length === 2 &&
|
|
1201
|
+
else if (this.shape.length === 2 && isOther1D) {
|
|
1154
1202
|
return this.mv(other);
|
|
1155
1203
|
}
|
|
1156
1204
|
else if (this.shape.length === 2 && other.shape.length === 2) {
|
|
1157
1205
|
return this.mm(other);
|
|
1158
1206
|
}
|
|
1159
|
-
|
|
1207
|
+
else if ((isThis1D && other.shape.length > 2) ||
|
|
1208
|
+
(isOther1D && this.shape.length > 2) ||
|
|
1209
|
+
(other.shape.length > 2 && this.shape.length > 2)) {
|
|
1210
|
+
// Append/prepend dims if needed
|
|
1211
|
+
const self = isThis1D ? this.unsqueeze(0) : this;
|
|
1212
|
+
other = isOther1D ? other.unsqueeze(1) : other;
|
|
1213
|
+
// Padding
|
|
1214
|
+
const [selfStrides, otherStrides, selfShape, otherShape] = Tensor.padShape(self.strides, other.strides, self.shape, other.shape);
|
|
1215
|
+
const lastDim = selfShape.length - 1;
|
|
1216
|
+
// Prepare data for broadcasting
|
|
1217
|
+
const batchA = self.value;
|
|
1218
|
+
const batchB = other.value;
|
|
1219
|
+
const batchARows = selfShape[lastDim - 1];
|
|
1220
|
+
const batchACols = selfShape[lastDim];
|
|
1221
|
+
const batchBRows = otherShape[lastDim - 1];
|
|
1222
|
+
const batchBCols = otherShape[lastDim];
|
|
1223
|
+
// Verify if can do matmul
|
|
1224
|
+
if (batchACols !== batchBRows)
|
|
1225
|
+
throw new Error("Invalid matrices shape for multiplication");
|
|
1226
|
+
// Prepare shape, strides, size info, but more importantly the offset-related data to loop through the outer, non-matrix dims
|
|
1227
|
+
// Self and other's offset data
|
|
1228
|
+
const selfOffsetShape = selfShape.slice(0, -2);
|
|
1229
|
+
const otherOffsetShape = otherShape.slice(0, -2);
|
|
1230
|
+
const selfOffsetStrides = selfStrides.slice(0, -2);
|
|
1231
|
+
const otherOffsetStrides = otherStrides.slice(0, -2);
|
|
1232
|
+
// The output's offset data
|
|
1233
|
+
const offsetShape = Tensor.broadcastShapes(selfOffsetShape, otherOffsetShape);
|
|
1234
|
+
const offsetSize = Tensor.shapeToSize(offsetShape);
|
|
1235
|
+
const offsetStrides = Tensor.getStrides(offsetShape);
|
|
1236
|
+
// Output shape, strides, size, value
|
|
1237
|
+
const outputShape = [...offsetShape, batchARows, batchBCols];
|
|
1238
|
+
const outputStrides = Tensor.getStrides(outputShape);
|
|
1239
|
+
const outputSize = Tensor.shapeToSize(outputShape);
|
|
1240
|
+
const outputValue = new Array(outputSize).fill(0);
|
|
1241
|
+
// Loop through outer dims and do matmul on two outer-most dims
|
|
1242
|
+
for (let index = 0; index < offsetSize; index++) {
|
|
1243
|
+
const coords = Tensor.indexToCoords(index, offsetStrides);
|
|
1244
|
+
const offset = Tensor.coordsToIndex(coords, outputStrides.slice(0, -2));
|
|
1245
|
+
const selfOffset = Tensor.coordsToUnbroadcastedIndex(coords, selfOffsetShape, selfOffsetStrides);
|
|
1246
|
+
const otherOffset = Tensor.coordsToUnbroadcastedIndex(coords, otherOffsetShape, otherOffsetStrides);
|
|
1247
|
+
for (let i = 0; i < batchARows; i++) {
|
|
1248
|
+
for (let j = 0; j < batchBCols; j++) {
|
|
1249
|
+
for (let k = 0; k < batchACols; k++) {
|
|
1250
|
+
const outputIdx = offset + i * outputStrides[lastDim - 1] + j * outputStrides[lastDim];
|
|
1251
|
+
const selfIdx = selfOffset + i * selfStrides[lastDim - 1] + k * selfStrides[lastDim];
|
|
1252
|
+
const otherIdx = otherOffset + k * otherStrides[lastDim - 1] + j * otherStrides[lastDim];
|
|
1253
|
+
outputValue[outputIdx] += batchA[selfIdx] * batchB[otherIdx];
|
|
1254
|
+
}
|
|
1255
|
+
}
|
|
1256
|
+
}
|
|
1257
|
+
}
|
|
1258
|
+
const out = new Tensor(outputValue, { shape: outputShape, strides: outputStrides });
|
|
1259
|
+
if (this.requiresGrad) {
|
|
1260
|
+
out.requiresGrad = true;
|
|
1261
|
+
out.children.push(this);
|
|
1262
|
+
}
|
|
1263
|
+
if (other.requiresGrad) {
|
|
1264
|
+
out.requiresGrad = true;
|
|
1265
|
+
out.children.push(other);
|
|
1266
|
+
}
|
|
1267
|
+
if (out.requiresGrad) {
|
|
1268
|
+
out.gradFn = () => {
|
|
1269
|
+
other = other;
|
|
1270
|
+
const outGrad = out.grad.withGrad(false);
|
|
1271
|
+
const selfNoGrad = self.withGrad(false);
|
|
1272
|
+
const otherNoGrad = other.withGrad(false);
|
|
1273
|
+
if (this.requiresGrad)
|
|
1274
|
+
Tensor.addGrad(this, outGrad.matmul(otherNoGrad.transpose(lastDim - 1, lastDim)));
|
|
1275
|
+
if (other.requiresGrad)
|
|
1276
|
+
Tensor.addGrad(other, selfNoGrad.transpose(lastDim - 1, lastDim).matmul(outGrad));
|
|
1277
|
+
};
|
|
1278
|
+
}
|
|
1279
|
+
return out;
|
|
1280
|
+
}
|
|
1160
1281
|
throw new Error(`Shapes [${this.shape}] and [${other.shape}] are not supported`);
|
|
1161
1282
|
}
|
|
1162
1283
|
// Utility to create a new tensor filled with a number
|