@elarsaks/umap-wasm 0.4.4 → 0.4.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/matrix.d.ts CHANGED
@@ -2,7 +2,7 @@ export declare class SparseMatrix {
2
2
  private entries;
3
3
  readonly nRows: number;
4
4
  readonly nCols: number;
5
- constructor(rows: number[], cols: number[], values: number[], dims: number[]);
5
+ constructor(rows: ArrayLike<number>, cols: ArrayLike<number>, values: ArrayLike<number>, dims: number[]);
6
6
  private makeKey;
7
7
  private checkDims;
8
8
  set(row: number, col: number, value: number): void;
@@ -2,7 +2,7 @@ import * as heap from './heap.js';
2
2
  import * as matrix from './matrix.js';
3
3
  import * as tree from './tree.js';
4
4
  import { RandomFn, Vectors, DistanceFn } from './umap.js';
5
- export declare function makeNNDescent(distanceFn: DistanceFn, random: RandomFn, useWasm?: boolean): (data: Vectors, leafArray: Vectors, nNeighbors: number, nIters?: number, maxCandidates?: number, delta?: number, rho?: number, rpTreeInit?: boolean) => {
5
+ export declare function makeNNDescent(distanceFn: DistanceFn, random: RandomFn, useWasm?: boolean): (data: Vectors, leafArray: ArrayLike<number>[], nNeighbors: number, nIters?: number, maxCandidates?: number, delta?: number, rho?: number, rpTreeInit?: boolean) => {
6
6
  indices: number[][];
7
7
  weights: number[][];
8
8
  };
@@ -11,10 +11,39 @@ export function makeNNDescent(distanceFn, random, useWasm = false) {
11
11
  }
12
12
  const distanceMetric = distanceFn.name === 'cosine' ? 'cosine' : 'euclidean';
13
13
  const seed = Math.floor(random() * 0xFFFFFFFF);
14
- const result = wasmBridge.nnDescentWasm(data, leafArray, nNeighbors, nIters, maxCandidates, delta, rho, rpTreeInit, distanceMetric, seed);
14
+ const nSamples = data.length;
15
+ const dim = data[0].length;
16
+ const flatData = new Float64Array(nSamples * dim);
17
+ for (let i = 0; i < nSamples; i++) {
18
+ for (let j = 0; j < dim; j++) {
19
+ flatData[i * dim + j] = data[i][j];
20
+ }
21
+ }
22
+ const nLeaves = leafArray.length;
23
+ const leafSize = nLeaves > 0 ? leafArray[0].length : 0;
24
+ const flatLeafArray = new Int32Array(nLeaves * leafSize);
25
+ for (let i = 0; i < nLeaves; i++) {
26
+ for (let j = 0; j < leafSize; j++) {
27
+ flatLeafArray[i * leafSize + j] = leafArray[i][j];
28
+ }
29
+ }
30
+ const result = wasmBridge.nnDescentWasmFlat(flatData, nSamples, dim, flatLeafArray, nLeaves, leafSize, nNeighbors, nIters, maxCandidates, delta, rho, rpTreeInit, distanceMetric, seed);
31
+ const indices = [];
32
+ const distances = [];
33
+ const offset1 = nSamples * nNeighbors;
34
+ for (let i = 0; i < nSamples; i++) {
35
+ const rowIndices = [];
36
+ const rowDistances = [];
37
+ for (let j = 0; j < nNeighbors; j++) {
38
+ rowIndices.push(result[i * nNeighbors + j]);
39
+ rowDistances.push(result[offset1 + i * nNeighbors + j]);
40
+ }
41
+ indices.push(rowIndices);
42
+ distances.push(rowDistances);
43
+ }
15
44
  return {
16
- indices: result[0],
17
- weights: result[1],
45
+ indices,
46
+ weights: distances,
18
47
  };
19
48
  }
20
49
  const nVertices = data.length;
package/dist/tree.d.ts CHANGED
@@ -1,16 +1,38 @@
1
1
  import { RandomFn, Vector, Vectors } from './umap.js';
2
2
  import { WasmFlatTree } from './wasmBridge.js';
3
3
  export declare class FlatTree {
4
- hyperplanes: number[][];
5
- offsets: number[];
6
- children: number[][];
7
- indices: number[][];
4
+ hyperplanes: ArrayLike<number>[];
5
+ offsets: ArrayLike<number>;
6
+ children: ArrayLike<number>[];
7
+ indices: ArrayLike<number>[];
8
8
  private wasmTree?;
9
- constructor(hyperplanes: number[][], offsets: number[], children: number[][], indices: number[][]);
9
+ private hyperplanesFlat?;
10
+ private offsetsFlat?;
11
+ private childrenFlat?;
12
+ private indicesFlat?;
13
+ private dim?;
14
+ private nNodes?;
15
+ private nLeaves?;
16
+ private leafSize?;
17
+ constructor(hyperplanes: ArrayLike<number>[], offsets: ArrayLike<number>, children: ArrayLike<number>[], indices: ArrayLike<number>[]);
10
18
  static fromWasm(wasmTree: WasmFlatTree): FlatTree;
11
19
  getWasmTree(): WasmFlatTree | undefined;
20
+ getFlatHyperplanes(): Float64Array | undefined;
21
+ getFlatOffsets(): Float64Array | undefined;
22
+ getFlatChildren(): Int32Array | undefined;
23
+ getFlatLeafMeta(): {
24
+ indices: Int32Array;
25
+ nLeaves: number;
26
+ leafSize: number;
27
+ } | undefined;
28
+ getDim(): number | undefined;
12
29
  dispose(): void;
13
30
  }
14
31
  export declare function makeForest(data: Vectors, nNeighbors: number, nTrees: number, random: RandomFn, useWasm?: boolean): FlatTree[];
15
- export declare function makeLeafArray(rpForest: FlatTree[]): number[][];
16
- export declare function searchFlatTree(point: Vector, tree: FlatTree, random: RandomFn): number[];
32
+ export declare function makeLeafArray(rpForest: FlatTree[]): ArrayLike<number>[];
33
+ export declare function makeLeafArrayFlat(rpForest: FlatTree[]): {
34
+ flatLeafArray: Int32Array;
35
+ nLeaves: number;
36
+ leafSize: number;
37
+ };
38
+ export declare function searchFlatTree(point: Vector, tree: FlatTree, random: RandomFn): ArrayLike<number>;
package/dist/tree.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import * as utils from './utils.js';
2
- import { isWasmAvailable, buildRpTreeWasm, searchFlatTreeWasm, wasmTreeToJs } from './wasmBridge.js';
2
+ import { isWasmAvailable, buildRpTreeWasmFlat, searchFlatTreeWasm } from './wasmBridge.js';
3
3
  export class FlatTree {
4
4
  constructor(hyperplanes, offsets, children, indices) {
5
5
  this.hyperplanes = hyperplanes;
@@ -8,14 +8,77 @@ export class FlatTree {
8
8
  this.indices = indices;
9
9
  }
10
10
  static fromWasm(wasmTree) {
11
- const jsData = wasmTreeToJs(wasmTree);
12
- const tree = new FlatTree(jsData.hyperplanes, jsData.offsets, jsData.children, jsData.indices);
11
+ const tree = new FlatTree([], new Float64Array(0), [], []);
12
+ const childrenFlat = wasmTree.children();
13
+ let maxLeafIdx = -1;
14
+ for (let i = 0; i < childrenFlat.length; i++) {
15
+ const v = childrenFlat[i];
16
+ if (v <= 0) {
17
+ const leafIdx = -v;
18
+ if (leafIdx > maxLeafIdx)
19
+ maxLeafIdx = leafIdx;
20
+ }
21
+ }
22
+ const nLeaves = maxLeafIdx + 1;
23
+ const indicesFlat = wasmTree.indices();
24
+ const leafSize = nLeaves > 0 ? Math.floor(indicesFlat.length / nLeaves) : 0;
25
+ tree.hyperplanesFlat = wasmTree.hyperplanes();
26
+ tree.offsetsFlat = wasmTree.offsets();
27
+ tree.childrenFlat = childrenFlat;
28
+ tree.indicesFlat = indicesFlat;
29
+ tree.dim = wasmTree.dim();
30
+ tree.nNodes = wasmTree.n_nodes();
31
+ tree.nLeaves = nLeaves;
32
+ tree.leafSize = leafSize;
33
+ const hyperplanes = new Array(tree.nNodes);
34
+ for (let i = 0; i < tree.nNodes; i++) {
35
+ const start = i * tree.dim;
36
+ hyperplanes[i] = tree.hyperplanesFlat.subarray(start, start + tree.dim);
37
+ }
38
+ const children = new Array(tree.nNodes);
39
+ for (let i = 0; i < tree.nNodes; i++) {
40
+ const start = i * 2;
41
+ children[i] = tree.childrenFlat.subarray(start, start + 2);
42
+ }
43
+ const indices = new Array(tree.nLeaves);
44
+ for (let i = 0; i < tree.nLeaves; i++) {
45
+ const start = i * tree.leafSize;
46
+ indices[i] = tree.indicesFlat.subarray(start, start + tree.leafSize);
47
+ }
48
+ tree.hyperplanes = hyperplanes;
49
+ tree.offsets = tree.offsetsFlat;
50
+ tree.children = children;
51
+ tree.indices = indices;
13
52
  tree.wasmTree = wasmTree;
14
53
  return tree;
15
54
  }
16
55
  getWasmTree() {
17
56
  return this.wasmTree;
18
57
  }
58
+ getFlatHyperplanes() {
59
+ return this.hyperplanesFlat;
60
+ }
61
+ getFlatOffsets() {
62
+ return this.offsetsFlat;
63
+ }
64
+ getFlatChildren() {
65
+ return this.childrenFlat;
66
+ }
67
+ getFlatLeafMeta() {
68
+ if (this.indicesFlat &&
69
+ this.nLeaves !== undefined &&
70
+ this.leafSize !== undefined) {
71
+ return {
72
+ indices: this.indicesFlat,
73
+ nLeaves: this.nLeaves,
74
+ leafSize: this.leafSize,
75
+ };
76
+ }
77
+ return undefined;
78
+ }
79
+ getDim() {
80
+ return this.dim;
81
+ }
19
82
  dispose() {
20
83
  if (this.wasmTree) {
21
84
  this.wasmTree.free();
@@ -41,9 +104,15 @@ function makeForestWasm(data, leafSize, nTrees, random) {
41
104
  const nSamples = data.length;
42
105
  const dim = data[0].length;
43
106
  const forest = [];
107
+ const flatData = new Float64Array(nSamples * dim);
108
+ for (let i = 0; i < nSamples; i++) {
109
+ for (let j = 0; j < dim; j++) {
110
+ flatData[i * dim + j] = data[i][j];
111
+ }
112
+ }
44
113
  for (let i = 0; i < nTrees; i++) {
45
114
  const seed = Math.floor(random() * 0xFFFFFFFF);
46
- const wasmTree = buildRpTreeWasm(data, nSamples, dim, leafSize, seed);
115
+ const wasmTree = buildRpTreeWasmFlat(flatData, nSamples, dim, leafSize, seed);
47
116
  forest.push(FlatTree.fromWasm(wasmTree));
48
117
  }
49
118
  return forest;
@@ -183,7 +252,23 @@ export function makeLeafArray(rpForest) {
183
252
  if (rpForest.length > 0) {
184
253
  const output = [];
185
254
  for (let tree of rpForest) {
186
- output.push(...tree.indices);
255
+ if (tree.indices.length > 0) {
256
+ output.push(...tree.indices);
257
+ continue;
258
+ }
259
+ const flatMeta = tree.getFlatLeafMeta();
260
+ if (!flatMeta) {
261
+ continue;
262
+ }
263
+ const { indices, nLeaves, leafSize } = flatMeta;
264
+ for (let leaf = 0; leaf < nLeaves; leaf++) {
265
+ const start = leaf * leafSize;
266
+ const row = new Array(leafSize);
267
+ for (let i = 0; i < leafSize; i++) {
268
+ row[i] = indices[start + i];
269
+ }
270
+ output.push(row);
271
+ }
187
272
  }
188
273
  return output;
189
274
  }
@@ -191,6 +276,44 @@ export function makeLeafArray(rpForest) {
191
276
  return [[-1]];
192
277
  }
193
278
  }
279
+ export function makeLeafArrayFlat(rpForest) {
280
+ if (rpForest.length === 0) {
281
+ return { flatLeafArray: new Int32Array([-1]), nLeaves: 1, leafSize: 1 };
282
+ }
283
+ let leafSize = -1;
284
+ let totalLeaves = 0;
285
+ for (const tree of rpForest) {
286
+ const flatMeta = tree.getFlatLeafMeta();
287
+ if (flatMeta) {
288
+ if (leafSize === -1)
289
+ leafSize = flatMeta.leafSize;
290
+ totalLeaves += flatMeta.nLeaves;
291
+ }
292
+ else {
293
+ if (leafSize === -1)
294
+ leafSize = tree.indices[0]?.length ?? 0;
295
+ totalLeaves += tree.indices.length;
296
+ }
297
+ }
298
+ const flatLeafArray = new Int32Array(totalLeaves * leafSize);
299
+ let offset = 0;
300
+ for (const tree of rpForest) {
301
+ const flatMeta = tree.getFlatLeafMeta();
302
+ if (flatMeta) {
303
+ flatLeafArray.set(flatMeta.indices, offset);
304
+ offset += flatMeta.indices.length;
305
+ continue;
306
+ }
307
+ for (let i = 0; i < tree.indices.length; i++) {
308
+ const row = tree.indices[i];
309
+ for (let j = 0; j < leafSize; j++) {
310
+ flatLeafArray[offset + i * leafSize + j] = row[j];
311
+ }
312
+ }
313
+ offset += tree.indices.length * leafSize;
314
+ }
315
+ return { flatLeafArray, nLeaves: totalLeaves, leafSize };
316
+ }
194
317
  function selectSide(hyperplane, offset, point, random) {
195
318
  let margin = offset;
196
319
  for (let d = 0; d < point.length; d++) {
@@ -213,6 +336,37 @@ export function searchFlatTree(point, tree, random) {
213
336
  const seed = Math.floor(random() * 0xFFFFFFFF);
214
337
  return searchFlatTreeWasm(wasmTree, point, seed);
215
338
  }
339
+ const childrenFlat = tree.getFlatChildren();
340
+ const hyperplanesFlat = tree.getFlatHyperplanes();
341
+ const offsetsFlat = tree.getFlatOffsets();
342
+ const dim = tree.getDim();
343
+ const flatMeta = tree.getFlatLeafMeta();
344
+ if (childrenFlat &&
345
+ hyperplanesFlat &&
346
+ offsetsFlat &&
347
+ dim !== undefined &&
348
+ flatMeta) {
349
+ let node = 0;
350
+ while (childrenFlat[node * 2] > 0) {
351
+ const offset = offsetsFlat[node];
352
+ const base = node * dim;
353
+ const side = selectSide(hyperplanesFlat.subarray(base, base + dim), offset, point, random);
354
+ if (side === 0) {
355
+ node = childrenFlat[node * 2];
356
+ }
357
+ else {
358
+ node = childrenFlat[node * 2 + 1];
359
+ }
360
+ }
361
+ const leafIdx = -childrenFlat[node * 2];
362
+ const { indices, leafSize } = flatMeta;
363
+ const start = leafIdx * leafSize;
364
+ const result = new Array(leafSize);
365
+ for (let i = 0; i < leafSize; i++) {
366
+ result[i] = indices[start + i];
367
+ }
368
+ return result;
369
+ }
216
370
  let node = 0;
217
371
  while (tree.children[node][0] > 0) {
218
372
  const side = selectSide(tree.hyperplanes[node], tree.offsets[node], point, random);
package/dist/umap.js CHANGED
@@ -237,6 +237,34 @@ export class UMAP {
237
237
  const nTrees = 5 + Math.floor(round(X.length ** 0.5 / 20.0));
238
238
  const nIters = Math.max(5, Math.floor(Math.round(log2(X.length))));
239
239
  this.rpForest = tree.makeForest(X, nNeighbors, nTrees, this.random, this.useWasmTree);
240
+ if (this.useWasmNNDescent && wasmBridge.isWasmAvailable()) {
241
+ const nSamples = X.length;
242
+ const dim = X[0].length;
243
+ const flatData = new Float64Array(nSamples * dim);
244
+ for (let i = 0; i < nSamples; i++) {
245
+ for (let j = 0; j < dim; j++) {
246
+ flatData[i * dim + j] = X[i][j];
247
+ }
248
+ }
249
+ const { flatLeafArray, nLeaves, leafSize } = tree.makeLeafArrayFlat(this.rpForest);
250
+ const distanceMetric = distanceFn.name === 'cosine' ? 'cosine' : 'euclidean';
251
+ const seed = Math.floor(this.random() * 0xFFFFFFFF);
252
+ const result = wasmBridge.nnDescentWasmFlat(flatData, nSamples, dim, flatLeafArray, nLeaves, leafSize, nNeighbors, nIters, 50, 0.001, 0.5, true, distanceMetric, seed);
253
+ const indices = [];
254
+ const weights = [];
255
+ const offset1 = nSamples * nNeighbors;
256
+ for (let i = 0; i < nSamples; i++) {
257
+ const rowIndices = [];
258
+ const rowWeights = [];
259
+ for (let j = 0; j < nNeighbors; j++) {
260
+ rowIndices.push(result[i * nNeighbors + j]);
261
+ rowWeights.push(result[offset1 + i * nNeighbors + j]);
262
+ }
263
+ indices.push(rowIndices);
264
+ weights.push(rowWeights);
265
+ }
266
+ return { knnIndices: indices, knnDistances: weights };
267
+ }
240
268
  const leafArray = tree.makeLeafArray(this.rpForest);
241
269
  const { indices, weights } = metricNNDescent(X, leafArray, nNeighbors, nIters);
242
270
  return { knnIndices: indices, knnDistances: weights };
@@ -255,10 +283,7 @@ export class UMAP {
255
283
  const b = wasmBridge.sparseMultiplyScalarWasm(a, setOpMixRatio);
256
284
  const c = wasmBridge.sparseMultiplyScalarWasm(prodMatrix, 1.0 - setOpMixRatio);
257
285
  const resultWasm = wasmBridge.sparseAddWasm(b, c);
258
- const entries = wasmBridge.wasmSparseMatrixGetAll(resultWasm);
259
- const jsRows = entries.map(e => e.row);
260
- const jsCols = entries.map(e => e.col);
261
- const jsVals = entries.map(e => e.value);
286
+ const { rows: jsRows, cols: jsCols, values: jsVals } = wasmBridge.wasmSparseMatrixGetAllTyped(resultWasm);
262
287
  return new matrix.SparseMatrix(jsRows, jsCols, jsVals, size);
263
288
  }
264
289
  if (this.useWasmMatrix && wasmBridge.isWasmAvailable()) {
@@ -270,10 +295,7 @@ export class UMAP {
270
295
  const b = wasmBridge.sparseMultiplyScalarWasm(a, setOpMixRatio);
271
296
  const c = wasmBridge.sparseMultiplyScalarWasm(prodMatrix, 1.0 - setOpMixRatio);
272
297
  const resultWasm = wasmBridge.sparseAddWasm(b, c);
273
- const entries = wasmBridge.wasmSparseMatrixGetAll(resultWasm);
274
- const jsRows = entries.map(e => e.row);
275
- const jsCols = entries.map(e => e.col);
276
- const jsVals = entries.map(e => e.value);
298
+ const { rows: jsRows, cols: jsCols, values: jsVals } = wasmBridge.wasmSparseMatrixGetAllTyped(resultWasm);
277
299
  return new matrix.SparseMatrix(jsRows, jsCols, jsVals, size);
278
300
  }
279
301
  const sparseMatrix = new matrix.SparseMatrix(rows, cols, vals, size);
@@ -12,6 +12,7 @@ export interface WasmFlatTree {
12
12
  free(): void;
13
13
  }
14
14
  export declare function buildRpTreeWasm(data: number[][], nSamples: number, dim: number, leafSize: number, seed: number): WasmFlatTree;
15
+ export declare function buildRpTreeWasmFlat(flatData: Float64Array, nSamples: number, dim: number, leafSize: number, seed: number): WasmFlatTree;
15
16
  export declare function searchFlatTreeWasm(tree: WasmFlatTree, point: number[], seed: number): number[];
16
17
  export declare function wasmTreeToJs(wasmTree: WasmFlatTree): {
17
18
  hyperplanes: number[][];
@@ -55,7 +56,13 @@ export declare function wasmSparseMatrixGetAll(matrix: WasmSparseMatrix): {
55
56
  row: number;
56
57
  col: number;
57
58
  }[];
58
- export declare function nnDescentWasm(data: number[][], leafArray: number[][], nNeighbors: number, nIters?: number, maxCandidates?: number, delta?: number, rho?: number, rpTreeInit?: boolean, distanceMetric?: string, seed?: number): number[][][];
59
+ export declare function wasmSparseMatrixGetAllTyped(matrix: WasmSparseMatrix): {
60
+ rows: Int32Array;
61
+ cols: Int32Array;
62
+ values: Float64Array;
63
+ };
64
+ export declare function nnDescentWasm(data: number[][], leafArray: ArrayLike<number>[], nNeighbors: number, nIters?: number, maxCandidates?: number, delta?: number, rho?: number, rpTreeInit?: boolean, distanceMetric?: string, seed?: number): number[][][];
65
+ export declare function nnDescentWasmFlat(flatData: Float64Array, nSamples: number, dim: number, flatLeafArray: Int32Array, nLeaves: number, leafSize: number, nNeighbors: number, nIters?: number, maxCandidates?: number, delta?: number, rho?: number, rpTreeInit?: boolean, distanceMetric?: string, seed?: number): number[];
59
66
  export interface WasmOptimizerState {
60
67
  head_embedding: Float64Array;
61
68
  current_epoch: number;
@@ -64,6 +64,11 @@ export function buildRpTreeWasm(data, nSamples, dim, leafSize, seed) {
64
64
  flatData[i * dim + j] = data[i][j];
65
65
  }
66
66
  }
67
+ return buildRpTreeWasmFlat(flatData, nSamples, dim, leafSize, seed);
68
+ }
69
+ export function buildRpTreeWasmFlat(flatData, nSamples, dim, leafSize, seed) {
70
+ if (!wasmModule)
71
+ throw new Error('WASM module not initialized');
67
72
  return wasmModule.build_rp_tree(flatData, nSamples, dim, leafSize, BigInt(seed));
68
73
  }
69
74
  export function searchFlatTreeWasm(tree, point, seed) {
@@ -179,17 +184,22 @@ export function sparseGetCSRWasm(matrix) {
179
184
  return { indices, values, indptr };
180
185
  }
181
186
  export function wasmSparseMatrixToArray(matrix) {
182
- const flat = Array.from(matrix.to_array());
187
+ const flat = matrix.to_array();
183
188
  const nRows = matrix.n_rows;
184
189
  const nCols = matrix.n_cols;
185
190
  const result = [];
186
191
  for (let i = 0; i < nRows; i++) {
187
- result.push(flat.slice(i * nCols, (i + 1) * nCols));
192
+ const row = new Array(nCols);
193
+ const start = i * nCols;
194
+ for (let j = 0; j < nCols; j++) {
195
+ row[j] = flat[start + j];
196
+ }
197
+ result.push(row);
188
198
  }
189
199
  return result;
190
200
  }
191
201
  export function wasmSparseMatrixGetAll(matrix) {
192
- const flat = Array.from(matrix.get_all_ordered());
202
+ const flat = matrix.get_all_ordered();
193
203
  const entries = [];
194
204
  for (let i = 0; i < flat.length; i += 3) {
195
205
  entries.push({
@@ -200,6 +210,21 @@ export function wasmSparseMatrixGetAll(matrix) {
200
210
  }
201
211
  return entries;
202
212
  }
213
+ export function wasmSparseMatrixGetAllTyped(matrix) {
214
+ const flat = matrix.get_all_ordered();
215
+ const count = Math.floor(flat.length / 3);
216
+ const rows = new Int32Array(count);
217
+ const cols = new Int32Array(count);
218
+ const values = new Float64Array(count);
219
+ let out = 0;
220
+ for (let i = 0; i < flat.length; i += 3) {
221
+ rows[out] = flat[i];
222
+ cols[out] = flat[i + 1];
223
+ values[out] = flat[i + 2];
224
+ out += 1;
225
+ }
226
+ return { rows, cols, values };
227
+ }
203
228
  export function nnDescentWasm(data, leafArray, nNeighbors, nIters = 10, maxCandidates = 50, delta = 0.001, rho = 0.5, rpTreeInit = true, distanceMetric = 'euclidean', seed = 42) {
204
229
  if (!wasmModule)
205
230
  throw new Error('WASM module not initialized');
@@ -219,7 +244,7 @@ export function nnDescentWasm(data, leafArray, nNeighbors, nIters = 10, maxCandi
219
244
  flatLeafArray[i * leafSize + j] = leafArray[i][j];
220
245
  }
221
246
  }
222
- const result = wasmModule.nn_descent(flatData, nSamples, dim, flatLeafArray, nLeaves, leafSize, nNeighbors, nIters, maxCandidates, delta, rho, rpTreeInit, distanceMetric, BigInt(seed));
247
+ const result = nnDescentWasmFlat(flatData, nSamples, dim, flatLeafArray, nLeaves, leafSize, nNeighbors, nIters, maxCandidates, delta, rho, rpTreeInit, distanceMetric, seed);
223
248
  const indices = [];
224
249
  const distances = [];
225
250
  const flags = [];
@@ -240,6 +265,11 @@ export function nnDescentWasm(data, leafArray, nNeighbors, nIters = 10, maxCandi
240
265
  }
241
266
  return [indices, distances, flags];
242
267
  }
268
+ export function nnDescentWasmFlat(flatData, nSamples, dim, flatLeafArray, nLeaves, leafSize, nNeighbors, nIters = 10, maxCandidates = 50, delta = 0.001, rho = 0.5, rpTreeInit = true, distanceMetric = 'euclidean', seed = 42) {
269
+ if (!wasmModule)
270
+ throw new Error('WASM module not initialized');
271
+ return wasmModule.nn_descent(flatData, nSamples, dim, flatLeafArray, nLeaves, leafSize, nNeighbors, nIters, maxCandidates, delta, rho, rpTreeInit, distanceMetric, BigInt(seed));
272
+ }
243
273
  export function createOptimizerState(head, tail, headEmbedding, tailEmbedding, epochsPerSample, epochsPerNegativeSample, moveOther, initialAlpha, gamma, a, b, dim, nEpochs, nVertices) {
244
274
  if (!wasmModule)
245
275
  throw new Error('WASM module not initialized');