@jax-js/jax 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -19
- package/dist/{backend-BqDtPGaR.js → backend-EBRGmEYw.js} +296 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-Ss1Mev_-.cjs} +315 -154
- package/dist/index.cjs +681 -157
- package/dist/index.d.cts +422 -76
- package/dist/index.d.ts +422 -76
- package/dist/index.js +677 -157
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-BVdMaO9T.cjs} +9 -3
- package/dist/{webgpu-CNg9JGva.js → webgpu-ow0Pn_6q.js} +9 -3
- package/package.json +15 -4
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
const require_backend = require('./backend-
|
|
1
|
+
const require_backend = require('./backend-Ss1Mev_-.cjs');
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu.ts
|
|
4
4
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -254,17 +254,23 @@ function pipelineSource(device, kernel) {
|
|
|
254
254
|
else if (op === require_backend.AluOp.Max) source = `max(${require_backend.strip1(a)}, ${require_backend.strip1(b)})`;
|
|
255
255
|
else if (op === require_backend.AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
256
256
|
else if (op === require_backend.AluOp.Cmpne) source = `(${a} != ${b})`;
|
|
257
|
-
} else if (require_backend.AluGroup.Unary.has(op)) {
|
|
257
|
+
} else if (require_backend.AluGroup.Unary.has(op)) if (op === require_backend.AluOp.Reciprocal && src[0].op === require_backend.AluOp.Sqrt) {
|
|
258
|
+
const a = gen(src[0].src[0]);
|
|
259
|
+
source = `inverseSqrt(${a})`;
|
|
260
|
+
} else {
|
|
258
261
|
const a = gen(src[0]);
|
|
259
262
|
if (op === require_backend.AluOp.Sin) source = `sin(${a})`;
|
|
260
263
|
else if (op === require_backend.AluOp.Cos) source = `cos(${a})`;
|
|
264
|
+
else if (op === require_backend.AluOp.Asin) source = `asin(${a})`;
|
|
265
|
+
else if (op === require_backend.AluOp.Atan) source = `atan(${a})`;
|
|
261
266
|
else if (op === require_backend.AluOp.Exp) source = `exp(${a})`;
|
|
262
267
|
else if (op === require_backend.AluOp.Log) source = `log(${a})`;
|
|
263
268
|
else if (op === require_backend.AluOp.Sqrt) source = `sqrt(${a})`;
|
|
264
269
|
else if (op === require_backend.AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
265
270
|
else if (op === require_backend.AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${require_backend.strip1(a)})`;
|
|
266
271
|
else if (op === require_backend.AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${require_backend.strip1(a)})`;
|
|
267
|
-
}
|
|
272
|
+
}
|
|
273
|
+
else if (op === require_backend.AluOp.Where) source = `select(${require_backend.strip1(gen(src[2]))}, ${require_backend.strip1(gen(src[1]))}, ${require_backend.strip1(gen(src[0]))})`;
|
|
268
274
|
else if (op === require_backend.AluOp.Threefry2x32) {
|
|
269
275
|
const x = gensym();
|
|
270
276
|
const [k0, k1, c0, c1] = src.map((x$1) => require_backend.strip1(gen(x$1)));
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-
|
|
1
|
+
import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-EBRGmEYw.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgpu.ts
|
|
4
4
|
/** Implementation of `Backend` that uses WebGPU in browsers. */
|
|
@@ -254,17 +254,23 @@ function pipelineSource(device, kernel) {
|
|
|
254
254
|
else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
|
|
255
255
|
else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
|
|
256
256
|
else if (op === AluOp.Cmpne) source = `(${a} != ${b})`;
|
|
257
|
-
} else if (AluGroup.Unary.has(op)) {
|
|
257
|
+
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
|
|
258
|
+
const a = gen(src[0].src[0]);
|
|
259
|
+
source = `inverseSqrt(${a})`;
|
|
260
|
+
} else {
|
|
258
261
|
const a = gen(src[0]);
|
|
259
262
|
if (op === AluOp.Sin) source = `sin(${a})`;
|
|
260
263
|
else if (op === AluOp.Cos) source = `cos(${a})`;
|
|
264
|
+
else if (op === AluOp.Asin) source = `asin(${a})`;
|
|
265
|
+
else if (op === AluOp.Atan) source = `atan(${a})`;
|
|
261
266
|
else if (op === AluOp.Exp) source = `exp(${a})`;
|
|
262
267
|
else if (op === AluOp.Log) source = `log(${a})`;
|
|
263
268
|
else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
|
|
264
269
|
else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
|
|
265
270
|
else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
|
|
266
271
|
else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
|
|
267
|
-
}
|
|
272
|
+
}
|
|
273
|
+
else if (op === AluOp.Where) source = `select(${strip1(gen(src[2]))}, ${strip1(gen(src[1]))}, ${strip1(gen(src[0]))})`;
|
|
268
274
|
else if (op === AluOp.Threefry2x32) {
|
|
269
275
|
const x = gensym();
|
|
270
276
|
const [k0, k1, c0, c1] = src.map((x$1) => strip1(gen(x$1)));
|
package/package.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
{
|
|
2
2
|
"name": "@jax-js/jax",
|
|
3
|
-
"version": "0.0.
|
|
3
|
+
"version": "0.0.4",
|
|
4
4
|
"description": "Numerical computing and ML in the browser",
|
|
5
5
|
"keywords": [
|
|
6
6
|
"machine learning",
|
|
@@ -38,7 +38,7 @@
|
|
|
38
38
|
"devDependencies": {
|
|
39
39
|
"@eslint/js": "^9.31.0",
|
|
40
40
|
"@types/debug": "^4.1.12",
|
|
41
|
-
"@vitest/browser": "^
|
|
41
|
+
"@vitest/browser-playwright": "^4.0.9",
|
|
42
42
|
"@webgpu/types": "^0.1.64",
|
|
43
43
|
"eslint": "^9.31.0",
|
|
44
44
|
"eslint-plugin-import": "^2.32.0",
|
|
@@ -52,7 +52,7 @@
|
|
|
52
52
|
"typedoc-theme-fresh": "^0.2.1",
|
|
53
53
|
"typescript": "~5.9.3",
|
|
54
54
|
"typescript-eslint": "^8.46.4",
|
|
55
|
-
"vitest": "^
|
|
55
|
+
"vitest": "^4.0.9"
|
|
56
56
|
},
|
|
57
57
|
"engines": {
|
|
58
58
|
"pnpm": ">=10.0.0"
|
|
@@ -60,7 +60,18 @@
|
|
|
60
60
|
"prettier": {
|
|
61
61
|
"plugins": [
|
|
62
62
|
"prettier-plugin-svelte"
|
|
63
|
-
]
|
|
63
|
+
],
|
|
64
|
+
"overrides": [
|
|
65
|
+
{
|
|
66
|
+
"files": [
|
|
67
|
+
"*.md"
|
|
68
|
+
],
|
|
69
|
+
"options": {
|
|
70
|
+
"printWidth": 100
|
|
71
|
+
}
|
|
72
|
+
}
|
|
73
|
+
],
|
|
74
|
+
"proseWrap": "always"
|
|
64
75
|
},
|
|
65
76
|
"scripts": {
|
|
66
77
|
"build": "tsdown",
|