catniff 0.3.0 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -72,6 +72,28 @@ G.backward();
72
72
  console.log(X.grad.val(), Y.grad.val());
73
73
  ```
74
74
 
75
+ ## Optimizer
76
+
77
+ Catniff comes bundled with optimizers as well:
78
+ ```js
79
+ const { Tensor, Optim } = require("catniff");
80
+
81
+ // Define some parameter
82
+ const w = new Tensor([1.0], { requiresGrad: true });
83
+ // Define a fake loss function: L = (w - 3)^2
84
+ const loss = w.sub(3).pow(2);
85
+ // Calculate gradient
86
+ loss.backward();
87
+ // Use Adam optimizer
88
+ const optim = new Optim.Adam([w]);
89
+ // Optimization step
90
+ optim.step();
91
+
92
+ console.log("Updated weight:", w.data); // Should move toward 3.0
93
+ ```
94
+
95
+ And it can still do much more, check out the docs mentioned below for more information.
96
+
75
97
  ## Documentation
76
98
 
77
99
  Full documentation is available in [`./docs/documentation.md`](./docs/documentation.md).
@@ -83,7 +105,6 @@ All available APIs are in [`./src/`](./src/) if you want to dig deeper.
83
105
  * Bug fixes.
84
106
  * More tensor ops.
85
107
  * GPU acceleration.
86
- * Option to load more backends.
87
108
  * Some general neural net APIs.
88
109
  * More detailed documentation.
89
110
  * Code refactoring.
@@ -0,0 +1,83 @@
1
+ import { Tensor, TensorValue } from "./core";
2
+ export interface Backend {
3
+ sum?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
4
+ prod?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
5
+ mean?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
6
+ max?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
7
+ min?(tensor: Tensor, dims?: number[] | number, keepDims?: boolean): Tensor;
8
+ softmax?(tensor: Tensor, dims?: number[] | number): Tensor;
9
+ add?(self: Tensor, other: Tensor | TensorValue): Tensor;
10
+ sub?(self: Tensor, other: Tensor | TensorValue): Tensor;
11
+ mul?(self: Tensor, other: Tensor | TensorValue): Tensor;
12
+ pow?(self: Tensor, other: Tensor | TensorValue): Tensor;
13
+ div?(self: Tensor, other: Tensor | TensorValue): Tensor;
14
+ remainder?(self: Tensor, other: Tensor | TensorValue): Tensor;
15
+ ge?(self: Tensor, other: Tensor | TensorValue): Tensor;
16
+ le?(self: Tensor, other: Tensor | TensorValue): Tensor;
17
+ gt?(self: Tensor, other: Tensor | TensorValue): Tensor;
18
+ lt?(self: Tensor, other: Tensor | TensorValue): Tensor;
19
+ eq?(self: Tensor, other: Tensor | TensorValue): Tensor;
20
+ ne?(self: Tensor, other: Tensor | TensorValue): Tensor;
21
+ logicalAnd?(self: Tensor, other: Tensor | TensorValue): Tensor;
22
+ logicalOr?(self: Tensor, other: Tensor | TensorValue): Tensor;
23
+ logicalXor?(self: Tensor, other: Tensor | TensorValue): Tensor;
24
+ logicalNot?(self: Tensor): Tensor;
25
+ bitwiseAnd?(self: Tensor, other: Tensor | TensorValue): Tensor;
26
+ bitwiseOr?(self: Tensor, other: Tensor | TensorValue): Tensor;
27
+ bitwiseXor?(self: Tensor, other: Tensor | TensorValue): Tensor;
28
+ bitwiseNot?(self: Tensor): Tensor;
29
+ bitwiseLeftShift?(self: Tensor, other: Tensor | TensorValue): Tensor;
30
+ bitwiseRightShift?(self: Tensor, other: Tensor | TensorValue): Tensor;
31
+ neg?(self: Tensor): Tensor;
32
+ reciprocal?(self: Tensor): Tensor;
33
+ square?(self: Tensor): Tensor;
34
+ abs?(self: Tensor): Tensor;
35
+ sign?(self: Tensor): Tensor;
36
+ sin?(self: Tensor): Tensor;
37
+ cos?(self: Tensor): Tensor;
38
+ tan?(self: Tensor): Tensor;
39
+ asin?(self: Tensor): Tensor;
40
+ acos?(self: Tensor): Tensor;
41
+ atan?(self: Tensor): Tensor;
42
+ atan2?(self: Tensor): Tensor;
43
+ sinh?(self: Tensor): Tensor;
44
+ cosh?(self: Tensor): Tensor;
45
+ asinh?(self: Tensor): Tensor;
46
+ acosh?(self: Tensor): Tensor;
47
+ atanh?(self: Tensor): Tensor;
48
+ deg2rad?(self: Tensor): Tensor;
49
+ rad2deg?(self: Tensor): Tensor;
50
+ sqrt?(self: Tensor): Tensor;
51
+ rsqrt?(self: Tensor): Tensor;
52
+ exp?(self: Tensor): Tensor;
53
+ exp2?(self: Tensor): Tensor;
54
+ expm1?(self: Tensor): Tensor;
55
+ log?(self: Tensor): Tensor;
56
+ log2?(self: Tensor): Tensor;
57
+ log10?(self: Tensor): Tensor;
58
+ log1p?(self: Tensor): Tensor;
59
+ relu?(self: Tensor): Tensor;
60
+ sigmoid?(self: Tensor): Tensor;
61
+ tanh?(self: Tensor): Tensor;
62
+ softplus?(self: Tensor): Tensor;
63
+ softsign?(self: Tensor): Tensor;
64
+ silu?(self: Tensor): Tensor;
65
+ mish?(self: Tensor): Tensor;
66
+ maximum?(self: Tensor, other: Tensor | TensorValue): Tensor;
67
+ minimum?(self: Tensor, other: Tensor | TensorValue): Tensor;
68
+ round?(self: Tensor): Tensor;
69
+ floor?(self: Tensor): Tensor;
70
+ ceil?(self: Tensor): Tensor;
71
+ trunc?(self: Tensor): Tensor;
72
+ frac?(self: Tensor): Tensor;
73
+ clip?(self: Tensor, min: number, max: number): Tensor;
74
+ erf?(self: Tensor): Tensor;
75
+ erfc?(self: Tensor): Tensor;
76
+ erfinv?(self: Tensor): Tensor;
77
+ dot?(self: Tensor, other: Tensor | TensorValue): Tensor;
78
+ mm?(self: Tensor, other: Tensor | TensorValue): Tensor;
79
+ bmm?(self: Tensor, other: Tensor | TensorValue): Tensor;
80
+ mv?(self: Tensor, other: Tensor | TensorValue): Tensor;
81
+ matmul?(self: Tensor, other: Tensor | TensorValue): Tensor;
82
+ to?(tensor: Tensor): Tensor;
83
+ }
@@ -0,0 +1,2 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
package/dist/core.d.ts CHANGED
@@ -1,3 +1,4 @@
1
+ import { Backend } from "./backend";
1
2
  export type TensorValue = number | TensorValue[];
2
3
  export interface TensorOptions {
3
4
  shape?: readonly number[];
@@ -6,6 +7,7 @@ export interface TensorOptions {
6
7
  requiresGrad?: boolean;
7
8
  gradFn?: Function;
8
9
  children?: Tensor[];
10
+ device?: string;
9
11
  }
10
12
  export declare class Tensor {
11
13
  value: number[] | number;
@@ -15,6 +17,7 @@ export declare class Tensor {
15
17
  requiresGrad: boolean;
16
18
  gradFn: Function;
17
19
  children: Tensor[];
20
+ device: string;
18
21
  constructor(value: TensorValue, options?: TensorOptions);
19
22
  static flatten(tensor: TensorValue): number[] | number;
20
23
  static getShape(tensor: TensorValue): readonly number[];
@@ -161,4 +164,6 @@ export declare class Tensor {
161
164
  detach(): Tensor;
162
165
  clone(): Tensor;
163
166
  replace(other: Tensor, allowShapeMismatch?: boolean): Tensor;
167
+ static backends: Map<string, Backend>;
168
+ to(device: string): Tensor;
164
169
  }
package/dist/core.js CHANGED
@@ -10,6 +10,7 @@ class Tensor {
10
10
  requiresGrad;
11
11
  gradFn;
12
12
  children;
13
+ device;
13
14
  constructor(value, options = {}) {
14
15
  this.value = Tensor.flatten(value);
15
16
  this.shape = options.shape || Tensor.getShape(value);
@@ -18,6 +19,7 @@ class Tensor {
18
19
  this.requiresGrad = options.requiresGrad ?? false;
19
20
  this.gradFn = options.gradFn || (() => { });
20
21
  this.children = options.children || [];
22
+ this.device = options.device || "cpu";
21
23
  }
22
24
  // Utility to flatten an nD array to be 1D
23
25
  static flatten(tensor) {
@@ -52,6 +54,8 @@ class Tensor {
52
54
  }
53
55
  // Utility to get strides from shape
54
56
  static getStrides(shape) {
57
+ if (shape.length === 0)
58
+ return [];
55
59
  const strides = new Array(shape.length);
56
60
  strides[strides.length - 1] = 1;
57
61
  for (let i = strides.length - 2; i >= 0; i--) {
@@ -282,7 +286,8 @@ class Tensor {
282
286
  const outValue = outShape.length === 0 ? this.value[0] : this.value;
283
287
  const out = new Tensor(outValue, {
284
288
  shape: outShape,
285
- strides: outStrides
289
+ strides: outStrides,
290
+ device: this.device
286
291
  });
287
292
  // Set up gradient if needed
288
293
  if (this.requiresGrad) {
@@ -319,7 +324,7 @@ class Tensor {
319
324
  newDimStride = this.strides[dim] * this.shape[dim];
320
325
  }
321
326
  newStrides.splice(dim, 0, newDimStride);
322
- const out = new Tensor(thisValue, { shape: newShape, strides: newStrides });
327
+ const out = new Tensor(thisValue, { shape: newShape, strides: newStrides, device: this.device });
323
328
  // Set up gradient if needed
324
329
  if (this.requiresGrad) {
325
330
  out.requiresGrad = true;
@@ -332,6 +337,11 @@ class Tensor {
332
337
  }
333
338
  // Tensor sum reduction
334
339
  sum(dims, keepDims = false) {
340
+ // Use backend of tensor's device if available, or else fallback to cpu
341
+ const backend = Tensor.backends.get(this.device);
342
+ if (backend && backend.sum) {
343
+ return backend.sum(this, dims, keepDims);
344
+ }
335
345
  if (typeof this.value === "number")
336
346
  return this;
337
347
  if (typeof dims === "number") {
@@ -385,6 +395,11 @@ class Tensor {
385
395
  }
386
396
  // Tensor product reduction
387
397
  prod(dims, keepDims = false) {
398
+ // Use backend of tensor's device if available, or else fallback to cpu
399
+ const backend = Tensor.backends.get(this.device);
400
+ if (backend && backend.prod) {
401
+ return backend.prod(this, dims, keepDims);
402
+ }
388
403
  if (typeof this.value === "number")
389
404
  return this;
390
405
  if (typeof dims === "number") {
@@ -436,6 +451,11 @@ class Tensor {
436
451
  }
437
452
  // Tensor mean reduction
438
453
  mean(dims, keepDims = false) {
454
+ // Use backend of tensor's device if available, or else fallback to cpu
455
+ const backend = Tensor.backends.get(this.device);
456
+ if (backend && backend.mean) {
457
+ return backend.mean(this, dims, keepDims);
458
+ }
439
459
  if (typeof this.value === "number")
440
460
  return this;
441
461
  if (typeof dims === "number") {
@@ -494,6 +514,11 @@ class Tensor {
494
514
  }
495
515
  // Tensor maximum reduction
496
516
  max(dims, keepDims = false) {
517
+ // Use backend of tensor's device if available, or else fallback to cpu
518
+ const backend = Tensor.backends.get(this.device);
519
+ if (backend && backend.max) {
520
+ return backend.max(this, dims, keepDims);
521
+ }
497
522
  if (typeof this.value === "number")
498
523
  return this;
499
524
  if (typeof dims === "number") {
@@ -547,6 +572,11 @@ class Tensor {
547
572
  }
548
573
  // Tensor minimum reduction
549
574
  min(dims, keepDims = false) {
575
+ // Use backend of tensor's device if available, or else fallback to cpu
576
+ const backend = Tensor.backends.get(this.device);
577
+ if (backend && backend.min) {
578
+ return backend.min(this, dims, keepDims);
579
+ }
550
580
  if (typeof this.value === "number")
551
581
  return this;
552
582
  if (typeof dims === "number") {
@@ -600,6 +630,11 @@ class Tensor {
600
630
  }
601
631
  // Tensor product reduction
602
632
  softmax(dims) {
633
+ // Use backend of tensor's device if available, or else fallback to cpu
634
+ const backend = Tensor.backends.get(this.device);
635
+ if (backend && backend.softmax) {
636
+ return backend.softmax(this, dims);
637
+ }
603
638
  if (typeof this.value === "number")
604
639
  return this;
605
640
  if (typeof dims === "number") {
@@ -664,228 +699,488 @@ class Tensor {
664
699
  }
665
700
  // Tensor element-wise addition
666
701
  add(other) {
702
+ // Use backend of tensor's device if available, or else fallback to cpu
703
+ const backend = Tensor.backends.get(this.device);
704
+ if (backend && backend.add) {
705
+ return backend.add(this, other);
706
+ }
667
707
  return this.elementWiseABDAG(other, (a, b) => a + b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad);
668
708
  }
669
709
  // Tensor element-wise subtraction
670
710
  sub(other) {
711
+ // Use backend of tensor's device if available, or else fallback to cpu
712
+ const backend = Tensor.backends.get(this.device);
713
+ if (backend && backend.sub) {
714
+ return backend.sub(this, other);
715
+ }
671
716
  return this.elementWiseABDAG(other, (a, b) => a - b, (self, other, outGrad) => outGrad, (self, other, outGrad) => outGrad.neg());
672
717
  }
673
718
  subtract = this.sub;
674
719
  // Tensor element-wise multiplication
675
720
  mul(other) {
721
+ // Use backend of tensor's device if available, or else fallback to cpu
722
+ const backend = Tensor.backends.get(this.device);
723
+ if (backend && backend.mul) {
724
+ return backend.mul(this, other);
725
+ }
676
726
  return this.elementWiseABDAG(other, (a, b) => a * b, (self, other, outGrad) => outGrad.mul(other), (self, other, outGrad) => outGrad.mul(self));
677
727
  }
678
728
  multiply = this.mul;
679
729
  // Tensor element-wise power
680
730
  pow(other) {
731
+ // Use backend of tensor's device if available, or else fallback to cpu
732
+ const backend = Tensor.backends.get(this.device);
733
+ if (backend && backend.pow) {
734
+ return backend.pow(this, other);
735
+ }
681
736
  return this.elementWiseABDAG(other, (a, b) => a ** b, (self, other, outGrad) => outGrad.mul(other.mul(self.pow(other.sub(1)))), (self, other, outGrad) => outGrad.mul(self.pow(other).mul(self.log())));
682
737
  }
683
738
  // Tensor element-wise division
684
739
  div(other) {
740
+ // Use backend of tensor's device if available, or else fallback to cpu
741
+ const backend = Tensor.backends.get(this.device);
742
+ if (backend && backend.div) {
743
+ return backend.div(this, other);
744
+ }
685
745
  return this.elementWiseABDAG(other, (a, b) => a / b, (self, other, outGrad) => outGrad.div(other), (self, other, outGrad) => outGrad.mul(self.neg().div(other.square())));
686
746
  }
687
747
  divide = this.div;
688
748
  // Tensor element-wise modulo
689
749
  remainder(other) {
750
+ // Use backend of tensor's device if available, or else fallback to cpu
751
+ const backend = Tensor.backends.get(this.device);
752
+ if (backend && backend.remainder) {
753
+ return backend.remainder(this, other);
754
+ }
690
755
  return this.elementWiseABDAG(other, (a, b) => a % b);
691
756
  }
692
757
  // Tensor element-wise greater or equal comparison
693
758
  ge(other) {
759
+ // Use backend of tensor's device if available, or else fallback to cpu
760
+ const backend = Tensor.backends.get(this.device);
761
+ if (backend && backend.ge) {
762
+ return backend.ge(this, other);
763
+ }
694
764
  return this.elementWiseABDAG(other, (a, b) => a >= b ? 1 : 0);
695
765
  }
696
766
  greaterEqual = this.ge;
697
767
  // Tensor element-wise less or equal comparison
698
768
  le(other) {
769
+ // Use backend of tensor's device if available, or else fallback to cpu
770
+ const backend = Tensor.backends.get(this.device);
771
+ if (backend && backend.le) {
772
+ return backend.le(this, other);
773
+ }
699
774
  return this.elementWiseABDAG(other, (a, b) => a <= b ? 1 : 0);
700
775
  }
701
776
  lessEqual = this.le;
702
777
  // Tensor element-wise greater-than comparison
703
778
  gt(other) {
779
+ // Use backend of tensor's device if available, or else fallback to cpu
780
+ const backend = Tensor.backends.get(this.device);
781
+ if (backend && backend.gt) {
782
+ return backend.gt(this, other);
783
+ }
704
784
  return this.elementWiseABDAG(other, (a, b) => a > b ? 1 : 0);
705
785
  }
706
786
  greater = this.gt;
707
787
  // Tensor element-wise less-than comparison
708
788
  lt(other) {
789
+ // Use backend of tensor's device if available, or else fallback to cpu
790
+ const backend = Tensor.backends.get(this.device);
791
+ if (backend && backend.lt) {
792
+ return backend.lt(this, other);
793
+ }
709
794
  return this.elementWiseABDAG(other, (a, b) => a < b ? 1 : 0);
710
795
  }
711
796
  less = this.lt;
712
797
  // Tensor element-wise equality comparison
713
798
  eq(other) {
799
+ // Use backend of tensor's device if available, or else fallback to cpu
800
+ const backend = Tensor.backends.get(this.device);
801
+ if (backend && backend.eq) {
802
+ return backend.eq(this, other);
803
+ }
714
804
  return this.elementWiseABDAG(other, (a, b) => a === b ? 1 : 0);
715
805
  }
716
806
  equal = this.eq;
717
807
  // Tensor element-wise not equality comparison
718
808
  ne(other) {
809
+ // Use backend of tensor's device if available, or else fallback to cpu
810
+ const backend = Tensor.backends.get(this.device);
811
+ if (backend && backend.ne) {
812
+ return backend.ne(this, other);
813
+ }
719
814
  return this.elementWiseABDAG(other, (a, b) => a !== b ? 1 : 0);
720
815
  }
721
816
  notEqual = this.ne;
722
817
  // Tensor element-wise logical and
723
818
  logicalAnd(other) {
819
+ // Use backend of tensor's device if available, or else fallback to cpu
820
+ const backend = Tensor.backends.get(this.device);
821
+ if (backend && backend.logicalAnd) {
822
+ return backend.logicalAnd(this, other);
823
+ }
724
824
  return this.elementWiseABDAG(other, (a, b) => a === 1 && b === 1 ? 1 : 0);
725
825
  }
726
826
  // Tensor element-wise logical or
727
827
  logicalOr(other) {
828
+ // Use backend of tensor's device if available, or else fallback to cpu
829
+ const backend = Tensor.backends.get(this.device);
830
+ if (backend && backend.logicalOr) {
831
+ return backend.logicalOr(this, other);
832
+ }
728
833
  return this.elementWiseABDAG(other, (a, b) => a === 1 || b === 1 ? 1 : 0);
729
834
  }
730
835
  // Tensor element-wise logical xor
731
836
  logicalXor(other) {
837
+ // Use backend of tensor's device if available, or else fallback to cpu
838
+ const backend = Tensor.backends.get(this.device);
839
+ if (backend && backend.logicalXor) {
840
+ return backend.logicalXor(this, other);
841
+ }
732
842
  return this.elementWiseABDAG(other, (a, b) => (a === 1 || b === 1) && a !== b ? 1 : 0);
733
843
  }
734
844
  // Tensor element-wise logical not
735
845
  logicalNot() {
846
+ // Use backend of tensor's device if available, or else fallback to cpu
847
+ const backend = Tensor.backends.get(this.device);
848
+ if (backend && backend.logicalNot) {
849
+ return backend.logicalNot(this);
850
+ }
736
851
  return this.elementWiseSelfDAG((a) => a === 1 ? 0 : 1);
737
852
  }
738
853
  // Tensor element-wise bitwise and
739
854
  bitwiseAnd(other) {
855
+ // Use backend of tensor's device if available, or else fallback to cpu
856
+ const backend = Tensor.backends.get(this.device);
857
+ if (backend && backend.bitwiseAnd) {
858
+ return backend.bitwiseAnd(this, other);
859
+ }
740
860
  return this.elementWiseABDAG(other, (a, b) => a & b);
741
861
  }
742
862
  // Tensor element-wise bitwise or
743
863
  bitwiseOr(other) {
864
+ // Use backend of tensor's device if available, or else fallback to cpu
865
+ const backend = Tensor.backends.get(this.device);
866
+ if (backend && backend.bitwiseOr) {
867
+ return backend.bitwiseOr(this, other);
868
+ }
744
869
  return this.elementWiseABDAG(other, (a, b) => a | b);
745
870
  }
746
871
  // Tensor element-wise bitwise xor
747
872
  bitwiseXor(other) {
873
+ // Use backend of tensor's device if available, or else fallback to cpu
874
+ const backend = Tensor.backends.get(this.device);
875
+ if (backend && backend.bitwiseXor) {
876
+ return backend.bitwiseXor(this, other);
877
+ }
748
878
  return this.elementWiseABDAG(other, (a, b) => a ^ b);
749
879
  }
750
880
  // Tensor element-wise bitwise not
751
881
  bitwiseNot() {
882
+ // Use backend of tensor's device if available, or else fallback to cpu
883
+ const backend = Tensor.backends.get(this.device);
884
+ if (backend && backend.bitwiseNot) {
885
+ return backend.bitwiseNot(this);
886
+ }
752
887
  return this.elementWiseSelfDAG((a) => ~a);
753
888
  }
754
889
  // Tensor element-wise left shift
755
890
  bitwiseLeftShift(other) {
891
+ // Use backend of tensor's device if available, or else fallback to cpu
892
+ const backend = Tensor.backends.get(this.device);
893
+ if (backend && backend.bitwiseLeftShift) {
894
+ return backend.bitwiseLeftShift(this, other);
895
+ }
756
896
  return this.elementWiseABDAG(other, (a, b) => a << b);
757
897
  }
758
898
  // Tensor element-wise right shift
759
899
  bitwiseRightShift(other) {
900
+ // Use backend of tensor's device if available, or else fallback to cpu
901
+ const backend = Tensor.backends.get(this.device);
902
+ if (backend && backend.bitwiseRightShift) {
903
+ return backend.bitwiseRightShift(this, other);
904
+ }
760
905
  return this.elementWiseABDAG(other, (a, b) => a >> b);
761
906
  }
762
907
  // Tensor element-wise negation
763
908
  neg() {
909
+ // Use backend of tensor's device if available, or else fallback to cpu
910
+ const backend = Tensor.backends.get(this.device);
911
+ if (backend && backend.neg) {
912
+ return backend.neg(this);
913
+ }
764
914
  return this.elementWiseSelfDAG((a) => -a, (self, outGrad) => outGrad.mul(-1));
765
915
  }
766
916
  negative = this.neg;
767
917
  // Tensor element-wise reciprocal
768
918
  reciprocal() {
919
+ // Use backend of tensor's device if available, or else fallback to cpu
920
+ const backend = Tensor.backends.get(this.device);
921
+ if (backend && backend.reciprocal) {
922
+ return backend.reciprocal(this);
923
+ }
769
924
  return this.elementWiseSelfDAG((a) => 1 / a, (self, outGrad) => outGrad.mul(self.pow(-2).neg()));
770
925
  }
771
926
  // Tensor element-wise square
772
927
  square() {
928
+ // Use backend of tensor's device if available, or else fallback to cpu
929
+ const backend = Tensor.backends.get(this.device);
930
+ if (backend && backend.square) {
931
+ return backend.square(this);
932
+ }
773
933
  return this.elementWiseSelfDAG((a) => a * a, (self, outGrad) => outGrad.mul(self.mul(2)));
774
934
  }
775
935
  // Tensor element-wise absolute
776
936
  abs() {
937
+ // Use backend of tensor's device if available, or else fallback to cpu
938
+ const backend = Tensor.backends.get(this.device);
939
+ if (backend && backend.abs) {
940
+ return backend.abs(this);
941
+ }
777
942
  return this.elementWiseSelfDAG((a) => Math.abs(a), (self, outGrad) => outGrad.mul(self.sign()));
778
943
  }
779
944
  absolute = this.abs;
780
945
  // Tensor element-wise sign function
781
946
  sign() {
947
+ // Use backend of tensor's device if available, or else fallback to cpu
948
+ const backend = Tensor.backends.get(this.device);
949
+ if (backend && backend.sign) {
950
+ return backend.sign(this);
951
+ }
782
952
  return this.elementWiseSelfDAG((a) => Math.sign(a));
783
953
  }
784
954
  // Tensor element-wise sin
785
955
  sin() {
956
+ // Use backend of tensor's device if available, or else fallback to cpu
957
+ const backend = Tensor.backends.get(this.device);
958
+ if (backend && backend.sin) {
959
+ return backend.sin(this);
960
+ }
786
961
  return this.elementWiseSelfDAG((a) => Math.sin(a), (self, outGrad) => outGrad.mul(self.cos()));
787
962
  }
788
963
  // Tensor element-wise cos
789
964
  cos() {
965
+ // Use backend of tensor's device if available, or else fallback to cpu
966
+ const backend = Tensor.backends.get(this.device);
967
+ if (backend && backend.cos) {
968
+ return backend.cos(this);
969
+ }
790
970
  return this.elementWiseSelfDAG((a) => Math.cos(a), (self, outGrad) => outGrad.mul(self.sin().neg()));
791
971
  }
792
972
  // Tensor element-wise tan
793
973
  tan() {
974
+ // Use backend of tensor's device if available, or else fallback to cpu
975
+ const backend = Tensor.backends.get(this.device);
976
+ if (backend && backend.tan) {
977
+ return backend.tan(this);
978
+ }
794
979
  return this.elementWiseSelfDAG((a) => Math.tan(a), (self, outGrad) => outGrad.mul(self.tan().square().add(1)));
795
980
  }
796
981
  // Tensor element-wise asin
797
982
  asin() {
983
+ // Use backend of tensor's device if available, or else fallback to cpu
984
+ const backend = Tensor.backends.get(this.device);
985
+ if (backend && backend.asin) {
986
+ return backend.asin(this);
987
+ }
798
988
  return this.elementWiseSelfDAG((a) => Math.asin(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()));
799
989
  }
800
990
  arcsin = this.asin;
801
991
  // Tensor element-wise acos
802
992
  acos() {
993
+ // Use backend of tensor's device if available, or else fallback to cpu
994
+ const backend = Tensor.backends.get(this.device);
995
+ if (backend && backend.acos) {
996
+ return backend.acos(this);
997
+ }
803
998
  return this.elementWiseSelfDAG((a) => Math.acos(a), (self, outGrad) => outGrad.div(self.square().neg().add(1).sqrt()).neg());
804
999
  }
805
1000
  arccos = this.acos;
806
1001
  // Tensor element-wise atan
807
1002
  atan() {
1003
+ // Use backend of tensor's device if available, or else fallback to cpu
1004
+ const backend = Tensor.backends.get(this.device);
1005
+ if (backend && backend.atan) {
1006
+ return backend.atan(this);
1007
+ }
808
1008
  return this.elementWiseSelfDAG((a) => Math.atan(a), (self, outGrad) => outGrad.div(self.square().add(1)));
809
1009
  }
810
1010
  arctan = this.atan;
811
1011
  // Tensor element-wise atan2
812
1012
  atan2(other) {
1013
+ // Use backend of tensor's device if available, or else fallback to cpu
1014
+ const backend = Tensor.backends.get(this.device);
1015
+ if (backend && backend.atan2) {
1016
+ return backend.atan2(this);
1017
+ }
813
1018
  return this.elementWiseABDAG(other, (a, b) => Math.atan2(a, b), (self, other, outGrad) => outGrad.mul(other.div(self.square().add(other.square()))), (self, other, outGrad) => outGrad.mul(self.neg().div(self.square().add(other.square()))));
814
1019
  }
815
1020
  arctan2 = this.atan2;
816
1021
  // Tensor element-wise sinh
817
1022
  sinh() {
1023
+ // Use backend of tensor's device if available, or else fallback to cpu
1024
+ const backend = Tensor.backends.get(this.device);
1025
+ if (backend && backend.sinh) {
1026
+ return backend.sinh(this);
1027
+ }
818
1028
  return this.elementWiseSelfDAG((a) => Math.sinh(a), (self, outGrad) => outGrad.mul(self.cosh()));
819
1029
  }
820
1030
  // Tensor element-wise cosh
821
1031
  cosh() {
1032
+ // Use backend of tensor's device if available, or else fallback to cpu
1033
+ const backend = Tensor.backends.get(this.device);
1034
+ if (backend && backend.cosh) {
1035
+ return backend.cosh(this);
1036
+ }
822
1037
  return this.elementWiseSelfDAG((a) => Math.cosh(a), (self, outGrad) => outGrad.mul(self.sinh()));
823
1038
  }
824
1039
  // Tensor element-wise asinh
825
1040
  asinh() {
1041
+ // Use backend of tensor's device if available, or else fallback to cpu
1042
+ const backend = Tensor.backends.get(this.device);
1043
+ if (backend && backend.asinh) {
1044
+ return backend.asinh(this);
1045
+ }
826
1046
  return this.elementWiseSelfDAG((a) => Math.asinh(a), (self, outGrad) => outGrad.div(self.square().add(1).sqrt()));
827
1047
  }
828
1048
  arcsinh = this.asinh;
829
1049
  // Tensor element-wise acosh
830
1050
  acosh() {
1051
+ // Use backend of tensor's device if available, or else fallback to cpu
1052
+ const backend = Tensor.backends.get(this.device);
1053
+ if (backend && backend.acosh) {
1054
+ return backend.acosh(this);
1055
+ }
831
1056
  return this.elementWiseSelfDAG((a) => Math.acosh(a), (self, outGrad) => outGrad.div(self.add(1).sqrt().mul(self.sub(1).sqrt())));
832
1057
  }
833
1058
  arccosh = this.acosh;
834
1059
  // Tensor element-wise atanh
835
1060
  atanh() {
1061
+ // Use backend of tensor's device if available, or else fallback to cpu
1062
+ const backend = Tensor.backends.get(this.device);
1063
+ if (backend && backend.atanh) {
1064
+ return backend.atanh(this);
1065
+ }
836
1066
  return this.elementWiseSelfDAG((a) => Math.atanh(a), (self, outGrad) => outGrad.div(self.square().neg().add(1)));
837
1067
  }
838
1068
  arctanh = this.atanh;
839
1069
  // Tensor element-wise degree to radian
840
1070
  deg2rad() {
1071
+ // Use backend of tensor's device if available, or else fallback to cpu
1072
+ const backend = Tensor.backends.get(this.device);
1073
+ if (backend && backend.deg2rad) {
1074
+ return backend.deg2rad(this);
1075
+ }
841
1076
  return this.elementWiseSelfDAG((a) => a * (Math.PI / 180), (self, outGrad) => outGrad.mul(Math.PI / 180));
842
1077
  }
843
1078
  // Tensor element-wise radian to degree
844
1079
  rad2deg() {
1080
+ // Use backend of tensor's device if available, or else fallback to cpu
1081
+ const backend = Tensor.backends.get(this.device);
1082
+ if (backend && backend.rad2deg) {
1083
+ return backend.rad2deg(this);
1084
+ }
845
1085
  return this.elementWiseSelfDAG((a) => a / (Math.PI / 180), (self, outGrad) => outGrad.div(Math.PI / 180));
846
1086
  }
847
1087
  // Tensor element-wise square root
848
1088
  sqrt() {
1089
+ // Use backend of tensor's device if available, or else fallback to cpu
1090
+ const backend = Tensor.backends.get(this.device);
1091
+ if (backend && backend.sqrt) {
1092
+ return backend.sqrt(this);
1093
+ }
849
1094
  return this.elementWiseSelfDAG((a) => Math.sqrt(a), (self, outGrad) => outGrad.div(self.sqrt().mul(2)));
850
1095
  }
851
1096
  // Tensor element-wise reciprocal of square root
852
1097
  rsqrt() {
1098
+ // Use backend of tensor's device if available, or else fallback to cpu
1099
+ const backend = Tensor.backends.get(this.device);
1100
+ if (backend && backend.rsqrt) {
1101
+ return backend.rsqrt(this);
1102
+ }
853
1103
  return this.elementWiseSelfDAG((a) => 1 / Math.sqrt(a), (self, outGrad) => outGrad.mul(self.pow(-1.5).mul(-0.5)));
854
1104
  }
855
1105
  // Tensor element-wise e^x
856
1106
  exp() {
1107
+ // Use backend of tensor's device if available, or else fallback to cpu
1108
+ const backend = Tensor.backends.get(this.device);
1109
+ if (backend && backend.exp) {
1110
+ return backend.exp(this);
1111
+ }
857
1112
  return this.elementWiseSelfDAG((a) => Math.exp(a), (self, outGrad) => outGrad.mul(self.exp()));
858
1113
  }
859
1114
  // Tensor element-wise 2^x
860
1115
  exp2() {
1116
+ // Use backend of tensor's device if available, or else fallback to cpu
1117
+ const backend = Tensor.backends.get(this.device);
1118
+ if (backend && backend.exp2) {
1119
+ return backend.exp2(this);
1120
+ }
861
1121
  return this.elementWiseSelfDAG((a) => 2 ** a, (self, outGrad) => outGrad.mul(self.exp2().mul(Math.log(2))));
862
1122
  }
863
1123
  // Tensor element-wise e^x - 1
864
1124
  expm1() {
1125
+ // Use backend of tensor's device if available, or else fallback to cpu
1126
+ const backend = Tensor.backends.get(this.device);
1127
+ if (backend && backend.expm1) {
1128
+ return backend.expm1(this);
1129
+ }
865
1130
  return this.elementWiseSelfDAG((a) => Math.expm1(a), (self, outGrad) => outGrad.mul(self.exp()));
866
1131
  }
867
1132
  // Tensor element-wise natural log
868
1133
  log() {
1134
+ // Use backend of tensor's device if available, or else fallback to cpu
1135
+ const backend = Tensor.backends.get(this.device);
1136
+ if (backend && backend.log) {
1137
+ return backend.log(this);
1138
+ }
869
1139
  return this.elementWiseSelfDAG((a) => Math.log(a), (self, outGrad) => outGrad.div(self));
870
1140
  }
871
1141
  // Tensor element-wise log2
872
1142
  log2() {
1143
+ // Use backend of tensor's device if available, or else fallback to cpu
1144
+ const backend = Tensor.backends.get(this.device);
1145
+ if (backend && backend.log2) {
1146
+ return backend.log2(this);
1147
+ }
873
1148
  return this.elementWiseSelfDAG((a) => Math.log2(a), (self, outGrad) => outGrad.div(self.mul(Math.log(2))));
874
1149
  }
875
1150
  // Tensor element-wise log10
876
1151
  log10() {
1152
+ // Use backend of tensor's device if available, or else fallback to cpu
1153
+ const backend = Tensor.backends.get(this.device);
1154
+ if (backend && backend.log10) {
1155
+ return backend.log10(this);
1156
+ }
877
1157
  return this.elementWiseSelfDAG((a) => Math.log10(a), (self, outGrad) => outGrad.div(self.mul(Math.log(10))));
878
1158
  }
879
1159
  // Tensor element-wise log(1+x)
880
1160
  log1p() {
1161
+ // Use backend of tensor's device if available, or else fallback to cpu
1162
+ const backend = Tensor.backends.get(this.device);
1163
+ if (backend && backend.log1p) {
1164
+ return backend.log1p(this);
1165
+ }
881
1166
  return this.elementWiseSelfDAG((a) => Math.log1p(a), (self, outGrad) => outGrad.div(self.add(1)));
882
1167
  }
883
1168
  // Tensor element-wise relu
884
1169
  relu() {
1170
+ // Use backend of tensor's device if available, or else fallback to cpu
1171
+ const backend = Tensor.backends.get(this.device);
1172
+ if (backend && backend.relu) {
1173
+ return backend.relu(this);
1174
+ }
885
1175
  return this.elementWiseSelfDAG((a) => Math.max(a, 0), (self, outGrad) => outGrad.mul(self.gt(0)));
886
1176
  }
887
1177
  // Tensor element-wise sigmoid
888
1178
  sigmoid() {
1179
+ // Use backend of tensor's device if available, or else fallback to cpu
1180
+ const backend = Tensor.backends.get(this.device);
1181
+ if (backend && backend.sigmoid) {
1182
+ return backend.sigmoid(this);
1183
+ }
889
1184
  return this.elementWiseSelfDAG((a) => 1 / (1 + Math.exp(-a)), (self, outGrad) => {
890
1185
  const sig = self.sigmoid();
891
1186
  return outGrad.mul(sig).mul(sig.neg().add(1));
@@ -893,18 +1188,38 @@ class Tensor {
893
1188
  }
894
1189
  // Tensor element-wise tanh
895
1190
  tanh() {
1191
+ // Use backend of tensor's device if available, or else fallback to cpu
1192
+ const backend = Tensor.backends.get(this.device);
1193
+ if (backend && backend.tanh) {
1194
+ return backend.tanh(this);
1195
+ }
896
1196
  return this.elementWiseSelfDAG((a) => Math.tanh(a), (self, outGrad) => outGrad.mul(self.tanh().square().neg().add(1)));
897
1197
  }
898
1198
  // Tensor element-wise softplus
899
1199
  softplus() {
1200
+ // Use backend of tensor's device if available, or else fallback to cpu
1201
+ const backend = Tensor.backends.get(this.device);
1202
+ if (backend && backend.softplus) {
1203
+ return backend.softplus(this);
1204
+ }
900
1205
  return this.elementWiseSelfDAG((a) => Math.log1p(Math.exp(a)), (self, outGrad) => outGrad.mul(self.sigmoid()));
901
1206
  }
902
1207
  // Tensor element-wise softsign
903
1208
  softsign() {
1209
+ // Use backend of tensor's device if available, or else fallback to cpu
1210
+ const backend = Tensor.backends.get(this.device);
1211
+ if (backend && backend.softsign) {
1212
+ return backend.softsign(this);
1213
+ }
904
1214
  return this.elementWiseSelfDAG((a) => a / (1 + Math.abs(a)), (self, outGrad) => outGrad.div(self.abs().add(1).square()));
905
1215
  }
906
1216
  // Tensor element-wise silu (swish)
907
1217
  silu() {
1218
+ // Use backend of tensor's device if available, or else fallback to cpu
1219
+ const backend = Tensor.backends.get(this.device);
1220
+ if (backend && backend.silu) {
1221
+ return backend.silu(this);
1222
+ }
908
1223
  return this.elementWiseSelfDAG((a) => a / (1 + Math.exp(-a)), (self, outGrad) => {
909
1224
  const sig = self.sigmoid();
910
1225
  return outGrad.mul(sig.add(self.mul(sig).mul(sig.neg().add(1))));
@@ -912,6 +1227,11 @@ class Tensor {
912
1227
  }
913
1228
  // Tensor element-wise mish
914
1229
  mish() {
1230
+ // Use backend of tensor's device if available, or else fallback to cpu
1231
+ const backend = Tensor.backends.get(this.device);
1232
+ if (backend && backend.mish) {
1233
+ return backend.mish(this);
1234
+ }
915
1235
  return this.elementWiseSelfDAG((a) => a * Math.tanh(Math.log1p(Math.exp(a))), (self, outGrad) => {
916
1236
  const tanhSoftPlus = self.exp().add(1).log().tanh();
917
1237
  // tanh(softplus(x)) + x * (1 - tanh²(softplus(x))) * sigmoid(x)
@@ -921,48 +1241,103 @@ class Tensor {
921
1241
  }
922
1242
  // Tensor element-wise maximum
923
1243
  maximum(other) {
1244
+ // Use backend of tensor's device if available, or else fallback to cpu
1245
+ const backend = Tensor.backends.get(this.device);
1246
+ if (backend && backend.maximum) {
1247
+ return backend.maximum(this, other);
1248
+ }
924
1249
  return this.elementWiseABDAG(other, (a, b) => Math.max(a, b), (self, other, outGrad) => outGrad.mul(self.gt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.gt(self).add(other.eq(self).mul(0.5))));
925
1250
  }
926
1251
  // Tensor element-wise minimum
927
1252
  minimum(other) {
1253
+ // Use backend of tensor's device if available, or else fallback to cpu
1254
+ const backend = Tensor.backends.get(this.device);
1255
+ if (backend && backend.minimum) {
1256
+ return backend.minimum(this, other);
1257
+ }
928
1258
  return this.elementWiseABDAG(other, (a, b) => Math.min(a, b), (self, other, outGrad) => outGrad.mul(self.lt(other).add(self.eq(other).mul(0.5))), (self, other, outGrad) => outGrad.mul(other.lt(self).add(other.eq(self).mul(0.5))));
929
1259
  }
930
1260
  // Tensor element-wise round
931
1261
  round() {
1262
+ // Use backend of tensor's device if available, or else fallback to cpu
1263
+ const backend = Tensor.backends.get(this.device);
1264
+ if (backend && backend.round) {
1265
+ return backend.round(this);
1266
+ }
932
1267
  return this.elementWiseSelfDAG((a) => Math.round(a));
933
1268
  }
934
1269
  // Tensor element-wise floor
935
1270
  floor() {
1271
+ // Use backend of tensor's device if available, or else fallback to cpu
1272
+ const backend = Tensor.backends.get(this.device);
1273
+ if (backend && backend.floor) {
1274
+ return backend.floor(this);
1275
+ }
936
1276
  return this.elementWiseSelfDAG((a) => Math.floor(a));
937
1277
  }
938
1278
  // Tensor element-wise ceil
939
1279
  ceil() {
1280
+ // Use backend of tensor's device if available, or else fallback to cpu
1281
+ const backend = Tensor.backends.get(this.device);
1282
+ if (backend && backend.ceil) {
1283
+ return backend.ceil(this);
1284
+ }
940
1285
  return this.elementWiseSelfDAG((a) => Math.ceil(a));
941
1286
  }
942
1287
  // Tensor element-wise truncation
943
1288
  trunc() {
1289
+ // Use backend of tensor's device if available, or else fallback to cpu
1290
+ const backend = Tensor.backends.get(this.device);
1291
+ if (backend && backend.trunc) {
1292
+ return backend.trunc(this);
1293
+ }
944
1294
  return this.elementWiseSelfDAG((a) => Math.trunc(a));
945
1295
  }
946
1296
  fix = this.trunc;
947
1297
  // Tensor element-wise fraction portion
948
1298
  frac() {
1299
+ // Use backend of tensor's device if available, or else fallback to cpu
1300
+ const backend = Tensor.backends.get(this.device);
1301
+ if (backend && backend.frac) {
1302
+ return backend.frac(this);
1303
+ }
949
1304
  return this.elementWiseSelfDAG((a) => a - Math.floor(a));
950
1305
  }
951
1306
  // Tensor element-wise clip and clamp
952
1307
  clip(min, max) {
1308
+ // Use backend of tensor's device if available, or else fallback to cpu
1309
+ const backend = Tensor.backends.get(this.device);
1310
+ if (backend && backend.clip) {
1311
+ return backend.clip(this, min, max);
1312
+ }
953
1313
  return this.elementWiseSelfDAG((a) => Math.max(min, Math.min(max, a)), (self, outGrad) => outGrad.mul(self.ge(min).mul(self.le(max))));
954
1314
  }
955
1315
  clamp = this.clip;
956
1316
  // Tensor element-wise error function
957
1317
  erf() {
1318
+ // Use backend of tensor's device if available, or else fallback to cpu
1319
+ const backend = Tensor.backends.get(this.device);
1320
+ if (backend && backend.erf) {
1321
+ return backend.erf(this);
1322
+ }
958
1323
  return this.elementWiseSelfDAG((a) => (0, utils_1.erf)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI))));
959
1324
  }
960
1325
  // Tensor element-wise complementary error function
961
1326
  erfc() {
1327
+ // Use backend of tensor's device if available, or else fallback to cpu
1328
+ const backend = Tensor.backends.get(this.device);
1329
+ if (backend && backend.erfc) {
1330
+ return backend.erfc(this);
1331
+ }
962
1332
  return this.elementWiseSelfDAG((a) => (0, utils_1.erfc)(a), (self, outGrad) => outGrad.mul(self.square().neg().exp().mul(2 / Math.sqrt(Math.PI)).neg()));
963
1333
  }
964
1334
  // Tensor element-wise inverse error function
965
1335
  erfinv() {
1336
+ // Use backend of tensor's device if available, or else fallback to cpu
1337
+ const backend = Tensor.backends.get(this.device);
1338
+ if (backend && backend.erfinv) {
1339
+ return backend.erfinv(this);
1340
+ }
966
1341
  return this.elementWiseSelfDAG((a) => (0, utils_1.erfinv)(a), (self, outGrad) => outGrad.mul(self.erfinv().square().exp().mul(Math.sqrt(Math.PI) / 2)));
967
1342
  }
968
1343
  // Transpose
@@ -981,7 +1356,7 @@ class Tensor {
981
1356
  [newShape[dim1], newShape[dim2]] = [newShape[dim2], newShape[dim1]];
982
1357
  [newStrides[dim1], newStrides[dim2]] = [newStrides[dim2], newStrides[dim1]];
983
1358
  // Create new tensor with same data but swapped shape/strides
984
- const out = new Tensor(this.value, { shape: newShape, strides: newStrides });
1359
+ const out = new Tensor(this.value, { shape: newShape, strides: newStrides, device: this.device });
985
1360
  out.requiresGrad = this.requiresGrad;
986
1361
  // Handle gradient if needed
987
1362
  if (this.requiresGrad) {
@@ -1004,6 +1379,11 @@ class Tensor {
1004
1379
  }
1005
1380
  // 1D tensor dot product
1006
1381
  dot(other) {
1382
+ // Use backend of tensor's device if available, or else fallback to cpu
1383
+ const backend = Tensor.backends.get(this.device);
1384
+ if (backend && backend.dot) {
1385
+ return backend.dot(this, other);
1386
+ }
1007
1387
  other = Tensor.forceTensor(other);
1008
1388
  // Verify 1D shape
1009
1389
  if (this.shape.length !== 1 || other.shape.length !== 1) {
@@ -1042,6 +1422,11 @@ class Tensor {
1042
1422
  }
1043
1423
  // Matrix multiplication
1044
1424
  mm(other) {
1425
+ // Use backend of tensor's device if available, or else fallback to cpu
1426
+ const backend = Tensor.backends.get(this.device);
1427
+ if (backend && backend.mm) {
1428
+ return backend.mm(this, other);
1429
+ }
1045
1430
  other = Tensor.forceTensor(other);
1046
1431
  // Verify 2D shape
1047
1432
  if (this.shape.length !== 2 || other.shape.length !== 2) {
@@ -1097,6 +1482,11 @@ class Tensor {
1097
1482
  }
1098
1483
  // Batched 3D tensor matmul
1099
1484
  bmm(other) {
1485
+ // Use backend of tensor's device if available, or else fallback to cpu
1486
+ const backend = Tensor.backends.get(this.device);
1487
+ if (backend && backend.bmm) {
1488
+ return backend.bmm(this, other);
1489
+ }
1100
1490
  other = Tensor.forceTensor(other);
1101
1491
  // Verify 3D shape
1102
1492
  if (this.shape.length !== 3 || other.shape.length !== 3 || this.shape[0] !== other.shape[0]) {
@@ -1155,40 +1545,25 @@ class Tensor {
1155
1545
  }
1156
1546
  // Convert right-side 1D tensor to a vector (nx1 tensor) to do matmul
1157
1547
  mv(other) {
1548
+ // Use backend of tensor's device if available, or else fallback to cpu
1549
+ const backend = Tensor.backends.get(this.device);
1550
+ if (backend && backend.mv) {
1551
+ return backend.mv(this, other);
1552
+ }
1158
1553
  other = Tensor.forceTensor(other);
1159
1554
  // Verify 2D shape
1160
1555
  if (this.shape.length !== 2 || other.shape.length !== 1) {
1161
1556
  throw new Error("Input is not a 2D and 1D tensor pair");
1162
1557
  }
1163
- // MM with no grad
1164
- const thisMat = new Tensor(this.value, { shape: this.shape, strides: this.strides });
1165
- const otherMat = new Tensor(other.value, { shape: [other.shape[0], 1], strides: [other.strides[0], 1] });
1166
- const out = thisMat.mm(otherMat).squeeze(1);
1167
- // Handle grad with original tensors
1168
- if (this.requiresGrad) {
1169
- out.requiresGrad = true;
1170
- out.children.push(this);
1171
- }
1172
- if (other.requiresGrad) {
1173
- out.requiresGrad = true;
1174
- out.children.push(other);
1175
- }
1176
- if (out.requiresGrad) {
1177
- out.gradFn = () => {
1178
- // Disable gradient collecting of gradients themselves
1179
- const outGrad = out.grad.withGrad(false);
1180
- const selfNoGrad = this.withGrad(false);
1181
- const otherNoGrad = other.withGrad(false);
1182
- if (this.requiresGrad)
1183
- Tensor.addGrad(this, outGrad.unsqueeze(1).mm(otherNoGrad.unsqueeze(0)));
1184
- if (other.requiresGrad)
1185
- Tensor.addGrad(other, selfNoGrad.t().mv(outGrad));
1186
- };
1187
- }
1188
- return out;
1558
+ return this.mm(other.unsqueeze(1)).squeeze(1);
1189
1559
  }
1190
1560
  // General matrix multiplication with different shapes
1191
1561
  matmul(other) {
1562
+ // Use backend of tensor's device if available, or else fallback to cpu
1563
+ const backend = Tensor.backends.get(this.device);
1564
+ if (backend && backend.matmul) {
1565
+ return backend.matmul(this, other);
1566
+ }
1192
1567
  other = Tensor.forceTensor(other);
1193
1568
  const isThis1D = this.shape.length === 1;
1194
1569
  const isOther1D = other.shape.length === 1;
@@ -1461,6 +1836,7 @@ class Tensor {
1461
1836
  return new Tensor(this.value, {
1462
1837
  shape: this.shape,
1463
1838
  strides: this.strides,
1839
+ device: this.device,
1464
1840
  requiresGrad
1465
1841
  });
1466
1842
  }
@@ -1469,6 +1845,7 @@ class Tensor {
1469
1845
  return new Tensor(this.value, {
1470
1846
  shape: this.shape,
1471
1847
  strides: this.strides,
1848
+ device: this.device,
1472
1849
  requiresGrad: false
1473
1850
  });
1474
1851
  }
@@ -1493,5 +1870,15 @@ class Tensor {
1493
1870
  this.value = other.value;
1494
1871
  return this;
1495
1872
  }
1873
+ // Holds all available backends
1874
+ static backends = new Map();
1875
+ // Op to transfer tensor to another device
1876
+ to(device) {
1877
+ const backend = Tensor.backends.get(device);
1878
+ if (backend && backend.to) {
1879
+ return backend.to(this);
1880
+ }
1881
+ throw new Error(`No device found to transfer tensor to or "to" is not implemented for device.`);
1882
+ }
1496
1883
  }
1497
1884
  exports.Tensor = Tensor;
package/dist/optim.d.ts CHANGED
@@ -17,7 +17,26 @@ declare class SGD {
17
17
  constructor(params: Tensor[], options?: SGDOptions);
18
18
  step(): void;
19
19
  }
20
+ export interface AdamOptions {
21
+ lr?: number;
22
+ betas?: [number, number];
23
+ eps?: number;
24
+ weightDecay?: number;
25
+ }
26
+ declare class Adam {
27
+ params: Tensor[];
28
+ momentumBuffers: Map<Tensor, Tensor>;
29
+ velocityBuffers: Map<Tensor, Tensor>;
30
+ stepCount: number;
31
+ lr: number;
32
+ betas: [number, number];
33
+ eps: number;
34
+ weightDecay: number;
35
+ constructor(params: Tensor[], options?: AdamOptions);
36
+ step(): void;
37
+ }
20
38
  export declare class Optim {
21
39
  static SGD: typeof SGD;
40
+ static Adam: typeof Adam;
22
41
  }
23
42
  export {};
package/dist/optim.js CHANGED
@@ -1,6 +1,7 @@
1
1
  "use strict";
2
2
  Object.defineProperty(exports, "__esModule", { value: true });
3
3
  exports.Optim = void 0;
4
+ const core_1 = require("./core");
4
5
  class SGD {
5
6
  params;
6
7
  momentumBuffers = new Map();
@@ -19,7 +20,7 @@ class SGD {
19
20
  }
20
21
  step() {
21
22
  for (const param of this.params) {
22
- if (typeof param.grad === "undefined") {
23
+ if (!param.grad) {
23
24
  throw new Error("Can not apply SGD on empty grad");
24
25
  }
25
26
  let grad = param.grad.detach(), detachedParam = param.detach();
@@ -55,7 +56,70 @@ class SGD {
55
56
  }
56
57
  }
57
58
  }
59
+ class Adam {
60
+ params;
61
+ momentumBuffers = new Map(); // First moment (m_t)
62
+ velocityBuffers = new Map(); // Second moment (v_t)
63
+ stepCount = 0;
64
+ lr;
65
+ betas;
66
+ eps;
67
+ weightDecay;
68
+ constructor(params, options) {
69
+ this.params = params;
70
+ this.lr = options?.lr || 0.001;
71
+ this.betas = options?.betas || [0.9, 0.999];
72
+ this.eps = options?.eps || 1e-8;
73
+ this.weightDecay = options?.weightDecay || 0;
74
+ }
75
+ step() {
76
+ this.stepCount++;
77
+ const beta1 = this.betas[0];
78
+ const beta2 = this.betas[1];
79
+ // Bias correction factors
80
+ const biasCorrection1 = 1 - Math.pow(beta1, this.stepCount);
81
+ const biasCorrection2 = 1 - Math.pow(beta2, this.stepCount);
82
+ for (const param of this.params) {
83
+ if (!param.grad) {
84
+ throw new Error("Can not apply Adam on empty grad");
85
+ }
86
+ let grad = param.grad.detach(), detachedParam = param.detach();
87
+ // Apply weight decay (L2 regularization)
88
+ if (this.weightDecay !== 0) {
89
+ grad = grad.add(detachedParam.mul(this.weightDecay));
90
+ }
91
+ // Get or initialize first moment buffer (momentum)
92
+ let momentumBuffer = this.momentumBuffers.get(param);
93
+ if (!momentumBuffer) {
94
+ momentumBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
95
+ this.momentumBuffers.set(param, momentumBuffer);
96
+ }
97
+ // Get or initialize second moment buffer (velocity)
98
+ let velocityBuffer = this.velocityBuffers.get(param);
99
+ if (!velocityBuffer) {
100
+ velocityBuffer = core_1.Tensor.zerosLike(grad); // Initialize with zeros (same shape as grad)
101
+ this.velocityBuffers.set(param, velocityBuffer);
102
+ }
103
+ // Update biased first moment estimate: m_t = β1 * m_{t-1} + (1 - β1) * g_t
104
+ momentumBuffer = momentumBuffer.mul(beta1).add(grad.mul(1 - beta1));
105
+ this.momentumBuffers.set(param, momentumBuffer);
106
+ // Update biased second moment estimate: v_t = β2 * v_{t-1} + (1 - β2) * g_t^2
107
+ velocityBuffer = velocityBuffer.mul(beta2).add(grad.pow(2).mul(1 - beta2));
108
+ this.velocityBuffers.set(param, velocityBuffer);
109
+ // Compute bias-corrected first moment: m̂_t = m_t / (1 - β1^t)
110
+ const correctedMomentum = momentumBuffer.div(biasCorrection1);
111
+ // Compute bias-corrected second moment: v̂_t = v_t / (1 - β2^t)
112
+ const correctedVelocity = velocityBuffer.div(biasCorrection2);
113
+ // Update parameters: θ_t = θ_{t-1} - α * m̂_t / (√v̂_t + ε)
114
+ const denom = correctedVelocity.sqrt().add(this.eps);
115
+ const stepSize = correctedMomentum.div(denom).mul(this.lr);
116
+ const newParam = detachedParam.sub(stepSize);
117
+ param.replace(newParam);
118
+ }
119
+ }
120
+ }
58
121
  class Optim {
59
122
  static SGD = SGD;
123
+ static Adam = Adam;
60
124
  }
61
125
  exports.Optim = Optim;
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "catniff",
3
- "version": "0.3.0",
3
+ "version": "0.4.0",
4
4
  "description": "A small Torch-like deep learning framework for Javascript with tensor and autograd support",
5
5
  "main": "index.js",
6
6
  "scripts": {