catniff 0.4.0 → 0.4.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/backend.d.ts CHANGED
@@ -1,83 +1,4 @@
1
- import { Tensor, TensorValue } from "./core";
1
+ import { Tensor } from "./core";
2
2
  export interface Backend {
3
- sum?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
4
- prod?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
5
- mean?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
6
- max?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
7
- min?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
8
- softmax?(tensor: Tensor, dims?: number[] | number): Tensor;
9
- add?(self: Tensor, other: Tensor | TensorValue): Tensor;
10
- sub?(self: Tensor, other: Tensor | TensorValue): Tensor;
11
- mul?(self: Tensor, other: Tensor | TensorValue): Tensor;
12
- pow?(self: Tensor, other: Tensor | TensorValue): Tensor;
13
- div?(self: Tensor, other: Tensor | TensorValue): Tensor;
14
- remainder?(self: Tensor, other: Tensor | TensorValue): Tensor;
15
- ge?(self: Tensor, other: Tensor | TensorValue): Tensor;
16
- le?(self: Tensor, other: Tensor | TensorValue): Tensor;
17
- gt?(self: Tensor, other: Tensor | TensorValue): Tensor;
18
- lt?(self: Tensor, other: Tensor | TensorValue): Tensor;
19
- eq?(self: Tensor, other: Tensor | TensorValue): Tensor;
20
- ne?(self: Tensor, other: Tensor | TensorValue): Tensor;
21
- logicalAnd?(self: Tensor, other: Tensor | TensorValue): Tensor;
22
- logicalOr?(self: Tensor, other: Tensor | TensorValue): Tensor;
23
- logicalXor?(self: Tensor, other: Tensor | TensorValue): Tensor;
24
- logicalNot?(self: Tensor): Tensor;
25
- bitwiseAnd?(self: Tensor, other: Tensor | TensorValue): Tensor;
26
- bitwiseOr?(self: Tensor, other: Tensor | TensorValue): Tensor;
27
- bitwiseXor?(self: Tensor, other: Tensor | TensorValue): Tensor;
28
- bitwiseNot?(self: Tensor): Tensor;
29
- bitwiseLeftShift?(self: Tensor, other: Tensor | TensorValue): Tensor;
30
- bitwiseRightShift?(self: Tensor, other: Tensor | TensorValue): Tensor;
31
- neg?(self: Tensor): Tensor;
32
- reciprocal?(self: Tensor): Tensor;
33
- square?(self: Tensor): Tensor;
34
- abs?(self: Tensor): Tensor;
35
- sign?(self: Tensor): Tensor;
36
- sin?(self: Tensor): Tensor;
37
- cos?(self: Tensor): Tensor;
38
- tan?(self: Tensor): Tensor;
39
- asin?(self: Tensor): Tensor;
40
- acos?(self: Tensor): Tensor;
41
- atan?(self: Tensor): Tensor;
42
- atan2?(self: Tensor): Tensor;
43
- sinh?(self: Tensor): Tensor;
44
- cosh?(self: Tensor): Tensor;
45
- asinh?(self: Tensor): Tensor;
46
- acosh?(self: Tensor): Tensor;
47
- atanh?(self: Tensor): Tensor;
48
- deg2rad?(self: Tensor): Tensor;
49
- rad2deg?(self: Tensor): Tensor;
50
- sqrt?(self: Tensor): Tensor;
51
- rsqrt?(self: Tensor): Tensor;
52
- exp?(self: Tensor): Tensor;
53
- exp2?(self: Tensor): Tensor;
54
- expm1?(self: Tensor): Tensor;
55
- log?(self: Tensor): Tensor;
56
- log2?(self: Tensor): Tensor;
57
- log10?(self: Tensor): Tensor;
58
- log1p?(self: Tensor): Tensor;
59
- relu?(self: Tensor): Tensor;
60
- sigmoid?(self: Tensor): Tensor;
61
- tanh?(self: Tensor): Tensor;
62
- softplus?(self: Tensor): Tensor;
63
- softsign?(self: Tensor): Tensor;
64
- silu?(self: Tensor): Tensor;
65
- mish?(self: Tensor): Tensor;
66
- maximum?(self: Tensor, other: Tensor | TensorValue): Tensor;
67
- minimum?(self: Tensor, other: Tensor | TensorValue): Tensor;
68
- round?(self: Tensor): Tensor;
69
- floor?(self: Tensor): Tensor;
70
- ceil?(self: Tensor): Tensor;
71
- trunc?(self: Tensor): Tensor;
72
- frac?(self: Tensor): Tensor;
73
- clip?(self: Tensor, min: number, max: number): Tensor;
74
- erf?(self: Tensor): Tensor;
75
- erfc?(self: Tensor): Tensor;
76
- erfinv?(self: Tensor): Tensor;
77
- dot?(self: Tensor, other: Tensor | TensorValue): Tensor;
78
- mm?(self: Tensor, other: Tensor | TensorValue): Tensor;
79
- bmm?(self: Tensor, other: Tensor | TensorValue): Tensor;
80
- mv?(self: Tensor, other: Tensor | TensorValue): Tensor;
81
- matmul?(self: Tensor, other: Tensor | TensorValue): Tensor;
82
- to?(tensor: Tensor): Tensor;
3
+ transfer(tensor: Tensor): Tensor;
83
4
  }
