@jax-js/jax 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -124,7 +124,6 @@ declare class ShapeTracker {
124
124
  /** Like pad(), but allows for negative values. */
125
125
  padOrShrink(arg: Pair[]): ShapeTracker;
126
126
  }
127
- //# sourceMappingURL=shape.d.ts.map
128
127
  //#endregion
129
128
  //#region src/utils.d.ts
130
129
  /**
@@ -407,6 +406,52 @@ declare class Reduction implements FpHashable {
407
406
  }
408
407
  /** Expression for accessing `indices` in input array with the given shape. */
409
408
  //#endregion
409
+ //#region src/routine.d.ts
410
+ /**
411
+ * Advanced operations that don't fit into the `AluExp` compiler representation.
412
+ *
413
+ * Some routines like iterative matrix algorithms, FFTs, or sorting may not be
414
+ * easy to express efficiently as a `Kernel` object. These also tend to be
415
+ * somewhat expensive, so the benefit of kernel fusion and inlining is less
416
+ * relevant.
417
+ *
418
+ * For these operations, we dispatch them as a custom operation on the backend,
419
+ * which each backend implements in a specific way. These are listed in the
420
+ * `Routines` enum below.
421
+ *
422
+ * Routines cannot be fused into other kernels and always operate on contiguous
423
+ * arrays (default `ShapeTracker`).
424
+ */
425
+ declare class Routine {
426
+ /** The name of the routine. */
427
+ readonly name: Routines;
428
+ /** Dtype and shape of the inputs and outputs. */
429
+ readonly type: RoutineType;
430
+ /** Extra parameters specific to the routine. */
431
+ readonly params?: any | undefined;
432
+ constructor(/** The name of the routine. */
433
+ name: Routines, /** Dtype and shape of the inputs and outputs. */
434
+ type: RoutineType, /** Extra parameters specific to the routine. */
435
+ params?: any | undefined);
436
+ }
437
+ /** One of the valid `Routine` that can be dispatched to backend. */
438
+ declare enum Routines {
439
+ /** Stable sorting algorithm along the last axis. */
440
+ Sort = "Sort",
441
+ /** Returns `int32` indices of the stably sorted array. */
442
+ Argsort = "Argsort",
443
+ /** Solve a triangular system of questions. */
444
+ TriangularSolve = "TriangularSolve",
445
+ /** Cholesky decomposition of 2D positive semi-definite matrices. */
446
+ Cholesky = "Cholesky",
447
+ }
448
+ interface RoutineType {
449
+ inputShapes: number[][];
450
+ inputDtypes: DType[];
451
+ outputShapes: number[][];
452
+ outputDtypes: DType[];
453
+ }
454
+ //#endregion
410
455
  //#region src/backend.d.ts
411
456
  type Device = "cpu" | "wasm" | "webgpu";
412
457
  declare const devices: Device[];
@@ -444,9 +489,13 @@ interface Backend {
444
489
  /** Read a range of bytes from a buffer, blocking variant. */
445
490
  readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
446
491
  /** Prepare an expression to be executed later. */
447
- prepare(kernel: Kernel): Promise<Executable>;
492
+ prepareKernel(kernel: Kernel): Promise<Executable>;
448
493
  /** Prepare an expression to be executed later, blocking variant. */
449
- prepareSync(kernel: Kernel): Executable;
494
+ prepareKernelSync(kernel: Kernel): Executable;
495
+ /** Prepare an advanced routine to be executed later. */
496
+ prepareRoutine(routine: Routine): Promise<Executable>;
497
+ /** Prepare an advanced routine to be executed later, blocking variant. */
498
+ prepareRoutineSync(routine: Routine): Executable;
450
499
  /**
451
500
  * Run a backend operation that was previously prepared.
452
501
  *
@@ -457,14 +506,69 @@ interface Backend {
457
506
  dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
458
507
  }
459
508
  declare class Executable<T = any> {
460
- readonly kernel: Kernel;
461
- /** Extra data specific to the backend running this kernel. */
509
+ /** The `Kernel` or `Routine` that was prepared. */
510
+ readonly source: Kernel | Routine;
511
+ /** Extra data specific to the backend running this executable. */
462
512
  readonly data: T;
463
- constructor(kernel: Kernel, /** Extra data specific to the backend running this kernel. */
513
+ constructor(/** The `Kernel` or `Routine` that was prepared. */
514
+ source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
464
515
  data: T);
465
516
  }
517
+ declare namespace numpy_fft_d_exports {
518
+ export { ComplexPair, fft, ifft };
519
+ }
520
+ /**
521
+ * A pair of arrays representing real and imaginary part `a + bj`. Both arrays
522
+ * must have the same shape.
523
+ */
524
+ type ComplexPair = {
525
+ real: Array;
526
+ imag: Array;
527
+ };
528
+ /**
529
+ * Compute a one-dimensional discrete Fourier transform.
530
+ *
531
+ * Currently, the size of the axis must be a power of two.
532
+ */
533
+ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
534
+ /**
535
+ * Compute a one-dimensional inverse discrete Fourier transform.
536
+ *
537
+ * Currently, the size of the axis must be a power of two.
538
+ */
539
+ declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
540
+ declare namespace numpy_linalg_d_exports {
541
+ export { cholesky$1 as cholesky, diagonal, lstsq, matmul, matrixTranspose, outer, tensordot, trace, vecdot };
542
+ }
543
+ /**
544
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
545
+ *
546
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
547
+ * the input matrix, which is on by default.
548
+ */
549
+ declare function cholesky$1(a: ArrayLike, {
550
+ upper,
551
+ symmetrizeInput
552
+ }?: {
553
+ upper?: boolean;
554
+ symmetrizeInput?: boolean;
555
+ }): Array;
556
+ /**
557
+ * Return the least-squares solution to a linear equation.
558
+ *
559
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
560
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
561
+ *
562
+ * This currently uses Cholesky decomposition to solve the normal equations,
563
+ * under the hood. The method is not as robust as QR or SVD.
564
+ *
565
+ * @param a coefficient matrix of shape `(M, N)`
566
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
567
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
568
+ */
569
+ declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
466
570
  declare namespace numpy_d_exports {
467
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
571
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sort, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
468
572
  }
469
573
  declare const float32 = DType.Float32;
470
574
  declare const int32 = DType.Int32;
@@ -590,6 +694,20 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
590
694
  declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
591
695
  /** Return the maximum of array elements along a given axis. */
592
696
  declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
697
+ /**
698
+ * Test whether all array elements along a given axis evaluate to True.
699
+ *
700
+ * Returns a boolean array with the same shape as `a` with the specified axis
701
+ * removed. If axis is None, returns a scalar.
702
+ */
703
+ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
704
+ /**
705
+ * Test whether any array element along a given axis evaluates to True.
706
+ *
707
+ * Returns a boolean array with the same shape as `a` with the specified axis
708
+ * removed. If axis is None, returns a scalar.
709
+ */
710
+ declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
593
711
  /** Return the peak-to-peak range along a given axis (`max - min`). */
594
712
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
595
713
  /** Compute the average of the array elements along the specified axis. */
@@ -615,8 +733,6 @@ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
615
733
  * two-phase parallel reduction algorithm.
616
734
  */
617
735
  declare function cumsum(a: ArrayLike, axis?: number): Array;
618
- /** @function Alternative name for `jax.numpy.cumsum()`. */
619
- declare const cumulativeSum: typeof cumsum;
620
736
  /** Reverse the elements in an array along the given axes. */
621
737
  declare function flip(x: ArrayLike, axis?: Axis): Array;
622
738
  /**
@@ -662,12 +778,29 @@ declare function columnStack(xs: ArrayLike[]): Array;
662
778
  declare function flipud(x: ArrayLike): Array;
663
779
  /** Flip an array horizontally (axis=1). */
664
780
  declare function fliplr(x: ArrayLike): Array;
665
- /** @function Alternative name for `numpy.transpose()`. */
666
- declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
781
+ /** Transpose the last two dimensions of an array. */
782
+ declare function matrixTranspose(a: ArrayLike): Array;
667
783
  /** Return a 1-D flattened array containing the elements of the input. */
668
784
  declare function ravel(a: ArrayLike): Array;
669
785
  /** Remove one or more length-1 axes from an array. */
670
786
  declare function squeeze(a: ArrayLike, axis?: Axis): Array;
787
+ /**
788
+ * Expand the shape of an array by inserting new axes of length 1.
789
+ *
790
+ * @param a - Input array.
791
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
792
+ * is placed. Can be a single integer or an array of integers.
793
+ * @returns Array with the number of dimensions increased.
794
+ *
795
+ * @example
796
+ * ```ts
797
+ * const x = np.array([1, 2]);
798
+ * np.expandDims(x, 0); // Shape [1, 2]
799
+ * np.expandDims(x, 1); // Shape [2, 1]
800
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
801
+ * ```
802
+ */
803
+ declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
671
804
  /**
672
805
  * Repeat each element of an array after themselves.
673
806
  *
@@ -714,6 +847,22 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
714
847
  declare function diag(v: ArrayLike, k?: number): Array;
715
848
  /** Calculate the sum of the diagonal of an array along the given axes. */
716
849
  declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
850
+ /**
851
+ * Return a sorted copy of an array.
852
+ *
853
+ * The array is sorted along a specified axis (the last by default). This may be
854
+ * an unstable sort, and it dispatches to device-specific implementation.
855
+ */
856
+ declare function sort(a: ArrayLike, axis?: number): Array;
857
+ /**
858
+ * Return indices that would sort an array. This may be an unstable sorting
859
+ * algorithm; it need not preserve order of indices in ties.
860
+ *
861
+ * Returns an array of `int32` indices.
862
+ *
863
+ * The array is sorted along a specified axis (the last by default).
864
+ */
865
+ declare function argsort(a: ArrayLike, axis?: number): Array;
717
866
  /** Return if two arrays are element-wise equal within a tolerance. */
718
867
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
719
868
  rtol?: number;
@@ -785,6 +934,10 @@ declare function vecdot(x: ArrayLike, y: ArrayLike, {
785
934
  * Like vecdot() but flattens the arguments first into vectors.
786
935
  */
787
936
  declare function vdot(x: ArrayLike, y: ArrayLike): Array;
937
+ /** Convolution of two one-dimensional arrays. */
938
+ declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
939
+ /** Correlation of two one dimensional arrays. */
940
+ declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
788
941
  /**
789
942
  * Return a tuple of coordinate matrices from coordinate vectors.
790
943
  *
@@ -796,21 +949,6 @@ declare function meshgrid(xs: Array[], {
796
949
  }?: {
797
950
  indexing?: "xy" | "ij";
798
951
  }): Array[];
799
- /**
800
- * Return an array with ones on and below the diagonal and zeros elsewhere.
801
- *
802
- * If `k` is provided, it specifies the sub-diagonal on and below which the
803
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
804
- * `k>0` is above it.
805
- */
806
- declare function tri(n: number, m?: number, k?: number, {
807
- dtype,
808
- device
809
- }?: DTypeAndDevice): Array;
810
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
811
- declare function tril(a: ArrayLike, k?: number): Array;
812
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
813
- declare function triu(a: ArrayLike, k?: number): Array;
814
952
  /**
815
953
  * Clip (limit) the values in an array.
816
954
  *
@@ -827,8 +965,6 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
827
965
  * This is the same function as `jax.numpy.abs()`.
828
966
  */
829
967
  declare function absolute(x: ArrayLike): Array;
830
- /** @function Alias of `jax.numpy.absolute()`. */
831
- declare const abs: typeof absolute;
832
968
  /** Return an element-wise indication of sign of the input. */
833
969
  declare function sign(x: ArrayLike): Array;
834
970
  /** @function Return element-wise positive values of the input (no-op). */
@@ -882,12 +1018,6 @@ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
882
1018
  * The output is ill-defined when both x and y are zero.
883
1019
  */
884
1020
  declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
885
- /** @function Alias of `jax.numpy.acos()`. */
886
- declare const arccos: typeof acos;
887
- /** @function Alias of `jax.numpy.atan()`. */
888
- declare const arctan: (x: ArrayLike) => Array;
889
- /** @function Alias of `jax.numpy.atan2()`. */
890
- declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
891
1021
  /** Element-wise subtraction, with broadcasting. */
892
1022
  declare function subtract(x: ArrayLike, y: ArrayLike): Array;
893
1023
  /** Calculates the floating-point division of x by y element-wise. */
@@ -902,8 +1032,6 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
902
1032
  * Calculate element-wise remainder of the division (matches sign of y).
903
1033
  */
904
1034
  declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
905
- /** @function Alias of `jax.numpy.trueDivide()`. */
906
- declare const divide: typeof trueDivide;
907
1035
  /** Round input to the nearest integer towards zero. */
908
1036
  declare function trunc(x: ArrayLike): Array;
909
1037
  /**
@@ -943,8 +1071,6 @@ declare const degrees: typeof rad2deg;
943
1071
  * Computes first array raised to power of second array, element-wise.
944
1072
  */
945
1073
  declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
946
- /** @function Alias of `jax.numpy.power()`. */
947
- declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
948
1074
  /** @function Calculate the element-wise cube root of the input array. */
949
1075
  declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
950
1076
  /**
@@ -989,12 +1115,6 @@ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
989
1115
  * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
990
1116
  */
991
1117
  declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
992
- /** @function Alias of `jax.numpy.arcsinh()`. */
993
- declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
994
- /** @function Alias of `jax.numpy.arccosh()`. */
995
- declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
996
- /** @function Alias of `jax.numpy.arctanh()`. */
997
- declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
998
1118
  /**
999
1119
  * Compute the variance of an array.
1000
1120
  *
@@ -1021,6 +1141,10 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1021
1141
  mean?: ArrayLike;
1022
1142
  correction?: number;
1023
1143
  } & ReduceOpts): Array;
1144
+ /** Estimate the sample covariance of a set of variables. */
1145
+ declare function cov(x: ArrayLike, y?: ArrayLike): Array;
1146
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
1147
+ declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
1024
1148
  /** Test element-wise for positive or negative infinity, return bool array. */
1025
1149
  declare function isinf(x: ArrayLike): Array;
1026
1150
  /** Test element-wise for NaN (Not a Number). */
@@ -1034,7 +1158,6 @@ declare function isposinf(x: ArrayLike): Array;
1034
1158
  * Test element-wise for finite values (not infinity or NaN).
1035
1159
  */
1036
1160
  declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1037
- //# sourceMappingURL=numpy.d.ts.map
1038
1161
  declare namespace tree_d_exports {
1039
1162
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
1040
1163
  }
@@ -1087,7 +1210,7 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
1087
1210
  interface ConvParams {
1088
1211
  vmapDims: number;
1089
1212
  strides: number[];
1090
- padding: [number, number][];
1213
+ padding: Pair[];
1091
1214
  lhsDilation: number[];
1092
1215
  rhsDilation: number[];
1093
1216
  }
@@ -1168,9 +1291,21 @@ declare class Jaxpr implements FpHashable {
1168
1291
  * - Remove no-op movement operations.
1169
1292
  */
1170
1293
  simplify(): Jaxpr;
1171
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1294
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1172
1295
  flatten(): Jaxpr;
1173
1296
  }
1297
+ /** Jaxpr with a collection of associated, traced constants. */
1298
+ declare class ClosedJaxpr {
1299
+ readonly jaxpr: Jaxpr;
1300
+ readonly consts: Tracer[];
1301
+ constructor(jaxpr: Jaxpr, consts: Tracer[]);
1302
+ /** String representation of this Jaxpr. */
1303
+ toString(): string;
1304
+ /** Apply a function to the underlying Jaxpr. */
1305
+ mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
1306
+ /** Dispose of the constants in this Jaxpr. */
1307
+ dispose(): void;
1308
+ }
1174
1309
  /** @inline */
1175
1310
  type JitOpts = {
1176
1311
  staticArgnums?: number[];
@@ -1193,7 +1328,9 @@ declare enum Primitive {
1193
1328
  Mul = "mul",
1194
1329
  Idiv = "idiv",
1195
1330
  Mod = "mod",
1196
- // uses sign of dividend, C-style, matches JS but not Python
1331
+ // uses sign of numerator, C-style, matches JS but not Python
1332
+ Min = "min",
1333
+ Max = "max",
1197
1334
  Neg = "neg",
1198
1335
  Reciprocal = "reciprocal",
1199
1336
  Floor = "floor",
@@ -1201,7 +1338,6 @@ declare enum Primitive {
1201
1338
  StopGradient = "stop_gradient",
1202
1339
  Cast = "cast",
1203
1340
  Bitcast = "bitcast",
1204
- RandomBits = "random_bits",
1205
1341
  Sin = "sin",
1206
1342
  Cos = "cos",
1207
1343
  Asin = "asin",
@@ -1211,8 +1347,6 @@ declare enum Primitive {
1211
1347
  Erf = "erf",
1212
1348
  Erfc = "erfc",
1213
1349
  Sqrt = "sqrt",
1214
- Min = "min",
1215
- Max = "max",
1216
1350
  Reduce = "reduce",
1217
1351
  Dot = "dot",
1218
1352
  // sum(x*y, axis=-1)
@@ -1222,14 +1356,23 @@ declare enum Primitive {
1222
1356
  PoolTranspose = "pool_transpose",
1223
1357
  Compare = "compare",
1224
1358
  Where = "where",
1359
+ RandomBits = "random_bits",
1360
+ Gather = "gather",
1225
1361
  Transpose = "transpose",
1226
1362
  Broadcast = "broadcast",
1227
1363
  Reshape = "reshape",
1228
1364
  Flip = "flip",
1229
1365
  Shrink = "shrink",
1230
1366
  Pad = "pad",
1231
- Gather = "gather",
1232
- JitCall = "jit_call",
1367
+ Sort = "sort",
1368
+ // sort(x, axis=-1)
1369
+ Argsort = "argsort",
1370
+ // argsort(x, axis=-1)
1371
+ TriangularSolve = "triangular_solve",
1372
+ // A is upper triangular, A @ X.T = B.T
1373
+ Cholesky = "cholesky",
1374
+ // A is positive-definite, A = L @ L^T
1375
+ Jit = "jit",
1233
1376
  }
1234
1377
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1235
1378
  [Primitive.Cast]: {
@@ -1255,6 +1398,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1255
1398
  [Primitive.Compare]: {
1256
1399
  op: CompareOp;
1257
1400
  };
1401
+ [Primitive.RandomBits]: {
1402
+ shape: number[];
1403
+ mode: "xor" | 0 | 1;
1404
+ };
1405
+ [Primitive.Gather]: {
1406
+ axis: number[];
1407
+ outDim: number;
1408
+ };
1258
1409
  [Primitive.Transpose]: {
1259
1410
  perm: number[];
1260
1411
  };
@@ -1262,10 +1413,6 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1262
1413
  shape: number[];
1263
1414
  axis: number[];
1264
1415
  };
1265
- [Primitive.RandomBits]: {
1266
- shape: number[];
1267
- mode: "xor" | 0 | 1;
1268
- };
1269
1416
  [Primitive.Reshape]: {
1270
1417
  shape: number[];
1271
1418
  };
@@ -1278,15 +1425,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1278
1425
  [Primitive.Pad]: {
1279
1426
  width: Pair[];
1280
1427
  };
1281
- [Primitive.Gather]: {
1282
- axis: number[];
1283
- outDim: number;
1284
- };
1285
- [Primitive.JitCall]: {
1428
+ [Primitive.Jit]: {
1286
1429
  name: string;
1287
1430
  jaxpr: Jaxpr;
1288
1431
  numConsts: number;
1289
1432
  };
1433
+ [Primitive.TriangularSolve]: {
1434
+ unitDiagonal: boolean;
1435
+ };
1290
1436
  }
1291
1437
  /** Type of parameters taken by each primitive. */
1292
1438
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
@@ -1454,7 +1600,7 @@ declare abstract class Tracer {
1454
1600
  sub(other: this | TracerValue): this;
1455
1601
  /** Divide an array by this one. */
1456
1602
  div(other: this | TracerValue): this;
1457
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
1603
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
1458
1604
  diagonal(offset?: number, axis1?: number, axis2?: number): this;
1459
1605
  /** Flatten the array without changing its data. */
1460
1606
  flatten(): this;
@@ -1473,6 +1619,19 @@ declare abstract class Tracer {
1473
1619
  * ```
1474
1620
  */
1475
1621
  [Symbol.iterator](): IterableIterator<this>;
1622
+ /**
1623
+ * Return a sorted copy of an array in ascending order.
1624
+ *
1625
+ * See `jax.numpy.sort` for full docs.
1626
+ */
1627
+ sort(axis?: number): this;
1628
+ /**
1629
+ * Return the indices that would sort an array. This may not be a stable
1630
+ * sorting algorithm; it need not preserve order of indices in ties.
1631
+ *
1632
+ * See `jax.numpy.argsort` for full docs.
1633
+ */
1634
+ argsort(axis?: number): this;
1476
1635
  /**
1477
1636
  * Slice an array along one or more axes.
1478
1637
  *
@@ -1515,6 +1674,7 @@ declare class ShapedArray implements AbstractValue {
1515
1674
  constructor(shape: number[], dtype: DType, weakType: boolean);
1516
1675
  static fromAval(aval: AbstractValue): ShapedArray;
1517
1676
  get ndim(): number;
1677
+ get size(): number;
1518
1678
  toString(): string;
1519
1679
  equals(other: ShapedArray): boolean;
1520
1680
  }
@@ -1532,12 +1692,12 @@ type ArrayLike = Array | number | boolean;
1532
1692
  declare class PendingExecute {
1533
1693
  #private;
1534
1694
  readonly backend: Backend;
1535
- readonly kernel: Kernel;
1695
+ readonly source: Kernel | Routine;
1536
1696
  readonly inputs: Slot[];
1537
1697
  readonly outputs: Slot[];
1538
1698
  prepared: Executable | null;
1539
1699
  submitted: boolean;
1540
- constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
1700
+ constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
1541
1701
  updateRc(delta: number): void;
1542
1702
  prepare(): Promise<void>;
1543
1703
  prepareSync(): void;
@@ -1569,7 +1729,6 @@ type ArrayConstructorArgs = {
1569
1729
  */
1570
1730
  declare class Array extends Tracer {
1571
1731
  #private;
1572
- id: number;
1573
1732
  /**
1574
1733
  * @ignore
1575
1734
  * Constructs an array from source, shape and backend. Note that if the source
@@ -1703,6 +1862,21 @@ declare function arange(start: number, stop?: number, step?: number, {
1703
1862
  dtype,
1704
1863
  device
1705
1864
  }?: DTypeAndDevice): Array;
1865
+ /**
1866
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
1867
+ *
1868
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1869
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1870
+ * `k>0` is above it.
1871
+ */
1872
+ declare function tri(n: number, m?: number, k?: number, {
1873
+ dtype,
1874
+ device
1875
+ }?: DTypeAndDevice): Array;
1876
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1877
+ declare function tril(a: ArrayLike, k?: number): Array;
1878
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1879
+ declare function triu(a: ArrayLike, k?: number): Array;
1706
1880
  /**
1707
1881
  * Return evenly spaced numbers over a specified interval.
1708
1882
  *
@@ -1716,8 +1890,71 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1716
1890
  dtype,
1717
1891
  device
1718
1892
  }?: DTypeAndDevice): Array;
1893
+ declare namespace lax_linalg_d_exports {
1894
+ export { cholesky, triangularSolve };
1895
+ }
1896
+ /**
1897
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
1898
+ *
1899
+ * The Cholesky decomposition of a matrix `A` is:
1900
+ *
1901
+ * - A = L @ L^T (for upper=false, default)
1902
+ * - A = U^T @ U (for upper=true)
1903
+ *
1904
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
1905
+ * The input matrix must be symmetric and positive-definite.
1906
+ *
1907
+ * @example
1908
+ * ```ts
1909
+ * import { lax, numpy as np } from "@jax-js/jax";
1910
+ *
1911
+ * const x = np.array([[2., 1.], [1., 2.]]);
1912
+ *
1913
+ * // Lower Cholesky factorization (default):
1914
+ * const L = lax.linalg.cholesky(x);
1915
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
1916
+ *
1917
+ * // Upper Cholesky factorization:
1918
+ * const U = lax.linalg.cholesky(x, { upper: true });
1919
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
1920
+ * ```
1921
+ */
1922
+ declare function cholesky(a: ArrayLike, {
1923
+ upper
1924
+ }?: {
1925
+ upper?: boolean;
1926
+ }): Array;
1927
+ /**
1928
+ * Solve a triangular linear system.
1929
+ *
1930
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
1931
+ * where `a` is a triangular matrix.
1932
+ *
1933
+ * @example
1934
+ * ```ts
1935
+ * import { lax, numpy as np } from "@jax-js/jax";
1936
+ *
1937
+ * const L = np.array([[2., 0.], [1., 3.]]);
1938
+ * const b = np.array([4., 7.]).reshape([2, 1]);
1939
+ *
1940
+ * // Solve L @ x = b
1941
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
1942
+ * // x = [[2.], [5./3.]]
1943
+ * ```
1944
+ */
1945
+ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1946
+ leftSide,
1947
+ lower,
1948
+ transposeA,
1949
+ unitDiagonal
1950
+ }?: {
1951
+ leftSide?: boolean;
1952
+ lower?: boolean;
1953
+ transposeA?: boolean;
1954
+ unitDiagonal?: boolean;
1955
+ }): Array;
1719
1956
  declare namespace lax_d_exports {
1720
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
1957
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1721
1958
  }
1722
1959
  /**
1723
1960
  * Dimension numbers for general `dot()` primitive.
@@ -1748,7 +1985,7 @@ declare function dot(lhs: Array, rhs: Array, {
1748
1985
  lhsBatchDims: lb,
1749
1986
  rhsBatchDims: rb
1750
1987
  }?: DotDimensionNumbers): Array;
1751
- type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1988
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
1752
1989
  /**
1753
1990
  * General n-dimensional convolution operator, with optional dilation.
1754
1991
  *
@@ -1788,9 +2025,8 @@ declare function erfc(x: ArrayLike): Array;
1788
2025
  * forward or reverse-mode automatic differentiation.
1789
2026
  */
1790
2027
  declare function stopGradient(x: ArrayLike): Array;
1791
- //# sourceMappingURL=lax.d.ts.map
1792
2028
  declare namespace nn_d_exports {
1793
- export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
2029
+ export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
1794
2030
  }
1795
2031
  /**
1796
2032
  * Rectified Linear Unit (ReLU) activation function:
@@ -1817,21 +2053,28 @@ declare function sigmoid(x: ArrayLike): Array;
1817
2053
  */
1818
2054
  declare function softplus(x: ArrayLike): Array;
1819
2055
  /**
1820
- * Soft-sign activation function, computed element-wise:
1821
- * `softsign(x) = x / (|x| + 1)`.
2056
+ * @function
2057
+ * Sparse plus function:
2058
+ *
2059
+ * - When `x <= -1`: `0`
2060
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
2061
+ * - When `x >= 1`: `x`
1822
2062
  */
1823
- declare function softSign(x: ArrayLike): Array;
2063
+ declare const sparsePlus: OwnedFunction<(x: ArrayLike) => Array>;
1824
2064
  /**
1825
2065
  * @function
1826
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1827
- * Swish, computed element-wise:
1828
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
1829
- *
1830
- * `swish()` and `silu()` are both aliases for the same function.
2066
+ * Sparse sigmoid activation function.
1831
2067
  *
1832
- * Reference: https://en.wikipedia.org/wiki/Swish_function
2068
+ * - When `x <= -1`: `0`
2069
+ * - When `-1 < x < 1`: `(x + 1) / 2`
2070
+ * - When `x >= 1`: `1`
1833
2071
  */
1834
- declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
2072
+ declare const sparseSigmoid: OwnedFunction<(x: ArrayLike) => Array>;
2073
+ /**
2074
+ * Soft-sign activation function, computed element-wise:
2075
+ * `softsign(x) = x / (|x| + 1)`.
2076
+ */
2077
+ declare function softSign(x: ArrayLike): Array;
1835
2078
  /**
1836
2079
  * @function
1837
2080
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
@@ -1842,7 +2085,7 @@ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1842
2085
  *
1843
2086
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1844
2087
  */
1845
- declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
2088
+ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1846
2089
  /**
1847
2090
  * Log-sigmoid activation function, computed element-wise:
1848
2091
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
@@ -1855,6 +2098,12 @@ declare function logSigmoid(x: ArrayLike): Array;
1855
2098
  declare const identity: (x: ArrayLike) => Array;
1856
2099
  /** Leaky rectified linear (ReLU) activation function */
1857
2100
  declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
2101
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
2102
+ declare function hardSigmoid(x: ArrayLike): Array;
2103
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
2104
+ declare function hardSilu(x: ArrayLike): Array;
2105
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
2106
+ declare function hardTanh(x: ArrayLike): Array;
1858
2107
  /**
1859
2108
  * Exponential linear unit activation function.
1860
2109
  *
@@ -1869,6 +2118,16 @@ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
1869
2118
  * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1870
2119
  */
1871
2120
  declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
2121
+ /**
2122
+ * @function
2123
+ * Scaled exponential linear unit activation.
2124
+ *
2125
+ * Computes the element-wise function:
2126
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
2127
+ *
2128
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
2129
+ */
2130
+ declare const selu: OwnedFunction<(x: ArrayLike) => Array>;
1872
2131
  /**
1873
2132
  * @function
1874
2133
  * Gaussion error linear unit (GELU) activation function.
@@ -1964,9 +2223,8 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1964
2223
  * ```
1965
2224
  */
1966
2225
  declare function oneHot(x: Array, numClasses: number): Array;
1967
- //# sourceMappingURL=nn.d.ts.map
1968
2226
  declare namespace random_d_exports {
1969
- export { bernoulli, bits, exponential, key, normal, split, uniform };
2227
+ export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, normal, split, uniform };
1970
2228
  }
1971
2229
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
1972
2230
  declare function key(seed: number): Array;
@@ -1989,11 +2247,33 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
1989
2247
  * and must be broadcastable to `shape`.
1990
2248
  */
1991
2249
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2250
+ /**
2251
+ * @function
2252
+ * Sample from a Cauchy distribution with location 0 and scale 1.
2253
+ *
2254
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
2255
+ */
2256
+ declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1992
2257
  /**
1993
2258
  * @function
1994
2259
  * Sample exponential random values according to `p(x) = exp(-x)`.
1995
2260
  */
1996
2261
  declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2262
+ /**
2263
+ * @function
2264
+ * Sample from a Gumbel distribution with location 0 and scale 1.
2265
+ *
2266
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
2267
+ */
2268
+ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2269
+ /**
2270
+ * @function
2271
+ * Sample from a Laplace distribution with location 0 and scale 1.
2272
+ *
2273
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
2274
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2275
+ */
2276
+ declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1997
2277
  /**
1998
2278
  * @function
1999
2279
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
@@ -2003,7 +2283,6 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
2003
2283
  * bitwise identical to JAX.
2004
2284
  */
2005
2285
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2006
- //# sourceMappingURL=random.d.ts.map
2007
2286
  declare namespace scipy_special_d_exports {
2008
2287
  export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
2009
2288
  }
@@ -2034,8 +2313,7 @@ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTr
2034
2313
  * Construct a Jaxpr by dynamically tracing a function with example inputs.
2035
2314
  */
2036
2315
  declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
2037
- jaxpr: Jaxpr;
2038
- consts: Array[];
2316
+ jaxpr: ClosedJaxpr;
2039
2317
  treedef: JsTreeDef;
2040
2318
  };
2041
2319
  /**
@@ -2083,11 +2361,6 @@ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F)
2083
2361
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2084
2362
  */
2085
2363
  declare const jacrev: typeof jacfwd;
2086
- /**
2087
- * @function
2088
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
2089
- */
2090
- declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2091
2364
  /**
2092
2365
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2093
2366
  *
@@ -2109,8 +2382,5 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2109
2382
  * default device.
2110
2383
  */
2111
2384
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2112
- //# sourceMappingURL=index.d.ts.map
2113
-
2114
2385
  //#endregion
2115
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2116
- //# sourceMappingURL=index.d.cts.map
2386
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };