@jax-js/jax 0.0.5 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-yEU0L_ig.cjs → backend-BbrKEB18.cjs} +378 -183
- package/dist/{backend-CdcTZEOF.js → backend-CoVtc9dx.js} +366 -177
- package/dist/index.cjs +385 -74
- package/dist/index.d.cts +115 -23
- package/dist/index.d.ts +115 -23
- package/dist/index.js +378 -74
- package/dist/{webgpu-CM-xNYzW.js → webgpu-B3UVme6n.js} +188 -153
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-DGYNVHma.cjs} +188 -153
- package/package.json +25 -15
package/dist/index.d.cts
CHANGED
|
@@ -124,6 +124,7 @@ declare class ShapeTracker {
|
|
|
124
124
|
/** Like pad(), but allows for negative values. */
|
|
125
125
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
126
126
|
}
|
|
127
|
+
//# sourceMappingURL=shape.d.ts.map
|
|
127
128
|
//#endregion
|
|
128
129
|
//#region src/utils.d.ts
|
|
129
130
|
/**
|
|
@@ -167,9 +168,10 @@ declare enum DType {
|
|
|
167
168
|
Uint32 = "uint32",
|
|
168
169
|
Bool = "bool",
|
|
169
170
|
Float16 = "float16",
|
|
171
|
+
Float64 = "float64",
|
|
170
172
|
}
|
|
171
173
|
/** @inline */
|
|
172
|
-
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
174
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
|
|
173
175
|
/**
|
|
174
176
|
* Promote two dtypes to their join according to the type lattice.
|
|
175
177
|
*
|
|
@@ -179,7 +181,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
179
181
|
*
|
|
180
182
|
* **Type lattice:**
|
|
181
183
|
* ```text
|
|
182
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
184
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
183
185
|
* weakType --^
|
|
184
186
|
* ```
|
|
185
187
|
*
|
|
@@ -222,6 +224,8 @@ declare class AluExp implements FpHashable {
|
|
|
222
224
|
static atan(a: AluExp): AluExp;
|
|
223
225
|
static exp(a: AluExp): AluExp;
|
|
224
226
|
static log(a: AluExp): AluExp;
|
|
227
|
+
static erf(a: AluExp): AluExp;
|
|
228
|
+
static erfc(a: AluExp): AluExp;
|
|
225
229
|
static sqrt(a: AluExp): AluExp;
|
|
226
230
|
static reciprocal(a: AluExp): AluExp;
|
|
227
231
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -240,6 +244,7 @@ declare class AluExp implements FpHashable {
|
|
|
240
244
|
static u32(value: number): AluExp;
|
|
241
245
|
static bool(value: boolean): AluExp;
|
|
242
246
|
static f16(value: number): AluExp;
|
|
247
|
+
static f64(value: number): AluExp;
|
|
243
248
|
not(): AluExp;
|
|
244
249
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
245
250
|
getHash(): bigint;
|
|
@@ -289,8 +294,8 @@ declare class AluExp implements FpHashable {
|
|
|
289
294
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
290
295
|
/** Collect all nodes that satisfy a predicate. */
|
|
291
296
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
292
|
-
/** Produce
|
|
293
|
-
distinctOps(): Set<
|
|
297
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
298
|
+
distinctOps(): Map<AluOp, Set<DType>>;
|
|
294
299
|
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
295
300
|
rewriteGlobalViews(): AluExp;
|
|
296
301
|
}
|
|
@@ -309,6 +314,8 @@ declare enum AluOp {
|
|
|
309
314
|
Atan = "Atan",
|
|
310
315
|
Exp = "Exp",
|
|
311
316
|
Log = "Log",
|
|
317
|
+
Erf = "Erf",
|
|
318
|
+
Erfc = "Erfc",
|
|
312
319
|
Sqrt = "Sqrt",
|
|
313
320
|
Reciprocal = "Reciprocal",
|
|
314
321
|
Cast = "Cast",
|
|
@@ -465,7 +472,7 @@ type JsTree<T> = T | JsTree<T>[] | {
|
|
|
465
472
|
[key: string]: JsTree<T>;
|
|
466
473
|
};
|
|
467
474
|
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
468
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
475
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
469
476
|
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
470
477
|
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
471
478
|
/** Represents the structure of a JsTree. */
|
|
@@ -477,6 +484,8 @@ declare class JsTreeDef {
|
|
|
477
484
|
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
478
485
|
// Must be comparable with deepEqual.
|
|
479
486
|
childTreedefs: JsTreeDef[]);
|
|
487
|
+
/** Get the total number of leaves in the tree. */
|
|
488
|
+
get size(): number;
|
|
480
489
|
/** Returns a string representation of this tree definition. */
|
|
481
490
|
toString(root?: boolean): string;
|
|
482
491
|
/** Compare this tree definition with another. */
|
|
@@ -540,6 +549,8 @@ declare enum Primitive {
|
|
|
540
549
|
Atan = "atan",
|
|
541
550
|
Exp = "exp",
|
|
542
551
|
Log = "log",
|
|
552
|
+
Erf = "erf",
|
|
553
|
+
Erfc = "erfc",
|
|
543
554
|
Sqrt = "sqrt",
|
|
544
555
|
Min = "min",
|
|
545
556
|
Max = "max",
|
|
@@ -621,11 +632,9 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
621
632
|
/** Type of parameters taken by each primitive. */
|
|
622
633
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
623
634
|
declare enum CompareOp {
|
|
624
|
-
Greater = "greater",
|
|
625
635
|
Less = "less",
|
|
626
636
|
Equal = "equal",
|
|
627
637
|
NotEqual = "not_equal",
|
|
628
|
-
GreaterEqual = "greater_equal",
|
|
629
638
|
LessEqual = "less_equal",
|
|
630
639
|
}
|
|
631
640
|
/** @inline */
|
|
@@ -886,12 +895,13 @@ type ArrayConstructorArgs = {
|
|
|
886
895
|
dtype: DType;
|
|
887
896
|
weakType: boolean;
|
|
888
897
|
backend: Backend;
|
|
898
|
+
committed: boolean;
|
|
889
899
|
pending?: Iterable<PendingExecute>;
|
|
890
900
|
};
|
|
891
901
|
/**
|
|
892
902
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
893
903
|
*
|
|
894
|
-
* This is the library's core data type. Equivalent to `
|
|
904
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
895
905
|
* `torch.Tensor`.
|
|
896
906
|
*
|
|
897
907
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -967,10 +977,15 @@ declare class Array extends Tracer {
|
|
|
967
977
|
item(): number;
|
|
968
978
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
969
979
|
static _implRules(): typeof implRules;
|
|
980
|
+
/** @private */
|
|
970
981
|
_realizeSource(): number;
|
|
982
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
983
|
+
_put(backend: Backend): Promise<Array>;
|
|
984
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
985
|
+
_putSync(backend: Backend): Array;
|
|
971
986
|
}
|
|
972
987
|
/** Constructor for creating a new array from data. */
|
|
973
|
-
declare function array(values: Array |
|
|
988
|
+
declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
974
989
|
shape,
|
|
975
990
|
dtype,
|
|
976
991
|
device
|
|
@@ -1043,13 +1058,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1043
1058
|
device
|
|
1044
1059
|
}?: DTypeAndDevice): Array;
|
|
1045
1060
|
declare namespace numpy_d_exports {
|
|
1046
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1061
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1047
1062
|
}
|
|
1048
1063
|
declare const float32 = DType.Float32;
|
|
1049
1064
|
declare const int32 = DType.Int32;
|
|
1050
1065
|
declare const uint32 = DType.Uint32;
|
|
1051
1066
|
declare const bool = DType.Bool;
|
|
1052
1067
|
declare const float16 = DType.Float16;
|
|
1068
|
+
declare const float64 = DType.Float64;
|
|
1053
1069
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
1054
1070
|
declare const e: number;
|
|
1055
1071
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -1126,7 +1142,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1126
1142
|
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1127
1143
|
* pair specifies the padding for its corresponding axis.
|
|
1128
1144
|
*/
|
|
1129
|
-
declare const pad: (x: ArrayLike, width: number |
|
|
1145
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
1130
1146
|
/**
|
|
1131
1147
|
* @function
|
|
1132
1148
|
* Return the number of dimensions of an array. Does not consume array reference.
|
|
@@ -1356,6 +1372,26 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1356
1372
|
declare const abs: typeof absolute;
|
|
1357
1373
|
/** Return an element-wise indication of sign of the input. */
|
|
1358
1374
|
declare function sign(x: ArrayLike): Array;
|
|
1375
|
+
/**
|
|
1376
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1377
|
+
*
|
|
1378
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1379
|
+
*/
|
|
1380
|
+
declare function hamming(M: number): Array;
|
|
1381
|
+
/**
|
|
1382
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
1383
|
+
*
|
|
1384
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1385
|
+
*/
|
|
1386
|
+
declare function hann(M: number): Array;
|
|
1387
|
+
/**
|
|
1388
|
+
* @function
|
|
1389
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
1390
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
1391
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
1392
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1393
|
+
*/
|
|
1394
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1359
1395
|
/** Calculate element-wise square of the input array. */
|
|
1360
1396
|
declare function square(x: ArrayLike): Array;
|
|
1361
1397
|
/** Element-wise tangent function (takes radians). */
|
|
@@ -1367,8 +1403,8 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1367
1403
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1368
1404
|
*
|
|
1369
1405
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1370
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
1371
|
-
* improvements.
|
|
1406
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
1407
|
+
* stability improvements.
|
|
1372
1408
|
*/
|
|
1373
1409
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1374
1410
|
/**
|
|
@@ -1500,6 +1536,20 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1500
1536
|
mean?: ArrayLike;
|
|
1501
1537
|
correction?: number;
|
|
1502
1538
|
} & ReduceOpts): Array;
|
|
1539
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1540
|
+
declare function isinf(x: ArrayLike): Array;
|
|
1541
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
1542
|
+
declare function isnan(x: ArrayLike): Array;
|
|
1543
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
1544
|
+
declare function isneginf(x: ArrayLike): Array;
|
|
1545
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
1546
|
+
declare function isposinf(x: ArrayLike): Array;
|
|
1547
|
+
/**
|
|
1548
|
+
* @function
|
|
1549
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
1550
|
+
*/
|
|
1551
|
+
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1552
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1503
1553
|
//#endregion
|
|
1504
1554
|
//#region src/frontend/jaxpr.d.ts
|
|
1505
1555
|
/**
|
|
@@ -1574,10 +1624,9 @@ declare class Jaxpr implements FpHashable {
|
|
|
1574
1624
|
/** @inline */
|
|
1575
1625
|
type JitOpts = {
|
|
1576
1626
|
staticArgnums?: number[];
|
|
1577
|
-
device?: Device;
|
|
1578
1627
|
};
|
|
1579
1628
|
declare namespace lax_d_exports {
|
|
1580
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1629
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1581
1630
|
}
|
|
1582
1631
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1583
1632
|
/**
|
|
@@ -1601,6 +1650,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
|
|
|
1601
1650
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1602
1651
|
/** Reduce a computation over padded windows. */
|
|
1603
1652
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1653
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1654
|
+
declare function erf(x: ArrayLike): Array;
|
|
1655
|
+
/**
|
|
1656
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
1657
|
+
*
|
|
1658
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1659
|
+
* where `erf(x)` is very close to 1.
|
|
1660
|
+
*/
|
|
1661
|
+
declare function erfc(x: ArrayLike): Array;
|
|
1662
|
+
/**
|
|
1663
|
+
* Stops gradient computation.
|
|
1664
|
+
*
|
|
1665
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1666
|
+
* forward or reverse-mode automatic differentiation.
|
|
1667
|
+
*/
|
|
1668
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1669
|
+
//# sourceMappingURL=lax.d.ts.map
|
|
1604
1670
|
declare namespace nn_d_exports {
|
|
1605
1671
|
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1606
1672
|
}
|
|
@@ -1685,15 +1751,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1685
1751
|
* @function
|
|
1686
1752
|
* Gaussion error linear unit (GELU) activation function.
|
|
1687
1753
|
*
|
|
1688
|
-
* This is computed element-wise.
|
|
1689
|
-
*
|
|
1690
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1754
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
1755
|
+
* `approximate` is set (default true):
|
|
1691
1756
|
*
|
|
1692
|
-
*
|
|
1757
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
1758
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
1693
1759
|
*
|
|
1694
|
-
*
|
|
1760
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1695
1761
|
*/
|
|
1696
|
-
declare const gelu: OwnedFunction<(x: ArrayLike
|
|
1762
|
+
declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
|
|
1763
|
+
approximate?: boolean | undefined;
|
|
1764
|
+
} | undefined) => Array>;
|
|
1697
1765
|
/**
|
|
1698
1766
|
* Gated linear unit (GLU) activation function.
|
|
1699
1767
|
*
|
|
@@ -1774,6 +1842,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1774
1842
|
* ```
|
|
1775
1843
|
*/
|
|
1776
1844
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1845
|
+
//# sourceMappingURL=nn.d.ts.map
|
|
1777
1846
|
declare namespace random_d_exports {
|
|
1778
1847
|
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1779
1848
|
}
|
|
@@ -1812,6 +1881,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
1812
1881
|
* bitwise identical to JAX.
|
|
1813
1882
|
*/
|
|
1814
1883
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1884
|
+
//# sourceMappingURL=random.d.ts.map
|
|
1885
|
+
declare namespace scipy_special_d_exports {
|
|
1886
|
+
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1887
|
+
}
|
|
1888
|
+
/**
|
|
1889
|
+
* @function
|
|
1890
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
1891
|
+
*/
|
|
1892
|
+
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1815
1893
|
//#endregion
|
|
1816
1894
|
//#region src/index.d.ts
|
|
1817
1895
|
/**
|
|
@@ -1823,7 +1901,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
|
|
|
1823
1901
|
* @function
|
|
1824
1902
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1825
1903
|
*/
|
|
1826
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number |
|
|
1904
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1827
1905
|
/**
|
|
1828
1906
|
* @function
|
|
1829
1907
|
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
@@ -1898,5 +1976,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
|
|
|
1898
1976
|
* Does not consume reference to the arrays.
|
|
1899
1977
|
*/
|
|
1900
1978
|
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1979
|
+
/**
|
|
1980
|
+
* Transfer `x` to `device`.
|
|
1981
|
+
*
|
|
1982
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
1983
|
+
* is committed to the device.
|
|
1984
|
+
*
|
|
1985
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
1986
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
1987
|
+
* default device.
|
|
1988
|
+
*/
|
|
1989
|
+
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
1990
|
+
//# sourceMappingURL=index.d.ts.map
|
|
1991
|
+
|
|
1901
1992
|
//#endregion
|
|
1902
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1993
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1994
|
+
//# sourceMappingURL=index.d.cts.map
|
package/dist/index.d.ts
CHANGED
|
@@ -121,6 +121,7 @@ declare class ShapeTracker {
|
|
|
121
121
|
/** Like pad(), but allows for negative values. */
|
|
122
122
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
123
123
|
}
|
|
124
|
+
//# sourceMappingURL=shape.d.ts.map
|
|
124
125
|
//#endregion
|
|
125
126
|
//#region src/utils.d.ts
|
|
126
127
|
/**
|
|
@@ -164,9 +165,10 @@ declare enum DType {
|
|
|
164
165
|
Uint32 = "uint32",
|
|
165
166
|
Bool = "bool",
|
|
166
167
|
Float16 = "float16",
|
|
168
|
+
Float64 = "float64",
|
|
167
169
|
}
|
|
168
170
|
/** @inline */
|
|
169
|
-
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
171
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer> | Float64Array<ArrayBuffer>;
|
|
170
172
|
/**
|
|
171
173
|
* Promote two dtypes to their join according to the type lattice.
|
|
172
174
|
*
|
|
@@ -176,7 +178,7 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
176
178
|
*
|
|
177
179
|
* **Type lattice:**
|
|
178
180
|
* ```text
|
|
179
|
-
* bool -> uint32 -> int32 -> float16 -> float32
|
|
181
|
+
* bool -> uint32 -> int32 -> float16 -> float32 -> float64
|
|
180
182
|
* weakType --^
|
|
181
183
|
* ```
|
|
182
184
|
*
|
|
@@ -219,6 +221,8 @@ declare class AluExp implements FpHashable {
|
|
|
219
221
|
static atan(a: AluExp): AluExp;
|
|
220
222
|
static exp(a: AluExp): AluExp;
|
|
221
223
|
static log(a: AluExp): AluExp;
|
|
224
|
+
static erf(a: AluExp): AluExp;
|
|
225
|
+
static erfc(a: AluExp): AluExp;
|
|
222
226
|
static sqrt(a: AluExp): AluExp;
|
|
223
227
|
static reciprocal(a: AluExp): AluExp;
|
|
224
228
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -237,6 +241,7 @@ declare class AluExp implements FpHashable {
|
|
|
237
241
|
static u32(value: number): AluExp;
|
|
238
242
|
static bool(value: boolean): AluExp;
|
|
239
243
|
static f16(value: number): AluExp;
|
|
244
|
+
static f64(value: number): AluExp;
|
|
240
245
|
not(): AluExp;
|
|
241
246
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
242
247
|
getHash(): bigint;
|
|
@@ -286,8 +291,8 @@ declare class AluExp implements FpHashable {
|
|
|
286
291
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
287
292
|
/** Collect all nodes that satisfy a predicate. */
|
|
288
293
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
289
|
-
/** Produce
|
|
290
|
-
distinctOps(): Set<
|
|
294
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
295
|
+
distinctOps(): Map<AluOp, Set<DType>>;
|
|
291
296
|
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
292
297
|
rewriteGlobalViews(): AluExp;
|
|
293
298
|
}
|
|
@@ -306,6 +311,8 @@ declare enum AluOp {
|
|
|
306
311
|
Atan = "Atan",
|
|
307
312
|
Exp = "Exp",
|
|
308
313
|
Log = "Log",
|
|
314
|
+
Erf = "Erf",
|
|
315
|
+
Erfc = "Erfc",
|
|
309
316
|
Sqrt = "Sqrt",
|
|
310
317
|
Reciprocal = "Reciprocal",
|
|
311
318
|
Cast = "Cast",
|
|
@@ -462,7 +469,7 @@ type JsTree<T> = T | JsTree<T>[] | {
|
|
|
462
469
|
[key: string]: JsTree<T>;
|
|
463
470
|
};
|
|
464
471
|
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
465
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
472
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
466
473
|
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
467
474
|
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
468
475
|
/** Represents the structure of a JsTree. */
|
|
@@ -474,6 +481,8 @@ declare class JsTreeDef {
|
|
|
474
481
|
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
475
482
|
// Must be comparable with deepEqual.
|
|
476
483
|
childTreedefs: JsTreeDef[]);
|
|
484
|
+
/** Get the total number of leaves in the tree. */
|
|
485
|
+
get size(): number;
|
|
477
486
|
/** Returns a string representation of this tree definition. */
|
|
478
487
|
toString(root?: boolean): string;
|
|
479
488
|
/** Compare this tree definition with another. */
|
|
@@ -537,6 +546,8 @@ declare enum Primitive {
|
|
|
537
546
|
Atan = "atan",
|
|
538
547
|
Exp = "exp",
|
|
539
548
|
Log = "log",
|
|
549
|
+
Erf = "erf",
|
|
550
|
+
Erfc = "erfc",
|
|
540
551
|
Sqrt = "sqrt",
|
|
541
552
|
Min = "min",
|
|
542
553
|
Max = "max",
|
|
@@ -618,11 +629,9 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
618
629
|
/** Type of parameters taken by each primitive. */
|
|
619
630
|
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
620
631
|
declare enum CompareOp {
|
|
621
|
-
Greater = "greater",
|
|
622
632
|
Less = "less",
|
|
623
633
|
Equal = "equal",
|
|
624
634
|
NotEqual = "not_equal",
|
|
625
|
-
GreaterEqual = "greater_equal",
|
|
626
635
|
LessEqual = "less_equal",
|
|
627
636
|
}
|
|
628
637
|
/** @inline */
|
|
@@ -883,12 +892,13 @@ type ArrayConstructorArgs = {
|
|
|
883
892
|
dtype: DType;
|
|
884
893
|
weakType: boolean;
|
|
885
894
|
backend: Backend;
|
|
895
|
+
committed: boolean;
|
|
886
896
|
pending?: Iterable<PendingExecute>;
|
|
887
897
|
};
|
|
888
898
|
/**
|
|
889
899
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
890
900
|
*
|
|
891
|
-
* This is the library's core data type. Equivalent to `
|
|
901
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
892
902
|
* `torch.Tensor`.
|
|
893
903
|
*
|
|
894
904
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -964,10 +974,15 @@ declare class Array extends Tracer {
|
|
|
964
974
|
item(): number;
|
|
965
975
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
966
976
|
static _implRules(): typeof implRules;
|
|
977
|
+
/** @private */
|
|
967
978
|
_realizeSource(): number;
|
|
979
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
980
|
+
_put(backend: Backend): Promise<Array>;
|
|
981
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
982
|
+
_putSync(backend: Backend): Array;
|
|
968
983
|
}
|
|
969
984
|
/** Constructor for creating a new array from data. */
|
|
970
|
-
declare function array(values: Array |
|
|
985
|
+
declare function array(values: Array | DataArray | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
971
986
|
shape,
|
|
972
987
|
dtype,
|
|
973
988
|
device
|
|
@@ -1040,13 +1055,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1040
1055
|
device
|
|
1041
1056
|
}?: DTypeAndDevice): Array;
|
|
1042
1057
|
declare namespace numpy_d_exports {
|
|
1043
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1058
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, float64, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, isfinite, isinf, isnan, isneginf, isposinf, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1044
1059
|
}
|
|
1045
1060
|
declare const float32 = DType.Float32;
|
|
1046
1061
|
declare const int32 = DType.Int32;
|
|
1047
1062
|
declare const uint32 = DType.Uint32;
|
|
1048
1063
|
declare const bool = DType.Bool;
|
|
1049
1064
|
declare const float16 = DType.Float16;
|
|
1065
|
+
declare const float64 = DType.Float64;
|
|
1050
1066
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
1051
1067
|
declare const e: number;
|
|
1052
1068
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -1123,7 +1139,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1123
1139
|
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1124
1140
|
* pair specifies the padding for its corresponding axis.
|
|
1125
1141
|
*/
|
|
1126
|
-
declare const pad: (x: ArrayLike, width: number |
|
|
1142
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
1127
1143
|
/**
|
|
1128
1144
|
* @function
|
|
1129
1145
|
* Return the number of dimensions of an array. Does not consume array reference.
|
|
@@ -1353,6 +1369,26 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1353
1369
|
declare const abs: typeof absolute;
|
|
1354
1370
|
/** Return an element-wise indication of sign of the input. */
|
|
1355
1371
|
declare function sign(x: ArrayLike): Array;
|
|
1372
|
+
/**
|
|
1373
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1374
|
+
*
|
|
1375
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1376
|
+
*/
|
|
1377
|
+
declare function hamming(M: number): Array;
|
|
1378
|
+
/**
|
|
1379
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
1380
|
+
*
|
|
1381
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1382
|
+
*/
|
|
1383
|
+
declare function hann(M: number): Array;
|
|
1384
|
+
/**
|
|
1385
|
+
* @function
|
|
1386
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
1387
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
1388
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
1389
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1390
|
+
*/
|
|
1391
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1356
1392
|
/** Calculate element-wise square of the input array. */
|
|
1357
1393
|
declare function square(x: ArrayLike): Array;
|
|
1358
1394
|
/** Element-wise tangent function (takes radians). */
|
|
@@ -1364,8 +1400,8 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1364
1400
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1365
1401
|
*
|
|
1366
1402
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1367
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
1368
|
-
* improvements.
|
|
1403
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
1404
|
+
* stability improvements.
|
|
1369
1405
|
*/
|
|
1370
1406
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1371
1407
|
/**
|
|
@@ -1497,6 +1533,20 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1497
1533
|
mean?: ArrayLike;
|
|
1498
1534
|
correction?: number;
|
|
1499
1535
|
} & ReduceOpts): Array;
|
|
1536
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
1537
|
+
declare function isinf(x: ArrayLike): Array;
|
|
1538
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
1539
|
+
declare function isnan(x: ArrayLike): Array;
|
|
1540
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
1541
|
+
declare function isneginf(x: ArrayLike): Array;
|
|
1542
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
1543
|
+
declare function isposinf(x: ArrayLike): Array;
|
|
1544
|
+
/**
|
|
1545
|
+
* @function
|
|
1546
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
1547
|
+
*/
|
|
1548
|
+
declare const isfinite: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1549
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1500
1550
|
//#endregion
|
|
1501
1551
|
//#region src/frontend/jaxpr.d.ts
|
|
1502
1552
|
/**
|
|
@@ -1571,10 +1621,9 @@ declare class Jaxpr implements FpHashable {
|
|
|
1571
1621
|
/** @inline */
|
|
1572
1622
|
type JitOpts = {
|
|
1573
1623
|
staticArgnums?: number[];
|
|
1574
|
-
device?: Device;
|
|
1575
1624
|
};
|
|
1576
1625
|
declare namespace lax_d_exports {
|
|
1577
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1626
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1578
1627
|
}
|
|
1579
1628
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1580
1629
|
/**
|
|
@@ -1598,6 +1647,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
|
|
|
1598
1647
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1599
1648
|
/** Reduce a computation over padded windows. */
|
|
1600
1649
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1650
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1651
|
+
declare function erf(x: ArrayLike): Array;
|
|
1652
|
+
/**
|
|
1653
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
1654
|
+
*
|
|
1655
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1656
|
+
* where `erf(x)` is very close to 1.
|
|
1657
|
+
*/
|
|
1658
|
+
declare function erfc(x: ArrayLike): Array;
|
|
1659
|
+
/**
|
|
1660
|
+
* Stops gradient computation.
|
|
1661
|
+
*
|
|
1662
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1663
|
+
* forward or reverse-mode automatic differentiation.
|
|
1664
|
+
*/
|
|
1665
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1666
|
+
//# sourceMappingURL=lax.d.ts.map
|
|
1601
1667
|
declare namespace nn_d_exports {
|
|
1602
1668
|
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1603
1669
|
}
|
|
@@ -1682,15 +1748,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1682
1748
|
* @function
|
|
1683
1749
|
* Gaussion error linear unit (GELU) activation function.
|
|
1684
1750
|
*
|
|
1685
|
-
* This is computed element-wise.
|
|
1686
|
-
*
|
|
1687
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1751
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
1752
|
+
* `approximate` is set (default true):
|
|
1688
1753
|
*
|
|
1689
|
-
*
|
|
1754
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
1755
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
1690
1756
|
*
|
|
1691
|
-
*
|
|
1757
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1692
1758
|
*/
|
|
1693
|
-
declare const gelu: OwnedFunction<(x: ArrayLike
|
|
1759
|
+
declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
|
|
1760
|
+
approximate?: boolean | undefined;
|
|
1761
|
+
} | undefined) => Array>;
|
|
1694
1762
|
/**
|
|
1695
1763
|
* Gated linear unit (GLU) activation function.
|
|
1696
1764
|
*
|
|
@@ -1771,6 +1839,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1771
1839
|
* ```
|
|
1772
1840
|
*/
|
|
1773
1841
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1842
|
+
//# sourceMappingURL=nn.d.ts.map
|
|
1774
1843
|
declare namespace random_d_exports {
|
|
1775
1844
|
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1776
1845
|
}
|
|
@@ -1809,6 +1878,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
1809
1878
|
* bitwise identical to JAX.
|
|
1810
1879
|
*/
|
|
1811
1880
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1881
|
+
//# sourceMappingURL=random.d.ts.map
|
|
1882
|
+
declare namespace scipy_special_d_exports {
|
|
1883
|
+
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1884
|
+
}
|
|
1885
|
+
/**
|
|
1886
|
+
* @function
|
|
1887
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
1888
|
+
*/
|
|
1889
|
+
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1812
1890
|
//#endregion
|
|
1813
1891
|
//#region src/index.d.ts
|
|
1814
1892
|
/**
|
|
@@ -1820,7 +1898,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
|
|
|
1820
1898
|
* @function
|
|
1821
1899
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1822
1900
|
*/
|
|
1823
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number |
|
|
1901
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1824
1902
|
/**
|
|
1825
1903
|
* @function
|
|
1826
1904
|
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
@@ -1895,5 +1973,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
|
|
|
1895
1973
|
* Does not consume reference to the arrays.
|
|
1896
1974
|
*/
|
|
1897
1975
|
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1976
|
+
/**
|
|
1977
|
+
* Transfer `x` to `device`.
|
|
1978
|
+
*
|
|
1979
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
1980
|
+
* is committed to the device.
|
|
1981
|
+
*
|
|
1982
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
1983
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
1984
|
+
* default device.
|
|
1985
|
+
*/
|
|
1986
|
+
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
1987
|
+
//# sourceMappingURL=index.d.ts.map
|
|
1988
|
+
|
|
1898
1989
|
//#endregion
|
|
1899
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1990
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1991
|
+
//# sourceMappingURL=index.d.ts.map
|