umap-gpu 0.2.14 → 0.2.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (2) hide show
  1. package/dist/index.js +158 -155
  2. package/package.json +3 -2
package/dist/index.js CHANGED
@@ -1,56 +1,56 @@
1
1
  var te = Object.defineProperty;
2
2
  var ne = (e, t, f) => t in e ? te(e, t, { enumerable: !0, configurable: !0, writable: !0, value: f }) : e[t] = f;
3
- var C = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
3
+ var R = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
4
4
  import { loadHnswlib as H } from "hnswlib-wasm";
5
5
  async function se(e, t, f = {}) {
6
- const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, i = e.length, n = new c.HierarchicalNSW("l2", d, "");
7
- n.initIndex(i, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
8
- const r = [], o = [];
9
- for (let h = 0; h < i; h++) {
6
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
7
+ n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
8
+ const r = [], i = [];
9
+ for (let h = 0; h < o; h++) {
10
10
  const l = n.searchKnn(e[h], t + 1, void 0), p = l.neighbors.map((g, _) => ({ idx: g, dist: l.distances[_] })).filter(({ idx: g }) => g !== h).slice(0, t);
11
- r.push(p.map(({ idx: g }) => g)), o.push(p.map(({ dist: g }) => Math.sqrt(g)));
11
+ r.push(p.map(({ idx: g }) => g)), i.push(p.map(({ dist: g }) => Math.sqrt(g)));
12
12
  }
13
- return { indices: r, distances: o };
13
+ return { indices: r, distances: i };
14
14
  }
15
15
  async function ae(e, t, f = {}) {
16
- const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, i = e.length, n = new c.HierarchicalNSW("l2", d, "");
17
- n.initIndex(i, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
18
- const r = [], o = [];
19
- for (let l = 0; l < i; l++) {
16
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
17
+ n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
18
+ const r = [], i = [];
19
+ for (let l = 0; l < o; l++) {
20
20
  const p = n.searchKnn(e[l], t + 1, void 0), g = p.neighbors.map((_, m) => ({ idx: _, dist: p.distances[m] })).filter(({ idx: _ }) => _ !== l).slice(0, t);
21
- r.push(g.map(({ idx: _ }) => _)), o.push(g.map(({ dist: _ }) => Math.sqrt(_)));
21
+ r.push(g.map(({ idx: _ }) => _)), i.push(g.map(({ dist: _ }) => Math.sqrt(_)));
22
22
  }
23
- return { knn: { indices: r, distances: o }, index: {
23
+ return { knn: { indices: r, distances: i }, index: {
24
24
  searchKnn(l, p) {
25
25
  const g = [], _ = [];
26
26
  for (const m of l) {
27
- const y = n.searchKnn(m, p, void 0), b = y.neighbors.map((x, v) => ({ idx: x, dist: y.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
28
- g.push(b.map(({ idx: x }) => x)), _.push(b.map(({ dist: x }) => Math.sqrt(x)));
27
+ const b = n.searchKnn(m, p, void 0), y = b.neighbors.map((x, v) => ({ idx: x, dist: b.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
28
+ g.push(y.map(({ idx: x }) => x)), _.push(y.map(({ dist: x }) => Math.sqrt(x)));
29
29
  }
30
30
  return { indices: g, distances: _ };
31
31
  }
32
32
  } };
33
33
  }
34
34
  function V(e, t, f, a = 1) {
35
- const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], i = [], n = [];
36
- for (let o = 0; o < s; o++)
37
- for (let h = 0; h < e[o].length; h++) {
38
- const l = t[o][h], p = l <= c[o] ? 1 : Math.exp(-((l - c[o]) / u[o]));
39
- d.push(o), i.push(e[o][h]), n.push(p);
35
+ const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], o = [], n = [];
36
+ for (let i = 0; i < s; i++)
37
+ for (let h = 0; h < e[i].length; h++) {
38
+ const l = t[i][h], p = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / u[i]));
39
+ d.push(i), o.push(e[i][h]), n.push(p);
40
40
  }
41
- return { ...oe(d, i, n, s, a), nVertices: s };
41
+ return { ...ie(d, o, n, s, a), nVertices: s };
42
42
  }
43
43
  function re(e, t, f) {
44
- const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], i = [];
44
+ const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], o = [];
45
45
  for (let n = 0; n < a; n++)
46
46
  for (let r = 0; r < e[n].length; r++) {
47
- const o = t[n][r], h = o <= u[n] ? 1 : Math.exp(-((o - u[n]) / s[n]));
48
- c.push(n), d.push(e[n][r]), i.push(h);
47
+ const i = t[n][r], h = i <= u[n] ? 1 : Math.exp(-((i - u[n]) / s[n]));
48
+ c.push(n), d.push(e[n][r]), o.push(h);
49
49
  }
50
50
  return {
51
51
  rows: new Uint32Array(c),
52
52
  cols: new Uint32Array(d),
53
- vals: new Float32Array(i),
53
+ vals: new Float32Array(o),
54
54
  nVertices: a
55
55
  };
56
56
  }
@@ -59,35 +59,35 @@ function Y(e, t) {
59
59
  for (let c = 0; c < a; c++) {
60
60
  const d = e[c];
61
61
  u[c] = d.find((h) => h > 0) ?? 0;
62
- let i = 0, n = 1 / 0, r = 1;
63
- const o = Math.log2(t);
62
+ let o = 0, n = 1 / 0, r = 1;
63
+ const i = Math.log2(t);
64
64
  for (let h = 0; h < 64; h++) {
65
65
  let l = 0;
66
66
  for (let p = 0; p < d.length; p++)
67
67
  l += Math.exp(-Math.max(0, d[p] - u[c]) / r);
68
- if (Math.abs(l - o) < 1e-5) break;
69
- l > o ? (n = r, r = (i + n) / 2) : (i = r, r = n === 1 / 0 ? r * 2 : (i + n) / 2);
68
+ if (Math.abs(l - i) < 1e-5) break;
69
+ l > i ? (n = r, r = (o + n) / 2) : (o = r, r = n === 1 / 0 ? r * 2 : (o + n) / 2);
70
70
  }
71
71
  s[c] = r;
72
72
  }
73
73
  return { sigmas: s, rhos: u };
74
74
  }
75
- function oe(e, t, f, a, s) {
75
+ function ie(e, t, f, a, s) {
76
76
  const u = /* @__PURE__ */ new Map();
77
77
  for (let n = 0; n < e.length; n++)
78
78
  u.set(e[n] * a + t[n], f[n]);
79
- const c = [], d = [], i = [];
79
+ const c = [], d = [], o = [];
80
80
  for (const [n, r] of u) {
81
- const o = Math.floor(n / a), h = n % a, l = u.get(h * a + o) ?? 0, p = r + l - r * l, g = r * l;
82
- c.push(o), d.push(h), i.push(s * p + (1 - s) * g);
81
+ const i = Math.floor(n / a), h = n % a, l = u.get(h * a + i) ?? 0, p = r + l - r * l, g = r * l;
82
+ c.push(i), d.push(h), o.push(s * p + (1 - s) * g);
83
83
  }
84
84
  return {
85
85
  rows: new Uint32Array(c),
86
86
  cols: new Uint32Array(d),
87
- vals: new Float32Array(i)
87
+ vals: new Float32Array(o)
88
88
  };
89
89
  }
90
- const ie = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
90
+ const oe = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
91
  // Computes attraction and repulsion forces and accumulates them atomically
92
92
  // into a forces buffer. A separate apply-forces pass then updates embeddings,
93
93
  // eliminating write-write races on shared vertex positions.
@@ -181,7 +181,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
181
181
  epoch_of_next_negative_sample[edge_idx] += f32(n_neg) * epochs_per_neg;
182
182
  }
183
183
 
184
- var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 6364136223u);
184
+ // 2654435761u is the 32-bit golden-ratio hash constant (0x9E3779B1),
185
+ // which fits in u32 unlike the 64-bit LCG value 6364136223 used originally.
186
+ // Bug 14 fix: the original constant exceeded u32 range and failed WGSL validation.
187
+ var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 2654435761u);
185
188
 
186
189
  for (var s = 0u; s < n_neg; s++) {
187
190
  rng = xorshift(rng);
@@ -250,9 +253,9 @@ async function he() {
250
253
  }
251
254
  class X {
252
255
  constructor() {
253
- C(this, "device");
254
- C(this, "sgdPipeline");
255
- C(this, "applyForcesPipeline");
256
+ R(this, "device");
257
+ R(this, "sgdPipeline");
258
+ R(this, "applyForcesPipeline");
256
259
  }
257
260
  async init() {
258
261
  const t = await Q();
@@ -260,7 +263,7 @@ class X {
260
263
  this.device = t, this.sgdPipeline = this.device.createComputePipeline({
261
264
  layout: "auto",
262
265
  compute: {
263
- module: this.device.createShaderModule({ code: ie }),
266
+ module: this.device.createShaderModule({ code: oe }),
264
267
  entryPoint: "main"
265
268
  }
266
269
  }), this.applyForcesPipeline = this.device.createComputePipeline({
@@ -284,22 +287,22 @@ class X {
284
287
  * @param params - UMAP curve parameters and repulsion settings
285
288
  * @returns Optimized embedding as Float32Array
286
289
  */
287
- async optimize(t, f, a, s, u, c, d, i, n) {
288
- const { device: r } = this, o = f.length, h = u * c, l = this.makeBuffer(
290
+ async optimize(t, f, a, s, u, c, d, o, n) {
291
+ const { device: r } = this, i = f.length, h = u * c, l = this.makeBuffer(
289
292
  t,
290
293
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
291
- ), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), y = this.makeBuffer(m, GPUBufferUsage.STORAGE), b = new Float32Array(o);
292
- for (let A = 0; A < o; A++)
293
- b[A] = s[A] / i.negativeSampleRate;
294
- const x = this.makeBuffer(b, GPUBufferUsage.STORAGE), v = new Uint32Array(o);
295
- for (let A = 0; A < o; A++)
294
+ ), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), b = this.makeBuffer(m, GPUBufferUsage.STORAGE), y = new Float32Array(i);
295
+ for (let A = 0; A < i; A++)
296
+ y[A] = s[A] / o.negativeSampleRate;
297
+ const x = this.makeBuffer(y, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
298
+ for (let A = 0; A < i; A++)
296
299
  v[A] = Math.random() * 4294967295 | 0;
297
- const P = this.makeBuffer(v, GPUBufferUsage.STORAGE), G = r.createBuffer({
300
+ const G = this.makeBuffer(v, GPUBufferUsage.STORAGE), P = r.createBuffer({
298
301
  size: h * 4,
299
302
  usage: GPUBufferUsage.STORAGE,
300
303
  mappedAtCreation: !0
301
304
  });
302
- new Int32Array(G.getMappedRange()).fill(0), G.unmap();
305
+ new Int32Array(P.getMappedRange()).fill(0), P.unmap();
303
306
  const U = r.createBuffer({
304
307
  size: 40,
305
308
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
@@ -313,37 +316,37 @@ class X {
313
316
  { binding: 1, resource: { buffer: p } },
314
317
  { binding: 2, resource: { buffer: g } },
315
318
  { binding: 3, resource: { buffer: l } },
316
- { binding: 4, resource: { buffer: y } },
319
+ { binding: 4, resource: { buffer: b } },
317
320
  { binding: 5, resource: { buffer: x } },
318
321
  { binding: 6, resource: { buffer: U } },
319
- { binding: 7, resource: { buffer: P } },
320
- { binding: 8, resource: { buffer: G } }
322
+ { binding: 7, resource: { buffer: G } },
323
+ { binding: 8, resource: { buffer: P } }
321
324
  ]
322
325
  }), M = r.createBindGroup({
323
326
  layout: this.applyForcesPipeline.getBindGroupLayout(0),
324
327
  entries: [
325
328
  { binding: 0, resource: { buffer: l } },
326
- { binding: 1, resource: { buffer: G } },
329
+ { binding: 1, resource: { buffer: P } },
327
330
  { binding: 2, resource: { buffer: F } }
328
331
  ]
329
332
  });
330
333
  for (let A = 0; A < d; A++) {
331
334
  const B = 1 - A / d, O = new ArrayBuffer(40), S = new Uint32Array(O), k = new Float32Array(O);
332
- S[0] = o, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = i.a, k[7] = i.b, k[8] = i.gamma, S[9] = i.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
335
+ S[0] = i, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = o.a, k[7] = o.b, k[8] = o.gamma, S[9] = o.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
333
336
  const D = new ArrayBuffer(16), $ = new Uint32Array(D), ee = new Float32Array(D);
334
337
  $[0] = h, ee[1] = B, r.queue.writeBuffer(F, 0, D);
335
338
  const T = r.createCommandEncoder(), q = T.beginComputePass();
336
- q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(o / 256)), q.end();
337
- const I = T.beginComputePass();
338
- I.setPipeline(this.applyForcesPipeline), I.setBindGroup(0, M), I.dispatchWorkgroups(Math.ceil(h / 256)), I.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
339
+ q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(i / 256)), q.end();
340
+ const L = T.beginComputePass();
341
+ L.setPipeline(this.applyForcesPipeline), L.setBindGroup(0, M), L.dispatchWorkgroups(Math.ceil(h / 256)), L.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
339
342
  }
340
343
  const E = r.createBuffer({
341
344
  size: t.byteLength,
342
345
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
343
346
  }), w = r.createCommandEncoder();
344
347
  w.copyBufferToBuffer(l, 0, E, 0, t.byteLength), r.queue.submit([w.finish()]), await E.mapAsync(GPUMapMode.READ);
345
- const R = new Float32Array(E.getMappedRange().slice(0));
346
- return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), y.destroy(), x.destroy(), P.destroy(), G.destroy(), U.destroy(), F.destroy(), E.destroy(), R;
348
+ const C = new Float32Array(E.getMappedRange().slice(0));
349
+ return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), b.destroy(), x.destroy(), G.destroy(), P.destroy(), U.destroy(), F.destroy(), E.destroy(), C;
347
350
  }
348
351
  makeBuffer(t, f) {
349
352
  const a = this.device.createBuffer({
@@ -354,34 +357,34 @@ class X {
354
357
  return t instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(t) : new Uint32Array(a.getMappedRange()).set(t), a.unmap(), a;
355
358
  }
356
359
  }
357
- function j(e) {
360
+ function I(e) {
358
361
  return Math.max(-4, Math.min(4, e));
359
362
  }
360
- function L(e, t, f, a, s, u, c, d) {
361
- const { a: i, b: n, gamma: r = 1, negativeSampleRate: o = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
363
+ function j(e, t, f, a, s, u, c, d) {
364
+ const { a: o, b: n, gamma: r = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
362
365
  for (let m = 0; m < h; m++)
363
- _[m] = f[m] / o;
366
+ _[m] = f[m] / i;
364
367
  for (let m = 0; m < u; m++) {
365
368
  d == null || d(m, u);
366
- const y = 1 - m / u;
367
- for (let b = 0; b < h; b++) {
368
- if (g[b] > m) continue;
369
- const x = l[b], v = p[b];
370
- let P = 0;
369
+ const b = 1 - m / u;
370
+ for (let y = 0; y < h; y++) {
371
+ if (g[y] > m) continue;
372
+ const x = l[y], v = p[y];
373
+ let G = 0;
371
374
  for (let M = 0; M < s; M++) {
372
375
  const E = e[x * s + M] - e[v * s + M];
373
- P += E * E;
376
+ G += E * E;
374
377
  }
375
- const G = Math.pow(P, n), U = -2 * i * n * (P > 0 ? G / P : 0) / (i * G + 1);
378
+ const P = Math.pow(G, n), U = -2 * o * n * (G > 0 ? P / G : 0) / (o * P + 1);
376
379
  for (let M = 0; M < s; M++) {
377
- const E = e[x * s + M] - e[v * s + M], w = j(U * E);
378
- e[x * s + M] += y * w, e[v * s + M] -= y * w;
380
+ const E = e[x * s + M] - e[v * s + M], w = I(U * E);
381
+ e[x * s + M] += b * w, e[v * s + M] -= b * w;
379
382
  }
380
- g[b] += f[b];
381
- const F = f[b] / o, N = Math.max(0, Math.floor(
382
- (m - _[b]) / F
383
+ g[y] += f[y];
384
+ const F = f[y] / i, N = Math.max(0, Math.floor(
385
+ (m - _[y]) / F
383
386
  ));
384
- _[b] += N * F;
387
+ _[y] += N * F;
385
388
  for (let M = 0; M < N; M++) {
386
389
  const E = Math.floor(Math.random() * a);
387
390
  if (E === x) continue;
@@ -390,52 +393,52 @@ function L(e, t, f, a, s, u, c, d) {
390
393
  const O = e[x * s + B] - e[E * s + B];
391
394
  w += O * O;
392
395
  }
393
- const R = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (i * R + 1));
396
+ const C = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (o * C + 1));
394
397
  for (let B = 0; B < s; B++) {
395
- const O = e[x * s + B] - e[E * s + B], S = j(A * O);
396
- e[x * s + B] += y * S;
398
+ const O = e[x * s + B] - e[E * s + B], S = I(A * O);
399
+ e[x * s + B] += b * S;
397
400
  }
398
401
  }
399
402
  }
400
403
  }
401
404
  return e;
402
405
  }
403
- function fe(e, t, f, a, s, u, c, d, i, n) {
404
- const { a: r, b: o, gamma: h = 1, negativeSampleRate: l = 5 } = i, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), y = new Float32Array(p);
405
- for (let b = 0; b < p; b++)
406
- y[b] = a[b] / l;
407
- for (let b = 0; b < d; b++) {
408
- const x = 1 - b / d;
406
+ function fe(e, t, f, a, s, u, c, d, o, n) {
407
+ const { a: r, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = o, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), b = new Float32Array(p);
408
+ for (let y = 0; y < p; y++)
409
+ b[y] = a[y] / l;
410
+ for (let y = 0; y < d; y++) {
411
+ const x = 1 - y / d;
409
412
  for (let v = 0; v < p; v++) {
410
- if (m[v] > b) continue;
411
- const P = g[v], G = _[v];
413
+ if (m[v] > y) continue;
414
+ const G = g[v], P = _[v];
412
415
  let U = 0;
413
416
  for (let w = 0; w < c; w++) {
414
- const R = e[P * c + w] - t[G * c + w];
415
- U += R * R;
417
+ const C = e[G * c + w] - t[P * c + w];
418
+ U += C * C;
416
419
  }
417
- const F = Math.pow(U, o), N = -2 * r * o * (U > 0 ? F / U : 0) / (r * F + 1);
420
+ const F = Math.pow(U, i), N = -2 * r * i * (U > 0 ? F / U : 0) / (r * F + 1);
418
421
  for (let w = 0; w < c; w++) {
419
- const R = e[P * c + w] - t[G * c + w];
420
- e[P * c + w] += x * j(N * R);
422
+ const C = e[G * c + w] - t[P * c + w];
423
+ e[G * c + w] += x * I(N * C);
421
424
  }
422
425
  m[v] += a[v];
423
426
  const M = a[v] / l, E = Math.max(0, Math.floor(
424
- (b - y[v]) / M
427
+ (y - b[v]) / M
425
428
  ));
426
- y[v] += E * M;
429
+ b[v] += E * M;
427
430
  for (let w = 0; w < E; w++) {
428
- const R = Math.floor(Math.random() * u);
429
- if (R === G) continue;
431
+ const C = Math.floor(Math.random() * u);
432
+ if (C === P) continue;
430
433
  let A = 0;
431
434
  for (let S = 0; S < c; S++) {
432
- const k = e[P * c + S] - t[R * c + S];
435
+ const k = e[G * c + S] - t[C * c + S];
433
436
  A += k * k;
434
437
  }
435
- const B = Math.pow(A, o), O = 2 * h * o / ((1e-3 + A) * (r * B + 1));
438
+ const B = Math.pow(A, i), O = 2 * h * i / ((1e-3 + A) * (r * B + 1));
436
439
  for (let S = 0; S < c; S++) {
437
- const k = e[P * c + S] - t[R * c + S];
438
- e[P * c + S] += x * j(O * k);
440
+ const k = e[G * c + S] - t[C * c + S];
441
+ e[G * c + S] += x * I(O * k);
439
442
  }
440
443
  }
441
444
  }
@@ -449,7 +452,7 @@ async function pe(e, t = {}, f) {
449
452
  minDist: u = 0.1,
450
453
  spread: c = 1,
451
454
  hnsw: d = {}
452
- } = t, i = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
455
+ } = t, o = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
453
456
  console.time("knn");
454
457
  const { indices: n, distances: r } = await se(e, s, {
455
458
  M: d.M ?? 16,
@@ -457,32 +460,32 @@ async function pe(e, t = {}, f) {
457
460
  efSearch: d.efSearch ?? 50
458
461
  });
459
462
  console.timeEnd("knn"), console.time("fuzzy-set");
460
- const o = V(n, r, s);
463
+ const i = V(n, r, s);
461
464
  console.timeEnd("fuzzy-set");
462
- const { a: h, b: l } = Z(u, c), p = W(o.vals), g = e.length, _ = new Float32Array(g * a);
463
- for (let y = 0; y < _.length; y++)
464
- _[y] = Math.random() * 20 - 10;
465
+ const { a: h, b: l } = Z(u, c), p = W(i.vals), g = e.length, _ = new Float32Array(g * a);
466
+ for (let b = 0; b < _.length; b++)
467
+ _[b] = Math.random() * 20 - 10;
465
468
  console.time("sgd");
466
469
  let m;
467
470
  if (J())
468
471
  try {
469
- const y = new X();
470
- await y.init(), m = await y.optimize(
472
+ const b = new X();
473
+ await b.init(), m = await b.optimize(
471
474
  _,
472
- new Uint32Array(o.rows),
473
- new Uint32Array(o.cols),
475
+ new Uint32Array(i.rows),
476
+ new Uint32Array(i.cols),
474
477
  p,
475
478
  g,
476
479
  a,
477
- i,
480
+ o,
478
481
  { a: h, b: l, gamma: 1, negativeSampleRate: 5 },
479
482
  f
480
483
  );
481
- } catch (y) {
482
- console.warn("WebGPU SGD failed, falling back to CPU:", y), m = L(_, o, p, g, a, i, { a: h, b: l }, f);
484
+ } catch (b) {
485
+ console.warn("WebGPU SGD failed, falling back to CPU:", b), m = j(_, i, p, g, a, o, { a: h, b: l }, f);
483
486
  }
484
487
  else
485
- m = L(_, o, p, g, a, i, { a: h, b: l }, f);
488
+ m = j(_, i, p, g, a, o, { a: h, b: l }, f);
486
489
  return console.timeEnd("sgd"), m;
487
490
  }
488
491
  function Z(e, t) {
@@ -490,45 +493,45 @@ function Z(e, t) {
490
493
  }
491
494
  function de(e, t) {
492
495
  const a = [], s = [];
493
- for (let i = 0; i < 299; i++) {
494
- const n = (i + 1) / 299 * t * 3;
496
+ for (let o = 0; o < 299; o++) {
497
+ const n = (o + 1) / 299 * t * 3;
495
498
  a.push(n), s.push(n < e ? 1 : Math.exp(-(n - e) / t));
496
499
  }
497
500
  let u = 1, c = 1, d = 1e-3;
498
- for (let i = 0; i < 500; i++) {
499
- let n = 0, r = 0, o = 0, h = 0, l = 0, p = 0;
501
+ for (let o = 0; o < 500; o++) {
502
+ let n = 0, r = 0, i = 0, h = 0, l = 0, p = 0;
500
503
  for (let U = 0; U < 299; U++) {
501
504
  const F = a[U], N = Math.pow(F, 2 * c), M = 1 + u * N, w = 1 / M - s[U];
502
505
  p += w * w;
503
- const R = M * M, A = -N / R, B = F > 0 ? -2 * Math.log(F) * u * N / R : 0;
504
- n += A * w, r += B * w, o += A * A, h += B * B, l += A * B;
506
+ const C = M * M, A = -N / C, B = F > 0 ? -2 * Math.log(F) * u * N / C : 0;
507
+ n += A * w, r += B * w, i += A * A, h += B * B, l += A * B;
505
508
  }
506
- const g = o + d, _ = h + d, m = l, y = g * _ - m * m;
507
- if (Math.abs(y) < 1e-20) break;
508
- const b = -(_ * n - m * r) / y, x = -(g * r - m * n) / y, v = Math.max(1e-4, u + b), P = Math.max(1e-4, c + x);
509
- let G = 0;
509
+ const g = i + d, _ = h + d, m = l, b = g * _ - m * m;
510
+ if (Math.abs(b) < 1e-20) break;
511
+ const y = -(_ * n - m * r) / b, x = -(g * r - m * n) / b, v = Math.max(1e-4, u + y), G = Math.max(1e-4, c + x);
512
+ let P = 0;
510
513
  for (let U = 0; U < 299; U++) {
511
- const F = Math.pow(a[U], 2 * P), N = 1 / (1 + v * F) - s[U];
512
- G += N * N;
514
+ const F = Math.pow(a[U], 2 * G), N = 1 / (1 + v * F) - s[U];
515
+ P += N * N;
513
516
  }
514
- if (G < p ? (u = v, c = P, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(b) < 1e-8 && Math.abs(x) < 1e-8) break;
517
+ if (P < p ? (u = v, c = G, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(y) < 1e-8 && Math.abs(x) < 1e-8) break;
515
518
  }
516
519
  return { a: u, b: c };
517
520
  }
518
521
  class ge {
519
522
  constructor(t = {}) {
520
- C(this, "_nComponents");
521
- C(this, "_nNeighbors");
522
- C(this, "_minDist");
523
- C(this, "_spread");
524
- C(this, "_nEpochs");
525
- C(this, "_hnswOpts");
526
- C(this, "_a");
527
- C(this, "_b");
523
+ R(this, "_nComponents");
524
+ R(this, "_nNeighbors");
525
+ R(this, "_minDist");
526
+ R(this, "_spread");
527
+ R(this, "_nEpochs");
528
+ R(this, "_hnswOpts");
529
+ R(this, "_a");
530
+ R(this, "_b");
528
531
  /** The low-dimensional embedding produced by the last fit() call. */
529
- C(this, "embedding", null);
530
- C(this, "_hnswIndex", null);
531
- C(this, "_nTrain", 0);
532
+ R(this, "embedding", null);
533
+ R(this, "_hnswIndex", null);
534
+ R(this, "_nTrain", 0);
532
535
  this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
533
536
  const { a: f, b: a } = Z(this._minDist, this._spread);
534
537
  this._a = f, this._b = a;
@@ -542,15 +545,15 @@ class ge {
542
545
  async fit(t, f) {
543
546
  const a = t.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: u = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
544
547
  console.time("knn");
545
- const { knn: i, index: n } = await ae(t, this._nNeighbors, {
548
+ const { knn: o, index: n } = await ae(t, this._nNeighbors, {
546
549
  M: u,
547
550
  efConstruction: c,
548
551
  efSearch: d
549
552
  });
550
553
  this._hnswIndex = n, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
551
- const r = V(i.indices, i.distances, this._nNeighbors);
554
+ const r = V(o.indices, o.distances, this._nNeighbors);
552
555
  console.timeEnd("fuzzy-set");
553
- const o = W(r.vals), h = new Float32Array(a * this._nComponents);
556
+ const i = W(r.vals), h = new Float32Array(a * this._nComponents);
554
557
  for (let l = 0; l < h.length; l++)
555
558
  h[l] = Math.random() * 20 - 10;
556
559
  if (console.time("sgd"), J())
@@ -560,7 +563,7 @@ class ge {
560
563
  h,
561
564
  new Uint32Array(r.rows),
562
565
  new Uint32Array(r.cols),
563
- o,
566
+ i,
564
567
  a,
565
568
  this._nComponents,
566
569
  s,
@@ -568,13 +571,13 @@ class ge {
568
571
  f
569
572
  );
570
573
  } catch (l) {
571
- console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = L(h, r, o, a, this._nComponents, s, {
574
+ console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = j(h, r, i, a, this._nComponents, s, {
572
575
  a: this._a,
573
576
  b: this._b
574
577
  }, f);
575
578
  }
576
579
  else
577
- this.embedding = L(h, r, o, a, this._nComponents, s, {
580
+ this.embedding = j(h, r, i, a, this._nComponents, s, {
578
581
  a: this._a,
579
582
  b: this._b
580
583
  }, f);
@@ -594,22 +597,22 @@ class ge {
594
597
  async transform(t, f = !1) {
595
598
  if (!this._hnswIndex || !this.embedding)
596
599
  throw new Error("UMAP.transform() must be called after fit()");
597
- const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), i = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), o = new Float32Array(a * this._nComponents);
598
- for (let p = 0; p < i.length; p++) {
599
- const g = i[p], _ = n[p], m = d.vals[p];
600
+ const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), o = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), i = new Float32Array(a * this._nComponents);
601
+ for (let p = 0; p < o.length; p++) {
602
+ const g = o[p], _ = n[p], m = d.vals[p];
600
603
  r[g] += m;
601
- for (let y = 0; y < this._nComponents; y++)
602
- o[g * this._nComponents + y] += m * this.embedding[_ * this._nComponents + y];
604
+ for (let b = 0; b < this._nComponents; b++)
605
+ i[g * this._nComponents + b] += m * this.embedding[_ * this._nComponents + b];
603
606
  }
604
607
  for (let p = 0; p < a; p++)
605
608
  if (r[p] > 0)
606
609
  for (let g = 0; g < this._nComponents; g++)
607
- o[p * this._nComponents + g] /= r[p];
610
+ i[p * this._nComponents + g] /= r[p];
608
611
  else
609
612
  for (let g = 0; g < this._nComponents; g++)
610
- o[p * this._nComponents + g] = Math.random() * 20 - 10;
613
+ i[p * this._nComponents + g] = Math.random() * 20 - 10;
611
614
  const h = W(d.vals), l = fe(
612
- o,
615
+ i,
613
616
  this.embedding,
614
617
  d,
615
618
  h,
@@ -638,13 +641,13 @@ function K(e, t, f) {
638
641
  const a = new Float32Array(e.length);
639
642
  for (let s = 0; s < f; s++) {
640
643
  let u = 1 / 0, c = -1 / 0;
641
- for (let i = 0; i < t; i++) {
642
- const n = e[i * f + s];
644
+ for (let o = 0; o < t; o++) {
645
+ const n = e[o * f + s];
643
646
  n < u && (u = n), n > c && (c = n);
644
647
  }
645
648
  const d = c - u;
646
- for (let i = 0; i < t; i++)
647
- a[i * f + s] = d > 0 ? (e[i * f + s] - u) / d : 0;
649
+ for (let o = 0; o < t; o++)
650
+ a[o * f + s] = d > 0 ? (e[o * f + s] - u) / d : 0;
648
651
  }
649
652
  return a;
650
653
  }
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.2.14",
3
+ "version": "0.2.15",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -55,7 +55,8 @@
55
55
  "vite": "^5.0.0",
56
56
  "vitepress": "^1.6.4",
57
57
  "vitepress-plugin-llms": "^1.11.0",
58
- "vitest": "^4.0.18"
58
+ "vitest": "^4.0.18",
59
+ "webgpu": "^0.3.8"
59
60
  },
60
61
  "license": "MIT"
61
62
  }