@jax-js/jax 0.0.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/LICENSE +20 -0
- package/README.md +102 -0
- package/dist/chunk-B2GFURUN.js +1978 -0
- package/dist/index.cjs +6284 -0
- package/dist/index.d.cts +1066 -0
- package/dist/index.d.ts +1066 -0
- package/dist/index.js +3708 -0
- package/dist/webgpu-QNXDOQZP.js +559 -0
- package/package.json +69 -0
package/dist/index.d.cts
ADDED
|
@@ -0,0 +1,1066 @@
|
|
|
1
|
+
/**
|
|
2
|
+
* @file Lazy shape tracking for multidimensional tensors.
|
|
3
|
+
*
|
|
4
|
+
* This module provides an immutable `View` class that can be used to calculate
|
|
5
|
+
* shapes of arrays as operations are applied to them, lazily.
|
|
6
|
+
*
|
|
7
|
+
* Some operations like reshape() may not be representable with a single view,
|
|
8
|
+
* for instance, because composing reshape() with shrink() leads to a
|
|
9
|
+
* non-contiguous range of memory locations. This is why `ShapeTracker` is a
|
|
10
|
+
* list of views.
|
|
11
|
+
*
|
|
12
|
+
* Indexing into a `ShapeTracker` or `View` can be folded into shader code.
|
|
13
|
+
*
|
|
14
|
+
* Originally based on tinygrad's implementation of shape tracking in the
|
|
15
|
+
* `tinygrad.shape` module. But this version is simplified a bit. I'm not really
|
|
16
|
+
* trying to innovate on shape tracking in this library, so if I have doubts on
|
|
17
|
+
* something, it'll just be copied from tinygrad (with comments).
|
|
18
|
+
*
|
|
19
|
+
* This file is a bit longer than the original, since Python is more concise.
|
|
20
|
+
*/
|
|
21
|
+
|
|
22
|
+
type Pair = [number, number];
|
|
23
|
+
/**
|
|
24
|
+
* A multidimensional view into memory. An array can be thought of as the
|
|
25
|
+
* combination of a linear buffer of memory, along with a `View`.
|
|
26
|
+
*
|
|
27
|
+
* Formula for getting a data point is basically:
|
|
28
|
+
* 1. Check if ∀i. 0 <= dim[i] < shape[i], otherwise out of bounds.
|
|
29
|
+
* 2. If mask exists, and ∃i. dim[i] ∉ mask[i], return 0.
|
|
30
|
+
* 2. Otherwise, look at this memory address: offset + ∑(strides[i] * dim[i]).
|
|
31
|
+
*/
|
|
32
|
+
declare class View {
|
|
33
|
+
#private;
|
|
34
|
+
/** The shape of the view (size of each dimension). */
|
|
35
|
+
readonly shape: number[];
|
|
36
|
+
/** How many indices to move in buffer for each hop in one dimension. */
|
|
37
|
+
readonly strides: number[];
|
|
38
|
+
/** Offset from the start of the buffer. */
|
|
39
|
+
readonly offset: number;
|
|
40
|
+
/** Masked out subarray where data is read. All other data is zeroed. */
|
|
41
|
+
readonly mask: Pair[] | null;
|
|
42
|
+
private constructor();
|
|
43
|
+
static create(shape: number[], strides?: number[], offset?: number, mask?: Pair[] | null): View;
|
|
44
|
+
get ndim(): number;
|
|
45
|
+
get size(): number;
|
|
46
|
+
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
47
|
+
get contiguous(): boolean;
|
|
48
|
+
/** Produce an AluExp for evaluating this view at an index. */
|
|
49
|
+
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
50
|
+
/**
|
|
51
|
+
* Try to compose this view with another one. `this` view is applied first,
|
|
52
|
+
* followed by the argument. If this is not possible for the specific views,
|
|
53
|
+
* return `null` instead.
|
|
54
|
+
*
|
|
55
|
+
* If composable, return a combined view with the same shape as `v1`.
|
|
56
|
+
*
|
|
57
|
+
* This is very tricky. The shapes of v1 and v2 may be different, and in that
|
|
58
|
+
* case, we do some math to figure out whether they're compatible.
|
|
59
|
+
*/
|
|
60
|
+
compose(v1: View): View | null;
|
|
61
|
+
/** Attempt to simplify this view into a smaller reshaped form. */
|
|
62
|
+
minify(): View;
|
|
63
|
+
/** Pad the view with zeros on each dimension. */
|
|
64
|
+
pad(arg: Pair[]): View;
|
|
65
|
+
/** Shrink the view by taking a subarray. */
|
|
66
|
+
shrink(arg: Pair[]): View;
|
|
67
|
+
/** Expand one or more axes with length "1" by repeating the data. */
|
|
68
|
+
expand(newShape: number[]): View;
|
|
69
|
+
/** Permute the axes of an array. */
|
|
70
|
+
permute(axis: number[]): View;
|
|
71
|
+
/** Flip (reverse) one or more axes of the view. */
|
|
72
|
+
flip(arg: boolean[]): View;
|
|
73
|
+
/** Reshape the view into a new shape. */
|
|
74
|
+
reshape(newShape: number[]): View | null;
|
|
75
|
+
}
|
|
76
|
+
/**
|
|
77
|
+
* Array shape after applying movement operations, as a series of views.
|
|
78
|
+
*
|
|
79
|
+
* Each view is applied, then treated as if it were a contiguous array of its
|
|
80
|
+
* shape, then used as the virtual buffer for the next view.
|
|
81
|
+
*/
|
|
82
|
+
declare class ShapeTracker {
|
|
83
|
+
readonly views: View[];
|
|
84
|
+
constructor(views: View[]);
|
|
85
|
+
/** Compose this shape tracker with another, applying after. */
|
|
86
|
+
compose(other: ShapeTracker): ShapeTracker;
|
|
87
|
+
static fromShape(shape: number[]): ShapeTracker;
|
|
88
|
+
get contiguous(): boolean;
|
|
89
|
+
get consecutive(): boolean;
|
|
90
|
+
get lastStrides(): number[];
|
|
91
|
+
get shape(): number[];
|
|
92
|
+
get size(): number;
|
|
93
|
+
toAluExp(idxs: AluExp[]): [AluExp, AluExp];
|
|
94
|
+
simplify(): ShapeTracker;
|
|
95
|
+
pad(arg: Pair[]): ShapeTracker;
|
|
96
|
+
shrink(arg: Pair[]): ShapeTracker;
|
|
97
|
+
expand(newShape: number[]): ShapeTracker;
|
|
98
|
+
permute(axis: number[]): ShapeTracker;
|
|
99
|
+
flip(arg: boolean[]): ShapeTracker;
|
|
100
|
+
reshape(newShape: number[]): ShapeTracker;
|
|
101
|
+
/** Broadcast along the given new axes, then expand the shape. */
|
|
102
|
+
broadcast(newShape: number[], axis: number[]): ShapeTracker;
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
type RecursiveArray<T> = T | RecursiveArray<T>[];
|
|
106
|
+
interface FpHashable {
|
|
107
|
+
hash(state: FpHash): void;
|
|
108
|
+
}
|
|
109
|
+
/**
|
|
110
|
+
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
111
|
+
* Probability-wise, it's good enough to be used for something like
|
|
112
|
+
* deduplicating seen compiler expressions, although it's not adversarial.
|
|
113
|
+
*
|
|
114
|
+
* See https://en.wikipedia.org/wiki/Lagrange%27s_theorem_(number_theory)
|
|
115
|
+
*/
|
|
116
|
+
declare class FpHash {
|
|
117
|
+
#private;
|
|
118
|
+
value: bigint;
|
|
119
|
+
update(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): this;
|
|
120
|
+
static hash(...values: (string | boolean | number | bigint | null | undefined | FpHashable)[]): bigint;
|
|
121
|
+
}
|
|
122
|
+
|
|
123
|
+
declare enum DType {
|
|
124
|
+
Float32 = "float32",
|
|
125
|
+
Int32 = "int32",
|
|
126
|
+
Bool = "bool",
|
|
127
|
+
Complex64 = "complex64"
|
|
128
|
+
}
|
|
129
|
+
/**
|
|
130
|
+
* Mathematical expression on scalar values.
|
|
131
|
+
*
|
|
132
|
+
* This is similiar to and based on tinygrad's UOp class, but it's more specific
|
|
133
|
+
* to just math on scalars. We're doing this to avoid the complexity of a full
|
|
134
|
+
* graph rewrite engine.
|
|
135
|
+
*/
|
|
136
|
+
declare class AluExp implements FpHashable {
|
|
137
|
+
#private;
|
|
138
|
+
readonly op: AluOp;
|
|
139
|
+
readonly dtype: DType;
|
|
140
|
+
readonly src: AluExp[];
|
|
141
|
+
readonly arg: any;
|
|
142
|
+
constructor(op: AluOp, dtype: DType, src: AluExp[], arg?: any);
|
|
143
|
+
static add(a: AluExp, b: AluExp): AluExp;
|
|
144
|
+
static sub(a: AluExp, b: AluExp): AluExp;
|
|
145
|
+
static mul(a: AluExp, b: AluExp): AluExp;
|
|
146
|
+
static idiv(a: AluExp, b: AluExp): AluExp;
|
|
147
|
+
static mod(a: AluExp, b: AluExp): AluExp;
|
|
148
|
+
static min(a: AluExp, b: AluExp): AluExp;
|
|
149
|
+
static max(a: AluExp, b: AluExp): AluExp;
|
|
150
|
+
static sin(a: AluExp): AluExp;
|
|
151
|
+
static cos(a: AluExp): AluExp;
|
|
152
|
+
static exp(a: AluExp): AluExp;
|
|
153
|
+
static log(a: AluExp): AluExp;
|
|
154
|
+
static reciprocal(a: AluExp): AluExp;
|
|
155
|
+
static cast(dtype: DType, a: AluExp): AluExp;
|
|
156
|
+
static cmplt(a: AluExp, b: AluExp): AluExp;
|
|
157
|
+
static cmpne(a: AluExp, b: AluExp): AluExp;
|
|
158
|
+
static where(cond: AluExp, a: AluExp, b: AluExp): AluExp;
|
|
159
|
+
static const(dtype: DType, value: any): AluExp;
|
|
160
|
+
static special(dtype: DType, name: string, n: number): AluExp;
|
|
161
|
+
static variable(dtype: DType, name: string): AluExp;
|
|
162
|
+
static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
|
|
163
|
+
static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
|
|
164
|
+
static i32(value: number): AluExp;
|
|
165
|
+
static f32(value: number): AluExp;
|
|
166
|
+
static bool(value: boolean): AluExp;
|
|
167
|
+
not(): AluExp;
|
|
168
|
+
/** Compute a reasonable expression hash with low collision rate. */
|
|
169
|
+
getHash(): bigint;
|
|
170
|
+
hash(state: FpHash): void;
|
|
171
|
+
/** Substitute variables in this AluExp to values. */
|
|
172
|
+
substitute(variables: Record<string, AluExp>): AluExp;
|
|
173
|
+
/** Reindex gid values in this expression as needed. */
|
|
174
|
+
reindexGids(gidMap: Map<number, number>): AluExp;
|
|
175
|
+
get min(): number;
|
|
176
|
+
get max(): number;
|
|
177
|
+
/**
|
|
178
|
+
* Simplify the expression by replacing any known patterns and deduping
|
|
179
|
+
* identical subexpressions.
|
|
180
|
+
*/
|
|
181
|
+
simplify(cache?: Map<bigint, AluExp>): AluExp;
|
|
182
|
+
/** Resolve this to a value, or `undefined` if not possible. */
|
|
183
|
+
resolve(): any | undefined;
|
|
184
|
+
/**
|
|
185
|
+
* Evaluate the expression on CPU, returning the result.
|
|
186
|
+
*
|
|
187
|
+
* Typically you would compile the AluExp as a representation to a lower-level
|
|
188
|
+
* language. This is just to define the semantics and help debug.
|
|
189
|
+
*
|
|
190
|
+
* Note that the representation of Bool is as a number (0 or 1) here.
|
|
191
|
+
*/
|
|
192
|
+
evaluate(context: Record<string, any>, globals?: (gid: number, bufidx: number) => any): any;
|
|
193
|
+
/** Get this expression in debug format as a string. */
|
|
194
|
+
toString(): string;
|
|
195
|
+
/** Generic fold() operation with a reducer over the expression tree. */
|
|
196
|
+
fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
|
|
197
|
+
/** Rewrite the expression recursively using a visitor. */
|
|
198
|
+
rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
|
|
199
|
+
/** Collect all nodes that satisfy a predicate. */
|
|
200
|
+
collect(predicate: (exp: AluExp) => boolean): AluExp[];
|
|
201
|
+
}
|
|
202
|
+
/** Symbolic form for each mathematical operation. */
|
|
203
|
+
declare enum AluOp {
|
|
204
|
+
Add = "Add",
|
|
205
|
+
Sub = "Sub",
|
|
206
|
+
Mul = "Mul",
|
|
207
|
+
Idiv = "Idiv",
|
|
208
|
+
Mod = "Mod",
|
|
209
|
+
Min = "Min",
|
|
210
|
+
Max = "Max",
|
|
211
|
+
Sin = "Sin",
|
|
212
|
+
Cos = "Cos",
|
|
213
|
+
Exp = "Exp",
|
|
214
|
+
Log = "Log",
|
|
215
|
+
Reciprocal = "Reciprocal",
|
|
216
|
+
Cast = "Cast",
|
|
217
|
+
Cmplt = "Cmplt",
|
|
218
|
+
Cmpne = "Cmpne",
|
|
219
|
+
Where = "Where",// Ternary operator: `cond ? a : b`
|
|
220
|
+
Const = "Const",// arg = value
|
|
221
|
+
Special = "Special",// arg = [variable, n]
|
|
222
|
+
Variable = "Variable",// arg = variable
|
|
223
|
+
GlobalIndex = "GlobalIndex",// arg = gid; src = [bufidx]
|
|
224
|
+
GlobalView = "GlobalView"
|
|
225
|
+
}
|
|
226
|
+
/**
|
|
227
|
+
* Description of a kernel to be compiled.
|
|
228
|
+
*
|
|
229
|
+
* Each of these can be processed by a backend into some lower-level
|
|
230
|
+
* representation. It consists of one or more fused operations, optionally
|
|
231
|
+
* indexing into a buffer.
|
|
232
|
+
*/
|
|
233
|
+
declare class Kernel implements FpHashable {
|
|
234
|
+
/** Number of global arguments / arrays. */
|
|
235
|
+
readonly nargs: number;
|
|
236
|
+
/** Size of the result array in element count. */
|
|
237
|
+
readonly size: number;
|
|
238
|
+
/** Expression to be evaluated. */
|
|
239
|
+
readonly exp: AluExp;
|
|
240
|
+
/** Optional reduction to be performed. */
|
|
241
|
+
readonly reduction?: Reduction | undefined;
|
|
242
|
+
constructor(
|
|
243
|
+
/** Number of global arguments / arrays. */
|
|
244
|
+
nargs: number,
|
|
245
|
+
/** Size of the result array in element count. */
|
|
246
|
+
size: number,
|
|
247
|
+
/** Expression to be evaluated. */
|
|
248
|
+
exp: AluExp,
|
|
249
|
+
/** Optional reduction to be performed. */
|
|
250
|
+
reduction?: Reduction | undefined);
|
|
251
|
+
hash(state: FpHash): void;
|
|
252
|
+
}
|
|
253
|
+
/**
|
|
254
|
+
* Description of a reduction.
|
|
255
|
+
*
|
|
256
|
+
* The strategy of jax-js backends is to either handle a standard operation that
|
|
257
|
+
* is dispatched in a vectorized way over an array, or to reduce over one axis
|
|
258
|
+
* of some computation. This is a description of the reduction.
|
|
259
|
+
*
|
|
260
|
+
* Reduction only supports a few operations, and only over one axis. Users can
|
|
261
|
+
* always `flatten()` the array before reducing if needed.
|
|
262
|
+
*
|
|
263
|
+
* The backend is responsible for implementing the reduction in a way that
|
|
264
|
+
* minimizes the number of global memory loads, for efficiency. This involves
|
|
265
|
+
* passing through some optimization strategy. But optimizations are not coded
|
|
266
|
+
* at this level since they depend on GPU, versus CPU or Wasm.
|
|
267
|
+
*/
|
|
268
|
+
declare class Reduction implements FpHashable {
|
|
269
|
+
/** Data type of the values being reduced over. */
|
|
270
|
+
readonly dtype: DType;
|
|
271
|
+
/** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
272
|
+
readonly op: AluOp;
|
|
273
|
+
/** Size of the reduction axis. */
|
|
274
|
+
readonly size: number;
|
|
275
|
+
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
276
|
+
readonly fusion: AluExp;
|
|
277
|
+
constructor(
|
|
278
|
+
/** Data type of the values being reduced over. */
|
|
279
|
+
dtype: DType,
|
|
280
|
+
/** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
|
|
281
|
+
op: AluOp,
|
|
282
|
+
/** Size of the reduction axis. */
|
|
283
|
+
size: number,
|
|
284
|
+
/** Follow-up expression defined with the "acc" variable, defaults to identity. */
|
|
285
|
+
fusion?: AluExp);
|
|
286
|
+
hash(state: FpHash): void;
|
|
287
|
+
/** Get the identity for this reduction operation. */
|
|
288
|
+
get identity(): any;
|
|
289
|
+
/** Evaluate this operation on CPU. */
|
|
290
|
+
evaluate(...values: any): any;
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
/**
|
|
294
|
+
* @file Shared interfaces and code for the low-level backend API.
|
|
295
|
+
*
|
|
296
|
+
* Think of each backend as a _connector_ to a specific hardware or software
|
|
297
|
+
* implementation of the array API.
|
|
298
|
+
*
|
|
299
|
+
* Backends do not share any of the built-in operational semantics of the
|
|
300
|
+
* library. This is a private API. You must allocate and free buffers manually,
|
|
301
|
+
* and dispatch happens on the level of each shader. Buffers are untyped.
|
|
302
|
+
*/
|
|
303
|
+
|
|
304
|
+
type Device = "cpu" | "webgpu";
|
|
305
|
+
declare const devices: Device[];
|
|
306
|
+
/** Set the default device backend (must be initialized). */
|
|
307
|
+
declare function setDevice(device: Device): void;
|
|
308
|
+
/**
|
|
309
|
+
* Initialize `jax-js` library backends.
|
|
310
|
+
*
|
|
311
|
+
* By default, this will initialize all available backends. If one or more
|
|
312
|
+
* backends is provided, only attempt to initialize those. Returns a list of
|
|
313
|
+
* available backends.
|
|
314
|
+
*/
|
|
315
|
+
declare function init(...devicesToInit: Device[]): Promise<Device[]>;
|
|
316
|
+
/** Unique identifier for an allocated, on-device buffer. */
|
|
317
|
+
type Slot = number;
|
|
318
|
+
/** A device backend. */
|
|
319
|
+
interface Backend {
|
|
320
|
+
/** The name of the backend as a string. */
|
|
321
|
+
readonly type: Device;
|
|
322
|
+
/** Maximum number of arguments per dispatched kernel. */
|
|
323
|
+
readonly maxArgs: number;
|
|
324
|
+
/** Allocate a new slot with reference count 1. */
|
|
325
|
+
malloc(size: number, initialData?: ArrayBuffer): Slot;
|
|
326
|
+
/** Increment the reference count of the slot. */
|
|
327
|
+
incRef(slot: Slot): void;
|
|
328
|
+
/**
|
|
329
|
+
* Decrement the reference count of the slot. If the reference count reaches
|
|
330
|
+
* zero, it is freed. This should throw if the slot was already freed.
|
|
331
|
+
*/
|
|
332
|
+
decRef(slot: Slot): void;
|
|
333
|
+
/** Read a range of bytes from a buffer. */
|
|
334
|
+
read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer>;
|
|
335
|
+
/** Read a range of bytes from a buffer, blocking variant. */
|
|
336
|
+
readSync(slot: Slot, start?: number, count?: number): ArrayBuffer;
|
|
337
|
+
/** Prepare an expression to be executed later. */
|
|
338
|
+
prepare(kernel: Kernel): Promise<Executable>;
|
|
339
|
+
/** Prepare an expression to be executed later, blocking variant. */
|
|
340
|
+
prepareSync(kernel: Kernel): Executable;
|
|
341
|
+
/**
|
|
342
|
+
* Run a backend operation that was previously prepared.
|
|
343
|
+
*
|
|
344
|
+
* The operation may not run immediately, but operations are guaranteed to run
|
|
345
|
+
* in the dispatch order. Also, `read()` will wait for all pending operations
|
|
346
|
+
* on that slot to finish.
|
|
347
|
+
*/
|
|
348
|
+
dispatch(exe: Executable, inputs: Slot[], outputs: Slot[]): void;
|
|
349
|
+
}
|
|
350
|
+
declare class Executable<T = any> {
|
|
351
|
+
readonly kernel: Kernel;
|
|
352
|
+
/** Extra data specific to the backend running this kernel. */
|
|
353
|
+
readonly data: T;
|
|
354
|
+
constructor(kernel: Kernel,
|
|
355
|
+
/** Extra data specific to the backend running this kernel. */
|
|
356
|
+
data: T);
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
/** @file Utilities for working with tree-like container data structures ("pytrees"). */
|
|
360
|
+
declare enum NodeType {
|
|
361
|
+
Array = "Array",
|
|
362
|
+
Object = "Object",
|
|
363
|
+
Leaf = "Leaf"
|
|
364
|
+
}
|
|
365
|
+
type JsTree<T> = T | JsTree<T>[] | {
|
|
366
|
+
[key: string]: JsTree<T>;
|
|
367
|
+
};
|
|
368
|
+
type MapJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : {
|
|
369
|
+
[K in keyof T]: MapJsTree<T[K], A, B>;
|
|
370
|
+
} : {
|
|
371
|
+
[K in keyof T]: MapJsTree<T[K], A, B>;
|
|
372
|
+
};
|
|
373
|
+
/** Analog to the JAX "pytree" object, but for JavaScript. */
|
|
374
|
+
declare class JsTreeDef {
|
|
375
|
+
readonly nodeType: NodeType;
|
|
376
|
+
readonly nodeMetadata: any;
|
|
377
|
+
readonly childTreedefs: JsTreeDef[];
|
|
378
|
+
static leaf: JsTreeDef;
|
|
379
|
+
constructor(nodeType: NodeType, nodeMetadata: any, // Must be comparable with deepEqual.
|
|
380
|
+
childTreedefs: JsTreeDef[]);
|
|
381
|
+
/** Returns a string representation of this tree definition. */
|
|
382
|
+
toString(root?: boolean): string;
|
|
383
|
+
/** Compare this tree definition with another. */
|
|
384
|
+
equals(other: JsTreeDef): boolean;
|
|
385
|
+
}
|
|
386
|
+
/** Flatten a structured object, returning the tree definition. */
|
|
387
|
+
declare function flatten<T>(tree: JsTree<T>): [T[], JsTreeDef];
|
|
388
|
+
/** Get the leaves of a tree. */
|
|
389
|
+
declare function leaves<T>(tree: JsTree<T>): T[];
|
|
390
|
+
/** Get the treedef for a tree. */
|
|
391
|
+
declare function structure<T>(tree: JsTree<T>): JsTreeDef;
|
|
392
|
+
/** Reconstruct a structured object from the flattened representation. */
|
|
393
|
+
declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T>;
|
|
394
|
+
|
|
395
|
+
type tree_JsTree<T> = JsTree<T>;
|
|
396
|
+
type tree_JsTreeDef = JsTreeDef;
|
|
397
|
+
declare const tree_JsTreeDef: typeof JsTreeDef;
|
|
398
|
+
type tree_MapJsTree<T, A, B> = MapJsTree<T, A, B>;
|
|
399
|
+
type tree_NodeType = NodeType;
|
|
400
|
+
declare const tree_NodeType: typeof NodeType;
|
|
401
|
+
declare const tree_flatten: typeof flatten;
|
|
402
|
+
declare const tree_leaves: typeof leaves;
|
|
403
|
+
declare const tree_structure: typeof structure;
|
|
404
|
+
declare const tree_unflatten: typeof unflatten;
|
|
405
|
+
declare namespace tree {
|
|
406
|
+
export { type tree_JsTree as JsTree, tree_JsTreeDef as JsTreeDef, type tree_MapJsTree as MapJsTree, tree_NodeType as NodeType, tree_flatten as flatten, tree_leaves as leaves, tree_structure as structure, tree_unflatten as unflatten };
|
|
407
|
+
}
|
|
408
|
+
|
|
409
|
+
/** @file Core library internals and interpreter stack, based on Autodidax. */
|
|
410
|
+
|
|
411
|
+
declare enum Primitive {
|
|
412
|
+
Add = "add",
|
|
413
|
+
Mul = "mul",
|
|
414
|
+
Idiv = "idiv",
|
|
415
|
+
Neg = "neg",
|
|
416
|
+
Reciprocal = "reciprocal",
|
|
417
|
+
Sin = "sin",
|
|
418
|
+
Cos = "cos",
|
|
419
|
+
Exp = "exp",
|
|
420
|
+
Log = "log",
|
|
421
|
+
Min = "min",
|
|
422
|
+
Max = "max",
|
|
423
|
+
ReduceSum = "reduce_sum",
|
|
424
|
+
Compare = "compare",
|
|
425
|
+
Where = "where",
|
|
426
|
+
Transpose = "transpose",
|
|
427
|
+
Broadcast = "broadcast",
|
|
428
|
+
Reshape = "reshape",
|
|
429
|
+
Flip = "flip",
|
|
430
|
+
JitCall = "jit_call"
|
|
431
|
+
}
|
|
432
|
+
type MainTrace = {
|
|
433
|
+
level: number;
|
|
434
|
+
traceType: new (main: MainTrace) => Trace;
|
|
435
|
+
globalData: any | null;
|
|
436
|
+
};
|
|
437
|
+
type TracerValue = Tracer | number | boolean;
|
|
438
|
+
declare abstract class Trace {
|
|
439
|
+
readonly main: MainTrace;
|
|
440
|
+
constructor(main: MainTrace);
|
|
441
|
+
abstract pure(val: TracerValue): Tracer;
|
|
442
|
+
abstract lift(val: Tracer): Tracer;
|
|
443
|
+
abstract processPrimitive(primitive: Primitive, tracers: Tracer[], params: Record<string, any>): Tracer[];
|
|
444
|
+
}
|
|
445
|
+
interface AbstractValue {
|
|
446
|
+
shape: number[];
|
|
447
|
+
dtype: DType;
|
|
448
|
+
}
|
|
449
|
+
declare abstract class Tracer {
|
|
450
|
+
readonly _trace: Trace;
|
|
451
|
+
constructor(trace: Trace);
|
|
452
|
+
abstract get aval(): AbstractValue;
|
|
453
|
+
abstract toString(): string;
|
|
454
|
+
/**
|
|
455
|
+
* Access an array by reference, incrementing the reference count.
|
|
456
|
+
*
|
|
457
|
+
* jax-js handles freeing arrays by using "move" semantics, like in Rust/C++.
|
|
458
|
+
* Whenever you pass an array into a function, that function should consume
|
|
459
|
+
* the array, and it will no longer be usable. For example, if you had:
|
|
460
|
+
*
|
|
461
|
+
* ```
|
|
462
|
+
* const x = np.array([1, 2, 3]);
|
|
463
|
+
* const y = np.add(x, x);
|
|
464
|
+
* ```
|
|
465
|
+
*
|
|
466
|
+
* The second line does not work because the first parameter consumes `x`, and
|
|
467
|
+
* then the second parameter will already have been freed / disposed.
|
|
468
|
+
*
|
|
469
|
+
* To fix this, you can write:
|
|
470
|
+
*
|
|
471
|
+
* ```
|
|
472
|
+
* const y = np.add(x.ref, x);
|
|
473
|
+
* ```
|
|
474
|
+
*
|
|
475
|
+
* Under the hood, every access to `.ref` increments the internal reference
|
|
476
|
+
* count of the array. The reference count starts at 1. When it hits 0, the
|
|
477
|
+
* memory behind the array is freed.
|
|
478
|
+
*/
|
|
479
|
+
abstract get ref(): this;
|
|
480
|
+
/**
|
|
481
|
+
* Manually decrement the reference count of the array.
|
|
482
|
+
*
|
|
483
|
+
* Arrays are created with reference count 1. Whenever it is used as argument
|
|
484
|
+
* to a function or other operation, it is disposed (i.e., reference count
|
|
485
|
+
* decreases by 1) automatically. Whenever a `.ref` is created, the reference
|
|
486
|
+
* count increases.
|
|
487
|
+
*
|
|
488
|
+
* You generally don't need to call this function directly since arrays are
|
|
489
|
+
* automatically disposed after being passed into an operation. One common
|
|
490
|
+
* exception is when writing a function and ignoring one of its arguments. In
|
|
491
|
+
* that case, by convention you should dispose of that argument manually.
|
|
492
|
+
*
|
|
493
|
+
* ```
|
|
494
|
+
* function myCustomOperation(a: np.Array, b: np.Array) {
|
|
495
|
+
* b.dispose(); // Needed to satisfy "move" rules.
|
|
496
|
+
* return a.add(1);
|
|
497
|
+
* }
|
|
498
|
+
* ```
|
|
499
|
+
*/
|
|
500
|
+
abstract dispose(): void;
|
|
501
|
+
get shape(): number[];
|
|
502
|
+
get dtype(): DType;
|
|
503
|
+
get ndim(): number;
|
|
504
|
+
fullLower(): Tracer;
|
|
505
|
+
neg(): this;
|
|
506
|
+
add(other: this | TracerValue): this;
|
|
507
|
+
mul(other: this | TracerValue): this;
|
|
508
|
+
greater(other: this | TracerValue): this;
|
|
509
|
+
less(other: this | TracerValue): this;
|
|
510
|
+
equal(other: this | TracerValue): this;
|
|
511
|
+
notEqual(other: this | TracerValue): this;
|
|
512
|
+
greaterEqual(other: this | TracerValue): this;
|
|
513
|
+
lessEqual(other: this | TracerValue): this;
|
|
514
|
+
sum(axis?: number | number[]): this;
|
|
515
|
+
transpose(perm?: number[]): this;
|
|
516
|
+
reshape(shape: number | number[]): this;
|
|
517
|
+
/** Subtract an array from this one. */
|
|
518
|
+
sub(other: this | TracerValue): this;
|
|
519
|
+
/** Divide an array by this one. */
|
|
520
|
+
div(other: this | TracerValue): this;
|
|
521
|
+
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
522
|
+
diagonal(offset?: number, axis1?: number, axis2?: number): this;
|
|
523
|
+
/** Flatten the array without changing its data. */
|
|
524
|
+
flatten(): this;
|
|
525
|
+
/** Flatten the array without changing its data. */
|
|
526
|
+
ravel(): this;
|
|
527
|
+
}
|
|
528
|
+
declare class ShapedArray implements AbstractValue {
|
|
529
|
+
readonly shape: number[];
|
|
530
|
+
readonly dtype: DType;
|
|
531
|
+
constructor(shape: number[], dtype: DType);
|
|
532
|
+
static fromAval(aval: AbstractValue): ShapedArray;
|
|
533
|
+
get ndim(): number;
|
|
534
|
+
strShort(): string;
|
|
535
|
+
equals(other: ShapedArray): boolean;
|
|
536
|
+
}
|
|
537
|
+
|
|
538
|
+
type ArrayLike = Array | number | boolean;
|
|
539
|
+
/**
|
|
540
|
+
* An executable operation that will be dispatched to the backend.
|
|
541
|
+
*
|
|
542
|
+
* This holds a reference to all input buffers used in the operation. After the
|
|
543
|
+
* operation is dispatched, the references should be released.
|
|
544
|
+
*/
|
|
545
|
+
declare class PendingExecute {
|
|
546
|
+
#private;
|
|
547
|
+
readonly backend: Backend;
|
|
548
|
+
readonly kernel: Kernel;
|
|
549
|
+
readonly inputs: Slot[];
|
|
550
|
+
readonly outputs: Slot[];
|
|
551
|
+
prepared: Executable | null;
|
|
552
|
+
submitted: boolean;
|
|
553
|
+
constructor(backend: Backend, kernel: Kernel, inputs: Slot[], outputs: Slot[]);
|
|
554
|
+
updateRc(delta: number): void;
|
|
555
|
+
prepare(): Promise<void>;
|
|
556
|
+
prepareSync(): void;
|
|
557
|
+
submit(): void;
|
|
558
|
+
}
|
|
559
|
+
type DTypeAndDevice = {
|
|
560
|
+
dtype?: DType;
|
|
561
|
+
device?: Device;
|
|
562
|
+
};
|
|
563
|
+
/**
|
|
564
|
+
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
565
|
+
*
|
|
566
|
+
* This is the library's core data type. Equivalent to `jnp.Array` from JAX, or
|
|
567
|
+
* `torch.Tensor`.
|
|
568
|
+
*
|
|
569
|
+
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
570
|
+
* this into your code's namespace if you're already using the JavaScript
|
|
571
|
+
* "Array" type by name.
|
|
572
|
+
*/
|
|
573
|
+
declare class Array extends Tracer {
|
|
574
|
+
#private;
|
|
575
|
+
id: number;
|
|
576
|
+
constructor(source: AluExp | Slot, st: ShapeTracker, dtype: DType, backend: Backend, pending?: Iterable<PendingExecute> | null);
|
|
577
|
+
get aval(): ShapedArray;
|
|
578
|
+
/** Return a simple string representation of the array's dimensions. */
|
|
579
|
+
toString(): string;
|
|
580
|
+
get device(): Device;
|
|
581
|
+
get ref(): this;
|
|
582
|
+
dispose(): void;
|
|
583
|
+
/**
|
|
584
|
+
* Convert this array into a primitive value.
|
|
585
|
+
*
|
|
586
|
+
* This only works for scalars (0-dimensional arrays). It lets you get values
|
|
587
|
+
* "out" of the JAX system. For instance, if `x = np.array(5)`, then you can
|
|
588
|
+
* evaluate `x + 1` and `x ** 2` to get `6` and `25`, respectively.
|
|
589
|
+
*
|
|
590
|
+
* This method is also called for `==` equality.
|
|
591
|
+
*/
|
|
592
|
+
[Symbol.toPrimitive](): any;
|
|
593
|
+
/** Realize the array and return it as data. */
|
|
594
|
+
data(): Promise<Float32Array | Int32Array>;
|
|
595
|
+
/** Wait for this array to be placed on the backend, if needed. */
|
|
596
|
+
wait(): Promise<void>;
|
|
597
|
+
/**
|
|
598
|
+
* Realize the array and return it as data. This is a sync variant and not
|
|
599
|
+
* recommended for performance reasons, as it will block rendering.
|
|
600
|
+
*/
|
|
601
|
+
dataSync(): Float32Array | Int32Array;
|
|
602
|
+
/** Convert this array into a JavaScript object (blocking). */
|
|
603
|
+
js(): RecursiveArray<number> | RecursiveArray<boolean>;
|
|
604
|
+
/** Convert this array into a JavaScript object, asynchronously. */
|
|
605
|
+
jsAsync(): Promise<RecursiveArray<number> | RecursiveArray<boolean>>;
|
|
606
|
+
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
607
|
+
static _implRules(): Record<Primitive, ImplRule>;
|
|
608
|
+
_realizeSource(): number;
|
|
609
|
+
}
|
|
610
|
+
/** Construct an array from a single scalar constant. */
|
|
611
|
+
declare function scalar(value: number | boolean, { dtype, device }?: DTypeAndDevice): Array;
|
|
612
|
+
/** Constructor for creating a new array from data. */
|
|
613
|
+
declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, { shape, dtype, device }?: {
|
|
614
|
+
shape?: number[];
|
|
615
|
+
} & DTypeAndDevice): Array;
|
|
616
|
+
type ImplRule = (tracers: Array[], params: any) => Array[];
|
|
617
|
+
/** Return a new array of given shape and type, filled with zeros. */
|
|
618
|
+
declare function zeros(shape: number[], { dtype, device }?: DTypeAndDevice): Array;
|
|
619
|
+
/** Return a new array of given shape and type, filled with ones. */
|
|
620
|
+
declare function ones(shape: number[], { dtype, device }?: DTypeAndDevice): Array;
|
|
621
|
+
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
622
|
+
declare function full(shape: number[], fillValue: number | boolean | Array, { dtype, device }?: DTypeAndDevice): Array;
|
|
623
|
+
/**
|
|
624
|
+
* Create an identity matrix.
|
|
625
|
+
*
|
|
626
|
+
* If numCols is not provided, it defaults to numRows, i.e., a square identity
|
|
627
|
+
* matrix with ones on the diagonal.
|
|
628
|
+
*/
|
|
629
|
+
declare function eye(numRows: number, numCols?: number, { dtype, device }?: DTypeAndDevice): Array;
|
|
630
|
+
/** Return the identity array, with ones on the main diagonal. */
|
|
631
|
+
declare function identity$1(n: number, { dtype, device }?: DTypeAndDevice): Array;
|
|
632
|
+
/**
|
|
633
|
+
* Return evenly spaced values within a given interval.
|
|
634
|
+
*
|
|
635
|
+
* This can be called with a varying number of arguments, just like the range()
|
|
636
|
+
* builtin function in Python.
|
|
637
|
+
*
|
|
638
|
+
* - `arange(stop)` is equivalent to `arange(0, stop, 1)`.
|
|
639
|
+
* - `arange(start, stop)` is equivalent to `arange(start, stop, 1)`.
|
|
640
|
+
* - `arange(start, stop, step)` creates an array starting at `start`, ending
|
|
641
|
+
* before `stop`, with a step size of `step`.
|
|
642
|
+
*
|
|
643
|
+
* Defaults to an integer data type. This can produce unintended results when
|
|
644
|
+
* using a non-integer step, so prefer linspace() in those cases.
|
|
645
|
+
*/
|
|
646
|
+
declare function arange(start: number, stop?: number, step?: number, { dtype, device }?: DTypeAndDevice): Array;
|
|
647
|
+
/**
|
|
648
|
+
* Return evenly spaced numbers over a specified interval.
|
|
649
|
+
*
|
|
650
|
+
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
651
|
+
* [`start`, `stop`]. The endpoint `stop` is included in the result by default,
|
|
652
|
+
* but this is controlled by the `endpoint` parameter.
|
|
653
|
+
*
|
|
654
|
+
* The default data type is Float32. Use arange() for integer steps.
|
|
655
|
+
*/
|
|
656
|
+
declare function linspace(start: number, stop: number, num?: number, endpoint?: boolean, { dtype, device }?: DTypeAndDevice): Array;
|
|
657
|
+
|
|
658
|
+
declare const float32 = DType.Float32;
|
|
659
|
+
declare const int32 = DType.Int32;
|
|
660
|
+
declare const bool = DType.Bool;
|
|
661
|
+
declare const complex64 = DType.Complex64;
|
|
662
|
+
/** Euler's constant, `e = 2.7182818284590...` */
|
|
663
|
+
declare const e: number;
|
|
664
|
+
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
665
|
+
declare const eulerGamma = 0.5772156649015329;
|
|
666
|
+
/** Positive infinity. */
|
|
667
|
+
declare const inf: number;
|
|
668
|
+
/** Floating-point representation of NaN. */
|
|
669
|
+
declare const nan: number;
|
|
670
|
+
/** This is Pi, `π = 3.14159265358979...` */
|
|
671
|
+
declare const pi: number;
|
|
672
|
+
/** Element-wise addition, with broadcasting. */
|
|
673
|
+
declare const add: (x: ArrayLike, y: ArrayLike) => Array;
|
|
674
|
+
/** Element-wise multiplication, with broadcasting. */
|
|
675
|
+
declare const multiply: (x: ArrayLike, y: ArrayLike) => Array;
|
|
676
|
+
/** Numerical negative of every element of an array. */
|
|
677
|
+
declare const negative: (x: ArrayLike) => Array;
|
|
678
|
+
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
679
|
+
declare const reciprocal: (x: ArrayLike) => Array;
|
|
680
|
+
/** Element-wise sine function (takes radians). */
|
|
681
|
+
declare const sin: (x: ArrayLike) => Array;
|
|
682
|
+
/** Element-wise cosine function (takes radians). */
|
|
683
|
+
declare const cos: (x: ArrayLike) => Array;
|
|
684
|
+
/** Calculate the exponential of all elements in the input array. */
|
|
685
|
+
declare const exp: (x: ArrayLike) => Array;
|
|
686
|
+
/** Calculate the natural logarithm of all elements in the input array. */
|
|
687
|
+
declare const log: (x: ArrayLike) => Array;
|
|
688
|
+
/** Return element-wise minimum of the input arrays. */
|
|
689
|
+
declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
690
|
+
/** Return element-wise maximum of the input arrays. */
|
|
691
|
+
declare const maximum: (x: ArrayLike, y: ArrayLike) => Array;
|
|
692
|
+
/** Compare two arrays element-wise. */
|
|
693
|
+
declare const greater: (x: ArrayLike, y: ArrayLike) => Array;
|
|
694
|
+
/** Compare two arrays element-wise. */
|
|
695
|
+
declare const less: (x: ArrayLike, y: ArrayLike) => Array;
|
|
696
|
+
/** Compare two arrays element-wise. */
|
|
697
|
+
declare const equal: (x: ArrayLike, y: ArrayLike) => Array;
|
|
698
|
+
/** Compare two arrays element-wise. */
|
|
699
|
+
declare const notEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
700
|
+
/** Compare two arrays element-wise. */
|
|
701
|
+
declare const greaterEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
702
|
+
/** Compare two arrays element-wise. */
|
|
703
|
+
declare const lessEqual: (x: ArrayLike, y: ArrayLike) => Array;
|
|
704
|
+
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
705
|
+
declare const where: (cond: ArrayLike, x: ArrayLike, y: ArrayLike) => Array;
|
|
706
|
+
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
707
|
+
declare const transpose: (x: ArrayLike, perm?: number[]) => Array;
|
|
708
|
+
/**
|
|
709
|
+
* Give a new shape to an array without changing its data.
|
|
710
|
+
*
|
|
711
|
+
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
712
|
+
* length of the array and remaining dimensions.
|
|
713
|
+
*/
|
|
714
|
+
declare const reshape: (x: ArrayLike, shape: number[]) => Array;
|
|
715
|
+
declare const sum: (x: ArrayLike, axis?: number | number[]) => Array;
|
|
716
|
+
declare const moveaxis: (x: ArrayLike, src: number, dst: number) => Array;
|
|
717
|
+
/** Return the number of dimensions of an array. */
|
|
718
|
+
declare const ndim: (x: ArrayLike) => number;
|
|
719
|
+
/** Return the shape of an array. */
|
|
720
|
+
declare const shape: (x: ArrayLike) => number[];
|
|
721
|
+
/** Return the number of elements in an array, optionally along an axis. */
|
|
722
|
+
declare function size(a: ArrayLike, axis?: number): number;
|
|
723
|
+
/** Reverse the elements in an array along the given axes. */
|
|
724
|
+
declare function flip(x: ArrayLike, axis?: number | number[]): Array;
|
|
725
|
+
/** Flip an array vertically (axis=0). */
|
|
726
|
+
declare function flipud(x: ArrayLike): Array;
|
|
727
|
+
/** Flip an array horizontally (axis=1). */
|
|
728
|
+
declare function fliplr(x: ArrayLike): Array;
|
|
729
|
+
declare const permuteDims: (x: ArrayLike, perm?: number[]) => Array;
|
|
730
|
+
/** Return a 1-D flattened array containing the elements of the input. */
|
|
731
|
+
declare function ravel(a: ArrayLike): Array;
|
|
732
|
+
/**
|
|
733
|
+
* Return specified diagonals.
|
|
734
|
+
*
|
|
735
|
+
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
736
|
+
* 3D or higher, compute diagonals along the two given axes.
|
|
737
|
+
*
|
|
738
|
+
* This returns a view over the existing array.
|
|
739
|
+
*/
|
|
740
|
+
declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
|
|
741
|
+
/** Transposes a matrix or stack of matrices `x` (swap last two axes). */
|
|
742
|
+
declare function matrixTranspose(x: ArrayLike): Array;
|
|
743
|
+
/**
|
|
744
|
+
* Extract a diagonal or construct a diagonal array.
|
|
745
|
+
*
|
|
746
|
+
* If v is a 2D array, return the k-th diagonal of v (as a view). If v is a 1D
|
|
747
|
+
* array, return a 2D array with v on the k-th diagonal.
|
|
748
|
+
*/
|
|
749
|
+
declare function diag(v: ArrayLike, k?: number): Array;
|
|
750
|
+
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
751
|
+
declare function allclose(actual: Parameters<typeof array>[0], expected: Parameters<typeof array>[0], options?: {
|
|
752
|
+
rtol?: number;
|
|
753
|
+
atol?: number;
|
|
754
|
+
}): boolean;
|
|
755
|
+
/** Matrix product of two arrays. */
|
|
756
|
+
declare const matmul: (x: ArrayLike, y: ArrayLike) => Array;
|
|
757
|
+
/** Dot product of two arrays. */
|
|
758
|
+
declare const dot: (x: ArrayLike, y: ArrayLike) => Array;
|
|
759
|
+
/** Vector dot product of two arrays. */
|
|
760
|
+
declare const vecdot: (x: ArrayLike, y: ArrayLike) => Array;
|
|
761
|
+
/**
|
|
762
|
+
* Return the dot product of two vectors.
|
|
763
|
+
*
|
|
764
|
+
* Like vecdot() but flattens the arguments first into vectors.
|
|
765
|
+
*/
|
|
766
|
+
declare function vdot(x: ArrayLike, y: ArrayLike): Array;
|
|
767
|
+
/**
|
|
768
|
+
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
769
|
+
*
|
|
770
|
+
* Make N-D coordinate arrays for vectorized evaluations of N-D scalar/vector
|
|
771
|
+
* fields over N-D grids, given one-dimensional coordinate arrays x1, x2,…, xn.
|
|
772
|
+
*/
|
|
773
|
+
declare function meshgrid(xs: Array[], { indexing }?: {
|
|
774
|
+
indexing?: "xy" | "ij";
|
|
775
|
+
}): Array[];
|
|
776
|
+
/**
|
|
777
|
+
* Clip (limit) the values in an array.
|
|
778
|
+
*
|
|
779
|
+
* Given an interval, values outside the interval are clipped to the interval
|
|
780
|
+
* edges. For example, if an interval of [0, 1] is specified, values smaller
|
|
781
|
+
* than 0 become 0, and values larger than 1 become 1.
|
|
782
|
+
*
|
|
783
|
+
* If either bound is undefined, it is ignored.
|
|
784
|
+
*/
|
|
785
|
+
declare function clip(a: ArrayLike, min?: ArrayLike, max?: ArrayLike): Array;
|
|
786
|
+
/**
|
|
787
|
+
* Calculate the absolute value element-wise.
|
|
788
|
+
*
|
|
789
|
+
* This is the same function as `jax.numpy.abs()`.
|
|
790
|
+
*/
|
|
791
|
+
declare function absolute(x: ArrayLike): Array;
|
|
792
|
+
/** Alias of `jax.numpy.absolute()`. */
|
|
793
|
+
declare const abs: typeof absolute;
|
|
794
|
+
/** Calculate element-wise square of the input array. */
|
|
795
|
+
declare function square(x: ArrayLike): Array;
|
|
796
|
+
/** Compute a trigonometric tangent of each element of input. */
|
|
797
|
+
declare function tan(x: ArrayLike): Array;
|
|
798
|
+
/** Calculates the floating-point division of x by y element-wise. */
|
|
799
|
+
declare function trueDivide(x: ArrayLike, y: ArrayLike): Array;
|
|
800
|
+
/** Alias of `jax.numpy.trueDivide()`. */
|
|
801
|
+
declare const divide: typeof trueDivide;
|
|
802
|
+
/** Round input to the nearest integer towards zero. */
|
|
803
|
+
declare function trunc(x: ArrayLike): Array;
|
|
804
|
+
/** Calculate `2**p` for all p in the input array. */
|
|
805
|
+
declare function exp2(p: ArrayLike): Array;
|
|
806
|
+
/** Return the base-2 logarithm of x, element-wise. */
|
|
807
|
+
declare function log2(x: ArrayLike): Array;
|
|
808
|
+
/** Return the base-10 logarithm of x, element-wise. */
|
|
809
|
+
declare function log10(x: ArrayLike): Array;
|
|
810
|
+
|
|
811
|
+
type numpy_Array = Array;
|
|
812
|
+
declare const numpy_Array: typeof Array;
|
|
813
|
+
type numpy_ArrayLike = ArrayLike;
|
|
814
|
+
type numpy_DType = DType;
|
|
815
|
+
declare const numpy_DType: typeof DType;
|
|
816
|
+
declare const numpy_abs: typeof abs;
|
|
817
|
+
declare const numpy_absolute: typeof absolute;
|
|
818
|
+
declare const numpy_add: typeof add;
|
|
819
|
+
declare const numpy_allclose: typeof allclose;
|
|
820
|
+
declare const numpy_arange: typeof arange;
|
|
821
|
+
declare const numpy_array: typeof array;
|
|
822
|
+
declare const numpy_bool: typeof bool;
|
|
823
|
+
declare const numpy_clip: typeof clip;
|
|
824
|
+
declare const numpy_complex64: typeof complex64;
|
|
825
|
+
declare const numpy_cos: typeof cos;
|
|
826
|
+
declare const numpy_diag: typeof diag;
|
|
827
|
+
declare const numpy_diagonal: typeof diagonal;
|
|
828
|
+
declare const numpy_divide: typeof divide;
|
|
829
|
+
declare const numpy_dot: typeof dot;
|
|
830
|
+
declare const numpy_e: typeof e;
|
|
831
|
+
declare const numpy_equal: typeof equal;
|
|
832
|
+
declare const numpy_eulerGamma: typeof eulerGamma;
|
|
833
|
+
declare const numpy_exp: typeof exp;
|
|
834
|
+
declare const numpy_exp2: typeof exp2;
|
|
835
|
+
declare const numpy_eye: typeof eye;
|
|
836
|
+
declare const numpy_flip: typeof flip;
|
|
837
|
+
declare const numpy_fliplr: typeof fliplr;
|
|
838
|
+
declare const numpy_flipud: typeof flipud;
|
|
839
|
+
declare const numpy_float32: typeof float32;
|
|
840
|
+
declare const numpy_full: typeof full;
|
|
841
|
+
declare const numpy_greater: typeof greater;
|
|
842
|
+
declare const numpy_greaterEqual: typeof greaterEqual;
|
|
843
|
+
declare const numpy_inf: typeof inf;
|
|
844
|
+
declare const numpy_int32: typeof int32;
|
|
845
|
+
declare const numpy_less: typeof less;
|
|
846
|
+
declare const numpy_lessEqual: typeof lessEqual;
|
|
847
|
+
declare const numpy_linspace: typeof linspace;
|
|
848
|
+
declare const numpy_log: typeof log;
|
|
849
|
+
declare const numpy_log10: typeof log10;
|
|
850
|
+
declare const numpy_log2: typeof log2;
|
|
851
|
+
declare const numpy_matmul: typeof matmul;
|
|
852
|
+
declare const numpy_matrixTranspose: typeof matrixTranspose;
|
|
853
|
+
declare const numpy_maximum: typeof maximum;
|
|
854
|
+
declare const numpy_meshgrid: typeof meshgrid;
|
|
855
|
+
declare const numpy_minimum: typeof minimum;
|
|
856
|
+
declare const numpy_moveaxis: typeof moveaxis;
|
|
857
|
+
declare const numpy_multiply: typeof multiply;
|
|
858
|
+
declare const numpy_nan: typeof nan;
|
|
859
|
+
declare const numpy_ndim: typeof ndim;
|
|
860
|
+
declare const numpy_negative: typeof negative;
|
|
861
|
+
declare const numpy_notEqual: typeof notEqual;
|
|
862
|
+
declare const numpy_ones: typeof ones;
|
|
863
|
+
declare const numpy_permuteDims: typeof permuteDims;
|
|
864
|
+
declare const numpy_pi: typeof pi;
|
|
865
|
+
declare const numpy_ravel: typeof ravel;
|
|
866
|
+
declare const numpy_reciprocal: typeof reciprocal;
|
|
867
|
+
declare const numpy_reshape: typeof reshape;
|
|
868
|
+
declare const numpy_scalar: typeof scalar;
|
|
869
|
+
declare const numpy_shape: typeof shape;
|
|
870
|
+
declare const numpy_sin: typeof sin;
|
|
871
|
+
declare const numpy_size: typeof size;
|
|
872
|
+
declare const numpy_square: typeof square;
|
|
873
|
+
declare const numpy_sum: typeof sum;
|
|
874
|
+
declare const numpy_tan: typeof tan;
|
|
875
|
+
declare const numpy_transpose: typeof transpose;
|
|
876
|
+
declare const numpy_trueDivide: typeof trueDivide;
|
|
877
|
+
declare const numpy_trunc: typeof trunc;
|
|
878
|
+
declare const numpy_vdot: typeof vdot;
|
|
879
|
+
declare const numpy_vecdot: typeof vecdot;
|
|
880
|
+
declare const numpy_where: typeof where;
|
|
881
|
+
declare const numpy_zeros: typeof zeros;
|
|
882
|
+
declare namespace numpy {
|
|
883
|
+
export { numpy_Array as Array, type numpy_ArrayLike as ArrayLike, numpy_DType as DType, numpy_abs as abs, numpy_absolute as absolute, numpy_add as add, numpy_allclose as allclose, numpy_arange as arange, numpy_array as array, numpy_bool as bool, numpy_clip as clip, numpy_complex64 as complex64, numpy_cos as cos, numpy_diag as diag, numpy_diagonal as diagonal, numpy_divide as divide, numpy_dot as dot, numpy_e as e, numpy_equal as equal, numpy_eulerGamma as eulerGamma, numpy_exp as exp, numpy_exp2 as exp2, numpy_eye as eye, numpy_flip as flip, numpy_fliplr as fliplr, numpy_flipud as flipud, numpy_float32 as float32, numpy_full as full, numpy_greater as greater, numpy_greaterEqual as greaterEqual, identity$1 as identity, numpy_inf as inf, numpy_int32 as int32, numpy_less as less, numpy_lessEqual as lessEqual, numpy_linspace as linspace, numpy_log as log, numpy_log10 as log10, numpy_log2 as log2, numpy_matmul as matmul, numpy_matrixTranspose as matrixTranspose, numpy_maximum as maximum, numpy_meshgrid as meshgrid, numpy_minimum as minimum, numpy_moveaxis as moveaxis, numpy_multiply as multiply, numpy_nan as nan, numpy_ndim as ndim, numpy_negative as negative, numpy_notEqual as notEqual, numpy_ones as ones, numpy_permuteDims as permuteDims, numpy_pi as pi, numpy_ravel as ravel, numpy_reciprocal as reciprocal, numpy_reshape as reshape, numpy_scalar as scalar, numpy_shape as shape, numpy_sin as sin, numpy_size as size, numpy_square as square, numpy_sum as sum, numpy_tan as tan, numpy_transpose as transpose, numpy_trueDivide as trueDivide, numpy_trunc as trunc, numpy_vdot as vdot, numpy_vecdot as vecdot, numpy_where as where, numpy_zeros as zeros };
|
|
884
|
+
}
|
|
885
|
+
|
|
886
|
+
/** General class for pretty-printing expressions with indentation. */
|
|
887
|
+
declare class PPrint {
|
|
888
|
+
readonly indents: number[];
|
|
889
|
+
readonly lines: string[];
|
|
890
|
+
constructor(indents: number[], lines: string[]);
|
|
891
|
+
/** Add a fixed amount of indentation to each line. */
|
|
892
|
+
indent(spaces: number): PPrint;
|
|
893
|
+
/** Concatenate two or more pretty-printed expressions. */
|
|
894
|
+
concat(...items: PPrint[]): PPrint;
|
|
895
|
+
/** Stack one block to the right of another one, sharing 1 common line. */
|
|
896
|
+
stack(other: PPrint): PPrint;
|
|
897
|
+
/** Combine this block of lines into a formatted string. */
|
|
898
|
+
toString(): string;
|
|
899
|
+
static pp(s: Stringable): PPrint;
|
|
900
|
+
}
|
|
901
|
+
interface Stringable {
|
|
902
|
+
toString(): string;
|
|
903
|
+
}
|
|
904
|
+
|
|
905
|
+
/** Variable in a Jaxpr expression. */
|
|
906
|
+
declare class Var {
|
|
907
|
+
#private;
|
|
908
|
+
readonly id: number;
|
|
909
|
+
readonly aval: ShapedArray;
|
|
910
|
+
constructor(aval: ShapedArray);
|
|
911
|
+
toString(): string;
|
|
912
|
+
}
|
|
913
|
+
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
914
|
+
declare class Lit {
|
|
915
|
+
readonly dtype: DType;
|
|
916
|
+
readonly value: number;
|
|
917
|
+
readonly aval: ShapedArray;
|
|
918
|
+
constructor(dtype: DType, value: number);
|
|
919
|
+
}
|
|
920
|
+
type Atom = Var | Lit;
|
|
921
|
+
declare class VarPrinter {
|
|
922
|
+
#private;
|
|
923
|
+
names: Map<Var, string>;
|
|
924
|
+
name(v: Var): string;
|
|
925
|
+
nameType(v: Var): string;
|
|
926
|
+
}
|
|
927
|
+
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
928
|
+
declare class JaxprEqn {
|
|
929
|
+
readonly primitive: Primitive;
|
|
930
|
+
readonly inputs: Atom[];
|
|
931
|
+
readonly params: Record<string, any>;
|
|
932
|
+
readonly outBinders: Var[];
|
|
933
|
+
constructor(primitive: Primitive, inputs: Atom[], params: Record<string, any>, outBinders: Var[]);
|
|
934
|
+
pprint(usedVars?: Set<Var>, vp?: VarPrinter): PPrint;
|
|
935
|
+
toString(): string;
|
|
936
|
+
}
|
|
937
|
+
/** Typed intermediate representation for traced computations. */
|
|
938
|
+
declare class Jaxpr implements FpHashable {
|
|
939
|
+
#private;
|
|
940
|
+
readonly inBinders: Var[];
|
|
941
|
+
readonly eqns: JaxprEqn[];
|
|
942
|
+
readonly outs: Atom[];
|
|
943
|
+
constructor(inBinders: Var[], eqns: JaxprEqn[], outs: Atom[]);
|
|
944
|
+
pprint(): PPrint;
|
|
945
|
+
toString(): string;
|
|
946
|
+
/**
|
|
947
|
+
* Gets a hash of this Jaxpr.
|
|
948
|
+
*
|
|
949
|
+
* Var identity is not considered in the hash, so two Jaxprs with the same
|
|
950
|
+
* order of assignments and operators but different variable IDs will resolve
|
|
951
|
+
* to the same hash (and toString representation).
|
|
952
|
+
*/
|
|
953
|
+
getHash(): bigint;
|
|
954
|
+
hash(state: FpHash): void;
|
|
955
|
+
/**
|
|
956
|
+
* Produce a simplified Jaxpr with basic optimizations applied.
|
|
957
|
+
* - Trim away unused variables.
|
|
958
|
+
* - Fold away *1, *0, or +0 operations against literals.
|
|
959
|
+
*/
|
|
960
|
+
simplify(): Jaxpr;
|
|
961
|
+
/** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
|
|
962
|
+
flatten(): Jaxpr;
|
|
963
|
+
}
|
|
964
|
+
|
|
965
|
+
/**
|
|
966
|
+
* Rectified Linear Unit (ReLU) activation function:
|
|
967
|
+
* `relu(x) = max(x, 0)`.
|
|
968
|
+
*/
|
|
969
|
+
declare function relu(x: ArrayLike): Array;
|
|
970
|
+
/**
|
|
971
|
+
* Rectified Linear Unit 6 (ReLU6) activation function:
|
|
972
|
+
* `relu6(x) = min(max(x, 0), 6)`.
|
|
973
|
+
*/
|
|
974
|
+
declare function relu6(x: ArrayLike): Array;
|
|
975
|
+
/**
|
|
976
|
+
* Sigmoid activation function, computed element-wise:
|
|
977
|
+
* `sigmoid(x) = 1 / (1 + exp(-x))`.
|
|
978
|
+
*
|
|
979
|
+
* Reference: https://en.wikipedia.org/wiki/Sigmoid_function
|
|
980
|
+
*/
|
|
981
|
+
declare function sigmoid(x: ArrayLike): Array;
|
|
982
|
+
/**
|
|
983
|
+
* Softplus activation function:
|
|
984
|
+
* `softplus(x) = log(1 + exp(x))`.
|
|
985
|
+
*
|
|
986
|
+
* Reference: https://en.wikipedia.org/wiki/Softplus
|
|
987
|
+
*/
|
|
988
|
+
declare function softplus(x: ArrayLike): Array;
|
|
989
|
+
/**
|
|
990
|
+
* Soft-sign activation function, computed element-wise:
|
|
991
|
+
* `softsign(x) = x / (|x| + 1)`.
|
|
992
|
+
*/
|
|
993
|
+
declare function softSign(x: ArrayLike): Array;
|
|
994
|
+
/**
|
|
995
|
+
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
996
|
+
* Swish, computed element-wise:
|
|
997
|
+
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
998
|
+
*
|
|
999
|
+
* `swish()` and `silu()` are both aliases for the same function.
|
|
1000
|
+
*
|
|
1001
|
+
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1002
|
+
*/
|
|
1003
|
+
declare function silu(x: ArrayLike): Array;
|
|
1004
|
+
/**
|
|
1005
|
+
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
1006
|
+
* Swish, computed element-wise:
|
|
1007
|
+
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
1008
|
+
*
|
|
1009
|
+
* `swish()` and `silu()` are both aliases for the same function.
|
|
1010
|
+
*
|
|
1011
|
+
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
1012
|
+
*/
|
|
1013
|
+
declare const swish: typeof silu;
|
|
1014
|
+
/**
|
|
1015
|
+
* Log-sigmoid activation function, computed element-wise:
|
|
1016
|
+
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
1017
|
+
*/
|
|
1018
|
+
declare function logSigmoid(x: ArrayLike): Array;
|
|
1019
|
+
/** Identity activation function. Returns the argument unmodified. */
|
|
1020
|
+
declare const identity: (x: ArrayLike) => Array;
|
|
1021
|
+
|
|
1022
|
+
declare const nn_identity: typeof identity;
|
|
1023
|
+
declare const nn_logSigmoid: typeof logSigmoid;
|
|
1024
|
+
declare const nn_relu: typeof relu;
|
|
1025
|
+
declare const nn_relu6: typeof relu6;
|
|
1026
|
+
declare const nn_sigmoid: typeof sigmoid;
|
|
1027
|
+
declare const nn_silu: typeof silu;
|
|
1028
|
+
declare const nn_softSign: typeof softSign;
|
|
1029
|
+
declare const nn_softplus: typeof softplus;
|
|
1030
|
+
declare const nn_swish: typeof swish;
|
|
1031
|
+
declare namespace nn {
|
|
1032
|
+
export { nn_identity as identity, nn_logSigmoid as logSigmoid, nn_relu as relu, nn_relu6 as relu6, nn_sigmoid as sigmoid, nn_silu as silu, nn_softSign as softSign, nn_softplus as softplus, nn_swish as swish };
|
|
1033
|
+
}
|
|
1034
|
+
|
|
1035
|
+
type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
|
|
1036
|
+
/** Compute the forward-mode Jacobian-vector product for a function. */
|
|
1037
|
+
declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>, tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
|
|
1038
|
+
/** Vectorize an operation on a batched axis for one or more inputs. */
|
|
1039
|
+
declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
|
|
1040
|
+
/** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
|
|
1041
|
+
declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
|
|
1042
|
+
/** Construct a Jaxpr by dynamically tracing a function with example inputs. */
|
|
1043
|
+
declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: Parameters<F>) => {
|
|
1044
|
+
jaxpr: Jaxpr;
|
|
1045
|
+
consts: Array[];
|
|
1046
|
+
treedef: JsTreeDef;
|
|
1047
|
+
};
|
|
1048
|
+
declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
|
|
1049
|
+
/**
|
|
1050
|
+
* Produce a local linear approximation to a function at a point using jvp() and
|
|
1051
|
+
* partial evaluation.
|
|
1052
|
+
*/
|
|
1053
|
+
declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>];
|
|
1054
|
+
/** Calculate the reverse-mode vector-Jacobian product for a function. */
|
|
1055
|
+
declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
|
|
1056
|
+
/**
|
|
1057
|
+
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
1058
|
+
* first argument.
|
|
1059
|
+
*/
|
|
1060
|
+
declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
|
|
1061
|
+
/** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
|
|
1062
|
+
declare const jacrev: typeof jacfwd;
|
|
1063
|
+
/** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
|
|
1064
|
+
declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
|
|
1065
|
+
|
|
1066
|
+
export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn, numpy, setDevice, tree, vjp, vmap };
|