@jax-js/jax 0.1.12 → 0.1.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +11 -7
- package/dist/{backend-x-6vqzIM.cjs → backend-VlXzdQvR.cjs} +2111 -1557
- package/dist/{backend-DI-V78Rk.js → backend-apsUOPzb.js} +2111 -1557
- package/dist/index.cjs +10 -1
- package/dist/index.js +10 -1
- package/dist/{webgl-CD3WK_Me.cjs → webgl-C6rCbloA.cjs} +1 -1
- package/dist/{webgl-BhsnpeB0.js → webgl-Hh0FX6oV.js} +1 -1
- package/dist/{webgpu-C2kLdkUh.js → webgpu-BRv5r9Sl.js} +84 -31
- package/dist/{webgpu-C4S8Uq9e.cjs → webgpu-pWnE96Xc.cjs} +84 -31
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-VlXzdQvR.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -3224,6 +3224,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3224
3224
|
},
|
|
3225
3225
|
[Primitive.Conv]([x, y], params) {
|
|
3226
3226
|
checkConvShape(x.shape, y.shape, params);
|
|
3227
|
+
const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
|
|
3228
|
+
if (shouldMaterializePadding) {
|
|
3229
|
+
x = x.#reshape(x.#st.padOrShrink([...require_backend.rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
|
|
3230
|
+
x.#realize();
|
|
3231
|
+
params = {
|
|
3232
|
+
...params,
|
|
3233
|
+
padding: require_backend.rep(params.padding.length, [0, 0])
|
|
3234
|
+
};
|
|
3235
|
+
}
|
|
3227
3236
|
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
3228
3237
|
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
3229
3238
|
},
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-apsUOPzb.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -3189,6 +3189,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3189
3189
|
},
|
|
3190
3190
|
[Primitive.Conv]([x, y], params) {
|
|
3191
3191
|
checkConvShape(x.shape, y.shape, params);
|
|
3192
|
+
const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
|
|
3193
|
+
if (shouldMaterializePadding) {
|
|
3194
|
+
x = x.#reshape(x.#st.padOrShrink([...rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
|
|
3195
|
+
x.#realize();
|
|
3196
|
+
params = {
|
|
3197
|
+
...params,
|
|
3198
|
+
padding: rep(params.padding.length, [0, 0])
|
|
3199
|
+
};
|
|
3200
|
+
}
|
|
3192
3201
|
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
3193
3202
|
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
3194
3203
|
},
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-apsUOPzb.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-apsUOPzb.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
|
|
|
147
147
|
}
|
|
148
148
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
149
149
|
}
|
|
150
|
+
function reduceOpWgsl(op, dtype, a, b) {
|
|
151
|
+
if (op === AluOp.Add) return `(${a} + ${b})`;
|
|
152
|
+
if (op === AluOp.Mul) return `(${a} * ${b})`;
|
|
153
|
+
if (op === AluOp.Min) return dtype === DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
|
|
154
|
+
if (op === AluOp.Max) return dtype === DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
|
|
155
|
+
throw new Error(`Unsupported reduction op: ${op}`);
|
|
156
|
+
}
|
|
150
157
|
/** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
|
|
151
158
|
var WgslExpCodegen = class {
|
|
152
159
|
#gensymCount = 0;
|
|
@@ -1099,6 +1106,8 @@ function flushTracingBatch(device, batch) {
|
|
|
1099
1106
|
|
|
1100
1107
|
//#endregion
|
|
1101
1108
|
//#region src/backend/webgpu.ts
|
|
1109
|
+
const MAX_REUSABLE_BUFFER_BYTES = 64 * 1024 * 1024;
|
|
1110
|
+
const MAX_REUSABLE_BUFFERS_PER_SIZE = 64;
|
|
1102
1111
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
1103
1112
|
var WebGPUBackend = class {
|
|
1104
1113
|
type = "webgpu";
|
|
@@ -1109,6 +1118,7 @@ var WebGPUBackend = class {
|
|
|
1109
1118
|
nextSlot;
|
|
1110
1119
|
#cachedShaderMap = /* @__PURE__ */ new Map();
|
|
1111
1120
|
#reusableZsb;
|
|
1121
|
+
#bufferPool = /* @__PURE__ */ new Map();
|
|
1112
1122
|
constructor(device) {
|
|
1113
1123
|
this.device = device;
|
|
1114
1124
|
if (DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
|
|
@@ -1123,31 +1133,22 @@ var WebGPUBackend = class {
|
|
|
1123
1133
|
});
|
|
1124
1134
|
}
|
|
1125
1135
|
malloc(size, initialData) {
|
|
1126
|
-
|
|
1127
|
-
const
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
1138
|
-
else {
|
|
1139
|
-
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
1140
|
-
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
1141
|
-
const remainder = new Uint8Array(4);
|
|
1142
|
-
remainder.set(initialData.subarray(aligned));
|
|
1143
|
-
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
1144
|
-
}
|
|
1145
|
-
}
|
|
1146
|
-
} else buffer = this.#createBuffer(paddedSize);
|
|
1136
|
+
if (initialData && initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
1137
|
+
const allocatedSize = Math.ceil(size / 4) * 4 || 4;
|
|
1138
|
+
const buffer = size === 0 ? this.#reusableZsb : this.#acquireBuffer(allocatedSize);
|
|
1139
|
+
if (initialData && size > 0) if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
1140
|
+
else {
|
|
1141
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
1142
|
+
if (aligned > 0) this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
1143
|
+
const remainder = new Uint8Array(4);
|
|
1144
|
+
remainder.set(initialData.subarray(aligned));
|
|
1145
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
1146
|
+
}
|
|
1147
1147
|
const slot = this.nextSlot++;
|
|
1148
1148
|
this.buffers.set(slot, {
|
|
1149
1149
|
buffer,
|
|
1150
1150
|
size,
|
|
1151
|
+
allocatedSize,
|
|
1151
1152
|
ref: 1
|
|
1152
1153
|
});
|
|
1153
1154
|
return slot;
|
|
@@ -1163,7 +1164,7 @@ var WebGPUBackend = class {
|
|
|
1163
1164
|
buffer.ref--;
|
|
1164
1165
|
if (buffer.ref === 0) {
|
|
1165
1166
|
this.buffers.delete(slot);
|
|
1166
|
-
if (buffer.buffer !== this.#reusableZsb) buffer.buffer.
|
|
1167
|
+
if (buffer.buffer !== this.#reusableZsb) this.#releaseBuffer(buffer.buffer, buffer.allocatedSize);
|
|
1167
1168
|
}
|
|
1168
1169
|
}
|
|
1169
1170
|
async read(slot, start, count) {
|
|
@@ -1251,6 +1252,29 @@ var WebGPUBackend = class {
|
|
|
1251
1252
|
size: buffer.size
|
|
1252
1253
|
};
|
|
1253
1254
|
}
|
|
1255
|
+
#acquireBuffer(size) {
|
|
1256
|
+
if (size > MAX_REUSABLE_BUFFER_BYTES) return this.#createBuffer(size);
|
|
1257
|
+
const bucket = this.#bufferPool.get(size);
|
|
1258
|
+
const buffer = bucket?.pop();
|
|
1259
|
+
if (bucket && bucket.length === 0) this.#bufferPool.delete(size);
|
|
1260
|
+
return buffer ?? this.#createBuffer(size);
|
|
1261
|
+
}
|
|
1262
|
+
#releaseBuffer(buffer, size) {
|
|
1263
|
+
if (size > MAX_REUSABLE_BUFFER_BYTES) {
|
|
1264
|
+
buffer.destroy();
|
|
1265
|
+
return;
|
|
1266
|
+
}
|
|
1267
|
+
const bucket = this.#bufferPool.get(size);
|
|
1268
|
+
if (!bucket) {
|
|
1269
|
+
this.#bufferPool.set(size, [buffer]);
|
|
1270
|
+
return;
|
|
1271
|
+
}
|
|
1272
|
+
if (bucket.length >= MAX_REUSABLE_BUFFERS_PER_SIZE) {
|
|
1273
|
+
buffer.destroy();
|
|
1274
|
+
return;
|
|
1275
|
+
}
|
|
1276
|
+
bucket.push(buffer);
|
|
1277
|
+
}
|
|
1254
1278
|
/**
|
|
1255
1279
|
* Create a GPU buffer.
|
|
1256
1280
|
*
|
|
@@ -1299,14 +1323,30 @@ function pipelineSource(device, kernel) {
|
|
|
1299
1323
|
}
|
|
1300
1324
|
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
1301
1325
|
wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1302
|
-
const
|
|
1303
|
-
const
|
|
1326
|
+
const groupCount = re ? tune.size.groups ?? 1 : 1;
|
|
1327
|
+
const groupedReduction = re && groupCount > 1;
|
|
1328
|
+
if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
|
|
1329
|
+
if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
|
|
1330
|
+
const workgroupSize = groupedReduction ? groupCount : findPow2(tune.threadCount, 256);
|
|
1331
|
+
const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
|
|
1304
1332
|
const [gridX, gridY] = calculateGrid(gridSize);
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1333
|
+
if (groupedReduction) {
|
|
1334
|
+
const partialTy = dtypeToWgsl(re.dtype);
|
|
1335
|
+
for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
|
|
1336
|
+
}
|
|
1337
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
|
|
1338
|
+
if (groupedReduction) {
|
|
1339
|
+
wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
|
|
1340
|
+
if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
|
|
1341
|
+
else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
|
|
1342
|
+
wb.emit("let group: i32 = i32(lid.x);");
|
|
1343
|
+
} else {
|
|
1344
|
+
wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
1345
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1346
|
+
else {
|
|
1347
|
+
const sizeX = gridX * workgroupSize;
|
|
1348
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1349
|
+
}
|
|
1310
1350
|
}
|
|
1311
1351
|
wb.emitPhonyAssignments(args);
|
|
1312
1352
|
const gen = new WgslExpCodegen(wb, args);
|
|
@@ -1316,7 +1356,6 @@ function pipelineSource(device, kernel) {
|
|
|
1316
1356
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1317
1357
|
wb.emit(`result[gidx] = ${rhs};`);
|
|
1318
1358
|
} else {
|
|
1319
|
-
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1320
1359
|
const unroll = tune.size.unroll ?? 1;
|
|
1321
1360
|
const upcast = tune.size.upcast ?? 1;
|
|
1322
1361
|
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
@@ -1352,6 +1391,15 @@ function pipelineSource(device, kernel) {
|
|
|
1352
1391
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1353
1392
|
}
|
|
1354
1393
|
wb.emit(wb.popIndent, "}");
|
|
1394
|
+
if (groupedReduction) {
|
|
1395
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
|
|
1396
|
+
wb.emit("workgroupBarrier();");
|
|
1397
|
+
for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
|
|
1398
|
+
wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
|
|
1399
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
|
|
1400
|
+
wb.emit(wb.popIndent, "}", "workgroupBarrier();");
|
|
1401
|
+
}
|
|
1402
|
+
}
|
|
1355
1403
|
gen.reset();
|
|
1356
1404
|
const outputIdxExps = [];
|
|
1357
1405
|
const fusionExps = [];
|
|
@@ -1365,12 +1413,17 @@ function pipelineSource(device, kernel) {
|
|
|
1365
1413
|
}).simplify(cache));
|
|
1366
1414
|
gen.countReferences(fusionExps[i]);
|
|
1367
1415
|
}
|
|
1416
|
+
if (groupedReduction) {
|
|
1417
|
+
wb.emit("if (lid.x == 0u) {", wb.pushIndent);
|
|
1418
|
+
for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
|
|
1419
|
+
}
|
|
1368
1420
|
for (let i = 0; i < upcast; i++) {
|
|
1369
1421
|
const index = strip1(gen.run(outputIdxExps[i]));
|
|
1370
1422
|
let rhs = strip1(gen.run(fusionExps[i]));
|
|
1371
1423
|
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1372
1424
|
wb.emit(`result[${index}] = ${rhs};`);
|
|
1373
1425
|
}
|
|
1426
|
+
if (groupedReduction) wb.emit(wb.popIndent, "}");
|
|
1374
1427
|
}
|
|
1375
1428
|
wb.emit(wb.popIndent, "}");
|
|
1376
1429
|
return {
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-VlXzdQvR.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
|
|
|
147
147
|
}
|
|
148
148
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
149
149
|
}
|
|
150
|
+
function reduceOpWgsl(op, dtype, a, b) {
|
|
151
|
+
if (op === require_backend.AluOp.Add) return `(${a} + ${b})`;
|
|
152
|
+
if (op === require_backend.AluOp.Mul) return `(${a} * ${b})`;
|
|
153
|
+
if (op === require_backend.AluOp.Min) return dtype === require_backend.DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
|
|
154
|
+
if (op === require_backend.AluOp.Max) return dtype === require_backend.DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
|
|
155
|
+
throw new Error(`Unsupported reduction op: ${op}`);
|
|
156
|
+
}
|
|
150
157
|
/** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
|
|
151
158
|
var WgslExpCodegen = class {
|
|
152
159
|
#gensymCount = 0;
|
|
@@ -1099,6 +1106,8 @@ function flushTracingBatch(device, batch) {
|
|
|
1099
1106
|
|
|
1100
1107
|
//#endregion
|
|
1101
1108
|
//#region src/backend/webgpu.ts
|
|
1109
|
+
const MAX_REUSABLE_BUFFER_BYTES = 64 * 1024 * 1024;
|
|
1110
|
+
const MAX_REUSABLE_BUFFERS_PER_SIZE = 64;
|
|
1102
1111
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
1103
1112
|
var WebGPUBackend = class {
|
|
1104
1113
|
type = "webgpu";
|
|
@@ -1109,6 +1118,7 @@ var WebGPUBackend = class {
|
|
|
1109
1118
|
nextSlot;
|
|
1110
1119
|
#cachedShaderMap = /* @__PURE__ */ new Map();
|
|
1111
1120
|
#reusableZsb;
|
|
1121
|
+
#bufferPool = /* @__PURE__ */ new Map();
|
|
1112
1122
|
constructor(device) {
|
|
1113
1123
|
this.device = device;
|
|
1114
1124
|
if (require_backend.DEBUG >= 3 && device.adapterInfo) console.info("webgpu adapter:", device.adapterInfo.vendor, device.adapterInfo.architecture);
|
|
@@ -1123,31 +1133,22 @@ var WebGPUBackend = class {
|
|
|
1123
1133
|
});
|
|
1124
1134
|
}
|
|
1125
1135
|
malloc(size, initialData) {
|
|
1126
|
-
|
|
1127
|
-
const
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
|
|
1132
|
-
|
|
1133
|
-
|
|
1134
|
-
|
|
1135
|
-
|
|
1136
|
-
|
|
1137
|
-
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
1138
|
-
else {
|
|
1139
|
-
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
1140
|
-
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
1141
|
-
const remainder = new Uint8Array(4);
|
|
1142
|
-
remainder.set(initialData.subarray(aligned));
|
|
1143
|
-
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
1144
|
-
}
|
|
1145
|
-
}
|
|
1146
|
-
} else buffer = this.#createBuffer(paddedSize);
|
|
1136
|
+
if (initialData && initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
1137
|
+
const allocatedSize = Math.ceil(size / 4) * 4 || 4;
|
|
1138
|
+
const buffer = size === 0 ? this.#reusableZsb : this.#acquireBuffer(allocatedSize);
|
|
1139
|
+
if (initialData && size > 0) if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
1140
|
+
else {
|
|
1141
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
1142
|
+
if (aligned > 0) this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
1143
|
+
const remainder = new Uint8Array(4);
|
|
1144
|
+
remainder.set(initialData.subarray(aligned));
|
|
1145
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
1146
|
+
}
|
|
1147
1147
|
const slot = this.nextSlot++;
|
|
1148
1148
|
this.buffers.set(slot, {
|
|
1149
1149
|
buffer,
|
|
1150
1150
|
size,
|
|
1151
|
+
allocatedSize,
|
|
1151
1152
|
ref: 1
|
|
1152
1153
|
});
|
|
1153
1154
|
return slot;
|
|
@@ -1163,7 +1164,7 @@ var WebGPUBackend = class {
|
|
|
1163
1164
|
buffer.ref--;
|
|
1164
1165
|
if (buffer.ref === 0) {
|
|
1165
1166
|
this.buffers.delete(slot);
|
|
1166
|
-
if (buffer.buffer !== this.#reusableZsb) buffer.buffer.
|
|
1167
|
+
if (buffer.buffer !== this.#reusableZsb) this.#releaseBuffer(buffer.buffer, buffer.allocatedSize);
|
|
1167
1168
|
}
|
|
1168
1169
|
}
|
|
1169
1170
|
async read(slot, start, count) {
|
|
@@ -1251,6 +1252,29 @@ var WebGPUBackend = class {
|
|
|
1251
1252
|
size: buffer.size
|
|
1252
1253
|
};
|
|
1253
1254
|
}
|
|
1255
|
+
#acquireBuffer(size) {
|
|
1256
|
+
if (size > MAX_REUSABLE_BUFFER_BYTES) return this.#createBuffer(size);
|
|
1257
|
+
const bucket = this.#bufferPool.get(size);
|
|
1258
|
+
const buffer = bucket?.pop();
|
|
1259
|
+
if (bucket && bucket.length === 0) this.#bufferPool.delete(size);
|
|
1260
|
+
return buffer ?? this.#createBuffer(size);
|
|
1261
|
+
}
|
|
1262
|
+
#releaseBuffer(buffer, size) {
|
|
1263
|
+
if (size > MAX_REUSABLE_BUFFER_BYTES) {
|
|
1264
|
+
buffer.destroy();
|
|
1265
|
+
return;
|
|
1266
|
+
}
|
|
1267
|
+
const bucket = this.#bufferPool.get(size);
|
|
1268
|
+
if (!bucket) {
|
|
1269
|
+
this.#bufferPool.set(size, [buffer]);
|
|
1270
|
+
return;
|
|
1271
|
+
}
|
|
1272
|
+
if (bucket.length >= MAX_REUSABLE_BUFFERS_PER_SIZE) {
|
|
1273
|
+
buffer.destroy();
|
|
1274
|
+
return;
|
|
1275
|
+
}
|
|
1276
|
+
bucket.push(buffer);
|
|
1277
|
+
}
|
|
1254
1278
|
/**
|
|
1255
1279
|
* Create a GPU buffer.
|
|
1256
1280
|
*
|
|
@@ -1299,14 +1323,30 @@ function pipelineSource(device, kernel) {
|
|
|
1299
1323
|
}
|
|
1300
1324
|
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
1301
1325
|
wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1302
|
-
const
|
|
1303
|
-
const
|
|
1326
|
+
const groupCount = re ? tune.size.groups ?? 1 : 1;
|
|
1327
|
+
const groupedReduction = re && groupCount > 1;
|
|
1328
|
+
if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
|
|
1329
|
+
if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
|
|
1330
|
+
const workgroupSize = groupedReduction ? groupCount : require_backend.findPow2(tune.threadCount, 256);
|
|
1331
|
+
const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
|
|
1304
1332
|
const [gridX, gridY] = calculateGrid(gridSize);
|
|
1305
|
-
|
|
1306
|
-
|
|
1307
|
-
|
|
1308
|
-
|
|
1309
|
-
|
|
1333
|
+
if (groupedReduction) {
|
|
1334
|
+
const partialTy = dtypeToWgsl(re.dtype);
|
|
1335
|
+
for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
|
|
1336
|
+
}
|
|
1337
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
|
|
1338
|
+
if (groupedReduction) {
|
|
1339
|
+
wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
|
|
1340
|
+
if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
|
|
1341
|
+
else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
|
|
1342
|
+
wb.emit("let group: i32 = i32(lid.x);");
|
|
1343
|
+
} else {
|
|
1344
|
+
wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
1345
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1346
|
+
else {
|
|
1347
|
+
const sizeX = gridX * workgroupSize;
|
|
1348
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1349
|
+
}
|
|
1310
1350
|
}
|
|
1311
1351
|
wb.emitPhonyAssignments(args);
|
|
1312
1352
|
const gen = new WgslExpCodegen(wb, args);
|
|
@@ -1316,7 +1356,6 @@ function pipelineSource(device, kernel) {
|
|
|
1316
1356
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1317
1357
|
wb.emit(`result[gidx] = ${rhs};`);
|
|
1318
1358
|
} else {
|
|
1319
|
-
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1320
1359
|
const unroll = tune.size.unroll ?? 1;
|
|
1321
1360
|
const upcast = tune.size.upcast ?? 1;
|
|
1322
1361
|
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
@@ -1352,6 +1391,15 @@ function pipelineSource(device, kernel) {
|
|
|
1352
1391
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1353
1392
|
}
|
|
1354
1393
|
wb.emit(wb.popIndent, "}");
|
|
1394
|
+
if (groupedReduction) {
|
|
1395
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
|
|
1396
|
+
wb.emit("workgroupBarrier();");
|
|
1397
|
+
for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
|
|
1398
|
+
wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
|
|
1399
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
|
|
1400
|
+
wb.emit(wb.popIndent, "}", "workgroupBarrier();");
|
|
1401
|
+
}
|
|
1402
|
+
}
|
|
1355
1403
|
gen.reset();
|
|
1356
1404
|
const outputIdxExps = [];
|
|
1357
1405
|
const fusionExps = [];
|
|
@@ -1365,12 +1413,17 @@ function pipelineSource(device, kernel) {
|
|
|
1365
1413
|
}).simplify(cache));
|
|
1366
1414
|
gen.countReferences(fusionExps[i]);
|
|
1367
1415
|
}
|
|
1416
|
+
if (groupedReduction) {
|
|
1417
|
+
wb.emit("if (lid.x == 0u) {", wb.pushIndent);
|
|
1418
|
+
for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
|
|
1419
|
+
}
|
|
1368
1420
|
for (let i = 0; i < upcast; i++) {
|
|
1369
1421
|
const index = require_backend.strip1(gen.run(outputIdxExps[i]));
|
|
1370
1422
|
let rhs = require_backend.strip1(gen.run(fusionExps[i]));
|
|
1371
1423
|
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1372
1424
|
wb.emit(`result[${index}] = ${rhs};`);
|
|
1373
1425
|
}
|
|
1426
|
+
if (groupedReduction) wb.emit(wb.popIndent, "}");
|
|
1374
1427
|
}
|
|
1375
1428
|
wb.emit(wb.popIndent, "}");
|
|
1376
1429
|
return {
|