@jax-js/jax 0.0.4 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +296 -78
- package/dist/{backend-EBRGmEYw.js → backend-DwIAd0AG.js} +238 -116
- package/dist/{backend-Ss1Mev_-.cjs → backend-FtkbO6pI.cjs} +256 -122
- package/dist/index.cjs +653 -277
- package/dist/index.d.cts +167 -44
- package/dist/index.d.ts +167 -44
- package/dist/index.js +637 -268
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.d.cts
CHANGED
|
@@ -124,6 +124,7 @@ declare class ShapeTracker {
|
|
|
124
124
|
/** Like pad(), but allows for negative values. */
|
|
125
125
|
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
126
126
|
}
|
|
127
|
+
//# sourceMappingURL=shape.d.ts.map
|
|
127
128
|
//#endregion
|
|
128
129
|
//#region src/utils.d.ts
|
|
129
130
|
/**
|
|
@@ -180,12 +181,12 @@ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Arr
|
|
|
180
181
|
* **Type lattice:**
|
|
181
182
|
* ```text
|
|
182
183
|
* bool -> uint32 -> int32 -> float16 -> float32
|
|
183
|
-
*
|
|
184
|
+
* weakType --^
|
|
184
185
|
* ```
|
|
185
186
|
*
|
|
186
|
-
*
|
|
187
|
-
*
|
|
188
|
-
*
|
|
187
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
188
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
189
|
+
* they are first combined with, except `bool`.
|
|
189
190
|
*
|
|
190
191
|
* **Examples:**
|
|
191
192
|
* - `promoteTypes(bool, int32) → int32`
|
|
@@ -222,6 +223,8 @@ declare class AluExp implements FpHashable {
|
|
|
222
223
|
static atan(a: AluExp): AluExp;
|
|
223
224
|
static exp(a: AluExp): AluExp;
|
|
224
225
|
static log(a: AluExp): AluExp;
|
|
226
|
+
static erf(a: AluExp): AluExp;
|
|
227
|
+
static erfc(a: AluExp): AluExp;
|
|
225
228
|
static sqrt(a: AluExp): AluExp;
|
|
226
229
|
static reciprocal(a: AluExp): AluExp;
|
|
227
230
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -289,8 +292,8 @@ declare class AluExp implements FpHashable {
|
|
|
289
292
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
290
293
|
/** Collect all nodes that satisfy a predicate. */
|
|
291
294
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
292
|
-
/** Produce
|
|
293
|
-
distinctOps(): Set<
|
|
295
|
+
/** Produce all distinct AluOp in this expression, with their dtypes. */
|
|
296
|
+
distinctOps(): Map<AluOp, Set<DType>>;
|
|
294
297
|
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
295
298
|
rewriteGlobalViews(): AluExp;
|
|
296
299
|
}
|
|
@@ -309,6 +312,8 @@ declare enum AluOp {
|
|
|
309
312
|
Atan = "Atan",
|
|
310
313
|
Exp = "Exp",
|
|
311
314
|
Log = "Log",
|
|
315
|
+
Erf = "Erf",
|
|
316
|
+
Erfc = "Erfc",
|
|
312
317
|
Sqrt = "Sqrt",
|
|
313
318
|
Reciprocal = "Reciprocal",
|
|
314
319
|
Cast = "Cast",
|
|
@@ -465,7 +470,7 @@ type JsTree<T> = T | JsTree<T>[] | {
|
|
|
465
470
|
[key: string]: JsTree<T>;
|
|
466
471
|
};
|
|
467
472
|
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
468
|
-
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
473
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends Array ? T : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
469
474
|
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
470
475
|
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
471
476
|
/** Represents the structure of a JsTree. */
|
|
@@ -477,6 +482,8 @@ declare class JsTreeDef {
|
|
|
477
482
|
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
478
483
|
// Must be comparable with deepEqual.
|
|
479
484
|
childTreedefs: JsTreeDef[]);
|
|
485
|
+
/** Get the total number of leaves in the tree. */
|
|
486
|
+
get size(): number;
|
|
480
487
|
/** Returns a string representation of this tree definition. */
|
|
481
488
|
toString(root?: boolean): string;
|
|
482
489
|
/** Compare this tree definition with another. */
|
|
@@ -540,6 +547,8 @@ declare enum Primitive {
|
|
|
540
547
|
Atan = "atan",
|
|
541
548
|
Exp = "exp",
|
|
542
549
|
Log = "log",
|
|
550
|
+
Erf = "erf",
|
|
551
|
+
Erfc = "erfc",
|
|
543
552
|
Sqrt = "sqrt",
|
|
544
553
|
Min = "min",
|
|
545
554
|
Max = "max",
|
|
@@ -613,6 +622,7 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
|
613
622
|
outDim: number;
|
|
614
623
|
};
|
|
615
624
|
[Primitive.JitCall]: {
|
|
625
|
+
name: string;
|
|
616
626
|
jaxpr: Jaxpr;
|
|
617
627
|
numConsts: number;
|
|
618
628
|
};
|
|
@@ -651,10 +661,40 @@ declare abstract class Trace {
|
|
|
651
661
|
abstract lift(val: Tracer): Tracer;
|
|
652
662
|
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
653
663
|
}
|
|
664
|
+
/** Internal representation of an array value. */
|
|
654
665
|
interface AbstractValue {
|
|
666
|
+
/** Shape of the array. Must be a static tuple of non-negative dimensions. */
|
|
655
667
|
shape: number[];
|
|
668
|
+
/** Concrete data type of array elements. */
|
|
656
669
|
dtype: DType;
|
|
670
|
+
/**
|
|
671
|
+
* Arrays created from JavaScript numbers (e.g., `np.array(3)`) are created as
|
|
672
|
+
* _weakly typed_ unless a dtype is explicitly specified.
|
|
673
|
+
*
|
|
674
|
+
* Weakly typed values will automatically cast to the data type of other
|
|
675
|
+
* arrays when used as an operand as an expression. This property only affects
|
|
676
|
+
* how they promote in type casting; their memory layout is still determined
|
|
677
|
+
* by the actual `dtype` field.
|
|
678
|
+
*
|
|
679
|
+
* ```ts
|
|
680
|
+
* const x = np.array(3); // weakType = true, dtype = float32
|
|
681
|
+
* const y = np.array([1, 2], { dtype: np.int32 }); // weakType = false, dtype = int32
|
|
682
|
+
* const z = x.add(y); // z has dtype int32 because x is weakly typed
|
|
683
|
+
* ```
|
|
684
|
+
*
|
|
685
|
+
* Weak types are present in JIT programs in their spec (e.g., Jaxpr inputs
|
|
686
|
+
* and outputs can be weakly typed) form. But they're solely a frontend
|
|
687
|
+
* concept. Backends are not aware of weak types.
|
|
688
|
+
*/
|
|
689
|
+
weakType: boolean;
|
|
657
690
|
}
|
|
691
|
+
/**
|
|
692
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
693
|
+
*
|
|
694
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
695
|
+
* implemented in that function as `weakType` is not passed.
|
|
696
|
+
*/
|
|
697
|
+
|
|
658
698
|
declare abstract class Tracer {
|
|
659
699
|
/** @ignore */
|
|
660
700
|
readonly _trace: Trace;
|
|
@@ -712,8 +752,15 @@ declare abstract class Tracer {
|
|
|
712
752
|
get shape(): number[];
|
|
713
753
|
/** The total number of elements in the array. */
|
|
714
754
|
get size(): number;
|
|
715
|
-
/** The dtype of the array. */
|
|
755
|
+
/** The dtype of elements stored in the array. */
|
|
716
756
|
get dtype(): DType;
|
|
757
|
+
/**
|
|
758
|
+
* Whether the array is weakly typed.
|
|
759
|
+
*
|
|
760
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
761
|
+
* `promoteTypes()` for details.
|
|
762
|
+
*/
|
|
763
|
+
get weakType(): boolean;
|
|
717
764
|
/** The number of dimensions of the array. */
|
|
718
765
|
get ndim(): number;
|
|
719
766
|
/** @ignore */
|
|
@@ -805,7 +852,8 @@ declare abstract class Tracer {
|
|
|
805
852
|
declare class ShapedArray implements AbstractValue {
|
|
806
853
|
readonly shape: number[];
|
|
807
854
|
readonly dtype: DType;
|
|
808
|
-
|
|
855
|
+
readonly weakType: boolean;
|
|
856
|
+
constructor(shape: number[], dtype: DType, weakType: boolean);
|
|
809
857
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
810
858
|
get ndim(): number;
|
|
811
859
|
toString(): string;
|
|
@@ -841,10 +889,19 @@ type DTypeAndDevice = {
|
|
|
841
889
|
dtype?: DType;
|
|
842
890
|
device?: Device;
|
|
843
891
|
};
|
|
892
|
+
type ArrayConstructorArgs = {
|
|
893
|
+
source: AluExp | Slot;
|
|
894
|
+
st: ShapeTracker;
|
|
895
|
+
dtype: DType;
|
|
896
|
+
weakType: boolean;
|
|
897
|
+
backend: Backend;
|
|
898
|
+
committed: boolean;
|
|
899
|
+
pending?: Iterable<PendingExecute>;
|
|
900
|
+
};
|
|
844
901
|
/**
|
|
845
902
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
846
903
|
*
|
|
847
|
-
* This is the library's core data type. Equivalent to `
|
|
904
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
848
905
|
* `torch.Tensor`.
|
|
849
906
|
*
|
|
850
907
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -860,11 +917,7 @@ declare class Array extends Tracer {
|
|
|
860
917
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
861
918
|
* will be freed when the array is disposed.
|
|
862
919
|
*/
|
|
863
|
-
constructor(
|
|
864
|
-
pending
|
|
865
|
-
}?: {
|
|
866
|
-
pending?: Iterable<PendingExecute> | null;
|
|
867
|
-
});
|
|
920
|
+
constructor(args: ArrayConstructorArgs);
|
|
868
921
|
/** @ignore */
|
|
869
922
|
get aval(): ShapedArray;
|
|
870
923
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -924,10 +977,13 @@ declare class Array extends Tracer {
|
|
|
924
977
|
item(): number;
|
|
925
978
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
926
979
|
static _implRules(): typeof implRules;
|
|
980
|
+
/** @private */
|
|
927
981
|
_realizeSource(): number;
|
|
982
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
983
|
+
_put(backend: Backend): Promise<Array>;
|
|
984
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
985
|
+
_putSync(backend: Backend): Array;
|
|
928
986
|
}
|
|
929
|
-
/** Construct an array from a single scalar constant. */
|
|
930
|
-
|
|
931
987
|
/** Constructor for creating a new array from data. */
|
|
932
988
|
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
933
989
|
shape,
|
|
@@ -1002,7 +1058,7 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
1002
1058
|
device
|
|
1003
1059
|
}?: DTypeAndDevice): Array;
|
|
1004
1060
|
declare namespace numpy_d_exports {
|
|
1005
|
-
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1061
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hamming, hann, heaviside, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
1006
1062
|
}
|
|
1007
1063
|
declare const float32 = DType.Float32;
|
|
1008
1064
|
declare const int32 = DType.Int32;
|
|
@@ -1085,7 +1141,7 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
1085
1141
|
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1086
1142
|
* pair specifies the padding for its corresponding axis.
|
|
1087
1143
|
*/
|
|
1088
|
-
declare const pad: (x: ArrayLike, width: number |
|
|
1144
|
+
declare const pad: (x: ArrayLike, width: number | Pair | Pair[]) => Array;
|
|
1089
1145
|
/**
|
|
1090
1146
|
* @function
|
|
1091
1147
|
* Return the number of dimensions of an array. Does not consume array reference.
|
|
@@ -1315,6 +1371,26 @@ declare function absolute(x: ArrayLike): Array;
|
|
|
1315
1371
|
declare const abs: typeof absolute;
|
|
1316
1372
|
/** Return an element-wise indication of sign of the input. */
|
|
1317
1373
|
declare function sign(x: ArrayLike): Array;
|
|
1374
|
+
/**
|
|
1375
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
1376
|
+
*
|
|
1377
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1378
|
+
*/
|
|
1379
|
+
declare function hamming(M: number): Array;
|
|
1380
|
+
/**
|
|
1381
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
1382
|
+
*
|
|
1383
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
1384
|
+
*/
|
|
1385
|
+
declare function hann(M: number): Array;
|
|
1386
|
+
/**
|
|
1387
|
+
* @function
|
|
1388
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
1389
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
1390
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
1391
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
1392
|
+
*/
|
|
1393
|
+
declare const heaviside: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1318
1394
|
/** Calculate element-wise square of the input array. */
|
|
1319
1395
|
declare function square(x: ArrayLike): Array;
|
|
1320
1396
|
/** Element-wise tangent function (takes radians). */
|
|
@@ -1326,8 +1402,8 @@ declare function acos(x: ArrayLike): Array;
|
|
|
1326
1402
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1327
1403
|
*
|
|
1328
1404
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1329
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
1330
|
-
* improvements.
|
|
1405
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
1406
|
+
* stability improvements.
|
|
1331
1407
|
*/
|
|
1332
1408
|
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1333
1409
|
/**
|
|
@@ -1459,6 +1535,7 @@ declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1459
1535
|
mean?: ArrayLike;
|
|
1460
1536
|
correction?: number;
|
|
1461
1537
|
} & ReduceOpts): Array;
|
|
1538
|
+
//# sourceMappingURL=numpy.d.ts.map
|
|
1462
1539
|
//#endregion
|
|
1463
1540
|
//#region src/frontend/jaxpr.d.ts
|
|
1464
1541
|
/**
|
|
@@ -1480,10 +1557,10 @@ declare class Var {
|
|
|
1480
1557
|
}
|
|
1481
1558
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
1482
1559
|
declare class Lit {
|
|
1483
|
-
readonly dtype: DType;
|
|
1484
1560
|
readonly value: number;
|
|
1485
1561
|
readonly aval: ShapedArray;
|
|
1486
|
-
|
|
1562
|
+
get dtype(): DType;
|
|
1563
|
+
constructor(aval: AbstractValue, value: number);
|
|
1487
1564
|
}
|
|
1488
1565
|
type Atom = Var | Lit;
|
|
1489
1566
|
declare class VarPrinter {
|
|
@@ -1533,10 +1610,9 @@ declare class Jaxpr implements FpHashable {
|
|
|
1533
1610
|
/** @inline */
|
|
1534
1611
|
type JitOpts = {
|
|
1535
1612
|
staticArgnums?: number[];
|
|
1536
|
-
device?: Device;
|
|
1537
1613
|
};
|
|
1538
1614
|
declare namespace lax_d_exports {
|
|
1539
|
-
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1615
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, erf, erfc, reduceWindow, stopGradient };
|
|
1540
1616
|
}
|
|
1541
1617
|
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1542
1618
|
/**
|
|
@@ -1560,6 +1636,23 @@ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: n
|
|
|
1560
1636
|
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1561
1637
|
/** Reduce a computation over padded windows. */
|
|
1562
1638
|
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1639
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
1640
|
+
declare function erf(x: ArrayLike): Array;
|
|
1641
|
+
/**
|
|
1642
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
1643
|
+
*
|
|
1644
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
1645
|
+
* where `erf(x)` is very close to 1.
|
|
1646
|
+
*/
|
|
1647
|
+
declare function erfc(x: ArrayLike): Array;
|
|
1648
|
+
/**
|
|
1649
|
+
* Stops gradient computation.
|
|
1650
|
+
*
|
|
1651
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
1652
|
+
* forward or reverse-mode automatic differentiation.
|
|
1653
|
+
*/
|
|
1654
|
+
declare function stopGradient(x: ArrayLike): Array;
|
|
1655
|
+
//# sourceMappingURL=lax.d.ts.map
|
|
1563
1656
|
declare namespace nn_d_exports {
|
|
1564
1657
|
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1565
1658
|
}
|
|
@@ -1644,15 +1737,17 @@ declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
|
1644
1737
|
* @function
|
|
1645
1738
|
* Gaussion error linear unit (GELU) activation function.
|
|
1646
1739
|
*
|
|
1647
|
-
* This is computed element-wise.
|
|
1648
|
-
*
|
|
1649
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1740
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
1741
|
+
* `approximate` is set (default true):
|
|
1650
1742
|
*
|
|
1651
|
-
*
|
|
1743
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
1744
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
1652
1745
|
*
|
|
1653
|
-
*
|
|
1746
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1654
1747
|
*/
|
|
1655
|
-
declare const gelu: OwnedFunction<(x: ArrayLike
|
|
1748
|
+
declare const gelu: OwnedFunction<(x: ArrayLike, opts?: {
|
|
1749
|
+
approximate?: boolean | undefined;
|
|
1750
|
+
} | undefined) => Array>;
|
|
1656
1751
|
/**
|
|
1657
1752
|
* Gated linear unit (GLU) activation function.
|
|
1658
1753
|
*
|
|
@@ -1733,6 +1828,7 @@ declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
|
1733
1828
|
* ```
|
|
1734
1829
|
*/
|
|
1735
1830
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1831
|
+
//# sourceMappingURL=nn.d.ts.map
|
|
1736
1832
|
declare namespace random_d_exports {
|
|
1737
1833
|
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1738
1834
|
}
|
|
@@ -1742,14 +1838,14 @@ declare function key(seed: number): Array;
|
|
|
1742
1838
|
declare function split(key: Array, num?: number | number[]): Array;
|
|
1743
1839
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
1744
1840
|
declare function bits(key: Array, shape?: number[]): Array;
|
|
1745
|
-
/**
|
|
1746
|
-
|
|
1747
|
-
|
|
1748
|
-
|
|
1749
|
-
|
|
1750
|
-
minval?: number;
|
|
1751
|
-
maxval?: number;
|
|
1752
|
-
})
|
|
1841
|
+
/**
|
|
1842
|
+
* @function
|
|
1843
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
1844
|
+
*/
|
|
1845
|
+
declare const uniform: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined, args_2?: {
|
|
1846
|
+
minval?: number | undefined;
|
|
1847
|
+
maxval?: number | undefined;
|
|
1848
|
+
} | undefined) => Array>;
|
|
1753
1849
|
/**
|
|
1754
1850
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1755
1851
|
*
|
|
@@ -1757,16 +1853,29 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1757
1853
|
* and must be broadcastable to `shape`.
|
|
1758
1854
|
*/
|
|
1759
1855
|
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1760
|
-
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1761
|
-
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1762
1856
|
/**
|
|
1857
|
+
* @function
|
|
1858
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
1859
|
+
*/
|
|
1860
|
+
declare const exponential: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1861
|
+
/**
|
|
1862
|
+
* @function
|
|
1763
1863
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1764
1864
|
*
|
|
1765
1865
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1766
1866
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1767
1867
|
* bitwise identical to JAX.
|
|
1768
1868
|
*/
|
|
1769
|
-
declare
|
|
1869
|
+
declare const normal: OwnedFunction<(key: ArrayLike, shape?: number[] | undefined) => Array>;
|
|
1870
|
+
//# sourceMappingURL=random.d.ts.map
|
|
1871
|
+
declare namespace scipy_special_d_exports {
|
|
1872
|
+
export { erf, erfc, logSoftmax, logit, logsumexp, softmax };
|
|
1873
|
+
}
|
|
1874
|
+
/**
|
|
1875
|
+
* @function
|
|
1876
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
1877
|
+
*/
|
|
1878
|
+
declare const logit: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1770
1879
|
//#endregion
|
|
1771
1880
|
//#region src/index.d.ts
|
|
1772
1881
|
/**
|
|
@@ -1778,7 +1887,7 @@ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals:
|
|
|
1778
1887
|
* @function
|
|
1779
1888
|
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1780
1889
|
*/
|
|
1781
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number |
|
|
1890
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | (number | null | JsTree<number | null>)[]) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1782
1891
|
/**
|
|
1783
1892
|
* @function
|
|
1784
1893
|
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
@@ -1853,5 +1962,19 @@ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJs
|
|
|
1853
1962
|
* Does not consume reference to the arrays.
|
|
1854
1963
|
*/
|
|
1855
1964
|
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1965
|
+
/**
|
|
1966
|
+
* Transfer `x` to `device`.
|
|
1967
|
+
*
|
|
1968
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
1969
|
+
* is committed to the device.
|
|
1970
|
+
*
|
|
1971
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
1972
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
1973
|
+
* default device.
|
|
1974
|
+
*/
|
|
1975
|
+
declare function devicePut<T extends JsTree<any>>(x: T, device?: Device): Promise<MapJsTree<T, number | boolean, Array>>;
|
|
1976
|
+
//# sourceMappingURL=index.d.ts.map
|
|
1977
|
+
|
|
1856
1978
|
//#endregion
|
|
1857
|
-
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1979
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, scipy_special_d_exports as scipySpecial, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1980
|
+
//# sourceMappingURL=index.d.cts.map
|