catniff 0.4.0 → 0.4.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/backend.d.ts +2 -81
- package/dist/core.d.ts +2 -0
- package/dist/core.js +110 -437
- package/package.json +1 -1
package/dist/backend.d.ts
CHANGED
|
@@ -1,83 +1,4 @@
|
|
|
1
|
-
import { Tensor
|
|
1
|
+
import { Tensor } from "./core";
|
|
2
2
|
export interface Backend {
|
|
3
|
-
|
|
4
|
-
prod?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
5
|
-
mean?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
6
|
-
max?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
7
|
-
min?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
8
|
-
softmax?(tensor: Tensor, dims?: number[] | number): Tensor;
|
|
9
|
-
add?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
10
|
-
sub?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
11
|
-
mul?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
12
|
-
pow?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
13
|
-
div?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
14
|
-
remainder?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
15
|
-
ge?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
16
|
-
le?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
17
|
-
gt?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
18
|
-
lt?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
19
|
-
eq?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
20
|
-
ne?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
21
|
-
logicalAnd?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
22
|
-
logicalOr?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
23
|
-
logicalXor?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
24
|
-
logicalNot?(self: Tensor): Tensor;
|
|
25
|
-
bitwiseAnd?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
26
|
-
bitwiseOr?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
27
|
-
bitwiseXor?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
28
|
-
bitwiseNot?(self: Tensor): Tensor;
|
|
29
|
-
bitwiseLeftShift?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
30
|
-
bitwiseRightShift?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
31
|
-
neg?(self: Tensor): Tensor;
|
|
32
|
-
reciprocal?(self: Tensor): Tensor;
|
|
33
|
-
square?(self: Tensor): Tensor;
|
|
34
|
-
abs?(self: Tensor): Tensor;
|
|
35
|
-
sign?(self: Tensor): Tensor;
|
|
36
|
-
sin?(self: Tensor): Tensor;
|
|
37
|
-
cos?(self: Tensor): Tensor;
|
|
38
|
-
tan?(self: Tensor): Tensor;
|
|
39
|
-
asin?(self: Tensor): Tensor;
|
|
40
|
-
acos?(self: Tensor): Tensor;
|
|
41
|
-
atan?(self: Tensor): Tensor;
|
|
42
|
-
atan2?(self: Tensor): Tensor;
|
|
43
|
-
sinh?(self: Tensor): Tensor;
|
|
44
|
-
cosh?(self: Tensor): Tensor;
|
|
45
|
-
asinh?(self: Tensor): Tensor;
|
|
46
|
-
acosh?(self: Tensor): Tensor;
|
|
47
|
-
atanh?(self: Tensor): Tensor;
|
|
48
|
-
deg2rad?(self: Tensor): Tensor;
|
|
49
|
-
rad2deg?(self: Tensor): Tensor;
|
|
50
|
-
sqrt?(self: Tensor): Tensor;
|
|
51
|
-
rsqrt?(self: Tensor): Tensor;
|
|
52
|
-
exp?(self: Tensor): Tensor;
|
|
53
|
-
exp2?(self: Tensor): Tensor;
|
|
54
|
-
expm1?(self: Tensor): Tensor;
|
|
55
|
-
log?(self: Tensor): Tensor;
|
|
56
|
-
log2?(self: Tensor): Tensor;
|
|
57
|
-
log10?(self: Tensor): Tensor;
|
|
58
|
-
log1p?(self: Tensor): Tensor;
|
|
59
|
-
relu?(self: Tensor): Tensor;
|
|
60
|
-
sigmoid?(self: Tensor): Tensor;
|
|
61
|
-
tanh?(self: Tensor): Tensor;
|
|
62
|
-
softplus?(self: Tensor): Tensor;
|
|
63
|
-
softsign?(self: Tensor): Tensor;
|
|
64
|
-
silu?(self: Tensor): Tensor;
|
|
65
|
-
mish?(self: Tensor): Tensor;
|
|
66
|
-
maximum?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
67
|
-
minimum?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
68
|
-
round?(self: Tensor): Tensor;
|
|
69
|
-
floor?(self: Tensor): Tensor;
|
|
70
|
-
ceil?(self: Tensor): Tensor;
|
|
71
|
-
trunc?(self: Tensor): Tensor;
|
|
72
|
-
frac?(self: Tensor): Tensor;
|
|
73
|
-
clip?(self: Tensor, min: number, max: number): Tensor;
|
|
74
|
-
erf?(self: Tensor): Tensor;
|
|
75
|
-
erfc?(self: Tensor): Tensor;
|
|
76
|
-
erfinv?(self: Tensor): Tensor;
|
|
77
|
-
dot?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
78
|
-
mm?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
79
|
-
bmm?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
80
|
-
mv?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
81
|
-
matmul?(self: Tensor, other: Tensor | TensorValue): Tensor;
|
|
82
|
-
to?(tensor: Tensor): Tensor;
|
|
3
|
+
transfer(tensor: Tensor): Tensor;
|
|
83
4
|
}
|
package/dist/core.d.ts
CHANGED
|
@@ -46,6 +46,8 @@ export declare class Tensor {
|
|
|
46
46
|
mean(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
47
47
|
max(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
48
48
|
min(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
49
|
+
var(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
50
|
+
std(dims?: number[] | number, keepDims?: boolean): Tensor;
|
|
49
51
|
softmax(dims?: number[] | number): Tensor;
|
|
50
52
|
add(other: TensorValue | Tensor): Tensor;
|
|
51
53
|
sub(other: TensorValue | Tensor): Tensor;
|
package/dist/core.js
CHANGED
|
@@ -337,21 +337,22 @@ class Tensor {
|
|
|
337
337
|
}
|
|
338
338
|
// Tensor sum reduction
|
|
339
339
|
sum(dims, keepDims = false) {
|
|
340
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
341
|
-
const backend = Tensor.backends.get(this.device);
|
|
342
|
-
if (backend && backend.sum) {
|
|
343
|
-
return backend.sum(this, dims, keepDims);
|
|
344
|
-
}
|
|
345
340
|
if (typeof this.value === "number")
|
|
346
341
|
return this;
|
|
347
|
-
if (typeof dims === "number") {
|
|
348
|
-
dims = [dims];
|
|
349
|
-
}
|
|
350
342
|
if (typeof dims === "undefined") {
|
|
351
343
|
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
352
344
|
}
|
|
345
|
+
if (Array.isArray(dims)) {
|
|
346
|
+
// Sort in descending order
|
|
347
|
+
const sortedDims = dims.sort((a, b) => b - a);
|
|
348
|
+
let reducedThis = this;
|
|
349
|
+
for (let i = 0; i < sortedDims.length; i++) {
|
|
350
|
+
reducedThis = reducedThis.sum(sortedDims[i], true);
|
|
351
|
+
}
|
|
352
|
+
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
353
|
+
}
|
|
353
354
|
// Dims that are reduced now have size-1
|
|
354
|
-
const outputShape = this.shape.map((dim, i) => dims
|
|
355
|
+
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
355
356
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
356
357
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
357
358
|
const outputValue = new Array(outputSize).fill(0);
|
|
@@ -368,7 +369,7 @@ class Tensor {
|
|
|
368
369
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
369
370
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
370
371
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
371
|
-
const outCoords = coords.map((val, i) => dims
|
|
372
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
372
373
|
// Convert output coordinates to flat index
|
|
373
374
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
374
375
|
// Add into sum
|
|
@@ -395,21 +396,22 @@ class Tensor {
|
|
|
395
396
|
}
|
|
396
397
|
// Tensor product reduction
|
|
397
398
|
prod(dims, keepDims = false) {
|
|
398
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
399
|
-
const backend = Tensor.backends.get(this.device);
|
|
400
|
-
if (backend && backend.prod) {
|
|
401
|
-
return backend.prod(this, dims, keepDims);
|
|
402
|
-
}
|
|
403
399
|
if (typeof this.value === "number")
|
|
404
400
|
return this;
|
|
405
|
-
if (typeof dims === "number") {
|
|
406
|
-
dims = [dims];
|
|
407
|
-
}
|
|
408
401
|
if (typeof dims === "undefined") {
|
|
409
402
|
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
410
403
|
}
|
|
404
|
+
if (Array.isArray(dims)) {
|
|
405
|
+
// Sort in descending order
|
|
406
|
+
const sortedDims = dims.sort((a, b) => b - a);
|
|
407
|
+
let reducedThis = this;
|
|
408
|
+
for (let i = 0; i < sortedDims.length; i++) {
|
|
409
|
+
reducedThis = reducedThis.prod(sortedDims[i], true);
|
|
410
|
+
}
|
|
411
|
+
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
412
|
+
}
|
|
411
413
|
// Dims that are reduced now have size-1
|
|
412
|
-
const outputShape = this.shape.map((dim, i) => dims
|
|
414
|
+
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
413
415
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
414
416
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
415
417
|
const outputValue = new Array(outputSize).fill(1);
|
|
@@ -418,7 +420,7 @@ class Tensor {
|
|
|
418
420
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
419
421
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
420
422
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
421
|
-
const outCoords = coords.map((val, i) => dims
|
|
423
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
422
424
|
// Convert output coordinates to flat index
|
|
423
425
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
424
426
|
// Multiply into product
|
|
@@ -437,7 +439,7 @@ class Tensor {
|
|
|
437
439
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
438
440
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
439
441
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
440
|
-
const outCoords = coords.map((val, i) => dims
|
|
442
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
441
443
|
// Convert output coordinates to flat index
|
|
442
444
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
443
445
|
// Grad is the product of other elements of the same axis, which is product of all els divided by the current value
|
|
@@ -451,21 +453,22 @@ class Tensor {
|
|
|
451
453
|
}
|
|
452
454
|
// Tensor mean reduction
|
|
453
455
|
mean(dims, keepDims = false) {
|
|
454
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
455
|
-
const backend = Tensor.backends.get(this.device);
|
|
456
|
-
if (backend && backend.mean) {
|
|
457
|
-
return backend.mean(this, dims, keepDims);
|
|
458
|
-
}
|
|
459
456
|
if (typeof this.value === "number")
|
|
460
457
|
return this;
|
|
461
|
-
if (typeof dims === "number") {
|
|
462
|
-
dims = [dims];
|
|
463
|
-
}
|
|
464
458
|
if (typeof dims === "undefined") {
|
|
465
459
|
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
466
460
|
}
|
|
461
|
+
if (Array.isArray(dims)) {
|
|
462
|
+
// Sort in descending order
|
|
463
|
+
const sortedDims = dims.sort((a, b) => b - a);
|
|
464
|
+
let reducedThis = this;
|
|
465
|
+
for (let i = 0; i < sortedDims.length; i++) {
|
|
466
|
+
reducedThis = reducedThis.mean(sortedDims[i], true);
|
|
467
|
+
}
|
|
468
|
+
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
469
|
+
}
|
|
467
470
|
// Dims that are reduced now have size-1
|
|
468
|
-
const outputShape = this.shape.map((dim, i) => dims
|
|
471
|
+
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
469
472
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
470
473
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
471
474
|
const outputValue = new Array(outputSize).fill(0);
|
|
@@ -475,7 +478,7 @@ class Tensor {
|
|
|
475
478
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
476
479
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
477
480
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
478
|
-
const outCoords = coords.map((val, i) => dims
|
|
481
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
479
482
|
// Convert output coordinates to flat index
|
|
480
483
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
481
484
|
// Calculate sum and contributors to the sum
|
|
@@ -500,7 +503,7 @@ class Tensor {
|
|
|
500
503
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
501
504
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
502
505
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
503
|
-
const outCoords = coords.map((val, i) => dims
|
|
506
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
504
507
|
// Convert output coordinates to flat index
|
|
505
508
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
506
509
|
// Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
|
|
@@ -514,21 +517,22 @@ class Tensor {
|
|
|
514
517
|
}
|
|
515
518
|
// Tensor maximum reduction
|
|
516
519
|
max(dims, keepDims = false) {
|
|
517
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
518
|
-
const backend = Tensor.backends.get(this.device);
|
|
519
|
-
if (backend && backend.max) {
|
|
520
|
-
return backend.max(this, dims, keepDims);
|
|
521
|
-
}
|
|
522
520
|
if (typeof this.value === "number")
|
|
523
521
|
return this;
|
|
524
|
-
if (typeof dims === "number") {
|
|
525
|
-
dims = [dims];
|
|
526
|
-
}
|
|
527
522
|
if (typeof dims === "undefined") {
|
|
528
523
|
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
529
524
|
}
|
|
525
|
+
if (Array.isArray(dims)) {
|
|
526
|
+
// Sort in descending order
|
|
527
|
+
const sortedDims = dims.sort((a, b) => b - a);
|
|
528
|
+
let reducedThis = this;
|
|
529
|
+
for (let i = 0; i < sortedDims.length; i++) {
|
|
530
|
+
reducedThis = reducedThis.max(sortedDims[i], true);
|
|
531
|
+
}
|
|
532
|
+
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
533
|
+
}
|
|
530
534
|
// Dims that are reduced now have size-1
|
|
531
|
-
const outputShape = this.shape.map((dim, i) => dims
|
|
535
|
+
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
532
536
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
533
537
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
534
538
|
const outputValue = new Array(outputSize).fill(-Infinity);
|
|
@@ -537,7 +541,7 @@ class Tensor {
|
|
|
537
541
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
538
542
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
539
543
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
540
|
-
const outCoords = coords.map((val, i) => dims
|
|
544
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
541
545
|
// Convert output coordinates to flat index
|
|
542
546
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
543
547
|
// Get max over time
|
|
@@ -555,14 +559,25 @@ class Tensor {
|
|
|
555
559
|
out.children.push(this);
|
|
556
560
|
out.gradFn = () => {
|
|
557
561
|
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
562
|
+
const shareCounts = new Array(outputSize).fill(0);
|
|
563
|
+
const originalValue = this.value;
|
|
558
564
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
559
565
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
560
566
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
561
|
-
const outCoords = coords.map((val, i) => dims
|
|
567
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
562
568
|
// Convert output coordinates to flat index
|
|
563
569
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
564
|
-
//
|
|
565
|
-
|
|
570
|
+
// We collect how many elements share the same max value first
|
|
571
|
+
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
|
|
572
|
+
}
|
|
573
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
574
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
575
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
576
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
577
|
+
// Convert output coordinates to flat index
|
|
578
|
+
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
579
|
+
// Here we share the grad between the elements that share the same max value
|
|
580
|
+
gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
|
|
566
581
|
}
|
|
567
582
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
568
583
|
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
@@ -572,21 +587,22 @@ class Tensor {
|
|
|
572
587
|
}
|
|
573
588
|
// Tensor minimum reduction
|
|
574
589
|
min(dims, keepDims = false) {
|
|
575
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
576
|
-
const backend = Tensor.backends.get(this.device);
|
|
577
|
-
if (backend && backend.min) {
|
|
578
|
-
return backend.min(this, dims, keepDims);
|
|
579
|
-
}
|
|
580
590
|
if (typeof this.value === "number")
|
|
581
591
|
return this;
|
|
582
|
-
if (typeof dims === "number") {
|
|
583
|
-
dims = [dims];
|
|
584
|
-
}
|
|
585
592
|
if (typeof dims === "undefined") {
|
|
586
593
|
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
587
594
|
}
|
|
595
|
+
if (Array.isArray(dims)) {
|
|
596
|
+
// Sort in descending order
|
|
597
|
+
const sortedDims = dims.sort((a, b) => b - a);
|
|
598
|
+
let reducedThis = this;
|
|
599
|
+
for (let i = 0; i < sortedDims.length; i++) {
|
|
600
|
+
reducedThis = reducedThis.min(sortedDims[i], true);
|
|
601
|
+
}
|
|
602
|
+
return keepDims ? reducedThis : reducedThis.squeeze(dims);
|
|
603
|
+
}
|
|
588
604
|
// Dims that are reduced now have size-1
|
|
589
|
-
const outputShape = this.shape.map((dim, i) => dims
|
|
605
|
+
const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
590
606
|
const outputStrides = Tensor.getStrides(outputShape);
|
|
591
607
|
const outputSize = Tensor.shapeToSize(outputShape);
|
|
592
608
|
const outputValue = new Array(outputSize).fill(Infinity);
|
|
@@ -595,7 +611,7 @@ class Tensor {
|
|
|
595
611
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
596
612
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
597
613
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
598
|
-
const outCoords = coords.map((val, i) => dims
|
|
614
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
599
615
|
// Convert output coordinates to flat index
|
|
600
616
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
601
617
|
// Get min over time
|
|
@@ -613,14 +629,25 @@ class Tensor {
|
|
|
613
629
|
out.children.push(this);
|
|
614
630
|
out.gradFn = () => {
|
|
615
631
|
const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
|
|
632
|
+
const shareCounts = new Array(outputSize).fill(0);
|
|
633
|
+
const originalValue = this.value;
|
|
634
|
+
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
635
|
+
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
636
|
+
// Force 0 on reduced axes to collapse into size-1 dims
|
|
637
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
638
|
+
// Convert output coordinates to flat index
|
|
639
|
+
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
640
|
+
// We collect how many elements share the same min value first
|
|
641
|
+
shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
|
|
642
|
+
}
|
|
616
643
|
for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
|
|
617
644
|
const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
|
|
618
645
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
619
|
-
const outCoords = coords.map((val, i) => dims
|
|
646
|
+
const outCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
620
647
|
// Convert output coordinates to flat index
|
|
621
648
|
const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
|
|
622
|
-
//
|
|
623
|
-
gradValue[realFlatIndex] = outputValue[outFlatIndex] ===
|
|
649
|
+
// Here we share the grad between the elements that share the same min value
|
|
650
|
+
gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
|
|
624
651
|
}
|
|
625
652
|
const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
|
|
626
653
|
Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
|
|
@@ -628,23 +655,34 @@ class Tensor {
|
|
|
628
655
|
}
|
|
629
656
|
return keepDims ? out : out.squeeze(dims);
|
|
630
657
|
}
|
|
658
|
+
// Tensor variance reduction
|
|
659
|
+
var(dims, keepDims = false) {
|
|
660
|
+
const meanXSquared = this.square().mean(dims, keepDims);
|
|
661
|
+
const meanXSquaredExpanded = this.mean(dims, keepDims).square();
|
|
662
|
+
return meanXSquared.sub(meanXSquaredExpanded);
|
|
663
|
+
}
|
|
664
|
+
// Tensor standard deviation reduction
|
|
665
|
+
std(dims, keepDims = false) {
|
|
666
|
+
return this.var(dims, keepDims).sqrt();
|
|
667
|
+
}
|
|
631
668
|
// Tensor product reduction
|
|
632
669
|
softmax(dims) {
|
|
633
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
634
|
-
const backend = Tensor.backends.get(this.device);
|
|
635
|
-
if (backend && backend.softmax) {
|
|
636
|
-
return backend.softmax(this, dims);
|
|
637
|
-
}
|
|
638
670
|
if (typeof this.value === "number")
|
|
639
671
|
return this;
|
|
640
|
-
if (typeof dims === "number") {
|
|
641
|
-
dims = [dims];
|
|
642
|
-
}
|
|
643
672
|
if (typeof dims === "undefined") {
|
|
644
673
|
dims = Array.from({ length: this.shape.length }, (_, index) => index);
|
|
645
674
|
}
|
|
675
|
+
if (Array.isArray(dims)) {
|
|
676
|
+
// Sort in descending order
|
|
677
|
+
const sortedDims = dims.sort((a, b) => b - a);
|
|
678
|
+
let reducedThis = this;
|
|
679
|
+
for (let i = 0; i < sortedDims.length; i++) {
|
|
680
|
+
reducedThis = reducedThis.softmax(sortedDims[i]);
|
|
681
|
+
}
|
|
682
|
+
return reducedThis;
|
|
683
|
+
}
|
|
646
684
|
// Dims that are reduced now have size-1
|
|
647
|
-
const expSumShape = this.shape.map((dim, i) => dims
|
|
685
|
+
const expSumShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
|
|
648
686
|
const expSumStrides = Tensor.getStrides(expSumShape);
|
|
649
687
|
const expSumSize = Tensor.shapeToSize(expSumShape);
|
|
650
688
|
const expSumValue = new Array(expSumSize).fill(0);
|
|
@@ -656,7 +694,7 @@ class Tensor {
|
|
|
656
694
|
for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
|
|
657
695
|
const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
|
|
658
696
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
659
|
-
const expSumCoords = coords.map((val, i) => dims
|
|
697
|
+
const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
660
698
|
// Convert exp sum coordinates to flat index
|
|
661
699
|
const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
|
|
662
700
|
// Add e^x to the sum cache
|
|
@@ -666,7 +704,7 @@ class Tensor {
|
|
|
666
704
|
for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
|
|
667
705
|
const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
|
|
668
706
|
// Force 0 on reduced axes to collapse into size-1 dims
|
|
669
|
-
const expSumCoords = coords.map((val, i) => dims
|
|
707
|
+
const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
|
|
670
708
|
// Convert exp sum coordinates to flat index
|
|
671
709
|
const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
|
|
672
710
|
// Calculate e^xi / sum
|
|
@@ -699,488 +737,228 @@ class Tensor {
|
|
|
699
737
|
}
|
|
700
738
|
// Tensor element-wise addition
|
|
701
739
|
add(other) {
|
|
702
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
703
|
-
const backend = Tensor.backends.get(this.device);
|
|
704
|
-
if (backend && backend.add) {
|
|
705
|
-
return backend.add(this, other);
|
|
706
|
-
}
|
|
707
740
|
return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
|
|
708
741
|
}
|
|
709
742
|
// Tensor element-wise subtraction
|
|
710
743
|
sub(other) {
|
|
711
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
712
|
-
const backend = Tensor.backends.get(this.device);
|
|
713
|
-
if (backend && backend.sub) {
|
|
714
|
-
return backend.sub(this, other);
|
|
715
|
-
}
|
|
716
744
|
return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
|
|
717
745
|
}
|
|
718
746
|
subtract = this.sub;
|
|
719
747
|
// Tensor element-wise multiplication
|
|
720
748
|
mul(other) {
|
|
721
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
722
|
-
const backend = Tensor.backends.get(this.device);
|
|
723
|
-
if (backend && backend.mul) {
|
|
724
|
-
return backend.mul(this, other);
|
|
725
|
-
}
|
|
726
749
|
return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
|
|
727
750
|
}
|
|
728
751
|
multiply = this.mul;
|
|
729
752
|
// Tensor element-wise power
|
|
730
753
|
pow(other) {
|
|
731
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
732
|
-
const backend = Tensor.backends.get(this.device);
|
|
733
|
-
if (backend && backend.pow) {
|
|
734
|
-
return backend.pow(this, other);
|
|
735
|
-
}
|
|
736
754
|
return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
|
|
737
755
|
}
|
|
738
756
|
// Tensor element-wise division
|
|
739
757
|
div(other) {
|
|
740
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
741
|
-
const backend = Tensor.backends.get(this.device);
|
|
742
|
-
if (backend && backend.div) {
|
|
743
|
-
return backend.div(this, other);
|
|
744
|
-
}
|
|
745
758
|
return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.square())));
|
|
746
759
|
}
|
|
747
760
|
divide = this.div;
|
|
748
761
|
// Tensor element-wise modulo
|
|
749
762
|
remainder(other) {
|
|
750
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
751
|
-
const backend = Tensor.backends.get(this.device);
|
|
752
|
-
if (backend && backend.remainder) {
|
|
753
|
-
return backend.remainder(this, other);
|
|
754
|
-
}
|
|
755
763
|
return this.elementWiseABDAG(other, (a, b) => a % b);
|
|
756
764
|
}
|
|
757
765
|
// Tensor element-wise greater or equal comparison
|
|
758
766
|
ge(other) {
|
|
759
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
760
|
-
const backend = Tensor.backends.get(this.device);
|
|
761
|
-
if (backend && backend.ge) {
|
|
762
|
-
return backend.ge(this, other);
|
|
763
|
-
}
|
|
764
767
|
return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
|
|
765
768
|
}
|
|
766
769
|
greaterEqual = this.ge;
|
|
767
770
|
// Tensor element-wise less or equal comparison
|
|
768
771
|
le(other) {
|
|
769
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
770
|
-
const backend = Tensor.backends.get(this.device);
|
|
771
|
-
if (backend && backend.le) {
|
|
772
|
-
return backend.le(this, other);
|
|
773
|
-
}
|
|
774
772
|
return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
|
|
775
773
|
}
|
|
776
774
|
lessEqual = this.le;
|
|
777
775
|
// Tensor element-wise greater-than comparison
|
|
778
776
|
gt(other) {
|
|
779
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
780
|
-
const backend = Tensor.backends.get(this.device);
|
|
781
|
-
if (backend && backend.gt) {
|
|
782
|
-
return backend.gt(this, other);
|
|
783
|
-
}
|
|
784
777
|
return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
|
|
785
778
|
}
|
|
786
779
|
greater = this.gt;
|
|
787
780
|
// Tensor element-wise less-than comparison
|
|
788
781
|
lt(other) {
|
|
789
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
790
|
-
const backend = Tensor.backends.get(this.device);
|
|
791
|
-
if (backend && backend.lt) {
|
|
792
|
-
return backend.lt(this, other);
|
|
793
|
-
}
|
|
794
782
|
return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
|
|
795
783
|
}
|
|
796
784
|
less = this.lt;
|
|
797
785
|
// Tensor element-wise equality comparison
|
|
798
786
|
eq(other) {
|
|
799
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
800
|
-
const backend = Tensor.backends.get(this.device);
|
|
801
|
-
if (backend && backend.eq) {
|
|
802
|
-
return backend.eq(this, other);
|
|
803
|
-
}
|
|
804
787
|
return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
|
|
805
788
|
}
|
|
806
789
|
equal = this.eq;
|
|
807
790
|
// Tensor element-wise not equality comparison
|
|
808
791
|
ne(other) {
|
|
809
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
810
|
-
const backend = Tensor.backends.get(this.device);
|
|
811
|
-
if (backend && backend.ne) {
|
|
812
|
-
return backend.ne(this, other);
|
|
813
|
-
}
|
|
814
792
|
return this.elementWiseABDAG(other, (a, b) => a !== b ? 1 : 0);
|
|
815
793
|
}
|
|
816
794
|
notEqual = this.ne;
|
|
817
795
|
// Tensor element-wise logical and
|
|
818
796
|
logicalAnd(other) {
|
|
819
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
820
|
-
const backend = Tensor.backends.get(this.device);
|
|
821
|
-
if (backend && backend.logicalAnd) {
|
|
822
|
-
return backend.logicalAnd(this, other);
|
|
823
|
-
}
|
|
824
797
|
return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
|
|
825
798
|
}
|
|
826
799
|
// Tensor element-wise logical or
|
|
827
800
|
logicalOr(other) {
|
|
828
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
829
|
-
const backend = Tensor.backends.get(this.device);
|
|
830
|
-
if (backend && backend.logicalOr) {
|
|
831
|
-
return backend.logicalOr(this, other);
|
|
832
|
-
}
|
|
833
801
|
return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
|
|
834
802
|
}
|
|
835
803
|
// Tensor element-wise logical xor
|
|
836
804
|
logicalXor(other) {
|
|
837
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
838
|
-
const backend = Tensor.backends.get(this.device);
|
|
839
|
-
if (backend && backend.logicalXor) {
|
|
840
|
-
return backend.logicalXor(this, other);
|
|
841
|
-
}
|
|
842
805
|
return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
|
|
843
806
|
}
|
|
844
807
|
// Tensor element-wise logical not
|
|
845
808
|
logicalNot() {
|
|
846
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
847
|
-
const backend = Tensor.backends.get(this.device);
|
|
848
|
-
if (backend && backend.logicalNot) {
|
|
849
|
-
return backend.logicalNot(this);
|
|
850
|
-
}
|
|
851
809
|
return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
|
|
852
810
|
}
|
|
853
811
|
// Tensor element-wise bitwise and
|
|
854
812
|
bitwiseAnd(other) {
|
|
855
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
856
|
-
const backend = Tensor.backends.get(this.device);
|
|
857
|
-
if (backend && backend.bitwiseAnd) {
|
|
858
|
-
return backend.bitwiseAnd(this, other);
|
|
859
|
-
}
|
|
860
813
|
return this.elementWiseABDAG(other, (a, b) => a & b);
|
|
861
814
|
}
|
|
862
815
|
// Tensor element-wise bitwise or
|
|
863
816
|
bitwiseOr(other) {
|
|
864
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
865
|
-
const backend = Tensor.backends.get(this.device);
|
|
866
|
-
if (backend && backend.bitwiseOr) {
|
|
867
|
-
return backend.bitwiseOr(this, other);
|
|
868
|
-
}
|
|
869
817
|
return this.elementWiseABDAG(other, (a, b) => a | b);
|
|
870
818
|
}
|
|
871
819
|
// Tensor element-wise bitwise xor
|
|
872
820
|
bitwiseXor(other) {
|
|
873
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
874
|
-
const backend = Tensor.backends.get(this.device);
|
|
875
|
-
if (backend && backend.bitwiseXor) {
|
|
876
|
-
return backend.bitwiseXor(this, other);
|
|
877
|
-
}
|
|
878
821
|
return this.elementWiseABDAG(other, (a, b) => a ^ b);
|
|
879
822
|
}
|
|
880
823
|
// Tensor element-wise bitwise not
|
|
881
824
|
bitwiseNot() {
|
|
882
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
883
|
-
const backend = Tensor.backends.get(this.device);
|
|
884
|
-
if (backend && backend.bitwiseNot) {
|
|
885
|
-
return backend.bitwiseNot(this);
|
|
886
|
-
}
|
|
887
825
|
return this.elementWiseSelfDAG((a) => ~a);
|
|
888
826
|
}
|
|
889
827
|
// Tensor element-wise left shift
|
|
890
828
|
bitwiseLeftShift(other) {
|
|
891
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
892
|
-
const backend = Tensor.backends.get(this.device);
|
|
893
|
-
if (backend && backend.bitwiseLeftShift) {
|
|
894
|
-
return backend.bitwiseLeftShift(this, other);
|
|
895
|
-
}
|
|
896
829
|
return this.elementWiseABDAG(other, (a, b) => a << b);
|
|
897
830
|
}
|
|
898
831
|
// Tensor element-wise right shift
|
|
899
832
|
bitwiseRightShift(other) {
|
|
900
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
901
|
-
const backend = Tensor.backends.get(this.device);
|
|
902
|
-
if (backend && backend.bitwiseRightShift) {
|
|
903
|
-
return backend.bitwiseRightShift(this, other);
|
|
904
|
-
}
|
|
905
833
|
return this.elementWiseABDAG(other, (a, b) => a >> b);
|
|
906
834
|
}
|
|
907
835
|
// Tensor element-wise negation
|
|
908
836
|
neg() {
|
|
909
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
910
|
-
const backend = Tensor.backends.get(this.device);
|
|
911
|
-
if (backend && backend.neg) {
|
|
912
|
-
return backend.neg(this);
|
|
913
|
-
}
|
|
914
837
|
return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
|
|
915
838
|
}
|
|
916
839
|
negative = this.neg;
|
|
917
840
|
// Tensor element-wise reciprocal
|
|
918
841
|
reciprocal() {
|
|
919
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
920
|
-
const backend = Tensor.backends.get(this.device);
|
|
921
|
-
if (backend && backend.reciprocal) {
|
|
922
|
-
return backend.reciprocal(this);
|
|
923
|
-
}
|
|
924
842
|
return this.elementWiseSelfDAG((a) => 1 / a, (self, outGrad) => outGrad.mul(self.pow(-2).neg()));
|
|
925
843
|
}
|
|
926
844
|
// Tensor element-wise square
|
|
927
845
|
square() {
|
|
928
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
929
|
-
const backend = Tensor.backends.get(this.device);
|
|
930
|
-
if (backend && backend.square) {
|
|
931
|
-
return backend.square(this);
|
|
932
|
-
}
|
|
933
846
|
return this.elementWiseSelfDAG((a) => a * a, (self, outGrad) => outGrad.mul(self.mul(2)));
|
|
934
847
|
}
|
|
935
848
|
// Tensor element-wise absolute
|
|
936
849
|
abs() {
|
|
937
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
938
|
-
const backend = Tensor.backends.get(this.device);
|
|
939
|
-
if (backend && backend.abs) {
|
|
940
|
-
return backend.abs(this);
|
|
941
|
-
}
|
|
942
850
|
return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
|
|
943
851
|
}
|
|
944
852
|
absolute = this.abs;
|
|
945
853
|
// Tensor element-wise sign function
|
|
946
854
|
sign() {
|
|
947
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
948
|
-
const backend = Tensor.backends.get(this.device);
|
|
949
|
-
if (backend && backend.sign) {
|
|
950
|
-
return backend.sign(this);
|
|
951
|
-
}
|
|
952
855
|
return this.elementWiseSelfDAG((a) => Math.sign(a));
|
|
953
856
|
}
|
|
954
857
|
// Tensor element-wise sin
|
|
955
858
|
sin() {
|
|
956
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
957
|
-
const backend = Tensor.backends.get(this.device);
|
|
958
|
-
if (backend && backend.sin) {
|
|
959
|
-
return backend.sin(this);
|
|
960
|
-
}
|
|
961
859
|
return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
|
|
962
860
|
}
|
|
963
861
|
// Tensor element-wise cos
|
|
964
862
|
cos() {
|
|
965
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
966
|
-
const backend = Tensor.backends.get(this.device);
|
|
967
|
-
if (backend && backend.cos) {
|
|
968
|
-
return backend.cos(this);
|
|
969
|
-
}
|
|
970
863
|
return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
|
|
971
864
|
}
|
|
972
865
|
// Tensor element-wise tan
|
|
973
866
|
tan() {
|
|
974
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
975
|
-
const backend = Tensor.backends.get(this.device);
|
|
976
|
-
if (backend && backend.tan) {
|
|
977
|
-
return backend.tan(this);
|
|
978
|
-
}
|
|
979
867
|
return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().square().add(1)));
|
|
980
868
|
}
|
|
981
869
|
// Tensor element-wise asin
|
|
982
870
|
asin() {
|
|
983
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
984
|
-
const backend = Tensor.backends.get(this.device);
|
|
985
|
-
if (backend && backend.asin) {
|
|
986
|
-
return backend.asin(this);
|
|
987
|
-
}
|
|
988
871
|
return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()));
|
|
989
872
|
}
|
|
990
873
|
arcsin = this.asin;
|
|
991
874
|
// Tensor element-wise acos
|
|
992
875
|
acos() {
|
|
993
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
994
|
-
const backend = Tensor.backends.get(this.device);
|
|
995
|
-
if (backend && backend.acos) {
|
|
996
|
-
return backend.acos(this);
|
|
997
|
-
}
|
|
998
876
|
return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()).neg());
|
|
999
877
|
}
|
|
1000
878
|
arccos = this.acos;
|
|
1001
879
|
// Tensor element-wise atan
|
|
1002
880
|
atan() {
|
|
1003
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1004
|
-
const backend = Tensor.backends.get(this.device);
|
|
1005
|
-
if (backend && backend.atan) {
|
|
1006
|
-
return backend.atan(this);
|
|
1007
|
-
}
|
|
1008
881
|
return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.square().add(1)));
|
|
1009
882
|
}
|
|
1010
883
|
arctan = this.atan;
|
|
1011
884
|
// Tensor element-wise atan2
|
|
1012
885
|
atan2(other) {
|
|
1013
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1014
|
-
const backend = Tensor.backends.get(this.device);
|
|
1015
|
-
if (backend && backend.atan2) {
|
|
1016
|
-
return backend.atan2(this);
|
|
1017
|
-
}
|
|
1018
886
|
return this.elementWiseABDAG(other, (a, b) => Math.atan2(a, b), (self, other, outGrad) => outGrad.mul(other.div(self.square().add(other.square()))), (self, other, outGrad) => outGrad.mul(self.neg().div(self.square().add(other.square()))));
|
|
1019
887
|
}
|
|
1020
888
|
arctan2 = this.atan2;
|
|
1021
889
|
// Tensor element-wise sinh
|
|
1022
890
|
sinh() {
|
|
1023
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1024
|
-
const backend = Tensor.backends.get(this.device);
|
|
1025
|
-
if (backend && backend.sinh) {
|
|
1026
|
-
return backend.sinh(this);
|
|
1027
|
-
}
|
|
1028
891
|
return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
|
|
1029
892
|
}
|
|
1030
893
|
// Tensor element-wise cosh
|
|
1031
894
|
cosh() {
|
|
1032
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1033
|
-
const backend = Tensor.backends.get(this.device);
|
|
1034
|
-
if (backend && backend.cosh) {
|
|
1035
|
-
return backend.cosh(this);
|
|
1036
|
-
}
|
|
1037
895
|
return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
|
|
1038
896
|
}
|
|
1039
897
|
// Tensor element-wise asinh
|
|
1040
898
|
asinh() {
|
|
1041
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1042
|
-
const backend = Tensor.backends.get(this.device);
|
|
1043
|
-
if (backend && backend.asinh) {
|
|
1044
|
-
return backend.asinh(this);
|
|
1045
|
-
}
|
|
1046
899
|
return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.square().add(1).sqrt()));
|
|
1047
900
|
}
|
|
1048
901
|
arcsinh = this.asinh;
|
|
1049
902
|
// Tensor element-wise acosh
|
|
1050
903
|
acosh() {
|
|
1051
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1052
|
-
const backend = Tensor.backends.get(this.device);
|
|
1053
|
-
if (backend && backend.acosh) {
|
|
1054
|
-
return backend.acosh(this);
|
|
1055
|
-
}
|
|
1056
904
|
return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
|
|
1057
905
|
}
|
|
1058
906
|
arccosh = this.acosh;
|
|
1059
907
|
// Tensor element-wise atanh
|
|
1060
908
|
atanh() {
|
|
1061
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1062
|
-
const backend = Tensor.backends.get(this.device);
|
|
1063
|
-
if (backend && backend.atanh) {
|
|
1064
|
-
return backend.atanh(this);
|
|
1065
|
-
}
|
|
1066
909
|
return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.square().neg().add(1)));
|
|
1067
910
|
}
|
|
1068
911
|
arctanh = this.atanh;
|
|
1069
912
|
// Tensor element-wise degree to radian
|
|
1070
913
|
deg2rad() {
|
|
1071
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1072
|
-
const backend = Tensor.backends.get(this.device);
|
|
1073
|
-
if (backend && backend.deg2rad) {
|
|
1074
|
-
return backend.deg2rad(this);
|
|
1075
|
-
}
|
|
1076
914
|
return this.elementWiseSelfDAG((a) => a * (Math.PI / 180), (self, outGrad) => outGrad.mul(Math.PI / 180));
|
|
1077
915
|
}
|
|
1078
916
|
// Tensor element-wise radian to degree
|
|
1079
917
|
rad2deg() {
|
|
1080
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1081
|
-
const backend = Tensor.backends.get(this.device);
|
|
1082
|
-
if (backend && backend.rad2deg) {
|
|
1083
|
-
return backend.rad2deg(this);
|
|
1084
|
-
}
|
|
1085
918
|
return this.elementWiseSelfDAG((a) => a / (Math.PI / 180), (self, outGrad) => outGrad.div(Math.PI / 180));
|
|
1086
919
|
}
|
|
1087
920
|
// Tensor element-wise square root
|
|
1088
921
|
sqrt() {
|
|
1089
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1090
|
-
const backend = Tensor.backends.get(this.device);
|
|
1091
|
-
if (backend && backend.sqrt) {
|
|
1092
|
-
return backend.sqrt(this);
|
|
1093
|
-
}
|
|
1094
922
|
return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
|
|
1095
923
|
}
|
|
1096
924
|
// Tensor element-wise reciprocal of square root
|
|
1097
925
|
rsqrt() {
|
|
1098
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1099
|
-
const backend = Tensor.backends.get(this.device);
|
|
1100
|
-
if (backend && backend.rsqrt) {
|
|
1101
|
-
return backend.rsqrt(this);
|
|
1102
|
-
}
|
|
1103
926
|
return this.elementWiseSelfDAG((a) => 1 / Math.sqrt(a), (self, outGrad) => outGrad.mul(self.pow(-1.5).mul(-0.5)));
|
|
1104
927
|
}
|
|
1105
928
|
// Tensor element-wise e^x
|
|
1106
929
|
exp() {
|
|
1107
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1108
|
-
const backend = Tensor.backends.get(this.device);
|
|
1109
|
-
if (backend && backend.exp) {
|
|
1110
|
-
return backend.exp(this);
|
|
1111
|
-
}
|
|
1112
930
|
return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp()));
|
|
1113
931
|
}
|
|
1114
932
|
// Tensor element-wise 2^x
|
|
1115
933
|
exp2() {
|
|
1116
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1117
|
-
const backend = Tensor.backends.get(this.device);
|
|
1118
|
-
if (backend && backend.exp2) {
|
|
1119
|
-
return backend.exp2(this);
|
|
1120
|
-
}
|
|
1121
934
|
return this.elementWiseSelfDAG((a) => 2 ** a, (self, outGrad) => outGrad.mul(self.exp2().mul(Math.log(2))));
|
|
1122
935
|
}
|
|
1123
936
|
// Tensor element-wise e^x - 1
|
|
1124
937
|
expm1() {
|
|
1125
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1126
|
-
const backend = Tensor.backends.get(this.device);
|
|
1127
|
-
if (backend && backend.expm1) {
|
|
1128
|
-
return backend.expm1(this);
|
|
1129
|
-
}
|
|
1130
938
|
return this.elementWiseSelfDAG((a) => Math.expm1(a), (self, outGrad) => outGrad.mul(self.exp()));
|
|
1131
939
|
}
|
|
1132
940
|
// Tensor element-wise natural log
|
|
1133
941
|
log() {
|
|
1134
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1135
|
-
const backend = Tensor.backends.get(this.device);
|
|
1136
|
-
if (backend && backend.log) {
|
|
1137
|
-
return backend.log(this);
|
|
1138
|
-
}
|
|
1139
942
|
return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self));
|
|
1140
943
|
}
|
|
1141
944
|
// Tensor element-wise log2
|
|
1142
945
|
log2() {
|
|
1143
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1144
|
-
const backend = Tensor.backends.get(this.device);
|
|
1145
|
-
if (backend && backend.log2) {
|
|
1146
|
-
return backend.log2(this);
|
|
1147
|
-
}
|
|
1148
946
|
return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2))));
|
|
1149
947
|
}
|
|
1150
948
|
// Tensor element-wise log10
|
|
1151
949
|
log10() {
|
|
1152
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1153
|
-
const backend = Tensor.backends.get(this.device);
|
|
1154
|
-
if (backend && backend.log10) {
|
|
1155
|
-
return backend.log10(this);
|
|
1156
|
-
}
|
|
1157
950
|
return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10))));
|
|
1158
951
|
}
|
|
1159
952
|
// Tensor element-wise log(1+x)
|
|
1160
953
|
log1p() {
|
|
1161
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1162
|
-
const backend = Tensor.backends.get(this.device);
|
|
1163
|
-
if (backend && backend.log1p) {
|
|
1164
|
-
return backend.log1p(this);
|
|
1165
|
-
}
|
|
1166
954
|
return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1)));
|
|
1167
955
|
}
|
|
1168
956
|
// Tensor element-wise relu
|
|
1169
957
|
relu() {
|
|
1170
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1171
|
-
const backend = Tensor.backends.get(this.device);
|
|
1172
|
-
if (backend && backend.relu) {
|
|
1173
|
-
return backend.relu(this);
|
|
1174
|
-
}
|
|
1175
958
|
return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.gt(0)));
|
|
1176
959
|
}
|
|
1177
960
|
// Tensor element-wise sigmoid
|
|
1178
961
|
sigmoid() {
|
|
1179
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1180
|
-
const backend = Tensor.backends.get(this.device);
|
|
1181
|
-
if (backend && backend.sigmoid) {
|
|
1182
|
-
return backend.sigmoid(this);
|
|
1183
|
-
}
|
|
1184
962
|
return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
|
|
1185
963
|
const sig = self.sigmoid();
|
|
1186
964
|
return outGrad.mul(sig).mul(sig.neg().add(1));
|
|
@@ -1188,38 +966,18 @@ class Tensor {
|
|
|
1188
966
|
}
|
|
1189
967
|
// Tensor element-wise tanh
|
|
1190
968
|
tanh() {
|
|
1191
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1192
|
-
const backend = Tensor.backends.get(this.device);
|
|
1193
|
-
if (backend && backend.tanh) {
|
|
1194
|
-
return backend.tanh(this);
|
|
1195
|
-
}
|
|
1196
969
|
return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().square().neg().add(1)));
|
|
1197
970
|
}
|
|
1198
971
|
// Tensor element-wise softplus
|
|
1199
972
|
softplus() {
|
|
1200
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1201
|
-
const backend = Tensor.backends.get(this.device);
|
|
1202
|
-
if (backend && backend.softplus) {
|
|
1203
|
-
return backend.softplus(this);
|
|
1204
|
-
}
|
|
1205
973
|
return this.elementWiseSelfDAG((a) => Math.log1p(Math.exp(a)), (self, outGrad) => outGrad.mul(self.sigmoid()));
|
|
1206
974
|
}
|
|
1207
975
|
// Tensor element-wise softsign
|
|
1208
976
|
softsign() {
|
|
1209
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1210
|
-
const backend = Tensor.backends.get(this.device);
|
|
1211
|
-
if (backend && backend.softsign) {
|
|
1212
|
-
return backend.softsign(this);
|
|
1213
|
-
}
|
|
1214
977
|
return this.elementWiseSelfDAG((a) => a / (1 + Math.abs(a)), (self, outGrad) => outGrad.div(self.abs().add(1).square()));
|
|
1215
978
|
}
|
|
1216
979
|
// Tensor element-wise silu (swish)
|
|
1217
980
|
silu() {
|
|
1218
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1219
|
-
const backend = Tensor.backends.get(this.device);
|
|
1220
|
-
if (backend && backend.silu) {
|
|
1221
|
-
return backend.silu(this);
|
|
1222
|
-
}
|
|
1223
981
|
return this.elementWiseSelfDAG((a) => a / (1 + Math.exp(-a)), (self, outGrad) => {
|
|
1224
982
|
const sig = self.sigmoid();
|
|
1225
983
|
return outGrad.mul(sig.add(self.mul(sig).mul(sig.neg().add(1))));
|
|
@@ -1227,11 +985,6 @@ class Tensor {
|
|
|
1227
985
|
}
|
|
1228
986
|
// Tensor element-wise mish
|
|
1229
987
|
mish() {
|
|
1230
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1231
|
-
const backend = Tensor.backends.get(this.device);
|
|
1232
|
-
if (backend && backend.mish) {
|
|
1233
|
-
return backend.mish(this);
|
|
1234
|
-
}
|
|
1235
988
|
return this.elementWiseSelfDAG((a) => a * Math.tanh(Math.log1p(Math.exp(a))), (self, outGrad) => {
|
|
1236
989
|
const tanhSoftPlus = self.exp().add(1).log().tanh();
|
|
1237
990
|
// tanh(softplus(x)) + x * (1 - tanh²(softplus(x))) * sigmoid(x)
|
|
@@ -1241,103 +994,48 @@ class Tensor {
|
|
|
1241
994
|
}
|
|
1242
995
|
// Tensor element-wise maximum
|
|
1243
996
|
maximum(other) {
|
|
1244
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1245
|
-
const backend = Tensor.backends.get(this.device);
|
|
1246
|
-
if (backend && backend.maximum) {
|
|
1247
|
-
return backend.maximum(this, other);
|
|
1248
|
-
}
|
|
1249
997
|
return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
|
|
1250
998
|
}
|
|
1251
999
|
// Tensor element-wise minimum
|
|
1252
1000
|
minimum(other) {
|
|
1253
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1254
|
-
const backend = Tensor.backends.get(this.device);
|
|
1255
|
-
if (backend && backend.minimum) {
|
|
1256
|
-
return backend.minimum(this, other);
|
|
1257
|
-
}
|
|
1258
1001
|
return this.elementWiseABDAG(other, (a, b) => Math.min(a, b), (self, other, outGrad) => outGrad.mul(self.lt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.lt(self).add(other.eq(self).mul(0.5))));
|
|
1259
1002
|
}
|
|
1260
1003
|
// Tensor element-wise round
|
|
1261
1004
|
round() {
|
|
1262
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1263
|
-
const backend = Tensor.backends.get(this.device);
|
|
1264
|
-
if (backend && backend.round) {
|
|
1265
|
-
return backend.round(this);
|
|
1266
|
-
}
|
|
1267
1005
|
return this.elementWiseSelfDAG((a) => Math.round(a));
|
|
1268
1006
|
}
|
|
1269
1007
|
// Tensor element-wise floor
|
|
1270
1008
|
floor() {
|
|
1271
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1272
|
-
const backend = Tensor.backends.get(this.device);
|
|
1273
|
-
if (backend && backend.floor) {
|
|
1274
|
-
return backend.floor(this);
|
|
1275
|
-
}
|
|
1276
1009
|
return this.elementWiseSelfDAG((a) => Math.floor(a));
|
|
1277
1010
|
}
|
|
1278
1011
|
// Tensor element-wise ceil
|
|
1279
1012
|
ceil() {
|
|
1280
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1281
|
-
const backend = Tensor.backends.get(this.device);
|
|
1282
|
-
if (backend && backend.ceil) {
|
|
1283
|
-
return backend.ceil(this);
|
|
1284
|
-
}
|
|
1285
1013
|
return this.elementWiseSelfDAG((a) => Math.ceil(a));
|
|
1286
1014
|
}
|
|
1287
1015
|
// Tensor element-wise truncation
|
|
1288
1016
|
trunc() {
|
|
1289
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1290
|
-
const backend = Tensor.backends.get(this.device);
|
|
1291
|
-
if (backend && backend.trunc) {
|
|
1292
|
-
return backend.trunc(this);
|
|
1293
|
-
}
|
|
1294
1017
|
return this.elementWiseSelfDAG((a) => Math.trunc(a));
|
|
1295
1018
|
}
|
|
1296
1019
|
fix = this.trunc;
|
|
1297
1020
|
// Tensor element-wise fraction portion
|
|
1298
1021
|
frac() {
|
|
1299
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1300
|
-
const backend = Tensor.backends.get(this.device);
|
|
1301
|
-
if (backend && backend.frac) {
|
|
1302
|
-
return backend.frac(this);
|
|
1303
|
-
}
|
|
1304
1022
|
return this.elementWiseSelfDAG((a) => a - Math.floor(a));
|
|
1305
1023
|
}
|
|
1306
1024
|
// Tensor element-wise clip and clamp
|
|
1307
1025
|
clip(min, max) {
|
|
1308
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1309
|
-
const backend = Tensor.backends.get(this.device);
|
|
1310
|
-
if (backend && backend.clip) {
|
|
1311
|
-
return backend.clip(this, min, max);
|
|
1312
|
-
}
|
|
1313
1026
|
return this.elementWiseSelfDAG((a) => Math.max(min, Math.min(max, a)), (self, outGrad) => outGrad.mul(self.ge(min).mul(self.le(max))));
|
|
1314
1027
|
}
|
|
1315
1028
|
clamp = this.clip;
|
|
1316
1029
|
// Tensor element-wise error function
|
|
1317
1030
|
erf() {
|
|
1318
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1319
|
-
const backend = Tensor.backends.get(this.device);
|
|
1320
|
-
if (backend && backend.erf) {
|
|
1321
|
-
return backend.erf(this);
|
|
1322
|
-
}
|
|
1323
1031
|
return this.elementWiseSelfDAG((a) => (0, utils_1.erf)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI))));
|
|
1324
1032
|
}
|
|
1325
1033
|
// Tensor element-wise complementary error function
|
|
1326
1034
|
erfc() {
|
|
1327
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1328
|
-
const backend = Tensor.backends.get(this.device);
|
|
1329
|
-
if (backend && backend.erfc) {
|
|
1330
|
-
return backend.erfc(this);
|
|
1331
|
-
}
|
|
1332
1035
|
return this.elementWiseSelfDAG((a) => (0, utils_1.erfc)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI)).neg()));
|
|
1333
1036
|
}
|
|
1334
1037
|
// Tensor element-wise inverse error function
|
|
1335
1038
|
erfinv() {
|
|
1336
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1337
|
-
const backend = Tensor.backends.get(this.device);
|
|
1338
|
-
if (backend && backend.erfinv) {
|
|
1339
|
-
return backend.erfinv(this);
|
|
1340
|
-
}
|
|
1341
1039
|
return this.elementWiseSelfDAG((a) => (0, utils_1.erfinv)(a), (self, outGrad) => outGrad.mul(self.erfinv().square().exp().mul(Math.sqrt(Math.PI) / 2)));
|
|
1342
1040
|
}
|
|
1343
1041
|
// Transpose
|
|
@@ -1379,11 +1077,6 @@ class Tensor {
|
|
|
1379
1077
|
}
|
|
1380
1078
|
// 1D tensor dot product
|
|
1381
1079
|
dot(other) {
|
|
1382
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1383
|
-
const backend = Tensor.backends.get(this.device);
|
|
1384
|
-
if (backend && backend.dot) {
|
|
1385
|
-
return backend.dot(this, other);
|
|
1386
|
-
}
|
|
1387
1080
|
other = Tensor.forceTensor(other);
|
|
1388
1081
|
// Verify 1D shape
|
|
1389
1082
|
if (this.shape.length !== 1 || other.shape.length !== 1) {
|
|
@@ -1422,11 +1115,6 @@ class Tensor {
|
|
|
1422
1115
|
}
|
|
1423
1116
|
// Matrix multiplication
|
|
1424
1117
|
mm(other) {
|
|
1425
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1426
|
-
const backend = Tensor.backends.get(this.device);
|
|
1427
|
-
if (backend && backend.mm) {
|
|
1428
|
-
return backend.mm(this, other);
|
|
1429
|
-
}
|
|
1430
1118
|
other = Tensor.forceTensor(other);
|
|
1431
1119
|
// Verify 2D shape
|
|
1432
1120
|
if (this.shape.length !== 2 || other.shape.length !== 2) {
|
|
@@ -1482,11 +1170,6 @@ class Tensor {
|
|
|
1482
1170
|
}
|
|
1483
1171
|
// Batched 3D tensor matmul
|
|
1484
1172
|
bmm(other) {
|
|
1485
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1486
|
-
const backend = Tensor.backends.get(this.device);
|
|
1487
|
-
if (backend && backend.bmm) {
|
|
1488
|
-
return backend.bmm(this, other);
|
|
1489
|
-
}
|
|
1490
1173
|
other = Tensor.forceTensor(other);
|
|
1491
1174
|
// Verify 3D shape
|
|
1492
1175
|
if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
|
|
@@ -1545,11 +1228,6 @@ class Tensor {
|
|
|
1545
1228
|
}
|
|
1546
1229
|
// Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
|
|
1547
1230
|
mv(other) {
|
|
1548
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1549
|
-
const backend = Tensor.backends.get(this.device);
|
|
1550
|
-
if (backend && backend.mv) {
|
|
1551
|
-
return backend.mv(this, other);
|
|
1552
|
-
}
|
|
1553
1231
|
other = Tensor.forceTensor(other);
|
|
1554
1232
|
// Verify 2D shape
|
|
1555
1233
|
if (this.shape.length !== 2 || other.shape.length !== 1) {
|
|
@@ -1559,11 +1237,6 @@ class Tensor {
|
|
|
1559
1237
|
}
|
|
1560
1238
|
// General matrix multiplication with different shapes
|
|
1561
1239
|
matmul(other) {
|
|
1562
|
-
// Use backend of tensor's device if available, or else fallback to cpu
|
|
1563
|
-
const backend = Tensor.backends.get(this.device);
|
|
1564
|
-
if (backend && backend.matmul) {
|
|
1565
|
-
return backend.matmul(this, other);
|
|
1566
|
-
}
|
|
1567
1240
|
other = Tensor.forceTensor(other);
|
|
1568
1241
|
const isThis1D = this.shape.length === 1;
|
|
1569
1242
|
const isOther1D = other.shape.length === 1;
|
|
@@ -1875,10 +1548,10 @@ class Tensor {
|
|
|
1875
1548
|
// Op to transfer tensor to another device
|
|
1876
1549
|
to(device) {
|
|
1877
1550
|
const backend = Tensor.backends.get(device);
|
|
1878
|
-
if (backend && backend.
|
|
1879
|
-
return backend.
|
|
1551
|
+
if (backend && backend.transfer) {
|
|
1552
|
+
return backend.transfer(this);
|
|
1880
1553
|
}
|
|
1881
|
-
throw new Error(`No device found to transfer tensor to or
|
|
1554
|
+
throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
|
|
1882
1555
|
}
|
|
1883
1556
|
}
|
|
1884
1557
|
exports.Tensor = Tensor;
|