@jax-js/jax 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +15 -9
- package/dist/{backend-BY8wlLEl.js → backend-DaqL-MNz.js} +240 -21
- package/dist/{backend-CmaidnkQ.cjs → backend-DziQSaoQ.cjs} +264 -21
- package/dist/index.cjs +2407 -1132
- package/dist/index.d.cts +596 -97
- package/dist/index.d.ts +596 -97
- package/dist/index.js +2400 -1126
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/webgpu-Db2JrNBr.cjs +1261 -0
- package/dist/webgpu-Dh7k9io0.js +1261 -0
- package/package.json +1 -1
- package/dist/webgpu-BVns4DbI.cjs +0 -663
- package/dist/webgpu-C9iAP5h5.js +0 -663
package/README.md
CHANGED
|
@@ -3,7 +3,8 @@
|
|
|
3
3
|
<p align="center"><strong>
|
|
4
4
|
<a href="https://jax-js.com">Website</a> |
|
|
5
5
|
<a href="https://jax-js.com/docs/">API Reference</a> |
|
|
6
|
-
<a href="./FEATURES.md">Compatibility Table</a>
|
|
6
|
+
<a href="./FEATURES.md">Compatibility Table</a> |
|
|
7
|
+
<a href="https://discord.gg/BW6YsCd4Tf">Discord</a>
|
|
7
8
|
</strong></p>
|
|
8
9
|
|
|
9
10
|
**jax-js** is a machine learning framework for the browser. It aims to bring
|
|
@@ -57,13 +58,13 @@ import { numpy as np } from "@jax-js/jax";
|
|
|
57
58
|
const ar = np.array([1, 2, 3]);
|
|
58
59
|
```
|
|
59
60
|
|
|
60
|
-
By default, this is a float32 array, but you can
|
|
61
|
+
By default, this is a float32 array, but you can specify a different dtype:
|
|
61
62
|
|
|
62
63
|
```ts
|
|
63
|
-
const ar = np.array([1, 2, 3], { dtype: np.
|
|
64
|
+
const ar = np.array([1, 2, 3], { dtype: np.int32 });
|
|
64
65
|
```
|
|
65
66
|
|
|
66
|
-
For more efficient construction,
|
|
67
|
+
For more efficient construction, create an array from a JS `TypedArray` buffer:
|
|
67
68
|
|
|
68
69
|
```ts
|
|
69
70
|
const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
|
|
@@ -222,14 +223,18 @@ Note that you need to use `type` alias syntax rather than `interface` to define
|
|
|
222
223
|
Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
|
|
223
224
|
and determines how to execute compiled operations on them.
|
|
224
225
|
|
|
225
|
-
There are currently
|
|
226
|
+
There are currently 4 devices in jax-js:
|
|
226
227
|
|
|
227
|
-
- `cpu`: Slow,
|
|
228
|
+
- `cpu`: Slow, interpreted JS, only meant for debugging.
|
|
228
229
|
- `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
|
|
229
230
|
- `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
|
|
230
231
|
[supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
|
|
232
|
+
- `webgl`: [WebGL2](https://developer.mozilla.org/en-US/docs/Web/API/WebGL2RenderingContext), via
|
|
233
|
+
fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much
|
|
234
|
+
slower than WebGPU. It's offered on a best-effort basis and not as well-supported.
|
|
231
235
|
|
|
232
|
-
|
|
236
|
+
**We recommend `webgpu` for best performance, especially when running neural networks.** The default
|
|
237
|
+
device is `wasm`, but you can change this at startup time:
|
|
233
238
|
|
|
234
239
|
```ts
|
|
235
240
|
import { defaultDevice, init } from "@jax-js/jax";
|
|
@@ -323,7 +328,7 @@ pnpm -C website dev
|
|
|
323
328
|
|
|
324
329
|
## Future work / help wanted
|
|
325
330
|
|
|
326
|
-
Contributions are welcomed!
|
|
331
|
+
Contributions are welcomed! Some fruitful areas to look into:
|
|
327
332
|
|
|
328
333
|
- Adding support for more JAX functions and operations, see [compatibility table](./FEATURES.md).
|
|
329
334
|
- Improving performance of the WebGPU and Wasm runtimes, generating better kernels, and using SIMD
|
|
@@ -332,5 +337,6 @@ Contributions are welcomed! Especially in:
|
|
|
332
337
|
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
333
338
|
to help with model performance debugging.
|
|
334
339
|
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
335
|
-
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
336
340
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
341
|
+
|
|
342
|
+
You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
|
|
@@ -68,6 +68,9 @@ function zipn(...arrays) {
|
|
|
68
68
|
const minLength = Math.min(...arrays.map((x) => x.length));
|
|
69
69
|
return Array.from({ length: minLength }, (_, i) => arrays.map((arr) => arr[i]));
|
|
70
70
|
}
|
|
71
|
+
function sorted(arr) {
|
|
72
|
+
return [...arr].sort((a, b) => a - b);
|
|
73
|
+
}
|
|
71
74
|
function rep(length, value) {
|
|
72
75
|
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
73
76
|
return new Array(length).fill(value);
|
|
@@ -144,7 +147,7 @@ function normalizeAxis(axis, ndim) {
|
|
|
144
147
|
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
145
148
|
seen.add(ca);
|
|
146
149
|
}
|
|
147
|
-
return
|
|
150
|
+
return sorted(seen);
|
|
148
151
|
}
|
|
149
152
|
}
|
|
150
153
|
function range(start, stop, step = 1) {
|
|
@@ -1327,7 +1330,7 @@ var Reduction = class {
|
|
|
1327
1330
|
/** Evaluate this operation on CPU. */
|
|
1328
1331
|
evaluate(...values) {
|
|
1329
1332
|
if (this.dtype === DType.Bool) {
|
|
1330
|
-
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b,
|
|
1333
|
+
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, false);
|
|
1331
1334
|
else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
|
|
1332
1335
|
} else if (this.dtype === DType.Int32) {
|
|
1333
1336
|
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
|
|
@@ -1437,6 +1440,184 @@ function erfc(x) {
|
|
|
1437
1440
|
else return 2 - _erfapprox$1(-x);
|
|
1438
1441
|
}
|
|
1439
1442
|
|
|
1443
|
+
//#endregion
|
|
1444
|
+
//#region src/routine.ts
|
|
1445
|
+
/**
|
|
1446
|
+
* Advanced operations that don't fit into the `AluExp` compiler representation.
|
|
1447
|
+
*
|
|
1448
|
+
* Some routines like iterative matrix algorithms, FFTs, or sorting may not be
|
|
1449
|
+
* easy to express efficiently as a `Kernel` object. These also tend to be
|
|
1450
|
+
* somewhat expensive, so the benefit of kernel fusion and inlining is less
|
|
1451
|
+
* relevant.
|
|
1452
|
+
*
|
|
1453
|
+
* For these operations, we dispatch them as a custom operation on the backend,
|
|
1454
|
+
* which each backend implements in a specific way. These are listed in the
|
|
1455
|
+
* `Routines` enum below.
|
|
1456
|
+
*
|
|
1457
|
+
* Routines cannot be fused into other kernels and always operate on contiguous
|
|
1458
|
+
* arrays (default `ShapeTracker`).
|
|
1459
|
+
*/
|
|
1460
|
+
var Routine = class {
|
|
1461
|
+
constructor(name, type, params) {
|
|
1462
|
+
this.name = name;
|
|
1463
|
+
this.type = type;
|
|
1464
|
+
this.params = params;
|
|
1465
|
+
}
|
|
1466
|
+
};
|
|
1467
|
+
/** One of the valid `Routine` that can be dispatched to backend. */
|
|
1468
|
+
let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
1469
|
+
/** Stable sorting algorithm along the last axis. */
|
|
1470
|
+
Routines$1["Sort"] = "Sort";
|
|
1471
|
+
/** Returns `int32` indices of the stably sorted array. */
|
|
1472
|
+
Routines$1["Argsort"] = "Argsort";
|
|
1473
|
+
/**
|
|
1474
|
+
* Solve a triangular system of equations.
|
|
1475
|
+
*
|
|
1476
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
1477
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
1478
|
+
*
|
|
1479
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
1480
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
1481
|
+
*/
|
|
1482
|
+
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1483
|
+
/**
|
|
1484
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
1485
|
+
*
|
|
1486
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
1487
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
1488
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
1489
|
+
*/
|
|
1490
|
+
Routines$1["Cholesky"] = "Cholesky";
|
|
1491
|
+
/**
|
|
1492
|
+
* LU decomposition of 2D rectangular matrices.
|
|
1493
|
+
*
|
|
1494
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
1495
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
1496
|
+
*
|
|
1497
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
1498
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
1499
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
1500
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
1501
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
1502
|
+
*/
|
|
1503
|
+
Routines$1["LU"] = "LU";
|
|
1504
|
+
return Routines$1;
|
|
1505
|
+
}({});
|
|
1506
|
+
function runCpuRoutine(routine, inputs, outputs) {
|
|
1507
|
+
const { name, type } = routine;
|
|
1508
|
+
const inputAr = inputs.map((buf, i) => dtypedArray(type.inputDtypes[i], buf));
|
|
1509
|
+
const outputAr = outputs.map((buf, i) => dtypedArray(type.outputDtypes[i], buf));
|
|
1510
|
+
switch (name) {
|
|
1511
|
+
case Routines.Sort: return runSort(type, inputAr, outputAr);
|
|
1512
|
+
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1513
|
+
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1514
|
+
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1515
|
+
case Routines.LU: return runLU(type, inputAr, outputAr);
|
|
1516
|
+
default:
|
|
1517
|
+
}
|
|
1518
|
+
}
|
|
1519
|
+
function runSort(type, [x], [y]) {
|
|
1520
|
+
const xs = type.inputShapes[0];
|
|
1521
|
+
if (xs.length === 0) throw new Error("sort: cannot sort a scalar");
|
|
1522
|
+
const n = xs[xs.length - 1];
|
|
1523
|
+
y.set(x);
|
|
1524
|
+
for (let i = 0; i < y.length; i += n) y.subarray(i, i + n).sort();
|
|
1525
|
+
}
|
|
1526
|
+
function runArgsort(type, [x], [y, yi]) {
|
|
1527
|
+
const xs = type.inputShapes[0];
|
|
1528
|
+
if (xs.length === 0) throw new Error("argsort: cannot sort a scalar");
|
|
1529
|
+
const n = xs[xs.length - 1];
|
|
1530
|
+
for (let offset = 0; offset < y.length; offset += n) {
|
|
1531
|
+
const ar = x.subarray(offset, offset + n);
|
|
1532
|
+
const out = y.subarray(offset, offset + n);
|
|
1533
|
+
const outi = yi.subarray(offset, offset + n);
|
|
1534
|
+
for (let i = 0; i < n; i++) outi[i] = i;
|
|
1535
|
+
outi.sort((a, b) => ar[a] - ar[b]);
|
|
1536
|
+
for (let i = 0; i < n; i++) out[i] = ar[outi[i]];
|
|
1537
|
+
}
|
|
1538
|
+
}
|
|
1539
|
+
function runTriangularSolve(type, [a, b], [x], { unitDiagonal }) {
|
|
1540
|
+
const as = type.inputShapes[0];
|
|
1541
|
+
const bs = type.inputShapes[1];
|
|
1542
|
+
if (as.length < 2) throw new Error(`triangular_solve: a must be at least 2D, got ${as}`);
|
|
1543
|
+
if (bs.length < 2) throw new Error(`triangular_solve: b must be at least 2D, got ${bs}`);
|
|
1544
|
+
const n = as[as.length - 2];
|
|
1545
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
1546
|
+
const batch = bs[bs.length - 2];
|
|
1547
|
+
for (let counter = 0; counter < a.length / (n * n); counter++) {
|
|
1548
|
+
const a1 = a.subarray(counter * n * n, (counter + 1) * n * n);
|
|
1549
|
+
for (let t = 0; t < batch; t++) {
|
|
1550
|
+
const b1 = b.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1551
|
+
const x1 = x.subarray((counter * batch + t) * n, (counter * batch + t + 1) * n);
|
|
1552
|
+
for (let i = n - 1; i >= 0; i--) {
|
|
1553
|
+
let sum = b1[i];
|
|
1554
|
+
for (let j = i + 1; j < n; j++) sum -= a1[i * n + j] * x1[j];
|
|
1555
|
+
x1[i] = unitDiagonal ? sum : sum / a1[i * n + i];
|
|
1556
|
+
}
|
|
1557
|
+
}
|
|
1558
|
+
}
|
|
1559
|
+
}
|
|
1560
|
+
function runCholesky(type, [x], [y]) {
|
|
1561
|
+
const xs = type.inputShapes[0];
|
|
1562
|
+
if (xs.length < 2) throw new Error("cholesky: input must be at least 2D");
|
|
1563
|
+
const n = xs[xs.length - 2];
|
|
1564
|
+
const m = xs[xs.length - 1];
|
|
1565
|
+
if (n !== m) throw new Error(`cholesky: input must be square, got [${n}, ${m}]`);
|
|
1566
|
+
for (let offset = 0; offset < y.length; offset += n * n) {
|
|
1567
|
+
const ar = x.subarray(offset, offset + n * n);
|
|
1568
|
+
const out = y.subarray(offset, offset + n * n);
|
|
1569
|
+
for (let i = 0; i < n; i++) for (let j = 0; j <= i; j++) {
|
|
1570
|
+
let sum = ar[i * n + j];
|
|
1571
|
+
for (let k = 0; k < j; k++) sum -= out[i * n + k] * out[j * n + k];
|
|
1572
|
+
out[i * n + j] = i === j ? Math.sqrt(sum) : sum / out[j * n + j];
|
|
1573
|
+
}
|
|
1574
|
+
}
|
|
1575
|
+
}
|
|
1576
|
+
function runLU(type, [a], [lu, pivots, perm]) {
|
|
1577
|
+
const shape = type.inputShapes[0];
|
|
1578
|
+
if (shape.length < 2) throw new Error("lu: input must be at least 2D");
|
|
1579
|
+
const m = shape[shape.length - 2];
|
|
1580
|
+
const n = shape[shape.length - 1];
|
|
1581
|
+
const r = Math.min(m, n);
|
|
1582
|
+
for (let offset = 0; offset < a.length; offset += m * n) {
|
|
1583
|
+
const ar = a.subarray(offset, offset + m * n);
|
|
1584
|
+
const out = lu.subarray(offset, offset + m * n);
|
|
1585
|
+
const batchIdx = offset / (m * n);
|
|
1586
|
+
const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
|
|
1587
|
+
const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
|
|
1588
|
+
out.set(ar);
|
|
1589
|
+
for (let i = 0; i < m; i++) p[i] = i;
|
|
1590
|
+
for (let j = 0; j < r; j++) {
|
|
1591
|
+
let maxVal = Math.abs(out[j * n + j]);
|
|
1592
|
+
let maxRow = j;
|
|
1593
|
+
for (let i = j + 1; i < m; i++) {
|
|
1594
|
+
const val = Math.abs(out[i * n + j]);
|
|
1595
|
+
if (val > maxVal) {
|
|
1596
|
+
maxVal = val;
|
|
1597
|
+
maxRow = i;
|
|
1598
|
+
}
|
|
1599
|
+
}
|
|
1600
|
+
piv[j] = maxRow;
|
|
1601
|
+
if (maxRow !== j) {
|
|
1602
|
+
for (let col = 0; col < n; col++) {
|
|
1603
|
+
const tmp = out[j * n + col];
|
|
1604
|
+
out[j * n + col] = out[maxRow * n + col];
|
|
1605
|
+
out[maxRow * n + col] = tmp;
|
|
1606
|
+
}
|
|
1607
|
+
const tmpP = p[j];
|
|
1608
|
+
p[j] = p[maxRow];
|
|
1609
|
+
p[maxRow] = tmpP;
|
|
1610
|
+
}
|
|
1611
|
+
const diag = out[j * n + j];
|
|
1612
|
+
if (diag !== 0) for (let i = j + 1; i < m; i++) {
|
|
1613
|
+
const factor = out[i * n + j] / diag;
|
|
1614
|
+
out[i * n + j] = factor;
|
|
1615
|
+
for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
|
|
1616
|
+
}
|
|
1617
|
+
}
|
|
1618
|
+
}
|
|
1619
|
+
}
|
|
1620
|
+
|
|
1440
1621
|
//#endregion
|
|
1441
1622
|
//#region src/shape.ts
|
|
1442
1623
|
const jstr = JSON.stringify;
|
|
@@ -1907,7 +2088,7 @@ var ShapeTracker = class ShapeTracker {
|
|
|
1907
2088
|
let st = this;
|
|
1908
2089
|
if (axis.length > 0) {
|
|
1909
2090
|
const unsqueezed = [...st.shape];
|
|
1910
|
-
for (const i of axis
|
|
2091
|
+
for (const i of sorted(axis)) unsqueezed.splice(i, 0, 1);
|
|
1911
2092
|
st = st.reshape(unsqueezed);
|
|
1912
2093
|
}
|
|
1913
2094
|
return st.expand(newShape);
|
|
@@ -2132,7 +2313,7 @@ function tuneWebgpu(kernel) {
|
|
|
2132
2313
|
break;
|
|
2133
2314
|
}
|
|
2134
2315
|
}
|
|
2135
|
-
for (const ax of
|
|
2316
|
+
for (const ax of sorted(upcastedAxis)) {
|
|
2136
2317
|
const s = dim.st.shape[ax];
|
|
2137
2318
|
for (const amount of [8, 4]) if (s % amount === 0) {
|
|
2138
2319
|
dim.applyLocal(ax, amount);
|
|
@@ -2250,13 +2431,21 @@ var CpuBackend = class {
|
|
|
2250
2431
|
if (count === void 0) count = buffer.byteLength - start;
|
|
2251
2432
|
return buffer.slice(start, start + count);
|
|
2252
2433
|
}
|
|
2253
|
-
async
|
|
2254
|
-
return this.
|
|
2434
|
+
async prepareKernel(kernel) {
|
|
2435
|
+
return this.prepareKernelSync(kernel);
|
|
2255
2436
|
}
|
|
2256
|
-
|
|
2437
|
+
prepareKernelSync(kernel) {
|
|
2257
2438
|
return new Executable(kernel, void 0);
|
|
2258
2439
|
}
|
|
2259
|
-
|
|
2440
|
+
async prepareRoutine(routine) {
|
|
2441
|
+
return this.prepareRoutineSync(routine);
|
|
2442
|
+
}
|
|
2443
|
+
prepareRoutineSync(routine) {
|
|
2444
|
+
return new Executable(routine, void 0);
|
|
2445
|
+
}
|
|
2446
|
+
dispatch(exe, inputs, outputs) {
|
|
2447
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
2448
|
+
const kernel = exe.source;
|
|
2260
2449
|
const { exp, epilogue } = tuneNullopt(kernel);
|
|
2261
2450
|
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
2262
2451
|
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
@@ -2314,8 +2503,10 @@ var WasmAllocator = class {
|
|
|
2314
2503
|
const sizeClass = this.#findSizeClass(size);
|
|
2315
2504
|
const freeList = this.#freeLists.get(sizeClass);
|
|
2316
2505
|
let ptr;
|
|
2317
|
-
if (freeList && freeList.length > 0)
|
|
2318
|
-
|
|
2506
|
+
if (freeList && freeList.length > 0) {
|
|
2507
|
+
ptr = freeList.pop();
|
|
2508
|
+
new Uint8Array(this.#memory.buffer, ptr, sizeClass).fill(0);
|
|
2509
|
+
} else ptr = this.#bumpAlloc(sizeClass);
|
|
2319
2510
|
this.#allocatedBuffers.set(ptr, sizeClass);
|
|
2320
2511
|
return ptr;
|
|
2321
2512
|
}
|
|
@@ -3393,7 +3584,7 @@ var I32 = class {
|
|
|
3393
3584
|
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3394
3585
|
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3395
3586
|
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3396
|
-
eqz =
|
|
3587
|
+
eqz = UNARY_OP("eqz", 69, "i32", "i32");
|
|
3397
3588
|
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3398
3589
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3399
3590
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
@@ -3681,10 +3872,10 @@ var WasmBackend = class {
|
|
|
3681
3872
|
if (count === void 0) count = buffer.byteLength - start;
|
|
3682
3873
|
return buffer.slice(start, start + count);
|
|
3683
3874
|
}
|
|
3684
|
-
async
|
|
3685
|
-
return this.
|
|
3875
|
+
async prepareKernel(kernel) {
|
|
3876
|
+
return this.prepareKernelSync(kernel);
|
|
3686
3877
|
}
|
|
3687
|
-
|
|
3878
|
+
prepareKernelSync(kernel) {
|
|
3688
3879
|
const kernelHash = FpHash.hash(kernel);
|
|
3689
3880
|
const module = runWithCache(moduleCache, kernelHash.toString(), () => {
|
|
3690
3881
|
const bytes = codegenWasm(kernel);
|
|
@@ -3692,7 +3883,14 @@ var WasmBackend = class {
|
|
|
3692
3883
|
});
|
|
3693
3884
|
return new Executable(kernel, { module });
|
|
3694
3885
|
}
|
|
3886
|
+
async prepareRoutine(routine) {
|
|
3887
|
+
return this.prepareRoutineSync(routine);
|
|
3888
|
+
}
|
|
3889
|
+
prepareRoutineSync(routine) {
|
|
3890
|
+
return new Executable(routine, void 0);
|
|
3891
|
+
}
|
|
3695
3892
|
dispatch(exe, inputs, outputs) {
|
|
3893
|
+
if (exe.source instanceof Routine) return runCpuRoutine(exe.source, inputs.map((slot) => this.#getBuffer(slot)), outputs.map((slot) => this.#getBuffer(slot)));
|
|
3696
3894
|
const instance = new WebAssembly.Instance(exe.data.module, { env: { memory: this.#memory } });
|
|
3697
3895
|
const func = instance.exports.kernel;
|
|
3698
3896
|
const ptrs = [...inputs, ...outputs].map((slot) => this.#buffers.get(slot).ptr);
|
|
@@ -3851,7 +4049,7 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3851
4049
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3852
4050
|
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3853
4051
|
else dtyF(cg, op, dtype).max();
|
|
3854
|
-
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
4052
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
|
|
3855
4053
|
const a = cg.local.declare(cg.i32);
|
|
3856
4054
|
const b = cg.local.declare(cg.i32);
|
|
3857
4055
|
cg.local.set(b);
|
|
@@ -4001,7 +4199,8 @@ function dtyF(cg, op, dtype) {
|
|
|
4001
4199
|
const devices = [
|
|
4002
4200
|
"cpu",
|
|
4003
4201
|
"wasm",
|
|
4004
|
-
"webgpu"
|
|
4202
|
+
"webgpu",
|
|
4203
|
+
"webgl"
|
|
4005
4204
|
];
|
|
4006
4205
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
4007
4206
|
initializedBackends.set("cpu", new CpuBackend());
|
|
@@ -4040,7 +4239,7 @@ async function createBackend(device) {
|
|
|
4040
4239
|
if (!navigator.gpu) return null;
|
|
4041
4240
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4042
4241
|
if (!adapter) return null;
|
|
4043
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4242
|
+
const { WebGPUBackend } = await import("./webgpu-Dh7k9io0.js");
|
|
4044
4243
|
const importantLimits = [
|
|
4045
4244
|
"maxBufferSize",
|
|
4046
4245
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4064,6 +4263,22 @@ async function createBackend(device) {
|
|
|
4064
4263
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
4065
4264
|
return null;
|
|
4066
4265
|
}
|
|
4266
|
+
} else if (device === "webgl") {
|
|
4267
|
+
if (typeof WebGL2RenderingContext === "undefined") return null;
|
|
4268
|
+
const canvas = new OffscreenCanvas(0, 0);
|
|
4269
|
+
const gl = canvas.getContext("webgl2", {
|
|
4270
|
+
alpha: false,
|
|
4271
|
+
antialias: false,
|
|
4272
|
+
premultipliedAlpha: false,
|
|
4273
|
+
preserveDrawingBuffer: false,
|
|
4274
|
+
depth: false,
|
|
4275
|
+
stencil: false,
|
|
4276
|
+
failIfMajorPerformanceCaveat: true
|
|
4277
|
+
});
|
|
4278
|
+
if (!gl) return null;
|
|
4279
|
+
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4280
|
+
const { WebGLBackend } = await import("./webgl-RSuZKvgc.js");
|
|
4281
|
+
return new WebGLBackend(gl);
|
|
4067
4282
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4068
4283
|
}
|
|
4069
4284
|
/** Retrieve a backend that has been initialized. */
|
|
@@ -4074,8 +4289,8 @@ function getBackend(device) {
|
|
|
4074
4289
|
return backend;
|
|
4075
4290
|
}
|
|
4076
4291
|
var Executable = class {
|
|
4077
|
-
constructor(
|
|
4078
|
-
this.
|
|
4292
|
+
constructor(source, data) {
|
|
4293
|
+
this.source = source;
|
|
4079
4294
|
this.data = data;
|
|
4080
4295
|
}
|
|
4081
4296
|
};
|
|
@@ -4091,7 +4306,11 @@ var UnsupportedOpError = class extends Error {
|
|
|
4091
4306
|
super(msg);
|
|
4092
4307
|
}
|
|
4093
4308
|
};
|
|
4309
|
+
var UnsupportedRoutineError = class extends Error {
|
|
4310
|
+
constructor(name, device) {
|
|
4311
|
+
super(`routine '${name}' is not supported in ${device} backend`);
|
|
4312
|
+
}
|
|
4313
|
+
};
|
|
4094
4314
|
|
|
4095
4315
|
//#endregion
|
|
4096
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4097
|
-
//# sourceMappingURL=backend-BY8wlLEl.js.map
|
|
4316
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|