umap-gpu 0.2.13 → 0.2.15

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  export interface FuzzyGraph {
2
- rows: Float32Array;
3
- cols: Float32Array;
2
+ rows: Uint32Array;
3
+ cols: Uint32Array;
4
4
  vals: Float32Array;
5
5
  nVertices: number;
6
6
  }
@@ -7,6 +7,23 @@
7
7
  */
8
8
  export declare function getGPUDevice(): Promise<GPUDevice | null>;
9
9
  /**
10
- * Check whether WebGPU is available in the current environment.
10
+ * Fast synchronous heuristic: returns `true` if `navigator.gpu` exists.
11
+ *
12
+ * **Caveat (Bug 13):** `navigator.gpu` being truthy does NOT guarantee that a
13
+ * WebGPU adapter can be acquired — `requestAdapter()` may still return `null`
14
+ * (no compatible hardware, or the browser has disabled WebGPU for the page).
15
+ * Use `checkWebGPUAvailable()` for a reliable async check, or rely on the
16
+ * `try/catch` around `GPUSgd.init()` in the calling code.
11
17
  */
12
18
  export declare function isWebGPUAvailable(): boolean;
19
+ /**
20
+ * Reliably check whether WebGPU is usable in the current environment by
21
+ * attempting to acquire an adapter via `getGPUDevice()`.
22
+ *
23
+ * Unlike the synchronous `isWebGPUAvailable()`, this actually calls
24
+ * `navigator.gpu.requestAdapter()` and returns `false` if the adapter is
25
+ * unavailable (no compatible GPU, browser policy, etc.).
26
+ *
27
+ * The result is automatically cached — repeated calls are free.
28
+ */
29
+ export declare function checkWebGPUAvailable(): Promise<boolean>;
package/dist/gpu/sgd.d.ts CHANGED
@@ -6,11 +6,23 @@ export interface SGDParams {
6
6
  }
7
7
  /**
8
8
  * GPU-accelerated SGD optimizer for UMAP embedding.
9
- * Each GPU thread processes one graph edge, applying attraction and repulsion forces.
9
+ *
10
+ * Uses a two-pass design per epoch to eliminate write-write races on shared
11
+ * vertex positions (Bug 2 fix):
12
+ * Pass 1 (sgd.wgsl): Each thread accumulates its attraction and
13
+ * repulsion gradients into an atomic<i32>
14
+ * forces buffer — no direct embedding writes.
15
+ * Pass 2 (apply-forces.wgsl): Each thread applies one element's accumulated
16
+ * force to the embedding and resets the
17
+ * accumulator to zero for the next epoch.
18
+ *
19
+ * Both passes are submitted in the same command encoder so WebGPU guarantees
20
+ * sequential execution and storage-buffer visibility between them.
10
21
  */
11
22
  export declare class GPUSgd {
12
23
  private device;
13
- private pipeline;
24
+ private sgdPipeline;
25
+ private applyForcesPipeline;
14
26
  init(): Promise<void>;
15
27
  /**
16
28
  * Run SGD optimization on the GPU.
package/dist/index.d.ts CHANGED
@@ -2,4 +2,4 @@ export { fit, UMAP } from './umap';
2
2
  export type { UMAPOptions, ProgressCallback } from './umap';
3
3
  export type { KNNResult, HNSWOptions, HNSWSearchableIndex } from './hnsw-knn';
4
4
  export type { FuzzyGraph } from './fuzzy-set';
5
- export { isWebGPUAvailable } from './gpu/device';
5
+ export { isWebGPUAvailable, checkWebGPUAvailable } from './gpu/device';
package/dist/index.js CHANGED
@@ -1,104 +1,111 @@
1
- var V = Object.defineProperty;
2
- var Y = (e, t, r) => t in e ? V(e, t, { enumerable: !0, configurable: !0, writable: !0, value: r }) : e[t] = r;
3
- var G = (e, t, r) => Y(e, typeof t != "symbol" ? t + "" : t, r);
4
- import { loadHnswlib as D } from "hnswlib-wasm";
5
- async function Q(e, t, r = {}) {
6
- const { M: s = 16, efConstruction: n = 200, efSearch: g = 50 } = r, c = await D(), d = e[0].length, f = e.length, a = new c.HierarchicalNSW("l2", d, "");
7
- a.initIndex(f, s, n, 200), a.setEfSearch(Math.max(g, t)), a.addItems(e, !1);
8
- const o = [], i = [];
9
- for (let h = 0; h < f; h++) {
10
- const l = a.searchKnn(e[h], t + 1, void 0), u = l.neighbors.map((p, m) => ({ idx: p, dist: l.distances[m] })).filter(({ idx: p }) => p !== h).slice(0, t);
11
- o.push(u.map(({ idx: p }) => p)), i.push(u.map(({ dist: p }) => p));
1
+ var te = Object.defineProperty;
2
+ var ne = (e, t, f) => t in e ? te(e, t, { enumerable: !0, configurable: !0, writable: !0, value: f }) : e[t] = f;
3
+ var R = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
4
+ import { loadHnswlib as H } from "hnswlib-wasm";
5
+ async function se(e, t, f = {}) {
6
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
7
+ n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
8
+ const r = [], i = [];
9
+ for (let h = 0; h < o; h++) {
10
+ const l = n.searchKnn(e[h], t + 1, void 0), p = l.neighbors.map((g, _) => ({ idx: g, dist: l.distances[_] })).filter(({ idx: g }) => g !== h).slice(0, t);
11
+ r.push(p.map(({ idx: g }) => g)), i.push(p.map(({ dist: g }) => Math.sqrt(g)));
12
12
  }
13
- return { indices: o, distances: i };
13
+ return { indices: r, distances: i };
14
14
  }
15
- async function J(e, t, r = {}) {
16
- const { M: s = 16, efConstruction: n = 200, efSearch: g = 50 } = r, c = await D(), d = e[0].length, f = e.length, a = new c.HierarchicalNSW("l2", d, "");
17
- a.initIndex(f, s, n, 200), a.setEfSearch(Math.max(g, t)), a.addItems(e, !1);
18
- const o = [], i = [];
19
- for (let l = 0; l < f; l++) {
20
- const u = a.searchKnn(e[l], t + 1, void 0), p = u.neighbors.map((m, _) => ({ idx: m, dist: u.distances[_] })).filter(({ idx: m }) => m !== l).slice(0, t);
21
- o.push(p.map(({ idx: m }) => m)), i.push(p.map(({ dist: m }) => m));
15
+ async function ae(e, t, f = {}) {
16
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
17
+ n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
18
+ const r = [], i = [];
19
+ for (let l = 0; l < o; l++) {
20
+ const p = n.searchKnn(e[l], t + 1, void 0), g = p.neighbors.map((_, m) => ({ idx: _, dist: p.distances[m] })).filter(({ idx: _ }) => _ !== l).slice(0, t);
21
+ r.push(g.map(({ idx: _ }) => _)), i.push(g.map(({ dist: _ }) => Math.sqrt(_)));
22
22
  }
23
- return { knn: { indices: o, distances: i }, index: {
24
- searchKnn(l, u) {
25
- const p = [], m = [];
26
- for (const _ of l) {
27
- const w = a.searchKnn(_, u, void 0), b = w.neighbors.map((A, x) => ({ idx: A, dist: w.distances[x] })).sort((A, x) => A.dist - x.dist).slice(0, u);
28
- p.push(b.map(({ idx: A }) => A)), m.push(b.map(({ dist: A }) => A));
23
+ return { knn: { indices: r, distances: i }, index: {
24
+ searchKnn(l, p) {
25
+ const g = [], _ = [];
26
+ for (const m of l) {
27
+ const b = n.searchKnn(m, p, void 0), y = b.neighbors.map((x, v) => ({ idx: x, dist: b.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
28
+ g.push(y.map(({ idx: x }) => x)), _.push(y.map(({ dist: x }) => Math.sqrt(x)));
29
29
  }
30
- return { indices: p, distances: m };
30
+ return { indices: g, distances: _ };
31
31
  }
32
32
  } };
33
33
  }
34
- function L(e, t, r, s = 1) {
35
- const n = e.length, { sigmas: g, rhos: c } = W(t, r), d = [], f = [], a = [];
36
- for (let i = 0; i < n; i++)
34
+ function V(e, t, f, a = 1) {
35
+ const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], o = [], n = [];
36
+ for (let i = 0; i < s; i++)
37
37
  for (let h = 0; h < e[i].length; h++) {
38
- const l = t[i][h], u = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / g[i]));
39
- d.push(i), f.push(e[i][h]), a.push(u);
38
+ const l = t[i][h], p = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / u[i]));
39
+ d.push(i), o.push(e[i][h]), n.push(p);
40
40
  }
41
- return { ...Z(d, f, a, n, s), nVertices: n };
41
+ return { ...ie(d, o, n, s, a), nVertices: s };
42
42
  }
43
- function X(e, t, r) {
44
- const s = e.length, { sigmas: n, rhos: g } = W(t, r), c = [], d = [], f = [];
45
- for (let a = 0; a < s; a++)
46
- for (let o = 0; o < e[a].length; o++) {
47
- const i = t[a][o], h = i <= g[a] ? 1 : Math.exp(-((i - g[a]) / n[a]));
48
- c.push(a), d.push(e[a][o]), f.push(h);
43
+ function re(e, t, f) {
44
+ const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], o = [];
45
+ for (let n = 0; n < a; n++)
46
+ for (let r = 0; r < e[n].length; r++) {
47
+ const i = t[n][r], h = i <= u[n] ? 1 : Math.exp(-((i - u[n]) / s[n]));
48
+ c.push(n), d.push(e[n][r]), o.push(h);
49
49
  }
50
50
  return {
51
- rows: new Float32Array(c),
52
- cols: new Float32Array(d),
53
- vals: new Float32Array(f),
54
- nVertices: s
51
+ rows: new Uint32Array(c),
52
+ cols: new Uint32Array(d),
53
+ vals: new Float32Array(o),
54
+ nVertices: a
55
55
  };
56
56
  }
57
- function W(e, t) {
58
- const s = e.length, n = new Float32Array(s), g = new Float32Array(s);
59
- for (let c = 0; c < s; c++) {
57
+ function Y(e, t) {
58
+ const a = e.length, s = new Float32Array(a), u = new Float32Array(a);
59
+ for (let c = 0; c < a; c++) {
60
60
  const d = e[c];
61
- g[c] = d.find((h) => h > 0) ?? 0;
62
- let f = 0, a = 1 / 0, o = 1;
61
+ u[c] = d.find((h) => h > 0) ?? 0;
62
+ let o = 0, n = 1 / 0, r = 1;
63
63
  const i = Math.log2(t);
64
64
  for (let h = 0; h < 64; h++) {
65
65
  let l = 0;
66
- for (let u = 0; u < d.length; u++)
67
- l += Math.exp(-Math.max(0, d[u] - g[c]) / o);
66
+ for (let p = 0; p < d.length; p++)
67
+ l += Math.exp(-Math.max(0, d[p] - u[c]) / r);
68
68
  if (Math.abs(l - i) < 1e-5) break;
69
- l > i ? (a = o, o = (f + a) / 2) : (f = o, o = a === 1 / 0 ? o * 2 : (f + a) / 2);
69
+ l > i ? (n = r, r = (o + n) / 2) : (o = r, r = n === 1 / 0 ? r * 2 : (o + n) / 2);
70
70
  }
71
- n[c] = o;
71
+ s[c] = r;
72
72
  }
73
- return { sigmas: n, rhos: g };
73
+ return { sigmas: s, rhos: u };
74
74
  }
75
- function Z(e, t, r, s, n) {
76
- const g = /* @__PURE__ */ new Map();
77
- for (let a = 0; a < e.length; a++)
78
- g.set(e[a] * s + t[a], r[a]);
79
- const c = [], d = [], f = [];
80
- for (const [a, o] of g) {
81
- const i = Math.floor(a / s), h = a % s, l = g.get(h * s + i) ?? 0, u = o + l - o * l, p = o * l;
82
- c.push(i), d.push(h), f.push(n * u + (1 - n) * p);
75
+ function ie(e, t, f, a, s) {
76
+ const u = /* @__PURE__ */ new Map();
77
+ for (let n = 0; n < e.length; n++)
78
+ u.set(e[n] * a + t[n], f[n]);
79
+ const c = [], d = [], o = [];
80
+ for (const [n, r] of u) {
81
+ const i = Math.floor(n / a), h = n % a, l = u.get(h * a + i) ?? 0, p = r + l - r * l, g = r * l;
82
+ c.push(i), d.push(h), o.push(s * p + (1 - s) * g);
83
83
  }
84
84
  return {
85
- rows: new Float32Array(c),
86
- cols: new Float32Array(d),
87
- vals: new Float32Array(f)
85
+ rows: new Uint32Array(c),
86
+ cols: new Uint32Array(d),
87
+ vals: new Float32Array(o)
88
88
  };
89
89
  }
90
- const $ = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
- // Applies attraction forces between connected nodes and repulsion forces
92
- // against negative samples.
90
+ const oe = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
+ // Computes attraction and repulsion forces and accumulates them atomically
92
+ // into a forces buffer. A separate apply-forces pass then updates embeddings,
93
+ // eliminating write-write races on shared vertex positions.
93
94
 
94
95
  @group(0) @binding(0) var<storage, read> epochs_per_sample : array<f32>;
95
96
  @group(0) @binding(1) var<storage, read> head : array<u32>; // edge source
96
97
  @group(0) @binding(2) var<storage, read> tail : array<u32>; // edge target
97
- @group(0) @binding(3) var<storage, read_write> embedding : array<f32>; // [n * nComponents]
98
+ @group(0) @binding(3) var<storage, read> embedding : array<f32>; // [n * nComponents], read-only
98
99
  @group(0) @binding(4) var<storage, read_write> epoch_of_next_sample : array<f32>;
99
100
  @group(0) @binding(5) var<storage, read_write> epoch_of_next_negative_sample : array<f32>;
100
101
  @group(0) @binding(6) var<uniform> params : Params;
101
102
  @group(0) @binding(7) var<storage, read> rng_seeds : array<u32>; // per-edge seed
103
+ @group(0) @binding(8) var<storage, read_write> forces : array<atomic<i32>>; // [n * nComponents]
104
+
105
+ // Scale factor for quantizing f32 gradients into i32 for atomic accumulation.
106
+ // Gradients are clipped to [-4, 4]. With up to ~1000 edges sharing a vertex
107
+ // the max accumulated magnitude is ~4000, well within i32 range at this scale.
108
+ const FORCE_SCALE : f32 = 65536.0;
102
109
 
103
110
  struct Params {
104
111
  n_edges : u32,
@@ -106,7 +113,7 @@ struct Params {
106
113
  n_components : u32,
107
114
  current_epoch : u32,
108
115
  n_epochs : u32,
109
- alpha : f32, // learning rate
116
+ alpha : f32, // learning rate (applied by apply-forces pass)
110
117
  a : f32,
111
118
  b : f32,
112
119
  gamma : f32, // repulsion strength
@@ -147,7 +154,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
147
154
 
148
155
  let pow_b = pow(dist_sq, params.b);
149
156
  // Guard dist_sq == 0: b-1 is negative so pow(0, b-1) = +Inf.
150
- // Mirror CPU: use pow_b / dist_sq only when dist_sq > 0, else 0.
151
157
  let grad_coeff_attr = select(
152
158
  -2.0 * params.a * params.b * (pow_b / dist_sq) / (params.a * pow_b + 1.0),
153
159
  0.0,
@@ -157,20 +163,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
157
163
  for (var d = 0u; d < nc; d++) {
158
164
  let diff = embedding[i * nc + d] - embedding[j * nc + d];
159
165
  let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
160
- embedding[i * nc + d] += params.alpha * grad;
161
- embedding[j * nc + d] -= params.alpha * grad;
166
+ // Accumulate atomically to avoid write-write races across threads.
167
+ atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
168
+ atomicAdd(&forces[j * nc + d], -i32(grad * FORCE_SCALE));
162
169
  }
163
170
 
164
171
  epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
165
172
 
166
173
  // --- Repulsion (negative samples) ---
167
- let eps = epochs_per_sample[edge_idx];
168
- let neg_eps = epoch_of_next_negative_sample[edge_idx];
174
+ // Compute how many negative samples are overdue relative to current epoch,
175
+ // matching the Python reference: n_neg = floor((n - next_neg) / eps_per_neg).
176
+ let epoch_f = f32(params.current_epoch);
177
+ let epochs_per_neg = epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
169
178
  var n_neg = 0u;
170
- if (neg_eps > 0.0) {
171
- n_neg = u32(eps / neg_eps);
179
+ if (epochs_per_neg > 0.0 && epoch_f >= epoch_of_next_negative_sample[edge_idx]) {
180
+ n_neg = u32((epoch_f - epoch_of_next_negative_sample[edge_idx]) / epochs_per_neg);
181
+ epoch_of_next_negative_sample[edge_idx] += f32(n_neg) * epochs_per_neg;
172
182
  }
173
- var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 6364136223u);
183
+
184
+ // 2654435761u is the 32-bit golden-ratio hash constant (0x9E3779B1),
185
+ // which fits in u32 unlike the 64-bit LCG value 6364136223 used originally.
186
+ // Bug 14 fix: the original constant exceeded u32 range and failed WGSL validation.
187
+ var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 2654435761u);
174
188
 
175
189
  for (var s = 0u; s < n_neg; s++) {
176
190
  rng = xorshift(rng);
@@ -189,26 +203,73 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
189
203
  for (var d = 0u; d < nc; d++) {
190
204
  let diff = embedding[i * nc + d] - embedding[k * nc + d];
191
205
  let grad = clip(grad_coeff_rep * diff, -4.0, 4.0);
192
- embedding[i * nc + d] += params.alpha * grad;
206
+ atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
193
207
  }
194
208
  }
209
+ }
210
+ `, ce = `// Apply-forces shader — second pass of the two-pass GPU SGD.
211
+ //
212
+ // After the SGD pass has atomically accumulated all gradients into the forces
213
+ // buffer, this shader applies each element's accumulated force to the
214
+ // embedding and resets the accumulator to zero for the next epoch.
195
215
 
196
- epoch_of_next_negative_sample[edge_idx] +=
197
- epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
216
+ @group(0) @binding(0) var<storage, read_write> embedding : array<f32>;
217
+ @group(0) @binding(1) var<storage, read_write> forces : array<atomic<i32>>;
218
+ @group(0) @binding(2) var<uniform> params : ApplyParams;
219
+
220
+ struct ApplyParams {
221
+ n_elements : u32, // nVertices * nComponents
222
+ alpha : f32, // current learning rate
223
+ }
224
+
225
+ // Must match FORCE_SCALE in sgd.wgsl
226
+ const FORCE_SCALE : f32 = 65536.0;
227
+
228
+ @compute @workgroup_size(256)
229
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
230
+ let idx = gid.x;
231
+ if (idx >= params.n_elements) { return; }
232
+
233
+ // atomicExchange atomically reads the accumulated force and resets it to 0.
234
+ let raw = atomicExchange(&forces[idx], 0);
235
+ embedding[idx] += params.alpha * f32(raw) / FORCE_SCALE;
198
236
  }
199
237
  `;
200
- class j {
238
+ let z = null;
239
+ async function Q() {
240
+ if (z) return z;
241
+ if (typeof navigator > "u" || !navigator.gpu)
242
+ return null;
243
+ const e = await navigator.gpu.requestAdapter();
244
+ return e ? (z = await e.requestDevice(), z.lost.then(() => {
245
+ z = null;
246
+ }), z) : null;
247
+ }
248
+ function J() {
249
+ return typeof navigator < "u" && !!navigator.gpu;
250
+ }
251
+ async function he() {
252
+ return await Q() !== null;
253
+ }
254
+ class X {
201
255
  constructor() {
202
- G(this, "device");
203
- G(this, "pipeline");
256
+ R(this, "device");
257
+ R(this, "sgdPipeline");
258
+ R(this, "applyForcesPipeline");
204
259
  }
205
260
  async init() {
206
- const t = await navigator.gpu.requestAdapter();
261
+ const t = await Q();
207
262
  if (!t) throw new Error("WebGPU not supported");
208
- this.device = await t.requestDevice(), this.pipeline = this.device.createComputePipeline({
263
+ this.device = t, this.sgdPipeline = this.device.createComputePipeline({
264
+ layout: "auto",
265
+ compute: {
266
+ module: this.device.createShaderModule({ code: oe }),
267
+ entryPoint: "main"
268
+ }
269
+ }), this.applyForcesPipeline = this.device.createComputePipeline({
209
270
  layout: "auto",
210
271
  compute: {
211
- module: this.device.createShaderModule({ code: $ }),
272
+ module: this.device.createShaderModule({ code: ce }),
212
273
  entryPoint: "main"
213
274
  }
214
275
  });
@@ -226,222 +287,254 @@ class j {
226
287
  * @param params - UMAP curve parameters and repulsion settings
227
288
  * @returns Optimized embedding as Float32Array
228
289
  */
229
- async optimize(t, r, s, n, g, c, d, f, a) {
230
- const { device: o } = this, i = r.length, h = this.makeBuffer(
290
+ async optimize(t, f, a, s, u, c, d, o, n) {
291
+ const { device: r } = this, i = f.length, h = u * c, l = this.makeBuffer(
231
292
  t,
232
293
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
233
- ), l = this.makeBuffer(r, GPUBufferUsage.STORAGE), u = this.makeBuffer(s, GPUBufferUsage.STORAGE), p = this.makeBuffer(n, GPUBufferUsage.STORAGE), m = new Float32Array(i).fill(0), _ = this.makeBuffer(m, GPUBufferUsage.STORAGE), w = new Float32Array(i);
234
- for (let M = 0; M < i; M++)
235
- w[M] = n[M] / f.negativeSampleRate;
236
- const b = this.makeBuffer(w, GPUBufferUsage.STORAGE), A = new Uint32Array(i);
237
- for (let M = 0; M < i; M++)
238
- A[M] = Math.random() * 4294967295 | 0;
239
- const x = this.makeBuffer(A, GPUBufferUsage.STORAGE), U = o.createBuffer({
294
+ ), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), b = this.makeBuffer(m, GPUBufferUsage.STORAGE), y = new Float32Array(i);
295
+ for (let A = 0; A < i; A++)
296
+ y[A] = s[A] / o.negativeSampleRate;
297
+ const x = this.makeBuffer(y, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
298
+ for (let A = 0; A < i; A++)
299
+ v[A] = Math.random() * 4294967295 | 0;
300
+ const G = this.makeBuffer(v, GPUBufferUsage.STORAGE), P = r.createBuffer({
301
+ size: h * 4,
302
+ usage: GPUBufferUsage.STORAGE,
303
+ mappedAtCreation: !0
304
+ });
305
+ new Int32Array(P.getMappedRange()).fill(0), P.unmap();
306
+ const U = r.createBuffer({
240
307
  size: 40,
241
308
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
309
+ }), F = r.createBuffer({
310
+ size: 16,
311
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
312
+ }), N = r.createBindGroup({
313
+ layout: this.sgdPipeline.getBindGroupLayout(0),
314
+ entries: [
315
+ { binding: 0, resource: { buffer: _ } },
316
+ { binding: 1, resource: { buffer: p } },
317
+ { binding: 2, resource: { buffer: g } },
318
+ { binding: 3, resource: { buffer: l } },
319
+ { binding: 4, resource: { buffer: b } },
320
+ { binding: 5, resource: { buffer: x } },
321
+ { binding: 6, resource: { buffer: U } },
322
+ { binding: 7, resource: { buffer: G } },
323
+ { binding: 8, resource: { buffer: P } }
324
+ ]
325
+ }), M = r.createBindGroup({
326
+ layout: this.applyForcesPipeline.getBindGroupLayout(0),
327
+ entries: [
328
+ { binding: 0, resource: { buffer: l } },
329
+ { binding: 1, resource: { buffer: P } },
330
+ { binding: 2, resource: { buffer: F } }
331
+ ]
242
332
  });
243
- for (let M = 0; M < d; M++) {
244
- const v = 1 - M / d, S = new ArrayBuffer(40), y = new Uint32Array(S), B = new Float32Array(S);
245
- y[0] = i, y[1] = g, y[2] = c, y[3] = M, y[4] = d, B[5] = v, B[6] = f.a, B[7] = f.b, B[8] = f.gamma, y[9] = f.negativeSampleRate, o.queue.writeBuffer(U, 0, S);
246
- const k = o.createBindGroup({
247
- layout: this.pipeline.getBindGroupLayout(0),
248
- entries: [
249
- { binding: 0, resource: { buffer: p } },
250
- { binding: 1, resource: { buffer: l } },
251
- { binding: 2, resource: { buffer: u } },
252
- { binding: 3, resource: { buffer: h } },
253
- { binding: 4, resource: { buffer: _ } },
254
- { binding: 5, resource: { buffer: b } },
255
- { binding: 6, resource: { buffer: U } },
256
- { binding: 7, resource: { buffer: x } }
257
- ]
258
- }), E = o.createCommandEncoder(), P = E.beginComputePass();
259
- P.setPipeline(this.pipeline), P.setBindGroup(0, k), P.dispatchWorkgroups(Math.ceil(i / 256)), P.end(), o.queue.submit([E.finish()]), M % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(M, d));
333
+ for (let A = 0; A < d; A++) {
334
+ const B = 1 - A / d, O = new ArrayBuffer(40), S = new Uint32Array(O), k = new Float32Array(O);
335
+ S[0] = i, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = o.a, k[7] = o.b, k[8] = o.gamma, S[9] = o.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
336
+ const D = new ArrayBuffer(16), $ = new Uint32Array(D), ee = new Float32Array(D);
337
+ $[0] = h, ee[1] = B, r.queue.writeBuffer(F, 0, D);
338
+ const T = r.createCommandEncoder(), q = T.beginComputePass();
339
+ q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(i / 256)), q.end();
340
+ const L = T.beginComputePass();
341
+ L.setPipeline(this.applyForcesPipeline), L.setBindGroup(0, M), L.dispatchWorkgroups(Math.ceil(h / 256)), L.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
260
342
  }
261
- const N = o.createBuffer({
343
+ const E = r.createBuffer({
262
344
  size: t.byteLength,
263
345
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
264
- }), R = o.createCommandEncoder();
265
- R.copyBufferToBuffer(h, 0, N, 0, t.byteLength), o.queue.submit([R.finish()]), await N.mapAsync(GPUMapMode.READ);
266
- const O = new Float32Array(N.getMappedRange().slice(0));
267
- return N.unmap(), h.destroy(), l.destroy(), u.destroy(), p.destroy(), _.destroy(), b.destroy(), x.destroy(), U.destroy(), N.destroy(), O;
346
+ }), w = r.createCommandEncoder();
347
+ w.copyBufferToBuffer(l, 0, E, 0, t.byteLength), r.queue.submit([w.finish()]), await E.mapAsync(GPUMapMode.READ);
348
+ const C = new Float32Array(E.getMappedRange().slice(0));
349
+ return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), b.destroy(), x.destroy(), G.destroy(), P.destroy(), U.destroy(), F.destroy(), E.destroy(), C;
268
350
  }
269
- makeBuffer(t, r) {
270
- const s = this.device.createBuffer({
351
+ makeBuffer(t, f) {
352
+ const a = this.device.createBuffer({
271
353
  size: t.byteLength,
272
- usage: r,
354
+ usage: f,
273
355
  mappedAtCreation: !0
274
356
  });
275
- return t instanceof Float32Array ? new Float32Array(s.getMappedRange()).set(t) : new Uint32Array(s.getMappedRange()).set(t), s.unmap(), s;
357
+ return t instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(t) : new Uint32Array(a.getMappedRange()).set(t), a.unmap(), a;
276
358
  }
277
359
  }
278
- function z(e) {
360
+ function I(e) {
279
361
  return Math.max(-4, Math.min(4, e));
280
362
  }
281
- function I(e, t, r, s, n, g, c, d) {
282
- const { a: f, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), u = new Uint32Array(t.cols), p = new Float32Array(h).fill(0), m = new Float32Array(h);
283
- for (let _ = 0; _ < h; _++)
284
- m[_] = r[_] / i;
285
- for (let _ = 0; _ < g; _++) {
286
- d == null || d(_, g);
287
- const w = 1 - _ / g;
288
- for (let b = 0; b < h; b++) {
289
- if (p[b] > _) continue;
290
- const A = l[b], x = u[b];
291
- let U = 0;
292
- for (let v = 0; v < n; v++) {
293
- const S = e[A * n + v] - e[x * n + v];
294
- U += S * S;
363
+ function j(e, t, f, a, s, u, c, d) {
364
+ const { a: o, b: n, gamma: r = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
365
+ for (let m = 0; m < h; m++)
366
+ _[m] = f[m] / i;
367
+ for (let m = 0; m < u; m++) {
368
+ d == null || d(m, u);
369
+ const b = 1 - m / u;
370
+ for (let y = 0; y < h; y++) {
371
+ if (g[y] > m) continue;
372
+ const x = l[y], v = p[y];
373
+ let G = 0;
374
+ for (let M = 0; M < s; M++) {
375
+ const E = e[x * s + M] - e[v * s + M];
376
+ G += E * E;
295
377
  }
296
- const N = Math.pow(U, a), R = -2 * f * a * (U > 0 ? N / U : 0) / (f * N + 1);
297
- for (let v = 0; v < n; v++) {
298
- const S = e[A * n + v] - e[x * n + v], y = z(R * S);
299
- e[A * n + v] += w * y, e[x * n + v] -= w * y;
378
+ const P = Math.pow(G, n), U = -2 * o * n * (G > 0 ? P / G : 0) / (o * P + 1);
379
+ for (let M = 0; M < s; M++) {
380
+ const E = e[x * s + M] - e[v * s + M], w = I(U * E);
381
+ e[x * s + M] += b * w, e[v * s + M] -= b * w;
300
382
  }
301
- p[b] += r[b];
302
- const O = r[b] / i, M = Math.max(0, Math.floor(
303
- (_ - m[b]) / O
383
+ g[y] += f[y];
384
+ const F = f[y] / i, N = Math.max(0, Math.floor(
385
+ (m - _[y]) / F
304
386
  ));
305
- m[b] += M * O;
306
- for (let v = 0; v < M; v++) {
307
- const S = Math.floor(Math.random() * s);
308
- if (S === A) continue;
309
- let y = 0;
310
- for (let E = 0; E < n; E++) {
311
- const P = e[A * n + E] - e[S * n + E];
312
- y += P * P;
387
+ _[y] += N * F;
388
+ for (let M = 0; M < N; M++) {
389
+ const E = Math.floor(Math.random() * a);
390
+ if (E === x) continue;
391
+ let w = 0;
392
+ for (let B = 0; B < s; B++) {
393
+ const O = e[x * s + B] - e[E * s + B];
394
+ w += O * O;
313
395
  }
314
- const B = Math.pow(y, a), k = 2 * o * a / ((1e-3 + y) * (f * B + 1));
315
- for (let E = 0; E < n; E++) {
316
- const P = e[A * n + E] - e[S * n + E], F = z(k * P);
317
- e[A * n + E] += w * F;
396
+ const C = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (o * C + 1));
397
+ for (let B = 0; B < s; B++) {
398
+ const O = e[x * s + B] - e[E * s + B], S = I(A * O);
399
+ e[x * s + B] += b * S;
318
400
  }
319
401
  }
320
402
  }
321
403
  }
322
404
  return e;
323
405
  }
324
- function ee(e, t, r, s, n, g, c, d, f, a) {
325
- const { a: o, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = f, u = r.rows.length, p = new Uint32Array(r.rows), m = new Uint32Array(r.cols), _ = new Float32Array(u).fill(0), w = new Float32Array(u);
326
- for (let b = 0; b < u; b++)
327
- w[b] = s[b] / l;
328
- for (let b = 0; b < d; b++) {
329
- const A = 1 - b / d;
330
- for (let x = 0; x < u; x++) {
331
- if (_[x] > b) continue;
332
- const U = p[x], N = m[x];
333
- let R = 0;
334
- for (let y = 0; y < c; y++) {
335
- const B = e[U * c + y] - t[N * c + y];
336
- R += B * B;
406
+ function fe(e, t, f, a, s, u, c, d, o, n) {
407
+ const { a: r, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = o, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), b = new Float32Array(p);
408
+ for (let y = 0; y < p; y++)
409
+ b[y] = a[y] / l;
410
+ for (let y = 0; y < d; y++) {
411
+ const x = 1 - y / d;
412
+ for (let v = 0; v < p; v++) {
413
+ if (m[v] > y) continue;
414
+ const G = g[v], P = _[v];
415
+ let U = 0;
416
+ for (let w = 0; w < c; w++) {
417
+ const C = e[G * c + w] - t[P * c + w];
418
+ U += C * C;
337
419
  }
338
- const O = Math.pow(R, i), M = -2 * o * i * (R > 0 ? O / R : 0) / (o * O + 1);
339
- for (let y = 0; y < c; y++) {
340
- const B = e[U * c + y] - t[N * c + y];
341
- e[U * c + y] += A * z(M * B);
420
+ const F = Math.pow(U, i), N = -2 * r * i * (U > 0 ? F / U : 0) / (r * F + 1);
421
+ for (let w = 0; w < c; w++) {
422
+ const C = e[G * c + w] - t[P * c + w];
423
+ e[G * c + w] += x * I(N * C);
342
424
  }
343
- _[x] += s[x];
344
- const v = s[x] / l, S = Math.max(0, Math.floor(
345
- (b - w[x]) / v
425
+ m[v] += a[v];
426
+ const M = a[v] / l, E = Math.max(0, Math.floor(
427
+ (y - b[v]) / M
346
428
  ));
347
- w[x] += S * v;
348
- for (let y = 0; y < S; y++) {
349
- const B = Math.floor(Math.random() * g);
350
- if (B === N) continue;
351
- let k = 0;
352
- for (let F = 0; F < c; F++) {
353
- const q = e[U * c + F] - t[B * c + F];
354
- k += q * q;
429
+ b[v] += E * M;
430
+ for (let w = 0; w < E; w++) {
431
+ const C = Math.floor(Math.random() * u);
432
+ if (C === P) continue;
433
+ let A = 0;
434
+ for (let S = 0; S < c; S++) {
435
+ const k = e[G * c + S] - t[C * c + S];
436
+ A += k * k;
355
437
  }
356
- const E = Math.pow(k, i), P = 2 * h * i / ((1e-3 + k) * (o * E + 1));
357
- for (let F = 0; F < c; F++) {
358
- const q = e[U * c + F] - t[B * c + F];
359
- e[U * c + F] += A * z(P * q);
438
+ const B = Math.pow(A, i), O = 2 * h * i / ((1e-3 + A) * (r * B + 1));
439
+ for (let S = 0; S < c; S++) {
440
+ const k = e[G * c + S] - t[C * c + S];
441
+ e[G * c + S] += x * I(O * k);
360
442
  }
361
443
  }
362
444
  }
363
445
  }
364
446
  return e;
365
447
  }
366
- function K() {
367
- return typeof navigator < "u" && !!navigator.gpu;
368
- }
369
- async function re(e, t = {}, r) {
448
+ async function pe(e, t = {}, f) {
370
449
  const {
371
- nComponents: s = 2,
372
- nNeighbors: n = 15,
373
- minDist: g = 0.1,
450
+ nComponents: a = 2,
451
+ nNeighbors: s = 15,
452
+ minDist: u = 0.1,
374
453
  spread: c = 1,
375
454
  hnsw: d = {}
376
- } = t, f = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
455
+ } = t, o = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
377
456
  console.time("knn");
378
- const { indices: a, distances: o } = await Q(e, n, {
457
+ const { indices: n, distances: r } = await se(e, s, {
379
458
  M: d.M ?? 16,
380
459
  efConstruction: d.efConstruction ?? 200,
381
460
  efSearch: d.efSearch ?? 50
382
461
  });
383
462
  console.timeEnd("knn"), console.time("fuzzy-set");
384
- const i = L(a, o, n);
463
+ const i = V(n, r, s);
385
464
  console.timeEnd("fuzzy-set");
386
- const { a: h, b: l } = H(g, c), u = T(i.vals), p = e.length, m = new Float32Array(p * s);
387
- for (let w = 0; w < m.length; w++)
388
- m[w] = Math.random() * 20 - 10;
465
+ const { a: h, b: l } = Z(u, c), p = W(i.vals), g = e.length, _ = new Float32Array(g * a);
466
+ for (let b = 0; b < _.length; b++)
467
+ _[b] = Math.random() * 20 - 10;
389
468
  console.time("sgd");
390
- let _;
391
- if (K())
469
+ let m;
470
+ if (J())
392
471
  try {
393
- const w = new j();
394
- await w.init(), _ = await w.optimize(
395
- m,
472
+ const b = new X();
473
+ await b.init(), m = await b.optimize(
474
+ _,
396
475
  new Uint32Array(i.rows),
397
476
  new Uint32Array(i.cols),
398
- u,
399
477
  p,
400
- s,
401
- f,
478
+ g,
479
+ a,
480
+ o,
402
481
  { a: h, b: l, gamma: 1, negativeSampleRate: 5 },
403
- r
482
+ f
404
483
  );
405
- } catch (w) {
406
- console.warn("WebGPU SGD failed, falling back to CPU:", w), _ = I(m, i, u, p, s, f, { a: h, b: l }, r);
484
+ } catch (b) {
485
+ console.warn("WebGPU SGD failed, falling back to CPU:", b), m = j(_, i, p, g, a, o, { a: h, b: l }, f);
407
486
  }
408
487
  else
409
- _ = I(m, i, u, p, s, f, { a: h, b: l }, r);
410
- return console.timeEnd("sgd"), _;
488
+ m = j(_, i, p, g, a, o, { a: h, b: l }, f);
489
+ return console.timeEnd("sgd"), m;
411
490
  }
412
- function H(e, t) {
413
- if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
414
- return { a: 1.9292, b: 0.7915 };
415
- if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
416
- return { a: 1.8956, b: 0.8006 };
417
- if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
418
- return { a: 1.5769, b: 0.8951 };
419
- const r = te(e, t);
420
- return { a: ne(e, t, r), b: r };
491
+ function Z(e, t) {
492
+ return Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6 ? { a: 1.9292, b: 0.7915 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6 ? { a: 1.8956, b: 0.8006 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6 ? { a: 1.5769, b: 0.8951 } : de(e, t);
421
493
  }
422
- function te(e, t) {
423
- return 1 / (t * 1.2);
424
- }
425
- function ne(e, t, r) {
426
- return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * r);
494
+ function de(e, t) {
495
+ const a = [], s = [];
496
+ for (let o = 0; o < 299; o++) {
497
+ const n = (o + 1) / 299 * t * 3;
498
+ a.push(n), s.push(n < e ? 1 : Math.exp(-(n - e) / t));
499
+ }
500
+ let u = 1, c = 1, d = 1e-3;
501
+ for (let o = 0; o < 500; o++) {
502
+ let n = 0, r = 0, i = 0, h = 0, l = 0, p = 0;
503
+ for (let U = 0; U < 299; U++) {
504
+ const F = a[U], N = Math.pow(F, 2 * c), M = 1 + u * N, w = 1 / M - s[U];
505
+ p += w * w;
506
+ const C = M * M, A = -N / C, B = F > 0 ? -2 * Math.log(F) * u * N / C : 0;
507
+ n += A * w, r += B * w, i += A * A, h += B * B, l += A * B;
508
+ }
509
+ const g = i + d, _ = h + d, m = l, b = g * _ - m * m;
510
+ if (Math.abs(b) < 1e-20) break;
511
+ const y = -(_ * n - m * r) / b, x = -(g * r - m * n) / b, v = Math.max(1e-4, u + y), G = Math.max(1e-4, c + x);
512
+ let P = 0;
513
+ for (let U = 0; U < 299; U++) {
514
+ const F = Math.pow(a[U], 2 * G), N = 1 / (1 + v * F) - s[U];
515
+ P += N * N;
516
+ }
517
+ if (P < p ? (u = v, c = G, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(y) < 1e-8 && Math.abs(x) < 1e-8) break;
518
+ }
519
+ return { a: u, b: c };
427
520
  }
428
- class ie {
521
+ class ge {
429
522
  constructor(t = {}) {
430
- G(this, "_nComponents");
431
- G(this, "_nNeighbors");
432
- G(this, "_minDist");
433
- G(this, "_spread");
434
- G(this, "_nEpochs");
435
- G(this, "_hnswOpts");
436
- G(this, "_a");
437
- G(this, "_b");
523
+ R(this, "_nComponents");
524
+ R(this, "_nNeighbors");
525
+ R(this, "_minDist");
526
+ R(this, "_spread");
527
+ R(this, "_nEpochs");
528
+ R(this, "_hnswOpts");
529
+ R(this, "_a");
530
+ R(this, "_b");
438
531
  /** The low-dimensional embedding produced by the last fit() call. */
439
- G(this, "embedding", null);
440
- G(this, "_hnswIndex", null);
441
- G(this, "_nTrain", 0);
532
+ R(this, "embedding", null);
533
+ R(this, "_hnswIndex", null);
534
+ R(this, "_nTrain", 0);
442
535
  this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
443
- const { a: r, b: s } = H(this._minDist, this._spread);
444
- this._a = r, this._b = s;
536
+ const { a: f, b: a } = Z(this._minDist, this._spread);
537
+ this._a = f, this._b = a;
445
538
  }
446
539
  /**
447
540
  * Train UMAP on `vectors`.
@@ -449,45 +542,45 @@ class ie {
449
542
  * index so that transform() can project new points later.
450
543
  * Returns `this` for chaining.
451
544
  */
452
- async fit(t, r) {
453
- const s = t.length, n = this._nEpochs ?? (s > 1e4 ? 200 : 500), { M: g = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
545
+ async fit(t, f) {
546
+ const a = t.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: u = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
454
547
  console.time("knn");
455
- const { knn: f, index: a } = await J(t, this._nNeighbors, {
456
- M: g,
548
+ const { knn: o, index: n } = await ae(t, this._nNeighbors, {
549
+ M: u,
457
550
  efConstruction: c,
458
551
  efSearch: d
459
552
  });
460
- this._hnswIndex = a, this._nTrain = s, console.timeEnd("knn"), console.time("fuzzy-set");
461
- const o = L(f.indices, f.distances, this._nNeighbors);
553
+ this._hnswIndex = n, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
554
+ const r = V(o.indices, o.distances, this._nNeighbors);
462
555
  console.timeEnd("fuzzy-set");
463
- const i = T(o.vals), h = new Float32Array(s * this._nComponents);
556
+ const i = W(r.vals), h = new Float32Array(a * this._nComponents);
464
557
  for (let l = 0; l < h.length; l++)
465
558
  h[l] = Math.random() * 20 - 10;
466
- if (console.time("sgd"), K())
559
+ if (console.time("sgd"), J())
467
560
  try {
468
- const l = new j();
561
+ const l = new X();
469
562
  await l.init(), this.embedding = await l.optimize(
470
563
  h,
471
- new Uint32Array(o.rows),
472
- new Uint32Array(o.cols),
564
+ new Uint32Array(r.rows),
565
+ new Uint32Array(r.cols),
473
566
  i,
474
- s,
567
+ a,
475
568
  this._nComponents,
476
- n,
569
+ s,
477
570
  { a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 },
478
- r
571
+ f
479
572
  );
480
573
  } catch (l) {
481
- console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = I(h, o, i, s, this._nComponents, n, {
574
+ console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = j(h, r, i, a, this._nComponents, s, {
482
575
  a: this._a,
483
576
  b: this._b
484
- }, r);
577
+ }, f);
485
578
  }
486
579
  else
487
- this.embedding = I(h, o, i, s, this._nComponents, n, {
580
+ this.embedding = j(h, r, i, a, this._nComponents, s, {
488
581
  a: this._a,
489
582
  b: this._b
490
- }, r);
583
+ }, f);
491
584
  return console.timeEnd("sgd"), this;
492
585
  }
493
586
  /**
@@ -501,35 +594,35 @@ class ie {
501
594
  * returned embedding to [0, 1]. The stored training embedding is never
502
595
  * mutated. Defaults to `false`.
503
596
  */
504
- async transform(t, r = !1) {
597
+ async transform(t, f = !1) {
505
598
  if (!this._hnswIndex || !this.embedding)
506
599
  throw new Error("UMAP.transform() must be called after fit()");
507
- const s = t.length, n = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), g = Math.max(100, Math.floor(n / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = X(c.indices, c.distances, this._nNeighbors), f = new Uint32Array(d.rows), a = new Uint32Array(d.cols), o = new Float32Array(s), i = new Float32Array(s * this._nComponents);
508
- for (let u = 0; u < f.length; u++) {
509
- const p = f[u], m = a[u], _ = d.vals[u];
510
- o[p] += _;
511
- for (let w = 0; w < this._nComponents; w++)
512
- i[p * this._nComponents + w] += _ * this.embedding[m * this._nComponents + w];
600
+ const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), o = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), i = new Float32Array(a * this._nComponents);
601
+ for (let p = 0; p < o.length; p++) {
602
+ const g = o[p], _ = n[p], m = d.vals[p];
603
+ r[g] += m;
604
+ for (let b = 0; b < this._nComponents; b++)
605
+ i[g * this._nComponents + b] += m * this.embedding[_ * this._nComponents + b];
513
606
  }
514
- for (let u = 0; u < s; u++)
515
- if (o[u] > 0)
516
- for (let p = 0; p < this._nComponents; p++)
517
- i[u * this._nComponents + p] /= o[u];
607
+ for (let p = 0; p < a; p++)
608
+ if (r[p] > 0)
609
+ for (let g = 0; g < this._nComponents; g++)
610
+ i[p * this._nComponents + g] /= r[p];
518
611
  else
519
- for (let p = 0; p < this._nComponents; p++)
520
- i[u * this._nComponents + p] = Math.random() * 20 - 10;
521
- const h = T(d.vals), l = ee(
612
+ for (let g = 0; g < this._nComponents; g++)
613
+ i[p * this._nComponents + g] = Math.random() * 20 - 10;
614
+ const h = W(d.vals), l = fe(
522
615
  i,
523
616
  this.embedding,
524
617
  d,
525
618
  h,
526
- s,
619
+ a,
527
620
  this._nTrain,
528
621
  this._nComponents,
529
- g,
622
+ u,
530
623
  { a: this._a, b: this._b }
531
624
  );
532
- return r ? C(l, s, this._nComponents) : l;
625
+ return f ? K(l, a, this._nComponents) : l;
533
626
  }
534
627
  /**
535
628
  * Convenience method equivalent to `fit(vectors)` followed by
@@ -540,37 +633,38 @@ class ie {
540
633
  * returned embedding to [0, 1]. `this.embedding` is never mutated.
541
634
  * Defaults to `false`.
542
635
  */
543
- async fit_transform(t, r, s = !1) {
544
- return await this.fit(t, r), s ? C(this.embedding, t.length, this._nComponents) : this.embedding;
636
+ async fit_transform(t, f, a = !1) {
637
+ return await this.fit(t, f), a ? K(this.embedding, t.length, this._nComponents) : this.embedding;
545
638
  }
546
639
  }
547
- function C(e, t, r) {
548
- const s = new Float32Array(e.length);
549
- for (let n = 0; n < r; n++) {
550
- let g = 1 / 0, c = -1 / 0;
551
- for (let f = 0; f < t; f++) {
552
- const a = e[f * r + n];
553
- a < g && (g = a), a > c && (c = a);
640
+ function K(e, t, f) {
641
+ const a = new Float32Array(e.length);
642
+ for (let s = 0; s < f; s++) {
643
+ let u = 1 / 0, c = -1 / 0;
644
+ for (let o = 0; o < t; o++) {
645
+ const n = e[o * f + s];
646
+ n < u && (u = n), n > c && (c = n);
554
647
  }
555
- const d = c - g;
556
- for (let f = 0; f < t; f++)
557
- s[f * r + n] = d > 0 ? (e[f * r + n] - g) / d : 0;
648
+ const d = c - u;
649
+ for (let o = 0; o < t; o++)
650
+ a[o * f + s] = d > 0 ? (e[o * f + s] - u) / d : 0;
558
651
  }
559
- return s;
652
+ return a;
560
653
  }
561
- function T(e, t) {
562
- let r = -1 / 0;
563
- for (let n = 0; n < e.length; n++)
564
- e[n] > r && (r = e[n]);
565
- const s = new Float32Array(e.length);
566
- for (let n = 0; n < e.length; n++) {
567
- const g = e[n] / r;
568
- s[n] = g > 0 ? 1 / g : -1;
654
+ function W(e, t) {
655
+ let f = -1 / 0;
656
+ for (let s = 0; s < e.length; s++)
657
+ e[s] > f && (f = e[s]);
658
+ const a = new Float32Array(e.length);
659
+ for (let s = 0; s < e.length; s++) {
660
+ const u = e[s] / f;
661
+ a[s] = u > 0 ? 1 / u : -1;
569
662
  }
570
- return s;
663
+ return a;
571
664
  }
572
665
  export {
573
- ie as UMAP,
574
- re as fit,
575
- K as isWebGPUAvailable
666
+ ge as UMAP,
667
+ he as checkWebGPUAvailable,
668
+ pe as fit,
669
+ J as isWebGPUAvailable
576
670
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.2.13",
3
+ "version": "0.2.15",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",
@@ -55,7 +55,8 @@
55
55
  "vite": "^5.0.0",
56
56
  "vitepress": "^1.6.4",
57
57
  "vitepress-plugin-llms": "^1.11.0",
58
- "vitest": "^4.0.18"
58
+ "vitest": "^4.0.18",
59
+ "webgpu": "^0.3.8"
59
60
  },
60
61
  "license": "MIT"
61
62
  }