@jax-js/jax 0.1.11 → 0.1.13

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -2719,7 +2719,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
2719
2719
  localWindowSize?: number | [number, number];
2720
2720
  }): Array;
2721
2721
  declare namespace random_d_exports {
2722
- export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
2722
+ export { ball, bernoulli, bits, categorical, cauchy, choice, doubleSidedMaxwell, exponential, geometric, gumbel, key, laplace, logistic, lognormal, maxwell, multivariateNormal, normal, pareto, permutation, rademacher, randint, rayleigh, split, triangular, uniform, weibullMin };
2723
2723
  }
2724
2724
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
2725
2725
  declare function key(seed: ArrayLike): Array;
@@ -2735,6 +2735,16 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
2735
2735
  minval?: number | undefined;
2736
2736
  maxval?: number | undefined;
2737
2737
  } | undefined) => Array>;
2738
+ /**
2739
+ * @function
2740
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
2741
+ *
2742
+ * Only the Euclidean `p=2` case is currently supported.
2743
+ */
2744
+ declare const ball: OwnedFunction<(key: ArrayLike, d: number, args_2?: {
2745
+ p?: number | undefined;
2746
+ shape?: number[] | undefined;
2747
+ } | undefined) => Array>;
2738
2748
  /**
2739
2749
  * Sample Bernoulli random variables with given mean (0,1 categorical).
2740
2750
  *
@@ -2775,11 +2785,42 @@ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, arg
2775
2785
  * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
2776
2786
  */
2777
2787
  declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2788
+ /**
2789
+ * Sample from a population with optional replacement and optional probabilities.
2790
+ *
2791
+ * This implements the common JAX-compatible cases: integer populations and
2792
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
2793
+ * via `categorical(log(p))`.
2794
+ */
2795
+ declare function choice(key: Array, a: number | ArrayLike, {
2796
+ shape,
2797
+ replace,
2798
+ p,
2799
+ axis
2800
+ }?: {
2801
+ shape?: number[];
2802
+ replace?: boolean;
2803
+ p?: ArrayLike;
2804
+ axis?: number;
2805
+ }): Array;
2806
+ /**
2807
+ * @function
2808
+ * Sample double-sided Maxwell random values with the provided location and scale.
2809
+ */
2810
+ declare const doubleSidedMaxwell: OwnedFunction<(key: ArrayLike, loc: ArrayLike, scale: ArrayLike, shape?: number[] | undefined) => Array>;
2778
2811
  /**
2779
2812
  * @function
2780
2813
  * Sample exponential random values according to `p(x) = exp(-x)`.
2781
2814
  */
2782
2815
  declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2816
+ /**
2817
+ * @function
2818
+ * Sample geometric random values: the number of trials until first success.
2819
+ */
2820
+ declare const geometric: OwnedFunction<(key: ArrayLike, p: ArrayLike, args_2?: {
2821
+ shape?: number[] | undefined;
2822
+ dtype?: DType | undefined;
2823
+ } | undefined) => Array>;
2783
2824
  /**
2784
2825
  * @function
2785
2826
  * Sample from a Gumbel distribution with location 0 and scale 1.
@@ -2795,6 +2836,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
2795
2836
  * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2796
2837
  */
2797
2838
  declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2839
+ /**
2840
+ * @function
2841
+ * Sample from a logistic distribution with location 0 and scale 1.
2842
+ *
2843
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
2844
+ */
2845
+ declare const logistic: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2846
+ /**
2847
+ * @function
2848
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
2849
+ */
2850
+ declare const lognormal: OwnedFunction<(key: ArrayLike, sigma?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2851
+ /**
2852
+ * @function
2853
+ * Sample Maxwell-distributed random values.
2854
+ */
2855
+ declare const maxwell: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2798
2856
  /**
2799
2857
  * @function
2800
2858
  * Sample multivariate normal random values with given mean and covariance.
@@ -2821,6 +2879,53 @@ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike
2821
2879
  * bitwise identical to JAX.
2822
2880
  */
2823
2881
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2882
+ /**
2883
+ * @function
2884
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
2885
+ */
2886
+ declare const pareto: OwnedFunction<(key: ArrayLike, b: ArrayLike, shape?: number[] | undefined) => Array>;
2887
+ /**
2888
+ * Return a random permutation of an integer range or of an array along `axis`.
2889
+ */
2890
+ declare function permutation(key: Array, x: number | ArrayLike, axis?: number): Array;
2891
+ /**
2892
+ * @function
2893
+ * Sample Rademacher random values, uniformly from {-1, 1}.
2894
+ */
2895
+ declare const rademacher: OwnedFunction<(key: ArrayLike, args_1?: {
2896
+ shape?: number[] | undefined;
2897
+ dtype?: DType | undefined;
2898
+ } | undefined) => Array>;
2899
+ /**
2900
+ * @function
2901
+ * Sample integer values uniformly from `[minval, maxval)`.
2902
+ *
2903
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
2904
+ * not divide 2^32, this introduces a very small modulo bias.
2905
+ */
2906
+ declare const randint: OwnedFunction<(key: ArrayLike, args_1: {
2907
+ minval: number;
2908
+ maxval: number;
2909
+ shape?: number[] | undefined;
2910
+ dtype?: DType | undefined;
2911
+ }) => Array>;
2912
+ /**
2913
+ * @function
2914
+ * Sample Rayleigh random values with the provided scale parameter.
2915
+ */
2916
+ declare const rayleigh: OwnedFunction<(key: ArrayLike, scale?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
2917
+ /**
2918
+ * @function
2919
+ * Sample triangular random values on `[left, right]` with the given mode.
2920
+ */
2921
+ declare const triangular: OwnedFunction<(key: ArrayLike, left: ArrayLike, mode: ArrayLike, right: ArrayLike, shape?: number[] | undefined) => Array>;
2922
+ /**
2923
+ * @function
2924
+ * Sample Weibull minimum random values.
2925
+ *
2926
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
2927
+ */
2928
+ declare const weibullMin: OwnedFunction<(key: ArrayLike, scale: ArrayLike, concentration: ArrayLike, shape?: number[] | undefined) => Array>;
2824
2929
  declare namespace scipy_special_d_exports {
2825
2930
  export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
2826
2931
  }
package/dist/index.js CHANGED
@@ -1,5 +1,5 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DZvR7mZV.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DLEk-B3V.js";
3
3
 
4
4
  //#region src/frontend/convolution.ts
5
5
  /**
@@ -209,7 +209,7 @@ __export(tree_exports, {
209
209
  structure: () => structure,
210
210
  unflatten: () => unflatten
211
211
  });
212
- const JsArray$2 = globalThis.Array;
212
+ const JsArray$3 = globalThis.Array;
213
213
  let NodeType = /* @__PURE__ */ function(NodeType$1) {
214
214
  NodeType$1["Array"] = "Array";
215
215
  NodeType$1["Object"] = "Object";
@@ -257,7 +257,7 @@ function flatten(tree) {
257
257
  return [leaves$1, treedef];
258
258
  }
259
259
  function _flatten(tree, leaves$1) {
260
- if (JsArray$2.isArray(tree)) {
260
+ if (JsArray$3.isArray(tree)) {
261
261
  const childTrees = tree.map((c) => _flatten(c, leaves$1));
262
262
  return new JsTreeDef(NodeType.Array, null, childTrees);
263
263
  } else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
@@ -2460,7 +2460,7 @@ function splitGraphDataflow(backend, jaxpr) {
2460
2460
 
2461
2461
  //#endregion
2462
2462
  //#region src/frontend/array.ts
2463
- const JsArray$1 = globalThis.Array;
2463
+ const JsArray$2 = globalThis.Array;
2464
2464
  const inlineArrayLimit = 128;
2465
2465
  /** Version of pureArray with fudged types. */
2466
2466
  const fudgeArray = pureArray;
@@ -2900,6 +2900,15 @@ var Array$1 = class Array$1 extends Tracer {
2900
2900
  this.#check();
2901
2901
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
2902
2902
  if (this.#source instanceof AluExp) {
2903
+ let resolvedSource;
2904
+ if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
2905
+ const byteLength = this.#st.size * byteWidth(this.#dtype);
2906
+ const initialData = new Uint8Array(byteLength);
2907
+ dtypedArray(this.#dtype, initialData).fill(resolvedSource);
2908
+ this.#source = this.#backend.malloc(byteLength, initialData);
2909
+ this.#st = ShapeTracker.fromShape(this.shape);
2910
+ return;
2911
+ }
2903
2912
  const exp$2 = accessorAluExp(this.#source, this.#st, indices);
2904
2913
  const kernel = new Kernel(0, this.#st.size, exp$2);
2905
2914
  const output = this.#backend.malloc(kernel.bytes);
@@ -3350,7 +3359,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
3350
3359
  if (!shape$1) {
3351
3360
  shape$1 = [];
3352
3361
  let cur = values;
3353
- while (JsArray$1.isArray(cur)) {
3362
+ while (JsArray$2.isArray(cur)) {
3354
3363
  shape$1.push(cur.length);
3355
3364
  cur = cur[0];
3356
3365
  }
@@ -4232,7 +4241,7 @@ const jvpRules = {
4232
4241
  return [[L], [dL]];
4233
4242
  },
4234
4243
  [Primitive.LU]([a], [da]) {
4235
- const [luMatrix, pivots, permutation] = lu$1(a);
4244
+ const [luMatrix, pivots, permutation$1] = lu$1(a);
4236
4245
  const [m, n] = a.shape.slice(-2);
4237
4246
  const k = Math.min(m, n);
4238
4247
  const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
@@ -4244,7 +4253,7 @@ const jvpRules = {
4244
4253
  const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
4245
4254
  const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
4246
4255
  const U = uPadded.add(uEye);
4247
- const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
4256
+ const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
4248
4257
  const pda = batchMatmulT(P, mT(da));
4249
4258
  const la = mT(triangularSolve$1(L.ref, mT(pda), {
4250
4259
  lower: true,
@@ -4256,11 +4265,11 @@ const jvpRules = {
4256
4265
  return [[
4257
4266
  luMatrix,
4258
4267
  pivots,
4259
- permutation
4268
+ permutation$1
4260
4269
  ], [
4261
4270
  lDot.add(uDot),
4262
4271
  zerosLike$1(pivots.ref),
4263
- zerosLike$1(permutation.ref)
4272
+ zerosLike$1(permutation$1.ref)
4264
4273
  ]];
4265
4274
  },
4266
4275
  [Primitive.Jit](primals, tangents, { name, jaxpr }) {
@@ -5342,8 +5351,8 @@ function cross$1(x1, x2, axis = -1) {
5342
5351
  function det(a) {
5343
5352
  a = fudgeArray(a);
5344
5353
  const n = checkSquare("det", a);
5345
- const [lu$2, pivots, permutation] = lu(a);
5346
- permutation.dispose();
5354
+ const [lu$2, pivots, permutation$1] = lu(a);
5355
+ permutation$1.dispose();
5347
5356
  const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
5348
5357
  const sign$1 = parity.mul(-2).add(1);
5349
5358
  const diag$1 = lu$2.diagonal(0, -1, -2);
@@ -5432,8 +5441,8 @@ function matrixPower(a, n) {
5432
5441
  function slogdet(a) {
5433
5442
  a = fudgeArray(a);
5434
5443
  const n = checkSquare("slogdet", a);
5435
- const [lu$2, pivots, permutation] = lu(a);
5436
- permutation.dispose();
5444
+ const [lu$2, pivots, permutation$1] = lu(a);
5445
+ permutation$1.dispose();
5437
5446
  let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
5438
5447
  const diag$1 = lu$2.diagonal(0, -1, -2);
5439
5448
  parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
@@ -5471,9 +5480,9 @@ function solve(a, b) {
5471
5480
  n,
5472
5481
  m
5473
5482
  ]);
5474
- const [lu$2, pivots, permutation] = lu(a);
5483
+ const [lu$2, pivots, permutation$1] = lu(a);
5475
5484
  pivots.dispose();
5476
- const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
5485
+ const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
5477
5486
  const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
5478
5487
  leftSide: true,
5479
5488
  lower: true,
@@ -7329,7 +7338,7 @@ __export(lax_exports, {
7329
7338
  stopGradient: () => stopGradient$1,
7330
7339
  topK: () => topK
7331
7340
  });
7332
- const JsArray = globalThis.Array;
7341
+ const JsArray$1 = globalThis.Array;
7333
7342
  /** Elementwise bitcast an array into a new dtype. */
7334
7343
  function bitcastConvertType(x, newDtype) {
7335
7344
  return fudgeArray(x).view(newDtype);
@@ -7516,7 +7525,7 @@ function convTransposePadding(k, s, padding) {
7516
7525
  } else if (padding === "VALID") {
7517
7526
  padLen = k + s - 2 + Math.max(k - s, 0);
7518
7527
  pad1 = k - 1;
7519
- } else if (JsArray.isArray(padding)) {
7528
+ } else if (JsArray$1.isArray(padding)) {
7520
7529
  const pads = [k - 1 - padding[0], k - 1 - padding[1]];
7521
7530
  pad1 = pads[0];
7522
7531
  padLen = pads[0] + pads[1];
@@ -8035,19 +8044,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
8035
8044
  //#region src/library/random.ts
8036
8045
  var random_exports = {};
8037
8046
  __export(random_exports, {
8047
+ ball: () => ball,
8038
8048
  bernoulli: () => bernoulli,
8039
8049
  bits: () => bits,
8040
8050
  categorical: () => categorical,
8041
8051
  cauchy: () => cauchy,
8052
+ choice: () => choice,
8053
+ doubleSidedMaxwell: () => doubleSidedMaxwell,
8042
8054
  exponential: () => exponential,
8055
+ geometric: () => geometric,
8043
8056
  gumbel: () => gumbel,
8044
8057
  key: () => key,
8045
8058
  laplace: () => laplace,
8059
+ logistic: () => logistic,
8060
+ lognormal: () => lognormal,
8061
+ maxwell: () => maxwell,
8046
8062
  multivariateNormal: () => multivariateNormal,
8047
8063
  normal: () => normal,
8064
+ pareto: () => pareto,
8065
+ permutation: () => permutation,
8066
+ rademacher: () => rademacher,
8067
+ randint: () => randint,
8068
+ rayleigh: () => rayleigh,
8048
8069
  split: () => split,
8049
- uniform: () => uniform
8070
+ triangular: () => triangular,
8071
+ uniform: () => uniform,
8072
+ weibullMin: () => weibullMin
8050
8073
  });
8074
+ const JsArray = globalThis.Array;
8051
8075
  function validateKeyShape(key$1, scalar = false) {
8052
8076
  if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
8053
8077
  if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
@@ -8100,6 +8124,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
8100
8124
  else return rand.mul(maxval - minval).add(minval);
8101
8125
  }, { staticArgnums: [1, 2] });
8102
8126
  /**
8127
+ * @function
8128
+ * Sample points uniformly from the Euclidean unit ball in `d` dimensions.
8129
+ *
8130
+ * Only the Euclidean `p=2` case is currently supported.
8131
+ */
8132
+ const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
8133
+ if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
8134
+ if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
8135
+ const [k1, k2] = split(key$1, 2);
8136
+ const z = normal(k1, [...shape$1, d]);
8137
+ const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
8138
+ const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
8139
+ return z.div(norm).mul(radius);
8140
+ }, { staticArgnums: [1, 2] });
8141
+ /**
8103
8142
  * Sample Bernoulli random variables with given mean (0,1 categorical).
8104
8143
  *
8105
8144
  * Returns a random Boolean array with the specified shape. `p` can be an array
@@ -8161,6 +8200,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
8161
8200
  return tan(u.sub(.5).mul(Math.PI));
8162
8201
  }, { staticArgnums: [1] });
8163
8202
  /**
8203
+ * Sample from a population with optional replacement and optional probabilities.
8204
+ *
8205
+ * This implements the common JAX-compatible cases: integer populations and
8206
+ * array populations along `axis`. Probabilities `p`, if provided, are sampled
8207
+ * via `categorical(log(p))`.
8208
+ */
8209
+ function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
8210
+ let n;
8211
+ let values = null;
8212
+ if (typeof a === "number") {
8213
+ if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
8214
+ n = a;
8215
+ } else {
8216
+ values = fudgeArray(a);
8217
+ axis = checkAxis(axis, values.ndim);
8218
+ n = values.shape[axis];
8219
+ }
8220
+ let indices;
8221
+ if (p !== void 0) indices = categorical(key$1, log(p), {
8222
+ shape: shape$1,
8223
+ replace
8224
+ });
8225
+ else if (replace) indices = randint(key$1, {
8226
+ minval: 0,
8227
+ maxval: n,
8228
+ shape: shape$1
8229
+ });
8230
+ else {
8231
+ const k = shape$1.reduce((acc, x) => acc * x, 1);
8232
+ if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
8233
+ indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
8234
+ }
8235
+ if (values === null) return indices;
8236
+ const index = JsArray(axis).fill([]);
8237
+ index.push(indices);
8238
+ return values.slice(...index);
8239
+ }
8240
+ /**
8241
+ * @function
8242
+ * Sample double-sided Maxwell random values with the provided location and scale.
8243
+ */
8244
+ const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
8245
+ loc = fudgeArray(loc);
8246
+ scale = fudgeArray(scale);
8247
+ const [k1, k2] = split(key$1, 2);
8248
+ return rademacher(k1, {
8249
+ shape: shape$1,
8250
+ dtype: DType.Float32
8251
+ }).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
8252
+ }, { staticArgnums: [3] });
8253
+ /**
8164
8254
  * @function
8165
8255
  * Sample exponential random values according to `p(x) = exp(-x)`.
8166
8256
  */
@@ -8170,6 +8260,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
8170
8260
  }, { staticArgnums: [1] });
8171
8261
  /**
8172
8262
  * @function
8263
+ * Sample geometric random values: the number of trials until first success.
8264
+ */
8265
+ const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
8266
+ p = fudgeArray(p);
8267
+ return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
8268
+ }, { staticArgnums: [2] });
8269
+ /**
8270
+ * @function
8173
8271
  * Sample from a Gumbel distribution with location 0 and scale 1.
8174
8272
  *
8175
8273
  * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
@@ -8194,6 +8292,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
8194
8292
  }, { staticArgnums: [1] });
8195
8293
  /**
8196
8294
  * @function
8295
+ * Sample from a logistic distribution with location 0 and scale 1.
8296
+ *
8297
+ * Uses inverse transform sampling: `x = log(u) - log(1-u)`.
8298
+ */
8299
+ const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
8300
+ const u = uniform(key$1, shape$1);
8301
+ return log(u.ref).sub(log1p(negative(u)));
8302
+ }, { staticArgnums: [1] });
8303
+ /**
8304
+ * @function
8305
+ * Sample log-normal random values: `exp(sigma * normal(key, shape))`.
8306
+ */
8307
+ const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
8308
+ sigma = fudgeArray(sigma);
8309
+ return exp(normal(key$1, shape$1).mul(sigma));
8310
+ }, { staticArgnums: [2] });
8311
+ /**
8312
+ * @function
8313
+ * Sample Maxwell-distributed random values.
8314
+ */
8315
+ const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
8316
+ const z = normal(key$1, [...shape$1, 3]);
8317
+ return sqrt(z.ref.mul(z).sum(-1));
8318
+ }, { staticArgnums: [1] });
8319
+ /**
8320
+ * @function
8197
8321
  * Sample multivariate normal random values with given mean and covariance.
8198
8322
  *
8199
8323
  * The values are returned with the given shape, along with the final dimension
@@ -8234,6 +8358,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
8234
8358
  const theta = u2.mul(2 * Math.PI);
8235
8359
  return radius.mul(cos(theta));
8236
8360
  }, { staticArgnums: [1] });
8361
+ /**
8362
+ * @function
8363
+ * Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
8364
+ */
8365
+ const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
8366
+ b = fudgeArray(b);
8367
+ return exp(exponential(key$1, shape$1).div(b));
8368
+ }, { staticArgnums: [2] });
8369
+ /**
8370
+ * Return a random permutation of an integer range or of an array along `axis`.
8371
+ */
8372
+ function permutation(key$1, x, axis = 0) {
8373
+ if (typeof x === "number") {
8374
+ if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
8375
+ return argsort(uniform(key$1, [x])).astype(DType.Int32);
8376
+ }
8377
+ const arr = fudgeArray(x);
8378
+ axis = checkAxis(axis, arr.ndim);
8379
+ const perm = permutation(key$1, arr.shape[axis]);
8380
+ const index = JsArray(axis).fill([]);
8381
+ index.push(perm);
8382
+ return arr.slice(...index);
8383
+ }
8384
+ /**
8385
+ * @function
8386
+ * Sample Rademacher random values, uniformly from {-1, 1}.
8387
+ */
8388
+ const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
8389
+ if (dtype === DType.Uint32 || dtype === DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
8390
+ const one = array(1, {
8391
+ dtype,
8392
+ device: key$1.device
8393
+ });
8394
+ const minusOne = array(-1, {
8395
+ dtype,
8396
+ device: key$1.device
8397
+ });
8398
+ return where(bernoulli(key$1, .5, shape$1), one, minusOne);
8399
+ }, { staticArgnums: [1] });
8400
+ /**
8401
+ * @function
8402
+ * Sample integer values uniformly from `[minval, maxval)`.
8403
+ *
8404
+ * This uses modulo reduction of uniform 32-bit random bits. For ranges that do
8405
+ * not divide 2^32, this introduces a very small modulo bias.
8406
+ */
8407
+ const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = DType.Int32 }) {
8408
+ if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
8409
+ if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
8410
+ if (dtype !== DType.Int32 && dtype !== DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
8411
+ if (dtype === DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
8412
+ const range$1 = maxval - minval;
8413
+ return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
8414
+ }, { staticArgnums: [1] });
8415
+ /**
8416
+ * @function
8417
+ * Sample Rayleigh random values with the provided scale parameter.
8418
+ */
8419
+ const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
8420
+ scale = fudgeArray(scale);
8421
+ return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
8422
+ }, { staticArgnums: [2] });
8423
+ /**
8424
+ * @function
8425
+ * Sample triangular random values on `[left, right]` with the given mode.
8426
+ */
8427
+ const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
8428
+ left = fudgeArray(left);
8429
+ mode = fudgeArray(mode);
8430
+ right = fudgeArray(right);
8431
+ const u = uniform(key$1, shape$1);
8432
+ const width = right.ref.sub(left.ref);
8433
+ const leftSpan = mode.ref.sub(left.ref);
8434
+ const rightSpan = right.ref.sub(mode);
8435
+ const cutoff = leftSpan.ref.div(width.ref);
8436
+ const cond = u.ref.less(cutoff);
8437
+ const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
8438
+ const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
8439
+ return where(cond, lower, upper);
8440
+ }, { staticArgnums: [4] });
8441
+ /**
8442
+ * @function
8443
+ * Sample Weibull minimum random values.
8444
+ *
8445
+ * Uses `scale * exponential(key) ** (1 / concentration)`.
8446
+ */
8447
+ const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
8448
+ scale = fudgeArray(scale);
8449
+ concentration = fudgeArray(concentration);
8450
+ return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
8451
+ }, { staticArgnums: [3] });
8237
8452
 
8238
8453
  //#endregion
8239
8454
  //#region src/library/scipy-special.ts
@@ -1,4 +1,4 @@
1
- import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DZvR7mZV.js";
1
+ import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DLEk-B3V.js";
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `
@@ -1,4 +1,4 @@
1
- const require_backend = require('./backend-DlYlOYqN.cjs');
1
+ const require_backend = require('./backend-DMyuoWi2.cjs');
2
2
 
3
3
  //#region src/backend/webgl/builtins.ts
4
4
  const threefrySrc = `