@jax-js/jax 0.1.0 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-DwIAd0AG.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-BqymqzuU.js";
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -323,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
323
323
  case DType.Uint32: return "u32";
324
324
  case DType.Float32: return "f32";
325
325
  case DType.Float16: return "f16";
326
- default: throw new Error(`Unsupported dtype: ${dtype}`);
326
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
327
327
  }
328
328
  }
329
329
  function constToWgsl(dtype, value) {
@@ -397,6 +397,7 @@ function pipelineSource(device, kernel) {
397
397
  }
398
398
  let gensymCount = 0;
399
399
  const gensym = () => `alu${gensymCount++}`;
400
+ const isGensym = (text) => text.match(/^alu[0-9]+$/);
400
401
  for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
401
402
  const references = /* @__PURE__ */ new Map();
402
403
  const seen = /* @__PURE__ */ new Set();
@@ -425,24 +426,30 @@ function pipelineSource(device, kernel) {
425
426
  else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
426
427
  else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
427
428
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
428
- else if (op === AluOp.Cmpne) source = `(${a} != ${b})`;
429
+ else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
430
+ const x = isGensym(a) ? a : gensym();
431
+ if (x !== a) emit(`let ${x} = ${a};`);
432
+ source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
433
+ } else source = `(${a} != ${b})`;
429
434
  } else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
430
435
  const a = gen(src[0].src[0]);
431
436
  source = `inverseSqrt(${a})`;
432
437
  } else {
433
438
  const a = gen(src[0]);
434
- if (op === AluOp.Sin) source = `sin(${a})`;
435
- else if (op === AluOp.Cos) source = `cos(${a})`;
436
- else if (op === AluOp.Asin) source = `asin(${a})`;
437
- else if (op === AluOp.Atan) source = `atan(${a})`;
438
- else if (op === AluOp.Exp) source = `exp(${a})`;
439
- else if (op === AluOp.Log) source = `log(${a})`;
439
+ if (op === AluOp.Sin) source = `sin(${strip1(a)})`;
440
+ else if (op === AluOp.Cos) source = `cos(${strip1(a)})`;
441
+ else if (op === AluOp.Asin) source = `asin(${strip1(a)})`;
442
+ else if (op === AluOp.Atan) source = `atan(${strip1(a)})`;
443
+ else if (op === AluOp.Exp) source = `exp(${strip1(a)})`;
444
+ else if (op === AluOp.Log) source = `log(${strip1(a)})`;
440
445
  else if (op === AluOp.Erf || op === AluOp.Erfc) {
441
446
  const funcName = op === AluOp.Erf ? "erf" : "erfc";
442
- if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
443
- else source = `${funcName}(${a})`;
444
- } else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
447
+ if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${strip1(a)})))`;
448
+ else source = `${funcName}(${strip1(a)})`;
449
+ } else if (op === AluOp.Sqrt) source = `sqrt(${strip1(a)})`;
445
450
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
451
+ else if (op === AluOp.Floor) source = `floor(${strip1(a)})`;
452
+ else if (op === AluOp.Ceil) source = `ceil(${strip1(a)})`;
446
453
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
447
454
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
448
455
  }
@@ -645,4 +652,4 @@ async function compileError(shaderModule, scope, code) {
645
652
 
646
653
  //#endregion
647
654
  export { WebGPUBackend };
648
- //# sourceMappingURL=webgpu-LGi2A3mS.js.map
655
+ //# sourceMappingURL=webgpu-BGuG58KZ.js.map
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-FtkbO6pI.cjs');
1
+ const require_backend = require('./backend-DeVfWEFS.cjs');
2
2
 
3
3
  //#region src/backend/webgpu/builtins.ts
4
4
  const threefrySrc = `
@@ -323,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
323
323
  case require_backend.DType.Uint32: return "u32";
324
324
  case require_backend.DType.Float32: return "f32";
325
325
  case require_backend.DType.Float16: return "f16";
326
- default: throw new Error(`Unsupported dtype: ${dtype}`);
326
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
327
327
  }
328
328
  }
329
329
  function constToWgsl(dtype, value) {
@@ -397,6 +397,7 @@ function pipelineSource(device, kernel) {
397
397
  }
398
398
  let gensymCount = 0;
399
399
  const gensym = () => `alu${gensymCount++}`;
400
+ const isGensym = (text) => text.match(/^alu[0-9]+$/);
400
401
  for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
401
402
  const references = /* @__PURE__ */ new Map();
402
403
  const seen = /* @__PURE__ */ new Set();
@@ -425,24 +426,30 @@ function pipelineSource(device, kernel) {
425
426
  else if (op === require_backend.AluOp.Min) source = `min(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
426
427
  else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
427
428
  else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
428
- else if (op === require_backend.AluOp.Cmpne) source = `(${a} != ${b})`;
429
+ else if (op === require_backend.AluOp.Cmpne) if (require_backend.isFloatDtype(src[0].dtype)) {
430
+ const x = isGensym(a) ? a : gensym();
431
+ if (x !== a) emit(`let ${x} = ${a};`);
432
+ source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
433
+ } else source = `(${a} != ${b})`;
429
434
  } else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
430
435
  const a = gen(src[0].src[0]);
431
436
  source = `inverseSqrt(${a})`;
432
437
  } else {
433
438
  const a = gen(src[0]);
434
- if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
435
- else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
436
- else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
437
- else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
438
- else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
439
- else if (op === require_backend.AluOp.Log) source = `log(${a})`;
439
+ if (op === require_backend.AluOp.Sin) source = `sin(${require_backend.strip1(a)})`;
440
+ else if (op === require_backend.AluOp.Cos) source = `cos(${require_backend.strip1(a)})`;
441
+ else if (op === require_backend.AluOp.Asin) source = `asin(${require_backend.strip1(a)})`;
442
+ else if (op === require_backend.AluOp.Atan) source = `atan(${require_backend.strip1(a)})`;
443
+ else if (op === require_backend.AluOp.Exp) source = `exp(${require_backend.strip1(a)})`;
444
+ else if (op === require_backend.AluOp.Log) source = `log(${require_backend.strip1(a)})`;
440
445
  else if (op === require_backend.AluOp.Erf || op === require_backend.AluOp.Erfc) {
441
446
  const funcName = op === require_backend.AluOp.Erf ? "erf" : "erfc";
442
- if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
443
- else source = `${funcName}(${a})`;
444
- } else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
447
+ if (dtype !== require_backend.DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${require_backend.strip1(a)})))`;
448
+ else source = `${funcName}(${require_backend.strip1(a)})`;
449
+ } else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${require_backend.strip1(a)})`;
445
450
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
451
+ else if (op === require_backend.AluOp.Floor) source = `floor(${require_backend.strip1(a)})`;
452
+ else if (op === require_backend.AluOp.Ceil) source = `ceil(${require_backend.strip1(a)})`;
446
453
  else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
447
454
  else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
448
455
  }
@@ -645,4 +652,4 @@ async function compileError(shaderModule, scope, code) {
645
652
 
646
653
  //#endregion
647
654
  exports.WebGPUBackend = WebGPUBackend;
648
- //# sourceMappingURL=webgpu-BE7zA_01.cjs.map
655
+ //# sourceMappingURL=webgpu-CcGP160M.cjs.map
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.1.0",
3
+ "version": "0.1.2",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",