@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
package/dist/index.d.cts
CHANGED
|
@@ -23,6 +23,7 @@ interface Stringable {
|
|
|
23
23
|
}
|
|
24
24
|
//#endregion
|
|
25
25
|
//#region src/shape.d.ts
|
|
26
|
+
/** @inline */
|
|
26
27
|
type Pair = [number, number];
|
|
27
28
|
/**
|
|
28
29
|
* A multidimensional view into memory. An array can be thought of as the
|
|
@@ -49,6 +50,8 @@ declare class View {
|
|
|
49
50
|
get size(): number;
|
|
50
51
|
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
51
52
|
get contiguous(): boolean;
|
|
53
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
54
|
+
dataRange(): [number, number];
|
|
52
55
|
/** Produce an AluExp for evaluating this view at an index. */
|
|
53
56
|
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
54
57
|
/**
|
|
@@ -109,6 +112,17 @@ declare class ShapeTracker {
|
|
|
109
112
|
reshape(newShape: number[]): ShapeTracker;
|
|
110
113
|
/** Broadcast along the given new axes, then expand the shape. */
|
|
111
114
|
broadcast(newShape: number[], axis: number[]): ShapeTracker;
|
|
115
|
+
/**
|
|
116
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
117
|
+
*
|
|
118
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
119
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
120
|
+
*/
|
|
121
|
+
repeat(reps: number[], tile?: boolean): ShapeTracker;
|
|
122
|
+
/** Move axis i to axis j. */
|
|
123
|
+
moveaxis(i: number, j: number): ShapeTracker;
|
|
124
|
+
/** Like pad(), but allows for negative values. */
|
|
125
|
+
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
112
126
|
}
|
|
113
127
|
//#endregion
|
|
114
128
|
//#region src/utils.d.ts
|
|
@@ -133,13 +147,16 @@ declare class FpHash {
|
|
|
133
147
|
/** Run a function while caching it inline inside a `Map`. */
|
|
134
148
|
//#endregion
|
|
135
149
|
//#region src/alu.d.ts
|
|
150
|
+
/** A numerical data type for array contents. */
|
|
136
151
|
declare enum DType {
|
|
137
152
|
Float32 = "float32",
|
|
138
153
|
Int32 = "int32",
|
|
139
154
|
Uint32 = "uint32",
|
|
140
155
|
Bool = "bool",
|
|
141
|
-
|
|
156
|
+
Float16 = "float16",
|
|
142
157
|
}
|
|
158
|
+
/** @inline */
|
|
159
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
143
160
|
/**
|
|
144
161
|
* Mathematical expression on scalar values.
|
|
145
162
|
*
|
|
@@ -165,6 +182,7 @@ declare class AluExp implements FpHashable {
|
|
|
165
182
|
static cos(a: AluExp): AluExp;
|
|
166
183
|
static exp(a: AluExp): AluExp;
|
|
167
184
|
static log(a: AluExp): AluExp;
|
|
185
|
+
static sqrt(a: AluExp): AluExp;
|
|
168
186
|
static reciprocal(a: AluExp): AluExp;
|
|
169
187
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
170
188
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -175,12 +193,13 @@ declare class AluExp implements FpHashable {
|
|
|
175
193
|
static const(dtype: DType, value: any): AluExp;
|
|
176
194
|
static special(dtype: DType, name: string, n: number): AluExp;
|
|
177
195
|
static variable(dtype: DType, name: string): AluExp;
|
|
178
|
-
static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
|
|
196
|
+
static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
|
|
179
197
|
static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
|
|
198
|
+
static f32(value: number): AluExp;
|
|
180
199
|
static i32(value: number): AluExp;
|
|
181
200
|
static u32(value: number): AluExp;
|
|
182
|
-
static f32(value: number): AluExp;
|
|
183
201
|
static bool(value: boolean): AluExp;
|
|
202
|
+
static f16(value: number): AluExp;
|
|
184
203
|
not(): AluExp;
|
|
185
204
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
186
205
|
getHash(): bigint;
|
|
@@ -191,6 +210,19 @@ declare class AluExp implements FpHashable {
|
|
|
191
210
|
reindexGids(gidMap: Map<number, number>): AluExp;
|
|
192
211
|
get min(): number;
|
|
193
212
|
get max(): number;
|
|
213
|
+
/** Largest known integer that divides self. */
|
|
214
|
+
constFactor(): number;
|
|
215
|
+
/**
|
|
216
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
217
|
+
* `null` if it's not divisible.
|
|
218
|
+
*/
|
|
219
|
+
divides(v: number): AluExp | null;
|
|
220
|
+
/**
|
|
221
|
+
* Get all expressions by deeply matching an operation.
|
|
222
|
+
*
|
|
223
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
224
|
+
*/
|
|
225
|
+
splitOp(sep: AluOp): IterableIterator<AluExp>;
|
|
194
226
|
/**
|
|
195
227
|
* Simplify the expression by replacing any known patterns and deduping
|
|
196
228
|
* identical subexpressions.
|
|
@@ -211,10 +243,16 @@ declare class AluExp implements FpHashable {
|
|
|
211
243
|
toString(): string;
|
|
212
244
|
/** Generic fold() operation with a reducer over the expression tree. */
|
|
213
245
|
fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
|
|
246
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
247
|
+
some(predicate: (exp: AluExp) => boolean): boolean;
|
|
214
248
|
/** Rewrite the expression recursively using a visitor. */
|
|
215
249
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
216
250
|
/** Collect all nodes that satisfy a predicate. */
|
|
217
251
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
252
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
253
|
+
distinctOps(): Set<AluOp>;
|
|
254
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
255
|
+
rewriteGlobalViews(): AluExp;
|
|
218
256
|
}
|
|
219
257
|
/** Symbolic form for each mathematical operation. */
|
|
220
258
|
declare enum AluOp {
|
|
@@ -229,6 +267,7 @@ declare enum AluOp {
|
|
|
229
267
|
Cos = "Cos",
|
|
230
268
|
Exp = "Exp",
|
|
231
269
|
Log = "Log",
|
|
270
|
+
Sqrt = "Sqrt",
|
|
232
271
|
Reciprocal = "Reciprocal",
|
|
233
272
|
Cast = "Cast",
|
|
234
273
|
Bitcast = "Bitcast",
|
|
@@ -245,7 +284,7 @@ declare enum AluOp {
|
|
|
245
284
|
Variable = "Variable",
|
|
246
285
|
// arg = variable
|
|
247
286
|
GlobalIndex = "GlobalIndex",
|
|
248
|
-
// arg = gid; src = [bufidx]
|
|
287
|
+
// arg = [gid, len]; src = [bufidx]
|
|
249
288
|
GlobalView = "GlobalView",
|
|
250
289
|
}
|
|
251
290
|
/**
|
|
@@ -300,12 +339,12 @@ declare class Reduction implements FpHashable {
|
|
|
300
339
|
/** Size of the reduction axis. */
|
|
301
340
|
readonly size: number;
|
|
302
341
|
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
303
|
-
readonly
|
|
342
|
+
readonly epilogue: AluExp;
|
|
304
343
|
constructor(/** Data type of the values being reduced over. */
|
|
305
344
|
dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
306
345
|
op: AluOp, /** Size of the reduction axis. */
|
|
307
346
|
size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
308
|
-
|
|
347
|
+
epilogue?: AluExp);
|
|
309
348
|
hash(state: FpHash): void;
|
|
310
349
|
toString(): string;
|
|
311
350
|
/** Get the identity for this reduction operation. */
|
|
@@ -316,7 +355,7 @@ declare class Reduction implements FpHashable {
|
|
|
316
355
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
317
356
|
//#endregion
|
|
318
357
|
//#region src/backend.d.ts
|
|
319
|
-
type Device = "cpu" | "webgpu";
|
|
358
|
+
type Device = "cpu" | "wasm" | "webgpu";
|
|
320
359
|
declare const devices: Device[];
|
|
321
360
|
/** Set the default device backend (must be initialized). */
|
|
322
361
|
declare function setDevice(device: Device): void;
|
|
@@ -339,7 +378,7 @@ interface Backend {
|
|
|
339
378
|
/** Maximum number of arguments per dispatched kernel. */
|
|
340
379
|
readonly maxArgs: number;
|
|
341
380
|
/** Allocate a new slot with reference count 1. */
|
|
342
|
-
malloc(size: number, initialData?:
|
|
381
|
+
malloc(size: number, initialData?: Uint8Array): Slot;
|
|
343
382
|
/** Increment the reference count of the slot. */
|
|
344
383
|
incRef(slot: Slot): void;
|
|
345
384
|
/**
|
|
@@ -348,9 +387,9 @@ interface Backend {
|
|
|
348
387
|
*/
|
|
349
388
|
decRef(slot: Slot): void;
|
|
350
389
|
/** Read a range of bytes from a buffer. */
|
|
351
|
-
read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer
|
|
390
|
+
read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
|
|
352
391
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
353
|
-
readSync(slot: Slot, start?: number, count?: number): ArrayBuffer
|
|
392
|
+
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
354
393
|
/** Prepare an expression to be executed later. */
|
|
355
394
|
prepare(kernel: Kernel): Promise<Executable>;
|
|
356
395
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
@@ -372,18 +411,22 @@ declare class Executable<T = any> {
|
|
|
372
411
|
data: T);
|
|
373
412
|
}
|
|
374
413
|
declare namespace tree_d_exports {
|
|
375
|
-
export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
|
|
414
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
376
415
|
}
|
|
377
416
|
declare enum NodeType {
|
|
378
417
|
Array = "Array",
|
|
379
418
|
Object = "Object",
|
|
380
419
|
Leaf = "Leaf",
|
|
381
420
|
}
|
|
421
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
382
422
|
type JsTree<T> = T | JsTree<T>[] | {
|
|
383
423
|
[key: string]: JsTree<T>;
|
|
384
424
|
};
|
|
385
|
-
type
|
|
386
|
-
|
|
425
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
426
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
427
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
428
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
429
|
+
/** Represents the structure of a JsTree. */
|
|
387
430
|
declare class JsTreeDef {
|
|
388
431
|
readonly nodeType: NodeType;
|
|
389
432
|
readonly nodeMetadata: any;
|
|
@@ -409,6 +452,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
|
|
|
409
452
|
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
410
453
|
/** Take a reference of every array in a tree. */
|
|
411
454
|
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
455
|
+
/** Dispose every array in a tree. */
|
|
456
|
+
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
457
|
+
//#endregion
|
|
458
|
+
//#region src/frontend/convolution.d.ts
|
|
459
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
460
|
+
interface ConvParams {
|
|
461
|
+
strides: number[];
|
|
462
|
+
padding: [number, number][];
|
|
463
|
+
lhsDilation: number[];
|
|
464
|
+
rhsDilation: number[];
|
|
465
|
+
}
|
|
466
|
+
/**
|
|
467
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
468
|
+
*
|
|
469
|
+
* If the check succeeds, returns the output shape.
|
|
470
|
+
*/
|
|
471
|
+
|
|
412
472
|
//#endregion
|
|
413
473
|
//#region src/frontend/core.d.ts
|
|
414
474
|
/**
|
|
@@ -436,10 +496,16 @@ declare enum Primitive {
|
|
|
436
496
|
Cos = "cos",
|
|
437
497
|
Exp = "exp",
|
|
438
498
|
Log = "log",
|
|
499
|
+
Sqrt = "sqrt",
|
|
439
500
|
Min = "min",
|
|
440
501
|
Max = "max",
|
|
441
502
|
Reduce = "reduce",
|
|
442
503
|
Dot = "dot",
|
|
504
|
+
// sum(x*y, axis=-1)
|
|
505
|
+
Conv = "conv",
|
|
506
|
+
// see lax.conv_general_dilated
|
|
507
|
+
Pool = "pool",
|
|
508
|
+
PoolTranspose = "pool_transpose",
|
|
443
509
|
Compare = "compare",
|
|
444
510
|
Where = "where",
|
|
445
511
|
Transpose = "transpose",
|
|
@@ -451,7 +517,7 @@ declare enum Primitive {
|
|
|
451
517
|
Gather = "gather",
|
|
452
518
|
JitCall = "jit_call",
|
|
453
519
|
}
|
|
454
|
-
interface PrimitiveParamsImpl extends Record<Primitive, Record<string,
|
|
520
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
455
521
|
[Primitive.Cast]: {
|
|
456
522
|
dtype: DType;
|
|
457
523
|
};
|
|
@@ -462,6 +528,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
462
528
|
op: AluOp;
|
|
463
529
|
axis: number[];
|
|
464
530
|
};
|
|
531
|
+
[Primitive.Conv]: ConvParams;
|
|
532
|
+
[Primitive.Pool]: {
|
|
533
|
+
window: number[];
|
|
534
|
+
strides: number[];
|
|
535
|
+
};
|
|
536
|
+
[Primitive.PoolTranspose]: {
|
|
537
|
+
inShape: number[];
|
|
538
|
+
window: number[];
|
|
539
|
+
strides: number[];
|
|
540
|
+
};
|
|
465
541
|
[Primitive.Compare]: {
|
|
466
542
|
op: CompareOp;
|
|
467
543
|
};
|
|
@@ -483,10 +559,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
483
559
|
axis: number[];
|
|
484
560
|
};
|
|
485
561
|
[Primitive.Shrink]: {
|
|
486
|
-
slice: [
|
|
562
|
+
slice: Pair[];
|
|
487
563
|
};
|
|
488
564
|
[Primitive.Pad]: {
|
|
489
|
-
width: [
|
|
565
|
+
width: Pair[];
|
|
490
566
|
};
|
|
491
567
|
[Primitive.Gather]: {
|
|
492
568
|
axis: number[];
|
|
@@ -587,6 +663,7 @@ declare abstract class Tracer {
|
|
|
587
663
|
*/
|
|
588
664
|
abstract dispose(): void;
|
|
589
665
|
get shape(): number[];
|
|
666
|
+
get size(): number;
|
|
590
667
|
get dtype(): DType;
|
|
591
668
|
get ndim(): number;
|
|
592
669
|
/** @ignore */
|
|
@@ -673,7 +750,7 @@ declare abstract class Tracer {
|
|
|
673
750
|
* the "gather" primitive, and it allows you to access specific elements of
|
|
674
751
|
* the array by integer indices stored in another array.
|
|
675
752
|
*/
|
|
676
|
-
slice(...index: (number | [] | [number] |
|
|
753
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
677
754
|
}
|
|
678
755
|
declare class ShapedArray implements AbstractValue {
|
|
679
756
|
readonly shape: number[];
|
|
@@ -681,7 +758,7 @@ declare class ShapedArray implements AbstractValue {
|
|
|
681
758
|
constructor(shape: number[], dtype: DType);
|
|
682
759
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
683
760
|
get ndim(): number;
|
|
684
|
-
|
|
761
|
+
toString(): string;
|
|
685
762
|
equals(other: ShapedArray): boolean;
|
|
686
763
|
}
|
|
687
764
|
//#endregion
|
|
@@ -752,14 +829,23 @@ declare class Array extends Tracer {
|
|
|
752
829
|
*/
|
|
753
830
|
[Symbol.toPrimitive](): any;
|
|
754
831
|
/** Realize the array and return it as data. */
|
|
755
|
-
data(): Promise<
|
|
756
|
-
/**
|
|
757
|
-
|
|
832
|
+
data(): Promise<DataArray>;
|
|
833
|
+
/**
|
|
834
|
+
* Wait for this array to finish evaluation.
|
|
835
|
+
*
|
|
836
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
837
|
+
* that pending operations are dispatched and fully executed before it
|
|
838
|
+
* returns.
|
|
839
|
+
*
|
|
840
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
841
|
+
* dispatch of operations as well.
|
|
842
|
+
*/
|
|
843
|
+
wait(): Promise<Array>;
|
|
758
844
|
/**
|
|
759
845
|
* Realize the array and return it as data. This is a sync variant and not
|
|
760
846
|
* recommended for performance reasons, as it will block rendering.
|
|
761
847
|
*/
|
|
762
|
-
dataSync():
|
|
848
|
+
dataSync(): DataArray;
|
|
763
849
|
/**
|
|
764
850
|
* Convert this array into a JavaScript object.
|
|
765
851
|
*
|
|
@@ -772,6 +858,13 @@ declare class Array extends Tracer {
|
|
|
772
858
|
js(): any;
|
|
773
859
|
/** Convert this array into a JavaScript object, asynchronously. */
|
|
774
860
|
jsAsync(): Promise<any>;
|
|
861
|
+
/**
|
|
862
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
863
|
+
*
|
|
864
|
+
* Throws an error if the array does not have a single element. The array must
|
|
865
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
866
|
+
*/
|
|
867
|
+
item(): number;
|
|
775
868
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
776
869
|
static _implRules(): typeof implRules;
|
|
777
870
|
_realizeSource(): number;
|
|
@@ -782,7 +875,7 @@ declare function scalar(value: number | boolean, {
|
|
|
782
875
|
device
|
|
783
876
|
}?: DTypeAndDevice): Array;
|
|
784
877
|
/** Constructor for creating a new array from data. */
|
|
785
|
-
declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
878
|
+
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
786
879
|
shape,
|
|
787
880
|
dtype,
|
|
788
881
|
device
|
|
@@ -854,15 +947,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
854
947
|
dtype,
|
|
855
948
|
device
|
|
856
949
|
}?: DTypeAndDevice): Array;
|
|
857
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
858
950
|
declare namespace numpy_d_exports {
|
|
859
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack,
|
|
951
|
+
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
860
952
|
}
|
|
861
953
|
declare const float32 = DType.Float32;
|
|
862
954
|
declare const int32 = DType.Int32;
|
|
863
955
|
declare const uint32 = DType.Uint32;
|
|
864
956
|
declare const bool = DType.Bool;
|
|
865
|
-
declare const
|
|
957
|
+
declare const float16 = DType.Float16;
|
|
866
958
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
867
959
|
declare const e: number;
|
|
868
960
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -889,6 +981,8 @@ declare const cos: (x: ArrayLike) => Array;
|
|
|
889
981
|
declare const exp: (x: ArrayLike) => Array;
|
|
890
982
|
/** Calculate the natural logarithm of all elements in the input array. */
|
|
891
983
|
declare const log: (x: ArrayLike) => Array;
|
|
984
|
+
/** Calculate the square root of all elements in the input array. */
|
|
985
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
892
986
|
/** Return element-wise minimum of the input arrays. */
|
|
893
987
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
894
988
|
/** Return element-wise maximum of the input arrays. */
|
|
@@ -930,6 +1024,12 @@ declare const pad: (x: ArrayLike, width: number | [number, number] | [number, nu
|
|
|
930
1024
|
declare const ndim: (x: ArrayLike) => number;
|
|
931
1025
|
/** Return the shape of an array. Does not consume array reference. */
|
|
932
1026
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1027
|
+
/** Return an array of zeros with the same shape and type as a given array. */
|
|
1028
|
+
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1029
|
+
/** Return an array of ones with the same shape and type as a given array. */
|
|
1030
|
+
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1031
|
+
/** Return a full array with the same shape and type as a given array. */
|
|
1032
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
933
1033
|
/**
|
|
934
1034
|
* Return the number of elements in an array, optionally along an axis.
|
|
935
1035
|
* Does not consume array reference.
|
|
@@ -1013,9 +1113,11 @@ declare function ravel(a: ArrayLike): Array;
|
|
|
1013
1113
|
* Return specified diagonals.
|
|
1014
1114
|
*
|
|
1015
1115
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
1016
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
1116
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
1017
1117
|
*
|
|
1018
|
-
* This returns a view over the existing array.
|
|
1118
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
1119
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
1120
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
1019
1121
|
*/
|
|
1020
1122
|
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1021
1123
|
/**
|
|
@@ -1087,6 +1189,24 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1087
1189
|
declare function log2(x: ArrayLike): Array;
|
|
1088
1190
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1089
1191
|
declare function log10(x: ArrayLike): Array;
|
|
1192
|
+
/**
|
|
1193
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
1194
|
+
*
|
|
1195
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1196
|
+
*/
|
|
1197
|
+
declare function sinh(x: ArrayLike): Array;
|
|
1198
|
+
/**
|
|
1199
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
1200
|
+
*
|
|
1201
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1202
|
+
*/
|
|
1203
|
+
declare function cosh(x: ArrayLike): Array;
|
|
1204
|
+
/**
|
|
1205
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
1206
|
+
*
|
|
1207
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1208
|
+
*/
|
|
1209
|
+
declare function tanh(x: ArrayLike): Array;
|
|
1090
1210
|
//#endregion
|
|
1091
1211
|
//#region src/frontend/jaxpr.d.ts
|
|
1092
1212
|
/** Variable in a Jaxpr expression. */
|
|
@@ -1149,8 +1269,38 @@ declare class Jaxpr implements FpHashable {
|
|
|
1149
1269
|
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1150
1270
|
flatten(): Jaxpr;
|
|
1151
1271
|
}
|
|
1272
|
+
/** @inline */
|
|
1273
|
+
type JitOpts = {
|
|
1274
|
+
staticArgnums?: number[];
|
|
1275
|
+
device?: Device;
|
|
1276
|
+
};
|
|
1277
|
+
declare namespace lax_d_exports {
|
|
1278
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1279
|
+
}
|
|
1280
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1281
|
+
/**
|
|
1282
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
1283
|
+
*
|
|
1284
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
1285
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
1286
|
+
*
|
|
1287
|
+
* Grouped convolutions are not supported right now.
|
|
1288
|
+
*/
|
|
1289
|
+
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1290
|
+
lhsDilation,
|
|
1291
|
+
rhsDilation
|
|
1292
|
+
}?: {
|
|
1293
|
+
lhsDilation?: number[];
|
|
1294
|
+
rhsDilation?: number[];
|
|
1295
|
+
}): Array;
|
|
1296
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1297
|
+
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
1298
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1299
|
+
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1300
|
+
/** Reduce a computation over padded windows. */
|
|
1301
|
+
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1152
1302
|
declare namespace nn_d_exports {
|
|
1153
|
-
export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1303
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1154
1304
|
}
|
|
1155
1305
|
/**
|
|
1156
1306
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1190,7 +1340,7 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1190
1340
|
*
|
|
1191
1341
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1192
1342
|
*/
|
|
1193
|
-
declare
|
|
1343
|
+
declare const silu: (x: ArrayLike) => Array;
|
|
1194
1344
|
/**
|
|
1195
1345
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1196
1346
|
* Swish, computed element-wise:
|
|
@@ -1200,7 +1350,7 @@ declare function silu(x: ArrayLike): Array;
|
|
|
1200
1350
|
*
|
|
1201
1351
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1202
1352
|
*/
|
|
1203
|
-
declare const swish:
|
|
1353
|
+
declare const swish: (x: ArrayLike) => Array;
|
|
1204
1354
|
/**
|
|
1205
1355
|
* Log-sigmoid activation function, computed element-wise:
|
|
1206
1356
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1208,6 +1358,48 @@ declare const swish: typeof silu;
|
|
|
1208
1358
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1209
1359
|
/** Identity activation function. Returns the argument unmodified. */
|
|
1210
1360
|
declare const identity: (x: ArrayLike) => Array;
|
|
1361
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
1362
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
|
|
1363
|
+
/**
|
|
1364
|
+
* Exponential linear unit activation function.
|
|
1365
|
+
*
|
|
1366
|
+
* Computes the element-wise function:
|
|
1367
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1368
|
+
*/
|
|
1369
|
+
declare function elu(x: ArrayLike, alpha?: number): Array;
|
|
1370
|
+
/**
|
|
1371
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
1372
|
+
*
|
|
1373
|
+
* Computes the element-wise function:
|
|
1374
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1375
|
+
*/
|
|
1376
|
+
declare function celu(x: ArrayLike, alpha?: number): Array;
|
|
1377
|
+
/**
|
|
1378
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
1379
|
+
*
|
|
1380
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
1381
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
1382
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1383
|
+
*
|
|
1384
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1385
|
+
*
|
|
1386
|
+
* This will be improved in the future.
|
|
1387
|
+
*/
|
|
1388
|
+
declare const gelu: (x: ArrayLike) => Array;
|
|
1389
|
+
/**
|
|
1390
|
+
* Gated linear unit (GLU) activation function.
|
|
1391
|
+
*
|
|
1392
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
1393
|
+
* computes `a * sigmoid(b)`.
|
|
1394
|
+
*/
|
|
1395
|
+
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1396
|
+
/**
|
|
1397
|
+
* Mish activation function.
|
|
1398
|
+
*
|
|
1399
|
+
* Computes the element-wise function:
|
|
1400
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
1401
|
+
*/
|
|
1402
|
+
declare function mish(x: ArrayLike): Array;
|
|
1211
1403
|
/**
|
|
1212
1404
|
* Softmax function. Computes the function which rescales elements to the range
|
|
1213
1405
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -1271,38 +1463,49 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1271
1463
|
}): Array;
|
|
1272
1464
|
//#endregion
|
|
1273
1465
|
//#region src/index.d.ts
|
|
1274
|
-
/** @inline */
|
|
1275
|
-
type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
|
|
1276
1466
|
/** Compute the forward-mode Jacobian-vector product for a function. */
|
|
1277
|
-
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1467
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1278
1468
|
/** Vectorize an operation on a batched axis for one or more inputs. */
|
|
1279
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1469
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1280
1470
|
/** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
|
|
1281
|
-
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>,
|
|
1471
|
+
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1282
1472
|
/** Construct a Jaxpr by dynamically tracing a function with example inputs. */
|
|
1283
|
-
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1473
|
+
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1284
1474
|
jaxpr: Jaxpr;
|
|
1285
1475
|
consts: Array[];
|
|
1286
1476
|
treedef: JsTreeDef;
|
|
1287
1477
|
};
|
|
1288
|
-
|
|
1478
|
+
/**
|
|
1479
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1480
|
+
*
|
|
1481
|
+
* The function will be compiled the first time it is called with a set of
|
|
1482
|
+
* argument shapes.
|
|
1483
|
+
*
|
|
1484
|
+
* **Options:**
|
|
1485
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1486
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
1487
|
+
* and different values will trigger recompilation.
|
|
1488
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
1489
|
+
* computation will be placed on the default device.
|
|
1490
|
+
*/
|
|
1491
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1289
1492
|
/**
|
|
1290
1493
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1291
1494
|
* partial evaluation.
|
|
1292
1495
|
*/
|
|
1293
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1496
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1294
1497
|
/** Calculate the reverse-mode vector-Jacobian product for a function. */
|
|
1295
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1498
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1296
1499
|
/**
|
|
1297
1500
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1298
1501
|
* first argument.
|
|
1299
1502
|
*/
|
|
1300
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1503
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1301
1504
|
/** Create a function that evaluates both `f` and the gradient of `f`. */
|
|
1302
|
-
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1505
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1303
1506
|
/** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
|
|
1304
1507
|
declare const jacrev: typeof jacfwd;
|
|
1305
1508
|
/** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
|
|
1306
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>,
|
|
1509
|
+
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1307
1510
|
//#endregion
|
|
1308
|
-
export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1511
|
+
export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|