@jax-js/jax 0.0.4 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +296 -78
- package/dist/{backend-EBRGmEYw.js → backend-DwIAd0AG.js} +238 -116
- package/dist/{backend-Ss1Mev_-.cjs → backend-FtkbO6pI.cjs} +256 -122
- package/dist/index.cjs +653 -277
- package/dist/index.d.cts +167 -44
- package/dist/index.d.ts +167 -44
- package/dist/index.js +637 -268
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.d.ts
CHANGED
|
@@ -121,6 +121,7 @@ declare class ShapeTracker {
|
|
|
121
121
|
/** Like pad(), but allows for negative values. */
|
|
122
122
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
123
123
|
}
|
|
124
|
+
//# sourceMappingURL=shape.d.ts.map
|
|
124
125
|
//#endregion
|
|
125
126
|
//#region src/utils.d.ts
|
|
126
127
|
/**
|
|
@@ -177,12 +178,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
177
178
|
* **Type lattice:**
|
|
178
179
|
* ```text
|
|
179
180
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
180
|
-
*
|
|
181
|
+
* weakType --^
|
|
181
182
|
* ```
|
|
182
183
|
*
|
|
183
|
-
*
|
|
184
|
-
*
|
|
185
|
-
*
|
|
184
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
185
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
186
|
+
* they are first combined with, except `bool`.
|
|
186
187
|
*
|
|
187
188
|
* **Examples:**
|
|
188
189
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -219,6 +220,8 @@ declare class AluExp implements FpHashable {
|
|
|
219
220
|
static atan(a: AluExp): AluExp;
|
|
220
221
|
static exp(a: AluExp): AluExp;
|
|
221
222
|
static log(a: AluExp): AluExp;
|
|
223
|
+
static erf(a: AluExp): AluExp;
|
|
224
|
+
static erfc(a: AluExp): AluExp;
|
|
222
225
|
static sqrt(a: AluExp): AluExp;
|
|
223
226
|
static reciprocal(a: AluExp): AluExp;
|
|
224
227
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -286,8 +289,8 @@ declare class AluExp implements FpHashable {
|
|
|
286
289
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
287
290
|
/** Collect all nodes that satisfy a predicate. */
|
|
288
291
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
289
|
-
/** Produce
|
|
290
|
-
distinctOps(): Set<
|
|
292
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
293
|
+
distinctOps(): Map<AluOp, Set<DType>>;
|
|
291
294
|
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
292
295
|
rewriteGlobalViews(): AluExp;
|
|
293
296
|
}
|
|
@@ -306,6 +309,8 @@ declare enum AluOp {
|
|
|
306
309
|
Atan = "Atan",
|
|
307
310
|
Exp = "Exp",
|
|
308
311
|
Log = "Log",
|
|
312
|
+
Erf = "Erf",
|
|
313
|
+
Erfc = "Erfc",
|
|
309
314
|
Sqrt = "Sqrt",
|
|
310
315
|
Reciprocal = "Reciprocal",
|
|
311
316
|
Cast = "Cast",
|
|
@@ -462,7 +467,7 @@ type JsTree<T> = T | JsTree<T>[] | {
|
|
|
462
467
|
[key: string]: JsTree<T>;
|
|
463
468
|
};
|
|
464
469
|
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
465
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
470
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
466
471
|
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
467
472
|
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
468
473
|
/** Represents the structure of a JsTree. */
|
|
@@ -474,6 +479,8 @@ declare class JsTreeDef {
|
|
|
474
479
|
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
475
480
|
// Must be comparable with deepEqual.
|
|
476
481
|
childTreedefs: JsTreeDef[]);
|
|
482
|
+
/** Get the total number of leaves in the tree. */
|
|
483
|
+
get size(): number;
|
|
477
484
|
/** Returns a string representation of this tree definition. */
|
|
478
485
|
toString(root?: boolean): string;
|
|
479
486
|
/** Compare this tree definition with another. */
|
|
@@ -537,6 +544,8 @@ declare enum Primitive {
|
|
|
537
544
|
Atan = "atan",
|
|
538
545
|
Exp = "exp",
|
|
539
546
|
Log = "log",
|
|
547
|
+
Erf = "erf",
|
|
548
|
+
Erfc = "erfc",
|
|
540
549
|
Sqrt = "sqrt",
|
|
541
550
|
Min = "min",
|
|
542
551
|
Max = "max",
|
|
@@ -610,6 +619,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
610
619
|
outDim: number;
|
|
611
620
|
};
|
|
612
621
|
[Primitive.JitCall]: {
|
|
622
|
+
name: string;
|
|
613
623
|
jaxpr: Jaxpr;
|
|
614
624
|
numConsts: number;
|
|
615
625
|
};
|
|
@@ -648,10 +658,40 @@ declare abstract class Trace {
|
|
|
648
658
|
abstract lift(val: Tracer): Tracer;
|
|
649
659
|
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
650
660
|
}
|
|
661
|
+
/** Internal representation of an array value. */
|
|
651
662
|
interface AbstractValue {
|
|
663
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
652
664
|
shape: number[];
|
|
665
|
+
/** Concrete data type of array elements. */
|
|
653
666
|
dtype: DType;
|
|
667
|
+
/**
|
|
668
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
669
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
670
|
+
*
|
|
671
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
672
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
673
|
+
* how they promote in type casting; their memory layout is still determined
|
|
674
|
+
* by the actual `dtype` field.
|
|
675
|
+
*
|
|
676
|
+
* ```ts
|
|
677
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
678
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
679
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
680
|
+
* ```
|
|
681
|
+
*
|
|
682
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
683
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
684
|
+
* concept. Backends are not aware of weak types.
|
|
685
|
+
*/
|
|
686
|
+
weakType: boolean;
|
|
654
687
|
}
|
|
688
|
+
/**
|
|
689
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
690
|
+
*
|
|
691
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
692
|
+
* implemented in that function as `weakType` is not passed.
|
|
693
|
+
*/
|
|
694
|
+
|
|
655
695
|
declare abstract class Tracer {
|
|
656
696
|
/** @ignore */
|
|
657
697
|
readonly _trace: Trace;
|
|
@@ -709,8 +749,15 @@ declare abstract class Tracer {
|
|
|
709
749
|
get shape(): number[];
|
|
710
750
|
/** The total number of elements in the array. */
|
|
711
751
|
get size(): number;
|
|
712
|
-
/** The dtype of the array. */
|
|
752
|
+
/** The dtype of elements stored in the array. */
|
|
713
753
|
get dtype(): DType;
|
|
754
|
+
/**
|
|
755
|
+
* Whether the array is weakly typed.
|
|
756
|
+
*
|
|
757
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
758
|
+
* `promoteTypes()` for details.
|
|
759
|
+
*/
|
|
760
|
+
get weakType(): boolean;
|
|
714
761
|
/** The number of dimensions of the array. */
|
|
715
762
|
get ndim(): number;
|
|
716
763
|
/** @ignore */
|
|
@@ -802,7 +849,8 @@ declare abstract class Tracer {
|
|
|
802
849
|
declare class ShapedArray implements AbstractValue {
|
|
803
850
|
readonly shape: number[];
|
|
804
851
|
readonly dtype: DType;
|
|
805
|
-
|
|
852
|
+
readonly weakType: boolean;
|
|
853
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
806
854
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
807
855
|
get ndim(): number;
|
|
808
856
|
toString(): string;
|
|
@@ -838,10 +886,19 @@ type DTypeAndDevice = {
|
|
|
838
886
|
dtype?: DType;
|
|
839
887
|
device?: Device;
|
|
840
888
|
};
|
|
889
|
+
type ArrayConstructorArgs = {
|
|
890
|
+
source: AluExp | Slot;
|
|
891
|
+
st: ShapeTracker;
|
|
892
|
+
dtype: DType;
|
|
893
|
+
weakType: boolean;
|
|
894
|
+
backend: Backend;
|
|
895
|
+
committed: boolean;
|
|
896
|
+
pending?: Iterable<PendingExecute>;
|
|
897
|
+
};
|
|
841
898
|
/**
|
|
842
899
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
843
900
|
*
|
|
844
|
-
* This is the library's core data type. Equivalent to `
|
|
901
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
845
902
|
* `torch.Tensor`.
|
|
846
903
|
*
|
|
847
904
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -857,11 +914,7 @@ declare class Array extends Tracer {
|
|
|
857
914
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
858
915
|
* will be freed when the array is disposed.
|
|
859
916
|
*/
|
|
860
|
-
constructor(
|
|
861
|
-
pending
|
|
862
|
-
}?: {
|
|
863
|
-
pending?: Iterable<PendingExecute> | null;
|
|
864
|
-
});
|
|
917
|
+
constructor(args: ArrayConstructorArgs);
|
|
865
918
|
/** @ignore */
|
|
866
919
|
get aval(): ShapedArray;
|
|
867
920
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -921,10 +974,13 @@ declare class Array extends Tracer {
|
|
|
921
974
|
item(): number;
|
|
922
975
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
923
976
|
static _implRules(): typeof implRules;
|
|
977
|
+
/** @private */
|
|
924
978
|
_realizeSource(): number;
|
|
979
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
980
|
+
_put(backend: Backend): Promise<Array>;
|
|
981
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
982
|
+
_putSync(backend: Backend): Array;
|
|
925
983
|
}
|
|
926
|
-
/** Construct an array from a single scalar constant. */
|
|
927
|
-
|
|
928
984
|
/** Constructor for creating a new array from data. */
|
|
929
985
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
930
986
|
shape,
|
|
@@ -999,7 +1055,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
999
1055
|
device
|
|
1000
1056
|
}?: DTypeAndDevice): Array;
|
|
1001
1057
|
declare namespace numpy_d_exports {
|
|
1002
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1058
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1003
1059
|
}
|
|
1004
1060
|
declare const float32 = DType.Float32;
|
|
1005
1061
|
declare const int32 = DType.Int32;
|
|
@@ -1082,7 +1138,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1082
1138
|
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1083
1139
|
* pair specifies the padding for its corresponding axis.
|
|
1084
1140
|
*/
|
|
1085
|
-
declare const pad: (x: ArrayLike, width: number |
|
|
1141
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
1086
1142
|
/**
|
|
1087
1143
|
* @function
|
|
1088
1144
|
* Return the number of dimensions of an array. Does not consume array reference.
|
|
@@ -1312,6 +1368,26 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1312
1368
|
declare const abs: typeof absolute;
|
|
1313
1369
|
/** Return an element-wise indication of sign of the input. */
|
|
1314
1370
|
declare function sign(x: ArrayLike): Array;
|
|
1371
|
+
/**
|
|
1372
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1373
|
+
*
|
|
1374
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1375
|
+
*/
|
|
1376
|
+
declare function hamming(M: number): Array;
|
|
1377
|
+
/**
|
|
1378
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
1379
|
+
*
|
|
1380
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1381
|
+
*/
|
|
1382
|
+
declare function hann(M: number): Array;
|
|
1383
|
+
/**
|
|
1384
|
+
* @function
|
|
1385
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
1386
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
1387
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
1388
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1389
|
+
*/
|
|
1390
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1315
1391
|
/** Calculate element-wise square of the input array. */
|
|
1316
1392
|
declare function square(x: ArrayLike): Array;
|
|
1317
1393
|
/** Element-wise tangent function (takes radians). */
|
|
@@ -1323,8 +1399,8 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1323
1399
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1324
1400
|
*
|
|
1325
1401
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1326
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
1327
|
-
* improvements.
|
|
1402
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
1403
|
+
* stability improvements.
|
|
1328
1404
|
*/
|
|
1329
1405
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1330
1406
|
/**
|
|
@@ -1456,6 +1532,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1456
1532
|
mean?: ArrayLike;
|
|
1457
1533
|
correction?: number;
|
|
1458
1534
|
} & ReduceOpts): Array;
|
|
1535
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1459
1536
|
//#endregion
|
|
1460
1537
|
//#region src/frontend/jaxpr.d.ts
|
|
1461
1538
|
/**
|
|
@@ -1477,10 +1554,10 @@ declare class Var {
|
|
|
1477
1554
|
}
|
|
1478
1555
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1479
1556
|
declare class Lit {
|
|
1480
|
-
readonly dtype: DType;
|
|
1481
1557
|
readonly value: number;
|
|
1482
1558
|
readonly aval: ShapedArray;
|
|
1483
|
-
|
|
1559
|
+
get dtype(): DType;
|
|
1560
|
+
constructor(aval: AbstractValue, value: number);
|
|
1484
1561
|
}
|
|
1485
1562
|
type Atom = Var | Lit;
|
|
1486
1563
|
declare class VarPrinter {
|
|
@@ -1530,10 +1607,9 @@ declare class Jaxpr implements FpHashable {
|
|
|
1530
1607
|
/** @inline */
|
|
1531
1608
|
type JitOpts = {
|
|
1532
1609
|
staticArgnums?: number[];
|
|
1533
|
-
device?: Device;
|
|
1534
1610
|
};
|
|
1535
1611
|
declare namespace lax_d_exports {
|
|
1536
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1612
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1537
1613
|
}
|
|
1538
1614
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1539
1615
|
/**
|
|
@@ -1557,6 +1633,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
|
|
|
1557
1633
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1558
1634
|
/** Reduce a computation over padded windows. */
|
|
1559
1635
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1636
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1637
|
+
declare function erf(x: ArrayLike): Array;
|
|
1638
|
+
/**
|
|
1639
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
1640
|
+
*
|
|
1641
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1642
|
+
* where `erf(x)` is very close to 1.
|
|
1643
|
+
*/
|
|
1644
|
+
declare function erfc(x: ArrayLike): Array;
|
|
1645
|
+
/**
|
|
1646
|
+
* Stops gradient computation.
|
|
1647
|
+
*
|
|
1648
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1649
|
+
* forward or reverse-mode automatic differentiation.
|
|
1650
|
+
*/
|
|
1651
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1652
|
+
//# sourceMappingURL=lax.d.ts.map
|
|
1560
1653
|
declare namespace nn_d_exports {
|
|
1561
1654
|
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1562
1655
|
}
|
|
@@ -1641,15 +1734,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1641
1734
|
* @function
|
|
1642
1735
|
* Gaussion error linear unit (GELU) activation function.
|
|
1643
1736
|
*
|
|
1644
|
-
* This is computed element-wise.
|
|
1645
|
-
*
|
|
1646
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1737
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
1738
|
+
* `approximate` is set (default true):
|
|
1647
1739
|
*
|
|
1648
|
-
*
|
|
1740
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
1741
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
1649
1742
|
*
|
|
1650
|
-
*
|
|
1743
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1651
1744
|
*/
|
|
1652
|
-
declare const gelu: OwnedFunction<(x: ArrayLike
|
|
1745
|
+
declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
|
|
1746
|
+
approximate?: boolean | undefined;
|
|
1747
|
+
} | undefined) => Array>;
|
|
1653
1748
|
/**
|
|
1654
1749
|
* Gated linear unit (GLU) activation function.
|
|
1655
1750
|
*
|
|
@@ -1730,6 +1825,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1730
1825
|
* ```
|
|
1731
1826
|
*/
|
|
1732
1827
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1828
|
+
//# sourceMappingURL=nn.d.ts.map
|
|
1733
1829
|
declare namespace random_d_exports {
|
|
1734
1830
|
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1735
1831
|
}
|
|
@@ -1739,14 +1835,14 @@ declare function key(seed: number): Array;
|
|
|
1739
1835
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1740
1836
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
1741
1837
|
declare function bits(key: Array, shape?: number[]): Array;
|
|
1742
|
-
/**
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
1746
|
-
|
|
1747
|
-
minval?: number;
|
|
1748
|
-
maxval?: number;
|
|
1749
|
-
})
|
|
1838
|
+
/**
|
|
1839
|
+
* @function
|
|
1840
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
1841
|
+
*/
|
|
1842
|
+
declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
|
|
1843
|
+
minval?: number | undefined;
|
|
1844
|
+
maxval?: number | undefined;
|
|
1845
|
+
} | undefined) => Array>;
|
|
1750
1846
|
/**
|
|
1751
1847
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1752
1848
|
*
|
|
@@ -1754,16 +1850,29 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1754
1850
|
* and must be broadcastable to `shape`.
|
|
1755
1851
|
*/
|
|
1756
1852
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1757
|
-
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1758
|
-
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1759
1853
|
/**
|
|
1854
|
+
* @function
|
|
1855
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1856
|
+
*/
|
|
1857
|
+
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1858
|
+
/**
|
|
1859
|
+
* @function
|
|
1760
1860
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1761
1861
|
*
|
|
1762
1862
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1763
1863
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1764
1864
|
* bitwise identical to JAX.
|
|
1765
1865
|
*/
|
|
1766
|
-
declare
|
|
1866
|
+
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1867
|
+
//# sourceMappingURL=random.d.ts.map
|
|
1868
|
+
declare namespace scipy_special_d_exports {
|
|
1869
|
+
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1870
|
+
}
|
|
1871
|
+
/**
|
|
1872
|
+
* @function
|
|
1873
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
1874
|
+
*/
|
|
1875
|
+
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1767
1876
|
//#endregion
|
|
1768
1877
|
//#region src/index.d.ts
|
|
1769
1878
|
/**
|
|
@@ -1775,7 +1884,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
|
|
|
1775
1884
|
* @function
|
|
1776
1885
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1777
1886
|
*/
|
|
1778
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number |
|
|
1887
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1779
1888
|
/**
|
|
1780
1889
|
* @function
|
|
1781
1890
|
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
@@ -1850,5 +1959,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
|
|
|
1850
1959
|
* Does not consume reference to the arrays.
|
|
1851
1960
|
*/
|
|
1852
1961
|
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1962
|
+
/**
|
|
1963
|
+
* Transfer `x` to `device`.
|
|
1964
|
+
*
|
|
1965
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
1966
|
+
* is committed to the device.
|
|
1967
|
+
*
|
|
1968
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
1969
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
1970
|
+
* default device.
|
|
1971
|
+
*/
|
|
1972
|
+
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
1973
|
+
//# sourceMappingURL=index.d.ts.map
|
|
1974
|
+
|
|
1853
1975
|
//#endregion
|
|
1854
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1976
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1977
|
+
//# sourceMappingURL=index.d.ts.map
|