@jax-js/jax 0.1.12 → 0.1.14

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-x-6vqzIM.cjs');
33
+ const require_backend = require('./backend-VlXzdQvR.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -3224,6 +3224,15 @@ var Array$1 = class Array$1 extends Tracer {
3224
3224
  },
3225
3225
  [Primitive.Conv]([x, y], params) {
3226
3226
  checkConvShape(x.shape, y.shape, params);
3227
+ const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
3228
+ if (shouldMaterializePadding) {
3229
+ x = x.#reshape(x.#st.padOrShrink([...require_backend.rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
3230
+ x.#realize();
3231
+ params = {
3232
+ ...params,
3233
+ padding: require_backend.rep(params.padding.length, [0, 0])
3234
+ };
3235
+ }
3227
3236
  const [stX, stY] = prepareConv(x.#st, y.#st, params);
3228
3237
  return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
3229
3238
  },
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DI-V78Rk.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-apsUOPzb.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -3189,6 +3189,15 @@ var Array$1 = class Array$1 extends Tracer {
3189
3189
  },
3190
3190
  [Primitive.Conv]([x, y], params) {
3191
3191
  checkConvShape(x.shape, y.shape, params);
3192
+ const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
3193
+ if (shouldMaterializePadding) {
3194
+ x = x.#reshape(x.#st.padOrShrink([...rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
3195
+ x.#realize();
3196
+ params = {
3197
+ ...params,
3198
+ padding: rep(params.padding.length, [0, 0])
3199
+ };
3200
+ }
3192
3201
  const [stX, stY] = prepareConv(x.#st, y.#st, params);
3193
3202
  return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
3194
3203
  },
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-x-6vqzIM.cjs');
1
+ const require_backend = require('./backend-VlXzdQvR.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DI-V78Rk.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-apsUOPzb.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-DI-V78Rk.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-apsUOPzb.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
147
147
  }
148
148
  throw new Error(`Unsupported const dtype: ${dtype}`);
149
149
  }
150
+ function reduceOpWgsl(op, dtype, a, b) {
151
+ if (op === AluOp.Add) return `(${a} + ${b})`;
152
+ if (op === AluOp.Mul) return `(${a} * ${b})`;
153
+ if (op === AluOp.Min) return dtype === DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
154
+ if (op === AluOp.Max) return dtype === DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
155
+ throw new Error(`Unsupported reduction op: ${op}`);
156
+ }
150
157
  /** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
151
158
  var WgslExpCodegen = class {
152
159
  #gensymCount = 0;
@@ -1099,6 +1106,8 @@ function flushTracingBatch(device, batch) {
1099
1106
 
1100
1107
  //#endregion
1101
1108
  //#region src/backend/webgpu.ts
1109
+ const MAX_REUSABLE_BUFFER_BYTES = 64 * 1024 * 1024;
1110
+ const MAX_REUSABLE_BUFFERS_PER_SIZE = 64;
1102
1111
  /** Implementation of `Backend` that uses WebGPU in browsers. */
1103
1112
  var WebGPUBackend = class {
1104
1113
  type = "webgpu";
@@ -1109,6 +1118,7 @@ var WebGPUBackend = class {
1109
1118
  nextSlot;
1110
1119
  #cachedShaderMap = /* @__PURE__ */ new Map();
1111
1120
  #reusableZsb;
1121
+ #bufferPool = /* @__PURE__ */ new Map();
1112
1122
  constructor(device) {
1113
1123
  this.device = device;
1114
1124
  if (DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
@@ -1123,31 +1133,22 @@ var WebGPUBackend = class {
1123
1133
  });
1124
1134
  }
1125
1135
  malloc(size, initialData) {
1126
- let buffer;
1127
- const paddedSize = Math.ceil(size / 4) * 4;
1128
- if (size === 0) buffer = this.#reusableZsb;
1129
- else if (initialData) {
1130
- if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
1131
- if (initialData.byteLength < 4096) {
1132
- buffer = this.#createBuffer(paddedSize, { mapped: true });
1133
- new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
1134
- buffer.unmap();
1135
- } else {
1136
- buffer = this.#createBuffer(paddedSize);
1137
- if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
1138
- else {
1139
- const aligned = initialData.byteLength - initialData.byteLength % 4;
1140
- this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
1141
- const remainder = new Uint8Array(4);
1142
- remainder.set(initialData.subarray(aligned));
1143
- this.device.queue.writeBuffer(buffer, aligned, remainder);
1144
- }
1145
- }
1146
- } else buffer = this.#createBuffer(paddedSize);
1136
+ if (initialData && initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
1137
+ const allocatedSize = Math.ceil(size / 4) * 4 || 4;
1138
+ const buffer = size === 0 ? this.#reusableZsb : this.#acquireBuffer(allocatedSize);
1139
+ if (initialData && size > 0) if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
1140
+ else {
1141
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
1142
+ if (aligned > 0) this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
1143
+ const remainder = new Uint8Array(4);
1144
+ remainder.set(initialData.subarray(aligned));
1145
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
1146
+ }
1147
1147
  const slot = this.nextSlot++;
1148
1148
  this.buffers.set(slot, {
1149
1149
  buffer,
1150
1150
  size,
1151
+ allocatedSize,
1151
1152
  ref: 1
1152
1153
  });
1153
1154
  return slot;
@@ -1163,7 +1164,7 @@ var WebGPUBackend = class {
1163
1164
  buffer.ref--;
1164
1165
  if (buffer.ref === 0) {
1165
1166
  this.buffers.delete(slot);
1166
- if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
1167
+ if (buffer.buffer !== this.#reusableZsb) this.#releaseBuffer(buffer.buffer, buffer.allocatedSize);
1167
1168
  }
1168
1169
  }
1169
1170
  async read(slot, start, count) {
@@ -1251,6 +1252,29 @@ var WebGPUBackend = class {
1251
1252
  size: buffer.size
1252
1253
  };
1253
1254
  }
1255
+ #acquireBuffer(size) {
1256
+ if (size > MAX_REUSABLE_BUFFER_BYTES) return this.#createBuffer(size);
1257
+ const bucket = this.#bufferPool.get(size);
1258
+ const buffer = bucket?.pop();
1259
+ if (bucket && bucket.length === 0) this.#bufferPool.delete(size);
1260
+ return buffer ?? this.#createBuffer(size);
1261
+ }
1262
+ #releaseBuffer(buffer, size) {
1263
+ if (size > MAX_REUSABLE_BUFFER_BYTES) {
1264
+ buffer.destroy();
1265
+ return;
1266
+ }
1267
+ const bucket = this.#bufferPool.get(size);
1268
+ if (!bucket) {
1269
+ this.#bufferPool.set(size, [buffer]);
1270
+ return;
1271
+ }
1272
+ if (bucket.length >= MAX_REUSABLE_BUFFERS_PER_SIZE) {
1273
+ buffer.destroy();
1274
+ return;
1275
+ }
1276
+ bucket.push(buffer);
1277
+ }
1254
1278
  /**
1255
1279
  * Create a GPU buffer.
1256
1280
  *
@@ -1299,14 +1323,30 @@ function pipelineSource(device, kernel) {
1299
1323
  }
1300
1324
  const resultTy = dtypeToWgsl(kernel.dtype, true);
1301
1325
  wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
1302
- const workgroupSize = findPow2(tune.threadCount, 256);
1303
- const gridSize = Math.ceil(tune.threadCount / workgroupSize);
1326
+ const groupCount = re ? tune.size.groups ?? 1 : 1;
1327
+ const groupedReduction = re && groupCount > 1;
1328
+ if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
1329
+ if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
1330
+ const workgroupSize = groupedReduction ? groupCount : findPow2(tune.threadCount, 256);
1331
+ const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
1304
1332
  const [gridX, gridY] = calculateGrid(gridSize);
1305
- wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1306
- if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1307
- else {
1308
- const sizeX = gridX * workgroupSize;
1309
- wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1333
+ if (groupedReduction) {
1334
+ const partialTy = dtypeToWgsl(re.dtype);
1335
+ for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
1336
+ }
1337
+ wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
1338
+ if (groupedReduction) {
1339
+ wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
1340
+ if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
1341
+ else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
1342
+ wb.emit("let group: i32 = i32(lid.x);");
1343
+ } else {
1344
+ wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1345
+ if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1346
+ else {
1347
+ const sizeX = gridX * workgroupSize;
1348
+ wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1349
+ }
1310
1350
  }
1311
1351
  wb.emitPhonyAssignments(args);
1312
1352
  const gen = new WgslExpCodegen(wb, args);
@@ -1316,7 +1356,6 @@ function pipelineSource(device, kernel) {
1316
1356
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
1317
1357
  wb.emit(`result[gidx] = ${rhs};`);
1318
1358
  } else {
1319
- if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
1320
1359
  const unroll = tune.size.unroll ?? 1;
1321
1360
  const upcast = tune.size.upcast ?? 1;
1322
1361
  const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
@@ -1352,6 +1391,15 @@ function pipelineSource(device, kernel) {
1352
1391
  else throw new Error(`Unsupported reduction op: ${re.op}`);
1353
1392
  }
1354
1393
  wb.emit(wb.popIndent, "}");
1394
+ if (groupedReduction) {
1395
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
1396
+ wb.emit("workgroupBarrier();");
1397
+ for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
1398
+ wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
1399
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
1400
+ wb.emit(wb.popIndent, "}", "workgroupBarrier();");
1401
+ }
1402
+ }
1355
1403
  gen.reset();
1356
1404
  const outputIdxExps = [];
1357
1405
  const fusionExps = [];
@@ -1365,12 +1413,17 @@ function pipelineSource(device, kernel) {
1365
1413
  }).simplify(cache));
1366
1414
  gen.countReferences(fusionExps[i]);
1367
1415
  }
1416
+ if (groupedReduction) {
1417
+ wb.emit("if (lid.x == 0u) {", wb.pushIndent);
1418
+ for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
1419
+ }
1368
1420
  for (let i = 0; i < upcast; i++) {
1369
1421
  const index = strip1(gen.run(outputIdxExps[i]));
1370
1422
  let rhs = strip1(gen.run(fusionExps[i]));
1371
1423
  if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
1372
1424
  wb.emit(`result[${index}] = ${rhs};`);
1373
1425
  }
1426
+ if (groupedReduction) wb.emit(wb.popIndent, "}");
1374
1427
  }
1375
1428
  wb.emit(wb.popIndent, "}");
1376
1429
  return {
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-x-6vqzIM.cjs');
1
+ const require_backend = require('./backend-VlXzdQvR.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
147
147
  }
148
148
  throw new Error(`Unsupported const dtype: ${dtype}`);
149
149
  }
150
+ function reduceOpWgsl(op, dtype, a, b) {
151
+ if (op === require_backend.AluOp.Add) return `(${a} + ${b})`;
152
+ if (op === require_backend.AluOp.Mul) return `(${a} * ${b})`;
153
+ if (op === require_backend.AluOp.Min) return dtype === require_backend.DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
154
+ if (op === require_backend.AluOp.Max) return dtype === require_backend.DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
155
+ throw new Error(`Unsupported reduction op: ${op}`);
156
+ }
150
157
  /** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
151
158
  var WgslExpCodegen = class {
152
159
  #gensymCount = 0;
@@ -1099,6 +1106,8 @@ function flushTracingBatch(device, batch) {
1099
1106
 
1100
1107
  //#endregion
1101
1108
  //#region src/backend/webgpu.ts
1109
+ const MAX_REUSABLE_BUFFER_BYTES = 64 * 1024 * 1024;
1110
+ const MAX_REUSABLE_BUFFERS_PER_SIZE = 64;
1102
1111
  /** Implementation of `Backend` that uses WebGPU in browsers. */
1103
1112
  var WebGPUBackend = class {
1104
1113
  type = "webgpu";
@@ -1109,6 +1118,7 @@ var WebGPUBackend = class {
1109
1118
  nextSlot;
1110
1119
  #cachedShaderMap = /* @__PURE__ */ new Map();
1111
1120
  #reusableZsb;
1121
+ #bufferPool = /* @__PURE__ */ new Map();
1112
1122
  constructor(device) {
1113
1123
  this.device = device;
1114
1124
  if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
@@ -1123,31 +1133,22 @@ var WebGPUBackend = class {
1123
1133
  });
1124
1134
  }
1125
1135
  malloc(size, initialData) {
1126
- let buffer;
1127
- const paddedSize = Math.ceil(size / 4) * 4;
1128
- if (size === 0) buffer = this.#reusableZsb;
1129
- else if (initialData) {
1130
- if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
1131
- if (initialData.byteLength < 4096) {
1132
- buffer = this.#createBuffer(paddedSize, { mapped: true });
1133
- new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
1134
- buffer.unmap();
1135
- } else {
1136
- buffer = this.#createBuffer(paddedSize);
1137
- if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
1138
- else {
1139
- const aligned = initialData.byteLength - initialData.byteLength % 4;
1140
- this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
1141
- const remainder = new Uint8Array(4);
1142
- remainder.set(initialData.subarray(aligned));
1143
- this.device.queue.writeBuffer(buffer, aligned, remainder);
1144
- }
1145
- }
1146
- } else buffer = this.#createBuffer(paddedSize);
1136
+ if (initialData && initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
1137
+ const allocatedSize = Math.ceil(size / 4) * 4 || 4;
1138
+ const buffer = size === 0 ? this.#reusableZsb : this.#acquireBuffer(allocatedSize);
1139
+ if (initialData && size > 0) if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
1140
+ else {
1141
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
1142
+ if (aligned > 0) this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
1143
+ const remainder = new Uint8Array(4);
1144
+ remainder.set(initialData.subarray(aligned));
1145
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
1146
+ }
1147
1147
  const slot = this.nextSlot++;
1148
1148
  this.buffers.set(slot, {
1149
1149
  buffer,
1150
1150
  size,
1151
+ allocatedSize,
1151
1152
  ref: 1
1152
1153
  });
1153
1154
  return slot;
@@ -1163,7 +1164,7 @@ var WebGPUBackend = class {
1163
1164
  buffer.ref--;
1164
1165
  if (buffer.ref === 0) {
1165
1166
  this.buffers.delete(slot);
1166
- if (buffer.buffer !== this.#reusableZsb) buffer.buffer.destroy();
1167
+ if (buffer.buffer !== this.#reusableZsb) this.#releaseBuffer(buffer.buffer, buffer.allocatedSize);
1167
1168
  }
1168
1169
  }
1169
1170
  async read(slot, start, count) {
@@ -1251,6 +1252,29 @@ var WebGPUBackend = class {
1251
1252
  size: buffer.size
1252
1253
  };
1253
1254
  }
1255
+ #acquireBuffer(size) {
1256
+ if (size > MAX_REUSABLE_BUFFER_BYTES) return this.#createBuffer(size);
1257
+ const bucket = this.#bufferPool.get(size);
1258
+ const buffer = bucket?.pop();
1259
+ if (bucket && bucket.length === 0) this.#bufferPool.delete(size);
1260
+ return buffer ?? this.#createBuffer(size);
1261
+ }
1262
+ #releaseBuffer(buffer, size) {
1263
+ if (size > MAX_REUSABLE_BUFFER_BYTES) {
1264
+ buffer.destroy();
1265
+ return;
1266
+ }
1267
+ const bucket = this.#bufferPool.get(size);
1268
+ if (!bucket) {
1269
+ this.#bufferPool.set(size, [buffer]);
1270
+ return;
1271
+ }
1272
+ if (bucket.length >= MAX_REUSABLE_BUFFERS_PER_SIZE) {
1273
+ buffer.destroy();
1274
+ return;
1275
+ }
1276
+ bucket.push(buffer);
1277
+ }
1254
1278
  /**
1255
1279
  * Create a GPU buffer.
1256
1280
  *
@@ -1299,14 +1323,30 @@ function pipelineSource(device, kernel) {
1299
1323
  }
1300
1324
  const resultTy = dtypeToWgsl(kernel.dtype, true);
1301
1325
  wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
1302
- const workgroupSize = require_backend.findPow2(tune.threadCount, 256);
1303
- const gridSize = Math.ceil(tune.threadCount / workgroupSize);
1326
+ const groupCount = re ? tune.size.groups ?? 1 : 1;
1327
+ const groupedReduction = re && groupCount > 1;
1328
+ if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
1329
+ if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
1330
+ const workgroupSize = groupedReduction ? groupCount : require_backend.findPow2(tune.threadCount, 256);
1331
+ const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
1304
1332
  const [gridX, gridY] = calculateGrid(gridSize);
1305
- wb.emit("", `@compute @workgroup_size(${workgroupSize})`, "fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1306
- if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1307
- else {
1308
- const sizeX = gridX * workgroupSize;
1309
- wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1333
+ if (groupedReduction) {
1334
+ const partialTy = dtypeToWgsl(re.dtype);
1335
+ for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
1336
+ }
1337
+ wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
1338
+ if (groupedReduction) {
1339
+ wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
1340
+ if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
1341
+ else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
1342
+ wb.emit("let group: i32 = i32(lid.x);");
1343
+ } else {
1344
+ wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
1345
+ if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
1346
+ else {
1347
+ const sizeX = gridX * workgroupSize;
1348
+ wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
1349
+ }
1310
1350
  }
1311
1351
  wb.emitPhonyAssignments(args);
1312
1352
  const gen = new WgslExpCodegen(wb, args);
@@ -1316,7 +1356,6 @@ function pipelineSource(device, kernel) {
1316
1356
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
1317
1357
  wb.emit(`result[gidx] = ${rhs};`);
1318
1358
  } else {
1319
- if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
1320
1359
  const unroll = tune.size.unroll ?? 1;
1321
1360
  const upcast = tune.size.upcast ?? 1;
1322
1361
  const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
@@ -1352,6 +1391,15 @@ function pipelineSource(device, kernel) {
1352
1391
  else throw new Error(`Unsupported reduction op: ${re.op}`);
1353
1392
  }
1354
1393
  wb.emit(wb.popIndent, "}");
1394
+ if (groupedReduction) {
1395
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
1396
+ wb.emit("workgroupBarrier();");
1397
+ for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
1398
+ wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
1399
+ for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
1400
+ wb.emit(wb.popIndent, "}", "workgroupBarrier();");
1401
+ }
1402
+ }
1355
1403
  gen.reset();
1356
1404
  const outputIdxExps = [];
1357
1405
  const fusionExps = [];
@@ -1365,12 +1413,17 @@ function pipelineSource(device, kernel) {
1365
1413
  }).simplify(cache));
1366
1414
  gen.countReferences(fusionExps[i]);
1367
1415
  }
1416
+ if (groupedReduction) {
1417
+ wb.emit("if (lid.x == 0u) {", wb.pushIndent);
1418
+ for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
1419
+ }
1368
1420
  for (let i = 0; i < upcast; i++) {
1369
1421
  const index = require_backend.strip1(gen.run(outputIdxExps[i]));
1370
1422
  let rhs = require_backend.strip1(gen.run(fusionExps[i]));
1371
1423
  if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
1372
1424
  wb.emit(`result[${index}] = ${rhs};`);
1373
1425
  }
1426
+ if (groupedReduction) wb.emit(wb.popIndent, "}");
1374
1427
  }
1375
1428
  wb.emit(wb.popIndent, "}");
1376
1429
  return {
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.12",
3
+ "version": "0.1.14",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",