@jax-js/jax 0.0.5 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,5 +1,176 @@
1
- import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, strip1, tuneWebgpu, union } from "./backend-CdcTZEOF.js";
1
+ import { AluExp, AluGroup, AluOp, DEBUG, DType, Executable, FpHash, SlotError, UnsupportedOpError, findPow2, isFloatDtype, mapSetUnion, strip1, tuneWebgpu } from "./backend-CoVtc9dx.js";
2
2
 
3
+ //#region src/backend/webgpu/builtins.ts
4
+ const threefrySrc = `
5
+ fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
6
+ let ks0: u32 = key.x;
7
+ let ks1: u32 = key.y;
8
+ let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
9
+
10
+ var x0: u32 = ctr.x + ks0;
11
+ var x1: u32 = ctr.y + ks1;
12
+
13
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
14
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
15
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
16
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
17
+ x0 += ks1;
18
+ x1 += ks2 + 1u;
19
+
20
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
21
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
22
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
23
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
24
+ x0 += ks2;
25
+ x1 += ks0 + 2u;
26
+
27
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
28
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
29
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
30
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
31
+ x0 += ks0;
32
+ x1 += ks1 + 3u;
33
+
34
+ x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
35
+ x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
36
+ x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
37
+ x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
38
+ x0 += ks1;
39
+ x1 += ks2 + 4u;
40
+
41
+ x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
42
+ x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
43
+ x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
44
+ x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
45
+ x0 += ks2;
46
+ x1 += ks0 + 5u;
47
+
48
+ return vec2<u32>(x0, x1);
49
+ }`;
50
+ const erfSrc = `
51
+ const _erf_p: f32 = 0.3275911;
52
+ const _erf_a1: f32 = 0.254829592;
53
+ const _erf_a2: f32 = -0.284496736;
54
+ const _erf_a3: f32 = 1.421413741;
55
+ const _erf_a4: f32 = -1.453152027;
56
+ const _erf_a5: f32 = 1.061405429;
57
+ fn erf(x: f32) -> f32 {
58
+ let t = 1.0 / (1.0 + _erf_p * abs(x));
59
+ let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
60
+ return sign(x) * (1.0 - P_t * exp(-x * x));
61
+ }
62
+ fn erfc(x: f32) -> f32 {
63
+ let t = 1.0 / (1.0 + _erf_p * abs(x));
64
+ let P_t = fma(fma(fma(fma(_erf_a5, t, _erf_a4), t, _erf_a3), t, _erf_a2), t, _erf_a1) * t;
65
+ let E = P_t * exp(-x * x);
66
+ return select(2.0 - E, E, x >= 0.0);
67
+ }`;
68
+
69
+ //#endregion
70
+ //#region src/backend/webgpu/reader.ts
71
+ /**
72
+ * Graphics state used to synchronously read data from WebGPU buffers.
73
+ *
74
+ * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
75
+ * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
76
+ * configure it with a WebGPU context. Copy the buffer to a texture, then draw
77
+ * the canvas onto another offscreen canvas with '2d' context ("host storage").
78
+ *
79
+ * Once it's on host storage, we can use `getImageData()` to read the pixels
80
+ * from the image directly.
81
+ *
82
+ * We use 256x256 canvases here (256 KiB). The performance of this is bad
83
+ * because it involves multiple data copies, but it still works. We also
84
+ * actually need to copy the image twice: once in "opaque" mode for the RGB
85
+ * values, and once in "premultiplied" mode for the alpha channel.
86
+ *
87
+ * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
88
+ */
89
+ var SyncReader = class SyncReader {
90
+ static alphaModes = ["opaque", "premultiplied"];
91
+ static width = 256;
92
+ static height = 256;
93
+ initialized = false;
94
+ deviceStorage;
95
+ deviceContexts;
96
+ hostStorage;
97
+ hostContext;
98
+ constructor(device) {
99
+ this.device = device;
100
+ }
101
+ #init() {
102
+ const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
103
+ this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
104
+ this.deviceContexts = this.deviceStorage.map((canvas, i) => {
105
+ const context = canvas.getContext("webgpu");
106
+ context.configure({
107
+ device: this.device,
108
+ format: "bgra8unorm",
109
+ usage: GPUTextureUsage.COPY_DST,
110
+ alphaMode: SyncReader.alphaModes[i]
111
+ });
112
+ return context;
113
+ });
114
+ this.hostStorage = makeCanvas();
115
+ this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
116
+ this.initialized = true;
117
+ }
118
+ read(buffer, start, count) {
119
+ if (!this.initialized) this.#init();
120
+ const deviceStorage = this.deviceStorage;
121
+ const deviceContexts = this.deviceContexts;
122
+ const hostContext = this.hostContext;
123
+ const pixelsSize = Math.ceil(count / 4);
124
+ const bytesPerRow = SyncReader.width * 4;
125
+ const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
126
+ for (let i = 0; i < deviceContexts.length; i++) {
127
+ const texture = deviceContexts[i].getCurrentTexture();
128
+ const readData = (width, height, offset$1) => {
129
+ const encoder = this.device.createCommandEncoder();
130
+ encoder.copyBufferToTexture({
131
+ buffer,
132
+ bytesPerRow,
133
+ offset: offset$1 + start
134
+ }, { texture }, {
135
+ width,
136
+ height,
137
+ depthOrArrayLayers: 1
138
+ });
139
+ const commandBuffer = encoder.finish();
140
+ this.device.queue.submit([commandBuffer]);
141
+ hostContext.clearRect(0, 0, width, height);
142
+ hostContext.drawImage(deviceStorage[i], 0, 0);
143
+ const values = hostContext.getImageData(0, 0, width, height).data;
144
+ const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
145
+ const alphaMode = SyncReader.alphaModes[i];
146
+ for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
147
+ else {
148
+ span[k] = values[k + 2];
149
+ span[k + 1] = values[k + 1];
150
+ span[k + 2] = values[k];
151
+ }
152
+ };
153
+ const pixelsPerCanvas = SyncReader.width * SyncReader.height;
154
+ const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
155
+ let remainder = pixelsSize % pixelsPerCanvas;
156
+ const remainderRows = Math.floor(remainder / SyncReader.width);
157
+ remainder = remainder % SyncReader.width;
158
+ let offset = 0;
159
+ for (let j = 0; j < wholeChunks; j++) {
160
+ readData(SyncReader.width, SyncReader.height, offset);
161
+ offset += pixelsPerCanvas * 4;
162
+ }
163
+ if (remainderRows > 0) {
164
+ readData(SyncReader.width, remainderRows, offset);
165
+ offset += remainderRows * SyncReader.width * 4;
166
+ }
167
+ if (remainder > 0) readData(remainder, 1, offset);
168
+ }
169
+ return new Uint8Array(valsGPU, 0, count);
170
+ }
171
+ };
172
+
173
+ //#endregion
3
174
  //#region src/backend/webgpu.ts
4
175
  /** Implementation of `Backend` that uses WebGPU in browsers. */
5
176
  var WebGPUBackend = class {
@@ -152,7 +323,7 @@ function dtypeToWgsl(dtype, storage = false) {
152
323
  case DType.Uint32: return "u32";
153
324
  case DType.Float32: return "f32";
154
325
  case DType.Float16: return "f16";
155
- default: throw new Error(`Unsupported dtype: ${dtype}`);
326
+ default: throw new Error(`Unsupported dtype for WebGPU: ${dtype}`);
156
327
  }
157
328
  }
158
329
  function constToWgsl(dtype, value) {
@@ -196,8 +367,9 @@ function pipelineSource(device, kernel) {
196
367
  emit("enable f16;");
197
368
  }
198
369
  emit("fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }", "fn inf() -> f32 { let bits = 0x7f800000u; return bitcast<f32>(bits); }");
199
- const distinctOps = union(tune.exp.distinctOps(), re?.epilogue.distinctOps());
370
+ const distinctOps = mapSetUnion(tune.exp.distinctOps(), re?.epilogue.distinctOps());
200
371
  if (distinctOps.has(AluOp.Threefry2x32)) emit(threefrySrc);
372
+ if (distinctOps.has(AluOp.Erf) || distinctOps.has(AluOp.Erfc)) emit(erfSrc);
201
373
  emit("");
202
374
  const usedArgs = Array.from({ length: nargs }, () => null);
203
375
  tune.exp.fold((exp) => {
@@ -225,6 +397,7 @@ function pipelineSource(device, kernel) {
225
397
  }
226
398
  let gensymCount = 0;
227
399
  const gensym = () => `alu${gensymCount++}`;
400
+ const isGensym = (text) => text.match(/^alu[0-9]+$/);
228
401
  for (let i = 0; i < args.length; i++) if (!usedArgs[i]) emit(`_ = &${args[i]};`);
229
402
  const references = /* @__PURE__ */ new Map();
230
403
  const seen = /* @__PURE__ */ new Set();
@@ -253,7 +426,11 @@ function pipelineSource(device, kernel) {
253
426
  else if (op === AluOp.Min) source = `min(${strip1(a)}, ${strip1(b)})`;
254
427
  else if (op === AluOp.Max) source = `max(${strip1(a)}, ${strip1(b)})`;
255
428
  else if (op === AluOp.Cmplt) source = `(${a} < ${b})`;
256
- else if (op === AluOp.Cmpne) source = `(${a} != ${b})`;
429
+ else if (op === AluOp.Cmpne) if (isFloatDtype(src[0].dtype)) {
430
+ const x = isGensym(a) ? a : gensym();
431
+ if (x !== a) emit(`let ${x} = ${a};`);
432
+ source = `(${x} != ${b} || min(${x}, ${dtypeToWgsl(src[0].dtype)}(inf())) != ${x})`;
433
+ } else source = `(${a} != ${b})`;
257
434
  } else if (AluGroup.Unary.has(op)) if (op === AluOp.Reciprocal && src[0].op === AluOp.Sqrt) {
258
435
  const a = gen(src[0].src[0]);
259
436
  source = `inverseSqrt(${a})`;
@@ -265,7 +442,11 @@ function pipelineSource(device, kernel) {
265
442
  else if (op === AluOp.Atan) source = `atan(${a})`;
266
443
  else if (op === AluOp.Exp) source = `exp(${a})`;
267
444
  else if (op === AluOp.Log) source = `log(${a})`;
268
- else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
445
+ else if (op === AluOp.Erf || op === AluOp.Erfc) {
446
+ const funcName = op === AluOp.Erf ? "erf" : "erfc";
447
+ if (dtype !== DType.Float32) source = `${dtypeToWgsl(dtype)}(${funcName}(f32(${a})))`;
448
+ else source = `${funcName}(${a})`;
449
+ } else if (op === AluOp.Sqrt) source = `sqrt(${a})`;
269
450
  else if (op === AluOp.Reciprocal) source = `(1.0 / ${a})`;
270
451
  else if (op === AluOp.Cast) source = `${dtypeToWgsl(dtype)}(${strip1(a)})`;
271
452
  else if (op === AluOp.Bitcast) source = `bitcast<${dtypeToWgsl(dtype)}>(${strip1(a)})`;
@@ -466,153 +647,7 @@ async function compileError(shaderModule, scope, code) {
466
647
  if (code) message += `\n\n${code}`;
467
648
  return message;
468
649
  }
469
- /**
470
- * Graphics state used to synchronously read data from WebGPU buffers.
471
- *
472
- * This trick is borrowed from TensorFlow.js. Basically, the idea is to create
473
- * an offscreen canvas with one pixel for every 4 bytes ("device storage"), then
474
- * configure it with a WebGPU context. Copy the buffer to a texture, then draw
475
- * the canvas onto another offscreen canvas with '2d' context ("host storage").
476
- *
477
- * Once it's on host storage, we can use `getImageData()` to read the pixels
478
- * from the image directly.
479
- *
480
- * We use 256x256 canvases here (256 KiB). The performance of this is bad
481
- * because it involves multiple data copies, but it still works. We also
482
- * actually need to copy the image twice: once in "opaque" mode for the RGB
483
- * values, and once in "premultiplied" mode for the alpha channel.
484
- *
485
- * https://github.com/tensorflow/tfjs/blob/tfjs-v4.22.0/tfjs-backend-webgpu/src/backend_webgpu.ts#L379
486
- */
487
- var SyncReader = class SyncReader {
488
- static alphaModes = ["opaque", "premultiplied"];
489
- static width = 256;
490
- static height = 256;
491
- initialized = false;
492
- deviceStorage;
493
- deviceContexts;
494
- hostStorage;
495
- hostContext;
496
- constructor(device) {
497
- this.device = device;
498
- }
499
- #init() {
500
- const makeCanvas = () => new OffscreenCanvas(SyncReader.width, SyncReader.height);
501
- this.deviceStorage = SyncReader.alphaModes.map(makeCanvas);
502
- this.deviceContexts = this.deviceStorage.map((canvas, i) => {
503
- const context = canvas.getContext("webgpu");
504
- context.configure({
505
- device: this.device,
506
- format: "bgra8unorm",
507
- usage: GPUTextureUsage.COPY_DST,
508
- alphaMode: SyncReader.alphaModes[i]
509
- });
510
- return context;
511
- });
512
- this.hostStorage = makeCanvas();
513
- this.hostContext = this.hostStorage.getContext("2d", { willReadFrequently: true });
514
- this.initialized = true;
515
- }
516
- read(buffer, start, count) {
517
- if (!this.initialized) this.#init();
518
- const deviceStorage = this.deviceStorage;
519
- const deviceContexts = this.deviceContexts;
520
- const hostContext = this.hostContext;
521
- const pixelsSize = Math.ceil(count / 4);
522
- const bytesPerRow = SyncReader.width * 4;
523
- const valsGPU = /* @__PURE__ */ new ArrayBuffer(pixelsSize * 4);
524
- for (let i = 0; i < deviceContexts.length; i++) {
525
- const texture = deviceContexts[i].getCurrentTexture();
526
- const readData = (width, height, offset$1) => {
527
- const encoder = this.device.createCommandEncoder();
528
- encoder.copyBufferToTexture({
529
- buffer,
530
- bytesPerRow,
531
- offset: offset$1 + start
532
- }, { texture }, {
533
- width,
534
- height,
535
- depthOrArrayLayers: 1
536
- });
537
- const commandBuffer = encoder.finish();
538
- this.device.queue.submit([commandBuffer]);
539
- hostContext.clearRect(0, 0, width, height);
540
- hostContext.drawImage(deviceStorage[i], 0, 0);
541
- const values = hostContext.getImageData(0, 0, width, height).data;
542
- const span = new Uint8ClampedArray(valsGPU, offset$1, 4 * width * height);
543
- const alphaMode = SyncReader.alphaModes[i];
544
- for (let k = 0; k < span.length; k += 4) if (alphaMode === "premultiplied") span[k + 3] = values[k + 3];
545
- else {
546
- span[k] = values[k + 2];
547
- span[k + 1] = values[k + 1];
548
- span[k + 2] = values[k];
549
- }
550
- };
551
- const pixelsPerCanvas = SyncReader.width * SyncReader.height;
552
- const wholeChunks = Math.floor(pixelsSize / pixelsPerCanvas);
553
- let remainder = pixelsSize % pixelsPerCanvas;
554
- const remainderRows = Math.floor(remainder / SyncReader.width);
555
- remainder = remainder % SyncReader.width;
556
- let offset = 0;
557
- for (let j = 0; j < wholeChunks; j++) {
558
- readData(SyncReader.width, SyncReader.height, offset);
559
- offset += pixelsPerCanvas * 4;
560
- }
561
- if (remainderRows > 0) {
562
- readData(SyncReader.width, remainderRows, offset);
563
- offset += remainderRows * SyncReader.width * 4;
564
- }
565
- if (remainder > 0) readData(remainder, 1, offset);
566
- }
567
- return new Uint8Array(valsGPU, 0, count);
568
- }
569
- };
570
- const threefrySrc = `
571
- fn threefry2x32(key: vec2<u32>, ctr: vec2<u32>) -> vec2<u32> {
572
- let ks0: u32 = key.x;
573
- let ks1: u32 = key.y;
574
- let ks2: u32 = ks0 ^ ks1 ^ 0x1BD11BDAu;
575
-
576
- var x0: u32 = ctr.x + ks0;
577
- var x1: u32 = ctr.y + ks1;
578
-
579
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
580
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
581
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
582
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
583
- x0 += ks1;
584
- x1 += ks2 + 1u;
585
-
586
- x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
587
- x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
588
- x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
589
- x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
590
- x0 += ks2;
591
- x1 += ks0 + 2u;
592
-
593
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
594
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
595
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
596
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
597
- x0 += ks0;
598
- x1 += ks1 + 3u;
599
-
600
- x0 += x1; x1 = (x1 << 17u) | (x1 >> 15u); x1 ^= x0;
601
- x0 += x1; x1 = (x1 << 29u) | (x1 >> 3u); x1 ^= x0;
602
- x0 += x1; x1 = (x1 << 16u) | (x1 >> 16u); x1 ^= x0;
603
- x0 += x1; x1 = (x1 << 24u) | (x1 >> 8u); x1 ^= x0;
604
- x0 += ks1;
605
- x1 += ks2 + 4u;
606
-
607
- x0 += x1; x1 = (x1 << 13u) | (x1 >> 19u); x1 ^= x0;
608
- x0 += x1; x1 = (x1 << 15u) | (x1 >> 17u); x1 ^= x0;
609
- x0 += x1; x1 = (x1 << 26u) | (x1 >> 6u); x1 ^= x0;
610
- x0 += x1; x1 = (x1 << 6u) | (x1 >> 26u); x1 ^= x0;
611
- x0 += ks2;
612
- x1 += ks0 + 5u;
613
-
614
- return vec2<u32>(x0, x1);
615
- }`;
616
650
 
617
651
  //#endregion
618
- export { WebGPUBackend };
652
+ export { WebGPUBackend };
653
+ //# sourceMappingURL=webgpu-B3UVme6n.js.map