@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-D2C4MJRP.cjs');
1
+ const require_backend = require('./backend-Ss1Mev_-.cjs');
2
2
 
3
3
  //#region src/backend/webgpu.ts
4
4
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -254,17 +254,23 @@ function pipelineSource(device, kernel) {
254
254
  else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
255
255
  else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
256
256
  else if (op === require_backend.AluOp.Cmpne) source = `(${a} != ${b})`;
257
- } else if (require_backend.AluGroup.Unary.has(op)) {
257
+ } else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
258
+ const a = gen(src[0].src[0]);
259
+ source = `inverseSqrt(${a})`;
260
+ } else {
258
261
  const a = gen(src[0]);
259
262
  if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
260
263
  else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
264
+ else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
265
+ else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
261
266
  else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
262
267
  else if (op === require_backend.AluOp.Log) source = `log(${a})`;
263
268
  else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
264
269
  else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
265
270
  else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
266
271
  else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
267
- } else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
272
+ }
273
+ else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
268
274
  else if (op === require_backend.AluOp.Threefry2x32) {
269
275
  const x = gensym();
270
276
  const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
@@ -1,4 +1,4 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-BqDtPGaR.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-EBRGmEYw.js";
2
2
 
3
3
  //#region src/backend/webgpu.ts
4
4
  /** Implementation of `Backend` that uses WebGPU in browsers. */
@@ -254,17 +254,23 @@ function pipelineSource(device, kernel) {
254
254
  else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
255
255
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
256
256
  else if (op === AluOp.Cmpne) source = `(${a} != ${b})`;
257
- } else if (AluGroup.Unary.has(op)) {
257
+ } else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
258
+ const a = gen(src[0].src[0]);
259
+ source = `inverseSqrt(${a})`;
260
+ } else {
258
261
  const a = gen(src[0]);
259
262
  if (op === AluOp.Sin) source = `sin(${a})`;
260
263
  else if (op === AluOp.Cos) source = `cos(${a})`;
264
+ else if (op === AluOp.Asin) source = `asin(${a})`;
265
+ else if (op === AluOp.Atan) source = `atan(${a})`;
261
266
  else if (op === AluOp.Exp) source = `exp(${a})`;
262
267
  else if (op === AluOp.Log) source = `log(${a})`;
263
268
  else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
264
269
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
265
270
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
266
271
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
267
- } else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
272
+ }
273
+ else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
268
274
  else if (op === AluOp.Threefry2x32) {
269
275
  const x = gensym();
270
276
  const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.0.3",
3
+ "version": "0.0.4",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -38,7 +38,7 @@
38
38
  "devDependencies": {
39
39
  "@eslint/js": "^9.31.0",
40
40
  "@types/debug": "^4.1.12",
41
- "@vitest/browser": "^3.2.4",
41
+ "@vitest/browser-playwright": "^4.0.9",
42
42
  "@webgpu/types": "^0.1.64",
43
43
  "eslint": "^9.31.0",
44
44
  "eslint-plugin-import": "^2.32.0",
@@ -52,7 +52,7 @@
52
52
  "typedoc-theme-fresh": "^0.2.1",
53
53
  "typescript": "~5.9.3",
54
54
  "typescript-eslint": "^8.46.4",
55
- "vitest": "^3.2.4"
55
+ "vitest": "^4.0.9"
56
56
  },
57
57
  "engines": {
58
58
  "pnpm": ">=10.0.0"
@@ -60,7 +60,18 @@
60
60
  "prettier": {
61
61
  "plugins": [
62
62
  "prettier-plugin-svelte"
63
- ]
63
+ ],
64
+ "overrides": [
65
+ {
66
+ "files": [
67
+ "*.md"
68
+ ],
69
+ "options": {
70
+ "printWidth": 100
71
+ }
72
+ }
73
+ ],
74
+ "proseWrap": "always"
64
75
  },
65
76
  "scripts": {
66
77
  "build": "tsdown",