package/dist/core.d.ts CHANGED
@@ -46,6 +46,8 @@ export declare class Tensor {
46
46
  mean(dims?: number[] | number, keepDims?: boolean): Tensor;
47
47
  max(dims?: number[] | number, keepDims?: boolean): Tensor;
48
48
  min(dims?: number[] | number, keepDims?: boolean): Tensor;
49
+ var(dims?: number[] | number, keepDims?: boolean): Tensor;
50
+ std(dims?: number[] | number, keepDims?: boolean): Tensor;
49
51
  softmax(dims?: number[] | number): Tensor;
50
52
  add(other: TensorValue | Tensor): Tensor;
51
53
  sub(other: TensorValue | Tensor): Tensor;
package/dist/core.js CHANGED
@@ -337,21 +337,22 @@ class Tensor {
337
337
  }
338
338
  // Tensor sum reduction
339
339
  sum(dims, keepDims = false) {
340
- // Use backend of tensor's device if available, or else fallback to cpu
341
- const backend = Tensor.backends.get(this.device);
342
- if (backend && backend.sum) {
343
- return backend.sum(this, dims, keepDims);
344
- }
345
340
  if (typeof this.value === "number")
346
341
  return this;
347
- if (typeof dims === "number") {
348
- dims = [dims];
349
- }
350
342
  if (typeof dims === "undefined") {
351
343
  dims = Array.from({ length: this.shape.length }, (_, index) => index);
352
344
  }
345
+ if (Array.isArray(dims)) {
346
+ // Sort in descending order
347
+ const sortedDims = dims.sort((a, b) => b - a);
348
+ let reducedThis = this;
349
+ for (let i = 0; i < sortedDims.length; i++) {
350
+ reducedThis = reducedThis.sum(sortedDims[i], true);
351
+ }
352
+ return keepDims ? reducedThis : reducedThis.squeeze(dims);
353
+ }
353
354
  // Dims that are reduced now have size-1
354
- const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
355
+ const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
355
356
  const outputStrides = Tensor.getStrides(outputShape);
356
357
  const outputSize = Tensor.shapeToSize(outputShape);
357
358
  const outputValue = new Array(outputSize).fill(0);
@@ -368,7 +369,7 @@ class Tensor {
368
369
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
369
370
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
370
371
  // Force 0 on reduced axes to collapse into size-1 dims
371
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
372
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
372
373
  // Convert output coordinates to flat index
373
374
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
374
375
  // Add into sum
@@ -395,21 +396,22 @@ class Tensor {
395
396
  }
396
397
  // Tensor product reduction
397
398
  prod(dims, keepDims = false) {
398
- // Use backend of tensor's device if available, or else fallback to cpu
399
- const backend = Tensor.backends.get(this.device);
400
- if (backend && backend.prod) {
401
- return backend.prod(this, dims, keepDims);
402
- }
403
399
  if (typeof this.value === "number")
404
400
  return this;
405
- if (typeof dims === "number") {
406
- dims = [dims];
407
- }
408
401
  if (typeof dims === "undefined") {
409
402
  dims = Array.from({ length: this.shape.length }, (_, index) => index);
410
403
  }
404
+ if (Array.isArray(dims)) {
405
+ // Sort in descending order
406
+ const sortedDims = dims.sort((a, b) => b - a);
407
+ let reducedThis = this;
408
+ for (let i = 0; i < sortedDims.length; i++) {
409
+ reducedThis = reducedThis.prod(sortedDims[i], true);
410
+ }
411
+ return keepDims ? reducedThis : reducedThis.squeeze(dims);
412
+ }
411
413
  // Dims that are reduced now have size-1
412
- const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
414
+ const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
413
415
  const outputStrides = Tensor.getStrides(outputShape);
414
416
  const outputSize = Tensor.shapeToSize(outputShape);
415
417
  const outputValue = new Array(outputSize).fill(1);
@@ -418,7 +420,7 @@ class Tensor {
418
420
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
419
421
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
420
422
  // Force 0 on reduced axes to collapse into size-1 dims
421
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
423
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
422
424
  // Convert output coordinates to flat index
423
425
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
424
426
  // Multiply into product
@@ -437,7 +439,7 @@ class Tensor {
437
439
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
438
440
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
439
441
  // Force 0 on reduced axes to collapse into size-1 dims
440
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
442
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
441
443
  // Convert output coordinates to flat index
442
444
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
443
445
  // Grad is the product of other elements of the same axis, which is product of all els divided by the current value
@@ -451,21 +453,22 @@ class Tensor {
451
453
  }
452
454
  // Tensor mean reduction
453
455
  mean(dims, keepDims = false) {
454
- // Use backend of tensor's device if available, or else fallback to cpu
455
- const backend = Tensor.backends.get(this.device);
456
- if (backend && backend.mean) {
457
- return backend.mean(this, dims, keepDims);
458
- }
459
456
  if (typeof this.value === "number")
460
457
  return this;
461
- if (typeof dims === "number") {
462
- dims = [dims];
463
- }
464
458
  if (typeof dims === "undefined") {
465
459
  dims = Array.from({ length: this.shape.length }, (_, index) => index);
466
460
  }
461
+ if (Array.isArray(dims)) {
462
+ // Sort in descending order
463
+ const sortedDims = dims.sort((a, b) => b - a);
464
+ let reducedThis = this;
465
+ for (let i = 0; i < sortedDims.length; i++) {
466
+ reducedThis = reducedThis.mean(sortedDims[i], true);
467
+ }
468
+ return keepDims ? reducedThis : reducedThis.squeeze(dims);
469
+ }
467
470
  // Dims that are reduced now have size-1
468
- const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
471
+ const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
469
472
  const outputStrides = Tensor.getStrides(outputShape);
470
473
  const outputSize = Tensor.shapeToSize(outputShape);
471
474
  const outputValue = new Array(outputSize).fill(0);
@@ -475,7 +478,7 @@ class Tensor {
475
478
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
476
479
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
477
480
  // Force 0 on reduced axes to collapse into size-1 dims
478
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
481
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
479
482
  // Convert output coordinates to flat index
480
483
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
481
484
  // Calculate sum and contributors to the sum
@@ -500,7 +503,7 @@ class Tensor {
500
503
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
501
504
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
502
505
  // Force 0 on reduced axes to collapse into size-1 dims
503
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
506
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
504
507
  // Convert output coordinates to flat index
505
508
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
506
509
  // Mean = 1/n * (el1 + el2 + ... + eln) so grad = 1/n
@@ -514,21 +517,22 @@ class Tensor {
514
517
  }
515
518
  // Tensor maximum reduction
516
519
  max(dims, keepDims = false) {
517
- // Use backend of tensor's device if available, or else fallback to cpu
518
- const backend = Tensor.backends.get(this.device);
519
- if (backend && backend.max) {
520
- return backend.max(this, dims, keepDims);
521
- }
522
520
  if (typeof this.value === "number")
523
521
  return this;
524
- if (typeof dims === "number") {
525
- dims = [dims];
526
- }
527
522
  if (typeof dims === "undefined") {
528
523
  dims = Array.from({ length: this.shape.length }, (_, index) => index);
529
524
  }
525
+ if (Array.isArray(dims)) {
526
+ // Sort in descending order
527
+ const sortedDims = dims.sort((a, b) => b - a);
528
+ let reducedThis = this;
529
+ for (let i = 0; i < sortedDims.length; i++) {
530
+ reducedThis = reducedThis.max(sortedDims[i], true);
531
+ }
532
+ return keepDims ? reducedThis : reducedThis.squeeze(dims);
533
+ }
530
534
  // Dims that are reduced now have size-1
531
- const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
535
+ const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
532
536
  const outputStrides = Tensor.getStrides(outputShape);
533
537
  const outputSize = Tensor.shapeToSize(outputShape);
534
538
  const outputValue = new Array(outputSize).fill(-Infinity);
@@ -537,7 +541,7 @@ class Tensor {
537
541
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
538
542
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
539
543
  // Force 0 on reduced axes to collapse into size-1 dims
540
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
544
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
541
545
  // Convert output coordinates to flat index
542
546
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
543
547
  // Get max over time
@@ -555,14 +559,25 @@ class Tensor {
555
559
  out.children.push(this);
556
560
  out.gradFn = () => {
557
561
  const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
562
+ const shareCounts = new Array(outputSize).fill(0);
563
+ const originalValue = this.value;
558
564
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
559
565
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
560
566
  // Force 0 on reduced axes to collapse into size-1 dims
561
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
567
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
562
568
  // Convert output coordinates to flat index
563
569
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
564
- // Calculate grad by checking if a positon holds a value equal to the max value
565
- gradValue[realFlatIndex] = outputValue[outFlatIndex] === this.value[realFlatIndex] ? 1 : 0;
570
+ // We collect how many elements share the same max value first
571
+ shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
572
+ }
573
+ for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
574
+ const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
575
+ // Force 0 on reduced axes to collapse into size-1 dims
576
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
577
+ // Convert output coordinates to flat index
578
+ const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
579
+ // Here we share the grad between the elements that share the same max value
580
+ gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
566
581
  }
567
582
  const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
568
583
  Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
@@ -572,21 +587,22 @@ class Tensor {
572
587
  }
573
588
  // Tensor minimum reduction
574
589
  min(dims, keepDims = false) {
575
- // Use backend of tensor's device if available, or else fallback to cpu
576
- const backend = Tensor.backends.get(this.device);
577
- if (backend && backend.min) {
578
- return backend.min(this, dims, keepDims);
579
- }
580
590
  if (typeof this.value === "number")
581
591
  return this;
582
- if (typeof dims === "number") {
583
- dims = [dims];
584
- }
585
592
  if (typeof dims === "undefined") {
586
593
  dims = Array.from({ length: this.shape.length }, (_, index) => index);
587
594
  }
595
+ if (Array.isArray(dims)) {
596
+ // Sort in descending order
597
+ const sortedDims = dims.sort((a, b) => b - a);
598
+ let reducedThis = this;
599
+ for (let i = 0; i < sortedDims.length; i++) {
600
+ reducedThis = reducedThis.min(sortedDims[i], true);
601
+ }
602
+ return keepDims ? reducedThis : reducedThis.squeeze(dims);
603
+ }
588
604
  // Dims that are reduced now have size-1
589
- const outputShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
605
+ const outputShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
590
606
  const outputStrides = Tensor.getStrides(outputShape);
591
607
  const outputSize = Tensor.shapeToSize(outputShape);
592
608
  const outputValue = new Array(outputSize).fill(Infinity);
@@ -595,7 +611,7 @@ class Tensor {
595
611
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
596
612
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
597
613
  // Force 0 on reduced axes to collapse into size-1 dims
598
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
614
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
599
615
  // Convert output coordinates to flat index
600
616
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
601
617
  // Get min over time
@@ -613,14 +629,25 @@ class Tensor {
613
629
  out.children.push(this);
614
630
  out.gradFn = () => {
615
631
  const gradShape = this.shape, gradStrides = this.strides, gradValue = new Array(originalSize).fill(0);
632
+ const shareCounts = new Array(outputSize).fill(0);
633
+ const originalValue = this.value;
634
+ for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
635
+ const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
636
+ // Force 0 on reduced axes to collapse into size-1 dims
637
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
638
+ // Convert output coordinates to flat index
639
+ const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
640
+ // We collect how many elements share the same min value first
641
+ shareCounts[outFlatIndex] += outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 : 0;
642
+ }
616
643
  for (let realFlatIndex = 0; realFlatIndex < originalSize; realFlatIndex++) {
617
644
  const coords = Tensor.indexToCoords(realFlatIndex, this.strides);
618
645
  // Force 0 on reduced axes to collapse into size-1 dims
619
- const outCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
646
+ const outCoords = coords.map((val, i) => dims === i ? 0 : val);
620
647
  // Convert output coordinates to flat index
621
648
  const outFlatIndex = Tensor.coordsToIndex(outCoords, outputStrides);
622
- // Calculate grad by checking if a positon holds a value equal to the min value
623
- gradValue[realFlatIndex] = outputValue[outFlatIndex] === this.value[realFlatIndex] ? 1 : 0;
649
+ // Here we share the grad between the elements that share the same min value
650
+ gradValue[realFlatIndex] = outputValue[outFlatIndex] === originalValue[realFlatIndex] ? 1 / shareCounts[outFlatIndex] : 0;
624
651
  }
625
652
  const localGrad = new Tensor(gradValue, { shape: gradShape, strides: gradStrides });
626
653
  Tensor.addGrad(this, out.grad.withGrad(false).mul(localGrad));
@@ -628,23 +655,34 @@ class Tensor {
628
655
  }
629
656
  return keepDims ? out : out.squeeze(dims);
630
657
  }
658
+ // Tensor variance reduction
659
+ var(dims, keepDims = false) {
660
+ const meanXSquared = this.square().mean(dims, keepDims);
661
+ const meanXSquaredExpanded = this.mean(dims, keepDims).square();
662
+ return meanXSquared.sub(meanXSquaredExpanded);
663
+ }
664
+ // Tensor standard deviation reduction
665
+ std(dims, keepDims = false) {
666
+ return this.var(dims, keepDims).sqrt();
667
+ }
631
668
  // Tensor product reduction
632
669
  softmax(dims) {
633
- // Use backend of tensor's device if available, or else fallback to cpu
634
- const backend = Tensor.backends.get(this.device);
635
- if (backend && backend.softmax) {
636
- return backend.softmax(this, dims);
637
- }
638
670
  if (typeof this.value === "number")
639
671
  return this;
640
- if (typeof dims === "number") {
641
- dims = [dims];
642
- }
643
672
  if (typeof dims === "undefined") {
644
673
  dims = Array.from({ length: this.shape.length }, (_, index) => index);
645
674
  }
675
+ if (Array.isArray(dims)) {
676
+ // Sort in descending order
677
+ const sortedDims = dims.sort((a, b) => b - a);
678
+ let reducedThis = this;
679
+ for (let i = 0; i < sortedDims.length; i++) {
680
+ reducedThis = reducedThis.softmax(sortedDims[i]);
681
+ }
682
+ return reducedThis;
683
+ }
646
684
  // Dims that are reduced now have size-1
647
- const expSumShape = this.shape.map((dim, i) => dims.includes(i) ? 1 : dim);
685
+ const expSumShape = this.shape.map((dim, i) => dims === i ? 1 : dim);
648
686
  const expSumStrides = Tensor.getStrides(expSumShape);
649
687
  const expSumSize = Tensor.shapeToSize(expSumShape);
650
688
  const expSumValue = new Array(expSumSize).fill(0);
@@ -656,7 +694,7 @@ class Tensor {
656
694
  for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
657
695
  const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
658
696
  // Force 0 on reduced axes to collapse into size-1 dims
659
- const expSumCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
697
+ const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
660
698
  // Convert exp sum coordinates to flat index
661
699
  const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
662
700
  // Add e^x to the sum cache
@@ -666,7 +704,7 @@ class Tensor {
666
704
  for (let realFlatIndex = 0; realFlatIndex < outputSize; realFlatIndex++) {
667
705
  const coords = Tensor.indexToCoords(realFlatIndex, outputStrides);
668
706
  // Force 0 on reduced axes to collapse into size-1 dims
669
- const expSumCoords = coords.map((val, i) => dims.includes(i) ? 0 : val);
707
+ const expSumCoords = coords.map((val, i) => dims === i ? 0 : val);
670
708
  // Convert exp sum coordinates to flat index
671
709
  const expSumFlatIndex = Tensor.coordsToIndex(expSumCoords, expSumStrides);
672
710
  // Calculate e^xi / sum
@@ -699,488 +737,228 @@ class Tensor {
699
737
  }
700
738
  // Tensor element-wise addition
701
739
  add(other) {
702
- // Use backend of tensor's device if available, or else fallback to cpu
703
- const backend = Tensor.backends.get(this.device);
704
- if (backend && backend.add) {
705
- return backend.add(this, other);
706
- }
707
740
  return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
708
741
  }
709
742
  // Tensor element-wise subtraction
710
743
  sub(other) {
711
- // Use backend of tensor's device if available, or else fallback to cpu
712
- const backend = Tensor.backends.get(this.device);
713
- if (backend && backend.sub) {
714
- return backend.sub(this, other);
715
- }
716
744
  return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
717
745
  }
718
746
  subtract = this.sub;
719
747
  // Tensor element-wise multiplication
720
748
  mul(other) {
721
- // Use backend of tensor's device if available, or else fallback to cpu
722
- const backend = Tensor.backends.get(this.device);
723
- if (backend && backend.mul) {
724
- return backend.mul(this, other);
725
- }
726
749
  return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
727
750
  }
728
751
  multiply = this.mul;
729
752
  // Tensor element-wise power
730
753
  pow(other) {
731
- // Use backend of tensor's device if available, or else fallback to cpu
732
- const backend = Tensor.backends.get(this.device);
733
- if (backend && backend.pow) {
734
- return backend.pow(this, other);
735
- }
736
754
  return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
737
755
  }
738
756
  // Tensor element-wise division
739
757
  div(other) {
740
- // Use backend of tensor's device if available, or else fallback to cpu
741
- const backend = Tensor.backends.get(this.device);
742
- if (backend && backend.div) {
743
- return backend.div(this, other);
744
- }
745
758
  return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.square())));
746
759
  }
747
760
  divide = this.div;
748
761
  // Tensor element-wise modulo
749
762
  remainder(other) {
750
- // Use backend of tensor's device if available, or else fallback to cpu
751
- const backend = Tensor.backends.get(this.device);
752
- if (backend && backend.remainder) {
753
- return backend.remainder(this, other);
754
- }
755
763
  return this.elementWiseABDAG(other, (a, b) => a % b);
756
764
  }
757
765
  // Tensor element-wise greater or equal comparison
758
766
  ge(other) {
759
- // Use backend of tensor's device if available, or else fallback to cpu
760
- const backend = Tensor.backends.get(this.device);
761
- if (backend && backend.ge) {
762
- return backend.ge(this, other);
763
- }
764
767
  return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
765
768
  }
766
769
  greaterEqual = this.ge;
767
770
  // Tensor element-wise less or equal comparison
768
771
  le(other) {
769
- // Use backend of tensor's device if available, or else fallback to cpu
770
- const backend = Tensor.backends.get(this.device);
771
- if (backend && backend.le) {
772
- return backend.le(this, other);
773
- }
774
772
  return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
775
773
  }
776
774
  lessEqual = this.le;
777
775
  // Tensor element-wise greater-than comparison
778
776
  gt(other) {
779
- // Use backend of tensor's device if available, or else fallback to cpu
780
- const backend = Tensor.backends.get(this.device);
781
- if (backend && backend.gt) {
782
- return backend.gt(this, other);
783
- }
784
777
  return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
785
778
  }
786
779
  greater = this.gt;
787
780
  // Tensor element-wise less-than comparison
788
781
  lt(other) {
789
- // Use backend of tensor's device if available, or else fallback to cpu
790
- const backend = Tensor.backends.get(this.device);
791
- if (backend && backend.lt) {
792
- return backend.lt(this, other);
793
- }
794
782
  return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
795
783
  }
796
784
  less = this.lt;
797
785
  // Tensor element-wise equality comparison
798
786
  eq(other) {
799
- // Use backend of tensor's device if available, or else fallback to cpu
800
- const backend = Tensor.backends.get(this.device);
801
- if (backend && backend.eq) {
802
- return backend.eq(this, other);
803
- }
804
787
  return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
805
788
  }
806
789
  equal = this.eq;
807
790
  // Tensor element-wise not equality comparison
808
791
  ne(other) {
809
- // Use backend of tensor's device if available, or else fallback to cpu
810
- const backend = Tensor.backends.get(this.device);
811
- if (backend && backend.ne) {
812
- return backend.ne(this, other);
813
- }
814
792
  return this.elementWiseABDAG(other, (a, b) => a !== b ? 1 : 0);
815
793
  }
816
794
  notEqual = this.ne;
817
795
  // Tensor element-wise logical and
818
796
  logicalAnd(other) {
819
- // Use backend of tensor's device if available, or else fallback to cpu
820
- const backend = Tensor.backends.get(this.device);
821
- if (backend && backend.logicalAnd) {
822
- return backend.logicalAnd(this, other);
823
- }
824
797
  return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
825
798
  }
826
799
  // Tensor element-wise logical or
827
800
  logicalOr(other) {
828
- // Use backend of tensor's device if available, or else fallback to cpu
829
- const backend = Tensor.backends.get(this.device);
830
- if (backend && backend.logicalOr) {
831
- return backend.logicalOr(this, other);
832
- }
833
801
  return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
834
802
  }
835
803
  // Tensor element-wise logical xor
836
804
  logicalXor(other) {
837
- // Use backend of tensor's device if available, or else fallback to cpu
838
- const backend = Tensor.backends.get(this.device);
839
- if (backend && backend.logicalXor) {
840
- return backend.logicalXor(this, other);
841
- }
842
805
  return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
843
806
  }
844
807
  // Tensor element-wise logical not
845
808
  logicalNot() {
846
- // Use backend of tensor's device if available, or else fallback to cpu
847
- const backend = Tensor.backends.get(this.device);
848
- if (backend && backend.logicalNot) {
849
- return backend.logicalNot(this);
850
- }
851
809
  return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
852
810
  }
853
811
  // Tensor element-wise bitwise and
854
812
  bitwiseAnd(other) {
855
- // Use backend of tensor's device if available, or else fallback to cpu
856
- const backend = Tensor.backends.get(this.device);
857
- if (backend && backend.bitwiseAnd) {
858
- return backend.bitwiseAnd(this, other);
859
- }
860
813
  return this.elementWiseABDAG(other, (a, b) => a & b);
861
814
  }
862
815
  // Tensor element-wise bitwise or
863
816
  bitwiseOr(other) {
864
- // Use backend of tensor's device if available, or else fallback to cpu
865
- const backend = Tensor.backends.get(this.device);
866
- if (backend && backend.bitwiseOr) {
867
- return backend.bitwiseOr(this, other);
868
- }
869
817
  return this.elementWiseABDAG(other, (a, b) => a | b);
870
818
  }
871
819
  // Tensor element-wise bitwise xor
872
820
  bitwiseXor(other) {
873
- // Use backend of tensor's device if available, or else fallback to cpu
874
- const backend = Tensor.backends.get(this.device);
875
- if (backend && backend.bitwiseXor) {
876
- return backend.bitwiseXor(this, other);
877
- }
878
821
  return this.elementWiseABDAG(other, (a, b) => a ^ b);
879
822
  }
880
823
  // Tensor element-wise bitwise not
881
824
  bitwiseNot() {
882
- // Use backend of tensor's device if available, or else fallback to cpu
883
- const backend = Tensor.backends.get(this.device);
884
- if (backend && backend.bitwiseNot) {
885
- return backend.bitwiseNot(this);
886
- }
887
825
  return this.elementWiseSelfDAG((a) => ~a);
888
826
  }
889
827
  // Tensor element-wise left shift
890
828
  bitwiseLeftShift(other) {
891
- // Use backend of tensor's device if available, or else fallback to cpu
892
- const backend = Tensor.backends.get(this.device);
893
- if (backend && backend.bitwiseLeftShift) {
894
- return backend.bitwiseLeftShift(this, other);
895
- }
896
829
  return this.elementWiseABDAG(other, (a, b) => a << b);
897
830
  }
898
831
  // Tensor element-wise right shift
899
832
  bitwiseRightShift(other) {
900
- // Use backend of tensor's device if available, or else fallback to cpu
901
- const backend = Tensor.backends.get(this.device);
902
- if (backend && backend.bitwiseRightShift) {
903
- return backend.bitwiseRightShift(this, other);
904
- }
905
833
  return this.elementWiseABDAG(other, (a, b) => a >> b);
906
834
  }
907
835
  // Tensor element-wise negation
908
836
  neg() {
909
- // Use backend of tensor's device if available, or else fallback to cpu
910
- const backend = Tensor.backends.get(this.device);
911
- if (backend && backend.neg) {
912
- return backend.neg(this);
913
- }
914
837
  return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
915
838
  }
916
839
  negative = this.neg;
917
840
  // Tensor element-wise reciprocal
918
841
  reciprocal() {
919
- // Use backend of tensor's device if available, or else fallback to cpu
920
- const backend = Tensor.backends.get(this.device);
921
- if (backend && backend.reciprocal) {
922
- return backend.reciprocal(this);
923
- }
924
842
  return this.elementWiseSelfDAG((a) => 1 / a, (self, outGrad) => outGrad.mul(self.pow(-2).neg()));
925
843
  }
926
844
  // Tensor element-wise square
927
845
  square() {
928
- // Use backend of tensor's device if available, or else fallback to cpu
929
- const backend = Tensor.backends.get(this.device);
930
- if (backend && backend.square) {
931
- return backend.square(this);
932
- }
933
846
  return this.elementWiseSelfDAG((a) => a * a, (self, outGrad) => outGrad.mul(self.mul(2)));
934
847
  }
935
848
  // Tensor element-wise absolute
936
849
  abs() {
937
- // Use backend of tensor's device if available, or else fallback to cpu
938
- const backend = Tensor.backends.get(this.device);
939
- if (backend && backend.abs) {
940
- return backend.abs(this);
941
- }
942
850
  return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
943
851
  }
944
852
  absolute = this.abs;
945
853
  // Tensor element-wise sign function
946
854
  sign() {
947
- // Use backend of tensor's device if available, or else fallback to cpu
948
- const backend = Tensor.backends.get(this.device);
949
- if (backend && backend.sign) {
950
- return backend.sign(this);
951
- }
952
855
  return this.elementWiseSelfDAG((a) => Math.sign(a));
953
856
  }
954
857
  // Tensor element-wise sin
955
858
  sin() {
956
- // Use backend of tensor's device if available, or else fallback to cpu
957
- const backend = Tensor.backends.get(this.device);
958
- if (backend && backend.sin) {
959
- return backend.sin(this);
960
- }
961
859
  return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
962
860
  }
963
861
  // Tensor element-wise cos
964
862
  cos() {
965
- // Use backend of tensor's device if available, or else fallback to cpu
966
- const backend = Tensor.backends.get(this.device);
967
- if (backend && backend.cos) {
968
- return backend.cos(this);
969
- }
970
863
  return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
971
864
  }
972
865
  // Tensor element-wise tan
973
866
  tan() {
974
- // Use backend of tensor's device if available, or else fallback to cpu
975
- const backend = Tensor.backends.get(this.device);
976
- if (backend && backend.tan) {
977
- return backend.tan(this);
978
- }
979
867
  return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().square().add(1)));
980
868
  }
981
869
  // Tensor element-wise asin
982
870
  asin() {
983
- // Use backend of tensor's device if available, or else fallback to cpu
984
- const backend = Tensor.backends.get(this.device);
985
- if (backend && backend.asin) {
986
- return backend.asin(this);
987
- }
988
871
  return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()));
989
872
  }
990
873
  arcsin = this.asin;
991
874
  // Tensor element-wise acos
992
875
  acos() {
993
- // Use backend of tensor's device if available, or else fallback to cpu
994
- const backend = Tensor.backends.get(this.device);
995
- if (backend && backend.acos) {
996
- return backend.acos(this);
997
- }
998
876
  return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()).neg());
999
877
  }
1000
878
  arccos = this.acos;
1001
879
  // Tensor element-wise atan
1002
880
  atan() {
1003
- // Use backend of tensor's device if available, or else fallback to cpu
1004
- const backend = Tensor.backends.get(this.device);
1005
- if (backend && backend.atan) {
1006
- return backend.atan(this);
1007
- }
1008
881
  return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.square().add(1)));
1009
882
  }
1010
883
  arctan = this.atan;
1011
884
  // Tensor element-wise atan2
1012
885
  atan2(other) {
1013
- // Use backend of tensor's device if available, or else fallback to cpu
1014
- const backend = Tensor.backends.get(this.device);
1015
- if (backend && backend.atan2) {
1016
- return backend.atan2(this);
1017
- }
1018
886
  return this.elementWiseABDAG(other, (a, b) => Math.atan2(a, b), (self, other, outGrad) => outGrad.mul(other.div(self.square().add(other.square()))), (self, other, outGrad) => outGrad.mul(self.neg().div(self.square().add(other.square()))));
1019
887
  }
1020
888
  arctan2 = this.atan2;
1021
889
  // Tensor element-wise sinh
1022
890
  sinh() {
1023
- // Use backend of tensor's device if available, or else fallback to cpu
1024
- const backend = Tensor.backends.get(this.device);
1025
- if (backend && backend.sinh) {
1026
- return backend.sinh(this);
1027
- }
1028
891
  return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
1029
892
  }
1030
893
  // Tensor element-wise cosh
1031
894
  cosh() {
1032
- // Use backend of tensor's device if available, or else fallback to cpu
1033
- const backend = Tensor.backends.get(this.device);
1034
- if (backend && backend.cosh) {
1035
- return backend.cosh(this);
1036
- }
1037
895
  return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
1038
896
  }
1039
897
  // Tensor element-wise asinh
1040
898
  asinh() {
1041
- // Use backend of tensor's device if available, or else fallback to cpu
1042
- const backend = Tensor.backends.get(this.device);
1043
- if (backend && backend.asinh) {
1044
- return backend.asinh(this);
1045
- }
1046
899
  return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.square().add(1).sqrt()));
1047
900
  }
1048
901
  arcsinh = this.asinh;
1049
902
  // Tensor element-wise acosh
1050
903
  acosh() {
1051
- // Use backend of tensor's device if available, or else fallback to cpu
1052
- const backend = Tensor.backends.get(this.device);
1053
- if (backend && backend.acosh) {
1054
- return backend.acosh(this);
1055
- }
1056
904
  return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
1057
905
  }
1058
906
  arccosh = this.acosh;
1059
907
  // Tensor element-wise atanh
1060
908
  atanh() {
1061
- // Use backend of tensor's device if available, or else fallback to cpu
1062
- const backend = Tensor.backends.get(this.device);
1063
- if (backend && backend.atanh) {
1064
- return backend.atanh(this);
1065
- }
1066
909
  return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.square().neg().add(1)));
1067
910
  }
1068
911
  arctanh = this.atanh;
1069
912
  // Tensor element-wise degree to radian
1070
913
  deg2rad() {
1071
- // Use backend of tensor's device if available, or else fallback to cpu
1072
- const backend = Tensor.backends.get(this.device);
1073
- if (backend && backend.deg2rad) {
1074
- return backend.deg2rad(this);
1075
- }
1076
914
  return this.elementWiseSelfDAG((a) => a * (Math.PI / 180), (self, outGrad) => outGrad.mul(Math.PI / 180));
1077
915
  }
1078
916
  // Tensor element-wise radian to degree
1079
917
  rad2deg() {
1080
- // Use backend of tensor's device if available, or else fallback to cpu
1081
- const backend = Tensor.backends.get(this.device);
1082
- if (backend && backend.rad2deg) {
1083
- return backend.rad2deg(this);
1084
- }
1085
918
  return this.elementWiseSelfDAG((a) => a / (Math.PI / 180), (self, outGrad) => outGrad.div(Math.PI / 180));
1086
919
  }
1087
920
  // Tensor element-wise square root
1088
921
  sqrt() {
1089
- // Use backend of tensor's device if available, or else fallback to cpu
1090
- const backend = Tensor.backends.get(this.device);
1091
- if (backend && backend.sqrt) {
1092
- return backend.sqrt(this);
1093
- }
1094
922
  return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
1095
923
  }
1096
924
  // Tensor element-wise reciprocal of square root
1097
925
  rsqrt() {
1098
- // Use backend of tensor's device if available, or else fallback to cpu
1099
- const backend = Tensor.backends.get(this.device);
1100
- if (backend && backend.rsqrt) {
1101
- return backend.rsqrt(this);
1102
- }
1103
926
  return this.elementWiseSelfDAG((a) => 1 / Math.sqrt(a), (self, outGrad) => outGrad.mul(self.pow(-1.5).mul(-0.5)));
1104
927
  }
1105
928
  // Tensor element-wise e^x
1106
929
  exp() {
1107
- // Use backend of tensor's device if available, or else fallback to cpu
1108
- const backend = Tensor.backends.get(this.device);
1109
- if (backend && backend.exp) {
1110
- return backend.exp(this);
1111
- }
1112
930
  return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp()));
1113
931
  }
1114
932
  // Tensor element-wise 2^x
1115
933
  exp2() {
1116
- // Use backend of tensor's device if available, or else fallback to cpu
1117
- const backend = Tensor.backends.get(this.device);
1118
- if (backend && backend.exp2) {
1119
- return backend.exp2(this);
1120
- }
1121
934
  return this.elementWiseSelfDAG((a) => 2 ** a, (self, outGrad) => outGrad.mul(self.exp2().mul(Math.log(2))));
1122
935
  }
1123
936
  // Tensor element-wise e^x - 1
1124
937
  expm1() {
1125
- // Use backend of tensor's device if available, or else fallback to cpu
1126
- const backend = Tensor.backends.get(this.device);
1127
- if (backend && backend.expm1) {
1128
- return backend.expm1(this);
1129
- }
1130
938
  return this.elementWiseSelfDAG((a) => Math.expm1(a), (self, outGrad) => outGrad.mul(self.exp()));
1131
939
  }
1132
940
  // Tensor element-wise natural log
1133
941
  log() {
1134
- // Use backend of tensor's device if available, or else fallback to cpu
1135
- const backend = Tensor.backends.get(this.device);
1136
- if (backend && backend.log) {
1137
- return backend.log(this);
1138
- }
1139
942
  return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self));
1140
943
  }
1141
944
  // Tensor element-wise log2
1142
945
  log2() {
1143
- // Use backend of tensor's device if available, or else fallback to cpu
1144
- const backend = Tensor.backends.get(this.device);
1145
- if (backend && backend.log2) {
1146
- return backend.log2(this);
1147
- }
1148
946
  return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2))));
1149
947
  }
1150
948
  // Tensor element-wise log10
1151
949
  log10() {
1152
- // Use backend of tensor's device if available, or else fallback to cpu
1153
- const backend = Tensor.backends.get(this.device);
1154
- if (backend && backend.log10) {
1155
- return backend.log10(this);
1156
- }
1157
950
  return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10))));
1158
951
  }
1159
952
  // Tensor element-wise log(1+x)
1160
953
  log1p() {
1161
- // Use backend of tensor's device if available, or else fallback to cpu
1162
- const backend = Tensor.backends.get(this.device);
1163
- if (backend && backend.log1p) {
1164
- return backend.log1p(this);
1165
- }
1166
954
  return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1)));
1167
955
  }
1168
956
  // Tensor element-wise relu
1169
957
  relu() {
1170
- // Use backend of tensor's device if available, or else fallback to cpu
1171
- const backend = Tensor.backends.get(this.device);
1172
- if (backend && backend.relu) {
1173
- return backend.relu(this);
1174
- }
1175
958
  return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.gt(0)));
