@jax-js/jax 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -121,7 +121,6 @@ declare class ShapeTracker {
121
121
  /** Like pad(), but allows for negative values. */
122
122
  padOrShrink(arg: Pair[]): ShapeTracker;
123
123
  }
124
- //# sourceMappingURL=shape.d.ts.map
125
124
  //#endregion
126
125
  //#region src/utils.d.ts
127
126
  /**
@@ -404,6 +403,52 @@ declare class Reduction implements FpHashable {
404
403
  }
405
404
  /** Expression for accessing `indices` in input array with the given shape. */
406
405
  //#endregion
406
+ //#region src/routine.d.ts
407
+ /**
408
+ * Advanced operations that don't fit into the `AluExp` compiler representation.
409
+ *
410
+ * Some routines like iterative matrix algorithms, FFTs, or sorting may not be
411
+ * easy to express efficiently as a `Kernel` object. These also tend to be
412
+ * somewhat expensive, so the benefit of kernel fusion and inlining is less
413
+ * relevant.
414
+ *
415
+ * For these operations, we dispatch them as a custom operation on the backend,
416
+ * which each backend implements in a specific way. These are listed in the
417
+ * `Routines` enum below.
418
+ *
419
+ * Routines cannot be fused into other kernels and always operate on contiguous
420
+ * arrays (default `ShapeTracker`).
421
+ */
422
+ declare class Routine {
423
+ /** The name of the routine. */
424
+ readonly name: Routines;
425
+ /** Dtype and shape of the inputs and outputs. */
426
+ readonly type: RoutineType;
427
+ /** Extra parameters specific to the routine. */
428
+ readonly params?: any | undefined;
429
+ constructor(/** The name of the routine. */
430
+ name: Routines, /** Dtype and shape of the inputs and outputs. */
431
+ type: RoutineType, /** Extra parameters specific to the routine. */
432
+ params?: any | undefined);
433
+ }
434
+ /** One of the valid `Routine` that can be dispatched to backend. */
435
+ declare enum Routines {
436
+ /** Stable sorting algorithm along the last axis. */
437
+ Sort = "Sort",
438
+ /** Returns `int32` indices of the stably sorted array. */
439
+ Argsort = "Argsort",
440
+ /** Solve a triangular system of questions. */
441
+ TriangularSolve = "TriangularSolve",
442
+ /** Cholesky decomposition of 2D positive semi-definite matrices. */
443
+ Cholesky = "Cholesky",
444
+ }
445
+ interface RoutineType {
446
+ inputShapes: number[][];
447
+ inputDtypes: DType[];
448
+ outputShapes: number[][];
449
+ outputDtypes: DType[];
450
+ }
451
+ //#endregion
407
452
  //#region src/backend.d.ts
408
453
  type Device = "cpu" | "wasm" | "webgpu";
409
454
  declare const devices: Device[];
