@jax-js/jax 0.1.1 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-BbrKEB18.cjs');
1
+ const require_backend = require('./backend-CmaidnkQ.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
284
284
  });
285
285
  }
286
286
  dispatch(exe, inputs, outputs) {
287
+ if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
288
+ if (exe.kernel.size === 0) return;
287
289
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
288
290
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
289
291
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
362
364
  else if (line === popIndent) indent = indent.slice(0, -2);
363
365
  else shader.push(line ? indent + line : line);
364
366
  };
365
- if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || re?.epilogue.some((exp) => exp.dtype === require_backend.DType.Float16)) {
367
+ if (tune.exp.some((exp) => exp.dtype === require_backend.DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === require_backend.DType.Float16)) {
366
368
  if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
367
369
  emit("enable f16;");
368
370
  }
369
371
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
370
- const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
372
+ const distinctOps = require_backend.mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
371
373
  if (distinctOps.has(require_backend.AluOp.Threefry2x32)) emit(threefrySrc);
372
374
  if (distinctOps.has(require_backend.AluOp.Erf) || distinctOps.has(require_backend.AluOp.Erfc)) emit(erfSrc);
373
375
  emit("");
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
375
377
  tune.exp.fold((exp) => {
376
378
  if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
377
379
  });
380
+ tune.epilogue?.fold((exp) => {
381
+ if (exp.op === require_backend.AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
382
+ });
378
383
  for (let i = 0; i < nargs; i++) {
379
384
  const ty = dtypeToWgsl(usedArgs[i] ?? require_backend.DType.Float32, true);
380
385
  emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
398
403
  let gensymCount = 0;
399
404
  const gensym = () => `alu${gensymCount++}`;
400
405
  const isGensym = (text) => text.match(/^alu[0-9]+$/);
401
- for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
406
+ if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
402
407
  const references = /* @__PURE__ */ new Map();
403
408
  const seen = /* @__PURE__ */ new Set();
404
409
  const countReferences = (exp) => {
@@ -436,18 +441,20 @@ function pipelineSource(device, kernel) {
436
441
  source = `inverseSqrt(${a})`;
437
442
  } else {
438
443
  const a = gen(src[0]);
439
- if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
440
- else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
441
- else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
442
- else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
443
- else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
444
- else if (op === require_backend.AluOp.Log) source = `log(${a})`;
444
+ if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
445
+ else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
446
+ else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
447
+ else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
448
+ else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
449
+ else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
445
450
  else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
446
451
  const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
447
- if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
448
- else source = `${funcName}(${a})`;
449
- } else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
452
+ if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
453
+ else source = `${funcName}(${require_backend.strip1(a)})`;
454
+ } else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
450
455
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
456
+ else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
457
+ else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
451
458
  else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
452
459
  else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
453
460
  }
@@ -528,7 +535,10 @@ function pipelineSource(device, kernel) {
528
535
  const exp = tune.outputIdxExp.substitute({ upcast: require_backend.AluExp.i32(i) });
529
536
  outputIdxExps.push(exp.simplify(cache));
530
537
  countReferences(outputIdxExps[i]);
531
- fusionExps.push(re.epilogue.substitute({ acc: require_backend.AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
538
+ fusionExps.push(tune.epilogue.substitute({
539
+ acc: require_backend.AluExp.variable(re.dtype, acc[i]),
540
+ upcast: require_backend.AluExp.i32(i)
541
+ }).simplify(cache));
532
542
  countReferences(fusionExps[i]);
533
543
  }
534
544
  for (let i = 0; i < upcast; i++) {
@@ -650,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
650
660
 
651
661
  //#endregion
652
662
  exports.WebGPUBackend = WebGPUBackend;
653
- //# sourceMappingURL=webgpu-DGYNVHma.cjs.map
663
+ //# sourceMappingURL=webgpu-BVns4DbI.cjs.map
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-CoVtc9dx.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BY8wlLEl.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -284,6 +284,8 @@ var WebGPUBackend = class {
284
284
  });
285
285
  }
286
286
  dispatch(exe, inputs, outputs) {
287
+ if (inputs.length !== exe.kernel.nargs) throw new Error(`webgpu: dispatch with ${inputs.length} inputs, expected ${exe.kernel.nargs}`);
288
+ if (exe.kernel.size === 0) return;
287
289
  const inputBuffers = inputs.map((slot) => this.#getBuffer(slot).buffer);
288
290
  const outputBuffers = outputs.map((slot) => this.#getBuffer(slot).buffer);
289
291
  pipelineSubmit(this.device, exe.data, inputBuffers, outputBuffers);
@@ -362,12 +364,12 @@ function pipelineSource(device, kernel) {
362
364
  else if (line === popIndent) indent = indent.slice(0, -2);
363
365
  else shader.push(line ? indent + line : line);
364
366
  };
365
- if (tune.exp.some((exp) => exp.dtype === DType.Float16) || re?.epilogue.some((exp) => exp.dtype === DType.Float16)) {
367
+ if (tune.exp.some((exp) => exp.dtype === DType.Float16) || tune.epilogue?.some((exp) => exp.dtype === DType.Float16)) {
366
368
  if (!device.features.has("shader-f16")) throw new Error("WebGPU device does not support shader-f16 feature");
367
369
  emit("enable f16;");
368
370
  }
369
371
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
370
- const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
372
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), tune.epilogue?.distinctOps());
371
373
  if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
372
374
  if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
373
375
  emit("");
@@ -375,6 +377,9 @@ function pipelineSource(device, kernel) {
375
377
  tune.exp.fold((exp) => {
376
378
  if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
377
379
  });
380
+ tune.epilogue?.fold((exp) => {
381
+ if (exp.op === AluOp.GlobalIndex) usedArgs[exp.arg[0]] = exp.dtype;
382
+ });
378
383
  for (let i = 0; i < nargs; i++) {
379
384
  const ty = dtypeToWgsl(usedArgs[i] ?? DType.Float32, true);
380
385
  emit(`@group(0) @binding(${i}) var<storage, read> ${args[i]} : array<${ty}>;`);
@@ -398,7 +403,7 @@ function pipelineSource(device, kernel) {
398
403
  let gensymCount = 0;
399
404
  const gensym = () => `alu${gensymCount++}`;
400
405
  const isGensym = (text) => text.match(/^alu[0-9]+$/);
401
- for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
406
+ if (args.length > 0) emit(args.map((arg) => `_ = &${arg};`).join(" "));
402
407
  const references = /* @__PURE__ */ new Map();
403
408
  const seen = /* @__PURE__ */ new Set();
404
409
  const countReferences = (exp) => {
@@ -436,18 +441,20 @@ function pipelineSource(device, kernel) {
436
441
  source = `inverseSqrt(${a})`;
437
442
  } else {
438
443
  const a = gen(src[0]);
439
- if (op === AluOp.Sin) source = `sin(${a})`;
440
- else if (op === AluOp.Cos) source = `cos(${a})`;
441
- else if (op === AluOp.Asin) source = `asin(${a})`;
442
- else if (op === AluOp.Atan) source = `atan(${a})`;
443
- else if (op === AluOp.Exp) source = `exp(${a})`;
444
- else if (op === AluOp.Log) source = `log(${a})`;
444
+ if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
445
+ else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
446
+ else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
447
+ else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
448
+ else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
449
+ else if (op === AluOp.Log) source = `log(${strip1(a)})`;
445
450
  else if (op === AluOp.Erf || op === AluOp.Erfc) {
446
451
  const funcName = op === AluOp.Erf ? "erf" : "erfc";
447
- if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
448
- else source = `${funcName}(${a})`;
449
- } else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
452
+ if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
453
+ else source = `${funcName}(${strip1(a)})`;
454
+ } else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
450
455
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
456
+ else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
457
+ else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
451
458
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
452
459
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
453
460
  }
@@ -528,7 +535,10 @@ function pipelineSource(device, kernel) {
528
535
  const exp = tune.outputIdxExp.substitute({ upcast: AluExp.i32(i) });
529
536
  outputIdxExps.push(exp.simplify(cache));
530
537
  countReferences(outputIdxExps[i]);
531
- fusionExps.push(re.epilogue.substitute({ acc: AluExp.variable(re.dtype, acc[i]) }).simplify(cache));
538
+ fusionExps.push(tune.epilogue.substitute({
539
+ acc: AluExp.variable(re.dtype, acc[i]),
540
+ upcast: AluExp.i32(i)
541
+ }).simplify(cache));
532
542
  countReferences(fusionExps[i]);
533
543
  }
534
544
  for (let i = 0; i < upcast; i++) {
@@ -650,4 +660,4 @@ async function compileError(shaderModule, scope, code) {
650
660
 
651
661
  //#endregion
652
662
  export { WebGPUBackend };
653
- //# sourceMappingURL=webgpu-B3UVme6n.js.map
663
+ //# sourceMappingURL=webgpu-C9iAP5h5.js.map
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.1",
3
+ "version": "0.1.3",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -35,19 +35,6 @@
35
35
  "types": "dist/index.d.ts",
36
36
  "module": "dist/index.js",
37
37
  "license": "MIT",
38
- "scripts": {
39
- "build": "tsdown",
40
- "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
41
- "check": "tsc",
42
- "docs": "tsx typedoc.ts",
43
- "format": "prettier --write .",
44
- "format:check": "prettier --check .",
45
- "lint": "eslint",
46
- "test": "vitest",
47
- "test:coverage": "vitest run --coverage && open coverage/index.html",
48
- "prepublishOnly": "pnpm build",
49
- "postpublish": "git tag jax/v$npm_package_version && git push --tags"
50
- },
51
38
  "devDependencies": {
52
39
  "@eslint/js": "^9.31.0",
53
40
  "@types/debug": "^4.1.12",
@@ -68,15 +55,9 @@
68
55
  "typescript-eslint": "^8.46.4",
69
56
  "vitest": "^4.0.9"
70
57
  },
71
- "packageManager": "pnpm@10.22.0",
72
58
  "engines": {
73
59
  "pnpm": ">=10.0.0"
74
60
  },
75
- "pnpm": {
76
- "overrides": {
77
- "@tensorflow/tfjs-core>@webgpu/types": "^0.1.68"
78
- }
79
- },
80
61
  "prettier": {
81
62
  "plugins": [
82
63
  "prettier-plugin-svelte"
@@ -92,5 +73,16 @@
92
73
  }
93
74
  ],
94
75
  "proseWrap": "always"
76
+ },
77
+ "scripts": {
78
+ "build": "tsdown",
79
+ "build:watch": "TSDOWN_WATCH_MODE=1 tsdown",
80
+ "check": "tsc",
81
+ "docs": "tsx typedoc.ts",
82
+ "format": "prettier --write .",
83
+ "format:check": "prettier --check .",
84
+ "lint": "eslint",
85
+ "test": "vitest",
86
+ "test:coverage": "vitest run --coverage && open coverage/index.html"
95
87
  }
96
- }
88
+ }