1176
959
  }
1177
960
  // Tensor element-wise sigmoid
1178
961
  sigmoid() {
1179
- // Use backend of tensor's device if available, or else fallback to cpu
1180
- const backend = Tensor.backends.get(this.device);
1181
- if (backend && backend.sigmoid) {
1182
- return backend.sigmoid(this);
1183
- }
1184
962
  return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
1185
963
  const sig = self.sigmoid();
1186
964
  return outGrad.mul(sig).mul(sig.neg().add(1));
@@ -1188,38 +966,18 @@ class Tensor {
1188
966
  }
1189
967
  // Tensor element-wise tanh
1190
968
  tanh() {
1191
- // Use backend of tensor's device if available, or else fallback to cpu
1192
- const backend = Tensor.backends.get(this.device);
1193
- if (backend && backend.tanh) {
1194
- return backend.tanh(this);
1195
- }
1196
969
  return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().square().neg().add(1)));
1197
970
  }
1198
971
  // Tensor element-wise softplus
1199
972
  softplus() {
1200
- // Use backend of tensor's device if available, or else fallback to cpu
1201
- const backend = Tensor.backends.get(this.device);
1202
- if (backend && backend.softplus) {
1203
- return backend.softplus(this);
1204
- }
1205
973
  return this.elementWiseSelfDAG((a) => Math.log1p(Math.exp(a)), (self, outGrad) => outGrad.mul(self.sigmoid()));
1206
974
  }
1207
975
  // Tensor element-wise softsign
1208
976
  softsign() {
1209
- // Use backend of tensor's device if available, or else fallback to cpu
1210
- const backend = Tensor.backends.get(this.device);
1211
- if (backend && backend.softsign) {
1212
- return backend.softsign(this);
1213
- }
1214
977
  return this.elementWiseSelfDAG((a) => a / (1 + Math.abs(a)), (self, outGrad) => outGrad.div(self.abs().add(1).square()));
1215
978
  }
1216
979
  // Tensor element-wise silu (swish)
1217
980
  silu() {
1218
- // Use backend of tensor's device if available, or else fallback to cpu
1219
- const backend = Tensor.backends.get(this.device);
1220
- if (backend && backend.silu) {
1221
- return backend.silu(this);
1222
- }
1223
981
  return this.elementWiseSelfDAG((a) => a / (1 + Math.exp(-a)), (self, outGrad) => {
1224
982
  const sig = self.sigmoid();
1225
983
  return outGrad.mul(sig.add(self.mul(sig).mul(sig.neg().add(1))));
@@ -1227,11 +985,6 @@ class Tensor {
1227
985
  }
1228
986
  // Tensor element-wise mish
1229
987
  mish() {
1230
- // Use backend of tensor's device if available, or else fallback to cpu
1231
- const backend = Tensor.backends.get(this.device);
1232
- if (backend && backend.mish) {
1233
- return backend.mish(this);
1234
- }
1235
988
  return this.elementWiseSelfDAG((a) => a * Math.tanh(Math.log1p(Math.exp(a))), (self, outGrad) => {
1236
989
  const tanhSoftPlus = self.exp().add(1).log().tanh();
1237
990
  // tanh(softplus(x)) + x * (1 - tanh²(softplus(x))) * sigmoid(x)
@@ -1241,103 +994,48 @@ class Tensor {
1241
994
  }
1242
995
  // Tensor element-wise maximum
1243
996
  maximum(other) {
1244
- // Use backend of tensor's device if available, or else fallback to cpu
1245
- const backend = Tensor.backends.get(this.device);
1246
- if (backend && backend.maximum) {
1247
- return backend.maximum(this, other);
1248
- }
1249
997
  return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
1250
998
  }
1251
999
  // Tensor element-wise minimum
1252
1000
  minimum(other) {
1253
- // Use backend of tensor's device if available, or else fallback to cpu
1254
- const backend = Tensor.backends.get(this.device);
1255
- if (backend && backend.minimum) {
1256
- return backend.minimum(this, other);
1257
- }
1258
1001
  return this.elementWiseABDAG(other, (a, b) => Math.min(a, b), (self, other, outGrad) => outGrad.mul(self.lt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.lt(self).add(other.eq(self).mul(0.5))));
1259
1002
  }
1260
1003
  // Tensor element-wise round
1261
1004
  round() {
1262
- // Use backend of tensor's device if available, or else fallback to cpu
1263
- const backend = Tensor.backends.get(this.device);
1264
- if (backend && backend.round) {
1265
- return backend.round(this);
1266
- }
1267
1005
  return this.elementWiseSelfDAG((a) => Math.round(a));
1268
1006
  }
1269
1007
  // Tensor element-wise floor
1270
1008
  floor() {
1271
- // Use backend of tensor's device if available, or else fallback to cpu
1272
- const backend = Tensor.backends.get(this.device);
1273
- if (backend && backend.floor) {
1274
- return backend.floor(this);
1275
- }
1276
1009
  return this.elementWiseSelfDAG((a) => Math.floor(a));
1277
1010
  }
1278
1011
  // Tensor element-wise ceil
1279
1012
  ceil() {
1280
- // Use backend of tensor's device if available, or else fallback to cpu
1281
- const backend = Tensor.backends.get(this.device);
1282
- if (backend && backend.ceil) {
1283
- return backend.ceil(this);
1284
- }
1285
1013
  return this.elementWiseSelfDAG((a) => Math.ceil(a));
1286
1014
  }
1287
1015
  // Tensor element-wise truncation
1288
1016
  trunc() {
1289
- // Use backend of tensor's device if available, or else fallback to cpu
1290
- const backend = Tensor.backends.get(this.device);
1291
- if (backend && backend.trunc) {
1292
- return backend.trunc(this);
1293
- }
1294
1017
  return this.elementWiseSelfDAG((a) => Math.trunc(a));
1295
1018
  }
1296
1019
  fix = this.trunc;
1297
1020
  // Tensor element-wise fraction portion
1298
1021
  frac() {
1299
- // Use backend of tensor's device if available, or else fallback to cpu
1300
- const backend = Tensor.backends.get(this.device);
1301
- if (backend && backend.frac) {
1302
- return backend.frac(this);
1303
- }
1304
1022
  return this.elementWiseSelfDAG((a) => a - Math.floor(a));
1305
1023
  }
1306
1024
  // Tensor element-wise clip and clamp
1307
1025
  clip(min, max) {
1308
- // Use backend of tensor's device if available, or else fallback to cpu
1309
- const backend = Tensor.backends.get(this.device);
1310
- if (backend && backend.clip) {
1311
- return backend.clip(this, min, max);
1312
- }
1313
1026
  return this.elementWiseSelfDAG((a) => Math.max(min, Math.min(max, a)), (self, outGrad) => outGrad.mul(self.ge(min).mul(self.le(max))));
1314
1027
  }
1315
1028
  clamp = this.clip;
1316
1029
  // Tensor element-wise error function
1317
1030
  erf() {
1318
- // Use backend of tensor's device if available, or else fallback to cpu
1319
- const backend = Tensor.backends.get(this.device);
1320
- if (backend && backend.erf) {
1321
- return backend.erf(this);
1322
- }
1323
1031
  return this.elementWiseSelfDAG((a) => (0, utils_1.erf)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI))));
1324
1032
  }
1325
1033
  // Tensor element-wise complementary error function
1326
1034
  erfc() {
1327
- // Use backend of tensor's device if available, or else fallback to cpu
1328
- const backend = Tensor.backends.get(this.device);
1329
- if (backend && backend.erfc) {
1330
- return backend.erfc(this);
1331
- }
1332
1035
  return this.elementWiseSelfDAG((a) => (0, utils_1.erfc)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI)).neg()));
1333
1036
  }
1334
1037
  // Tensor element-wise inverse error function
1335
1038
  erfinv() {
1336
- // Use backend of tensor's device if available, or else fallback to cpu
1337
- const backend = Tensor.backends.get(this.device);
1338
- if (backend && backend.erfinv) {
1339
- return backend.erfinv(this);
1340
- }
1341
1039
  return this.elementWiseSelfDAG((a) => (0, utils_1.erfinv)(a), (self, outGrad) => outGrad.mul(self.erfinv().square().exp().mul(Math.sqrt(Math.PI) / 2)));
1342
1040
  }
1343
1041
  // Transpose
@@ -1379,11 +1077,6 @@ class Tensor {
1379
1077
  }
1380
1078
  // 1D tensor dot product
1381
1079
  dot(other) {
1382
- // Use backend of tensor's device if available, or else fallback to cpu
1383
- const backend = Tensor.backends.get(this.device);
1384
- if (backend && backend.dot) {
1385
- return backend.dot(this, other);
1386
- }
1387
1080
  other = Tensor.forceTensor(other);
1388
1081
  // Verify 1D shape
1389
1082
  if (this.shape.length !== 1 || other.shape.length !== 1) {
@@ -1422,11 +1115,6 @@ class Tensor {
1422
1115
  }
1423
1116
  // Matrix multiplication
1424
1117
  mm(other) {
1425
- // Use backend of tensor's device if available, or else fallback to cpu
1426
- const backend = Tensor.backends.get(this.device);
1427
- if (backend && backend.mm) {
1428
- return backend.mm(this, other);
1429
- }
1430
1118
  other = Tensor.forceTensor(other);
1431
1119
  // Verify 2D shape
1432
1120
  if (this.shape.length !== 2 || other.shape.length !== 2) {
@@ -1482,11 +1170,6 @@ class Tensor {
1482
1170
  }
1483
1171
  // Batched 3D tensor matmul
1484
1172
  bmm(other) {
1485
- // Use backend of tensor's device if available, or else fallback to cpu
1486
- const backend = Tensor.backends.get(this.device);
1487
- if (backend && backend.bmm) {
1488
- return backend.bmm(this, other);
1489
- }
1490
1173
  other = Tensor.forceTensor(other);
1491
1174
  // Verify 3D shape
1492
1175
  if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
@@ -1545,11 +1228,6 @@ class Tensor {
1545
1228
  }
1546
1229
  // Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
1547
1230
  mv(other) {
1548
- // Use backend of tensor's device if available, or else fallback to cpu
1549
- const backend = Tensor.backends.get(this.device);
1550
- if (backend && backend.mv) {
1551
- return backend.mv(this, other);
1552
- }
1553
1231
  other = Tensor.forceTensor(other);
1554
1232
  // Verify 2D shape
1555
1233
  if (this.shape.length !== 2 || other.shape.length !== 1) {
@@ -1559,11 +1237,6 @@ class Tensor {
1559
1237
  }
1560
1238
  // General matrix multiplication with different shapes
1561
1239
  matmul(other) {
1562
- // Use backend of tensor's device if available, or else fallback to cpu
1563
- const backend = Tensor.backends.get(this.device);
1564
- if (backend && backend.matmul) {
1565
- return backend.matmul(this, other);
1566
- }
1567
1240
  other = Tensor.forceTensor(other);
1568
1241
  const isThis1D = this.shape.length === 1;
1569
1242
  const isOther1D = other.shape.length === 1;
@@ -1875,10 +1548,10 @@ class Tensor {
1875
1548
  // Op to transfer tensor to another device
1876
1549
  to(device) {
1877
1550
  const backend = Tensor.backends.get(device);
1878
- if (backend && backend.to) {
1879
- return backend.to(this);
1551
+ if (backend && backend.transfer) {
1552
+ return backend.transfer(this);
1880
1553
  }
1881
- throw new Error(`No device found to transfer tensor to or "to" is not implemented for device.`);
1554
+ throw new Error(`No device found to transfer tensor to or a handler is not implemented for device.`);
1882
1555
  }
1883
1556
  }
1884
1557
  exports.Tensor = Tensor;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.4.0",
3
+ "version": "0.4.2",
4
4
  "description": "A small Torch-like deep learning framework for Javascript with tensor and autograd support",
5
5
  "main": "index.js",
6
6
  "scripts": {