@jax-js/jax 0.1.13 → 0.1.14
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +10 -7
- package/dist/{backend-DMyuoWi2.cjs → backend-VlXzdQvR.cjs} +2111 -1557
- package/dist/{backend-DLEk-B3V.js → backend-apsUOPzb.js} +2111 -1557
- package/dist/index.cjs +10 -1
- package/dist/index.js +10 -1
- package/dist/{webgl-pbfUGDA6.cjs → webgl-C6rCbloA.cjs} +1 -1
- package/dist/{webgl-NsFtyIts.js → webgl-Hh0FX6oV.js} +1 -1
- package/dist/{webgpu-NkF1TZ0t.js → webgpu-BRv5r9Sl.js} +45 -9
- package/dist/{webgpu-DDGCYtHa.cjs → webgpu-pWnE96Xc.cjs} +45 -9
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-VlXzdQvR.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -3224,6 +3224,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3224
3224
|
},
|
|
3225
3225
|
[Primitive.Conv]([x, y], params) {
|
|
3226
3226
|
checkConvShape(x.shape, y.shape, params);
|
|
3227
|
+
const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
|
|
3228
|
+
if (shouldMaterializePadding) {
|
|
3229
|
+
x = x.#reshape(x.#st.padOrShrink([...require_backend.rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
|
|
3230
|
+
x.#realize();
|
|
3231
|
+
params = {
|
|
3232
|
+
...params,
|
|
3233
|
+
padding: require_backend.rep(params.padding.length, [0, 0])
|
|
3234
|
+
};
|
|
3235
|
+
}
|
|
3227
3236
|
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
3228
3237
|
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
3229
3238
|
},
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-apsUOPzb.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -3189,6 +3189,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
3189
3189
|
},
|
|
3190
3190
|
[Primitive.Conv]([x, y], params) {
|
|
3191
3191
|
checkConvShape(x.shape, y.shape, params);
|
|
3192
|
+
const shouldMaterializePadding = x.#backend.type === "wasm" && params.lhsDilation.every((d) => d === 1) && params.padding.some(([left, right]) => left > 0 || right > 0);
|
|
3193
|
+
if (shouldMaterializePadding) {
|
|
3194
|
+
x = x.#reshape(x.#st.padOrShrink([...rep(params.vmapDims + 2, [0, 0]), ...params.padding]));
|
|
3195
|
+
x.#realize();
|
|
3196
|
+
params = {
|
|
3197
|
+
...params,
|
|
3198
|
+
padding: rep(params.padding.length, [0, 0])
|
|
3199
|
+
};
|
|
3200
|
+
}
|
|
3192
3201
|
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
3193
3202
|
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
3194
3203
|
},
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-apsUOPzb.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, Routines, SlotError, UnsupportedOpError, UnsupportedRoutineError, emitTrace, findPow2, isFloatDtype, isTracing, mapSetUnion, onFlushTrace, prod, range, strip1, traceSourceInfo, tuneWebgpu } from "./backend-apsUOPzb.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
|
|
|
147
147
|
}
|
|
148
148
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
149
149
|
}
|
|
150
|
+
function reduceOpWgsl(op, dtype, a, b) {
|
|
151
|
+
if (op === AluOp.Add) return `(${a} + ${b})`;
|
|
152
|
+
if (op === AluOp.Mul) return `(${a} * ${b})`;
|
|
153
|
+
if (op === AluOp.Min) return dtype === DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
|
|
154
|
+
if (op === AluOp.Max) return dtype === DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
|
|
155
|
+
throw new Error(`Unsupported reduction op: ${op}`);
|
|
156
|
+
}
|
|
150
157
|
/** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
|
|
151
158
|
var WgslExpCodegen = class {
|
|
152
159
|
#gensymCount = 0;
|
|
@@ -1316,14 +1323,30 @@ function pipelineSource(device, kernel) {
|
|
|
1316
1323
|
}
|
|
1317
1324
|
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
1318
1325
|
wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1319
|
-
const
|
|
1320
|
-
const
|
|
1326
|
+
const groupCount = re ? tune.size.groups ?? 1 : 1;
|
|
1327
|
+
const groupedReduction = re && groupCount > 1;
|
|
1328
|
+
if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
|
|
1329
|
+
if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
|
|
1330
|
+
const workgroupSize = groupedReduction ? groupCount : findPow2(tune.threadCount, 256);
|
|
1331
|
+
const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
|
|
1321
1332
|
const [gridX, gridY] = calculateGrid(gridSize);
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1333
|
+
if (groupedReduction) {
|
|
1334
|
+
const partialTy = dtypeToWgsl(re.dtype);
|
|
1335
|
+
for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
|
|
1336
|
+
}
|
|
1337
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
|
|
1338
|
+
if (groupedReduction) {
|
|
1339
|
+
wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
|
|
1340
|
+
if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
|
|
1341
|
+
else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
|
|
1342
|
+
wb.emit("let group: i32 = i32(lid.x);");
|
|
1343
|
+
} else {
|
|
1344
|
+
wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
1345
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1346
|
+
else {
|
|
1347
|
+
const sizeX = gridX * workgroupSize;
|
|
1348
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1349
|
+
}
|
|
1327
1350
|
}
|
|
1328
1351
|
wb.emitPhonyAssignments(args);
|
|
1329
1352
|
const gen = new WgslExpCodegen(wb, args);
|
|
@@ -1333,7 +1356,6 @@ function pipelineSource(device, kernel) {
|
|
|
1333
1356
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1334
1357
|
wb.emit(`result[gidx] = ${rhs};`);
|
|
1335
1358
|
} else {
|
|
1336
|
-
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1337
1359
|
const unroll = tune.size.unroll ?? 1;
|
|
1338
1360
|
const upcast = tune.size.upcast ?? 1;
|
|
1339
1361
|
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
@@ -1369,6 +1391,15 @@ function pipelineSource(device, kernel) {
|
|
|
1369
1391
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1370
1392
|
}
|
|
1371
1393
|
wb.emit(wb.popIndent, "}");
|
|
1394
|
+
if (groupedReduction) {
|
|
1395
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
|
|
1396
|
+
wb.emit("workgroupBarrier();");
|
|
1397
|
+
for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
|
|
1398
|
+
wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
|
|
1399
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
|
|
1400
|
+
wb.emit(wb.popIndent, "}", "workgroupBarrier();");
|
|
1401
|
+
}
|
|
1402
|
+
}
|
|
1372
1403
|
gen.reset();
|
|
1373
1404
|
const outputIdxExps = [];
|
|
1374
1405
|
const fusionExps = [];
|
|
@@ -1382,12 +1413,17 @@ function pipelineSource(device, kernel) {
|
|
|
1382
1413
|
}).simplify(cache));
|
|
1383
1414
|
gen.countReferences(fusionExps[i]);
|
|
1384
1415
|
}
|
|
1416
|
+
if (groupedReduction) {
|
|
1417
|
+
wb.emit("if (lid.x == 0u) {", wb.pushIndent);
|
|
1418
|
+
for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
|
|
1419
|
+
}
|
|
1385
1420
|
for (let i = 0; i < upcast; i++) {
|
|
1386
1421
|
const index = strip1(gen.run(outputIdxExps[i]));
|
|
1387
1422
|
let rhs = strip1(gen.run(fusionExps[i]));
|
|
1388
1423
|
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1389
1424
|
wb.emit(`result[${index}] = ${rhs};`);
|
|
1390
1425
|
}
|
|
1426
|
+
if (groupedReduction) wb.emit(wb.popIndent, "}");
|
|
1391
1427
|
}
|
|
1392
1428
|
wb.emit(wb.popIndent, "}");
|
|
1393
1429
|
return {
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-VlXzdQvR.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu/builtins.ts
|
|
4
4
|
const threefrySrc = `
|
|
@@ -147,6 +147,13 @@ function constToWgsl(dtype, value) {
|
|
|
147
147
|
}
|
|
148
148
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
149
149
|
}
|
|
150
|
+
function reduceOpWgsl(op, dtype, a, b) {
|
|
151
|
+
if (op === require_backend.AluOp.Add) return `(${a} + ${b})`;
|
|
152
|
+
if (op === require_backend.AluOp.Mul) return `(${a} * ${b})`;
|
|
153
|
+
if (op === require_backend.AluOp.Min) return dtype === require_backend.DType.Bool ? `(${a} && ${b})` : `min(${a}, ${b})`;
|
|
154
|
+
if (op === require_backend.AluOp.Max) return dtype === require_backend.DType.Bool ? `(${a} || ${b})` : `max(${a}, ${b})`;
|
|
155
|
+
throw new Error(`Unsupported reduction op: ${op}`);
|
|
156
|
+
}
|
|
150
157
|
/** Codegen for WebGPU expressions, linearizing AluOp into a kernel. */
|
|
151
158
|
var WgslExpCodegen = class {
|
|
152
159
|
#gensymCount = 0;
|
|
@@ -1316,14 +1323,30 @@ function pipelineSource(device, kernel) {
|
|
|
1316
1323
|
}
|
|
1317
1324
|
const resultTy = dtypeToWgsl(kernel.dtype, true);
|
|
1318
1325
|
wb.emit(`@group(0) @binding(${nargs}) var<storage, read_write> result : array<${resultTy}>;`);
|
|
1319
|
-
const
|
|
1320
|
-
const
|
|
1326
|
+
const groupCount = re ? tune.size.groups ?? 1 : 1;
|
|
1327
|
+
const groupedReduction = re && groupCount > 1;
|
|
1328
|
+
if (groupedReduction && tune.threadCount % groupCount !== 0) throw new Error("WebGPU grouped reduction has invalid thread count");
|
|
1329
|
+
if (groupedReduction && groupCount > device.limits.maxComputeWorkgroupSizeX) throw new Error("WebGPU grouped reduction exceeds workgroup size limit");
|
|
1330
|
+
const workgroupSize = groupedReduction ? groupCount : require_backend.findPow2(tune.threadCount, 256);
|
|
1331
|
+
const gridSize = groupedReduction ? tune.threadCount / groupCount : Math.ceil(tune.threadCount / workgroupSize);
|
|
1321
1332
|
const [gridX, gridY] = calculateGrid(gridSize);
|
|
1322
|
-
|
|
1323
|
-
|
|
1324
|
-
|
|
1325
|
-
|
|
1326
|
-
|
|
1333
|
+
if (groupedReduction) {
|
|
1334
|
+
const partialTy = dtypeToWgsl(re.dtype);
|
|
1335
|
+
for (let i = 0; i < (tune.size.upcast ?? 1); i++) wb.emit(`var<workgroup> partial${i}: array<${partialTy}, ${groupCount}>;`);
|
|
1336
|
+
}
|
|
1337
|
+
wb.emit("", `@compute @workgroup_size(${workgroupSize})`);
|
|
1338
|
+
if (groupedReduction) {
|
|
1339
|
+
wb.emit("fn main(", wb.pushIndent, "@builtin(local_invocation_id) lid : vec3<u32>,", "@builtin(workgroup_id) wg_id : vec3<u32>,", wb.popIndent, ") {", wb.pushIndent);
|
|
1340
|
+
if (gridY === 1) wb.emit(`if (wg_id.x >= ${gridSize}u) { return; }`, "let gidx: i32 = i32(wg_id.x);");
|
|
1341
|
+
else wb.emit(`if (${gridX}u * wg_id.y + wg_id.x >= ${gridSize}u) { return; }`, `let gidx: i32 = i32(${gridX}u * wg_id.y + wg_id.x);`);
|
|
1342
|
+
wb.emit("let group: i32 = i32(lid.x);");
|
|
1343
|
+
} else {
|
|
1344
|
+
wb.emit("fn main(@builtin(global_invocation_id) id : vec3<u32>) {", wb.pushIndent);
|
|
1345
|
+
if (gridY === 1) wb.emit(`if (id.x >= ${tune.threadCount}) { return; }`, "let gidx: i32 = i32(id.x);");
|
|
1346
|
+
else {
|
|
1347
|
+
const sizeX = gridX * workgroupSize;
|
|
1348
|
+
wb.emit(`if (${sizeX} * id.y + id.x >= ${tune.threadCount}) { return; }`, `let gidx: i32 = i32(${sizeX} * id.y + id.x);`);
|
|
1349
|
+
}
|
|
1327
1350
|
}
|
|
1328
1351
|
wb.emitPhonyAssignments(args);
|
|
1329
1352
|
const gen = new WgslExpCodegen(wb, args);
|
|
@@ -1333,7 +1356,6 @@ function pipelineSource(device, kernel) {
|
|
|
1333
1356
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1334
1357
|
wb.emit(`result[gidx] = ${rhs};`);
|
|
1335
1358
|
} else {
|
|
1336
|
-
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
1337
1359
|
const unroll = tune.size.unroll ?? 1;
|
|
1338
1360
|
const upcast = tune.size.upcast ?? 1;
|
|
1339
1361
|
const acc = [...Array(upcast)].map((_, i) => `acc${i}`);
|
|
@@ -1369,6 +1391,15 @@ function pipelineSource(device, kernel) {
|
|
|
1369
1391
|
else throw new Error(`Unsupported reduction op: ${re.op}`);
|
|
1370
1392
|
}
|
|
1371
1393
|
wb.emit(wb.popIndent, "}");
|
|
1394
|
+
if (groupedReduction) {
|
|
1395
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${acc[i]};`);
|
|
1396
|
+
wb.emit("workgroupBarrier();");
|
|
1397
|
+
for (let stride = groupCount / 2; stride >= 1; stride /= 2) {
|
|
1398
|
+
wb.emit(`if (lid.x < ${stride}u) {`, wb.pushIndent);
|
|
1399
|
+
for (let i = 0; i < upcast; i++) wb.emit(`partial${i}[lid.x] = ${reduceOpWgsl(re.op, re.dtype, `partial${i}[lid.x]`, `partial${i}[lid.x + ${stride}u]`)};`);
|
|
1400
|
+
wb.emit(wb.popIndent, "}", "workgroupBarrier();");
|
|
1401
|
+
}
|
|
1402
|
+
}
|
|
1372
1403
|
gen.reset();
|
|
1373
1404
|
const outputIdxExps = [];
|
|
1374
1405
|
const fusionExps = [];
|
|
@@ -1382,12 +1413,17 @@ function pipelineSource(device, kernel) {
|
|
|
1382
1413
|
}).simplify(cache));
|
|
1383
1414
|
gen.countReferences(fusionExps[i]);
|
|
1384
1415
|
}
|
|
1416
|
+
if (groupedReduction) {
|
|
1417
|
+
wb.emit("if (lid.x == 0u) {", wb.pushIndent);
|
|
1418
|
+
for (let i = 0; i < upcast; i++) wb.emit(`${acc[i]} = partial${i}[0u];`);
|
|
1419
|
+
}
|
|
1385
1420
|
for (let i = 0; i < upcast; i++) {
|
|
1386
1421
|
const index = require_backend.strip1(gen.run(outputIdxExps[i]));
|
|
1387
1422
|
let rhs = require_backend.strip1(gen.run(fusionExps[i]));
|
|
1388
1423
|
if (resultTy !== dtypeToWgsl(fusionExps[i].dtype)) rhs = `${resultTy}(${rhs})`;
|
|
1389
1424
|
wb.emit(`result[${index}] = ${rhs};`);
|
|
1390
1425
|
}
|
|
1426
|
+
if (groupedReduction) wb.emit(wb.popIndent, "}");
|
|
1391
1427
|
}
|
|
1392
1428
|
wb.emit(wb.popIndent, "}");
|
|
1393
1429
|
return {
|