@jax-js/jax 0.0.4 → 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,176 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-EBRGmEYw.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-DwIAd0AG.js";
2
2
 
3
+ //#region src/backend/webgpu/builtins.ts
4
+ const threefrySrc = `
5
+ fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
6
+ let ks0: u32 = key.x;
7
+ let ks1: u32 = key.y;
8
+ let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
9
+
10
+ var x0: u32 = ctr.x + ks0;
11
+ var x1: u32 = ctr.y + ks1;
12
+
13
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
14
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
15
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
16
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
17
+ x0 += ks1;
18
+ x1 += ks2 + 1u;
19
+
20
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
21
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
22
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
23
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
24
+ x0 += ks2;
25
+ x1 += ks0 + 2u;
26
+
27
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
28
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
29
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
30
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
31
+ x0 += ks0;
32
+ x1 += ks1 + 3u;
33
+
34
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
35
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
36
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
37
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
38
+ x0 += ks1;
39
+ x1 += ks2 + 4u;
40
+
41
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
42
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
43
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
44
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
45
+ x0 += ks2;
46
+ x1 += ks0 + 5u;
47
+
48
+ return vec2<u32>(x0, x1);
49
+ }`;
50
+ const erfSrc = `
51
+ const _erf_p: f32 = 0.3275911;
52
+ const _erf_a1: f32 = 0.254829592;
53
+ const _erf_a2: f32 = -0.284496736;
54
+ const _erf_a3: f32 = 1.421413741;
55
+ const _erf_a4: f32 = -1.453152027;
56
+ const _erf_a5: f32 = 1.061405429;
57
+ fn erf(x: f32) -> f32 {
58
+ let t = 1.0 / (1.0 + _erf_p * abs(x));
59
+ let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
60
+ return sign(x) * (1.0 - P_t * exp(-x * x));
61
+ }
62
+ fn erfc(x: f32) -> f32 {
63
+ let t = 1.0 / (1.0 + _erf_p * abs(x));
64
+ let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
65
+ let E = P_t * exp(-x * x);
66
+ return select(2.0 - E, E, x >= 0.0);
67
+ }`;
68
+
69
+ //#endregion
70
+ //#region src/backend/webgpu/reader.ts
71
+ /**
72
+ * Graphics state used to synchronously read data from WebGPU buffers.
73
+ *
74
+ * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
75
+ * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
76
+ * configure it with a WebGPU context. Copy the buffer to a texture, then draw
77
+ * the canvas onto another offscreen canvas with '2d' context ("host storage").
78
+ *
79
+ * Once it's on host storage, we can use `getImageData()` to read the pixels
80
+ * from the image directly.
81
+ *
82
+ * We use 256x256 canvases here (256 KiB). The performance of this is bad
83
+ * because it involves multiple data copies, but it still works. We also
84
+ * actually need to copy the image twice: once in "opaque" mode for the RGB
85
+ * values, and once in "premultiplied" mode for the alpha channel.
86
+ *
87
+ * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
88
+ */
89
+ var SyncReader = class SyncReader {
90
+ static alphaModes = ["opaque", "premultiplied"];
91
+ static width = 256;
92
+ static height = 256;
93
+ initialized = false;
94
+ deviceStorage;
95
+ deviceContexts;
96
+ hostStorage;
97
+ hostContext;
98
+ constructor(device) {
99
+ this.device = device;
100
+ }
101
+ #init() {
102
+ const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
103
+ this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
104
+ this.deviceContexts = this.deviceStorage.map((canvas, i) => {
105
+ const context = canvas.getContext("webgpu");
106
+ context.configure({
107
+ device: this.device,
108
+ format: "bgra8unorm",
109
+ usage: GPUTextureUsage.COPY_DST,
110
+ alphaMode: SyncReader.alphaModes[i]
111
+ });
112
+ return context;
113
+ });
114
+ this.hostStorage = makeCanvas();
115
+ this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
116
+ this.initialized = true;
117
+ }
118
+ read(buffer, start, count) {
119
+ if (!this.initialized) this.#init();
120
+ const deviceStorage = this.deviceStorage;
121
+ const deviceContexts = this.deviceContexts;
122
+ const hostContext = this.hostContext;
123
+ const pixelsSize = Math.ceil(count / 4);
124
+ const bytesPerRow = SyncReader.width * 4;
125
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
126
+ for (let i = 0; i < deviceContexts.length; i++) {
127
+ const texture = deviceContexts[i].getCurrentTexture();
128
+ const readData = (width, height, offset$1) => {
129
+ const encoder = this.device.createCommandEncoder();
130
+ encoder.copyBufferToTexture({
131
+ buffer,
132
+ bytesPerRow,
133
+ offset: offset$1 + start
134
+ }, { texture }, {
135
+ width,
136
+ height,
137
+ depthOrArrayLayers: 1
138
+ });
139
+ const commandBuffer = encoder.finish();
140
+ this.device.queue.submit([commandBuffer]);
141
+ hostContext.clearRect(0, 0, width, height);
142
+ hostContext.drawImage(deviceStorage[i], 0, 0);
143
+ const values = hostContext.getImageData(0, 0, width, height).data;
144
+ const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
145
+ const alphaMode = SyncReader.alphaModes[i];
146
+ for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
147
+ else {
148
+ span[k] = values[k + 2];
149
+ span[k + 1] = values[k + 1];
150
+ span[k + 2] = values[k];
151
+ }
152
+ };
153
+ const pixelsPerCanvas = SyncReader.width * SyncReader.height;
154
+ const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
155
+ let remainder = pixelsSize % pixelsPerCanvas;
156
+ const remainderRows = Math.floor(remainder / SyncReader.width);
157
+ remainder = remainder % SyncReader.width;
158
+ let offset = 0;
159
+ for (let j = 0; j < wholeChunks; j++) {
160
+ readData(SyncReader.width, SyncReader.height, offset);
161
+ offset += pixelsPerCanvas * 4;
162
+ }
163
+ if (remainderRows > 0) {
164
+ readData(SyncReader.width, remainderRows, offset);
165
+ offset += remainderRows * SyncReader.width * 4;
166
+ }
167
+ if (remainder > 0) readData(remainder, 1, offset);
168
+ }
169
+ return new Uint8Array(valsGPU, 0, count);
170
+ }
171
+ };
172
+
173
+ //#endregion
3
174
  //#region src/backend/webgpu.ts
