umap-gpu 0.2.11 → 0.2.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  export interface FuzzyGraph {
2
- rows: Float32Array;
3
- cols: Float32Array;
2
+ rows: Uint32Array;
3
+ cols: Uint32Array;
4
4
  vals: Float32Array;
5
5
  nVertices: number;
6
6
  }
@@ -7,6 +7,23 @@
7
7
  */
8
8
  export declare function getGPUDevice(): Promise<GPUDevice | null>;
9
9
  /**
10
- * Check whether WebGPU is available in the current environment.
10
+ * Fast synchronous heuristic: returns `true` if `navigator.gpu` exists.
11
+ *
12
+ * **Caveat (Bug 13):** `navigator.gpu` being truthy does NOT guarantee that a
13
+ * WebGPU adapter can be acquired — `requestAdapter()` may still return `null`
14
+ * (no compatible hardware, or the browser has disabled WebGPU for the page).
15
+ * Use `checkWebGPUAvailable()` for a reliable async check, or rely on the
16
+ * `try/catch` around `GPUSgd.init()` in the calling code.
11
17
  */
12
18
  export declare function isWebGPUAvailable(): boolean;
19
+ /**
20
+ * Reliably check whether WebGPU is usable in the current environment by
21
+ * attempting to acquire an adapter via `getGPUDevice()`.
22
+ *
23
+ * Unlike the synchronous `isWebGPUAvailable()`, this actually calls
24
+ * `navigator.gpu.requestAdapter()` and returns `false` if the adapter is
25
+ * unavailable (no compatible GPU, browser policy, etc.).
26
+ *
27
+ * The result is automatically cached — repeated calls are free.
28
+ */
29
+ export declare function checkWebGPUAvailable(): Promise<boolean>;
package/dist/gpu/sgd.d.ts CHANGED
@@ -6,11 +6,23 @@ export interface SGDParams {
6
6
  }
7
7
  /**
8
8
  * GPU-accelerated SGD optimizer for UMAP embedding.
9
- * Each GPU thread processes one graph edge, applying attraction and repulsion forces.
9
+ *
10
+ * Uses a two-pass design per epoch to eliminate write-write races on shared
11
+ * vertex positions (Bug 2 fix):
12
+ * Pass 1 (sgd.wgsl): Each thread accumulates its attraction and
13
+ * repulsion gradients into an atomic<i32>
14
+ * forces buffer — no direct embedding writes.
15
+ * Pass 2 (apply-forces.wgsl): Each thread applies one element's accumulated
16
+ * force to the embedding and resets the
17
+ * accumulator to zero for the next epoch.
18
+ *
19
+ * Both passes are submitted in the same command encoder so WebGPU guarantees
20
+ * sequential execution and storage-buffer visibility between them.
10
21
  */
