@jax-js/jax 0.1.10 → 0.1.12

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
30
30
  }) : target, mod$1));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-DMauYnfl.cjs');
33
+ const require_backend = require('./backend-x-6vqzIM.cjs');
34
34
 
35
35
  //#region src/frontend/convolution.ts
36
36
  /**
@@ -240,7 +240,7 @@ __export(tree_exports, {
240
240
  structure: () => structure,
241
241
  unflatten: () => unflatten
242
242
  });
243
- const JsArray$2 = globalThis.Array;
243
+ const JsArray$3 = globalThis.Array;
244
244
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
245
245
  NodeType$1["Array"] = "Array";
246
246
  NodeType$1["Object"] = "Object";
@@ -288,7 +288,7 @@ function flatten(tree) {
288
288
  return [leaves$1, treedef];
289
289
  }
290
290
  function _flatten(tree, leaves$1) {
291
- if (JsArray$2.isArray(tree)) {
291
+ if (JsArray$3.isArray(tree)) {
292
292
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
293
293
  return new JsTreeDef(NodeType.Array, null, childTrees);
294
294
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -364,6 +364,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
364
364
  Primitive$1["Mod"] = "mod";
365
365
  Primitive$1["Min"] = "min";
366
366
  Primitive$1["Max"] = "max";
367
+ Primitive$1["BitCombine"] = "bit_combine";
368
+ Primitive$1["BitShift"] = "bit_shift";
367
369
  Primitive$1["Neg"] = "neg";
368
370
  Primitive$1["Reciprocal"] = "reciprocal";
369
371
  Primitive$1["Floor"] = "floor";
@@ -437,6 +439,12 @@ function min$1(x, y) {
437
439
  function max$1(x, y) {
438
440
  return bind1(Primitive.Max, [x, y]);
439
441
  }
442
+ function bitCombine(x, y, op) {
443
+ return bind1(Primitive.BitCombine, [x, y], { op });
444
+ }
445
+ function bitShift(x, y, op) {
446
+ return bind1(Primitive.BitShift, [x, y], { op });
447
+ }
440
448
  function neg(x) {
441
449
  return bind1(Primitive.Neg, [x]);
442
450
  }
@@ -1655,6 +1663,16 @@ const abstractEvalRules = {
1655
1663
  [Primitive.Mod]: binopAbstractEval,
1656
1664
  [Primitive.Min]: binopAbstractEval,
1657
1665
  [Primitive.Max]: binopAbstractEval,
1666
+ [Primitive.BitCombine]([x, y]) {
1667
+ const aval = promoteAvals(x, y);
1668
+ if (require_backend.isFloatDtype(aval.dtype)) throw new TypeError(`bitwise operations require integer or boolean inputs, got ${aval.dtype}`);
1669
+ return [aval];
1670
+ },
1671
+ [Primitive.BitShift]([x, y]) {
1672
+ const shape$1 = require_backend.generalBroadcast(x.shape, y.shape);
1673
+ if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype) || x.dtype === require_backend.DType.Bool || y.dtype === require_backend.DType.Bool) throw new TypeError(`bit shift operations require integer inputs, got ${x} and ${y}`);
1674
+ return [new ShapedArray(shape$1, x.dtype, x.weakType)];
1675
+ },
1658
1676
  [Primitive.Neg]: vectorizedUnopAbstractEval,
1659
1677
  [Primitive.Reciprocal]: vectorizedUnopAbstractEval,
1660
1678
  [Primitive.Floor]: vectorizedUnopAbstractEval,
@@ -2190,6 +2208,8 @@ const jitRules = {
2190
2208
  [Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
2191
2209
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
2192
2210
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
2211
+ [Primitive.BitCombine]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitCombine(a, b, op)),
2212
+ [Primitive.BitShift]: broadcastedJit(([a, b], { op }) => require_backend.AluExp.bitShift(a, b, op)),
2193
2213
  [Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
2194
2214
  [Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
2195
2215
  [Primitive.Floor]: unopJit(require_backend.AluExp.floor),
@@ -2382,7 +2402,9 @@ function splitGraphDataflow(backend, jaxpr) {
2382
2402
  case Primitive.Idiv:
2383
2403
  case Primitive.Mod:
2384
2404
  case Primitive.Min:
2385
- case Primitive.Max: {
2405
+ case Primitive.Max:
2406
+ case Primitive.BitCombine:
2407
+ case Primitive.BitShift: {
2386
2408
  const otherInput = nextEqn.inputs.find((v) => v !== outVar);
2387
2409
  if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
2388
2410
  head = usages[0];
@@ -2473,7 +2495,7 @@ function splitGraphDataflow(backend, jaxpr) {
2473
2495
 
2474
2496
  //#endregion
2475
2497
  //#region src/frontend/array.ts
2476
- const JsArray$1 = globalThis.Array;
2498
+ const JsArray$2 = globalThis.Array;
2477
2499
  const inlineArrayLimit = 128;
2478
2500
  /** Version of pureArray with fudged types. */
