@elarsaks/umap-wasm 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/{src/heap.d.ts → heap.d.ts} +1 -1
- package/dist/heap.js +184 -0
- package/dist/index.d.ts +2 -0
- package/dist/index.js +2 -0
- package/dist/lib.d.ts +2 -0
- package/dist/lib.js +2 -0
- package/dist/matrix.js +257 -0
- package/dist/{src/nn_descent.d.ts → nn_descent.d.ts} +4 -4
- package/dist/nn_descent.js +127 -0
- package/dist/{src/tree.d.ts → tree.d.ts} +2 -2
- package/dist/tree.js +228 -0
- package/dist/{src/umap.d.ts → umap.d.ts} +1 -1
- package/dist/umap.js +700 -0
- package/dist/{src/utils.d.ts → utils.d.ts} +1 -1
- package/dist/utils.js +98 -0
- package/dist/wasmBridge.js +188 -0
- package/lib/umap-js.js +6842 -7490
- package/lib/umap-js.min.js +1 -1
- package/package.json +64 -63
- package/dist/src/heap.js +0 -226
- package/dist/src/index.d.ts +0 -2
- package/dist/src/index.js +0 -8
- package/dist/src/lib.d.ts +0 -1
- package/dist/src/lib.js +0 -5
- package/dist/src/matrix.js +0 -360
- package/dist/src/nn_descent.js +0 -204
- package/dist/src/tree.js +0 -320
- package/dist/src/umap.js +0 -842
- package/dist/src/utils.js +0 -137
- package/dist/src/wasmBridge.js +0 -290
- package/dist/test/matrix.test.d.ts +0 -1
- package/dist/test/matrix.test.js +0 -169
- package/dist/test/nn_descent.test.d.ts +0 -1
- package/dist/test/nn_descent.test.js +0 -58
- package/dist/test/smoke.playwright.test.d.ts +0 -1
- package/dist/test/smoke.playwright.test.js +0 -98
- package/dist/test/test_data.d.ts +0 -13
- package/dist/test/test_data.js +0 -1054
- package/dist/test/tree.test.d.ts +0 -1
- package/dist/test/tree.test.js +0 -60
- package/dist/test/umap.test.d.ts +0 -1
- package/dist/test/umap.test.js +0 -293
- package/dist/test/utils.test.d.ts +0 -1
- package/dist/test/utils.test.js +0 -128
- package/dist/test/wasmDistance.test.d.ts +0 -1
- package/dist/test/wasmDistance.test.js +0 -124
- package/dist/test/wasmMatrix.test.d.ts +0 -1
- package/dist/test/wasmMatrix.test.js +0 -389
- package/dist/test/wasmTree.test.d.ts +0 -1
- package/dist/test/wasmTree.test.js +0 -212
- /package/dist/{src/matrix.d.ts → matrix.d.ts} +0 -0
- /package/dist/{src/wasmBridge.d.ts → wasmBridge.d.ts} +0 -0
package/dist/umap.js
ADDED
|
@@ -0,0 +1,700 @@
|
|
|
1
|
+
import * as heap from './heap.js';
|
|
2
|
+
import * as matrix from './matrix.js';
|
|
3
|
+
import * as nnDescent from './nn_descent.js';
|
|
4
|
+
import * as tree from './tree.js';
|
|
5
|
+
import * as utils from './utils.js';
|
|
6
|
+
import LM from 'ml-levenberg-marquardt';
|
|
7
|
+
import * as wasmBridge from './wasmBridge.js';
|
|
8
|
+
const SMOOTH_K_TOLERANCE = 1e-5;
|
|
9
|
+
const MIN_K_DIST_SCALE = 1e-3;
|
|
10
|
+
export class UMAP {
|
|
11
|
+
constructor(params = {}) {
|
|
12
|
+
this.learningRate = 1.0;
|
|
13
|
+
this.localConnectivity = 1.0;
|
|
14
|
+
this.minDist = 0.1;
|
|
15
|
+
this.nComponents = 2;
|
|
16
|
+
this.nEpochs = 0;
|
|
17
|
+
this.nNeighbors = 15;
|
|
18
|
+
this.negativeSampleRate = 5;
|
|
19
|
+
this.random = Math.random;
|
|
20
|
+
this.repulsionStrength = 1.0;
|
|
21
|
+
this.setOpMixRatio = 1.0;
|
|
22
|
+
this.spread = 1.0;
|
|
23
|
+
this.transformQueueSize = 4.0;
|
|
24
|
+
this.targetMetric = "categorical";
|
|
25
|
+
this.targetWeight = 0.5;
|
|
26
|
+
this.targetNNeighbors = this.nNeighbors;
|
|
27
|
+
this.distanceFn = euclidean;
|
|
28
|
+
this.useWasmDistance = false;
|
|
29
|
+
this.useWasmMatrix = false;
|
|
30
|
+
this.useWasmTree = false;
|
|
31
|
+
this.isInitialized = false;
|
|
32
|
+
this.rpForest = [];
|
|
33
|
+
this.embedding = [];
|
|
34
|
+
this.optimizationState = new OptimizationState();
|
|
35
|
+
const setParam = (key) => {
|
|
36
|
+
if (params[key] !== undefined)
|
|
37
|
+
this[key] = params[key];
|
|
38
|
+
};
|
|
39
|
+
setParam('distanceFn');
|
|
40
|
+
setParam('useWasmDistance');
|
|
41
|
+
setParam('useWasmMatrix');
|
|
42
|
+
setParam('useWasmTree');
|
|
43
|
+
setParam('learningRate');
|
|
44
|
+
setParam('localConnectivity');
|
|
45
|
+
setParam('minDist');
|
|
46
|
+
setParam('nComponents');
|
|
47
|
+
setParam('nEpochs');
|
|
48
|
+
setParam('nNeighbors');
|
|
49
|
+
setParam('negativeSampleRate');
|
|
50
|
+
setParam('random');
|
|
51
|
+
setParam('repulsionStrength');
|
|
52
|
+
setParam('setOpMixRatio');
|
|
53
|
+
setParam('spread');
|
|
54
|
+
setParam('transformQueueSize');
|
|
55
|
+
}
|
|
56
|
+
fit(X) {
|
|
57
|
+
this.initializeFit(X);
|
|
58
|
+
this.optimizeLayout();
|
|
59
|
+
return this.embedding;
|
|
60
|
+
}
|
|
61
|
+
async fitAsync(X, callback = () => true) {
|
|
62
|
+
this.initializeFit(X);
|
|
63
|
+
await this.optimizeLayoutAsync(callback);
|
|
64
|
+
return this.embedding;
|
|
65
|
+
}
|
|
66
|
+
setSupervisedProjection(Y, params = {}) {
|
|
67
|
+
this.Y = Y;
|
|
68
|
+
this.targetMetric = params.targetMetric || this.targetMetric;
|
|
69
|
+
this.targetWeight = params.targetWeight || this.targetWeight;
|
|
70
|
+
this.targetNNeighbors = params.targetNNeighbors || this.targetNNeighbors;
|
|
71
|
+
}
|
|
72
|
+
setPrecomputedKNN(knnIndices, knnDistances) {
|
|
73
|
+
this.knnIndices = knnIndices;
|
|
74
|
+
this.knnDistances = knnDistances;
|
|
75
|
+
}
|
|
76
|
+
initializeFit(X) {
|
|
77
|
+
if (X.length <= this.nNeighbors) {
|
|
78
|
+
throw new Error(`Not enough data points (${X.length}) to create nNeighbors: ${this.nNeighbors}. Add more data points or adjust the configuration.`);
|
|
79
|
+
}
|
|
80
|
+
if (this.X === X && this.isInitialized) {
|
|
81
|
+
return this.getNEpochs();
|
|
82
|
+
}
|
|
83
|
+
this.X = X;
|
|
84
|
+
if (!this.knnIndices && !this.knnDistances) {
|
|
85
|
+
const knnResults = this.nearestNeighbors(X);
|
|
86
|
+
this.knnIndices = knnResults.knnIndices;
|
|
87
|
+
this.knnDistances = knnResults.knnDistances;
|
|
88
|
+
}
|
|
89
|
+
this.graph = this.fuzzySimplicialSet(X, this.nNeighbors, this.setOpMixRatio);
|
|
90
|
+
this.makeSearchFns();
|
|
91
|
+
this.searchGraph = this.makeSearchGraph(X);
|
|
92
|
+
this.processGraphForSupervisedProjection();
|
|
93
|
+
const { head, tail, epochsPerSample, } = this.initializeSimplicialSetEmbedding();
|
|
94
|
+
this.optimizationState.head = head;
|
|
95
|
+
this.optimizationState.tail = tail;
|
|
96
|
+
this.optimizationState.epochsPerSample = epochsPerSample;
|
|
97
|
+
this.initializeOptimization();
|
|
98
|
+
this.prepareForOptimizationLoop();
|
|
99
|
+
this.isInitialized = true;
|
|
100
|
+
return this.getNEpochs();
|
|
101
|
+
}
|
|
102
|
+
makeSearchFns() {
|
|
103
|
+
const distanceWrapper = (a, b) => {
|
|
104
|
+
if (this.useWasmDistance) {
|
|
105
|
+
if (!wasmBridge.isWasmAvailable()) {
|
|
106
|
+
throw new Error('WASM distance requested via `useWasmDistance: true` but the wasm module is not initialized or available. ' +
|
|
107
|
+
'Call `await wasmBridge.initWasm()` before using UMAP with wasm distances or build the wasm package.');
|
|
108
|
+
}
|
|
109
|
+
return wasmBridge.euclideanWasm(a, b);
|
|
110
|
+
}
|
|
111
|
+
return this.distanceFn(a, b);
|
|
112
|
+
};
|
|
113
|
+
const { initFromTree, initFromRandom } = nnDescent.makeInitializations(distanceWrapper);
|
|
114
|
+
this.initFromTree = initFromTree;
|
|
115
|
+
this.initFromRandom = initFromRandom;
|
|
116
|
+
this.search = nnDescent.makeInitializedNNSearch(distanceWrapper);
|
|
117
|
+
}
|
|
118
|
+
computeDistance(a, b) {
|
|
119
|
+
if (this.useWasmDistance) {
|
|
120
|
+
if (!wasmBridge.isWasmAvailable()) {
|
|
121
|
+
throw new Error('WASM distance requested via `useWasmDistance: true` but the wasm module is not initialized or available. ' +
|
|
122
|
+
'Call `await wasmBridge.initWasm()` before using UMAP with wasm distances or build the wasm package.');
|
|
123
|
+
}
|
|
124
|
+
return wasmBridge.euclideanWasm(a, b);
|
|
125
|
+
}
|
|
126
|
+
return this.distanceFn(a, b);
|
|
127
|
+
}
|
|
128
|
+
makeSearchGraph(X) {
|
|
129
|
+
const knnIndices = this.knnIndices;
|
|
130
|
+
const knnDistances = this.knnDistances;
|
|
131
|
+
const dims = [X.length, X.length];
|
|
132
|
+
const searchGraph = new matrix.SparseMatrix([], [], [], dims);
|
|
133
|
+
for (let i = 0; i < knnIndices.length; i++) {
|
|
134
|
+
const knn = knnIndices[i];
|
|
135
|
+
const distances = knnDistances[i];
|
|
136
|
+
for (let j = 0; j < knn.length; j++) {
|
|
137
|
+
const neighbor = knn[j];
|
|
138
|
+
const distance = distances[j];
|
|
139
|
+
if (distance > 0) {
|
|
140
|
+
searchGraph.set(i, neighbor, distance);
|
|
141
|
+
}
|
|
142
|
+
}
|
|
143
|
+
}
|
|
144
|
+
const transpose = matrix.transpose(searchGraph);
|
|
145
|
+
return matrix.maximum(searchGraph, transpose);
|
|
146
|
+
}
|
|
147
|
+
transform(toTransform) {
|
|
148
|
+
const rawData = this.X;
|
|
149
|
+
if (rawData === undefined || rawData.length === 0) {
|
|
150
|
+
throw new Error('No data has been fit.');
|
|
151
|
+
}
|
|
152
|
+
let nNeighbors = Math.floor(this.nNeighbors * this.transformQueueSize);
|
|
153
|
+
nNeighbors = Math.min(rawData.length, nNeighbors);
|
|
154
|
+
const init = nnDescent.initializeSearch(this.rpForest, rawData, toTransform, nNeighbors, this.initFromRandom, this.initFromTree, this.random);
|
|
155
|
+
const result = this.search(rawData, this.searchGraph, init, toTransform);
|
|
156
|
+
let { indices, weights: distances } = heap.deheapSort(result);
|
|
157
|
+
indices = indices.map(x => x.slice(0, this.nNeighbors));
|
|
158
|
+
distances = distances.map(x => x.slice(0, this.nNeighbors));
|
|
159
|
+
const adjustedLocalConnectivity = Math.max(0, this.localConnectivity - 1);
|
|
160
|
+
const { sigmas, rhos } = this.smoothKNNDistance(distances, this.nNeighbors, adjustedLocalConnectivity);
|
|
161
|
+
const { rows, cols, vals } = this.computeMembershipStrengths(indices, distances, sigmas, rhos);
|
|
162
|
+
const size = [toTransform.length, rawData.length];
|
|
163
|
+
let graph = new matrix.SparseMatrix(rows, cols, vals, size);
|
|
164
|
+
const normed = matrix.normalize(graph, "l1");
|
|
165
|
+
const csrMatrix = matrix.getCSR(normed);
|
|
166
|
+
const nPoints = toTransform.length;
|
|
167
|
+
const eIndices = utils.reshape2d(csrMatrix.indices, nPoints, this.nNeighbors);
|
|
168
|
+
const eWeights = utils.reshape2d(csrMatrix.values, nPoints, this.nNeighbors);
|
|
169
|
+
const embedding = initTransform(eIndices, eWeights, this.embedding);
|
|
170
|
+
const nEpochs = this.nEpochs
|
|
171
|
+
? this.nEpochs / 3
|
|
172
|
+
: graph.nRows <= 10000
|
|
173
|
+
? 100
|
|
174
|
+
: 30;
|
|
175
|
+
const graphMax = graph
|
|
176
|
+
.getValues()
|
|
177
|
+
.reduce((max, val) => (val > max ? val : max), 0);
|
|
178
|
+
graph = graph.map(value => (value < graphMax / nEpochs ? 0 : value));
|
|
179
|
+
graph = matrix.eliminateZeros(graph);
|
|
180
|
+
const epochsPerSample = this.makeEpochsPerSample(graph.getValues(), nEpochs);
|
|
181
|
+
const head = graph.getRows();
|
|
182
|
+
const tail = graph.getCols();
|
|
183
|
+
this.assignOptimizationStateParameters({
|
|
184
|
+
headEmbedding: embedding,
|
|
185
|
+
tailEmbedding: this.embedding,
|
|
186
|
+
head,
|
|
187
|
+
tail,
|
|
188
|
+
currentEpoch: 0,
|
|
189
|
+
nEpochs,
|
|
190
|
+
nVertices: graph.getDims()[1],
|
|
191
|
+
epochsPerSample,
|
|
192
|
+
});
|
|
193
|
+
this.prepareForOptimizationLoop();
|
|
194
|
+
return this.optimizeLayout();
|
|
195
|
+
}
|
|
196
|
+
processGraphForSupervisedProjection() {
|
|
197
|
+
const { Y, X } = this;
|
|
198
|
+
if (Y) {
|
|
199
|
+
if (Y.length !== X.length) {
|
|
200
|
+
throw new Error('Length of X and y must be equal');
|
|
201
|
+
}
|
|
202
|
+
if (this.targetMetric === "categorical") {
|
|
203
|
+
const lt = this.targetWeight < 1.0;
|
|
204
|
+
const farDist = lt ? 2.5 * (1.0 / (1.0 - this.targetWeight)) : 1.0e12;
|
|
205
|
+
this.graph = this.categoricalSimplicialSetIntersection(this.graph, Y, farDist);
|
|
206
|
+
}
|
|
207
|
+
}
|
|
208
|
+
}
|
|
209
|
+
step() {
|
|
210
|
+
const { currentEpoch } = this.optimizationState;
|
|
211
|
+
if (currentEpoch < this.getNEpochs()) {
|
|
212
|
+
this.optimizeLayoutStep(currentEpoch);
|
|
213
|
+
}
|
|
214
|
+
return this.optimizationState.currentEpoch;
|
|
215
|
+
}
|
|
216
|
+
getEmbedding() {
|
|
217
|
+
return this.embedding;
|
|
218
|
+
}
|
|
219
|
+
nearestNeighbors(X) {
|
|
220
|
+
const { distanceFn, nNeighbors } = this;
|
|
221
|
+
const log2 = (n) => Math.log(n) / Math.log(2);
|
|
222
|
+
const metricNNDescent = nnDescent.makeNNDescent(distanceFn, this.random);
|
|
223
|
+
const round = (n) => {
|
|
224
|
+
return n === 0.5 ? 0 : Math.round(n);
|
|
225
|
+
};
|
|
226
|
+
const nTrees = 5 + Math.floor(round(X.length ** 0.5 / 20.0));
|
|
227
|
+
const nIters = Math.max(5, Math.floor(Math.round(log2(X.length))));
|
|
228
|
+
this.rpForest = tree.makeForest(X, nNeighbors, nTrees, this.random, this.useWasmTree);
|
|
229
|
+
const leafArray = tree.makeLeafArray(this.rpForest);
|
|
230
|
+
const { indices, weights } = metricNNDescent(X, leafArray, nNeighbors, nIters);
|
|
231
|
+
return { knnIndices: indices, knnDistances: weights };
|
|
232
|
+
}
|
|
233
|
+
fuzzySimplicialSet(X, nNeighbors, setOpMixRatio = 1.0) {
|
|
234
|
+
const { knnIndices = [], knnDistances = [], localConnectivity } = this;
|
|
235
|
+
const { sigmas, rhos } = this.smoothKNNDistance(knnDistances, nNeighbors, localConnectivity);
|
|
236
|
+
const { rows, cols, vals } = this.computeMembershipStrengths(knnIndices, knnDistances, sigmas, rhos);
|
|
237
|
+
const size = [X.length, X.length];
|
|
238
|
+
if (this.useWasmMatrix && wasmBridge.isWasmAvailable()) {
|
|
239
|
+
const wasmMat = wasmBridge.createSparseMatrixWasm(rows, cols, vals, size[0], size[1]);
|
|
240
|
+
const transpose = wasmBridge.sparseTransposeWasm(wasmMat);
|
|
241
|
+
const prodMatrix = wasmBridge.sparsePairwiseMultiplyWasm(wasmMat, transpose);
|
|
242
|
+
const added = wasmBridge.sparseAddWasm(wasmMat, transpose);
|
|
243
|
+
const a = wasmBridge.sparseSubtractWasm(added, prodMatrix);
|
|
244
|
+
const b = wasmBridge.sparseMultiplyScalarWasm(a, setOpMixRatio);
|
|
245
|
+
const c = wasmBridge.sparseMultiplyScalarWasm(prodMatrix, 1.0 - setOpMixRatio);
|
|
246
|
+
const resultWasm = wasmBridge.sparseAddWasm(b, c);
|
|
247
|
+
const entries = wasmBridge.wasmSparseMatrixGetAll(resultWasm);
|
|
248
|
+
const jsRows = entries.map(e => e.row);
|
|
249
|
+
const jsCols = entries.map(e => e.col);
|
|
250
|
+
const jsVals = entries.map(e => e.value);
|
|
251
|
+
return new matrix.SparseMatrix(jsRows, jsCols, jsVals, size);
|
|
252
|
+
}
|
|
253
|
+
if (this.useWasmMatrix && wasmBridge.isWasmAvailable()) {
|
|
254
|
+
const wasmMat = wasmBridge.createSparseMatrixWasm(rows, cols, vals, size[0], size[1]);
|
|
255
|
+
const transpose = wasmBridge.sparseTransposeWasm(wasmMat);
|
|
256
|
+
const prodMatrix = wasmBridge.sparsePairwiseMultiplyWasm(wasmMat, transpose);
|
|
257
|
+
const added = wasmBridge.sparseAddWasm(wasmMat, transpose);
|
|
258
|
+
const a = wasmBridge.sparseSubtractWasm(added, prodMatrix);
|
|
259
|
+
const b = wasmBridge.sparseMultiplyScalarWasm(a, setOpMixRatio);
|
|
260
|
+
const c = wasmBridge.sparseMultiplyScalarWasm(prodMatrix, 1.0 - setOpMixRatio);
|
|
261
|
+
const resultWasm = wasmBridge.sparseAddWasm(b, c);
|
|
262
|
+
const entries = wasmBridge.wasmSparseMatrixGetAll(resultWasm);
|
|
263
|
+
const jsRows = entries.map(e => e.row);
|
|
264
|
+
const jsCols = entries.map(e => e.col);
|
|
265
|
+
const jsVals = entries.map(e => e.value);
|
|
266
|
+
return new matrix.SparseMatrix(jsRows, jsCols, jsVals, size);
|
|
267
|
+
}
|
|
268
|
+
const sparseMatrix = new matrix.SparseMatrix(rows, cols, vals, size);
|
|
269
|
+
const transpose = matrix.transpose(sparseMatrix);
|
|
270
|
+
const prodMatrix = matrix.pairwiseMultiply(sparseMatrix, transpose);
|
|
271
|
+
const a = matrix.subtract(matrix.add(sparseMatrix, transpose), prodMatrix);
|
|
272
|
+
const b = matrix.multiplyScalar(a, setOpMixRatio);
|
|
273
|
+
const c = matrix.multiplyScalar(prodMatrix, 1.0 - setOpMixRatio);
|
|
274
|
+
const result = matrix.add(b, c);
|
|
275
|
+
return result;
|
|
276
|
+
}
|
|
277
|
+
categoricalSimplicialSetIntersection(simplicialSet, target, farDist, unknownDist = 1.0) {
|
|
278
|
+
let intersection = fastIntersection(simplicialSet, target, unknownDist, farDist);
|
|
279
|
+
intersection = matrix.eliminateZeros(intersection);
|
|
280
|
+
return resetLocalConnectivity(intersection);
|
|
281
|
+
}
|
|
282
|
+
smoothKNNDistance(distances, k, localConnectivity = 1.0, nIter = 64, bandwidth = 1.0) {
|
|
283
|
+
const target = (Math.log(k) / Math.log(2)) * bandwidth;
|
|
284
|
+
const rho = utils.zeros(distances.length);
|
|
285
|
+
const result = utils.zeros(distances.length);
|
|
286
|
+
for (let i = 0; i < distances.length; i++) {
|
|
287
|
+
let lo = 0.0;
|
|
288
|
+
let hi = Infinity;
|
|
289
|
+
let mid = 1.0;
|
|
290
|
+
const ithDistances = distances[i];
|
|
291
|
+
const nonZeroDists = ithDistances.filter(d => d > 0.0);
|
|
292
|
+
if (nonZeroDists.length >= localConnectivity) {
|
|
293
|
+
let index = Math.floor(localConnectivity);
|
|
294
|
+
let interpolation = localConnectivity - index;
|
|
295
|
+
if (index > 0) {
|
|
296
|
+
rho[i] = nonZeroDists[index - 1];
|
|
297
|
+
if (interpolation > SMOOTH_K_TOLERANCE) {
|
|
298
|
+
rho[i] +=
|
|
299
|
+
interpolation * (nonZeroDists[index] - nonZeroDists[index - 1]);
|
|
300
|
+
}
|
|
301
|
+
}
|
|
302
|
+
else {
|
|
303
|
+
rho[i] = interpolation * nonZeroDists[0];
|
|
304
|
+
}
|
|
305
|
+
}
|
|
306
|
+
else if (nonZeroDists.length > 0) {
|
|
307
|
+
rho[i] = utils.max(nonZeroDists);
|
|
308
|
+
}
|
|
309
|
+
for (let n = 0; n < nIter; n++) {
|
|
310
|
+
let psum = 0.0;
|
|
311
|
+
for (let j = 1; j < distances[i].length; j++) {
|
|
312
|
+
const d = distances[i][j] - rho[i];
|
|
313
|
+
if (d > 0) {
|
|
314
|
+
psum += Math.exp(-(d / mid));
|
|
315
|
+
}
|
|
316
|
+
else {
|
|
317
|
+
psum += 1.0;
|
|
318
|
+
}
|
|
319
|
+
}
|
|
320
|
+
if (Math.abs(psum - target) < SMOOTH_K_TOLERANCE) {
|
|
321
|
+
break;
|
|
322
|
+
}
|
|
323
|
+
if (psum > target) {
|
|
324
|
+
hi = mid;
|
|
325
|
+
mid = (lo + hi) / 2.0;
|
|
326
|
+
}
|
|
327
|
+
else {
|
|
328
|
+
lo = mid;
|
|
329
|
+
if (hi === Infinity) {
|
|
330
|
+
mid *= 2;
|
|
331
|
+
}
|
|
332
|
+
else {
|
|
333
|
+
mid = (lo + hi) / 2.0;
|
|
334
|
+
}
|
|
335
|
+
}
|
|
336
|
+
}
|
|
337
|
+
result[i] = mid;
|
|
338
|
+
if (rho[i] > 0.0) {
|
|
339
|
+
const meanIthDistances = utils.mean(ithDistances);
|
|
340
|
+
if (result[i] < MIN_K_DIST_SCALE * meanIthDistances) {
|
|
341
|
+
result[i] = MIN_K_DIST_SCALE * meanIthDistances;
|
|
342
|
+
}
|
|
343
|
+
}
|
|
344
|
+
else {
|
|
345
|
+
const meanDistances = utils.mean(distances.map(utils.mean));
|
|
346
|
+
if (result[i] < MIN_K_DIST_SCALE * meanDistances) {
|
|
347
|
+
result[i] = MIN_K_DIST_SCALE * meanDistances;
|
|
348
|
+
}
|
|
349
|
+
}
|
|
350
|
+
}
|
|
351
|
+
return { sigmas: result, rhos: rho };
|
|
352
|
+
}
|
|
353
|
+
computeMembershipStrengths(knnIndices, knnDistances, sigmas, rhos) {
|
|
354
|
+
const nSamples = knnIndices.length;
|
|
355
|
+
const nNeighbors = knnIndices[0].length;
|
|
356
|
+
const rows = utils.zeros(nSamples * nNeighbors);
|
|
357
|
+
const cols = utils.zeros(nSamples * nNeighbors);
|
|
358
|
+
const vals = utils.zeros(nSamples * nNeighbors);
|
|
359
|
+
for (let i = 0; i < nSamples; i++) {
|
|
360
|
+
for (let j = 0; j < nNeighbors; j++) {
|
|
361
|
+
let val = 0;
|
|
362
|
+
if (knnIndices[i][j] === -1) {
|
|
363
|
+
continue;
|
|
364
|
+
}
|
|
365
|
+
if (knnIndices[i][j] === i) {
|
|
366
|
+
val = 0.0;
|
|
367
|
+
}
|
|
368
|
+
else if (knnDistances[i][j] - rhos[i] <= 0.0) {
|
|
369
|
+
val = 1.0;
|
|
370
|
+
}
|
|
371
|
+
else {
|
|
372
|
+
val = Math.exp(-((knnDistances[i][j] - rhos[i]) / sigmas[i]));
|
|
373
|
+
}
|
|
374
|
+
rows[i * nNeighbors + j] = i;
|
|
375
|
+
cols[i * nNeighbors + j] = knnIndices[i][j];
|
|
376
|
+
vals[i * nNeighbors + j] = val;
|
|
377
|
+
}
|
|
378
|
+
}
|
|
379
|
+
return { rows, cols, vals };
|
|
380
|
+
}
|
|
381
|
+
initializeSimplicialSetEmbedding() {
|
|
382
|
+
const nEpochs = this.getNEpochs();
|
|
383
|
+
const { nComponents } = this;
|
|
384
|
+
const graphValues = this.graph.getValues();
|
|
385
|
+
let graphMax = 0;
|
|
386
|
+
for (let i = 0; i < graphValues.length; i++) {
|
|
387
|
+
const value = graphValues[i];
|
|
388
|
+
if (graphMax < graphValues[i]) {
|
|
389
|
+
graphMax = value;
|
|
390
|
+
}
|
|
391
|
+
}
|
|
392
|
+
const graph = this.graph.map(value => {
|
|
393
|
+
if (value < graphMax / nEpochs) {
|
|
394
|
+
return 0;
|
|
395
|
+
}
|
|
396
|
+
else {
|
|
397
|
+
return value;
|
|
398
|
+
}
|
|
399
|
+
});
|
|
400
|
+
this.embedding = utils.zeros(graph.nRows).map(() => {
|
|
401
|
+
return utils.zeros(nComponents).map(() => {
|
|
402
|
+
return utils.tauRand(this.random) * 20 + -10;
|
|
403
|
+
});
|
|
404
|
+
});
|
|
405
|
+
const weights = [];
|
|
406
|
+
const head = [];
|
|
407
|
+
const tail = [];
|
|
408
|
+
const rowColValues = graph.getAll();
|
|
409
|
+
for (let i = 0; i < rowColValues.length; i++) {
|
|
410
|
+
const entry = rowColValues[i];
|
|
411
|
+
if (entry.value) {
|
|
412
|
+
weights.push(entry.value);
|
|
413
|
+
tail.push(entry.row);
|
|
414
|
+
head.push(entry.col);
|
|
415
|
+
}
|
|
416
|
+
}
|
|
417
|
+
const epochsPerSample = this.makeEpochsPerSample(weights, nEpochs);
|
|
418
|
+
return { head, tail, epochsPerSample };
|
|
419
|
+
}
|
|
420
|
+
makeEpochsPerSample(weights, nEpochs) {
|
|
421
|
+
const result = utils.filled(weights.length, -1.0);
|
|
422
|
+
const max = utils.max(weights);
|
|
423
|
+
const nSamples = weights.map(w => (w / max) * nEpochs);
|
|
424
|
+
nSamples.forEach((n, i) => {
|
|
425
|
+
if (n > 0)
|
|
426
|
+
result[i] = nEpochs / nSamples[i];
|
|
427
|
+
});
|
|
428
|
+
return result;
|
|
429
|
+
}
|
|
430
|
+
assignOptimizationStateParameters(state) {
|
|
431
|
+
Object.assign(this.optimizationState, state);
|
|
432
|
+
}
|
|
433
|
+
prepareForOptimizationLoop() {
|
|
434
|
+
const { repulsionStrength, learningRate, negativeSampleRate } = this;
|
|
435
|
+
const { epochsPerSample, headEmbedding, tailEmbedding, } = this.optimizationState;
|
|
436
|
+
const dim = headEmbedding[0].length;
|
|
437
|
+
const moveOther = headEmbedding.length === tailEmbedding.length;
|
|
438
|
+
const epochsPerNegativeSample = epochsPerSample.map(e => e / negativeSampleRate);
|
|
439
|
+
const epochOfNextNegativeSample = [...epochsPerNegativeSample];
|
|
440
|
+
const epochOfNextSample = [...epochsPerSample];
|
|
441
|
+
this.assignOptimizationStateParameters({
|
|
442
|
+
epochOfNextSample,
|
|
443
|
+
epochOfNextNegativeSample,
|
|
444
|
+
epochsPerNegativeSample,
|
|
445
|
+
moveOther,
|
|
446
|
+
initialAlpha: learningRate,
|
|
447
|
+
alpha: learningRate,
|
|
448
|
+
gamma: repulsionStrength,
|
|
449
|
+
dim,
|
|
450
|
+
});
|
|
451
|
+
}
|
|
452
|
+
initializeOptimization() {
|
|
453
|
+
const headEmbedding = this.embedding;
|
|
454
|
+
const tailEmbedding = this.embedding;
|
|
455
|
+
const { head, tail, epochsPerSample } = this.optimizationState;
|
|
456
|
+
const nEpochs = this.getNEpochs();
|
|
457
|
+
const nVertices = this.graph.nCols;
|
|
458
|
+
const { a, b } = findABParams(this.spread, this.minDist);
|
|
459
|
+
this.assignOptimizationStateParameters({
|
|
460
|
+
headEmbedding,
|
|
461
|
+
tailEmbedding,
|
|
462
|
+
head,
|
|
463
|
+
tail,
|
|
464
|
+
epochsPerSample,
|
|
465
|
+
a,
|
|
466
|
+
b,
|
|
467
|
+
nEpochs,
|
|
468
|
+
nVertices,
|
|
469
|
+
});
|
|
470
|
+
}
|
|
471
|
+
optimizeLayoutStep(n) {
|
|
472
|
+
const { optimizationState } = this;
|
|
473
|
+
const { head, tail, headEmbedding, tailEmbedding, epochsPerSample, epochOfNextSample, epochOfNextNegativeSample, epochsPerNegativeSample, moveOther, initialAlpha, alpha, gamma, a, b, dim, nEpochs, nVertices, } = optimizationState;
|
|
474
|
+
const clipValue = 4.0;
|
|
475
|
+
for (let i = 0; i < epochsPerSample.length; i++) {
|
|
476
|
+
if (epochOfNextSample[i] > n) {
|
|
477
|
+
continue;
|
|
478
|
+
}
|
|
479
|
+
const j = head[i];
|
|
480
|
+
const k = tail[i];
|
|
481
|
+
const current = headEmbedding[j];
|
|
482
|
+
const other = tailEmbedding[k];
|
|
483
|
+
const distSquared = rDist(current, other);
|
|
484
|
+
let gradCoeff = 0;
|
|
485
|
+
if (distSquared > 0) {
|
|
486
|
+
gradCoeff = -2.0 * a * b * Math.pow(distSquared, b - 1.0);
|
|
487
|
+
gradCoeff /= a * Math.pow(distSquared, b) + 1.0;
|
|
488
|
+
}
|
|
489
|
+
for (let d = 0; d < dim; d++) {
|
|
490
|
+
const gradD = clip(gradCoeff * (current[d] - other[d]), clipValue);
|
|
491
|
+
current[d] += gradD * alpha;
|
|
492
|
+
if (moveOther) {
|
|
493
|
+
other[d] += -gradD * alpha;
|
|
494
|
+
}
|
|
495
|
+
}
|
|
496
|
+
epochOfNextSample[i] += epochsPerSample[i];
|
|
497
|
+
const nNegSamples = Math.floor((n - epochOfNextNegativeSample[i]) / epochsPerNegativeSample[i]);
|
|
498
|
+
for (let p = 0; p < nNegSamples; p++) {
|
|
499
|
+
const k = utils.tauRandInt(nVertices, this.random);
|
|
500
|
+
const other = tailEmbedding[k];
|
|
501
|
+
const distSquared = rDist(current, other);
|
|
502
|
+
let gradCoeff = 0.0;
|
|
503
|
+
if (distSquared > 0.0) {
|
|
504
|
+
gradCoeff = 2.0 * gamma * b;
|
|
505
|
+
gradCoeff /=
|
|
506
|
+
(0.001 + distSquared) * (a * Math.pow(distSquared, b) + 1);
|
|
507
|
+
}
|
|
508
|
+
else if (j === k) {
|
|
509
|
+
continue;
|
|
510
|
+
}
|
|
511
|
+
for (let d = 0; d < dim; d++) {
|
|
512
|
+
let gradD = 4.0;
|
|
513
|
+
if (gradCoeff > 0.0) {
|
|
514
|
+
gradD = clip(gradCoeff * (current[d] - other[d]), clipValue);
|
|
515
|
+
}
|
|
516
|
+
current[d] += gradD * alpha;
|
|
517
|
+
}
|
|
518
|
+
}
|
|
519
|
+
epochOfNextNegativeSample[i] += nNegSamples * epochsPerNegativeSample[i];
|
|
520
|
+
}
|
|
521
|
+
optimizationState.alpha = initialAlpha * (1.0 - n / nEpochs);
|
|
522
|
+
optimizationState.currentEpoch += 1;
|
|
523
|
+
return headEmbedding;
|
|
524
|
+
}
|
|
525
|
+
optimizeLayoutAsync(epochCallback = () => true) {
|
|
526
|
+
return new Promise((resolve, reject) => {
|
|
527
|
+
const step = async () => {
|
|
528
|
+
try {
|
|
529
|
+
const { nEpochs, currentEpoch } = this.optimizationState;
|
|
530
|
+
this.embedding = this.optimizeLayoutStep(currentEpoch);
|
|
531
|
+
const epochCompleted = this.optimizationState.currentEpoch;
|
|
532
|
+
const shouldStop = epochCallback(epochCompleted) === false;
|
|
533
|
+
const isFinished = epochCompleted === nEpochs;
|
|
534
|
+
if (!shouldStop && !isFinished) {
|
|
535
|
+
setTimeout(() => step(), 0);
|
|
536
|
+
}
|
|
537
|
+
else {
|
|
538
|
+
return resolve(isFinished);
|
|
539
|
+
}
|
|
540
|
+
}
|
|
541
|
+
catch (err) {
|
|
542
|
+
reject(err);
|
|
543
|
+
}
|
|
544
|
+
};
|
|
545
|
+
setTimeout(() => step(), 0);
|
|
546
|
+
});
|
|
547
|
+
}
|
|
548
|
+
optimizeLayout(epochCallback = () => true) {
|
|
549
|
+
let isFinished = false;
|
|
550
|
+
let embedding = [];
|
|
551
|
+
while (!isFinished) {
|
|
552
|
+
const { nEpochs, currentEpoch } = this.optimizationState;
|
|
553
|
+
embedding = this.optimizeLayoutStep(currentEpoch);
|
|
554
|
+
const epochCompleted = this.optimizationState.currentEpoch;
|
|
555
|
+
const shouldStop = epochCallback(epochCompleted) === false;
|
|
556
|
+
isFinished = epochCompleted === nEpochs || shouldStop;
|
|
557
|
+
}
|
|
558
|
+
return embedding;
|
|
559
|
+
}
|
|
560
|
+
getNEpochs() {
|
|
561
|
+
const graph = this.graph;
|
|
562
|
+
if (this.nEpochs > 0) {
|
|
563
|
+
return this.nEpochs;
|
|
564
|
+
}
|
|
565
|
+
const length = graph.nRows;
|
|
566
|
+
if (length <= 2500) {
|
|
567
|
+
return 500;
|
|
568
|
+
}
|
|
569
|
+
else if (length <= 5000) {
|
|
570
|
+
return 400;
|
|
571
|
+
}
|
|
572
|
+
else if (length <= 7500) {
|
|
573
|
+
return 300;
|
|
574
|
+
}
|
|
575
|
+
else {
|
|
576
|
+
return 200;
|
|
577
|
+
}
|
|
578
|
+
}
|
|
579
|
+
}
|
|
580
|
+
export function euclidean(x, y) {
|
|
581
|
+
let result = 0;
|
|
582
|
+
for (let i = 0; i < x.length; i++) {
|
|
583
|
+
result += (x[i] - y[i]) ** 2;
|
|
584
|
+
}
|
|
585
|
+
return Math.sqrt(result);
|
|
586
|
+
}
|
|
587
|
+
export function cosine(x, y) {
|
|
588
|
+
let result = 0.0;
|
|
589
|
+
let normX = 0.0;
|
|
590
|
+
let normY = 0.0;
|
|
591
|
+
for (let i = 0; i < x.length; i++) {
|
|
592
|
+
result += x[i] * y[i];
|
|
593
|
+
normX += x[i] ** 2;
|
|
594
|
+
normY += y[i] ** 2;
|
|
595
|
+
}
|
|
596
|
+
if (normX === 0 && normY === 0) {
|
|
597
|
+
return 0;
|
|
598
|
+
}
|
|
599
|
+
else if (normX === 0 || normY === 0) {
|
|
600
|
+
return 1.0;
|
|
601
|
+
}
|
|
602
|
+
else {
|
|
603
|
+
return 1.0 - result / Math.sqrt(normX * normY);
|
|
604
|
+
}
|
|
605
|
+
}
|
|
606
|
+
class OptimizationState {
|
|
607
|
+
constructor() {
|
|
608
|
+
this.currentEpoch = 0;
|
|
609
|
+
this.headEmbedding = [];
|
|
610
|
+
this.tailEmbedding = [];
|
|
611
|
+
this.head = [];
|
|
612
|
+
this.tail = [];
|
|
613
|
+
this.epochsPerSample = [];
|
|
614
|
+
this.epochOfNextSample = [];
|
|
615
|
+
this.epochOfNextNegativeSample = [];
|
|
616
|
+
this.epochsPerNegativeSample = [];
|
|
617
|
+
this.moveOther = true;
|
|
618
|
+
this.initialAlpha = 1.0;
|
|
619
|
+
this.alpha = 1.0;
|
|
620
|
+
this.gamma = 1.0;
|
|
621
|
+
this.a = 1.5769434603113077;
|
|
622
|
+
this.b = 0.8950608779109733;
|
|
623
|
+
this.dim = 2;
|
|
624
|
+
this.nEpochs = 500;
|
|
625
|
+
this.nVertices = 0;
|
|
626
|
+
}
|
|
627
|
+
}
|
|
628
|
+
function clip(x, clipValue) {
|
|
629
|
+
if (x > clipValue)
|
|
630
|
+
return clipValue;
|
|
631
|
+
else if (x < -clipValue)
|
|
632
|
+
return -clipValue;
|
|
633
|
+
else
|
|
634
|
+
return x;
|
|
635
|
+
}
|
|
636
|
+
function rDist(x, y) {
|
|
637
|
+
let result = 0.0;
|
|
638
|
+
for (let i = 0; i < x.length; i++) {
|
|
639
|
+
result += Math.pow(x[i] - y[i], 2);
|
|
640
|
+
}
|
|
641
|
+
return result;
|
|
642
|
+
}
|
|
643
|
+
export function findABParams(spread, minDist) {
|
|
644
|
+
const curve = ([a, b]) => (x) => {
|
|
645
|
+
return 1.0 / (1.0 + a * x ** (2 * b));
|
|
646
|
+
};
|
|
647
|
+
const xv = utils
|
|
648
|
+
.linear(0, spread * 3, 300)
|
|
649
|
+
.map(val => (val < minDist ? 1.0 : val));
|
|
650
|
+
const yv = utils.zeros(xv.length).map((val, index) => {
|
|
651
|
+
const gte = xv[index] >= minDist;
|
|
652
|
+
return gte ? Math.exp(-(xv[index] - minDist) / spread) : val;
|
|
653
|
+
});
|
|
654
|
+
const initialValues = [0.5, 0.5];
|
|
655
|
+
const data = { x: xv, y: yv };
|
|
656
|
+
const options = {
|
|
657
|
+
damping: 1.5,
|
|
658
|
+
initialValues,
|
|
659
|
+
gradientDifference: 10e-2,
|
|
660
|
+
maxIterations: 100,
|
|
661
|
+
errorTolerance: 10e-3,
|
|
662
|
+
};
|
|
663
|
+
const { parameterValues } = LM(data, curve, options);
|
|
664
|
+
const [a, b] = parameterValues;
|
|
665
|
+
return { a, b };
|
|
666
|
+
}
|
|
667
|
+
export function fastIntersection(graph, target, unknownDist = 1.0, farDist = 5.0) {
|
|
668
|
+
return graph.map((value, row, col) => {
|
|
669
|
+
if (target[row] === -1 || target[col] === -1) {
|
|
670
|
+
return value * Math.exp(-unknownDist);
|
|
671
|
+
}
|
|
672
|
+
else if (target[row] !== target[col]) {
|
|
673
|
+
return value * Math.exp(-farDist);
|
|
674
|
+
}
|
|
675
|
+
else {
|
|
676
|
+
return value;
|
|
677
|
+
}
|
|
678
|
+
});
|
|
679
|
+
}
|
|
680
|
+
export function resetLocalConnectivity(simplicialSet) {
|
|
681
|
+
simplicialSet = matrix.normalize(simplicialSet, "max");
|
|
682
|
+
const transpose = matrix.transpose(simplicialSet);
|
|
683
|
+
const prodMatrix = matrix.pairwiseMultiply(transpose, simplicialSet);
|
|
684
|
+
simplicialSet = matrix.add(simplicialSet, matrix.subtract(transpose, prodMatrix));
|
|
685
|
+
return matrix.eliminateZeros(simplicialSet);
|
|
686
|
+
}
|
|
687
|
+
export function initTransform(indices, weights, embedding) {
|
|
688
|
+
const result = utils
|
|
689
|
+
.zeros(indices.length)
|
|
690
|
+
.map(z => utils.zeros(embedding[0].length));
|
|
691
|
+
for (let i = 0; i < indices.length; i++) {
|
|
692
|
+
for (let j = 0; j < indices[0].length; j++) {
|
|
693
|
+
for (let d = 0; d < embedding[0].length; d++) {
|
|
694
|
+
const a = indices[i][j];
|
|
695
|
+
result[i][d] += weights[i][j] * embedding[a][d];
|
|
696
|
+
}
|
|
697
|
+
}
|
|
698
|
+
}
|
|
699
|
+
return result;
|
|
700
|
+
}
|