@jax-js/jax 0.0.2 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, findPow2, isFloatDtype, strip1, tuneWebgpu } from "./backend-1eVbAoaV.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-BqDtPGaR.js";
2
2
 
3
3
  //#region src/backend/webgpu.ts
4
4
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
21
21
  }
22
22
  malloc(size, initialData) {
23
23
  let buffer;
24
+ const paddedSize = Math.ceil(size / 4) * 4;
24
25
  if (initialData) {
25
26
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
26
27
  if (initialData.byteLength < 4096) {
27
- buffer = this.#createBuffer(size, { mapped: true });
28
- new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(initialData));
28
+ buffer = this.#createBuffer(paddedSize, { mapped: true });
29
+ new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
29
30
  buffer.unmap();
30
31
  } else {
31
- buffer = this.#createBuffer(size);
32
- this.device.queue.writeBuffer(buffer, 0, initialData);
32
+ buffer = this.#createBuffer(paddedSize);
33
+ if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
34
+ else {
35
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
36
+ this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
37
+ const remainder = new Uint8Array(4);
38
+ remainder.set(initialData.subarray(aligned));
39
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
40
+ }
33
41
  }
34
- } else buffer = this.#createBuffer(size);
42
+ } else buffer = this.#createBuffer(paddedSize);
35
43
  const slot = this.nextSlot++;
36
44
  this.buffers.set(slot, {
37
45
  buffer,
46
+ size,
38
47
  ref: 1
39
48
  });
40
49
  return slot;
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
54
63
  }
55
64
  }
56
65
  async read(slot, start, count) {
57
- const buffer = this.#getBuffer(slot);
66
+ const { buffer, size } = this.#getBuffer(slot);
58
67
  if (start === void 0) start = 0;
59
- if (count === void 0) count = buffer.size - start;
60
- const staging = this.#createBuffer(count, { read: true });
68
+ if (count === void 0) count = size - start;
69
+ const paddedSize = Math.ceil(count / 4) * 4;
70
+ const staging = this.#createBuffer(paddedSize, { read: true });
61
71
  try {
62
72
  const commandEncoder = this.device.createCommandEncoder();
63
- commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, count);
73
+ commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
64
74
  this.device.queue.submit([commandEncoder.finish()]);
65
75
  await staging.mapAsync(GPUMapMode.READ);
66
76
  const arrayBuffer = staging.getMappedRange();
67
- return arrayBuffer.slice();
77
+ return new Uint8Array(arrayBuffer.slice(), 0, count);
68
78
  } finally {
69
79
  staging.destroy();
70
80
  }
71
81
  }
72
82
  readSync(slot, start, count) {
73
- const buffer = this.#getBuffer(slot);
83
+ const { buffer, size } = this.#getBuffer(slot);
74
84
  if (start === void 0) start = 0;
75
- if (count === void 0) count = buffer.size - start;
85
+ if (count === void 0) count = size - start;
76
86
  return this.syncReader.read(buffer, start, count);
77
87
  }
78
88
  #cachedShader(kernel) {
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
103
113
  });
104
114
  }
105
115
  dispatch(exe, inputs, outputs) {
106
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
107
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
116
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
117
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
108
118
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
109
119
  }
110
120
  #getBuffer(slot) {
111
121
  const buffer = this.buffers.get(slot);
112
122
  if (!buffer) throw new SlotError(slot);
113
- return buffer.buffer;
123
+ return {
124
+ buffer: buffer.buffer,
125
+ size: buffer.size
126
+ };
114
127
  }
115
128
  /**
116
129
  * Create a GPU buffer.
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
138
151
  case DType.Int32: return "i32";
139
152
  case DType.Uint32: return "u32";
140
153
  case DType.Float32: return "f32";
154
+ case DType.Float16: return "f16";
141
155
  default: throw new Error(`Unsupported dtype: ${dtype}`);
142
156
  }
143
157
  }
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
148
162
  if (dtype === DType.Float32) {
149
163
  if (Number.isNaN(value)) return "nan()";
150
164
  if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
151
- let s = value.toString();
152
- if (!s.includes(".")) s += ".0";
153
- return s;
165
+ return "f32(" + value.toString() + ")";
166
+ }
167
+ if (dtype === DType.Float16) {
168
+ if (Number.isNaN(value)) return "f16(nan())";
169
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
170
+ return "f16(" + value.toString() + ")";
154
171
  }
155
172
  throw new Error(`Unsupported const dtype: ${dtype}`);
156
173
  }
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
163
180
  function pipelineSource(device, kernel) {
164
181
  const tune = tuneWebgpu(kernel);
165
182
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
166
- const { nargs } = kernel;
183
+ const { nargs, reduction: re } = kernel;
167
184
  const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
168
185
  const shader = [];
169
186
  let indent = "";
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
174
191
  else if (line === popIndent) indent = indent.slice(0, -2);
175
192
  else shader.push(line ? indent + line : line);
176
193
  };
194
+ if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
195
+ if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
196
+ emit("enable f16;");
197
+ }
177
198
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
178
- if (tune.exp.collect((exp) => exp.op === AluOp.Threefry2x32).length > 0) emit(threefrySrc);
199
+ const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
200
+ if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
179
201
  emit("");
180
202
  const usedArgs = Array.from({ length: nargs }, () => null);
181
203
  tune.exp.fold((exp) => {
182
- if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
204
+ if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
183
205
  });
184
206
  for (let i = 0; i < nargs; i++) {
185
207
  const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
@@ -226,7 +248,7 @@ function pipelineSource(device, kernel) {
226
248
  else if (op === AluOp.Sub) source = `(${a} - ${b})`;
227
249
  else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
228
250
  else source = `(${a} * ${b})`;
229
- else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `floor(${a} / ${b})` : `(${a} / ${b})`;
251
+ else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
230
252
  else if (op === AluOp.Mod) source = `(${a} % ${b})`;
231
253
  else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
232
254
  else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
@@ -238,6 +260,7 @@ function pipelineSource(device, kernel) {
238
260
  else if (op === AluOp.Cos) source = `cos(${a})`;
239
261
  else if (op === AluOp.Exp) source = `exp(${a})`;
240
262
  else if (op === AluOp.Log) source = `log(${a})`;
263
+ else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
241
264
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
242
265
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
243
266
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
@@ -249,15 +272,15 @@ function pipelineSource(device, kernel) {
249
272
  if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
250
273
  else if (arg === 0) source = `${x}.x`;
251
274
  else if (arg === 1) source = `${x}.y`;
252
- else throw new Error("Invalid Threefry2x32 mode: " + arg);
275
+ else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
253
276
  } else if (op === AluOp.Const) return constToWgsl(dtype, arg);
254
277
  else if (op === AluOp.Special) return arg[0];
255
278
  else if (op === AluOp.Variable) return arg;
256
279
  else if (op === AluOp.GlobalIndex) {
257
- source = `${args[arg]}[${strip1(gen(src[0]))}]`;
280
+ source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
258
281
  if (dtype === DType.Bool) source = `(${source} != 0)`;
259
282
  }
260
- if (!source) throw new Error(`Missing impl for op: ${op}`);
283
+ if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
261
284
  const typeName = dtypeToWgsl(dtype);
262
285
  if ((references.get(exp) ?? 0) > 1) {
263
286
  const name = gensym();
@@ -269,13 +292,12 @@ function pipelineSource(device, kernel) {
269
292
  return source;
270
293
  }
271
294
  };
272
- if (!kernel.reduction) {
295
+ if (!re) {
273
296
  countReferences(tune.exp);
274
297
  let rhs = strip1(gen(tune.exp));
275
298
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
276
299
  emit(`result[gidx] = ${rhs};`);
277
300
  } else {
278
- const re = kernel.reduction;
279
301
  if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
280
302
  const unroll = tune.size.unroll ?? 1;
281
303
  const upcast = tune.size.upcast ?? 1;
@@ -319,7 +341,7 @@ function pipelineSource(device, kernel) {
319
341
  const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
320
342
  outputIdxExps.push(exp.simplify(cache));
321
343
  countReferences(outputIdxExps[i]);
322
- fusionExps.push(re.fusion.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
344
+ fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
323
345
  countReferences(fusionExps[i]);
324
346
  }
325
347
  for (let i = 0; i < upcast; i++) {
@@ -487,13 +509,12 @@ var SyncReader = class SyncReader {
487
509
  }
488
510
  read(buffer, start, count) {
489
511
  if (!this.initialized) this.#init();
490
- if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
491
512
  const deviceStorage = this.deviceStorage;
492
513
  const deviceContexts = this.deviceContexts;
493
514
  const hostContext = this.hostContext;
494
- const pixelsSize = count / 4;
515
+ const pixelsSize = Math.ceil(count / 4);
495
516
  const bytesPerRow = SyncReader.width * 4;
496
- const valsGPU = new ArrayBuffer(count);
517
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
497
518
  for (let i = 0; i < deviceContexts.length; i++) {
498
519
  const texture = deviceContexts[i].getCurrentTexture();
499
520
  const readData = (width, height, offset$1) => {
@@ -537,7 +558,7 @@ var SyncReader = class SyncReader {
537
558
  }
538
559
  if (remainder > 0) readData(remainder, 1, offset);
539
560
  }
540
- return valsGPU;
561
+ return new Uint8Array(valsGPU, 0, count);
541
562
  }
542
563
  };
543
564
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-BK21PBVP.cjs');
1
+ const require_backend = require('./backend-D2C4MJRP.cjs');
2
2
 
3
3
  //#region src/backend/webgpu.ts
4
4
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
21
21
  }
22
22
  malloc(size, initialData) {
23
23
  let buffer;
24
+ const paddedSize = Math.ceil(size / 4) * 4;
24
25
  if (initialData) {
25
26
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
26
27
  if (initialData.byteLength < 4096) {
27
- buffer = this.#createBuffer(size, { mapped: true });
28
- new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(initialData));
28
+ buffer = this.#createBuffer(paddedSize, { mapped: true });
29
+ new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
29
30
  buffer.unmap();
30
31
  } else {
31
- buffer = this.#createBuffer(size);
32
- this.device.queue.writeBuffer(buffer, 0, initialData);
32
+ buffer = this.#createBuffer(paddedSize);
33
+ if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
34
+ else {
35
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
36
+ this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
37
+ const remainder = new Uint8Array(4);
38
+ remainder.set(initialData.subarray(aligned));
39
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
40
+ }
33
41
  }
34
- } else buffer = this.#createBuffer(size);
42
+ } else buffer = this.#createBuffer(paddedSize);
35
43
  const slot = this.nextSlot++;
36
44
  this.buffers.set(slot, {
37
45
  buffer,
46
+ size,
38
47
  ref: 1
39
48
  });
40
49
  return slot;
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
54
63
  }
55
64
  }
56
65
  async read(slot, start, count) {
57
- const buffer = this.#getBuffer(slot);
66
+ const { buffer, size } = this.#getBuffer(slot);
58
67
  if (start === void 0) start = 0;
59
- if (count === void 0) count = buffer.size - start;
60
- const staging = this.#createBuffer(count, { read: true });
68
+ if (count === void 0) count = size - start;
69
+ const paddedSize = Math.ceil(count / 4) * 4;
70
+ const staging = this.#createBuffer(paddedSize, { read: true });
61
71
  try {
62
72
  const commandEncoder = this.device.createCommandEncoder();
63
- commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, count);
73
+ commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
64
74
  this.device.queue.submit([commandEncoder.finish()]);
65
75
  await staging.mapAsync(GPUMapMode.READ);
66
76
  const arrayBuffer = staging.getMappedRange();
67
- return arrayBuffer.slice();
77
+ return new Uint8Array(arrayBuffer.slice(), 0, count);
68
78
  } finally {
69
79
  staging.destroy();
70
80
  }
71
81
  }
72
82
  readSync(slot, start, count) {
73
- const buffer = this.#getBuffer(slot);
83
+ const { buffer, size } = this.#getBuffer(slot);
74
84
  if (start === void 0) start = 0;
75
- if (count === void 0) count = buffer.size - start;
85
+ if (count === void 0) count = size - start;
76
86
  return this.syncReader.read(buffer, start, count);
77
87
  }
78
88
  #cachedShader(kernel) {
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
103
113
  });
104
114
  }
105
115
  dispatch(exe, inputs, outputs) {
106
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
107
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
116
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
117
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
108
118
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
109
119
  }
110
120
  #getBuffer(slot) {
111
121
  const buffer = this.buffers.get(slot);
112
122
  if (!buffer) throw new require_backend.SlotError(slot);
113
- return buffer.buffer;
123
+ return {
124
+ buffer: buffer.buffer,
125
+ size: buffer.size
126
+ };
114
127
  }
115
128
  /**
116
129
  * Create a GPU buffer.
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
138
151
  case require_backend.DType.Int32: return "i32";
139
152
  case require_backend.DType.Uint32: return "u32";
140
153
  case require_backend.DType.Float32: return "f32";
154
+ case require_backend.DType.Float16: return "f16";
141
155
  default: throw new Error(`Unsupported dtype: ${dtype}`);
142
156
  }
143
157
  }
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
148
162
  if (dtype === require_backend.DType.Float32) {
149
163
  if (Number.isNaN(value)) return "nan()";
150
164
  if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
151
- let s = value.toString();
152
- if (!s.includes(".")) s += ".0";
153
- return s;
165
+ return "f32(" + value.toString() + ")";
166
+ }
167
+ if (dtype === require_backend.DType.Float16) {
168
+ if (Number.isNaN(value)) return "f16(nan())";
169
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
170
+ return "f16(" + value.toString() + ")";
154
171
  }
155
172
  throw new Error(`Unsupported const dtype: ${dtype}`);
156
173
  }
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
163
180
  function pipelineSource(device, kernel) {
164
181
  const tune = require_backend.tuneWebgpu(kernel);
165
182
  if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
166
- const { nargs } = kernel;
183
+ const { nargs, reduction: re } = kernel;
167
184
  const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
168
185
  const shader = [];
169
186
  let indent = "";
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
174
191
  else if (line === popIndent) indent = indent.slice(0, -2);
175
192
  else shader.push(line ? indent + line : line);
176
193
  };
194
+ if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || re?.epilogue.some((exp) => exp.dtype === require_backend.DType.Float16)) {
195
+ if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
196
+ emit("enable f16;");
197
+ }
177
198
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
178
- if (tune.exp.collect((exp) => exp.op === require_backend.AluOp.Threefry2x32).length > 0) emit(threefrySrc);
199
+ const distinctOps = require_backend.union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
200
+ if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
179
201
  emit("");
180
202
  const usedArgs = Array.from({ length: nargs }, () => null);
181
203
  tune.exp.fold((exp) => {
182
- if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
204
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
183
205
  });
184
206
  for (let i = 0; i < nargs; i++) {
185
207
  const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
@@ -226,7 +248,7 @@ function pipelineSource(device, kernel) {
226
248
  else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
227
249
  else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
228
250
  else source = `(${a} * ${b})`;
229
- else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `floor(${a} / ${b})` : `(${a} / ${b})`;
251
+ else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
230
252
  else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
231
253
  else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
232
254
  else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
@@ -238,6 +260,7 @@ function pipelineSource(device, kernel) {
238
260
  else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
239
261
  else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
240
262
  else if (op === require_backend.AluOp.Log) source = `log(${a})`;
263
+ else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
241
264
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
242
265
  else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
243
266
  else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
@@ -249,15 +272,15 @@ function pipelineSource(device, kernel) {
249
272
  if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
250
273
  else if (arg === 0) source = `${x}.x`;
251
274
  else if (arg === 1) source = `${x}.y`;
252
- else throw new Error("Invalid Threefry2x32 mode: " + arg);
275
+ else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
253
276
  } else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
254
277
  else if (op === require_backend.AluOp.Special) return arg[0];
255
278
  else if (op === require_backend.AluOp.Variable) return arg;
256
279
  else if (op === require_backend.AluOp.GlobalIndex) {
257
- source = `${args[arg]}[${require_backend.strip1(gen(src[0]))}]`;
280
+ source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
258
281
  if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
259
282
  }
260
- if (!source) throw new Error(`Missing impl for op: ${op}`);
283
+ if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
261
284
  const typeName = dtypeToWgsl(dtype);
262
285
  if ((references.get(exp) ?? 0) > 1) {
263
286
  const name = gensym();
@@ -269,13 +292,12 @@ function pipelineSource(device, kernel) {
269
292
  return source;
270
293
  }
271
294
  };
272
- if (!kernel.reduction) {
295
+ if (!re) {
273
296
  countReferences(tune.exp);
274
297
  let rhs = require_backend.strip1(gen(tune.exp));
275
298
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
276
299
  emit(`result[gidx] = ${rhs};`);
277
300
  } else {
278
- const re = kernel.reduction;
279
301
  if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
280
302
  const unroll = tune.size.unroll ?? 1;
281
303
  const upcast = tune.size.upcast ?? 1;
@@ -319,7 +341,7 @@ function pipelineSource(device, kernel) {
319
341
  const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
320
342
  outputIdxExps.push(exp.simplify(cache));
321
343
  countReferences(outputIdxExps[i]);
322
- fusionExps.push(re.fusion.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
344
+ fusionExps.push(re.epilogue.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
323
345
  countReferences(fusionExps[i]);
324
346
  }
325
347
  for (let i = 0; i < upcast; i++) {
@@ -487,13 +509,12 @@ var SyncReader = class SyncReader {
487
509
  }
488
510
  read(buffer, start, count) {
489
511
  if (!this.initialized) this.#init();
490
- if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
491
512
  const deviceStorage = this.deviceStorage;
492
513
  const deviceContexts = this.deviceContexts;
493
514
  const hostContext = this.hostContext;
494
- const pixelsSize = count / 4;
515
+ const pixelsSize = Math.ceil(count / 4);
495
516
  const bytesPerRow = SyncReader.width * 4;
496
- const valsGPU = new ArrayBuffer(count);
517
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
497
518
  for (let i = 0; i < deviceContexts.length; i++) {
498
519
  const texture = deviceContexts[i].getCurrentTexture();
499
520
  const readData = (width, height, offset$1) => {
@@ -537,7 +558,7 @@ var SyncReader = class SyncReader {
537
558
  }
538
559
  if (remainder > 0) readData(remainder, 1, offset);
539
560
  }
540
- return valsGPU;
561
+ return new Uint8Array(valsGPU, 0, count);
541
562
  }
542
563
  };
543
564
  const threefrySrc = `
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.0.2",
3
+ "version": "0.0.3",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -43,14 +43,15 @@
43
43
  "eslint": "^9.31.0",
44
44
  "eslint-plugin-import": "^2.32.0",
45
45
  "globals": "^16.0.0",
46
- "playwright": "~1.50.1",
46
+ "playwright": "~1.52.0",
47
47
  "prettier": "^3.6.2",
48
48
  "prettier-plugin-svelte": "^3.4.0",
49
- "tsdown": "^0.13.0",
49
+ "tsdown": "^0.13.2",
50
50
  "tsx": "^4.20.3",
51
- "typedoc": "^0.28.7",
52
- "typescript": "~5.8.3",
53
- "typescript-eslint": "^8.38.0",
51
+ "typedoc": "^0.28.14",
52
+ "typedoc-theme-fresh": "^0.2.1",
53
+ "typescript": "~5.9.3",
54
+ "typescript-eslint": "^8.46.4",
54
55
  "vitest": "^3.2.4"
55
56
  },
56
57
  "engines": {