@thi.ng/tensors 0.1.0 → 0.3.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (126) hide show
  1. package/CHANGELOG.md +45 -1
  2. package/README.md +93 -34
  3. package/abs.d.ts +1 -29
  4. package/abs.js +2 -11
  5. package/add.d.ts +4 -35
  6. package/add.js +2 -11
  7. package/addn.d.ts +1 -33
  8. package/addn.js +2 -11
  9. package/api.d.ts +138 -15
  10. package/broadcast.d.ts +24 -0
  11. package/broadcast.js +54 -0
  12. package/clamp.d.ts +4 -38
  13. package/clamp.js +2 -11
  14. package/clampn.d.ts +1 -37
  15. package/clampn.js +2 -11
  16. package/cos.d.ts +1 -29
  17. package/cos.js +2 -11
  18. package/defopn.d.ts +4 -7
  19. package/defopn.js +17 -7
  20. package/defoprt.d.ts +5 -7
  21. package/defoprt.js +21 -11
  22. package/defoprtt.d.ts +7 -7
  23. package/defoprtt.js +30 -26
  24. package/defopt.d.ts +4 -7
  25. package/defopt.js +17 -20
  26. package/defoptn.d.ts +4 -7
  27. package/defoptn.js +17 -20
  28. package/defoptnn.d.ts +4 -7
  29. package/defoptnn.js +20 -16
  30. package/defoptt.d.ts +5 -6
  31. package/defoptt.js +36 -27
  32. package/defopttt.d.ts +5 -7
  33. package/defopttt.js +43 -37
  34. package/diagonal.d.ts +16 -0
  35. package/diagonal.js +18 -0
  36. package/div.d.ts +4 -35
  37. package/div.js +2 -11
  38. package/divn.d.ts +1 -33
  39. package/divn.js +2 -11
  40. package/dot.d.ts +4 -26
  41. package/dot.js +3 -12
  42. package/errors.d.ts +2 -0
  43. package/errors.js +3 -0
  44. package/exp.d.ts +1 -29
  45. package/exp.js +2 -11
  46. package/exp2.d.ts +1 -29
  47. package/exp2.js +2 -11
  48. package/filtered-indices.d.ts +34 -0
  49. package/filtered-indices.js +17 -0
  50. package/identity.d.ts +7 -0
  51. package/identity.js +3 -2
  52. package/index.d.ts +6 -0
  53. package/index.js +6 -0
  54. package/log.d.ts +1 -29
  55. package/log.js +2 -11
  56. package/log2.d.ts +1 -29
  57. package/log2.js +2 -11
  58. package/mag.d.ts +5 -0
  59. package/magsq.d.ts +1 -25
  60. package/magsq.js +3 -12
  61. package/max.d.ts +5 -30
  62. package/max.js +2 -11
  63. package/maxn.d.ts +1 -33
  64. package/maxn.js +2 -11
  65. package/min.d.ts +5 -30
  66. package/min.js +2 -11
  67. package/minn.d.ts +1 -33
  68. package/minn.js +2 -11
  69. package/mul.d.ts +4 -35
  70. package/mul.js +2 -11
  71. package/mulm.d.ts +1 -1
  72. package/mulm.js +21 -20
  73. package/muln.d.ts +1 -33
  74. package/muln.js +2 -11
  75. package/mulv.d.ts +1 -1
  76. package/mulv.js +16 -15
  77. package/normalize.d.ts +9 -1
  78. package/package.json +48 -4
  79. package/pow.d.ts +5 -30
  80. package/pow.js +2 -11
  81. package/pown.d.ts +1 -33
  82. package/pown.js +2 -11
  83. package/product.d.ts +1 -25
  84. package/product.js +5 -12
  85. package/rand-distrib.d.ts +13 -10
  86. package/rand-distrib.js +55 -23
  87. package/range.d.ts +20 -0
  88. package/range.js +28 -0
  89. package/relu.d.ts +1 -29
  90. package/relu.js +2 -11
  91. package/relun.d.ts +1 -33
  92. package/relun.js +2 -11
  93. package/select.d.ts +6 -6
  94. package/select.js +26 -23
  95. package/set.d.ts +1 -6
  96. package/set.js +3 -11
  97. package/setn.d.ts +1 -6
  98. package/setn.js +3 -11
  99. package/sigmoid.d.ts +1 -29
  100. package/sigmoid.js +2 -11
  101. package/sin.d.ts +1 -29
  102. package/sin.js +2 -11
  103. package/softmax.d.ts +1 -1
  104. package/softplus.d.ts +1 -33
  105. package/softplus.js +2 -11
  106. package/sqrt.d.ts +1 -29
  107. package/sqrt.js +2 -11
  108. package/step.d.ts +1 -33
  109. package/step.js +2 -11
  110. package/sub.d.ts +4 -35
  111. package/sub.js +2 -11
  112. package/subn.d.ts +1 -33
  113. package/subn.js +2 -11
  114. package/sum.d.ts +1 -25
  115. package/sum.js +5 -12
  116. package/svd.d.ts +33 -0
  117. package/svd.js +246 -0
  118. package/swap.d.ts +26 -0
  119. package/swap.js +15 -0
  120. package/tan.d.ts +1 -29
  121. package/tan.js +2 -11
  122. package/tanh.d.ts +1 -29
  123. package/tanh.js +2 -11
  124. package/tensor.d.ts +6 -1
  125. package/tensor.js +58 -21
  126. package/top.d.ts +2 -6
