@jax-js/jax 0.0.2 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.ts CHANGED
@@ -20,6 +20,7 @@ interface Stringable {
20
20
  //#endregion
21
21
  //#region src/shape.d.ts
22
22
 
23
+ /** @inline */
23
24
  type Pair = [number, number];
24
25
  /**
25
26
  * A multidimensional view into memory. An array can be thought of as the
@@ -46,6 +47,8 @@ declare class View {
46
47
  get size(): number;
47
48
  /** Whether this is a default, contiguous, unaltered view of the data (identity). */
48
49
  get contiguous(): boolean;
50
+ /** Return the range of data being indexed in this view, or [0, 0] if none. */
51
+ dataRange(): [number, number];
49
52
  /** Produce an AluExp for evaluating this view at an index. */
50
53
  toAluExp(idxs: AluExp[]): [AluExp, AluExp];
51
54
  /**
@@ -106,6 +109,17 @@ declare class ShapeTracker {
106
109
  reshape(newShape: number[]): ShapeTracker;
107
110
  /** Broadcast along the given new axes, then expand the shape. */
108
111
  broadcast(newShape: number[], axis: number[]): ShapeTracker;
112
+ /**
113
+ * Repeat data in each axis by a positive number of repetitions.
114
+ *
115
+ * - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
116
+ * - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
117
+ */
118
+ repeat(reps: number[], tile?: boolean): ShapeTracker;
119
+ /** Move axis i to axis j. */
120
+ moveaxis(i: number, j: number): ShapeTracker;
121
+ /** Like pad(), but allows for negative values. */
122
+ padOrShrink(arg: Pair[]): ShapeTracker;
109
123
  }
110
124
  //#endregion
111
125
  //#region src/utils.d.ts
@@ -130,13 +144,16 @@ declare class FpHash {
130
144
  /** Run a function while caching it inline inside a `Map`. */
131
145
  //#endregion
132
146
  //#region src/alu.d.ts
147
+ /** A numerical data type for array contents. */
133
148
  declare enum DType {
134
149
  Float32 = "float32",
135
150
  Int32 = "int32",
136
151
  Uint32 = "uint32",
137
152
  Bool = "bool",
138
- Complex64 = "complex64",
153
+ Float16 = "float16",
139
154
  }
155
+ /** @inline */
156
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
140
157
  /**
141
158
  * Mathematical expression on scalar values.
142
159
  *
@@ -162,6 +179,7 @@ declare class AluExp implements FpHashable {
162
179
  static cos(a: AluExp): AluExp;
163
180
  static exp(a: AluExp): AluExp;
164
181
  static log(a: AluExp): AluExp;
182
+ static sqrt(a: AluExp): AluExp;
165
183
  static reciprocal(a: AluExp): AluExp;
166
184
  static cast(dtype: DType, a: AluExp): AluExp;
167
185
  static bitcast(dtype: DType, a: AluExp): AluExp;
@@ -172,12 +190,13 @@ declare class AluExp implements FpHashable {
172
190
  static const(dtype: DType, value: any): AluExp;
173
191
  static special(dtype: DType, name: string, n: number): AluExp;
174
192
  static variable(dtype: DType, name: string): AluExp;
175
- static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
193
+ static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
176
194
  static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
195
+ static f32(value: number): AluExp;
177
196
  static i32(value: number): AluExp;
178
197
  static u32(value: number): AluExp;
179
- static f32(value: number): AluExp;
180
198
  static bool(value: boolean): AluExp;
199
+ static f16(value: number): AluExp;
181
200
  not(): AluExp;
182
201
  /** Compute a reasonable expression hash with low collision rate. */
183
202
  getHash(): bigint;
@@ -188,6 +207,19 @@ declare class AluExp implements FpHashable {
188
207
  reindexGids(gidMap: Map<number, number>): AluExp;
189
208
  get min(): number;
190
209
  get max(): number;
210
+ /** Largest known integer that divides self. */
211
+ constFactor(): number;
212
+ /**
213
+ * Checks if divisible by an integer v and returns the quotient if it is, or
214
+ * `null` if it's not divisible.
215
+ */
216
+ divides(v: number): AluExp | null;
217
+ /**
218
+ * Get all expressions by deeply matching an operation.
219
+ *
220
+ * For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
221
+ */
222
+ splitOp(sep: AluOp): IterableIterator<AluExp>;
191
223
  /**
192
224
  * Simplify the expression by replacing any known patterns and deduping
193
225
  * identical subexpressions.
@@ -208,10 +240,16 @@ declare class AluExp implements FpHashable {
208
240
  toString(): string;
209
241
  /** Generic fold() operation with a reducer over the expression tree. */
210
242
  fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
243
+ /** Check if any expression in the tree satisfies a predicate. */
244
+ some(predicate: (exp: AluExp) => boolean): boolean;
211
245
  /** Rewrite the expression recursively using a visitor. */
212
246
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
213
247
  /** Collect all nodes that satisfy a predicate. */
214
248
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
249
+ /** Produce a list of all distinct AluOp in this expression. */
250
+ distinctOps(): Set<AluOp>;
251
+ /** Rewrite GlobalView operations to GlobalIndex operations. */
252
+ rewriteGlobalViews(): AluExp;
215
253
  }
216
254
  /** Symbolic form for each mathematical operation. */
217
255
  declare enum AluOp {
@@ -226,6 +264,7 @@ declare enum AluOp {
226
264
  Cos = "Cos",
227
265
  Exp = "Exp",
228
266
  Log = "Log",
267
+ Sqrt = "Sqrt",
229
268
  Reciprocal = "Reciprocal",
230
269
  Cast = "Cast",
231
270
  Bitcast = "Bitcast",
@@ -242,7 +281,7 @@ declare enum AluOp {
242
281
  Variable = "Variable",
243
282
  // arg = variable
244
283
  GlobalIndex = "GlobalIndex",
245
- // arg = gid; src = [bufidx]
284
+ // arg = [gid, len]; src = [bufidx]
246
285
  GlobalView = "GlobalView",
247
286
  }
248
287
  /**
@@ -297,12 +336,12 @@ declare class Reduction implements FpHashable {
297
336
  /** Size of the reduction axis. */
298
337
  readonly size: number;
299
338
  /** Follow-up expression defined with the "acc" variable, defaults to identity. */
300
- readonly fusion: AluExp;
339
+ readonly epilogue: AluExp;
301
340
  constructor(/** Data type of the values being reduced over. */
302
341
  dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
303
342
  op: AluOp, /** Size of the reduction axis. */
304
343
  size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
305
- fusion?: AluExp);
344
+ epilogue?: AluExp);
306
345
  hash(state: FpHash): void;
307
346
  toString(): string;
308
347
  /** Get the identity for this reduction operation. */
@@ -313,7 +352,7 @@ declare class Reduction implements FpHashable {
313
352
  /** Expression for accessing `indices` in input array with the given shape. */
314
353
  //#endregion
315
354
  //#region src/backend.d.ts
316
- type Device = "cpu" | "webgpu";
355
+ type Device = "cpu" | "wasm" | "webgpu";
317
356
  declare const devices: Device[];
318
357
  /** Set the default device backend (must be initialized). */
319
358
  declare function setDevice(device: Device): void;
@@ -336,7 +375,7 @@ interface Backend {
336
375
  /** Maximum number of arguments per dispatched kernel. */
337
376
  readonly maxArgs: number;
338
377
  /** Allocate a new slot with reference count 1. */
339
- malloc(size: number, initialData?: ArrayBuffer): Slot;
378
+ malloc(size: number, initialData?: Uint8Array): Slot;
340
379
  /** Increment the reference count of the slot. */
341
380
  incRef(slot: Slot): void;
342
381
  /**
@@ -345,9 +384,9 @@ interface Backend {
345
384
  */
346
385
  decRef(slot: Slot): void;
347
386
  /** Read a range of bytes from a buffer. */
348
- read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer>;
387
+ read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
349
388
  /** Read a range of bytes from a buffer, blocking variant. */
350
- readSync(slot: Slot, start?: number, count?: number): ArrayBuffer;
389
+ readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
351
390
  /** Prepare an expression to be executed later. */
352
391
  prepare(kernel: Kernel): Promise<Executable>;
353
392
  /** Prepare an expression to be executed later, blocking variant. */
@@ -369,18 +408,22 @@ declare class Executable<T = any> {
369
408
  data: T);
370
409
  }
371
410
  declare namespace tree_d_exports {
372
- export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
411
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
373
412
  }
374
413
  declare enum NodeType {
375
414
  Array = "Array",
376
415
  Object = "Object",
377
416
  Leaf = "Leaf",
378
417
  }
418
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
379
419
  type JsTree<T> = T | JsTree<T>[] | {
380
420
  [key: string]: JsTree<T>;
381
421
  };
382
- type MapJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
383
- /** Analog to the JAX "pytree" object, but for JavaScript. */
422
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
423
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
424
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
425
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
426
+ /** Represents the structure of a JsTree. */
384
427
  declare class JsTreeDef {
385
428
  readonly nodeType: NodeType;
386
429
  readonly nodeMetadata: any;
@@ -406,6 +449,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
406
449
  declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
407
450
  /** Take a reference of every array in a tree. */
408
451
  declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
452
+ /** Dispose every array in a tree. */
453
+ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
454
+ //#endregion
455
+ //#region src/frontend/convolution.d.ts
456
+ /** Definition of a general dilated convolution. Should be valid on creation. */
457
+ interface ConvParams {
458
+ strides: number[];
459
+ padding: [number, number][];
460
+ lhsDilation: number[];
461
+ rhsDilation: number[];
462
+ }
463
+ /**
464
+ * Check that the shapes and parameters passed to convolution are valid.
465
+ *
466
+ * If the check succeeds, returns the output shape.
467
+ */
468
+
409
469
  //#endregion
410
470
  //#region src/frontend/core.d.ts
411
471
  /**
@@ -433,10 +493,16 @@ declare enum Primitive {
433
493
  Cos = "cos",
434
494
  Exp = "exp",
435
495
  Log = "log",
496
+ Sqrt = "sqrt",
436
497
  Min = "min",
437
498
  Max = "max",
438
499
  Reduce = "reduce",
439
500
  Dot = "dot",
501
+ // sum(x*y, axis=-1)
502
+ Conv = "conv",
503
+ // see lax.conv_general_dilated
504
+ Pool = "pool",
505
+ PoolTranspose = "pool_transpose",
440
506
  Compare = "compare",
441
507
  Where = "where",
442
508
  Transpose = "transpose",
@@ -448,7 +514,7 @@ declare enum Primitive {
448
514
  Gather = "gather",
449
515
  JitCall = "jit_call",
450
516
  }
451
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>> {
517
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
452
518
  [Primitive.Cast]: {
453
519
  dtype: DType;
454
520
  };
@@ -459,6 +525,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
459
525
  op: AluOp;
460
526
  axis: number[];
461
527
  };
528
+ [Primitive.Conv]: ConvParams;
529
+ [Primitive.Pool]: {
530
+ window: number[];
531
+ strides: number[];
532
+ };
533
+ [Primitive.PoolTranspose]: {
534
+ inShape: number[];
535
+ window: number[];
536
+ strides: number[];
537
+ };
462
538
  [Primitive.Compare]: {
463
539
  op: CompareOp;
464
540
  };
@@ -480,10 +556,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
480
556
  axis: number[];
481
557
  };
482
558
  [Primitive.Shrink]: {
483
- slice: [number, number][];
559
+ slice: Pair[];
484
560
  };
485
561
  [Primitive.Pad]: {
486
- width: [number, number][];
562
+ width: Pair[];
487
563
  };
488
564
  [Primitive.Gather]: {
489
565
  axis: number[];
@@ -584,6 +660,7 @@ declare abstract class Tracer {
584
660
  */
585
661
  abstract dispose(): void;
586
662
  get shape(): number[];
663
+ get size(): number;
587
664
  get dtype(): DType;
588
665
  get ndim(): number;
589
666
  /** @ignore */
@@ -670,7 +747,7 @@ declare abstract class Tracer {
670
747
  * the "gather" primitive, and it allows you to access specific elements of
671
748
  * the array by integer indices stored in another array.
672
749
  */
673
- slice(...index: (number | [] | [number] | [number, number] | null | Tracer)[]): this;
750
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
674
751
  }
675
752
  declare class ShapedArray implements AbstractValue {
676
753
  readonly shape: number[];
@@ -678,7 +755,7 @@ declare class ShapedArray implements AbstractValue {
678
755
  constructor(shape: number[], dtype: DType);
679
756
  static fromAval(aval: AbstractValue): ShapedArray;
680
757
  get ndim(): number;
681
- strShort(): string;
758
+ toString(): string;
682
759
  equals(other: ShapedArray): boolean;
683
760
  }
684
761
  //#endregion
@@ -749,14 +826,23 @@ declare class Array extends Tracer {
749
826
  */
750
827
  [Symbol.toPrimitive](): any;
751
828
  /** Realize the array and return it as data. */
752
- data(): Promise<Float32Array | Int32Array | Uint32Array>;
753
- /** Wait for this array to be placed on the backend, if needed. */
754
- wait(): Promise<void>;
829
+ data(): Promise<DataArray>;
830
+ /**
831
+ * Wait for this array to finish evaluation.
832
+ *
833
+ * Operations and data loading in jax-js are lazy, so this function ensures
834
+ * that pending operations are dispatched and fully executed before it
835
+ * returns.
836
+ *
837
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
838
+ * dispatch of operations as well.
839
+ */
840
+ wait(): Promise<Array>;
755
841
  /**
756
842
  * Realize the array and return it as data. This is a sync variant and not
757
843
  * recommended for performance reasons, as it will block rendering.
758
844
  */
759
- dataSync(): Float32Array | Int32Array | Uint32Array;
845
+ dataSync(): DataArray;
760
846
  /**
761
847
  * Convert this array into a JavaScript object.
762
848
  *
@@ -769,6 +855,13 @@ declare class Array extends Tracer {
769
855
  js(): any;
770
856
  /** Convert this array into a JavaScript object, asynchronously. */
771
857
  jsAsync(): Promise<any>;
858
+ /**
859
+ * Copy an element of an array to a numeric scalar and return it.
860
+ *
861
+ * Throws an error if the array does not have a single element. The array must
862
+ * either be rank-0, or all dimensions of the shape are 1.
863
+ */
864
+ item(): number;
772
865
  /** @private Internal plumbing method for Array / Tracer ops. */
773
866
  static _implRules(): typeof implRules;
774
867
  _realizeSource(): number;
@@ -779,7 +872,7 @@ declare function scalar(value: number | boolean, {
779
872
  device
780
873
  }?: DTypeAndDevice): Array;
781
874
  /** Constructor for creating a new array from data. */
782
- declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
875
+ declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
783
876
  shape,
784
877
  dtype,
785
878
  device
@@ -851,15 +944,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
851
944
  dtype,
852
945
  device
853
946
  }?: DTypeAndDevice): Array;
854
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
855
947
  declare namespace numpy_d_exports {
856
- export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, complex64, concatenate, cos, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float32, full, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, size, square, stack, sum, tan, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros };
948
+ export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
857
949
  }
858
950
  declare const float32 = DType.Float32;
859
951
  declare const int32 = DType.Int32;
860
952
  declare const uint32 = DType.Uint32;
861
953
  declare const bool = DType.Bool;
862
- declare const complex64 = DType.Complex64;
954
+ declare const float16 = DType.Float16;
863
955
  /** Euler's constant, `e = 2.7182818284590...` */
864
956
  declare const e: number;
865
957
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -886,6 +978,8 @@ declare const cos: (x: ArrayLike) => Array;
886
978
  declare const exp: (x: ArrayLike) => Array;
887
979
  /** Calculate the natural logarithm of all elements in the input array. */
888
980
  declare const log: (x: ArrayLike) => Array;
981
+ /** Calculate the square root of all elements in the input array. */
982
+ declare const sqrt: (x: ArrayLike) => Array;
889
983
  /** Return element-wise minimum of the input arrays. */
890
984
  declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
891
985
  /** Return element-wise maximum of the input arrays. */
@@ -927,6 +1021,12 @@ declare const pad: (x: ArrayLike, width: number | [number, number] | [number, nu
927
1021
  declare const ndim: (x: ArrayLike) => number;
928
1022
  /** Return the shape of an array. Does not consume array reference. */
929
1023
  declare const shape$1: (x: ArrayLike) => number[];
1024
+ /** Return an array of zeros with the same shape and type as a given array. */
1025
+ declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1026
+ /** Return an array of ones with the same shape and type as a given array. */
1027
+ declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1028
+ /** Return a full array with the same shape and type as a given array. */
1029
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
930
1030
  /**
931
1031
  * Return the number of elements in an array, optionally along an axis.
932
1032
  * Does not consume array reference.
@@ -1010,9 +1110,11 @@ declare function ravel(a: ArrayLike): Array;
1010
1110
  * Return specified diagonals.
1011
1111
  *
1012
1112
  * If a is 2D, return the diagonal of the array with the given offset. If a is
1013
- * 3D or higher, compute diagonals along the two given axes.
1113
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1014
1114
  *
1015
- * This returns a view over the existing array.
1115
+ * This returns a view over the existing array. The shape of the resulting array
1116
+ * is determined by removing the two axes along which the diagonal is taken,
1117
+ * then appending a new axis to the right with holding the diagonals.
1016
1118
  */
1017
1119
  declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1018
1120
  /**
@@ -1084,6 +1186,24 @@ declare function exp2(p: ArrayLike): Array;
1084
1186
  declare function log2(x: ArrayLike): Array;
1085
1187
  /** Return the base-10 logarithm of x, element-wise. */
1086
1188
  declare function log10(x: ArrayLike): Array;
1189
+ /**
1190
+ * Calculate element-wise hyperbolic sine of input.
1191
+ *
1192
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
1193
+ */
1194
+ declare function sinh(x: ArrayLike): Array;
1195
+ /**
1196
+ * Calculate element-wise hyperbolic cosine of input.
1197
+ *
1198
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
1199
+ */
1200
+ declare function cosh(x: ArrayLike): Array;
1201
+ /**
1202
+ * Calculate element-wise hyperbolic tangent of input.
1203
+ *
1204
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1205
+ */
1206
+ declare function tanh(x: ArrayLike): Array;
1087
1207
  //#endregion
1088
1208
  //#region src/frontend/jaxpr.d.ts
1089
1209
  /** Variable in a Jaxpr expression. */
@@ -1146,8 +1266,38 @@ declare class Jaxpr implements FpHashable {
1146
1266
  /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1147
1267
  flatten(): Jaxpr;
1148
1268
  }
1269
+ /** @inline */
1270
+ type JitOpts = {
1271
+ staticArgnums?: number[];
1272
+ device?: Device;
1273
+ };
1274
+ declare namespace lax_d_exports {
1275
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1276
+ }
1277
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1278
+ /**
1279
+ * General n-dimensional convolution operator, with optional dilation.
1280
+ *
1281
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1282
+ * function in JAX, which wraps XLA's general convolution operator.
1283
+ *
1284
+ * Grouped convolutions are not supported right now.
1285
+ */
1286
+ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1287
+ lhsDilation,
1288
+ rhsDilation
1289
+ }?: {
1290
+ lhsDilation?: number[];
1291
+ rhsDilation?: number[];
1292
+ }): Array;
1293
+ /** Convenience wrapper around `convGeneralDilated`. */
1294
+ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
1295
+ /** Convenience wrapper around `convGeneralDilated`. */
1296
+ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1297
+ /** Reduce a computation over padded windows. */
1298
+ declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1149
1299
  declare namespace nn_d_exports {
1150
- export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1300
+ export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1151
1301
  }
1152
1302
  /**
1153
1303
  * Rectified Linear Unit (ReLU) activation function:
@@ -1187,7 +1337,7 @@ declare function softSign(x: ArrayLike): Array;
1187
1337
  *
1188
1338
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1189
1339
  */
1190
- declare function silu(x: ArrayLike): Array;
1340
+ declare const silu: (x: ArrayLike) => Array;
1191
1341
  /**
1192
1342
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1193
1343
  * Swish, computed element-wise:
@@ -1197,7 +1347,7 @@ declare function silu(x: ArrayLike): Array;
1197
1347
  *
1198
1348
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1199
1349
  */
1200
- declare const swish: typeof silu;
1350
+ declare const swish: (x: ArrayLike) => Array;
1201
1351
  /**
1202
1352
  * Log-sigmoid activation function, computed element-wise:
1203
1353
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
@@ -1205,6 +1355,48 @@ declare const swish: typeof silu;
1205
1355
  declare function logSigmoid(x: ArrayLike): Array;
1206
1356
  /** Identity activation function. Returns the argument unmodified. */
1207
1357
  declare const identity: (x: ArrayLike) => Array;
1358
+ /** Leaky rectified linear (ReLU) activation function */
1359
+ declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
1360
+ /**
1361
+ * Exponential linear unit activation function.
1362
+ *
1363
+ * Computes the element-wise function:
1364
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
1365
+ */
1366
+ declare function elu(x: ArrayLike, alpha?: number): Array;
1367
+ /**
1368
+ * Continuously-differentiable exponential linear unit activation function.
1369
+ *
1370
+ * Computes the element-wise function:
1371
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1372
+ */
1373
+ declare function celu(x: ArrayLike, alpha?: number): Array;
1374
+ /**
1375
+ * Gaussion error linear unit (GELU) activation function.
1376
+ *
1377
+ * This is computed element-wise. Currently jax-js does not support the erf() or
1378
+ * gelu() functions exactly as primitives, so an approximation is used:
1379
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1380
+ *
1381
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1382
+ *
1383
+ * This will be improved in the future.
1384
+ */
1385
+ declare const gelu: (x: ArrayLike) => Array;
1386
+ /**
1387
+ * Gated linear unit (GLU) activation function.
1388
+ *
1389
+ * Splits the `axis` dimension of the input into two halves, a and b, then
1390
+ * computes `a * sigmoid(b)`.
1391
+ */
1392
+ declare function glu(x: ArrayLike, axis?: number): Array;
1393
+ /**
1394
+ * Mish activation function.
1395
+ *
1396
+ * Computes the element-wise function:
1397
+ * `mish(x) = x * tanh(softplus(x))`
1398
+ */
1399
+ declare function mish(x: ArrayLike): Array;
1208
1400
  /**
1209
1401
  * Softmax function. Computes the function which rescales elements to the range
1210
1402
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -1268,38 +1460,49 @@ declare function uniform(key: Array, shape?: number[], {
1268
1460
  }): Array;
1269
1461
  //#endregion
1270
1462
  //#region src/index.d.ts
1271
- /** @inline */
1272
- type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
1273
1463
  /** Compute the forward-mode Jacobian-vector product for a function. */
1274
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>, tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1464
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1275
1465
  /** Vectorize an operation on a batched axis for one or more inputs. */
1276
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1466
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1277
1467
  /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
1278
- declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1468
+ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1279
1469
  /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
1280
- declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: Parameters<F>) => {
1470
+ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
1281
1471
  jaxpr: Jaxpr;
1282
1472
  consts: Array[];
1283
1473
  treedef: JsTreeDef;
1284
1474
  };
1285
- declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1475
+ /**
1476
+ * Mark a function for automatic JIT compilation, with operator fusion.
1477
+ *
1478
+ * The function will be compiled the first time it is called with a set of
1479
+ * argument shapes.
1480
+ *
1481
+ * **Options:**
1482
+ * - `staticArgnums`: An array of argument indices to treat as static
1483
+ * (compile-time constant). These arguments must be hashable, won't be traced,
1484
+ * and different values will trigger recompilation.
1485
+ * - `device`: The device to place the computation on. If not specified, the
1486
+ * computation will be placed on the default device.
1487
+ */
1488
+ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1286
1489
  /**
1287
1490
  * Produce a local linear approximation to a function at a point using jvp() and
1288
1491
  * partial evaluation.
1289
1492
  */
1290
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>];
1493
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
1291
1494
  /** Calculate the reverse-mode vector-Jacobian product for a function. */
1292
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1495
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1293
1496
  /**
1294
1497
  * Compute the gradient of a scalar-valued function `f` with respect to its
1295
1498
  * first argument.
1296
1499
  */
1297
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1500
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1298
1501
  /** Create a function that evaluates both `f` and the gradient of `f`. */
1299
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1502
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1300
1503
  /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
1301
1504
  declare const jacrev: typeof jacfwd;
1302
1505
  /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
1303
- declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1506
+ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1304
1507
  //#endregion
1305
- export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1508
+ export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };