umap-gpu 0.1.0 → 0.2.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,87 +1,104 @@
1
1
  # umap-gpu
2
2
 
3
- UMAP dimensionality reduction with HNSW k-nearest-neighbor search and WebGPU-accelerated SGD optimization, with a transparent CPU fallback.
3
+ UMAP dimensionality reduction with WebGPU-accelerated SGD and HNSW approximate nearest neighbors.
4
4
 
5
- ## What it does
5
+ Embed millions of high-dimensional vectors into 2D in seconds — not minutes.
6
6
 
7
- Takes a set of high-dimensional vectors and returns a low-dimensional embedding (default: 2D) suitable for visualization or downstream tasks.
7
+ ## Why GPU?
8
8
 
9
- The pipeline runs in three stages:
9
+ The bottleneck in UMAP is the SGD optimization loop: thousands of epochs, millions of edge updates per epoch. On CPU this is sequential. On GPU, all edges run in parallel across thousands of shader cores — expect a significant speedup on large datasets, scaling with both the number of points and the number of epochs.
10
10
 
11
- 1. **k-NN** approximate nearest neighbors via [hnswlib-wasm](https://github.com/yoshoku/hnswlib-wasm) (O(n log n))
12
- 2. **Fuzzy simplicial set** builds a weighted graph from the k-NN graph using smooth distances
13
- 3. **SGD** — optimizes the embedding using attraction/repulsion forces:
14
- - **WebGPU** compute shader when available (Chrome 113+, Edge 113+)
15
- - **CPU** fallback otherwise — identical output, just slower
11
+ The k-NN stage uses [hnswlib-wasm](https://github.com/yoshoku/hnswlib-wasm) (O(n log n)) so it stays fast regardless.
12
+ A transparent CPU fallback guarantees identical output everywhere WebGPU isn't available.
16
13
 
17
14
  ## Install
18
15
 
19
16
  ```bash
17
+ # npm
20
18
  npm install umap-gpu
21
- ```
22
19
 
23
- > Requires a browser or runtime with WebGPU support for GPU acceleration. The CPU fallback works anywhere.
20
+ # Bun
21
+ bun add umap-gpu
22
+
23
+ # pnpm
24
+ pnpm add umap-gpu
25
+ ```
24
26
 
25
- ## Usage
27
+ ## Quick start
26
28
 
27
29
  ```ts
28
30
  import { fit } from 'umap-gpu';
29
31
 
30
32
  const vectors = [
31
- [1.0, 0.0, 0.3],
32
- [0.9, 0.1, 0.4],
33
- [0.0, 1.0, 0.8],
33
+ [0.1, 0.4, 0.9, ...], // high-dimensional points
34
+ [0.2, 0.3, 0.8, ...],
34
35
  // ...
35
36
  ];
36
37
 
37
38
  const embedding = await fit(vectors);
38
- // Float32Array of length n * nComponents (default: n * 2)
39
- // embedding[i*2], embedding[i*2 + 1] → 2D coordinates of point i
39
+ // Float32Array embedding[i*2], embedding[i*2+1] are the 2D coords of point i
40
+ ```
41
+
42
+ ## Train once, project many times
43
+
44
+ Use the `UMAP` class to embed a training set and later project new points into the same space without retraining.
45
+
46
+ ```ts
47
+ import { UMAP } from 'umap-gpu';
48
+
49
+ const umap = new UMAP({ nNeighbors: 15, minDist: 0.1 });
50
+
51
+ // Train
52
+ await umap.fit(trainVectors);
53
+ console.log(umap.embedding); // Float32Array [nTrain × 2]
54
+
55
+ // Project new points (training embedding stays fixed)
56
+ const projected = await umap.transform(newVectors);
57
+ // Float32Array [nNew × 2]
40
58
  ```
41
59
 
42
- ### Options
60
+ ## Options
43
61
 
44
62
  ```ts
45
- const embedding = await fit(vectors, {
46
- nComponents: 2, // output dimensions (default: 2)
47
- nNeighbors: 15, // k-NN graph degree (default: 15)
48
- nEpochs: 500, // SGD iterations (default: 500 for <10k points, 200 otherwise)
49
- minDist: 0.1, // minimum distance between points in the embedding (default: 0.1)
50
- spread: 1.0, // scale of the embedding (default: 1.0)
63
+ const umap = new UMAP({
64
+ nComponents: 2, // output dimensions (default: 2)
65
+ nNeighbors: 15, // k-NN graph degree (default: 15)
66
+ nEpochs: 500, // SGD iterations (default: auto — 500 for <10k points, 200 otherwise)
67
+ minDist: 0.1, // min distance in embedding (default: 0.1)
68
+ spread: 1.0, // scale of the embedding (default: 1.0)
51
69
  hnsw: {
52
- M: 16, // HNSW graph connectivity (default: 16)
53
- efConstruction: 200, // build-time search width (default: 200)
54
- efSearch: 50, // query-time search width (default: 50)
70
+ M: 16, // graph connectivity (default: 16)
71
+ efConstruction: 200, // build-time search width (default: 200)
72
+ efSearch: 50, // query-time search width (default: 50)
55
73
  },
56
74
  });
75
+
76
+ // Same options work with the functional API
77
+ const embedding = await fit(vectors, { nNeighbors: 15, minDist: 0.05 });
57
78
  ```
58
79
 
59
- ### Checking GPU availability
80
+ ## Check GPU availability
60
81
 
61
82
  ```ts
62
83
  import { isWebGPUAvailable } from 'umap-gpu';
63
84
 
64
- if (isWebGPUAvailable()) {
65
- console.log('Will use WebGPU-accelerated SGD');
66
- } else {
67
- console.log('Will fall back to CPU SGD');
68
- }
69
- ```
70
-
71
- ## Build
72
-
73
- ```bash
74
- npm run build # compiles TypeScript to dist/
75
- npm test # runs the unit test suite (Vitest)
85
+ console.log(isWebGPUAvailable()); // true → GPU path, false → CPU fallback
76
86
  ```
77
87
 
78
88
  ## Browser support
79
89
 
80
- | Feature | Requirement |
90
+ | Feature | Supported in |
81
91
  |---------|-------------|
82
92
  | WebGPU SGD | Chrome 113+, Edge 113+, Safari 18+ |
83
- | CPU fallback | Any modern browser / Node.js |
84
- | HNSW (WASM) | Any environment with WebAssembly support |
93
+ | CPU fallback | Any modern browser / Node.js / Bun |
94
+ | HNSW (WASM) | Any environment with WebAssembly |
95
+
96
+ ## Development
97
+
98
+ ```bash
99
+ npm test # Vitest unit tests
100
+ npm run build # TypeScript → dist/
101
+ ```
85
102
 
86
103
  ## License
87
104
 
@@ -9,4 +9,21 @@ export interface CPUSgdParams {
9
9
  * CPU fallback SGD optimizer for environments without WebGPU.
10
10
  * Mirrors the GPU shader logic: per-edge attraction + negative-sample repulsion.
11
11
  */
12
- export declare function cpuSgd(embedding: Float32Array, graph: FuzzyGraph, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: CPUSgdParams): Float32Array;
12
+ export declare function cpuSgd(embedding: Float32Array, graph: FuzzyGraph, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: CPUSgdParams, onProgress?: (epoch: number, nEpochs: number) => void): Float32Array;
13
+ /**
14
+ * CPU SGD for UMAP.transform(): optimizes only the new-point embeddings.
15
+ * The training embedding is read-only; attraction pulls new points toward
16
+ * their training neighbors, and repulsion pushes them away from random
17
+ * training points.
18
+ *
19
+ * @param embeddingNew - New-point embeddings to optimize [nNew × nComponents]
20
+ * @param embeddingTrain - Fixed training embeddings [nTrain × nComponents]
21
+ * @param graph - Bipartite graph: rows=new-point indices, cols=training-point indices
22
+ * @param epochsPerSample - Per-edge epoch sampling schedule
23
+ * @param nNew - Number of new points
24
+ * @param nTrain - Number of training points
25
+ * @param nComponents - Embedding dimensionality
26
+ * @param nEpochs - Number of optimization epochs
27
+ * @param params - UMAP curve parameters
28
+ */
29
+ export declare function cpuSgdTransform(embeddingNew: Float32Array, embeddingTrain: Float32Array, graph: FuzzyGraph, epochsPerSample: Float32Array, nNew: number, nTrain: number, nComponents: number, nEpochs: number, params: CPUSgdParams, onProgress?: (epoch: number, nEpochs: number) => void): Float32Array;
@@ -10,3 +10,16 @@ export interface FuzzyGraph {
10
10
  * (sigmas, rhos) and symmetrizes with the fuzzy set union operation.
11
11
  */
12
12
  export declare function computeFuzzySimplicialSet(knnIndices: number[][], knnDistances: number[][], nNeighbors: number, setOpMixRatio?: number): FuzzyGraph;
13
+ /**
14
+ * Compute the fuzzy weight graph between new (query) points and training points.
15
+ * Used by UMAP.transform() to project unseen data into an existing embedding.
16
+ *
17
+ * Unlike computeFuzzySimplicialSet, this produces a bipartite graph
18
+ * (new points → training points) with no symmetrization.
19
+ *
20
+ * @param knnIndices - For each new point, the indices of its training neighbors
21
+ * @param knnDistances - For each new point, the distances to those neighbors
22
+ * @param nNeighbors - Number of neighbors used
23
+ * @returns FuzzyGraph where rows are new-point indices, cols are training-point indices
24
+ */
25
+ export declare function computeTransformFuzzyWeights(knnIndices: number[][], knnDistances: number[][], nNeighbors: number): FuzzyGraph;
package/dist/gpu/sgd.d.ts CHANGED
@@ -25,6 +25,6 @@ export declare class GPUSgd {
25
25
  * @param params - UMAP curve parameters and repulsion settings
26
26
  * @returns Optimized embedding as Float32Array
27
27
  */
28
- optimize(embedding: Float32Array, head: Uint32Array, tail: Uint32Array, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: SGDParams): Promise<Float32Array>;
28
+ optimize(embedding: Float32Array, head: Uint32Array, tail: Uint32Array, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: SGDParams, onProgress?: (epoch: number, nEpochs: number) => void): Promise<Float32Array>;
29
29
  private makeBuffer;
30
30
  }
@@ -7,9 +7,24 @@ export interface HNSWOptions {
7
7
  efConstruction?: number;
8
8
  efSearch?: number;
9
9
  }
10
+ /**
11
+ * A built HNSW index that can be queried to find nearest neighbors in the
12
+ * training data for new (unseen) points — used by UMAP.transform().
13
+ */
14
+ export interface HNSWSearchableIndex {
15
+ searchKnn(queryVectors: number[][], nNeighbors: number): KNNResult;
16
+ }
10
17
  /**
11
18
  * Compute k-nearest neighbors using HNSW (Hierarchical Navigable Small World)
12
19
  * via hnswlib-wasm, replacing the O(n^2) brute-force search in umap-js with
13
20
  * an O(n log n) approximate nearest neighbor search.
14
21
  */
15
22
  export declare function computeKNN(vectors: number[][], nNeighbors: number, opts?: HNSWOptions): Promise<KNNResult>;
23
+ /**
24
+ * Like computeKNN, but also returns the built HNSW index so it can be reused
25
+ * later to project new points (used by UMAP.transform()).
26
+ */
27
+ export declare function computeKNNWithIndex(vectors: number[][], nNeighbors: number, opts?: HNSWOptions): Promise<{
28
+ knn: KNNResult;
29
+ index: HNSWSearchableIndex;
30
+ }>;
package/dist/index.d.ts CHANGED
@@ -1,5 +1,5 @@
1
- export { fit } from './umap';
2
- export type { UMAPOptions } from './umap';
3
- export type { KNNResult, HNSWOptions } from './hnsw-knn';
1
+ export { fit, UMAP } from './umap';
2
+ export type { UMAPOptions, ProgressCallback } from './umap';
3
+ export type { KNNResult, HNSWOptions, HNSWSearchableIndex } from './hnsw-knn';
4
4
  export type { FuzzyGraph } from './fuzzy-set';
5
5
  export { isWebGPUAvailable } from './gpu/device';
package/dist/index.js CHANGED
@@ -1,65 +1,98 @@
1
- var N = Object.defineProperty;
2
- var q = (e, n, a) => n in e ? N(e, n, { enumerable: !0, configurable: !0, writable: !0, value: a }) : e[n] = a;
3
- var O = (e, n, a) => q(e, typeof n != "symbol" ? n + "" : n, a);
4
- import { loadHnswlib as z } from "hnswlib-wasm";
5
- async function T(e, n, a = {}) {
6
- const { M: d = 16, efConstruction: t = 200, efSearch: h = 50 } = a, p = await z(), u = e[0].length, c = e.length, o = new p.HierarchicalNSW("l2", u, "");
7
- o.initIndex(c, d, t, 200), o.setEfSearch(Math.max(h, n)), o.addItems(e, !1);
8
- const r = [], s = [];
9
- for (let f = 0; f < c; f++) {
10
- const l = o.searchKnn(e[f], n + 1, void 0), m = l.neighbors.map((_, w) => ({ idx: _, dist: l.distances[w] })).filter(({ idx: _ }) => _ !== f).slice(0, n);
11
- r.push(m.map(({ idx: _ }) => _)), s.push(m.map(({ dist: _ }) => _));
1
+ var H = Object.defineProperty;
2
+ var V = (t, e, i) => e in t ? H(t, e, { enumerable: !0, configurable: !0, writable: !0, value: i }) : t[e] = i;
3
+ var E = (t, e, i) => V(t, typeof e != "symbol" ? e + "" : e, i);
4
+ import { loadHnswlib as C } from "hnswlib-wasm";
5
+ async function Y(t, e, i = {}) {
6
+ const { M: a = 16, efConstruction: s = 200, efSearch: p = 50 } = i, f = await C(), h = t[0].length, c = t.length, o = new f.HierarchicalNSW("l2", h, "");
7
+ o.initIndex(c, a, s, 200), o.setEfSearch(Math.max(p, e)), o.addItems(t, !1);
8
+ const n = [], r = [];
9
+ for (let l = 0; l < c; l++) {
10
+ const d = o.searchKnn(t[l], e + 1, void 0), u = d.neighbors.map((g, _) => ({ idx: g, dist: d.distances[_] })).filter(({ idx: g }) => g !== l).slice(0, e);
11
+ n.push(u.map(({ idx: g }) => g)), r.push(u.map(({ dist: g }) => g));
12
12
  }
13
- return { indices: r, distances: s };
13
+ return { indices: n, distances: r };
14
14
  }
15
- function L(e, n, a, d = 1) {
16
- const t = e.length, { sigmas: h, rhos: p } = C(n, a), u = [], c = [], o = [];
17
- for (let s = 0; s < t; s++)
18
- for (let f = 0; f < e[s].length; f++) {
19
- const l = n[s][f], m = l <= p[s] ? 1 : Math.exp(-((l - p[s]) / h[s]));
20
- u.push(s), c.push(e[s][f]), o.push(m);
15
+ async function J(t, e, i = {}) {
16
+ const { M: a = 16, efConstruction: s = 200, efSearch: p = 50 } = i, f = await C(), h = t[0].length, c = t.length, o = new f.HierarchicalNSW("l2", h, "");
17
+ o.initIndex(c, a, s, 200), o.setEfSearch(Math.max(p, e)), o.addItems(t, !1);
18
+ const n = [], r = [];
19
+ for (let d = 0; d < c; d++) {
20
+ const u = o.searchKnn(t[d], e + 1, void 0), g = u.neighbors.map((_, w) => ({ idx: _, dist: u.distances[w] })).filter(({ idx: _ }) => _ !== d).slice(0, e);
21
+ n.push(g.map(({ idx: _ }) => _)), r.push(g.map(({ dist: _ }) => _));
22
+ }
23
+ return { knn: { indices: n, distances: r }, index: {
24
+ searchKnn(d, u) {
25
+ const g = [], _ = [];
26
+ for (const w of d) {
27
+ const y = o.searchKnn(w, u, void 0), b = y.neighbors.map((M, x) => ({ idx: M, dist: y.distances[x] })).sort((M, x) => M.dist - x.dist).slice(0, u);
28
+ g.push(b.map(({ idx: M }) => M)), _.push(b.map(({ dist: M }) => M));
29
+ }
30
+ return { indices: g, distances: _ };
31
+ }
32
+ } };
33
+ }
34
+ function D(t, e, i, a = 1) {
35
+ const s = t.length, { sigmas: p, rhos: f } = L(e, i), h = [], c = [], o = [];
36
+ for (let r = 0; r < s; r++)
37
+ for (let l = 0; l < t[r].length; l++) {
38
+ const d = e[r][l], u = d <= f[r] ? 1 : Math.exp(-((d - f[r]) / p[r]));
39
+ h.push(r), c.push(t[r][l]), o.push(u);
40
+ }
41
+ return { ...X(h, c, o, s, a), nVertices: s };
42
+ }
43
+ function Q(t, e, i) {
44
+ const a = t.length, { sigmas: s, rhos: p } = L(e, i), f = [], h = [], c = [];
45
+ for (let o = 0; o < a; o++)
46
+ for (let n = 0; n < t[o].length; n++) {
47
+ const r = e[o][n], l = r <= p[o] ? 1 : Math.exp(-((r - p[o]) / s[o]));
48
+ f.push(o), h.push(t[o][n]), c.push(l);
21
49
  }
22
- return { ...D(u, c, o, t, d), nVertices: t };
50
+ return {
51
+ rows: new Float32Array(f),
52
+ cols: new Float32Array(h),
53
+ vals: new Float32Array(c),
54
+ nVertices: a
55
+ };
23
56
  }
24
- function C(e, n) {
25
- const d = e.length, t = new Float32Array(d), h = new Float32Array(d);
26
- for (let p = 0; p < d; p++) {
27
- const u = e[p];
28
- h[p] = u.find((f) => f > 0) ?? 0;
29
- let c = 0, o = 1 / 0, r = 1;
30
- const s = Math.log2(n);
31
- for (let f = 0; f < 64; f++) {
32
- let l = 0;
33
- for (let m = 1; m < u.length; m++)
34
- l += Math.exp(-Math.max(0, u[m] - h[p]) / r);
35
- if (Math.abs(l - s) < 1e-5) break;
36
- l > s ? (o = r, r = (c + o) / 2) : (c = r, r = o === 1 / 0 ? r * 2 : (c + o) / 2);
57
+ function L(t, e) {
58
+ const a = t.length, s = new Float32Array(a), p = new Float32Array(a);
59
+ for (let f = 0; f < a; f++) {
60
+ const h = t[f];
61
+ p[f] = h.find((l) => l > 0) ?? 0;
62
+ let c = 0, o = 1 / 0, n = 1;
63
+ const r = Math.log2(e);
64
+ for (let l = 0; l < 64; l++) {
65
+ let d = 0;
66
+ for (let u = 1; u < h.length; u++)
67
+ d += Math.exp(-Math.max(0, h[u] - p[f]) / n);
68
+ if (Math.abs(d - r) < 1e-5) break;
69
+ d > r ? (o = n, n = (c + o) / 2) : (c = n, n = o === 1 / 0 ? n * 2 : (c + o) / 2);
37
70
  }
38
- t[p] = r;
71
+ s[f] = n;
39
72
  }
40
- return { sigmas: t, rhos: h };
73
+ return { sigmas: s, rhos: p };
41
74
  }
42
- function D(e, n, a, d, t) {
43
- const h = /* @__PURE__ */ new Map(), p = (r, s, f) => {
44
- const l = `${r},${s}`, m = h.get(l) ?? 0;
45
- h.set(l, m + f);
75
+ function X(t, e, i, a, s) {
76
+ const p = /* @__PURE__ */ new Map(), f = (n, r, l) => {
77
+ const d = n * a + r;
78
+ p.set(d, (p.get(d) ?? 0) + l);
46
79
  };
47
- for (let r = 0; r < e.length; r++)
48
- p(e[r], n[r], a[r]), p(n[r], e[r], a[r]);
49
- const u = [], c = [], o = [];
50
- for (const [r, s] of h.entries()) {
51
- const [f, l] = r.split(",").map(Number);
52
- u.push(f), c.push(l), o.push(
53
- s > 1 ? t * (2 - s) + (1 - t) * (s - 1) : s
80
+ for (let n = 0; n < t.length; n++)
81
+ f(t[n], e[n], i[n]), f(e[n], t[n], i[n]);
82
+ const h = [], c = [], o = [];
83
+ for (const [n, r] of p.entries()) {
84
+ const l = Math.floor(n / a), d = n % a;
85
+ h.push(l), c.push(d), o.push(
86
+ r > 1 ? s * (2 - r) + (1 - s) * (r - 1) : r
54
87
  );
55
88
  }
56
89
  return {
57
- rows: new Float32Array(u),
90
+ rows: new Float32Array(h),
58
91
  cols: new Float32Array(c),
59
92
  vals: new Float32Array(o)
60
93
  };
61
94
  }
62
- const W = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
95
+ const Z = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
63
96
  // Applies attraction forces between connected nodes and repulsion forces
64
97
  // against negative samples.
65
98
 
@@ -162,18 +195,18 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
162
195
  epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
163
196
  }
164
197
  `;
165
- class j {
198
+ class W {
166
199
  constructor() {
167
- O(this, "device");
168
- O(this, "pipeline");
200
+ E(this, "device");
201
+ E(this, "pipeline");
169
202
  }
170
203
  async init() {
171
- const n = await navigator.gpu.requestAdapter();
172
- if (!n) throw new Error("WebGPU not supported");
173
- this.device = await n.requestDevice(), this.pipeline = this.device.createComputePipeline({
204
+ const e = await navigator.gpu.requestAdapter();
205
+ if (!e) throw new Error("WebGPU not supported");
206
+ this.device = await e.requestDevice(), this.pipeline = this.device.createComputePipeline({
174
207
  layout: "auto",
175
208
  compute: {
176
- module: this.device.createShaderModule({ code: W }),
209
+ module: this.device.createShaderModule({ code: Z }),
177
210
  entryPoint: "main"
178
211
  }
179
212
  });
@@ -191,171 +224,347 @@ class j {
191
224
  * @param params - UMAP curve parameters and repulsion settings
192
225
  * @returns Optimized embedding as Float32Array
193
226
  */
194
- async optimize(n, a, d, t, h, p, u, c) {
195
- const { device: o } = this, r = a.length, s = this.makeBuffer(
196
- n,
227
+ async optimize(e, i, a, s, p, f, h, c, o) {
228
+ const { device: n } = this, r = i.length, l = this.makeBuffer(
229
+ e,
197
230
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
198
- ), f = this.makeBuffer(a, GPUBufferUsage.STORAGE), l = this.makeBuffer(d, GPUBufferUsage.STORAGE), m = this.makeBuffer(t, GPUBufferUsage.STORAGE), _ = new Float32Array(r).fill(0), w = this.makeBuffer(_, GPUBufferUsage.STORAGE), g = new Float32Array(r);
199
- for (let i = 0; i < r; i++)
200
- g[i] = t[i] / c.negativeSampleRate;
201
- const S = this.makeBuffer(g, GPUBufferUsage.STORAGE), b = new Uint32Array(r);
202
- for (let i = 0; i < r; i++)
203
- b[i] = Math.random() * 4294967295 | 0;
204
- const v = this.makeBuffer(b, GPUBufferUsage.STORAGE), M = o.createBuffer({
231
+ ), d = this.makeBuffer(i, GPUBufferUsage.STORAGE), u = this.makeBuffer(a, GPUBufferUsage.STORAGE), g = this.makeBuffer(s, GPUBufferUsage.STORAGE), _ = new Float32Array(r).fill(0), w = this.makeBuffer(_, GPUBufferUsage.STORAGE), y = new Float32Array(r);
232
+ for (let m = 0; m < r; m++)
233
+ y[m] = s[m] / c.negativeSampleRate;
234
+ const b = this.makeBuffer(y, GPUBufferUsage.STORAGE), M = new Uint32Array(r);
235
+ for (let m = 0; m < r; m++)
236
+ M[m] = Math.random() * 4294967295 | 0;
237
+ const x = this.makeBuffer(M, GPUBufferUsage.STORAGE), B = n.createBuffer({
205
238
  size: 40,
206
239
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
207
240
  });
208
- for (let i = 0; i < u; i++) {
209
- const A = 1 - i / u, B = new ArrayBuffer(40), U = new Uint32Array(B), y = new Float32Array(B);
210
- U[0] = r, U[1] = h, U[2] = p, U[3] = i, U[4] = u, y[5] = A, y[6] = c.a, y[7] = c.b, y[8] = c.gamma, U[9] = c.negativeSampleRate, o.queue.writeBuffer(M, 0, B);
211
- const G = o.createBindGroup({
241
+ for (let m = 0; m < h; m++) {
242
+ const F = 1 - m / h, A = new ArrayBuffer(40), v = new Uint32Array(A), N = new Float32Array(A);
243
+ v[0] = r, v[1] = p, v[2] = f, v[3] = m, v[4] = h, N[5] = F, N[6] = c.a, N[7] = c.b, N[8] = c.gamma, v[9] = c.negativeSampleRate, n.queue.writeBuffer(B, 0, A);
244
+ const S = n.createBindGroup({
212
245
  layout: this.pipeline.getBindGroupLayout(0),
213
246
  entries: [
214
- { binding: 0, resource: { buffer: m } },
215
- { binding: 1, resource: { buffer: f } },
216
- { binding: 2, resource: { buffer: l } },
217
- { binding: 3, resource: { buffer: s } },
247
+ { binding: 0, resource: { buffer: g } },
248
+ { binding: 1, resource: { buffer: d } },
249
+ { binding: 2, resource: { buffer: u } },
250
+ { binding: 3, resource: { buffer: l } },
218
251
  { binding: 4, resource: { buffer: w } },
219
- { binding: 5, resource: { buffer: S } },
220
- { binding: 6, resource: { buffer: M } },
221
- { binding: 7, resource: { buffer: v } }
252
+ { binding: 5, resource: { buffer: b } },
253
+ { binding: 6, resource: { buffer: B } },
254
+ { binding: 7, resource: { buffer: x } }
222
255
  ]
223
- }), E = o.createCommandEncoder(), R = E.beginComputePass();
224
- R.setPipeline(this.pipeline), R.setBindGroup(0, G), R.dispatchWorkgroups(Math.ceil(r / 256)), R.end(), o.queue.submit([E.finish()]), i % 10 === 0 && await o.queue.onSubmittedWorkDone();
256
+ }), k = n.createCommandEncoder(), U = k.beginComputePass();
257
+ U.setPipeline(this.pipeline), U.setBindGroup(0, S), U.dispatchWorkgroups(Math.ceil(r / 256)), U.end(), n.queue.submit([k.finish()]), m % 10 === 0 && (await n.queue.onSubmittedWorkDone(), o == null || o(m, h));
225
258
  }
226
- const x = o.createBuffer({
227
- size: n.byteLength,
259
+ const G = n.createBuffer({
260
+ size: e.byteLength,
228
261
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
229
- }), P = o.createCommandEncoder();
230
- P.copyBufferToBuffer(s, 0, x, 0, n.byteLength), o.queue.submit([P.finish()]), await x.mapAsync(GPUMapMode.READ);
231
- const k = new Float32Array(x.getMappedRange().slice(0));
232
- return x.unmap(), s.destroy(), f.destroy(), l.destroy(), m.destroy(), w.destroy(), S.destroy(), v.destroy(), M.destroy(), x.destroy(), k;
262
+ }), R = n.createCommandEncoder();
263
+ R.copyBufferToBuffer(l, 0, G, 0, e.byteLength), n.queue.submit([R.finish()]), await G.mapAsync(GPUMapMode.READ);
264
+ const O = new Float32Array(G.getMappedRange().slice(0));
265
+ return G.unmap(), l.destroy(), d.destroy(), u.destroy(), g.destroy(), w.destroy(), b.destroy(), x.destroy(), B.destroy(), G.destroy(), O;
233
266
  }
234
- makeBuffer(n, a) {
235
- const d = this.device.createBuffer({
236
- size: n.byteLength,
237
- usage: a,
267
+ makeBuffer(e, i) {
268
+ const a = this.device.createBuffer({
269
+ size: e.byteLength,
270
+ usage: i,
238
271
  mappedAtCreation: !0
239
272
  });
240
- return n instanceof Float32Array ? new Float32Array(d.getMappedRange()).set(n) : new Uint32Array(d.getMappedRange()).set(n), d.unmap(), d;
273
+ return e instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(e) : new Uint32Array(a.getMappedRange()).set(e), a.unmap(), a;
241
274
  }
242
275
  }
243
- function F(e, n, a, d, t, h, p) {
244
- const { a: u, b: c, gamma: o = 1, negativeSampleRate: r = 5 } = p, s = n.rows.length, f = new Uint32Array(n.rows), l = new Uint32Array(n.cols), m = new Float32Array(s).fill(0), _ = new Float32Array(s);
245
- for (let g = 0; g < s; g++)
246
- _[g] = a[g] / r;
247
- function w(g) {
248
- return Math.max(-4, Math.min(4, g));
276
+ function P(t) {
277
+ return Math.max(-4, Math.min(4, t));
278
+ }
279
+ function q(t, e, i, a, s, p, f, h) {
280
+ const { a: c, b: o, gamma: n = 1, negativeSampleRate: r = 5 } = f, l = e.rows.length, d = new Uint32Array(e.rows), u = new Uint32Array(e.cols), g = new Float32Array(l).fill(0), _ = new Float32Array(l);
281
+ for (let w = 0; w < l; w++)
282
+ _[w] = i[w] / r;
283
+ for (let w = 0; w < p; w++) {
284
+ h == null || h(w, p);
285
+ const y = 1 - w / p;
286
+ for (let b = 0; b < l; b++) {
287
+ if (g[b] > w) continue;
288
+ const M = d[b], x = u[b];
289
+ let B = 0;
290
+ for (let m = 0; m < s; m++) {
291
+ const F = t[M * s + m] - t[x * s + m];
292
+ B += F * F;
293
+ }
294
+ const G = Math.pow(B, o), R = -2 * c * o * (B > 0 ? G / B : 0) / (c * G + 1);
295
+ for (let m = 0; m < s; m++) {
296
+ const F = t[M * s + m] - t[x * s + m], A = P(R * F);
297
+ t[M * s + m] += y * A;
298
+ }
299
+ g[b] += i[b];
300
+ const O = _[b] > 0 ? Math.floor(i[b] / _[b]) : 0;
301
+ for (let m = 0; m < O; m++) {
302
+ const F = Math.floor(Math.random() * a);
303
+ if (F === M) continue;
304
+ let A = 0;
305
+ for (let S = 0; S < s; S++) {
306
+ const k = t[M * s + S] - t[F * s + S];
307
+ A += k * k;
308
+ }
309
+ const v = Math.pow(A, o), N = 2 * n * o / ((1e-3 + A) * (c * v + 1));
310
+ for (let S = 0; S < s; S++) {
311
+ const k = t[M * s + S] - t[F * s + S], U = P(N * k);
312
+ t[M * s + S] += y * U;
313
+ }
314
+ }
315
+ _[b] += i[b] / r;
316
+ }
249
317
  }
250
- for (let g = 0; g < h; g++) {
251
- const S = 1 - g / h;
252
- for (let b = 0; b < s; b++) {
253
- if (m[b] > g) continue;
254
- const v = f[b], M = l[b];
255
- let x = 0;
256
- for (let i = 0; i < t; i++) {
257
- const A = e[v * t + i] - e[M * t + i];
258
- x += A * A;
318
+ return t;
319
+ }
320
+ function $(t, e, i, a, s, p, f, h, c, o) {
321
+ const { a: n, b: r, gamma: l = 1, negativeSampleRate: d = 5 } = c, u = i.rows.length, g = new Uint32Array(i.rows), _ = new Uint32Array(i.cols), w = new Float32Array(u).fill(0), y = new Float32Array(u);
322
+ for (let b = 0; b < u; b++)
323
+ y[b] = a[b] / d;
324
+ for (let b = 0; b < h; b++) {
325
+ const M = 1 - b / h;
326
+ for (let x = 0; x < u; x++) {
327
+ if (w[x] > b) continue;
328
+ const B = g[x], G = _[x];
329
+ let R = 0;
330
+ for (let A = 0; A < f; A++) {
331
+ const v = t[B * f + A] - e[G * f + A];
332
+ R += v * v;
259
333
  }
260
- const P = -2 * u * c * Math.pow(x, c - 1) / (u * Math.pow(x, c) + 1);
261
- for (let i = 0; i < t; i++) {
262
- const A = e[v * t + i] - e[M * t + i], B = w(P * A);
263
- e[v * t + i] += S * B;
334
+ const O = Math.pow(R, r), m = -2 * n * r * (R > 0 ? O / R : 0) / (n * O + 1);
335
+ for (let A = 0; A < f; A++) {
336
+ const v = t[B * f + A] - e[G * f + A];
337
+ t[B * f + A] += M * P(m * v);
264
338
  }
265
- m[b] += a[b];
266
- const k = _[b] > 0 ? Math.floor(a[b] / _[b]) : 0;
267
- for (let i = 0; i < k; i++) {
268
- const A = Math.floor(Math.random() * d);
269
- if (A === v) continue;
270
- let B = 0;
271
- for (let y = 0; y < t; y++) {
272
- const G = e[v * t + y] - e[A * t + y];
273
- B += G * G;
339
+ w[x] += a[x];
340
+ const F = y[x] > 0 ? Math.floor(a[x] / y[x]) : 0;
341
+ for (let A = 0; A < F; A++) {
342
+ const v = Math.floor(Math.random() * p);
343
+ if (v === G) continue;
344
+ let N = 0;
345
+ for (let U = 0; U < f; U++) {
346
+ const z = t[B * f + U] - e[v * f + U];
347
+ N += z * z;
274
348
  }
275
- const U = 2 * o * c / ((1e-3 + B) * (u * Math.pow(B, c) + 1));
276
- for (let y = 0; y < t; y++) {
277
- const G = e[v * t + y] - e[A * t + y], E = w(U * G);
278
- e[v * t + y] += S * E;
349
+ const S = Math.pow(N, r), k = 2 * l * r / ((1e-3 + N) * (n * S + 1));
350
+ for (let U = 0; U < f; U++) {
351
+ const z = t[B * f + U] - e[v * f + U];
352
+ t[B * f + U] += M * P(k * z);
279
353
  }
280
354
  }
281
- _[b] += a[b] / r;
355
+ y[x] += a[x] / d;
282
356
  }
283
357
  }
284
- return e;
358
+ return t;
285
359
  }
286
- function I() {
360
+ function K() {
287
361
  return typeof navigator < "u" && !!navigator.gpu;
288
362
  }
289
- async function Q(e, n = {}) {
363
+ async function ae(t, e = {}, i) {
290
364
  const {
291
365
  nComponents: a = 2,
292
- nNeighbors: d = 15,
293
- minDist: t = 0.1,
294
- spread: h = 1,
295
- hnsw: p = {}
296
- } = n, u = n.nEpochs ?? (e.length > 1e4 ? 200 : 500);
366
+ nNeighbors: s = 15,
367
+ minDist: p = 0.1,
368
+ spread: f = 1,
369
+ hnsw: h = {}
370
+ } = e, c = e.nEpochs ?? (t.length > 1e4 ? 200 : 500);
297
371
  console.time("knn");
298
- const { indices: c, distances: o } = await T(e, d, {
299
- M: p.M ?? 16,
300
- efConstruction: p.efConstruction ?? 200,
301
- efSearch: p.efSearch ?? 50
372
+ const { indices: o, distances: n } = await Y(t, s, {
373
+ M: h.M ?? 16,
374
+ efConstruction: h.efConstruction ?? 200,
375
+ efSearch: h.efSearch ?? 50
302
376
  });
303
377
  console.timeEnd("knn"), console.time("fuzzy-set");
304
- const r = L(c, o, d);
378
+ const r = D(o, n, s);
305
379
  console.timeEnd("fuzzy-set");
306
- const { a: s, b: f } = K(t, h), l = Y(r.vals, u), m = e.length, _ = new Float32Array(m * a);
307
- for (let g = 0; g < _.length; g++)
308
- _[g] = Math.random() * 20 - 10;
380
+ const { a: l, b: d } = j(p, f), u = I(r.vals, c), g = t.length, _ = new Float32Array(g * a);
381
+ for (let y = 0; y < _.length; y++)
382
+ _[y] = Math.random() * 20 - 10;
309
383
  console.time("sgd");
310
384
  let w;
311
- if (I())
385
+ if (K())
312
386
  try {
313
- const g = new j();
314
- await g.init(), w = await g.optimize(
387
+ const y = new W();
388
+ await y.init(), w = await y.optimize(
315
389
  _,
316
390
  new Uint32Array(r.rows),
317
391
  new Uint32Array(r.cols),
318
- l,
319
- m,
320
- a,
321
392
  u,
322
- { a: s, b: f, gamma: 1, negativeSampleRate: 5 }
393
+ g,
394
+ a,
395
+ c,
396
+ { a: l, b: d, gamma: 1, negativeSampleRate: 5 },
397
+ i
323
398
  );
324
- } catch (g) {
325
- console.warn("WebGPU SGD failed, falling back to CPU:", g), w = F(_, r, l, m, a, u, { a: s, b: f });
399
+ } catch (y) {
400
+ console.warn("WebGPU SGD failed, falling back to CPU:", y), w = q(_, r, u, g, a, c, { a: l, b: d }, i);
326
401
  }
327
402
  else
328
- w = F(_, r, l, m, a, u, { a: s, b: f });
403
+ w = q(_, r, u, g, a, c, { a: l, b: d }, i);
329
404
  return console.timeEnd("sgd"), w;
330
405
  }
331
- function K(e, n) {
332
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
406
+ function j(t, e) {
407
+ if (Math.abs(e - 1) < 1e-6 && Math.abs(t - 0.1) < 1e-6)
333
408
  return { a: 1.9292, b: 0.7915 };
334
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
409
+ if (Math.abs(e - 1) < 1e-6 && Math.abs(t - 0) < 1e-6)
335
410
  return { a: 1.8956, b: 0.8006 };
336
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
411
+ if (Math.abs(e - 1) < 1e-6 && Math.abs(t - 0.5) < 1e-6)
337
412
  return { a: 1.5769, b: 0.8951 };
338
- const a = H(e, n);
339
- return { a: V(e, n, a), b: a };
413
+ const i = ee(t, e);
414
+ return { a: te(t, e, i), b: i };
340
415
  }
341
- function H(e, n) {
342
- return 1 / (n * 1.2);
416
+ function ee(t, e) {
417
+ return 1 / (e * 1.2);
343
418
  }
344
- function V(e, n, a) {
345
- return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * a);
419
+ function te(t, e, i) {
420
+ return t < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(t, 2 * i);
421
+ }
422
+ class re {
423
+ constructor(e = {}) {
424
+ E(this, "_nComponents");
425
+ E(this, "_nNeighbors");
426
+ E(this, "_minDist");
427
+ E(this, "_spread");
428
+ E(this, "_nEpochs");
429
+ E(this, "_hnswOpts");
430
+ E(this, "_a");
431
+ E(this, "_b");
432
+ /** The low-dimensional embedding produced by the last fit() call. */
433
+ E(this, "embedding", null);
434
+ E(this, "_hnswIndex", null);
435
+ E(this, "_nTrain", 0);
436
+ this._nComponents = e.nComponents ?? 2, this._nNeighbors = e.nNeighbors ?? 15, this._minDist = e.minDist ?? 0.1, this._spread = e.spread ?? 1, this._nEpochs = e.nEpochs, this._hnswOpts = e.hnsw ?? {};
437
+ const { a: i, b: a } = j(this._minDist, this._spread);
438
+ this._a = i, this._b = a;
439
+ }
440
+ /**
441
+ * Train UMAP on `vectors`.
442
+ * Stores the resulting embedding in `this.embedding` and retains the HNSW
443
+ * index so that transform() can project new points later.
444
+ * Returns `this` for chaining.
445
+ */
446
+ async fit(e, i) {
447
+ const a = e.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: p = 16, efConstruction: f = 200, efSearch: h = 50 } = this._hnswOpts;
448
+ console.time("knn");
449
+ const { knn: c, index: o } = await J(e, this._nNeighbors, {
450
+ M: p,
451
+ efConstruction: f,
452
+ efSearch: h
453
+ });
454
+ this._hnswIndex = o, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
455
+ const n = D(c.indices, c.distances, this._nNeighbors);
456
+ console.timeEnd("fuzzy-set");
457
+ const r = I(n.vals, s), l = new Float32Array(a * this._nComponents);
458
+ for (let d = 0; d < l.length; d++)
459
+ l[d] = Math.random() * 20 - 10;
460
+ if (console.time("sgd"), K())
461
+ try {
462
+ const d = new W();
463
+ await d.init(), this.embedding = await d.optimize(
464
+ l,
465
+ new Uint32Array(n.rows),
466
+ new Uint32Array(n.cols),
467
+ r,
468
+ a,
469
+ this._nComponents,
470
+ s,
471
+ { a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 },
472
+ i
473
+ );
474
+ } catch (d) {
475
+ console.warn("WebGPU SGD failed, falling back to CPU:", d), this.embedding = q(l, n, r, a, this._nComponents, s, {
476
+ a: this._a,
477
+ b: this._b
478
+ }, i);
479
+ }
480
+ else
481
+ this.embedding = q(l, n, r, a, this._nComponents, s, {
482
+ a: this._a,
483
+ b: this._b
484
+ }, i);
485
+ return console.timeEnd("sgd"), this;
486
+ }
487
+ /**
488
+ * Project new (unseen) `vectors` into the embedding space learned by fit().
489
+ * Must be called after fit().
490
+ *
491
+ * The training embedding is kept fixed; only the new-point positions are
492
+ * optimised. Returns a Float32Array of shape [vectors.length × nComponents].
493
+ *
494
+ * @param normalize - When `true`, min-max normalise each dimension of the
495
+ * returned embedding to [0, 1]. The stored training embedding is never
496
+ * mutated. Defaults to `false`.
497
+ */
498
+ async transform(e, i = !1) {
499
+ if (!this._hnswIndex || !this.embedding)
500
+ throw new Error("UMAP.transform() must be called after fit()");
501
+ const a = e.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), p = Math.max(100, Math.floor(s / 4)), f = this._hnswIndex.searchKnn(e, this._nNeighbors), h = Q(f.indices, f.distances, this._nNeighbors), c = new Uint32Array(h.rows), o = new Uint32Array(h.cols), n = new Float32Array(a), r = new Float32Array(a * this._nComponents);
502
+ for (let u = 0; u < c.length; u++) {
503
+ const g = c[u], _ = o[u], w = h.vals[u];
504
+ n[g] += w;
505
+ for (let y = 0; y < this._nComponents; y++)
506
+ r[g * this._nComponents + y] += w * this.embedding[_ * this._nComponents + y];
507
+ }
508
+ for (let u = 0; u < a; u++)
509
+ if (n[u] > 0)
510
+ for (let g = 0; g < this._nComponents; g++)
511
+ r[u * this._nComponents + g] /= n[u];
512
+ else
513
+ for (let g = 0; g < this._nComponents; g++)
514
+ r[u * this._nComponents + g] = Math.random() * 20 - 10;
515
+ const l = I(h.vals, p), d = $(
516
+ r,
517
+ this.embedding,
518
+ h,
519
+ l,
520
+ a,
521
+ this._nTrain,
522
+ this._nComponents,
523
+ p,
524
+ { a: this._a, b: this._b }
525
+ );
526
+ return i ? T(d, a, this._nComponents) : d;
527
+ }
528
+ /**
529
+ * Convenience method equivalent to `fit(vectors)` followed by
530
+ * `transform(vectors)` — but more efficient because the training embedding
531
+ * is returned directly without a second optimization pass.
532
+ *
533
+ * @param normalize - When `true`, min-max normalise each dimension of the
534
+ * returned embedding to [0, 1]. `this.embedding` is never mutated.
535
+ * Defaults to `false`.
536
+ */
537
+ async fit_transform(e, i, a = !1) {
538
+ return await this.fit(e, i), a ? T(this.embedding, e.length, this._nComponents) : this.embedding;
539
+ }
540
+ }
541
+ function T(t, e, i) {
542
+ const a = new Float32Array(t.length);
543
+ for (let s = 0; s < i; s++) {
544
+ let p = 1 / 0, f = -1 / 0;
545
+ for (let c = 0; c < e; c++) {
546
+ const o = t[c * i + s];
547
+ o < p && (p = o), o > f && (f = o);
548
+ }
549
+ const h = f - p;
550
+ for (let c = 0; c < e; c++)
551
+ a[c * i + s] = h > 0 ? (t[c * i + s] - p) / h : 0;
552
+ }
553
+ return a;
346
554
  }
347
- function Y(e, n) {
348
- let a = -1 / 0;
349
- for (let t = 0; t < e.length; t++)
350
- e[t] > a && (a = e[t]);
351
- const d = new Float32Array(e.length);
352
- for (let t = 0; t < e.length; t++) {
353
- const h = e[t] / a;
354
- d[t] = h > 0 ? n / h : -1;
555
+ function I(t, e) {
556
+ let i = -1 / 0;
557
+ for (let s = 0; s < t.length; s++)
558
+ t[s] > i && (i = t[s]);
559
+ const a = new Float32Array(t.length);
560
+ for (let s = 0; s < t.length; s++) {
561
+ const p = t[s] / i;
562
+ a[s] = p > 0 ? e / p : -1;
355
563
  }
356
- return d;
564
+ return a;
357
565
  }
358
566
  export {
359
- Q as fit,
360
- I as isWebGPUAvailable
567
+ re as UMAP,
568
+ ae as fit,
569
+ K as isWebGPUAvailable
361
570
  };
package/dist/umap.d.ts CHANGED
@@ -16,6 +16,14 @@ export interface UMAPOptions {
16
16
  efSearch?: number;
17
17
  };
18
18
  }
19
+ /**
20
+ * Called after each completed SGD epoch (or every 10 epochs on the GPU path,
21
+ * piggybacking on the existing GPU synchronisation point to avoid extra stalls).
22
+ *
23
+ * @param epoch - Zero-based index of the epoch that just finished.
24
+ * @param nEpochs - Total number of epochs.
25
+ */
26
+ export type ProgressCallback = (epoch: number, nEpochs: number) => void;
19
27
  /**
20
28
  * Fit UMAP to the given high-dimensional vectors and return a low-dimensional embedding.
21
29
  *
@@ -24,7 +32,7 @@ export interface UMAPOptions {
24
32
  * 2. Fuzzy simplicial set construction (graph weights)
25
33
  * 3. SGD optimization (WebGPU accelerated, with CPU fallback)
26
34
  */
27
- export declare function fit(vectors: number[][], opts?: UMAPOptions): Promise<Float32Array>;
35
+ export declare function fit(vectors: number[][], opts?: UMAPOptions, onProgress?: ProgressCallback): Promise<Float32Array>;
28
36
  /**
29
37
  * Compute the a, b parameters for the UMAP curve 1/(1 + a*d^(2b)).
30
38
  *
@@ -36,6 +44,68 @@ export declare function findAB(minDist: number, spread: number): {
36
44
  a: number;
37
45
  b: number;
38
46
  };
47
+ /**
48
+ * Stateful UMAP model that supports separate fit / transform / fit_transform.
49
+ *
50
+ * Usage:
51
+ * ```ts
52
+ * const umap = new UMAP({ nNeighbors: 15, nComponents: 2 });
53
+ *
54
+ * // Train on high-dimensional data:
55
+ * await umap.fit(trainVectors);
56
+ * console.log(umap.embedding); // Float32Array [nTrain * nComponents]
57
+ *
58
+ * // Project new points into the same space:
59
+ * const newEmbedding = await umap.transform(testVectors);
60
+ *
61
+ * // Or do both at once:
62
+ * const embedding = await umap.fit_transform(vectors);
63
+ * ```
64
+ */
65
+ export declare class UMAP {
66
+ private readonly _nComponents;
67
+ private readonly _nNeighbors;
68
+ private readonly _minDist;
69
+ private readonly _spread;
70
+ private readonly _nEpochs;
71
+ private readonly _hnswOpts;
72
+ private readonly _a;
73
+ private readonly _b;
74
+ /** The low-dimensional embedding produced by the last fit() call. */
75
+ embedding: Float32Array | null;
76
+ private _hnswIndex;
77
+ private _nTrain;
78
+ constructor(opts?: UMAPOptions);
79
+ /**
80
+ * Train UMAP on `vectors`.
81
+ * Stores the resulting embedding in `this.embedding` and retains the HNSW
82
+ * index so that transform() can project new points later.
83
+ * Returns `this` for chaining.
84
+ */
85
+ fit(vectors: number[][], onProgress?: ProgressCallback): Promise<this>;
86
+ /**
87
+ * Project new (unseen) `vectors` into the embedding space learned by fit().
88
+ * Must be called after fit().
89
+ *
90
+ * The training embedding is kept fixed; only the new-point positions are
91
+ * optimised. Returns a Float32Array of shape [vectors.length × nComponents].
92
+ *
93
+ * @param normalize - When `true`, min-max normalise each dimension of the
94
+ * returned embedding to [0, 1]. The stored training embedding is never
95
+ * mutated. Defaults to `false`.
96
+ */
97
+ transform(vectors: number[][], normalize?: boolean): Promise<Float32Array>;
98
+ /**
99
+ * Convenience method equivalent to `fit(vectors)` followed by
100
+ * `transform(vectors)` — but more efficient because the training embedding
101
+ * is returned directly without a second optimization pass.
102
+ *
103
+ * @param normalize - When `true`, min-max normalise each dimension of the
104
+ * returned embedding to [0, 1]. `this.embedding` is never mutated.
105
+ * Defaults to `false`.
106
+ */
107
+ fit_transform(vectors: number[][], onProgress?: ProgressCallback, normalize?: boolean): Promise<Float32Array>;
108
+ }
39
109
  /**
40
110
  * Compute per-edge epoch sampling periods based on edge weights.
41
111
  * Higher-weight edges are sampled more frequently.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.1.0",
3
+ "version": "0.2.8",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -8,19 +8,29 @@
8
8
  "files": [
9
9
  "dist"
10
10
  ],
11
+ "repository": {
12
+ "type": "git",
13
+ "url": "https://github.com/Achuttarsing/umap-gpu"
14
+ },
11
15
  "scripts": {
12
16
  "build": "vite build && tsc",
13
17
  "dev": "vite",
14
18
  "test": "vitest run",
15
- "prepublishOnly": "npm test && npm run build"
19
+ "prepublishOnly": "bun test && bun run build",
20
+ "docs:dev": "vitepress dev docs",
21
+ "docs:build": "vitepress build docs",
22
+ "docs:generate": "bun run build && bunx api-extractor run && bun run docs:build"
16
23
  },
17
24
  "dependencies": {
18
25
  "hnswlib-wasm": "^0.8.2"
19
26
  },
20
27
  "devDependencies": {
28
+ "@microsoft/api-extractor": "^7.57.6",
21
29
  "@webgpu/types": "^0.1.40",
22
30
  "typescript": "^5.4.0",
23
31
  "vite": "^5.0.0",
32
+ "vitepress": "^1.6.4",
33
+ "vitepress-plugin-llms": "^1.11.0",
24
34
  "vitest": "^4.0.18"
25
35
  },
26
36
  "license": "MIT"