@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, findPow2, isFloatDtype, strip1, tuneWebgpu } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-BqDtPGaR.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu.ts
|
|
4
4
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
|
|
|
21
21
|
}
|
|
22
22
|
malloc(size, initialData) {
|
|
23
23
|
let buffer;
|
|
24
|
+
const paddedSize = Math.ceil(size / 4) * 4;
|
|
24
25
|
if (initialData) {
|
|
25
26
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
26
27
|
if (initialData.byteLength < 4096) {
|
|
27
|
-
buffer = this.#createBuffer(
|
|
28
|
-
new Uint8Array(buffer.getMappedRange()).set(
|
|
28
|
+
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
29
|
+
new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
|
|
29
30
|
buffer.unmap();
|
|
30
31
|
} else {
|
|
31
|
-
buffer = this.#createBuffer(
|
|
32
|
-
this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
32
|
+
buffer = this.#createBuffer(paddedSize);
|
|
33
|
+
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
34
|
+
else {
|
|
35
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
36
|
+
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
37
|
+
const remainder = new Uint8Array(4);
|
|
38
|
+
remainder.set(initialData.subarray(aligned));
|
|
39
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
40
|
+
}
|
|
33
41
|
}
|
|
34
|
-
} else buffer = this.#createBuffer(
|
|
42
|
+
} else buffer = this.#createBuffer(paddedSize);
|
|
35
43
|
const slot = this.nextSlot++;
|
|
36
44
|
this.buffers.set(slot, {
|
|
37
45
|
buffer,
|
|
46
|
+
size,
|
|
38
47
|
ref: 1
|
|
39
48
|
});
|
|
40
49
|
return slot;
|
|
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
|
|
|
54
63
|
}
|
|
55
64
|
}
|
|
56
65
|
async read(slot, start, count) {
|
|
57
|
-
const buffer = this.#getBuffer(slot);
|
|
66
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
58
67
|
if (start === void 0) start = 0;
|
|
59
|
-
if (count === void 0) count =
|
|
60
|
-
const
|
|
68
|
+
if (count === void 0) count = size - start;
|
|
69
|
+
const paddedSize = Math.ceil(count / 4) * 4;
|
|
70
|
+
const staging = this.#createBuffer(paddedSize, { read: true });
|
|
61
71
|
try {
|
|
62
72
|
const commandEncoder = this.device.createCommandEncoder();
|
|
63
|
-
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0,
|
|
73
|
+
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
|
|
64
74
|
this.device.queue.submit([commandEncoder.finish()]);
|
|
65
75
|
await staging.mapAsync(GPUMapMode.READ);
|
|
66
76
|
const arrayBuffer = staging.getMappedRange();
|
|
67
|
-
return arrayBuffer.slice();
|
|
77
|
+
return new Uint8Array(arrayBuffer.slice(), 0, count);
|
|
68
78
|
} finally {
|
|
69
79
|
staging.destroy();
|
|
70
80
|
}
|
|
71
81
|
}
|
|
72
82
|
readSync(slot, start, count) {
|
|
73
|
-
const buffer = this.#getBuffer(slot);
|
|
83
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
74
84
|
if (start === void 0) start = 0;
|
|
75
|
-
if (count === void 0) count =
|
|
85
|
+
if (count === void 0) count = size - start;
|
|
76
86
|
return this.syncReader.read(buffer, start, count);
|
|
77
87
|
}
|
|
78
88
|
#cachedShader(kernel) {
|
|
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
|
|
|
103
113
|
});
|
|
104
114
|
}
|
|
105
115
|
dispatch(exe, inputs, outputs) {
|
|
106
|
-
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
107
|
-
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
116
|
+
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
117
|
+
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
108
118
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
109
119
|
}
|
|
110
120
|
#getBuffer(slot) {
|
|
111
121
|
const buffer = this.buffers.get(slot);
|
|
112
122
|
if (!buffer) throw new SlotError(slot);
|
|
113
|
-
return
|
|
123
|
+
return {
|
|
124
|
+
buffer: buffer.buffer,
|
|
125
|
+
size: buffer.size
|
|
126
|
+
};
|
|
114
127
|
}
|
|
115
128
|
/**
|
|
116
129
|
* Create a GPU buffer.
|
|
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
138
151
|
case DType.Int32: return "i32";
|
|
139
152
|
case DType.Uint32: return "u32";
|
|
140
153
|
case DType.Float32: return "f32";
|
|
154
|
+
case DType.Float16: return "f16";
|
|
141
155
|
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
142
156
|
}
|
|
143
157
|
}
|
|
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
|
|
|
148
162
|
if (dtype === DType.Float32) {
|
|
149
163
|
if (Number.isNaN(value)) return "nan()";
|
|
150
164
|
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
165
|
+
return "f32(" + value.toString() + ")";
|
|
166
|
+
}
|
|
167
|
+
if (dtype === DType.Float16) {
|
|
168
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
169
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
170
|
+
return "f16(" + value.toString() + ")";
|
|
154
171
|
}
|
|
155
172
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
156
173
|
}
|
|
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
|
|
|
163
180
|
function pipelineSource(device, kernel) {
|
|
164
181
|
const tune = tuneWebgpu(kernel);
|
|
165
182
|
if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
166
|
-
const { nargs } = kernel;
|
|
183
|
+
const { nargs, reduction: re } = kernel;
|
|
167
184
|
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
168
185
|
const shader = [];
|
|
169
186
|
let indent = "";
|
|
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
|
|
|
174
191
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
175
192
|
else shader.push(line ? indent + line : line);
|
|
176
193
|
};
|
|
194
|
+
if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
|
|
195
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
196
|
+
emit("enable f16;");
|
|
197
|
+
}
|
|
177
198
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
178
|
-
|
|
199
|
+
const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
200
|
+
if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
|
|
179
201
|
emit("");
|
|
180
202
|
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
181
203
|
tune.exp.fold((exp) => {
|
|
182
|
-
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
|
|
204
|
+
if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
183
205
|
});
|
|
184
206
|
for (let i = 0; i < nargs; i++) {
|
|
185
207
|
const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
|
|
@@ -226,7 +248,7 @@ function pipelineSource(device, kernel) {
|
|
|
226
248
|
else if (op === AluOp.Sub) source = `(${a} - ${b})`;
|
|
227
249
|
else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
|
|
228
250
|
else source = `(${a} * ${b})`;
|
|
229
|
-
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `
|
|
251
|
+
else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
230
252
|
else if (op === AluOp.Mod) source = `(${a} % ${b})`;
|
|
231
253
|
else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
|
|
232
254
|
else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
@@ -238,6 +260,7 @@ function pipelineSource(device, kernel) {
|
|
|
238
260
|
else if (op === AluOp.Cos) source = `cos(${a})`;
|
|
239
261
|
else if (op === AluOp.Exp) source = `exp(${a})`;
|
|
240
262
|
else if (op === AluOp.Log) source = `log(${a})`;
|
|
263
|
+
else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
|
|
241
264
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
242
265
|
else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
|
|
243
266
|
else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
@@ -249,15 +272,15 @@ function pipelineSource(device, kernel) {
|
|
|
249
272
|
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
250
273
|
else if (arg === 0) source = `${x}.x`;
|
|
251
274
|
else if (arg === 1) source = `${x}.y`;
|
|
252
|
-
else throw new
|
|
275
|
+
else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
253
276
|
} else if (op === AluOp.Const) return constToWgsl(dtype, arg);
|
|
254
277
|
else if (op === AluOp.Special) return arg[0];
|
|
255
278
|
else if (op === AluOp.Variable) return arg;
|
|
256
279
|
else if (op === AluOp.GlobalIndex) {
|
|
257
|
-
source = `${args[arg]}[${strip1(gen(src[0]))}]`;
|
|
280
|
+
source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
|
|
258
281
|
if (dtype === DType.Bool) source = `(${source} != 0)`;
|
|
259
282
|
}
|
|
260
|
-
if (!source) throw new
|
|
283
|
+
if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
261
284
|
const typeName = dtypeToWgsl(dtype);
|
|
262
285
|
if ((references.get(exp) ?? 0) > 1) {
|
|
263
286
|
const name = gensym();
|
|
@@ -269,13 +292,12 @@ function pipelineSource(device, kernel) {
|
|
|
269
292
|
return source;
|
|
270
293
|
}
|
|
271
294
|
};
|
|
272
|
-
if (!
|
|
295
|
+
if (!re) {
|
|
273
296
|
countReferences(tune.exp);
|
|
274
297
|
let rhs = strip1(gen(tune.exp));
|
|
275
298
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
276
299
|
emit(`result[gidx] = ${rhs};`);
|
|
277
300
|
} else {
|
|
278
|
-
const re = kernel.reduction;
|
|
279
301
|
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
280
302
|
const unroll = tune.size.unroll ?? 1;
|
|
281
303
|
const upcast = tune.size.upcast ?? 1;
|
|
@@ -319,7 +341,7 @@ function pipelineSource(device, kernel) {
|
|
|
319
341
|
const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
|
|
320
342
|
outputIdxExps.push(exp.simplify(cache));
|
|
321
343
|
countReferences(outputIdxExps[i]);
|
|
322
|
-
fusionExps.push(re.
|
|
344
|
+
fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
|
|
323
345
|
countReferences(fusionExps[i]);
|
|
324
346
|
}
|
|
325
347
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -487,13 +509,12 @@ var SyncReader = class SyncReader {
|
|
|
487
509
|
}
|
|
488
510
|
read(buffer, start, count) {
|
|
489
511
|
if (!this.initialized) this.#init();
|
|
490
|
-
if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
|
|
491
512
|
const deviceStorage = this.deviceStorage;
|
|
492
513
|
const deviceContexts = this.deviceContexts;
|
|
493
514
|
const hostContext = this.hostContext;
|
|
494
|
-
const pixelsSize = count / 4;
|
|
515
|
+
const pixelsSize = Math.ceil(count / 4);
|
|
495
516
|
const bytesPerRow = SyncReader.width * 4;
|
|
496
|
-
const valsGPU = new ArrayBuffer(
|
|
517
|
+
const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
|
|
497
518
|
for (let i = 0; i < deviceContexts.length; i++) {
|
|
498
519
|
const texture = deviceContexts[i].getCurrentTexture();
|
|
499
520
|
const readData = (width, height, offset$1) => {
|
|
@@ -537,7 +558,7 @@ var SyncReader = class SyncReader {
|
|
|
537
558
|
}
|
|
538
559
|
if (remainder > 0) readData(remainder, 1, offset);
|
|
539
560
|
}
|
|
540
|
-
return valsGPU;
|
|
561
|
+
return new Uint8Array(valsGPU, 0, count);
|
|
541
562
|
}
|
|
542
563
|
};
|
|
543
564
|
const threefrySrc = `
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-D2C4MJRP.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu.ts
|
|
4
4
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
|
|
|
21
21
|
}
|
|
22
22
|
malloc(size, initialData) {
|
|
23
23
|
let buffer;
|
|
24
|
+
const paddedSize = Math.ceil(size / 4) * 4;
|
|
24
25
|
if (initialData) {
|
|
25
26
|
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
26
27
|
if (initialData.byteLength < 4096) {
|
|
27
|
-
buffer = this.#createBuffer(
|
|
28
|
-
new Uint8Array(buffer.getMappedRange()).set(
|
|
28
|
+
buffer = this.#createBuffer(paddedSize, { mapped: true });
|
|
29
|
+
new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
|
|
29
30
|
buffer.unmap();
|
|
30
31
|
} else {
|
|
31
|
-
buffer = this.#createBuffer(
|
|
32
|
-
this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
32
|
+
buffer = this.#createBuffer(paddedSize);
|
|
33
|
+
if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
|
|
34
|
+
else {
|
|
35
|
+
const aligned = initialData.byteLength - initialData.byteLength % 4;
|
|
36
|
+
this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
|
|
37
|
+
const remainder = new Uint8Array(4);
|
|
38
|
+
remainder.set(initialData.subarray(aligned));
|
|
39
|
+
this.device.queue.writeBuffer(buffer, aligned, remainder);
|
|
40
|
+
}
|
|
33
41
|
}
|
|
34
|
-
} else buffer = this.#createBuffer(
|
|
42
|
+
} else buffer = this.#createBuffer(paddedSize);
|
|
35
43
|
const slot = this.nextSlot++;
|
|
36
44
|
this.buffers.set(slot, {
|
|
37
45
|
buffer,
|
|
46
|
+
size,
|
|
38
47
|
ref: 1
|
|
39
48
|
});
|
|
40
49
|
return slot;
|
|
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
|
|
|
54
63
|
}
|
|
55
64
|
}
|
|
56
65
|
async read(slot, start, count) {
|
|
57
|
-
const buffer = this.#getBuffer(slot);
|
|
66
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
58
67
|
if (start === void 0) start = 0;
|
|
59
|
-
if (count === void 0) count =
|
|
60
|
-
const
|
|
68
|
+
if (count === void 0) count = size - start;
|
|
69
|
+
const paddedSize = Math.ceil(count / 4) * 4;
|
|
70
|
+
const staging = this.#createBuffer(paddedSize, { read: true });
|
|
61
71
|
try {
|
|
62
72
|
const commandEncoder = this.device.createCommandEncoder();
|
|
63
|
-
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0,
|
|
73
|
+
commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
|
|
64
74
|
this.device.queue.submit([commandEncoder.finish()]);
|
|
65
75
|
await staging.mapAsync(GPUMapMode.READ);
|
|
66
76
|
const arrayBuffer = staging.getMappedRange();
|
|
67
|
-
return arrayBuffer.slice();
|
|
77
|
+
return new Uint8Array(arrayBuffer.slice(), 0, count);
|
|
68
78
|
} finally {
|
|
69
79
|
staging.destroy();
|
|
70
80
|
}
|
|
71
81
|
}
|
|
72
82
|
readSync(slot, start, count) {
|
|
73
|
-
const buffer = this.#getBuffer(slot);
|
|
83
|
+
const { buffer, size } = this.#getBuffer(slot);
|
|
74
84
|
if (start === void 0) start = 0;
|
|
75
|
-
if (count === void 0) count =
|
|
85
|
+
if (count === void 0) count = size - start;
|
|
76
86
|
return this.syncReader.read(buffer, start, count);
|
|
77
87
|
}
|
|
78
88
|
#cachedShader(kernel) {
|
|
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
|
|
|
103
113
|
});
|
|
104
114
|
}
|
|
105
115
|
dispatch(exe, inputs, outputs) {
|
|
106
|
-
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
107
|
-
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
116
|
+
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
117
|
+
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
|
|
108
118
|
pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
|
|
109
119
|
}
|
|
110
120
|
#getBuffer(slot) {
|
|
111
121
|
const buffer = this.buffers.get(slot);
|
|
112
122
|
if (!buffer) throw new require_backend.SlotError(slot);
|
|
113
|
-
return
|
|
123
|
+
return {
|
|
124
|
+
buffer: buffer.buffer,
|
|
125
|
+
size: buffer.size
|
|
126
|
+
};
|
|
114
127
|
}
|
|
115
128
|
/**
|
|
116
129
|
* Create a GPU buffer.
|
|
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
|
|
|
138
151
|
case require_backend.DType.Int32: return "i32";
|
|
139
152
|
case require_backend.DType.Uint32: return "u32";
|
|
140
153
|
case require_backend.DType.Float32: return "f32";
|
|
154
|
+
case require_backend.DType.Float16: return "f16";
|
|
141
155
|
default: throw new Error(`Unsupported dtype: ${dtype}`);
|
|
142
156
|
}
|
|
143
157
|
}
|
|
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
|
|
|
148
162
|
if (dtype === require_backend.DType.Float32) {
|
|
149
163
|
if (Number.isNaN(value)) return "nan()";
|
|
150
164
|
if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
165
|
+
return "f32(" + value.toString() + ")";
|
|
166
|
+
}
|
|
167
|
+
if (dtype === require_backend.DType.Float16) {
|
|
168
|
+
if (Number.isNaN(value)) return "f16(nan())";
|
|
169
|
+
if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
|
|
170
|
+
return "f16(" + value.toString() + ")";
|
|
154
171
|
}
|
|
155
172
|
throw new Error(`Unsupported const dtype: ${dtype}`);
|
|
156
173
|
}
|
|
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
|
|
|
163
180
|
function pipelineSource(device, kernel) {
|
|
164
181
|
const tune = require_backend.tuneWebgpu(kernel);
|
|
165
182
|
if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
|
|
166
|
-
const { nargs } = kernel;
|
|
183
|
+
const { nargs, reduction: re } = kernel;
|
|
167
184
|
const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
|
|
168
185
|
const shader = [];
|
|
169
186
|
let indent = "";
|
|
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
|
|
|
174
191
|
else if (line === popIndent) indent = indent.slice(0, -2);
|
|
175
192
|
else shader.push(line ? indent + line : line);
|
|
176
193
|
};
|
|
194
|
+
if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || re?.epilogue.some((exp) => exp.dtype === require_backend.DType.Float16)) {
|
|
195
|
+
if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
|
|
196
|
+
emit("enable f16;");
|
|
197
|
+
}
|
|
177
198
|
emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
|
|
178
|
-
|
|
199
|
+
const distinctOps = require_backend.union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
|
|
200
|
+
if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
|
|
179
201
|
emit("");
|
|
180
202
|
const usedArgs = Array.from({ length: nargs }, () => null);
|
|
181
203
|
tune.exp.fold((exp) => {
|
|
182
|
-
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
|
|
204
|
+
if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
|
|
183
205
|
});
|
|
184
206
|
for (let i = 0; i < nargs; i++) {
|
|
185
207
|
const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
|
|
@@ -226,7 +248,7 @@ function pipelineSource(device, kernel) {
|
|
|
226
248
|
else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
|
|
227
249
|
else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
|
|
228
250
|
else source = `(${a} * ${b})`;
|
|
229
|
-
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `
|
|
251
|
+
else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
|
|
230
252
|
else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
|
|
231
253
|
else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
232
254
|
else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
@@ -238,6 +260,7 @@ function pipelineSource(device, kernel) {
|
|
|
238
260
|
else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
|
|
239
261
|
else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
|
|
240
262
|
else if (op === require_backend.AluOp.Log) source = `log(${a})`;
|
|
263
|
+
else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
|
|
241
264
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
242
265
|
else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
|
|
243
266
|
else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
@@ -249,15 +272,15 @@ function pipelineSource(device, kernel) {
|
|
|
249
272
|
if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
|
|
250
273
|
else if (arg === 0) source = `${x}.x`;
|
|
251
274
|
else if (arg === 1) source = `${x}.y`;
|
|
252
|
-
else throw new
|
|
275
|
+
else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
253
276
|
} else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
|
|
254
277
|
else if (op === require_backend.AluOp.Special) return arg[0];
|
|
255
278
|
else if (op === require_backend.AluOp.Variable) return arg;
|
|
256
279
|
else if (op === require_backend.AluOp.GlobalIndex) {
|
|
257
|
-
source = `${args[arg]}[${require_backend.strip1(gen(src[0]))}]`;
|
|
280
|
+
source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
|
|
258
281
|
if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
|
|
259
282
|
}
|
|
260
|
-
if (!source) throw new
|
|
283
|
+
if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
|
|
261
284
|
const typeName = dtypeToWgsl(dtype);
|
|
262
285
|
if ((references.get(exp) ?? 0) > 1) {
|
|
263
286
|
const name = gensym();
|
|
@@ -269,13 +292,12 @@ function pipelineSource(device, kernel) {
|
|
|
269
292
|
return source;
|
|
270
293
|
}
|
|
271
294
|
};
|
|
272
|
-
if (!
|
|
295
|
+
if (!re) {
|
|
273
296
|
countReferences(tune.exp);
|
|
274
297
|
let rhs = require_backend.strip1(gen(tune.exp));
|
|
275
298
|
if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
|
|
276
299
|
emit(`result[gidx] = ${rhs};`);
|
|
277
300
|
} else {
|
|
278
|
-
const re = kernel.reduction;
|
|
279
301
|
if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
|
|
280
302
|
const unroll = tune.size.unroll ?? 1;
|
|
281
303
|
const upcast = tune.size.upcast ?? 1;
|
|
@@ -319,7 +341,7 @@ function pipelineSource(device, kernel) {
|
|
|
319
341
|
const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
|
|
320
342
|
outputIdxExps.push(exp.simplify(cache));
|
|
321
343
|
countReferences(outputIdxExps[i]);
|
|
322
|
-
fusionExps.push(re.
|
|
344
|
+
fusionExps.push(re.epilogue.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
|
|
323
345
|
countReferences(fusionExps[i]);
|
|
324
346
|
}
|
|
325
347
|
for (let i = 0; i < upcast; i++) {
|
|
@@ -487,13 +509,12 @@ var SyncReader = class SyncReader {
|
|
|
487
509
|
}
|
|
488
510
|
read(buffer, start, count) {
|
|
489
511
|
if (!this.initialized) this.#init();
|
|
490
|
-
if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
|
|
491
512
|
const deviceStorage = this.deviceStorage;
|
|
492
513
|
const deviceContexts = this.deviceContexts;
|
|
493
514
|
const hostContext = this.hostContext;
|
|
494
|
-
const pixelsSize = count / 4;
|
|
515
|
+
const pixelsSize = Math.ceil(count / 4);
|
|
495
516
|
const bytesPerRow = SyncReader.width * 4;
|
|
496
|
-
const valsGPU = new ArrayBuffer(
|
|
517
|
+
const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
|
|
497
518
|
for (let i = 0; i < deviceContexts.length; i++) {
|
|
498
519
|
const texture = deviceContexts[i].getCurrentTexture();
|
|
499
520
|
const readData = (width, height, offset$1) => {
|
|
@@ -537,7 +558,7 @@ var SyncReader = class SyncReader {
|
|
|
537
558
|
}
|
|
538
559
|
if (remainder > 0) readData(remainder, 1, offset);
|
|
539
560
|
}
|
|
540
|
-
return valsGPU;
|
|
561
|
+
return new Uint8Array(valsGPU, 0, count);
|
|
541
562
|
}
|
|
542
563
|
};
|
|
543
564
|
const threefrySrc = `
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.0.
|
|
3
|
+
"version": "0.0.3",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -43,14 +43,15 @@
|
|
|
43
43
|
"eslint": "^9.31.0",
|
|
44
44
|
"eslint-plugin-import": "^2.32.0",
|
|
45
45
|
"globals": "^16.0.0",
|
|
46
|
-
"playwright": "~1.
|
|
46
|
+
"playwright": "~1.52.0",
|
|
47
47
|
"prettier": "^3.6.2",
|
|
48
48
|
"prettier-plugin-svelte": "^3.4.0",
|
|
49
|
-
"tsdown": "^0.13.
|
|
49
|
+
"tsdown": "^0.13.2",
|
|
50
50
|
"tsx": "^4.20.3",
|
|
51
|
-
"typedoc": "^0.28.
|
|
52
|
-
"
|
|
53
|
-
"typescript
|
|
51
|
+
"typedoc": "^0.28.14",
|
|
52
|
+
"typedoc-theme-fresh": "^0.2.1",
|
|
53
|
+
"typescript": "~5.9.3",
|
|
54
|
+
"typescript-eslint": "^8.46.4",
|
|
54
55
|
"vitest": "^3.2.4"
|
|
55
56
|
},
|
|
56
57
|
"engines": {
|