@jax-js/jax 0.0.5 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-CdcTZEOF.js → backend-DwIAd0AG.js} +205 -112
- package/dist/{backend-yEU0L_ig.cjs → backend-FtkbO6pI.cjs} +217 -118
- package/dist/index.cjs +344 -67
- package/dist/index.d.cts +96 -18
- package/dist/index.d.ts +96 -18
- package/dist/index.js +337 -67
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-CM-xNYzW.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.d.cts
CHANGED
|
@@ -124,6 +124,7 @@ declare class ShapeTracker {
|
|
|
124
124
|
/** Like pad(), but allows for negative values. */
|
|
125
125
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
126
126
|
}
|
|
127
|
+
//# sourceMappingURL=shape.d.ts.map
|
|
127
128
|
//#endregion
|
|
128
129
|
//#region src/utils.d.ts
|
|
129
130
|
/**
|
|
@@ -222,6 +223,8 @@ declare class AluExp implements FpHashable {
|
|
|
222
223
|
static atan(a: AluExp): AluExp;
|
|
223
224
|
static exp(a: AluExp): AluExp;
|
|
224
225
|
static log(a: AluExp): AluExp;
|
|
226
|
+
static erf(a: AluExp): AluExp;
|
|
227
|
+
static erfc(a: AluExp): AluExp;
|
|
225
228
|
static sqrt(a: AluExp): AluExp;
|
|
226
229
|
static reciprocal(a: AluExp): AluExp;
|
|
227
230
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -289,8 +292,8 @@ declare class AluExp implements FpHashable {
|
|
|
289
292
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
290
293
|
/** Collect all nodes that satisfy a predicate. */
|
|
291
294
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
292
|
-
/** Produce
|
|
293
|
-
distinctOps(): Set<
|
|
295
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
296
|
+
distinctOps(): Map<AluOp, Set<DType>>;
|
|
294
297
|
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
295
298
|
rewriteGlobalViews(): AluExp;
|
|
296
299
|
}
|
|
@@ -309,6 +312,8 @@ declare enum AluOp {
|
|
|
309
312
|
Atan = "Atan",
|
|
310
313
|
Exp = "Exp",
|
|
311
314
|
Log = "Log",
|
|
315
|
+
Erf = "Erf",
|
|
316
|
+
Erfc = "Erfc",
|
|
312
317
|
Sqrt = "Sqrt",
|
|
313
318
|
Reciprocal = "Reciprocal",
|
|
314
319
|
Cast = "Cast",
|
|
@@ -465,7 +470,7 @@ type JsTree<T> = T | JsTree<T>[] | {
|
|
|
465
470
|
[key: string]: JsTree<T>;
|
|
466
471
|
};
|
|
467
472
|
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
468
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
473
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
469
474
|
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
470
475
|
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
471
476
|
/** Represents the structure of a JsTree. */
|
|
@@ -477,6 +482,8 @@ declare class JsTreeDef {
|
|
|
477
482
|
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
478
483
|
// Must be comparable with deepEqual.
|
|
479
484
|
childTreedefs: JsTreeDef[]);
|
|
485
|
+
/** Get the total number of leaves in the tree. */
|
|
486
|
+
get size(): number;
|
|
480
487
|
/** Returns a string representation of this tree definition. */
|
|
481
488
|
toString(root?: boolean): string;
|
|
482
489
|
/** Compare this tree definition with another. */
|
|
@@ -540,6 +547,8 @@ declare enum Primitive {
|
|
|
540
547
|
Atan = "atan",
|
|
541
548
|
Exp = "exp",
|
|
542
549
|
Log = "log",
|
|
550
|
+
Erf = "erf",
|
|
551
|
+
Erfc = "erfc",
|
|
543
552
|
Sqrt = "sqrt",
|
|
544
553
|
Min = "min",
|
|
545
554
|
Max = "max",
|
|
@@ -886,12 +895,13 @@ type ArrayConstructorArgs = {
|
|
|
886
895
|
dtype: DType;
|
|
887
896
|
weakType: boolean;
|
|
888
897
|
backend: Backend;
|
|
898
|
+
committed: boolean;
|
|
889
899
|
pending?: Iterable<PendingExecute>;
|
|
890
900
|
};
|
|
891
901
|
/**
|
|
892
902
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
893
903
|
*
|
|
894
|
-
* This is the library's core data type. Equivalent to `
|
|
904
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
895
905
|
* `torch.Tensor`.
|
|
896
906
|
*
|
|
897
907
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -967,7 +977,12 @@ declare class Array extends Tracer {
|
|
|
967
977
|
item(): number;
|
|
968
978
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
969
979
|
static _implRules(): typeof implRules;
|
|
980
|
+
/** @private */
|
|
970
981
|
_realizeSource(): number;
|
|
982
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
983
|
+
_put(backend: Backend): Promise<Array>;
|
|
984
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
985
|
+
_putSync(backend: Backend): Array;
|
|
971
986
|
}
|
|
972
987
|
/** Constructor for creating a new array from data. */
|
|
973
988
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
@@ -1043,7 +1058,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1043
1058
|
device
|
|
1044
1059
|
}?: DTypeAndDevice): Array;
|
|
1045
1060
|
declare namespace numpy_d_exports {
|
|
1046
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1061
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1047
1062
|
}
|
|
1048
1063
|
declare const float32 = DType.Float32;
|
|
1049
1064
|
declare const int32 = DType.Int32;
|
|
@@ -1126,7 +1141,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1126
1141
|
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1127
1142
|
* pair specifies the padding for its corresponding axis.
|
|
1128
1143
|
*/
|
|
1129
|
-
declare const pad: (x: ArrayLike, width: number |
|
|
1144
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
1130
1145
|
/**
|
|
1131
1146
|
* @function
|
|
1132
1147
|
* Return the number of dimensions of an array. Does not consume array reference.
|
|
@@ -1356,6 +1371,26 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1356
1371
|
declare const abs: typeof absolute;
|
|
1357
1372
|
/** Return an element-wise indication of sign of the input. */
|
|
1358
1373
|
declare function sign(x: ArrayLike): Array;
|
|
1374
|
+
/**
|
|
1375
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1376
|
+
*
|
|
1377
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1378
|
+
*/
|
|
1379
|
+
declare function hamming(M: number): Array;
|
|
1380
|
+
/**
|
|
1381
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
1382
|
+
*
|
|
1383
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1384
|
+
*/
|
|
1385
|
+
declare function hann(M: number): Array;
|
|
1386
|
+
/**
|
|
1387
|
+
* @function
|
|
1388
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
1389
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
1390
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
1391
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1392
|
+
*/
|
|
1393
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1359
1394
|
/** Calculate element-wise square of the input array. */
|
|
1360
1395
|
declare function square(x: ArrayLike): Array;
|
|
1361
1396
|
/** Element-wise tangent function (takes radians). */
|
|
@@ -1367,8 +1402,8 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1367
1402
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1368
1403
|
*
|
|
1369
1404
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1370
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
1371
|
-
* improvements.
|
|
1405
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
1406
|
+
* stability improvements.
|
|
1372
1407
|
*/
|
|
1373
1408
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1374
1409
|
/**
|
|
@@ -1500,6 +1535,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1500
1535
|
mean?: ArrayLike;
|
|
1501
1536
|
correction?: number;
|
|
1502
1537
|
} & ReduceOpts): Array;
|
|
1538
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1503
1539
|
//#endregion
|
|
1504
1540
|
//#region src/frontend/jaxpr.d.ts
|
|
1505
1541
|
/**
|
|
@@ -1574,10 +1610,9 @@ declare class Jaxpr implements FpHashable {
|
|
|
1574
1610
|
/** @inline */
|
|
1575
1611
|
type JitOpts = {
|
|
1576
1612
|
staticArgnums?: number[];
|
|
1577
|
-
device?: Device;
|
|
1578
1613
|
};
|
|
1579
1614
|
declare namespace lax_d_exports {
|
|
1580
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1615
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1581
1616
|
}
|
|
1582
1617
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1583
1618
|
/**
|
|
@@ -1601,6 +1636,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
|
|
|
1601
1636
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1602
1637
|
/** Reduce a computation over padded windows. */
|
|
1603
1638
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1639
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1640
|
+
declare function erf(x: ArrayLike): Array;
|
|
1641
|
+
/**
|
|
1642
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
1643
|
+
*
|
|
1644
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1645
|
+
* where `erf(x)` is very close to 1.
|
|
1646
|
+
*/
|
|
1647
|
+
declare function erfc(x: ArrayLike): Array;
|
|
1648
|
+
/**
|
|
1649
|
+
* Stops gradient computation.
|
|
1650
|
+
*
|
|
1651
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1652
|
+
* forward or reverse-mode automatic differentiation.
|
|
1653
|
+
*/
|
|
1654
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1655
|
+
//# sourceMappingURL=lax.d.ts.map
|
|
1604
1656
|
declare namespace nn_d_exports {
|
|
1605
1657
|
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1606
1658
|
}
|
|
@@ -1685,15 +1737,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1685
1737
|
* @function
|
|
1686
1738
|
* Gaussion error linear unit (GELU) activation function.
|
|
1687
1739
|
*
|
|
1688
|
-
* This is computed element-wise.
|
|
1689
|
-
*
|
|
1690
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1740
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
1741
|
+
* `approximate` is set (default true):
|
|
1691
1742
|
*
|
|
1692
|
-
*
|
|
1743
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
1744
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
1693
1745
|
*
|
|
1694
|
-
*
|
|
1746
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1695
1747
|
*/
|
|
1696
|
-
declare const gelu: OwnedFunction<(x: ArrayLike
|
|
1748
|
+
declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
|
|
1749
|
+
approximate?: boolean | undefined;
|
|
1750
|
+
} | undefined) => Array>;
|
|
1697
1751
|
/**
|
|
1698
1752
|
* Gated linear unit (GLU) activation function.
|
|
1699
1753
|
*
|
|
@@ -1774,6 +1828,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1774
1828
|
* ```
|
|
1775
1829
|
*/
|
|
1776
1830
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1831
|
+
//# sourceMappingURL=nn.d.ts.map
|
|
1777
1832
|
declare namespace random_d_exports {
|
|
1778
1833
|
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1779
1834
|
}
|
|
@@ -1812,6 +1867,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
1812
1867
|
* bitwise identical to JAX.
|
|
1813
1868
|
*/
|
|
1814
1869
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1870
|
+
//# sourceMappingURL=random.d.ts.map
|
|
1871
|
+
declare namespace scipy_special_d_exports {
|
|
1872
|
+
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1873
|
+
}
|
|
1874
|
+
/**
|
|
1875
|
+
* @function
|
|
1876
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
1877
|
+
*/
|
|
1878
|
+
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1815
1879
|
//#endregion
|
|
1816
1880
|
//#region src/index.d.ts
|
|
1817
1881
|
/**
|
|
@@ -1823,7 +1887,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
|
|
|
1823
1887
|
* @function
|
|
1824
1888
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1825
1889
|
*/
|
|
1826
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number |
|
|
1890
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1827
1891
|
/**
|
|
1828
1892
|
* @function
|
|
1829
1893
|
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
@@ -1898,5 +1962,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
|
|
|
1898
1962
|
* Does not consume reference to the arrays.
|
|
1899
1963
|
*/
|
|
1900
1964
|
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1965
|
+
/**
|
|
1966
|
+
* Transfer `x` to `device`.
|
|
1967
|
+
*
|
|
1968
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
1969
|
+
* is committed to the device.
|
|
1970
|
+
*
|
|
1971
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
1972
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
1973
|
+
* default device.
|
|
1974
|
+
*/
|
|
1975
|
+
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
1976
|
+
//# sourceMappingURL=index.d.ts.map
|
|
1977
|
+
|
|
1901
1978
|
//#endregion
|
|
1902
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1979
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1980
|
+
//# sourceMappingURL=index.d.cts.map
|
package/dist/index.d.ts
CHANGED
|
@@ -121,6 +121,7 @@ declare class ShapeTracker {
|
|
|
121
121
|
/** Like pad(), but allows for negative values. */
|
|
122
122
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
123
123
|
}
|
|
124
|
+
//# sourceMappingURL=shape.d.ts.map
|
|
124
125
|
//#endregion
|
|
125
126
|
//#region src/utils.d.ts
|
|
126
127
|
/**
|
|
@@ -219,6 +220,8 @@ declare class AluExp implements FpHashable {
|
|
|
219
220
|
static atan(a: AluExp): AluExp;
|
|
220
221
|
static exp(a: AluExp): AluExp;
|
|
221
222
|
static log(a: AluExp): AluExp;
|
|
223
|
+
static erf(a: AluExp): AluExp;
|
|
224
|
+
static erfc(a: AluExp): AluExp;
|
|
222
225
|
static sqrt(a: AluExp): AluExp;
|
|
223
226
|
static reciprocal(a: AluExp): AluExp;
|
|
224
227
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -286,8 +289,8 @@ declare class AluExp implements FpHashable {
|
|
|
286
289
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
287
290
|
/** Collect all nodes that satisfy a predicate. */
|
|
288
291
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
289
|
-
/** Produce
|
|
290
|
-
distinctOps(): Set<
|
|
292
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
293
|
+
distinctOps(): Map<AluOp, Set<DType>>;
|
|
291
294
|
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
292
295
|
rewriteGlobalViews(): AluExp;
|
|
293
296
|
}
|
|
@@ -306,6 +309,8 @@ declare enum AluOp {
|
|
|
306
309
|
Atan = "Atan",
|
|
307
310
|
Exp = "Exp",
|
|
308
311
|
Log = "Log",
|
|
312
|
+
Erf = "Erf",
|
|
313
|
+
Erfc = "Erfc",
|
|
309
314
|
Sqrt = "Sqrt",
|
|
310
315
|
Reciprocal = "Reciprocal",
|
|
311
316
|
Cast = "Cast",
|
|
@@ -462,7 +467,7 @@ type JsTree<T> = T | JsTree<T>[] | {
|
|
|
462
467
|
[key: string]: JsTree<T>;
|
|
463
468
|
};
|
|
464
469
|
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
465
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
470
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
466
471
|
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
467
472
|
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
468
473
|
/** Represents the structure of a JsTree. */
|
|
@@ -474,6 +479,8 @@ declare class JsTreeDef {
|
|
|
474
479
|
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
475
480
|
// Must be comparable with deepEqual.
|
|
476
481
|
childTreedefs: JsTreeDef[]);
|
|
482
|
+
/** Get the total number of leaves in the tree. */
|
|
483
|
+
get size(): number;
|
|
477
484
|
/** Returns a string representation of this tree definition. */
|
|
478
485
|
toString(root?: boolean): string;
|
|
479
486
|
/** Compare this tree definition with another. */
|
|
@@ -537,6 +544,8 @@ declare enum Primitive {
|
|
|
537
544
|
Atan = "atan",
|
|
538
545
|
Exp = "exp",
|
|
539
546
|
Log = "log",
|
|
547
|
+
Erf = "erf",
|
|
548
|
+
Erfc = "erfc",
|
|
540
549
|
Sqrt = "sqrt",
|
|
541
550
|
Min = "min",
|
|
542
551
|
Max = "max",
|
|
@@ -883,12 +892,13 @@ type ArrayConstructorArgs = {
|
|
|
883
892
|
dtype: DType;
|
|
884
893
|
weakType: boolean;
|
|
885
894
|
backend: Backend;
|
|
895
|
+
committed: boolean;
|
|
886
896
|
pending?: Iterable<PendingExecute>;
|
|
887
897
|
};
|
|
888
898
|
/**
|
|
889
899
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
890
900
|
*
|
|
891
|
-
* This is the library's core data type. Equivalent to `
|
|
901
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
892
902
|
* `torch.Tensor`.
|
|
893
903
|
*
|
|
894
904
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -964,7 +974,12 @@ declare class Array extends Tracer {
|
|
|
964
974
|
item(): number;
|
|
965
975
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
966
976
|
static _implRules(): typeof implRules;
|
|
977
|
+
/** @private */
|
|
967
978
|
_realizeSource(): number;
|
|
979
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
980
|
+
_put(backend: Backend): Promise<Array>;
|
|
981
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
982
|
+
_putSync(backend: Backend): Array;
|
|
968
983
|
}
|
|
969
984
|
/** Constructor for creating a new array from data. */
|
|
970
985
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
@@ -1040,7 +1055,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1040
1055
|
device
|
|
1041
1056
|
}?: DTypeAndDevice): Array;
|
|
1042
1057
|
declare namespace numpy_d_exports {
|
|
1043
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1058
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1044
1059
|
}
|
|
1045
1060
|
declare const float32 = DType.Float32;
|
|
1046
1061
|
declare const int32 = DType.Int32;
|
|
@@ -1123,7 +1138,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1123
1138
|
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1124
1139
|
* pair specifies the padding for its corresponding axis.
|
|
1125
1140
|
*/
|
|
1126
|
-
declare const pad: (x: ArrayLike, width: number |
|
|
1141
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
1127
1142
|
/**
|
|
1128
1143
|
* @function
|
|
1129
1144
|
* Return the number of dimensions of an array. Does not consume array reference.
|
|
@@ -1353,6 +1368,26 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1353
1368
|
declare const abs: typeof absolute;
|
|
1354
1369
|
/** Return an element-wise indication of sign of the input. */
|
|
1355
1370
|
declare function sign(x: ArrayLike): Array;
|
|
1371
|
+
/**
|
|
1372
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1373
|
+
*
|
|
1374
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1375
|
+
*/
|
|
1376
|
+
declare function hamming(M: number): Array;
|
|
1377
|
+
/**
|
|
1378
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
1379
|
+
*
|
|
1380
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1381
|
+
*/
|
|
1382
|
+
declare function hann(M: number): Array;
|
|
1383
|
+
/**
|
|
1384
|
+
* @function
|
|
1385
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
1386
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
1387
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
1388
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1389
|
+
*/
|
|
1390
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1356
1391
|
/** Calculate element-wise square of the input array. */
|
|
1357
1392
|
declare function square(x: ArrayLike): Array;
|
|
1358
1393
|
/** Element-wise tangent function (takes radians). */
|
|
@@ -1364,8 +1399,8 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1364
1399
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1365
1400
|
*
|
|
1366
1401
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1367
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
1368
|
-
* improvements.
|
|
1402
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
1403
|
+
* stability improvements.
|
|
1369
1404
|
*/
|
|
1370
1405
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1371
1406
|
/**
|
|
@@ -1497,6 +1532,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1497
1532
|
mean?: ArrayLike;
|
|
1498
1533
|
correction?: number;
|
|
1499
1534
|
} & ReduceOpts): Array;
|
|
1535
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1500
1536
|
//#endregion
|
|
1501
1537
|
//#region src/frontend/jaxpr.d.ts
|
|
1502
1538
|
/**
|
|
@@ -1571,10 +1607,9 @@ declare class Jaxpr implements FpHashable {
|
|
|
1571
1607
|
/** @inline */
|
|
1572
1608
|
type JitOpts = {
|
|
1573
1609
|
staticArgnums?: number[];
|
|
1574
|
-
device?: Device;
|
|
1575
1610
|
};
|
|
1576
1611
|
declare namespace lax_d_exports {
|
|
1577
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1612
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1578
1613
|
}
|
|
1579
1614
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1580
1615
|
/**
|
|
@@ -1598,6 +1633,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
|
|
|
1598
1633
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1599
1634
|
/** Reduce a computation over padded windows. */
|
|
1600
1635
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1636
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1637
|
+
declare function erf(x: ArrayLike): Array;
|
|
1638
|
+
/**
|
|
1639
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
1640
|
+
*
|
|
1641
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1642
|
+
* where `erf(x)` is very close to 1.
|
|
1643
|
+
*/
|
|
1644
|
+
declare function erfc(x: ArrayLike): Array;
|
|
1645
|
+
/**
|
|
1646
|
+
* Stops gradient computation.
|
|
1647
|
+
*
|
|
1648
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1649
|
+
* forward or reverse-mode automatic differentiation.
|
|
1650
|
+
*/
|
|
1651
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1652
|
+
//# sourceMappingURL=lax.d.ts.map
|
|
1601
1653
|
declare namespace nn_d_exports {
|
|
1602
1654
|
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1603
1655
|
}
|
|
@@ -1682,15 +1734,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1682
1734
|
* @function
|
|
1683
1735
|
* Gaussion error linear unit (GELU) activation function.
|
|
1684
1736
|
*
|
|
1685
|
-
* This is computed element-wise.
|
|
1686
|
-
*
|
|
1687
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1737
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
1738
|
+
* `approximate` is set (default true):
|
|
1688
1739
|
*
|
|
1689
|
-
*
|
|
1740
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
1741
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
1690
1742
|
*
|
|
1691
|
-
*
|
|
1743
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1692
1744
|
*/
|
|
1693
|
-
declare const gelu: OwnedFunction<(x: ArrayLike
|
|
1745
|
+
declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
|
|
1746
|
+
approximate?: boolean | undefined;
|
|
1747
|
+
} | undefined) => Array>;
|
|
1694
1748
|
/**
|
|
1695
1749
|
* Gated linear unit (GLU) activation function.
|
|
1696
1750
|
*
|
|
@@ -1771,6 +1825,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1771
1825
|
* ```
|
|
1772
1826
|
*/
|
|
1773
1827
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1828
|
+
//# sourceMappingURL=nn.d.ts.map
|
|
1774
1829
|
declare namespace random_d_exports {
|
|
1775
1830
|
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1776
1831
|
}
|
|
@@ -1809,6 +1864,15 @@ declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | und
|
|
|
1809
1864
|
* bitwise identical to JAX.
|
|
1810
1865
|
*/
|
|
1811
1866
|
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1867
|
+
//# sourceMappingURL=random.d.ts.map
|
|
1868
|
+
declare namespace scipy_special_d_exports {
|
|
1869
|
+
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1870
|
+
}
|
|
1871
|
+
/**
|
|
1872
|
+
* @function
|
|
1873
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
1874
|
+
*/
|
|
1875
|
+
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1812
1876
|
//#endregion
|
|
1813
1877
|
//#region src/index.d.ts
|
|
1814
1878
|
/**
|
|
@@ -1820,7 +1884,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
|
|
|
1820
1884
|
* @function
|
|
1821
1885
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1822
1886
|
*/
|
|
1823
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number |
|
|
1887
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1824
1888
|
/**
|
|
1825
1889
|
* @function
|
|
1826
1890
|
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
@@ -1895,5 +1959,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
|
|
|
1895
1959
|
* Does not consume reference to the arrays.
|
|
1896
1960
|
*/
|
|
1897
1961
|
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1962
|
+
/**
|
|
1963
|
+
* Transfer `x` to `device`.
|
|
1964
|
+
*
|
|
1965
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
1966
|
+
* is committed to the device.
|
|
1967
|
+
*
|
|
1968
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
1969
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
1970
|
+
* default device.
|
|
1971
|
+
*/
|
|
1972
|
+
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
1973
|
+
//# sourceMappingURL=index.d.ts.map
|
|
1974
|
+
|
|
1898
1975
|
//#endregion
|
|
1899
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1976
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1977
|
+
//# sourceMappingURL=index.d.ts.map
|