@jax-js/jax 0.1.4 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/README.md CHANGED
@@ -58,13 +58,13 @@ import { numpy as np } from "@jax-js/jax";
58
58
  const ar = np.array([1, 2, 3]);
59
59
  ```
60
60
 
61
- By default, this is a float32 array, but you can also specify a dtype explicitly:
61
+ By default, this is a float32 array, but you can specify a different dtype:
62
62
 
63
63
  ```ts
64
- const ar = np.array([1, 2, 3], { dtype: np.float32 });
64
+ const ar = np.array([1, 2, 3], { dtype: np.int32 });
65
65
  ```
66
66
 
67
- For more efficient construction, you can create an array from a JS `TypedArray` buffer:
67
+ For more efficient construction, create an array from a JS `TypedArray` buffer:
68
68
 
69
69
  ```ts
70
70
  const buf = new Float32Array([10, 20, 30, 100, 200, 300]);
@@ -223,14 +223,18 @@ Note that you need to use `type` alias syntax rather than `interface` to define
223
223
  Similar to JAX, jax-js has a concept of "devices" which are a backend that stores Arrays in memory
224
224
  and determines how to execute compiled operations on them.
225
225
 
226
- There are currently 3 devices in jax-js:
226
+ There are currently 4 devices in jax-js:
227
227
 
228
- - `cpu`: Slow, mostly for debugging purposes.
228
+ - `cpu`: Slow, interpreted JS, only meant for debugging.
229
229
  - `wasm`: [WebAssembly](https://webassembly.org/), currently single-threaded and blocking.
230
230
  - `webgpu`: [WebGPU](https://developer.mozilla.org/en-US/docs/Web/API/WebGPU_API), available on
231
231
  [supported browsers](https://caniuse.com/webgpu) (Chrome, Firefox, Safari, iOS).
232
+ - `webgl`: [WebGL2](https://developer.mozilla.org/en-US/docs/Web/API/WebGL2RenderingContext), via
233
+ fragment shaders. This is an older graphics API that runs on almost all browsers, but it is much
234
+ slower than WebGPU. It's offered on a best-effort basis and not as well-supported.
232
235
 
233
- The default device is `wasm`, but you can change this at startup time:
236
+ **We recommend `webgpu` for best performance, especially when running neural networks.** The default
237
+ device is `wasm`, but you can change this at startup time:
234
238
 
235
239
  ```ts
236
240
  import { defaultDevice, init } from "@jax-js/jax";
@@ -333,7 +337,6 @@ Contributions are welcomed! Some fruitful areas to look into:
333
337
  able to generate `traceEvents` from backends (especially on GPU, with precise timestamp queries)
334
338
  to help with model performance debugging.
335
339
  - Helping the JIT compiler to fuse operations in more cases, like `tanh` branches.
336
- - Adding WebGL runtime for older browsers that don't support WebGPU.
337
340
  - Making a fast transformer inference engine, comparing against onnxruntime-web.
338
341
 
339
342
  You may join our [Discord server](https://discord.gg/BW6YsCd4Tf) and chat with the community.
@@ -151,6 +151,19 @@ function normalizeAxis(axis, ndim) {
151
151
  return sorted(seen);
152
152
  }
153
153
  }
154
+ /** Check for an array of integers with no duplicates. */
155
+ function checkInts(indices) {
156
+ if (typeof indices === "number") {
157
+ if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
158
+ } else {
159
+ const seen = /* @__PURE__ */ new Set();
160
+ for (const i of indices) {
161
+ if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
162
+ if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
163
+ seen.add(i);
164
+ }
165
+ }
166
+ }
154
167
  function range(start, stop, step = 1) {
155
168
  if (stop === void 0) {
156
169
  stop = start;
@@ -1471,10 +1484,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
1471
1484
  Routines$1["Sort"] = "Sort";
1472
1485
  /** Returns `int32` indices of the stably sorted array. */
1473
1486
  Routines$1["Argsort"] = "Argsort";
1474
- /** Solve a triangular system of questions. */
1487
+ /**
1488
+ * Solve a triangular system of equations.
1489
+ *
1490
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
1491
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
1492
+ *
1493
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
1494
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
1495
+ */
1475
1496
  Routines$1["TriangularSolve"] = "TriangularSolve";
1476
- /** Cholesky decomposition of 2D positive semi-definite matrices. */
1497
+ /**
1498
+ * Cholesky decomposition of 2D positive semi-definite matrices.
1499
+ *
1500
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
1501
+ * of the same shape, containing the lower-triangular matrix `L` such that
1502
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
1503
+ */
1477
1504
  Routines$1["Cholesky"] = "Cholesky";
1505
+ /**
1506
+ * LU decomposition of 2D rectangular matrices.
1507
+ *
1508
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
1509
+ * three arrays: `LU, Pivots, Permutation`.
1510
+ *
1511
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
1512
+ * triangular matrices. (lower triangular = implicit unit diagonal)
1513
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
1514
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
1515
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
1516
+ */
1517
+ Routines$1["LU"] = "LU";
1478
1518
  return Routines$1;
1479
1519
  }({});
1480
1520
  function runCpuRoutine(routine, inputs, outputs) {
@@ -1486,6 +1526,7 @@ function runCpuRoutine(routine, inputs, outputs) {
1486
1526
  case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1487
1527
  case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1488
1528
  case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1529
+ case Routines.LU: return runLU(type, inputAr, outputAr);
1489
1530
  default:
1490
1531
  }
1491
1532
  }
@@ -1546,6 +1587,50 @@ function runCholesky(type, [x], [y]) {
1546
1587
  }
1547
1588
  }
1548
1589
  }
1590
+ function runLU(type, [a], [lu, pivots, perm]) {
1591
+ const shape = type.inputShapes[0];
1592
+ if (shape.length < 2) throw new Error("lu: input must be at least 2D");
1593
+ const m = shape[shape.length - 2];
1594
+ const n = shape[shape.length - 1];
1595
+ const r = Math.min(m, n);
1596
+ for (let offset = 0; offset < a.length; offset += m * n) {
1597
+ const ar = a.subarray(offset, offset + m * n);
1598
+ const out = lu.subarray(offset, offset + m * n);
1599
+ const batchIdx = offset / (m * n);
1600
+ const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
1601
+ const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
1602
+ out.set(ar);
1603
+ for (let i = 0; i < m; i++) p[i] = i;
1604
+ for (let j = 0; j < r; j++) {
1605
+ let maxVal = Math.abs(out[j * n + j]);
1606
+ let maxRow = j;
1607
+ for (let i = j + 1; i < m; i++) {
1608
+ const val = Math.abs(out[i * n + j]);
1609
+ if (val > maxVal) {
1610
+ maxVal = val;
1611
+ maxRow = i;
1612
+ }
1613
+ }
1614
+ piv[j] = maxRow;
1615
+ if (maxRow !== j) {
1616
+ for (let col = 0; col < n; col++) {
1617
+ const tmp = out[j * n + col];
1618
+ out[j * n + col] = out[maxRow * n + col];
1619
+ out[maxRow * n + col] = tmp;
1620
+ }
1621
+ const tmpP = p[j];
1622
+ p[j] = p[maxRow];
1623
+ p[maxRow] = tmpP;
1624
+ }
1625
+ const diag = out[j * n + j];
1626
+ if (diag !== 0) for (let i = j + 1; i < m; i++) {
1627
+ const factor = out[i * n + j] / diag;
1628
+ out[i * n + j] = factor;
1629
+ for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
1630
+ }
1631
+ }
1632
+ }
1633
+ }
1549
1634
 
1550
1635
  //#endregion
1551
1636
  //#region src/shape.ts
@@ -2234,10 +2319,10 @@ function tuneWebgpu(kernel) {
2234
2319
  upcastedAxis.add(choices[0][2]);
2235
2320
  } else break;
2236
2321
  }
2237
- if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2322
+ if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2238
2323
  const s = dim.st.shape[dim.unroll - 1];
2239
2324
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2240
- else for (const splits of [4]) if (s % splits === 0) {
2325
+ else for (const splits of [8, 4]) if (s % splits === 0) {
2241
2326
  dim.applyUnroll(dim.unroll - 1, splits);
2242
2327
  break;
2243
2328
  }
@@ -3513,7 +3598,7 @@ var I32 = class {
3513
3598
  shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3514
3599
  rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3515
3600
  rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3516
- eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3601
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
3517
3602
  eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3518
3603
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3519
3604
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
@@ -3978,7 +4063,7 @@ function translateExp(cg, funcs, exp, ctx) {
3978
4063
  else throw new UnsupportedOpError(op, dtype, "wasm");
3979
4064
  else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3980
4065
  else dtyF(cg, op, dtype).max();
3981
- else if (dtype === DType.Int32 || dtype === DType.Uint32) {
4066
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3982
4067
  const a = cg.local.declare(cg.i32);
3983
4068
  const b = cg.local.declare(cg.i32);
3984
4069
  cg.local.set(b);
@@ -4128,7 +4213,8 @@ function dtyF(cg, op, dtype) {
4128
4213
  const devices = [
4129
4214
  "cpu",
4130
4215
  "wasm",
4131
- "webgpu"
4216
+ "webgpu",
4217
+ "webgl"
4132
4218
  ];
4133
4219
  const initializedBackends = /* @__PURE__ */ new Map();
4134
4220
  initializedBackends.set("cpu", new CpuBackend());
@@ -4167,7 +4253,7 @@ async function createBackend(device) {
4167
4253
  if (!navigator.gpu) return null;
4168
4254
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4169
4255
  if (!adapter) return null;
4170
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-Oj3Kd-kd.cjs"));
4256
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-rraa6dfz.cjs"));
4171
4257
  const importantLimits = [
4172
4258
  "maxBufferSize",
4173
4259
  "maxComputeInvocationsPerWorkgroup",
@@ -4191,6 +4277,22 @@ async function createBackend(device) {
4191
4277
  console.error("Unexpected error requesting WebGPU device:", error);
4192
4278
  return null;
4193
4279
  }
4280
+ } else if (device === "webgl") {
4281
+ if (typeof WebGL2RenderingContext === "undefined") return null;
4282
+ const canvas = new OffscreenCanvas(0, 0);
4283
+ const gl = canvas.getContext("webgl2", {
4284
+ alpha: false,
4285
+ antialias: false,
4286
+ premultipliedAlpha: false,
4287
+ preserveDrawingBuffer: false,
4288
+ depth: false,
4289
+ stencil: false,
4290
+ failIfMajorPerformanceCaveat: true
4291
+ });
4292
+ if (!gl) return null;
4293
+ if (!gl.getExtension("EXT_color_buffer_float")) return null;
4294
+ const { WebGLBackend } = await Promise.resolve().then(() => require("./webgl-CyfzNW8T.cjs"));
4295
+ return new WebGLBackend(gl);
4194
4296
  } else throw new Error(`Backend not found: ${device}`);
4195
4297
  }
4196
4298
  /** Retrieve a backend that has been initialized. */
@@ -4357,6 +4459,12 @@ Object.defineProperty(exports, 'checkAxis', {
4357
4459
  return checkAxis;
4358
4460
  }
4359
4461
  });
4462
+ Object.defineProperty(exports, 'checkInts', {
4463
+ enumerable: true,
4464
+ get: function () {
4465
+ return checkInts;
4466
+ }
4467
+ });
4360
4468
  Object.defineProperty(exports, 'deepEqual', {
4361
4469
  enumerable: true,
4362
4470
  get: function () {
@@ -4507,6 +4615,12 @@ Object.defineProperty(exports, 'toposort', {
4507
4615
  return toposort;
4508
4616
  }
4509
4617
  });
4618
+ Object.defineProperty(exports, 'tuneNullopt', {
4619
+ enumerable: true,
4620
+ get: function () {
4621
+ return tuneNullopt;
4622
+ }
4623
+ });
4510
4624
  Object.defineProperty(exports, 'tuneWebgpu', {
4511
4625
  enumerable: true,
4512
4626
  get: function () {
@@ -150,6 +150,19 @@ function normalizeAxis(axis, ndim) {
150
150
  return sorted(seen);
151
151
  }
152
152
  }
153
+ /** Check for an array of integers with no duplicates. */
154
+ function checkInts(indices) {
155
+ if (typeof indices === "number") {
156
+ if (!Number.isInteger(indices)) throw new TypeError(`Expected integer index, got ${indices}`);
157
+ } else {
158
+ const seen = /* @__PURE__ */ new Set();
159
+ for (const i of indices) {
160
+ if (!Number.isInteger(i)) throw new TypeError(`Expected integer indices, got ${i}`);
161
+ if (seen.has(i)) throw new Error(`Duplicate index ${i} passed to function`);
162
+ seen.add(i);
163
+ }
164
+ }
165
+ }
153
166
  function range(start, stop, step = 1) {
154
167
  if (stop === void 0) {
155
168
  stop = start;
@@ -1470,10 +1483,37 @@ let Routines = /* @__PURE__ */ function(Routines$1) {
1470
1483
  Routines$1["Sort"] = "Sort";
1471
1484
  /** Returns `int32` indices of the stably sorted array. */
1472
1485
  Routines$1["Argsort"] = "Argsort";
1473
- /** Solve a triangular system of questions. */
1486
+ /**
1487
+ * Solve a triangular system of equations.
1488
+ *
1489
+ * The first batch of inputs `A` should be of shape `[..., N, N]` and upper
1490
+ * triangular, while the second batch `B` should be of shape `[..., M, N]`.
1491
+ *
1492
+ * Solves for `X` in the equation `A @ X.T = B.T`, where `A` is the
1493
+ * triangular matrix. This is equivalent to `X = B @ A^-T`.
1494
+ */
1474
1495
  Routines$1["TriangularSolve"] = "TriangularSolve";
1475
- /** Cholesky decomposition of 2D positive semi-definite matrices. */
1496
+ /**
1497
+ * Cholesky decomposition of 2D positive semi-definite matrices.
1498
+ *
1499
+ * The input batch should be of shape `[..., N, N]`, and the output batch is
1500
+ * of the same shape, containing the lower-triangular matrix `L` such that
1501
+ * `A = L @ L.T`. Behavior is unspecified if A is not positive semi-definite.
1502
+ */
1476
1503
  Routines$1["Cholesky"] = "Cholesky";
1504
+ /**
1505
+ * LU decomposition of 2D rectangular matrices.
1506
+ *
1507
+ * The input is a batch of shape `[..., M, N]`, and the output is a tuple of
1508
+ * three arrays: `LU, Pivots, Permutation`.
1509
+ *
1510
+ * - `LU` is of shape `[..., M, N]`, containing the combined lower and upper
1511
+ * triangular matrices. (lower triangular = implicit unit diagonal)
1512
+ * - `Pivots` is of shape `[..., min(M, N)]`, containing the row swaps.
1513
+ * - `Permutation` is of shape `[..., M]`, containing the permutation vector
1514
+ * such that `P = eye(M).slice(Permutation)` -> `P @ A = L @ U`.
1515
+ */
1516
+ Routines$1["LU"] = "LU";
1477
1517
  return Routines$1;
1478
1518
  }({});
1479
1519
  function runCpuRoutine(routine, inputs, outputs) {
@@ -1485,6 +1525,7 @@ function runCpuRoutine(routine, inputs, outputs) {
1485
1525
  case Routines.Argsort: return runArgsort(type, inputAr, outputAr);
1486
1526
  case Routines.TriangularSolve: return runTriangularSolve(type, inputAr, outputAr, routine.params);
1487
1527
  case Routines.Cholesky: return runCholesky(type, inputAr, outputAr);
1528
+ case Routines.LU: return runLU(type, inputAr, outputAr);
1488
1529
  default:
1489
1530
  }
1490
1531
  }
@@ -1545,6 +1586,50 @@ function runCholesky(type, [x], [y]) {
1545
1586
  }
1546
1587
  }
1547
1588
  }
1589
+ function runLU(type, [a], [lu, pivots, perm]) {
1590
+ const shape = type.inputShapes[0];
1591
+ if (shape.length < 2) throw new Error("lu: input must be at least 2D");
1592
+ const m = shape[shape.length - 2];
1593
+ const n = shape[shape.length - 1];
1594
+ const r = Math.min(m, n);
1595
+ for (let offset = 0; offset < a.length; offset += m * n) {
1596
+ const ar = a.subarray(offset, offset + m * n);
1597
+ const out = lu.subarray(offset, offset + m * n);
1598
+ const batchIdx = offset / (m * n);
1599
+ const piv = pivots.subarray(batchIdx * r, (batchIdx + 1) * r);
1600
+ const p = perm.subarray(batchIdx * m, (batchIdx + 1) * m);
1601
+ out.set(ar);
1602
+ for (let i = 0; i < m; i++) p[i] = i;
1603
+ for (let j = 0; j < r; j++) {
1604
+ let maxVal = Math.abs(out[j * n + j]);
1605
+ let maxRow = j;
1606
+ for (let i = j + 1; i < m; i++) {
1607
+ const val = Math.abs(out[i * n + j]);
1608
+ if (val > maxVal) {
1609
+ maxVal = val;
1610
+ maxRow = i;
1611
+ }
1612
+ }
1613
+ piv[j] = maxRow;
1614
+ if (maxRow !== j) {
1615
+ for (let col = 0; col < n; col++) {
1616
+ const tmp = out[j * n + col];
1617
+ out[j * n + col] = out[maxRow * n + col];
1618
+ out[maxRow * n + col] = tmp;
1619
+ }
1620
+ const tmpP = p[j];
1621
+ p[j] = p[maxRow];
1622
+ p[maxRow] = tmpP;
1623
+ }
1624
+ const diag = out[j * n + j];
1625
+ if (diag !== 0) for (let i = j + 1; i < m; i++) {
1626
+ const factor = out[i * n + j] / diag;
1627
+ out[i * n + j] = factor;
1628
+ for (let col = j + 1; col < n; col++) out[i * n + col] -= factor * out[j * n + col];
1629
+ }
1630
+ }
1631
+ }
1632
+ }
1548
1633
 
1549
1634
  //#endregion
1550
1635
  //#region src/shape.ts
@@ -2233,10 +2318,10 @@ function tuneWebgpu(kernel) {
2233
2318
  upcastedAxis.add(choices[0][2]);
2234
2319
  } else break;
2235
2320
  }
2236
- if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2321
+ if (!/Mobi|Android/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
2237
2322
  const s = dim.st.shape[dim.unroll - 1];
2238
2323
  if (0 < s && s <= 32) dim.applyUnroll(dim.reduce, s);
2239
- else for (const splits of [4]) if (s % splits === 0) {
2324
+ else for (const splits of [8, 4]) if (s % splits === 0) {
2240
2325
  dim.applyUnroll(dim.unroll - 1, splits);
2241
2326
  break;
2242
2327
  }
@@ -3512,7 +3597,7 @@ var I32 = class {
3512
3597
  shr_u = BINARY_OP("shr_u", 118, "i32", "i32", "i32");
3513
3598
  rotl = BINARY_OP("rotl", 119, "i32", "i32", "i32");
3514
3599
  rotr = BINARY_OP("rotr", 120, "i32", "i32", "i32");
3515
- eqz = BINARY_OP("eqz", 69, "i32", "i32", "i32");
3600
+ eqz = UNARY_OP("eqz", 69, "i32", "i32");
3516
3601
  eq = BINARY_OP("eq", 70, "i32", "i32", "i32");
3517
3602
  ne = BINARY_OP("ne", 71, "i32", "i32", "i32");
3518
3603
  trunc_f32_s = UNARY_OP("trunc_f32_s", 168, "f32", "i32");
@@ -3977,7 +4062,7 @@ function translateExp(cg, funcs, exp, ctx) {
3977
4062
  else throw new UnsupportedOpError(op, dtype, "wasm");
3978
4063
  else if (op === AluOp.Min || op === AluOp.Max) if (isFloatDtype(dtype)) if (op === AluOp.Min) dtyF(cg, op, dtype).min();
3979
4064
  else dtyF(cg, op, dtype).max();
3980
- else if (dtype === DType.Int32 || dtype === DType.Uint32) {
4065
+ else if (dtype === DType.Int32 || dtype === DType.Uint32 || dtype === DType.Bool) {
3981
4066
  const a = cg.local.declare(cg.i32);
3982
4067
  const b = cg.local.declare(cg.i32);
3983
4068
  cg.local.set(b);
@@ -4127,7 +4212,8 @@ function dtyF(cg, op, dtype) {
4127
4212
  const devices = [
4128
4213
  "cpu",
4129
4214
  "wasm",
4130
- "webgpu"
4215
+ "webgpu",
4216
+ "webgl"
4131
4217
  ];
4132
4218
  const initializedBackends = /* @__PURE__ */ new Map();
4133
4219
  initializedBackends.set("cpu", new CpuBackend());
@@ -4166,7 +4252,7 @@ async function createBackend(device) {
4166
4252
  if (!navigator.gpu) return null;
4167
4253
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
4168
4254
  if (!adapter) return null;
4169
- const { WebGPUBackend } = await import("./webgpu-ChVgx3b6.js");
4255
+ const { WebGPUBackend } = await import("./webgpu-C-VfevQW.js");
4170
4256
  const importantLimits = [
4171
4257
  "maxBufferSize",
4172
4258
  "maxComputeInvocationsPerWorkgroup",
@@ -4190,6 +4276,22 @@ async function createBackend(device) {
4190
4276
  console.error("Unexpected error requesting WebGPU device:", error);
4191
4277
  return null;
4192
4278
  }
4279
+ } else if (device === "webgl") {
4280
+ if (typeof WebGL2RenderingContext === "undefined") return null;
4281
+ const canvas = new OffscreenCanvas(0, 0);
4282
+ const gl = canvas.getContext("webgl2", {
4283
+ alpha: false,
4284
+ antialias: false,
4285
+ premultipliedAlpha: false,
4286
+ preserveDrawingBuffer: false,
4287
+ depth: false,
4288
+ stencil: false,
4289
+ failIfMajorPerformanceCaveat: true
4290
+ });
4291
+ if (!gl) return null;
4292
+ if (!gl.getExtension("EXT_color_buffer_float")) return null;
4293
+ const { WebGLBackend } = await import("./webgl-CLLvzJlO.js");
4294
+ return new WebGLBackend(gl);
4193
4295
  } else throw new Error(`Backend not found: ${device}`);
4194
4296
  }
4195
4297
  /** Retrieve a backend that has been initialized. */
@@ -4224,4 +4326,4 @@ var UnsupportedRoutineError = class extends Error {
4224
4326
  };
4225
4327
 
4226
4328
  //#endregion
4227
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip, zipn };
4329
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, SlotError, UnsupportedOpError, UnsupportedRoutineError, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, mapSetUnion, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneNullopt, tuneWebgpu, unravelAlu, unzip2, zip, zipn };