@@ -441,9 +486,13 @@ interface Backend {
441
486
  /** Read a range of bytes from a buffer, blocking variant. */
442
487
  readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
443
488
  /** Prepare an expression to be executed later. */
444
- prepare(kernel: Kernel): Promise<Executable>;
489
+ prepareKernel(kernel: Kernel): Promise<Executable>;
445
490
  /** Prepare an expression to be executed later, blocking variant. */
446
- prepareSync(kernel: Kernel): Executable;
491
+ prepareKernelSync(kernel: Kernel): Executable;
492
+ /** Prepare an advanced routine to be executed later. */
493
+ prepareRoutine(routine: Routine): Promise<Executable>;
494
+ /** Prepare an advanced routine to be executed later, blocking variant. */
495
+ prepareRoutineSync(routine: Routine): Executable;
447
496
  /**
448
497
  * Run a backend operation that was previously prepared.
449
498
  *
@@ -454,14 +503,69 @@ interface Backend {
454
503
  dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
455
504
  }
456
505
  declare class Executable<T = any> {
457
- readonly kernel: Kernel;
458
- /** Extra data specific to the backend running this kernel. */
506
+ /** The `Kernel` or `Routine` that was prepared. */
507
+ readonly source: Kernel | Routine;
508
+ /** Extra data specific to the backend running this executable. */
459
509
  readonly data: T;
460
- constructor(kernel: Kernel, /** Extra data specific to the backend running this kernel. */
510
+ constructor(/** The `Kernel` or `Routine` that was prepared. */
511
+ source: Kernel | Routine, /** Extra data specific to the backend running this executable. */
461
512
  data: T);
462
513
  }
514
+ declare namespace numpy_fft_d_exports {
515
+ export { ComplexPair, fft, ifft };
516
+ }
517
+ /**
518
+ * A pair of arrays representing real and imaginary part `a + bj`. Both arrays
519
+ * must have the same shape.
520
+ */
521
+ type ComplexPair = {
522
+ real: Array;
523
+ imag: Array;
524
+ };
525
+ /**
526
+ * Compute a one-dimensional discrete Fourier transform.
527
+ *
528
+ * Currently, the size of the axis must be a power of two.
529
+ */
530
+ declare function fft(a: ComplexPair, axis?: number): ComplexPair;
531
+ /**
532
+ * Compute a one-dimensional inverse discrete Fourier transform.
533
+ *
534
+ * Currently, the size of the axis must be a power of two.
535
+ */
536
+ declare function ifft(a: ComplexPair, axis?: number): ComplexPair;
537
+ declare namespace numpy_linalg_d_exports {
538
+ export { cholesky$1 as cholesky, diagonal, lstsq, matmul, matrixTranspose, outer, tensordot, trace, vecdot };
539
+ }
540
+ /**
541
+ * Compute the Cholesky decomposition of a (batched) positive-definite matrix.
542
+ *
543
+ * This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
544
+ * the input matrix, which is on by default.
545
+ */
546
+ declare function cholesky$1(a: ArrayLike, {
547
+ upper,
548
+ symmetrizeInput
549
+ }?: {
550
+ upper?: boolean;
551
+ symmetrizeInput?: boolean;
552
+ }): Array;
553
+ /**
554
+ * Return the least-squares solution to a linear equation.
555
+ *
556
+ * For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
557
+ * For underdetermined systems, this finds the minimum-norm solution for `x`.
558
+ *
559
+ * This currently uses Cholesky decomposition to solve the normal equations,
560
+ * under the hood. The method is not as robust as QR or SVD.
561
+ *
562
+ * @param a coefficient matrix of shape `(M, N)`
563
+ * @param b right-hand side of shape `(M,)` or `(M, K)`
564
+ * @return least-squares solution of shape `(N,)` or `(N, K)`
565
+ */
566
+ declare function lstsq(a: ArrayLike, b: ArrayLike): Array;
463
567
  declare namespace numpy_d_exports {
464
- export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, cos, cosh, cumsum, cumulativeSum, deg2rad, degrees, diag, diagonal, divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, positive, pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
568
+ export { Array, ArrayLike, DType, absolute as abs, absolute, acos, arccosh as acosh, add, all, allclose, any, arange, acos as arccos, arccosh, asin as arcsin, arcsinh, atan as arctan, atan2 as arctan2, arctanh, argmax, argmin, argsort, array, asin, arcsinh as asinh, astype, atan, atan2, arctanh as atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, ceil, clip, columnStack, concatenate, convolve, corrcoef, correlate, cos, cosh, cov, cumsum, cumsum as cumulativeSum, deg2rad, degrees, diag, diagonal, trueDivide as divide, dot$1 as dot, dstack, e, einsum, equal, eulerGamma, exp, exp2, expandDims, expm1, eye, numpy_fft_d_exports as fft, flip, fliplr, flipud, float16, float32, float64, floor, fmod, frexp, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, ldexp, less, lessEqual, numpy_linalg_d_exports as linalg, linspace, log, log10, log1p, log2, matmul, matrixTranspose, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, transpose as permuteDims, pi, positive, power as pow, power, prod, promoteTypes, ptp, rad2deg, radians, ravel, reciprocal, remainder, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sort, sqrt, square, squeeze, stack, std, subtract, sum, tan, tanh, tensordot, tile, trace, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
465
569
  }
466
570
  declare const float32 = DType.Float32;
467
571
  declare const int32 = DType.Int32;
@@ -587,6 +691,20 @@ declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
587
691
  declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
588
692
  /** Return the maximum of array elements along a given axis. */
589
693
  declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
694
+ /**
695
+ * Test whether all array elements along a given axis evaluate to True.
696
+ *
697
+ * Returns a boolean array with the same shape as `a` with the specified axis
698
+ * removed. If axis is None, returns a scalar.
699
+ */
700
+ declare function all(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
701
+ /**
702
+ * Test whether any array element along a given axis evaluates to True.
703
+ *
704
+ * Returns a boolean array with the same shape as `a` with the specified axis
705
+ * removed. If axis is None, returns a scalar.
706
+ */
707
+ declare function any(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
590
708
  /** Return the peak-to-peak range along a given axis (`max - min`). */
591
709
  declare function ptp(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
592
710
  /** Compute the average of the array elements along the specified axis. */
@@ -612,8 +730,6 @@ declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
612
730
  * two-phase parallel reduction algorithm.
613
731
  */
614
732
  declare function cumsum(a: ArrayLike, axis?: number): Array;
615
- /** @function Alternative name for `jax.numpy.cumsum()`. */
616
- declare const cumulativeSum: typeof cumsum;
617
733
  /** Reverse the elements in an array along the given axes. */
618
734
  declare function flip(x: ArrayLike, axis?: Axis): Array;
619
735
  /**
@@ -659,12 +775,29 @@ declare function columnStack(xs: ArrayLike[]): Array;
659
775
  declare function flipud(x: ArrayLike): Array;
660
776
  /** Flip an array horizontally (axis=1). */
661
777
  declare function fliplr(x: ArrayLike): Array;
662
- /** @function Alternative name for `numpy.transpose()`. */
663
- declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
778
+ /** Transpose the last two dimensions of an array. */
779
+ declare function matrixTranspose(a: ArrayLike): Array;
664
780
  /** Return a 1-D flattened array containing the elements of the input. */
665
781
  declare function ravel(a: ArrayLike): Array;
666
782
  /** Remove one or more length-1 axes from an array. */
667
783
  declare function squeeze(a: ArrayLike, axis?: Axis): Array;
784
+ /**
785
+ * Expand the shape of an array by inserting new axes of length 1.
786
+ *
787
+ * @param a - Input array.
788
+ * @param axis - Position(s) in the expanded axes where the new axis (or axes)
789
+ * is placed. Can be a single integer or an array of integers.
790
+ * @returns Array with the number of dimensions increased.
791
+ *
792
+ * @example
793
+ * ```ts
794
+ * const x = np.array([1, 2]);
795
+ * np.expandDims(x, 0); // Shape [1, 2]
796
+ * np.expandDims(x, 1); // Shape [2, 1]
797
+ * np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
798
+ * ```
799
+ */
800
+ declare function expandDims(a: ArrayLike, axis: number | number[]): Array;
668
801
  /**
669
802
  * Repeat each element of an array after themselves.
670
803
  *
@@ -711,6 +844,22 @@ declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?:
711
844
  declare function diag(v: ArrayLike, k?: number): Array;
712
845
  /** Calculate the sum of the diagonal of an array along the given axes. */
713
846
  declare function trace(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
847
+ /**
848
+ * Return a sorted copy of an array.
849
+ *
850
+ * The array is sorted along a specified axis (the last by default). This may be
851
+ * an unstable sort, and it dispatches to device-specific implementation.
852
+ */
853
+ declare function sort(a: ArrayLike, axis?: number): Array;
854
+ /**
855
+ * Return indices that would sort an array. This may be an unstable sorting
856
+ * algorithm; it need not preserve order of indices in ties.
857
+ *
858
+ * Returns an array of `int32` indices.
859
+ *
860
+ * The array is sorted along a specified axis (the last by default).
861
+ */
862
+ declare function argsort(a: ArrayLike, axis?: number): Array;
714
863
  /** Return if two arrays are element-wise equal within a tolerance. */
715
864
  declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
716
865
  rtol?: number;
@@ -782,6 +931,10 @@ declare function vecdot(x: ArrayLike, y: ArrayLike, {
782
931
  * Like vecdot() but flattens the arguments first into vectors.
783
932
  */
784
933
  declare function vdot(x: ArrayLike, y: ArrayLike): Array;
934
+ /** Convolution of two one-dimensional arrays. */
935
+ declare function convolve(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
936
+ /** Correlation of two one dimensional arrays. */
937
+ declare function correlate(x: Array, y: Array, mode?: "full" | "same" | "valid"): Array;
785
938
  /**
786
939
  * Return a tuple of coordinate matrices from coordinate vectors.
787
940
  *
@@ -793,21 +946,6 @@ declare function meshgrid(xs: Array[], {
793
946
  }?: {
794
947
  indexing?: "xy" | "ij";
795
948
  }): Array[];
796
- /**
797
- * Return an array with ones on and below the diagonal and zeros elsewhere.
798
- *
799
- * If `k` is provided, it specifies the sub-diagonal on and below which the
800
- * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
801
- * `k>0` is above it.
802
- */
803
- declare function tri(n: number, m?: number, k?: number, {
804
- dtype,
805
- device
806
- }?: DTypeAndDevice): Array;
807
- /** Return the lower triangle of an array. Must be of dimension >= 2. */
808
- declare function tril(a: ArrayLike, k?: number): Array;
809
- /** Return the upper triangle of an array. Must be of dimension >= 2. */
810
- declare function triu(a: ArrayLike, k?: number): Array;
811
949
  /**
812
950
  * Clip (limit) the values in an array.
813
951
  *
@@ -824,8 +962,6 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
824
962
  * This is the same function as `jax.numpy.abs()`.
825
963
  */
826
964
  declare function absolute(x: ArrayLike): Array;
827
- /** @function Alias of `jax.numpy.absolute()`. */
828
- declare const abs: typeof absolute;
829
965
  /** Return an element-wise indication of sign of the input. */
830
966
  declare function sign(x: ArrayLike): Array;
831
967
  /** @function Return element-wise positive values of the input (no-op). */
@@ -879,12 +1015,6 @@ declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
879
1015
  * The output is ill-defined when both x and y are zero.
880
1016
  */
881
1017
  declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
882
- /** @function Alias of `jax.numpy.acos()`. */
883
- declare const arccos: typeof acos;
884
- /** @function Alias of `jax.numpy.atan()`. */
885
- declare const arctan: (x: ArrayLike) => Array;
886
- /** @function Alias of `jax.numpy.atan2()`. */
887
- declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
888
1018
  /** Element-wise subtraction, with broadcasting. */
889
1019
  declare function subtract(x: ArrayLike, y: ArrayLike): Array;
890
1020
  /** Calculates the floating-point division of x by y element-wise. */
@@ -899,8 +1029,6 @@ declare const fmod: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
899
1029
  * Calculate element-wise remainder of the division (matches sign of y).
900
1030
  */
901
1031
  declare const remainder: OwnedFunction<(x: ArrayLike, y: ArrayLike) => Array>;
902
- /** @function Alias of `jax.numpy.trueDivide()`. */
903
- declare const divide: typeof trueDivide;
904
1032
  /** Round input to the nearest integer towards zero. */
905
1033
  declare function trunc(x: ArrayLike): Array;
906
1034
  /**
@@ -940,8 +1068,6 @@ declare const degrees: typeof rad2deg;
940
1068
  * Computes first array raised to power of second array, element-wise.
941
1069
  */
942
1070
  declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
943
- /** @function Alias of `jax.numpy.power()`. */
944
- declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
945
1071
  /** @function Calculate the element-wise cube root of the input array. */
946
1072
  declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
947
1073
  /**
@@ -986,12 +1112,6 @@ declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
986
1112
  * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
987
1113
  */
988
1114
  declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
989
- /** @function Alias of `jax.numpy.arcsinh()`. */
990
- declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
991
- /** @function Alias of `jax.numpy.arccosh()`. */
992
- declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
993
- /** @function Alias of `jax.numpy.arctanh()`. */
994
- declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
995
1115
  /**
996
1116
  * Compute the variance of an array.
997
1117
  *
@@ -1018,6 +1138,10 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
1018
1138
  mean?: ArrayLike;
1019
1139
  correction?: number;
1020
1140
  } & ReduceOpts): Array;
1141
+ /** Estimate the sample covariance of a set of variables. */
1142
+ declare function cov(x: ArrayLike, y?: ArrayLike): Array;
1143
+ /** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
1144
+ declare function corrcoef(x: ArrayLike, y?: ArrayLike): Array;
1021
1145
  /** Test element-wise for positive or negative infinity, return bool array. */
1022
1146
  declare function isinf(x: ArrayLike): Array;
1023
1147
  /** Test element-wise for NaN (Not a Number). */
@@ -1031,7 +1155,6 @@ declare function isposinf(x: ArrayLike): Array;
1031
1155
  * Test element-wise for finite values (not infinity or NaN).
1032
1156
  */
1033
1157
  declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
1034
- //# sourceMappingURL=numpy.d.ts.map
1035
1158
  declare namespace tree_d_exports {
1036
1159
  export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
1037
1160
  }
@@ -1084,7 +1207,7 @@ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefin
1084
1207
  interface ConvParams {
1085
1208
  vmapDims: number;
1086
1209
  strides: number[];
1087
- padding: [number, number][];
1210
+ padding: Pair[];
1088
1211
  lhsDilation: number[];
1089
1212
  rhsDilation: number[];
1090
1213
  }
@@ -1165,9 +1288,21 @@ declare class Jaxpr implements FpHashable {
1165
1288
  * - Remove no-op movement operations.
1166
1289
  */
1167
1290
  simplify(): Jaxpr;
1168
- /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1291
+ /** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
1169
1292
  flatten(): Jaxpr;
1170
1293
  }
1294
+ /** Jaxpr with a collection of associated, traced constants. */
1295
+ declare class ClosedJaxpr {
1296
+ readonly jaxpr: Jaxpr;
1297
+ readonly consts: Tracer[];
1298
+ constructor(jaxpr: Jaxpr, consts: Tracer[]);
1299
+ /** String representation of this Jaxpr. */
1300
+ toString(): string;
1301
+ /** Apply a function to the underlying Jaxpr. */
1302
+ mapJaxpr(f: (jaxpr: Jaxpr) => Jaxpr): ClosedJaxpr;
1303
+ /** Dispose of the constants in this Jaxpr. */
1304
+ dispose(): void;
1305
+ }
1171
1306
  /** @inline */
1172
1307
  type JitOpts = {
1173
1308
  staticArgnums?: number[];
@@ -1190,7 +1325,9 @@ declare enum Primitive {
1190
1325
  Mul = "mul",
1191
1326
  Idiv = "idiv",
1192
1327
  Mod = "mod",
1193
- // uses sign of dividend, C-style, matches JS but not Python
1328
+ // uses sign of numerator, C-style, matches JS but not Python
1329
+ Min = "min",
1330
+ Max = "max",
1194
1331
  Neg = "neg",
1195
1332
  Reciprocal = "reciprocal",
1196
1333
  Floor = "floor",
@@ -1198,7 +1335,6 @@ declare enum Primitive {
1198
1335
  StopGradient = "stop_gradient",
1199
1336
  Cast = "cast",
1200
1337
  Bitcast = "bitcast",
1201
- RandomBits = "random_bits",
1202
1338
  Sin = "sin",
1203
1339
  Cos = "cos",
1204
1340
  Asin = "asin",
@@ -1208,8 +1344,6 @@ declare enum Primitive {
1208
1344
  Erf = "erf",
1209
1345
  Erfc = "erfc",
1210
1346
  Sqrt = "sqrt",
1211
- Min = "min",
1212
- Max = "max",
1213
1347
  Reduce = "reduce",
1214
1348
  Dot = "dot",
1215
1349
  // sum(x*y, axis=-1)
@@ -1219,14 +1353,23 @@ declare enum Primitive {
1219
1353
  PoolTranspose = "pool_transpose",
1220
1354
  Compare = "compare",
1221
1355
  Where = "where",
1356
+ RandomBits = "random_bits",
1357
+ Gather = "gather",
1222
1358
  Transpose = "transpose",
1223
1359
  Broadcast = "broadcast",
1224
1360
  Reshape = "reshape",
1225
1361
  Flip = "flip",
1226
1362
  Shrink = "shrink",
1227
1363
  Pad = "pad",
1228
- Gather = "gather",
1229
- JitCall = "jit_call",
1364
+ Sort = "sort",
1365
+ // sort(x, axis=-1)
1366
+ Argsort = "argsort",
1367
+ // argsort(x, axis=-1)
1368
+ TriangularSolve = "triangular_solve",
1369
+ // A is upper triangular, A @ X.T = B.T
1370
+ Cholesky = "cholesky",
1371
+ // A is positive-definite, A = L @ L^T
1372
+ Jit = "jit",
1230
1373
  }
1231
1374
  interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1232
1375
  [Primitive.Cast]: {
@@ -1252,6 +1395,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1252
1395
  [Primitive.Compare]: {
1253
1396
  op: CompareOp;
1254
1397
  };
1398
+ [Primitive.RandomBits]: {
1399
+ shape: number[];
1400
+ mode: "xor" | 0 | 1;
1401
+ };
1402
+ [Primitive.Gather]: {
1403
+ axis: number[];
1404
+ outDim: number;
1405
+ };
1255
1406
  [Primitive.Transpose]: {
1256
1407
  perm: number[];
1257
1408
  };
@@ -1259,10 +1410,6 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1259
1410
  shape: number[];
1260
1411
  axis: number[];
1261
1412
  };
1262
- [Primitive.RandomBits]: {
1263
- shape: number[];
1264
- mode: "xor" | 0 | 1;
1265
- };
1266
1413
  [Primitive.Reshape]: {
1267
1414
  shape: number[];
1268
1415
  };
@@ -1275,15 +1422,14 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
1275
1422
  [Primitive.Pad]: {
1276
1423
  width: Pair[];
1277
1424
  };
1278
- [Primitive.Gather]: {
1279
- axis: number[];
1280
- outDim: number;
1281
- };
1282
- [Primitive.JitCall]: {
1425
+ [Primitive.Jit]: {
1283
1426
  name: string;
1284
1427
  jaxpr: Jaxpr;
1285
1428
  numConsts: number;
1286
1429
  };
1430
+ [Primitive.TriangularSolve]: {
1431
+ unitDiagonal: boolean;
1432
+ };
1287
1433
  }
1288
1434
  /** Type of parameters taken by each primitive. */
1289
1435
  type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
@@ -1451,7 +1597,7 @@ declare abstract class Tracer {
1451
1597
  sub(other: this | TracerValue): this;
1452
1598
  /** Divide an array by this one. */
1453
1599
  div(other: this | TracerValue): this;
1454
- /** Return specified diagonals. See `numpy.diagonal` for full docs. */
1600
+ /** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
1455
1601
  diagonal(offset?: number, axis1?: number, axis2?: number): this;
1456
1602
  /** Flatten the array without changing its data. */
1457
1603
  flatten(): this;
@@ -1470,6 +1616,19 @@ declare abstract class Tracer {
1470
1616
  * ```
1471
1617
  */
1472
1618
  [Symbol.iterator](): IterableIterator<this>;
1619
+ /**
1620
+ * Return a sorted copy of an array in ascending order.
1621
+ *
1622
+ * See `jax.numpy.sort` for full docs.
1623
+ */
1624
+ sort(axis?: number): this;
1625
+ /**
1626
+ * Return the indices that would sort an array. This may not be a stable
1627
+ * sorting algorithm; it need not preserve order of indices in ties.
1628
+ *
1629
+ * See `jax.numpy.argsort` for full docs.
1630
+ */
1631
+ argsort(axis?: number): this;
1473
1632
  /**
1474
1633
  * Slice an array along one or more axes.
1475
1634
  *
@@ -1512,6 +1671,7 @@ declare class ShapedArray implements AbstractValue {
1512
1671
  constructor(shape: number[], dtype: DType, weakType: boolean);
1513
1672
  static fromAval(aval: AbstractValue): ShapedArray;
1514
1673
  get ndim(): number;
1674
+ get size(): number;
1515
1675
  toString(): string;
1516
1676
  equals(other: ShapedArray): boolean;
1517
1677
  }
@@ -1529,12 +1689,12 @@ type ArrayLike = Array | number | boolean;
1529
1689
  declare class PendingExecute {
1530
1690
  #private;
1531
1691
  readonly backend: Backend;
1532
- readonly kernel: Kernel;
1692
+ readonly source: Kernel | Routine;
1533
1693
  readonly inputs: Slot[];
1534
1694
  readonly outputs: Slot[];
1535
1695
  prepared: Executable | null;
1536
1696
  submitted: boolean;
1537
- constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
1697
+ constructor(backend: Backend, source: Kernel | Routine, inputs: Slot[], outputs: Slot[]);
1538
1698
  updateRc(delta: number): void;
1539
1699
  prepare(): Promise<void>;
1540
1700
  prepareSync(): void;
@@ -1566,7 +1726,6 @@ type ArrayConstructorArgs = {
1566
1726
  */
1567
1727
  declare class Array extends Tracer {
1568
1728
  #private;
1569
- id: number;
1570
1729
  /**
1571
1730
  * @ignore
1572
1731
  * Constructs an array from source, shape and backend. Note that if the source
@@ -1700,6 +1859,21 @@ declare function arange(start: number, stop?: number, step?: number, {
1700
1859
  dtype,
1701
1860
  device
1702
1861
  }?: DTypeAndDevice): Array;
1862
+ /**
1863
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
1864
+ *
1865
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
1866
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
1867
+ * `k>0` is above it.
1868
+ */
1869
+ declare function tri(n: number, m?: number, k?: number, {
1870
+ dtype,
1871
+ device
1872
+ }?: DTypeAndDevice): Array;
1873
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
1874
+ declare function tril(a: ArrayLike, k?: number): Array;
1875
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
1876
+ declare function triu(a: ArrayLike, k?: number): Array;
1703
1877
  /**
1704
1878
  * Return evenly spaced numbers over a specified interval.
1705
1879
  *
@@ -1713,8 +1887,71 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
1713
1887
  dtype,
1714
1888
  device
1715
1889
  }?: DTypeAndDevice): Array;
1890
+ declare namespace lax_linalg_d_exports {
1891
+ export { cholesky, triangularSolve };
1892
+ }
1893
+ /**
1894
+ * Compute the Cholesky decomposition of a symmetric positive-definite matrix.
1895
+ *
1896
+ * The Cholesky decomposition of a matrix `A` is:
1897
+ *
1898
+ * - A = L @ L^T (for upper=false, default)
1899
+ * - A = U^T @ U (for upper=true)
1900
+ *
1901
+ * where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
1902
+ * The input matrix must be symmetric and positive-definite.
1903
+ *
1904
+ * @example
1905
+ * ```ts
1906
+ * import { lax, numpy as np } from "@jax-js/jax";
1907
+ *
1908
+ * const x = np.array([[2., 1.], [1., 2.]]);
1909
+ *
1910
+ * // Lower Cholesky factorization (default):
1911
+ * const L = lax.linalg.cholesky(x);
1912
+ * // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
1913
+ *
1914
+ * // Upper Cholesky factorization:
1915
+ * const U = lax.linalg.cholesky(x, { upper: true });
1916
+ * // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
1917
+ * ```
1918
+ */
1919
+ declare function cholesky(a: ArrayLike, {
1920
+ upper
1921
+ }?: {
1922
+ upper?: boolean;
1923
+ }): Array;
1924
+ /**
1925
+ * Solve a triangular linear system.
1926
+ *
1927
+ * Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
1928
+ * where `a` is a triangular matrix.
1929
+ *
1930
+ * @example
1931
+ * ```ts
1932
+ * import { lax, numpy as np } from "@jax-js/jax";
1933
+ *
1934
+ * const L = np.array([[2., 0.], [1., 3.]]);
1935
+ * const b = np.array([4., 7.]).reshape([2, 1]);
1936
+ *
1937
+ * // Solve L @ x = b
1938
+ * const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
1939
+ * // x = [[2.], [5./3.]]
1940
+ * ```
1941
+ */
1942
+ declare function triangularSolve(a: ArrayLike, b: ArrayLike, {
1943
+ leftSide,
1944
+ lower,
1945
+ transposeA,
1946
+ unitDiagonal
1947
+ }?: {
1948
+ leftSide?: boolean;
1949
+ lower?: boolean;
1950
+ transposeA?: boolean;
1951
+ unitDiagonal?: boolean;
1952
+ }): Array;
1716
1953
  declare namespace lax_d_exports {
1717
- export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, reduceWindow, stopGradient };
1954
+ export { DotDimensionNumbers, PaddingType, conv, convGeneralDilated, convWithGeneralPadding, dot, erf, erfc, lax_linalg_d_exports as linalg, reduceWindow, stopGradient };
1718
1955
  }
1719
1956
  /**
1720
1957
  * Dimension numbers for general `dot()` primitive.
@@ -1745,7 +1982,7 @@ declare function dot(lhs: Array, rhs: Array, {
1745
1982
  lhsBatchDims: lb,
1746
1983
  rhsBatchDims: rb
1747
1984
  }?: DotDimensionNumbers): Array;
1748
- type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1985
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | Pair[];
1749
1986
  /**
1750
1987
  * General n-dimensional convolution operator, with optional dilation.
1751
1988
  *
@@ -1785,9 +2022,8 @@ declare function erfc(x: ArrayLike): Array;
1785
2022
  * forward or reverse-mode automatic differentiation.
1786
2023
  */
1787
2024
  declare function stopGradient(x: ArrayLike): Array;
1788
- //# sourceMappingURL=lax.d.ts.map
1789
2025
  declare namespace nn_d_exports {
1790
- export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
2026
+ export { celu, elu, gelu, glu, hardSigmoid, hardSilu, hardSilu as hardSwish, hardTanh, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, selu, sigmoid, silu, softSign, softmax, softplus, sparsePlus, sparseSigmoid, squareplus, standardize, silu as swish };
1791
2027
  }
1792
2028
  /**
1793
2029
  * Rectified Linear Unit (ReLU) activation function:
@@ -1814,21 +2050,28 @@ declare function sigmoid(x: ArrayLike): Array;
1814
2050
  */
1815
2051
  declare function softplus(x: ArrayLike): Array;
1816
2052
  /**
1817
- * Soft-sign activation function, computed element-wise:
1818
- * `softsign(x) = x / (|x| + 1)`.
2053
+ * @function
2054
+ * Sparse plus function:
2055
+ *
2056
+ * - When `x <= -1`: `0`
2057
+ * - When `-1 < x < 1`: `(x+1)**2 / 4`
2058
+ * - When `x >= 1`: `x`
1819
2059
  */
1820
- declare function softSign(x: ArrayLike): Array;
2060
+ declare const sparsePlus: OwnedFunction<(x: ArrayLike) => Array>;
1821
2061
  /**
1822
2062
  * @function
1823
- * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1824
- * Swish, computed element-wise:
1825
- * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
1826
- *
1827
- * `swish()` and `silu()` are both aliases for the same function.
2063
+ * Sparse sigmoid activation function.
1828
2064
  *
1829
- * Reference: https://en.wikipedia.org/wiki/Swish_function
2065
+ * - When `x <= -1`: `0`
2066
+ * - When `-1 < x < 1`: `(x + 1) / 2`
2067
+ * - When `x >= 1`: `1`
1830
2068
  */
1831
- declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
2069
+ declare const sparseSigmoid: OwnedFunction<(x: ArrayLike) => Array>;
2070
+ /**
2071
+ * Soft-sign activation function, computed element-wise:
2072
+ * `softsign(x) = x / (|x| + 1)`.
2073
+ */
2074
+ declare function softSign(x: ArrayLike): Array;
1832
2075
  /**
1833
2076
  * @function
1834
2077
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
@@ -1839,7 +2082,7 @@ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1839
2082
  *
1840
2083
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1841
2084
  */
1842
- declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
2085
+ declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
1843
2086
  /**
1844
2087
  * Log-sigmoid activation function, computed element-wise:
1845
2088
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
@@ -1852,6 +2095,12 @@ declare function logSigmoid(x: ArrayLike): Array;
1852
2095
  declare const identity: (x: ArrayLike) => Array;
1853
2096
  /** Leaky rectified linear (ReLU) activation function */
1854
2097
  declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
2098
+ /** Hard sigmoid activation function: `relu6(x+3)/6`. */
2099
+ declare function hardSigmoid(x: ArrayLike): Array;
2100
+ /** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
2101
+ declare function hardSilu(x: ArrayLike): Array;
2102
+ /** Hard tanh activation function: `clip(x, -1, 1)`. */
2103
+ declare function hardTanh(x: ArrayLike): Array;
1855
2104
  /**
1856
2105
  * Exponential linear unit activation function.
1857
2106
  *
@@ -1866,6 +2115,16 @@ declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
1866
2115
  * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1867
2116
  */
1868
2117
  declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
2118
+ /**
2119
+ * @function
2120
+ * Scaled exponential linear unit activation.
2121
+ *
2122
+ * Computes the element-wise function:
2123
+ * `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
2124
+ *
2125
+ * Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
2126
+ */
2127
+ declare const selu: OwnedFunction<(x: ArrayLike) => Array>;
1869
2128
  /**
1870
2129
  * @function
1871
2130
  * Gaussion error linear unit (GELU) activation function.
@@ -1961,9 +2220,8 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
1961
2220
  * ```
1962
2221
  */
1963
2222
  declare function oneHot(x: Array, numClasses: number): Array;
1964
- //# sourceMappingURL=nn.d.ts.map
1965
2223
  declare namespace random_d_exports {
1966
- export { bernoulli, bits, exponential, key, normal, split, uniform };
2224
+ export { bernoulli, bits, cauchy, exponential, gumbel, key, laplace, normal, split, uniform };
1967
2225
  }
1968
2226
  /** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
1969
2227
  declare function key(seed: number): Array;
@@ -1986,11 +2244,33 @@ declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefin
1986
2244
  * and must be broadcastable to `shape`.
1987
2245
  */
1988
2246
  declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
2247
+ /**
2248
+ * @function
2249
+ * Sample from a Cauchy distribution with location 0 and scale 1.
2250
+ *
2251
+ * Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
2252
+ */
2253
+ declare const cauchy: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1989
2254
  /**
1990
2255
  * @function
1991
2256
  * Sample exponential random values according to `p(x) = exp(-x)`.
1992
2257
  */
1993
2258
  declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2259
+ /**
2260
+ * @function
2261
+ * Sample from a Gumbel distribution with location 0 and scale 1.
2262
+ *
2263
+ * Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
2264
+ */
2265
+ declare const gumbel: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2266
+ /**
2267
+ * @function
2268
+ * Sample from a Laplace distribution with location 0 and scale 1.
2269
+ *
2270
+ * Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
2271
+ * Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
2272
+ */
2273
+ declare const laplace: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
1994
2274
  /**
1995
2275
  * @function
1996
2276
  * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
@@ -2000,7 +2280,6 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
2000
2280
  * bitwise identical to JAX.
2001
2281
  */
2002
2282
  declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
2003
- //# sourceMappingURL=random.d.ts.map
2004
2283
  declare namespace scipy_special_d_exports {
2005
2284
  export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
2006
2285
  }
@@ -2031,8 +2310,7 @@ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTr
2031
2310
  * Construct a Jaxpr by dynamically tracing a function with example inputs.
2032
2311
  */
2033
2312
  declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
2034
- jaxpr: Jaxpr;
2035
- consts: Array[];
2313
+ jaxpr: ClosedJaxpr;
2036
2314
  treedef: JsTreeDef;
2037
2315
  };
2038
2316
  /**
@@ -2080,11 +2358,6 @@ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F)
2080
2358
  * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
2081
2359
  */
2082
2360
  declare const jacrev: typeof jacfwd;
2083
- /**
2084
- * @function
2085
- * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
2086
- */
2087
- declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
2088
2361
  /**
2089
2362
  * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
2090
2363
  *
@@ -2106,8 +2379,5 @@ declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
2106
2379
  * default device.
2107
2380
  */
2108
2381
  declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
2109
- //# sourceMappingURL=index.d.ts.map
2110
-
2111
2382
  //#endregion
2112
- export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
2113
- //# sourceMappingURL=index.d.ts.map
2383
+ export { Array, ClosedJaxpr, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };