@jax-js/jax 0.0.2 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,13 +30,14 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-BK21PBVP.cjs');
33
+ const require_backend = require('./backend-Ss1Mev_-.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
37
37
  __export(tree_exports, {
38
38
  JsTreeDef: () => JsTreeDef,
39
39
  NodeType: () => NodeType,
40
+ dispose: () => dispose,
40
41
  flatten: () => flatten,
41
42
  leaves: () => leaves,
42
43
  map: () => map,
@@ -51,7 +52,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
51
52
  NodeType$1["Leaf"] = "Leaf";
52
53
  return NodeType$1;
53
54
  }({});
54
- /** Analog to the JAX "pytree" object, but for JavaScript. */
55
+ /** Represents the structure of a JsTree. */
55
56
  var JsTreeDef = class JsTreeDef {
56
57
  static leaf = new JsTreeDef(NodeType.Leaf, null, []);
57
58
  constructor(nodeType, nodeMetadata, childTreedefs) {
@@ -139,6 +140,194 @@ function map(fn, tree, ...rest) {
139
140
  function ref(tree) {
140
141
  return map((x) => x.ref, tree);
141
142
  }
143
+ /** Dispose every array in a tree. */
144
+ function dispose(tree) {
145
+ if (tree) map((x) => x.dispose(), tree);
146
+ }
147
+
148
+ //#endregion
149
+ //#region src/frontend/convolution.ts
150
+ /**
151
+ * Check that the shapes and parameters passed to convolution are valid.
152
+ *
153
+ * If the check succeeds, returns the output shape.
154
+ */
155
+ function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
156
+ if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
157
+ const n = lhsShape.length - 2;
158
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
159
+ if (strides.length !== n) throw new Error("conv() strides != spatial dims");
160
+ if (padding.length !== n) throw new Error("conv() padding != spatial dims");
161
+ if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
162
+ if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
163
+ if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
164
+ const outShape = [lhsShape[0], rhsShape[0]];
165
+ for (let i = 0; i < n; i++) {
166
+ if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
167
+ if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
168
+ if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
169
+ if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
170
+ const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
171
+ if (k <= 0) throw new Error("conv() kernel size must be positive");
172
+ const [pl, pr] = padding[i];
173
+ if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
174
+ const kernelSize = (k - 1) * rhsDilation[i] + 1;
175
+ const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
176
+ if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
177
+ outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
178
+ }
179
+ return outShape;
180
+ }
181
+ function checkPoolShape(inShape, window, strides) {
182
+ if (strides.length !== window.length) throw new Error("pool() strides != window dims");
183
+ if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
184
+ const outShape = inShape.slice(0, inShape.length - window.length);
185
+ for (let i = 0; i < window.length; i++) {
186
+ const k = window[i];
187
+ const s = strides[i];
188
+ const size$1 = inShape[inShape.length - window.length + i];
189
+ if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
190
+ if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
191
+ if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
192
+ outShape.push(Math.ceil((size$1 - k + 1) / s));
193
+ }
194
+ return outShape.concat(window);
195
+ }
196
+ /**
197
+ * Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
198
+ * `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
199
+ * spatial dimensions convolved with a kernel.
200
+ *
201
+ * The resulting array can be multiplied with a kernel of shape `ks`, then
202
+ * reduced along the last `ks.length` dimensions for a convolution.
203
+ *
204
+ * Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
205
+ */
206
+ function pool(st, ks, strides = 1, dilation = 1) {
207
+ if (ks.length === 0) return st;
208
+ if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
209
+ if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
210
+ if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
211
+ if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
212
+ if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
213
+ const noop = st.shape.slice(0, -ks.length);
214
+ const i_ = st.shape.slice(-ks.length);
215
+ const s_ = strides;
216
+ const d_ = dilation;
217
+ const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
218
+ const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
219
+ const kidf = require_backend.zipn(ks, i_, d_, f_);
220
+ st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
221
+ st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
222
+ const kos = require_backend.zipn(ks, o_, s_);
223
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
224
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
225
+ [0, k],
226
+ [0, o],
227
+ [0, 1]
228
+ ])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
229
+ st = st.permute([
230
+ ...require_backend.range(noop.length),
231
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
232
+ ...ks.map((_, j) => noop.length + 2 * j)
233
+ ]);
234
+ return st;
235
+ }
236
+ /**
237
+ * Perform the transpose of pool, directly undo-ing a pool() operation.
238
+ *
239
+ * Note that since pool repeats the input, the transpose operation technically
240
+ * should include a sum reduction. This function doesn't perform the reduction,
241
+ * which should be done on the last `k` axes of the returned shape.
242
+ */
243
+ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
244
+ if (ks.length === 0) return st;
245
+ if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
246
+ if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
247
+ const noop = inShape.slice(0, -ks.length);
248
+ const i_ = inShape.slice(-ks.length);
249
+ const s_ = strides;
250
+ const d_ = dilation;
251
+ const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
252
+ if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
253
+ const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
254
+ const kidf = require_backend.zipn(ks, i_, d_, f_);
255
+ const kos = require_backend.zipn(ks, o_, s_);
256
+ st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
257
+ st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
258
+ k,
259
+ o,
260
+ 1
261
+ ])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
262
+ [0, 0],
263
+ [0, 0],
264
+ [0, s - 1]
265
+ ])]);
266
+ st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
267
+ st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
268
+ st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
269
+ ...require_backend.range(noop.length),
270
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
271
+ ...ks.map((_, j) => noop.length + 2 * j)
272
+ ]);
273
+ return st;
274
+ }
275
+ /** Applies dilation to an array directly, for transposed convolution. */
276
+ function applyDilation(st, dilation) {
277
+ if (dilation.every((s) => s === 1)) return st;
278
+ const s_ = dilation;
279
+ const [a, b, ...k_] = st.shape;
280
+ st = st.reshape([
281
+ a,
282
+ b,
283
+ ...k_.flatMap((k) => [k, 1])
284
+ ]);
285
+ st = st.pad([
286
+ [0, 0],
287
+ [0, 0],
288
+ ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
289
+ ]);
290
+ st = st.reshape([
291
+ a,
292
+ b,
293
+ ...k_.map((k, i) => k * s_[i])
294
+ ]);
295
+ st = st.shrink([
296
+ [0, a],
297
+ [0, b],
298
+ ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
299
+ ]);
300
+ return st;
301
+ }
302
+ /**
303
+ * Prepare for a convolution between two arrays.
304
+ *
305
+ * This does not check the validity of the shapes, which should be checked
306
+ * beforehand using `checkConvShape()`.
307
+ */
308
+ function prepareConv(stX, stY, params) {
309
+ const n = stX.shape.length - 2;
310
+ stX = applyDilation(stX, params.lhsDilation);
311
+ const ks = stY.shape.slice(2);
312
+ stX = stX.padOrShrink([
313
+ [0, 0],
314
+ [0, 0],
315
+ ...params.padding
316
+ ]);
317
+ stX = pool(stX, ks, params.strides, params.rhsDilation);
318
+ stX = stX.moveaxis(1, n + 1).reshape([
319
+ stX.shape[0],
320
+ 1,
321
+ ...stX.shape.slice(2, n + 2),
322
+ stX.shape[1] * require_backend.prod(ks)
323
+ ]);
324
+ stY = stY.reshape([
325
+ stY.shape[0],
326
+ ...require_backend.rep(n, 1),
327
+ stY.shape[1] * require_backend.prod(ks)
328
+ ]);
329
+ return [stX, stY];
330
+ }
142
331
 
143
332
  //#endregion
144
333
  //#region src/frontend/core.ts
@@ -165,12 +354,18 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
165
354
  Primitive$1["RandomBits"] = "random_bits";
166
355
  Primitive$1["Sin"] = "sin";
167
356
  Primitive$1["Cos"] = "cos";
357
+ Primitive$1["Asin"] = "asin";
358
+ Primitive$1["Atan"] = "atan";
168
359
  Primitive$1["Exp"] = "exp";
169
360
  Primitive$1["Log"] = "log";
361
+ Primitive$1["Sqrt"] = "sqrt";
170
362
  Primitive$1["Min"] = "min";
171
363
  Primitive$1["Max"] = "max";
172
364
  Primitive$1["Reduce"] = "reduce";
173
365
  Primitive$1["Dot"] = "dot";
366
+ Primitive$1["Conv"] = "conv";
367
+ Primitive$1["Pool"] = "pool";
368
+ Primitive$1["PoolTranspose"] = "pool_transpose";
174
369
  Primitive$1["Compare"] = "compare";
175
370
  Primitive$1["Where"] = "where";
176
371
  Primitive$1["Transpose"] = "transpose";
@@ -228,34 +423,52 @@ function sin$1(x) {
228
423
  function cos$1(x) {
229
424
  return bind1(Primitive.Cos, [x]);
230
425
  }
426
+ function asin$1(x) {
427
+ return bind1(Primitive.Asin, [x]);
428
+ }
429
+ function atan$1(x) {
430
+ return bind1(Primitive.Atan, [x]);
431
+ }
231
432
  function exp$1(x) {
232
433
  return bind1(Primitive.Exp, [x]);
233
434
  }
234
435
  function log$1(x) {
235
436
  return bind1(Primitive.Log, [x]);
236
437
  }
438
+ function sqrt$1(x) {
439
+ return bind1(Primitive.Sqrt, [x]);
440
+ }
237
441
  function min$1(x, y) {
238
442
  return bind1(Primitive.Min, [x, y]);
239
443
  }
240
444
  function max$1(x, y) {
241
445
  return bind1(Primitive.Max, [x, y]);
242
446
  }
243
- function reduce(x, op, axis, opts) {
447
+ function reduce(x, op, axis = null, opts) {
244
448
  if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
245
- if (axis === void 0) if (x instanceof Tracer) axis = require_backend.range(x.shape.length);
246
- else axis = [];
247
- else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, ndim$1(x))];
248
- else axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
449
+ axis = require_backend.normalizeAxis(axis, ndim$1(x));
249
450
  const originalShape = getShape(x);
250
- const result = bind1(Primitive.Reduce, [x], {
451
+ let result = bind1(Primitive.Reduce, [x], {
251
452
  op,
252
453
  axis
253
454
  });
254
- return opts?.keepDims ? broadcast(result, originalShape, axis) : result;
455
+ if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
456
+ return result;
255
457
  }
256
458
  function dot$1(x, y) {
257
459
  return bind1(Primitive.Dot, [x, y]);
258
460
  }
461
+ function conv(x, y, params = {}) {
462
+ if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
463
+ const n = x.ndim - 2;
464
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
465
+ return bind1(Primitive.Conv, [x, y], {
466
+ strides: params.strides ?? require_backend.rep(n, 1),
467
+ padding: params.padding ?? require_backend.rep(n, [0, 0]),
468
+ lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
469
+ rhsDilation: params.rhsDilation ?? require_backend.rep(n, 1)
470
+ });
471
+ }
259
472
  function compare(x, y, op) {
260
473
  return bind1(Primitive.Compare, [x, y], { op });
261
474
  }
@@ -286,10 +499,11 @@ function where$1(cond, x, y) {
286
499
  }
287
500
  function transpose$1(x, perm) {
288
501
  perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
502
+ if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
289
503
  return bind1(Primitive.Transpose, [x], { perm });
290
504
  }
291
505
  function broadcast(x, shape$1, axis) {
292
- axis = axis.map((a) => require_backend.checkAxis(a, shape$1.length));
506
+ axis = require_backend.normalizeAxis(axis, shape$1.length);
293
507
  return bind1(Primitive.Broadcast, [x], {
294
508
  shape: shape$1,
295
509
  axis
@@ -308,7 +522,7 @@ function reshape$1(x, shape$1) {
308
522
  return bind1(Primitive.Reshape, [x], { shape: shape$1 });
309
523
  }
310
524
  function flip$1(x, axis) {
311
- axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
525
+ axis = require_backend.normalizeAxis(axis, ndim$1(x));
312
526
  return bind1(Primitive.Flip, [x], { axis });
313
527
  }
314
528
  function shrink(x, slice) {
@@ -388,12 +602,19 @@ var Tracer = class Tracer {
388
602
  constructor(trace) {
389
603
  this._trace = trace;
390
604
  }
605
+ /** The shape of the array. */
391
606
  get shape() {
392
607
  return this.aval.shape;
393
608
  }
609
+ /** The total number of elements in the array. */
610
+ get size() {
611
+ return require_backend.prod(this.shape);
612
+ }
613
+ /** The dtype of the array. */
394
614
  get dtype() {
395
615
  return this.aval.dtype;
396
616
  }
617
+ /** The number of dimensions of the array. */
397
618
  get ndim() {
398
619
  return this.shape.length;
399
620
  }
@@ -429,22 +650,20 @@ var Tracer = class Tracer {
429
650
  return lessEqual$1(this, other);
430
651
  }
431
652
  /** Sum of the elements of the array over a given axis, or axes. */
432
- sum(axis, opts) {
653
+ sum(axis = null, opts) {
433
654
  return reduce(this, require_backend.AluOp.Add, axis, opts);
434
655
  }
435
656
  /** Product of the array elements over a given axis. */
436
- prod(axis, opts) {
657
+ prod(axis = null, opts) {
437
658
  return reduce(this, require_backend.AluOp.Mul, axis, opts);
438
659
  }
439
660
  /** Compute the average of the array elements along the specified axis. */
440
- mean(axis, opts) {
441
- if (axis === void 0) axis = require_backend.range(this.ndim);
442
- else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, this.ndim)];
443
- else axis = axis.map((a) => require_backend.checkAxis(a, this.ndim));
444
- let result = reduce(this, require_backend.AluOp.Add, axis);
445
- result = result.mul(require_backend.prod(result.shape) / require_backend.prod(this.shape));
446
- if (opts?.keepDims) result = broadcast(result, this.shape, axis);
447
- return result;
661
+ mean(axis = null, opts) {
662
+ axis = require_backend.normalizeAxis(axis, this.ndim);
663
+ const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
664
+ if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
665
+ const result = reduce(this, require_backend.AluOp.Add, axis, opts);
666
+ return result.mul(1 / n);
448
667
  }
449
668
  /** Permute the dimensions of an array. Defaults to reversing the axis order. */
450
669
  transpose(perm) {
@@ -476,8 +695,29 @@ var Tracer = class Tracer {
476
695
  /** Return specified diagonals. See `numpy.diagonal` for full docs. */
477
696
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
478
697
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
698
+ if (offset < 0) return this.diagonal(-offset, axis2, axis1);
699
+ axis1 = require_backend.checkAxis(axis1, this.ndim);
700
+ axis2 = require_backend.checkAxis(axis2, this.ndim);
479
701
  if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
480
- throw new Error("diagonal not implemented");
702
+ if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
703
+ let ar = this;
704
+ if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
705
+ const perm = require_backend.range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
706
+ ar = ar.transpose(perm);
707
+ }
708
+ const [n, m] = ar.shape.slice(-2);
709
+ const diagSize = Math.min(n, m - offset);
710
+ ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
711
+ const npad = diagSize * (m + 1) - n * m;
712
+ if (npad > 0) ar = pad$1(ar, [...require_backend.rep(ar.ndim - 1, [0, 0]), [0, npad]]);
713
+ else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
714
+ ar = ar.reshape([
715
+ ...ar.shape.slice(0, -1),
716
+ diagSize,
717
+ m + 1
718
+ ]);
719
+ ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
720
+ return ar;
481
721
  }
482
722
  /** Flatten the array without changing its data. */
483
723
  flatten() {
@@ -620,7 +860,7 @@ var ShapedArray = class ShapedArray {
620
860
  get ndim() {
621
861
  return this.shape.length;
622
862
  }
623
- strShort() {
863
+ toString() {
624
864
  return `${this.dtype}[${this.shape.join(",")}]`;
625
865
  }
626
866
  equals(other) {
@@ -651,7 +891,7 @@ function fullRaise(trace, val) {
651
891
  if (Object.is(val._trace.main, trace.main)) return val;
652
892
  else if (val._trace.main.level < level) return trace.lift(val);
653
893
  else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
654
- else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
894
+ else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
655
895
  }
656
896
  var TreeMismatchError = class extends TypeError {
657
897
  constructor(where$2, left, right) {
@@ -900,16 +1140,16 @@ function jitCompile(backend, jaxpr, consts) {
900
1140
  jitCompileCache.set(cacheKey, jp);
901
1141
  return jp;
902
1142
  }
903
- function reshapeViews(exp$2, mapping) {
1143
+ function reshapeViews(exp$2, mapping, reduceAxis = false) {
904
1144
  return exp$2.rewrite((exp$3) => {
905
1145
  if (exp$3.op === require_backend.AluOp.GlobalView) {
906
1146
  const [gid, st] = exp$3.arg;
907
1147
  const newSt = mapping(st);
908
1148
  if (newSt) {
909
- const indices = require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
1149
+ const indices = reduceAxis ? require_backend.unravelAlu(newSt.shape.slice(0, -1), require_backend.AluVar.gidx).concat(require_backend.AluVar.ridx) : require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
910
1150
  return require_backend.AluExp.globalView(exp$3.dtype, gid, newSt, indices);
911
1151
  }
912
- }
1152
+ } else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
913
1153
  });
914
1154
  }
915
1155
  function broadcastedJit(fn) {
@@ -956,8 +1196,11 @@ const jitRules = {
956
1196
  },
957
1197
  [Primitive.Sin]: unopJit(require_backend.AluExp.sin),
958
1198
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
1199
+ [Primitive.Asin]: unopJit(require_backend.AluExp.asin),
1200
+ [Primitive.Atan]: unopJit(require_backend.AluExp.atan),
959
1201
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
960
1202
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1203
+ [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
961
1204
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
962
1205
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
963
1206
  [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
@@ -972,18 +1215,20 @@ const jitRules = {
972
1215
  const size$1 = require_backend.prod(newShape);
973
1216
  const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
974
1217
  newShape.push(reductionSize);
975
- a = a.rewrite((exp$2) => {
976
- if (exp$2.op === require_backend.AluOp.GlobalView) {
977
- const [gid, st] = exp$2.arg;
978
- const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
979
- const indices = require_backend.unravelAlu(newShape.slice(0, -1), require_backend.AluVar.gidx);
980
- indices.push(require_backend.AluVar.ridx);
981
- return require_backend.AluExp.globalView(exp$2.dtype, gid, newSt, indices);
982
- }
983
- });
1218
+ const perm = keptAxes.concat(shiftedAxes);
1219
+ a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
984
1220
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
985
1221
  return new require_backend.Kernel(nargs, size$1, a, reduction);
986
1222
  },
1223
+ [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1224
+ [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
1225
+ let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
1226
+ const size$1 = require_backend.prod(inShape);
1227
+ stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
1228
+ a = reshapeViews(a, (st) => st.compose(stX), true);
1229
+ const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
1230
+ return new require_backend.Kernel(nargs, size$1, a, reduction);
1231
+ },
987
1232
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
988
1233
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
989
1234
  const c = k1.exp;
@@ -993,6 +1238,14 @@ const jitRules = {
993
1238
  axis: [cs.ndim - 1]
994
1239
  });
995
1240
  },
1241
+ [Primitive.Conv](nargs, [a, b], [as, bs], params) {
1242
+ const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1243
+ a = reshapeViews(a, (st) => st.compose(stX));
1244
+ b = reshapeViews(b, (st) => st.compose(stY));
1245
+ as = new ShapedArray(stX.shape, as.dtype);
1246
+ bs = new ShapedArray(stY.shape, bs.dtype);
1247
+ return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1248
+ },
996
1249
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
997
1250
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
998
1251
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
@@ -1005,8 +1258,20 @@ const jitRules = {
1005
1258
  }),
1006
1259
  [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1007
1260
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1008
- [Primitive.Gather]() {
1009
- throw new Error("Gather is not implemented in JIT yet");
1261
+ [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1262
+ const axisSet = new Set(axis);
1263
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1264
+ const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1265
+ finalShape.splice(outDim, 0, ...indexShape);
1266
+ const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
1267
+ const idxNonaxis = [...idxAll];
1268
+ idxNonaxis.splice(outDim, indexShape.length);
1269
+ const src = [...idxNonaxis];
1270
+ for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1271
+ for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
1272
+ const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
1273
+ if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1274
+ return new require_backend.Kernel(nargs, require_backend.prod(finalShape), x.substitute({ gidx: index }));
1010
1275
  },
1011
1276
  [Primitive.JitCall]() {
1012
1277
  throw new Error("internal: JitCall should have been flattened before JIT compilation");
@@ -1025,9 +1290,15 @@ function splitGraphDataflow(backend, jaxpr) {
1025
1290
  blackNodes.add(v);
1026
1291
  p1NextBlack.set(v, v);
1027
1292
  }
1293
+ const reducePrimitives = [
1294
+ Primitive.Reduce,
1295
+ Primitive.Dot,
1296
+ Primitive.Conv,
1297
+ Primitive.PoolTranspose
1298
+ ];
1028
1299
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1029
1300
  const eqn = jaxpr.eqns[i];
1030
- if (eqn.primitive === Primitive.Reduce || eqn.outBinders.some((v) => blackNodes.has(v))) {
1301
+ if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1031
1302
  for (const v of eqn.outBinders) {
1032
1303
  blackNodes.add(v);
1033
1304
  p1NextBlack.set(v, v);
@@ -1168,7 +1439,7 @@ var Array$1 = class Array$1 extends Tracer {
1168
1439
  * is a backend `Slot`, this constructor _takes ownership_ of the slot. It
1169
1440
  * will be freed when the array is disposed.
1170
1441
  */
1171
- constructor(source, st, dtype, backend, pending = null) {
1442
+ constructor(source, st, dtype, backend, { pending = null } = {}) {
1172
1443
  super(baseArrayTrace);
1173
1444
  this.id = Array$1.#nextId++;
1174
1445
  this.#dtype = dtype;
@@ -1177,6 +1448,8 @@ var Array$1 = class Array$1 extends Tracer {
1177
1448
  this.#backend = backend;
1178
1449
  this.#rc = 1;
1179
1450
  this.#pendingSet = new Set(pending);
1451
+ if (this.#pendingSet.size === 0) this.#pendingSet = null;
1452
+ else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
1180
1453
  }
1181
1454
  /** @ignore */
1182
1455
  get aval() {
@@ -1231,7 +1504,7 @@ var Array$1 = class Array$1 extends Tracer {
1231
1504
  const pending = this.#pending;
1232
1505
  for (const exe of pending) exe.updateRc(1);
1233
1506
  if (typeof this.#source === "number") this.#backend.incRef(this.#source);
1234
- const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
1507
+ const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
1235
1508
  this.dispose();
1236
1509
  return ar;
1237
1510
  }
@@ -1254,7 +1527,7 @@ var Array$1 = class Array$1 extends Tracer {
1254
1527
  const inputs = [];
1255
1528
  const src = [...idxNonaxis];
1256
1529
  for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1257
- for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#dtype, ar.#source, ar.#st, idxAxis));
1530
+ for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#source, ar.#st, idxAxis));
1258
1531
  else {
1259
1532
  let gid = inputs.indexOf(ar.#source);
1260
1533
  if (gid === -1) {
@@ -1264,7 +1537,7 @@ var Array$1 = class Array$1 extends Tracer {
1264
1537
  src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
1265
1538
  }
1266
1539
  let exp$2;
1267
- if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#dtype, this.#source, this.#st, src);
1540
+ if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, src);
1268
1541
  else {
1269
1542
  let gid = inputs.indexOf(this.#source);
1270
1543
  if (gid === -1) {
@@ -1280,7 +1553,7 @@ var Array$1 = class Array$1 extends Tracer {
1280
1553
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1281
1554
  this.dispose();
1282
1555
  for (const ar of indices) ar.dispose();
1283
- return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
1556
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
1284
1557
  }
1285
1558
  /** Move axes to the rightmost dimension of the shape. */
1286
1559
  #moveAxesDown(axis) {
@@ -1307,7 +1580,7 @@ var Array$1 = class Array$1 extends Tracer {
1307
1580
  this.#check();
1308
1581
  if (this.#source instanceof require_backend.AluExp) {
1309
1582
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1310
- return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
1583
+ return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1311
1584
  }
1312
1585
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1313
1586
  const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1317,7 +1590,7 @@ var Array$1 = class Array$1 extends Tracer {
1317
1590
  for (const exe of pending) exe.updateRc(1);
1318
1591
  pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
1319
1592
  this.dispose();
1320
- return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
1593
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
1321
1594
  }
1322
1595
  #binary(op, other) {
1323
1596
  const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
@@ -1340,18 +1613,18 @@ var Array$1 = class Array$1 extends Tracer {
1340
1613
  if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1341
1614
  arrays = Array$1.#broadcastArrays(arrays);
1342
1615
  const newShape = [...arrays[0].shape];
1343
- if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && reduceAxis === void 0) {
1616
+ if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
1344
1617
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1345
1618
  const exp$4 = custom(arrays.map((ar) => ar.#source));
1346
- return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
1619
+ return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1347
1620
  }
1348
1621
  const exp$3 = custom(arrays.map((ar) => {
1349
1622
  const src$1 = ar.#source;
1350
1623
  if (ar.#st.contiguous) return src$1;
1351
- return require_backend.accessorAluExp(ar.#dtype, src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1624
+ return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1352
1625
  }));
1353
1626
  const st = require_backend.ShapeTracker.fromShape(newShape);
1354
- return new Array$1(exp$3, st, exp$3.dtype, backend);
1627
+ return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1355
1628
  }
1356
1629
  let indices;
1357
1630
  if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
@@ -1361,7 +1634,7 @@ var Array$1 = class Array$1 extends Tracer {
1361
1634
  }
1362
1635
  const inputs = [];
1363
1636
  const src = [];
1364
- for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#dtype, ar.#source, ar.#st, indices));
1637
+ for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
1365
1638
  else {
1366
1639
  let gid = inputs.indexOf(ar.#source);
1367
1640
  if (gid === -1) {
@@ -1382,7 +1655,7 @@ var Array$1 = class Array$1 extends Tracer {
1382
1655
  for (const exe of pending) exe.updateRc(1);
1383
1656
  pending.add(new PendingExecute(backend, kernel, inputs, [output]));
1384
1657
  for (const ar of arrays) ar.dispose();
1385
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
1658
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
1386
1659
  }
1387
1660
  /** Reduce the last dimension of the array by an operation. */
1388
1661
  #reduce(op) {
@@ -1395,7 +1668,7 @@ var Array$1 = class Array$1 extends Tracer {
1395
1668
  const indices = [...require_backend.unravelAlu(newShape, require_backend.AluVar.gidx), require_backend.AluVar.ridx];
1396
1669
  let exp$2;
1397
1670
  const inputs = [];
1398
- if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1671
+ if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
1399
1672
  else {
1400
1673
  inputs.push(this.#source);
1401
1674
  exp$2 = require_backend.accessorGlobal(this.#dtype, 0, this.#st, indices);
@@ -1406,7 +1679,7 @@ var Array$1 = class Array$1 extends Tracer {
1406
1679
  for (const exe of pending) exe.updateRc(1);
1407
1680
  pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
1408
1681
  this.dispose();
1409
- return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
1682
+ return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
1410
1683
  }
1411
1684
  /**
1412
1685
  * Normalizes this array into one backed by a `Slot`.
@@ -1420,7 +1693,7 @@ var Array$1 = class Array$1 extends Tracer {
1420
1693
  this.#check();
1421
1694
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1422
1695
  if (this.#source instanceof require_backend.AluExp) {
1423
- const exp$2 = require_backend.accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1696
+ const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
1424
1697
  const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
1425
1698
  const output = this.#backend.malloc(kernel.bytes);
1426
1699
  const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
@@ -1458,42 +1731,54 @@ var Array$1 = class Array$1 extends Tracer {
1458
1731
  }
1459
1732
  /** Realize the array and return it as data. */
1460
1733
  async data() {
1461
- if (this.#source instanceof require_backend.AluExp && require_backend.prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1734
+ if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1462
1735
  this.#realize();
1463
1736
  const pending = this.#pending;
1464
1737
  if (pending) {
1465
1738
  await Promise.all(pending.map((p) => p.prepare()));
1466
1739
  for (const p of pending) p.submit();
1467
1740
  }
1468
- const byteCount = require_backend.byteWidth(this.#dtype) * require_backend.prod(this.shape);
1741
+ const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
1469
1742
  const buf = await this.#backend.read(this.#source, 0, byteCount);
1470
1743
  this.dispose();
1471
1744
  return require_backend.dtypedArray(this.dtype, buf);
1472
1745
  }
1473
- /** Wait for this array to be placed on the backend, if needed. */
1474
- async wait() {
1746
+ /**
1747
+ * Wait for this array to finish evaluation.
1748
+ *
1749
+ * Operations and data loading in jax-js are lazy, so this function ensures
1750
+ * that pending operations are dispatched and fully executed before it
1751
+ * returns.
1752
+ *
1753
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1754
+ * dispatch of operations as well.
1755
+ *
1756
+ * **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
1757
+ * asynchronously for multiple arrays.
1758
+ */
1759
+ async blockUntilReady() {
1475
1760
  this.#check();
1476
- if (this.#source instanceof require_backend.AluExp) return;
1761
+ if (this.#source instanceof require_backend.AluExp) return this;
1477
1762
  const pending = this.#pending;
1478
1763
  if (pending) {
1479
1764
  await Promise.all(pending.map((p) => p.prepare()));
1480
1765
  for (const p of pending) p.submit();
1481
1766
  }
1482
1767
  await this.#backend.read(this.#source, 0, 0);
1483
- this.dispose();
1768
+ return this;
1484
1769
  }
1485
1770
  /**
1486
1771
  * Realize the array and return it as data. This is a sync variant and not
1487
1772
  * recommended for performance reasons, as it will block rendering.
1488
1773
  */
1489
1774
  dataSync() {
1490
- if (this.#source instanceof require_backend.AluExp && require_backend.prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1775
+ if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1491
1776
  this.#realize();
1492
1777
  for (const p of this.#pending) {
1493
1778
  p.prepareSync();
1494
1779
  p.submit();
1495
1780
  }
1496
- const byteCount = require_backend.byteWidth(this.#dtype) * require_backend.prod(this.shape);
1781
+ const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
1497
1782
  const buf = this.#backend.readSync(this.#source, 0, byteCount);
1498
1783
  this.dispose();
1499
1784
  return require_backend.dtypedArray(this.dtype, buf);
@@ -1514,6 +1799,16 @@ var Array$1 = class Array$1 extends Tracer {
1514
1799
  async jsAsync() {
1515
1800
  return dataToJs(this.dtype, await this.data(), this.shape);
1516
1801
  }
1802
+ /**
1803
+ * Copy an element of an array to a numeric scalar and return it.
1804
+ *
1805
+ * Throws an error if the array does not have a single element. The array must
1806
+ * either be rank-0, or all dimensions of the shape are 1.
1807
+ */
1808
+ item() {
1809
+ if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
1810
+ return this.dataSync()[0];
1811
+ }
1517
1812
  /** @private Internal plumbing method for Array / Tracer ops. */
1518
1813
  static _implRules() {
1519
1814
  return {
@@ -1527,7 +1822,7 @@ var Array$1 = class Array$1 extends Tracer {
1527
1822
  return [x.#binary(require_backend.AluOp.Idiv, y)];
1528
1823
  },
1529
1824
  [Primitive.Neg]([x]) {
1530
- return [zerosLike(x).#binary(require_backend.AluOp.Sub, x)];
1825
+ return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
1531
1826
  },
1532
1827
  [Primitive.Reciprocal]([x]) {
1533
1828
  return [x.#unary(require_backend.AluOp.Reciprocal)];
@@ -1547,7 +1842,7 @@ var Array$1 = class Array$1 extends Tracer {
1547
1842
  x.#backend.incRef(x.#source);
1548
1843
  const pending = x.#pending;
1549
1844
  for (const exe of pending) exe.updateRc(1);
1550
- const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
1845
+ const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
1551
1846
  x.dispose();
1552
1847
  return [y];
1553
1848
  }
@@ -1577,12 +1872,21 @@ var Array$1 = class Array$1 extends Tracer {
1577
1872
  [Primitive.Cos]([x]) {
1578
1873
  return [x.#unary(require_backend.AluOp.Cos)];
1579
1874
  },
1875
+ [Primitive.Asin]([x]) {
1876
+ return [x.#unary(require_backend.AluOp.Asin)];
1877
+ },
1878
+ [Primitive.Atan]([x]) {
1879
+ return [x.#unary(require_backend.AluOp.Atan)];
1880
+ },
1580
1881
  [Primitive.Exp]([x]) {
1581
1882
  return [x.#unary(require_backend.AluOp.Exp)];
1582
1883
  },
1583
1884
  [Primitive.Log]([x]) {
1584
1885
  return [x.#unary(require_backend.AluOp.Log)];
1585
1886
  },
1887
+ [Primitive.Sqrt]([x]) {
1888
+ return [x.#unary(require_backend.AluOp.Sqrt)];
1889
+ },
1586
1890
  [Primitive.Min]([x, y]) {
1587
1891
  return [x.#binary(require_backend.AluOp.Min, y)];
1588
1892
  },
@@ -1593,9 +1897,24 @@ var Array$1 = class Array$1 extends Tracer {
1593
1897
  if (axis.length === 0) return [x];
1594
1898
  return [x.#moveAxesDown(axis).#reduce(op)];
1595
1899
  },
1900
+ [Primitive.Pool]([x], { window, strides }) {
1901
+ const st = pool(x.#st, window, strides);
1902
+ return [x.#reshape(st)];
1903
+ },
1904
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
1905
+ const n = inShape.length;
1906
+ let st = poolTranspose(x.#st, inShape, window, strides);
1907
+ st = st.reshape([...st.shape.slice(0, n), require_backend.prod(st.shape.slice(n))]);
1908
+ return [x.#reshape(st).#reduce(require_backend.AluOp.Add)];
1909
+ },
1596
1910
  [Primitive.Dot]([x, y]) {
1597
1911
  return [Array$1.#naryCustom("dot", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
1598
1912
  },
1913
+ [Primitive.Conv]([x, y], params) {
1914
+ checkConvShape(x.shape, y.shape, params);
1915
+ const [stX, stY] = prepareConv(x.#st, y.#st, params);
1916
+ return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
1917
+ },
1599
1918
  [Primitive.Compare]([x, y], { op }) {
1600
1919
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1601
1920
  return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
@@ -1644,7 +1963,7 @@ var Array$1 = class Array$1 extends Tracer {
1644
1963
  pending.splice(0, 0, ...prevPending);
1645
1964
  args.forEach((x) => x.dispose());
1646
1965
  return outputs.map((source, i) => {
1647
- return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
1966
+ return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
1648
1967
  });
1649
1968
  }
1650
1969
  };
@@ -1660,6 +1979,7 @@ function scalar(value, { dtype, device } = {}) {
1660
1979
  dtype ??= require_backend.DType.Float32;
1661
1980
  if (![
1662
1981
  require_backend.DType.Float32,
1982
+ require_backend.DType.Float16,
1663
1983
  require_backend.DType.Int32,
1664
1984
  require_backend.DType.Uint32
1665
1985
  ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
@@ -1667,6 +1987,7 @@ function scalar(value, { dtype, device } = {}) {
1667
1987
  dtype ??= require_backend.DType.Bool;
1668
1988
  if (![
1669
1989
  require_backend.DType.Float32,
1990
+ require_backend.DType.Float16,
1670
1991
  require_backend.DType.Int32,
1671
1992
  require_backend.DType.Uint32,
1672
1993
  require_backend.DType.Bool
@@ -1680,7 +2001,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1680
2001
  if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1681
2002
  if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
1682
2003
  return values;
1683
- } else if (values instanceof Float32Array || values instanceof Int32Array) return arrayFromData(values, shape$1 ?? [values.length], {
2004
+ } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1684
2005
  dtype,
1685
2006
  device
1686
2007
  });
@@ -1709,7 +2030,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1709
2030
  });
1710
2031
  } else {
1711
2032
  dtype = dtype ?? require_backend.DType.Float32;
1712
- const data = require_backend.dtypedArray(dtype, flat);
2033
+ const data = require_backend.dtypedJsArray(dtype, flat);
1713
2034
  return arrayFromData(data, shape$1, {
1714
2035
  dtype,
1715
2036
  device
@@ -1730,19 +2051,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
1730
2051
  });
1731
2052
  }
1732
2053
  const backend = require_backend.getBackend(device);
1733
- if (data instanceof Float32Array) {
1734
- if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
1735
- const slot = backend.malloc(data.byteLength, data.buffer);
1736
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), require_backend.DType.Float32, backend);
1737
- } else if (data instanceof Int32Array) {
1738
- if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
1739
- const slot = backend.malloc(data.byteLength, data.buffer);
1740
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype ?? require_backend.DType.Int32, backend);
1741
- } else if (data instanceof Uint32Array) {
1742
- if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
1743
- const slot = backend.malloc(data.byteLength, data.buffer);
1744
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), require_backend.DType.Uint32, backend);
1745
- } else throw new Error("Unsupported data type");
2054
+ if (ArrayBuffer.isView(data)) {
2055
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2056
+ if (data instanceof Float32Array) {
2057
+ if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2058
+ dtype ??= require_backend.DType.Float32;
2059
+ } else if (data instanceof Int32Array) {
2060
+ if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2061
+ dtype ??= require_backend.DType.Int32;
2062
+ } else if (data instanceof Uint32Array) {
2063
+ if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2064
+ dtype ??= require_backend.DType.Uint32;
2065
+ } else if (data instanceof Float16Array) {
2066
+ if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2067
+ dtype ??= require_backend.DType.Float16;
2068
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2069
+ const slot = backend.malloc(data.byteLength, buf);
2070
+ return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
2071
+ } else throw new Error("Unsupported data type: " + data.constructor.name);
1746
2072
  }
1747
2073
  function dataToJs(dtype, data, shape$1) {
1748
2074
  if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
@@ -1769,9 +2095,20 @@ var EvalTrace = class extends Trace {
1769
2095
  };
1770
2096
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
1771
2097
  const implRules = Array$1._implRules();
1772
- function zerosLike(val) {
2098
+ function zerosLike$1(val, dtype) {
1773
2099
  const aval = getAval(val);
1774
- return zeros(aval.shape, { dtype: aval.dtype });
2100
+ if (val instanceof Tracer) val.dispose();
2101
+ return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2102
+ }
2103
+ function onesLike$1(val, dtype) {
2104
+ const aval = getAval(val);
2105
+ if (val instanceof Tracer) val.dispose();
2106
+ return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2107
+ }
2108
+ function fullLike(val, fillValue, dtype) {
2109
+ const aval = getAval(val);
2110
+ if (val instanceof Tracer) val.dispose();
2111
+ return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
1775
2112
  }
1776
2113
  /** Return a new array of given shape and type, filled with zeros. */
1777
2114
  function zeros(shape$1, { dtype, device } = {}) {
@@ -1793,6 +2130,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
1793
2130
  if (typeof fillValue === "number") {
1794
2131
  dtype = dtype ?? require_backend.DType.Float32;
1795
2132
  source = require_backend.AluExp.const(dtype, fillValue);
2133
+ } else if (typeof fillValue === "bigint") {
2134
+ dtype = dtype ?? require_backend.DType.Int32;
2135
+ source = require_backend.AluExp.const(dtype, Number(fillValue));
1796
2136
  } else if (typeof fillValue === "boolean") {
1797
2137
  dtype = dtype ?? require_backend.DType.Bool;
1798
2138
  source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
@@ -1823,7 +2163,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
1823
2163
  const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
1824
2164
  return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
1825
2165
  }
1826
- /** Return the identity array, with ones on the main diagonal. */
2166
+ /** Return the identity matrix, with ones on the main diagonal. */
1827
2167
  function identity$1(n, { dtype, device } = {}) {
1828
2168
  return eye(n, n, {
1829
2169
  dtype,
@@ -1890,7 +2230,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
1890
2230
  const st = require_backend.ShapeTracker.fromShape([num]);
1891
2231
  return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
1892
2232
  }
1893
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
1894
2233
  function aluCompare(a, b, op) {
1895
2234
  switch (op) {
1896
2235
  case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
@@ -1932,8 +2271,8 @@ function generalBroadcast(a, b) {
1932
2271
  }
1933
2272
 
1934
2273
  //#endregion
1935
- //#region node_modules/.pnpm/@oxc-project+runtime@0.77.3/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
1936
- var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.77.3/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js"(exports, module) {
2274
+ //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
2275
+ var require_usingCtx = /* @__PURE__ */ __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js": ((exports, module) => {
1937
2276
  function _usingCtx() {
1938
2277
  var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
1939
2278
  var n$1 = Error();
@@ -1989,11 +2328,11 @@ var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.7
1989
2328
  };
1990
2329
  }
1991
2330
  module.exports = _usingCtx, module.exports.__esModule = true, module.exports["default"] = module.exports;
1992
- } });
2331
+ }) });
1993
2332
 
1994
2333
  //#endregion
1995
2334
  //#region src/frontend/jaxpr.ts
1996
- var import_usingCtx$2 = __toESM(require_usingCtx(), 1);
2335
+ var import_usingCtx$2 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
1997
2336
  /** Variable in a Jaxpr expression. */
1998
2337
  var Var = class Var {
1999
2338
  static #nextId = 1;
@@ -2004,7 +2343,7 @@ var Var = class Var {
2004
2343
  this.aval = aval;
2005
2344
  }
2006
2345
  toString() {
2007
- return `Var(${this.id}):${this.aval.strShort()}`;
2346
+ return `Var(${this.id}):${this.aval.toString()}`;
2008
2347
  }
2009
2348
  };
2010
2349
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
@@ -2044,7 +2383,7 @@ var VarPrinter = class {
2044
2383
  return name;
2045
2384
  }
2046
2385
  nameType(v) {
2047
- return `${this.name(v)}:${v.aval.strShort()}`;
2386
+ return `${this.name(v)}:${v.aval.toString()}`;
2048
2387
  }
2049
2388
  };
2050
2389
  /** A single statement / binding in a Jaxpr, in ANF form. */
@@ -2104,16 +2443,19 @@ var Jaxpr = class Jaxpr {
2104
2443
  varIds.set(v, require_backend.FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
2105
2444
  return id;
2106
2445
  };
2107
- hasher.update(this.inBinders.length, ...this.inBinders.map(vi));
2108
- hasher.update(this.eqns.length, ...this.eqns.flatMap((eqn) => [
2109
- eqn.primitive,
2110
- eqn.inputs.length,
2111
- ...eqn.inputs.map((x) => x instanceof Var ? vi(x) : x.value),
2112
- JSON.stringify(eqn.params),
2113
- eqn.outBinders.length,
2114
- ...eqn.outBinders.map(vi)
2115
- ]));
2116
- hasher.update(this.outs.length, ...this.outs.map((x) => x instanceof Var ? vi(x) : x.value));
2446
+ hasher.update(this.inBinders.length);
2447
+ for (const x of this.inBinders) hasher.update(vi(x));
2448
+ hasher.update(this.eqns.length);
2449
+ for (const eqn of this.eqns) {
2450
+ hasher.update(eqn.primitive);
2451
+ hasher.update(eqn.inputs.length);
2452
+ for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
2453
+ hasher.update(JSON.stringify(eqn.params));
2454
+ hasher.update(eqn.outBinders.length);
2455
+ for (const x of eqn.outBinders) hasher.update(vi(x));
2456
+ }
2457
+ hasher.update(this.outs.length);
2458
+ for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
2117
2459
  return this.#hash = hasher.value;
2118
2460
  }
2119
2461
  hash(state) {
@@ -2150,7 +2492,7 @@ var Jaxpr = class Jaxpr {
2150
2492
  const c = eqn.outBinders[0];
2151
2493
  if (atomIsLit(b, 1)) context.set(c, a);
2152
2494
  else newEqns.push(eqn);
2153
- } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2495
+ } else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
2154
2496
  else newEqns.push(eqn);
2155
2497
  }
2156
2498
  const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
@@ -2199,8 +2541,8 @@ var JaxprType = class {
2199
2541
  this.outTypes = outTypes;
2200
2542
  }
2201
2543
  toString() {
2202
- const inTypes = this.inTypes.map((aval) => aval.strShort()).join(", ");
2203
- const outTypes = this.outTypes.map((aval) => aval.strShort()).join(", ");
2544
+ const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
2545
+ const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
2204
2546
  return `(${inTypes}) -> (${outTypes})`;
2205
2547
  }
2206
2548
  };
@@ -2279,7 +2621,7 @@ var JaxprTracer = class extends Tracer {
2279
2621
  this.aval = aval;
2280
2622
  }
2281
2623
  toString() {
2282
- return `JaxprTracer(${this.aval.strShort()})`;
2624
+ return `JaxprTracer(${this.aval.toString()})`;
2283
2625
  }
2284
2626
  get ref() {
2285
2627
  return this;
@@ -2416,8 +2758,11 @@ const abstractEvalRules = {
2416
2758
  },
2417
2759
  [Primitive.Sin]: vectorizedUnopAbstractEval,
2418
2760
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2761
+ [Primitive.Asin]: vectorizedUnopAbstractEval,
2762
+ [Primitive.Atan]: vectorizedUnopAbstractEval,
2419
2763
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2420
2764
  [Primitive.Log]: vectorizedUnopAbstractEval,
2765
+ [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2421
2766
  [Primitive.Min]: binopAbstractEval,
2422
2767
  [Primitive.Max]: binopAbstractEval,
2423
2768
  [Primitive.Reduce]([x], { axis }) {
@@ -2425,6 +2770,15 @@ const abstractEvalRules = {
2425
2770
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2426
2771
  return [new ShapedArray(newShape, x.dtype)];
2427
2772
  },
2773
+ [Primitive.Pool]([x], { window, strides }) {
2774
+ const shape$1 = checkPoolShape(x.shape, window, strides);
2775
+ return [new ShapedArray(shape$1, x.dtype)];
2776
+ },
2777
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2778
+ const shape$1 = checkPoolShape(inShape, window, strides);
2779
+ if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2780
+ return [new ShapedArray(inShape, x.dtype)];
2781
+ },
2428
2782
  [Primitive.Dot]([x, y]) {
2429
2783
  if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2430
2784
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
@@ -2432,6 +2786,11 @@ const abstractEvalRules = {
2432
2786
  shape$1.splice(-1, 1);
2433
2787
  return [new ShapedArray(shape$1, x.dtype)];
2434
2788
  },
2789
+ [Primitive.Conv]([lhs, rhs], params) {
2790
+ if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2791
+ const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2792
+ return [new ShapedArray(shape$1, lhs.dtype)];
2793
+ },
2435
2794
  [Primitive.Compare]: compareAbstractEval,
2436
2795
  [Primitive.Where]([cond, x, y]) {
2437
2796
  if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
@@ -2479,15 +2838,34 @@ const abstractEvalRules = {
2479
2838
  return outTypes;
2480
2839
  }
2481
2840
  };
2482
- function makeJaxpr$1(f) {
2841
+ function splitIdx(values, argnums) {
2842
+ const a = [];
2843
+ const b = [];
2844
+ for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
2845
+ else b.push(values[i]);
2846
+ return [a, b];
2847
+ }
2848
+ function joinIdx(n, a, b, argnums) {
2849
+ const result = [];
2850
+ let ai = 0;
2851
+ let bi = 0;
2852
+ for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
2853
+ else result.push(b[bi++]);
2854
+ return result;
2855
+ }
2856
+ function makeJaxpr$1(f, opts) {
2483
2857
  return (...argsIn) => {
2484
2858
  try {
2485
2859
  var _usingCtx$1 = (0, import_usingCtx$2.default)();
2486
- const [avalsIn, inTree] = flatten(argsIn);
2487
- const [fFlat, outTree] = flattenFun(f, inTree);
2860
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2861
+ const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
2862
+ const [avalsIn, inTree] = flatten(shapedArgs);
2863
+ const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
2864
+ return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
2865
+ }, inTree);
2488
2866
  const builder = new JaxprBuilder();
2489
2867
  const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
2490
- const _dynamic = _usingCtx$1.u(newDynamic(main));
2868
+ _usingCtx$1.u(newDynamic(main));
2491
2869
  const trace = new JaxprTrace(main);
2492
2870
  const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
2493
2871
  const outs = fFlat(...tracersIn);
@@ -2506,25 +2884,32 @@ function makeJaxpr$1(f) {
2506
2884
  }
2507
2885
  };
2508
2886
  }
2509
- function jit$1(f) {
2887
+ function jit$1(f, opts) {
2510
2888
  const cache = /* @__PURE__ */ new Map();
2511
- return ((...args) => {
2512
- const [argsFlat, inTree] = flatten(args);
2889
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2890
+ const result = ((...args) => {
2891
+ const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2892
+ const [argsFlat, inTree] = flatten(dynamicArgs);
2513
2893
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
2514
2894
  const avalsIn = unflatten(inTree, avalsInFlat);
2515
- const cacheKey = JSON.stringify(avalsIn);
2516
- const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f)(...avalsIn));
2895
+ const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
2896
+ const cacheKey = JSON.stringify(jaxprArgs);
2897
+ const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2517
2898
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2518
2899
  jaxpr,
2519
2900
  numConsts: consts.length
2520
2901
  });
2521
2902
  return unflatten(outTree, outs);
2522
2903
  });
2904
+ result.dispose = () => {
2905
+ for (const { consts } of cache.values()) for (const c of consts) c.dispose();
2906
+ };
2907
+ return result;
2523
2908
  }
2524
2909
 
2525
2910
  //#endregion
2526
2911
  //#region src/frontend/jvp.ts
2527
- var import_usingCtx$1 = __toESM(require_usingCtx(), 1);
2912
+ var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
2528
2913
  var JVPTracer = class extends Tracer {
2529
2914
  constructor(trace, primal, tangent) {
2530
2915
  super(trace);
@@ -2551,7 +2936,7 @@ var JVPTrace = class extends Trace {
2551
2936
  return this.lift(pureArray(val));
2552
2937
  }
2553
2938
  lift(val) {
2554
- return new JVPTracer(this, val, zerosLike(val));
2939
+ return new JVPTracer(this, val, zerosLike$1(val.ref));
2555
2940
  }
2556
2941
  processPrimitive(primitive, tracers, params) {
2557
2942
  const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2569,19 +2954,25 @@ function linearTangentsJvp(primitive) {
2569
2954
  return [ys, dys];
2570
2955
  };
2571
2956
  }
2957
+ /** Rule for product of gradients in bilinear operations. */
2958
+ function bilinearTangentsJvp(primitive) {
2959
+ return ([x, y], [dx, dy], params) => {
2960
+ const primal = bind1(primitive, [x.ref, y.ref], params);
2961
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
2962
+ return [[primal], [tangent]];
2963
+ };
2964
+ }
2572
2965
  /** Rule that zeros out any tangents. */
2573
2966
  function zeroTangentsJvp(primitive) {
2574
2967
  return (primals, tangents, params) => {
2575
2968
  for (const t of tangents) t.dispose();
2576
2969
  const ys = bind(primitive, primals, params);
2577
- return [ys, ys.map((y) => zerosLike(y))];
2970
+ return [ys, ys.map((y) => zerosLike$1(y.ref))];
2578
2971
  };
2579
2972
  }
2580
2973
  const jvpRules = {
2581
2974
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
2582
- [Primitive.Mul]([x, y], [dx, dy]) {
2583
- return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
2584
- },
2975
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
2585
2976
  [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
2586
2977
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
2587
2978
  [Primitive.Reciprocal]([x], [dx]) {
@@ -2594,13 +2985,13 @@ const jvpRules = {
2594
2985
  if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2595
2986
  else {
2596
2987
  dx.dispose();
2597
- return [[cast(x, dtype)], [zerosLike(x)]];
2988
+ return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
2598
2989
  }
2599
2990
  },
2600
2991
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2601
2992
  if (x.dtype === dtype) return [[x], [dx]];
2602
2993
  dx.dispose();
2603
- return [[bitcast(x, dtype)], [zerosLike(x)]];
2994
+ return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
2604
2995
  },
2605
2996
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2606
2997
  [Primitive.Sin]([x], [dx]) {
@@ -2609,6 +3000,14 @@ const jvpRules = {
2609
3000
  [Primitive.Cos]([x], [dx]) {
2610
3001
  return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
2611
3002
  },
3003
+ [Primitive.Asin]([x], [dx]) {
3004
+ const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
3005
+ return [[asin$1(x)], [denom.mul(dx)]];
3006
+ },
3007
+ [Primitive.Atan]([x], [dx]) {
3008
+ const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
3009
+ return [[atan$1(x)], [dx.div(denom)]];
3010
+ },
2612
3011
  [Primitive.Exp]([x], [dx]) {
2613
3012
  const z = exp$1(x);
2614
3013
  return [[z.ref], [z.mul(dx)]];
@@ -2616,6 +3015,10 @@ const jvpRules = {
2616
3015
  [Primitive.Log]([x], [dx]) {
2617
3016
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
2618
3017
  },
3018
+ [Primitive.Sqrt]([x], [dx]) {
3019
+ const z = sqrt$1(x);
3020
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
3021
+ },
2619
3022
  [Primitive.Min]([x, y], [dx, dy]) {
2620
3023
  return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
2621
3024
  },
@@ -2632,13 +3035,14 @@ const jvpRules = {
2632
3035
  const primal = reduce(x.ref, op, axis);
2633
3036
  const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
2634
3037
  const minCount = where$1(notMin.ref, 0, 1).sum(axis);
2635
- const tangent = where$1(notMin, op === require_backend.AluOp.Min ? 0 : 0, dx).sum(axis).mul(reciprocal$1(minCount));
3038
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
2636
3039
  return [[primal], [tangent]];
2637
3040
  } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
2638
3041
  },
2639
- [Primitive.Dot]([x, y], [dx, dy]) {
2640
- return [[dot$1(x.ref, y.ref)], [dot$1(dx, y).add(dot$1(x, dy))]];
2641
- },
3042
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3043
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3044
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3045
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
2642
3046
  [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
2643
3047
  [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
2644
3048
  dcond.dispose();
@@ -2711,7 +3115,7 @@ function jvp$1(f, primals, tangents) {
2711
3115
 
2712
3116
  //#endregion
2713
3117
  //#region src/frontend/vmap.ts
2714
- var import_usingCtx = __toESM(require_usingCtx(), 1);
3118
+ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
2715
3119
  function mappedAval(batchDim, aval) {
2716
3120
  const shape$1 = [...aval.shape];
2717
3121
  shape$1.splice(batchDim, 1);
@@ -2720,7 +3124,10 @@ function mappedAval(batchDim, aval) {
2720
3124
  /** Move one axis to a different index. */
2721
3125
  function moveaxis$1(x, src, dst) {
2722
3126
  const t = pureArray(x);
2723
- const perm = require_backend.range(t.shape.length);
3127
+ src = require_backend.checkAxis(src, t.ndim);
3128
+ dst = require_backend.checkAxis(dst, t.ndim);
3129
+ if (src === dst) return t;
3130
+ const perm = require_backend.range(t.ndim);
2724
3131
  perm.splice(src, 1);
2725
3132
  perm.splice(dst, 0, src);
2726
3133
  return transpose$1(t, perm);
@@ -2813,8 +3220,11 @@ const vmapRules = {
2813
3220
  [Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
2814
3221
  [Primitive.Sin]: unopBatcher(sin$1),
2815
3222
  [Primitive.Cos]: unopBatcher(cos$1),
3223
+ [Primitive.Asin]: unopBatcher(asin$1),
3224
+ [Primitive.Atan]: unopBatcher(atan$1),
2816
3225
  [Primitive.Exp]: unopBatcher(exp$1),
2817
3226
  [Primitive.Log]: unopBatcher(log$1),
3227
+ [Primitive.Sqrt]: unopBatcher(sqrt$1),
2818
3228
  [Primitive.Min]: broadcastBatcher(min$1),
2819
3229
  [Primitive.Max]: broadcastBatcher(max$1),
2820
3230
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
@@ -2951,7 +3361,7 @@ var PartialVal = class PartialVal {
2951
3361
  return this.val !== null;
2952
3362
  }
2953
3363
  toString() {
2954
- return this.val ? this.val.toString() : this.aval.strShort();
3364
+ return this.val ? this.val.toString() : this.aval.toString();
2955
3365
  }
2956
3366
  };
2957
3367
  function partialEvalFlat(f, pvalsIn) {
@@ -2997,20 +3407,28 @@ function linearizeFlatUtil(f, primalsIn) {
2997
3407
  function linearizeFlat(f, primalsIn) {
2998
3408
  const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
2999
3409
  const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
3000
- return [primalsOut, fLin];
3410
+ const dispose$1 = () => {
3411
+ for (const c of consts) c.dispose();
3412
+ };
3413
+ return [
3414
+ primalsOut,
3415
+ fLin,
3416
+ dispose$1
3417
+ ];
3001
3418
  }
3002
3419
  function linearize$1(f, ...primalsIn) {
3003
3420
  const [primalsInFlat, inTree] = flatten(primalsIn);
3004
3421
  const [fFlat, outTree] = flattenFun(f, inTree);
3005
- const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3422
+ const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
3006
3423
  if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
3007
3424
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3008
- const fLin = (...tangentsIn) => {
3425
+ const fLin = ((...tangentsIn) => {
3009
3426
  const [tangentsInFlat, inTree2] = flatten(tangentsIn);
3010
3427
  if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
3011
3428
  const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
3012
3429
  return unflatten(outTree.value, tangentsOutFlat);
3013
- };
3430
+ });
3431
+ fLin.dispose = dispose$1;
3014
3432
  return [primalsOut, fLin];
3015
3433
  }
3016
3434
  var PartialEvalTracer = class extends Tracer {
@@ -3126,7 +3544,10 @@ var PartialEvalTrace = class extends Trace {
3126
3544
  avalsOut: jaxpr2.outs.map((x) => x.aval),
3127
3545
  tracerRefsOut: []
3128
3546
  };
3129
- const outs2 = jaxpr2.outs.map((x) => new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe));
3547
+ const outs2 = jaxpr2.outs.map((x, i$1) => {
3548
+ if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
3549
+ return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
3550
+ });
3130
3551
  recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
3131
3552
  let i = 0;
3132
3553
  let j = 0;
@@ -3210,13 +3631,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
3210
3631
  const [consts, constvars] = require_backend.unzip2(constToVar.entries());
3211
3632
  const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
3212
3633
  const outVars = tracersOut.map((t) => tracerToVar.get(t));
3213
- const jaxpr = new Jaxpr(inBinders, eqns, outVars);
3634
+ let jaxpr = new Jaxpr(inBinders, eqns, outVars);
3214
3635
  typecheckJaxpr(jaxpr);
3215
3636
  for (const t of consts) t.ref;
3216
3637
  for (const t of tracersIn) t.dispose();
3217
3638
  for (const t of tracersOut) t.dispose();
3639
+ jaxpr = jaxpr.simplify();
3640
+ if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
3218
3641
  return {
3219
- jaxpr: jaxpr.simplify(),
3642
+ jaxpr,
3220
3643
  consts
3221
3644
  };
3222
3645
  }
@@ -3325,12 +3748,72 @@ const transposeRules = {
3325
3748
  if (op === require_backend.AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
3326
3749
  else throw new NonlinearError(Primitive.Reduce);
3327
3750
  },
3751
+ [Primitive.Pool]([ct], [x], { window, strides }) {
3752
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
3753
+ return bind(Primitive.PoolTranspose, [ct], {
3754
+ inShape: x.aval.shape,
3755
+ window,
3756
+ strides
3757
+ });
3758
+ },
3759
+ [Primitive.PoolTranspose]([ct], [x], { window, strides }) {
3760
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
3761
+ return bind(Primitive.Pool, [ct], {
3762
+ window,
3763
+ strides
3764
+ });
3765
+ },
3328
3766
  [Primitive.Dot]([ct], [x, y]) {
3329
3767
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3330
3768
  const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3331
3769
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3332
3770
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3333
3771
  },
3772
+ [Primitive.Conv]([ct], [lhs, rhs], params) {
3773
+ if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
3774
+ const rev01 = [
3775
+ 1,
3776
+ 0,
3777
+ ...require_backend.range(2, ct.ndim)
3778
+ ];
3779
+ if (lhs instanceof UndefPrimal) {
3780
+ let kernel = rhs;
3781
+ kernel = transpose$1(kernel, rev01);
3782
+ kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
3783
+ const result = conv(ct, kernel, {
3784
+ strides: params.lhsDilation,
3785
+ padding: params.padding.map(([pl, _pr], i) => {
3786
+ const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3787
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3788
+ const padBefore = dilatedKernel - 1 - pl;
3789
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3790
+ const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
3791
+ return [padBefore, padAfter];
3792
+ }),
3793
+ lhsDilation: params.strides,
3794
+ rhsDilation: params.rhsDilation
3795
+ });
3796
+ return [result, null];
3797
+ } else {
3798
+ const newLhs = transpose$1(lhs, rev01);
3799
+ const newRhs = transpose$1(ct, rev01);
3800
+ let result = conv(newLhs, newRhs, {
3801
+ strides: params.rhsDilation,
3802
+ padding: params.padding.map(([pl, _pr], i) => {
3803
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3804
+ const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3805
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3806
+ const padFromLhs = dilatedCt - dilatedLhs;
3807
+ const padFromRhs = dilatedKernel - pl - 1;
3808
+ return [pl, padFromLhs + padFromRhs];
3809
+ }),
3810
+ lhsDilation: params.lhsDilation,
3811
+ rhsDilation: params.strides
3812
+ });
3813
+ result = transpose$1(result, rev01);
3814
+ return [null, result];
3815
+ }
3816
+ },
3334
3817
  [Primitive.Where]([ct], [cond, x, y]) {
3335
3818
  const cts = [
3336
3819
  null,
@@ -3422,20 +3905,28 @@ function vjpFlat(f, primalsIn) {
3422
3905
  const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
3423
3906
  return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
3424
3907
  };
3425
- return [primalsOut, fVjp];
3908
+ const dispose$1 = () => {
3909
+ for (const c of consts) c.dispose();
3910
+ };
3911
+ return [
3912
+ primalsOut,
3913
+ fVjp,
3914
+ dispose$1
3915
+ ];
3426
3916
  }
3427
3917
  function vjp$1(f, ...primalsIn) {
3428
3918
  const [primalsInFlat, inTree] = flatten(primalsIn);
3429
3919
  const [fFlat, outTree] = flattenFun(f, inTree);
3430
- const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3920
+ const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
3431
3921
  if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
3432
3922
  const primalsOut = unflatten(outTree.value, primalsOutFlat);
3433
- const fVjp = (cotangentsOut) => {
3923
+ const fVjp = ((cotangentsOut) => {
3434
3924
  const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
3435
3925
  if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
3436
3926
  const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
3437
3927
  return unflatten(inTree, cotangentsInFlat);
3438
- };
3928
+ });
3929
+ fVjp.dispose = dispose$1;
3439
3930
  return [primalsOut, fVjp];
3440
3931
  }
3441
3932
  function grad$1(f) {
@@ -3451,9 +3942,10 @@ function valueAndGrad$1(f) {
3451
3942
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
3452
3943
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3453
3944
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3454
- if (y.dtype !== require_backend.DType.Float32) throw new TypeError("grad currently only supports float32");
3455
- const [ct, ...rest] = fVjp(pureArray(1));
3456
- for (const r of rest) r.dispose();
3945
+ if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3946
+ const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3947
+ for (const r of rest) dispose(r);
3948
+ fVjp.dispose();
3457
3949
  return [y, ct];
3458
3950
  };
3459
3951
  }
@@ -3461,11 +3953,84 @@ function jacrev$1(f) {
3461
3953
  return function jacobianReverse(x) {
3462
3954
  if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
3463
3955
  const [size$1] = x.shape;
3464
- const pullback = (ct) => vjp$1(f, x)[1](ct)[0];
3956
+ const pullback = (ct) => {
3957
+ const [y, fVjp] = vjp$1(f, x);
3958
+ y.dispose();
3959
+ const [ret] = fVjp(ct);
3960
+ fVjp.dispose();
3961
+ return ret;
3962
+ };
3465
3963
  return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
3466
3964
  };
3467
3965
  }
3468
3966
 
3967
+ //#endregion
3968
+ //#region src/lax.ts
3969
+ var lax_exports = {};
3970
+ __export(lax_exports, {
3971
+ conv: () => conv$1,
3972
+ convGeneralDilated: () => convGeneralDilated,
3973
+ convWithGeneralPadding: () => convWithGeneralPadding,
3974
+ reduceWindow: () => reduceWindow
3975
+ });
3976
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
3977
+ const padType = padding.toUpperCase();
3978
+ switch (padType) {
3979
+ case "VALID": return require_backend.rep(inShape.length, [0, 0]);
3980
+ case "SAME":
3981
+ case "SAME_LOWER": {
3982
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
3983
+ const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
3984
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
3985
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
3986
+ }
3987
+ default: throw new Error(`Unknown padding type: ${padType}`);
3988
+ }
3989
+ }
3990
+ /**
3991
+ * General n-dimensional convolution operator, with optional dilation.
3992
+ *
3993
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
3994
+ * function in JAX, which wraps XLA's general convolution operator.
3995
+ *
3996
+ * Grouped convolutions are not supported right now.
3997
+ */
3998
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
3999
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
4000
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
4001
+ if (typeof padding === "string") {
4002
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
4003
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
4004
+ }
4005
+ return conv(lhs, rhs, {
4006
+ strides: windowStrides,
4007
+ padding,
4008
+ lhsDilation,
4009
+ rhsDilation
4010
+ });
4011
+ }
4012
+ /** Convenience wrapper around `convGeneralDilated`. */
4013
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
4014
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
4015
+ lhsDilation,
4016
+ rhsDilation
4017
+ });
4018
+ }
4019
+ /** Convenience wrapper around `convGeneralDilated`. */
4020
+ function conv$1(lhs, rhs, windowStrides, padding) {
4021
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
4022
+ }
4023
+ /** Reduce a computation over padded windows. */
4024
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
4025
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
4026
+ if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
4027
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
4028
+ return computation(bind1(Primitive.Pool, [operand], {
4029
+ window: windowDimensions,
4030
+ strides: windowStrides
4031
+ }));
4032
+ }
4033
+
3469
4034
  //#endregion
3470
4035
  //#region src/numpy.ts
3471
4036
  var numpy_exports = {};
@@ -3474,19 +4039,38 @@ __export(numpy_exports, {
3474
4039
  DType: () => require_backend.DType,
3475
4040
  abs: () => abs,
3476
4041
  absolute: () => absolute,
4042
+ acos: () => acos,
4043
+ acosh: () => acosh,
3477
4044
  add: () => add,
3478
4045
  allclose: () => allclose,
3479
4046
  arange: () => arange,
4047
+ arccos: () => arccos,
4048
+ arccosh: () => arccosh,
4049
+ arcsinh: () => arcsinh,
4050
+ arctan: () => arctan,
4051
+ arctan2: () => arctan2,
4052
+ arctanh: () => arctanh,
3480
4053
  argmax: () => argmax,
3481
4054
  argmin: () => argmin,
3482
4055
  array: () => array,
4056
+ asin: () => asin,
4057
+ asinh: () => asinh,
3483
4058
  astype: () => astype,
4059
+ atan: () => atan,
4060
+ atan2: () => atan2,
4061
+ atanh: () => atanh,
3484
4062
  bool: () => bool,
4063
+ broadcastArrays: () => broadcastArrays,
4064
+ broadcastShapes: () => broadcastShapes,
4065
+ broadcastTo: () => broadcastTo,
4066
+ cbrt: () => cbrt,
3485
4067
  clip: () => clip,
3486
4068
  columnStack: () => columnStack,
3487
- complex64: () => complex64,
3488
4069
  concatenate: () => concatenate,
3489
4070
  cos: () => cos,
4071
+ cosh: () => cosh,
4072
+ deg2rad: () => deg2rad,
4073
+ degrees: () => degrees,
3490
4074
  diag: () => diag,
3491
4075
  diagonal: () => diagonal,
3492
4076
  divide: () => divide,
@@ -3497,23 +4081,29 @@ __export(numpy_exports, {
3497
4081
  eulerGamma: () => eulerGamma,
3498
4082
  exp: () => exp,
3499
4083
  exp2: () => exp2,
4084
+ expm1: () => expm1,
3500
4085
  eye: () => eye,
3501
4086
  flip: () => flip,
3502
4087
  fliplr: () => fliplr,
3503
4088
  flipud: () => flipud,
4089
+ float16: () => float16,
3504
4090
  float32: () => float32,
3505
4091
  full: () => full,
4092
+ fullLike: () => fullLike$1,
3506
4093
  greater: () => greater,
3507
4094
  greaterEqual: () => greaterEqual,
3508
4095
  hstack: () => hstack,
4096
+ hypot: () => hypot,
3509
4097
  identity: () => identity$1,
3510
4098
  inf: () => inf,
4099
+ inner: () => inner,
3511
4100
  int32: () => int32,
3512
4101
  less: () => less,
3513
4102
  lessEqual: () => lessEqual,
3514
4103
  linspace: () => linspace,
3515
4104
  log: () => log,
3516
4105
  log10: () => log10,
4106
+ log1p: () => log1p,
3517
4107
  log2: () => log2,
3518
4108
  matmul: () => matmul,
3519
4109
  max: () => max,
@@ -3529,36 +4119,55 @@ __export(numpy_exports, {
3529
4119
  negative: () => negative,
3530
4120
  notEqual: () => notEqual,
3531
4121
  ones: () => ones,
4122
+ onesLike: () => onesLike,
4123
+ outer: () => outer,
3532
4124
  pad: () => pad,
3533
4125
  permuteDims: () => permuteDims,
3534
4126
  pi: () => pi,
4127
+ pow: () => pow,
4128
+ power: () => power,
3535
4129
  prod: () => prod$1,
4130
+ promoteTypes: () => require_backend.promoteTypes,
4131
+ rad2deg: () => rad2deg,
4132
+ radians: () => radians,
3536
4133
  ravel: () => ravel,
3537
4134
  reciprocal: () => reciprocal,
4135
+ repeat: () => repeat,
3538
4136
  reshape: () => reshape,
3539
- scalar: () => scalar,
3540
4137
  shape: () => shape,
4138
+ sign: () => sign,
3541
4139
  sin: () => sin,
4140
+ sinh: () => sinh,
3542
4141
  size: () => size,
4142
+ sqrt: () => sqrt,
3543
4143
  square: () => square,
3544
4144
  stack: () => stack,
4145
+ std: () => std,
4146
+ subtract: () => subtract,
3545
4147
  sum: () => sum,
3546
4148
  tan: () => tan,
4149
+ tanh: () => tanh,
4150
+ tile: () => tile,
3547
4151
  transpose: () => transpose,
4152
+ tri: () => tri,
4153
+ tril: () => tril,
4154
+ triu: () => triu,
3548
4155
  trueDivide: () => trueDivide,
3549
4156
  trunc: () => trunc,
3550
4157
  uint32: () => uint32,
4158
+ var_: () => var_,
3551
4159
  vdot: () => vdot,
3552
4160
  vecdot: () => vecdot,
3553
4161
  vstack: () => vstack,
3554
4162
  where: () => where,
3555
- zeros: () => zeros
4163
+ zeros: () => zeros,
4164
+ zerosLike: () => zerosLike
3556
4165
  });
3557
4166
  const float32 = require_backend.DType.Float32;
3558
4167
  const int32 = require_backend.DType.Int32;
3559
4168
  const uint32 = require_backend.DType.Uint32;
3560
4169
  const bool = require_backend.DType.Bool;
3561
- const complex64 = require_backend.DType.Complex64;
4170
+ const float16 = require_backend.DType.Float16;
3562
4171
  /** Euler's constant, `e = 2.7182818284590...` */
3563
4172
  const e = Math.E;
3564
4173
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -3569,52 +4178,66 @@ const inf = Number.POSITIVE_INFINITY;
3569
4178
  const nan = NaN;
3570
4179
  /** This is Pi, `π = 3.14159265358979...` */
3571
4180
  const pi = Math.PI;
3572
- /** Element-wise addition, with broadcasting. */
4181
+ /** @function Element-wise addition, with broadcasting. */
3573
4182
  const add = add$1;
3574
- /** Element-wise multiplication, with broadcasting. */
4183
+ /** @function Element-wise multiplication, with broadcasting. */
3575
4184
  const multiply = mul;
3576
- /** Numerical negative of every element of an array. */
4185
+ /** @function Numerical negative of every element of an array. */
3577
4186
  const negative = neg;
3578
- /** Calculate element-wise reciprocal of the input. This is `1/x`. */
4187
+ /** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
3579
4188
  const reciprocal = reciprocal$1;
3580
- /** Element-wise sine function (takes radians). */
4189
+ /** @function Element-wise sine function (takes radians). */
3581
4190
  const sin = sin$1;
3582
- /** Element-wise cosine function (takes radians). */
4191
+ /** @function Element-wise cosine function (takes radians). */
3583
4192
  const cos = cos$1;
3584
- /** Calculate the exponential of all elements in the input array. */
4193
+ /** @function Element-wise inverse sine function (inverse of sin). */
4194
+ const asin = asin$1;
4195
+ /** @function Element-wise inverse tangent function (inverse of tan). */
4196
+ const atan = atan$1;
4197
+ /** @function Calculate the exponential of all elements in the input array. */
3585
4198
  const exp = exp$1;
3586
- /** Calculate the natural logarithm of all elements in the input array. */
4199
+ /** @function Calculate the natural logarithm of all elements in the input array. */
3587
4200
  const log = log$1;
3588
- /** Return element-wise minimum of the input arrays. */
4201
+ /** @function Calculate the square root of all elements in the input array. */
4202
+ const sqrt = sqrt$1;
4203
+ /** @function Return element-wise minimum of the input arrays. */
3589
4204
  const minimum = min$1;
3590
- /** Return element-wise maximum of the input arrays. */
4205
+ /** @function Return element-wise maximum of the input arrays. */
3591
4206
  const maximum = max$1;
3592
- /** Compare two arrays element-wise. */
4207
+ /** @function Compare two arrays element-wise. */
3593
4208
  const greater = greater$1;
3594
- /** Compare two arrays element-wise. */
4209
+ /** @function Compare two arrays element-wise. */
3595
4210
  const less = less$1;
3596
- /** Compare two arrays element-wise. */
4211
+ /** @function Compare two arrays element-wise. */
3597
4212
  const equal = equal$1;
3598
- /** Compare two arrays element-wise. */
4213
+ /** @function Compare two arrays element-wise. */
3599
4214
  const notEqual = notEqual$1;
3600
- /** Compare two arrays element-wise. */
4215
+ /** @function Compare two arrays element-wise. */
3601
4216
  const greaterEqual = greaterEqual$1;
3602
- /** Compare two arrays element-wise. */
4217
+ /** @function Compare two arrays element-wise. */
3603
4218
  const lessEqual = lessEqual$1;
3604
- /** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
4219
+ /** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
3605
4220
  const where = where$1;
3606
- /** Permute the dimensions of an array. Defaults to reversing the axis order. */
4221
+ /**
4222
+ * @function
4223
+ * Permute the dimensions of an array. Defaults to reversing the axis order.
4224
+ */
3607
4225
  const transpose = transpose$1;
3608
4226
  /**
4227
+ * @function
3609
4228
  * Give a new shape to an array without changing its data.
3610
4229
  *
3611
4230
  * One shape dimension can be -1. In this case, the value is inferred from the
3612
4231
  * length of the array and remaining dimensions.
3613
4232
  */
3614
4233
  const reshape = reshape$1;
3615
- /** Move axes of an array to new positions. Other axes retain original order. */
4234
+ /**
4235
+ * @function
4236
+ * Move axes of an array to new positions. Other axes retain original order.
4237
+ */
3616
4238
  const moveaxis = moveaxis$1;
3617
4239
  /**
4240
+ * @function
3618
4241
  * Add padding (zeros) to an array.
3619
4242
  *
3620
4243
  * The `width` argument is either an integer or pair of integers, in which case
@@ -3622,11 +4245,29 @@ const moveaxis = moveaxis$1;
3622
4245
  * pair specifies the padding for its corresponding axis.
3623
4246
  */
3624
4247
  const pad = pad$1;
3625
- /** Return the number of dimensions of an array. Does not consume array reference. */
4248
+ /**
4249
+ * @function
4250
+ * Return the number of dimensions of an array. Does not consume array reference.
4251
+ */
3626
4252
  const ndim = ndim$1;
3627
- /** Return the shape of an array. Does not consume array reference. */
4253
+ /** @function Return the shape of an array. Does not consume array reference. */
3628
4254
  const shape = getShape;
3629
4255
  /**
4256
+ * @function
4257
+ * Return an array of zeros with the same shape and type as a given array.
4258
+ */
4259
+ const zerosLike = zerosLike$1;
4260
+ /**
4261
+ * @function
4262
+ * Return an array of ones with the same shape and type as a given array.
4263
+ */
4264
+ const onesLike = onesLike$1;
4265
+ /**
4266
+ * @function
4267
+ * Return a full array with the same shape and type as a given array.
4268
+ */
4269
+ const fullLike$1 = fullLike;
4270
+ /**
3630
4271
  * Return the number of elements in an array, optionally along an axis.
3631
4272
  * Does not consume array reference.
3632
4273
  */
@@ -3639,23 +4280,23 @@ function astype(a, dtype) {
3639
4280
  return fudgeArray(a).astype(dtype);
3640
4281
  }
3641
4282
  /** Sum of the elements of the array over a given axis, or axes. */
3642
- function sum(a, axis, opts) {
4283
+ function sum(a, axis = null, opts) {
3643
4284
  return reduce(a, require_backend.AluOp.Add, axis, opts);
3644
4285
  }
3645
4286
  /** Product of the array elements over a given axis. */
3646
- function prod$1(a, axis, opts) {
4287
+ function prod$1(a, axis = null, opts) {
3647
4288
  return reduce(a, require_backend.AluOp.Mul, axis, opts);
3648
4289
  }
3649
4290
  /** Return the minimum of array elements along a given axis. */
3650
- function min(a, axis, opts) {
4291
+ function min(a, axis = null, opts) {
3651
4292
  return reduce(a, require_backend.AluOp.Min, axis, opts);
3652
4293
  }
3653
4294
  /** Return the maximum of array elements along a given axis. */
3654
- function max(a, axis, opts) {
4295
+ function max(a, axis = null, opts) {
3655
4296
  return reduce(a, require_backend.AluOp.Max, axis, opts);
3656
4297
  }
3657
4298
  /** Compute the average of the array elements along the specified axis. */
3658
- function mean(a, axis, opts) {
4299
+ function mean(a, axis = null, opts) {
3659
4300
  return fudgeArray(a).mean(axis, opts);
3660
4301
  }
3661
4302
  /**
@@ -3671,18 +4312,12 @@ function argmin(a, axis, opts) {
3671
4312
  axis = 0;
3672
4313
  } else axis = require_backend.checkAxis(axis, a.ndim);
3673
4314
  const shape$1 = a.shape;
3674
- const isMax = equal(a, min(a.ref, axis, { keepDims: true }));
4315
+ const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
3675
4316
  const length = scalar(shape$1[axis], {
3676
4317
  dtype: int32,
3677
4318
  device: a.device
3678
4319
  });
3679
- const idx = where(isMax, scalar(1, {
3680
- dtype: int32,
3681
- device: a.device
3682
- }), scalar(0, {
3683
- dtype: int32,
3684
- device: a.device
3685
- })).mul(arange(shape$1[axis], 0, -1, {
4320
+ const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3686
4321
  dtype: int32,
3687
4322
  device: a.device
3688
4323
  }).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
@@ -3701,35 +4336,21 @@ function argmax(a, axis, opts) {
3701
4336
  axis = 0;
3702
4337
  } else axis = require_backend.checkAxis(axis, a.ndim);
3703
4338
  const shape$1 = a.shape;
3704
- const isMax = equal(a, max(a.ref, axis, { keepDims: true }));
4339
+ const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
3705
4340
  const length = scalar(shape$1[axis], {
3706
4341
  dtype: int32,
3707
4342
  device: a.device
3708
4343
  });
3709
- const idx = where(isMax, scalar(1, {
3710
- dtype: int32,
3711
- device: a.device
3712
- }), scalar(0, {
3713
- dtype: int32,
3714
- device: a.device
3715
- })).mul(arange(shape$1[axis], 0, -1, {
4344
+ const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3716
4345
  dtype: int32,
3717
4346
  device: a.device
3718
4347
  }).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
3719
4348
  return length.sub(max(idx, axis, opts));
3720
4349
  }
3721
4350
  /** Reverse the elements in an array along the given axes. */
3722
- function flip(x, axis) {
4351
+ function flip(x, axis = null) {
3723
4352
  const nd = ndim(x);
3724
- if (axis === void 0) axis = require_backend.range(nd);
3725
- else if (typeof axis === "number") axis = [axis];
3726
- const seen = /* @__PURE__ */ new Set();
3727
- for (let i = 0; i < axis.length; i++) {
3728
- if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
3729
- if (axis[i] < 0) axis[i] += nd;
3730
- if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
3731
- seen.add(axis[i]);
3732
- }
4353
+ axis = require_backend.normalizeAxis(axis, nd);
3733
4354
  return flip$1(x, axis);
3734
4355
  }
3735
4356
  /**
@@ -3835,18 +4456,88 @@ function flipud(x) {
3835
4456
  function fliplr(x) {
3836
4457
  return flip(x, 1);
3837
4458
  }
4459
+ /** @function Alternative name for `numpy.transpose()`. */
3838
4460
  const permuteDims = transpose;
3839
4461
  /** Return a 1-D flattened array containing the elements of the input. */
3840
4462
  function ravel(a) {
3841
4463
  return fudgeArray(a).ravel();
3842
4464
  }
3843
4465
  /**
4466
+ * Repeat each element of an array after themselves.
4467
+ *
4468
+ * If no axis is provided, use the flattened input array, and return a flat
4469
+ * output array.
4470
+ */
4471
+ function repeat(a, repeats, axis) {
4472
+ if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
4473
+ a = fudgeArray(a);
4474
+ if (axis === void 0) {
4475
+ a = ravel(a);
4476
+ axis = 0;
4477
+ }
4478
+ axis = require_backend.checkAxis(axis, a.ndim);
4479
+ if (repeats === 1) return a;
4480
+ const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
4481
+ const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
4482
+ return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
4483
+ }
4484
+ /**
4485
+ * Construct an array by repeating A the number of times given by reps.
4486
+ *
4487
+ * If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
4488
+ * integers, the resulting array will have a shape of `(reps[0] * d1,
4489
+ * reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
4490
+ */
4491
+ function tile(a, reps) {
4492
+ a = fudgeArray(a);
4493
+ if (typeof reps === "number") reps = [reps];
4494
+ if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
4495
+ const ndiff = reps.length - a.ndim;
4496
+ if (ndiff > 0) a = a.reshape([...require_backend.rep(ndiff, 1), ...a.shape]);
4497
+ if (ndiff < 0) reps = [...require_backend.rep(-ndiff, 1), ...reps];
4498
+ const broadcastedShape = [];
4499
+ const broadcastAxes = [];
4500
+ for (let i = 0; i < a.ndim; i++) {
4501
+ if (reps[i] > 1) {
4502
+ broadcastedShape.push(reps[i]);
4503
+ broadcastAxes.push(broadcastedShape.length - 1);
4504
+ }
4505
+ broadcastedShape.push(a.shape[i]);
4506
+ }
4507
+ const finalShape = a.shape.map((d, i) => reps[i] * d);
4508
+ return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
4509
+ }
4510
+ /**
4511
+ * Broadcast an array to a shape, with NumPy-style broadcasing rules.
4512
+ *
4513
+ * In other words, this lets you append axes to the left, and/or expand
4514
+ * dimensions where the shape is 1.
4515
+ */
4516
+ function broadcastTo(a, shape$1) {
4517
+ const nd = ndim(a);
4518
+ if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
4519
+ return broadcast(a, shape$1, require_backend.range(shape$1.length - nd));
4520
+ }
4521
+ /** Broadcast input shapes to a common output shape. */
4522
+ function broadcastShapes(...shapes) {
4523
+ if (shapes.length === 0) return [];
4524
+ return shapes.reduce(generalBroadcast);
4525
+ }
4526
+ /** Broadcast arrays to a common shape. */
4527
+ function broadcastArrays(...arrays) {
4528
+ const shapes = arrays.map((a) => shape(a));
4529
+ const outShape = broadcastShapes(...shapes);
4530
+ return arrays.map((a) => broadcastTo(a, outShape));
4531
+ }
4532
+ /**
3844
4533
  * Return specified diagonals.
3845
4534
  *
3846
4535
  * If a is 2D, return the diagonal of the array with the given offset. If a is
3847
- * 3D or higher, compute diagonals along the two given axes.
4536
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
3848
4537
  *
3849
- * This returns a view over the existing array.
4538
+ * This returns a view over the existing array. The shape of the resulting array
4539
+ * is determined by removing the two axes along which the diagonal is taken,
4540
+ * then appending a new axis to the right with holding the diagonals.
3850
4541
  */
3851
4542
  function diagonal(a, offset, axis1, axis2) {
3852
4543
  return fudgeArray(a).diagonal(offset, axis1, axis2);
@@ -3862,15 +4553,16 @@ function diag(v, k = 0) {
3862
4553
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
3863
4554
  if (a.ndim === 1) {
3864
4555
  const n = a.shape[0];
3865
- const ret = where(eye(n).equal(1), a, 0);
3866
- if (k !== 0) throw new Error("diag() for 1D arrays only for k=0");
3867
- return ret;
4556
+ const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
4557
+ if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4558
+ else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4559
+ else return ret;
3868
4560
  } else if (a.ndim === 2) return diagonal(a, k);
3869
4561
  else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
3870
4562
  }
3871
4563
  /** Return if two arrays are element-wise equal within a tolerance. */
3872
4564
  function allclose(actual, expected, options) {
3873
- const { rtol = 1e-5, atol = 1e-8 } = options ?? {};
4565
+ const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
3874
4566
  const x = array(actual);
3875
4567
  const y = array(expected);
3876
4568
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
@@ -3905,8 +4597,36 @@ function dot(x, y) {
3905
4597
  ]);
3906
4598
  return dot$1(x, y);
3907
4599
  }
3908
- /** Vector dot product of two arrays. */
3909
- function vecdot(x, y) {
4600
+ /**
4601
+ * Compute the inner product of two arrays.
4602
+ *
4603
+ * Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
4604
+ * contraction on the last axis.
4605
+ *
4606
+ * Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
4607
+ */
4608
+ function inner(x, y) {
4609
+ x = reshape(x, shape(x).toSpliced(-1, 0, ...require_backend.rep(ndim(y) - 1, 1)));
4610
+ return dot$1(x, y);
4611
+ }
4612
+ /**
4613
+ * Compute the outer product of two arrays.
4614
+ *
4615
+ * If the input arrays are not 1D, they will be flattened. Returned array will
4616
+ * be of shape `[x.size, y.size]`.
4617
+ */
4618
+ function outer(x, y) {
4619
+ x = ravel(x);
4620
+ y = ravel(y);
4621
+ return multiply(x.reshape([x.shape[0], 1]), y);
4622
+ }
4623
+ /** Vector dot product of two arrays along a given axis. */
4624
+ function vecdot(x, y, { axis } = {}) {
4625
+ const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
4626
+ const yaxis = require_backend.checkAxis(axis ?? -1, ndim(y));
4627
+ if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
4628
+ x = moveaxis(x, xaxis, -1);
4629
+ y = moveaxis(y, yaxis, -1);
3910
4630
  return dot$1(x, y);
3911
4631
  }
3912
4632
  /**
@@ -3915,7 +4635,7 @@ function vecdot(x, y) {
3915
4635
  * Like vecdot() but flattens the arguments first into vectors.
3916
4636
  */
3917
4637
  function vdot(x, y) {
3918
- return vecdot(ravel(x), ravel(y));
4638
+ return dot$1(ravel(x), ravel(y));
3919
4639
  }
3920
4640
  /**
3921
4641
  * Return a tuple of coordinate matrices from coordinate vectors.
@@ -3944,6 +4664,43 @@ function meshgrid(xs, { indexing } = {}) {
3944
4664
  return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
3945
4665
  }
3946
4666
  /**
4667
+ * Return an array with ones on and below the diagonal and zeros elsewhere.
4668
+ *
4669
+ * If `k` is provided, it specifies the sub-diagonal on and below which the
4670
+ * array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
4671
+ * `k>0` is above it.
4672
+ */
4673
+ function tri(n, m, k = 0, { dtype, device } = {}) {
4674
+ m ??= n;
4675
+ dtype ??= require_backend.DType.Float32;
4676
+ if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
4677
+ if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
4678
+ if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
4679
+ const rows = arange(k, n + k, 1, {
4680
+ dtype: require_backend.DType.Int32,
4681
+ device
4682
+ });
4683
+ const cols = arange(0, m, 1, {
4684
+ dtype: require_backend.DType.Int32,
4685
+ device
4686
+ });
4687
+ return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
4688
+ }
4689
+ /** Return the lower triangle of an array. Must be of dimension >= 2. */
4690
+ function tril(a, k = 0) {
4691
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4692
+ a = fudgeArray(a);
4693
+ const [n, m] = a.shape.slice(-2);
4694
+ return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
4695
+ }
4696
+ /** Return the upper triangle of an array. Must be of dimension >= 2. */
4697
+ function triu(a, k = 0) {
4698
+ if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
4699
+ a = fudgeArray(a);
4700
+ const [n, m] = a.shape.slice(-2);
4701
+ return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
4702
+ }
4703
+ /**
3947
4704
  * Clip (limit) the values in an array.
3948
4705
  *
3949
4706
  * Given an interval, values outside the interval are clipped to the interval
@@ -3967,18 +4724,70 @@ function absolute(x) {
3967
4724
  x = fudgeArray(x);
3968
4725
  return where(less(x.ref, 0), x.ref.mul(-1), x);
3969
4726
  }
3970
- /** Alias of `jax.numpy.absolute()`. */
4727
+ /** @function Alias of `jax.numpy.absolute()`. */
3971
4728
  const abs = absolute;
4729
+ /** Return an element-wise indication of sign of the input. */
4730
+ function sign(x) {
4731
+ x = fudgeArray(x);
4732
+ return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
4733
+ }
3972
4734
  /** Calculate element-wise square of the input array. */
3973
4735
  function square(x) {
3974
4736
  x = fudgeArray(x);
3975
4737
  return x.ref.mul(x);
3976
4738
  }
3977
- /** Compute a trigonometric tangent of each element of input. */
4739
+ /** Element-wise tangent function (takes radians). */
3978
4740
  function tan(x) {
3979
4741
  x = fudgeArray(x);
3980
4742
  return sin(x.ref).div(cos(x));
3981
4743
  }
4744
+ /** Element-wise inverse cosine function (inverse of cos). */
4745
+ function acos(x) {
4746
+ return subtract(pi / 2, asin(x));
4747
+ }
4748
+ /**
4749
+ * @function
4750
+ * Return element-wise hypotenuse for the given legs of a right triangle.
4751
+ *
4752
+ * In the original NumPy/JAX implementation, this function is more numerically
4753
+ * stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
4754
+ * improvements.
4755
+ */
4756
+ const hypot = jit$1((x1, x2) => {
4757
+ return sqrt(square(x1).add(square(x2)));
4758
+ });
4759
+ /**
4760
+ * @function
4761
+ * Element-wise arc tangent of y/x with correct quadrant.
4762
+ *
4763
+ * Returns the angle in radians between the positive x-axis and the point (x, y).
4764
+ * The result is in the range [-π, π].
4765
+ *
4766
+ * Uses numerically stable formulas:
4767
+ * - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
4768
+ * - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
4769
+ *
4770
+ * The output is ill-defined when both x and y are zero.
4771
+ */
4772
+ const atan2 = jit$1((y, x) => {
4773
+ const r = sqrt(square(x.ref).add(square(y.ref)));
4774
+ const xNeg = less(x.ref, 0);
4775
+ const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
4776
+ const denom = where(xNeg, y, r.add(x));
4777
+ return atan(numer.div(denom)).mul(2);
4778
+ });
4779
+ /** @function Alias of `jax.numpy.acos()`. */
4780
+ const arccos = acos;
4781
+ /** @function Alias of `jax.numpy.atan()`. */
4782
+ const arctan = atan;
4783
+ /** @function Alias of `jax.numpy.atan2()`. */
4784
+ const arctan2 = atan2;
4785
+ /** Element-wise subtraction, with broadcasting. */
4786
+ function subtract(x, y) {
4787
+ x = fudgeArray(x);
4788
+ y = fudgeArray(y);
4789
+ return x.sub(y);
4790
+ }
3982
4791
  /** Calculates the floating-point division of x by y element-wise. */
3983
4792
  function trueDivide(x, y) {
3984
4793
  x = fudgeArray(x);
@@ -3986,7 +4795,7 @@ function trueDivide(x, y) {
3986
4795
  if (!require_backend.isFloatDtype(x.dtype) || !require_backend.isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
3987
4796
  return x.div(y);
3988
4797
  }
3989
- /** Alias of `jax.numpy.trueDivide()`. */
4798
+ /** @function Alias of `jax.numpy.trueDivide()`. */
3990
4799
  const divide = trueDivide;
3991
4800
  /** Round input to the nearest integer towards zero. */
3992
4801
  function trunc(x) {
@@ -4004,15 +4813,151 @@ function log2(x) {
4004
4813
  function log10(x) {
4005
4814
  return log(x).mul(Math.LOG10E);
4006
4815
  }
4816
+ /** Calculate `exp(x) - 1` element-wise. */
4817
+ function expm1(x) {
4818
+ return exp(x).sub(1);
4819
+ }
4820
+ /** Calculate the natural logarithm of `1 + x` element-wise. */
4821
+ function log1p(x) {
4822
+ return log(add(1, x));
4823
+ }
4824
+ /** Convert angles from degrees to radians. */
4825
+ function deg2rad(x) {
4826
+ return multiply(x, pi / 180);
4827
+ }
4828
+ /** @function Alias of `jax.numpy.deg2rad()`. */
4829
+ const radians = deg2rad;
4830
+ /** Convert angles from radians to degrees. */
4831
+ function rad2deg(x) {
4832
+ return multiply(x, 180 / pi);
4833
+ }
4834
+ /** @function Alias of `jax.numpy.rad2deg()`. */
4835
+ const degrees = rad2deg;
4836
+ /**
4837
+ * @function
4838
+ * Computes first array raised to power of second array, element-wise.
4839
+ */
4840
+ const power = jit$1((x1, x2) => {
4841
+ return exp(log(x1).mul(x2));
4842
+ });
4843
+ /** @function Alias of `jax.numpy.power()`. */
4844
+ const pow = power;
4845
+ /** @function Calculate the element-wise cube root of the input array. */
4846
+ const cbrt = jit$1((x) => {
4847
+ const sgn = where(less(x.ref, 0), -1, 1);
4848
+ return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
4849
+ });
4850
+ /**
4851
+ * @function
4852
+ * Calculate element-wise hyperbolic sine of input.
4853
+ *
4854
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
4855
+ */
4856
+ const sinh = jit$1((x) => {
4857
+ const ex = exp(x);
4858
+ const emx = reciprocal(ex.ref);
4859
+ return ex.sub(emx).mul(.5);
4860
+ });
4861
+ /**
4862
+ * @function
4863
+ * Calculate element-wise hyperbolic cosine of input.
4864
+ *
4865
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
4866
+ */
4867
+ const cosh = jit$1((x) => {
4868
+ const ex = exp(x);
4869
+ const emx = reciprocal(ex.ref);
4870
+ return ex.add(emx).mul(.5);
4871
+ });
4872
+ /**
4873
+ * @function
4874
+ * Calculate element-wise hyperbolic tangent of input.
4875
+ *
4876
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4877
+ */
4878
+ const tanh = jit$1((x) => {
4879
+ const negsgn = where(less(x.ref, 0), 1, -1);
4880
+ const en2x = exp(x.mul(negsgn.ref).mul(2));
4881
+ return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4882
+ });
4883
+ /**
4884
+ * @function
4885
+ * Calculate element-wise inverse hyperbolic sine of input.
4886
+ *
4887
+ * `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
4888
+ */
4889
+ const arcsinh = jit$1((x) => {
4890
+ return log(x.ref.add(sqrt(square(x).add(1))));
4891
+ });
4892
+ /**
4893
+ * @function
4894
+ * Calculate element-wise inverse hyperbolic cosine of input.
4895
+ *
4896
+ * `arccosh(x) = ln(x + sqrt(x^2 - 1))`
4897
+ */
4898
+ const arccosh = jit$1((x) => {
4899
+ return log(x.ref.add(sqrt(square(x).sub(1))));
4900
+ });
4901
+ /**
4902
+ * @function
4903
+ * Calculate element-wise inverse hyperbolic tangent of input.
4904
+ *
4905
+ * `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
4906
+ */
4907
+ const arctanh = jit$1((x) => {
4908
+ return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
4909
+ });
4910
+ /** @function Alias of `jax.numpy.arcsinh()`. */
4911
+ const asinh = arcsinh;
4912
+ /** @function Alias of `jax.numpy.arccosh()`. */
4913
+ const acosh = arccosh;
4914
+ /** @function Alias of `jax.numpy.arctanh()`. */
4915
+ const atanh = arctanh;
4916
+ /**
4917
+ * Compute the variance of an array.
4918
+ *
4919
+ * The variance is computed for the flattened array by default, otherwise over
4920
+ * the specified axis.
4921
+ *
4922
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4923
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4924
+ */
4925
+ function var_(x, axis = null, opts) {
4926
+ x = fudgeArray(x);
4927
+ axis = require_backend.normalizeAxis(axis, x.ndim);
4928
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
4929
+ if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
4930
+ const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
4931
+ return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
4932
+ }
4933
+ /**
4934
+ * Compute the standard deviation of an array.
4935
+ *
4936
+ * The standard deviation is computed for the flattened array by default,
4937
+ * otherwise over the specified axis.
4938
+ *
4939
+ * If `correction` is provided, the divisor in calculation is `N - correction`,
4940
+ * where `N` represents the number of elements (e.g., for Bessel's correction).
4941
+ */
4942
+ function std(x, axis = null, opts) {
4943
+ return sqrt(var_(x, axis, opts));
4944
+ }
4007
4945
 
4008
4946
  //#endregion
4009
4947
  //#region src/nn.ts
4010
4948
  var nn_exports = {};
4011
4949
  __export(nn_exports, {
4950
+ celu: () => celu,
4951
+ elu: () => elu,
4952
+ gelu: () => gelu,
4953
+ glu: () => glu,
4012
4954
  identity: () => identity,
4955
+ leakyRelu: () => leakyRelu,
4013
4956
  logSigmoid: () => logSigmoid,
4014
4957
  logSoftmax: () => logSoftmax,
4958
+ logmeanexp: () => logmeanexp,
4015
4959
  logsumexp: () => logsumexp,
4960
+ mish: () => mish,
4016
4961
  oneHot: () => oneHot,
4017
4962
  relu: () => relu,
4018
4963
  relu6: () => relu6,
@@ -4021,6 +4966,8 @@ __export(nn_exports, {
4021
4966
  softSign: () => softSign,
4022
4967
  softmax: () => softmax,
4023
4968
  softplus: () => softplus,
4969
+ squareplus: () => squareplus,
4970
+ standardize: () => standardize,
4024
4971
  swish: () => swish
4025
4972
  });
4026
4973
  /**
@@ -4064,6 +5011,7 @@ function softSign(x) {
4064
5011
  return x.ref.div(absolute(x).add(1));
4065
5012
  }
4066
5013
  /**
5014
+ * @function
4067
5015
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4068
5016
  * Swish, computed element-wise:
4069
5017
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4072,11 +5020,9 @@ function softSign(x) {
4072
5020
  *
4073
5021
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4074
5022
  */
4075
- function silu(x) {
4076
- x = fudgeArray(x);
4077
- return x.ref.mul(sigmoid(x));
4078
- }
5023
+ const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
4079
5024
  /**
5025
+ * @function
4080
5026
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4081
5027
  * Swish, computed element-wise:
4082
5028
  * `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
@@ -4093,8 +5039,88 @@ const swish = silu;
4093
5039
  function logSigmoid(x) {
4094
5040
  return negative(softplus(negative(x)));
4095
5041
  }
4096
- /** Identity activation function. Returns the argument unmodified. */
5042
+ /**
5043
+ * @function
5044
+ * Identity activation function. Returns the argument unmodified.
5045
+ */
4097
5046
  const identity = fudgeArray;
5047
+ /** Leaky rectified linear (ReLU) activation function */
5048
+ function leakyRelu(x, negativeSlope = .01) {
5049
+ x = fudgeArray(x);
5050
+ return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
5051
+ }
5052
+ /**
5053
+ * Exponential linear unit activation function.
5054
+ *
5055
+ * Computes the element-wise function:
5056
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
5057
+ */
5058
+ function elu(x, alpha = 1) {
5059
+ x = fudgeArray(x);
5060
+ return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
5061
+ }
5062
+ /**
5063
+ * Continuously-differentiable exponential linear unit activation function.
5064
+ *
5065
+ * Computes the element-wise function:
5066
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
5067
+ */
5068
+ function celu(x, alpha = 1) {
5069
+ x = fudgeArray(x);
5070
+ return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
5071
+ }
5072
+ /**
5073
+ * @function
5074
+ * Gaussion error linear unit (GELU) activation function.
5075
+ *
5076
+ * This is computed element-wise. Currently jax-js does not support the erf() or
5077
+ * gelu() functions exactly as primitives, so an approximation is used:
5078
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
5079
+ *
5080
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
5081
+ *
5082
+ * This will be improved in the future.
5083
+ */
5084
+ const gelu = jit$1((x) => {
5085
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
5086
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
5087
+ });
5088
+ /**
5089
+ * Gated linear unit (GLU) activation function.
5090
+ *
5091
+ * Splits the `axis` dimension of the input into two halves, a and b, then
5092
+ * computes `a * sigmoid(b)`.
5093
+ */
5094
+ function glu(x, axis = -1) {
5095
+ x = fudgeArray(x);
5096
+ axis = require_backend.checkAxis(axis, x.ndim);
5097
+ const size$1 = x.shape[axis];
5098
+ if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
5099
+ const slice = x.shape.map((a$1) => [0, a$1]);
5100
+ const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
5101
+ const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
5102
+ return a.mul(sigmoid(b));
5103
+ }
5104
+ /**
5105
+ * Squareplus activation function.
5106
+ *
5107
+ * Computes the element-wise function:
5108
+ * `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
5109
+ */
5110
+ function squareplus(x, b = 4) {
5111
+ x = fudgeArray(x);
5112
+ return x.ref.add(sqrt(square(x).add(b))).mul(.5);
5113
+ }
5114
+ /**
5115
+ * Mish activation function.
5116
+ *
5117
+ * Computes the element-wise function:
5118
+ * `mish(x) = x * tanh(softplus(x))`
5119
+ */
5120
+ function mish(x) {
5121
+ x = fudgeArray(x);
5122
+ return x.ref.mul(tanh(softplus(x)));
5123
+ }
4098
5124
  /**
4099
5125
  * Softmax function. Computes the function which rescales elements to the range
4100
5126
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -4103,17 +5129,13 @@ const identity = fudgeArray;
4103
5129
  *
4104
5130
  * Reference: https://en.wikipedia.org/wiki/Softmax_function
4105
5131
  */
4106
- function softmax(x, axis) {
5132
+ function softmax(x, axis = -1) {
4107
5133
  x = fudgeArray(x);
4108
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4109
- else if (typeof axis === "number") axis = [axis];
4110
- if (axis.length === 0) {
4111
- x.dispose();
4112
- return ones(x.shape);
4113
- }
4114
- const xMax = max(x.ref, axis, { keepDims: true });
5134
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5135
+ if (axis.length === 0) return onesLike(x);
5136
+ const xMax = max(x.ref, axis, { keepdims: true });
4115
5137
  const unnormalized = exp(x.sub(stopGradient(xMax)));
4116
- return unnormalized.ref.div(unnormalized.sum(axis, { keepDims: true }));
5138
+ return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
4117
5139
  }
4118
5140
  /**
4119
5141
  * Log-Softmax function.
@@ -4123,17 +5145,13 @@ function softmax(x, axis) {
4123
5145
  *
4124
5146
  * If `axis` is not specified, it defaults to the last axis.
4125
5147
  */
4126
- function logSoftmax(x, axis) {
5148
+ function logSoftmax(x, axis = -1) {
4127
5149
  x = fudgeArray(x);
4128
- if (axis === void 0) axis = x.ndim ? [x.ndim - 1] : [];
4129
- else if (typeof axis === "number") axis = [axis];
4130
- if (axis.length === 0) {
4131
- x.dispose();
4132
- return zeros(x.shape);
4133
- }
4134
- const xMax = max(x.ref, axis, { keepDims: true });
5150
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5151
+ if (axis.length === 0) return zerosLike(x);
5152
+ const xMax = max(x.ref, axis, { keepdims: true });
4135
5153
  const shifted = x.sub(stopGradient(xMax));
4136
- const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepDims: true }));
5154
+ const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
4137
5155
  return shifted.sub(shiftedLogsumexp);
4138
5156
  }
4139
5157
  /**
@@ -4144,16 +5162,39 @@ function logSoftmax(x, axis) {
4144
5162
  *
4145
5163
  * Reference: https://en.wikipedia.org/wiki/LogSumExp
4146
5164
  */
4147
- function logsumexp(x, axis) {
5165
+ function logsumexp(x, axis = null) {
4148
5166
  x = fudgeArray(x);
4149
- if (axis === void 0) axis = require_backend.range(x.ndim);
4150
- else if (typeof axis === "number") axis = [axis];
5167
+ axis = require_backend.normalizeAxis(axis, x.ndim);
4151
5168
  if (axis.length === 0) return x;
4152
5169
  const xMax = stopGradient(max(x.ref, axis));
4153
5170
  const xMaxDims = broadcast(xMax.ref, x.shape, axis);
4154
5171
  const shifted = x.sub(xMaxDims);
4155
5172
  return xMax.add(log(exp(shifted).sum(axis)));
4156
5173
  }
5174
+ /** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
5175
+ function logmeanexp(x, axis = null) {
5176
+ x = fudgeArray(x);
5177
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5178
+ if (axis.length === 0) return x;
5179
+ const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
5180
+ return logsumexp(x, axis).sub(Math.log(n));
5181
+ }
5182
+ /**
5183
+ * Standardizes input to zero mean and unit variance.
5184
+ *
5185
+ * By default, this is computed over the last axis. You can pass in a different
5186
+ * axis, or `null` to standardize over all elements.
5187
+ *
5188
+ * Epsilon is added to denominator, it defaults to `1e-5` for stability.
5189
+ */
5190
+ function standardize(x, axis = -1, opts = {}) {
5191
+ x = fudgeArray(x);
5192
+ axis = require_backend.normalizeAxis(axis, x.ndim);
5193
+ if (axis.length === 0) return x;
5194
+ const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
5195
+ const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
5196
+ return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
5197
+ }
4157
5198
  /**
4158
5199
  * One-hot encodes the given indices.
4159
5200
  *
@@ -4171,7 +5212,7 @@ function logsumexp(x, axis) {
4171
5212
  * ```
4172
5213
  */
4173
5214
  function oneHot(x, numClasses) {
4174
- if (x.dtype !== "int32") throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
5215
+ if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4175
5216
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4176
5217
  }
4177
5218
 
@@ -4179,8 +5220,11 @@ function oneHot(x, numClasses) {
4179
5220
  //#region src/random.ts
4180
5221
  var random_exports = {};
4181
5222
  __export(random_exports, {
5223
+ bernoulli: () => bernoulli,
4182
5224
  bits: () => bits,
5225
+ exponential: () => exponential,
4183
5226
  key: () => key,
5227
+ normal: () => normal,
4184
5228
  split: () => split,
4185
5229
  uniform: () => uniform
4186
5230
  });
@@ -4211,11 +5255,11 @@ function bits(key$1, shape$1 = []) {
4211
5255
  /** Sample uniform random values in [minval, maxval) with given shape. */
4212
5256
  function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4213
5257
  if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
4214
- const mantissa = bits(key$1, shape$1).div(scalar(512, {
5258
+ const mantissa = bits(key$1, shape$1).div(array(512, {
4215
5259
  dtype: require_backend.DType.Uint32,
4216
5260
  device: key$1.device
4217
5261
  }));
4218
- const float12 = mantissa.add(scalar(1065353216, {
5262
+ const float12 = mantissa.add(array(1065353216, {
4219
5263
  dtype: require_backend.DType.Uint32,
4220
5264
  device: key$1.device
4221
5265
  }));
@@ -4223,6 +5267,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
4223
5267
  if (minval === 0 && maxval === 1) return rand;
4224
5268
  else return rand.mul(maxval - minval).add(minval);
4225
5269
  }
5270
+ /**
5271
+ * Sample Bernoulli random variables with given mean (0,1 categorical).
5272
+ *
5273
+ * Returns a random Boolean array with the specified shape. `p` can be an array
5274
+ * and must be broadcastable to `shape`.
5275
+ */
5276
+ function bernoulli(key$1, p = .5, shape$1 = []) {
5277
+ p = fudgeArray(p);
5278
+ return uniform(key$1, shape$1).less(p);
5279
+ }
5280
+ /** Sample exponential random values according to `p(x) = exp(-x)`. */
5281
+ function exponential(key$1, shape$1 = []) {
5282
+ const u = uniform(key$1, shape$1);
5283
+ return negative(log1p(negative(u)));
5284
+ }
5285
+ /**
5286
+ * Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
5287
+ *
5288
+ * Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
5289
+ * directly inverts the CDF, but we don't have support for that yet. Outputs will not be
5290
+ * bitwise identical to JAX.
5291
+ */
5292
+ function normal(key$1, shape$1 = []) {
5293
+ const [k1, k2] = split(key$1, 2);
5294
+ const u1 = uniform(k1, shape$1);
5295
+ const u2 = uniform(k2, shape$1);
5296
+ const radius = sqrt(log1p(negative(u1)).mul(-2));
5297
+ const theta = u2.mul(2 * Math.PI);
5298
+ return radius.mul(cos(theta));
5299
+ }
4226
5300
 
4227
5301
  //#endregion
4228
5302
  //#region src/polyfills.ts
@@ -4232,35 +5306,98 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
4232
5306
 
4233
5307
  //#endregion
4234
5308
  //#region src/index.ts
4235
- /** Compute the forward-mode Jacobian-vector product for a function. */
5309
+ /**
5310
+ * @function
5311
+ * Compute the forward-mode Jacobian-vector product for a function.
5312
+ */
4236
5313
  const jvp = jvp$1;
4237
- /** Vectorize an operation on a batched axis for one or more inputs. */
5314
+ /**
5315
+ * @function
5316
+ * Vectorize an operation on a batched axis for one or more inputs.
5317
+ */
4238
5318
  const vmap = vmap$1;
4239
- /** Compute the Jacobian evaluated column-by-column by forward-mode AD. */
5319
+ /**
5320
+ * @function
5321
+ * Compute the Jacobian evaluated column-by-column by forward-mode AD.
5322
+ */
4240
5323
  const jacfwd = jacfwd$1;
4241
- /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
5324
+ /**
5325
+ * @function
5326
+ * Construct a Jaxpr by dynamically tracing a function with example inputs.
5327
+ */
4242
5328
  const makeJaxpr = makeJaxpr$1;
5329
+ /**
5330
+ * @function
5331
+ * Mark a function for automatic JIT compilation, with operator fusion.
5332
+ *
5333
+ * The function will be compiled the first time it is called with a set of
5334
+ * argument shapes.
5335
+ *
5336
+ * You can call `.dispose()` on the returned, JIT-compiled function after all
5337
+ * calls to free memory associated with array constants.
5338
+ *
5339
+ * **Options:**
5340
+ * - `staticArgnums`: An array of argument indices to treat as static
5341
+ * (compile-time constant). These arguments must be hashable, won't be traced,
5342
+ * and different values will trigger recompilation.
5343
+ * - `device`: The device to place the computation on. If not specified, the
5344
+ * computation will be placed on the default device.
5345
+ */
4243
5346
  const jit = jit$1;
4244
5347
  /**
5348
+ * @function
4245
5349
  * Produce a local linear approximation to a function at a point using jvp() and
4246
5350
  * partial evaluation.
4247
5351
  */
4248
5352
  const linearize = linearize$1;
4249
- /** Calculate the reverse-mode vector-Jacobian product for a function. */
5353
+ /**
5354
+ * @function
5355
+ * Calculate the reverse-mode vector-Jacobian product for a function.
5356
+ */
4250
5357
  const vjp = vjp$1;
4251
5358
  /**
5359
+ * @function
4252
5360
  * Compute the gradient of a scalar-valued function `f` with respect to its
4253
5361
  * first argument.
4254
5362
  */
4255
5363
  const grad = grad$1;
4256
- /** Create a function that evaluates both `f` and the gradient of `f`. */
5364
+ /**
5365
+ * @function
5366
+ * Create a function that evaluates both `f` and the gradient of `f`.
5367
+ */
4257
5368
  const valueAndGrad = valueAndGrad$1;
4258
- /** Compute the Jacobian evaluated row-by-row by reverse-mode AD. */
5369
+ /**
5370
+ * @function
5371
+ * Compute the Jacobian evaluated row-by-row by reverse-mode AD.
5372
+ */
4259
5373
  const jacrev = jacrev$1;
4260
- /** Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`. */
5374
+ /**
5375
+ * @function
5376
+ * Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
5377
+ */
4261
5378
  const jacobian = jacrev;
5379
+ /**
5380
+ * Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
5381
+ *
5382
+ * This can be used to wait for the results of an intermediate computation to
5383
+ * finish. It's recommended to call this regularly in an iterative computation
5384
+ * to avoid queueing up too many pending operations.
5385
+ *
5386
+ * Does not consume reference to the arrays.
5387
+ */
5388
+ async function blockUntilReady(x) {
5389
+ const promises = [];
5390
+ for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
5391
+ await Promise.all(promises);
5392
+ return x;
5393
+ }
4262
5394
 
4263
5395
  //#endregion
5396
+ exports.Array = Array$1;
5397
+ exports.DType = require_backend.DType;
5398
+ exports.Jaxpr = Jaxpr;
5399
+ exports.blockUntilReady = blockUntilReady;
5400
+ exports.defaultDevice = require_backend.defaultDevice;
4264
5401
  exports.devices = require_backend.devices;
4265
5402
  exports.grad = grad;
4266
5403
  exports.init = require_backend.init;
@@ -4269,6 +5406,12 @@ exports.jacobian = jacobian;
4269
5406
  exports.jacrev = jacrev;
4270
5407
  exports.jit = jit;
4271
5408
  exports.jvp = jvp;
5409
+ Object.defineProperty(exports, 'lax', {
5410
+ enumerable: true,
5411
+ get: function () {
5412
+ return lax_exports;
5413
+ }
5414
+ });
4272
5415
  exports.linearize = linearize;
4273
5416
  exports.makeJaxpr = makeJaxpr;
4274
5417
  Object.defineProperty(exports, 'nn', {
@@ -4289,7 +5432,7 @@ Object.defineProperty(exports, 'random', {
4289
5432
  return random_exports;
4290
5433
  }
4291
5434
  });
4292
- exports.setDevice = require_backend.setDevice;
5435
+ exports.setDebug = require_backend.setDebug;
4293
5436
  Object.defineProperty(exports, 'tree', {
4294
5437
  enumerable: true,
4295
5438
  get: function () {