@mni-ml/framework 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/autodiff.d.ts +13 -0
- package/dist/autodiff.d.ts.map +1 -0
- package/dist/autodiff.js +91 -0
- package/dist/autodiff.js.map +1 -0
- package/dist/datasets.d.ts +16 -0
- package/dist/datasets.d.ts.map +1 -0
- package/dist/datasets.js +64 -0
- package/dist/datasets.js.map +1 -0
- package/dist/fast_ops.d.ts +23 -0
- package/dist/fast_ops.d.ts.map +1 -0
- package/dist/fast_ops.js +263 -0
- package/dist/fast_ops.js.map +1 -0
- package/dist/fast_ops_worker.d.ts +2 -0
- package/dist/fast_ops_worker.d.ts.map +1 -0
- package/dist/fast_ops_worker.js +119 -0
- package/dist/fast_ops_worker.js.map +1 -0
- package/dist/gpu_backend.d.ts +37 -0
- package/dist/gpu_backend.d.ts.map +1 -0
- package/dist/gpu_backend.js +163 -0
- package/dist/gpu_backend.js.map +1 -0
- package/dist/gpu_kernels.d.ts +74 -0
- package/dist/gpu_kernels.d.ts.map +1 -0
- package/dist/gpu_kernels.js +571 -0
- package/dist/gpu_kernels.js.map +1 -0
- package/dist/gpu_ops.d.ts +43 -0
- package/dist/gpu_ops.d.ts.map +1 -0
- package/dist/gpu_ops.js +365 -0
- package/dist/gpu_ops.js.map +1 -0
- package/dist/index.d.ts +15 -0
- package/dist/index.d.ts.map +1 -0
- package/dist/index.js +20 -0
- package/dist/index.js.map +1 -0
- package/dist/module.d.ts +23 -0
- package/dist/module.d.ts.map +1 -0
- package/dist/module.js +97 -0
- package/dist/module.js.map +1 -0
- package/dist/nn.d.ts +63 -0
- package/dist/nn.d.ts.map +1 -0
- package/dist/nn.js +234 -0
- package/dist/nn.js.map +1 -0
- package/dist/operators.d.ts +29 -0
- package/dist/operators.d.ts.map +1 -0
- package/dist/operators.js +91 -0
- package/dist/operators.js.map +1 -0
- package/dist/optimizer.d.ts +15 -0
- package/dist/optimizer.d.ts.map +1 -0
- package/dist/optimizer.js +62 -0
- package/dist/optimizer.js.map +1 -0
- package/dist/scalar.d.ts +42 -0
- package/dist/scalar.d.ts.map +1 -0
- package/dist/scalar.js +126 -0
- package/dist/scalar.js.map +1 -0
- package/dist/scalar_functions.d.ts +62 -0
- package/dist/scalar_functions.d.ts.map +1 -0
- package/dist/scalar_functions.js +127 -0
- package/dist/scalar_functions.js.map +1 -0
- package/dist/tensor.d.ts +58 -0
- package/dist/tensor.d.ts.map +1 -0
- package/dist/tensor.js +288 -0
- package/dist/tensor.js.map +1 -0
- package/dist/tensor_data.d.ts +29 -0
- package/dist/tensor_data.d.ts.map +1 -0
- package/dist/tensor_data.js +131 -0
- package/dist/tensor_data.js.map +1 -0
- package/dist/tensor_functions.d.ts +97 -0
- package/dist/tensor_functions.d.ts.map +1 -0
- package/dist/tensor_functions.js +465 -0
- package/dist/tensor_functions.js.map +1 -0
- package/dist/tensor_ops.d.ts +47 -0
- package/dist/tensor_ops.d.ts.map +1 -0
- package/dist/tensor_ops.js +249 -0
- package/dist/tensor_ops.js.map +1 -0
- package/package.json +45 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
import { Scalar } from "./scalar.js";
|
|
2
|
+
import { Tensor } from "./tensor.js";
|
|
3
|
+
export declare function centralDifference(f: (...args: number[]) => number, vals: number[], arg?: number, epsilon?: number): number;
|
|
4
|
+
export declare class Context {
|
|
5
|
+
private _savedValues;
|
|
6
|
+
saveForBackward(...values: number[]): void;
|
|
7
|
+
get savedValues(): number[];
|
|
8
|
+
}
|
|
9
|
+
export declare function topologicalSort(scalar: Scalar): Scalar[];
|
|
10
|
+
export declare function backPropagate(scalar: Scalar, dOut: number): void;
|
|
11
|
+
export declare function topologicalSortTensor(tensor: Tensor): Tensor[];
|
|
12
|
+
export declare function backPropagateTensor(tensor: Tensor, gradOutput: Tensor): void;
|
|
13
|
+
//# sourceMappingURL=autodiff.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"autodiff.d.ts","sourceRoot":"","sources":["../src/autodiff.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAC;AACrC,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAC;AAErC,wBAAgB,iBAAiB,CAC7B,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,MAAM,EAAE,KAAK,MAAM,EAChC,IAAI,EAAE,MAAM,EAAE,EACd,GAAG,GAAE,MAAU,EACf,OAAO,GAAE,MAAa,GACvB,MAAM,CAQR;AAED,qBAAa,OAAO;IAChB,OAAO,CAAC,YAAY,CAAgB;IAEpC,eAAe,CAAE,GAAG,MAAM,EAAE,MAAM,EAAE,GAAG,IAAI;IAI3C,IAAI,WAAW,IAAI,MAAM,EAAE,CAE1B;CACJ;AAED,wBAAgB,eAAe,CAAC,MAAM,EAAE,MAAM,GAAG,MAAM,EAAE,CAcxD;AAED,wBAAgB,aAAa,CAAC,MAAM,EAAE,MAAM,EAAE,IAAI,EAAE,MAAM,GAAG,IAAI,CAkBhE;AAED,wBAAgB,qBAAqB,CAAC,MAAM,EAAE,MAAM,GAAG,MAAM,EAAE,CAc9D;AAED,wBAAgB,mBAAmB,CAAC,MAAM,EAAE,MAAM,EAAE,UAAU,EAAE,MAAM,GAAG,IAAI,CAuB5E"}
|
package/dist/autodiff.js
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import { Scalar } from "./scalar.js";
|
|
2
|
+
import { Tensor } from "./tensor.js";
|
|
3
|
+
export function centralDifference(f, vals, arg = 0, epsilon = 1e-6) {
|
|
4
|
+
const valsPlus = [...vals];
|
|
5
|
+
valsPlus[arg] = valsPlus[arg] + epsilon;
|
|
6
|
+
const valsMinus = [...vals];
|
|
7
|
+
valsMinus[arg] = valsMinus[arg] - epsilon;
|
|
8
|
+
return (f(...valsPlus) - f(...valsMinus)) / (2 * epsilon);
|
|
9
|
+
}
|
|
10
|
+
export class Context {
|
|
11
|
+
_savedValues = [];
|
|
12
|
+
saveForBackward(...values) {
|
|
13
|
+
this._savedValues = values;
|
|
14
|
+
}
|
|
15
|
+
get savedValues() {
|
|
16
|
+
return this._savedValues;
|
|
17
|
+
}
|
|
18
|
+
}
|
|
19
|
+
export function topologicalSort(scalar) {
|
|
20
|
+
const visited = new Set();
|
|
21
|
+
const sorted = new Array();
|
|
22
|
+
const dfs = (scalar) => {
|
|
23
|
+
if (visited.has(scalar))
|
|
24
|
+
return;
|
|
25
|
+
visited.add(scalar);
|
|
26
|
+
for (const parent of scalar.parents) {
|
|
27
|
+
dfs(parent);
|
|
28
|
+
}
|
|
29
|
+
sorted.push(scalar);
|
|
30
|
+
};
|
|
31
|
+
dfs(scalar);
|
|
32
|
+
return sorted.reverse();
|
|
33
|
+
}
|
|
34
|
+
export function backPropagate(scalar, dOut) {
|
|
35
|
+
const sorted = topologicalSort(scalar);
|
|
36
|
+
const derivatives = new Map();
|
|
37
|
+
derivatives.set(scalar, dOut);
|
|
38
|
+
for (const node of sorted) {
|
|
39
|
+
const d = derivatives.get(node);
|
|
40
|
+
if (d === undefined)
|
|
41
|
+
continue;
|
|
42
|
+
if (node.isLeaf()) {
|
|
43
|
+
node.accumulateDerivative(d);
|
|
44
|
+
}
|
|
45
|
+
else {
|
|
46
|
+
for (const [parent, grad] of node.chainRule(d)) {
|
|
47
|
+
derivatives.set(parent, (derivatives.get(parent) ?? 0) + grad);
|
|
48
|
+
}
|
|
49
|
+
}
|
|
50
|
+
}
|
|
51
|
+
}
|
|
52
|
+
export function topologicalSortTensor(tensor) {
|
|
53
|
+
const visited = new Set();
|
|
54
|
+
const sorted = [];
|
|
55
|
+
const dfs = (t) => {
|
|
56
|
+
if (visited.has(t))
|
|
57
|
+
return;
|
|
58
|
+
visited.add(t);
|
|
59
|
+
for (const parent of t.parents) {
|
|
60
|
+
dfs(parent);
|
|
61
|
+
}
|
|
62
|
+
sorted.push(t);
|
|
63
|
+
};
|
|
64
|
+
dfs(tensor);
|
|
65
|
+
return sorted.reverse();
|
|
66
|
+
}
|
|
67
|
+
export function backPropagateTensor(tensor, gradOutput) {
|
|
68
|
+
const sorted = topologicalSortTensor(tensor);
|
|
69
|
+
const gradients = new Map();
|
|
70
|
+
gradients.set(tensor, gradOutput);
|
|
71
|
+
for (const node of sorted) {
|
|
72
|
+
const grad = gradients.get(node);
|
|
73
|
+
if (grad === undefined)
|
|
74
|
+
continue;
|
|
75
|
+
if (node.isLeaf()) {
|
|
76
|
+
node.accumulateGrad(grad);
|
|
77
|
+
}
|
|
78
|
+
else {
|
|
79
|
+
for (const [parent, parentGrad] of node.chainRule(grad)) {
|
|
80
|
+
const existing = gradients.get(parent);
|
|
81
|
+
if (existing) {
|
|
82
|
+
gradients.set(parent, existing.add(parentGrad));
|
|
83
|
+
}
|
|
84
|
+
else {
|
|
85
|
+
gradients.set(parent, parentGrad);
|
|
86
|
+
}
|
|
87
|
+
}
|
|
88
|
+
}
|
|
89
|
+
}
|
|
90
|
+
}
|
|
91
|
+
//# sourceMappingURL=autodiff.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"autodiff.js","sourceRoot":"","sources":["../src/autodiff.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAC;AACrC,OAAO,EAAE,MAAM,EAAE,MAAM,aAAa,CAAC;AAErC,MAAM,UAAU,iBAAiB,CAC7B,CAAgC,EAChC,IAAc,EACd,MAAc,CAAC,EACf,UAAkB,IAAI;IAEtB,MAAM,QAAQ,GAAG,CAAC,GAAG,IAAI,CAAC,CAAC;IAC3B,QAAQ,CAAC,GAAG,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAE,GAAG,OAAO,CAAC;IAEzC,MAAM,SAAS,GAAG,CAAC,GAAG,IAAI,CAAC,CAAC;IAC5B,SAAS,CAAC,GAAG,CAAC,GAAG,SAAS,CAAC,GAAG,CAAE,GAAG,OAAO,CAAC;IAE3C,OAAO,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,GAAG,CAAC,CAAC,GAAG,SAAS,CAAC,CAAC,GAAG,CAAC,CAAC,GAAG,OAAO,CAAC,CAAC;AAC9D,CAAC;AAED,MAAM,OAAO,OAAO;IACR,YAAY,GAAa,EAAE,CAAC;IAEpC,eAAe,CAAE,GAAG,MAAgB;QAChC,IAAI,CAAC,YAAY,GAAG,MAAM,CAAC;IAC/B,CAAC;IAED,IAAI,WAAW;QACX,OAAO,IAAI,CAAC,YAAY,CAAC;IAC7B,CAAC;CACJ;AAED,MAAM,UAAU,eAAe,CAAC,MAAc;IAC1C,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;IAClC,MAAM,MAAM,GAAG,IAAI,KAAK,EAAU,CAAC;IAEnC,MAAM,GAAG,GAA6B,CAAC,MAAM,EAAE,EAAE;QAC7C,IAAI,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC;YAAE,OAAO;QAChC,OAAO,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;QACpB,KAAK,MAAM,MAAM,IAAI,MAAM,CAAC,OAAO,EAAE,CAAC;YAClC,GAAG,CAAC,MAAM,CAAC,CAAC;QAChB,CAAC;QACD,MAAM,CAAC,IAAI,CAAC,MAAM,CAAC,CAAC;IACxB,CAAC,CAAC;IACF,GAAG,CAAC,MAAM,CAAC,CAAC;IACZ,OAAO,MAAM,CAAC,OAAO,EAAE,CAAC;AAC5B,CAAC;AAED,MAAM,UAAU,aAAa,CAAC,MAAc,EAAE,IAAY;IACtD,MAAM,MAAM,GAAG,eAAe,CAAC,MAAM,CAAC,CAAC;IACvC,MAAM,WAAW,GAAwB,IAAI,GAAG,EAAE,CAAC;IAEnD,WAAW,CAAC,GAAG,CAAC,MAAM,EAAE,IAAI,CAAC,CAAC;IAE9B,KAAK,MAAM,IAAI,IAAI,MAAM,EAAE,CAAC;QACxB,MAAM,CAAC,GAAG,WAAW,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;QAChC,IAAI,CAAC,KAAK,SAAS;YAAE,SAAS;QAE9B,IAAI,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC;YAChB,IAAI,CAAC,oBAAoB,CAAC,CAAC,CAAC,CAAC;QACjC,CAAC;aAAM,CAAC;YACJ,KAAK,MAAM,CAAC,MAAM,EAAE,IAAI,CAAC,IAAI,IAAI,CAAC,SAAS,CAAC,CAAC,CAAC,EAAE,CAAC;gBAC7C,WAAW,CAAC,GAAG,CAAC,MAAM,EAAE,CAAC,WAAW,CAAC,GAAG,CAAC,MAAM,CAAC,IAAI,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;YACnE,CAAC;QACL,CAAC;IACL,CAAC;AACL,CAAC;AAED,MAAM,UAAU,qBAAqB,CAAC,MAAc;IAChD,MAAM,OAAO,GAAG,IAAI,GAAG,EAAU,CAAC;IAClC,MAAM,MAAM,GAAa,EAAE,CAAC;IAE5B,MAAM,GAAG,GAAG,CAAC,CAAS,EAAE,EAAE;QACtB,IAAI,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC;YAAE,OAAO;QAC3B,OAAO,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC;QACf,KAAK,MAAM,MAAM,IAAI,CAAC,CAAC,OAAO,EAAE,CAAC;YAC7B,GAAG,CAAC,MAAM,CAAC,CAAC;QAChB,CAAC;QACD,MAAM,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACnB,CAAC,CAAC;IACF,GAAG,CAAC,MAAM,CAAC,CAAC;IACZ,OAAO,MAAM,CAAC,OAAO,EAAE,CAAC;AAC5B,CAAC;AAED,MAAM,UAAU,mBAAmB,CAAC,MAAc,EAAE,UAAkB;IAClE,MAAM,MAAM,GAAG,qBAAqB,CAAC,MAAM,CAAC,CAAC;IAC7C,MAAM,SAAS,GAAwB,IAAI,GAAG,EAAE,CAAC;IAEjD,SAAS,CAAC,GAAG,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;IAElC,KAAK,MAAM,IAAI,IAAI,MAAM,EAAE,CAAC;QACxB,MAAM,IAAI,GAAG,SAAS,CAAC,GAAG,CAAC,IAAI,CAAC,CAAC;QACjC,IAAI,IAAI,KAAK,SAAS;YAAE,SAAS;QAEjC,IAAI,IAAI,CAAC,MAAM,EAAE,EAAE,CAAC;YAChB,IAAI,CAAC,cAAc,CAAC,IAAI,CAAC,CAAC;QAC9B,CAAC;aAAM,CAAC;YACJ,KAAK,MAAM,CAAC,MAAM,EAAE,UAAU,CAAC,IAAI,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,EAAE,CAAC;gBACtD,MAAM,QAAQ,GAAG,SAAS,CAAC,GAAG,CAAC,MAAM,CAAC,CAAC;gBACvC,IAAI,QAAQ,EAAE,CAAC;oBACX,SAAS,CAAC,GAAG,CAAC,MAAM,EAAE,QAAQ,CAAC,GAAG,CAAC,UAAU,CAAC,CAAC,CAAC;gBACpD,CAAC;qBAAM,CAAC;oBACJ,SAAS,CAAC,GAAG,CAAC,MAAM,EAAE,UAAU,CAAC,CAAC;gBACtC,CAAC;YACL,CAAC;QACL,CAAC;IACL,CAAC;AACL,CAAC"}
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
export type Point = [number, number];
|
|
2
|
+
export interface Graph {
|
|
3
|
+
N: number;
|
|
4
|
+
X: Point[];
|
|
5
|
+
y: number[];
|
|
6
|
+
}
|
|
7
|
+
/** generate N random 2D points in [0, 1) × [0, 1) */
|
|
8
|
+
export declare function makePts(N: number): Point[];
|
|
9
|
+
export declare function simple(N: number): Graph;
|
|
10
|
+
export declare function diag(N: number): Graph;
|
|
11
|
+
export declare function split(N: number): Graph;
|
|
12
|
+
export declare function xor(N: number): Graph;
|
|
13
|
+
export declare function circle(N: number): Graph;
|
|
14
|
+
export declare function spiral(N: number): Graph;
|
|
15
|
+
export declare const datasets: Record<string, (N: number) => Graph>;
|
|
16
|
+
//# sourceMappingURL=datasets.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"datasets.d.ts","sourceRoot":"","sources":["../src/datasets.ts"],"names":[],"mappings":"AAAA,MAAM,MAAM,KAAK,GAAG,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;AAErC,MAAM,WAAW,KAAK;IACpB,CAAC,EAAE,MAAM,CAAC;IACV,CAAC,EAAE,KAAK,EAAE,CAAC;IACX,CAAC,EAAE,MAAM,EAAE,CAAC;CACb;AAED,qDAAqD;AACrD,wBAAgB,OAAO,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,EAAE,CAQ1C;AAED,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAIvC;AAED,wBAAgB,IAAI,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAIrC;AAED,wBAAgB,KAAK,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAItC;AAED,wBAAgB,GAAG,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAMpC;AAED,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAQvC;AAED,wBAAgB,MAAM,CAAC,CAAC,EAAE,MAAM,GAAG,KAAK,CAoBvC;AAED,eAAO,MAAM,QAAQ,EAAE,MAAM,CAAC,MAAM,EAAE,CAAC,CAAC,EAAE,MAAM,KAAK,KAAK,CAOzD,CAAC"}
|
package/dist/datasets.js
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
/** generate N random 2D points in [0, 1) × [0, 1) */
|
|
2
|
+
export function makePts(N) {
|
|
3
|
+
const X = [];
|
|
4
|
+
for (let i = 0; i < N; i++) {
|
|
5
|
+
const x1 = Math.random();
|
|
6
|
+
const x2 = Math.random();
|
|
7
|
+
X.push([x1, x2]);
|
|
8
|
+
}
|
|
9
|
+
return X;
|
|
10
|
+
}
|
|
11
|
+
export function simple(N) {
|
|
12
|
+
const X = makePts(N);
|
|
13
|
+
const y = X.map(([x1]) => (x1 < 0.5 ? 1 : 0));
|
|
14
|
+
return { N, X, y };
|
|
15
|
+
}
|
|
16
|
+
export function diag(N) {
|
|
17
|
+
const X = makePts(N);
|
|
18
|
+
const y = X.map(([x1, x2]) => (x1 + x2 < 0.5 ? 1 : 0));
|
|
19
|
+
return { N, X, y };
|
|
20
|
+
}
|
|
21
|
+
export function split(N) {
|
|
22
|
+
const X = makePts(N);
|
|
23
|
+
const y = X.map(([x1]) => (x1 < 0.2 || x1 > 0.8 ? 1 : 0));
|
|
24
|
+
return { N, X, y };
|
|
25
|
+
}
|
|
26
|
+
export function xor(N) {
|
|
27
|
+
const X = makePts(N);
|
|
28
|
+
const y = X.map(([x1, x2]) => (x1 < 0.5 && x2 > 0.5) || (x1 > 0.5 && x2 < 0.5) ? 1 : 0);
|
|
29
|
+
return { N, X, y };
|
|
30
|
+
}
|
|
31
|
+
export function circle(N) {
|
|
32
|
+
const X = makePts(N);
|
|
33
|
+
const y = X.map(([x1, x2]) => {
|
|
34
|
+
const dx = x1 - 0.5;
|
|
35
|
+
const dy = x2 - 0.5;
|
|
36
|
+
return dx * dx + dy * dy > 0.1 ? 1 : 0;
|
|
37
|
+
});
|
|
38
|
+
return { N, X, y };
|
|
39
|
+
}
|
|
40
|
+
export function spiral(N) {
|
|
41
|
+
const half = Math.floor(N / 2);
|
|
42
|
+
const fx = (t) => (t * Math.cos(t)) / 20.0;
|
|
43
|
+
const fy = (t) => (t * Math.sin(t)) / 20.0;
|
|
44
|
+
const X1 = Array.from({ length: half }, (_, i) => {
|
|
45
|
+
const t = 10.0 * (i / half);
|
|
46
|
+
return [fx(t) + 0.5, fy(t) + 0.5];
|
|
47
|
+
});
|
|
48
|
+
const X2 = Array.from({ length: half }, (_, i) => {
|
|
49
|
+
const t = -10.0 * (i / half);
|
|
50
|
+
return [fy(t) + 0.5, fx(t) + 0.5];
|
|
51
|
+
});
|
|
52
|
+
const X = [...X1, ...X2];
|
|
53
|
+
const y = [...Array(half).fill(0), ...Array(half).fill(1)];
|
|
54
|
+
return { N, X, y };
|
|
55
|
+
}
|
|
56
|
+
export const datasets = {
|
|
57
|
+
Simple: simple,
|
|
58
|
+
Diag: diag,
|
|
59
|
+
Split: split,
|
|
60
|
+
Xor: xor,
|
|
61
|
+
Circle: circle,
|
|
62
|
+
Spiral: spiral,
|
|
63
|
+
};
|
|
64
|
+
//# sourceMappingURL=datasets.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"datasets.js","sourceRoot":"","sources":["../src/datasets.ts"],"names":[],"mappings":"AAQA,qDAAqD;AACrD,MAAM,UAAU,OAAO,CAAC,CAAS;IAC/B,MAAM,CAAC,GAAY,EAAE,CAAC;IACtB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QAC3B,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC;QACzB,MAAM,EAAE,GAAG,IAAI,CAAC,MAAM,EAAE,CAAC;QACzB,CAAC,CAAC,IAAI,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,CAAC,CAAC;IACnB,CAAC;IACD,OAAO,CAAC,CAAC;AACX,CAAC;AAED,MAAM,UAAU,MAAM,CAAC,CAAS;IAC9B,MAAM,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;IACrB,MAAM,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC9C,OAAO,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,IAAI,CAAC,CAAS;IAC5B,MAAM,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;IACrB,MAAM,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,EAAE,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACvD,OAAO,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,KAAK,CAAC,CAAS;IAC7B,MAAM,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;IACrB,MAAM,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC,CAAC,EAAE,GAAG,GAAG,IAAI,EAAE,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IAC1D,OAAO,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,GAAG,CAAC,CAAS;IAC3B,MAAM,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;IACrB,MAAM,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE,CAC3B,CAAC,EAAE,GAAG,GAAG,IAAI,EAAE,GAAG,GAAG,CAAC,IAAI,CAAC,EAAE,GAAG,GAAG,IAAI,EAAE,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CACzD,CAAC;IACF,OAAO,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,MAAM,CAAC,CAAS;IAC9B,MAAM,CAAC,GAAG,OAAO,CAAC,CAAC,CAAC,CAAC;IACrB,MAAM,CAAC,GAAG,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC,EAAE,EAAE,EAAE,CAAC,EAAE,EAAE;QAC3B,MAAM,EAAE,GAAG,EAAE,GAAG,GAAG,CAAC;QACpB,MAAM,EAAE,GAAG,EAAE,GAAG,GAAG,CAAC;QACpB,OAAO,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,EAAE,GAAG,GAAG,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC,CAAC;IACzC,CAAC,CAAC,CAAC;IACH,OAAO,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;AACrB,CAAC;AAED,MAAM,UAAU,MAAM,CAAC,CAAS;IAC9B,MAAM,IAAI,GAAG,IAAI,CAAC,KAAK,CAAC,CAAC,GAAG,CAAC,CAAC,CAAC;IAE/B,MAAM,EAAE,GAAG,CAAC,CAAS,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC;IACnD,MAAM,EAAE,GAAG,CAAC,CAAS,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC,CAAC,GAAG,IAAI,CAAC;IAEnD,MAAM,EAAE,GAAY,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE;QACxD,MAAM,CAAC,GAAG,IAAI,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;QAC5B,OAAO,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,MAAM,EAAE,GAAY,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,IAAI,EAAE,EAAE,CAAC,CAAC,EAAE,CAAC,EAAE,EAAE;QACxD,MAAM,CAAC,GAAG,CAAC,IAAI,GAAG,CAAC,CAAC,GAAG,IAAI,CAAC,CAAC;QAC7B,OAAO,CAAC,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,EAAE,EAAE,CAAC,CAAC,CAAC,GAAG,GAAG,CAAC,CAAC;IACpC,CAAC,CAAC,CAAC;IAEH,MAAM,CAAC,GAAG,CAAC,GAAG,EAAE,EAAE,GAAG,EAAE,CAAC,CAAC;IACzB,MAAM,CAAC,GAAG,CAAC,GAAG,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,EAAE,GAAG,KAAK,CAAC,IAAI,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC,CAAC;IAE3D,OAAO,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC,EAAE,CAAC;AACrB,CAAC;AAED,MAAM,CAAC,MAAM,QAAQ,GAAyC;IAC5D,MAAM,EAAE,MAAM;IACd,IAAI,EAAE,IAAI;IACV,KAAK,EAAE,KAAK;IACZ,GAAG,EAAE,GAAG;IACR,MAAM,EAAE,MAAM;IACd,MAAM,EAAE,MAAM;CACf,CAAC"}
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import type { Storage, Shape, Strides } from './tensor_data.js';
|
|
2
|
+
/** Terminate the worker pool (call in test teardown so Node can exit). */
|
|
3
|
+
export declare function destroyPool(): void;
|
|
4
|
+
/**
|
|
5
|
+
* Parallel tensor map. Optimizations: main loop in parallel when
|
|
6
|
+
* size >= PARALLEL_THRESHOLD; stride-aligned fast path avoids all indexing.
|
|
7
|
+
*
|
|
8
|
+
* `fn` must be a pure function (no captured variables) -- workers reconstruct
|
|
9
|
+
* it from `fn.toString()` via `new Function()`.
|
|
10
|
+
*/
|
|
11
|
+
export declare function fastTensorMap(fn: (x: number) => number): (outStorage: Storage, outShape: Shape, outStrides: Strides, inStorage: Storage, inShape: Shape, inStrides: Strides) => void;
|
|
12
|
+
/**
|
|
13
|
+
* Parallel tensor zip. Same optimizations and pure-function constraint
|
|
14
|
+
* as fastTensorMap; stride-aligned when all three tensors match.
|
|
15
|
+
*/
|
|
16
|
+
export declare function fastTensorZip(fn: (a: number, b: number) => number): (outStorage: Storage, outShape: Shape, outStrides: Strides, aStorage: Storage, aShape: Shape, aStrides: Strides, bStorage: Storage, bShape: Shape, bStrides: Strides) => void;
|
|
17
|
+
/**
|
|
18
|
+
* Parallel tensor reduce. Outer loop (output elements) in parallel; inner
|
|
19
|
+
* reduction is sequential per element using stride-stepping. Same
|
|
20
|
+
* pure-function constraint as fastTensorMap.
|
|
21
|
+
*/
|
|
22
|
+
export declare function fastTensorReduce(fn: (acc: number, x: number) => number): (outStorage: Storage, outShape: Shape, outStrides: Strides, aStorage: Storage, aShape: Shape, aStrides: Strides, reduceDim: number) => void;
|
|
23
|
+
//# sourceMappingURL=fast_ops.d.ts.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"fast_ops.d.ts","sourceRoot":"","sources":["../src/fast_ops.ts"],"names":[],"mappings":"AAKA,OAAO,KAAK,EACR,OAAO,EACP,KAAK,EACL,OAAO,EACV,MAAM,kBAAkB,CAAC;AAgH1B,0EAA0E;AAC1E,wBAAgB,WAAW,IAAI,IAAI,CAKlC;AAMD;;;;;;GAMG;AACH,wBAAgB,aAAa,CACzB,EAAE,EAAE,CAAC,CAAC,EAAE,MAAM,KAAK,MAAM,GAC1B,CACC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,SAAS,EAAE,OAAO,EAClB,OAAO,EAAE,KAAK,EACd,SAAS,EAAE,OAAO,KACjB,IAAI,CAyDR;AAED;;;GAGG;AACH,wBAAgB,aAAa,CACzB,EAAE,EAAE,CAAC,CAAC,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,KAAK,MAAM,GACrC,CACC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,EACjB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,KAChB,IAAI,CAsER;AAED;;;;GAIG;AACH,wBAAgB,gBAAgB,CAC5B,EAAE,EAAE,CAAC,GAAG,EAAE,MAAM,EAAE,CAAC,EAAE,MAAM,KAAK,MAAM,GACvC,CACC,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,KAAK,EACf,UAAU,EAAE,OAAO,EACnB,QAAQ,EAAE,OAAO,EACjB,MAAM,EAAE,KAAK,EACb,QAAQ,EAAE,OAAO,EACjB,SAAS,EAAE,MAAM,KAChB,IAAI,CA+DR"}
|
package/dist/fast_ops.js
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import { Worker } from 'node:worker_threads';
|
|
2
|
+
import { cpus } from 'node:os';
|
|
3
|
+
import { fileURLToPath } from 'node:url';
|
|
4
|
+
import { dirname, join } from 'node:path';
|
|
5
|
+
import { existsSync } from 'node:fs';
|
|
6
|
+
import { indexToPosition, toIndex, shapeProduct, broadcastIndex, } from './tensor_data.js';
|
|
7
|
+
const NUM_WORKERS = Math.max(1, cpus().length);
|
|
8
|
+
// Parallelism is across independent output elements, not within a single
|
|
9
|
+
// reduction, so both paths produce bitwise-identical results for the same fn.
|
|
10
|
+
const PARALLEL_THRESHOLD = 4096;
|
|
11
|
+
function shapesEqual(a, b) {
|
|
12
|
+
if (a.length !== b.length)
|
|
13
|
+
return false;
|
|
14
|
+
for (let i = 0; i < a.length; i++) {
|
|
15
|
+
if (a[i] !== b[i])
|
|
16
|
+
return false;
|
|
17
|
+
}
|
|
18
|
+
return true;
|
|
19
|
+
}
|
|
20
|
+
function stridesEqual(a, b) {
|
|
21
|
+
if (a.length !== b.length)
|
|
22
|
+
return false;
|
|
23
|
+
for (let i = 0; i < a.length; i++) {
|
|
24
|
+
if (a[i] !== b[i])
|
|
25
|
+
return false;
|
|
26
|
+
}
|
|
27
|
+
return true;
|
|
28
|
+
}
|
|
29
|
+
class WorkerPool {
|
|
30
|
+
workers;
|
|
31
|
+
syncBuffer;
|
|
32
|
+
syncArray;
|
|
33
|
+
constructor(numWorkers) {
|
|
34
|
+
this.syncBuffer = new SharedArrayBuffer(numWorkers * Int32Array.BYTES_PER_ELEMENT);
|
|
35
|
+
this.syncArray = new Int32Array(this.syncBuffer);
|
|
36
|
+
// .ts in dev (Node 25+ type-stripping), .js after tsc build
|
|
37
|
+
const currentDir = dirname(fileURLToPath(import.meta.url));
|
|
38
|
+
const tsPath = join(currentDir, 'fast_ops_worker.ts');
|
|
39
|
+
const jsPath = join(currentDir, 'fast_ops_worker.js');
|
|
40
|
+
const workerPath = existsSync(tsPath) ? tsPath : jsPath;
|
|
41
|
+
this.workers = Array.from({ length: numWorkers }, (_, id) => {
|
|
42
|
+
const w = new Worker(workerPath, {
|
|
43
|
+
workerData: { workerId: id, syncBuffer: this.syncBuffer },
|
|
44
|
+
});
|
|
45
|
+
w.on('error', (err) => {
|
|
46
|
+
console.error(`[fast_ops] Worker ${id} error:`, err);
|
|
47
|
+
});
|
|
48
|
+
return w;
|
|
49
|
+
});
|
|
50
|
+
}
|
|
51
|
+
/** Split [0, size) across workers; block until all signal via Atomics. */
|
|
52
|
+
parallelFor(size, taskFactory) {
|
|
53
|
+
const numWorkers = this.workers.length;
|
|
54
|
+
const chunkSize = Math.ceil(size / numWorkers);
|
|
55
|
+
for (let i = 0; i < numWorkers; i++) {
|
|
56
|
+
Atomics.store(this.syncArray, i, 0);
|
|
57
|
+
}
|
|
58
|
+
for (let i = 0; i < numWorkers; i++) {
|
|
59
|
+
const start = i * chunkSize;
|
|
60
|
+
const end = Math.min(start + chunkSize, size);
|
|
61
|
+
if (start >= size) {
|
|
62
|
+
Atomics.store(this.syncArray, i, 1);
|
|
63
|
+
continue;
|
|
64
|
+
}
|
|
65
|
+
this.workers[i].postMessage(taskFactory(start, end));
|
|
66
|
+
}
|
|
67
|
+
for (let i = 0; i < numWorkers; i++) {
|
|
68
|
+
while (Atomics.load(this.syncArray, i) === 0) {
|
|
69
|
+
Atomics.wait(this.syncArray, i, 0, 30000);
|
|
70
|
+
}
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
terminate() {
|
|
74
|
+
for (const w of this.workers) {
|
|
75
|
+
w.terminate();
|
|
76
|
+
}
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
let _pool = undefined;
|
|
80
|
+
// SharedArrayBuffer can't cross Jest's VM context boundary to real workers.
|
|
81
|
+
const PARALLEL_DISABLED = typeof process !== 'undefined' &&
|
|
82
|
+
(process.env['JEST_WORKER_ID'] !== undefined ||
|
|
83
|
+
process.env['TSTORCH_DISABLE_PARALLEL'] !== undefined);
|
|
84
|
+
function getPool() {
|
|
85
|
+
if (PARALLEL_DISABLED)
|
|
86
|
+
return null;
|
|
87
|
+
if (_pool === undefined) {
|
|
88
|
+
try {
|
|
89
|
+
_pool = new WorkerPool(NUM_WORKERS);
|
|
90
|
+
}
|
|
91
|
+
catch {
|
|
92
|
+
_pool = null;
|
|
93
|
+
}
|
|
94
|
+
}
|
|
95
|
+
return _pool;
|
|
96
|
+
}
|
|
97
|
+
/** Terminate the worker pool (call in test teardown so Node can exit). */
|
|
98
|
+
export function destroyPool() {
|
|
99
|
+
if (_pool) {
|
|
100
|
+
_pool.terminate();
|
|
101
|
+
_pool = undefined;
|
|
102
|
+
}
|
|
103
|
+
}
|
|
104
|
+
function isShared(storage) {
|
|
105
|
+
return storage.buffer instanceof SharedArrayBuffer;
|
|
106
|
+
}
|
|
107
|
+
/**
|
|
108
|
+
* Parallel tensor map. Optimizations: main loop in parallel when
|
|
109
|
+
* size >= PARALLEL_THRESHOLD; stride-aligned fast path avoids all indexing.
|
|
110
|
+
*
|
|
111
|
+
* `fn` must be a pure function (no captured variables) -- workers reconstruct
|
|
112
|
+
* it from `fn.toString()` via `new Function()`.
|
|
113
|
+
*/
|
|
114
|
+
export function fastTensorMap(fn) {
|
|
115
|
+
return (outStorage, outShape, outStrides, inStorage, inShape, inStrides) => {
|
|
116
|
+
const size = shapeProduct(outShape);
|
|
117
|
+
const aligned = shapesEqual(outShape, inShape) && stridesEqual(outStrides, inStrides);
|
|
118
|
+
const pool = getPool();
|
|
119
|
+
if (pool &&
|
|
120
|
+
size >= PARALLEL_THRESHOLD &&
|
|
121
|
+
isShared(outStorage) &&
|
|
122
|
+
isShared(inStorage)) {
|
|
123
|
+
const fnSource = fn.toString();
|
|
124
|
+
pool.parallelFor(size, (start, end) => ({
|
|
125
|
+
type: 'map',
|
|
126
|
+
fnSource,
|
|
127
|
+
start,
|
|
128
|
+
end,
|
|
129
|
+
outBuffer: outStorage.buffer,
|
|
130
|
+
outShape: Array.from(outShape),
|
|
131
|
+
outStrides: Array.from(outStrides),
|
|
132
|
+
inBuffer: inStorage.buffer,
|
|
133
|
+
inShape: Array.from(inShape),
|
|
134
|
+
inStrides: Array.from(inStrides),
|
|
135
|
+
aligned,
|
|
136
|
+
}));
|
|
137
|
+
return;
|
|
138
|
+
}
|
|
139
|
+
if (aligned) {
|
|
140
|
+
for (let i = 0; i < size; i++) {
|
|
141
|
+
outStorage[i] = fn(inStorage[i]);
|
|
142
|
+
}
|
|
143
|
+
return;
|
|
144
|
+
}
|
|
145
|
+
const outIndex = new Array(outShape.length).fill(0);
|
|
146
|
+
const inIndex = new Array(inShape.length).fill(0);
|
|
147
|
+
for (let ordinal = 0; ordinal < size; ordinal++) {
|
|
148
|
+
toIndex(ordinal, outShape, outIndex);
|
|
149
|
+
broadcastIndex(outIndex, outShape, inShape, inIndex);
|
|
150
|
+
const inPos = indexToPosition(inIndex, inStrides);
|
|
151
|
+
const outPos = indexToPosition(outIndex, outStrides);
|
|
152
|
+
outStorage[outPos] = fn(inStorage[inPos]);
|
|
153
|
+
}
|
|
154
|
+
};
|
|
155
|
+
}
|
|
156
|
+
/**
|
|
157
|
+
* Parallel tensor zip. Same optimizations and pure-function constraint
|
|
158
|
+
* as fastTensorMap; stride-aligned when all three tensors match.
|
|
159
|
+
*/
|
|
160
|
+
export function fastTensorZip(fn) {
|
|
161
|
+
return (outStorage, outShape, outStrides, aStorage, aShape, aStrides, bStorage, bShape, bStrides) => {
|
|
162
|
+
const size = shapeProduct(outShape);
|
|
163
|
+
const aligned = shapesEqual(outShape, aShape) &&
|
|
164
|
+
shapesEqual(outShape, bShape) &&
|
|
165
|
+
stridesEqual(outStrides, aStrides) &&
|
|
166
|
+
stridesEqual(outStrides, bStrides);
|
|
167
|
+
const pool = getPool();
|
|
168
|
+
if (pool &&
|
|
169
|
+
size >= PARALLEL_THRESHOLD &&
|
|
170
|
+
isShared(outStorage) &&
|
|
171
|
+
isShared(aStorage) &&
|
|
172
|
+
isShared(bStorage)) {
|
|
173
|
+
const fnSource = fn.toString();
|
|
174
|
+
pool.parallelFor(size, (start, end) => ({
|
|
175
|
+
type: 'zip',
|
|
176
|
+
fnSource,
|
|
177
|
+
start,
|
|
178
|
+
end,
|
|
179
|
+
outBuffer: outStorage.buffer,
|
|
180
|
+
outShape: Array.from(outShape),
|
|
181
|
+
outStrides: Array.from(outStrides),
|
|
182
|
+
aBuffer: aStorage.buffer,
|
|
183
|
+
aShape: Array.from(aShape),
|
|
184
|
+
aStrides: Array.from(aStrides),
|
|
185
|
+
bBuffer: bStorage.buffer,
|
|
186
|
+
bShape: Array.from(bShape),
|
|
187
|
+
bStrides: Array.from(bStrides),
|
|
188
|
+
aligned,
|
|
189
|
+
}));
|
|
190
|
+
return;
|
|
191
|
+
}
|
|
192
|
+
if (aligned) {
|
|
193
|
+
for (let i = 0; i < size; i++) {
|
|
194
|
+
outStorage[i] = fn(aStorage[i], bStorage[i]);
|
|
195
|
+
}
|
|
196
|
+
return;
|
|
197
|
+
}
|
|
198
|
+
const outIndex = new Array(outShape.length).fill(0);
|
|
199
|
+
const aIndex = new Array(aShape.length).fill(0);
|
|
200
|
+
const bIndex = new Array(bShape.length).fill(0);
|
|
201
|
+
for (let ordinal = 0; ordinal < size; ordinal++) {
|
|
202
|
+
toIndex(ordinal, outShape, outIndex);
|
|
203
|
+
broadcastIndex(outIndex, outShape, aShape, aIndex);
|
|
204
|
+
broadcastIndex(outIndex, outShape, bShape, bIndex);
|
|
205
|
+
const aPos = indexToPosition(aIndex, aStrides);
|
|
206
|
+
const bPos = indexToPosition(bIndex, bStrides);
|
|
207
|
+
const outPos = indexToPosition(outIndex, outStrides);
|
|
208
|
+
outStorage[outPos] = fn(aStorage[aPos], bStorage[bPos]);
|
|
209
|
+
}
|
|
210
|
+
};
|
|
211
|
+
}
|
|
212
|
+
/**
|
|
213
|
+
* Parallel tensor reduce. Outer loop (output elements) in parallel; inner
|
|
214
|
+
* reduction is sequential per element using stride-stepping. Same
|
|
215
|
+
* pure-function constraint as fastTensorMap.
|
|
216
|
+
*/
|
|
217
|
+
export function fastTensorReduce(fn) {
|
|
218
|
+
return (outStorage, outShape, outStrides, aStorage, aShape, aStrides, reduceDim) => {
|
|
219
|
+
const outSize = shapeProduct(outShape);
|
|
220
|
+
const reduceDimSize = aShape[reduceDim];
|
|
221
|
+
const reduceStride = aStrides[reduceDim];
|
|
222
|
+
const pool = getPool();
|
|
223
|
+
if (pool &&
|
|
224
|
+
outSize >= PARALLEL_THRESHOLD &&
|
|
225
|
+
isShared(outStorage) &&
|
|
226
|
+
isShared(aStorage)) {
|
|
227
|
+
const fnSource = fn.toString();
|
|
228
|
+
pool.parallelFor(outSize, (start, end) => ({
|
|
229
|
+
type: 'reduce',
|
|
230
|
+
fnSource,
|
|
231
|
+
start,
|
|
232
|
+
end,
|
|
233
|
+
outBuffer: outStorage.buffer,
|
|
234
|
+
outShape: Array.from(outShape),
|
|
235
|
+
outStrides: Array.from(outStrides),
|
|
236
|
+
inBuffer: aStorage.buffer,
|
|
237
|
+
inShape: Array.from(aShape),
|
|
238
|
+
inStrides: Array.from(aStrides),
|
|
239
|
+
reduceDim,
|
|
240
|
+
reduceDimSize,
|
|
241
|
+
}));
|
|
242
|
+
return;
|
|
243
|
+
}
|
|
244
|
+
const outIndex = new Array(outShape.length).fill(0);
|
|
245
|
+
const aIndex = new Array(aShape.length).fill(0);
|
|
246
|
+
for (let ordinal = 0; ordinal < outSize; ordinal++) {
|
|
247
|
+
toIndex(ordinal, outShape, outIndex);
|
|
248
|
+
const outPos = indexToPosition(outIndex, outStrides);
|
|
249
|
+
for (let i = 0; i < outShape.length; i++) {
|
|
250
|
+
aIndex[i] = outIndex[i];
|
|
251
|
+
}
|
|
252
|
+
aIndex[reduceDim] = 0;
|
|
253
|
+
let aPos = indexToPosition(aIndex, aStrides);
|
|
254
|
+
let acc = aStorage[aPos];
|
|
255
|
+
for (let j = 1; j < reduceDimSize; j++) {
|
|
256
|
+
aPos += reduceStride;
|
|
257
|
+
acc = fn(acc, aStorage[aPos]);
|
|
258
|
+
}
|
|
259
|
+
outStorage[outPos] = acc;
|
|
260
|
+
}
|
|
261
|
+
};
|
|
262
|
+
}
|
|
263
|
+
//# sourceMappingURL=fast_ops.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"fast_ops.js","sourceRoot":"","sources":["../src/fast_ops.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,MAAM,EAAE,MAAM,qBAAqB,CAAC;AAC7C,OAAO,EAAE,IAAI,EAAE,MAAM,SAAS,CAAC;AAC/B,OAAO,EAAE,aAAa,EAAE,MAAM,UAAU,CAAC;AACzC,OAAO,EAAE,OAAO,EAAE,IAAI,EAAE,MAAM,WAAW,CAAC;AAC1C,OAAO,EAAE,UAAU,EAAE,MAAM,SAAS,CAAC;AAOrC,OAAO,EACH,eAAe,EACf,OAAO,EACP,YAAY,EACZ,cAAc,GACjB,MAAM,kBAAkB,CAAC;AAE1B,MAAM,WAAW,GAAG,IAAI,CAAC,GAAG,CAAC,CAAC,EAAE,IAAI,EAAE,CAAC,MAAM,CAAC,CAAC;AAE/C,yEAAyE;AACzE,8EAA8E;AAC9E,MAAM,kBAAkB,GAAG,IAAI,CAAC;AAEhC,SAAS,WAAW,CAAC,CAAQ,EAAE,CAAQ;IACnC,IAAI,CAAC,CAAC,MAAM,KAAK,CAAC,CAAC,MAAM;QAAE,OAAO,KAAK,CAAC;IACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAChC,IAAI,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAAE,OAAO,KAAK,CAAC;IACpC,CAAC;IACD,OAAO,IAAI,CAAC;AAChB,CAAC;AAED,SAAS,YAAY,CAAC,CAAU,EAAE,CAAU;IACxC,IAAI,CAAC,CAAC,MAAM,KAAK,CAAC,CAAC,MAAM;QAAE,OAAO,KAAK,CAAC;IACxC,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,CAAC,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAChC,IAAI,CAAC,CAAC,CAAC,CAAC,KAAK,CAAC,CAAC,CAAC,CAAC;YAAE,OAAO,KAAK,CAAC;IACpC,CAAC;IACD,OAAO,IAAI,CAAC;AAChB,CAAC;AAED,MAAM,UAAU;IACJ,OAAO,CAAW;IAClB,UAAU,CAAoB;IAC9B,SAAS,CAAa;IAE9B,YAAY,UAAkB;QAC1B,IAAI,CAAC,UAAU,GAAG,IAAI,iBAAiB,CAAC,UAAU,GAAG,UAAU,CAAC,iBAAiB,CAAC,CAAC;QACnF,IAAI,CAAC,SAAS,GAAG,IAAI,UAAU,CAAC,IAAI,CAAC,UAAU,CAAC,CAAC;QAEjD,4DAA4D;QAC5D,MAAM,UAAU,GAAG,OAAO,CAAC,aAAa,CAAC,MAAM,CAAC,IAAI,CAAC,GAAG,CAAC,CAAC,CAAC;QAC3D,MAAM,MAAM,GAAG,IAAI,CAAC,UAAU,EAAE,oBAAoB,CAAC,CAAC;QACtD,MAAM,MAAM,GAAG,IAAI,CAAC,UAAU,EAAE,oBAAoB,CAAC,CAAC;QACtD,MAAM,UAAU,GAAG,UAAU,CAAC,MAAM,CAAC,CAAC,CAAC,CAAC,MAAM,CAAC,CAAC,CAAC,MAAM,CAAC;QAExD,IAAI,CAAC,OAAO,GAAG,KAAK,CAAC,IAAI,CAAC,EAAE,MAAM,EAAE,UAAU,EAAE,EAAE,CAAC,CAAC,EAAE,EAAE,EAAE,EAAE;YACxD,MAAM,CAAC,GAAG,IAAI,MAAM,CAAC,UAAU,EAAE;gBAC7B,UAAU,EAAE,EAAE,QAAQ,EAAE,EAAE,EAAE,UAAU,EAAE,IAAI,CAAC,UAAU,EAAE;aAC5D,CAAC,CAAC;YACH,CAAC,CAAC,EAAE,CAAC,OAAO,EAAE,CAAC,GAAG,EAAE,EAAE;gBAClB,OAAO,CAAC,KAAK,CAAC,qBAAqB,EAAE,SAAS,EAAE,GAAG,CAAC,CAAC;YACzD,CAAC,CAAC,CAAC;YACH,OAAO,CAAC,CAAC;QACb,CAAC,CAAC,CAAC;IACP,CAAC;IAED,0EAA0E;IAC1E,WAAW,CACP,IAAY,EACZ,WAAmD;QAEnD,MAAM,UAAU,GAAG,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC;QACvC,MAAM,SAAS,GAAG,IAAI,CAAC,IAAI,CAAC,IAAI,GAAG,UAAU,CAAC,CAAC;QAE/C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;QACxC,CAAC;QAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,MAAM,KAAK,GAAG,CAAC,GAAG,SAAS,CAAC;YAC5B,MAAM,GAAG,GAAG,IAAI,CAAC,GAAG,CAAC,KAAK,GAAG,SAAS,EAAE,IAAI,CAAC,CAAC;YAC9C,IAAI,KAAK,IAAI,IAAI,EAAE,CAAC;gBAChB,OAAO,CAAC,KAAK,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC,EAAE,CAAC,CAAC,CAAC;gBACpC,SAAS;YACb,CAAC;YACD,IAAI,CAAC,OAAO,CAAC,CAAC,CAAE,CAAC,WAAW,CAAC,WAAW,CAAC,KAAK,EAAE,GAAG,CAAC,CAAC,CAAC;QAC1D,CAAC;QAED,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,UAAU,EAAE,CAAC,EAAE,EAAE,CAAC;YAClC,OAAO,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC;gBAC3C,OAAO,CAAC,IAAI,CAAC,IAAI,CAAC,SAAS,EAAE,CAAC,EAAE,CAAC,EAAE,KAAK,CAAC,CAAC;YAC9C,CAAC;QACL,CAAC;IACL,CAAC;IAED,SAAS;QACL,KAAK,MAAM,CAAC,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;YAC3B,CAAC,CAAC,SAAS,EAAE,CAAC;QAClB,CAAC;IACL,CAAC;CACJ;AAED,IAAI,KAAK,GAAkC,SAAS,CAAC;AAErD,4EAA4E;AAC5E,MAAM,iBAAiB,GAAG,OAAO,OAAO,KAAK,WAAW;IACpD,CAAC,OAAO,CAAC,GAAG,CAAC,gBAAgB,CAAC,KAAK,SAAS;QAC3C,OAAO,CAAC,GAAG,CAAC,0BAA0B,CAAC,KAAK,SAAS,CAAC,CAAC;AAE5D,SAAS,OAAO;IACZ,IAAI,iBAAiB;QAAE,OAAO,IAAI,CAAC;IACnC,IAAI,KAAK,KAAK,SAAS,EAAE,CAAC;QACtB,IAAI,CAAC;YACD,KAAK,GAAG,IAAI,UAAU,CAAC,WAAW,CAAC,CAAC;QACxC,CAAC;QAAC,MAAM,CAAC;YACL,KAAK,GAAG,IAAI,CAAC;QACjB,CAAC;IACL,CAAC;IACD,OAAO,KAAK,CAAC;AACjB,CAAC;AAED,0EAA0E;AAC1E,MAAM,UAAU,WAAW;IACvB,IAAI,KAAK,EAAE,CAAC;QACR,KAAK,CAAC,SAAS,EAAE,CAAC;QAClB,KAAK,GAAG,SAAS,CAAC;IACtB,CAAC;AACL,CAAC;AAED,SAAS,QAAQ,CAAC,OAAgB;IAC9B,OAAO,OAAO,CAAC,MAAM,YAAY,iBAAiB,CAAC;AACvD,CAAC;AAED;;;;;;GAMG;AACH,MAAM,UAAU,aAAa,CACzB,EAAyB;IASzB,OAAO,CACH,UAAmB,EACnB,QAAe,EACf,UAAmB,EACnB,SAAkB,EAClB,OAAc,EACd,SAAkB,EACd,EAAE;QACN,MAAM,IAAI,GAAG,YAAY,CAAC,QAAQ,CAAC,CAAC;QACpC,MAAM,OAAO,GACT,WAAW,CAAC,QAAQ,EAAE,OAAO,CAAC,IAAI,YAAY,CAAC,UAAU,EAAE,SAAS,CAAC,CAAC;QAE1E,MAAM,IAAI,GAAG,OAAO,EAAE,CAAC;QACvB,IACI,IAAI;YACJ,IAAI,IAAI,kBAAkB;YAC1B,QAAQ,CAAC,UAAU,CAAC;YACpB,QAAQ,CAAC,SAAS,CAAC,EACrB,CAAC;YACC,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC;YAE/B,IAAI,CAAC,WAAW,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC;gBACpC,IAAI,EAAE,KAAK;gBACX,QAAQ;gBACR,KAAK;gBACL,GAAG;gBACH,SAAS,EAAE,UAAU,CAAC,MAA2B;gBACjD,QAAQ,EAAE,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC;gBAC9B,UAAU,EAAE,KAAK,CAAC,IAAI,CAAC,UAAU,CAAC;gBAClC,QAAQ,EAAE,SAAS,CAAC,MAA2B;gBAC/C,OAAO,EAAE,KAAK,CAAC,IAAI,CAAC,OAAO,CAAC;gBAC5B,SAAS,EAAE,KAAK,CAAC,IAAI,CAAC,SAAS,CAAC;gBAChC,OAAO;aACV,CAAC,CAAC,CAAC;YACJ,OAAO;QACX,CAAC;QAED,IAAI,OAAO,EAAE,CAAC;YACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC5B,UAAU,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,CAAC,CAAE,CAAC,CAAC;YACtC,CAAC;YACD,OAAO;QACX,CAAC;QAED,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAC9D,MAAM,OAAO,GAAa,IAAI,KAAK,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE5D,KAAK,IAAI,OAAO,GAAG,CAAC,EAAE,OAAO,GAAG,IAAI,EAAE,OAAO,EAAE,EAAE,CAAC;YAC9C,OAAO,CAAC,OAAO,EAAE,QAAQ,EAAE,QAAQ,CAAC,CAAC;YACrC,cAAc,CAAC,QAAQ,EAAE,QAAQ,EAAE,OAAO,EAAE,OAAO,CAAC,CAAC;YAErD,MAAM,KAAK,GAAG,eAAe,CAAC,OAAO,EAAE,SAAS,CAAC,CAAC;YAClD,MAAM,MAAM,GAAG,eAAe,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;YACrD,UAAU,CAAC,MAAM,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,KAAK,CAAE,CAAC,CAAC;QAC/C,CAAC;IACL,CAAC,CAAC;AACN,CAAC;AAED;;;GAGG;AACH,MAAM,UAAU,aAAa,CACzB,EAAoC;IAYpC,OAAO,CACH,UAAmB,EACnB,QAAe,EACf,UAAmB,EACnB,QAAiB,EACjB,MAAa,EACb,QAAiB,EACjB,QAAiB,EACjB,MAAa,EACb,QAAiB,EACb,EAAE;QACN,MAAM,IAAI,GAAG,YAAY,CAAC,QAAQ,CAAC,CAAC;QACpC,MAAM,OAAO,GACT,WAAW,CAAC,QAAQ,EAAE,MAAM,CAAC;YAC7B,WAAW,CAAC,QAAQ,EAAE,MAAM,CAAC;YAC7B,YAAY,CAAC,UAAU,EAAE,QAAQ,CAAC;YAClC,YAAY,CAAC,UAAU,EAAE,QAAQ,CAAC,CAAC;QAEvC,MAAM,IAAI,GAAG,OAAO,EAAE,CAAC;QACvB,IACI,IAAI;YACJ,IAAI,IAAI,kBAAkB;YAC1B,QAAQ,CAAC,UAAU,CAAC;YACpB,QAAQ,CAAC,QAAQ,CAAC;YAClB,QAAQ,CAAC,QAAQ,CAAC,EACpB,CAAC;YACC,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC;YAE/B,IAAI,CAAC,WAAW,CAAC,IAAI,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC;gBACpC,IAAI,EAAE,KAAK;gBACX,QAAQ;gBACR,KAAK;gBACL,GAAG;gBACH,SAAS,EAAE,UAAU,CAAC,MAA2B;gBACjD,QAAQ,EAAE,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC;gBAC9B,UAAU,EAAE,KAAK,CAAC,IAAI,CAAC,UAAU,CAAC;gBAClC,OAAO,EAAE,QAAQ,CAAC,MAA2B;gBAC7C,MAAM,EAAE,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC;gBAC1B,QAAQ,EAAE,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC;gBAC9B,OAAO,EAAE,QAAQ,CAAC,MAA2B;gBAC7C,MAAM,EAAE,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC;gBAC1B,QAAQ,EAAE,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC;gBAC9B,OAAO;aACV,CAAC,CAAC,CAAC;YACJ,OAAO;QACX,CAAC;QAED,IAAI,OAAO,EAAE,CAAC;YACV,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,EAAE,CAAC,EAAE,EAAE,CAAC;gBAC5B,UAAU,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAE,EAAE,QAAQ,CAAC,CAAC,CAAE,CAAC,CAAC;YACnD,CAAC;YACD,OAAO;QACX,CAAC;QAED,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAC9D,MAAM,MAAM,GAAa,IAAI,KAAK,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAC1D,MAAM,MAAM,GAAa,IAAI,KAAK,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1D,KAAK,IAAI,OAAO,GAAG,CAAC,EAAE,OAAO,GAAG,IAAI,EAAE,OAAO,EAAE,EAAE,CAAC;YAC9C,OAAO,CAAC,OAAO,EAAE,QAAQ,EAAE,QAAQ,CAAC,CAAC;YACrC,cAAc,CAAC,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC;YACnD,cAAc,CAAC,QAAQ,EAAE,QAAQ,EAAE,MAAM,EAAE,MAAM,CAAC,CAAC;YAEnD,MAAM,IAAI,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,CAAC,CAAC;YAC/C,MAAM,IAAI,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,CAAC,CAAC;YAC/C,MAAM,MAAM,GAAG,eAAe,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;YACrD,UAAU,CAAC,MAAM,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,IAAI,CAAE,EAAE,QAAQ,CAAC,IAAI,CAAE,CAAC,CAAC;QAC9D,CAAC;IACL,CAAC,CAAC;AACN,CAAC;AAED;;;;GAIG;AACH,MAAM,UAAU,gBAAgB,CAC5B,EAAsC;IAUtC,OAAO,CACH,UAAmB,EACnB,QAAe,EACf,UAAmB,EACnB,QAAiB,EACjB,MAAa,EACb,QAAiB,EACjB,SAAiB,EACb,EAAE;QACN,MAAM,OAAO,GAAG,YAAY,CAAC,QAAQ,CAAC,CAAC;QACvC,MAAM,aAAa,GAAG,MAAM,CAAC,SAAS,CAAE,CAAC;QACzC,MAAM,YAAY,GAAG,QAAQ,CAAC,SAAS,CAAE,CAAC;QAE1C,MAAM,IAAI,GAAG,OAAO,EAAE,CAAC;QACvB,IACI,IAAI;YACJ,OAAO,IAAI,kBAAkB;YAC7B,QAAQ,CAAC,UAAU,CAAC;YACpB,QAAQ,CAAC,QAAQ,CAAC,EACpB,CAAC;YACC,MAAM,QAAQ,GAAG,EAAE,CAAC,QAAQ,EAAE,CAAC;YAE/B,IAAI,CAAC,WAAW,CAAC,OAAO,EAAE,CAAC,KAAK,EAAE,GAAG,EAAE,EAAE,CAAC,CAAC;gBACvC,IAAI,EAAE,QAAQ;gBACd,QAAQ;gBACR,KAAK;gBACL,GAAG;gBACH,SAAS,EAAE,UAAU,CAAC,MAA2B;gBACjD,QAAQ,EAAE,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC;gBAC9B,UAAU,EAAE,KAAK,CAAC,IAAI,CAAC,UAAU,CAAC;gBAClC,QAAQ,EAAE,QAAQ,CAAC,MAA2B;gBAC9C,OAAO,EAAE,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC;gBAC3B,SAAS,EAAE,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC;gBAC/B,SAAS;gBACT,aAAa;aAChB,CAAC,CAAC,CAAC;YACJ,OAAO;QACX,CAAC;QAED,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAC9D,MAAM,MAAM,GAAa,IAAI,KAAK,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE1D,KAAK,IAAI,OAAO,GAAG,CAAC,EAAE,OAAO,GAAG,OAAO,EAAE,OAAO,EAAE,EAAE,CAAC;YACjD,OAAO,CAAC,OAAO,EAAE,QAAQ,EAAE,QAAQ,CAAC,CAAC;YAErC,MAAM,MAAM,GAAG,eAAe,CAAC,QAAQ,EAAE,UAAU,CAAC,CAAC;YAErD,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,QAAQ,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;gBACvC,MAAM,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAE,CAAC;YAC7B,CAAC;YACD,MAAM,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC;YACtB,IAAI,IAAI,GAAG,eAAe,CAAC,MAAM,EAAE,QAAQ,CAAC,CAAC;YAE7C,IAAI,GAAG,GAAG,QAAQ,CAAC,IAAI,CAAE,CAAC;YAC1B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,aAAa,EAAE,CAAC,EAAE,EAAE,CAAC;gBACrC,IAAI,IAAI,YAAY,CAAC;gBACrB,GAAG,GAAG,EAAE,CAAC,GAAG,EAAE,QAAQ,CAAC,IAAI,CAAE,CAAC,CAAC;YACnC,CAAC;YAED,UAAU,CAAC,MAAM,CAAC,GAAG,GAAG,CAAC;QAC7B,CAAC;IACL,CAAC,CAAC;AACN,CAAC"}
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"fast_ops_worker.d.ts","sourceRoot":"","sources":["../src/fast_ops_worker.ts"],"names":[],"mappings":""}
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
import { parentPort, workerData } from 'node:worker_threads';
|
|
2
|
+
// Inlined from tensor_data.ts -- workers are separate V8 isolates and
|
|
3
|
+
// can't import .js-extensioned modules from .ts source.
|
|
4
|
+
function toIndex(ordinal, shape, outIndex) {
|
|
5
|
+
let remaining = ordinal;
|
|
6
|
+
for (let i = shape.length - 1; i >= 0; i--) {
|
|
7
|
+
const dimSize = shape[i];
|
|
8
|
+
outIndex[i] = remaining % dimSize;
|
|
9
|
+
remaining = Math.floor(remaining / dimSize);
|
|
10
|
+
}
|
|
11
|
+
}
|
|
12
|
+
function indexToPosition(idx, strides) {
|
|
13
|
+
let position = 0;
|
|
14
|
+
for (let i = 0; i < idx.length; i++) {
|
|
15
|
+
position += idx[i] * strides[i];
|
|
16
|
+
}
|
|
17
|
+
return position;
|
|
18
|
+
}
|
|
19
|
+
function broadcastIndex(bigIndex, bigShape, shape, outIndex) {
|
|
20
|
+
const offset = bigShape.length - shape.length;
|
|
21
|
+
for (let i = 0; i < shape.length; i++) {
|
|
22
|
+
const bigI = i + offset;
|
|
23
|
+
if (shape[i] === 1) {
|
|
24
|
+
outIndex[i] = 0;
|
|
25
|
+
}
|
|
26
|
+
else {
|
|
27
|
+
outIndex[i] = bigIndex[bigI];
|
|
28
|
+
}
|
|
29
|
+
}
|
|
30
|
+
}
|
|
31
|
+
const { workerId, syncBuffer } = workerData;
|
|
32
|
+
const syncArray = new Int32Array(syncBuffer);
|
|
33
|
+
// fn must be pure (no closures) -- reconstructed from source via new Function()
|
|
34
|
+
function reconstructFn(source) {
|
|
35
|
+
return new Function('return ' + source)();
|
|
36
|
+
}
|
|
37
|
+
function handleMap(task) {
|
|
38
|
+
const fn = reconstructFn(task.fnSource);
|
|
39
|
+
const outStorage = new Float64Array(task.outBuffer);
|
|
40
|
+
const inStorage = new Float64Array(task.inBuffer);
|
|
41
|
+
if (task.aligned) {
|
|
42
|
+
for (let i = task.start; i < task.end; i++) {
|
|
43
|
+
outStorage[i] = fn(inStorage[i]);
|
|
44
|
+
}
|
|
45
|
+
}
|
|
46
|
+
else {
|
|
47
|
+
const outIndex = new Array(task.outShape.length).fill(0);
|
|
48
|
+
const inIndex = new Array(task.inShape.length).fill(0);
|
|
49
|
+
for (let ordinal = task.start; ordinal < task.end; ordinal++) {
|
|
50
|
+
toIndex(ordinal, task.outShape, outIndex);
|
|
51
|
+
broadcastIndex(outIndex, task.outShape, task.inShape, inIndex);
|
|
52
|
+
outStorage[indexToPosition(outIndex, task.outStrides)] =
|
|
53
|
+
fn(inStorage[indexToPosition(inIndex, task.inStrides)]);
|
|
54
|
+
}
|
|
55
|
+
}
|
|
56
|
+
}
|
|
57
|
+
function handleZip(task) {
|
|
58
|
+
const fn = reconstructFn(task.fnSource);
|
|
59
|
+
const outStorage = new Float64Array(task.outBuffer);
|
|
60
|
+
const aStorage = new Float64Array(task.aBuffer);
|
|
61
|
+
const bStorage = new Float64Array(task.bBuffer);
|
|
62
|
+
if (task.aligned) {
|
|
63
|
+
for (let i = task.start; i < task.end; i++) {
|
|
64
|
+
outStorage[i] = fn(aStorage[i], bStorage[i]);
|
|
65
|
+
}
|
|
66
|
+
}
|
|
67
|
+
else {
|
|
68
|
+
const outIndex = new Array(task.outShape.length).fill(0);
|
|
69
|
+
const aIndex = new Array(task.aShape.length).fill(0);
|
|
70
|
+
const bIndex = new Array(task.bShape.length).fill(0);
|
|
71
|
+
for (let ordinal = task.start; ordinal < task.end; ordinal++) {
|
|
72
|
+
toIndex(ordinal, task.outShape, outIndex);
|
|
73
|
+
broadcastIndex(outIndex, task.outShape, task.aShape, aIndex);
|
|
74
|
+
broadcastIndex(outIndex, task.outShape, task.bShape, bIndex);
|
|
75
|
+
outStorage[indexToPosition(outIndex, task.outStrides)] =
|
|
76
|
+
fn(aStorage[indexToPosition(aIndex, task.aStrides)], bStorage[indexToPosition(bIndex, task.bStrides)]);
|
|
77
|
+
}
|
|
78
|
+
}
|
|
79
|
+
}
|
|
80
|
+
function handleReduce(task) {
|
|
81
|
+
const fn = reconstructFn(task.fnSource);
|
|
82
|
+
const outStorage = new Float64Array(task.outBuffer);
|
|
83
|
+
const inStorage = new Float64Array(task.inBuffer);
|
|
84
|
+
const reduceStride = task.inStrides[task.reduceDim];
|
|
85
|
+
const outIndex = new Array(task.outShape.length).fill(0);
|
|
86
|
+
const inIndex = new Array(task.inShape.length).fill(0);
|
|
87
|
+
for (let ordinal = task.start; ordinal < task.end; ordinal++) {
|
|
88
|
+
toIndex(ordinal, task.outShape, outIndex);
|
|
89
|
+
const outPos = indexToPosition(outIndex, task.outStrides);
|
|
90
|
+
for (let i = 0; i < task.outShape.length; i++) {
|
|
91
|
+
inIndex[i] = outIndex[i];
|
|
92
|
+
}
|
|
93
|
+
inIndex[task.reduceDim] = 0;
|
|
94
|
+
let inPos = indexToPosition(inIndex, task.inStrides);
|
|
95
|
+
let acc = inStorage[inPos];
|
|
96
|
+
for (let j = 1; j < task.reduceDimSize; j++) {
|
|
97
|
+
inPos += reduceStride;
|
|
98
|
+
acc = fn(acc, inStorage[inPos]);
|
|
99
|
+
}
|
|
100
|
+
outStorage[outPos] = acc;
|
|
101
|
+
}
|
|
102
|
+
}
|
|
103
|
+
parentPort.on('message', (task) => {
|
|
104
|
+
switch (task.type) {
|
|
105
|
+
case 'map':
|
|
106
|
+
handleMap(task);
|
|
107
|
+
break;
|
|
108
|
+
case 'zip':
|
|
109
|
+
handleZip(task);
|
|
110
|
+
break;
|
|
111
|
+
case 'reduce':
|
|
112
|
+
handleReduce(task);
|
|
113
|
+
break;
|
|
114
|
+
}
|
|
115
|
+
Atomics.store(syncArray, workerId, 1);
|
|
116
|
+
Atomics.notify(syncArray, workerId);
|
|
117
|
+
parentPort.postMessage('done');
|
|
118
|
+
});
|
|
119
|
+
//# sourceMappingURL=fast_ops_worker.js.map
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
{"version":3,"file":"fast_ops_worker.js","sourceRoot":"","sources":["../src/fast_ops_worker.ts"],"names":[],"mappings":"AAAA,OAAO,EAAE,UAAU,EAAE,UAAU,EAAE,MAAM,qBAAqB,CAAC;AAE7D,sEAAsE;AACtE,wDAAwD;AAExD,SAAS,OAAO,CAAC,OAAe,EAAE,KAAe,EAAE,QAAkB;IACjE,IAAI,SAAS,GAAG,OAAO,CAAC;IACxB,KAAK,IAAI,CAAC,GAAG,KAAK,CAAC,MAAM,GAAG,CAAC,EAAE,CAAC,IAAI,CAAC,EAAE,CAAC,EAAE,EAAE,CAAC;QACzC,MAAM,OAAO,GAAG,KAAK,CAAC,CAAC,CAAE,CAAC;QAC1B,QAAQ,CAAC,CAAC,CAAC,GAAG,SAAS,GAAG,OAAO,CAAC;QAClC,SAAS,GAAG,IAAI,CAAC,KAAK,CAAC,SAAS,GAAG,OAAO,CAAC,CAAC;IAChD,CAAC;AACL,CAAC;AAED,SAAS,eAAe,CAAC,GAAa,EAAE,OAAiB;IACrD,IAAI,QAAQ,GAAG,CAAC,CAAC;IACjB,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,GAAG,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QAClC,QAAQ,IAAI,GAAG,CAAC,CAAC,CAAE,GAAG,OAAO,CAAC,CAAC,CAAE,CAAC;IACtC,CAAC;IACD,OAAO,QAAQ,CAAC;AACpB,CAAC;AAED,SAAS,cAAc,CACnB,QAAkB,EAClB,QAAkB,EAClB,KAAe,EACf,QAAkB;IAElB,MAAM,MAAM,GAAG,QAAQ,CAAC,MAAM,GAAG,KAAK,CAAC,MAAM,CAAC;IAC9C,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,KAAK,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;QACpC,MAAM,IAAI,GAAG,CAAC,GAAG,MAAM,CAAC;QACxB,IAAI,KAAK,CAAC,CAAC,CAAC,KAAK,CAAC,EAAE,CAAC;YACjB,QAAQ,CAAC,CAAC,CAAC,GAAG,CAAC,CAAC;QACpB,CAAC;aAAM,CAAC;YACJ,QAAQ,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,IAAI,CAAE,CAAC;QAClC,CAAC;IACL,CAAC;AACL,CAAC;AAkDD,MAAM,EAAE,QAAQ,EAAE,UAAU,EAAE,GAAG,UAGhC,CAAC;AACF,MAAM,SAAS,GAAG,IAAI,UAAU,CAAC,UAAU,CAAC,CAAC;AAE7C,gFAAgF;AAChF,SAAS,aAAa,CAAI,MAAc;IACpC,OAAO,IAAI,QAAQ,CAAC,SAAS,GAAG,MAAM,CAAC,EAAO,CAAC;AACnD,CAAC;AAED,SAAS,SAAS,CAAC,IAAa;IAC5B,MAAM,EAAE,GAAG,aAAa,CAAwB,IAAI,CAAC,QAAQ,CAAC,CAAC;IAC/D,MAAM,UAAU,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACpD,MAAM,SAAS,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IAElD,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC,KAAK,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,UAAU,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,SAAS,CAAC,CAAC,CAAE,CAAC,CAAC;QACtC,CAAC;IACL,CAAC;SAAM,CAAC;QACJ,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QACnE,MAAM,OAAO,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAEjE,KAAK,IAAI,OAAO,GAAG,IAAI,CAAC,KAAK,EAAE,OAAO,GAAG,IAAI,CAAC,GAAG,EAAE,OAAO,EAAE,EAAE,CAAC;YAC3D,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,CAAC,CAAC;YAC1C,cAAc,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,OAAO,EAAE,OAAO,CAAC,CAAC;YAC/D,UAAU,CAAC,eAAe,CAAC,QAAQ,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;gBAClD,EAAE,CAAC,SAAS,CAAC,eAAe,CAAC,OAAO,EAAE,IAAI,CAAC,SAAS,CAAC,CAAE,CAAC,CAAC;QACjE,CAAC;IACL,CAAC;AACL,CAAC;AAED,SAAS,SAAS,CAAC,IAAa;IAC5B,MAAM,EAAE,GAAG,aAAa,CAAmC,IAAI,CAAC,QAAQ,CAAC,CAAC;IAC1E,MAAM,UAAU,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACpD,MAAM,QAAQ,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IAChD,MAAM,QAAQ,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,OAAO,CAAC,CAAC;IAEhD,IAAI,IAAI,CAAC,OAAO,EAAE,CAAC;QACf,KAAK,IAAI,CAAC,GAAG,IAAI,CAAC,KAAK,EAAE,CAAC,GAAG,IAAI,CAAC,GAAG,EAAE,CAAC,EAAE,EAAE,CAAC;YACzC,UAAU,CAAC,CAAC,CAAC,GAAG,EAAE,CAAC,QAAQ,CAAC,CAAC,CAAE,EAAE,QAAQ,CAAC,CAAC,CAAE,CAAC,CAAC;QACnD,CAAC;IACL,CAAC;SAAM,CAAC;QACJ,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QACnE,MAAM,MAAM,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAC/D,MAAM,MAAM,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,MAAM,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;QAE/D,KAAK,IAAI,OAAO,GAAG,IAAI,CAAC,KAAK,EAAE,OAAO,GAAG,IAAI,CAAC,GAAG,EAAE,OAAO,EAAE,EAAE,CAAC;YAC3D,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,CAAC,CAAC;YAC1C,cAAc,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC7D,cAAc,CAAC,QAAQ,EAAE,IAAI,CAAC,QAAQ,EAAE,IAAI,CAAC,MAAM,EAAE,MAAM,CAAC,CAAC;YAC7D,UAAU,CAAC,eAAe,CAAC,QAAQ,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;gBAClD,EAAE,CAAC,QAAQ,CAAC,eAAe,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAE,EACjD,QAAQ,CAAC,eAAe,CAAC,MAAM,EAAE,IAAI,CAAC,QAAQ,CAAC,CAAE,CAAC,CAAC;QAC9D,CAAC;IACL,CAAC;AACL,CAAC;AAED,SAAS,YAAY,CAAC,IAAgB;IAClC,MAAM,EAAE,GAAG,aAAa,CAAqC,IAAI,CAAC,QAAQ,CAAC,CAAC;IAC5E,MAAM,UAAU,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,SAAS,CAAC,CAAC;IACpD,MAAM,SAAS,GAAG,IAAI,YAAY,CAAC,IAAI,CAAC,QAAQ,CAAC,CAAC;IAClD,MAAM,YAAY,GAAG,IAAI,CAAC,SAAS,CAAC,IAAI,CAAC,SAAS,CAAE,CAAC;IAErD,MAAM,QAAQ,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,QAAQ,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IACnE,MAAM,OAAO,GAAa,IAAI,KAAK,CAAC,IAAI,CAAC,OAAO,CAAC,MAAM,CAAC,CAAC,IAAI,CAAC,CAAC,CAAC,CAAC;IAEjE,KAAK,IAAI,OAAO,GAAG,IAAI,CAAC,KAAK,EAAE,OAAO,GAAG,IAAI,CAAC,GAAG,EAAE,OAAO,EAAE,EAAE,CAAC;QAC3D,OAAO,CAAC,OAAO,EAAE,IAAI,CAAC,QAAQ,EAAE,QAAQ,CAAC,CAAC;QAC1C,MAAM,MAAM,GAAG,eAAe,CAAC,QAAQ,EAAE,IAAI,CAAC,UAAU,CAAC,CAAC;QAE1D,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,QAAQ,CAAC,MAAM,EAAE,CAAC,EAAE,EAAE,CAAC;YAC5C,OAAO,CAAC,CAAC,CAAC,GAAG,QAAQ,CAAC,CAAC,CAAE,CAAC;QAC9B,CAAC;QACD,OAAO,CAAC,IAAI,CAAC,SAAS,CAAC,GAAG,CAAC,CAAC;QAC5B,IAAI,KAAK,GAAG,eAAe,CAAC,OAAO,EAAE,IAAI,CAAC,SAAS,CAAC,CAAC;QAErD,IAAI,GAAG,GAAG,SAAS,CAAC,KAAK,CAAE,CAAC;QAC5B,KAAK,IAAI,CAAC,GAAG,CAAC,EAAE,CAAC,GAAG,IAAI,CAAC,aAAa,EAAE,CAAC,EAAE,EAAE,CAAC;YAC1C,KAAK,IAAI,YAAY,CAAC;YACtB,GAAG,GAAG,EAAE,CAAC,GAAG,EAAE,SAAS,CAAC,KAAK,CAAE,CAAC,CAAC;QACrC,CAAC;QAED,UAAU,CAAC,MAAM,CAAC,GAAG,GAAG,CAAC;IAC7B,CAAC;AACL,CAAC;AAED,UAAW,CAAC,EAAE,CAAC,SAAS,EAAE,CAAC,IAAU,EAAE,EAAE;IACrC,QAAQ,IAAI,CAAC,IAAI,EAAE,CAAC;QAChB,KAAK,KAAK;YAAK,SAAS,CAAC,IAAI,CAAC,CAAC;YAAI,MAAM;QACzC,KAAK,KAAK;YAAK,SAAS,CAAC,IAAI,CAAC,CAAC;YAAI,MAAM;QACzC,KAAK,QAAQ;YAAE,YAAY,CAAC,IAAI,CAAC,CAAC;YAAC,MAAM;IAC7C,CAAC;IAED,OAAO,CAAC,KAAK,CAAC,SAAS,EAAE,QAAQ,EAAE,CAAC,CAAC,CAAC;IACtC,OAAO,CAAC,MAAM,CAAC,SAAS,EAAE,QAAQ,CAAC,CAAC;IACpC,UAAW,CAAC,WAAW,CAAC,MAAM,CAAC,CAAC;AACpC,CAAC,CAAC,CAAC"}
|