@jax-js/jax 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +57 -25
- package/dist/backend-EBRGmEYw.js +3816 -0
- package/dist/{backend-BK21PBVP.cjs → backend-Ss1Mev_-.cjs} +2075 -107
- package/dist/index.cjs +1393 -250
- package/dist/index.d.cts +651 -102
- package/dist/index.d.ts +651 -102
- package/dist/index.js +1377 -245
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-BVdMaO9T.cjs} +62 -35
- package/dist/{webgpu-JVpVad6g.js → webgpu-ow0Pn_6q.js} +62 -35
- package/package.json +21 -9
- package/dist/backend-1eVbAoaV.js +0 -1890
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-Ss1Mev_-.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu.ts
|
|
4
4
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
|
|
|
21
21
|
}
|
|
22
22
|
malloc(size, initialData) {
|
|
23
23
|
let buffer;
|
|
24
|
+
const paddedSize = Math.ceil(size / 4) * 4;
|
|
24
25
|
if (initialData) {
|
|
25
26
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
26
27
|
if (initialData.byteLength < 4096) {
|
|
27
|
-
buffer = this.#createBuffer(
|
|
28
|
-
new Uint8Array(buffer.getMappedRange()).set(
|
|
28
|
+
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
29
|
+
new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
|
|
29
30
|
buffer.unmap();
|
|
30
31
|
} else {
|
|
31
|
-
buffer = this.#createBuffer(
|
|
32
|
-
this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
32
|
+
buffer = this.#createBuffer(paddedSize);
|
|
33
|
+
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
34
|
+
else {
|
|
35
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
36
|
+
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
37
|
+
const remainder = new Uint8Array(4);
|
|
38
|
+
remainder.set(initialData.subarray(aligned));
|
|
39
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
40
|
+
}
|
|
33
41
|
}
|
|
34
|
-
} else buffer = this.#createBuffer(
|
|
42
|
+
} else buffer = this.#createBuffer(paddedSize);
|
|
35
43
|
const slot = this.nextSlot++;
|
|
36
44
|
this.buffers.set(slot, {
|
|
37
45
|
buffer,
|
|
46
|
+
size,
|
|
38
47
|
ref: 1
|
|
39
48
|
});
|
|
40
49
|
return slot;
|
|
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
|
|
|
54
63
|
}
|
|
55
64
|
}
|
|
56
65
|
async read(slot, start, count) {
|
|
57
|
-
const buffer = this.#getBuffer(slot);
|
|
66
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
58
67
|
if (start === void 0) start = 0;
|
|
59
|
-
if (count === void 0) count =
|
|
60
|
-
const
|
|
68
|
+
if (count === void 0) count = size - start;
|
|
69
|
+
const paddedSize = Math.ceil(count / 4) * 4;
|
|
70
|
+
const staging = this.#createBuffer(paddedSize, { read: true });
|
|
61
71
|
try {
|
|
62
72
|
const commandEncoder = this.device.createCommandEncoder();
|
|
63
|
-
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0,
|
|
73
|
+
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
|
|
64
74
|
this.device.queue.submit([commandEncoder.finish()]);
|
|
65
75
|
await staging.mapAsync(GPUMapMode.READ);
|
|
66
76
|
const arrayBuffer = staging.getMappedRange();
|
|
67
|
-
return arrayBuffer.slice();
|
|
77
|
+
return new Uint8Array(arrayBuffer.slice(), 0, count);
|
|
68
78
|
} finally {
|
|
69
79
|
staging.destroy();
|
|
70
80
|
}
|
|
71
81
|
}
|
|
72
82
|
readSync(slot, start, count) {
|
|
73
|
-
const buffer = this.#getBuffer(slot);
|
|
83
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
74
84
|
if (start === void 0) start = 0;
|
|
75
|
-
if (count === void 0) count =
|
|
85
|
+
if (count === void 0) count = size - start;
|
|
76
86
|
return this.syncReader.read(buffer, start, count);
|
|
77
87
|
}
|
|
78
88
|
#cachedShader(kernel) {
|
|
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
|
|
|
103
113
|
});
|
|
104
114
|
}
|
|
105
115
|
dispatch(exe, inputs, outputs) {
|
|
106
|
-
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
107
|
-
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
116
|
+
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
117
|
+
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
108
118
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
109
119
|
}
|
|
110
120
|
#getBuffer(slot) {
|
|
111
121
|
const buffer = this.buffers.get(slot);
|
|
112
122
|
if (!buffer) throw new require_backend.SlotError(slot);
|
|
113
|
-
return
|
|
123
|
+
return {
|
|
124
|
+
buffer: buffer.buffer,
|
|
125
|
+
size: buffer.size
|
|
126
|
+
};
|
|
114
127
|
}
|
|
115
128
|
/**
|
|
116
129
|
* Create a GPU buffer.
|
|
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
138
151
|
case require_backend.DType.Int32: return "i32";
|
|
139
152
|
case require_backend.DType.Uint32: return "u32";
|
|
140
153
|
case require_backend.DType.Float32: return "f32";
|
|
154
|
+
case require_backend.DType.Float16: return "f16";
|
|
141
155
|
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
142
156
|
}
|
|
143
157
|
}
|
|
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
|
|
|
148
162
|
if (dtype === require_backend.DType.Float32) {
|
|
149
163
|
if (Number.isNaN(value)) return "nan()";
|
|
150
164
|
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
165
|
+
return "f32(" + value.toString() + ")";
|
|
166
|
+
}
|
|
167
|
+
if (dtype === require_backend.DType.Float16) {
|
|
168
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
169
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
170
|
+
return "f16(" + value.toString() + ")";
|
|
154
171
|
}
|
|
155
172
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
156
173
|
}
|
|
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
|
|
|
163
180
|
function pipelineSource(device, kernel) {
|
|
164
181
|
const tune = require_backend.tuneWebgpu(kernel);
|
|
165
182
|
if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
166
|
-
const { nargs } = kernel;
|
|
183
|
+
const { nargs, reduction: re } = kernel;
|
|
167
184
|
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
168
185
|
const shader = [];
|
|
169
186
|
let indent = "";
|
|
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
|
|
|
174
191
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
175
192
|
else shader.push(line ? indent + line : line);
|
|
176
193
|
};
|
|
194
|
+
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || re?.epilogue.some((exp) => exp.dtype === require_backend.DType.Float16)) {
|
|
195
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
196
|
+
emit("enable f16;");
|
|
197
|
+
}
|
|
177
198
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
178
|
-
|
|
199
|
+
const distinctOps = require_backend.union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
200
|
+
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
179
201
|
emit("");
|
|
180
202
|
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
181
203
|
tune.exp.fold((exp) => {
|
|
182
|
-
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
|
|
204
|
+
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
183
205
|
});
|
|
184
206
|
for (let i = 0; i < nargs; i++) {
|
|
185
207
|
const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
|
|
@@ -226,22 +248,29 @@ function pipelineSource(device, kernel) {
|
|
|
226
248
|
else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
|
|
227
249
|
else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
228
250
|
else source = `(${a} * ${b})`;
|
|
229
|
-
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `
|
|
251
|
+
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
230
252
|
else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
|
|
231
253
|
else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
232
254
|
else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
233
255
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
234
256
|
else if (op === require_backend.AluOp.Cmpne) source = `(${a} != ${b})`;
|
|
235
|
-
} else if (require_backend.AluGroup.Unary.has(op)) {
|
|
257
|
+
} else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
|
|
258
|
+
const a = gen(src[0].src[0]);
|
|
259
|
+
source = `inverseSqrt(${a})`;
|
|
260
|
+
} else {
|
|
236
261
|
const a = gen(src[0]);
|
|
237
262
|
if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
|
|
238
263
|
else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
|
|
264
|
+
else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
|
|
265
|
+
else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
|
|
239
266
|
else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
|
|
240
267
|
else if (op === require_backend.AluOp.Log) source = `log(${a})`;
|
|
268
|
+
else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
|
|
241
269
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
242
270
|
else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
|
|
243
271
|
else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
244
|
-
}
|
|
272
|
+
}
|
|
273
|
+
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
|
|
245
274
|
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
246
275
|
const x = gensym();
|
|
247
276
|
const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
|
|
@@ -249,15 +278,15 @@ function pipelineSource(device, kernel) {
|
|
|
249
278
|
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
250
279
|
else if (arg === 0) source = `${x}.x`;
|
|
251
280
|
else if (arg === 1) source = `${x}.y`;
|
|
252
|
-
else throw new
|
|
281
|
+
else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
253
282
|
} else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
|
|
254
283
|
else if (op === require_backend.AluOp.Special) return arg[0];
|
|
255
284
|
else if (op === require_backend.AluOp.Variable) return arg;
|
|
256
285
|
else if (op === require_backend.AluOp.GlobalIndex) {
|
|
257
|
-
source = `${args[arg]}[${require_backend.strip1(gen(src[0]))}]`;
|
|
286
|
+
source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
|
|
258
287
|
if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
|
|
259
288
|
}
|
|
260
|
-
if (!source) throw new
|
|
289
|
+
if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
261
290
|
const typeName = dtypeToWgsl(dtype);
|
|
262
291
|
if ((references.get(exp) ?? 0) > 1) {
|
|
263
292
|
const name = gensym();
|
|
@@ -269,13 +298,12 @@ function pipelineSource(device, kernel) {
|
|
|
269
298
|
return source;
|
|
270
299
|
}
|
|
271
300
|
};
|
|
272
|
-
if (!
|
|
301
|
+
if (!re) {
|
|
273
302
|
countReferences(tune.exp);
|
|
274
303
|
let rhs = require_backend.strip1(gen(tune.exp));
|
|
275
304
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
276
305
|
emit(`result[gidx] = ${rhs};`);
|
|
277
306
|
} else {
|
|
278
|
-
const re = kernel.reduction;
|
|
279
307
|
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
280
308
|
const unroll = tune.size.unroll ?? 1;
|
|
281
309
|
const upcast = tune.size.upcast ?? 1;
|
|
@@ -319,7 +347,7 @@ function pipelineSource(device, kernel) {
|
|
|
319
347
|
const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
|
|
320
348
|
outputIdxExps.push(exp.simplify(cache));
|
|
321
349
|
countReferences(outputIdxExps[i]);
|
|
322
|
-
fusionExps.push(re.
|
|
350
|
+
fusionExps.push(re.epilogue.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
|
|
323
351
|
countReferences(fusionExps[i]);
|
|
324
352
|
}
|
|
325
353
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -487,13 +515,12 @@ var SyncReader = class SyncReader {
|
|
|
487
515
|
}
|
|
488
516
|
read(buffer, start, count) {
|
|
489
517
|
if (!this.initialized) this.#init();
|
|
490
|
-
if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
|
|
491
518
|
const deviceStorage = this.deviceStorage;
|
|
492
519
|
const deviceContexts = this.deviceContexts;
|
|
493
520
|
const hostContext = this.hostContext;
|
|
494
|
-
const pixelsSize = count / 4;
|
|
521
|
+
const pixelsSize = Math.ceil(count / 4);
|
|
495
522
|
const bytesPerRow = SyncReader.width * 4;
|
|
496
|
-
const valsGPU = new ArrayBuffer(
|
|
523
|
+
const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
|
|
497
524
|
for (let i = 0; i < deviceContexts.length; i++) {
|
|
498
525
|
const texture = deviceContexts[i].getCurrentTexture();
|
|
499
526
|
const readData = (width, height, offset$1) => {
|
|
@@ -537,7 +564,7 @@ var SyncReader = class SyncReader {
|
|
|
537
564
|
}
|
|
538
565
|
if (remainder > 0) readData(remainder, 1, offset);
|
|
539
566
|
}
|
|
540
|
-
return valsGPU;
|
|
567
|
+
return new Uint8Array(valsGPU, 0, count);
|
|
541
568
|
}
|
|
542
569
|
};
|
|
543
570
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, findPow2, isFloatDtype, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-EBRGmEYw.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu.ts
|
|
4
4
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
|
|
|
21
21
|
}
|
|
22
22
|
malloc(size, initialData) {
|
|
23
23
|
let buffer;
|
|
24
|
+
const paddedSize = Math.ceil(size / 4) * 4;
|
|
24
25
|
if (initialData) {
|
|
25
26
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
26
27
|
if (initialData.byteLength < 4096) {
|
|
27
|
-
buffer = this.#createBuffer(
|
|
28
|
-
new Uint8Array(buffer.getMappedRange()).set(
|
|
28
|
+
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
29
|
+
new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
|
|
29
30
|
buffer.unmap();
|
|
30
31
|
} else {
|
|
31
|
-
buffer = this.#createBuffer(
|
|
32
|
-
this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
32
|
+
buffer = this.#createBuffer(paddedSize);
|
|
33
|
+
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
34
|
+
else {
|
|
35
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
36
|
+
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
37
|
+
const remainder = new Uint8Array(4);
|
|
38
|
+
remainder.set(initialData.subarray(aligned));
|
|
39
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
40
|
+
}
|
|
33
41
|
}
|
|
34
|
-
} else buffer = this.#createBuffer(
|
|
42
|
+
} else buffer = this.#createBuffer(paddedSize);
|
|
35
43
|
const slot = this.nextSlot++;
|
|
36
44
|
this.buffers.set(slot, {
|
|
37
45
|
buffer,
|
|
46
|
+
size,
|
|
38
47
|
ref: 1
|
|
39
48
|
});
|
|
40
49
|
return slot;
|
|
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
|
|
|
54
63
|
}
|
|
55
64
|
}
|
|
56
65
|
async read(slot, start, count) {
|
|
57
|
-
const buffer = this.#getBuffer(slot);
|
|
66
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
58
67
|
if (start === void 0) start = 0;
|
|
59
|
-
if (count === void 0) count =
|
|
60
|
-
const
|
|
68
|
+
if (count === void 0) count = size - start;
|
|
69
|
+
const paddedSize = Math.ceil(count / 4) * 4;
|
|
70
|
+
const staging = this.#createBuffer(paddedSize, { read: true });
|
|
61
71
|
try {
|
|
62
72
|
const commandEncoder = this.device.createCommandEncoder();
|
|
63
|
-
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0,
|
|
73
|
+
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
|
|
64
74
|
this.device.queue.submit([commandEncoder.finish()]);
|
|
65
75
|
await staging.mapAsync(GPUMapMode.READ);
|
|
66
76
|
const arrayBuffer = staging.getMappedRange();
|
|
67
|
-
return arrayBuffer.slice();
|
|
77
|
+
return new Uint8Array(arrayBuffer.slice(), 0, count);
|
|
68
78
|
} finally {
|
|
69
79
|
staging.destroy();
|
|
70
80
|
}
|
|
71
81
|
}
|
|
72
82
|
readSync(slot, start, count) {
|
|
73
|
-
const buffer = this.#getBuffer(slot);
|
|
83
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
74
84
|
if (start === void 0) start = 0;
|
|
75
|
-
if (count === void 0) count =
|
|
85
|
+
if (count === void 0) count = size - start;
|
|
76
86
|
return this.syncReader.read(buffer, start, count);
|
|
77
87
|
}
|
|
78
88
|
#cachedShader(kernel) {
|
|
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
|
|
|
103
113
|
});
|
|
104
114
|
}
|
|
105
115
|
dispatch(exe, inputs, outputs) {
|
|
106
|
-
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
107
|
-
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
116
|
+
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
117
|
+
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
108
118
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
109
119
|
}
|
|
110
120
|
#getBuffer(slot) {
|
|
111
121
|
const buffer = this.buffers.get(slot);
|
|
112
122
|
if (!buffer) throw new SlotError(slot);
|
|
113
|
-
return
|
|
123
|
+
return {
|
|
124
|
+
buffer: buffer.buffer,
|
|
125
|
+
size: buffer.size
|
|
126
|
+
};
|
|
114
127
|
}
|
|
115
128
|
/**
|
|
116
129
|
* Create a GPU buffer.
|
|
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
138
151
|
case DType.Int32: return "i32";
|
|
139
152
|
case DType.Uint32: return "u32";
|
|
140
153
|
case DType.Float32: return "f32";
|
|
154
|
+
case DType.Float16: return "f16";
|
|
141
155
|
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
142
156
|
}
|
|
143
157
|
}
|
|
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
|
|
|
148
162
|
if (dtype === DType.Float32) {
|
|
149
163
|
if (Number.isNaN(value)) return "nan()";
|
|
150
164
|
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
165
|
+
return "f32(" + value.toString() + ")";
|
|
166
|
+
}
|
|
167
|
+
if (dtype === DType.Float16) {
|
|
168
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
169
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
170
|
+
return "f16(" + value.toString() + ")";
|
|
154
171
|
}
|
|
155
172
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
156
173
|
}
|
|
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
|
|
|
163
180
|
function pipelineSource(device, kernel) {
|
|
164
181
|
const tune = tuneWebgpu(kernel);
|
|
165
182
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
166
|
-
const { nargs } = kernel;
|
|
183
|
+
const { nargs, reduction: re } = kernel;
|
|
167
184
|
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
168
185
|
const shader = [];
|
|
169
186
|
let indent = "";
|
|
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
|
|
|
174
191
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
175
192
|
else shader.push(line ? indent + line : line);
|
|
176
193
|
};
|
|
194
|
+
if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
|
|
195
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
196
|
+
emit("enable f16;");
|
|
197
|
+
}
|
|
177
198
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
178
|
-
|
|
199
|
+
const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
200
|
+
if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
|
|
179
201
|
emit("");
|
|
180
202
|
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
181
203
|
tune.exp.fold((exp) => {
|
|
182
|
-
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
|
|
204
|
+
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
183
205
|
});
|
|
184
206
|
for (let i = 0; i < nargs; i++) {
|
|
185
207
|
const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
|
|
@@ -226,22 +248,29 @@ function pipelineSource(device, kernel) {
|
|
|
226
248
|
else if (op === AluOp.Sub) source = `(${a} - ${b})`;
|
|
227
249
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
228
250
|
else source = `(${a} * ${b})`;
|
|
229
|
-
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `
|
|
251
|
+
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
230
252
|
else if (op === AluOp.Mod) source = `(${a} % ${b})`;
|
|
231
253
|
else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
232
254
|
else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
233
255
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
234
256
|
else if (op === AluOp.Cmpne) source = `(${a} != ${b})`;
|
|
235
|
-
} else if (AluGroup.Unary.has(op)) {
|
|
257
|
+
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
258
|
+
const a = gen(src[0].src[0]);
|
|
259
|
+
source = `inverseSqrt(${a})`;
|
|
260
|
+
} else {
|
|
236
261
|
const a = gen(src[0]);
|
|
237
262
|
if (op === AluOp.Sin) source = `sin(${a})`;
|
|
238
263
|
else if (op === AluOp.Cos) source = `cos(${a})`;
|
|
264
|
+
else if (op === AluOp.Asin) source = `asin(${a})`;
|
|
265
|
+
else if (op === AluOp.Atan) source = `atan(${a})`;
|
|
239
266
|
else if (op === AluOp.Exp) source = `exp(${a})`;
|
|
240
267
|
else if (op === AluOp.Log) source = `log(${a})`;
|
|
268
|
+
else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
|
|
241
269
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
242
270
|
else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
|
|
243
271
|
else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
244
|
-
}
|
|
272
|
+
}
|
|
273
|
+
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
245
274
|
else if (op === AluOp.Threefry2x32) {
|
|
246
275
|
const x = gensym();
|
|
247
276
|
const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
|
|
@@ -249,15 +278,15 @@ function pipelineSource(device, kernel) {
|
|
|
249
278
|
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
250
279
|
else if (arg === 0) source = `${x}.x`;
|
|
251
280
|
else if (arg === 1) source = `${x}.y`;
|
|
252
|
-
else throw new
|
|
281
|
+
else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
253
282
|
} else if (op === AluOp.Const) return constToWgsl(dtype, arg);
|
|
254
283
|
else if (op === AluOp.Special) return arg[0];
|
|
255
284
|
else if (op === AluOp.Variable) return arg;
|
|
256
285
|
else if (op === AluOp.GlobalIndex) {
|
|
257
|
-
source = `${args[arg]}[${strip1(gen(src[0]))}]`;
|
|
286
|
+
source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
|
|
258
287
|
if (dtype === DType.Bool) source = `(${source} != 0)`;
|
|
259
288
|
}
|
|
260
|
-
if (!source) throw new
|
|
289
|
+
if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
261
290
|
const typeName = dtypeToWgsl(dtype);
|
|
262
291
|
if ((references.get(exp) ?? 0) > 1) {
|
|
263
292
|
const name = gensym();
|
|
@@ -269,13 +298,12 @@ function pipelineSource(device, kernel) {
|
|
|
269
298
|
return source;
|
|
270
299
|
}
|
|
271
300
|
};
|
|
272
|
-
if (!
|
|
301
|
+
if (!re) {
|
|
273
302
|
countReferences(tune.exp);
|
|
274
303
|
let rhs = strip1(gen(tune.exp));
|
|
275
304
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
276
305
|
emit(`result[gidx] = ${rhs};`);
|
|
277
306
|
} else {
|
|
278
|
-
const re = kernel.reduction;
|
|
279
307
|
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
280
308
|
const unroll = tune.size.unroll ?? 1;
|
|
281
309
|
const upcast = tune.size.upcast ?? 1;
|
|
@@ -319,7 +347,7 @@ function pipelineSource(device, kernel) {
|
|
|
319
347
|
const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
|
|
320
348
|
outputIdxExps.push(exp.simplify(cache));
|
|
321
349
|
countReferences(outputIdxExps[i]);
|
|
322
|
-
fusionExps.push(re.
|
|
350
|
+
fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
|
|
323
351
|
countReferences(fusionExps[i]);
|
|
324
352
|
}
|
|
325
353
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -487,13 +515,12 @@ var SyncReader = class SyncReader {
|
|
|
487
515
|
}
|
|
488
516
|
read(buffer, start, count) {
|
|
489
517
|
if (!this.initialized) this.#init();
|
|
490
|
-
if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
|
|
491
518
|
const deviceStorage = this.deviceStorage;
|
|
492
519
|
const deviceContexts = this.deviceContexts;
|
|
493
520
|
const hostContext = this.hostContext;
|
|
494
|
-
const pixelsSize = count / 4;
|
|
521
|
+
const pixelsSize = Math.ceil(count / 4);
|
|
495
522
|
const bytesPerRow = SyncReader.width * 4;
|
|
496
|
-
const valsGPU = new ArrayBuffer(
|
|
523
|
+
const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
|
|
497
524
|
for (let i = 0; i < deviceContexts.length; i++) {
|
|
498
525
|
const texture = deviceContexts[i].getCurrentTexture();
|
|
499
526
|
const readData = (width, height, offset$1) => {
|
|
@@ -537,7 +564,7 @@ var SyncReader = class SyncReader {
|
|
|
537
564
|
}
|
|
538
565
|
if (remainder > 0) readData(remainder, 1, offset);
|
|
539
566
|
}
|
|
540
|
-
return valsGPU;
|
|
567
|
+
return new Uint8Array(valsGPU, 0, count);
|
|
541
568
|
}
|
|
542
569
|
};
|
|
543
570
|
const threefrySrc = `
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.0.
|
|
3
|
+
"version": "0.0.4",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -38,20 +38,21 @@
|
|
|
38
38
|
"devDependencies": {
|
|
39
39
|
"@eslint/js": "^9.31.0",
|
|
40
40
|
"@types/debug": "^4.1.12",
|
|
41
|
-
"@vitest/browser": "^
|
|
41
|
+
"@vitest/browser-playwright": "^4.0.9",
|
|
42
42
|
"@webgpu/types": "^0.1.64",
|
|
43
43
|
"eslint": "^9.31.0",
|
|
44
44
|
"eslint-plugin-import": "^2.32.0",
|
|
45
45
|
"globals": "^16.0.0",
|
|
46
|
-
"playwright": "~1.
|
|
46
|
+
"playwright": "~1.52.0",
|
|
47
47
|
"prettier": "^3.6.2",
|
|
48
48
|
"prettier-plugin-svelte": "^3.4.0",
|
|
49
|
-
"tsdown": "^0.13.
|
|
49
|
+
"tsdown": "^0.13.2",
|
|
50
50
|
"tsx": "^4.20.3",
|
|
51
|
-
"typedoc": "^0.28.
|
|
52
|
-
"
|
|
53
|
-
"typescript
|
|
54
|
-
"
|
|
51
|
+
"typedoc": "^0.28.14",
|
|
52
|
+
"typedoc-theme-fresh": "^0.2.1",
|
|
53
|
+
"typescript": "~5.9.3",
|
|
54
|
+
"typescript-eslint": "^8.46.4",
|
|
55
|
+
"vitest": "^4.0.9"
|
|
55
56
|
},
|
|
56
57
|
"engines": {
|
|
57
58
|
"pnpm": ">=10.0.0"
|
|
@@ -59,7 +60,18 @@
|
|
|
59
60
|
"prettier": {
|
|
60
61
|
"plugins": [
|
|
61
62
|
"prettier-plugin-svelte"
|
|
62
|
-
]
|
|
63
|
+
],
|
|
64
|
+
"overrides": [
|
|
65
|
+
{
|
|
66
|
+
"files": [
|
|
67
|
+
"*.md"
|
|
68
|
+
],
|
|
69
|
+
"options": {
|
|
70
|
+
"printWidth": 100
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
],
|
|
74
|
+
"proseWrap": "always"
|
|
63
75
|
},
|
|
64
76
|
"scripts": {
|
|
65
77
|
"build": "tsdown",
|