@jax-js/jax 0.1.10 → 0.1.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-Ctqs8la1.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DI-V78Rk.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -209,7 +209,7 @@ __export(tree_exports, {
209
209
  structure: () => structure,
210
210
  unflatten: () => unflatten
211
211
  });
212
- const JsArray$2 = globalThis.Array;
212
+ const JsArray$3 = globalThis.Array;
213
213
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
214
214
  NodeType$1["Array"] = "Array";
215
215
  NodeType$1["Object"] = "Object";
@@ -257,7 +257,7 @@ function flatten(tree) {
257
257
  return [leaves$1, treedef];
258
258
  }
259
259
  function _flatten(tree, leaves$1) {
260
- if (JsArray$2.isArray(tree)) {
260
+ if (JsArray$3.isArray(tree)) {
261
261
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
262
262
  return new JsTreeDef(NodeType.Array, null, childTrees);
263
263
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -333,6 +333,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
333
333
  Primitive$1["Mod"] = "mod";
334
334
  Primitive$1["Min"] = "min";
335
335
  Primitive$1["Max"] = "max";
336
+ Primitive$1["BitCombine"] = "bit_combine";
337
+ Primitive$1["BitShift"] = "bit_shift";
336
338
  Primitive$1["Neg"] = "neg";
337
339
  Primitive$1["Reciprocal"] = "reciprocal";
338
340
  Primitive$1["Floor"] = "floor";
@@ -406,6 +408,12 @@ function min$1(x, y) {
406
408
  function max$1(x, y) {
407
409
  return bind1(Primitive.Max, [x, y]);
408
410
  }
411
+ function bitCombine(x, y, op) {
412
+ return bind1(Primitive.BitCombine, [x, y], { op });
413
+ }
414
+ function bitShift(x, y, op) {
415
+ return bind1(Primitive.BitShift, [x, y], { op });
416
+ }
409
417
  function neg(x) {
410
418
  return bind1(Primitive.Neg, [x]);
411
419
  }
@@ -1620,6 +1628,16 @@ const abstractEvalRules = {
1620
1628
  [Primitive.Mod]: binopAbstractEval,
1621
1629
  [Primitive.Min]: binopAbstractEval,
1622
1630
  [Primitive.Max]: binopAbstractEval,
1631
+ [Primitive.BitCombine]([x, y]) {
1632
+ const aval = promoteAvals(x, y);
1633
+ if (isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
1634
+ return [aval];
1635
+ },
1636
+ [Primitive.BitShift]([x, y]) {
1637
+ const shape$1 = generalBroadcast(x.shape, y.shape);
1638
+ if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype) || x.dtype === DType.Bool || y.dtype === DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
1639
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
1640
+ },
1623
1641
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1624
1642
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1625
1643
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -2155,6 +2173,8 @@ const jitRules = {
2155
2173
  [Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
2156
2174
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
2157
2175
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
2176
+ [Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => AluExp.bitCombine(a, b, op)),
2177
+ [Primitive.BitShift]: broadcastedJit(([a, b], { op }) => AluExp.bitShift(a, b, op)),
2158
2178
  [Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
2159
2179
  [Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
2160
2180
  [Primitive.Floor]: unopJit(AluExp.floor),
@@ -2347,7 +2367,9 @@ function splitGraphDataflow(backend, jaxpr) {
2347
2367
  case Primitive.Idiv:
2348
2368
  case Primitive.Mod:
2349
2369
  case Primitive.Min:
2350
- case Primitive.Max: {
2370
+ case Primitive.Max:
2371
+ case Primitive.BitCombine:
2372
+ case Primitive.BitShift: {
2351
2373
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2352
2374
  if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2353
2375
  head = usages[0];
@@ -2438,7 +2460,7 @@ function splitGraphDataflow(backend, jaxpr) {
2438
2460
 
2439
2461
  //#endregion
2440
2462
  //#region src/frontend/array.ts
2441
- const JsArray$1 = globalThis.Array;
2463
+ const JsArray$2 = globalThis.Array;
2442
2464
  const inlineArrayLimit = 128;
2443
2465
  /** Version of pureArray with fudged types. */
2444
2466
  const fudgeArray = pureArray;
@@ -2878,6 +2900,15 @@ var Array$1 = class Array$1 extends Tracer {
2878
2900
  this.#check();
2879
2901
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
2880
2902
  if (this.#source instanceof AluExp) {
2903
+ let resolvedSource;
2904
+ if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
2905
+ const byteLength = this.#st.size * byteWidth(this.#dtype);
2906
+ const initialData = new Uint8Array(byteLength);
2907
+ dtypedArray(this.#dtype, initialData).fill(resolvedSource);
2908
+ this.#source = this.#backend.malloc(byteLength, initialData);
2909
+ this.#st = ShapeTracker.fromShape(this.shape);
2910
+ return;
2911
+ }
2881
2912
  const exp$2 = accessorAluExp(this.#source, this.#st, indices);
2882
2913
  const kernel = new Kernel(0, this.#st.size, exp$2);
2883
2914
  const output = this.#backend.malloc(kernel.bytes);
@@ -2986,6 +3017,42 @@ var Array$1 = class Array$1 extends Tracer {
2986
3017
  return dtypedArray(this.dtype, buf);
2987
3018
  }
2988
3019
  /**
3020
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
3021
+ *
3022
+ * Only available on the WebGPU backend. The array's memory is still managed
3023
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
3024
+ * _should not_ mutate the buffer's contents.
3025
+ *
3026
+ * Note that the GPU buffer may be slightly larger than the array's size; it
3027
+ * will always be aligned to 4 bytes.
3028
+ */
3029
+ async gpuBuffer() {
3030
+ if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
3031
+ this.#realize();
3032
+ const pending = this.#pending;
3033
+ if (pending) {
3034
+ await Promise.all(pending.map((p) => p.prepare()));
3035
+ for (const p of pending) p.submit();
3036
+ }
3037
+ const backend = this.#backend;
3038
+ const { buffer } = backend.buffers.get(this.#source);
3039
+ this.dispose();
3040
+ return buffer;
3041
+ }
3042
+ /** Synchronous version of `Array.gpuBuffer()`. */
3043
+ gpuBufferSync() {
3044
+ if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
3045
+ this.#realize();
3046
+ for (const p of this.#pending) {
3047
+ p.prepareSync();
3048
+ p.submit();
3049
+ }
3050
+ const backend = this.#backend;
3051
+ const { buffer } = backend.buffers.get(this.#source);
3052
+ this.dispose();
3053
+ return buffer;
3054
+ }
3055
+ /**
2989
3056
  * Convert this array into a JavaScript object.
2990
3057
  *
2991
3058
  * This is a blocking operation that will compile all of the shaders and wait
@@ -3032,6 +3099,14 @@ var Array$1 = class Array$1 extends Tracer {
3032
3099
  [Primitive.Max]([x, y]) {
3033
3100
  return [x.#binary(AluOp.Max, y)];
3034
3101
  },
3102
+ [Primitive.BitCombine]([x, y], { op }) {
3103
+ const custom = (src) => AluExp.bitCombine(src[0], src[1], op);
3104
+ return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
3105
+ },
3106
+ [Primitive.BitShift]([x, y], { op }) {
3107
+ const custom = (src) => AluExp.bitShift(src[0], src[1], op);
3108
+ return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
3109
+ },
3035
3110
  [Primitive.Neg]([x]) {
3036
3111
  return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
3037
3112
  },
@@ -3284,7 +3359,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3284
3359
  if (!shape$1) {
3285
3360
  shape$1 = [];
3286
3361
  let cur = values;
3287
- while (JsArray$1.isArray(cur)) {
3362
+ while (JsArray$2.isArray(cur)) {
3288
3363
  shape$1.push(cur.length);
3289
3364
  cur = cur[0];
3290
3365
  }
@@ -3723,6 +3798,8 @@ const vmapRules = {
3723
3798
  [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3724
3799
  [Primitive.Min]: broadcastBatcher(Primitive.Min),
3725
3800
  [Primitive.Max]: broadcastBatcher(Primitive.Max),
3801
+ [Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
3802
+ [Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
3726
3803
  [Primitive.Neg]: unopBatcher(Primitive.Neg),
3727
3804
  [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3728
3805
  [Primitive.Floor]: unopBatcher(Primitive.Floor),
@@ -4045,6 +4122,8 @@ const jvpRules = {
4045
4122
  [Primitive.Max]([x, y], [dx, dy]) {
4046
4123
  return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4047
4124
  },
4125
+ [Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
4126
+ [Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
4048
4127
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4049
4128
  [Primitive.Reciprocal]([x], [dx]) {
4050
4129
  const xRecip = reciprocal$1(x.ref);
@@ -4162,7 +4241,7 @@ const jvpRules = {
4162
4241
  return [[L], [dL]];
4163
4242
  },
4164
4243
  [Primitive.LU]([a], [da]) {
4165
- const [luMatrix, pivots, permutation] = lu$1(a);
4244
+ const [luMatrix, pivots, permutation$1] = lu$1(a);
4166
4245
  const [m, n] = a.shape.slice(-2);
4167
4246
  const k = Math.min(m, n);
4168
4247
  const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
@@ -4174,7 +4253,7 @@ const jvpRules = {
4174
4253
  const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4175
4254
  const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4176
4255
  const U = uPadded.add(uEye);
4177
- const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4256
+ const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
4178
4257
  const pda = batchMatmulT(P, mT(da));
4179
4258
  const la = mT(triangularSolve$1(L.ref, mT(pda), {
4180
4259
  lower: true,
@@ -4186,11 +4265,11 @@ const jvpRules = {
4186
4265
  return [[
4187
4266
  luMatrix,
4188
4267
  pivots,
4189
- permutation
4268
+ permutation$1
4190
4269
  ], [
4191
4270
  lDot.add(uDot),
4192
4271
  zerosLike$1(pivots.ref),
4193
- zerosLike$1(permutation.ref)
4272
+ zerosLike$1(permutation$1.ref)
4194
4273
  ]];
4195
4274
  },
4196
4275
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
@@ -5236,7 +5315,8 @@ __export(numpy_linalg_exports, {
5236
5315
  solve: () => solve,
5237
5316
  tensordot: () => tensordot,
5238
5317
  trace: () => trace,
5239
- vecdot: () => vecdot
5318
+ vecdot: () => vecdot,
5319
+ vectorNorm: () => vectorNorm
5240
5320
  });
5241
5321
  function checkSquare(name, a) {
5242
5322
  if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
@@ -5271,8 +5351,8 @@ function cross$1(x1, x2, axis = -1) {
5271
5351
  function det(a) {
5272
5352
  a = fudgeArray(a);
5273
5353
  const n = checkSquare("det", a);
5274
- const [lu$2, pivots, permutation] = lu(a);
5275
- permutation.dispose();
5354
+ const [lu$2, pivots, permutation$1] = lu(a);
5355
+ permutation$1.dispose();
5276
5356
  const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5277
5357
  const sign$1 = parity.mul(-2).add(1);
5278
5358
  const diag$1 = lu$2.diagonal(0, -1, -2);
@@ -5361,8 +5441,8 @@ function matrixPower(a, n) {
5361
5441
  function slogdet(a) {
5362
5442
  a = fudgeArray(a);
5363
5443
  const n = checkSquare("slogdet", a);
5364
- const [lu$2, pivots, permutation] = lu(a);
5365
- permutation.dispose();
5444
+ const [lu$2, pivots, permutation$1] = lu(a);
5445
+ permutation$1.dispose();
5366
5446
  let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5367
5447
  const diag$1 = lu$2.diagonal(0, -1, -2);
5368
5448
  parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
@@ -5400,9 +5480,9 @@ function solve(a, b) {
5400
5480
  n,
5401
5481
  m
5402
5482
  ]);
5403
- const [lu$2, pivots, permutation] = lu(a);
5483
+ const [lu$2, pivots, permutation$1] = lu(a);
5404
5484
  pivots.dispose();
5405
- const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5485
+ const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
5406
5486
  const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5407
5487
  leftSide: true,
5408
5488
  lower: true,
@@ -5415,6 +5495,23 @@ function solve(a, b) {
5415
5495
  if (bIs1d) x = squeeze(x, -1);
5416
5496
  return x;
5417
5497
  }
5498
+ /**
5499
+ * Compute the vector norm of an array.
5500
+ *
5501
+ * @param x - Input array.
5502
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
5503
+ * @param axis - Axis/axes to reduce over (default: all axes).
5504
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
5505
+ * @returns The norm of `x`, reduced over the given axes.
5506
+ */
5507
+ function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
5508
+ x = fudgeArray(x);
5509
+ const ax = axis ?? null;
5510
+ if (ord === Infinity) return max(absolute(x), ax, { keepdims });
5511
+ else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
5512
+ else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
5513
+ else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
5514
+ }
5418
5515
 
5419
5516
  //#endregion
5420
5517
  //#region src/library/numpy/dtype-info.ts
@@ -5534,6 +5631,13 @@ __export(numpy_exports, {
5534
5631
  atan2: () => atan2,
5535
5632
  atanh: () => arctanh,
5536
5633
  average: () => average,
5634
+ bitwiseAnd: () => bitwiseAnd,
5635
+ bitwiseInvert: () => invert,
5636
+ bitwiseLeftShift: () => leftShift,
5637
+ bitwiseNot: () => invert,
5638
+ bitwiseOr: () => bitwiseOr,
5639
+ bitwiseRightShift: () => rightShift,
5640
+ bitwiseXor: () => bitwiseXor,
5537
5641
  bool: () => bool,
5538
5642
  broadcastArrays: () => broadcastArrays,
5539
5643
  broadcastShapes: () => broadcastShapes,
@@ -5595,12 +5699,14 @@ __export(numpy_exports, {
5595
5699
  inf: () => inf,
5596
5700
  inner: () => inner,
5597
5701
  int32: () => int32,
5702
+ invert: () => invert,
5598
5703
  isfinite: () => isfinite,
5599
5704
  isinf: () => isinf,
5600
5705
  isnan: () => isnan,
5601
5706
  isneginf: () => isneginf,
5602
5707
  isposinf: () => isposinf,
5603
5708
  ldexp: () => ldexp,
5709
+ leftShift: () => leftShift,
5604
5710
  less: () => less,
5605
5711
  lessEqual: () => lessEqual,
5606
5712
  linalg: () => numpy_linalg_exports,
@@ -5649,6 +5755,7 @@ __export(numpy_exports, {
5649
5755
  remainder: () => remainder,
5650
5756
  repeat: () => repeat,
5651
5757
  reshape: () => reshape,
5758
+ rightShift: () => rightShift,
5652
5759
  rint: () => rint,
5653
5760
  round: () => round,
5654
5761
  shape: () => shape,
@@ -5763,6 +5870,44 @@ function logicalXor(x, y) {
5763
5870
  function logicalNot(x) {
5764
5871
  return notEqual(astype(x, DType.Bool), true);
5765
5872
  }
5873
+ /** Compute element-wise bitwise AND. */
5874
+ function bitwiseAnd(x, y) {
5875
+ return bitCombine(x, y, "and");
5876
+ }
5877
+ /** Compute element-wise bitwise OR. */
5878
+ function bitwiseOr(x, y) {
5879
+ return bitCombine(x, y, "or");
5880
+ }
5881
+ /** Compute element-wise bitwise XOR. */
5882
+ function bitwiseXor(x, y) {
5883
+ return bitCombine(x, y, "xor");
5884
+ }
5885
+ /** Compute element-wise bitwise NOT (inversion). */
5886
+ function invert(x) {
5887
+ const arr = fudgeArray(x);
5888
+ let allOnes;
5889
+ switch (arr.dtype) {
5890
+ case DType.Bool:
5891
+ allOnes = true;
5892
+ break;
5893
+ case DType.Uint32:
5894
+ allOnes = 4294967295;
5895
+ break;
5896
+ case DType.Int32:
5897
+ allOnes = -1;
5898
+ break;
5899
+ default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
5900
+ }
5901
+ return bitCombine(arr, allOnes, "xor");
5902
+ }
5903
+ /** Compute element-wise left bit shift. */
5904
+ function leftShift(x, y) {
5905
+ return bitShift(x, y, "shl");
5906
+ }
5907
+ /** Compute element-wise right bit shift. */
5908
+ function rightShift(x, y) {
5909
+ return bitShift(x, y, "shr");
5910
+ }
5766
5911
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5767
5912
  const where = where$1;
5768
5913
  /**
@@ -7193,7 +7338,7 @@ __export(lax_exports, {
7193
7338
  stopGradient: () => stopGradient$1,
7194
7339
  topK: () => topK
7195
7340
  });
7196
- const JsArray = globalThis.Array;
7341
+ const JsArray$1 = globalThis.Array;
7197
7342
  /** Elementwise bitcast an array into a new dtype. */
7198
7343
  function bitcastConvertType(x, newDtype) {
7199
7344
  return fudgeArray(x).view(newDtype);
@@ -7380,7 +7525,7 @@ function convTransposePadding(k, s, padding) {
7380
7525
  } else if (padding === "VALID") {
7381
7526
  padLen = k + s - 2 + Math.max(k - s, 0);
7382
7527
  pad1 = k - 1;
7383
- } else if (JsArray.isArray(padding)) {
7528
+ } else if (JsArray$1.isArray(padding)) {
7384
7529
  const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7385
7530
  pad1 = pads[0];
7386
7531
  padLen = pads[0] + pads[1];
@@ -7899,19 +8044,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
7899
8044
  //#region src/library/random.ts
7900
8045
  var random_exports = {};
7901
8046
  __export(random_exports, {
8047
+ ball: () => ball,
7902
8048
  bernoulli: () => bernoulli,
7903
8049
  bits: () => bits,
7904
8050
  categorical: () => categorical,
7905
8051
  cauchy: () => cauchy,
8052
+ choice: () => choice,
8053
+ doubleSidedMaxwell: () => doubleSidedMaxwell,
7906
8054
  exponential: () => exponential,
8055
+ geometric: () => geometric,
7907
8056
  gumbel: () => gumbel,
7908
8057
  key: () => key,
7909
8058
  laplace: () => laplace,
8059
+ logistic: () => logistic,
8060
+ lognormal: () => lognormal,
8061
+ maxwell: () => maxwell,
7910
8062
  multivariateNormal: () => multivariateNormal,
7911
8063
  normal: () => normal,
8064
+ pareto: () => pareto,
8065
+ permutation: () => permutation,
8066
+ rademacher: () => rademacher,
8067
+ randint: () => randint,
8068
+ rayleigh: () => rayleigh,
7912
8069
  split: () => split,
7913
- uniform: () => uniform
8070
+ triangular: () => triangular,
8071
+ uniform: () => uniform,
8072
+ weibullMin: () => weibullMin
7914
8073
  });
8074
+ const JsArray = globalThis.Array;
7915
8075
  function validateKeyShape(key$1, scalar = false) {
7916
8076
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
7917
8077
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
@@ -7964,6 +8124,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
7964
8124
  else return rand.mul(maxval - minval).add(minval);
7965
8125
  }, { staticArgnums: [1, 2] });
7966
8126
  /**
8127
+ * @function
8128
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
8129
+ *
8130
+ * Only the Euclidean `p=2` case is currently supported.
8131
+ */
8132
+ const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
8133
+ if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
8134
+ if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
8135
+ const [k1, k2] = split(key$1, 2);
8136
+ const z = normal(k1, [...shape$1, d]);
8137
+ const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
8138
+ const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
8139
+ return z.div(norm).mul(radius);
8140
+ }, { staticArgnums: [1, 2] });
8141
+ /**
7967
8142
  * Sample Bernoulli random variables with given mean (0,1 categorical).
7968
8143
  *
7969
8144
  * Returns a random Boolean array with the specified shape. `p` can be an array
@@ -8025,6 +8200,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
8025
8200
  return tan(u.sub(.5).mul(Math.PI));
8026
8201
  }, { staticArgnums: [1] });
8027
8202
  /**
8203
+ * Sample from a population with optional replacement and optional probabilities.
8204
+ *
8205
+ * This implements the common JAX-compatible cases: integer populations and
8206
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
8207
+ * via `categorical(log(p))`.
8208
+ */
8209
+ function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
8210
+ let n;
8211
+ let values = null;
8212
+ if (typeof a === "number") {
8213
+ if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
8214
+ n = a;
8215
+ } else {
8216
+ values = fudgeArray(a);
8217
+ axis = checkAxis(axis, values.ndim);
8218
+ n = values.shape[axis];
8219
+ }
8220
+ let indices;
8221
+ if (p !== void 0) indices = categorical(key$1, log(p), {
8222
+ shape: shape$1,
8223
+ replace
8224
+ });
8225
+ else if (replace) indices = randint(key$1, {
8226
+ minval: 0,
8227
+ maxval: n,
8228
+ shape: shape$1
8229
+ });
8230
+ else {
8231
+ const k = shape$1.reduce((acc, x) => acc * x, 1);
8232
+ if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
8233
+ indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
8234
+ }
8235
+ if (values === null) return indices;
8236
+ const index = JsArray(axis).fill([]);
8237
+ index.push(indices);
8238
+ return values.slice(...index);
8239
+ }
8240
+ /**
8241
+ * @function
8242
+ * Sample double-sided Maxwell random values with the provided location and scale.
8243
+ */
8244
+ const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
8245
+ loc = fudgeArray(loc);
8246
+ scale = fudgeArray(scale);
8247
+ const [k1, k2] = split(key$1, 2);
8248
+ return rademacher(k1, {
8249
+ shape: shape$1,
8250
+ dtype: DType.Float32
8251
+ }).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
8252
+ }, { staticArgnums: [3] });
8253
+ /**
8028
8254
  * @function
8029
8255
  * Sample exponential random values according to `p(x) = exp(-x)`.
8030
8256
  */
@@ -8034,6 +8260,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
8034
8260
  }, { staticArgnums: [1] });
8035
8261
  /**
8036
8262
  * @function
8263
+ * Sample geometric random values: the number of trials until first success.
8264
+ */
8265
+ const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
8266
+ p = fudgeArray(p);
8267
+ return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
8268
+ }, { staticArgnums: [2] });
8269
+ /**
8270
+ * @function
8037
8271
  * Sample from a Gumbel distribution with location 0 and scale 1.
8038
8272
  *
8039
8273
  * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
@@ -8058,6 +8292,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
8058
8292
  }, { staticArgnums: [1] });
8059
8293
  /**
8060
8294
  * @function
8295
+ * Sample from a logistic distribution with location 0 and scale 1.
8296
+ *
8297
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
8298
+ */
8299
+ const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
8300
+ const u = uniform(key$1, shape$1);
8301
+ return log(u.ref).sub(log1p(negative(u)));
8302
+ }, { staticArgnums: [1] });
8303
+ /**
8304
+ * @function
8305
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
8306
+ */
8307
+ const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
8308
+ sigma = fudgeArray(sigma);
8309
+ return exp(normal(key$1, shape$1).mul(sigma));
8310
+ }, { staticArgnums: [2] });
8311
+ /**
8312
+ * @function
8313
+ * Sample Maxwell-distributed random values.
8314
+ */
8315
+ const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
8316
+ const z = normal(key$1, [...shape$1, 3]);
8317
+ return sqrt(z.ref.mul(z).sum(-1));
8318
+ }, { staticArgnums: [1] });
8319
+ /**
8320
+ * @function
8061
8321
  * Sample multivariate normal random values with given mean and covariance.
8062
8322
  *
8063
8323
  * The values are returned with the given shape, along with the final dimension
@@ -8098,6 +8358,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
8098
8358
  const theta = u2.mul(2 * Math.PI);
8099
8359
  return radius.mul(cos(theta));
8100
8360
  }, { staticArgnums: [1] });
8361
+ /**
8362
+ * @function
8363
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
8364
+ */
8365
+ const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
8366
+ b = fudgeArray(b);
8367
+ return exp(exponential(key$1, shape$1).div(b));
8368
+ }, { staticArgnums: [2] });
8369
+ /**
8370
+ * Return a random permutation of an integer range or of an array along `axis`.
8371
+ */
8372
+ function permutation(key$1, x, axis = 0) {
8373
+ if (typeof x === "number") {
8374
+ if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
8375
+ return argsort(uniform(key$1, [x])).astype(DType.Int32);
8376
+ }
8377
+ const arr = fudgeArray(x);
8378
+ axis = checkAxis(axis, arr.ndim);
8379
+ const perm = permutation(key$1, arr.shape[axis]);
8380
+ const index = JsArray(axis).fill([]);
8381
+ index.push(perm);
8382
+ return arr.slice(...index);
8383
+ }
8384
+ /**
8385
+ * @function
8386
+ * Sample Rademacher random values, uniformly from {-1, 1}.
8387
+ */
8388
+ const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
8389
+ if (dtype === DType.Uint32 || dtype === DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
8390
+ const one = array(1, {
8391
+ dtype,
8392
+ device: key$1.device
8393
+ });
8394
+ const minusOne = array(-1, {
8395
+ dtype,
8396
+ device: key$1.device
8397
+ });
8398
+ return where(bernoulli(key$1, .5, shape$1), one, minusOne);
8399
+ }, { staticArgnums: [1] });
8400
+ /**
8401
+ * @function
8402
+ * Sample integer values uniformly from `[minval, maxval)`.
8403
+ *
8404
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
8405
+ * not divide 2^32, this introduces a very small modulo bias.
8406
+ */
8407
+ const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = DType.Int32 }) {
8408
+ if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
8409
+ if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
8410
+ if (dtype !== DType.Int32 && dtype !== DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
8411
+ if (dtype === DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
8412
+ const range$1 = maxval - minval;
8413
+ return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
8414
+ }, { staticArgnums: [1] });
8415
+ /**
8416
+ * @function
8417
+ * Sample Rayleigh random values with the provided scale parameter.
8418
+ */
8419
+ const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
8420
+ scale = fudgeArray(scale);
8421
+ return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
8422
+ }, { staticArgnums: [2] });
8423
+ /**
8424
+ * @function
8425
+ * Sample triangular random values on `[left, right]` with the given mode.
8426
+ */
8427
+ const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
8428
+ left = fudgeArray(left);
8429
+ mode = fudgeArray(mode);
8430
+ right = fudgeArray(right);
8431
+ const u = uniform(key$1, shape$1);
8432
+ const width = right.ref.sub(left.ref);
8433
+ const leftSpan = mode.ref.sub(left.ref);
8434
+ const rightSpan = right.ref.sub(mode);
8435
+ const cutoff = leftSpan.ref.div(width.ref);
8436
+ const cond = u.ref.less(cutoff);
8437
+ const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
8438
+ const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
8439
+ return where(cond, lower, upper);
8440
+ }, { staticArgnums: [4] });
8441
+ /**
8442
+ * @function
8443
+ * Sample Weibull minimum random values.
8444
+ *
8445
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
8446
+ */
8447
+ const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
8448
+ scale = fudgeArray(scale);
8449
+ concentration = fudgeArray(concentration);
8450
+ return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
8451
+ }, { staticArgnums: [3] });
8101
8452
 
8102
8453
  //#endregion
8103
8454
  //#region src/library/scipy-special.ts
@@ -8291,4 +8642,4 @@ async function devicePut(x, device) {
8291
8642
  }
8292
8643
 
8293
8644
  //#endregion
8294
- export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
8645
+ export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, getWebGPUDevice, grad, hessian, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, profiler, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-Ctqs8la1.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DI-V78Rk.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
458
458
  else source = `min(${a}, ${b})`;
459
459
  else if (op === AluOp.Max) if (dtype === DType.Bool) source = `(${a} || ${b})`;
460
460
  else source = `max(${a}, ${b})`;
461
+ else if (op === AluOp.BitCombine) {
462
+ let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
463
+ if (dtype === DType.Bool) infix = infix + infix;
464
+ source = `(${a} ${infix} ${b})`;
465
+ } else if (op === AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
466
+ else source = `(${a} >> ${b})`;
461
467
  } else if (AluGroup.Compare.has(op)) {
462
468
  const a = gen(src[0]);
463
469
  const b = gen(src[1]);
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DMauYnfl.cjs');
1
+ const require_backend = require('./backend-x-6vqzIM.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -458,6 +458,12 @@ function generateExpression(exp, args, inputDtypes) {
458
458
  else source = `min(${a}, ${b})`;
459
459
  else if (op === require_backend.AluOp.Max) if (dtype === require_backend.DType.Bool) source = `(${a} || ${b})`;
460
460
  else source = `max(${a}, ${b})`;
461
+ else if (op === require_backend.AluOp.BitCombine) {
462
+ let infix = arg === "and" ? "&" : arg === "or" ? "|" : "^";
463
+ if (dtype === require_backend.DType.Bool) infix = infix + infix;
464
+ source = `(${a} ${infix} ${b})`;
465
+ } else if (op === require_backend.AluOp.BitShift) if (arg === "shl") source = `(${a} << ${b})`;
466
+ else source = `(${a} >> ${b})`;
461
467
  } else if (require_backend.AluGroup.Compare.has(op)) {
462
468
  const a = gen(src[0]);
463
469
  const b = gen(src[1]);