@jax-js/jax 0.1.3 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +15 -9
- package/dist/{backend-BY8wlLEl.js → backend-DaqL-MNz.js} +240 -21
- package/dist/{backend-CmaidnkQ.cjs → backend-DziQSaoQ.cjs} +264 -21
- package/dist/index.cjs +2407 -1132
- package/dist/index.d.cts +596 -97
- package/dist/index.d.ts +596 -97
- package/dist/index.js +2400 -1126
- package/dist/webgl-ClIYb8jP.cjs +522 -0
- package/dist/webgl-RSuZKvgc.js +522 -0
- package/dist/webgpu-Db2JrNBr.cjs +1261 -0
- package/dist/webgpu-Dh7k9io0.js +1261 -0
- package/package.json +1 -1
- package/dist/webgpu-BVns4DbI.cjs +0 -663
- package/dist/webgpu-C9iAP5h5.js +0 -663
package/dist/index.cjs
CHANGED
|
@@ -8,9 +8,9 @@ var __hasOwnProp = Object.prototype.hasOwnProperty;
|
|
|
8
8
|
var __commonJS = (cb, mod$1) => function() {
|
|
9
9
|
return mod$1 || (0, cb[__getOwnPropNames(cb)[0]])((mod$1 = { exports: {} }).exports, mod$1), mod$1.exports;
|
|
10
10
|
};
|
|
11
|
-
var __export = (target, all) => {
|
|
12
|
-
for (var name in all) __defProp(target, name, {
|
|
13
|
-
get: all[name],
|
|
11
|
+
var __export = (target, all$1) => {
|
|
12
|
+
for (var name in all$1) __defProp(target, name, {
|
|
13
|
+
get: all$1[name],
|
|
14
14
|
enumerable: true
|
|
15
15
|
});
|
|
16
16
|
};
|
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-DziQSaoQ.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -362,6 +362,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
362
362
|
Primitive$1["Mul"] = "mul";
|
|
363
363
|
Primitive$1["Idiv"] = "idiv";
|
|
364
364
|
Primitive$1["Mod"] = "mod";
|
|
365
|
+
Primitive$1["Min"] = "min";
|
|
366
|
+
Primitive$1["Max"] = "max";
|
|
365
367
|
Primitive$1["Neg"] = "neg";
|
|
366
368
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
367
369
|
Primitive$1["Floor"] = "floor";
|
|
@@ -369,7 +371,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
369
371
|
Primitive$1["StopGradient"] = "stop_gradient";
|
|
370
372
|
Primitive$1["Cast"] = "cast";
|
|
371
373
|
Primitive$1["Bitcast"] = "bitcast";
|
|
372
|
-
Primitive$1["RandomBits"] = "random_bits";
|
|
373
374
|
Primitive$1["Sin"] = "sin";
|
|
374
375
|
Primitive$1["Cos"] = "cos";
|
|
375
376
|
Primitive$1["Asin"] = "asin";
|
|
@@ -379,8 +380,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
379
380
|
Primitive$1["Erf"] = "erf";
|
|
380
381
|
Primitive$1["Erfc"] = "erfc";
|
|
381
382
|
Primitive$1["Sqrt"] = "sqrt";
|
|
382
|
-
Primitive$1["Min"] = "min";
|
|
383
|
-
Primitive$1["Max"] = "max";
|
|
384
383
|
Primitive$1["Reduce"] = "reduce";
|
|
385
384
|
Primitive$1["Dot"] = "dot";
|
|
386
385
|
Primitive$1["Conv"] = "conv";
|
|
@@ -388,14 +387,22 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
388
387
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
389
388
|
Primitive$1["Compare"] = "compare";
|
|
390
389
|
Primitive$1["Where"] = "where";
|
|
390
|
+
Primitive$1["Concatenate"] = "concatenate";
|
|
391
|
+
Primitive$1["Split"] = "split";
|
|
392
|
+
Primitive$1["RandomBits"] = "random_bits";
|
|
393
|
+
Primitive$1["Gather"] = "gather";
|
|
391
394
|
Primitive$1["Transpose"] = "transpose";
|
|
392
395
|
Primitive$1["Broadcast"] = "broadcast";
|
|
393
396
|
Primitive$1["Reshape"] = "reshape";
|
|
394
397
|
Primitive$1["Flip"] = "flip";
|
|
395
398
|
Primitive$1["Shrink"] = "shrink";
|
|
396
399
|
Primitive$1["Pad"] = "pad";
|
|
397
|
-
Primitive$1["
|
|
398
|
-
Primitive$1["
|
|
400
|
+
Primitive$1["Sort"] = "sort";
|
|
401
|
+
Primitive$1["Argsort"] = "argsort";
|
|
402
|
+
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
403
|
+
Primitive$1["Cholesky"] = "cholesky";
|
|
404
|
+
Primitive$1["LU"] = "lu";
|
|
405
|
+
Primitive$1["Jit"] = "jit";
|
|
399
406
|
return Primitive$1;
|
|
400
407
|
}({});
|
|
401
408
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
@@ -417,6 +424,12 @@ function idiv(x, y) {
|
|
|
417
424
|
function mod(x, y) {
|
|
418
425
|
return bind1(Primitive.Mod, [x, y]);
|
|
419
426
|
}
|
|
427
|
+
function min$1(x, y) {
|
|
428
|
+
return bind1(Primitive.Min, [x, y]);
|
|
429
|
+
}
|
|
430
|
+
function max$1(x, y) {
|
|
431
|
+
return bind1(Primitive.Max, [x, y]);
|
|
432
|
+
}
|
|
420
433
|
function neg(x) {
|
|
421
434
|
return bind1(Primitive.Neg, [x]);
|
|
422
435
|
}
|
|
@@ -438,12 +451,6 @@ function cast(x, dtype) {
|
|
|
438
451
|
function bitcast(x, dtype) {
|
|
439
452
|
return bind1(Primitive.Bitcast, [x], { dtype });
|
|
440
453
|
}
|
|
441
|
-
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
442
|
-
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
443
|
-
shape: shape$1,
|
|
444
|
-
mode
|
|
445
|
-
});
|
|
446
|
-
}
|
|
447
454
|
function sin$1(x) {
|
|
448
455
|
return bind1(Primitive.Sin, [x]);
|
|
449
456
|
}
|
|
@@ -471,12 +478,6 @@ function erfc$1(x) {
|
|
|
471
478
|
function sqrt$1(x) {
|
|
472
479
|
return bind1(Primitive.Sqrt, [x]);
|
|
473
480
|
}
|
|
474
|
-
function min$1(x, y) {
|
|
475
|
-
return bind1(Primitive.Min, [x, y]);
|
|
476
|
-
}
|
|
477
|
-
function max$1(x, y) {
|
|
478
|
-
return bind1(Primitive.Max, [x, y]);
|
|
479
|
-
}
|
|
480
481
|
function reduce(x, op, axis = null, opts) {
|
|
481
482
|
if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
482
483
|
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
@@ -532,6 +533,41 @@ function where$1(cond, x, y) {
|
|
|
532
533
|
y
|
|
533
534
|
]);
|
|
534
535
|
}
|
|
536
|
+
function concatenate$1(xs, axis) {
|
|
537
|
+
if (xs.length === 0) throw new Error("concatenate requires at least one input");
|
|
538
|
+
const avals = xs.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
539
|
+
axis = require_backend.checkAxis(axis, avals[0].ndim);
|
|
540
|
+
for (const x of avals) if (x.ndim !== avals[0].ndim || !x.shape.every((s, i) => i === axis || s === avals[0].shape[i])) throw new Error(`Concatenate: inputs ${avals[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
541
|
+
return bind1(Primitive.Concatenate, xs, { axis });
|
|
542
|
+
}
|
|
543
|
+
function split$2(x, axis, sizes) {
|
|
544
|
+
axis = require_backend.checkAxis(axis, ndim$1(x));
|
|
545
|
+
if (sizes.some((s) => s < 0 || !Number.isInteger(s))) throw new Error(`split: sizes must be nonnegative integers, got ${JSON.stringify(sizes)}`);
|
|
546
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
547
|
+
if (totalSize !== getShape(x)[axis]) throw new Error(`split: sizes must sum to the size of the axis ${axis}, got ${totalSize}`);
|
|
548
|
+
return bind(Primitive.Split, [x], {
|
|
549
|
+
axis,
|
|
550
|
+
sizes
|
|
551
|
+
});
|
|
552
|
+
}
|
|
553
|
+
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
554
|
+
if (!require_backend.deepEqual(k0.shape, k1.shape) || k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new Error(`randomBits: key parts must be uint32 with the same shape, got ${ShapedArray.fromAval(k0.aval)} and ${ShapedArray.fromAval(k1.aval)}`);
|
|
555
|
+
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
556
|
+
shape: shape$1,
|
|
557
|
+
mode
|
|
558
|
+
});
|
|
559
|
+
}
|
|
560
|
+
function gather(x, indices, axis, outDim) {
|
|
561
|
+
if (indices.length === 0) throw new Error("gather() requires at least one index");
|
|
562
|
+
if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
|
|
563
|
+
axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
|
|
564
|
+
if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
|
|
565
|
+
outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
|
|
566
|
+
return bind1(Primitive.Gather, [x, ...indices], {
|
|
567
|
+
axis,
|
|
568
|
+
outDim
|
|
569
|
+
});
|
|
570
|
+
}
|
|
535
571
|
function transpose$1(x, perm) {
|
|
536
572
|
perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
|
|
537
573
|
if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
@@ -581,16 +617,39 @@ function pad$1(x, width) {
|
|
|
581
617
|
} else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
|
|
582
618
|
return bind1(Primitive.Pad, [x], { width });
|
|
583
619
|
}
|
|
584
|
-
function
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
}
|
|
620
|
+
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
621
|
+
const as = getShape(a);
|
|
622
|
+
const bs = getShape(b);
|
|
623
|
+
if (as.length < 2 || bs.length < 2) throw new Error(`triangular_solve: must be >=2D, got a=${as}, b=${bs}`);
|
|
624
|
+
const n = as[as.length - 2];
|
|
625
|
+
if (n !== as[as.length - 1] || n !== bs[bs.length - 1]) throw new Error(`triangular_solve: incompatible shapes a=${as}, b=${bs}`);
|
|
626
|
+
if (lower) {
|
|
627
|
+
a = flip$1(a, [-2, -1]);
|
|
628
|
+
b = flip$1(b, [-1]);
|
|
629
|
+
}
|
|
630
|
+
let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
631
|
+
if (lower) x = flip$1(x, [-1]);
|
|
632
|
+
return x;
|
|
633
|
+
}
|
|
634
|
+
function cholesky$2(x) {
|
|
635
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
636
|
+
if (aval.ndim < 2 || aval.shape[aval.ndim - 1] !== aval.shape[aval.ndim - 2]) throw new Error(`cholesky: expected batch of square matrices, got ${aval}`);
|
|
637
|
+
return bind1(Primitive.Cholesky, [x]);
|
|
638
|
+
}
|
|
639
|
+
function lu$1(x) {
|
|
640
|
+
const aval = ShapedArray.fromAval(getAval(x));
|
|
641
|
+
if (aval.ndim < 2) throw new Error(`lu: expected batch of matrices, got ${aval}`);
|
|
642
|
+
return bind(Primitive.LU, [x]);
|
|
643
|
+
}
|
|
644
|
+
function sort$1(x) {
|
|
645
|
+
const nd = ndim$1(x);
|
|
646
|
+
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
647
|
+
return bind1(Primitive.Sort, [x]);
|
|
648
|
+
}
|
|
649
|
+
function argsort$1(x) {
|
|
650
|
+
const nd = ndim$1(x);
|
|
651
|
+
if (nd === 0) throw new Error("argsort: requires at least 1D input");
|
|
652
|
+
return bind(Primitive.Argsort, [x]);
|
|
594
653
|
}
|
|
595
654
|
function bind1(prim, args, params = {}) {
|
|
596
655
|
const [results] = bind(prim, args, params);
|
|
@@ -690,6 +749,9 @@ var Tracer = class Tracer {
|
|
|
690
749
|
mul(other) {
|
|
691
750
|
return mul(this, other);
|
|
692
751
|
}
|
|
752
|
+
mod(other) {
|
|
753
|
+
return mod(this, other);
|
|
754
|
+
}
|
|
693
755
|
greater(other) {
|
|
694
756
|
return greater$1(this, other);
|
|
695
757
|
}
|
|
@@ -753,7 +815,7 @@ var Tracer = class Tracer {
|
|
|
753
815
|
if (require_backend.isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
|
|
754
816
|
return idiv(this, other);
|
|
755
817
|
}
|
|
756
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
818
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
757
819
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
758
820
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
759
821
|
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
@@ -802,8 +864,42 @@ var Tracer = class Tracer {
|
|
|
802
864
|
*/
|
|
803
865
|
*[Symbol.iterator]() {
|
|
804
866
|
if (this.ndim === 0) throw new Error("Cannot iterate over a scalar array");
|
|
805
|
-
|
|
806
|
-
this.
|
|
867
|
+
let residual = this;
|
|
868
|
+
const subarrayShape = this.shape.slice(1);
|
|
869
|
+
for (let i = 0; i < this.shape[0]; i++) {
|
|
870
|
+
const lr = split$2(residual, 0, [1, residual.shape[0] - 1]);
|
|
871
|
+
yield lr[0].reshape(subarrayShape);
|
|
872
|
+
residual = lr[1];
|
|
873
|
+
}
|
|
874
|
+
residual.dispose();
|
|
875
|
+
}
|
|
876
|
+
/**
|
|
877
|
+
* Return a sorted copy of an array in ascending order.
|
|
878
|
+
*
|
|
879
|
+
* See `jax.numpy.sort` for full docs.
|
|
880
|
+
*/
|
|
881
|
+
sort(axis = -1) {
|
|
882
|
+
axis = require_backend.checkAxis(axis, this.ndim);
|
|
883
|
+
if (this.shape[axis] <= 1) return this;
|
|
884
|
+
if (axis === this.ndim - 1) return sort$1(this);
|
|
885
|
+
const perm = require_backend.range(this.ndim);
|
|
886
|
+
perm.splice(axis, 1);
|
|
887
|
+
perm.push(axis);
|
|
888
|
+
return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
|
|
889
|
+
}
|
|
890
|
+
/**
|
|
891
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
892
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
893
|
+
*
|
|
894
|
+
* See `jax.numpy.argsort` for full docs.
|
|
895
|
+
*/
|
|
896
|
+
argsort(axis = -1) {
|
|
897
|
+
axis = require_backend.checkAxis(axis, this.ndim);
|
|
898
|
+
if (axis === this.ndim - 1) return argsort$1(this)[1];
|
|
899
|
+
const perm = require_backend.range(this.ndim);
|
|
900
|
+
perm.splice(axis, 1);
|
|
901
|
+
perm.push(axis);
|
|
902
|
+
return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
|
|
807
903
|
}
|
|
808
904
|
/**
|
|
809
905
|
* Slice an array along one or more axes.
|
|
@@ -922,6 +1018,12 @@ var ShapedArray = class ShapedArray {
|
|
|
922
1018
|
get ndim() {
|
|
923
1019
|
return this.shape.length;
|
|
924
1020
|
}
|
|
1021
|
+
get size() {
|
|
1022
|
+
return require_backend.prod(this.shape);
|
|
1023
|
+
}
|
|
1024
|
+
scalar() {
|
|
1025
|
+
return new ShapedArray([], this.dtype, this.weakType);
|
|
1026
|
+
}
|
|
925
1027
|
toString() {
|
|
926
1028
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
927
1029
|
}
|
|
@@ -1221,13 +1323,13 @@ var Jaxpr = class Jaxpr {
|
|
|
1221
1323
|
}
|
|
1222
1324
|
return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
|
|
1223
1325
|
}
|
|
1224
|
-
/** Flattens nested
|
|
1326
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1225
1327
|
flatten() {
|
|
1226
|
-
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.
|
|
1328
|
+
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
|
|
1227
1329
|
const newEqns = [];
|
|
1228
1330
|
const varMap = /* @__PURE__ */ new Map();
|
|
1229
1331
|
const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
|
|
1230
|
-
for (const eqn of this.eqns) if (eqn.primitive === Primitive.
|
|
1332
|
+
for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
|
|
1231
1333
|
const jaxpr = eqn.params.jaxpr.flatten();
|
|
1232
1334
|
const translation = /* @__PURE__ */ new Map();
|
|
1233
1335
|
const translationF = (x) => x instanceof Var ? translation.get(x) : x;
|
|
@@ -1328,19 +1430,48 @@ function evalJaxpr(jaxpr, args) {
|
|
|
1328
1430
|
function jaxprAsFun(jaxpr) {
|
|
1329
1431
|
return (...args) => evalJaxpr(jaxpr, args);
|
|
1330
1432
|
}
|
|
1433
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1434
|
+
var ClosedJaxpr = class ClosedJaxpr {
|
|
1435
|
+
constructor(jaxpr, consts) {
|
|
1436
|
+
this.jaxpr = jaxpr;
|
|
1437
|
+
this.consts = consts;
|
|
1438
|
+
}
|
|
1439
|
+
/** String representation of this Jaxpr. */
|
|
1440
|
+
toString() {
|
|
1441
|
+
return this.jaxpr.toString();
|
|
1442
|
+
}
|
|
1443
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1444
|
+
mapJaxpr(f) {
|
|
1445
|
+
return new ClosedJaxpr(f(this.jaxpr), this.consts);
|
|
1446
|
+
}
|
|
1447
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1448
|
+
dispose() {
|
|
1449
|
+
for (const c of this.consts) c.dispose();
|
|
1450
|
+
}
|
|
1451
|
+
};
|
|
1331
1452
|
/** Tracer that records its operations to dynamically construct a Jaxpr. */
|
|
1332
1453
|
var JaxprTracer = class extends Tracer {
|
|
1454
|
+
#rc;
|
|
1333
1455
|
constructor(trace$1, aval) {
|
|
1334
1456
|
super(trace$1);
|
|
1335
1457
|
this.aval = aval;
|
|
1458
|
+
this.#rc = 1;
|
|
1336
1459
|
}
|
|
1337
1460
|
toString() {
|
|
1338
1461
|
return `JaxprTracer(${this.aval.toString()})`;
|
|
1339
1462
|
}
|
|
1340
1463
|
get ref() {
|
|
1464
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1465
|
+
this.#rc++;
|
|
1341
1466
|
return this;
|
|
1342
1467
|
}
|
|
1343
|
-
dispose() {
|
|
1468
|
+
dispose() {
|
|
1469
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1470
|
+
this.#rc--;
|
|
1471
|
+
}
|
|
1472
|
+
trackLiftedConstant() {
|
|
1473
|
+
this.#rc++;
|
|
1474
|
+
}
|
|
1344
1475
|
};
|
|
1345
1476
|
/** Analogous to the 'DynamicJaxprTrace' class in JAX. */
|
|
1346
1477
|
var JaxprTrace = class extends Trace {
|
|
@@ -1353,17 +1484,24 @@ var JaxprTrace = class extends Trace {
|
|
|
1353
1484
|
}
|
|
1354
1485
|
/** Register a constant / literal in this Jaxpr. */
|
|
1355
1486
|
getOrMakeConstTracer(val) {
|
|
1487
|
+
if (!(val instanceof Tracer)) val = pureArray(val);
|
|
1356
1488
|
let tracer = this.builder.constTracers.get(val);
|
|
1357
1489
|
if (tracer === void 0) {
|
|
1358
1490
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
1359
|
-
this.builder.addConst(tracer, val
|
|
1491
|
+
this.builder.addConst(tracer, val);
|
|
1492
|
+
} else {
|
|
1493
|
+
val.dispose();
|
|
1494
|
+
tracer.trackLiftedConstant();
|
|
1360
1495
|
}
|
|
1361
1496
|
return tracer;
|
|
1362
1497
|
}
|
|
1363
1498
|
pure = this.getOrMakeConstTracer;
|
|
1364
1499
|
lift = this.getOrMakeConstTracer;
|
|
1365
1500
|
processPrimitive(primitive, tracers, params) {
|
|
1366
|
-
const avalsIn = tracers.map((t) =>
|
|
1501
|
+
const avalsIn = tracers.map((t) => {
|
|
1502
|
+
t.dispose();
|
|
1503
|
+
return t.aval;
|
|
1504
|
+
});
|
|
1367
1505
|
const avalsOut = abstractEvalRules[primitive](avalsIn, params);
|
|
1368
1506
|
const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
|
|
1369
1507
|
this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
|
|
@@ -1406,20 +1544,17 @@ var JaxprBuilder = class {
|
|
|
1406
1544
|
return v;
|
|
1407
1545
|
}
|
|
1408
1546
|
build(inTracers, outTracers) {
|
|
1409
|
-
|
|
1547
|
+
const [constVars, consts] = require_backend.unzip2(this.constVals.entries());
|
|
1410
1548
|
const t2v = this.getVar.bind(this);
|
|
1411
1549
|
const inBinders = [...constVars, ...inTracers.map(t2v)];
|
|
1412
1550
|
const outVars = outTracers.map(t2v);
|
|
1413
|
-
|
|
1551
|
+
const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
|
|
1414
1552
|
typecheckJaxpr(jaxpr);
|
|
1415
|
-
|
|
1416
|
-
return
|
|
1417
|
-
jaxpr,
|
|
1418
|
-
consts
|
|
1419
|
-
};
|
|
1553
|
+
const cjaxpr = new ClosedJaxpr(jaxpr, consts);
|
|
1554
|
+
return _inlineLiterals(cjaxpr);
|
|
1420
1555
|
}
|
|
1421
1556
|
};
|
|
1422
|
-
function _inlineLiterals(jaxpr, consts) {
|
|
1557
|
+
function _inlineLiterals({ jaxpr, consts }) {
|
|
1423
1558
|
const literals = /* @__PURE__ */ new Map();
|
|
1424
1559
|
const constBinders = [];
|
|
1425
1560
|
const newConsts = [];
|
|
@@ -1434,7 +1569,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
1434
1569
|
const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
|
|
1435
1570
|
const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
|
|
1436
1571
|
typecheckJaxpr(newJaxpr);
|
|
1437
|
-
return
|
|
1572
|
+
return new ClosedJaxpr(newJaxpr, newConsts);
|
|
1438
1573
|
}
|
|
1439
1574
|
function binopAbstractEval([x, y]) {
|
|
1440
1575
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
@@ -1453,6 +1588,8 @@ const abstractEvalRules = {
|
|
|
1453
1588
|
[Primitive.Mul]: binopAbstractEval,
|
|
1454
1589
|
[Primitive.Idiv]: binopAbstractEval,
|
|
1455
1590
|
[Primitive.Mod]: binopAbstractEval,
|
|
1591
|
+
[Primitive.Min]: binopAbstractEval,
|
|
1592
|
+
[Primitive.Max]: binopAbstractEval,
|
|
1456
1593
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1457
1594
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1458
1595
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1466,12 +1603,6 @@ const abstractEvalRules = {
|
|
|
1466
1603
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1467
1604
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1468
1605
|
},
|
|
1469
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1470
|
-
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1471
|
-
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1472
|
-
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1473
|
-
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1474
|
-
},
|
|
1475
1606
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
1476
1607
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
1477
1608
|
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
@@ -1481,8 +1612,6 @@ const abstractEvalRules = {
|
|
|
1481
1612
|
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
1482
1613
|
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
1483
1614
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
1484
|
-
[Primitive.Min]: binopAbstractEval,
|
|
1485
|
-
[Primitive.Max]: binopAbstractEval,
|
|
1486
1615
|
[Primitive.Reduce]([x], { axis }) {
|
|
1487
1616
|
const axisSet = new Set(axis);
|
|
1488
1617
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1504,7 +1633,7 @@ const abstractEvalRules = {
|
|
|
1504
1633
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1505
1634
|
},
|
|
1506
1635
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
1507
|
-
const { dtype, weakType } = promoteAvals(
|
|
1636
|
+
const { dtype, weakType } = promoteAvals(lhs.scalar(), rhs.scalar());
|
|
1508
1637
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
1509
1638
|
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1510
1639
|
},
|
|
@@ -1515,6 +1644,40 @@ const abstractEvalRules = {
|
|
|
1515
1644
|
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
1516
1645
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1517
1646
|
},
|
|
1647
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
1648
|
+
if (xs.length === 0) throw new TypeError("Concatenate requires at least one input");
|
|
1649
|
+
for (const x of xs) if (x.ndim !== xs[0].ndim || !x.shape.every((s, i) => i === axis || s === xs[0].shape[i])) throw new TypeError(`Concatenate: inputs ${xs[0]} and ${x} must match shapes except on axis ${axis}`);
|
|
1650
|
+
const shape$1 = xs[0].shape.slice();
|
|
1651
|
+
shape$1[axis] = xs.reduce((sum$1, x) => sum$1 + x.shape[axis], 0);
|
|
1652
|
+
const { dtype, weakType } = xs.map((x) => x.scalar()).reduce(promoteAvals);
|
|
1653
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
1654
|
+
},
|
|
1655
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
1656
|
+
const totalSize = sizes.reduce((a, b) => a + b, 0);
|
|
1657
|
+
if (x.shape[axis] !== totalSize) throw new TypeError(`Split: sizes ${sizes} do not sum to dimension ${x.shape[axis]} on axis ${axis}`);
|
|
1658
|
+
return sizes.map((size$1) => {
|
|
1659
|
+
return new ShapedArray(x.shape.toSpliced(axis, 1, size$1), x.dtype, x.weakType);
|
|
1660
|
+
});
|
|
1661
|
+
},
|
|
1662
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1663
|
+
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1664
|
+
if (!require_backend.deepEqual(k0.shape, k1.shape)) throw new TypeError(`RandomBits: Keys have different shapes ${k0.shape} and ${k1.shape}`);
|
|
1665
|
+
if (!require_backend.deepEqual(shape$1.slice(0, k0.ndim), k0.shape)) throw new TypeError(`RandomBits: generated shape ${shape$1} must match key shape ${k0.shape}`);
|
|
1666
|
+
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1667
|
+
},
|
|
1668
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
1669
|
+
for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
1670
|
+
if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
|
|
1671
|
+
if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
|
|
1672
|
+
if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
|
|
1673
|
+
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
1674
|
+
const axisSet = new Set(axis);
|
|
1675
|
+
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
1676
|
+
const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
|
|
1677
|
+
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
1678
|
+
newShape.splice(outDim, 0, ...gatherShape);
|
|
1679
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1680
|
+
},
|
|
1518
1681
|
[Primitive.Transpose]([x], { perm }) {
|
|
1519
1682
|
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
1520
1683
|
},
|
|
@@ -1535,23 +1698,41 @@ const abstractEvalRules = {
|
|
|
1535
1698
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
1536
1699
|
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1537
1700
|
},
|
|
1538
|
-
[Primitive.
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
if (
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1701
|
+
[Primitive.Sort]([x]) {
|
|
1702
|
+
if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
|
|
1703
|
+
return [ShapedArray.fromAval(x)];
|
|
1704
|
+
},
|
|
1705
|
+
[Primitive.Argsort]([x]) {
|
|
1706
|
+
if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
|
|
1707
|
+
return [ShapedArray.fromAval(x), new ShapedArray(x.shape, require_backend.DType.Int32, false)];
|
|
1708
|
+
},
|
|
1709
|
+
[Primitive.TriangularSolve]([a, b]) {
|
|
1710
|
+
if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
|
|
1711
|
+
if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
|
|
1712
|
+
const [m, n] = a.shape.slice(-2);
|
|
1713
|
+
const [_batch, q] = b.shape.slice(-2);
|
|
1714
|
+
if (!require_backend.deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
|
|
1715
|
+
return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
|
|
1716
|
+
},
|
|
1717
|
+
[Primitive.Cholesky]([a]) {
|
|
1718
|
+
if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
|
|
1719
|
+
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1720
|
+
return [ShapedArray.fromAval(a)];
|
|
1721
|
+
},
|
|
1722
|
+
[Primitive.LU]([a]) {
|
|
1723
|
+
if (a.ndim < 2) throw new TypeError(`lu: requires at least 2D input, got ${a}`);
|
|
1724
|
+
const batch = a.shape.slice(0, -2);
|
|
1725
|
+
const [m, n] = a.shape.slice(-2);
|
|
1726
|
+
return [
|
|
1727
|
+
ShapedArray.fromAval(a),
|
|
1728
|
+
new ShapedArray([...batch, Math.min(m, n)], require_backend.DType.Int32, false),
|
|
1729
|
+
new ShapedArray([...batch, m], require_backend.DType.Int32, false)
|
|
1730
|
+
];
|
|
1550
1731
|
},
|
|
1551
|
-
[Primitive.
|
|
1732
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
1552
1733
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1553
|
-
if (args.length !== inTypes.length) throw new TypeError(`
|
|
1554
|
-
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`
|
|
1734
|
+
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
1735
|
+
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
|
|
1555
1736
|
return outTypes;
|
|
1556
1737
|
}
|
|
1557
1738
|
};
|
|
@@ -1587,11 +1768,10 @@ function makeJaxpr$1(f, opts) {
|
|
|
1587
1768
|
const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
1588
1769
|
const outs = fFlat(...tracersIn);
|
|
1589
1770
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
1590
|
-
const
|
|
1771
|
+
const jaxpr = builder.build(tracersIn, tracersOut);
|
|
1591
1772
|
if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
|
|
1592
1773
|
return {
|
|
1593
|
-
jaxpr: jaxpr.simplify(),
|
|
1594
|
-
consts,
|
|
1774
|
+
jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
|
|
1595
1775
|
treedef: outTree.value
|
|
1596
1776
|
};
|
|
1597
1777
|
} catch (_) {
|
|
@@ -1610,22 +1790,29 @@ function jit$1(f, opts) {
|
|
|
1610
1790
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
1611
1791
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
1612
1792
|
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
1613
|
-
const { jaxpr,
|
|
1614
|
-
const outs = bind(Primitive.
|
|
1793
|
+
const { jaxpr, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
1794
|
+
const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
|
|
1615
1795
|
name: f.name || "closure",
|
|
1616
|
-
jaxpr,
|
|
1617
|
-
numConsts: consts.length
|
|
1796
|
+
jaxpr: jaxpr.jaxpr,
|
|
1797
|
+
numConsts: jaxpr.consts.length
|
|
1618
1798
|
});
|
|
1619
1799
|
return unflatten(outTree, outs);
|
|
1620
1800
|
});
|
|
1621
1801
|
result.dispose = () => {
|
|
1622
|
-
for (const {
|
|
1802
|
+
for (const { jaxpr } of cache.values()) jaxpr.dispose();
|
|
1623
1803
|
};
|
|
1624
1804
|
return result;
|
|
1625
1805
|
}
|
|
1626
1806
|
|
|
1627
1807
|
//#endregion
|
|
1628
1808
|
//#region src/frontend/jit.ts
|
|
1809
|
+
const routinePrimitives = new Map([
|
|
1810
|
+
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1811
|
+
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1812
|
+
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1813
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky],
|
|
1814
|
+
[Primitive.LU, require_backend.Routines.LU]
|
|
1815
|
+
]);
|
|
1629
1816
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1630
1817
|
var JitProgram = class {
|
|
1631
1818
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1640,9 +1827,14 @@ var JitProgram = class {
|
|
|
1640
1827
|
case "execute": {
|
|
1641
1828
|
const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
|
|
1642
1829
|
const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
|
|
1643
|
-
|
|
1830
|
+
const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
|
|
1831
|
+
if (step.source instanceof require_backend.Kernel) return require_backend.PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
|
|
1832
|
+
else if (step.source instanceof require_backend.Routine) return require_backend.PPrint.pp(`${executeText}, routine ${step.source.name}`);
|
|
1833
|
+
else {
|
|
1834
|
+
step.source;
|
|
1835
|
+
return require_backend.PPrint.pp(executeText);
|
|
1836
|
+
}
|
|
1644
1837
|
}
|
|
1645
|
-
case "const": return require_backend.PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
|
|
1646
1838
|
case "malloc": return require_backend.PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
|
|
1647
1839
|
case "incref": return require_backend.PPrint.pp(`incref ${step.input}`);
|
|
1648
1840
|
case "free": return require_backend.PPrint.pp(`free ${step.input}`);
|
|
@@ -1665,12 +1857,9 @@ var JitProgram = class {
|
|
|
1665
1857
|
const inputs$1 = step.inputs.map((id) => scope.get(id));
|
|
1666
1858
|
const outputs = step.outputs.map((id) => scope.get(id));
|
|
1667
1859
|
if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
|
|
1668
|
-
pending.push(new PendingExecute(this.backend, step.
|
|
1860
|
+
pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
|
|
1669
1861
|
break;
|
|
1670
1862
|
}
|
|
1671
|
-
case "const":
|
|
1672
|
-
scope.set(step.output, step.slot);
|
|
1673
|
-
break;
|
|
1674
1863
|
case "malloc": {
|
|
1675
1864
|
const slot = this.backend.malloc(step.size);
|
|
1676
1865
|
scope.set(step.output, slot);
|
|
@@ -1704,34 +1893,37 @@ var JitProgramBuilder = class {
|
|
|
1704
1893
|
this.#nextId = nargs;
|
|
1705
1894
|
this.steps = [];
|
|
1706
1895
|
}
|
|
1707
|
-
pushConst(slot) {
|
|
1708
|
-
const id = this.#nextId++;
|
|
1709
|
-
this.steps.push({
|
|
1710
|
-
type: "const",
|
|
1711
|
-
slot,
|
|
1712
|
-
output: id
|
|
1713
|
-
});
|
|
1714
|
-
return id;
|
|
1715
|
-
}
|
|
1716
1896
|
pushLit(lit) {
|
|
1717
|
-
const kernel = new require_backend.Kernel(0,
|
|
1897
|
+
const kernel = new require_backend.Kernel(0, lit.aval.size, require_backend.AluExp.const(lit.dtype, lit.value));
|
|
1718
1898
|
return this.pushKernel(kernel, []);
|
|
1719
1899
|
}
|
|
1720
|
-
|
|
1900
|
+
pushBuffer(size$1) {
|
|
1721
1901
|
const id = this.#nextId++;
|
|
1722
1902
|
this.steps.push({
|
|
1723
1903
|
type: "malloc",
|
|
1724
|
-
size:
|
|
1904
|
+
size: size$1,
|
|
1725
1905
|
output: id
|
|
1726
1906
|
});
|
|
1907
|
+
return id;
|
|
1908
|
+
}
|
|
1909
|
+
pushKernel(kernel, inputs) {
|
|
1910
|
+
const id = this.pushBuffer(kernel.bytes);
|
|
1727
1911
|
this.steps.push({
|
|
1728
1912
|
type: "execute",
|
|
1729
|
-
kernel,
|
|
1913
|
+
source: kernel,
|
|
1730
1914
|
inputs,
|
|
1731
1915
|
outputs: [id]
|
|
1732
1916
|
});
|
|
1733
1917
|
return id;
|
|
1734
1918
|
}
|
|
1919
|
+
pushRoutine(routine, inputs, outputs) {
|
|
1920
|
+
this.steps.push({
|
|
1921
|
+
type: "execute",
|
|
1922
|
+
source: routine,
|
|
1923
|
+
inputs,
|
|
1924
|
+
outputs
|
|
1925
|
+
});
|
|
1926
|
+
}
|
|
1735
1927
|
pushIncref(id) {
|
|
1736
1928
|
this.steps.push({
|
|
1737
1929
|
type: "incref",
|
|
@@ -1757,28 +1949,18 @@ var JitProgramBuilder = class {
|
|
|
1757
1949
|
}
|
|
1758
1950
|
};
|
|
1759
1951
|
const jitCompileCache = /* @__PURE__ */ new Map();
|
|
1760
|
-
function jitCompile(backend, jaxpr
|
|
1761
|
-
|
|
1762
|
-
for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
|
|
1763
|
-
const cacheKey = backend.type + require_backend.FpHash.hash(jaxpr, ...consts.map((c) => c.id));
|
|
1952
|
+
function jitCompile(backend, jaxpr) {
|
|
1953
|
+
const cacheKey = backend.type + "," + require_backend.FpHash.hash(jaxpr);
|
|
1764
1954
|
const cached = jitCompileCache.get(cacheKey);
|
|
1765
1955
|
if (cached) return cached;
|
|
1766
1956
|
if (require_backend.DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
|
|
1767
1957
|
jaxpr = jaxpr.flatten().simplify();
|
|
1768
|
-
const nargs = jaxpr.inBinders.length
|
|
1958
|
+
const nargs = jaxpr.inBinders.length;
|
|
1769
1959
|
const builder = new JitProgramBuilder(backend, nargs);
|
|
1770
1960
|
const blackNodes = splitGraphDataflow(backend, jaxpr);
|
|
1771
1961
|
const ctx = /* @__PURE__ */ new Map();
|
|
1772
|
-
for (let i = 0; i < consts.length; i++) {
|
|
1773
|
-
const v = jaxpr.inBinders[i];
|
|
1774
|
-
const slot = consts[i]._realizeSource();
|
|
1775
|
-
ctx.set(v, {
|
|
1776
|
-
type: "imm",
|
|
1777
|
-
arg: builder.pushConst(slot)
|
|
1778
|
-
});
|
|
1779
|
-
}
|
|
1780
1962
|
for (let i = 0; i < nargs; i++) {
|
|
1781
|
-
const v = jaxpr.inBinders[
|
|
1963
|
+
const v = jaxpr.inBinders[i];
|
|
1782
1964
|
ctx.set(v, {
|
|
1783
1965
|
type: "imm",
|
|
1784
1966
|
arg: i
|
|
@@ -1786,6 +1968,31 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1786
1968
|
}
|
|
1787
1969
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1788
1970
|
const eqn = jaxpr.eqns[i];
|
|
1971
|
+
if (routinePrimitives.has(eqn.primitive)) {
|
|
1972
|
+
const routine = new require_backend.Routine(routinePrimitives.get(eqn.primitive), {
|
|
1973
|
+
inputShapes: eqn.inputs.map((x) => x.aval.shape),
|
|
1974
|
+
inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
|
|
1975
|
+
outputShapes: eqn.outBinders.map((x) => x.aval.shape),
|
|
1976
|
+
outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
|
|
1977
|
+
}, eqn.params);
|
|
1978
|
+
const inputs = [];
|
|
1979
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1980
|
+
const jv = ctx.get(input);
|
|
1981
|
+
if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
|
|
1982
|
+
inputs.push(jv.arg);
|
|
1983
|
+
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1984
|
+
const outputs = [];
|
|
1985
|
+
for (const outVar of eqn.outBinders) {
|
|
1986
|
+
const outId = builder.pushBuffer(outVar.aval.size * require_backend.byteWidth(outVar.aval.dtype));
|
|
1987
|
+
outputs.push(outId);
|
|
1988
|
+
ctx.set(outVar, {
|
|
1989
|
+
type: "imm",
|
|
1990
|
+
arg: outId
|
|
1991
|
+
});
|
|
1992
|
+
}
|
|
1993
|
+
builder.pushRoutine(routine, inputs, outputs);
|
|
1994
|
+
continue;
|
|
1995
|
+
}
|
|
1789
1996
|
const inputExps = [];
|
|
1790
1997
|
const inputAvals = [];
|
|
1791
1998
|
const inputArgs = [];
|
|
@@ -1829,35 +2036,37 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1829
2036
|
let reduction;
|
|
1830
2037
|
if (inputReduction) {
|
|
1831
2038
|
const jv = inputReduction;
|
|
1832
|
-
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1833
|
-
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
2039
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp[0];
|
|
2040
|
+
exp$2 = [jv.exp.reindexGids(addArgs(jv.args))];
|
|
1834
2041
|
reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1835
2042
|
} else {
|
|
1836
2043
|
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1837
2044
|
exp$2 = ruleOutput.exp;
|
|
1838
2045
|
reduction = ruleOutput.reduction;
|
|
1839
2046
|
}
|
|
1840
|
-
|
|
1841
|
-
|
|
1842
|
-
|
|
1843
|
-
|
|
1844
|
-
|
|
1845
|
-
|
|
1846
|
-
|
|
1847
|
-
|
|
1848
|
-
|
|
2047
|
+
for (let i$1 = 0; i$1 < eqn.outBinders.length; i$1++) {
|
|
2048
|
+
const outVar = eqn.outBinders[i$1];
|
|
2049
|
+
if (blackNodes.has(outVar)) {
|
|
2050
|
+
const nargs$1 = inputArgs.length;
|
|
2051
|
+
const size$1 = outVar.aval.size;
|
|
2052
|
+
const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2[i$1], reduction);
|
|
2053
|
+
const outId = builder.pushKernel(kernel, inputArgs);
|
|
2054
|
+
ctx.set(outVar, {
|
|
2055
|
+
type: "imm",
|
|
2056
|
+
arg: outId
|
|
2057
|
+
});
|
|
2058
|
+
} else if (reduction) ctx.set(outVar, {
|
|
2059
|
+
type: "red",
|
|
2060
|
+
exp: exp$2[i$1],
|
|
2061
|
+
reduction,
|
|
2062
|
+
args: inputArgs
|
|
1849
2063
|
});
|
|
1850
|
-
|
|
1851
|
-
|
|
1852
|
-
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
}
|
|
1856
|
-
else ctx.set(outVar, {
|
|
1857
|
-
type: "exp",
|
|
1858
|
-
exp: exp$2,
|
|
1859
|
-
args: inputArgs
|
|
1860
|
-
});
|
|
2064
|
+
else ctx.set(outVar, {
|
|
2065
|
+
type: "exp",
|
|
2066
|
+
exp: exp$2[i$1],
|
|
2067
|
+
args: inputArgs
|
|
2068
|
+
});
|
|
2069
|
+
}
|
|
1861
2070
|
}
|
|
1862
2071
|
const outputIds = [];
|
|
1863
2072
|
for (const out of jaxpr.outs) if (out instanceof Var) {
|
|
@@ -1865,7 +2074,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1865
2074
|
if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
|
|
1866
2075
|
outputIds.push(jitValue.arg);
|
|
1867
2076
|
} else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
|
|
1868
|
-
const outputNeedsRef = new Set(
|
|
2077
|
+
const outputNeedsRef = new Set(require_backend.range(nargs));
|
|
1869
2078
|
for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
|
|
1870
2079
|
else outputNeedsRef.add(outputId);
|
|
1871
2080
|
builder.insertFreeSteps(outputIds);
|
|
@@ -1898,17 +2107,22 @@ function broadcastedJit(fn, opts) {
|
|
|
1898
2107
|
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
|
|
1899
2108
|
return exp$2;
|
|
1900
2109
|
});
|
|
1901
|
-
return { exp: fn(exps, params) };
|
|
2110
|
+
return { exp: [fn(exps, params)] };
|
|
1902
2111
|
};
|
|
1903
2112
|
}
|
|
1904
2113
|
function unopJit(fn) {
|
|
1905
2114
|
return ([a], [_as], params) => {
|
|
1906
|
-
return { exp: fn(a, params) };
|
|
2115
|
+
return { exp: [fn(a, params)] };
|
|
1907
2116
|
};
|
|
1908
2117
|
}
|
|
1909
2118
|
function reshapeJit(fn) {
|
|
1910
2119
|
return ([a], [_as], params) => {
|
|
1911
|
-
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2120
|
+
return { exp: [reshapeViews(a, (st) => fn(st, params))] };
|
|
2121
|
+
};
|
|
2122
|
+
}
|
|
2123
|
+
function routineNoJit() {
|
|
2124
|
+
return () => {
|
|
2125
|
+
throw new Error("jit: rule is not implemented for routines");
|
|
1912
2126
|
};
|
|
1913
2127
|
}
|
|
1914
2128
|
const jitRules = {
|
|
@@ -1916,6 +2130,8 @@ const jitRules = {
|
|
|
1916
2130
|
[Primitive.Mul]: broadcastedJit(([a, b]) => require_backend.AluExp.mul(a, b)),
|
|
1917
2131
|
[Primitive.Idiv]: broadcastedJit(([a, b]) => require_backend.AluExp.idiv(a, b)),
|
|
1918
2132
|
[Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
|
|
2133
|
+
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
2134
|
+
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1919
2135
|
[Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
|
|
1920
2136
|
[Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
|
|
1921
2137
|
[Primitive.Floor]: unopJit(require_backend.AluExp.floor),
|
|
@@ -1923,17 +2139,6 @@ const jitRules = {
|
|
|
1923
2139
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1924
2140
|
[Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
|
|
1925
2141
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
|
|
1926
|
-
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1927
|
-
const mapping = (st) => {
|
|
1928
|
-
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
|
|
1929
|
-
};
|
|
1930
|
-
const k0 = reshapeViews(keys[0], mapping);
|
|
1931
|
-
const k1 = reshapeViews(keys[1], mapping);
|
|
1932
|
-
const c0 = require_backend.AluExp.u32(0);
|
|
1933
|
-
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1934
|
-
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1935
|
-
return { exp: exp$2 };
|
|
1936
|
-
},
|
|
1937
2142
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
1938
2143
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
1939
2144
|
[Primitive.Asin]: unopJit(require_backend.AluExp.asin),
|
|
@@ -1943,8 +2148,6 @@ const jitRules = {
|
|
|
1943
2148
|
[Primitive.Erf]: unopJit(require_backend.AluExp.erf),
|
|
1944
2149
|
[Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
|
|
1945
2150
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1946
|
-
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
1947
|
-
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1948
2151
|
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1949
2152
|
const keptAxes = [];
|
|
1950
2153
|
const shiftedAxes = [];
|
|
@@ -1960,7 +2163,7 @@ const jitRules = {
|
|
|
1960
2163
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
1961
2164
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
1962
2165
|
return {
|
|
1963
|
-
exp: a,
|
|
2166
|
+
exp: [a],
|
|
1964
2167
|
reduction
|
|
1965
2168
|
};
|
|
1966
2169
|
},
|
|
@@ -1971,13 +2174,13 @@ const jitRules = {
|
|
|
1971
2174
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1972
2175
|
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1973
2176
|
return {
|
|
1974
|
-
exp: a,
|
|
2177
|
+
exp: [a],
|
|
1975
2178
|
reduction
|
|
1976
2179
|
};
|
|
1977
2180
|
},
|
|
1978
2181
|
[Primitive.Dot]([a, b], [as, bs]) {
|
|
1979
2182
|
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
1980
|
-
const c = k1.exp;
|
|
2183
|
+
const [c] = k1.exp;
|
|
1981
2184
|
const cs = promoteAvals(as, bs);
|
|
1982
2185
|
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
1983
2186
|
op: require_backend.AluOp.Add,
|
|
@@ -1994,16 +2197,42 @@ const jitRules = {
|
|
|
1994
2197
|
},
|
|
1995
2198
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1996
2199
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1997
|
-
[Primitive.
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
2001
|
-
const
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2200
|
+
[Primitive.Concatenate](exps, avals, { axis }) {
|
|
2201
|
+
const ndim$2 = avals[0].ndim;
|
|
2202
|
+
const sizes = avals.map((x) => x.shape[axis]);
|
|
2203
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
2204
|
+
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
2205
|
+
let cum = 0;
|
|
2206
|
+
const src = [];
|
|
2207
|
+
for (let i = 0; i < exps.length; i++) {
|
|
2208
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
2209
|
+
src.push(reshapeViews(exps[i], (st) => st.pad(padding)));
|
|
2210
|
+
cum += sizes[i];
|
|
2211
|
+
}
|
|
2212
|
+
return { exp: [src.reduce(require_backend.AluExp.add)] };
|
|
2213
|
+
},
|
|
2214
|
+
[Primitive.Split]([a], [as], { axis, sizes }) {
|
|
2215
|
+
const exp$2 = [];
|
|
2216
|
+
let start = 0;
|
|
2217
|
+
for (const size$1 of sizes) {
|
|
2218
|
+
const slice = require_backend.range(as.ndim).map((d) => d === axis ? [start, start + size$1] : [0, as.shape[d]]);
|
|
2219
|
+
exp$2.push(reshapeViews(a, (st) => st.shrink(slice)));
|
|
2220
|
+
start += size$1;
|
|
2221
|
+
}
|
|
2222
|
+
return { exp: exp$2 };
|
|
2223
|
+
},
|
|
2224
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2225
|
+
const keyShape = keyShapes[0].shape;
|
|
2226
|
+
const mapping = (st) => {
|
|
2227
|
+
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(st.shape.length, shape$1.length));
|
|
2228
|
+
};
|
|
2229
|
+
const k0 = reshapeViews(keys[0], mapping);
|
|
2230
|
+
const k1 = reshapeViews(keys[1], mapping);
|
|
2231
|
+
const c0 = require_backend.AluExp.u32(0);
|
|
2232
|
+
const c1 = require_backend.AluExp.mod(require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx), require_backend.AluExp.u32(Math.max(require_backend.prod(shape$1.slice(keyShape.length)), 1)));
|
|
2233
|
+
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2234
|
+
return { exp: [exp$2] };
|
|
2235
|
+
},
|
|
2007
2236
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
2008
2237
|
const axisSet = new Set(axis);
|
|
2009
2238
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
@@ -2017,10 +2246,25 @@ const jitRules = {
|
|
|
2017
2246
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
2018
2247
|
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
2019
2248
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
2020
|
-
return { exp: x.substitute({ gidx: index }) };
|
|
2249
|
+
return { exp: [x.substitute({ gidx: index })] };
|
|
2021
2250
|
},
|
|
2022
|
-
[Primitive.
|
|
2023
|
-
|
|
2251
|
+
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2252
|
+
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
2253
|
+
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
2254
|
+
[Primitive.Flip]: reshapeJit((st, { axis }) => {
|
|
2255
|
+
const arg = require_backend.rep(st.shape.length, false);
|
|
2256
|
+
for (const ax of axis) arg[ax] = true;
|
|
2257
|
+
return st.flip(arg);
|
|
2258
|
+
}),
|
|
2259
|
+
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
2260
|
+
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
2261
|
+
[Primitive.Sort]: routineNoJit(),
|
|
2262
|
+
[Primitive.Argsort]: routineNoJit(),
|
|
2263
|
+
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2264
|
+
[Primitive.Cholesky]: routineNoJit(),
|
|
2265
|
+
[Primitive.LU]: routineNoJit(),
|
|
2266
|
+
[Primitive.Jit]() {
|
|
2267
|
+
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
2024
2268
|
}
|
|
2025
2269
|
};
|
|
2026
2270
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
@@ -2078,8 +2322,8 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2078
2322
|
case Primitive.Mul:
|
|
2079
2323
|
case Primitive.Idiv:
|
|
2080
2324
|
case Primitive.Mod:
|
|
2081
|
-
case Primitive.
|
|
2082
|
-
case Primitive.
|
|
2325
|
+
case Primitive.Min:
|
|
2326
|
+
case Primitive.Max: {
|
|
2083
2327
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2084
2328
|
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2085
2329
|
head = usages[0];
|
|
@@ -2099,11 +2343,11 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2099
2343
|
blackNodes.add(v);
|
|
2100
2344
|
p1NextBlack.set(v, v);
|
|
2101
2345
|
}
|
|
2102
|
-
const heterogeneousViewPrimitives = [Primitive.
|
|
2346
|
+
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2103
2347
|
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2104
2348
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2105
2349
|
const eqn = jaxpr.eqns[i];
|
|
2106
|
-
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2350
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2107
2351
|
for (const v of eqn.outBinders) {
|
|
2108
2352
|
blackNodes.add(v);
|
|
2109
2353
|
p1NextBlack.set(v, v);
|
|
@@ -2113,7 +2357,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2113
2357
|
const reach = /* @__PURE__ */ new Set();
|
|
2114
2358
|
let needsCleanOutput = false;
|
|
2115
2359
|
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2116
|
-
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
|
|
2360
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
|
|
2117
2361
|
needsCleanOutput = true;
|
|
2118
2362
|
break outer;
|
|
2119
2363
|
}
|
|
@@ -2137,7 +2381,6 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2137
2381
|
while (p2idx < jaxpr.eqns.length) {
|
|
2138
2382
|
const eqn = jaxpr.eqns[p2idx++];
|
|
2139
2383
|
const deps = [];
|
|
2140
|
-
if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
|
|
2141
2384
|
for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
|
|
2142
2385
|
else deps.push(p2Deps.get(input));
|
|
2143
2386
|
else deps.push(/* @__PURE__ */ new Set());
|
|
@@ -2160,7 +2403,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2160
2403
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2161
2404
|
const assocVar = eqn.inputs[assocInput];
|
|
2162
2405
|
p2idx = varToDefn.get(assocVar);
|
|
2163
|
-
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2406
|
+
for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
|
|
2164
2407
|
} else {
|
|
2165
2408
|
const s = new Set(depCounter.keys());
|
|
2166
2409
|
for (const out of eqn.outBinders) p2Deps.set(out, s);
|
|
@@ -2186,9 +2429,9 @@ var PendingExecute = class {
|
|
|
2186
2429
|
submitted = false;
|
|
2187
2430
|
#promise = null;
|
|
2188
2431
|
#rc = 1;
|
|
2189
|
-
constructor(backend,
|
|
2432
|
+
constructor(backend, source, inputs, outputs) {
|
|
2190
2433
|
this.backend = backend;
|
|
2191
|
-
this.
|
|
2434
|
+
this.source = source;
|
|
2192
2435
|
this.inputs = inputs;
|
|
2193
2436
|
this.outputs = outputs;
|
|
2194
2437
|
for (const slot of inputs) this.backend.incRef(slot);
|
|
@@ -2209,13 +2452,15 @@ var PendingExecute = class {
|
|
|
2209
2452
|
return;
|
|
2210
2453
|
}
|
|
2211
2454
|
this.#promise = (async () => {
|
|
2212
|
-
this.prepared = await this.backend.
|
|
2455
|
+
if (this.source instanceof require_backend.Kernel) this.prepared = await this.backend.prepareKernel(this.source);
|
|
2456
|
+
else this.prepared = await this.backend.prepareRoutine(this.source);
|
|
2213
2457
|
})();
|
|
2214
2458
|
await this.#promise;
|
|
2215
2459
|
}
|
|
2216
2460
|
prepareSync() {
|
|
2217
2461
|
if (this.prepared) return;
|
|
2218
|
-
this.prepared = this.backend.
|
|
2462
|
+
if (this.source instanceof require_backend.Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
|
|
2463
|
+
else this.prepared = this.backend.prepareRoutineSync(this.source);
|
|
2219
2464
|
}
|
|
2220
2465
|
submit() {
|
|
2221
2466
|
if (this.submitted) return;
|
|
@@ -2238,8 +2483,6 @@ var PendingExecute = class {
|
|
|
2238
2483
|
* "Array" type by name.
|
|
2239
2484
|
*/
|
|
2240
2485
|
var Array$1 = class Array$1 extends Tracer {
|
|
2241
|
-
static #nextId = 1001;
|
|
2242
|
-
id;
|
|
2243
2486
|
#dtype;
|
|
2244
2487
|
#weakType;
|
|
2245
2488
|
#source;
|
|
@@ -2256,7 +2499,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2256
2499
|
*/
|
|
2257
2500
|
constructor(args) {
|
|
2258
2501
|
super(baseArrayTrace);
|
|
2259
|
-
this.id = Array$1.#nextId++;
|
|
2260
2502
|
this.#dtype = args.dtype;
|
|
2261
2503
|
this.#weakType = args.weakType;
|
|
2262
2504
|
this.#source = args.source;
|
|
@@ -2299,6 +2541,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2299
2541
|
this.#rc++;
|
|
2300
2542
|
return this;
|
|
2301
2543
|
}
|
|
2544
|
+
/** Get the current reference count (for debugging memory management). */
|
|
2545
|
+
get refCount() {
|
|
2546
|
+
return this.#rc;
|
|
2547
|
+
}
|
|
2302
2548
|
dispose() {
|
|
2303
2549
|
this.#check();
|
|
2304
2550
|
if (--this.#rc === 0) {
|
|
@@ -2456,7 +2702,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2456
2702
|
} else if (castDtype === void 0) {
|
|
2457
2703
|
castDtype = arrays[i].#dtype;
|
|
2458
2704
|
castWeakType = arrays[i].#weakType;
|
|
2459
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType),
|
|
2705
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), arrays[i].aval.scalar()));
|
|
2460
2706
|
const weakType = castWeakType && !strongTypeOutput;
|
|
2461
2707
|
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
2462
2708
|
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
@@ -2565,6 +2811,27 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2565
2811
|
pending
|
|
2566
2812
|
});
|
|
2567
2813
|
}
|
|
2814
|
+
/** Apply an operation with custom lowering to this array. */
|
|
2815
|
+
static #routine(routine, arrays, outputWeakType) {
|
|
2816
|
+
const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
|
|
2817
|
+
for (const ar of arrays) ar.#realize();
|
|
2818
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2819
|
+
const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
|
|
2820
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2821
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2822
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2823
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2824
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2825
|
+
return outputs.map((output, i) => new Array$1({
|
|
2826
|
+
source: output,
|
|
2827
|
+
st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
|
|
2828
|
+
dtype: routine.type.outputDtypes[i],
|
|
2829
|
+
weakType: outputWeakType[i],
|
|
2830
|
+
backend,
|
|
2831
|
+
committed,
|
|
2832
|
+
pending
|
|
2833
|
+
}));
|
|
2834
|
+
}
|
|
2568
2835
|
/**
|
|
2569
2836
|
* Normalizes this array into one backed by a `Slot`.
|
|
2570
2837
|
*
|
|
@@ -2725,6 +2992,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2725
2992
|
[Primitive.Mod]([x, y]) {
|
|
2726
2993
|
return [x.#binary(require_backend.AluOp.Mod, y)];
|
|
2727
2994
|
},
|
|
2995
|
+
[Primitive.Min]([x, y]) {
|
|
2996
|
+
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
2997
|
+
},
|
|
2998
|
+
[Primitive.Max]([x, y]) {
|
|
2999
|
+
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
3000
|
+
},
|
|
2728
3001
|
[Primitive.Neg]([x]) {
|
|
2729
3002
|
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
2730
3003
|
},
|
|
@@ -2761,25 +3034,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2761
3034
|
return [y];
|
|
2762
3035
|
}
|
|
2763
3036
|
},
|
|
2764
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2765
|
-
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2766
|
-
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2767
|
-
const c0 = zeros(shape$1, {
|
|
2768
|
-
dtype: require_backend.DType.Uint32,
|
|
2769
|
-
device: k0.device
|
|
2770
|
-
});
|
|
2771
|
-
const c1 = arange(0, require_backend.prod(shape$1), 1, {
|
|
2772
|
-
dtype: require_backend.DType.Uint32,
|
|
2773
|
-
device: k0.device
|
|
2774
|
-
}).reshape(shape$1);
|
|
2775
|
-
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2776
|
-
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2777
|
-
k0,
|
|
2778
|
-
k1,
|
|
2779
|
-
c0,
|
|
2780
|
-
c1
|
|
2781
|
-
])];
|
|
2782
|
-
},
|
|
2783
3037
|
[Primitive.Sin]([x]) {
|
|
2784
3038
|
return [x.#unary(require_backend.AluOp.Sin)];
|
|
2785
3039
|
},
|
|
@@ -2807,12 +3061,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2807
3061
|
[Primitive.Sqrt]([x]) {
|
|
2808
3062
|
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
2809
3063
|
},
|
|
2810
|
-
[Primitive.Min]([x, y]) {
|
|
2811
|
-
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
2812
|
-
},
|
|
2813
|
-
[Primitive.Max]([x, y]) {
|
|
2814
|
-
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
2815
|
-
},
|
|
2816
3064
|
[Primitive.Reduce]([x], { op, axis }) {
|
|
2817
3065
|
if (axis.length === 0) return [x];
|
|
2818
3066
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
@@ -2847,6 +3095,55 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2847
3095
|
y
|
|
2848
3096
|
], { dtypeOverride: [require_backend.DType.Bool] })];
|
|
2849
3097
|
},
|
|
3098
|
+
[Primitive.Concatenate](xs, { axis }) {
|
|
3099
|
+
const ndim$2 = xs[0].ndim;
|
|
3100
|
+
const sizes = xs.map((x) => x.shape[axis]);
|
|
3101
|
+
const finalSize = sizes.reduce((a, b) => a + b, 0);
|
|
3102
|
+
const makePadAxis = (start, end) => require_backend.range(ndim$2).map((i) => i === axis ? [start, end] : [0, 0]);
|
|
3103
|
+
let cum = 0;
|
|
3104
|
+
const xsPadded = [];
|
|
3105
|
+
for (let i = 0; i < xs.length; i++) {
|
|
3106
|
+
const padding = makePadAxis(cum, finalSize - cum - sizes[i]);
|
|
3107
|
+
xsPadded.push(xs[i].#reshape(xs[i].#st.pad(padding)));
|
|
3108
|
+
cum += sizes[i];
|
|
3109
|
+
}
|
|
3110
|
+
const custom = (exps) => exps.reduce(require_backend.AluExp.add);
|
|
3111
|
+
return [Array$1.#naryCustom("concatenate", custom, xsPadded)];
|
|
3112
|
+
},
|
|
3113
|
+
[Primitive.Split]([x], { axis, sizes }) {
|
|
3114
|
+
const outputs = [];
|
|
3115
|
+
for (let i = 0, start = 0; i < sizes.length; i++) {
|
|
3116
|
+
const slice = require_backend.range(x.ndim).map((d) => d === axis ? [start, start + sizes[i]] : [0, x.shape[d]]);
|
|
3117
|
+
outputs.push(x.ref.#reshape(x.#st.shrink(slice)));
|
|
3118
|
+
start += sizes[i];
|
|
3119
|
+
}
|
|
3120
|
+
x.dispose();
|
|
3121
|
+
return outputs;
|
|
3122
|
+
},
|
|
3123
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
3124
|
+
const keyShape = k0.shape;
|
|
3125
|
+
const genShape = shape$1.slice(keyShape.length);
|
|
3126
|
+
const c0 = zeros(genShape, {
|
|
3127
|
+
dtype: require_backend.DType.Uint32,
|
|
3128
|
+
device: k0.device
|
|
3129
|
+
});
|
|
3130
|
+
const c1 = arange(0, require_backend.prod(genShape), 1, {
|
|
3131
|
+
dtype: require_backend.DType.Uint32,
|
|
3132
|
+
device: k0.device
|
|
3133
|
+
}).reshape(genShape);
|
|
3134
|
+
k0 = k0.#reshape(k0.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
|
|
3135
|
+
k1 = k1.#reshape(k1.#st.reshape(keyShape.concat(require_backend.rep(genShape.length, 1))));
|
|
3136
|
+
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
3137
|
+
return [Array$1.#naryCustom("random_bits", custom, [
|
|
3138
|
+
k0,
|
|
3139
|
+
k1,
|
|
3140
|
+
c0,
|
|
3141
|
+
c1
|
|
3142
|
+
])];
|
|
3143
|
+
},
|
|
3144
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
3145
|
+
return [x.#gather(indices, axis, outDim)];
|
|
3146
|
+
},
|
|
2850
3147
|
[Primitive.Transpose]([x], { perm }) {
|
|
2851
3148
|
return [x.#transpose(perm)];
|
|
2852
3149
|
},
|
|
@@ -2867,17 +3164,71 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2867
3164
|
[Primitive.Pad]([x], { width }) {
|
|
2868
3165
|
return [x.#reshape(x.#st.pad(width))];
|
|
2869
3166
|
},
|
|
2870
|
-
[Primitive.
|
|
2871
|
-
|
|
3167
|
+
[Primitive.Sort]([x]) {
|
|
3168
|
+
const routine = new require_backend.Routine(require_backend.Routines.Sort, {
|
|
3169
|
+
inputShapes: [x.shape],
|
|
3170
|
+
inputDtypes: [x.dtype],
|
|
3171
|
+
outputShapes: [x.shape],
|
|
3172
|
+
outputDtypes: [x.dtype]
|
|
3173
|
+
});
|
|
3174
|
+
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3175
|
+
},
|
|
3176
|
+
[Primitive.Argsort]([x]) {
|
|
3177
|
+
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3178
|
+
inputShapes: [x.shape],
|
|
3179
|
+
inputDtypes: [x.dtype],
|
|
3180
|
+
outputShapes: [x.shape, x.shape],
|
|
3181
|
+
outputDtypes: [x.dtype, require_backend.DType.Int32]
|
|
3182
|
+
});
|
|
3183
|
+
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3184
|
+
},
|
|
3185
|
+
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3186
|
+
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3187
|
+
inputShapes: [a.shape, b.shape],
|
|
3188
|
+
inputDtypes: [a.dtype, b.dtype],
|
|
3189
|
+
outputShapes: [b.shape],
|
|
3190
|
+
outputDtypes: [b.dtype]
|
|
3191
|
+
}, { unitDiagonal });
|
|
3192
|
+
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
2872
3193
|
},
|
|
2873
|
-
[Primitive.
|
|
2874
|
-
|
|
2875
|
-
|
|
3194
|
+
[Primitive.Cholesky]([a]) {
|
|
3195
|
+
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3196
|
+
inputShapes: [a.shape],
|
|
3197
|
+
inputDtypes: [a.dtype],
|
|
3198
|
+
outputShapes: [a.shape],
|
|
3199
|
+
outputDtypes: [a.dtype]
|
|
3200
|
+
});
|
|
3201
|
+
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3202
|
+
},
|
|
3203
|
+
[Primitive.LU]([a]) {
|
|
3204
|
+
const batch = a.shape.slice(0, -2);
|
|
3205
|
+
const [m, n] = a.shape.slice(-2);
|
|
3206
|
+
const routine = new require_backend.Routine(require_backend.Routines.LU, {
|
|
3207
|
+
inputShapes: [a.shape],
|
|
3208
|
+
inputDtypes: [a.dtype],
|
|
3209
|
+
outputShapes: [
|
|
3210
|
+
a.shape,
|
|
3211
|
+
[...batch, Math.min(m, n)],
|
|
3212
|
+
[...batch, m]
|
|
3213
|
+
],
|
|
3214
|
+
outputDtypes: [
|
|
3215
|
+
a.dtype,
|
|
3216
|
+
require_backend.DType.Int32,
|
|
3217
|
+
require_backend.DType.Int32
|
|
3218
|
+
]
|
|
3219
|
+
});
|
|
3220
|
+
return Array$1.#routine(routine, [a], [
|
|
3221
|
+
a.#weakType,
|
|
3222
|
+
false,
|
|
3223
|
+
false
|
|
3224
|
+
]);
|
|
3225
|
+
},
|
|
3226
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
3227
|
+
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3228
|
+
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
2876
3229
|
args = args.map((ar) => ar._putSync(backend));
|
|
2877
|
-
const
|
|
2878
|
-
const
|
|
2879
|
-
const jp = jitCompile(backend, jaxpr, consts);
|
|
2880
|
-
const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
|
|
3230
|
+
const jp = jitCompile(backend, jaxpr);
|
|
3231
|
+
const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
|
|
2881
3232
|
for (const exe of pending) exe.updateRc(+outputs.length - 1);
|
|
2882
3233
|
const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
|
|
2883
3234
|
for (const exe of prevPending) exe.updateRc(+outputs.length);
|
|
@@ -2977,7 +3328,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2977
3328
|
device
|
|
2978
3329
|
});
|
|
2979
3330
|
} else {
|
|
2980
|
-
const weakType = dtype == void 0;
|
|
3331
|
+
const weakType = dtype == void 0 && shape$1.length === 0;
|
|
2981
3332
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2982
3333
|
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
2983
3334
|
return arrayFromData(data, shape$1, {
|
|
@@ -3091,7 +3442,7 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
3091
3442
|
}
|
|
3092
3443
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
3093
3444
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
3094
|
-
let weakType = dtype == void 0;
|
|
3445
|
+
let weakType = dtype == void 0 && shape$1.length === 0;
|
|
3095
3446
|
if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
|
|
3096
3447
|
else if (typeof fillValue === "boolean") {
|
|
3097
3448
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
@@ -3176,6 +3527,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
3176
3527
|
});
|
|
3177
3528
|
}
|
|
3178
3529
|
/**
|
|
3530
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
3531
|
+
*
|
|
3532
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
3533
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
3534
|
+
* `k>0` is above it.
|
|
3535
|
+
*/
|
|
3536
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
3537
|
+
m ??= n;
|
|
3538
|
+
dtype ??= require_backend.DType.Float32;
|
|
3539
|
+
if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
|
|
3540
|
+
if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
|
|
3541
|
+
if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
|
|
3542
|
+
const rows = arange(k, n + k, 1, {
|
|
3543
|
+
dtype: require_backend.DType.Int32,
|
|
3544
|
+
device
|
|
3545
|
+
});
|
|
3546
|
+
const cols = arange(0, m, 1, {
|
|
3547
|
+
dtype: require_backend.DType.Int32,
|
|
3548
|
+
device
|
|
3549
|
+
});
|
|
3550
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
3551
|
+
}
|
|
3552
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
3553
|
+
function tril(a, k = 0) {
|
|
3554
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3555
|
+
a = fudgeArray(a);
|
|
3556
|
+
const [n, m] = a.shape.slice(-2);
|
|
3557
|
+
return where$1(tri(n, m, k, { dtype: require_backend.DType.Bool }), a.ref, zerosLike$1(a));
|
|
3558
|
+
}
|
|
3559
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
3560
|
+
function triu(a, k = 0) {
|
|
3561
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3562
|
+
a = fudgeArray(a);
|
|
3563
|
+
const [n, m] = a.shape.slice(-2);
|
|
3564
|
+
return where$1(tri(n, m, k - 1, { dtype: require_backend.DType.Bool }), zerosLike$1(a.ref), a);
|
|
3565
|
+
}
|
|
3566
|
+
/**
|
|
3179
3567
|
* Return evenly spaced numbers over a specified interval.
|
|
3180
3568
|
*
|
|
3181
3569
|
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
@@ -3212,6 +3600,27 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
3212
3600
|
committed: device != void 0
|
|
3213
3601
|
});
|
|
3214
3602
|
}
|
|
3603
|
+
/**
|
|
3604
|
+
* Return numbers spaced evenly on a log scale.
|
|
3605
|
+
*
|
|
3606
|
+
* In linear space, the sequence starts at `base ** start` and ends at
|
|
3607
|
+
* `base ** stop` (see `endpoint` below).
|
|
3608
|
+
*
|
|
3609
|
+
* @param start - `base ** start` is the starting value of the sequence.
|
|
3610
|
+
* @param stop - `base ** stop` is the final value of the sequence, unless `endpoint` is false.
|
|
3611
|
+
* @param num - Number of samples to generate. Default is 50.
|
|
3612
|
+
* @param endpoint - If true, `stop` is the last sample. Otherwise, it is not included. Default is true.
|
|
3613
|
+
* @param base - The base of the log space. Default is 10.
|
|
3614
|
+
* @returns Array of evenly spaced values on a log scale.
|
|
3615
|
+
*/
|
|
3616
|
+
function logspace(start, stop, num = 50, endpoint = true, base = 10, { dtype, device } = {}) {
|
|
3617
|
+
const y = linspace(start, stop, num, endpoint, {
|
|
3618
|
+
dtype,
|
|
3619
|
+
device
|
|
3620
|
+
});
|
|
3621
|
+
const logBase = Math.log(base);
|
|
3622
|
+
return exp$1(mul(y, logBase));
|
|
3623
|
+
}
|
|
3215
3624
|
function aluCompare(a, b, op) {
|
|
3216
3625
|
switch (op) {
|
|
3217
3626
|
case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
|
|
@@ -3222,385 +3631,211 @@ function aluCompare(a, b, op) {
|
|
|
3222
3631
|
}
|
|
3223
3632
|
|
|
3224
3633
|
//#endregion
|
|
3225
|
-
//#region src/frontend/
|
|
3634
|
+
//#region src/frontend/vmap.ts
|
|
3226
3635
|
var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3227
|
-
|
|
3228
|
-
|
|
3636
|
+
function mappedAval(batchDim, aval) {
|
|
3637
|
+
const shape$1 = [...aval.shape];
|
|
3638
|
+
shape$1.splice(batchDim, 1);
|
|
3639
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3640
|
+
}
|
|
3641
|
+
/** Move one axis to a different index. */
|
|
3642
|
+
function moveaxis(x, src, dst) {
|
|
3643
|
+
const t = pureArray(x);
|
|
3644
|
+
src = require_backend.checkAxis(src, t.ndim);
|
|
3645
|
+
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3646
|
+
if (src === dst) return t;
|
|
3647
|
+
const perm = require_backend.range(t.ndim);
|
|
3648
|
+
perm.splice(src, 1);
|
|
3649
|
+
perm.splice(dst, 0, src);
|
|
3650
|
+
return transpose$1(t, perm);
|
|
3651
|
+
}
|
|
3652
|
+
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3653
|
+
if (src === null) {
|
|
3654
|
+
const targetShape = [...x.shape];
|
|
3655
|
+
targetShape.splice(dst, 0, axisSize);
|
|
3656
|
+
return broadcast(x, targetShape, [dst]);
|
|
3657
|
+
} else if (src === dst) return x;
|
|
3658
|
+
else return moveaxis(x, src, dst);
|
|
3659
|
+
}
|
|
3660
|
+
var BatchTracer = class extends Tracer {
|
|
3661
|
+
constructor(trace$1, val, batchDim) {
|
|
3229
3662
|
super(trace$1);
|
|
3230
|
-
this.
|
|
3231
|
-
this.
|
|
3663
|
+
this.val = val;
|
|
3664
|
+
this.batchDim = batchDim;
|
|
3232
3665
|
}
|
|
3233
3666
|
get aval() {
|
|
3234
|
-
return this.
|
|
3667
|
+
if (this.batchDim === null) return this.val.aval;
|
|
3668
|
+
else return mappedAval(this.batchDim, this.val.aval);
|
|
3235
3669
|
}
|
|
3236
3670
|
toString() {
|
|
3237
|
-
return `
|
|
3671
|
+
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3238
3672
|
}
|
|
3239
3673
|
get ref() {
|
|
3240
|
-
this.
|
|
3674
|
+
this.val.ref;
|
|
3241
3675
|
return this;
|
|
3242
3676
|
}
|
|
3243
3677
|
dispose() {
|
|
3244
|
-
this.
|
|
3245
|
-
|
|
3678
|
+
this.val.dispose();
|
|
3679
|
+
}
|
|
3680
|
+
fullLower() {
|
|
3681
|
+
if (this.batchDim === null) return this.val.fullLower();
|
|
3682
|
+
else return this;
|
|
3246
3683
|
}
|
|
3247
3684
|
};
|
|
3248
|
-
var
|
|
3685
|
+
var BatchTrace = class extends Trace {
|
|
3249
3686
|
pure(val) {
|
|
3250
3687
|
return this.lift(pureArray(val));
|
|
3251
3688
|
}
|
|
3252
3689
|
lift(val) {
|
|
3253
|
-
return new
|
|
3690
|
+
return new BatchTracer(this, val, null);
|
|
3254
3691
|
}
|
|
3255
3692
|
processPrimitive(primitive, tracers, params) {
|
|
3256
|
-
const [
|
|
3257
|
-
const
|
|
3258
|
-
if (
|
|
3259
|
-
|
|
3260
|
-
|
|
3693
|
+
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3694
|
+
const vmapRule = vmapRules[primitive];
|
|
3695
|
+
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3696
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3697
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3698
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3699
|
+
}
|
|
3700
|
+
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3701
|
+
if (valOuts.length !== bdimOuts.length) throw new Error(`vmap rule for ${primitive} returned mismatched lengths: ${valOuts.length} vs ${bdimOuts.length}`);
|
|
3702
|
+
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3703
|
+
}
|
|
3704
|
+
get axisSize() {
|
|
3705
|
+
return this.main.globalData;
|
|
3261
3706
|
}
|
|
3262
3707
|
};
|
|
3263
|
-
/**
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3708
|
+
/**
|
|
3709
|
+
* Process a primitive with built-in broadcasting.
|
|
3710
|
+
*
|
|
3711
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3712
|
+
*/
|
|
3713
|
+
function broadcastBatcher(prim) {
|
|
3714
|
+
return (axisSize, args, dims, params) => {
|
|
3715
|
+
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3716
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3717
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3718
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3719
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[bind1(prim, args, params)], [nd + firstBdim]];
|
|
3720
|
+
args = args.map((x, i) => {
|
|
3721
|
+
if (dims[i] === null) return x;
|
|
3722
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3723
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3724
|
+
x.shape[0],
|
|
3725
|
+
...require_backend.rep(nd - x.ndim, 1),
|
|
3726
|
+
...x.shape.slice(1)
|
|
3727
|
+
]);
|
|
3728
|
+
return x;
|
|
3729
|
+
});
|
|
3730
|
+
return [[bind1(prim, args, params)], [0]];
|
|
3269
3731
|
};
|
|
3270
3732
|
}
|
|
3271
|
-
|
|
3272
|
-
|
|
3273
|
-
|
|
3274
|
-
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3275
|
-
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3276
|
-
return [[primal], [tangent]];
|
|
3733
|
+
function unopBatcher(prim) {
|
|
3734
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3735
|
+
return [[bind1(prim, [x], params)], [xBdim]];
|
|
3277
3736
|
};
|
|
3278
3737
|
}
|
|
3279
|
-
|
|
3280
|
-
|
|
3281
|
-
|
|
3282
|
-
|
|
3283
|
-
|
|
3284
|
-
return [
|
|
3738
|
+
function lastDimsBatcher(prim, inputDims, numOutputs = 1) {
|
|
3739
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3740
|
+
require_backend.assertNonNull(xBdim);
|
|
3741
|
+
if (xBdim < x.ndim - inputDims) return [bind(prim, [x], params), require_backend.rep(numOutputs, xBdim)];
|
|
3742
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3743
|
+
return [bind(prim, [x], params), require_backend.rep(numOutputs, 0)];
|
|
3285
3744
|
};
|
|
3286
3745
|
}
|
|
3287
|
-
const
|
|
3288
|
-
[Primitive.Add]:
|
|
3289
|
-
[Primitive.Mul]:
|
|
3290
|
-
[Primitive.Idiv]:
|
|
3291
|
-
[Primitive.Mod](
|
|
3292
|
-
|
|
3293
|
-
|
|
3294
|
-
|
|
3295
|
-
|
|
3296
|
-
|
|
3297
|
-
|
|
3298
|
-
|
|
3299
|
-
|
|
3300
|
-
[Primitive.
|
|
3301
|
-
[Primitive.
|
|
3302
|
-
|
|
3303
|
-
|
|
3304
|
-
|
|
3305
|
-
[Primitive.
|
|
3306
|
-
[Primitive.
|
|
3307
|
-
[Primitive.
|
|
3308
|
-
[Primitive.
|
|
3309
|
-
|
|
3310
|
-
|
|
3311
|
-
|
|
3312
|
-
|
|
3313
|
-
|
|
3314
|
-
|
|
3315
|
-
},
|
|
3316
|
-
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3317
|
-
if (x.dtype === dtype) return [[x], [dx]];
|
|
3318
|
-
dx.dispose();
|
|
3319
|
-
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3320
|
-
},
|
|
3321
|
-
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3322
|
-
[Primitive.Sin]([x], [dx]) {
|
|
3323
|
-
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3746
|
+
const vmapRules = {
|
|
3747
|
+
[Primitive.Add]: broadcastBatcher(Primitive.Add),
|
|
3748
|
+
[Primitive.Mul]: broadcastBatcher(Primitive.Mul),
|
|
3749
|
+
[Primitive.Idiv]: broadcastBatcher(Primitive.Idiv),
|
|
3750
|
+
[Primitive.Mod]: broadcastBatcher(Primitive.Mod),
|
|
3751
|
+
[Primitive.Min]: broadcastBatcher(Primitive.Min),
|
|
3752
|
+
[Primitive.Max]: broadcastBatcher(Primitive.Max),
|
|
3753
|
+
[Primitive.Neg]: unopBatcher(Primitive.Neg),
|
|
3754
|
+
[Primitive.Reciprocal]: unopBatcher(Primitive.Reciprocal),
|
|
3755
|
+
[Primitive.Floor]: unopBatcher(Primitive.Floor),
|
|
3756
|
+
[Primitive.Ceil]: unopBatcher(Primitive.Ceil),
|
|
3757
|
+
[Primitive.StopGradient]: unopBatcher(Primitive.StopGradient),
|
|
3758
|
+
[Primitive.Cast]: unopBatcher(Primitive.Cast),
|
|
3759
|
+
[Primitive.Bitcast]: unopBatcher(Primitive.Bitcast),
|
|
3760
|
+
[Primitive.Sin]: unopBatcher(Primitive.Sin),
|
|
3761
|
+
[Primitive.Cos]: unopBatcher(Primitive.Cos),
|
|
3762
|
+
[Primitive.Asin]: unopBatcher(Primitive.Asin),
|
|
3763
|
+
[Primitive.Atan]: unopBatcher(Primitive.Atan),
|
|
3764
|
+
[Primitive.Exp]: unopBatcher(Primitive.Exp),
|
|
3765
|
+
[Primitive.Log]: unopBatcher(Primitive.Log),
|
|
3766
|
+
[Primitive.Erf]: unopBatcher(Primitive.Erf),
|
|
3767
|
+
[Primitive.Erfc]: unopBatcher(Primitive.Erfc),
|
|
3768
|
+
[Primitive.Sqrt]: unopBatcher(Primitive.Sqrt),
|
|
3769
|
+
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3770
|
+
require_backend.assertNonNull(xBdim);
|
|
3771
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3772
|
+
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3773
|
+
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3324
3774
|
},
|
|
3325
|
-
[Primitive.
|
|
3326
|
-
|
|
3775
|
+
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3776
|
+
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3777
|
+
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3778
|
+
const z = dot$2(x, y);
|
|
3779
|
+
return [[z], [z.ndim - 1]];
|
|
3327
3780
|
},
|
|
3328
|
-
[Primitive.
|
|
3329
|
-
|
|
3330
|
-
|
|
3781
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3782
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3783
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3784
|
+
const z = conv$1(x, y, {
|
|
3785
|
+
...params,
|
|
3786
|
+
vmapDims: params.vmapDims + 1
|
|
3787
|
+
});
|
|
3788
|
+
return [[z], [0]];
|
|
3331
3789
|
},
|
|
3332
|
-
[Primitive.
|
|
3333
|
-
|
|
3334
|
-
|
|
3790
|
+
[Primitive.Compare]: broadcastBatcher(Primitive.Compare),
|
|
3791
|
+
[Primitive.Where]: broadcastBatcher(Primitive.Where),
|
|
3792
|
+
[Primitive.Concatenate](axisSize, xs, xBdims, { axis }) {
|
|
3793
|
+
const minBdim = Math.min(...xBdims.filter((d) => d !== null));
|
|
3794
|
+
xs = xs.map((x, i) => moveBatchAxis(axisSize, xBdims[i], minBdim, x));
|
|
3795
|
+
const newAxis = axis + (minBdim <= axis ? 1 : 0);
|
|
3796
|
+
return [[concatenate$1(xs, newAxis)], [minBdim]];
|
|
3335
3797
|
},
|
|
3336
|
-
[Primitive.
|
|
3337
|
-
|
|
3338
|
-
|
|
3798
|
+
[Primitive.Split](axisSize, [x], [xBdim], { axis, sizes }) {
|
|
3799
|
+
require_backend.assertNonNull(xBdim);
|
|
3800
|
+
const newAxis = axis + (xBdim <= axis ? 1 : 0);
|
|
3801
|
+
const outs = split$2(x, newAxis, sizes);
|
|
3802
|
+
return [outs, require_backend.rep(outs.length, xBdim)];
|
|
3339
3803
|
},
|
|
3340
|
-
[Primitive.
|
|
3341
|
-
|
|
3804
|
+
[Primitive.RandomBits](axisSize, [k0, k1], [bdim0, bdim1], { shape: shape$1, mode }) {
|
|
3805
|
+
k0 = moveBatchAxis(axisSize, bdim0, 0, k0);
|
|
3806
|
+
k1 = moveBatchAxis(axisSize, bdim1, 0, k1);
|
|
3807
|
+
return [[randomBits(k0, k1, [axisSize, ...shape$1], mode)], [0]];
|
|
3342
3808
|
},
|
|
3343
|
-
[Primitive.
|
|
3344
|
-
|
|
3345
|
-
|
|
3346
|
-
|
|
3347
|
-
|
|
3348
|
-
|
|
3349
|
-
|
|
3350
|
-
|
|
3351
|
-
|
|
3352
|
-
},
|
|
3353
|
-
[Primitive.Sqrt]([x], [dx]) {
|
|
3354
|
-
const z = sqrt$1(x);
|
|
3355
|
-
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3356
|
-
},
|
|
3357
|
-
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3358
|
-
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3359
|
-
},
|
|
3360
|
-
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3361
|
-
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3362
|
-
},
|
|
3363
|
-
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3364
|
-
if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3365
|
-
else if (op === require_backend.AluOp.Mul) {
|
|
3366
|
-
const primal = reduce(x.ref, op, axis);
|
|
3367
|
-
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3368
|
-
return [[primal], [tangent]];
|
|
3369
|
-
} else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
|
|
3370
|
-
const primal = reduce(x.ref, op, axis);
|
|
3371
|
-
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3372
|
-
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3373
|
-
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3374
|
-
return [[primal], [tangent]];
|
|
3375
|
-
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3376
|
-
},
|
|
3377
|
-
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3378
|
-
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3379
|
-
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3380
|
-
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3381
|
-
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3382
|
-
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3383
|
-
dcond.dispose();
|
|
3384
|
-
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3385
|
-
},
|
|
3386
|
-
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3387
|
-
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3388
|
-
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3389
|
-
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3390
|
-
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3391
|
-
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3392
|
-
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3393
|
-
const indicesRef = indices.map((t) => t.ref);
|
|
3394
|
-
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3395
|
-
},
|
|
3396
|
-
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3397
|
-
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3398
|
-
const outs = bind(Primitive.JitCall, [
|
|
3399
|
-
...newConsts.map((c) => c.ref),
|
|
3400
|
-
...primals,
|
|
3401
|
-
...tangents
|
|
3402
|
-
], {
|
|
3403
|
-
name: `${name}_jvp`,
|
|
3404
|
-
jaxpr: newJaxpr,
|
|
3405
|
-
numConsts: newConsts.length
|
|
3406
|
-
});
|
|
3407
|
-
const n = outs.length / 2;
|
|
3408
|
-
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3409
|
-
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3410
|
-
return [primalsOut, tangentsOut];
|
|
3411
|
-
}
|
|
3412
|
-
};
|
|
3413
|
-
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3414
|
-
function jvpJaxpr(jaxpr) {
|
|
3415
|
-
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3416
|
-
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3417
|
-
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3418
|
-
const result = {
|
|
3419
|
-
newJaxpr,
|
|
3420
|
-
newConsts
|
|
3421
|
-
};
|
|
3422
|
-
jvpJaxprCache.set(jaxpr, result);
|
|
3423
|
-
return result;
|
|
3424
|
-
}
|
|
3425
|
-
function jvpFlat(f, primals, tangents) {
|
|
3426
|
-
try {
|
|
3427
|
-
var _usingCtx$1 = (0, import_usingCtx$1.default)();
|
|
3428
|
-
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3429
|
-
const trace$1 = new JVPTrace(main);
|
|
3430
|
-
const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3431
|
-
const outs = f(...tracersIn);
|
|
3432
|
-
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3433
|
-
return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3434
|
-
} catch (_) {
|
|
3435
|
-
_usingCtx$1.e = _;
|
|
3436
|
-
} finally {
|
|
3437
|
-
_usingCtx$1.d();
|
|
3438
|
-
}
|
|
3439
|
-
}
|
|
3440
|
-
function jvp$1(f, primals, tangents) {
|
|
3441
|
-
const [primalsFlat, inTree] = flatten(primals);
|
|
3442
|
-
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
3443
|
-
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
3444
|
-
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
3445
|
-
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
3446
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
3447
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3448
|
-
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
3449
|
-
return [primalsOut, tangentsOut];
|
|
3450
|
-
}
|
|
3451
|
-
|
|
3452
|
-
//#endregion
|
|
3453
|
-
//#region src/frontend/vmap.ts
|
|
3454
|
-
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3455
|
-
function mappedAval(batchDim, aval) {
|
|
3456
|
-
const shape$1 = [...aval.shape];
|
|
3457
|
-
shape$1.splice(batchDim, 1);
|
|
3458
|
-
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3459
|
-
}
|
|
3460
|
-
/** Move one axis to a different index. */
|
|
3461
|
-
function moveaxis(x, src, dst) {
|
|
3462
|
-
const t = pureArray(x);
|
|
3463
|
-
src = require_backend.checkAxis(src, t.ndim);
|
|
3464
|
-
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3465
|
-
if (src === dst) return t;
|
|
3466
|
-
const perm = require_backend.range(t.ndim);
|
|
3467
|
-
perm.splice(src, 1);
|
|
3468
|
-
perm.splice(dst, 0, src);
|
|
3469
|
-
return transpose$1(t, perm);
|
|
3470
|
-
}
|
|
3471
|
-
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3472
|
-
if (src === null) {
|
|
3473
|
-
const targetShape = [...x.shape];
|
|
3474
|
-
targetShape.splice(dst, 0, axisSize);
|
|
3475
|
-
return broadcast(x, targetShape, [dst]);
|
|
3476
|
-
} else if (src === dst) return x;
|
|
3477
|
-
else return moveaxis(x, src, dst);
|
|
3478
|
-
}
|
|
3479
|
-
var BatchTracer = class extends Tracer {
|
|
3480
|
-
constructor(trace$1, val, batchDim) {
|
|
3481
|
-
super(trace$1);
|
|
3482
|
-
this.val = val;
|
|
3483
|
-
this.batchDim = batchDim;
|
|
3484
|
-
}
|
|
3485
|
-
get aval() {
|
|
3486
|
-
if (this.batchDim === null) return this.val.aval;
|
|
3487
|
-
else return mappedAval(this.batchDim, this.val.aval);
|
|
3488
|
-
}
|
|
3489
|
-
toString() {
|
|
3490
|
-
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3491
|
-
}
|
|
3492
|
-
get ref() {
|
|
3493
|
-
this.val.ref;
|
|
3494
|
-
return this;
|
|
3495
|
-
}
|
|
3496
|
-
dispose() {
|
|
3497
|
-
this.val.dispose();
|
|
3498
|
-
}
|
|
3499
|
-
fullLower() {
|
|
3500
|
-
if (this.batchDim === null) return this.val.fullLower();
|
|
3501
|
-
else return this;
|
|
3502
|
-
}
|
|
3503
|
-
};
|
|
3504
|
-
var BatchTrace = class extends Trace {
|
|
3505
|
-
pure(val) {
|
|
3506
|
-
return this.lift(pureArray(val));
|
|
3507
|
-
}
|
|
3508
|
-
lift(val) {
|
|
3509
|
-
return new BatchTracer(this, val, null);
|
|
3510
|
-
}
|
|
3511
|
-
processPrimitive(primitive, tracers, params) {
|
|
3512
|
-
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3513
|
-
const vmapRule = vmapRules[primitive];
|
|
3514
|
-
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3515
|
-
if (bdimsIn.every((d) => d === null)) {
|
|
3516
|
-
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3517
|
-
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3809
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3810
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3811
|
+
require_backend.assertNonNull(xBdim);
|
|
3812
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3813
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3814
|
+
let newOutDim = outDim;
|
|
3815
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3816
|
+
else newOutDim += 1;
|
|
3817
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3518
3818
|
}
|
|
3519
|
-
const
|
|
3520
|
-
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
|
|
3524
|
-
|
|
3525
|
-
|
|
3526
|
-
|
|
3527
|
-
* Process a primitive with built-in broadcasting.
|
|
3528
|
-
*
|
|
3529
|
-
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3530
|
-
*/
|
|
3531
|
-
function broadcastBatcher(op) {
|
|
3532
|
-
return (axisSize, args, dims) => {
|
|
3533
|
-
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3534
|
-
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3535
|
-
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3536
|
-
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3537
|
-
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3538
|
-
args = args.map((x, i) => {
|
|
3539
|
-
if (dims[i] === null) return x;
|
|
3540
|
-
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3541
|
-
if (x.ndim < nd) x = x.reshape([
|
|
3542
|
-
x.shape[0],
|
|
3543
|
-
...require_backend.rep(nd - x.ndim, 1),
|
|
3544
|
-
...x.shape.slice(1)
|
|
3819
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3820
|
+
indices = indices.map((m, i) => {
|
|
3821
|
+
if (indicesBdim[i] === null) return m;
|
|
3822
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3823
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3824
|
+
m.shape[0],
|
|
3825
|
+
...require_backend.rep(nd - m.ndim, 1),
|
|
3826
|
+
...m.shape.slice(1)
|
|
3545
3827
|
]);
|
|
3546
|
-
return
|
|
3547
|
-
});
|
|
3548
|
-
return [[op(...args)], [0]];
|
|
3549
|
-
};
|
|
3550
|
-
}
|
|
3551
|
-
function unopBatcher(op) {
|
|
3552
|
-
return (axisSize, [x], [xBdim], params) => {
|
|
3553
|
-
return [[op(x, params)], [xBdim]];
|
|
3554
|
-
};
|
|
3555
|
-
}
|
|
3556
|
-
const vmapRules = {
|
|
3557
|
-
[Primitive.Add]: broadcastBatcher(add$1),
|
|
3558
|
-
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3559
|
-
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3560
|
-
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3561
|
-
[Primitive.Neg]: unopBatcher(neg),
|
|
3562
|
-
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3563
|
-
[Primitive.Floor]: unopBatcher(floor$1),
|
|
3564
|
-
[Primitive.Ceil]: unopBatcher(ceil$1),
|
|
3565
|
-
[Primitive.StopGradient]: unopBatcher(stopGradient),
|
|
3566
|
-
[Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
|
|
3567
|
-
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3568
|
-
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3569
|
-
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3570
|
-
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3571
|
-
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3572
|
-
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3573
|
-
[Primitive.Log]: unopBatcher(log$1),
|
|
3574
|
-
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3575
|
-
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3576
|
-
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3577
|
-
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3578
|
-
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3579
|
-
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3580
|
-
require_backend.assertNonNull(xBdim);
|
|
3581
|
-
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3582
|
-
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3583
|
-
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3584
|
-
},
|
|
3585
|
-
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3586
|
-
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3587
|
-
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3588
|
-
const z = dot$2(x, y);
|
|
3589
|
-
return [[z], [z.ndim - 1]];
|
|
3590
|
-
},
|
|
3591
|
-
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3592
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3593
|
-
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3594
|
-
const z = conv$1(x, y, {
|
|
3595
|
-
...params,
|
|
3596
|
-
vmapDims: params.vmapDims + 1
|
|
3828
|
+
return m;
|
|
3597
3829
|
});
|
|
3598
|
-
return [[
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3830
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3831
|
+
else {
|
|
3832
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3833
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3834
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
|
|
3835
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3836
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3837
|
+
}
|
|
3602
3838
|
},
|
|
3603
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3604
3839
|
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3605
3840
|
require_backend.assertNonNull(xBdim);
|
|
3606
3841
|
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
@@ -3632,42 +3867,39 @@ const vmapRules = {
|
|
|
3632
3867
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3633
3868
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3634
3869
|
},
|
|
3635
|
-
[Primitive.
|
|
3636
|
-
|
|
3637
|
-
|
|
3638
|
-
|
|
3639
|
-
|
|
3640
|
-
|
|
3641
|
-
|
|
3642
|
-
|
|
3643
|
-
|
|
3644
|
-
|
|
3645
|
-
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3646
|
-
indices = indices.map((m, i) => {
|
|
3647
|
-
if (indicesBdim[i] === null) return m;
|
|
3648
|
-
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3649
|
-
if (m.ndim < nd) m = m.reshape([
|
|
3650
|
-
m.shape[0],
|
|
3651
|
-
...require_backend.rep(nd - m.ndim, 1),
|
|
3652
|
-
...m.shape.slice(1)
|
|
3870
|
+
[Primitive.Sort]: lastDimsBatcher(Primitive.Sort, 1),
|
|
3871
|
+
[Primitive.Argsort]: lastDimsBatcher(Primitive.Argsort, 1, 2),
|
|
3872
|
+
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3873
|
+
if (aBdim === null) {
|
|
3874
|
+
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
3875
|
+
const [s, m, n] = b.shape.slice(-3);
|
|
3876
|
+
b = b.reshape([
|
|
3877
|
+
...b.shape.slice(0, -3),
|
|
3878
|
+
s * m,
|
|
3879
|
+
n
|
|
3653
3880
|
]);
|
|
3654
|
-
|
|
3655
|
-
|
|
3656
|
-
|
|
3657
|
-
|
|
3658
|
-
|
|
3659
|
-
|
|
3660
|
-
|
|
3661
|
-
|
|
3662
|
-
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3881
|
+
let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3882
|
+
x$1 = x$1.reshape([
|
|
3883
|
+
...b.shape.slice(0, -2),
|
|
3884
|
+
s,
|
|
3885
|
+
m,
|
|
3886
|
+
n
|
|
3887
|
+
]);
|
|
3888
|
+
return [[x$1], [x$1.ndim - 3]];
|
|
3663
3889
|
}
|
|
3890
|
+
a = moveBatchAxis(axisSize, aBdim, 0, a);
|
|
3891
|
+
b = moveBatchAxis(axisSize, bBdim, 0, b);
|
|
3892
|
+
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3893
|
+
return [[x], [0]];
|
|
3664
3894
|
},
|
|
3665
|
-
[Primitive.
|
|
3666
|
-
|
|
3667
|
-
|
|
3895
|
+
[Primitive.Cholesky]: lastDimsBatcher(Primitive.Cholesky, 2),
|
|
3896
|
+
[Primitive.LU]: lastDimsBatcher(Primitive.LU, 2, 3),
|
|
3897
|
+
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3898
|
+
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3899
|
+
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
3668
3900
|
name: `${name}_vmap`,
|
|
3669
|
-
jaxpr: newJaxpr,
|
|
3670
|
-
numConsts:
|
|
3901
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3902
|
+
numConsts: newJaxpr.consts.length
|
|
3671
3903
|
});
|
|
3672
3904
|
return [outs, require_backend.rep(outs.length, 0)];
|
|
3673
3905
|
}
|
|
@@ -3683,14 +3915,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3683
3915
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3684
3916
|
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3685
3917
|
});
|
|
3686
|
-
const { jaxpr: newJaxpr
|
|
3687
|
-
const result = {
|
|
3688
|
-
newJaxpr,
|
|
3689
|
-
newConsts
|
|
3690
|
-
};
|
|
3918
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3691
3919
|
if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
3692
|
-
vmapJaxprCache.get(jaxpr).set(cacheKey,
|
|
3693
|
-
return
|
|
3920
|
+
vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
3921
|
+
return newJaxpr;
|
|
3694
3922
|
}
|
|
3695
3923
|
function vmapFlat(f, inAxes, args) {
|
|
3696
3924
|
let axisSize = void 0;
|
|
@@ -3704,7 +3932,7 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3704
3932
|
if (axisSize === void 0) throw new TypeError("vmap requires at least one mapped axis");
|
|
3705
3933
|
let valsOut, bdimsOut;
|
|
3706
3934
|
try {
|
|
3707
|
-
var _usingCtx$1 = (0, import_usingCtx.default)();
|
|
3935
|
+
var _usingCtx$1 = (0, import_usingCtx$1.default)();
|
|
3708
3936
|
const main = _usingCtx$1.u(newMain(BatchTrace, axisSize));
|
|
3709
3937
|
const trace$1 = new BatchTrace(main);
|
|
3710
3938
|
const tracersIn = args.map((x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace$1, pureArray(x), inAxes[i]));
|
|
@@ -3736,13 +3964,312 @@ function vmap$1(f, inAxes = 0) {
|
|
|
3736
3964
|
return unflatten(outTree.value, outsFlat);
|
|
3737
3965
|
};
|
|
3738
3966
|
}
|
|
3739
|
-
function jacfwd$1(f) {
|
|
3740
|
-
return function jacobianForward(x) {
|
|
3741
|
-
if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
|
|
3742
|
-
const [size$1] = x.shape;
|
|
3743
|
-
const pushfwd = (v) => jvp$1(f, [x], [v])[1];
|
|
3744
|
-
return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3745
|
-
};
|
|
3967
|
+
function jacfwd$1(f) {
|
|
3968
|
+
return function jacobianForward(x) {
|
|
3969
|
+
if (x.shape.length !== 1) throw new TypeError("jacfwd only supports 1D inputs");
|
|
3970
|
+
const [size$1] = x.shape;
|
|
3971
|
+
const pushfwd = (v) => jvp$1(f, [x], [v])[1];
|
|
3972
|
+
return vmap$1(pushfwd, [0])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3973
|
+
};
|
|
3974
|
+
}
|
|
3975
|
+
|
|
3976
|
+
//#endregion
|
|
3977
|
+
//#region src/frontend/jvp.ts
|
|
3978
|
+
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3979
|
+
var JVPTracer = class extends Tracer {
|
|
3980
|
+
constructor(trace$1, primal, tangent) {
|
|
3981
|
+
super(trace$1);
|
|
3982
|
+
this.primal = primal;
|
|
3983
|
+
this.tangent = tangent;
|
|
3984
|
+
}
|
|
3985
|
+
get aval() {
|
|
3986
|
+
return this.primal.aval;
|
|
3987
|
+
}
|
|
3988
|
+
toString() {
|
|
3989
|
+
return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
|
|
3990
|
+
}
|
|
3991
|
+
get ref() {
|
|
3992
|
+
this.primal.ref, this.tangent.ref;
|
|
3993
|
+
return this;
|
|
3994
|
+
}
|
|
3995
|
+
dispose() {
|
|
3996
|
+
this.primal.dispose();
|
|
3997
|
+
this.tangent.dispose();
|
|
3998
|
+
}
|
|
3999
|
+
};
|
|
4000
|
+
var JVPTrace = class extends Trace {
|
|
4001
|
+
pure(val) {
|
|
4002
|
+
return this.lift(pureArray(val));
|
|
4003
|
+
}
|
|
4004
|
+
lift(val) {
|
|
4005
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
4006
|
+
}
|
|
4007
|
+
processPrimitive(primitive, tracers, params) {
|
|
4008
|
+
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
4009
|
+
const jvpRule = jvpRules[primitive];
|
|
4010
|
+
if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
|
|
4011
|
+
const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
|
|
4012
|
+
return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
|
|
4013
|
+
}
|
|
4014
|
+
};
|
|
4015
|
+
/** Rule that applies the same operation to primals and tangents. */
|
|
4016
|
+
function linearTangentsJvp(primitive) {
|
|
4017
|
+
return (primals, tangents, params) => {
|
|
4018
|
+
const ys = bind(primitive, primals, params);
|
|
4019
|
+
const dys = bind(primitive, tangents, params);
|
|
4020
|
+
return [ys, dys];
|
|
4021
|
+
};
|
|
4022
|
+
}
|
|
4023
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
4024
|
+
function bilinearTangentsJvp(primitive) {
|
|
4025
|
+
return ([x, y], [dx, dy], params) => {
|
|
4026
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
4027
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
4028
|
+
return [[primal], [tangent]];
|
|
4029
|
+
};
|
|
4030
|
+
}
|
|
4031
|
+
/** Rule that zeros out any tangents. */
|
|
4032
|
+
function zeroTangentsJvp(primitive) {
|
|
4033
|
+
return (primals, tangents, params) => {
|
|
4034
|
+
for (const t of tangents) t.dispose();
|
|
4035
|
+
const ys = bind(primitive, primals, params);
|
|
4036
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
4037
|
+
};
|
|
4038
|
+
}
|
|
4039
|
+
/** Compute `a @ b.T`, batched to last two axes. */
|
|
4040
|
+
function batchMatmulT(a, b) {
|
|
4041
|
+
return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
|
|
4042
|
+
}
|
|
4043
|
+
/** Batch matrix transpose. */
|
|
4044
|
+
function mT(a) {
|
|
4045
|
+
return moveaxis(a, -2, -1);
|
|
4046
|
+
}
|
|
4047
|
+
function sliceAxis(a, axis, p) {
|
|
4048
|
+
const slices = Array(a.shape.length).fill([]);
|
|
4049
|
+
slices[require_backend.checkAxis(axis, a.ndim)] = p;
|
|
4050
|
+
return a.slice(...slices);
|
|
4051
|
+
}
|
|
4052
|
+
function padAxis(a, axis, p) {
|
|
4053
|
+
const pads = Array(a.shape.length).fill([0, 0]);
|
|
4054
|
+
pads[require_backend.checkAxis(axis, a.ndim)] = p;
|
|
4055
|
+
return pad$1(a, pads);
|
|
4056
|
+
}
|
|
4057
|
+
const jvpRules = {
|
|
4058
|
+
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
4059
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
4060
|
+
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
4061
|
+
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
4062
|
+
if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
|
|
4063
|
+
dx.dispose();
|
|
4064
|
+
dy.dispose();
|
|
4065
|
+
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
4066
|
+
}
|
|
4067
|
+
const q = idiv(x.ref, y.ref);
|
|
4068
|
+
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
4069
|
+
},
|
|
4070
|
+
[Primitive.Min]([x, y], [dx, dy]) {
|
|
4071
|
+
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
4072
|
+
},
|
|
4073
|
+
[Primitive.Max]([x, y], [dx, dy]) {
|
|
4074
|
+
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
4075
|
+
},
|
|
4076
|
+
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
4077
|
+
[Primitive.Reciprocal]([x], [dx]) {
|
|
4078
|
+
const xRecip = reciprocal$1(x.ref);
|
|
4079
|
+
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
4080
|
+
},
|
|
4081
|
+
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
4082
|
+
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
4083
|
+
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
4084
|
+
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
4085
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
4086
|
+
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
4087
|
+
else {
|
|
4088
|
+
dx.dispose();
|
|
4089
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
4090
|
+
}
|
|
4091
|
+
},
|
|
4092
|
+
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
4093
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
4094
|
+
dx.dispose();
|
|
4095
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
4096
|
+
},
|
|
4097
|
+
[Primitive.Sin]([x], [dx]) {
|
|
4098
|
+
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
4099
|
+
},
|
|
4100
|
+
[Primitive.Cos]([x], [dx]) {
|
|
4101
|
+
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
4102
|
+
},
|
|
4103
|
+
[Primitive.Asin]([x], [dx]) {
|
|
4104
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
4105
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
4106
|
+
},
|
|
4107
|
+
[Primitive.Atan]([x], [dx]) {
|
|
4108
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
4109
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
4110
|
+
},
|
|
4111
|
+
[Primitive.Exp]([x], [dx]) {
|
|
4112
|
+
const z = exp$1(x);
|
|
4113
|
+
return [[z.ref], [z.mul(dx)]];
|
|
4114
|
+
},
|
|
4115
|
+
[Primitive.Log]([x], [dx]) {
|
|
4116
|
+
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
4117
|
+
},
|
|
4118
|
+
[Primitive.Erf]([x], [dx]) {
|
|
4119
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
4120
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
4121
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
4122
|
+
},
|
|
4123
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
4124
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
4125
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
4126
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
4127
|
+
},
|
|
4128
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
4129
|
+
const z = sqrt$1(x);
|
|
4130
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
4131
|
+
},
|
|
4132
|
+
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
4133
|
+
if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
4134
|
+
else if (op === require_backend.AluOp.Mul) {
|
|
4135
|
+
const primal = reduce(x.ref, op, axis);
|
|
4136
|
+
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
4137
|
+
return [[primal], [tangent]];
|
|
4138
|
+
} else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
|
|
4139
|
+
const primal = reduce(x.ref, op, axis);
|
|
4140
|
+
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
4141
|
+
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
4142
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
4143
|
+
return [[primal], [tangent]];
|
|
4144
|
+
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
4145
|
+
},
|
|
4146
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
4147
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
4148
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
4149
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
4150
|
+
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
4151
|
+
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
4152
|
+
dcond.dispose();
|
|
4153
|
+
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
4154
|
+
},
|
|
4155
|
+
[Primitive.Concatenate]: linearTangentsJvp(Primitive.Concatenate),
|
|
4156
|
+
[Primitive.Split]: linearTangentsJvp(Primitive.Split),
|
|
4157
|
+
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
4158
|
+
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
4159
|
+
const indicesRef = indices.map((t) => t.ref);
|
|
4160
|
+
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
4161
|
+
},
|
|
4162
|
+
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
4163
|
+
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
4164
|
+
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
4165
|
+
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
4166
|
+
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
4167
|
+
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
4168
|
+
[Primitive.Sort]([x], [dx]) {
|
|
4169
|
+
const [y, idx] = argsort$1(x);
|
|
4170
|
+
return [[y], [gather(dx, [idx], [-1], -1)]];
|
|
4171
|
+
},
|
|
4172
|
+
[Primitive.Argsort]([x], [dx]) {
|
|
4173
|
+
const [y, idx] = argsort$1(x);
|
|
4174
|
+
return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
|
|
4175
|
+
},
|
|
4176
|
+
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
4177
|
+
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
4178
|
+
const dax = batchMatmulT(da, x.ref);
|
|
4179
|
+
const rhsT = db.sub(mT(dax));
|
|
4180
|
+
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
4181
|
+
return [[x], [dx]];
|
|
4182
|
+
},
|
|
4183
|
+
[Primitive.Cholesky]([a], [da]) {
|
|
4184
|
+
const L = cholesky$2(a.ref);
|
|
4185
|
+
da = da.ref.add(mT(da)).mul(.5);
|
|
4186
|
+
const W = triangularSolve$1(L.ref, da, { lower: true });
|
|
4187
|
+
const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
|
|
4188
|
+
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
4189
|
+
return [[L], [dL]];
|
|
4190
|
+
},
|
|
4191
|
+
[Primitive.LU]([a], [da]) {
|
|
4192
|
+
const [luMatrix, pivots, permutation] = lu$1(a);
|
|
4193
|
+
const [m, n] = a.shape.slice(-2);
|
|
4194
|
+
const k = Math.min(m, n);
|
|
4195
|
+
const luSliceL = sliceAxis(luMatrix.ref, -1, [0, k]);
|
|
4196
|
+
const lLower = tril(luSliceL, -1);
|
|
4197
|
+
const lPadded = m > k ? padAxis(lLower, -1, [0, m - k]) : lLower;
|
|
4198
|
+
const L = lPadded.add(eye(m));
|
|
4199
|
+
const luSliceU = sliceAxis(luMatrix.ref, -2, [0, k]);
|
|
4200
|
+
const uUpper = triu(luSliceU);
|
|
4201
|
+
const uPadded = n > k ? padAxis(uUpper, -2, [0, n - k]) : uUpper;
|
|
4202
|
+
const uEye = n > k ? padAxis(padAxis(eye(n - k), -1, [k, 0]), -2, [k, 0]) : zerosLike$1(uPadded.ref);
|
|
4203
|
+
const U = uPadded.add(uEye);
|
|
4204
|
+
const P = permutation.ref.reshape([...permutation.shape, 1]).equal(arange(m)).astype(da.dtype);
|
|
4205
|
+
const pda = batchMatmulT(P, mT(da));
|
|
4206
|
+
const la = mT(triangularSolve$1(L.ref, mT(pda), {
|
|
4207
|
+
lower: true,
|
|
4208
|
+
unitDiagonal: true
|
|
4209
|
+
}));
|
|
4210
|
+
const lau = triangularSolve$1(mT(U.ref), la, { lower: true });
|
|
4211
|
+
const lDot = batchMatmulT(L, mT(tril(lau.ref, -1)));
|
|
4212
|
+
const uDot = batchMatmulT(triu(lau), mT(U));
|
|
4213
|
+
return [[
|
|
4214
|
+
luMatrix,
|
|
4215
|
+
pivots,
|
|
4216
|
+
permutation
|
|
4217
|
+
], [
|
|
4218
|
+
lDot.add(uDot),
|
|
4219
|
+
zerosLike$1(pivots.ref),
|
|
4220
|
+
zerosLike$1(permutation.ref)
|
|
4221
|
+
]];
|
|
4222
|
+
},
|
|
4223
|
+
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
4224
|
+
const newJaxpr = jvpJaxpr(jaxpr);
|
|
4225
|
+
const outs = bind(Primitive.Jit, [
|
|
4226
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4227
|
+
...primals,
|
|
4228
|
+
...tangents
|
|
4229
|
+
], {
|
|
4230
|
+
name: `${name}_jvp`,
|
|
4231
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4232
|
+
numConsts: newJaxpr.consts.length
|
|
4233
|
+
});
|
|
4234
|
+
const n = outs.length / 2;
|
|
4235
|
+
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
4236
|
+
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
4237
|
+
return [primalsOut, tangentsOut];
|
|
4238
|
+
}
|
|
4239
|
+
};
|
|
4240
|
+
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
4241
|
+
function jvpJaxpr(jaxpr) {
|
|
4242
|
+
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
4243
|
+
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
4244
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
4245
|
+
jvpJaxprCache.set(jaxpr, newJaxpr);
|
|
4246
|
+
return newJaxpr;
|
|
4247
|
+
}
|
|
4248
|
+
function jvpFlat(f, primals, tangents) {
|
|
4249
|
+
try {
|
|
4250
|
+
var _usingCtx$1 = (0, import_usingCtx.default)();
|
|
4251
|
+
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
4252
|
+
const trace$1 = new JVPTrace(main);
|
|
4253
|
+
const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
4254
|
+
const outs = f(...tracersIn);
|
|
4255
|
+
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
4256
|
+
return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
4257
|
+
} catch (_) {
|
|
4258
|
+
_usingCtx$1.e = _;
|
|
4259
|
+
} finally {
|
|
4260
|
+
_usingCtx$1.d();
|
|
4261
|
+
}
|
|
4262
|
+
}
|
|
4263
|
+
function jvp$1(f, primals, tangents) {
|
|
4264
|
+
const [primalsFlat, inTree] = flatten(primals);
|
|
4265
|
+
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4266
|
+
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4267
|
+
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
4268
|
+
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4269
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4270
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4271
|
+
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4272
|
+
return [primalsOut, tangentsOut];
|
|
3746
4273
|
}
|
|
3747
4274
|
|
|
3748
4275
|
//#endregion
|
|
@@ -3775,11 +4302,10 @@ function partialEvalFlat(f, pvalsIn) {
|
|
|
3775
4302
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3776
4303
|
const pvalsOut = tracersOut.map((t) => t.pval);
|
|
3777
4304
|
const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
|
|
3778
|
-
const
|
|
4305
|
+
const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
|
|
3779
4306
|
return {
|
|
3780
4307
|
jaxpr,
|
|
3781
|
-
pvalsOut
|
|
3782
|
-
consts
|
|
4308
|
+
pvalsOut
|
|
3783
4309
|
};
|
|
3784
4310
|
}
|
|
3785
4311
|
/**
|
|
@@ -3796,22 +4322,19 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3796
4322
|
const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
|
|
3797
4323
|
return [...primalsOut$1, ...tangentsOut];
|
|
3798
4324
|
};
|
|
3799
|
-
const { jaxpr, pvalsOut
|
|
4325
|
+
const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
|
|
3800
4326
|
const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
|
|
3801
4327
|
if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
|
|
3802
4328
|
const primalsOut = primalPvals.map((pval) => pval.val);
|
|
3803
4329
|
return {
|
|
3804
4330
|
primalsOut,
|
|
3805
|
-
jaxpr
|
|
3806
|
-
consts
|
|
4331
|
+
jaxpr
|
|
3807
4332
|
};
|
|
3808
4333
|
}
|
|
3809
4334
|
function linearizeFlat(f, primalsIn) {
|
|
3810
|
-
const { primalsOut, jaxpr
|
|
3811
|
-
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3812
|
-
const dispose$1 = () =>
|
|
3813
|
-
for (const c of consts) c.dispose();
|
|
3814
|
-
};
|
|
4335
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4336
|
+
const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
|
|
4337
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
3815
4338
|
return [
|
|
3816
4339
|
primalsOut,
|
|
3817
4340
|
fLin,
|
|
@@ -3895,7 +4418,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3895
4418
|
}
|
|
3896
4419
|
processPrimitive(primitive, tracers, params) {
|
|
3897
4420
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3898
|
-
if (primitive === Primitive.
|
|
4421
|
+
if (primitive === Primitive.Jit) {
|
|
3899
4422
|
const { name, jaxpr, numConsts } = params;
|
|
3900
4423
|
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3901
4424
|
}
|
|
@@ -3921,14 +4444,14 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3921
4444
|
* Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
|
|
3922
4445
|
* values as possible (with JIT) and forwarding the unknown ones.
|
|
3923
4446
|
*
|
|
3924
|
-
* Used when encountering a
|
|
4447
|
+
* Used when encountering a Jit rule during the trace.
|
|
3925
4448
|
*/
|
|
3926
4449
|
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3927
4450
|
jaxpr = jaxpr.flatten();
|
|
3928
4451
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3929
4452
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3930
4453
|
const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
|
|
3931
|
-
const outs1Res = bind(Primitive.
|
|
4454
|
+
const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3932
4455
|
name: `${name}_peval`,
|
|
3933
4456
|
jaxpr: jaxpr1,
|
|
3934
4457
|
numConsts: 0
|
|
@@ -3938,7 +4461,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3938
4461
|
const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
|
|
3939
4462
|
const recipe = {
|
|
3940
4463
|
type: "JaxprEqn",
|
|
3941
|
-
prim: Primitive.
|
|
4464
|
+
prim: Primitive.Jit,
|
|
3942
4465
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3943
4466
|
params: {
|
|
3944
4467
|
name: `${name}_resid`,
|
|
@@ -3967,7 +4490,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
|
|
|
3967
4490
|
const eqns1 = [];
|
|
3968
4491
|
const eqns2 = [];
|
|
3969
4492
|
for (const eqn of jaxpr.eqns) {
|
|
3970
|
-
if (eqn.primitive === Primitive.
|
|
4493
|
+
if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
|
|
3971
4494
|
const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
|
|
3972
4495
|
if (hasUnknowns) {
|
|
3973
4496
|
for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
|
|
@@ -4042,10 +4565,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
4042
4565
|
for (const t of tracersOut) t.dispose();
|
|
4043
4566
|
jaxpr = jaxpr.simplify();
|
|
4044
4567
|
if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
4045
|
-
return
|
|
4046
|
-
jaxpr,
|
|
4047
|
-
consts
|
|
4048
|
-
};
|
|
4568
|
+
return new ClosedJaxpr(jaxpr, consts);
|
|
4049
4569
|
}
|
|
4050
4570
|
/** Marker type for pullback, used by transpose rules. */
|
|
4051
4571
|
var UndefPrimal = class {
|
|
@@ -4237,317 +4757,151 @@ const transposeRules = {
|
|
|
4237
4757
|
cond.dispose();
|
|
4238
4758
|
return cts;
|
|
4239
4759
|
},
|
|
4240
|
-
[Primitive.
|
|
4241
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4245
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4246
|
-
return [reduce(ct, require_backend.AluOp.Add, axis)];
|
|
4247
|
-
},
|
|
4248
|
-
[Primitive.Reshape]([ct], [x], _) {
|
|
4249
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4250
|
-
return [reshape$1(ct, x.aval.shape)];
|
|
4251
|
-
},
|
|
4252
|
-
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4253
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4254
|
-
return [flip$1(ct, axis)];
|
|
4255
|
-
},
|
|
4256
|
-
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4257
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4258
|
-
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4259
|
-
return [pad$1(ct, width)];
|
|
4760
|
+
[Primitive.Concatenate]([ct], inputs, { axis }) {
|
|
4761
|
+
if (inputs.some((x) => !(x instanceof UndefPrimal))) throw new NonlinearError(Primitive.Concatenate);
|
|
4762
|
+
const sizes = inputs.map((x) => x.aval.shape[axis]);
|
|
4763
|
+
return split$2(ct, axis, sizes);
|
|
4260
4764
|
},
|
|
4261
|
-
[Primitive.
|
|
4262
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.
|
|
4263
|
-
|
|
4264
|
-
return [shrink(ct, slice)];
|
|
4765
|
+
[Primitive.Split](cts, [x], { axis }) {
|
|
4766
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Split);
|
|
4767
|
+
return [concatenate$1(cts, axis)];
|
|
4265
4768
|
},
|
|
4266
4769
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4267
4770
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4268
4771
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4269
|
-
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4270
|
-
},
|
|
4271
|
-
[Primitive.
|
|
4272
|
-
|
|
4273
|
-
|
|
4274
|
-
|
|
4275
|
-
|
|
4276
|
-
|
|
4277
|
-
|
|
4278
|
-
|
|
4279
|
-
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4283
|
-
|
|
4284
|
-
|
|
4285
|
-
return
|
|
4286
|
-
}
|
|
4287
|
-
}
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
|
|
4292
|
-
|
|
4293
|
-
|
|
4294
|
-
|
|
4295
|
-
|
|
4296
|
-
|
|
4297
|
-
|
|
4298
|
-
|
|
4299
|
-
|
|
4300
|
-
|
|
4301
|
-
|
|
4302
|
-
typecheckJaxpr(newJaxpr);
|
|
4303
|
-
const result = {
|
|
4304
|
-
newJaxpr,
|
|
4305
|
-
newConsts
|
|
4306
|
-
};
|
|
4307
|
-
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4308
|
-
transposeJaxprCache.get(jaxpr).set(cacheKey, result);
|
|
4309
|
-
return result;
|
|
4310
|
-
}
|
|
4311
|
-
function vjpFlat(f, primalsIn) {
|
|
4312
|
-
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
4313
|
-
const fVjp = (...cotangents) => {
|
|
4314
|
-
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4315
|
-
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
4316
|
-
};
|
|
4317
|
-
const dispose$1 = () => {
|
|
4318
|
-
for (const c of consts) c.dispose();
|
|
4319
|
-
};
|
|
4320
|
-
return [
|
|
4321
|
-
primalsOut,
|
|
4322
|
-
fVjp,
|
|
4323
|
-
dispose$1
|
|
4324
|
-
];
|
|
4325
|
-
}
|
|
4326
|
-
function vjp$1(f, ...primalsIn) {
|
|
4327
|
-
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4328
|
-
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4329
|
-
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4330
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4331
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4332
|
-
const fVjp = ((cotangentsOut) => {
|
|
4333
|
-
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4334
|
-
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4335
|
-
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4336
|
-
return unflatten(inTree, cotangentsInFlat);
|
|
4337
|
-
});
|
|
4338
|
-
fVjp.dispose = dispose$1;
|
|
4339
|
-
return [primalsOut, fVjp];
|
|
4340
|
-
}
|
|
4341
|
-
function grad$1(f) {
|
|
4342
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4343
|
-
return (...x) => {
|
|
4344
|
-
const [y, dx] = valueAndGradFn(...x);
|
|
4345
|
-
y.dispose();
|
|
4346
|
-
return dx;
|
|
4347
|
-
};
|
|
4348
|
-
}
|
|
4349
|
-
function valueAndGrad$1(f) {
|
|
4350
|
-
return (...x) => {
|
|
4351
|
-
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4352
|
-
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4353
|
-
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4354
|
-
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4355
|
-
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4356
|
-
for (const r of rest) dispose(r);
|
|
4357
|
-
fVjp.dispose();
|
|
4358
|
-
return [y, ct];
|
|
4359
|
-
};
|
|
4360
|
-
}
|
|
4361
|
-
function jacrev$1(f) {
|
|
4362
|
-
return function jacobianReverse(x) {
|
|
4363
|
-
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4364
|
-
const [size$1] = x.shape;
|
|
4365
|
-
const pullback = (ct) => {
|
|
4366
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4367
|
-
y.dispose();
|
|
4368
|
-
const [ret] = fVjp(ct);
|
|
4369
|
-
fVjp.dispose();
|
|
4370
|
-
return ret;
|
|
4371
|
-
};
|
|
4372
|
-
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4373
|
-
};
|
|
4374
|
-
}
|
|
4375
|
-
|
|
4376
|
-
//#endregion
|
|
4377
|
-
//#region src/library/lax.ts
|
|
4378
|
-
var lax_exports = {};
|
|
4379
|
-
__export(lax_exports, {
|
|
4380
|
-
conv: () => conv,
|
|
4381
|
-
convGeneralDilated: () => convGeneralDilated,
|
|
4382
|
-
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4383
|
-
dot: () => dot$1,
|
|
4384
|
-
erf: () => erf,
|
|
4385
|
-
erfc: () => erfc,
|
|
4386
|
-
reduceWindow: () => reduceWindow,
|
|
4387
|
-
stopGradient: () => stopGradient$1
|
|
4388
|
-
});
|
|
4389
|
-
/**
|
|
4390
|
-
* General dot product/contraction operator.
|
|
4391
|
-
*
|
|
4392
|
-
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
4393
|
-
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
4394
|
-
*/
|
|
4395
|
-
function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
4396
|
-
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
4397
|
-
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
4398
|
-
lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
4399
|
-
rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
4400
|
-
lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
4401
|
-
rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
4402
|
-
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
4403
|
-
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
4404
|
-
const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
4405
|
-
const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
4406
|
-
const lhs2 = lhs.transpose([
|
|
4407
|
-
...lb,
|
|
4408
|
-
...lf,
|
|
4409
|
-
...lc
|
|
4410
|
-
]);
|
|
4411
|
-
const rhs2 = rhs.transpose([
|
|
4412
|
-
...rb,
|
|
4413
|
-
...rf,
|
|
4414
|
-
...rc
|
|
4415
|
-
]);
|
|
4416
|
-
if (lc.length === 0) return mul(lhs2.reshape([
|
|
4417
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4418
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4419
|
-
...require_backend.rep(rf.length, 1)
|
|
4420
|
-
]), rhs2.reshape([
|
|
4421
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4422
|
-
...require_backend.rep(lf.length, 1),
|
|
4423
|
-
...rf.map((a) => rhs.shape[a])
|
|
4424
|
-
]));
|
|
4425
|
-
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
4426
|
-
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
4427
|
-
if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
4428
|
-
return dot$2(lhs2.reshape([
|
|
4429
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4430
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4431
|
-
...require_backend.rep(rf.length, 1),
|
|
4432
|
-
require_backend.prod(dotShapeX)
|
|
4433
|
-
]), rhs2.reshape([
|
|
4434
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4435
|
-
...require_backend.rep(lf.length, 1),
|
|
4436
|
-
...rf.map((a) => rhs.shape[a]),
|
|
4437
|
-
require_backend.prod(dotShapeY)
|
|
4438
|
-
]));
|
|
4439
|
-
}
|
|
4440
|
-
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4441
|
-
const padType = padding.toUpperCase();
|
|
4442
|
-
switch (padType) {
|
|
4443
|
-
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
4444
|
-
case "SAME":
|
|
4445
|
-
case "SAME_LOWER": {
|
|
4446
|
-
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
4447
|
-
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
4448
|
-
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
4449
|
-
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
4450
|
-
}
|
|
4451
|
-
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
4452
|
-
}
|
|
4453
|
-
}
|
|
4454
|
-
/**
|
|
4455
|
-
* General n-dimensional convolution operator, with optional dilation.
|
|
4456
|
-
*
|
|
4457
|
-
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
4458
|
-
* function in JAX, which wraps XLA's general convolution operator.
|
|
4459
|
-
*
|
|
4460
|
-
* Grouped convolutions are not supported right now.
|
|
4461
|
-
*/
|
|
4462
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
4463
|
-
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4464
|
-
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4465
|
-
if (typeof padding === "string") {
|
|
4466
|
-
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4467
|
-
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
4468
|
-
}
|
|
4469
|
-
if (featureGroupCount !== 1) {
|
|
4470
|
-
const G = featureGroupCount;
|
|
4471
|
-
const [N, C_in, ...xs] = lhs.shape;
|
|
4472
|
-
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
4473
|
-
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
4474
|
-
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
4475
|
-
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
4476
|
-
const lhsGrouped = moveaxis(lhs.reshape([
|
|
4477
|
-
N,
|
|
4478
|
-
G,
|
|
4479
|
-
C_in / G,
|
|
4480
|
-
...xs
|
|
4481
|
-
]), 1, 0);
|
|
4482
|
-
const rhsGrouped = rhs.reshape([
|
|
4483
|
-
G,
|
|
4484
|
-
C_out / G,
|
|
4485
|
-
C_in_per_group,
|
|
4486
|
-
...ks
|
|
4487
|
-
]);
|
|
4488
|
-
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
4489
|
-
vmapDims: 1,
|
|
4490
|
-
strides: windowStrides,
|
|
4491
|
-
padding,
|
|
4492
|
-
lhsDilation,
|
|
4493
|
-
rhsDilation
|
|
4772
|
+
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4773
|
+
},
|
|
4774
|
+
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4775
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4776
|
+
return [transpose$1(ct, require_backend.invertPermutation(perm))];
|
|
4777
|
+
},
|
|
4778
|
+
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4779
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4780
|
+
return [reduce(ct, require_backend.AluOp.Add, axis)];
|
|
4781
|
+
},
|
|
4782
|
+
[Primitive.Reshape]([ct], [x], _) {
|
|
4783
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4784
|
+
return [reshape$1(ct, x.aval.shape)];
|
|
4785
|
+
},
|
|
4786
|
+
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4787
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4788
|
+
return [flip$1(ct, axis)];
|
|
4789
|
+
},
|
|
4790
|
+
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4791
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4792
|
+
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4793
|
+
return [pad$1(ct, width)];
|
|
4794
|
+
},
|
|
4795
|
+
[Primitive.Pad]([ct], [x], { width }) {
|
|
4796
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
|
|
4797
|
+
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4798
|
+
return [shrink(ct, slice)];
|
|
4799
|
+
},
|
|
4800
|
+
[Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
|
|
4801
|
+
if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
|
|
4802
|
+
const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
|
|
4803
|
+
lower: true,
|
|
4804
|
+
unitDiagonal
|
|
4494
4805
|
});
|
|
4495
|
-
|
|
4496
|
-
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
]);
|
|
4806
|
+
return [null, ctB];
|
|
4807
|
+
},
|
|
4808
|
+
[Primitive.Jit](cts, args, { name, jaxpr }) {
|
|
4809
|
+
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
4810
|
+
const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
|
|
4811
|
+
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
4812
|
+
const outs = bind(Primitive.Jit, [
|
|
4813
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4814
|
+
...residuals,
|
|
4815
|
+
...cts
|
|
4816
|
+
], {
|
|
4817
|
+
name: `${name}_t`,
|
|
4818
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4819
|
+
numConsts: newJaxpr.consts.length
|
|
4820
|
+
});
|
|
4821
|
+
let i = 0;
|
|
4822
|
+
return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
|
|
4501
4823
|
}
|
|
4502
|
-
|
|
4503
|
-
|
|
4504
|
-
|
|
4505
|
-
|
|
4506
|
-
|
|
4507
|
-
|
|
4508
|
-
}
|
|
4509
|
-
|
|
4510
|
-
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
|
|
4824
|
+
};
|
|
4825
|
+
const transposeJaxprCache = /* @__PURE__ */ new Map();
|
|
4826
|
+
function transposeJaxpr(jaxpr, undefPrimals) {
|
|
4827
|
+
const cacheKey = JSON.stringify(undefPrimals);
|
|
4828
|
+
const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
|
|
4829
|
+
if (prevResult) return prevResult;
|
|
4830
|
+
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
4831
|
+
const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
|
|
4832
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
|
|
4833
|
+
const args = [];
|
|
4834
|
+
let forwardInIdx = 0;
|
|
4835
|
+
for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
|
|
4836
|
+
else args.push(forwardIn[forwardInIdx++]);
|
|
4837
|
+
return evalJaxprTransposed(jaxpr, args, cotangents);
|
|
4838
|
+
})(forwardInTypes, outTypes);
|
|
4839
|
+
typecheckJaxpr(newJaxpr.jaxpr);
|
|
4840
|
+
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4841
|
+
transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
4842
|
+
return newJaxpr;
|
|
4515
4843
|
}
|
|
4516
|
-
|
|
4517
|
-
|
|
4518
|
-
|
|
4844
|
+
function vjpFlat(f, primalsIn) {
|
|
4845
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4846
|
+
const fVjp = (...cotangents) => {
|
|
4847
|
+
const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4848
|
+
return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
|
|
4849
|
+
};
|
|
4850
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
4851
|
+
return [
|
|
4852
|
+
primalsOut,
|
|
4853
|
+
fVjp,
|
|
4854
|
+
dispose$1
|
|
4855
|
+
];
|
|
4519
4856
|
}
|
|
4520
|
-
|
|
4521
|
-
|
|
4522
|
-
|
|
4523
|
-
|
|
4524
|
-
|
|
4525
|
-
|
|
4526
|
-
|
|
4527
|
-
|
|
4528
|
-
|
|
4857
|
+
function vjp$1(f, ...primalsIn) {
|
|
4858
|
+
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4859
|
+
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4860
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4861
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4862
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4863
|
+
const fVjp = ((cotangentsOut) => {
|
|
4864
|
+
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4865
|
+
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4866
|
+
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4867
|
+
return unflatten(inTree, cotangentsInFlat);
|
|
4868
|
+
});
|
|
4869
|
+
fVjp.dispose = dispose$1;
|
|
4870
|
+
return [primalsOut, fVjp];
|
|
4529
4871
|
}
|
|
4530
|
-
|
|
4531
|
-
|
|
4532
|
-
return
|
|
4872
|
+
function grad$1(f) {
|
|
4873
|
+
const valueAndGradFn = valueAndGrad$1(f);
|
|
4874
|
+
return (...x) => {
|
|
4875
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4876
|
+
y.dispose();
|
|
4877
|
+
return dx;
|
|
4878
|
+
};
|
|
4533
4879
|
}
|
|
4534
|
-
|
|
4535
|
-
|
|
4536
|
-
|
|
4537
|
-
|
|
4538
|
-
|
|
4539
|
-
|
|
4540
|
-
|
|
4541
|
-
|
|
4880
|
+
function valueAndGrad$1(f) {
|
|
4881
|
+
return (...x) => {
|
|
4882
|
+
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4883
|
+
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4884
|
+
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4885
|
+
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4886
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4887
|
+
for (const r of rest) dispose(r);
|
|
4888
|
+
fVjp.dispose();
|
|
4889
|
+
return [y, ct];
|
|
4890
|
+
};
|
|
4542
4891
|
}
|
|
4543
|
-
|
|
4544
|
-
|
|
4545
|
-
|
|
4546
|
-
|
|
4547
|
-
|
|
4548
|
-
|
|
4549
|
-
|
|
4550
|
-
|
|
4892
|
+
function jacrev$1(f) {
|
|
4893
|
+
return function jacobianReverse(x) {
|
|
4894
|
+
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4895
|
+
const [size$1] = x.shape;
|
|
4896
|
+
const pullback = (ct) => {
|
|
4897
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
4898
|
+
y.dispose();
|
|
4899
|
+
const [ret] = fVjp(ct);
|
|
4900
|
+
fVjp.dispose();
|
|
4901
|
+
return ret;
|
|
4902
|
+
};
|
|
4903
|
+
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4904
|
+
};
|
|
4551
4905
|
}
|
|
4552
4906
|
|
|
4553
4907
|
//#endregion
|
|
@@ -4687,8 +5041,8 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4687
5041
|
const idx = lhsIndex[j];
|
|
4688
5042
|
const dim = shape$1[j];
|
|
4689
5043
|
const existing = sizeMap.get(idx);
|
|
4690
|
-
if (existing === void 0) sizeMap.set(idx, dim);
|
|
4691
|
-
else if (existing !== dim) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
5044
|
+
if (existing === void 0 || existing === 1) sizeMap.set(idx, dim);
|
|
5045
|
+
else if (existing !== dim && dim !== 1) throw new Error(`Inconsistent size for index ${idx} in einsum: ${existing} vs ${dim}`);
|
|
4692
5046
|
}
|
|
4693
5047
|
}
|
|
4694
5048
|
for (const [idx, size$1] of sizeMap) if (!Number.isInteger(idx) || idx < 0) throw new Error(`Invalid index ${idx} in einsum expression, must be non-negative integer`);
|
|
@@ -4696,52 +5050,410 @@ function computeSizeMap({ shapes, lhsIndices, rhsIndex }) {
|
|
|
4696
5050
|
for (const idx of rhsIndex) if (!sizeMap.has(idx)) throw new Error(`Output index ${idx} not present in einsum inputs`);
|
|
4697
5051
|
return sizeMap;
|
|
4698
5052
|
}
|
|
4699
|
-
const einsumPathCache = /* @__PURE__ */ new Map();
|
|
4700
|
-
function computeEinsumPath(input, method) {
|
|
4701
|
-
if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
|
|
4702
|
-
return require_backend.runWithCache(einsumPathCache, [input, method], () => {
|
|
4703
|
-
const sizeMap = computeSizeMap(input);
|
|
4704
|
-
if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
|
|
4705
|
-
switch (method) {
|
|
4706
|
-
case "naive": return computePathNaive(input, sizeMap);
|
|
4707
|
-
case "optimal": return computePathOptimal(input, sizeMap);
|
|
4708
|
-
default: throw new Error(`Unknown computePath method: ${method}`);
|
|
4709
|
-
}
|
|
4710
|
-
});
|
|
5053
|
+
const einsumPathCache = /* @__PURE__ */ new Map();
|
|
5054
|
+
function computeEinsumPath(input, method) {
|
|
5055
|
+
if (!method) method = input.shapes.length <= 5 ? "optimal" : "naive";
|
|
5056
|
+
return require_backend.runWithCache(einsumPathCache, [input, method], () => {
|
|
5057
|
+
const sizeMap = computeSizeMap(input);
|
|
5058
|
+
if (input.shapes.length === 1) return new EinsumPath(input, sizeMap, []);
|
|
5059
|
+
switch (method) {
|
|
5060
|
+
case "naive": return computePathNaive(input, sizeMap);
|
|
5061
|
+
case "optimal": return computePathOptimal(input, sizeMap);
|
|
5062
|
+
default: throw new Error(`Unknown computePath method: ${method}`);
|
|
5063
|
+
}
|
|
5064
|
+
});
|
|
5065
|
+
}
|
|
5066
|
+
function computePathNaive(input, sizeMap) {
|
|
5067
|
+
const n = input.shapes.length;
|
|
5068
|
+
const path = [];
|
|
5069
|
+
let lastTensorIndex = 0;
|
|
5070
|
+
for (let i = 1; i < n; i++) {
|
|
5071
|
+
path.push([lastTensorIndex, i]);
|
|
5072
|
+
lastTensorIndex = n + i - 1;
|
|
5073
|
+
}
|
|
5074
|
+
return new EinsumPath(input, sizeMap, path);
|
|
5075
|
+
}
|
|
5076
|
+
function computePathOptimal(input, sizeMap) {
|
|
5077
|
+
const n = input.shapes.length;
|
|
5078
|
+
let bestPath = null;
|
|
5079
|
+
let bestFlops = null;
|
|
5080
|
+
for (const path of allPaths(require_backend.range(n), n)) {
|
|
5081
|
+
const flops = approximatePathFlops(input, sizeMap, path);
|
|
5082
|
+
if (bestFlops === null || flops < bestFlops) {
|
|
5083
|
+
bestPath = path;
|
|
5084
|
+
bestFlops = flops;
|
|
5085
|
+
}
|
|
5086
|
+
}
|
|
5087
|
+
return new EinsumPath(input, sizeMap, bestPath);
|
|
5088
|
+
}
|
|
5089
|
+
function* allPaths(tensors, next) {
|
|
5090
|
+
if (tensors.length === 2) {
|
|
5091
|
+
yield [[tensors[0], tensors[1]]];
|
|
5092
|
+
return;
|
|
5093
|
+
}
|
|
5094
|
+
for (let i = 0; i < tensors.length; i++) for (let j = i + 1; j < tensors.length; j++) {
|
|
5095
|
+
const pair = [tensors[i], tensors[j]];
|
|
5096
|
+
const newTensors = tensors.filter((t) => t !== pair[0] && t !== pair[1]);
|
|
5097
|
+
newTensors.push(next);
|
|
5098
|
+
for (const subpath of allPaths(newTensors, next + 1)) yield [pair, ...subpath];
|
|
5099
|
+
}
|
|
5100
|
+
}
|
|
5101
|
+
|
|
5102
|
+
//#endregion
|
|
5103
|
+
//#region src/library/numpy-fft.ts
|
|
5104
|
+
var numpy_fft_exports = {};
|
|
5105
|
+
__export(numpy_fft_exports, {
|
|
5106
|
+
fft: () => fft,
|
|
5107
|
+
ifft: () => ifft
|
|
5108
|
+
});
|
|
5109
|
+
function checkPairInput(name, a) {
|
|
5110
|
+
const fullName = `jax.numpy.fft.${name}`;
|
|
5111
|
+
if (!require_backend.deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
|
|
5112
|
+
if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
|
|
5113
|
+
if (!require_backend.isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
|
|
5114
|
+
}
|
|
5115
|
+
function checkPowerOfTwo(name, n) {
|
|
5116
|
+
if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
|
|
5117
|
+
}
|
|
5118
|
+
const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
|
|
5119
|
+
const half = 2 ** i;
|
|
5120
|
+
real = real.reshape([-1, 2 * half]);
|
|
5121
|
+
imag = imag.reshape([-1, 2 * half]);
|
|
5122
|
+
const k = arange(0, half, 1, { dtype: real.dtype });
|
|
5123
|
+
const theta = k.mul(-Math.PI / half);
|
|
5124
|
+
const wr = cos(theta.ref);
|
|
5125
|
+
const wi = sin(theta);
|
|
5126
|
+
const ur = real.ref.slice([], [0, half]);
|
|
5127
|
+
const ui = imag.ref.slice([], [0, half]);
|
|
5128
|
+
const vr = real.slice([], [half, 2 * half]);
|
|
5129
|
+
const vi = imag.slice([], [half, 2 * half]);
|
|
5130
|
+
const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
|
|
5131
|
+
const ti = vr.mul(wi).add(vi.mul(wr));
|
|
5132
|
+
return {
|
|
5133
|
+
real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
|
|
5134
|
+
imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
|
|
5135
|
+
};
|
|
5136
|
+
}, { staticArgnums: [0] });
|
|
5137
|
+
/**
|
|
5138
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
5139
|
+
*
|
|
5140
|
+
* Currently, the size of the axis must be a power of two.
|
|
5141
|
+
*/
|
|
5142
|
+
function fft(a, axis = -1) {
|
|
5143
|
+
checkPairInput("fft", a);
|
|
5144
|
+
let { real, imag } = a;
|
|
5145
|
+
axis = require_backend.checkAxis(axis, real.ndim);
|
|
5146
|
+
const n = real.shape[axis];
|
|
5147
|
+
checkPowerOfTwo("fft", n);
|
|
5148
|
+
const logN = Math.log2(n);
|
|
5149
|
+
let perm = null;
|
|
5150
|
+
if (axis !== real.ndim - 1) {
|
|
5151
|
+
perm = require_backend.range(real.ndim);
|
|
5152
|
+
perm.splice(axis, 1);
|
|
5153
|
+
perm.push(axis);
|
|
5154
|
+
real = real.transpose(perm);
|
|
5155
|
+
imag = imag.transpose(perm);
|
|
5156
|
+
}
|
|
5157
|
+
const originalShape = real.shape;
|
|
5158
|
+
real = real.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
|
|
5159
|
+
imag = imag.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
|
|
5160
|
+
for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
|
|
5161
|
+
real,
|
|
5162
|
+
imag
|
|
5163
|
+
}));
|
|
5164
|
+
real = real.reshape(originalShape);
|
|
5165
|
+
imag = imag.reshape(originalShape);
|
|
5166
|
+
if (perm !== null) {
|
|
5167
|
+
real = real.transpose(require_backend.invertPermutation(perm));
|
|
5168
|
+
imag = imag.transpose(require_backend.invertPermutation(perm));
|
|
5169
|
+
}
|
|
5170
|
+
return {
|
|
5171
|
+
real,
|
|
5172
|
+
imag
|
|
5173
|
+
};
|
|
5174
|
+
}
|
|
5175
|
+
/**
|
|
5176
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
5177
|
+
*
|
|
5178
|
+
* Currently, the size of the axis must be a power of two.
|
|
5179
|
+
*/
|
|
5180
|
+
function ifft(a, axis = -1) {
|
|
5181
|
+
checkPairInput("ifft", a);
|
|
5182
|
+
let { real, imag } = a;
|
|
5183
|
+
axis = require_backend.checkAxis(axis, real.ndim);
|
|
5184
|
+
const n = real.shape[axis];
|
|
5185
|
+
checkPowerOfTwo("ifft", n);
|
|
5186
|
+
imag = imag.mul(-1);
|
|
5187
|
+
const result = fft({
|
|
5188
|
+
real,
|
|
5189
|
+
imag
|
|
5190
|
+
}, axis);
|
|
5191
|
+
return {
|
|
5192
|
+
real: result.real.div(n),
|
|
5193
|
+
imag: result.imag.mul(-1).div(n)
|
|
5194
|
+
};
|
|
5195
|
+
}
|
|
5196
|
+
|
|
5197
|
+
//#endregion
|
|
5198
|
+
//#region src/library/numpy-linalg.ts
|
|
5199
|
+
var numpy_linalg_exports = {};
|
|
5200
|
+
__export(numpy_linalg_exports, {
|
|
5201
|
+
cholesky: () => cholesky,
|
|
5202
|
+
det: () => det,
|
|
5203
|
+
diagonal: () => diagonal,
|
|
5204
|
+
inv: () => inv,
|
|
5205
|
+
lstsq: () => lstsq,
|
|
5206
|
+
matmul: () => matmul,
|
|
5207
|
+
matrixPower: () => matrixPower,
|
|
5208
|
+
matrixTranspose: () => matrixTranspose,
|
|
5209
|
+
outer: () => outer,
|
|
5210
|
+
slogdet: () => slogdet,
|
|
5211
|
+
solve: () => solve,
|
|
5212
|
+
tensordot: () => tensordot,
|
|
5213
|
+
trace: () => trace,
|
|
5214
|
+
vecdot: () => vecdot
|
|
5215
|
+
});
|
|
5216
|
+
function checkSquare(name, a) {
|
|
5217
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`${name}: input must be at least 2D square matrix, got ${a.aval}`);
|
|
5218
|
+
return a.shape[a.ndim - 1];
|
|
5219
|
+
}
|
|
5220
|
+
/**
|
|
5221
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
5222
|
+
*
|
|
5223
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
5224
|
+
* the input matrix, which is on by default.
|
|
5225
|
+
*/
|
|
5226
|
+
function cholesky(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
5227
|
+
a = fudgeArray(a);
|
|
5228
|
+
checkSquare("cholesky", a);
|
|
5229
|
+
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
5230
|
+
return cholesky$1(a, { upper });
|
|
5231
|
+
}
|
|
5232
|
+
/** Compute the determinant of a square matrix (batched). */
|
|
5233
|
+
function det(a) {
|
|
5234
|
+
a = fudgeArray(a);
|
|
5235
|
+
const n = checkSquare("det", a);
|
|
5236
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5237
|
+
permutation.dispose();
|
|
5238
|
+
const parity = pivots.notEqual(arange(n)).astype(int32).sum(-1).mod(2);
|
|
5239
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5240
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5241
|
+
return prod$1(diag$1, -1).mul(sign$1);
|
|
5242
|
+
}
|
|
5243
|
+
/** Compute the inverse of a square matrix (batched). */
|
|
5244
|
+
function inv(a) {
|
|
5245
|
+
a = fudgeArray(a);
|
|
5246
|
+
const n = checkSquare("inv", a);
|
|
5247
|
+
return solve(a, eye(n));
|
|
5248
|
+
}
|
|
5249
|
+
/**
|
|
5250
|
+
* Return the least-squares solution to a linear equation.
|
|
5251
|
+
*
|
|
5252
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
5253
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
5254
|
+
*
|
|
5255
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
5256
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
5257
|
+
*
|
|
5258
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
5259
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
5260
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
5261
|
+
*/
|
|
5262
|
+
function lstsq(a, b) {
|
|
5263
|
+
a = fudgeArray(a);
|
|
5264
|
+
b = fudgeArray(b);
|
|
5265
|
+
if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
|
|
5266
|
+
const [m, n] = a.shape;
|
|
5267
|
+
if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
|
|
5268
|
+
const at = matrixTranspose(a.ref);
|
|
5269
|
+
if (m <= n) {
|
|
5270
|
+
const aat = matmul(a, at.ref);
|
|
5271
|
+
const l = cholesky(aat, { symmetrizeInput: false });
|
|
5272
|
+
const lb = triangularSolve(l.ref, b, {
|
|
5273
|
+
leftSide: true,
|
|
5274
|
+
lower: true
|
|
5275
|
+
});
|
|
5276
|
+
const llb = triangularSolve(l, lb, {
|
|
5277
|
+
leftSide: true,
|
|
5278
|
+
transposeA: true
|
|
5279
|
+
});
|
|
5280
|
+
return matmul(at, llb.ref);
|
|
5281
|
+
} else {
|
|
5282
|
+
const ata = matmul(at.ref, a);
|
|
5283
|
+
const l = cholesky(ata, { symmetrizeInput: false });
|
|
5284
|
+
const atb = matmul(at, b);
|
|
5285
|
+
const lb = triangularSolve(l.ref, atb, {
|
|
5286
|
+
leftSide: true,
|
|
5287
|
+
lower: true
|
|
5288
|
+
});
|
|
5289
|
+
const llb = triangularSolve(l, lb, {
|
|
5290
|
+
leftSide: true,
|
|
5291
|
+
transposeA: true
|
|
5292
|
+
});
|
|
5293
|
+
return llb;
|
|
5294
|
+
}
|
|
5295
|
+
}
|
|
5296
|
+
/** Raise a square matrix to an integer power, via repeated squarings. */
|
|
5297
|
+
function matrixPower(a, n) {
|
|
5298
|
+
if (!Number.isInteger(n)) throw new Error(`matrixPower: exponent must be an integer, got ${n}`);
|
|
5299
|
+
a = fudgeArray(a);
|
|
5300
|
+
const m = checkSquare("matrixPower", a);
|
|
5301
|
+
if (n === 0) {
|
|
5302
|
+
a.dispose();
|
|
5303
|
+
return broadcastTo(eye(m), a.shape);
|
|
5304
|
+
}
|
|
5305
|
+
if (n < 0) {
|
|
5306
|
+
a = inv(a);
|
|
5307
|
+
n = -n;
|
|
5308
|
+
}
|
|
5309
|
+
let result = null;
|
|
5310
|
+
let a2k = a;
|
|
5311
|
+
for (let k = 0; n; k++) {
|
|
5312
|
+
if (k > 0) a2k = matmul(a2k.ref, a2k);
|
|
5313
|
+
if (n % 2 === 1) result = result === null ? a2k.ref : matmul(result, a2k.ref);
|
|
5314
|
+
n = Math.floor(n / 2);
|
|
5315
|
+
}
|
|
5316
|
+
a2k.dispose();
|
|
5317
|
+
return result;
|
|
5318
|
+
}
|
|
5319
|
+
/** Return sign and natural logarithm of the determinant of `a`. */
|
|
5320
|
+
function slogdet(a) {
|
|
5321
|
+
a = fudgeArray(a);
|
|
5322
|
+
const n = checkSquare("slogdet", a);
|
|
5323
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5324
|
+
permutation.dispose();
|
|
5325
|
+
let parity = pivots.notEqual(arange(n)).astype(int32).sum(-1);
|
|
5326
|
+
const diag$1 = lu$2.diagonal(0, -1, -2);
|
|
5327
|
+
parity = parity.add(diag$1.ref.less(0).astype(int32).sum(-1)).mod(2);
|
|
5328
|
+
const logabsdet = log(absolute(diag$1)).sum(-1);
|
|
5329
|
+
const sign$1 = parity.mul(-2).add(1);
|
|
5330
|
+
return [sign$1, logabsdet];
|
|
4711
5331
|
}
|
|
4712
|
-
|
|
4713
|
-
|
|
4714
|
-
|
|
4715
|
-
|
|
4716
|
-
|
|
4717
|
-
|
|
4718
|
-
|
|
4719
|
-
|
|
4720
|
-
|
|
5332
|
+
/**
|
|
5333
|
+
* Solve a linear system of equations.
|
|
5334
|
+
*
|
|
5335
|
+
* This solves a (batched) linear system of equations `a @ x = b` for `x` given
|
|
5336
|
+
* `a` and `b`. If `a` is singular, this will return `nan` or `inf` values.
|
|
5337
|
+
*
|
|
5338
|
+
* @param a - Coefficient matrix of shape `(..., N, N)`.
|
|
5339
|
+
* @param b - Values of shape `(N,)` or `(..., N, M)`.
|
|
5340
|
+
* @returns Solution `x` of shape `(..., N)` or `(..., N, M)`.
|
|
5341
|
+
*/
|
|
5342
|
+
function solve(a, b) {
|
|
5343
|
+
a = fudgeArray(a);
|
|
5344
|
+
b = fudgeArray(b);
|
|
5345
|
+
const n = checkSquare("solve", a);
|
|
5346
|
+
if (b.ndim === 0) throw new Error(`solve: b cannot be scalar`);
|
|
5347
|
+
const bIs1d = b.ndim === 1;
|
|
5348
|
+
if (bIs1d) b = b.reshape([...b.shape, 1]);
|
|
5349
|
+
if (b.shape[b.ndim - 2] !== n) throw new Error(`solve: leading dimension of b must match size of a, got a=${a.aval}, b=${b.aval}`);
|
|
5350
|
+
const m = b.shape[b.ndim - 1];
|
|
5351
|
+
const batchDims = require_backend.generalBroadcast(a.shape.slice(0, -2), b.shape.slice(0, -2));
|
|
5352
|
+
a = broadcastTo(a, [
|
|
5353
|
+
...batchDims,
|
|
5354
|
+
n,
|
|
5355
|
+
n
|
|
5356
|
+
]);
|
|
5357
|
+
b = broadcastTo(b, [
|
|
5358
|
+
...batchDims,
|
|
5359
|
+
n,
|
|
5360
|
+
m
|
|
5361
|
+
]);
|
|
5362
|
+
const [lu$2, pivots, permutation] = lu(a);
|
|
5363
|
+
pivots.dispose();
|
|
5364
|
+
const P = arange(n).equal(permutation.reshape([...permutation.shape, 1])).astype(b.dtype);
|
|
5365
|
+
const LPb = triangularSolve(lu$2.ref, matmul(P, b), {
|
|
5366
|
+
leftSide: true,
|
|
5367
|
+
lower: true,
|
|
5368
|
+
unitDiagonal: true
|
|
5369
|
+
});
|
|
5370
|
+
let x = triangularSolve(lu$2, LPb.ref, {
|
|
5371
|
+
leftSide: true,
|
|
5372
|
+
lower: false
|
|
5373
|
+
});
|
|
5374
|
+
if (bIs1d) x = squeeze(x, -1);
|
|
5375
|
+
return x;
|
|
4721
5376
|
}
|
|
4722
|
-
|
|
4723
|
-
|
|
4724
|
-
|
|
4725
|
-
|
|
4726
|
-
|
|
4727
|
-
|
|
4728
|
-
|
|
4729
|
-
|
|
4730
|
-
|
|
4731
|
-
|
|
5377
|
+
|
|
5378
|
+
//#endregion
|
|
5379
|
+
//#region src/library/numpy/dtype-info.ts
|
|
5380
|
+
/** Machine limits for floating-point types. */
|
|
5381
|
+
function finfo(dtype) {
|
|
5382
|
+
if (!require_backend.isFloatDtype(dtype)) throw new Error(`finfo: received ${dtype}, must be a floating-point type`);
|
|
5383
|
+
switch (dtype) {
|
|
5384
|
+
case require_backend.DType.Float16: return Object.freeze({
|
|
5385
|
+
bits: 16,
|
|
5386
|
+
dtype: require_backend.DType.Float16,
|
|
5387
|
+
eps: 2 ** -10,
|
|
5388
|
+
epsneg: 2 ** -11,
|
|
5389
|
+
machep: -10,
|
|
5390
|
+
max: 65504,
|
|
5391
|
+
maxexp: 16,
|
|
5392
|
+
min: -65504,
|
|
5393
|
+
minexp: -14,
|
|
5394
|
+
negep: -24,
|
|
5395
|
+
nexp: 5,
|
|
5396
|
+
nmant: 10,
|
|
5397
|
+
precision: 3,
|
|
5398
|
+
resolution: .001,
|
|
5399
|
+
smallestNormal: 2 ** -14,
|
|
5400
|
+
smallestSubnormal: 2 ** -24
|
|
5401
|
+
});
|
|
5402
|
+
case require_backend.DType.Float32: return Object.freeze({
|
|
5403
|
+
bits: 32,
|
|
5404
|
+
dtype: require_backend.DType.Float32,
|
|
5405
|
+
eps: 2 ** -23,
|
|
5406
|
+
epsneg: 2 ** -24,
|
|
5407
|
+
machep: -23,
|
|
5408
|
+
max: 34028234663852886e22,
|
|
5409
|
+
maxexp: 128,
|
|
5410
|
+
min: -34028234663852886e22,
|
|
5411
|
+
minexp: -126,
|
|
5412
|
+
negep: -24,
|
|
5413
|
+
nexp: 8,
|
|
5414
|
+
nmant: 23,
|
|
5415
|
+
precision: 6,
|
|
5416
|
+
resolution: 1e-6,
|
|
5417
|
+
smallestNormal: 2 ** -126,
|
|
5418
|
+
smallestSubnormal: 2 ** -149
|
|
5419
|
+
});
|
|
5420
|
+
case require_backend.DType.Float64: return Object.freeze({
|
|
5421
|
+
bits: 64,
|
|
5422
|
+
dtype: require_backend.DType.Float64,
|
|
5423
|
+
eps: 2 ** -52,
|
|
5424
|
+
epsneg: 2 ** -53,
|
|
5425
|
+
machep: -52,
|
|
5426
|
+
max: Number.MAX_VALUE,
|
|
5427
|
+
maxexp: 1024,
|
|
5428
|
+
min: -Number.MAX_VALUE,
|
|
5429
|
+
minexp: -1022,
|
|
5430
|
+
negep: -53,
|
|
5431
|
+
nexp: 11,
|
|
5432
|
+
nmant: 52,
|
|
5433
|
+
precision: 15,
|
|
5434
|
+
resolution: 1e-15,
|
|
5435
|
+
smallestNormal: 2 ** -1022,
|
|
5436
|
+
smallestSubnormal: 2 ** -1074
|
|
5437
|
+
});
|
|
5438
|
+
default: throw new Error(`finfo: unsupported dtype ${dtype}`);
|
|
4732
5439
|
}
|
|
4733
|
-
return new EinsumPath(input, sizeMap, bestPath);
|
|
4734
5440
|
}
|
|
4735
|
-
|
|
4736
|
-
|
|
4737
|
-
|
|
4738
|
-
return
|
|
4739
|
-
|
|
4740
|
-
|
|
4741
|
-
|
|
4742
|
-
|
|
4743
|
-
|
|
4744
|
-
|
|
5441
|
+
/** Machine limits for integer types. */
|
|
5442
|
+
function iinfo(dtype) {
|
|
5443
|
+
switch (dtype) {
|
|
5444
|
+
case require_backend.DType.Int32: return Object.freeze({
|
|
5445
|
+
bits: 32,
|
|
5446
|
+
dtype: require_backend.DType.Int32,
|
|
5447
|
+
max: 2147483647,
|
|
5448
|
+
min: -2147483648
|
|
5449
|
+
});
|
|
5450
|
+
case require_backend.DType.Uint32: return Object.freeze({
|
|
5451
|
+
bits: 32,
|
|
5452
|
+
dtype: require_backend.DType.Uint32,
|
|
5453
|
+
max: 4294967295,
|
|
5454
|
+
min: 0
|
|
5455
|
+
});
|
|
5456
|
+
default: throw new Error(`iinfo: unsupported dtype ${dtype}`);
|
|
4745
5457
|
}
|
|
4746
5458
|
}
|
|
4747
5459
|
|
|
@@ -4751,28 +5463,32 @@ var numpy_exports = {};
|
|
|
4751
5463
|
__export(numpy_exports, {
|
|
4752
5464
|
Array: () => Array$1,
|
|
4753
5465
|
DType: () => require_backend.DType,
|
|
4754
|
-
abs: () =>
|
|
5466
|
+
abs: () => absolute,
|
|
4755
5467
|
absolute: () => absolute,
|
|
4756
5468
|
acos: () => acos,
|
|
4757
|
-
acosh: () =>
|
|
5469
|
+
acosh: () => arccosh,
|
|
4758
5470
|
add: () => add,
|
|
5471
|
+
all: () => all,
|
|
4759
5472
|
allclose: () => allclose,
|
|
5473
|
+
any: () => any,
|
|
4760
5474
|
arange: () => arange,
|
|
4761
|
-
arccos: () =>
|
|
5475
|
+
arccos: () => acos,
|
|
4762
5476
|
arccosh: () => arccosh,
|
|
5477
|
+
arcsin: () => asin,
|
|
4763
5478
|
arcsinh: () => arcsinh,
|
|
4764
|
-
arctan: () =>
|
|
4765
|
-
arctan2: () =>
|
|
5479
|
+
arctan: () => atan,
|
|
5480
|
+
arctan2: () => atan2,
|
|
4766
5481
|
arctanh: () => arctanh,
|
|
4767
5482
|
argmax: () => argmax,
|
|
4768
5483
|
argmin: () => argmin,
|
|
5484
|
+
argsort: () => argsort,
|
|
4769
5485
|
array: () => array,
|
|
4770
5486
|
asin: () => asin,
|
|
4771
|
-
asinh: () =>
|
|
5487
|
+
asinh: () => arcsinh,
|
|
4772
5488
|
astype: () => astype,
|
|
4773
5489
|
atan: () => atan,
|
|
4774
5490
|
atan2: () => atan2,
|
|
4775
|
-
atanh: () =>
|
|
5491
|
+
atanh: () => arctanh,
|
|
4776
5492
|
bool: () => bool,
|
|
4777
5493
|
broadcastArrays: () => broadcastArrays,
|
|
4778
5494
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -4782,16 +5498,21 @@ __export(numpy_exports, {
|
|
|
4782
5498
|
clip: () => clip,
|
|
4783
5499
|
columnStack: () => columnStack,
|
|
4784
5500
|
concatenate: () => concatenate,
|
|
5501
|
+
convolve: () => convolve,
|
|
5502
|
+
corrcoef: () => corrcoef,
|
|
5503
|
+
correlate: () => correlate,
|
|
4785
5504
|
cos: () => cos,
|
|
4786
5505
|
cosh: () => cosh,
|
|
5506
|
+
cov: () => cov,
|
|
4787
5507
|
cumsum: () => cumsum,
|
|
4788
|
-
cumulativeSum: () =>
|
|
5508
|
+
cumulativeSum: () => cumsum,
|
|
4789
5509
|
deg2rad: () => deg2rad,
|
|
4790
5510
|
degrees: () => degrees,
|
|
4791
5511
|
diag: () => diag,
|
|
4792
5512
|
diagonal: () => diagonal,
|
|
4793
|
-
divide: () =>
|
|
4794
|
-
|
|
5513
|
+
divide: () => trueDivide,
|
|
5514
|
+
divmod: () => divmod,
|
|
5515
|
+
dot: () => dot$1,
|
|
4795
5516
|
dstack: () => dstack,
|
|
4796
5517
|
e: () => e,
|
|
4797
5518
|
einsum: () => einsum,
|
|
@@ -4799,8 +5520,11 @@ __export(numpy_exports, {
|
|
|
4799
5520
|
eulerGamma: () => eulerGamma,
|
|
4800
5521
|
exp: () => exp,
|
|
4801
5522
|
exp2: () => exp2,
|
|
5523
|
+
expandDims: () => expandDims,
|
|
4802
5524
|
expm1: () => expm1,
|
|
4803
5525
|
eye: () => eye,
|
|
5526
|
+
fft: () => numpy_fft_exports,
|
|
5527
|
+
finfo: () => finfo,
|
|
4804
5528
|
flip: () => flip,
|
|
4805
5529
|
fliplr: () => fliplr,
|
|
4806
5530
|
flipud: () => flipud,
|
|
@@ -4808,6 +5532,7 @@ __export(numpy_exports, {
|
|
|
4808
5532
|
float32: () => float32,
|
|
4809
5533
|
float64: () => float64,
|
|
4810
5534
|
floor: () => floor,
|
|
5535
|
+
floorDivide: () => floorDivide,
|
|
4811
5536
|
fmod: () => fmod,
|
|
4812
5537
|
frexp: () => frexp,
|
|
4813
5538
|
full: () => full,
|
|
@@ -4820,6 +5545,7 @@ __export(numpy_exports, {
|
|
|
4820
5545
|
hstack: () => hstack,
|
|
4821
5546
|
hypot: () => hypot,
|
|
4822
5547
|
identity: () => identity$1,
|
|
5548
|
+
iinfo: () => iinfo,
|
|
4823
5549
|
inf: () => inf,
|
|
4824
5550
|
inner: () => inner,
|
|
4825
5551
|
int32: () => int32,
|
|
@@ -4831,12 +5557,15 @@ __export(numpy_exports, {
|
|
|
4831
5557
|
ldexp: () => ldexp,
|
|
4832
5558
|
less: () => less,
|
|
4833
5559
|
lessEqual: () => lessEqual,
|
|
5560
|
+
linalg: () => numpy_linalg_exports,
|
|
4834
5561
|
linspace: () => linspace,
|
|
4835
5562
|
log: () => log,
|
|
4836
5563
|
log10: () => log10,
|
|
4837
5564
|
log1p: () => log1p,
|
|
4838
5565
|
log2: () => log2,
|
|
5566
|
+
logspace: () => logspace,
|
|
4839
5567
|
matmul: () => matmul,
|
|
5568
|
+
matrixTranspose: () => matrixTranspose,
|
|
4840
5569
|
max: () => max,
|
|
4841
5570
|
maximum: () => maximum,
|
|
4842
5571
|
mean: () => mean,
|
|
@@ -4853,10 +5582,10 @@ __export(numpy_exports, {
|
|
|
4853
5582
|
onesLike: () => onesLike,
|
|
4854
5583
|
outer: () => outer,
|
|
4855
5584
|
pad: () => pad,
|
|
4856
|
-
permuteDims: () =>
|
|
5585
|
+
permuteDims: () => transpose,
|
|
4857
5586
|
pi: () => pi,
|
|
4858
5587
|
positive: () => positive,
|
|
4859
|
-
pow: () =>
|
|
5588
|
+
pow: () => power,
|
|
4860
5589
|
power: () => power,
|
|
4861
5590
|
prod: () => prod$1,
|
|
4862
5591
|
promoteTypes: () => require_backend.promoteTypes,
|
|
@@ -4871,8 +5600,11 @@ __export(numpy_exports, {
|
|
|
4871
5600
|
shape: () => shape,
|
|
4872
5601
|
sign: () => sign,
|
|
4873
5602
|
sin: () => sin,
|
|
5603
|
+
sinc: () => sinc,
|
|
4874
5604
|
sinh: () => sinh,
|
|
4875
5605
|
size: () => size,
|
|
5606
|
+
sort: () => sort,
|
|
5607
|
+
split: () => split$1,
|
|
4876
5608
|
sqrt: () => sqrt,
|
|
4877
5609
|
square: () => square,
|
|
4878
5610
|
squeeze: () => squeeze,
|
|
@@ -4880,6 +5612,7 @@ __export(numpy_exports, {
|
|
|
4880
5612
|
std: () => std,
|
|
4881
5613
|
subtract: () => subtract,
|
|
4882
5614
|
sum: () => sum,
|
|
5615
|
+
take: () => take,
|
|
4883
5616
|
tan: () => tan,
|
|
4884
5617
|
tanh: () => tanh,
|
|
4885
5618
|
tensordot: () => tensordot,
|
|
@@ -5037,6 +5770,26 @@ function min(a, axis = null, opts) {
|
|
|
5037
5770
|
function max(a, axis = null, opts) {
|
|
5038
5771
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
5039
5772
|
}
|
|
5773
|
+
/**
|
|
5774
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5775
|
+
*
|
|
5776
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5777
|
+
* removed. If axis is None, returns a scalar.
|
|
5778
|
+
*/
|
|
5779
|
+
function all(a, axis = null, opts) {
|
|
5780
|
+
a = fudgeArray(a).astype(require_backend.DType.Bool);
|
|
5781
|
+
return min(a, axis, opts);
|
|
5782
|
+
}
|
|
5783
|
+
/**
|
|
5784
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5785
|
+
*
|
|
5786
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5787
|
+
* removed. If axis is None, returns a scalar.
|
|
5788
|
+
*/
|
|
5789
|
+
function any(a, axis = null, opts) {
|
|
5790
|
+
a = fudgeArray(a).astype(require_backend.DType.Bool);
|
|
5791
|
+
return max(a, axis, opts);
|
|
5792
|
+
}
|
|
5040
5793
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5041
5794
|
function ptp(a, axis = null, opts) {
|
|
5042
5795
|
a = fudgeArray(a);
|
|
@@ -5111,8 +5864,6 @@ function cumsum(a, axis) {
|
|
|
5111
5864
|
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5112
5865
|
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5113
5866
|
}
|
|
5114
|
-
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
5115
|
-
const cumulativeSum = cumsum;
|
|
5116
5867
|
/** Reverse the elements in an array along the given axes. */
|
|
5117
5868
|
function flip(x, axis = null) {
|
|
5118
5869
|
const nd = ndim(x);
|
|
@@ -5120,6 +5871,45 @@ function flip(x, axis = null) {
|
|
|
5120
5871
|
return flip$1(x, axis);
|
|
5121
5872
|
}
|
|
5122
5873
|
/**
|
|
5874
|
+
* Split an array into multiple sub-arrays along an axis.
|
|
5875
|
+
*
|
|
5876
|
+
* @param a - The input array to split.
|
|
5877
|
+
* @param indicesOrSections - If an integer, it indicates the number of equal
|
|
5878
|
+
* sections to create along the specified axis. If a list of integers, it
|
|
5879
|
+
* specifies the indices at which to split the array.
|
|
5880
|
+
* @param axis - The axis along which to split the array. Default is 0.
|
|
5881
|
+
*/
|
|
5882
|
+
function split$1(a, indicesOrSections, axis = 0) {
|
|
5883
|
+
a = fudgeArray(a);
|
|
5884
|
+
axis = require_backend.checkAxis(axis, a.ndim);
|
|
5885
|
+
const size$1 = a.shape[axis];
|
|
5886
|
+
let sizes;
|
|
5887
|
+
if (typeof indicesOrSections === "number") {
|
|
5888
|
+
if (size$1 % indicesOrSections !== 0) throw new Error(`Array of size ${size$1} cannot be split into ${indicesOrSections} equal parts`);
|
|
5889
|
+
const partSize = size$1 / indicesOrSections;
|
|
5890
|
+
sizes = require_backend.rep(indicesOrSections, partSize);
|
|
5891
|
+
} else {
|
|
5892
|
+
const indices = indicesOrSections;
|
|
5893
|
+
sizes = [indices[0]];
|
|
5894
|
+
for (let i = 1; i < indices.length; i++) sizes.push(indices[i] - indices[i - 1]);
|
|
5895
|
+
sizes.push(size$1 - indices[indices.length - 1]);
|
|
5896
|
+
}
|
|
5897
|
+
const results = [];
|
|
5898
|
+
for (let i = 0; i < sizes.length; i += 7) if (i === sizes.length) {
|
|
5899
|
+
results.push(a);
|
|
5900
|
+
break;
|
|
5901
|
+
} else if (i + 8 >= sizes.length) {
|
|
5902
|
+
results.push(...split$2(a, axis, sizes.slice(i)));
|
|
5903
|
+
break;
|
|
5904
|
+
} else {
|
|
5905
|
+
const groupSizes = [...sizes.slice(i, i + 7), sizes.slice(i + 7).reduce((x, y) => x + y, 0)];
|
|
5906
|
+
const outs = split$2(a, axis, groupSizes);
|
|
5907
|
+
results.push(...outs.slice(0, -1));
|
|
5908
|
+
a = outs[outs.length - 1];
|
|
5909
|
+
}
|
|
5910
|
+
return results;
|
|
5911
|
+
}
|
|
5912
|
+
/**
|
|
5123
5913
|
* Join a sequence of arrays along an existing axis.
|
|
5124
5914
|
*
|
|
5125
5915
|
* The arrays must have the same shape, except in the dimension corresponding to
|
|
@@ -5131,13 +5921,11 @@ function concatenate(xs, axis = 0) {
|
|
|
5131
5921
|
if (xs.length === 0) throw new Error("Need at least one array to concatenate");
|
|
5132
5922
|
const shapes = xs.map(shape);
|
|
5133
5923
|
axis = require_backend.checkAxis(axis, shapes[0].length);
|
|
5134
|
-
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays
|
|
5135
|
-
const makePadAxis = (start, end) => shapes[0].map((_, i) => i === axis ? [start, end] : [0, 0]);
|
|
5924
|
+
for (let i = 1; i < shapes.length; i++) if (shapes[i].length !== shapes[0].length || !shapes[i].every((d, j) => j === axis || d === shapes[0][j])) throw new Error(`Cannot concatenate arrays ${xs[0].aval} and ${xs[i].aval} along axis ${axis}`);
|
|
5136
5925
|
let result = xs[0];
|
|
5137
|
-
for (let i = 1; i < xs.length; i
|
|
5138
|
-
const
|
|
5139
|
-
|
|
5140
|
-
result = pad(result, makePadAxis(0, len2)).add(pad(xs[i], makePadAxis(len1, 0)));
|
|
5926
|
+
for (let i = 1; i < xs.length; i += 7) {
|
|
5927
|
+
const group = xs.slice(i, i + 7);
|
|
5928
|
+
result = concatenate$1([result, ...group], axis);
|
|
5141
5929
|
}
|
|
5142
5930
|
return result;
|
|
5143
5931
|
}
|
|
@@ -5222,8 +6010,11 @@ function flipud(x) {
|
|
|
5222
6010
|
function fliplr(x) {
|
|
5223
6011
|
return flip(x, 1);
|
|
5224
6012
|
}
|
|
5225
|
-
/**
|
|
5226
|
-
|
|
6013
|
+
/** Transpose the last two dimensions of an array. */
|
|
6014
|
+
function matrixTranspose(a) {
|
|
6015
|
+
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
6016
|
+
return moveaxis$1(a, -1, -2);
|
|
6017
|
+
}
|
|
5227
6018
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
5228
6019
|
function ravel(a) {
|
|
5229
6020
|
return fudgeArray(a).ravel();
|
|
@@ -5239,6 +6030,32 @@ function squeeze(a, axis = null) {
|
|
|
5239
6030
|
return reshape(a, newShape);
|
|
5240
6031
|
}
|
|
5241
6032
|
/**
|
|
6033
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
6034
|
+
*
|
|
6035
|
+
* @param a - Input array.
|
|
6036
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
6037
|
+
* is placed. Can be a single integer or an array of integers.
|
|
6038
|
+
* @returns Array with the number of dimensions increased.
|
|
6039
|
+
*
|
|
6040
|
+
* @example
|
|
6041
|
+
* ```ts
|
|
6042
|
+
* const x = np.array([1, 2]);
|
|
6043
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
6044
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
6045
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
6046
|
+
* ```
|
|
6047
|
+
*/
|
|
6048
|
+
function expandDims(a, axis) {
|
|
6049
|
+
const as = shape(a);
|
|
6050
|
+
axis = typeof axis === "number" ? [axis] : axis;
|
|
6051
|
+
axis = require_backend.normalizeAxis(axis, as.length + axis.length);
|
|
6052
|
+
const newShape = [];
|
|
6053
|
+
let srcIdx = 0;
|
|
6054
|
+
for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
|
|
6055
|
+
else newShape.push(as[srcIdx++]);
|
|
6056
|
+
return reshape(a, newShape);
|
|
6057
|
+
}
|
|
6058
|
+
/**
|
|
5242
6059
|
* Repeat each element of an array after themselves.
|
|
5243
6060
|
*
|
|
5244
6061
|
* If no axis is provided, use the flattened input array, and return a flat
|
|
@@ -5326,7 +6143,7 @@ function diagonal(a, offset, axis1, axis2) {
|
|
|
5326
6143
|
*/
|
|
5327
6144
|
function diag(v, k = 0) {
|
|
5328
6145
|
const a = fudgeArray(v);
|
|
5329
|
-
if (!Number.isInteger(k)) throw new
|
|
6146
|
+
if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
|
|
5330
6147
|
if (a.ndim === 1) {
|
|
5331
6148
|
const n = a.shape[0];
|
|
5332
6149
|
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
@@ -5334,12 +6151,46 @@ function diag(v, k = 0) {
|
|
|
5334
6151
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
5335
6152
|
else return ret;
|
|
5336
6153
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
5337
|
-
else throw new
|
|
6154
|
+
else throw new Error("numpy.diag only supports 1D and 2D arrays");
|
|
5338
6155
|
}
|
|
5339
6156
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
5340
6157
|
function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
|
|
5341
6158
|
return diagonal(a, offset, axis1, axis2).sum(-1);
|
|
5342
6159
|
}
|
|
6160
|
+
/**
|
|
6161
|
+
* Return a sorted copy of an array.
|
|
6162
|
+
*
|
|
6163
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
6164
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
6165
|
+
*/
|
|
6166
|
+
function sort(a, axis = -1) {
|
|
6167
|
+
return fudgeArray(a).sort(axis);
|
|
6168
|
+
}
|
|
6169
|
+
/**
|
|
6170
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
6171
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
6172
|
+
*
|
|
6173
|
+
* Returns an array of `int32` indices.
|
|
6174
|
+
*
|
|
6175
|
+
* The array is sorted along a specified axis (the last by default).
|
|
6176
|
+
*/
|
|
6177
|
+
function argsort(a, axis = -1) {
|
|
6178
|
+
return fudgeArray(a).argsort(axis);
|
|
6179
|
+
}
|
|
6180
|
+
/**
|
|
6181
|
+
* Take elements from an array along an axis.
|
|
6182
|
+
*
|
|
6183
|
+
* This is equivalent to advanced indexing with integer indices over that
|
|
6184
|
+
* numbered axis. By default, the flattened array is used.
|
|
6185
|
+
*/
|
|
6186
|
+
function take(a, indices, axis = null) {
|
|
6187
|
+
if (axis === null) {
|
|
6188
|
+
a = ravel(a);
|
|
6189
|
+
axis = 0;
|
|
6190
|
+
}
|
|
6191
|
+
axis = require_backend.checkAxis(axis, ndim(a));
|
|
6192
|
+
return gather(a, [indices], [axis], axis);
|
|
6193
|
+
}
|
|
5343
6194
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5344
6195
|
function allclose(actual, expected, options) {
|
|
5345
6196
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5356,11 +6207,11 @@ function allclose(actual, expected, options) {
|
|
|
5356
6207
|
}
|
|
5357
6208
|
/** Matrix product of two arrays. */
|
|
5358
6209
|
function matmul(x, y) {
|
|
5359
|
-
if (ndim(x) === 0 || ndim(y) === 0) throw new
|
|
6210
|
+
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
5360
6211
|
x = x, y = y;
|
|
5361
6212
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5362
6213
|
const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
|
|
5363
|
-
return dot
|
|
6214
|
+
return dot(x, y, {
|
|
5364
6215
|
lhsContractingDims: [-1],
|
|
5365
6216
|
rhsContractingDims: [-2],
|
|
5366
6217
|
lhsBatchDims: require_backend.range(-2 - numBatchDims, -2),
|
|
@@ -5368,11 +6219,11 @@ function matmul(x, y) {
|
|
|
5368
6219
|
});
|
|
5369
6220
|
}
|
|
5370
6221
|
/** Dot product of two arrays. */
|
|
5371
|
-
function dot(x, y) {
|
|
6222
|
+
function dot$1(x, y) {
|
|
5372
6223
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
5373
6224
|
x = x, y = y;
|
|
5374
6225
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5375
|
-
return dot
|
|
6226
|
+
return dot(x, y, {
|
|
5376
6227
|
lhsContractingDims: [-1],
|
|
5377
6228
|
rhsContractingDims: [-2]
|
|
5378
6229
|
});
|
|
@@ -5388,7 +6239,7 @@ function tensordot(x, y, axes = 2) {
|
|
|
5388
6239
|
x = fudgeArray(x);
|
|
5389
6240
|
y = fudgeArray(y);
|
|
5390
6241
|
if (typeof axes === "number") axes = [require_backend.range(-axes, 0), require_backend.range(axes)];
|
|
5391
|
-
return dot
|
|
6242
|
+
return dot(x, y, {
|
|
5392
6243
|
lhsContractingDims: axes[0],
|
|
5393
6244
|
rhsContractingDims: axes[1]
|
|
5394
6245
|
});
|
|
@@ -5481,7 +6332,7 @@ function einsum(...args) {
|
|
|
5481
6332
|
const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
|
|
5482
6333
|
indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
|
|
5483
6334
|
const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
|
|
5484
|
-
const result = dot
|
|
6335
|
+
const result = dot(a, b, {
|
|
5485
6336
|
lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
|
|
5486
6337
|
rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
|
|
5487
6338
|
lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
|
|
@@ -5509,7 +6360,7 @@ function einsum(...args) {
|
|
|
5509
6360
|
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
5510
6361
|
*/
|
|
5511
6362
|
function inner(x, y) {
|
|
5512
|
-
return dot
|
|
6363
|
+
return dot(fudgeArray(x), fudgeArray(y), {
|
|
5513
6364
|
lhsContractingDims: [-1],
|
|
5514
6365
|
rhsContractingDims: [-1]
|
|
5515
6366
|
});
|
|
@@ -5542,6 +6393,30 @@ function vecdot(x, y, { axis } = {}) {
|
|
|
5542
6393
|
function vdot(x, y) {
|
|
5543
6394
|
return dot$2(ravel(x), ravel(y));
|
|
5544
6395
|
}
|
|
6396
|
+
function _convImpl(name, x, y, mode) {
|
|
6397
|
+
if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
|
|
6398
|
+
let flipOutput = false;
|
|
6399
|
+
if (x.shape[0] < y.shape[0]) {
|
|
6400
|
+
[x, y] = [y, x];
|
|
6401
|
+
if (name === "correlate") flipOutput = true;
|
|
6402
|
+
}
|
|
6403
|
+
if (name === "convolve") y = flip(y);
|
|
6404
|
+
let padding;
|
|
6405
|
+
if (mode === "valid") padding = "VALID";
|
|
6406
|
+
else if (mode === "same") padding = "SAME_LOWER";
|
|
6407
|
+
else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
|
|
6408
|
+
else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
|
|
6409
|
+
const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
|
|
6410
|
+
return flipOutput ? flip(z) : z;
|
|
6411
|
+
}
|
|
6412
|
+
/** Convolution of two one-dimensional arrays. */
|
|
6413
|
+
function convolve(x, y, mode = "full") {
|
|
6414
|
+
return _convImpl("convolve", x, y, mode);
|
|
6415
|
+
}
|
|
6416
|
+
/** Correlation of two one dimensional arrays. */
|
|
6417
|
+
function correlate(x, y, mode = "valid") {
|
|
6418
|
+
return _convImpl("correlate", x, y, mode);
|
|
6419
|
+
}
|
|
5545
6420
|
/**
|
|
5546
6421
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
5547
6422
|
*
|
|
@@ -5550,7 +6425,7 @@ function vdot(x, y) {
|
|
|
5550
6425
|
*/
|
|
5551
6426
|
function meshgrid(xs, { indexing } = {}) {
|
|
5552
6427
|
indexing ??= "xy";
|
|
5553
|
-
for (const x of xs) if (x.ndim !== 1) throw new
|
|
6428
|
+
for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
|
|
5554
6429
|
if (xs.length <= 1) return xs;
|
|
5555
6430
|
if (indexing === "xy") {
|
|
5556
6431
|
const [a, b, ...rest] = xs;
|
|
@@ -5569,43 +6444,6 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
5569
6444
|
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
5570
6445
|
}
|
|
5571
6446
|
/**
|
|
5572
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
5573
|
-
*
|
|
5574
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
5575
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
5576
|
-
* `k>0` is above it.
|
|
5577
|
-
*/
|
|
5578
|
-
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
5579
|
-
m ??= n;
|
|
5580
|
-
dtype ??= require_backend.DType.Float32;
|
|
5581
|
-
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
5582
|
-
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
5583
|
-
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
5584
|
-
const rows = arange(k, n + k, 1, {
|
|
5585
|
-
dtype: require_backend.DType.Int32,
|
|
5586
|
-
device
|
|
5587
|
-
});
|
|
5588
|
-
const cols = arange(0, m, 1, {
|
|
5589
|
-
dtype: require_backend.DType.Int32,
|
|
5590
|
-
device
|
|
5591
|
-
});
|
|
5592
|
-
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
5593
|
-
}
|
|
5594
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
5595
|
-
function tril(a, k = 0) {
|
|
5596
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5597
|
-
a = fudgeArray(a);
|
|
5598
|
-
const [n, m] = a.shape.slice(-2);
|
|
5599
|
-
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
5600
|
-
}
|
|
5601
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
5602
|
-
function triu(a, k = 0) {
|
|
5603
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5604
|
-
a = fudgeArray(a);
|
|
5605
|
-
const [n, m] = a.shape.slice(-2);
|
|
5606
|
-
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
5607
|
-
}
|
|
5608
|
-
/**
|
|
5609
6447
|
* Clip (limit) the values in an array.
|
|
5610
6448
|
*
|
|
5611
6449
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -5629,8 +6467,6 @@ function absolute(x) {
|
|
|
5629
6467
|
x = fudgeArray(x);
|
|
5630
6468
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
5631
6469
|
}
|
|
5632
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
5633
|
-
const abs = absolute;
|
|
5634
6470
|
/** Return an element-wise indication of sign of the input. */
|
|
5635
6471
|
function sign(x) {
|
|
5636
6472
|
x = fudgeArray(x);
|
|
@@ -5674,6 +6510,20 @@ function tan(x) {
|
|
|
5674
6510
|
x = fudgeArray(x);
|
|
5675
6511
|
return sin(x.ref).div(cos(x));
|
|
5676
6512
|
}
|
|
6513
|
+
/**
|
|
6514
|
+
* @function
|
|
6515
|
+
* Return the normalized sinc function.
|
|
6516
|
+
*
|
|
6517
|
+
* The sinc function is defined as `sin(πx) / (πx)` for `x != 0`, and `1` for `x = 0`.
|
|
6518
|
+
* This is the normalized sinc function commonly used in signal processing.
|
|
6519
|
+
*
|
|
6520
|
+
* **Note:** JVP is not supported at x=0 due to discontinuous derivative. This
|
|
6521
|
+
* requires a custom JVP rule to handle properly (see JAX implementation).
|
|
6522
|
+
*/
|
|
6523
|
+
const sinc = jit$1(function sinc$1(x) {
|
|
6524
|
+
const pix = x.ref.mul(Math.PI);
|
|
6525
|
+
return where(equal(x, 0), 1, sin(pix.ref).div(pix));
|
|
6526
|
+
});
|
|
5677
6527
|
/** Element-wise inverse cosine function (inverse of cos). */
|
|
5678
6528
|
function acos(x) {
|
|
5679
6529
|
return subtract(pi / 2, asin(x));
|
|
@@ -5709,12 +6559,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
|
|
|
5709
6559
|
const denom = where(xNeg, y, r.add(x));
|
|
5710
6560
|
return atan(numer.div(denom)).mul(2);
|
|
5711
6561
|
});
|
|
5712
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
5713
|
-
const arccos = acos;
|
|
5714
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
5715
|
-
const arctan = atan;
|
|
5716
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
5717
|
-
const arctan2 = atan2;
|
|
5718
6562
|
/** Element-wise subtraction, with broadcasting. */
|
|
5719
6563
|
function subtract(x, y) {
|
|
5720
6564
|
x = fudgeArray(x);
|
|
@@ -5732,6 +6576,25 @@ function trueDivide(x, y) {
|
|
|
5732
6576
|
return x.div(y);
|
|
5733
6577
|
}
|
|
5734
6578
|
/**
|
|
6579
|
+
* Return the largest integer smaller or equal to the division of the inputs.
|
|
6580
|
+
*
|
|
6581
|
+
* The result is always rounded towards negative infinity.
|
|
6582
|
+
*
|
|
6583
|
+
* For floating-point inputs, this is equivalent to `floor(x / y)`.
|
|
6584
|
+
* For integer inputs, we use `(x - remainder(x, y)) / y` to handle
|
|
6585
|
+
* negative values correctly (note: may overflow near int32 boundaries).
|
|
6586
|
+
*
|
|
6587
|
+
* @param x - Dividend array.
|
|
6588
|
+
* @param y - Divisor array.
|
|
6589
|
+
* @returns Element-wise floor division of x by y.
|
|
6590
|
+
*/
|
|
6591
|
+
function floorDivide(x, y) {
|
|
6592
|
+
x = fudgeArray(x);
|
|
6593
|
+
y = fudgeArray(y);
|
|
6594
|
+
if (require_backend.isFloatDtype(x.dtype) || require_backend.isFloatDtype(y.dtype)) return floor(trueDivide(x, y));
|
|
6595
|
+
return subtract(x, remainder(x.ref, y.ref)).div(y);
|
|
6596
|
+
}
|
|
6597
|
+
/**
|
|
5735
6598
|
* @function
|
|
5736
6599
|
* Calculate element-wise floating-point modulo operation.
|
|
5737
6600
|
*/
|
|
@@ -5745,8 +6608,20 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
5745
6608
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
5746
6609
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
5747
6610
|
});
|
|
5748
|
-
/**
|
|
5749
|
-
|
|
6611
|
+
/**
|
|
6612
|
+
* Return element-wise quotient and remainder simultaneously.
|
|
6613
|
+
*
|
|
6614
|
+
* Equivalent to `[floorDivide(x, y), remainder(x, y)]`.
|
|
6615
|
+
*
|
|
6616
|
+
* @param x - Dividend array.
|
|
6617
|
+
* @param y - Divisor array.
|
|
6618
|
+
* @returns Tuple of [quotient, remainder].
|
|
6619
|
+
*/
|
|
6620
|
+
function divmod(x, y) {
|
|
6621
|
+
const xArr = fudgeArray(x);
|
|
6622
|
+
const yArr = fudgeArray(y);
|
|
6623
|
+
return [floorDivide(xArr.ref, yArr.ref), remainder(xArr, yArr)];
|
|
6624
|
+
}
|
|
5750
6625
|
/** Round input to the nearest integer towards zero. */
|
|
5751
6626
|
function trunc(x) {
|
|
5752
6627
|
return idiv(x, 1);
|
|
@@ -5768,9 +6643,9 @@ function ldexp(x1, x2) {
|
|
|
5768
6643
|
*/
|
|
5769
6644
|
function frexp(x) {
|
|
5770
6645
|
x = fudgeArray(x);
|
|
5771
|
-
const absx =
|
|
6646
|
+
const absx = absolute(x.ref);
|
|
5772
6647
|
const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(require_backend.DType.Int32));
|
|
5773
|
-
const mantissa =
|
|
6648
|
+
const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
|
|
5774
6649
|
return [mantissa, exponent];
|
|
5775
6650
|
}
|
|
5776
6651
|
/** Calculate `2**p` for all p in the input array. */
|
|
@@ -5813,10 +6688,8 @@ const power = jit$1(function power$1(x1, x2) {
|
|
|
5813
6688
|
const x2i = trunc(x2.ref);
|
|
5814
6689
|
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
5815
6690
|
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
5816
|
-
return where(shouldBeNaN, nan, exp(log(
|
|
6691
|
+
return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
|
|
5817
6692
|
});
|
|
5818
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
5819
|
-
const pow = power;
|
|
5820
6693
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
5821
6694
|
const cbrt = jit$1(function cbrt$1(x) {
|
|
5822
6695
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
@@ -5882,69 +6755,360 @@ const arccosh = jit$1(function arccosh$1(x) {
|
|
|
5882
6755
|
const arctanh = jit$1(function arctanh$1(x) {
|
|
5883
6756
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
5884
6757
|
});
|
|
5885
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5886
|
-
const asinh = arcsinh;
|
|
5887
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5888
|
-
const acosh = arccosh;
|
|
5889
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5890
|
-
const atanh = arctanh;
|
|
5891
6758
|
/**
|
|
5892
|
-
* Compute the variance of an array.
|
|
5893
|
-
*
|
|
5894
|
-
* The variance is computed for the flattened array by default, otherwise over
|
|
5895
|
-
* the specified axis.
|
|
6759
|
+
* Compute the variance of an array.
|
|
6760
|
+
*
|
|
6761
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
6762
|
+
* the specified axis.
|
|
6763
|
+
*
|
|
6764
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
6765
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
6766
|
+
*/
|
|
6767
|
+
function var_(x, axis = null, opts) {
|
|
6768
|
+
x = fudgeArray(x);
|
|
6769
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
6770
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
6771
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
6772
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
6773
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
6774
|
+
}
|
|
6775
|
+
/**
|
|
6776
|
+
* Compute the standard deviation of an array.
|
|
6777
|
+
*
|
|
6778
|
+
* The standard deviation is computed for the flattened array by default,
|
|
6779
|
+
* otherwise over the specified axis.
|
|
6780
|
+
*
|
|
6781
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
6782
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
6783
|
+
*/
|
|
6784
|
+
function std(x, axis = null, opts) {
|
|
6785
|
+
return sqrt(var_(x, axis, opts));
|
|
6786
|
+
}
|
|
6787
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
6788
|
+
function cov(x, y = null, { rowvar = true } = {}) {
|
|
6789
|
+
x = fudgeArray(x);
|
|
6790
|
+
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6791
|
+
if (y !== null) {
|
|
6792
|
+
y = fudgeArray(y);
|
|
6793
|
+
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6794
|
+
x = vstack([x, y]);
|
|
6795
|
+
}
|
|
6796
|
+
if (!rowvar) x = x.transpose();
|
|
6797
|
+
const [_M, N] = x.shape;
|
|
6798
|
+
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6799
|
+
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
6800
|
+
}
|
|
6801
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
6802
|
+
function corrcoef(x, y) {
|
|
6803
|
+
const c = cov(x, y);
|
|
6804
|
+
const variances = diag(c.ref);
|
|
6805
|
+
const norm = sqrt(outer(variances.ref, variances));
|
|
6806
|
+
return c.div(norm);
|
|
6807
|
+
}
|
|
6808
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
6809
|
+
function isinf(x) {
|
|
6810
|
+
x = fudgeArray(x);
|
|
6811
|
+
return require_backend.isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
|
|
6812
|
+
}
|
|
6813
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
6814
|
+
function isnan(x) {
|
|
6815
|
+
x = fudgeArray(x);
|
|
6816
|
+
return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
6817
|
+
}
|
|
6818
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
6819
|
+
function isneginf(x) {
|
|
6820
|
+
x = fudgeArray(x);
|
|
6821
|
+
return require_backend.isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
|
|
6822
|
+
}
|
|
6823
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
6824
|
+
function isposinf(x) {
|
|
6825
|
+
x = fudgeArray(x);
|
|
6826
|
+
return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
6827
|
+
}
|
|
6828
|
+
/**
|
|
6829
|
+
* @function
|
|
6830
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
6831
|
+
*/
|
|
6832
|
+
const isfinite = jit$1(function isfinite$1(x) {
|
|
6833
|
+
if (!require_backend.isFloatDtype(x.dtype)) return fullLike$1(x, true);
|
|
6834
|
+
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
6835
|
+
});
|
|
6836
|
+
|
|
6837
|
+
//#endregion
|
|
6838
|
+
//#region src/library/lax-linalg.ts
|
|
6839
|
+
var lax_linalg_exports = {};
|
|
6840
|
+
__export(lax_linalg_exports, {
|
|
6841
|
+
cholesky: () => cholesky$1,
|
|
6842
|
+
lu: () => lu,
|
|
6843
|
+
triangularSolve: () => triangularSolve
|
|
6844
|
+
});
|
|
6845
|
+
/**
|
|
6846
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
6847
|
+
*
|
|
6848
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
6849
|
+
*
|
|
6850
|
+
* - A = L @ L^T (for upper=false, default)
|
|
6851
|
+
* - A = U^T @ U (for upper=true)
|
|
6852
|
+
*
|
|
6853
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
6854
|
+
* The input matrix must be symmetric and positive-definite.
|
|
6855
|
+
*
|
|
6856
|
+
* @example
|
|
6857
|
+
* ```ts
|
|
6858
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6859
|
+
*
|
|
6860
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
6861
|
+
*
|
|
6862
|
+
* // Lower Cholesky factorization (default):
|
|
6863
|
+
* const L = lax.linalg.cholesky(x);
|
|
6864
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
6865
|
+
*
|
|
6866
|
+
* // Upper Cholesky factorization:
|
|
6867
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
6868
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6869
|
+
* ```
|
|
6870
|
+
*/
|
|
6871
|
+
function cholesky$1(a, { upper = false } = {}) {
|
|
6872
|
+
const L = cholesky$2(a);
|
|
6873
|
+
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6874
|
+
}
|
|
6875
|
+
/**
|
|
6876
|
+
* LU decomposition with partial pivoting.
|
|
6877
|
+
*
|
|
6878
|
+
* Computes the matrix decomposition: `P @ A = L @ U`, where `P` is a
|
|
6879
|
+
* permutation of the rows of `A`, `L` is lower-triangular with unit diagonal,
|
|
6880
|
+
* and `U` is upper-triangular.
|
|
6881
|
+
*
|
|
6882
|
+
* @param x - A batch of matrices with shape `[..., m, n]`.
|
|
6883
|
+
*
|
|
6884
|
+
* @returns A tuple `(lu, pivots, permutation)` where:
|
|
6885
|
+
* - `lu`: combined lower and upper triangular matrices.
|
|
6886
|
+
* - `pivots`: an array of pivot indices with shape `[..., min(m, n)]`.
|
|
6887
|
+
* - `permutation`: the permutation generated by pivots with shape `[..., m]`.
|
|
6888
|
+
*
|
|
6889
|
+
* @example
|
|
6890
|
+
* ```ts
|
|
6891
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6892
|
+
*
|
|
6893
|
+
* const A = np.array([[4., 3.], [6., 3.]]);
|
|
6894
|
+
* const [lu, pivots, permutation] = lax.linalg.lu(A);
|
|
6895
|
+
* // lu ≈ [[6., 3.], [0.6666667, 1.0]]
|
|
6896
|
+
* // pivots = [1, 1]
|
|
6897
|
+
* // permutation = [1, 0]
|
|
6898
|
+
* ```
|
|
6899
|
+
*/
|
|
6900
|
+
function lu(x) {
|
|
6901
|
+
return lu$1(x);
|
|
6902
|
+
}
|
|
6903
|
+
/**
|
|
6904
|
+
* Solve a triangular linear system.
|
|
6905
|
+
*
|
|
6906
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
6907
|
+
* where `a` is a triangular matrix.
|
|
6908
|
+
*
|
|
6909
|
+
* @example
|
|
6910
|
+
* ```ts
|
|
6911
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6912
|
+
*
|
|
6913
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
6914
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
6915
|
+
*
|
|
6916
|
+
* // Solve L @ x = b
|
|
6917
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
6918
|
+
* // x = [[2.], [5./3.]]
|
|
6919
|
+
* ```
|
|
6920
|
+
*/
|
|
6921
|
+
function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
|
|
6922
|
+
a = fudgeArray(a);
|
|
6923
|
+
b = fudgeArray(b);
|
|
6924
|
+
if (!leftSide) transposeA = !transposeA;
|
|
6925
|
+
else b = moveaxis$1(b, -2, -1);
|
|
6926
|
+
if (transposeA) a = moveaxis$1(a, -2, -1);
|
|
6927
|
+
let x = triangularSolve$1(a, b, {
|
|
6928
|
+
lower,
|
|
6929
|
+
unitDiagonal
|
|
6930
|
+
});
|
|
6931
|
+
if (leftSide) x = moveaxis$1(x, -2, -1);
|
|
6932
|
+
return x;
|
|
6933
|
+
}
|
|
6934
|
+
|
|
6935
|
+
//#endregion
|
|
6936
|
+
//#region src/library/lax.ts
|
|
6937
|
+
var lax_exports = {};
|
|
6938
|
+
__export(lax_exports, {
|
|
6939
|
+
conv: () => conv,
|
|
6940
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
6941
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6942
|
+
dot: () => dot,
|
|
6943
|
+
erf: () => erf,
|
|
6944
|
+
erfc: () => erfc,
|
|
6945
|
+
linalg: () => lax_linalg_exports,
|
|
6946
|
+
reduceWindow: () => reduceWindow,
|
|
6947
|
+
stopGradient: () => stopGradient$1
|
|
6948
|
+
});
|
|
6949
|
+
/**
|
|
6950
|
+
* General dot product/contraction operator.
|
|
5896
6951
|
*
|
|
5897
|
-
*
|
|
5898
|
-
*
|
|
6952
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
6953
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
5899
6954
|
*/
|
|
5900
|
-
function
|
|
5901
|
-
|
|
5902
|
-
|
|
5903
|
-
|
|
5904
|
-
|
|
5905
|
-
|
|
5906
|
-
|
|
6955
|
+
function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
6956
|
+
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
6957
|
+
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
6958
|
+
lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
6959
|
+
rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
6960
|
+
lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
6961
|
+
rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
6962
|
+
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
6963
|
+
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
6964
|
+
const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
6965
|
+
const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
6966
|
+
const lhs2 = lhs.transpose([
|
|
6967
|
+
...lb,
|
|
6968
|
+
...lf,
|
|
6969
|
+
...lc
|
|
6970
|
+
]);
|
|
6971
|
+
const rhs2 = rhs.transpose([
|
|
6972
|
+
...rb,
|
|
6973
|
+
...rf,
|
|
6974
|
+
...rc
|
|
6975
|
+
]);
|
|
6976
|
+
if (lc.length === 0) return mul(lhs2.reshape([
|
|
6977
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6978
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6979
|
+
...require_backend.rep(rf.length, 1)
|
|
6980
|
+
]), rhs2.reshape([
|
|
6981
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6982
|
+
...require_backend.rep(lf.length, 1),
|
|
6983
|
+
...rf.map((a) => rhs.shape[a])
|
|
6984
|
+
]));
|
|
6985
|
+
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
6986
|
+
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
6987
|
+
if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
6988
|
+
return dot$2(lhs2.reshape([
|
|
6989
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6990
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6991
|
+
...require_backend.rep(rf.length, 1),
|
|
6992
|
+
require_backend.prod(dotShapeX)
|
|
6993
|
+
]), rhs2.reshape([
|
|
6994
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6995
|
+
...require_backend.rep(lf.length, 1),
|
|
6996
|
+
...rf.map((a) => rhs.shape[a]),
|
|
6997
|
+
require_backend.prod(dotShapeY)
|
|
6998
|
+
]));
|
|
6999
|
+
}
|
|
7000
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
7001
|
+
const padType = padding.toUpperCase();
|
|
7002
|
+
switch (padType) {
|
|
7003
|
+
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
7004
|
+
case "SAME":
|
|
7005
|
+
case "SAME_LOWER": {
|
|
7006
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
7007
|
+
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
7008
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
7009
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
7010
|
+
}
|
|
7011
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
7012
|
+
}
|
|
5907
7013
|
}
|
|
5908
7014
|
/**
|
|
5909
|
-
*
|
|
7015
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
5910
7016
|
*
|
|
5911
|
-
* The
|
|
5912
|
-
*
|
|
7017
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
7018
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
5913
7019
|
*
|
|
5914
|
-
*
|
|
5915
|
-
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
7020
|
+
* Grouped convolutions are not supported right now.
|
|
5916
7021
|
*/
|
|
5917
|
-
function
|
|
5918
|
-
|
|
7022
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
7023
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
7024
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
7025
|
+
if (typeof padding === "string") {
|
|
7026
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
7027
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
7028
|
+
}
|
|
7029
|
+
if (featureGroupCount !== 1) {
|
|
7030
|
+
const G = featureGroupCount;
|
|
7031
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
7032
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
7033
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
7034
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
7035
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
7036
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
7037
|
+
N,
|
|
7038
|
+
G,
|
|
7039
|
+
C_in / G,
|
|
7040
|
+
...xs
|
|
7041
|
+
]), 1, 0);
|
|
7042
|
+
const rhsGrouped = rhs.reshape([
|
|
7043
|
+
G,
|
|
7044
|
+
C_out / G,
|
|
7045
|
+
C_in_per_group,
|
|
7046
|
+
...ks
|
|
7047
|
+
]);
|
|
7048
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
7049
|
+
vmapDims: 1,
|
|
7050
|
+
strides: windowStrides,
|
|
7051
|
+
padding,
|
|
7052
|
+
lhsDilation,
|
|
7053
|
+
rhsDilation
|
|
7054
|
+
});
|
|
7055
|
+
const ys = result.shape.slice(3);
|
|
7056
|
+
return moveaxis(result, 0, 1).reshape([
|
|
7057
|
+
N,
|
|
7058
|
+
C_out,
|
|
7059
|
+
...ys
|
|
7060
|
+
]);
|
|
7061
|
+
}
|
|
7062
|
+
return conv$1(lhs, rhs, {
|
|
7063
|
+
strides: windowStrides,
|
|
7064
|
+
padding,
|
|
7065
|
+
lhsDilation,
|
|
7066
|
+
rhsDilation
|
|
7067
|
+
});
|
|
5919
7068
|
}
|
|
5920
|
-
/**
|
|
5921
|
-
function
|
|
5922
|
-
|
|
5923
|
-
|
|
7069
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
7070
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
7071
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
7072
|
+
lhsDilation,
|
|
7073
|
+
rhsDilation
|
|
7074
|
+
});
|
|
5924
7075
|
}
|
|
5925
|
-
/**
|
|
5926
|
-
function
|
|
5927
|
-
|
|
5928
|
-
return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
7076
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
7077
|
+
function conv(lhs, rhs, windowStrides, padding) {
|
|
7078
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
5929
7079
|
}
|
|
5930
|
-
/**
|
|
5931
|
-
function
|
|
5932
|
-
|
|
5933
|
-
|
|
7080
|
+
/** Reduce a computation over padded windows. */
|
|
7081
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
7082
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
7083
|
+
if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
|
|
7084
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
7085
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
7086
|
+
window: windowDimensions,
|
|
7087
|
+
strides: windowStrides
|
|
7088
|
+
}));
|
|
5934
7089
|
}
|
|
5935
|
-
/**
|
|
5936
|
-
function
|
|
5937
|
-
|
|
5938
|
-
return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
7090
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
7091
|
+
function erf(x) {
|
|
7092
|
+
return erf$1(x);
|
|
5939
7093
|
}
|
|
5940
7094
|
/**
|
|
5941
|
-
*
|
|
5942
|
-
*
|
|
7095
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
7096
|
+
*
|
|
7097
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
7098
|
+
* where `erf(x)` is very close to 1.
|
|
5943
7099
|
*/
|
|
5944
|
-
|
|
5945
|
-
|
|
5946
|
-
|
|
5947
|
-
|
|
7100
|
+
function erfc(x) {
|
|
7101
|
+
return erfc$1(x);
|
|
7102
|
+
}
|
|
7103
|
+
/**
|
|
7104
|
+
* Stops gradient computation.
|
|
7105
|
+
*
|
|
7106
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
7107
|
+
* forward or reverse-mode automatic differentiation.
|
|
7108
|
+
*/
|
|
7109
|
+
function stopGradient$1(x) {
|
|
7110
|
+
return stopGradient(x);
|
|
7111
|
+
}
|
|
5948
7112
|
|
|
5949
7113
|
//#endregion
|
|
5950
7114
|
//#region src/library/nn.ts
|
|
@@ -5954,6 +7118,10 @@ __export(nn_exports, {
|
|
|
5954
7118
|
elu: () => elu,
|
|
5955
7119
|
gelu: () => gelu,
|
|
5956
7120
|
glu: () => glu,
|
|
7121
|
+
hardSigmoid: () => hardSigmoid,
|
|
7122
|
+
hardSilu: () => hardSilu,
|
|
7123
|
+
hardSwish: () => hardSilu,
|
|
7124
|
+
hardTanh: () => hardTanh,
|
|
5957
7125
|
identity: () => identity,
|
|
5958
7126
|
leakyRelu: () => leakyRelu,
|
|
5959
7127
|
logSigmoid: () => logSigmoid,
|
|
@@ -5964,14 +7132,17 @@ __export(nn_exports, {
|
|
|
5964
7132
|
oneHot: () => oneHot,
|
|
5965
7133
|
relu: () => relu,
|
|
5966
7134
|
relu6: () => relu6,
|
|
7135
|
+
selu: () => selu,
|
|
5967
7136
|
sigmoid: () => sigmoid,
|
|
5968
7137
|
silu: () => silu,
|
|
5969
7138
|
softSign: () => softSign,
|
|
5970
7139
|
softmax: () => softmax,
|
|
5971
7140
|
softplus: () => softplus,
|
|
7141
|
+
sparsePlus: () => sparsePlus,
|
|
7142
|
+
sparseSigmoid: () => sparseSigmoid,
|
|
5972
7143
|
squareplus: () => squareplus,
|
|
5973
7144
|
standardize: () => standardize,
|
|
5974
|
-
swish: () =>
|
|
7145
|
+
swish: () => silu
|
|
5975
7146
|
});
|
|
5976
7147
|
/**
|
|
5977
7148
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -6006,6 +7177,28 @@ function softplus(x) {
|
|
|
6006
7177
|
return log(exp(x).add(1));
|
|
6007
7178
|
}
|
|
6008
7179
|
/**
|
|
7180
|
+
* @function
|
|
7181
|
+
* Sparse plus function:
|
|
7182
|
+
*
|
|
7183
|
+
* - When `x <= -1`: `0`
|
|
7184
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
7185
|
+
* - When `x >= 1`: `x`
|
|
7186
|
+
*/
|
|
7187
|
+
const sparsePlus = jit$1((x) => {
|
|
7188
|
+
return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
|
|
7189
|
+
});
|
|
7190
|
+
/**
|
|
7191
|
+
* @function
|
|
7192
|
+
* Sparse sigmoid activation function.
|
|
7193
|
+
*
|
|
7194
|
+
* - When `x <= -1`: `0`
|
|
7195
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
7196
|
+
* - When `x >= 1`: `1`
|
|
7197
|
+
*/
|
|
7198
|
+
const sparseSigmoid = jit$1((x) => {
|
|
7199
|
+
return clip(x.add(1).mul(.5), 0, 1);
|
|
7200
|
+
});
|
|
7201
|
+
/**
|
|
6009
7202
|
* Soft-sign activation function, computed element-wise:
|
|
6010
7203
|
* `softsign(x) = x / (|x| + 1)`.
|
|
6011
7204
|
*/
|
|
@@ -6027,17 +7220,6 @@ const silu = jit$1(function silu$1(x) {
|
|
|
6027
7220
|
return x.ref.mul(sigmoid(x));
|
|
6028
7221
|
});
|
|
6029
7222
|
/**
|
|
6030
|
-
* @function
|
|
6031
|
-
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
6032
|
-
* Swish, computed element-wise:
|
|
6033
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
6034
|
-
*
|
|
6035
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
6036
|
-
*
|
|
6037
|
-
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
6038
|
-
*/
|
|
6039
|
-
const swish = silu;
|
|
6040
|
-
/**
|
|
6041
7223
|
* Log-sigmoid activation function, computed element-wise:
|
|
6042
7224
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
6043
7225
|
*/
|
|
@@ -6054,6 +7236,19 @@ function leakyRelu(x, negativeSlope = .01) {
|
|
|
6054
7236
|
x = fudgeArray(x);
|
|
6055
7237
|
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
6056
7238
|
}
|
|
7239
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
7240
|
+
function hardSigmoid(x) {
|
|
7241
|
+
return relu6(add(x, 3)).mul(1 / 6);
|
|
7242
|
+
}
|
|
7243
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
7244
|
+
function hardSilu(x) {
|
|
7245
|
+
x = fudgeArray(x);
|
|
7246
|
+
return x.ref.mul(hardSigmoid(x));
|
|
7247
|
+
}
|
|
7248
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
7249
|
+
function hardTanh(x) {
|
|
7250
|
+
return clip(x, -1, 1);
|
|
7251
|
+
}
|
|
6057
7252
|
/**
|
|
6058
7253
|
* Exponential linear unit activation function.
|
|
6059
7254
|
*
|
|
@@ -6076,6 +7271,20 @@ function celu(x, alpha = 1) {
|
|
|
6076
7271
|
}
|
|
6077
7272
|
/**
|
|
6078
7273
|
* @function
|
|
7274
|
+
* Scaled exponential linear unit activation.
|
|
7275
|
+
*
|
|
7276
|
+
* Computes the element-wise function:
|
|
7277
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
7278
|
+
*
|
|
7279
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
7280
|
+
*/
|
|
7281
|
+
const selu = jit$1(function selu$1(x) {
|
|
7282
|
+
const alpha = 1.6732632423543772;
|
|
7283
|
+
const lambda = 1.0507009873554805;
|
|
7284
|
+
return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
|
|
7285
|
+
});
|
|
7286
|
+
/**
|
|
7287
|
+
* @function
|
|
6079
7288
|
* Gaussion error linear unit (GELU) activation function.
|
|
6080
7289
|
*
|
|
6081
7290
|
* This is computed element-wise. There are two variants depending on whether
|
|
@@ -6229,35 +7438,46 @@ var random_exports = {};
|
|
|
6229
7438
|
__export(random_exports, {
|
|
6230
7439
|
bernoulli: () => bernoulli,
|
|
6231
7440
|
bits: () => bits,
|
|
7441
|
+
cauchy: () => cauchy,
|
|
6232
7442
|
exponential: () => exponential,
|
|
7443
|
+
gumbel: () => gumbel,
|
|
6233
7444
|
key: () => key,
|
|
7445
|
+
laplace: () => laplace,
|
|
7446
|
+
multivariateNormal: () => multivariateNormal,
|
|
6234
7447
|
normal: () => normal,
|
|
6235
7448
|
split: () => split,
|
|
6236
7449
|
uniform: () => uniform
|
|
6237
7450
|
});
|
|
6238
|
-
function validateKeyShape(key$1) {
|
|
7451
|
+
function validateKeyShape(key$1, scalar = false) {
|
|
6239
7452
|
if (key$1.ndim === 0) throw new Error("Key must have at least one dimension.");
|
|
6240
7453
|
if (key$1.shape[key$1.shape.length - 1] !== 2) throw new Error(`Invalid key shape: ${key$1.shape}. Expected last dimension to be 2.`);
|
|
7454
|
+
if (scalar && key$1.shape.length > 1) throw new Error(`Expected a single PRNG key, but got a batch of keys with shape ${JSON.stringify(key$1.shape)} - use jax.vmap for batching.`);
|
|
6241
7455
|
return key$1.shape.slice(0, -1);
|
|
6242
7456
|
}
|
|
7457
|
+
function getK01(key$1) {
|
|
7458
|
+
const keyShape = validateKeyShape(key$1, true);
|
|
7459
|
+
let [k0, k1] = split$2(key$1, -1, [1, 1]);
|
|
7460
|
+
k0 = k0.reshape(keyShape);
|
|
7461
|
+
k1 = k1.reshape(keyShape);
|
|
7462
|
+
return [k0, k1];
|
|
7463
|
+
}
|
|
6243
7464
|
/** Create a pseudo-random number generator (PRNG) key from 32-bit integer seed. */
|
|
6244
7465
|
function key(seed) {
|
|
6245
|
-
seed = seed
|
|
6246
|
-
|
|
7466
|
+
seed = array(seed, { dtype: require_backend.DType.Uint32 });
|
|
7467
|
+
if (seed.ndim !== 0) throw new Error(`key: seed must be a scalar integer, but got shape ${seed.shape} - use jax.vmap for batching.`);
|
|
7468
|
+
return stack([0, seed]);
|
|
6247
7469
|
}
|
|
6248
7470
|
/** Splits a PRNG key into `num` new keys by adding a leading axis. */
|
|
6249
7471
|
function split(key$1, num = 2) {
|
|
6250
7472
|
const shape$1 = typeof num === "number" ? [num] : num;
|
|
6251
7473
|
for (const len of shape$1) if (len <= 0 || !Number.isInteger(len)) throw new Error(`Invalid split length: ${len}. Must be a positive integer.`);
|
|
6252
|
-
const
|
|
6253
|
-
const k0 = key$1.ref.slice(...keyShape.map(() => null), 0);
|
|
6254
|
-
const k1 = key$1.slice(...keyShape.map(() => null), 1);
|
|
7474
|
+
const [k0, k1] = getK01(key$1);
|
|
6255
7475
|
return stack([randomBits(k0.ref, k1.ref, shape$1, 0), randomBits(k0, k1, shape$1, 1)], -1);
|
|
6256
7476
|
}
|
|
6257
7477
|
/** Sample uniform bits in the form of unsigned integers. */
|
|
6258
7478
|
function bits(key$1, shape$1 = []) {
|
|
6259
|
-
const
|
|
6260
|
-
return randomBits(
|
|
7479
|
+
const [k0, k1] = getK01(key$1);
|
|
7480
|
+
return randomBits(k0, k1, shape$1);
|
|
6261
7481
|
}
|
|
6262
7482
|
/**
|
|
6263
7483
|
* @function
|
|
@@ -6289,6 +7509,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
6289
7509
|
}
|
|
6290
7510
|
/**
|
|
6291
7511
|
* @function
|
|
7512
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
7513
|
+
*
|
|
7514
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
7515
|
+
*/
|
|
7516
|
+
const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
7517
|
+
const u = uniform(key$1, shape$1);
|
|
7518
|
+
return tan(u.sub(.5).mul(Math.PI));
|
|
7519
|
+
}, { staticArgnums: [1] });
|
|
7520
|
+
/**
|
|
7521
|
+
* @function
|
|
6292
7522
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
6293
7523
|
*/
|
|
6294
7524
|
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
@@ -6297,6 +7527,56 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
6297
7527
|
}, { staticArgnums: [1] });
|
|
6298
7528
|
/**
|
|
6299
7529
|
* @function
|
|
7530
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
7531
|
+
*
|
|
7532
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
7533
|
+
*/
|
|
7534
|
+
const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
|
|
7535
|
+
const u = uniform(key$1, shape$1);
|
|
7536
|
+
return negative(log(negative(log1p(negative(u)))));
|
|
7537
|
+
}, { staticArgnums: [1] });
|
|
7538
|
+
/**
|
|
7539
|
+
* @function
|
|
7540
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
7541
|
+
*
|
|
7542
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
7543
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
7544
|
+
*/
|
|
7545
|
+
const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
7546
|
+
const u = uniform(key$1, shape$1);
|
|
7547
|
+
const centered = u.sub(.5);
|
|
7548
|
+
const s = sign(centered.ref);
|
|
7549
|
+
const absVal = absolute(centered);
|
|
7550
|
+
return s.mul(log1p(absVal.mul(-2)).mul(-1));
|
|
7551
|
+
}, { staticArgnums: [1] });
|
|
7552
|
+
/**
|
|
7553
|
+
* @function
|
|
7554
|
+
* Sample multivariate normal random values with given mean and covariance.
|
|
7555
|
+
*
|
|
7556
|
+
* The values are returned with the given shape, along with the final dimension
|
|
7557
|
+
* used to represent the n-dimensional multivariate normal factors.
|
|
7558
|
+
*
|
|
7559
|
+
* This uses Cholesky decomposition on the covariance matrix.
|
|
7560
|
+
*
|
|
7561
|
+
* - `key` - PRNG key
|
|
7562
|
+
* - `mean` - Mean vector of shape `[..., n]`
|
|
7563
|
+
* - `cov` - Covariance of shape `[..., n, n]`, must be positive-definite
|
|
7564
|
+
* - `shape` - Result batch shape, must be broadcastable with
|
|
7565
|
+
* `mean.shape[:-1]` and `cov.shape[:-2]`
|
|
7566
|
+
* @returns Random samples of shape `[...shape, n]`
|
|
7567
|
+
*/
|
|
7568
|
+
const multivariateNormal = jit$1(function multivariateNormal$1(key$1, mean$1, cov$1, shape$1 = []) {
|
|
7569
|
+
mean$1 = fudgeArray(mean$1);
|
|
7570
|
+
cov$1 = fudgeArray(cov$1);
|
|
7571
|
+
const n = mean$1.shape[mean$1.ndim - 1];
|
|
7572
|
+
if (cov$1.shape[cov$1.ndim - 1] !== n || cov$1.shape[cov$1.ndim - 2] !== n) throw new Error(`Invalid covariance shape: ${cov$1.shape}. Expected last two dimensions to be [${n}, ${n}].`);
|
|
7573
|
+
const outputShape = broadcastShapes(shape$1, mean$1.shape.slice(0, -1), cov$1.shape.slice(0, -2)).concat(n);
|
|
7574
|
+
const L = cholesky(cov$1);
|
|
7575
|
+
const z = normal(key$1, outputShape);
|
|
7576
|
+
return einsum("...ij,...j->...i", L, z).add(mean$1);
|
|
7577
|
+
}, { staticArgnums: [3] });
|
|
7578
|
+
/**
|
|
7579
|
+
* @function
|
|
6300
7580
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6301
7581
|
*
|
|
6302
7582
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -6405,11 +7685,6 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
6405
7685
|
*/
|
|
6406
7686
|
const jacrev = jacrev$1;
|
|
6407
7687
|
/**
|
|
6408
|
-
* @function
|
|
6409
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
6410
|
-
*/
|
|
6411
|
-
const jacobian = jacrev;
|
|
6412
|
-
/**
|
|
6413
7688
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
6414
7689
|
*
|
|
6415
7690
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -6445,6 +7720,7 @@ async function devicePut(x, device) {
|
|
|
6445
7720
|
|
|
6446
7721
|
//#endregion
|
|
6447
7722
|
exports.Array = Array$1;
|
|
7723
|
+
exports.ClosedJaxpr = ClosedJaxpr;
|
|
6448
7724
|
exports.DType = require_backend.DType;
|
|
6449
7725
|
exports.Jaxpr = Jaxpr;
|
|
6450
7726
|
exports.blockUntilReady = blockUntilReady;
|
|
@@ -6454,7 +7730,7 @@ exports.devices = require_backend.devices;
|
|
|
6454
7730
|
exports.grad = grad;
|
|
6455
7731
|
exports.init = require_backend.init;
|
|
6456
7732
|
exports.jacfwd = jacfwd;
|
|
6457
|
-
exports.jacobian =
|
|
7733
|
+
exports.jacobian = jacrev;
|
|
6458
7734
|
exports.jacrev = jacrev;
|
|
6459
7735
|
exports.jit = jit;
|
|
6460
7736
|
exports.jvp = jvp;
|
|
@@ -6499,5 +7775,4 @@ Object.defineProperty(exports, 'tree', {
|
|
|
6499
7775
|
});
|
|
6500
7776
|
exports.valueAndGrad = valueAndGrad;
|
|
6501
7777
|
exports.vjp = vjp;
|
|
6502
|
-
exports.vmap = vmap;
|
|
6503
|
-
//# sourceMappingURL=index.cjs.map
|
|
7778
|
+
exports.vmap = vmap;
|