@thi.ng/tensors 0.1.0 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. package/CHANGELOG.md +26 -1
  2. package/README.md +93 -34
  3. package/abs.d.ts +1 -29
  4. package/abs.js +2 -11
  5. package/add.d.ts +4 -35
  6. package/add.js +2 -11
  7. package/addn.d.ts +1 -33
  8. package/addn.js +2 -11
  9. package/api.d.ts +119 -15
  10. package/broadcast.d.ts +24 -0
  11. package/broadcast.js +54 -0
  12. package/clamp.d.ts +4 -38
  13. package/clamp.js +2 -11
  14. package/clampn.d.ts +1 -37
  15. package/clampn.js +2 -11
  16. package/cos.d.ts +1 -29
  17. package/cos.js +2 -11
  18. package/defopbtt.d.ts +11 -0
  19. package/defopbtt.js +153 -0
  20. package/defopn.d.ts +4 -7
  21. package/defopn.js +17 -7
  22. package/defoprt.d.ts +4 -6
  23. package/defoprt.js +17 -7
  24. package/defoprtt.d.ts +5 -6
  25. package/defoprtt.js +22 -20
  26. package/defopt.d.ts +4 -7
  27. package/defopt.js +17 -20
  28. package/defoptn.d.ts +4 -7
  29. package/defoptn.js +17 -20
  30. package/defoptnn.d.ts +4 -7
  31. package/defoptnn.js +20 -16
  32. package/defoptt.d.ts +5 -6
  33. package/defoptt.js +36 -27
  34. package/defopttt.d.ts +5 -7
  35. package/defopttt.js +43 -37
  36. package/diagonal.d.ts +16 -0
  37. package/diagonal.js +18 -0
  38. package/div.d.ts +4 -35
  39. package/div.js +2 -11
  40. package/divn.d.ts +1 -33
  41. package/divn.js +2 -11
  42. package/dot.d.ts +4 -26
  43. package/dot.js +2 -11
  44. package/errors.d.ts +2 -0
  45. package/errors.js +3 -0
  46. package/exp.d.ts +1 -29
  47. package/exp.js +2 -11
  48. package/exp2.d.ts +1 -29
  49. package/exp2.js +2 -11
  50. package/identity.d.ts +7 -0
  51. package/identity.js +3 -2
  52. package/index.d.ts +3 -0
  53. package/index.js +3 -0
  54. package/log.d.ts +1 -29
  55. package/log.js +2 -11
  56. package/log2.d.ts +1 -29
  57. package/log2.js +2 -11
  58. package/mag.d.ts +5 -0
  59. package/magsq.d.ts +1 -25
  60. package/magsq.js +2 -11
  61. package/max.d.ts +5 -30
  62. package/max.js +2 -11
  63. package/maxn.d.ts +1 -33
  64. package/maxn.js +2 -11
  65. package/min.d.ts +5 -30
  66. package/min.js +2 -11
  67. package/minn.d.ts +1 -33
  68. package/minn.js +2 -11
  69. package/mul.d.ts +4 -35
  70. package/mul.js +2 -11
  71. package/mulm.d.ts +1 -1
  72. package/mulm.js +21 -20
  73. package/muln.d.ts +1 -33
  74. package/muln.js +2 -11
  75. package/mulv.d.ts +1 -1
  76. package/mulv.js +16 -15
  77. package/normalize.d.ts +9 -1
  78. package/package.json +39 -4
  79. package/pow.d.ts +5 -30
  80. package/pow.js +2 -11
  81. package/pown.d.ts +1 -33
  82. package/pown.js +2 -11
  83. package/product.d.ts +1 -25
  84. package/product.js +2 -11
  85. package/rand-distrib.d.ts +13 -10
  86. package/rand-distrib.js +55 -23
  87. package/relu.d.ts +1 -29
  88. package/relu.js +2 -11
  89. package/relun.d.ts +1 -33
  90. package/relun.js +2 -11
  91. package/select.d.ts +6 -6
  92. package/select.js +26 -23
  93. package/set.d.ts +1 -6
  94. package/set.js +3 -11
  95. package/setn.d.ts +1 -6
  96. package/setn.js +3 -11
  97. package/sigmoid.d.ts +1 -29
  98. package/sigmoid.js +2 -11
  99. package/sin.d.ts +1 -29
  100. package/sin.js +2 -11
  101. package/softmax.d.ts +1 -1
  102. package/softplus.d.ts +1 -33
  103. package/softplus.js +2 -11
  104. package/sqrt.d.ts +1 -29
  105. package/sqrt.js +2 -11
  106. package/step.d.ts +1 -33
  107. package/step.js +2 -11
  108. package/sub.d.ts +4 -35
  109. package/sub.js +2 -11
  110. package/subn.d.ts +1 -33
  111. package/subn.js +2 -11
  112. package/sum.d.ts +1 -25
  113. package/sum.js +2 -11
  114. package/svd.d.ts +33 -0
  115. package/svd.js +246 -0
  116. package/tan.d.ts +1 -29
  117. package/tan.js +2 -11
  118. package/tanh.d.ts +1 -29
  119. package/tanh.js +2 -11
  120. package/tensor.d.ts +3 -0
  121. package/tensor.js +45 -21
  122. package/top.d.ts +2 -6
package/defopt.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { top } from "./top.js";
2
- const defOpT = (fn, dispatch = 1) => {
2
+ const defOpT = (fn) => {
3
3
  const f1 = (out, a) => {
4
4
  !out && (out = a);
5
5
  const {
@@ -31,9 +31,10 @@ const defOpT = (fn, dispatch = 1) => {
31
31
  shape: [sx, sy],
32
32
  stride: [txa, tya]
33
33
  } = a;
34
+ let oox, oax;
34
35
  for (let x = 0; x < sx; x++) {
35
- const oox = oo + x * txo;
36
- const oax = oa + x * txa;
36
+ oox = oo + x * txo;
37
+ oax = oa + x * txa;
37
38
  for (let y = 0; y < sy; y++) {
38
39
  odata[oox + y * tyo] = fn(adata[oax + y * tya]);
39
40
  }
@@ -53,12 +54,13 @@ const defOpT = (fn, dispatch = 1) => {
53
54
  shape: [sx, sy, sz],
54
55
  stride: [txa, tya, tza]
55
56
  } = a;
57
+ let oox, oax, ooy, oay;
56
58
  for (let x = 0; x < sx; x++) {
57
- const oox = oo + x * txo;
58
- const oax = oa + x * txa;
59
+ oox = oo + x * txo;
60
+ oax = oa + x * txa;
59
61
  for (let y = 0; y < sy; y++) {
60
- const ooy = oox + y * tyo;
61
- const oay = oax + y * tya;
62
+ ooy = oox + y * tyo;
63
+ oay = oax + y * tya;
62
64
  for (let z = 0; z < sz; z++) {
63
65
  odata[ooy + z * tzo] = fn(adata[oay + z * tza]);
64
66
  }
@@ -79,15 +81,16 @@ const defOpT = (fn, dispatch = 1) => {
79
81
  shape: [sx, sy, sz, sw],
80
82
  stride: [txa, tya, tza, twa]
81
83
  } = a;
84
+ let oox, oax, ooy, oay, ooz, oaz;
82
85
  for (let x = 0; x < sx; x++) {
83
- const oox = oo + x * txo;
84
- const oax = oa + x * txa;
86
+ oox = oo + x * txo;
87
+ oax = oa + x * txa;
85
88
  for (let y = 0; y < sy; y++) {
86
- const ooy = oox + y * tyo;
87
- const oay = oax + y * tya;
89
+ ooy = oox + y * tyo;
90
+ oay = oax + y * tya;
88
91
  for (let z = 0; z < sz; z++) {
89
- const ooz = ooy + z * tzo;
90
- const oaz = oay + z * tza;
92
+ ooz = ooy + z * tzo;
93
+ oaz = oay + z * tza;
91
94
  for (let w = 0; w < sw; w++) {
92
95
  odata[ooz + w * two] = fn(adata[oaz + w * twa]);
93
96
  }
@@ -96,13 +99,7 @@ const defOpT = (fn, dispatch = 1) => {
96
99
  }
97
100
  return out;
98
101
  };
99
- return [
100
- top(dispatch, void 0, f1, f2, f3, f4),
101
- f1,
102
- f2,
103
- f3,
104
- f4
105
- ];
102
+ return top(1, void 0, f1, f2, f3, f4);
106
103
  };
107
104
  export {
108
105
  defOpT
package/defoptn.d.ts CHANGED
@@ -1,13 +1,10 @@
1
1
  import type { FnU2 } from "@thi.ng/api";
2
- import type { MultiTensorOpTN, TensorOpTN } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpTN } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpTN}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpTN} applying the given function component-wise.
8
6
  *
9
7
  * @param fn
10
- * @param dispatch
11
8
  */
12
- export declare const defOpTN: <T = number>(fn: FnU2<T>, dispatch?: number) => [MultiTensorOpTN<T>, TensorOpTN<T, T, Tensor1<T>, Tensor1<T>>, TensorOpTN<T, T, Tensor2<T>, Tensor2<T>>, TensorOpTN<T, T, Tensor3<T>, Tensor3<T>>, TensorOpTN<T, T, Tensor4<T>, Tensor4<T>>];
9
+ export declare const defOpTN: <T = number>(fn: FnU2<T>) => import("./api.js").MultiTensorOpImpl<TensorOpTN<T>>;
13
10
  //# sourceMappingURL=defoptn.d.ts.map
package/defoptn.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { top } from "./top.js";
2
- const defOpTN = (fn, dispatch = 1) => {
2
+ const defOpTN = (fn) => {
3
3
  const f1 = (out, a, n) => {
4
4
  !out && (out = a);
5
5
  const {
@@ -31,9 +31,10 @@ const defOpTN = (fn, dispatch = 1) => {
31
31
  shape: [sx, sy],
32
32
  stride: [txa, tya]
33
33
  } = a;
34
+ let oox, oax;
34
35
  for (let x = 0; x < sx; x++) {
35
- const oox = oo + x * txo;
36
- const oax = oa + x * txa;
36
+ oox = oo + x * txo;
37
+ oax = oa + x * txa;
37
38
  for (let y = 0; y < sy; y++) {
38
39
  odata[oox + y * tyo] = fn(adata[oax + y * tya], n);
39
40
  }
@@ -53,12 +54,13 @@ const defOpTN = (fn, dispatch = 1) => {
53
54
  shape: [sx, sy, sz],
54
55
  stride: [txa, tya, tza]
55
56
  } = a;
57
+ let oox, oax, ooy, oay;
56
58
  for (let x = 0; x < sx; x++) {
57
- const oox = oo + x * txo;
58
- const oax = oa + x * txa;
59
+ oox = oo + x * txo;
60
+ oax = oa + x * txa;
59
61
  for (let y = 0; y < sy; y++) {
60
- const ooy = oox + y * tyo;
61
- const oay = oax + y * tya;
62
+ ooy = oox + y * tyo;
63
+ oay = oax + y * tya;
62
64
  for (let z = 0; z < sz; z++) {
63
65
  odata[ooy + z * tzo] = fn(adata[oay + z * tza], n);
64
66
  }
@@ -79,15 +81,16 @@ const defOpTN = (fn, dispatch = 1) => {
79
81
  shape: [sx, sy, sz, sw],
80
82
  stride: [txa, tya, tza, twa]
81
83
  } = a;
84
+ let oox, oax, ooy, oay, ooz, oaz;
82
85
  for (let x = 0; x < sx; x++) {
83
- const oox = oo + x * txo;
84
- const oax = oa + x * txa;
86
+ oox = oo + x * txo;
87
+ oax = oa + x * txa;
85
88
  for (let y = 0; y < sy; y++) {
86
- const ooy = oox + y * tyo;
87
- const oay = oax + y * tya;
89
+ ooy = oox + y * tyo;
90
+ oay = oax + y * tya;
88
91
  for (let z = 0; z < sz; z++) {
89
- const ooz = ooy + z * tzo;
90
- const oaz = oay + z * tza;
92
+ ooz = ooy + z * tzo;
93
+ oaz = oay + z * tza;
91
94
  for (let w = 0; w < sw; w++) {
92
95
  odata[ooz + w * two] = fn(adata[oaz + w * twa], n);
93
96
  }
@@ -96,13 +99,7 @@ const defOpTN = (fn, dispatch = 1) => {
96
99
  }
97
100
  return out;
98
101
  };
99
- return [
100
- top(dispatch, void 0, f1, f2, f3, f4),
101
- f1,
102
- f2,
103
- f3,
104
- f4
105
- ];
102
+ return top(1, void 0, f1, f2, f3, f4);
106
103
  };
107
104
  export {
108
105
  defOpTN
package/defoptnn.d.ts CHANGED
@@ -1,13 +1,10 @@
1
1
  import type { FnU3 } from "@thi.ng/api";
2
- import type { MultiTensorOpTNN, TensorOpTNN } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpTNN } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpTNN}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpTNN} applying the given function component-wise.
8
6
  *
9
7
  * @param fn
10
- * @param dispatch
11
8
  */
12
- export declare const defOpTNN: <T = number>(fn: FnU3<T>, dispatch?: number) => [MultiTensorOpTNN<T>, TensorOpTNN<T, T, Tensor1<T>, Tensor1<T>>, TensorOpTNN<T, T, Tensor2<T>, Tensor2<T>>, TensorOpTNN<T, T, Tensor3<T>, Tensor3<T>>, TensorOpTNN<T, T, Tensor4<T>, Tensor4<T>>];
9
+ export declare const defOpTNN: <T = number>(fn: FnU3<T>) => import("./api.js").MultiTensorOpImpl<TensorOpTNN<T>>;
13
10
  //# sourceMappingURL=defoptnn.d.ts.map
package/defoptnn.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { top } from "./top.js";
2
- const defOpTNN = (fn, dispatch = 1) => {
2
+ const defOpTNN = (fn) => {
3
3
  const f1 = (out, a, n, m) => {
4
4
  !out && (out = a);
5
5
  const {
@@ -31,9 +31,10 @@ const defOpTNN = (fn, dispatch = 1) => {
31
31
  shape: [sx, sy],
32
32
  stride: [txa, tya]
33
33
  } = a;
34
+ let oox, oax;
34
35
  for (let x = 0; x < sx; x++) {
35
- const oox = oo + x * txo;
36
- const oax = oa + x * txa;
36
+ oox = oo + x * txo;
37
+ oax = oa + x * txa;
37
38
  for (let y = 0; y < sy; y++) {
38
39
  odata[oox + y * tyo] = fn(adata[oax + y * tya], n, m);
39
40
  }
@@ -53,12 +54,13 @@ const defOpTNN = (fn, dispatch = 1) => {
53
54
  shape: [sx, sy, sz],
54
55
  stride: [txa, tya, tza]
55
56
  } = a;
57
+ let oox, oax, ooy, oay;
56
58
  for (let x = 0; x < sx; x++) {
57
- const oox = oo + x * txo;
58
- const oax = oa + x * txa;
59
+ oox = oo + x * txo;
60
+ oax = oa + x * txa;
59
61
  for (let y = 0; y < sy; y++) {
60
- const ooy = oox + y * tyo;
61
- const oay = oax + y * tya;
62
+ ooy = oox + y * tyo;
63
+ oay = oax + y * tya;
62
64
  for (let z = 0; z < sz; z++) {
63
65
  odata[ooy + z * tzo] = fn(adata[oay + z * tza], n, m);
64
66
  }
@@ -79,15 +81,16 @@ const defOpTNN = (fn, dispatch = 1) => {
79
81
  shape: [sx, sy, sz, sw],
80
82
  stride: [txa, tya, tza, twa]
81
83
  } = a;
84
+ let oox, oax, ooy, oay, ooz, oaz;
82
85
  for (let x = 0; x < sx; x++) {
83
- const oox = oo + x * txo;
84
- const oax = oa + x * txa;
86
+ oox = oo + x * txo;
87
+ oax = oa + x * txa;
85
88
  for (let y = 0; y < sy; y++) {
86
- const ooy = oox + y * tyo;
87
- const oay = oax + y * tya;
89
+ ooy = oox + y * tyo;
90
+ oay = oax + y * tya;
88
91
  for (let z = 0; z < sz; z++) {
89
- const ooz = ooy + z * tzo;
90
- const oaz = oay + z * tza;
92
+ ooz = ooy + z * tzo;
93
+ oaz = oay + z * tza;
91
94
  for (let w = 0; w < sw; w++) {
92
95
  odata[ooz + w * two] = fn(adata[oaz + w * twa], n, m);
93
96
  }
@@ -96,13 +99,14 @@ const defOpTNN = (fn, dispatch = 1) => {
96
99
  }
97
100
  return out;
98
101
  };
99
- return [
100
- top(dispatch, void 0, f1, f2, f3, f4),
102
+ return top(
103
+ 1,
104
+ void 0,
101
105
  f1,
102
106
  f2,
103
107
  f3,
104
108
  f4
105
- ];
109
+ );
106
110
  };
107
111
  export {
108
112
  defOpTNN
package/defoptt.d.ts CHANGED
@@ -1,13 +1,12 @@
1
1
  import type { FnU2 } from "@thi.ng/api";
2
- import type { MultiTensorOpTT, TensorOpTT } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpTT } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpTT}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpTT} applying the given function componentwise with
6
+ * broadcasting rules (see {@link broadcast} for details).
8
7
  *
9
8
  * @param fn
10
9
  * @param dispatch
11
10
  */
12
- export declare const defOpTT: <T = number>(fn: FnU2<T>, dispatch?: number) => [MultiTensorOpTT<T>, TensorOpTT<T, T, Tensor1<T>, Tensor1<T>>, TensorOpTT<T, T, Tensor2<T>, Tensor2<T>>, TensorOpTT<T, T, Tensor3<T>, Tensor3<T>>, TensorOpTT<T, T, Tensor4<T>, Tensor4<T>>];
11
+ export declare const defOpTT: <T = number>(fn: FnU2<T>) => TensorOpTT<T>;
13
12
  //# sourceMappingURL=defoptt.d.ts.map
package/defoptt.js CHANGED
@@ -1,5 +1,7 @@
1
- import { top } from "./top.js";
2
- const defOpTT = (fn, dispatch = 1) => {
1
+ import { broadcast } from "./broadcast.js";
2
+ import { ensureShape } from "./errors.js";
3
+ import { tensor } from "./tensor.js";
4
+ const defOpTT = (fn) => {
3
5
  const f1 = (out, a, b) => {
4
6
  !out && (out = a);
5
7
  const {
@@ -41,10 +43,11 @@ const defOpTT = (fn, dispatch = 1) => {
41
43
  offset: ob,
42
44
  stride: [txb, tyb]
43
45
  } = b;
46
+ let oox, oax, obx;
44
47
  for (let x = 0; x < sx; x++) {
45
- const oox = oo + x * txo;
46
- const oax = oa + x * txa;
47
- const obx = ob + x * txb;
48
+ oox = oo + x * txo;
49
+ oax = oa + x * txa;
50
+ obx = ob + x * txb;
48
51
  for (let y = 0; y < sy; y++) {
49
52
  odata[oox + y * tyo] = fn(
50
53
  adata[oax + y * tya],
@@ -72,14 +75,15 @@ const defOpTT = (fn, dispatch = 1) => {
72
75
  offset: ob,
73
76
  stride: [txb, tyb, tzb]
74
77
  } = b;
78
+ let oox, oax, obx, ooy, oay, oby;
75
79
  for (let x = 0; x < sx; x++) {
76
- const oox = oo + x * txo;
77
- const oax = oa + x * txa;
78
- const obx = ob + x * txb;
80
+ oox = oo + x * txo;
81
+ oax = oa + x * txa;
82
+ obx = ob + x * txb;
79
83
  for (let y = 0; y < sy; y++) {
80
- const ooy = oox + y * tyo;
81
- const oay = oax + y * tya;
82
- const oby = obx + y * tyb;
84
+ ooy = oox + y * tyo;
85
+ oay = oax + y * tya;
86
+ oby = obx + y * tyb;
83
87
  for (let z = 0; z < sz; z++) {
84
88
  odata[ooy + z * tzo] = fn(
85
89
  adata[oay + z * tza],
@@ -108,18 +112,19 @@ const defOpTT = (fn, dispatch = 1) => {
108
112
  offset: ob,
109
113
  stride: [txb, tyb, tzb, twb]
110
114
  } = b;
115
+ let oox, oax, obx, ooy, oay, oby, ooz, oaz, obz;
111
116
  for (let x = 0; x < sx; x++) {
112
- const oox = oo + x * txo;
113
- const oax = oa + x * txa;
114
- const obx = ob + x * txb;
117
+ oox = oo + x * txo;
118
+ oax = oa + x * txa;
119
+ obx = ob + x * txb;
115
120
  for (let y = 0; y < sy; y++) {
116
- const ooy = oox + y * tyo;
117
- const oay = oax + y * tya;
118
- const oby = obx + y * tyb;
121
+ ooy = oox + y * tyo;
122
+ oay = oax + y * tya;
123
+ oby = obx + y * tyb;
119
124
  for (let z = 0; z < sz; z++) {
120
- const ooz = ooy + z * tzo;
121
- const oaz = oay + z * tza;
122
- const obz = oby + z * tzb;
125
+ ooz = ooy + z * tzo;
126
+ oaz = oay + z * tza;
127
+ obz = oby + z * tzb;
123
128
  for (let w = 0; w < sw; w++) {
124
129
  odata[ooz + w * two] = fn(
125
130
  adata[oaz + w * twa],
@@ -131,13 +136,17 @@ const defOpTT = (fn, dispatch = 1) => {
131
136
  }
132
137
  return out;
133
138
  };
134
- return [
135
- top(dispatch, void 0, f1, f2, f3, f4),
136
- f1,
137
- f2,
138
- f3,
139
- f4
140
- ];
139
+ const impls = [, f1, f2, f3, f4];
140
+ const wrapper = (out, a, b) => {
141
+ const { shape, a: $a, b: $b } = broadcast(a, b);
142
+ if (out) {
143
+ ensureShape(out, shape);
144
+ } else {
145
+ out = tensor(a.type, shape, { storage: a.storage });
146
+ }
147
+ return impls[shape.length](out, $a, $b);
148
+ };
149
+ return wrapper;
141
150
  };
142
151
  export {
143
152
  defOpTT
package/defopttt.d.ts CHANGED
@@ -1,13 +1,11 @@
1
1
  import type { FnU3 } from "@thi.ng/api";
2
- import type { MultiTensorOpTTT, TensorOpTTT } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpTTT } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpTTT}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpTTT} applying the given function componentwise with
6
+ * broadcasting rules (see {@link broadcast} for details).
8
7
  *
9
8
  * @param fn
10
- * @param dispatch
11
9
  */
12
- export declare const defOpTTT: <T = number>(fn: FnU3<T>, dispatch?: number) => [MultiTensorOpTTT<T>, TensorOpTTT<T, T, Tensor1<T>, Tensor1<T>>, TensorOpTTT<T, T, Tensor2<T>, Tensor2<T>>, TensorOpTTT<T, T, Tensor3<T>, Tensor3<T>>, TensorOpTTT<T, T, Tensor4<T>, Tensor4<T>>];
10
+ export declare const defOpTTT: <T = number>(fn: FnU3<T>) => TensorOpTTT<T>;
13
11
  //# sourceMappingURL=defopttt.d.ts.map
package/defopttt.js CHANGED
@@ -1,7 +1,8 @@
1
- import { top } from "./top.js";
2
- const defOpTTT = (fn, dispatch = 1) => {
1
+ import { broadcast } from "./broadcast.js";
2
+ import { ensureShape } from "./errors.js";
3
+ import { tensor } from "./tensor.js";
4
+ const defOpTTT = (fn) => {
3
5
  const f1 = (out, a, b, c) => {
4
- !out && (out = a);
5
6
  const {
6
7
  data: odata,
7
8
  offset: oo,
@@ -33,7 +34,6 @@ const defOpTTT = (fn, dispatch = 1) => {
33
34
  return out;
34
35
  };
35
36
  const f2 = (out, a, b, c) => {
36
- !out && (out = a);
37
37
  const {
38
38
  data: odata,
39
39
  offset: oo,
@@ -55,11 +55,12 @@ const defOpTTT = (fn, dispatch = 1) => {
55
55
  offset: oc,
56
56
  stride: [txc, tyc]
57
57
  } = c;
58
+ let oox, oax, obx, ocx;
58
59
  for (let x = 0; x < sx; x++) {
59
- const oox = oo + x * txo;
60
- const oax = oa + x * txa;
61
- const obx = ob + x * txb;
62
- const ocx = oc + x * txc;
60
+ oox = oo + x * txo;
61
+ oax = oa + x * txa;
62
+ obx = ob + x * txb;
63
+ ocx = oc + x * txc;
63
64
  for (let y = 0; y < sy; y++) {
64
65
  odata[oox + y * tyo] = fn(
65
66
  adata[oax + y * tya],
@@ -71,7 +72,6 @@ const defOpTTT = (fn, dispatch = 1) => {
71
72
  return out;
72
73
  };
73
74
  const f3 = (out, a, b, c) => {
74
- !out && (out = a);
75
75
  const {
76
76
  data: odata,
77
77
  offset: oo,
@@ -93,16 +93,17 @@ const defOpTTT = (fn, dispatch = 1) => {
93
93
  offset: oc,
94
94
  stride: [txc, tyc, tzc]
95
95
  } = c;
96
+ let oox, oax, obx, ocx, ooy, oay, oby, ocy;
96
97
  for (let x = 0; x < sx; x++) {
97
- const oox = oo + x * txo;
98
- const oax = oa + x * txa;
99
- const obx = ob + x * txb;
100
- const ocx = oc + x * txc;
98
+ oox = oo + x * txo;
99
+ oax = oa + x * txa;
100
+ obx = ob + x * txb;
101
+ ocx = oc + x * txc;
101
102
  for (let y = 0; y < sy; y++) {
102
- const ooy = oox + y * tyo;
103
- const oay = oax + y * tya;
104
- const oby = obx + y * tyb;
105
- const ocy = ocx + y * tyc;
103
+ ooy = oox + y * tyo;
104
+ oay = oax + y * tya;
105
+ oby = obx + y * tyb;
106
+ ocy = ocx + y * tyc;
106
107
  for (let z = 0; z < sz; z++) {
107
108
  odata[ooy + z * tzo] = fn(
108
109
  adata[oay + z * tza],
@@ -115,7 +116,6 @@ const defOpTTT = (fn, dispatch = 1) => {
115
116
  return out;
116
117
  };
117
118
  const f4 = (out, a, b, c) => {
118
- !out && (out = a);
119
119
  const {
120
120
  data: odata,
121
121
  offset: oo,
@@ -137,21 +137,22 @@ const defOpTTT = (fn, dispatch = 1) => {
137
137
  offset: oc,
138
138
  stride: [txc, tyc, tzc, twc]
139
139
  } = c;
140
+ let oox, oax, obx, ocx, ooy, oay, oby, ocy, ooz, oaz, obz, ocz;
140
141
  for (let x = 0; x < sx; x++) {
141
- const oox = oo + x * txo;
142
- const oax = oa + x * txa;
143
- const obx = ob + x * txb;
144
- const ocx = oc + x * txc;
142
+ oox = oo + x * txo;
143
+ oax = oa + x * txa;
144
+ obx = ob + x * txb;
145
+ ocx = oc + x * txc;
145
146
  for (let y = 0; y < sy; y++) {
146
- const ooy = oox + y * tyo;
147
- const oay = oax + y * tya;
148
- const oby = obx + y * tyb;
149
- const ocy = ocx + y * tyc;
147
+ ooy = oox + y * tyo;
148
+ oay = oax + y * tya;
149
+ oby = obx + y * tyb;
150
+ ocy = ocx + y * tyc;
150
151
  for (let z = 0; z < sz; z++) {
151
- const ooz = ooy + z * tzo;
152
- const oaz = oay + z * tza;
153
- const obz = oby + z * tzb;
154
- const ocz = ocy + z * tzc;
152
+ ooz = ooy + z * tzo;
153
+ oaz = oay + z * tza;
154
+ obz = oby + z * tzb;
155
+ ocz = ocy + z * tzc;
155
156
  for (let w = 0; w < sw; w++) {
156
157
  odata[ooz + w * two] = fn(
157
158
  adata[oaz + w * twa],
@@ -164,13 +165,18 @@ const defOpTTT = (fn, dispatch = 1) => {
164
165
  }
165
166
  return out;
166
167
  };
167
- return [
168
- top(dispatch, void 0, f1, f2, f3, f4),
169
- f1,
170
- f2,
171
- f3,
172
- f4
173
- ];
168
+ const impls = [, f1, f2, f3, f4];
169
+ const wrapper = (out, a, b, c) => {
170
+ const { a: $a1, b: $b } = broadcast(a, b);
171
+ const { shape, a: $a2, b: $c } = broadcast($a1, c);
172
+ if (out) {
173
+ ensureShape(out, shape);
174
+ } else {
175
+ out = tensor(a.type, shape, { storage: a.storage });
176
+ }
177
+ return impls[shape.length](out, $a2, $b, $c);
178
+ };
179
+ return wrapper;
174
180
  };
175
181
  export {
176
182
  defOpTTT
package/diagonal.d.ts ADDED
@@ -0,0 +1,16 @@
1
+ import type { ITensor } from "./api.js";
2
+ import { Tensor1 } from "./tensor.js";
3
+ /**
4
+ * Returns a 1D tensor view of the nD diagonal of the given tensor.
5
+ *
6
+ * @remarks
7
+ * For 1D tensors this will merely by a shallow copy.
8
+ */
9
+ export declare const diagonal: <T>(a: ITensor<T>) => Tensor1<T>;
10
+ /**
11
+ * Computes the trace of given tensor, i.e. the component sum of `diagonal(a)`.
12
+ *
13
+ * @param a
14
+ */
15
+ export declare const trace: (a: ITensor) => number;
16
+ //# sourceMappingURL=diagonal.d.ts.map
package/diagonal.js ADDED
@@ -0,0 +1,18 @@
1
+ import { sum as vsum } from "@thi.ng/vectors/sum";
2
+ import { sum } from "./sum.js";
3
+ import { Tensor1 } from "./tensor.js";
4
+ const diagonal = (a) => {
5
+ return new Tensor1(
6
+ a.type,
7
+ a.storage,
8
+ a.data,
9
+ [Math.min(...a.shape)],
10
+ [vsum(a.stride)],
11
+ a.offset
12
+ );
13
+ };
14
+ const trace = (a) => sum(diagonal(a));
15
+ export {
16
+ diagonal,
17
+ trace
18
+ };
package/div.d.ts CHANGED
@@ -1,42 +1,11 @@
1
1
  /**
2
2
  * Componentwise nD tensor division. Writes result to `out`. If `out` is null,
3
- * mutates `a`. Multi-method.
3
+ * creates a new tensor using `a`'s type and storage provider and shape as
4
+ * determined by broadcasting rules (see {@link broadcast} for details).
4
5
  *
5
6
  * @param out - output tensor
6
7
  * @param a - input tensor
7
- * @param n - scalar
8
+ * @param b - input tensor
8
9
  */
9
- export declare const div: import("./api.js").MultiTensorOpTT<number>;
10
- /**
11
- * Same as {@link div} for 1D tensors.
12
- *
13
- * @param out - output tensor
14
- * @param a - input tensor
15
- * @param n - scalar
16
- */
17
- export declare const div1: import("./api.js").TensorOpTT<number, number, import("./tensor.js").Tensor1<number>, import("./tensor.js").Tensor1<number>>;
18
- /**
19
- * Same as {@link div} for 2D tensors.
20
- *
21
- * @param out - output tensor
22
- * @param a - input tensor
23
- * @param n - scalar
24
- */
25
- export declare const div2: import("./api.js").TensorOpTT<number, number, import("./tensor.js").Tensor2<number>, import("./tensor.js").Tensor2<number>>;
26
- /**
27
- * Same as {@link div} for 3D tensors.
28
- *
29
- * @param out - output tensor
30
- * @param a - input tensor
31
- * @param n - scalar
32
- */
33
- export declare const div3: import("./api.js").TensorOpTT<number, number, import("./tensor.js").Tensor3<number>, import("./tensor.js").Tensor3<number>>;
34
- /**
35
- * Same as {@link div} for 4D tensors.
36
- *
37
- * @param out - output tensor
38
- * @param a - input tensor
39
- * @param n - scalar
40
- */
41
- export declare const div4: import("./api.js").TensorOpTT<number, number, import("./tensor.js").Tensor4<number>, import("./tensor.js").Tensor4<number>>;
10
+ export declare const div: import("./api.js").TensorOpTT<number>;
42
11
  //# sourceMappingURL=div.d.ts.map
package/div.js CHANGED
@@ -1,15 +1,6 @@
1
1
  import { $div } from "@thi.ng/vectors/ops";
2
2
  import { defOpTT } from "./defoptt.js";
3
- const [a, b, c, d, e] = defOpTT($div);
4
- const div = a;
5
- const div1 = b;
6
- const div2 = c;
7
- const div3 = d;
8
- const div4 = e;
3
+ const div = defOpTT($div);
9
4
  export {
10
- div,
11
- div1,
12
- div2,
13
- div3,
14
- div4
5
+ div
15
6
  };