@jax-js/jax 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +57 -25
- package/dist/backend-EBRGmEYw.js +3816 -0
- package/dist/{backend-BK21PBVP.cjs → backend-Ss1Mev_-.cjs} +2075 -107
- package/dist/index.cjs +1393 -250
- package/dist/index.d.cts +651 -102
- package/dist/index.d.ts +651 -102
- package/dist/index.js +1377 -245
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-BVdMaO9T.cjs} +62 -35
- package/dist/{webgpu-JVpVad6g.js → webgpu-ow0Pn_6q.js} +62 -35
- package/package.json +21 -9
- package/dist/backend-1eVbAoaV.js +0 -1890
package/dist/index.d.cts
CHANGED
|
@@ -23,6 +23,7 @@ interface Stringable {
|
|
|
23
23
|
}
|
|
24
24
|
//#endregion
|
|
25
25
|
//#region src/shape.d.ts
|
|
26
|
+
/** @inline */
|
|
26
27
|
type Pair = [number, number];
|
|
27
28
|
/**
|
|
28
29
|
* A multidimensional view into memory. An array can be thought of as the
|
|
@@ -49,6 +50,8 @@ declare class View {
|
|
|
49
50
|
get size(): number;
|
|
50
51
|
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
51
52
|
get contiguous(): boolean;
|
|
53
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
54
|
+
dataRange(): [number, number];
|
|
52
55
|
/** Produce an AluExp for evaluating this view at an index. */
|
|
53
56
|
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
54
57
|
/**
|
|
@@ -109,9 +112,33 @@ declare class ShapeTracker {
|
|
|
109
112
|
reshape(newShape: number[]): ShapeTracker;
|
|
110
113
|
/** Broadcast along the given new axes, then expand the shape. */
|
|
111
114
|
broadcast(newShape: number[], axis: number[]): ShapeTracker;
|
|
115
|
+
/**
|
|
116
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
117
|
+
*
|
|
118
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
119
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
120
|
+
*/
|
|
121
|
+
repeat(reps: number[], tile?: boolean): ShapeTracker;
|
|
122
|
+
/** Move axis i to axis j. */
|
|
123
|
+
moveaxis(i: number, j: number): ShapeTracker;
|
|
124
|
+
/** Like pad(), but allows for negative values. */
|
|
125
|
+
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
112
126
|
}
|
|
113
127
|
//#endregion
|
|
114
128
|
//#region src/utils.d.ts
|
|
129
|
+
/**
|
|
130
|
+
* Set the debug level for verbose logging.
|
|
131
|
+
*
|
|
132
|
+
* 1. JIT compile logs
|
|
133
|
+
* 2. Shader code
|
|
134
|
+
* 3. Expressions and metadata
|
|
135
|
+
* 4. JIT programs, tuning details
|
|
136
|
+
* 5. Most verbose operation traces
|
|
137
|
+
*
|
|
138
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
139
|
+
* in production.
|
|
140
|
+
*/
|
|
141
|
+
declare function setDebug(level: number): void;
|
|
115
142
|
/** @inline */
|
|
116
143
|
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
117
144
|
interface FpHashable {
|
|
@@ -127,19 +154,47 @@ interface FpHashable {
|
|
|
127
154
|
declare class FpHash {
|
|
128
155
|
#private;
|
|
129
156
|
value: bigint;
|
|
130
|
-
update(
|
|
157
|
+
update(x: string | boolean | number | bigint | null | undefined | FpHashable): this;
|
|
131
158
|
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
132
159
|
}
|
|
133
160
|
/** Run a function while caching it inline inside a `Map`. */
|
|
134
161
|
//#endregion
|
|
135
162
|
//#region src/alu.d.ts
|
|
163
|
+
/** A numerical data type for array contents. */
|
|
136
164
|
declare enum DType {
|
|
137
165
|
Float32 = "float32",
|
|
138
166
|
Int32 = "int32",
|
|
139
167
|
Uint32 = "uint32",
|
|
140
168
|
Bool = "bool",
|
|
141
|
-
|
|
169
|
+
Float16 = "float16",
|
|
142
170
|
}
|
|
171
|
+
/** @inline */
|
|
172
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
173
|
+
/**
|
|
174
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
175
|
+
*
|
|
176
|
+
* When performing operations between arrays of different types, we need to
|
|
177
|
+
* promote both operands to a common type that can represent values from both
|
|
178
|
+
* input types. This follows JAX's type promotion rules.
|
|
179
|
+
*
|
|
180
|
+
* **Type lattice:**
|
|
181
|
+
* ```text
|
|
182
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
183
|
+
* weak f* --^
|
|
184
|
+
* ```
|
|
185
|
+
*
|
|
186
|
+
* The asterisk f* is a weak type used for JS number constants. When creating
|
|
187
|
+
* arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
|
|
188
|
+
* any array they are first combined with.
|
|
189
|
+
*
|
|
190
|
+
* **Examples:**
|
|
191
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
192
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
193
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
194
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
195
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
196
|
+
*/
|
|
197
|
+
declare function promoteTypes(dtype1: DType, dtype2: DType): DType;
|
|
143
198
|
/**
|
|
144
199
|
* Mathematical expression on scalar values.
|
|
145
200
|
*
|
|
@@ -163,8 +218,11 @@ declare class AluExp implements FpHashable {
|
|
|
163
218
|
static max(a: AluExp, b: AluExp): AluExp;
|
|
164
219
|
static sin(a: AluExp): AluExp;
|
|
165
220
|
static cos(a: AluExp): AluExp;
|
|
221
|
+
static asin(a: AluExp): AluExp;
|
|
222
|
+
static atan(a: AluExp): AluExp;
|
|
166
223
|
static exp(a: AluExp): AluExp;
|
|
167
224
|
static log(a: AluExp): AluExp;
|
|
225
|
+
static sqrt(a: AluExp): AluExp;
|
|
168
226
|
static reciprocal(a: AluExp): AluExp;
|
|
169
227
|
static cast(dtype: DType, a: AluExp): AluExp;
|
|
170
228
|
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
@@ -175,12 +233,13 @@ declare class AluExp implements FpHashable {
|
|
|
175
233
|
static const(dtype: DType, value: any): AluExp;
|
|
176
234
|
static special(dtype: DType, name: string, n: number): AluExp;
|
|
177
235
|
static variable(dtype: DType, name: string): AluExp;
|
|
178
|
-
static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
|
|
236
|
+
static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
|
|
179
237
|
static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
|
|
238
|
+
static f32(value: number): AluExp;
|
|
180
239
|
static i32(value: number): AluExp;
|
|
181
240
|
static u32(value: number): AluExp;
|
|
182
|
-
static f32(value: number): AluExp;
|
|
183
241
|
static bool(value: boolean): AluExp;
|
|
242
|
+
static f16(value: number): AluExp;
|
|
184
243
|
not(): AluExp;
|
|
185
244
|
/** Compute a reasonable expression hash with low collision rate. */
|
|
186
245
|
getHash(): bigint;
|
|
@@ -191,6 +250,19 @@ declare class AluExp implements FpHashable {
|
|
|
191
250
|
reindexGids(gidMap: Map<number, number>): AluExp;
|
|
192
251
|
get min(): number;
|
|
193
252
|
get max(): number;
|
|
253
|
+
/** Largest known integer that divides self. */
|
|
254
|
+
constFactor(): number;
|
|
255
|
+
/**
|
|
256
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
257
|
+
* `null` if it's not divisible.
|
|
258
|
+
*/
|
|
259
|
+
divides(v: number): AluExp | null;
|
|
260
|
+
/**
|
|
261
|
+
* Get all expressions by deeply matching an operation.
|
|
262
|
+
*
|
|
263
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
264
|
+
*/
|
|
265
|
+
splitOp(sep: AluOp): IterableIterator<AluExp>;
|
|
194
266
|
/**
|
|
195
267
|
* Simplify the expression by replacing any known patterns and deduping
|
|
196
268
|
* identical subexpressions.
|
|
@@ -211,10 +283,16 @@ declare class AluExp implements FpHashable {
|
|
|
211
283
|
toString(): string;
|
|
212
284
|
/** Generic fold() operation with a reducer over the expression tree. */
|
|
213
285
|
fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
|
|
286
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
287
|
+
some(predicate: (exp: AluExp) => boolean): boolean;
|
|
214
288
|
/** Rewrite the expression recursively using a visitor. */
|
|
215
289
|
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
216
290
|
/** Collect all nodes that satisfy a predicate. */
|
|
217
291
|
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
292
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
293
|
+
distinctOps(): Set<AluOp>;
|
|
294
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
295
|
+
rewriteGlobalViews(): AluExp;
|
|
218
296
|
}
|
|
219
297
|
/** Symbolic form for each mathematical operation. */
|
|
220
298
|
declare enum AluOp {
|
|
@@ -227,8 +305,11 @@ declare enum AluOp {
|
|
|
227
305
|
Max = "Max",
|
|
228
306
|
Sin = "Sin",
|
|
229
307
|
Cos = "Cos",
|
|
308
|
+
Asin = "Asin",
|
|
309
|
+
Atan = "Atan",
|
|
230
310
|
Exp = "Exp",
|
|
231
311
|
Log = "Log",
|
|
312
|
+
Sqrt = "Sqrt",
|
|
232
313
|
Reciprocal = "Reciprocal",
|
|
233
314
|
Cast = "Cast",
|
|
234
315
|
Bitcast = "Bitcast",
|
|
@@ -245,7 +326,7 @@ declare enum AluOp {
|
|
|
245
326
|
Variable = "Variable",
|
|
246
327
|
// arg = variable
|
|
247
328
|
GlobalIndex = "GlobalIndex",
|
|
248
|
-
// arg = gid; src = [bufidx]
|
|
329
|
+
// arg = [gid, len]; src = [bufidx]
|
|
249
330
|
GlobalView = "GlobalView",
|
|
250
331
|
}
|
|
251
332
|
/**
|
|
@@ -300,12 +381,12 @@ declare class Reduction implements FpHashable {
|
|
|
300
381
|
/** Size of the reduction axis. */
|
|
301
382
|
readonly size: number;
|
|
302
383
|
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
303
|
-
readonly
|
|
384
|
+
readonly epilogue: AluExp;
|
|
304
385
|
constructor(/** Data type of the values being reduced over. */
|
|
305
386
|
dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
306
387
|
op: AluOp, /** Size of the reduction axis. */
|
|
307
388
|
size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
308
|
-
|
|
389
|
+
epilogue?: AluExp);
|
|
309
390
|
hash(state: FpHash): void;
|
|
310
391
|
toString(): string;
|
|
311
392
|
/** Get the identity for this reduction operation. */
|
|
@@ -316,10 +397,10 @@ declare class Reduction implements FpHashable {
|
|
|
316
397
|
/** Expression for accessing `indices` in input array with the given shape. */
|
|
317
398
|
//#endregion
|
|
318
399
|
//#region src/backend.d.ts
|
|
319
|
-
type Device = "cpu" | "webgpu";
|
|
400
|
+
type Device = "cpu" | "wasm" | "webgpu";
|
|
320
401
|
declare const devices: Device[];
|
|
321
|
-
/**
|
|
322
|
-
declare function
|
|
402
|
+
/** Configure the default device for arrays. */
|
|
403
|
+
declare function defaultDevice(device?: Device): Device;
|
|
323
404
|
/**
|
|
324
405
|
* Initialize `jax-js` library backends.
|
|
325
406
|
*
|
|
@@ -339,7 +420,7 @@ interface Backend {
|
|
|
339
420
|
/** Maximum number of arguments per dispatched kernel. */
|
|
340
421
|
readonly maxArgs: number;
|
|
341
422
|
/** Allocate a new slot with reference count 1. */
|
|
342
|
-
malloc(size: number, initialData?:
|
|
423
|
+
malloc(size: number, initialData?: Uint8Array): Slot;
|
|
343
424
|
/** Increment the reference count of the slot. */
|
|
344
425
|
incRef(slot: Slot): void;
|
|
345
426
|
/**
|
|
@@ -348,9 +429,9 @@ interface Backend {
|
|
|
348
429
|
*/
|
|
349
430
|
decRef(slot: Slot): void;
|
|
350
431
|
/** Read a range of bytes from a buffer. */
|
|
351
|
-
read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer
|
|
432
|
+
read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
|
|
352
433
|
/** Read a range of bytes from a buffer, blocking variant. */
|
|
353
|
-
readSync(slot: Slot, start?: number, count?: number): ArrayBuffer
|
|
434
|
+
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
354
435
|
/** Prepare an expression to be executed later. */
|
|
355
436
|
prepare(kernel: Kernel): Promise<Executable>;
|
|
356
437
|
/** Prepare an expression to be executed later, blocking variant. */
|
|
@@ -372,18 +453,22 @@ declare class Executable<T = any> {
|
|
|
372
453
|
data: T);
|
|
373
454
|
}
|
|
374
455
|
declare namespace tree_d_exports {
|
|
375
|
-
export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
|
|
456
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
376
457
|
}
|
|
377
458
|
declare enum NodeType {
|
|
378
459
|
Array = "Array",
|
|
379
460
|
Object = "Object",
|
|
380
461
|
Leaf = "Leaf",
|
|
381
462
|
}
|
|
463
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
382
464
|
type JsTree<T> = T | JsTree<T>[] | {
|
|
383
465
|
[key: string]: JsTree<T>;
|
|
384
466
|
};
|
|
385
|
-
type
|
|
386
|
-
|
|
467
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
468
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
469
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
470
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
471
|
+
/** Represents the structure of a JsTree. */
|
|
387
472
|
declare class JsTreeDef {
|
|
388
473
|
readonly nodeType: NodeType;
|
|
389
474
|
readonly nodeMetadata: any;
|
|
@@ -409,6 +494,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
|
|
|
409
494
|
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
410
495
|
/** Take a reference of every array in a tree. */
|
|
411
496
|
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
497
|
+
/** Dispose every array in a tree. */
|
|
498
|
+
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
499
|
+
//#endregion
|
|
500
|
+
//#region src/frontend/convolution.d.ts
|
|
501
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
502
|
+
interface ConvParams {
|
|
503
|
+
strides: number[];
|
|
504
|
+
padding: [number, number][];
|
|
505
|
+
lhsDilation: number[];
|
|
506
|
+
rhsDilation: number[];
|
|
507
|
+
}
|
|
508
|
+
/**
|
|
509
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
510
|
+
*
|
|
511
|
+
* If the check succeeds, returns the output shape.
|
|
512
|
+
*/
|
|
513
|
+
|
|
412
514
|
//#endregion
|
|
413
515
|
//#region src/frontend/core.d.ts
|
|
414
516
|
/**
|
|
@@ -434,12 +536,20 @@ declare enum Primitive {
|
|
|
434
536
|
RandomBits = "random_bits",
|
|
435
537
|
Sin = "sin",
|
|
436
538
|
Cos = "cos",
|
|
539
|
+
Asin = "asin",
|
|
540
|
+
Atan = "atan",
|
|
437
541
|
Exp = "exp",
|
|
438
542
|
Log = "log",
|
|
543
|
+
Sqrt = "sqrt",
|
|
439
544
|
Min = "min",
|
|
440
545
|
Max = "max",
|
|
441
546
|
Reduce = "reduce",
|
|
442
547
|
Dot = "dot",
|
|
548
|
+
// sum(x*y, axis=-1)
|
|
549
|
+
Conv = "conv",
|
|
550
|
+
// see lax.conv_general_dilated
|
|
551
|
+
Pool = "pool",
|
|
552
|
+
PoolTranspose = "pool_transpose",
|
|
443
553
|
Compare = "compare",
|
|
444
554
|
Where = "where",
|
|
445
555
|
Transpose = "transpose",
|
|
@@ -451,7 +561,7 @@ declare enum Primitive {
|
|
|
451
561
|
Gather = "gather",
|
|
452
562
|
JitCall = "jit_call",
|
|
453
563
|
}
|
|
454
|
-
interface PrimitiveParamsImpl extends Record<Primitive, Record<string,
|
|
564
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
455
565
|
[Primitive.Cast]: {
|
|
456
566
|
dtype: DType;
|
|
457
567
|
};
|
|
@@ -462,6 +572,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
462
572
|
op: AluOp;
|
|
463
573
|
axis: number[];
|
|
464
574
|
};
|
|
575
|
+
[Primitive.Conv]: ConvParams;
|
|
576
|
+
[Primitive.Pool]: {
|
|
577
|
+
window: number[];
|
|
578
|
+
strides: number[];
|
|
579
|
+
};
|
|
580
|
+
[Primitive.PoolTranspose]: {
|
|
581
|
+
inShape: number[];
|
|
582
|
+
window: number[];
|
|
583
|
+
strides: number[];
|
|
584
|
+
};
|
|
465
585
|
[Primitive.Compare]: {
|
|
466
586
|
op: CompareOp;
|
|
467
587
|
};
|
|
@@ -483,10 +603,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
|
|
|
483
603
|
axis: number[];
|
|
484
604
|
};
|
|
485
605
|
[Primitive.Shrink]: {
|
|
486
|
-
slice: [
|
|
606
|
+
slice: Pair[];
|
|
487
607
|
};
|
|
488
608
|
[Primitive.Pad]: {
|
|
489
|
-
width: [
|
|
609
|
+
width: Pair[];
|
|
490
610
|
};
|
|
491
611
|
[Primitive.Gather]: {
|
|
492
612
|
axis: number[];
|
|
@@ -508,8 +628,10 @@ declare enum CompareOp {
|
|
|
508
628
|
LessEqual = "less_equal",
|
|
509
629
|
}
|
|
510
630
|
/** @inline */
|
|
631
|
+
type Axis = number | number[] | null;
|
|
632
|
+
/** @inline */
|
|
511
633
|
type ReduceOpts = {
|
|
512
|
-
|
|
634
|
+
keepdims?: boolean;
|
|
513
635
|
};
|
|
514
636
|
type MainTrace = {
|
|
515
637
|
level: number;
|
|
@@ -586,8 +708,13 @@ declare abstract class Tracer {
|
|
|
586
708
|
* ```
|
|
587
709
|
*/
|
|
588
710
|
abstract dispose(): void;
|
|
711
|
+
/** The shape of the array. */
|
|
589
712
|
get shape(): number[];
|
|
713
|
+
/** The total number of elements in the array. */
|
|
714
|
+
get size(): number;
|
|
715
|
+
/** The dtype of the array. */
|
|
590
716
|
get dtype(): DType;
|
|
717
|
+
/** The number of dimensions of the array. */
|
|
591
718
|
get ndim(): number;
|
|
592
719
|
/** @ignore */
|
|
593
720
|
fullLower(): Tracer;
|
|
@@ -601,11 +728,11 @@ declare abstract class Tracer {
|
|
|
601
728
|
greaterEqual(other: this | TracerValue): this;
|
|
602
729
|
lessEqual(other: this | TracerValue): this;
|
|
603
730
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
604
|
-
sum(axis?:
|
|
731
|
+
sum(axis?: Axis, opts?: ReduceOpts): this;
|
|
605
732
|
/** Product of the array elements over a given axis. */
|
|
606
|
-
prod(axis?:
|
|
733
|
+
prod(axis?: Axis, opts?: ReduceOpts): this;
|
|
607
734
|
/** Compute the average of the array elements along the specified axis. */
|
|
608
|
-
mean(axis?:
|
|
735
|
+
mean(axis?: Axis, opts?: ReduceOpts): this;
|
|
609
736
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
610
737
|
transpose(perm?: number[]): this;
|
|
611
738
|
/**
|
|
@@ -673,7 +800,7 @@ declare abstract class Tracer {
|
|
|
673
800
|
* the "gather" primitive, and it allows you to access specific elements of
|
|
674
801
|
* the array by integer indices stored in another array.
|
|
675
802
|
*/
|
|
676
|
-
slice(...index: (number | [] | [number] |
|
|
803
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
677
804
|
}
|
|
678
805
|
declare class ShapedArray implements AbstractValue {
|
|
679
806
|
readonly shape: number[];
|
|
@@ -681,7 +808,7 @@ declare class ShapedArray implements AbstractValue {
|
|
|
681
808
|
constructor(shape: number[], dtype: DType);
|
|
682
809
|
static fromAval(aval: AbstractValue): ShapedArray;
|
|
683
810
|
get ndim(): number;
|
|
684
|
-
|
|
811
|
+
toString(): string;
|
|
685
812
|
equals(other: ShapedArray): boolean;
|
|
686
813
|
}
|
|
687
814
|
//#endregion
|
|
@@ -733,7 +860,11 @@ declare class Array extends Tracer {
|
|
|
733
860
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
734
861
|
* will be freed when the array is disposed.
|
|
735
862
|
*/
|
|
736
|
-
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend,
|
|
863
|
+
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, {
|
|
864
|
+
pending
|
|
865
|
+
}?: {
|
|
866
|
+
pending?: Iterable<PendingExecute> | null;
|
|
867
|
+
});
|
|
737
868
|
/** @ignore */
|
|
738
869
|
get aval(): ShapedArray;
|
|
739
870
|
/** Return a simple string representation of the array's dimensions. */
|
|
@@ -752,14 +883,26 @@ declare class Array extends Tracer {
|
|
|
752
883
|
*/
|
|
753
884
|
[Symbol.toPrimitive](): any;
|
|
754
885
|
/** Realize the array and return it as data. */
|
|
755
|
-
data(): Promise<
|
|
756
|
-
/**
|
|
757
|
-
|
|
886
|
+
data(): Promise<DataArray>;
|
|
887
|
+
/**
|
|
888
|
+
* Wait for this array to finish evaluation.
|
|
889
|
+
*
|
|
890
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
891
|
+
* that pending operations are dispatched and fully executed before it
|
|
892
|
+
* returns.
|
|
893
|
+
*
|
|
894
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
895
|
+
* dispatch of operations as well.
|
|
896
|
+
*
|
|
897
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
898
|
+
* asynchronously for multiple arrays.
|
|
899
|
+
*/
|
|
900
|
+
blockUntilReady(): Promise<Array>;
|
|
758
901
|
/**
|
|
759
902
|
* Realize the array and return it as data. This is a sync variant and not
|
|
760
903
|
* recommended for performance reasons, as it will block rendering.
|
|
761
904
|
*/
|
|
762
|
-
dataSync():
|
|
905
|
+
dataSync(): DataArray;
|
|
763
906
|
/**
|
|
764
907
|
* Convert this array into a JavaScript object.
|
|
765
908
|
*
|
|
@@ -772,17 +915,21 @@ declare class Array extends Tracer {
|
|
|
772
915
|
js(): any;
|
|
773
916
|
/** Convert this array into a JavaScript object, asynchronously. */
|
|
774
917
|
jsAsync(): Promise<any>;
|
|
918
|
+
/**
|
|
919
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
920
|
+
*
|
|
921
|
+
* Throws an error if the array does not have a single element. The array must
|
|
922
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
923
|
+
*/
|
|
924
|
+
item(): number;
|
|
775
925
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
776
926
|
static _implRules(): typeof implRules;
|
|
777
927
|
_realizeSource(): number;
|
|
778
928
|
}
|
|
779
929
|
/** Construct an array from a single scalar constant. */
|
|
780
|
-
|
|
781
|
-
dtype,
|
|
782
|
-
device
|
|
783
|
-
}?: DTypeAndDevice): Array;
|
|
930
|
+
|
|
784
931
|
/** Constructor for creating a new array from data. */
|
|
785
|
-
declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
932
|
+
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
786
933
|
shape,
|
|
787
934
|
dtype,
|
|
788
935
|
device
|
|
@@ -818,7 +965,7 @@ declare function eye(numRows: number, numCols?: number, {
|
|
|
818
965
|
dtype,
|
|
819
966
|
device
|
|
820
967
|
}?: DTypeAndDevice): Array;
|
|
821
|
-
/** Return the identity
|
|
968
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
822
969
|
declare function identity$1(n: number, {
|
|
823
970
|
dtype,
|
|
824
971
|
device
|
|
@@ -854,15 +1001,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
|
|
|
854
1001
|
dtype,
|
|
855
1002
|
device
|
|
856
1003
|
}?: DTypeAndDevice): Array;
|
|
857
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
858
1004
|
declare namespace numpy_d_exports {
|
|
859
|
-
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack,
|
|
1005
|
+
export { Array, ArrayLike, DType, abs, absolute, acos, acosh, add, allclose, arange, arccos, arccosh, arcsinh, arctan, arctan2, arctanh, argmax, argmin, array, asin, asinh, astype, atan, atan2, atanh, bool, broadcastArrays, broadcastShapes, broadcastTo, cbrt, clip, columnStack, concatenate, cos, cosh, deg2rad, degrees, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, expm1, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, hypot, identity$1 as identity, inf, inner, int32, less, lessEqual, linspace, log, log10, log1p, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, outer, pad, permuteDims, pi, pow, power, prod, promoteTypes, rad2deg, radians, ravel, reciprocal, repeat, reshape, shape$1 as shape, sign, sin, sinh, size, sqrt, square, stack, std, subtract, sum, tan, tanh, tile, transpose, tri, tril, triu, trueDivide, trunc, uint32, var_, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
860
1006
|
}
|
|
861
1007
|
declare const float32 = DType.Float32;
|
|
862
1008
|
declare const int32 = DType.Int32;
|
|
863
1009
|
declare const uint32 = DType.Uint32;
|
|
864
1010
|
declare const bool = DType.Bool;
|
|
865
|
-
declare const
|
|
1011
|
+
declare const float16 = DType.Float16;
|
|
866
1012
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
867
1013
|
declare const e: number;
|
|
868
1014
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -873,52 +1019,66 @@ declare const inf: number;
|
|
|
873
1019
|
declare const nan: number;
|
|
874
1020
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
875
1021
|
declare const pi: number;
|
|
876
|
-
/** Element-wise addition, with broadcasting. */
|
|
1022
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
877
1023
|
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
878
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
1024
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
879
1025
|
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
880
|
-
/** Numerical negative of every element of an array. */
|
|
1026
|
+
/** @function Numerical negative of every element of an array. */
|
|
881
1027
|
declare const negative: (x: ArrayLike) => Array;
|
|
882
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
1028
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
883
1029
|
declare const reciprocal: (x: ArrayLike) => Array;
|
|
884
|
-
/** Element-wise sine function (takes radians). */
|
|
1030
|
+
/** @function Element-wise sine function (takes radians). */
|
|
885
1031
|
declare const sin: (x: ArrayLike) => Array;
|
|
886
|
-
/** Element-wise cosine function (takes radians). */
|
|
1032
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
887
1033
|
declare const cos: (x: ArrayLike) => Array;
|
|
888
|
-
/**
|
|
1034
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
1035
|
+
declare const asin: (x: ArrayLike) => Array;
|
|
1036
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
1037
|
+
declare const atan: (x: ArrayLike) => Array;
|
|
1038
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
889
1039
|
declare const exp: (x: ArrayLike) => Array;
|
|
890
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
1040
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
891
1041
|
declare const log: (x: ArrayLike) => Array;
|
|
892
|
-
/**
|
|
1042
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
1043
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
1044
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
893
1045
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
894
|
-
/** Return element-wise maximum of the input arrays. */
|
|
1046
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
895
1047
|
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
896
|
-
/** Compare two arrays element-wise. */
|
|
1048
|
+
/** @function Compare two arrays element-wise. */
|
|
897
1049
|
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
898
|
-
/** Compare two arrays element-wise. */
|
|
1050
|
+
/** @function Compare two arrays element-wise. */
|
|
899
1051
|
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
900
|
-
/** Compare two arrays element-wise. */
|
|
1052
|
+
/** @function Compare two arrays element-wise. */
|
|
901
1053
|
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
902
|
-
/** Compare two arrays element-wise. */
|
|
1054
|
+
/** @function Compare two arrays element-wise. */
|
|
903
1055
|
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
904
|
-
/** Compare two arrays element-wise. */
|
|
1056
|
+
/** @function Compare two arrays element-wise. */
|
|
905
1057
|
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
906
|
-
/** Compare two arrays element-wise. */
|
|
1058
|
+
/** @function Compare two arrays element-wise. */
|
|
907
1059
|
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
908
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
1060
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
909
1061
|
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
910
|
-
/**
|
|
1062
|
+
/**
|
|
1063
|
+
* @function
|
|
1064
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
1065
|
+
*/
|
|
911
1066
|
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
912
1067
|
/**
|
|
1068
|
+
* @function
|
|
913
1069
|
* Give a new shape to an array without changing its data.
|
|
914
1070
|
*
|
|
915
1071
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
916
1072
|
* length of the array and remaining dimensions.
|
|
917
1073
|
*/
|
|
918
1074
|
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
919
|
-
/**
|
|
1075
|
+
/**
|
|
1076
|
+
* @function
|
|
1077
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
1078
|
+
*/
|
|
920
1079
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
921
1080
|
/**
|
|
1081
|
+
* @function
|
|
922
1082
|
* Add padding (zeros) to an array.
|
|
923
1083
|
*
|
|
924
1084
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -926,10 +1086,28 @@ declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
|
926
1086
|
* pair specifies the padding for its corresponding axis.
|
|
927
1087
|
*/
|
|
928
1088
|
declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
|
|
929
|
-
/**
|
|
1089
|
+
/**
|
|
1090
|
+
* @function
|
|
1091
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
1092
|
+
*/
|
|
930
1093
|
declare const ndim: (x: ArrayLike) => number;
|
|
931
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
1094
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
932
1095
|
declare const shape$1: (x: ArrayLike) => number[];
|
|
1096
|
+
/**
|
|
1097
|
+
* @function
|
|
1098
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
1099
|
+
*/
|
|
1100
|
+
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1101
|
+
/**
|
|
1102
|
+
* @function
|
|
1103
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
1104
|
+
*/
|
|
1105
|
+
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1106
|
+
/**
|
|
1107
|
+
* @function
|
|
1108
|
+
* Return a full array with the same shape and type as a given array.
|
|
1109
|
+
*/
|
|
1110
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
933
1111
|
/**
|
|
934
1112
|
* Return the number of elements in an array, optionally along an axis.
|
|
935
1113
|
* Does not consume array reference.
|
|
@@ -938,15 +1116,15 @@ declare function size(a: ArrayLike, axis?: number): number;
|
|
|
938
1116
|
/** Convert an array to a specified dtype. */
|
|
939
1117
|
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
940
1118
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
941
|
-
declare function sum(a: ArrayLike, axis?:
|
|
1119
|
+
declare function sum(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
942
1120
|
/** Product of the array elements over a given axis. */
|
|
943
|
-
declare function prod(a: ArrayLike, axis?:
|
|
1121
|
+
declare function prod(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
944
1122
|
/** Return the minimum of array elements along a given axis. */
|
|
945
|
-
declare function min(a: ArrayLike, axis?:
|
|
1123
|
+
declare function min(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
946
1124
|
/** Return the maximum of array elements along a given axis. */
|
|
947
|
-
declare function max(a: ArrayLike, axis?:
|
|
1125
|
+
declare function max(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
948
1126
|
/** Compute the average of the array elements along the specified axis. */
|
|
949
|
-
declare function mean(a: ArrayLike, axis?:
|
|
1127
|
+
declare function mean(a: ArrayLike, axis?: Axis, opts?: ReduceOpts): Array;
|
|
950
1128
|
/**
|
|
951
1129
|
* Returns the indices of the minimum values along an axis.
|
|
952
1130
|
*
|
|
@@ -962,7 +1140,7 @@ declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
|
962
1140
|
*/
|
|
963
1141
|
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
964
1142
|
/** Reverse the elements in an array along the given axes. */
|
|
965
|
-
declare function flip(x: ArrayLike, axis?:
|
|
1143
|
+
declare function flip(x: ArrayLike, axis?: Axis): Array;
|
|
966
1144
|
/**
|
|
967
1145
|
* Join a sequence of arrays along an existing axis.
|
|
968
1146
|
*
|
|
@@ -1006,16 +1184,45 @@ declare function columnStack(xs: ArrayLike[]): Array;
|
|
|
1006
1184
|
declare function flipud(x: ArrayLike): Array;
|
|
1007
1185
|
/** Flip an array horizontally (axis=1). */
|
|
1008
1186
|
declare function fliplr(x: ArrayLike): Array;
|
|
1187
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
1009
1188
|
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
1010
1189
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
1011
1190
|
declare function ravel(a: ArrayLike): Array;
|
|
1191
|
+
/**
|
|
1192
|
+
* Repeat each element of an array after themselves.
|
|
1193
|
+
*
|
|
1194
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
1195
|
+
* output array.
|
|
1196
|
+
*/
|
|
1197
|
+
declare function repeat(a: ArrayLike, repeats: number, axis?: number): Array;
|
|
1198
|
+
/**
|
|
1199
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
1200
|
+
*
|
|
1201
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
1202
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
1203
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
1204
|
+
*/
|
|
1205
|
+
declare function tile(a: ArrayLike, reps: number | number[]): Array;
|
|
1206
|
+
/**
|
|
1207
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
1208
|
+
*
|
|
1209
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
1210
|
+
* dimensions where the shape is 1.
|
|
1211
|
+
*/
|
|
1212
|
+
declare function broadcastTo(a: ArrayLike, shape: number[]): Array;
|
|
1213
|
+
/** Broadcast input shapes to a common output shape. */
|
|
1214
|
+
declare function broadcastShapes(...shapes: number[][]): number[];
|
|
1215
|
+
/** Broadcast arrays to a common shape. */
|
|
1216
|
+
declare function broadcastArrays(...arrays: ArrayLike[]): Array[];
|
|
1012
1217
|
/**
|
|
1013
1218
|
* Return specified diagonals.
|
|
1014
1219
|
*
|
|
1015
1220
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
1016
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
1221
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
1017
1222
|
*
|
|
1018
|
-
* This returns a view over the existing array.
|
|
1223
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
1224
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
1225
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
1019
1226
|
*/
|
|
1020
1227
|
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
1021
1228
|
/**
|
|
@@ -1034,8 +1241,28 @@ declare function allclose(actual: Parameters<typeof array>[0], expected: Paramet
|
|
|
1034
1241
|
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
1035
1242
|
/** Dot product of two arrays. */
|
|
1036
1243
|
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
1037
|
-
/**
|
|
1038
|
-
|
|
1244
|
+
/**
|
|
1245
|
+
* Compute the inner product of two arrays.
|
|
1246
|
+
*
|
|
1247
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
1248
|
+
* contraction on the last axis.
|
|
1249
|
+
*
|
|
1250
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
1251
|
+
*/
|
|
1252
|
+
declare function inner(x: ArrayLike, y: ArrayLike): Array;
|
|
1253
|
+
/**
|
|
1254
|
+
* Compute the outer product of two arrays.
|
|
1255
|
+
*
|
|
1256
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
1257
|
+
* be of shape `[x.size, y.size]`.
|
|
1258
|
+
*/
|
|
1259
|
+
declare function outer(x: ArrayLike, y: ArrayLike): Array;
|
|
1260
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
1261
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike, {
|
|
1262
|
+
axis
|
|
1263
|
+
}?: {
|
|
1264
|
+
axis?: number;
|
|
1265
|
+
}): Array;
|
|
1039
1266
|
/**
|
|
1040
1267
|
* Return the dot product of two vectors.
|
|
1041
1268
|
*
|
|
@@ -1053,6 +1280,21 @@ declare function meshgrid(xs: Array[], {
|
|
|
1053
1280
|
}?: {
|
|
1054
1281
|
indexing?: "xy" | "ij";
|
|
1055
1282
|
}): Array[];
|
|
1283
|
+
/**
|
|
1284
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
1285
|
+
*
|
|
1286
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
1287
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
1288
|
+
* `k>0` is above it.
|
|
1289
|
+
*/
|
|
1290
|
+
declare function tri(n: number, m?: number, k?: number, {
|
|
1291
|
+
dtype,
|
|
1292
|
+
device
|
|
1293
|
+
}?: DTypeAndDevice): Array;
|
|
1294
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
1295
|
+
declare function tril(a: ArrayLike, k?: number): Array;
|
|
1296
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
1297
|
+
declare function triu(a: ArrayLike, k?: number): Array;
|
|
1056
1298
|
/**
|
|
1057
1299
|
* Clip (limit) the values in an array.
|
|
1058
1300
|
*
|
|
@@ -1069,15 +1311,50 @@ declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
|
1069
1311
|
* This is the same function as `jax.numpy.abs()`.
|
|
1070
1312
|
*/
|
|
1071
1313
|
declare function absolute(x: ArrayLike): Array;
|
|
1072
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
1314
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
1073
1315
|
declare const abs: typeof absolute;
|
|
1316
|
+
/** Return an element-wise indication of sign of the input. */
|
|
1317
|
+
declare function sign(x: ArrayLike): Array;
|
|
1074
1318
|
/** Calculate element-wise square of the input array. */
|
|
1075
1319
|
declare function square(x: ArrayLike): Array;
|
|
1076
|
-
/**
|
|
1320
|
+
/** Element-wise tangent function (takes radians). */
|
|
1077
1321
|
declare function tan(x: ArrayLike): Array;
|
|
1322
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
1323
|
+
declare function acos(x: ArrayLike): Array;
|
|
1324
|
+
/**
|
|
1325
|
+
* @function
|
|
1326
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
1327
|
+
*
|
|
1328
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
1329
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
1330
|
+
* improvements.
|
|
1331
|
+
*/
|
|
1332
|
+
declare const hypot: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1333
|
+
/**
|
|
1334
|
+
* @function
|
|
1335
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
1336
|
+
*
|
|
1337
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
1338
|
+
* The result is in the range [-π, π].
|
|
1339
|
+
*
|
|
1340
|
+
* Uses numerically stable formulas:
|
|
1341
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
1342
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
1343
|
+
*
|
|
1344
|
+
* The output is ill-defined when both x and y are zero.
|
|
1345
|
+
*/
|
|
1346
|
+
declare const atan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1347
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
1348
|
+
declare const arccos: typeof acos;
|
|
1349
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
1350
|
+
declare const arctan: (x: ArrayLike) => Array;
|
|
1351
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
1352
|
+
declare const arctan2: OwnedFunction<(y: ArrayLike, x: ArrayLike) => Array>;
|
|
1353
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
1354
|
+
declare function subtract(x: ArrayLike, y: ArrayLike): Array;
|
|
1078
1355
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
1079
1356
|
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
1080
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
1357
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
1081
1358
|
declare const divide: typeof trueDivide;
|
|
1082
1359
|
/** Round input to the nearest integer towards zero. */
|
|
1083
1360
|
declare function trunc(x: ArrayLike): Array;
|
|
@@ -1087,8 +1364,112 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
1087
1364
|
declare function log2(x: ArrayLike): Array;
|
|
1088
1365
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
1089
1366
|
declare function log10(x: ArrayLike): Array;
|
|
1367
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
1368
|
+
declare function expm1(x: ArrayLike): Array;
|
|
1369
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
1370
|
+
declare function log1p(x: ArrayLike): Array;
|
|
1371
|
+
/** Convert angles from degrees to radians. */
|
|
1372
|
+
declare function deg2rad(x: ArrayLike): Array;
|
|
1373
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
1374
|
+
declare const radians: typeof deg2rad;
|
|
1375
|
+
/** Convert angles from radians to degrees. */
|
|
1376
|
+
declare function rad2deg(x: ArrayLike): Array;
|
|
1377
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
1378
|
+
declare const degrees: typeof rad2deg;
|
|
1379
|
+
/**
|
|
1380
|
+
* @function
|
|
1381
|
+
* Computes first array raised to power of second array, element-wise.
|
|
1382
|
+
*/
|
|
1383
|
+
declare const power: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1384
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
1385
|
+
declare const pow: OwnedFunction<(x1: ArrayLike, x2: ArrayLike) => Array>;
|
|
1386
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
1387
|
+
declare const cbrt: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1388
|
+
/**
|
|
1389
|
+
* @function
|
|
1390
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
1391
|
+
*
|
|
1392
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1393
|
+
*/
|
|
1394
|
+
declare const sinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1395
|
+
/**
|
|
1396
|
+
* @function
|
|
1397
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
1398
|
+
*
|
|
1399
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1400
|
+
*/
|
|
1401
|
+
declare const cosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1402
|
+
/**
|
|
1403
|
+
* @function
|
|
1404
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
1405
|
+
*
|
|
1406
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1407
|
+
*/
|
|
1408
|
+
declare const tanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1409
|
+
/**
|
|
1410
|
+
* @function
|
|
1411
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
1412
|
+
*
|
|
1413
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
1414
|
+
*/
|
|
1415
|
+
declare const arcsinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1416
|
+
/**
|
|
1417
|
+
* @function
|
|
1418
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
1419
|
+
*
|
|
1420
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
1421
|
+
*/
|
|
1422
|
+
declare const arccosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1423
|
+
/**
|
|
1424
|
+
* @function
|
|
1425
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
1426
|
+
*
|
|
1427
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
1428
|
+
*/
|
|
1429
|
+
declare const arctanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1430
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
1431
|
+
declare const asinh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1432
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
1433
|
+
declare const acosh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1434
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
1435
|
+
declare const atanh: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1436
|
+
/**
|
|
1437
|
+
* Compute the variance of an array.
|
|
1438
|
+
*
|
|
1439
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
1440
|
+
* the specified axis.
|
|
1441
|
+
*
|
|
1442
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1443
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1444
|
+
*/
|
|
1445
|
+
declare function var_(x: ArrayLike, axis?: Axis, opts?: {
|
|
1446
|
+
mean?: ArrayLike;
|
|
1447
|
+
correction?: number;
|
|
1448
|
+
} & ReduceOpts): Array;
|
|
1449
|
+
/**
|
|
1450
|
+
* Compute the standard deviation of an array.
|
|
1451
|
+
*
|
|
1452
|
+
* The standard deviation is computed for the flattened array by default,
|
|
1453
|
+
* otherwise over the specified axis.
|
|
1454
|
+
*
|
|
1455
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
1456
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
1457
|
+
*/
|
|
1458
|
+
declare function std(x: ArrayLike, axis?: Axis, opts?: {
|
|
1459
|
+
mean?: ArrayLike;
|
|
1460
|
+
correction?: number;
|
|
1461
|
+
} & ReduceOpts): Array;
|
|
1090
1462
|
//#endregion
|
|
1091
1463
|
//#region src/frontend/jaxpr.d.ts
|
|
1464
|
+
/**
|
|
1465
|
+
* Function callback with an associated dispose() method.
|
|
1466
|
+
*
|
|
1467
|
+
* The dispose() method should be called to clean up any tracer resources needed
|
|
1468
|
+
* by the function after the last time it is called.
|
|
1469
|
+
*/
|
|
1470
|
+
type OwnedFunction<F extends Function> = F & {
|
|
1471
|
+
dispose: () => void;
|
|
1472
|
+
};
|
|
1092
1473
|
/** Variable in a Jaxpr expression. */
|
|
1093
1474
|
declare class Var {
|
|
1094
1475
|
#private;
|
|
@@ -1149,8 +1530,38 @@ declare class Jaxpr implements FpHashable {
|
|
|
1149
1530
|
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1150
1531
|
flatten(): Jaxpr;
|
|
1151
1532
|
}
|
|
1533
|
+
/** @inline */
|
|
1534
|
+
type JitOpts = {
|
|
1535
|
+
staticArgnums?: number[];
|
|
1536
|
+
device?: Device;
|
|
1537
|
+
};
|
|
1538
|
+
declare namespace lax_d_exports {
|
|
1539
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1540
|
+
}
|
|
1541
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1542
|
+
/**
|
|
1543
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
1544
|
+
*
|
|
1545
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
1546
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
1547
|
+
*
|
|
1548
|
+
* Grouped convolutions are not supported right now.
|
|
1549
|
+
*/
|
|
1550
|
+
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1551
|
+
lhsDilation,
|
|
1552
|
+
rhsDilation
|
|
1553
|
+
}?: {
|
|
1554
|
+
lhsDilation?: number[];
|
|
1555
|
+
rhsDilation?: number[];
|
|
1556
|
+
}): Array;
|
|
1557
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1558
|
+
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
1559
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1560
|
+
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1561
|
+
/** Reduce a computation over padded windows. */
|
|
1562
|
+
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1152
1563
|
declare namespace nn_d_exports {
|
|
1153
|
-
export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
1564
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logmeanexp, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, squareplus, standardize, swish };
|
|
1154
1565
|
}
|
|
1155
1566
|
/**
|
|
1156
1567
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -1182,6 +1593,7 @@ declare function softplus(x: ArrayLike): Array;
|
|
|
1182
1593
|
*/
|
|
1183
1594
|
declare function softSign(x: ArrayLike): Array;
|
|
1184
1595
|
/**
|
|
1596
|
+
* @function
|
|
1185
1597
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1186
1598
|
* Swish, computed element-wise:
|
|
1187
1599
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1190,8 +1602,9 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1190
1602
|
*
|
|
1191
1603
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1192
1604
|
*/
|
|
1193
|
-
declare
|
|
1605
|
+
declare const silu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1194
1606
|
/**
|
|
1607
|
+
* @function
|
|
1195
1608
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1196
1609
|
* Swish, computed element-wise:
|
|
1197
1610
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -1200,14 +1613,67 @@ declare function silu(x: ArrayLike): Array;
|
|
|
1200
1613
|
*
|
|
1201
1614
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1202
1615
|
*/
|
|
1203
|
-
declare const swish:
|
|
1616
|
+
declare const swish: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1204
1617
|
/**
|
|
1205
1618
|
* Log-sigmoid activation function, computed element-wise:
|
|
1206
1619
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
1207
1620
|
*/
|
|
1208
1621
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1209
|
-
/**
|
|
1622
|
+
/**
|
|
1623
|
+
* @function
|
|
1624
|
+
* Identity activation function. Returns the argument unmodified.
|
|
1625
|
+
*/
|
|
1210
1626
|
declare const identity: (x: ArrayLike) => Array;
|
|
1627
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
1628
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: ArrayLike): Array;
|
|
1629
|
+
/**
|
|
1630
|
+
* Exponential linear unit activation function.
|
|
1631
|
+
*
|
|
1632
|
+
* Computes the element-wise function:
|
|
1633
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1634
|
+
*/
|
|
1635
|
+
declare function elu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1636
|
+
/**
|
|
1637
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
1638
|
+
*
|
|
1639
|
+
* Computes the element-wise function:
|
|
1640
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1641
|
+
*/
|
|
1642
|
+
declare function celu(x: ArrayLike, alpha?: ArrayLike): Array;
|
|
1643
|
+
/**
|
|
1644
|
+
* @function
|
|
1645
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
1646
|
+
*
|
|
1647
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
1648
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
1649
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1650
|
+
*
|
|
1651
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1652
|
+
*
|
|
1653
|
+
* This will be improved in the future.
|
|
1654
|
+
*/
|
|
1655
|
+
declare const gelu: OwnedFunction<(x: ArrayLike) => Array>;
|
|
1656
|
+
/**
|
|
1657
|
+
* Gated linear unit (GLU) activation function.
|
|
1658
|
+
*
|
|
1659
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
1660
|
+
* computes `a * sigmoid(b)`.
|
|
1661
|
+
*/
|
|
1662
|
+
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1663
|
+
/**
|
|
1664
|
+
* Squareplus activation function.
|
|
1665
|
+
*
|
|
1666
|
+
* Computes the element-wise function:
|
|
1667
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
1668
|
+
*/
|
|
1669
|
+
declare function squareplus(x: ArrayLike, b?: ArrayLike): Array;
|
|
1670
|
+
/**
|
|
1671
|
+
* Mish activation function.
|
|
1672
|
+
*
|
|
1673
|
+
* Computes the element-wise function:
|
|
1674
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
1675
|
+
*/
|
|
1676
|
+
declare function mish(x: ArrayLike): Array;
|
|
1211
1677
|
/**
|
|
1212
1678
|
* Softmax function. Computes the function which rescales elements to the range
|
|
1213
1679
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -1216,7 +1682,7 @@ declare const identity: (x: ArrayLike) => Array;
|
|
|
1216
1682
|
*
|
|
1217
1683
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
1218
1684
|
*/
|
|
1219
|
-
declare function softmax(x: ArrayLike, axis?:
|
|
1685
|
+
declare function softmax(x: ArrayLike, axis?: Axis): Array;
|
|
1220
1686
|
/**
|
|
1221
1687
|
* Log-Softmax function.
|
|
1222
1688
|
*
|
|
@@ -1225,7 +1691,7 @@ declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1225
1691
|
*
|
|
1226
1692
|
* If `axis` is not specified, it defaults to the last axis.
|
|
1227
1693
|
*/
|
|
1228
|
-
declare function logSoftmax(x: ArrayLike, axis?:
|
|
1694
|
+
declare function logSoftmax(x: ArrayLike, axis?: Axis): Array;
|
|
1229
1695
|
/**
|
|
1230
1696
|
* Log-sum-exp reduction. Also a multivariate version of `softplus`.
|
|
1231
1697
|
*
|
|
@@ -1234,7 +1700,22 @@ declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1234
1700
|
*
|
|
1235
1701
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1236
1702
|
*/
|
|
1237
|
-
declare function logsumexp(x: ArrayLike, axis?:
|
|
1703
|
+
declare function logsumexp(x: ArrayLike, axis?: Axis): Array;
|
|
1704
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
1705
|
+
declare function logmeanexp(x: ArrayLike, axis?: Axis): Array;
|
|
1706
|
+
/**
|
|
1707
|
+
* Standardizes input to zero mean and unit variance.
|
|
1708
|
+
*
|
|
1709
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
1710
|
+
* axis, or `null` to standardize over all elements.
|
|
1711
|
+
*
|
|
1712
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
1713
|
+
*/
|
|
1714
|
+
declare function standardize(x: ArrayLike, axis?: Axis, opts?: {
|
|
1715
|
+
mean?: ArrayLike;
|
|
1716
|
+
variance?: ArrayLike;
|
|
1717
|
+
epsilon?: ArrayLike;
|
|
1718
|
+
}): Array;
|
|
1238
1719
|
/**
|
|
1239
1720
|
* One-hot encodes the given indices.
|
|
1240
1721
|
*
|
|
@@ -1253,7 +1734,7 @@ declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
|
|
|
1253
1734
|
*/
|
|
1254
1735
|
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1255
1736
|
declare namespace random_d_exports {
|
|
1256
|
-
export { bits, key, split, uniform };
|
|
1737
|
+
export { bernoulli, bits, exponential, key, normal, split, uniform };
|
|
1257
1738
|
}
|
|
1258
1739
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1259
1740
|
declare function key(seed: number): Array;
|
|
@@ -1269,40 +1750,108 @@ declare function uniform(key: Array, shape?: number[], {
|
|
|
1269
1750
|
minval?: number;
|
|
1270
1751
|
maxval?: number;
|
|
1271
1752
|
}): Array;
|
|
1753
|
+
/**
|
|
1754
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
1755
|
+
*
|
|
1756
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
1757
|
+
* and must be broadcastable to `shape`.
|
|
1758
|
+
*/
|
|
1759
|
+
declare function bernoulli(key: Array, p?: ArrayLike, shape?: number[]): Array;
|
|
1760
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
1761
|
+
declare function exponential(key: Array, shape?: number[]): Array;
|
|
1762
|
+
/**
|
|
1763
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
1764
|
+
*
|
|
1765
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
1766
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
1767
|
+
* bitwise identical to JAX.
|
|
1768
|
+
*/
|
|
1769
|
+
declare function normal(key: Array, shape?: number[]): Array;
|
|
1272
1770
|
//#endregion
|
|
1273
1771
|
//#region src/index.d.ts
|
|
1274
|
-
/**
|
|
1275
|
-
|
|
1276
|
-
|
|
1277
|
-
|
|
1278
|
-
|
|
1279
|
-
|
|
1280
|
-
|
|
1281
|
-
|
|
1282
|
-
|
|
1283
|
-
declare const
|
|
1772
|
+
/**
|
|
1773
|
+
* @function
|
|
1774
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
1775
|
+
*/
|
|
1776
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1777
|
+
/**
|
|
1778
|
+
* @function
|
|
1779
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
1780
|
+
*/
|
|
1781
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1782
|
+
/**
|
|
1783
|
+
* @function
|
|
1784
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
1785
|
+
*/
|
|
1786
|
+
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1787
|
+
/**
|
|
1788
|
+
* @function
|
|
1789
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
1790
|
+
*/
|
|
1791
|
+
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1284
1792
|
jaxpr: Jaxpr;
|
|
1285
1793
|
consts: Array[];
|
|
1286
1794
|
treedef: JsTreeDef;
|
|
1287
1795
|
};
|
|
1288
|
-
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
|
|
1289
1796
|
/**
|
|
1797
|
+
* @function
|
|
1798
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1799
|
+
*
|
|
1800
|
+
* The function will be compiled the first time it is called with a set of
|
|
1801
|
+
* argument shapes.
|
|
1802
|
+
*
|
|
1803
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
1804
|
+
* calls to free memory associated with array constants.
|
|
1805
|
+
*
|
|
1806
|
+
* **Options:**
|
|
1807
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1808
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
1809
|
+
* and different values will trigger recompilation.
|
|
1810
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
1811
|
+
* computation will be placed on the default device.
|
|
1812
|
+
*/
|
|
1813
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => OwnedFunction<(...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>>;
|
|
1814
|
+
/**
|
|
1815
|
+
* @function
|
|
1290
1816
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1291
1817
|
* partial evaluation.
|
|
1292
1818
|
*/
|
|
1293
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1294
|
-
/** Calculate the reverse-mode vector-Jacobian product for a function. */
|
|
1295
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1819
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1296
1820
|
/**
|
|
1821
|
+
* @function
|
|
1822
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
1823
|
+
*/
|
|
1824
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1825
|
+
/**
|
|
1826
|
+
* @function
|
|
1297
1827
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1298
1828
|
* first argument.
|
|
1299
1829
|
*/
|
|
1300
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1301
|
-
/**
|
|
1302
|
-
|
|
1303
|
-
|
|
1830
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1831
|
+
/**
|
|
1832
|
+
* @function
|
|
1833
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
1834
|
+
*/
|
|
1835
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1836
|
+
/**
|
|
1837
|
+
* @function
|
|
1838
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
1839
|
+
*/
|
|
1304
1840
|
declare const jacrev: typeof jacfwd;
|
|
1305
|
-
/**
|
|
1306
|
-
|
|
1841
|
+
/**
|
|
1842
|
+
* @function
|
|
1843
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
1844
|
+
*/
|
|
1845
|
+
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1846
|
+
/**
|
|
1847
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
1848
|
+
*
|
|
1849
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
1850
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
1851
|
+
* to avoid queueing up too many pending operations.
|
|
1852
|
+
*
|
|
1853
|
+
* Does not consume reference to the arrays.
|
|
1854
|
+
*/
|
|
1855
|
+
declare function blockUntilReady<T extends JsTree<any>>(x: T): Promise<T>;
|
|
1307
1856
|
//#endregion
|
|
1308
|
-
export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random,
|
|
1857
|
+
export { Array, DType, type Device, Jaxpr, type JsTree, type JsTreeDef, type OwnedFunction, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDebug, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|