@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-BK21PBVP.cjs');
1
+ const require_backend = require('./backend-Ss1Mev_-.cjs');
2
2
 
3
3
  //#region src/backend/webgpu.ts
4
4
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
21
21
  }
22
22
  malloc(size, initialData) {
23
23
  let buffer;
24
+ const paddedSize = Math.ceil(size / 4) * 4;
24
25
  if (initialData) {
25
26
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
26
27
  if (initialData.byteLength < 4096) {
27
- buffer = this.#createBuffer(size, { mapped: true });
28
- new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(initialData));
28
+ buffer = this.#createBuffer(paddedSize, { mapped: true });
29
+ new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
29
30
  buffer.unmap();
30
31
  } else {
31
- buffer = this.#createBuffer(size);
32
- this.device.queue.writeBuffer(buffer, 0, initialData);
32
+ buffer = this.#createBuffer(paddedSize);
33
+ if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
34
+ else {
35
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
36
+ this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
37
+ const remainder = new Uint8Array(4);
38
+ remainder.set(initialData.subarray(aligned));
39
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
40
+ }
33
41
  }
34
- } else buffer = this.#createBuffer(size);
42
+ } else buffer = this.#createBuffer(paddedSize);
35
43
  const slot = this.nextSlot++;
36
44
  this.buffers.set(slot, {
37
45
  buffer,
46
+ size,
38
47
  ref: 1
39
48
  });
40
49
  return slot;
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
54
63
  }
55
64
  }
56
65
  async read(slot, start, count) {
57
- const buffer = this.#getBuffer(slot);
66
+ const { buffer, size } = this.#getBuffer(slot);
58
67
  if (start === void 0) start = 0;
59
- if (count === void 0) count = buffer.size - start;
60
- const staging = this.#createBuffer(count, { read: true });
68
+ if (count === void 0) count = size - start;
69
+ const paddedSize = Math.ceil(count / 4) * 4;
70
+ const staging = this.#createBuffer(paddedSize, { read: true });
61
71
  try {
62
72
  const commandEncoder = this.device.createCommandEncoder();
63
- commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, count);
73
+ commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
64
74
  this.device.queue.submit([commandEncoder.finish()]);
65
75
  await staging.mapAsync(GPUMapMode.READ);
66
76
  const arrayBuffer = staging.getMappedRange();
67
- return arrayBuffer.slice();
77
+ return new Uint8Array(arrayBuffer.slice(), 0, count);
68
78
  } finally {
69
79
  staging.destroy();
70
80
  }
71
81
  }
72
82
  readSync(slot, start, count) {
73
- const buffer = this.#getBuffer(slot);
83
+ const { buffer, size } = this.#getBuffer(slot);
74
84
  if (start === void 0) start = 0;
75
- if (count === void 0) count = buffer.size - start;
85
+ if (count === void 0) count = size - start;
76
86
  return this.syncReader.read(buffer, start, count);
77
87
  }
78
88
  #cachedShader(kernel) {
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
103
113
  });
104
114
  }
105
115
  dispatch(exe, inputs, outputs) {
106
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
107
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
116
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
117
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
108
118
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
109
119
  }
110
120
  #getBuffer(slot) {
111
121
  const buffer = this.buffers.get(slot);
112
122
  if (!buffer) throw new require_backend.SlotError(slot);
113
- return buffer.buffer;
123
+ return {
124
+ buffer: buffer.buffer,
125
+ size: buffer.size
126
+ };
114
127
  }
