@jax-js/jax 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-tngXtWe4.js → backend-DaqL-MNz.js} +96 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-DziQSaoQ.cjs} +101 -6
- package/dist/index.cjs +737 -141
- package/dist/index.d.cts +238 -9
- package/dist/index.d.ts +238 -9
- package/dist/index.js +737 -141
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-Db2JrNBr.cjs} +296 -3
- package/dist/{webgpu-ChVgx3b6.js → webgpu-Dh7k9io0.js} +296 -3
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -58,13 +58,13 @@ import { numpy as np } from "@jax-js/jax";
|
|
|
58
58
|
const ar = np.array([1, 2, 3]);
|
|
59
59
|
```
|
|
60
60
|
|
|
61
|
-
By default, this is a float32 array, but you can
|
|
61
|
+
By default, this is a float32 array, but you can specify a different dtype:
|
|
62
62
|
|
|
63
63
|
```ts
|
|
64
|
-
const ar = np.array([1, 2, 3], { dtype: np.
|
|
64
|
+
const ar = np.array([1, 2, 3], { dtype: np.int32 });
|
|
65
65
|
```
|
|
66
66
|
|
|
67
|
-
For more efficient construction,
|
|
67
|
+
For more efficient construction, create an array from a JS `TypedArray` buffer:
|
|
68
68
|
|
|
69
69
|
```ts
|
|
70
70
|
const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
|
|
@@ -223,14 +223,18 @@ Note that you need to use `type` alias syntax rather than `interface` to define
|
|
|
223
223
|
Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
|
|
224
224
|
and determines how to execute compiled operations on them.
|
|
225
225
|
|
|
226
|
-
There are currently
|
|
226
|
+
There are currently 4 devices in jax-js:
|
|
227
227
|
|
|
228
|
-
- `cpu`: Slow,
|
|
228
|
+
- `cpu`: Slow, interpreted JS, only meant for debugging.
|
|
229
229
|
- `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
|
|
230
230
|
- `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
|
|
231
231
|
[supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
|
|
232
|
+
- `webgl`: [WebGL2](https://developer.mozilla.org/en-US/docs/Web/API/WebGL2RenderingContext), via
|
|
233
|
+
fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much
|
|
234
|
+
slower than WebGPU. It's offered on a best-effort basis and not as well-supported.
|
|
232
235
|
|
|
233
|
-
|
|
236
|
+
**We recommend `webgpu` for best performance, especially when running neural networks.** The default
|
|
237
|
+
device is `wasm`, but you can change this at startup time:
|
|
234
238
|
|
|
235
239
|
```ts
|
|
236
240
|
import { defaultDevice, init } from "@jax-js/jax";
|
|
@@ -333,7 +337,6 @@ Contributions are welcomed! Some fruitful areas to look into:
|
|
|
333
337
|
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
334
338
|
to help with model performance debugging.
|
|
335
339
|
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
336
|
-
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
337
340
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
338
341
|
|
|
339
342
|
You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
|
|
@@ -1470,10 +1470,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
|
1470
1470
|
Routines$1["Sort"] = "Sort";
|
|
1471
1471
|
/** Returns `int32` indices of the stably sorted array. */
|
|
1472
1472
|
Routines$1["Argsort"] = "Argsort";
|
|
1473
|
-
/**
|
|
1473
|
+
/**
|
|
1474
|
+
* Solve a triangular system of equations.
|
|
1475
|
+
*
|
|
1476
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
1477
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
1478
|
+
*
|
|
1479
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
1480
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
1481
|
+
*/
|
|
1474
1482
|
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1475
|
-
/**
|
|
1483
|
+
/**
|
|
1484
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
1485
|
+
*
|
|
1486
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
1487
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
1488
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
1489
|
+
*/
|
|
1476
1490
|
Routines$1["Cholesky"] = "Cholesky";
|
|
1491
|
+
/**
|
|
1492
|
+
* LU decomposition of 2D rectangular matrices.
|
|
1493
|
+
*
|
|
1494
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
1495
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
1496
|
+
*
|
|
1497
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
1498
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
1499
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
1500
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
1501
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
1502
|
+
*/
|
|
1503
|
+
Routines$1["LU"] = "LU";
|
|
1477
1504
|
return Routines$1;
|
|
1478
1505
|
}({});
|
|
1479
1506
|
function runCpuRoutine(routine, inputs, outputs) {
|
|
@@ -1485,6 +1512,7 @@ function runCpuRoutine(routine, inputs, outputs) {
|
|
|
1485
1512
|
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1486
1513
|
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1487
1514
|
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1515
|
+
case Routines.LU: return runLU(type, inputAr, outputAr);
|
|
1488
1516
|
default:
|
|
1489
1517
|
}
|
|
1490
1518
|
}
|
|
@@ -1545,6 +1573,50 @@ function runCholesky(type, [x], [y]) {
|
|
|
1545
1573
|
}
|
|
1546
1574
|
}
|
|
1547
1575
|
}
|
|
1576
|
+
function runLU(type, [a], [lu, pivots, perm]) {
|
|
1577
|
+
const shape = type.inputShapes[0];
|
|
1578
|
+
if (shape.length < 2) throw new Error("lu: input must be at least 2D");
|
|
1579
|
+
const m = shape[shape.length - 2];
|
|
1580
|
+
const n = shape[shape.length - 1];
|
|
1581
|
+
const r = Math.min(m, n);
|
|
1582
|
+
for (let offset = 0; offset < a.length; offset += m * n) {
|
|
1583
|
+
const ar = a.subarray(offset, offset + m * n);
|
|
1584
|
+
const out = lu.subarray(offset, offset + m * n);
|
|
1585
|
+
const batchIdx = offset / (m * n);
|
|
1586
|
+
const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
|
|
1587
|
+
const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
|
|
1588
|
+
out.set(ar);
|
|
1589
|
+
for (let i = 0; i < m; i++) p[i] = i;
|
|
1590
|
+
for (let j = 0; j < r; j++) {
|
|
1591
|
+
let maxVal = Math.abs(out[j * n + j]);
|
|
1592
|
+
let maxRow = j;
|
|
1593
|
+
for (let i = j + 1; i < m; i++) {
|
|
1594
|
+
const val = Math.abs(out[i * n + j]);
|
|
1595
|
+
if (val > maxVal) {
|
|
1596
|
+
maxVal = val;
|
|
1597
|
+
maxRow = i;
|
|
1598
|
+
}
|
|
1599
|
+
}
|
|
1600
|
+
piv[j] = maxRow;
|
|
1601
|
+
if (maxRow !== j) {
|
|
1602
|
+
for (let col = 0; col < n; col++) {
|
|
1603
|
+
const tmp = out[j * n + col];
|
|
1604
|
+
out[j * n + col] = out[maxRow * n + col];
|
|
1605
|
+
out[maxRow * n + col] = tmp;
|
|
1606
|
+
}
|
|
1607
|
+
const tmpP = p[j];
|
|
1608
|
+
p[j] = p[maxRow];
|
|
1609
|
+
p[maxRow] = tmpP;
|
|
1610
|
+
}
|
|
1611
|
+
const diag = out[j * n + j];
|
|
1612
|
+
if (diag !== 0) for (let i = j + 1; i < m; i++) {
|
|
1613
|
+
const factor = out[i * n + j] / diag;
|
|
1614
|
+
out[i * n + j] = factor;
|
|
1615
|
+
for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
|
|
1616
|
+
}
|
|
1617
|
+
}
|
|
1618
|
+
}
|
|
1619
|
+
}
|
|
1548
1620
|
|
|
1549
1621
|
//#endregion
|
|
1550
1622
|
//#region src/shape.ts
|
|
@@ -3512,7 +3584,7 @@ var I32 = class {
|
|
|
3512
3584
|
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3513
3585
|
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3514
3586
|
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3515
|
-
eqz =
|
|
3587
|
+
eqz = UNARY_OP("eqz", 69, "i32", "i32");
|
|
3516
3588
|
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3517
3589
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3518
3590
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
@@ -3977,7 +4049,7 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3977
4049
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3978
4050
|
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3979
4051
|
else dtyF(cg, op, dtype).max();
|
|
3980
|
-
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
4052
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
|
|
3981
4053
|
const a = cg.local.declare(cg.i32);
|
|
3982
4054
|
const b = cg.local.declare(cg.i32);
|
|
3983
4055
|
cg.local.set(b);
|
|
@@ -4127,7 +4199,8 @@ function dtyF(cg, op, dtype) {
|
|
|
4127
4199
|
const devices = [
|
|
4128
4200
|
"cpu",
|
|
4129
4201
|
"wasm",
|
|
4130
|
-
"webgpu"
|
|
4202
|
+
"webgpu",
|
|
4203
|
+
"webgl"
|
|
4131
4204
|
];
|
|
4132
4205
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
4133
4206
|
initializedBackends.set("cpu", new CpuBackend());
|
|
@@ -4166,7 +4239,7 @@ async function createBackend(device) {
|
|
|
4166
4239
|
if (!navigator.gpu) return null;
|
|
4167
4240
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4168
4241
|
if (!adapter) return null;
|
|
4169
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4242
|
+
const { WebGPUBackend } = await import("./webgpu-Dh7k9io0.js");
|
|
4170
4243
|
const importantLimits = [
|
|
4171
4244
|
"maxBufferSize",
|
|
4172
4245
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4190,6 +4263,22 @@ async function createBackend(device) {
|
|
|
4190
4263
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
4191
4264
|
return null;
|
|
4192
4265
|
}
|
|
4266
|
+
} else if (device === "webgl") {
|
|
4267
|
+
if (typeof WebGL2RenderingContext === "undefined") return null;
|
|
4268
|
+
const canvas = new OffscreenCanvas(0, 0);
|
|
4269
|
+
const gl = canvas.getContext("webgl2", {
|
|
4270
|
+
alpha: false,
|
|
4271
|
+
antialias: false,
|
|
4272
|
+
premultipliedAlpha: false,
|
|
4273
|
+
preserveDrawingBuffer: false,
|
|
4274
|
+
depth: false,
|
|
4275
|
+
stencil: false,
|
|
4276
|
+
failIfMajorPerformanceCaveat: true
|
|
4277
|
+
});
|
|
4278
|
+
if (!gl) return null;
|
|
4279
|
+
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4280
|
+
const { WebGLBackend } = await import("./webgl-RSuZKvgc.js");
|
|
4281
|
+
return new WebGLBackend(gl);
|
|
4193
4282
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4194
4283
|
}
|
|
4195
4284
|
/** Retrieve a backend that has been initialized. */
|
|
@@ -4224,4 +4313,4 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4224
4313
|
};
|
|
4225
4314
|
|
|
4226
4315
|
//#endregion
|
|
4227
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4316
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
@@ -1471,10 +1471,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
|
1471
1471
|
Routines$1["Sort"] = "Sort";
|
|
1472
1472
|
/** Returns `int32` indices of the stably sorted array. */
|
|
1473
1473
|
Routines$1["Argsort"] = "Argsort";
|
|
1474
|
-
/**
|
|
1474
|
+
/**
|
|
1475
|
+
* Solve a triangular system of equations.
|
|
1476
|
+
*
|
|
1477
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
1478
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
1479
|
+
*
|
|
1480
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
1481
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
1482
|
+
*/
|
|
1475
1483
|
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1476
|
-
/**
|
|
1484
|
+
/**
|
|
1485
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
1486
|
+
*
|
|
1487
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
1488
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
1489
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
1490
|
+
*/
|
|
1477
1491
|
Routines$1["Cholesky"] = "Cholesky";
|
|
1492
|
+
/**
|
|
1493
|
+
* LU decomposition of 2D rectangular matrices.
|
|
1494
|
+
*
|
|
1495
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
1496
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
1497
|
+
*
|
|
1498
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
1499
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
1500
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
1501
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
1502
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
1503
|
+
*/
|
|
1504
|
+
Routines$1["LU"] = "LU";
|
|
1478
1505
|
return Routines$1;
|
|
1479
1506
|
}({});
|
|
1480
1507
|
function runCpuRoutine(routine, inputs, outputs) {
|
|
@@ -1486,6 +1513,7 @@ function runCpuRoutine(routine, inputs, outputs) {
|
|
|
1486
1513
|
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1487
1514
|
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1488
1515
|
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1516
|
+
case Routines.LU: return runLU(type, inputAr, outputAr);
|
|
1489
1517
|
default:
|
|
1490
1518
|
}
|
|
1491
1519
|
}
|
|
@@ -1546,6 +1574,50 @@ function runCholesky(type, [x], [y]) {
|
|
|
1546
1574
|
}
|
|
1547
1575
|
}
|
|
1548
1576
|
}
|
|
1577
|
+
function runLU(type, [a], [lu, pivots, perm]) {
|
|
1578
|
+
const shape = type.inputShapes[0];
|
|
1579
|
+
if (shape.length < 2) throw new Error("lu: input must be at least 2D");
|
|
1580
|
+
const m = shape[shape.length - 2];
|
|
1581
|
+
const n = shape[shape.length - 1];
|
|
1582
|
+
const r = Math.min(m, n);
|
|
1583
|
+
for (let offset = 0; offset < a.length; offset += m * n) {
|
|
1584
|
+
const ar = a.subarray(offset, offset + m * n);
|
|
1585
|
+
const out = lu.subarray(offset, offset + m * n);
|
|
1586
|
+
const batchIdx = offset / (m * n);
|
|
1587
|
+
const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
|
|
1588
|
+
const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
|
|
1589
|
+
out.set(ar);
|
|
1590
|
+
for (let i = 0; i < m; i++) p[i] = i;
|
|
1591
|
+
for (let j = 0; j < r; j++) {
|
|
1592
|
+
let maxVal = Math.abs(out[j * n + j]);
|
|
1593
|
+
let maxRow = j;
|
|
1594
|
+
for (let i = j + 1; i < m; i++) {
|
|
1595
|
+
const val = Math.abs(out[i * n + j]);
|
|
1596
|
+
if (val > maxVal) {
|
|
1597
|
+
maxVal = val;
|
|
1598
|
+
maxRow = i;
|
|
1599
|
+
}
|
|
1600
|
+
}
|
|
1601
|
+
piv[j] = maxRow;
|
|
1602
|
+
if (maxRow !== j) {
|
|
1603
|
+
for (let col = 0; col < n; col++) {
|
|
1604
|
+
const tmp = out[j * n + col];
|
|
1605
|
+
out[j * n + col] = out[maxRow * n + col];
|
|
1606
|
+
out[maxRow * n + col] = tmp;
|
|
1607
|
+
}
|
|
1608
|
+
const tmpP = p[j];
|
|
1609
|
+
p[j] = p[maxRow];
|
|
1610
|
+
p[maxRow] = tmpP;
|
|
1611
|
+
}
|
|
1612
|
+
const diag = out[j * n + j];
|
|
1613
|
+
if (diag !== 0) for (let i = j + 1; i < m; i++) {
|
|
1614
|
+
const factor = out[i * n + j] / diag;
|
|
1615
|
+
out[i * n + j] = factor;
|
|
1616
|
+
for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
|
|
1617
|
+
}
|
|
1618
|
+
}
|
|
1619
|
+
}
|
|
1620
|
+
}
|
|
1549
1621
|
|
|
1550
1622
|
//#endregion
|
|
1551
1623
|
//#region src/shape.ts
|
|
@@ -3513,7 +3585,7 @@ var I32 = class {
|
|
|
3513
3585
|
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3514
3586
|
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3515
3587
|
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3516
|
-
eqz =
|
|
3588
|
+
eqz = UNARY_OP("eqz", 69, "i32", "i32");
|
|
3517
3589
|
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3518
3590
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3519
3591
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
@@ -3978,7 +4050,7 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3978
4050
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3979
4051
|
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3980
4052
|
else dtyF(cg, op, dtype).max();
|
|
3981
|
-
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
4053
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
|
|
3982
4054
|
const a = cg.local.declare(cg.i32);
|
|
3983
4055
|
const b = cg.local.declare(cg.i32);
|
|
3984
4056
|
cg.local.set(b);
|
|
@@ -4128,7 +4200,8 @@ function dtyF(cg, op, dtype) {
|
|
|
4128
4200
|
const devices = [
|
|
4129
4201
|
"cpu",
|
|
4130
4202
|
"wasm",
|
|
4131
|
-
"webgpu"
|
|
4203
|
+
"webgpu",
|
|
4204
|
+
"webgl"
|
|
4132
4205
|
];
|
|
4133
4206
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
4134
4207
|
initializedBackends.set("cpu", new CpuBackend());
|
|
@@ -4167,7 +4240,7 @@ async function createBackend(device) {
|
|
|
4167
4240
|
if (!navigator.gpu) return null;
|
|
4168
4241
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4169
4242
|
if (!adapter) return null;
|
|
4170
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4243
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Db2JrNBr.cjs"));
|
|
4171
4244
|
const importantLimits = [
|
|
4172
4245
|
"maxBufferSize",
|
|
4173
4246
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4191,6 +4264,22 @@ async function createBackend(device) {
|
|
|
4191
4264
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
4192
4265
|
return null;
|
|
4193
4266
|
}
|
|
4267
|
+
} else if (device === "webgl") {
|
|
4268
|
+
if (typeof WebGL2RenderingContext === "undefined") return null;
|
|
4269
|
+
const canvas = new OffscreenCanvas(0, 0);
|
|
4270
|
+
const gl = canvas.getContext("webgl2", {
|
|
4271
|
+
alpha: false,
|
|
4272
|
+
antialias: false,
|
|
4273
|
+
premultipliedAlpha: false,
|
|
4274
|
+
preserveDrawingBuffer: false,
|
|
4275
|
+
depth: false,
|
|
4276
|
+
stencil: false,
|
|
4277
|
+
failIfMajorPerformanceCaveat: true
|
|
4278
|
+
});
|
|
4279
|
+
if (!gl) return null;
|
|
4280
|
+
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4281
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-ClIYb8jP.cjs"));
|
|
4282
|
+
return new WebGLBackend(gl);
|
|
4194
4283
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4195
4284
|
}
|
|
4196
4285
|
/** Retrieve a backend that has been initialized. */
|
|
@@ -4507,6 +4596,12 @@ Object.defineProperty(exports, 'toposort', {
|
|
|
4507
4596
|
return toposort;
|
|
4508
4597
|
}
|
|
4509
4598
|
});
|
|
4599
|
+
Object.defineProperty(exports, 'tuneNullopt', {
|
|
4600
|
+
enumerable: true,
|
|
4601
|
+
get: function () {
|
|
4602
|
+
return tuneNullopt;
|
|
4603
|
+
}
|
|
4604
|
+
});
|
|
4510
4605
|
Object.defineProperty(exports, 'tuneWebgpu', {
|
|
4511
4606
|
enumerable: true,
|
|
4512
4607
|
get: function () {
|