package/broadcast.js ADDED
@@ -0,0 +1,54 @@
1
+ import { illegalArgs } from "@thi.ng/errors";
2
+ import { equals } from "@thi.ng/vectors";
3
+ import { max } from "@thi.ng/vectors/max";
4
+ const broadcast = (a, b) => {
5
+ if (equals(a.shape, b.shape)) return { shape: a.shape, a, b };
6
+ const ashape = a.shape.slice();
7
+ const astride = a.stride.slice();
8
+ const bshape = b.shape.slice();
9
+ const bstride = b.stride.slice();
10
+ let da = a.dim;
11
+ let db = b.dim;
12
+ let bcastA = da < db;
13
+ let bcastB = db < da;
14
+ if (bcastA) {
15
+ while (da < db) {
16
+ ashape.unshift(1);
17
+ astride.unshift(0);
18
+ da++;
19
+ }
20
+ } else if (bcastB) {
21
+ while (db < da) {
22
+ bshape.unshift(1);
23
+ bstride.unshift(0);
24
+ db++;
25
+ }
26
+ }
27
+ for (let i = 0; i < da; i++) {
28
+ const sa = ashape[i];
29
+ const sb = bshape[i];
30
+ if (sa < sb) {
31
+ if (sa > 1) __broadcastError(ashape, bshape);
32
+ astride[i] = 0;
33
+ bcastA = true;
34
+ } else if (sb < sa) {
35
+ if (sb > 1) __broadcastError(ashape, bshape);
36
+ bstride[i] = 0;
37
+ bcastB = true;
38
+ }
39
+ }
40
+ const shape = max([], ashape, bshape);
41
+ return {
42
+ shape,
43
+ a: bcastA ? a.broadcast(shape, astride) : a,
44
+ b: bcastB ? b.broadcast(shape, bstride) : b
45
+ };
46
+ };
47
+ const __broadcastError = (ashape, bshape) => illegalArgs(
48
+ `incompatible shapes: ${JSON.stringify(ashape)} vs ${JSON.stringify(
49
+ bshape
50
+ )}`
51
+ );
52
+ export {
53
+ broadcast
54
+ };
package/clamp.d.ts CHANGED
@@ -1,47 +1,13 @@
1
1
  /**
2
2
  * Componentwise clamps nD tensor `a` to closed interval defined by `[b,c]`.
3
- * Writes result to `out`. If `out` is null, mutates `a`. Multi-method.
3
+ * Writes result to `out`. If `out` is null, creates a new tensor using `a`'s
4
+ * type and storage provider and shape as determined by broadcasting rules (see
5
+ * {@link broadcast} for details).
4
6
  *
5
7
  * @param out - output tensor
6
8
  * @param a - input tensor
7
9
  * @param b - input tensor (min)
8
10
  * @param c - input tensor (max)
9
11
  */
10
- export declare const clamp: import("./api.js").MultiTensorOpTTT<number>;
11
- /**
12
- * Same as {@link clamp} for 1D tensors.
13
- *
14
- * @param out - output tensor
15
- * @param a - input tensor
16
- * @param b - input tensor (min)
17
- * @param c - input tensor (max)
18
- */
19
- export declare const clamp1: import("./api.js").TensorOpTTT<number, number, import("./tensor.js").Tensor1<number>, import("./tensor.js").Tensor1<number>>;
20
- /**
21
- * Same as {@link clamp} for 2D tensors.
22
- *
23
- * @param out - output tensor
24
- * @param a - input tensor
25
- * @param b - input tensor (min)
26
- * @param c - input tensor (max)
27
- */
28
- export declare const clamp2: import("./api.js").TensorOpTTT<number, number, import("./tensor.js").Tensor2<number>, import("./tensor.js").Tensor2<number>>;
29
- /**
30
- * Same as {@link clamp} for 3D tensors.
31
- *
32
- * @param out - output tensor
33
- * @param a - input tensor
34
- * @param b - input tensor (min)
35
- * @param c - input tensor (max)
36
- */
37
- export declare const clamp3: import("./api.js").TensorOpTTT<number, number, import("./tensor.js").Tensor3<number>, import("./tensor.js").Tensor3<number>>;
38
- /**
39
- * Same as {@link clamp} for 4D tensors.
40
- *
41
- * @param out - output tensor
42
- * @param a - input tensor
43
- * @param b - input tensor (min)
44
- * @param c - input tensor (max)
45
- */
46
- export declare const clamp4: import("./api.js").TensorOpTTT<number, number, import("./tensor.js").Tensor4<number>, import("./tensor.js").Tensor4<number>>;
12
+ export declare const clamp: import("./api.js").TensorOpTTT<number>;
47
13
  //# sourceMappingURL=clamp.d.ts.map
package/clamp.js CHANGED
@@ -1,15 +1,6 @@
1
1
  import { clamp as op } from "@thi.ng/math/interval";
2
2
  import { defOpTTT } from "./defopttt.js";
3
- const [a, b, c, d, e] = defOpTTT(op);
4
- const clamp = a;
5
- const clamp1 = b;
6
- const clamp2 = c;
7
- const clamp3 = d;
8
- const clamp4 = e;
3
+ const clamp = defOpTTT(op);
9
4
  export {
10
- clamp,
11
- clamp1,
12
- clamp2,
13
- clamp3,
14
- clamp4
5
+ clamp
15
6
  };
package/clampn.d.ts CHANGED
@@ -7,41 +7,5 @@
7
7
  * @param b - scalar (min)
8
8
  * @param c - scalar (max)
9
9
  */
10
- export declare const clampN: import("./api.js").MultiTensorOpTNN<number>;
11
- /**
12
- * Same as {@link clampN} for 1D tensors.
13
- *
14
- * @param out - output tensor
15
- * @param a - input tensor
16
- * @param b - scalar (min)
17
- * @param c - scalar (max)
18
- */
19
- export declare const clampN1: import("./api.js").TensorOpTNN<number, number, import("./tensor.js").Tensor1<number>, import("./tensor.js").Tensor1<number>>;
20
- /**
21
- * Same as {@link clampN} for 2D tensors.
22
- *
23
- * @param out - output tensor
24
- * @param a - input tensor
25
- * @param b - scalar (min)
26
- * @param c - scalar (max)
27
- */
28
- export declare const clampN2: import("./api.js").TensorOpTNN<number, number, import("./tensor.js").Tensor2<number>, import("./tensor.js").Tensor2<number>>;
29
- /**
30
- * Same as {@link clampN} for 3D tensors.
31
- *
32
- * @param out - output tensor
33
- * @param a - input tensor
34
- * @param b - scalar (min)
35
- * @param c - scalar (max)
36
- */
37
- export declare const clampN3: import("./api.js").TensorOpTNN<number, number, import("./tensor.js").Tensor3<number>, import("./tensor.js").Tensor3<number>>;
38
- /**
39
- * Same as {@link clampN} for 4D tensors.
40
- *
41
- * @param out - output tensor
42
- * @param a - input tensor
43
- * @param b - scalar (min)
44
- * @param c - scalar (max)
45
- */
46
- export declare const clampN4: import("./api.js").TensorOpTNN<number, number, import("./tensor.js").Tensor4<number>, import("./tensor.js").Tensor4<number>>;
10
+ export declare const clampN: import("./api.js").MultiTensorOpImpl<import("./api.js").TensorOpTNN<number>>;
47
11
  //# sourceMappingURL=clampn.d.ts.map
package/clampn.js CHANGED
@@ -1,15 +1,6 @@
1
1
  import { clamp as op } from "@thi.ng/math/interval";
2
2
  import { defOpTNN } from "./defoptnn.js";
