@jax-js/jax 0.0.2 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.cjs CHANGED
@@ -30,13 +30,14 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
30
30
  }) : target, mod));
31
31
 
32
32
  //#endregion
33
- const require_backend = require('./backend-BK21PBVP.cjs');
33
+ const require_backend = require('./backend-D2C4MJRP.cjs');
34
34
 
35
35
  //#region src/tree.ts
36
36
  var tree_exports = {};
37
37
  __export(tree_exports, {
38
38
  JsTreeDef: () => JsTreeDef,
39
39
  NodeType: () => NodeType,
40
+ dispose: () => dispose,
40
41
  flatten: () => flatten,
41
42
  leaves: () => leaves,
42
43
  map: () => map,
@@ -51,7 +52,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
51
52
  NodeType$1["Leaf"] = "Leaf";
52
53
  return NodeType$1;
53
54
  }({});
54
- /** Analog to the JAX "pytree" object, but for JavaScript. */
55
+ /** Represents the structure of a JsTree. */
55
56
  var JsTreeDef = class JsTreeDef {
56
57
  static leaf = new JsTreeDef(NodeType.Leaf, null, []);
57
58
  constructor(nodeType, nodeMetadata, childTreedefs) {
@@ -139,6 +140,194 @@ function map(fn, tree, ...rest) {
139
140
  function ref(tree) {
140
141
  return map((x) => x.ref, tree);
141
142
  }
143
+ /** Dispose every array in a tree. */
144
+ function dispose(tree) {
145
+ if (tree) map((x) => x.dispose(), tree);
146
+ }
147
+
148
+ //#endregion
149
+ //#region src/frontend/convolution.ts
150
+ /**
151
+ * Check that the shapes and parameters passed to convolution are valid.
152
+ *
153
+ * If the check succeeds, returns the output shape.
154
+ */
155
+ function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
156
+ if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
157
+ const n = lhsShape.length - 2;
158
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
159
+ if (strides.length !== n) throw new Error("conv() strides != spatial dims");
160
+ if (padding.length !== n) throw new Error("conv() padding != spatial dims");
161
+ if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
162
+ if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
163
+ if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
164
+ const outShape = [lhsShape[0], rhsShape[0]];
165
+ for (let i = 0; i < n; i++) {
166
+ if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
167
+ if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
168
+ if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
169
+ if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
170
+ const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
171
+ if (k <= 0) throw new Error("conv() kernel size must be positive");
172
+ const [pl, pr] = padding[i];
173
+ if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
174
+ const kernelSize = (k - 1) * rhsDilation[i] + 1;
175
+ const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
176
+ if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
177
+ outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
178
+ }
179
+ return outShape;
180
+ }
181
+ function checkPoolShape(inShape, window, strides) {
182
+ if (strides.length !== window.length) throw new Error("pool() strides != window dims");
183
+ if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
184
+ const outShape = inShape.slice(0, inShape.length - window.length);
185
+ for (let i = 0; i < window.length; i++) {
186
+ const k = window[i];
187
+ const s = strides[i];
188
+ const size$1 = inShape[inShape.length - window.length + i];
189
+ if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
190
+ if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
191
+ if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
192
+ outShape.push(Math.ceil((size$1 - k + 1) / s));
193
+ }
194
+ return outShape.concat(window);
195
+ }
196
+ /**
197
+ * Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
198
+ * `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
199
+ * spatial dimensions convolved with a kernel.
200
+ *
201
+ * The resulting array can be multiplied with a kernel of shape `ks`, then
202
+ * reduced along the last `ks.length` dimensions for a convolution.
203
+ *
204
+ * Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
205
+ */
206
+ function pool(st, ks, strides = 1, dilation = 1) {
207
+ if (ks.length === 0) return st;
208
+ if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
209
+ if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
210
+ if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
211
+ if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
212
+ if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
213
+ const noop = st.shape.slice(0, -ks.length);
214
+ const i_ = st.shape.slice(-ks.length);
215
+ const s_ = strides;
216
+ const d_ = dilation;
217
+ const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
218
+ const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
219
+ const kidf = require_backend.zipn(ks, i_, d_, f_);
220
+ st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
221
+ st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
222
+ const kos = require_backend.zipn(ks, o_, s_);
223
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
224
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
225
+ [0, k],
226
+ [0, o],
227
+ [0, 1]
228
+ ])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
229
+ st = st.permute([
230
+ ...require_backend.range(noop.length),
231
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
232
+ ...ks.map((_, j) => noop.length + 2 * j)
233
+ ]);
234
+ return st;
235
+ }
236
+ /**
237
+ * Perform the transpose of pool, directly undo-ing a pool() operation.
238
+ *
239
+ * Note that since pool repeats the input, the transpose operation technically
240
+ * should include a sum reduction. This function doesn't perform the reduction,
241
+ * which should be done on the last `k` axes of the returned shape.
242
+ */
243
+ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
244
+ if (ks.length === 0) return st;
245
+ if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
246
+ if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
247
+ const noop = inShape.slice(0, -ks.length);
248
+ const i_ = inShape.slice(-ks.length);
249
+ const s_ = strides;
250
+ const d_ = dilation;
251
+ const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
252
+ if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
253
+ const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
254
+ const kidf = require_backend.zipn(ks, i_, d_, f_);
255
+ const kos = require_backend.zipn(ks, o_, s_);
256
+ st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
257
+ st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
258
+ k,
259
+ o,
260
+ 1
261
+ ])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
262
+ [0, 0],
263
+ [0, 0],
264
+ [0, s - 1]
265
+ ])]);
266
+ st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
267
+ st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
268
+ st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
269
+ ...require_backend.range(noop.length),
270
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
271
+ ...ks.map((_, j) => noop.length + 2 * j)
272
+ ]);
273
+ return st;
274
+ }
275
+ /** Applies dilation to an array directly, for transposed convolution. */
276
+ function applyDilation(st, dilation) {
277
+ if (dilation.every((s) => s === 1)) return st;
278
+ const s_ = dilation;
279
+ const [a, b, ...k_] = st.shape;
280
+ st = st.reshape([
281
+ a,
282
+ b,
283
+ ...k_.flatMap((k) => [k, 1])
284
+ ]);
285
+ st = st.pad([
286
+ [0, 0],
287
+ [0, 0],
288
+ ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
289
+ ]);
290
+ st = st.reshape([
291
+ a,
292
+ b,
293
+ ...k_.map((k, i) => k * s_[i])
294
+ ]);
295
+ st = st.shrink([
296
+ [0, a],
297
+ [0, b],
298
+ ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
299
+ ]);
300
+ return st;
301
+ }
302
+ /**
303
+ * Prepare for a convolution between two arrays.
304
+ *
305
+ * This does not check the validity of the shapes, which should be checked
306
+ * beforehand using `checkConvShape()`.
307
+ */
308
+ function prepareConv(stX, stY, params) {
309
+ const n = stX.shape.length - 2;
310
+ stX = applyDilation(stX, params.lhsDilation);
311
+ const ks = stY.shape.slice(2);
312
+ stX = stX.padOrShrink([
313
+ [0, 0],
314
+ [0, 0],
315
+ ...params.padding
316
+ ]);
317
+ stX = pool(stX, ks, params.strides, params.rhsDilation);
318
+ stX = stX.moveaxis(1, n + 1).reshape([
319
+ stX.shape[0],
320
+ 1,
321
+ ...stX.shape.slice(2, n + 2),
322
+ stX.shape[1] * require_backend.prod(ks)
323
+ ]);
324
+ stY = stY.reshape([
325
+ stY.shape[0],
326
+ ...require_backend.rep(n, 1),
327
+ stY.shape[1] * require_backend.prod(ks)
328
+ ]);
329
+ return [stX, stY];
330
+ }
142
331
 
143
332
  //#endregion
144
333
  //#region src/frontend/core.ts
@@ -167,10 +356,14 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
167
356
  Primitive$1["Cos"] = "cos";
168
357
  Primitive$1["Exp"] = "exp";
169
358
  Primitive$1["Log"] = "log";
359
+ Primitive$1["Sqrt"] = "sqrt";
170
360
  Primitive$1["Min"] = "min";
171
361
  Primitive$1["Max"] = "max";
172
362
  Primitive$1["Reduce"] = "reduce";
173
363
  Primitive$1["Dot"] = "dot";
364
+ Primitive$1["Conv"] = "conv";
365
+ Primitive$1["Pool"] = "pool";
366
+ Primitive$1["PoolTranspose"] = "pool_transpose";
174
367
  Primitive$1["Compare"] = "compare";
175
368
  Primitive$1["Where"] = "where";
176
369
  Primitive$1["Transpose"] = "transpose";
@@ -234,6 +427,9 @@ function exp$1(x) {
234
427
  function log$1(x) {
235
428
  return bind1(Primitive.Log, [x]);
236
429
  }
430
+ function sqrt$1(x) {
431
+ return bind1(Primitive.Sqrt, [x]);
432
+ }
237
433
  function min$1(x, y) {
238
434
  return bind1(Primitive.Min, [x, y]);
239
435
  }
@@ -256,6 +452,17 @@ function reduce(x, op, axis, opts) {
256
452
  function dot$1(x, y) {
257
453
  return bind1(Primitive.Dot, [x, y]);
258
454
  }
455
+ function conv(x, y, params = {}) {
456
+ if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
457
+ const n = x.ndim - 2;
458
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
459
+ return bind1(Primitive.Conv, [x, y], {
460
+ strides: params.strides ?? require_backend.rep(n, 1),
461
+ padding: params.padding ?? require_backend.rep(n, [0, 0]),
462
+ lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
463
+ rhsDilation: params.rhsDilation ?? require_backend.rep(n, 1)
464
+ });
465
+ }
259
466
  function compare(x, y, op) {
260
467
  return bind1(Primitive.Compare, [x, y], { op });
261
468
  }
@@ -391,6 +598,9 @@ var Tracer = class Tracer {
391
598
  get shape() {
392
599
  return this.aval.shape;
393
600
  }
601
+ get size() {
602
+ return require_backend.prod(this.shape);
603
+ }
394
604
  get dtype() {
395
605
  return this.aval.dtype;
396
606
  }
@@ -442,7 +652,7 @@ var Tracer = class Tracer {
442
652
  else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, this.ndim)];
443
653
  else axis = axis.map((a) => require_backend.checkAxis(a, this.ndim));
444
654
  let result = reduce(this, require_backend.AluOp.Add, axis);
445
- result = result.mul(require_backend.prod(result.shape) / require_backend.prod(this.shape));
655
+ result = result.mul(result.size / this.size);
446
656
  if (opts?.keepDims) result = broadcast(result, this.shape, axis);
447
657
  return result;
448
658
  }
@@ -476,8 +686,29 @@ var Tracer = class Tracer {
476
686
  /** Return specified diagonals. See `numpy.diagonal` for full docs. */
477
687
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
478
688
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
689
+ if (offset < 0) return this.diagonal(-offset, axis2, axis1);
690
+ axis1 = require_backend.checkAxis(axis1, this.ndim);
691
+ axis2 = require_backend.checkAxis(axis2, this.ndim);
479
692
  if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
480
- throw new Error("diagonal not implemented");
693
+ if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
694
+ let ar = this;
695
+ if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
696
+ const perm = require_backend.range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
697
+ ar = ar.transpose(perm);
698
+ }
699
+ const [n, m] = ar.shape.slice(-2);
700
+ const diagSize = Math.min(n, m - offset);
701
+ ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
702
+ const npad = diagSize * (m + 1) - n * m;
703
+ if (npad > 0) ar = pad$1(ar, [...require_backend.rep(ar.ndim - 1, [0, 0]), [0, npad]]);
704
+ else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
705
+ ar = ar.reshape([
706
+ ...ar.shape.slice(0, -1),
707
+ diagSize,
708
+ m + 1
709
+ ]);
710
+ ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
711
+ return ar;
481
712
  }
482
713
  /** Flatten the array without changing its data. */
483
714
  flatten() {
@@ -620,7 +851,7 @@ var ShapedArray = class ShapedArray {
620
851
  get ndim() {
621
852
  return this.shape.length;
622
853
  }
623
- strShort() {
854
+ toString() {
624
855
  return `${this.dtype}[${this.shape.join(",")}]`;
625
856
  }
626
857
  equals(other) {
@@ -651,7 +882,7 @@ function fullRaise(trace, val) {
651
882
  if (Object.is(val._trace.main, trace.main)) return val;
652
883
  else if (val._trace.main.level < level) return trace.lift(val);
653
884
  else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
654
- else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
885
+ else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
655
886
  }
656
887
  var TreeMismatchError = class extends TypeError {
657
888
  constructor(where$2, left, right) {
@@ -900,16 +1131,16 @@ function jitCompile(backend, jaxpr, consts) {
900
1131
  jitCompileCache.set(cacheKey, jp);
901
1132
  return jp;
902
1133
  }
903
- function reshapeViews(exp$2, mapping) {
1134
+ function reshapeViews(exp$2, mapping, reduceAxis = false) {
904
1135
  return exp$2.rewrite((exp$3) => {
905
1136
  if (exp$3.op === require_backend.AluOp.GlobalView) {
906
1137
  const [gid, st] = exp$3.arg;
907
1138
  const newSt = mapping(st);
908
1139
  if (newSt) {
909
- const indices = require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
1140
+ const indices = reduceAxis ? require_backend.unravelAlu(newSt.shape.slice(0, -1), require_backend.AluVar.gidx).concat(require_backend.AluVar.ridx) : require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
910
1141
  return require_backend.AluExp.globalView(exp$3.dtype, gid, newSt, indices);
911
1142
  }
912
- }
1143
+ } else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
913
1144
  });
914
1145
  }
915
1146
  function broadcastedJit(fn) {
@@ -958,6 +1189,7 @@ const jitRules = {
958
1189
  [Primitive.Cos]: unopJit(require_backend.AluExp.cos),
959
1190
  [Primitive.Exp]: unopJit(require_backend.AluExp.exp),
960
1191
  [Primitive.Log]: unopJit(require_backend.AluExp.log),
1192
+ [Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
961
1193
  [Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
962
1194
  [Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
963
1195
  [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
@@ -972,18 +1204,20 @@ const jitRules = {
972
1204
  const size$1 = require_backend.prod(newShape);
973
1205
  const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
974
1206
  newShape.push(reductionSize);
975
- a = a.rewrite((exp$2) => {
976
- if (exp$2.op === require_backend.AluOp.GlobalView) {
977
- const [gid, st] = exp$2.arg;
978
- const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
979
- const indices = require_backend.unravelAlu(newShape.slice(0, -1), require_backend.AluVar.gidx);
980
- indices.push(require_backend.AluVar.ridx);
981
- return require_backend.AluExp.globalView(exp$2.dtype, gid, newSt, indices);
982
- }
983
- });
1207
+ const perm = keptAxes.concat(shiftedAxes);
1208
+ a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
984
1209
  const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
985
1210
  return new require_backend.Kernel(nargs, size$1, a, reduction);
986
1211
  },
1212
+ [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1213
+ [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
1214
+ let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
1215
+ const size$1 = require_backend.prod(inShape);
1216
+ stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
1217
+ a = reshapeViews(a, (st) => st.compose(stX), true);
1218
+ const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
1219
+ return new require_backend.Kernel(nargs, size$1, a, reduction);
1220
+ },
987
1221
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
988
1222
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
989
1223
  const c = k1.exp;
@@ -993,6 +1227,14 @@ const jitRules = {
993
1227
  axis: [cs.ndim - 1]
994
1228
  });
995
1229
  },
1230
+ [Primitive.Conv](nargs, [a, b], [as, bs], params) {
1231
+ const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
1232
+ a = reshapeViews(a, (st) => st.compose(stX));
1233
+ b = reshapeViews(b, (st) => st.compose(stY));
1234
+ as = new ShapedArray(stX.shape, as.dtype);
1235
+ bs = new ShapedArray(stY.shape, bs.dtype);
1236
+ return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1237
+ },
996
1238
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
997
1239
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
998
1240
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
@@ -1005,8 +1247,20 @@ const jitRules = {
1005
1247
  }),
1006
1248
  [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
1007
1249
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
1008
- [Primitive.Gather]() {
1009
- throw new Error("Gather is not implemented in JIT yet");
1250
+ [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1251
+ const axisSet = new Set(axis);
1252
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1253
+ const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1254
+ finalShape.splice(outDim, 0, ...indexShape);
1255
+ const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
1256
+ const idxNonaxis = [...idxAll];
1257
+ idxNonaxis.splice(outDim, indexShape.length);
1258
+ const src = [...idxNonaxis];
1259
+ for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1260
+ for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
1261
+ const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
1262
+ if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1263
+ return new require_backend.Kernel(nargs, require_backend.prod(finalShape), x.substitute({ gidx: index }));
1010
1264
  },
1011
1265
  [Primitive.JitCall]() {
1012
1266
  throw new Error("internal: JitCall should have been flattened before JIT compilation");
@@ -1025,9 +1279,15 @@ function splitGraphDataflow(backend, jaxpr) {
1025
1279
  blackNodes.add(v);
1026
1280
  p1NextBlack.set(v, v);
1027
1281
  }
1282
+ const reducePrimitives = [
1283
+ Primitive.Reduce,
1284
+ Primitive.Dot,
1285
+ Primitive.Conv,
1286
+ Primitive.PoolTranspose
1287
+ ];
1028
1288
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
1029
1289
  const eqn = jaxpr.eqns[i];
1030
- if (eqn.primitive === Primitive.Reduce || eqn.outBinders.some((v) => blackNodes.has(v))) {
1290
+ if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1031
1291
  for (const v of eqn.outBinders) {
1032
1292
  blackNodes.add(v);
1033
1293
  p1NextBlack.set(v, v);
@@ -1254,7 +1514,7 @@ var Array$1 = class Array$1 extends Tracer {
1254
1514
  const inputs = [];
1255
1515
  const src = [...idxNonaxis];
1256
1516
  for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1257
- for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#dtype, ar.#source, ar.#st, idxAxis));
1517
+ for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#source, ar.#st, idxAxis));
1258
1518
  else {
1259
1519
  let gid = inputs.indexOf(ar.#source);
1260
1520
  if (gid === -1) {
@@ -1264,7 +1524,7 @@ var Array$1 = class Array$1 extends Tracer {
1264
1524
  src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
1265
1525
  }
1266
1526
  let exp$2;
1267
- if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#dtype, this.#source, this.#st, src);
1527
+ if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, src);
1268
1528
  else {
1269
1529
  let gid = inputs.indexOf(this.#source);
1270
1530
  if (gid === -1) {
@@ -1307,7 +1567,7 @@ var Array$1 = class Array$1 extends Tracer {
1307
1567
  this.#check();
1308
1568
  if (this.#source instanceof require_backend.AluExp) {
1309
1569
  const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
1310
- return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
1570
+ return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1311
1571
  }
1312
1572
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1313
1573
  const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1340,18 +1600,18 @@ var Array$1 = class Array$1 extends Tracer {
1340
1600
  if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1341
1601
  arrays = Array$1.#broadcastArrays(arrays);
1342
1602
  const newShape = [...arrays[0].shape];
1343
- if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && reduceAxis === void 0) {
1603
+ if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
1344
1604
  if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
1345
1605
  const exp$4 = custom(arrays.map((ar) => ar.#source));
1346
- return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
1606
+ return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1347
1607
  }
1348
1608
  const exp$3 = custom(arrays.map((ar) => {
1349
1609
  const src$1 = ar.#source;
1350
1610
  if (ar.#st.contiguous) return src$1;
1351
- return require_backend.accessorAluExp(ar.#dtype, src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1611
+ return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
1352
1612
  }));
1353
1613
  const st = require_backend.ShapeTracker.fromShape(newShape);
1354
- return new Array$1(exp$3, st, exp$3.dtype, backend);
1614
+ return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1355
1615
  }
1356
1616
  let indices;
1357
1617
  if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
@@ -1361,7 +1621,7 @@ var Array$1 = class Array$1 extends Tracer {
1361
1621
  }
1362
1622
  const inputs = [];
1363
1623
  const src = [];
1364
- for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#dtype, ar.#source, ar.#st, indices));
1624
+ for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
1365
1625
  else {
1366
1626
  let gid = inputs.indexOf(ar.#source);
1367
1627
  if (gid === -1) {
@@ -1395,7 +1655,7 @@ var Array$1 = class Array$1 extends Tracer {
1395
1655
  const indices = [...require_backend.unravelAlu(newShape, require_backend.AluVar.gidx), require_backend.AluVar.ridx];
1396
1656
  let exp$2;
1397
1657
  const inputs = [];
1398
- if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1658
+ if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
1399
1659
  else {
1400
1660
  inputs.push(this.#source);
1401
1661
  exp$2 = require_backend.accessorGlobal(this.#dtype, 0, this.#st, indices);
@@ -1420,7 +1680,7 @@ var Array$1 = class Array$1 extends Tracer {
1420
1680
  this.#check();
1421
1681
  const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
1422
1682
  if (this.#source instanceof require_backend.AluExp) {
1423
- const exp$2 = require_backend.accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1683
+ const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
1424
1684
  const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
1425
1685
  const output = this.#backend.malloc(kernel.bytes);
1426
1686
  const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
@@ -1458,42 +1718,51 @@ var Array$1 = class Array$1 extends Tracer {
1458
1718
  }
1459
1719
  /** Realize the array and return it as data. */
1460
1720
  async data() {
1461
- if (this.#source instanceof require_backend.AluExp && require_backend.prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1721
+ if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1462
1722
  this.#realize();
1463
1723
  const pending = this.#pending;
1464
1724
  if (pending) {
1465
1725
  await Promise.all(pending.map((p) => p.prepare()));
1466
1726
  for (const p of pending) p.submit();
1467
1727
  }
1468
- const byteCount = require_backend.byteWidth(this.#dtype) * require_backend.prod(this.shape);
1728
+ const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
1469
1729
  const buf = await this.#backend.read(this.#source, 0, byteCount);
1470
1730
  this.dispose();
1471
1731
  return require_backend.dtypedArray(this.dtype, buf);
1472
1732
  }
1473
- /** Wait for this array to be placed on the backend, if needed. */
1733
+ /**
1734
+ * Wait for this array to finish evaluation.
1735
+ *
1736
+ * Operations and data loading in jax-js are lazy, so this function ensures
1737
+ * that pending operations are dispatched and fully executed before it
1738
+ * returns.
1739
+ *
1740
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1741
+ * dispatch of operations as well.
1742
+ */
1474
1743
  async wait() {
1475
1744
  this.#check();
1476
- if (this.#source instanceof require_backend.AluExp) return;
1745
+ if (this.#source instanceof require_backend.AluExp) return this;
1477
1746
  const pending = this.#pending;
1478
1747
  if (pending) {
1479
1748
  await Promise.all(pending.map((p) => p.prepare()));
1480
1749
  for (const p of pending) p.submit();
1481
1750
  }
1482
1751
  await this.#backend.read(this.#source, 0, 0);
1483
- this.dispose();
1752
+ return this;
1484
1753
  }
1485
1754
  /**
1486
1755
  * Realize the array and return it as data. This is a sync variant and not
1487
1756
  * recommended for performance reasons, as it will block rendering.
1488
1757
  */
1489
1758
  dataSync() {
1490
- if (this.#source instanceof require_backend.AluExp && require_backend.prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1759
+ if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1491
1760
  this.#realize();
1492
1761
  for (const p of this.#pending) {
1493
1762
  p.prepareSync();
1494
1763
  p.submit();
1495
1764
  }
1496
- const byteCount = require_backend.byteWidth(this.#dtype) * require_backend.prod(this.shape);
1765
+ const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
1497
1766
  const buf = this.#backend.readSync(this.#source, 0, byteCount);
1498
1767
  this.dispose();
1499
1768
  return require_backend.dtypedArray(this.dtype, buf);
@@ -1514,6 +1783,16 @@ var Array$1 = class Array$1 extends Tracer {
1514
1783
  async jsAsync() {
1515
1784
  return dataToJs(this.dtype, await this.data(), this.shape);
1516
1785
  }
1786
+ /**
1787
+ * Copy an element of an array to a numeric scalar and return it.
1788
+ *
1789
+ * Throws an error if the array does not have a single element. The array must
1790
+ * either be rank-0, or all dimensions of the shape are 1.
1791
+ */
1792
+ item() {
1793
+ if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
1794
+ return this.dataSync()[0];
1795
+ }
1517
1796
  /** @private Internal plumbing method for Array / Tracer ops. */
1518
1797
  static _implRules() {
1519
1798
  return {
@@ -1527,7 +1806,7 @@ var Array$1 = class Array$1 extends Tracer {
1527
1806
  return [x.#binary(require_backend.AluOp.Idiv, y)];
1528
1807
  },
1529
1808
  [Primitive.Neg]([x]) {
1530
- return [zerosLike(x).#binary(require_backend.AluOp.Sub, x)];
1809
+ return [zerosLike(x.ref).#binary(require_backend.AluOp.Sub, x)];
1531
1810
  },
1532
1811
  [Primitive.Reciprocal]([x]) {
1533
1812
  return [x.#unary(require_backend.AluOp.Reciprocal)];
@@ -1583,6 +1862,9 @@ var Array$1 = class Array$1 extends Tracer {
1583
1862
  [Primitive.Log]([x]) {
1584
1863
  return [x.#unary(require_backend.AluOp.Log)];
1585
1864
  },
1865
+ [Primitive.Sqrt]([x]) {
1866
+ return [x.#unary(require_backend.AluOp.Sqrt)];
1867
+ },
1586
1868
  [Primitive.Min]([x, y]) {
1587
1869
  return [x.#binary(require_backend.AluOp.Min, y)];
1588
1870
  },
@@ -1593,9 +1875,24 @@ var Array$1 = class Array$1 extends Tracer {
1593
1875
  if (axis.length === 0) return [x];
1594
1876
  return [x.#moveAxesDown(axis).#reduce(op)];
1595
1877
  },
1878
+ [Primitive.Pool]([x], { window, strides }) {
1879
+ const st = pool(x.#st, window, strides);
1880
+ return [x.#reshape(st)];
1881
+ },
1882
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
1883
+ const n = inShape.length;
1884
+ let st = poolTranspose(x.#st, inShape, window, strides);
1885
+ st = st.reshape([...st.shape.slice(0, n), require_backend.prod(st.shape.slice(n))]);
1886
+ return [x.#reshape(st).#reduce(require_backend.AluOp.Add)];
1887
+ },
1596
1888
  [Primitive.Dot]([x, y]) {
1597
1889
  return [Array$1.#naryCustom("dot", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
1598
1890
  },
1891
+ [Primitive.Conv]([x, y], params) {
1892
+ checkConvShape(x.shape, y.shape, params);
1893
+ const [stX, stY] = prepareConv(x.#st, y.#st, params);
1894
+ return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
1895
+ },
1599
1896
  [Primitive.Compare]([x, y], { op }) {
1600
1897
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1601
1898
  return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
@@ -1660,6 +1957,7 @@ function scalar(value, { dtype, device } = {}) {
1660
1957
  dtype ??= require_backend.DType.Float32;
1661
1958
  if (![
1662
1959
  require_backend.DType.Float32,
1960
+ require_backend.DType.Float16,
1663
1961
  require_backend.DType.Int32,
1664
1962
  require_backend.DType.Uint32
1665
1963
  ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
@@ -1667,6 +1965,7 @@ function scalar(value, { dtype, device } = {}) {
1667
1965
  dtype ??= require_backend.DType.Bool;
1668
1966
  if (![
1669
1967
  require_backend.DType.Float32,
1968
+ require_backend.DType.Float16,
1670
1969
  require_backend.DType.Int32,
1671
1970
  require_backend.DType.Uint32,
1672
1971
  require_backend.DType.Bool
@@ -1680,7 +1979,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1680
1979
  if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1681
1980
  if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
1682
1981
  return values;
1683
- } else if (values instanceof Float32Array || values instanceof Int32Array) return arrayFromData(values, shape$1 ?? [values.length], {
1982
+ } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1684
1983
  dtype,
1685
1984
  device
1686
1985
  });
@@ -1709,7 +2008,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1709
2008
  });
1710
2009
  } else {
1711
2010
  dtype = dtype ?? require_backend.DType.Float32;
1712
- const data = require_backend.dtypedArray(dtype, flat);
2011
+ const data = require_backend.dtypedJsArray(dtype, flat);
1713
2012
  return arrayFromData(data, shape$1, {
1714
2013
  dtype,
1715
2014
  device
@@ -1730,19 +2029,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
1730
2029
  });
1731
2030
  }
1732
2031
  const backend = require_backend.getBackend(device);
1733
- if (data instanceof Float32Array) {
1734
- if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
1735
- const slot = backend.malloc(data.byteLength, data.buffer);
1736
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), require_backend.DType.Float32, backend);
1737
- } else if (data instanceof Int32Array) {
1738
- if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
1739
- const slot = backend.malloc(data.byteLength, data.buffer);
1740
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype ?? require_backend.DType.Int32, backend);
1741
- } else if (data instanceof Uint32Array) {
1742
- if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
1743
- const slot = backend.malloc(data.byteLength, data.buffer);
1744
- return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), require_backend.DType.Uint32, backend);
1745
- } else throw new Error("Unsupported data type");
2032
+ if (ArrayBuffer.isView(data)) {
2033
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2034
+ if (data instanceof Float32Array) {
2035
+ if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
2036
+ dtype ??= require_backend.DType.Float32;
2037
+ } else if (data instanceof Int32Array) {
2038
+ if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2039
+ dtype ??= require_backend.DType.Int32;
2040
+ } else if (data instanceof Uint32Array) {
2041
+ if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2042
+ dtype ??= require_backend.DType.Uint32;
2043
+ } else if (data instanceof Float16Array) {
2044
+ if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
2045
+ dtype ??= require_backend.DType.Float16;
2046
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2047
+ const slot = backend.malloc(data.byteLength, buf);
2048
+ return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
2049
+ } else throw new Error("Unsupported data type: " + data.constructor.name);
1746
2050
  }
1747
2051
  function dataToJs(dtype, data, shape$1) {
1748
2052
  if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
@@ -1769,9 +2073,20 @@ var EvalTrace = class extends Trace {
1769
2073
  };
1770
2074
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
1771
2075
  const implRules = Array$1._implRules();
1772
- function zerosLike(val) {
2076
+ function zerosLike(val, dtype) {
1773
2077
  const aval = getAval(val);
1774
- return zeros(aval.shape, { dtype: aval.dtype });
2078
+ if (val instanceof Tracer) val.dispose();
2079
+ return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2080
+ }
2081
+ function onesLike(val, dtype) {
2082
+ const aval = getAval(val);
2083
+ if (val instanceof Tracer) val.dispose();
2084
+ return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2085
+ }
2086
+ function fullLike(val, fillValue, dtype) {
2087
+ const aval = getAval(val);
2088
+ if (val instanceof Tracer) val.dispose();
2089
+ return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
1775
2090
  }
1776
2091
  /** Return a new array of given shape and type, filled with zeros. */
1777
2092
  function zeros(shape$1, { dtype, device } = {}) {
@@ -1793,6 +2108,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
1793
2108
  if (typeof fillValue === "number") {
1794
2109
  dtype = dtype ?? require_backend.DType.Float32;
1795
2110
  source = require_backend.AluExp.const(dtype, fillValue);
2111
+ } else if (typeof fillValue === "bigint") {
2112
+ dtype = dtype ?? require_backend.DType.Int32;
2113
+ source = require_backend.AluExp.const(dtype, Number(fillValue));
1796
2114
  } else if (typeof fillValue === "boolean") {
1797
2115
  dtype = dtype ?? require_backend.DType.Bool;
1798
2116
  source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
@@ -1890,7 +2208,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
1890
2208
  const st = require_backend.ShapeTracker.fromShape([num]);
1891
2209
  return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
1892
2210
  }
1893
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
1894
2211
  function aluCompare(a, b, op) {
1895
2212
  switch (op) {
1896
2213
  case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
@@ -1932,8 +2249,8 @@ function generalBroadcast(a, b) {
1932
2249
  }
1933
2250
 
1934
2251
  //#endregion
1935
- //#region node_modules/.pnpm/@oxc-project+runtime@0.77.3/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
1936
- var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.77.3/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js"(exports, module) {
2252
+ //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
2253
+ var require_usingCtx = /* @__PURE__ */ __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js": ((exports, module) => {
1937
2254
  function _usingCtx() {
1938
2255
  var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
1939
2256
  var n$1 = Error();
@@ -1989,11 +2306,11 @@ var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.7
1989
2306
  };
1990
2307
  }
1991
2308
  module.exports = _usingCtx, module.exports.__esModule = true, module.exports["default"] = module.exports;
1992
- } });
2309
+ }) });
1993
2310
 
1994
2311
  //#endregion
1995
2312
  //#region src/frontend/jaxpr.ts
1996
- var import_usingCtx$2 = __toESM(require_usingCtx(), 1);
2313
+ var import_usingCtx$2 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
1997
2314
  /** Variable in a Jaxpr expression. */
1998
2315
  var Var = class Var {
1999
2316
  static #nextId = 1;
@@ -2004,7 +2321,7 @@ var Var = class Var {
2004
2321
  this.aval = aval;
2005
2322
  }
2006
2323
  toString() {
2007
- return `Var(${this.id}):${this.aval.strShort()}`;
2324
+ return `Var(${this.id}):${this.aval.toString()}`;
2008
2325
  }
2009
2326
  };
2010
2327
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
@@ -2044,7 +2361,7 @@ var VarPrinter = class {
2044
2361
  return name;
2045
2362
  }
2046
2363
  nameType(v) {
2047
- return `${this.name(v)}:${v.aval.strShort()}`;
2364
+ return `${this.name(v)}:${v.aval.toString()}`;
2048
2365
  }
2049
2366
  };
2050
2367
  /** A single statement / binding in a Jaxpr, in ANF form. */
@@ -2199,8 +2516,8 @@ var JaxprType = class {
2199
2516
  this.outTypes = outTypes;
2200
2517
  }
2201
2518
  toString() {
2202
- const inTypes = this.inTypes.map((aval) => aval.strShort()).join(", ");
2203
- const outTypes = this.outTypes.map((aval) => aval.strShort()).join(", ");
2519
+ const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
2520
+ const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
2204
2521
  return `(${inTypes}) -> (${outTypes})`;
2205
2522
  }
2206
2523
  };
@@ -2279,7 +2596,7 @@ var JaxprTracer = class extends Tracer {
2279
2596
  this.aval = aval;
2280
2597
  }
2281
2598
  toString() {
2282
- return `JaxprTracer(${this.aval.strShort()})`;
2599
+ return `JaxprTracer(${this.aval.toString()})`;
2283
2600
  }
2284
2601
  get ref() {
2285
2602
  return this;
@@ -2418,6 +2735,7 @@ const abstractEvalRules = {
2418
2735
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2419
2736
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2420
2737
  [Primitive.Log]: vectorizedUnopAbstractEval,
2738
+ [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2421
2739
  [Primitive.Min]: binopAbstractEval,
2422
2740
  [Primitive.Max]: binopAbstractEval,
2423
2741
  [Primitive.Reduce]([x], { axis }) {
@@ -2425,6 +2743,15 @@ const abstractEvalRules = {
2425
2743
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2426
2744
  return [new ShapedArray(newShape, x.dtype)];
2427
2745
  },
2746
+ [Primitive.Pool]([x], { window, strides }) {
2747
+ const shape$1 = checkPoolShape(x.shape, window, strides);
2748
+ return [new ShapedArray(shape$1, x.dtype)];
2749
+ },
2750
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2751
+ const shape$1 = checkPoolShape(inShape, window, strides);
2752
+ if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2753
+ return [new ShapedArray(inShape, x.dtype)];
2754
+ },
2428
2755
  [Primitive.Dot]([x, y]) {
2429
2756
  if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2430
2757
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
@@ -2432,6 +2759,11 @@ const abstractEvalRules = {
2432
2759
  shape$1.splice(-1, 1);
2433
2760
  return [new ShapedArray(shape$1, x.dtype)];
2434
2761
  },
2762
+ [Primitive.Conv]([lhs, rhs], params) {
2763
+ if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2764
+ const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2765
+ return [new ShapedArray(shape$1, lhs.dtype)];
2766
+ },
2435
2767
  [Primitive.Compare]: compareAbstractEval,
2436
2768
  [Primitive.Where]([cond, x, y]) {
2437
2769
  if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
@@ -2479,15 +2811,34 @@ const abstractEvalRules = {
2479
2811
  return outTypes;
2480
2812
  }
2481
2813
  };
2482
- function makeJaxpr$1(f) {
2814
+ function splitIdx(values, argnums) {
2815
+ const a = [];
2816
+ const b = [];
2817
+ for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
2818
+ else b.push(values[i]);
2819
+ return [a, b];
2820
+ }
2821
+ function joinIdx(n, a, b, argnums) {
2822
+ const result = [];
2823
+ let ai = 0;
2824
+ let bi = 0;
2825
+ for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
2826
+ else result.push(b[bi++]);
2827
+ return result;
2828
+ }
2829
+ function makeJaxpr$1(f, opts) {
2483
2830
  return (...argsIn) => {
2484
2831
  try {
2485
2832
  var _usingCtx$1 = (0, import_usingCtx$2.default)();
2486
- const [avalsIn, inTree] = flatten(argsIn);
2487
- const [fFlat, outTree] = flattenFun(f, inTree);
2833
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2834
+ const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
2835
+ const [avalsIn, inTree] = flatten(shapedArgs);
2836
+ const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
2837
+ return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
2838
+ }, inTree);
2488
2839
  const builder = new JaxprBuilder();
2489
2840
  const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
2490
- const _dynamic = _usingCtx$1.u(newDynamic(main));
2841
+ _usingCtx$1.u(newDynamic(main));
2491
2842
  const trace = new JaxprTrace(main);
2492
2843
  const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
2493
2844
  const outs = fFlat(...tracersIn);
@@ -2506,14 +2857,17 @@ function makeJaxpr$1(f) {
2506
2857
  }
2507
2858
  };
2508
2859
  }
2509
- function jit$1(f) {
2860
+ function jit$1(f, opts) {
2510
2861
  const cache = /* @__PURE__ */ new Map();
2862
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2511
2863
  return ((...args) => {
2512
- const [argsFlat, inTree] = flatten(args);
2864
+ const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2865
+ const [argsFlat, inTree] = flatten(dynamicArgs);
2513
2866
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
2514
2867
  const avalsIn = unflatten(inTree, avalsInFlat);
2515
- const cacheKey = JSON.stringify(avalsIn);
2516
- const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f)(...avalsIn));
2868
+ const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
2869
+ const cacheKey = JSON.stringify(jaxprArgs);
2870
+ const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2517
2871
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2518
2872
  jaxpr,
2519
2873
  numConsts: consts.length
@@ -2524,7 +2878,7 @@ function jit$1(f) {
2524
2878
 
2525
2879
  //#endregion
2526
2880
  //#region src/frontend/jvp.ts
2527
- var import_usingCtx$1 = __toESM(require_usingCtx(), 1);
2881
+ var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
2528
2882
  var JVPTracer = class extends Tracer {
2529
2883
  constructor(trace, primal, tangent) {
2530
2884
  super(trace);
@@ -2551,7 +2905,7 @@ var JVPTrace = class extends Trace {
2551
2905
  return this.lift(pureArray(val));
2552
2906
  }
2553
2907
  lift(val) {
2554
- return new JVPTracer(this, val, zerosLike(val));
2908
+ return new JVPTracer(this, val, zerosLike(val.ref));
2555
2909
  }
2556
2910
  processPrimitive(primitive, tracers, params) {
2557
2911
  const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2569,19 +2923,25 @@ function linearTangentsJvp(primitive) {
2569
2923
  return [ys, dys];
2570
2924
  };
2571
2925
  }
2926
+ /** Rule for product of gradients in bilinear operations. */
2927
+ function bilinearTangentsJvp(primitive) {
2928
+ return ([x, y], [dx, dy], params) => {
2929
+ const primal = bind1(primitive, [x.ref, y.ref], params);
2930
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
2931
+ return [[primal], [tangent]];
2932
+ };
2933
+ }
2572
2934
  /** Rule that zeros out any tangents. */
2573
2935
  function zeroTangentsJvp(primitive) {
2574
2936
  return (primals, tangents, params) => {
2575
2937
  for (const t of tangents) t.dispose();
2576
2938
  const ys = bind(primitive, primals, params);
2577
- return [ys, ys.map((y) => zerosLike(y))];
2939
+ return [ys, ys.map((y) => zerosLike(y.ref))];
2578
2940
  };
2579
2941
  }
2580
2942
  const jvpRules = {
2581
2943
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
2582
- [Primitive.Mul]([x, y], [dx, dy]) {
2583
- return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
2584
- },
2944
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
2585
2945
  [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
2586
2946
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
2587
2947
  [Primitive.Reciprocal]([x], [dx]) {
@@ -2594,13 +2954,13 @@ const jvpRules = {
2594
2954
  if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2595
2955
  else {
2596
2956
  dx.dispose();
2597
- return [[cast(x, dtype)], [zerosLike(x)]];
2957
+ return [[cast(x.ref, dtype)], [zerosLike(x)]];
2598
2958
  }
2599
2959
  },
2600
2960
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2601
2961
  if (x.dtype === dtype) return [[x], [dx]];
2602
2962
  dx.dispose();
2603
- return [[bitcast(x, dtype)], [zerosLike(x)]];
2963
+ return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
2604
2964
  },
2605
2965
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2606
2966
  [Primitive.Sin]([x], [dx]) {
@@ -2616,6 +2976,10 @@ const jvpRules = {
2616
2976
  [Primitive.Log]([x], [dx]) {
2617
2977
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
2618
2978
  },
2979
+ [Primitive.Sqrt]([x], [dx]) {
2980
+ const z = sqrt$1(x);
2981
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
2982
+ },
2619
2983
  [Primitive.Min]([x, y], [dx, dy]) {
2620
2984
  return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
2621
2985
  },
@@ -2632,13 +2996,14 @@ const jvpRules = {
2632
2996
  const primal = reduce(x.ref, op, axis);
2633
2997
  const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
2634
2998
  const minCount = where$1(notMin.ref, 0, 1).sum(axis);
2635
- const tangent = where$1(notMin, op === require_backend.AluOp.Min ? 0 : 0, dx).sum(axis).mul(reciprocal$1(minCount));
2999
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
2636
3000
  return [[primal], [tangent]];
2637
3001
  } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
2638
3002
  },
2639
- [Primitive.Dot]([x, y], [dx, dy]) {
2640
- return [[dot$1(x.ref, y.ref)], [dot$1(dx, y).add(dot$1(x, dy))]];
2641
- },
3003
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
3004
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
3005
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
3006
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
2642
3007
  [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
2643
3008
  [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
2644
3009
  dcond.dispose();
@@ -2711,7 +3076,7 @@ function jvp$1(f, primals, tangents) {
2711
3076
 
2712
3077
  //#endregion
2713
3078
  //#region src/frontend/vmap.ts
2714
- var import_usingCtx = __toESM(require_usingCtx(), 1);
3079
+ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
2715
3080
  function mappedAval(batchDim, aval) {
2716
3081
  const shape$1 = [...aval.shape];
2717
3082
  shape$1.splice(batchDim, 1);
@@ -2815,6 +3180,7 @@ const vmapRules = {
2815
3180
  [Primitive.Cos]: unopBatcher(cos$1),
2816
3181
  [Primitive.Exp]: unopBatcher(exp$1),
2817
3182
  [Primitive.Log]: unopBatcher(log$1),
3183
+ [Primitive.Sqrt]: unopBatcher(sqrt$1),
2818
3184
  [Primitive.Min]: broadcastBatcher(min$1),
2819
3185
  [Primitive.Max]: broadcastBatcher(max$1),
2820
3186
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
@@ -2951,7 +3317,7 @@ var PartialVal = class PartialVal {
2951
3317
  return this.val !== null;
2952
3318
  }
2953
3319
  toString() {
2954
- return this.val ? this.val.toString() : this.aval.strShort();
3320
+ return this.val ? this.val.toString() : this.aval.toString();
2955
3321
  }
2956
3322
  };
2957
3323
  function partialEvalFlat(f, pvalsIn) {
@@ -3325,12 +3691,72 @@ const transposeRules = {
3325
3691
  if (op === require_backend.AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
3326
3692
  else throw new NonlinearError(Primitive.Reduce);
3327
3693
  },
3694
+ [Primitive.Pool]([ct], [x], { window, strides }) {
3695
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
3696
+ return bind(Primitive.PoolTranspose, [ct], {
3697
+ inShape: x.aval.shape,
3698
+ window,
3699
+ strides
3700
+ });
3701
+ },
3702
+ [Primitive.PoolTranspose]([ct], [x], { window, strides }) {
3703
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
3704
+ return bind(Primitive.Pool, [ct], {
3705
+ window,
3706
+ strides
3707
+ });
3708
+ },
3328
3709
  [Primitive.Dot]([ct], [x, y]) {
3329
3710
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3330
3711
  const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3331
3712
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3332
3713
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3333
3714
  },
3715
+ [Primitive.Conv]([ct], [lhs, rhs], params) {
3716
+ if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
3717
+ const rev01 = [
3718
+ 1,
3719
+ 0,
3720
+ ...require_backend.range(2, ct.ndim)
3721
+ ];
3722
+ if (lhs instanceof UndefPrimal) {
3723
+ let kernel = rhs;
3724
+ kernel = transpose$1(kernel, rev01);
3725
+ kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
3726
+ const result = conv(ct, kernel, {
3727
+ strides: params.lhsDilation,
3728
+ padding: params.padding.map(([pl, _pr], i) => {
3729
+ const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3730
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3731
+ const padBefore = dilatedKernel - 1 - pl;
3732
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3733
+ const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
3734
+ return [padBefore, padAfter];
3735
+ }),
3736
+ lhsDilation: params.strides,
3737
+ rhsDilation: params.rhsDilation
3738
+ });
3739
+ return [result, null];
3740
+ } else {
3741
+ const newLhs = transpose$1(lhs, rev01);
3742
+ const newRhs = transpose$1(ct, rev01);
3743
+ let result = conv(newLhs, newRhs, {
3744
+ strides: params.rhsDilation,
3745
+ padding: params.padding.map(([pl, _pr], i) => {
3746
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3747
+ const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3748
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3749
+ const padFromLhs = dilatedCt - dilatedLhs;
3750
+ const padFromRhs = dilatedKernel - pl - 1;
3751
+ return [pl, padFromLhs + padFromRhs];
3752
+ }),
3753
+ lhsDilation: params.lhsDilation,
3754
+ rhsDilation: params.strides
3755
+ });
3756
+ result = transpose$1(result, rev01);
3757
+ return [null, result];
3758
+ }
3759
+ },
3334
3760
  [Primitive.Where]([ct], [cond, x, y]) {
3335
3761
  const cts = [
3336
3762
  null,
@@ -3451,8 +3877,8 @@ function valueAndGrad$1(f) {
3451
3877
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
3452
3878
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3453
3879
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3454
- if (y.dtype !== require_backend.DType.Float32) throw new TypeError("grad currently only supports float32");
3455
- const [ct, ...rest] = fVjp(pureArray(1));
3880
+ if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3881
+ const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3456
3882
  for (const r of rest) r.dispose();
3457
3883
  return [y, ct];
3458
3884
  };
@@ -3466,6 +3892,73 @@ function jacrev$1(f) {
3466
3892
  };
3467
3893
  }
3468
3894
 
3895
+ //#endregion
3896
+ //#region src/lax.ts
3897
+ var lax_exports = {};
3898
+ __export(lax_exports, {
3899
+ conv: () => conv$1,
3900
+ convGeneralDilated: () => convGeneralDilated,
3901
+ convWithGeneralPadding: () => convWithGeneralPadding,
3902
+ reduceWindow: () => reduceWindow
3903
+ });
3904
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
3905
+ const padType = padding.toUpperCase();
3906
+ switch (padType) {
3907
+ case "VALID": return require_backend.rep(inShape.length, [0, 0]);
3908
+ case "SAME":
3909
+ case "SAME_LOWER": {
3910
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
3911
+ const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
3912
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
3913
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
3914
+ }
3915
+ default: throw new Error(`Unknown padding type: ${padType}`);
3916
+ }
3917
+ }
3918
+ /**
3919
+ * General n-dimensional convolution operator, with optional dilation.
3920
+ *
3921
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
3922
+ * function in JAX, which wraps XLA's general convolution operator.
3923
+ *
3924
+ * Grouped convolutions are not supported right now.
3925
+ */
3926
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
3927
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
3928
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
3929
+ if (typeof padding === "string") {
3930
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
3931
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
3932
+ }
3933
+ return conv(lhs, rhs, {
3934
+ strides: windowStrides,
3935
+ padding,
3936
+ lhsDilation,
3937
+ rhsDilation
3938
+ });
3939
+ }
3940
+ /** Convenience wrapper around `convGeneralDilated`. */
3941
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
3942
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
3943
+ lhsDilation,
3944
+ rhsDilation
3945
+ });
3946
+ }
3947
+ /** Convenience wrapper around `convGeneralDilated`. */
3948
+ function conv$1(lhs, rhs, windowStrides, padding) {
3949
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
3950
+ }
3951
+ /** Reduce a computation over padded windows. */
3952
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
3953
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
3954
+ if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
3955
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
3956
+ return computation(bind1(Primitive.Pool, [operand], {
3957
+ window: windowDimensions,
3958
+ strides: windowStrides
3959
+ }));
3960
+ }
3961
+
3469
3962
  //#endregion
3470
3963
  //#region src/numpy.ts
3471
3964
  var numpy_exports = {};
@@ -3484,9 +3977,9 @@ __export(numpy_exports, {
3484
3977
  bool: () => bool,
3485
3978
  clip: () => clip,
3486
3979
  columnStack: () => columnStack,
3487
- complex64: () => complex64,
3488
3980
  concatenate: () => concatenate,
3489
3981
  cos: () => cos,
3982
+ cosh: () => cosh,
3490
3983
  diag: () => diag,
3491
3984
  diagonal: () => diagonal,
3492
3985
  divide: () => divide,
@@ -3501,8 +3994,10 @@ __export(numpy_exports, {
3501
3994
  flip: () => flip,
3502
3995
  fliplr: () => fliplr,
3503
3996
  flipud: () => flipud,
3997
+ float16: () => float16,
3504
3998
  float32: () => float32,
3505
3999
  full: () => full,
4000
+ fullLike: () => fullLike$1,
3506
4001
  greater: () => greater,
3507
4002
  greaterEqual: () => greaterEqual,
3508
4003
  hstack: () => hstack,
@@ -3529,6 +4024,7 @@ __export(numpy_exports, {
3529
4024
  negative: () => negative,
3530
4025
  notEqual: () => notEqual,
3531
4026
  ones: () => ones,
4027
+ onesLike: () => onesLike$1,
3532
4028
  pad: () => pad,
3533
4029
  permuteDims: () => permuteDims,
3534
4030
  pi: () => pi,
@@ -3539,11 +4035,14 @@ __export(numpy_exports, {
3539
4035
  scalar: () => scalar,
3540
4036
  shape: () => shape,
3541
4037
  sin: () => sin,
4038
+ sinh: () => sinh,
3542
4039
  size: () => size,
4040
+ sqrt: () => sqrt,
3543
4041
  square: () => square,
3544
4042
  stack: () => stack,
3545
4043
  sum: () => sum,
3546
4044
  tan: () => tan,
4045
+ tanh: () => tanh,
3547
4046
  transpose: () => transpose,
3548
4047
  trueDivide: () => trueDivide,
3549
4048
  trunc: () => trunc,
@@ -3552,13 +4051,14 @@ __export(numpy_exports, {
3552
4051
  vecdot: () => vecdot,
3553
4052
  vstack: () => vstack,
3554
4053
  where: () => where,
3555
- zeros: () => zeros
4054
+ zeros: () => zeros,
4055
+ zerosLike: () => zerosLike$1
3556
4056
  });
3557
4057
  const float32 = require_backend.DType.Float32;
3558
4058
  const int32 = require_backend.DType.Int32;
3559
4059
  const uint32 = require_backend.DType.Uint32;
3560
4060
  const bool = require_backend.DType.Bool;
3561
- const complex64 = require_backend.DType.Complex64;
4061
+ const float16 = require_backend.DType.Float16;
3562
4062
  /** Euler's constant, `e = 2.7182818284590...` */
3563
4063
  const e = Math.E;
3564
4064
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -3585,6 +4085,8 @@ const cos = cos$1;
3585
4085
  const exp = exp$1;
3586
4086
  /** Calculate the natural logarithm of all elements in the input array. */
3587
4087
  const log = log$1;
4088
+ /** Calculate the square root of all elements in the input array. */
4089
+ const sqrt = sqrt$1;
3588
4090
  /** Return element-wise minimum of the input arrays. */
3589
4091
  const minimum = min$1;
3590
4092
  /** Return element-wise maximum of the input arrays. */
@@ -3626,6 +4128,12 @@ const pad = pad$1;
3626
4128
  const ndim = ndim$1;
3627
4129
  /** Return the shape of an array. Does not consume array reference. */
3628
4130
  const shape = getShape;
4131
+ /** Return an array of zeros with the same shape and type as a given array. */
4132
+ const zerosLike$1 = zerosLike;
4133
+ /** Return an array of ones with the same shape and type as a given array. */
4134
+ const onesLike$1 = onesLike;
4135
+ /** Return a full array with the same shape and type as a given array. */
4136
+ const fullLike$1 = fullLike;
3629
4137
  /**
3630
4138
  * Return the number of elements in an array, optionally along an axis.
3631
4139
  * Does not consume array reference.
@@ -3676,13 +4184,7 @@ function argmin(a, axis, opts) {
3676
4184
  dtype: int32,
3677
4185
  device: a.device
3678
4186
  });
3679
- const idx = where(isMax, scalar(1, {
3680
- dtype: int32,
3681
- device: a.device
3682
- }), scalar(0, {
3683
- dtype: int32,
3684
- device: a.device
3685
- })).mul(arange(shape$1[axis], 0, -1, {
4187
+ const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3686
4188
  dtype: int32,
3687
4189
  device: a.device
3688
4190
  }).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
@@ -3706,13 +4208,7 @@ function argmax(a, axis, opts) {
3706
4208
  dtype: int32,
3707
4209
  device: a.device
3708
4210
  });
3709
- const idx = where(isMax, scalar(1, {
3710
- dtype: int32,
3711
- device: a.device
3712
- }), scalar(0, {
3713
- dtype: int32,
3714
- device: a.device
3715
- })).mul(arange(shape$1[axis], 0, -1, {
4211
+ const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3716
4212
  dtype: int32,
3717
4213
  device: a.device
3718
4214
  }).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
@@ -3844,9 +4340,11 @@ function ravel(a) {
3844
4340
  * Return specified diagonals.
3845
4341
  *
3846
4342
  * If a is 2D, return the diagonal of the array with the given offset. If a is
3847
- * 3D or higher, compute diagonals along the two given axes.
4343
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
3848
4344
  *
3849
- * This returns a view over the existing array.
4345
+ * This returns a view over the existing array. The shape of the resulting array
4346
+ * is determined by removing the two axes along which the diagonal is taken,
4347
+ * then appending a new axis to the right with holding the diagonals.
3850
4348
  */
3851
4349
  function diagonal(a, offset, axis1, axis2) {
3852
4350
  return fudgeArray(a).diagonal(offset, axis1, axis2);
@@ -3862,15 +4360,16 @@ function diag(v, k = 0) {
3862
4360
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
3863
4361
  if (a.ndim === 1) {
3864
4362
  const n = a.shape[0];
3865
- const ret = where(eye(n).equal(1), a, 0);
3866
- if (k !== 0) throw new Error("diag() for 1D arrays only for k=0");
3867
- return ret;
4363
+ const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
4364
+ if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4365
+ else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4366
+ else return ret;
3868
4367
  } else if (a.ndim === 2) return diagonal(a, k);
3869
4368
  else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
3870
4369
  }
3871
4370
  /** Return if two arrays are element-wise equal within a tolerance. */
3872
4371
  function allclose(actual, expected, options) {
3873
- const { rtol = 1e-5, atol = 1e-8 } = options ?? {};
4372
+ const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
3874
4373
  const x = array(actual);
3875
4374
  const y = array(expected);
3876
4375
  if (!require_backend.deepEqual(x.shape, y.shape)) return false;
@@ -4004,15 +4503,52 @@ function log2(x) {
4004
4503
  function log10(x) {
4005
4504
  return log(x).mul(Math.LOG10E);
4006
4505
  }
4506
+ /**
4507
+ * Calculate element-wise hyperbolic sine of input.
4508
+ *
4509
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
4510
+ */
4511
+ function sinh(x) {
4512
+ const ex = exp(x);
4513
+ const emx = reciprocal(ex.ref);
4514
+ return ex.sub(emx).mul(.5);
4515
+ }
4516
+ /**
4517
+ * Calculate element-wise hyperbolic cosine of input.
4518
+ *
4519
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
4520
+ */
4521
+ function cosh(x) {
4522
+ const ex = exp(x);
4523
+ const emx = reciprocal(ex.ref);
4524
+ return ex.add(emx).mul(.5);
4525
+ }
4526
+ /**
4527
+ * Calculate element-wise hyperbolic tangent of input.
4528
+ *
4529
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4530
+ */
4531
+ function tanh(x) {
4532
+ x = fudgeArray(x);
4533
+ const negsgn = where(less(x.ref, 0), 1, -1);
4534
+ const en2x = exp(x.mul(negsgn.ref).mul(2));
4535
+ return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4536
+ }
4007
4537
 
4008
4538
  //#endregion
4009
4539
  //#region src/nn.ts
4010
4540
  var nn_exports = {};
4011
4541
  __export(nn_exports, {
4542
+ celu: () => celu,
4543
+ elu: () => elu,
4544
+ gelu: () => gelu,
4545
+ glu: () => glu,
4012
4546
  identity: () => identity,
4547
+ leakyRelu: () => leakyRelu,
4013
4548
  logSigmoid: () => logSigmoid,
4014
4549
  logSoftmax: () => logSoftmax,
4015
4550
  logsumexp: () => logsumexp,
4551
+ mish: () => mish,
4016
4552
  oneHot: () => oneHot,
4017
4553
  relu: () => relu,
4018
4554
  relu6: () => relu6,
@@ -4072,10 +4608,7 @@ function softSign(x) {
4072
4608
  *
4073
4609
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4074
4610
  */
4075
- function silu(x) {
4076
- x = fudgeArray(x);
4077
- return x.ref.mul(sigmoid(x));
4078
- }
4611
+ const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
4079
4612
  /**
4080
4613
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4081
4614
  * Swish, computed element-wise:
@@ -4095,6 +4628,72 @@ function logSigmoid(x) {
4095
4628
  }
4096
4629
  /** Identity activation function. Returns the argument unmodified. */
4097
4630
  const identity = fudgeArray;
4631
+ /** Leaky rectified linear (ReLU) activation function */
4632
+ function leakyRelu(x, negativeSlope = .01) {
4633
+ x = fudgeArray(x);
4634
+ return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
4635
+ }
4636
+ /**
4637
+ * Exponential linear unit activation function.
4638
+ *
4639
+ * Computes the element-wise function:
4640
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
4641
+ */
4642
+ function elu(x, alpha = 1) {
4643
+ x = fudgeArray(x);
4644
+ return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
4645
+ }
4646
+ /**
4647
+ * Continuously-differentiable exponential linear unit activation function.
4648
+ *
4649
+ * Computes the element-wise function:
4650
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
4651
+ */
4652
+ function celu(x, alpha = 1) {
4653
+ x = fudgeArray(x);
4654
+ return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
4655
+ }
4656
+ /**
4657
+ * Gaussion error linear unit (GELU) activation function.
4658
+ *
4659
+ * This is computed element-wise. Currently jax-js does not support the erf() or
4660
+ * gelu() functions exactly as primitives, so an approximation is used:
4661
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
4662
+ *
4663
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
4664
+ *
4665
+ * This will be improved in the future.
4666
+ */
4667
+ const gelu = jit$1((x) => {
4668
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
4669
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
4670
+ });
4671
+ /**
4672
+ * Gated linear unit (GLU) activation function.
4673
+ *
4674
+ * Splits the `axis` dimension of the input into two halves, a and b, then
4675
+ * computes `a * sigmoid(b)`.
4676
+ */
4677
+ function glu(x, axis = -1) {
4678
+ x = fudgeArray(x);
4679
+ axis = require_backend.checkAxis(axis, x.ndim);
4680
+ const size$1 = x.shape[axis];
4681
+ if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
4682
+ const slice = x.shape.map((a$1) => [0, a$1]);
4683
+ const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
4684
+ const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
4685
+ return a.mul(sigmoid(b));
4686
+ }
4687
+ /**
4688
+ * Mish activation function.
4689
+ *
4690
+ * Computes the element-wise function:
4691
+ * `mish(x) = x * tanh(softplus(x))`
4692
+ */
4693
+ function mish(x) {
4694
+ x = fudgeArray(x);
4695
+ return x.ref.mul(tanh(softplus(x)));
4696
+ }
4098
4697
  /**
4099
4698
  * Softmax function. Computes the function which rescales elements to the range
4100
4699
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -4171,7 +4770,7 @@ function logsumexp(x, axis) {
4171
4770
  * ```
4172
4771
  */
4173
4772
  function oneHot(x, numClasses) {
4174
- if (x.dtype !== "int32") throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4773
+ if (x.dtype !== require_backend.DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4175
4774
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4176
4775
  }
4177
4776
 
@@ -4240,6 +4839,19 @@ const vmap = vmap$1;
4240
4839
  const jacfwd = jacfwd$1;
4241
4840
  /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
4242
4841
  const makeJaxpr = makeJaxpr$1;
4842
+ /**
4843
+ * Mark a function for automatic JIT compilation, with operator fusion.
4844
+ *
4845
+ * The function will be compiled the first time it is called with a set of
4846
+ * argument shapes.
4847
+ *
4848
+ * **Options:**
4849
+ * - `staticArgnums`: An array of argument indices to treat as static
4850
+ * (compile-time constant). These arguments must be hashable, won't be traced,
4851
+ * and different values will trigger recompilation.
4852
+ * - `device`: The device to place the computation on. If not specified, the
4853
+ * computation will be placed on the default device.
4854
+ */
4243
4855
  const jit = jit$1;
4244
4856
  /**
4245
4857
  * Produce a local linear approximation to a function at a point using jvp() and
@@ -4261,6 +4873,7 @@ const jacrev = jacrev$1;
4261
4873
  const jacobian = jacrev;
4262
4874
 
4263
4875
  //#endregion
4876
+ exports.DType = require_backend.DType;
4264
4877
  exports.devices = require_backend.devices;
4265
4878
  exports.grad = grad;
4266
4879
  exports.init = require_backend.init;
@@ -4269,6 +4882,12 @@ exports.jacobian = jacobian;
4269
4882
  exports.jacrev = jacrev;
4270
4883
  exports.jit = jit;
4271
4884
  exports.jvp = jvp;
4885
+ Object.defineProperty(exports, 'lax', {
4886
+ enumerable: true,
4887
+ get: function () {
4888
+ return lax_exports;
4889
+ }
4890
+ });
4272
4891
  exports.linearize = linearize;
4273
4892
  exports.makeJaxpr = makeJaxpr;
4274
4893
  Object.defineProperty(exports, 'nn', {