4
175
  /** Implementation of `Backend` that uses WebGPU in browsers. */
5
176
  var WebGPUBackend = class {
@@ -196,8 +367,9 @@ function pipelineSource(device, kernel) {
196
367
  emit("enable f16;");
197
368
  }
198
369
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
199
- const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
370
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
200
371
  if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
372
+ if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
201
373
  emit("");
202
374
  const usedArgs = Array.from({ length: nargs }, () => null);
203
375
  tune.exp.fold((exp) => {
@@ -265,7 +437,11 @@ function pipelineSource(device, kernel) {
265
437
  else if (op === AluOp.Atan) source = `atan(${a})`;
266
438
  else if (op === AluOp.Exp) source = `exp(${a})`;
267
439
  else if (op === AluOp.Log) source = `log(${a})`;
268
- else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
440
+ else if (op === AluOp.Erf || op === AluOp.Erfc) {
441
+ const funcName = op === AluOp.Erf ? "erf" : "erfc";
442
+ if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
443
+ else source = `${funcName}(${a})`;
444
+ } else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
269
445
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
270
446
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
271
447
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
@@ -466,153 +642,7 @@ async function compileError(shaderModule, scope, code) {
466
642
  if (code) message += `\n\n${code}`;
467
643
  return message;
468
644
  }
469
- /**
470
- * Graphics state used to synchronously read data from WebGPU buffers.
471
- *
472
- * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
473
- * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
474
- * configure it with a WebGPU context. Copy the buffer to a texture, then draw
475
- * the canvas onto another offscreen canvas with '2d' context ("host storage").
476
- *
477
- * Once it's on host storage, we can use `getImageData()` to read the pixels
478
- * from the image directly.
479
- *
480
- * We use 256x256 canvases here (256 KiB). The performance of this is bad
481
- * because it involves multiple data copies, but it still works. We also
482
- * actually need to copy the image twice: once in "opaque" mode for the RGB
483
- * values, and once in "premultiplied" mode for the alpha channel.
484
- *
485
- * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
486
- */
487
- var SyncReader = class SyncReader {
488
- static alphaModes = ["opaque", "premultiplied"];
489
- static width = 256;
490
- static height = 256;
491
- initialized = false;
492
- deviceStorage;
493
- deviceContexts;
494
- hostStorage;
495
- hostContext;
496
- constructor(device) {
497
- this.device = device;
498
- }
499
- #init() {
500
- const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
501
- this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
502
- this.deviceContexts = this.deviceStorage.map((canvas, i) => {
503
- const context = canvas.getContext("webgpu");
504
- context.configure({
505
- device: this.device,
506
- format: "bgra8unorm",
507
- usage: GPUTextureUsage.COPY_DST,
508
- alphaMode: SyncReader.alphaModes[i]
509
- });
510
- return context;
511
- });
512
- this.hostStorage = makeCanvas();
513
- this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
514
- this.initialized = true;
515
- }
516
- read(buffer, start, count) {
517
- if (!this.initialized) this.#init();
518
- const deviceStorage = this.deviceStorage;
519
- const deviceContexts = this.deviceContexts;
520
- const hostContext = this.hostContext;
521
- const pixelsSize = Math.ceil(count / 4);
522
- const bytesPerRow = SyncReader.width * 4;
523
- const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
524
- for (let i = 0; i < deviceContexts.length; i++) {
525
- const texture = deviceContexts[i].getCurrentTexture();
526
- const readData = (width, height, offset$1) => {
527
- const encoder = this.device.createCommandEncoder();
528
- encoder.copyBufferToTexture({
529
- buffer,
530
- bytesPerRow,
531
- offset: offset$1 + start
532
- }, { texture }, {
533
- width,
534
- height,
535
- depthOrArrayLayers: 1
536
- });
537
- const commandBuffer = encoder.finish();
538
- this.device.queue.submit([commandBuffer]);
539
- hostContext.clearRect(0, 0, width, height);
540
- hostContext.drawImage(deviceStorage[i], 0, 0);
541
- const values = hostContext.getImageData(0, 0, width, height).data;
542
- const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
543
- const alphaMode = SyncReader.alphaModes[i];
544
- for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
545
- else {
546
- span[k] = values[k + 2];
547
- span[k + 1] = values[k + 1];
548
- span[k + 2] = values[k];
549
- }
550
- };
551
- const pixelsPerCanvas = SyncReader.width * SyncReader.height;
552
- const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
553
- let remainder = pixelsSize % pixelsPerCanvas;
554
- const remainderRows = Math.floor(remainder / SyncReader.width);
555
- remainder = remainder % SyncReader.width;
556
- let offset = 0;
557
- for (let j = 0; j < wholeChunks; j++) {
558
- readData(SyncReader.width, SyncReader.height, offset);
559
- offset += pixelsPerCanvas * 4;
560
- }
561
- if (remainderRows > 0) {
562
- readData(SyncReader.width, remainderRows, offset);
563
- offset += remainderRows * SyncReader.width * 4;
564
- }
565
- if (remainder > 0) readData(remainder, 1, offset);
566
- }
567
- return new Uint8Array(valsGPU, 0, count);
568
- }
569
- };
570
- const threefrySrc = `
571
- fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
572
- let ks0: u32 = key.x;
573
- let ks1: u32 = key.y;
574
- let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
575
-
576
- var x0: u32 = ctr.x + ks0;
577
- var x1: u32 = ctr.y + ks1;
578
-
579
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
580
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
581
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
582
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
583
- x0 += ks1;
584
- x1 += ks2 + 1u;
585
-
586
- x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
587
- x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
588
- x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
589
- x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
590
- x0 += ks2;
591
- x1 += ks0 + 2u;
592
-
593
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
594
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
595
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
596
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
597
- x0 += ks0;
598
- x1 += ks1 + 3u;
599
-
600
- x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
601
- x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
602
- x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
603
- x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
604
- x0 += ks1;
605
- x1 += ks2 + 4u;
606
-
607
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
608
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
609
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
610
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
611
- x0 += ks2;
612
- x1 += ks0 + 5u;
613
-
614
- return vec2<u32>(x0, x1);
615
- }`;
616
645
 
617
646
  //#endregion
618
- export { WebGPUBackend };
647
+ export { WebGPUBackend };
648
+ //# sourceMappingURL=webgpu-LGi2A3mS.js.map
package/package.json CHANGED
@@ -1,6 +1,6 @@
1
1
  {
2
2
  "name": "@jax-js/jax",
3
- "version": "0.0.4",
3
+ "version": "0.1.0",
4
4
  "description": "Numerical computing and ML in the browser",
5
5
  "keywords": [
6
6
  "machine learning",
@@ -19,7 +19,7 @@
19
19
  },
20
20
  "type": "module",
21
21
  "files": [
22
- "/dist"
22
+ "/dist/*.{js,cjs,d.ts,d.cts}"
23
23
  ],
24
24
  "main": "dist/index.js",
25
25
  "exports": {
@@ -39,7 +39,8 @@
39
39
  "@eslint/js": "^9.31.0",
40
40
  "@types/debug": "^4.1.12",
41
41
  "@vitest/browser-playwright": "^4.0.9",
42
- "@webgpu/types": "^0.1.64",
42
+ "@vitest/coverage-v8": "4.0.9",
43
+ "@webgpu/types": "^0.1.68",
43
44
  "eslint": "^9.31.0",
44
45
  "eslint-plugin-import": "^2.32.0",
45
46
  "globals": "^16.0.0",
@@ -49,7 +50,7 @@
49
50
  "tsdown": "^0.13.2",
50
51
  "tsx": "^4.20.3",
51
52
  "typedoc": "^0.28.14",
52
- "typedoc-theme-fresh": "^0.2.1",
53
+ "typedoc-theme-fresh": "^0.2.3",
53
54
  "typescript": "~5.9.3",
54
55
  "typescript-eslint": "^8.46.4",
55
56
  "vitest": "^4.0.9"
@@ -81,6 +82,7 @@
81
82
  "format": "prettier --write .",
82
83
  "format:check": "prettier --check .",
83
84
  "lint": "eslint",
84
- "test": "vitest"
85
+ "test": "vitest",
86
+ "test:coverage": "vitest run --coverage && open coverage/index.html"
85
87
  }
86
88
  }