3
- const [a, b, c, d, e] = defOpTNN(op);
4
- const clampN = a;
5
- const clampN1 = b;
6
- const clampN2 = c;
7
- const clampN3 = d;
8
- const clampN4 = e;
3
+ const clampN = defOpTNN(op);
9
4
  export {
10
- clampN,
11
- clampN1,
12
- clampN2,
13
- clampN3,
14
- clampN4
5
+ clampN
15
6
  };
package/cos.d.ts CHANGED
@@ -5,33 +5,5 @@
5
5
  * @param out - output tensor
6
6
  * @param a - input tensor
7
7
  */
8
- export declare const cos: import("./api.js").MultiTensorOpT<number>;
9
- /**
10
- * Same as {@link cos} for 1D tensors.
11
- *
12
- * @param out - output tensor
13
- * @param a - input tensor
14
- */
15
- export declare const cos1: import("./api.js").TensorOpT<number, number, import("./tensor.js").Tensor1<number>, import("./tensor.js").Tensor1<number>>;
16
- /**
17
- * Same as {@link cos} for 2D tensors.
18
- *
19
- * @param out - output tensor
20
- * @param a - input tensor
21
- */
22
- export declare const cos2: import("./api.js").TensorOpT<number, number, import("./tensor.js").Tensor2<number>, import("./tensor.js").Tensor2<number>>;
23
- /**
24
- * Same as {@link cos} for 3D tensors.
25
- *
26
- * @param out - output tensor
27
- * @param a - input tensor
28
- */
29
- export declare const cos3: import("./api.js").TensorOpT<number, number, import("./tensor.js").Tensor3<number>, import("./tensor.js").Tensor3<number>>;
30
- /**
31
- * Same as {@link cos} for 4D tensors.
32
- *
33
- * @param out - output tensor
34
- * @param a - input tensor
35
- */
36
- export declare const cos4: import("./api.js").TensorOpT<number, number, import("./tensor.js").Tensor4<number>, import("./tensor.js").Tensor4<number>>;
8
+ export declare const cos: import("./api.js").MultiTensorOpImpl<import("./api.js").TensorOpT<number>>;
37
9
  //# sourceMappingURL=cos.d.ts.map
package/cos.js CHANGED
@@ -1,14 +1,5 @@
1
1
  import { defOpT } from "./defopt.js";
2
- const [a, b, c, d, e] = defOpT(Math.cos);
3
- const cos = a;
4
- const cos1 = b;
5
- const cos2 = c;
6
- const cos3 = d;
7
- const cos4 = e;
2
+ const cos = defOpT(Math.cos);
8
3
  export {
9
- cos,
10
- cos1,
11
- cos2,
12
- cos3,
13
- cos4
4
+ cos
14
5
  };
package/defopn.d.ts CHANGED
@@ -1,13 +1,10 @@
1
1
  import type { Fn } from "@thi.ng/api";
2
- import type { MultiTensorOpN, TensorOpN } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpN } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpN}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpN} applying the given function component-wise.
8
6
  *
9
7
  * @param fn
10
- * @param dispatch
11
8
  */
12
- export declare const defOpN: <A = number, B = A>(fn: Fn<A, B>) => [MultiTensorOpN<A, B>, TensorOpN<A, B, Tensor1<B>>, TensorOpN<A, B, Tensor2<B>>, TensorOpN<A, B, Tensor3<B>>, TensorOpN<A, B, Tensor4<B>>];
9
+ export declare const defOpN: <A = number, B = A>(fn: Fn<A, B>) => import("./api.js").MultiTensorOpImpl<TensorOpN<A, B>>;
13
10
  //# sourceMappingURL=defopn.d.ts.map
package/defopn.js CHANGED
@@ -17,8 +17,9 @@ const defOpN = (fn) => {
17
17
  stride: [tx, ty],
18
18
  offset
19
19
  } = out;
20
+ let ox;
20
21
  for (let x = 0; x < sx; x++) {
21
- const ox = offset + x * tx;
22
+ ox = offset + x * tx;
22
23
  for (let y = 0; y < sy; y++) data[ox + y * ty] = fn(a);
23
24
  }
24
25
  return out;
@@ -30,10 +31,11 @@ const defOpN = (fn) => {
30
31
  stride: [tx, ty, tz],
31
32
  offset
32
33
  } = out;
34
+ let ox, oy;
33
35
  for (let x = 0; x < sx; x++) {
34
- const ox = offset + x * tx;
36
+ ox = offset + x * tx;
35
37
  for (let y = 0; y < sy; y++) {
36
- const oy = ox + y * ty;
38
+ oy = ox + y * ty;
37
39
  for (let z = 0; z < sz; z++) data[oy + z * tz] = fn(a);
38
40
  }
39
41
  }
