@jax-js/jax 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,22 @@ var PPrint = class PPrint {
35
35
  //#endregion
36
36
  //#region src/utils.ts
37
37
  /** @file Generic programming utilities with no dependencies on library code. */
38
- const DEBUG = 3;
38
+ let DEBUG = 0;
39
+ /**
40
+ * Set the debug level for verbose logging.
41
+ *
42
+ * 1. JIT compile logs
43
+ * 2. Shader code
44
+ * 3. Expressions and metadata
45
+ * 4. JIT programs, tuning details
46
+ * 5. Most verbose operation traces
47
+ *
48
+ * This is an experimental API and may change in behavior. Do not rely on this
49
+ * in production.
50
+ */
51
+ function setDebug(level) {
52
+ DEBUG = level;
53
+ }
39
54
  function unzip2(pairs) {
40
55
  const lst1 = [];
41
56
  const lst2 = [];
@@ -109,9 +124,23 @@ function isNumberPair(x) {
109
124
  }
110
125
  /** Check an axis against number of dimensions, and resolve negative axes. */
111
126
  function checkAxis(axis, ndim) {
112
- if (axis < -ndim || axis >= ndim) throw new Error(`Invalid axis ${axis} for array of ${ndim} dimensions`);
127
+ if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
113
128
  return axis < 0 ? axis + ndim : axis;
114
129
  }
130
+ /** Normalize common axis argument for functions, defaulting to all axes. */
131
+ function normalizeAxis(axis, ndim) {
132
+ if (axis === null) return range(ndim);
133
+ else if (typeof axis === "number") return [checkAxis(axis, ndim)];
134
+ else {
135
+ const seen = /* @__PURE__ */ new Set();
136
+ for (const a of axis) {
137
+ const ca = checkAxis(a, ndim);
138
+ if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
139
+ seen.add(ca);
140
+ }
141
+ return [...seen].sort();
142
+ }
143
+ }
115
144
  function range(start, stop, step = 1) {
116
145
  if (stop === void 0) {
117
146
  stop = start;
@@ -186,6 +215,7 @@ function strip1(str) {
186
215
  if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
187
216
  return str;
188
217
  }
218
+ const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
189
219
  /**
190
220
  * Polynomial hashes modulo p are good at avoiding collisions in expectation.
191
221
  * Probability-wise, it's good enough to be used for something like
@@ -200,22 +230,26 @@ var FpHash = class FpHash {
200
230
  const modulus = 3189051996290219n;
201
231
  this.value = (this.value * base + x) % modulus;
202
232
  }
203
- update(...values) {
204
- for (const x of values) if (typeof x === "string") for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
205
- else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
233
+ update(x) {
234
+ if (typeof x === "string") {
235
+ this.#update(BigInt(x.length));
236
+ for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
237
+ } else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
206
238
  else {
207
- const ar = new Float64Array([x]);
208
- this.#update(new DataView(ar.buffer).getBigUint64(0, true));
239
+ _stagingbuf.setFloat64(0, x, true);
240
+ this.#update(_stagingbuf.getBigUint64(0, true));
209
241
  }
210
242
  else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
211
243
  else if (typeof x === "bigint") this.#update(x ^ 71657401n);
212
244
  else if (x === null) this.#update(37832657n);
213
245
  else if (x === void 0) this.#update(18145117n);
214
- else if (typeof x === "object" && "hash" in x) x.hash(this);
246
+ else x.hash(this);
215
247
  return this;
216
248
  }
217
249
  static hash(...values) {
218
- return new FpHash().update(...values).value;
250
+ const h = new FpHash();
251
+ for (const x of values) h.update(x);
252
+ return h.value;
219
253
  }
220
254
  };
221
255
  /** Run a function while caching it inline inside a `Map`. */
@@ -250,6 +284,41 @@ const byteWidth = (dtype) => {
250
284
  }
251
285
  };
252
286
  const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
287
+ /**
288
+ * Promote two dtypes to their join according to the type lattice.
289
+ *
290
+ * When performing operations between arrays of different types, we need to
291
+ * promote both operands to a common type that can represent values from both
292
+ * input types. This follows JAX's type promotion rules.
293
+ *
294
+ * **Type lattice:**
295
+ * ```text
296
+ * bool -> uint32 -> int32 -> float16 -> float32
297
+ * weak f* --^
298
+ * ```
299
+ *
300
+ * The asterisk f* is a weak type used for JS number constants. When creating
301
+ * arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
302
+ * any array they are first combined with.
303
+ *
304
+ * **Examples:**
305
+ * - `promoteTypes(bool, int32) → int32`
306
+ * - `promoteTypes(uint32, int32) → int32`
307
+ * - `promoteTypes(int32, float16) → float16`
308
+ * - `promoteTypes(float16, float32) → float32`
309
+ * - `promoteTypes(uint32, float32) → float32`
310
+ */
311
+ function promoteTypes(dtype1, dtype2) {
312
+ if (dtype1 === dtype2) return dtype1;
313
+ const rank = {
314
+ [DType.Bool]: 0,
315
+ [DType.Uint32]: 1,
316
+ [DType.Int32]: 2,
317
+ [DType.Float16]: 3,
318
+ [DType.Float32]: 4
319
+ };
320
+ return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
321
+ }
253
322
  function dtypedArray(dtype, data) {
254
323
  const { buffer, byteLength, byteOffset } = data;
255
324
  const length = byteLength / byteWidth(dtype);
@@ -319,6 +388,12 @@ var AluExp = class AluExp {
319
388
  static cos(a) {
320
389
  return new AluExp(AluOp.Cos, a.dtype, [a]);
321
390
  }
391
+ static asin(a) {
392
+ return new AluExp(AluOp.Asin, a.dtype, [a]);
393
+ }
394
+ static atan(a) {
395
+ return new AluExp(AluOp.Atan, a.dtype, [a]);
396
+ }
322
397
  static exp(a) {
323
398
  return new AluExp(AluOp.Exp, a.dtype, [a]);
324
399
  }
@@ -402,8 +477,11 @@ var AluExp = class AluExp {
402
477
  getHash() {
403
478
  if (this.#hash !== void 0) return this.#hash;
404
479
  const hasher = new FpHash();
405
- hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
406
- hasher.update(this.src.length, ...this.src);
480
+ hasher.update(this.op);
481
+ hasher.update(this.dtype);
482
+ hasher.update(JSON.stringify(this.arg));
483
+ hasher.update(this.src.length);
484
+ for (const s of this.src) hasher.update(s);
407
485
  this.#hash = hasher.value;
408
486
  return this.#hash;
409
487
  }
@@ -475,10 +553,16 @@ var AluExp = class AluExp {
475
553
  ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
476
554
  break;
477
555
  case AluOp.Sin:
478
- ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
556
+ ret = [-1, 1];
479
557
  break;
480
558
  case AluOp.Cos:
481
- ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
559
+ ret = [-1, 1];
560
+ break;
561
+ case AluOp.Asin:
562
+ ret = [-Math.PI / 2, Math.PI / 2];
563
+ break;
564
+ case AluOp.Atan:
565
+ ret = [-Math.PI / 2, Math.PI / 2];
482
566
  break;
483
567
  case AluOp.Exp:
484
568
  ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
@@ -594,11 +678,12 @@ var AluExp = class AluExp {
594
678
  simplify(cache = /* @__PURE__ */ new Map()) {
595
679
  if (this.#simplified !== void 0) return this.#simplified;
596
680
  const hash = this.getHash();
597
- if (cache.has(hash)) return this.#simplified = cache.get(hash);
681
+ const prevCachedValue = cache.get(hash);
682
+ if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
598
683
  const simplified = this.#simplifyInner(cache);
599
684
  const simplifiedHash = simplified.getHash();
600
- if (cache.has(simplifiedHash)) {
601
- const prevSimplified = cache.get(simplifiedHash);
685
+ const prevSimplified = cache.get(simplifiedHash);
686
+ if (prevSimplified !== void 0) {
602
687
  cache.set(hash, prevSimplified);
603
688
  this.#simplified = prevSimplified;
604
689
  return prevSimplified;
@@ -802,6 +887,8 @@ var AluExp = class AluExp {
802
887
  switch (this.op) {
803
888
  case AluOp.Sin: return Math.sin(x);
804
889
  case AluOp.Cos: return Math.cos(x);
890
+ case AluOp.Asin: return Math.asin(x);
891
+ case AluOp.Atan: return Math.atan(x);
805
892
  case AluOp.Exp: return Math.exp(x);
806
893
  case AluOp.Log: return Math.log(x);
807
894
  case AluOp.Sqrt: return Math.sqrt(x);
@@ -981,6 +1068,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
981
1068
  AluOp$1["Max"] = "Max";
982
1069
  AluOp$1["Sin"] = "Sin";
983
1070
  AluOp$1["Cos"] = "Cos";
1071
+ AluOp$1["Asin"] = "Asin";
1072
+ AluOp$1["Atan"] = "Atan";
984
1073
  AluOp$1["Exp"] = "Exp";
985
1074
  AluOp$1["Log"] = "Log";
986
1075
  AluOp$1["Sqrt"] = "Sqrt";
@@ -1011,6 +1100,8 @@ const AluGroup = {
1011
1100
  Unary: new Set([
1012
1101
  AluOp.Sin,
1013
1102
  AluOp.Cos,
1103
+ AluOp.Asin,
1104
+ AluOp.Atan,
1014
1105
  AluOp.Exp,
1015
1106
  AluOp.Log,
1016
1107
  AluOp.Sqrt,
@@ -1034,6 +1125,8 @@ const AluGroup = {
1034
1125
  RequiredFloat: new Set([
1035
1126
  AluOp.Sin,
1036
1127
  AluOp.Cos,
1128
+ AluOp.Asin,
1129
+ AluOp.Atan,
1037
1130
  AluOp.Exp,
1038
1131
  AluOp.Log,
1039
1132
  AluOp.Sqrt,
@@ -1065,7 +1158,7 @@ var Kernel = class {
1065
1158
  this.exp = exp.simplify();
1066
1159
  }
1067
1160
  hash(state) {
1068
- state.update(this.nargs, this.size, this.exp, this.reduction);
1161
+ state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
1069
1162
  }
1070
1163
  pprint() {
1071
1164
  let details = PPrint.pp(`exp = ${this.exp}`);
@@ -1111,7 +1204,7 @@ var Reduction = class {
1111
1204
  this.epilogue = epilogue.simplify();
1112
1205
  }
1113
1206
  hash(state) {
1114
- state.update(this.dtype, this.op, this.size, this.epilogue);
1207
+ state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
1115
1208
  }
1116
1209
  toString() {
1117
1210
  return `${this.op}{${this.size}} -> ${this.epilogue}`;
@@ -2283,78 +2376,92 @@ function wasm_log(cg) {
2283
2376
  });
2284
2377
  }
2285
2378
  /**
2286
- * Approximate sin(x).
2379
+ * Common helper to approximate sin(x) and cos(x).
2287
2380
  *
2288
2381
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2289
- * z = y - q*(π/2); use odd polynomial on z:
2382
+ * z = y - q*(π/2); use one of two polynomials on z:
2290
2383
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2384
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2385
+ */
2386
+ function _sincos(cg) {
2387
+ const y = cg.local.declare(cg.f32);
2388
+ const qf = cg.local.declare(cg.f32);
2389
+ const q = cg.local.declare(cg.i32);
2390
+ const z = cg.local.declare(cg.f32);
2391
+ const z2 = cg.local.declare(cg.f32);
2392
+ const sz = cg.local.declare(cg.f32);
2393
+ const cz = cg.local.declare(cg.f32);
2394
+ cg.local.get(0);
2395
+ cg.local.get(0);
2396
+ cg.f32.const(1 / (2 * Math.PI));
2397
+ cg.f32.mul();
2398
+ cg.f32.nearest();
2399
+ cg.local.tee(qf);
2400
+ cg.f32.const(2 * Math.PI);
2401
+ cg.f32.mul();
2402
+ cg.f32.sub();
2403
+ cg.local.set(y);
2404
+ cg.local.get(y);
2405
+ cg.f32.const(2 / Math.PI);
2406
+ cg.f32.mul();
2407
+ cg.f32.nearest();
2408
+ cg.local.tee(qf);
2409
+ cg.i32.trunc_f32_s();
2410
+ cg.local.set(q);
2411
+ cg.local.get(y);
2412
+ cg.local.get(qf);
2413
+ cg.f32.const(Math.PI / 2);
2414
+ cg.f32.mul();
2415
+ cg.f32.sub();
2416
+ cg.local.tee(z);
2417
+ cg.local.get(z);
2418
+ cg.f32.mul();
2419
+ cg.local.set(z2);
2420
+ cg.f32.const(-1 / 5040);
2421
+ cg.local.get(z2);
2422
+ cg.f32.mul();
2423
+ cg.f32.const(1 / 120);
2424
+ cg.f32.add();
2425
+ cg.local.get(z2);
2426
+ cg.f32.mul();
2427
+ cg.f32.const(-1 / 6);
2428
+ cg.f32.add();
2429
+ cg.local.get(z2);
2430
+ cg.f32.mul();
2431
+ cg.f32.const(1);
2432
+ cg.f32.add();
2433
+ cg.local.get(z);
2434
+ cg.f32.mul();
2435
+ cg.local.set(sz);
2436
+ cg.f32.const(-1 / 720);
2437
+ cg.local.get(z2);
2438
+ cg.f32.mul();
2439
+ cg.f32.const(1 / 24);
2440
+ cg.f32.add();
2441
+ cg.local.get(z2);
2442
+ cg.f32.mul();
2443
+ cg.f32.const(-1 / 2);
2444
+ cg.f32.add();
2445
+ cg.local.get(z2);
2446
+ cg.f32.mul();
2447
+ cg.f32.const(1);
2448
+ cg.f32.add();
2449
+ cg.local.set(cz);
2450
+ return {
2451
+ q,
2452
+ sz,
2453
+ cz
2454
+ };
2455
+ }
2456
+ /**
2457
+ * Approximate sin(x).
2458
+ *
2459
+ * Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
2291
2460
  */
2292
2461
  function wasm_sin(cg) {
2293
2462
  return cg.function([cg.f32], [cg.f32], () => {
2294
- const y = cg.local.declare(cg.f32);
2295
- const qf = cg.local.declare(cg.f32);
2296
- const q = cg.local.declare(cg.i32);
2297
- const z = cg.local.declare(cg.f32);
2298
- const z2 = cg.local.declare(cg.f32);
2299
- const sz = cg.local.declare(cg.f32);
2300
- const cz = cg.local.declare(cg.f32);
2463
+ const { q, sz, cz } = _sincos(cg);
2301
2464
  const mag = cg.local.declare(cg.f32);
2302
- cg.local.get(0);
2303
- cg.local.get(0);
2304
- cg.f32.const(1 / (2 * Math.PI));
2305
- cg.f32.mul();
2306
- cg.f32.nearest();
2307
- cg.local.tee(qf);
2308
- cg.f32.const(2 * Math.PI);
2309
- cg.f32.mul();
2310
- cg.f32.sub();
2311
- cg.local.set(y);
2312
- cg.local.get(y);
2313
- cg.f32.const(2 / Math.PI);
2314
- cg.f32.mul();
2315
- cg.f32.nearest();
2316
- cg.local.tee(qf);
2317
- cg.i32.trunc_f32_s();
2318
- cg.local.set(q);
2319
- cg.local.get(y);
2320
- cg.local.get(qf);
2321
- cg.f32.const(Math.PI / 2);
2322
- cg.f32.mul();
2323
- cg.f32.sub();
2324
- cg.local.tee(z);
2325
- cg.local.get(z);
2326
- cg.f32.mul();
2327
- cg.local.set(z2);
2328
- cg.f32.const(-1 / 5040);
2329
- cg.local.get(z2);
2330
- cg.f32.mul();
2331
- cg.f32.const(1 / 120);
2332
- cg.f32.add();
2333
- cg.local.get(z2);
2334
- cg.f32.mul();
2335
- cg.f32.const(-1 / 6);
2336
- cg.f32.add();
2337
- cg.local.get(z2);
2338
- cg.f32.mul();
2339
- cg.f32.const(1);
2340
- cg.f32.add();
2341
- cg.local.get(z);
2342
- cg.f32.mul();
2343
- cg.local.set(sz);
2344
- cg.f32.const(-1 / 720);
2345
- cg.local.get(z2);
2346
- cg.f32.mul();
2347
- cg.f32.const(1 / 24);
2348
- cg.f32.add();
2349
- cg.local.get(z2);
2350
- cg.f32.mul();
2351
- cg.f32.const(-1 / 2);
2352
- cg.f32.add();
2353
- cg.local.get(z2);
2354
- cg.f32.mul();
2355
- cg.f32.const(1);
2356
- cg.f32.add();
2357
- cg.local.set(cz);
2358
2465
  cg.local.get(cz);
2359
2466
  cg.local.get(sz);
2360
2467
  cg.local.get(q);
@@ -2373,75 +2480,12 @@ function wasm_sin(cg) {
2373
2480
  /**
2374
2481
  * Approximate cos(x).
2375
2482
  *
2376
- * Same reduction as sinf, then quadrant mapping:
2377
- * k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2483
+ * Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2378
2484
  */
2379
2485
  function wasm_cos(cg) {
2380
2486
  return cg.function([cg.f32], [cg.f32], () => {
2381
- const y = cg.local.declare(cg.f32);
2382
- const qf = cg.local.declare(cg.f32);
2383
- const q = cg.local.declare(cg.i32);
2384
- const z = cg.local.declare(cg.f32);
2385
- const z2 = cg.local.declare(cg.f32);
2386
- const sz = cg.local.declare(cg.f32);
2387
- const cz = cg.local.declare(cg.f32);
2487
+ const { q, sz, cz } = _sincos(cg);
2388
2488
  const mag = cg.local.declare(cg.f32);
2389
- cg.local.get(0);
2390
- cg.local.get(0);
2391
- cg.f32.const(1 / (2 * Math.PI));
2392
- cg.f32.mul();
2393
- cg.f32.nearest();
2394
- cg.local.tee(qf);
2395
- cg.f32.const(2 * Math.PI);
2396
- cg.f32.mul();
2397
- cg.f32.sub();
2398
- cg.local.set(y);
2399
- cg.local.get(y);
2400
- cg.f32.const(2 / Math.PI);
2401
- cg.f32.mul();
2402
- cg.f32.nearest();
2403
- cg.local.tee(qf);
2404
- cg.i32.trunc_f32_s();
2405
- cg.local.set(q);
2406
- cg.local.get(y);
2407
- cg.local.get(qf);
2408
- cg.f32.const(Math.PI / 2);
2409
- cg.f32.mul();
2410
- cg.f32.sub();
2411
- cg.local.tee(z);
2412
- cg.local.get(z);
2413
- cg.f32.mul();
2414
- cg.local.set(z2);
2415
- cg.f32.const(-1 / 5040);
2416
- cg.local.get(z2);
2417
- cg.f32.mul();
2418
- cg.f32.const(1 / 120);
2419
- cg.f32.add();
2420
- cg.local.get(z2);
2421
- cg.f32.mul();
2422
- cg.f32.const(-1 / 6);
2423
- cg.f32.add();
2424
- cg.local.get(z2);
2425
- cg.f32.mul();
2426
- cg.f32.const(1);
2427
- cg.f32.add();
2428
- cg.local.get(z);
2429
- cg.f32.mul();
2430
- cg.local.set(sz);
2431
- cg.f32.const(-1 / 720);
2432
- cg.local.get(z2);
2433
- cg.f32.mul();
2434
- cg.f32.const(1 / 24);
2435
- cg.f32.add();
2436
- cg.local.get(z2);
2437
- cg.f32.mul();
2438
- cg.f32.const(-1 / 2);
2439
- cg.f32.add();
2440
- cg.local.get(z2);
2441
- cg.f32.mul();
2442
- cg.f32.const(1);
2443
- cg.f32.add();
2444
- cg.local.set(cz);
2445
2489
  cg.local.get(sz);
2446
2490
  cg.local.get(cz);
2447
2491
  cg.local.get(q);
@@ -2459,6 +2503,100 @@ function wasm_cos(cg) {
2459
2503
  cg.select();
2460
2504
  });
2461
2505
  }
2506
+ /** Helper function for approximating arctan(x). */
2507
+ function _atan(cg) {
2508
+ const x = cg.local.declare(cg.f32);
2509
+ const abs_x = cg.local.declare(cg.f32);
2510
+ const z = cg.local.declare(cg.f32);
2511
+ const z2 = cg.local.declare(cg.f32);
2512
+ const p = cg.local.declare(cg.f32);
2513
+ cg.local.set(x);
2514
+ cg.local.get(x);
2515
+ cg.f32.abs();
2516
+ cg.local.set(abs_x);
2517
+ cg.f32.const(1);
2518
+ cg.local.get(abs_x);
2519
+ cg.f32.div();
2520
+ cg.local.get(abs_x);
2521
+ cg.local.get(abs_x);
2522
+ cg.f32.const(1);
2523
+ cg.f32.ge();
2524
+ cg.select();
2525
+ cg.local.set(z);
2526
+ cg.local.get(z);
2527
+ cg.local.get(z);
2528
+ cg.f32.mul();
2529
+ cg.local.set(z2);
2530
+ cg.f32.const(.0415796528637);
2531
+ cg.local.get(z2);
2532
+ cg.f32.mul();
2533
+ cg.f32.const(.661705427875);
2534
+ cg.f32.add();
2535
+ cg.local.get(z2);
2536
+ cg.f32.mul();
2537
+ cg.f32.const(.999998614341);
2538
+ cg.f32.add();
2539
+ cg.f32.const(.173698870181);
2540
+ cg.local.get(z2);
2541
+ cg.f32.mul();
2542
+ cg.f32.const(.994987933645);
2543
+ cg.f32.add();
2544
+ cg.local.get(z2);
2545
+ cg.f32.mul();
2546
+ cg.f32.const(1);
2547
+ cg.f32.add();
2548
+ cg.f32.div();
2549
+ cg.local.get(z);
2550
+ cg.f32.mul();
2551
+ cg.local.set(p);
2552
+ cg.f32.const(Math.PI / 2);
2553
+ cg.local.get(p);
2554
+ cg.f32.sub();
2555
+ cg.local.get(p);
2556
+ cg.local.get(abs_x);
2557
+ cg.f32.const(1);
2558
+ cg.f32.ge();
2559
+ cg.select();
2560
+ cg.local.get(x);
2561
+ cg.f32.copysign();
2562
+ }
2563
+ /**
2564
+ * Approximate atan(x).
2565
+ *
2566
+ * Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
2567
+ * where P(u) = A0 + A1*u + A2*u^2 (degree 2)
2568
+ * Q(u) = 1 + B1*u + B2*u^2 (degree 2)
2569
+ * if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
2570
+ * (fitted coefficients, max error ~5e-7 on [0,1])
2571
+ */
2572
+ function wasm_atan(cg) {
2573
+ return cg.function([cg.f32], [cg.f32], () => {
2574
+ cg.local.get(0);
2575
+ _atan(cg);
2576
+ });
2577
+ }
2578
+ /**
2579
+ * Approximate asin(x).
2580
+ *
2581
+ * Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
2582
+ */
2583
+ function wasm_asin(cg) {
2584
+ return cg.function([cg.f32], [cg.f32], () => {
2585
+ cg.local.get(0);
2586
+ cg.f32.const(1);
2587
+ cg.local.get(0);
2588
+ cg.local.get(0);
2589
+ cg.f32.mul();
2590
+ cg.f32.sub();
2591
+ cg.f32.sqrt();
2592
+ cg.f32.const(1);
2593
+ cg.f32.add();
2594
+ cg.f32.div();
2595
+ _atan(cg);
2596
+ cg.f32.const(2);
2597
+ cg.f32.mul();
2598
+ });
2599
+ }
2462
2600
  /**
2463
2601
  * Threefry2x32 pseudorandom number generator.
2464
2602
  *
@@ -3338,6 +3476,8 @@ function codegenWasm(kernel) {
3338
3476
  const funcs = {};
3339
3477
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3340
3478
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3479
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3480
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3341
3481
  if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3342
3482
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3343
3483
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
@@ -3489,6 +3629,8 @@ function translateExp(cg, funcs, exp, ctx) {
3489
3629
  else throw new UnsupportedOpError(op, dtype, "wasm");
3490
3630
  } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3491
3631
  else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3632
+ else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3633
+ else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3492
3634
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3493
3635
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3494
3636
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
@@ -3587,10 +3729,11 @@ let defaultBackend = "wasm";
3587
3729
  const initializedBackends = /* @__PURE__ */ new Map();
3588
3730
  initializedBackends.set("cpu", new CpuBackend());
3589
3731
  initializedBackends.set("wasm", new WasmBackend());
3590
- /** Set the default device backend (must be initialized). */
3591
- function setDevice(device) {
3592
- if (initializedBackends.has(device)) defaultBackend = device;
3732
+ /** Configure the default device for arrays. */
3733
+ function defaultDevice(device) {
3734
+ if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
3593
3735
  else throw new Error(`Backend not initialized: ${device}`);
3736
+ return defaultBackend;
3594
3737
  }
3595
3738
  /**
3596
3739
  * Initialize `jax-js` library backends.
@@ -3617,7 +3760,7 @@ async function createBackend(device) {
3617
3760
  if (!navigator.gpu) return null;
3618
3761
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3619
3762
  if (!adapter) return null;
3620
- const { WebGPUBackend } = await import("./webgpu-CNg9JGva.js");
3763
+ const { WebGPUBackend } = await import("./webgpu-ow0Pn_6q.js");
3621
3764
  const importantLimits = [
3622
3765
  "maxBufferSize",
3623
3766
  "maxComputeInvocationsPerWorkgroup",
@@ -3670,4 +3813,4 @@ var UnsupportedOpError = class extends Error {
3670
3813
  };
3671
3814
 
3672
3815
  //#endregion
3673
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
3816
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };