@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
package/dist/index.d.ts
CHANGED
|
@@ -20,6 +20,7 @@ interface Stringable {
|
|
|
20
20
|
//#endregion
|
|
21
21
|
//#region src/shape.d.ts
|
|
22
22
|
|
|
23
|
+
/** @inline */
|
|
23
24
|
type Pair = [number, number];
|
|
24
25
|
/**
|
|
25
26
|
* A multidimensional view into memory. An array can be thought of as the
|
|
@@ -46,6 +47,8 @@ declare class View {
|
|
|
46
47
|
get size(): number;
|
|
47
48
|
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
48
49
|
get contiguous(): boolean;
|
|
50
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
51
|
+
dataRange(): [number, number];
|
|
49
52
|
/** Produce an AluExp for evaluating this view at an index. */
|
|
50
53
|
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
51
54
|
/**
|
|
@@ -106,6 +109,17 @@ declare class ShapeTracker {
|
|
|
106
109
|
reshape(newShape: number[]): ShapeTracker;
|
|
107
110
|
/** Broadcast along the given new axes, then expand the shape. */
|
|
108
111
|
broadcast(newShape: number[], axis: number[]): ShapeTracker;
|
|
112
|
+
/**
|
|
113
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
114
|
+
*
|
|
115
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
116
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
117
|
+
*/
|
|
118
|
+
repeat(reps: number[], tile?: boolean): ShapeTracker;
|
|
119
|
+
/** Move axis i to axis j. */
|
|
120
|
+
moveaxis(i: number, j: number): ShapeTracker;
|
|
121
|
+
/** Like pad(), but allows for negative values. */
|
|
122
|
+
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
109
123
|
}
|
|
110
124
|
//#endregion
|
|
111
125
|
//#region src/utils.d.ts
|
|
@@ -130,13 +144,16 @@ declare class FpHash {
|
|
|
130
144
|
/** Run a function while caching it inline inside a `Map`. */
|
|
131
145
|
//#endregion
|
|
132
146
|
//#region src/alu.d.ts
|
|
147
|
+
/** A numerical data type for array contents. */
|
|
133
148
|
declare enum DType {
|
|
134
149
|
Float32 = "float32",
|
|
135
150
|
Int32 = "int32",
|
|
136
151
|
Uint32 = "uint32",
|
|
137
152
|
Bool = "bool",
|
|
138
|
-
|
|
153
|
+
Float16 = "float16",
|
|
139
154
|
}
|
|
155
|
+
/** @inline */
|
|
156
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
140
157
|
/**
|
|
141
158
|
* Mathematical expression on scalar values.
|
|
142
159
|
*
|
|
@@ -162,6 +179,7 @@ declare class AluExp implements FpHashable {
|
|
|
162
179
|
static cos(a: AluExp): AluExp;
|
|
163
180
|
static exp(a: AluExp): AluExp;
|
|
164
181
|
static log(a: AluExp): AluExp;
|
|
182
|
+
static sqrt(a: AluExp): AluExp;
|
|
165
183
|
static reciprocal(a: AluExp): AluExp;
|
|
166
184
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
167
185
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -172,12 +190,13 @@ declare class AluExp implements FpHashable {
|
|
|
172
190
|
static const(dtype: DType, value: any): AluExp;
|
|
173
191
|
static special(dtype: DType, name: string, n: number): AluExp;
|
|
174
192
|
static variable(dtype: DType, name: string): AluExp;
|
|
175
|
-
static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
|
|
193
|
+
static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
|
|
176
194
|
static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
|
|
195
|
+
static f32(value: number): AluExp;
|
|
177
196
|
static i32(value: number): AluExp;
|
|
178
197
|
static u32(value: number): AluExp;
|
|
179
|
-
static f32(value: number): AluExp;
|
|
180
198
|
static bool(value: boolean): AluExp;
|
|
199
|
+
static f16(value: number): AluExp;
|
|
181
200
|
not(): AluExp;
|
|
182
201
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
183
202
|
getHash(): bigint;
|
|
@@ -188,6 +207,19 @@ declare class AluExp implements FpHashable {
|
|
|
188
207
|
reindexGids(gidMap: Map<number, number>): AluExp;
|
|
189
208
|
get min(): number;
|
|
190
209
|
get max(): number;
|
|
210
|
+
/** Largest known integer that divides self. */
|
|
211
|
+
constFactor(): number;
|
|
212
|
+
/**
|
|
213
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
214
|
+
* `null` if it's not divisible.
|
|
215
|
+
*/
|
|
216
|
+
divides(v: number): AluExp | null;
|
|
217
|
+
/**
|
|
218
|
+
* Get all expressions by deeply matching an operation.
|
|
219
|
+
*
|
|
220
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
221
|
+
*/
|
|
222
|
+
splitOp(sep: AluOp): IterableIterator<AluExp>;
|
|
191
223
|
/**
|
|
192
224
|
* Simplify the expression by replacing any known patterns and deduping
|
|
193
225
|
* identical subexpressions.
|
|
@@ -208,10 +240,16 @@ declare class AluExp implements FpHashable {
|
|
|
208
240
|
toString(): string;
|
|
209
241
|
/** Generic fold() operation with a reducer over the expression tree. */
|
|
210
242
|
fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
|
|
243
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
244
|
+
some(predicate: (exp: AluExp) => boolean): boolean;
|
|
211
245
|
/** Rewrite the expression recursively using a visitor. */
|
|
212
246
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
213
247
|
/** Collect all nodes that satisfy a predicate. */
|
|
214
248
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
249
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
250
|
+
distinctOps(): Set<AluOp>;
|
|
251
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
252
|
+
rewriteGlobalViews(): AluExp;
|
|
215
253
|
}
|
|
216
254
|
/** Symbolic form for each mathematical operation. */
|
|
217
255
|
declare enum AluOp {
|
|
@@ -226,6 +264,7 @@ declare enum AluOp {
|
|
|
226
264
|
Cos = "Cos",
|
|
227
265
|
Exp = "Exp",
|
|
228
266
|
Log = "Log",
|
|
267
|
+
Sqrt = "Sqrt",
|
|
229
268
|
Reciprocal = "Reciprocal",
|
|
230
269
|
Cast = "Cast",
|
|
231
270
|
Bitcast = "Bitcast",
|
|
@@ -242,7 +281,7 @@ declare enum AluOp {
|
|
|
242
281
|
Variable = "Variable",
|
|
243
282
|
// arg = variable
|
|
244
283
|
GlobalIndex = "GlobalIndex",
|
|
245
|
-
// arg = gid; src = [bufidx]
|
|
284
|
+
// arg = [gid, len]; src = [bufidx]
|
|
246
285
|
GlobalView = "GlobalView",
|
|
247
286
|
}
|
|
248
287
|
/**
|
|
@@ -297,12 +336,12 @@ declare class Reduction implements FpHashable {
|
|
|
297
336
|
/** Size of the reduction axis. */
|
|
298
337
|
readonly size: number;
|
|
299
338
|
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
300
|
-
readonly
|
|
339
|
+
readonly epilogue: AluExp;
|
|
301
340
|
constructor(/** Data type of the values being reduced over. */
|
|
302
341
|
dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
303
342
|
op: AluOp, /** Size of the reduction axis. */
|
|
304
343
|
size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
305
|
-
|
|
344
|
+
epilogue?: AluExp);
|
|
306
345
|
hash(state: FpHash): void;
|
|
307
346
|
toString(): string;
|
|
308
347
|
/** Get the identity for this reduction operation. */
|
|
@@ -313,7 +352,7 @@ declare class Reduction implements FpHashable {
|
|
|
313
352
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
314
353
|
//#endregion
|
|
315
354
|
//#region src/backend.d.ts
|
|
316
|
-
type Device = "cpu" | "webgpu";
|
|
355
|
+
type Device = "cpu" | "wasm" | "webgpu";
|
|
317
356
|
declare const devices: Device[];
|
|
318
357
|
/** Set the default device backend (must be initialized). */
|
|
319
358
|
declare function setDevice(device: Device): void;
|
|
@@ -336,7 +375,7 @@ interface Backend {
|
|
|
336
375
|
/** Maximum number of arguments per dispatched kernel. */
|
|
337
376
|
readonly maxArgs: number;
|
|
338
377
|
/** Allocate a new slot with reference count 1. */
|
|
339
|
-
malloc(size: number, initialData?:
|
|
378
|
+
malloc(size: number, initialData?: Uint8Array): Slot;
|
|
340
379
|
/** Increment the reference count of the slot. */
|
|
341
380
|
incRef(slot: Slot): void;
|
|
342
381
|
/**
|
|
@@ -345,9 +384,9 @@ interface Backend {
|
|
|
345
384
|
*/
|
|
346
385
|
decRef(slot: Slot): void;
|
|
347
386
|
/** Read a range of bytes from a buffer. */
|
|
348
|
-
read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer
|
|
387
|
+
read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
|
|
349
388
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
350
|
-
readSync(slot: Slot, start?: number, count?: number): ArrayBuffer
|
|
389
|
+
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
351
390
|
/** Prepare an expression to be executed later. */
|
|
352
391
|
prepare(kernel: Kernel): Promise<Executable>;
|
|
353
392
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
@@ -369,18 +408,22 @@ declare class Executable<T = any> {
|
|
|
369
408
|
data: T);
|
|
370
409
|
}
|
|
371
410
|
declare namespace tree_d_exports {
|
|
372
|
-
export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
|
|
411
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
373
412
|
}
|
|
374
413
|
declare enum NodeType {
|
|
375
414
|
Array = "Array",
|
|
376
415
|
Object = "Object",
|
|
377
416
|
Leaf = "Leaf",
|
|
378
417
|
}
|
|
418
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
379
419
|
type JsTree<T> = T | JsTree<T>[] | {
|
|
380
420
|
[key: string]: JsTree<T>;
|
|
381
421
|
};
|
|
382
|
-
type
|
|
383
|
-
|
|
422
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
423
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
424
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
425
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
426
|
+
/** Represents the structure of a JsTree. */
|
|
384
427
|
declare class JsTreeDef {
|
|
385
428
|
readonly nodeType: NodeType;
|
|
386
429
|
readonly nodeMetadata: any;
|
|
@@ -406,6 +449,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
|
|
|
406
449
|
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
407
450
|
/** Take a reference of every array in a tree. */
|
|
408
451
|
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
452
|
+
/** Dispose every array in a tree. */
|
|
453
|
+
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
454
|
+
//#endregion
|
|
455
|
+
//#region src/frontend/convolution.d.ts
|
|
456
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
457
|
+
interface ConvParams {
|
|
458
|
+
strides: number[];
|
|
459
|
+
padding: [number, number][];
|
|
460
|
+
lhsDilation: number[];
|
|
461
|
+
rhsDilation: number[];
|
|
462
|
+
}
|
|
463
|
+
/**
|
|
464
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
465
|
+
*
|
|
466
|
+
* If the check succeeds, returns the output shape.
|
|
467
|
+
*/
|
|
468
|
+
|
|
409
469
|
//#endregion
|
|
410
470
|
//#region src/frontend/core.d.ts
|
|
411
471
|
/**
|
|
@@ -433,10 +493,16 @@ declare enum Primitive {
|
|
|
433
493
|
Cos = "cos",
|
|
434
494
|
Exp = "exp",
|
|
435
495
|
Log = "log",
|
|
496
|
+
Sqrt = "sqrt",
|
|
436
497
|
Min = "min",
|
|
437
498
|
Max = "max",
|
|
438
499
|
Reduce = "reduce",
|
|
439
500
|
Dot = "dot",
|
|
501
|
+
// sum(x*y, axis=-1)
|
|
502
|
+
Conv = "conv",
|
|
503
|
+
// see lax.conv_general_dilated
|
|
504
|
+
Pool = "pool",
|
|
505
|
+
PoolTranspose = "pool_transpose",
|
|
440
506
|
Compare = "compare",
|
|
441
507
|
Where = "where",
|
|
442
508
|
Transpose = "transpose",
|
|
@@ -448,7 +514,7 @@ declare enum Primitive {
|
|
|
448
514
|
Gather = "gather",
|
|
449
515
|
JitCall = "jit_call",
|
|
450
516
|
}
|
|
451
|
-
interface PrimitiveParamsImpl extends Record<Primitive, Record<string,
|
|
517
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
452
518
|
[Primitive.Cast]: {
|
|
453
519
|
dtype: DType;
|
|
454
520
|
};
|
|
@@ -459,6 +525,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
459
525
|
op: AluOp;
|
|
460
526
|
axis: number[];
|
|
461
527
|
};
|
|
528
|
+
[Primitive.Conv]: ConvParams;
|
|
529
|
+
[Primitive.Pool]: {
|
|
530
|
+
window: number[];
|
|
531
|
+
strides: number[];
|
|
532
|
+
};
|
|
533
|
+
[Primitive.PoolTranspose]: {
|
|
534
|
+
inShape: number[];
|
|
535
|
+
window: number[];
|
|
536
|
+
strides: number[];
|
|
537
|
+
};
|
|
462
538
|
[Primitive.Compare]: {
|
|
463
539
|
op: CompareOp;
|
|
464
540
|
};
|
|
@@ -480,10 +556,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
480
556
|
axis: number[];
|
|
481
557
|
};
|
|
482
558
|
[Primitive.Shrink]: {
|
|
483
|
-
slice: [
|
|
559
|
+
slice: Pair[];
|
|
484
560
|
};
|
|
485
561
|
[Primitive.Pad]: {
|
|
486
|
-
width: [
|
|
562
|
+
width: Pair[];
|
|
487
563
|
};
|
|
488
564
|
[Primitive.Gather]: {
|
|
489
565
|
axis: number[];
|
|
@@ -584,6 +660,7 @@ declare abstract class Tracer {
|
|
|
584
660
|
*/
|
|
585
661
|
abstract dispose(): void;
|
|
586
662
|
get shape(): number[];
|
|
663
|
+
get size(): number;
|
|
587
664
|
get dtype(): DType;
|
|
588
665
|
get ndim(): number;
|
|
589
666
|
/** @ignore */
|
|
@@ -670,7 +747,7 @@ declare abstract class Tracer {
|
|
|
670
747
|
* the "gather" primitive, and it allows you to access specific elements of
|
|
671
748
|
* the array by integer indices stored in another array.
|
|
672
749
|
*/
|
|
673
|
-
slice(...index: (number | [] | [number] |
|
|
750
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
674
751
|
}
|
|
675
752
|
declare class ShapedArray implements AbstractValue {
|
|
676
753
|
readonly shape: number[];
|
|
@@ -678,7 +755,7 @@ declare class ShapedArray implements AbstractValue {
|
|
|
678
755
|
constructor(shape: number[], dtype: DType);
|
|
679
756
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
680
757
|
get ndim(): number;
|
|
681
|
-
|
|
758
|
+
toString(): string;
|
|
682
759
|
equals(other: ShapedArray): boolean;
|
|
683
760
|
}
|
|
684
761
|
//#endregion
|
|
@@ -749,14 +826,23 @@ declare class Array extends Tracer {
|
|
|
749
826
|
*/
|
|
750
827
|
[Symbol.toPrimitive](): any;
|
|
751
828
|
/** Realize the array and return it as data. */
|
|
752
|
-
data(): Promise<
|
|
753
|
-
/**
|
|
754
|
-
|
|
829
|
+
data(): Promise<DataArray>;
|
|
830
|
+
/**
|
|
831
|
+
* Wait for this array to finish evaluation.
|
|
832
|
+
*
|
|
833
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
834
|
+
* that pending operations are dispatched and fully executed before it
|
|
835
|
+
* returns.
|
|
836
|
+
*
|
|
837
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
838
|
+
* dispatch of operations as well.
|
|
839
|
+
*/
|
|
840
|
+
wait(): Promise<Array>;
|
|
755
841
|
/**
|
|
756
842
|
* Realize the array and return it as data. This is a sync variant and not
|
|
757
843
|
* recommended for performance reasons, as it will block rendering.
|
|
758
844
|
*/
|
|
759
|
-
dataSync():
|
|
845
|
+
dataSync(): DataArray;
|
|
760
846
|
/**
|
|
761
847
|
* Convert this array into a JavaScript object.
|
|
762
848
|
*
|
|
@@ -769,6 +855,13 @@ declare class Array extends Tracer {
|
|
|
769
855
|
js(): any;
|
|
770
856
|
/** Convert this array into a JavaScript object, asynchronously. */
|
|
771
857
|
jsAsync(): Promise<any>;
|
|
858
|
+
/**
|
|
859
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
860
|
+
*
|
|
861
|
+
* Throws an error if the array does not have a single element. The array must
|
|
862
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
863
|
+
*/
|
|
864
|
+
item(): number;
|
|
772
865
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
773
866
|
static _implRules(): typeof implRules;
|
|
774
867
|
_realizeSource(): number;
|
|
@@ -779,7 +872,7 @@ declare function scalar(value: number | boolean, {
|
|
|
779
872
|
device
|
|
780
873
|
}?: DTypeAndDevice): Array;
|
|
781
874
|
/** Constructor for creating a new array from data. */
|
|
782
|
-
declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
875
|
+
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
783
876
|
shape,
|
|
784
877
|
dtype,
|
|
785
878
|
device
|
|
@@ -851,15 +944,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
851
944
|
dtype,
|
|
852
945
|
device
|
|
853
946
|
}?: DTypeAndDevice): Array;
|
|
854
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
855
947
|
declare namespace numpy_d_exports {
|
|
856
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack,
|
|
948
|
+
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
857
949
|
}
|
|
858
950
|
declare const float32 = DType.Float32;
|
|
859
951
|
declare const int32 = DType.Int32;
|
|
860
952
|
declare const uint32 = DType.Uint32;
|
|
861
953
|
declare const bool = DType.Bool;
|
|
862
|
-
declare const
|
|
954
|
+
declare const float16 = DType.Float16;
|
|
863
955
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
864
956
|
declare const e: number;
|
|
865
957
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -886,6 +978,8 @@ declare const cos: (x: ArrayLike) => Array;
|
|
|
886
978
|
declare const exp: (x: ArrayLike) => Array;
|
|
887
979
|
/** Calculate the natural logarithm of all elements in the input array. */
|
|
888
980
|
declare const log: (x: ArrayLike) => Array;
|
|
981
|
+
/** Calculate the square root of all elements in the input array. */
|
|
982
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
889
983
|
/** Return element-wise minimum of the input arrays. */
|
|
890
984
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
891
985
|
/** Return element-wise maximum of the input arrays. */
|
|
@@ -927,6 +1021,12 @@ declare const pad: (x: ArrayLike, width: number | [number, number] | [number, nu
|
|
|
927
1021
|
declare const ndim: (x: ArrayLike) => number;
|
|
928
1022
|
/** Return the shape of an array. Does not consume array reference. */
|
|
929
1023
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1024
|
+
/** Return an array of zeros with the same shape and type as a given array. */
|
|
1025
|
+
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1026
|
+
/** Return an array of ones with the same shape and type as a given array. */
|
|
1027
|
+
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1028
|
+
/** Return a full array with the same shape and type as a given array. */
|
|
1029
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
930
1030
|
/**
|
|
931
1031
|
* Return the number of elements in an array, optionally along an axis.
|
|
932
1032
|
* Does not consume array reference.
|
|
@@ -1010,9 +1110,11 @@ declare function ravel(a: ArrayLike): Array;
|
|
|
1010
1110
|
* Return specified diagonals.
|
|
1011
1111
|
*
|
|
1012
1112
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
1013
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
1113
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
1014
1114
|
*
|
|
1015
|
-
* This returns a view over the existing array.
|
|
1115
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
1116
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
1117
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
1016
1118
|
*/
|
|
1017
1119
|
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1018
1120
|
/**
|
|
@@ -1084,6 +1186,24 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1084
1186
|
declare function log2(x: ArrayLike): Array;
|
|
1085
1187
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1086
1188
|
declare function log10(x: ArrayLike): Array;
|
|
1189
|
+
/**
|
|
1190
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
1191
|
+
*
|
|
1192
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1193
|
+
*/
|
|
1194
|
+
declare function sinh(x: ArrayLike): Array;
|
|
1195
|
+
/**
|
|
1196
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
1197
|
+
*
|
|
1198
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1199
|
+
*/
|
|
1200
|
+
declare function cosh(x: ArrayLike): Array;
|
|
1201
|
+
/**
|
|
1202
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
1203
|
+
*
|
|
1204
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1205
|
+
*/
|
|
1206
|
+
declare function tanh(x: ArrayLike): Array;
|
|
1087
1207
|
//#endregion
|
|
1088
1208
|
//#region src/frontend/jaxpr.d.ts
|
|
1089
1209
|
/** Variable in a Jaxpr expression. */
|
|
@@ -1146,8 +1266,38 @@ declare class Jaxpr implements FpHashable {
|
|
|
1146
1266
|
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1147
1267
|
flatten(): Jaxpr;
|
|
1148
1268
|
}
|
|
1269
|
+
/** @inline */
|
|
1270
|
+
type JitOpts = {
|
|
1271
|
+
staticArgnums?: number[];
|
|
1272
|
+
device?: Device;
|
|
1273
|
+
};
|
|
1274
|
+
declare namespace lax_d_exports {
|
|
1275
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1276
|
+
}
|
|
1277
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1278
|
+
/**
|
|
1279
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
1280
|
+
*
|
|
1281
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
1282
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
1283
|
+
*
|
|
1284
|
+
* Grouped convolutions are not supported right now.
|
|
1285
|
+
*/
|
|
1286
|
+
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1287
|
+
lhsDilation,
|
|
1288
|
+
rhsDilation
|
|
1289
|
+
}?: {
|
|
1290
|
+
lhsDilation?: number[];
|
|
1291
|
+
rhsDilation?: number[];
|
|
1292
|
+
}): Array;
|
|
1293
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1294
|
+
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
1295
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1296
|
+
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1297
|
+
/** Reduce a computation over padded windows. */
|
|
1298
|
+
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1149
1299
|
declare namespace nn_d_exports {
|
|
1150
|
-
export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1300
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1151
1301
|
}
|
|
1152
1302
|
/**
|
|
1153
1303
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1187,7 +1337,7 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1187
1337
|
*
|
|
1188
1338
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1189
1339
|
*/
|
|
1190
|
-
declare
|
|
1340
|
+
declare const silu: (x: ArrayLike) => Array;
|
|
1191
1341
|
/**
|
|
1192
1342
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1193
1343
|
* Swish, computed element-wise:
|
|
@@ -1197,7 +1347,7 @@ declare function silu(x: ArrayLike): Array;
|
|
|
1197
1347
|
*
|
|
1198
1348
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1199
1349
|
*/
|
|
1200
|
-
declare const swish:
|
|
1350
|
+
declare const swish: (x: ArrayLike) => Array;
|
|
1201
1351
|
/**
|
|
1202
1352
|
* Log-sigmoid activation function, computed element-wise:
|
|
1203
1353
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1205,6 +1355,48 @@ declare const swish: typeof silu;
|
|
|
1205
1355
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1206
1356
|
/** Identity activation function. Returns the argument unmodified. */
|
|
1207
1357
|
declare const identity: (x: ArrayLike) => Array;
|
|
1358
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
1359
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
|
|
1360
|
+
/**
|
|
1361
|
+
* Exponential linear unit activation function.
|
|
1362
|
+
*
|
|
1363
|
+
* Computes the element-wise function:
|
|
1364
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1365
|
+
*/
|
|
1366
|
+
declare function elu(x: ArrayLike, alpha?: number): Array;
|
|
1367
|
+
/**
|
|
1368
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
1369
|
+
*
|
|
1370
|
+
* Computes the element-wise function:
|
|
1371
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1372
|
+
*/
|
|
1373
|
+
declare function celu(x: ArrayLike, alpha?: number): Array;
|
|
1374
|
+
/**
|
|
1375
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
1376
|
+
*
|
|
1377
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
1378
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
1379
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1380
|
+
*
|
|
1381
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1382
|
+
*
|
|
1383
|
+
* This will be improved in the future.
|
|
1384
|
+
*/
|
|
1385
|
+
declare const gelu: (x: ArrayLike) => Array;
|
|
1386
|
+
/**
|
|
1387
|
+
* Gated linear unit (GLU) activation function.
|
|
1388
|
+
*
|
|
1389
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
1390
|
+
* computes `a * sigmoid(b)`.
|
|
1391
|
+
*/
|
|
1392
|
+
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1393
|
+
/**
|
|
1394
|
+
* Mish activation function.
|
|
1395
|
+
*
|
|
1396
|
+
* Computes the element-wise function:
|
|
1397
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
1398
|
+
*/
|
|
1399
|
+
declare function mish(x: ArrayLike): Array;
|
|
1208
1400
|
/**
|
|
1209
1401
|
* Softmax function. Computes the function which rescales elements to the range
|
|
1210
1402
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -1268,38 +1460,49 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1268
1460
|
}): Array;
|
|
1269
1461
|
//#endregion
|
|
1270
1462
|
//#region src/index.d.ts
|
|
1271
|
-
/** @inline */
|
|
1272
|
-
type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
|
|
1273
1463
|
/** Compute the forward-mode Jacobian-vector product for a function. */
|
|
1274
|
-
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1464
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1275
1465
|
/** Vectorize an operation on a batched axis for one or more inputs. */
|
|
1276
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1466
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1277
1467
|
/** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
|
|
1278
|
-
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>,
|
|
1468
|
+
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1279
1469
|
/** Construct a Jaxpr by dynamically tracing a function with example inputs. */
|
|
1280
|
-
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1470
|
+
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1281
1471
|
jaxpr: Jaxpr;
|
|
1282
1472
|
consts: Array[];
|
|
1283
1473
|
treedef: JsTreeDef;
|
|
1284
1474
|
};
|
|
1285
|
-
|
|
1475
|
+
/**
|
|
1476
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1477
|
+
*
|
|
1478
|
+
* The function will be compiled the first time it is called with a set of
|
|
1479
|
+
* argument shapes.
|
|
1480
|
+
*
|
|
1481
|
+
* **Options:**
|
|
1482
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1483
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
1484
|
+
* and different values will trigger recompilation.
|
|
1485
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
1486
|
+
* computation will be placed on the default device.
|
|
1487
|
+
*/
|
|
1488
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1286
1489
|
/**
|
|
1287
1490
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1288
1491
|
* partial evaluation.
|
|
1289
1492
|
*/
|
|
1290
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1493
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1291
1494
|
/** Calculate the reverse-mode vector-Jacobian product for a function. */
|
|
1292
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1495
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1293
1496
|
/**
|
|
1294
1497
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1295
1498
|
* first argument.
|
|
1296
1499
|
*/
|
|
1297
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1500
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1298
1501
|
/** Create a function that evaluates both `f` and the gradient of `f`. */
|
|
1299
|
-
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1502
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1300
1503
|
/** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
|
|
1301
1504
|
declare const jacrev: typeof jacfwd;
|
|
1302
1505
|
/** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
|
|
1303
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>,
|
|
1506
|
+
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1304
1507
|
//#endregion
|
|
1305
|
-
export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|
|
1508
|
+
export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|