@jax-js/jax 0.0.2 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.d.cts CHANGED
@@ -23,6 +23,7 @@ interface Stringable {
23
23
  }
24
24
  //#endregion
25
25
  //#region src/shape.d.ts
26
+ /** @inline */
26
27
  type Pair = [number, number];
27
28
  /**
28
29
  * A multidimensional view into memory. An array can be thought of as the
@@ -49,6 +50,8 @@ declare class View {
49
50
  get size(): number;
50
51
  /** Whether this is a default, contiguous, unaltered view of the data (identity). */
51
52
  get contiguous(): boolean;
53
+ /** Return the range of data being indexed in this view, or [0, 0] if none. */
54
+ dataRange(): [number, number];
52
55
  /** Produce an AluExp for evaluating this view at an index. */
53
56
  toAluExp(idxs: AluExp[]): [AluExp, AluExp];
54
57
  /**
@@ -109,6 +112,17 @@ declare class ShapeTracker {
109
112
  reshape(newShape: number[]): ShapeTracker;
110
113
  /** Broadcast along the given new axes, then expand the shape. */
111
114
  broadcast(newShape: number[], axis: number[]): ShapeTracker;
115
+ /**
116
+ * Repeat data in each axis by a positive number of repetitions.
117
+ *
118
+ * - If `tile` is true (default): [1, 2, 3] -> [1, 2, 3, 1, 2, 3].
119
+ * - If `tile` is false: [1, 2, 3] -> [1, 1, 2, 2, 3, 3].
120
+ */
121
+ repeat(reps: number[], tile?: boolean): ShapeTracker;
122
+ /** Move axis i to axis j. */
123
+ moveaxis(i: number, j: number): ShapeTracker;
124
+ /** Like pad(), but allows for negative values. */
125
+ padOrShrink(arg: Pair[]): ShapeTracker;
112
126
  }
113
127
  //#endregion
114
128
  //#region src/utils.d.ts
@@ -133,13 +147,16 @@ declare class FpHash {
133
147
  /** Run a function while caching it inline inside a `Map`. */
134
148
  //#endregion
135
149
  //#region src/alu.d.ts
150
+ /** A numerical data type for array contents. */
136
151
  declare enum DType {
137
152
  Float32 = "float32",
138
153
  Int32 = "int32",
139
154
  Uint32 = "uint32",
140
155
  Bool = "bool",
141
- Complex64 = "complex64",
156
+ Float16 = "float16",
142
157
  }
158
+ /** @inline */
159
+ type DataArray = Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | Float16Array<ArrayBuffer>;
143
160
  /**
144
161
  * Mathematical expression on scalar values.
145
162
  *
@@ -165,6 +182,7 @@ declare class AluExp implements FpHashable {
165
182
  static cos(a: AluExp): AluExp;
166
183
  static exp(a: AluExp): AluExp;
167
184
  static log(a: AluExp): AluExp;
185
+ static sqrt(a: AluExp): AluExp;
168
186
  static reciprocal(a: AluExp): AluExp;
169
187
  static cast(dtype: DType, a: AluExp): AluExp;
170
188
  static bitcast(dtype: DType, a: AluExp): AluExp;
@@ -175,12 +193,13 @@ declare class AluExp implements FpHashable {
175
193
  static const(dtype: DType, value: any): AluExp;
176
194
  static special(dtype: DType, name: string, n: number): AluExp;
177
195
  static variable(dtype: DType, name: string): AluExp;
178
- static globalIndex(dtype: DType, gid: number, bufidx: AluExp): AluExp;
196
+ static globalIndex(dtype: DType, gid: number, len: number, bufidx: AluExp): AluExp;
179
197
  static globalView(dtype: DType, gid: number, st: ShapeTracker, indices: AluExp[]): AluExp;
198
+ static f32(value: number): AluExp;
180
199
  static i32(value: number): AluExp;
181
200
  static u32(value: number): AluExp;
182
- static f32(value: number): AluExp;
183
201
  static bool(value: boolean): AluExp;
202
+ static f16(value: number): AluExp;
184
203
  not(): AluExp;
185
204
  /** Compute a reasonable expression hash with low collision rate. */
186
205
  getHash(): bigint;
@@ -191,6 +210,19 @@ declare class AluExp implements FpHashable {
191
210
  reindexGids(gidMap: Map<number, number>): AluExp;
192
211
  get min(): number;
193
212
  get max(): number;
213
+ /** Largest known integer that divides self. */
214
+ constFactor(): number;
215
+ /**
216
+ * Checks if divisible by an integer v and returns the quotient if it is, or
217
+ * `null` if it's not divisible.
218
+ */
219
+ divides(v: number): AluExp | null;
220
+ /**
221
+ * Get all expressions by deeply matching an operation.
222
+ *
223
+ * For example: `((2+(3*5))+4).splitOp(+) -> [2,(3*5),4]`.
224
+ */
225
+ splitOp(sep: AluOp): IterableIterator<AluExp>;
194
226
  /**
195
227
  * Simplify the expression by replacing any known patterns and deduping
196
228
  * identical subexpressions.
@@ -211,10 +243,16 @@ declare class AluExp implements FpHashable {
211
243
  toString(): string;
212
244
  /** Generic fold() operation with a reducer over the expression tree. */
213
245
  fold<T = void>(reducer: (exp: AluExp, mappedSrc: T[]) => T): T;
246
+ /** Check if any expression in the tree satisfies a predicate. */
247
+ some(predicate: (exp: AluExp) => boolean): boolean;
214
248
  /** Rewrite the expression recursively using a visitor. */
215
249
  rewrite(visitor: (exp: AluExp) => AluExp | undefined | null): AluExp;
216
250
  /** Collect all nodes that satisfy a predicate. */
217
251
  collect(predicate: (exp: AluExp) => boolean): AluExp[];
252
+ /** Produce a list of all distinct AluOp in this expression. */
253
+ distinctOps(): Set<AluOp>;
254
+ /** Rewrite GlobalView operations to GlobalIndex operations. */
255
+ rewriteGlobalViews(): AluExp;
218
256
  }
219
257
  /** Symbolic form for each mathematical operation. */
220
258
  declare enum AluOp {
@@ -229,6 +267,7 @@ declare enum AluOp {
229
267
  Cos = "Cos",
230
268
  Exp = "Exp",
231
269
  Log = "Log",
270
+ Sqrt = "Sqrt",
232
271
  Reciprocal = "Reciprocal",
233
272
  Cast = "Cast",
234
273
  Bitcast = "Bitcast",
@@ -245,7 +284,7 @@ declare enum AluOp {
245
284
  Variable = "Variable",
246
285
  // arg = variable
247
286
  GlobalIndex = "GlobalIndex",
248
- // arg = gid; src = [bufidx]
287
+ // arg = [gid, len]; src = [bufidx]
249
288
  GlobalView = "GlobalView",
250
289
  }
251
290
  /**
@@ -300,12 +339,12 @@ declare class Reduction implements FpHashable {
300
339
  /** Size of the reduction axis. */
301
340
  readonly size: number;
302
341
  /** Follow-up expression defined with the "acc" variable, defaults to identity. */
303
- readonly fusion: AluExp;
342
+ readonly epilogue: AluExp;
304
343
  constructor(/** Data type of the values being reduced over. */
305
344
  dtype: DType, /** Operation to perform. Only ops in `AluGroup.Reduce` are supported. */
306
345
  op: AluOp, /** Size of the reduction axis. */
307
346
  size: number, /** Follow-up expression defined with the "acc" variable, defaults to identity. */
308
- fusion?: AluExp);
347
+ epilogue?: AluExp);
309
348
  hash(state: FpHash): void;
310
349
  toString(): string;
311
350
  /** Get the identity for this reduction operation. */
@@ -316,7 +355,7 @@ declare class Reduction implements FpHashable {
316
355
  /** Expression for accessing `indices` in input array with the given shape. */
317
356
  //#endregion
318
357
  //#region src/backend.d.ts
319
- type Device = "cpu" | "webgpu";
358
+ type Device = "cpu" | "wasm" | "webgpu";
320
359
  declare const devices: Device[];
321
360
  /** Set the default device backend (must be initialized). */
322
361
  declare function setDevice(device: Device): void;
@@ -339,7 +378,7 @@ interface Backend {
339
378
  /** Maximum number of arguments per dispatched kernel. */
340
379
  readonly maxArgs: number;
341
380
  /** Allocate a new slot with reference count 1. */
342
- malloc(size: number, initialData?: ArrayBuffer): Slot;
381
+ malloc(size: number, initialData?: Uint8Array): Slot;
343
382
  /** Increment the reference count of the slot. */
344
383
  incRef(slot: Slot): void;
345
384
  /**
@@ -348,9 +387,9 @@ interface Backend {
348
387
  */
349
388
  decRef(slot: Slot): void;
350
389
  /** Read a range of bytes from a buffer. */
351
- read(slot: Slot, start?: number, count?: number): Promise<ArrayBuffer>;
390
+ read(slot: Slot, start?: number, count?: number): Promise<Uint8Array<ArrayBuffer>>;
352
391
  /** Read a range of bytes from a buffer, blocking variant. */
353
- readSync(slot: Slot, start?: number, count?: number): ArrayBuffer;
392
+ readSync(slot: Slot, start?: number, count?: number): Uint8Array<ArrayBuffer>;
354
393
  /** Prepare an expression to be executed later. */
355
394
  prepare(kernel: Kernel): Promise<Executable>;
356
395
  /** Prepare an expression to be executed later, blocking variant. */
@@ -372,18 +411,22 @@ declare class Executable<T = any> {
372
411
  data: T);
373
412
  }
374
413
  declare namespace tree_d_exports {
375
- export { JsTree, JsTreeDef, MapJsTree, NodeType, flatten, leaves, map, ref, structure, unflatten };
414
+ export { JsTree, JsTreeDef, MapJsTree, NodeType, dispose, flatten, leaves, map, ref, structure, unflatten };
376
415
  }
377
416
  declare enum NodeType {
378
417
  Array = "Array",
379
418
  Object = "Object",
380
419
  Leaf = "Leaf",
381
420
  }
421
+ /** Analog to the JAX "pytree" object, but for JavaScript. */
382
422
  type JsTree<T> = T | JsTree<T>[] | {
383
423
  [key: string]: JsTree<T>;
384
424
  };
385
- type MapJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
386
- /** Analog to the JAX "pytree" object, but for JavaScript. */
425
+ type Same<X, Y> = (<T>() => T extends X ? 1 : 2) extends (<T>() => T extends Y ? 1 : 2) ? true : false;
426
+ type MappedJsTree<T, A, B> = T extends A ? B : T extends globalThis.Array<infer U> ? number extends T["length"] ? MapJsTree<U, A, B>[] : { [K in keyof T]: MapJsTree<T[K], A, B> } : { [K in keyof T]: MapJsTree<T[K], A, B> };
427
+ /** @ignore Convert a subtype of JsTree<A> into a JsTree<B>, with the same structure. */
428
+ type MapJsTree<T, A, B> = Same<A, B> extends true ? T : MappedJsTree<T, A, B>;
429
+ /** Represents the structure of a JsTree. */
387
430
  declare class JsTreeDef {
388
431
  readonly nodeType: NodeType;
389
432
  readonly nodeMetadata: any;
@@ -409,6 +452,23 @@ declare function unflatten<T>(treedef: JsTreeDef, leaves: Iterable<T>): JsTree<T
409
452
  declare function map<T, U, Tree extends JsTree<T>>(fn: (...args: T[]) => U, tree: Tree, ...rest: Tree[]): MapJsTree<Tree, T, U>;
410
453
  /** Take a reference of every array in a tree. */
411
454
  declare function ref<Tree extends JsTree<Array>>(tree: Tree): Tree;
455
+ /** Dispose every array in a tree. */
456
+ declare function dispose<Tree extends JsTree<Array>>(tree: Tree | null | undefined): void;
457
+ //#endregion
458
+ //#region src/frontend/convolution.d.ts
459
+ /** Definition of a general dilated convolution. Should be valid on creation. */
460
+ interface ConvParams {
461
+ strides: number[];
462
+ padding: [number, number][];
463
+ lhsDilation: number[];
464
+ rhsDilation: number[];
465
+ }
466
+ /**
467
+ * Check that the shapes and parameters passed to convolution are valid.
468
+ *
469
+ * If the check succeeds, returns the output shape.
470
+ */
471
+
412
472
  //#endregion
413
473
  //#region src/frontend/core.d.ts
414
474
  /**
@@ -436,10 +496,16 @@ declare enum Primitive {
436
496
  Cos = "cos",
437
497
  Exp = "exp",
438
498
  Log = "log",
499
+ Sqrt = "sqrt",
439
500
  Min = "min",
440
501
  Max = "max",
441
502
  Reduce = "reduce",
442
503
  Dot = "dot",
504
+ // sum(x*y, axis=-1)
505
+ Conv = "conv",
506
+ // see lax.conv_general_dilated
507
+ Pool = "pool",
508
+ PoolTranspose = "pool_transpose",
443
509
  Compare = "compare",
444
510
  Where = "where",
445
511
  Transpose = "transpose",
@@ -451,7 +517,7 @@ declare enum Primitive {
451
517
  Gather = "gather",
452
518
  JitCall = "jit_call",
453
519
  }
454
- interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>> {
520
+ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, any>> {
455
521
  [Primitive.Cast]: {
456
522
  dtype: DType;
457
523
  };
@@ -462,6 +528,16 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
462
528
  op: AluOp;
463
529
  axis: number[];
464
530
  };
531
+ [Primitive.Conv]: ConvParams;
532
+ [Primitive.Pool]: {
533
+ window: number[];
534
+ strides: number[];
535
+ };
536
+ [Primitive.PoolTranspose]: {
537
+ inShape: number[];
538
+ window: number[];
539
+ strides: number[];
540
+ };
465
541
  [Primitive.Compare]: {
466
542
  op: CompareOp;
467
543
  };
@@ -483,10 +559,10 @@ interface PrimitiveParamsImpl extends Record<Primitive, Record<string, unknown>>
483
559
  axis: number[];
484
560
  };
485
561
  [Primitive.Shrink]: {
486
- slice: [number, number][];
562
+ slice: Pair[];
487
563
  };
488
564
  [Primitive.Pad]: {
489
- width: [number, number][];
565
+ width: Pair[];
490
566
  };
491
567
  [Primitive.Gather]: {
492
568
  axis: number[];
@@ -587,6 +663,7 @@ declare abstract class Tracer {
587
663
  */
588
664
  abstract dispose(): void;
589
665
  get shape(): number[];
666
+ get size(): number;
590
667
  get dtype(): DType;
591
668
  get ndim(): number;
592
669
  /** @ignore */
@@ -673,7 +750,7 @@ declare abstract class Tracer {
673
750
  * the "gather" primitive, and it allows you to access specific elements of
674
751
  * the array by integer indices stored in another array.
675
752
  */
676
- slice(...index: (number | [] | [number] | [number, number] | null | Tracer)[]): this;
753
+ slice(...index: (number | [] | [number] | Pair | null | Tracer)[]): this;
677
754
  }
678
755
  declare class ShapedArray implements AbstractValue {
679
756
  readonly shape: number[];
@@ -681,7 +758,7 @@ declare class ShapedArray implements AbstractValue {
681
758
  constructor(shape: number[], dtype: DType);
682
759
  static fromAval(aval: AbstractValue): ShapedArray;
683
760
  get ndim(): number;
684
- strShort(): string;
761
+ toString(): string;
685
762
  equals(other: ShapedArray): boolean;
686
763
  }
687
764
  //#endregion
@@ -752,14 +829,23 @@ declare class Array extends Tracer {
752
829
  */
753
830
  [Symbol.toPrimitive](): any;
754
831
  /** Realize the array and return it as data. */
755
- data(): Promise<Float32Array | Int32Array | Uint32Array>;
756
- /** Wait for this array to be placed on the backend, if needed. */
757
- wait(): Promise<void>;
832
+ data(): Promise<DataArray>;
833
+ /**
834
+ * Wait for this array to finish evaluation.
835
+ *
836
+ * Operations and data loading in jax-js are lazy, so this function ensures
837
+ * that pending operations are dispatched and fully executed before it
838
+ * returns.
839
+ *
840
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
841
+ * dispatch of operations as well.
842
+ */
843
+ wait(): Promise<Array>;
758
844
  /**
759
845
  * Realize the array and return it as data. This is a sync variant and not
760
846
  * recommended for performance reasons, as it will block rendering.
761
847
  */
762
- dataSync(): Float32Array | Int32Array | Uint32Array;
848
+ dataSync(): DataArray;
763
849
  /**
764
850
  * Convert this array into a JavaScript object.
765
851
  *
@@ -772,6 +858,13 @@ declare class Array extends Tracer {
772
858
  js(): any;
773
859
  /** Convert this array into a JavaScript object, asynchronously. */
774
860
  jsAsync(): Promise<any>;
861
+ /**
862
+ * Copy an element of an array to a numeric scalar and return it.
863
+ *
864
+ * Throws an error if the array does not have a single element. The array must
865
+ * either be rank-0, or all dimensions of the shape are 1.
866
+ */
867
+ item(): number;
775
868
  /** @private Internal plumbing method for Array / Tracer ops. */
776
869
  static _implRules(): typeof implRules;
777
870
  _realizeSource(): number;
@@ -782,7 +875,7 @@ declare function scalar(value: number | boolean, {
782
875
  device
783
876
  }?: DTypeAndDevice): Array;
784
877
  /** Constructor for creating a new array from data. */
785
- declare function array(values: Array | Float32Array | Int32Array | RecursiveArray<number> | RecursiveArray<boolean>, {
878
+ declare function array(values: Array | Float16Array<ArrayBuffer> | Float32Array<ArrayBuffer> | Int32Array<ArrayBuffer> | Uint32Array<ArrayBuffer> | RecursiveArray<number> | RecursiveArray<boolean>, {
786
879
  shape,
787
880
  dtype,
788
881
  device
@@ -854,15 +947,14 @@ declare function linspace(start: number, stop: number, num?: number, endpoint?:
854
947
  dtype,
855
948
  device
856
949
  }?: DTypeAndDevice): Array;
857
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
858
950
  declare namespace numpy_d_exports {
859
- export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, complex64, concatenate, cos, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float32, full, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, size, square, stack, sum, tan, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros };
951
+ export { Array, ArrayLike, DType, abs, absolute, add, allclose, arange, argmax, argmin, array, astype, bool, clip, columnStack, concatenate, cos, cosh, diag, diagonal, divide, dot, dstack, e, equal, eulerGamma, exp, exp2, eye, flip, fliplr, flipud, float16, float32, full, fullLike, greater, greaterEqual, hstack, identity$1 as identity, inf, int32, less, lessEqual, linspace, log, log10, log2, matmul, max, maximum, mean, meshgrid, min, minimum, moveaxis, multiply, nan, ndim, negative, notEqual, ones, onesLike, pad, permuteDims, pi, prod, ravel, reciprocal, reshape, scalar, shape$1 as shape, sin, sinh, size, sqrt, square, stack, sum, tan, tanh, transpose, trueDivide, trunc, uint32, vdot, vecdot, vstack, where, zeros, zerosLike };
860
952
  }
861
953
  declare const float32 = DType.Float32;
862
954
  declare const int32 = DType.Int32;
863
955
  declare const uint32 = DType.Uint32;
864
956
  declare const bool = DType.Bool;
865
- declare const complex64 = DType.Complex64;
957
+ declare const float16 = DType.Float16;
866
958
  /** Euler's constant, `e = 2.7182818284590...` */
867
959
  declare const e: number;
868
960
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -889,6 +981,8 @@ declare const cos: (x: ArrayLike) => Array;
889
981
  declare const exp: (x: ArrayLike) => Array;
890
982
  /** Calculate the natural logarithm of all elements in the input array. */
891
983
  declare const log: (x: ArrayLike) => Array;
984
+ /** Calculate the square root of all elements in the input array. */
985
+ declare const sqrt: (x: ArrayLike) => Array;
892
986
  /** Return element-wise minimum of the input arrays. */
893
987
  declare const minimum: (x: ArrayLike, y: ArrayLike) => Array;
894
988
  /** Return element-wise maximum of the input arrays. */
@@ -930,6 +1024,12 @@ declare const pad: (x: ArrayLike, width: number | [number, number] | [number, nu
930
1024
  declare const ndim: (x: ArrayLike) => number;
931
1025
  /** Return the shape of an array. Does not consume array reference. */
932
1026
  declare const shape$1: (x: ArrayLike) => number[];
1027
+ /** Return an array of zeros with the same shape and type as a given array. */
1028
+ declare const zerosLike: (a: ArrayLike, dtype?: DType) => Array;
1029
+ /** Return an array of ones with the same shape and type as a given array. */
1030
+ declare const onesLike: (a: ArrayLike, dtype?: DType) => Array;
1031
+ /** Return a full array with the same shape and type as a given array. */
1032
+ declare const fullLike: (a: ArrayLike, fillValue: number | boolean | Array, dtype?: DType) => Array;
933
1033
  /**
934
1034
  * Return the number of elements in an array, optionally along an axis.
935
1035
  * Does not consume array reference.
@@ -1013,9 +1113,11 @@ declare function ravel(a: ArrayLike): Array;
1013
1113
  * Return specified diagonals.
1014
1114
  *
1015
1115
  * If a is 2D, return the diagonal of the array with the given offset. If a is
1016
- * 3D or higher, compute diagonals along the two given axes.
1116
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
1017
1117
  *
1018
- * This returns a view over the existing array.
1118
+ * This returns a view over the existing array. The shape of the resulting array
1119
+ * is determined by removing the two axes along which the diagonal is taken,
1120
+ * then appending a new axis to the right with holding the diagonals.
1019
1121
  */
1020
1122
  declare function diagonal(a: ArrayLike, offset?: number, axis1?: number, axis2?: number): Array;
1021
1123
  /**
@@ -1087,6 +1189,24 @@ declare function exp2(p: ArrayLike): Array;
1087
1189
  declare function log2(x: ArrayLike): Array;
1088
1190
  /** Return the base-10 logarithm of x, element-wise. */
1089
1191
  declare function log10(x: ArrayLike): Array;
1192
+ /**
1193
+ * Calculate element-wise hyperbolic sine of input.
1194
+ *
1195
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
1196
+ */
1197
+ declare function sinh(x: ArrayLike): Array;
1198
+ /**
1199
+ * Calculate element-wise hyperbolic cosine of input.
1200
+ *
1201
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
1202
+ */
1203
+ declare function cosh(x: ArrayLike): Array;
1204
+ /**
1205
+ * Calculate element-wise hyperbolic tangent of input.
1206
+ *
1207
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
1208
+ */
1209
+ declare function tanh(x: ArrayLike): Array;
1090
1210
  //#endregion
1091
1211
  //#region src/frontend/jaxpr.d.ts
1092
1212
  /** Variable in a Jaxpr expression. */
@@ -1149,8 +1269,38 @@ declare class Jaxpr implements FpHashable {
1149
1269
  /** Flattens nested JitCall in a Jaxpr. Useful for handling jit-of-jit. */
1150
1270
  flatten(): Jaxpr;
1151
1271
  }
1272
+ /** @inline */
1273
+ type JitOpts = {
1274
+ staticArgnums?: number[];
1275
+ device?: Device;
1276
+ };
1277
+ declare namespace lax_d_exports {
1278
+ export { PaddingType, conv, convGeneralDilated, convWithGeneralPadding, reduceWindow };
1279
+ }
1280
+ type PaddingType = "VALID" | "SAME" | "SAME_LOWER" | [number, number][];
1281
+ /**
1282
+ * General n-dimensional convolution operator, with optional dilation.
1283
+ *
1284
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
1285
+ * function in JAX, which wraps XLA's general convolution operator.
1286
+ *
1287
+ * Grouped convolutions are not supported right now.
1288
+ */
1289
+ declare function convGeneralDilated(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, {
1290
+ lhsDilation,
1291
+ rhsDilation
1292
+ }?: {
1293
+ lhsDilation?: number[];
1294
+ rhsDilation?: number[];
1295
+ }): Array;
1296
+ /** Convenience wrapper around `convGeneralDilated`. */
1297
+ declare function convWithGeneralPadding(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType, lhsDilation?: number[], rhsDilation?: number[]): Array;
1298
+ /** Convenience wrapper around `convGeneralDilated`. */
1299
+ declare function conv(lhs: Array, rhs: Array, windowStrides: number[], padding: PaddingType): Array;
1300
+ /** Reduce a computation over padded windows. */
1301
+ declare function reduceWindow(operand: Array, computation: (x: Array) => Array, windowDimensions: number[], windowStrides?: number[]): Array;
1152
1302
  declare namespace nn_d_exports {
1153
- export { identity, logSigmoid, logSoftmax, logsumexp, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1303
+ export { celu, elu, gelu, glu, identity, leakyRelu, logSigmoid, logSoftmax, logsumexp, mish, oneHot, relu, relu6, sigmoid, silu, softSign, softmax, softplus, swish };
1154
1304
  }
1155
1305
  /**
1156
1306
  * Rectified Linear Unit (ReLU) activation function:
@@ -1190,7 +1340,7 @@ declare function softSign(x: ArrayLike): Array;
1190
1340
  *
1191
1341
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1192
1342
  */
1193
- declare function silu(x: ArrayLike): Array;
1343
+ declare const silu: (x: ArrayLike) => Array;
1194
1344
  /**
1195
1345
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
1196
1346
  * Swish, computed element-wise:
@@ -1200,7 +1350,7 @@ declare function silu(x: ArrayLike): Array;
1200
1350
  *
1201
1351
  * Reference: https://en.wikipedia.org/wiki/Swish_function
1202
1352
  */
1203
- declare const swish: typeof silu;
1353
+ declare const swish: (x: ArrayLike) => Array;
1204
1354
  /**
1205
1355
  * Log-sigmoid activation function, computed element-wise:
1206
1356
  * `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
@@ -1208,6 +1358,48 @@ declare const swish: typeof silu;
1208
1358
  declare function logSigmoid(x: ArrayLike): Array;
1209
1359
  /** Identity activation function. Returns the argument unmodified. */
1210
1360
  declare const identity: (x: ArrayLike) => Array;
1361
+ /** Leaky rectified linear (ReLU) activation function */
1362
+ declare function leakyRelu(x: ArrayLike, negativeSlope?: number): Array;
1363
+ /**
1364
+ * Exponential linear unit activation function.
1365
+ *
1366
+ * Computes the element-wise function:
1367
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
1368
+ */
1369
+ declare function elu(x: ArrayLike, alpha?: number): Array;
1370
+ /**
1371
+ * Continuously-differentiable exponential linear unit activation function.
1372
+ *
1373
+ * Computes the element-wise function:
1374
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
1375
+ */
1376
+ declare function celu(x: ArrayLike, alpha?: number): Array;
1377
+ /**
1378
+ * Gaussion error linear unit (GELU) activation function.
1379
+ *
1380
+ * This is computed element-wise. Currently jax-js does not support the erf() or
1381
+ * gelu() functions exactly as primitives, so an approximation is used:
1382
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
1383
+ *
1384
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
1385
+ *
1386
+ * This will be improved in the future.
1387
+ */
1388
+ declare const gelu: (x: ArrayLike) => Array;
1389
+ /**
1390
+ * Gated linear unit (GLU) activation function.
1391
+ *
1392
+ * Splits the `axis` dimension of the input into two halves, a and b, then
1393
+ * computes `a * sigmoid(b)`.
1394
+ */
1395
+ declare function glu(x: ArrayLike, axis?: number): Array;
1396
+ /**
1397
+ * Mish activation function.
1398
+ *
1399
+ * Computes the element-wise function:
1400
+ * `mish(x) = x * tanh(softplus(x))`
1401
+ */
1402
+ declare function mish(x: ArrayLike): Array;
1211
1403
  /**
1212
1404
  * Softmax function. Computes the function which rescales elements to the range
1213
1405
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -1271,38 +1463,49 @@ declare function uniform(key: Array, shape?: number[], {
1271
1463
  }): Array;
1272
1464
  //#endregion
1273
1465
  //#region src/index.d.ts
1274
- /** @inline */
1275
- type WithArgsSubtype<F extends (args: any[]) => any, T> = Parameters<F> extends T ? F : never;
1276
1466
  /** Compute the forward-mode Jacobian-vector product for a function. */
1277
- declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>, tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1467
+ declare const jvp: <F extends (...args: any[]) => JsTree<Array>>(f: F, primals: MapJsTree<Parameters<F>, Array, ArrayLike>, tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, ReturnType<F>];
1278
1468
  /** Vectorize an operation on a batched axis for one or more inputs. */
1279
- declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1469
+ declare const vmap: <F extends (...args: any[]) => JsTree<Array>>(f: F, inAxes?: number | MapJsTree<Parameters<F>, ArrayLike, number | null>) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1280
1470
  /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
1281
- declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1471
+ declare const jacfwd: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1282
1472
  /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
1283
- declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: Parameters<F>) => {
1473
+ declare const makeJaxpr: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...args: Parameters<F>) => {
1284
1474
  jaxpr: Jaxpr;
1285
1475
  consts: Array[];
1286
1476
  treedef: JsTreeDef;
1287
1477
  };
1288
- declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1478
+ /**
1479
+ * Mark a function for automatic JIT compilation, with operator fusion.
1480
+ *
1481
+ * The function will be compiled the first time it is called with a set of
1482
+ * argument shapes.
1483
+ *
1484
+ * **Options:**
1485
+ * - `staticArgnums`: An array of argument indices to treat as static
1486
+ * (compile-time constant). These arguments must be hashable, won't be traced,
1487
+ * and different values will trigger recompilation.
1488
+ * - `device`: The device to place the computation on. If not specified, the
1489
+ * computation will be placed on the default device.
1490
+ */
1491
+ declare const jit: <F extends (...args: any[]) => JsTree<Array>>(f: F, opts?: JitOpts) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1289
1492
  /**
1290
1493
  * Produce a local linear approximation to a function at a point using jvp() and
1291
1494
  * partial evaluation.
1292
1495
  */
1293
- declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>];
1496
+ declare const linearize: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (...tangents: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>];
1294
1497
  /** Calculate the reverse-mode vector-Jacobian product for a function. */
1295
- declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>, ...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1498
+ declare const vjp: <F extends (...args: any[]) => JsTree<Array>>(f: F, ...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, (cotangents: MapJsTree<ReturnType<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>, ArrayLike, Array>];
1296
1499
  /**
1297
1500
  * Compute the gradient of a scalar-valued function `f` with respect to its
1298
1501
  * first argument.
1299
1502
  */
1300
- declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1503
+ declare const grad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => MapJsTree<Parameters<F>[0], ArrayLike, Array>;
1301
1504
  /** Create a function that evaluates both `f` and the gradient of `f`. */
1302
- declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: WithArgsSubtype<F, JsTree<ArrayLike>>) => (...primals: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1505
+ declare const valueAndGrad: <F extends (...args: any[]) => JsTree<Array>>(f: F) => (...primals: MapJsTree<Parameters<F>, Array, ArrayLike>) => [ReturnType<F>, MapJsTree<Parameters<F>[0], ArrayLike, Array>];
1303
1506
  /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
1304
1507
  declare const jacrev: typeof jacfwd;
1305
1508
  /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
1306
- declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, ArrayLike, ArrayLike>) => ReturnType<F>;
1509
+ declare const jacobian: <F extends (x: Array) => Array>(f: F) => (...args: MapJsTree<Parameters<F>, Array, ArrayLike>) => ReturnType<F>;
1307
1510
  //#endregion
1308
- export { type Device, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };
1511
+ export { DType, type Device, type JsTree, type JsTreeDef, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_d_exports as lax, linearize, makeJaxpr, nn_d_exports as nn, numpy_d_exports as numpy, random_d_exports as random, setDevice, tree_d_exports as tree, valueAndGrad, vjp, vmap };