@jax-js/jax 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -2
- package/dist/{backend-CmaidnkQ.cjs → backend-Bu9GY6sK.cjs} +166 -18
- package/dist/{backend-BY8wlLEl.js → backend-tngXtWe4.js} +148 -18
- package/dist/index.cjs +1683 -1004
- package/dist/index.d.cts +365 -95
- package/dist/index.d.ts +365 -95
- package/dist/index.js +1675 -997
- package/dist/{webgpu-C9iAP5h5.js → webgpu-ChVgx3b6.js} +400 -95
- package/dist/{webgpu-BVns4DbI.cjs → webgpu-Oj3Kd-kd.cjs} +400 -95
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
<p align="center"><strong>
|
|
4
4
|
<a href="https://jax-js.com">Website</a> |
|
|
5
5
|
<a href="https://jax-js.com/docs/">API Reference</a> |
|
|
6
|
-
<a href="./FEATURES.md">Compatibility Table</a>
|
|
6
|
+
<a href="./FEATURES.md">Compatibility Table</a> |
|
|
7
|
+
<a href="https://discord.gg/BW6YsCd4Tf">Discord</a>
|
|
7
8
|
</strong></p>
|
|
8
9
|
|
|
9
10
|
**jax-js** is a machine learning framework for the browser. It aims to bring
|
|
@@ -323,7 +324,7 @@ pnpm -C website dev
|
|
|
323
324
|
|
|
324
325
|
## Future work / help wanted
|
|
325
326
|
|
|
326
|
-
Contributions are welcomed!
|
|
327
|
+
Contributions are welcomed! Some fruitful areas to look into:
|
|
327
328
|
|
|
328
329
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
329
330
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
@@ -334,3 +335,5 @@ Contributions are welcomed! Especially in:
|
|
|
334
335
|
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
335
336
|
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
336
337
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
338
|
+
|
|
339
|
+
You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
|
|
@@ -69,6 +69,9 @@ function zipn(...arrays) {
|
|
|
69
69
|
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
70
70
|
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
71
71
|
}
|
|
72
|
+
function sorted(arr) {
|
|
73
|
+
return [...arr].sort((a, b) => a - b);
|
|
74
|
+
}
|
|
72
75
|
function rep(length, value) {
|
|
73
76
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
74
77
|
return new Array(length).fill(value);
|
|
@@ -145,7 +148,7 @@ function normalizeAxis(axis, ndim) {
|
|
|
145
148
|
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
146
149
|
seen.add(ca);
|
|
147
150
|
}
|
|
148
|
-
return
|
|
151
|
+
return sorted(seen);
|
|
149
152
|
}
|
|
150
153
|
}
|
|
151
154
|
function range(start, stop, step = 1) {
|
|
@@ -1328,7 +1331,7 @@ var Reduction = class {
|
|
|
1328
1331
|
/** Evaluate this operation on CPU. */
|
|
1329
1332
|
evaluate(...values) {
|
|
1330
1333
|
if (this.dtype === DType.Bool) {
|
|
1331
|
-
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b,
|
|
1334
|
+
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
|
|
1332
1335
|
else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
|
|
1333
1336
|
} else if (this.dtype === DType.Int32) {
|
|
1334
1337
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
|
|
@@ -1438,6 +1441,112 @@ function erfc(x) {
|
|
|
1438
1441
|
else return 2 - _erfapprox$1(-x);
|
|
1439
1442
|
}
|
|
1440
1443
|
|
|
1444
|
+
//#endregion
|
|
1445
|
+
//#region src/routine.ts
|
|
1446
|
+
/**
|
|
1447
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
1448
|
+
*
|
|
1449
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
1450
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
1451
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
1452
|
+
* relevant.
|
|
1453
|
+
*
|
|
1454
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
1455
|
+
* which each backend implements in a specific way. These are listed in the
|
|
1456
|
+
* `Routines` enum below.
|
|
1457
|
+
*
|
|
1458
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
1459
|
+
* arrays (default `ShapeTracker`).
|
|
1460
|
+
*/
|
|
1461
|
+
var Routine = class {
|
|
1462
|
+
constructor(name, type, params) {
|
|
1463
|
+
this.name = name;
|
|
1464
|
+
this.type = type;
|
|
1465
|
+
this.params = params;
|
|
1466
|
+
}
|
|
1467
|
+
};
|
|
1468
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1469
|
+
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1470
|
+
/** Stable sorting algorithm along the last axis. */
|
|
1471
|
+
Routines$1["Sort"] = "Sort";
|
|
1472
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
1473
|
+
Routines$1["Argsort"] = "Argsort";
|
|
1474
|
+
/** Solve a triangular system of questions. */
|
|
1475
|
+
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1476
|
+
/** Cholesky decomposition of 2D positive semi-definite matrices. */
|
|
1477
|
+
Routines$1["Cholesky"] = "Cholesky";
|
|
1478
|
+
return Routines$1;
|
|
1479
|
+
}({});
|
|
1480
|
+
function runCpuRoutine(routine, inputs, outputs) {
|
|
1481
|
+
const { name, type } = routine;
|
|
1482
|
+
const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
|
|
1483
|
+
const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
|
|
1484
|
+
switch (name) {
|
|
1485
|
+
case Routines.Sort: return runSort(type, inputAr, outputAr);
|
|
1486
|
+
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1487
|
+
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1488
|
+
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1489
|
+
default:
|
|
1490
|
+
}
|
|
1491
|
+
}
|
|
1492
|
+
function runSort(type, [x], [y]) {
|
|
1493
|
+
const xs = type.inputShapes[0];
|
|
1494
|
+
if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
|
|
1495
|
+
const n = xs[xs.length - 1];
|
|
1496
|
+
y.set(x);
|
|
1497
|
+
for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
|
|
1498
|
+
}
|
|
1499
|
+
function runArgsort(type, [x], [y, yi]) {
|
|
1500
|
+
const xs = type.inputShapes[0];
|
|
1501
|
+
if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
|
|
1502
|
+
const n = xs[xs.length - 1];
|
|
1503
|
+
for (let offset = 0; offset < y.length; offset += n) {
|
|
1504
|
+
const ar = x.subarray(offset, offset + n);
|
|
1505
|
+
const out = y.subarray(offset, offset + n);
|
|
1506
|
+
const outi = yi.subarray(offset, offset + n);
|
|
1507
|
+
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1508
|
+
outi.sort((a, b) => ar[a] - ar[b]);
|
|
1509
|
+
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1510
|
+
}
|
|
1511
|
+
}
|
|
1512
|
+
function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
|
|
1513
|
+
const as = type.inputShapes[0];
|
|
1514
|
+
const bs = type.inputShapes[1];
|
|
1515
|
+
if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
|
|
1516
|
+
if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
|
|
1517
|
+
const n = as[as.length - 2];
|
|
1518
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
1519
|
+
const batch = bs[bs.length - 2];
|
|
1520
|
+
for (let counter = 0; counter < a.length / (n * n); counter++) {
|
|
1521
|
+
const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
|
|
1522
|
+
for (let t = 0; t < batch; t++) {
|
|
1523
|
+
const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1524
|
+
const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1525
|
+
for (let i = n - 1; i >= 0; i--) {
|
|
1526
|
+
let sum = b1[i];
|
|
1527
|
+
for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
|
|
1528
|
+
x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
|
|
1529
|
+
}
|
|
1530
|
+
}
|
|
1531
|
+
}
|
|
1532
|
+
}
|
|
1533
|
+
function runCholesky(type, [x], [y]) {
|
|
1534
|
+
const xs = type.inputShapes[0];
|
|
1535
|
+
if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
|
|
1536
|
+
const n = xs[xs.length - 2];
|
|
1537
|
+
const m = xs[xs.length - 1];
|
|
1538
|
+
if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
|
|
1539
|
+
for (let offset = 0; offset < y.length; offset += n * n) {
|
|
1540
|
+
const ar = x.subarray(offset, offset + n * n);
|
|
1541
|
+
const out = y.subarray(offset, offset + n * n);
|
|
1542
|
+
for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
|
|
1543
|
+
let sum = ar[i * n + j];
|
|
1544
|
+
for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
|
|
1545
|
+
out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
|
|
1546
|
+
}
|
|
1547
|
+
}
|
|
1548
|
+
}
|
|
1549
|
+
|
|
1441
1550
|
//#endregion
|
|
1442
1551
|
//#region src/shape.ts
|
|
1443
1552
|
const jstr = JSON.stringify;
|
|
@@ -1908,7 +2017,7 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1908
2017
|
let st = this;
|
|
1909
2018
|
if (axis.length > 0) {
|
|
1910
2019
|
const unsqueezed = [...st.shape];
|
|
1911
|
-
for (const i of axis
|
|
2020
|
+
for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
|
|
1912
2021
|
st = st.reshape(unsqueezed);
|
|
1913
2022
|
}
|
|
1914
2023
|
return st.expand(newShape);
|
|
@@ -2133,7 +2242,7 @@ function tuneWebgpu(kernel) {
|
|
|
2133
2242
|
break;
|
|
2134
2243
|
}
|
|
2135
2244
|
}
|
|
2136
|
-
for (const ax of
|
|
2245
|
+
for (const ax of sorted(upcastedAxis)) {
|
|
2137
2246
|
const s = dim.st.shape[ax];
|
|
2138
2247
|
for (const amount of [8, 4]) if (s % amount === 0) {
|
|
2139
2248
|
dim.applyLocal(ax, amount);
|
|
@@ -2251,13 +2360,21 @@ var CpuBackend = class {
|
|
|
2251
2360
|
if (count === void 0) count = buffer.byteLength - start;
|
|
2252
2361
|
return buffer.slice(start, start + count);
|
|
2253
2362
|
}
|
|
2254
|
-
async
|
|
2255
|
-
return this.
|
|
2363
|
+
async prepareKernel(kernel) {
|
|
2364
|
+
return this.prepareKernelSync(kernel);
|
|
2256
2365
|
}
|
|
2257
|
-
|
|
2366
|
+
prepareKernelSync(kernel) {
|
|
2258
2367
|
return new Executable(kernel, void 0);
|
|
2259
2368
|
}
|
|
2260
|
-
|
|
2369
|
+
async prepareRoutine(routine) {
|
|
2370
|
+
return this.prepareRoutineSync(routine);
|
|
2371
|
+
}
|
|
2372
|
+
prepareRoutineSync(routine) {
|
|
2373
|
+
return new Executable(routine, void 0);
|
|
2374
|
+
}
|
|
2375
|
+
dispatch(exe, inputs, outputs) {
|
|
2376
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
2377
|
+
const kernel = exe.source;
|
|
2261
2378
|
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2262
2379
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2263
2380
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
@@ -2315,8 +2432,10 @@ var WasmAllocator = class {
|
|
|
2315
2432
|
const sizeClass = this.#findSizeClass(size);
|
|
2316
2433
|
const freeList = this.#freeLists.get(sizeClass);
|
|
2317
2434
|
let ptr;
|
|
2318
|
-
if (freeList && freeList.length > 0)
|
|
2319
|
-
|
|
2435
|
+
if (freeList && freeList.length > 0) {
|
|
2436
|
+
ptr = freeList.pop();
|
|
2437
|
+
new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
|
|
2438
|
+
} else ptr = this.#bumpAlloc(sizeClass);
|
|
2320
2439
|
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2321
2440
|
return ptr;
|
|
2322
2441
|
}
|
|
@@ -3682,10 +3801,10 @@ var WasmBackend = class {
|
|
|
3682
3801
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3683
3802
|
return buffer.slice(start, start + count);
|
|
3684
3803
|
}
|
|
3685
|
-
async
|
|
3686
|
-
return this.
|
|
3804
|
+
async prepareKernel(kernel) {
|
|
3805
|
+
return this.prepareKernelSync(kernel);
|
|
3687
3806
|
}
|
|
3688
|
-
|
|
3807
|
+
prepareKernelSync(kernel) {
|
|
3689
3808
|
const kernelHash = FpHash.hash(kernel);
|
|
3690
3809
|
const module$1 = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3691
3810
|
const bytes = codegenWasm(kernel);
|
|
@@ -3693,7 +3812,14 @@ var WasmBackend = class {
|
|
|
3693
3812
|
});
|
|
3694
3813
|
return new Executable(kernel, { module: module$1 });
|
|
3695
3814
|
}
|
|
3815
|
+
async prepareRoutine(routine) {
|
|
3816
|
+
return this.prepareRoutineSync(routine);
|
|
3817
|
+
}
|
|
3818
|
+
prepareRoutineSync(routine) {
|
|
3819
|
+
return new Executable(routine, void 0);
|
|
3820
|
+
}
|
|
3696
3821
|
dispatch(exe, inputs, outputs) {
|
|
3822
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
3697
3823
|
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3698
3824
|
const func = instance.exports.kernel;
|
|
3699
3825
|
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
@@ -4041,7 +4167,7 @@ async function createBackend(device) {
|
|
|
4041
4167
|
if (!navigator.gpu) return null;
|
|
4042
4168
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4043
4169
|
if (!adapter) return null;
|
|
4044
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4170
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Oj3Kd-kd.cjs"));
|
|
4045
4171
|
const importantLimits = [
|
|
4046
4172
|
"maxBufferSize",
|
|
4047
4173
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4075,8 +4201,8 @@ function getBackend(device) {
|
|
|
4075
4201
|
return backend;
|
|
4076
4202
|
}
|
|
4077
4203
|
var Executable = class {
|
|
4078
|
-
constructor(
|
|
4079
|
-
this.
|
|
4204
|
+
constructor(source, data) {
|
|
4205
|
+
this.source = source;
|
|
4080
4206
|
this.data = data;
|
|
4081
4207
|
}
|
|
4082
4208
|
};
|
|
@@ -4092,6 +4218,11 @@ var UnsupportedOpError = class extends Error {
|
|
|
4092
4218
|
super(msg);
|
|
4093
4219
|
}
|
|
4094
4220
|
};
|
|
4221
|
+
var UnsupportedRoutineError = class extends Error {
|
|
4222
|
+
constructor(name, device) {
|
|
4223
|
+
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4224
|
+
}
|
|
4225
|
+
};
|
|
4095
4226
|
|
|
4096
4227
|
//#endregion
|
|
4097
4228
|
Object.defineProperty(exports, 'AluExp', {
|
|
@@ -4160,6 +4291,18 @@ Object.defineProperty(exports, 'Reduction', {
|
|
|
4160
4291
|
return Reduction;
|
|
4161
4292
|
}
|
|
4162
4293
|
});
|
|
4294
|
+
Object.defineProperty(exports, 'Routine', {
|
|
4295
|
+
enumerable: true,
|
|
4296
|
+
get: function () {
|
|
4297
|
+
return Routine;
|
|
4298
|
+
}
|
|
4299
|
+
});
|
|
4300
|
+
Object.defineProperty(exports, 'Routines', {
|
|
4301
|
+
enumerable: true,
|
|
4302
|
+
get: function () {
|
|
4303
|
+
return Routines;
|
|
4304
|
+
}
|
|
4305
|
+
});
|
|
4163
4306
|
Object.defineProperty(exports, 'ShapeTracker', {
|
|
4164
4307
|
enumerable: true,
|
|
4165
4308
|
get: function () {
|
|
@@ -4178,6 +4321,12 @@ Object.defineProperty(exports, 'UnsupportedOpError', {
|
|
|
4178
4321
|
return UnsupportedOpError;
|
|
4179
4322
|
}
|
|
4180
4323
|
});
|
|
4324
|
+
Object.defineProperty(exports, 'UnsupportedRoutineError', {
|
|
4325
|
+
enumerable: true,
|
|
4326
|
+
get: function () {
|
|
4327
|
+
return UnsupportedRoutineError;
|
|
4328
|
+
}
|
|
4329
|
+
});
|
|
4181
4330
|
Object.defineProperty(exports, 'accessorAluExp', {
|
|
4182
4331
|
enumerable: true,
|
|
4183
4332
|
get: function () {
|
|
@@ -4387,5 +4536,4 @@ Object.defineProperty(exports, 'zipn', {
|
|
|
4387
4536
|
get: function () {
|
|
4388
4537
|
return zipn;
|
|
4389
4538
|
}
|
|
4390
|
-
});
|
|
4391
|
-
//# sourceMappingURL=backend-CmaidnkQ.cjs.map
|
|
4539
|
+
});
|
|
@@ -68,6 +68,9 @@ function zipn(...arrays) {
|
|
|
68
68
|
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
69
69
|
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
70
70
|
}
|
|
71
|
+
function sorted(arr) {
|
|
72
|
+
return [...arr].sort((a, b) => a - b);
|
|
73
|
+
}
|
|
71
74
|
function rep(length, value) {
|
|
72
75
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
73
76
|
return new Array(length).fill(value);
|
|
@@ -144,7 +147,7 @@ function normalizeAxis(axis, ndim) {
|
|
|
144
147
|
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
145
148
|
seen.add(ca);
|
|
146
149
|
}
|
|
147
|
-
return
|
|
150
|
+
return sorted(seen);
|
|
148
151
|
}
|
|
149
152
|
}
|
|
150
153
|
function range(start, stop, step = 1) {
|
|
@@ -1327,7 +1330,7 @@ var Reduction = class {
|
|
|
1327
1330
|
/** Evaluate this operation on CPU. */
|
|
1328
1331
|
evaluate(...values) {
|
|
1329
1332
|
if (this.dtype === DType.Bool) {
|
|
1330
|
-
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b,
|
|
1333
|
+
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
|
|
1331
1334
|
else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
|
|
1332
1335
|
} else if (this.dtype === DType.Int32) {
|
|
1333
1336
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
|
|
@@ -1437,6 +1440,112 @@ function erfc(x) {
|
|
|
1437
1440
|
else return 2 - _erfapprox$1(-x);
|
|
1438
1441
|
}
|
|
1439
1442
|
|
|
1443
|
+
//#endregion
|
|
1444
|
+
//#region src/routine.ts
|
|
1445
|
+
/**
|
|
1446
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
1447
|
+
*
|
|
1448
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
1449
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
1450
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
1451
|
+
* relevant.
|
|
1452
|
+
*
|
|
1453
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
1454
|
+
* which each backend implements in a specific way. These are listed in the
|
|
1455
|
+
* `Routines` enum below.
|
|
1456
|
+
*
|
|
1457
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
1458
|
+
* arrays (default `ShapeTracker`).
|
|
1459
|
+
*/
|
|
1460
|
+
var Routine = class {
|
|
1461
|
+
constructor(name, type, params) {
|
|
1462
|
+
this.name = name;
|
|
1463
|
+
this.type = type;
|
|
1464
|
+
this.params = params;
|
|
1465
|
+
}
|
|
1466
|
+
};
|
|
1467
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1468
|
+
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1469
|
+
/** Stable sorting algorithm along the last axis. */
|
|
1470
|
+
Routines$1["Sort"] = "Sort";
|
|
1471
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
1472
|
+
Routines$1["Argsort"] = "Argsort";
|
|
1473
|
+
/** Solve a triangular system of questions. */
|
|
1474
|
+
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1475
|
+
/** Cholesky decomposition of 2D positive semi-definite matrices. */
|
|
1476
|
+
Routines$1["Cholesky"] = "Cholesky";
|
|
1477
|
+
return Routines$1;
|
|
1478
|
+
}({});
|
|
1479
|
+
function runCpuRoutine(routine, inputs, outputs) {
|
|
1480
|
+
const { name, type } = routine;
|
|
1481
|
+
const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
|
|
1482
|
+
const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
|
|
1483
|
+
switch (name) {
|
|
1484
|
+
case Routines.Sort: return runSort(type, inputAr, outputAr);
|
|
1485
|
+
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1486
|
+
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1487
|
+
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1488
|
+
default:
|
|
1489
|
+
}
|
|
1490
|
+
}
|
|
1491
|
+
function runSort(type, [x], [y]) {
|
|
1492
|
+
const xs = type.inputShapes[0];
|
|
1493
|
+
if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
|
|
1494
|
+
const n = xs[xs.length - 1];
|
|
1495
|
+
y.set(x);
|
|
1496
|
+
for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
|
|
1497
|
+
}
|
|
1498
|
+
function runArgsort(type, [x], [y, yi]) {
|
|
1499
|
+
const xs = type.inputShapes[0];
|
|
1500
|
+
if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
|
|
1501
|
+
const n = xs[xs.length - 1];
|
|
1502
|
+
for (let offset = 0; offset < y.length; offset += n) {
|
|
1503
|
+
const ar = x.subarray(offset, offset + n);
|
|
1504
|
+
const out = y.subarray(offset, offset + n);
|
|
1505
|
+
const outi = yi.subarray(offset, offset + n);
|
|
1506
|
+
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1507
|
+
outi.sort((a, b) => ar[a] - ar[b]);
|
|
1508
|
+
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1509
|
+
}
|
|
1510
|
+
}
|
|
1511
|
+
function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
|
|
1512
|
+
const as = type.inputShapes[0];
|
|
1513
|
+
const bs = type.inputShapes[1];
|
|
1514
|
+
if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
|
|
1515
|
+
if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
|
|
1516
|
+
const n = as[as.length - 2];
|
|
1517
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
1518
|
+
const batch = bs[bs.length - 2];
|
|
1519
|
+
for (let counter = 0; counter < a.length / (n * n); counter++) {
|
|
1520
|
+
const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
|
|
1521
|
+
for (let t = 0; t < batch; t++) {
|
|
1522
|
+
const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1523
|
+
const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1524
|
+
for (let i = n - 1; i >= 0; i--) {
|
|
1525
|
+
let sum = b1[i];
|
|
1526
|
+
for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
|
|
1527
|
+
x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
|
|
1528
|
+
}
|
|
1529
|
+
}
|
|
1530
|
+
}
|
|
1531
|
+
}
|
|
1532
|
+
function runCholesky(type, [x], [y]) {
|
|
1533
|
+
const xs = type.inputShapes[0];
|
|
1534
|
+
if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
|
|
1535
|
+
const n = xs[xs.length - 2];
|
|
1536
|
+
const m = xs[xs.length - 1];
|
|
1537
|
+
if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
|
|
1538
|
+
for (let offset = 0; offset < y.length; offset += n * n) {
|
|
1539
|
+
const ar = x.subarray(offset, offset + n * n);
|
|
1540
|
+
const out = y.subarray(offset, offset + n * n);
|
|
1541
|
+
for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
|
|
1542
|
+
let sum = ar[i * n + j];
|
|
1543
|
+
for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
|
|
1544
|
+
out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
|
|
1545
|
+
}
|
|
1546
|
+
}
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1440
1549
|
//#endregion
|
|
1441
1550
|
//#region src/shape.ts
|
|
1442
1551
|
const jstr = JSON.stringify;
|
|
@@ -1907,7 +2016,7 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1907
2016
|
let st = this;
|
|
1908
2017
|
if (axis.length > 0) {
|
|
1909
2018
|
const unsqueezed = [...st.shape];
|
|
1910
|
-
for (const i of axis
|
|
2019
|
+
for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
|
|
1911
2020
|
st = st.reshape(unsqueezed);
|
|
1912
2021
|
}
|
|
1913
2022
|
return st.expand(newShape);
|
|
@@ -2132,7 +2241,7 @@ function tuneWebgpu(kernel) {
|
|
|
2132
2241
|
break;
|
|
2133
2242
|
}
|
|
2134
2243
|
}
|
|
2135
|
-
for (const ax of
|
|
2244
|
+
for (const ax of sorted(upcastedAxis)) {
|
|
2136
2245
|
const s = dim.st.shape[ax];
|
|
2137
2246
|
for (const amount of [8, 4]) if (s % amount === 0) {
|
|
2138
2247
|
dim.applyLocal(ax, amount);
|
|
@@ -2250,13 +2359,21 @@ var CpuBackend = class {
|
|
|
2250
2359
|
if (count === void 0) count = buffer.byteLength - start;
|
|
2251
2360
|
return buffer.slice(start, start + count);
|
|
2252
2361
|
}
|
|
2253
|
-
async
|
|
2254
|
-
return this.
|
|
2362
|
+
async prepareKernel(kernel) {
|
|
2363
|
+
return this.prepareKernelSync(kernel);
|
|
2255
2364
|
}
|
|
2256
|
-
|
|
2365
|
+
prepareKernelSync(kernel) {
|
|
2257
2366
|
return new Executable(kernel, void 0);
|
|
2258
2367
|
}
|
|
2259
|
-
|
|
2368
|
+
async prepareRoutine(routine) {
|
|
2369
|
+
return this.prepareRoutineSync(routine);
|
|
2370
|
+
}
|
|
2371
|
+
prepareRoutineSync(routine) {
|
|
2372
|
+
return new Executable(routine, void 0);
|
|
2373
|
+
}
|
|
2374
|
+
dispatch(exe, inputs, outputs) {
|
|
2375
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
2376
|
+
const kernel = exe.source;
|
|
2260
2377
|
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2261
2378
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2262
2379
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
@@ -2314,8 +2431,10 @@ var WasmAllocator = class {
|
|
|
2314
2431
|
const sizeClass = this.#findSizeClass(size);
|
|
2315
2432
|
const freeList = this.#freeLists.get(sizeClass);
|
|
2316
2433
|
let ptr;
|
|
2317
|
-
if (freeList && freeList.length > 0)
|
|
2318
|
-
|
|
2434
|
+
if (freeList && freeList.length > 0) {
|
|
2435
|
+
ptr = freeList.pop();
|
|
2436
|
+
new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
|
|
2437
|
+
} else ptr = this.#bumpAlloc(sizeClass);
|
|
2319
2438
|
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2320
2439
|
return ptr;
|
|
2321
2440
|
}
|
|
@@ -3681,10 +3800,10 @@ var WasmBackend = class {
|
|
|
3681
3800
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3682
3801
|
return buffer.slice(start, start + count);
|
|
3683
3802
|
}
|
|
3684
|
-
async
|
|
3685
|
-
return this.
|
|
3803
|
+
async prepareKernel(kernel) {
|
|
3804
|
+
return this.prepareKernelSync(kernel);
|
|
3686
3805
|
}
|
|
3687
|
-
|
|
3806
|
+
prepareKernelSync(kernel) {
|
|
3688
3807
|
const kernelHash = FpHash.hash(kernel);
|
|
3689
3808
|
const module = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3690
3809
|
const bytes = codegenWasm(kernel);
|
|
@@ -3692,7 +3811,14 @@ var WasmBackend = class {
|
|
|
3692
3811
|
});
|
|
3693
3812
|
return new Executable(kernel, { module });
|
|
3694
3813
|
}
|
|
3814
|
+
async prepareRoutine(routine) {
|
|
3815
|
+
return this.prepareRoutineSync(routine);
|
|
3816
|
+
}
|
|
3817
|
+
prepareRoutineSync(routine) {
|
|
3818
|
+
return new Executable(routine, void 0);
|
|
3819
|
+
}
|
|
3695
3820
|
dispatch(exe, inputs, outputs) {
|
|
3821
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
3696
3822
|
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3697
3823
|
const func = instance.exports.kernel;
|
|
3698
3824
|
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
@@ -4040,7 +4166,7 @@ async function createBackend(device) {
|
|
|
4040
4166
|
if (!navigator.gpu) return null;
|
|
4041
4167
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4042
4168
|
if (!adapter) return null;
|
|
4043
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4169
|
+
const { WebGPUBackend } = await import("./webgpu-ChVgx3b6.js");
|
|
4044
4170
|
const importantLimits = [
|
|
4045
4171
|
"maxBufferSize",
|
|
4046
4172
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4074,8 +4200,8 @@ function getBackend(device) {
|
|
|
4074
4200
|
return backend;
|
|
4075
4201
|
}
|
|
4076
4202
|
var Executable = class {
|
|
4077
|
-
constructor(
|
|
4078
|
-
this.
|
|
4203
|
+
constructor(source, data) {
|
|
4204
|
+
this.source = source;
|
|
4079
4205
|
this.data = data;
|
|
4080
4206
|
}
|
|
4081
4207
|
};
|
|
@@ -4091,7 +4217,11 @@ var UnsupportedOpError = class extends Error {
|
|
|
4091
4217
|
super(msg);
|
|
4092
4218
|
}
|
|
4093
4219
|
};
|
|
4220
|
+
var UnsupportedRoutineError = class extends Error {
|
|
4221
|
+
constructor(name, device) {
|
|
4222
|
+
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4223
|
+
}
|
|
4224
|
+
};
|
|
4094
4225
|
|
|
4095
4226
|
//#endregion
|
|
4096
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4097
|
-
//# sourceMappingURL=backend-BY8wlLEl.js.map
|
|
4227
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|