@jax-js/jax 0.1.11 → 0.1.12
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +3 -1
- package/dist/{backend-DZvR7mZV.js → backend-DI-V78Rk.js} +4 -2
- package/dist/{backend-DlYlOYqN.cjs → backend-x-6vqzIM.cjs} +4 -2
- package/dist/index.cjs +233 -18
- package/dist/index.d.cts +106 -1
- package/dist/index.d.ts +106 -1
- package/dist/index.js +233 -18
- package/dist/{webgl-D8-14NzA.js → webgl-BhsnpeB0.js} +1 -1
- package/dist/{webgl-Ovaaa-Qx.cjs → webgl-CD3WK_Me.cjs} +1 -1
- package/dist/{webgpu-Dg8FpYrH.js → webgpu-C2kLdkUh.js} +299 -154
- package/dist/{webgpu-uU9nnttc.cjs → webgpu-C4S8Uq9e.cjs} +299 -154
- package/package.json +1 -1
package/dist/index.d.ts
CHANGED
|
@@ -2719,7 +2719,7 @@ declare function dotProductAttention(query: ArrayLike, key: ArrayLike, value: Ar
|
|
|
2719
2719
|
localWindowSize?: number | [number, number];
|
|
2720
2720
|
}): Array;
|
|
2721
2721
|
declare namespace random_d_exports {
|
|
2722
|
-
export { bernoulli, bits, categorical, cauchy, exponential, gumbel, key, laplace, multivariateNormal, normal, split, uniform };
|
|
2722
|
+
export { ball, bernoulli, bits, categorical, cauchy, choice, doubleSidedMaxwell, exponential, geometric, gumbel, key, laplace, logistic, lognormal, maxwell, multivariateNormal, normal, pareto, permutation, rademacher, randint, rayleigh, split, triangular, uniform, weibullMin };
|
|
2723
2723
|
}
|
|
2724
2724
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
2725
2725
|
declare function key(seed: ArrayLike): Array;
|
|
@@ -2735,6 +2735,16 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
|
|
|
2735
2735
|
minval?: number | undefined;
|
|
2736
2736
|
maxval?: number | undefined;
|
|
2737
2737
|
} | undefined) => Array>;
|
|
2738
|
+
/**
|
|
2739
|
+
* @function
|
|
2740
|
+
* Sample points uniformly from the Euclidean unit ball in `d` dimensions.
|
|
2741
|
+
*
|
|
2742
|
+
* Only the Euclidean `p=2` case is currently supported.
|
|
2743
|
+
*/
|
|
2744
|
+
declare const ball: OwnedFunction<(key: ArrayLike, d: number, args_2?: {
|
|
2745
|
+
p?: number | undefined;
|
|
2746
|
+
shape?: number[] | undefined;
|
|
2747
|
+
} | undefined) => Array>;
|
|
2738
2748
|
/**
|
|
2739
2749
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
2740
2750
|
*
|
|
@@ -2775,11 +2785,42 @@ declare const categorical: OwnedFunction<(key: ArrayLike, logits: ArrayLike, arg
|
|
|
2775
2785
|
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
2776
2786
|
*/
|
|
2777
2787
|
declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2788
|
+
/**
|
|
2789
|
+
* Sample from a population with optional replacement and optional probabilities.
|
|
2790
|
+
*
|
|
2791
|
+
* This implements the common JAX-compatible cases: integer populations and
|
|
2792
|
+
* array populations along `axis`. Probabilities `p`, if provided, are sampled
|
|
2793
|
+
* via `categorical(log(p))`.
|
|
2794
|
+
*/
|
|
2795
|
+
declare function choice(key: Array, a: number | ArrayLike, {
|
|
2796
|
+
shape,
|
|
2797
|
+
replace,
|
|
2798
|
+
p,
|
|
2799
|
+
axis
|
|
2800
|
+
}?: {
|
|
2801
|
+
shape?: number[];
|
|
2802
|
+
replace?: boolean;
|
|
2803
|
+
p?: ArrayLike;
|
|
2804
|
+
axis?: number;
|
|
2805
|
+
}): Array;
|
|
2806
|
+
/**
|
|
2807
|
+
* @function
|
|
2808
|
+
* Sample double-sided Maxwell random values with the provided location and scale.
|
|
2809
|
+
*/
|
|
2810
|
+
declare const doubleSidedMaxwell: OwnedFunction<(key: ArrayLike, loc: ArrayLike, scale: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2778
2811
|
/**
|
|
2779
2812
|
* @function
|
|
2780
2813
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
2781
2814
|
*/
|
|
2782
2815
|
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2816
|
+
/**
|
|
2817
|
+
* @function
|
|
2818
|
+
* Sample geometric random values: the number of trials until first success.
|
|
2819
|
+
*/
|
|
2820
|
+
declare const geometric: OwnedFunction<(key: ArrayLike, p: ArrayLike, args_2?: {
|
|
2821
|
+
shape?: number[] | undefined;
|
|
2822
|
+
dtype?: DType | undefined;
|
|
2823
|
+
} | undefined) => Array>;
|
|
2783
2824
|
/**
|
|
2784
2825
|
* @function
|
|
2785
2826
|
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
@@ -2795,6 +2836,23 @@ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefine
|
|
|
2795
2836
|
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
2796
2837
|
*/
|
|
2797
2838
|
declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2839
|
+
/**
|
|
2840
|
+
* @function
|
|
2841
|
+
* Sample from a logistic distribution with location 0 and scale 1.
|
|
2842
|
+
*
|
|
2843
|
+
* Uses inverse transform sampling: `x = log(u) - log(1-u)`.
|
|
2844
|
+
*/
|
|
2845
|
+
declare const logistic: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2846
|
+
/**
|
|
2847
|
+
* @function
|
|
2848
|
+
* Sample log-normal random values: `exp(sigma * normal(key, shape))`.
|
|
2849
|
+
*/
|
|
2850
|
+
declare const lognormal: OwnedFunction<(key: ArrayLike, sigma?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
|
|
2851
|
+
/**
|
|
2852
|
+
* @function
|
|
2853
|
+
* Sample Maxwell-distributed random values.
|
|
2854
|
+
*/
|
|
2855
|
+
declare const maxwell: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2798
2856
|
/**
|
|
2799
2857
|
* @function
|
|
2800
2858
|
* Sample multivariate normal random values with given mean and covariance.
|
|
@@ -2821,6 +2879,53 @@ declare const multivariateNormal: OwnedFunction<(key: ArrayLike, mean: ArrayLike
|
|
|
2821
2879
|
* bitwise identical to JAX.
|
|
2822
2880
|
*/
|
|
2823
2881
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2882
|
+
/**
|
|
2883
|
+
* @function
|
|
2884
|
+
* Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
|
|
2885
|
+
*/
|
|
2886
|
+
declare const pareto: OwnedFunction<(key: ArrayLike, b: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2887
|
+
/**
|
|
2888
|
+
* Return a random permutation of an integer range or of an array along `axis`.
|
|
2889
|
+
*/
|
|
2890
|
+
declare function permutation(key: Array, x: number | ArrayLike, axis?: number): Array;
|
|
2891
|
+
/**
|
|
2892
|
+
* @function
|
|
2893
|
+
* Sample Rademacher random values, uniformly from {-1, 1}.
|
|
2894
|
+
*/
|
|
2895
|
+
declare const rademacher: OwnedFunction<(key: ArrayLike, args_1?: {
|
|
2896
|
+
shape?: number[] | undefined;
|
|
2897
|
+
dtype?: DType | undefined;
|
|
2898
|
+
} | undefined) => Array>;
|
|
2899
|
+
/**
|
|
2900
|
+
* @function
|
|
2901
|
+
* Sample integer values uniformly from `[minval, maxval)`.
|
|
2902
|
+
*
|
|
2903
|
+
* This uses modulo reduction of uniform 32-bit random bits. For ranges that do
|
|
2904
|
+
* not divide 2^32, this introduces a very small modulo bias.
|
|
2905
|
+
*/
|
|
2906
|
+
declare const randint: OwnedFunction<(key: ArrayLike, args_1: {
|
|
2907
|
+
minval: number;
|
|
2908
|
+
maxval: number;
|
|
2909
|
+
shape?: number[] | undefined;
|
|
2910
|
+
dtype?: DType | undefined;
|
|
2911
|
+
}) => Array>;
|
|
2912
|
+
/**
|
|
2913
|
+
* @function
|
|
2914
|
+
* Sample Rayleigh random values with the provided scale parameter.
|
|
2915
|
+
*/
|
|
2916
|
+
declare const rayleigh: OwnedFunction<(key: ArrayLike, scale?: ArrayLike | undefined, shape?: number[] | undefined) => Array>;
|
|
2917
|
+
/**
|
|
2918
|
+
* @function
|
|
2919
|
+
* Sample triangular random values on `[left, right]` with the given mode.
|
|
2920
|
+
*/
|
|
2921
|
+
declare const triangular: OwnedFunction<(key: ArrayLike, left: ArrayLike, mode: ArrayLike, right: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2922
|
+
/**
|
|
2923
|
+
* @function
|
|
2924
|
+
* Sample Weibull minimum random values.
|
|
2925
|
+
*
|
|
2926
|
+
* Uses `scale * exponential(key) ** (1 / concentration)`.
|
|
2927
|
+
*/
|
|
2928
|
+
declare const weibullMin: OwnedFunction<(key: ArrayLike, scale: ArrayLike, concentration: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
2824
2929
|
declare namespace scipy_special_d_exports {
|
|
2825
2930
|
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
2826
2931
|
}
|
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, checkInts, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, getWebGPUDevice, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, startTrace, stopTrace, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DI-V78Rk.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -209,7 +209,7 @@ __export(tree_exports, {
|
|
|
209
209
|
structure: () => structure,
|
|
210
210
|
unflatten: () => unflatten
|
|
211
211
|
});
|
|
212
|
-
const JsArray$
|
|
212
|
+
const JsArray$3 = globalThis.Array;
|
|
213
213
|
let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
214
214
|
NodeType$1["Array"] = "Array";
|
|
215
215
|
NodeType$1["Object"] = "Object";
|
|
@@ -257,7 +257,7 @@ function flatten(tree) {
|
|
|
257
257
|
return [leaves$1, treedef];
|
|
258
258
|
}
|
|
259
259
|
function _flatten(tree, leaves$1) {
|
|
260
|
-
if (JsArray$
|
|
260
|
+
if (JsArray$3.isArray(tree)) {
|
|
261
261
|
const childTrees = tree.map((c) => _flatten(c, leaves$1));
|
|
262
262
|
return new JsTreeDef(NodeType.Array, null, childTrees);
|
|
263
263
|
} else if (typeof tree === "object" && tree !== null && tree.constructor === Object) {
|
|
@@ -2460,7 +2460,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2460
2460
|
|
|
2461
2461
|
//#endregion
|
|
2462
2462
|
//#region src/frontend/array.ts
|
|
2463
|
-
const JsArray$
|
|
2463
|
+
const JsArray$2 = globalThis.Array;
|
|
2464
2464
|
const inlineArrayLimit = 128;
|
|
2465
2465
|
/** Version of pureArray with fudged types. */
|
|
2466
2466
|
const fudgeArray = pureArray;
|
|
@@ -2900,6 +2900,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2900
2900
|
this.#check();
|
|
2901
2901
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
2902
2902
|
if (this.#source instanceof AluExp) {
|
|
2903
|
+
let resolvedSource;
|
|
2904
|
+
if (this.#st.contiguous && this.#st.size < inlineArrayLimit && (resolvedSource = this.#source.resolve()) !== void 0) {
|
|
2905
|
+
const byteLength = this.#st.size * byteWidth(this.#dtype);
|
|
2906
|
+
const initialData = new Uint8Array(byteLength);
|
|
2907
|
+
dtypedArray(this.#dtype, initialData).fill(resolvedSource);
|
|
2908
|
+
this.#source = this.#backend.malloc(byteLength, initialData);
|
|
2909
|
+
this.#st = ShapeTracker.fromShape(this.shape);
|
|
2910
|
+
return;
|
|
2911
|
+
}
|
|
2903
2912
|
const exp$2 = accessorAluExp(this.#source, this.#st, indices);
|
|
2904
2913
|
const kernel = new Kernel(0, this.#st.size, exp$2);
|
|
2905
2914
|
const output = this.#backend.malloc(kernel.bytes);
|
|
@@ -3350,7 +3359,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
3350
3359
|
if (!shape$1) {
|
|
3351
3360
|
shape$1 = [];
|
|
3352
3361
|
let cur = values;
|
|
3353
|
-
while (JsArray$
|
|
3362
|
+
while (JsArray$2.isArray(cur)) {
|
|
3354
3363
|
shape$1.push(cur.length);
|
|
3355
3364
|
cur = cur[0];
|
|
3356
3365
|
}
|
|
@@ -4232,7 +4241,7 @@ const jvpRules = {
|
|
|
4232
4241
|
return [[L], [dL]];
|
|
4233
4242
|
},
|
|
4234
4243
|
[Primitive.LU]([a], [da]) {
|
|
4235
|
-
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4244
|
+
const [luMatrix, pivots, permutation$1] = lu$1(a);
|
|
4236
4245
|
const [m, n] = a.shape.slice(-2);
|
|
4237
4246
|
const k = Math.min(m, n);
|
|
4238
4247
|
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
@@ -4244,7 +4253,7 @@ const jvpRules = {
|
|
|
4244
4253
|
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4245
4254
|
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4246
4255
|
const U = uPadded.add(uEye);
|
|
4247
|
-
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4256
|
+
const P = permutation$1.ref.reshape([...permutation$1.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4248
4257
|
const pda = batchMatmulT(P, mT(da));
|
|
4249
4258
|
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4250
4259
|
lower: true,
|
|
@@ -4256,11 +4265,11 @@ const jvpRules = {
|
|
|
4256
4265
|
return [[
|
|
4257
4266
|
luMatrix,
|
|
4258
4267
|
pivots,
|
|
4259
|
-
permutation
|
|
4268
|
+
permutation$1
|
|
4260
4269
|
], [
|
|
4261
4270
|
lDot.add(uDot),
|
|
4262
4271
|
zerosLike$1(pivots.ref),
|
|
4263
|
-
zerosLike$1(permutation.ref)
|
|
4272
|
+
zerosLike$1(permutation$1.ref)
|
|
4264
4273
|
]];
|
|
4265
4274
|
},
|
|
4266
4275
|
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
@@ -5342,8 +5351,8 @@ function cross$1(x1, x2, axis = -1) {
|
|
|
5342
5351
|
function det(a) {
|
|
5343
5352
|
a = fudgeArray(a);
|
|
5344
5353
|
const n = checkSquare("det", a);
|
|
5345
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5346
|
-
permutation.dispose();
|
|
5354
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5355
|
+
permutation$1.dispose();
|
|
5347
5356
|
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5348
5357
|
const sign$1 = parity.mul(-2).add(1);
|
|
5349
5358
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
@@ -5432,8 +5441,8 @@ function matrixPower(a, n) {
|
|
|
5432
5441
|
function slogdet(a) {
|
|
5433
5442
|
a = fudgeArray(a);
|
|
5434
5443
|
const n = checkSquare("slogdet", a);
|
|
5435
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5436
|
-
permutation.dispose();
|
|
5444
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5445
|
+
permutation$1.dispose();
|
|
5437
5446
|
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5438
5447
|
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5439
5448
|
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
@@ -5471,9 +5480,9 @@ function solve(a, b) {
|
|
|
5471
5480
|
n,
|
|
5472
5481
|
m
|
|
5473
5482
|
]);
|
|
5474
|
-
const [lu$2, pivots, permutation] = lu(a);
|
|
5483
|
+
const [lu$2, pivots, permutation$1] = lu(a);
|
|
5475
5484
|
pivots.dispose();
|
|
5476
|
-
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5485
|
+
const P = arange(n).equal(permutation$1.reshape([...permutation$1.shape, 1])).astype(b.dtype);
|
|
5477
5486
|
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5478
5487
|
leftSide: true,
|
|
5479
5488
|
lower: true,
|
|
@@ -7329,7 +7338,7 @@ __export(lax_exports, {
|
|
|
7329
7338
|
stopGradient: () => stopGradient$1,
|
|
7330
7339
|
topK: () => topK
|
|
7331
7340
|
});
|
|
7332
|
-
const JsArray = globalThis.Array;
|
|
7341
|
+
const JsArray$1 = globalThis.Array;
|
|
7333
7342
|
/** Elementwise bitcast an array into a new dtype. */
|
|
7334
7343
|
function bitcastConvertType(x, newDtype) {
|
|
7335
7344
|
return fudgeArray(x).view(newDtype);
|
|
@@ -7516,7 +7525,7 @@ function convTransposePadding(k, s, padding) {
|
|
|
7516
7525
|
} else if (padding === "VALID") {
|
|
7517
7526
|
padLen = k + s - 2 + Math.max(k - s, 0);
|
|
7518
7527
|
pad1 = k - 1;
|
|
7519
|
-
} else if (JsArray.isArray(padding)) {
|
|
7528
|
+
} else if (JsArray$1.isArray(padding)) {
|
|
7520
7529
|
const pads = [k - 1 - padding[0], k - 1 - padding[1]];
|
|
7521
7530
|
pad1 = pads[0];
|
|
7522
7531
|
padLen = pads[0] + pads[1];
|
|
@@ -8035,19 +8044,34 @@ function dotProductAttention(query, key$1, value, opts = {}) {
|
|
|
8035
8044
|
//#region src/library/random.ts
|
|
8036
8045
|
var random_exports = {};
|
|
8037
8046
|
__export(random_exports, {
|
|
8047
|
+
ball: () => ball,
|
|
8038
8048
|
bernoulli: () => bernoulli,
|
|
8039
8049
|
bits: () => bits,
|
|
8040
8050
|
categorical: () => categorical,
|
|
8041
8051
|
cauchy: () => cauchy,
|
|
8052
|
+
choice: () => choice,
|
|
8053
|
+
doubleSidedMaxwell: () => doubleSidedMaxwell,
|
|
8042
8054
|
exponential: () => exponential,
|
|
8055
|
+
geometric: () => geometric,
|
|
8043
8056
|
gumbel: () => gumbel,
|
|
8044
8057
|
key: () => key,
|
|
8045
8058
|
laplace: () => laplace,
|
|
8059
|
+
logistic: () => logistic,
|
|
8060
|
+
lognormal: () => lognormal,
|
|
8061
|
+
maxwell: () => maxwell,
|
|
8046
8062
|
multivariateNormal: () => multivariateNormal,
|
|
8047
8063
|
normal: () => normal,
|
|
8064
|
+
pareto: () => pareto,
|
|
8065
|
+
permutation: () => permutation,
|
|
8066
|
+
rademacher: () => rademacher,
|
|
8067
|
+
randint: () => randint,
|
|
8068
|
+
rayleigh: () => rayleigh,
|
|
8048
8069
|
split: () => split,
|
|
8049
|
-
|
|
8070
|
+
triangular: () => triangular,
|
|
8071
|
+
uniform: () => uniform,
|
|
8072
|
+
weibullMin: () => weibullMin
|
|
8050
8073
|
});
|
|
8074
|
+
const JsArray = globalThis.Array;
|
|
8051
8075
|
function validateKeyShape(key$1, scalar = false) {
|
|
8052
8076
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
8053
8077
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
@@ -8100,6 +8124,21 @@ const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxv
|
|
|
8100
8124
|
else return rand.mul(maxval - minval).add(minval);
|
|
8101
8125
|
}, { staticArgnums: [1, 2] });
|
|
8102
8126
|
/**
|
|
8127
|
+
* @function
|
|
8128
|
+
* Sample points uniformly from the Euclidean unit ball in `d` dimensions.
|
|
8129
|
+
*
|
|
8130
|
+
* Only the Euclidean `p=2` case is currently supported.
|
|
8131
|
+
*/
|
|
8132
|
+
const ball = jit$1(function ball$1(key$1, d, { p = 2, shape: shape$1 = [] } = {}) {
|
|
8133
|
+
if (!Number.isInteger(d) || d <= 0) throw new Error(`ball: dimension must be a positive integer, got ${d}`);
|
|
8134
|
+
if (p !== 2) throw new Error("ball: only the Euclidean p=2 case is supported");
|
|
8135
|
+
const [k1, k2] = split(key$1, 2);
|
|
8136
|
+
const z = normal(k1, [...shape$1, d]);
|
|
8137
|
+
const norm = sqrt(z.ref.mul(z.ref).sum(-1, { keepdims: true }));
|
|
8138
|
+
const radius = exp(log(uniform(k2, [...shape$1, 1])).mul(1 / d));
|
|
8139
|
+
return z.div(norm).mul(radius);
|
|
8140
|
+
}, { staticArgnums: [1, 2] });
|
|
8141
|
+
/**
|
|
8103
8142
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
8104
8143
|
*
|
|
8105
8144
|
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
@@ -8161,6 +8200,57 @@ const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
|
8161
8200
|
return tan(u.sub(.5).mul(Math.PI));
|
|
8162
8201
|
}, { staticArgnums: [1] });
|
|
8163
8202
|
/**
|
|
8203
|
+
* Sample from a population with optional replacement and optional probabilities.
|
|
8204
|
+
*
|
|
8205
|
+
* This implements the common JAX-compatible cases: integer populations and
|
|
8206
|
+
* array populations along `axis`. Probabilities `p`, if provided, are sampled
|
|
8207
|
+
* via `categorical(log(p))`.
|
|
8208
|
+
*/
|
|
8209
|
+
function choice(key$1, a, { shape: shape$1 = [], replace = true, p, axis = 0 } = {}) {
|
|
8210
|
+
let n;
|
|
8211
|
+
let values = null;
|
|
8212
|
+
if (typeof a === "number") {
|
|
8213
|
+
if (!Number.isInteger(a) || a < 0) throw new Error(`choice: a must be a non-negative integer, got ${a}`);
|
|
8214
|
+
n = a;
|
|
8215
|
+
} else {
|
|
8216
|
+
values = fudgeArray(a);
|
|
8217
|
+
axis = checkAxis(axis, values.ndim);
|
|
8218
|
+
n = values.shape[axis];
|
|
8219
|
+
}
|
|
8220
|
+
let indices;
|
|
8221
|
+
if (p !== void 0) indices = categorical(key$1, log(p), {
|
|
8222
|
+
shape: shape$1,
|
|
8223
|
+
replace
|
|
8224
|
+
});
|
|
8225
|
+
else if (replace) indices = randint(key$1, {
|
|
8226
|
+
minval: 0,
|
|
8227
|
+
maxval: n,
|
|
8228
|
+
shape: shape$1
|
|
8229
|
+
});
|
|
8230
|
+
else {
|
|
8231
|
+
const k = shape$1.reduce((acc, x) => acc * x, 1);
|
|
8232
|
+
if (k > n) throw new Error(`Number of samples without replacement (${k}) cannot exceed population size (${n}).`);
|
|
8233
|
+
indices = permutation(key$1, n).slice([0, k]).reshape(shape$1);
|
|
8234
|
+
}
|
|
8235
|
+
if (values === null) return indices;
|
|
8236
|
+
const index = JsArray(axis).fill([]);
|
|
8237
|
+
index.push(indices);
|
|
8238
|
+
return values.slice(...index);
|
|
8239
|
+
}
|
|
8240
|
+
/**
|
|
8241
|
+
* @function
|
|
8242
|
+
* Sample double-sided Maxwell random values with the provided location and scale.
|
|
8243
|
+
*/
|
|
8244
|
+
const doubleSidedMaxwell = jit$1(function doubleSidedMaxwell$1(key$1, loc, scale, shape$1 = []) {
|
|
8245
|
+
loc = fudgeArray(loc);
|
|
8246
|
+
scale = fudgeArray(scale);
|
|
8247
|
+
const [k1, k2] = split(key$1, 2);
|
|
8248
|
+
return rademacher(k1, {
|
|
8249
|
+
shape: shape$1,
|
|
8250
|
+
dtype: DType.Float32
|
|
8251
|
+
}).mul(maxwell(k2, shape$1)).mul(scale).add(loc);
|
|
8252
|
+
}, { staticArgnums: [3] });
|
|
8253
|
+
/**
|
|
8164
8254
|
* @function
|
|
8165
8255
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
8166
8256
|
*/
|
|
@@ -8170,6 +8260,14 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
8170
8260
|
}, { staticArgnums: [1] });
|
|
8171
8261
|
/**
|
|
8172
8262
|
* @function
|
|
8263
|
+
* Sample geometric random values: the number of trials until first success.
|
|
8264
|
+
*/
|
|
8265
|
+
const geometric = jit$1(function geometric$1(key$1, p, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
|
|
8266
|
+
p = fudgeArray(p);
|
|
8267
|
+
return floor(log1p(negative(uniform(key$1, shape$1))).div(log1p(negative(p)))).add(1).astype(dtype);
|
|
8268
|
+
}, { staticArgnums: [2] });
|
|
8269
|
+
/**
|
|
8270
|
+
* @function
|
|
8173
8271
|
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
8174
8272
|
*
|
|
8175
8273
|
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
@@ -8194,6 +8292,32 @@ const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
|
8194
8292
|
}, { staticArgnums: [1] });
|
|
8195
8293
|
/**
|
|
8196
8294
|
* @function
|
|
8295
|
+
* Sample from a logistic distribution with location 0 and scale 1.
|
|
8296
|
+
*
|
|
8297
|
+
* Uses inverse transform sampling: `x = log(u) - log(1-u)`.
|
|
8298
|
+
*/
|
|
8299
|
+
const logistic = jit$1(function logistic$1(key$1, shape$1 = []) {
|
|
8300
|
+
const u = uniform(key$1, shape$1);
|
|
8301
|
+
return log(u.ref).sub(log1p(negative(u)));
|
|
8302
|
+
}, { staticArgnums: [1] });
|
|
8303
|
+
/**
|
|
8304
|
+
* @function
|
|
8305
|
+
* Sample log-normal random values: `exp(sigma * normal(key, shape))`.
|
|
8306
|
+
*/
|
|
8307
|
+
const lognormal = jit$1(function lognormal$1(key$1, sigma = 1, shape$1 = []) {
|
|
8308
|
+
sigma = fudgeArray(sigma);
|
|
8309
|
+
return exp(normal(key$1, shape$1).mul(sigma));
|
|
8310
|
+
}, { staticArgnums: [2] });
|
|
8311
|
+
/**
|
|
8312
|
+
* @function
|
|
8313
|
+
* Sample Maxwell-distributed random values.
|
|
8314
|
+
*/
|
|
8315
|
+
const maxwell = jit$1(function maxwell$1(key$1, shape$1 = []) {
|
|
8316
|
+
const z = normal(key$1, [...shape$1, 3]);
|
|
8317
|
+
return sqrt(z.ref.mul(z).sum(-1));
|
|
8318
|
+
}, { staticArgnums: [1] });
|
|
8319
|
+
/**
|
|
8320
|
+
* @function
|
|
8197
8321
|
* Sample multivariate normal random values with given mean and covariance.
|
|
8198
8322
|
*
|
|
8199
8323
|
* The values are returned with the given shape, along with the final dimension
|
|
@@ -8234,6 +8358,97 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
8234
8358
|
const theta = u2.mul(2 * Math.PI);
|
|
8235
8359
|
return radius.mul(cos(theta));
|
|
8236
8360
|
}, { staticArgnums: [1] });
|
|
8361
|
+
/**
|
|
8362
|
+
* @function
|
|
8363
|
+
* Sample from a Pareto distribution with shape parameter `b` and support [1, ∞).
|
|
8364
|
+
*/
|
|
8365
|
+
const pareto = jit$1(function pareto$1(key$1, b, shape$1 = []) {
|
|
8366
|
+
b = fudgeArray(b);
|
|
8367
|
+
return exp(exponential(key$1, shape$1).div(b));
|
|
8368
|
+
}, { staticArgnums: [2] });
|
|
8369
|
+
/**
|
|
8370
|
+
* Return a random permutation of an integer range or of an array along `axis`.
|
|
8371
|
+
*/
|
|
8372
|
+
function permutation(key$1, x, axis = 0) {
|
|
8373
|
+
if (typeof x === "number") {
|
|
8374
|
+
if (!Number.isInteger(x) || x < 0) throw new Error(`permutation: x must be a non-negative integer, got ${x}`);
|
|
8375
|
+
return argsort(uniform(key$1, [x])).astype(DType.Int32);
|
|
8376
|
+
}
|
|
8377
|
+
const arr = fudgeArray(x);
|
|
8378
|
+
axis = checkAxis(axis, arr.ndim);
|
|
8379
|
+
const perm = permutation(key$1, arr.shape[axis]);
|
|
8380
|
+
const index = JsArray(axis).fill([]);
|
|
8381
|
+
index.push(perm);
|
|
8382
|
+
return arr.slice(...index);
|
|
8383
|
+
}
|
|
8384
|
+
/**
|
|
8385
|
+
* @function
|
|
8386
|
+
* Sample Rademacher random values, uniformly from {-1, 1}.
|
|
8387
|
+
*/
|
|
8388
|
+
const rademacher = jit$1(function rademacher$1(key$1, { shape: shape$1 = [], dtype = DType.Int32 } = {}) {
|
|
8389
|
+
if (dtype === DType.Uint32 || dtype === DType.Bool) throw new Error(`rademacher: unsupported dtype ${dtype}`);
|
|
8390
|
+
const one = array(1, {
|
|
8391
|
+
dtype,
|
|
8392
|
+
device: key$1.device
|
|
8393
|
+
});
|
|
8394
|
+
const minusOne = array(-1, {
|
|
8395
|
+
dtype,
|
|
8396
|
+
device: key$1.device
|
|
8397
|
+
});
|
|
8398
|
+
return where(bernoulli(key$1, .5, shape$1), one, minusOne);
|
|
8399
|
+
}, { staticArgnums: [1] });
|
|
8400
|
+
/**
|
|
8401
|
+
* @function
|
|
8402
|
+
* Sample integer values uniformly from `[minval, maxval)`.
|
|
8403
|
+
*
|
|
8404
|
+
* This uses modulo reduction of uniform 32-bit random bits. For ranges that do
|
|
8405
|
+
* not divide 2^32, this introduces a very small modulo bias.
|
|
8406
|
+
*/
|
|
8407
|
+
const randint = jit$1(function randint$1(key$1, { minval, maxval, shape: shape$1 = [], dtype = DType.Int32 }) {
|
|
8408
|
+
if (!Number.isInteger(minval) || !Number.isInteger(maxval)) throw new Error("randint: minval and maxval must be integers");
|
|
8409
|
+
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
8410
|
+
if (dtype !== DType.Int32 && dtype !== DType.Uint32) throw new Error(`randint: dtype must be int32 or uint32, got ${dtype}`);
|
|
8411
|
+
if (dtype === DType.Uint32 && minval < 0) throw new Error("randint: uint32 dtype requires minval >= 0");
|
|
8412
|
+
const range$1 = maxval - minval;
|
|
8413
|
+
return bits(key$1, shape$1).mod(range$1).astype(dtype).add(minval);
|
|
8414
|
+
}, { staticArgnums: [1] });
|
|
8415
|
+
/**
|
|
8416
|
+
* @function
|
|
8417
|
+
* Sample Rayleigh random values with the provided scale parameter.
|
|
8418
|
+
*/
|
|
8419
|
+
const rayleigh = jit$1(function rayleigh$1(key$1, scale = 1, shape$1 = []) {
|
|
8420
|
+
scale = fudgeArray(scale);
|
|
8421
|
+
return sqrt(exponential(key$1, shape$1).mul(2)).mul(scale);
|
|
8422
|
+
}, { staticArgnums: [2] });
|
|
8423
|
+
/**
|
|
8424
|
+
* @function
|
|
8425
|
+
* Sample triangular random values on `[left, right]` with the given mode.
|
|
8426
|
+
*/
|
|
8427
|
+
const triangular = jit$1(function triangular$1(key$1, left, mode, right, shape$1 = []) {
|
|
8428
|
+
left = fudgeArray(left);
|
|
8429
|
+
mode = fudgeArray(mode);
|
|
8430
|
+
right = fudgeArray(right);
|
|
8431
|
+
const u = uniform(key$1, shape$1);
|
|
8432
|
+
const width = right.ref.sub(left.ref);
|
|
8433
|
+
const leftSpan = mode.ref.sub(left.ref);
|
|
8434
|
+
const rightSpan = right.ref.sub(mode);
|
|
8435
|
+
const cutoff = leftSpan.ref.div(width.ref);
|
|
8436
|
+
const cond = u.ref.less(cutoff);
|
|
8437
|
+
const lower = left.add(sqrt(u.ref.mul(width.ref).mul(leftSpan)));
|
|
8438
|
+
const upper = right.sub(sqrt(negative(u).add(1).mul(width).mul(rightSpan)));
|
|
8439
|
+
return where(cond, lower, upper);
|
|
8440
|
+
}, { staticArgnums: [4] });
|
|
8441
|
+
/**
|
|
8442
|
+
* @function
|
|
8443
|
+
* Sample Weibull minimum random values.
|
|
8444
|
+
*
|
|
8445
|
+
* Uses `scale * exponential(key) ** (1 / concentration)`.
|
|
8446
|
+
*/
|
|
8447
|
+
const weibullMin = jit$1(function weibullMin$1(key$1, scale, concentration, shape$1 = []) {
|
|
8448
|
+
scale = fudgeArray(scale);
|
|
8449
|
+
concentration = fudgeArray(concentration);
|
|
8450
|
+
return scale.mul(exp(log(exponential(key$1, shape$1)).div(concentration)));
|
|
8451
|
+
}, { staticArgnums: [3] });
|
|
8237
8452
|
|
|
8238
8453
|
//#endregion
|
|
8239
8454
|
//#region src/library/scipy-special.ts
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-
|
|
1
|
+
import { AluGroup, AluOp, DEBUG, DType, Executable, SlotError, UnsupportedOpError, UnsupportedRoutineError, isFloatDtype, range, strip1, tuneNullopt } from "./backend-DI-V78Rk.js";
|
|
2
2
|
|
|
3
3
|
//#region src/backend/webgl/builtins.ts
|
|
4
4
|
const threefrySrc = `
|