@jax-js/jax 0.0.3 → 0.0.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -36,7 +36,22 @@ var PPrint = class PPrint {
36
36
  //#endregion
37
37
  //#region src/utils.ts
38
38
  /** @file Generic programming utilities with no dependencies on library code. */
39
- const DEBUG = 3;
39
+ let DEBUG = 0;
40
+ /**
41
+ * Set the debug level for verbose logging.
42
+ *
43
+ * 1. JIT compile logs
44
+ * 2. Shader code
45
+ * 3. Expressions and metadata
46
+ * 4. JIT programs, tuning details
47
+ * 5. Most verbose operation traces
48
+ *
49
+ * This is an experimental API and may change in behavior. Do not rely on this
50
+ * in production.
51
+ */
52
+ function setDebug(level) {
53
+ DEBUG = level;
54
+ }
40
55
  function unzip2(pairs) {
41
56
  const lst1 = [];
42
57
  const lst2 = [];
@@ -110,9 +125,23 @@ function isNumberPair(x) {
110
125
  }
111
126
  /** Check an axis against number of dimensions, and resolve negative axes. */
112
127
  function checkAxis(axis, ndim) {
113
- if (axis < -ndim || axis >= ndim) throw new Error(`Invalid axis ${axis} for array of ${ndim} dimensions`);
128
+ if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
114
129
  return axis < 0 ? axis + ndim : axis;
115
130
  }
131
+ /** Normalize common axis argument for functions, defaulting to all axes. */
132
+ function normalizeAxis(axis, ndim) {
133
+ if (axis === null) return range(ndim);
134
+ else if (typeof axis === "number") return [checkAxis(axis, ndim)];
135
+ else {
136
+ const seen = /* @__PURE__ */ new Set();
137
+ for (const a of axis) {
138
+ const ca = checkAxis(a, ndim);
139
+ if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
140
+ seen.add(ca);
141
+ }
142
+ return [...seen].sort();
143
+ }
144
+ }
116
145
  function range(start, stop, step = 1) {
117
146
  if (stop === void 0) {
118
147
  stop = start;
@@ -178,6 +207,35 @@ function findPow2(hint, max) {
178
207
  while (ret < hint && 2 * ret <= max) ret *= 2;
179
208
  return ret;
180
209
  }
210
+ /**
211
+ * Implements a NumPy-style generalized broadcast rule on two array shapes.
212
+ *
213
+ * "When operating on two arrays, NumPy compares their shapes element-wise. It
214
+ * starts with the trailing (i.e. rightmost) dimension and works its way left.
215
+ * Two dimensions are compatible when:
216
+ * 1. they are equal, or
217
+ * 2. one of them is 1."
218
+ *
219
+ * Throws a TypeError if the broadcast is not possible.
220
+ *
221
+ * <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
222
+ */
223
+ function generalBroadcast(a, b) {
224
+ const out = [];
225
+ let i = a.length - 1;
226
+ let j = b.length - 1;
227
+ for (; i >= 0 && j >= 0; i--, j--) {
228
+ const x = a[i];
229
+ const y = b[j];
230
+ if (x === y) out.push(x);
231
+ else if (x === 1) out.push(y);
232
+ else if (y === 1) out.push(x);
233
+ else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
234
+ }
235
+ for (; i >= 0; i--) out.push(a[i]);
236
+ for (; j >= 0; j--) out.push(b[j]);
237
+ return out.reverse();
238
+ }
181
239
  function recursiveFlatten(ar) {
182
240
  if (!Array.isArray(ar)) return [ar];
183
241
  return ar.flat(Infinity);
@@ -187,6 +245,7 @@ function strip1(str) {
187
245
  if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
188
246
  return str;
189
247
  }
248
+ const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
190
249
  /**
191
250
  * Polynomial hashes modulo p are good at avoiding collisions in expectation.
192
251
  * Probability-wise, it's good enough to be used for something like
@@ -201,22 +260,26 @@ var FpHash = class FpHash {
201
260
  const modulus = 3189051996290219n;
202
261
  this.value = (this.value * base + x) % modulus;
203
262
  }
204
- update(...values) {
205
- for (const x of values) if (typeof x === "string") for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
206
- else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
263
+ update(x) {
264
+ if (typeof x === "string") {
265
+ this.#update(BigInt(x.length));
266
+ for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
267
+ } else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
207
268
  else {
208
- const ar = new Float64Array([x]);
209
- this.#update(new DataView(ar.buffer).getBigUint64(0, true));
269
+ _stagingbuf.setFloat64(0, x, true);
270
+ this.#update(_stagingbuf.getBigUint64(0, true));
210
271
  }
211
272
  else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
212
273
  else if (typeof x === "bigint") this.#update(x ^ 71657401n);
213
274
  else if (x === null) this.#update(37832657n);
214
275
  else if (x === void 0) this.#update(18145117n);
215
- else if (typeof x === "object" && "hash" in x) x.hash(this);
276
+ else x.hash(this);
216
277
  return this;
217
278
  }
218
279
  static hash(...values) {
219
- return new FpHash().update(...values).value;
280
+ const h = new FpHash();
281
+ for (const x of values) h.update(x);
282
+ return h.value;
220
283
  }
221
284
  };
222
285
  /** Run a function while caching it inline inside a `Map`. */
@@ -251,6 +314,41 @@ const byteWidth = (dtype) => {
251
314
  }
252
315
  };
253
316
  const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
317
+ /**
318
+ * Promote two dtypes to their join according to the type lattice.
319
+ *
320
+ * When performing operations between arrays of different types, we need to
321
+ * promote both operands to a common type that can represent values from both
322
+ * input types. This follows JAX's type promotion rules.
323
+ *
324
+ * **Type lattice:**
325
+ * ```text
326
+ * bool -> uint32 -> int32 -> float16 -> float32
327
+ * weakType --^
328
+ * ```
329
+ *
330
+ * `weakType` represents weakly typed arrays. These are created for JS numbers,
331
+ * which default to float32 but "weak" so they cast to the dtype of any array
332
+ * they are first combined with, except `bool`.
333
+ *
334
+ * **Examples:**
335
+ * - `promoteTypes(bool, int32) → int32`
336
+ * - `promoteTypes(uint32, int32) → int32`
337
+ * - `promoteTypes(int32, float16) → float16`
338
+ * - `promoteTypes(float16, float32) → float32`
339
+ * - `promoteTypes(uint32, float32) → float32`
340
+ */
341
+ function promoteTypes(dtype1, dtype2) {
342
+ if (dtype1 === dtype2) return dtype1;
343
+ const rank = {
344
+ [DType.Bool]: 0,
345
+ [DType.Uint32]: 1,
346
+ [DType.Int32]: 2,
347
+ [DType.Float16]: 3,
348
+ [DType.Float32]: 4
349
+ };
350
+ return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
351
+ }
254
352
  function dtypedArray(dtype, data) {
255
353
  const { buffer, byteLength, byteOffset } = data;
256
354
  const length = byteLength / byteWidth(dtype);
@@ -320,6 +418,12 @@ var AluExp = class AluExp {
320
418
  static cos(a) {
321
419
  return new AluExp(AluOp.Cos, a.dtype, [a]);
322
420
  }
421
+ static asin(a) {
422
+ return new AluExp(AluOp.Asin, a.dtype, [a]);
423
+ }
424
+ static atan(a) {
425
+ return new AluExp(AluOp.Atan, a.dtype, [a]);
426
+ }
323
427
  static exp(a) {
324
428
  return new AluExp(AluOp.Exp, a.dtype, [a]);
325
429
  }
@@ -403,8 +507,11 @@ var AluExp = class AluExp {
403
507
  getHash() {
404
508
  if (this.#hash !== void 0) return this.#hash;
405
509
  const hasher = new FpHash();
406
- hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
407
- hasher.update(this.src.length, ...this.src);
510
+ hasher.update(this.op);
511
+ hasher.update(this.dtype);
512
+ hasher.update(JSON.stringify(this.arg));
513
+ hasher.update(this.src.length);
514
+ for (const s of this.src) hasher.update(s);
408
515
  this.#hash = hasher.value;
409
516
  return this.#hash;
410
517
  }
@@ -476,10 +583,16 @@ var AluExp = class AluExp {
476
583
  ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
477
584
  break;
478
585
  case AluOp.Sin:
479
- ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
586
+ ret = [-1, 1];
480
587
  break;
481
588
  case AluOp.Cos:
482
- ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
589
+ ret = [-1, 1];
590
+ break;
591
+ case AluOp.Asin:
592
+ ret = [-Math.PI / 2, Math.PI / 2];
593
+ break;
594
+ case AluOp.Atan:
595
+ ret = [-Math.PI / 2, Math.PI / 2];
483
596
  break;
484
597
  case AluOp.Exp:
485
598
  ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
@@ -595,11 +708,12 @@ var AluExp = class AluExp {
595
708
  simplify(cache = /* @__PURE__ */ new Map()) {
596
709
  if (this.#simplified !== void 0) return this.#simplified;
597
710
  const hash = this.getHash();
598
- if (cache.has(hash)) return this.#simplified = cache.get(hash);
711
+ const prevCachedValue = cache.get(hash);
712
+ if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
599
713
  const simplified = this.#simplifyInner(cache);
600
714
  const simplifiedHash = simplified.getHash();
601
- if (cache.has(simplifiedHash)) {
602
- const prevSimplified = cache.get(simplifiedHash);
715
+ const prevSimplified = cache.get(simplifiedHash);
716
+ if (prevSimplified !== void 0) {
603
717
  cache.set(hash, prevSimplified);
604
718
  this.#simplified = prevSimplified;
605
719
  return prevSimplified;
@@ -803,6 +917,8 @@ var AluExp = class AluExp {
803
917
  switch (this.op) {
804
918
  case AluOp.Sin: return Math.sin(x);
805
919
  case AluOp.Cos: return Math.cos(x);
920
+ case AluOp.Asin: return Math.asin(x);
921
+ case AluOp.Atan: return Math.atan(x);
806
922
  case AluOp.Exp: return Math.exp(x);
807
923
  case AluOp.Log: return Math.log(x);
808
924
  case AluOp.Sqrt: return Math.sqrt(x);
@@ -982,6 +1098,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
982
1098
  AluOp$1["Max"] = "Max";
983
1099
  AluOp$1["Sin"] = "Sin";
984
1100
  AluOp$1["Cos"] = "Cos";
1101
+ AluOp$1["Asin"] = "Asin";
1102
+ AluOp$1["Atan"] = "Atan";
985
1103
  AluOp$1["Exp"] = "Exp";
986
1104
  AluOp$1["Log"] = "Log";
987
1105
  AluOp$1["Sqrt"] = "Sqrt";
@@ -1012,6 +1130,8 @@ const AluGroup = {
1012
1130
  Unary: new Set([
1013
1131
  AluOp.Sin,
1014
1132
  AluOp.Cos,
1133
+ AluOp.Asin,
1134
+ AluOp.Atan,
1015
1135
  AluOp.Exp,
1016
1136
  AluOp.Log,
1017
1137
  AluOp.Sqrt,
@@ -1035,6 +1155,8 @@ const AluGroup = {
1035
1155
  RequiredFloat: new Set([
1036
1156
  AluOp.Sin,
1037
1157
  AluOp.Cos,
1158
+ AluOp.Asin,
1159
+ AluOp.Atan,
1038
1160
  AluOp.Exp,
1039
1161
  AluOp.Log,
1040
1162
  AluOp.Sqrt,
@@ -1066,7 +1188,7 @@ var Kernel = class {
1066
1188
  this.exp = exp.simplify();
1067
1189
  }
1068
1190
  hash(state) {
1069
- state.update(this.nargs, this.size, this.exp, this.reduction);
1191
+ state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
1070
1192
  }
1071
1193
  pprint() {
1072
1194
  let details = PPrint.pp(`exp = ${this.exp}`);
@@ -1112,7 +1234,7 @@ var Reduction = class {
1112
1234
  this.epilogue = epilogue.simplify();
1113
1235
  }
1114
1236
  hash(state) {
1115
- state.update(this.dtype, this.op, this.size, this.epilogue);
1237
+ state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
1116
1238
  }
1117
1239
  toString() {
1118
1240
  return `${this.op}{${this.size}} -> ${this.epilogue}`;
@@ -2284,78 +2406,92 @@ function wasm_log(cg) {
2284
2406
  });
2285
2407
  }
2286
2408
  /**
2287
- * Approximate sin(x).
2409
+ * Common helper to approximate sin(x) and cos(x).
2288
2410
  *
2289
2411
  * Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
2290
- * z = y - q*(π/2); use odd polynomial on z:
2412
+ * z = y - q*(π/2); use one of two polynomials on z:
2291
2413
  * sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
2414
+ * cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
2415
+ */
2416
+ function _sincos(cg) {
2417
+ const y = cg.local.declare(cg.f32);
2418
+ const qf = cg.local.declare(cg.f32);
2419
+ const q = cg.local.declare(cg.i32);
2420
+ const z = cg.local.declare(cg.f32);
2421
+ const z2 = cg.local.declare(cg.f32);
2422
+ const sz = cg.local.declare(cg.f32);
2423
+ const cz = cg.local.declare(cg.f32);
2424
+ cg.local.get(0);
2425
+ cg.local.get(0);
2426
+ cg.f32.const(1 / (2 * Math.PI));
2427
+ cg.f32.mul();
2428
+ cg.f32.nearest();
2429
+ cg.local.tee(qf);
2430
+ cg.f32.const(2 * Math.PI);
2431
+ cg.f32.mul();
2432
+ cg.f32.sub();
2433
+ cg.local.set(y);
2434
+ cg.local.get(y);
2435
+ cg.f32.const(2 / Math.PI);
2436
+ cg.f32.mul();
2437
+ cg.f32.nearest();
2438
+ cg.local.tee(qf);
2439
+ cg.i32.trunc_f32_s();
2440
+ cg.local.set(q);
2441
+ cg.local.get(y);
2442
+ cg.local.get(qf);
2443
+ cg.f32.const(Math.PI / 2);
2444
+ cg.f32.mul();
2445
+ cg.f32.sub();
2446
+ cg.local.tee(z);
2447
+ cg.local.get(z);
2448
+ cg.f32.mul();
2449
+ cg.local.set(z2);
2450
+ cg.f32.const(-1 / 5040);
2451
+ cg.local.get(z2);
2452
+ cg.f32.mul();
2453
+ cg.f32.const(1 / 120);
2454
+ cg.f32.add();
2455
+ cg.local.get(z2);
2456
+ cg.f32.mul();
2457
+ cg.f32.const(-1 / 6);
2458
+ cg.f32.add();
2459
+ cg.local.get(z2);
2460
+ cg.f32.mul();
2461
+ cg.f32.const(1);
2462
+ cg.f32.add();
2463
+ cg.local.get(z);
2464
+ cg.f32.mul();
2465
+ cg.local.set(sz);
2466
+ cg.f32.const(-1 / 720);
2467
+ cg.local.get(z2);
2468
+ cg.f32.mul();
2469
+ cg.f32.const(1 / 24);
2470
+ cg.f32.add();
2471
+ cg.local.get(z2);
2472
+ cg.f32.mul();
2473
+ cg.f32.const(-1 / 2);
2474
+ cg.f32.add();
2475
+ cg.local.get(z2);
2476
+ cg.f32.mul();
2477
+ cg.f32.const(1);
2478
+ cg.f32.add();
2479
+ cg.local.set(cz);
2480
+ return {
2481
+ q,
2482
+ sz,
2483
+ cz
2484
+ };
2485
+ }
2486
+ /**
2487
+ * Approximate sin(x).
2488
+ *
2489
+ * Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
2292
2490
  */
2293
2491
  function wasm_sin(cg) {
2294
2492
  return cg.function([cg.f32], [cg.f32], () => {
2295
- const y = cg.local.declare(cg.f32);
2296
- const qf = cg.local.declare(cg.f32);
2297
- const q = cg.local.declare(cg.i32);
2298
- const z = cg.local.declare(cg.f32);
2299
- const z2 = cg.local.declare(cg.f32);
2300
- const sz = cg.local.declare(cg.f32);
2301
- const cz = cg.local.declare(cg.f32);
2493
+ const { q, sz, cz } = _sincos(cg);
2302
2494
  const mag = cg.local.declare(cg.f32);
2303
- cg.local.get(0);
2304
- cg.local.get(0);
2305
- cg.f32.const(1 / (2 * Math.PI));
2306
- cg.f32.mul();
2307
- cg.f32.nearest();
2308
- cg.local.tee(qf);
2309
- cg.f32.const(2 * Math.PI);
2310
- cg.f32.mul();
2311
- cg.f32.sub();
2312
- cg.local.set(y);
2313
- cg.local.get(y);
2314
- cg.f32.const(2 / Math.PI);
2315
- cg.f32.mul();
2316
- cg.f32.nearest();
2317
- cg.local.tee(qf);
2318
- cg.i32.trunc_f32_s();
2319
- cg.local.set(q);
2320
- cg.local.get(y);
2321
- cg.local.get(qf);
2322
- cg.f32.const(Math.PI / 2);
2323
- cg.f32.mul();
2324
- cg.f32.sub();
2325
- cg.local.tee(z);
2326
- cg.local.get(z);
2327
- cg.f32.mul();
2328
- cg.local.set(z2);
2329
- cg.f32.const(-1 / 5040);
2330
- cg.local.get(z2);
2331
- cg.f32.mul();
2332
- cg.f32.const(1 / 120);
2333
- cg.f32.add();
2334
- cg.local.get(z2);
2335
- cg.f32.mul();
2336
- cg.f32.const(-1 / 6);
2337
- cg.f32.add();
2338
- cg.local.get(z2);
2339
- cg.f32.mul();
2340
- cg.f32.const(1);
2341
- cg.f32.add();
2342
- cg.local.get(z);
2343
- cg.f32.mul();
2344
- cg.local.set(sz);
2345
- cg.f32.const(-1 / 720);
2346
- cg.local.get(z2);
2347
- cg.f32.mul();
2348
- cg.f32.const(1 / 24);
2349
- cg.f32.add();
2350
- cg.local.get(z2);
2351
- cg.f32.mul();
2352
- cg.f32.const(-1 / 2);
2353
- cg.f32.add();
2354
- cg.local.get(z2);
2355
- cg.f32.mul();
2356
- cg.f32.const(1);
2357
- cg.f32.add();
2358
- cg.local.set(cz);
2359
2495
  cg.local.get(cz);
2360
2496
  cg.local.get(sz);
2361
2497
  cg.local.get(q);
@@ -2374,75 +2510,12 @@ function wasm_sin(cg) {
2374
2510
  /**
2375
2511
  * Approximate cos(x).
2376
2512
  *
2377
- * Same reduction as sinf, then quadrant mapping:
2378
- * k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2513
+ * Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
2379
2514
  */
2380
2515
  function wasm_cos(cg) {
2381
2516
  return cg.function([cg.f32], [cg.f32], () => {
2382
- const y = cg.local.declare(cg.f32);
2383
- const qf = cg.local.declare(cg.f32);
2384
- const q = cg.local.declare(cg.i32);
2385
- const z = cg.local.declare(cg.f32);
2386
- const z2 = cg.local.declare(cg.f32);
2387
- const sz = cg.local.declare(cg.f32);
2388
- const cz = cg.local.declare(cg.f32);
2517
+ const { q, sz, cz } = _sincos(cg);
2389
2518
  const mag = cg.local.declare(cg.f32);
2390
- cg.local.get(0);
2391
- cg.local.get(0);
2392
- cg.f32.const(1 / (2 * Math.PI));
2393
- cg.f32.mul();
2394
- cg.f32.nearest();
2395
- cg.local.tee(qf);
2396
- cg.f32.const(2 * Math.PI);
2397
- cg.f32.mul();
2398
- cg.f32.sub();
2399
- cg.local.set(y);
2400
- cg.local.get(y);
2401
- cg.f32.const(2 / Math.PI);
2402
- cg.f32.mul();
2403
- cg.f32.nearest();
2404
- cg.local.tee(qf);
2405
- cg.i32.trunc_f32_s();
2406
- cg.local.set(q);
2407
- cg.local.get(y);
2408
- cg.local.get(qf);
2409
- cg.f32.const(Math.PI / 2);
2410
- cg.f32.mul();
2411
- cg.f32.sub();
2412
- cg.local.tee(z);
2413
- cg.local.get(z);
2414
- cg.f32.mul();
2415
- cg.local.set(z2);
2416
- cg.f32.const(-1 / 5040);
2417
- cg.local.get(z2);
2418
- cg.f32.mul();
2419
- cg.f32.const(1 / 120);
2420
- cg.f32.add();
2421
- cg.local.get(z2);
2422
- cg.f32.mul();
2423
- cg.f32.const(-1 / 6);
2424
- cg.f32.add();
2425
- cg.local.get(z2);
2426
- cg.f32.mul();
2427
- cg.f32.const(1);
2428
- cg.f32.add();
2429
- cg.local.get(z);
2430
- cg.f32.mul();
2431
- cg.local.set(sz);
2432
- cg.f32.const(-1 / 720);
2433
- cg.local.get(z2);
2434
- cg.f32.mul();
2435
- cg.f32.const(1 / 24);
2436
- cg.f32.add();
2437
- cg.local.get(z2);
2438
- cg.f32.mul();
2439
- cg.f32.const(-1 / 2);
2440
- cg.f32.add();
2441
- cg.local.get(z2);
2442
- cg.f32.mul();
2443
- cg.f32.const(1);
2444
- cg.f32.add();
2445
- cg.local.set(cz);
2446
2519
  cg.local.get(sz);
2447
2520
  cg.local.get(cz);
2448
2521
  cg.local.get(q);
@@ -2460,6 +2533,100 @@ function wasm_cos(cg) {
2460
2533
  cg.select();
2461
2534
  });
2462
2535
  }
2536
+ /** Helper function for approximating arctan(x). */
2537
+ function _atan(cg) {
2538
+ const x = cg.local.declare(cg.f32);
2539
+ const abs_x = cg.local.declare(cg.f32);
2540
+ const z = cg.local.declare(cg.f32);
2541
+ const z2 = cg.local.declare(cg.f32);
2542
+ const p = cg.local.declare(cg.f32);
2543
+ cg.local.set(x);
2544
+ cg.local.get(x);
2545
+ cg.f32.abs();
2546
+ cg.local.set(abs_x);
2547
+ cg.f32.const(1);
2548
+ cg.local.get(abs_x);
2549
+ cg.f32.div();
2550
+ cg.local.get(abs_x);
2551
+ cg.local.get(abs_x);
2552
+ cg.f32.const(1);
2553
+ cg.f32.ge();
2554
+ cg.select();
2555
+ cg.local.set(z);
2556
+ cg.local.get(z);
2557
+ cg.local.get(z);
2558
+ cg.f32.mul();
2559
+ cg.local.set(z2);
2560
+ cg.f32.const(.0415796528637);
2561
+ cg.local.get(z2);
2562
+ cg.f32.mul();
2563
+ cg.f32.const(.661705427875);
2564
+ cg.f32.add();
2565
+ cg.local.get(z2);
2566
+ cg.f32.mul();
2567
+ cg.f32.const(.999998614341);
2568
+ cg.f32.add();
2569
+ cg.f32.const(.173698870181);
2570
+ cg.local.get(z2);
2571
+ cg.f32.mul();
2572
+ cg.f32.const(.994987933645);
2573
+ cg.f32.add();
2574
+ cg.local.get(z2);
2575
+ cg.f32.mul();
2576
+ cg.f32.const(1);
2577
+ cg.f32.add();
2578
+ cg.f32.div();
2579
+ cg.local.get(z);
2580
+ cg.f32.mul();
2581
+ cg.local.set(p);
2582
+ cg.f32.const(Math.PI / 2);
2583
+ cg.local.get(p);
2584
+ cg.f32.sub();
2585
+ cg.local.get(p);
2586
+ cg.local.get(abs_x);
2587
+ cg.f32.const(1);
2588
+ cg.f32.ge();
2589
+ cg.select();
2590
+ cg.local.get(x);
2591
+ cg.f32.copysign();
2592
+ }
2593
+ /**
2594
+ * Approximate atan(x).
2595
+ *
2596
+ * Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
2597
+ * where P(u) = A0 + A1*u + A2*u^2 (degree 2)
2598
+ * Q(u) = 1 + B1*u + B2*u^2 (degree 2)
2599
+ * if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
2600
+ * (fitted coefficients, max error ~5e-7 on [0,1])
2601
+ */
2602
+ function wasm_atan(cg) {
2603
+ return cg.function([cg.f32], [cg.f32], () => {
2604
+ cg.local.get(0);
2605
+ _atan(cg);
2606
+ });
2607
+ }
2608
+ /**
2609
+ * Approximate asin(x).
2610
+ *
2611
+ * Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
2612
+ */
2613
+ function wasm_asin(cg) {
2614
+ return cg.function([cg.f32], [cg.f32], () => {
2615
+ cg.local.get(0);
2616
+ cg.f32.const(1);
2617
+ cg.local.get(0);
2618
+ cg.local.get(0);
2619
+ cg.f32.mul();
2620
+ cg.f32.sub();
2621
+ cg.f32.sqrt();
2622
+ cg.f32.const(1);
2623
+ cg.f32.add();
2624
+ cg.f32.div();
2625
+ _atan(cg);
2626
+ cg.f32.const(2);
2627
+ cg.f32.mul();
2628
+ });
2629
+ }
2463
2630
  /**
2464
2631
  * Threefry2x32 pseudorandom number generator.
2465
2632
  *
@@ -3339,6 +3506,8 @@ function codegenWasm(kernel) {
3339
3506
  const funcs = {};
3340
3507
  if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
3341
3508
  if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
3509
+ if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
3510
+ if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
3342
3511
  if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
3343
3512
  if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
3344
3513
  if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
@@ -3490,6 +3659,8 @@ function translateExp(cg, funcs, exp, ctx) {
3490
3659
  else throw new UnsupportedOpError(op, dtype, "wasm");
3491
3660
  } else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
3492
3661
  else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
3662
+ else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
3663
+ else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
3493
3664
  else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
3494
3665
  else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
3495
3666
  else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
@@ -3588,10 +3759,11 @@ let defaultBackend = "wasm";
3588
3759
  const initializedBackends = /* @__PURE__ */ new Map();
3589
3760
  initializedBackends.set("cpu", new CpuBackend());
3590
3761
  initializedBackends.set("wasm", new WasmBackend());
3591
- /** Set the default device backend (must be initialized). */
3592
- function setDevice(device) {
3593
- if (initializedBackends.has(device)) defaultBackend = device;
3762
+ /** Configure the default device for arrays. */
3763
+ function defaultDevice(device) {
3764
+ if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
3594
3765
  else throw new Error(`Backend not initialized: ${device}`);
3766
+ return defaultBackend;
3595
3767
  }
3596
3768
  /**
3597
3769
  * Initialize `jax-js` library backends.
@@ -3618,7 +3790,7 @@ async function createBackend(device) {
3618
3790
  if (!navigator.gpu) return null;
3619
3791
  const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
3620
3792
  if (!adapter) return null;
3621
- const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-fqhx41TC.cjs"));
3793
+ const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CNOpiO5T.cjs"));
3622
3794
  const importantLimits = [
3623
3795
  "maxBufferSize",
3624
3796
  "maxComputeInvocationsPerWorkgroup",
@@ -3785,6 +3957,12 @@ Object.defineProperty(exports, 'deepEqual', {
3785
3957
  return deepEqual;
3786
3958
  }
3787
3959
  });
3960
+ Object.defineProperty(exports, 'defaultDevice', {
3961
+ enumerable: true,
3962
+ get: function () {
3963
+ return defaultDevice;
3964
+ }
3965
+ });
3788
3966
  Object.defineProperty(exports, 'devices', {
3789
3967
  enumerable: true,
3790
3968
  get: function () {
@@ -3809,6 +3987,12 @@ Object.defineProperty(exports, 'findPow2', {
3809
3987
  return findPow2;
3810
3988
  }
3811
3989
  });
3990
+ Object.defineProperty(exports, 'generalBroadcast', {
3991
+ enumerable: true,
3992
+ get: function () {
3993
+ return generalBroadcast;
3994
+ }
3995
+ });
3812
3996
  Object.defineProperty(exports, 'getBackend', {
3813
3997
  enumerable: true,
3814
3998
  get: function () {
@@ -3845,6 +4029,12 @@ Object.defineProperty(exports, 'isPermutation', {
3845
4029
  return isPermutation;
3846
4030
  }
3847
4031
  });
4032
+ Object.defineProperty(exports, 'normalizeAxis', {
4033
+ enumerable: true,
4034
+ get: function () {
4035
+ return normalizeAxis;
4036
+ }
4037
+ });
3848
4038
  Object.defineProperty(exports, 'partitionList', {
3849
4039
  enumerable: true,
3850
4040
  get: function () {
@@ -3857,6 +4047,12 @@ Object.defineProperty(exports, 'prod', {
3857
4047
  return prod;
3858
4048
  }
3859
4049
  });
4050
+ Object.defineProperty(exports, 'promoteTypes', {
4051
+ enumerable: true,
4052
+ get: function () {
4053
+ return promoteTypes;
4054
+ }
4055
+ });
3860
4056
  Object.defineProperty(exports, 'range', {
3861
4057
  enumerable: true,
3862
4058
  get: function () {
@@ -3881,10 +4077,10 @@ Object.defineProperty(exports, 'runWithCache', {
3881
4077
  return runWithCache;
3882
4078
  }
3883
4079
  });
3884
- Object.defineProperty(exports, 'setDevice', {
4080
+ Object.defineProperty(exports, 'setDebug', {
3885
4081
  enumerable: true,
3886
4082
  get: function () {
3887
- return setDevice;
4083
+ return setDebug;
3888
4084
  }
3889
4085
  });
3890
4086
  Object.defineProperty(exports, 'strip1', {