115
128
  /**
116
129
  * Create a GPU buffer.
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
138
151
  case require_backend.DType.Int32: return "i32";
139
152
  case require_backend.DType.Uint32: return "u32";
140
153
  case require_backend.DType.Float32: return "f32";
154
+ case require_backend.DType.Float16: return "f16";
141
155
  default: throw new Error(`Unsupported dtype: ${dtype}`);
142
156
  }
143
157
  }
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
148
162
  if (dtype === require_backend.DType.Float32) {
149
163
  if (Number.isNaN(value)) return "nan()";
150
164
  if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
151
- let s = value.toString();
152
- if (!s.includes(".")) s += ".0";
153
- return s;
165
+ return "f32(" + value.toString() + ")";
166
+ }
167
+ if (dtype === require_backend.DType.Float16) {
168
+ if (Number.isNaN(value)) return "f16(nan())";
169
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
170
+ return "f16(" + value.toString() + ")";
154
171
  }
155
172
  throw new Error(`Unsupported const dtype: ${dtype}`);
156
173
  }
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
163
180
  function pipelineSource(device, kernel) {
164
181
  const tune = require_backend.tuneWebgpu(kernel);
165
182
  if (require_backend.DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
166
- const { nargs } = kernel;
183
+ const { nargs, reduction: re } = kernel;
167
184
  const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
168
185
  const shader = [];
169
186
  let indent = "";
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
174
191
  else if (line === popIndent) indent = indent.slice(0, -2);
175
192
  else shader.push(line ? indent + line : line);
176
193
  };
194
+ if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || re?.epilogue.some((exp) => exp.dtype === require_backend.DType.Float16)) {
195
+ if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
196
+ emit("enable f16;");
197
+ }
177
198
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
178
- if (tune.exp.collect((exp) => exp.op === require_backend.AluOp.Threefry2x32).length > 0) emit(threefrySrc);
199
+ const distinctOps = require_backend.union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
200
+ if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
179
201
  emit("");
180
202
  const usedArgs = Array.from({ length: nargs }, () => null);
181
203
  tune.exp.fold((exp) => {
182
- if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
204
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
183
205
  });
184
206
  for (let i = 0; i < nargs; i++) {
185
207
  const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
@@ -226,22 +248,29 @@ function pipelineSource(device, kernel) {
226
248
  else if (op === require_backend.AluOp.Sub) source = `(${a} - ${b})`;
227
249
  else if (op === require_backend.AluOp.Mul) if (dtype === require_backend.DType.Bool) source = `(${a} && ${b})`;
228
250
  else source = `(${a} * ${b})`;
229
- else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `floor(${a} / ${b})` : `(${a} / ${b})`;
251
+ else if (op === require_backend.AluOp.Idiv) source = require_backend.isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
230
252
  else if (op === require_backend.AluOp.Mod) source = `(${a} % ${b})`;
231
253
  else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
232
254
  else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
233
255
  else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
234
256
  else if (op === require_backend.AluOp.Cmpne) source = `(${a} != ${b})`;
235
- } else if (require_backend.AluGroup.Unary.has(op)) {
257
+ } else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
258
+ const a = gen(src[0].src[0]);
259
+ source = `inverseSqrt(${a})`;
260
+ } else {
236
261
  const a = gen(src[0]);
237
262
  if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
238
263
  else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
264
+ else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
265
+ else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
239
266
  else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
240
267
  else if (op === require_backend.AluOp.Log) source = `log(${a})`;
268
+ else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
241
269
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
242
270
  else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
243
271
  else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
244
- } else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
272
+ }
273
+ else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
245
274
  else if (op === require_backend.AluOp.Threefry2x32) {
246
275
  const x = gensym();
247
276
  const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
@@ -249,15 +278,15 @@ function pipelineSource(device, kernel) {
249
278
  if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
250
279
  else if (arg === 0) source = `${x}.x`;
251
280
  else if (arg === 1) source = `${x}.y`;
252
- else throw new Error("Invalid Threefry2x32 mode: " + arg);
281
+ else throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
253
282
  } else if (op === require_backend.AluOp.Const) return constToWgsl(dtype, arg);
254
283
  else if (op === require_backend.AluOp.Special) return arg[0];
255
284
  else if (op === require_backend.AluOp.Variable) return arg;
256
285
  else if (op === require_backend.AluOp.GlobalIndex) {
257
- source = `${args[arg]}[${require_backend.strip1(gen(src[0]))}]`;
286
+ source = `${args[arg[0]]}[${require_backend.strip1(gen(src[0]))}]`;
258
287
  if (dtype === require_backend.DType.Bool) source = `(${source} != 0)`;
259
288
  }
260
- if (!source) throw new Error(`Missing impl for op: ${op}`);
289
+ if (!source) throw new require_backend.UnsupportedOpError(op, dtype, "webgpu", arg);
261
290
  const typeName = dtypeToWgsl(dtype);
262
291
  if ((references.get(exp) ?? 0) > 1) {
263
292
  const name = gensym();
@@ -269,13 +298,12 @@ function pipelineSource(device, kernel) {
269
298
  return source;
270
299
  }
271
300
  };
272
- if (!kernel.reduction) {
301
+ if (!re) {
273
302
  countReferences(tune.exp);
274
303
  let rhs = require_backend.strip1(gen(tune.exp));
275
304
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
276
305
  emit(`result[gidx] = ${rhs};`);
277
306
  } else {
278
- const re = kernel.reduction;
279
307
  if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
280
308
  const unroll = tune.size.unroll ?? 1;
281
309
  const upcast = tune.size.upcast ?? 1;
@@ -319,7 +347,7 @@ function pipelineSource(device, kernel) {
319
347
  const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
320
348
  outputIdxExps.push(exp.simplify(cache));
321
349
  countReferences(outputIdxExps[i]);
322
- fusionExps.push(re.fusion.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
350
+ fusionExps.push(re.epilogue.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
323
351
  countReferences(fusionExps[i]);
324
352
  }
325
353
  for (let i = 0; i < upcast; i++) {
@@ -487,13 +515,12 @@ var SyncReader = class SyncReader {
487
515
  }
488
516
  read(buffer, start, count) {
489
517
  if (!this.initialized) this.#init();
490
- if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
491
518
  const deviceStorage = this.deviceStorage;
492
519
  const deviceContexts = this.deviceContexts;
493
520
  const hostContext = this.hostContext;
494
- const pixelsSize = count / 4;
521
+ const pixelsSize = Math.ceil(count / 4);
495
522
  const bytesPerRow = SyncReader.width * 4;
496
- const valsGPU = new ArrayBuffer(count);
523
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
497
524
  for (let i = 0; i < deviceContexts.length; i++) {
498
525
  const texture = deviceContexts[i].getCurrentTexture();
499
526
  const readData = (width, height, offset$1) => {
@@ -537,7 +564,7 @@ var SyncReader = class SyncReader {
537
564
  }
538
565
  if (remainder > 0) readData(remainder, 1, offset);
539
566
  }
540
- return valsGPU;
567
+ return new Uint8Array(valsGPU, 0, count);
541
568
  }
542
569
  };
543
570
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, findPow2, isFloatDtype, strip1, tuneWebgpu } from "./backend-1eVbAoaV.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-EBRGmEYw.js";
2
2
 
3
3
  //#region src/backend/webgpu.ts
4
4
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -21,20 +21,29 @@ var WebGPUBackend = class {
21
21
  }
22
22
  malloc(size, initialData) {
23
23
  let buffer;
24
+ const paddedSize = Math.ceil(size / 4) * 4;
24
25
  if (initialData) {
25
26
  if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
26
27
  if (initialData.byteLength < 4096) {
27
- buffer = this.#createBuffer(size, { mapped: true });
28
- new Uint8Array(buffer.getMappedRange()).set(new Uint8Array(initialData));
28
+ buffer = this.#createBuffer(paddedSize, { mapped: true });
29
+ new Uint8Array(buffer.getMappedRange(), 0, size).set(initialData);
29
30
  buffer.unmap();
30
31
  } else {
31
- buffer = this.#createBuffer(size);
32
- this.device.queue.writeBuffer(buffer, 0, initialData);
32
+ buffer = this.#createBuffer(paddedSize);
33
+ if (initialData.byteLength % 4 === 0) this.device.queue.writeBuffer(buffer, 0, initialData);
34
+ else {
35
+ const aligned = initialData.byteLength - initialData.byteLength % 4;
36
+ this.device.queue.writeBuffer(buffer, 0, initialData, 0, aligned);
37
+ const remainder = new Uint8Array(4);
38
+ remainder.set(initialData.subarray(aligned));
39
+ this.device.queue.writeBuffer(buffer, aligned, remainder);
40
+ }
33
41
  }
34
- } else buffer = this.#createBuffer(size);
42
+ } else buffer = this.#createBuffer(paddedSize);
35
43
  const slot = this.nextSlot++;
36
44
  this.buffers.set(slot, {
37
45
  buffer,
46
+ size,
38
47
  ref: 1
39
48
  });
40
49
  return slot;
@@ -54,25 +63,26 @@ var WebGPUBackend = class {
54
63
  }
55
64
  }
56
65
  async read(slot, start, count) {
57
- const buffer = this.#getBuffer(slot);
66
+ const { buffer, size } = this.#getBuffer(slot);
58
67
  if (start === void 0) start = 0;
59
- if (count === void 0) count = buffer.size - start;
60
- const staging = this.#createBuffer(count, { read: true });
68
+ if (count === void 0) count = size - start;
69
+ const paddedSize = Math.ceil(count / 4) * 4;
70
+ const staging = this.#createBuffer(paddedSize, { read: true });
61
71
  try {
62
72
  const commandEncoder = this.device.createCommandEncoder();
63
- commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, count);
73
+ commandEncoder.copyBufferToBuffer(buffer, start, staging, 0, paddedSize);
64
74
  this.device.queue.submit([commandEncoder.finish()]);
65
75
  await staging.mapAsync(GPUMapMode.READ);
66
76
  const arrayBuffer = staging.getMappedRange();
67
- return arrayBuffer.slice();
77
+ return new Uint8Array(arrayBuffer.slice(), 0, count);
68
78
  } finally {
69
79
  staging.destroy();
70
80
  }
71
81
  }
72
82
  readSync(slot, start, count) {
73
- const buffer = this.#getBuffer(slot);
83
+ const { buffer, size } = this.#getBuffer(slot);
74
84
  if (start === void 0) start = 0;
75
- if (count === void 0) count = buffer.size - start;
85
+ if (count === void 0) count = size - start;
76
86
  return this.syncReader.read(buffer, start, count);
77
87
  }
78
88
  #cachedShader(kernel) {
@@ -103,14 +113,17 @@ var WebGPUBackend = class {
103
113
  });
104
114
  }
105
115
  dispatch(exe, inputs, outputs) {
106
- const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
107
- const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
116
+ const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
117
+ const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
108
118
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
109
119
  }
110
120
  #getBuffer(slot) {
111
121
  const buffer = this.buffers.get(slot);
112
122
  if (!buffer) throw new SlotError(slot);
113
- return buffer.buffer;
123
+ return {
124
+ buffer: buffer.buffer,
125
+ size: buffer.size
126
+ };
114
127
  }
115
128
  /**
116
129
  * Create a GPU buffer.
@@ -138,6 +151,7 @@ function dtypeToWgsl(dtype, storage = false) {
138
151
  case DType.Int32: return "i32";
139
152
  case DType.Uint32: return "u32";
140
153
  case DType.Float32: return "f32";
154
+ case DType.Float16: return "f16";
141
155
  default: throw new Error(`Unsupported dtype: ${dtype}`);
142
156
  }
143
157
  }
@@ -148,9 +162,12 @@ function constToWgsl(dtype, value) {
148
162
  if (dtype === DType.Float32) {
149
163
  if (Number.isNaN(value)) return "nan()";
150
164
  if (!Number.isFinite(value)) return value > 0 ? "inf()" : "-inf()";
151
- let s = value.toString();
152
- if (!s.includes(".")) s += ".0";
153
- return s;
165
+ return "f32(" + value.toString() + ")";
166
+ }
167
+ if (dtype === DType.Float16) {
168
+ if (Number.isNaN(value)) return "f16(nan())";
169
+ if (!Number.isFinite(value)) return value > 0 ? "f16(inf())" : "f16(-inf())";
170
+ return "f16(" + value.toString() + ")";
154
171
  }
155
172
  throw new Error(`Unsupported const dtype: ${dtype}`);
156
173
  }
@@ -163,7 +180,7 @@ function constToWgsl(dtype, value) {
163
180
  function pipelineSource(device, kernel) {
164
181
  const tune = tuneWebgpu(kernel);
165
182
  if (DEBUG >= 3) console.info(`kernel.exp: ${kernel.exp}\ntune.exp: ${tune.exp}`);
166
- const { nargs } = kernel;
183
+ const { nargs, reduction: re } = kernel;
167
184
  const args = Array.from({ length: nargs }, (_, i) => `in${i}`);
168
185
  const shader = [];
169
186
  let indent = "";
@@ -174,12 +191,17 @@ function pipelineSource(device, kernel) {
174
191
  else if (line === popIndent) indent = indent.slice(0, -2);
175
192
  else shader.push(line ? indent + line : line);
176
193
  };
194
+ if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
195
+ if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
196
+ emit("enable f16;");
197
+ }
177
198
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
178
- if (tune.exp.collect((exp) => exp.op === AluOp.Threefry2x32).length > 0) emit(threefrySrc);
199
+ const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
200
+ if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
179
201
  emit("");
180
202
  const usedArgs = Array.from({ length: nargs }, () => null);
181
203
  tune.exp.fold((exp) => {
182
- if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg] = exp.dtype;
204
+ if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
183
205
  });
184
206
  for (let i = 0; i < nargs; i++) {
185
207
  const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
@@ -226,22 +248,29 @@ function pipelineSource(device, kernel) {
226
248
  else if (op === AluOp.Sub) source = `(${a} - ${b})`;
227
249
  else if (op === AluOp.Mul) if (dtype === DType.Bool) source = `(${a} && ${b})`;
228
250
  else source = `(${a} * ${b})`;
229
- else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `floor(${a} / ${b})` : `(${a} / ${b})`;
251
+ else if (op === AluOp.Idiv) source = isFloatDtype(dtype) ? `trunc(${a} / ${b})` : `(${a} / ${b})`;
230
252
  else if (op === AluOp.Mod) source = `(${a} % ${b})`;
231
253
  else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
232
254
  else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
233
255
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
234
256
  else if (op === AluOp.Cmpne) source = `(${a} != ${b})`;
235
- } else if (AluGroup.Unary.has(op)) {
257
+ } else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
258
+ const a = gen(src[0].src[0]);
259
+ source = `inverseSqrt(${a})`;
260
+ } else {
236
261
  const a = gen(src[0]);
237
262
  if (op === AluOp.Sin) source = `sin(${a})`;
238
263
  else if (op === AluOp.Cos) source = `cos(${a})`;
264
+ else if (op === AluOp.Asin) source = `asin(${a})`;
265
+ else if (op === AluOp.Atan) source = `atan(${a})`;
239
266
  else if (op === AluOp.Exp) source = `exp(${a})`;
240
267
  else if (op === AluOp.Log) source = `log(${a})`;
268
+ else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
241
269
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
242
270
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
243
271
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
244
- } else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
272
+ }
273
+ else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
245
274
  else if (op === AluOp.Threefry2x32) {
246
275
  const x = gensym();
247
276
  const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
@@ -249,15 +278,15 @@ function pipelineSource(device, kernel) {
249
278
  if (arg === "xor") source = `(${x}.x ^ ${x}.y)`;
250
279
  else if (arg === 0) source = `${x}.x`;
251
280
  else if (arg === 1) source = `${x}.y`;
252
- else throw new Error("Invalid Threefry2x32 mode: " + arg);
281
+ else throw new UnsupportedOpError(op, dtype, "webgpu", arg);
253
282
  } else if (op === AluOp.Const) return constToWgsl(dtype, arg);
254
283
  else if (op === AluOp.Special) return arg[0];
255
284
  else if (op === AluOp.Variable) return arg;
256
285
  else if (op === AluOp.GlobalIndex) {
257
- source = `${args[arg]}[${strip1(gen(src[0]))}]`;
286
+ source = `${args[arg[0]]}[${strip1(gen(src[0]))}]`;
258
287
  if (dtype === DType.Bool) source = `(${source} != 0)`;
259
288
  }
260
- if (!source) throw new Error(`Missing impl for op: ${op}`);
289
+ if (!source) throw new UnsupportedOpError(op, dtype, "webgpu", arg);
261
290
  const typeName = dtypeToWgsl(dtype);
262
291
  if ((references.get(exp) ?? 0) > 1) {
263
292
  const name = gensym();
@@ -269,13 +298,12 @@ function pipelineSource(device, kernel) {
269
298
  return source;
270
299
  }
271
300
  };
272
- if (!kernel.reduction) {
301
+ if (!re) {
273
302
  countReferences(tune.exp);
274
303
  let rhs = strip1(gen(tune.exp));
275
304
  if (resultTy !== dtypeToWgsl(tune.exp.dtype)) rhs = `${resultTy}(${rhs})`;
276
305
  emit(`result[gidx] = ${rhs};`);
277
306
  } else {
278
- const re = kernel.reduction;
279
307
  if ((tune.size.groups ?? 1) > 1) throw new Error("WebGPU backend does not support group optimization yet");
280
308
  const unroll = tune.size.unroll ?? 1;
281
309
  const upcast = tune.size.upcast ?? 1;
@@ -319,7 +347,7 @@ function pipelineSource(device, kernel) {
319
347
  const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
320
348
  outputIdxExps.push(exp.simplify(cache));
321
349
  countReferences(outputIdxExps[i]);
322
- fusionExps.push(re.fusion.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
350
+ fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
323
351
  countReferences(fusionExps[i]);
324
352
  }
325
353
  for (let i = 0; i < upcast; i++) {
@@ -487,13 +515,12 @@ var SyncReader = class SyncReader {
487
515
  }
488
516
  read(buffer, start, count) {
489
517
  if (!this.initialized) this.#init();
490
- if (count % 4 !== 0) throw new Error("Read size must be a multiple of 4 bytes");
491
518
  const deviceStorage = this.deviceStorage;
492
519
  const deviceContexts = this.deviceContexts;
493
520
  const hostContext = this.hostContext;
494
- const pixelsSize = count / 4;
521
+ const pixelsSize = Math.ceil(count / 4);
495
522
  const bytesPerRow = SyncReader.width * 4;
496
- const valsGPU = new ArrayBuffer(count);
523
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
497
524
  for (let i = 0; i < deviceContexts.length; i++) {
498
525
  const texture = deviceContexts[i].getCurrentTexture();
499
526
  const readData = (width, height, offset$1) => {
@@ -537,7 +564,7 @@ var SyncReader = class SyncReader {
537
564
  }
538
565
  if (remainder > 0) readData(remainder, 1, offset);
539
566
  }
540
- return valsGPU;
567
+ return new Uint8Array(valsGPU, 0, count);
541
568
  }
542
569
  };
543
570
  const threefrySrc = `
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.0.2",
3
+ "version": "0.0.4",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -38,20 +38,21 @@
38
38
  "devDependencies": {
39
39
  "@eslint/js": "^9.31.0",
40
40
  "@types/debug": "^4.1.12",
41
- "@vitest/browser": "^3.2.4",
41
+ "@vitest/browser-playwright": "^4.0.9",
42
42
  "@webgpu/types": "^0.1.64",
43
43
  "eslint": "^9.31.0",
44
44
  "eslint-plugin-import": "^2.32.0",
45
45
  "globals": "^16.0.0",
46
- "playwright": "~1.50.1",
46
+ "playwright": "~1.52.0",
47
47
  "prettier": "^3.6.2",
48
48
  "prettier-plugin-svelte": "^3.4.0",
49
- "tsdown": "^0.13.0",
49
+ "tsdown": "^0.13.2",
50
50
  "tsx": "^4.20.3",
51
- "typedoc": "^0.28.7",
52
- "typescript": "~5.8.3",
53
- "typescript-eslint": "^8.38.0",
54
- "vitest": "^3.2.4"
51
+ "typedoc": "^0.28.14",
52
+ "typedoc-theme-fresh": "^0.2.1",
53
+ "typescript": "~5.9.3",
54
+ "typescript-eslint": "^8.46.4",
55
+ "vitest": "^4.0.9"
55
56
  },
56
57
  "engines": {
57
58
  "pnpm": ">=10.0.0"
@@ -59,7 +60,18 @@
59
60
  "prettier": {
60
61
  "plugins": [
61
62
  "prettier-plugin-svelte"
62
- ]
63
+ ],
64
+ "overrides": [
65
+ {
66
+ "files": [
67
+ "*.md"
68
+ ],
69
+ "options": {
70
+ "printWidth": 100
71
+ }
72
+ }
73
+ ],
74
+ "proseWrap": "always"
63
75
  },
64
76
  "scripts": {
65
77
  "build": "tsdown",