@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,11 +1,12 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip } from "./backend-1eVbAoaV.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-EBRGmEYw.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
6
6
  __export(tree_exports, {
7
7
  JsTreeDef: () => JsTreeDef,
8
8
  NodeType: () => NodeType,
9
+ dispose: () => dispose,
9
10
  flatten: () => flatten,
10
11
  leaves: () => leaves,
11
12
  map: () => map,
@@ -20,7 +21,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
20
21
  NodeType$1["Leaf"] = "Leaf";
21
22
  return NodeType$1;
22
23
  }({});
23
- /** Analog to the JAX "pytree" object, but for JavaScript. */
24
+ /** Represents the structure of a JsTree. */
24
25
  var JsTreeDef = class JsTreeDef {
25
26
  static leaf = new JsTreeDef(NodeType.Leaf, null, []);
26
27
  constructor(nodeType, nodeMetadata, childTreedefs) {
@@ -108,6 +109,194 @@ function map(fn, tree, ...rest) {
108
109
  function ref(tree) {
109
110
  return map((x) => x.ref, tree);
110
111
  }
112
+ /** Dispose every array in a tree. */
113
+ function dispose(tree) {
114
+ if (tree) map((x) => x.dispose(), tree);
115
+ }
116
+
117
+ //#endregion
118
+ //#region src/frontend/convolution.ts
119
+ /**
120
+ * Check that the shapes and parameters passed to convolution are valid.
121
+ *
122
+ * If the check succeeds, returns the output shape.
123
+ */
124
+ function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
125
+ if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
126
+ const n = lhsShape.length - 2;
127
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
128
+ if (strides.length !== n) throw new Error("conv() strides != spatial dims");
129
+ if (padding.length !== n) throw new Error("conv() padding != spatial dims");
130
+ if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
131
+ if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
132
+ if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
133
+ const outShape = [lhsShape[0], rhsShape[0]];
134
+ for (let i = 0; i < n; i++) {
135
+ if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
136
+ if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
137
+ if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
138
+ if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
139
+ const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
140
+ if (k <= 0) throw new Error("conv() kernel size must be positive");
141
+ const [pl, pr] = padding[i];
142
+ if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
143
+ const kernelSize = (k - 1) * rhsDilation[i] + 1;
144
+ const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
145
+ if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
146
+ outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
147
+ }
148
+ return outShape;
149
+ }
150
+ function checkPoolShape(inShape, window, strides) {
151
+ if (strides.length !== window.length) throw new Error("pool() strides != window dims");
152
+ if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
153
+ const outShape = inShape.slice(0, inShape.length - window.length);
154
+ for (let i = 0; i < window.length; i++) {
155
+ const k = window[i];
156
+ const s = strides[i];
157
+ const size$1 = inShape[inShape.length - window.length + i];
158
+ if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
159
+ if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
160
+ if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
161
+ outShape.push(Math.ceil((size$1 - k + 1) / s));
162
+ }
163
+ return outShape.concat(window);
164
+ }
165
+ /**
166
+ * Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
167
+ * `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
168
+ * spatial dimensions convolved with a kernel.
169
+ *
170
+ * The resulting array can be multiplied with a kernel of shape `ks`, then
171
+ * reduced along the last `ks.length` dimensions for a convolution.
172
+ *
173
+ * Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
174
+ */
175
+ function pool(st, ks, strides = 1, dilation = 1) {
176
+ if (ks.length === 0) return st;
177
+ if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
178
+ if (typeof strides === "number") strides = rep(ks.length, strides);
179
+ if (typeof dilation === "number") dilation = rep(ks.length, dilation);
180
+ if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
181
+ if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
182
+ const noop = st.shape.slice(0, -ks.length);
183
+ const i_ = st.shape.slice(-ks.length);
184
+ const s_ = strides;
185
+ const d_ = dilation;
186
+ const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
187
+ const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
188
+ const kidf = zipn(ks, i_, d_, f_);
189
+ st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
190
+ st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
191
+ const kos = zipn(ks, o_, s_);
192
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
193
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
194
+ [0, k],
195
+ [0, o],
196
+ [0, 1]
197
+ ])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
198
+ st = st.permute([
199
+ ...range(noop.length),
200
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
201
+ ...ks.map((_, j) => noop.length + 2 * j)
202
+ ]);
203
+ return st;
204
+ }
205
+ /**
206
+ * Perform the transpose of pool, directly undo-ing a pool() operation.
207
+ *
208
+ * Note that since pool repeats the input, the transpose operation technically
209
+ * should include a sum reduction. This function doesn't perform the reduction,
210
+ * which should be done on the last `k` axes of the returned shape.
211
+ */
212
+ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
213
+ if (ks.length === 0) return st;
214
+ if (typeof strides === "number") strides = rep(ks.length, strides);
215
+ if (typeof dilation === "number") dilation = rep(ks.length, dilation);
216
+ const noop = inShape.slice(0, -ks.length);
217
+ const i_ = inShape.slice(-ks.length);
218
+ const s_ = strides;
219
+ const d_ = dilation;
220
+ const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
221
+ if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
222
+ const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
223
+ const kidf = zipn(ks, i_, d_, f_);
224
+ const kos = zipn(ks, o_, s_);
225
+ st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
226
+ st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
227
+ k,
228
+ o,
229
+ 1
230
+ ])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
231
+ [0, 0],
232
+ [0, 0],
233
+ [0, s - 1]
234
+ ])]);
235
+ st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
236
+ st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
237
+ st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
238
+ ...range(noop.length),
239
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
240
+ ...ks.map((_, j) => noop.length + 2 * j)
241
+ ]);
242
+ return st;
243
+ }
244
+ /** Applies dilation to an array directly, for transposed convolution. */
245
+ function applyDilation(st, dilation) {
246
+ if (dilation.every((s) => s === 1)) return st;
247
+ const s_ = dilation;
248
+ const [a, b, ...k_] = st.shape;
249
+ st = st.reshape([
250
+ a,
251
+ b,
252
+ ...k_.flatMap((k) => [k, 1])
253
+ ]);
254
+ st = st.pad([
255
+ [0, 0],
256
+ [0, 0],
257
+ ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
258
+ ]);
259
+ st = st.reshape([
260
+ a,
261
+ b,
262
+ ...k_.map((k, i) => k * s_[i])
263
+ ]);
264
+ st = st.shrink([
265
+ [0, a],
266
+ [0, b],
267
+ ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
268
+ ]);
269
+ return st;
270
+ }
271
+ /**
272
+ * Prepare for a convolution between two arrays.
273
+ *
274
+ * This does not check the validity of the shapes, which should be checked
275
+ * beforehand using `checkConvShape()`.
276
+ */
277
+ function prepareConv(stX, stY, params) {
278
+ const n = stX.shape.length - 2;
279
+ stX = applyDilation(stX, params.lhsDilation);
280
+ const ks = stY.shape.slice(2);
281
+ stX = stX.padOrShrink([
282
+ [0, 0],
283
+ [0, 0],
284
+ ...params.padding
285
+ ]);
286
+ stX = pool(stX, ks, params.strides, params.rhsDilation);
287
+ stX = stX.moveaxis(1, n + 1).reshape([
288
+ stX.shape[0],
289
+ 1,
290
+ ...stX.shape.slice(2, n + 2),
291
+ stX.shape[1] * prod(ks)
292
+ ]);
293
+ stY = stY.reshape([
294
+ stY.shape[0],
295
+ ...rep(n, 1),
296
+ stY.shape[1] * prod(ks)
297
+ ]);
298
+ return [stX, stY];
299
+ }
111
300
 
112
301
  //#endregion
113
302
  //#region src/frontend/core.ts
@@ -134,12 +323,18 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
134
323
  Primitive$1["RandomBits"] = "random_bits";
135
324
  Primitive$1["Sin"] = "sin";
136
325
  Primitive$1["Cos"] = "cos";
326
+ Primitive$1["Asin"] = "asin";
327
+ Primitive$1["Atan"] = "atan";
137
328
  Primitive$1["Exp"] = "exp";
138
329
  Primitive$1["Log"] = "log";
330
+ Primitive$1["Sqrt"] = "sqrt";
139
331
  Primitive$1["Min"] = "min";
140
332
  Primitive$1["Max"] = "max";
141
333
  Primitive$1["Reduce"] = "reduce";
142
334
  Primitive$1["Dot"] = "dot";
335
+ Primitive$1["Conv"] = "conv";
336
+ Primitive$1["Pool"] = "pool";
337
+ Primitive$1["PoolTranspose"] = "pool_transpose";
143
338
  Primitive$1["Compare"] = "compare";
144
339
  Primitive$1["Where"] = "where";
145
340
  Primitive$1["Transpose"] = "transpose";
@@ -197,34 +392,52 @@ function sin$1(x) {
197
392
  function cos$1(x) {
198
393
  return bind1(Primitive.Cos, [x]);
199
394
  }
395
+ function asin$1(x) {
396
+ return bind1(Primitive.Asin, [x]);
397
+ }
398
+ function atan$1(x) {
399
+ return bind1(Primitive.Atan, [x]);
400
+ }
200
401
  function exp$1(x) {
201
402
  return bind1(Primitive.Exp, [x]);
202
403
  }
203
404
  function log$1(x) {
204
405
  return bind1(Primitive.Log, [x]);
205
406
  }
407
+ function sqrt$1(x) {
408
+ return bind1(Primitive.Sqrt, [x]);
409
+ }
206
410
  function min$1(x, y) {
207
411
  return bind1(Primitive.Min, [x, y]);
208
412
  }
209
413
  function max$1(x, y) {
210
414
  return bind1(Primitive.Max, [x, y]);
211
415
  }
212
- function reduce(x, op, axis, opts) {
416
+ function reduce(x, op, axis = null, opts) {
213
417
  if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
214
- if (axis === void 0) if (x instanceof Tracer) axis = range(x.shape.length);
215
- else axis = [];
216
- else if (typeof axis === "number") axis = [checkAxis(axis, ndim$1(x))];
217
- else axis = axis.map((a) => checkAxis(a, ndim$1(x)));
418
+ axis = normalizeAxis(axis, ndim$1(x));
218
419
  const originalShape = getShape(x);
219
- const result = bind1(Primitive.Reduce, [x], {
420
+ let result = bind1(Primitive.Reduce, [x], {
220
421
  op,
221
422
  axis
222
423
  });
223
- return opts?.keepDims ? broadcast(result, originalShape, axis) : result;
424
+ if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
425
+ return result;
224
426
  }
225
427
  function dot$1(x, y) {
226
428
  return bind1(Primitive.Dot, [x, y]);
227
429
  }
430
+ function conv(x, y, params = {}) {
431
+ if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
432
+ const n = x.ndim - 2;
433
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
434
+ return bind1(Primitive.Conv, [x, y], {
435
+ strides: params.strides ?? rep(n, 1),
436
+ padding: params.padding ?? rep(n, [0, 0]),
437
+ lhsDilation: params.lhsDilation ?? rep(n, 1),
438
+ rhsDilation: params.rhsDilation ?? rep(n, 1)
439
+ });
440
+ }
228
441
  function compare(x, y, op) {
229
442
  return bind1(Primitive.Compare, [x, y], { op });
230
443
  }
@@ -255,10 +468,11 @@ function where$1(cond, x, y) {
255
468
  }
256
469
  function transpose$1(x, perm) {
257
470
  perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
471
+ if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
258
472
  return bind1(Primitive.Transpose, [x], { perm });
259
473
  }
260
474
  function broadcast(x, shape$1, axis) {
261
- axis = axis.map((a) => checkAxis(a, shape$1.length));
475
+ axis = normalizeAxis(axis, shape$1.length);
262
476
  return bind1(Primitive.Broadcast, [x], {
263
477
  shape: shape$1,
264
478
  axis
@@ -277,7 +491,7 @@ function reshape$1(x, shape$1) {
277
491
  return bind1(Primitive.Reshape, [x], { shape: shape$1 });
278
492
  }
279
493
  function flip$1(x, axis) {
280
- axis = axis.map((a) => checkAxis(a, ndim$1(x)));
494
+ axis = normalizeAxis(axis, ndim$1(x));
281
495
  return bind1(Primitive.Flip, [x], { axis });
282
496
  }
283
497
  function shrink(x, slice) {
@@ -357,12 +571,19 @@ var Tracer = class Tracer {
357
571
  constructor(trace) {
358
572
  this._trace = trace;
359
573
  }
574
+ /** The shape of the array. */
360
575
  get shape() {
361
576
  return this.aval.shape;
362
577
  }
578
+ /** The total number of elements in the array. */
579
+ get size() {
580
+ return prod(this.shape);
581
+ }
582
+ /** The dtype of the array. */
363
583
  get dtype() {
364
584
  return this.aval.dtype;
365
585
  }
586
+ /** The number of dimensions of the array. */
366
587
  get ndim() {
367
588
  return this.shape.length;
368
589
  }
@@ -398,22 +619,20 @@ var Tracer = class Tracer {
398
619
  return lessEqual$1(this, other);
399
620
  }
400
621
  /** Sum of the elements of the array over a given axis, or axes. */
401
- sum(axis, opts) {
622
+ sum(axis = null, opts) {
402
623
  return reduce(this, AluOp.Add, axis, opts);
403
624
  }
404
625
  /** Product of the array elements over a given axis. */
405
- prod(axis, opts) {
626
+ prod(axis = null, opts) {
406
627
  return reduce(this, AluOp.Mul, axis, opts);
407
628
  }
408
629
  /** Compute the average of the array elements along the specified axis. */
409
- mean(axis, opts) {
410
- if (axis === void 0) axis = range(this.ndim);
411
- else if (typeof axis === "number") axis = [checkAxis(axis, this.ndim)];
412
- else axis = axis.map((a) => checkAxis(a, this.ndim));
413
- let result = reduce(this, AluOp.Add, axis);
414
- result = result.mul(prod(result.shape) / prod(this.shape));
415
- if (opts?.keepDims) result = broadcast(result, this.shape, axis);
416
- return result;
630
+ mean(axis = null, opts) {
631
+ axis = normalizeAxis(axis, this.ndim);
632
+ const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
633
+ if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
634
+ const result = reduce(this, AluOp.Add, axis, opts);
635
+ return result.mul(1 / n);
417
636
  }
418
637
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
419
638
  transpose(perm) {
@@ -445,8 +664,29 @@ var Tracer = class Tracer {
445
664
  /** Return specified diagonals. See `numpy.diagonal` for full docs. */
446
665
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
447
666
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
667
+ if (offset < 0) return this.diagonal(-offset, axis2, axis1);
668
+ axis1 = checkAxis(axis1, this.ndim);
669
+ axis2 = checkAxis(axis2, this.ndim);
448
670
  if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
449
- throw new Error("diagonal not implemented");
671
+ if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
672
+ let ar = this;
673
+ if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
674
+ const perm = range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
675
+ ar = ar.transpose(perm);
676
+ }
677
+ const [n, m] = ar.shape.slice(-2);
678
+ const diagSize = Math.min(n, m - offset);
679
+ ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
680
+ const npad = diagSize * (m + 1) - n * m;
681
+ if (npad > 0) ar = pad$1(ar, [...rep(ar.ndim - 1, [0, 0]), [0, npad]]);
682
+ else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
683
+ ar = ar.reshape([
684
+ ...ar.shape.slice(0, -1),
685
+ diagSize,
686
+ m + 1
687
+ ]);
688
+ ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
689
+ return ar;
450
690
  }
451
691
  /** Flatten the array without changing its data. */
452
692
  flatten() {
@@ -589,7 +829,7 @@ var ShapedArray = class ShapedArray {
589
829
  get ndim() {
590
830
  return this.shape.length;
591
831
  }
592
- strShort() {
832
+ toString() {
593
833
  return `${this.dtype}[${this.shape.join(",")}]`;
594
834
  }
595
835
  equals(other) {
@@ -620,7 +860,7 @@ function fullRaise(trace, val) {
620
860
  if (Object.is(val._trace.main, trace.main)) return val;
621
861
  else if (val._trace.main.level < level) return trace.lift(val);
622
862
  else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
623
- else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
863
+ else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
624
864
  }
625
865
  var TreeMismatchError = class extends TypeError {
626
866
  constructor(where$2, left, right) {
@@ -869,16 +1109,16 @@ function jitCompile(backend, jaxpr, consts) {
869
1109
  jitCompileCache.set(cacheKey, jp);
870
1110
  return jp;
871
1111
  }
872
- function reshapeViews(exp$2, mapping) {
1112
+ function reshapeViews(exp$2, mapping, reduceAxis = false) {
873
1113
  return exp$2.rewrite((exp$3) => {
874
1114
  if (exp$3.op === AluOp.GlobalView) {
875
1115
  const [gid, st] = exp$3.arg;
876
1116
  const newSt = mapping(st);
877
1117
  if (newSt) {
878
- const indices = unravelAlu(newSt.shape, AluVar.gidx);
1118
+ const indices = reduceAxis ? unravelAlu(newSt.shape.slice(0, -1), AluVar.gidx).concat(AluVar.ridx) : unravelAlu(newSt.shape, AluVar.gidx);
879
1119
  return AluExp.globalView(exp$3.dtype, gid, newSt, indices);
880
1120
  }
881
- }
1121
+ } else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
882
1122
  });
883
1123
  }
884
1124
  function broadcastedJit(fn) {
@@ -925,8 +1165,11 @@ const jitRules = {
925
1165
  },
926
1166
  [Primitive.Sin]: unopJit(AluExp.sin),
927
1167
  [Primitive.Cos]: unopJit(AluExp.cos),
1168
+ [Primitive.Asin]: unopJit(AluExp.asin),
1169
+ [Primitive.Atan]: unopJit(AluExp.atan),
928
1170
  [Primitive.Exp]: unopJit(AluExp.exp),
929
1171
  [Primitive.Log]: unopJit(AluExp.log),
1172
+ [Primitive.Sqrt]: unopJit(AluExp.sqrt),
930
1173
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
931
1174
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
932
1175
  [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
@@ -941,18 +1184,20 @@ const jitRules = {
941
1184
  const size$1 = prod(newShape);
942
1185
  const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
943
1186
  newShape.push(reductionSize);
944
- a = a.rewrite((exp$2) => {
945
- if (exp$2.op === AluOp.GlobalView) {
946
- const [gid, st] = exp$2.arg;
947
- const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
948
- const indices = unravelAlu(newShape.slice(0, -1), AluVar.gidx);
949
- indices.push(AluVar.ridx);
950
- return AluExp.globalView(exp$2.dtype, gid, newSt, indices);
951
- }
952
- });
1187
+ const perm = keptAxes.concat(shiftedAxes);
1188
+ a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
953
1189
  const reduction = new Reduction(a.dtype, op, reductionSize);
954
1190
  return new Kernel(nargs, size$1, a, reduction);
955
1191
  },
1192
+ [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1193
+ [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
1194
+ let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
1195
+ const size$1 = prod(inShape);
1196
+ stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
1197
+ a = reshapeViews(a, (st) => st.compose(stX), true);
1198
+ const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
1199
+ return new Kernel(nargs, size$1, a, reduction);
1200
+ },
956
1201
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
957
1202
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
958
1203
  const c = k1.exp;
@@ -962,6 +1207,14 @@ const jitRules = {
962
1207
  axis: [cs.ndim - 1]
963
1208
  });
964
1209
  },
1210
+ [Primitive.Conv](nargs, [a, b], [as, bs], params) {
1211
+ const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1212
+ a = reshapeViews(a, (st) => st.compose(stX));
1213
+ b = reshapeViews(b, (st) => st.compose(stY));
1214
+ as = new ShapedArray(stX.shape, as.dtype);
1215
+ bs = new ShapedArray(stY.shape, bs.dtype);
1216
+ return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1217
+ },
965
1218
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
966
1219
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
967
1220
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
@@ -974,8 +1227,20 @@ const jitRules = {
974
1227
  }),
975
1228
  [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
976
1229
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
977
- [Primitive.Gather]() {
978
- throw new Error("Gather is not implemented in JIT yet");
1230
+ [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1231
+ const axisSet = new Set(axis);
1232
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1233
+ const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1234
+ finalShape.splice(outDim, 0, ...indexShape);
1235
+ const idxAll = unravelAlu(finalShape, AluVar.gidx);
1236
+ const idxNonaxis = [...idxAll];
1237
+ idxNonaxis.splice(outDim, indexShape.length);
1238
+ const src = [...idxNonaxis];
1239
+ for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1240
+ for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
1241
+ const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
1242
+ if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1243
+ return new Kernel(nargs, prod(finalShape), x.substitute({ gidx: index }));
979
1244
  },
980
1245
  [Primitive.JitCall]() {
981
1246
  throw new Error("internal: JitCall should have been flattened before JIT compilation");
@@ -994,9 +1259,15 @@ function splitGraphDataflow(backend, jaxpr) {
994
1259
  blackNodes.add(v);
995
1260
  p1NextBlack.set(v, v);
996
1261
  }
1262
+ const reducePrimitives = [
1263
+ Primitive.Reduce,
1264
+ Primitive.Dot,
1265
+ Primitive.Conv,
1266
+ Primitive.PoolTranspose
1267
+ ];
997
1268
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
998
1269
  const eqn = jaxpr.eqns[i];
999
- if (eqn.primitive === Primitive.Reduce || eqn.outBinders.some((v) => blackNodes.has(v))) {
1270
+ if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1000
1271
  for (const v of eqn.outBinders) {
1001
1272
  blackNodes.add(v);
1002
1273
  p1NextBlack.set(v, v);
@@ -1137,7 +1408,7 @@ var Array$1 = class Array$1 extends Tracer {
1137
1408
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1138
1409
  * will be freed when the array is disposed.
1139
1410
  */
1140
- constructor(source, st, dtype, backend, pending = null) {
1411
+ constructor(source, st, dtype, backend, { pending = null } = {}) {
1141
1412
  super(baseArrayTrace);
1142
1413
  this.id = Array$1.#nextId++;
1143
1414
  this.#dtype = dtype;
@@ -1146,6 +1417,8 @@ var Array$1 = class Array$1 extends Tracer {
1146
1417
  this.#backend = backend;
1147
1418
  this.#rc = 1;
1148
1419
  this.#pendingSet = new Set(pending);
1420
+ if (this.#pendingSet.size === 0) this.#pendingSet = null;
1421
+ else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1149
1422
  }
1150
1423
  /** @ignore */
1151
1424
  get aval() {
@@ -1200,7 +1473,7 @@ var Array$1 = class Array$1 extends Tracer {
1200
1473
  const pending = this.#pending;
1201
1474
  for (const exe of pending) exe.updateRc(1);
1202
1475
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1203
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
1476
+ const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1204
1477
  this.dispose();
1205
1478
  return ar;
1206
1479
  }
@@ -1223,7 +1496,7 @@ var Array$1 = class Array$1 extends Tracer {
1223
1496
  const inputs = [];
1224
1497
  const src = [...idxNonaxis];
1225
1498
  for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1226
- for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#dtype, ar.#source, ar.#st, idxAxis));
1499
+ for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#source, ar.#st, idxAxis));
1227
1500
  else {
1228
1501
  let gid = inputs.indexOf(ar.#source);
1229
1502
  if (gid === -1) {
@@ -1233,7 +1506,7 @@ var Array$1 = class Array$1 extends Tracer {
1233
1506
  src[axis[i]] = AluExp.cast(DType.Int32, AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
1234
1507
  }
1235
1508
  let exp$2;
1236
- if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#dtype, this.#source, this.#st, src);
1509
+ if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, src);
1237
1510
  else {
1238
1511
  let gid = inputs.indexOf(this.#source);
1239
1512
  if (gid === -1) {
@@ -1249,7 +1522,7 @@ var Array$1 = class Array$1 extends Tracer {
1249
1522
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1250
1523
  this.dispose();
1251
1524
  for (const ar of indices) ar.dispose();
1252
- return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
1525
+ return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1253
1526
  }
1254
1527
  /** Move axes to the rightmost dimension of the shape. */
1255
1528
  #moveAxesDown(axis) {
@@ -1276,7 +1549,7 @@ var Array$1 = class Array$1 extends Tracer {
1276
1549
  this.#check();
1277
1550
  if (this.#source instanceof AluExp) {
1278
1551
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1279
- return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
1552
+ return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1280
1553
  }
1281
1554
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1282
1555
  const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1286,7 +1559,7 @@ var Array$1 = class Array$1 extends Tracer {
1286
1559
  for (const exe of pending) exe.updateRc(1);
1287
1560
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1288
1561
  this.dispose();
1289
- return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
1562
+ return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1290
1563
  }
1291
1564
  #binary(op, other) {
1292
1565
  const custom = (src) => new AluExp(op, this.#dtype, src);
@@ -1309,18 +1582,18 @@ var Array$1 = class Array$1 extends Tracer {
1309
1582
  if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1310
1583
  arrays = Array$1.#broadcastArrays(arrays);
1311
1584
  const newShape = [...arrays[0].shape];
1312
- if (arrays.every((ar) => ar.#source instanceof AluExp) && reduceAxis === void 0) {
1585
+ if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
1313
1586
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1314
1587
  const exp$4 = custom(arrays.map((ar) => ar.#source));
1315
- return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
1588
+ return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1316
1589
  }
1317
1590
  const exp$3 = custom(arrays.map((ar) => {
1318
1591
  const src$1 = ar.#source;
1319
1592
  if (ar.#st.contiguous) return src$1;
1320
- return accessorAluExp(ar.#dtype, src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1593
+ return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1321
1594
  }));
1322
1595
  const st = ShapeTracker.fromShape(newShape);
1323
- return new Array$1(exp$3, st, exp$3.dtype, backend);
1596
+ return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1324
1597
  }
1325
1598
  let indices;
1326
1599
  if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
@@ -1330,7 +1603,7 @@ var Array$1 = class Array$1 extends Tracer {
1330
1603
  }
1331
1604
  const inputs = [];
1332
1605
  const src = [];
1333
- for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#dtype, ar.#source, ar.#st, indices));
1606
+ for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
1334
1607
  else {
1335
1608
  let gid = inputs.indexOf(ar.#source);
1336
1609
  if (gid === -1) {
@@ -1351,7 +1624,7 @@ var Array$1 = class Array$1 extends Tracer {
1351
1624
  for (const exe of pending) exe.updateRc(1);
1352
1625
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1353
1626
  for (const ar of arrays) ar.dispose();
1354
- return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
1627
+ return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1355
1628
  }
1356
1629
  /** Reduce the last dimension of the array by an operation. */
1357
1630
  #reduce(op) {
@@ -1364,7 +1637,7 @@ var Array$1 = class Array$1 extends Tracer {
1364
1637
  const indices = [...unravelAlu(newShape, AluVar.gidx), AluVar.ridx];
1365
1638
  let exp$2;
1366
1639
  const inputs = [];
1367
- if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1640
+ if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, indices);
1368
1641
  else {
1369
1642
  inputs.push(this.#source);
1370
1643
  exp$2 = accessorGlobal(this.#dtype, 0, this.#st, indices);
@@ -1375,7 +1648,7 @@ var Array$1 = class Array$1 extends Tracer {
1375
1648
  for (const exe of pending) exe.updateRc(1);
1376
1649
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1377
1650
  this.dispose();
1378
- return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
1651
+ return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1379
1652
  }
1380
1653
  /**
1381
1654
  * Normalizes this array into one backed by a `Slot`.
@@ -1389,7 +1662,7 @@ var Array$1 = class Array$1 extends Tracer {
1389
1662
  this.#check();
1390
1663
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1391
1664
  if (this.#source instanceof AluExp) {
1392
- const exp$2 = accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1665
+ const exp$2 = accessorAluExp(this.#source, this.#st, indices);
1393
1666
  const kernel = new Kernel(0, this.#st.size, exp$2);
1394
1667
  const output = this.#backend.malloc(kernel.bytes);
1395
1668
  const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
@@ -1427,42 +1700,54 @@ var Array$1 = class Array$1 extends Tracer {
1427
1700
  }
1428
1701
  /** Realize the array and return it as data. */
1429
1702
  async data() {
1430
- if (this.#source instanceof AluExp && prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1703
+ if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1431
1704
  this.#realize();
1432
1705
  const pending = this.#pending;
1433
1706
  if (pending) {
1434
1707
  await Promise.all(pending.map((p) => p.prepare()));
1435
1708
  for (const p of pending) p.submit();
1436
1709
  }
1437
- const byteCount = byteWidth(this.#dtype) * prod(this.shape);
1710
+ const byteCount = byteWidth(this.#dtype) * this.size;
1438
1711
  const buf = await this.#backend.read(this.#source, 0, byteCount);
1439
1712
  this.dispose();
1440
1713
  return dtypedArray(this.dtype, buf);
1441
1714
  }
1442
- /** Wait for this array to be placed on the backend, if needed. */
1443
- async wait() {
1715
+ /**
1716
+ * Wait for this array to finish evaluation.
1717
+ *
1718
+ * Operations and data loading in jax-js are lazy, so this function ensures
1719
+ * that pending operations are dispatched and fully executed before it
1720
+ * returns.
1721
+ *
1722
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1723
+ * dispatch of operations as well.
1724
+ *
1725
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1726
+ * asynchronously for multiple arrays.
1727
+ */
1728
+ async blockUntilReady() {
1444
1729
  this.#check();
1445
- if (this.#source instanceof AluExp) return;
1730
+ if (this.#source instanceof AluExp) return this;
1446
1731
  const pending = this.#pending;
1447
1732
  if (pending) {
1448
1733
  await Promise.all(pending.map((p) => p.prepare()));
1449
1734
  for (const p of pending) p.submit();
1450
1735
  }
1451
1736
  await this.#backend.read(this.#source, 0, 0);
1452
- this.dispose();
1737
+ return this;
1453
1738
  }
1454
1739
  /**
1455
1740
  * Realize the array and return it as data. This is a sync variant and not
1456
1741
  * recommended for performance reasons, as it will block rendering.
1457
1742
  */
1458
1743
  dataSync() {
1459
- if (this.#source instanceof AluExp && prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1744
+ if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1460
1745
  this.#realize();
1461
1746
  for (const p of this.#pending) {
1462
1747
  p.prepareSync();
1463
1748
  p.submit();
1464
1749
  }
1465
- const byteCount = byteWidth(this.#dtype) * prod(this.shape);
1750
+ const byteCount = byteWidth(this.#dtype) * this.size;
1466
1751
  const buf = this.#backend.readSync(this.#source, 0, byteCount);
1467
1752
  this.dispose();
1468
1753
  return dtypedArray(this.dtype, buf);
@@ -1483,6 +1768,16 @@ var Array$1 = class Array$1 extends Tracer {
1483
1768
  async jsAsync() {
1484
1769
  return dataToJs(this.dtype, await this.data(), this.shape);
1485
1770
  }
1771
+ /**
1772
+ * Copy an element of an array to a numeric scalar and return it.
1773
+ *
1774
+ * Throws an error if the array does not have a single element. The array must
1775
+ * either be rank-0, or all dimensions of the shape are 1.
1776
+ */
1777
+ item() {
1778
+ if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
1779
+ return this.dataSync()[0];
1780
+ }
1486
1781
  /** @private Internal plumbing method for Array / Tracer ops. */
1487
1782
  static _implRules() {
1488
1783
  return {
@@ -1496,7 +1791,7 @@ var Array$1 = class Array$1 extends Tracer {
1496
1791
  return [x.#binary(AluOp.Idiv, y)];
1497
1792
  },
1498
1793
  [Primitive.Neg]([x]) {
1499
- return [zerosLike(x).#binary(AluOp.Sub, x)];
1794
+ return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
1500
1795
  },
1501
1796
  [Primitive.Reciprocal]([x]) {
1502
1797
  return [x.#unary(AluOp.Reciprocal)];
@@ -1516,7 +1811,7 @@ var Array$1 = class Array$1 extends Tracer {
1516
1811
  x.#backend.incRef(x.#source);
1517
1812
  const pending = x.#pending;
1518
1813
  for (const exe of pending) exe.updateRc(1);
1519
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
1814
+ const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1520
1815
  x.dispose();
1521
1816
  return [y];
1522
1817
  }
@@ -1546,12 +1841,21 @@ var Array$1 = class Array$1 extends Tracer {
1546
1841
  [Primitive.Cos]([x]) {
1547
1842
  return [x.#unary(AluOp.Cos)];
1548
1843
  },
1844
+ [Primitive.Asin]([x]) {
1845
+ return [x.#unary(AluOp.Asin)];
1846
+ },
1847
+ [Primitive.Atan]([x]) {
1848
+ return [x.#unary(AluOp.Atan)];
1849
+ },
1549
1850
  [Primitive.Exp]([x]) {
1550
1851
  return [x.#unary(AluOp.Exp)];
1551
1852
  },
1552
1853
  [Primitive.Log]([x]) {
1553
1854
  return [x.#unary(AluOp.Log)];
1554
1855
  },
1856
+ [Primitive.Sqrt]([x]) {
1857
+ return [x.#unary(AluOp.Sqrt)];
1858
+ },
1555
1859
  [Primitive.Min]([x, y]) {
1556
1860
  return [x.#binary(AluOp.Min, y)];
1557
1861
  },
@@ -1562,9 +1866,24 @@ var Array$1 = class Array$1 extends Tracer {
1562
1866
  if (axis.length === 0) return [x];
1563
1867
  return [x.#moveAxesDown(axis).#reduce(op)];
1564
1868
  },
1869
+ [Primitive.Pool]([x], { window, strides }) {
1870
+ const st = pool(x.#st, window, strides);
1871
+ return [x.#reshape(st)];
1872
+ },
1873
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
1874
+ const n = inShape.length;
1875
+ let st = poolTranspose(x.#st, inShape, window, strides);
1876
+ st = st.reshape([...st.shape.slice(0, n), prod(st.shape.slice(n))]);
1877
+ return [x.#reshape(st).#reduce(AluOp.Add)];
1878
+ },
1565
1879
  [Primitive.Dot]([x, y]) {
1566
1880
  return [Array$1.#naryCustom("dot", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
1567
1881
  },
1882
+ [Primitive.Conv]([x, y], params) {
1883
+ checkConvShape(x.shape, y.shape, params);
1884
+ const [stX, stY] = prepareConv(x.#st, y.#st, params);
1885
+ return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
1886
+ },
1568
1887
  [Primitive.Compare]([x, y], { op }) {
1569
1888
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1570
1889
  return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
@@ -1613,7 +1932,7 @@ var Array$1 = class Array$1 extends Tracer {
1613
1932
  pending.splice(0, 0, ...prevPending);
1614
1933
  args.forEach((x) => x.dispose());
1615
1934
  return outputs.map((source, i) => {
1616
- return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
1935
+ return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
1617
1936
  });
1618
1937
  }
1619
1938
  };
@@ -1629,6 +1948,7 @@ function scalar(value, { dtype, device } = {}) {
1629
1948
  dtype ??= DType.Float32;
1630
1949
  if (![
1631
1950
  DType.Float32,
1951
+ DType.Float16,
1632
1952
  DType.Int32,
1633
1953
  DType.Uint32
1634
1954
  ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
@@ -1636,6 +1956,7 @@ function scalar(value, { dtype, device } = {}) {
1636
1956
  dtype ??= DType.Bool;
1637
1957
  if (![
1638
1958
  DType.Float32,
1959
+ DType.Float16,
1639
1960
  DType.Int32,
1640
1961
  DType.Uint32,
1641
1962
  DType.Bool
@@ -1649,7 +1970,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1649
1970
  if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1650
1971
  if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
1651
1972
  return values;
1652
- } else if (values instanceof Float32Array || values instanceof Int32Array) return arrayFromData(values, shape$1 ?? [values.length], {
1973
+ } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1653
1974
  dtype,
1654
1975
  device
1655
1976
  });
@@ -1678,7 +1999,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1678
1999
  });
1679
2000
  } else {
1680
2001
  dtype = dtype ?? DType.Float32;
1681
- const data = dtypedArray(dtype, flat);
2002
+ const data = dtypedJsArray(dtype, flat);
1682
2003
  return arrayFromData(data, shape$1, {
1683
2004
  dtype,
1684
2005
  device
@@ -1699,19 +2020,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
1699
2020
  });
1700
2021
  }
1701
2022
  const backend = getBackend(device);
1702
- if (data instanceof Float32Array) {
1703
- if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
1704
- const slot = backend.malloc(data.byteLength, data.buffer);
1705
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), DType.Float32, backend);
1706
- } else if (data instanceof Int32Array) {
1707
- if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
1708
- const slot = backend.malloc(data.byteLength, data.buffer);
1709
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype ?? DType.Int32, backend);
1710
- } else if (data instanceof Uint32Array) {
1711
- if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
1712
- const slot = backend.malloc(data.byteLength, data.buffer);
1713
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), DType.Uint32, backend);
1714
- } else throw new Error("Unsupported data type");
2023
+ if (ArrayBuffer.isView(data)) {
2024
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2025
+ if (data instanceof Float32Array) {
2026
+ if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2027
+ dtype ??= DType.Float32;
2028
+ } else if (data instanceof Int32Array) {
2029
+ if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2030
+ dtype ??= DType.Int32;
2031
+ } else if (data instanceof Uint32Array) {
2032
+ if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2033
+ dtype ??= DType.Uint32;
2034
+ } else if (data instanceof Float16Array) {
2035
+ if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2036
+ dtype ??= DType.Float16;
2037
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2038
+ const slot = backend.malloc(data.byteLength, buf);
2039
+ return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
2040
+ } else throw new Error("Unsupported data type: " + data.constructor.name);
1715
2041
  }
1716
2042
  function dataToJs(dtype, data, shape$1) {
1717
2043
  if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
@@ -1738,9 +2064,20 @@ var EvalTrace = class extends Trace {
1738
2064
  };
1739
2065
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
1740
2066
  const implRules = Array$1._implRules();
1741
- function zerosLike(val) {
2067
+ function zerosLike$1(val, dtype) {
2068
+ const aval = getAval(val);
2069
+ if (val instanceof Tracer) val.dispose();
2070
+ return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2071
+ }
2072
+ function onesLike$1(val, dtype) {
2073
+ const aval = getAval(val);
2074
+ if (val instanceof Tracer) val.dispose();
2075
+ return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2076
+ }
2077
+ function fullLike(val, fillValue, dtype) {
1742
2078
  const aval = getAval(val);
1743
- return zeros(aval.shape, { dtype: aval.dtype });
2079
+ if (val instanceof Tracer) val.dispose();
2080
+ return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
1744
2081
  }
1745
2082
  /** Return a new array of given shape and type, filled with zeros. */
1746
2083
  function zeros(shape$1, { dtype, device } = {}) {
@@ -1762,6 +2099,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
1762
2099
  if (typeof fillValue === "number") {
1763
2100
  dtype = dtype ?? DType.Float32;
1764
2101
  source = AluExp.const(dtype, fillValue);
2102
+ } else if (typeof fillValue === "bigint") {
2103
+ dtype = dtype ?? DType.Int32;
2104
+ source = AluExp.const(dtype, Number(fillValue));
1765
2105
  } else if (typeof fillValue === "boolean") {
1766
2106
  dtype = dtype ?? DType.Bool;
1767
2107
  source = AluExp.const(dtype, fillValue ? 1 : 0);
@@ -1792,7 +2132,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
1792
2132
  const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
1793
2133
  return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
1794
2134
  }
1795
- /** Return the identity array, with ones on the main diagonal. */
2135
+ /** Return the identity matrix, with ones on the main diagonal. */
1796
2136
  function identity$1(n, { dtype, device } = {}) {
1797
2137
  return eye(n, n, {
1798
2138
  dtype,
@@ -1859,7 +2199,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
1859
2199
  const st = ShapeTracker.fromShape([num]);
1860
2200
  return new Array$1(exp$2, st, dtype, getBackend(device));
1861
2201
  }
1862
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
1863
2202
  function aluCompare(a, b, op) {
1864
2203
  switch (op) {
1865
2204
  case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
@@ -1901,7 +2240,7 @@ function generalBroadcast(a, b) {
1901
2240
  }
1902
2241
 
1903
2242
  //#endregion
1904
- //#region node_modules/.pnpm/@oxc-project+runtime@0.77.3/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
2243
+ //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
1905
2244
  function _usingCtx() {
1906
2245
  var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
1907
2246
  var n$1 = Error();
@@ -1969,7 +2308,7 @@ var Var = class Var {
1969
2308
  this.aval = aval;
1970
2309
  }
1971
2310
  toString() {
1972
- return `Var(${this.id}):${this.aval.strShort()}`;
2311
+ return `Var(${this.id}):${this.aval.toString()}`;
1973
2312
  }
1974
2313
  };
1975
2314
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
@@ -2009,7 +2348,7 @@ var VarPrinter = class {
2009
2348
  return name;
2010
2349
  }
2011
2350
  nameType(v) {
2012
- return `${this.name(v)}:${v.aval.strShort()}`;
2351
+ return `${this.name(v)}:${v.aval.toString()}`;
2013
2352
  }
2014
2353
  };
2015
2354
  /** A single statement / binding in a Jaxpr, in ANF form. */
@@ -2069,16 +2408,19 @@ var Jaxpr = class Jaxpr {
2069
2408
  varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
2070
2409
  return id;
2071
2410
  };
2072
- hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
2073
- hasher.update(this.eqns.length, ...this.eqns.flatMap((eqn) => [
2074
- eqn.primitive,
2075
- eqn.inputs.length,
2076
- ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
2077
- JSON.stringify(eqn.params),
2078
- eqn.outBinders.length,
2079
- ...eqn.outBinders.map(vi)
2080
- ]));
2081
- hasher.update(this.outs.length, ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value));
2411
+ hasher.update(this.inBinders.length);
2412
+ for (const x of this.inBinders) hasher.update(vi(x));
2413
+ hasher.update(this.eqns.length);
2414
+ for (const eqn of this.eqns) {
2415
+ hasher.update(eqn.primitive);
2416
+ hasher.update(eqn.inputs.length);
2417
+ for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
2418
+ hasher.update(JSON.stringify(eqn.params));
2419
+ hasher.update(eqn.outBinders.length);
2420
+ for (const x of eqn.outBinders) hasher.update(vi(x));
2421
+ }
2422
+ hasher.update(this.outs.length);
2423
+ for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
2082
2424
  return this.#hash = hasher.value;
2083
2425
  }
2084
2426
  hash(state) {
@@ -2115,7 +2457,7 @@ var Jaxpr = class Jaxpr {
2115
2457
  const c = eqn.outBinders[0];
2116
2458
  if (atomIsLit(b, 1)) context.set(c, a);
2117
2459
  else newEqns.push(eqn);
2118
- } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2460
+ } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2119
2461
  else newEqns.push(eqn);
2120
2462
  }
2121
2463
  const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
@@ -2164,8 +2506,8 @@ var JaxprType = class {
2164
2506
  this.outTypes = outTypes;
2165
2507
  }
2166
2508
  toString() {
2167
- const inTypes = this.inTypes.map((aval) => aval.strShort()).join(", ");
2168
- const outTypes = this.outTypes.map((aval) => aval.strShort()).join(", ");
2509
+ const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
2510
+ const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
2169
2511
  return `(${inTypes}) -> (${outTypes})`;
2170
2512
  }
2171
2513
  };
@@ -2244,7 +2586,7 @@ var JaxprTracer = class extends Tracer {
2244
2586
  this.aval = aval;
2245
2587
  }
2246
2588
  toString() {
2247
- return `JaxprTracer(${this.aval.strShort()})`;
2589
+ return `JaxprTracer(${this.aval.toString()})`;
2248
2590
  }
2249
2591
  get ref() {
2250
2592
  return this;
@@ -2381,8 +2723,11 @@ const abstractEvalRules = {
2381
2723
  },
2382
2724
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2383
2725
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2726
+ [Primitive.Asin]: vectorizedUnopAbstractEval,
2727
+ [Primitive.Atan]: vectorizedUnopAbstractEval,
2384
2728
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2385
2729
  [Primitive.Log]: vectorizedUnopAbstractEval,
2730
+ [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2386
2731
  [Primitive.Min]: binopAbstractEval,
2387
2732
  [Primitive.Max]: binopAbstractEval,
2388
2733
  [Primitive.Reduce]([x], { axis }) {
@@ -2390,6 +2735,15 @@ const abstractEvalRules = {
2390
2735
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2391
2736
  return [new ShapedArray(newShape, x.dtype)];
2392
2737
  },
2738
+ [Primitive.Pool]([x], { window, strides }) {
2739
+ const shape$1 = checkPoolShape(x.shape, window, strides);
2740
+ return [new ShapedArray(shape$1, x.dtype)];
2741
+ },
2742
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2743
+ const shape$1 = checkPoolShape(inShape, window, strides);
2744
+ if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2745
+ return [new ShapedArray(inShape, x.dtype)];
2746
+ },
2393
2747
  [Primitive.Dot]([x, y]) {
2394
2748
  if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2395
2749
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
@@ -2397,6 +2751,11 @@ const abstractEvalRules = {
2397
2751
  shape$1.splice(-1, 1);
2398
2752
  return [new ShapedArray(shape$1, x.dtype)];
2399
2753
  },
2754
+ [Primitive.Conv]([lhs, rhs], params) {
2755
+ if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2756
+ const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2757
+ return [new ShapedArray(shape$1, lhs.dtype)];
2758
+ },
2400
2759
  [Primitive.Compare]: compareAbstractEval,
2401
2760
  [Primitive.Where]([cond, x, y]) {
2402
2761
  if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
@@ -2444,15 +2803,34 @@ const abstractEvalRules = {
2444
2803
  return outTypes;
2445
2804
  }
2446
2805
  };
2447
- function makeJaxpr$1(f) {
2806
+ function splitIdx(values, argnums) {
2807
+ const a = [];
2808
+ const b = [];
2809
+ for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
2810
+ else b.push(values[i]);
2811
+ return [a, b];
2812
+ }
2813
+ function joinIdx(n, a, b, argnums) {
2814
+ const result = [];
2815
+ let ai = 0;
2816
+ let bi = 0;
2817
+ for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
2818
+ else result.push(b[bi++]);
2819
+ return result;
2820
+ }
2821
+ function makeJaxpr$1(f, opts) {
2448
2822
  return (...argsIn) => {
2449
2823
  try {
2450
2824
  var _usingCtx$1 = _usingCtx();
2451
- const [avalsIn, inTree] = flatten(argsIn);
2452
- const [fFlat, outTree] = flattenFun(f, inTree);
2825
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2826
+ const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
2827
+ const [avalsIn, inTree] = flatten(shapedArgs);
2828
+ const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
2829
+ return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
2830
+ }, inTree);
2453
2831
  const builder = new JaxprBuilder();
2454
2832
  const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
2455
- const _dynamic = _usingCtx$1.u(newDynamic(main));
2833
+ _usingCtx$1.u(newDynamic(main));
2456
2834
  const trace = new JaxprTrace(main);
2457
2835
  const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
2458
2836
  const outs = fFlat(...tracersIn);
@@ -2471,20 +2849,27 @@ function makeJaxpr$1(f) {
2471
2849
  }
2472
2850
  };
2473
2851
  }
2474
- function jit$1(f) {
2852
+ function jit$1(f, opts) {
2475
2853
  const cache = /* @__PURE__ */ new Map();
2476
- return ((...args) => {
2477
- const [argsFlat, inTree] = flatten(args);
2854
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2855
+ const result = ((...args) => {
2856
+ const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2857
+ const [argsFlat, inTree] = flatten(dynamicArgs);
2478
2858
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
2479
2859
  const avalsIn = unflatten(inTree, avalsInFlat);
2480
- const cacheKey = JSON.stringify(avalsIn);
2481
- const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f)(...avalsIn));
2860
+ const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
2861
+ const cacheKey = JSON.stringify(jaxprArgs);
2862
+ const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2482
2863
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2483
2864
  jaxpr,
2484
2865
  numConsts: consts.length
2485
2866
  });
2486
2867
  return unflatten(outTree, outs);
2487
2868
  });
2869
+ result.dispose = () => {
2870
+ for (const { consts } of cache.values()) for (const c of consts) c.dispose();
2871
+ };
2872
+ return result;
2488
2873
  }
2489
2874
 
2490
2875
  //#endregion
@@ -2515,7 +2900,7 @@ var JVPTrace = class extends Trace {
2515
2900
  return this.lift(pureArray(val));
2516
2901
  }
2517
2902
  lift(val) {
2518
- return new JVPTracer(this, val, zerosLike(val));
2903
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
2519
2904
  }
2520
2905
  processPrimitive(primitive, tracers, params) {
2521
2906
  const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2533,19 +2918,25 @@ function linearTangentsJvp(primitive) {
2533
2918
  return [ys, dys];
2534
2919
  };
2535
2920
  }
2921
+ /** Rule for product of gradients in bilinear operations. */
2922
+ function bilinearTangentsJvp(primitive) {
2923
+ return ([x, y], [dx, dy], params) => {
2924
+ const primal = bind1(primitive, [x.ref, y.ref], params);
2925
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
2926
+ return [[primal], [tangent]];
2927
+ };
2928
+ }
2536
2929
  /** Rule that zeros out any tangents. */
2537
2930
  function zeroTangentsJvp(primitive) {
2538
2931
  return (primals, tangents, params) => {
2539
2932
  for (const t of tangents) t.dispose();
2540
2933
  const ys = bind(primitive, primals, params);
2541
- return [ys, ys.map((y) => zerosLike(y))];
2934
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
2542
2935
  };
2543
2936
  }
2544
2937
  const jvpRules = {
2545
2938
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
2546
- [Primitive.Mul]([x, y], [dx, dy]) {
2547
- return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
2548
- },
2939
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
2549
2940
  [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
2550
2941
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
2551
2942
  [Primitive.Reciprocal]([x], [dx]) {
@@ -2558,13 +2949,13 @@ const jvpRules = {
2558
2949
  if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2559
2950
  else {
2560
2951
  dx.dispose();
2561
- return [[cast(x, dtype)], [zerosLike(x)]];
2952
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
2562
2953
  }
2563
2954
  },
2564
2955
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2565
2956
  if (x.dtype === dtype) return [[x], [dx]];
2566
2957
  dx.dispose();
2567
- return [[bitcast(x, dtype)], [zerosLike(x)]];
2958
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
2568
2959
  },
2569
2960
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2570
2961
  [Primitive.Sin]([x], [dx]) {
@@ -2573,6 +2964,14 @@ const jvpRules = {
2573
2964
  [Primitive.Cos]([x], [dx]) {
2574
2965
  return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
2575
2966
  },
2967
+ [Primitive.Asin]([x], [dx]) {
2968
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
2969
+ return [[asin$1(x)], [denom.mul(dx)]];
2970
+ },
2971
+ [Primitive.Atan]([x], [dx]) {
2972
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
2973
+ return [[atan$1(x)], [dx.div(denom)]];
2974
+ },
2576
2975
  [Primitive.Exp]([x], [dx]) {
2577
2976
  const z = exp$1(x);
2578
2977
  return [[z.ref], [z.mul(dx)]];
@@ -2580,6 +2979,10 @@ const jvpRules = {
2580
2979
  [Primitive.Log]([x], [dx]) {
2581
2980
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
2582
2981
  },
2982
+ [Primitive.Sqrt]([x], [dx]) {
2983
+ const z = sqrt$1(x);
2984
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
2985
+ },
2583
2986
  [Primitive.Min]([x, y], [dx, dy]) {
2584
2987
  return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
2585
2988
  },
@@ -2596,13 +2999,14 @@ const jvpRules = {
2596
2999
  const primal = reduce(x.ref, op, axis);
2597
3000
  const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
2598
3001
  const minCount = where$1(notMin.ref, 0, 1).sum(axis);
2599
- const tangent = where$1(notMin, op === AluOp.Min ? 0 : 0, dx).sum(axis).mul(reciprocal$1(minCount));
3002
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
2600
3003
  return [[primal], [tangent]];
2601
3004
  } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
2602
3005
  },
2603
- [Primitive.Dot]([x, y], [dx, dy]) {
2604
- return [[dot$1(x.ref, y.ref)], [dot$1(dx, y).add(dot$1(x, dy))]];
2605
- },
3006
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3007
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3008
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3009
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
2606
3010
  [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
2607
3011
  [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
2608
3012
  dcond.dispose();
@@ -2683,7 +3087,10 @@ function mappedAval(batchDim, aval) {
2683
3087
  /** Move one axis to a different index. */
2684
3088
  function moveaxis$1(x, src, dst) {
2685
3089
  const t = pureArray(x);
2686
- const perm = range(t.shape.length);
3090
+ src = checkAxis(src, t.ndim);
3091
+ dst = checkAxis(dst, t.ndim);
3092
+ if (src === dst) return t;
3093
+ const perm = range(t.ndim);
2687
3094
  perm.splice(src, 1);
2688
3095
  perm.splice(dst, 0, src);
2689
3096
  return transpose$1(t, perm);
@@ -2776,8 +3183,11 @@ const vmapRules = {
2776
3183
  [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
2777
3184
  [Primitive.Sin]: unopBatcher(sin$1),
2778
3185
  [Primitive.Cos]: unopBatcher(cos$1),
3186
+ [Primitive.Asin]: unopBatcher(asin$1),
3187
+ [Primitive.Atan]: unopBatcher(atan$1),
2779
3188
  [Primitive.Exp]: unopBatcher(exp$1),
2780
3189
  [Primitive.Log]: unopBatcher(log$1),
3190
+ [Primitive.Sqrt]: unopBatcher(sqrt$1),
2781
3191
  [Primitive.Min]: broadcastBatcher(min$1),
2782
3192
  [Primitive.Max]: broadcastBatcher(max$1),
2783
3193
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
@@ -2914,7 +3324,7 @@ var PartialVal = class PartialVal {
2914
3324
  return this.val !== null;
2915
3325
  }
2916
3326
  toString() {
2917
- return this.val ? this.val.toString() : this.aval.strShort();
3327
+ return this.val ? this.val.toString() : this.aval.toString();
2918
3328
  }
2919
3329
  };
2920
3330
  function partialEvalFlat(f, pvalsIn) {
@@ -2960,20 +3370,28 @@ function linearizeFlatUtil(f, primalsIn) {
2960
3370
  function linearizeFlat(f, primalsIn) {
2961
3371
  const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
2962
3372
  const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
2963
- return [primalsOut, fLin];
3373
+ const dispose$1 = () => {
3374
+ for (const c of consts) c.dispose();
3375
+ };
3376
+ return [
3377
+ primalsOut,
3378
+ fLin,
3379
+ dispose$1
3380
+ ];
2964
3381
  }
2965
3382
  function linearize$1(f, ...primalsIn) {
2966
3383
  const [primalsInFlat, inTree] = flatten(primalsIn);
2967
3384
  const [fFlat, outTree] = flattenFun(f, inTree);
2968
- const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3385
+ const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
2969
3386
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
2970
3387
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
2971
- const fLin = (...tangentsIn) => {
3388
+ const fLin = ((...tangentsIn) => {
2972
3389
  const [tangentsInFlat, inTree2] = flatten(tangentsIn);
2973
3390
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
2974
3391
  const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
2975
3392
  return unflatten(outTree.value, tangentsOutFlat);
2976
- };
3393
+ });
3394
+ fLin.dispose = dispose$1;
2977
3395
  return [primalsOut, fLin];
2978
3396
  }
2979
3397
  var PartialEvalTracer = class extends Tracer {
@@ -3089,7 +3507,10 @@ var PartialEvalTrace = class extends Trace {
3089
3507
  avalsOut: jaxpr2.outs.map((x) => x.aval),
3090
3508
  tracerRefsOut: []
3091
3509
  };
3092
- const outs2 = jaxpr2.outs.map((x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe));
3510
+ const outs2 = jaxpr2.outs.map((x, i$1) => {
3511
+ if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
3512
+ return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
3513
+ });
3093
3514
  recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3094
3515
  let i = 0;
3095
3516
  let j = 0;
@@ -3173,13 +3594,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3173
3594
  const [consts, constvars] = unzip2(constToVar.entries());
3174
3595
  const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
3175
3596
  const outVars = tracersOut.map((t) => tracerToVar.get(t));
3176
- const jaxpr = new Jaxpr(inBinders, eqns, outVars);
3597
+ let jaxpr = new Jaxpr(inBinders, eqns, outVars);
3177
3598
  typecheckJaxpr(jaxpr);
3178
3599
  for (const t of consts) t.ref;
3179
3600
  for (const t of tracersIn) t.dispose();
3180
3601
  for (const t of tracersOut) t.dispose();
3602
+ jaxpr = jaxpr.simplify();
3603
+ if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3181
3604
  return {
3182
- jaxpr: jaxpr.simplify(),
3605
+ jaxpr,
3183
3606
  consts
3184
3607
  };
3185
3608
  }
@@ -3288,12 +3711,72 @@ const transposeRules = {
3288
3711
  if (op === AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
3289
3712
  else throw new NonlinearError(Primitive.Reduce);
3290
3713
  },
3714
+ [Primitive.Pool]([ct], [x], { window, strides }) {
3715
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
3716
+ return bind(Primitive.PoolTranspose, [ct], {
3717
+ inShape: x.aval.shape,
3718
+ window,
3719
+ strides
3720
+ });
3721
+ },
3722
+ [Primitive.PoolTranspose]([ct], [x], { window, strides }) {
3723
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
3724
+ return bind(Primitive.Pool, [ct], {
3725
+ window,
3726
+ strides
3727
+ });
3728
+ },
3291
3729
  [Primitive.Dot]([ct], [x, y]) {
3292
3730
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3293
3731
  const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3294
3732
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3295
3733
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3296
3734
  },
3735
+ [Primitive.Conv]([ct], [lhs, rhs], params) {
3736
+ if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
3737
+ const rev01 = [
3738
+ 1,
3739
+ 0,
3740
+ ...range(2, ct.ndim)
3741
+ ];
3742
+ if (lhs instanceof UndefPrimal) {
3743
+ let kernel = rhs;
3744
+ kernel = transpose$1(kernel, rev01);
3745
+ kernel = flip$1(kernel, range(2, kernel.ndim));
3746
+ const result = conv(ct, kernel, {
3747
+ strides: params.lhsDilation,
3748
+ padding: params.padding.map(([pl, _pr], i) => {
3749
+ const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3750
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3751
+ const padBefore = dilatedKernel - 1 - pl;
3752
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3753
+ const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
3754
+ return [padBefore, padAfter];
3755
+ }),
3756
+ lhsDilation: params.strides,
3757
+ rhsDilation: params.rhsDilation
3758
+ });
3759
+ return [result, null];
3760
+ } else {
3761
+ const newLhs = transpose$1(lhs, rev01);
3762
+ const newRhs = transpose$1(ct, rev01);
3763
+ let result = conv(newLhs, newRhs, {
3764
+ strides: params.rhsDilation,
3765
+ padding: params.padding.map(([pl, _pr], i) => {
3766
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3767
+ const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3768
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3769
+ const padFromLhs = dilatedCt - dilatedLhs;
3770
+ const padFromRhs = dilatedKernel - pl - 1;
3771
+ return [pl, padFromLhs + padFromRhs];
3772
+ }),
3773
+ lhsDilation: params.lhsDilation,
3774
+ rhsDilation: params.strides
3775
+ });
3776
+ result = transpose$1(result, rev01);
3777
+ return [null, result];
3778
+ }
3779
+ },
3297
3780
  [Primitive.Where]([ct], [cond, x, y]) {
3298
3781
  const cts = [
3299
3782
  null,
@@ -3385,20 +3868,28 @@ function vjpFlat(f, primalsIn) {
3385
3868
  const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
3386
3869
  return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3387
3870
  };
3388
- return [primalsOut, fVjp];
3871
+ const dispose$1 = () => {
3872
+ for (const c of consts) c.dispose();
3873
+ };
3874
+ return [
3875
+ primalsOut,
3876
+ fVjp,
3877
+ dispose$1
3878
+ ];
3389
3879
  }
3390
3880
  function vjp$1(f, ...primalsIn) {
3391
3881
  const [primalsInFlat, inTree] = flatten(primalsIn);
3392
3882
  const [fFlat, outTree] = flattenFun(f, inTree);
3393
- const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3883
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3394
3884
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
3395
3885
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3396
- const fVjp = (cotangentsOut) => {
3886
+ const fVjp = ((cotangentsOut) => {
3397
3887
  const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3398
3888
  if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
3399
3889
  const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3400
3890
  return unflatten(inTree, cotangentsInFlat);
3401
- };
3891
+ });
3892
+ fVjp.dispose = dispose$1;
3402
3893
  return [primalsOut, fVjp];
3403
3894
  }
3404
3895
  function grad$1(f) {
@@ -3414,9 +3905,10 @@ function valueAndGrad$1(f) {
3414
3905
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
3415
3906
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3416
3907
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3417
- if (y.dtype !== DType.Float32) throw new TypeError("grad currently only supports float32");
3418
- const [ct, ...rest] = fVjp(pureArray(1));
3419
- for (const r of rest) r.dispose();
3908
+ if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3909
+ const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3910
+ for (const r of rest) dispose(r);
3911
+ fVjp.dispose();
3420
3912
  return [y, ct];
3421
3913
  };
3422
3914
  }
@@ -3424,11 +3916,84 @@ function jacrev$1(f) {
3424
3916
  return function jacobianReverse(x) {
3425
3917
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
3426
3918
  const [size$1] = x.shape;
3427
- const pullback = (ct) => vjp$1(f, x)[1](ct)[0];
3919
+ const pullback = (ct) => {
3920
+ const [y, fVjp] = vjp$1(f, x);
3921
+ y.dispose();
3922
+ const [ret] = fVjp(ct);
3923
+ fVjp.dispose();
3924
+ return ret;
3925
+ };
3428
3926
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
3429
3927
  };
3430
3928
  }
3431
3929
 
3930
+ //#endregion
3931
+ //#region src/lax.ts
3932
+ var lax_exports = {};
3933
+ __export(lax_exports, {
3934
+ conv: () => conv$1,
3935
+ convGeneralDilated: () => convGeneralDilated,
3936
+ convWithGeneralPadding: () => convWithGeneralPadding,
3937
+ reduceWindow: () => reduceWindow
3938
+ });
3939
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
3940
+ const padType = padding.toUpperCase();
3941
+ switch (padType) {
3942
+ case "VALID": return rep(inShape.length, [0, 0]);
3943
+ case "SAME":
3944
+ case "SAME_LOWER": {
3945
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
3946
+ const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
3947
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
3948
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
3949
+ }
3950
+ default: throw new Error(`Unknown padding type: ${padType}`);
3951
+ }
3952
+ }
3953
+ /**
3954
+ * General n-dimensional convolution operator, with optional dilation.
3955
+ *
3956
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
3957
+ * function in JAX, which wraps XLA's general convolution operator.
3958
+ *
3959
+ * Grouped convolutions are not supported right now.
3960
+ */
3961
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
3962
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
3963
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
3964
+ if (typeof padding === "string") {
3965
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
3966
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
3967
+ }
3968
+ return conv(lhs, rhs, {
3969
+ strides: windowStrides,
3970
+ padding,
3971
+ lhsDilation,
3972
+ rhsDilation
3973
+ });
3974
+ }
3975
+ /** Convenience wrapper around `convGeneralDilated`. */
3976
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
3977
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
3978
+ lhsDilation,
3979
+ rhsDilation
3980
+ });
3981
+ }
3982
+ /** Convenience wrapper around `convGeneralDilated`. */
3983
+ function conv$1(lhs, rhs, windowStrides, padding) {
3984
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
3985
+ }
3986
+ /** Reduce a computation over padded windows. */
3987
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
3988
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
3989
+ if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
3990
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
3991
+ return computation(bind1(Primitive.Pool, [operand], {
3992
+ window: windowDimensions,
3993
+ strides: windowStrides
3994
+ }));
3995
+ }
3996
+
3432
3997
  //#endregion
3433
3998
  //#region src/numpy.ts
3434
3999
  var numpy_exports = {};
@@ -3437,19 +4002,38 @@ __export(numpy_exports, {
3437
4002
  DType: () => DType,
3438
4003
  abs: () => abs,
3439
4004
  absolute: () => absolute,
4005
+ acos: () => acos,
4006
+ acosh: () => acosh,
3440
4007
  add: () => add,
3441
4008
  allclose: () => allclose,
3442
4009
  arange: () => arange,
4010
+ arccos: () => arccos,
4011
+ arccosh: () => arccosh,
4012
+ arcsinh: () => arcsinh,
4013
+ arctan: () => arctan,
4014
+ arctan2: () => arctan2,
4015
+ arctanh: () => arctanh,
3443
4016
  argmax: () => argmax,
3444
4017
  argmin: () => argmin,
3445
4018
  array: () => array,
4019
+ asin: () => asin,
4020
+ asinh: () => asinh,
3446
4021
  astype: () => astype,
4022
+ atan: () => atan,
4023
+ atan2: () => atan2,
4024
+ atanh: () => atanh,
3447
4025
  bool: () => bool,
4026
+ broadcastArrays: () => broadcastArrays,
4027
+ broadcastShapes: () => broadcastShapes,
4028
+ broadcastTo: () => broadcastTo,
4029
+ cbrt: () => cbrt,
3448
4030
  clip: () => clip,
3449
4031
  columnStack: () => columnStack,
3450
- complex64: () => complex64,
3451
4032
  concatenate: () => concatenate,
3452
4033
  cos: () => cos,
4034
+ cosh: () => cosh,
4035
+ deg2rad: () => deg2rad,
4036
+ degrees: () => degrees,
3453
4037
  diag: () => diag,
3454
4038
  diagonal: () => diagonal,
3455
4039
  divide: () => divide,
@@ -3460,23 +4044,29 @@ __export(numpy_exports, {
3460
4044
  eulerGamma: () => eulerGamma,
3461
4045
  exp: () => exp,
3462
4046
  exp2: () => exp2,
4047
+ expm1: () => expm1,
3463
4048
  eye: () => eye,
3464
4049
  flip: () => flip,
3465
4050
  fliplr: () => fliplr,
3466
4051
  flipud: () => flipud,
4052
+ float16: () => float16,
3467
4053
  float32: () => float32,
3468
4054
  full: () => full,
4055
+ fullLike: () => fullLike$1,
3469
4056
  greater: () => greater,
3470
4057
  greaterEqual: () => greaterEqual,
3471
4058
  hstack: () => hstack,
4059
+ hypot: () => hypot,
3472
4060
  identity: () => identity$1,
3473
4061
  inf: () => inf,
4062
+ inner: () => inner,
3474
4063
  int32: () => int32,
3475
4064
  less: () => less,
3476
4065
  lessEqual: () => lessEqual,
3477
4066
  linspace: () => linspace,
3478
4067
  log: () => log,
3479
4068
  log10: () => log10,
4069
+ log1p: () => log1p,
3480
4070
  log2: () => log2,
3481
4071
  matmul: () => matmul,
3482
4072
  max: () => max,
@@ -3492,36 +4082,55 @@ __export(numpy_exports, {
3492
4082
  negative: () => negative,
3493
4083
  notEqual: () => notEqual,
3494
4084
  ones: () => ones,
4085
+ onesLike: () => onesLike,
4086
+ outer: () => outer,
3495
4087
  pad: () => pad,
3496
4088
  permuteDims: () => permuteDims,
3497
4089
  pi: () => pi,
4090
+ pow: () => pow,
4091
+ power: () => power,
3498
4092
  prod: () => prod$1,
4093
+ promoteTypes: () => promoteTypes,
4094
+ rad2deg: () => rad2deg,
4095
+ radians: () => radians,
3499
4096
  ravel: () => ravel,
3500
4097
  reciprocal: () => reciprocal,
4098
+ repeat: () => repeat,
3501
4099
  reshape: () => reshape,
3502
- scalar: () => scalar,
3503
4100
  shape: () => shape,
4101
+ sign: () => sign,
3504
4102
  sin: () => sin,
4103
+ sinh: () => sinh,
3505
4104
  size: () => size,
4105
+ sqrt: () => sqrt,
3506
4106
  square: () => square,
3507
4107
  stack: () => stack,
4108
+ std: () => std,
4109
+ subtract: () => subtract,
3508
4110
  sum: () => sum,
3509
4111
  tan: () => tan,
4112
+ tanh: () => tanh,
4113
+ tile: () => tile,
3510
4114
  transpose: () => transpose,
4115
+ tri: () => tri,
4116
+ tril: () => tril,
4117
+ triu: () => triu,
3511
4118
  trueDivide: () => trueDivide,
3512
4119
  trunc: () => trunc,
3513
4120
  uint32: () => uint32,
4121
+ var_: () => var_,
3514
4122
  vdot: () => vdot,
3515
4123
  vecdot: () => vecdot,
3516
4124
  vstack: () => vstack,
3517
4125
  where: () => where,
3518
- zeros: () => zeros
4126
+ zeros: () => zeros,
4127
+ zerosLike: () => zerosLike
3519
4128
  });
3520
4129
  const float32 = DType.Float32;
3521
4130
  const int32 = DType.Int32;
3522
4131
  const uint32 = DType.Uint32;
3523
4132
  const bool = DType.Bool;
3524
- const complex64 = DType.Complex64;
4133
+ const float16 = DType.Float16;
3525
4134
  /** Euler's constant, `e = 2.7182818284590...` */
3526
4135
  const e = Math.E;
3527
4136
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -3532,52 +4141,66 @@ const inf = Number.POSITIVE_INFINITY;
3532
4141
  const nan = NaN;
3533
4142
  /** This is Pi, `π = 3.14159265358979...` */
3534
4143
  const pi = Math.PI;
3535
- /** Element-wise addition, with broadcasting. */
4144
+ /** @function Element-wise addition, with broadcasting. */
3536
4145
  const add = add$1;
3537
- /** Element-wise multiplication, with broadcasting. */
4146
+ /** @function Element-wise multiplication, with broadcasting. */
3538
4147
  const multiply = mul;
3539
- /** Numerical negative of every element of an array. */
4148
+ /** @function Numerical negative of every element of an array. */
3540
4149
  const negative = neg;
3541
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
4150
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
3542
4151
  const reciprocal = reciprocal$1;
3543
- /** Element-wise sine function (takes radians). */
4152
+ /** @function Element-wise sine function (takes radians). */
3544
4153
  const sin = sin$1;
3545
- /** Element-wise cosine function (takes radians). */
4154
+ /** @function Element-wise cosine function (takes radians). */
3546
4155
  const cos = cos$1;
3547
- /** Calculate the exponential of all elements in the input array. */
4156
+ /** @function Element-wise inverse sine function (inverse of sin). */
4157
+ const asin = asin$1;
4158
+ /** @function Element-wise inverse tangent function (inverse of tan). */
4159
+ const atan = atan$1;
4160
+ /** @function Calculate the exponential of all elements in the input array. */
3548
4161
  const exp = exp$1;
3549
- /** Calculate the natural logarithm of all elements in the input array. */
4162
+ /** @function Calculate the natural logarithm of all elements in the input array. */
3550
4163
  const log = log$1;
3551
- /** Return element-wise minimum of the input arrays. */
4164
+ /** @function Calculate the square root of all elements in the input array. */
4165
+ const sqrt = sqrt$1;
4166
+ /** @function Return element-wise minimum of the input arrays. */
3552
4167
  const minimum = min$1;
3553
- /** Return element-wise maximum of the input arrays. */
4168
+ /** @function Return element-wise maximum of the input arrays. */
3554
4169
  const maximum = max$1;
3555
- /** Compare two arrays element-wise. */
4170
+ /** @function Compare two arrays element-wise. */
3556
4171
  const greater = greater$1;
3557
- /** Compare two arrays element-wise. */
4172
+ /** @function Compare two arrays element-wise. */
3558
4173
  const less = less$1;
3559
- /** Compare two arrays element-wise. */
4174
+ /** @function Compare two arrays element-wise. */
3560
4175
  const equal = equal$1;
3561
- /** Compare two arrays element-wise. */
4176
+ /** @function Compare two arrays element-wise. */
3562
4177
  const notEqual = notEqual$1;
3563
- /** Compare two arrays element-wise. */
4178
+ /** @function Compare two arrays element-wise. */
3564
4179
  const greaterEqual = greaterEqual$1;
3565
- /** Compare two arrays element-wise. */
4180
+ /** @function Compare two arrays element-wise. */
3566
4181
  const lessEqual = lessEqual$1;
3567
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4182
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
3568
4183
  const where = where$1;
3569
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
4184
+ /**
4185
+ * @function
4186
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
4187
+ */
3570
4188
  const transpose = transpose$1;
3571
4189
  /**
4190
+ * @function
3572
4191
  * Give a new shape to an array without changing its data.
3573
4192
  *
3574
4193
  * One shape dimension can be -1. In this case, the value is inferred from the
3575
4194
  * length of the array and remaining dimensions.
3576
4195
  */
3577
4196
  const reshape = reshape$1;
3578
- /** Move axes of an array to new positions. Other axes retain original order. */
4197
+ /**
4198
+ * @function
4199
+ * Move axes of an array to new positions. Other axes retain original order.
4200
+ */
3579
4201
  const moveaxis = moveaxis$1;
3580
4202
  /**
4203
+ * @function
3581
4204
  * Add padding (zeros) to an array.
3582
4205
  *
3583
4206
  * The `width` argument is either an integer or pair of integers, in which case
@@ -3585,11 +4208,29 @@ const moveaxis = moveaxis$1;
3585
4208
  * pair specifies the padding for its corresponding axis.
3586
4209
  */
3587
4210
  const pad = pad$1;
3588
- /** Return the number of dimensions of an array. Does not consume array reference. */
4211
+ /**
4212
+ * @function
4213
+ * Return the number of dimensions of an array. Does not consume array reference.
4214
+ */
3589
4215
  const ndim = ndim$1;
3590
- /** Return the shape of an array. Does not consume array reference. */
4216
+ /** @function Return the shape of an array. Does not consume array reference. */
3591
4217
  const shape = getShape;
3592
4218
  /**
4219
+ * @function
4220
+ * Return an array of zeros with the same shape and type as a given array.
4221
+ */
4222
+ const zerosLike = zerosLike$1;
4223
+ /**
4224
+ * @function
4225
+ * Return an array of ones with the same shape and type as a given array.
4226
+ */
4227
+ const onesLike = onesLike$1;
4228
+ /**
4229
+ * @function
4230
+ * Return a full array with the same shape and type as a given array.
4231
+ */
4232
+ const fullLike$1 = fullLike;
4233
+ /**
3593
4234
  * Return the number of elements in an array, optionally along an axis.
3594
4235
  * Does not consume array reference.
3595
4236
  */
@@ -3602,23 +4243,23 @@ function astype(a, dtype) {
3602
4243
  return fudgeArray(a).astype(dtype);
3603
4244
  }
3604
4245
  /** Sum of the elements of the array over a given axis, or axes. */
3605
- function sum(a, axis, opts) {
4246
+ function sum(a, axis = null, opts) {
3606
4247
  return reduce(a, AluOp.Add, axis, opts);
3607
4248
  }
3608
4249
  /** Product of the array elements over a given axis. */
3609
- function prod$1(a, axis, opts) {
4250
+ function prod$1(a, axis = null, opts) {
3610
4251
  return reduce(a, AluOp.Mul, axis, opts);
3611
4252
  }
3612
4253
  /** Return the minimum of array elements along a given axis. */
3613
- function min(a, axis, opts) {
4254
+ function min(a, axis = null, opts) {
3614
4255
  return reduce(a, AluOp.Min, axis, opts);
3615
4256
  }
3616
4257
  /** Return the maximum of array elements along a given axis. */
3617
- function max(a, axis, opts) {
4258
+ function max(a, axis = null, opts) {
3618
4259
  return reduce(a, AluOp.Max, axis, opts);
3619
4260
  }
3620
4261
  /** Compute the average of the array elements along the specified axis. */
3621
- function mean(a, axis, opts) {
4262
+ function mean(a, axis = null, opts) {
3622
4263
  return fudgeArray(a).mean(axis, opts);
3623
4264
  }
3624
4265
  /**
@@ -3634,18 +4275,12 @@ function argmin(a, axis, opts) {
3634
4275
  axis = 0;
3635
4276
  } else axis = checkAxis(axis, a.ndim);
3636
4277
  const shape$1 = a.shape;
3637
- const isMax = equal(a, min(a.ref, axis, { keepDims: true }));
4278
+ const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
3638
4279
  const length = scalar(shape$1[axis], {
3639
4280
  dtype: int32,
3640
4281
  device: a.device
3641
4282
  });
3642
- const idx = where(isMax, scalar(1, {
3643
- dtype: int32,
3644
- device: a.device
3645
- }), scalar(0, {
3646
- dtype: int32,
3647
- device: a.device
3648
- })).mul(arange(shape$1[axis], 0, -1, {
4283
+ const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3649
4284
  dtype: int32,
3650
4285
  device: a.device
3651
4286
  }).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
@@ -3664,35 +4299,21 @@ function argmax(a, axis, opts) {
3664
4299
  axis = 0;
3665
4300
  } else axis = checkAxis(axis, a.ndim);
3666
4301
  const shape$1 = a.shape;
3667
- const isMax = equal(a, max(a.ref, axis, { keepDims: true }));
4302
+ const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
3668
4303
  const length = scalar(shape$1[axis], {
3669
4304
  dtype: int32,
3670
4305
  device: a.device
3671
4306
  });
3672
- const idx = where(isMax, scalar(1, {
3673
- dtype: int32,
3674
- device: a.device
3675
- }), scalar(0, {
3676
- dtype: int32,
3677
- device: a.device
3678
- })).mul(arange(shape$1[axis], 0, -1, {
4307
+ const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3679
4308
  dtype: int32,
3680
4309
  device: a.device
3681
4310
  }).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
3682
4311
  return length.sub(max(idx, axis, opts));
3683
4312
  }
3684
4313
  /** Reverse the elements in an array along the given axes. */
3685
- function flip(x, axis) {
4314
+ function flip(x, axis = null) {
3686
4315
  const nd = ndim(x);
3687
- if (axis === void 0) axis = range(nd);
3688
- else if (typeof axis === "number") axis = [axis];
3689
- const seen = /* @__PURE__ */ new Set();
3690
- for (let i = 0; i < axis.length; i++) {
3691
- if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
3692
- if (axis[i] < 0) axis[i] += nd;
3693
- if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
3694
- seen.add(axis[i]);
3695
- }
4316
+ axis = normalizeAxis(axis, nd);
3696
4317
  return flip$1(x, axis);
3697
4318
  }
3698
4319
  /**
@@ -3798,18 +4419,88 @@ function flipud(x) {
3798
4419
  function fliplr(x) {
3799
4420
  return flip(x, 1);
3800
4421
  }
4422
+ /** @function Alternative name for `numpy.transpose()`. */
3801
4423
  const permuteDims = transpose;
3802
4424
  /** Return a 1-D flattened array containing the elements of the input. */
3803
4425
  function ravel(a) {
3804
4426
  return fudgeArray(a).ravel();
3805
4427
  }
3806
4428
  /**
4429
+ * Repeat each element of an array after themselves.
4430
+ *
4431
+ * If no axis is provided, use the flattened input array, and return a flat
4432
+ * output array.
4433
+ */
4434
+ function repeat(a, repeats, axis) {
4435
+ if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
4436
+ a = fudgeArray(a);
4437
+ if (axis === void 0) {
4438
+ a = ravel(a);
4439
+ axis = 0;
4440
+ }
4441
+ axis = checkAxis(axis, a.ndim);
4442
+ if (repeats === 1) return a;
4443
+ const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
4444
+ const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
4445
+ return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
4446
+ }
4447
+ /**
4448
+ * Construct an array by repeating A the number of times given by reps.
4449
+ *
4450
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
4451
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
4452
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
4453
+ */
4454
+ function tile(a, reps) {
4455
+ a = fudgeArray(a);
4456
+ if (typeof reps === "number") reps = [reps];
4457
+ if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
4458
+ const ndiff = reps.length - a.ndim;
4459
+ if (ndiff > 0) a = a.reshape([...rep(ndiff, 1), ...a.shape]);
4460
+ if (ndiff < 0) reps = [...rep(-ndiff, 1), ...reps];
4461
+ const broadcastedShape = [];
4462
+ const broadcastAxes = [];
4463
+ for (let i = 0; i < a.ndim; i++) {
4464
+ if (reps[i] > 1) {
4465
+ broadcastedShape.push(reps[i]);
4466
+ broadcastAxes.push(broadcastedShape.length - 1);
4467
+ }
4468
+ broadcastedShape.push(a.shape[i]);
4469
+ }
4470
+ const finalShape = a.shape.map((d, i) => reps[i] * d);
4471
+ return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
4472
+ }
4473
+ /**
4474
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
4475
+ *
4476
+ * In other words, this lets you append axes to the left, and/or expand
4477
+ * dimensions where the shape is 1.
4478
+ */
4479
+ function broadcastTo(a, shape$1) {
4480
+ const nd = ndim(a);
4481
+ if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
4482
+ return broadcast(a, shape$1, range(shape$1.length - nd));
4483
+ }
4484
+ /** Broadcast input shapes to a common output shape. */
4485
+ function broadcastShapes(...shapes) {
4486
+ if (shapes.length === 0) return [];
4487
+ return shapes.reduce(generalBroadcast);
4488
+ }
4489
+ /** Broadcast arrays to a common shape. */
4490
+ function broadcastArrays(...arrays) {
4491
+ const shapes = arrays.map((a) => shape(a));
4492
+ const outShape = broadcastShapes(...shapes);
4493
+ return arrays.map((a) => broadcastTo(a, outShape));
4494
+ }
4495
+ /**
3807
4496
  * Return specified diagonals.
3808
4497
  *
3809
4498
  * If a is 2D, return the diagonal of the array with the given offset. If a is
3810
- * 3D or higher, compute diagonals along the two given axes.
4499
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
3811
4500
  *
3812
- * This returns a view over the existing array.
4501
+ * This returns a view over the existing array. The shape of the resulting array
4502
+ * is determined by removing the two axes along which the diagonal is taken,
4503
+ * then appending a new axis to the right with holding the diagonals.
3813
4504
  */
3814
4505
  function diagonal(a, offset, axis1, axis2) {
3815
4506
  return fudgeArray(a).diagonal(offset, axis1, axis2);
@@ -3825,15 +4516,16 @@ function diag(v, k = 0) {
3825
4516
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
3826
4517
  if (a.ndim === 1) {
3827
4518
  const n = a.shape[0];
3828
- const ret = where(eye(n).equal(1), a, 0);
3829
- if (k !== 0) throw new Error("diag() for 1D arrays only for k=0");
3830
- return ret;
4519
+ const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
4520
+ if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4521
+ else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4522
+ else return ret;
3831
4523
  } else if (a.ndim === 2) return diagonal(a, k);
3832
4524
  else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
3833
4525
  }
3834
4526
  /** Return if two arrays are element-wise equal within a tolerance. */
3835
4527
  function allclose(actual, expected, options) {
3836
- const { rtol = 1e-5, atol = 1e-8 } = options ?? {};
4528
+ const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
3837
4529
  const x = array(actual);
3838
4530
  const y = array(expected);
3839
4531
  if (!deepEqual(x.shape, y.shape)) return false;
@@ -3868,8 +4560,36 @@ function dot(x, y) {
3868
4560
  ]);
3869
4561
  return dot$1(x, y);
3870
4562
  }
3871
- /** Vector dot product of two arrays. */
3872
- function vecdot(x, y) {
4563
+ /**
4564
+ * Compute the inner product of two arrays.
4565
+ *
4566
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
4567
+ * contraction on the last axis.
4568
+ *
4569
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
4570
+ */
4571
+ function inner(x, y) {
4572
+ x = reshape(x, shape(x).toSpliced(-1, 0, ...rep(ndim(y) - 1, 1)));
4573
+ return dot$1(x, y);
4574
+ }
4575
+ /**
4576
+ * Compute the outer product of two arrays.
4577
+ *
4578
+ * If the input arrays are not 1D, they will be flattened. Returned array will
4579
+ * be of shape `[x.size, y.size]`.
4580
+ */
4581
+ function outer(x, y) {
4582
+ x = ravel(x);
4583
+ y = ravel(y);
4584
+ return multiply(x.reshape([x.shape[0], 1]), y);
4585
+ }
4586
+ /** Vector dot product of two arrays along a given axis. */
4587
+ function vecdot(x, y, { axis } = {}) {
4588
+ const xaxis = checkAxis(axis ?? -1, ndim(x));
4589
+ const yaxis = checkAxis(axis ?? -1, ndim(y));
4590
+ if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
4591
+ x = moveaxis(x, xaxis, -1);
4592
+ y = moveaxis(y, yaxis, -1);
3873
4593
  return dot$1(x, y);
3874
4594
  }
3875
4595
  /**
@@ -3878,7 +4598,7 @@ function vecdot(x, y) {
3878
4598
  * Like vecdot() but flattens the arguments first into vectors.
3879
4599
  */
3880
4600
  function vdot(x, y) {
3881
- return vecdot(ravel(x), ravel(y));
4601
+ return dot$1(ravel(x), ravel(y));
3882
4602
  }
3883
4603
  /**
3884
4604
  * Return a tuple of coordinate matrices from coordinate vectors.
@@ -3907,6 +4627,43 @@ function meshgrid(xs, { indexing } = {}) {
3907
4627
  return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
3908
4628
  }
3909
4629
  /**
4630
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
4631
+ *
4632
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
4633
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
4634
+ * `k>0` is above it.
4635
+ */
4636
+ function tri(n, m, k = 0, { dtype, device } = {}) {
4637
+ m ??= n;
4638
+ dtype ??= DType.Float32;
4639
+ if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
4640
+ if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
4641
+ if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
4642
+ const rows = arange(k, n + k, 1, {
4643
+ dtype: DType.Int32,
4644
+ device
4645
+ });
4646
+ const cols = arange(0, m, 1, {
4647
+ dtype: DType.Int32,
4648
+ device
4649
+ });
4650
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
4651
+ }
4652
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
4653
+ function tril(a, k = 0) {
4654
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4655
+ a = fudgeArray(a);
4656
+ const [n, m] = a.shape.slice(-2);
4657
+ return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
4658
+ }
4659
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
4660
+ function triu(a, k = 0) {
4661
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4662
+ a = fudgeArray(a);
4663
+ const [n, m] = a.shape.slice(-2);
4664
+ return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
4665
+ }
4666
+ /**
3910
4667
  * Clip (limit) the values in an array.
3911
4668
  *
3912
4669
  * Given an interval, values outside the interval are clipped to the interval
@@ -3930,18 +4687,70 @@ function absolute(x) {
3930
4687
  x = fudgeArray(x);
3931
4688
  return where(less(x.ref, 0), x.ref.mul(-1), x);
3932
4689
  }
3933
- /** Alias of `jax.numpy.absolute()`. */
4690
+ /** @function Alias of `jax.numpy.absolute()`. */
3934
4691
  const abs = absolute;
4692
+ /** Return an element-wise indication of sign of the input. */
4693
+ function sign(x) {
4694
+ x = fudgeArray(x);
4695
+ return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4696
+ }
3935
4697
  /** Calculate element-wise square of the input array. */
3936
4698
  function square(x) {
3937
4699
  x = fudgeArray(x);
3938
4700
  return x.ref.mul(x);
3939
4701
  }
3940
- /** Compute a trigonometric tangent of each element of input. */
4702
+ /** Element-wise tangent function (takes radians). */
3941
4703
  function tan(x) {
3942
4704
  x = fudgeArray(x);
3943
4705
  return sin(x.ref).div(cos(x));
3944
4706
  }
4707
+ /** Element-wise inverse cosine function (inverse of cos). */
4708
+ function acos(x) {
4709
+ return subtract(pi / 2, asin(x));
4710
+ }
4711
+ /**
4712
+ * @function
4713
+ * Return element-wise hypotenuse for the given legs of a right triangle.
4714
+ *
4715
+ * In the original NumPy/JAX implementation, this function is more numerically
4716
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4717
+ * improvements.
4718
+ */
4719
+ const hypot = jit$1((x1, x2) => {
4720
+ return sqrt(square(x1).add(square(x2)));
4721
+ });
4722
+ /**
4723
+ * @function
4724
+ * Element-wise arc tangent of y/x with correct quadrant.
4725
+ *
4726
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
4727
+ * The result is in the range [-π, π].
4728
+ *
4729
+ * Uses numerically stable formulas:
4730
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
4731
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
4732
+ *
4733
+ * The output is ill-defined when both x and y are zero.
4734
+ */
4735
+ const atan2 = jit$1((y, x) => {
4736
+ const r = sqrt(square(x.ref).add(square(y.ref)));
4737
+ const xNeg = less(x.ref, 0);
4738
+ const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
4739
+ const denom = where(xNeg, y, r.add(x));
4740
+ return atan(numer.div(denom)).mul(2);
4741
+ });
4742
+ /** @function Alias of `jax.numpy.acos()`. */
4743
+ const arccos = acos;
4744
+ /** @function Alias of `jax.numpy.atan()`. */
4745
+ const arctan = atan;
4746
+ /** @function Alias of `jax.numpy.atan2()`. */
4747
+ const arctan2 = atan2;
4748
+ /** Element-wise subtraction, with broadcasting. */
4749
+ function subtract(x, y) {
4750
+ x = fudgeArray(x);
4751
+ y = fudgeArray(y);
4752
+ return x.sub(y);
4753
+ }
3945
4754
  /** Calculates the floating-point division of x by y element-wise. */
3946
4755
  function trueDivide(x, y) {
3947
4756
  x = fudgeArray(x);
@@ -3949,7 +4758,7 @@ function trueDivide(x, y) {
3949
4758
  if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
3950
4759
  return x.div(y);
3951
4760
  }
3952
- /** Alias of `jax.numpy.trueDivide()`. */
4761
+ /** @function Alias of `jax.numpy.trueDivide()`. */
3953
4762
  const divide = trueDivide;
3954
4763
  /** Round input to the nearest integer towards zero. */
3955
4764
  function trunc(x) {
@@ -3967,15 +4776,151 @@ function log2(x) {
3967
4776
  function log10(x) {
3968
4777
  return log(x).mul(Math.LOG10E);
3969
4778
  }
4779
+ /** Calculate `exp(x) - 1` element-wise. */
4780
+ function expm1(x) {
4781
+ return exp(x).sub(1);
4782
+ }
4783
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
4784
+ function log1p(x) {
4785
+ return log(add(1, x));
4786
+ }
4787
+ /** Convert angles from degrees to radians. */
4788
+ function deg2rad(x) {
4789
+ return multiply(x, pi / 180);
4790
+ }
4791
+ /** @function Alias of `jax.numpy.deg2rad()`. */
4792
+ const radians = deg2rad;
4793
+ /** Convert angles from radians to degrees. */
4794
+ function rad2deg(x) {
4795
+ return multiply(x, 180 / pi);
4796
+ }
4797
+ /** @function Alias of `jax.numpy.rad2deg()`. */
4798
+ const degrees = rad2deg;
4799
+ /**
4800
+ * @function
4801
+ * Computes first array raised to power of second array, element-wise.
4802
+ */
4803
+ const power = jit$1((x1, x2) => {
4804
+ return exp(log(x1).mul(x2));
4805
+ });
4806
+ /** @function Alias of `jax.numpy.power()`. */
4807
+ const pow = power;
4808
+ /** @function Calculate the element-wise cube root of the input array. */
4809
+ const cbrt = jit$1((x) => {
4810
+ const sgn = where(less(x.ref, 0), -1, 1);
4811
+ return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4812
+ });
4813
+ /**
4814
+ * @function
4815
+ * Calculate element-wise hyperbolic sine of input.
4816
+ *
4817
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
4818
+ */
4819
+ const sinh = jit$1((x) => {
4820
+ const ex = exp(x);
4821
+ const emx = reciprocal(ex.ref);
4822
+ return ex.sub(emx).mul(.5);
4823
+ });
4824
+ /**
4825
+ * @function
4826
+ * Calculate element-wise hyperbolic cosine of input.
4827
+ *
4828
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
4829
+ */
4830
+ const cosh = jit$1((x) => {
4831
+ const ex = exp(x);
4832
+ const emx = reciprocal(ex.ref);
4833
+ return ex.add(emx).mul(.5);
4834
+ });
4835
+ /**
4836
+ * @function
4837
+ * Calculate element-wise hyperbolic tangent of input.
4838
+ *
4839
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4840
+ */
4841
+ const tanh = jit$1((x) => {
4842
+ const negsgn = where(less(x.ref, 0), 1, -1);
4843
+ const en2x = exp(x.mul(negsgn.ref).mul(2));
4844
+ return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4845
+ });
4846
+ /**
4847
+ * @function
4848
+ * Calculate element-wise inverse hyperbolic sine of input.
4849
+ *
4850
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4851
+ */
4852
+ const arcsinh = jit$1((x) => {
4853
+ return log(x.ref.add(sqrt(square(x).add(1))));
4854
+ });
4855
+ /**
4856
+ * @function
4857
+ * Calculate element-wise inverse hyperbolic cosine of input.
4858
+ *
4859
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4860
+ */
4861
+ const arccosh = jit$1((x) => {
4862
+ return log(x.ref.add(sqrt(square(x).sub(1))));
4863
+ });
4864
+ /**
4865
+ * @function
4866
+ * Calculate element-wise inverse hyperbolic tangent of input.
4867
+ *
4868
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4869
+ */
4870
+ const arctanh = jit$1((x) => {
4871
+ return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4872
+ });
4873
+ /** @function Alias of `jax.numpy.arcsinh()`. */
4874
+ const asinh = arcsinh;
4875
+ /** @function Alias of `jax.numpy.arccosh()`. */
4876
+ const acosh = arccosh;
4877
+ /** @function Alias of `jax.numpy.arctanh()`. */
4878
+ const atanh = arctanh;
4879
+ /**
4880
+ * Compute the variance of an array.
4881
+ *
4882
+ * The variance is computed for the flattened array by default, otherwise over
4883
+ * the specified axis.
4884
+ *
4885
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4886
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4887
+ */
4888
+ function var_(x, axis = null, opts) {
4889
+ x = fudgeArray(x);
4890
+ axis = normalizeAxis(axis, x.ndim);
4891
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
4892
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
4893
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
4894
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
4895
+ }
4896
+ /**
4897
+ * Compute the standard deviation of an array.
4898
+ *
4899
+ * The standard deviation is computed for the flattened array by default,
4900
+ * otherwise over the specified axis.
4901
+ *
4902
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4903
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4904
+ */
4905
+ function std(x, axis = null, opts) {
4906
+ return sqrt(var_(x, axis, opts));
4907
+ }
3970
4908
 
3971
4909
  //#endregion
3972
4910
  //#region src/nn.ts
3973
4911
  var nn_exports = {};
3974
4912
  __export(nn_exports, {
4913
+ celu: () => celu,
4914
+ elu: () => elu,
4915
+ gelu: () => gelu,
4916
+ glu: () => glu,
3975
4917
  identity: () => identity,
4918
+ leakyRelu: () => leakyRelu,
3976
4919
  logSigmoid: () => logSigmoid,
3977
4920
  logSoftmax: () => logSoftmax,
4921
+ logmeanexp: () => logmeanexp,
3978
4922
  logsumexp: () => logsumexp,
4923
+ mish: () => mish,
3979
4924
  oneHot: () => oneHot,
3980
4925
  relu: () => relu,
3981
4926
  relu6: () => relu6,
@@ -3984,6 +4929,8 @@ __export(nn_exports, {
3984
4929
  softSign: () => softSign,
3985
4930
  softmax: () => softmax,
3986
4931
  softplus: () => softplus,
4932
+ squareplus: () => squareplus,
4933
+ standardize: () => standardize,
3987
4934
  swish: () => swish
3988
4935
  });
3989
4936
  /**
@@ -4027,6 +4974,7 @@ function softSign(x) {
4027
4974
  return x.ref.div(absolute(x).add(1));
4028
4975
  }
4029
4976
  /**
4977
+ * @function
4030
4978
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4031
4979
  * Swish, computed element-wise:
4032
4980
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4035,11 +4983,9 @@ function softSign(x) {
4035
4983
  *
4036
4984
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4037
4985
  */
4038
- function silu(x) {
4039
- x = fudgeArray(x);
4040
- return x.ref.mul(sigmoid(x));
4041
- }
4986
+ const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
4042
4987
  /**
4988
+ * @function
4043
4989
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4044
4990
  * Swish, computed element-wise:
4045
4991
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4056,8 +5002,88 @@ const swish = silu;
4056
5002
  function logSigmoid(x) {
4057
5003
  return negative(softplus(negative(x)));
4058
5004
  }
4059
- /** Identity activation function. Returns the argument unmodified. */
5005
+ /**
5006
+ * @function
5007
+ * Identity activation function. Returns the argument unmodified.
5008
+ */
4060
5009
  const identity = fudgeArray;
5010
+ /** Leaky rectified linear (ReLU) activation function */
5011
+ function leakyRelu(x, negativeSlope = .01) {
5012
+ x = fudgeArray(x);
5013
+ return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
5014
+ }
5015
+ /**
5016
+ * Exponential linear unit activation function.
5017
+ *
5018
+ * Computes the element-wise function:
5019
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
5020
+ */
5021
+ function elu(x, alpha = 1) {
5022
+ x = fudgeArray(x);
5023
+ return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
5024
+ }
5025
+ /**
5026
+ * Continuously-differentiable exponential linear unit activation function.
5027
+ *
5028
+ * Computes the element-wise function:
5029
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
5030
+ */
5031
+ function celu(x, alpha = 1) {
5032
+ x = fudgeArray(x);
5033
+ return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
5034
+ }
5035
+ /**
5036
+ * @function
5037
+ * Gaussion error linear unit (GELU) activation function.
5038
+ *
5039
+ * This is computed element-wise. Currently jax-js does not support the erf() or
5040
+ * gelu() functions exactly as primitives, so an approximation is used:
5041
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5042
+ *
5043
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5044
+ *
5045
+ * This will be improved in the future.
5046
+ */
5047
+ const gelu = jit$1((x) => {
5048
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5049
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5050
+ });
5051
+ /**
5052
+ * Gated linear unit (GLU) activation function.
5053
+ *
5054
+ * Splits the `axis` dimension of the input into two halves, a and b, then
5055
+ * computes `a * sigmoid(b)`.
5056
+ */
5057
+ function glu(x, axis = -1) {
5058
+ x = fudgeArray(x);
5059
+ axis = checkAxis(axis, x.ndim);
5060
+ const size$1 = x.shape[axis];
5061
+ if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
5062
+ const slice = x.shape.map((a$1) => [0, a$1]);
5063
+ const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
5064
+ const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
5065
+ return a.mul(sigmoid(b));
5066
+ }
5067
+ /**
5068
+ * Squareplus activation function.
5069
+ *
5070
+ * Computes the element-wise function:
5071
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
5072
+ */
5073
+ function squareplus(x, b = 4) {
5074
+ x = fudgeArray(x);
5075
+ return x.ref.add(sqrt(square(x).add(b))).mul(.5);
5076
+ }
5077
+ /**
5078
+ * Mish activation function.
5079
+ *
5080
+ * Computes the element-wise function:
5081
+ * `mish(x) = x * tanh(softplus(x))`
5082
+ */
5083
+ function mish(x) {
5084
+ x = fudgeArray(x);
5085
+ return x.ref.mul(tanh(softplus(x)));
5086
+ }
4061
5087
  /**
4062
5088
  * Softmax function. Computes the function which rescales elements to the range
4063
5089
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -4066,17 +5092,13 @@ const identity = fudgeArray;
4066
5092
  *
4067
5093
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
4068
5094
  */
4069
- function softmax(x, axis) {
5095
+ function softmax(x, axis = -1) {
4070
5096
  x = fudgeArray(x);
4071
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4072
- else if (typeof axis === "number") axis = [axis];
4073
- if (axis.length === 0) {
4074
- x.dispose();
4075
- return ones(x.shape);
4076
- }
4077
- const xMax = max(x.ref, axis, { keepDims: true });
5097
+ axis = normalizeAxis(axis, x.ndim);
5098
+ if (axis.length === 0) return onesLike(x);
5099
+ const xMax = max(x.ref, axis, { keepdims: true });
4078
5100
  const unnormalized = exp(x.sub(stopGradient(xMax)));
4079
- return unnormalized.ref.div(unnormalized.sum(axis, { keepDims: true }));
5101
+ return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
4080
5102
  }
4081
5103
  /**
4082
5104
  * Log-Softmax function.
@@ -4086,17 +5108,13 @@ function softmax(x, axis) {
4086
5108
  *
4087
5109
  * If `axis` is not specified, it defaults to the last axis.
4088
5110
  */
4089
- function logSoftmax(x, axis) {
5111
+ function logSoftmax(x, axis = -1) {
4090
5112
  x = fudgeArray(x);
4091
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4092
- else if (typeof axis === "number") axis = [axis];
4093
- if (axis.length === 0) {
4094
- x.dispose();
4095
- return zeros(x.shape);
4096
- }
4097
- const xMax = max(x.ref, axis, { keepDims: true });
5113
+ axis = normalizeAxis(axis, x.ndim);
5114
+ if (axis.length === 0) return zerosLike(x);
5115
+ const xMax = max(x.ref, axis, { keepdims: true });
4098
5116
  const shifted = x.sub(stopGradient(xMax));
4099
- const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepDims: true }));
5117
+ const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
4100
5118
  return shifted.sub(shiftedLogsumexp);
4101
5119
  }
4102
5120
  /**
@@ -4107,16 +5125,39 @@ function logSoftmax(x, axis) {
4107
5125
  *
4108
5126
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
4109
5127
  */
4110
- function logsumexp(x, axis) {
5128
+ function logsumexp(x, axis = null) {
4111
5129
  x = fudgeArray(x);
4112
- if (axis === void 0) axis = range(x.ndim);
4113
- else if (typeof axis === "number") axis = [axis];
5130
+ axis = normalizeAxis(axis, x.ndim);
4114
5131
  if (axis.length === 0) return x;
4115
5132
  const xMax = stopGradient(max(x.ref, axis));
4116
5133
  const xMaxDims = broadcast(xMax.ref, x.shape, axis);
4117
5134
  const shifted = x.sub(xMaxDims);
4118
5135
  return xMax.add(log(exp(shifted).sum(axis)));
4119
5136
  }
5137
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5138
+ function logmeanexp(x, axis = null) {
5139
+ x = fudgeArray(x);
5140
+ axis = normalizeAxis(axis, x.ndim);
5141
+ if (axis.length === 0) return x;
5142
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5143
+ return logsumexp(x, axis).sub(Math.log(n));
5144
+ }
5145
+ /**
5146
+ * Standardizes input to zero mean and unit variance.
5147
+ *
5148
+ * By default, this is computed over the last axis. You can pass in a different
5149
+ * axis, or `null` to standardize over all elements.
5150
+ *
5151
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
5152
+ */
5153
+ function standardize(x, axis = -1, opts = {}) {
5154
+ x = fudgeArray(x);
5155
+ axis = normalizeAxis(axis, x.ndim);
5156
+ if (axis.length === 0) return x;
5157
+ const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
5158
+ const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
5159
+ return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
5160
+ }
4120
5161
  /**
4121
5162
  * One-hot encodes the given indices.
4122
5163
  *
@@ -4134,7 +5175,7 @@ function logsumexp(x, axis) {
4134
5175
  * ```
4135
5176
  */
4136
5177
  function oneHot(x, numClasses) {
4137
- if (x.dtype !== "int32") throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
5178
+ if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4138
5179
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4139
5180
  }
4140
5181
 
@@ -4142,8 +5183,11 @@ function oneHot(x, numClasses) {
4142
5183
  //#region src/random.ts
4143
5184
  var random_exports = {};
4144
5185
  __export(random_exports, {
5186
+ bernoulli: () => bernoulli,
4145
5187
  bits: () => bits,
5188
+ exponential: () => exponential,
4146
5189
  key: () => key,
5190
+ normal: () => normal,
4147
5191
  split: () => split,
4148
5192
  uniform: () => uniform
4149
5193
  });
@@ -4174,11 +5218,11 @@ function bits(key$1, shape$1 = []) {
4174
5218
  /** Sample uniform random values in [minval, maxval) with given shape. */
4175
5219
  function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4176
5220
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
4177
- const mantissa = bits(key$1, shape$1).div(scalar(512, {
5221
+ const mantissa = bits(key$1, shape$1).div(array(512, {
4178
5222
  dtype: DType.Uint32,
4179
5223
  device: key$1.device
4180
5224
  }));
4181
- const float12 = mantissa.add(scalar(1065353216, {
5225
+ const float12 = mantissa.add(array(1065353216, {
4182
5226
  dtype: DType.Uint32,
4183
5227
  device: key$1.device
4184
5228
  }));
@@ -4186,6 +5230,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4186
5230
  if (minval === 0 && maxval === 1) return rand;
4187
5231
  else return rand.mul(maxval - minval).add(minval);
4188
5232
  }
5233
+ /**
5234
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
5235
+ *
5236
+ * Returns a random Boolean array with the specified shape. `p` can be an array
5237
+ * and must be broadcastable to `shape`.
5238
+ */
5239
+ function bernoulli(key$1, p = .5, shape$1 = []) {
5240
+ p = fudgeArray(p);
5241
+ return uniform(key$1, shape$1).less(p);
5242
+ }
5243
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
5244
+ function exponential(key$1, shape$1 = []) {
5245
+ const u = uniform(key$1, shape$1);
5246
+ return negative(log1p(negative(u)));
5247
+ }
5248
+ /**
5249
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5250
+ *
5251
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5252
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5253
+ * bitwise identical to JAX.
5254
+ */
5255
+ function normal(key$1, shape$1 = []) {
5256
+ const [k1, k2] = split(key$1, 2);
5257
+ const u1 = uniform(k1, shape$1);
5258
+ const u2 = uniform(k2, shape$1);
5259
+ const radius = sqrt(log1p(negative(u1)).mul(-2));
5260
+ const theta = u2.mul(2 * Math.PI);
5261
+ return radius.mul(cos(theta));
5262
+ }
4189
5263
 
4190
5264
  //#endregion
4191
5265
  //#region src/polyfills.ts
@@ -4195,33 +5269,91 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
4195
5269
 
4196
5270
  //#endregion
4197
5271
  //#region src/index.ts
4198
- /** Compute the forward-mode Jacobian-vector product for a function. */
5272
+ /**
5273
+ * @function
5274
+ * Compute the forward-mode Jacobian-vector product for a function.
5275
+ */
4199
5276
  const jvp = jvp$1;
4200
- /** Vectorize an operation on a batched axis for one or more inputs. */
5277
+ /**
5278
+ * @function
5279
+ * Vectorize an operation on a batched axis for one or more inputs.
5280
+ */
4201
5281
  const vmap = vmap$1;
4202
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
5282
+ /**
5283
+ * @function
5284
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
5285
+ */
4203
5286
  const jacfwd = jacfwd$1;
4204
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
5287
+ /**
5288
+ * @function
5289
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
5290
+ */
4205
5291
  const makeJaxpr = makeJaxpr$1;
5292
+ /**
5293
+ * @function
5294
+ * Mark a function for automatic JIT compilation, with operator fusion.
5295
+ *
5296
+ * The function will be compiled the first time it is called with a set of
5297
+ * argument shapes.
5298
+ *
5299
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
5300
+ * calls to free memory associated with array constants.
5301
+ *
5302
+ * **Options:**
5303
+ * - `staticArgnums`: An array of argument indices to treat as static
5304
+ * (compile-time constant). These arguments must be hashable, won't be traced,
5305
+ * and different values will trigger recompilation.
5306
+ * - `device`: The device to place the computation on. If not specified, the
5307
+ * computation will be placed on the default device.
5308
+ */
4206
5309
  const jit = jit$1;
4207
5310
  /**
5311
+ * @function
4208
5312
  * Produce a local linear approximation to a function at a point using jvp() and
4209
5313
  * partial evaluation.
4210
5314
  */
4211
5315
  const linearize = linearize$1;
4212
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
5316
+ /**
5317
+ * @function
5318
+ * Calculate the reverse-mode vector-Jacobian product for a function.
5319
+ */
4213
5320
  const vjp = vjp$1;
4214
5321
  /**
5322
+ * @function
4215
5323
  * Compute the gradient of a scalar-valued function `f` with respect to its
4216
5324
  * first argument.
4217
5325
  */
4218
5326
  const grad = grad$1;
4219
- /** Create a function that evaluates both `f` and the gradient of `f`. */
5327
+ /**
5328
+ * @function
5329
+ * Create a function that evaluates both `f` and the gradient of `f`.
5330
+ */
4220
5331
  const valueAndGrad = valueAndGrad$1;
4221
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
5332
+ /**
5333
+ * @function
5334
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
5335
+ */
4222
5336
  const jacrev = jacrev$1;
4223
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
5337
+ /**
5338
+ * @function
5339
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
5340
+ */
4224
5341
  const jacobian = jacrev;
5342
+ /**
5343
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
5344
+ *
5345
+ * This can be used to wait for the results of an intermediate computation to
5346
+ * finish. It's recommended to call this regularly in an iterative computation
5347
+ * to avoid queueing up too many pending operations.
5348
+ *
5349
+ * Does not consume reference to the arrays.
5350
+ */
5351
+ async function blockUntilReady(x) {
5352
+ const promises = [];
5353
+ for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
5354
+ await Promise.all(promises);
5355
+ return x;
5356
+ }
4225
5357
 
4226
5358
  //#endregion
4227
- export { devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };
5359
+ export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };