@jax-js/jax 0.1.4 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-Bu9GY6sK.cjs → backend-D7s-Retx.cjs} +122 -8
- package/dist/{backend-tngXtWe4.js → backend-Dx6Ob2D1.js} +111 -9
- package/dist/index.cjs +1059 -208
- package/dist/index.d.cts +429 -21
- package/dist/index.d.ts +429 -21
- package/dist/index.js +1059 -209
- package/dist/webgl-CLLvzJlO.js +522 -0
- package/dist/webgl-CyfzNW8T.cjs +522 -0
- package/dist/{webgpu-ChVgx3b6.js → webgpu-C-VfevQW.js} +296 -3
- package/dist/{webgpu-Oj3Kd-kd.cjs → webgpu-rraa6dfz.cjs} +296 -3
- package/package.json +1 -1
package/README.md
CHANGED
|
@@ -58,13 +58,13 @@ import { numpy as np } from "@jax-js/jax";
|
|
|
58
58
|
const ar = np.array([1, 2, 3]);
|
|
59
59
|
```
|
|
60
60
|
|
|
61
|
-
By default, this is a float32 array, but you can
|
|
61
|
+
By default, this is a float32 array, but you can specify a different dtype:
|
|
62
62
|
|
|
63
63
|
```ts
|
|
64
|
-
const ar = np.array([1, 2, 3], { dtype: np.
|
|
64
|
+
const ar = np.array([1, 2, 3], { dtype: np.int32 });
|
|
65
65
|
```
|
|
66
66
|
|
|
67
|
-
For more efficient construction,
|
|
67
|
+
For more efficient construction, create an array from a JS `TypedArray` buffer:
|
|
68
68
|
|
|
69
69
|
```ts
|
|
70
70
|
const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
|
|
@@ -223,14 +223,18 @@ Note that you need to use `type` alias syntax rather than `interface` to define
|
|
|
223
223
|
Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
|
|
224
224
|
and determines how to execute compiled operations on them.
|
|
225
225
|
|
|
226
|
-
There are currently
|
|
226
|
+
There are currently 4 devices in jax-js:
|
|
227
227
|
|
|
228
|
-
- `cpu`: Slow,
|
|
228
|
+
- `cpu`: Slow, interpreted JS, only meant for debugging.
|
|
229
229
|
- `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
|
|
230
230
|
- `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
|
|
231
231
|
[supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
|
|
232
|
+
- `webgl`: [WebGL2](https://developer.mozilla.org/en-US/docs/Web/API/WebGL2RenderingContext), via
|
|
233
|
+
fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much
|
|
234
|
+
slower than WebGPU. It's offered on a best-effort basis and not as well-supported.
|
|
232
235
|
|
|
233
|
-
|
|
236
|
+
**We recommend `webgpu` for best performance, especially when running neural networks.** The default
|
|
237
|
+
device is `wasm`, but you can change this at startup time:
|
|
234
238
|
|
|
235
239
|
```ts
|
|
236
240
|
import { defaultDevice, init } from "@jax-js/jax";
|
|
@@ -333,7 +337,6 @@ Contributions are welcomed! Some fruitful areas to look into:
|
|
|
333
337
|
able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
|
|
334
338
|
to help with model performance debugging.
|
|
335
339
|
- Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
|
|
336
|
-
- Adding WebGL runtime for older browsers that don't support WebGPU.
|
|
337
340
|
- Making a fast transformer inference engine, comparing against onnxruntime-web.
|
|
338
341
|
|
|
339
342
|
You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
|
|
@@ -151,6 +151,19 @@ function normalizeAxis(axis, ndim) {
|
|
|
151
151
|
return sorted(seen);
|
|
152
152
|
}
|
|
153
153
|
}
|
|
154
|
+
/** Check for an array of integers with no duplicates. */
|
|
155
|
+
function checkInts(indices) {
|
|
156
|
+
if (typeof indices === "number") {
|
|
157
|
+
if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
|
|
158
|
+
} else {
|
|
159
|
+
const seen = /* @__PURE__ */ new Set();
|
|
160
|
+
for (const i of indices) {
|
|
161
|
+
if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
|
|
162
|
+
if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
|
|
163
|
+
seen.add(i);
|
|
164
|
+
}
|
|
165
|
+
}
|
|
166
|
+
}
|
|
154
167
|
function range(start, stop, step = 1) {
|
|
155
168
|
if (stop === void 0) {
|
|
156
169
|
stop = start;
|
|
@@ -1471,10 +1484,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
|
1471
1484
|
Routines$1["Sort"] = "Sort";
|
|
1472
1485
|
/** Returns `int32` indices of the stably sorted array. */
|
|
1473
1486
|
Routines$1["Argsort"] = "Argsort";
|
|
1474
|
-
/**
|
|
1487
|
+
/**
|
|
1488
|
+
* Solve a triangular system of equations.
|
|
1489
|
+
*
|
|
1490
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
1491
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
1492
|
+
*
|
|
1493
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
1494
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
1495
|
+
*/
|
|
1475
1496
|
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1476
|
-
/**
|
|
1497
|
+
/**
|
|
1498
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
1499
|
+
*
|
|
1500
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
1501
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
1502
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
1503
|
+
*/
|
|
1477
1504
|
Routines$1["Cholesky"] = "Cholesky";
|
|
1505
|
+
/**
|
|
1506
|
+
* LU decomposition of 2D rectangular matrices.
|
|
1507
|
+
*
|
|
1508
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
1509
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
1510
|
+
*
|
|
1511
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
1512
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
1513
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
1514
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
1515
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
1516
|
+
*/
|
|
1517
|
+
Routines$1["LU"] = "LU";
|
|
1478
1518
|
return Routines$1;
|
|
1479
1519
|
}({});
|
|
1480
1520
|
function runCpuRoutine(routine, inputs, outputs) {
|
|
@@ -1486,6 +1526,7 @@ function runCpuRoutine(routine, inputs, outputs) {
|
|
|
1486
1526
|
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1487
1527
|
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1488
1528
|
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1529
|
+
case Routines.LU: return runLU(type, inputAr, outputAr);
|
|
1489
1530
|
default:
|
|
1490
1531
|
}
|
|
1491
1532
|
}
|
|
@@ -1546,6 +1587,50 @@ function runCholesky(type, [x], [y]) {
|
|
|
1546
1587
|
}
|
|
1547
1588
|
}
|
|
1548
1589
|
}
|
|
1590
|
+
function runLU(type, [a], [lu, pivots, perm]) {
|
|
1591
|
+
const shape = type.inputShapes[0];
|
|
1592
|
+
if (shape.length < 2) throw new Error("lu: input must be at least 2D");
|
|
1593
|
+
const m = shape[shape.length - 2];
|
|
1594
|
+
const n = shape[shape.length - 1];
|
|
1595
|
+
const r = Math.min(m, n);
|
|
1596
|
+
for (let offset = 0; offset < a.length; offset += m * n) {
|
|
1597
|
+
const ar = a.subarray(offset, offset + m * n);
|
|
1598
|
+
const out = lu.subarray(offset, offset + m * n);
|
|
1599
|
+
const batchIdx = offset / (m * n);
|
|
1600
|
+
const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
|
|
1601
|
+
const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
|
|
1602
|
+
out.set(ar);
|
|
1603
|
+
for (let i = 0; i < m; i++) p[i] = i;
|
|
1604
|
+
for (let j = 0; j < r; j++) {
|
|
1605
|
+
let maxVal = Math.abs(out[j * n + j]);
|
|
1606
|
+
let maxRow = j;
|
|
1607
|
+
for (let i = j + 1; i < m; i++) {
|
|
1608
|
+
const val = Math.abs(out[i * n + j]);
|
|
1609
|
+
if (val > maxVal) {
|
|
1610
|
+
maxVal = val;
|
|
1611
|
+
maxRow = i;
|
|
1612
|
+
}
|
|
1613
|
+
}
|
|
1614
|
+
piv[j] = maxRow;
|
|
1615
|
+
if (maxRow !== j) {
|
|
1616
|
+
for (let col = 0; col < n; col++) {
|
|
1617
|
+
const tmp = out[j * n + col];
|
|
1618
|
+
out[j * n + col] = out[maxRow * n + col];
|
|
1619
|
+
out[maxRow * n + col] = tmp;
|
|
1620
|
+
}
|
|
1621
|
+
const tmpP = p[j];
|
|
1622
|
+
p[j] = p[maxRow];
|
|
1623
|
+
p[maxRow] = tmpP;
|
|
1624
|
+
}
|
|
1625
|
+
const diag = out[j * n + j];
|
|
1626
|
+
if (diag !== 0) for (let i = j + 1; i < m; i++) {
|
|
1627
|
+
const factor = out[i * n + j] / diag;
|
|
1628
|
+
out[i * n + j] = factor;
|
|
1629
|
+
for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
|
|
1630
|
+
}
|
|
1631
|
+
}
|
|
1632
|
+
}
|
|
1633
|
+
}
|
|
1549
1634
|
|
|
1550
1635
|
//#endregion
|
|
1551
1636
|
//#region src/shape.ts
|
|
@@ -2234,10 +2319,10 @@ function tuneWebgpu(kernel) {
|
|
|
2234
2319
|
upcastedAxis.add(choices[0][2]);
|
|
2235
2320
|
} else break;
|
|
2236
2321
|
}
|
|
2237
|
-
if (/
|
|
2322
|
+
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2238
2323
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2239
2324
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2240
|
-
else for (const splits of [4]) if (s % splits === 0) {
|
|
2325
|
+
else for (const splits of [8, 4]) if (s % splits === 0) {
|
|
2241
2326
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2242
2327
|
break;
|
|
2243
2328
|
}
|
|
@@ -3513,7 +3598,7 @@ var I32 = class {
|
|
|
3513
3598
|
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3514
3599
|
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3515
3600
|
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3516
|
-
eqz =
|
|
3601
|
+
eqz = UNARY_OP("eqz", 69, "i32", "i32");
|
|
3517
3602
|
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3518
3603
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3519
3604
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
@@ -3978,7 +4063,7 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3978
4063
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3979
4064
|
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3980
4065
|
else dtyF(cg, op, dtype).max();
|
|
3981
|
-
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
4066
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
|
|
3982
4067
|
const a = cg.local.declare(cg.i32);
|
|
3983
4068
|
const b = cg.local.declare(cg.i32);
|
|
3984
4069
|
cg.local.set(b);
|
|
@@ -4128,7 +4213,8 @@ function dtyF(cg, op, dtype) {
|
|
|
4128
4213
|
const devices = [
|
|
4129
4214
|
"cpu",
|
|
4130
4215
|
"wasm",
|
|
4131
|
-
"webgpu"
|
|
4216
|
+
"webgpu",
|
|
4217
|
+
"webgl"
|
|
4132
4218
|
];
|
|
4133
4219
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
4134
4220
|
initializedBackends.set("cpu", new CpuBackend());
|
|
@@ -4167,7 +4253,7 @@ async function createBackend(device) {
|
|
|
4167
4253
|
if (!navigator.gpu) return null;
|
|
4168
4254
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4169
4255
|
if (!adapter) return null;
|
|
4170
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
4256
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-rraa6dfz.cjs"));
|
|
4171
4257
|
const importantLimits = [
|
|
4172
4258
|
"maxBufferSize",
|
|
4173
4259
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4191,6 +4277,22 @@ async function createBackend(device) {
|
|
|
4191
4277
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
4192
4278
|
return null;
|
|
4193
4279
|
}
|
|
4280
|
+
} else if (device === "webgl") {
|
|
4281
|
+
if (typeof WebGL2RenderingContext === "undefined") return null;
|
|
4282
|
+
const canvas = new OffscreenCanvas(0, 0);
|
|
4283
|
+
const gl = canvas.getContext("webgl2", {
|
|
4284
|
+
alpha: false,
|
|
4285
|
+
antialias: false,
|
|
4286
|
+
premultipliedAlpha: false,
|
|
4287
|
+
preserveDrawingBuffer: false,
|
|
4288
|
+
depth: false,
|
|
4289
|
+
stencil: false,
|
|
4290
|
+
failIfMajorPerformanceCaveat: true
|
|
4291
|
+
});
|
|
4292
|
+
if (!gl) return null;
|
|
4293
|
+
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4294
|
+
const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-CyfzNW8T.cjs"));
|
|
4295
|
+
return new WebGLBackend(gl);
|
|
4194
4296
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4195
4297
|
}
|
|
4196
4298
|
/** Retrieve a backend that has been initialized. */
|
|
@@ -4357,6 +4459,12 @@ Object.defineProperty(exports, 'checkAxis', {
|
|
|
4357
4459
|
return checkAxis;
|
|
4358
4460
|
}
|
|
4359
4461
|
});
|
|
4462
|
+
Object.defineProperty(exports, 'checkInts', {
|
|
4463
|
+
enumerable: true,
|
|
4464
|
+
get: function () {
|
|
4465
|
+
return checkInts;
|
|
4466
|
+
}
|
|
4467
|
+
});
|
|
4360
4468
|
Object.defineProperty(exports, 'deepEqual', {
|
|
4361
4469
|
enumerable: true,
|
|
4362
4470
|
get: function () {
|
|
@@ -4507,6 +4615,12 @@ Object.defineProperty(exports, 'toposort', {
|
|
|
4507
4615
|
return toposort;
|
|
4508
4616
|
}
|
|
4509
4617
|
});
|
|
4618
|
+
Object.defineProperty(exports, 'tuneNullopt', {
|
|
4619
|
+
enumerable: true,
|
|
4620
|
+
get: function () {
|
|
4621
|
+
return tuneNullopt;
|
|
4622
|
+
}
|
|
4623
|
+
});
|
|
4510
4624
|
Object.defineProperty(exports, 'tuneWebgpu', {
|
|
4511
4625
|
enumerable: true,
|
|
4512
4626
|
get: function () {
|
|
@@ -150,6 +150,19 @@ function normalizeAxis(axis, ndim) {
|
|
|
150
150
|
return sorted(seen);
|
|
151
151
|
}
|
|
152
152
|
}
|
|
153
|
+
/** Check for an array of integers with no duplicates. */
|
|
154
|
+
function checkInts(indices) {
|
|
155
|
+
if (typeof indices === "number") {
|
|
156
|
+
if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
|
|
157
|
+
} else {
|
|
158
|
+
const seen = /* @__PURE__ */ new Set();
|
|
159
|
+
for (const i of indices) {
|
|
160
|
+
if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
|
|
161
|
+
if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
|
|
162
|
+
seen.add(i);
|
|
163
|
+
}
|
|
164
|
+
}
|
|
165
|
+
}
|
|
153
166
|
function range(start, stop, step = 1) {
|
|
154
167
|
if (stop === void 0) {
|
|
155
168
|
stop = start;
|
|
@@ -1470,10 +1483,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
|
|
|
1470
1483
|
Routines$1["Sort"] = "Sort";
|
|
1471
1484
|
/** Returns `int32` indices of the stably sorted array. */
|
|
1472
1485
|
Routines$1["Argsort"] = "Argsort";
|
|
1473
|
-
/**
|
|
1486
|
+
/**
|
|
1487
|
+
* Solve a triangular system of equations.
|
|
1488
|
+
*
|
|
1489
|
+
* The first batch of inputs `A` should be of shape `[..., N, N]` and upper
|
|
1490
|
+
* triangular, while the second batch `B` should be of shape `[..., M, N]`.
|
|
1491
|
+
*
|
|
1492
|
+
* Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
|
|
1493
|
+
* triangular matrix. This is equivalent to `X = B @ A^-T`.
|
|
1494
|
+
*/
|
|
1474
1495
|
Routines$1["TriangularSolve"] = "TriangularSolve";
|
|
1475
|
-
/**
|
|
1496
|
+
/**
|
|
1497
|
+
* Cholesky decomposition of 2D positive semi-definite matrices.
|
|
1498
|
+
*
|
|
1499
|
+
* The input batch should be of shape `[..., N, N]`, and the output batch is
|
|
1500
|
+
* of the same shape, containing the lower-triangular matrix `L` such that
|
|
1501
|
+
* `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
|
|
1502
|
+
*/
|
|
1476
1503
|
Routines$1["Cholesky"] = "Cholesky";
|
|
1504
|
+
/**
|
|
1505
|
+
* LU decomposition of 2D rectangular matrices.
|
|
1506
|
+
*
|
|
1507
|
+
* The input is a batch of shape `[..., M, N]`, and the output is a tuple of
|
|
1508
|
+
* three arrays: `LU, Pivots, Permutation`.
|
|
1509
|
+
*
|
|
1510
|
+
* - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
|
|
1511
|
+
* triangular matrices. (lower triangular = implicit unit diagonal)
|
|
1512
|
+
* - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
|
|
1513
|
+
* - `Permutation` is of shape `[..., M]`, containing the permutation vector
|
|
1514
|
+
* such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
|
|
1515
|
+
*/
|
|
1516
|
+
Routines$1["LU"] = "LU";
|
|
1477
1517
|
return Routines$1;
|
|
1478
1518
|
}({});
|
|
1479
1519
|
function runCpuRoutine(routine, inputs, outputs) {
|
|
@@ -1485,6 +1525,7 @@ function runCpuRoutine(routine, inputs, outputs) {
|
|
|
1485
1525
|
case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
|
|
1486
1526
|
case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
|
|
1487
1527
|
case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
|
|
1528
|
+
case Routines.LU: return runLU(type, inputAr, outputAr);
|
|
1488
1529
|
default:
|
|
1489
1530
|
}
|
|
1490
1531
|
}
|
|
@@ -1545,6 +1586,50 @@ function runCholesky(type, [x], [y]) {
|
|
|
1545
1586
|
}
|
|
1546
1587
|
}
|
|
1547
1588
|
}
|
|
1589
|
+
function runLU(type, [a], [lu, pivots, perm]) {
|
|
1590
|
+
const shape = type.inputShapes[0];
|
|
1591
|
+
if (shape.length < 2) throw new Error("lu: input must be at least 2D");
|
|
1592
|
+
const m = shape[shape.length - 2];
|
|
1593
|
+
const n = shape[shape.length - 1];
|
|
1594
|
+
const r = Math.min(m, n);
|
|
1595
|
+
for (let offset = 0; offset < a.length; offset += m * n) {
|
|
1596
|
+
const ar = a.subarray(offset, offset + m * n);
|
|
1597
|
+
const out = lu.subarray(offset, offset + m * n);
|
|
1598
|
+
const batchIdx = offset / (m * n);
|
|
1599
|
+
const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
|
|
1600
|
+
const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
|
|
1601
|
+
out.set(ar);
|
|
1602
|
+
for (let i = 0; i < m; i++) p[i] = i;
|
|
1603
|
+
for (let j = 0; j < r; j++) {
|
|
1604
|
+
let maxVal = Math.abs(out[j * n + j]);
|
|
1605
|
+
let maxRow = j;
|
|
1606
|
+
for (let i = j + 1; i < m; i++) {
|
|
1607
|
+
const val = Math.abs(out[i * n + j]);
|
|
1608
|
+
if (val > maxVal) {
|
|
1609
|
+
maxVal = val;
|
|
1610
|
+
maxRow = i;
|
|
1611
|
+
}
|
|
1612
|
+
}
|
|
1613
|
+
piv[j] = maxRow;
|
|
1614
|
+
if (maxRow !== j) {
|
|
1615
|
+
for (let col = 0; col < n; col++) {
|
|
1616
|
+
const tmp = out[j * n + col];
|
|
1617
|
+
out[j * n + col] = out[maxRow * n + col];
|
|
1618
|
+
out[maxRow * n + col] = tmp;
|
|
1619
|
+
}
|
|
1620
|
+
const tmpP = p[j];
|
|
1621
|
+
p[j] = p[maxRow];
|
|
1622
|
+
p[maxRow] = tmpP;
|
|
1623
|
+
}
|
|
1624
|
+
const diag = out[j * n + j];
|
|
1625
|
+
if (diag !== 0) for (let i = j + 1; i < m; i++) {
|
|
1626
|
+
const factor = out[i * n + j] / diag;
|
|
1627
|
+
out[i * n + j] = factor;
|
|
1628
|
+
for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
|
|
1629
|
+
}
|
|
1630
|
+
}
|
|
1631
|
+
}
|
|
1632
|
+
}
|
|
1548
1633
|
|
|
1549
1634
|
//#endregion
|
|
1550
1635
|
//#region src/shape.ts
|
|
@@ -2233,10 +2318,10 @@ function tuneWebgpu(kernel) {
|
|
|
2233
2318
|
upcastedAxis.add(choices[0][2]);
|
|
2234
2319
|
} else break;
|
|
2235
2320
|
}
|
|
2236
|
-
if (/
|
|
2321
|
+
if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
2237
2322
|
const s = dim.st.shape[dim.unroll - 1];
|
|
2238
2323
|
if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
2239
|
-
else for (const splits of [4]) if (s % splits === 0) {
|
|
2324
|
+
else for (const splits of [8, 4]) if (s % splits === 0) {
|
|
2240
2325
|
dim.applyUnroll(dim.unroll - 1, splits);
|
|
2241
2326
|
break;
|
|
2242
2327
|
}
|
|
@@ -3512,7 +3597,7 @@ var I32 = class {
|
|
|
3512
3597
|
shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
|
|
3513
3598
|
rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
|
|
3514
3599
|
rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
|
|
3515
|
-
eqz =
|
|
3600
|
+
eqz = UNARY_OP("eqz", 69, "i32", "i32");
|
|
3516
3601
|
eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
|
|
3517
3602
|
ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
|
|
3518
3603
|
trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
|
|
@@ -3977,7 +4062,7 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3977
4062
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3978
4063
|
else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
|
|
3979
4064
|
else dtyF(cg, op, dtype).max();
|
|
3980
|
-
else if (dtype === DType.Int32 || dtype === DType.Uint32) {
|
|
4065
|
+
else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
|
|
3981
4066
|
const a = cg.local.declare(cg.i32);
|
|
3982
4067
|
const b = cg.local.declare(cg.i32);
|
|
3983
4068
|
cg.local.set(b);
|
|
@@ -4127,7 +4212,8 @@ function dtyF(cg, op, dtype) {
|
|
|
4127
4212
|
const devices = [
|
|
4128
4213
|
"cpu",
|
|
4129
4214
|
"wasm",
|
|
4130
|
-
"webgpu"
|
|
4215
|
+
"webgpu",
|
|
4216
|
+
"webgl"
|
|
4131
4217
|
];
|
|
4132
4218
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
4133
4219
|
initializedBackends.set("cpu", new CpuBackend());
|
|
@@ -4166,7 +4252,7 @@ async function createBackend(device) {
|
|
|
4166
4252
|
if (!navigator.gpu) return null;
|
|
4167
4253
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
4168
4254
|
if (!adapter) return null;
|
|
4169
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
4255
|
+
const { WebGPUBackend } = await import("./webgpu-C-VfevQW.js");
|
|
4170
4256
|
const importantLimits = [
|
|
4171
4257
|
"maxBufferSize",
|
|
4172
4258
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -4190,6 +4276,22 @@ async function createBackend(device) {
|
|
|
4190
4276
|
console.error("Unexpected error requesting WebGPU device:", error);
|
|
4191
4277
|
return null;
|
|
4192
4278
|
}
|
|
4279
|
+
} else if (device === "webgl") {
|
|
4280
|
+
if (typeof WebGL2RenderingContext === "undefined") return null;
|
|
4281
|
+
const canvas = new OffscreenCanvas(0, 0);
|
|
4282
|
+
const gl = canvas.getContext("webgl2", {
|
|
4283
|
+
alpha: false,
|
|
4284
|
+
antialias: false,
|
|
4285
|
+
premultipliedAlpha: false,
|
|
4286
|
+
preserveDrawingBuffer: false,
|
|
4287
|
+
depth: false,
|
|
4288
|
+
stencil: false,
|
|
4289
|
+
failIfMajorPerformanceCaveat: true
|
|
4290
|
+
});
|
|
4291
|
+
if (!gl) return null;
|
|
4292
|
+
if (!gl.getExtension("EXT_color_buffer_float")) return null;
|
|
4293
|
+
const { WebGLBackend } = await import("./webgl-CLLvzJlO.js");
|
|
4294
|
+
return new WebGLBackend(gl);
|
|
4193
4295
|
} else throw new Error(`Backend not found: ${device}`);
|
|
4194
4296
|
}
|
|
4195
4297
|
/** Retrieve a backend that has been initialized. */
|
|
@@ -4224,4 +4326,4 @@ var UnsupportedRoutineError = class extends Error {
|
|
|
4224
4326
|
};
|
|
4225
4327
|
|
|
4226
4328
|
//#endregion
|
|
4227
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|
|
4329
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
|