@jax-js/jax 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DeVfWEFS.cjs');
1
+ const require_backend = require('./backend-CmaidnkQ.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
284
284
  });
285
285
  }
286
286
  dispatch(exe, inputs, outputs) {
287
+ if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
288
+ if (exe.kernel.size === 0) return;
287
289
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
288
290
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
289
291
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
362
364
  else if (line === popIndent) indent = indent.slice(0, -2);
363
365
  else shader.push(line ? indent + line : line);
364
366
  };
365
- if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || re?.epilogue.some((exp) => exp.dtype === require_backend.DType.Float16)) {
367
+ if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
366
368
  if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
367
369
  emit("enable f16;");
368
370
  }
369
371
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
370
- const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
372
+ const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
371
373
  if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
372
374
  if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
373
375
  emit("");
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
375
377
  tune.exp.fold((exp) => {
376
378
  if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
377
379
  });
380
+ tune.epilogue?.fold((exp) => {
381
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
382
+ });
378
383
  for (let i = 0; i < nargs; i++) {
379
384
  const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
380
385
  emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
398
403
  let gensymCount = 0;
399
404
  const gensym = () => `alu${gensymCount++}`;
400
405
  const isGensym = (text) => text.match(/^alu[0-9]+$/);
401
- for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
406
+ if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
402
407
  const references = /* @__PURE__ */ new Map();
403
408
  const seen = /* @__PURE__ */ new Set();
404
409
  const countReferences = (exp) => {
@@ -530,7 +535,10 @@ function pipelineSource(device, kernel) {
530
535
  const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
531
536
  outputIdxExps.push(exp.simplify(cache));
532
537
  countReferences(outputIdxExps[i]);
533
- fusionExps.push(re.epilogue.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
538
+ fusionExps.push(tune.epilogue.substitute({
539
+ acc: require_backend.AluExp.variable(re.dtype, acc[i]),
540
+ upcast: require_backend.AluExp.i32(i)
541
+ }).simplify(cache));
534
542
  countReferences(fusionExps[i]);
535
543
  }
536
544
  for (let i = 0; i < upcast; i++) {
@@ -652,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
652
660
 
653
661
  //#endregion
654
662
  exports.WebGPUBackend = WebGPUBackend;
655
- //# sourceMappingURL=webgpu-CcGP160M.cjs.map
663
+ //# sourceMappingURL=webgpu-BVns4DbI.cjs.map
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BqymqzuU.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BY8wlLEl.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
284
284
  });
285
285
  }
286
286
  dispatch(exe, inputs, outputs) {
287
+ if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
288
+ if (exe.kernel.size === 0) return;
287
289
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
288
290
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
289
291
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
362
364
  else if (line === popIndent) indent = indent.slice(0, -2);
363
365
  else shader.push(line ? indent + line : line);
364
366
  };
365
- if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
367
+ if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
366
368
  if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
367
369
  emit("enable f16;");
368
370
  }
369
371
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
370
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
372
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
371
373
  if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
372
374
  if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
373
375
  emit("");
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
375
377
  tune.exp.fold((exp) => {
376
378
  if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
377
379
  });
380
+ tune.epilogue?.fold((exp) => {
381
+ if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
382
+ });
378
383
  for (let i = 0; i < nargs; i++) {
379
384
  const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
380
385
  emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
398
403
  let gensymCount = 0;
399
404
  const gensym = () => `alu${gensymCount++}`;
400
405
  const isGensym = (text) => text.match(/^alu[0-9]+$/);
401
- for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
406
+ if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
402
407
  const references = /* @__PURE__ */ new Map();
403
408
  const seen = /* @__PURE__ */ new Set();
404
409
  const countReferences = (exp) => {
@@ -530,7 +535,10 @@ function pipelineSource(device, kernel) {
530
535
  const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
531
536
  outputIdxExps.push(exp.simplify(cache));
532
537
  countReferences(outputIdxExps[i]);
533
- fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
538
+ fusionExps.push(tune.epilogue.substitute({
539
+ acc: AluExp.variable(re.dtype, acc[i]),
540
+ upcast: AluExp.i32(i)
541
+ }).simplify(cache));
534
542
  countReferences(fusionExps[i]);
535
543
  }
536
544
  for (let i = 0; i < upcast; i++) {
@@ -652,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
652
660
 
653
661
  //#endregion
654
662
  export { WebGPUBackend };
655
- //# sourceMappingURL=webgpu-BGuG58KZ.js.map
663
+ //# sourceMappingURL=webgpu-C9iAP5h5.js.map
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.2",
3
+ "version": "0.1.3",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",