@jax-js/jax 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +15 -9
- package/dist/{backend-BY8wlLEl.js → backend-DaqL-MNz.js} +240 -21
- package/dist/{backend-CmaidnkQ.cjs → backend-DziQSaoQ.cjs} +264 -21
- package/dist/index.cjs +2407 -1132
- package/dist/index.d.cts +596 -97
- package/dist/index.d.ts +596 -97
- package/dist/index.js +2400 -1126
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/webgpu-Db2JrNBr.cjs +1261 -0
- package/dist/webgpu-Dh7k9io0.js +1261 -0
- package/package.json +1 -1
- package/dist/webgpu-BVns4DbI.cjs +0 -663
- package/dist/webgpu-C9iAP5h5.js +0 -663
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DaqL-MNz.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -331,6 +331,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
331
331
|
Primitive$1["Mul"] = "mul";
|
|
332
332
|
Primitive$1["Idiv"] = "idiv";
|
|
333
333
|
Primitive$1["Mod"] = "mod";
|
|
334
|
+
Primitive$1["Min"] = "min";
|
|
335
|
+
Primitive$1["Max"] = "max";
|
|
334
336
|
Primitive$1["Neg"] = "neg";
|
|
335
337
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
336
338
|
Primitive$1["Floor"] = "floor";
|
|
@@ -338,7 +340,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
338
340
|
Primitive$1["StopGradient"] = "stop_gradient";
|
|
339
341
|
Primitive$1["Cast"] = "cast";
|
|
340
342
|
Primitive$1["Bitcast"] = "bitcast";
|
|
341
|
-
Primitive$1["RandomBits"] = "random_bits";
|
|
342
343
|
Primitive$1["Sin"] = "sin";
|
|
343
344
|
Primitive$1["Cos"] = "cos";
|
|
344
345
|
Primitive$1["Asin"] = "asin";
|
|
@@ -348,8 +349,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
348
349
|
Primitive$1["Erf"] = "erf";
|
|
349
350
|
Primitive$1["Erfc"] = "erfc";
|
|
350
351
|
Primitive$1["Sqrt"] = "sqrt";
|
|
351
|
-
Primitive$1["Min"] = "min";
|
|
352
|
-
Primitive$1["Max"] = "max";
|
|
353
352
|
Primitive$1["Reduce"] = "reduce";
|
|
354
353
|
Primitive$1["Dot"] = "dot";
|
|
355
354
|
Primitive$1["Conv"] = "conv";
|
|
@@ -357,14 +356,22 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
357
356
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
358
357
|
Primitive$1["Compare"] = "compare";
|
|
359
358
|
Primitive$1["Where"] = "where";
|
|
359
|
+
Primitive$1["Concatenate"] = "concatenate";
|
|
360
|
+
Primitive$1["Split"] = "split";
|
|
361
|
+
Primitive$1["RandomBits"] = "random_bits";
|
|
362
|
+
Primitive$1["Gather"] = "gather";
|
|
360
363
|
Primitive$1["Transpose"] = "transpose";
|
|
361
364
|
Primitive$1["Broadcast"] = "broadcast";
|
|
362
365
|
Primitive$1["Reshape"] = "reshape";
|
|
363
366
|
Primitive$1["Flip"] = "flip";
|
|
364
367
|
Primitive$1["Shrink"] = "shrink";
|
|
365
368
|
Primitive$1["Pad"] = "pad";
|
|
366
|
-
Primitive$1["
|
|
367
|
-
Primitive$1["
|
|
369
|
+
Primitive$1["Sort"] = "sort";
|
|
370
|
+
Primitive$1["Argsort"] = "argsort";
|
|
371
|
+
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
372
|
+
Primitive$1["Cholesky"] = "cholesky";
|
|
373
|
+
Primitive$1["LU"] = "lu";
|
|
374
|
+
Primitive$1["Jit"] = "jit";
|
|
368
375
|
return Primitive$1;
|
|
369
376
|
}({});
|
|
370
377
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
@@ -386,6 +393,12 @@ function idiv(x, y) {
|
|
|
386
393
|
function mod(x, y) {
|
|
387
394
|
return bind1(Primitive.Mod, [x, y]);
|
|
388
395
|
}
|
|
396
|
+
function min$1(x, y) {
|
|
397
|
+
return bind1(Primitive.Min, [x, y]);
|
|
398
|
+
}
|
|
399
|
+
function max$1(x, y) {
|
|
400
|
+
return bind1(Primitive.Max, [x, y]);
|
|
401
|
+
}
|
|
389
402
|
function neg(x) {
|
|
390
403
|
return bind1(Primitive.Neg, [x]);
|
|
391
404
|
}
|
|
@@ -407,12 +420,6 @@ function cast(x, dtype) {
|
|
|
407
420
|
function bitcast(x, dtype) {
|
|
408
421
|
return bind1(Primitive.Bitcast, [x], { dtype });
|
|
409
422
|
}
|
|
410
|
-
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
411
|
-
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
412
|
-
shape: shape$1,
|
|
413
|
-
mode
|
|
414
|
-
});
|
|
415
|
-
}
|
|
416
423
|
function sin$1(x) {
|
|
417
424
|
return bind1(Primitive.Sin, [x]);
|
|
418
425
|
}
|
|
@@ -440,12 +447,6 @@ function erfc$1(x) {
|
|
|
440
447
|
function sqrt$1(x) {
|
|
441
448
|
return bind1(Primitive.Sqrt, [x]);
|
|
442
449
|
}
|
|
443
|
-
function min$1(x, y) {
|
|
444
|
-
return bind1(Primitive.Min, [x, y]);
|
|
445
|
-
}
|
|
446
|
-
function max$1(x, y) {
|
|
447
|
-
return bind1(Primitive.Max, [x, y]);
|
|
448
|
-
}
|
|
449
450
|
function reduce(x, op, axis = null, opts) {
|
|
450
451
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
451
452
|
axis = normalizeAxis(axis, ndim$1(x));
|
|
@@ -501,6 +502,41 @@ function where$1(cond, x, y) {
|
|
|
501
502
|
y
|
|
502
503
|
]);
|
|
503
504
|
}
|
|
505
|
+
function concatenate$1(xs, axis) {
|
|
506
|
+
if (xs.length === 0) throw new Error("concatenate requires at least one input");
|
|
507
|
+
const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
508
|
+
axis = checkAxis(axis, avals[0].ndim);
|
|
509
|
+
for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
510
|
+
return bind1(Primitive.Concatenate, xs, { axis });
|
|
511
|
+
}
|
|
512
|
+
function split$2(x, axis, sizes) {
|
|
513
|
+
axis = checkAxis(axis, ndim$1(x));
|
|
514
|
+
if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
|
|
515
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
516
|
+
if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
|
|
517
|
+
return bind(Primitive.Split, [x], {
|
|
518
|
+
axis,
|
|
519
|
+
sizes
|
|
520
|
+
});
|
|
521
|
+
}
|
|
522
|
+
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
523
|
+
if (!deepEqual(k0.shape, k1.shape) || k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
|
|
524
|
+
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
525
|
+
shape: shape$1,
|
|
526
|
+
mode
|
|
527
|
+
});
|
|
528
|
+
}
|
|
529
|
+
function gather(x, indices, axis, outDim) {
|
|
530
|
+
if (indices.length === 0) throw new Error("gather() requires at least one index");
|
|
531
|
+
if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
|
|
532
|
+
axis = axis.map((a) => checkAxis(a, ndim$1(x)));
|
|
533
|
+
if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
|
|
534
|
+
outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
|
|
535
|
+
return bind1(Primitive.Gather, [x, ...indices], {
|
|
536
|
+
axis,
|
|
537
|
+
outDim
|
|
538
|
+
});
|
|
539
|
+
}
|
|
504
540
|
function transpose$1(x, perm) {
|
|
505
541
|
perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
|
|
506
542
|
if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
@@ -550,16 +586,39 @@ function pad$1(x, width) {
|
|
|
550
586
|
} else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
|
|
551
587
|
return bind1(Primitive.Pad, [x], { width });
|
|
552
588
|
}
|
|
553
|
-
function
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
}
|
|
589
|
+
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
590
|
+
const as = getShape(a);
|
|
591
|
+
const bs = getShape(b);
|
|
592
|
+
if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
|
|
593
|
+
const n = as[as.length - 2];
|
|
594
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
595
|
+
if (lower) {
|
|
596
|
+
a = flip$1(a, [-2, -1]);
|
|
597
|
+
b = flip$1(b, [-1]);
|
|
598
|
+
}
|
|
599
|
+
let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
600
|
+
if (lower) x = flip$1(x, [-1]);
|
|
601
|
+
return x;
|
|
602
|
+
}
|
|
603
|
+
function cholesky$2(x) {
|
|
604
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
605
|
+
if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
|
|
606
|
+
return bind1(Primitive.Cholesky, [x]);
|
|
607
|
+
}
|
|
608
|
+
function lu$1(x) {
|
|
609
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
610
|
+
if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
|
|
611
|
+
return bind(Primitive.LU, [x]);
|
|
612
|
+
}
|
|
613
|
+
function sort$1(x) {
|
|
614
|
+
const nd = ndim$1(x);
|
|
615
|
+
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
616
|
+
return bind1(Primitive.Sort, [x]);
|
|
617
|
+
}
|
|
618
|
+
function argsort$1(x) {
|
|
619
|
+
const nd = ndim$1(x);
|
|
620
|
+
if (nd === 0) throw new Error("argsort: requires at least 1D input");
|
|
621
|
+
return bind(Primitive.Argsort, [x]);
|
|
563
622
|
}
|
|
564
623
|
function bind1(prim, args, params = {}) {
|
|
565
624
|
const [results] = bind(prim, args, params);
|
|
@@ -659,6 +718,9 @@ var Tracer = class Tracer {
|
|
|
659
718
|
mul(other) {
|
|
660
719
|
return mul(this, other);
|
|
661
720
|
}
|
|
721
|
+
mod(other) {
|
|
722
|
+
return mod(this, other);
|
|
723
|
+
}
|
|
662
724
|
greater(other) {
|
|
663
725
|
return greater$1(this, other);
|
|
664
726
|
}
|
|
@@ -722,7 +784,7 @@ var Tracer = class Tracer {
|
|
|
722
784
|
if (isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
|
|
723
785
|
return idiv(this, other);
|
|
724
786
|
}
|
|
725
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
787
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
726
788
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
727
789
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
728
790
|
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
@@ -771,8 +833,42 @@ var Tracer = class Tracer {
|
|
|
771
833
|
*/
|
|
772
834
|
*[Symbol.iterator]() {
|
|
773
835
|
if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
|
|
774
|
-
|
|
775
|
-
this.
|
|
836
|
+
let residual = this;
|
|
837
|
+
const subarrayShape = this.shape.slice(1);
|
|
838
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
839
|
+
const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
|
|
840
|
+
yield lr[0].reshape(subarrayShape);
|
|
841
|
+
residual = lr[1];
|
|
842
|
+
}
|
|
843
|
+
residual.dispose();
|
|
844
|
+
}
|
|
845
|
+
/**
|
|
846
|
+
* Return a sorted copy of an array in ascending order.
|
|
847
|
+
*
|
|
848
|
+
* See `jax.numpy.sort` for full docs.
|
|
849
|
+
*/
|
|
850
|
+
sort(axis = -1) {
|
|
851
|
+
axis = checkAxis(axis, this.ndim);
|
|
852
|
+
if (this.shape[axis] <= 1) return this;
|
|
853
|
+
if (axis === this.ndim - 1) return sort$1(this);
|
|
854
|
+
const perm = range(this.ndim);
|
|
855
|
+
perm.splice(axis, 1);
|
|
856
|
+
perm.push(axis);
|
|
857
|
+
return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
|
|
858
|
+
}
|
|
859
|
+
/**
|
|
860
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
861
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
862
|
+
*
|
|
863
|
+
* See `jax.numpy.argsort` for full docs.
|
|
864
|
+
*/
|
|
865
|
+
argsort(axis = -1) {
|
|
866
|
+
axis = checkAxis(axis, this.ndim);
|
|
867
|
+
if (axis === this.ndim - 1) return argsort$1(this)[1];
|
|
868
|
+
const perm = range(this.ndim);
|
|
869
|
+
perm.splice(axis, 1);
|
|
870
|
+
perm.push(axis);
|
|
871
|
+
return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
|
|
776
872
|
}
|
|
777
873
|
/**
|
|
778
874
|
* Slice an array along one or more axes.
|
|
@@ -891,6 +987,12 @@ var ShapedArray = class ShapedArray {
|
|
|
891
987
|
get ndim() {
|
|
892
988
|
return this.shape.length;
|
|
893
989
|
}
|
|
990
|
+
get size() {
|
|
991
|
+
return prod(this.shape);
|
|
992
|
+
}
|
|
993
|
+
scalar() {
|
|
994
|
+
return new ShapedArray([], this.dtype, this.weakType);
|
|
995
|
+
}
|
|
894
996
|
toString() {
|
|
895
997
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
896
998
|
}
|
|
@@ -1186,13 +1288,13 @@ var Jaxpr = class Jaxpr {
|
|
|
1186
1288
|
}
|
|
1187
1289
|
return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
|
|
1188
1290
|
}
|
|
1189
|
-
/** Flattens nested
|
|
1291
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1190
1292
|
flatten() {
|
|
1191
|
-
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.
|
|
1293
|
+
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
|
|
1192
1294
|
const newEqns = [];
|
|
1193
1295
|
const varMap = /* @__PURE__ */ new Map();
|
|
1194
1296
|
const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
|
|
1195
|
-
for (const eqn of this.eqns) if (eqn.primitive === Primitive.
|
|
1297
|
+
for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
|
|
1196
1298
|
const jaxpr = eqn.params.jaxpr.flatten();
|
|
1197
1299
|
const translation = /* @__PURE__ */ new Map();
|
|
1198
1300
|
const translationF = (x) => x instanceof Var ? translation.get(x) : x;
|
|
@@ -1293,19 +1395,48 @@ function evalJaxpr(jaxpr, args) {
|
|
|
1293
1395
|
function jaxprAsFun(jaxpr) {
|
|
1294
1396
|
return (...args) => evalJaxpr(jaxpr, args);
|
|
1295
1397
|
}
|
|
1398
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1399
|
+
var ClosedJaxpr = class ClosedJaxpr {
|
|
1400
|
+
constructor(jaxpr, consts) {
|
|
1401
|
+
this.jaxpr = jaxpr;
|
|
1402
|
+
this.consts = consts;
|
|
1403
|
+
}
|
|
1404
|
+
/** String representation of this Jaxpr. */
|
|
1405
|
+
toString() {
|
|
1406
|
+
return this.jaxpr.toString();
|
|
1407
|
+
}
|
|
1408
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1409
|
+
mapJaxpr(f) {
|
|
1410
|
+
return new ClosedJaxpr(f(this.jaxpr), this.consts);
|
|
1411
|
+
}
|
|
1412
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1413
|
+
dispose() {
|
|
1414
|
+
for (const c of this.consts) c.dispose();
|
|
1415
|
+
}
|
|
1416
|
+
};
|
|
1296
1417
|
/** Tracer that records its operations to dynamically construct a Jaxpr. */
|
|
1297
1418
|
var JaxprTracer = class extends Tracer {
|
|
1419
|
+
#rc;
|
|
1298
1420
|
constructor(trace$1, aval) {
|
|
1299
1421
|
super(trace$1);
|
|
1300
1422
|
this.aval = aval;
|
|
1423
|
+
this.#rc = 1;
|
|
1301
1424
|
}
|
|
1302
1425
|
toString() {
|
|
1303
1426
|
return `JaxprTracer(${this.aval.toString()})`;
|
|
1304
1427
|
}
|
|
1305
1428
|
get ref() {
|
|
1429
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1430
|
+
this.#rc++;
|
|
1306
1431
|
return this;
|
|
1307
1432
|
}
|
|
1308
|
-
dispose() {
|
|
1433
|
+
dispose() {
|
|
1434
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1435
|
+
this.#rc--;
|
|
1436
|
+
}
|
|
1437
|
+
trackLiftedConstant() {
|
|
1438
|
+
this.#rc++;
|
|
1439
|
+
}
|
|
1309
1440
|
};
|
|
1310
1441
|
/** Analogous to the 'DynamicJaxprTrace' class in JAX. */
|
|
1311
1442
|
var JaxprTrace = class extends Trace {
|
|
@@ -1318,17 +1449,24 @@ var JaxprTrace = class extends Trace {
|
|
|
1318
1449
|
}
|
|
1319
1450
|
/** Register a constant / literal in this Jaxpr. */
|
|
1320
1451
|
getOrMakeConstTracer(val) {
|
|
1452
|
+
if (!(val instanceof Tracer)) val = pureArray(val);
|
|
1321
1453
|
let tracer = this.builder.constTracers.get(val);
|
|
1322
1454
|
if (tracer === void 0) {
|
|
1323
1455
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
1324
|
-
this.builder.addConst(tracer, val
|
|
1456
|
+
this.builder.addConst(tracer, val);
|
|
1457
|
+
} else {
|
|
1458
|
+
val.dispose();
|
|
1459
|
+
tracer.trackLiftedConstant();
|
|
1325
1460
|
}
|
|
1326
1461
|
return tracer;
|
|
1327
1462
|
}
|
|
1328
1463
|
pure = this.getOrMakeConstTracer;
|
|
1329
1464
|
lift = this.getOrMakeConstTracer;
|
|
1330
1465
|
processPrimitive(primitive, tracers, params) {
|
|
1331
|
-
const avalsIn = tracers.map((t) =>
|
|
1466
|
+
const avalsIn = tracers.map((t) => {
|
|
1467
|
+
t.dispose();
|
|
1468
|
+
return t.aval;
|
|
1469
|
+
});
|
|
1332
1470
|
const avalsOut = abstractEvalRules[primitive](avalsIn, params);
|
|
1333
1471
|
const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
|
|
1334
1472
|
this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
|
|
@@ -1371,20 +1509,17 @@ var JaxprBuilder = class {
|
|
|
1371
1509
|
return v;
|
|
1372
1510
|
}
|
|
1373
1511
|
build(inTracers, outTracers) {
|
|
1374
|
-
|
|
1512
|
+
const [constVars, consts] = unzip2(this.constVals.entries());
|
|
1375
1513
|
const t2v = this.getVar.bind(this);
|
|
1376
1514
|
const inBinders = [...constVars, ...inTracers.map(t2v)];
|
|
1377
1515
|
const outVars = outTracers.map(t2v);
|
|
1378
|
-
|
|
1516
|
+
const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
|
|
1379
1517
|
typecheckJaxpr(jaxpr);
|
|
1380
|
-
|
|
1381
|
-
return
|
|
1382
|
-
jaxpr,
|
|
1383
|
-
consts
|
|
1384
|
-
};
|
|
1518
|
+
const cjaxpr = new ClosedJaxpr(jaxpr, consts);
|
|
1519
|
+
return _inlineLiterals(cjaxpr);
|
|
1385
1520
|
}
|
|
1386
1521
|
};
|
|
1387
|
-
function _inlineLiterals(jaxpr, consts) {
|
|
1522
|
+
function _inlineLiterals({ jaxpr, consts }) {
|
|
1388
1523
|
const literals = /* @__PURE__ */ new Map();
|
|
1389
1524
|
const constBinders = [];
|
|
1390
1525
|
const newConsts = [];
|
|
@@ -1399,7 +1534,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
1399
1534
|
const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
|
|
1400
1535
|
const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
|
|
1401
1536
|
typecheckJaxpr(newJaxpr);
|
|
1402
|
-
return
|
|
1537
|
+
return new ClosedJaxpr(newJaxpr, newConsts);
|
|
1403
1538
|
}
|
|
1404
1539
|
function binopAbstractEval([x, y]) {
|
|
1405
1540
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
@@ -1418,6 +1553,8 @@ const abstractEvalRules = {
|
|
|
1418
1553
|
[Primitive.Mul]: binopAbstractEval,
|
|
1419
1554
|
[Primitive.Idiv]: binopAbstractEval,
|
|
1420
1555
|
[Primitive.Mod]: binopAbstractEval,
|
|
1556
|
+
[Primitive.Min]: binopAbstractEval,
|
|
1557
|
+
[Primitive.Max]: binopAbstractEval,
|
|
1421
1558
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1422
1559
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1423
1560
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1431,12 +1568,6 @@ const abstractEvalRules = {
|
|
|
1431
1568
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1432
1569
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1433
1570
|
},
|
|
1434
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1435
|
-
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1436
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1437
|
-
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1438
|
-
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1439
|
-
},
|
|
1440
1571
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
1441
1572
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
1442
1573
|
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
@@ -1446,8 +1577,6 @@ const abstractEvalRules = {
|
|
|
1446
1577
|
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
1447
1578
|
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
1448
1579
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
1449
|
-
[Primitive.Min]: binopAbstractEval,
|
|
1450
|
-
[Primitive.Max]: binopAbstractEval,
|
|
1451
1580
|
[Primitive.Reduce]([x], { axis }) {
|
|
1452
1581
|
const axisSet = new Set(axis);
|
|
1453
1582
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1469,7 +1598,7 @@ const abstractEvalRules = {
|
|
|
1469
1598
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1470
1599
|
},
|
|
1471
1600
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
1472
|
-
const { dtype, weakType } = promoteAvals(
|
|
1601
|
+
const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
|
|
1473
1602
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
1474
1603
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1475
1604
|
},
|
|
@@ -1480,6 +1609,40 @@ const abstractEvalRules = {
|
|
|
1480
1609
|
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
1481
1610
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1482
1611
|
},
|
|
1612
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
1613
|
+
if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
|
|
1614
|
+
for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
1615
|
+
const shape$1 = xs[0].shape.slice();
|
|
1616
|
+
shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
|
|
1617
|
+
const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
|
|
1618
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1619
|
+
},
|
|
1620
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
1621
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
1622
|
+
if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
|
|
1623
|
+
return sizes.map((size$1) => {
|
|
1624
|
+
return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
|
|
1625
|
+
});
|
|
1626
|
+
},
|
|
1627
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1628
|
+
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1629
|
+
if (!deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
|
|
1630
|
+
if (!deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
|
|
1631
|
+
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1632
|
+
},
|
|
1633
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
1634
|
+
for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
1635
|
+
if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
|
|
1636
|
+
if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
|
|
1637
|
+
if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
|
|
1638
|
+
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
1639
|
+
const axisSet = new Set(axis);
|
|
1640
|
+
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
1641
|
+
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
1642
|
+
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
1643
|
+
newShape.splice(outDim, 0, ...gatherShape);
|
|
1644
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1645
|
+
},
|
|
1483
1646
|
[Primitive.Transpose]([x], { perm }) {
|
|
1484
1647
|
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
1485
1648
|
},
|
|
@@ -1500,23 +1663,41 @@ const abstractEvalRules = {
|
|
|
1500
1663
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
1501
1664
|
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1502
1665
|
},
|
|
1503
|
-
[Primitive.
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
if (
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1666
|
+
[Primitive.Sort]([x]) {
|
|
1667
|
+
if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
|
|
1668
|
+
return [ShapedArray.fromAval(x)];
|
|
1669
|
+
},
|
|
1670
|
+
[Primitive.Argsort]([x]) {
|
|
1671
|
+
if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
|
|
1672
|
+
return [ShapedArray.fromAval(x), new ShapedArray(x.shape, DType.Int32, false)];
|
|
1673
|
+
},
|
|
1674
|
+
[Primitive.TriangularSolve]([a, b]) {
|
|
1675
|
+
if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
|
|
1676
|
+
if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
|
|
1677
|
+
const [m, n] = a.shape.slice(-2);
|
|
1678
|
+
const [_batch, q] = b.shape.slice(-2);
|
|
1679
|
+
if (!deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
|
|
1680
|
+
return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
|
|
1681
|
+
},
|
|
1682
|
+
[Primitive.Cholesky]([a]) {
|
|
1683
|
+
if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
|
|
1684
|
+
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1685
|
+
return [ShapedArray.fromAval(a)];
|
|
1686
|
+
},
|
|
1687
|
+
[Primitive.LU]([a]) {
|
|
1688
|
+
if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
|
|
1689
|
+
const batch = a.shape.slice(0, -2);
|
|
1690
|
+
const [m, n] = a.shape.slice(-2);
|
|
1691
|
+
return [
|
|
1692
|
+
ShapedArray.fromAval(a),
|
|
1693
|
+
new ShapedArray([...batch, Math.min(m, n)], DType.Int32, false),
|
|
1694
|
+
new ShapedArray([...batch, m], DType.Int32, false)
|
|
1695
|
+
];
|
|
1515
1696
|
},
|
|
1516
|
-
[Primitive.
|
|
1697
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
1517
1698
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1518
|
-
if (args.length !== inTypes.length) throw new TypeError(`
|
|
1519
|
-
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`
|
|
1699
|
+
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
1700
|
+
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
|
|
1520
1701
|
return outTypes;
|
|
1521
1702
|
}
|
|
1522
1703
|
};
|
|
@@ -1552,11 +1733,10 @@ function makeJaxpr$1(f, opts) {
|
|
|
1552
1733
|
const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
1553
1734
|
const outs = fFlat(...tracersIn);
|
|
1554
1735
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
1555
|
-
const
|
|
1736
|
+
const jaxpr = builder.build(tracersIn, tracersOut);
|
|
1556
1737
|
if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
|
|
1557
1738
|
return {
|
|
1558
|
-
jaxpr: jaxpr.simplify(),
|
|
1559
|
-
consts,
|
|
1739
|
+
jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
|
|
1560
1740
|
treedef: outTree.value
|
|
1561
1741
|
};
|
|
1562
1742
|
} catch (_) {
|
|
@@ -1575,22 +1755,29 @@ function jit$1(f, opts) {
|
|
|
1575
1755
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
1576
1756
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
1577
1757
|
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
1578
|
-
const { jaxpr,
|
|
1579
|
-
const outs = bind(Primitive.
|
|
1758
|
+
const { jaxpr, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
1759
|
+
const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
|
|
1580
1760
|
name: f.name || "closure",
|
|
1581
|
-
jaxpr,
|
|
1582
|
-
numConsts: consts.length
|
|
1761
|
+
jaxpr: jaxpr.jaxpr,
|
|
1762
|
+
numConsts: jaxpr.consts.length
|
|
1583
1763
|
});
|
|
1584
1764
|
return unflatten(outTree, outs);
|
|
1585
1765
|
});
|
|
1586
1766
|
result.dispose = () => {
|
|
1587
|
-
for (const {
|
|
1767
|
+
for (const { jaxpr } of cache.values()) jaxpr.dispose();
|
|
1588
1768
|
};
|
|
1589
1769
|
return result;
|
|
1590
1770
|
}
|
|
1591
1771
|
|
|
1592
1772
|
//#endregion
|
|
1593
1773
|
//#region src/frontend/jit.ts
|
|
1774
|
+
const routinePrimitives = new Map([
|
|
1775
|
+
[Primitive.Sort, Routines.Sort],
|
|
1776
|
+
[Primitive.Argsort, Routines.Argsort],
|
|
1777
|
+
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1778
|
+
[Primitive.Cholesky, Routines.Cholesky],
|
|
1779
|
+
[Primitive.LU, Routines.LU]
|
|
1780
|
+
]);
|
|
1594
1781
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1595
1782
|
var JitProgram = class {
|
|
1596
1783
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1605,9 +1792,14 @@ var JitProgram = class {
|
|
|
1605
1792
|
case "execute": {
|
|
1606
1793
|
const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
|
|
1607
1794
|
const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
|
|
1608
|
-
|
|
1795
|
+
const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
|
|
1796
|
+
if (step.source instanceof Kernel) return PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
|
|
1797
|
+
else if (step.source instanceof Routine) return PPrint.pp(`${executeText}, routine ${step.source.name}`);
|
|
1798
|
+
else {
|
|
1799
|
+
step.source;
|
|
1800
|
+
return PPrint.pp(executeText);
|
|
1801
|
+
}
|
|
1609
1802
|
}
|
|
1610
|
-
case "const": return PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
|
|
1611
1803
|
case "malloc": return PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
|
|
1612
1804
|
case "incref": return PPrint.pp(`incref ${step.input}`);
|
|
1613
1805
|
case "free": return PPrint.pp(`free ${step.input}`);
|
|
@@ -1630,12 +1822,9 @@ var JitProgram = class {
|
|
|
1630
1822
|
const inputs$1 = step.inputs.map((id) => scope.get(id));
|
|
1631
1823
|
const outputs = step.outputs.map((id) => scope.get(id));
|
|
1632
1824
|
if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
|
|
1633
|
-
pending.push(new PendingExecute(this.backend, step.
|
|
1825
|
+
pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
|
|
1634
1826
|
break;
|
|
1635
1827
|
}
|
|
1636
|
-
case "const":
|
|
1637
|
-
scope.set(step.output, step.slot);
|
|
1638
|
-
break;
|
|
1639
1828
|
case "malloc": {
|
|
1640
1829
|
const slot = this.backend.malloc(step.size);
|
|
1641
1830
|
scope.set(step.output, slot);
|
|
@@ -1669,34 +1858,37 @@ var JitProgramBuilder = class {
|
|
|
1669
1858
|
this.#nextId = nargs;
|
|
1670
1859
|
this.steps = [];
|
|
1671
1860
|
}
|
|
1672
|
-
pushConst(slot) {
|
|
1673
|
-
const id = this.#nextId++;
|
|
1674
|
-
this.steps.push({
|
|
1675
|
-
type: "const",
|
|
1676
|
-
slot,
|
|
1677
|
-
output: id
|
|
1678
|
-
});
|
|
1679
|
-
return id;
|
|
1680
|
-
}
|
|
1681
1861
|
pushLit(lit) {
|
|
1682
|
-
const kernel = new Kernel(0,
|
|
1862
|
+
const kernel = new Kernel(0, lit.aval.size, AluExp.const(lit.dtype, lit.value));
|
|
1683
1863
|
return this.pushKernel(kernel, []);
|
|
1684
1864
|
}
|
|
1685
|
-
|
|
1865
|
+
pushBuffer(size$1) {
|
|
1686
1866
|
const id = this.#nextId++;
|
|
1687
1867
|
this.steps.push({
|
|
1688
1868
|
type: "malloc",
|
|
1689
|
-
size:
|
|
1869
|
+
size: size$1,
|
|
1690
1870
|
output: id
|
|
1691
1871
|
});
|
|
1872
|
+
return id;
|
|
1873
|
+
}
|
|
1874
|
+
pushKernel(kernel, inputs) {
|
|
1875
|
+
const id = this.pushBuffer(kernel.bytes);
|
|
1692
1876
|
this.steps.push({
|
|
1693
1877
|
type: "execute",
|
|
1694
|
-
kernel,
|
|
1878
|
+
source: kernel,
|
|
1695
1879
|
inputs,
|
|
1696
1880
|
outputs: [id]
|
|
1697
1881
|
});
|
|
1698
1882
|
return id;
|
|
1699
1883
|
}
|
|
1884
|
+
pushRoutine(routine, inputs, outputs) {
|
|
1885
|
+
this.steps.push({
|
|
1886
|
+
type: "execute",
|
|
1887
|
+
source: routine,
|
|
1888
|
+
inputs,
|
|
1889
|
+
outputs
|
|
1890
|
+
});
|
|
1891
|
+
}
|
|
1700
1892
|
pushIncref(id) {
|
|
1701
1893
|
this.steps.push({
|
|
1702
1894
|
type: "incref",
|
|
@@ -1722,28 +1914,18 @@ var JitProgramBuilder = class {
|
|
|
1722
1914
|
}
|
|
1723
1915
|
};
|
|
1724
1916
|
const jitCompileCache = /* @__PURE__ */ new Map();
|
|
1725
|
-
function jitCompile(backend, jaxpr
|
|
1726
|
-
|
|
1727
|
-
for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
|
|
1728
|
-
const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
|
|
1917
|
+
function jitCompile(backend, jaxpr) {
|
|
1918
|
+
const cacheKey = backend.type + "," + FpHash.hash(jaxpr);
|
|
1729
1919
|
const cached = jitCompileCache.get(cacheKey);
|
|
1730
1920
|
if (cached) return cached;
|
|
1731
1921
|
if (DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
|
|
1732
1922
|
jaxpr = jaxpr.flatten().simplify();
|
|
1733
|
-
const nargs = jaxpr.inBinders.length
|
|
1923
|
+
const nargs = jaxpr.inBinders.length;
|
|
1734
1924
|
const builder = new JitProgramBuilder(backend, nargs);
|
|
1735
1925
|
const blackNodes = splitGraphDataflow(backend, jaxpr);
|
|
1736
1926
|
const ctx = /* @__PURE__ */ new Map();
|
|
1737
|
-
for (let i = 0; i < consts.length; i++) {
|
|
1738
|
-
const v = jaxpr.inBinders[i];
|
|
1739
|
-
const slot = consts[i]._realizeSource();
|
|
1740
|
-
ctx.set(v, {
|
|
1741
|
-
type: "imm",
|
|
1742
|
-
arg: builder.pushConst(slot)
|
|
1743
|
-
});
|
|
1744
|
-
}
|
|
1745
1927
|
for (let i = 0; i < nargs; i++) {
|
|
1746
|
-
const v = jaxpr.inBinders[
|
|
1928
|
+
const v = jaxpr.inBinders[i];
|
|
1747
1929
|
ctx.set(v, {
|
|
1748
1930
|
type: "imm",
|
|
1749
1931
|
arg: i
|
|
@@ -1751,6 +1933,31 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1751
1933
|
}
|
|
1752
1934
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1753
1935
|
const eqn = jaxpr.eqns[i];
|
|
1936
|
+
if (routinePrimitives.has(eqn.primitive)) {
|
|
1937
|
+
const routine = new Routine(routinePrimitives.get(eqn.primitive), {
|
|
1938
|
+
inputShapes: eqn.inputs.map((x) => x.aval.shape),
|
|
1939
|
+
inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
|
|
1940
|
+
outputShapes: eqn.outBinders.map((x) => x.aval.shape),
|
|
1941
|
+
outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
|
|
1942
|
+
}, eqn.params);
|
|
1943
|
+
const inputs = [];
|
|
1944
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1945
|
+
const jv = ctx.get(input);
|
|
1946
|
+
if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
|
|
1947
|
+
inputs.push(jv.arg);
|
|
1948
|
+
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1949
|
+
const outputs = [];
|
|
1950
|
+
for (const outVar of eqn.outBinders) {
|
|
1951
|
+
const outId = builder.pushBuffer(outVar.aval.size * byteWidth(outVar.aval.dtype));
|
|
1952
|
+
outputs.push(outId);
|
|
1953
|
+
ctx.set(outVar, {
|
|
1954
|
+
type: "imm",
|
|
1955
|
+
arg: outId
|
|
1956
|
+
});
|
|
1957
|
+
}
|
|
1958
|
+
builder.pushRoutine(routine, inputs, outputs);
|
|
1959
|
+
continue;
|
|
1960
|
+
}
|
|
1754
1961
|
const inputExps = [];
|
|
1755
1962
|
const inputAvals = [];
|
|
1756
1963
|
const inputArgs = [];
|
|
@@ -1794,35 +2001,37 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1794
2001
|
let reduction;
|
|
1795
2002
|
if (inputReduction) {
|
|
1796
2003
|
const jv = inputReduction;
|
|
1797
|
-
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1798
|
-
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
2004
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
|
|
2005
|
+
exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
|
|
1799
2006
|
reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1800
2007
|
} else {
|
|
1801
2008
|
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1802
2009
|
exp$2 = ruleOutput.exp;
|
|
1803
2010
|
reduction = ruleOutput.reduction;
|
|
1804
2011
|
}
|
|
1805
|
-
|
|
1806
|
-
|
|
1807
|
-
|
|
1808
|
-
|
|
1809
|
-
|
|
1810
|
-
|
|
1811
|
-
|
|
1812
|
-
|
|
1813
|
-
|
|
2012
|
+
for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
|
|
2013
|
+
const outVar = eqn.outBinders[i$1];
|
|
2014
|
+
if (blackNodes.has(outVar)) {
|
|
2015
|
+
const nargs$1 = inputArgs.length;
|
|
2016
|
+
const size$1 = outVar.aval.size;
|
|
2017
|
+
const kernel = new Kernel(nargs$1, size$1, exp$2[i$1], reduction);
|
|
2018
|
+
const outId = builder.pushKernel(kernel, inputArgs);
|
|
2019
|
+
ctx.set(outVar, {
|
|
2020
|
+
type: "imm",
|
|
2021
|
+
arg: outId
|
|
2022
|
+
});
|
|
2023
|
+
} else if (reduction) ctx.set(outVar, {
|
|
2024
|
+
type: "red",
|
|
2025
|
+
exp: exp$2[i$1],
|
|
2026
|
+
reduction,
|
|
2027
|
+
args: inputArgs
|
|
1814
2028
|
});
|
|
1815
|
-
|
|
1816
|
-
|
|
1817
|
-
|
|
1818
|
-
|
|
1819
|
-
|
|
1820
|
-
}
|
|
1821
|
-
else ctx.set(outVar, {
|
|
1822
|
-
type: "exp",
|
|
1823
|
-
exp: exp$2,
|
|
1824
|
-
args: inputArgs
|
|
1825
|
-
});
|
|
2029
|
+
else ctx.set(outVar, {
|
|
2030
|
+
type: "exp",
|
|
2031
|
+
exp: exp$2[i$1],
|
|
2032
|
+
args: inputArgs
|
|
2033
|
+
});
|
|
2034
|
+
}
|
|
1826
2035
|
}
|
|
1827
2036
|
const outputIds = [];
|
|
1828
2037
|
for (const out of jaxpr.outs) if (out instanceof Var) {
|
|
@@ -1830,7 +2039,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1830
2039
|
if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
|
|
1831
2040
|
outputIds.push(jitValue.arg);
|
|
1832
2041
|
} else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
|
|
1833
|
-
const outputNeedsRef = new Set(
|
|
2042
|
+
const outputNeedsRef = new Set(range(nargs));
|
|
1834
2043
|
for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
|
|
1835
2044
|
else outputNeedsRef.add(outputId);
|
|
1836
2045
|
builder.insertFreeSteps(outputIds);
|
|
@@ -1863,17 +2072,22 @@ function broadcastedJit(fn, opts) {
|
|
|
1863
2072
|
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
|
|
1864
2073
|
return exp$2;
|
|
1865
2074
|
});
|
|
1866
|
-
return { exp: fn(exps, params) };
|
|
2075
|
+
return { exp: [fn(exps, params)] };
|
|
1867
2076
|
};
|
|
1868
2077
|
}
|
|
1869
2078
|
function unopJit(fn) {
|
|
1870
2079
|
return ([a], [_as], params) => {
|
|
1871
|
-
return { exp: fn(a, params) };
|
|
2080
|
+
return { exp: [fn(a, params)] };
|
|
1872
2081
|
};
|
|
1873
2082
|
}
|
|
1874
2083
|
function reshapeJit(fn) {
|
|
1875
2084
|
return ([a], [_as], params) => {
|
|
1876
|
-
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2085
|
+
return { exp: [reshapeViews(a, (st) => fn(st, params))] };
|
|
2086
|
+
};
|
|
2087
|
+
}
|
|
2088
|
+
function routineNoJit() {
|
|
2089
|
+
return () => {
|
|
2090
|
+
throw new Error("jit: rule is not implemented for routines");
|
|
1877
2091
|
};
|
|
1878
2092
|
}
|
|
1879
2093
|
const jitRules = {
|
|
@@ -1881,6 +2095,8 @@ const jitRules = {
|
|
|
1881
2095
|
[Primitive.Mul]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
|
|
1882
2096
|
[Primitive.Idiv]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
|
|
1883
2097
|
[Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
|
|
2098
|
+
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
2099
|
+
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1884
2100
|
[Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
|
|
1885
2101
|
[Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
|
|
1886
2102
|
[Primitive.Floor]: unopJit(AluExp.floor),
|
|
@@ -1888,17 +2104,6 @@ const jitRules = {
|
|
|
1888
2104
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1889
2105
|
[Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
|
|
1890
2106
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
|
|
1891
|
-
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1892
|
-
const mapping = (st) => {
|
|
1893
|
-
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
|
|
1894
|
-
};
|
|
1895
|
-
const k0 = reshapeViews(keys[0], mapping);
|
|
1896
|
-
const k1 = reshapeViews(keys[1], mapping);
|
|
1897
|
-
const c0 = AluExp.u32(0);
|
|
1898
|
-
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1899
|
-
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1900
|
-
return { exp: exp$2 };
|
|
1901
|
-
},
|
|
1902
2107
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
1903
2108
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
1904
2109
|
[Primitive.Asin]: unopJit(AluExp.asin),
|
|
@@ -1908,8 +2113,6 @@ const jitRules = {
|
|
|
1908
2113
|
[Primitive.Erf]: unopJit(AluExp.erf),
|
|
1909
2114
|
[Primitive.Erfc]: unopJit(AluExp.erfc),
|
|
1910
2115
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1911
|
-
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
1912
|
-
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1913
2116
|
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1914
2117
|
const keptAxes = [];
|
|
1915
2118
|
const shiftedAxes = [];
|
|
@@ -1925,7 +2128,7 @@ const jitRules = {
|
|
|
1925
2128
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
1926
2129
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
1927
2130
|
return {
|
|
1928
|
-
exp: a,
|
|
2131
|
+
exp: [a],
|
|
1929
2132
|
reduction
|
|
1930
2133
|
};
|
|
1931
2134
|
},
|
|
@@ -1936,13 +2139,13 @@ const jitRules = {
|
|
|
1936
2139
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1937
2140
|
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1938
2141
|
return {
|
|
1939
|
-
exp: a,
|
|
2142
|
+
exp: [a],
|
|
1940
2143
|
reduction
|
|
1941
2144
|
};
|
|
1942
2145
|
},
|
|
1943
2146
|
[Primitive.Dot]([a, b], [as, bs]) {
|
|
1944
2147
|
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
1945
|
-
const c = k1.exp;
|
|
2148
|
+
const [c] = k1.exp;
|
|
1946
2149
|
const cs = promoteAvals(as, bs);
|
|
1947
2150
|
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
1948
2151
|
op: AluOp.Add,
|
|
@@ -1959,16 +2162,42 @@ const jitRules = {
|
|
|
1959
2162
|
},
|
|
1960
2163
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1961
2164
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1962
|
-
[Primitive.
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
const
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
2165
|
+
[Primitive.Concatenate](exps, avals, { axis }) {
|
|
2166
|
+
const ndim$2 = avals[0].ndim;
|
|
2167
|
+
const sizes = avals.map((x) => x.shape[axis]);
|
|
2168
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2169
|
+
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2170
|
+
let cum = 0;
|
|
2171
|
+
const src = [];
|
|
2172
|
+
for (let i = 0; i < exps.length; i++) {
|
|
2173
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2174
|
+
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2175
|
+
cum += sizes[i];
|
|
2176
|
+
}
|
|
2177
|
+
return { exp: [src.reduce(AluExp.add)] };
|
|
2178
|
+
},
|
|
2179
|
+
[Primitive.Split]([a], [as], { axis, sizes }) {
|
|
2180
|
+
const exp$2 = [];
|
|
2181
|
+
let start = 0;
|
|
2182
|
+
for (const size$1 of sizes) {
|
|
2183
|
+
const slice = range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
|
|
2184
|
+
exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
|
|
2185
|
+
start += size$1;
|
|
2186
|
+
}
|
|
2187
|
+
return { exp: exp$2 };
|
|
2188
|
+
},
|
|
2189
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2190
|
+
const keyShape = keyShapes[0].shape;
|
|
2191
|
+
const mapping = (st) => {
|
|
2192
|
+
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(st.shape.length, shape$1.length));
|
|
2193
|
+
};
|
|
2194
|
+
const k0 = reshapeViews(keys[0], mapping);
|
|
2195
|
+
const k1 = reshapeViews(keys[1], mapping);
|
|
2196
|
+
const c0 = AluExp.u32(0);
|
|
2197
|
+
const c1 = AluExp.mod(AluExp.cast(DType.Uint32, AluVar.gidx), AluExp.u32(Math.max(prod(shape$1.slice(keyShape.length)), 1)));
|
|
2198
|
+
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2199
|
+
return { exp: [exp$2] };
|
|
2200
|
+
},
|
|
1972
2201
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1973
2202
|
const axisSet = new Set(axis);
|
|
1974
2203
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
@@ -1982,10 +2211,25 @@ const jitRules = {
|
|
|
1982
2211
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
1983
2212
|
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1984
2213
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1985
|
-
return { exp: x.substitute({ gidx: index }) };
|
|
2214
|
+
return { exp: [x.substitute({ gidx: index })] };
|
|
1986
2215
|
},
|
|
1987
|
-
[Primitive.
|
|
1988
|
-
|
|
2216
|
+
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2217
|
+
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
2218
|
+
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
2219
|
+
[Primitive.Flip]: reshapeJit((st, { axis }) => {
|
|
2220
|
+
const arg = rep(st.shape.length, false);
|
|
2221
|
+
for (const ax of axis) arg[ax] = true;
|
|
2222
|
+
return st.flip(arg);
|
|
2223
|
+
}),
|
|
2224
|
+
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
2225
|
+
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
2226
|
+
[Primitive.Sort]: routineNoJit(),
|
|
2227
|
+
[Primitive.Argsort]: routineNoJit(),
|
|
2228
|
+
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2229
|
+
[Primitive.Cholesky]: routineNoJit(),
|
|
2230
|
+
[Primitive.LU]: routineNoJit(),
|
|
2231
|
+
[Primitive.Jit]() {
|
|
2232
|
+
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
1989
2233
|
}
|
|
1990
2234
|
};
|
|
1991
2235
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
@@ -2043,8 +2287,8 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2043
2287
|
case Primitive.Mul:
|
|
2044
2288
|
case Primitive.Idiv:
|
|
2045
2289
|
case Primitive.Mod:
|
|
2046
|
-
case Primitive.
|
|
2047
|
-
case Primitive.
|
|
2290
|
+
case Primitive.Min:
|
|
2291
|
+
case Primitive.Max: {
|
|
2048
2292
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2049
2293
|
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2050
2294
|
head = usages[0];
|
|
@@ -2064,11 +2308,11 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2064
2308
|
blackNodes.add(v);
|
|
2065
2309
|
p1NextBlack.set(v, v);
|
|
2066
2310
|
}
|
|
2067
|
-
const heterogeneousViewPrimitives = [Primitive.
|
|
2311
|
+
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2068
2312
|
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2069
2313
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2070
2314
|
const eqn = jaxpr.eqns[i];
|
|
2071
|
-
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2315
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2072
2316
|
for (const v of eqn.outBinders) {
|
|
2073
2317
|
blackNodes.add(v);
|
|
2074
2318
|
p1NextBlack.set(v, v);
|
|
@@ -2078,7 +2322,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2078
2322
|
const reach = /* @__PURE__ */ new Set();
|
|
2079
2323
|
let needsCleanOutput = false;
|
|
2080
2324
|
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2081
|
-
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
|
|
2325
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
|
|
2082
2326
|
needsCleanOutput = true;
|
|
2083
2327
|
break outer;
|
|
2084
2328
|
}
|
|
@@ -2102,7 +2346,6 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2102
2346
|
while (p2idx < jaxpr.eqns.length) {
|
|
2103
2347
|
const eqn = jaxpr.eqns[p2idx++];
|
|
2104
2348
|
const deps = [];
|
|
2105
|
-
if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
|
|
2106
2349
|
for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
|
|
2107
2350
|
else deps.push(p2Deps.get(input));
|
|
2108
2351
|
else deps.push(/* @__PURE__ */ new Set());
|
|
@@ -2125,7 +2368,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2125
2368
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2126
2369
|
const assocVar = eqn.inputs[assocInput];
|
|
2127
2370
|
p2idx = varToDefn.get(assocVar);
|
|
2128
|
-
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2371
|
+
for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
|
|
2129
2372
|
} else {
|
|
2130
2373
|
const s = new Set(depCounter.keys());
|
|
2131
2374
|
for (const out of eqn.outBinders) p2Deps.set(out, s);
|
|
@@ -2151,9 +2394,9 @@ var PendingExecute = class {
|
|
|
2151
2394
|
submitted = false;
|
|
2152
2395
|
#promise = null;
|
|
2153
2396
|
#rc = 1;
|
|
2154
|
-
constructor(backend,
|
|
2397
|
+
constructor(backend, source, inputs, outputs) {
|
|
2155
2398
|
this.backend = backend;
|
|
2156
|
-
this.
|
|
2399
|
+
this.source = source;
|
|
2157
2400
|
this.inputs = inputs;
|
|
2158
2401
|
this.outputs = outputs;
|
|
2159
2402
|
for (const slot of inputs) this.backend.incRef(slot);
|
|
@@ -2174,13 +2417,15 @@ var PendingExecute = class {
|
|
|
2174
2417
|
return;
|
|
2175
2418
|
}
|
|
2176
2419
|
this.#promise = (async () => {
|
|
2177
|
-
this.prepared = await this.backend.
|
|
2420
|
+
if (this.source instanceof Kernel) this.prepared = await this.backend.prepareKernel(this.source);
|
|
2421
|
+
else this.prepared = await this.backend.prepareRoutine(this.source);
|
|
2178
2422
|
})();
|
|
2179
2423
|
await this.#promise;
|
|
2180
2424
|
}
|
|
2181
2425
|
prepareSync() {
|
|
2182
2426
|
if (this.prepared) return;
|
|
2183
|
-
this.prepared = this.backend.
|
|
2427
|
+
if (this.source instanceof Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
|
|
2428
|
+
else this.prepared = this.backend.prepareRoutineSync(this.source);
|
|
2184
2429
|
}
|
|
2185
2430
|
submit() {
|
|
2186
2431
|
if (this.submitted) return;
|
|
@@ -2203,8 +2448,6 @@ var PendingExecute = class {
|
|
|
2203
2448
|
* "Array" type by name.
|
|
2204
2449
|
*/
|
|
2205
2450
|
var Array$1 = class Array$1 extends Tracer {
|
|
2206
|
-
static #nextId = 1001;
|
|
2207
|
-
id;
|
|
2208
2451
|
#dtype;
|
|
2209
2452
|
#weakType;
|
|
2210
2453
|
#source;
|
|
@@ -2221,7 +2464,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2221
2464
|
*/
|
|
2222
2465
|
constructor(args) {
|
|
2223
2466
|
super(baseArrayTrace);
|
|
2224
|
-
this.id = Array$1.#nextId++;
|
|
2225
2467
|
this.#dtype = args.dtype;
|
|
2226
2468
|
this.#weakType = args.weakType;
|
|
2227
2469
|
this.#source = args.source;
|
|
@@ -2264,6 +2506,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2264
2506
|
this.#rc++;
|
|
2265
2507
|
return this;
|
|
2266
2508
|
}
|
|
2509
|
+
/** Get the current reference count (for debugging memory management). */
|
|
2510
|
+
get refCount() {
|
|
2511
|
+
return this.#rc;
|
|
2512
|
+
}
|
|
2267
2513
|
dispose() {
|
|
2268
2514
|
this.#check();
|
|
2269
2515
|
if (--this.#rc === 0) {
|
|
@@ -2421,7 +2667,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2421
2667
|
} else if (castDtype === void 0) {
|
|
2422
2668
|
castDtype = arrays[i].#dtype;
|
|
2423
2669
|
castWeakType = arrays[i].#weakType;
|
|
2424
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType),
|
|
2670
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
|
|
2425
2671
|
const weakType = castWeakType && !strongTypeOutput;
|
|
2426
2672
|
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
2427
2673
|
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
@@ -2530,6 +2776,27 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2530
2776
|
pending
|
|
2531
2777
|
});
|
|
2532
2778
|
}
|
|
2779
|
+
/** Apply an operation with custom lowering to this array. */
|
|
2780
|
+
static #routine(routine, arrays, outputWeakType) {
|
|
2781
|
+
const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
|
|
2782
|
+
for (const ar of arrays) ar.#realize();
|
|
2783
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2784
|
+
const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
|
|
2785
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2786
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2787
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2788
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2789
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2790
|
+
return outputs.map((output, i) => new Array$1({
|
|
2791
|
+
source: output,
|
|
2792
|
+
st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
|
|
2793
|
+
dtype: routine.type.outputDtypes[i],
|
|
2794
|
+
weakType: outputWeakType[i],
|
|
2795
|
+
backend,
|
|
2796
|
+
committed,
|
|
2797
|
+
pending
|
|
2798
|
+
}));
|
|
2799
|
+
}
|
|
2533
2800
|
/**
|
|
2534
2801
|
* Normalizes this array into one backed by a `Slot`.
|
|
2535
2802
|
*
|
|
@@ -2690,6 +2957,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2690
2957
|
[Primitive.Mod]([x, y]) {
|
|
2691
2958
|
return [x.#binary(AluOp.Mod, y)];
|
|
2692
2959
|
},
|
|
2960
|
+
[Primitive.Min]([x, y]) {
|
|
2961
|
+
return [x.#binary(AluOp.Min, y)];
|
|
2962
|
+
},
|
|
2963
|
+
[Primitive.Max]([x, y]) {
|
|
2964
|
+
return [x.#binary(AluOp.Max, y)];
|
|
2965
|
+
},
|
|
2693
2966
|
[Primitive.Neg]([x]) {
|
|
2694
2967
|
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
2695
2968
|
},
|
|
@@ -2726,25 +2999,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2726
2999
|
return [y];
|
|
2727
3000
|
}
|
|
2728
3001
|
},
|
|
2729
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2730
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2731
|
-
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2732
|
-
const c0 = zeros(shape$1, {
|
|
2733
|
-
dtype: DType.Uint32,
|
|
2734
|
-
device: k0.device
|
|
2735
|
-
});
|
|
2736
|
-
const c1 = arange(0, prod(shape$1), 1, {
|
|
2737
|
-
dtype: DType.Uint32,
|
|
2738
|
-
device: k0.device
|
|
2739
|
-
}).reshape(shape$1);
|
|
2740
|
-
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2741
|
-
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2742
|
-
k0,
|
|
2743
|
-
k1,
|
|
2744
|
-
c0,
|
|
2745
|
-
c1
|
|
2746
|
-
])];
|
|
2747
|
-
},
|
|
2748
3002
|
[Primitive.Sin]([x]) {
|
|
2749
3003
|
return [x.#unary(AluOp.Sin)];
|
|
2750
3004
|
},
|
|
@@ -2772,12 +3026,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2772
3026
|
[Primitive.Sqrt]([x]) {
|
|
2773
3027
|
return [x.#unary(AluOp.Sqrt)];
|
|
2774
3028
|
},
|
|
2775
|
-
[Primitive.Min]([x, y]) {
|
|
2776
|
-
return [x.#binary(AluOp.Min, y)];
|
|
2777
|
-
},
|
|
2778
|
-
[Primitive.Max]([x, y]) {
|
|
2779
|
-
return [x.#binary(AluOp.Max, y)];
|
|
2780
|
-
},
|
|
2781
3029
|
[Primitive.Reduce]([x], { op, axis }) {
|
|
2782
3030
|
if (axis.length === 0) return [x];
|
|
2783
3031
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
@@ -2812,6 +3060,55 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2812
3060
|
y
|
|
2813
3061
|
], { dtypeOverride: [DType.Bool] })];
|
|
2814
3062
|
},
|
|
3063
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
3064
|
+
const ndim$2 = xs[0].ndim;
|
|
3065
|
+
const sizes = xs.map((x) => x.shape[axis]);
|
|
3066
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
3067
|
+
const makePadAxis = (start, end) => range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
3068
|
+
let cum = 0;
|
|
3069
|
+
const xsPadded = [];
|
|
3070
|
+
for (let i = 0; i < xs.length; i++) {
|
|
3071
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
3072
|
+
xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
|
|
3073
|
+
cum += sizes[i];
|
|
3074
|
+
}
|
|
3075
|
+
const custom = (exps) => exps.reduce(AluExp.add);
|
|
3076
|
+
return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
|
|
3077
|
+
},
|
|
3078
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
3079
|
+
const outputs = [];
|
|
3080
|
+
for (let i = 0, start = 0; i < sizes.length; i++) {
|
|
3081
|
+
const slice = range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
|
|
3082
|
+
outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
|
|
3083
|
+
start += sizes[i];
|
|
3084
|
+
}
|
|
3085
|
+
x.dispose();
|
|
3086
|
+
return outputs;
|
|
3087
|
+
},
|
|
3088
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
3089
|
+
const keyShape = k0.shape;
|
|
3090
|
+
const genShape = shape$1.slice(keyShape.length);
|
|
3091
|
+
const c0 = zeros(genShape, {
|
|
3092
|
+
dtype: DType.Uint32,
|
|
3093
|
+
device: k0.device
|
|
3094
|
+
});
|
|
3095
|
+
const c1 = arange(0, prod(genShape), 1, {
|
|
3096
|
+
dtype: DType.Uint32,
|
|
3097
|
+
device: k0.device
|
|
3098
|
+
}).reshape(genShape);
|
|
3099
|
+
k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
|
|
3100
|
+
k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(rep(genShape.length, 1))));
|
|
3101
|
+
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
3102
|
+
return [Array$1.#naryCustom("random_bits", custom, [
|
|
3103
|
+
k0,
|
|
3104
|
+
k1,
|
|
3105
|
+
c0,
|
|
3106
|
+
c1
|
|
3107
|
+
])];
|
|
3108
|
+
},
|
|
3109
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
3110
|
+
return [x.#gather(indices, axis, outDim)];
|
|
3111
|
+
},
|
|
2815
3112
|
[Primitive.Transpose]([x], { perm }) {
|
|
2816
3113
|
return [x.#transpose(perm)];
|
|
2817
3114
|
},
|
|
@@ -2832,17 +3129,71 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2832
3129
|
[Primitive.Pad]([x], { width }) {
|
|
2833
3130
|
return [x.#reshape(x.#st.pad(width))];
|
|
2834
3131
|
},
|
|
2835
|
-
[Primitive.
|
|
2836
|
-
|
|
3132
|
+
[Primitive.Sort]([x]) {
|
|
3133
|
+
const routine = new Routine(Routines.Sort, {
|
|
3134
|
+
inputShapes: [x.shape],
|
|
3135
|
+
inputDtypes: [x.dtype],
|
|
3136
|
+
outputShapes: [x.shape],
|
|
3137
|
+
outputDtypes: [x.dtype]
|
|
3138
|
+
});
|
|
3139
|
+
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3140
|
+
},
|
|
3141
|
+
[Primitive.Argsort]([x]) {
|
|
3142
|
+
const routine = new Routine(Routines.Argsort, {
|
|
3143
|
+
inputShapes: [x.shape],
|
|
3144
|
+
inputDtypes: [x.dtype],
|
|
3145
|
+
outputShapes: [x.shape, x.shape],
|
|
3146
|
+
outputDtypes: [x.dtype, DType.Int32]
|
|
3147
|
+
});
|
|
3148
|
+
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3149
|
+
},
|
|
3150
|
+
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3151
|
+
const routine = new Routine(Routines.TriangularSolve, {
|
|
3152
|
+
inputShapes: [a.shape, b.shape],
|
|
3153
|
+
inputDtypes: [a.dtype, b.dtype],
|
|
3154
|
+
outputShapes: [b.shape],
|
|
3155
|
+
outputDtypes: [b.dtype]
|
|
3156
|
+
}, { unitDiagonal });
|
|
3157
|
+
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
2837
3158
|
},
|
|
2838
|
-
[Primitive.
|
|
2839
|
-
|
|
2840
|
-
|
|
3159
|
+
[Primitive.Cholesky]([a]) {
|
|
3160
|
+
const routine = new Routine(Routines.Cholesky, {
|
|
3161
|
+
inputShapes: [a.shape],
|
|
3162
|
+
inputDtypes: [a.dtype],
|
|
3163
|
+
outputShapes: [a.shape],
|
|
3164
|
+
outputDtypes: [a.dtype]
|
|
3165
|
+
});
|
|
3166
|
+
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3167
|
+
},
|
|
3168
|
+
[Primitive.LU]([a]) {
|
|
3169
|
+
const batch = a.shape.slice(0, -2);
|
|
3170
|
+
const [m, n] = a.shape.slice(-2);
|
|
3171
|
+
const routine = new Routine(Routines.LU, {
|
|
3172
|
+
inputShapes: [a.shape],
|
|
3173
|
+
inputDtypes: [a.dtype],
|
|
3174
|
+
outputShapes: [
|
|
3175
|
+
a.shape,
|
|
3176
|
+
[...batch, Math.min(m, n)],
|
|
3177
|
+
[...batch, m]
|
|
3178
|
+
],
|
|
3179
|
+
outputDtypes: [
|
|
3180
|
+
a.dtype,
|
|
3181
|
+
DType.Int32,
|
|
3182
|
+
DType.Int32
|
|
3183
|
+
]
|
|
3184
|
+
});
|
|
3185
|
+
return Array$1.#routine(routine, [a], [
|
|
3186
|
+
a.#weakType,
|
|
3187
|
+
false,
|
|
3188
|
+
false
|
|
3189
|
+
]);
|
|
3190
|
+
},
|
|
3191
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
3192
|
+
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3193
|
+
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
2841
3194
|
args = args.map((ar) => ar._putSync(backend));
|
|
2842
|
-
const
|
|
2843
|
-
const
|
|
2844
|
-
const jp = jitCompile(backend, jaxpr, consts);
|
|
2845
|
-
const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
|
|
3195
|
+
const jp = jitCompile(backend, jaxpr);
|
|
3196
|
+
const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
|
|
2846
3197
|
for (const exe of pending) exe.updateRc(+outputs.length - 1);
|
|
2847
3198
|
const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
|
|
2848
3199
|
for (const exe of prevPending) exe.updateRc(+outputs.length);
|
|
@@ -2942,7 +3293,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2942
3293
|
device
|
|
2943
3294
|
});
|
|
2944
3295
|
} else {
|
|
2945
|
-
const weakType = dtype == void 0;
|
|
3296
|
+
const weakType = dtype == void 0 && shape$1.length === 0;
|
|
2946
3297
|
dtype = dtype ?? DType.Float32;
|
|
2947
3298
|
const data = dtypedJsArray(dtype, flat);
|
|
2948
3299
|
return arrayFromData(data, shape$1, {
|
|
@@ -3056,7 +3407,7 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
3056
3407
|
}
|
|
3057
3408
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3058
3409
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
3059
|
-
let weakType = dtype == void 0;
|
|
3410
|
+
let weakType = dtype == void 0 && shape$1.length === 0;
|
|
3060
3411
|
if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
|
|
3061
3412
|
else if (typeof fillValue === "boolean") {
|
|
3062
3413
|
dtype = dtype ?? DType.Bool;
|
|
@@ -3141,6 +3492,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
3141
3492
|
});
|
|
3142
3493
|
}
|
|
3143
3494
|
/**
|
|
3495
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
3496
|
+
*
|
|
3497
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
3498
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
3499
|
+
* `k>0` is above it.
|
|
3500
|
+
*/
|
|
3501
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
3502
|
+
m ??= n;
|
|
3503
|
+
dtype ??= DType.Float32;
|
|
3504
|
+
if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
|
|
3505
|
+
if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
|
|
3506
|
+
if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
|
|
3507
|
+
const rows = arange(k, n + k, 1, {
|
|
3508
|
+
dtype: DType.Int32,
|
|
3509
|
+
device
|
|
3510
|
+
});
|
|
3511
|
+
const cols = arange(0, m, 1, {
|
|
3512
|
+
dtype: DType.Int32,
|
|
3513
|
+
device
|
|
3514
|
+
});
|
|
3515
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
3516
|
+
}
|
|
3517
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
3518
|
+
function tril(a, k = 0) {
|
|
3519
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3520
|
+
a = fudgeArray(a);
|
|
3521
|
+
const [n, m] = a.shape.slice(-2);
|
|
3522
|
+
return where$1(tri(n, m, k, { dtype: DType.Bool }), a.ref, zerosLike$1(a));
|
|
3523
|
+
}
|
|
3524
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
3525
|
+
function triu(a, k = 0) {
|
|
3526
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3527
|
+
a = fudgeArray(a);
|
|
3528
|
+
const [n, m] = a.shape.slice(-2);
|
|
3529
|
+
return where$1(tri(n, m, k - 1, { dtype: DType.Bool }), zerosLike$1(a.ref), a);
|
|
3530
|
+
}
|
|
3531
|
+
/**
|
|
3144
3532
|
* Return evenly spaced numbers over a specified interval.
|
|
3145
3533
|
*
|
|
3146
3534
|
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
@@ -3177,6 +3565,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
3177
3565
|
committed: device != void 0
|
|
3178
3566
|
});
|
|
3179
3567
|
}
|
|
3568
|
+
/**
|
|
3569
|
+
* Return numbers spaced evenly on a log scale.
|
|
3570
|
+
*
|
|
3571
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
3572
|
+
* `base ** stop` (see `endpoint` below).
|
|
3573
|
+
*
|
|
3574
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
3575
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
3576
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
3577
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
3578
|
+
* @param base - The base of the log space. Default is 10.
|
|
3579
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
3580
|
+
*/
|
|
3581
|
+
function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
|
|
3582
|
+
const y = linspace(start, stop, num, endpoint, {
|
|
3583
|
+
dtype,
|
|
3584
|
+
device
|
|
3585
|
+
});
|
|
3586
|
+
const logBase = Math.log(base);
|
|
3587
|
+
return exp$1(mul(y, logBase));
|
|
3588
|
+
}
|
|
3180
3589
|
function aluCompare(a, b, op) {
|
|
3181
3590
|
switch (op) {
|
|
3182
3591
|
case CompareOp.Less: return AluExp.cmplt(a, b);
|
|
@@ -3187,383 +3596,210 @@ function aluCompare(a, b, op) {
|
|
|
3187
3596
|
}
|
|
3188
3597
|
|
|
3189
3598
|
//#endregion
|
|
3190
|
-
//#region src/frontend/
|
|
3191
|
-
|
|
3192
|
-
|
|
3599
|
+
//#region src/frontend/vmap.ts
|
|
3600
|
+
function mappedAval(batchDim, aval) {
|
|
3601
|
+
const shape$1 = [...aval.shape];
|
|
3602
|
+
shape$1.splice(batchDim, 1);
|
|
3603
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3604
|
+
}
|
|
3605
|
+
/** Move one axis to a different index. */
|
|
3606
|
+
function moveaxis(x, src, dst) {
|
|
3607
|
+
const t = pureArray(x);
|
|
3608
|
+
src = checkAxis(src, t.ndim);
|
|
3609
|
+
dst = checkAxis(dst, t.ndim);
|
|
3610
|
+
if (src === dst) return t;
|
|
3611
|
+
const perm = range(t.ndim);
|
|
3612
|
+
perm.splice(src, 1);
|
|
3613
|
+
perm.splice(dst, 0, src);
|
|
3614
|
+
return transpose$1(t, perm);
|
|
3615
|
+
}
|
|
3616
|
+
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3617
|
+
if (src === null) {
|
|
3618
|
+
const targetShape = [...x.shape];
|
|
3619
|
+
targetShape.splice(dst, 0, axisSize);
|
|
3620
|
+
return broadcast(x, targetShape, [dst]);
|
|
3621
|
+
} else if (src === dst) return x;
|
|
3622
|
+
else return moveaxis(x, src, dst);
|
|
3623
|
+
}
|
|
3624
|
+
var BatchTracer = class extends Tracer {
|
|
3625
|
+
constructor(trace$1, val, batchDim) {
|
|
3193
3626
|
super(trace$1);
|
|
3194
|
-
this.
|
|
3195
|
-
this.
|
|
3627
|
+
this.val = val;
|
|
3628
|
+
this.batchDim = batchDim;
|
|
3196
3629
|
}
|
|
3197
3630
|
get aval() {
|
|
3198
|
-
return this.
|
|
3631
|
+
if (this.batchDim === null) return this.val.aval;
|
|
3632
|
+
else return mappedAval(this.batchDim, this.val.aval);
|
|
3199
3633
|
}
|
|
3200
3634
|
toString() {
|
|
3201
|
-
return `
|
|
3635
|
+
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3202
3636
|
}
|
|
3203
3637
|
get ref() {
|
|
3204
|
-
this.
|
|
3638
|
+
this.val.ref;
|
|
3205
3639
|
return this;
|
|
3206
3640
|
}
|
|
3207
3641
|
dispose() {
|
|
3208
|
-
this.
|
|
3209
|
-
|
|
3642
|
+
this.val.dispose();
|
|
3643
|
+
}
|
|
3644
|
+
fullLower() {
|
|
3645
|
+
if (this.batchDim === null) return this.val.fullLower();
|
|
3646
|
+
else return this;
|
|
3210
3647
|
}
|
|
3211
3648
|
};
|
|
3212
|
-
var
|
|
3649
|
+
var BatchTrace = class extends Trace {
|
|
3213
3650
|
pure(val) {
|
|
3214
3651
|
return this.lift(pureArray(val));
|
|
3215
3652
|
}
|
|
3216
3653
|
lift(val) {
|
|
3217
|
-
return new
|
|
3654
|
+
return new BatchTracer(this, val, null);
|
|
3218
3655
|
}
|
|
3219
3656
|
processPrimitive(primitive, tracers, params) {
|
|
3220
|
-
const [
|
|
3221
|
-
const
|
|
3222
|
-
if (
|
|
3223
|
-
|
|
3224
|
-
|
|
3657
|
+
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3658
|
+
const vmapRule = vmapRules[primitive];
|
|
3659
|
+
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3660
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3661
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3662
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3663
|
+
}
|
|
3664
|
+
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3665
|
+
if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
|
|
3666
|
+
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3667
|
+
}
|
|
3668
|
+
get axisSize() {
|
|
3669
|
+
return this.main.globalData;
|
|
3225
3670
|
}
|
|
3226
3671
|
};
|
|
3227
|
-
/**
|
|
3228
|
-
|
|
3229
|
-
|
|
3230
|
-
|
|
3231
|
-
|
|
3232
|
-
|
|
3672
|
+
/**
|
|
3673
|
+
* Process a primitive with built-in broadcasting.
|
|
3674
|
+
*
|
|
3675
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3676
|
+
*/
|
|
3677
|
+
function broadcastBatcher(prim) {
|
|
3678
|
+
return (axisSize, args, dims, params) => {
|
|
3679
|
+
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3680
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3681
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3682
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3683
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
|
|
3684
|
+
args = args.map((x, i) => {
|
|
3685
|
+
if (dims[i] === null) return x;
|
|
3686
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3687
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3688
|
+
x.shape[0],
|
|
3689
|
+
...rep(nd - x.ndim, 1),
|
|
3690
|
+
...x.shape.slice(1)
|
|
3691
|
+
]);
|
|
3692
|
+
return x;
|
|
3693
|
+
});
|
|
3694
|
+
return [[bind1(prim, args, params)], [0]];
|
|
3233
3695
|
};
|
|
3234
3696
|
}
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3239
|
-
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3240
|
-
return [[primal], [tangent]];
|
|
3697
|
+
function unopBatcher(prim) {
|
|
3698
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3699
|
+
return [[bind1(prim, [x], params)], [xBdim]];
|
|
3241
3700
|
};
|
|
3242
3701
|
}
|
|
3243
|
-
|
|
3244
|
-
|
|
3245
|
-
|
|
3246
|
-
|
|
3247
|
-
|
|
3248
|
-
return [
|
|
3702
|
+
function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
|
|
3703
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3704
|
+
assertNonNull(xBdim);
|
|
3705
|
+
if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), rep(numOutputs, xBdim)];
|
|
3706
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3707
|
+
return [bind(prim, [x], params), rep(numOutputs, 0)];
|
|
3249
3708
|
};
|
|
3250
3709
|
}
|
|
3251
|
-
const
|
|
3252
|
-
[Primitive.Add]:
|
|
3253
|
-
[Primitive.Mul]:
|
|
3254
|
-
[Primitive.Idiv]:
|
|
3255
|
-
[Primitive.Mod](
|
|
3256
|
-
|
|
3257
|
-
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
|
|
3261
|
-
|
|
3262
|
-
|
|
3263
|
-
|
|
3264
|
-
[Primitive.
|
|
3265
|
-
[Primitive.
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3269
|
-
[Primitive.
|
|
3270
|
-
[Primitive.
|
|
3271
|
-
[Primitive.
|
|
3272
|
-
[Primitive.
|
|
3273
|
-
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3277
|
-
|
|
3278
|
-
|
|
3279
|
-
},
|
|
3280
|
-
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3281
|
-
if (x.dtype === dtype) return [[x], [dx]];
|
|
3282
|
-
dx.dispose();
|
|
3283
|
-
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3284
|
-
},
|
|
3285
|
-
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3286
|
-
[Primitive.Sin]([x], [dx]) {
|
|
3287
|
-
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3710
|
+
const vmapRules = {
|
|
3711
|
+
[Primitive.Add]: broadcastBatcher(Primitive.Add),
|
|
3712
|
+
[Primitive.Mul]: broadcastBatcher(Primitive.Mul),
|
|
3713
|
+
[Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
|
|
3714
|
+
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3715
|
+
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3716
|
+
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3717
|
+
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3718
|
+
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3719
|
+
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
3720
|
+
[Primitive.Ceil]: unopBatcher(Primitive.Ceil),
|
|
3721
|
+
[Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
|
|
3722
|
+
[Primitive.Cast]: unopBatcher(Primitive.Cast),
|
|
3723
|
+
[Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
|
|
3724
|
+
[Primitive.Sin]: unopBatcher(Primitive.Sin),
|
|
3725
|
+
[Primitive.Cos]: unopBatcher(Primitive.Cos),
|
|
3726
|
+
[Primitive.Asin]: unopBatcher(Primitive.Asin),
|
|
3727
|
+
[Primitive.Atan]: unopBatcher(Primitive.Atan),
|
|
3728
|
+
[Primitive.Exp]: unopBatcher(Primitive.Exp),
|
|
3729
|
+
[Primitive.Log]: unopBatcher(Primitive.Log),
|
|
3730
|
+
[Primitive.Erf]: unopBatcher(Primitive.Erf),
|
|
3731
|
+
[Primitive.Erfc]: unopBatcher(Primitive.Erfc),
|
|
3732
|
+
[Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
|
|
3733
|
+
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3734
|
+
assertNonNull(xBdim);
|
|
3735
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3736
|
+
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3737
|
+
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3288
3738
|
},
|
|
3289
|
-
[Primitive.
|
|
3290
|
-
|
|
3739
|
+
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3740
|
+
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3741
|
+
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3742
|
+
const z = dot$2(x, y);
|
|
3743
|
+
return [[z], [z.ndim - 1]];
|
|
3291
3744
|
},
|
|
3292
|
-
[Primitive.
|
|
3293
|
-
|
|
3294
|
-
|
|
3745
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3746
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3747
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3748
|
+
const z = conv$1(x, y, {
|
|
3749
|
+
...params,
|
|
3750
|
+
vmapDims: params.vmapDims + 1
|
|
3751
|
+
});
|
|
3752
|
+
return [[z], [0]];
|
|
3295
3753
|
},
|
|
3296
|
-
[Primitive.
|
|
3297
|
-
|
|
3298
|
-
|
|
3754
|
+
[Primitive.Compare]: broadcastBatcher(Primitive.Compare),
|
|
3755
|
+
[Primitive.Where]: broadcastBatcher(Primitive.Where),
|
|
3756
|
+
[Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
|
|
3757
|
+
const minBdim = Math.min(...xBdims.filter((d) => d !== null));
|
|
3758
|
+
xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
|
|
3759
|
+
const newAxis = axis + (minBdim <= axis ? 1 : 0);
|
|
3760
|
+
return [[concatenate$1(xs, newAxis)], [minBdim]];
|
|
3299
3761
|
},
|
|
3300
|
-
[Primitive.
|
|
3301
|
-
|
|
3302
|
-
|
|
3762
|
+
[Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
|
|
3763
|
+
assertNonNull(xBdim);
|
|
3764
|
+
const newAxis = axis + (xBdim <= axis ? 1 : 0);
|
|
3765
|
+
const outs = split$2(x, newAxis, sizes);
|
|
3766
|
+
return [outs, rep(outs.length, xBdim)];
|
|
3303
3767
|
},
|
|
3304
|
-
[Primitive.
|
|
3305
|
-
|
|
3768
|
+
[Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
|
|
3769
|
+
k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
|
|
3770
|
+
k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
|
|
3771
|
+
return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
|
|
3306
3772
|
},
|
|
3307
|
-
[Primitive.
|
|
3308
|
-
|
|
3309
|
-
|
|
3310
|
-
|
|
3311
|
-
|
|
3312
|
-
|
|
3313
|
-
|
|
3314
|
-
|
|
3315
|
-
|
|
3316
|
-
},
|
|
3317
|
-
[Primitive.Sqrt]([x], [dx]) {
|
|
3318
|
-
const z = sqrt$1(x);
|
|
3319
|
-
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3320
|
-
},
|
|
3321
|
-
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3322
|
-
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3323
|
-
},
|
|
3324
|
-
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3325
|
-
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3326
|
-
},
|
|
3327
|
-
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3328
|
-
if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3329
|
-
else if (op === AluOp.Mul) {
|
|
3330
|
-
const primal = reduce(x.ref, op, axis);
|
|
3331
|
-
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3332
|
-
return [[primal], [tangent]];
|
|
3333
|
-
} else if (op === AluOp.Min || op === AluOp.Max) {
|
|
3334
|
-
const primal = reduce(x.ref, op, axis);
|
|
3335
|
-
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3336
|
-
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3337
|
-
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3338
|
-
return [[primal], [tangent]];
|
|
3339
|
-
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3340
|
-
},
|
|
3341
|
-
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3342
|
-
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3343
|
-
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3344
|
-
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3345
|
-
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3346
|
-
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3347
|
-
dcond.dispose();
|
|
3348
|
-
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3349
|
-
},
|
|
3350
|
-
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3351
|
-
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3352
|
-
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3353
|
-
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3354
|
-
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3355
|
-
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3356
|
-
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3357
|
-
const indicesRef = indices.map((t) => t.ref);
|
|
3358
|
-
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3359
|
-
},
|
|
3360
|
-
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3361
|
-
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3362
|
-
const outs = bind(Primitive.JitCall, [
|
|
3363
|
-
...newConsts.map((c) => c.ref),
|
|
3364
|
-
...primals,
|
|
3365
|
-
...tangents
|
|
3366
|
-
], {
|
|
3367
|
-
name: `${name}_jvp`,
|
|
3368
|
-
jaxpr: newJaxpr,
|
|
3369
|
-
numConsts: newConsts.length
|
|
3370
|
-
});
|
|
3371
|
-
const n = outs.length / 2;
|
|
3372
|
-
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3373
|
-
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3374
|
-
return [primalsOut, tangentsOut];
|
|
3375
|
-
}
|
|
3376
|
-
};
|
|
3377
|
-
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3378
|
-
function jvpJaxpr(jaxpr) {
|
|
3379
|
-
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3380
|
-
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3381
|
-
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3382
|
-
const result = {
|
|
3383
|
-
newJaxpr,
|
|
3384
|
-
newConsts
|
|
3385
|
-
};
|
|
3386
|
-
jvpJaxprCache.set(jaxpr, result);
|
|
3387
|
-
return result;
|
|
3388
|
-
}
|
|
3389
|
-
function jvpFlat(f, primals, tangents) {
|
|
3390
|
-
try {
|
|
3391
|
-
var _usingCtx$1 = _usingCtx();
|
|
3392
|
-
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3393
|
-
const trace$1 = new JVPTrace(main);
|
|
3394
|
-
const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3395
|
-
const outs = f(...tracersIn);
|
|
3396
|
-
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3397
|
-
return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3398
|
-
} catch (_) {
|
|
3399
|
-
_usingCtx$1.e = _;
|
|
3400
|
-
} finally {
|
|
3401
|
-
_usingCtx$1.d();
|
|
3402
|
-
}
|
|
3403
|
-
}
|
|
3404
|
-
function jvp$1(f, primals, tangents) {
|
|
3405
|
-
const [primalsFlat, inTree] = flatten(primals);
|
|
3406
|
-
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
3407
|
-
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
3408
|
-
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
3409
|
-
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
3410
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
3411
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3412
|
-
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
3413
|
-
return [primalsOut, tangentsOut];
|
|
3414
|
-
}
|
|
3415
|
-
|
|
3416
|
-
//#endregion
|
|
3417
|
-
//#region src/frontend/vmap.ts
|
|
3418
|
-
function mappedAval(batchDim, aval) {
|
|
3419
|
-
const shape$1 = [...aval.shape];
|
|
3420
|
-
shape$1.splice(batchDim, 1);
|
|
3421
|
-
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3422
|
-
}
|
|
3423
|
-
/** Move one axis to a different index. */
|
|
3424
|
-
function moveaxis(x, src, dst) {
|
|
3425
|
-
const t = pureArray(x);
|
|
3426
|
-
src = checkAxis(src, t.ndim);
|
|
3427
|
-
dst = checkAxis(dst, t.ndim);
|
|
3428
|
-
if (src === dst) return t;
|
|
3429
|
-
const perm = range(t.ndim);
|
|
3430
|
-
perm.splice(src, 1);
|
|
3431
|
-
perm.splice(dst, 0, src);
|
|
3432
|
-
return transpose$1(t, perm);
|
|
3433
|
-
}
|
|
3434
|
-
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3435
|
-
if (src === null) {
|
|
3436
|
-
const targetShape = [...x.shape];
|
|
3437
|
-
targetShape.splice(dst, 0, axisSize);
|
|
3438
|
-
return broadcast(x, targetShape, [dst]);
|
|
3439
|
-
} else if (src === dst) return x;
|
|
3440
|
-
else return moveaxis(x, src, dst);
|
|
3441
|
-
}
|
|
3442
|
-
var BatchTracer = class extends Tracer {
|
|
3443
|
-
constructor(trace$1, val, batchDim) {
|
|
3444
|
-
super(trace$1);
|
|
3445
|
-
this.val = val;
|
|
3446
|
-
this.batchDim = batchDim;
|
|
3447
|
-
}
|
|
3448
|
-
get aval() {
|
|
3449
|
-
if (this.batchDim === null) return this.val.aval;
|
|
3450
|
-
else return mappedAval(this.batchDim, this.val.aval);
|
|
3451
|
-
}
|
|
3452
|
-
toString() {
|
|
3453
|
-
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3454
|
-
}
|
|
3455
|
-
get ref() {
|
|
3456
|
-
this.val.ref;
|
|
3457
|
-
return this;
|
|
3458
|
-
}
|
|
3459
|
-
dispose() {
|
|
3460
|
-
this.val.dispose();
|
|
3461
|
-
}
|
|
3462
|
-
fullLower() {
|
|
3463
|
-
if (this.batchDim === null) return this.val.fullLower();
|
|
3464
|
-
else return this;
|
|
3465
|
-
}
|
|
3466
|
-
};
|
|
3467
|
-
var BatchTrace = class extends Trace {
|
|
3468
|
-
pure(val) {
|
|
3469
|
-
return this.lift(pureArray(val));
|
|
3470
|
-
}
|
|
3471
|
-
lift(val) {
|
|
3472
|
-
return new BatchTracer(this, val, null);
|
|
3473
|
-
}
|
|
3474
|
-
processPrimitive(primitive, tracers, params) {
|
|
3475
|
-
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3476
|
-
const vmapRule = vmapRules[primitive];
|
|
3477
|
-
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3478
|
-
if (bdimsIn.every((d) => d === null)) {
|
|
3479
|
-
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3480
|
-
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3773
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3774
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3775
|
+
assertNonNull(xBdim);
|
|
3776
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3777
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3778
|
+
let newOutDim = outDim;
|
|
3779
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3780
|
+
else newOutDim += 1;
|
|
3781
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3481
3782
|
}
|
|
3482
|
-
const
|
|
3483
|
-
|
|
3484
|
-
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
* Process a primitive with built-in broadcasting.
|
|
3491
|
-
*
|
|
3492
|
-
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3493
|
-
*/
|
|
3494
|
-
function broadcastBatcher(op) {
|
|
3495
|
-
return (axisSize, args, dims) => {
|
|
3496
|
-
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3497
|
-
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3498
|
-
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3499
|
-
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3500
|
-
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3501
|
-
args = args.map((x, i) => {
|
|
3502
|
-
if (dims[i] === null) return x;
|
|
3503
|
-
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3504
|
-
if (x.ndim < nd) x = x.reshape([
|
|
3505
|
-
x.shape[0],
|
|
3506
|
-
...rep(nd - x.ndim, 1),
|
|
3507
|
-
...x.shape.slice(1)
|
|
3783
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3784
|
+
indices = indices.map((m, i) => {
|
|
3785
|
+
if (indicesBdim[i] === null) return m;
|
|
3786
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3787
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3788
|
+
m.shape[0],
|
|
3789
|
+
...rep(nd - m.ndim, 1),
|
|
3790
|
+
...m.shape.slice(1)
|
|
3508
3791
|
]);
|
|
3509
|
-
return
|
|
3510
|
-
});
|
|
3511
|
-
return [[op(...args)], [0]];
|
|
3512
|
-
};
|
|
3513
|
-
}
|
|
3514
|
-
function unopBatcher(op) {
|
|
3515
|
-
return (axisSize, [x], [xBdim], params) => {
|
|
3516
|
-
return [[op(x, params)], [xBdim]];
|
|
3517
|
-
};
|
|
3518
|
-
}
|
|
3519
|
-
const vmapRules = {
|
|
3520
|
-
[Primitive.Add]: broadcastBatcher(add$1),
|
|
3521
|
-
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3522
|
-
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3523
|
-
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3524
|
-
[Primitive.Neg]: unopBatcher(neg),
|
|
3525
|
-
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3526
|
-
[Primitive.Floor]: unopBatcher(floor$1),
|
|
3527
|
-
[Primitive.Ceil]: unopBatcher(ceil$1),
|
|
3528
|
-
[Primitive.StopGradient]: unopBatcher(stopGradient),
|
|
3529
|
-
[Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
|
|
3530
|
-
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3531
|
-
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3532
|
-
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3533
|
-
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3534
|
-
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3535
|
-
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3536
|
-
[Primitive.Log]: unopBatcher(log$1),
|
|
3537
|
-
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3538
|
-
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3539
|
-
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3540
|
-
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3541
|
-
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3542
|
-
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3543
|
-
assertNonNull(xBdim);
|
|
3544
|
-
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3545
|
-
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3546
|
-
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3547
|
-
},
|
|
3548
|
-
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3549
|
-
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3550
|
-
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3551
|
-
const z = dot$2(x, y);
|
|
3552
|
-
return [[z], [z.ndim - 1]];
|
|
3553
|
-
},
|
|
3554
|
-
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3555
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3556
|
-
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3557
|
-
const z = conv$1(x, y, {
|
|
3558
|
-
...params,
|
|
3559
|
-
vmapDims: params.vmapDims + 1
|
|
3792
|
+
return m;
|
|
3560
3793
|
});
|
|
3561
|
-
return [[
|
|
3562
|
-
|
|
3563
|
-
|
|
3564
|
-
|
|
3794
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3795
|
+
else {
|
|
3796
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3797
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3798
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
|
|
3799
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3800
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3801
|
+
}
|
|
3565
3802
|
},
|
|
3566
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3567
3803
|
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3568
3804
|
assertNonNull(xBdim);
|
|
3569
3805
|
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
@@ -3595,42 +3831,39 @@ const vmapRules = {
|
|
|
3595
3831
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3596
3832
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3597
3833
|
},
|
|
3598
|
-
[Primitive.
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
|
|
3603
|
-
|
|
3604
|
-
|
|
3605
|
-
|
|
3606
|
-
|
|
3607
|
-
|
|
3608
|
-
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3609
|
-
indices = indices.map((m, i) => {
|
|
3610
|
-
if (indicesBdim[i] === null) return m;
|
|
3611
|
-
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3612
|
-
if (m.ndim < nd) m = m.reshape([
|
|
3613
|
-
m.shape[0],
|
|
3614
|
-
...rep(nd - m.ndim, 1),
|
|
3615
|
-
...m.shape.slice(1)
|
|
3834
|
+
[Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
|
|
3835
|
+
[Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
|
|
3836
|
+
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3837
|
+
if (aBdim === null) {
|
|
3838
|
+
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
3839
|
+
const [s, m, n] = b.shape.slice(-3);
|
|
3840
|
+
b = b.reshape([
|
|
3841
|
+
...b.shape.slice(0, -3),
|
|
3842
|
+
s * m,
|
|
3843
|
+
n
|
|
3616
3844
|
]);
|
|
3617
|
-
|
|
3618
|
-
|
|
3619
|
-
|
|
3620
|
-
|
|
3621
|
-
|
|
3622
|
-
|
|
3623
|
-
|
|
3624
|
-
|
|
3625
|
-
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3845
|
+
let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3846
|
+
x$1 = x$1.reshape([
|
|
3847
|
+
...b.shape.slice(0, -2),
|
|
3848
|
+
s,
|
|
3849
|
+
m,
|
|
3850
|
+
n
|
|
3851
|
+
]);
|
|
3852
|
+
return [[x$1], [x$1.ndim - 3]];
|
|
3626
3853
|
}
|
|
3854
|
+
a = moveBatchAxis(axisSize, aBdim, 0, a);
|
|
3855
|
+
b = moveBatchAxis(axisSize, bBdim, 0, b);
|
|
3856
|
+
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3857
|
+
return [[x], [0]];
|
|
3627
3858
|
},
|
|
3628
|
-
[Primitive.
|
|
3629
|
-
|
|
3630
|
-
|
|
3859
|
+
[Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
|
|
3860
|
+
[Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
|
|
3861
|
+
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3862
|
+
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3863
|
+
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
3631
3864
|
name: `${name}_vmap`,
|
|
3632
|
-
jaxpr: newJaxpr,
|
|
3633
|
-
numConsts:
|
|
3865
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3866
|
+
numConsts: newJaxpr.consts.length
|
|
3634
3867
|
});
|
|
3635
3868
|
return [outs, rep(outs.length, 0)];
|
|
3636
3869
|
}
|
|
@@ -3646,14 +3879,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3646
3879
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3647
3880
|
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3648
3881
|
});
|
|
3649
|
-
const { jaxpr: newJaxpr
|
|
3650
|
-
const result = {
|
|
3651
|
-
newJaxpr,
|
|
3652
|
-
newConsts
|
|
3653
|
-
};
|
|
3882
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3654
3883
|
if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
3655
|
-
vmapJaxprCache.get(jaxpr).set(cacheKey,
|
|
3656
|
-
return
|
|
3884
|
+
vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
3885
|
+
return newJaxpr;
|
|
3657
3886
|
}
|
|
3658
3887
|
function vmapFlat(f, inAxes, args) {
|
|
3659
3888
|
let axisSize = void 0;
|
|
@@ -3699,13 +3928,311 @@ function vmap$1(f, inAxes = 0) {
|
|
|
3699
3928
|
return unflatten(outTree.value, outsFlat);
|
|
3700
3929
|
};
|
|
3701
3930
|
}
|
|
3702
|
-
function jacfwd$1(f) {
|
|
3703
|
-
return function jacobianForward(x) {
|
|
3704
|
-
if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
|
|
3705
|
-
const [size$1] = x.shape;
|
|
3706
|
-
const pushfwd = (v) => jvp$1(f, [x], [v])[1];
|
|
3707
|
-
return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3708
|
-
};
|
|
3931
|
+
function jacfwd$1(f) {
|
|
3932
|
+
return function jacobianForward(x) {
|
|
3933
|
+
if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
|
|
3934
|
+
const [size$1] = x.shape;
|
|
3935
|
+
const pushfwd = (v) => jvp$1(f, [x], [v])[1];
|
|
3936
|
+
return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3937
|
+
};
|
|
3938
|
+
}
|
|
3939
|
+
|
|
3940
|
+
//#endregion
|
|
3941
|
+
//#region src/frontend/jvp.ts
|
|
3942
|
+
var JVPTracer = class extends Tracer {
|
|
3943
|
+
constructor(trace$1, primal, tangent) {
|
|
3944
|
+
super(trace$1);
|
|
3945
|
+
this.primal = primal;
|
|
3946
|
+
this.tangent = tangent;
|
|
3947
|
+
}
|
|
3948
|
+
get aval() {
|
|
3949
|
+
return this.primal.aval;
|
|
3950
|
+
}
|
|
3951
|
+
toString() {
|
|
3952
|
+
return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
|
|
3953
|
+
}
|
|
3954
|
+
get ref() {
|
|
3955
|
+
this.primal.ref, this.tangent.ref;
|
|
3956
|
+
return this;
|
|
3957
|
+
}
|
|
3958
|
+
dispose() {
|
|
3959
|
+
this.primal.dispose();
|
|
3960
|
+
this.tangent.dispose();
|
|
3961
|
+
}
|
|
3962
|
+
};
|
|
3963
|
+
var JVPTrace = class extends Trace {
|
|
3964
|
+
pure(val) {
|
|
3965
|
+
return this.lift(pureArray(val));
|
|
3966
|
+
}
|
|
3967
|
+
lift(val) {
|
|
3968
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
3969
|
+
}
|
|
3970
|
+
processPrimitive(primitive, tracers, params) {
|
|
3971
|
+
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
3972
|
+
const jvpRule = jvpRules[primitive];
|
|
3973
|
+
if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
|
|
3974
|
+
const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
|
|
3975
|
+
return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
|
|
3976
|
+
}
|
|
3977
|
+
};
|
|
3978
|
+
/** Rule that applies the same operation to primals and tangents. */
|
|
3979
|
+
function linearTangentsJvp(primitive) {
|
|
3980
|
+
return (primals, tangents, params) => {
|
|
3981
|
+
const ys = bind(primitive, primals, params);
|
|
3982
|
+
const dys = bind(primitive, tangents, params);
|
|
3983
|
+
return [ys, dys];
|
|
3984
|
+
};
|
|
3985
|
+
}
|
|
3986
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
3987
|
+
function bilinearTangentsJvp(primitive) {
|
|
3988
|
+
return ([x, y], [dx, dy], params) => {
|
|
3989
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3990
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3991
|
+
return [[primal], [tangent]];
|
|
3992
|
+
};
|
|
3993
|
+
}
|
|
3994
|
+
/** Rule that zeros out any tangents. */
|
|
3995
|
+
function zeroTangentsJvp(primitive) {
|
|
3996
|
+
return (primals, tangents, params) => {
|
|
3997
|
+
for (const t of tangents) t.dispose();
|
|
3998
|
+
const ys = bind(primitive, primals, params);
|
|
3999
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
4000
|
+
};
|
|
4001
|
+
}
|
|
4002
|
+
/** Compute `a @ b.T`, batched to last two axes. */
|
|
4003
|
+
function batchMatmulT(a, b) {
|
|
4004
|
+
return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
|
|
4005
|
+
}
|
|
4006
|
+
/** Batch matrix transpose. */
|
|
4007
|
+
function mT(a) {
|
|
4008
|
+
return moveaxis(a, -2, -1);
|
|
4009
|
+
}
|
|
4010
|
+
function sliceAxis(a, axis, p) {
|
|
4011
|
+
const slices = Array(a.shape.length).fill([]);
|
|
4012
|
+
slices[checkAxis(axis, a.ndim)] = p;
|
|
4013
|
+
return a.slice(...slices);
|
|
4014
|
+
}
|
|
4015
|
+
function padAxis(a, axis, p) {
|
|
4016
|
+
const pads = Array(a.shape.length).fill([0, 0]);
|
|
4017
|
+
pads[checkAxis(axis, a.ndim)] = p;
|
|
4018
|
+
return pad$1(a, pads);
|
|
4019
|
+
}
|
|
4020
|
+
const jvpRules = {
|
|
4021
|
+
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
4022
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
4023
|
+
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
4024
|
+
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
4025
|
+
if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
|
|
4026
|
+
dx.dispose();
|
|
4027
|
+
dy.dispose();
|
|
4028
|
+
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
4029
|
+
}
|
|
4030
|
+
const q = idiv(x.ref, y.ref);
|
|
4031
|
+
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
4032
|
+
},
|
|
4033
|
+
[Primitive.Min]([x, y], [dx, dy]) {
|
|
4034
|
+
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
4035
|
+
},
|
|
4036
|
+
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4037
|
+
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4038
|
+
},
|
|
4039
|
+
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4040
|
+
[Primitive.Reciprocal]([x], [dx]) {
|
|
4041
|
+
const xRecip = reciprocal$1(x.ref);
|
|
4042
|
+
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
4043
|
+
},
|
|
4044
|
+
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
4045
|
+
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
4046
|
+
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
4047
|
+
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
4048
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
4049
|
+
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
4050
|
+
else {
|
|
4051
|
+
dx.dispose();
|
|
4052
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
4053
|
+
}
|
|
4054
|
+
},
|
|
4055
|
+
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
4056
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
4057
|
+
dx.dispose();
|
|
4058
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
4059
|
+
},
|
|
4060
|
+
[Primitive.Sin]([x], [dx]) {
|
|
4061
|
+
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
4062
|
+
},
|
|
4063
|
+
[Primitive.Cos]([x], [dx]) {
|
|
4064
|
+
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
4065
|
+
},
|
|
4066
|
+
[Primitive.Asin]([x], [dx]) {
|
|
4067
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
4068
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
4069
|
+
},
|
|
4070
|
+
[Primitive.Atan]([x], [dx]) {
|
|
4071
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
4072
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
4073
|
+
},
|
|
4074
|
+
[Primitive.Exp]([x], [dx]) {
|
|
4075
|
+
const z = exp$1(x);
|
|
4076
|
+
return [[z.ref], [z.mul(dx)]];
|
|
4077
|
+
},
|
|
4078
|
+
[Primitive.Log]([x], [dx]) {
|
|
4079
|
+
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
4080
|
+
},
|
|
4081
|
+
[Primitive.Erf]([x], [dx]) {
|
|
4082
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
4083
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
4084
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
4085
|
+
},
|
|
4086
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
4087
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
4088
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
4089
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
4090
|
+
},
|
|
4091
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
4092
|
+
const z = sqrt$1(x);
|
|
4093
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
4094
|
+
},
|
|
4095
|
+
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
4096
|
+
if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
4097
|
+
else if (op === AluOp.Mul) {
|
|
4098
|
+
const primal = reduce(x.ref, op, axis);
|
|
4099
|
+
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
4100
|
+
return [[primal], [tangent]];
|
|
4101
|
+
} else if (op === AluOp.Min || op === AluOp.Max) {
|
|
4102
|
+
const primal = reduce(x.ref, op, axis);
|
|
4103
|
+
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
4104
|
+
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
4105
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
4106
|
+
return [[primal], [tangent]];
|
|
4107
|
+
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
4108
|
+
},
|
|
4109
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
4110
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
4111
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
4112
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
4113
|
+
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
4114
|
+
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
4115
|
+
dcond.dispose();
|
|
4116
|
+
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
4117
|
+
},
|
|
4118
|
+
[Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
|
|
4119
|
+
[Primitive.Split]: linearTangentsJvp(Primitive.Split),
|
|
4120
|
+
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
4121
|
+
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
4122
|
+
const indicesRef = indices.map((t) => t.ref);
|
|
4123
|
+
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
4124
|
+
},
|
|
4125
|
+
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
4126
|
+
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
4127
|
+
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
4128
|
+
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
4129
|
+
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
4130
|
+
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
4131
|
+
[Primitive.Sort]([x], [dx]) {
|
|
4132
|
+
const [y, idx] = argsort$1(x);
|
|
4133
|
+
return [[y], [gather(dx, [idx], [-1], -1)]];
|
|
4134
|
+
},
|
|
4135
|
+
[Primitive.Argsort]([x], [dx]) {
|
|
4136
|
+
const [y, idx] = argsort$1(x);
|
|
4137
|
+
return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
|
|
4138
|
+
},
|
|
4139
|
+
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4140
|
+
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4141
|
+
const dax = batchMatmulT(da, x.ref);
|
|
4142
|
+
const rhsT = db.sub(mT(dax));
|
|
4143
|
+
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
4144
|
+
return [[x], [dx]];
|
|
4145
|
+
},
|
|
4146
|
+
[Primitive.Cholesky]([a], [da]) {
|
|
4147
|
+
const L = cholesky$2(a.ref);
|
|
4148
|
+
da = da.ref.add(mT(da)).mul(.5);
|
|
4149
|
+
const W = triangularSolve$1(L.ref, da, { lower: true });
|
|
4150
|
+
const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
|
|
4151
|
+
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
4152
|
+
return [[L], [dL]];
|
|
4153
|
+
},
|
|
4154
|
+
[Primitive.LU]([a], [da]) {
|
|
4155
|
+
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4156
|
+
const [m, n] = a.shape.slice(-2);
|
|
4157
|
+
const k = Math.min(m, n);
|
|
4158
|
+
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
4159
|
+
const lLower = tril(luSliceL, -1);
|
|
4160
|
+
const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
|
|
4161
|
+
const L = lPadded.add(eye(m));
|
|
4162
|
+
const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
|
|
4163
|
+
const uUpper = triu(luSliceU);
|
|
4164
|
+
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4165
|
+
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4166
|
+
const U = uPadded.add(uEye);
|
|
4167
|
+
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4168
|
+
const pda = batchMatmulT(P, mT(da));
|
|
4169
|
+
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4170
|
+
lower: true,
|
|
4171
|
+
unitDiagonal: true
|
|
4172
|
+
}));
|
|
4173
|
+
const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
|
|
4174
|
+
const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
|
|
4175
|
+
const uDot = batchMatmulT(triu(lau), mT(U));
|
|
4176
|
+
return [[
|
|
4177
|
+
luMatrix,
|
|
4178
|
+
pivots,
|
|
4179
|
+
permutation
|
|
4180
|
+
], [
|
|
4181
|
+
lDot.add(uDot),
|
|
4182
|
+
zerosLike$1(pivots.ref),
|
|
4183
|
+
zerosLike$1(permutation.ref)
|
|
4184
|
+
]];
|
|
4185
|
+
},
|
|
4186
|
+
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
4187
|
+
const newJaxpr = jvpJaxpr(jaxpr);
|
|
4188
|
+
const outs = bind(Primitive.Jit, [
|
|
4189
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4190
|
+
...primals,
|
|
4191
|
+
...tangents
|
|
4192
|
+
], {
|
|
4193
|
+
name: `${name}_jvp`,
|
|
4194
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4195
|
+
numConsts: newJaxpr.consts.length
|
|
4196
|
+
});
|
|
4197
|
+
const n = outs.length / 2;
|
|
4198
|
+
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
4199
|
+
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
4200
|
+
return [primalsOut, tangentsOut];
|
|
4201
|
+
}
|
|
4202
|
+
};
|
|
4203
|
+
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
4204
|
+
function jvpJaxpr(jaxpr) {
|
|
4205
|
+
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
4206
|
+
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
4207
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
4208
|
+
jvpJaxprCache.set(jaxpr, newJaxpr);
|
|
4209
|
+
return newJaxpr;
|
|
4210
|
+
}
|
|
4211
|
+
function jvpFlat(f, primals, tangents) {
|
|
4212
|
+
try {
|
|
4213
|
+
var _usingCtx$1 = _usingCtx();
|
|
4214
|
+
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
4215
|
+
const trace$1 = new JVPTrace(main);
|
|
4216
|
+
const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
4217
|
+
const outs = f(...tracersIn);
|
|
4218
|
+
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
4219
|
+
return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
4220
|
+
} catch (_) {
|
|
4221
|
+
_usingCtx$1.e = _;
|
|
4222
|
+
} finally {
|
|
4223
|
+
_usingCtx$1.d();
|
|
4224
|
+
}
|
|
4225
|
+
}
|
|
4226
|
+
function jvp$1(f, primals, tangents) {
|
|
4227
|
+
const [primalsFlat, inTree] = flatten(primals);
|
|
4228
|
+
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4229
|
+
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4230
|
+
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
4231
|
+
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4232
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4233
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4234
|
+
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4235
|
+
return [primalsOut, tangentsOut];
|
|
3709
4236
|
}
|
|
3710
4237
|
|
|
3711
4238
|
//#endregion
|
|
@@ -3738,11 +4265,10 @@ function partialEvalFlat(f, pvalsIn) {
|
|
|
3738
4265
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3739
4266
|
const pvalsOut = tracersOut.map((t) => t.pval);
|
|
3740
4267
|
const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
|
|
3741
|
-
const
|
|
4268
|
+
const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
|
|
3742
4269
|
return {
|
|
3743
4270
|
jaxpr,
|
|
3744
|
-
pvalsOut
|
|
3745
|
-
consts
|
|
4271
|
+
pvalsOut
|
|
3746
4272
|
};
|
|
3747
4273
|
}
|
|
3748
4274
|
/**
|
|
@@ -3759,22 +4285,19 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3759
4285
|
const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
|
|
3760
4286
|
return [...primalsOut$1, ...tangentsOut];
|
|
3761
4287
|
};
|
|
3762
|
-
const { jaxpr, pvalsOut
|
|
4288
|
+
const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
|
|
3763
4289
|
const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
|
|
3764
4290
|
if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
|
|
3765
4291
|
const primalsOut = primalPvals.map((pval) => pval.val);
|
|
3766
4292
|
return {
|
|
3767
4293
|
primalsOut,
|
|
3768
|
-
jaxpr
|
|
3769
|
-
consts
|
|
4294
|
+
jaxpr
|
|
3770
4295
|
};
|
|
3771
4296
|
}
|
|
3772
4297
|
function linearizeFlat(f, primalsIn) {
|
|
3773
|
-
const { primalsOut, jaxpr
|
|
3774
|
-
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3775
|
-
const dispose$1 = () =>
|
|
3776
|
-
for (const c of consts) c.dispose();
|
|
3777
|
-
};
|
|
4298
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4299
|
+
const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
|
|
4300
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
3778
4301
|
return [
|
|
3779
4302
|
primalsOut,
|
|
3780
4303
|
fLin,
|
|
@@ -3858,7 +4381,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3858
4381
|
}
|
|
3859
4382
|
processPrimitive(primitive, tracers, params) {
|
|
3860
4383
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3861
|
-
if (primitive === Primitive.
|
|
4384
|
+
if (primitive === Primitive.Jit) {
|
|
3862
4385
|
const { name, jaxpr, numConsts } = params;
|
|
3863
4386
|
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3864
4387
|
}
|
|
@@ -3884,14 +4407,14 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3884
4407
|
* Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
|
|
3885
4408
|
* values as possible (with JIT) and forwarding the unknown ones.
|
|
3886
4409
|
*
|
|
3887
|
-
* Used when encountering a
|
|
4410
|
+
* Used when encountering a Jit rule during the trace.
|
|
3888
4411
|
*/
|
|
3889
4412
|
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3890
4413
|
jaxpr = jaxpr.flatten();
|
|
3891
4414
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3892
4415
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3893
4416
|
const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
|
|
3894
|
-
const outs1Res = bind(Primitive.
|
|
4417
|
+
const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3895
4418
|
name: `${name}_peval`,
|
|
3896
4419
|
jaxpr: jaxpr1,
|
|
3897
4420
|
numConsts: 0
|
|
@@ -3901,7 +4424,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3901
4424
|
const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
|
|
3902
4425
|
const recipe = {
|
|
3903
4426
|
type: "JaxprEqn",
|
|
3904
|
-
prim: Primitive.
|
|
4427
|
+
prim: Primitive.Jit,
|
|
3905
4428
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3906
4429
|
params: {
|
|
3907
4430
|
name: `${name}_resid`,
|
|
@@ -3930,7 +4453,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
|
|
|
3930
4453
|
const eqns1 = [];
|
|
3931
4454
|
const eqns2 = [];
|
|
3932
4455
|
for (const eqn of jaxpr.eqns) {
|
|
3933
|
-
if (eqn.primitive === Primitive.
|
|
4456
|
+
if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
|
|
3934
4457
|
const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
|
|
3935
4458
|
if (hasUnknowns) {
|
|
3936
4459
|
for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
|
|
@@ -4005,10 +4528,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
4005
4528
|
for (const t of tracersOut) t.dispose();
|
|
4006
4529
|
jaxpr = jaxpr.simplify();
|
|
4007
4530
|
if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
4008
|
-
return
|
|
4009
|
-
jaxpr,
|
|
4010
|
-
consts
|
|
4011
|
-
};
|
|
4531
|
+
return new ClosedJaxpr(jaxpr, consts);
|
|
4012
4532
|
}
|
|
4013
4533
|
/** Marker type for pullback, used by transpose rules. */
|
|
4014
4534
|
var UndefPrimal = class {
|
|
@@ -4200,317 +4720,151 @@ const transposeRules = {
|
|
|
4200
4720
|
cond.dispose();
|
|
4201
4721
|
return cts;
|
|
4202
4722
|
},
|
|
4203
|
-
[Primitive.
|
|
4204
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.
|
|
4205
|
-
|
|
4206
|
-
|
|
4207
|
-
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4208
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4209
|
-
return [reduce(ct, AluOp.Add, axis)];
|
|
4210
|
-
},
|
|
4211
|
-
[Primitive.Reshape]([ct], [x], _) {
|
|
4212
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4213
|
-
return [reshape$1(ct, x.aval.shape)];
|
|
4214
|
-
},
|
|
4215
|
-
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4216
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4217
|
-
return [flip$1(ct, axis)];
|
|
4218
|
-
},
|
|
4219
|
-
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4220
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4221
|
-
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4222
|
-
return [pad$1(ct, width)];
|
|
4723
|
+
[Primitive.Concatenate]([ct], inputs, { axis }) {
|
|
4724
|
+
if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
|
|
4725
|
+
const sizes = inputs.map((x) => x.aval.shape[axis]);
|
|
4726
|
+
return split$2(ct, axis, sizes);
|
|
4223
4727
|
},
|
|
4224
|
-
[Primitive.
|
|
4225
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.
|
|
4226
|
-
|
|
4227
|
-
return [shrink(ct, slice)];
|
|
4728
|
+
[Primitive.Split](cts, [x], { axis }) {
|
|
4729
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
|
|
4730
|
+
return [concatenate$1(cts, axis)];
|
|
4228
4731
|
},
|
|
4229
4732
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4230
4733
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4231
4734
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4232
|
-
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4233
|
-
},
|
|
4234
|
-
[Primitive.
|
|
4235
|
-
|
|
4236
|
-
|
|
4237
|
-
|
|
4238
|
-
|
|
4239
|
-
|
|
4240
|
-
|
|
4241
|
-
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
|
|
4245
|
-
|
|
4246
|
-
|
|
4247
|
-
|
|
4248
|
-
return
|
|
4249
|
-
}
|
|
4250
|
-
}
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4257
|
-
|
|
4258
|
-
|
|
4259
|
-
|
|
4260
|
-
|
|
4261
|
-
|
|
4262
|
-
|
|
4263
|
-
|
|
4264
|
-
|
|
4265
|
-
typecheckJaxpr(newJaxpr);
|
|
4266
|
-
const result = {
|
|
4267
|
-
newJaxpr,
|
|
4268
|
-
newConsts
|
|
4269
|
-
};
|
|
4270
|
-
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4271
|
-
transposeJaxprCache.get(jaxpr).set(cacheKey, result);
|
|
4272
|
-
return result;
|
|
4273
|
-
}
|
|
4274
|
-
function vjpFlat(f, primalsIn) {
|
|
4275
|
-
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
4276
|
-
const fVjp = (...cotangents) => {
|
|
4277
|
-
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4278
|
-
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
4279
|
-
};
|
|
4280
|
-
const dispose$1 = () => {
|
|
4281
|
-
for (const c of consts) c.dispose();
|
|
4282
|
-
};
|
|
4283
|
-
return [
|
|
4284
|
-
primalsOut,
|
|
4285
|
-
fVjp,
|
|
4286
|
-
dispose$1
|
|
4287
|
-
];
|
|
4288
|
-
}
|
|
4289
|
-
function vjp$1(f, ...primalsIn) {
|
|
4290
|
-
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4291
|
-
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4292
|
-
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4293
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4294
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4295
|
-
const fVjp = ((cotangentsOut) => {
|
|
4296
|
-
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4297
|
-
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4298
|
-
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4299
|
-
return unflatten(inTree, cotangentsInFlat);
|
|
4300
|
-
});
|
|
4301
|
-
fVjp.dispose = dispose$1;
|
|
4302
|
-
return [primalsOut, fVjp];
|
|
4303
|
-
}
|
|
4304
|
-
function grad$1(f) {
|
|
4305
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4306
|
-
return (...x) => {
|
|
4307
|
-
const [y, dx] = valueAndGradFn(...x);
|
|
4308
|
-
y.dispose();
|
|
4309
|
-
return dx;
|
|
4310
|
-
};
|
|
4311
|
-
}
|
|
4312
|
-
function valueAndGrad$1(f) {
|
|
4313
|
-
return (...x) => {
|
|
4314
|
-
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4315
|
-
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4316
|
-
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4317
|
-
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4318
|
-
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4319
|
-
for (const r of rest) dispose(r);
|
|
4320
|
-
fVjp.dispose();
|
|
4321
|
-
return [y, ct];
|
|
4322
|
-
};
|
|
4323
|
-
}
|
|
4324
|
-
function jacrev$1(f) {
|
|
4325
|
-
return function jacobianReverse(x) {
|
|
4326
|
-
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4327
|
-
const [size$1] = x.shape;
|
|
4328
|
-
const pullback = (ct) => {
|
|
4329
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4330
|
-
y.dispose();
|
|
4331
|
-
const [ret] = fVjp(ct);
|
|
4332
|
-
fVjp.dispose();
|
|
4333
|
-
return ret;
|
|
4334
|
-
};
|
|
4335
|
-
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4336
|
-
};
|
|
4337
|
-
}
|
|
4338
|
-
|
|
4339
|
-
//#endregion
|
|
4340
|
-
//#region src/library/lax.ts
|
|
4341
|
-
var lax_exports = {};
|
|
4342
|
-
__export(lax_exports, {
|
|
4343
|
-
conv: () => conv,
|
|
4344
|
-
convGeneralDilated: () => convGeneralDilated,
|
|
4345
|
-
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4346
|
-
dot: () => dot$1,
|
|
4347
|
-
erf: () => erf,
|
|
4348
|
-
erfc: () => erfc,
|
|
4349
|
-
reduceWindow: () => reduceWindow,
|
|
4350
|
-
stopGradient: () => stopGradient$1
|
|
4351
|
-
});
|
|
4352
|
-
/**
|
|
4353
|
-
* General dot product/contraction operator.
|
|
4354
|
-
*
|
|
4355
|
-
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
4356
|
-
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
4357
|
-
*/
|
|
4358
|
-
function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
4359
|
-
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
4360
|
-
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
4361
|
-
lc = lc.map((a) => checkAxis(a, lhs.ndim));
|
|
4362
|
-
rc = rc.map((a) => checkAxis(a, rhs.ndim));
|
|
4363
|
-
lb = lb.map((a) => checkAxis(a, lhs.ndim));
|
|
4364
|
-
rb = rb.map((a) => checkAxis(a, rhs.ndim));
|
|
4365
|
-
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
4366
|
-
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
4367
|
-
const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
4368
|
-
const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
4369
|
-
const lhs2 = lhs.transpose([
|
|
4370
|
-
...lb,
|
|
4371
|
-
...lf,
|
|
4372
|
-
...lc
|
|
4373
|
-
]);
|
|
4374
|
-
const rhs2 = rhs.transpose([
|
|
4375
|
-
...rb,
|
|
4376
|
-
...rf,
|
|
4377
|
-
...rc
|
|
4378
|
-
]);
|
|
4379
|
-
if (lc.length === 0) return mul(lhs2.reshape([
|
|
4380
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4381
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4382
|
-
...rep(rf.length, 1)
|
|
4383
|
-
]), rhs2.reshape([
|
|
4384
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4385
|
-
...rep(lf.length, 1),
|
|
4386
|
-
...rf.map((a) => rhs.shape[a])
|
|
4387
|
-
]));
|
|
4388
|
-
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
4389
|
-
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
4390
|
-
if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
4391
|
-
return dot$2(lhs2.reshape([
|
|
4392
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4393
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4394
|
-
...rep(rf.length, 1),
|
|
4395
|
-
prod(dotShapeX)
|
|
4396
|
-
]), rhs2.reshape([
|
|
4397
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4398
|
-
...rep(lf.length, 1),
|
|
4399
|
-
...rf.map((a) => rhs.shape[a]),
|
|
4400
|
-
prod(dotShapeY)
|
|
4401
|
-
]));
|
|
4402
|
-
}
|
|
4403
|
-
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4404
|
-
const padType = padding.toUpperCase();
|
|
4405
|
-
switch (padType) {
|
|
4406
|
-
case "VALID": return rep(inShape.length, [0, 0]);
|
|
4407
|
-
case "SAME":
|
|
4408
|
-
case "SAME_LOWER": {
|
|
4409
|
-
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
4410
|
-
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
4411
|
-
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
4412
|
-
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
4413
|
-
}
|
|
4414
|
-
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
4415
|
-
}
|
|
4416
|
-
}
|
|
4417
|
-
/**
|
|
4418
|
-
* General n-dimensional convolution operator, with optional dilation.
|
|
4419
|
-
*
|
|
4420
|
-
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
4421
|
-
* function in JAX, which wraps XLA's general convolution operator.
|
|
4422
|
-
*
|
|
4423
|
-
* Grouped convolutions are not supported right now.
|
|
4424
|
-
*/
|
|
4425
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
4426
|
-
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4427
|
-
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4428
|
-
if (typeof padding === "string") {
|
|
4429
|
-
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4430
|
-
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
4431
|
-
}
|
|
4432
|
-
if (featureGroupCount !== 1) {
|
|
4433
|
-
const G = featureGroupCount;
|
|
4434
|
-
const [N, C_in, ...xs] = lhs.shape;
|
|
4435
|
-
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
4436
|
-
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
4437
|
-
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
4438
|
-
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
4439
|
-
const lhsGrouped = moveaxis(lhs.reshape([
|
|
4440
|
-
N,
|
|
4441
|
-
G,
|
|
4442
|
-
C_in / G,
|
|
4443
|
-
...xs
|
|
4444
|
-
]), 1, 0);
|
|
4445
|
-
const rhsGrouped = rhs.reshape([
|
|
4446
|
-
G,
|
|
4447
|
-
C_out / G,
|
|
4448
|
-
C_in_per_group,
|
|
4449
|
-
...ks
|
|
4450
|
-
]);
|
|
4451
|
-
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
4452
|
-
vmapDims: 1,
|
|
4453
|
-
strides: windowStrides,
|
|
4454
|
-
padding,
|
|
4455
|
-
lhsDilation,
|
|
4456
|
-
rhsDilation
|
|
4735
|
+
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4736
|
+
},
|
|
4737
|
+
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4738
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4739
|
+
return [transpose$1(ct, invertPermutation(perm))];
|
|
4740
|
+
},
|
|
4741
|
+
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4742
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4743
|
+
return [reduce(ct, AluOp.Add, axis)];
|
|
4744
|
+
},
|
|
4745
|
+
[Primitive.Reshape]([ct], [x], _) {
|
|
4746
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4747
|
+
return [reshape$1(ct, x.aval.shape)];
|
|
4748
|
+
},
|
|
4749
|
+
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4750
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4751
|
+
return [flip$1(ct, axis)];
|
|
4752
|
+
},
|
|
4753
|
+
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4754
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4755
|
+
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4756
|
+
return [pad$1(ct, width)];
|
|
4757
|
+
},
|
|
4758
|
+
[Primitive.Pad]([ct], [x], { width }) {
|
|
4759
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
|
|
4760
|
+
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4761
|
+
return [shrink(ct, slice)];
|
|
4762
|
+
},
|
|
4763
|
+
[Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
|
|
4764
|
+
if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
|
|
4765
|
+
const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
|
|
4766
|
+
lower: true,
|
|
4767
|
+
unitDiagonal
|
|
4457
4768
|
});
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
|
|
4463
|
-
]);
|
|
4769
|
+
return [null, ctB];
|
|
4770
|
+
},
|
|
4771
|
+
[Primitive.Jit](cts, args, { name, jaxpr }) {
|
|
4772
|
+
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
4773
|
+
const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
|
|
4774
|
+
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
4775
|
+
const outs = bind(Primitive.Jit, [
|
|
4776
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4777
|
+
...residuals,
|
|
4778
|
+
...cts
|
|
4779
|
+
], {
|
|
4780
|
+
name: `${name}_t`,
|
|
4781
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4782
|
+
numConsts: newJaxpr.consts.length
|
|
4783
|
+
});
|
|
4784
|
+
let i = 0;
|
|
4785
|
+
return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
|
|
4464
4786
|
}
|
|
4465
|
-
|
|
4466
|
-
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
}
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
|
|
4477
|
-
|
|
4787
|
+
};
|
|
4788
|
+
const transposeJaxprCache = /* @__PURE__ */ new Map();
|
|
4789
|
+
function transposeJaxpr(jaxpr, undefPrimals) {
|
|
4790
|
+
const cacheKey = JSON.stringify(undefPrimals);
|
|
4791
|
+
const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
|
|
4792
|
+
if (prevResult) return prevResult;
|
|
4793
|
+
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
4794
|
+
const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
|
|
4795
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
|
|
4796
|
+
const args = [];
|
|
4797
|
+
let forwardInIdx = 0;
|
|
4798
|
+
for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
|
|
4799
|
+
else args.push(forwardIn[forwardInIdx++]);
|
|
4800
|
+
return evalJaxprTransposed(jaxpr, args, cotangents);
|
|
4801
|
+
})(forwardInTypes, outTypes);
|
|
4802
|
+
typecheckJaxpr(newJaxpr.jaxpr);
|
|
4803
|
+
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4804
|
+
transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
4805
|
+
return newJaxpr;
|
|
4478
4806
|
}
|
|
4479
|
-
|
|
4480
|
-
|
|
4481
|
-
|
|
4807
|
+
function vjpFlat(f, primalsIn) {
|
|
4808
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4809
|
+
const fVjp = (...cotangents) => {
|
|
4810
|
+
const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4811
|
+
return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
|
|
4812
|
+
};
|
|
4813
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
4814
|
+
return [
|
|
4815
|
+
primalsOut,
|
|
4816
|
+
fVjp,
|
|
4817
|
+
dispose$1
|
|
4818
|
+
];
|
|
4482
4819
|
}
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4490
|
-
|
|
4491
|
-
|
|
4820
|
+
function vjp$1(f, ...primalsIn) {
|
|
4821
|
+
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4822
|
+
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4823
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4824
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4825
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4826
|
+
const fVjp = ((cotangentsOut) => {
|
|
4827
|
+
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4828
|
+
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4829
|
+
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4830
|
+
return unflatten(inTree, cotangentsInFlat);
|
|
4831
|
+
});
|
|
4832
|
+
fVjp.dispose = dispose$1;
|
|
4833
|
+
return [primalsOut, fVjp];
|
|
4492
4834
|
}
|
|
4493
|
-
|
|
4494
|
-
|
|
4495
|
-
return
|
|
4835
|
+
function grad$1(f) {
|
|
4836
|
+
const valueAndGradFn = valueAndGrad$1(f);
|
|
4837
|
+
return (...x) => {
|
|
4838
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4839
|
+
y.dispose();
|
|
4840
|
+
return dx;
|
|
4841
|
+
};
|
|
4496
4842
|
}
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
|
|
4501
|
-
|
|
4502
|
-
|
|
4503
|
-
|
|
4504
|
-
|
|
4843
|
+
function valueAndGrad$1(f) {
|
|
4844
|
+
return (...x) => {
|
|
4845
|
+
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4846
|
+
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4847
|
+
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4848
|
+
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4849
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4850
|
+
for (const r of rest) dispose(r);
|
|
4851
|
+
fVjp.dispose();
|
|
4852
|
+
return [y, ct];
|
|
4853
|
+
};
|
|
4505
4854
|
}
|
|
4506
|
-
|
|
4507
|
-
|
|
4508
|
-
|
|
4509
|
-
|
|
4510
|
-
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4855
|
+
function jacrev$1(f) {
|
|
4856
|
+
return function jacobianReverse(x) {
|
|
4857
|
+
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4858
|
+
const [size$1] = x.shape;
|
|
4859
|
+
const pullback = (ct) => {
|
|
4860
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
4861
|
+
y.dispose();
|
|
4862
|
+
const [ret] = fVjp(ct);
|
|
4863
|
+
fVjp.dispose();
|
|
4864
|
+
return ret;
|
|
4865
|
+
};
|
|
4866
|
+
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4867
|
+
};
|
|
4514
4868
|
}
|
|
4515
4869
|
|
|
4516
4870
|
//#endregion
|
|
@@ -4650,8 +5004,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4650
5004
|
const idx = lhsIndex[j];
|
|
4651
5005
|
const dim = shape$1[j];
|
|
4652
5006
|
const existing = sizeMap.get(idx);
|
|
4653
|
-
if (existing === void 0) sizeMap.set(idx, dim);
|
|
4654
|
-
else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
5007
|
+
if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
|
|
5008
|
+
else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
4655
5009
|
}
|
|
4656
5010
|
}
|
|
4657
5011
|
for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
|
|
@@ -4659,52 +5013,410 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4659
5013
|
for (const idx of rhsIndex) if (!sizeMap.has(idx)) throw new Error(`Output index ${idx} not present in einsum inputs`);
|
|
4660
5014
|
return sizeMap;
|
|
4661
5015
|
}
|
|
4662
|
-
const einsumPathCache = /* @__PURE__ */ new Map();
|
|
4663
|
-
function computeEinsumPath(input, method) {
|
|
4664
|
-
if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
|
|
4665
|
-
return runWithCache(einsumPathCache, [input, method], () => {
|
|
4666
|
-
const sizeMap = computeSizeMap(input);
|
|
4667
|
-
if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
|
|
4668
|
-
switch (method) {
|
|
4669
|
-
case "naive": return computePathNaive(input, sizeMap);
|
|
4670
|
-
case "optimal": return computePathOptimal(input, sizeMap);
|
|
4671
|
-
default: throw new Error(`Unknown computePath method: ${method}`);
|
|
4672
|
-
}
|
|
4673
|
-
});
|
|
5016
|
+
const einsumPathCache = /* @__PURE__ */ new Map();
|
|
5017
|
+
function computeEinsumPath(input, method) {
|
|
5018
|
+
if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
|
|
5019
|
+
return runWithCache(einsumPathCache, [input, method], () => {
|
|
5020
|
+
const sizeMap = computeSizeMap(input);
|
|
5021
|
+
if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
|
|
5022
|
+
switch (method) {
|
|
5023
|
+
case "naive": return computePathNaive(input, sizeMap);
|
|
5024
|
+
case "optimal": return computePathOptimal(input, sizeMap);
|
|
5025
|
+
default: throw new Error(`Unknown computePath method: ${method}`);
|
|
5026
|
+
}
|
|
5027
|
+
});
|
|
5028
|
+
}
|
|
5029
|
+
function computePathNaive(input, sizeMap) {
|
|
5030
|
+
const n = input.shapes.length;
|
|
5031
|
+
const path = [];
|
|
5032
|
+
let lastTensorIndex = 0;
|
|
5033
|
+
for (let i = 1; i < n; i++) {
|
|
5034
|
+
path.push([lastTensorIndex, i]);
|
|
5035
|
+
lastTensorIndex = n + i - 1;
|
|
5036
|
+
}
|
|
5037
|
+
return new EinsumPath(input, sizeMap, path);
|
|
5038
|
+
}
|
|
5039
|
+
function computePathOptimal(input, sizeMap) {
|
|
5040
|
+
const n = input.shapes.length;
|
|
5041
|
+
let bestPath = null;
|
|
5042
|
+
let bestFlops = null;
|
|
5043
|
+
for (const path of allPaths(range(n), n)) {
|
|
5044
|
+
const flops = approximatePathFlops(input, sizeMap, path);
|
|
5045
|
+
if (bestFlops === null || flops < bestFlops) {
|
|
5046
|
+
bestPath = path;
|
|
5047
|
+
bestFlops = flops;
|
|
5048
|
+
}
|
|
5049
|
+
}
|
|
5050
|
+
return new EinsumPath(input, sizeMap, bestPath);
|
|
5051
|
+
}
|
|
5052
|
+
function* allPaths(tensors, next) {
|
|
5053
|
+
if (tensors.length === 2) {
|
|
5054
|
+
yield [[tensors[0], tensors[1]]];
|
|
5055
|
+
return;
|
|
5056
|
+
}
|
|
5057
|
+
for (let i = 0; i < tensors.length; i++) for (let j = i + 1; j < tensors.length; j++) {
|
|
5058
|
+
const pair = [tensors[i], tensors[j]];
|
|
5059
|
+
const newTensors = tensors.filter((t) => t !== pair[0] && t !== pair[1]);
|
|
5060
|
+
newTensors.push(next);
|
|
5061
|
+
for (const subpath of allPaths(newTensors, next + 1)) yield [pair, ...subpath];
|
|
5062
|
+
}
|
|
5063
|
+
}
|
|
5064
|
+
|
|
5065
|
+
//#endregion
|
|
5066
|
+
//#region src/library/numpy-fft.ts
|
|
5067
|
+
var numpy_fft_exports = {};
|
|
5068
|
+
__export(numpy_fft_exports, {
|
|
5069
|
+
fft: () => fft,
|
|
5070
|
+
ifft: () => ifft
|
|
5071
|
+
});
|
|
5072
|
+
function checkPairInput(name, a) {
|
|
5073
|
+
const fullName = `jax.numpy.fft.${name}`;
|
|
5074
|
+
if (!deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
|
|
5075
|
+
if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
|
|
5076
|
+
if (!isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
|
|
5077
|
+
}
|
|
5078
|
+
function checkPowerOfTwo(name, n) {
|
|
5079
|
+
if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
|
|
5080
|
+
}
|
|
5081
|
+
const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
|
|
5082
|
+
const half = 2 ** i;
|
|
5083
|
+
real = real.reshape([-1, 2 * half]);
|
|
5084
|
+
imag = imag.reshape([-1, 2 * half]);
|
|
5085
|
+
const k = arange(0, half, 1, { dtype: real.dtype });
|
|
5086
|
+
const theta = k.mul(-Math.PI / half);
|
|
5087
|
+
const wr = cos(theta.ref);
|
|
5088
|
+
const wi = sin(theta);
|
|
5089
|
+
const ur = real.ref.slice([], [0, half]);
|
|
5090
|
+
const ui = imag.ref.slice([], [0, half]);
|
|
5091
|
+
const vr = real.slice([], [half, 2 * half]);
|
|
5092
|
+
const vi = imag.slice([], [half, 2 * half]);
|
|
5093
|
+
const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
|
|
5094
|
+
const ti = vr.mul(wi).add(vi.mul(wr));
|
|
5095
|
+
return {
|
|
5096
|
+
real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
|
|
5097
|
+
imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
|
|
5098
|
+
};
|
|
5099
|
+
}, { staticArgnums: [0] });
|
|
5100
|
+
/**
|
|
5101
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
5102
|
+
*
|
|
5103
|
+
* Currently, the size of the axis must be a power of two.
|
|
5104
|
+
*/
|
|
5105
|
+
function fft(a, axis = -1) {
|
|
5106
|
+
checkPairInput("fft", a);
|
|
5107
|
+
let { real, imag } = a;
|
|
5108
|
+
axis = checkAxis(axis, real.ndim);
|
|
5109
|
+
const n = real.shape[axis];
|
|
5110
|
+
checkPowerOfTwo("fft", n);
|
|
5111
|
+
const logN = Math.log2(n);
|
|
5112
|
+
let perm = null;
|
|
5113
|
+
if (axis !== real.ndim - 1) {
|
|
5114
|
+
perm = range(real.ndim);
|
|
5115
|
+
perm.splice(axis, 1);
|
|
5116
|
+
perm.push(axis);
|
|
5117
|
+
real = real.transpose(perm);
|
|
5118
|
+
imag = imag.transpose(perm);
|
|
5119
|
+
}
|
|
5120
|
+
const originalShape = real.shape;
|
|
5121
|
+
real = real.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
|
|
5122
|
+
imag = imag.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
|
|
5123
|
+
for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
|
|
5124
|
+
real,
|
|
5125
|
+
imag
|
|
5126
|
+
}));
|
|
5127
|
+
real = real.reshape(originalShape);
|
|
5128
|
+
imag = imag.reshape(originalShape);
|
|
5129
|
+
if (perm !== null) {
|
|
5130
|
+
real = real.transpose(invertPermutation(perm));
|
|
5131
|
+
imag = imag.transpose(invertPermutation(perm));
|
|
5132
|
+
}
|
|
5133
|
+
return {
|
|
5134
|
+
real,
|
|
5135
|
+
imag
|
|
5136
|
+
};
|
|
5137
|
+
}
|
|
5138
|
+
/**
|
|
5139
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
5140
|
+
*
|
|
5141
|
+
* Currently, the size of the axis must be a power of two.
|
|
5142
|
+
*/
|
|
5143
|
+
function ifft(a, axis = -1) {
|
|
5144
|
+
checkPairInput("ifft", a);
|
|
5145
|
+
let { real, imag } = a;
|
|
5146
|
+
axis = checkAxis(axis, real.ndim);
|
|
5147
|
+
const n = real.shape[axis];
|
|
5148
|
+
checkPowerOfTwo("ifft", n);
|
|
5149
|
+
imag = imag.mul(-1);
|
|
5150
|
+
const result = fft({
|
|
5151
|
+
real,
|
|
5152
|
+
imag
|
|
5153
|
+
}, axis);
|
|
5154
|
+
return {
|
|
5155
|
+
real: result.real.div(n),
|
|
5156
|
+
imag: result.imag.mul(-1).div(n)
|
|
5157
|
+
};
|
|
5158
|
+
}
|
|
5159
|
+
|
|
5160
|
+
//#endregion
|
|
5161
|
+
//#region src/library/numpy-linalg.ts
|
|
5162
|
+
var numpy_linalg_exports = {};
|
|
5163
|
+
__export(numpy_linalg_exports, {
|
|
5164
|
+
cholesky: () => cholesky,
|
|
5165
|
+
det: () => det,
|
|
5166
|
+
diagonal: () => diagonal,
|
|
5167
|
+
inv: () => inv,
|
|
5168
|
+
lstsq: () => lstsq,
|
|
5169
|
+
matmul: () => matmul,
|
|
5170
|
+
matrixPower: () => matrixPower,
|
|
5171
|
+
matrixTranspose: () => matrixTranspose,
|
|
5172
|
+
outer: () => outer,
|
|
5173
|
+
slogdet: () => slogdet,
|
|
5174
|
+
solve: () => solve,
|
|
5175
|
+
tensordot: () => tensordot,
|
|
5176
|
+
trace: () => trace,
|
|
5177
|
+
vecdot: () => vecdot
|
|
5178
|
+
});
|
|
5179
|
+
function checkSquare(name, a) {
|
|
5180
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
5181
|
+
return a.shape[a.ndim - 1];
|
|
5182
|
+
}
|
|
5183
|
+
/**
|
|
5184
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
5185
|
+
*
|
|
5186
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
5187
|
+
* the input matrix, which is on by default.
|
|
5188
|
+
*/
|
|
5189
|
+
function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
5190
|
+
a = fudgeArray(a);
|
|
5191
|
+
checkSquare("cholesky", a);
|
|
5192
|
+
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5193
|
+
return cholesky$1(a, { upper });
|
|
5194
|
+
}
|
|
5195
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
5196
|
+
function det(a) {
|
|
5197
|
+
a = fudgeArray(a);
|
|
5198
|
+
const n = checkSquare("det", a);
|
|
5199
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5200
|
+
permutation.dispose();
|
|
5201
|
+
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5202
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5203
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5204
|
+
return prod$1(diag$1, -1).mul(sign$1);
|
|
5205
|
+
}
|
|
5206
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
5207
|
+
function inv(a) {
|
|
5208
|
+
a = fudgeArray(a);
|
|
5209
|
+
const n = checkSquare("inv", a);
|
|
5210
|
+
return solve(a, eye(n));
|
|
5211
|
+
}
|
|
5212
|
+
/**
|
|
5213
|
+
* Return the least-squares solution to a linear equation.
|
|
5214
|
+
*
|
|
5215
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
5216
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
5217
|
+
*
|
|
5218
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
5219
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
5220
|
+
*
|
|
5221
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
5222
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
5223
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
5224
|
+
*/
|
|
5225
|
+
function lstsq(a, b) {
|
|
5226
|
+
a = fudgeArray(a);
|
|
5227
|
+
b = fudgeArray(b);
|
|
5228
|
+
if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
|
|
5229
|
+
const [m, n] = a.shape;
|
|
5230
|
+
if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
|
|
5231
|
+
const at = matrixTranspose(a.ref);
|
|
5232
|
+
if (m <= n) {
|
|
5233
|
+
const aat = matmul(a, at.ref);
|
|
5234
|
+
const l = cholesky(aat, { symmetrizeInput: false });
|
|
5235
|
+
const lb = triangularSolve(l.ref, b, {
|
|
5236
|
+
leftSide: true,
|
|
5237
|
+
lower: true
|
|
5238
|
+
});
|
|
5239
|
+
const llb = triangularSolve(l, lb, {
|
|
5240
|
+
leftSide: true,
|
|
5241
|
+
transposeA: true
|
|
5242
|
+
});
|
|
5243
|
+
return matmul(at, llb.ref);
|
|
5244
|
+
} else {
|
|
5245
|
+
const ata = matmul(at.ref, a);
|
|
5246
|
+
const l = cholesky(ata, { symmetrizeInput: false });
|
|
5247
|
+
const atb = matmul(at, b);
|
|
5248
|
+
const lb = triangularSolve(l.ref, atb, {
|
|
5249
|
+
leftSide: true,
|
|
5250
|
+
lower: true
|
|
5251
|
+
});
|
|
5252
|
+
const llb = triangularSolve(l, lb, {
|
|
5253
|
+
leftSide: true,
|
|
5254
|
+
transposeA: true
|
|
5255
|
+
});
|
|
5256
|
+
return llb;
|
|
5257
|
+
}
|
|
5258
|
+
}
|
|
5259
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
5260
|
+
function matrixPower(a, n) {
|
|
5261
|
+
if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
|
|
5262
|
+
a = fudgeArray(a);
|
|
5263
|
+
const m = checkSquare("matrixPower", a);
|
|
5264
|
+
if (n === 0) {
|
|
5265
|
+
a.dispose();
|
|
5266
|
+
return broadcastTo(eye(m), a.shape);
|
|
5267
|
+
}
|
|
5268
|
+
if (n < 0) {
|
|
5269
|
+
a = inv(a);
|
|
5270
|
+
n = -n;
|
|
5271
|
+
}
|
|
5272
|
+
let result = null;
|
|
5273
|
+
let a2k = a;
|
|
5274
|
+
for (let k = 0; n; k++) {
|
|
5275
|
+
if (k > 0) a2k = matmul(a2k.ref, a2k);
|
|
5276
|
+
if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
|
|
5277
|
+
n = Math.floor(n / 2);
|
|
5278
|
+
}
|
|
5279
|
+
a2k.dispose();
|
|
5280
|
+
return result;
|
|
5281
|
+
}
|
|
5282
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
5283
|
+
function slogdet(a) {
|
|
5284
|
+
a = fudgeArray(a);
|
|
5285
|
+
const n = checkSquare("slogdet", a);
|
|
5286
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5287
|
+
permutation.dispose();
|
|
5288
|
+
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5289
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5290
|
+
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
5291
|
+
const logabsdet = log(absolute(diag$1)).sum(-1);
|
|
5292
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5293
|
+
return [sign$1, logabsdet];
|
|
4674
5294
|
}
|
|
4675
|
-
|
|
4676
|
-
|
|
4677
|
-
|
|
4678
|
-
|
|
4679
|
-
|
|
4680
|
-
|
|
4681
|
-
|
|
4682
|
-
|
|
4683
|
-
|
|
5295
|
+
/**
|
|
5296
|
+
* Solve a linear system of equations.
|
|
5297
|
+
*
|
|
5298
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
5299
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
5300
|
+
*
|
|
5301
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
5302
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
5303
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
5304
|
+
*/
|
|
5305
|
+
function solve(a, b) {
|
|
5306
|
+
a = fudgeArray(a);
|
|
5307
|
+
b = fudgeArray(b);
|
|
5308
|
+
const n = checkSquare("solve", a);
|
|
5309
|
+
if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
|
|
5310
|
+
const bIs1d = b.ndim === 1;
|
|
5311
|
+
if (bIs1d) b = b.reshape([...b.shape, 1]);
|
|
5312
|
+
if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
|
|
5313
|
+
const m = b.shape[b.ndim - 1];
|
|
5314
|
+
const batchDims = generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
5315
|
+
a = broadcastTo(a, [
|
|
5316
|
+
...batchDims,
|
|
5317
|
+
n,
|
|
5318
|
+
n
|
|
5319
|
+
]);
|
|
5320
|
+
b = broadcastTo(b, [
|
|
5321
|
+
...batchDims,
|
|
5322
|
+
n,
|
|
5323
|
+
m
|
|
5324
|
+
]);
|
|
5325
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5326
|
+
pivots.dispose();
|
|
5327
|
+
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5328
|
+
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5329
|
+
leftSide: true,
|
|
5330
|
+
lower: true,
|
|
5331
|
+
unitDiagonal: true
|
|
5332
|
+
});
|
|
5333
|
+
let x = triangularSolve(lu$2, LPb.ref, {
|
|
5334
|
+
leftSide: true,
|
|
5335
|
+
lower: false
|
|
5336
|
+
});
|
|
5337
|
+
if (bIs1d) x = squeeze(x, -1);
|
|
5338
|
+
return x;
|
|
4684
5339
|
}
|
|
4685
|
-
|
|
4686
|
-
|
|
4687
|
-
|
|
4688
|
-
|
|
4689
|
-
|
|
4690
|
-
|
|
4691
|
-
|
|
4692
|
-
|
|
4693
|
-
|
|
4694
|
-
|
|
5340
|
+
|
|
5341
|
+
//#endregion
|
|
5342
|
+
//#region src/library/numpy/dtype-info.ts
|
|
5343
|
+
/** Machine limits for floating-point types. */
|
|
5344
|
+
function finfo(dtype) {
|
|
5345
|
+
if (!isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
|
|
5346
|
+
switch (dtype) {
|
|
5347
|
+
case DType.Float16: return Object.freeze({
|
|
5348
|
+
bits: 16,
|
|
5349
|
+
dtype: DType.Float16,
|
|
5350
|
+
eps: 2 ** -10,
|
|
5351
|
+
epsneg: 2 ** -11,
|
|
5352
|
+
machep: -10,
|
|
5353
|
+
max: 65504,
|
|
5354
|
+
maxexp: 16,
|
|
5355
|
+
min: -65504,
|
|
5356
|
+
minexp: -14,
|
|
5357
|
+
negep: -24,
|
|
5358
|
+
nexp: 5,
|
|
5359
|
+
nmant: 10,
|
|
5360
|
+
precision: 3,
|
|
5361
|
+
resolution: .001,
|
|
5362
|
+
smallestNormal: 2 ** -14,
|
|
5363
|
+
smallestSubnormal: 2 ** -24
|
|
5364
|
+
});
|
|
5365
|
+
case DType.Float32: return Object.freeze({
|
|
5366
|
+
bits: 32,
|
|
5367
|
+
dtype: DType.Float32,
|
|
5368
|
+
eps: 2 ** -23,
|
|
5369
|
+
epsneg: 2 ** -24,
|
|
5370
|
+
machep: -23,
|
|
5371
|
+
max: 34028234663852886e22,
|
|
5372
|
+
maxexp: 128,
|
|
5373
|
+
min: -34028234663852886e22,
|
|
5374
|
+
minexp: -126,
|
|
5375
|
+
negep: -24,
|
|
5376
|
+
nexp: 8,
|
|
5377
|
+
nmant: 23,
|
|
5378
|
+
precision: 6,
|
|
5379
|
+
resolution: 1e-6,
|
|
5380
|
+
smallestNormal: 2 ** -126,
|
|
5381
|
+
smallestSubnormal: 2 ** -149
|
|
5382
|
+
});
|
|
5383
|
+
case DType.Float64: return Object.freeze({
|
|
5384
|
+
bits: 64,
|
|
5385
|
+
dtype: DType.Float64,
|
|
5386
|
+
eps: 2 ** -52,
|
|
5387
|
+
epsneg: 2 ** -53,
|
|
5388
|
+
machep: -52,
|
|
5389
|
+
max: Number.MAX_VALUE,
|
|
5390
|
+
maxexp: 1024,
|
|
5391
|
+
min: -Number.MAX_VALUE,
|
|
5392
|
+
minexp: -1022,
|
|
5393
|
+
negep: -53,
|
|
5394
|
+
nexp: 11,
|
|
5395
|
+
nmant: 52,
|
|
5396
|
+
precision: 15,
|
|
5397
|
+
resolution: 1e-15,
|
|
5398
|
+
smallestNormal: 2 ** -1022,
|
|
5399
|
+
smallestSubnormal: 2 ** -1074
|
|
5400
|
+
});
|
|
5401
|
+
default: throw new Error(`finfo: unsupported dtype ${dtype}`);
|
|
4695
5402
|
}
|
|
4696
|
-
return new EinsumPath(input, sizeMap, bestPath);
|
|
4697
5403
|
}
|
|
4698
|
-
|
|
4699
|
-
|
|
4700
|
-
|
|
4701
|
-
return
|
|
4702
|
-
|
|
4703
|
-
|
|
4704
|
-
|
|
4705
|
-
|
|
4706
|
-
|
|
4707
|
-
|
|
5404
|
+
/** Machine limits for integer types. */
|
|
5405
|
+
function iinfo(dtype) {
|
|
5406
|
+
switch (dtype) {
|
|
5407
|
+
case DType.Int32: return Object.freeze({
|
|
5408
|
+
bits: 32,
|
|
5409
|
+
dtype: DType.Int32,
|
|
5410
|
+
max: 2147483647,
|
|
5411
|
+
min: -2147483648
|
|
5412
|
+
});
|
|
5413
|
+
case DType.Uint32: return Object.freeze({
|
|
5414
|
+
bits: 32,
|
|
5415
|
+
dtype: DType.Uint32,
|
|
5416
|
+
max: 4294967295,
|
|
5417
|
+
min: 0
|
|
5418
|
+
});
|
|
5419
|
+
default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
|
|
4708
5420
|
}
|
|
4709
5421
|
}
|
|
4710
5422
|
|
|
@@ -4714,28 +5426,32 @@ var numpy_exports = {};
|
|
|
4714
5426
|
__export(numpy_exports, {
|
|
4715
5427
|
Array: () => Array$1,
|
|
4716
5428
|
DType: () => DType,
|
|
4717
|
-
abs: () =>
|
|
5429
|
+
abs: () => absolute,
|
|
4718
5430
|
absolute: () => absolute,
|
|
4719
5431
|
acos: () => acos,
|
|
4720
|
-
acosh: () =>
|
|
5432
|
+
acosh: () => arccosh,
|
|
4721
5433
|
add: () => add,
|
|
5434
|
+
all: () => all,
|
|
4722
5435
|
allclose: () => allclose,
|
|
5436
|
+
any: () => any,
|
|
4723
5437
|
arange: () => arange,
|
|
4724
|
-
arccos: () =>
|
|
5438
|
+
arccos: () => acos,
|
|
4725
5439
|
arccosh: () => arccosh,
|
|
5440
|
+
arcsin: () => asin,
|
|
4726
5441
|
arcsinh: () => arcsinh,
|
|
4727
|
-
arctan: () =>
|
|
4728
|
-
arctan2: () =>
|
|
5442
|
+
arctan: () => atan,
|
|
5443
|
+
arctan2: () => atan2,
|
|
4729
5444
|
arctanh: () => arctanh,
|
|
4730
5445
|
argmax: () => argmax,
|
|
4731
5446
|
argmin: () => argmin,
|
|
5447
|
+
argsort: () => argsort,
|
|
4732
5448
|
array: () => array,
|
|
4733
5449
|
asin: () => asin,
|
|
4734
|
-
asinh: () =>
|
|
5450
|
+
asinh: () => arcsinh,
|
|
4735
5451
|
astype: () => astype,
|
|
4736
5452
|
atan: () => atan,
|
|
4737
5453
|
atan2: () => atan2,
|
|
4738
|
-
atanh: () =>
|
|
5454
|
+
atanh: () => arctanh,
|
|
4739
5455
|
bool: () => bool,
|
|
4740
5456
|
broadcastArrays: () => broadcastArrays,
|
|
4741
5457
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -4745,16 +5461,21 @@ __export(numpy_exports, {
|
|
|
4745
5461
|
clip: () => clip,
|
|
4746
5462
|
columnStack: () => columnStack,
|
|
4747
5463
|
concatenate: () => concatenate,
|
|
5464
|
+
convolve: () => convolve,
|
|
5465
|
+
corrcoef: () => corrcoef,
|
|
5466
|
+
correlate: () => correlate,
|
|
4748
5467
|
cos: () => cos,
|
|
4749
5468
|
cosh: () => cosh,
|
|
5469
|
+
cov: () => cov,
|
|
4750
5470
|
cumsum: () => cumsum,
|
|
4751
|
-
cumulativeSum: () =>
|
|
5471
|
+
cumulativeSum: () => cumsum,
|
|
4752
5472
|
deg2rad: () => deg2rad,
|
|
4753
5473
|
degrees: () => degrees,
|
|
4754
5474
|
diag: () => diag,
|
|
4755
5475
|
diagonal: () => diagonal,
|
|
4756
|
-
divide: () =>
|
|
4757
|
-
|
|
5476
|
+
divide: () => trueDivide,
|
|
5477
|
+
divmod: () => divmod,
|
|
5478
|
+
dot: () => dot$1,
|
|
4758
5479
|
dstack: () => dstack,
|
|
4759
5480
|
e: () => e,
|
|
4760
5481
|
einsum: () => einsum,
|
|
@@ -4762,8 +5483,11 @@ __export(numpy_exports, {
|
|
|
4762
5483
|
eulerGamma: () => eulerGamma,
|
|
4763
5484
|
exp: () => exp,
|
|
4764
5485
|
exp2: () => exp2,
|
|
5486
|
+
expandDims: () => expandDims,
|
|
4765
5487
|
expm1: () => expm1,
|
|
4766
5488
|
eye: () => eye,
|
|
5489
|
+
fft: () => numpy_fft_exports,
|
|
5490
|
+
finfo: () => finfo,
|
|
4767
5491
|
flip: () => flip,
|
|
4768
5492
|
fliplr: () => fliplr,
|
|
4769
5493
|
flipud: () => flipud,
|
|
@@ -4771,6 +5495,7 @@ __export(numpy_exports, {
|
|
|
4771
5495
|
float32: () => float32,
|
|
4772
5496
|
float64: () => float64,
|
|
4773
5497
|
floor: () => floor,
|
|
5498
|
+
floorDivide: () => floorDivide,
|
|
4774
5499
|
fmod: () => fmod,
|
|
4775
5500
|
frexp: () => frexp,
|
|
4776
5501
|
full: () => full,
|
|
@@ -4783,6 +5508,7 @@ __export(numpy_exports, {
|
|
|
4783
5508
|
hstack: () => hstack,
|
|
4784
5509
|
hypot: () => hypot,
|
|
4785
5510
|
identity: () => identity$1,
|
|
5511
|
+
iinfo: () => iinfo,
|
|
4786
5512
|
inf: () => inf,
|
|
4787
5513
|
inner: () => inner,
|
|
4788
5514
|
int32: () => int32,
|
|
@@ -4794,12 +5520,15 @@ __export(numpy_exports, {
|
|
|
4794
5520
|
ldexp: () => ldexp,
|
|
4795
5521
|
less: () => less,
|
|
4796
5522
|
lessEqual: () => lessEqual,
|
|
5523
|
+
linalg: () => numpy_linalg_exports,
|
|
4797
5524
|
linspace: () => linspace,
|
|
4798
5525
|
log: () => log,
|
|
4799
5526
|
log10: () => log10,
|
|
4800
5527
|
log1p: () => log1p,
|
|
4801
5528
|
log2: () => log2,
|
|
5529
|
+
logspace: () => logspace,
|
|
4802
5530
|
matmul: () => matmul,
|
|
5531
|
+
matrixTranspose: () => matrixTranspose,
|
|
4803
5532
|
max: () => max,
|
|
4804
5533
|
maximum: () => maximum,
|
|
4805
5534
|
mean: () => mean,
|
|
@@ -4816,10 +5545,10 @@ __export(numpy_exports, {
|
|
|
4816
5545
|
onesLike: () => onesLike,
|
|
4817
5546
|
outer: () => outer,
|
|
4818
5547
|
pad: () => pad,
|
|
4819
|
-
permuteDims: () =>
|
|
5548
|
+
permuteDims: () => transpose,
|
|
4820
5549
|
pi: () => pi,
|
|
4821
5550
|
positive: () => positive,
|
|
4822
|
-
pow: () =>
|
|
5551
|
+
pow: () => power,
|
|
4823
5552
|
power: () => power,
|
|
4824
5553
|
prod: () => prod$1,
|
|
4825
5554
|
promoteTypes: () => promoteTypes,
|
|
@@ -4834,8 +5563,11 @@ __export(numpy_exports, {
|
|
|
4834
5563
|
shape: () => shape,
|
|
4835
5564
|
sign: () => sign,
|
|
4836
5565
|
sin: () => sin,
|
|
5566
|
+
sinc: () => sinc,
|
|
4837
5567
|
sinh: () => sinh,
|
|
4838
5568
|
size: () => size,
|
|
5569
|
+
sort: () => sort,
|
|
5570
|
+
split: () => split$1,
|
|
4839
5571
|
sqrt: () => sqrt,
|
|
4840
5572
|
square: () => square,
|
|
4841
5573
|
squeeze: () => squeeze,
|
|
@@ -4843,6 +5575,7 @@ __export(numpy_exports, {
|
|
|
4843
5575
|
std: () => std,
|
|
4844
5576
|
subtract: () => subtract,
|
|
4845
5577
|
sum: () => sum,
|
|
5578
|
+
take: () => take,
|
|
4846
5579
|
tan: () => tan,
|
|
4847
5580
|
tanh: () => tanh,
|
|
4848
5581
|
tensordot: () => tensordot,
|
|
@@ -5000,6 +5733,26 @@ function min(a, axis = null, opts) {
|
|
|
5000
5733
|
function max(a, axis = null, opts) {
|
|
5001
5734
|
return reduce(a, AluOp.Max, axis, opts);
|
|
5002
5735
|
}
|
|
5736
|
+
/**
|
|
5737
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5738
|
+
*
|
|
5739
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5740
|
+
* removed. If axis is None, returns a scalar.
|
|
5741
|
+
*/
|
|
5742
|
+
function all(a, axis = null, opts) {
|
|
5743
|
+
a = fudgeArray(a).astype(DType.Bool);
|
|
5744
|
+
return min(a, axis, opts);
|
|
5745
|
+
}
|
|
5746
|
+
/**
|
|
5747
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5748
|
+
*
|
|
5749
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5750
|
+
* removed. If axis is None, returns a scalar.
|
|
5751
|
+
*/
|
|
5752
|
+
function any(a, axis = null, opts) {
|
|
5753
|
+
a = fudgeArray(a).astype(DType.Bool);
|
|
5754
|
+
return max(a, axis, opts);
|
|
5755
|
+
}
|
|
5003
5756
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5004
5757
|
function ptp(a, axis = null, opts) {
|
|
5005
5758
|
a = fudgeArray(a);
|
|
@@ -5074,8 +5827,6 @@ function cumsum(a, axis) {
|
|
|
5074
5827
|
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5075
5828
|
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5076
5829
|
}
|
|
5077
|
-
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
5078
|
-
const cumulativeSum = cumsum;
|
|
5079
5830
|
/** Reverse the elements in an array along the given axes. */
|
|
5080
5831
|
function flip(x, axis = null) {
|
|
5081
5832
|
const nd = ndim(x);
|
|
@@ -5083,6 +5834,45 @@ function flip(x, axis = null) {
|
|
|
5083
5834
|
return flip$1(x, axis);
|
|
5084
5835
|
}
|
|
5085
5836
|
/**
|
|
5837
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
5838
|
+
*
|
|
5839
|
+
* @param a - The input array to split.
|
|
5840
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
5841
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
5842
|
+
* specifies the indices at which to split the array.
|
|
5843
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
5844
|
+
*/
|
|
5845
|
+
function split$1(a, indicesOrSections, axis = 0) {
|
|
5846
|
+
a = fudgeArray(a);
|
|
5847
|
+
axis = checkAxis(axis, a.ndim);
|
|
5848
|
+
const size$1 = a.shape[axis];
|
|
5849
|
+
let sizes;
|
|
5850
|
+
if (typeof indicesOrSections === "number") {
|
|
5851
|
+
if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
|
|
5852
|
+
const partSize = size$1 / indicesOrSections;
|
|
5853
|
+
sizes = rep(indicesOrSections, partSize);
|
|
5854
|
+
} else {
|
|
5855
|
+
const indices = indicesOrSections;
|
|
5856
|
+
sizes = [indices[0]];
|
|
5857
|
+
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5858
|
+
sizes.push(size$1 - indices[indices.length - 1]);
|
|
5859
|
+
}
|
|
5860
|
+
const results = [];
|
|
5861
|
+
for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
|
|
5862
|
+
results.push(a);
|
|
5863
|
+
break;
|
|
5864
|
+
} else if (i + 8 >= sizes.length) {
|
|
5865
|
+
results.push(...split$2(a, axis, sizes.slice(i)));
|
|
5866
|
+
break;
|
|
5867
|
+
} else {
|
|
5868
|
+
const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
|
|
5869
|
+
const outs = split$2(a, axis, groupSizes);
|
|
5870
|
+
results.push(...outs.slice(0, -1));
|
|
5871
|
+
a = outs[outs.length - 1];
|
|
5872
|
+
}
|
|
5873
|
+
return results;
|
|
5874
|
+
}
|
|
5875
|
+
/**
|
|
5086
5876
|
* Join a sequence of arrays along an existing axis.
|
|
5087
5877
|
*
|
|
5088
5878
|
* The arrays must have the same shape, except in the dimension corresponding to
|
|
@@ -5094,13 +5884,11 @@ function concatenate(xs, axis = 0) {
|
|
|
5094
5884
|
if (xs.length === 0) throw new Error("Need at least one array to concatenate");
|
|
5095
5885
|
const shapes = xs.map(shape);
|
|
5096
5886
|
axis = checkAxis(axis, shapes[0].length);
|
|
5097
|
-
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays
|
|
5098
|
-
const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
|
|
5887
|
+
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
|
|
5099
5888
|
let result = xs[0];
|
|
5100
|
-
for (let i = 1; i < xs.length; i
|
|
5101
|
-
const
|
|
5102
|
-
|
|
5103
|
-
result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
|
|
5889
|
+
for (let i = 1; i < xs.length; i += 7) {
|
|
5890
|
+
const group = xs.slice(i, i + 7);
|
|
5891
|
+
result = concatenate$1([result, ...group], axis);
|
|
5104
5892
|
}
|
|
5105
5893
|
return result;
|
|
5106
5894
|
}
|
|
@@ -5185,8 +5973,11 @@ function flipud(x) {
|
|
|
5185
5973
|
function fliplr(x) {
|
|
5186
5974
|
return flip(x, 1);
|
|
5187
5975
|
}
|
|
5188
|
-
/**
|
|
5189
|
-
|
|
5976
|
+
/** Transpose the last two dimensions of an array. */
|
|
5977
|
+
function matrixTranspose(a) {
|
|
5978
|
+
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
5979
|
+
return moveaxis$1(a, -1, -2);
|
|
5980
|
+
}
|
|
5190
5981
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
5191
5982
|
function ravel(a) {
|
|
5192
5983
|
return fudgeArray(a).ravel();
|
|
@@ -5202,6 +5993,32 @@ function squeeze(a, axis = null) {
|
|
|
5202
5993
|
return reshape(a, newShape);
|
|
5203
5994
|
}
|
|
5204
5995
|
/**
|
|
5996
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
5997
|
+
*
|
|
5998
|
+
* @param a - Input array.
|
|
5999
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
6000
|
+
* is placed. Can be a single integer or an array of integers.
|
|
6001
|
+
* @returns Array with the number of dimensions increased.
|
|
6002
|
+
*
|
|
6003
|
+
* @example
|
|
6004
|
+
* ```ts
|
|
6005
|
+
* const x = np.array([1, 2]);
|
|
6006
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
6007
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
6008
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
6009
|
+
* ```
|
|
6010
|
+
*/
|
|
6011
|
+
function expandDims(a, axis) {
|
|
6012
|
+
const as = shape(a);
|
|
6013
|
+
axis = typeof axis === "number" ? [axis] : axis;
|
|
6014
|
+
axis = normalizeAxis(axis, as.length + axis.length);
|
|
6015
|
+
const newShape = [];
|
|
6016
|
+
let srcIdx = 0;
|
|
6017
|
+
for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
|
|
6018
|
+
else newShape.push(as[srcIdx++]);
|
|
6019
|
+
return reshape(a, newShape);
|
|
6020
|
+
}
|
|
6021
|
+
/**
|
|
5205
6022
|
* Repeat each element of an array after themselves.
|
|
5206
6023
|
*
|
|
5207
6024
|
* If no axis is provided, use the flattened input array, and return a flat
|
|
@@ -5289,7 +6106,7 @@ function diagonal(a, offset, axis1, axis2) {
|
|
|
5289
6106
|
*/
|
|
5290
6107
|
function diag(v, k = 0) {
|
|
5291
6108
|
const a = fudgeArray(v);
|
|
5292
|
-
if (!Number.isInteger(k)) throw new
|
|
6109
|
+
if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
|
|
5293
6110
|
if (a.ndim === 1) {
|
|
5294
6111
|
const n = a.shape[0];
|
|
5295
6112
|
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
@@ -5297,12 +6114,46 @@ function diag(v, k = 0) {
|
|
|
5297
6114
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
5298
6115
|
else return ret;
|
|
5299
6116
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
5300
|
-
else throw new
|
|
6117
|
+
else throw new Error("numpy.diag only supports 1D and 2D arrays");
|
|
5301
6118
|
}
|
|
5302
6119
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
5303
6120
|
function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
|
|
5304
6121
|
return diagonal(a, offset, axis1, axis2).sum(-1);
|
|
5305
6122
|
}
|
|
6123
|
+
/**
|
|
6124
|
+
* Return a sorted copy of an array.
|
|
6125
|
+
*
|
|
6126
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
6127
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
6128
|
+
*/
|
|
6129
|
+
function sort(a, axis = -1) {
|
|
6130
|
+
return fudgeArray(a).sort(axis);
|
|
6131
|
+
}
|
|
6132
|
+
/**
|
|
6133
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
6134
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
6135
|
+
*
|
|
6136
|
+
* Returns an array of `int32` indices.
|
|
6137
|
+
*
|
|
6138
|
+
* The array is sorted along a specified axis (the last by default).
|
|
6139
|
+
*/
|
|
6140
|
+
function argsort(a, axis = -1) {
|
|
6141
|
+
return fudgeArray(a).argsort(axis);
|
|
6142
|
+
}
|
|
6143
|
+
/**
|
|
6144
|
+
* Take elements from an array along an axis.
|
|
6145
|
+
*
|
|
6146
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
6147
|
+
* numbered axis. By default, the flattened array is used.
|
|
6148
|
+
*/
|
|
6149
|
+
function take(a, indices, axis = null) {
|
|
6150
|
+
if (axis === null) {
|
|
6151
|
+
a = ravel(a);
|
|
6152
|
+
axis = 0;
|
|
6153
|
+
}
|
|
6154
|
+
axis = checkAxis(axis, ndim(a));
|
|
6155
|
+
return gather(a, [indices], [axis], axis);
|
|
6156
|
+
}
|
|
5306
6157
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5307
6158
|
function allclose(actual, expected, options) {
|
|
5308
6159
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5319,11 +6170,11 @@ function allclose(actual, expected, options) {
|
|
|
5319
6170
|
}
|
|
5320
6171
|
/** Matrix product of two arrays. */
|
|
5321
6172
|
function matmul(x, y) {
|
|
5322
|
-
if (ndim(x) === 0 || ndim(y) === 0) throw new
|
|
6173
|
+
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
5323
6174
|
x = x, y = y;
|
|
5324
6175
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5325
6176
|
const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
|
|
5326
|
-
return dot
|
|
6177
|
+
return dot(x, y, {
|
|
5327
6178
|
lhsContractingDims: [-1],
|
|
5328
6179
|
rhsContractingDims: [-2],
|
|
5329
6180
|
lhsBatchDims: range(-2 - numBatchDims, -2),
|
|
@@ -5331,11 +6182,11 @@ function matmul(x, y) {
|
|
|
5331
6182
|
});
|
|
5332
6183
|
}
|
|
5333
6184
|
/** Dot product of two arrays. */
|
|
5334
|
-
function dot(x, y) {
|
|
6185
|
+
function dot$1(x, y) {
|
|
5335
6186
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
5336
6187
|
x = x, y = y;
|
|
5337
6188
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5338
|
-
return dot
|
|
6189
|
+
return dot(x, y, {
|
|
5339
6190
|
lhsContractingDims: [-1],
|
|
5340
6191
|
rhsContractingDims: [-2]
|
|
5341
6192
|
});
|
|
@@ -5351,7 +6202,7 @@ function tensordot(x, y, axes = 2) {
|
|
|
5351
6202
|
x = fudgeArray(x);
|
|
5352
6203
|
y = fudgeArray(y);
|
|
5353
6204
|
if (typeof axes === "number") axes = [range(-axes, 0), range(axes)];
|
|
5354
|
-
return dot
|
|
6205
|
+
return dot(x, y, {
|
|
5355
6206
|
lhsContractingDims: axes[0],
|
|
5356
6207
|
rhsContractingDims: axes[1]
|
|
5357
6208
|
});
|
|
@@ -5444,7 +6295,7 @@ function einsum(...args) {
|
|
|
5444
6295
|
const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
|
|
5445
6296
|
indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
|
|
5446
6297
|
const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
|
|
5447
|
-
const result = dot
|
|
6298
|
+
const result = dot(a, b, {
|
|
5448
6299
|
lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
|
|
5449
6300
|
rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
|
|
5450
6301
|
lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
|
|
@@ -5472,7 +6323,7 @@ function einsum(...args) {
|
|
|
5472
6323
|
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
5473
6324
|
*/
|
|
5474
6325
|
function inner(x, y) {
|
|
5475
|
-
return dot
|
|
6326
|
+
return dot(fudgeArray(x), fudgeArray(y), {
|
|
5476
6327
|
lhsContractingDims: [-1],
|
|
5477
6328
|
rhsContractingDims: [-1]
|
|
5478
6329
|
});
|
|
@@ -5505,6 +6356,30 @@ function vecdot(x, y, { axis } = {}) {
|
|
|
5505
6356
|
function vdot(x, y) {
|
|
5506
6357
|
return dot$2(ravel(x), ravel(y));
|
|
5507
6358
|
}
|
|
6359
|
+
function _convImpl(name, x, y, mode) {
|
|
6360
|
+
if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
|
|
6361
|
+
let flipOutput = false;
|
|
6362
|
+
if (x.shape[0] < y.shape[0]) {
|
|
6363
|
+
[x, y] = [y, x];
|
|
6364
|
+
if (name === "correlate") flipOutput = true;
|
|
6365
|
+
}
|
|
6366
|
+
if (name === "convolve") y = flip(y);
|
|
6367
|
+
let padding;
|
|
6368
|
+
if (mode === "valid") padding = "VALID";
|
|
6369
|
+
else if (mode === "same") padding = "SAME_LOWER";
|
|
6370
|
+
else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
|
|
6371
|
+
else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
|
|
6372
|
+
const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
|
|
6373
|
+
return flipOutput ? flip(z) : z;
|
|
6374
|
+
}
|
|
6375
|
+
/** Convolution of two one-dimensional arrays. */
|
|
6376
|
+
function convolve(x, y, mode = "full") {
|
|
6377
|
+
return _convImpl("convolve", x, y, mode);
|
|
6378
|
+
}
|
|
6379
|
+
/** Correlation of two one dimensional arrays. */
|
|
6380
|
+
function correlate(x, y, mode = "valid") {
|
|
6381
|
+
return _convImpl("correlate", x, y, mode);
|
|
6382
|
+
}
|
|
5508
6383
|
/**
|
|
5509
6384
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
5510
6385
|
*
|
|
@@ -5513,7 +6388,7 @@ function vdot(x, y) {
|
|
|
5513
6388
|
*/
|
|
5514
6389
|
function meshgrid(xs, { indexing } = {}) {
|
|
5515
6390
|
indexing ??= "xy";
|
|
5516
|
-
for (const x of xs) if (x.ndim !== 1) throw new
|
|
6391
|
+
for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
|
|
5517
6392
|
if (xs.length <= 1) return xs;
|
|
5518
6393
|
if (indexing === "xy") {
|
|
5519
6394
|
const [a, b, ...rest] = xs;
|
|
@@ -5532,43 +6407,6 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
5532
6407
|
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
5533
6408
|
}
|
|
5534
6409
|
/**
|
|
5535
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
5536
|
-
*
|
|
5537
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
5538
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
5539
|
-
* `k>0` is above it.
|
|
5540
|
-
*/
|
|
5541
|
-
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
5542
|
-
m ??= n;
|
|
5543
|
-
dtype ??= DType.Float32;
|
|
5544
|
-
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
5545
|
-
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
5546
|
-
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
5547
|
-
const rows = arange(k, n + k, 1, {
|
|
5548
|
-
dtype: DType.Int32,
|
|
5549
|
-
device
|
|
5550
|
-
});
|
|
5551
|
-
const cols = arange(0, m, 1, {
|
|
5552
|
-
dtype: DType.Int32,
|
|
5553
|
-
device
|
|
5554
|
-
});
|
|
5555
|
-
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
5556
|
-
}
|
|
5557
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
5558
|
-
function tril(a, k = 0) {
|
|
5559
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5560
|
-
a = fudgeArray(a);
|
|
5561
|
-
const [n, m] = a.shape.slice(-2);
|
|
5562
|
-
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
5563
|
-
}
|
|
5564
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
5565
|
-
function triu(a, k = 0) {
|
|
5566
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5567
|
-
a = fudgeArray(a);
|
|
5568
|
-
const [n, m] = a.shape.slice(-2);
|
|
5569
|
-
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
5570
|
-
}
|
|
5571
|
-
/**
|
|
5572
6410
|
* Clip (limit) the values in an array.
|
|
5573
6411
|
*
|
|
5574
6412
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -5592,8 +6430,6 @@ function absolute(x) {
|
|
|
5592
6430
|
x = fudgeArray(x);
|
|
5593
6431
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
5594
6432
|
}
|
|
5595
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
5596
|
-
const abs = absolute;
|
|
5597
6433
|
/** Return an element-wise indication of sign of the input. */
|
|
5598
6434
|
function sign(x) {
|
|
5599
6435
|
x = fudgeArray(x);
|
|
@@ -5637,6 +6473,20 @@ function tan(x) {
|
|
|
5637
6473
|
x = fudgeArray(x);
|
|
5638
6474
|
return sin(x.ref).div(cos(x));
|
|
5639
6475
|
}
|
|
6476
|
+
/**
|
|
6477
|
+
* @function
|
|
6478
|
+
* Return the normalized sinc function.
|
|
6479
|
+
*
|
|
6480
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
6481
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
6482
|
+
*
|
|
6483
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
6484
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
6485
|
+
*/
|
|
6486
|
+
const sinc = jit$1(function sinc$1(x) {
|
|
6487
|
+
const pix = x.ref.mul(Math.PI);
|
|
6488
|
+
return where(equal(x, 0), 1, sin(pix.ref).div(pix));
|
|
6489
|
+
});
|
|
5640
6490
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
5641
6491
|
function acos(x) {
|
|
5642
6492
|
return subtract(pi / 2, asin(x));
|
|
@@ -5672,12 +6522,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
|
|
|
5672
6522
|
const denom = where(xNeg, y, r.add(x));
|
|
5673
6523
|
return atan(numer.div(denom)).mul(2);
|
|
5674
6524
|
});
|
|
5675
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
5676
|
-
const arccos = acos;
|
|
5677
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
5678
|
-
const arctan = atan;
|
|
5679
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
5680
|
-
const arctan2 = atan2;
|
|
5681
6525
|
/** Element-wise subtraction, with broadcasting. */
|
|
5682
6526
|
function subtract(x, y) {
|
|
5683
6527
|
x = fudgeArray(x);
|
|
@@ -5695,6 +6539,25 @@ function trueDivide(x, y) {
|
|
|
5695
6539
|
return x.div(y);
|
|
5696
6540
|
}
|
|
5697
6541
|
/**
|
|
6542
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
6543
|
+
*
|
|
6544
|
+
* The result is always rounded towards negative infinity.
|
|
6545
|
+
*
|
|
6546
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
6547
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
6548
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
6549
|
+
*
|
|
6550
|
+
* @param x - Dividend array.
|
|
6551
|
+
* @param y - Divisor array.
|
|
6552
|
+
* @returns Element-wise floor division of x by y.
|
|
6553
|
+
*/
|
|
6554
|
+
function floorDivide(x, y) {
|
|
6555
|
+
x = fudgeArray(x);
|
|
6556
|
+
y = fudgeArray(y);
|
|
6557
|
+
if (isFloatDtype(x.dtype) || isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
|
|
6558
|
+
return subtract(x, remainder(x.ref, y.ref)).div(y);
|
|
6559
|
+
}
|
|
6560
|
+
/**
|
|
5698
6561
|
* @function
|
|
5699
6562
|
* Calculate element-wise floating-point modulo operation.
|
|
5700
6563
|
*/
|
|
@@ -5708,8 +6571,20 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
5708
6571
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
5709
6572
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
5710
6573
|
});
|
|
5711
|
-
/**
|
|
5712
|
-
|
|
6574
|
+
/**
|
|
6575
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
6576
|
+
*
|
|
6577
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
6578
|
+
*
|
|
6579
|
+
* @param x - Dividend array.
|
|
6580
|
+
* @param y - Divisor array.
|
|
6581
|
+
* @returns Tuple of [quotient, remainder].
|
|
6582
|
+
*/
|
|
6583
|
+
function divmod(x, y) {
|
|
6584
|
+
const xArr = fudgeArray(x);
|
|
6585
|
+
const yArr = fudgeArray(y);
|
|
6586
|
+
return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
|
|
6587
|
+
}
|
|
5713
6588
|
/** Round input to the nearest integer towards zero. */
|
|
5714
6589
|
function trunc(x) {
|
|
5715
6590
|
return idiv(x, 1);
|
|
@@ -5731,9 +6606,9 @@ function ldexp(x1, x2) {
|
|
|
5731
6606
|
*/
|
|
5732
6607
|
function frexp(x) {
|
|
5733
6608
|
x = fudgeArray(x);
|
|
5734
|
-
const absx =
|
|
6609
|
+
const absx = absolute(x.ref);
|
|
5735
6610
|
const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(DType.Int32));
|
|
5736
|
-
const mantissa =
|
|
6611
|
+
const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
|
|
5737
6612
|
return [mantissa, exponent];
|
|
5738
6613
|
}
|
|
5739
6614
|
/** Calculate `2**p` for all p in the input array. */
|
|
@@ -5776,10 +6651,8 @@ const power = jit$1(function power$1(x1, x2) {
|
|
|
5776
6651
|
const x2i = trunc(x2.ref);
|
|
5777
6652
|
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
5778
6653
|
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
5779
|
-
return where(shouldBeNaN, nan, exp(log(
|
|
6654
|
+
return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
|
|
5780
6655
|
});
|
|
5781
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
5782
|
-
const pow = power;
|
|
5783
6656
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
5784
6657
|
const cbrt = jit$1(function cbrt$1(x) {
|
|
5785
6658
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
@@ -5845,69 +6718,360 @@ const arccosh = jit$1(function arccosh$1(x) {
|
|
|
5845
6718
|
const arctanh = jit$1(function arctanh$1(x) {
|
|
5846
6719
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
5847
6720
|
});
|
|
5848
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5849
|
-
const asinh = arcsinh;
|
|
5850
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5851
|
-
const acosh = arccosh;
|
|
5852
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5853
|
-
const atanh = arctanh;
|
|
5854
6721
|
/**
|
|
5855
|
-
* Compute the variance of an array.
|
|
5856
|
-
*
|
|
5857
|
-
* The variance is computed for the flattened array by default, otherwise over
|
|
5858
|
-
* the specified axis.
|
|
6722
|
+
* Compute the variance of an array.
|
|
6723
|
+
*
|
|
6724
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
6725
|
+
* the specified axis.
|
|
6726
|
+
*
|
|
6727
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
6728
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
6729
|
+
*/
|
|
6730
|
+
function var_(x, axis = null, opts) {
|
|
6731
|
+
x = fudgeArray(x);
|
|
6732
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
6733
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
6734
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
6735
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
6736
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
6737
|
+
}
|
|
6738
|
+
/**
|
|
6739
|
+
* Compute the standard deviation of an array.
|
|
6740
|
+
*
|
|
6741
|
+
* The standard deviation is computed for the flattened array by default,
|
|
6742
|
+
* otherwise over the specified axis.
|
|
6743
|
+
*
|
|
6744
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
6745
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
6746
|
+
*/
|
|
6747
|
+
function std(x, axis = null, opts) {
|
|
6748
|
+
return sqrt(var_(x, axis, opts));
|
|
6749
|
+
}
|
|
6750
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
6751
|
+
function cov(x, y = null, { rowvar = true } = {}) {
|
|
6752
|
+
x = fudgeArray(x);
|
|
6753
|
+
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6754
|
+
if (y !== null) {
|
|
6755
|
+
y = fudgeArray(y);
|
|
6756
|
+
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6757
|
+
x = vstack([x, y]);
|
|
6758
|
+
}
|
|
6759
|
+
if (!rowvar) x = x.transpose();
|
|
6760
|
+
const [_M, N] = x.shape;
|
|
6761
|
+
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6762
|
+
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
6763
|
+
}
|
|
6764
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
6765
|
+
function corrcoef(x, y) {
|
|
6766
|
+
const c = cov(x, y);
|
|
6767
|
+
const variances = diag(c.ref);
|
|
6768
|
+
const norm = sqrt(outer(variances.ref, variances));
|
|
6769
|
+
return c.div(norm);
|
|
6770
|
+
}
|
|
6771
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
6772
|
+
function isinf(x) {
|
|
6773
|
+
x = fudgeArray(x);
|
|
6774
|
+
return isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
|
|
6775
|
+
}
|
|
6776
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
6777
|
+
function isnan(x) {
|
|
6778
|
+
x = fudgeArray(x);
|
|
6779
|
+
return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
6780
|
+
}
|
|
6781
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
6782
|
+
function isneginf(x) {
|
|
6783
|
+
x = fudgeArray(x);
|
|
6784
|
+
return isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
|
|
6785
|
+
}
|
|
6786
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
6787
|
+
function isposinf(x) {
|
|
6788
|
+
x = fudgeArray(x);
|
|
6789
|
+
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6790
|
+
}
|
|
6791
|
+
/**
|
|
6792
|
+
* @function
|
|
6793
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
6794
|
+
*/
|
|
6795
|
+
const isfinite = jit$1(function isfinite$1(x) {
|
|
6796
|
+
if (!isFloatDtype(x.dtype)) return fullLike$1(x, true);
|
|
6797
|
+
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
6798
|
+
});
|
|
6799
|
+
|
|
6800
|
+
//#endregion
|
|
6801
|
+
//#region src/library/lax-linalg.ts
|
|
6802
|
+
var lax_linalg_exports = {};
|
|
6803
|
+
__export(lax_linalg_exports, {
|
|
6804
|
+
cholesky: () => cholesky$1,
|
|
6805
|
+
lu: () => lu,
|
|
6806
|
+
triangularSolve: () => triangularSolve
|
|
6807
|
+
});
|
|
6808
|
+
/**
|
|
6809
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
6810
|
+
*
|
|
6811
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
6812
|
+
*
|
|
6813
|
+
* - A = L @ L^T (for upper=false, default)
|
|
6814
|
+
* - A = U^T @ U (for upper=true)
|
|
6815
|
+
*
|
|
6816
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
6817
|
+
* The input matrix must be symmetric and positive-definite.
|
|
6818
|
+
*
|
|
6819
|
+
* @example
|
|
6820
|
+
* ```ts
|
|
6821
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6822
|
+
*
|
|
6823
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
6824
|
+
*
|
|
6825
|
+
* // Lower Cholesky factorization (default):
|
|
6826
|
+
* const L = lax.linalg.cholesky(x);
|
|
6827
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
6828
|
+
*
|
|
6829
|
+
* // Upper Cholesky factorization:
|
|
6830
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
6831
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6832
|
+
* ```
|
|
6833
|
+
*/
|
|
6834
|
+
function cholesky$1(a, { upper = false } = {}) {
|
|
6835
|
+
const L = cholesky$2(a);
|
|
6836
|
+
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6837
|
+
}
|
|
6838
|
+
/**
|
|
6839
|
+
* LU decomposition with partial pivoting.
|
|
6840
|
+
*
|
|
6841
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
6842
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
6843
|
+
* and `U` is upper-triangular.
|
|
6844
|
+
*
|
|
6845
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
6846
|
+
*
|
|
6847
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
6848
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
6849
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
6850
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
6851
|
+
*
|
|
6852
|
+
* @example
|
|
6853
|
+
* ```ts
|
|
6854
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6855
|
+
*
|
|
6856
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
6857
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
6858
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
6859
|
+
* // pivots = [1, 1]
|
|
6860
|
+
* // permutation = [1, 0]
|
|
6861
|
+
* ```
|
|
6862
|
+
*/
|
|
6863
|
+
function lu(x) {
|
|
6864
|
+
return lu$1(x);
|
|
6865
|
+
}
|
|
6866
|
+
/**
|
|
6867
|
+
* Solve a triangular linear system.
|
|
6868
|
+
*
|
|
6869
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
6870
|
+
* where `a` is a triangular matrix.
|
|
6871
|
+
*
|
|
6872
|
+
* @example
|
|
6873
|
+
* ```ts
|
|
6874
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6875
|
+
*
|
|
6876
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
6877
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
6878
|
+
*
|
|
6879
|
+
* // Solve L @ x = b
|
|
6880
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
6881
|
+
* // x = [[2.], [5./3.]]
|
|
6882
|
+
* ```
|
|
6883
|
+
*/
|
|
6884
|
+
function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
|
|
6885
|
+
a = fudgeArray(a);
|
|
6886
|
+
b = fudgeArray(b);
|
|
6887
|
+
if (!leftSide) transposeA = !transposeA;
|
|
6888
|
+
else b = moveaxis$1(b, -2, -1);
|
|
6889
|
+
if (transposeA) a = moveaxis$1(a, -2, -1);
|
|
6890
|
+
let x = triangularSolve$1(a, b, {
|
|
6891
|
+
lower,
|
|
6892
|
+
unitDiagonal
|
|
6893
|
+
});
|
|
6894
|
+
if (leftSide) x = moveaxis$1(x, -2, -1);
|
|
6895
|
+
return x;
|
|
6896
|
+
}
|
|
6897
|
+
|
|
6898
|
+
//#endregion
|
|
6899
|
+
//#region src/library/lax.ts
|
|
6900
|
+
var lax_exports = {};
|
|
6901
|
+
__export(lax_exports, {
|
|
6902
|
+
conv: () => conv,
|
|
6903
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
6904
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6905
|
+
dot: () => dot,
|
|
6906
|
+
erf: () => erf,
|
|
6907
|
+
erfc: () => erfc,
|
|
6908
|
+
linalg: () => lax_linalg_exports,
|
|
6909
|
+
reduceWindow: () => reduceWindow,
|
|
6910
|
+
stopGradient: () => stopGradient$1
|
|
6911
|
+
});
|
|
6912
|
+
/**
|
|
6913
|
+
* General dot product/contraction operator.
|
|
5859
6914
|
*
|
|
5860
|
-
*
|
|
5861
|
-
*
|
|
6915
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
6916
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
5862
6917
|
*/
|
|
5863
|
-
function
|
|
5864
|
-
|
|
5865
|
-
|
|
5866
|
-
|
|
5867
|
-
|
|
5868
|
-
|
|
5869
|
-
|
|
6918
|
+
function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
6919
|
+
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
6920
|
+
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
6921
|
+
lc = lc.map((a) => checkAxis(a, lhs.ndim));
|
|
6922
|
+
rc = rc.map((a) => checkAxis(a, rhs.ndim));
|
|
6923
|
+
lb = lb.map((a) => checkAxis(a, lhs.ndim));
|
|
6924
|
+
rb = rb.map((a) => checkAxis(a, rhs.ndim));
|
|
6925
|
+
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
6926
|
+
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
6927
|
+
const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
6928
|
+
const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
6929
|
+
const lhs2 = lhs.transpose([
|
|
6930
|
+
...lb,
|
|
6931
|
+
...lf,
|
|
6932
|
+
...lc
|
|
6933
|
+
]);
|
|
6934
|
+
const rhs2 = rhs.transpose([
|
|
6935
|
+
...rb,
|
|
6936
|
+
...rf,
|
|
6937
|
+
...rc
|
|
6938
|
+
]);
|
|
6939
|
+
if (lc.length === 0) return mul(lhs2.reshape([
|
|
6940
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6941
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6942
|
+
...rep(rf.length, 1)
|
|
6943
|
+
]), rhs2.reshape([
|
|
6944
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6945
|
+
...rep(lf.length, 1),
|
|
6946
|
+
...rf.map((a) => rhs.shape[a])
|
|
6947
|
+
]));
|
|
6948
|
+
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
6949
|
+
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
6950
|
+
if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
6951
|
+
return dot$2(lhs2.reshape([
|
|
6952
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6953
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6954
|
+
...rep(rf.length, 1),
|
|
6955
|
+
prod(dotShapeX)
|
|
6956
|
+
]), rhs2.reshape([
|
|
6957
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6958
|
+
...rep(lf.length, 1),
|
|
6959
|
+
...rf.map((a) => rhs.shape[a]),
|
|
6960
|
+
prod(dotShapeY)
|
|
6961
|
+
]));
|
|
6962
|
+
}
|
|
6963
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
6964
|
+
const padType = padding.toUpperCase();
|
|
6965
|
+
switch (padType) {
|
|
6966
|
+
case "VALID": return rep(inShape.length, [0, 0]);
|
|
6967
|
+
case "SAME":
|
|
6968
|
+
case "SAME_LOWER": {
|
|
6969
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
6970
|
+
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
6971
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
6972
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
6973
|
+
}
|
|
6974
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
6975
|
+
}
|
|
5870
6976
|
}
|
|
5871
6977
|
/**
|
|
5872
|
-
*
|
|
6978
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
5873
6979
|
*
|
|
5874
|
-
* The
|
|
5875
|
-
*
|
|
6980
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6981
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
5876
6982
|
*
|
|
5877
|
-
*
|
|
5878
|
-
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
6983
|
+
* Grouped convolutions are not supported right now.
|
|
5879
6984
|
*/
|
|
5880
|
-
function
|
|
5881
|
-
|
|
6985
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6986
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
6987
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
6988
|
+
if (typeof padding === "string") {
|
|
6989
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
6990
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
6991
|
+
}
|
|
6992
|
+
if (featureGroupCount !== 1) {
|
|
6993
|
+
const G = featureGroupCount;
|
|
6994
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
6995
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
6996
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
6997
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
6998
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
6999
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
7000
|
+
N,
|
|
7001
|
+
G,
|
|
7002
|
+
C_in / G,
|
|
7003
|
+
...xs
|
|
7004
|
+
]), 1, 0);
|
|
7005
|
+
const rhsGrouped = rhs.reshape([
|
|
7006
|
+
G,
|
|
7007
|
+
C_out / G,
|
|
7008
|
+
C_in_per_group,
|
|
7009
|
+
...ks
|
|
7010
|
+
]);
|
|
7011
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
7012
|
+
vmapDims: 1,
|
|
7013
|
+
strides: windowStrides,
|
|
7014
|
+
padding,
|
|
7015
|
+
lhsDilation,
|
|
7016
|
+
rhsDilation
|
|
7017
|
+
});
|
|
7018
|
+
const ys = result.shape.slice(3);
|
|
7019
|
+
return moveaxis(result, 0, 1).reshape([
|
|
7020
|
+
N,
|
|
7021
|
+
C_out,
|
|
7022
|
+
...ys
|
|
7023
|
+
]);
|
|
7024
|
+
}
|
|
7025
|
+
return conv$1(lhs, rhs, {
|
|
7026
|
+
strides: windowStrides,
|
|
7027
|
+
padding,
|
|
7028
|
+
lhsDilation,
|
|
7029
|
+
rhsDilation
|
|
7030
|
+
});
|
|
5882
7031
|
}
|
|
5883
|
-
/**
|
|
5884
|
-
function
|
|
5885
|
-
|
|
5886
|
-
|
|
7032
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
7033
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
7034
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
7035
|
+
lhsDilation,
|
|
7036
|
+
rhsDilation
|
|
7037
|
+
});
|
|
5887
7038
|
}
|
|
5888
|
-
/**
|
|
5889
|
-
function
|
|
5890
|
-
|
|
5891
|
-
return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
7039
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
7040
|
+
function conv(lhs, rhs, windowStrides, padding) {
|
|
7041
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
5892
7042
|
}
|
|
5893
|
-
/**
|
|
5894
|
-
function
|
|
5895
|
-
|
|
5896
|
-
|
|
7043
|
+
/** Reduce a computation over padded windows. */
|
|
7044
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
7045
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
7046
|
+
if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
|
|
7047
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
7048
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
7049
|
+
window: windowDimensions,
|
|
7050
|
+
strides: windowStrides
|
|
7051
|
+
}));
|
|
5897
7052
|
}
|
|
5898
|
-
/**
|
|
5899
|
-
function
|
|
5900
|
-
|
|
5901
|
-
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
7053
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
7054
|
+
function erf(x) {
|
|
7055
|
+
return erf$1(x);
|
|
5902
7056
|
}
|
|
5903
7057
|
/**
|
|
5904
|
-
*
|
|
5905
|
-
*
|
|
7058
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
7059
|
+
*
|
|
7060
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
7061
|
+
* where `erf(x)` is very close to 1.
|
|
5906
7062
|
*/
|
|
5907
|
-
|
|
5908
|
-
|
|
5909
|
-
|
|
5910
|
-
|
|
7063
|
+
function erfc(x) {
|
|
7064
|
+
return erfc$1(x);
|
|
7065
|
+
}
|
|
7066
|
+
/**
|
|
7067
|
+
* Stops gradient computation.
|
|
7068
|
+
*
|
|
7069
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
7070
|
+
* forward or reverse-mode automatic differentiation.
|
|
7071
|
+
*/
|
|
7072
|
+
function stopGradient$1(x) {
|
|
7073
|
+
return stopGradient(x);
|
|
7074
|
+
}
|
|
5911
7075
|
|
|
5912
7076
|
//#endregion
|
|
5913
7077
|
//#region src/library/nn.ts
|
|
@@ -5917,6 +7081,10 @@ __export(nn_exports, {
|
|
|
5917
7081
|
elu: () => elu,
|
|
5918
7082
|
gelu: () => gelu,
|
|
5919
7083
|
glu: () => glu,
|
|
7084
|
+
hardSigmoid: () => hardSigmoid,
|
|
7085
|
+
hardSilu: () => hardSilu,
|
|
7086
|
+
hardSwish: () => hardSilu,
|
|
7087
|
+
hardTanh: () => hardTanh,
|
|
5920
7088
|
identity: () => identity,
|
|
5921
7089
|
leakyRelu: () => leakyRelu,
|
|
5922
7090
|
logSigmoid: () => logSigmoid,
|
|
@@ -5927,14 +7095,17 @@ __export(nn_exports, {
|
|
|
5927
7095
|
oneHot: () => oneHot,
|
|
5928
7096
|
relu: () => relu,
|
|
5929
7097
|
relu6: () => relu6,
|
|
7098
|
+
selu: () => selu,
|
|
5930
7099
|
sigmoid: () => sigmoid,
|
|
5931
7100
|
silu: () => silu,
|
|
5932
7101
|
softSign: () => softSign,
|
|
5933
7102
|
softmax: () => softmax,
|
|
5934
7103
|
softplus: () => softplus,
|
|
7104
|
+
sparsePlus: () => sparsePlus,
|
|
7105
|
+
sparseSigmoid: () => sparseSigmoid,
|
|
5935
7106
|
squareplus: () => squareplus,
|
|
5936
7107
|
standardize: () => standardize,
|
|
5937
|
-
swish: () =>
|
|
7108
|
+
swish: () => silu
|
|
5938
7109
|
});
|
|
5939
7110
|
/**
|
|
5940
7111
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -5969,6 +7140,28 @@ function softplus(x) {
|
|
|
5969
7140
|
return log(exp(x).add(1));
|
|
5970
7141
|
}
|
|
5971
7142
|
/**
|
|
7143
|
+
* @function
|
|
7144
|
+
* Sparse plus function:
|
|
7145
|
+
*
|
|
7146
|
+
* - When `x <= -1`: `0`
|
|
7147
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
7148
|
+
* - When `x >= 1`: `x`
|
|
7149
|
+
*/
|
|
7150
|
+
const sparsePlus = jit$1((x) => {
|
|
7151
|
+
return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
|
|
7152
|
+
});
|
|
7153
|
+
/**
|
|
7154
|
+
* @function
|
|
7155
|
+
* Sparse sigmoid activation function.
|
|
7156
|
+
*
|
|
7157
|
+
* - When `x <= -1`: `0`
|
|
7158
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
7159
|
+
* - When `x >= 1`: `1`
|
|
7160
|
+
*/
|
|
7161
|
+
const sparseSigmoid = jit$1((x) => {
|
|
7162
|
+
return clip(x.add(1).mul(.5), 0, 1);
|
|
7163
|
+
});
|
|
7164
|
+
/**
|
|
5972
7165
|
* Soft-sign activation function, computed element-wise:
|
|
5973
7166
|
* `softsign(x) = x / (|x| + 1)`.
|
|
5974
7167
|
*/
|
|
@@ -5990,17 +7183,6 @@ const silu = jit$1(function silu$1(x) {
|
|
|
5990
7183
|
return x.ref.mul(sigmoid(x));
|
|
5991
7184
|
});
|
|
5992
7185
|
/**
|
|
5993
|
-
* @function
|
|
5994
|
-
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
5995
|
-
* Swish, computed element-wise:
|
|
5996
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
5997
|
-
*
|
|
5998
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
5999
|
-
*
|
|
6000
|
-
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
6001
|
-
*/
|
|
6002
|
-
const swish = silu;
|
|
6003
|
-
/**
|
|
6004
7186
|
* Log-sigmoid activation function, computed element-wise:
|
|
6005
7187
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
6006
7188
|
*/
|
|
@@ -6017,6 +7199,19 @@ function leakyRelu(x, negativeSlope = .01) {
|
|
|
6017
7199
|
x = fudgeArray(x);
|
|
6018
7200
|
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
6019
7201
|
}
|
|
7202
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
7203
|
+
function hardSigmoid(x) {
|
|
7204
|
+
return relu6(add(x, 3)).mul(1 / 6);
|
|
7205
|
+
}
|
|
7206
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
7207
|
+
function hardSilu(x) {
|
|
7208
|
+
x = fudgeArray(x);
|
|
7209
|
+
return x.ref.mul(hardSigmoid(x));
|
|
7210
|
+
}
|
|
7211
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
7212
|
+
function hardTanh(x) {
|
|
7213
|
+
return clip(x, -1, 1);
|
|
7214
|
+
}
|
|
6020
7215
|
/**
|
|
6021
7216
|
* Exponential linear unit activation function.
|
|
6022
7217
|
*
|
|
@@ -6039,6 +7234,20 @@ function celu(x, alpha = 1) {
|
|
|
6039
7234
|
}
|
|
6040
7235
|
/**
|
|
6041
7236
|
* @function
|
|
7237
|
+
* Scaled exponential linear unit activation.
|
|
7238
|
+
*
|
|
7239
|
+
* Computes the element-wise function:
|
|
7240
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
7241
|
+
*
|
|
7242
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
7243
|
+
*/
|
|
7244
|
+
const selu = jit$1(function selu$1(x) {
|
|
7245
|
+
const alpha = 1.6732632423543772;
|
|
7246
|
+
const lambda = 1.0507009873554805;
|
|
7247
|
+
return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
|
|
7248
|
+
});
|
|
7249
|
+
/**
|
|
7250
|
+
* @function
|
|
6042
7251
|
* Gaussion error linear unit (GELU) activation function.
|
|
6043
7252
|
*
|
|
6044
7253
|
* This is computed element-wise. There are two variants depending on whether
|
|
@@ -6192,35 +7401,46 @@ var random_exports = {};
|
|
|
6192
7401
|
__export(random_exports, {
|
|
6193
7402
|
bernoulli: () => bernoulli,
|
|
6194
7403
|
bits: () => bits,
|
|
7404
|
+
cauchy: () => cauchy,
|
|
6195
7405
|
exponential: () => exponential,
|
|
7406
|
+
gumbel: () => gumbel,
|
|
6196
7407
|
key: () => key,
|
|
7408
|
+
laplace: () => laplace,
|
|
7409
|
+
multivariateNormal: () => multivariateNormal,
|
|
6197
7410
|
normal: () => normal,
|
|
6198
7411
|
split: () => split,
|
|
6199
7412
|
uniform: () => uniform
|
|
6200
7413
|
});
|
|
6201
|
-
function validateKeyShape(key$1) {
|
|
7414
|
+
function validateKeyShape(key$1, scalar = false) {
|
|
6202
7415
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
6203
7416
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
7417
|
+
if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
|
|
6204
7418
|
return key$1.shape.slice(0, -1);
|
|
6205
7419
|
}
|
|
7420
|
+
function getK01(key$1) {
|
|
7421
|
+
const keyShape = validateKeyShape(key$1, true);
|
|
7422
|
+
let [k0, k1] = split$2(key$1, -1, [1, 1]);
|
|
7423
|
+
k0 = k0.reshape(keyShape);
|
|
7424
|
+
k1 = k1.reshape(keyShape);
|
|
7425
|
+
return [k0, k1];
|
|
7426
|
+
}
|
|
6206
7427
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
6207
7428
|
function key(seed) {
|
|
6208
|
-
seed = seed
|
|
6209
|
-
|
|
7429
|
+
seed = array(seed, { dtype: DType.Uint32 });
|
|
7430
|
+
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7431
|
+
return stack([0, seed]);
|
|
6210
7432
|
}
|
|
6211
7433
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
6212
7434
|
function split(key$1, num = 2) {
|
|
6213
7435
|
const shape$1 = typeof num === "number" ? [num] : num;
|
|
6214
7436
|
for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
|
|
6215
|
-
const
|
|
6216
|
-
const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
|
|
6217
|
-
const k1 = key$1.slice(...keyShape.map(() => null), 1);
|
|
7437
|
+
const [k0, k1] = getK01(key$1);
|
|
6218
7438
|
return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
|
|
6219
7439
|
}
|
|
6220
7440
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
6221
7441
|
function bits(key$1, shape$1 = []) {
|
|
6222
|
-
const
|
|
6223
|
-
return randomBits(
|
|
7442
|
+
const [k0, k1] = getK01(key$1);
|
|
7443
|
+
return randomBits(k0, k1, shape$1);
|
|
6224
7444
|
}
|
|
6225
7445
|
/**
|
|
6226
7446
|
* @function
|
|
@@ -6252,6 +7472,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
6252
7472
|
}
|
|
6253
7473
|
/**
|
|
6254
7474
|
* @function
|
|
7475
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7476
|
+
*
|
|
7477
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
7478
|
+
*/
|
|
7479
|
+
const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
7480
|
+
const u = uniform(key$1, shape$1);
|
|
7481
|
+
return tan(u.sub(.5).mul(Math.PI));
|
|
7482
|
+
}, { staticArgnums: [1] });
|
|
7483
|
+
/**
|
|
7484
|
+
* @function
|
|
6255
7485
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
6256
7486
|
*/
|
|
6257
7487
|
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
@@ -6260,6 +7490,56 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
6260
7490
|
}, { staticArgnums: [1] });
|
|
6261
7491
|
/**
|
|
6262
7492
|
* @function
|
|
7493
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
7494
|
+
*
|
|
7495
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
7496
|
+
*/
|
|
7497
|
+
const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
|
|
7498
|
+
const u = uniform(key$1, shape$1);
|
|
7499
|
+
return negative(log(negative(log1p(negative(u)))));
|
|
7500
|
+
}, { staticArgnums: [1] });
|
|
7501
|
+
/**
|
|
7502
|
+
* @function
|
|
7503
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
7504
|
+
*
|
|
7505
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
7506
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
7507
|
+
*/
|
|
7508
|
+
const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
7509
|
+
const u = uniform(key$1, shape$1);
|
|
7510
|
+
const centered = u.sub(.5);
|
|
7511
|
+
const s = sign(centered.ref);
|
|
7512
|
+
const absVal = absolute(centered);
|
|
7513
|
+
return s.mul(log1p(absVal.mul(-2)).mul(-1));
|
|
7514
|
+
}, { staticArgnums: [1] });
|
|
7515
|
+
/**
|
|
7516
|
+
* @function
|
|
7517
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
7518
|
+
*
|
|
7519
|
+
* The values are returned with the given shape, along with the final dimension
|
|
7520
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
7521
|
+
*
|
|
7522
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
7523
|
+
*
|
|
7524
|
+
* - `key` - PRNG key
|
|
7525
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
7526
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
7527
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
7528
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
7529
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
7530
|
+
*/
|
|
7531
|
+
const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
|
|
7532
|
+
mean$1 = fudgeArray(mean$1);
|
|
7533
|
+
cov$1 = fudgeArray(cov$1);
|
|
7534
|
+
const n = mean$1.shape[mean$1.ndim - 1];
|
|
7535
|
+
if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
|
|
7536
|
+
const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
|
|
7537
|
+
const L = cholesky(cov$1);
|
|
7538
|
+
const z = normal(key$1, outputShape);
|
|
7539
|
+
return einsum("...ij,...j->...i", L, z).add(mean$1);
|
|
7540
|
+
}, { staticArgnums: [3] });
|
|
7541
|
+
/**
|
|
7542
|
+
* @function
|
|
6263
7543
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6264
7544
|
*
|
|
6265
7545
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -6368,11 +7648,6 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
6368
7648
|
*/
|
|
6369
7649
|
const jacrev = jacrev$1;
|
|
6370
7650
|
/**
|
|
6371
|
-
* @function
|
|
6372
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
6373
|
-
*/
|
|
6374
|
-
const jacobian = jacrev;
|
|
6375
|
-
/**
|
|
6376
7651
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
6377
7652
|
*
|
|
6378
7653
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -6407,5 +7682,4 @@ async function devicePut(x, device) {
|
|
|
6407
7682
|
}
|
|
6408
7683
|
|
|
6409
7684
|
//#endregion
|
|
6410
|
-
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
6411
|
-
//# sourceMappingURL=index.js.map
|
|
7685
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|