umap-gpu 0.1.0 → 0.2.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -1,87 +1,104 @@
1
1
  # umap-gpu
2
2
 
3
- UMAP dimensionality reduction with HNSW k-nearest-neighbor search and WebGPU-accelerated SGD optimization, with a transparent CPU fallback.
3
+ UMAP dimensionality reduction with WebGPU-accelerated SGD and HNSW approximate nearest neighbors.
4
4
 
5
- ## What it does
5
+ Embed millions of high-dimensional vectors into 2D in seconds — not minutes.
6
6
 
7
- Takes a set of high-dimensional vectors and returns a low-dimensional embedding (default: 2D) suitable for visualization or downstream tasks.
7
+ ## Why GPU?
8
8
 
9
- The pipeline runs in three stages:
9
+ The bottleneck in UMAP is the SGD optimization loop: thousands of epochs, millions of edge updates per epoch. On CPU this is sequential. On GPU, all edges run in parallel across thousands of shader cores — expect a significant speedup on large datasets, scaling with both the number of points and the number of epochs.
10
10
 
11
- 1. **k-NN** approximate nearest neighbors via [hnswlib-wasm](https://github.com/yoshoku/hnswlib-wasm) (O(n log n))
12
- 2. **Fuzzy simplicial set** builds a weighted graph from the k-NN graph using smooth distances
13
- 3. **SGD** — optimizes the embedding using attraction/repulsion forces:
14
- - **WebGPU** compute shader when available (Chrome 113+, Edge 113+)
15
- - **CPU** fallback otherwise — identical output, just slower
11
+ The k-NN stage uses [hnswlib-wasm](https://github.com/yoshoku/hnswlib-wasm) (O(n log n)) so it stays fast regardless.
12
+ A transparent CPU fallback guarantees identical output everywhere WebGPU isn't available.
16
13
 
17
14
  ## Install
18
15
 
19
16
  ```bash
17
+ # npm
20
18
  npm install umap-gpu
21
- ```
22
19
 
23
- > Requires a browser or runtime with WebGPU support for GPU acceleration. The CPU fallback works anywhere.
20
+ # Bun
21
+ bun add umap-gpu
22
+
23
+ # pnpm
24
+ pnpm add umap-gpu
25
+ ```
24
26
 
25
- ## Usage
27
+ ## Quick start
26
28
 
27
29
  ```ts
28
30
  import { fit } from 'umap-gpu';
29
31
 
30
32
  const vectors = [
31
- [1.0, 0.0, 0.3],
32
- [0.9, 0.1, 0.4],
33
- [0.0, 1.0, 0.8],
33
+ [0.1, 0.4, 0.9, ...], // high-dimensional points
34
+ [0.2, 0.3, 0.8, ...],
34
35
  // ...
35
36
  ];
36
37
 
37
38
  const embedding = await fit(vectors);
38
- // Float32Array of length n * nComponents (default: n * 2)
39
- // embedding[i*2], embedding[i*2 + 1] → 2D coordinates of point i
39
+ // Float32Array embedding[i*2], embedding[i*2+1] are the 2D coords of point i
40
+ ```
41
+
42
+ ## Train once, project many times
43
+
44
+ Use the `UMAP` class to embed a training set and later project new points into the same space without retraining.
45
+
46
+ ```ts
47
+ import { UMAP } from 'umap-gpu';
48
+
49
+ const umap = new UMAP({ nNeighbors: 15, minDist: 0.1 });
50
+
51
+ // Train
52
+ await umap.fit(trainVectors);
53
+ console.log(umap.embedding); // Float32Array [nTrain × 2]
54
+
55
+ // Project new points (training embedding stays fixed)
56
+ const projected = await umap.transform(newVectors);
57
+ // Float32Array [nNew × 2]
40
58
  ```
41
59
 
42
- ### Options
60
+ ## Options
43
61
 
44
62
  ```ts
45
- const embedding = await fit(vectors, {
46
- nComponents: 2, // output dimensions (default: 2)
47
- nNeighbors: 15, // k-NN graph degree (default: 15)
48
- nEpochs: 500, // SGD iterations (default: 500 for <10k points, 200 otherwise)
49
- minDist: 0.1, // minimum distance between points in the embedding (default: 0.1)
50
- spread: 1.0, // scale of the embedding (default: 1.0)
63
+ const umap = new UMAP({
64
+ nComponents: 2, // output dimensions (default: 2)
65
+ nNeighbors: 15, // k-NN graph degree (default: 15)
66
+ nEpochs: 500, // SGD iterations (default: auto — 500 for <10k points, 200 otherwise)
67
+ minDist: 0.1, // min distance in embedding (default: 0.1)
68
+ spread: 1.0, // scale of the embedding (default: 1.0)
51
69
  hnsw: {
52
- M: 16, // HNSW graph connectivity (default: 16)
53
- efConstruction: 200, // build-time search width (default: 200)
54
- efSearch: 50, // query-time search width (default: 50)
70
+ M: 16, // graph connectivity (default: 16)
71
+ efConstruction: 200, // build-time search width (default: 200)
72
+ efSearch: 50, // query-time search width (default: 50)
55
73
  },
56
74
  });
75
+
76
+ // Same options work with the functional API
77
+ const embedding = await fit(vectors, { nNeighbors: 15, minDist: 0.05 });
57
78
  ```
58
79
 
59
- ### Checking GPU availability
80
+ ## Check GPU availability
60
81
 
61
82
  ```ts
62
83
  import { isWebGPUAvailable } from 'umap-gpu';
63
84
 
64
- if (isWebGPUAvailable()) {
65
- console.log('Will use WebGPU-accelerated SGD');
66
- } else {
67
- console.log('Will fall back to CPU SGD');
68
- }
69
- ```
70
-
71
- ## Build
72
-
73
- ```bash
74
- npm run build # compiles TypeScript to dist/
75
- npm test # runs the unit test suite (Vitest)
85
+ console.log(isWebGPUAvailable()); // true → GPU path, false → CPU fallback
76
86
  ```
77
87
 
78
88
  ## Browser support
79
89
 
80
- | Feature | Requirement |
90
+ | Feature | Supported in |
81
91
  |---------|-------------|
82
92
  | WebGPU SGD | Chrome 113+, Edge 113+, Safari 18+ |
83
- | CPU fallback | Any modern browser / Node.js |
84
- | HNSW (WASM) | Any environment with WebAssembly support |
93
+ | CPU fallback | Any modern browser / Node.js / Bun |
94
+ | HNSW (WASM) | Any environment with WebAssembly |
95
+
96
+ ## Development
97
+
98
+ ```bash
99
+ npm test # Vitest unit tests
100
+ npm run build # TypeScript → dist/
101
+ ```
85
102
 
86
103
  ## License
87
104
 
@@ -10,3 +10,20 @@ export interface CPUSgdParams {
10
10
  * Mirrors the GPU shader logic: per-edge attraction + negative-sample repulsion.
11
11
  */
12
12
  export declare function cpuSgd(embedding: Float32Array, graph: FuzzyGraph, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: CPUSgdParams): Float32Array;
13
+ /**
14
+ * CPU SGD for UMAP.transform(): optimizes only the new-point embeddings.
15
+ * The training embedding is read-only; attraction pulls new points toward
16
+ * their training neighbors, and repulsion pushes them away from random
17
+ * training points.
18
+ *
19
+ * @param embeddingNew - New-point embeddings to optimize [nNew × nComponents]
20
+ * @param embeddingTrain - Fixed training embeddings [nTrain × nComponents]
21
+ * @param graph - Bipartite graph: rows=new-point indices, cols=training-point indices
22
+ * @param epochsPerSample - Per-edge epoch sampling schedule
23
+ * @param nNew - Number of new points
24
+ * @param nTrain - Number of training points
25
+ * @param nComponents - Embedding dimensionality
26
+ * @param nEpochs - Number of optimization epochs
27
+ * @param params - UMAP curve parameters
28
+ */
29
+ export declare function cpuSgdTransform(embeddingNew: Float32Array, embeddingTrain: Float32Array, graph: FuzzyGraph, epochsPerSample: Float32Array, nNew: number, nTrain: number, nComponents: number, nEpochs: number, params: CPUSgdParams): Float32Array;
@@ -10,3 +10,16 @@ export interface FuzzyGraph {
10
10
  * (sigmas, rhos) and symmetrizes with the fuzzy set union operation.
11
11
  */
12
12
  export declare function computeFuzzySimplicialSet(knnIndices: number[][], knnDistances: number[][], nNeighbors: number, setOpMixRatio?: number): FuzzyGraph;
13
+ /**
14
+ * Compute the fuzzy weight graph between new (query) points and training points.
15
+ * Used by UMAP.transform() to project unseen data into an existing embedding.
16
+ *
17
+ * Unlike computeFuzzySimplicialSet, this produces a bipartite graph
18
+ * (new points → training points) with no symmetrization.
19
+ *
20
+ * @param knnIndices - For each new point, the indices of its training neighbors
21
+ * @param knnDistances - For each new point, the distances to those neighbors
22
+ * @param nNeighbors - Number of neighbors used
23
+ * @returns FuzzyGraph where rows are new-point indices, cols are training-point indices
24
+ */
25
+ export declare function computeTransformFuzzyWeights(knnIndices: number[][], knnDistances: number[][], nNeighbors: number): FuzzyGraph;
@@ -7,9 +7,24 @@ export interface HNSWOptions {
7
7
  efConstruction?: number;
8
8
  efSearch?: number;
9
9
  }
10
+ /**
11
+ * A built HNSW index that can be queried to find nearest neighbors in the
12
+ * training data for new (unseen) points — used by UMAP.transform().
13
+ */
14
+ export interface HNSWSearchableIndex {
15
+ searchKnn(queryVectors: number[][], nNeighbors: number): KNNResult;
16
+ }
10
17
  /**
11
18
  * Compute k-nearest neighbors using HNSW (Hierarchical Navigable Small World)
12
19
  * via hnswlib-wasm, replacing the O(n^2) brute-force search in umap-js with
13
20
  * an O(n log n) approximate nearest neighbor search.
14
21
  */
15
22
  export declare function computeKNN(vectors: number[][], nNeighbors: number, opts?: HNSWOptions): Promise<KNNResult>;
23
+ /**
24
+ * Like computeKNN, but also returns the built HNSW index so it can be reused
25
+ * later to project new points (used by UMAP.transform()).
26
+ */
27
+ export declare function computeKNNWithIndex(vectors: number[][], nNeighbors: number, opts?: HNSWOptions): Promise<{
28
+ knn: KNNResult;
29
+ index: HNSWSearchableIndex;
30
+ }>;
package/dist/index.d.ts CHANGED
@@ -1,5 +1,5 @@
1
- export { fit } from './umap';
1
+ export { fit, UMAP } from './umap';
2
2
  export type { UMAPOptions } from './umap';
3
- export type { KNNResult, HNSWOptions } from './hnsw-knn';
3
+ export type { KNNResult, HNSWOptions, HNSWSearchableIndex } from './hnsw-knn';
4
4
  export type { FuzzyGraph } from './fuzzy-set';
5
5
  export { isWebGPUAvailable } from './gpu/device';
package/dist/index.js CHANGED
@@ -1,65 +1,98 @@
1
- var N = Object.defineProperty;
2
- var q = (e, n, a) => n in e ? N(e, n, { enumerable: !0, configurable: !0, writable: !0, value: a }) : e[n] = a;
3
- var O = (e, n, a) => q(e, typeof n != "symbol" ? n + "" : n, a);
4
- import { loadHnswlib as z } from "hnswlib-wasm";
5
- async function T(e, n, a = {}) {
6
- const { M: d = 16, efConstruction: t = 200, efSearch: h = 50 } = a, p = await z(), u = e[0].length, c = e.length, o = new p.HierarchicalNSW("l2", u, "");
7
- o.initIndex(c, d, t, 200), o.setEfSearch(Math.max(h, n)), o.addItems(e, !1);
8
- const r = [], s = [];
9
- for (let f = 0; f < c; f++) {
10
- const l = o.searchKnn(e[f], n + 1, void 0), m = l.neighbors.map((_, w) => ({ idx: _, dist: l.distances[w] })).filter(({ idx: _ }) => _ !== f).slice(0, n);
11
- r.push(m.map(({ idx: _ }) => _)), s.push(m.map(({ dist: _ }) => _));
1
+ var K = Object.defineProperty;
2
+ var j = (n, e, a) => e in n ? K(n, e, { enumerable: !0, configurable: !0, writable: !0, value: a }) : n[e] = a;
3
+ var S = (n, e, a) => j(n, typeof e != "symbol" ? e + "" : e, a);
4
+ import { loadHnswlib as T } from "hnswlib-wasm";
5
+ async function H(n, e, a = {}) {
6
+ const { M: f = 16, efConstruction: o = 200, efSearch: p = 50 } = a, c = await T(), l = n[0].length, u = n.length, s = new c.HierarchicalNSW("l2", l, "");
7
+ s.initIndex(u, f, o, 200), s.setEfSearch(Math.max(p, e)), s.addItems(n, !1);
8
+ const t = [], i = [];
9
+ for (let r = 0; r < u; r++) {
10
+ const d = s.searchKnn(n[r], e + 1, void 0), _ = d.neighbors.map((g, h) => ({ idx: g, dist: d.distances[h] })).filter(({ idx: g }) => g !== r).slice(0, e);
11
+ t.push(_.map(({ idx: g }) => g)), i.push(_.map(({ dist: g }) => g));
12
12
  }
13
- return { indices: r, distances: s };
13
+ return { indices: t, distances: i };
14
14
  }
15
- function L(e, n, a, d = 1) {
16
- const t = e.length, { sigmas: h, rhos: p } = C(n, a), u = [], c = [], o = [];
17
- for (let s = 0; s < t; s++)
18
- for (let f = 0; f < e[s].length; f++) {
19
- const l = n[s][f], m = l <= p[s] ? 1 : Math.exp(-((l - p[s]) / h[s]));
20
- u.push(s), c.push(e[s][f]), o.push(m);
15
+ async function V(n, e, a = {}) {
16
+ const { M: f = 16, efConstruction: o = 200, efSearch: p = 50 } = a, c = await T(), l = n[0].length, u = n.length, s = new c.HierarchicalNSW("l2", l, "");
17
+ s.initIndex(u, f, o, 200), s.setEfSearch(Math.max(p, e)), s.addItems(n, !1);
18
+ const t = [], i = [];
19
+ for (let d = 0; d < u; d++) {
20
+ const _ = s.searchKnn(n[d], e + 1, void 0), g = _.neighbors.map((h, y) => ({ idx: h, dist: _.distances[y] })).filter(({ idx: h }) => h !== d).slice(0, e);
21
+ t.push(g.map(({ idx: h }) => h)), i.push(g.map(({ dist: h }) => h));
22
+ }
23
+ return { knn: { indices: t, distances: i }, index: {
24
+ searchKnn(d, _) {
25
+ const g = [], h = [];
26
+ for (const y of d) {
27
+ const w = s.searchKnn(y, _, void 0), U = w.neighbors.map((b, x) => ({ idx: b, dist: w.distances[x] })).sort((b, x) => b.dist - x.dist).slice(0, _);
28
+ g.push(U.map(({ idx: b }) => b)), h.push(U.map(({ dist: b }) => b));
29
+ }
30
+ return { indices: g, distances: h };
21
31
  }
22
- return { ...D(u, c, o, t, d), nVertices: t };
32
+ } };
23
33
  }
24
- function C(e, n) {
25
- const d = e.length, t = new Float32Array(d), h = new Float32Array(d);
26
- for (let p = 0; p < d; p++) {
27
- const u = e[p];
28
- h[p] = u.find((f) => f > 0) ?? 0;
29
- let c = 0, o = 1 / 0, r = 1;
30
- const s = Math.log2(n);
31
- for (let f = 0; f < 64; f++) {
32
- let l = 0;
33
- for (let m = 1; m < u.length; m++)
34
- l += Math.exp(-Math.max(0, u[m] - h[p]) / r);
35
- if (Math.abs(l - s) < 1e-5) break;
36
- l > s ? (o = r, r = (c + o) / 2) : (c = r, r = o === 1 / 0 ? r * 2 : (c + o) / 2);
34
+ function C(n, e, a, f = 1) {
35
+ const o = n.length, { sigmas: p, rhos: c } = I(e, a), l = [], u = [], s = [];
36
+ for (let i = 0; i < o; i++)
37
+ for (let r = 0; r < n[i].length; r++) {
38
+ const d = e[i][r], _ = d <= c[i] ? 1 : Math.exp(-((d - c[i]) / p[i]));
39
+ l.push(i), u.push(n[i][r]), s.push(_);
37
40
  }
38
- t[p] = r;
41
+ return { ...J(l, u, s, o, f), nVertices: o };
42
+ }
43
+ function Y(n, e, a) {
44
+ const f = n.length, { sigmas: o, rhos: p } = I(e, a), c = [], l = [], u = [];
45
+ for (let s = 0; s < f; s++)
46
+ for (let t = 0; t < n[s].length; t++) {
47
+ const i = e[s][t], r = i <= p[s] ? 1 : Math.exp(-((i - p[s]) / o[s]));
48
+ c.push(s), l.push(n[s][t]), u.push(r);
49
+ }
50
+ return {
51
+ rows: new Float32Array(c),
52
+ cols: new Float32Array(l),
53
+ vals: new Float32Array(u),
54
+ nVertices: f
55
+ };
56
+ }
57
+ function I(n, e) {
58
+ const f = n.length, o = new Float32Array(f), p = new Float32Array(f);
59
+ for (let c = 0; c < f; c++) {
60
+ const l = n[c];
61
+ p[c] = l.find((r) => r > 0) ?? 0;
62
+ let u = 0, s = 1 / 0, t = 1;
63
+ const i = Math.log2(e);
64
+ for (let r = 0; r < 64; r++) {
65
+ let d = 0;
66
+ for (let _ = 1; _ < l.length; _++)
67
+ d += Math.exp(-Math.max(0, l[_] - p[c]) / t);
68
+ if (Math.abs(d - i) < 1e-5) break;
69
+ d > i ? (s = t, t = (u + s) / 2) : (u = t, t = s === 1 / 0 ? t * 2 : (u + s) / 2);
70
+ }
71
+ o[c] = t;
39
72
  }
40
- return { sigmas: t, rhos: h };
73
+ return { sigmas: o, rhos: p };
41
74
  }
42
- function D(e, n, a, d, t) {
43
- const h = /* @__PURE__ */ new Map(), p = (r, s, f) => {
44
- const l = `${r},${s}`, m = h.get(l) ?? 0;
45
- h.set(l, m + f);
75
+ function J(n, e, a, f, o) {
76
+ const p = /* @__PURE__ */ new Map(), c = (t, i, r) => {
77
+ const d = t * f + i;
78
+ p.set(d, (p.get(d) ?? 0) + r);
46
79
  };
47
- for (let r = 0; r < e.length; r++)
48
- p(e[r], n[r], a[r]), p(n[r], e[r], a[r]);
49
- const u = [], c = [], o = [];
50
- for (const [r, s] of h.entries()) {
51
- const [f, l] = r.split(",").map(Number);
52
- u.push(f), c.push(l), o.push(
53
- s > 1 ? t * (2 - s) + (1 - t) * (s - 1) : s
80
+ for (let t = 0; t < n.length; t++)
81
+ c(n[t], e[t], a[t]), c(e[t], n[t], a[t]);
82
+ const l = [], u = [], s = [];
83
+ for (const [t, i] of p.entries()) {
84
+ const r = Math.floor(t / f), d = t % f;
85
+ l.push(r), u.push(d), s.push(
86
+ i > 1 ? o * (2 - i) + (1 - o) * (i - 1) : i
54
87
  );
55
88
  }
56
89
  return {
57
- rows: new Float32Array(u),
58
- cols: new Float32Array(c),
59
- vals: new Float32Array(o)
90
+ rows: new Float32Array(l),
91
+ cols: new Float32Array(u),
92
+ vals: new Float32Array(s)
60
93
  };
61
94
  }
62
- const W = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
95
+ const Q = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
63
96
  // Applies attraction forces between connected nodes and repulsion forces
64
97
  // against negative samples.
65
98
 
@@ -162,18 +195,18 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
162
195
  epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
163
196
  }
164
197
  `;
165
- class j {
198
+ class D {
166
199
  constructor() {
167
- O(this, "device");
168
- O(this, "pipeline");
200
+ S(this, "device");
201
+ S(this, "pipeline");
169
202
  }
170
203
  async init() {
171
- const n = await navigator.gpu.requestAdapter();
172
- if (!n) throw new Error("WebGPU not supported");
173
- this.device = await n.requestDevice(), this.pipeline = this.device.createComputePipeline({
204
+ const e = await navigator.gpu.requestAdapter();
205
+ if (!e) throw new Error("WebGPU not supported");
206
+ this.device = await e.requestDevice(), this.pipeline = this.device.createComputePipeline({
174
207
  layout: "auto",
175
208
  compute: {
176
- module: this.device.createShaderModule({ code: W }),
209
+ module: this.device.createShaderModule({ code: Q }),
177
210
  entryPoint: "main"
178
211
  }
179
212
  });
@@ -191,171 +224,322 @@ class j {
191
224
  * @param params - UMAP curve parameters and repulsion settings
192
225
  * @returns Optimized embedding as Float32Array
193
226
  */
194
- async optimize(n, a, d, t, h, p, u, c) {
195
- const { device: o } = this, r = a.length, s = this.makeBuffer(
196
- n,
227
+ async optimize(e, a, f, o, p, c, l, u) {
228
+ const { device: s } = this, t = a.length, i = this.makeBuffer(
229
+ e,
197
230
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
198
- ), f = this.makeBuffer(a, GPUBufferUsage.STORAGE), l = this.makeBuffer(d, GPUBufferUsage.STORAGE), m = this.makeBuffer(t, GPUBufferUsage.STORAGE), _ = new Float32Array(r).fill(0), w = this.makeBuffer(_, GPUBufferUsage.STORAGE), g = new Float32Array(r);
199
- for (let i = 0; i < r; i++)
200
- g[i] = t[i] / c.negativeSampleRate;
201
- const S = this.makeBuffer(g, GPUBufferUsage.STORAGE), b = new Uint32Array(r);
202
- for (let i = 0; i < r; i++)
203
- b[i] = Math.random() * 4294967295 | 0;
204
- const v = this.makeBuffer(b, GPUBufferUsage.STORAGE), M = o.createBuffer({
231
+ ), r = this.makeBuffer(a, GPUBufferUsage.STORAGE), d = this.makeBuffer(f, GPUBufferUsage.STORAGE), _ = this.makeBuffer(o, GPUBufferUsage.STORAGE), g = new Float32Array(t).fill(0), h = this.makeBuffer(g, GPUBufferUsage.STORAGE), y = new Float32Array(t);
232
+ for (let m = 0; m < t; m++)
233
+ y[m] = o[m] / u.negativeSampleRate;
234
+ const w = this.makeBuffer(y, GPUBufferUsage.STORAGE), U = new Uint32Array(t);
235
+ for (let m = 0; m < t; m++)
236
+ U[m] = Math.random() * 4294967295 | 0;
237
+ const b = this.makeBuffer(U, GPUBufferUsage.STORAGE), x = s.createBuffer({
205
238
  size: 40,
206
239
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
207
240
  });
208
- for (let i = 0; i < u; i++) {
209
- const A = 1 - i / u, B = new ArrayBuffer(40), U = new Uint32Array(B), y = new Float32Array(B);
210
- U[0] = r, U[1] = h, U[2] = p, U[3] = i, U[4] = u, y[5] = A, y[6] = c.a, y[7] = c.b, y[8] = c.gamma, U[9] = c.negativeSampleRate, o.queue.writeBuffer(M, 0, B);
211
- const G = o.createBindGroup({
241
+ for (let m = 0; m < l; m++) {
242
+ const G = 1 - m / l, A = new ArrayBuffer(40), M = new Uint32Array(A), P = new Float32Array(A);
243
+ M[0] = t, M[1] = p, M[2] = c, M[3] = m, M[4] = l, P[5] = G, P[6] = u.a, P[7] = u.b, P[8] = u.gamma, M[9] = u.negativeSampleRate, s.queue.writeBuffer(x, 0, A);
244
+ const B = s.createBindGroup({
212
245
  layout: this.pipeline.getBindGroupLayout(0),
213
246
  entries: [
214
- { binding: 0, resource: { buffer: m } },
215
- { binding: 1, resource: { buffer: f } },
216
- { binding: 2, resource: { buffer: l } },
217
- { binding: 3, resource: { buffer: s } },
218
- { binding: 4, resource: { buffer: w } },
219
- { binding: 5, resource: { buffer: S } },
220
- { binding: 6, resource: { buffer: M } },
221
- { binding: 7, resource: { buffer: v } }
247
+ { binding: 0, resource: { buffer: _ } },
248
+ { binding: 1, resource: { buffer: r } },
249
+ { binding: 2, resource: { buffer: d } },
250
+ { binding: 3, resource: { buffer: i } },
251
+ { binding: 4, resource: { buffer: h } },
252
+ { binding: 5, resource: { buffer: w } },
253
+ { binding: 6, resource: { buffer: x } },
254
+ { binding: 7, resource: { buffer: b } }
222
255
  ]
223
- }), E = o.createCommandEncoder(), R = E.beginComputePass();
224
- R.setPipeline(this.pipeline), R.setBindGroup(0, G), R.dispatchWorkgroups(Math.ceil(r / 256)), R.end(), o.queue.submit([E.finish()]), i % 10 === 0 && await o.queue.onSubmittedWorkDone();
256
+ }), N = s.createCommandEncoder(), v = N.beginComputePass();
257
+ v.setPipeline(this.pipeline), v.setBindGroup(0, B), v.dispatchWorkgroups(Math.ceil(t / 256)), v.end(), s.queue.submit([N.finish()]), m % 10 === 0 && await s.queue.onSubmittedWorkDone();
225
258
  }
226
- const x = o.createBuffer({
227
- size: n.byteLength,
259
+ const E = s.createBuffer({
260
+ size: e.byteLength,
228
261
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
229
- }), P = o.createCommandEncoder();
230
- P.copyBufferToBuffer(s, 0, x, 0, n.byteLength), o.queue.submit([P.finish()]), await x.mapAsync(GPUMapMode.READ);
231
- const k = new Float32Array(x.getMappedRange().slice(0));
232
- return x.unmap(), s.destroy(), f.destroy(), l.destroy(), m.destroy(), w.destroy(), S.destroy(), v.destroy(), M.destroy(), x.destroy(), k;
262
+ }), F = s.createCommandEncoder();
263
+ F.copyBufferToBuffer(i, 0, E, 0, e.byteLength), s.queue.submit([F.finish()]), await E.mapAsync(GPUMapMode.READ);
264
+ const R = new Float32Array(E.getMappedRange().slice(0));
265
+ return E.unmap(), i.destroy(), r.destroy(), d.destroy(), _.destroy(), h.destroy(), w.destroy(), b.destroy(), x.destroy(), E.destroy(), R;
233
266
  }
234
- makeBuffer(n, a) {
235
- const d = this.device.createBuffer({
236
- size: n.byteLength,
267
+ makeBuffer(e, a) {
268
+ const f = this.device.createBuffer({
269
+ size: e.byteLength,
237
270
  usage: a,
238
271
  mappedAtCreation: !0
239
272
  });
240
- return n instanceof Float32Array ? new Float32Array(d.getMappedRange()).set(n) : new Uint32Array(d.getMappedRange()).set(n), d.unmap(), d;
273
+ return e instanceof Float32Array ? new Float32Array(f.getMappedRange()).set(e) : new Uint32Array(f.getMappedRange()).set(e), f.unmap(), f;
241
274
  }
242
275
  }
243
- function F(e, n, a, d, t, h, p) {
244
- const { a: u, b: c, gamma: o = 1, negativeSampleRate: r = 5 } = p, s = n.rows.length, f = new Uint32Array(n.rows), l = new Uint32Array(n.cols), m = new Float32Array(s).fill(0), _ = new Float32Array(s);
245
- for (let g = 0; g < s; g++)
246
- _[g] = a[g] / r;
247
- function w(g) {
248
- return Math.max(-4, Math.min(4, g));
249
- }
250
- for (let g = 0; g < h; g++) {
251
- const S = 1 - g / h;
252
- for (let b = 0; b < s; b++) {
253
- if (m[b] > g) continue;
254
- const v = f[b], M = l[b];
276
+ function O(n) {
277
+ return Math.max(-4, Math.min(4, n));
278
+ }
279
+ function z(n, e, a, f, o, p, c) {
280
+ const { a: l, b: u, gamma: s = 1, negativeSampleRate: t = 5 } = c, i = e.rows.length, r = new Uint32Array(e.rows), d = new Uint32Array(e.cols), _ = new Float32Array(i).fill(0), g = new Float32Array(i);
281
+ for (let h = 0; h < i; h++)
282
+ g[h] = a[h] / t;
283
+ for (let h = 0; h < p; h++) {
284
+ const y = 1 - h / p;
285
+ for (let w = 0; w < i; w++) {
286
+ if (_[w] > h) continue;
287
+ const U = r[w], b = d[w];
255
288
  let x = 0;
256
- for (let i = 0; i < t; i++) {
257
- const A = e[v * t + i] - e[M * t + i];
258
- x += A * A;
289
+ for (let m = 0; m < o; m++) {
290
+ const G = n[U * o + m] - n[b * o + m];
291
+ x += G * G;
292
+ }
293
+ const E = Math.pow(x, u), F = -2 * l * u * (x > 0 ? E / x : 0) / (l * E + 1);
294
+ for (let m = 0; m < o; m++) {
295
+ const G = n[U * o + m] - n[b * o + m], A = O(F * G);
296
+ n[U * o + m] += y * A;
297
+ }
298
+ _[w] += a[w];
299
+ const R = g[w] > 0 ? Math.floor(a[w] / g[w]) : 0;
300
+ for (let m = 0; m < R; m++) {
301
+ const G = Math.floor(Math.random() * f);
302
+ if (G === U) continue;
303
+ let A = 0;
304
+ for (let B = 0; B < o; B++) {
305
+ const N = n[U * o + B] - n[G * o + B];
306
+ A += N * N;
307
+ }
308
+ const M = Math.pow(A, u), P = 2 * s * u / ((1e-3 + A) * (l * M + 1));
309
+ for (let B = 0; B < o; B++) {
310
+ const N = n[U * o + B] - n[G * o + B], v = O(P * N);
311
+ n[U * o + B] += y * v;
312
+ }
313
+ }
314
+ g[w] += a[w] / t;
315
+ }
316
+ }
317
+ return n;
318
+ }
319
+ function X(n, e, a, f, o, p, c, l, u) {
320
+ const { a: s, b: t, gamma: i = 1, negativeSampleRate: r = 5 } = u, d = a.rows.length, _ = new Uint32Array(a.rows), g = new Uint32Array(a.cols), h = new Float32Array(d).fill(0), y = new Float32Array(d);
321
+ for (let w = 0; w < d; w++)
322
+ y[w] = f[w] / r;
323
+ for (let w = 0; w < l; w++) {
324
+ const U = 1 - w / l;
325
+ for (let b = 0; b < d; b++) {
326
+ if (h[b] > w) continue;
327
+ const x = _[b], E = g[b];
328
+ let F = 0;
329
+ for (let A = 0; A < c; A++) {
330
+ const M = n[x * c + A] - e[E * c + A];
331
+ F += M * M;
259
332
  }
260
- const P = -2 * u * c * Math.pow(x, c - 1) / (u * Math.pow(x, c) + 1);
261
- for (let i = 0; i < t; i++) {
262
- const A = e[v * t + i] - e[M * t + i], B = w(P * A);
263
- e[v * t + i] += S * B;
333
+ const R = Math.pow(F, t), m = -2 * s * t * (F > 0 ? R / F : 0) / (s * R + 1);
334
+ for (let A = 0; A < c; A++) {
335
+ const M = n[x * c + A] - e[E * c + A];
336
+ n[x * c + A] += U * O(m * M);
264
337
  }
265
- m[b] += a[b];
266
- const k = _[b] > 0 ? Math.floor(a[b] / _[b]) : 0;
267
- for (let i = 0; i < k; i++) {
268
- const A = Math.floor(Math.random() * d);
269
- if (A === v) continue;
270
- let B = 0;
271
- for (let y = 0; y < t; y++) {
272
- const G = e[v * t + y] - e[A * t + y];
273
- B += G * G;
338
+ h[b] += f[b];
339
+ const G = y[b] > 0 ? Math.floor(f[b] / y[b]) : 0;
340
+ for (let A = 0; A < G; A++) {
341
+ const M = Math.floor(Math.random() * p);
342
+ if (M === E) continue;
343
+ let P = 0;
344
+ for (let v = 0; v < c; v++) {
345
+ const k = n[x * c + v] - e[M * c + v];
346
+ P += k * k;
274
347
  }
275
- const U = 2 * o * c / ((1e-3 + B) * (u * Math.pow(B, c) + 1));
276
- for (let y = 0; y < t; y++) {
277
- const G = e[v * t + y] - e[A * t + y], E = w(U * G);
278
- e[v * t + y] += S * E;
348
+ const B = Math.pow(P, t), N = 2 * i * t / ((1e-3 + P) * (s * B + 1));
349
+ for (let v = 0; v < c; v++) {
350
+ const k = n[x * c + v] - e[M * c + v];
351
+ n[x * c + v] += U * O(N * k);
279
352
  }
280
353
  }
281
- _[b] += a[b] / r;
354
+ y[b] += f[b] / r;
282
355
  }
283
356
  }
284
- return e;
357
+ return n;
285
358
  }
286
- function I() {
359
+ function L() {
287
360
  return typeof navigator < "u" && !!navigator.gpu;
288
361
  }
289
- async function Q(e, n = {}) {
362
+ async function te(n, e = {}) {
290
363
  const {
291
364
  nComponents: a = 2,
292
- nNeighbors: d = 15,
293
- minDist: t = 0.1,
294
- spread: h = 1,
295
- hnsw: p = {}
296
- } = n, u = n.nEpochs ?? (e.length > 1e4 ? 200 : 500);
365
+ nNeighbors: f = 15,
366
+ minDist: o = 0.1,
367
+ spread: p = 1,
368
+ hnsw: c = {}
369
+ } = e, l = e.nEpochs ?? (n.length > 1e4 ? 200 : 500);
297
370
  console.time("knn");
298
- const { indices: c, distances: o } = await T(e, d, {
299
- M: p.M ?? 16,
300
- efConstruction: p.efConstruction ?? 200,
301
- efSearch: p.efSearch ?? 50
371
+ const { indices: u, distances: s } = await H(n, f, {
372
+ M: c.M ?? 16,
373
+ efConstruction: c.efConstruction ?? 200,
374
+ efSearch: c.efSearch ?? 50
302
375
  });
303
376
  console.timeEnd("knn"), console.time("fuzzy-set");
304
- const r = L(c, o, d);
377
+ const t = C(u, s, f);
305
378
  console.timeEnd("fuzzy-set");
306
- const { a: s, b: f } = K(t, h), l = Y(r.vals, u), m = e.length, _ = new Float32Array(m * a);
307
- for (let g = 0; g < _.length; g++)
308
- _[g] = Math.random() * 20 - 10;
379
+ const { a: i, b: r } = W(o, p), d = q(t.vals, l), _ = n.length, g = new Float32Array(_ * a);
380
+ for (let y = 0; y < g.length; y++)
381
+ g[y] = Math.random() * 20 - 10;
309
382
  console.time("sgd");
310
- let w;
311
- if (I())
383
+ let h;
384
+ if (L())
312
385
  try {
313
- const g = new j();
314
- await g.init(), w = await g.optimize(
386
+ const y = new D();
387
+ await y.init(), h = await y.optimize(
388
+ g,
389
+ new Uint32Array(t.rows),
390
+ new Uint32Array(t.cols),
391
+ d,
315
392
  _,
316
- new Uint32Array(r.rows),
317
- new Uint32Array(r.cols),
318
- l,
319
- m,
320
393
  a,
321
- u,
322
- { a: s, b: f, gamma: 1, negativeSampleRate: 5 }
394
+ l,
395
+ { a: i, b: r, gamma: 1, negativeSampleRate: 5 }
323
396
  );
324
- } catch (g) {
325
- console.warn("WebGPU SGD failed, falling back to CPU:", g), w = F(_, r, l, m, a, u, { a: s, b: f });
397
+ } catch (y) {
398
+ console.warn("WebGPU SGD failed, falling back to CPU:", y), h = z(g, t, d, _, a, l, { a: i, b: r });
326
399
  }
327
400
  else
328
- w = F(_, r, l, m, a, u, { a: s, b: f });
329
- return console.timeEnd("sgd"), w;
401
+ h = z(g, t, d, _, a, l, { a: i, b: r });
402
+ return console.timeEnd("sgd"), h;
330
403
  }
331
- function K(e, n) {
332
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
404
+ function W(n, e) {
405
+ if (Math.abs(e - 1) < 1e-6 && Math.abs(n - 0.1) < 1e-6)
333
406
  return { a: 1.9292, b: 0.7915 };
334
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
407
+ if (Math.abs(e - 1) < 1e-6 && Math.abs(n - 0) < 1e-6)
335
408
  return { a: 1.8956, b: 0.8006 };
336
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
409
+ if (Math.abs(e - 1) < 1e-6 && Math.abs(n - 0.5) < 1e-6)
337
410
  return { a: 1.5769, b: 0.8951 };
338
- const a = H(e, n);
339
- return { a: V(e, n, a), b: a };
411
+ const a = Z(n, e);
412
+ return { a: $(n, e, a), b: a };
413
+ }
414
+ function Z(n, e) {
415
+ return 1 / (e * 1.2);
340
416
  }
341
- function H(e, n) {
342
- return 1 / (n * 1.2);
417
+ function $(n, e, a) {
418
+ return n < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(n, 2 * a);
343
419
  }
344
- function V(e, n, a) {
345
- return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * a);
420
+ class se {
421
+ constructor(e = {}) {
422
+ S(this, "_nComponents");
423
+ S(this, "_nNeighbors");
424
+ S(this, "_minDist");
425
+ S(this, "_spread");
426
+ S(this, "_nEpochs");
427
+ S(this, "_hnswOpts");
428
+ S(this, "_a");
429
+ S(this, "_b");
430
+ /** The low-dimensional embedding produced by the last fit() call. */
431
+ S(this, "embedding", null);
432
+ S(this, "_hnswIndex", null);
433
+ S(this, "_nTrain", 0);
434
+ this._nComponents = e.nComponents ?? 2, this._nNeighbors = e.nNeighbors ?? 15, this._minDist = e.minDist ?? 0.1, this._spread = e.spread ?? 1, this._nEpochs = e.nEpochs, this._hnswOpts = e.hnsw ?? {};
435
+ const { a, b: f } = W(this._minDist, this._spread);
436
+ this._a = a, this._b = f;
437
+ }
438
+ /**
439
+ * Train UMAP on `vectors`.
440
+ * Stores the resulting embedding in `this.embedding` and retains the HNSW
441
+ * index so that transform() can project new points later.
442
+ * Returns `this` for chaining.
443
+ */
444
+ async fit(e) {
445
+ const a = e.length, f = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: o = 16, efConstruction: p = 200, efSearch: c = 50 } = this._hnswOpts;
446
+ console.time("knn");
447
+ const { knn: l, index: u } = await V(e, this._nNeighbors, {
448
+ M: o,
449
+ efConstruction: p,
450
+ efSearch: c
451
+ });
452
+ this._hnswIndex = u, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
453
+ const s = C(l.indices, l.distances, this._nNeighbors);
454
+ console.timeEnd("fuzzy-set");
455
+ const t = q(s.vals, f), i = new Float32Array(a * this._nComponents);
456
+ for (let r = 0; r < i.length; r++)
457
+ i[r] = Math.random() * 20 - 10;
458
+ if (console.time("sgd"), L())
459
+ try {
460
+ const r = new D();
461
+ await r.init(), this.embedding = await r.optimize(
462
+ i,
463
+ new Uint32Array(s.rows),
464
+ new Uint32Array(s.cols),
465
+ t,
466
+ a,
467
+ this._nComponents,
468
+ f,
469
+ { a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 }
470
+ );
471
+ } catch (r) {
472
+ console.warn("WebGPU SGD failed, falling back to CPU:", r), this.embedding = z(i, s, t, a, this._nComponents, f, {
473
+ a: this._a,
474
+ b: this._b
475
+ });
476
+ }
477
+ else
478
+ this.embedding = z(i, s, t, a, this._nComponents, f, {
479
+ a: this._a,
480
+ b: this._b
481
+ });
482
+ return console.timeEnd("sgd"), this;
483
+ }
484
+ /**
485
+ * Project new (unseen) `vectors` into the embedding space learned by fit().
486
+ * Must be called after fit().
487
+ *
488
+ * The training embedding is kept fixed; only the new-point positions are
489
+ * optimised. Returns a Float32Array of shape [vectors.length × nComponents].
490
+ */
491
+ async transform(e) {
492
+ if (!this._hnswIndex || !this.embedding)
493
+ throw new Error("UMAP.transform() must be called after fit()");
494
+ const a = e.length, f = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), o = Math.max(100, Math.floor(f / 4)), p = this._hnswIndex.searchKnn(e, this._nNeighbors), c = Y(p.indices, p.distances, this._nNeighbors), l = new Uint32Array(c.rows), u = new Uint32Array(c.cols), s = new Float32Array(a), t = new Float32Array(a * this._nComponents);
495
+ for (let r = 0; r < l.length; r++) {
496
+ const d = l[r], _ = u[r], g = c.vals[r];
497
+ s[d] += g;
498
+ for (let h = 0; h < this._nComponents; h++)
499
+ t[d * this._nComponents + h] += g * this.embedding[_ * this._nComponents + h];
500
+ }
501
+ for (let r = 0; r < a; r++)
502
+ if (s[r] > 0)
503
+ for (let d = 0; d < this._nComponents; d++)
504
+ t[r * this._nComponents + d] /= s[r];
505
+ else
506
+ for (let d = 0; d < this._nComponents; d++)
507
+ t[r * this._nComponents + d] = Math.random() * 20 - 10;
508
+ const i = q(c.vals, o);
509
+ return X(
510
+ t,
511
+ this.embedding,
512
+ c,
513
+ i,
514
+ a,
515
+ this._nTrain,
516
+ this._nComponents,
517
+ o,
518
+ { a: this._a, b: this._b }
519
+ );
520
+ }
521
+ /**
522
+ * Convenience method equivalent to `fit(vectors)` followed by
523
+ * `transform(vectors)` — but more efficient because the training embedding
524
+ * is returned directly without a second optimization pass.
525
+ */
526
+ async fit_transform(e) {
527
+ return await this.fit(e), this.embedding;
528
+ }
346
529
  }
347
- function Y(e, n) {
530
+ function q(n, e) {
348
531
  let a = -1 / 0;
349
- for (let t = 0; t < e.length; t++)
350
- e[t] > a && (a = e[t]);
351
- const d = new Float32Array(e.length);
352
- for (let t = 0; t < e.length; t++) {
353
- const h = e[t] / a;
354
- d[t] = h > 0 ? n / h : -1;
532
+ for (let o = 0; o < n.length; o++)
533
+ n[o] > a && (a = n[o]);
534
+ const f = new Float32Array(n.length);
535
+ for (let o = 0; o < n.length; o++) {
536
+ const p = n[o] / a;
537
+ f[o] = p > 0 ? e / p : -1;
355
538
  }
356
- return d;
539
+ return f;
357
540
  }
358
541
  export {
359
- Q as fit,
360
- I as isWebGPUAvailable
542
+ se as UMAP,
543
+ te as fit,
544
+ L as isWebGPUAvailable
361
545
  };
package/dist/umap.d.ts CHANGED
@@ -36,6 +36,60 @@ export declare function findAB(minDist: number, spread: number): {
36
36
  a: number;
37
37
  b: number;
38
38
  };
39
+ /**
40
+ * Stateful UMAP model that supports separate fit / transform / fit_transform.
41
+ *
42
+ * Usage:
43
+ * ```ts
44
+ * const umap = new UMAP({ nNeighbors: 15, nComponents: 2 });
45
+ *
46
+ * // Train on high-dimensional data:
47
+ * await umap.fit(trainVectors);
48
+ * console.log(umap.embedding); // Float32Array [nTrain * nComponents]
49
+ *
50
+ * // Project new points into the same space:
51
+ * const newEmbedding = await umap.transform(testVectors);
52
+ *
53
+ * // Or do both at once:
54
+ * const embedding = await umap.fit_transform(vectors);
55
+ * ```
56
+ */
57
+ export declare class UMAP {
58
+ private readonly _nComponents;
59
+ private readonly _nNeighbors;
60
+ private readonly _minDist;
61
+ private readonly _spread;
62
+ private readonly _nEpochs;
63
+ private readonly _hnswOpts;
64
+ private readonly _a;
65
+ private readonly _b;
66
+ /** The low-dimensional embedding produced by the last fit() call. */
67
+ embedding: Float32Array | null;
68
+ private _hnswIndex;
69
+ private _nTrain;
70
+ constructor(opts?: UMAPOptions);
71
+ /**
72
+ * Train UMAP on `vectors`.
73
+ * Stores the resulting embedding in `this.embedding` and retains the HNSW
74
+ * index so that transform() can project new points later.
75
+ * Returns `this` for chaining.
76
+ */
77
+ fit(vectors: number[][]): Promise<this>;
78
+ /**
79
+ * Project new (unseen) `vectors` into the embedding space learned by fit().
80
+ * Must be called after fit().
81
+ *
82
+ * The training embedding is kept fixed; only the new-point positions are
83
+ * optimised. Returns a Float32Array of shape [vectors.length × nComponents].
84
+ */
85
+ transform(vectors: number[][]): Promise<Float32Array>;
86
+ /**
87
+ * Convenience method equivalent to `fit(vectors)` followed by
88
+ * `transform(vectors)` — but more efficient because the training embedding
89
+ * is returned directly without a second optimization pass.
90
+ */
91
+ fit_transform(vectors: number[][]): Promise<Float32Array>;
92
+ }
39
93
  /**
40
94
  * Compute per-edge epoch sampling periods based on edge weights.
41
95
  * Higher-weight edges are sampled more frequently.
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.1.0",
3
+ "version": "0.2.6",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -8,6 +8,10 @@
8
8
  "files": [
9
9
  "dist"
10
10
  ],
11
+ "repository": {
12
+ "type": "git",
13
+ "url": "https://github.com/Achuttarsing/umap-gpu"
14
+ },
11
15
  "scripts": {
12
16
  "build": "vite build && tsc",
13
17
  "dev": "vite",