umap-gpu 0.2.13 → 0.2.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  export interface FuzzyGraph {
2
- rows: Float32Array;
3
- cols: Float32Array;
2
+ rows: Uint32Array;
3
+ cols: Uint32Array;
4
4
  vals: Float32Array;
5
5
  nVertices: number;
6
6
  }
@@ -7,6 +7,23 @@
7
7
  */
8
8
  export declare function getGPUDevice(): Promise<GPUDevice | null>;
9
9
  /**
10
- * Check whether WebGPU is available in the current environment.
10
+ * Fast synchronous heuristic: returns `true` if `navigator.gpu` exists.
11
+ *
12
+ * **Caveat (Bug 13):** `navigator.gpu` being truthy does NOT guarantee that a
13
+ * WebGPU adapter can be acquired — `requestAdapter()` may still return `null`
14
+ * (no compatible hardware, or the browser has disabled WebGPU for the page).
15
+ * Use `checkWebGPUAvailable()` for a reliable async check, or rely on the
16
+ * `try/catch` around `GPUSgd.init()` in the calling code.
11
17
  */
12
18
  export declare function isWebGPUAvailable(): boolean;
19
+ /**
20
+ * Reliably check whether WebGPU is usable in the current environment by
21
+ * attempting to acquire an adapter via `getGPUDevice()`.
22
+ *
23
+ * Unlike the synchronous `isWebGPUAvailable()`, this actually calls
24
+ * `navigator.gpu.requestAdapter()` and returns `false` if the adapter is
25
+ * unavailable (no compatible GPU, browser policy, etc.).
26
+ *
27
+ * The result is automatically cached — repeated calls are free.
28
+ */
29
+ export declare function checkWebGPUAvailable(): Promise<boolean>;
package/dist/gpu/sgd.d.ts CHANGED
@@ -6,11 +6,23 @@ export interface SGDParams {
6
6
  }
7
7
  /**
8
8
  * GPU-accelerated SGD optimizer for UMAP embedding.
9
- * Each GPU thread processes one graph edge, applying attraction and repulsion forces.
9
+ *
10
+ * Uses a two-pass design per epoch to eliminate write-write races on shared
11
+ * vertex positions (Bug 2 fix):
12
+ * Pass 1 (sgd.wgsl): Each thread accumulates its attraction and
13
+ * repulsion gradients into an atomic<i32>
14
+ * forces buffer — no direct embedding writes.
15
+ * Pass 2 (apply-forces.wgsl): Each thread applies one element's accumulated
16
+ * force to the embedding and resets the
17
+ * accumulator to zero for the next epoch.
18
+ *
19
+ * Both passes are submitted in the same command encoder so WebGPU guarantees
20
+ * sequential execution and storage-buffer visibility between them.
10
21
  */
11
22
  export declare class GPUSgd {
12
23
  private device;
13
- private pipeline;
24
+ private sgdPipeline;
25
+ private applyForcesPipeline;
14
26
  init(): Promise<void>;
15
27
  /**
16
28
  * Run SGD optimization on the GPU.
package/dist/index.d.ts CHANGED
@@ -2,4 +2,4 @@ export { fit, UMAP } from './umap';
2
2
  export type { UMAPOptions, ProgressCallback } from './umap';
3
3
  export type { KNNResult, HNSWOptions, HNSWSearchableIndex } from './hnsw-knn';
4
4
  export type { FuzzyGraph } from './fuzzy-set';
5
- export { isWebGPUAvailable } from './gpu/device';
5
+ export { isWebGPUAvailable, checkWebGPUAvailable } from './gpu/device';
package/dist/index.js CHANGED
@@ -1,104 +1,111 @@
1
- var V = Object.defineProperty;
2
- var Y = (e, t, r) => t in e ? V(e, t, { enumerable: !0, configurable: !0, writable: !0, value: r }) : e[t] = r;
3
- var G = (e, t, r) => Y(e, typeof t != "symbol" ? t + "" : t, r);
4
- import { loadHnswlib as D } from "hnswlib-wasm";
5
- async function Q(e, t, r = {}) {
6
- const { M: s = 16, efConstruction: n = 200, efSearch: g = 50 } = r, c = await D(), d = e[0].length, f = e.length, a = new c.HierarchicalNSW("l2", d, "");
7
- a.initIndex(f, s, n, 200), a.setEfSearch(Math.max(g, t)), a.addItems(e, !1);
8
- const o = [], i = [];
9
- for (let h = 0; h < f; h++) {
10
- const l = a.searchKnn(e[h], t + 1, void 0), u = l.neighbors.map((p, m) => ({ idx: p, dist: l.distances[m] })).filter(({ idx: p }) => p !== h).slice(0, t);
11
- o.push(u.map(({ idx: p }) => p)), i.push(u.map(({ dist: p }) => p));
1
+ var te = Object.defineProperty;
2
+ var ne = (e, t, f) => t in e ? te(e, t, { enumerable: !0, configurable: !0, writable: !0, value: f }) : e[t] = f;
3
+ var C = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
4
+ import { loadHnswlib as H } from "hnswlib-wasm";
5
+ async function se(e, t, f = {}) {
6
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, i = e.length, n = new c.HierarchicalNSW("l2", d, "");
7
+ n.initIndex(i, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
8
+ const r = [], o = [];
9
+ for (let h = 0; h < i; h++) {
10
+ const l = n.searchKnn(e[h], t + 1, void 0), p = l.neighbors.map((g, _) => ({ idx: g, dist: l.distances[_] })).filter(({ idx: g }) => g !== h).slice(0, t);
11
+ r.push(p.map(({ idx: g }) => g)), o.push(p.map(({ dist: g }) => Math.sqrt(g)));
12
12
  }
13
- return { indices: o, distances: i };
13
+ return { indices: r, distances: o };
14
14
  }
15
- async function J(e, t, r = {}) {
16
- const { M: s = 16, efConstruction: n = 200, efSearch: g = 50 } = r, c = await D(), d = e[0].length, f = e.length, a = new c.HierarchicalNSW("l2", d, "");
17
- a.initIndex(f, s, n, 200), a.setEfSearch(Math.max(g, t)), a.addItems(e, !1);
18
- const o = [], i = [];
19
- for (let l = 0; l < f; l++) {
20
- const u = a.searchKnn(e[l], t + 1, void 0), p = u.neighbors.map((m, _) => ({ idx: m, dist: u.distances[_] })).filter(({ idx: m }) => m !== l).slice(0, t);
21
- o.push(p.map(({ idx: m }) => m)), i.push(p.map(({ dist: m }) => m));
15
+ async function ae(e, t, f = {}) {
16
+ const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, i = e.length, n = new c.HierarchicalNSW("l2", d, "");
17
+ n.initIndex(i, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
18
+ const r = [], o = [];
19
+ for (let l = 0; l < i; l++) {
20
+ const p = n.searchKnn(e[l], t + 1, void 0), g = p.neighbors.map((_, m) => ({ idx: _, dist: p.distances[m] })).filter(({ idx: _ }) => _ !== l).slice(0, t);
21
+ r.push(g.map(({ idx: _ }) => _)), o.push(g.map(({ dist: _ }) => Math.sqrt(_)));
22
22
  }
23
- return { knn: { indices: o, distances: i }, index: {
24
- searchKnn(l, u) {
25
- const p = [], m = [];
26
- for (const _ of l) {
27
- const w = a.searchKnn(_, u, void 0), b = w.neighbors.map((A, x) => ({ idx: A, dist: w.distances[x] })).sort((A, x) => A.dist - x.dist).slice(0, u);
28
- p.push(b.map(({ idx: A }) => A)), m.push(b.map(({ dist: A }) => A));
23
+ return { knn: { indices: r, distances: o }, index: {
24
+ searchKnn(l, p) {
25
+ const g = [], _ = [];
26
+ for (const m of l) {
27
+ const y = n.searchKnn(m, p, void 0), b = y.neighbors.map((x, v) => ({ idx: x, dist: y.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
28
+ g.push(b.map(({ idx: x }) => x)), _.push(b.map(({ dist: x }) => Math.sqrt(x)));
29
29
  }
30
- return { indices: p, distances: m };
30
+ return { indices: g, distances: _ };
31
31
  }
32
32
  } };
33
33
  }
34
- function L(e, t, r, s = 1) {
35
- const n = e.length, { sigmas: g, rhos: c } = W(t, r), d = [], f = [], a = [];
36
- for (let i = 0; i < n; i++)
37
- for (let h = 0; h < e[i].length; h++) {
38
- const l = t[i][h], u = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / g[i]));
39
- d.push(i), f.push(e[i][h]), a.push(u);
34
+ function V(e, t, f, a = 1) {
35
+ const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], i = [], n = [];
36
+ for (let o = 0; o < s; o++)
37
+ for (let h = 0; h < e[o].length; h++) {
38
+ const l = t[o][h], p = l <= c[o] ? 1 : Math.exp(-((l - c[o]) / u[o]));
39
+ d.push(o), i.push(e[o][h]), n.push(p);
40
40
  }
41
- return { ...Z(d, f, a, n, s), nVertices: n };
41
+ return { ...oe(d, i, n, s, a), nVertices: s };
42
42
  }
43
- function X(e, t, r) {
44
- const s = e.length, { sigmas: n, rhos: g } = W(t, r), c = [], d = [], f = [];
45
- for (let a = 0; a < s; a++)
46
- for (let o = 0; o < e[a].length; o++) {
47
- const i = t[a][o], h = i <= g[a] ? 1 : Math.exp(-((i - g[a]) / n[a]));
48
- c.push(a), d.push(e[a][o]), f.push(h);
43
+ function re(e, t, f) {
44
+ const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], i = [];
45
+ for (let n = 0; n < a; n++)
46
+ for (let r = 0; r < e[n].length; r++) {
47
+ const o = t[n][r], h = o <= u[n] ? 1 : Math.exp(-((o - u[n]) / s[n]));
48
+ c.push(n), d.push(e[n][r]), i.push(h);
49
49
  }
50
50
  return {
51
- rows: new Float32Array(c),
52
- cols: new Float32Array(d),
53
- vals: new Float32Array(f),
54
- nVertices: s
51
+ rows: new Uint32Array(c),
52
+ cols: new Uint32Array(d),
53
+ vals: new Float32Array(i),
54
+ nVertices: a
55
55
  };
56
56
  }
57
- function W(e, t) {
58
- const s = e.length, n = new Float32Array(s), g = new Float32Array(s);
59
- for (let c = 0; c < s; c++) {
57
+ function Y(e, t) {
58
+ const a = e.length, s = new Float32Array(a), u = new Float32Array(a);
59
+ for (let c = 0; c < a; c++) {
60
60
  const d = e[c];
61
- g[c] = d.find((h) => h > 0) ?? 0;
62
- let f = 0, a = 1 / 0, o = 1;
63
- const i = Math.log2(t);
61
+ u[c] = d.find((h) => h > 0) ?? 0;
62
+ let i = 0, n = 1 / 0, r = 1;
63
+ const o = Math.log2(t);
64
64
  for (let h = 0; h < 64; h++) {
65
65
  let l = 0;
66
- for (let u = 0; u < d.length; u++)
67
- l += Math.exp(-Math.max(0, d[u] - g[c]) / o);
68
- if (Math.abs(l - i) < 1e-5) break;
69
- l > i ? (a = o, o = (f + a) / 2) : (f = o, o = a === 1 / 0 ? o * 2 : (f + a) / 2);
66
+ for (let p = 0; p < d.length; p++)
67
+ l += Math.exp(-Math.max(0, d[p] - u[c]) / r);
68
+ if (Math.abs(l - o) < 1e-5) break;
69
+ l > o ? (n = r, r = (i + n) / 2) : (i = r, r = n === 1 / 0 ? r * 2 : (i + n) / 2);
70
70
  }
71
- n[c] = o;
71
+ s[c] = r;
72
72
  }
73
- return { sigmas: n, rhos: g };
73
+ return { sigmas: s, rhos: u };
74
74
  }
75
- function Z(e, t, r, s, n) {
76
- const g = /* @__PURE__ */ new Map();
77
- for (let a = 0; a < e.length; a++)
78
- g.set(e[a] * s + t[a], r[a]);
79
- const c = [], d = [], f = [];
80
- for (const [a, o] of g) {
81
- const i = Math.floor(a / s), h = a % s, l = g.get(h * s + i) ?? 0, u = o + l - o * l, p = o * l;
82
- c.push(i), d.push(h), f.push(n * u + (1 - n) * p);
75
+ function oe(e, t, f, a, s) {
76
+ const u = /* @__PURE__ */ new Map();
77
+ for (let n = 0; n < e.length; n++)
78
+ u.set(e[n] * a + t[n], f[n]);
79
+ const c = [], d = [], i = [];
80
+ for (const [n, r] of u) {
81
+ const o = Math.floor(n / a), h = n % a, l = u.get(h * a + o) ?? 0, p = r + l - r * l, g = r * l;
82
+ c.push(o), d.push(h), i.push(s * p + (1 - s) * g);
83
83
  }
84
84
  return {
85
- rows: new Float32Array(c),
86
- cols: new Float32Array(d),
87
- vals: new Float32Array(f)
85
+ rows: new Uint32Array(c),
86
+ cols: new Uint32Array(d),
87
+ vals: new Float32Array(i)
88
88
  };
89
89
  }
90
- const $ = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
- // Applies attraction forces between connected nodes and repulsion forces
92
- // against negative samples.
90
+ const ie = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
91
+ // Computes attraction and repulsion forces and accumulates them atomically
92
+ // into a forces buffer. A separate apply-forces pass then updates embeddings,
93
+ // eliminating write-write races on shared vertex positions.
93
94
 
94
95
  @group(0) @binding(0) var<storage, read> epochs_per_sample : array<f32>;
95
96
  @group(0) @binding(1) var<storage, read> head : array<u32>; // edge source
96
97
  @group(0) @binding(2) var<storage, read> tail : array<u32>; // edge target
97
- @group(0) @binding(3) var<storage, read_write> embedding : array<f32>; // [n * nComponents]
98
+ @group(0) @binding(3) var<storage, read> embedding : array<f32>; // [n * nComponents], read-only
98
99
  @group(0) @binding(4) var<storage, read_write> epoch_of_next_sample : array<f32>;
99
100
  @group(0) @binding(5) var<storage, read_write> epoch_of_next_negative_sample : array<f32>;
100
101
  @group(0) @binding(6) var<uniform> params : Params;
101
102
  @group(0) @binding(7) var<storage, read> rng_seeds : array<u32>; // per-edge seed
103
+ @group(0) @binding(8) var<storage, read_write> forces : array<atomic<i32>>; // [n * nComponents]
104
+
105
+ // Scale factor for quantizing f32 gradients into i32 for atomic accumulation.
106
+ // Gradients are clipped to [-4, 4]. With up to ~1000 edges sharing a vertex
107
+ // the max accumulated magnitude is ~4000, well within i32 range at this scale.
108
+ const FORCE_SCALE : f32 = 65536.0;
102
109
 
103
110
  struct Params {
104
111
  n_edges : u32,
@@ -106,7 +113,7 @@ struct Params {
106
113
  n_components : u32,
107
114
  current_epoch : u32,
108
115
  n_epochs : u32,
109
- alpha : f32, // learning rate
116
+ alpha : f32, // learning rate (applied by apply-forces pass)
110
117
  a : f32,
111
118
  b : f32,
112
119
  gamma : f32, // repulsion strength
@@ -147,7 +154,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
147
154
 
148
155
  let pow_b = pow(dist_sq, params.b);
149
156
  // Guard dist_sq == 0: b-1 is negative so pow(0, b-1) = +Inf.
150
- // Mirror CPU: use pow_b / dist_sq only when dist_sq > 0, else 0.
151
157
  let grad_coeff_attr = select(
152
158
  -2.0 * params.a * params.b * (pow_b / dist_sq) / (params.a * pow_b + 1.0),
153
159
  0.0,
@@ -157,19 +163,24 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
157
163
  for (var d = 0u; d < nc; d++) {
158
164
  let diff = embedding[i * nc + d] - embedding[j * nc + d];
159
165
  let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
160
- embedding[i * nc + d] += params.alpha * grad;
161
- embedding[j * nc + d] -= params.alpha * grad;
166
+ // Accumulate atomically to avoid write-write races across threads.
167
+ atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
168
+ atomicAdd(&forces[j * nc + d], -i32(grad * FORCE_SCALE));
162
169
  }
163
170
 
164
171
  epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
165
172
 
166
173
  // --- Repulsion (negative samples) ---
167
- let eps = epochs_per_sample[edge_idx];
168
- let neg_eps = epoch_of_next_negative_sample[edge_idx];
174
+ // Compute how many negative samples are overdue relative to current epoch,
175
+ // matching the Python reference: n_neg = floor((n - next_neg) / eps_per_neg).
176
+ let epoch_f = f32(params.current_epoch);
177
+ let epochs_per_neg = epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
169
178
  var n_neg = 0u;
170
- if (neg_eps > 0.0) {
171
- n_neg = u32(eps / neg_eps);
179
+ if (epochs_per_neg > 0.0 && epoch_f >= epoch_of_next_negative_sample[edge_idx]) {
180
+ n_neg = u32((epoch_f - epoch_of_next_negative_sample[edge_idx]) / epochs_per_neg);
181
+ epoch_of_next_negative_sample[edge_idx] += f32(n_neg) * epochs_per_neg;
172
182
  }
183
+
173
184
  var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 6364136223u);
174
185
 
175
186
  for (var s = 0u; s < n_neg; s++) {
@@ -189,26 +200,73 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
189
200
  for (var d = 0u; d < nc; d++) {
190
201
  let diff = embedding[i * nc + d] - embedding[k * nc + d];
191
202
  let grad = clip(grad_coeff_rep * diff, -4.0, 4.0);
192
- embedding[i * nc + d] += params.alpha * grad;
203
+ atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
193
204
  }
194
205
  }
206
+ }
207
+ `, ce = `// Apply-forces shader — second pass of the two-pass GPU SGD.
208
+ //
209
+ // After the SGD pass has atomically accumulated all gradients into the forces
210
+ // buffer, this shader applies each element's accumulated force to the
211
+ // embedding and resets the accumulator to zero for the next epoch.
195
212
 
196
- epoch_of_next_negative_sample[edge_idx] +=
197
- epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
213
+ @group(0) @binding(0) var<storage, read_write> embedding : array<f32>;
214
+ @group(0) @binding(1) var<storage, read_write> forces : array<atomic<i32>>;
215
+ @group(0) @binding(2) var<uniform> params : ApplyParams;
216
+
217
+ struct ApplyParams {
218
+ n_elements : u32, // nVertices * nComponents
219
+ alpha : f32, // current learning rate
220
+ }
221
+
222
+ // Must match FORCE_SCALE in sgd.wgsl
223
+ const FORCE_SCALE : f32 = 65536.0;
224
+
225
+ @compute @workgroup_size(256)
226
+ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
227
+ let idx = gid.x;
228
+ if (idx >= params.n_elements) { return; }
229
+
230
+ // atomicExchange atomically reads the accumulated force and resets it to 0.
231
+ let raw = atomicExchange(&forces[idx], 0);
232
+ embedding[idx] += params.alpha * f32(raw) / FORCE_SCALE;
198
233
  }
199
234
  `;
200
- class j {
235
+ let z = null;
236
+ async function Q() {
237
+ if (z) return z;
238
+ if (typeof navigator > "u" || !navigator.gpu)
239
+ return null;
240
+ const e = await navigator.gpu.requestAdapter();
241
+ return e ? (z = await e.requestDevice(), z.lost.then(() => {
242
+ z = null;
243
+ }), z) : null;
244
+ }
245
+ function J() {
246
+ return typeof navigator < "u" && !!navigator.gpu;
247
+ }
248
+ async function he() {
249
+ return await Q() !== null;
250
+ }
251
+ class X {
201
252
  constructor() {
202
- G(this, "device");
203
- G(this, "pipeline");
253
+ C(this, "device");
254
+ C(this, "sgdPipeline");
255
+ C(this, "applyForcesPipeline");
204
256
  }
205
257
  async init() {
206
- const t = await navigator.gpu.requestAdapter();
258
+ const t = await Q();
207
259
  if (!t) throw new Error("WebGPU not supported");
208
- this.device = await t.requestDevice(), this.pipeline = this.device.createComputePipeline({
260
+ this.device = t, this.sgdPipeline = this.device.createComputePipeline({
261
+ layout: "auto",
262
+ compute: {
263
+ module: this.device.createShaderModule({ code: ie }),
264
+ entryPoint: "main"
265
+ }
266
+ }), this.applyForcesPipeline = this.device.createComputePipeline({
209
267
  layout: "auto",
210
268
  compute: {
211
- module: this.device.createShaderModule({ code: $ }),
269
+ module: this.device.createShaderModule({ code: ce }),
212
270
  entryPoint: "main"
213
271
  }
214
272
  });
@@ -226,222 +284,254 @@ class j {
226
284
  * @param params - UMAP curve parameters and repulsion settings
227
285
  * @returns Optimized embedding as Float32Array
228
286
  */
229
- async optimize(t, r, s, n, g, c, d, f, a) {
230
- const { device: o } = this, i = r.length, h = this.makeBuffer(
287
+ async optimize(t, f, a, s, u, c, d, i, n) {
288
+ const { device: r } = this, o = f.length, h = u * c, l = this.makeBuffer(
231
289
  t,
232
290
  GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
233
- ), l = this.makeBuffer(r, GPUBufferUsage.STORAGE), u = this.makeBuffer(s, GPUBufferUsage.STORAGE), p = this.makeBuffer(n, GPUBufferUsage.STORAGE), m = new Float32Array(i).fill(0), _ = this.makeBuffer(m, GPUBufferUsage.STORAGE), w = new Float32Array(i);
234
- for (let M = 0; M < i; M++)
235
- w[M] = n[M] / f.negativeSampleRate;
236
- const b = this.makeBuffer(w, GPUBufferUsage.STORAGE), A = new Uint32Array(i);
237
- for (let M = 0; M < i; M++)
238
- A[M] = Math.random() * 4294967295 | 0;
239
- const x = this.makeBuffer(A, GPUBufferUsage.STORAGE), U = o.createBuffer({
291
+ ), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), y = this.makeBuffer(m, GPUBufferUsage.STORAGE), b = new Float32Array(o);
292
+ for (let A = 0; A < o; A++)
293
+ b[A] = s[A] / i.negativeSampleRate;
294
+ const x = this.makeBuffer(b, GPUBufferUsage.STORAGE), v = new Uint32Array(o);
295
+ for (let A = 0; A < o; A++)
296
+ v[A] = Math.random() * 4294967295 | 0;
297
+ const P = this.makeBuffer(v, GPUBufferUsage.STORAGE), G = r.createBuffer({
298
+ size: h * 4,
299
+ usage: GPUBufferUsage.STORAGE,
300
+ mappedAtCreation: !0
301
+ });
302
+ new Int32Array(G.getMappedRange()).fill(0), G.unmap();
303
+ const U = r.createBuffer({
240
304
  size: 40,
241
305
  usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
306
+ }), F = r.createBuffer({
307
+ size: 16,
308
+ usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
309
+ }), N = r.createBindGroup({
310
+ layout: this.sgdPipeline.getBindGroupLayout(0),
311
+ entries: [
312
+ { binding: 0, resource: { buffer: _ } },
313
+ { binding: 1, resource: { buffer: p } },
314
+ { binding: 2, resource: { buffer: g } },
315
+ { binding: 3, resource: { buffer: l } },
316
+ { binding: 4, resource: { buffer: y } },
317
+ { binding: 5, resource: { buffer: x } },
318
+ { binding: 6, resource: { buffer: U } },
319
+ { binding: 7, resource: { buffer: P } },
320
+ { binding: 8, resource: { buffer: G } }
321
+ ]
322
+ }), M = r.createBindGroup({
323
+ layout: this.applyForcesPipeline.getBindGroupLayout(0),
324
+ entries: [
325
+ { binding: 0, resource: { buffer: l } },
326
+ { binding: 1, resource: { buffer: G } },
327
+ { binding: 2, resource: { buffer: F } }
328
+ ]
242
329
  });
243
- for (let M = 0; M < d; M++) {
244
- const v = 1 - M / d, S = new ArrayBuffer(40), y = new Uint32Array(S), B = new Float32Array(S);
245
- y[0] = i, y[1] = g, y[2] = c, y[3] = M, y[4] = d, B[5] = v, B[6] = f.a, B[7] = f.b, B[8] = f.gamma, y[9] = f.negativeSampleRate, o.queue.writeBuffer(U, 0, S);
246
- const k = o.createBindGroup({
247
- layout: this.pipeline.getBindGroupLayout(0),
248
- entries: [
249
- { binding: 0, resource: { buffer: p } },
250
- { binding: 1, resource: { buffer: l } },
251
- { binding: 2, resource: { buffer: u } },
252
- { binding: 3, resource: { buffer: h } },
253
- { binding: 4, resource: { buffer: _ } },
254
- { binding: 5, resource: { buffer: b } },
255
- { binding: 6, resource: { buffer: U } },
256
- { binding: 7, resource: { buffer: x } }
257
- ]
258
- }), E = o.createCommandEncoder(), P = E.beginComputePass();
259
- P.setPipeline(this.pipeline), P.setBindGroup(0, k), P.dispatchWorkgroups(Math.ceil(i / 256)), P.end(), o.queue.submit([E.finish()]), M % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(M, d));
330
+ for (let A = 0; A < d; A++) {
331
+ const B = 1 - A / d, O = new ArrayBuffer(40), S = new Uint32Array(O), k = new Float32Array(O);
332
+ S[0] = o, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = i.a, k[7] = i.b, k[8] = i.gamma, S[9] = i.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
333
+ const D = new ArrayBuffer(16), $ = new Uint32Array(D), ee = new Float32Array(D);
334
+ $[0] = h, ee[1] = B, r.queue.writeBuffer(F, 0, D);
335
+ const T = r.createCommandEncoder(), q = T.beginComputePass();
336
+ q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(o / 256)), q.end();
337
+ const I = T.beginComputePass();
338
+ I.setPipeline(this.applyForcesPipeline), I.setBindGroup(0, M), I.dispatchWorkgroups(Math.ceil(h / 256)), I.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
260
339
  }
261
- const N = o.createBuffer({
340
+ const E = r.createBuffer({
262
341
  size: t.byteLength,
263
342
  usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
264
- }), R = o.createCommandEncoder();
265
- R.copyBufferToBuffer(h, 0, N, 0, t.byteLength), o.queue.submit([R.finish()]), await N.mapAsync(GPUMapMode.READ);
266
- const O = new Float32Array(N.getMappedRange().slice(0));
267
- return N.unmap(), h.destroy(), l.destroy(), u.destroy(), p.destroy(), _.destroy(), b.destroy(), x.destroy(), U.destroy(), N.destroy(), O;
343
+ }), w = r.createCommandEncoder();
344
+ w.copyBufferToBuffer(l, 0, E, 0, t.byteLength), r.queue.submit([w.finish()]), await E.mapAsync(GPUMapMode.READ);
345
+ const R = new Float32Array(E.getMappedRange().slice(0));
346
+ return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), y.destroy(), x.destroy(), P.destroy(), G.destroy(), U.destroy(), F.destroy(), E.destroy(), R;
268
347
  }
269
- makeBuffer(t, r) {
270
- const s = this.device.createBuffer({
348
+ makeBuffer(t, f) {
349
+ const a = this.device.createBuffer({
271
350
  size: t.byteLength,
272
- usage: r,
351
+ usage: f,
273
352
  mappedAtCreation: !0
274
353
  });
275
- return t instanceof Float32Array ? new Float32Array(s.getMappedRange()).set(t) : new Uint32Array(s.getMappedRange()).set(t), s.unmap(), s;
354
+ return t instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(t) : new Uint32Array(a.getMappedRange()).set(t), a.unmap(), a;
276
355
  }
277
356
  }
278
- function z(e) {
357
+ function j(e) {
279
358
  return Math.max(-4, Math.min(4, e));
280
359
  }
281
- function I(e, t, r, s, n, g, c, d) {
282
- const { a: f, b: a, gamma: o = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), u = new Uint32Array(t.cols), p = new Float32Array(h).fill(0), m = new Float32Array(h);
283
- for (let _ = 0; _ < h; _++)
284
- m[_] = r[_] / i;
285
- for (let _ = 0; _ < g; _++) {
286
- d == null || d(_, g);
287
- const w = 1 - _ / g;
360
+ function L(e, t, f, a, s, u, c, d) {
361
+ const { a: i, b: n, gamma: r = 1, negativeSampleRate: o = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
362
+ for (let m = 0; m < h; m++)
363
+ _[m] = f[m] / o;
364
+ for (let m = 0; m < u; m++) {
365
+ d == null || d(m, u);
366
+ const y = 1 - m / u;
288
367
  for (let b = 0; b < h; b++) {
289
- if (p[b] > _) continue;
290
- const A = l[b], x = u[b];
291
- let U = 0;
292
- for (let v = 0; v < n; v++) {
293
- const S = e[A * n + v] - e[x * n + v];
294
- U += S * S;
368
+ if (g[b] > m) continue;
369
+ const x = l[b], v = p[b];
370
+ let P = 0;
371
+ for (let M = 0; M < s; M++) {
372
+ const E = e[x * s + M] - e[v * s + M];
373
+ P += E * E;
295
374
  }
296
- const N = Math.pow(U, a), R = -2 * f * a * (U > 0 ? N / U : 0) / (f * N + 1);
297
- for (let v = 0; v < n; v++) {
298
- const S = e[A * n + v] - e[x * n + v], y = z(R * S);
299
- e[A * n + v] += w * y, e[x * n + v] -= w * y;
375
+ const G = Math.pow(P, n), U = -2 * i * n * (P > 0 ? G / P : 0) / (i * G + 1);
376
+ for (let M = 0; M < s; M++) {
377
+ const E = e[x * s + M] - e[v * s + M], w = j(U * E);
378
+ e[x * s + M] += y * w, e[v * s + M] -= y * w;
300
379
  }
301
- p[b] += r[b];
302
- const O = r[b] / i, M = Math.max(0, Math.floor(
303
- (_ - m[b]) / O
380
+ g[b] += f[b];
381
+ const F = f[b] / o, N = Math.max(0, Math.floor(
382
+ (m - _[b]) / F
304
383
  ));
305
- m[b] += M * O;
306
- for (let v = 0; v < M; v++) {
307
- const S = Math.floor(Math.random() * s);
308
- if (S === A) continue;
309
- let y = 0;
310
- for (let E = 0; E < n; E++) {
311
- const P = e[A * n + E] - e[S * n + E];
312
- y += P * P;
384
+ _[b] += N * F;
385
+ for (let M = 0; M < N; M++) {
386
+ const E = Math.floor(Math.random() * a);
387
+ if (E === x) continue;
388
+ let w = 0;
389
+ for (let B = 0; B < s; B++) {
390
+ const O = e[x * s + B] - e[E * s + B];
391
+ w += O * O;
313
392
  }
314
- const B = Math.pow(y, a), k = 2 * o * a / ((1e-3 + y) * (f * B + 1));
315
- for (let E = 0; E < n; E++) {
316
- const P = e[A * n + E] - e[S * n + E], F = z(k * P);
317
- e[A * n + E] += w * F;
393
+ const R = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (i * R + 1));
394
+ for (let B = 0; B < s; B++) {
395
+ const O = e[x * s + B] - e[E * s + B], S = j(A * O);
396
+ e[x * s + B] += y * S;
318
397
  }
319
398
  }
320
399
  }
321
400
  }
322
401
  return e;
323
402
  }
324
- function ee(e, t, r, s, n, g, c, d, f, a) {
325
- const { a: o, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = f, u = r.rows.length, p = new Uint32Array(r.rows), m = new Uint32Array(r.cols), _ = new Float32Array(u).fill(0), w = new Float32Array(u);
326
- for (let b = 0; b < u; b++)
327
- w[b] = s[b] / l;
403
+ function fe(e, t, f, a, s, u, c, d, i, n) {
404
+ const { a: r, b: o, gamma: h = 1, negativeSampleRate: l = 5 } = i, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), y = new Float32Array(p);
405
+ for (let b = 0; b < p; b++)
406
+ y[b] = a[b] / l;
328
407
  for (let b = 0; b < d; b++) {
329
- const A = 1 - b / d;
330
- for (let x = 0; x < u; x++) {
331
- if (_[x] > b) continue;
332
- const U = p[x], N = m[x];
333
- let R = 0;
334
- for (let y = 0; y < c; y++) {
335
- const B = e[U * c + y] - t[N * c + y];
336
- R += B * B;
408
+ const x = 1 - b / d;
409
+ for (let v = 0; v < p; v++) {
410
+ if (m[v] > b) continue;
411
+ const P = g[v], G = _[v];
412
+ let U = 0;
413
+ for (let w = 0; w < c; w++) {
414
+ const R = e[P * c + w] - t[G * c + w];
415
+ U += R * R;
337
416
  }
338
- const O = Math.pow(R, i), M = -2 * o * i * (R > 0 ? O / R : 0) / (o * O + 1);
339
- for (let y = 0; y < c; y++) {
340
- const B = e[U * c + y] - t[N * c + y];
341
- e[U * c + y] += A * z(M * B);
417
+ const F = Math.pow(U, o), N = -2 * r * o * (U > 0 ? F / U : 0) / (r * F + 1);
418
+ for (let w = 0; w < c; w++) {
419
+ const R = e[P * c + w] - t[G * c + w];
420
+ e[P * c + w] += x * j(N * R);
342
421
  }
343
- _[x] += s[x];
344
- const v = s[x] / l, S = Math.max(0, Math.floor(
345
- (b - w[x]) / v
422
+ m[v] += a[v];
423
+ const M = a[v] / l, E = Math.max(0, Math.floor(
424
+ (b - y[v]) / M
346
425
  ));
347
- w[x] += S * v;
348
- for (let y = 0; y < S; y++) {
349
- const B = Math.floor(Math.random() * g);
350
- if (B === N) continue;
351
- let k = 0;
352
- for (let F = 0; F < c; F++) {
353
- const q = e[U * c + F] - t[B * c + F];
354
- k += q * q;
426
+ y[v] += E * M;
427
+ for (let w = 0; w < E; w++) {
428
+ const R = Math.floor(Math.random() * u);
429
+ if (R === G) continue;
430
+ let A = 0;
431
+ for (let S = 0; S < c; S++) {
432
+ const k = e[P * c + S] - t[R * c + S];
433
+ A += k * k;
355
434
  }
356
- const E = Math.pow(k, i), P = 2 * h * i / ((1e-3 + k) * (o * E + 1));
357
- for (let F = 0; F < c; F++) {
358
- const q = e[U * c + F] - t[B * c + F];
359
- e[U * c + F] += A * z(P * q);
435
+ const B = Math.pow(A, o), O = 2 * h * o / ((1e-3 + A) * (r * B + 1));
436
+ for (let S = 0; S < c; S++) {
437
+ const k = e[P * c + S] - t[R * c + S];
438
+ e[P * c + S] += x * j(O * k);
360
439
  }
361
440
  }
362
441
  }
363
442
  }
364
443
  return e;
365
444
  }
366
- function K() {
367
- return typeof navigator < "u" && !!navigator.gpu;
368
- }
369
- async function re(e, t = {}, r) {
445
+ async function pe(e, t = {}, f) {
370
446
  const {
371
- nComponents: s = 2,
372
- nNeighbors: n = 15,
373
- minDist: g = 0.1,
447
+ nComponents: a = 2,
448
+ nNeighbors: s = 15,
449
+ minDist: u = 0.1,
374
450
  spread: c = 1,
375
451
  hnsw: d = {}
376
- } = t, f = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
452
+ } = t, i = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
377
453
  console.time("knn");
378
- const { indices: a, distances: o } = await Q(e, n, {
454
+ const { indices: n, distances: r } = await se(e, s, {
379
455
  M: d.M ?? 16,
380
456
  efConstruction: d.efConstruction ?? 200,
381
457
  efSearch: d.efSearch ?? 50
382
458
  });
383
459
  console.timeEnd("knn"), console.time("fuzzy-set");
384
- const i = L(a, o, n);
460
+ const o = V(n, r, s);
385
461
  console.timeEnd("fuzzy-set");
386
- const { a: h, b: l } = H(g, c), u = T(i.vals), p = e.length, m = new Float32Array(p * s);
387
- for (let w = 0; w < m.length; w++)
388
- m[w] = Math.random() * 20 - 10;
462
+ const { a: h, b: l } = Z(u, c), p = W(o.vals), g = e.length, _ = new Float32Array(g * a);
463
+ for (let y = 0; y < _.length; y++)
464
+ _[y] = Math.random() * 20 - 10;
389
465
  console.time("sgd");
390
- let _;
391
- if (K())
466
+ let m;
467
+ if (J())
392
468
  try {
393
- const w = new j();
394
- await w.init(), _ = await w.optimize(
395
- m,
396
- new Uint32Array(i.rows),
397
- new Uint32Array(i.cols),
398
- u,
469
+ const y = new X();
470
+ await y.init(), m = await y.optimize(
471
+ _,
472
+ new Uint32Array(o.rows),
473
+ new Uint32Array(o.cols),
399
474
  p,
400
- s,
401
- f,
475
+ g,
476
+ a,
477
+ i,
402
478
  { a: h, b: l, gamma: 1, negativeSampleRate: 5 },
403
- r
479
+ f
404
480
  );
405
- } catch (w) {
406
- console.warn("WebGPU SGD failed, falling back to CPU:", w), _ = I(m, i, u, p, s, f, { a: h, b: l }, r);
481
+ } catch (y) {
482
+ console.warn("WebGPU SGD failed, falling back to CPU:", y), m = L(_, o, p, g, a, i, { a: h, b: l }, f);
407
483
  }
408
484
  else
409
- _ = I(m, i, u, p, s, f, { a: h, b: l }, r);
410
- return console.timeEnd("sgd"), _;
485
+ m = L(_, o, p, g, a, i, { a: h, b: l }, f);
486
+ return console.timeEnd("sgd"), m;
411
487
  }
412
- function H(e, t) {
413
- if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6)
414
- return { a: 1.9292, b: 0.7915 };
415
- if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
416
- return { a: 1.8956, b: 0.8006 };
417
- if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
418
- return { a: 1.5769, b: 0.8951 };
419
- const r = te(e, t);
420
- return { a: ne(e, t, r), b: r };
488
+ function Z(e, t) {
489
+ return Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6 ? { a: 1.9292, b: 0.7915 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6 ? { a: 1.8956, b: 0.8006 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6 ? { a: 1.5769, b: 0.8951 } : de(e, t);
421
490
  }
422
- function te(e, t) {
423
- return 1 / (t * 1.2);
424
- }
425
- function ne(e, t, r) {
426
- return e < 1e-6 ? 1.8956 : (1 / (1 + 1e-3) - 1) / -Math.pow(e, 2 * r);
491
+ function de(e, t) {
492
+ const a = [], s = [];
493
+ for (let i = 0; i < 299; i++) {
494
+ const n = (i + 1) / 299 * t * 3;
495
+ a.push(n), s.push(n < e ? 1 : Math.exp(-(n - e) / t));
496
+ }
497
+ let u = 1, c = 1, d = 1e-3;
498
+ for (let i = 0; i < 500; i++) {
499
+ let n = 0, r = 0, o = 0, h = 0, l = 0, p = 0;
500
+ for (let U = 0; U < 299; U++) {
501
+ const F = a[U], N = Math.pow(F, 2 * c), M = 1 + u * N, w = 1 / M - s[U];
502
+ p += w * w;
503
+ const R = M * M, A = -N / R, B = F > 0 ? -2 * Math.log(F) * u * N / R : 0;
504
+ n += A * w, r += B * w, o += A * A, h += B * B, l += A * B;
505
+ }
506
+ const g = o + d, _ = h + d, m = l, y = g * _ - m * m;
507
+ if (Math.abs(y) < 1e-20) break;
508
+ const b = -(_ * n - m * r) / y, x = -(g * r - m * n) / y, v = Math.max(1e-4, u + b), P = Math.max(1e-4, c + x);
509
+ let G = 0;
510
+ for (let U = 0; U < 299; U++) {
511
+ const F = Math.pow(a[U], 2 * P), N = 1 / (1 + v * F) - s[U];
512
+ G += N * N;
513
+ }
514
+ if (G < p ? (u = v, c = P, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(b) < 1e-8 && Math.abs(x) < 1e-8) break;
515
+ }
516
+ return { a: u, b: c };
427
517
  }
428
- class ie {
518
+ class ge {
429
519
  constructor(t = {}) {
430
- G(this, "_nComponents");
431
- G(this, "_nNeighbors");
432
- G(this, "_minDist");
433
- G(this, "_spread");
434
- G(this, "_nEpochs");
435
- G(this, "_hnswOpts");
436
- G(this, "_a");
437
- G(this, "_b");
520
+ C(this, "_nComponents");
521
+ C(this, "_nNeighbors");
522
+ C(this, "_minDist");
523
+ C(this, "_spread");
524
+ C(this, "_nEpochs");
525
+ C(this, "_hnswOpts");
526
+ C(this, "_a");
527
+ C(this, "_b");
438
528
  /** The low-dimensional embedding produced by the last fit() call. */
439
- G(this, "embedding", null);
440
- G(this, "_hnswIndex", null);
441
- G(this, "_nTrain", 0);
529
+ C(this, "embedding", null);
530
+ C(this, "_hnswIndex", null);
531
+ C(this, "_nTrain", 0);
442
532
  this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
443
- const { a: r, b: s } = H(this._minDist, this._spread);
444
- this._a = r, this._b = s;
533
+ const { a: f, b: a } = Z(this._minDist, this._spread);
534
+ this._a = f, this._b = a;
445
535
  }
446
536
  /**
447
537
  * Train UMAP on `vectors`.
@@ -449,45 +539,45 @@ class ie {
449
539
  * index so that transform() can project new points later.
450
540
  * Returns `this` for chaining.
451
541
  */
452
- async fit(t, r) {
453
- const s = t.length, n = this._nEpochs ?? (s > 1e4 ? 200 : 500), { M: g = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
542
+ async fit(t, f) {
543
+ const a = t.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: u = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
454
544
  console.time("knn");
455
- const { knn: f, index: a } = await J(t, this._nNeighbors, {
456
- M: g,
545
+ const { knn: i, index: n } = await ae(t, this._nNeighbors, {
546
+ M: u,
457
547
  efConstruction: c,
458
548
  efSearch: d
459
549
  });
460
- this._hnswIndex = a, this._nTrain = s, console.timeEnd("knn"), console.time("fuzzy-set");
461
- const o = L(f.indices, f.distances, this._nNeighbors);
550
+ this._hnswIndex = n, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
551
+ const r = V(i.indices, i.distances, this._nNeighbors);
462
552
  console.timeEnd("fuzzy-set");
463
- const i = T(o.vals), h = new Float32Array(s * this._nComponents);
553
+ const o = W(r.vals), h = new Float32Array(a * this._nComponents);
464
554
  for (let l = 0; l < h.length; l++)
465
555
  h[l] = Math.random() * 20 - 10;
466
- if (console.time("sgd"), K())
556
+ if (console.time("sgd"), J())
467
557
  try {
468
- const l = new j();
558
+ const l = new X();
469
559
  await l.init(), this.embedding = await l.optimize(
470
560
  h,
471
- new Uint32Array(o.rows),
472
- new Uint32Array(o.cols),
473
- i,
474
- s,
561
+ new Uint32Array(r.rows),
562
+ new Uint32Array(r.cols),
563
+ o,
564
+ a,
475
565
  this._nComponents,
476
- n,
566
+ s,
477
567
  { a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 },
478
- r
568
+ f
479
569
  );
480
570
  } catch (l) {
481
- console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = I(h, o, i, s, this._nComponents, n, {
571
+ console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = L(h, r, o, a, this._nComponents, s, {
482
572
  a: this._a,
483
573
  b: this._b
484
- }, r);
574
+ }, f);
485
575
  }
486
576
  else
487
- this.embedding = I(h, o, i, s, this._nComponents, n, {
577
+ this.embedding = L(h, r, o, a, this._nComponents, s, {
488
578
  a: this._a,
489
579
  b: this._b
490
- }, r);
580
+ }, f);
491
581
  return console.timeEnd("sgd"), this;
492
582
  }
493
583
  /**
@@ -501,35 +591,35 @@ class ie {
501
591
  * returned embedding to [0, 1]. The stored training embedding is never
502
592
  * mutated. Defaults to `false`.
503
593
  */
504
- async transform(t, r = !1) {
594
+ async transform(t, f = !1) {
505
595
  if (!this._hnswIndex || !this.embedding)
506
596
  throw new Error("UMAP.transform() must be called after fit()");
507
- const s = t.length, n = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), g = Math.max(100, Math.floor(n / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = X(c.indices, c.distances, this._nNeighbors), f = new Uint32Array(d.rows), a = new Uint32Array(d.cols), o = new Float32Array(s), i = new Float32Array(s * this._nComponents);
508
- for (let u = 0; u < f.length; u++) {
509
- const p = f[u], m = a[u], _ = d.vals[u];
510
- o[p] += _;
511
- for (let w = 0; w < this._nComponents; w++)
512
- i[p * this._nComponents + w] += _ * this.embedding[m * this._nComponents + w];
597
+ const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), i = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), o = new Float32Array(a * this._nComponents);
598
+ for (let p = 0; p < i.length; p++) {
599
+ const g = i[p], _ = n[p], m = d.vals[p];
600
+ r[g] += m;
601
+ for (let y = 0; y < this._nComponents; y++)
602
+ o[g * this._nComponents + y] += m * this.embedding[_ * this._nComponents + y];
513
603
  }
514
- for (let u = 0; u < s; u++)
515
- if (o[u] > 0)
516
- for (let p = 0; p < this._nComponents; p++)
517
- i[u * this._nComponents + p] /= o[u];
604
+ for (let p = 0; p < a; p++)
605
+ if (r[p] > 0)
606
+ for (let g = 0; g < this._nComponents; g++)
607
+ o[p * this._nComponents + g] /= r[p];
518
608
  else
519
- for (let p = 0; p < this._nComponents; p++)
520
- i[u * this._nComponents + p] = Math.random() * 20 - 10;
521
- const h = T(d.vals), l = ee(
522
- i,
609
+ for (let g = 0; g < this._nComponents; g++)
610
+ o[p * this._nComponents + g] = Math.random() * 20 - 10;
611
+ const h = W(d.vals), l = fe(
612
+ o,
523
613
  this.embedding,
524
614
  d,
525
615
  h,
526
- s,
616
+ a,
527
617
  this._nTrain,
528
618
  this._nComponents,
529
- g,
619
+ u,
530
620
  { a: this._a, b: this._b }
531
621
  );
532
- return r ? C(l, s, this._nComponents) : l;
622
+ return f ? K(l, a, this._nComponents) : l;
533
623
  }
534
624
  /**
535
625
  * Convenience method equivalent to `fit(vectors)` followed by
@@ -540,37 +630,38 @@ class ie {
540
630
  * returned embedding to [0, 1]. `this.embedding` is never mutated.
541
631
  * Defaults to `false`.
542
632
  */
543
- async fit_transform(t, r, s = !1) {
544
- return await this.fit(t, r), s ? C(this.embedding, t.length, this._nComponents) : this.embedding;
633
+ async fit_transform(t, f, a = !1) {
634
+ return await this.fit(t, f), a ? K(this.embedding, t.length, this._nComponents) : this.embedding;
545
635
  }
546
636
  }
547
- function C(e, t, r) {
548
- const s = new Float32Array(e.length);
549
- for (let n = 0; n < r; n++) {
550
- let g = 1 / 0, c = -1 / 0;
551
- for (let f = 0; f < t; f++) {
552
- const a = e[f * r + n];
553
- a < g && (g = a), a > c && (c = a);
637
+ function K(e, t, f) {
638
+ const a = new Float32Array(e.length);
639
+ for (let s = 0; s < f; s++) {
640
+ let u = 1 / 0, c = -1 / 0;
641
+ for (let i = 0; i < t; i++) {
642
+ const n = e[i * f + s];
643
+ n < u && (u = n), n > c && (c = n);
554
644
  }
555
- const d = c - g;
556
- for (let f = 0; f < t; f++)
557
- s[f * r + n] = d > 0 ? (e[f * r + n] - g) / d : 0;
645
+ const d = c - u;
646
+ for (let i = 0; i < t; i++)
647
+ a[i * f + s] = d > 0 ? (e[i * f + s] - u) / d : 0;
558
648
  }
559
- return s;
649
+ return a;
560
650
  }
561
- function T(e, t) {
562
- let r = -1 / 0;
563
- for (let n = 0; n < e.length; n++)
564
- e[n] > r && (r = e[n]);
565
- const s = new Float32Array(e.length);
566
- for (let n = 0; n < e.length; n++) {
567
- const g = e[n] / r;
568
- s[n] = g > 0 ? 1 / g : -1;
651
+ function W(e, t) {
652
+ let f = -1 / 0;
653
+ for (let s = 0; s < e.length; s++)
654
+ e[s] > f && (f = e[s]);
655
+ const a = new Float32Array(e.length);
656
+ for (let s = 0; s < e.length; s++) {
657
+ const u = e[s] / f;
658
+ a[s] = u > 0 ? 1 / u : -1;
569
659
  }
570
- return s;
660
+ return a;
571
661
  }
572
662
  export {
573
- ie as UMAP,
574
- re as fit,
575
- K as isWebGPUAvailable
663
+ ge as UMAP,
664
+ he as checkWebGPUAvailable,
665
+ pe as fit,
666
+ J as isWebGPUAvailable
576
667
  };
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "umap-gpu",
3
- "version": "0.2.13",
3
+ "version": "0.2.14",
4
4
  "description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
5
5
  "type": "module",
6
6
  "main": "dist/index.js",