umap-gpu 0.2.10 → 0.2.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (3) hide show
  1. package/README.md +2 -0
  2. package/dist/index.js +205 -200
  3. package/package.json +1 -1
package/README.md CHANGED
@@ -1,3 +1,5 @@
1
+ # ⚠️ NOT READY YET | WORK IN PROGRESS ⚠️
2
+
1
3
  # umap-gpu
2
4
 
3
5
  UMAP dimensionality reduction with WebGPU-accelerated SGD and HNSW approximate nearest neighbors.
package/dist/index.js CHANGED
@@ -1,93 +1,93 @@
1
- var H = Object.defineProperty;
2
- var V = (e, t, r) => t in e ? H(e, t, { enumerable: !0, configurable: !0, writable: !0, value: r }) : e[t] = r;
3
- var G = (e, t, r) => V(e, typeof t != "symbol" ? t + "" : t, r);
4
- import { loadHnswlib as C } from "hnswlib-wasm";
5
- async function Y(e, t, r = {}) {
6
- const { M: s = 16, efConstruction: n = 200, efSearch: p = 50 } = r, f = await C(), d = e[0].length, c = e.length, a = new f.HierarchicalNSW("l2", d, "");
7
- a.initIndex(c, s, n, 200), a.setEfSearch(Math.max(p, t)), a.addItems(e, !1);
1
+ var V = Object.defineProperty;
2
+ var Y = (e, t, r) => t in e ? V(e, t, { enumerable: !0, configurable: !0, writable: !0, value: r }) : e[t] = r;
3
+ var G = (e, t, r) => Y(e, typeof t != "symbol" ? t + "" : t, r);
4
+ import { loadHnswlib as D } from "hnswlib-wasm";
5
+ async function Q(e, t, r = {}) {
6
+ const { M: s = 16, efConstruction: n = 200, efSearch: g = 50 } = r, c = await D(), d = e[0].length, f = e.length, a = new c.HierarchicalNSW("l2", d, "");
7
+ a.initIndex(f, s, n, 200), a.setEfSearch(Math.max(g, t)), a.addItems(e, !1);
8
8
  const o = [], i = [];
9
- for (let h = 0; h < c; h++) {
10
- const l = a.searchKnn(e[h], t + 1, void 0), u = l.neighbors.map((g, m) => ({ idx: g, dist: l.distances[m] })).filter(({ idx: g }) => g !== h).slice(0, t);
11
- o.push(u.map(({ idx: g }) => g)), i.push(u.map(({ dist: g }) => g));
9
+ for (let h = 0; h < f; h++) {
10
+ const l = a.searchKnn(e[h], t + 1, void 0), u = l.neighbors.map((p, m) => ({ idx: p, dist: l.distances[m] })).filter(({ idx: p }) => p !== h).slice(0, t);
11
+ o.push(u.map(({ idx: p }) => p)), i.push(u.map(({ dist: p }) => p));
12
12
  }
13
13
  return { indices: o, distances: i };
14
14
  }
15
- async function Q(e, t, r = {}) {
16
- const { M: s = 16, efConstruction: n = 200, efSearch: p = 50 } = r, f = await C(), d = e[0].length, c = e.length, a = new f.HierarchicalNSW("l2", d, "");
17
- a.initIndex(c, s, n, 200), a.setEfSearch(Math.max(p, t)), a.addItems(e, !1);
15
+ async function J(e, t, r = {}) {
16
+ const { M: s = 16, efConstruction: n = 200, efSearch: g = 50 } = r, c = await D(), d = e[0].length, f = e.length, a = new c.HierarchicalNSW("l2", d, "");
17
+ a.initIndex(f, s, n, 200), a.setEfSearch(Math.max(g, t)), a.addItems(e, !1);
18
18
  const o = [], i = [];
19
- for (let l = 0; l < c; l++) {
20
- const u = a.searchKnn(e[l], t + 1, void 0), g = u.neighbors.map((m, w) => ({ idx: m, dist: u.distances[w] })).filter(({ idx: m }) => m !== l).slice(0, t);
21
- o.push(g.map(({ idx: m }) => m)), i.push(g.map(({ dist: m }) => m));
19
+ for (let l = 0; l < f; l++) {
20
+ const u = a.searchKnn(e[l], t + 1, void 0), p = u.neighbors.map((m, _) => ({ idx: m, dist: u.distances[_] })).filter(({ idx: m }) => m !== l).slice(0, t);
21
+ o.push(p.map(({ idx: m }) => m)), i.push(p.map(({ dist: m }) => m));
22
22
  }
23
23
  return { knn: { indices: o, distances: i }, index: {
24
24
  searchKnn(l, u) {
25
- const g = [], m = [];
26
- for (const w of l) {
27
- const b = a.searchKnn(w, u, void 0), y = b.neighbors.map((v, x) => ({ idx: v, dist: b.distances[x] })).sort((v, x) => v.dist - x.dist).slice(0, u);
28
- g.push(y.map(({ idx: v }) => v)), m.push(y.map(({ dist: v }) => v));
25
+ const p = [], m = [];
26
+ for (const _ of l) {
27
+ const w = a.searchKnn(_, u, void 0), b = w.neighbors.map((A, x) => ({ idx: A, dist: w.distances[x] })).sort((A, x) => A.dist - x.dist).slice(0, u);
28
+ p.push(b.map(({ idx: A }) => A)), m.push(b.map(({ dist: A }) => A));
29
29
  }
30
- return { indices: g, distances: m };
30
+ return { indices: p, distances: m };
31
31
  }
32
32
  } };
33
33
  }
34
- function D(e, t, r, s = 1) {
35
- const n = e.length, { sigmas: p, rhos: f } = L(t, r), d = [], c = [], a = [];
34
+ function L(e, t, r, s = 1) {
35
+ const n = e.length, { sigmas: g, rhos: c } = W(t, r), d = [], f = [], a = [];
36
36
  for (let i = 0; i < n; i++)
37
37
  for (let h = 0; h < e[i].length; h++) {
38
- const l = t[i][h], u = l <= f[i] ? 1 : Math.exp(-((l - f[i]) / p[i]));
39
- d.push(i), c.push(e[i][h]), a.push(u);
38
+ const l = t[i][h], u = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / g[i]));
39
+ d.push(i), f.push(e[i][h]), a.push(u);
40
40
  }
41
- return { ...X(d, c, a, n, s), nVertices: n };
41
+ return { ...Z(d, f, a, n, s), nVertices: n };
42
42
  }
43
- function J(e, t, r) {
44
- const s = e.length, { sigmas: n, rhos: p } = L(t, r), f = [], d = [], c = [];
43
+ function X(e, t, r) {
44
+ const s = e.length, { sigmas: n, rhos: g } = W(t, r), c = [], d = [], f = [];
45
45
  for (let a = 0; a < s; a++)
46
46
  for (let o = 0; o < e[a].length; o++) {
47
- const i = t[a][o], h = i <= p[a] ? 1 : Math.exp(-((i - p[a]) / n[a]));
48
- f.push(a), d.push(e[a][o]), c.push(h);
47
+ const i = t[a][o], h = i <= g[a] ? 1 : Math.exp(-((i - g[a]) / n[a]));
48
+ c.push(a), d.push(e[a][o]), f.push(h);
49
49
  }
50
50
  return {
51
- rows: new Float32Array(f),
51
+ rows: new Float32Array(c),
52
52
  cols: new Float32Array(d),
53
- vals: new Float32Array(c),
53
+ vals: new Float32Array(f),
54
54
  nVertices: s
55
55
  };
56
56
  }
57
- function L(e, t) {
58
- const s = e.length, n = new Float32Array(s), p = new Float32Array(s);
59
- for (let f = 0; f < s; f++) {
60
- const d = e[f];
61
- p[f] = d.find((h) => h > 0) ?? 0;
62
- let c = 0, a = 1 / 0, o = 1;
57
+ function W(e, t) {
58
+ const s = e.length, n = new Float32Array(s), g = new Float32Array(s);
59
+ for (let c = 0; c < s; c++) {
60
+ const d = e[c];
61
+ g[c] = d.find((h) => h > 0) ?? 0;
62
+ let f = 0, a = 1 / 0, o = 1;
63
63
  const i = Math.log2(t);
64
64
  for (let h = 0; h < 64; h++) {
65
65
  let l = 0;
66
- for (let u = 1; u < d.length; u++)
67
- l += Math.exp(-Math.max(0, d[u] - p[f]) / o);
66
+ for (let u = 0; u < d.length; u++)
67
+ l += Math.exp(-Math.max(0, d[u] - g[c]) / o);
68
68
  if (Math.abs(l - i) < 1e-5) break;
69
- l > i ? (a = o, o = (c + a) / 2) : (c = o, o = a === 1 / 0 ? o * 2 : (c + a) / 2);
69
+ l > i ? (a = o, o = (f + a) / 2) : (f = o, o = a === 1 / 0 ? o * 2 : (f + a) / 2);
70
70
  }
71
- n[f] = o;
71
+ n[c] = o;
72
72
  }
73
- return { sigmas: n, rhos: p };
73
+ return { sigmas: n, rhos: g };
74
74
  }
75
- function X(e, t, r, s, n) {
76
- const p = /* @__PURE__ */ new Map();
75
+ function Z(e, t, r, s, n) {
76
+ const g = /* @__PURE__ */ new Map();
77
77
  for (let a = 0; a < e.length; a++)
78
- p.set(e[a] * s + t[a], r[a]);
79
- const f = [], d = [], c = [];
80
- for (const [a, o] of p) {
81
- const i = Math.floor(a / s), h = a % s, l = p.get(h * s + i) ?? 0, u = o + l - o * l, g = o * l;
82
- f.push(i), d.push(h), c.push(n * u + (1 - n) * g);
78
+ g.set(e[a] * s + t[a], r[a]);
79
+ const c = [], d = [], f = [];
80
+ for (const [a, o] of g) {
81
+ const i = Math.floor(a / s), h = a % s, l = g.get(h * s + i) ?? 0, u = o + l - o * l, p = o * l;
82
+ c.push(i), d.push(h), f.push(n * u + (1 - n) * p);
83
83
  }
84
84
  return {
85
- rows: new Float32Array(f),
85
+ rows: new Float32Array(c),
86
86
  cols: new Float32Array(d),
87
- vals: new Float32Array(c)
87
+ vals: new Float32Array(f)
88
88
  };
89
89
  }
90
- const Z = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
90
+ const $ = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
91
  // Applies attraction forces between connected nodes and repulsion forces
92
92
  // against negative samples.
93
93
 
@@ -158,6 +158,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
158
158
  let diff = embedding[i * nc + d] - embedding[j * nc + d];
159
159
  let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
160
160
  embedding[i * nc + d] += params.alpha * grad;
161
+ embedding[j * nc + d] -= params.alpha * grad;
161
162
  }
162
163
 
163
164
  epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
@@ -196,7 +197,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
196
197
  epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
197
198
  }
198
199
  `;
199
- class W {
200
+ class j {
200
201
  constructor() {
201
202
  G(this, "device");
202
203
  G(this, "pipeline");
@@ -207,7 +208,7 @@ class W {
207
208
  this.device = await t.requestDevice(), this.pipeline = this.device.createComputePipeline({
208
209
  layout: "auto",
209
210
  compute: {
210
- module: this.device.createShaderModule({ code: Z }),
211
+ module: this.device.createShaderModule({ code: $ }),
211
212
  entryPoint: "main"
212
213
  }
213
214
  });
@@ -225,45 +226,45 @@ class W {
225
226
  * @param params - UMAP curve parameters and repulsion settings
226
227
  * @returns Optimized embedding as Float32Array
227
228
  */
228
- async optimize(t, r, s, n, p, f, d, c, a) {
229
+ async optimize(t, r, s, n, g, c, d, f, a) {
229
230
  const { device: o } = this, i = r.length, h = this.makeBuffer(
230
231
  t,
231
232
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
232
- ), l = this.makeBuffer(r, GPUBufferUsage.STORAGE), u = this.makeBuffer(s, GPUBufferUsage.STORAGE), g = this.makeBuffer(n, GPUBufferUsage.STORAGE), m = new Float32Array(i).fill(0), w = this.makeBuffer(m, GPUBufferUsage.STORAGE), b = new Float32Array(i);
233
- for (let _ = 0; _ < i; _++)
234
- b[_] = n[_] / c.negativeSampleRate;
235
- const y = this.makeBuffer(b, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
236
- for (let _ = 0; _ < i; _++)
237
- v[_] = Math.random() * 4294967295 | 0;
238
- const x = this.makeBuffer(v, GPUBufferUsage.STORAGE), B = o.createBuffer({
233
+ ), l = this.makeBuffer(r, GPUBufferUsage.STORAGE), u = this.makeBuffer(s, GPUBufferUsage.STORAGE), p = this.makeBuffer(n, GPUBufferUsage.STORAGE), m = new Float32Array(i).fill(0), _ = this.makeBuffer(m, GPUBufferUsage.STORAGE), w = new Float32Array(i);
234
+ for (let M = 0; M < i; M++)
235
+ w[M] = n[M] / f.negativeSampleRate;
236
+ const b = this.makeBuffer(w, GPUBufferUsage.STORAGE), A = new Uint32Array(i);
237
+ for (let M = 0; M < i; M++)
238
+ A[M] = Math.random() * 4294967295 | 0;
239
+ const x = this.makeBuffer(A, GPUBufferUsage.STORAGE), U = o.createBuffer({
239
240
  size: 40,
240
241
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
241
242
  });
242
- for (let _ = 0; _ < d; _++) {
243
- const F = 1 - _ / d, A = new ArrayBuffer(40), M = new Uint32Array(A), N = new Float32Array(A);
244
- M[0] = i, M[1] = p, M[2] = f, M[3] = _, M[4] = d, N[5] = F, N[6] = c.a, N[7] = c.b, N[8] = c.gamma, M[9] = c.negativeSampleRate, o.queue.writeBuffer(B, 0, A);
245
- const S = o.createBindGroup({
243
+ for (let M = 0; M < d; M++) {
244
+ const v = 1 - M / d, S = new ArrayBuffer(40), y = new Uint32Array(S), B = new Float32Array(S);
245
+ y[0] = i, y[1] = g, y[2] = c, y[3] = M, y[4] = d, B[5] = v, B[6] = f.a, B[7] = f.b, B[8] = f.gamma, y[9] = f.negativeSampleRate, o.queue.writeBuffer(U, 0, S);
246
+ const k = o.createBindGroup({
246
247
  layout: this.pipeline.getBindGroupLayout(0),
247
248
  entries: [
248
- { binding: 0, resource: { buffer: g } },
249
+ { binding: 0, resource: { buffer: p } },
249
250
  { binding: 1, resource: { buffer: l } },
250
251
  { binding: 2, resource: { buffer: u } },
251
252
  { binding: 3, resource: { buffer: h } },
252
- { binding: 4, resource: { buffer: w } },
253
- { binding: 5, resource: { buffer: y } },
254
- { binding: 6, resource: { buffer: B } },
253
+ { binding: 4, resource: { buffer: _ } },
254
+ { binding: 5, resource: { buffer: b } },
255
+ { binding: 6, resource: { buffer: U } },
255
256
  { binding: 7, resource: { buffer: x } }
256
257
  ]
257
- }), O = o.createCommandEncoder(), U = O.beginComputePass();
258
- U.setPipeline(this.pipeline), U.setBindGroup(0, S), U.dispatchWorkgroups(Math.ceil(i / 256)), U.end(), o.queue.submit([O.finish()]), _ % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(_, d));
258
+ }), E = o.createCommandEncoder(), P = E.beginComputePass();
259
+ P.setPipeline(this.pipeline), P.setBindGroup(0, k), P.dispatchWorkgroups(Math.ceil(i / 256)), P.end(), o.queue.submit([E.finish()]), M % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(M, d));
259
260
  }
260
- const E = o.createBuffer({
261
+ const N = o.createBuffer({
261
262
  size: t.byteLength,
262
263
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
263
264
  }), R = o.createCommandEncoder();
264
- R.copyBufferToBuffer(h, 0, E, 0, t.byteLength), o.queue.submit([R.finish()]), await E.mapAsync(GPUMapMode.READ);
265
- const k = new Float32Array(E.getMappedRange().slice(0));
266
- return E.unmap(), h.destroy(), l.destroy(), u.destroy(), g.destroy(), w.destroy(), y.destroy(), x.destroy(), B.destroy(), E.destroy(), k;
265
+ R.copyBufferToBuffer(h, 0, N, 0, t.byteLength), o.queue.submit([R.finish()]), await N.mapAsync(GPUMapMode.READ);
266
+ const O = new Float32Array(N.getMappedRange().slice(0));
267
+ return N.unmap(), h.destroy(), l.destroy(), u.destroy(), p.destroy(), _.destroy(), b.destroy(), x.destroy(), U.destroy(), N.destroy(), O;
267
268
  }
268
269
  makeBuffer(t, r) {
269
270
  const s = this.device.createBuffer({
@@ -274,86 +275,90 @@ class W {
274
275
  return t instanceof Float32Array ? new Float32Array(s.getMappedRange()).set(t) : new Uint32Array(s.getMappedRange()).set(t), s.unmap(), s;
275
276
  }
276
277
  }
277
- function q(e) {
278
+ function z(e) {
278
279
  return Math.max(-4, Math.min(4, e));
279
280
  }
280
- function z(e, t, r, s, n, p, f, d) {
281
- const { a: c, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = f, h = t.rows.length, l = new Uint32Array(t.rows), u = new Uint32Array(t.cols), g = new Float32Array(h).fill(0), m = new Float32Array(h);
282
- for (let w = 0; w < h; w++)
283
- m[w] = r[w] / i;
284
- for (let w = 0; w < p; w++) {
285
- d == null || d(w, p);
286
- const b = 1 - w / p;
287
- for (let y = 0; y < h; y++) {
288
- if (g[y] > w) continue;
289
- const v = l[y], x = u[y];
290
- let B = 0;
291
- for (let _ = 0; _ < n; _++) {
292
- const F = e[v * n + _] - e[x * n + _];
293
- B += F * F;
281
+ function I(e, t, r, s, n, g, c, d) {
282
+ const { a: f, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), u = new Uint32Array(t.cols), p = new Float32Array(h).fill(0), m = new Float32Array(h);
283
+ for (let _ = 0; _ < h; _++)
284
+ m[_] = r[_] / i;
285
+ for (let _ = 0; _ < g; _++) {
286
+ d == null || d(_, g);
287
+ const w = 1 - _ / g;
288
+ for (let b = 0; b < h; b++) {
289
+ if (p[b] > _) continue;
290
+ const A = l[b], x = u[b];
291
+ let U = 0;
292
+ for (let v = 0; v < n; v++) {
293
+ const S = e[A * n + v] - e[x * n + v];
294
+ U += S * S;
294
295
  }
295
- const E = Math.pow(B, a), R = -2 * c * a * (B > 0 ? E / B : 0) / (c * E + 1);
296
- for (let _ = 0; _ < n; _++) {
297
- const F = e[v * n + _] - e[x * n + _], A = q(R * F);
298
- e[v * n + _] += b * A, e[x * n + _] -= b * A;
296
+ const N = Math.pow(U, a), R = -2 * f * a * (U > 0 ? N / U : 0) / (f * N + 1);
297
+ for (let v = 0; v < n; v++) {
298
+ const S = e[A * n + v] - e[x * n + v], y = z(R * S);
299
+ e[A * n + v] += w * y, e[x * n + v] -= w * y;
299
300
  }
300
- g[y] += r[y];
301
- const k = m[y] > 0 ? Math.floor(r[y] / m[y]) : 0;
302
- for (let _ = 0; _ < k; _++) {
303
- const F = Math.floor(Math.random() * s);
304
- if (F === v) continue;
305
- let A = 0;
306
- for (let S = 0; S < n; S++) {
307
- const O = e[v * n + S] - e[F * n + S];
308
- A += O * O;
301
+ p[b] += r[b];
302
+ const O = r[b] / i, M = Math.max(0, Math.floor(
303
+ (_ - m[b]) / O
304
+ ));
305
+ m[b] += M * O;
306
+ for (let v = 0; v < M; v++) {
307
+ const S = Math.floor(Math.random() * s);
308
+ if (S === A) continue;
309
+ let y = 0;
310
+ for (let E = 0; E < n; E++) {
311
+ const P = e[A * n + E] - e[S * n + E];
312
+ y += P * P;
309
313
  }
310
- const M = Math.pow(A, a), N = 2 * o * a / ((1e-3 + A) * (c * M + 1));
311
- for (let S = 0; S < n; S++) {
312
- const O = e[v * n + S] - e[F * n + S], U = q(N * O);
313
- e[v * n + S] += b * U;
314
+ const B = Math.pow(y, a), k = 2 * o * a / ((1e-3 + y) * (f * B + 1));
315
+ for (let E = 0; E < n; E++) {
316
+ const P = e[A * n + E] - e[S * n + E], F = z(k * P);
317
+ e[A * n + E] += w * F;
314
318
  }
315
319
  }
316
- m[y] += r[y] / i;
317
320
  }
318
321
  }
319
322
  return e;
320
323
  }
321
- function $(e, t, r, s, n, p, f, d, c, a) {
322
- const { a: o, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = c, u = r.rows.length, g = new Uint32Array(r.rows), m = new Uint32Array(r.cols), w = new Float32Array(u).fill(0), b = new Float32Array(u);
323
- for (let y = 0; y < u; y++)
324
- b[y] = s[y] / l;
325
- for (let y = 0; y < d; y++) {
326
- const v = 1 - y / d;
324
+ function ee(e, t, r, s, n, g, c, d, f, a) {
325
+ const { a: o, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = f, u = r.rows.length, p = new Uint32Array(r.rows), m = new Uint32Array(r.cols), _ = new Float32Array(u).fill(0), w = new Float32Array(u);
326
+ for (let b = 0; b < u; b++)
327
+ w[b] = s[b] / l;
328
+ for (let b = 0; b < d; b++) {
329
+ const A = 1 - b / d;
327
330
  for (let x = 0; x < u; x++) {
328
- if (w[x] > y) continue;
329
- const B = g[x], E = m[x];
331
+ if (_[x] > b) continue;
332
+ const U = p[x], N = m[x];
330
333
  let R = 0;
331
- for (let A = 0; A < f; A++) {
332
- const M = e[B * f + A] - t[E * f + A];
333
- R += M * M;
334
+ for (let y = 0; y < c; y++) {
335
+ const B = e[U * c + y] - t[N * c + y];
336
+ R += B * B;
334
337
  }
335
- const k = Math.pow(R, i), _ = -2 * o * i * (R > 0 ? k / R : 0) / (o * k + 1);
336
- for (let A = 0; A < f; A++) {
337
- const M = e[B * f + A] - t[E * f + A];
338
- e[B * f + A] += v * q(_ * M);
338
+ const O = Math.pow(R, i), M = -2 * o * i * (R > 0 ? O / R : 0) / (o * O + 1);
339
+ for (let y = 0; y < c; y++) {
340
+ const B = e[U * c + y] - t[N * c + y];
341
+ e[U * c + y] += A * z(M * B);
339
342
  }
340
- w[x] += s[x];
341
- const F = b[x] > 0 ? Math.floor(s[x] / b[x]) : 0;
342
- for (let A = 0; A < F; A++) {
343
- const M = Math.floor(Math.random() * p);
344
- if (M === E) continue;
345
- let N = 0;
346
- for (let U = 0; U < f; U++) {
347
- const P = e[B * f + U] - t[M * f + U];
348
- N += P * P;
343
+ _[x] += s[x];
344
+ const v = s[x] / l, S = Math.max(0, Math.floor(
345
+ (b - w[x]) / v
346
+ ));
347
+ w[x] += S * v;
348
+ for (let y = 0; y < S; y++) {
349
+ const B = Math.floor(Math.random() * g);
350
+ if (B === N) continue;
351
+ let k = 0;
352
+ for (let F = 0; F < c; F++) {
353
+ const q = e[U * c + F] - t[B * c + F];
354
+ k += q * q;
349
355
  }
350
- const S = Math.pow(N, i), O = 2 * h * i / ((1e-3 + N) * (o * S + 1));
351
- for (let U = 0; U < f; U++) {
352
- const P = e[B * f + U] - t[M * f + U];
353
- e[B * f + U] += v * q(O * P);
356
+ const E = Math.pow(k, i), P = 2 * h * i / ((1e-3 + k) * (o * E + 1));
357
+ for (let F = 0; F < c; F++) {
358
+ const q = e[U * c + F] - t[B * c + F];
359
+ e[U * c + F] += A * z(P * q);
354
360
  }
355
361
  }
356
- b[x] += s[x] / l;
357
362
  }
358
363
  }
359
364
  return e;
@@ -361,66 +366,66 @@ function $(e, t, r, s, n, p, f, d, c, a) {
361
366
  function K() {
362
367
  return typeof navigator < "u" && !!navigator.gpu;
363
368
  }
364
- async function ae(e, t = {}, r) {
369
+ async function re(e, t = {}, r) {
365
370
  const {
366
371
  nComponents: s = 2,
367
372
  nNeighbors: n = 15,
368
- minDist: p = 0.1,
369
- spread: f = 1,
373
+ minDist: g = 0.1,
374
+ spread: c = 1,
370
375
  hnsw: d = {}
371
- } = t, c = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
376
+ } = t, f = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
372
377
  console.time("knn");
373
- const { indices: a, distances: o } = await Y(e, n, {
378
+ const { indices: a, distances: o } = await Q(e, n, {
374
379
  M: d.M ?? 16,
375
380
  efConstruction: d.efConstruction ?? 200,
376
381
  efSearch: d.efSearch ?? 50
377
382
  });
378
383
  console.timeEnd("knn"), console.time("fuzzy-set");
379
- const i = D(a, o, n);
384
+ const i = L(a, o, n);
380
385
  console.timeEnd("fuzzy-set");
381
- const { a: h, b: l } = j(p, f), u = I(i.vals, c), g = e.length, m = new Float32Array(g * s);
382
- for (let b = 0; b < m.length; b++)
383
- m[b] = Math.random() * 20 - 10;
386
+ const { a: h, b: l } = H(g, c), u = T(i.vals), p = e.length, m = new Float32Array(p * s);
387
+ for (let w = 0; w < m.length; w++)
388
+ m[w] = Math.random() * 20 - 10;
384
389
  console.time("sgd");
385
- let w;
390
+ let _;
386
391
  if (K())
387
392
  try {
388
- const b = new W();
389
- await b.init(), w = await b.optimize(
393
+ const w = new j();
394
+ await w.init(), _ = await w.optimize(
390
395
  m,
391
396
  new Uint32Array(i.rows),
392
397
  new Uint32Array(i.cols),
393
398
  u,
394
- g,
399
+ p,
395
400
  s,
396
- c,
401
+ f,
397
402
  { a: h, b: l, gamma: 1, negativeSampleRate: 5 },
398
403
  r
399
404
  );
400
- } catch (b) {
401
- console.warn("WebGPU SGD failed, falling back to CPU:", b), w = z(m, i, u, g, s, c, { a: h, b: l }, r);
405
+ } catch (w) {
406
+ console.warn("WebGPU SGD failed, falling back to CPU:", w), _ = I(m, i, u, p, s, f, { a: h, b: l }, r);
402
407
  }
403
408
  else
404
- w = z(m, i, u, g, s, c, { a: h, b: l }, r);
405
- return console.timeEnd("sgd"), w;
409
+ _ = I(m, i, u, p, s, f, { a: h, b: l }, r);
410
+ return console.timeEnd("sgd"), _;
406
411
  }
407
- function j(e, t) {
412
+ function H(e, t) {
408
413
  if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
409
414
  return { a: 1.9292, b: 0.7915 };
410
415
  if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
411
416
  return { a: 1.8956, b: 0.8006 };
412
417
  if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
413
418
  return { a: 1.5769, b: 0.8951 };
414
- const r = ee(e, t);
415
- return { a: te(e, t, r), b: r };
419
+ const r = te(e, t);
420
+ return { a: ne(e, t, r), b: r };
416
421
  }
417
- function ee(e, t) {
422
+ function te(e, t) {
418
423
  return 1 / (t * 1.2);
419
424
  }
420
- function te(e, t, r) {
425
+ function ne(e, t, r) {
421
426
  return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * r);
422
427
  }
423
- class re {
428
+ class ie {
424
429
  constructor(t = {}) {
425
430
  G(this, "_nComponents");
426
431
  G(this, "_nNeighbors");
@@ -435,7 +440,7 @@ class re {
435
440
  G(this, "_hnswIndex", null);
436
441
  G(this, "_nTrain", 0);
437
442
  this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
438
- const { a: r, b: s } = j(this._minDist, this._spread);
443
+ const { a: r, b: s } = H(this._minDist, this._spread);
439
444
  this._a = r, this._b = s;
440
445
  }
441
446
  /**
@@ -445,22 +450,22 @@ class re {
445
450
  * Returns `this` for chaining.
446
451
  */
447
452
  async fit(t, r) {
448
- const s = t.length, n = this._nEpochs ?? (s > 1e4 ? 200 : 500), { M: p = 16, efConstruction: f = 200, efSearch: d = 50 } = this._hnswOpts;
453
+ const s = t.length, n = this._nEpochs ?? (s > 1e4 ? 200 : 500), { M: g = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
449
454
  console.time("knn");
450
- const { knn: c, index: a } = await Q(t, this._nNeighbors, {
451
- M: p,
452
- efConstruction: f,
455
+ const { knn: f, index: a } = await J(t, this._nNeighbors, {
456
+ M: g,
457
+ efConstruction: c,
453
458
  efSearch: d
454
459
  });
455
460
  this._hnswIndex = a, this._nTrain = s, console.timeEnd("knn"), console.time("fuzzy-set");
456
- const o = D(c.indices, c.distances, this._nNeighbors);
461
+ const o = L(f.indices, f.distances, this._nNeighbors);
457
462
  console.timeEnd("fuzzy-set");
458
- const i = I(o.vals, n), h = new Float32Array(s * this._nComponents);
463
+ const i = T(o.vals), h = new Float32Array(s * this._nComponents);
459
464
  for (let l = 0; l < h.length; l++)
460
465
  h[l] = Math.random() * 20 - 10;
461
466
  if (console.time("sgd"), K())
462
467
  try {
463
- const l = new W();
468
+ const l = new j();
464
469
  await l.init(), this.embedding = await l.optimize(
465
470
  h,
466
471
  new Uint32Array(o.rows),
@@ -473,13 +478,13 @@ class re {
473
478
  r
474
479
  );
475
480
  } catch (l) {
476
- console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = z(h, o, i, s, this._nComponents, n, {
481
+ console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = I(h, o, i, s, this._nComponents, n, {
477
482
  a: this._a,
478
483
  b: this._b
479
484
  }, r);
480
485
  }
481
486
  else
482
- this.embedding = z(h, o, i, s, this._nComponents, n, {
487
+ this.embedding = I(h, o, i, s, this._nComponents, n, {
483
488
  a: this._a,
484
489
  b: this._b
485
490
  }, r);
@@ -499,21 +504,21 @@ class re {
499
504
  async transform(t, r = !1) {
500
505
  if (!this._hnswIndex || !this.embedding)
501
506
  throw new Error("UMAP.transform() must be called after fit()");
502
- const s = t.length, n = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), p = Math.max(100, Math.floor(n / 4)), f = this._hnswIndex.searchKnn(t, this._nNeighbors), d = J(f.indices, f.distances, this._nNeighbors), c = new Uint32Array(d.rows), a = new Uint32Array(d.cols), o = new Float32Array(s), i = new Float32Array(s * this._nComponents);
503
- for (let u = 0; u < c.length; u++) {
504
- const g = c[u], m = a[u], w = d.vals[u];
505
- o[g] += w;
506
- for (let b = 0; b < this._nComponents; b++)
507
- i[g * this._nComponents + b] += w * this.embedding[m * this._nComponents + b];
507
+ const s = t.length, n = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), g = Math.max(100, Math.floor(n / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = X(c.indices, c.distances, this._nNeighbors), f = new Uint32Array(d.rows), a = new Uint32Array(d.cols), o = new Float32Array(s), i = new Float32Array(s * this._nComponents);
508
+ for (let u = 0; u < f.length; u++) {
509
+ const p = f[u], m = a[u], _ = d.vals[u];
510
+ o[p] += _;
511
+ for (let w = 0; w < this._nComponents; w++)
512
+ i[p * this._nComponents + w] += _ * this.embedding[m * this._nComponents + w];
508
513
  }
509
514
  for (let u = 0; u < s; u++)
510
515
  if (o[u] > 0)
511
- for (let g = 0; g < this._nComponents; g++)
512
- i[u * this._nComponents + g] /= o[u];
516
+ for (let p = 0; p < this._nComponents; p++)
517
+ i[u * this._nComponents + p] /= o[u];
513
518
  else
514
- for (let g = 0; g < this._nComponents; g++)
515
- i[u * this._nComponents + g] = Math.random() * 20 - 10;
516
- const h = I(d.vals, p), l = $(
519
+ for (let p = 0; p < this._nComponents; p++)
520
+ i[u * this._nComponents + p] = Math.random() * 20 - 10;
521
+ const h = T(d.vals), l = ee(
517
522
  i,
518
523
  this.embedding,
519
524
  d,
@@ -521,10 +526,10 @@ class re {
521
526
  s,
522
527
  this._nTrain,
523
528
  this._nComponents,
524
- p,
529
+ g,
525
530
  { a: this._a, b: this._b }
526
531
  );
527
- return r ? T(l, s, this._nComponents) : l;
532
+ return r ? C(l, s, this._nComponents) : l;
528
533
  }
529
534
  /**
530
535
  * Convenience method equivalent to `fit(vectors)` followed by
@@ -536,36 +541,36 @@ class re {
536
541
  * Defaults to `false`.
537
542
  */
538
543
  async fit_transform(t, r, s = !1) {
539
- return await this.fit(t, r), s ? T(this.embedding, t.length, this._nComponents) : this.embedding;
544
+ return await this.fit(t, r), s ? C(this.embedding, t.length, this._nComponents) : this.embedding;
540
545
  }
541
546
  }
542
- function T(e, t, r) {
547
+ function C(e, t, r) {
543
548
  const s = new Float32Array(e.length);
544
549
  for (let n = 0; n < r; n++) {
545
- let p = 1 / 0, f = -1 / 0;
546
- for (let c = 0; c < t; c++) {
547
- const a = e[c * r + n];
548
- a < p && (p = a), a > f && (f = a);
550
+ let g = 1 / 0, c = -1 / 0;
551
+ for (let f = 0; f < t; f++) {
552
+ const a = e[f * r + n];
553
+ a < g && (g = a), a > c && (c = a);
549
554
  }
550
- const d = f - p;
551
- for (let c = 0; c < t; c++)
552
- s[c * r + n] = d > 0 ? (e[c * r + n] - p) / d : 0;
555
+ const d = c - g;
556
+ for (let f = 0; f < t; f++)
557
+ s[f * r + n] = d > 0 ? (e[f * r + n] - g) / d : 0;
553
558
  }
554
559
  return s;
555
560
  }
556
- function I(e, t) {
561
+ function T(e, t) {
557
562
  let r = -1 / 0;
558
563
  for (let n = 0; n < e.length; n++)
559
564
  e[n] > r && (r = e[n]);
560
565
  const s = new Float32Array(e.length);
561
566
  for (let n = 0; n < e.length; n++) {
562
- const p = e[n] / r;
563
- s[n] = p > 0 ? t / p : -1;
567
+ const g = e[n] / r;
568
+ s[n] = g > 0 ? 1 / g : -1;
564
569
  }
565
570
  return s;
566
571
  }
567
572
  export {
568
- re as UMAP,
569
- ae as fit,
573
+ ie as UMAP,
574
+ re as fit,
570
575
  K as isWebGPUAvailable
571
576
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.2.10",
3
+ "version": "0.2.13",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",