umap-gpu 0.2.13 → 0.2.15
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/dist/fuzzy-set.d.ts +2 -2
- package/dist/gpu/device.d.ts +18 -1
- package/dist/gpu/sgd.d.ts +14 -2
- package/dist/index.d.ts +1 -1
- package/dist/index.js +394 -300
- package/package.json +3 -2
package/dist/fuzzy-set.d.ts
CHANGED
package/dist/gpu/device.d.ts
CHANGED
|
@@ -7,6 +7,23 @@
|
|
|
7
7
|
*/
|
|
8
8
|
export declare function getGPUDevice(): Promise<GPUDevice | null>;
|
|
9
9
|
/**
|
|
10
|
-
*
|
|
10
|
+
* Fast synchronous heuristic: returns `true` if `navigator.gpu` exists.
|
|
11
|
+
*
|
|
12
|
+
* **Caveat (Bug 13):** `navigator.gpu` being truthy does NOT guarantee that a
|
|
13
|
+
* WebGPU adapter can be acquired — `requestAdapter()` may still return `null`
|
|
14
|
+
* (no compatible hardware, or the browser has disabled WebGPU for the page).
|
|
15
|
+
* Use `checkWebGPUAvailable()` for a reliable async check, or rely on the
|
|
16
|
+
* `try/catch` around `GPUSgd.init()` in the calling code.
|
|
11
17
|
*/
|
|
12
18
|
export declare function isWebGPUAvailable(): boolean;
|
|
19
|
+
/**
|
|
20
|
+
* Reliably check whether WebGPU is usable in the current environment by
|
|
21
|
+
* attempting to acquire an adapter via `getGPUDevice()`.
|
|
22
|
+
*
|
|
23
|
+
* Unlike the synchronous `isWebGPUAvailable()`, this actually calls
|
|
24
|
+
* `navigator.gpu.requestAdapter()` and returns `false` if the adapter is
|
|
25
|
+
* unavailable (no compatible GPU, browser policy, etc.).
|
|
26
|
+
*
|
|
27
|
+
* The result is automatically cached — repeated calls are free.
|
|
28
|
+
*/
|
|
29
|
+
export declare function checkWebGPUAvailable(): Promise<boolean>;
|
package/dist/gpu/sgd.d.ts
CHANGED
|
@@ -6,11 +6,23 @@ export interface SGDParams {
|
|
|
6
6
|
}
|
|
7
7
|
/**
|
|
8
8
|
* GPU-accelerated SGD optimizer for UMAP embedding.
|
|
9
|
-
*
|
|
9
|
+
*
|
|
10
|
+
* Uses a two-pass design per epoch to eliminate write-write races on shared
|
|
11
|
+
* vertex positions (Bug 2 fix):
|
|
12
|
+
* Pass 1 (sgd.wgsl): Each thread accumulates its attraction and
|
|
13
|
+
* repulsion gradients into an atomic<i32>
|
|
14
|
+
* forces buffer — no direct embedding writes.
|
|
15
|
+
* Pass 2 (apply-forces.wgsl): Each thread applies one element's accumulated
|
|
16
|
+
* force to the embedding and resets the
|
|
17
|
+
* accumulator to zero for the next epoch.
|
|
18
|
+
*
|
|
19
|
+
* Both passes are submitted in the same command encoder so WebGPU guarantees
|
|
20
|
+
* sequential execution and storage-buffer visibility between them.
|
|
10
21
|
*/
|
|
11
22
|
export declare class GPUSgd {
|
|
12
23
|
private device;
|
|
13
|
-
private
|
|
24
|
+
private sgdPipeline;
|
|
25
|
+
private applyForcesPipeline;
|
|
14
26
|
init(): Promise<void>;
|
|
15
27
|
/**
|
|
16
28
|
* Run SGD optimization on the GPU.
|
package/dist/index.d.ts
CHANGED
|
@@ -2,4 +2,4 @@ export { fit, UMAP } from './umap';
|
|
|
2
2
|
export type { UMAPOptions, ProgressCallback } from './umap';
|
|
3
3
|
export type { KNNResult, HNSWOptions, HNSWSearchableIndex } from './hnsw-knn';
|
|
4
4
|
export type { FuzzyGraph } from './fuzzy-set';
|
|
5
|
-
export { isWebGPUAvailable } from './gpu/device';
|
|
5
|
+
export { isWebGPUAvailable, checkWebGPUAvailable } from './gpu/device';
|
package/dist/index.js
CHANGED
|
@@ -1,104 +1,111 @@
|
|
|
1
|
-
var
|
|
2
|
-
var
|
|
3
|
-
var
|
|
4
|
-
import { loadHnswlib as
|
|
5
|
-
async function
|
|
6
|
-
const { M:
|
|
7
|
-
|
|
8
|
-
const
|
|
9
|
-
for (let h = 0; h <
|
|
10
|
-
const l =
|
|
11
|
-
|
|
1
|
+
var te = Object.defineProperty;
|
|
2
|
+
var ne = (e, t, f) => t in e ? te(e, t, { enumerable: !0, configurable: !0, writable: !0, value: f }) : e[t] = f;
|
|
3
|
+
var R = (e, t, f) => ne(e, typeof t != "symbol" ? t + "" : t, f);
|
|
4
|
+
import { loadHnswlib as H } from "hnswlib-wasm";
|
|
5
|
+
async function se(e, t, f = {}) {
|
|
6
|
+
const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
|
|
7
|
+
n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
|
|
8
|
+
const r = [], i = [];
|
|
9
|
+
for (let h = 0; h < o; h++) {
|
|
10
|
+
const l = n.searchKnn(e[h], t + 1, void 0), p = l.neighbors.map((g, _) => ({ idx: g, dist: l.distances[_] })).filter(({ idx: g }) => g !== h).slice(0, t);
|
|
11
|
+
r.push(p.map(({ idx: g }) => g)), i.push(p.map(({ dist: g }) => Math.sqrt(g)));
|
|
12
12
|
}
|
|
13
|
-
return { indices:
|
|
13
|
+
return { indices: r, distances: i };
|
|
14
14
|
}
|
|
15
|
-
async function
|
|
16
|
-
const { M:
|
|
17
|
-
|
|
18
|
-
const
|
|
19
|
-
for (let l = 0; l <
|
|
20
|
-
const
|
|
21
|
-
|
|
15
|
+
async function ae(e, t, f = {}) {
|
|
16
|
+
const { M: a = 16, efConstruction: s = 200, efSearch: u = 50 } = f, c = await H(), d = e[0].length, o = e.length, n = new c.HierarchicalNSW("l2", d, "");
|
|
17
|
+
n.initIndex(o, a, s, 200), n.setEfSearch(Math.max(u, t)), n.addItems(e, !1);
|
|
18
|
+
const r = [], i = [];
|
|
19
|
+
for (let l = 0; l < o; l++) {
|
|
20
|
+
const p = n.searchKnn(e[l], t + 1, void 0), g = p.neighbors.map((_, m) => ({ idx: _, dist: p.distances[m] })).filter(({ idx: _ }) => _ !== l).slice(0, t);
|
|
21
|
+
r.push(g.map(({ idx: _ }) => _)), i.push(g.map(({ dist: _ }) => Math.sqrt(_)));
|
|
22
22
|
}
|
|
23
|
-
return { knn: { indices:
|
|
24
|
-
searchKnn(l,
|
|
25
|
-
const
|
|
26
|
-
for (const
|
|
27
|
-
const
|
|
28
|
-
|
|
23
|
+
return { knn: { indices: r, distances: i }, index: {
|
|
24
|
+
searchKnn(l, p) {
|
|
25
|
+
const g = [], _ = [];
|
|
26
|
+
for (const m of l) {
|
|
27
|
+
const b = n.searchKnn(m, p, void 0), y = b.neighbors.map((x, v) => ({ idx: x, dist: b.distances[v] })).sort((x, v) => x.dist - v.dist).slice(0, p);
|
|
28
|
+
g.push(y.map(({ idx: x }) => x)), _.push(y.map(({ dist: x }) => Math.sqrt(x)));
|
|
29
29
|
}
|
|
30
|
-
return { indices:
|
|
30
|
+
return { indices: g, distances: _ };
|
|
31
31
|
}
|
|
32
32
|
} };
|
|
33
33
|
}
|
|
34
|
-
function
|
|
35
|
-
const
|
|
36
|
-
for (let i = 0; i <
|
|
34
|
+
function V(e, t, f, a = 1) {
|
|
35
|
+
const s = e.length, { sigmas: u, rhos: c } = Y(t, f), d = [], o = [], n = [];
|
|
36
|
+
for (let i = 0; i < s; i++)
|
|
37
37
|
for (let h = 0; h < e[i].length; h++) {
|
|
38
|
-
const l = t[i][h],
|
|
39
|
-
d.push(i),
|
|
38
|
+
const l = t[i][h], p = l <= c[i] ? 1 : Math.exp(-((l - c[i]) / u[i]));
|
|
39
|
+
d.push(i), o.push(e[i][h]), n.push(p);
|
|
40
40
|
}
|
|
41
|
-
return { ...
|
|
41
|
+
return { ...ie(d, o, n, s, a), nVertices: s };
|
|
42
42
|
}
|
|
43
|
-
function
|
|
44
|
-
const
|
|
45
|
-
for (let
|
|
46
|
-
for (let
|
|
47
|
-
const i = t[
|
|
48
|
-
c.push(
|
|
43
|
+
function re(e, t, f) {
|
|
44
|
+
const a = e.length, { sigmas: s, rhos: u } = Y(t, f), c = [], d = [], o = [];
|
|
45
|
+
for (let n = 0; n < a; n++)
|
|
46
|
+
for (let r = 0; r < e[n].length; r++) {
|
|
47
|
+
const i = t[n][r], h = i <= u[n] ? 1 : Math.exp(-((i - u[n]) / s[n]));
|
|
48
|
+
c.push(n), d.push(e[n][r]), o.push(h);
|
|
49
49
|
}
|
|
50
50
|
return {
|
|
51
|
-
rows: new
|
|
52
|
-
cols: new
|
|
53
|
-
vals: new Float32Array(
|
|
54
|
-
nVertices:
|
|
51
|
+
rows: new Uint32Array(c),
|
|
52
|
+
cols: new Uint32Array(d),
|
|
53
|
+
vals: new Float32Array(o),
|
|
54
|
+
nVertices: a
|
|
55
55
|
};
|
|
56
56
|
}
|
|
57
|
-
function
|
|
58
|
-
const
|
|
59
|
-
for (let c = 0; c <
|
|
57
|
+
function Y(e, t) {
|
|
58
|
+
const a = e.length, s = new Float32Array(a), u = new Float32Array(a);
|
|
59
|
+
for (let c = 0; c < a; c++) {
|
|
60
60
|
const d = e[c];
|
|
61
|
-
|
|
62
|
-
let
|
|
61
|
+
u[c] = d.find((h) => h > 0) ?? 0;
|
|
62
|
+
let o = 0, n = 1 / 0, r = 1;
|
|
63
63
|
const i = Math.log2(t);
|
|
64
64
|
for (let h = 0; h < 64; h++) {
|
|
65
65
|
let l = 0;
|
|
66
|
-
for (let
|
|
67
|
-
l += Math.exp(-Math.max(0, d[
|
|
66
|
+
for (let p = 0; p < d.length; p++)
|
|
67
|
+
l += Math.exp(-Math.max(0, d[p] - u[c]) / r);
|
|
68
68
|
if (Math.abs(l - i) < 1e-5) break;
|
|
69
|
-
l > i ? (
|
|
69
|
+
l > i ? (n = r, r = (o + n) / 2) : (o = r, r = n === 1 / 0 ? r * 2 : (o + n) / 2);
|
|
70
70
|
}
|
|
71
|
-
|
|
71
|
+
s[c] = r;
|
|
72
72
|
}
|
|
73
|
-
return { sigmas:
|
|
73
|
+
return { sigmas: s, rhos: u };
|
|
74
74
|
}
|
|
75
|
-
function
|
|
76
|
-
const
|
|
77
|
-
for (let
|
|
78
|
-
|
|
79
|
-
const c = [], d = [],
|
|
80
|
-
for (const [
|
|
81
|
-
const i = Math.floor(
|
|
82
|
-
c.push(i), d.push(h),
|
|
75
|
+
function ie(e, t, f, a, s) {
|
|
76
|
+
const u = /* @__PURE__ */ new Map();
|
|
77
|
+
for (let n = 0; n < e.length; n++)
|
|
78
|
+
u.set(e[n] * a + t[n], f[n]);
|
|
79
|
+
const c = [], d = [], o = [];
|
|
80
|
+
for (const [n, r] of u) {
|
|
81
|
+
const i = Math.floor(n / a), h = n % a, l = u.get(h * a + i) ?? 0, p = r + l - r * l, g = r * l;
|
|
82
|
+
c.push(i), d.push(h), o.push(s * p + (1 - s) * g);
|
|
83
83
|
}
|
|
84
84
|
return {
|
|
85
|
-
rows: new
|
|
86
|
-
cols: new
|
|
87
|
-
vals: new Float32Array(
|
|
85
|
+
rows: new Uint32Array(c),
|
|
86
|
+
cols: new Uint32Array(d),
|
|
87
|
+
vals: new Float32Array(o)
|
|
88
88
|
};
|
|
89
89
|
}
|
|
90
|
-
const
|
|
91
|
-
//
|
|
92
|
-
//
|
|
90
|
+
const oe = `// UMAP SGD compute shader — processes one graph edge per GPU thread.
|
|
91
|
+
// Computes attraction and repulsion forces and accumulates them atomically
|
|
92
|
+
// into a forces buffer. A separate apply-forces pass then updates embeddings,
|
|
93
|
+
// eliminating write-write races on shared vertex positions.
|
|
93
94
|
|
|
94
95
|
@group(0) @binding(0) var<storage, read> epochs_per_sample : array<f32>;
|
|
95
96
|
@group(0) @binding(1) var<storage, read> head : array<u32>; // edge source
|
|
96
97
|
@group(0) @binding(2) var<storage, read> tail : array<u32>; // edge target
|
|
97
|
-
@group(0) @binding(3) var<storage,
|
|
98
|
+
@group(0) @binding(3) var<storage, read> embedding : array<f32>; // [n * nComponents], read-only
|
|
98
99
|
@group(0) @binding(4) var<storage, read_write> epoch_of_next_sample : array<f32>;
|
|
99
100
|
@group(0) @binding(5) var<storage, read_write> epoch_of_next_negative_sample : array<f32>;
|
|
100
101
|
@group(0) @binding(6) var<uniform> params : Params;
|
|
101
102
|
@group(0) @binding(7) var<storage, read> rng_seeds : array<u32>; // per-edge seed
|
|
103
|
+
@group(0) @binding(8) var<storage, read_write> forces : array<atomic<i32>>; // [n * nComponents]
|
|
104
|
+
|
|
105
|
+
// Scale factor for quantizing f32 gradients into i32 for atomic accumulation.
|
|
106
|
+
// Gradients are clipped to [-4, 4]. With up to ~1000 edges sharing a vertex
|
|
107
|
+
// the max accumulated magnitude is ~4000, well within i32 range at this scale.
|
|
108
|
+
const FORCE_SCALE : f32 = 65536.0;
|
|
102
109
|
|
|
103
110
|
struct Params {
|
|
104
111
|
n_edges : u32,
|
|
@@ -106,7 +113,7 @@ struct Params {
|
|
|
106
113
|
n_components : u32,
|
|
107
114
|
current_epoch : u32,
|
|
108
115
|
n_epochs : u32,
|
|
109
|
-
alpha : f32, // learning rate
|
|
116
|
+
alpha : f32, // learning rate (applied by apply-forces pass)
|
|
110
117
|
a : f32,
|
|
111
118
|
b : f32,
|
|
112
119
|
gamma : f32, // repulsion strength
|
|
@@ -147,7 +154,6 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
147
154
|
|
|
148
155
|
let pow_b = pow(dist_sq, params.b);
|
|
149
156
|
// Guard dist_sq == 0: b-1 is negative so pow(0, b-1) = +Inf.
|
|
150
|
-
// Mirror CPU: use pow_b / dist_sq only when dist_sq > 0, else 0.
|
|
151
157
|
let grad_coeff_attr = select(
|
|
152
158
|
-2.0 * params.a * params.b * (pow_b / dist_sq) / (params.a * pow_b + 1.0),
|
|
153
159
|
0.0,
|
|
@@ -157,20 +163,28 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
157
163
|
for (var d = 0u; d < nc; d++) {
|
|
158
164
|
let diff = embedding[i * nc + d] - embedding[j * nc + d];
|
|
159
165
|
let grad = clip(grad_coeff_attr * diff, -4.0, 4.0);
|
|
160
|
-
|
|
161
|
-
|
|
166
|
+
// Accumulate atomically to avoid write-write races across threads.
|
|
167
|
+
atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
|
|
168
|
+
atomicAdd(&forces[j * nc + d], -i32(grad * FORCE_SCALE));
|
|
162
169
|
}
|
|
163
170
|
|
|
164
171
|
epoch_of_next_sample[edge_idx] += epochs_per_sample[edge_idx];
|
|
165
172
|
|
|
166
173
|
// --- Repulsion (negative samples) ---
|
|
167
|
-
|
|
168
|
-
|
|
174
|
+
// Compute how many negative samples are overdue relative to current epoch,
|
|
175
|
+
// matching the Python reference: n_neg = floor((n - next_neg) / eps_per_neg).
|
|
176
|
+
let epoch_f = f32(params.current_epoch);
|
|
177
|
+
let epochs_per_neg = epochs_per_sample[edge_idx] / f32(params.negative_sample_rate);
|
|
169
178
|
var n_neg = 0u;
|
|
170
|
-
if (
|
|
171
|
-
n_neg = u32(
|
|
179
|
+
if (epochs_per_neg > 0.0 && epoch_f >= epoch_of_next_negative_sample[edge_idx]) {
|
|
180
|
+
n_neg = u32((epoch_f - epoch_of_next_negative_sample[edge_idx]) / epochs_per_neg);
|
|
181
|
+
epoch_of_next_negative_sample[edge_idx] += f32(n_neg) * epochs_per_neg;
|
|
172
182
|
}
|
|
173
|
-
|
|
183
|
+
|
|
184
|
+
// 2654435761u is the 32-bit golden-ratio hash constant (0x9E3779B1),
|
|
185
|
+
// which fits in u32 unlike the 64-bit LCG value 6364136223 used originally.
|
|
186
|
+
// Bug 14 fix: the original constant exceeded u32 range and failed WGSL validation.
|
|
187
|
+
var rng = xorshift(rng_seeds[edge_idx] + params.current_epoch * 2654435761u);
|
|
174
188
|
|
|
175
189
|
for (var s = 0u; s < n_neg; s++) {
|
|
176
190
|
rng = xorshift(rng);
|
|
@@ -189,26 +203,73 @@ fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
|
189
203
|
for (var d = 0u; d < nc; d++) {
|
|
190
204
|
let diff = embedding[i * nc + d] - embedding[k * nc + d];
|
|
191
205
|
let grad = clip(grad_coeff_rep * diff, -4.0, 4.0);
|
|
192
|
-
|
|
206
|
+
atomicAdd(&forces[i * nc + d], i32(grad * FORCE_SCALE));
|
|
193
207
|
}
|
|
194
208
|
}
|
|
209
|
+
}
|
|
210
|
+
`, ce = `// Apply-forces shader — second pass of the two-pass GPU SGD.
|
|
211
|
+
//
|
|
212
|
+
// After the SGD pass has atomically accumulated all gradients into the forces
|
|
213
|
+
// buffer, this shader applies each element's accumulated force to the
|
|
214
|
+
// embedding and resets the accumulator to zero for the next epoch.
|
|
195
215
|
|
|
196
|
-
|
|
197
|
-
|
|
216
|
+
@group(0) @binding(0) var<storage, read_write> embedding : array<f32>;
|
|
217
|
+
@group(0) @binding(1) var<storage, read_write> forces : array<atomic<i32>>;
|
|
218
|
+
@group(0) @binding(2) var<uniform> params : ApplyParams;
|
|
219
|
+
|
|
220
|
+
struct ApplyParams {
|
|
221
|
+
n_elements : u32, // nVertices * nComponents
|
|
222
|
+
alpha : f32, // current learning rate
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
// Must match FORCE_SCALE in sgd.wgsl
|
|
226
|
+
const FORCE_SCALE : f32 = 65536.0;
|
|
227
|
+
|
|
228
|
+
@compute @workgroup_size(256)
|
|
229
|
+
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
230
|
+
let idx = gid.x;
|
|
231
|
+
if (idx >= params.n_elements) { return; }
|
|
232
|
+
|
|
233
|
+
// atomicExchange atomically reads the accumulated force and resets it to 0.
|
|
234
|
+
let raw = atomicExchange(&forces[idx], 0);
|
|
235
|
+
embedding[idx] += params.alpha * f32(raw) / FORCE_SCALE;
|
|
198
236
|
}
|
|
199
237
|
`;
|
|
200
|
-
|
|
238
|
+
let z = null;
|
|
239
|
+
async function Q() {
|
|
240
|
+
if (z) return z;
|
|
241
|
+
if (typeof navigator > "u" || !navigator.gpu)
|
|
242
|
+
return null;
|
|
243
|
+
const e = await navigator.gpu.requestAdapter();
|
|
244
|
+
return e ? (z = await e.requestDevice(), z.lost.then(() => {
|
|
245
|
+
z = null;
|
|
246
|
+
}), z) : null;
|
|
247
|
+
}
|
|
248
|
+
function J() {
|
|
249
|
+
return typeof navigator < "u" && !!navigator.gpu;
|
|
250
|
+
}
|
|
251
|
+
async function he() {
|
|
252
|
+
return await Q() !== null;
|
|
253
|
+
}
|
|
254
|
+
class X {
|
|
201
255
|
constructor() {
|
|
202
|
-
|
|
203
|
-
|
|
256
|
+
R(this, "device");
|
|
257
|
+
R(this, "sgdPipeline");
|
|
258
|
+
R(this, "applyForcesPipeline");
|
|
204
259
|
}
|
|
205
260
|
async init() {
|
|
206
|
-
const t = await
|
|
261
|
+
const t = await Q();
|
|
207
262
|
if (!t) throw new Error("WebGPU not supported");
|
|
208
|
-
this.device =
|
|
263
|
+
this.device = t, this.sgdPipeline = this.device.createComputePipeline({
|
|
264
|
+
layout: "auto",
|
|
265
|
+
compute: {
|
|
266
|
+
module: this.device.createShaderModule({ code: oe }),
|
|
267
|
+
entryPoint: "main"
|
|
268
|
+
}
|
|
269
|
+
}), this.applyForcesPipeline = this.device.createComputePipeline({
|
|
209
270
|
layout: "auto",
|
|
210
271
|
compute: {
|
|
211
|
-
module: this.device.createShaderModule({ code:
|
|
272
|
+
module: this.device.createShaderModule({ code: ce }),
|
|
212
273
|
entryPoint: "main"
|
|
213
274
|
}
|
|
214
275
|
});
|
|
@@ -226,222 +287,254 @@ class j {
|
|
|
226
287
|
* @param params - UMAP curve parameters and repulsion settings
|
|
227
288
|
* @returns Optimized embedding as Float32Array
|
|
228
289
|
*/
|
|
229
|
-
async optimize(t,
|
|
230
|
-
const { device:
|
|
290
|
+
async optimize(t, f, a, s, u, c, d, o, n) {
|
|
291
|
+
const { device: r } = this, i = f.length, h = u * c, l = this.makeBuffer(
|
|
231
292
|
t,
|
|
232
293
|
GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
|
|
233
|
-
),
|
|
234
|
-
for (let
|
|
235
|
-
|
|
236
|
-
const
|
|
237
|
-
for (let
|
|
238
|
-
A
|
|
239
|
-
const
|
|
294
|
+
), p = this.makeBuffer(f, GPUBufferUsage.STORAGE), g = this.makeBuffer(a, GPUBufferUsage.STORAGE), _ = this.makeBuffer(s, GPUBufferUsage.STORAGE), m = new Float32Array(s), b = this.makeBuffer(m, GPUBufferUsage.STORAGE), y = new Float32Array(i);
|
|
295
|
+
for (let A = 0; A < i; A++)
|
|
296
|
+
y[A] = s[A] / o.negativeSampleRate;
|
|
297
|
+
const x = this.makeBuffer(y, GPUBufferUsage.STORAGE), v = new Uint32Array(i);
|
|
298
|
+
for (let A = 0; A < i; A++)
|
|
299
|
+
v[A] = Math.random() * 4294967295 | 0;
|
|
300
|
+
const G = this.makeBuffer(v, GPUBufferUsage.STORAGE), P = r.createBuffer({
|
|
301
|
+
size: h * 4,
|
|
302
|
+
usage: GPUBufferUsage.STORAGE,
|
|
303
|
+
mappedAtCreation: !0
|
|
304
|
+
});
|
|
305
|
+
new Int32Array(P.getMappedRange()).fill(0), P.unmap();
|
|
306
|
+
const U = r.createBuffer({
|
|
240
307
|
size: 40,
|
|
241
308
|
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
309
|
+
}), F = r.createBuffer({
|
|
310
|
+
size: 16,
|
|
311
|
+
usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
|
|
312
|
+
}), N = r.createBindGroup({
|
|
313
|
+
layout: this.sgdPipeline.getBindGroupLayout(0),
|
|
314
|
+
entries: [
|
|
315
|
+
{ binding: 0, resource: { buffer: _ } },
|
|
316
|
+
{ binding: 1, resource: { buffer: p } },
|
|
317
|
+
{ binding: 2, resource: { buffer: g } },
|
|
318
|
+
{ binding: 3, resource: { buffer: l } },
|
|
319
|
+
{ binding: 4, resource: { buffer: b } },
|
|
320
|
+
{ binding: 5, resource: { buffer: x } },
|
|
321
|
+
{ binding: 6, resource: { buffer: U } },
|
|
322
|
+
{ binding: 7, resource: { buffer: G } },
|
|
323
|
+
{ binding: 8, resource: { buffer: P } }
|
|
324
|
+
]
|
|
325
|
+
}), M = r.createBindGroup({
|
|
326
|
+
layout: this.applyForcesPipeline.getBindGroupLayout(0),
|
|
327
|
+
entries: [
|
|
328
|
+
{ binding: 0, resource: { buffer: l } },
|
|
329
|
+
{ binding: 1, resource: { buffer: P } },
|
|
330
|
+
{ binding: 2, resource: { buffer: F } }
|
|
331
|
+
]
|
|
242
332
|
});
|
|
243
|
-
for (let
|
|
244
|
-
const
|
|
245
|
-
|
|
246
|
-
const
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
{ binding: 3, resource: { buffer: h } },
|
|
253
|
-
{ binding: 4, resource: { buffer: _ } },
|
|
254
|
-
{ binding: 5, resource: { buffer: b } },
|
|
255
|
-
{ binding: 6, resource: { buffer: U } },
|
|
256
|
-
{ binding: 7, resource: { buffer: x } }
|
|
257
|
-
]
|
|
258
|
-
}), E = o.createCommandEncoder(), P = E.beginComputePass();
|
|
259
|
-
P.setPipeline(this.pipeline), P.setBindGroup(0, k), P.dispatchWorkgroups(Math.ceil(i / 256)), P.end(), o.queue.submit([E.finish()]), M % 10 === 0 && (await o.queue.onSubmittedWorkDone(), a == null || a(M, d));
|
|
333
|
+
for (let A = 0; A < d; A++) {
|
|
334
|
+
const B = 1 - A / d, O = new ArrayBuffer(40), S = new Uint32Array(O), k = new Float32Array(O);
|
|
335
|
+
S[0] = i, S[1] = u, S[2] = c, S[3] = A, S[4] = d, k[5] = B, k[6] = o.a, k[7] = o.b, k[8] = o.gamma, S[9] = o.negativeSampleRate, r.queue.writeBuffer(U, 0, O);
|
|
336
|
+
const D = new ArrayBuffer(16), $ = new Uint32Array(D), ee = new Float32Array(D);
|
|
337
|
+
$[0] = h, ee[1] = B, r.queue.writeBuffer(F, 0, D);
|
|
338
|
+
const T = r.createCommandEncoder(), q = T.beginComputePass();
|
|
339
|
+
q.setPipeline(this.sgdPipeline), q.setBindGroup(0, N), q.dispatchWorkgroups(Math.ceil(i / 256)), q.end();
|
|
340
|
+
const L = T.beginComputePass();
|
|
341
|
+
L.setPipeline(this.applyForcesPipeline), L.setBindGroup(0, M), L.dispatchWorkgroups(Math.ceil(h / 256)), L.end(), r.queue.submit([T.finish()]), A % 10 === 0 && (await r.queue.onSubmittedWorkDone(), n == null || n(A, d));
|
|
260
342
|
}
|
|
261
|
-
const
|
|
343
|
+
const E = r.createBuffer({
|
|
262
344
|
size: t.byteLength,
|
|
263
345
|
usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ
|
|
264
|
-
}),
|
|
265
|
-
|
|
266
|
-
const
|
|
267
|
-
return
|
|
346
|
+
}), w = r.createCommandEncoder();
|
|
347
|
+
w.copyBufferToBuffer(l, 0, E, 0, t.byteLength), r.queue.submit([w.finish()]), await E.mapAsync(GPUMapMode.READ);
|
|
348
|
+
const C = new Float32Array(E.getMappedRange().slice(0));
|
|
349
|
+
return E.unmap(), l.destroy(), p.destroy(), g.destroy(), _.destroy(), b.destroy(), x.destroy(), G.destroy(), P.destroy(), U.destroy(), F.destroy(), E.destroy(), C;
|
|
268
350
|
}
|
|
269
|
-
makeBuffer(t,
|
|
270
|
-
const
|
|
351
|
+
makeBuffer(t, f) {
|
|
352
|
+
const a = this.device.createBuffer({
|
|
271
353
|
size: t.byteLength,
|
|
272
|
-
usage:
|
|
354
|
+
usage: f,
|
|
273
355
|
mappedAtCreation: !0
|
|
274
356
|
});
|
|
275
|
-
return t instanceof Float32Array ? new Float32Array(
|
|
357
|
+
return t instanceof Float32Array ? new Float32Array(a.getMappedRange()).set(t) : new Uint32Array(a.getMappedRange()).set(t), a.unmap(), a;
|
|
276
358
|
}
|
|
277
359
|
}
|
|
278
|
-
function
|
|
360
|
+
function I(e) {
|
|
279
361
|
return Math.max(-4, Math.min(4, e));
|
|
280
362
|
}
|
|
281
|
-
function
|
|
282
|
-
const { a:
|
|
283
|
-
for (let
|
|
284
|
-
m
|
|
285
|
-
for (let
|
|
286
|
-
d == null || d(
|
|
287
|
-
const
|
|
288
|
-
for (let
|
|
289
|
-
if (
|
|
290
|
-
const
|
|
291
|
-
let
|
|
292
|
-
for (let
|
|
293
|
-
const
|
|
294
|
-
|
|
363
|
+
function j(e, t, f, a, s, u, c, d) {
|
|
364
|
+
const { a: o, b: n, gamma: r = 1, negativeSampleRate: i = 5 } = c, h = t.rows.length, l = new Uint32Array(t.rows), p = new Uint32Array(t.cols), g = new Float32Array(f), _ = new Float32Array(h);
|
|
365
|
+
for (let m = 0; m < h; m++)
|
|
366
|
+
_[m] = f[m] / i;
|
|
367
|
+
for (let m = 0; m < u; m++) {
|
|
368
|
+
d == null || d(m, u);
|
|
369
|
+
const b = 1 - m / u;
|
|
370
|
+
for (let y = 0; y < h; y++) {
|
|
371
|
+
if (g[y] > m) continue;
|
|
372
|
+
const x = l[y], v = p[y];
|
|
373
|
+
let G = 0;
|
|
374
|
+
for (let M = 0; M < s; M++) {
|
|
375
|
+
const E = e[x * s + M] - e[v * s + M];
|
|
376
|
+
G += E * E;
|
|
295
377
|
}
|
|
296
|
-
const
|
|
297
|
-
for (let
|
|
298
|
-
const
|
|
299
|
-
e[
|
|
378
|
+
const P = Math.pow(G, n), U = -2 * o * n * (G > 0 ? P / G : 0) / (o * P + 1);
|
|
379
|
+
for (let M = 0; M < s; M++) {
|
|
380
|
+
const E = e[x * s + M] - e[v * s + M], w = I(U * E);
|
|
381
|
+
e[x * s + M] += b * w, e[v * s + M] -= b * w;
|
|
300
382
|
}
|
|
301
|
-
|
|
302
|
-
const
|
|
303
|
-
(
|
|
383
|
+
g[y] += f[y];
|
|
384
|
+
const F = f[y] / i, N = Math.max(0, Math.floor(
|
|
385
|
+
(m - _[y]) / F
|
|
304
386
|
));
|
|
305
|
-
|
|
306
|
-
for (let
|
|
307
|
-
const
|
|
308
|
-
if (
|
|
309
|
-
let
|
|
310
|
-
for (let
|
|
311
|
-
const
|
|
312
|
-
|
|
387
|
+
_[y] += N * F;
|
|
388
|
+
for (let M = 0; M < N; M++) {
|
|
389
|
+
const E = Math.floor(Math.random() * a);
|
|
390
|
+
if (E === x) continue;
|
|
391
|
+
let w = 0;
|
|
392
|
+
for (let B = 0; B < s; B++) {
|
|
393
|
+
const O = e[x * s + B] - e[E * s + B];
|
|
394
|
+
w += O * O;
|
|
313
395
|
}
|
|
314
|
-
const
|
|
315
|
-
for (let
|
|
316
|
-
const
|
|
317
|
-
e[
|
|
396
|
+
const C = Math.pow(w, n), A = 2 * r * n / ((1e-3 + w) * (o * C + 1));
|
|
397
|
+
for (let B = 0; B < s; B++) {
|
|
398
|
+
const O = e[x * s + B] - e[E * s + B], S = I(A * O);
|
|
399
|
+
e[x * s + B] += b * S;
|
|
318
400
|
}
|
|
319
401
|
}
|
|
320
402
|
}
|
|
321
403
|
}
|
|
322
404
|
return e;
|
|
323
405
|
}
|
|
324
|
-
function
|
|
325
|
-
const { a:
|
|
326
|
-
for (let
|
|
327
|
-
|
|
328
|
-
for (let
|
|
329
|
-
const
|
|
330
|
-
for (let
|
|
331
|
-
if (
|
|
332
|
-
const
|
|
333
|
-
let
|
|
334
|
-
for (let
|
|
335
|
-
const
|
|
336
|
-
|
|
406
|
+
function fe(e, t, f, a, s, u, c, d, o, n) {
|
|
407
|
+
const { a: r, b: i, gamma: h = 1, negativeSampleRate: l = 5 } = o, p = f.rows.length, g = new Uint32Array(f.rows), _ = new Uint32Array(f.cols), m = new Float32Array(a), b = new Float32Array(p);
|
|
408
|
+
for (let y = 0; y < p; y++)
|
|
409
|
+
b[y] = a[y] / l;
|
|
410
|
+
for (let y = 0; y < d; y++) {
|
|
411
|
+
const x = 1 - y / d;
|
|
412
|
+
for (let v = 0; v < p; v++) {
|
|
413
|
+
if (m[v] > y) continue;
|
|
414
|
+
const G = g[v], P = _[v];
|
|
415
|
+
let U = 0;
|
|
416
|
+
for (let w = 0; w < c; w++) {
|
|
417
|
+
const C = e[G * c + w] - t[P * c + w];
|
|
418
|
+
U += C * C;
|
|
337
419
|
}
|
|
338
|
-
const
|
|
339
|
-
for (let
|
|
340
|
-
const
|
|
341
|
-
e[
|
|
420
|
+
const F = Math.pow(U, i), N = -2 * r * i * (U > 0 ? F / U : 0) / (r * F + 1);
|
|
421
|
+
for (let w = 0; w < c; w++) {
|
|
422
|
+
const C = e[G * c + w] - t[P * c + w];
|
|
423
|
+
e[G * c + w] += x * I(N * C);
|
|
342
424
|
}
|
|
343
|
-
|
|
344
|
-
const
|
|
345
|
-
(
|
|
425
|
+
m[v] += a[v];
|
|
426
|
+
const M = a[v] / l, E = Math.max(0, Math.floor(
|
|
427
|
+
(y - b[v]) / M
|
|
346
428
|
));
|
|
347
|
-
|
|
348
|
-
for (let
|
|
349
|
-
const
|
|
350
|
-
if (
|
|
351
|
-
let
|
|
352
|
-
for (let
|
|
353
|
-
const
|
|
354
|
-
|
|
429
|
+
b[v] += E * M;
|
|
430
|
+
for (let w = 0; w < E; w++) {
|
|
431
|
+
const C = Math.floor(Math.random() * u);
|
|
432
|
+
if (C === P) continue;
|
|
433
|
+
let A = 0;
|
|
434
|
+
for (let S = 0; S < c; S++) {
|
|
435
|
+
const k = e[G * c + S] - t[C * c + S];
|
|
436
|
+
A += k * k;
|
|
355
437
|
}
|
|
356
|
-
const
|
|
357
|
-
for (let
|
|
358
|
-
const
|
|
359
|
-
e[
|
|
438
|
+
const B = Math.pow(A, i), O = 2 * h * i / ((1e-3 + A) * (r * B + 1));
|
|
439
|
+
for (let S = 0; S < c; S++) {
|
|
440
|
+
const k = e[G * c + S] - t[C * c + S];
|
|
441
|
+
e[G * c + S] += x * I(O * k);
|
|
360
442
|
}
|
|
361
443
|
}
|
|
362
444
|
}
|
|
363
445
|
}
|
|
364
446
|
return e;
|
|
365
447
|
}
|
|
366
|
-
function
|
|
367
|
-
return typeof navigator < "u" && !!navigator.gpu;
|
|
368
|
-
}
|
|
369
|
-
async function re(e, t = {}, r) {
|
|
448
|
+
async function pe(e, t = {}, f) {
|
|
370
449
|
const {
|
|
371
|
-
nComponents:
|
|
372
|
-
nNeighbors:
|
|
373
|
-
minDist:
|
|
450
|
+
nComponents: a = 2,
|
|
451
|
+
nNeighbors: s = 15,
|
|
452
|
+
minDist: u = 0.1,
|
|
374
453
|
spread: c = 1,
|
|
375
454
|
hnsw: d = {}
|
|
376
|
-
} = t,
|
|
455
|
+
} = t, o = t.nEpochs ?? (e.length > 1e4 ? 200 : 500);
|
|
377
456
|
console.time("knn");
|
|
378
|
-
const { indices:
|
|
457
|
+
const { indices: n, distances: r } = await se(e, s, {
|
|
379
458
|
M: d.M ?? 16,
|
|
380
459
|
efConstruction: d.efConstruction ?? 200,
|
|
381
460
|
efSearch: d.efSearch ?? 50
|
|
382
461
|
});
|
|
383
462
|
console.timeEnd("knn"), console.time("fuzzy-set");
|
|
384
|
-
const i =
|
|
463
|
+
const i = V(n, r, s);
|
|
385
464
|
console.timeEnd("fuzzy-set");
|
|
386
|
-
const { a: h, b: l } =
|
|
387
|
-
for (let
|
|
388
|
-
|
|
465
|
+
const { a: h, b: l } = Z(u, c), p = W(i.vals), g = e.length, _ = new Float32Array(g * a);
|
|
466
|
+
for (let b = 0; b < _.length; b++)
|
|
467
|
+
_[b] = Math.random() * 20 - 10;
|
|
389
468
|
console.time("sgd");
|
|
390
|
-
let
|
|
391
|
-
if (
|
|
469
|
+
let m;
|
|
470
|
+
if (J())
|
|
392
471
|
try {
|
|
393
|
-
const
|
|
394
|
-
await
|
|
395
|
-
|
|
472
|
+
const b = new X();
|
|
473
|
+
await b.init(), m = await b.optimize(
|
|
474
|
+
_,
|
|
396
475
|
new Uint32Array(i.rows),
|
|
397
476
|
new Uint32Array(i.cols),
|
|
398
|
-
u,
|
|
399
477
|
p,
|
|
400
|
-
|
|
401
|
-
|
|
478
|
+
g,
|
|
479
|
+
a,
|
|
480
|
+
o,
|
|
402
481
|
{ a: h, b: l, gamma: 1, negativeSampleRate: 5 },
|
|
403
|
-
|
|
482
|
+
f
|
|
404
483
|
);
|
|
405
|
-
} catch (
|
|
406
|
-
console.warn("WebGPU SGD failed, falling back to CPU:",
|
|
484
|
+
} catch (b) {
|
|
485
|
+
console.warn("WebGPU SGD failed, falling back to CPU:", b), m = j(_, i, p, g, a, o, { a: h, b: l }, f);
|
|
407
486
|
}
|
|
408
487
|
else
|
|
409
|
-
|
|
410
|
-
return console.timeEnd("sgd"),
|
|
488
|
+
m = j(_, i, p, g, a, o, { a: h, b: l }, f);
|
|
489
|
+
return console.timeEnd("sgd"), m;
|
|
411
490
|
}
|
|
412
|
-
function
|
|
413
|
-
|
|
414
|
-
return { a: 1.9292, b: 0.7915 };
|
|
415
|
-
if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6)
|
|
416
|
-
return { a: 1.8956, b: 0.8006 };
|
|
417
|
-
if (Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6)
|
|
418
|
-
return { a: 1.5769, b: 0.8951 };
|
|
419
|
-
const r = te(e, t);
|
|
420
|
-
return { a: ne(e, t, r), b: r };
|
|
491
|
+
function Z(e, t) {
|
|
492
|
+
return Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.1) < 1e-6 ? { a: 1.9292, b: 0.7915 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0) < 1e-6 ? { a: 1.8956, b: 0.8006 } : Math.abs(t - 1) < 1e-6 && Math.abs(e - 0.5) < 1e-6 ? { a: 1.5769, b: 0.8951 } : de(e, t);
|
|
421
493
|
}
|
|
422
|
-
function
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
494
|
+
function de(e, t) {
|
|
495
|
+
const a = [], s = [];
|
|
496
|
+
for (let o = 0; o < 299; o++) {
|
|
497
|
+
const n = (o + 1) / 299 * t * 3;
|
|
498
|
+
a.push(n), s.push(n < e ? 1 : Math.exp(-(n - e) / t));
|
|
499
|
+
}
|
|
500
|
+
let u = 1, c = 1, d = 1e-3;
|
|
501
|
+
for (let o = 0; o < 500; o++) {
|
|
502
|
+
let n = 0, r = 0, i = 0, h = 0, l = 0, p = 0;
|
|
503
|
+
for (let U = 0; U < 299; U++) {
|
|
504
|
+
const F = a[U], N = Math.pow(F, 2 * c), M = 1 + u * N, w = 1 / M - s[U];
|
|
505
|
+
p += w * w;
|
|
506
|
+
const C = M * M, A = -N / C, B = F > 0 ? -2 * Math.log(F) * u * N / C : 0;
|
|
507
|
+
n += A * w, r += B * w, i += A * A, h += B * B, l += A * B;
|
|
508
|
+
}
|
|
509
|
+
const g = i + d, _ = h + d, m = l, b = g * _ - m * m;
|
|
510
|
+
if (Math.abs(b) < 1e-20) break;
|
|
511
|
+
const y = -(_ * n - m * r) / b, x = -(g * r - m * n) / b, v = Math.max(1e-4, u + y), G = Math.max(1e-4, c + x);
|
|
512
|
+
let P = 0;
|
|
513
|
+
for (let U = 0; U < 299; U++) {
|
|
514
|
+
const F = Math.pow(a[U], 2 * G), N = 1 / (1 + v * F) - s[U];
|
|
515
|
+
P += N * N;
|
|
516
|
+
}
|
|
517
|
+
if (P < p ? (u = v, c = G, d = Math.max(1e-10, d / 10)) : d = Math.min(1e10, d * 10), Math.abs(y) < 1e-8 && Math.abs(x) < 1e-8) break;
|
|
518
|
+
}
|
|
519
|
+
return { a: u, b: c };
|
|
427
520
|
}
|
|
428
|
-
class
|
|
521
|
+
class ge {
|
|
429
522
|
constructor(t = {}) {
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
523
|
+
R(this, "_nComponents");
|
|
524
|
+
R(this, "_nNeighbors");
|
|
525
|
+
R(this, "_minDist");
|
|
526
|
+
R(this, "_spread");
|
|
527
|
+
R(this, "_nEpochs");
|
|
528
|
+
R(this, "_hnswOpts");
|
|
529
|
+
R(this, "_a");
|
|
530
|
+
R(this, "_b");
|
|
438
531
|
/** The low-dimensional embedding produced by the last fit() call. */
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
532
|
+
R(this, "embedding", null);
|
|
533
|
+
R(this, "_hnswIndex", null);
|
|
534
|
+
R(this, "_nTrain", 0);
|
|
442
535
|
this._nComponents = t.nComponents ?? 2, this._nNeighbors = t.nNeighbors ?? 15, this._minDist = t.minDist ?? 0.1, this._spread = t.spread ?? 1, this._nEpochs = t.nEpochs, this._hnswOpts = t.hnsw ?? {};
|
|
443
|
-
const { a:
|
|
444
|
-
this._a =
|
|
536
|
+
const { a: f, b: a } = Z(this._minDist, this._spread);
|
|
537
|
+
this._a = f, this._b = a;
|
|
445
538
|
}
|
|
446
539
|
/**
|
|
447
540
|
* Train UMAP on `vectors`.
|
|
@@ -449,45 +542,45 @@ class ie {
|
|
|
449
542
|
* index so that transform() can project new points later.
|
|
450
543
|
* Returns `this` for chaining.
|
|
451
544
|
*/
|
|
452
|
-
async fit(t,
|
|
453
|
-
const
|
|
545
|
+
async fit(t, f) {
|
|
546
|
+
const a = t.length, s = this._nEpochs ?? (a > 1e4 ? 200 : 500), { M: u = 16, efConstruction: c = 200, efSearch: d = 50 } = this._hnswOpts;
|
|
454
547
|
console.time("knn");
|
|
455
|
-
const { knn:
|
|
456
|
-
M:
|
|
548
|
+
const { knn: o, index: n } = await ae(t, this._nNeighbors, {
|
|
549
|
+
M: u,
|
|
457
550
|
efConstruction: c,
|
|
458
551
|
efSearch: d
|
|
459
552
|
});
|
|
460
|
-
this._hnswIndex =
|
|
461
|
-
const
|
|
553
|
+
this._hnswIndex = n, this._nTrain = a, console.timeEnd("knn"), console.time("fuzzy-set");
|
|
554
|
+
const r = V(o.indices, o.distances, this._nNeighbors);
|
|
462
555
|
console.timeEnd("fuzzy-set");
|
|
463
|
-
const i =
|
|
556
|
+
const i = W(r.vals), h = new Float32Array(a * this._nComponents);
|
|
464
557
|
for (let l = 0; l < h.length; l++)
|
|
465
558
|
h[l] = Math.random() * 20 - 10;
|
|
466
|
-
if (console.time("sgd"),
|
|
559
|
+
if (console.time("sgd"), J())
|
|
467
560
|
try {
|
|
468
|
-
const l = new
|
|
561
|
+
const l = new X();
|
|
469
562
|
await l.init(), this.embedding = await l.optimize(
|
|
470
563
|
h,
|
|
471
|
-
new Uint32Array(
|
|
472
|
-
new Uint32Array(
|
|
564
|
+
new Uint32Array(r.rows),
|
|
565
|
+
new Uint32Array(r.cols),
|
|
473
566
|
i,
|
|
474
|
-
|
|
567
|
+
a,
|
|
475
568
|
this._nComponents,
|
|
476
|
-
|
|
569
|
+
s,
|
|
477
570
|
{ a: this._a, b: this._b, gamma: 1, negativeSampleRate: 5 },
|
|
478
|
-
|
|
571
|
+
f
|
|
479
572
|
);
|
|
480
573
|
} catch (l) {
|
|
481
|
-
console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding =
|
|
574
|
+
console.warn("WebGPU SGD failed, falling back to CPU:", l), this.embedding = j(h, r, i, a, this._nComponents, s, {
|
|
482
575
|
a: this._a,
|
|
483
576
|
b: this._b
|
|
484
|
-
},
|
|
577
|
+
}, f);
|
|
485
578
|
}
|
|
486
579
|
else
|
|
487
|
-
this.embedding =
|
|
580
|
+
this.embedding = j(h, r, i, a, this._nComponents, s, {
|
|
488
581
|
a: this._a,
|
|
489
582
|
b: this._b
|
|
490
|
-
},
|
|
583
|
+
}, f);
|
|
491
584
|
return console.timeEnd("sgd"), this;
|
|
492
585
|
}
|
|
493
586
|
/**
|
|
@@ -501,35 +594,35 @@ class ie {
|
|
|
501
594
|
* returned embedding to [0, 1]. The stored training embedding is never
|
|
502
595
|
* mutated. Defaults to `false`.
|
|
503
596
|
*/
|
|
504
|
-
async transform(t,
|
|
597
|
+
async transform(t, f = !1) {
|
|
505
598
|
if (!this._hnswIndex || !this.embedding)
|
|
506
599
|
throw new Error("UMAP.transform() must be called after fit()");
|
|
507
|
-
const
|
|
508
|
-
for (let
|
|
509
|
-
const
|
|
510
|
-
|
|
511
|
-
for (let
|
|
512
|
-
i[
|
|
600
|
+
const a = t.length, s = this._nEpochs ?? (this._nTrain > 1e4 ? 200 : 500), u = Math.max(100, Math.floor(s / 4)), c = this._hnswIndex.searchKnn(t, this._nNeighbors), d = re(c.indices, c.distances, this._nNeighbors), o = new Uint32Array(d.rows), n = new Uint32Array(d.cols), r = new Float32Array(a), i = new Float32Array(a * this._nComponents);
|
|
601
|
+
for (let p = 0; p < o.length; p++) {
|
|
602
|
+
const g = o[p], _ = n[p], m = d.vals[p];
|
|
603
|
+
r[g] += m;
|
|
604
|
+
for (let b = 0; b < this._nComponents; b++)
|
|
605
|
+
i[g * this._nComponents + b] += m * this.embedding[_ * this._nComponents + b];
|
|
513
606
|
}
|
|
514
|
-
for (let
|
|
515
|
-
if (
|
|
516
|
-
for (let
|
|
517
|
-
i[
|
|
607
|
+
for (let p = 0; p < a; p++)
|
|
608
|
+
if (r[p] > 0)
|
|
609
|
+
for (let g = 0; g < this._nComponents; g++)
|
|
610
|
+
i[p * this._nComponents + g] /= r[p];
|
|
518
611
|
else
|
|
519
|
-
for (let
|
|
520
|
-
i[
|
|
521
|
-
const h =
|
|
612
|
+
for (let g = 0; g < this._nComponents; g++)
|
|
613
|
+
i[p * this._nComponents + g] = Math.random() * 20 - 10;
|
|
614
|
+
const h = W(d.vals), l = fe(
|
|
522
615
|
i,
|
|
523
616
|
this.embedding,
|
|
524
617
|
d,
|
|
525
618
|
h,
|
|
526
|
-
|
|
619
|
+
a,
|
|
527
620
|
this._nTrain,
|
|
528
621
|
this._nComponents,
|
|
529
|
-
|
|
622
|
+
u,
|
|
530
623
|
{ a: this._a, b: this._b }
|
|
531
624
|
);
|
|
532
|
-
return
|
|
625
|
+
return f ? K(l, a, this._nComponents) : l;
|
|
533
626
|
}
|
|
534
627
|
/**
|
|
535
628
|
* Convenience method equivalent to `fit(vectors)` followed by
|
|
@@ -540,37 +633,38 @@ class ie {
|
|
|
540
633
|
* returned embedding to [0, 1]. `this.embedding` is never mutated.
|
|
541
634
|
* Defaults to `false`.
|
|
542
635
|
*/
|
|
543
|
-
async fit_transform(t,
|
|
544
|
-
return await this.fit(t,
|
|
636
|
+
async fit_transform(t, f, a = !1) {
|
|
637
|
+
return await this.fit(t, f), a ? K(this.embedding, t.length, this._nComponents) : this.embedding;
|
|
545
638
|
}
|
|
546
639
|
}
|
|
547
|
-
function
|
|
548
|
-
const
|
|
549
|
-
for (let
|
|
550
|
-
let
|
|
551
|
-
for (let
|
|
552
|
-
const
|
|
553
|
-
|
|
640
|
+
function K(e, t, f) {
|
|
641
|
+
const a = new Float32Array(e.length);
|
|
642
|
+
for (let s = 0; s < f; s++) {
|
|
643
|
+
let u = 1 / 0, c = -1 / 0;
|
|
644
|
+
for (let o = 0; o < t; o++) {
|
|
645
|
+
const n = e[o * f + s];
|
|
646
|
+
n < u && (u = n), n > c && (c = n);
|
|
554
647
|
}
|
|
555
|
-
const d = c -
|
|
556
|
-
for (let
|
|
557
|
-
|
|
648
|
+
const d = c - u;
|
|
649
|
+
for (let o = 0; o < t; o++)
|
|
650
|
+
a[o * f + s] = d > 0 ? (e[o * f + s] - u) / d : 0;
|
|
558
651
|
}
|
|
559
|
-
return
|
|
652
|
+
return a;
|
|
560
653
|
}
|
|
561
|
-
function
|
|
562
|
-
let
|
|
563
|
-
for (let
|
|
564
|
-
e[
|
|
565
|
-
const
|
|
566
|
-
for (let
|
|
567
|
-
const
|
|
568
|
-
s
|
|
654
|
+
function W(e, t) {
|
|
655
|
+
let f = -1 / 0;
|
|
656
|
+
for (let s = 0; s < e.length; s++)
|
|
657
|
+
e[s] > f && (f = e[s]);
|
|
658
|
+
const a = new Float32Array(e.length);
|
|
659
|
+
for (let s = 0; s < e.length; s++) {
|
|
660
|
+
const u = e[s] / f;
|
|
661
|
+
a[s] = u > 0 ? 1 / u : -1;
|
|
569
662
|
}
|
|
570
|
-
return
|
|
663
|
+
return a;
|
|
571
664
|
}
|
|
572
665
|
export {
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
666
|
+
ge as UMAP,
|
|
667
|
+
he as checkWebGPUAvailable,
|
|
668
|
+
pe as fit,
|
|
669
|
+
J as isWebGPUAvailable
|
|
576
670
|
};
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "umap-gpu",
|
|
3
|
-
"version": "0.2.
|
|
3
|
+
"version": "0.2.15",
|
|
4
4
|
"description": "UMAP with HNSW kNN and WebGPU-accelerated SGD",
|
|
5
5
|
"type": "module",
|
|
6
6
|
"main": "dist/index.js",
|
|
@@ -55,7 +55,8 @@
|
|
|
55
55
|
"vite": "^5.0.0",
|
|
56
56
|
"vitepress": "^1.6.4",
|
|
57
57
|
"vitepress-plugin-llms": "^1.11.0",
|
|
58
|
-
"vitest": "^4.0.18"
|
|
58
|
+
"vitest": "^4.0.18",
|
|
59
|
+
"webgpu": "^0.3.8"
|
|
59
60
|
},
|
|
60
61
|
"license": "MIT"
|
|
61
62
|
}
|