umap-gpu 0.2.14 → 0.2.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/index.js +158 -155
- package/package.json +3 -2
package/dist/index.js
CHANGED
|
@@ -1,56 +1,56 @@
|
|
|
1
1
|
var te = Object.defineProperty;
|
|
2
2
|
var ne = (e, t, f) => t in e ? te(e, t, { enumerable: !0, configurable: !0, writable: !0, value: f }) : e[t] = f;
|
|
3
|
-
var
|
|
3
|
+
var R = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
|
|
4
4
|
import { loadHnswlib as H } from "hnswlib-wasm";
|
|
5
5
|
async function se(e, t, f = {}) {
|
|
6
|
-
const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length,
|
|
7
|
-
n.initIndex(
|
|
8
|
-
const r = [],
|
|
9
|
-
for (let h = 0; h <
|
|
6
|
+
const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
|
|
7
|
+
n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
|
|
8
|
+
const r = [], i = [];
|
|
9
|
+
for (let h = 0; h < o; h++) {
|
|
10
10
|
const l = n.searchKnn(e[h], t + 1, void 0), p = l.neighbors.map((g, _) => ({ idx: g, dist: l.distances[_] })).filter(({ idx: g }) => g !== h).slice(0, t);
|
|
11
|
-
r.push(p.map(({ idx: g }) => g)),
|
|
11
|
+
r.push(p.map(({ idx: g }) => g)), i.push(p.map(({ dist: g }) => Math.sqrt(g)));
|
|
12
12
|
}
|
|
13
|
-
return { indices: r, distances:
|
|
13
|
+
return { indices: r, distances: i };
|
|
14
14
|
}
|
|
15
15
|
async function ae(e, t, f = {}) {
|
|
16
|
-
const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length,
|
|
17
|
-
n.initIndex(
|
|
18
|
-
const r = [],
|
|
19
|
-
for (let l = 0; l <
|
|
16
|
+
const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
|
|
17
|
+
n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
|
|
18
|
+
const r = [], i = [];
|
|
19
|
+
for (let l = 0; l < o; l++) {
|
|
20
20
|
const p = n.searchKnn(e[l], t + 1, void 0), g = p.neighbors.map((_, m) => ({ idx: _, dist: p.distances[m] })).filter(({ idx: _ }) => _ !== l).slice(0, t);
|
|
21
|
-
r.push(g.map(({ idx: _ }) => _)),
|
|
21
|
+
r.push(g.map(({ idx: _ }) => _)), i.push(g.map(({ dist: _ }) => Math.sqrt(_)));
|
|
22
22
|
}
|
|
23
|
-
return { knn: { indices: r, distances:
|
|
23
|
+
return { knn: { indices: r, distances: i }, index: {
|
|
24
24
|
searchKnn(l, p) {
|
|
25
25
|
const g = [], _ = [];
|
|
26
26
|
for (const m of l) {
|
|
27
|
-
const
|
|
28
|
-
g.push(
|
|
27
|
+
const b = n.searchKnn(m, p, void 0), y = b.neighbors.map((x, v) => ({ idx: x, dist: b.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
|
|
28
|
+
g.push(y.map(({ idx: x }) => x)), _.push(y.map(({ dist: x }) => Math.sqrt(x)));
|
|
29
29
|
}
|
|
30
30
|
return { indices: g, distances: _ };
|
|
31
31
|
}
|
|
32
32
|
} };
|
|
33
33
|
}
|
|
34
34
|
function V(e, t, f, a = 1) {
|
|
35
|
-
const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [],
|
|
36
|
-
for (let
|
|
37
|
-
for (let h = 0; h < e[
|
|
38
|
-
const l = t[
|
|
39
|
-
d.push(
|
|
35
|
+
const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], o = [], n = [];
|
|
36
|
+
for (let i = 0; i < s; i++)
|
|
37
|
+
for (let h = 0; h < e[i].length; h++) {
|
|
38
|
+
const l = t[i][h], p = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / u[i]));
|
|
39
|
+
d.push(i), o.push(e[i][h]), n.push(p);
|
|
40
40
|
}
|
|
41
|
-
return { ...
|
|
41
|
+
return { ...ie(d, o, n, s, a), nVertices: s };
|
|
42
42
|
}
|
|
43
43
|
function re(e, t, f) {
|
|
44
|
-
const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [],
|
|
44
|
+
const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], o = [];
|
|
45
45
|
for (let n = 0; n < a; n++)
|
|
46
46
|
for (let r = 0; r < e[n].length; r++) {
|
|
47
|
-
const
|
|
48
|
-
c.push(n), d.push(e[n][r]),
|
|
47
|
+
const i = t[n][r], h = i <= u[n] ? 1 : Math.exp(-((i - u[n]) / s[n]));
|
|
48
|
+
c.push(n), d.push(e[n][r]), o.push(h);
|
|
49
49
|
}
|
|
50
50
|
return {
|
|
51
51
|
rows: new Uint32Array(c),
|
|
52
52
|
cols: new Uint32Array(d),
|
|
53
|
-
vals: new Float32Array(
|
|
53
|
+
vals: new Float32Array(o),
|
|
54
54
|
nVertices: a
|
|
55
55
|
};
|
|
56
56
|
}
|
|
@@ -59,35 +59,35 @@ function Y(e, t) {
|
|
|
59
59
|
for (let c = 0; c < a; c++) {
|
|
60
60
|
const d = e[c];
|
|
61
61
|
u[c] = d.find((h) => h > 0) ?? 0;
|
|
62
|
-
let
|
|
63
|
-
const
|
|
62
|
+
let o = 0, n = 1 / 0, r = 1;
|
|
63
|
+
const i = Math.log2(t);
|
|
64
64
|
for (let h = 0; h < 64; h++) {
|
|
65
65
|
let l = 0;
|
|
66
66
|
for (let p = 0; p < d.length; p++)
|
|
67
67
|
l += Math.exp(-Math.max(0, d[p] - u[c]) / r);
|
|
68
|
-
if (Math.abs(l -
|
|
69
|
-
l >
|
|
68
|
+
if (Math.abs(l - i) < 1e-5) break;
|
|
69
|
+
l > i ? (n = r, r = (o + n) / 2) : (o = r, r = n === 1 / 0 ? r * 2 : (o + n) / 2);
|
|
70
70
|
}
|
|
71
71
|
s[c] = r;
|
|
72
72
|
}
|
|
73
73
|
return { sigmas: s, rhos: u };
|
|
74
74
|
}
|
|
75
|
-
function
|
|
75
|
+
function ie(e, t, f, a, s) {
|
|
76
76
|
const u = /* @__PURE__ */ new Map();
|
|
77
77
|
for (let n = 0; n < e.length; n++)
|
|
78
78
|
u.set(e[n] * a + t[n], f[n]);
|
|
79
|
-
const c = [], d = [],
|
|
79
|
+
const c = [], d = [], o = [];
|
|
80
80
|
for (const [n, r] of u) {
|
|
81
|
-
const
|
|
82
|
-
c.push(
|
|
81
|
+
const i = Math.floor(n / a), h = n % a, l = u.get(h * a + i) ?? 0, p = r + l - r * l, g = r * l;
|
|
82
|
+
c.push(i), d.push(h), o.push(s * p + (1 - s) * g);
|
|
83
83
|
}
|
|
84
84
|
return {
|
|
85
85
|
rows: new Uint32Array(c),
|
|
86
86
|
cols: new Uint32Array(d),
|
|
87
|
-
vals: new Float32Array(
|
|
87
|
+
vals: new Float32Array(o)
|
|
88
88
|
};
|
|
89
89
|
}
|
|
90
|
-
const
|
|
90
|
+
const oe = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
|
|
91
91
|
// Computes attraction and repulsion forces and accumulates them atomically
|
|
92
92
|
// into a forces buffer. A separate apply-forces pass then updates embeddings,
|
|
93
93
|
// eliminating write-write races on shared vertex positions.
|
|
@@ -181,7 +181,10 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
181
181
|
epoch_of_next_negative_sample[edge_idx] += f32(n_neg) * epochs_per_neg;
|
|
182
182
|
}
|
|
183
183
|
|
|
184
|
-
|
|
184
|
+
// 2654435761u is the 32-bit golden-ratio hash constant (0x9E3779B1),
|
|
185
|
+
// which fits in u32 unlike the 64-bit LCG value 6364136223 used originally.
|
|
186
|
+
// Bug 14 fix: the original constant exceeded u32 range and failed WGSL validation.
|
|
187
|
+
var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 2654435761u);
|
|
185
188
|
|
|
186
189
|
for (var s = 0u; s < n_neg; s++) {
|
|
187
190
|
rng = xorshift(rng);
|
|
@@ -250,9 +253,9 @@ async function he() {
|
|
|
250
253
|
}
|
|
251
254
|
class X {
|
|
252
255
|
constructor() {
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
+
R(this, "device");
|
|
257
|
+
R(this, "sgdPipeline");
|
|
258
|
+
R(this, "applyForcesPipeline");
|
|
256
259
|
}
|
|
257
260
|
async init() {
|
|
258
261
|
const t = await Q();
|
|
@@ -260,7 +263,7 @@ class X {
|
|
|
260
263
|
this.device = t, this.sgdPipeline = this.device.createComputePipeline({
|
|
261
264
|
layout: "auto",
|
|
262
265
|
compute: {
|
|
263
|
-
module: this.device.createShaderModule({ code:
|
|
266
|
+
module: this.device.createShaderModule({ code: oe }),
|
|
264
267
|
entryPoint: "main"
|
|
265
268
|
}
|
|
266
269
|
}), this.applyForcesPipeline = this.device.createComputePipeline({
|
|
@@ -284,22 +287,22 @@ class X {
|
|
|
284
287
|
* @param params - UMAP curve parameters and repulsion settings
|
|
285
288
|
* @returns Optimized embedding as Float32Array
|
|
286
289
|
*/
|
|
287
|
-
async optimize(t, f, a, s, u, c, d,
|
|
288
|
-
const { device: r } = this,
|
|
290
|
+
async optimize(t, f, a, s, u, c, d, o, n) {
|
|
291
|
+
const { device: r } = this, i = f.length, h = u * c, l = this.makeBuffer(
|
|
289
292
|
t,
|
|
290
293
|
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
|
|
291
|
-
), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s),
|
|
292
|
-
for (let A = 0; A <
|
|
293
|
-
|
|
294
|
-
const x = this.makeBuffer(
|
|
295
|
-
for (let A = 0; A <
|
|
294
|
+
), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), b = this.makeBuffer(m, GPUBufferUsage.STORAGE), y = new Float32Array(i);
|
|
295
|
+
for (let A = 0; A < i; A++)
|
|
296
|
+
y[A] = s[A] / o.negativeSampleRate;
|
|
297
|
+
const x = this.makeBuffer(y, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
|
|
298
|
+
for (let A = 0; A < i; A++)
|
|
296
299
|
v[A] = Math.random() * 4294967295 | 0;
|
|
297
|
-
const
|
|
300
|
+
const G = this.makeBuffer(v, GPUBufferUsage.STORAGE), P = r.createBuffer({
|
|
298
301
|
size: h * 4,
|
|
299
302
|
usage: GPUBufferUsage.STORAGE,
|
|
300
303
|
mappedAtCreation: !0
|
|
301
304
|
});
|
|
302
|
-
new Int32Array(
|
|
305
|
+
new Int32Array(P.getMappedRange()).fill(0), P.unmap();
|
|
303
306
|
const U = r.createBuffer({
|
|
304
307
|
size: 40,
|
|
305
308
|
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
@@ -313,37 +316,37 @@ class X {
|
|
|
313
316
|
{ binding: 1, resource: { buffer: p } },
|
|
314
317
|
{ binding: 2, resource: { buffer: g } },
|
|
315
318
|
{ binding: 3, resource: { buffer: l } },
|
|
316
|
-
{ binding: 4, resource: { buffer:
|
|
319
|
+
{ binding: 4, resource: { buffer: b } },
|
|
317
320
|
{ binding: 5, resource: { buffer: x } },
|
|
318
321
|
{ binding: 6, resource: { buffer: U } },
|
|
319
|
-
{ binding: 7, resource: { buffer:
|
|
320
|
-
{ binding: 8, resource: { buffer:
|
|
322
|
+
{ binding: 7, resource: { buffer: G } },
|
|
323
|
+
{ binding: 8, resource: { buffer: P } }
|
|
321
324
|
]
|
|
322
325
|
}), M = r.createBindGroup({
|
|
323
326
|
layout: this.applyForcesPipeline.getBindGroupLayout(0),
|
|
324
327
|
entries: [
|
|
325
328
|
{ binding: 0, resource: { buffer: l } },
|
|
326
|
-
{ binding: 1, resource: { buffer:
|
|
329
|
+
{ binding: 1, resource: { buffer: P } },
|
|
327
330
|
{ binding: 2, resource: { buffer: F } }
|
|
328
331
|
]
|
|
329
332
|
});
|
|
330
333
|
for (let A = 0; A < d; A++) {
|
|
331
334
|
const B = 1 - A / d, O = new ArrayBuffer(40), S = new Uint32Array(O), k = new Float32Array(O);
|
|
332
|
-
S[0] =
|
|
335
|
+
S[0] = i, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = o.a, k[7] = o.b, k[8] = o.gamma, S[9] = o.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
|
|
333
336
|
const D = new ArrayBuffer(16), $ = new Uint32Array(D), ee = new Float32Array(D);
|
|
334
337
|
$[0] = h, ee[1] = B, r.queue.writeBuffer(F, 0, D);
|
|
335
338
|
const T = r.createCommandEncoder(), q = T.beginComputePass();
|
|
336
|
-
q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(
|
|
337
|
-
const
|
|
338
|
-
|
|
339
|
+
q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(i / 256)), q.end();
|
|
340
|
+
const L = T.beginComputePass();
|
|
341
|
+
L.setPipeline(this.applyForcesPipeline), L.setBindGroup(0, M), L.dispatchWorkgroups(Math.ceil(h / 256)), L.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
|
|
339
342
|
}
|
|
340
343
|
const E = r.createBuffer({
|
|
341
344
|
size: t.byteLength,
|
|
342
345
|
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
|
|
343
346
|
}), w = r.createCommandEncoder();
|
|
344
347
|
w.copyBufferToBuffer(l, 0, E, 0, t.byteLength), r.queue.submit([w.finish()]), await E.mapAsync(GPUMapMode.READ);
|
|
345
|
-
const
|
|
346
|
-
return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(),
|
|
348
|
+
const C = new Float32Array(E.getMappedRange().slice(0));
|
|
349
|
+
return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), b.destroy(), x.destroy(), G.destroy(), P.destroy(), U.destroy(), F.destroy(), E.destroy(), C;
|
|
347
350
|
}
|
|
348
351
|
makeBuffer(t, f) {
|
|
349
352
|
const a = this.device.createBuffer({
|
|
@@ -354,34 +357,34 @@ class X {
|
|
|
354
357
|
return t instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(t) : new Uint32Array(a.getMappedRange()).set(t), a.unmap(), a;
|
|
355
358
|
}
|
|
356
359
|
}
|
|
357
|
-
function
|
|
360
|
+
function I(e) {
|
|
358
361
|
return Math.max(-4, Math.min(4, e));
|
|
359
362
|
}
|
|
360
|
-
function
|
|
361
|
-
const { a:
|
|
363
|
+
function j(e, t, f, a, s, u, c, d) {
|
|
364
|
+
const { a: o, b: n, gamma: r = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
|
|
362
365
|
for (let m = 0; m < h; m++)
|
|
363
|
-
_[m] = f[m] /
|
|
366
|
+
_[m] = f[m] / i;
|
|
364
367
|
for (let m = 0; m < u; m++) {
|
|
365
368
|
d == null || d(m, u);
|
|
366
|
-
const
|
|
367
|
-
for (let
|
|
368
|
-
if (g[
|
|
369
|
-
const x = l[
|
|
370
|
-
let
|
|
369
|
+
const b = 1 - m / u;
|
|
370
|
+
for (let y = 0; y < h; y++) {
|
|
371
|
+
if (g[y] > m) continue;
|
|
372
|
+
const x = l[y], v = p[y];
|
|
373
|
+
let G = 0;
|
|
371
374
|
for (let M = 0; M < s; M++) {
|
|
372
375
|
const E = e[x * s + M] - e[v * s + M];
|
|
373
|
-
|
|
376
|
+
G += E * E;
|
|
374
377
|
}
|
|
375
|
-
const
|
|
378
|
+
const P = Math.pow(G, n), U = -2 * o * n * (G > 0 ? P / G : 0) / (o * P + 1);
|
|
376
379
|
for (let M = 0; M < s; M++) {
|
|
377
|
-
const E = e[x * s + M] - e[v * s + M], w =
|
|
378
|
-
e[x * s + M] +=
|
|
380
|
+
const E = e[x * s + M] - e[v * s + M], w = I(U * E);
|
|
381
|
+
e[x * s + M] += b * w, e[v * s + M] -= b * w;
|
|
379
382
|
}
|
|
380
|
-
g[
|
|
381
|
-
const F = f[
|
|
382
|
-
(m - _[
|
|
383
|
+
g[y] += f[y];
|
|
384
|
+
const F = f[y] / i, N = Math.max(0, Math.floor(
|
|
385
|
+
(m - _[y]) / F
|
|
383
386
|
));
|
|
384
|
-
_[
|
|
387
|
+
_[y] += N * F;
|
|
385
388
|
for (let M = 0; M < N; M++) {
|
|
386
389
|
const E = Math.floor(Math.random() * a);
|
|
387
390
|
if (E === x) continue;
|
|
@@ -390,52 +393,52 @@ function L(e, t, f, a, s, u, c, d) {
|
|
|
390
393
|
const O = e[x * s + B] - e[E * s + B];
|
|
391
394
|
w += O * O;
|
|
392
395
|
}
|
|
393
|
-
const
|
|
396
|
+
const C = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (o * C + 1));
|
|
394
397
|
for (let B = 0; B < s; B++) {
|
|
395
|
-
const O = e[x * s + B] - e[E * s + B], S =
|
|
396
|
-
e[x * s + B] +=
|
|
398
|
+
const O = e[x * s + B] - e[E * s + B], S = I(A * O);
|
|
399
|
+
e[x * s + B] += b * S;
|
|
397
400
|
}
|
|
398
401
|
}
|
|
399
402
|
}
|
|
400
403
|
}
|
|
401
404
|
return e;
|
|
402
405
|
}
|
|
403
|
-
function fe(e, t, f, a, s, u, c, d,
|
|
404
|
-
const { a: r, b:
|
|
405
|
-
for (let
|
|
406
|
-
y
|
|
407
|
-
for (let
|
|
408
|
-
const x = 1 -
|
|
406
|
+
function fe(e, t, f, a, s, u, c, d, o, n) {
|
|
407
|
+
const { a: r, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = o, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), b = new Float32Array(p);
|
|
408
|
+
for (let y = 0; y < p; y++)
|
|
409
|
+
b[y] = a[y] / l;
|
|
410
|
+
for (let y = 0; y < d; y++) {
|
|
411
|
+
const x = 1 - y / d;
|
|
409
412
|
for (let v = 0; v < p; v++) {
|
|
410
|
-
if (m[v] >
|
|
411
|
-
const
|
|
413
|
+
if (m[v] > y) continue;
|
|
414
|
+
const G = g[v], P = _[v];
|
|
412
415
|
let U = 0;
|
|
413
416
|
for (let w = 0; w < c; w++) {
|
|
414
|
-
const
|
|
415
|
-
U +=
|
|
417
|
+
const C = e[G * c + w] - t[P * c + w];
|
|
418
|
+
U += C * C;
|
|
416
419
|
}
|
|
417
|
-
const F = Math.pow(U,
|
|
420
|
+
const F = Math.pow(U, i), N = -2 * r * i * (U > 0 ? F / U : 0) / (r * F + 1);
|
|
418
421
|
for (let w = 0; w < c; w++) {
|
|
419
|
-
const
|
|
420
|
-
e[
|
|
422
|
+
const C = e[G * c + w] - t[P * c + w];
|
|
423
|
+
e[G * c + w] += x * I(N * C);
|
|
421
424
|
}
|
|
422
425
|
m[v] += a[v];
|
|
423
426
|
const M = a[v] / l, E = Math.max(0, Math.floor(
|
|
424
|
-
(
|
|
427
|
+
(y - b[v]) / M
|
|
425
428
|
));
|
|
426
|
-
|
|
429
|
+
b[v] += E * M;
|
|
427
430
|
for (let w = 0; w < E; w++) {
|
|
428
|
-
const
|
|
429
|
-
if (
|
|
431
|
+
const C = Math.floor(Math.random() * u);
|
|
432
|
+
if (C === P) continue;
|
|
430
433
|
let A = 0;
|
|
431
434
|
for (let S = 0; S < c; S++) {
|
|
432
|
-
const k = e[
|
|
435
|
+
const k = e[G * c + S] - t[C * c + S];
|
|
433
436
|
A += k * k;
|
|
434
437
|
}
|
|
435
|
-
const B = Math.pow(A,
|
|
438
|
+
const B = Math.pow(A, i), O = 2 * h * i / ((1e-3 + A) * (r * B + 1));
|
|
436
439
|
for (let S = 0; S < c; S++) {
|
|
437
|
-
const k = e[
|
|
438
|
-
e[
|
|
440
|
+
const k = e[G * c + S] - t[C * c + S];
|
|
441
|
+
e[G * c + S] += x * I(O * k);
|
|
439
442
|
}
|
|
440
443
|
}
|
|
441
444
|
}
|
|
@@ -449,7 +452,7 @@ async function pe(e, t = {}, f) {
|
|
|
449
452
|
minDist: u = 0.1,
|
|
450
453
|
spread: c = 1,
|
|
451
454
|
hnsw: d = {}
|
|
452
|
-
} = t,
|
|
455
|
+
} = t, o = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
|
|
453
456
|
console.time("knn");
|
|
454
457
|
const { indices: n, distances: r } = await se(e, s, {
|
|
455
458
|
M: d.M ?? 16,
|
|
@@ -457,32 +460,32 @@ async function pe(e, t = {}, f) {
|
|
|
457
460
|
efSearch: d.efSearch ?? 50
|
|
458
461
|
});
|
|
459
462
|
console.timeEnd("knn"), console.time("fuzzy-set");
|
|
460
|
-
const
|
|
463
|
+
const i = V(n, r, s);
|
|
461
464
|
console.timeEnd("fuzzy-set");
|
|
462
|
-
const { a: h, b: l } = Z(u, c), p = W(
|
|
463
|
-
for (let
|
|
464
|
-
_[
|
|
465
|
+
const { a: h, b: l } = Z(u, c), p = W(i.vals), g = e.length, _ = new Float32Array(g * a);
|
|
466
|
+
for (let b = 0; b < _.length; b++)
|
|
467
|
+
_[b] = Math.random() * 20 - 10;
|
|
465
468
|
console.time("sgd");
|
|
466
469
|
let m;
|
|
467
470
|
if (J())
|
|
468
471
|
try {
|
|
469
|
-
const
|
|
470
|
-
await
|
|
472
|
+
const b = new X();
|
|
473
|
+
await b.init(), m = await b.optimize(
|
|
471
474
|
_,
|
|
472
|
-
new Uint32Array(
|
|
473
|
-
new Uint32Array(
|
|
475
|
+
new Uint32Array(i.rows),
|
|
476
|
+
new Uint32Array(i.cols),
|
|
474
477
|
p,
|
|
475
478
|
g,
|
|
476
479
|
a,
|
|
477
|
-
|
|
480
|
+
o,
|
|
478
481
|
{ a: h, b: l, gamma: 1, negativeSampleRate: 5 },
|
|
479
482
|
f
|
|
480
483
|
);
|
|
481
|
-
} catch (
|
|
482
|
-
console.warn("WebGPU SGD failed, falling back to CPU:",
|
|
484
|
+
} catch (b) {
|
|
485
|
+
console.warn("WebGPU SGD failed, falling back to CPU:", b), m = j(_, i, p, g, a, o, { a: h, b: l }, f);
|
|
483
486
|
}
|
|
484
487
|
else
|
|
485
|
-
m =
|
|
488
|
+
m = j(_, i, p, g, a, o, { a: h, b: l }, f);
|
|
486
489
|
return console.timeEnd("sgd"), m;
|
|
487
490
|
}
|
|
488
491
|
function Z(e, t) {
|
|
@@ -490,45 +493,45 @@ function Z(e, t) {
|
|
|
490
493
|
}
|
|
491
494
|
function de(e, t) {
|
|
492
495
|
const a = [], s = [];
|
|
493
|
-
for (let
|
|
494
|
-
const n = (
|
|
496
|
+
for (let o = 0; o < 299; o++) {
|
|
497
|
+
const n = (o + 1) / 299 * t * 3;
|
|
495
498
|
a.push(n), s.push(n < e ? 1 : Math.exp(-(n - e) / t));
|
|
496
499
|
}
|
|
497
500
|
let u = 1, c = 1, d = 1e-3;
|
|
498
|
-
for (let
|
|
499
|
-
let n = 0, r = 0,
|
|
501
|
+
for (let o = 0; o < 500; o++) {
|
|
502
|
+
let n = 0, r = 0, i = 0, h = 0, l = 0, p = 0;
|
|
500
503
|
for (let U = 0; U < 299; U++) {
|
|
501
504
|
const F = a[U], N = Math.pow(F, 2 * c), M = 1 + u * N, w = 1 / M - s[U];
|
|
502
505
|
p += w * w;
|
|
503
|
-
const
|
|
504
|
-
n += A * w, r += B * w,
|
|
506
|
+
const C = M * M, A = -N / C, B = F > 0 ? -2 * Math.log(F) * u * N / C : 0;
|
|
507
|
+
n += A * w, r += B * w, i += A * A, h += B * B, l += A * B;
|
|
505
508
|
}
|
|
506
|
-
const g =
|
|
507
|
-
if (Math.abs(
|
|
508
|
-
const
|
|
509
|
-
let
|
|
509
|
+
const g = i + d, _ = h + d, m = l, b = g * _ - m * m;
|
|
510
|
+
if (Math.abs(b) < 1e-20) break;
|
|
511
|
+
const y = -(_ * n - m * r) / b, x = -(g * r - m * n) / b, v = Math.max(1e-4, u + y), G = Math.max(1e-4, c + x);
|
|
512
|
+
let P = 0;
|
|
510
513
|
for (let U = 0; U < 299; U++) {
|
|
511
|
-
const F = Math.pow(a[U], 2 *
|
|
512
|
-
|
|
514
|
+
const F = Math.pow(a[U], 2 * G), N = 1 / (1 + v * F) - s[U];
|
|
515
|
+
P += N * N;
|
|
513
516
|
}
|
|
514
|
-
if (
|
|
517
|
+
if (P < p ? (u = v, c = G, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(y) < 1e-8 && Math.abs(x) < 1e-8) break;
|
|
515
518
|
}
|
|
516
519
|
return { a: u, b: c };
|
|
517
520
|
}
|
|
518
521
|
class ge {
|
|
519
522
|
constructor(t = {}) {
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
523
|
+
R(this, "_nComponents");
|
|
524
|
+
R(this, "_nNeighbors");
|
|
525
|
+
R(this, "_minDist");
|
|
526
|
+
R(this, "_spread");
|
|
527
|
+
R(this, "_nEpochs");
|
|
528
|
+
R(this, "_hnswOpts");
|
|
529
|
+
R(this, "_a");
|
|
530
|
+
R(this, "_b");
|
|
528
531
|
/** The low-dimensional embedding produced by the last fit() call. */
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
+
R(this, "embedding", null);
|
|
533
|
+
R(this, "_hnswIndex", null);
|
|
534
|
+
R(this, "_nTrain", 0);
|
|
532
535
|
this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
|
|
533
536
|
const { a: f, b: a } = Z(this._minDist, this._spread);
|
|
534
537
|
this._a = f, this._b = a;
|
|
@@ -542,15 +545,15 @@ class ge {
|
|
|
542
545
|
async fit(t, f) {
|
|
543
546
|
const a = t.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: u = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
|
|
544
547
|
console.time("knn");
|
|
545
|
-
const { knn:
|
|
548
|
+
const { knn: o, index: n } = await ae(t, this._nNeighbors, {
|
|
546
549
|
M: u,
|
|
547
550
|
efConstruction: c,
|
|
548
551
|
efSearch: d
|
|
549
552
|
});
|
|
550
553
|
this._hnswIndex = n, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
|
|
551
|
-
const r = V(
|
|
554
|
+
const r = V(o.indices, o.distances, this._nNeighbors);
|
|
552
555
|
console.timeEnd("fuzzy-set");
|
|
553
|
-
const
|
|
556
|
+
const i = W(r.vals), h = new Float32Array(a * this._nComponents);
|
|
554
557
|
for (let l = 0; l < h.length; l++)
|
|
555
558
|
h[l] = Math.random() * 20 - 10;
|
|
556
559
|
if (console.time("sgd"), J())
|
|
@@ -560,7 +563,7 @@ class ge {
|
|
|
560
563
|
h,
|
|
561
564
|
new Uint32Array(r.rows),
|
|
562
565
|
new Uint32Array(r.cols),
|
|
563
|
-
|
|
566
|
+
i,
|
|
564
567
|
a,
|
|
565
568
|
this._nComponents,
|
|
566
569
|
s,
|
|
@@ -568,13 +571,13 @@ class ge {
|
|
|
568
571
|
f
|
|
569
572
|
);
|
|
570
573
|
} catch (l) {
|
|
571
|
-
console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding =
|
|
574
|
+
console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = j(h, r, i, a, this._nComponents, s, {
|
|
572
575
|
a: this._a,
|
|
573
576
|
b: this._b
|
|
574
577
|
}, f);
|
|
575
578
|
}
|
|
576
579
|
else
|
|
577
|
-
this.embedding =
|
|
580
|
+
this.embedding = j(h, r, i, a, this._nComponents, s, {
|
|
578
581
|
a: this._a,
|
|
579
582
|
b: this._b
|
|
580
583
|
}, f);
|
|
@@ -594,22 +597,22 @@ class ge {
|
|
|
594
597
|
async transform(t, f = !1) {
|
|
595
598
|
if (!this._hnswIndex || !this.embedding)
|
|
596
599
|
throw new Error("UMAP.transform() must be called after fit()");
|
|
597
|
-
const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors),
|
|
598
|
-
for (let p = 0; p <
|
|
599
|
-
const g =
|
|
600
|
+
const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), o = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), i = new Float32Array(a * this._nComponents);
|
|
601
|
+
for (let p = 0; p < o.length; p++) {
|
|
602
|
+
const g = o[p], _ = n[p], m = d.vals[p];
|
|
600
603
|
r[g] += m;
|
|
601
|
-
for (let
|
|
602
|
-
|
|
604
|
+
for (let b = 0; b < this._nComponents; b++)
|
|
605
|
+
i[g * this._nComponents + b] += m * this.embedding[_ * this._nComponents + b];
|
|
603
606
|
}
|
|
604
607
|
for (let p = 0; p < a; p++)
|
|
605
608
|
if (r[p] > 0)
|
|
606
609
|
for (let g = 0; g < this._nComponents; g++)
|
|
607
|
-
|
|
610
|
+
i[p * this._nComponents + g] /= r[p];
|
|
608
611
|
else
|
|
609
612
|
for (let g = 0; g < this._nComponents; g++)
|
|
610
|
-
|
|
613
|
+
i[p * this._nComponents + g] = Math.random() * 20 - 10;
|
|
611
614
|
const h = W(d.vals), l = fe(
|
|
612
|
-
|
|
615
|
+
i,
|
|
613
616
|
this.embedding,
|
|
614
617
|
d,
|
|
615
618
|
h,
|
|
@@ -638,13 +641,13 @@ function K(e, t, f) {
|
|
|
638
641
|
const a = new Float32Array(e.length);
|
|
639
642
|
for (let s = 0; s < f; s++) {
|
|
640
643
|
let u = 1 / 0, c = -1 / 0;
|
|
641
|
-
for (let
|
|
642
|
-
const n = e[
|
|
644
|
+
for (let o = 0; o < t; o++) {
|
|
645
|
+
const n = e[o * f + s];
|
|
643
646
|
n < u && (u = n), n > c && (c = n);
|
|
644
647
|
}
|
|
645
648
|
const d = c - u;
|
|
646
|
-
for (let
|
|
647
|
-
a[
|
|
649
|
+
for (let o = 0; o < t; o++)
|
|
650
|
+
a[o * f + s] = d > 0 ? (e[o * f + s] - u) / d : 0;
|
|
648
651
|
}
|
|
649
652
|
return a;
|
|
650
653
|
}
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "umap-gpu",
|
|
3
|
-
"version": "0.2.
|
|
3
|
+
"version": "0.2.15",
|
|
4
4
|
"description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "dist/index.js",
|
|
@@ -55,7 +55,8 @@
|
|
|
55
55
|
"vite": "^5.0.0",
|
|
56
56
|
"vitepress": "^1.6.4",
|
|
57
57
|
"vitepress-plugin-llms": "^1.11.0",
|
|
58
|
-
"vitest": "^4.0.18"
|
|
58
|
+
"vitest": "^4.0.18",
|
|
59
|
+
"webgpu": "^0.3.8"
|
|
59
60
|
},
|
|
60
61
|
"license": "MIT"
|
|
61
62
|
}
|