umap-gpu 0.2.10 → 0.2.11
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +2 -0
- package/dist/index.js +148 -147
- package/package.json +1 -1
package/README.md
CHANGED
package/dist/index.js
CHANGED
|
@@ -1,29 +1,29 @@
|
|
|
1
1
|
var H = Object.defineProperty;
|
|
2
|
-
var V = (e,
|
|
3
|
-
var G = (e,
|
|
2
|
+
var V = (e, n, r) => n in e ? H(e, n, { enumerable: !0, configurable: !0, writable: !0, value: r }) : e[n] = r;
|
|
3
|
+
var G = (e, n, r) => V(e, typeof n != "symbol" ? n + "" : n, r);
|
|
4
4
|
import { loadHnswlib as C } from "hnswlib-wasm";
|
|
5
|
-
async function Y(e,
|
|
6
|
-
const { M: s = 16, efConstruction:
|
|
7
|
-
a.initIndex(c, s,
|
|
5
|
+
async function Y(e, n, r = {}) {
|
|
6
|
+
const { M: s = 16, efConstruction: t = 200, efSearch: p = 50 } = r, f = await C(), l = e[0].length, c = e.length, a = new f.HierarchicalNSW("l2", l, "");
|
|
7
|
+
a.initIndex(c, s, t, 200), a.setEfSearch(Math.max(p, n)), a.addItems(e, !1);
|
|
8
8
|
const o = [], i = [];
|
|
9
9
|
for (let h = 0; h < c; h++) {
|
|
10
|
-
const
|
|
10
|
+
const d = a.searchKnn(e[h], n + 1, void 0), u = d.neighbors.map((g, m) => ({ idx: g, dist: d.distances[m] })).filter(({ idx: g }) => g !== h).slice(0, n);
|
|
11
11
|
o.push(u.map(({ idx: g }) => g)), i.push(u.map(({ dist: g }) => g));
|
|
12
12
|
}
|
|
13
13
|
return { indices: o, distances: i };
|
|
14
14
|
}
|
|
15
|
-
async function Q(e,
|
|
16
|
-
const { M: s = 16, efConstruction:
|
|
17
|
-
a.initIndex(c, s,
|
|
15
|
+
async function Q(e, n, r = {}) {
|
|
16
|
+
const { M: s = 16, efConstruction: t = 200, efSearch: p = 50 } = r, f = await C(), l = e[0].length, c = e.length, a = new f.HierarchicalNSW("l2", l, "");
|
|
17
|
+
a.initIndex(c, s, t, 200), a.setEfSearch(Math.max(p, n)), a.addItems(e, !1);
|
|
18
18
|
const o = [], i = [];
|
|
19
|
-
for (let
|
|
20
|
-
const u = a.searchKnn(e[
|
|
19
|
+
for (let d = 0; d < c; d++) {
|
|
20
|
+
const u = a.searchKnn(e[d], n + 1, void 0), g = u.neighbors.map((m, w) => ({ idx: m, dist: u.distances[w] })).filter(({ idx: m }) => m !== d).slice(0, n);
|
|
21
21
|
o.push(g.map(({ idx: m }) => m)), i.push(g.map(({ dist: m }) => m));
|
|
22
22
|
}
|
|
23
23
|
return { knn: { indices: o, distances: i }, index: {
|
|
24
|
-
searchKnn(
|
|
24
|
+
searchKnn(d, u) {
|
|
25
25
|
const g = [], m = [];
|
|
26
|
-
for (const w of
|
|
26
|
+
for (const w of d) {
|
|
27
27
|
const b = a.searchKnn(w, u, void 0), y = b.neighbors.map((v, x) => ({ idx: v, dist: b.distances[x] })).sort((v, x) => v.dist - x.dist).slice(0, u);
|
|
28
28
|
g.push(y.map(({ idx: v }) => v)), m.push(y.map(({ dist: v }) => v));
|
|
29
29
|
}
|
|
@@ -31,59 +31,59 @@ async function Q(e, t, r = {}) {
|
|
|
31
31
|
}
|
|
32
32
|
} };
|
|
33
33
|
}
|
|
34
|
-
function D(e,
|
|
35
|
-
const
|
|
36
|
-
for (let i = 0; i <
|
|
34
|
+
function D(e, n, r, s = 1) {
|
|
35
|
+
const t = e.length, { sigmas: p, rhos: f } = L(n, r), l = [], c = [], a = [];
|
|
36
|
+
for (let i = 0; i < t; i++)
|
|
37
37
|
for (let h = 0; h < e[i].length; h++) {
|
|
38
|
-
const
|
|
39
|
-
|
|
38
|
+
const d = n[i][h], u = d <= f[i] ? 1 : Math.exp(-((d - f[i]) / p[i]));
|
|
39
|
+
l.push(i), c.push(e[i][h]), a.push(u);
|
|
40
40
|
}
|
|
41
|
-
return { ...X(
|
|
41
|
+
return { ...X(l, c, a, t, s), nVertices: t };
|
|
42
42
|
}
|
|
43
|
-
function J(e,
|
|
44
|
-
const s = e.length, { sigmas:
|
|
43
|
+
function J(e, n, r) {
|
|
44
|
+
const s = e.length, { sigmas: t, rhos: p } = L(n, r), f = [], l = [], c = [];
|
|
45
45
|
for (let a = 0; a < s; a++)
|
|
46
46
|
for (let o = 0; o < e[a].length; o++) {
|
|
47
|
-
const i =
|
|
48
|
-
f.push(a),
|
|
47
|
+
const i = n[a][o], h = i <= p[a] ? 1 : Math.exp(-((i - p[a]) / t[a]));
|
|
48
|
+
f.push(a), l.push(e[a][o]), c.push(h);
|
|
49
49
|
}
|
|
50
50
|
return {
|
|
51
51
|
rows: new Float32Array(f),
|
|
52
|
-
cols: new Float32Array(
|
|
52
|
+
cols: new Float32Array(l),
|
|
53
53
|
vals: new Float32Array(c),
|
|
54
54
|
nVertices: s
|
|
55
55
|
};
|
|
56
56
|
}
|
|
57
|
-
function L(e,
|
|
58
|
-
const s = e.length,
|
|
57
|
+
function L(e, n) {
|
|
58
|
+
const s = e.length, t = new Float32Array(s), p = new Float32Array(s);
|
|
59
59
|
for (let f = 0; f < s; f++) {
|
|
60
|
-
const
|
|
61
|
-
p[f] =
|
|
60
|
+
const l = e[f];
|
|
61
|
+
p[f] = l.find((h) => h > 0) ?? 0;
|
|
62
62
|
let c = 0, a = 1 / 0, o = 1;
|
|
63
|
-
const i = Math.log2(
|
|
63
|
+
const i = Math.log2(n);
|
|
64
64
|
for (let h = 0; h < 64; h++) {
|
|
65
|
-
let
|
|
66
|
-
for (let u =
|
|
67
|
-
|
|
68
|
-
if (Math.abs(
|
|
69
|
-
|
|
65
|
+
let d = 0;
|
|
66
|
+
for (let u = 0; u < l.length; u++)
|
|
67
|
+
d += Math.exp(-Math.max(0, l[u] - p[f]) / o);
|
|
68
|
+
if (Math.abs(d - i) < 1e-5) break;
|
|
69
|
+
d > i ? (a = o, o = (c + a) / 2) : (c = o, o = a === 1 / 0 ? o * 2 : (c + a) / 2);
|
|
70
70
|
}
|
|
71
|
-
|
|
71
|
+
t[f] = o;
|
|
72
72
|
}
|
|
73
|
-
return { sigmas:
|
|
73
|
+
return { sigmas: t, rhos: p };
|
|
74
74
|
}
|
|
75
|
-
function X(e,
|
|
75
|
+
function X(e, n, r, s, t) {
|
|
76
76
|
const p = /* @__PURE__ */ new Map();
|
|
77
77
|
for (let a = 0; a < e.length; a++)
|
|
78
|
-
p.set(e[a] * s +
|
|
79
|
-
const f = [],
|
|
78
|
+
p.set(e[a] * s + n[a], r[a]);
|
|
79
|
+
const f = [], l = [], c = [];
|
|
80
80
|
for (const [a, o] of p) {
|
|
81
|
-
const i = Math.floor(a / s), h = a % s,
|
|
82
|
-
f.push(i),
|
|
81
|
+
const i = Math.floor(a / s), h = a % s, d = p.get(h * s + i) ?? 0, u = o + d - o * d, g = o * d;
|
|
82
|
+
f.push(i), l.push(h), c.push(t * u + (1 - t) * g);
|
|
83
83
|
}
|
|
84
84
|
return {
|
|
85
85
|
rows: new Float32Array(f),
|
|
86
|
-
cols: new Float32Array(
|
|
86
|
+
cols: new Float32Array(l),
|
|
87
87
|
vals: new Float32Array(c)
|
|
88
88
|
};
|
|
89
89
|
}
|
|
@@ -158,6 +158,7 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
158
158
|
let diff = embedding[i * nc + d] - embedding[j * nc + d];
|
|
159
159
|
let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
|
|
160
160
|
embedding[i * nc + d] += params.alpha * grad;
|
|
161
|
+
embedding[j * nc + d] -= params.alpha * grad;
|
|
161
162
|
}
|
|
162
163
|
|
|
163
164
|
epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
|
|
@@ -202,9 +203,9 @@ class W {
|
|
|
202
203
|
G(this, "pipeline");
|
|
203
204
|
}
|
|
204
205
|
async init() {
|
|
205
|
-
const
|
|
206
|
-
if (!
|
|
207
|
-
this.device = await
|
|
206
|
+
const n = await navigator.gpu.requestAdapter();
|
|
207
|
+
if (!n) throw new Error("WebGPU not supported");
|
|
208
|
+
this.device = await n.requestDevice(), this.pipeline = this.device.createComputePipeline({
|
|
208
209
|
layout: "auto",
|
|
209
210
|
compute: {
|
|
210
211
|
module: this.device.createShaderModule({ code: Z }),
|
|
@@ -225,13 +226,13 @@ class W {
|
|
|
225
226
|
* @param params - UMAP curve parameters and repulsion settings
|
|
226
227
|
* @returns Optimized embedding as Float32Array
|
|
227
228
|
*/
|
|
228
|
-
async optimize(
|
|
229
|
+
async optimize(n, r, s, t, p, f, l, c, a) {
|
|
229
230
|
const { device: o } = this, i = r.length, h = this.makeBuffer(
|
|
230
|
-
|
|
231
|
+
n,
|
|
231
232
|
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
|
|
232
|
-
),
|
|
233
|
+
), d = this.makeBuffer(r, GPUBufferUsage.STORAGE), u = this.makeBuffer(s, GPUBufferUsage.STORAGE), g = this.makeBuffer(t, GPUBufferUsage.STORAGE), m = new Float32Array(i).fill(0), w = this.makeBuffer(m, GPUBufferUsage.STORAGE), b = new Float32Array(i);
|
|
233
234
|
for (let _ = 0; _ < i; _++)
|
|
234
|
-
b[_] =
|
|
235
|
+
b[_] = t[_] / c.negativeSampleRate;
|
|
235
236
|
const y = this.makeBuffer(b, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
|
|
236
237
|
for (let _ = 0; _ < i; _++)
|
|
237
238
|
v[_] = Math.random() * 4294967295 | 0;
|
|
@@ -239,14 +240,14 @@ class W {
|
|
|
239
240
|
size: 40,
|
|
240
241
|
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
241
242
|
});
|
|
242
|
-
for (let _ = 0; _ <
|
|
243
|
-
const F = 1 - _ /
|
|
244
|
-
M[0] = i, M[1] = p, M[2] = f, M[3] = _, M[4] =
|
|
243
|
+
for (let _ = 0; _ < l; _++) {
|
|
244
|
+
const F = 1 - _ / l, A = new ArrayBuffer(40), M = new Uint32Array(A), N = new Float32Array(A);
|
|
245
|
+
M[0] = i, M[1] = p, M[2] = f, M[3] = _, M[4] = l, N[5] = F, N[6] = c.a, N[7] = c.b, N[8] = c.gamma, M[9] = c.negativeSampleRate, o.queue.writeBuffer(B, 0, A);
|
|
245
246
|
const S = o.createBindGroup({
|
|
246
247
|
layout: this.pipeline.getBindGroupLayout(0),
|
|
247
248
|
entries: [
|
|
248
249
|
{ binding: 0, resource: { buffer: g } },
|
|
249
|
-
{ binding: 1, resource: { buffer:
|
|
250
|
+
{ binding: 1, resource: { buffer: d } },
|
|
250
251
|
{ binding: 2, resource: { buffer: u } },
|
|
251
252
|
{ binding: 3, resource: { buffer: h } },
|
|
252
253
|
{ binding: 4, resource: { buffer: w } },
|
|
@@ -255,47 +256,47 @@ class W {
|
|
|
255
256
|
{ binding: 7, resource: { buffer: x } }
|
|
256
257
|
]
|
|
257
258
|
}), O = o.createCommandEncoder(), U = O.beginComputePass();
|
|
258
|
-
U.setPipeline(this.pipeline), U.setBindGroup(0, S), U.dispatchWorkgroups(Math.ceil(i / 256)), U.end(), o.queue.submit([O.finish()]), _ % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(_,
|
|
259
|
+
U.setPipeline(this.pipeline), U.setBindGroup(0, S), U.dispatchWorkgroups(Math.ceil(i / 256)), U.end(), o.queue.submit([O.finish()]), _ % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(_, l));
|
|
259
260
|
}
|
|
260
261
|
const E = o.createBuffer({
|
|
261
|
-
size:
|
|
262
|
+
size: n.byteLength,
|
|
262
263
|
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
|
|
263
264
|
}), R = o.createCommandEncoder();
|
|
264
|
-
R.copyBufferToBuffer(h, 0, E, 0,
|
|
265
|
+
R.copyBufferToBuffer(h, 0, E, 0, n.byteLength), o.queue.submit([R.finish()]), await E.mapAsync(GPUMapMode.READ);
|
|
265
266
|
const k = new Float32Array(E.getMappedRange().slice(0));
|
|
266
|
-
return E.unmap(), h.destroy(),
|
|
267
|
+
return E.unmap(), h.destroy(), d.destroy(), u.destroy(), g.destroy(), w.destroy(), y.destroy(), x.destroy(), B.destroy(), E.destroy(), k;
|
|
267
268
|
}
|
|
268
|
-
makeBuffer(
|
|
269
|
+
makeBuffer(n, r) {
|
|
269
270
|
const s = this.device.createBuffer({
|
|
270
|
-
size:
|
|
271
|
+
size: n.byteLength,
|
|
271
272
|
usage: r,
|
|
272
273
|
mappedAtCreation: !0
|
|
273
274
|
});
|
|
274
|
-
return
|
|
275
|
+
return n instanceof Float32Array ? new Float32Array(s.getMappedRange()).set(n) : new Uint32Array(s.getMappedRange()).set(n), s.unmap(), s;
|
|
275
276
|
}
|
|
276
277
|
}
|
|
277
278
|
function q(e) {
|
|
278
279
|
return Math.max(-4, Math.min(4, e));
|
|
279
280
|
}
|
|
280
|
-
function z(e,
|
|
281
|
-
const { a: c, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = f, h =
|
|
281
|
+
function z(e, n, r, s, t, p, f, l) {
|
|
282
|
+
const { a: c, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = f, h = n.rows.length, d = new Uint32Array(n.rows), u = new Uint32Array(n.cols), g = new Float32Array(h).fill(0), m = new Float32Array(h);
|
|
282
283
|
for (let w = 0; w < h; w++)
|
|
283
284
|
m[w] = r[w] / i;
|
|
284
285
|
for (let w = 0; w < p; w++) {
|
|
285
|
-
|
|
286
|
+
l == null || l(w, p);
|
|
286
287
|
const b = 1 - w / p;
|
|
287
288
|
for (let y = 0; y < h; y++) {
|
|
288
289
|
if (g[y] > w) continue;
|
|
289
|
-
const v =
|
|
290
|
+
const v = d[y], x = u[y];
|
|
290
291
|
let B = 0;
|
|
291
|
-
for (let _ = 0; _ <
|
|
292
|
-
const F = e[v *
|
|
292
|
+
for (let _ = 0; _ < t; _++) {
|
|
293
|
+
const F = e[v * t + _] - e[x * t + _];
|
|
293
294
|
B += F * F;
|
|
294
295
|
}
|
|
295
296
|
const E = Math.pow(B, a), R = -2 * c * a * (B > 0 ? E / B : 0) / (c * E + 1);
|
|
296
|
-
for (let _ = 0; _ <
|
|
297
|
-
const F = e[v *
|
|
298
|
-
e[v *
|
|
297
|
+
for (let _ = 0; _ < t; _++) {
|
|
298
|
+
const F = e[v * t + _] - e[x * t + _], A = q(R * F);
|
|
299
|
+
e[v * t + _] += b * A, e[x * t + _] -= b * A;
|
|
299
300
|
}
|
|
300
301
|
g[y] += r[y];
|
|
301
302
|
const k = m[y] > 0 ? Math.floor(r[y] / m[y]) : 0;
|
|
@@ -303,14 +304,14 @@ function z(e, t, r, s, n, p, f, d) {
|
|
|
303
304
|
const F = Math.floor(Math.random() * s);
|
|
304
305
|
if (F === v) continue;
|
|
305
306
|
let A = 0;
|
|
306
|
-
for (let S = 0; S <
|
|
307
|
-
const O = e[v *
|
|
307
|
+
for (let S = 0; S < t; S++) {
|
|
308
|
+
const O = e[v * t + S] - e[F * t + S];
|
|
308
309
|
A += O * O;
|
|
309
310
|
}
|
|
310
311
|
const M = Math.pow(A, a), N = 2 * o * a / ((1e-3 + A) * (c * M + 1));
|
|
311
|
-
for (let S = 0; S <
|
|
312
|
-
const O = e[v *
|
|
313
|
-
e[v *
|
|
312
|
+
for (let S = 0; S < t; S++) {
|
|
313
|
+
const O = e[v * t + S] - e[F * t + S], U = q(N * O);
|
|
314
|
+
e[v * t + S] += b * U;
|
|
314
315
|
}
|
|
315
316
|
}
|
|
316
317
|
m[y] += r[y] / i;
|
|
@@ -318,23 +319,23 @@ function z(e, t, r, s, n, p, f, d) {
|
|
|
318
319
|
}
|
|
319
320
|
return e;
|
|
320
321
|
}
|
|
321
|
-
function $(e,
|
|
322
|
-
const { a: o, b: i, gamma: h = 1, negativeSampleRate:
|
|
322
|
+
function $(e, n, r, s, t, p, f, l, c, a) {
|
|
323
|
+
const { a: o, b: i, gamma: h = 1, negativeSampleRate: d = 5 } = c, u = r.rows.length, g = new Uint32Array(r.rows), m = new Uint32Array(r.cols), w = new Float32Array(u).fill(0), b = new Float32Array(u);
|
|
323
324
|
for (let y = 0; y < u; y++)
|
|
324
|
-
b[y] = s[y] /
|
|
325
|
-
for (let y = 0; y <
|
|
326
|
-
const v = 1 - y /
|
|
325
|
+
b[y] = s[y] / d;
|
|
326
|
+
for (let y = 0; y < l; y++) {
|
|
327
|
+
const v = 1 - y / l;
|
|
327
328
|
for (let x = 0; x < u; x++) {
|
|
328
329
|
if (w[x] > y) continue;
|
|
329
330
|
const B = g[x], E = m[x];
|
|
330
331
|
let R = 0;
|
|
331
332
|
for (let A = 0; A < f; A++) {
|
|
332
|
-
const M = e[B * f + A] -
|
|
333
|
+
const M = e[B * f + A] - n[E * f + A];
|
|
333
334
|
R += M * M;
|
|
334
335
|
}
|
|
335
336
|
const k = Math.pow(R, i), _ = -2 * o * i * (R > 0 ? k / R : 0) / (o * k + 1);
|
|
336
337
|
for (let A = 0; A < f; A++) {
|
|
337
|
-
const M = e[B * f + A] -
|
|
338
|
+
const M = e[B * f + A] - n[E * f + A];
|
|
338
339
|
e[B * f + A] += v * q(_ * M);
|
|
339
340
|
}
|
|
340
341
|
w[x] += s[x];
|
|
@@ -344,46 +345,46 @@ function $(e, t, r, s, n, p, f, d, c, a) {
|
|
|
344
345
|
if (M === E) continue;
|
|
345
346
|
let N = 0;
|
|
346
347
|
for (let U = 0; U < f; U++) {
|
|
347
|
-
const P = e[B * f + U] -
|
|
348
|
+
const P = e[B * f + U] - n[M * f + U];
|
|
348
349
|
N += P * P;
|
|
349
350
|
}
|
|
350
351
|
const S = Math.pow(N, i), O = 2 * h * i / ((1e-3 + N) * (o * S + 1));
|
|
351
352
|
for (let U = 0; U < f; U++) {
|
|
352
|
-
const P = e[B * f + U] -
|
|
353
|
+
const P = e[B * f + U] - n[M * f + U];
|
|
353
354
|
e[B * f + U] += v * q(O * P);
|
|
354
355
|
}
|
|
355
356
|
}
|
|
356
|
-
b[x] += s[x] /
|
|
357
|
+
b[x] += s[x] / d;
|
|
357
358
|
}
|
|
358
359
|
}
|
|
359
360
|
return e;
|
|
360
361
|
}
|
|
361
|
-
function
|
|
362
|
+
function j() {
|
|
362
363
|
return typeof navigator < "u" && !!navigator.gpu;
|
|
363
364
|
}
|
|
364
|
-
async function ae(e,
|
|
365
|
+
async function ae(e, n = {}, r) {
|
|
365
366
|
const {
|
|
366
367
|
nComponents: s = 2,
|
|
367
|
-
nNeighbors:
|
|
368
|
+
nNeighbors: t = 15,
|
|
368
369
|
minDist: p = 0.1,
|
|
369
370
|
spread: f = 1,
|
|
370
|
-
hnsw:
|
|
371
|
-
} =
|
|
371
|
+
hnsw: l = {}
|
|
372
|
+
} = n, c = n.nEpochs ?? (e.length > 1e4 ? 200 : 500);
|
|
372
373
|
console.time("knn");
|
|
373
|
-
const { indices: a, distances: o } = await Y(e,
|
|
374
|
-
M:
|
|
375
|
-
efConstruction:
|
|
376
|
-
efSearch:
|
|
374
|
+
const { indices: a, distances: o } = await Y(e, t, {
|
|
375
|
+
M: l.M ?? 16,
|
|
376
|
+
efConstruction: l.efConstruction ?? 200,
|
|
377
|
+
efSearch: l.efSearch ?? 50
|
|
377
378
|
});
|
|
378
379
|
console.timeEnd("knn"), console.time("fuzzy-set");
|
|
379
|
-
const i = D(a, o,
|
|
380
|
+
const i = D(a, o, t);
|
|
380
381
|
console.timeEnd("fuzzy-set");
|
|
381
|
-
const { a: h, b:
|
|
382
|
+
const { a: h, b: d } = K(p, f), u = I(i.vals, c), g = e.length, m = new Float32Array(g * s);
|
|
382
383
|
for (let b = 0; b < m.length; b++)
|
|
383
384
|
m[b] = Math.random() * 20 - 10;
|
|
384
385
|
console.time("sgd");
|
|
385
386
|
let w;
|
|
386
|
-
if (
|
|
387
|
+
if (j())
|
|
387
388
|
try {
|
|
388
389
|
const b = new W();
|
|
389
390
|
await b.init(), w = await b.optimize(
|
|
@@ -394,34 +395,34 @@ async function ae(e, t = {}, r) {
|
|
|
394
395
|
g,
|
|
395
396
|
s,
|
|
396
397
|
c,
|
|
397
|
-
{ a: h, b:
|
|
398
|
+
{ a: h, b: d, gamma: 1, negativeSampleRate: 5 },
|
|
398
399
|
r
|
|
399
400
|
);
|
|
400
401
|
} catch (b) {
|
|
401
|
-
console.warn("WebGPU SGD failed, falling back to CPU:", b), w = z(m, i, u, g, s, c, { a: h, b:
|
|
402
|
+
console.warn("WebGPU SGD failed, falling back to CPU:", b), w = z(m, i, u, g, s, c, { a: h, b: d }, r);
|
|
402
403
|
}
|
|
403
404
|
else
|
|
404
|
-
w = z(m, i, u, g, s, c, { a: h, b:
|
|
405
|
+
w = z(m, i, u, g, s, c, { a: h, b: d }, r);
|
|
405
406
|
return console.timeEnd("sgd"), w;
|
|
406
407
|
}
|
|
407
|
-
function
|
|
408
|
-
if (Math.abs(
|
|
408
|
+
function K(e, n) {
|
|
409
|
+
if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
|
|
409
410
|
return { a: 1.9292, b: 0.7915 };
|
|
410
|
-
if (Math.abs(
|
|
411
|
+
if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
|
|
411
412
|
return { a: 1.8956, b: 0.8006 };
|
|
412
|
-
if (Math.abs(
|
|
413
|
+
if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
|
|
413
414
|
return { a: 1.5769, b: 0.8951 };
|
|
414
|
-
const r = ee(e,
|
|
415
|
-
return { a:
|
|
415
|
+
const r = ee(e, n);
|
|
416
|
+
return { a: ne(e, n, r), b: r };
|
|
416
417
|
}
|
|
417
|
-
function ee(e,
|
|
418
|
-
return 1 / (
|
|
418
|
+
function ee(e, n) {
|
|
419
|
+
return 1 / (n * 1.2);
|
|
419
420
|
}
|
|
420
|
-
function
|
|
421
|
+
function ne(e, n, r) {
|
|
421
422
|
return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * r);
|
|
422
423
|
}
|
|
423
424
|
class re {
|
|
424
|
-
constructor(
|
|
425
|
+
constructor(n = {}) {
|
|
425
426
|
G(this, "_nComponents");
|
|
426
427
|
G(this, "_nNeighbors");
|
|
427
428
|
G(this, "_minDist");
|
|
@@ -434,8 +435,8 @@ class re {
|
|
|
434
435
|
G(this, "embedding", null);
|
|
435
436
|
G(this, "_hnswIndex", null);
|
|
436
437
|
G(this, "_nTrain", 0);
|
|
437
|
-
this._nComponents =
|
|
438
|
-
const { a: r, b: s } =
|
|
438
|
+
this._nComponents = n.nComponents ?? 2, this._nNeighbors = n.nNeighbors ?? 15, this._minDist = n.minDist ?? 0.1, this._spread = n.spread ?? 1, this._nEpochs = n.nEpochs, this._hnswOpts = n.hnsw ?? {};
|
|
439
|
+
const { a: r, b: s } = K(this._minDist, this._spread);
|
|
439
440
|
this._a = r, this._b = s;
|
|
440
441
|
}
|
|
441
442
|
/**
|
|
@@ -444,42 +445,42 @@ class re {
|
|
|
444
445
|
* index so that transform() can project new points later.
|
|
445
446
|
* Returns `this` for chaining.
|
|
446
447
|
*/
|
|
447
|
-
async fit(
|
|
448
|
-
const s =
|
|
448
|
+
async fit(n, r) {
|
|
449
|
+
const s = n.length, t = this._nEpochs ?? (s > 1e4 ? 200 : 500), { M: p = 16, efConstruction: f = 200, efSearch: l = 50 } = this._hnswOpts;
|
|
449
450
|
console.time("knn");
|
|
450
|
-
const { knn: c, index: a } = await Q(
|
|
451
|
+
const { knn: c, index: a } = await Q(n, this._nNeighbors, {
|
|
451
452
|
M: p,
|
|
452
453
|
efConstruction: f,
|
|
453
|
-
efSearch:
|
|
454
|
+
efSearch: l
|
|
454
455
|
});
|
|
455
456
|
this._hnswIndex = a, this._nTrain = s, console.timeEnd("knn"), console.time("fuzzy-set");
|
|
456
457
|
const o = D(c.indices, c.distances, this._nNeighbors);
|
|
457
458
|
console.timeEnd("fuzzy-set");
|
|
458
|
-
const i = I(o.vals,
|
|
459
|
-
for (let
|
|
460
|
-
h[
|
|
461
|
-
if (console.time("sgd"),
|
|
459
|
+
const i = I(o.vals, t), h = new Float32Array(s * this._nComponents);
|
|
460
|
+
for (let d = 0; d < h.length; d++)
|
|
461
|
+
h[d] = Math.random() * 20 - 10;
|
|
462
|
+
if (console.time("sgd"), j())
|
|
462
463
|
try {
|
|
463
|
-
const
|
|
464
|
-
await
|
|
464
|
+
const d = new W();
|
|
465
|
+
await d.init(), this.embedding = await d.optimize(
|
|
465
466
|
h,
|
|
466
467
|
new Uint32Array(o.rows),
|
|
467
468
|
new Uint32Array(o.cols),
|
|
468
469
|
i,
|
|
469
470
|
s,
|
|
470
471
|
this._nComponents,
|
|
471
|
-
|
|
472
|
+
t,
|
|
472
473
|
{ a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 },
|
|
473
474
|
r
|
|
474
475
|
);
|
|
475
|
-
} catch (
|
|
476
|
-
console.warn("WebGPU SGD failed, falling back to CPU:",
|
|
476
|
+
} catch (d) {
|
|
477
|
+
console.warn("WebGPU SGD failed, falling back to CPU:", d), this.embedding = z(h, o, i, s, this._nComponents, t, {
|
|
477
478
|
a: this._a,
|
|
478
479
|
b: this._b
|
|
479
480
|
}, r);
|
|
480
481
|
}
|
|
481
482
|
else
|
|
482
|
-
this.embedding = z(h, o, i, s, this._nComponents,
|
|
483
|
+
this.embedding = z(h, o, i, s, this._nComponents, t, {
|
|
483
484
|
a: this._a,
|
|
484
485
|
b: this._b
|
|
485
486
|
}, r);
|
|
@@ -496,12 +497,12 @@ class re {
|
|
|
496
497
|
* returned embedding to [0, 1]. The stored training embedding is never
|
|
497
498
|
* mutated. Defaults to `false`.
|
|
498
499
|
*/
|
|
499
|
-
async transform(
|
|
500
|
+
async transform(n, r = !1) {
|
|
500
501
|
if (!this._hnswIndex || !this.embedding)
|
|
501
502
|
throw new Error("UMAP.transform() must be called after fit()");
|
|
502
|
-
const s =
|
|
503
|
+
const s = n.length, t = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), p = Math.max(100, Math.floor(t / 4)), f = this._hnswIndex.searchKnn(n, this._nNeighbors), l = J(f.indices, f.distances, this._nNeighbors), c = new Uint32Array(l.rows), a = new Uint32Array(l.cols), o = new Float32Array(s), i = new Float32Array(s * this._nComponents);
|
|
503
504
|
for (let u = 0; u < c.length; u++) {
|
|
504
|
-
const g = c[u], m = a[u], w =
|
|
505
|
+
const g = c[u], m = a[u], w = l.vals[u];
|
|
505
506
|
o[g] += w;
|
|
506
507
|
for (let b = 0; b < this._nComponents; b++)
|
|
507
508
|
i[g * this._nComponents + b] += w * this.embedding[m * this._nComponents + b];
|
|
@@ -513,10 +514,10 @@ class re {
|
|
|
513
514
|
else
|
|
514
515
|
for (let g = 0; g < this._nComponents; g++)
|
|
515
516
|
i[u * this._nComponents + g] = Math.random() * 20 - 10;
|
|
516
|
-
const h = I(
|
|
517
|
+
const h = I(l.vals, p), d = $(
|
|
517
518
|
i,
|
|
518
519
|
this.embedding,
|
|
519
|
-
|
|
520
|
+
l,
|
|
520
521
|
h,
|
|
521
522
|
s,
|
|
522
523
|
this._nTrain,
|
|
@@ -524,7 +525,7 @@ class re {
|
|
|
524
525
|
p,
|
|
525
526
|
{ a: this._a, b: this._b }
|
|
526
527
|
);
|
|
527
|
-
return r ? T(
|
|
528
|
+
return r ? T(d, s, this._nComponents) : d;
|
|
528
529
|
}
|
|
529
530
|
/**
|
|
530
531
|
* Convenience method equivalent to `fit(vectors)` followed by
|
|
@@ -535,37 +536,37 @@ class re {
|
|
|
535
536
|
* returned embedding to [0, 1]. `this.embedding` is never mutated.
|
|
536
537
|
* Defaults to `false`.
|
|
537
538
|
*/
|
|
538
|
-
async fit_transform(
|
|
539
|
-
return await this.fit(
|
|
539
|
+
async fit_transform(n, r, s = !1) {
|
|
540
|
+
return await this.fit(n, r), s ? T(this.embedding, n.length, this._nComponents) : this.embedding;
|
|
540
541
|
}
|
|
541
542
|
}
|
|
542
|
-
function T(e,
|
|
543
|
+
function T(e, n, r) {
|
|
543
544
|
const s = new Float32Array(e.length);
|
|
544
|
-
for (let
|
|
545
|
+
for (let t = 0; t < r; t++) {
|
|
545
546
|
let p = 1 / 0, f = -1 / 0;
|
|
546
|
-
for (let c = 0; c <
|
|
547
|
-
const a = e[c * r +
|
|
547
|
+
for (let c = 0; c < n; c++) {
|
|
548
|
+
const a = e[c * r + t];
|
|
548
549
|
a < p && (p = a), a > f && (f = a);
|
|
549
550
|
}
|
|
550
|
-
const
|
|
551
|
-
for (let c = 0; c <
|
|
552
|
-
s[c * r +
|
|
551
|
+
const l = f - p;
|
|
552
|
+
for (let c = 0; c < n; c++)
|
|
553
|
+
s[c * r + t] = l > 0 ? (e[c * r + t] - p) / l : 0;
|
|
553
554
|
}
|
|
554
555
|
return s;
|
|
555
556
|
}
|
|
556
|
-
function I(e,
|
|
557
|
+
function I(e, n) {
|
|
557
558
|
let r = -1 / 0;
|
|
558
|
-
for (let
|
|
559
|
-
e[
|
|
559
|
+
for (let t = 0; t < e.length; t++)
|
|
560
|
+
e[t] > r && (r = e[t]);
|
|
560
561
|
const s = new Float32Array(e.length);
|
|
561
|
-
for (let
|
|
562
|
-
const p = e[
|
|
563
|
-
s[
|
|
562
|
+
for (let t = 0; t < e.length; t++) {
|
|
563
|
+
const p = e[t] / r;
|
|
564
|
+
s[t] = p > 0 ? n / p : -1;
|
|
564
565
|
}
|
|
565
566
|
return s;
|
|
566
567
|
}
|
|
567
568
|
export {
|
|
568
569
|
re as UMAP,
|
|
569
570
|
ae as fit,
|
|
570
|
-
|
|
571
|
+
j as isWebGPUAvailable
|
|
571
572
|
};
|