umap-gpu 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Adrien Chuttarsing
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
package/README.md ADDED
@@ -0,0 +1,88 @@
1
+ # umap-gpu
2
+
3
+ UMAP dimensionality reduction with HNSW k-nearest-neighbor search and WebGPU-accelerated SGD optimization, with a transparent CPU fallback.
4
+
5
+ ## What it does
6
+
7
+ Takes a set of high-dimensional vectors and returns a low-dimensional embedding (default: 2D) suitable for visualization or downstream tasks.
8
+
9
+ The pipeline runs in three stages:
10
+
11
+ 1. **k-NN** — approximate nearest neighbors via [hnswlib-wasm](https://github.com/yoshoku/hnswlib-wasm) (O(n log n))
12
+ 2. **Fuzzy simplicial set** — builds a weighted graph from the k-NN graph using smooth distances
13
+ 3. **SGD** — optimizes the embedding using attraction/repulsion forces:
14
+ - **WebGPU** compute shader when available (Chrome 113+, Edge 113+)
15
+ - **CPU** fallback otherwise — identical output, just slower
16
+
17
+ ## Install
18
+
19
+ ```bash
20
+ npm install umap-gpu
21
+ ```
22
+
23
+ > Requires a browser or runtime with WebGPU support for GPU acceleration. The CPU fallback works anywhere.
24
+
25
+ ## Usage
26
+
27
+ ```ts
28
+ import { fit } from 'umap-gpu';
29
+
30
+ const vectors = [
31
+ [1.0, 0.0, 0.3],
32
+ [0.9, 0.1, 0.4],
33
+ [0.0, 1.0, 0.8],
34
+ // ...
35
+ ];
36
+
37
+ const embedding = await fit(vectors);
38
+ // Float32Array of length n * nComponents (default: n * 2)
39
+ // embedding[i*2], embedding[i*2 + 1] → 2D coordinates of point i
40
+ ```
41
+
42
+ ### Options
43
+
44
+ ```ts
45
+ const embedding = await fit(vectors, {
46
+ nComponents: 2, // output dimensions (default: 2)
47
+ nNeighbors: 15, // k-NN graph degree (default: 15)
48
+ nEpochs: 500, // SGD iterations (default: 500 for <10k points, 200 otherwise)
49
+ minDist: 0.1, // minimum distance between points in the embedding (default: 0.1)
50
+ spread: 1.0, // scale of the embedding (default: 1.0)
51
+ hnsw: {
52
+ M: 16, // HNSW graph connectivity (default: 16)
53
+ efConstruction: 200, // build-time search width (default: 200)
54
+ efSearch: 50, // query-time search width (default: 50)
55
+ },
56
+ });
57
+ ```
58
+
59
+ ### Checking GPU availability
60
+
61
+ ```ts
62
+ import { isWebGPUAvailable } from 'umap-gpu';
63
+
64
+ if (isWebGPUAvailable()) {
65
+ console.log('Will use WebGPU-accelerated SGD');
66
+ } else {
67
+ console.log('Will fall back to CPU SGD');
68
+ }
69
+ ```
70
+
71
+ ## Build
72
+
73
+ ```bash
74
+ npm run build # compiles TypeScript to dist/
75
+ npm test # runs the unit test suite (Vitest)
76
+ ```
77
+
78
+ ## Browser support
79
+
80
+ | Feature | Requirement |
81
+ |---------|-------------|
82
+ | WebGPU SGD | Chrome 113+, Edge 113+, Safari 18+ |
83
+ | CPU fallback | Any modern browser / Node.js |
84
+ | HNSW (WASM) | Any environment with WebAssembly support |
85
+
86
+ ## License
87
+
88
+ MIT
@@ -0,0 +1,12 @@
1
+ import type { FuzzyGraph } from '../fuzzy-set';
2
+ export interface CPUSgdParams {
3
+ a: number;
4
+ b: number;
5
+ gamma?: number;
6
+ negativeSampleRate?: number;
7
+ }
8
+ /**
9
+ * CPU fallback SGD optimizer for environments without WebGPU.
10
+ * Mirrors the GPU shader logic: per-edge attraction + negative-sample repulsion.
11
+ */
12
+ export declare function cpuSgd(embedding: Float32Array, graph: FuzzyGraph, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: CPUSgdParams): Float32Array;
@@ -0,0 +1,12 @@
1
+ export interface FuzzyGraph {
2
+ rows: Float32Array;
3
+ cols: Float32Array;
4
+ vals: Float32Array;
5
+ nVertices: number;
6
+ }
7
+ /**
8
+ * Compute the fuzzy simplicial set from kNN results.
9
+ * This builds the high-dimensional graph weights using smooth kNN distances
10
+ * (sigmas, rhos) and symmetrizes with the fuzzy set union operation.
11
+ */
12
+ export declare function computeFuzzySimplicialSet(knnIndices: number[][], knnDistances: number[][], nNeighbors: number, setOpMixRatio?: number): FuzzyGraph;
@@ -0,0 +1,12 @@
1
+ /**
2
+ * WebGPU device management — handles adapter/device acquisition and
3
+ * provides a single shared device instance.
4
+ */
5
+ /**
6
+ * Request and cache a WebGPU device. Returns null if WebGPU is not available.
7
+ */
8
+ export declare function getGPUDevice(): Promise<GPUDevice | null>;
9
+ /**
10
+ * Check whether WebGPU is available in the current environment.
11
+ */
12
+ export declare function isWebGPUAvailable(): boolean;
@@ -0,0 +1,30 @@
1
+ export interface SGDParams {
2
+ a: number;
3
+ b: number;
4
+ gamma: number;
5
+ negativeSampleRate: number;
6
+ }
7
+ /**
8
+ * GPU-accelerated SGD optimizer for UMAP embedding.
9
+ * Each GPU thread processes one graph edge, applying attraction and repulsion forces.
10
+ */
11
+ export declare class GPUSgd {
12
+ private device;
13
+ private pipeline;
14
+ init(): Promise<void>;
15
+ /**
16
+ * Run SGD optimization on the GPU.
17
+ *
18
+ * @param embedding - Initial embedding positions [nVertices * nComponents]
19
+ * @param head - Edge source node indices
20
+ * @param tail - Edge target node indices
21
+ * @param epochsPerSample - Per-edge epoch sampling period
22
+ * @param nVertices - Number of data points
23
+ * @param nComponents - Embedding dimensionality (typically 2)
24
+ * @param nEpochs - Total number of optimization epochs
25
+ * @param params - UMAP curve parameters and repulsion settings
26
+ * @returns Optimized embedding as Float32Array
27
+ */
28
+ optimize(embedding: Float32Array, head: Uint32Array, tail: Uint32Array, epochsPerSample: Float32Array, nVertices: number, nComponents: number, nEpochs: number, params: SGDParams): Promise<Float32Array>;
29
+ private makeBuffer;
30
+ }
@@ -0,0 +1,15 @@
1
+ export interface KNNResult {
2
+ indices: number[][];
3
+ distances: number[][];
4
+ }
5
+ export interface HNSWOptions {
6
+ M?: number;
7
+ efConstruction?: number;
8
+ efSearch?: number;
9
+ }
10
+ /**
11
+ * Compute k-nearest neighbors using HNSW (Hierarchical Navigable Small World)
12
+ * via hnswlib-wasm, replacing the O(n^2) brute-force search in umap-js with
13
+ * an O(n log n) approximate nearest neighbor search.
14
+ */
15
+ export declare function computeKNN(vectors: number[][], nNeighbors: number, opts?: HNSWOptions): Promise<KNNResult>;
@@ -0,0 +1,5 @@
1
+ export { fit } from './umap';
2
+ export type { UMAPOptions } from './umap';
3
+ export type { KNNResult, HNSWOptions } from './hnsw-knn';
4
+ export type { FuzzyGraph } from './fuzzy-set';
5
+ export { isWebGPUAvailable } from './gpu/device';
package/dist/index.js ADDED
@@ -0,0 +1,361 @@
1
+ var N = Object.defineProperty;
2
+ var q = (e, n, a) => n in e ? N(e, n, { enumerable: !0, configurable: !0, writable: !0, value: a }) : e[n] = a;
3
+ var O = (e, n, a) => q(e, typeof n != "symbol" ? n + "" : n, a);
4
+ import { loadHnswlib as z } from "hnswlib-wasm";
5
+ async function T(e, n, a = {}) {
6
+ const { M: d = 16, efConstruction: t = 200, efSearch: h = 50 } = a, p = await z(), u = e[0].length, c = e.length, o = new p.HierarchicalNSW("l2", u, "");
7
+ o.initIndex(c, d, t, 200), o.setEfSearch(Math.max(h, n)), o.addItems(e, !1);
8
+ const r = [], s = [];
9
+ for (let f = 0; f < c; f++) {
10
+ const l = o.searchKnn(e[f], n + 1, void 0), m = l.neighbors.map((_, w) => ({ idx: _, dist: l.distances[w] })).filter(({ idx: _ }) => _ !== f).slice(0, n);
11
+ r.push(m.map(({ idx: _ }) => _)), s.push(m.map(({ dist: _ }) => _));
12
+ }
13
+ return { indices: r, distances: s };
14
+ }
15
+ function L(e, n, a, d = 1) {
16
+ const t = e.length, { sigmas: h, rhos: p } = C(n, a), u = [], c = [], o = [];
17
+ for (let s = 0; s < t; s++)
18
+ for (let f = 0; f < e[s].length; f++) {
19
+ const l = n[s][f], m = l <= p[s] ? 1 : Math.exp(-((l - p[s]) / h[s]));
20
+ u.push(s), c.push(e[s][f]), o.push(m);
21
+ }
22
+ return { ...D(u, c, o, t, d), nVertices: t };
23
+ }
24
+ function C(e, n) {
25
+ const d = e.length, t = new Float32Array(d), h = new Float32Array(d);
26
+ for (let p = 0; p < d; p++) {
27
+ const u = e[p];
28
+ h[p] = u.find((f) => f > 0) ?? 0;
29
+ let c = 0, o = 1 / 0, r = 1;
30
+ const s = Math.log2(n);
31
+ for (let f = 0; f < 64; f++) {
32
+ let l = 0;
33
+ for (let m = 1; m < u.length; m++)
34
+ l += Math.exp(-Math.max(0, u[m] - h[p]) / r);
35
+ if (Math.abs(l - s) < 1e-5) break;
36
+ l > s ? (o = r, r = (c + o) / 2) : (c = r, r = o === 1 / 0 ? r * 2 : (c + o) / 2);
37
+ }
38
+ t[p] = r;
39
+ }
40
+ return { sigmas: t, rhos: h };
41
+ }
42
+ function D(e, n, a, d, t) {
43
+ const h = /* @__PURE__ */ new Map(), p = (r, s, f) => {
44
+ const l = `${r},${s}`, m = h.get(l) ?? 0;
45
+ h.set(l, m + f);
46
+ };
47
+ for (let r = 0; r < e.length; r++)
48
+ p(e[r], n[r], a[r]), p(n[r], e[r], a[r]);
49
+ const u = [], c = [], o = [];
50
+ for (const [r, s] of h.entries()) {
51
+ const [f, l] = r.split(",").map(Number);
52
+ u.push(f), c.push(l), o.push(
53
+ s > 1 ? t * (2 - s) + (1 - t) * (s - 1) : s
54
+ );
55
+ }
56
+ return {
57
+ rows: new Float32Array(u),
58
+ cols: new Float32Array(c),
59
+ vals: new Float32Array(o)
60
+ };
61
+ }
62
+ const W = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
63
+ // Applies attraction forces between connected nodes and repulsion forces
64
+ // against negative samples.
65
+
66
+ @group(0) @binding(0) var<storage, read> epochs_per_sample : array<f32>;
67
+ @group(0) @binding(1) var<storage, read> head : array<u32>; // edge source
68
+ @group(0) @binding(2) var<storage, read> tail : array<u32>; // edge target
69
+ @group(0) @binding(3) var<storage, read_write> embedding : array<f32>; // [n * nComponents]
70
+ @group(0) @binding(4) var<storage, read_write> epoch_of_next_sample : array<f32>;
71
+ @group(0) @binding(5) var<storage, read_write> epoch_of_next_negative_sample : array<f32>;
72
+ @group(0) @binding(6) var<uniform> params : Params;
73
+ @group(0) @binding(7) var<storage, read> rng_seeds : array<u32>; // per-edge seed
74
+
75
+ struct Params {
76
+ n_edges : u32,
77
+ n_vertices : u32,
78
+ n_components : u32,
79
+ current_epoch : u32,
80
+ n_epochs : u32,
81
+ alpha : f32, // learning rate
82
+ a : f32,
83
+ b : f32,
84
+ gamma : f32, // repulsion strength
85
+ negative_sample_rate : u32,
86
+ }
87
+
88
+ fn clip(v: f32, lo: f32, hi: f32) -> f32 {
89
+ return max(lo, min(hi, v));
90
+ }
91
+
92
+ // Simple xorshift RNG per thread
93
+ fn xorshift(seed: u32) -> u32 {
94
+ var s = seed;
95
+ s ^= s << 13u;
96
+ s ^= s >> 17u;
97
+ s ^= s << 5u;
98
+ return s;
99
+ }
100
+
101
+ @compute @workgroup_size(256)
102
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
103
+ let edge_idx = gid.x;
104
+ if (edge_idx >= params.n_edges) { return; }
105
+
106
+ // Only process this edge if its epoch has come
107
+ if (epoch_of_next_sample[edge_idx] > f32(params.current_epoch)) { return; }
108
+
109
+ let i = head[edge_idx];
110
+ let j = tail[edge_idx];
111
+ let nc = params.n_components;
112
+
113
+ // --- Attraction ---
114
+ var dist_sq : f32 = 0.0;
115
+ for (var d = 0u; d < nc; d++) {
116
+ let diff = embedding[i * nc + d] - embedding[j * nc + d];
117
+ dist_sq += diff * diff;
118
+ }
119
+
120
+ let grad_coeff_attr = -2.0 * params.a * params.b * pow(dist_sq, params.b - 1.0)
121
+ / (params.a * pow(dist_sq, params.b) + 1.0);
122
+
123
+ for (var d = 0u; d < nc; d++) {
124
+ let diff = embedding[i * nc + d] - embedding[j * nc + d];
125
+ let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
126
+ embedding[i * nc + d] += params.alpha * grad;
127
+ }
128
+
129
+ epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
130
+
131
+ // --- Repulsion (negative samples) ---
132
+ let eps = epochs_per_sample[edge_idx];
133
+ let neg_eps = epoch_of_next_negative_sample[edge_idx];
134
+ var n_neg = 0u;
135
+ if (neg_eps > 0.0) {
136
+ n_neg = u32(eps / neg_eps);
137
+ }
138
+ var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 6364136223u);
139
+
140
+ for (var s = 0u; s < n_neg; s++) {
141
+ rng = xorshift(rng);
142
+ let k = rng % params.n_vertices;
143
+ if (k == i) { continue; }
144
+
145
+ var neg_dist_sq : f32 = 0.0;
146
+ for (var d = 0u; d < nc; d++) {
147
+ let diff = embedding[i * nc + d] - embedding[k * nc + d];
148
+ neg_dist_sq += diff * diff;
149
+ }
150
+
151
+ let grad_coeff_rep = 2.0 * params.gamma * params.b
152
+ / ((0.001 + neg_dist_sq) * (params.a * pow(neg_dist_sq, params.b) + 1.0));
153
+
154
+ for (var d = 0u; d < nc; d++) {
155
+ let diff = embedding[i * nc + d] - embedding[k * nc + d];
156
+ let grad = clip(grad_coeff_rep * diff, -4.0, 4.0);
157
+ embedding[i * nc + d] += params.alpha * grad;
158
+ }
159
+ }
160
+
161
+ epoch_of_next_negative_sample[edge_idx] +=
162
+ epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
163
+ }
164
+ `;
165
+ class j {
166
+ constructor() {
167
+ O(this, "device");
168
+ O(this, "pipeline");
169
+ }
170
+ async init() {
171
+ const n = await navigator.gpu.requestAdapter();
172
+ if (!n) throw new Error("WebGPU not supported");
173
+ this.device = await n.requestDevice(), this.pipeline = this.device.createComputePipeline({
174
+ layout: "auto",
175
+ compute: {
176
+ module: this.device.createShaderModule({ code: W }),
177
+ entryPoint: "main"
178
+ }
179
+ });
180
+ }
181
+ /**
182
+ * Run SGD optimization on the GPU.
183
+ *
184
+ * @param embedding - Initial embedding positions [nVertices * nComponents]
185
+ * @param head - Edge source node indices
186
+ * @param tail - Edge target node indices
187
+ * @param epochsPerSample - Per-edge epoch sampling period
188
+ * @param nVertices - Number of data points
189
+ * @param nComponents - Embedding dimensionality (typically 2)
190
+ * @param nEpochs - Total number of optimization epochs
191
+ * @param params - UMAP curve parameters and repulsion settings
192
+ * @returns Optimized embedding as Float32Array
193
+ */
194
+ async optimize(n, a, d, t, h, p, u, c) {
195
+ const { device: o } = this, r = a.length, s = this.makeBuffer(
196
+ n,
197
+ GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
198
+ ), f = this.makeBuffer(a, GPUBufferUsage.STORAGE), l = this.makeBuffer(d, GPUBufferUsage.STORAGE), m = this.makeBuffer(t, GPUBufferUsage.STORAGE), _ = new Float32Array(r).fill(0), w = this.makeBuffer(_, GPUBufferUsage.STORAGE), g = new Float32Array(r);
199
+ for (let i = 0; i < r; i++)
200
+ g[i] = t[i] / c.negativeSampleRate;
201
+ const S = this.makeBuffer(g, GPUBufferUsage.STORAGE), b = new Uint32Array(r);
202
+ for (let i = 0; i < r; i++)
203
+ b[i] = Math.random() * 4294967295 | 0;
204
+ const v = this.makeBuffer(b, GPUBufferUsage.STORAGE), M = o.createBuffer({
205
+ size: 40,
206
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
207
+ });
208
+ for (let i = 0; i < u; i++) {
209
+ const A = 1 - i / u, B = new ArrayBuffer(40), U = new Uint32Array(B), y = new Float32Array(B);
210
+ U[0] = r, U[1] = h, U[2] = p, U[3] = i, U[4] = u, y[5] = A, y[6] = c.a, y[7] = c.b, y[8] = c.gamma, U[9] = c.negativeSampleRate, o.queue.writeBuffer(M, 0, B);
211
+ const G = o.createBindGroup({
212
+ layout: this.pipeline.getBindGroupLayout(0),
213
+ entries: [
214
+ { binding: 0, resource: { buffer: m } },
215
+ { binding: 1, resource: { buffer: f } },
216
+ { binding: 2, resource: { buffer: l } },
217
+ { binding: 3, resource: { buffer: s } },
218
+ { binding: 4, resource: { buffer: w } },
219
+ { binding: 5, resource: { buffer: S } },
220
+ { binding: 6, resource: { buffer: M } },
221
+ { binding: 7, resource: { buffer: v } }
222
+ ]
223
+ }), E = o.createCommandEncoder(), R = E.beginComputePass();
224
+ R.setPipeline(this.pipeline), R.setBindGroup(0, G), R.dispatchWorkgroups(Math.ceil(r / 256)), R.end(), o.queue.submit([E.finish()]), i % 10 === 0 && await o.queue.onSubmittedWorkDone();
225
+ }
226
+ const x = o.createBuffer({
227
+ size: n.byteLength,
228
+ usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
229
+ }), P = o.createCommandEncoder();
230
+ P.copyBufferToBuffer(s, 0, x, 0, n.byteLength), o.queue.submit([P.finish()]), await x.mapAsync(GPUMapMode.READ);
231
+ const k = new Float32Array(x.getMappedRange().slice(0));
232
+ return x.unmap(), s.destroy(), f.destroy(), l.destroy(), m.destroy(), w.destroy(), S.destroy(), v.destroy(), M.destroy(), x.destroy(), k;
233
+ }
234
+ makeBuffer(n, a) {
235
+ const d = this.device.createBuffer({
236
+ size: n.byteLength,
237
+ usage: a,
238
+ mappedAtCreation: !0
239
+ });
240
+ return n instanceof Float32Array ? new Float32Array(d.getMappedRange()).set(n) : new Uint32Array(d.getMappedRange()).set(n), d.unmap(), d;
241
+ }
242
+ }
243
+ function F(e, n, a, d, t, h, p) {
244
+ const { a: u, b: c, gamma: o = 1, negativeSampleRate: r = 5 } = p, s = n.rows.length, f = new Uint32Array(n.rows), l = new Uint32Array(n.cols), m = new Float32Array(s).fill(0), _ = new Float32Array(s);
245
+ for (let g = 0; g < s; g++)
246
+ _[g] = a[g] / r;
247
+ function w(g) {
248
+ return Math.max(-4, Math.min(4, g));
249
+ }
250
+ for (let g = 0; g < h; g++) {
251
+ const S = 1 - g / h;
252
+ for (let b = 0; b < s; b++) {
253
+ if (m[b] > g) continue;
254
+ const v = f[b], M = l[b];
255
+ let x = 0;
256
+ for (let i = 0; i < t; i++) {
257
+ const A = e[v * t + i] - e[M * t + i];
258
+ x += A * A;
259
+ }
260
+ const P = -2 * u * c * Math.pow(x, c - 1) / (u * Math.pow(x, c) + 1);
261
+ for (let i = 0; i < t; i++) {
262
+ const A = e[v * t + i] - e[M * t + i], B = w(P * A);
263
+ e[v * t + i] += S * B;
264
+ }
265
+ m[b] += a[b];
266
+ const k = _[b] > 0 ? Math.floor(a[b] / _[b]) : 0;
267
+ for (let i = 0; i < k; i++) {
268
+ const A = Math.floor(Math.random() * d);
269
+ if (A === v) continue;
270
+ let B = 0;
271
+ for (let y = 0; y < t; y++) {
272
+ const G = e[v * t + y] - e[A * t + y];
273
+ B += G * G;
274
+ }
275
+ const U = 2 * o * c / ((1e-3 + B) * (u * Math.pow(B, c) + 1));
276
+ for (let y = 0; y < t; y++) {
277
+ const G = e[v * t + y] - e[A * t + y], E = w(U * G);
278
+ e[v * t + y] += S * E;
279
+ }
280
+ }
281
+ _[b] += a[b] / r;
282
+ }
283
+ }
284
+ return e;
285
+ }
286
+ function I() {
287
+ return typeof navigator < "u" && !!navigator.gpu;
288
+ }
289
+ async function Q(e, n = {}) {
290
+ const {
291
+ nComponents: a = 2,
292
+ nNeighbors: d = 15,
293
+ minDist: t = 0.1,
294
+ spread: h = 1,
295
+ hnsw: p = {}
296
+ } = n, u = n.nEpochs ?? (e.length > 1e4 ? 200 : 500);
297
+ console.time("knn");
298
+ const { indices: c, distances: o } = await T(e, d, {
299
+ M: p.M ?? 16,
300
+ efConstruction: p.efConstruction ?? 200,
301
+ efSearch: p.efSearch ?? 50
302
+ });
303
+ console.timeEnd("knn"), console.time("fuzzy-set");
304
+ const r = L(c, o, d);
305
+ console.timeEnd("fuzzy-set");
306
+ const { a: s, b: f } = K(t, h), l = Y(r.vals, u), m = e.length, _ = new Float32Array(m * a);
307
+ for (let g = 0; g < _.length; g++)
308
+ _[g] = Math.random() * 20 - 10;
309
+ console.time("sgd");
310
+ let w;
311
+ if (I())
312
+ try {
313
+ const g = new j();
314
+ await g.init(), w = await g.optimize(
315
+ _,
316
+ new Uint32Array(r.rows),
317
+ new Uint32Array(r.cols),
318
+ l,
319
+ m,
320
+ a,
321
+ u,
322
+ { a: s, b: f, gamma: 1, negativeSampleRate: 5 }
323
+ );
324
+ } catch (g) {
325
+ console.warn("WebGPU SGD failed, falling back to CPU:", g), w = F(_, r, l, m, a, u, { a: s, b: f });
326
+ }
327
+ else
328
+ w = F(_, r, l, m, a, u, { a: s, b: f });
329
+ return console.timeEnd("sgd"), w;
330
+ }
331
+ function K(e, n) {
332
+ if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
333
+ return { a: 1.9292, b: 0.7915 };
334
+ if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
335
+ return { a: 1.8956, b: 0.8006 };
336
+ if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
337
+ return { a: 1.5769, b: 0.8951 };
338
+ const a = H(e, n);
339
+ return { a: V(e, n, a), b: a };
340
+ }
341
+ function H(e, n) {
342
+ return 1 / (n * 1.2);
343
+ }
344
+ function V(e, n, a) {
345
+ return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * a);
346
+ }
347
+ function Y(e, n) {
348
+ let a = -1 / 0;
349
+ for (let t = 0; t < e.length; t++)
350
+ e[t] > a && (a = e[t]);
351
+ const d = new Float32Array(e.length);
352
+ for (let t = 0; t < e.length; t++) {
353
+ const h = e[t] / a;
354
+ d[t] = h > 0 ? n / h : -1;
355
+ }
356
+ return d;
357
+ }
358
+ export {
359
+ Q as fit,
360
+ I as isWebGPUAvailable
361
+ };
package/dist/umap.d.ts ADDED
@@ -0,0 +1,43 @@
1
+ export interface UMAPOptions {
2
+ /** Embedding dimensionality (default: 2) */
3
+ nComponents?: number;
4
+ /** Number of nearest neighbors (default: 15) */
5
+ nNeighbors?: number;
6
+ /** Number of optimization epochs (default: auto based on dataset size) */
7
+ nEpochs?: number;
8
+ /** Minimum distance in the embedding (default: 0.1) */
9
+ minDist?: number;
10
+ /** Spread of the embedding (default: 1.0) */
11
+ spread?: number;
12
+ /** HNSW index parameters */
13
+ hnsw?: {
14
+ M?: number;
15
+ efConstruction?: number;
16
+ efSearch?: number;
17
+ };
18
+ }
19
+ /**
20
+ * Fit UMAP to the given high-dimensional vectors and return a low-dimensional embedding.
21
+ *
22
+ * Pipeline:
23
+ * 1. HNSW k-nearest neighbor search (O(n log n) via hnswlib-wasm)
24
+ * 2. Fuzzy simplicial set construction (graph weights)
25
+ * 3. SGD optimization (WebGPU accelerated, with CPU fallback)
26
+ */
27
+ export declare function fit(vectors: number[][], opts?: UMAPOptions): Promise<Float32Array>;
28
+ /**
29
+ * Compute the a, b parameters for the UMAP curve 1/(1 + a*d^(2b)).
30
+ *
31
+ * For arbitrary minDist/spread values, a proper implementation would use
32
+ * Levenberg-Marquardt curve fitting. Here we provide pre-fitted constants
33
+ * for common parameter combinations plus an approximation for others.
34
+ */
35
+ export declare function findAB(minDist: number, spread: number): {
36
+ a: number;
37
+ b: number;
38
+ };
39
+ /**
40
+ * Compute per-edge epoch sampling periods based on edge weights.
41
+ * Higher-weight edges are sampled more frequently.
42
+ */
43
+ export declare function computeEpochsPerSample(weights: Float32Array, nEpochs: number): Float32Array;
package/package.json ADDED
@@ -0,0 +1,27 @@
1
+ {
2
+ "name": "umap-gpu",
3
+ "version": "0.1.0",
4
+ "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
+ "type": "module",
6
+ "main": "dist/index.js",
7
+ "types": "dist/index.d.ts",
8
+ "files": [
9
+ "dist"
10
+ ],
11
+ "scripts": {
12
+ "build": "vite build && tsc",
13
+ "dev": "vite",
14
+ "test": "vitest run",
15
+ "prepublishOnly": "npm test && npm run build"
16
+ },
17
+ "dependencies": {
18
+ "hnswlib-wasm": "^0.8.2"
19
+ },
20
+ "devDependencies": {
21
+ "@webgpu/types": "^0.1.40",
22
+ "typescript": "^5.4.0",
23
+ "vite": "^5.0.0",
24
+ "vitest": "^4.0.18"
25
+ },
26
+ "license": "MIT"
27
+ }