@jax-js/jax 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -35,7 +35,22 @@ var PPrint = class PPrint {
35
35
  //#endregion
36
36
  //#region src/utils.ts
37
37
  /** @file Generic programming utilities with no dependencies on library code. */
38
- const DEBUG = 3;
38
+ let DEBUG = 0;
39
+ /**
40
+ * Set the debug level for verbose logging.
41
+ *
42
+ * 1. JIT compile logs
43
+ * 2. Shader code
44
+ * 3. Expressions and metadata
45
+ * 4. JIT programs, tuning details
46
+ * 5. Most verbose operation traces
47
+ *
48
+ * This is an experimental API and may change in behavior. Do not rely on this
49
+ * in production.
50
+ */
51
+ function setDebug(level) {
52
+ DEBUG = level;
53
+ }
39
54
  function unzip2(pairs) {
40
55
  const lst1 = [];
41
56
  const lst2 = [];
@@ -109,9 +124,23 @@ function isNumberPair(x) {
109
124
  }
110
125
  /** Check an axis against number of dimensions, and resolve negative axes. */
111
126
  function checkAxis(axis, ndim) {
112
- if (axis < -ndim || axis >= ndim) throw new Error(`Invalid axis ${axis} for array of ${ndim} dimensions`);
127
+ if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
113
128
  return axis < 0 ? axis + ndim : axis;
114
129
  }
130
+ /** Normalize common axis argument for functions, defaulting to all axes. */
131
+ function normalizeAxis(axis, ndim) {
132
+ if (axis === null) return range(ndim);
133
+ else if (typeof axis === "number") return [checkAxis(axis, ndim)];
134
+ else {
135
+ const seen = /* @__PURE__ */ new Set();
136
+ for (const a of axis) {
137
+ const ca = checkAxis(a, ndim);
138
+ if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
139
+ seen.add(ca);
140
+ }
141
+ return [...seen].sort();
142
+ }
143
+ }
115
144
  function range(start, stop, step = 1) {
116
145
  if (stop === void 0) {
117
146
  stop = start;
@@ -177,6 +206,35 @@ function findPow2(hint, max) {
177
206
  while (ret < hint && 2 * ret <= max) ret *= 2;
178
207
  return ret;
179
208
  }
209
+ /**
210
+ * Implements a NumPy-style generalized broadcast rule on two array shapes.
211
+ *
212
+ * "When operating on two arrays, NumPy compares their shapes element-wise. It
213
+ * starts with the trailing (i.e. rightmost) dimension and works its way left.
214
+ * Two dimensions are compatible when:
215
+ * 1. they are equal, or
216
+ * 2. one of them is 1."
217
+ *
218
+ * Throws a TypeError if the broadcast is not possible.
219
+ *
220
+ * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
221
+ */
222
+ function generalBroadcast(a, b) {
223
+ const out = [];
224
+ let i = a.length - 1;
225
+ let j = b.length - 1;
226
+ for (; i >= 0 && j >= 0; i--, j--) {
227
+ const x = a[i];
228
+ const y = b[j];
229
+ if (x === y) out.push(x);
230
+ else if (x === 1) out.push(y);
231
+ else if (y === 1) out.push(x);
232
+ else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
233
+ }
234
+ for (; i >= 0; i--) out.push(a[i]);
235
+ for (; j >= 0; j--) out.push(b[j]);
236
+ return out.reverse();
237
+ }
180
238
  function recursiveFlatten(ar) {
181
239
  if (!Array.isArray(ar)) return [ar];
182
240
  return ar.flat(Infinity);
@@ -186,6 +244,7 @@ function strip1(str) {
186
244
  if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
187
245
  return str;
188
246
  }
247
+ const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
189
248
  /**
190
249
  * Polynomial hashes modulo p are good at avoiding collisions in expectation.
191
250
  * Probability-wise, it's good enough to be used for something like
@@ -200,22 +259,26 @@ var FpHash = class FpHash {
200
259
  const modulus = 3189051996290219n;
201
260
  this.value = (this.value * base + x) % modulus;
202
261
  }
203
- update(...values) {
204
- for (const x of values) if (typeof x === "string") for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
205
- else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
262
+ update(x) {
263
+ if (typeof x === "string") {
264
+ this.#update(BigInt(x.length));
265
+ for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
266
+ } else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
206
267
  else {
207
- const ar = new Float64Array([x]);
208
- this.#update(new DataView(ar.buffer).getBigUint64(0, true));
268
+ _stagingbuf.setFloat64(0, x, true);
269
+ this.#update(_stagingbuf.getBigUint64(0, true));
209
270
  }
210
271
  else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
211
272
  else if (typeof x === "bigint") this.#update(x ^ 71657401n);
212
273
  else if (x === null) this.#update(37832657n);
213
274
  else if (x === void 0) this.#update(18145117n);
214
- else if (typeof x === "object" && "hash" in x) x.hash(this);
275
+ else x.hash(this);
215
276
  return this;
216
277
  }
217
278
  static hash(...values) {
218
- return new FpHash().update(...values).value;
279
+ const h = new FpHash();
280
+ for (const x of values) h.update(x);
281
+ return h.value;
219
282
  }
220
283
  };
221
284
  /** Run a function while caching it inline inside a `Map`. */
@@ -250,6 +313,41 @@ const byteWidth = (dtype) => {
250
313
  }
251
314
  };
252
315
  const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
316
+ /**
317
+ * Promote two dtypes to their join according to the type lattice.
318
+ *
319
+ * When performing operations between arrays of different types, we need to
320
+ * promote both operands to a common type that can represent values from both
321
+ * input types. This follows JAX's type promotion rules.
322
+ *
323
+ * **Type lattice:**
324
+ * ```text
325
+ * bool -> uint32 -> int32 -> float16 -> float32
326
+ * weakType --^
327
+ * ```
328
+ *
329
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
330
+ * which default to float32 but "weak" so they cast to the dtype of any array
331
+ * they are first combined with, except `bool`.
332
+ *
333
+ * **Examples:**
334
+ * - `promoteTypes(bool, int32) → int32`
335
+ * - `promoteTypes(uint32, int32) → int32`
336
+ * - `promoteTypes(int32, float16) → float16`
337
+ * - `promoteTypes(float16, float32) → float32`
338
+ * - `promoteTypes(uint32, float32) → float32`
339
+ */
340
+ function promoteTypes(dtype1, dtype2) {
341
+ if (dtype1 === dtype2) return dtype1;
342
+ const rank = {
343
+ [DType.Bool]: 0,
344
+ [DType.Uint32]: 1,
345
+ [DType.Int32]: 2,
346
+ [DType.Float16]: 3,
347
+ [DType.Float32]: 4
348
+ };
349
+ return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
350
+ }
253
351
  function dtypedArray(dtype, data) {
254
352
  const { buffer, byteLength, byteOffset } = data;
255
353
  const length = byteLength / byteWidth(dtype);
@@ -319,6 +417,12 @@ var AluExp = class AluExp {
319
417
  static cos(a) {
320
418
  return new AluExp(AluOp.Cos, a.dtype, [a]);
321
419
  }
420
+ static asin(a) {
421
+ return new AluExp(AluOp.Asin, a.dtype, [a]);
422
+ }
423
+ static atan(a) {
424
+ return new AluExp(AluOp.Atan, a.dtype, [a]);
425
+ }
322
426
  static exp(a) {
323
427
  return new AluExp(AluOp.Exp, a.dtype, [a]);
324
428
  }
@@ -402,8 +506,11 @@ var AluExp = class AluExp {
402
506
  getHash() {
403
507
  if (this.#hash !== void 0) return this.#hash;
404
508
  const hasher = new FpHash();
405
- hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
406
- hasher.update(this.src.length, ...this.src);
509
+ hasher.update(this.op);
510
+ hasher.update(this.dtype);
511
+ hasher.update(JSON.stringify(this.arg));
512
+ hasher.update(this.src.length);
513
+ for (const s of this.src) hasher.update(s);
407
514
  this.#hash = hasher.value;
408
515
  return this.#hash;
409
516
  }
@@ -475,10 +582,16 @@ var AluExp = class AluExp {
475
582
  ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
476
583
  break;
477
584
  case AluOp.Sin:
478
- ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
585
+ ret = [-1, 1];
479
586
  break;
480
587
  case AluOp.Cos:
481
- ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
588
+ ret = [-1, 1];
589
+ break;
590
+ case AluOp.Asin:
591
+ ret = [-Math.PI / 2, Math.PI / 2];
592
+ break;
593
+ case AluOp.Atan:
594
+ ret = [-Math.PI / 2, Math.PI / 2];
482
595
  break;
483
596
  case AluOp.Exp:
484
597
  ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
@@ -594,11 +707,12 @@ var AluExp = class AluExp {
594
707
  simplify(cache = /* @__PURE__ */ new Map()) {
595
708
  if (this.#simplified !== void 0) return this.#simplified;
596
709
  const hash = this.getHash();
597
- if (cache.has(hash)) return this.#simplified = cache.get(hash);
710
+ const prevCachedValue = cache.get(hash);
711
+ if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
598
712
  const simplified = this.#simplifyInner(cache);
599
713
  const simplifiedHash = simplified.getHash();
600
- if (cache.has(simplifiedHash)) {
601
- const prevSimplified = cache.get(simplifiedHash);
714
+ const prevSimplified = cache.get(simplifiedHash);
715
+ if (prevSimplified !== void 0) {
602
716
  cache.set(hash, prevSimplified);
603
717
  this.#simplified = prevSimplified;
604
718
  return prevSimplified;
@@ -802,6 +916,8 @@ var AluExp = class AluExp {
802
916
  switch (this.op) {
803
917
  case AluOp.Sin: return Math.sin(x);
804
918
  case AluOp.Cos: return Math.cos(x);
919
+ case AluOp.Asin: return Math.asin(x);
920
+ case AluOp.Atan: return Math.atan(x);
805
921
  case AluOp.Exp: return Math.exp(x);
806
922
  case AluOp.Log: return Math.log(x);
807
923
  case AluOp.Sqrt: return Math.sqrt(x);
@@ -981,6 +1097,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
981
1097
  AluOp$1["Max"] = "Max";
982
1098
  AluOp$1["Sin"] = "Sin";
983
1099
  AluOp$1["Cos"] = "Cos";
1100
+ AluOp$1["Asin"] = "Asin";
1101
+ AluOp$1["Atan"] = "Atan";
984
1102
  AluOp$1["Exp"] = "Exp";
985
1103
  AluOp$1["Log"] = "Log";
986
1104
  AluOp$1["Sqrt"] = "Sqrt";
@@ -1011,6 +1129,8 @@ const AluGroup = {
1011
1129
  Unary: new Set([
1012
1130
  AluOp.Sin,
1013
1131
  AluOp.Cos,
1132
+ AluOp.Asin,
1133
+ AluOp.Atan,
1014
1134
  AluOp.Exp,
1015
1135
  AluOp.Log,
1016
1136
  AluOp.Sqrt,
@@ -1034,6 +1154,8 @@ const AluGroup = {
1034
1154
  RequiredFloat: new Set([
1035
1155
  AluOp.Sin,
1036
1156
  AluOp.Cos,
1157
+ AluOp.Asin,
1158
+ AluOp.Atan,
1037
1159
  AluOp.Exp,
1038
1160
  AluOp.Log,
1039
1161
  AluOp.Sqrt,
@@ -1065,7 +1187,7 @@ var Kernel = class {
1065
1187
  this.exp = exp.simplify();
1066
1188
  }
1067
1189
  hash(state) {
1068
- state.update(this.nargs, this.size, this.exp, this.reduction);
1190
+ state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
1069
1191
  }
1070
1192
  pprint() {
1071
1193
  let details = PPrint.pp(`exp = ${this.exp}`);
@@ -1111,7 +1233,7 @@ var Reduction = class {
1111
1233
  this.epilogue = epilogue.simplify();
1112
1234
  }
1113
1235
  hash(state) {
1114
- state.update(this.dtype, this.op, this.size, this.epilogue);
1236
+ state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
1115
1237
  }
1116
1238
  toString() {
1117
1239
  return `${this.op}{${this.size}} -> ${this.epilogue}`;
@@ -2283,78 +2405,92 @@ function wasm_log(cg) {
2283
2405
  });
2284
2406
  }
2285
2407
  /**
2286
- * Approximate sin(x).
2408
+ * Common helper to approximate sin(x) and cos(x).
2287
2409
  *
2288
2410
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2289
- * z = y - q*(π/2); use odd polynomial on z:
2411
+ * z = y - q*(π/2); use one of two polynomials on z:
2290
2412
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2413
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2414
+ */
2415
+ function _sincos(cg) {
2416
+ const y = cg.local.declare(cg.f32);
2417
+ const qf = cg.local.declare(cg.f32);
2418
+ const q = cg.local.declare(cg.i32);
2419
+ const z = cg.local.declare(cg.f32);
2420
+ const z2 = cg.local.declare(cg.f32);
2421
+ const sz = cg.local.declare(cg.f32);
2422
+ const cz = cg.local.declare(cg.f32);
2423
+ cg.local.get(0);
2424
+ cg.local.get(0);
2425
+ cg.f32.const(1 / (2 * Math.PI));
2426
+ cg.f32.mul();
2427
+ cg.f32.nearest();
2428
+ cg.local.tee(qf);
2429
+ cg.f32.const(2 * Math.PI);
2430
+ cg.f32.mul();
2431
+ cg.f32.sub();
2432
+ cg.local.set(y);
2433
+ cg.local.get(y);
2434
+ cg.f32.const(2 / Math.PI);
2435
+ cg.f32.mul();
2436
+ cg.f32.nearest();
2437
+ cg.local.tee(qf);
2438
+ cg.i32.trunc_f32_s();
2439
+ cg.local.set(q);
2440
+ cg.local.get(y);
2441
+ cg.local.get(qf);
2442
+ cg.f32.const(Math.PI / 2);
2443
+ cg.f32.mul();
2444
+ cg.f32.sub();
2445
+ cg.local.tee(z);
2446
+ cg.local.get(z);
2447
+ cg.f32.mul();
2448
+ cg.local.set(z2);
2449
+ cg.f32.const(-1 / 5040);
2450
+ cg.local.get(z2);
2451
+ cg.f32.mul();
2452
+ cg.f32.const(1 / 120);
2453
+ cg.f32.add();
2454
+ cg.local.get(z2);
2455
+ cg.f32.mul();
2456
+ cg.f32.const(-1 / 6);
2457
+ cg.f32.add();
2458
+ cg.local.get(z2);
2459
+ cg.f32.mul();
2460
+ cg.f32.const(1);
2461
+ cg.f32.add();
2462
+ cg.local.get(z);
2463
+ cg.f32.mul();
2464
+ cg.local.set(sz);
2465
+ cg.f32.const(-1 / 720);
2466
+ cg.local.get(z2);
2467
+ cg.f32.mul();
2468
+ cg.f32.const(1 / 24);
2469
+ cg.f32.add();
2470
+ cg.local.get(z2);
2471
+ cg.f32.mul();
2472
+ cg.f32.const(-1 / 2);
2473
+ cg.f32.add();
2474
+ cg.local.get(z2);
2475
+ cg.f32.mul();
2476
+ cg.f32.const(1);
2477
+ cg.f32.add();
2478
+ cg.local.set(cz);
2479
+ return {
2480
+ q,
2481
+ sz,
2482
+ cz
2483
+ };
2484
+ }
2485
+ /**
2486
+ * Approximate sin(x).
2487
+ *
2488
+ * Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
2291
2489
  */
2292
2490
  function wasm_sin(cg) {
2293
2491
  return cg.function([cg.f32], [cg.f32], () => {
2294
- const y = cg.local.declare(cg.f32);
2295
- const qf = cg.local.declare(cg.f32);
2296
- const q = cg.local.declare(cg.i32);
2297
- const z = cg.local.declare(cg.f32);
2298
- const z2 = cg.local.declare(cg.f32);
2299
- const sz = cg.local.declare(cg.f32);
2300
- const cz = cg.local.declare(cg.f32);
2492
+ const { q, sz, cz } = _sincos(cg);
2301
2493
  const mag = cg.local.declare(cg.f32);
2302
- cg.local.get(0);
2303
- cg.local.get(0);
2304
- cg.f32.const(1 / (2 * Math.PI));
2305
- cg.f32.mul();
2306
- cg.f32.nearest();
2307
- cg.local.tee(qf);
2308
- cg.f32.const(2 * Math.PI);
2309
- cg.f32.mul();
2310
- cg.f32.sub();
2311
- cg.local.set(y);
2312
- cg.local.get(y);
2313
- cg.f32.const(2 / Math.PI);
2314
- cg.f32.mul();
2315
- cg.f32.nearest();
2316
- cg.local.tee(qf);
2317
- cg.i32.trunc_f32_s();
2318
- cg.local.set(q);
2319
- cg.local.get(y);
2320
- cg.local.get(qf);
2321
- cg.f32.const(Math.PI / 2);
2322
- cg.f32.mul();
2323
- cg.f32.sub();
2324
- cg.local.tee(z);
2325
- cg.local.get(z);
2326
- cg.f32.mul();
2327
- cg.local.set(z2);
2328
- cg.f32.const(-1 / 5040);
2329
- cg.local.get(z2);
2330
- cg.f32.mul();
2331
- cg.f32.const(1 / 120);
2332
- cg.f32.add();
2333
- cg.local.get(z2);
2334
- cg.f32.mul();
2335
- cg.f32.const(-1 / 6);
2336
- cg.f32.add();
2337
- cg.local.get(z2);
2338
- cg.f32.mul();
2339
- cg.f32.const(1);
2340
- cg.f32.add();
2341
- cg.local.get(z);
2342
- cg.f32.mul();
2343
- cg.local.set(sz);
2344
- cg.f32.const(-1 / 720);
2345
- cg.local.get(z2);
2346
- cg.f32.mul();
2347
- cg.f32.const(1 / 24);
2348
- cg.f32.add();
2349
- cg.local.get(z2);
2350
- cg.f32.mul();
2351
- cg.f32.const(-1 / 2);
2352
- cg.f32.add();
2353
- cg.local.get(z2);
2354
- cg.f32.mul();
2355
- cg.f32.const(1);
2356
- cg.f32.add();
2357
- cg.local.set(cz);
2358
2494
  cg.local.get(cz);
2359
2495
  cg.local.get(sz);
2360
2496
  cg.local.get(q);
@@ -2373,75 +2509,12 @@ function wasm_sin(cg) {
2373
2509
  /**
2374
2510
  * Approximate cos(x).
2375
2511
  *
2376
- * Same reduction as sinf, then quadrant mapping:
2377
- * k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2512
+ * Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2378
2513
  */
2379
2514
  function wasm_cos(cg) {
2380
2515
  return cg.function([cg.f32], [cg.f32], () => {
2381
- const y = cg.local.declare(cg.f32);
2382
- const qf = cg.local.declare(cg.f32);
2383
- const q = cg.local.declare(cg.i32);
2384
- const z = cg.local.declare(cg.f32);
2385
- const z2 = cg.local.declare(cg.f32);
2386
- const sz = cg.local.declare(cg.f32);
2387
- const cz = cg.local.declare(cg.f32);
2516
+ const { q, sz, cz } = _sincos(cg);
2388
2517
  const mag = cg.local.declare(cg.f32);
2389
- cg.local.get(0);
2390
- cg.local.get(0);
2391
- cg.f32.const(1 / (2 * Math.PI));
2392
- cg.f32.mul();
2393
- cg.f32.nearest();
2394
- cg.local.tee(qf);
2395
- cg.f32.const(2 * Math.PI);
2396
- cg.f32.mul();
2397
- cg.f32.sub();
2398
- cg.local.set(y);
2399
- cg.local.get(y);
2400
- cg.f32.const(2 / Math.PI);
2401
- cg.f32.mul();
2402
- cg.f32.nearest();
2403
- cg.local.tee(qf);
2404
- cg.i32.trunc_f32_s();
2405
- cg.local.set(q);
2406
- cg.local.get(y);
2407
- cg.local.get(qf);
2408
- cg.f32.const(Math.PI / 2);
2409
- cg.f32.mul();
2410
- cg.f32.sub();
2411
- cg.local.tee(z);
2412
- cg.local.get(z);
2413
- cg.f32.mul();
2414
- cg.local.set(z2);
2415
- cg.f32.const(-1 / 5040);
2416
- cg.local.get(z2);
2417
- cg.f32.mul();
2418
- cg.f32.const(1 / 120);
2419
- cg.f32.add();
2420
- cg.local.get(z2);
2421
- cg.f32.mul();
2422
- cg.f32.const(-1 / 6);
2423
- cg.f32.add();
2424
- cg.local.get(z2);
2425
- cg.f32.mul();
2426
- cg.f32.const(1);
2427
- cg.f32.add();
2428
- cg.local.get(z);
2429
- cg.f32.mul();
2430
- cg.local.set(sz);
2431
- cg.f32.const(-1 / 720);
2432
- cg.local.get(z2);
2433
- cg.f32.mul();
2434
- cg.f32.const(1 / 24);
2435
- cg.f32.add();
2436
- cg.local.get(z2);
2437
- cg.f32.mul();
2438
- cg.f32.const(-1 / 2);
2439
- cg.f32.add();
2440
- cg.local.get(z2);
2441
- cg.f32.mul();
2442
- cg.f32.const(1);
2443
- cg.f32.add();
2444
- cg.local.set(cz);
2445
2518
  cg.local.get(sz);
2446
2519
  cg.local.get(cz);
2447
2520
  cg.local.get(q);
@@ -2459,6 +2532,100 @@ function wasm_cos(cg) {
2459
2532
  cg.select();
2460
2533
  });
2461
2534
  }
2535
+ /** Helper function for approximating arctan(x). */
2536
+ function _atan(cg) {
2537
+ const x = cg.local.declare(cg.f32);
2538
+ const abs_x = cg.local.declare(cg.f32);
2539
+ const z = cg.local.declare(cg.f32);
2540
+ const z2 = cg.local.declare(cg.f32);
2541
+ const p = cg.local.declare(cg.f32);
2542
+ cg.local.set(x);
2543
+ cg.local.get(x);
2544
+ cg.f32.abs();
2545
+ cg.local.set(abs_x);
2546
+ cg.f32.const(1);
2547
+ cg.local.get(abs_x);
2548
+ cg.f32.div();
2549
+ cg.local.get(abs_x);
2550
+ cg.local.get(abs_x);
2551
+ cg.f32.const(1);
2552
+ cg.f32.ge();
2553
+ cg.select();
2554
+ cg.local.set(z);
2555
+ cg.local.get(z);
2556
+ cg.local.get(z);
2557
+ cg.f32.mul();
2558
+ cg.local.set(z2);
2559
+ cg.f32.const(.0415796528637);
2560
+ cg.local.get(z2);
2561
+ cg.f32.mul();
2562
+ cg.f32.const(.661705427875);
2563
+ cg.f32.add();
2564
+ cg.local.get(z2);
2565
+ cg.f32.mul();
2566
+ cg.f32.const(.999998614341);
2567
+ cg.f32.add();
2568
+ cg.f32.const(.173698870181);
2569
+ cg.local.get(z2);
2570
+ cg.f32.mul();
2571
+ cg.f32.const(.994987933645);
2572
+ cg.f32.add();
2573
+ cg.local.get(z2);
2574
+ cg.f32.mul();
2575
+ cg.f32.const(1);
2576
+ cg.f32.add();
2577
+ cg.f32.div();
2578
+ cg.local.get(z);
2579
+ cg.f32.mul();
2580
+ cg.local.set(p);
2581
+ cg.f32.const(Math.PI / 2);
2582
+ cg.local.get(p);
2583
+ cg.f32.sub();
2584
+ cg.local.get(p);
2585
+ cg.local.get(abs_x);
2586
+ cg.f32.const(1);
2587
+ cg.f32.ge();
2588
+ cg.select();
2589
+ cg.local.get(x);
2590
+ cg.f32.copysign();
2591
+ }
2592
+ /**
2593
+ * Approximate atan(x).
2594
+ *
2595
+ * Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
2596
+ * where P(u) = A0 + A1*u + A2*u^2 (degree 2)
2597
+ * Q(u) = 1 + B1*u + B2*u^2 (degree 2)
2598
+ * if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
2599
+ * (fitted coefficients, max error ~5e-7 on [0,1])
2600
+ */
2601
+ function wasm_atan(cg) {
2602
+ return cg.function([cg.f32], [cg.f32], () => {
2603
+ cg.local.get(0);
2604
+ _atan(cg);
2605
+ });
2606
+ }
2607
+ /**
2608
+ * Approximate asin(x).
2609
+ *
2610
+ * Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
2611
+ */
2612
+ function wasm_asin(cg) {
2613
+ return cg.function([cg.f32], [cg.f32], () => {
2614
+ cg.local.get(0);
2615
+ cg.f32.const(1);
2616
+ cg.local.get(0);
2617
+ cg.local.get(0);
2618
+ cg.f32.mul();
2619
+ cg.f32.sub();
2620
+ cg.f32.sqrt();
2621
+ cg.f32.const(1);
2622
+ cg.f32.add();
2623
+ cg.f32.div();
2624
+ _atan(cg);
2625
+ cg.f32.const(2);
2626
+ cg.f32.mul();
2627
+ });
2628
+ }
2462
2629
  /**
2463
2630
  * Threefry2x32 pseudorandom number generator.
2464
2631
  *
@@ -3338,6 +3505,8 @@ function codegenWasm(kernel) {
3338
3505
  const funcs = {};
3339
3506
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3340
3507
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3508
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3509
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3341
3510
  if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3342
3511
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3343
3512
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
@@ -3489,6 +3658,8 @@ function translateExp(cg, funcs, exp, ctx) {
3489
3658
  else throw new UnsupportedOpError(op, dtype, "wasm");
3490
3659
  } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3491
3660
  else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3661
+ else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3662
+ else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3492
3663
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3493
3664
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3494
3665
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
@@ -3587,10 +3758,11 @@ let defaultBackend = "wasm";
3587
3758
  const initializedBackends = /* @__PURE__ */ new Map();
3588
3759
  initializedBackends.set("cpu", new CpuBackend());
3589
3760
  initializedBackends.set("wasm", new WasmBackend());
3590
- /** Set the default device backend (must be initialized). */
3591
- function setDevice(device) {
3592
- if (initializedBackends.has(device)) defaultBackend = device;
3761
+ /** Configure the default device for arrays. */
3762
+ function defaultDevice(device) {
3763
+ if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
3593
3764
  else throw new Error(`Backend not initialized: ${device}`);
3765
+ return defaultBackend;
3594
3766
  }
3595
3767
  /**
3596
3768
  * Initialize `jax-js` library backends.
@@ -3617,7 +3789,7 @@ async function createBackend(device) {
3617
3789
  if (!navigator.gpu) return null;
3618
3790
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3619
3791
  if (!adapter) return null;
3620
- const { WebGPUBackend } = await import("./webgpu-CNg9JGva.js");
3792
+ const { WebGPUBackend } = await import("./webgpu-CM-xNYzW.js");
3621
3793
  const importantLimits = [
3622
3794
  "maxBufferSize",
3623
3795
  "maxComputeInvocationsPerWorkgroup",
@@ -3670,4 +3842,4 @@ var UnsupportedOpError = class extends Error {
3670
3842
  };
3671
3843
 
3672
3844
  //#endregion
3673
- export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
3845
+ export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };