@jax-js/jax 0.0.1 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +21 -15
- package/dist/backend-BqDtPGaR.js +3673 -0
- package/dist/backend-D2C4MJRP.cjs +3937 -0
- package/dist/chunk-Cl8Af3a2.js +11 -0
- package/dist/index.cjs +4705 -6069
- package/dist/index.d.cts +1110 -665
- package/dist/index.d.ts +1107 -665
- package/dist/index.js +4605 -3474
- package/dist/webgpu-CNg9JGva.js +612 -0
- package/dist/webgpu-fqhx41TC.cjs +612 -0
- package/package.json +30 -24
- package/dist/chunk-B2GFURUN.js +0 -1978
- package/dist/webgpu-QNXDOQZP.js +0 -559
package/dist/index.d.cts
CHANGED
|
@@ -1,24 +1,29 @@
|
|
|
1
|
-
|
|
2
|
-
* @file Lazy shape tracking for multidimensional tensors.
|
|
3
|
-
*
|
|
4
|
-
* This module provides an immutable `View` class that can be used to calculate
|
|
5
|
-
* shapes of arrays as operations are applied to them, lazily.
|
|
6
|
-
*
|
|
7
|
-
* Some operations like reshape() may not be representable with a single view,
|
|
8
|
-
* for instance, because composing reshape() with shrink() leads to a
|
|
9
|
-
* non-contiguous range of memory locations. This is why `ShapeTracker` is a
|
|
10
|
-
* list of views.
|
|
11
|
-
*
|
|
12
|
-
* Indexing into a `ShapeTracker` or `View` can be folded into shader code.
|
|
13
|
-
*
|
|
14
|
-
* Originally based on tinygrad's implementation of shape tracking in the
|
|
15
|
-
* `tinygrad.shape` module. But this version is simplified a bit. I'm not really
|
|
16
|
-
* trying to innovate on shape tracking in this library, so if I have doubts on
|
|
17
|
-
* something, it'll just be copied from tinygrad (with comments).
|
|
18
|
-
*
|
|
19
|
-
* This file is a bit longer than the original, since Python is more concise.
|
|
20
|
-
*/
|
|
1
|
+
import "node:module";
|
|
21
2
|
|
|
3
|
+
//#region rolldown:runtime
|
|
4
|
+
//#endregion
|
|
5
|
+
//#region src/pprint.d.ts
|
|
6
|
+
/** General class for pretty-printing expressions with indentation. */
|
|
7
|
+
declare class PPrint {
|
|
8
|
+
readonly indents: number[];
|
|
9
|
+
readonly lines: string[];
|
|
10
|
+
constructor(indents: number[], lines: string[]);
|
|
11
|
+
/** Add a fixed amount of indentation to each line. */
|
|
12
|
+
indent(spaces: number): PPrint;
|
|
13
|
+
/** Concatenate pretty-printed expressions with newlines. */
|
|
14
|
+
concat(...items: PPrint[]): PPrint;
|
|
15
|
+
/** Stack one block to the right of another one, sharing 1 common line. */
|
|
16
|
+
stack(other: PPrint): PPrint;
|
|
17
|
+
/** Combine this block of lines into a formatted string. */
|
|
18
|
+
toString(): string;
|
|
19
|
+
static pp(s: Stringable): PPrint;
|
|
20
|
+
}
|
|
21
|
+
interface Stringable {
|
|
22
|
+
toString(): string;
|
|
23
|
+
}
|
|
24
|
+
//#endregion
|
|
25
|
+
//#region src/shape.d.ts
|
|
26
|
+
/** @inline */
|
|
22
27
|
type Pair = [number, number];
|
|
23
28
|
/**
|
|
24
29
|
* A multidimensional view into memory. An array can be thought of as the
|
|
@@ -30,49 +35,56 @@ type Pair = [number, number];
|
|
|
30
35
|
* 2. Otherwise, look at this memory address: offset + ∑(strides[i] * dim[i]).
|
|
31
36
|
*/
|
|
32
37
|
declare class View {
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
38
|
+
#private;
|
|
39
|
+
/** The shape of the view (size of each dimension). */
|
|
40
|
+
readonly shape: number[];
|
|
41
|
+
/** How many indices to move in buffer for each hop in one dimension. */
|
|
42
|
+
readonly strides: number[];
|
|
43
|
+
/** Offset from the start of the buffer. */
|
|
44
|
+
readonly offset: number;
|
|
45
|
+
/** Masked out subarray where data is read. All other data is zeroed. */
|
|
46
|
+
readonly mask: Pair[] | null;
|
|
47
|
+
private constructor();
|
|
48
|
+
static create(shape: number[], strides?: number[], offset?: number, mask?: Pair[] | null): View;
|
|
49
|
+
get ndim(): number;
|
|
50
|
+
get size(): number;
|
|
51
|
+
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
52
|
+
get contiguous(): boolean;
|
|
53
|
+
/** Return the range of data being indexed in this view, or [0, 0] if none. */
|
|
54
|
+
dataRange(): [number, number];
|
|
55
|
+
/** Produce an AluExp for evaluating this view at an index. */
|
|
56
|
+
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
57
|
+
/**
|
|
58
|
+
* Try to compose this view with another one. `this` view is applied first,
|
|
59
|
+
* followed by the argument. If this is not possible for the specific views,
|
|
60
|
+
* return `null` instead.
|
|
61
|
+
*
|
|
62
|
+
* If composable, return a combined view with the same shape as `v1`.
|
|
63
|
+
*
|
|
64
|
+
* This is very tricky. The shapes of v1 and v2 may be different, and in that
|
|
65
|
+
* case, we do some math to figure out whether they're compatible.
|
|
66
|
+
*/
|
|
67
|
+
compose(v1: View): View | null;
|
|
68
|
+
/** Attempt to simplify this view into a smaller reshaped form. */
|
|
69
|
+
minify(): View;
|
|
70
|
+
/** Pad the view with zeros on each dimension. */
|
|
71
|
+
pad(arg: Pair[]): View;
|
|
72
|
+
/** Shrink the view by taking a subarray. */
|
|
73
|
+
shrink(arg: Pair[]): View;
|
|
74
|
+
/** Expand one or more axes with length "1" by repeating the data. */
|
|
75
|
+
expand(newShape: number[]): View;
|
|
76
|
+
/** Permute the axes of an array. */
|
|
77
|
+
permute(axis: number[]): View;
|
|
78
|
+
/** Flip (reverse) one or more axes of the view. */
|
|
79
|
+
flip(arg: boolean[]): View;
|
|
80
|
+
/** Reshape the view into a new shape. */
|
|
81
|
+
reshape(newShape: number[]): View | null;
|
|
75
82
|
}
|
|
83
|
+
/**
|
|
84
|
+
* Find position of `offset` in each dimension within an existing shape. Like
|
|
85
|
+
* `numpy.unravel_index` in behavior.
|
|
86
|
+
*/
|
|
87
|
+
|
|
76
88
|
/**
|
|
77
89
|
* Array shape after applying movement operations, as a series of views.
|
|
78
90
|
*
|
|
@@ -80,31 +92,44 @@ declare class View {
|
|
|
80
92
|
* shape, then used as the virtual buffer for the next view.
|
|
81
93
|
*/
|
|
82
94
|
declare class ShapeTracker {
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
95
|
+
readonly views: View[];
|
|
96
|
+
constructor(views: View[]);
|
|
97
|
+
/** Compose this shape tracker with another, applying it after this one. */
|
|
98
|
+
compose(other: ShapeTracker): ShapeTracker;
|
|
99
|
+
static fromShape(shape: number[]): ShapeTracker;
|
|
100
|
+
get contiguous(): boolean;
|
|
101
|
+
get consecutive(): boolean;
|
|
102
|
+
get lastStrides(): number[];
|
|
103
|
+
get shape(): number[];
|
|
104
|
+
get size(): number;
|
|
105
|
+
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
106
|
+
simplify(): ShapeTracker;
|
|
107
|
+
pad(arg: Pair[]): ShapeTracker;
|
|
108
|
+
shrink(arg: Pair[]): ShapeTracker;
|
|
109
|
+
expand(newShape: number[]): ShapeTracker;
|
|
110
|
+
permute(axis: number[]): ShapeTracker;
|
|
111
|
+
flip(arg: boolean[]): ShapeTracker;
|
|
112
|
+
reshape(newShape: number[]): ShapeTracker;
|
|
113
|
+
/** Broadcast along the given new axes, then expand the shape. */
|
|
114
|
+
broadcast(newShape: number[], axis: number[]): ShapeTracker;
|
|
115
|
+
/**
|
|
116
|
+
* Repeat data in each axis by a positive number of repetitions.
|
|
117
|
+
*
|
|
118
|
+
* - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
|
|
119
|
+
* - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
|
|
120
|
+
*/
|
|
121
|
+
repeat(reps: number[], tile?: boolean): ShapeTracker;
|
|
122
|
+
/** Move axis i to axis j. */
|
|
123
|
+
moveaxis(i: number, j: number): ShapeTracker;
|
|
124
|
+
/** Like pad(), but allows for negative values. */
|
|
125
|
+
padOrShrink(arg: Pair[]): ShapeTracker;
|
|
103
126
|
}
|
|
104
|
-
|
|
127
|
+
//#endregion
|
|
128
|
+
//#region src/utils.d.ts
|
|
129
|
+
/** @inline */
|
|
105
130
|
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
106
131
|
interface FpHashable {
|
|
107
|
-
|
|
132
|
+
hash(state: FpHash): void;
|
|
108
133
|
}
|
|
109
134
|
/**
|
|
110
135
|
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
@@ -114,18 +139,24 @@ interface FpHashable {
|
|
|
114
139
|
* See https://en.wikipedia.org/wiki/Lagrange%27s_theorem_(number_theory)
|
|
115
140
|
*/
|
|
116
141
|
declare class FpHash {
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
142
|
+
#private;
|
|
143
|
+
value: bigint;
|
|
144
|
+
update(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): this;
|
|
145
|
+
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
121
146
|
}
|
|
122
|
-
|
|
147
|
+
/** Run a function while caching it inline inside a `Map`. */
|
|
148
|
+
//#endregion
|
|
149
|
+
//#region src/alu.d.ts
|
|
150
|
+
/** A numerical data type for array contents. */
|
|
123
151
|
declare enum DType {
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
152
|
+
Float32 = "float32",
|
|
153
|
+
Int32 = "int32",
|
|
154
|
+
Uint32 = "uint32",
|
|
155
|
+
Bool = "bool",
|
|
156
|
+
Float16 = "float16",
|
|
128
157
|
}
|
|
158
|
+
/** @inline */
|
|
159
|
+
type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
|
|
129
160
|
/**
|
|
130
161
|
* Mathematical expression on scalar values.
|
|
131
162
|
*
|
|
@@ -134,94 +165,127 @@ declare enum DType {
|
|
|
134
165
|
* graph rewrite engine.
|
|
135
166
|
*/
|
|
136
167
|
declare class AluExp implements FpHashable {
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
168
|
+
#private;
|
|
169
|
+
readonly op: AluOp;
|
|
170
|
+
readonly dtype: DType;
|
|
171
|
+
readonly src: AluExp[];
|
|
172
|
+
readonly arg: any;
|
|
173
|
+
constructor(op: AluOp, dtype: DType, src: AluExp[], arg?: any);
|
|
174
|
+
static add(a: AluExp, b: AluExp): AluExp;
|
|
175
|
+
static sub(a: AluExp, b: AluExp): AluExp;
|
|
176
|
+
static mul(a: AluExp, b: AluExp): AluExp;
|
|
177
|
+
static idiv(a: AluExp, b: AluExp): AluExp;
|
|
178
|
+
static mod(a: AluExp, b: AluExp): AluExp;
|
|
179
|
+
static min(a: AluExp, b: AluExp): AluExp;
|
|
180
|
+
static max(a: AluExp, b: AluExp): AluExp;
|
|
181
|
+
static sin(a: AluExp): AluExp;
|
|
182
|
+
static cos(a: AluExp): AluExp;
|
|
183
|
+
static exp(a: AluExp): AluExp;
|
|
184
|
+
static log(a: AluExp): AluExp;
|
|
185
|
+
static sqrt(a: AluExp): AluExp;
|
|
186
|
+
static reciprocal(a: AluExp): AluExp;
|
|
187
|
+
static cast(dtype: DType, a: AluExp): AluExp;
|
|
188
|
+
static bitcast(dtype: DType, a: AluExp): AluExp;
|
|
189
|
+
static threefry2x32(k0: AluExp, k1: AluExp, c0: AluExp, c1: AluExp, mode?: "xor" | 0 | 1): AluExp;
|
|
190
|
+
static cmplt(a: AluExp, b: AluExp): AluExp;
|
|
191
|
+
static cmpne(a: AluExp, b: AluExp): AluExp;
|
|
192
|
+
static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
|
|
193
|
+
static const(dtype: DType, value: any): AluExp;
|
|
194
|
+
static special(dtype: DType, name: string, n: number): AluExp;
|
|
195
|
+
static variable(dtype: DType, name: string): AluExp;
|
|
196
|
+
static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
|
|
197
|
+
static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
|
|
198
|
+
static f32(value: number): AluExp;
|
|
199
|
+
static i32(value: number): AluExp;
|
|
200
|
+
static u32(value: number): AluExp;
|
|
201
|
+
static bool(value: boolean): AluExp;
|
|
202
|
+
static f16(value: number): AluExp;
|
|
203
|
+
not(): AluExp;
|
|
204
|
+
/** Compute a reasonable expression hash with low collision rate. */
|
|
205
|
+
getHash(): bigint;
|
|
206
|
+
hash(state: FpHash): void;
|
|
207
|
+
/** Substitute variables in this AluExp to values. */
|
|
208
|
+
substitute(variables: Record<string, AluExp>): AluExp;
|
|
209
|
+
/** Reindex gid values in this expression as needed. */
|
|
210
|
+
reindexGids(gidMap: Map<number, number>): AluExp;
|
|
211
|
+
get min(): number;
|
|
212
|
+
get max(): number;
|
|
213
|
+
/** Largest known integer that divides self. */
|
|
214
|
+
constFactor(): number;
|
|
215
|
+
/**
|
|
216
|
+
* Checks if divisible by an integer v and returns the quotient if it is, or
|
|
217
|
+
* `null` if it's not divisible.
|
|
218
|
+
*/
|
|
219
|
+
divides(v: number): AluExp | null;
|
|
220
|
+
/**
|
|
221
|
+
* Get all expressions by deeply matching an operation.
|
|
222
|
+
*
|
|
223
|
+
* For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
|
|
224
|
+
*/
|
|
225
|
+
splitOp(sep: AluOp): IterableIterator<AluExp>;
|
|
226
|
+
/**
|
|
227
|
+
* Simplify the expression by replacing any known patterns and deduping
|
|
228
|
+
* identical subexpressions.
|
|
229
|
+
*/
|
|
230
|
+
simplify(cache?: Map<bigint, AluExp>): AluExp;
|
|
231
|
+
/** Resolve this to a value, or `undefined` if not possible. */
|
|
232
|
+
resolve(): any | undefined;
|
|
233
|
+
/**
|
|
234
|
+
* Evaluate the expression on CPU, returning the result.
|
|
235
|
+
*
|
|
236
|
+
* Typically you would compile the AluExp as a representation to a lower-level
|
|
237
|
+
* language. This is just to define the semantics and help debug.
|
|
238
|
+
*
|
|
239
|
+
* Note that the representation of Bool is as a number (0 or 1) here.
|
|
240
|
+
*/
|
|
241
|
+
evaluate(context: Record<string, any>, globals?: (gid: number, bufidx: number) => any): number;
|
|
242
|
+
/** Get this expression in debug format as a string. */
|
|
243
|
+
toString(): string;
|
|
244
|
+
/** Generic fold() operation with a reducer over the expression tree. */
|
|
245
|
+
fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
|
|
246
|
+
/** Check if any expression in the tree satisfies a predicate. */
|
|
247
|
+
some(predicate: (exp: AluExp) => boolean): boolean;
|
|
248
|
+
/** Rewrite the expression recursively using a visitor. */
|
|
249
|
+
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
250
|
+
/** Collect all nodes that satisfy a predicate. */
|
|
251
|
+
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
252
|
+
/** Produce a list of all distinct AluOp in this expression. */
|
|
253
|
+
distinctOps(): Set<AluOp>;
|
|
254
|
+
/** Rewrite GlobalView operations to GlobalIndex operations. */
|
|
255
|
+
rewriteGlobalViews(): AluExp;
|
|
201
256
|
}
|
|
202
257
|
/** Symbolic form for each mathematical operation. */
|
|
203
258
|
declare enum AluOp {
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
259
|
+
Add = "Add",
|
|
260
|
+
Sub = "Sub",
|
|
261
|
+
Mul = "Mul",
|
|
262
|
+
Idiv = "Idiv",
|
|
263
|
+
Mod = "Mod",
|
|
264
|
+
Min = "Min",
|
|
265
|
+
Max = "Max",
|
|
266
|
+
Sin = "Sin",
|
|
267
|
+
Cos = "Cos",
|
|
268
|
+
Exp = "Exp",
|
|
269
|
+
Log = "Log",
|
|
270
|
+
Sqrt = "Sqrt",
|
|
271
|
+
Reciprocal = "Reciprocal",
|
|
272
|
+
Cast = "Cast",
|
|
273
|
+
Bitcast = "Bitcast",
|
|
274
|
+
Cmplt = "Cmplt",
|
|
275
|
+
Cmpne = "Cmpne",
|
|
276
|
+
Where = "Where",
|
|
277
|
+
// Ternary operator: `cond ? a : b`
|
|
278
|
+
Threefry2x32 = "Threefry2x32",
|
|
279
|
+
// PRNG operation, arg = 'xor' | 0 | 1
|
|
280
|
+
Const = "Const",
|
|
281
|
+
// arg = value
|
|
282
|
+
Special = "Special",
|
|
283
|
+
// arg = [variable, n]
|
|
284
|
+
Variable = "Variable",
|
|
285
|
+
// arg = variable
|
|
286
|
+
GlobalIndex = "GlobalIndex",
|
|
287
|
+
// arg = [gid, len]; src = [bufidx]
|
|
288
|
+
GlobalView = "GlobalView",
|
|
225
289
|
}
|
|
226
290
|
/**
|
|
227
291
|
* Description of a kernel to be compiled.
|
|
@@ -231,24 +295,26 @@ declare enum AluOp {
|
|
|
231
295
|
* indexing into a buffer.
|
|
232
296
|
*/
|
|
233
297
|
declare class Kernel implements FpHashable {
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
298
|
+
/** Number of global arguments / arrays. */
|
|
299
|
+
readonly nargs: number;
|
|
300
|
+
/** Size of the result array in element count. */
|
|
301
|
+
readonly size: number;
|
|
302
|
+
/** Expression to be evaluated. */
|
|
303
|
+
readonly exp: AluExp;
|
|
304
|
+
/** Optional reduction to be performed. */
|
|
305
|
+
readonly reduction?: Reduction | undefined;
|
|
306
|
+
constructor(/** Number of global arguments / arrays. */
|
|
307
|
+
nargs: number, /** Size of the result array in element count. */
|
|
308
|
+
size: number, /** Expression to be evaluated. */
|
|
309
|
+
exp: AluExp, /** Optional reduction to be performed. */
|
|
310
|
+
reduction?: Reduction | undefined);
|
|
311
|
+
hash(state: FpHash): void;
|
|
312
|
+
pprint(): PPrint;
|
|
313
|
+
toString(): string;
|
|
314
|
+
/** The dtype of the values output by this kernel. */
|
|
315
|
+
get dtype(): DType;
|
|
316
|
+
/** The number of bytes in the output array when evaluating this kernel. */
|
|
317
|
+
get bytes(): number;
|
|
252
318
|
}
|
|
253
319
|
/**
|
|
254
320
|
* Description of a reduction.
|
|
@@ -266,42 +332,30 @@ declare class Kernel implements FpHashable {
|
|
|
266
332
|
* at this level since they depend on GPU, versus CPU or Wasm.
|
|
267
333
|
*/
|
|
268
334
|
declare class Reduction implements FpHashable {
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
get identity(): any;
|
|
289
|
-
/** Evaluate this operation on CPU. */
|
|
290
|
-
evaluate(...values: any): any;
|
|
335
|
+
/** Data type of the values being reduced over. */
|
|
336
|
+
readonly dtype: DType;
|
|
337
|
+
/** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
338
|
+
readonly op: AluOp;
|
|
339
|
+
/** Size of the reduction axis. */
|
|
340
|
+
readonly size: number;
|
|
341
|
+
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
342
|
+
readonly epilogue: AluExp;
|
|
343
|
+
constructor(/** Data type of the values being reduced over. */
|
|
344
|
+
dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
345
|
+
op: AluOp, /** Size of the reduction axis. */
|
|
346
|
+
size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
347
|
+
epilogue?: AluExp);
|
|
348
|
+
hash(state: FpHash): void;
|
|
349
|
+
toString(): string;
|
|
350
|
+
/** Get the identity for this reduction operation. */
|
|
351
|
+
get identity(): any;
|
|
352
|
+
/** Evaluate this operation on CPU. */
|
|
353
|
+
evaluate(...values: any): any;
|
|
291
354
|
}
|
|
292
|
-
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
* Think of each backend as a _connector_ to a specific hardware or software
|
|
297
|
-
* implementation of the array API.
|
|
298
|
-
*
|
|
299
|
-
* Backends do not share any of the built-in operational semantics of the
|
|
300
|
-
* library. This is a private API. You must allocate and free buffers manually,
|
|
301
|
-
* and dispatch happens on the level of each shader. Buffers are untyped.
|
|
302
|
-
*/
|
|
303
|
-
|
|
304
|
-
type Device = "cpu" | "webgpu";
|
|
355
|
+
/** Expression for accessing `indices` in input array with the given shape. */
|
|
356
|
+
//#endregion
|
|
357
|
+
//#region src/backend.d.ts
|
|
358
|
+
type Device = "cpu" | "wasm" | "webgpu";
|
|
305
359
|
declare const devices: Device[];
|
|
306
360
|
/** Set the default device backend (must be initialized). */
|
|
307
361
|
declare function setDevice(device: Device): void;
|
|
@@ -313,75 +367,78 @@ declare function setDevice(device: Device): void;
|
|
|
313
367
|
* available backends.
|
|
314
368
|
*/
|
|
315
369
|
declare function init(...devicesToInit: Device[]): Promise<Device[]>;
|
|
370
|
+
/** Retrieve a backend that has been initialized. */
|
|
371
|
+
|
|
316
372
|
/** Unique identifier for an allocated, on-device buffer. */
|
|
317
373
|
type Slot = number;
|
|
318
374
|
/** A device backend. */
|
|
319
375
|
interface Backend {
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
376
|
+
/** The name of the backend as a string. */
|
|
377
|
+
readonly type: Device;
|
|
378
|
+
/** Maximum number of arguments per dispatched kernel. */
|
|
379
|
+
readonly maxArgs: number;
|
|
380
|
+
/** Allocate a new slot with reference count 1. */
|
|
381
|
+
malloc(size: number, initialData?: Uint8Array): Slot;
|
|
382
|
+
/** Increment the reference count of the slot. */
|
|
383
|
+
incRef(slot: Slot): void;
|
|
384
|
+
/**
|
|
385
|
+
* Decrement the reference count of the slot. If the reference count reaches
|
|
386
|
+
* zero, it is freed. This should throw if the slot was already freed.
|
|
387
|
+
*/
|
|
388
|
+
decRef(slot: Slot): void;
|
|
389
|
+
/** Read a range of bytes from a buffer. */
|
|
390
|
+
read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
|
|
391
|
+
/** Read a range of bytes from a buffer, blocking variant. */
|
|
392
|
+
readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
|
|
393
|
+
/** Prepare an expression to be executed later. */
|
|
394
|
+
prepare(kernel: Kernel): Promise<Executable>;
|
|
395
|
+
/** Prepare an expression to be executed later, blocking variant. */
|
|
396
|
+
prepareSync(kernel: Kernel): Executable;
|
|
397
|
+
/**
|
|
398
|
+
* Run a backend operation that was previously prepared.
|
|
399
|
+
*
|
|
400
|
+
* The operation may not run immediately, but operations are guaranteed to run
|
|
401
|
+
* in the dispatch order. Also, `read()` will wait for all pending operations
|
|
402
|
+
* on that slot to finish.
|
|
403
|
+
*/
|
|
404
|
+
dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
|
|
349
405
|
}
|
|
350
406
|
declare class Executable<T = any> {
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
407
|
+
readonly kernel: Kernel;
|
|
408
|
+
/** Extra data specific to the backend running this kernel. */
|
|
409
|
+
readonly data: T;
|
|
410
|
+
constructor(kernel: Kernel, /** Extra data specific to the backend running this kernel. */
|
|
411
|
+
data: T);
|
|
412
|
+
}
|
|
413
|
+
declare namespace tree_d_exports {
|
|
414
|
+
export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
|
|
357
415
|
}
|
|
358
|
-
|
|
359
|
-
/** @file Utilities for working with tree-like container data structures ("pytrees"). */
|
|
360
416
|
declare enum NodeType {
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
417
|
+
Array = "Array",
|
|
418
|
+
Object = "Object",
|
|
419
|
+
Leaf = "Leaf",
|
|
364
420
|
}
|
|
421
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
365
422
|
type JsTree<T> = T | JsTree<T>[] | {
|
|
366
|
-
|
|
423
|
+
[key: string]: JsTree<T>;
|
|
367
424
|
};
|
|
368
|
-
type
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
425
|
+
type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
|
|
426
|
+
type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
|
|
427
|
+
/** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
|
|
428
|
+
type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
|
|
429
|
+
/** Represents the structure of a JsTree. */
|
|
374
430
|
declare class JsTreeDef {
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
431
|
+
readonly nodeType: NodeType;
|
|
432
|
+
readonly nodeMetadata: any;
|
|
433
|
+
readonly childTreedefs: JsTreeDef[];
|
|
434
|
+
static leaf: JsTreeDef;
|
|
435
|
+
constructor(nodeType: NodeType, nodeMetadata: any,
|
|
436
|
+
// Must be comparable with deepEqual.
|
|
437
|
+
childTreedefs: JsTreeDef[]);
|
|
438
|
+
/** Returns a string representation of this tree definition. */
|
|
439
|
+
toString(root?: boolean): string;
|
|
440
|
+
/** Compare this tree definition with another. */
|
|
441
|
+
equals(other: JsTreeDef): boolean;
|
|
385
442
|
}
|
|
386
443
|
/** Flatten a structured object, returning the tree definition. */
|
|
387
444
|
declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
|
|
@@ -391,151 +448,324 @@ declare function leaves<T>(tree: JsTree<T>): T[];
|
|
|
391
448
|
declare function structure<T>(tree: JsTree<T>): JsTreeDef;
|
|
392
449
|
/** Reconstruct a structured object from the flattened representation. */
|
|
393
450
|
declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
declare
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
451
|
+
/** Maps a multi-input function over pytree args to produce a new pytree. */
|
|
452
|
+
declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
|
|
453
|
+
/** Take a reference of every array in a tree. */
|
|
454
|
+
declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
|
|
455
|
+
/** Dispose every array in a tree. */
|
|
456
|
+
declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
|
|
457
|
+
//#endregion
|
|
458
|
+
//#region src/frontend/convolution.d.ts
|
|
459
|
+
/** Definition of a general dilated convolution. Should be valid on creation. */
|
|
460
|
+
interface ConvParams {
|
|
461
|
+
strides: number[];
|
|
462
|
+
padding: [number, number][];
|
|
463
|
+
lhsDilation: number[];
|
|
464
|
+
rhsDilation: number[];
|
|
407
465
|
}
|
|
466
|
+
/**
|
|
467
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
468
|
+
*
|
|
469
|
+
* If the check succeeds, returns the output shape.
|
|
470
|
+
*/
|
|
408
471
|
|
|
409
|
-
|
|
410
|
-
|
|
472
|
+
//#endregion
|
|
473
|
+
//#region src/frontend/core.d.ts
|
|
474
|
+
/**
|
|
475
|
+
* Frontend primitive operations, which are lowered into Kernel objects before
|
|
476
|
+
* being dispatched to the backend.
|
|
477
|
+
*
|
|
478
|
+
* Any operation between arrays can be described in these parts. This is also
|
|
479
|
+
* the set of primitives that can occur in Jaxpr programs, and the level at
|
|
480
|
+
* which transformations like vmap, grad, and jvp occur. They are loosely based
|
|
481
|
+
* on [XLA](https://openxla.org/xla/operation_semantics).
|
|
482
|
+
*
|
|
483
|
+
* All n-ary operations support broadcasting, with NumPy semantics.
|
|
484
|
+
*/
|
|
411
485
|
declare enum Primitive {
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
486
|
+
Add = "add",
|
|
487
|
+
Mul = "mul",
|
|
488
|
+
Idiv = "idiv",
|
|
489
|
+
Neg = "neg",
|
|
490
|
+
Reciprocal = "reciprocal",
|
|
491
|
+
StopGradient = "stop_gradient",
|
|
492
|
+
Cast = "cast",
|
|
493
|
+
Bitcast = "bitcast",
|
|
494
|
+
RandomBits = "random_bits",
|
|
495
|
+
Sin = "sin",
|
|
496
|
+
Cos = "cos",
|
|
497
|
+
Exp = "exp",
|
|
498
|
+
Log = "log",
|
|
499
|
+
Sqrt = "sqrt",
|
|
500
|
+
Min = "min",
|
|
501
|
+
Max = "max",
|
|
502
|
+
Reduce = "reduce",
|
|
503
|
+
Dot = "dot",
|
|
504
|
+
// sum(x*y, axis=-1)
|
|
505
|
+
Conv = "conv",
|
|
506
|
+
// see lax.conv_general_dilated
|
|
507
|
+
Pool = "pool",
|
|
508
|
+
PoolTranspose = "pool_transpose",
|
|
509
|
+
Compare = "compare",
|
|
510
|
+
Where = "where",
|
|
511
|
+
Transpose = "transpose",
|
|
512
|
+
Broadcast = "broadcast",
|
|
513
|
+
Reshape = "reshape",
|
|
514
|
+
Flip = "flip",
|
|
515
|
+
Shrink = "shrink",
|
|
516
|
+
Pad = "pad",
|
|
517
|
+
Gather = "gather",
|
|
518
|
+
JitCall = "jit_call",
|
|
431
519
|
}
|
|
520
|
+
interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
|
|
521
|
+
[Primitive.Cast]: {
|
|
522
|
+
dtype: DType;
|
|
523
|
+
};
|
|
524
|
+
[Primitive.Bitcast]: {
|
|
525
|
+
dtype: DType;
|
|
526
|
+
};
|
|
527
|
+
[Primitive.Reduce]: {
|
|
528
|
+
op: AluOp;
|
|
529
|
+
axis: number[];
|
|
530
|
+
};
|
|
531
|
+
[Primitive.Conv]: ConvParams;
|
|
532
|
+
[Primitive.Pool]: {
|
|
533
|
+
window: number[];
|
|
534
|
+
strides: number[];
|
|
535
|
+
};
|
|
536
|
+
[Primitive.PoolTranspose]: {
|
|
537
|
+
inShape: number[];
|
|
538
|
+
window: number[];
|
|
539
|
+
strides: number[];
|
|
540
|
+
};
|
|
541
|
+
[Primitive.Compare]: {
|
|
542
|
+
op: CompareOp;
|
|
543
|
+
};
|
|
544
|
+
[Primitive.Transpose]: {
|
|
545
|
+
perm: number[];
|
|
546
|
+
};
|
|
547
|
+
[Primitive.Broadcast]: {
|
|
548
|
+
shape: number[];
|
|
549
|
+
axis: number[];
|
|
550
|
+
};
|
|
551
|
+
[Primitive.RandomBits]: {
|
|
552
|
+
shape: number[];
|
|
553
|
+
mode: "xor" | 0 | 1;
|
|
554
|
+
};
|
|
555
|
+
[Primitive.Reshape]: {
|
|
556
|
+
shape: number[];
|
|
557
|
+
};
|
|
558
|
+
[Primitive.Flip]: {
|
|
559
|
+
axis: number[];
|
|
560
|
+
};
|
|
561
|
+
[Primitive.Shrink]: {
|
|
562
|
+
slice: Pair[];
|
|
563
|
+
};
|
|
564
|
+
[Primitive.Pad]: {
|
|
565
|
+
width: Pair[];
|
|
566
|
+
};
|
|
567
|
+
[Primitive.Gather]: {
|
|
568
|
+
axis: number[];
|
|
569
|
+
outDim: number;
|
|
570
|
+
};
|
|
571
|
+
[Primitive.JitCall]: {
|
|
572
|
+
jaxpr: Jaxpr;
|
|
573
|
+
numConsts: number;
|
|
574
|
+
};
|
|
575
|
+
}
|
|
576
|
+
/** Type of parameters taken by each primitive. */
|
|
577
|
+
type PrimitiveParams<T extends Primitive> = T extends keyof PrimitiveParamsImpl ? PrimitiveParamsImpl[T] : Record<string, never>;
|
|
578
|
+
declare enum CompareOp {
|
|
579
|
+
Greater = "greater",
|
|
580
|
+
Less = "less",
|
|
581
|
+
Equal = "equal",
|
|
582
|
+
NotEqual = "not_equal",
|
|
583
|
+
GreaterEqual = "greater_equal",
|
|
584
|
+
LessEqual = "less_equal",
|
|
585
|
+
}
|
|
586
|
+
/** @inline */
|
|
587
|
+
type ReduceOpts = {
|
|
588
|
+
keepDims?: boolean;
|
|
589
|
+
};
|
|
432
590
|
type MainTrace = {
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
591
|
+
level: number;
|
|
592
|
+
traceType: new (main: MainTrace) => Trace;
|
|
593
|
+
globalData: any | null;
|
|
436
594
|
};
|
|
595
|
+
/**
|
|
596
|
+
* Push an interpreter onto the trace stack. Use this like:
|
|
597
|
+
* `using main = newMain(...);`
|
|
598
|
+
*/
|
|
599
|
+
|
|
437
600
|
type TracerValue = Tracer | number | boolean;
|
|
438
601
|
declare abstract class Trace {
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
602
|
+
readonly main: MainTrace;
|
|
603
|
+
constructor(main: MainTrace);
|
|
604
|
+
abstract pure(val: TracerValue): Tracer;
|
|
605
|
+
abstract lift(val: Tracer): Tracer;
|
|
606
|
+
abstract processPrimitive<P extends Primitive>(primitive: P, tracers: Tracer[], params: PrimitiveParams<P>): Tracer[];
|
|
444
607
|
}
|
|
445
608
|
interface AbstractValue {
|
|
446
|
-
|
|
447
|
-
|
|
609
|
+
shape: number[];
|
|
610
|
+
dtype: DType;
|
|
448
611
|
}
|
|
449
612
|
declare abstract class Tracer {
|
|
450
|
-
|
|
451
|
-
|
|
452
|
-
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
613
|
+
/** @ignore */
|
|
614
|
+
readonly _trace: Trace;
|
|
615
|
+
constructor(trace: Trace);
|
|
616
|
+
abstract get aval(): AbstractValue;
|
|
617
|
+
abstract toString(): string;
|
|
618
|
+
/**
|
|
619
|
+
* Access an array by reference, incrementing the reference count.
|
|
620
|
+
*
|
|
621
|
+
* jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
|
|
622
|
+
* Whenever you pass an array into a function, that function should consume
|
|
623
|
+
* the array, and it will no longer be usable. For example, if you had:
|
|
624
|
+
*
|
|
625
|
+
* ```
|
|
626
|
+
* const x = np.array([1, 2, 3]);
|
|
627
|
+
* const y = np.add(x, x);
|
|
628
|
+
* ```
|
|
629
|
+
*
|
|
630
|
+
* The second line does not work because the first parameter consumes `x`, and
|
|
631
|
+
* then the second parameter will already have been freed / disposed.
|
|
632
|
+
*
|
|
633
|
+
* To fix this, you can write:
|
|
634
|
+
*
|
|
635
|
+
* ```
|
|
636
|
+
* const y = np.add(x.ref, x);
|
|
637
|
+
* ```
|
|
638
|
+
*
|
|
639
|
+
* Under the hood, every access to `.ref` increments the internal reference
|
|
640
|
+
* count of the array. The reference count starts at 1. When it hits 0, the
|
|
641
|
+
* memory behind the array is freed.
|
|
642
|
+
*/
|
|
643
|
+
abstract get ref(): this;
|
|
644
|
+
/**
|
|
645
|
+
* Manually decrement the reference count of the array.
|
|
646
|
+
*
|
|
647
|
+
* Arrays are created with reference count 1. Whenever it is used as argument
|
|
648
|
+
* to a function or other operation, it is disposed (i.e., reference count
|
|
649
|
+
* decreases by 1) automatically. Whenever a `.ref` is created, the reference
|
|
650
|
+
* count increases.
|
|
651
|
+
*
|
|
652
|
+
* You generally don't need to call this function directly since arrays are
|
|
653
|
+
* automatically disposed after being passed into an operation. One common
|
|
654
|
+
* exception is when writing a function and ignoring one of its arguments. In
|
|
655
|
+
* that case, by convention you should dispose of that argument manually.
|
|
656
|
+
*
|
|
657
|
+
* ```
|
|
658
|
+
* function myCustomOperation(a: np.Array, b: np.Array) {
|
|
659
|
+
* b.dispose(); // Needed to satisfy "move" rules.
|
|
660
|
+
* return a.add(1);
|
|
661
|
+
* }
|
|
662
|
+
* ```
|
|
663
|
+
*/
|
|
664
|
+
abstract dispose(): void;
|
|
665
|
+
get shape(): number[];
|
|
666
|
+
get size(): number;
|
|
667
|
+
get dtype(): DType;
|
|
668
|
+
get ndim(): number;
|
|
669
|
+
/** @ignore */
|
|
670
|
+
fullLower(): Tracer;
|
|
671
|
+
neg(): this;
|
|
672
|
+
add(other: this | TracerValue): this;
|
|
673
|
+
mul(other: this | TracerValue): this;
|
|
674
|
+
greater(other: this | TracerValue): this;
|
|
675
|
+
less(other: this | TracerValue): this;
|
|
676
|
+
equal(other: this | TracerValue): this;
|
|
677
|
+
notEqual(other: this | TracerValue): this;
|
|
678
|
+
greaterEqual(other: this | TracerValue): this;
|
|
679
|
+
lessEqual(other: this | TracerValue): this;
|
|
680
|
+
/** Sum of the elements of the array over a given axis, or axes. */
|
|
681
|
+
sum(axis?: number | number[], opts?: ReduceOpts): this;
|
|
682
|
+
/** Product of the array elements over a given axis. */
|
|
683
|
+
prod(axis?: number | number[], opts?: ReduceOpts): this;
|
|
684
|
+
/** Compute the average of the array elements along the specified axis. */
|
|
685
|
+
mean(axis?: number | number[], opts?: ReduceOpts): this;
|
|
686
|
+
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
687
|
+
transpose(perm?: number[]): this;
|
|
688
|
+
/**
|
|
689
|
+
* Give a new shape to an array without changing its data.
|
|
690
|
+
*
|
|
691
|
+
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
692
|
+
* length of the array and remaining dimensions.
|
|
693
|
+
*/
|
|
694
|
+
reshape(shape: number | number[]): this;
|
|
695
|
+
/** Copy the array and cast to a specified dtype. */
|
|
696
|
+
astype(dtype: DType): this;
|
|
697
|
+
/** Subtract an array from this one. */
|
|
698
|
+
sub(other: this | TracerValue): this;
|
|
699
|
+
/** Divide an array by this one. */
|
|
700
|
+
div(other: this | TracerValue): this;
|
|
701
|
+
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
702
|
+
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
703
|
+
/** Flatten the array without changing its data. */
|
|
704
|
+
flatten(): this;
|
|
705
|
+
/** Flatten the array without changing its data. */
|
|
706
|
+
ravel(): this;
|
|
707
|
+
/**
|
|
708
|
+
* Iterate over the first dimension of this array, returning slices.
|
|
709
|
+
*
|
|
710
|
+
* This can be used to destructure arrays. For example:
|
|
711
|
+
*
|
|
712
|
+
* ```js
|
|
713
|
+
* let x = np.array([[1, 2], [3, 4]]);
|
|
714
|
+
* let [a, b] = x;
|
|
715
|
+
* console.log(a.js()); // [1, 2]
|
|
716
|
+
* console.log(b.js()); // [3, 4]
|
|
717
|
+
* ```
|
|
718
|
+
*/
|
|
719
|
+
[Symbol.iterator](): IterableIterator<this>;
|
|
720
|
+
/**
|
|
721
|
+
* Slice an array along one or more axes.
|
|
722
|
+
*
|
|
723
|
+
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
724
|
+
* mimic this in JavaScript, we would write:
|
|
725
|
+
*
|
|
726
|
+
* ```js
|
|
727
|
+
* x.slice([1, 3], 2, [], null);
|
|
728
|
+
* ```
|
|
729
|
+
*
|
|
730
|
+
* The `slice` method accepts a variable number of arguments, each of which
|
|
731
|
+
* can be a number, an empty array, a single-element array, a two-element
|
|
732
|
+
* array, or `null`. The arguments are interpreted as follows:
|
|
733
|
+
*
|
|
734
|
+
* - A number `n` means to access the `n`-th element along that axis, removing
|
|
735
|
+
* that axis from the resulting shape.
|
|
736
|
+
* - An empty array `[]` means to keep that axis as-is, like `:` in Python.
|
|
737
|
+
* - A single-element array `[i]` means to start slicing from index `i`
|
|
738
|
+
* (inclusive) to the end of the axis, like `x[i:]`.
|
|
739
|
+
* - A two-element array `[i, j]` means to slice from index `i` (inclusive)
|
|
740
|
+
* to index `j` (exclusive), like `x[i:j]`.
|
|
741
|
+
* - `null` means to add a new axis at that position, like `np.newaxis`.
|
|
742
|
+
*
|
|
743
|
+
* Like in Python, negative indices are supported, which count from the end of
|
|
744
|
+
* the axis. For example, `-1` means the last element.
|
|
745
|
+
*
|
|
746
|
+
* Strided slices are not yet implemented, so you cannot write `x[::2]` or
|
|
747
|
+
* similar.
|
|
748
|
+
*
|
|
749
|
+
* Advanced indexing by integer arrays is also supported. This translates to
|
|
750
|
+
* the "gather" primitive, and it allows you to access specific elements of
|
|
751
|
+
* the array by integer indices stored in another array.
|
|
752
|
+
*/
|
|
753
|
+
slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
|
|
527
754
|
}
|
|
528
755
|
declare class ShapedArray implements AbstractValue {
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
756
|
+
readonly shape: number[];
|
|
757
|
+
readonly dtype: DType;
|
|
758
|
+
constructor(shape: number[], dtype: DType);
|
|
759
|
+
static fromAval(aval: AbstractValue): ShapedArray;
|
|
760
|
+
get ndim(): number;
|
|
761
|
+
toString(): string;
|
|
762
|
+
equals(other: ShapedArray): boolean;
|
|
536
763
|
}
|
|
537
|
-
|
|
764
|
+
//#endregion
|
|
765
|
+
//#region src/frontend/array.d.ts
|
|
538
766
|
type ArrayLike = Array | number | boolean;
|
|
767
|
+
/** Version of pureArray with fudged types. */
|
|
768
|
+
|
|
539
769
|
/**
|
|
540
770
|
* An executable operation that will be dispatched to the backend.
|
|
541
771
|
*
|
|
@@ -543,22 +773,23 @@ type ArrayLike = Array | number | boolean;
|
|
|
543
773
|
* operation is dispatched, the references should be released.
|
|
544
774
|
*/
|
|
545
775
|
declare class PendingExecute {
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
776
|
+
#private;
|
|
777
|
+
readonly backend: Backend;
|
|
778
|
+
readonly kernel: Kernel;
|
|
779
|
+
readonly inputs: Slot[];
|
|
780
|
+
readonly outputs: Slot[];
|
|
781
|
+
prepared: Executable | null;
|
|
782
|
+
submitted: boolean;
|
|
783
|
+
constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
|
|
784
|
+
updateRc(delta: number): void;
|
|
785
|
+
prepare(): Promise<void>;
|
|
786
|
+
prepareSync(): void;
|
|
787
|
+
submit(): void;
|
|
558
788
|
}
|
|
789
|
+
/** @inline */
|
|
559
790
|
type DTypeAndDevice = {
|
|
560
|
-
|
|
561
|
-
|
|
791
|
+
dtype?: DType;
|
|
792
|
+
device?: Device;
|
|
562
793
|
};
|
|
563
794
|
/**
|
|
564
795
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
@@ -571,64 +802,120 @@ type DTypeAndDevice = {
|
|
|
571
802
|
* "Array" type by name.
|
|
572
803
|
*/
|
|
573
804
|
declare class Array extends Tracer {
|
|
574
|
-
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
805
|
+
#private;
|
|
806
|
+
id: number;
|
|
807
|
+
/**
|
|
808
|
+
* @ignore
|
|
809
|
+
* Constructs an array from source, shape and backend. Note that if the source
|
|
810
|
+
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
811
|
+
* will be freed when the array is disposed.
|
|
812
|
+
*/
|
|
813
|
+
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, pending?: Iterable<PendingExecute> | null);
|
|
814
|
+
/** @ignore */
|
|
815
|
+
get aval(): ShapedArray;
|
|
816
|
+
/** Return a simple string representation of the array's dimensions. */
|
|
817
|
+
toString(): string;
|
|
818
|
+
get device(): Device;
|
|
819
|
+
get ref(): this;
|
|
820
|
+
dispose(): void;
|
|
821
|
+
/**
|
|
822
|
+
* Convert this array into a primitive value.
|
|
823
|
+
*
|
|
824
|
+
* This only works for scalars (0-dimensional arrays). It lets you get values
|
|
825
|
+
* "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
|
|
826
|
+
* evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
|
|
827
|
+
*
|
|
828
|
+
* This method is also called for `==` equality.
|
|
829
|
+
*/
|
|
830
|
+
[Symbol.toPrimitive](): any;
|
|
831
|
+
/** Realize the array and return it as data. */
|
|
832
|
+
data(): Promise<DataArray>;
|
|
833
|
+
/**
|
|
834
|
+
* Wait for this array to finish evaluation.
|
|
835
|
+
*
|
|
836
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
837
|
+
* that pending operations are dispatched and fully executed before it
|
|
838
|
+
* returns.
|
|
839
|
+
*
|
|
840
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
841
|
+
* dispatch of operations as well.
|
|
842
|
+
*/
|
|
843
|
+
wait(): Promise<Array>;
|
|
844
|
+
/**
|
|
845
|
+
* Realize the array and return it as data. This is a sync variant and not
|
|
846
|
+
* recommended for performance reasons, as it will block rendering.
|
|
847
|
+
*/
|
|
848
|
+
dataSync(): DataArray;
|
|
849
|
+
/**
|
|
850
|
+
* Convert this array into a JavaScript object.
|
|
851
|
+
*
|
|
852
|
+
* This is a blocking operation that will compile all of the shaders and wait
|
|
853
|
+
* for execution to complete, synchronously. No other JavaScript code on the
|
|
854
|
+
* site will be run during shader execution.
|
|
855
|
+
*
|
|
856
|
+
* To avoid blocking, prefer `jsAsync()` when possible.
|
|
857
|
+
*/
|
|
858
|
+
js(): any;
|
|
859
|
+
/** Convert this array into a JavaScript object, asynchronously. */
|
|
860
|
+
jsAsync(): Promise<any>;
|
|
861
|
+
/**
|
|
862
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
863
|
+
*
|
|
864
|
+
* Throws an error if the array does not have a single element. The array must
|
|
865
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
866
|
+
*/
|
|
867
|
+
item(): number;
|
|
868
|
+
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
869
|
+
static _implRules(): typeof implRules;
|
|
870
|
+
_realizeSource(): number;
|
|
609
871
|
}
|
|
610
872
|
/** Construct an array from a single scalar constant. */
|
|
611
|
-
declare function scalar(value: number | boolean, {
|
|
873
|
+
declare function scalar(value: number | boolean, {
|
|
874
|
+
dtype,
|
|
875
|
+
device
|
|
876
|
+
}?: DTypeAndDevice): Array;
|
|
612
877
|
/** Constructor for creating a new array from data. */
|
|
613
|
-
declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
614
|
-
|
|
878
|
+
declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
|
|
879
|
+
shape,
|
|
880
|
+
dtype,
|
|
881
|
+
device
|
|
882
|
+
}?: {
|
|
883
|
+
shape?: number[];
|
|
615
884
|
} & DTypeAndDevice): Array;
|
|
616
|
-
|
|
885
|
+
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
886
|
+
|
|
887
|
+
type ImplRule<P extends Primitive> = (tracers: Array[], params: PrimitiveParams<P>) => Array[];
|
|
888
|
+
declare const implRules: { [P in Primitive]: ImplRule<P> };
|
|
617
889
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
618
|
-
declare function zeros(shape: number[], {
|
|
890
|
+
declare function zeros(shape: number[], {
|
|
891
|
+
dtype,
|
|
892
|
+
device
|
|
893
|
+
}?: DTypeAndDevice): Array;
|
|
619
894
|
/** Return a new array of given shape and type, filled with ones. */
|
|
620
|
-
declare function ones(shape: number[], {
|
|
895
|
+
declare function ones(shape: number[], {
|
|
896
|
+
dtype,
|
|
897
|
+
device
|
|
898
|
+
}?: DTypeAndDevice): Array;
|
|
621
899
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
622
|
-
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
900
|
+
declare function full(shape: number[], fillValue: number | boolean | Array, {
|
|
901
|
+
dtype,
|
|
902
|
+
device
|
|
903
|
+
}?: DTypeAndDevice): Array;
|
|
623
904
|
/**
|
|
624
905
|
* Create an identity matrix.
|
|
625
906
|
*
|
|
626
907
|
* If numCols is not provided, it defaults to numRows, i.e., a square identity
|
|
627
908
|
* matrix with ones on the diagonal.
|
|
628
909
|
*/
|
|
629
|
-
declare function eye(numRows: number, numCols?: number, {
|
|
910
|
+
declare function eye(numRows: number, numCols?: number, {
|
|
911
|
+
dtype,
|
|
912
|
+
device
|
|
913
|
+
}?: DTypeAndDevice): Array;
|
|
630
914
|
/** Return the identity array, with ones on the main diagonal. */
|
|
631
|
-
declare function identity$1(n: number, {
|
|
915
|
+
declare function identity$1(n: number, {
|
|
916
|
+
dtype,
|
|
917
|
+
device
|
|
918
|
+
}?: DTypeAndDevice): Array;
|
|
632
919
|
/**
|
|
633
920
|
* Return evenly spaced values within a given interval.
|
|
634
921
|
*
|
|
@@ -643,7 +930,10 @@ declare function identity$1(n: number, { dtype, device }?: DTypeAndDevice): Arra
|
|
|
643
930
|
* Defaults to an integer data type. This can produce unintended results when
|
|
644
931
|
* using a non-integer step, so prefer linspace() in those cases.
|
|
645
932
|
*/
|
|
646
|
-
declare function arange(start: number, stop?: number, step?: number, {
|
|
933
|
+
declare function arange(start: number, stop?: number, step?: number, {
|
|
934
|
+
dtype,
|
|
935
|
+
device
|
|
936
|
+
}?: DTypeAndDevice): Array;
|
|
647
937
|
/**
|
|
648
938
|
* Return evenly spaced numbers over a specified interval.
|
|
649
939
|
*
|
|
@@ -653,12 +943,18 @@ declare function arange(start: number, stop?: number, step?: number, { dtype, de
|
|
|
653
943
|
*
|
|
654
944
|
* The default data type is Float32. Use arange() for integer steps.
|
|
655
945
|
*/
|
|
656
|
-
declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
|
|
657
|
-
|
|
946
|
+
declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, {
|
|
947
|
+
dtype,
|
|
948
|
+
device
|
|
949
|
+
}?: DTypeAndDevice): Array;
|
|
950
|
+
declare namespace numpy_d_exports {
|
|
951
|
+
export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
|
|
952
|
+
}
|
|
658
953
|
declare const float32 = DType.Float32;
|
|
659
954
|
declare const int32 = DType.Int32;
|
|
955
|
+
declare const uint32 = DType.Uint32;
|
|
660
956
|
declare const bool = DType.Bool;
|
|
661
|
-
declare const
|
|
957
|
+
declare const float16 = DType.Float16;
|
|
662
958
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
663
959
|
declare const e: number;
|
|
664
960
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -685,6 +981,8 @@ declare const cos: (x: ArrayLike) => Array;
|
|
|
685
981
|
declare const exp: (x: ArrayLike) => Array;
|
|
686
982
|
/** Calculate the natural logarithm of all elements in the input array. */
|
|
687
983
|
declare const log: (x: ArrayLike) => Array;
|
|
984
|
+
/** Calculate the square root of all elements in the input array. */
|
|
985
|
+
declare const sqrt: (x: ArrayLike) => Array;
|
|
688
986
|
/** Return element-wise minimum of the input arrays. */
|
|
689
987
|
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
690
988
|
/** Return element-wise maximum of the input arrays. */
|
|
@@ -712,16 +1010,98 @@ declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
|
712
1010
|
* length of the array and remaining dimensions.
|
|
713
1011
|
*/
|
|
714
1012
|
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
715
|
-
|
|
1013
|
+
/** Move axes of an array to new positions. Other axes retain original order. */
|
|
716
1014
|
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
717
|
-
/**
|
|
1015
|
+
/**
|
|
1016
|
+
* Add padding (zeros) to an array.
|
|
1017
|
+
*
|
|
1018
|
+
* The `width` argument is either an integer or pair of integers, in which case
|
|
1019
|
+
* all axes are padded with the same width. Or if it is an array of pairs, each
|
|
1020
|
+
* pair specifies the padding for its corresponding axis.
|
|
1021
|
+
*/
|
|
1022
|
+
declare const pad: (x: ArrayLike, width: number | [number, number] | [number, number][]) => Array;
|
|
1023
|
+
/** Return the number of dimensions of an array. Does not consume array reference. */
|
|
718
1024
|
declare const ndim: (x: ArrayLike) => number;
|
|
719
|
-
/** Return the shape of an array. */
|
|
720
|
-
declare const shape: (x: ArrayLike) => number[];
|
|
721
|
-
/** Return
|
|
1025
|
+
/** Return the shape of an array. Does not consume array reference. */
|
|
1026
|
+
declare const shape$1: (x: ArrayLike) => number[];
|
|
1027
|
+
/** Return an array of zeros with the same shape and type as a given array. */
|
|
1028
|
+
declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1029
|
+
/** Return an array of ones with the same shape and type as a given array. */
|
|
1030
|
+
declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
|
|
1031
|
+
/** Return a full array with the same shape and type as a given array. */
|
|
1032
|
+
declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
|
|
1033
|
+
/**
|
|
1034
|
+
* Return the number of elements in an array, optionally along an axis.
|
|
1035
|
+
* Does not consume array reference.
|
|
1036
|
+
*/
|
|
722
1037
|
declare function size(a: ArrayLike, axis?: number): number;
|
|
1038
|
+
/** Convert an array to a specified dtype. */
|
|
1039
|
+
declare function astype(a: ArrayLike, dtype: DType): Array;
|
|
1040
|
+
/** Sum of the elements of the array over a given axis, or axes. */
|
|
1041
|
+
declare function sum(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
|
|
1042
|
+
/** Product of the array elements over a given axis. */
|
|
1043
|
+
declare function prod(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
|
|
1044
|
+
/** Return the minimum of array elements along a given axis. */
|
|
1045
|
+
declare function min(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
|
|
1046
|
+
/** Return the maximum of array elements along a given axis. */
|
|
1047
|
+
declare function max(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
|
|
1048
|
+
/** Compute the average of the array elements along the specified axis. */
|
|
1049
|
+
declare function mean(a: ArrayLike, axis?: number | number[], opts?: ReduceOpts): Array;
|
|
1050
|
+
/**
|
|
1051
|
+
* Returns the indices of the minimum values along an axis.
|
|
1052
|
+
*
|
|
1053
|
+
* By default, index is into the flatted array, otherwise it is along the
|
|
1054
|
+
* specified axis.
|
|
1055
|
+
*/
|
|
1056
|
+
declare function argmin(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
1057
|
+
/**
|
|
1058
|
+
* Returns the indices of the maximum values along an axis.
|
|
1059
|
+
*
|
|
1060
|
+
* By default, index is into the flatted array, otherwise it is along the
|
|
1061
|
+
* specified axis.
|
|
1062
|
+
*/
|
|
1063
|
+
declare function argmax(a: ArrayLike, axis?: number, opts?: ReduceOpts): Array;
|
|
723
1064
|
/** Reverse the elements in an array along the given axes. */
|
|
724
1065
|
declare function flip(x: ArrayLike, axis?: number | number[]): Array;
|
|
1066
|
+
/**
|
|
1067
|
+
* Join a sequence of arrays along an existing axis.
|
|
1068
|
+
*
|
|
1069
|
+
* The arrays must have the same shape, except in the dimension corresponding to
|
|
1070
|
+
* `axis` (the first, by default).
|
|
1071
|
+
*
|
|
1072
|
+
* No scalars can be passed to this function, as the axis is then ambiguous.
|
|
1073
|
+
*/
|
|
1074
|
+
declare function concatenate(xs: Array[], axis?: number): Array;
|
|
1075
|
+
/**
|
|
1076
|
+
* Join a sequence of arrays along a new axis.
|
|
1077
|
+
*
|
|
1078
|
+
* The `axis` parameter specifies the index of the new axis in the dimensions of
|
|
1079
|
+
* the result. For example, if `axis=0` it will be the first dimension and if
|
|
1080
|
+
* `axis=-1` it will be the last dimension.
|
|
1081
|
+
*
|
|
1082
|
+
* All shapes must have the same shape.
|
|
1083
|
+
*/
|
|
1084
|
+
declare function stack(xs: ArrayLike[], axis?: number): Array;
|
|
1085
|
+
/**
|
|
1086
|
+
* Horizontally stack arrays. Inputs are promoted to rank at least 1, then
|
|
1087
|
+
* concatenated along axis 1 (if rank-2 or higher) or 0 (if rank-1).
|
|
1088
|
+
*/
|
|
1089
|
+
declare function hstack(xs: ArrayLike[]): Array;
|
|
1090
|
+
/**
|
|
1091
|
+
* Vertically stack arrays. Inputs are promoted to rank at least 2, then
|
|
1092
|
+
* concatenated along axis 0.
|
|
1093
|
+
*/
|
|
1094
|
+
declare function vstack(xs: ArrayLike[]): Array;
|
|
1095
|
+
/**
|
|
1096
|
+
* Stack arrays depth-wise. Inputs are promoted to rank at least 3, then
|
|
1097
|
+
* concatenated along axis 2.
|
|
1098
|
+
*/
|
|
1099
|
+
declare function dstack(xs: ArrayLike[]): Array;
|
|
1100
|
+
/**
|
|
1101
|
+
* Stack arrays column-wise. Inputs are promoted to rank at least 2, then
|
|
1102
|
+
* concatenated along axis 1.
|
|
1103
|
+
*/
|
|
1104
|
+
declare function columnStack(xs: ArrayLike[]): Array;
|
|
725
1105
|
/** Flip an array vertically (axis=0). */
|
|
726
1106
|
declare function flipud(x: ArrayLike): Array;
|
|
727
1107
|
/** Flip an array horizontally (axis=1). */
|
|
@@ -733,13 +1113,13 @@ declare function ravel(a: ArrayLike): Array;
|
|
|
733
1113
|
* Return specified diagonals.
|
|
734
1114
|
*
|
|
735
1115
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
736
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
1116
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
737
1117
|
*
|
|
738
|
-
* This returns a view over the existing array.
|
|
1118
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
1119
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
1120
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
739
1121
|
*/
|
|
740
1122
|
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
741
|
-
/** Transposes a matrix or stack of matrices `x` (swap last two axes). */
|
|
742
|
-
declare function matrixTranspose(x: ArrayLike): Array;
|
|
743
1123
|
/**
|
|
744
1124
|
* Extract a diagonal or construct a diagonal array.
|
|
745
1125
|
*
|
|
@@ -749,15 +1129,15 @@ declare function matrixTranspose(x: ArrayLike): Array;
|
|
|
749
1129
|
declare function diag(v: ArrayLike, k?: number): Array;
|
|
750
1130
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
751
1131
|
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
752
|
-
|
|
753
|
-
|
|
1132
|
+
rtol?: number;
|
|
1133
|
+
atol?: number;
|
|
754
1134
|
}): boolean;
|
|
755
1135
|
/** Matrix product of two arrays. */
|
|
756
|
-
declare
|
|
1136
|
+
declare function matmul(x: ArrayLike, y: ArrayLike): Array;
|
|
757
1137
|
/** Dot product of two arrays. */
|
|
758
|
-
declare
|
|
1138
|
+
declare function dot(x: ArrayLike, y: ArrayLike): Array;
|
|
759
1139
|
/** Vector dot product of two arrays. */
|
|
760
|
-
declare
|
|
1140
|
+
declare function vecdot(x: ArrayLike, y: ArrayLike): Array;
|
|
761
1141
|
/**
|
|
762
1142
|
* Return the dot product of two vectors.
|
|
763
1143
|
*
|
|
@@ -770,8 +1150,10 @@ declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
|
770
1150
|
* Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
|
|
771
1151
|
* fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
|
|
772
1152
|
*/
|
|
773
|
-
declare function meshgrid(xs: Array[], {
|
|
774
|
-
|
|
1153
|
+
declare function meshgrid(xs: Array[], {
|
|
1154
|
+
indexing
|
|
1155
|
+
}?: {
|
|
1156
|
+
indexing?: "xy" | "ij";
|
|
775
1157
|
}): Array[];
|
|
776
1158
|
/**
|
|
777
1159
|
* Clip (limit) the values in an array.
|
|
@@ -807,161 +1189,119 @@ declare function exp2(p: ArrayLike): Array;
|
|
|
807
1189
|
declare function log2(x: ArrayLike): Array;
|
|
808
1190
|
/** Return the base-10 logarithm of x, element-wise. */
|
|
809
1191
|
declare function log10(x: ArrayLike): Array;
|
|
810
|
-
|
|
811
|
-
|
|
812
|
-
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
declare
|
|
816
|
-
|
|
817
|
-
|
|
818
|
-
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
declare
|
|
822
|
-
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
declare
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
declare const numpy_e: typeof e;
|
|
831
|
-
declare const numpy_equal: typeof equal;
|
|
832
|
-
declare const numpy_eulerGamma: typeof eulerGamma;
|
|
833
|
-
declare const numpy_exp: typeof exp;
|
|
834
|
-
declare const numpy_exp2: typeof exp2;
|
|
835
|
-
declare const numpy_eye: typeof eye;
|
|
836
|
-
declare const numpy_flip: typeof flip;
|
|
837
|
-
declare const numpy_fliplr: typeof fliplr;
|
|
838
|
-
declare const numpy_flipud: typeof flipud;
|
|
839
|
-
declare const numpy_float32: typeof float32;
|
|
840
|
-
declare const numpy_full: typeof full;
|
|
841
|
-
declare const numpy_greater: typeof greater;
|
|
842
|
-
declare const numpy_greaterEqual: typeof greaterEqual;
|
|
843
|
-
declare const numpy_inf: typeof inf;
|
|
844
|
-
declare const numpy_int32: typeof int32;
|
|
845
|
-
declare const numpy_less: typeof less;
|
|
846
|
-
declare const numpy_lessEqual: typeof lessEqual;
|
|
847
|
-
declare const numpy_linspace: typeof linspace;
|
|
848
|
-
declare const numpy_log: typeof log;
|
|
849
|
-
declare const numpy_log10: typeof log10;
|
|
850
|
-
declare const numpy_log2: typeof log2;
|
|
851
|
-
declare const numpy_matmul: typeof matmul;
|
|
852
|
-
declare const numpy_matrixTranspose: typeof matrixTranspose;
|
|
853
|
-
declare const numpy_maximum: typeof maximum;
|
|
854
|
-
declare const numpy_meshgrid: typeof meshgrid;
|
|
855
|
-
declare const numpy_minimum: typeof minimum;
|
|
856
|
-
declare const numpy_moveaxis: typeof moveaxis;
|
|
857
|
-
declare const numpy_multiply: typeof multiply;
|
|
858
|
-
declare const numpy_nan: typeof nan;
|
|
859
|
-
declare const numpy_ndim: typeof ndim;
|
|
860
|
-
declare const numpy_negative: typeof negative;
|
|
861
|
-
declare const numpy_notEqual: typeof notEqual;
|
|
862
|
-
declare const numpy_ones: typeof ones;
|
|
863
|
-
declare const numpy_permuteDims: typeof permuteDims;
|
|
864
|
-
declare const numpy_pi: typeof pi;
|
|
865
|
-
declare const numpy_ravel: typeof ravel;
|
|
866
|
-
declare const numpy_reciprocal: typeof reciprocal;
|
|
867
|
-
declare const numpy_reshape: typeof reshape;
|
|
868
|
-
declare const numpy_scalar: typeof scalar;
|
|
869
|
-
declare const numpy_shape: typeof shape;
|
|
870
|
-
declare const numpy_sin: typeof sin;
|
|
871
|
-
declare const numpy_size: typeof size;
|
|
872
|
-
declare const numpy_square: typeof square;
|
|
873
|
-
declare const numpy_sum: typeof sum;
|
|
874
|
-
declare const numpy_tan: typeof tan;
|
|
875
|
-
declare const numpy_transpose: typeof transpose;
|
|
876
|
-
declare const numpy_trueDivide: typeof trueDivide;
|
|
877
|
-
declare const numpy_trunc: typeof trunc;
|
|
878
|
-
declare const numpy_vdot: typeof vdot;
|
|
879
|
-
declare const numpy_vecdot: typeof vecdot;
|
|
880
|
-
declare const numpy_where: typeof where;
|
|
881
|
-
declare const numpy_zeros: typeof zeros;
|
|
882
|
-
declare namespace numpy {
|
|
883
|
-
export { numpy_Array as Array, type numpy_ArrayLike as ArrayLike, numpy_DType as DType, numpy_abs as abs, numpy_absolute as absolute, numpy_add as add, numpy_allclose as allclose, numpy_arange as arange, numpy_array as array, numpy_bool as bool, numpy_clip as clip, numpy_complex64 as complex64, numpy_cos as cos, numpy_diag as diag, numpy_diagonal as diagonal, numpy_divide as divide, numpy_dot as dot, numpy_e as e, numpy_equal as equal, numpy_eulerGamma as eulerGamma, numpy_exp as exp, numpy_exp2 as exp2, numpy_eye as eye, numpy_flip as flip, numpy_fliplr as fliplr, numpy_flipud as flipud, numpy_float32 as float32, numpy_full as full, numpy_greater as greater, numpy_greaterEqual as greaterEqual, identity$1 as identity, numpy_inf as inf, numpy_int32 as int32, numpy_less as less, numpy_lessEqual as lessEqual, numpy_linspace as linspace, numpy_log as log, numpy_log10 as log10, numpy_log2 as log2, numpy_matmul as matmul, numpy_matrixTranspose as matrixTranspose, numpy_maximum as maximum, numpy_meshgrid as meshgrid, numpy_minimum as minimum, numpy_moveaxis as moveaxis, numpy_multiply as multiply, numpy_nan as nan, numpy_ndim as ndim, numpy_negative as negative, numpy_notEqual as notEqual, numpy_ones as ones, numpy_permuteDims as permuteDims, numpy_pi as pi, numpy_ravel as ravel, numpy_reciprocal as reciprocal, numpy_reshape as reshape, numpy_scalar as scalar, numpy_shape as shape, numpy_sin as sin, numpy_size as size, numpy_square as square, numpy_sum as sum, numpy_tan as tan, numpy_transpose as transpose, numpy_trueDivide as trueDivide, numpy_trunc as trunc, numpy_vdot as vdot, numpy_vecdot as vecdot, numpy_where as where, numpy_zeros as zeros };
|
|
884
|
-
}
|
|
885
|
-
|
|
886
|
-
/** General class for pretty-printing expressions with indentation. */
|
|
887
|
-
declare class PPrint {
|
|
888
|
-
readonly indents: number[];
|
|
889
|
-
readonly lines: string[];
|
|
890
|
-
constructor(indents: number[], lines: string[]);
|
|
891
|
-
/** Add a fixed amount of indentation to each line. */
|
|
892
|
-
indent(spaces: number): PPrint;
|
|
893
|
-
/** Concatenate two or more pretty-printed expressions. */
|
|
894
|
-
concat(...items: PPrint[]): PPrint;
|
|
895
|
-
/** Stack one block to the right of another one, sharing 1 common line. */
|
|
896
|
-
stack(other: PPrint): PPrint;
|
|
897
|
-
/** Combine this block of lines into a formatted string. */
|
|
898
|
-
toString(): string;
|
|
899
|
-
static pp(s: Stringable): PPrint;
|
|
900
|
-
}
|
|
901
|
-
interface Stringable {
|
|
902
|
-
toString(): string;
|
|
903
|
-
}
|
|
904
|
-
|
|
1192
|
+
/**
|
|
1193
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
1194
|
+
*
|
|
1195
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
1196
|
+
*/
|
|
1197
|
+
declare function sinh(x: ArrayLike): Array;
|
|
1198
|
+
/**
|
|
1199
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
1200
|
+
*
|
|
1201
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
1202
|
+
*/
|
|
1203
|
+
declare function cosh(x: ArrayLike): Array;
|
|
1204
|
+
/**
|
|
1205
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
1206
|
+
*
|
|
1207
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
1208
|
+
*/
|
|
1209
|
+
declare function tanh(x: ArrayLike): Array;
|
|
1210
|
+
//#endregion
|
|
1211
|
+
//#region src/frontend/jaxpr.d.ts
|
|
905
1212
|
/** Variable in a Jaxpr expression. */
|
|
906
1213
|
declare class Var {
|
|
907
|
-
|
|
908
|
-
|
|
909
|
-
|
|
910
|
-
|
|
911
|
-
|
|
1214
|
+
#private;
|
|
1215
|
+
readonly id: number;
|
|
1216
|
+
readonly aval: ShapedArray;
|
|
1217
|
+
constructor(aval: ShapedArray);
|
|
1218
|
+
toString(): string;
|
|
912
1219
|
}
|
|
913
1220
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
914
1221
|
declare class Lit {
|
|
915
|
-
|
|
916
|
-
|
|
917
|
-
|
|
918
|
-
|
|
1222
|
+
readonly dtype: DType;
|
|
1223
|
+
readonly value: number;
|
|
1224
|
+
readonly aval: ShapedArray;
|
|
1225
|
+
constructor(dtype: DType, value: number);
|
|
919
1226
|
}
|
|
920
1227
|
type Atom = Var | Lit;
|
|
921
1228
|
declare class VarPrinter {
|
|
922
|
-
|
|
923
|
-
|
|
924
|
-
|
|
925
|
-
|
|
1229
|
+
#private;
|
|
1230
|
+
names: Map<Var, string>;
|
|
1231
|
+
name(v: Var): string;
|
|
1232
|
+
nameType(v: Var): string;
|
|
926
1233
|
}
|
|
927
1234
|
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
928
1235
|
declare class JaxprEqn {
|
|
929
|
-
|
|
930
|
-
|
|
931
|
-
|
|
932
|
-
|
|
933
|
-
|
|
934
|
-
|
|
935
|
-
|
|
1236
|
+
readonly primitive: Primitive;
|
|
1237
|
+
readonly inputs: Atom[];
|
|
1238
|
+
readonly params: Record<string, any>;
|
|
1239
|
+
readonly outBinders: Var[];
|
|
1240
|
+
constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
|
|
1241
|
+
pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
|
|
1242
|
+
toString(): string;
|
|
936
1243
|
}
|
|
937
1244
|
/** Typed intermediate representation for traced computations. */
|
|
938
1245
|
declare class Jaxpr implements FpHashable {
|
|
939
|
-
|
|
940
|
-
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
961
|
-
|
|
962
|
-
|
|
1246
|
+
#private;
|
|
1247
|
+
readonly inBinders: Var[];
|
|
1248
|
+
readonly eqns: JaxprEqn[];
|
|
1249
|
+
readonly outs: Atom[];
|
|
1250
|
+
constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
|
|
1251
|
+
pprint(): PPrint;
|
|
1252
|
+
toString(): string;
|
|
1253
|
+
/**
|
|
1254
|
+
* Gets a hash of this Jaxpr.
|
|
1255
|
+
*
|
|
1256
|
+
* Var identity is not considered in the hash, so two Jaxprs with the same
|
|
1257
|
+
* order of assignments and operators but different variable IDs will resolve
|
|
1258
|
+
* to the same hash (and toString representation).
|
|
1259
|
+
*/
|
|
1260
|
+
getHash(): bigint;
|
|
1261
|
+
hash(state: FpHash): void;
|
|
1262
|
+
/**
|
|
1263
|
+
* Produce a simplified Jaxpr with basic optimizations applied.
|
|
1264
|
+
* - Trim away unused variables.
|
|
1265
|
+
* - Fold away *1, *0, or +0 operations against literals.
|
|
1266
|
+
* - Remove no-op movement operations.
|
|
1267
|
+
*/
|
|
1268
|
+
simplify(): Jaxpr;
|
|
1269
|
+
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1270
|
+
flatten(): Jaxpr;
|
|
1271
|
+
}
|
|
1272
|
+
/** @inline */
|
|
1273
|
+
type JitOpts = {
|
|
1274
|
+
staticArgnums?: number[];
|
|
1275
|
+
device?: Device;
|
|
1276
|
+
};
|
|
1277
|
+
declare namespace lax_d_exports {
|
|
1278
|
+
export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
|
|
1279
|
+
}
|
|
1280
|
+
type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
|
|
1281
|
+
/**
|
|
1282
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
1283
|
+
*
|
|
1284
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
1285
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
1286
|
+
*
|
|
1287
|
+
* Grouped convolutions are not supported right now.
|
|
1288
|
+
*/
|
|
1289
|
+
declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
|
|
1290
|
+
lhsDilation,
|
|
1291
|
+
rhsDilation
|
|
1292
|
+
}?: {
|
|
1293
|
+
lhsDilation?: number[];
|
|
1294
|
+
rhsDilation?: number[];
|
|
1295
|
+
}): Array;
|
|
1296
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1297
|
+
declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
|
|
1298
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
1299
|
+
declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
|
|
1300
|
+
/** Reduce a computation over padded windows. */
|
|
1301
|
+
declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
|
|
1302
|
+
declare namespace nn_d_exports {
|
|
1303
|
+
export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
|
|
963
1304
|
}
|
|
964
|
-
|
|
965
1305
|
/**
|
|
966
1306
|
* Rectified Linear Unit (ReLU) activation function:
|
|
967
1307
|
* `relu(x) = max(x, 0)`.
|
|
@@ -1000,7 +1340,7 @@ declare function softSign(x: ArrayLike): Array;
|
|
|
1000
1340
|
*
|
|
1001
1341
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1002
1342
|
*/
|
|
1003
|
-
declare
|
|
1343
|
+
declare const silu: (x: ArrayLike) => Array;
|
|
1004
1344
|
/**
|
|
1005
1345
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1006
1346
|
* Swish, computed element-wise:
|
|
@@ -1010,7 +1350,7 @@ declare function silu(x: ArrayLike): Array;
|
|
|
1010
1350
|
*
|
|
1011
1351
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1012
1352
|
*/
|
|
1013
|
-
declare const swish:
|
|
1353
|
+
declare const swish: (x: ArrayLike) => Array;
|
|
1014
1354
|
/**
|
|
1015
1355
|
* Log-sigmoid activation function, computed element-wise:
|
|
1016
1356
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
@@ -1018,49 +1358,154 @@ declare const swish: typeof silu;
|
|
|
1018
1358
|
declare function logSigmoid(x: ArrayLike): Array;
|
|
1019
1359
|
/** Identity activation function. Returns the argument unmodified. */
|
|
1020
1360
|
declare const identity: (x: ArrayLike) => Array;
|
|
1021
|
-
|
|
1022
|
-
declare
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
|
|
1029
|
-
declare
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1361
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
1362
|
+
declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
|
|
1363
|
+
/**
|
|
1364
|
+
* Exponential linear unit activation function.
|
|
1365
|
+
*
|
|
1366
|
+
* Computes the element-wise function:
|
|
1367
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
1368
|
+
*/
|
|
1369
|
+
declare function elu(x: ArrayLike, alpha?: number): Array;
|
|
1370
|
+
/**
|
|
1371
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
1372
|
+
*
|
|
1373
|
+
* Computes the element-wise function:
|
|
1374
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
1375
|
+
*/
|
|
1376
|
+
declare function celu(x: ArrayLike, alpha?: number): Array;
|
|
1377
|
+
/**
|
|
1378
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
1379
|
+
*
|
|
1380
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
1381
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
1382
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
1383
|
+
*
|
|
1384
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
1385
|
+
*
|
|
1386
|
+
* This will be improved in the future.
|
|
1387
|
+
*/
|
|
1388
|
+
declare const gelu: (x: ArrayLike) => Array;
|
|
1389
|
+
/**
|
|
1390
|
+
* Gated linear unit (GLU) activation function.
|
|
1391
|
+
*
|
|
1392
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
1393
|
+
* computes `a * sigmoid(b)`.
|
|
1394
|
+
*/
|
|
1395
|
+
declare function glu(x: ArrayLike, axis?: number): Array;
|
|
1396
|
+
/**
|
|
1397
|
+
* Mish activation function.
|
|
1398
|
+
*
|
|
1399
|
+
* Computes the element-wise function:
|
|
1400
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
1401
|
+
*/
|
|
1402
|
+
declare function mish(x: ArrayLike): Array;
|
|
1403
|
+
/**
|
|
1404
|
+
* Softmax function. Computes the function which rescales elements to the range
|
|
1405
|
+
* [0, 1] such that the elements along `axis` sum to 1.
|
|
1406
|
+
*
|
|
1407
|
+
* If `axis` is not specified, it defaults to the last axis.
|
|
1408
|
+
*
|
|
1409
|
+
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
1410
|
+
*/
|
|
1411
|
+
declare function softmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
1412
|
+
/**
|
|
1413
|
+
* Log-Softmax function.
|
|
1414
|
+
*
|
|
1415
|
+
* Computes the logarithm of the `softmax` function, which rescales elements to
|
|
1416
|
+
* the range [-infinity, 0).
|
|
1417
|
+
*
|
|
1418
|
+
* If `axis` is not specified, it defaults to the last axis.
|
|
1419
|
+
*/
|
|
1420
|
+
declare function logSoftmax(x: ArrayLike, axis?: number | number[]): Array;
|
|
1421
|
+
/**
|
|
1422
|
+
* Log-sum-exp reduction. Also a multivariate version of `softplus`.
|
|
1423
|
+
*
|
|
1424
|
+
* If no axis is specified, the reduction is performed over all elements. This
|
|
1425
|
+
* convention differs from `jax.nn.logSoftmax()`.
|
|
1426
|
+
*
|
|
1427
|
+
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
1428
|
+
*/
|
|
1429
|
+
declare function logsumexp(x: ArrayLike, axis?: number | number[]): Array;
|
|
1430
|
+
/**
|
|
1431
|
+
* One-hot encodes the given indices.
|
|
1432
|
+
*
|
|
1433
|
+
* Each index in the integer input `x` is encoded as a vector of zeros of length
|
|
1434
|
+
* `numClasses`, with a 1 at the index position specified by its value.
|
|
1435
|
+
*
|
|
1436
|
+
* ```js
|
|
1437
|
+
* import { nn, numpy as np } from '@jax-js/jax';
|
|
1438
|
+
*
|
|
1439
|
+
* nn.oneHot(np.array([1, 1, 2], { dtype: np.int32 }), 3);
|
|
1440
|
+
* // Output:
|
|
1441
|
+
* // [[0, 1, 0],
|
|
1442
|
+
* // [0, 1, 0],
|
|
1443
|
+
* // [0, 0, 1]]
|
|
1444
|
+
* ```
|
|
1445
|
+
*/
|
|
1446
|
+
declare function oneHot(x: Array, numClasses: number): Array;
|
|
1447
|
+
declare namespace random_d_exports {
|
|
1448
|
+
export { bits, key, split, uniform };
|
|
1033
1449
|
}
|
|
1034
|
-
|
|
1035
|
-
|
|
1450
|
+
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
1451
|
+
declare function key(seed: number): Array;
|
|
1452
|
+
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
1453
|
+
declare function split(key: Array, num?: number | number[]): Array;
|
|
1454
|
+
/** Sample uniform bits in the form of unsigned integers. */
|
|
1455
|
+
declare function bits(key: Array, shape?: number[]): Array;
|
|
1456
|
+
/** Sample uniform random values in [minval, maxval) with given shape. */
|
|
1457
|
+
declare function uniform(key: Array, shape?: number[], {
|
|
1458
|
+
minval,
|
|
1459
|
+
maxval
|
|
1460
|
+
}?: {
|
|
1461
|
+
minval?: number;
|
|
1462
|
+
maxval?: number;
|
|
1463
|
+
}): Array;
|
|
1464
|
+
//#endregion
|
|
1465
|
+
//#region src/index.d.ts
|
|
1036
1466
|
/** Compute the forward-mode Jacobian-vector product for a function. */
|
|
1037
|
-
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1467
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1038
1468
|
/** Vectorize an operation on a batched axis for one or more inputs. */
|
|
1039
|
-
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1469
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1040
1470
|
/** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
|
|
1041
|
-
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>,
|
|
1471
|
+
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1042
1472
|
/** Construct a Jaxpr by dynamically tracing a function with example inputs. */
|
|
1043
|
-
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
|
|
1473
|
+
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
|
|
1474
|
+
jaxpr: Jaxpr;
|
|
1475
|
+
consts: Array[];
|
|
1476
|
+
treedef: JsTreeDef;
|
|
1047
1477
|
};
|
|
1048
|
-
|
|
1478
|
+
/**
|
|
1479
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
1480
|
+
*
|
|
1481
|
+
* The function will be compiled the first time it is called with a set of
|
|
1482
|
+
* argument shapes.
|
|
1483
|
+
*
|
|
1484
|
+
* **Options:**
|
|
1485
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
1486
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
1487
|
+
* and different values will trigger recompilation.
|
|
1488
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
1489
|
+
* computation will be placed on the default device.
|
|
1490
|
+
*/
|
|
1491
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1049
1492
|
/**
|
|
1050
1493
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1051
1494
|
* partial evaluation.
|
|
1052
1495
|
*/
|
|
1053
|
-
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1496
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
|
|
1054
1497
|
/** Calculate the reverse-mode vector-Jacobian product for a function. */
|
|
1055
|
-
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1498
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1056
1499
|
/**
|
|
1057
1500
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1058
1501
|
* first argument.
|
|
1059
1502
|
*/
|
|
1060
|
-
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f:
|
|
1503
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1504
|
+
/** Create a function that evaluates both `f` and the gradient of `f`. */
|
|
1505
|
+
declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
|
|
1061
1506
|
/** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
|
|
1062
1507
|
declare const jacrev: typeof jacfwd;
|
|
1063
1508
|
/** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
|
|
1064
|
-
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>,
|
|
1065
|
-
|
|
1066
|
-
export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn, numpy, setDevice, tree, vjp, vmap };
|
|
1509
|
+
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
|
|
1510
|
+
//#endregion
|
|
1511
|
+
export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
|