11
22
  export declare class GPUSgd {
12
23
  private device;
13
- private pipeline;
24
+ private sgdPipeline;
25
+ private applyForcesPipeline;
14
26
  init(): Promise<void>;
15
27
  /**
16
28
  * Run SGD optimization on the GPU.
package/dist/index.d.ts CHANGED
@@ -2,4 +2,4 @@ export { fit, UMAP } from './umap';
2
2
  export type { UMAPOptions, ProgressCallback } from './umap';
3
3
  export type { KNNResult, HNSWOptions, HNSWSearchableIndex } from './hnsw-knn';
4
4
  export type { FuzzyGraph } from './fuzzy-set';
5
- export { isWebGPUAvailable } from './gpu/device';
5
+ export { isWebGPUAvailable, checkWebGPUAvailable } from './gpu/device';
package/dist/index.js CHANGED
@@ -1,104 +1,111 @@
1
- var H = Object.defineProperty;
2
- var V = (e, n, r) => n in e ? H(e, n, { enumerable: !0, configurable: !0, writable: !0, value: r }) : e[n] = r;
3
- var G = (e, n, r) => V(e, typeof n != "symbol" ? n + "" : n, r);
4
- import { loadHnswlib as C } from "hnswlib-wasm";
5
- async function Y(e, n, r = {}) {
6
- const { M: s = 16, efConstruction: t = 200, efSearch: p = 50 } = r, f = await C(), l = e[0].length, c = e.length, a = new f.HierarchicalNSW("l2", l, "");
7
- a.initIndex(c, s, t, 200), a.setEfSearch(Math.max(p, n)), a.addItems(e, !1);
8
- const o = [], i = [];
9
- for (let h = 0; h < c; h++) {
10
- const d = a.searchKnn(e[h], n + 1, void 0), u = d.neighbors.map((g, m) => ({ idx: g, dist: d.distances[m] })).filter(({ idx: g }) => g !== h).slice(0, n);
11
- o.push(u.map(({ idx: g }) => g)), i.push(u.map(({ dist: g }) => g));
1
+ var te = Object.defineProperty;
2
+ var ne = (e, t, f) => t in e ? te(e, t, { enumerable: !0, configurable: !0, writable: !0, value: f }) : e[t] = f;
3
+ var C = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
4
+ import { loadHnswlib as H } from "hnswlib-wasm";
5
+ async function se(e, t, f = {}) {
6
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, i = e.length, n = new c.HierarchicalNSW("l2", d, "");
7
+ n.initIndex(i, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
8
+ const r = [], o = [];
9
+ for (let h = 0; h < i; h++) {
10
+ const l = n.searchKnn(e[h], t + 1, void 0), p = l.neighbors.map((g, _) => ({ idx: g, dist: l.distances[_] })).filter(({ idx: g }) => g !== h).slice(0, t);
11
+ r.push(p.map(({ idx: g }) => g)), o.push(p.map(({ dist: g }) => Math.sqrt(g)));
12
12
  }
13
- return { indices: o, distances: i };
13
+ return { indices: r, distances: o };
14
14
  }
15
- async function Q(e, n, r = {}) {
16
- const { M: s = 16, efConstruction: t = 200, efSearch: p = 50 } = r, f = await C(), l = e[0].length, c = e.length, a = new f.HierarchicalNSW("l2", l, "");
17
- a.initIndex(c, s, t, 200), a.setEfSearch(Math.max(p, n)), a.addItems(e, !1);
18
- const o = [], i = [];
19
- for (let d = 0; d < c; d++) {
20
- const u = a.searchKnn(e[d], n + 1, void 0), g = u.neighbors.map((m, w) => ({ idx: m, dist: u.distances[w] })).filter(({ idx: m }) => m !== d).slice(0, n);
21
- o.push(g.map(({ idx: m }) => m)), i.push(g.map(({ dist: m }) => m));
15
+ async function ae(e, t, f = {}) {
16
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, i = e.length, n = new c.HierarchicalNSW("l2", d, "");
17
+ n.initIndex(i, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
18
+ const r = [], o = [];
19
+ for (let l = 0; l < i; l++) {
20
+ const p = n.searchKnn(e[l], t + 1, void 0), g = p.neighbors.map((_, m) => ({ idx: _, dist: p.distances[m] })).filter(({ idx: _ }) => _ !== l).slice(0, t);
21
+ r.push(g.map(({ idx: _ }) => _)), o.push(g.map(({ dist: _ }) => Math.sqrt(_)));
22
22
  }
23
- return { knn: { indices: o, distances: i }, index: {
24
- searchKnn(d, u) {
25
- const g = [], m = [];
26
- for (const w of d) {
27
- const b = a.searchKnn(w, u, void 0), y = b.neighbors.map((v, x) => ({ idx: v, dist: b.distances[x] })).sort((v, x) => v.dist - x.dist).slice(0, u);
28
- g.push(y.map(({ idx: v }) => v)), m.push(y.map(({ dist: v }) => v));
23
+ return { knn: { indices: r, distances: o }, index: {
24
+ searchKnn(l, p) {
25
+ const g = [], _ = [];
26
+ for (const m of l) {
27
+ const y = n.searchKnn(m, p, void 0), b = y.neighbors.map((x, v) => ({ idx: x, dist: y.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
28
+ g.push(b.map(({ idx: x }) => x)), _.push(b.map(({ dist: x }) => Math.sqrt(x)));
29
29
  }
30
- return { indices: g, distances: m };
30
+ return { indices: g, distances: _ };
31
31
  }
32
32
  } };
33
33
  }
34
- function D(e, n, r, s = 1) {
35
- const t = e.length, { sigmas: p, rhos: f } = L(n, r), l = [], c = [], a = [];
36
- for (let i = 0; i < t; i++)
37
- for (let h = 0; h < e[i].length; h++) {
38
- const d = n[i][h], u = d <= f[i] ? 1 : Math.exp(-((d - f[i]) / p[i]));
39
- l.push(i), c.push(e[i][h]), a.push(u);
34
+ function V(e, t, f, a = 1) {
35
+ const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], i = [], n = [];
36
+ for (let o = 0; o < s; o++)
37
+ for (let h = 0; h < e[o].length; h++) {
38
+ const l = t[o][h], p = l <= c[o] ? 1 : Math.exp(-((l - c[o]) / u[o]));
39
+ d.push(o), i.push(e[o][h]), n.push(p);
40
40
  }
41
- return { ...X(l, c, a, t, s), nVertices: t };
41
+ return { ...oe(d, i, n, s, a), nVertices: s };
42
42
  }
43
- function J(e, n, r) {
44
- const s = e.length, { sigmas: t, rhos: p } = L(n, r), f = [], l = [], c = [];
45
- for (let a = 0; a < s; a++)
46
- for (let o = 0; o < e[a].length; o++) {
47
- const i = n[a][o], h = i <= p[a] ? 1 : Math.exp(-((i - p[a]) / t[a]));
48
- f.push(a), l.push(e[a][o]), c.push(h);
43
+ function re(e, t, f) {
44
+ const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], i = [];
45
+ for (let n = 0; n < a; n++)
46
+ for (let r = 0; r < e[n].length; r++) {
47
+ const o = t[n][r], h = o <= u[n] ? 1 : Math.exp(-((o - u[n]) / s[n]));
48
+ c.push(n), d.push(e[n][r]), i.push(h);
49
49
  }
50
50
  return {
51
- rows: new Float32Array(f),
52
- cols: new Float32Array(l),
53
- vals: new Float32Array(c),
54
- nVertices: s
51
+ rows: new Uint32Array(c),
52
+ cols: new Uint32Array(d),
53
+ vals: new Float32Array(i),
54
+ nVertices: a
55
55
  };
56
56
  }
57
- function L(e, n) {
58
- const s = e.length, t = new Float32Array(s), p = new Float32Array(s);
59
- for (let f = 0; f < s; f++) {
60
- const l = e[f];
61
- p[f] = l.find((h) => h > 0) ?? 0;
62
- let c = 0, a = 1 / 0, o = 1;
63
- const i = Math.log2(n);
57
+ function Y(e, t) {
58
+ const a = e.length, s = new Float32Array(a), u = new Float32Array(a);
59
+ for (let c = 0; c < a; c++) {
60
+ const d = e[c];
61
+ u[c] = d.find((h) => h > 0) ?? 0;
62
+ let i = 0, n = 1 / 0, r = 1;
63
+ const o = Math.log2(t);
64
64
  for (let h = 0; h < 64; h++) {
65
- let d = 0;
66
- for (let u = 0; u < l.length; u++)
67
- d += Math.exp(-Math.max(0, l[u] - p[f]) / o);
68
- if (Math.abs(d - i) < 1e-5) break;
69
- d > i ? (a = o, o = (c + a) / 2) : (c = o, o = a === 1 / 0 ? o * 2 : (c + a) / 2);
65
+ let l = 0;
66
+ for (let p = 0; p < d.length; p++)
67
+ l += Math.exp(-Math.max(0, d[p] - u[c]) / r);
68
+ if (Math.abs(l - o) < 1e-5) break;
69
+ l > o ? (n = r, r = (i + n) / 2) : (i = r, r = n === 1 / 0 ? r * 2 : (i + n) / 2);
70
70
  }
71
- t[f] = o;
71
+ s[c] = r;
72
72
  }
73
- return { sigmas: t, rhos: p };
73
+ return { sigmas: s, rhos: u };
74
74
  }
75
- function X(e, n, r, s, t) {
76
- const p = /* @__PURE__ */ new Map();
77
- for (let a = 0; a < e.length; a++)
78
- p.set(e[a] * s + n[a], r[a]);
79
- const f = [], l = [], c = [];
80
- for (const [a, o] of p) {
81
- const i = Math.floor(a / s), h = a % s, d = p.get(h * s + i) ?? 0, u = o + d - o * d, g = o * d;
82
- f.push(i), l.push(h), c.push(t * u + (1 - t) * g);
75
+ function oe(e, t, f, a, s) {
76
+ const u = /* @__PURE__ */ new Map();
77
+ for (let n = 0; n < e.length; n++)
78
+ u.set(e[n] * a + t[n], f[n]);
79
+ const c = [], d = [], i = [];
80
+ for (const [n, r] of u) {
81
+ const o = Math.floor(n / a), h = n % a, l = u.get(h * a + o) ?? 0, p = r + l - r * l, g = r * l;
82
+ c.push(o), d.push(h), i.push(s * p + (1 - s) * g);
83
83
  }
84
84
  return {
85
- rows: new Float32Array(f),
86
- cols: new Float32Array(l),
87
- vals: new Float32Array(c)
85
+ rows: new Uint32Array(c),
86
+ cols: new Uint32Array(d),
87
+ vals: new Float32Array(i)
88
88
  };
89
89
  }
90
- const Z = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
- // Applies attraction forces between connected nodes and repulsion forces
92
- // against negative samples.
90
+ const ie = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
+ // Computes attraction and repulsion forces and accumulates them atomically
92
+ // into a forces buffer. A separate apply-forces pass then updates embeddings,
93
+ // eliminating write-write races on shared vertex positions.
93
94
 
94
95
  @group(0) @binding(0) var<storage, read> epochs_per_sample : array<f32>;
95
96
  @group(0) @binding(1) var<storage, read> head : array<u32>; // edge source
96
97
  @group(0) @binding(2) var<storage, read> tail : array<u32>; // edge target
97
- @group(0) @binding(3) var<storage, read_write> embedding : array<f32>; // [n * nComponents]
98
+ @group(0) @binding(3) var<storage, read> embedding : array<f32>; // [n * nComponents], read-only
98
99
  @group(0) @binding(4) var<storage, read_write> epoch_of_next_sample : array<f32>;
99
100
  @group(0) @binding(5) var<storage, read_write> epoch_of_next_negative_sample : array<f32>;
100
101
  @group(0) @binding(6) var<uniform> params : Params;
101
102
  @group(0) @binding(7) var<storage, read> rng_seeds : array<u32>; // per-edge seed
103
+ @group(0) @binding(8) var<storage, read_write> forces : array<atomic<i32>>; // [n * nComponents]
104
+
105
+ // Scale factor for quantizing f32 gradients into i32 for atomic accumulation.
106
+ // Gradients are clipped to [-4, 4]. With up to ~1000 edges sharing a vertex
107
+ // the max accumulated magnitude is ~4000, well within i32 range at this scale.
108
+ const FORCE_SCALE : f32 = 65536.0;
102
109
 
103
110
  struct Params {
104
111
  n_edges : u32,
@@ -106,7 +113,7 @@ struct Params {
106
113
  n_components : u32,
107
114
  current_epoch : u32,
108
115
  n_epochs : u32,
109
- alpha : f32, // learning rate
116
+ alpha : f32, // learning rate (applied by apply-forces pass)
110
117
  a : f32,
111
118
  b : f32,
112
119
  gamma : f32, // repulsion strength
@@ -147,7 +154,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
147
154
 
148
155
  let pow_b = pow(dist_sq, params.b);
149
156
  // Guard dist_sq == 0: b-1 is negative so pow(0, b-1) = +Inf.
150
- // Mirror CPU: use pow_b / dist_sq only when dist_sq > 0, else 0.
151
157
  let grad_coeff_attr = select(
152
158
  -2.0 * params.a * params.b * (pow_b / dist_sq) / (params.a * pow_b + 1.0),
153
159
  0.0,
@@ -157,19 +163,24 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
157
163
  for (var d = 0u; d < nc; d++) {
158
164
  let diff = embedding[i * nc + d] - embedding[j * nc + d];
159
165
  let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
160
- embedding[i * nc + d] += params.alpha * grad;
161
- embedding[j * nc + d] -= params.alpha * grad;
166
+ // Accumulate atomically to avoid write-write races across threads.
167
+ atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
168
+ atomicAdd(&forces[j * nc + d], -i32(grad * FORCE_SCALE));
162
169
  }
163
170
 
164
171
  epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
165
172
 
166
173
  // --- Repulsion (negative samples) ---
167
- let eps = epochs_per_sample[edge_idx];
168
- let neg_eps = epoch_of_next_negative_sample[edge_idx];
174
+ // Compute how many negative samples are overdue relative to current epoch,
175
+ // matching the Python reference: n_neg = floor((n - next_neg) / eps_per_neg).
176
+ let epoch_f = f32(params.current_epoch);
177
+ let epochs_per_neg = epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
169
178
  var n_neg = 0u;
170
- if (neg_eps > 0.0) {
171
- n_neg = u32(eps / neg_eps);
179
+ if (epochs_per_neg > 0.0 && epoch_f >= epoch_of_next_negative_sample[edge_idx]) {
180
+ n_neg = u32((epoch_f - epoch_of_next_negative_sample[edge_idx]) / epochs_per_neg);
181
+ epoch_of_next_negative_sample[edge_idx] += f32(n_neg) * epochs_per_neg;
172
182
  }
183
+
173
184
  var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 6364136223u);
174
185
 
175
186
  for (var s = 0u; s < n_neg; s++) {
@@ -189,26 +200,73 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
189
200
  for (var d = 0u; d < nc; d++) {
190
201
  let diff = embedding[i * nc + d] - embedding[k * nc + d];
191
202
  let grad = clip(grad_coeff_rep * diff, -4.0, 4.0);
192
- embedding[i * nc + d] += params.alpha * grad;
203
+ atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
193
204
  }
194
205
  }
206
+ }
207
+ `, ce = `// Apply-forces shader — second pass of the two-pass GPU SGD.
208
+ //
209
+ // After the SGD pass has atomically accumulated all gradients into the forces
210
+ // buffer, this shader applies each element's accumulated force to the
211
+ // embedding and resets the accumulator to zero for the next epoch.
195
212
 
196
- epoch_of_next_negative_sample[edge_idx] +=
197
- epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
213
+ @group(0) @binding(0) var<storage, read_write> embedding : array<f32>;
214
+ @group(0) @binding(1) var<storage, read_write> forces : array<atomic<i32>>;
215
+ @group(0) @binding(2) var<uniform> params : ApplyParams;
216
+
217
+ struct ApplyParams {
218
+ n_elements : u32, // nVertices * nComponents
219
+ alpha : f32, // current learning rate
220
+ }
221
+
222
+ // Must match FORCE_SCALE in sgd.wgsl
223
+ const FORCE_SCALE : f32 = 65536.0;
224
+
225
+ @compute @workgroup_size(256)
226
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
227
+ let idx = gid.x;
228
+ if (idx >= params.n_elements) { return; }
229
+
230
+ // atomicExchange atomically reads the accumulated force and resets it to 0.
231
+ let raw = atomicExchange(&forces[idx], 0);
232
+ embedding[idx] += params.alpha * f32(raw) / FORCE_SCALE;
198
233
  }
199
234
  `;
200
- class W {
235
+ let z = null;
236
+ async function Q() {
237
+ if (z) return z;
238
+ if (typeof navigator > "u" || !navigator.gpu)
239
+ return null;
240
+ const e = await navigator.gpu.requestAdapter();
241
+ return e ? (z = await e.requestDevice(), z.lost.then(() => {
242
+ z = null;
243
+ }), z) : null;
244
+ }
245
+ function J() {
246
+ return typeof navigator < "u" && !!navigator.gpu;
247
+ }
248
+ async function he() {
249
+ return await Q() !== null;
250
+ }
251
+ class X {
201
252
  constructor() {
202
- G(this, "device");
203
- G(this, "pipeline");
253
+ C(this, "device");
254
+ C(this, "sgdPipeline");
255
+ C(this, "applyForcesPipeline");
204
256
  }
205
257
  async init() {
206
- const n = await navigator.gpu.requestAdapter();
207
- if (!n) throw new Error("WebGPU not supported");
208
- this.device = await n.requestDevice(), this.pipeline = this.device.createComputePipeline({
258
+ const t = await Q();
259
+ if (!t) throw new Error("WebGPU not supported");
260
+ this.device = t, this.sgdPipeline = this.device.createComputePipeline({
261
+ layout: "auto",
262
+ compute: {
263
+ module: this.device.createShaderModule({ code: ie }),
264
+ entryPoint: "main"
265
+ }
266
+ }), this.applyForcesPipeline = this.device.createComputePipeline({
209
267
  layout: "auto",
210
268
  compute: {
211
- module: this.device.createShaderModule({ code: Z }),
269
+ module: this.device.createShaderModule({ code: ce }),
212
270
  entryPoint: "main"
213
271
  }
214
272
  });
@@ -226,218 +284,254 @@ class W {
226
284
  * @param params - UMAP curve parameters and repulsion settings
227
285
  * @returns Optimized embedding as Float32Array
228
286
  */
229
- async optimize(n, r, s, t, p, f, l, c, a) {
230
- const { device: o } = this, i = r.length, h = this.makeBuffer(
231
- n,
287
+ async optimize(t, f, a, s, u, c, d, i, n) {
288
+ const { device: r } = this, o = f.length, h = u * c, l = this.makeBuffer(
289
+ t,
232
290
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
233
- ), d = this.makeBuffer(r, GPUBufferUsage.STORAGE), u = this.makeBuffer(s, GPUBufferUsage.STORAGE), g = this.makeBuffer(t, GPUBufferUsage.STORAGE), m = new Float32Array(i).fill(0), w = this.makeBuffer(m, GPUBufferUsage.STORAGE), b = new Float32Array(i);
234
- for (let _ = 0; _ < i; _++)
235
- b[_] = t[_] / c.negativeSampleRate;
236
- const y = this.makeBuffer(b, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
237
- for (let _ = 0; _ < i; _++)
238
- v[_] = Math.random() * 4294967295 | 0;
239
- const x = this.makeBuffer(v, GPUBufferUsage.STORAGE), B = o.createBuffer({
291
+ ), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), y = this.makeBuffer(m, GPUBufferUsage.STORAGE), b = new Float32Array(o);
292
+ for (let A = 0; A < o; A++)
293
+ b[A] = s[A] / i.negativeSampleRate;
294
+ const x = this.makeBuffer(b, GPUBufferUsage.STORAGE), v = new Uint32Array(o);
295
+ for (let A = 0; A < o; A++)
296
+ v[A] = Math.random() * 4294967295 | 0;
297
+ const P = this.makeBuffer(v, GPUBufferUsage.STORAGE), G = r.createBuffer({
298
+ size: h * 4,
299
+ usage: GPUBufferUsage.STORAGE,
300
+ mappedAtCreation: !0
301
+ });
302
+ new Int32Array(G.getMappedRange()).fill(0), G.unmap();
303
+ const U = r.createBuffer({
240
304
  size: 40,
241
305
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
306
+ }), F = r.createBuffer({
307
+ size: 16,
308
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
309
+ }), N = r.createBindGroup({
310
+ layout: this.sgdPipeline.getBindGroupLayout(0),
311
+ entries: [
312
+ { binding: 0, resource: { buffer: _ } },
313
+ { binding: 1, resource: { buffer: p } },
314
+ { binding: 2, resource: { buffer: g } },
315
+ { binding: 3, resource: { buffer: l } },
316
+ { binding: 4, resource: { buffer: y } },
317
+ { binding: 5, resource: { buffer: x } },
318
+ { binding: 6, resource: { buffer: U } },
319
+ { binding: 7, resource: { buffer: P } },
320
+ { binding: 8, resource: { buffer: G } }
321
+ ]
322
+ }), M = r.createBindGroup({
323
+ layout: this.applyForcesPipeline.getBindGroupLayout(0),
324
+ entries: [
325
+ { binding: 0, resource: { buffer: l } },
326
+ { binding: 1, resource: { buffer: G } },
327
+ { binding: 2, resource: { buffer: F } }
328
+ ]
242
329
  });
243
- for (let _ = 0; _ < l; _++) {
244
- const F = 1 - _ / l, A = new ArrayBuffer(40), M = new Uint32Array(A), N = new Float32Array(A);
245
- M[0] = i, M[1] = p, M[2] = f, M[3] = _, M[4] = l, N[5] = F, N[6] = c.a, N[7] = c.b, N[8] = c.gamma, M[9] = c.negativeSampleRate, o.queue.writeBuffer(B, 0, A);
246
- const S = o.createBindGroup({
247
- layout: this.pipeline.getBindGroupLayout(0),
248
- entries: [
249
- { binding: 0, resource: { buffer: g } },
250
- { binding: 1, resource: { buffer: d } },
251
- { binding: 2, resource: { buffer: u } },
252
- { binding: 3, resource: { buffer: h } },
253
- { binding: 4, resource: { buffer: w } },
254
- { binding: 5, resource: { buffer: y } },
255
- { binding: 6, resource: { buffer: B } },
256
- { binding: 7, resource: { buffer: x } }
257
- ]
258
- }), O = o.createCommandEncoder(), U = O.beginComputePass();
259
- U.setPipeline(this.pipeline), U.setBindGroup(0, S), U.dispatchWorkgroups(Math.ceil(i / 256)), U.end(), o.queue.submit([O.finish()]), _ % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(_, l));
330
+ for (let A = 0; A < d; A++) {
331
+ const B = 1 - A / d, O = new ArrayBuffer(40), S = new Uint32Array(O), k = new Float32Array(O);
332
+ S[0] = o, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = i.a, k[7] = i.b, k[8] = i.gamma, S[9] = i.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
333
+ const D = new ArrayBuffer(16), $ = new Uint32Array(D), ee = new Float32Array(D);
334
+ $[0] = h, ee[1] = B, r.queue.writeBuffer(F, 0, D);
335
+ const T = r.createCommandEncoder(), q = T.beginComputePass();
336
+ q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(o / 256)), q.end();
337
+ const I = T.beginComputePass();
338
+ I.setPipeline(this.applyForcesPipeline), I.setBindGroup(0, M), I.dispatchWorkgroups(Math.ceil(h / 256)), I.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
260
339
  }
261
- const E = o.createBuffer({
262
- size: n.byteLength,
340
+ const E = r.createBuffer({
341
+ size: t.byteLength,
263
342
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
264
- }), R = o.createCommandEncoder();
265
- R.copyBufferToBuffer(h, 0, E, 0, n.byteLength), o.queue.submit([R.finish()]), await E.mapAsync(GPUMapMode.READ);
266
- const k = new Float32Array(E.getMappedRange().slice(0));
267
- return E.unmap(), h.destroy(), d.destroy(), u.destroy(), g.destroy(), w.destroy(), y.destroy(), x.destroy(), B.destroy(), E.destroy(), k;
343
+ }), w = r.createCommandEncoder();
344
+ w.copyBufferToBuffer(l, 0, E, 0, t.byteLength), r.queue.submit([w.finish()]), await E.mapAsync(GPUMapMode.READ);
345
+ const R = new Float32Array(E.getMappedRange().slice(0));
346
+ return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), y.destroy(), x.destroy(), P.destroy(), G.destroy(), U.destroy(), F.destroy(), E.destroy(), R;
268
347
  }
269
- makeBuffer(n, r) {
270
- const s = this.device.createBuffer({
271
- size: n.byteLength,
272
- usage: r,
348
+ makeBuffer(t, f) {
349
+ const a = this.device.createBuffer({
350
+ size: t.byteLength,
351
+ usage: f,
273
352
  mappedAtCreation: !0
274
353
  });
275
- return n instanceof Float32Array ? new Float32Array(s.getMappedRange()).set(n) : new Uint32Array(s.getMappedRange()).set(n), s.unmap(), s;
354
+ return t instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(t) : new Uint32Array(a.getMappedRange()).set(t), a.unmap(), a;
276
355
  }
277
356
  }
278
- function q(e) {
357
+ function j(e) {
279
358
  return Math.max(-4, Math.min(4, e));
280
359
  }
281
- function z(e, n, r, s, t, p, f, l) {
282
- const { a: c, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = f, h = n.rows.length, d = new Uint32Array(n.rows), u = new Uint32Array(n.cols), g = new Float32Array(h).fill(0), m = new Float32Array(h);
283
- for (let w = 0; w < h; w++)
284
- m[w] = r[w] / i;
285
- for (let w = 0; w < p; w++) {
286
- l == null || l(w, p);
287
- const b = 1 - w / p;
288
- for (let y = 0; y < h; y++) {
289
- if (g[y] > w) continue;
290
- const v = d[y], x = u[y];
291
- let B = 0;
292
- for (let _ = 0; _ < t; _++) {
293
- const F = e[v * t + _] - e[x * t + _];
294
- B += F * F;
360
+ function L(e, t, f, a, s, u, c, d) {
361
+ const { a: i, b: n, gamma: r = 1, negativeSampleRate: o = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
362
+ for (let m = 0; m < h; m++)
363
+ _[m] = f[m] / o;
364
+ for (let m = 0; m < u; m++) {
365
+ d == null || d(m, u);
366
+ const y = 1 - m / u;
367
+ for (let b = 0; b < h; b++) {
368
+ if (g[b] > m) continue;
369
+ const x = l[b], v = p[b];
370
+ let P = 0;
371
+ for (let M = 0; M < s; M++) {
372
+ const E = e[x * s + M] - e[v * s + M];
373
+ P += E * E;
295
374
  }
296
- const E = Math.pow(B, a), R = -2 * c * a * (B > 0 ? E / B : 0) / (c * E + 1);
297
- for (let _ = 0; _ < t; _++) {
298
- const F = e[v * t + _] - e[x * t + _], A = q(R * F);
299
- e[v * t + _] += b * A, e[x * t + _] -= b * A;
375
+ const G = Math.pow(P, n), U = -2 * i * n * (P > 0 ? G / P : 0) / (i * G + 1);
376
+ for (let M = 0; M < s; M++) {
377
+ const E = e[x * s + M] - e[v * s + M], w = j(U * E);
378
+ e[x * s + M] += y * w, e[v * s + M] -= y * w;
300
379
  }
301
- g[y] += r[y];
302
- const k = m[y] > 0 ? Math.floor(r[y] / m[y]) : 0;
303
- for (let _ = 0; _ < k; _++) {
304
- const F = Math.floor(Math.random() * s);
305
- if (F === v) continue;
306
- let A = 0;
307
- for (let S = 0; S < t; S++) {
308
- const O = e[v * t + S] - e[F * t + S];
309
- A += O * O;
380
+ g[b] += f[b];
381
+ const F = f[b] / o, N = Math.max(0, Math.floor(
382
+ (m - _[b]) / F
383
+ ));
384
+ _[b] += N * F;
385
+ for (let M = 0; M < N; M++) {
386
+ const E = Math.floor(Math.random() * a);
387
+ if (E === x) continue;
388
+ let w = 0;
389
+ for (let B = 0; B < s; B++) {
390
+ const O = e[x * s + B] - e[E * s + B];
391
+ w += O * O;
310
392
  }
311
- const M = Math.pow(A, a), N = 2 * o * a / ((1e-3 + A) * (c * M + 1));
312
- for (let S = 0; S < t; S++) {
313
- const O = e[v * t + S] - e[F * t + S], U = q(N * O);
314
- e[v * t + S] += b * U;
393
+ const R = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (i * R + 1));
394
+ for (let B = 0; B < s; B++) {
395
+ const O = e[x * s + B] - e[E * s + B], S = j(A * O);
396
+ e[x * s + B] += y * S;
315
397
  }
316
398
  }
317
- m[y] += r[y] / i;
318
399
  }
319
400
  }
320
401
  return e;
321
402
  }
322
- function $(e, n, r, s, t, p, f, l, c, a) {
323
- const { a: o, b: i, gamma: h = 1, negativeSampleRate: d = 5 } = c, u = r.rows.length, g = new Uint32Array(r.rows), m = new Uint32Array(r.cols), w = new Float32Array(u).fill(0), b = new Float32Array(u);
324
- for (let y = 0; y < u; y++)
325
- b[y] = s[y] / d;
326
- for (let y = 0; y < l; y++) {
327
- const v = 1 - y / l;
328
- for (let x = 0; x < u; x++) {
329
- if (w[x] > y) continue;
330
- const B = g[x], E = m[x];
331
- let R = 0;
332
- for (let A = 0; A < f; A++) {
333
- const M = e[B * f + A] - n[E * f + A];
334
- R += M * M;
403
+ function fe(e, t, f, a, s, u, c, d, i, n) {
404
+ const { a: r, b: o, gamma: h = 1, negativeSampleRate: l = 5 } = i, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), y = new Float32Array(p);
405
+ for (let b = 0; b < p; b++)
406
+ y[b] = a[b] / l;
407
+ for (let b = 0; b < d; b++) {
408
+ const x = 1 - b / d;
409
+ for (let v = 0; v < p; v++) {
410
+ if (m[v] > b) continue;
411
+ const P = g[v], G = _[v];
412
+ let U = 0;
413
+ for (let w = 0; w < c; w++) {
414
+ const R = e[P * c + w] - t[G * c + w];
415
+ U += R * R;
335
416
  }
336
- const k = Math.pow(R, i), _ = -2 * o * i * (R > 0 ? k / R : 0) / (o * k + 1);
337
- for (let A = 0; A < f; A++) {
338
- const M = e[B * f + A] - n[E * f + A];
339
- e[B * f + A] += v * q(_ * M);
417
+ const F = Math.pow(U, o), N = -2 * r * o * (U > 0 ? F / U : 0) / (r * F + 1);
418
+ for (let w = 0; w < c; w++) {
419
+ const R = e[P * c + w] - t[G * c + w];
420
+ e[P * c + w] += x * j(N * R);
340
421
  }
341
- w[x] += s[x];
342
- const F = b[x] > 0 ? Math.floor(s[x] / b[x]) : 0;
343
- for (let A = 0; A < F; A++) {
344
- const M = Math.floor(Math.random() * p);
345
- if (M === E) continue;
346
- let N = 0;
347
- for (let U = 0; U < f; U++) {
348
- const P = e[B * f + U] - n[M * f + U];
349
- N += P * P;
422
+ m[v] += a[v];
423
+ const M = a[v] / l, E = Math.max(0, Math.floor(
424
+ (b - y[v]) / M
425
+ ));
426
+ y[v] += E * M;
427
+ for (let w = 0; w < E; w++) {
428
+ const R = Math.floor(Math.random() * u);
429
+ if (R === G) continue;
430
+ let A = 0;
431
+ for (let S = 0; S < c; S++) {
432
+ const k = e[P * c + S] - t[R * c + S];
433
+ A += k * k;
350
434
  }
351
- const S = Math.pow(N, i), O = 2 * h * i / ((1e-3 + N) * (o * S + 1));
352
- for (let U = 0; U < f; U++) {
353
- const P = e[B * f + U] - n[M * f + U];
354
- e[B * f + U] += v * q(O * P);
435
+ const B = Math.pow(A, o), O = 2 * h * o / ((1e-3 + A) * (r * B + 1));
436
+ for (let S = 0; S < c; S++) {
437
+ const k = e[P * c + S] - t[R * c + S];
438
+ e[P * c + S] += x * j(O * k);
355
439
  }
356
440
  }
357
- b[x] += s[x] / d;
358
441
  }
359
442
  }
360
443
  return e;
361
444
  }
362
- function j() {
363
- return typeof navigator < "u" && !!navigator.gpu;
364
- }
365
- async function ae(e, n = {}, r) {
445
+ async function pe(e, t = {}, f) {
366
446
  const {
367
- nComponents: s = 2,
368
- nNeighbors: t = 15,
369
- minDist: p = 0.1,
370
- spread: f = 1,
371
- hnsw: l = {}
372
- } = n, c = n.nEpochs ?? (e.length > 1e4 ? 200 : 500);
447
+ nComponents: a = 2,
448
+ nNeighbors: s = 15,
449
+ minDist: u = 0.1,
450
+ spread: c = 1,
451
+ hnsw: d = {}
452
+ } = t, i = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
373
453
  console.time("knn");
374
- const { indices: a, distances: o } = await Y(e, t, {
375
- M: l.M ?? 16,
376
- efConstruction: l.efConstruction ?? 200,
377
- efSearch: l.efSearch ?? 50
454
+ const { indices: n, distances: r } = await se(e, s, {
455
+ M: d.M ?? 16,
456
+ efConstruction: d.efConstruction ?? 200,
457
+ efSearch: d.efSearch ?? 50
378
458
  });
379
459
  console.timeEnd("knn"), console.time("fuzzy-set");
380
- const i = D(a, o, t);
460
+ const o = V(n, r, s);
381
461
  console.timeEnd("fuzzy-set");
382
- const { a: h, b: d } = K(p, f), u = I(i.vals, c), g = e.length, m = new Float32Array(g * s);
383
- for (let b = 0; b < m.length; b++)
384
- m[b] = Math.random() * 20 - 10;
462
+ const { a: h, b: l } = Z(u, c), p = W(o.vals), g = e.length, _ = new Float32Array(g * a);
463
+ for (let y = 0; y < _.length; y++)
464
+ _[y] = Math.random() * 20 - 10;
385
465
  console.time("sgd");
386
- let w;
387
- if (j())
466
+ let m;
467
+ if (J())
388
468
  try {
389
- const b = new W();
390
- await b.init(), w = await b.optimize(
391
- m,
392
- new Uint32Array(i.rows),
393
- new Uint32Array(i.cols),
394
- u,
469
+ const y = new X();
470
+ await y.init(), m = await y.optimize(
471
+ _,
472
+ new Uint32Array(o.rows),
473
+ new Uint32Array(o.cols),
474
+ p,
395
475
  g,
396
- s,
397
- c,
398
- { a: h, b: d, gamma: 1, negativeSampleRate: 5 },
399
- r
476
+ a,
477
+ i,
478
+ { a: h, b: l, gamma: 1, negativeSampleRate: 5 },
479
+ f
400
480
  );
401
- } catch (b) {
402
- console.warn("WebGPU SGD failed, falling back to CPU:", b), w = z(m, i, u, g, s, c, { a: h, b: d }, r);
481
+ } catch (y) {
482
+ console.warn("WebGPU SGD failed, falling back to CPU:", y), m = L(_, o, p, g, a, i, { a: h, b: l }, f);
403
483
  }
404
484
  else
405
- w = z(m, i, u, g, s, c, { a: h, b: d }, r);
406
- return console.timeEnd("sgd"), w;
407
- }
408
- function K(e, n) {
409
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
410
- return { a: 1.9292, b: 0.7915 };
411
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
412
- return { a: 1.8956, b: 0.8006 };
413
- if (Math.abs(n - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
414
- return { a: 1.5769, b: 0.8951 };
415
- const r = ee(e, n);
416
- return { a: ne(e, n, r), b: r };
485
+ m = L(_, o, p, g, a, i, { a: h, b: l }, f);
486
+ return console.timeEnd("sgd"), m;
417
487
  }
418
- function ee(e, n) {
419
- return 1 / (n * 1.2);
488
+ function Z(e, t) {
489
+ return Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6 ? { a: 1.9292, b: 0.7915 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6 ? { a: 1.8956, b: 0.8006 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6 ? { a: 1.5769, b: 0.8951 } : de(e, t);
420
490
  }
421
- function ne(e, n, r) {
422
- return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * r);
491
+ function de(e, t) {
492
+ const a = [], s = [];
493
+ for (let i = 0; i < 299; i++) {
494
+ const n = (i + 1) / 299 * t * 3;
495
+ a.push(n), s.push(n < e ? 1 : Math.exp(-(n - e) / t));
496
+ }
497
+ let u = 1, c = 1, d = 1e-3;
498
+ for (let i = 0; i < 500; i++) {
499
+ let n = 0, r = 0, o = 0, h = 0, l = 0, p = 0;
500
+ for (let U = 0; U < 299; U++) {
501
+ const F = a[U], N = Math.pow(F, 2 * c), M = 1 + u * N, w = 1 / M - s[U];
502
+ p += w * w;
503
+ const R = M * M, A = -N / R, B = F > 0 ? -2 * Math.log(F) * u * N / R : 0;
504
+ n += A * w, r += B * w, o += A * A, h += B * B, l += A * B;
505
+ }
506
+ const g = o + d, _ = h + d, m = l, y = g * _ - m * m;
507
+ if (Math.abs(y) < 1e-20) break;
508
+ const b = -(_ * n - m * r) / y, x = -(g * r - m * n) / y, v = Math.max(1e-4, u + b), P = Math.max(1e-4, c + x);
509
+ let G = 0;
510
+ for (let U = 0; U < 299; U++) {
511
+ const F = Math.pow(a[U], 2 * P), N = 1 / (1 + v * F) - s[U];
512
+ G += N * N;
513
+ }
514
+ if (G < p ? (u = v, c = P, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(b) < 1e-8 && Math.abs(x) < 1e-8) break;
515
+ }
516
+ return { a: u, b: c };
423
517
  }
424
- class re {
425
- constructor(n = {}) {
426
- G(this, "_nComponents");
427
- G(this, "_nNeighbors");
428
- G(this, "_minDist");
429
- G(this, "_spread");
430
- G(this, "_nEpochs");
431
- G(this, "_hnswOpts");
432
- G(this, "_a");
433
- G(this, "_b");
518
+ class ge {
519
+ constructor(t = {}) {
520
+ C(this, "_nComponents");
521
+ C(this, "_nNeighbors");
522
+ C(this, "_minDist");
523
+ C(this, "_spread");
524
+ C(this, "_nEpochs");
525
+ C(this, "_hnswOpts");
526
+ C(this, "_a");
527
+ C(this, "_b");
434
528
  /** The low-dimensional embedding produced by the last fit() call. */
435
- G(this, "embedding", null);
436
- G(this, "_hnswIndex", null);
437
- G(this, "_nTrain", 0);
438
- this._nComponents = n.nComponents ?? 2, this._nNeighbors = n.nNeighbors ?? 15, this._minDist = n.minDist ?? 0.1, this._spread = n.spread ?? 1, this._nEpochs = n.nEpochs, this._hnswOpts = n.hnsw ?? {};
439
- const { a: r, b: s } = K(this._minDist, this._spread);
440
- this._a = r, this._b = s;
529
+ C(this, "embedding", null);
530
+ C(this, "_hnswIndex", null);
531
+ C(this, "_nTrain", 0);
532
+ this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
533
+ const { a: f, b: a } = Z(this._minDist, this._spread);
534
+ this._a = f, this._b = a;
441
535
  }
442
536
  /**
443
537
  * Train UMAP on `vectors`.
@@ -445,45 +539,45 @@ class re {
445
539
  * index so that transform() can project new points later.
446
540
  * Returns `this` for chaining.
447
541
  */
448
- async fit(n, r) {
449
- const s = n.length, t = this._nEpochs ?? (s > 1e4 ? 200 : 500), { M: p = 16, efConstruction: f = 200, efSearch: l = 50 } = this._hnswOpts;
542
+ async fit(t, f) {
543
+ const a = t.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: u = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
450
544
  console.time("knn");
451
- const { knn: c, index: a } = await Q(n, this._nNeighbors, {
452
- M: p,
453
- efConstruction: f,
454
- efSearch: l
545
+ const { knn: i, index: n } = await ae(t, this._nNeighbors, {
546
+ M: u,
547
+ efConstruction: c,
548
+ efSearch: d
455
549
  });
456
- this._hnswIndex = a, this._nTrain = s, console.timeEnd("knn"), console.time("fuzzy-set");
457
- const o = D(c.indices, c.distances, this._nNeighbors);
550
+ this._hnswIndex = n, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
551
+ const r = V(i.indices, i.distances, this._nNeighbors);
458
552
  console.timeEnd("fuzzy-set");
459
- const i = I(o.vals, t), h = new Float32Array(s * this._nComponents);
460
- for (let d = 0; d < h.length; d++)
461
- h[d] = Math.random() * 20 - 10;
462
- if (console.time("sgd"), j())
553
+ const o = W(r.vals), h = new Float32Array(a * this._nComponents);
554
+ for (let l = 0; l < h.length; l++)
555
+ h[l] = Math.random() * 20 - 10;
556
+ if (console.time("sgd"), J())
463
557
  try {
464
- const d = new W();
465
- await d.init(), this.embedding = await d.optimize(
558
+ const l = new X();
559
+ await l.init(), this.embedding = await l.optimize(
466
560
  h,
467
- new Uint32Array(o.rows),
468
- new Uint32Array(o.cols),
469
- i,
470
- s,
561
+ new Uint32Array(r.rows),
562
+ new Uint32Array(r.cols),
563
+ o,
564
+ a,
471
565
  this._nComponents,
472
- t,
566
+ s,
473
567
  { a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 },
474
- r
568
+ f
475
569
  );
476
- } catch (d) {
477
- console.warn("WebGPU SGD failed, falling back to CPU:", d), this.embedding = z(h, o, i, s, this._nComponents, t, {
570
+ } catch (l) {
571
+ console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = L(h, r, o, a, this._nComponents, s, {
478
572
  a: this._a,
479
573
  b: this._b
480
- }, r);
574
+ }, f);
481
575
  }
482
576
  else
483
- this.embedding = z(h, o, i, s, this._nComponents, t, {
577
+ this.embedding = L(h, r, o, a, this._nComponents, s, {
484
578
  a: this._a,
485
579
  b: this._b
486
- }, r);
580
+ }, f);
487
581
  return console.timeEnd("sgd"), this;
488
582
  }
489
583
  /**
@@ -497,35 +591,35 @@ class re {
497
591
  * returned embedding to [0, 1]. The stored training embedding is never
498
592
  * mutated. Defaults to `false`.
499
593
  */
500
- async transform(n, r = !1) {
594
+ async transform(t, f = !1) {
501
595
  if (!this._hnswIndex || !this.embedding)
502
596
  throw new Error("UMAP.transform() must be called after fit()");
503
- const s = n.length, t = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), p = Math.max(100, Math.floor(t / 4)), f = this._hnswIndex.searchKnn(n, this._nNeighbors), l = J(f.indices, f.distances, this._nNeighbors), c = new Uint32Array(l.rows), a = new Uint32Array(l.cols), o = new Float32Array(s), i = new Float32Array(s * this._nComponents);
504
- for (let u = 0; u < c.length; u++) {
505
- const g = c[u], m = a[u], w = l.vals[u];
506
- o[g] += w;
507
- for (let b = 0; b < this._nComponents; b++)
508
- i[g * this._nComponents + b] += w * this.embedding[m * this._nComponents + b];
597
+ const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), i = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), o = new Float32Array(a * this._nComponents);
598
+ for (let p = 0; p < i.length; p++) {
599
+ const g = i[p], _ = n[p], m = d.vals[p];
600
+ r[g] += m;
601
+ for (let y = 0; y < this._nComponents; y++)
602
+ o[g * this._nComponents + y] += m * this.embedding[_ * this._nComponents + y];
509
603
  }
510
- for (let u = 0; u < s; u++)
511
- if (o[u] > 0)
604
+ for (let p = 0; p < a; p++)
605
+ if (r[p] > 0)
512
606
  for (let g = 0; g < this._nComponents; g++)
513
- i[u * this._nComponents + g] /= o[u];
607
+ o[p * this._nComponents + g] /= r[p];
514
608
  else
515
609
  for (let g = 0; g < this._nComponents; g++)
516
- i[u * this._nComponents + g] = Math.random() * 20 - 10;
517
- const h = I(l.vals, p), d = $(
518
- i,
610
+ o[p * this._nComponents + g] = Math.random() * 20 - 10;
611
+ const h = W(d.vals), l = fe(
612
+ o,
519
613
  this.embedding,
520
- l,
614
+ d,
521
615
  h,
522
- s,
616
+ a,
523
617
  this._nTrain,
524
618
  this._nComponents,
525
- p,
619
+ u,
526
620
  { a: this._a, b: this._b }
527
621
  );
528
- return r ? T(d, s, this._nComponents) : d;
622
+ return f ? K(l, a, this._nComponents) : l;
529
623
  }
530
624
  /**
531
625
  * Convenience method equivalent to `fit(vectors)` followed by
@@ -536,37 +630,38 @@ class re {
536
630
  * returned embedding to [0, 1]. `this.embedding` is never mutated.
537
631
  * Defaults to `false`.
538
632
  */
539
- async fit_transform(n, r, s = !1) {
540
- return await this.fit(n, r), s ? T(this.embedding, n.length, this._nComponents) : this.embedding;
633
+ async fit_transform(t, f, a = !1) {
634
+ return await this.fit(t, f), a ? K(this.embedding, t.length, this._nComponents) : this.embedding;
541
635
  }
542
636
  }
543
- function T(e, n, r) {
544
- const s = new Float32Array(e.length);
545
- for (let t = 0; t < r; t++) {
546
- let p = 1 / 0, f = -1 / 0;
547
- for (let c = 0; c < n; c++) {
548
- const a = e[c * r + t];
549
- a < p && (p = a), a > f && (f = a);
637
+ function K(e, t, f) {
638
+ const a = new Float32Array(e.length);
639
+ for (let s = 0; s < f; s++) {
640
+ let u = 1 / 0, c = -1 / 0;
641
+ for (let i = 0; i < t; i++) {
642
+ const n = e[i * f + s];
643
+ n < u && (u = n), n > c && (c = n);
550
644
  }
551
- const l = f - p;
552
- for (let c = 0; c < n; c++)
553
- s[c * r + t] = l > 0 ? (e[c * r + t] - p) / l : 0;
645
+ const d = c - u;
646
+ for (let i = 0; i < t; i++)
647
+ a[i * f + s] = d > 0 ? (e[i * f + s] - u) / d : 0;
554
648
  }
555
- return s;
649
+ return a;
556
650
  }
557
- function I(e, n) {
558
- let r = -1 / 0;
559
- for (let t = 0; t < e.length; t++)
560
- e[t] > r && (r = e[t]);
561
- const s = new Float32Array(e.length);
562
- for (let t = 0; t < e.length; t++) {
563
- const p = e[t] / r;
564
- s[t] = p > 0 ? n / p : -1;
651
+ function W(e, t) {
652
+ let f = -1 / 0;
653
+ for (let s = 0; s < e.length; s++)
654
+ e[s] > f && (f = e[s]);
655
+ const a = new Float32Array(e.length);
656
+ for (let s = 0; s < e.length; s++) {
657
+ const u = e[s] / f;
658
+ a[s] = u > 0 ? 1 / u : -1;
565
659
  }
566
- return s;
660
+ return a;
567
661
  }
568
662
  export {
569
- re as UMAP,
570
- ae as fit,
571
- j as isWebGPUAvailable
663
+ ge as UMAP,
664
+ he as checkWebGPUAvailable,
665
+ pe as fit,
666
+ J as isWebGPUAvailable
572
667
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.2.11",
3
+ "version": "0.2.14",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",