@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,22 @@ var PPrint = class PPrint {
36
36
  //#endregion
37
37
  //#region src/utils.ts
38
38
  /** @file Generic programming utilities with no dependencies on library code. */
39
- const DEBUG = 3;
39
+ let DEBUG = 0;
40
+ /**
41
+ * Set the debug level for verbose logging.
42
+ *
43
+ * 1. JIT compile logs
44
+ * 2. Shader code
45
+ * 3. Expressions and metadata
46
+ * 4. JIT programs, tuning details
47
+ * 5. Most verbose operation traces
48
+ *
49
+ * This is an experimental API and may change in behavior. Do not rely on this
50
+ * in production.
51
+ */
52
+ function setDebug(level) {
53
+ DEBUG = level;
54
+ }
40
55
  function unzip2(pairs) {
41
56
  const lst1 = [];
42
57
  const lst2 = [];
@@ -110,9 +125,23 @@ function isNumberPair(x) {
110
125
  }
111
126
  /** Check an axis against number of dimensions, and resolve negative axes. */
112
127
  function checkAxis(axis, ndim) {
113
- if (axis < -ndim || axis >= ndim) throw new Error(`Invalid axis ${axis} for array of ${ndim} dimensions`);
128
+ if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
114
129
  return axis < 0 ? axis + ndim : axis;
115
130
  }
131
+ /** Normalize common axis argument for functions, defaulting to all axes. */
132
+ function normalizeAxis(axis, ndim) {
133
+ if (axis === null) return range(ndim);
134
+ else if (typeof axis === "number") return [checkAxis(axis, ndim)];
135
+ else {
136
+ const seen = /* @__PURE__ */ new Set();
137
+ for (const a of axis) {
138
+ const ca = checkAxis(a, ndim);
139
+ if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
140
+ seen.add(ca);
141
+ }
142
+ return [...seen].sort();
143
+ }
144
+ }
116
145
  function range(start, stop, step = 1) {
117
146
  if (stop === void 0) {
118
147
  stop = start;
@@ -187,6 +216,7 @@ function strip1(str) {
187
216
  if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
188
217
  return str;
189
218
  }
219
+ const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
190
220
  /**
191
221
  * Polynomial hashes modulo p are good at avoiding collisions in expectation.
192
222
  * Probability-wise, it's good enough to be used for something like
@@ -201,22 +231,26 @@ var FpHash = class FpHash {
201
231
  const modulus = 3189051996290219n;
202
232
  this.value = (this.value * base + x) % modulus;
203
233
  }
204
- update(...values) {
205
- for (const x of values) if (typeof x === "string") for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
206
- else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
234
+ update(x) {
235
+ if (typeof x === "string") {
236
+ this.#update(BigInt(x.length));
237
+ for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
238
+ } else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
207
239
  else {
208
- const ar = new Float64Array([x]);
209
- this.#update(new DataView(ar.buffer).getBigUint64(0, true));
240
+ _stagingbuf.setFloat64(0, x, true);
241
+ this.#update(_stagingbuf.getBigUint64(0, true));
210
242
  }
211
243
  else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
212
244
  else if (typeof x === "bigint") this.#update(x ^ 71657401n);
213
245
  else if (x === null) this.#update(37832657n);
214
246
  else if (x === void 0) this.#update(18145117n);
215
- else if (typeof x === "object" && "hash" in x) x.hash(this);
247
+ else x.hash(this);
216
248
  return this;
217
249
  }
218
250
  static hash(...values) {
219
- return new FpHash().update(...values).value;
251
+ const h = new FpHash();
252
+ for (const x of values) h.update(x);
253
+ return h.value;
220
254
  }
221
255
  };
222
256
  /** Run a function while caching it inline inside a `Map`. */
@@ -251,6 +285,41 @@ const byteWidth = (dtype) => {
251
285
  }
252
286
  };
253
287
  const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
288
+ /**
289
+ * Promote two dtypes to their join according to the type lattice.
290
+ *
291
+ * When performing operations between arrays of different types, we need to
292
+ * promote both operands to a common type that can represent values from both
293
+ * input types. This follows JAX's type promotion rules.
294
+ *
295
+ * **Type lattice:**
296
+ * ```text
297
+ * bool -> uint32 -> int32 -> float16 -> float32
298
+ * weak f* --^
299
+ * ```
300
+ *
301
+ * The asterisk f* is a weak type used for JS number constants. When creating
302
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
303
+ * any array they are first combined with.
304
+ *
305
+ * **Examples:**
306
+ * - `promoteTypes(bool, int32) → int32`
307
+ * - `promoteTypes(uint32, int32) → int32`
308
+ * - `promoteTypes(int32, float16) → float16`
309
+ * - `promoteTypes(float16, float32) → float32`
310
+ * - `promoteTypes(uint32, float32) → float32`
311
+ */
312
+ function promoteTypes(dtype1, dtype2) {
313
+ if (dtype1 === dtype2) return dtype1;
314
+ const rank = {
315
+ [DType.Bool]: 0,
316
+ [DType.Uint32]: 1,
317
+ [DType.Int32]: 2,
318
+ [DType.Float16]: 3,
319
+ [DType.Float32]: 4
320
+ };
321
+ return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
322
+ }
254
323
  function dtypedArray(dtype, data) {
255
324
  const { buffer, byteLength, byteOffset } = data;
256
325
  const length = byteLength / byteWidth(dtype);
@@ -320,6 +389,12 @@ var AluExp = class AluExp {
320
389
  static cos(a) {
321
390
  return new AluExp(AluOp.Cos, a.dtype, [a]);
322
391
  }
392
+ static asin(a) {
393
+ return new AluExp(AluOp.Asin, a.dtype, [a]);
394
+ }
395
+ static atan(a) {
396
+ return new AluExp(AluOp.Atan, a.dtype, [a]);
397
+ }
323
398
  static exp(a) {
324
399
  return new AluExp(AluOp.Exp, a.dtype, [a]);
325
400
  }
@@ -403,8 +478,11 @@ var AluExp = class AluExp {
403
478
  getHash() {
404
479
  if (this.#hash !== void 0) return this.#hash;
405
480
  const hasher = new FpHash();
406
- hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
407
- hasher.update(this.src.length, ...this.src);
481
+ hasher.update(this.op);
482
+ hasher.update(this.dtype);
483
+ hasher.update(JSON.stringify(this.arg));
484
+ hasher.update(this.src.length);
485
+ for (const s of this.src) hasher.update(s);
408
486
  this.#hash = hasher.value;
409
487
  return this.#hash;
410
488
  }
@@ -476,10 +554,16 @@ var AluExp = class AluExp {
476
554
  ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
477
555
  break;
478
556
  case AluOp.Sin:
479
- ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
557
+ ret = [-1, 1];
480
558
  break;
481
559
  case AluOp.Cos:
482
- ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
560
+ ret = [-1, 1];
561
+ break;
562
+ case AluOp.Asin:
563
+ ret = [-Math.PI / 2, Math.PI / 2];
564
+ break;
565
+ case AluOp.Atan:
566
+ ret = [-Math.PI / 2, Math.PI / 2];
483
567
  break;
484
568
  case AluOp.Exp:
485
569
  ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
@@ -595,11 +679,12 @@ var AluExp = class AluExp {
595
679
  simplify(cache = /* @__PURE__ */ new Map()) {
596
680
  if (this.#simplified !== void 0) return this.#simplified;
597
681
  const hash = this.getHash();
598
- if (cache.has(hash)) return this.#simplified = cache.get(hash);
682
+ const prevCachedValue = cache.get(hash);
683
+ if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
599
684
  const simplified = this.#simplifyInner(cache);
600
685
  const simplifiedHash = simplified.getHash();
601
- if (cache.has(simplifiedHash)) {
602
- const prevSimplified = cache.get(simplifiedHash);
686
+ const prevSimplified = cache.get(simplifiedHash);
687
+ if (prevSimplified !== void 0) {
603
688
  cache.set(hash, prevSimplified);
604
689
  this.#simplified = prevSimplified;
605
690
  return prevSimplified;
@@ -803,6 +888,8 @@ var AluExp = class AluExp {
803
888
  switch (this.op) {
804
889
  case AluOp.Sin: return Math.sin(x);
805
890
  case AluOp.Cos: return Math.cos(x);
891
+ case AluOp.Asin: return Math.asin(x);
892
+ case AluOp.Atan: return Math.atan(x);
806
893
  case AluOp.Exp: return Math.exp(x);
807
894
  case AluOp.Log: return Math.log(x);
808
895
  case AluOp.Sqrt: return Math.sqrt(x);
@@ -982,6 +1069,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
982
1069
  AluOp$1["Max"] = "Max";
983
1070
  AluOp$1["Sin"] = "Sin";
984
1071
  AluOp$1["Cos"] = "Cos";
1072
+ AluOp$1["Asin"] = "Asin";
1073
+ AluOp$1["Atan"] = "Atan";
985
1074
  AluOp$1["Exp"] = "Exp";
986
1075
  AluOp$1["Log"] = "Log";
987
1076
  AluOp$1["Sqrt"] = "Sqrt";
@@ -1012,6 +1101,8 @@ const AluGroup = {
1012
1101
  Unary: new Set([
1013
1102
  AluOp.Sin,
1014
1103
  AluOp.Cos,
1104
+ AluOp.Asin,
1105
+ AluOp.Atan,
1015
1106
  AluOp.Exp,
1016
1107
  AluOp.Log,
1017
1108
  AluOp.Sqrt,
@@ -1035,6 +1126,8 @@ const AluGroup = {
1035
1126
  RequiredFloat: new Set([
1036
1127
  AluOp.Sin,
1037
1128
  AluOp.Cos,
1129
+ AluOp.Asin,
1130
+ AluOp.Atan,
1038
1131
  AluOp.Exp,
1039
1132
  AluOp.Log,
1040
1133
  AluOp.Sqrt,
@@ -1066,7 +1159,7 @@ var Kernel = class {
1066
1159
  this.exp = exp.simplify();
1067
1160
  }
1068
1161
  hash(state) {
1069
- state.update(this.nargs, this.size, this.exp, this.reduction);
1162
+ state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
1070
1163
  }
1071
1164
  pprint() {
1072
1165
  let details = PPrint.pp(`exp = ${this.exp}`);
@@ -1112,7 +1205,7 @@ var Reduction = class {
1112
1205
  this.epilogue = epilogue.simplify();
1113
1206
  }
1114
1207
  hash(state) {
1115
- state.update(this.dtype, this.op, this.size, this.epilogue);
1208
+ state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
1116
1209
  }
1117
1210
  toString() {
1118
1211
  return `${this.op}{${this.size}} -> ${this.epilogue}`;
@@ -2284,78 +2377,92 @@ function wasm_log(cg) {
2284
2377
  });
2285
2378
  }
2286
2379
  /**
2287
- * Approximate sin(x).
2380
+ * Common helper to approximate sin(x) and cos(x).
2288
2381
  *
2289
2382
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2290
- * z = y - q*(π/2); use odd polynomial on z:
2383
+ * z = y - q*(π/2); use one of two polynomials on z:
2291
2384
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2385
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2386
+ */
2387
+ function _sincos(cg) {
2388
+ const y = cg.local.declare(cg.f32);
2389
+ const qf = cg.local.declare(cg.f32);
2390
+ const q = cg.local.declare(cg.i32);
2391
+ const z = cg.local.declare(cg.f32);
2392
+ const z2 = cg.local.declare(cg.f32);
2393
+ const sz = cg.local.declare(cg.f32);
2394
+ const cz = cg.local.declare(cg.f32);
2395
+ cg.local.get(0);
2396
+ cg.local.get(0);
2397
+ cg.f32.const(1 / (2 * Math.PI));
2398
+ cg.f32.mul();
2399
+ cg.f32.nearest();
2400
+ cg.local.tee(qf);
2401
+ cg.f32.const(2 * Math.PI);
2402
+ cg.f32.mul();
2403
+ cg.f32.sub();
2404
+ cg.local.set(y);
2405
+ cg.local.get(y);
2406
+ cg.f32.const(2 / Math.PI);
2407
+ cg.f32.mul();
2408
+ cg.f32.nearest();
2409
+ cg.local.tee(qf);
2410
+ cg.i32.trunc_f32_s();
2411
+ cg.local.set(q);
2412
+ cg.local.get(y);
2413
+ cg.local.get(qf);
2414
+ cg.f32.const(Math.PI / 2);
2415
+ cg.f32.mul();
2416
+ cg.f32.sub();
2417
+ cg.local.tee(z);
2418
+ cg.local.get(z);
2419
+ cg.f32.mul();
2420
+ cg.local.set(z2);
2421
+ cg.f32.const(-1 / 5040);
2422
+ cg.local.get(z2);
2423
+ cg.f32.mul();
2424
+ cg.f32.const(1 / 120);
2425
+ cg.f32.add();
2426
+ cg.local.get(z2);
2427
+ cg.f32.mul();
2428
+ cg.f32.const(-1 / 6);
2429
+ cg.f32.add();
2430
+ cg.local.get(z2);
2431
+ cg.f32.mul();
2432
+ cg.f32.const(1);
2433
+ cg.f32.add();
2434
+ cg.local.get(z);
2435
+ cg.f32.mul();
2436
+ cg.local.set(sz);
2437
+ cg.f32.const(-1 / 720);
2438
+ cg.local.get(z2);
2439
+ cg.f32.mul();
2440
+ cg.f32.const(1 / 24);
2441
+ cg.f32.add();
2442
+ cg.local.get(z2);
2443
+ cg.f32.mul();
2444
+ cg.f32.const(-1 / 2);
2445
+ cg.f32.add();
2446
+ cg.local.get(z2);
2447
+ cg.f32.mul();
2448
+ cg.f32.const(1);
2449
+ cg.f32.add();
2450
+ cg.local.set(cz);
2451
+ return {
2452
+ q,
2453
+ sz,
2454
+ cz
2455
+ };
2456
+ }
2457
+ /**
2458
+ * Approximate sin(x).
2459
+ *
2460
+ * Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
2292
2461
  */
2293
2462
  function wasm_sin(cg) {
2294
2463
  return cg.function([cg.f32], [cg.f32], () => {
2295
- const y = cg.local.declare(cg.f32);
2296
- const qf = cg.local.declare(cg.f32);
2297
- const q = cg.local.declare(cg.i32);
2298
- const z = cg.local.declare(cg.f32);
2299
- const z2 = cg.local.declare(cg.f32);
2300
- const sz = cg.local.declare(cg.f32);
2301
- const cz = cg.local.declare(cg.f32);
2464
+ const { q, sz, cz } = _sincos(cg);
2302
2465
  const mag = cg.local.declare(cg.f32);
2303
- cg.local.get(0);
2304
- cg.local.get(0);
2305
- cg.f32.const(1 / (2 * Math.PI));
2306
- cg.f32.mul();
2307
- cg.f32.nearest();
2308
- cg.local.tee(qf);
2309
- cg.f32.const(2 * Math.PI);
2310
- cg.f32.mul();
2311
- cg.f32.sub();
2312
- cg.local.set(y);
2313
- cg.local.get(y);
2314
- cg.f32.const(2 / Math.PI);
2315
- cg.f32.mul();
2316
- cg.f32.nearest();
2317
- cg.local.tee(qf);
2318
- cg.i32.trunc_f32_s();
2319
- cg.local.set(q);
2320
- cg.local.get(y);
2321
- cg.local.get(qf);
2322
- cg.f32.const(Math.PI / 2);
2323
- cg.f32.mul();
2324
- cg.f32.sub();
2325
- cg.local.tee(z);
2326
- cg.local.get(z);
2327
- cg.f32.mul();
2328
- cg.local.set(z2);
2329
- cg.f32.const(-1 / 5040);
2330
- cg.local.get(z2);
2331
- cg.f32.mul();
2332
- cg.f32.const(1 / 120);
2333
- cg.f32.add();
2334
- cg.local.get(z2);
2335
- cg.f32.mul();
2336
- cg.f32.const(-1 / 6);
2337
- cg.f32.add();
2338
- cg.local.get(z2);
2339
- cg.f32.mul();
2340
- cg.f32.const(1);
2341
- cg.f32.add();
2342
- cg.local.get(z);
2343
- cg.f32.mul();
2344
- cg.local.set(sz);
2345
- cg.f32.const(-1 / 720);
2346
- cg.local.get(z2);
2347
- cg.f32.mul();
2348
- cg.f32.const(1 / 24);
2349
- cg.f32.add();
2350
- cg.local.get(z2);
2351
- cg.f32.mul();
2352
- cg.f32.const(-1 / 2);
2353
- cg.f32.add();
2354
- cg.local.get(z2);
2355
- cg.f32.mul();
2356
- cg.f32.const(1);
2357
- cg.f32.add();
2358
- cg.local.set(cz);
2359
2466
  cg.local.get(cz);
2360
2467
  cg.local.get(sz);
2361
2468
  cg.local.get(q);
@@ -2374,75 +2481,12 @@ function wasm_sin(cg) {
2374
2481
  /**
2375
2482
  * Approximate cos(x).
2376
2483
  *
2377
- * Same reduction as sinf, then quadrant mapping:
2378
- * k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2484
+ * Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2379
2485
  */
2380
2486
  function wasm_cos(cg) {
2381
2487
  return cg.function([cg.f32], [cg.f32], () => {
2382
- const y = cg.local.declare(cg.f32);
2383
- const qf = cg.local.declare(cg.f32);
2384
- const q = cg.local.declare(cg.i32);
2385
- const z = cg.local.declare(cg.f32);
2386
- const z2 = cg.local.declare(cg.f32);
2387
- const sz = cg.local.declare(cg.f32);
2388
- const cz = cg.local.declare(cg.f32);
2488
+ const { q, sz, cz } = _sincos(cg);
2389
2489
  const mag = cg.local.declare(cg.f32);
2390
- cg.local.get(0);
2391
- cg.local.get(0);
2392
- cg.f32.const(1 / (2 * Math.PI));
2393
- cg.f32.mul();
2394
- cg.f32.nearest();
2395
- cg.local.tee(qf);
2396
- cg.f32.const(2 * Math.PI);
2397
- cg.f32.mul();
2398
- cg.f32.sub();
2399
- cg.local.set(y);
2400
- cg.local.get(y);
2401
- cg.f32.const(2 / Math.PI);
2402
- cg.f32.mul();
2403
- cg.f32.nearest();
2404
- cg.local.tee(qf);
2405
- cg.i32.trunc_f32_s();
2406
- cg.local.set(q);
2407
- cg.local.get(y);
2408
- cg.local.get(qf);
2409
- cg.f32.const(Math.PI / 2);
2410
- cg.f32.mul();
2411
- cg.f32.sub();
2412
- cg.local.tee(z);
2413
- cg.local.get(z);
2414
- cg.f32.mul();
2415
- cg.local.set(z2);
2416
- cg.f32.const(-1 / 5040);
2417
- cg.local.get(z2);
2418
- cg.f32.mul();
2419
- cg.f32.const(1 / 120);
2420
- cg.f32.add();
2421
- cg.local.get(z2);
2422
- cg.f32.mul();
2423
- cg.f32.const(-1 / 6);
2424
- cg.f32.add();
2425
- cg.local.get(z2);
2426
- cg.f32.mul();
2427
- cg.f32.const(1);
2428
- cg.f32.add();
2429
- cg.local.get(z);
2430
- cg.f32.mul();
2431
- cg.local.set(sz);
2432
- cg.f32.const(-1 / 720);
2433
- cg.local.get(z2);
2434
- cg.f32.mul();
2435
- cg.f32.const(1 / 24);
2436
- cg.f32.add();
2437
- cg.local.get(z2);
2438
- cg.f32.mul();
2439
- cg.f32.const(-1 / 2);
2440
- cg.f32.add();
2441
- cg.local.get(z2);
2442
- cg.f32.mul();
2443
- cg.f32.const(1);
2444
- cg.f32.add();
2445
- cg.local.set(cz);
2446
2490
  cg.local.get(sz);
2447
2491
  cg.local.get(cz);
2448
2492
  cg.local.get(q);
@@ -2460,6 +2504,100 @@ function wasm_cos(cg) {
2460
2504
  cg.select();
2461
2505
  });
2462
2506
  }
2507
+ /** Helper function for approximating arctan(x). */
2508
+ function _atan(cg) {
2509
+ const x = cg.local.declare(cg.f32);
2510
+ const abs_x = cg.local.declare(cg.f32);
2511
+ const z = cg.local.declare(cg.f32);
2512
+ const z2 = cg.local.declare(cg.f32);
2513
+ const p = cg.local.declare(cg.f32);
2514
+ cg.local.set(x);
2515
+ cg.local.get(x);
2516
+ cg.f32.abs();
2517
+ cg.local.set(abs_x);
2518
+ cg.f32.const(1);
2519
+ cg.local.get(abs_x);
2520
+ cg.f32.div();
2521
+ cg.local.get(abs_x);
2522
+ cg.local.get(abs_x);
2523
+ cg.f32.const(1);
2524
+ cg.f32.ge();
2525
+ cg.select();
2526
+ cg.local.set(z);
2527
+ cg.local.get(z);
2528
+ cg.local.get(z);
2529
+ cg.f32.mul();
2530
+ cg.local.set(z2);
2531
+ cg.f32.const(.0415796528637);
2532
+ cg.local.get(z2);
2533
+ cg.f32.mul();
2534
+ cg.f32.const(.661705427875);
2535
+ cg.f32.add();
2536
+ cg.local.get(z2);
2537
+ cg.f32.mul();
2538
+ cg.f32.const(.999998614341);
2539
+ cg.f32.add();
2540
+ cg.f32.const(.173698870181);
2541
+ cg.local.get(z2);
2542
+ cg.f32.mul();
2543
+ cg.f32.const(.994987933645);
2544
+ cg.f32.add();
2545
+ cg.local.get(z2);
2546
+ cg.f32.mul();
2547
+ cg.f32.const(1);
2548
+ cg.f32.add();
2549
+ cg.f32.div();
2550
+ cg.local.get(z);
2551
+ cg.f32.mul();
2552
+ cg.local.set(p);
2553
+ cg.f32.const(Math.PI / 2);
2554
+ cg.local.get(p);
2555
+ cg.f32.sub();
2556
+ cg.local.get(p);
2557
+ cg.local.get(abs_x);
2558
+ cg.f32.const(1);
2559
+ cg.f32.ge();
2560
+ cg.select();
2561
+ cg.local.get(x);
2562
+ cg.f32.copysign();
2563
+ }
2564
+ /**
2565
+ * Approximate atan(x).
2566
+ *
2567
+ * Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
2568
+ * where P(u) = A0 + A1*u + A2*u^2 (degree 2)
2569
+ * Q(u) = 1 + B1*u + B2*u^2 (degree 2)
2570
+ * if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
2571
+ * (fitted coefficients, max error ~5e-7 on [0,1])
2572
+ */
2573
+ function wasm_atan(cg) {
2574
+ return cg.function([cg.f32], [cg.f32], () => {
2575
+ cg.local.get(0);
2576
+ _atan(cg);
2577
+ });
2578
+ }
2579
+ /**
2580
+ * Approximate asin(x).
2581
+ *
2582
+ * Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
2583
+ */
2584
+ function wasm_asin(cg) {
2585
+ return cg.function([cg.f32], [cg.f32], () => {
2586
+ cg.local.get(0);
2587
+ cg.f32.const(1);
2588
+ cg.local.get(0);
2589
+ cg.local.get(0);
2590
+ cg.f32.mul();
2591
+ cg.f32.sub();
2592
+ cg.f32.sqrt();
2593
+ cg.f32.const(1);
2594
+ cg.f32.add();
2595
+ cg.f32.div();
2596
+ _atan(cg);
2597
+ cg.f32.const(2);
2598
+ cg.f32.mul();
2599
+ });
2600
+ }
2463
2601
  /**
2464
2602
  * Threefry2x32 pseudorandom number generator.
2465
2603
  *
@@ -3339,6 +3477,8 @@ function codegenWasm(kernel) {
3339
3477
  const funcs = {};
3340
3478
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3341
3479
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3480
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3481
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3342
3482
  if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3343
3483
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3344
3484
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
@@ -3490,6 +3630,8 @@ function translateExp(cg, funcs, exp, ctx) {
3490
3630
  else throw new UnsupportedOpError(op, dtype, "wasm");
3491
3631
  } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3492
3632
  else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3633
+ else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3634
+ else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3493
3635
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3494
3636
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3495
3637
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
@@ -3588,10 +3730,11 @@ let defaultBackend = "wasm";
3588
3730
  const initializedBackends = /* @__PURE__ */ new Map();
3589
3731
  initializedBackends.set("cpu", new CpuBackend());
3590
3732
  initializedBackends.set("wasm", new WasmBackend());
3591
- /** Set the default device backend (must be initialized). */
3592
- function setDevice(device) {
3593
- if (initializedBackends.has(device)) defaultBackend = device;
3733
+ /** Configure the default device for arrays. */
3734
+ function defaultDevice(device) {
3735
+ if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
3594
3736
  else throw new Error(`Backend not initialized: ${device}`);
3737
+ return defaultBackend;
3595
3738
  }
3596
3739
  /**
3597
3740
  * Initialize `jax-js` library backends.
@@ -3618,7 +3761,7 @@ async function createBackend(device) {
3618
3761
  if (!navigator.gpu) return null;
3619
3762
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3620
3763
  if (!adapter) return null;
3621
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-fqhx41TC.cjs"));
3764
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVdMaO9T.cjs"));
3622
3765
  const importantLimits = [
3623
3766
  "maxBufferSize",
3624
3767
  "maxComputeInvocationsPerWorkgroup",
@@ -3785,6 +3928,12 @@ Object.defineProperty(exports, 'deepEqual', {
3785
3928
  return deepEqual;
3786
3929
  }
3787
3930
  });
3931
+ Object.defineProperty(exports, 'defaultDevice', {
3932
+ enumerable: true,
3933
+ get: function () {
3934
+ return defaultDevice;
3935
+ }
3936
+ });
3788
3937
  Object.defineProperty(exports, 'devices', {
3789
3938
  enumerable: true,
3790
3939
  get: function () {
@@ -3845,6 +3994,12 @@ Object.defineProperty(exports, 'isPermutation', {
3845
3994
  return isPermutation;
3846
3995
  }
3847
3996
  });
3997
+ Object.defineProperty(exports, 'normalizeAxis', {
3998
+ enumerable: true,
3999
+ get: function () {
4000
+ return normalizeAxis;
4001
+ }
4002
+ });
3848
4003
  Object.defineProperty(exports, 'partitionList', {
3849
4004
  enumerable: true,
3850
4005
  get: function () {
@@ -3857,6 +4012,12 @@ Object.defineProperty(exports, 'prod', {
3857
4012
  return prod;
3858
4013
  }
3859
4014
  });
4015
+ Object.defineProperty(exports, 'promoteTypes', {
4016
+ enumerable: true,
4017
+ get: function () {
4018
+ return promoteTypes;
4019
+ }
4020
+ });
3860
4021
  Object.defineProperty(exports, 'range', {
3861
4022
  enumerable: true,
3862
4023
  get: function () {
@@ -3881,10 +4042,10 @@ Object.defineProperty(exports, 'runWithCache', {
3881
4042
  return runWithCache;
3882
4043
  }
3883
4044
  });
3884
- Object.defineProperty(exports, 'setDevice', {
4045
+ Object.defineProperty(exports, 'setDebug', {
3885
4046
  enumerable: true,
3886
4047
  get: function () {
3887
- return setDevice;
4048
+ return setDebug;
3888
4049
  }
3889
4050
  });
3890
4051
  Object.defineProperty(exports, 'strip1', {