@jax-js/jax 0.1.13 → 0.1.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DMyuoWi2.cjs');
33
+ const require_backend = require('./backend-VlXzdQvR.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -3224,6 +3224,15 @@ var Array$1 = class Array$1 extends Tracer {
3224
3224
  },
3225
3225
  [Primitive.Conv]([x, y], params) {
3226
3226
  checkConvShape(x.shape, y.shape, params);
3227
+ const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
3228
+ if (shouldMaterializePadding) {
3229
+ x = x.#reshape(x.#st.padOrShrink([...require_backend.rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
3230
+ x.#realize();
3231
+ params = {
3232
+ ...params,
3233
+ padding: require_backend.rep(params.padding.length, [0, 0])
3234
+ };
3235
+ }
3227
3236
  const [stX, stY] = prepareConv(x.#st, y.#st, params);
3228
3237
  return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
3229
3238
  },
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DLEk-B3V.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-apsUOPzb.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -3189,6 +3189,15 @@ var Array$1 = class Array$1 extends Tracer {
3189
3189
  },
3190
3190
  [Primitive.Conv]([x, y], params) {
3191
3191
  checkConvShape(x.shape, y.shape, params);
3192
+ const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
3193
+ if (shouldMaterializePadding) {
3194
+ x = x.#reshape(x.#st.padOrShrink([...rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
3195
+ x.#realize();
3196
+ params = {
3197
+ ...params,
3198
+ padding: rep(params.padding.length, [0, 0])
3199
+ };
3200
+ }
3192
3201
  const [stX, stY] = prepareConv(x.#st, y.#st, params);
3193
3202
  return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
3194
3203
  },
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DMyuoWi2.cjs');
1
+ const require_backend = require('./backend-VlXzdQvR.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DLEk-B3V.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-apsUOPzb.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DLEk-B3V.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-apsUOPzb.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
147
147
  }
148
148
  throw new Error(`Unsupported const dtype: ${dtype}`);
149
149
  }
150
+ function reduceOpWgsl(op, dtype, a, b) {
151
+ if (op === AluOp.Add) return `(${a} + ${b})`;
152
+ if (op === AluOp.Mul) return `(${a} * ${b})`;
153
+ if (op === AluOp.Min) return dtype === DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
154
+ if (op === AluOp.Max) return dtype === DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
155
+ throw new Error(`Unsupported reduction op: ${op}`);
156
+ }
150
157
  /** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
151
158
  var WgslExpCodegen = class {
152
159
  #gensymCount = 0;
@@ -1316,14 +1323,30 @@ function pipelineSource(device, kernel) {
1316
1323
  }
1317
1324
  const resultTy = dtypeToWgsl(kernel.dtype, true);
1318
1325
  wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
1319
- const workgroupSize = findPow2(tune.threadCount, 256);
1320
- const gridSize = Math.ceil(tune.threadCount / workgroupSize);
1326
+ const groupCount = re ? tune.size.groups ?? 1 : 1;
1327
+ const groupedReduction = re && groupCount > 1;
1328
+ if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
1329
+ if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
1330
+ const workgroupSize = groupedReduction ? groupCount : findPow2(tune.threadCount, 256);
1331
+ const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
1321
1332
  const [gridX, gridY] = calculateGrid(gridSize);
1322
- wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1323
- if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1324
- else {
1325
- const sizeX = gridX * workgroupSize;
1326
- wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1333
+ if (groupedReduction) {
1334
+ const partialTy = dtypeToWgsl(re.dtype);
1335
+ for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
1336
+ }
1337
+ wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
1338
+ if (groupedReduction) {
1339
+ wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
1340
+ if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
1341
+ else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
1342
+ wb.emit("let group: i32 = i32(lid.x);");
1343
+ } else {
1344
+ wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1345
+ if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1346
+ else {
1347
+ const sizeX = gridX * workgroupSize;
1348
+ wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1349
+ }
1327
1350
  }
1328
1351
  wb.emitPhonyAssignments(args);
1329
1352
  const gen = new WgslExpCodegen(wb, args);
@@ -1333,7 +1356,6 @@ function pipelineSource(device, kernel) {
1333
1356
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
1334
1357
  wb.emit(`result[gidx] = ${rhs};`);
1335
1358
  } else {
1336
- if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
1337
1359
  const unroll = tune.size.unroll ?? 1;
1338
1360
  const upcast = tune.size.upcast ?? 1;
1339
1361
  const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
@@ -1369,6 +1391,15 @@ function pipelineSource(device, kernel) {
1369
1391
  else throw new Error(`Unsupported reduction op: ${re.op}`);
1370
1392
  }
1371
1393
  wb.emit(wb.popIndent, "}");
1394
+ if (groupedReduction) {
1395
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
1396
+ wb.emit("workgroupBarrier();");
1397
+ for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
1398
+ wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
1399
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
1400
+ wb.emit(wb.popIndent, "}", "workgroupBarrier();");
1401
+ }
1402
+ }
1372
1403
  gen.reset();
1373
1404
  const outputIdxExps = [];
1374
1405
  const fusionExps = [];
@@ -1382,12 +1413,17 @@ function pipelineSource(device, kernel) {
1382
1413
  }).simplify(cache));
1383
1414
  gen.countReferences(fusionExps[i]);
1384
1415
  }
1416
+ if (groupedReduction) {
1417
+ wb.emit("if (lid.x == 0u) {", wb.pushIndent);
1418
+ for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
1419
+ }
1385
1420
  for (let i = 0; i < upcast; i++) {
1386
1421
  const index = strip1(gen.run(outputIdxExps[i]));
1387
1422
  let rhs = strip1(gen.run(fusionExps[i]));
1388
1423
  if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
1389
1424
  wb.emit(`result[${index}] = ${rhs};`);
1390
1425
  }
1426
+ if (groupedReduction) wb.emit(wb.popIndent, "}");
1391
1427
  }
1392
1428
  wb.emit(wb.popIndent, "}");
1393
1429
  return {
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DMyuoWi2.cjs');
1
+ const require_backend = require('./backend-VlXzdQvR.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
147
147
  }
148
148
  throw new Error(`Unsupported const dtype: ${dtype}`);
149
149
  }
150
+ function reduceOpWgsl(op, dtype, a, b) {
151
+ if (op === require_backend.AluOp.Add) return `(${a} + ${b})`;
152
+ if (op === require_backend.AluOp.Mul) return `(${a} * ${b})`;
153
+ if (op === require_backend.AluOp.Min) return dtype === require_backend.DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
154
+ if (op === require_backend.AluOp.Max) return dtype === require_backend.DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
155
+ throw new Error(`Unsupported reduction op: ${op}`);
156
+ }
150
157
  /** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
151
158
  var WgslExpCodegen = class {
152
159
  #gensymCount = 0;
@@ -1316,14 +1323,30 @@ function pipelineSource(device, kernel) {
1316
1323
  }
1317
1324
  const resultTy = dtypeToWgsl(kernel.dtype, true);
1318
1325
  wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
1319
- const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
1320
- const gridSize = Math.ceil(tune.threadCount / workgroupSize);
1326
+ const groupCount = re ? tune.size.groups ?? 1 : 1;
1327
+ const groupedReduction = re && groupCount > 1;
1328
+ if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
1329
+ if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
1330
+ const workgroupSize = groupedReduction ? groupCount : require_backend.findPow2(tune.threadCount, 256);
1331
+ const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
1321
1332
  const [gridX, gridY] = calculateGrid(gridSize);
1322
- wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1323
- if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1324
- else {
1325
- const sizeX = gridX * workgroupSize;
1326
- wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1333
+ if (groupedReduction) {
1334
+ const partialTy = dtypeToWgsl(re.dtype);
1335
+ for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
1336
+ }
1337
+ wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
1338
+ if (groupedReduction) {
1339
+ wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
1340
+ if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
1341
+ else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
1342
+ wb.emit("let group: i32 = i32(lid.x);");
1343
+ } else {
1344
+ wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1345
+ if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1346
+ else {
1347
+ const sizeX = gridX * workgroupSize;
1348
+ wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1349
+ }
1327
1350
  }
1328
1351
  wb.emitPhonyAssignments(args);
1329
1352
  const gen = new WgslExpCodegen(wb, args);
@@ -1333,7 +1356,6 @@ function pipelineSource(device, kernel) {
1333
1356
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
1334
1357
  wb.emit(`result[gidx] = ${rhs};`);
1335
1358
  } else {
1336
- if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
1337
1359
  const unroll = tune.size.unroll ?? 1;
1338
1360
  const upcast = tune.size.upcast ?? 1;
1339
1361
  const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
@@ -1369,6 +1391,15 @@ function pipelineSource(device, kernel) {
1369
1391
  else throw new Error(`Unsupported reduction op: ${re.op}`);
1370
1392
  }
1371
1393
  wb.emit(wb.popIndent, "}");
1394
+ if (groupedReduction) {
1395
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
1396
+ wb.emit("workgroupBarrier();");
1397
+ for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
1398
+ wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
1399
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
1400
+ wb.emit(wb.popIndent, "}", "workgroupBarrier();");
1401
+ }
1402
+ }
1372
1403
  gen.reset();
1373
1404
  const outputIdxExps = [];
1374
1405
  const fusionExps = [];
@@ -1382,12 +1413,17 @@ function pipelineSource(device, kernel) {
1382
1413
  }).simplify(cache));
1383
1414
  gen.countReferences(fusionExps[i]);
1384
1415
  }
1416
+ if (groupedReduction) {
1417
+ wb.emit("if (lid.x == 0u) {", wb.pushIndent);
1418
+ for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
1419
+ }
1385
1420
  for (let i = 0; i < upcast; i++) {
1386
1421
  const index = require_backend.strip1(gen.run(outputIdxExps[i]));
1387
1422
  let rhs = require_backend.strip1(gen.run(fusionExps[i]));
1388
1423
  if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
1389
1424
  wb.emit(`result[${index}] = ${rhs};`);
1390
1425
  }
1426
+ if (groupedReduction) wb.emit(wb.popIndent, "}");
1391
1427
  }
1392
1428
  wb.emit(wb.popIndent, "}");
1393
1429
  return {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.13",
3
+ "version": "0.1.14",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",