2479
2501
  const fudgeArray = pureArray;
@@ -2913,6 +2935,15 @@ var Array$1 = class Array$1 extends Tracer {
2913
2935
  this.#check();
2914
2936
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
2915
2937
  if (this.#source instanceof require_backend.AluExp) {
2938
+ let resolvedSource;
2939
+ if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
2940
+ const byteLength = this.#st.size * require_backend.byteWidth(this.#dtype);
2941
+ const initialData = new Uint8Array(byteLength);
2942
+ require_backend.dtypedArray(this.#dtype, initialData).fill(resolvedSource);
2943
+ this.#source = this.#backend.malloc(byteLength, initialData);
2944
+ this.#st = require_backend.ShapeTracker.fromShape(this.shape);
2945
+ return;
2946
+ }
2916
2947
  const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
2917
2948
  const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
2918
2949
  const output = this.#backend.malloc(kernel.bytes);
@@ -3021,6 +3052,42 @@ var Array$1 = class Array$1 extends Tracer {
3021
3052
  return require_backend.dtypedArray(this.dtype, buf);
3022
3053
  }
3023
3054
  /**
3055
+ * Return this array as a WebGPU buffer (with `STORAGE | COPY_SRC`).
3056
+ *
3057
+ * Only available on the WebGPU backend. The array's memory is still managed
3058
+ * by jax-js, and it will be freed when the buffer is no longer in use. You
3059
+ * _should not_ mutate the buffer's contents.
3060
+ *
3061
+ * Note that the GPU buffer may be slightly larger than the array's size; it
3062
+ * will always be aligned to 4 bytes.
3063
+ */
3064
+ async gpuBuffer() {
3065
+ if (this.device !== "webgpu") throw new Error(`gpuBuffer() is only available on WebGPU backend`);
3066
+ this.#realize();
3067
+ const pending = this.#pending;
3068
+ if (pending) {
3069
+ await Promise.all(pending.map((p) => p.prepare()));
3070
+ for (const p of pending) p.submit();
3071
+ }
3072
+ const backend = this.#backend;
3073
+ const { buffer } = backend.buffers.get(this.#source);
3074
+ this.dispose();
3075
+ return buffer;
3076
+ }
3077
+ /** Synchronous version of `Array.gpuBuffer()`. */
3078
+ gpuBufferSync() {
3079
+ if (this.device !== "webgpu") throw new Error(`gpuBufferSync() is only available on WebGPU backend`);
3080
+ this.#realize();
3081
+ for (const p of this.#pending) {
3082
+ p.prepareSync();
3083
+ p.submit();
3084
+ }
3085
+ const backend = this.#backend;
3086
+ const { buffer } = backend.buffers.get(this.#source);
3087
+ this.dispose();
3088
+ return buffer;
3089
+ }
3090
+ /**
3024
3091
  * Convert this array into a JavaScript object.
3025
3092
  *
3026
3093
  * This is a blocking operation that will compile all of the shaders and wait
@@ -3067,6 +3134,14 @@ var Array$1 = class Array$1 extends Tracer {
3067
3134
  [Primitive.Max]([x, y]) {
3068
3135
  return [x.#binary(require_backend.AluOp.Max, y)];
3069
3136
  },
3137
+ [Primitive.BitCombine]([x, y], { op }) {
3138
+ const custom = (src) => require_backend.AluExp.bitCombine(src[0], src[1], op);
3139
+ return [Array$1.#naryCustom("bit_combine", custom, [x, y])];
3140
+ },
3141
+ [Primitive.BitShift]([x, y], { op }) {
3142
+ const custom = (src) => require_backend.AluExp.bitShift(src[0], src[1], op);
3143
+ return [Array$1.#naryCustom("bit_shift", custom, [x, y], { dtypeOverride: [void 0, y.dtype] })];
3144
+ },
3070
3145
  [Primitive.Neg]([x]) {
3071
3146
  return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
3072
3147
  },
@@ -3319,7 +3394,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3319
3394
  if (!shape$1) {
3320
3395
  shape$1 = [];
3321
3396
  let cur = values;
3322
- while (JsArray$1.isArray(cur)) {
3397
+ while (JsArray$2.isArray(cur)) {
3323
3398
  shape$1.push(cur.length);
3324
3399
  cur = cur[0];
3325
3400
  }
@@ -3759,6 +3834,8 @@ const vmapRules = {
3759
3834
  [Primitive.Mod]: broadcastBatcher(Primitive.Mod),
3760
3835
  [Primitive.Min]: broadcastBatcher(Primitive.Min),
3761
3836
  [Primitive.Max]: broadcastBatcher(Primitive.Max),
3837
+ [Primitive.BitCombine]: broadcastBatcher(Primitive.BitCombine),
3838
+ [Primitive.BitShift]: broadcastBatcher(Primitive.BitShift),
3762
3839
  [Primitive.Neg]: unopBatcher(Primitive.Neg),
3763
3840
  [Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
3764
3841
  [Primitive.Floor]: unopBatcher(Primitive.Floor),
@@ -4082,6 +4159,8 @@ const jvpRules = {
4082
4159
  [Primitive.Max]([x, y], [dx, dy]) {
4083
4160
  return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
4084
4161
  },
4162
+ [Primitive.BitCombine]: zeroTangentsJvp(Primitive.BitCombine),
4163
+ [Primitive.BitShift]: zeroTangentsJvp(Primitive.BitShift),
4085
4164
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
4086
4165
  [Primitive.Reciprocal]([x], [dx]) {
4087
4166
  const xRecip = reciprocal$1(x.ref);
@@ -4199,7 +4278,7 @@ const jvpRules = {
4199
4278
  return [[L], [dL]];
4200
4279
  },
4201
4280
  [Primitive.LU]([a], [da]) {
4202
- const [luMatrix, pivots, permutation] = lu$1(a);
4281
+ const [luMatrix, pivots, permutation$1] = lu$1(a);
4203
4282
  const [m, n] = a.shape.slice(-2);
4204
4283
  const k = Math.min(m, n);
4205
4284
  const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
@@ -4211,7 +4290,7 @@ const jvpRules = {
4211
4290
  const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4212
4291
  const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4213
4292
  const U = uPadded.add(uEye);
4214
- const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4293
+ const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
4215
4294
  const pda = batchMatmulT(P, mT(da));
4216
4295
  const la = mT(triangularSolve$1(L.ref, mT(pda), {
4217
4296
  lower: true,
@@ -4223,11 +4302,11 @@ const jvpRules = {
4223
4302
  return [[
4224
4303
  luMatrix,
4225
4304
  pivots,
4226
- permutation
4305
+ permutation$1
4227
4306
  ], [
4228
4307
  lDot.add(uDot),
4229
4308
  zerosLike$1(pivots.ref),
4230
- zerosLike$1(permutation.ref)
4309
+ zerosLike$1(permutation$1.ref)
4231
4310
  ]];
4232
4311
  },
4233
4312
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
@@ -5273,7 +5352,8 @@ __export(numpy_linalg_exports, {
5273
5352
  solve: () => solve,
5274
5353
  tensordot: () => tensordot,
5275
5354
  trace: () => trace,
5276
- vecdot: () => vecdot
5355
+ vecdot: () => vecdot,
5356
+ vectorNorm: () => vectorNorm
5277
5357
  });
5278
5358
  function checkSquare(name, a) {
5279
5359
  if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
@@ -5308,8 +5388,8 @@ function cross$1(x1, x2, axis = -1) {
5308
5388
  function det(a) {
5309
5389
  a = fudgeArray(a);
5310
5390
  const n = checkSquare("det", a);
5311
- const [lu$2, pivots, permutation] = lu(a);
5312
- permutation.dispose();
5391
+ const [lu$2, pivots, permutation$1] = lu(a);
5392
+ permutation$1.dispose();
5313
5393
  const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5314
5394
  const sign$1 = parity.mul(-2).add(1);
5315
5395
  const diag$1 = lu$2.diagonal(0, -1, -2);
@@ -5398,8 +5478,8 @@ function matrixPower(a, n) {
5398
5478
  function slogdet(a) {
5399
5479
  a = fudgeArray(a);
5400
5480
  const n = checkSquare("slogdet", a);
5401
- const [lu$2, pivots, permutation] = lu(a);
5402
- permutation.dispose();
5481
+ const [lu$2, pivots, permutation$1] = lu(a);
5482
+ permutation$1.dispose();
5403
5483
  let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5404
5484
  const diag$1 = lu$2.diagonal(0, -1, -2);
5405
5485
  parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
@@ -5437,9 +5517,9 @@ function solve(a, b) {
5437
5517
  n,
5438
5518
  m
5439
5519
  ]);
5440
- const [lu$2, pivots, permutation] = lu(a);
5520
+ const [lu$2, pivots, permutation$1] = lu(a);
5441
5521
  pivots.dispose();
5442
- const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5522
+ const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
5443
5523
  const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5444
5524
  leftSide: true,
5445
5525
  lower: true,
@@ -5452,6 +5532,23 @@ function solve(a, b) {
5452
5532
  if (bIs1d) x = squeeze(x, -1);
5453
5533
  return x;
5454
5534
  }
5535
+ /**
5536
+ * Compute the vector norm of an array.
5537
+ *
5538
+ * @param x - Input array.
5539
+ * @param ord - Order of the norm (default 2). Supports `Infinity`, `-Infinity`, `0`, or any real number.
5540
+ * @param axis - Axis/axes to reduce over (default: all axes).
5541
+ * @param keepdims - Whether to keep reduced dimensions as size 1.
5542
+ * @returns The norm of `x`, reduced over the given axes.
5543
+ */
5544
+ function vectorNorm(x, { ord = 2, axis = null, keepdims = false } = {}) {
5545
+ x = fudgeArray(x);
5546
+ const ax = axis ?? null;
5547
+ if (ord === Infinity) return max(absolute(x), ax, { keepdims });
5548
+ else if (ord === -Infinity) return min(absolute(x), ax, { keepdims });
5549
+ else if (ord === 0) return x.notEqual(0).astype(x.dtype).sum(ax, { keepdims });
5550
+ else return power(power(absolute(x), ord).sum(ax, { keepdims }), 1 / ord);
5551
+ }
5455
5552
 
5456
5553
  //#endregion
5457
5554
  //#region src/library/numpy/dtype-info.ts
@@ -5571,6 +5668,13 @@ __export(numpy_exports, {
5571
5668
  atan2: () => atan2,
5572
5669
  atanh: () => arctanh,
5573
5670
  average: () => average,
5671
+ bitwiseAnd: () => bitwiseAnd,
5672
+ bitwiseInvert: () => invert,
5673
+ bitwiseLeftShift: () => leftShift,
5674
+ bitwiseNot: () => invert,
5675
+ bitwiseOr: () => bitwiseOr,
5676
+ bitwiseRightShift: () => rightShift,
5677
+ bitwiseXor: () => bitwiseXor,
5574
5678
  bool: () => bool,
5575
5679
  broadcastArrays: () => broadcastArrays,
5576
5680
  broadcastShapes: () => broadcastShapes,
@@ -5632,12 +5736,14 @@ __export(numpy_exports, {
5632
5736
  inf: () => inf,
5633
5737
  inner: () => inner,
5634
5738
  int32: () => int32,
5739
+ invert: () => invert,
5635
5740
  isfinite: () => isfinite,
5636
5741
  isinf: () => isinf,
5637
5742
  isnan: () => isnan,
5638
5743
  isneginf: () => isneginf,
5639
5744
  isposinf: () => isposinf,
5640
5745
  ldexp: () => ldexp,
5746
+ leftShift: () => leftShift,
5641
5747
  less: () => less,
5642
5748
  lessEqual: () => lessEqual,
5643
5749
  linalg: () => numpy_linalg_exports,
@@ -5686,6 +5792,7 @@ __export(numpy_exports, {
5686
5792
  remainder: () => remainder,
5687
5793
  repeat: () => repeat,
5688
5794
  reshape: () => reshape,
5795
+ rightShift: () => rightShift,
5689
5796
  rint: () => rint,
5690
5797
  round: () => round,
5691
5798
  shape: () => shape,
@@ -5800,6 +5907,44 @@ function logicalXor(x, y) {
5800
5907
  function logicalNot(x) {
5801
5908
  return notEqual(astype(x, require_backend.DType.Bool), true);
5802
5909
  }
5910
+ /** Compute element-wise bitwise AND. */
5911
+ function bitwiseAnd(x, y) {
5912
+ return bitCombine(x, y, "and");
5913
+ }
5914
+ /** Compute element-wise bitwise OR. */
5915
+ function bitwiseOr(x, y) {
5916
+ return bitCombine(x, y, "or");
5917
+ }
5918
+ /** Compute element-wise bitwise XOR. */
5919
+ function bitwiseXor(x, y) {
5920
+ return bitCombine(x, y, "xor");
5921
+ }
5922
+ /** Compute element-wise bitwise NOT (inversion). */
5923
+ function invert(x) {
5924
+ const arr = fudgeArray(x);
5925
+ let allOnes;
5926
+ switch (arr.dtype) {
5927
+ case require_backend.DType.Bool:
5928
+ allOnes = true;
5929
+ break;
5930
+ case require_backend.DType.Uint32:
5931
+ allOnes = 4294967295;
5932
+ break;
5933
+ case require_backend.DType.Int32:
5934
+ allOnes = -1;
5935
+ break;
5936
+ default: throw new TypeError(`invert: unsupported dtype ${arr.dtype}`);
5937
+ }
5938
+ return bitCombine(arr, allOnes, "xor");
5939
+ }
5940
+ /** Compute element-wise left bit shift. */
5941
+ function leftShift(x, y) {
5942
+ return bitShift(x, y, "shl");
5943
+ }
5944
+ /** Compute element-wise right bit shift. */
5945
+ function rightShift(x, y) {
5946
+ return bitShift(x, y, "shr");
5947
+ }
5803
5948
  /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
5804
5949
  const where = where$1;
5805
5950
  /**
@@ -7230,7 +7375,7 @@ __export(lax_exports, {
7230
7375
  stopGradient: () => stopGradient$1,
7231
7376
  topK: () => topK
7232
7377
  });
7233
- const JsArray = globalThis.Array;
7378
+ const JsArray$1 = globalThis.Array;
7234
7379
  /** Elementwise bitcast an array into a new dtype. */
7235
7380
  function bitcastConvertType(x, newDtype) {
7236
7381
  return fudgeArray(x).view(newDtype);
@@ -7417,7 +7562,7 @@ function convTransposePadding(k, s, padding) {
7417
7562
  } else if (padding === "VALID") {
7418
7563
  padLen = k + s - 2 + Math.max(k - s, 0);
7419
7564
  pad1 = k - 1;
7420
- } else if (JsArray.isArray(padding)) {
7565
+ } else if (JsArray$1.isArray(padding)) {
7421
7566
  const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7422
7567
  pad1 = pads[0];
7423
7568
  padLen = pads[0] + pads[1];
@@ -7936,19 +8081,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
7936
8081
  //#region src/library/random.ts
7937
8082
  var random_exports = {};
7938
8083
  __export(random_exports, {
8084
+ ball: () => ball,
7939
8085
  bernoulli: () => bernoulli,
7940
8086
  bits: () => bits,
7941
8087
  categorical: () => categorical,
7942
8088
  cauchy: () => cauchy,
8089
+ choice: () => choice,
8090
+ doubleSidedMaxwell: () => doubleSidedMaxwell,
7943
8091
  exponential: () => exponential,
8092
+ geometric: () => geometric,
7944
8093
  gumbel: () => gumbel,
7945
8094
  key: () => key,
7946
8095
  laplace: () => laplace,
8096
+ logistic: () => logistic,
8097
+ lognormal: () => lognormal,
8098
+ maxwell: () => maxwell,
7947
8099
  multivariateNormal: () => multivariateNormal,
7948
8100
  normal: () => normal,
8101
+ pareto: () => pareto,
8102
+ permutation: () => permutation,
8103
+ rademacher: () => rademacher,
8104
+ randint: () => randint,
8105
+ rayleigh: () => rayleigh,
7949
8106
  split: () => split,
7950
- uniform: () => uniform
8107
+ triangular: () => triangular,
8108
+ uniform: () => uniform,
8109
+ weibullMin: () => weibullMin
7951
8110
  });
8111
+ const JsArray = globalThis.Array;
7952
8112
  function validateKeyShape(key$1, scalar = false) {
7953
8113
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
7954
8114
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
@@ -8001,6 +8161,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
8001
8161
  else return rand.mul(maxval - minval).add(minval);
8002
8162
  }, { staticArgnums: [1, 2] });
8003
8163
  /**
8164
+ * @function
8165
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
8166
+ *
8167
+ * Only the Euclidean `p=2` case is currently supported.
8168
+ */
8169
+ const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
8170
+ if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
8171
+ if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
8172
+ const [k1, k2] = split(key$1, 2);
8173
+ const z = normal(k1, [...shape$1, d]);
8174
+ const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
8175
+ const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
8176
+ return z.div(norm).mul(radius);
8177
+ }, { staticArgnums: [1, 2] });
8178
+ /**
8004
8179
  * Sample Bernoulli random variables with given mean (0,1 categorical).
8005
8180
  *
8006
8181
  * Returns a random Boolean array with the specified shape. `p` can be an array
@@ -8062,6 +8237,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
8062
8237
  return tan(u.sub(.5).mul(Math.PI));
8063
8238
  }, { staticArgnums: [1] });
8064
8239
  /**
8240
+ * Sample from a population with optional replacement and optional probabilities.
8241
+ *
8242
+ * This implements the common JAX-compatible cases: integer populations and
8243
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
8244
+ * via `categorical(log(p))`.
8245
+ */
8246
+ function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
8247
+ let n;
8248
+ let values = null;
8249
+ if (typeof a === "number") {
8250
+ if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
8251
+ n = a;
8252
+ } else {
8253
+ values = fudgeArray(a);
8254
+ axis = require_backend.checkAxis(axis, values.ndim);
8255
+ n = values.shape[axis];
8256
+ }
8257
+ let indices;
8258
+ if (p !== void 0) indices = categorical(key$1, log(p), {
8259
+ shape: shape$1,
8260
+ replace
8261
+ });
8262
+ else if (replace) indices = randint(key$1, {
8263
+ minval: 0,
8264
+ maxval: n,
8265
+ shape: shape$1
8266
+ });
8267
+ else {
8268
+ const k = shape$1.reduce((acc, x) => acc * x, 1);
8269
+ if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
8270
+ indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
8271
+ }
8272
+ if (values === null) return indices;
8273
+ const index = JsArray(axis).fill([]);
8274
+ index.push(indices);
8275
+ return values.slice(...index);
8276
+ }
8277
+ /**
8278
+ * @function
8279
+ * Sample double-sided Maxwell random values with the provided location and scale.
8280
+ */
8281
+ const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
8282
+ loc = fudgeArray(loc);
8283
+ scale = fudgeArray(scale);
8284
+ const [k1, k2] = split(key$1, 2);
8285
+ return rademacher(k1, {
8286
+ shape: shape$1,
8287
+ dtype: require_backend.DType.Float32
8288
+ }).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
8289
+ }, { staticArgnums: [3] });
8290
+ /**
8065
8291
  * @function
8066
8292
  * Sample exponential random values according to `p(x) = exp(-x)`.
8067
8293
  */
@@ -8071,6 +8297,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
8071
8297
  }, { staticArgnums: [1] });
8072
8298
  /**
8073
8299
  * @function
8300
+ * Sample geometric random values: the number of trials until first success.
8301
+ */
8302
+ const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
8303
+ p = fudgeArray(p);
8304
+ return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
8305
+ }, { staticArgnums: [2] });
8306
+ /**
8307
+ * @function
8074
8308
  * Sample from a Gumbel distribution with location 0 and scale 1.
8075
8309
  *
8076
8310
  * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
@@ -8095,6 +8329,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
8095
8329
  }, { staticArgnums: [1] });
8096
8330
  /**
8097
8331
  * @function
8332
+ * Sample from a logistic distribution with location 0 and scale 1.
8333
+ *
8334
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
8335
+ */
8336
+ const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
8337
+ const u = uniform(key$1, shape$1);
8338
+ return log(u.ref).sub(log1p(negative(u)));
8339
+ }, { staticArgnums: [1] });
8340
+ /**
8341
+ * @function
8342
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
8343
+ */
8344
+ const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
8345
+ sigma = fudgeArray(sigma);
8346
+ return exp(normal(key$1, shape$1).mul(sigma));
8347
+ }, { staticArgnums: [2] });
8348
+ /**
8349
+ * @function
8350
+ * Sample Maxwell-distributed random values.
8351
+ */
8352
+ const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
8353
+ const z = normal(key$1, [...shape$1, 3]);
8354
+ return sqrt(z.ref.mul(z).sum(-1));
8355
+ }, { staticArgnums: [1] });
8356
+ /**
8357
+ * @function
8098
8358
  * Sample multivariate normal random values with given mean and covariance.
8099
8359
  *
8100
8360
  * The values are returned with the given shape, along with the final dimension
@@ -8135,6 +8395,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
8135
8395
  const theta = u2.mul(2 * Math.PI);
8136
8396
  return radius.mul(cos(theta));
8137
8397
  }, { staticArgnums: [1] });
8398
+ /**
8399
+ * @function
8400
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
8401
+ */
8402
+ const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
8403
+ b = fudgeArray(b);
8404
+ return exp(exponential(key$1, shape$1).div(b));
8405
+ }, { staticArgnums: [2] });
8406
+ /**
8407
+ * Return a random permutation of an integer range or of an array along `axis`.
8408
+ */
8409
+ function permutation(key$1, x, axis = 0) {
8410
+ if (typeof x === "number") {
8411
+ if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
8412
+ return argsort(uniform(key$1, [x])).astype(require_backend.DType.Int32);
8413
+ }
8414
+ const arr = fudgeArray(x);
8415
+ axis = require_backend.checkAxis(axis, arr.ndim);
8416
+ const perm = permutation(key$1, arr.shape[axis]);
8417
+ const index = JsArray(axis).fill([]);
8418
+ index.push(perm);
8419
+ return arr.slice(...index);
8420
+ }
8421
+ /**
8422
+ * @function
8423
+ * Sample Rademacher random values, uniformly from {-1, 1}.
8424
+ */
8425
+ const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = require_backend.DType.Int32 } = {}) {
8426
+ if (dtype === require_backend.DType.Uint32 || dtype === require_backend.DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
8427
+ const one = array(1, {
8428
+ dtype,
8429
+ device: key$1.device
8430
+ });
8431
+ const minusOne = array(-1, {
8432
+ dtype,
8433
+ device: key$1.device
8434
+ });
8435
+ return where(bernoulli(key$1, .5, shape$1), one, minusOne);
8436
+ }, { staticArgnums: [1] });
8437
+ /**
8438
+ * @function
8439
+ * Sample integer values uniformly from `[minval, maxval)`.
8440
+ *
8441
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
8442
+ * not divide 2^32, this introduces a very small modulo bias.
8443
+ */
8444
+ const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = require_backend.DType.Int32 }) {
8445
+ if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
8446
+ if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
8447
+ if (dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
8448
+ if (dtype === require_backend.DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
8449
+ const range$1 = maxval - minval;
8450
+ return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
8451
+ }, { staticArgnums: [1] });
8452
+ /**
8453
+ * @function
8454
+ * Sample Rayleigh random values with the provided scale parameter.
8455
+ */
8456
+ const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
8457
+ scale = fudgeArray(scale);
8458
+ return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
8459
+ }, { staticArgnums: [2] });
8460
+ /**
8461
+ * @function
8462
+ * Sample triangular random values on `[left, right]` with the given mode.
8463
+ */
8464
+ const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
8465
+ left = fudgeArray(left);
8466
+ mode = fudgeArray(mode);
8467
+ right = fudgeArray(right);
8468
+ const u = uniform(key$1, shape$1);
8469
+ const width = right.ref.sub(left.ref);
8470
+ const leftSpan = mode.ref.sub(left.ref);
8471
+ const rightSpan = right.ref.sub(mode);
8472
+ const cutoff = leftSpan.ref.div(width.ref);
8473
+ const cond = u.ref.less(cutoff);
8474
+ const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
8475
+ const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
8476
+ return where(cond, lower, upper);
8477
+ }, { staticArgnums: [4] });
8478
+ /**
8479
+ * @function
8480
+ * Sample Weibull minimum random values.
8481
+ *
8482
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
8483
+ */
8484
+ const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
8485
+ scale = fudgeArray(scale);
8486
+ concentration = fudgeArray(concentration);
8487
+ return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
8488
+ }, { staticArgnums: [3] });
8138
8489
 
8139
8490
  //#endregion
8140
8491
  //#region src/library/scipy-special.ts
@@ -8336,6 +8687,7 @@ exports.blockUntilReady = blockUntilReady;
8336
8687
  exports.defaultDevice = require_backend.defaultDevice;
8337
8688
  exports.devicePut = devicePut;
8338
8689
  exports.devices = require_backend.devices;
8690
+ exports.getWebGPUDevice = require_backend.getWebGPUDevice;
8339
8691
  exports.grad = grad;
8340
8692
  exports.hessian = hessian;
8341
8693
  exports.init = require_backend.init;