@jax-js/jax 0.0.2 → 0.0.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/dist/index.js CHANGED
@@ -1,11 +1,12 @@
1
1
  import { __export } from "./chunk-Cl8Af3a2.js";
2
- import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip } from "./backend-1eVbAoaV.js";
2
+ import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BqDtPGaR.js";
3
3
 
4
4
  //#region src/tree.ts
5
5
  var tree_exports = {};
6
6
  __export(tree_exports, {
7
7
  JsTreeDef: () => JsTreeDef,
8
8
  NodeType: () => NodeType,
9
+ dispose: () => dispose,
9
10
  flatten: () => flatten,
10
11
  leaves: () => leaves,
11
12
  map: () => map,
@@ -20,7 +21,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
20
21
  NodeType$1["Leaf"] = "Leaf";
21
22
  return NodeType$1;
22
23
  }({});
23
- /** Analog to the JAX "pytree" object, but for JavaScript. */
24
+ /** Represents the structure of a JsTree. */
24
25
  var JsTreeDef = class JsTreeDef {
25
26
  static leaf = new JsTreeDef(NodeType.Leaf, null, []);
26
27
  constructor(nodeType, nodeMetadata, childTreedefs) {
@@ -108,6 +109,194 @@ function map(fn, tree, ...rest) {
108
109
  function ref(tree) {
109
110
  return map((x) => x.ref, tree);
110
111
  }
112
+ /** Dispose every array in a tree. */
113
+ function dispose(tree) {
114
+ if (tree) map((x) => x.dispose(), tree);
115
+ }
116
+
117
+ //#endregion
118
+ //#region src/frontend/convolution.ts
119
+ /**
120
+ * Check that the shapes and parameters passed to convolution are valid.
121
+ *
122
+ * If the check succeeds, returns the output shape.
123
+ */
124
+ function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
125
+ if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
126
+ const n = lhsShape.length - 2;
127
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
128
+ if (strides.length !== n) throw new Error("conv() strides != spatial dims");
129
+ if (padding.length !== n) throw new Error("conv() padding != spatial dims");
130
+ if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
131
+ if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
132
+ if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
133
+ const outShape = [lhsShape[0], rhsShape[0]];
134
+ for (let i = 0; i < n; i++) {
135
+ if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
136
+ if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
137
+ if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
138
+ if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
139
+ const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
140
+ if (k <= 0) throw new Error("conv() kernel size must be positive");
141
+ const [pl, pr] = padding[i];
142
+ if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
143
+ const kernelSize = (k - 1) * rhsDilation[i] + 1;
144
+ const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
145
+ if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
146
+ outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
147
+ }
148
+ return outShape;
149
+ }
150
+ function checkPoolShape(inShape, window, strides) {
151
+ if (strides.length !== window.length) throw new Error("pool() strides != window dims");
152
+ if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
153
+ const outShape = inShape.slice(0, inShape.length - window.length);
154
+ for (let i = 0; i < window.length; i++) {
155
+ const k = window[i];
156
+ const s = strides[i];
157
+ const size$1 = inShape[inShape.length - window.length + i];
158
+ if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
159
+ if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
160
+ if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
161
+ outShape.push(Math.ceil((size$1 - k + 1) / s));
162
+ }
163
+ return outShape.concat(window);
164
+ }
165
+ /**
166
+ * Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
167
+ * `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
168
+ * spatial dimensions convolved with a kernel.
169
+ *
170
+ * The resulting array can be multiplied with a kernel of shape `ks`, then
171
+ * reduced along the last `ks.length` dimensions for a convolution.
172
+ *
173
+ * Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
174
+ */
175
+ function pool(st, ks, strides = 1, dilation = 1) {
176
+ if (ks.length === 0) return st;
177
+ if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
178
+ if (typeof strides === "number") strides = rep(ks.length, strides);
179
+ if (typeof dilation === "number") dilation = rep(ks.length, dilation);
180
+ if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
181
+ if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
182
+ const noop = st.shape.slice(0, -ks.length);
183
+ const i_ = st.shape.slice(-ks.length);
184
+ const s_ = strides;
185
+ const d_ = dilation;
186
+ const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
187
+ const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
188
+ const kidf = zipn(ks, i_, d_, f_);
189
+ st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
190
+ st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
191
+ const kos = zipn(ks, o_, s_);
192
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
193
+ st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
194
+ [0, k],
195
+ [0, o],
196
+ [0, 1]
197
+ ])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
198
+ st = st.permute([
199
+ ...range(noop.length),
200
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
201
+ ...ks.map((_, j) => noop.length + 2 * j)
202
+ ]);
203
+ return st;
204
+ }
205
+ /**
206
+ * Perform the transpose of pool, directly undo-ing a pool() operation.
207
+ *
208
+ * Note that since pool repeats the input, the transpose operation technically
209
+ * should include a sum reduction. This function doesn't perform the reduction,
210
+ * which should be done on the last `k` axes of the returned shape.
211
+ */
212
+ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
213
+ if (ks.length === 0) return st;
214
+ if (typeof strides === "number") strides = rep(ks.length, strides);
215
+ if (typeof dilation === "number") dilation = rep(ks.length, dilation);
216
+ const noop = inShape.slice(0, -ks.length);
217
+ const i_ = inShape.slice(-ks.length);
218
+ const s_ = strides;
219
+ const d_ = dilation;
220
+ const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
221
+ if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
222
+ const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
223
+ const kidf = zipn(ks, i_, d_, f_);
224
+ const kos = zipn(ks, o_, s_);
225
+ st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
226
+ st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
227
+ k,
228
+ o,
229
+ 1
230
+ ])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
231
+ [0, 0],
232
+ [0, 0],
233
+ [0, s - 1]
234
+ ])]);
235
+ st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
236
+ st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
237
+ st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
238
+ ...range(noop.length),
239
+ ...ks.map((_, j) => noop.length + 2 * j + 1),
240
+ ...ks.map((_, j) => noop.length + 2 * j)
241
+ ]);
242
+ return st;
243
+ }
244
+ /** Applies dilation to an array directly, for transposed convolution. */
245
+ function applyDilation(st, dilation) {
246
+ if (dilation.every((s) => s === 1)) return st;
247
+ const s_ = dilation;
248
+ const [a, b, ...k_] = st.shape;
249
+ st = st.reshape([
250
+ a,
251
+ b,
252
+ ...k_.flatMap((k) => [k, 1])
253
+ ]);
254
+ st = st.pad([
255
+ [0, 0],
256
+ [0, 0],
257
+ ...s_.flatMap((s) => [[0, 0], [0, s - 1]])
258
+ ]);
259
+ st = st.reshape([
260
+ a,
261
+ b,
262
+ ...k_.map((k, i) => k * s_[i])
263
+ ]);
264
+ st = st.shrink([
265
+ [0, a],
266
+ [0, b],
267
+ ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
268
+ ]);
269
+ return st;
270
+ }
271
+ /**
272
+ * Prepare for a convolution between two arrays.
273
+ *
274
+ * This does not check the validity of the shapes, which should be checked
275
+ * beforehand using `checkConvShape()`.
276
+ */
277
+ function prepareConv(stX, stY, params) {
278
+ const n = stX.shape.length - 2;
279
+ stX = applyDilation(stX, params.lhsDilation);
280
+ const ks = stY.shape.slice(2);
281
+ stX = stX.padOrShrink([
282
+ [0, 0],
283
+ [0, 0],
284
+ ...params.padding
285
+ ]);
286
+ stX = pool(stX, ks, params.strides, params.rhsDilation);
287
+ stX = stX.moveaxis(1, n + 1).reshape([
288
+ stX.shape[0],
289
+ 1,
290
+ ...stX.shape.slice(2, n + 2),
291
+ stX.shape[1] * prod(ks)
292
+ ]);
293
+ stY = stY.reshape([
294
+ stY.shape[0],
295
+ ...rep(n, 1),
296
+ stY.shape[1] * prod(ks)
297
+ ]);
298
+ return [stX, stY];
299
+ }
111
300
 
112
301
  //#endregion
113
302
  //#region src/frontend/core.ts
@@ -136,10 +325,14 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
136
325
  Primitive$1["Cos"] = "cos";
137
326
  Primitive$1["Exp"] = "exp";
138
327
  Primitive$1["Log"] = "log";
328
+ Primitive$1["Sqrt"] = "sqrt";
139
329
  Primitive$1["Min"] = "min";
140
330
  Primitive$1["Max"] = "max";
141
331
  Primitive$1["Reduce"] = "reduce";
142
332
  Primitive$1["Dot"] = "dot";
333
+ Primitive$1["Conv"] = "conv";
334
+ Primitive$1["Pool"] = "pool";
335
+ Primitive$1["PoolTranspose"] = "pool_transpose";
143
336
  Primitive$1["Compare"] = "compare";
144
337
  Primitive$1["Where"] = "where";
145
338
  Primitive$1["Transpose"] = "transpose";
@@ -203,6 +396,9 @@ function exp$1(x) {
203
396
  function log$1(x) {
204
397
  return bind1(Primitive.Log, [x]);
205
398
  }
399
+ function sqrt$1(x) {
400
+ return bind1(Primitive.Sqrt, [x]);
401
+ }
206
402
  function min$1(x, y) {
207
403
  return bind1(Primitive.Min, [x, y]);
208
404
  }
@@ -225,6 +421,17 @@ function reduce(x, op, axis, opts) {
225
421
  function dot$1(x, y) {
226
422
  return bind1(Primitive.Dot, [x, y]);
227
423
  }
424
+ function conv(x, y, params = {}) {
425
+ if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
426
+ const n = x.ndim - 2;
427
+ if (n < 0) throw new Error("conv() requires at least 2D inputs");
428
+ return bind1(Primitive.Conv, [x, y], {
429
+ strides: params.strides ?? rep(n, 1),
430
+ padding: params.padding ?? rep(n, [0, 0]),
431
+ lhsDilation: params.lhsDilation ?? rep(n, 1),
432
+ rhsDilation: params.rhsDilation ?? rep(n, 1)
433
+ });
434
+ }
228
435
  function compare(x, y, op) {
229
436
  return bind1(Primitive.Compare, [x, y], { op });
230
437
  }
@@ -360,6 +567,9 @@ var Tracer = class Tracer {
360
567
  get shape() {
361
568
  return this.aval.shape;
362
569
  }
570
+ get size() {
571
+ return prod(this.shape);
572
+ }
363
573
  get dtype() {
364
574
  return this.aval.dtype;
365
575
  }
@@ -411,7 +621,7 @@ var Tracer = class Tracer {
411
621
  else if (typeof axis === "number") axis = [checkAxis(axis, this.ndim)];
412
622
  else axis = axis.map((a) => checkAxis(a, this.ndim));
413
623
  let result = reduce(this, AluOp.Add, axis);
414
- result = result.mul(prod(result.shape) / prod(this.shape));
624
+ result = result.mul(result.size / this.size);
415
625
  if (opts?.keepDims) result = broadcast(result, this.shape, axis);
416
626
  return result;
417
627
  }
@@ -445,8 +655,29 @@ var Tracer = class Tracer {
445
655
  /** Return specified diagonals. See `numpy.diagonal` for full docs. */
446
656
  diagonal(offset = 0, axis1 = 0, axis2 = 1) {
447
657
  if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
658
+ if (offset < 0) return this.diagonal(-offset, axis2, axis1);
659
+ axis1 = checkAxis(axis1, this.ndim);
660
+ axis2 = checkAxis(axis2, this.ndim);
448
661
  if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
449
- throw new Error("diagonal not implemented");
662
+ if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
663
+ let ar = this;
664
+ if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
665
+ const perm = range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
666
+ ar = ar.transpose(perm);
667
+ }
668
+ const [n, m] = ar.shape.slice(-2);
669
+ const diagSize = Math.min(n, m - offset);
670
+ ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
671
+ const npad = diagSize * (m + 1) - n * m;
672
+ if (npad > 0) ar = pad$1(ar, [...rep(ar.ndim - 1, [0, 0]), [0, npad]]);
673
+ else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
674
+ ar = ar.reshape([
675
+ ...ar.shape.slice(0, -1),
676
+ diagSize,
677
+ m + 1
678
+ ]);
679
+ ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
680
+ return ar;
450
681
  }
451
682
  /** Flatten the array without changing its data. */
452
683
  flatten() {
@@ -589,7 +820,7 @@ var ShapedArray = class ShapedArray {
589
820
  get ndim() {
590
821
  return this.shape.length;
591
822
  }
592
- strShort() {
823
+ toString() {
593
824
  return `${this.dtype}[${this.shape.join(",")}]`;
594
825
  }
595
826
  equals(other) {
@@ -620,7 +851,7 @@ function fullRaise(trace, val) {
620
851
  if (Object.is(val._trace.main, trace.main)) return val;
621
852
  else if (val._trace.main.level < level) return trace.lift(val);
622
853
  else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
623
- else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
854
+ else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
624
855
  }
625
856
  var TreeMismatchError = class extends TypeError {
626
857
  constructor(where$2, left, right) {
@@ -869,16 +1100,16 @@ function jitCompile(backend, jaxpr, consts) {
869
1100
  jitCompileCache.set(cacheKey, jp);
870
1101
  return jp;
871
1102
  }
872
- function reshapeViews(exp$2, mapping) {
1103
+ function reshapeViews(exp$2, mapping, reduceAxis = false) {
873
1104
  return exp$2.rewrite((exp$3) => {
874
1105
  if (exp$3.op === AluOp.GlobalView) {
875
1106
  const [gid, st] = exp$3.arg;
876
1107
  const newSt = mapping(st);
877
1108
  if (newSt) {
878
- const indices = unravelAlu(newSt.shape, AluVar.gidx);
1109
+ const indices = reduceAxis ? unravelAlu(newSt.shape.slice(0, -1), AluVar.gidx).concat(AluVar.ridx) : unravelAlu(newSt.shape, AluVar.gidx);
879
1110
  return AluExp.globalView(exp$3.dtype, gid, newSt, indices);
880
1111
  }
881
- }
1112
+ } else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
882
1113
  });
883
1114
  }
884
1115
  function broadcastedJit(fn) {
@@ -927,6 +1158,7 @@ const jitRules = {
927
1158
  [Primitive.Cos]: unopJit(AluExp.cos),
928
1159
  [Primitive.Exp]: unopJit(AluExp.exp),
929
1160
  [Primitive.Log]: unopJit(AluExp.log),
1161
+ [Primitive.Sqrt]: unopJit(AluExp.sqrt),
930
1162
  [Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
931
1163
  [Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
932
1164
  [Primitive.Reduce](nargs, [a], [as], { op, axis }) {
@@ -941,18 +1173,20 @@ const jitRules = {
941
1173
  const size$1 = prod(newShape);
942
1174
  const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
943
1175
  newShape.push(reductionSize);
944
- a = a.rewrite((exp$2) => {
945
- if (exp$2.op === AluOp.GlobalView) {
946
- const [gid, st] = exp$2.arg;
947
- const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
948
- const indices = unravelAlu(newShape.slice(0, -1), AluVar.gidx);
949
- indices.push(AluVar.ridx);
950
- return AluExp.globalView(exp$2.dtype, gid, newSt, indices);
951
- }
952
- });
1176
+ const perm = keptAxes.concat(shiftedAxes);
1177
+ a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
953
1178
  const reduction = new Reduction(a.dtype, op, reductionSize);
954
1179
  return new Kernel(nargs, size$1, a, reduction);
955
1180
  },
1181
+ [Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
1182
+ [Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
1183
+ let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
1184
+ const size$1 = prod(inShape);
1185
+ stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
1186
+ a = reshapeViews(a, (st) => st.compose(stX), true);
1187
+ const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
1188
+ return new Kernel(nargs, size$1, a, reduction);
1189
+ },
956
1190
  [Primitive.Dot](nargs, [a, b], [as, bs]) {
957
1191
  const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
958
1192
  const c = k1.exp;
@@ -962,6 +1196,14 @@ const jitRules = {
962
1196
  axis: [cs.ndim - 1]
963
1197
  });
964
1198
  },
1199
+ [Primitive.Conv](nargs, [a, b], [as, bs], params) {
1200
+ const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
1201
+ a = reshapeViews(a, (st) => st.compose(stX));
1202
+ b = reshapeViews(b, (st) => st.compose(stY));
1203
+ as = new ShapedArray(stX.shape, as.dtype);
1204
+ bs = new ShapedArray(stY.shape, bs.dtype);
1205
+ return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
1206
+ },
965
1207
  [Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
966
1208
  [Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
967
1209
  [Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
@@ -974,8 +1216,20 @@ const jitRules = {
974
1216
  }),
975
1217
  [Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
976
1218
  [Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
977
- [Primitive.Gather]() {
978
- throw new Error("Gather is not implemented in JIT yet");
1219
+ [Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
1220
+ const axisSet = new Set(axis);
1221
+ const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
1222
+ const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
1223
+ finalShape.splice(outDim, 0, ...indexShape);
1224
+ const idxAll = unravelAlu(finalShape, AluVar.gidx);
1225
+ const idxNonaxis = [...idxAll];
1226
+ idxNonaxis.splice(outDim, indexShape.length);
1227
+ const src = [...idxNonaxis];
1228
+ for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1229
+ for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
1230
+ const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
1231
+ if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
1232
+ return new Kernel(nargs, prod(finalShape), x.substitute({ gidx: index }));
979
1233
  },
980
1234
  [Primitive.JitCall]() {
981
1235
  throw new Error("internal: JitCall should have been flattened before JIT compilation");
@@ -994,9 +1248,15 @@ function splitGraphDataflow(backend, jaxpr) {
994
1248
  blackNodes.add(v);
995
1249
  p1NextBlack.set(v, v);
996
1250
  }
1251
+ const reducePrimitives = [
1252
+ Primitive.Reduce,
1253
+ Primitive.Dot,
1254
+ Primitive.Conv,
1255
+ Primitive.PoolTranspose
1256
+ ];
997
1257
  for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
998
1258
  const eqn = jaxpr.eqns[i];
999
- if (eqn.primitive === Primitive.Reduce || eqn.outBinders.some((v) => blackNodes.has(v))) {
1259
+ if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
1000
1260
  for (const v of eqn.outBinders) {
1001
1261
  blackNodes.add(v);
1002
1262
  p1NextBlack.set(v, v);
@@ -1223,7 +1483,7 @@ var Array$1 = class Array$1 extends Tracer {
1223
1483
  const inputs = [];
1224
1484
  const src = [...idxNonaxis];
1225
1485
  for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
1226
- for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#dtype, ar.#source, ar.#st, idxAxis));
1486
+ for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#source, ar.#st, idxAxis));
1227
1487
  else {
1228
1488
  let gid = inputs.indexOf(ar.#source);
1229
1489
  if (gid === -1) {
@@ -1233,7 +1493,7 @@ var Array$1 = class Array$1 extends Tracer {
1233
1493
  src[axis[i]] = AluExp.cast(DType.Int32, AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
1234
1494
  }
1235
1495
  let exp$2;
1236
- if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#dtype, this.#source, this.#st, src);
1496
+ if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, src);
1237
1497
  else {
1238
1498
  let gid = inputs.indexOf(this.#source);
1239
1499
  if (gid === -1) {
@@ -1276,7 +1536,7 @@ var Array$1 = class Array$1 extends Tracer {
1276
1536
  this.#check();
1277
1537
  if (this.#source instanceof AluExp) {
1278
1538
  const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
1279
- return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
1539
+ return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
1280
1540
  }
1281
1541
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1282
1542
  const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
@@ -1309,18 +1569,18 @@ var Array$1 = class Array$1 extends Tracer {
1309
1569
  if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
1310
1570
  arrays = Array$1.#broadcastArrays(arrays);
1311
1571
  const newShape = [...arrays[0].shape];
1312
- if (arrays.every((ar) => ar.#source instanceof AluExp) && reduceAxis === void 0) {
1572
+ if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
1313
1573
  if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
1314
1574
  const exp$4 = custom(arrays.map((ar) => ar.#source));
1315
- return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
1575
+ return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
1316
1576
  }
1317
1577
  const exp$3 = custom(arrays.map((ar) => {
1318
1578
  const src$1 = ar.#source;
1319
1579
  if (ar.#st.contiguous) return src$1;
1320
- return accessorAluExp(ar.#dtype, src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1580
+ return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
1321
1581
  }));
1322
1582
  const st = ShapeTracker.fromShape(newShape);
1323
- return new Array$1(exp$3, st, exp$3.dtype, backend);
1583
+ return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
1324
1584
  }
1325
1585
  let indices;
1326
1586
  if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
@@ -1330,7 +1590,7 @@ var Array$1 = class Array$1 extends Tracer {
1330
1590
  }
1331
1591
  const inputs = [];
1332
1592
  const src = [];
1333
- for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#dtype, ar.#source, ar.#st, indices));
1593
+ for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
1334
1594
  else {
1335
1595
  let gid = inputs.indexOf(ar.#source);
1336
1596
  if (gid === -1) {
@@ -1364,7 +1624,7 @@ var Array$1 = class Array$1 extends Tracer {
1364
1624
  const indices = [...unravelAlu(newShape, AluVar.gidx), AluVar.ridx];
1365
1625
  let exp$2;
1366
1626
  const inputs = [];
1367
- if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1627
+ if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, indices);
1368
1628
  else {
1369
1629
  inputs.push(this.#source);
1370
1630
  exp$2 = accessorGlobal(this.#dtype, 0, this.#st, indices);
@@ -1389,7 +1649,7 @@ var Array$1 = class Array$1 extends Tracer {
1389
1649
  this.#check();
1390
1650
  const indices = unravelAlu(this.#st.shape, AluVar.gidx);
1391
1651
  if (this.#source instanceof AluExp) {
1392
- const exp$2 = accessorAluExp(this.#dtype, this.#source, this.#st, indices);
1652
+ const exp$2 = accessorAluExp(this.#source, this.#st, indices);
1393
1653
  const kernel = new Kernel(0, this.#st.size, exp$2);
1394
1654
  const output = this.#backend.malloc(kernel.bytes);
1395
1655
  const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
@@ -1427,42 +1687,51 @@ var Array$1 = class Array$1 extends Tracer {
1427
1687
  }
1428
1688
  /** Realize the array and return it as data. */
1429
1689
  async data() {
1430
- if (this.#source instanceof AluExp && prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1690
+ if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1431
1691
  this.#realize();
1432
1692
  const pending = this.#pending;
1433
1693
  if (pending) {
1434
1694
  await Promise.all(pending.map((p) => p.prepare()));
1435
1695
  for (const p of pending) p.submit();
1436
1696
  }
1437
- const byteCount = byteWidth(this.#dtype) * prod(this.shape);
1697
+ const byteCount = byteWidth(this.#dtype) * this.size;
1438
1698
  const buf = await this.#backend.read(this.#source, 0, byteCount);
1439
1699
  this.dispose();
1440
1700
  return dtypedArray(this.dtype, buf);
1441
1701
  }
1442
- /** Wait for this array to be placed on the backend, if needed. */
1702
+ /**
1703
+ * Wait for this array to finish evaluation.
1704
+ *
1705
+ * Operations and data loading in jax-js are lazy, so this function ensures
1706
+ * that pending operations are dispatched and fully executed before it
1707
+ * returns.
1708
+ *
1709
+ * If you are mapping from `data()` or `dataSync()`, it will also trigger
1710
+ * dispatch of operations as well.
1711
+ */
1443
1712
  async wait() {
1444
1713
  this.#check();
1445
- if (this.#source instanceof AluExp) return;
1714
+ if (this.#source instanceof AluExp) return this;
1446
1715
  const pending = this.#pending;
1447
1716
  if (pending) {
1448
1717
  await Promise.all(pending.map((p) => p.prepare()));
1449
1718
  for (const p of pending) p.submit();
1450
1719
  }
1451
1720
  await this.#backend.read(this.#source, 0, 0);
1452
- this.dispose();
1721
+ return this;
1453
1722
  }
1454
1723
  /**
1455
1724
  * Realize the array and return it as data. This is a sync variant and not
1456
1725
  * recommended for performance reasons, as it will block rendering.
1457
1726
  */
1458
1727
  dataSync() {
1459
- if (this.#source instanceof AluExp && prod(this.shape) < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1728
+ if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
1460
1729
  this.#realize();
1461
1730
  for (const p of this.#pending) {
1462
1731
  p.prepareSync();
1463
1732
  p.submit();
1464
1733
  }
1465
- const byteCount = byteWidth(this.#dtype) * prod(this.shape);
1734
+ const byteCount = byteWidth(this.#dtype) * this.size;
1466
1735
  const buf = this.#backend.readSync(this.#source, 0, byteCount);
1467
1736
  this.dispose();
1468
1737
  return dtypedArray(this.dtype, buf);
@@ -1483,6 +1752,16 @@ var Array$1 = class Array$1 extends Tracer {
1483
1752
  async jsAsync() {
1484
1753
  return dataToJs(this.dtype, await this.data(), this.shape);
1485
1754
  }
1755
+ /**
1756
+ * Copy an element of an array to a numeric scalar and return it.
1757
+ *
1758
+ * Throws an error if the array does not have a single element. The array must
1759
+ * either be rank-0, or all dimensions of the shape are 1.
1760
+ */
1761
+ item() {
1762
+ if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
1763
+ return this.dataSync()[0];
1764
+ }
1486
1765
  /** @private Internal plumbing method for Array / Tracer ops. */
1487
1766
  static _implRules() {
1488
1767
  return {
@@ -1496,7 +1775,7 @@ var Array$1 = class Array$1 extends Tracer {
1496
1775
  return [x.#binary(AluOp.Idiv, y)];
1497
1776
  },
1498
1777
  [Primitive.Neg]([x]) {
1499
- return [zerosLike(x).#binary(AluOp.Sub, x)];
1778
+ return [zerosLike(x.ref).#binary(AluOp.Sub, x)];
1500
1779
  },
1501
1780
  [Primitive.Reciprocal]([x]) {
1502
1781
  return [x.#unary(AluOp.Reciprocal)];
@@ -1552,6 +1831,9 @@ var Array$1 = class Array$1 extends Tracer {
1552
1831
  [Primitive.Log]([x]) {
1553
1832
  return [x.#unary(AluOp.Log)];
1554
1833
  },
1834
+ [Primitive.Sqrt]([x]) {
1835
+ return [x.#unary(AluOp.Sqrt)];
1836
+ },
1555
1837
  [Primitive.Min]([x, y]) {
1556
1838
  return [x.#binary(AluOp.Min, y)];
1557
1839
  },
@@ -1562,9 +1844,24 @@ var Array$1 = class Array$1 extends Tracer {
1562
1844
  if (axis.length === 0) return [x];
1563
1845
  return [x.#moveAxesDown(axis).#reduce(op)];
1564
1846
  },
1847
+ [Primitive.Pool]([x], { window, strides }) {
1848
+ const st = pool(x.#st, window, strides);
1849
+ return [x.#reshape(st)];
1850
+ },
1851
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
1852
+ const n = inShape.length;
1853
+ let st = poolTranspose(x.#st, inShape, window, strides);
1854
+ st = st.reshape([...st.shape.slice(0, n), prod(st.shape.slice(n))]);
1855
+ return [x.#reshape(st).#reduce(AluOp.Add)];
1856
+ },
1565
1857
  [Primitive.Dot]([x, y]) {
1566
1858
  return [Array$1.#naryCustom("dot", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
1567
1859
  },
1860
+ [Primitive.Conv]([x, y], params) {
1861
+ checkConvShape(x.shape, y.shape, params);
1862
+ const [stX, stY] = prepareConv(x.#st, y.#st, params);
1863
+ return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
1864
+ },
1568
1865
  [Primitive.Compare]([x, y], { op }) {
1569
1866
  const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
1570
1867
  return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
@@ -1629,6 +1926,7 @@ function scalar(value, { dtype, device } = {}) {
1629
1926
  dtype ??= DType.Float32;
1630
1927
  if (![
1631
1928
  DType.Float32,
1929
+ DType.Float16,
1632
1930
  DType.Int32,
1633
1931
  DType.Uint32
1634
1932
  ].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
@@ -1636,6 +1934,7 @@ function scalar(value, { dtype, device } = {}) {
1636
1934
  dtype ??= DType.Bool;
1637
1935
  if (![
1638
1936
  DType.Float32,
1937
+ DType.Float16,
1639
1938
  DType.Int32,
1640
1939
  DType.Uint32,
1641
1940
  DType.Bool
@@ -1649,7 +1948,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1649
1948
  if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
1650
1949
  if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
1651
1950
  return values;
1652
- } else if (values instanceof Float32Array || values instanceof Int32Array) return arrayFromData(values, shape$1 ?? [values.length], {
1951
+ } else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
1653
1952
  dtype,
1654
1953
  device
1655
1954
  });
@@ -1678,7 +1977,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
1678
1977
  });
1679
1978
  } else {
1680
1979
  dtype = dtype ?? DType.Float32;
1681
- const data = dtypedArray(dtype, flat);
1980
+ const data = dtypedJsArray(dtype, flat);
1682
1981
  return arrayFromData(data, shape$1, {
1683
1982
  dtype,
1684
1983
  device
@@ -1699,19 +1998,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
1699
1998
  });
1700
1999
  }
1701
2000
  const backend = getBackend(device);
1702
- if (data instanceof Float32Array) {
1703
- if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
1704
- const slot = backend.malloc(data.byteLength, data.buffer);
1705
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), DType.Float32, backend);
1706
- } else if (data instanceof Int32Array) {
1707
- if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
1708
- const slot = backend.malloc(data.byteLength, data.buffer);
1709
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype ?? DType.Int32, backend);
1710
- } else if (data instanceof Uint32Array) {
1711
- if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
1712
- const slot = backend.malloc(data.byteLength, data.buffer);
1713
- return new Array$1(slot, ShapeTracker.fromShape(shape$1), DType.Uint32, backend);
1714
- } else throw new Error("Unsupported data type");
2001
+ if (ArrayBuffer.isView(data)) {
2002
+ const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
2003
+ if (data instanceof Float32Array) {
2004
+ if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
2005
+ dtype ??= DType.Float32;
2006
+ } else if (data instanceof Int32Array) {
2007
+ if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
2008
+ dtype ??= DType.Int32;
2009
+ } else if (data instanceof Uint32Array) {
2010
+ if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
2011
+ dtype ??= DType.Uint32;
2012
+ } else if (data instanceof Float16Array) {
2013
+ if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
2014
+ dtype ??= DType.Float16;
2015
+ } else throw new Error("Unsupported data array type: " + data.constructor.name);
2016
+ const slot = backend.malloc(data.byteLength, buf);
2017
+ return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
2018
+ } else throw new Error("Unsupported data type: " + data.constructor.name);
1715
2019
  }
1716
2020
  function dataToJs(dtype, data, shape$1) {
1717
2021
  if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
@@ -1738,9 +2042,20 @@ var EvalTrace = class extends Trace {
1738
2042
  };
1739
2043
  const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
1740
2044
  const implRules = Array$1._implRules();
1741
- function zerosLike(val) {
2045
+ function zerosLike(val, dtype) {
2046
+ const aval = getAval(val);
2047
+ if (val instanceof Tracer) val.dispose();
2048
+ return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
2049
+ }
2050
+ function onesLike(val, dtype) {
2051
+ const aval = getAval(val);
2052
+ if (val instanceof Tracer) val.dispose();
2053
+ return ones(aval.shape, { dtype: dtype ?? aval.dtype });
2054
+ }
2055
+ function fullLike(val, fillValue, dtype) {
1742
2056
  const aval = getAval(val);
1743
- return zeros(aval.shape, { dtype: aval.dtype });
2057
+ if (val instanceof Tracer) val.dispose();
2058
+ return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
1744
2059
  }
1745
2060
  /** Return a new array of given shape and type, filled with zeros. */
1746
2061
  function zeros(shape$1, { dtype, device } = {}) {
@@ -1762,6 +2077,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
1762
2077
  if (typeof fillValue === "number") {
1763
2078
  dtype = dtype ?? DType.Float32;
1764
2079
  source = AluExp.const(dtype, fillValue);
2080
+ } else if (typeof fillValue === "bigint") {
2081
+ dtype = dtype ?? DType.Int32;
2082
+ source = AluExp.const(dtype, Number(fillValue));
1765
2083
  } else if (typeof fillValue === "boolean") {
1766
2084
  dtype = dtype ?? DType.Bool;
1767
2085
  source = AluExp.const(dtype, fillValue ? 1 : 0);
@@ -1859,7 +2177,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
1859
2177
  const st = ShapeTracker.fromShape([num]);
1860
2178
  return new Array$1(exp$2, st, dtype, getBackend(device));
1861
2179
  }
1862
- /** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
1863
2180
  function aluCompare(a, b, op) {
1864
2181
  switch (op) {
1865
2182
  case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
@@ -1901,7 +2218,7 @@ function generalBroadcast(a, b) {
1901
2218
  }
1902
2219
 
1903
2220
  //#endregion
1904
- //#region node_modules/.pnpm/@oxc-project+runtime@0.77.3/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
2221
+ //#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
1905
2222
  function _usingCtx() {
1906
2223
  var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
1907
2224
  var n$1 = Error();
@@ -1969,7 +2286,7 @@ var Var = class Var {
1969
2286
  this.aval = aval;
1970
2287
  }
1971
2288
  toString() {
1972
- return `Var(${this.id}):${this.aval.strShort()}`;
2289
+ return `Var(${this.id}):${this.aval.toString()}`;
1973
2290
  }
1974
2291
  };
1975
2292
  /** Literal in a Jaxpr expression. Currently, only scalars are supported. */
@@ -2009,7 +2326,7 @@ var VarPrinter = class {
2009
2326
  return name;
2010
2327
  }
2011
2328
  nameType(v) {
2012
- return `${this.name(v)}:${v.aval.strShort()}`;
2329
+ return `${this.name(v)}:${v.aval.toString()}`;
2013
2330
  }
2014
2331
  };
2015
2332
  /** A single statement / binding in a Jaxpr, in ANF form. */
@@ -2164,8 +2481,8 @@ var JaxprType = class {
2164
2481
  this.outTypes = outTypes;
2165
2482
  }
2166
2483
  toString() {
2167
- const inTypes = this.inTypes.map((aval) => aval.strShort()).join(", ");
2168
- const outTypes = this.outTypes.map((aval) => aval.strShort()).join(", ");
2484
+ const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
2485
+ const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
2169
2486
  return `(${inTypes}) -> (${outTypes})`;
2170
2487
  }
2171
2488
  };
@@ -2244,7 +2561,7 @@ var JaxprTracer = class extends Tracer {
2244
2561
  this.aval = aval;
2245
2562
  }
2246
2563
  toString() {
2247
- return `JaxprTracer(${this.aval.strShort()})`;
2564
+ return `JaxprTracer(${this.aval.toString()})`;
2248
2565
  }
2249
2566
  get ref() {
2250
2567
  return this;
@@ -2383,6 +2700,7 @@ const abstractEvalRules = {
2383
2700
  [Primitive.Cos]: vectorizedUnopAbstractEval,
2384
2701
  [Primitive.Exp]: vectorizedUnopAbstractEval,
2385
2702
  [Primitive.Log]: vectorizedUnopAbstractEval,
2703
+ [Primitive.Sqrt]: vectorizedUnopAbstractEval,
2386
2704
  [Primitive.Min]: binopAbstractEval,
2387
2705
  [Primitive.Max]: binopAbstractEval,
2388
2706
  [Primitive.Reduce]([x], { axis }) {
@@ -2390,6 +2708,15 @@ const abstractEvalRules = {
2390
2708
  const newShape = x.shape.filter((_, i) => !axisSet.has(i));
2391
2709
  return [new ShapedArray(newShape, x.dtype)];
2392
2710
  },
2711
+ [Primitive.Pool]([x], { window, strides }) {
2712
+ const shape$1 = checkPoolShape(x.shape, window, strides);
2713
+ return [new ShapedArray(shape$1, x.dtype)];
2714
+ },
2715
+ [Primitive.PoolTranspose]([x], { inShape, window, strides }) {
2716
+ const shape$1 = checkPoolShape(inShape, window, strides);
2717
+ if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
2718
+ return [new ShapedArray(inShape, x.dtype)];
2719
+ },
2393
2720
  [Primitive.Dot]([x, y]) {
2394
2721
  if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
2395
2722
  if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
@@ -2397,6 +2724,11 @@ const abstractEvalRules = {
2397
2724
  shape$1.splice(-1, 1);
2398
2725
  return [new ShapedArray(shape$1, x.dtype)];
2399
2726
  },
2727
+ [Primitive.Conv]([lhs, rhs], params) {
2728
+ if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
2729
+ const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
2730
+ return [new ShapedArray(shape$1, lhs.dtype)];
2731
+ },
2400
2732
  [Primitive.Compare]: compareAbstractEval,
2401
2733
  [Primitive.Where]([cond, x, y]) {
2402
2734
  if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
@@ -2444,15 +2776,34 @@ const abstractEvalRules = {
2444
2776
  return outTypes;
2445
2777
  }
2446
2778
  };
2447
- function makeJaxpr$1(f) {
2779
+ function splitIdx(values, argnums) {
2780
+ const a = [];
2781
+ const b = [];
2782
+ for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
2783
+ else b.push(values[i]);
2784
+ return [a, b];
2785
+ }
2786
+ function joinIdx(n, a, b, argnums) {
2787
+ const result = [];
2788
+ let ai = 0;
2789
+ let bi = 0;
2790
+ for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
2791
+ else result.push(b[bi++]);
2792
+ return result;
2793
+ }
2794
+ function makeJaxpr$1(f, opts) {
2448
2795
  return (...argsIn) => {
2449
2796
  try {
2450
2797
  var _usingCtx$1 = _usingCtx();
2451
- const [avalsIn, inTree] = flatten(argsIn);
2452
- const [fFlat, outTree] = flattenFun(f, inTree);
2798
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2799
+ const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
2800
+ const [avalsIn, inTree] = flatten(shapedArgs);
2801
+ const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
2802
+ return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
2803
+ }, inTree);
2453
2804
  const builder = new JaxprBuilder();
2454
2805
  const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
2455
- const _dynamic = _usingCtx$1.u(newDynamic(main));
2806
+ _usingCtx$1.u(newDynamic(main));
2456
2807
  const trace = new JaxprTrace(main);
2457
2808
  const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
2458
2809
  const outs = fFlat(...tracersIn);
@@ -2471,14 +2822,17 @@ function makeJaxpr$1(f) {
2471
2822
  }
2472
2823
  };
2473
2824
  }
2474
- function jit$1(f) {
2825
+ function jit$1(f, opts) {
2475
2826
  const cache = /* @__PURE__ */ new Map();
2827
+ const staticArgnums = new Set(opts?.staticArgnums ?? []);
2476
2828
  return ((...args) => {
2477
- const [argsFlat, inTree] = flatten(args);
2829
+ const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
2830
+ const [argsFlat, inTree] = flatten(dynamicArgs);
2478
2831
  const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
2479
2832
  const avalsIn = unflatten(inTree, avalsInFlat);
2480
- const cacheKey = JSON.stringify(avalsIn);
2481
- const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f)(...avalsIn));
2833
+ const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
2834
+ const cacheKey = JSON.stringify(jaxprArgs);
2835
+ const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
2482
2836
  const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
2483
2837
  jaxpr,
2484
2838
  numConsts: consts.length
@@ -2515,7 +2869,7 @@ var JVPTrace = class extends Trace {
2515
2869
  return this.lift(pureArray(val));
2516
2870
  }
2517
2871
  lift(val) {
2518
- return new JVPTracer(this, val, zerosLike(val));
2872
+ return new JVPTracer(this, val, zerosLike(val.ref));
2519
2873
  }
2520
2874
  processPrimitive(primitive, tracers, params) {
2521
2875
  const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
@@ -2533,19 +2887,25 @@ function linearTangentsJvp(primitive) {
2533
2887
  return [ys, dys];
2534
2888
  };
2535
2889
  }
2890
+ /** Rule for product of gradients in bilinear operations. */
2891
+ function bilinearTangentsJvp(primitive) {
2892
+ return ([x, y], [dx, dy], params) => {
2893
+ const primal = bind1(primitive, [x.ref, y.ref], params);
2894
+ const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
2895
+ return [[primal], [tangent]];
2896
+ };
2897
+ }
2536
2898
  /** Rule that zeros out any tangents. */
2537
2899
  function zeroTangentsJvp(primitive) {
2538
2900
  return (primals, tangents, params) => {
2539
2901
  for (const t of tangents) t.dispose();
2540
2902
  const ys = bind(primitive, primals, params);
2541
- return [ys, ys.map((y) => zerosLike(y))];
2903
+ return [ys, ys.map((y) => zerosLike(y.ref))];
2542
2904
  };
2543
2905
  }
2544
2906
  const jvpRules = {
2545
2907
  [Primitive.Add]: linearTangentsJvp(Primitive.Add),
2546
- [Primitive.Mul]([x, y], [dx, dy]) {
2547
- return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
2548
- },
2908
+ [Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
2549
2909
  [Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
2550
2910
  [Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
2551
2911
  [Primitive.Reciprocal]([x], [dx]) {
@@ -2558,13 +2918,13 @@ const jvpRules = {
2558
2918
  if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
2559
2919
  else {
2560
2920
  dx.dispose();
2561
- return [[cast(x, dtype)], [zerosLike(x)]];
2921
+ return [[cast(x.ref, dtype)], [zerosLike(x)]];
2562
2922
  }
2563
2923
  },
2564
2924
  [Primitive.Bitcast]([x], [dx], { dtype }) {
2565
2925
  if (x.dtype === dtype) return [[x], [dx]];
2566
2926
  dx.dispose();
2567
- return [[bitcast(x, dtype)], [zerosLike(x)]];
2927
+ return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
2568
2928
  },
2569
2929
  [Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
2570
2930
  [Primitive.Sin]([x], [dx]) {
@@ -2580,6 +2940,10 @@ const jvpRules = {
2580
2940
  [Primitive.Log]([x], [dx]) {
2581
2941
  return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
2582
2942
  },
2943
+ [Primitive.Sqrt]([x], [dx]) {
2944
+ const z = sqrt$1(x);
2945
+ return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
2946
+ },
2583
2947
  [Primitive.Min]([x, y], [dx, dy]) {
2584
2948
  return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
2585
2949
  },
@@ -2596,13 +2960,14 @@ const jvpRules = {
2596
2960
  const primal = reduce(x.ref, op, axis);
2597
2961
  const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
2598
2962
  const minCount = where$1(notMin.ref, 0, 1).sum(axis);
2599
- const tangent = where$1(notMin, op === AluOp.Min ? 0 : 0, dx).sum(axis).mul(reciprocal$1(minCount));
2963
+ const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
2600
2964
  return [[primal], [tangent]];
2601
2965
  } else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
2602
2966
  },
2603
- [Primitive.Dot]([x, y], [dx, dy]) {
2604
- return [[dot$1(x.ref, y.ref)], [dot$1(dx, y).add(dot$1(x, dy))]];
2605
- },
2967
+ [Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
2968
+ [Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
2969
+ [Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
2970
+ [Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
2606
2971
  [Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
2607
2972
  [Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
2608
2973
  dcond.dispose();
@@ -2778,6 +3143,7 @@ const vmapRules = {
2778
3143
  [Primitive.Cos]: unopBatcher(cos$1),
2779
3144
  [Primitive.Exp]: unopBatcher(exp$1),
2780
3145
  [Primitive.Log]: unopBatcher(log$1),
3146
+ [Primitive.Sqrt]: unopBatcher(sqrt$1),
2781
3147
  [Primitive.Min]: broadcastBatcher(min$1),
2782
3148
  [Primitive.Max]: broadcastBatcher(max$1),
2783
3149
  [Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
@@ -2914,7 +3280,7 @@ var PartialVal = class PartialVal {
2914
3280
  return this.val !== null;
2915
3281
  }
2916
3282
  toString() {
2917
- return this.val ? this.val.toString() : this.aval.strShort();
3283
+ return this.val ? this.val.toString() : this.aval.toString();
2918
3284
  }
2919
3285
  };
2920
3286
  function partialEvalFlat(f, pvalsIn) {
@@ -3288,12 +3654,72 @@ const transposeRules = {
3288
3654
  if (op === AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
3289
3655
  else throw new NonlinearError(Primitive.Reduce);
3290
3656
  },
3657
+ [Primitive.Pool]([ct], [x], { window, strides }) {
3658
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
3659
+ return bind(Primitive.PoolTranspose, [ct], {
3660
+ inShape: x.aval.shape,
3661
+ window,
3662
+ strides
3663
+ });
3664
+ },
3665
+ [Primitive.PoolTranspose]([ct], [x], { window, strides }) {
3666
+ if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
3667
+ return bind(Primitive.Pool, [ct], {
3668
+ window,
3669
+ strides
3670
+ });
3671
+ },
3291
3672
  [Primitive.Dot]([ct], [x, y]) {
3292
3673
  if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
3293
3674
  const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
3294
3675
  ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
3295
3676
  return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
3296
3677
  },
3678
+ [Primitive.Conv]([ct], [lhs, rhs], params) {
3679
+ if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
3680
+ const rev01 = [
3681
+ 1,
3682
+ 0,
3683
+ ...range(2, ct.ndim)
3684
+ ];
3685
+ if (lhs instanceof UndefPrimal) {
3686
+ let kernel = rhs;
3687
+ kernel = transpose$1(kernel, rev01);
3688
+ kernel = flip$1(kernel, range(2, kernel.ndim));
3689
+ const result = conv(ct, kernel, {
3690
+ strides: params.lhsDilation,
3691
+ padding: params.padding.map(([pl, _pr], i) => {
3692
+ const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3693
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3694
+ const padBefore = dilatedKernel - 1 - pl;
3695
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3696
+ const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
3697
+ return [padBefore, padAfter];
3698
+ }),
3699
+ lhsDilation: params.strides,
3700
+ rhsDilation: params.rhsDilation
3701
+ });
3702
+ return [result, null];
3703
+ } else {
3704
+ const newLhs = transpose$1(lhs, rev01);
3705
+ const newRhs = transpose$1(ct, rev01);
3706
+ let result = conv(newLhs, newRhs, {
3707
+ strides: params.rhsDilation,
3708
+ padding: params.padding.map(([pl, _pr], i) => {
3709
+ const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
3710
+ const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
3711
+ const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
3712
+ const padFromLhs = dilatedCt - dilatedLhs;
3713
+ const padFromRhs = dilatedKernel - pl - 1;
3714
+ return [pl, padFromLhs + padFromRhs];
3715
+ }),
3716
+ lhsDilation: params.lhsDilation,
3717
+ rhsDilation: params.strides
3718
+ });
3719
+ result = transpose$1(result, rev01);
3720
+ return [null, result];
3721
+ }
3722
+ },
3297
3723
  [Primitive.Where]([ct], [cond, x, y]) {
3298
3724
  const cts = [
3299
3725
  null,
@@ -3414,8 +3840,8 @@ function valueAndGrad$1(f) {
3414
3840
  if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
3415
3841
  const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
3416
3842
  if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
3417
- if (y.dtype !== DType.Float32) throw new TypeError("grad currently only supports float32");
3418
- const [ct, ...rest] = fVjp(pureArray(1));
3843
+ if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
3844
+ const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
3419
3845
  for (const r of rest) r.dispose();
3420
3846
  return [y, ct];
3421
3847
  };
@@ -3429,6 +3855,73 @@ function jacrev$1(f) {
3429
3855
  };
3430
3856
  }
3431
3857
 
3858
+ //#endregion
3859
+ //#region src/lax.ts
3860
+ var lax_exports = {};
3861
+ __export(lax_exports, {
3862
+ conv: () => conv$1,
3863
+ convGeneralDilated: () => convGeneralDilated,
3864
+ convWithGeneralPadding: () => convWithGeneralPadding,
3865
+ reduceWindow: () => reduceWindow
3866
+ });
3867
+ function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
3868
+ const padType = padding.toUpperCase();
3869
+ switch (padType) {
3870
+ case "VALID": return rep(inShape.length, [0, 0]);
3871
+ case "SAME":
3872
+ case "SAME_LOWER": {
3873
+ const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
3874
+ const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
3875
+ if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
3876
+ else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
3877
+ }
3878
+ default: throw new Error(`Unknown padding type: ${padType}`);
3879
+ }
3880
+ }
3881
+ /**
3882
+ * General n-dimensional convolution operator, with optional dilation.
3883
+ *
3884
+ * The semantics of this operation mimic the `jax.lax.conv_general_dilated`
3885
+ * function in JAX, which wraps XLA's general convolution operator.
3886
+ *
3887
+ * Grouped convolutions are not supported right now.
3888
+ */
3889
+ function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
3890
+ if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
3891
+ if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
3892
+ if (typeof padding === "string") {
3893
+ if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
3894
+ padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
3895
+ }
3896
+ return conv(lhs, rhs, {
3897
+ strides: windowStrides,
3898
+ padding,
3899
+ lhsDilation,
3900
+ rhsDilation
3901
+ });
3902
+ }
3903
+ /** Convenience wrapper around `convGeneralDilated`. */
3904
+ function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
3905
+ return convGeneralDilated(lhs, rhs, windowStrides, padding, {
3906
+ lhsDilation,
3907
+ rhsDilation
3908
+ });
3909
+ }
3910
+ /** Convenience wrapper around `convGeneralDilated`. */
3911
+ function conv$1(lhs, rhs, windowStrides, padding) {
3912
+ return convGeneralDilated(lhs, rhs, windowStrides, padding);
3913
+ }
3914
+ /** Reduce a computation over padded windows. */
3915
+ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
3916
+ if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
3917
+ if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
3918
+ for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
3919
+ return computation(bind1(Primitive.Pool, [operand], {
3920
+ window: windowDimensions,
3921
+ strides: windowStrides
3922
+ }));
3923
+ }
3924
+
3432
3925
  //#endregion
3433
3926
  //#region src/numpy.ts
3434
3927
  var numpy_exports = {};
@@ -3447,9 +3940,9 @@ __export(numpy_exports, {
3447
3940
  bool: () => bool,
3448
3941
  clip: () => clip,
3449
3942
  columnStack: () => columnStack,
3450
- complex64: () => complex64,
3451
3943
  concatenate: () => concatenate,
3452
3944
  cos: () => cos,
3945
+ cosh: () => cosh,
3453
3946
  diag: () => diag,
3454
3947
  diagonal: () => diagonal,
3455
3948
  divide: () => divide,
@@ -3464,8 +3957,10 @@ __export(numpy_exports, {
3464
3957
  flip: () => flip,
3465
3958
  fliplr: () => fliplr,
3466
3959
  flipud: () => flipud,
3960
+ float16: () => float16,
3467
3961
  float32: () => float32,
3468
3962
  full: () => full,
3963
+ fullLike: () => fullLike$1,
3469
3964
  greater: () => greater,
3470
3965
  greaterEqual: () => greaterEqual,
3471
3966
  hstack: () => hstack,
@@ -3492,6 +3987,7 @@ __export(numpy_exports, {
3492
3987
  negative: () => negative,
3493
3988
  notEqual: () => notEqual,
3494
3989
  ones: () => ones,
3990
+ onesLike: () => onesLike$1,
3495
3991
  pad: () => pad,
3496
3992
  permuteDims: () => permuteDims,
3497
3993
  pi: () => pi,
@@ -3502,11 +3998,14 @@ __export(numpy_exports, {
3502
3998
  scalar: () => scalar,
3503
3999
  shape: () => shape,
3504
4000
  sin: () => sin,
4001
+ sinh: () => sinh,
3505
4002
  size: () => size,
4003
+ sqrt: () => sqrt,
3506
4004
  square: () => square,
3507
4005
  stack: () => stack,
3508
4006
  sum: () => sum,
3509
4007
  tan: () => tan,
4008
+ tanh: () => tanh,
3510
4009
  transpose: () => transpose,
3511
4010
  trueDivide: () => trueDivide,
3512
4011
  trunc: () => trunc,
@@ -3515,13 +4014,14 @@ __export(numpy_exports, {
3515
4014
  vecdot: () => vecdot,
3516
4015
  vstack: () => vstack,
3517
4016
  where: () => where,
3518
- zeros: () => zeros
4017
+ zeros: () => zeros,
4018
+ zerosLike: () => zerosLike$1
3519
4019
  });
3520
4020
  const float32 = DType.Float32;
3521
4021
  const int32 = DType.Int32;
3522
4022
  const uint32 = DType.Uint32;
3523
4023
  const bool = DType.Bool;
3524
- const complex64 = DType.Complex64;
4024
+ const float16 = DType.Float16;
3525
4025
  /** Euler's constant, `e = 2.7182818284590...` */
3526
4026
  const e = Math.E;
3527
4027
  /** Euler-Mascheroni constant, `γ = 0.5772156649...` */
@@ -3548,6 +4048,8 @@ const cos = cos$1;
3548
4048
  const exp = exp$1;
3549
4049
  /** Calculate the natural logarithm of all elements in the input array. */
3550
4050
  const log = log$1;
4051
+ /** Calculate the square root of all elements in the input array. */
4052
+ const sqrt = sqrt$1;
3551
4053
  /** Return element-wise minimum of the input arrays. */
3552
4054
  const minimum = min$1;
3553
4055
  /** Return element-wise maximum of the input arrays. */
@@ -3589,6 +4091,12 @@ const pad = pad$1;
3589
4091
  const ndim = ndim$1;
3590
4092
  /** Return the shape of an array. Does not consume array reference. */
3591
4093
  const shape = getShape;
4094
+ /** Return an array of zeros with the same shape and type as a given array. */
4095
+ const zerosLike$1 = zerosLike;
4096
+ /** Return an array of ones with the same shape and type as a given array. */
4097
+ const onesLike$1 = onesLike;
4098
+ /** Return a full array with the same shape and type as a given array. */
4099
+ const fullLike$1 = fullLike;
3592
4100
  /**
3593
4101
  * Return the number of elements in an array, optionally along an axis.
3594
4102
  * Does not consume array reference.
@@ -3639,13 +4147,7 @@ function argmin(a, axis, opts) {
3639
4147
  dtype: int32,
3640
4148
  device: a.device
3641
4149
  });
3642
- const idx = where(isMax, scalar(1, {
3643
- dtype: int32,
3644
- device: a.device
3645
- }), scalar(0, {
3646
- dtype: int32,
3647
- device: a.device
3648
- })).mul(arange(shape$1[axis], 0, -1, {
4150
+ const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3649
4151
  dtype: int32,
3650
4152
  device: a.device
3651
4153
  }).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
@@ -3669,13 +4171,7 @@ function argmax(a, axis, opts) {
3669
4171
  dtype: int32,
3670
4172
  device: a.device
3671
4173
  });
3672
- const idx = where(isMax, scalar(1, {
3673
- dtype: int32,
3674
- device: a.device
3675
- }), scalar(0, {
3676
- dtype: int32,
3677
- device: a.device
3678
- })).mul(arange(shape$1[axis], 0, -1, {
4174
+ const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
3679
4175
  dtype: int32,
3680
4176
  device: a.device
3681
4177
  }).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
@@ -3807,9 +4303,11 @@ function ravel(a) {
3807
4303
  * Return specified diagonals.
3808
4304
  *
3809
4305
  * If a is 2D, return the diagonal of the array with the given offset. If a is
3810
- * 3D or higher, compute diagonals along the two given axes.
4306
+ * 3D or higher, compute diagonals along the two given axes (default: 0, 1).
3811
4307
  *
3812
- * This returns a view over the existing array.
4308
+ * This returns a view over the existing array. The shape of the resulting array
4309
+ * is determined by removing the two axes along which the diagonal is taken,
4310
+ * then appending a new axis to the right with holding the diagonals.
3813
4311
  */
3814
4312
  function diagonal(a, offset, axis1, axis2) {
3815
4313
  return fudgeArray(a).diagonal(offset, axis1, axis2);
@@ -3825,15 +4323,16 @@ function diag(v, k = 0) {
3825
4323
  if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
3826
4324
  if (a.ndim === 1) {
3827
4325
  const n = a.shape[0];
3828
- const ret = where(eye(n).equal(1), a, 0);
3829
- if (k !== 0) throw new Error("diag() for 1D arrays only for k=0");
3830
- return ret;
4326
+ const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
4327
+ if (k > 0) return pad(ret, [[0, k], [k, 0]]);
4328
+ else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
4329
+ else return ret;
3831
4330
  } else if (a.ndim === 2) return diagonal(a, k);
3832
4331
  else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
3833
4332
  }
3834
4333
  /** Return if two arrays are element-wise equal within a tolerance. */
3835
4334
  function allclose(actual, expected, options) {
3836
- const { rtol = 1e-5, atol = 1e-8 } = options ?? {};
4335
+ const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
3837
4336
  const x = array(actual);
3838
4337
  const y = array(expected);
3839
4338
  if (!deepEqual(x.shape, y.shape)) return false;
@@ -3967,15 +4466,52 @@ function log2(x) {
3967
4466
  function log10(x) {
3968
4467
  return log(x).mul(Math.LOG10E);
3969
4468
  }
4469
+ /**
4470
+ * Calculate element-wise hyperbolic sine of input.
4471
+ *
4472
+ * `sinh(x) = (exp(x) - exp(-x)) / 2`
4473
+ */
4474
+ function sinh(x) {
4475
+ const ex = exp(x);
4476
+ const emx = reciprocal(ex.ref);
4477
+ return ex.sub(emx).mul(.5);
4478
+ }
4479
+ /**
4480
+ * Calculate element-wise hyperbolic cosine of input.
4481
+ *
4482
+ * `cosh(x) = (exp(x) + exp(-x)) / 2`
4483
+ */
4484
+ function cosh(x) {
4485
+ const ex = exp(x);
4486
+ const emx = reciprocal(ex.ref);
4487
+ return ex.add(emx).mul(.5);
4488
+ }
4489
+ /**
4490
+ * Calculate element-wise hyperbolic tangent of input.
4491
+ *
4492
+ * `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
4493
+ */
4494
+ function tanh(x) {
4495
+ x = fudgeArray(x);
4496
+ const negsgn = where(less(x.ref, 0), 1, -1);
4497
+ const en2x = exp(x.mul(negsgn.ref).mul(2));
4498
+ return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
4499
+ }
3970
4500
 
3971
4501
  //#endregion
3972
4502
  //#region src/nn.ts
3973
4503
  var nn_exports = {};
3974
4504
  __export(nn_exports, {
4505
+ celu: () => celu,
4506
+ elu: () => elu,
4507
+ gelu: () => gelu,
4508
+ glu: () => glu,
3975
4509
  identity: () => identity,
4510
+ leakyRelu: () => leakyRelu,
3976
4511
  logSigmoid: () => logSigmoid,
3977
4512
  logSoftmax: () => logSoftmax,
3978
4513
  logsumexp: () => logsumexp,
4514
+ mish: () => mish,
3979
4515
  oneHot: () => oneHot,
3980
4516
  relu: () => relu,
3981
4517
  relu6: () => relu6,
@@ -4035,10 +4571,7 @@ function softSign(x) {
4035
4571
  *
4036
4572
  * Reference: https://en.wikipedia.org/wiki/Swish_function
4037
4573
  */
4038
- function silu(x) {
4039
- x = fudgeArray(x);
4040
- return x.ref.mul(sigmoid(x));
4041
- }
4574
+ const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
4042
4575
  /**
4043
4576
  * Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
4044
4577
  * Swish, computed element-wise:
@@ -4058,6 +4591,72 @@ function logSigmoid(x) {
4058
4591
  }
4059
4592
  /** Identity activation function. Returns the argument unmodified. */
4060
4593
  const identity = fudgeArray;
4594
+ /** Leaky rectified linear (ReLU) activation function */
4595
+ function leakyRelu(x, negativeSlope = .01) {
4596
+ x = fudgeArray(x);
4597
+ return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
4598
+ }
4599
+ /**
4600
+ * Exponential linear unit activation function.
4601
+ *
4602
+ * Computes the element-wise function:
4603
+ * `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
4604
+ */
4605
+ function elu(x, alpha = 1) {
4606
+ x = fudgeArray(x);
4607
+ return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
4608
+ }
4609
+ /**
4610
+ * Continuously-differentiable exponential linear unit activation function.
4611
+ *
4612
+ * Computes the element-wise function:
4613
+ * `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
4614
+ */
4615
+ function celu(x, alpha = 1) {
4616
+ x = fudgeArray(x);
4617
+ return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
4618
+ }
4619
+ /**
4620
+ * Gaussion error linear unit (GELU) activation function.
4621
+ *
4622
+ * This is computed element-wise. Currently jax-js does not support the erf() or
4623
+ * gelu() functions exactly as primitives, so an approximation is used:
4624
+ * `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
4625
+ *
4626
+ * Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
4627
+ *
4628
+ * This will be improved in the future.
4629
+ */
4630
+ const gelu = jit$1((x) => {
4631
+ const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
4632
+ return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
4633
+ });
4634
+ /**
4635
+ * Gated linear unit (GLU) activation function.
4636
+ *
4637
+ * Splits the `axis` dimension of the input into two halves, a and b, then
4638
+ * computes `a * sigmoid(b)`.
4639
+ */
4640
+ function glu(x, axis = -1) {
4641
+ x = fudgeArray(x);
4642
+ axis = checkAxis(axis, x.ndim);
4643
+ const size$1 = x.shape[axis];
4644
+ if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
4645
+ const slice = x.shape.map((a$1) => [0, a$1]);
4646
+ const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
4647
+ const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
4648
+ return a.mul(sigmoid(b));
4649
+ }
4650
+ /**
4651
+ * Mish activation function.
4652
+ *
4653
+ * Computes the element-wise function:
4654
+ * `mish(x) = x * tanh(softplus(x))`
4655
+ */
4656
+ function mish(x) {
4657
+ x = fudgeArray(x);
4658
+ return x.ref.mul(tanh(softplus(x)));
4659
+ }
4061
4660
  /**
4062
4661
  * Softmax function. Computes the function which rescales elements to the range
4063
4662
  * [0, 1] such that the elements along `axis` sum to 1.
@@ -4134,7 +4733,7 @@ function logsumexp(x, axis) {
4134
4733
  * ```
4135
4734
  */
4136
4735
  function oneHot(x, numClasses) {
4137
- if (x.dtype !== "int32") throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4736
+ if (x.dtype !== DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
4138
4737
  return eye(numClasses, void 0, { device: x.device }).slice(x);
4139
4738
  }
4140
4739
 
@@ -4203,6 +4802,19 @@ const vmap = vmap$1;
4203
4802
  const jacfwd = jacfwd$1;
4204
4803
  /** Construct a Jaxpr by dynamically tracing a function with example inputs. */
4205
4804
  const makeJaxpr = makeJaxpr$1;
4805
+ /**
4806
+ * Mark a function for automatic JIT compilation, with operator fusion.
4807
+ *
4808
+ * The function will be compiled the first time it is called with a set of
4809
+ * argument shapes.
4810
+ *
4811
+ * **Options:**
4812
+ * - `staticArgnums`: An array of argument indices to treat as static
4813
+ * (compile-time constant). These arguments must be hashable, won't be traced,
4814
+ * and different values will trigger recompilation.
4815
+ * - `device`: The device to place the computation on. If not specified, the
4816
+ * computation will be placed on the default device.
4817
+ */
4206
4818
  const jit = jit$1;
4207
4819
  /**
4208
4820
  * Produce a local linear approximation to a function at a point using jvp() and
@@ -4224,4 +4836,4 @@ const jacrev = jacrev$1;
4224
4836
  const jacobian = jacrev;
4225
4837
 
4226
4838
  //#endregion
4227
- export { devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };
4839
+ export { DType, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };