@jax-js/jax 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +57 -25
- package/dist/backend-EBRGmEYw.js +3816 -0
- package/dist/{backend-BK21PBVP.cjs → backend-Ss1Mev_-.cjs} +2075 -107
- package/dist/index.cjs +1393 -250
- package/dist/index.d.cts +651 -102
- package/dist/index.d.ts +651 -102
- package/dist/index.js +1377 -245
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-BVdMaO9T.cjs} +62 -35
- package/dist/{webgpu-JVpVad6g.js → webgpu-ow0Pn_6q.js} +62 -35
- package/package.json +21 -9
- package/dist/backend-1eVbAoaV.js +0 -1890
package/dist/index.d.ts
CHANGED
|
@@ -20,6 +20,7 @@ interface Stringable {
|
|
|
20
20
|
//#endregion
|
|
21
21
|
//#region src/shape.d.ts
|
|
22
22
|
|
|
23
|
+
/** @inline */
|
|
23
24
|
type Pair = [number, number];
|
|
24
25
|
/**
|
|
25
26
|
* A multidimensional view into memory. An array can be thought of as the
|
|
@@ -46,6 +47,8 @@ declare class View {
|
|
|
46
47
|
get size(): number;
|
|
47
48
|
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
48
49
|
get contiguous(): boolean;
|
|
50
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
51
|
+
dataRange(): [number, number];
|
|
49
52
|
/** Produce an AluExp for evaluating this view at an index. */
|
|
50
53
|
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
51
54
|
/**
|
|
@@ -106,9 +109,33 @@ declare class ShapeTracker {
|
|
|
106
109
|
reshape(newShape: number[]): ShapeTracker;
|
|
107
110
|
/** Broadcast along the given new axes, then expand the shape. */
|
|
108
111
|
broadcast(newShape: number[], axis: number[]): ShapeTracker;
|
|
112
|
+
/**
|
|
113
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
114
|
+
*
|
|
115
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
116
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
117
|
+
*/
|
|
118
|
+
repeat(reps: number[], tile?: boolean): ShapeTracker;
|
|
119
|
+
/** Move axis i to axis j. */
|
|
120
|
+
moveaxis(i: number, j: number): ShapeTracker;
|
|
121
|
+
/** Like pad(), but allows for negative values. */
|
|
122
|
+
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
109
123
|
}
|
|
110
124
|
//#endregion
|
|
111
125
|
//#region src/utils.d.ts
|
|
126
|
+
/**
|
|
127
|
+
* Set the debug level for verbose logging.
|
|
128
|
+
*
|
|
129
|
+
* 1. JIT compile logs
|
|
130
|
+
* 2. Shader code
|
|
131
|
+
* 3. Expressions and metadata
|
|
132
|
+
* 4. JIT programs, tuning details
|
|
133
|
+
* 5. Most verbose operation traces
|
|
134
|
+
*
|
|
135
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
136
|
+
* in production.
|
|
137
|
+
*/
|
|
138
|
+
declare function setDebug(level: number): void;
|
|
112
139
|
/** @inline */
|
|
113
140
|
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
114
141
|
interface FpHashable {
|
|
@@ -124,19 +151,47 @@ interface FpHashable {
|
|
|
124
151
|
declare class FpHash {
|
|
125
152
|
#private;
|
|
126
153
|
value: bigint;
|
|
127
|
-
update(
|
|
154
|
+
update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
|
|
128
155
|
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
129
156
|
}
|
|
130
157
|
/** Run a function while caching it inline inside a `Map`. */
|
|
131
158
|
//#endregion
|
|
132
159
|
//#region src/alu.d.ts
|
|
160
|
+
/** A numerical data type for array contents. */
|
|
133
161
|
declare enum DType {
|
|
134
162
|
Float32 = "float32",
|
|
135
163
|
Int32 = "int32",
|
|
136
164
|
Uint32 = "uint32",
|
|
137
165
|
Bool = "bool",
|
|
138
|
-
|
|
166
|
+
Float16 = "float16",
|
|
139
167
|
}
|
|
168
|
+
/** @inline */
|
|
169
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
170
|
+
/**
|
|
171
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
172
|
+
*
|
|
173
|
+
* When performing operations between arrays of different types, we need to
|
|
174
|
+
* promote both operands to a common type that can represent values from both
|
|
175
|
+
* input types. This follows JAX's type promotion rules.
|
|
176
|
+
*
|
|
177
|
+
* **Type lattice:**
|
|
178
|
+
* ```text
|
|
179
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
180
|
+
* weak f* --^
|
|
181
|
+
* ```
|
|
182
|
+
*
|
|
183
|
+
* The asterisk f* is a weak type used for JS number constants. When creating
|
|
184
|
+
* arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
|
|
185
|
+
* any array they are first combined with.
|
|
186
|
+
*
|
|
187
|
+
* **Examples:**
|
|
188
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
189
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
190
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
191
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
192
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
193
|
+
*/
|
|
194
|
+
declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
|
|
140
195
|
/**
|
|
141
196
|
* Mathematical expression on scalar values.
|
|
142
197
|
*
|
|
@@ -160,8 +215,11 @@ declare class AluExp implements FpHashable {
|
|
|
160
215
|
static max(a: AluExp, b: AluExp): AluExp;
|
|
161
216
|
static sin(a: AluExp): AluExp;
|
|
162
217
|
static cos(a: AluExp): AluExp;
|
|
218
|
+
static asin(a: AluExp): AluExp;
|
|
219
|
+
static atan(a: AluExp): AluExp;
|
|
163
220
|
static exp(a: AluExp): AluExp;
|
|
164
221
|
static log(a: AluExp): AluExp;
|
|
222
|
+
static sqrt(a: AluExp): AluExp;
|
|
165
223
|
static reciprocal(a: AluExp): AluExp;
|
|
166
224
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
167
225
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -172,12 +230,13 @@ declare class AluExp implements FpHashable {
|
|
|
172
230
|
static const(dtype: DType, value: any): AluExp;
|
|
173
231
|
static special(dtype: DType, name: string, n: number): AluExp;
|
|
174
232
|
static variable(dtype: DType, name: string): AluExp;
|
|
175
|
-
static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
|
|
233
|
+
static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
|
|
176
234
|
static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
|
|
235
|
+
static f32(value: number): AluExp;
|
|
177
236
|
static i32(value: number): AluExp;
|
|
178
237
|
static u32(value: number): AluExp;
|
|
179
|
-
static f32(value: number): AluExp;
|
|
180
238
|
static bool(value: boolean): AluExp;
|
|
239
|
+
static f16(value: number): AluExp;
|
|
181
240
|
not(): AluExp;
|
|
182
241
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
183
242
|
getHash(): bigint;
|
|
@@ -188,6 +247,19 @@ declare class AluExp implements FpHashable {
|
|
|
188
247
|
reindexGids(gidMap: Map<number, number>): AluExp;
|
|
189
248
|
get min(): number;
|
|
190
249
|
get max(): number;
|
|
250
|
+
/** Largest known integer that divides self. */
|
|
251
|
+
constFactor(): number;
|
|
252
|
+
/**
|
|
253
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
254
|
+
* `null` if it's not divisible.
|
|
255
|
+
*/
|
|
256
|
+
divides(v: number): AluExp | null;
|
|
257
|
+
/**
|
|
258
|
+
* Get all expressions by deeply matching an operation.
|
|
259
|
+
*
|
|
260
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
261
|
+
*/
|
|
262
|
+
splitOp(sep: AluOp): IterableIterator<AluExp>;
|
|
191
263
|
/**
|
|
192
264
|
* Simplify the expression by replacing any known patterns and deduping
|
|
193
265
|
* identical subexpressions.
|
|
@@ -208,10 +280,16 @@ declare class AluExp implements FpHashable {
|
|
|
208
280
|
toString(): string;
|
|
209
281
|
/** Generic fold() operation with a reducer over the expression tree. */
|
|
210
282
|
fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
|
|
283
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
284
|
+
some(predicate: (exp: AluExp) => boolean): boolean;
|
|
211
285
|
/** Rewrite the expression recursively using a visitor. */
|
|
212
286
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
213
287
|
/** Collect all nodes that satisfy a predicate. */
|
|
214
288
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
289
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
290
|
+
distinctOps(): Set<AluOp>;
|
|
291
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
292
|
+
rewriteGlobalViews(): AluExp;
|
|
215
293
|
}
|
|
216
294
|
/** Symbolic form for each mathematical operation. */
|
|
217
295
|
declare enum AluOp {
|
|
@@ -224,8 +302,11 @@ declare enum AluOp {
|
|
|
224
302
|
Max = "Max",
|
|
225
303
|
Sin = "Sin",
|
|
226
304
|
Cos = "Cos",
|
|
305
|
+
Asin = "Asin",
|
|
306
|
+
Atan = "Atan",
|
|
227
307
|
Exp = "Exp",
|
|
228
308
|
Log = "Log",
|
|
309
|
+
Sqrt = "Sqrt",
|
|
229
310
|
Reciprocal = "Reciprocal",
|
|
230
311
|
Cast = "Cast",
|
|
231
312
|
Bitcast = "Bitcast",
|
|
@@ -242,7 +323,7 @@ declare enum AluOp {
|
|
|
242
323
|
Variable = "Variable",
|
|
243
324
|
// arg = variable
|
|
244
325
|
GlobalIndex = "GlobalIndex",
|
|
245
|
-
// arg = gid; src = [bufidx]
|
|
326
|
+
// arg = [gid, len]; src = [bufidx]
|
|
246
327
|
GlobalView = "GlobalView",
|
|
247
328
|
}
|
|
248
329
|
/**
|
|
@@ -297,12 +378,12 @@ declare class Reduction implements FpHashable {
|
|
|
297
378
|
/** Size of the reduction axis. */
|
|
298
379
|
readonly size: number;
|
|
299
380
|
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
300
|
-
readonly
|
|
381
|
+
readonly epilogue: AluExp;
|
|
301
382
|
constructor(/** Data type of the values being reduced over. */
|
|
302
383
|
dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
303
384
|
op: AluOp, /** Size of the reduction axis. */
|
|
304
385
|
size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
305
|
-
|
|
386
|
+
epilogue?: AluExp);
|
|
306
387
|
hash(state: FpHash): void;
|
|
307
388
|
toString(): string;
|
|
308
389
|
/** Get the identity for this reduction operation. */
|
|
@@ -313,10 +394,10 @@ declare class Reduction implements FpHashable {
|
|
|
313
394
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
314
395
|
//#endregion
|
|
315
396
|
//#region src/backend.d.ts
|
|
316
|
-
type Device = "cpu" | "webgpu";
|
|
397
|
+
type Device = "cpu" | "wasm" | "webgpu";
|
|
317
398
|
declare const devices: Device[];
|
|
318
|
-
/**
|
|
319
|
-
declare function
|
|
399
|
+
/** Configure the default device for arrays. */
|
|
400
|
+
declare function defaultDevice(device?: Device): Device;
|
|
320
401
|
/**
|
|
321
402
|
* Initialize `jax-js` library backends.
|
|
322
403
|
*
|
|
@@ -336,7 +417,7 @@ interface Backend {
|
|
|
336
417
|
/** Maximum number of arguments per dispatched kernel. */
|
|
337
418
|
readonly maxArgs: number;
|
|
338
419
|
/** Allocate a new slot with reference count 1. */
|
|
339
|
-
malloc(size: number, initialData?:
|
|
420
|
+
malloc(size: number, initialData?: Uint8Array): Slot;
|
|
340
421
|
/** Increment the reference count of the slot. */
|
|
341
422
|
incRef(slot: Slot): void;
|
|
342
423
|
/**
|
|
@@ -345,9 +426,9 @@ interface Backend {
|
|
|
345
426
|
*/
|
|
346
427
|
decRef(slot: Slot): void;
|
|
347
428
|
/** Read a range of bytes from a buffer. */
|
|
348
|
-
read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer
|
|
429
|
+
read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
|
|
349
430
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
350
|
-
readSync(slot: Slot, start?: number, count?: number): ArrayBuffer
|
|
431
|
+
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
351
432
|
/** Prepare an expression to be executed later. */
|
|
352
433
|
prepare(kernel: Kernel): Promise<Executable>;
|
|
353
434
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
@@ -369,18 +450,22 @@ declare class Executable<T = any> {
|
|
|
369
450
|
data: T);
|
|
370
451
|
}
|
|
371
452
|
declare namespace tree_d_exports {
|
|
372
|
-
export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
|
|
453
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
373
454
|
}
|
|
374
455
|
declare enum NodeType {
|
|
375
456
|
Array = "Array",
|
|
376
457
|
Object = "Object",
|
|
377
458
|
Leaf = "Leaf",
|
|
378
459
|
}
|
|
460
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
379
461
|
type JsTree<T> = T | JsTree<T>[] | {
|
|
380
462
|
[key: string]: JsTree<T>;
|
|
381
463
|
};
|
|
382
|
-
type
|
|
383
|
-
|
|
464
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
465
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
466
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
467
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
468
|
+
/** Represents the structure of a JsTree. */
|
|
384
469
|
declare class JsTreeDef {
|
|
385
470
|
readonly nodeType: NodeType;
|
|
386
471
|
readonly nodeMetadata: any;
|
|
@@ -406,6 +491,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
|
|
|
406
491
|
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
407
492
|
/** Take a reference of every array in a tree. */
|
|
408
493
|
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
494
|
+
/** Dispose every array in a tree. */
|
|
495
|
+
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
496
|
+
//#endregion
|
|
497
|
+
//#region src/frontend/convolution.d.ts
|
|
498
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
499
|
+
interface ConvParams {
|
|
500
|
+
strides: number[];
|
|
501
|
+
padding: [number, number][];
|
|
502
|
+
lhsDilation: number[];
|
|
503
|
+
rhsDilation: number[];
|
|
504
|
+
}
|
|
505
|
+
/**
|
|
506
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
507
|
+
*
|
|
508
|
+
* If the check succeeds, returns the output shape.
|
|
509
|
+
*/
|
|
510
|
+
|
|
409
511
|
//#endregion
|
|
410
512
|
//#region src/frontend/core.d.ts
|
|
411
513
|
/**
|
|
@@ -431,12 +533,20 @@ declare enum Primitive {
|
|
|
431
533
|
RandomBits = "random_bits",
|
|
432
534
|
Sin = "sin",
|
|
433
535
|
Cos = "cos",
|
|
536
|
+
Asin = "asin",
|
|
537
|
+
Atan = "atan",
|
|
434
538
|
Exp = "exp",
|
|
435
539
|
Log = "log",
|
|
540
|
+
Sqrt = "sqrt",
|
|
436
541
|
Min = "min",
|
|
437
542
|
Max = "max",
|
|
438
543
|
Reduce = "reduce",
|
|
439
544
|
Dot = "dot",
|
|
545
|
+
// sum(x*y, axis=-1)
|
|
546
|
+
Conv = "conv",
|
|
547
|
+
// see lax.conv_general_dilated
|
|
548
|
+
Pool = "pool",
|
|
549
|
+
PoolTranspose = "pool_transpose",
|
|
440
550
|
Compare = "compare",
|
|
441
551
|
Where = "where",
|
|
442
552
|
Transpose = "transpose",
|
|
@@ -448,7 +558,7 @@ declare enum Primitive {
|
|
|
448
558
|
Gather = "gather",
|
|
449
559
|
JitCall = "jit_call",
|
|
450
560
|
}
|
|
451
|
-
interface PrimitiveParamsImpl extends Record<Primitive, Record<string,
|
|
561
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
452
562
|
[Primitive.Cast]: {
|
|
453
563
|
dtype: DType;
|
|
454
564
|
};
|
|
@@ -459,6 +569,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
459
569
|
op: AluOp;
|
|
460
570
|
axis: number[];
|
|
461
571
|
};
|
|
572
|
+
[Primitive.Conv]: ConvParams;
|
|
573
|
+
[Primitive.Pool]: {
|
|
574
|
+
window: number[];
|
|
575
|
+
strides: number[];
|
|
576
|
+
};
|
|
577
|
+
[Primitive.PoolTranspose]: {
|
|
578
|
+
inShape: number[];
|
|
579
|
+
window: number[];
|
|
580
|
+
strides: number[];
|
|
581
|
+
};
|
|
462
582
|
[Primitive.Compare]: {
|
|
463
583
|
op: CompareOp;
|
|
464
584
|
};
|
|
@@ -480,10 +600,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
480
600
|
axis: number[];
|
|
481
601
|
};
|
|
482
602
|
[Primitive.Shrink]: {
|
|
483
|
-
slice: [
|
|
603
|
+
slice: Pair[];
|
|
484
604
|
};
|
|
485
605
|
[Primitive.Pad]: {
|
|
486
|
-
width: [
|
|
606
|
+
width: Pair[];
|
|
487
607
|
};
|
|
488
608
|
[Primitive.Gather]: {
|
|
489
609
|
axis: number[];
|
|
@@ -505,8 +625,10 @@ declare enum CompareOp {
|
|
|
505
625
|
LessEqual = "less_equal",
|
|
506
626
|
}
|
|
507
627
|
/** @inline */
|
|
628
|
+
type Axis = number | number[] | null;
|
|
629
|
+
/** @inline */
|
|
508
630
|
type ReduceOpts = {
|
|
509
|
-
|
|
631
|
+
keepdims?: boolean;
|
|
510
632
|
};
|
|
511
633
|
type MainTrace = {
|
|
512
634
|
level: number;
|
|
@@ -583,8 +705,13 @@ declare abstract class Tracer {
|
|
|
583
705
|
* ```
|
|
584
706
|
*/
|
|
585
707
|
abstract dispose(): void;
|
|
708
|
+
/** The shape of the array. */
|
|
586
709
|
get shape(): number[];
|
|
710
|
+
/** The total number of elements in the array. */
|
|
711
|
+
get size(): number;
|
|
712
|
+
/** The dtype of the array. */
|
|
587
713
|
get dtype(): DType;
|
|
714
|
+
/** The number of dimensions of the array. */
|
|
588
715
|
get ndim(): number;
|
|
589
716
|
/** @ignore */
|
|
590
717
|
fullLower(): Tracer;
|
|
@@ -598,11 +725,11 @@ declare abstract class Tracer {
|
|
|
598
725
|
greaterEqual(other: this | TracerValue): this;
|
|
599
726
|
lessEqual(other: this | TracerValue): this;
|
|
600
727
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
601
|
-
sum(axis?:
|
|
728
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
602
729
|
/** Product of the array elements over a given axis. */
|
|
603
|
-
prod(axis?:
|
|
730
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
604
731
|
/** Compute the average of the array elements along the specified axis. */
|
|
605
|
-
mean(axis?:
|
|
732
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
606
733
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
607
734
|
transpose(perm?: number[]): this;
|
|
608
735
|
/**
|
|
@@ -670,7 +797,7 @@ declare abstract class Tracer {
|
|
|
670
797
|
* the "gather" primitive, and it allows you to access specific elements of
|
|
671
798
|
* the array by integer indices stored in another array.
|
|
672
799
|
*/
|
|
673
|
-
slice(...index: (number | [] | [number] |
|
|
800
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
674
801
|
}
|
|
675
802
|
declare class ShapedArray implements AbstractValue {
|
|
676
803
|
readonly shape: number[];
|
|
@@ -678,7 +805,7 @@ declare class ShapedArray implements AbstractValue {
|
|
|
678
805
|
constructor(shape: number[], dtype: DType);
|
|
679
806
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
680
807
|
get ndim(): number;
|
|
681
|
-
|
|
808
|
+
toString(): string;
|
|
682
809
|
equals(other: ShapedArray): boolean;
|
|
683
810
|
}
|
|
684
811
|
//#endregion
|
|
@@ -730,7 +857,11 @@ declare class Array extends Tracer {
|
|
|
730
857
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
731
858
|
* will be freed when the array is disposed.
|
|
732
859
|
*/
|
|
733
|
-
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend,
|
|
860
|
+
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
|
|
861
|
+
pending
|
|
862
|
+
}?: {
|
|
863
|
+
pending?: Iterable<PendingExecute> | null;
|
|
864
|
+
});
|
|
734
865
|
/** @ignore */
|
|
735
866
|
get aval(): ShapedArray;
|
|
736
867
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -749,14 +880,26 @@ declare class Array extends Tracer {
|
|
|
749
880
|
*/
|
|
750
881
|
[Symbol.toPrimitive](): any;
|
|
751
882
|
/** Realize the array and return it as data. */
|
|
752
|
-
data(): Promise<
|
|
753
|
-
/**
|
|
754
|
-
|
|
883
|
+
data(): Promise<DataArray>;
|
|
884
|
+
/**
|
|
885
|
+
* Wait for this array to finish evaluation.
|
|
886
|
+
*
|
|
887
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
888
|
+
* that pending operations are dispatched and fully executed before it
|
|
889
|
+
* returns.
|
|
890
|
+
*
|
|
891
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
892
|
+
* dispatch of operations as well.
|
|
893
|
+
*
|
|
894
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
895
|
+
* asynchronously for multiple arrays.
|
|
896
|
+
*/
|
|
897
|
+
blockUntilReady(): Promise<Array>;
|
|
755
898
|
/**
|
|
756
899
|
* Realize the array and return it as data. This is a sync variant and not
|
|
757
900
|
* recommended for performance reasons, as it will block rendering.
|
|
758
901
|
*/
|
|
759
|
-
dataSync():
|
|
902
|
+
dataSync(): DataArray;
|
|
760
903
|
/**
|
|
761
904
|
* Convert this array into a JavaScript object.
|
|
762
905
|
*
|
|
@@ -769,17 +912,21 @@ declare class Array extends Tracer {
|
|
|
769
912
|
js(): any;
|
|
770
913
|
/** Convert this array into a JavaScript object, asynchronously. */
|
|
771
914
|
jsAsync(): Promise<any>;
|
|
915
|
+
/**
|
|
916
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
917
|
+
*
|
|
918
|
+
* Throws an error if the array does not have a single element. The array must
|
|
919
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
920
|
+
*/
|
|
921
|
+
item(): number;
|
|
772
922
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
773
923
|
static _implRules(): typeof implRules;
|
|
774
924
|
_realizeSource(): number;
|
|
775
925
|
}
|
|
776
926
|
/** Construct an array from a single scalar constant. */
|
|
777
|
-
|
|
778
|
-
dtype,
|
|
779
|
-
device
|
|
780
|
-
}?: DTypeAndDevice): Array;
|
|
927
|
+
|
|
781
928
|
/** Constructor for creating a new array from data. */
|
|
782
|
-
declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
929
|
+
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
783
930
|
shape,
|
|
784
931
|
dtype,
|
|
785
932
|
device
|
|
@@ -815,7 +962,7 @@ declare function eye(numRows: number, numCols?: number, {
|
|
|
815
962
|
dtype,
|
|
816
963
|
device
|
|
817
964
|
}?: DTypeAndDevice): Array;
|
|
818
|
-
/** Return the identity
|
|
965
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
819
966
|
declare function identity$1(n: number, {
|
|
820
967
|
dtype,
|
|
821
968
|
device
|
|
@@ -851,15 +998,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
851
998
|
dtype,
|
|
852
999
|
device
|
|
853
1000
|
}?: DTypeAndDevice): Array;
|
|
854
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
855
1001
|
declare namespace numpy_d_exports {
|
|
856
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack,
|
|
1002
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
857
1003
|
}
|
|
858
1004
|
declare const float32 = DType.Float32;
|
|
859
1005
|
declare const int32 = DType.Int32;
|
|
860
1006
|
declare const uint32 = DType.Uint32;
|
|
861
1007
|
declare const bool = DType.Bool;
|
|
862
|
-
declare const
|
|
1008
|
+
declare const float16 = DType.Float16;
|
|
863
1009
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
864
1010
|
declare const e: number;
|
|
865
1011
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -870,52 +1016,66 @@ declare const inf: number;
|
|
|
870
1016
|
declare const nan: number;
|
|
871
1017
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
872
1018
|
declare const pi: number;
|
|
873
|
-
/** Element-wise addition, with broadcasting. */
|
|
1019
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
874
1020
|
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
875
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
1021
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
876
1022
|
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
877
|
-
/** Numerical negative of every element of an array. */
|
|
1023
|
+
/** @function Numerical negative of every element of an array. */
|
|
878
1024
|
declare const negative: (x: ArrayLike) => Array;
|
|
879
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1025
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
880
1026
|
declare const reciprocal: (x: ArrayLike) => Array;
|
|
881
|
-
/** Element-wise sine function (takes radians). */
|
|
1027
|
+
/** @function Element-wise sine function (takes radians). */
|
|
882
1028
|
declare const sin: (x: ArrayLike) => Array;
|
|
883
|
-
/** Element-wise cosine function (takes radians). */
|
|
1029
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
884
1030
|
declare const cos: (x: ArrayLike) => Array;
|
|
885
|
-
/**
|
|
1031
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1032
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
1033
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1034
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
1035
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
886
1036
|
declare const exp: (x: ArrayLike) => Array;
|
|
887
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
1037
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
888
1038
|
declare const log: (x: ArrayLike) => Array;
|
|
889
|
-
/**
|
|
1039
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
1040
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
1041
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
890
1042
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
891
|
-
/** Return element-wise maximum of the input arrays. */
|
|
1043
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
892
1044
|
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
893
|
-
/** Compare two arrays element-wise. */
|
|
1045
|
+
/** @function Compare two arrays element-wise. */
|
|
894
1046
|
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
895
|
-
/** Compare two arrays element-wise. */
|
|
1047
|
+
/** @function Compare two arrays element-wise. */
|
|
896
1048
|
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
897
|
-
/** Compare two arrays element-wise. */
|
|
1049
|
+
/** @function Compare two arrays element-wise. */
|
|
898
1050
|
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
899
|
-
/** Compare two arrays element-wise. */
|
|
1051
|
+
/** @function Compare two arrays element-wise. */
|
|
900
1052
|
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
901
|
-
/** Compare two arrays element-wise. */
|
|
1053
|
+
/** @function Compare two arrays element-wise. */
|
|
902
1054
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
903
|
-
/** Compare two arrays element-wise. */
|
|
1055
|
+
/** @function Compare two arrays element-wise. */
|
|
904
1056
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
905
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1057
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
906
1058
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
907
|
-
/**
|
|
1059
|
+
/**
|
|
1060
|
+
* @function
|
|
1061
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
1062
|
+
*/
|
|
908
1063
|
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
909
1064
|
/**
|
|
1065
|
+
* @function
|
|
910
1066
|
* Give a new shape to an array without changing its data.
|
|
911
1067
|
*
|
|
912
1068
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
913
1069
|
* length of the array and remaining dimensions.
|
|
914
1070
|
*/
|
|
915
1071
|
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
916
|
-
/**
|
|
1072
|
+
/**
|
|
1073
|
+
* @function
|
|
1074
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1075
|
+
*/
|
|
917
1076
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
918
1077
|
/**
|
|
1078
|
+
* @function
|
|
919
1079
|
* Add padding (zeros) to an array.
|
|
920
1080
|
*
|
|
921
1081
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -923,10 +1083,28 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
923
1083
|
* pair specifies the padding for its corresponding axis.
|
|
924
1084
|
*/
|
|
925
1085
|
declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
|
|
926
|
-
/**
|
|
1086
|
+
/**
|
|
1087
|
+
* @function
|
|
1088
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
1089
|
+
*/
|
|
927
1090
|
declare const ndim: (x: ArrayLike) => number;
|
|
928
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
1091
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
929
1092
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1093
|
+
/**
|
|
1094
|
+
* @function
|
|
1095
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
1096
|
+
*/
|
|
1097
|
+
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1098
|
+
/**
|
|
1099
|
+
* @function
|
|
1100
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
1101
|
+
*/
|
|
1102
|
+
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1103
|
+
/**
|
|
1104
|
+
* @function
|
|
1105
|
+
* Return a full array with the same shape and type as a given array.
|
|
1106
|
+
*/
|
|
1107
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
930
1108
|
/**
|
|
931
1109
|
* Return the number of elements in an array, optionally along an axis.
|
|
932
1110
|
* Does not consume array reference.
|
|
@@ -935,15 +1113,15 @@ declare function size(a: ArrayLike, axis?: number): number;
|
|
|
935
1113
|
/** Convert an array to a specified dtype. */
|
|
936
1114
|
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
937
1115
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
938
|
-
declare function sum(a: ArrayLike, axis?:
|
|
1116
|
+
declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
939
1117
|
/** Product of the array elements over a given axis. */
|
|
940
|
-
declare function prod(a: ArrayLike, axis?:
|
|
1118
|
+
declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
941
1119
|
/** Return the minimum of array elements along a given axis. */
|
|
942
|
-
declare function min(a: ArrayLike, axis?:
|
|
1120
|
+
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
943
1121
|
/** Return the maximum of array elements along a given axis. */
|
|
944
|
-
declare function max(a: ArrayLike, axis?:
|
|
1122
|
+
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
945
1123
|
/** Compute the average of the array elements along the specified axis. */
|
|
946
|
-
declare function mean(a: ArrayLike, axis?:
|
|
1124
|
+
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
947
1125
|
/**
|
|
948
1126
|
* Returns the indices of the minimum values along an axis.
|
|
949
1127
|
*
|
|
@@ -959,7 +1137,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
959
1137
|
*/
|
|
960
1138
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
961
1139
|
/** Reverse the elements in an array along the given axes. */
|
|
962
|
-
declare function flip(x: ArrayLike, axis?:
|
|
1140
|
+
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
963
1141
|
/**
|
|
964
1142
|
* Join a sequence of arrays along an existing axis.
|
|
965
1143
|
*
|
|
@@ -1003,16 +1181,45 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
1003
1181
|
declare function flipud(x: ArrayLike): Array;
|
|
1004
1182
|
/** Flip an array horizontally (axis=1). */
|
|
1005
1183
|
declare function fliplr(x: ArrayLike): Array;
|
|
1184
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
1006
1185
|
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
1007
1186
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1008
1187
|
declare function ravel(a: ArrayLike): Array;
|
|
1188
|
+
/**
|
|
1189
|
+
* Repeat each element of an array after themselves.
|
|
1190
|
+
*
|
|
1191
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
1192
|
+
* output array.
|
|
1193
|
+
*/
|
|
1194
|
+
declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
|
|
1195
|
+
/**
|
|
1196
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
1197
|
+
*
|
|
1198
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
1199
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
1200
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1201
|
+
*/
|
|
1202
|
+
declare function tile(a: ArrayLike, reps: number | number[]): Array;
|
|
1203
|
+
/**
|
|
1204
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
1205
|
+
*
|
|
1206
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
1207
|
+
* dimensions where the shape is 1.
|
|
1208
|
+
*/
|
|
1209
|
+
declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
|
|
1210
|
+
/** Broadcast input shapes to a common output shape. */
|
|
1211
|
+
declare function broadcastShapes(...shapes: number[][]): number[];
|
|
1212
|
+
/** Broadcast arrays to a common shape. */
|
|
1213
|
+
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1009
1214
|
/**
|
|
1010
1215
|
* Return specified diagonals.
|
|
1011
1216
|
*
|
|
1012
1217
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
1013
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
1218
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
1014
1219
|
*
|
|
1015
|
-
* This returns a view over the existing array.
|
|
1220
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
1221
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
1222
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
1016
1223
|
*/
|
|
1017
1224
|
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1018
1225
|
/**
|
|
@@ -1031,8 +1238,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
|
|
|
1031
1238
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1032
1239
|
/** Dot product of two arrays. */
|
|
1033
1240
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1034
|
-
/**
|
|
1035
|
-
|
|
1241
|
+
/**
|
|
1242
|
+
* Compute the inner product of two arrays.
|
|
1243
|
+
*
|
|
1244
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
1245
|
+
* contraction on the last axis.
|
|
1246
|
+
*
|
|
1247
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1248
|
+
*/
|
|
1249
|
+
declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
1250
|
+
/**
|
|
1251
|
+
* Compute the outer product of two arrays.
|
|
1252
|
+
*
|
|
1253
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
1254
|
+
* be of shape `[x.size, y.size]`.
|
|
1255
|
+
*/
|
|
1256
|
+
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
1257
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
1258
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
1259
|
+
axis
|
|
1260
|
+
}?: {
|
|
1261
|
+
axis?: number;
|
|
1262
|
+
}): Array;
|
|
1036
1263
|
/**
|
|
1037
1264
|
* Return the dot product of two vectors.
|
|
1038
1265
|
*
|
|
@@ -1050,6 +1277,21 @@ declare function meshgrid(xs: Array[], {
|
|
|
1050
1277
|
}?: {
|
|
1051
1278
|
indexing?: "xy" | "ij";
|
|
1052
1279
|
}): Array[];
|
|
1280
|
+
/**
|
|
1281
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1282
|
+
*
|
|
1283
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1284
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1285
|
+
* `k>0` is above it.
|
|
1286
|
+
*/
|
|
1287
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1288
|
+
dtype,
|
|
1289
|
+
device
|
|
1290
|
+
}?: DTypeAndDevice): Array;
|
|
1291
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1292
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1293
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1294
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1053
1295
|
/**
|
|
1054
1296
|
* Clip (limit) the values in an array.
|
|
1055
1297
|
*
|
|
@@ -1066,15 +1308,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
1066
1308
|
* This is the same function as `jax.numpy.abs()`.
|
|
1067
1309
|
*/
|
|
1068
1310
|
declare function absolute(x: ArrayLike): Array;
|
|
1069
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
1311
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
1070
1312
|
declare const abs: typeof absolute;
|
|
1313
|
+
/** Return an element-wise indication of sign of the input. */
|
|
1314
|
+
declare function sign(x: ArrayLike): Array;
|
|
1071
1315
|
/** Calculate element-wise square of the input array. */
|
|
1072
1316
|
declare function square(x: ArrayLike): Array;
|
|
1073
|
-
/**
|
|
1317
|
+
/** Element-wise tangent function (takes radians). */
|
|
1074
1318
|
declare function tan(x: ArrayLike): Array;
|
|
1319
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
1320
|
+
declare function acos(x: ArrayLike): Array;
|
|
1321
|
+
/**
|
|
1322
|
+
* @function
|
|
1323
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1324
|
+
*
|
|
1325
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1326
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
1327
|
+
* improvements.
|
|
1328
|
+
*/
|
|
1329
|
+
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1330
|
+
/**
|
|
1331
|
+
* @function
|
|
1332
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1333
|
+
*
|
|
1334
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1335
|
+
* The result is in the range [-π, π].
|
|
1336
|
+
*
|
|
1337
|
+
* Uses numerically stable formulas:
|
|
1338
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
1339
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
1340
|
+
*
|
|
1341
|
+
* The output is ill-defined when both x and y are zero.
|
|
1342
|
+
*/
|
|
1343
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1344
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
1345
|
+
declare const arccos: typeof acos;
|
|
1346
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
1347
|
+
declare const arctan: (x: ArrayLike) => Array;
|
|
1348
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
1349
|
+
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1350
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
1351
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1075
1352
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
1076
1353
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1077
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
1354
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
1078
1355
|
declare const divide: typeof trueDivide;
|
|
1079
1356
|
/** Round input to the nearest integer towards zero. */
|
|
1080
1357
|
declare function trunc(x: ArrayLike): Array;
|
|
@@ -1084,8 +1361,112 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1084
1361
|
declare function log2(x: ArrayLike): Array;
|
|
1085
1362
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1086
1363
|
declare function log10(x: ArrayLike): Array;
|
|
1364
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
1365
|
+
declare function expm1(x: ArrayLike): Array;
|
|
1366
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1367
|
+
declare function log1p(x: ArrayLike): Array;
|
|
1368
|
+
/** Convert angles from degrees to radians. */
|
|
1369
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
1370
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1371
|
+
declare const radians: typeof deg2rad;
|
|
1372
|
+
/** Convert angles from radians to degrees. */
|
|
1373
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
1374
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1375
|
+
declare const degrees: typeof rad2deg;
|
|
1376
|
+
/**
|
|
1377
|
+
* @function
|
|
1378
|
+
* Computes first array raised to power of second array, element-wise.
|
|
1379
|
+
*/
|
|
1380
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1381
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
1382
|
+
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1383
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
1384
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1385
|
+
/**
|
|
1386
|
+
* @function
|
|
1387
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
1388
|
+
*
|
|
1389
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1390
|
+
*/
|
|
1391
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1392
|
+
/**
|
|
1393
|
+
* @function
|
|
1394
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
1395
|
+
*
|
|
1396
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1397
|
+
*/
|
|
1398
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1399
|
+
/**
|
|
1400
|
+
* @function
|
|
1401
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
1402
|
+
*
|
|
1403
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1404
|
+
*/
|
|
1405
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1406
|
+
/**
|
|
1407
|
+
* @function
|
|
1408
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1409
|
+
*
|
|
1410
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
1411
|
+
*/
|
|
1412
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1413
|
+
/**
|
|
1414
|
+
* @function
|
|
1415
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1416
|
+
*
|
|
1417
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
1418
|
+
*/
|
|
1419
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1420
|
+
/**
|
|
1421
|
+
* @function
|
|
1422
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1423
|
+
*
|
|
1424
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
1425
|
+
*/
|
|
1426
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1427
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
1428
|
+
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1429
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
1430
|
+
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1431
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
1432
|
+
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1433
|
+
/**
|
|
1434
|
+
* Compute the variance of an array.
|
|
1435
|
+
*
|
|
1436
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
1437
|
+
* the specified axis.
|
|
1438
|
+
*
|
|
1439
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1440
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1441
|
+
*/
|
|
1442
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
1443
|
+
mean?: ArrayLike;
|
|
1444
|
+
correction?: number;
|
|
1445
|
+
} & ReduceOpts): Array;
|
|
1446
|
+
/**
|
|
1447
|
+
* Compute the standard deviation of an array.
|
|
1448
|
+
*
|
|
1449
|
+
* The standard deviation is computed for the flattened array by default,
|
|
1450
|
+
* otherwise over the specified axis.
|
|
1451
|
+
*
|
|
1452
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1453
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1454
|
+
*/
|
|
1455
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
1456
|
+
mean?: ArrayLike;
|
|
1457
|
+
correction?: number;
|
|
1458
|
+
} & ReduceOpts): Array;
|
|
1087
1459
|
//#endregion
|
|
1088
1460
|
//#region src/frontend/jaxpr.d.ts
|
|
1461
|
+
/**
|
|
1462
|
+
* Function callback with an associated dispose() method.
|
|
1463
|
+
*
|
|
1464
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
1465
|
+
* by the function after the last time it is called.
|
|
1466
|
+
*/
|
|
1467
|
+
type OwnedFunction<F extends Function> = F & {
|
|
1468
|
+
dispose: () => void;
|
|
1469
|
+
};
|
|
1089
1470
|
/** Variable in a Jaxpr expression. */
|
|
1090
1471
|
declare class Var {
|
|
1091
1472
|
#private;
|
|
@@ -1146,8 +1527,38 @@ declare class Jaxpr implements FpHashable {
|
|
|
1146
1527
|
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1147
1528
|
flatten(): Jaxpr;
|
|
1148
1529
|
}
|
|
1530
|
+
/** @inline */
|
|
1531
|
+
type JitOpts = {
|
|
1532
|
+
staticArgnums?: number[];
|
|
1533
|
+
device?: Device;
|
|
1534
|
+
};
|
|
1535
|
+
declare namespace lax_d_exports {
|
|
1536
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1537
|
+
}
|
|
1538
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1539
|
+
/**
|
|
1540
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
1541
|
+
*
|
|
1542
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
1543
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
1544
|
+
*
|
|
1545
|
+
* Grouped convolutions are not supported right now.
|
|
1546
|
+
*/
|
|
1547
|
+
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1548
|
+
lhsDilation,
|
|
1549
|
+
rhsDilation
|
|
1550
|
+
}?: {
|
|
1551
|
+
lhsDilation?: number[];
|
|
1552
|
+
rhsDilation?: number[];
|
|
1553
|
+
}): Array;
|
|
1554
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1555
|
+
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
1556
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1557
|
+
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1558
|
+
/** Reduce a computation over padded windows. */
|
|
1559
|
+
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1149
1560
|
declare namespace nn_d_exports {
|
|
1150
|
-
export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1561
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1151
1562
|
}
|
|
1152
1563
|
/**
|
|
1153
1564
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1179,6 +1590,7 @@ declare function softplus(x: ArrayLike): Array;
|
|
|
1179
1590
|
*/
|
|
1180
1591
|
declare function softSign(x: ArrayLike): Array;
|
|
1181
1592
|
/**
|
|
1593
|
+
* @function
|
|
1182
1594
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1183
1595
|
* Swish, computed element-wise:
|
|
1184
1596
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1187,8 +1599,9 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1187
1599
|
*
|
|
1188
1600
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1189
1601
|
*/
|
|
1190
|
-
declare
|
|
1602
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1191
1603
|
/**
|
|
1604
|
+
* @function
|
|
1192
1605
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1193
1606
|
* Swish, computed element-wise:
|
|
1194
1607
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1197,14 +1610,67 @@ declare function silu(x: ArrayLike): Array;
|
|
|
1197
1610
|
*
|
|
1198
1611
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1199
1612
|
*/
|
|
1200
|
-
declare const swish:
|
|
1613
|
+
declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1201
1614
|
/**
|
|
1202
1615
|
* Log-sigmoid activation function, computed element-wise:
|
|
1203
1616
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
1204
1617
|
*/
|
|
1205
1618
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1206
|
-
/**
|
|
1619
|
+
/**
|
|
1620
|
+
* @function
|
|
1621
|
+
* Identity activation function. Returns the argument unmodified.
|
|
1622
|
+
*/
|
|
1207
1623
|
declare const identity: (x: ArrayLike) => Array;
|
|
1624
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
1625
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
1626
|
+
/**
|
|
1627
|
+
* Exponential linear unit activation function.
|
|
1628
|
+
*
|
|
1629
|
+
* Computes the element-wise function:
|
|
1630
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1631
|
+
*/
|
|
1632
|
+
declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1633
|
+
/**
|
|
1634
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
1635
|
+
*
|
|
1636
|
+
* Computes the element-wise function:
|
|
1637
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1638
|
+
*/
|
|
1639
|
+
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1640
|
+
/**
|
|
1641
|
+
* @function
|
|
1642
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
1643
|
+
*
|
|
1644
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
1645
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
1646
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1647
|
+
*
|
|
1648
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1649
|
+
*
|
|
1650
|
+
* This will be improved in the future.
|
|
1651
|
+
*/
|
|
1652
|
+
declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1653
|
+
/**
|
|
1654
|
+
* Gated linear unit (GLU) activation function.
|
|
1655
|
+
*
|
|
1656
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
1657
|
+
* computes `a * sigmoid(b)`.
|
|
1658
|
+
*/
|
|
1659
|
+
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1660
|
+
/**
|
|
1661
|
+
* Squareplus activation function.
|
|
1662
|
+
*
|
|
1663
|
+
* Computes the element-wise function:
|
|
1664
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
1665
|
+
*/
|
|
1666
|
+
declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
|
|
1667
|
+
/**
|
|
1668
|
+
* Mish activation function.
|
|
1669
|
+
*
|
|
1670
|
+
* Computes the element-wise function:
|
|
1671
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
1672
|
+
*/
|
|
1673
|
+
declare function mish(x: ArrayLike): Array;
|
|
1208
1674
|
/**
|
|
1209
1675
|
* Softmax function. Computes the function which rescales elements to the range
|
|
1210
1676
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -1213,7 +1679,7 @@ declare const identity: (x: ArrayLike) => Array;
|
|
|
1213
1679
|
*
|
|
1214
1680
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
1215
1681
|
*/
|
|
1216
|
-
declare function softmax(x: ArrayLike, axis?:
|
|
1682
|
+
declare function softmax(x: ArrayLike, axis?: Axis): Array;
|
|
1217
1683
|
/**
|
|
1218
1684
|
* Log-Softmax function.
|
|
1219
1685
|
*
|
|
@@ -1222,7 +1688,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1222
1688
|
*
|
|
1223
1689
|
* If `axis` is not specified, it defaults to the last axis.
|
|
1224
1690
|
*/
|
|
1225
|
-
declare function logSoftmax(x: ArrayLike, axis?:
|
|
1691
|
+
declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
1226
1692
|
/**
|
|
1227
1693
|
* Log-sum-exp reduction. Also a multivariate version of `softplus`.
|
|
1228
1694
|
*
|
|
@@ -1231,7 +1697,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1231
1697
|
*
|
|
1232
1698
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1233
1699
|
*/
|
|
1234
|
-
declare function logsumexp(x: ArrayLike, axis?:
|
|
1700
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
1701
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1702
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
1703
|
+
/**
|
|
1704
|
+
* Standardizes input to zero mean and unit variance.
|
|
1705
|
+
*
|
|
1706
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
1707
|
+
* axis, or `null` to standardize over all elements.
|
|
1708
|
+
*
|
|
1709
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
1710
|
+
*/
|
|
1711
|
+
declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
1712
|
+
mean?: ArrayLike;
|
|
1713
|
+
variance?: ArrayLike;
|
|
1714
|
+
epsilon?: ArrayLike;
|
|
1715
|
+
}): Array;
|
|
1235
1716
|
/**
|
|
1236
1717
|
* One-hot encodes the given indices.
|
|
1237
1718
|
*
|
|
@@ -1250,7 +1731,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1250
1731
|
*/
|
|
1251
1732
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1252
1733
|
declare namespace random_d_exports {
|
|
1253
|
-
export { bits, key, split, uniform };
|
|
1734
|
+
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1254
1735
|
}
|
|
1255
1736
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1256
1737
|
declare function key(seed: number): Array;
|
|
@@ -1266,40 +1747,108 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1266
1747
|
minval?: number;
|
|
1267
1748
|
maxval?: number;
|
|
1268
1749
|
}): Array;
|
|
1750
|
+
/**
|
|
1751
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1752
|
+
*
|
|
1753
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
1754
|
+
* and must be broadcastable to `shape`.
|
|
1755
|
+
*/
|
|
1756
|
+
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1757
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1758
|
+
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1759
|
+
/**
|
|
1760
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1761
|
+
*
|
|
1762
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1763
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1764
|
+
* bitwise identical to JAX.
|
|
1765
|
+
*/
|
|
1766
|
+
declare function normal(key: Array, shape?: number[]): Array;
|
|
1269
1767
|
//#endregion
|
|
1270
1768
|
//#region src/index.d.ts
|
|
1271
|
-
/**
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
declare const
|
|
1769
|
+
/**
|
|
1770
|
+
* @function
|
|
1771
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
1772
|
+
*/
|
|
1773
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1774
|
+
/**
|
|
1775
|
+
* @function
|
|
1776
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1777
|
+
*/
|
|
1778
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1779
|
+
/**
|
|
1780
|
+
* @function
|
|
1781
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
1782
|
+
*/
|
|
1783
|
+
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1784
|
+
/**
|
|
1785
|
+
* @function
|
|
1786
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
1787
|
+
*/
|
|
1788
|
+
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1281
1789
|
jaxpr: Jaxpr;
|
|
1282
1790
|
consts: Array[];
|
|
1283
1791
|
treedef: JsTreeDef;
|
|
1284
1792
|
};
|
|
1285
|
-
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
|
|
1286
1793
|
/**
|
|
1794
|
+
* @function
|
|
1795
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1796
|
+
*
|
|
1797
|
+
* The function will be compiled the first time it is called with a set of
|
|
1798
|
+
* argument shapes.
|
|
1799
|
+
*
|
|
1800
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
1801
|
+
* calls to free memory associated with array constants.
|
|
1802
|
+
*
|
|
1803
|
+
* **Options:**
|
|
1804
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1805
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
1806
|
+
* and different values will trigger recompilation.
|
|
1807
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
1808
|
+
* computation will be placed on the default device.
|
|
1809
|
+
*/
|
|
1810
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
|
|
1811
|
+
/**
|
|
1812
|
+
* @function
|
|
1287
1813
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1288
1814
|
* partial evaluation.
|
|
1289
1815
|
*/
|
|
1290
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1291
|
-
/** Calculate the reverse-mode vector-Jacobian product for a function. */
|
|
1292
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1816
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1293
1817
|
/**
|
|
1818
|
+
* @function
|
|
1819
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
1820
|
+
*/
|
|
1821
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1822
|
+
/**
|
|
1823
|
+
* @function
|
|
1294
1824
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1295
1825
|
* first argument.
|
|
1296
1826
|
*/
|
|
1297
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1298
|
-
/**
|
|
1299
|
-
|
|
1300
|
-
|
|
1827
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1828
|
+
/**
|
|
1829
|
+
* @function
|
|
1830
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
1831
|
+
*/
|
|
1832
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1833
|
+
/**
|
|
1834
|
+
* @function
|
|
1835
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
1836
|
+
*/
|
|
1301
1837
|
declare const jacrev: typeof jacfwd;
|
|
1302
|
-
/**
|
|
1303
|
-
|
|
1838
|
+
/**
|
|
1839
|
+
* @function
|
|
1840
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
1841
|
+
*/
|
|
1842
|
+
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1843
|
+
/**
|
|
1844
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
1845
|
+
*
|
|
1846
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
1847
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
1848
|
+
* to avoid queueing up too many pending operations.
|
|
1849
|
+
*
|
|
1850
|
+
* Does not consume reference to the arrays.
|
|
1851
|
+
*/
|
|
1852
|
+
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1304
1853
|
//#endregion
|
|
1305
|
-
export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random,
|
|
1854
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|