@@ -46,19 +48,27 @@ const defOpN = (fn) => {
46
48
  stride: [tx, ty, tz, tw],
47
49
  offset
48
50
  } = out;
51
+ let ox, oy, oz;
49
52
  for (let x = 0; x < sx; x++) {
50
- const ox = offset + x * tx;
53
+ ox = offset + x * tx;
51
54
  for (let y = 0; y < sy; y++) {
52
- const oy = ox + y * ty;
55
+ oy = ox + y * ty;
53
56
  for (let z = 0; z < sz; z++) {
54
- const oz = oy + z * tz;
57
+ oz = oy + z * tz;
55
58
  for (let w = 0; w < sw; w++) data[oz + w * tw] = fn(a);
56
59
  }
57
60
  }
58
61
  }
59
62
  return out;
60
63
  };
61
- return [top(0, void 0, f1, f2, f3, f4), f1, f2, f3, f4];
64
+ return top(
65
+ 0,
66
+ void 0,
67
+ f1,
68
+ f2,
69
+ f3,
70
+ f4
71
+ );
62
72
  };
63
73
  export {
64
74
  defOpN
package/defoprt.d.ts CHANGED
@@ -1,14 +1,12 @@
1
- import type { Fn0, Fn2 } from "@thi.ng/api";
2
- import type { MultiTensorOpRT, TensorOpRT } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
1
+ import type { Fn0 } from "@thi.ng/api";
2
+ import type { TensorData, TensorOpRT } from "./api.js";
4
3
  /**
5
4
  * Higher order tensor reduction op factory. Takes given reduction `rfn` and
6
- * `init` function to produce an initial result. Returns a 4-tuple of
7
- * {@link TensorOpRT}s applying the given function component-wise. The result
8
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d]`.
5
+ * `init` function to produce an initial result. Returns a {@link TensorOpRT}
6
+ * applying the given function component-wise.
9
7
  *
10
8
  * @param rfn
11
9
  * @param init
12
10
  */
13
- export declare const defOpRT: <A = number, B = A>(rfn: Fn2<B, A, B>, init: Fn0<B>) => [MultiTensorOpRT<A, B>, TensorOpRT<A, B, Tensor1<A>>, TensorOpRT<A, B, Tensor2<A>>, TensorOpRT<A, B, Tensor3<A>>, TensorOpRT<A, B, Tensor4<A>>];
11
+ export declare const defOpRT: <A = number, B = A>(rfn: (acc: B, data: TensorData<A>, i: number) => B, init: Fn0<B>) => import("./api.js").MultiTensorOpImpl<TensorOpRT<A, B>>;
14
12
  //# sourceMappingURL=defoprt.d.ts.map
package/defoprt.js CHANGED
@@ -9,7 +9,7 @@ const defOpRT = (rfn, init) => {
9
9
  } = a;
10
10
  let res = init();
11
11
  for (let x = 0; x < sx; x++) {
12
- res = rfn(res, data[offset + x * tx]);
12
+ res = rfn(res, data, offset + x * tx);
13
13
  }
14
14
  return res;
15
15
  };
@@ -21,10 +21,11 @@ const defOpRT = (rfn, init) => {
21
21
  stride: [tx, ty]
22
22
  } = a;
23
23
  let res = init();
24
+ let ox;
24
25
  for (let x = 0; x < sx; x++) {
25
- const ox = offset + x * tx;
26
+ ox = offset + x * tx;
26
27
  for (let y = 0; y < sy; y++) {
27
- res = rfn(res, data[ox + y * ty]);
28
+ res = rfn(res, data, ox + y * ty);
28
29
  }
29
30
  }
30
31
  return res;
@@ -37,12 +38,13 @@ const defOpRT = (rfn, init) => {
37
38
  stride: [tx, ty, tz]
38
39
  } = a;
39
40
  let res = init();
41
+ let ox, oy;
40
42
  for (let x = 0; x < sx; x++) {
41
- const ox = offset + x * tx;
43
+ ox = offset + x * tx;
42
44
  for (let y = 0; y < sy; y++) {
43
- const oy = ox + y * ty;
45
+ oy = ox + y * ty;
44
46
  for (let z = 0; z < sz; z++) {
45
- res = rfn(res, data[oy + z * tz]);
47
+ res = rfn(res, data, oy + z * tz);
46
48
  }
47
49
  }
48
50
  }
@@ -56,21 +58,29 @@ const defOpRT = (rfn, init) => {
56
58
  stride: [tx, ty, tz, tw]
57
59
  } = a;
58
60
  let res = init();
61
+ let ox, oy, oz;
59
62
  for (let x = 0; x < sx; x++) {
60
- const ox = offset + x * tx;
63
+ ox = offset + x * tx;
61
64
  for (let y = 0; y < sy; y++) {
62
- const oy = ox + y * ty;
65
+ oy = ox + y * ty;
63
66
  for (let z = 0; z < sz; z++) {
64
- const oz = oy + z * tz;
67
+ oz = oy + z * tz;
65
68
  for (let w = 0; w < sw; w++) {
66
- res = rfn(res, data[oz + w * tw]);
69
+ res = rfn(res, data, oz + w * tw);
67
70
  }
68
71
  }
69
72
  }
70
73
  }
71
74
  return res;
72
75
  };
73
- return [top(0, void 0, f1, f2, f3, f4), f1, f2, f3, f4];
76
+ return top(
77
+ 0,
78
+ void 0,
79
+ f1,
80
+ f2,
81
+ f3,
82
+ f4
83
+ );
74
84
  };
75
85
  export {
76
86
  defOpRT
package/defoprtt.d.ts CHANGED
@@ -1,14 +1,14 @@
1
- import type { Fn0, Fn3 } from "@thi.ng/api";
2
- import type { MultiTensorOpRTT, TensorOpRTT } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
1
+ import type { Fn0 } from "@thi.ng/api";
2
+ import type { TensorData, TensorOpRTT } from "./api.js";
4
3
  /**
5
4
  * Higher order tensor reduction op factory. Takes given reduction `rfn` and
6
- * `init` function to produce an initial result. Returns a 4-tuple of
7
- * {@link TensorOpRTT}s applying the given function component-wise. The result
8
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d]`.
5
+ * `init` function to produce an initial result. Returns a {@link TensorOpRTT}
6
+ * applying the given function componentwise, by default with broadcasting rules
7
+ * (see {@link broadcast} for details).
9
8
  *
10
9
  * @param rfn
11
10
  * @param init
11
+ * @param useBroadcast
12
12
  */
13
- export declare const defOpRTT: <A = number, B = A>(rfn: Fn3<B, A, A, B>, init: Fn0<B>) => [MultiTensorOpRTT<A, B>, TensorOpRTT<A, B, Tensor1<A>>, TensorOpRTT<A, B, Tensor2<A>>, TensorOpRTT<A, B, Tensor3<A>>, TensorOpRTT<A, B, Tensor4<A>>];
13
+ export declare const defOpRTT: <A = number, B = A>(rfn: (acc: B, adata: TensorData<A>, bdata: TensorData<A>, ia: number, ib: number) => B, init: Fn0<B>, useBroadcast?: boolean) => TensorOpRTT<A, B>;
14
14
  //# sourceMappingURL=defoprtt.d.ts.map
package/defoprtt.js CHANGED
@@ -1,5 +1,5 @@
1
- import { top } from "./top.js";
2
- const defOpRTT = (rfn, init) => {
1
+ import { broadcast } from "./broadcast.js";
2
+ const defOpRTT = (rfn, init, useBroadcast = true) => {
3
3
  const f1 = (a, b) => {
4
4
  const {
5
5
  data: adata,
@@ -14,7 +14,7 @@ const defOpRTT = (rfn, init) => {
14
14
  } = b;
15
15
  let res = init();
16
16
  for (let x = 0; x < sx; x++) {
17
- res = rfn(res, adata[oa + x * txa], bdata[ob + x * txb]);
17
+ res = rfn(res, adata, bdata, oa + x * txa, ob + x * txb);
18
18
  }
19
19
  return res;
20
20
  };
@@ -31,11 +31,12 @@ const defOpRTT = (rfn, init) => {
31
31
  stride: [txb, tyb]
32
32
  } = b;
33
33
  let res = init();
34
+ let oax, obx;
34
35
  for (let x = 0; x < sx; x++) {
35
- const oax = oa + x * txa;
36
- const obx = ob + x * txb;
36
+ oax = oa + x * txa;
37
+ obx = ob + x * txb;
37
38
  for (let y = 0; y < sy; y++) {
38
- res = rfn(res, adata[oax + y * tya], bdata[obx + y * tyb]);
39
+ res = rfn(res, adata, bdata, oax + y * tya, obx + y * tyb);
39
40
  }
40
41
  }
41
42
  return res;
@@ -53,14 +54,15 @@ const defOpRTT = (rfn, init) => {
53
54
  stride: [txb, tyb, tzb]
54
55
  } = b;
55
56
  let res = init();
57
+ let oax, obx, oay, oby;
56
58
  for (let x = 0; x < sx; x++) {
57
- const oax = oa + x * txa;
58
- const obx = ob + x * txb;
59
+ oax = oa + x * txa;
60
+ obx = ob + x * txb;
59
61
  for (let y = 0; y < sy; y++) {
60
- const oay = oax + y * tya;
61
- const oby = obx + y * tyb;
62
+ oay = oax + y * tya;
63
+ oby = obx + y * tyb;
62
64
  for (let z = 0; z < sz; z++) {
63
- res = rfn(res, adata[oay + z * tza], bdata[oby + z * tzb]);
65
+ res = rfn(res, adata, bdata, oay + z * tza, oby + z * tzb);
64
66
  }
65
67
  }
66
68
  }
@@ -79,20 +81,23 @@ const defOpRTT = (rfn, init) => {
79
81
  stride: [txb, tyb, tzb, twb]
80
82
  } = b;
81
83
  let res = init();
84
+ let oax, obx, oay, oby, oaz, obz;
82
85
  for (let x = 0; x < sx; x++) {
83
- const oax = oa + x * txa;
84
- const obx = ob + x * txb;
86
+ oax = oa + x * txa;
87
+ obx = ob + x * txb;
85
88
  for (let y = 0; y < sy; y++) {
86
- const oay = oax + y * tya;
87
- const oby = obx + y * tyb;
89
+ oay = oax + y * tya;
90
+ oby = obx + y * tyb;
88
91
  for (let z = 0; z < sz; z++) {
89
- const oaz = oay + z * tza;
90
- const obz = oby + z * tzb;
92
+ oaz = oay + z * tza;
93
+ obz = oby + z * tzb;
91
94
  for (let w = 0; w < sw; w++) {
92
95
  res = rfn(
93
96
  res,
94
- adata[oaz + w * twa],
95
- bdata[obz + w * twb]
97
+ adata,
98
+ bdata,
99
+ oaz + w * twa,
100
+ obz + w * twb
96
101
  );
97
102
  }
98
103
  }
@@ -100,13 +105,12 @@ const defOpRTT = (rfn, init) => {
100
105
  }
101
106
  return res;
102
107
  };
103
- return [
104
- top(0, void 0, f1, f2, f3, f4),
105
- f1,
106
- f2,
107
- f3,
108
- f4
109
- ];
108
+ const impls = [, f1, f2, f3, f4];
109
+ const wrapper = useBroadcast ? (a, b) => {
110
+ const { shape, a: $a, b: $b } = broadcast(a, b);
111
+ return impls[shape.length]($a, $b);
112
+ } : (a, b) => impls[a.dim](a, b);
113
+ return wrapper;
110
114
  };
111
115
  export {
112
116
  defOpRTT
package/defopt.d.ts CHANGED
@@ -1,13 +1,10 @@
1
1
  import type { Fn } from "@thi.ng/api";
2
- import type { MultiTensorOpT, TensorOpT } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpT } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpT}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpT} applying the given function component-wise.
8
6
  *
9
7
  * @param fn
10
- * @param dispatch
11
8
  */
12
- export declare const defOpT: <T = number>(fn: Fn<T, T>, dispatch?: number) => [MultiTensorOpT<T>, TensorOpT<T, T, Tensor1<T>, Tensor1<T>>, TensorOpT<T, T, Tensor2<T>, Tensor2<T>>, TensorOpT<T, T, Tensor3<T>, Tensor3<T>>, TensorOpT<T, T, Tensor4<T>, Tensor4<T>>];
9
+ export declare const defOpT: <T = number>(fn: Fn<T, T>) => import("./api.js").MultiTensorOpImpl<TensorOpT<T>>;
13
10
  //# sourceMappingURL=defopt.d.ts.map
package/defopt.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { top } from "./top.js";
2
- const defOpT = (fn, dispatch = 1) => {
2
+ const defOpT = (fn) => {
3
3
  const f1 = (out, a) => {
4
4
  !out && (out = a);
5
5
  const {
@@ -31,9 +31,10 @@ const defOpT = (fn, dispatch = 1) => {
31
31
  shape: [sx, sy],
32
32
  stride: [txa, tya]
33
33
  } = a;
34
+ let oox, oax;
34
35
  for (let x = 0; x < sx; x++) {
35
- const oox = oo + x * txo;
36
- const oax = oa + x * txa;
36
+ oox = oo + x * txo;
37
+ oax = oa + x * txa;
37
38
  for (let y = 0; y < sy; y++) {
38
39
  odata[oox + y * tyo] = fn(adata[oax + y * tya]);
39
40
  }
@@ -53,12 +54,13 @@ const defOpT = (fn, dispatch = 1) => {
53
54
  shape: [sx, sy, sz],
54
55
  stride: [txa, tya, tza]
55
56
  } = a;
57
+ let oox, oax, ooy, oay;
56
58
  for (let x = 0; x < sx; x++) {
57
- const oox = oo + x * txo;
58
- const oax = oa + x * txa;
59
+ oox = oo + x * txo;
60
+ oax = oa + x * txa;
59
61
  for (let y = 0; y < sy; y++) {
60
- const ooy = oox + y * tyo;
61
- const oay = oax + y * tya;
62
+ ooy = oox + y * tyo;
63
+ oay = oax + y * tya;
62
64
  for (let z = 0; z < sz; z++) {
63
65
  odata[ooy + z * tzo] = fn(adata[oay + z * tza]);
64
66
  }
@@ -79,15 +81,16 @@ const defOpT = (fn, dispatch = 1) => {
79
81
  shape: [sx, sy, sz, sw],
80
82
  stride: [txa, tya, tza, twa]
81
83
  } = a;
84
+ let oox, oax, ooy, oay, ooz, oaz;
82
85
  for (let x = 0; x < sx; x++) {
83
- const oox = oo + x * txo;
84
- const oax = oa + x * txa;
86
+ oox = oo + x * txo;
87
+ oax = oa + x * txa;
85
88
  for (let y = 0; y < sy; y++) {
86
- const ooy = oox + y * tyo;
87
- const oay = oax + y * tya;
89
+ ooy = oox + y * tyo;
90
+ oay = oax + y * tya;
88
91
  for (let z = 0; z < sz; z++) {
89
- const ooz = ooy + z * tzo;
90
- const oaz = oay + z * tza;
92
+ ooz = ooy + z * tzo;
93
+ oaz = oay + z * tza;
91
94
  for (let w = 0; w < sw; w++) {
92
95
  odata[ooz + w * two] = fn(adata[oaz + w * twa]);
93
96
  }
@@ -96,13 +99,7 @@ const defOpT = (fn, dispatch = 1) => {
96
99
  }
97
100
  return out;
98
101
  };
99
- return [
100
- top(dispatch, void 0, f1, f2, f3, f4),
101
- f1,
102
- f2,
103
- f3,
104
- f4
105
- ];
102
+ return top(1, void 0, f1, f2, f3, f4);
106
103
  };
107
104
  export {
108
105
  defOpT
package/defoptn.d.ts CHANGED
@@ -1,13 +1,10 @@
1
1
  import type { FnU2 } from "@thi.ng/api";
2
- import type { MultiTensorOpTN, TensorOpTN } from "./api.js";
3
- import type { Tensor1, Tensor2, Tensor3, Tensor4 } from "./tensor.js";
2
+ import type { TensorOpTN } from "./api.js";
4
3
  /**
5
- * Higher order tensor op factory. Takes given `fn` and returns a 4-tuple of
6
- * {@link TensorOpTN}s applying the given function component-wise. The result
7
- * tuple uses this order: `[polymorphic, 1d, 2d, 3d, 4d]`.
4
+ * Higher order tensor op factory. Takes given `fn` and returns a
5
+ * {@link TensorOpTN} applying the given function component-wise.
8
6
  *
9
7
  * @param fn
10
- * @param dispatch
11
8
  */
12
- export declare const defOpTN: <T = number>(fn: FnU2<T>, dispatch?: number) => [MultiTensorOpTN<T>, TensorOpTN<T, T, Tensor1<T>, Tensor1<T>>, TensorOpTN<T, T, Tensor2<T>, Tensor2<T>>, TensorOpTN<T, T, Tensor3<T>, Tensor3<T>>, TensorOpTN<T, T, Tensor4<T>, Tensor4<T>>];
9
+ export declare const defOpTN: <T = number>(fn: FnU2<T>) => import("./api.js").MultiTensorOpImpl<TensorOpTN<T>>;
13
10
  //# sourceMappingURL=defoptn.d.ts.map