@jax-js/jax 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -2
- package/dist/{backend-CmaidnkQ.cjs → backend-Bu9GY6sK.cjs} +166 -18
- package/dist/{backend-BY8wlLEl.js → backend-tngXtWe4.js} +148 -18
- package/dist/index.cjs +1683 -1004
- package/dist/index.d.cts +365 -95
- package/dist/index.d.ts +365 -95
- package/dist/index.js +1675 -997
- package/dist/{webgpu-C9iAP5h5.js → webgpu-ChVgx3b6.js} +400 -95
- package/dist/{webgpu-BVns4DbI.cjs → webgpu-Oj3Kd-kd.cjs} +400 -95
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -8,9 +8,9 @@ var __hasOwnProp = Object.prototype.hasOwnProperty;
|
|
|
8
8
|
var __commonJS = (cb, mod$1) => function() {
|
|
9
9
|
return mod$1 || (0, cb[__getOwnPropNames(cb)[0]])((mod$1 = { exports: {} }).exports, mod$1), mod$1.exports;
|
|
10
10
|
};
|
|
11
|
-
var __export = (target, all) => {
|
|
12
|
-
for (var name in all) __defProp(target, name, {
|
|
13
|
-
get: all[name],
|
|
11
|
+
var __export = (target, all$1) => {
|
|
12
|
+
for (var name in all$1) __defProp(target, name, {
|
|
13
|
+
get: all$1[name],
|
|
14
14
|
enumerable: true
|
|
15
15
|
});
|
|
16
16
|
};
|
|
@@ -30,7 +30,7 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-Bu9GY6sK.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
@@ -362,6 +362,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
362
362
|
Primitive$1["Mul"] = "mul";
|
|
363
363
|
Primitive$1["Idiv"] = "idiv";
|
|
364
364
|
Primitive$1["Mod"] = "mod";
|
|
365
|
+
Primitive$1["Min"] = "min";
|
|
366
|
+
Primitive$1["Max"] = "max";
|
|
365
367
|
Primitive$1["Neg"] = "neg";
|
|
366
368
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
367
369
|
Primitive$1["Floor"] = "floor";
|
|
@@ -369,7 +371,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
369
371
|
Primitive$1["StopGradient"] = "stop_gradient";
|
|
370
372
|
Primitive$1["Cast"] = "cast";
|
|
371
373
|
Primitive$1["Bitcast"] = "bitcast";
|
|
372
|
-
Primitive$1["RandomBits"] = "random_bits";
|
|
373
374
|
Primitive$1["Sin"] = "sin";
|
|
374
375
|
Primitive$1["Cos"] = "cos";
|
|
375
376
|
Primitive$1["Asin"] = "asin";
|
|
@@ -379,8 +380,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
379
380
|
Primitive$1["Erf"] = "erf";
|
|
380
381
|
Primitive$1["Erfc"] = "erfc";
|
|
381
382
|
Primitive$1["Sqrt"] = "sqrt";
|
|
382
|
-
Primitive$1["Min"] = "min";
|
|
383
|
-
Primitive$1["Max"] = "max";
|
|
384
383
|
Primitive$1["Reduce"] = "reduce";
|
|
385
384
|
Primitive$1["Dot"] = "dot";
|
|
386
385
|
Primitive$1["Conv"] = "conv";
|
|
@@ -388,14 +387,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
388
387
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
389
388
|
Primitive$1["Compare"] = "compare";
|
|
390
389
|
Primitive$1["Where"] = "where";
|
|
390
|
+
Primitive$1["RandomBits"] = "random_bits";
|
|
391
|
+
Primitive$1["Gather"] = "gather";
|
|
391
392
|
Primitive$1["Transpose"] = "transpose";
|
|
392
393
|
Primitive$1["Broadcast"] = "broadcast";
|
|
393
394
|
Primitive$1["Reshape"] = "reshape";
|
|
394
395
|
Primitive$1["Flip"] = "flip";
|
|
395
396
|
Primitive$1["Shrink"] = "shrink";
|
|
396
397
|
Primitive$1["Pad"] = "pad";
|
|
397
|
-
Primitive$1["
|
|
398
|
-
Primitive$1["
|
|
398
|
+
Primitive$1["Sort"] = "sort";
|
|
399
|
+
Primitive$1["Argsort"] = "argsort";
|
|
400
|
+
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
401
|
+
Primitive$1["Cholesky"] = "cholesky";
|
|
402
|
+
Primitive$1["Jit"] = "jit";
|
|
399
403
|
return Primitive$1;
|
|
400
404
|
}({});
|
|
401
405
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
@@ -417,6 +421,12 @@ function idiv(x, y) {
|
|
|
417
421
|
function mod(x, y) {
|
|
418
422
|
return bind1(Primitive.Mod, [x, y]);
|
|
419
423
|
}
|
|
424
|
+
function min$1(x, y) {
|
|
425
|
+
return bind1(Primitive.Min, [x, y]);
|
|
426
|
+
}
|
|
427
|
+
function max$1(x, y) {
|
|
428
|
+
return bind1(Primitive.Max, [x, y]);
|
|
429
|
+
}
|
|
420
430
|
function neg(x) {
|
|
421
431
|
return bind1(Primitive.Neg, [x]);
|
|
422
432
|
}
|
|
@@ -438,12 +448,6 @@ function cast(x, dtype) {
|
|
|
438
448
|
function bitcast(x, dtype) {
|
|
439
449
|
return bind1(Primitive.Bitcast, [x], { dtype });
|
|
440
450
|
}
|
|
441
|
-
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
442
|
-
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
443
|
-
shape: shape$1,
|
|
444
|
-
mode
|
|
445
|
-
});
|
|
446
|
-
}
|
|
447
451
|
function sin$1(x) {
|
|
448
452
|
return bind1(Primitive.Sin, [x]);
|
|
449
453
|
}
|
|
@@ -471,12 +475,6 @@ function erfc$1(x) {
|
|
|
471
475
|
function sqrt$1(x) {
|
|
472
476
|
return bind1(Primitive.Sqrt, [x]);
|
|
473
477
|
}
|
|
474
|
-
function min$1(x, y) {
|
|
475
|
-
return bind1(Primitive.Min, [x, y]);
|
|
476
|
-
}
|
|
477
|
-
function max$1(x, y) {
|
|
478
|
-
return bind1(Primitive.Max, [x, y]);
|
|
479
|
-
}
|
|
480
478
|
function reduce(x, op, axis = null, opts) {
|
|
481
479
|
if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
482
480
|
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
@@ -532,6 +530,23 @@ function where$1(cond, x, y) {
|
|
|
532
530
|
y
|
|
533
531
|
]);
|
|
534
532
|
}
|
|
533
|
+
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
534
|
+
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
535
|
+
shape: shape$1,
|
|
536
|
+
mode
|
|
537
|
+
});
|
|
538
|
+
}
|
|
539
|
+
function gather(x, indices, axis, outDim) {
|
|
540
|
+
if (indices.length === 0) throw new Error("gather() requires at least one index");
|
|
541
|
+
if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
|
|
542
|
+
axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
|
|
543
|
+
if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
|
|
544
|
+
outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
|
|
545
|
+
return bind1(Primitive.Gather, [x, ...indices], {
|
|
546
|
+
axis,
|
|
547
|
+
outDim
|
|
548
|
+
});
|
|
549
|
+
}
|
|
535
550
|
function transpose$1(x, perm) {
|
|
536
551
|
perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
|
|
537
552
|
if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
@@ -581,16 +596,27 @@ function pad$1(x, width) {
|
|
|
581
596
|
} else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
|
|
582
597
|
return bind1(Primitive.Pad, [x], { width });
|
|
583
598
|
}
|
|
584
|
-
function
|
|
585
|
-
if (
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
599
|
+
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
600
|
+
if (lower) {
|
|
601
|
+
a = flip$1(a, [-2, -1]);
|
|
602
|
+
b = flip$1(b, [-1]);
|
|
603
|
+
}
|
|
604
|
+
let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
605
|
+
if (lower) x = flip$1(x, [-1]);
|
|
606
|
+
return x;
|
|
607
|
+
}
|
|
608
|
+
function cholesky$2(x) {
|
|
609
|
+
return bind1(Primitive.Cholesky, [x]);
|
|
610
|
+
}
|
|
611
|
+
function sort$1(x) {
|
|
612
|
+
const nd = ndim$1(x);
|
|
613
|
+
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
614
|
+
return bind1(Primitive.Sort, [x]);
|
|
615
|
+
}
|
|
616
|
+
function argsort$1(x) {
|
|
617
|
+
const nd = ndim$1(x);
|
|
618
|
+
if (nd === 0) throw new Error("argsort: requires at least 1D input");
|
|
619
|
+
return bind(Primitive.Argsort, [x]);
|
|
594
620
|
}
|
|
595
621
|
function bind1(prim, args, params = {}) {
|
|
596
622
|
const [results] = bind(prim, args, params);
|
|
@@ -753,7 +779,7 @@ var Tracer = class Tracer {
|
|
|
753
779
|
if (require_backend.isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
|
|
754
780
|
return idiv(this, other);
|
|
755
781
|
}
|
|
756
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
782
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
757
783
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
758
784
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
759
785
|
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
@@ -806,6 +832,34 @@ var Tracer = class Tracer {
|
|
|
806
832
|
this.dispose();
|
|
807
833
|
}
|
|
808
834
|
/**
|
|
835
|
+
* Return a sorted copy of an array in ascending order.
|
|
836
|
+
*
|
|
837
|
+
* See `jax.numpy.sort` for full docs.
|
|
838
|
+
*/
|
|
839
|
+
sort(axis = -1) {
|
|
840
|
+
axis = require_backend.checkAxis(axis, this.ndim);
|
|
841
|
+
if (this.shape[axis] <= 1) return this;
|
|
842
|
+
if (axis === this.ndim - 1) return sort$1(this);
|
|
843
|
+
const perm = require_backend.range(this.ndim);
|
|
844
|
+
perm.splice(axis, 1);
|
|
845
|
+
perm.push(axis);
|
|
846
|
+
return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
|
|
847
|
+
}
|
|
848
|
+
/**
|
|
849
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
850
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
851
|
+
*
|
|
852
|
+
* See `jax.numpy.argsort` for full docs.
|
|
853
|
+
*/
|
|
854
|
+
argsort(axis = -1) {
|
|
855
|
+
axis = require_backend.checkAxis(axis, this.ndim);
|
|
856
|
+
if (axis === this.ndim - 1) return argsort$1(this)[1];
|
|
857
|
+
const perm = require_backend.range(this.ndim);
|
|
858
|
+
perm.splice(axis, 1);
|
|
859
|
+
perm.push(axis);
|
|
860
|
+
return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
|
|
861
|
+
}
|
|
862
|
+
/**
|
|
809
863
|
* Slice an array along one or more axes.
|
|
810
864
|
*
|
|
811
865
|
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
@@ -922,6 +976,9 @@ var ShapedArray = class ShapedArray {
|
|
|
922
976
|
get ndim() {
|
|
923
977
|
return this.shape.length;
|
|
924
978
|
}
|
|
979
|
+
get size() {
|
|
980
|
+
return require_backend.prod(this.shape);
|
|
981
|
+
}
|
|
925
982
|
toString() {
|
|
926
983
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
927
984
|
}
|
|
@@ -1221,13 +1278,13 @@ var Jaxpr = class Jaxpr {
|
|
|
1221
1278
|
}
|
|
1222
1279
|
return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
|
|
1223
1280
|
}
|
|
1224
|
-
/** Flattens nested
|
|
1281
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1225
1282
|
flatten() {
|
|
1226
|
-
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.
|
|
1283
|
+
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
|
|
1227
1284
|
const newEqns = [];
|
|
1228
1285
|
const varMap = /* @__PURE__ */ new Map();
|
|
1229
1286
|
const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
|
|
1230
|
-
for (const eqn of this.eqns) if (eqn.primitive === Primitive.
|
|
1287
|
+
for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
|
|
1231
1288
|
const jaxpr = eqn.params.jaxpr.flatten();
|
|
1232
1289
|
const translation = /* @__PURE__ */ new Map();
|
|
1233
1290
|
const translationF = (x) => x instanceof Var ? translation.get(x) : x;
|
|
@@ -1328,19 +1385,48 @@ function evalJaxpr(jaxpr, args) {
|
|
|
1328
1385
|
function jaxprAsFun(jaxpr) {
|
|
1329
1386
|
return (...args) => evalJaxpr(jaxpr, args);
|
|
1330
1387
|
}
|
|
1388
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1389
|
+
var ClosedJaxpr = class ClosedJaxpr {
|
|
1390
|
+
constructor(jaxpr, consts) {
|
|
1391
|
+
this.jaxpr = jaxpr;
|
|
1392
|
+
this.consts = consts;
|
|
1393
|
+
}
|
|
1394
|
+
/** String representation of this Jaxpr. */
|
|
1395
|
+
toString() {
|
|
1396
|
+
return this.jaxpr.toString();
|
|
1397
|
+
}
|
|
1398
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1399
|
+
mapJaxpr(f) {
|
|
1400
|
+
return new ClosedJaxpr(f(this.jaxpr), this.consts);
|
|
1401
|
+
}
|
|
1402
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1403
|
+
dispose() {
|
|
1404
|
+
for (const c of this.consts) c.dispose();
|
|
1405
|
+
}
|
|
1406
|
+
};
|
|
1331
1407
|
/** Tracer that records its operations to dynamically construct a Jaxpr. */
|
|
1332
1408
|
var JaxprTracer = class extends Tracer {
|
|
1409
|
+
#rc;
|
|
1333
1410
|
constructor(trace$1, aval) {
|
|
1334
1411
|
super(trace$1);
|
|
1335
1412
|
this.aval = aval;
|
|
1413
|
+
this.#rc = 1;
|
|
1336
1414
|
}
|
|
1337
1415
|
toString() {
|
|
1338
1416
|
return `JaxprTracer(${this.aval.toString()})`;
|
|
1339
1417
|
}
|
|
1340
1418
|
get ref() {
|
|
1419
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1420
|
+
this.#rc++;
|
|
1341
1421
|
return this;
|
|
1342
1422
|
}
|
|
1343
|
-
dispose() {
|
|
1423
|
+
dispose() {
|
|
1424
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1425
|
+
this.#rc--;
|
|
1426
|
+
}
|
|
1427
|
+
trackLiftedConstant() {
|
|
1428
|
+
this.#rc++;
|
|
1429
|
+
}
|
|
1344
1430
|
};
|
|
1345
1431
|
/** Analogous to the 'DynamicJaxprTrace' class in JAX. */
|
|
1346
1432
|
var JaxprTrace = class extends Trace {
|
|
@@ -1353,17 +1439,24 @@ var JaxprTrace = class extends Trace {
|
|
|
1353
1439
|
}
|
|
1354
1440
|
/** Register a constant / literal in this Jaxpr. */
|
|
1355
1441
|
getOrMakeConstTracer(val) {
|
|
1442
|
+
if (!(val instanceof Tracer)) val = pureArray(val);
|
|
1356
1443
|
let tracer = this.builder.constTracers.get(val);
|
|
1357
1444
|
if (tracer === void 0) {
|
|
1358
1445
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
1359
|
-
this.builder.addConst(tracer, val
|
|
1446
|
+
this.builder.addConst(tracer, val);
|
|
1447
|
+
} else {
|
|
1448
|
+
val.dispose();
|
|
1449
|
+
tracer.trackLiftedConstant();
|
|
1360
1450
|
}
|
|
1361
1451
|
return tracer;
|
|
1362
1452
|
}
|
|
1363
1453
|
pure = this.getOrMakeConstTracer;
|
|
1364
1454
|
lift = this.getOrMakeConstTracer;
|
|
1365
1455
|
processPrimitive(primitive, tracers, params) {
|
|
1366
|
-
const avalsIn = tracers.map((t) =>
|
|
1456
|
+
const avalsIn = tracers.map((t) => {
|
|
1457
|
+
t.dispose();
|
|
1458
|
+
return t.aval;
|
|
1459
|
+
});
|
|
1367
1460
|
const avalsOut = abstractEvalRules[primitive](avalsIn, params);
|
|
1368
1461
|
const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
|
|
1369
1462
|
this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
|
|
@@ -1406,20 +1499,17 @@ var JaxprBuilder = class {
|
|
|
1406
1499
|
return v;
|
|
1407
1500
|
}
|
|
1408
1501
|
build(inTracers, outTracers) {
|
|
1409
|
-
|
|
1502
|
+
const [constVars, consts] = require_backend.unzip2(this.constVals.entries());
|
|
1410
1503
|
const t2v = this.getVar.bind(this);
|
|
1411
1504
|
const inBinders = [...constVars, ...inTracers.map(t2v)];
|
|
1412
1505
|
const outVars = outTracers.map(t2v);
|
|
1413
|
-
|
|
1506
|
+
const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
|
|
1414
1507
|
typecheckJaxpr(jaxpr);
|
|
1415
|
-
|
|
1416
|
-
return
|
|
1417
|
-
jaxpr,
|
|
1418
|
-
consts
|
|
1419
|
-
};
|
|
1508
|
+
const cjaxpr = new ClosedJaxpr(jaxpr, consts);
|
|
1509
|
+
return _inlineLiterals(cjaxpr);
|
|
1420
1510
|
}
|
|
1421
1511
|
};
|
|
1422
|
-
function _inlineLiterals(jaxpr, consts) {
|
|
1512
|
+
function _inlineLiterals({ jaxpr, consts }) {
|
|
1423
1513
|
const literals = /* @__PURE__ */ new Map();
|
|
1424
1514
|
const constBinders = [];
|
|
1425
1515
|
const newConsts = [];
|
|
@@ -1434,7 +1524,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
1434
1524
|
const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
|
|
1435
1525
|
const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
|
|
1436
1526
|
typecheckJaxpr(newJaxpr);
|
|
1437
|
-
return
|
|
1527
|
+
return new ClosedJaxpr(newJaxpr, newConsts);
|
|
1438
1528
|
}
|
|
1439
1529
|
function binopAbstractEval([x, y]) {
|
|
1440
1530
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
@@ -1453,6 +1543,8 @@ const abstractEvalRules = {
|
|
|
1453
1543
|
[Primitive.Mul]: binopAbstractEval,
|
|
1454
1544
|
[Primitive.Idiv]: binopAbstractEval,
|
|
1455
1545
|
[Primitive.Mod]: binopAbstractEval,
|
|
1546
|
+
[Primitive.Min]: binopAbstractEval,
|
|
1547
|
+
[Primitive.Max]: binopAbstractEval,
|
|
1456
1548
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1457
1549
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1458
1550
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1466,12 +1558,6 @@ const abstractEvalRules = {
|
|
|
1466
1558
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1467
1559
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1468
1560
|
},
|
|
1469
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1470
|
-
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1471
|
-
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1472
|
-
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1473
|
-
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1474
|
-
},
|
|
1475
1561
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
1476
1562
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
1477
1563
|
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
@@ -1481,8 +1567,6 @@ const abstractEvalRules = {
|
|
|
1481
1567
|
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
1482
1568
|
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
1483
1569
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
1484
|
-
[Primitive.Min]: binopAbstractEval,
|
|
1485
|
-
[Primitive.Max]: binopAbstractEval,
|
|
1486
1570
|
[Primitive.Reduce]([x], { axis }) {
|
|
1487
1571
|
const axisSet = new Set(axis);
|
|
1488
1572
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1515,6 +1599,25 @@ const abstractEvalRules = {
|
|
|
1515
1599
|
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
1516
1600
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1517
1601
|
},
|
|
1602
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1603
|
+
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1604
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1605
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1606
|
+
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1607
|
+
},
|
|
1608
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
1609
|
+
for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
1610
|
+
if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
|
|
1611
|
+
if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
|
|
1612
|
+
if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
|
|
1613
|
+
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
1614
|
+
const axisSet = new Set(axis);
|
|
1615
|
+
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
1616
|
+
const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
|
|
1617
|
+
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
1618
|
+
newShape.splice(outDim, 0, ...gatherShape);
|
|
1619
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1620
|
+
},
|
|
1518
1621
|
[Primitive.Transpose]([x], { perm }) {
|
|
1519
1622
|
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
1520
1623
|
},
|
|
@@ -1535,23 +1638,31 @@ const abstractEvalRules = {
|
|
|
1535
1638
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
1536
1639
|
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1537
1640
|
},
|
|
1538
|
-
[Primitive.
|
|
1539
|
-
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
if (
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1641
|
+
[Primitive.Sort]([x]) {
|
|
1642
|
+
if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
|
|
1643
|
+
return [ShapedArray.fromAval(x)];
|
|
1644
|
+
},
|
|
1645
|
+
[Primitive.Argsort]([x]) {
|
|
1646
|
+
if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
|
|
1647
|
+
return [ShapedArray.fromAval(x), new ShapedArray(x.shape, require_backend.DType.Int32, false)];
|
|
1648
|
+
},
|
|
1649
|
+
[Primitive.TriangularSolve]([a, b]) {
|
|
1650
|
+
if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
|
|
1651
|
+
if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
|
|
1652
|
+
const [m, n] = a.shape.slice(-2);
|
|
1653
|
+
const [_batch, q] = b.shape.slice(-2);
|
|
1654
|
+
if (!require_backend.deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
|
|
1655
|
+
return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
|
|
1656
|
+
},
|
|
1657
|
+
[Primitive.Cholesky]([a]) {
|
|
1658
|
+
if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
|
|
1659
|
+
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1660
|
+
return [ShapedArray.fromAval(a)];
|
|
1550
1661
|
},
|
|
1551
|
-
[Primitive.
|
|
1662
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
1552
1663
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1553
|
-
if (args.length !== inTypes.length) throw new TypeError(`
|
|
1554
|
-
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`
|
|
1664
|
+
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
1665
|
+
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
|
|
1555
1666
|
return outTypes;
|
|
1556
1667
|
}
|
|
1557
1668
|
};
|
|
@@ -1587,11 +1698,10 @@ function makeJaxpr$1(f, opts) {
|
|
|
1587
1698
|
const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
1588
1699
|
const outs = fFlat(...tracersIn);
|
|
1589
1700
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
1590
|
-
const
|
|
1701
|
+
const jaxpr = builder.build(tracersIn, tracersOut);
|
|
1591
1702
|
if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
|
|
1592
1703
|
return {
|
|
1593
|
-
jaxpr: jaxpr.simplify(),
|
|
1594
|
-
consts,
|
|
1704
|
+
jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
|
|
1595
1705
|
treedef: outTree.value
|
|
1596
1706
|
};
|
|
1597
1707
|
} catch (_) {
|
|
@@ -1610,22 +1720,28 @@ function jit$1(f, opts) {
|
|
|
1610
1720
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
1611
1721
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
1612
1722
|
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
1613
|
-
const { jaxpr,
|
|
1614
|
-
const outs = bind(Primitive.
|
|
1723
|
+
const { jaxpr, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
1724
|
+
const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
|
|
1615
1725
|
name: f.name || "closure",
|
|
1616
|
-
jaxpr,
|
|
1617
|
-
numConsts: consts.length
|
|
1726
|
+
jaxpr: jaxpr.jaxpr,
|
|
1727
|
+
numConsts: jaxpr.consts.length
|
|
1618
1728
|
});
|
|
1619
1729
|
return unflatten(outTree, outs);
|
|
1620
1730
|
});
|
|
1621
1731
|
result.dispose = () => {
|
|
1622
|
-
for (const {
|
|
1732
|
+
for (const { jaxpr } of cache.values()) jaxpr.dispose();
|
|
1623
1733
|
};
|
|
1624
1734
|
return result;
|
|
1625
1735
|
}
|
|
1626
1736
|
|
|
1627
1737
|
//#endregion
|
|
1628
1738
|
//#region src/frontend/jit.ts
|
|
1739
|
+
const routinePrimitives = new Map([
|
|
1740
|
+
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1741
|
+
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1742
|
+
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1743
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky]
|
|
1744
|
+
]);
|
|
1629
1745
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1630
1746
|
var JitProgram = class {
|
|
1631
1747
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1640,9 +1756,14 @@ var JitProgram = class {
|
|
|
1640
1756
|
case "execute": {
|
|
1641
1757
|
const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
|
|
1642
1758
|
const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
|
|
1643
|
-
|
|
1759
|
+
const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
|
|
1760
|
+
if (step.source instanceof require_backend.Kernel) return require_backend.PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
|
|
1761
|
+
else if (step.source instanceof require_backend.Routine) return require_backend.PPrint.pp(`${executeText}, routine ${step.source.name}`);
|
|
1762
|
+
else {
|
|
1763
|
+
step.source;
|
|
1764
|
+
return require_backend.PPrint.pp(executeText);
|
|
1765
|
+
}
|
|
1644
1766
|
}
|
|
1645
|
-
case "const": return require_backend.PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
|
|
1646
1767
|
case "malloc": return require_backend.PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
|
|
1647
1768
|
case "incref": return require_backend.PPrint.pp(`incref ${step.input}`);
|
|
1648
1769
|
case "free": return require_backend.PPrint.pp(`free ${step.input}`);
|
|
@@ -1665,12 +1786,9 @@ var JitProgram = class {
|
|
|
1665
1786
|
const inputs$1 = step.inputs.map((id) => scope.get(id));
|
|
1666
1787
|
const outputs = step.outputs.map((id) => scope.get(id));
|
|
1667
1788
|
if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
|
|
1668
|
-
pending.push(new PendingExecute(this.backend, step.
|
|
1789
|
+
pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
|
|
1669
1790
|
break;
|
|
1670
1791
|
}
|
|
1671
|
-
case "const":
|
|
1672
|
-
scope.set(step.output, step.slot);
|
|
1673
|
-
break;
|
|
1674
1792
|
case "malloc": {
|
|
1675
1793
|
const slot = this.backend.malloc(step.size);
|
|
1676
1794
|
scope.set(step.output, slot);
|
|
@@ -1704,34 +1822,37 @@ var JitProgramBuilder = class {
|
|
|
1704
1822
|
this.#nextId = nargs;
|
|
1705
1823
|
this.steps = [];
|
|
1706
1824
|
}
|
|
1707
|
-
pushConst(slot) {
|
|
1708
|
-
const id = this.#nextId++;
|
|
1709
|
-
this.steps.push({
|
|
1710
|
-
type: "const",
|
|
1711
|
-
slot,
|
|
1712
|
-
output: id
|
|
1713
|
-
});
|
|
1714
|
-
return id;
|
|
1715
|
-
}
|
|
1716
1825
|
pushLit(lit) {
|
|
1717
|
-
const kernel = new require_backend.Kernel(0,
|
|
1826
|
+
const kernel = new require_backend.Kernel(0, lit.aval.size, require_backend.AluExp.const(lit.dtype, lit.value));
|
|
1718
1827
|
return this.pushKernel(kernel, []);
|
|
1719
1828
|
}
|
|
1720
|
-
|
|
1829
|
+
pushBuffer(size$1) {
|
|
1721
1830
|
const id = this.#nextId++;
|
|
1722
1831
|
this.steps.push({
|
|
1723
1832
|
type: "malloc",
|
|
1724
|
-
size:
|
|
1833
|
+
size: size$1,
|
|
1725
1834
|
output: id
|
|
1726
1835
|
});
|
|
1836
|
+
return id;
|
|
1837
|
+
}
|
|
1838
|
+
pushKernel(kernel, inputs) {
|
|
1839
|
+
const id = this.pushBuffer(kernel.bytes);
|
|
1727
1840
|
this.steps.push({
|
|
1728
1841
|
type: "execute",
|
|
1729
|
-
kernel,
|
|
1842
|
+
source: kernel,
|
|
1730
1843
|
inputs,
|
|
1731
1844
|
outputs: [id]
|
|
1732
1845
|
});
|
|
1733
1846
|
return id;
|
|
1734
1847
|
}
|
|
1848
|
+
pushRoutine(routine, inputs, outputs) {
|
|
1849
|
+
this.steps.push({
|
|
1850
|
+
type: "execute",
|
|
1851
|
+
source: routine,
|
|
1852
|
+
inputs,
|
|
1853
|
+
outputs
|
|
1854
|
+
});
|
|
1855
|
+
}
|
|
1735
1856
|
pushIncref(id) {
|
|
1736
1857
|
this.steps.push({
|
|
1737
1858
|
type: "incref",
|
|
@@ -1757,28 +1878,18 @@ var JitProgramBuilder = class {
|
|
|
1757
1878
|
}
|
|
1758
1879
|
};
|
|
1759
1880
|
const jitCompileCache = /* @__PURE__ */ new Map();
|
|
1760
|
-
function jitCompile(backend, jaxpr
|
|
1761
|
-
|
|
1762
|
-
for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
|
|
1763
|
-
const cacheKey = backend.type + require_backend.FpHash.hash(jaxpr, ...consts.map((c) => c.id));
|
|
1881
|
+
function jitCompile(backend, jaxpr) {
|
|
1882
|
+
const cacheKey = backend.type + "," + require_backend.FpHash.hash(jaxpr);
|
|
1764
1883
|
const cached = jitCompileCache.get(cacheKey);
|
|
1765
1884
|
if (cached) return cached;
|
|
1766
1885
|
if (require_backend.DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
|
|
1767
1886
|
jaxpr = jaxpr.flatten().simplify();
|
|
1768
|
-
const nargs = jaxpr.inBinders.length
|
|
1887
|
+
const nargs = jaxpr.inBinders.length;
|
|
1769
1888
|
const builder = new JitProgramBuilder(backend, nargs);
|
|
1770
1889
|
const blackNodes = splitGraphDataflow(backend, jaxpr);
|
|
1771
1890
|
const ctx = /* @__PURE__ */ new Map();
|
|
1772
|
-
for (let i = 0; i < consts.length; i++) {
|
|
1773
|
-
const v = jaxpr.inBinders[i];
|
|
1774
|
-
const slot = consts[i]._realizeSource();
|
|
1775
|
-
ctx.set(v, {
|
|
1776
|
-
type: "imm",
|
|
1777
|
-
arg: builder.pushConst(slot)
|
|
1778
|
-
});
|
|
1779
|
-
}
|
|
1780
1891
|
for (let i = 0; i < nargs; i++) {
|
|
1781
|
-
const v = jaxpr.inBinders[
|
|
1892
|
+
const v = jaxpr.inBinders[i];
|
|
1782
1893
|
ctx.set(v, {
|
|
1783
1894
|
type: "imm",
|
|
1784
1895
|
arg: i
|
|
@@ -1786,6 +1897,31 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1786
1897
|
}
|
|
1787
1898
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1788
1899
|
const eqn = jaxpr.eqns[i];
|
|
1900
|
+
if (routinePrimitives.has(eqn.primitive)) {
|
|
1901
|
+
const routine = new require_backend.Routine(routinePrimitives.get(eqn.primitive), {
|
|
1902
|
+
inputShapes: eqn.inputs.map((x) => x.aval.shape),
|
|
1903
|
+
inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
|
|
1904
|
+
outputShapes: eqn.outBinders.map((x) => x.aval.shape),
|
|
1905
|
+
outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
|
|
1906
|
+
}, eqn.params);
|
|
1907
|
+
const inputs = [];
|
|
1908
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1909
|
+
const jv = ctx.get(input);
|
|
1910
|
+
if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
|
|
1911
|
+
inputs.push(jv.arg);
|
|
1912
|
+
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1913
|
+
const outputs = [];
|
|
1914
|
+
for (const outVar$1 of eqn.outBinders) {
|
|
1915
|
+
const outId = builder.pushBuffer(outVar$1.aval.size * require_backend.byteWidth(outVar$1.aval.dtype));
|
|
1916
|
+
outputs.push(outId);
|
|
1917
|
+
ctx.set(outVar$1, {
|
|
1918
|
+
type: "imm",
|
|
1919
|
+
arg: outId
|
|
1920
|
+
});
|
|
1921
|
+
}
|
|
1922
|
+
builder.pushRoutine(routine, inputs, outputs);
|
|
1923
|
+
continue;
|
|
1924
|
+
}
|
|
1789
1925
|
const inputExps = [];
|
|
1790
1926
|
const inputAvals = [];
|
|
1791
1927
|
const inputArgs = [];
|
|
@@ -1840,7 +1976,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1840
1976
|
const outVar = eqn.outBinders[0];
|
|
1841
1977
|
if (blackNodes.has(outVar)) {
|
|
1842
1978
|
const nargs$1 = inputArgs.length;
|
|
1843
|
-
const size$1 =
|
|
1979
|
+
const size$1 = outVar.aval.size;
|
|
1844
1980
|
const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
|
|
1845
1981
|
const outId = builder.pushKernel(kernel, inputArgs);
|
|
1846
1982
|
ctx.set(outVar, {
|
|
@@ -1865,7 +2001,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1865
2001
|
if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
|
|
1866
2002
|
outputIds.push(jitValue.arg);
|
|
1867
2003
|
} else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
|
|
1868
|
-
const outputNeedsRef = new Set(
|
|
2004
|
+
const outputNeedsRef = new Set(require_backend.range(nargs));
|
|
1869
2005
|
for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
|
|
1870
2006
|
else outputNeedsRef.add(outputId);
|
|
1871
2007
|
builder.insertFreeSteps(outputIds);
|
|
@@ -1911,11 +2047,18 @@ function reshapeJit(fn) {
|
|
|
1911
2047
|
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
1912
2048
|
};
|
|
1913
2049
|
}
|
|
2050
|
+
function routineNoJit() {
|
|
2051
|
+
return () => {
|
|
2052
|
+
throw new Error("jit: rule is not implemented for routines");
|
|
2053
|
+
};
|
|
2054
|
+
}
|
|
1914
2055
|
const jitRules = {
|
|
1915
2056
|
[Primitive.Add]: broadcastedJit(([a, b]) => require_backend.AluExp.add(a, b)),
|
|
1916
2057
|
[Primitive.Mul]: broadcastedJit(([a, b]) => require_backend.AluExp.mul(a, b)),
|
|
1917
2058
|
[Primitive.Idiv]: broadcastedJit(([a, b]) => require_backend.AluExp.idiv(a, b)),
|
|
1918
2059
|
[Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
|
|
2060
|
+
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
2061
|
+
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1919
2062
|
[Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
|
|
1920
2063
|
[Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
|
|
1921
2064
|
[Primitive.Floor]: unopJit(require_backend.AluExp.floor),
|
|
@@ -1923,17 +2066,6 @@ const jitRules = {
|
|
|
1923
2066
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1924
2067
|
[Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
|
|
1925
2068
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
|
|
1926
|
-
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1927
|
-
const mapping = (st) => {
|
|
1928
|
-
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
|
|
1929
|
-
};
|
|
1930
|
-
const k0 = reshapeViews(keys[0], mapping);
|
|
1931
|
-
const k1 = reshapeViews(keys[1], mapping);
|
|
1932
|
-
const c0 = require_backend.AluExp.u32(0);
|
|
1933
|
-
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1934
|
-
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1935
|
-
return { exp: exp$2 };
|
|
1936
|
-
},
|
|
1937
2069
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
1938
2070
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
1939
2071
|
[Primitive.Asin]: unopJit(require_backend.AluExp.asin),
|
|
@@ -1943,8 +2075,6 @@ const jitRules = {
|
|
|
1943
2075
|
[Primitive.Erf]: unopJit(require_backend.AluExp.erf),
|
|
1944
2076
|
[Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
|
|
1945
2077
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1946
|
-
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
1947
|
-
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1948
2078
|
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1949
2079
|
const keptAxes = [];
|
|
1950
2080
|
const shiftedAxes = [];
|
|
@@ -1994,16 +2124,17 @@ const jitRules = {
|
|
|
1994
2124
|
},
|
|
1995
2125
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1996
2126
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1997
|
-
[Primitive.
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
2001
|
-
const
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2127
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2128
|
+
const mapping = (st) => {
|
|
2129
|
+
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
|
|
2130
|
+
};
|
|
2131
|
+
const k0 = reshapeViews(keys[0], mapping);
|
|
2132
|
+
const k1 = reshapeViews(keys[1], mapping);
|
|
2133
|
+
const c0 = require_backend.AluExp.u32(0);
|
|
2134
|
+
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
2135
|
+
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2136
|
+
return { exp: exp$2 };
|
|
2137
|
+
},
|
|
2007
2138
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
2008
2139
|
const axisSet = new Set(axis);
|
|
2009
2140
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
@@ -2019,8 +2150,22 @@ const jitRules = {
|
|
|
2019
2150
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
2020
2151
|
return { exp: x.substitute({ gidx: index }) };
|
|
2021
2152
|
},
|
|
2022
|
-
[Primitive.
|
|
2023
|
-
|
|
2153
|
+
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2154
|
+
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
2155
|
+
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
2156
|
+
[Primitive.Flip]: reshapeJit((st, { axis }) => {
|
|
2157
|
+
const arg = require_backend.rep(st.shape.length, false);
|
|
2158
|
+
for (const ax of axis) arg[ax] = true;
|
|
2159
|
+
return st.flip(arg);
|
|
2160
|
+
}),
|
|
2161
|
+
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
2162
|
+
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
2163
|
+
[Primitive.Sort]: routineNoJit(),
|
|
2164
|
+
[Primitive.Argsort]: routineNoJit(),
|
|
2165
|
+
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2166
|
+
[Primitive.Cholesky]: routineNoJit(),
|
|
2167
|
+
[Primitive.Jit]() {
|
|
2168
|
+
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
2024
2169
|
}
|
|
2025
2170
|
};
|
|
2026
2171
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
@@ -2078,8 +2223,8 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2078
2223
|
case Primitive.Mul:
|
|
2079
2224
|
case Primitive.Idiv:
|
|
2080
2225
|
case Primitive.Mod:
|
|
2081
|
-
case Primitive.
|
|
2082
|
-
case Primitive.
|
|
2226
|
+
case Primitive.Min:
|
|
2227
|
+
case Primitive.Max: {
|
|
2083
2228
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2084
2229
|
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2085
2230
|
head = usages[0];
|
|
@@ -2099,11 +2244,11 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2099
2244
|
blackNodes.add(v);
|
|
2100
2245
|
p1NextBlack.set(v, v);
|
|
2101
2246
|
}
|
|
2102
|
-
const heterogeneousViewPrimitives = [Primitive.
|
|
2247
|
+
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2103
2248
|
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2104
2249
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2105
2250
|
const eqn = jaxpr.eqns[i];
|
|
2106
|
-
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2251
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2107
2252
|
for (const v of eqn.outBinders) {
|
|
2108
2253
|
blackNodes.add(v);
|
|
2109
2254
|
p1NextBlack.set(v, v);
|
|
@@ -2113,7 +2258,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2113
2258
|
const reach = /* @__PURE__ */ new Set();
|
|
2114
2259
|
let needsCleanOutput = false;
|
|
2115
2260
|
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2116
|
-
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
|
|
2261
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
|
|
2117
2262
|
needsCleanOutput = true;
|
|
2118
2263
|
break outer;
|
|
2119
2264
|
}
|
|
@@ -2137,7 +2282,6 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2137
2282
|
while (p2idx < jaxpr.eqns.length) {
|
|
2138
2283
|
const eqn = jaxpr.eqns[p2idx++];
|
|
2139
2284
|
const deps = [];
|
|
2140
|
-
if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
|
|
2141
2285
|
for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
|
|
2142
2286
|
else deps.push(p2Deps.get(input));
|
|
2143
2287
|
else deps.push(/* @__PURE__ */ new Set());
|
|
@@ -2160,7 +2304,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2160
2304
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2161
2305
|
const assocVar = eqn.inputs[assocInput];
|
|
2162
2306
|
p2idx = varToDefn.get(assocVar);
|
|
2163
|
-
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2307
|
+
for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
|
|
2164
2308
|
} else {
|
|
2165
2309
|
const s = new Set(depCounter.keys());
|
|
2166
2310
|
for (const out of eqn.outBinders) p2Deps.set(out, s);
|
|
@@ -2186,9 +2330,9 @@ var PendingExecute = class {
|
|
|
2186
2330
|
submitted = false;
|
|
2187
2331
|
#promise = null;
|
|
2188
2332
|
#rc = 1;
|
|
2189
|
-
constructor(backend,
|
|
2333
|
+
constructor(backend, source, inputs, outputs) {
|
|
2190
2334
|
this.backend = backend;
|
|
2191
|
-
this.
|
|
2335
|
+
this.source = source;
|
|
2192
2336
|
this.inputs = inputs;
|
|
2193
2337
|
this.outputs = outputs;
|
|
2194
2338
|
for (const slot of inputs) this.backend.incRef(slot);
|
|
@@ -2209,13 +2353,15 @@ var PendingExecute = class {
|
|
|
2209
2353
|
return;
|
|
2210
2354
|
}
|
|
2211
2355
|
this.#promise = (async () => {
|
|
2212
|
-
this.prepared = await this.backend.
|
|
2356
|
+
if (this.source instanceof require_backend.Kernel) this.prepared = await this.backend.prepareKernel(this.source);
|
|
2357
|
+
else this.prepared = await this.backend.prepareRoutine(this.source);
|
|
2213
2358
|
})();
|
|
2214
2359
|
await this.#promise;
|
|
2215
2360
|
}
|
|
2216
2361
|
prepareSync() {
|
|
2217
2362
|
if (this.prepared) return;
|
|
2218
|
-
this.prepared = this.backend.
|
|
2363
|
+
if (this.source instanceof require_backend.Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
|
|
2364
|
+
else this.prepared = this.backend.prepareRoutineSync(this.source);
|
|
2219
2365
|
}
|
|
2220
2366
|
submit() {
|
|
2221
2367
|
if (this.submitted) return;
|
|
@@ -2238,8 +2384,6 @@ var PendingExecute = class {
|
|
|
2238
2384
|
* "Array" type by name.
|
|
2239
2385
|
*/
|
|
2240
2386
|
var Array$1 = class Array$1 extends Tracer {
|
|
2241
|
-
static #nextId = 1001;
|
|
2242
|
-
id;
|
|
2243
2387
|
#dtype;
|
|
2244
2388
|
#weakType;
|
|
2245
2389
|
#source;
|
|
@@ -2256,7 +2400,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2256
2400
|
*/
|
|
2257
2401
|
constructor(args) {
|
|
2258
2402
|
super(baseArrayTrace);
|
|
2259
|
-
this.id = Array$1.#nextId++;
|
|
2260
2403
|
this.#dtype = args.dtype;
|
|
2261
2404
|
this.#weakType = args.weakType;
|
|
2262
2405
|
this.#source = args.source;
|
|
@@ -2565,6 +2708,27 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2565
2708
|
pending
|
|
2566
2709
|
});
|
|
2567
2710
|
}
|
|
2711
|
+
/** Apply an operation with custom lowering to this array. */
|
|
2712
|
+
static #routine(routine, arrays, outputWeakType) {
|
|
2713
|
+
const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
|
|
2714
|
+
for (const ar of arrays) ar.#realize();
|
|
2715
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2716
|
+
const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
|
|
2717
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2718
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2719
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2720
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2721
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2722
|
+
return outputs.map((output, i) => new Array$1({
|
|
2723
|
+
source: output,
|
|
2724
|
+
st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
|
|
2725
|
+
dtype: routine.type.outputDtypes[i],
|
|
2726
|
+
weakType: outputWeakType[i],
|
|
2727
|
+
backend,
|
|
2728
|
+
committed,
|
|
2729
|
+
pending
|
|
2730
|
+
}));
|
|
2731
|
+
}
|
|
2568
2732
|
/**
|
|
2569
2733
|
* Normalizes this array into one backed by a `Slot`.
|
|
2570
2734
|
*
|
|
@@ -2725,6 +2889,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2725
2889
|
[Primitive.Mod]([x, y]) {
|
|
2726
2890
|
return [x.#binary(require_backend.AluOp.Mod, y)];
|
|
2727
2891
|
},
|
|
2892
|
+
[Primitive.Min]([x, y]) {
|
|
2893
|
+
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
2894
|
+
},
|
|
2895
|
+
[Primitive.Max]([x, y]) {
|
|
2896
|
+
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
2897
|
+
},
|
|
2728
2898
|
[Primitive.Neg]([x]) {
|
|
2729
2899
|
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
2730
2900
|
},
|
|
@@ -2761,25 +2931,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2761
2931
|
return [y];
|
|
2762
2932
|
}
|
|
2763
2933
|
},
|
|
2764
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2765
|
-
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2766
|
-
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2767
|
-
const c0 = zeros(shape$1, {
|
|
2768
|
-
dtype: require_backend.DType.Uint32,
|
|
2769
|
-
device: k0.device
|
|
2770
|
-
});
|
|
2771
|
-
const c1 = arange(0, require_backend.prod(shape$1), 1, {
|
|
2772
|
-
dtype: require_backend.DType.Uint32,
|
|
2773
|
-
device: k0.device
|
|
2774
|
-
}).reshape(shape$1);
|
|
2775
|
-
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2776
|
-
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2777
|
-
k0,
|
|
2778
|
-
k1,
|
|
2779
|
-
c0,
|
|
2780
|
-
c1
|
|
2781
|
-
])];
|
|
2782
|
-
},
|
|
2783
2934
|
[Primitive.Sin]([x]) {
|
|
2784
2935
|
return [x.#unary(require_backend.AluOp.Sin)];
|
|
2785
2936
|
},
|
|
@@ -2807,12 +2958,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2807
2958
|
[Primitive.Sqrt]([x]) {
|
|
2808
2959
|
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
2809
2960
|
},
|
|
2810
|
-
[Primitive.Min]([x, y]) {
|
|
2811
|
-
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
2812
|
-
},
|
|
2813
|
-
[Primitive.Max]([x, y]) {
|
|
2814
|
-
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
2815
|
-
},
|
|
2816
2961
|
[Primitive.Reduce]([x], { op, axis }) {
|
|
2817
2962
|
if (axis.length === 0) return [x];
|
|
2818
2963
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
@@ -2847,13 +2992,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2847
2992
|
y
|
|
2848
2993
|
], { dtypeOverride: [require_backend.DType.Bool] })];
|
|
2849
2994
|
},
|
|
2850
|
-
[Primitive.
|
|
2851
|
-
|
|
2852
|
-
|
|
2853
|
-
|
|
2854
|
-
|
|
2855
|
-
|
|
2856
|
-
|
|
2995
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2996
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2997
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2998
|
+
const c0 = zeros(shape$1, {
|
|
2999
|
+
dtype: require_backend.DType.Uint32,
|
|
3000
|
+
device: k0.device
|
|
3001
|
+
});
|
|
3002
|
+
const c1 = arange(0, require_backend.prod(shape$1), 1, {
|
|
3003
|
+
dtype: require_backend.DType.Uint32,
|
|
3004
|
+
device: k0.device
|
|
3005
|
+
}).reshape(shape$1);
|
|
3006
|
+
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
3007
|
+
return [Array$1.#naryCustom("random_bits", custom, [
|
|
3008
|
+
k0,
|
|
3009
|
+
k1,
|
|
3010
|
+
c0,
|
|
3011
|
+
c1
|
|
3012
|
+
])];
|
|
3013
|
+
},
|
|
3014
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
3015
|
+
return [x.#gather(indices, axis, outDim)];
|
|
3016
|
+
},
|
|
3017
|
+
[Primitive.Transpose]([x], { perm }) {
|
|
3018
|
+
return [x.#transpose(perm)];
|
|
3019
|
+
},
|
|
3020
|
+
[Primitive.Broadcast]([x], { shape: shape$1, axis }) {
|
|
3021
|
+
return [x.#reshape(x.#st.broadcast(shape$1, axis))];
|
|
3022
|
+
},
|
|
3023
|
+
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2857
3024
|
return [x.#reshape(x.#st.reshape(shape$1))];
|
|
2858
3025
|
},
|
|
2859
3026
|
[Primitive.Flip]([x], { axis }) {
|
|
@@ -2867,17 +3034,48 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2867
3034
|
[Primitive.Pad]([x], { width }) {
|
|
2868
3035
|
return [x.#reshape(x.#st.pad(width))];
|
|
2869
3036
|
},
|
|
2870
|
-
[Primitive.
|
|
2871
|
-
|
|
3037
|
+
[Primitive.Sort]([x]) {
|
|
3038
|
+
const routine = new require_backend.Routine(require_backend.Routines.Sort, {
|
|
3039
|
+
inputShapes: [x.aval.shape],
|
|
3040
|
+
inputDtypes: [x.aval.dtype],
|
|
3041
|
+
outputShapes: [x.aval.shape],
|
|
3042
|
+
outputDtypes: [x.aval.dtype]
|
|
3043
|
+
});
|
|
3044
|
+
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
2872
3045
|
},
|
|
2873
|
-
[Primitive.
|
|
2874
|
-
|
|
2875
|
-
|
|
3046
|
+
[Primitive.Argsort]([x]) {
|
|
3047
|
+
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3048
|
+
inputShapes: [x.aval.shape],
|
|
3049
|
+
inputDtypes: [x.aval.dtype],
|
|
3050
|
+
outputShapes: [x.aval.shape, x.aval.shape],
|
|
3051
|
+
outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
|
|
3052
|
+
});
|
|
3053
|
+
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3054
|
+
},
|
|
3055
|
+
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3056
|
+
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3057
|
+
inputShapes: [a.aval.shape, b.aval.shape],
|
|
3058
|
+
inputDtypes: [a.aval.dtype, b.aval.dtype],
|
|
3059
|
+
outputShapes: [b.aval.shape],
|
|
3060
|
+
outputDtypes: [b.aval.dtype]
|
|
3061
|
+
}, { unitDiagonal });
|
|
3062
|
+
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3063
|
+
},
|
|
3064
|
+
[Primitive.Cholesky]([a]) {
|
|
3065
|
+
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3066
|
+
inputShapes: [a.aval.shape],
|
|
3067
|
+
inputDtypes: [a.aval.dtype],
|
|
3068
|
+
outputShapes: [a.aval.shape],
|
|
3069
|
+
outputDtypes: [a.aval.dtype]
|
|
3070
|
+
});
|
|
3071
|
+
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3072
|
+
},
|
|
3073
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
3074
|
+
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3075
|
+
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
2876
3076
|
args = args.map((ar) => ar._putSync(backend));
|
|
2877
|
-
const
|
|
2878
|
-
const
|
|
2879
|
-
const jp = jitCompile(backend, jaxpr, consts);
|
|
2880
|
-
const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
|
|
3077
|
+
const jp = jitCompile(backend, jaxpr);
|
|
3078
|
+
const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
|
|
2881
3079
|
for (const exe of pending) exe.updateRc(+outputs.length - 1);
|
|
2882
3080
|
const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
|
|
2883
3081
|
for (const exe of prevPending) exe.updateRc(+outputs.length);
|
|
@@ -3176,6 +3374,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
3176
3374
|
});
|
|
3177
3375
|
}
|
|
3178
3376
|
/**
|
|
3377
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
3378
|
+
*
|
|
3379
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
3380
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
3381
|
+
* `k>0` is above it.
|
|
3382
|
+
*/
|
|
3383
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
3384
|
+
m ??= n;
|
|
3385
|
+
dtype ??= require_backend.DType.Float32;
|
|
3386
|
+
if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
|
|
3387
|
+
if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
|
|
3388
|
+
if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
|
|
3389
|
+
const rows = arange(k, n + k, 1, {
|
|
3390
|
+
dtype: require_backend.DType.Int32,
|
|
3391
|
+
device
|
|
3392
|
+
});
|
|
3393
|
+
const cols = arange(0, m, 1, {
|
|
3394
|
+
dtype: require_backend.DType.Int32,
|
|
3395
|
+
device
|
|
3396
|
+
});
|
|
3397
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
3398
|
+
}
|
|
3399
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
3400
|
+
function tril(a, k = 0) {
|
|
3401
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3402
|
+
a = fudgeArray(a);
|
|
3403
|
+
const [n, m] = a.shape.slice(-2);
|
|
3404
|
+
return where$1(tri(n, m, k, { dtype: require_backend.DType.Bool }), a.ref, zerosLike$1(a));
|
|
3405
|
+
}
|
|
3406
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
3407
|
+
function triu(a, k = 0) {
|
|
3408
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3409
|
+
a = fudgeArray(a);
|
|
3410
|
+
const [n, m] = a.shape.slice(-2);
|
|
3411
|
+
return where$1(tri(n, m, k - 1, { dtype: require_backend.DType.Bool }), zerosLike$1(a.ref), a);
|
|
3412
|
+
}
|
|
3413
|
+
/**
|
|
3179
3414
|
* Return evenly spaced numbers over a specified interval.
|
|
3180
3415
|
*
|
|
3181
3416
|
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
@@ -3222,385 +3457,187 @@ function aluCompare(a, b, op) {
|
|
|
3222
3457
|
}
|
|
3223
3458
|
|
|
3224
3459
|
//#endregion
|
|
3225
|
-
//#region src/frontend/
|
|
3460
|
+
//#region src/frontend/vmap.ts
|
|
3226
3461
|
var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3227
|
-
|
|
3228
|
-
|
|
3462
|
+
function mappedAval(batchDim, aval) {
|
|
3463
|
+
const shape$1 = [...aval.shape];
|
|
3464
|
+
shape$1.splice(batchDim, 1);
|
|
3465
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3466
|
+
}
|
|
3467
|
+
/** Move one axis to a different index. */
|
|
3468
|
+
function moveaxis(x, src, dst) {
|
|
3469
|
+
const t = pureArray(x);
|
|
3470
|
+
src = require_backend.checkAxis(src, t.ndim);
|
|
3471
|
+
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3472
|
+
if (src === dst) return t;
|
|
3473
|
+
const perm = require_backend.range(t.ndim);
|
|
3474
|
+
perm.splice(src, 1);
|
|
3475
|
+
perm.splice(dst, 0, src);
|
|
3476
|
+
return transpose$1(t, perm);
|
|
3477
|
+
}
|
|
3478
|
+
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3479
|
+
if (src === null) {
|
|
3480
|
+
const targetShape = [...x.shape];
|
|
3481
|
+
targetShape.splice(dst, 0, axisSize);
|
|
3482
|
+
return broadcast(x, targetShape, [dst]);
|
|
3483
|
+
} else if (src === dst) return x;
|
|
3484
|
+
else return moveaxis(x, src, dst);
|
|
3485
|
+
}
|
|
3486
|
+
var BatchTracer = class extends Tracer {
|
|
3487
|
+
constructor(trace$1, val, batchDim) {
|
|
3229
3488
|
super(trace$1);
|
|
3230
|
-
this.
|
|
3231
|
-
this.
|
|
3489
|
+
this.val = val;
|
|
3490
|
+
this.batchDim = batchDim;
|
|
3232
3491
|
}
|
|
3233
3492
|
get aval() {
|
|
3234
|
-
return this.
|
|
3493
|
+
if (this.batchDim === null) return this.val.aval;
|
|
3494
|
+
else return mappedAval(this.batchDim, this.val.aval);
|
|
3235
3495
|
}
|
|
3236
3496
|
toString() {
|
|
3237
|
-
return `
|
|
3497
|
+
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3238
3498
|
}
|
|
3239
3499
|
get ref() {
|
|
3240
|
-
this.
|
|
3500
|
+
this.val.ref;
|
|
3241
3501
|
return this;
|
|
3242
3502
|
}
|
|
3243
3503
|
dispose() {
|
|
3244
|
-
this.
|
|
3245
|
-
|
|
3504
|
+
this.val.dispose();
|
|
3505
|
+
}
|
|
3506
|
+
fullLower() {
|
|
3507
|
+
if (this.batchDim === null) return this.val.fullLower();
|
|
3508
|
+
else return this;
|
|
3246
3509
|
}
|
|
3247
3510
|
};
|
|
3248
|
-
var
|
|
3511
|
+
var BatchTrace = class extends Trace {
|
|
3249
3512
|
pure(val) {
|
|
3250
3513
|
return this.lift(pureArray(val));
|
|
3251
3514
|
}
|
|
3252
3515
|
lift(val) {
|
|
3253
|
-
return new
|
|
3516
|
+
return new BatchTracer(this, val, null);
|
|
3254
3517
|
}
|
|
3255
3518
|
processPrimitive(primitive, tracers, params) {
|
|
3256
|
-
const [
|
|
3257
|
-
const
|
|
3258
|
-
if (
|
|
3259
|
-
|
|
3260
|
-
|
|
3519
|
+
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3520
|
+
const vmapRule = vmapRules[primitive];
|
|
3521
|
+
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3522
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3523
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3524
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3525
|
+
}
|
|
3526
|
+
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3527
|
+
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3528
|
+
}
|
|
3529
|
+
get axisSize() {
|
|
3530
|
+
return this.main.globalData;
|
|
3261
3531
|
}
|
|
3262
3532
|
};
|
|
3263
|
-
/**
|
|
3264
|
-
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3268
|
-
|
|
3269
|
-
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
|
|
3273
|
-
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3533
|
+
/**
|
|
3534
|
+
* Process a primitive with built-in broadcasting.
|
|
3535
|
+
*
|
|
3536
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3537
|
+
*/
|
|
3538
|
+
function broadcastBatcher(op) {
|
|
3539
|
+
return (axisSize, args, dims) => {
|
|
3540
|
+
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3541
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3542
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3543
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3544
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3545
|
+
args = args.map((x, i) => {
|
|
3546
|
+
if (dims[i] === null) return x;
|
|
3547
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3548
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3549
|
+
x.shape[0],
|
|
3550
|
+
...require_backend.rep(nd - x.ndim, 1),
|
|
3551
|
+
...x.shape.slice(1)
|
|
3552
|
+
]);
|
|
3553
|
+
return x;
|
|
3554
|
+
});
|
|
3555
|
+
return [[op(...args)], [0]];
|
|
3277
3556
|
};
|
|
3278
3557
|
}
|
|
3279
|
-
|
|
3280
|
-
|
|
3281
|
-
|
|
3282
|
-
for (const t of tangents) t.dispose();
|
|
3283
|
-
const ys = bind(primitive, primals, params);
|
|
3284
|
-
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3558
|
+
function unopBatcher(op) {
|
|
3559
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3560
|
+
return [[op(x, params)], [xBdim]];
|
|
3285
3561
|
};
|
|
3286
3562
|
}
|
|
3287
|
-
const
|
|
3288
|
-
[Primitive.Add]:
|
|
3289
|
-
[Primitive.Mul]:
|
|
3290
|
-
[Primitive.Idiv]:
|
|
3291
|
-
[Primitive.Mod](
|
|
3292
|
-
|
|
3293
|
-
|
|
3294
|
-
|
|
3295
|
-
|
|
3296
|
-
|
|
3297
|
-
|
|
3298
|
-
|
|
3563
|
+
const vmapRules = {
|
|
3564
|
+
[Primitive.Add]: broadcastBatcher(add$1),
|
|
3565
|
+
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3566
|
+
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3567
|
+
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3568
|
+
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3569
|
+
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3570
|
+
[Primitive.Neg]: unopBatcher(neg),
|
|
3571
|
+
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3572
|
+
[Primitive.Floor]: unopBatcher(floor$1),
|
|
3573
|
+
[Primitive.Ceil]: unopBatcher(ceil$1),
|
|
3574
|
+
[Primitive.StopGradient]: unopBatcher(stopGradient),
|
|
3575
|
+
[Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
|
|
3576
|
+
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3577
|
+
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3578
|
+
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3579
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3580
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3581
|
+
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3582
|
+
[Primitive.Log]: unopBatcher(log$1),
|
|
3583
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3584
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3585
|
+
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3586
|
+
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3587
|
+
require_backend.assertNonNull(xBdim);
|
|
3588
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3589
|
+
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3590
|
+
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3299
3591
|
},
|
|
3300
|
-
[Primitive.
|
|
3301
|
-
|
|
3302
|
-
|
|
3303
|
-
|
|
3592
|
+
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3593
|
+
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3594
|
+
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3595
|
+
const z = dot$2(x, y);
|
|
3596
|
+
return [[z], [z.ndim - 1]];
|
|
3304
3597
|
},
|
|
3305
|
-
[Primitive.
|
|
3306
|
-
|
|
3307
|
-
|
|
3308
|
-
|
|
3309
|
-
|
|
3310
|
-
|
|
3311
|
-
|
|
3312
|
-
|
|
3313
|
-
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3314
|
-
}
|
|
3598
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3599
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3600
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3601
|
+
const z = conv$1(x, y, {
|
|
3602
|
+
...params,
|
|
3603
|
+
vmapDims: params.vmapDims + 1
|
|
3604
|
+
});
|
|
3605
|
+
return [[z], [0]];
|
|
3315
3606
|
},
|
|
3316
|
-
[Primitive.
|
|
3317
|
-
|
|
3318
|
-
dx.dispose();
|
|
3319
|
-
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3607
|
+
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3608
|
+
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3320
3609
|
},
|
|
3321
|
-
[Primitive.
|
|
3322
|
-
[Primitive.
|
|
3323
|
-
|
|
3324
|
-
|
|
3325
|
-
|
|
3326
|
-
|
|
3327
|
-
|
|
3328
|
-
|
|
3329
|
-
|
|
3330
|
-
|
|
3331
|
-
},
|
|
3332
|
-
[Primitive.Atan]([x], [dx]) {
|
|
3333
|
-
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3334
|
-
return [[atan$1(x)], [dx.div(denom)]];
|
|
3335
|
-
},
|
|
3336
|
-
[Primitive.Exp]([x], [dx]) {
|
|
3337
|
-
const z = exp$1(x);
|
|
3338
|
-
return [[z.ref], [z.mul(dx)]];
|
|
3339
|
-
},
|
|
3340
|
-
[Primitive.Log]([x], [dx]) {
|
|
3341
|
-
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3342
|
-
},
|
|
3343
|
-
[Primitive.Erf]([x], [dx]) {
|
|
3344
|
-
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3345
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3346
|
-
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3347
|
-
},
|
|
3348
|
-
[Primitive.Erfc]([x], [dx]) {
|
|
3349
|
-
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3350
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3351
|
-
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3352
|
-
},
|
|
3353
|
-
[Primitive.Sqrt]([x], [dx]) {
|
|
3354
|
-
const z = sqrt$1(x);
|
|
3355
|
-
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3356
|
-
},
|
|
3357
|
-
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3358
|
-
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3359
|
-
},
|
|
3360
|
-
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3361
|
-
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3362
|
-
},
|
|
3363
|
-
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3364
|
-
if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3365
|
-
else if (op === require_backend.AluOp.Mul) {
|
|
3366
|
-
const primal = reduce(x.ref, op, axis);
|
|
3367
|
-
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3368
|
-
return [[primal], [tangent]];
|
|
3369
|
-
} else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
|
|
3370
|
-
const primal = reduce(x.ref, op, axis);
|
|
3371
|
-
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3372
|
-
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3373
|
-
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3374
|
-
return [[primal], [tangent]];
|
|
3375
|
-
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3376
|
-
},
|
|
3377
|
-
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3378
|
-
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3379
|
-
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3380
|
-
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3381
|
-
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3382
|
-
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3383
|
-
dcond.dispose();
|
|
3384
|
-
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3385
|
-
},
|
|
3386
|
-
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3387
|
-
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3388
|
-
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3389
|
-
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3390
|
-
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3391
|
-
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3392
|
-
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3393
|
-
const indicesRef = indices.map((t) => t.ref);
|
|
3394
|
-
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3395
|
-
},
|
|
3396
|
-
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3397
|
-
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3398
|
-
const outs = bind(Primitive.JitCall, [
|
|
3399
|
-
...newConsts.map((c) => c.ref),
|
|
3400
|
-
...primals,
|
|
3401
|
-
...tangents
|
|
3402
|
-
], {
|
|
3403
|
-
name: `${name}_jvp`,
|
|
3404
|
-
jaxpr: newJaxpr,
|
|
3405
|
-
numConsts: newConsts.length
|
|
3406
|
-
});
|
|
3407
|
-
const n = outs.length / 2;
|
|
3408
|
-
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3409
|
-
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3410
|
-
return [primalsOut, tangentsOut];
|
|
3411
|
-
}
|
|
3412
|
-
};
|
|
3413
|
-
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3414
|
-
function jvpJaxpr(jaxpr) {
|
|
3415
|
-
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3416
|
-
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3417
|
-
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3418
|
-
const result = {
|
|
3419
|
-
newJaxpr,
|
|
3420
|
-
newConsts
|
|
3421
|
-
};
|
|
3422
|
-
jvpJaxprCache.set(jaxpr, result);
|
|
3423
|
-
return result;
|
|
3424
|
-
}
|
|
3425
|
-
function jvpFlat(f, primals, tangents) {
|
|
3426
|
-
try {
|
|
3427
|
-
var _usingCtx$1 = (0, import_usingCtx$1.default)();
|
|
3428
|
-
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3429
|
-
const trace$1 = new JVPTrace(main);
|
|
3430
|
-
const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3431
|
-
const outs = f(...tracersIn);
|
|
3432
|
-
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3433
|
-
return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3434
|
-
} catch (_) {
|
|
3435
|
-
_usingCtx$1.e = _;
|
|
3436
|
-
} finally {
|
|
3437
|
-
_usingCtx$1.d();
|
|
3438
|
-
}
|
|
3439
|
-
}
|
|
3440
|
-
function jvp$1(f, primals, tangents) {
|
|
3441
|
-
const [primalsFlat, inTree] = flatten(primals);
|
|
3442
|
-
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
3443
|
-
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
3444
|
-
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
3445
|
-
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
3446
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
3447
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3448
|
-
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
3449
|
-
return [primalsOut, tangentsOut];
|
|
3450
|
-
}
|
|
3451
|
-
|
|
3452
|
-
//#endregion
|
|
3453
|
-
//#region src/frontend/vmap.ts
|
|
3454
|
-
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3455
|
-
function mappedAval(batchDim, aval) {
|
|
3456
|
-
const shape$1 = [...aval.shape];
|
|
3457
|
-
shape$1.splice(batchDim, 1);
|
|
3458
|
-
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3459
|
-
}
|
|
3460
|
-
/** Move one axis to a different index. */
|
|
3461
|
-
function moveaxis(x, src, dst) {
|
|
3462
|
-
const t = pureArray(x);
|
|
3463
|
-
src = require_backend.checkAxis(src, t.ndim);
|
|
3464
|
-
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3465
|
-
if (src === dst) return t;
|
|
3466
|
-
const perm = require_backend.range(t.ndim);
|
|
3467
|
-
perm.splice(src, 1);
|
|
3468
|
-
perm.splice(dst, 0, src);
|
|
3469
|
-
return transpose$1(t, perm);
|
|
3470
|
-
}
|
|
3471
|
-
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3472
|
-
if (src === null) {
|
|
3473
|
-
const targetShape = [...x.shape];
|
|
3474
|
-
targetShape.splice(dst, 0, axisSize);
|
|
3475
|
-
return broadcast(x, targetShape, [dst]);
|
|
3476
|
-
} else if (src === dst) return x;
|
|
3477
|
-
else return moveaxis(x, src, dst);
|
|
3478
|
-
}
|
|
3479
|
-
var BatchTracer = class extends Tracer {
|
|
3480
|
-
constructor(trace$1, val, batchDim) {
|
|
3481
|
-
super(trace$1);
|
|
3482
|
-
this.val = val;
|
|
3483
|
-
this.batchDim = batchDim;
|
|
3484
|
-
}
|
|
3485
|
-
get aval() {
|
|
3486
|
-
if (this.batchDim === null) return this.val.aval;
|
|
3487
|
-
else return mappedAval(this.batchDim, this.val.aval);
|
|
3488
|
-
}
|
|
3489
|
-
toString() {
|
|
3490
|
-
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3491
|
-
}
|
|
3492
|
-
get ref() {
|
|
3493
|
-
this.val.ref;
|
|
3494
|
-
return this;
|
|
3495
|
-
}
|
|
3496
|
-
dispose() {
|
|
3497
|
-
this.val.dispose();
|
|
3498
|
-
}
|
|
3499
|
-
fullLower() {
|
|
3500
|
-
if (this.batchDim === null) return this.val.fullLower();
|
|
3501
|
-
else return this;
|
|
3502
|
-
}
|
|
3503
|
-
};
|
|
3504
|
-
var BatchTrace = class extends Trace {
|
|
3505
|
-
pure(val) {
|
|
3506
|
-
return this.lift(pureArray(val));
|
|
3507
|
-
}
|
|
3508
|
-
lift(val) {
|
|
3509
|
-
return new BatchTracer(this, val, null);
|
|
3510
|
-
}
|
|
3511
|
-
processPrimitive(primitive, tracers, params) {
|
|
3512
|
-
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3513
|
-
const vmapRule = vmapRules[primitive];
|
|
3514
|
-
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3515
|
-
if (bdimsIn.every((d) => d === null)) {
|
|
3516
|
-
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3517
|
-
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3610
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3611
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3612
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3613
|
+
require_backend.assertNonNull(xBdim);
|
|
3614
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3615
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3616
|
+
let newOutDim = outDim;
|
|
3617
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3618
|
+
else newOutDim += 1;
|
|
3619
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3518
3620
|
}
|
|
3519
|
-
const
|
|
3520
|
-
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
|
|
3524
|
-
|
|
3525
|
-
|
|
3526
|
-
|
|
3527
|
-
* Process a primitive with built-in broadcasting.
|
|
3528
|
-
*
|
|
3529
|
-
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3530
|
-
*/
|
|
3531
|
-
function broadcastBatcher(op) {
|
|
3532
|
-
return (axisSize, args, dims) => {
|
|
3533
|
-
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3534
|
-
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3535
|
-
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3536
|
-
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3537
|
-
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3538
|
-
args = args.map((x, i) => {
|
|
3539
|
-
if (dims[i] === null) return x;
|
|
3540
|
-
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3541
|
-
if (x.ndim < nd) x = x.reshape([
|
|
3542
|
-
x.shape[0],
|
|
3543
|
-
...require_backend.rep(nd - x.ndim, 1),
|
|
3544
|
-
...x.shape.slice(1)
|
|
3621
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3622
|
+
indices = indices.map((m, i) => {
|
|
3623
|
+
if (indicesBdim[i] === null) return m;
|
|
3624
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3625
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3626
|
+
m.shape[0],
|
|
3627
|
+
...require_backend.rep(nd - m.ndim, 1),
|
|
3628
|
+
...m.shape.slice(1)
|
|
3545
3629
|
]);
|
|
3546
|
-
return
|
|
3547
|
-
});
|
|
3548
|
-
return [[
|
|
3549
|
-
|
|
3550
|
-
|
|
3551
|
-
|
|
3552
|
-
|
|
3553
|
-
|
|
3554
|
-
|
|
3555
|
-
}
|
|
3556
|
-
const vmapRules = {
|
|
3557
|
-
[Primitive.Add]: broadcastBatcher(add$1),
|
|
3558
|
-
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3559
|
-
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3560
|
-
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3561
|
-
[Primitive.Neg]: unopBatcher(neg),
|
|
3562
|
-
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3563
|
-
[Primitive.Floor]: unopBatcher(floor$1),
|
|
3564
|
-
[Primitive.Ceil]: unopBatcher(ceil$1),
|
|
3565
|
-
[Primitive.StopGradient]: unopBatcher(stopGradient),
|
|
3566
|
-
[Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
|
|
3567
|
-
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3568
|
-
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3569
|
-
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3570
|
-
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3571
|
-
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3572
|
-
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3573
|
-
[Primitive.Log]: unopBatcher(log$1),
|
|
3574
|
-
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3575
|
-
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3576
|
-
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3577
|
-
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3578
|
-
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3579
|
-
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3580
|
-
require_backend.assertNonNull(xBdim);
|
|
3581
|
-
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3582
|
-
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3583
|
-
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3584
|
-
},
|
|
3585
|
-
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3586
|
-
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3587
|
-
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3588
|
-
const z = dot$2(x, y);
|
|
3589
|
-
return [[z], [z.ndim - 1]];
|
|
3590
|
-
},
|
|
3591
|
-
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3592
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3593
|
-
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3594
|
-
const z = conv$1(x, y, {
|
|
3595
|
-
...params,
|
|
3596
|
-
vmapDims: params.vmapDims + 1
|
|
3597
|
-
});
|
|
3598
|
-
return [[z], [0]];
|
|
3599
|
-
},
|
|
3600
|
-
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3601
|
-
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3630
|
+
return m;
|
|
3631
|
+
});
|
|
3632
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3633
|
+
else {
|
|
3634
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3635
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3636
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
|
|
3637
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3638
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3639
|
+
}
|
|
3602
3640
|
},
|
|
3603
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3604
3641
|
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3605
3642
|
require_backend.assertNonNull(xBdim);
|
|
3606
3643
|
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
@@ -3632,42 +3669,53 @@ const vmapRules = {
|
|
|
3632
3669
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3633
3670
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3634
3671
|
},
|
|
3635
|
-
[Primitive.
|
|
3636
|
-
|
|
3637
|
-
|
|
3638
|
-
|
|
3639
|
-
|
|
3640
|
-
|
|
3641
|
-
|
|
3642
|
-
|
|
3643
|
-
|
|
3644
|
-
|
|
3645
|
-
|
|
3646
|
-
|
|
3647
|
-
|
|
3648
|
-
|
|
3649
|
-
|
|
3650
|
-
|
|
3651
|
-
|
|
3652
|
-
...
|
|
3672
|
+
[Primitive.Sort](axisSize, [x], [xBdim]) {
|
|
3673
|
+
require_backend.assertNonNull(xBdim);
|
|
3674
|
+
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3675
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3676
|
+
return [[sort$1(x)], [0]];
|
|
3677
|
+
},
|
|
3678
|
+
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3679
|
+
require_backend.assertNonNull(xBdim);
|
|
3680
|
+
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3681
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3682
|
+
return [argsort$1(x), [0, 0]];
|
|
3683
|
+
},
|
|
3684
|
+
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3685
|
+
if (aBdim === null) {
|
|
3686
|
+
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
3687
|
+
const [s, m, n] = b.shape.slice(-3);
|
|
3688
|
+
b = b.reshape([
|
|
3689
|
+
...b.shape.slice(0, -3),
|
|
3690
|
+
s * m,
|
|
3691
|
+
n
|
|
3653
3692
|
]);
|
|
3654
|
-
|
|
3655
|
-
|
|
3656
|
-
|
|
3657
|
-
|
|
3658
|
-
|
|
3659
|
-
|
|
3660
|
-
|
|
3661
|
-
|
|
3662
|
-
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3693
|
+
let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3694
|
+
x$1 = x$1.reshape([
|
|
3695
|
+
...b.shape.slice(0, -2),
|
|
3696
|
+
s,
|
|
3697
|
+
m,
|
|
3698
|
+
n
|
|
3699
|
+
]);
|
|
3700
|
+
return [[x$1], [x$1.ndim - 3]];
|
|
3663
3701
|
}
|
|
3702
|
+
a = moveBatchAxis(axisSize, aBdim, 0, a);
|
|
3703
|
+
b = moveBatchAxis(axisSize, bBdim, 0, b);
|
|
3704
|
+
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3705
|
+
return [[x], [0]];
|
|
3664
3706
|
},
|
|
3665
|
-
[Primitive.
|
|
3666
|
-
|
|
3667
|
-
|
|
3707
|
+
[Primitive.Cholesky](axisSize, [x], [xBdim]) {
|
|
3708
|
+
require_backend.assertNonNull(xBdim);
|
|
3709
|
+
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3710
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3711
|
+
return [[cholesky$2(x)], [0]];
|
|
3712
|
+
},
|
|
3713
|
+
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3714
|
+
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3715
|
+
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
3668
3716
|
name: `${name}_vmap`,
|
|
3669
|
-
jaxpr: newJaxpr,
|
|
3670
|
-
numConsts:
|
|
3717
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3718
|
+
numConsts: newJaxpr.consts.length
|
|
3671
3719
|
});
|
|
3672
3720
|
return [outs, require_backend.rep(outs.length, 0)];
|
|
3673
3721
|
}
|
|
@@ -3683,14 +3731,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3683
3731
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3684
3732
|
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3685
3733
|
});
|
|
3686
|
-
const { jaxpr: newJaxpr
|
|
3687
|
-
const result = {
|
|
3688
|
-
newJaxpr,
|
|
3689
|
-
newConsts
|
|
3690
|
-
};
|
|
3734
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3691
3735
|
if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
3692
|
-
vmapJaxprCache.get(jaxpr).set(cacheKey,
|
|
3693
|
-
return
|
|
3736
|
+
vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
3737
|
+
return newJaxpr;
|
|
3694
3738
|
}
|
|
3695
3739
|
function vmapFlat(f, inAxes, args) {
|
|
3696
3740
|
let axisSize = void 0;
|
|
@@ -3704,7 +3748,7 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3704
3748
|
if (axisSize === void 0) throw new TypeError("vmap requires at least one mapped axis");
|
|
3705
3749
|
let valsOut, bdimsOut;
|
|
3706
3750
|
try {
|
|
3707
|
-
var _usingCtx$1 = (0, import_usingCtx.default)();
|
|
3751
|
+
var _usingCtx$1 = (0, import_usingCtx$1.default)();
|
|
3708
3752
|
const main = _usingCtx$1.u(newMain(BatchTrace, axisSize));
|
|
3709
3753
|
const trace$1 = new BatchTrace(main);
|
|
3710
3754
|
const tracersIn = args.map((x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace$1, pureArray(x), inAxes[i]));
|
|
@@ -3745,6 +3789,261 @@ function jacfwd$1(f) {
|
|
|
3745
3789
|
};
|
|
3746
3790
|
}
|
|
3747
3791
|
|
|
3792
|
+
//#endregion
|
|
3793
|
+
//#region src/frontend/jvp.ts
|
|
3794
|
+
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3795
|
+
var JVPTracer = class extends Tracer {
|
|
3796
|
+
constructor(trace$1, primal, tangent) {
|
|
3797
|
+
super(trace$1);
|
|
3798
|
+
this.primal = primal;
|
|
3799
|
+
this.tangent = tangent;
|
|
3800
|
+
}
|
|
3801
|
+
get aval() {
|
|
3802
|
+
return this.primal.aval;
|
|
3803
|
+
}
|
|
3804
|
+
toString() {
|
|
3805
|
+
return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
|
|
3806
|
+
}
|
|
3807
|
+
get ref() {
|
|
3808
|
+
this.primal.ref, this.tangent.ref;
|
|
3809
|
+
return this;
|
|
3810
|
+
}
|
|
3811
|
+
dispose() {
|
|
3812
|
+
this.primal.dispose();
|
|
3813
|
+
this.tangent.dispose();
|
|
3814
|
+
}
|
|
3815
|
+
};
|
|
3816
|
+
var JVPTrace = class extends Trace {
|
|
3817
|
+
pure(val) {
|
|
3818
|
+
return this.lift(pureArray(val));
|
|
3819
|
+
}
|
|
3820
|
+
lift(val) {
|
|
3821
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
3822
|
+
}
|
|
3823
|
+
processPrimitive(primitive, tracers, params) {
|
|
3824
|
+
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
3825
|
+
const jvpRule = jvpRules[primitive];
|
|
3826
|
+
if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
|
|
3827
|
+
const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
|
|
3828
|
+
return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
|
|
3829
|
+
}
|
|
3830
|
+
};
|
|
3831
|
+
/** Rule that applies the same operation to primals and tangents. */
|
|
3832
|
+
function linearTangentsJvp(primitive) {
|
|
3833
|
+
return (primals, tangents, params) => {
|
|
3834
|
+
const ys = bind(primitive, primals, params);
|
|
3835
|
+
const dys = bind(primitive, tangents, params);
|
|
3836
|
+
return [ys, dys];
|
|
3837
|
+
};
|
|
3838
|
+
}
|
|
3839
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
3840
|
+
function bilinearTangentsJvp(primitive) {
|
|
3841
|
+
return ([x, y], [dx, dy], params) => {
|
|
3842
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3843
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3844
|
+
return [[primal], [tangent]];
|
|
3845
|
+
};
|
|
3846
|
+
}
|
|
3847
|
+
/** Rule that zeros out any tangents. */
|
|
3848
|
+
function zeroTangentsJvp(primitive) {
|
|
3849
|
+
return (primals, tangents, params) => {
|
|
3850
|
+
for (const t of tangents) t.dispose();
|
|
3851
|
+
const ys = bind(primitive, primals, params);
|
|
3852
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3853
|
+
};
|
|
3854
|
+
}
|
|
3855
|
+
/** Compute `a @ b.T`, batched to last two axes. */
|
|
3856
|
+
function batchMatmulT(a, b) {
|
|
3857
|
+
return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
|
|
3858
|
+
}
|
|
3859
|
+
/** Batch matrix transpose. */
|
|
3860
|
+
function mT(a) {
|
|
3861
|
+
return moveaxis(a, -2, -1);
|
|
3862
|
+
}
|
|
3863
|
+
const jvpRules = {
|
|
3864
|
+
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3865
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
3866
|
+
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
3867
|
+
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
3868
|
+
if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
|
|
3869
|
+
dx.dispose();
|
|
3870
|
+
dy.dispose();
|
|
3871
|
+
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
3872
|
+
}
|
|
3873
|
+
const q = idiv(x.ref, y.ref);
|
|
3874
|
+
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
3875
|
+
},
|
|
3876
|
+
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3877
|
+
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3878
|
+
},
|
|
3879
|
+
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3880
|
+
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3881
|
+
},
|
|
3882
|
+
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
3883
|
+
[Primitive.Reciprocal]([x], [dx]) {
|
|
3884
|
+
const xRecip = reciprocal$1(x.ref);
|
|
3885
|
+
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
3886
|
+
},
|
|
3887
|
+
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
3888
|
+
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
3889
|
+
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
3890
|
+
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
3891
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3892
|
+
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
3893
|
+
else {
|
|
3894
|
+
dx.dispose();
|
|
3895
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3896
|
+
}
|
|
3897
|
+
},
|
|
3898
|
+
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3899
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3900
|
+
dx.dispose();
|
|
3901
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3902
|
+
},
|
|
3903
|
+
[Primitive.Sin]([x], [dx]) {
|
|
3904
|
+
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3905
|
+
},
|
|
3906
|
+
[Primitive.Cos]([x], [dx]) {
|
|
3907
|
+
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
3908
|
+
},
|
|
3909
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3910
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3911
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3912
|
+
},
|
|
3913
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3914
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3915
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3916
|
+
},
|
|
3917
|
+
[Primitive.Exp]([x], [dx]) {
|
|
3918
|
+
const z = exp$1(x);
|
|
3919
|
+
return [[z.ref], [z.mul(dx)]];
|
|
3920
|
+
},
|
|
3921
|
+
[Primitive.Log]([x], [dx]) {
|
|
3922
|
+
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3923
|
+
},
|
|
3924
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3925
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3926
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3927
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3928
|
+
},
|
|
3929
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3930
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3931
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3932
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3933
|
+
},
|
|
3934
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
3935
|
+
const z = sqrt$1(x);
|
|
3936
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3937
|
+
},
|
|
3938
|
+
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3939
|
+
if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3940
|
+
else if (op === require_backend.AluOp.Mul) {
|
|
3941
|
+
const primal = reduce(x.ref, op, axis);
|
|
3942
|
+
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3943
|
+
return [[primal], [tangent]];
|
|
3944
|
+
} else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
|
|
3945
|
+
const primal = reduce(x.ref, op, axis);
|
|
3946
|
+
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3947
|
+
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3948
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3949
|
+
return [[primal], [tangent]];
|
|
3950
|
+
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3951
|
+
},
|
|
3952
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3953
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3954
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3955
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3956
|
+
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3957
|
+
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3958
|
+
dcond.dispose();
|
|
3959
|
+
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3960
|
+
},
|
|
3961
|
+
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3962
|
+
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3963
|
+
const indicesRef = indices.map((t) => t.ref);
|
|
3964
|
+
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3965
|
+
},
|
|
3966
|
+
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3967
|
+
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3968
|
+
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3969
|
+
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3970
|
+
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3971
|
+
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3972
|
+
[Primitive.Sort]([x], [dx]) {
|
|
3973
|
+
const [y, idx] = argsort$1(x);
|
|
3974
|
+
return [[y], [gather(dx, [idx], [-1], -1)]];
|
|
3975
|
+
},
|
|
3976
|
+
[Primitive.Argsort]([x], [dx]) {
|
|
3977
|
+
const [y, idx] = argsort$1(x);
|
|
3978
|
+
return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
|
|
3979
|
+
},
|
|
3980
|
+
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
3981
|
+
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
3982
|
+
const dax = batchMatmulT(da, x.ref);
|
|
3983
|
+
const rhsT = db.sub(mT(dax));
|
|
3984
|
+
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
3985
|
+
return [[x], [dx]];
|
|
3986
|
+
},
|
|
3987
|
+
[Primitive.Cholesky]([a], [da]) {
|
|
3988
|
+
const L = cholesky$2(a.ref);
|
|
3989
|
+
da = da.ref.add(mT(da)).mul(.5);
|
|
3990
|
+
const W = triangularSolve$1(L.ref, da, { lower: true });
|
|
3991
|
+
const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
|
|
3992
|
+
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3993
|
+
return [[L], [dL]];
|
|
3994
|
+
},
|
|
3995
|
+
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3996
|
+
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3997
|
+
const outs = bind(Primitive.Jit, [
|
|
3998
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
3999
|
+
...primals,
|
|
4000
|
+
...tangents
|
|
4001
|
+
], {
|
|
4002
|
+
name: `${name}_jvp`,
|
|
4003
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4004
|
+
numConsts: newJaxpr.consts.length
|
|
4005
|
+
});
|
|
4006
|
+
const n = outs.length / 2;
|
|
4007
|
+
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
4008
|
+
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
4009
|
+
return [primalsOut, tangentsOut];
|
|
4010
|
+
}
|
|
4011
|
+
};
|
|
4012
|
+
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
4013
|
+
function jvpJaxpr(jaxpr) {
|
|
4014
|
+
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
4015
|
+
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
4016
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
4017
|
+
jvpJaxprCache.set(jaxpr, newJaxpr);
|
|
4018
|
+
return newJaxpr;
|
|
4019
|
+
}
|
|
4020
|
+
function jvpFlat(f, primals, tangents) {
|
|
4021
|
+
try {
|
|
4022
|
+
var _usingCtx$1 = (0, import_usingCtx.default)();
|
|
4023
|
+
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
4024
|
+
const trace$1 = new JVPTrace(main);
|
|
4025
|
+
const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
4026
|
+
const outs = f(...tracersIn);
|
|
4027
|
+
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
4028
|
+
return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
4029
|
+
} catch (_) {
|
|
4030
|
+
_usingCtx$1.e = _;
|
|
4031
|
+
} finally {
|
|
4032
|
+
_usingCtx$1.d();
|
|
4033
|
+
}
|
|
4034
|
+
}
|
|
4035
|
+
function jvp$1(f, primals, tangents) {
|
|
4036
|
+
const [primalsFlat, inTree] = flatten(primals);
|
|
4037
|
+
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4038
|
+
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4039
|
+
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
4040
|
+
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4041
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4042
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4043
|
+
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4044
|
+
return [primalsOut, tangentsOut];
|
|
4045
|
+
}
|
|
4046
|
+
|
|
3748
4047
|
//#endregion
|
|
3749
4048
|
//#region src/frontend/linearize.ts
|
|
3750
4049
|
/** Array value that can either be known or unknown. */
|
|
@@ -3775,11 +4074,10 @@ function partialEvalFlat(f, pvalsIn) {
|
|
|
3775
4074
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3776
4075
|
const pvalsOut = tracersOut.map((t) => t.pval);
|
|
3777
4076
|
const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
|
|
3778
|
-
const
|
|
4077
|
+
const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
|
|
3779
4078
|
return {
|
|
3780
4079
|
jaxpr,
|
|
3781
|
-
pvalsOut
|
|
3782
|
-
consts
|
|
4080
|
+
pvalsOut
|
|
3783
4081
|
};
|
|
3784
4082
|
}
|
|
3785
4083
|
/**
|
|
@@ -3796,22 +4094,19 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3796
4094
|
const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
|
|
3797
4095
|
return [...primalsOut$1, ...tangentsOut];
|
|
3798
4096
|
};
|
|
3799
|
-
const { jaxpr, pvalsOut
|
|
4097
|
+
const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
|
|
3800
4098
|
const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
|
|
3801
4099
|
if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
|
|
3802
4100
|
const primalsOut = primalPvals.map((pval) => pval.val);
|
|
3803
4101
|
return {
|
|
3804
4102
|
primalsOut,
|
|
3805
|
-
jaxpr
|
|
3806
|
-
consts
|
|
4103
|
+
jaxpr
|
|
3807
4104
|
};
|
|
3808
4105
|
}
|
|
3809
4106
|
function linearizeFlat(f, primalsIn) {
|
|
3810
|
-
const { primalsOut, jaxpr
|
|
3811
|
-
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3812
|
-
const dispose$1 = () =>
|
|
3813
|
-
for (const c of consts) c.dispose();
|
|
3814
|
-
};
|
|
4107
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4108
|
+
const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
|
|
4109
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
3815
4110
|
return [
|
|
3816
4111
|
primalsOut,
|
|
3817
4112
|
fLin,
|
|
@@ -3895,7 +4190,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3895
4190
|
}
|
|
3896
4191
|
processPrimitive(primitive, tracers, params) {
|
|
3897
4192
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3898
|
-
if (primitive === Primitive.
|
|
4193
|
+
if (primitive === Primitive.Jit) {
|
|
3899
4194
|
const { name, jaxpr, numConsts } = params;
|
|
3900
4195
|
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3901
4196
|
}
|
|
@@ -3921,14 +4216,14 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3921
4216
|
* Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
|
|
3922
4217
|
* values as possible (with JIT) and forwarding the unknown ones.
|
|
3923
4218
|
*
|
|
3924
|
-
* Used when encountering a
|
|
4219
|
+
* Used when encountering a Jit rule during the trace.
|
|
3925
4220
|
*/
|
|
3926
4221
|
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3927
4222
|
jaxpr = jaxpr.flatten();
|
|
3928
4223
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3929
4224
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3930
4225
|
const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
|
|
3931
|
-
const outs1Res = bind(Primitive.
|
|
4226
|
+
const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3932
4227
|
name: `${name}_peval`,
|
|
3933
4228
|
jaxpr: jaxpr1,
|
|
3934
4229
|
numConsts: 0
|
|
@@ -3938,7 +4233,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3938
4233
|
const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
|
|
3939
4234
|
const recipe = {
|
|
3940
4235
|
type: "JaxprEqn",
|
|
3941
|
-
prim: Primitive.
|
|
4236
|
+
prim: Primitive.Jit,
|
|
3942
4237
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3943
4238
|
params: {
|
|
3944
4239
|
name: `${name}_resid`,
|
|
@@ -3967,7 +4262,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
|
|
|
3967
4262
|
const eqns1 = [];
|
|
3968
4263
|
const eqns2 = [];
|
|
3969
4264
|
for (const eqn of jaxpr.eqns) {
|
|
3970
|
-
if (eqn.primitive === Primitive.
|
|
4265
|
+
if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
|
|
3971
4266
|
const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
|
|
3972
4267
|
if (hasUnknowns) {
|
|
3973
4268
|
for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
|
|
@@ -4042,10 +4337,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
4042
4337
|
for (const t of tracersOut) t.dispose();
|
|
4043
4338
|
jaxpr = jaxpr.simplify();
|
|
4044
4339
|
if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
4045
|
-
return
|
|
4046
|
-
jaxpr,
|
|
4047
|
-
consts
|
|
4048
|
-
};
|
|
4340
|
+
return new ClosedJaxpr(jaxpr, consts);
|
|
4049
4341
|
}
|
|
4050
4342
|
/** Marker type for pullback, used by transpose rules. */
|
|
4051
4343
|
var UndefPrimal = class {
|
|
@@ -4237,317 +4529,142 @@ const transposeRules = {
|
|
|
4237
4529
|
cond.dispose();
|
|
4238
4530
|
return cts;
|
|
4239
4531
|
},
|
|
4240
|
-
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4241
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4242
|
-
return [transpose$1(ct, require_backend.invertPermutation(perm))];
|
|
4243
|
-
},
|
|
4244
|
-
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4245
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4246
|
-
return [reduce(ct, require_backend.AluOp.Add, axis)];
|
|
4247
|
-
},
|
|
4248
|
-
[Primitive.Reshape]([ct], [x], _) {
|
|
4249
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4250
|
-
return [reshape$1(ct, x.aval.shape)];
|
|
4251
|
-
},
|
|
4252
|
-
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4253
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4254
|
-
return [flip$1(ct, axis)];
|
|
4255
|
-
},
|
|
4256
|
-
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4257
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4258
|
-
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4259
|
-
return [pad$1(ct, width)];
|
|
4260
|
-
},
|
|
4261
|
-
[Primitive.Pad]([ct], [x], { width }) {
|
|
4262
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
|
|
4263
|
-
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4264
|
-
return [shrink(ct, slice)];
|
|
4265
|
-
},
|
|
4266
4532
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4267
4533
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4268
4534
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4269
4535
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4270
4536
|
},
|
|
4271
|
-
[Primitive.
|
|
4272
|
-
|
|
4273
|
-
|
|
4274
|
-
|
|
4275
|
-
|
|
4276
|
-
|
|
4277
|
-
|
|
4278
|
-
|
|
4279
|
-
|
|
4280
|
-
|
|
4281
|
-
|
|
4282
|
-
|
|
4283
|
-
|
|
4284
|
-
|
|
4285
|
-
return
|
|
4286
|
-
}
|
|
4287
|
-
}
|
|
4288
|
-
|
|
4289
|
-
|
|
4290
|
-
|
|
4291
|
-
|
|
4292
|
-
|
|
4293
|
-
|
|
4294
|
-
|
|
4295
|
-
|
|
4296
|
-
|
|
4297
|
-
|
|
4298
|
-
|
|
4299
|
-
|
|
4300
|
-
|
|
4301
|
-
|
|
4302
|
-
typecheckJaxpr(newJaxpr);
|
|
4303
|
-
const result = {
|
|
4304
|
-
newJaxpr,
|
|
4305
|
-
newConsts
|
|
4306
|
-
};
|
|
4307
|
-
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4308
|
-
transposeJaxprCache.get(jaxpr).set(cacheKey, result);
|
|
4309
|
-
return result;
|
|
4310
|
-
}
|
|
4311
|
-
function vjpFlat(f, primalsIn) {
|
|
4312
|
-
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
4313
|
-
const fVjp = (...cotangents) => {
|
|
4314
|
-
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4315
|
-
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
4316
|
-
};
|
|
4317
|
-
const dispose$1 = () => {
|
|
4318
|
-
for (const c of consts) c.dispose();
|
|
4319
|
-
};
|
|
4320
|
-
return [
|
|
4321
|
-
primalsOut,
|
|
4322
|
-
fVjp,
|
|
4323
|
-
dispose$1
|
|
4324
|
-
];
|
|
4325
|
-
}
|
|
4326
|
-
function vjp$1(f, ...primalsIn) {
|
|
4327
|
-
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4328
|
-
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4329
|
-
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4330
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4331
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4332
|
-
const fVjp = ((cotangentsOut) => {
|
|
4333
|
-
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4334
|
-
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4335
|
-
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4336
|
-
return unflatten(inTree, cotangentsInFlat);
|
|
4337
|
-
});
|
|
4338
|
-
fVjp.dispose = dispose$1;
|
|
4339
|
-
return [primalsOut, fVjp];
|
|
4340
|
-
}
|
|
4341
|
-
function grad$1(f) {
|
|
4342
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4343
|
-
return (...x) => {
|
|
4344
|
-
const [y, dx] = valueAndGradFn(...x);
|
|
4345
|
-
y.dispose();
|
|
4346
|
-
return dx;
|
|
4347
|
-
};
|
|
4348
|
-
}
|
|
4349
|
-
function valueAndGrad$1(f) {
|
|
4350
|
-
return (...x) => {
|
|
4351
|
-
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4352
|
-
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4353
|
-
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4354
|
-
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4355
|
-
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4356
|
-
for (const r of rest) dispose(r);
|
|
4357
|
-
fVjp.dispose();
|
|
4358
|
-
return [y, ct];
|
|
4359
|
-
};
|
|
4360
|
-
}
|
|
4361
|
-
function jacrev$1(f) {
|
|
4362
|
-
return function jacobianReverse(x) {
|
|
4363
|
-
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4364
|
-
const [size$1] = x.shape;
|
|
4365
|
-
const pullback = (ct) => {
|
|
4366
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4367
|
-
y.dispose();
|
|
4368
|
-
const [ret] = fVjp(ct);
|
|
4369
|
-
fVjp.dispose();
|
|
4370
|
-
return ret;
|
|
4371
|
-
};
|
|
4372
|
-
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4373
|
-
};
|
|
4374
|
-
}
|
|
4375
|
-
|
|
4376
|
-
//#endregion
|
|
4377
|
-
//#region src/library/lax.ts
|
|
4378
|
-
var lax_exports = {};
|
|
4379
|
-
__export(lax_exports, {
|
|
4380
|
-
conv: () => conv,
|
|
4381
|
-
convGeneralDilated: () => convGeneralDilated,
|
|
4382
|
-
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4383
|
-
dot: () => dot$1,
|
|
4384
|
-
erf: () => erf,
|
|
4385
|
-
erfc: () => erfc,
|
|
4386
|
-
reduceWindow: () => reduceWindow,
|
|
4387
|
-
stopGradient: () => stopGradient$1
|
|
4388
|
-
});
|
|
4389
|
-
/**
|
|
4390
|
-
* General dot product/contraction operator.
|
|
4391
|
-
*
|
|
4392
|
-
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
4393
|
-
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
4394
|
-
*/
|
|
4395
|
-
function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
4396
|
-
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
4397
|
-
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
4398
|
-
lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
4399
|
-
rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
4400
|
-
lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
4401
|
-
rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
4402
|
-
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
4403
|
-
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
4404
|
-
const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
4405
|
-
const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
4406
|
-
const lhs2 = lhs.transpose([
|
|
4407
|
-
...lb,
|
|
4408
|
-
...lf,
|
|
4409
|
-
...lc
|
|
4410
|
-
]);
|
|
4411
|
-
const rhs2 = rhs.transpose([
|
|
4412
|
-
...rb,
|
|
4413
|
-
...rf,
|
|
4414
|
-
...rc
|
|
4415
|
-
]);
|
|
4416
|
-
if (lc.length === 0) return mul(lhs2.reshape([
|
|
4417
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4418
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4419
|
-
...require_backend.rep(rf.length, 1)
|
|
4420
|
-
]), rhs2.reshape([
|
|
4421
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4422
|
-
...require_backend.rep(lf.length, 1),
|
|
4423
|
-
...rf.map((a) => rhs.shape[a])
|
|
4424
|
-
]));
|
|
4425
|
-
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
4426
|
-
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
4427
|
-
if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
4428
|
-
return dot$2(lhs2.reshape([
|
|
4429
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4430
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4431
|
-
...require_backend.rep(rf.length, 1),
|
|
4432
|
-
require_backend.prod(dotShapeX)
|
|
4433
|
-
]), rhs2.reshape([
|
|
4434
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4435
|
-
...require_backend.rep(lf.length, 1),
|
|
4436
|
-
...rf.map((a) => rhs.shape[a]),
|
|
4437
|
-
require_backend.prod(dotShapeY)
|
|
4438
|
-
]));
|
|
4439
|
-
}
|
|
4440
|
-
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4441
|
-
const padType = padding.toUpperCase();
|
|
4442
|
-
switch (padType) {
|
|
4443
|
-
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
4444
|
-
case "SAME":
|
|
4445
|
-
case "SAME_LOWER": {
|
|
4446
|
-
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
4447
|
-
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
4448
|
-
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
4449
|
-
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
4450
|
-
}
|
|
4451
|
-
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
4452
|
-
}
|
|
4453
|
-
}
|
|
4454
|
-
/**
|
|
4455
|
-
* General n-dimensional convolution operator, with optional dilation.
|
|
4456
|
-
*
|
|
4457
|
-
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
4458
|
-
* function in JAX, which wraps XLA's general convolution operator.
|
|
4459
|
-
*
|
|
4460
|
-
* Grouped convolutions are not supported right now.
|
|
4461
|
-
*/
|
|
4462
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
4463
|
-
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4464
|
-
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4465
|
-
if (typeof padding === "string") {
|
|
4466
|
-
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4467
|
-
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
4468
|
-
}
|
|
4469
|
-
if (featureGroupCount !== 1) {
|
|
4470
|
-
const G = featureGroupCount;
|
|
4471
|
-
const [N, C_in, ...xs] = lhs.shape;
|
|
4472
|
-
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
4473
|
-
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
4474
|
-
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
4475
|
-
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
4476
|
-
const lhsGrouped = moveaxis(lhs.reshape([
|
|
4477
|
-
N,
|
|
4478
|
-
G,
|
|
4479
|
-
C_in / G,
|
|
4480
|
-
...xs
|
|
4481
|
-
]), 1, 0);
|
|
4482
|
-
const rhsGrouped = rhs.reshape([
|
|
4483
|
-
G,
|
|
4484
|
-
C_out / G,
|
|
4485
|
-
C_in_per_group,
|
|
4486
|
-
...ks
|
|
4487
|
-
]);
|
|
4488
|
-
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
4489
|
-
vmapDims: 1,
|
|
4490
|
-
strides: windowStrides,
|
|
4491
|
-
padding,
|
|
4492
|
-
lhsDilation,
|
|
4493
|
-
rhsDilation
|
|
4537
|
+
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4538
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4539
|
+
return [transpose$1(ct, require_backend.invertPermutation(perm))];
|
|
4540
|
+
},
|
|
4541
|
+
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4542
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4543
|
+
return [reduce(ct, require_backend.AluOp.Add, axis)];
|
|
4544
|
+
},
|
|
4545
|
+
[Primitive.Reshape]([ct], [x], _) {
|
|
4546
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4547
|
+
return [reshape$1(ct, x.aval.shape)];
|
|
4548
|
+
},
|
|
4549
|
+
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4550
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4551
|
+
return [flip$1(ct, axis)];
|
|
4552
|
+
},
|
|
4553
|
+
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4554
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4555
|
+
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4556
|
+
return [pad$1(ct, width)];
|
|
4557
|
+
},
|
|
4558
|
+
[Primitive.Pad]([ct], [x], { width }) {
|
|
4559
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
|
|
4560
|
+
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4561
|
+
return [shrink(ct, slice)];
|
|
4562
|
+
},
|
|
4563
|
+
[Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
|
|
4564
|
+
if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
|
|
4565
|
+
const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
|
|
4566
|
+
lower: true,
|
|
4567
|
+
unitDiagonal
|
|
4494
4568
|
});
|
|
4495
|
-
|
|
4496
|
-
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
]);
|
|
4569
|
+
return [null, ctB];
|
|
4570
|
+
},
|
|
4571
|
+
[Primitive.Jit](cts, args, { name, jaxpr }) {
|
|
4572
|
+
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
4573
|
+
const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
|
|
4574
|
+
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
4575
|
+
const outs = bind(Primitive.Jit, [
|
|
4576
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4577
|
+
...residuals,
|
|
4578
|
+
...cts
|
|
4579
|
+
], {
|
|
4580
|
+
name: `${name}_t`,
|
|
4581
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4582
|
+
numConsts: newJaxpr.consts.length
|
|
4583
|
+
});
|
|
4584
|
+
let i = 0;
|
|
4585
|
+
return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
|
|
4501
4586
|
}
|
|
4502
|
-
|
|
4503
|
-
|
|
4504
|
-
|
|
4505
|
-
|
|
4506
|
-
|
|
4507
|
-
|
|
4508
|
-
}
|
|
4509
|
-
|
|
4510
|
-
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4514
|
-
|
|
4587
|
+
};
|
|
4588
|
+
const transposeJaxprCache = /* @__PURE__ */ new Map();
|
|
4589
|
+
function transposeJaxpr(jaxpr, undefPrimals) {
|
|
4590
|
+
const cacheKey = JSON.stringify(undefPrimals);
|
|
4591
|
+
const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
|
|
4592
|
+
if (prevResult) return prevResult;
|
|
4593
|
+
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
4594
|
+
const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
|
|
4595
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
|
|
4596
|
+
const args = [];
|
|
4597
|
+
let forwardInIdx = 0;
|
|
4598
|
+
for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
|
|
4599
|
+
else args.push(forwardIn[forwardInIdx++]);
|
|
4600
|
+
return evalJaxprTransposed(jaxpr, args, cotangents);
|
|
4601
|
+
})(forwardInTypes, outTypes);
|
|
4602
|
+
typecheckJaxpr(newJaxpr.jaxpr);
|
|
4603
|
+
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4604
|
+
transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
4605
|
+
return newJaxpr;
|
|
4515
4606
|
}
|
|
4516
|
-
|
|
4517
|
-
|
|
4518
|
-
|
|
4607
|
+
function vjpFlat(f, primalsIn) {
|
|
4608
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4609
|
+
const fVjp = (...cotangents) => {
|
|
4610
|
+
const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4611
|
+
return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
|
|
4612
|
+
};
|
|
4613
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
4614
|
+
return [
|
|
4615
|
+
primalsOut,
|
|
4616
|
+
fVjp,
|
|
4617
|
+
dispose$1
|
|
4618
|
+
];
|
|
4519
4619
|
}
|
|
4520
|
-
|
|
4521
|
-
|
|
4522
|
-
|
|
4523
|
-
|
|
4524
|
-
|
|
4525
|
-
|
|
4526
|
-
|
|
4527
|
-
|
|
4528
|
-
|
|
4620
|
+
function vjp$1(f, ...primalsIn) {
|
|
4621
|
+
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4622
|
+
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4623
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4624
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4625
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4626
|
+
const fVjp = ((cotangentsOut) => {
|
|
4627
|
+
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4628
|
+
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4629
|
+
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4630
|
+
return unflatten(inTree, cotangentsInFlat);
|
|
4631
|
+
});
|
|
4632
|
+
fVjp.dispose = dispose$1;
|
|
4633
|
+
return [primalsOut, fVjp];
|
|
4529
4634
|
}
|
|
4530
|
-
|
|
4531
|
-
|
|
4532
|
-
return
|
|
4635
|
+
function grad$1(f) {
|
|
4636
|
+
const valueAndGradFn = valueAndGrad$1(f);
|
|
4637
|
+
return (...x) => {
|
|
4638
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4639
|
+
y.dispose();
|
|
4640
|
+
return dx;
|
|
4641
|
+
};
|
|
4533
4642
|
}
|
|
4534
|
-
|
|
4535
|
-
|
|
4536
|
-
|
|
4537
|
-
|
|
4538
|
-
|
|
4539
|
-
|
|
4540
|
-
|
|
4541
|
-
|
|
4643
|
+
function valueAndGrad$1(f) {
|
|
4644
|
+
return (...x) => {
|
|
4645
|
+
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4646
|
+
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4647
|
+
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4648
|
+
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4649
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4650
|
+
for (const r of rest) dispose(r);
|
|
4651
|
+
fVjp.dispose();
|
|
4652
|
+
return [y, ct];
|
|
4653
|
+
};
|
|
4542
4654
|
}
|
|
4543
|
-
|
|
4544
|
-
|
|
4545
|
-
|
|
4546
|
-
|
|
4547
|
-
|
|
4548
|
-
|
|
4549
|
-
|
|
4550
|
-
|
|
4655
|
+
function jacrev$1(f) {
|
|
4656
|
+
return function jacobianReverse(x) {
|
|
4657
|
+
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4658
|
+
const [size$1] = x.shape;
|
|
4659
|
+
const pullback = (ct) => {
|
|
4660
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
4661
|
+
y.dispose();
|
|
4662
|
+
const [ret] = fVjp(ct);
|
|
4663
|
+
fVjp.dispose();
|
|
4664
|
+
return ret;
|
|
4665
|
+
};
|
|
4666
|
+
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4667
|
+
};
|
|
4551
4668
|
}
|
|
4552
4669
|
|
|
4553
4670
|
//#endregion
|
|
@@ -4745,34 +4862,207 @@ function* allPaths(tensors, next) {
|
|
|
4745
4862
|
}
|
|
4746
4863
|
}
|
|
4747
4864
|
|
|
4865
|
+
//#endregion
|
|
4866
|
+
//#region src/library/numpy-fft.ts
|
|
4867
|
+
var numpy_fft_exports = {};
|
|
4868
|
+
__export(numpy_fft_exports, {
|
|
4869
|
+
fft: () => fft,
|
|
4870
|
+
ifft: () => ifft
|
|
4871
|
+
});
|
|
4872
|
+
function checkPairInput(name, a) {
|
|
4873
|
+
const fullName = `jax.numpy.fft.${name}`;
|
|
4874
|
+
if (!require_backend.deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
|
|
4875
|
+
if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
|
|
4876
|
+
if (!require_backend.isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
|
|
4877
|
+
}
|
|
4878
|
+
function checkPowerOfTwo(name, n) {
|
|
4879
|
+
if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
|
|
4880
|
+
}
|
|
4881
|
+
const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
|
|
4882
|
+
const half = 2 ** i;
|
|
4883
|
+
real = real.reshape([-1, 2 * half]);
|
|
4884
|
+
imag = imag.reshape([-1, 2 * half]);
|
|
4885
|
+
const k = arange(0, half, 1, { dtype: real.dtype });
|
|
4886
|
+
const theta = k.mul(-Math.PI / half);
|
|
4887
|
+
const wr = cos(theta.ref);
|
|
4888
|
+
const wi = sin(theta);
|
|
4889
|
+
const ur = real.ref.slice([], [0, half]);
|
|
4890
|
+
const ui = imag.ref.slice([], [0, half]);
|
|
4891
|
+
const vr = real.slice([], [half, 2 * half]);
|
|
4892
|
+
const vi = imag.slice([], [half, 2 * half]);
|
|
4893
|
+
const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
|
|
4894
|
+
const ti = vr.mul(wi).add(vi.mul(wr));
|
|
4895
|
+
return {
|
|
4896
|
+
real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
|
|
4897
|
+
imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
|
|
4898
|
+
};
|
|
4899
|
+
}, { staticArgnums: [0] });
|
|
4900
|
+
/**
|
|
4901
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
4902
|
+
*
|
|
4903
|
+
* Currently, the size of the axis must be a power of two.
|
|
4904
|
+
*/
|
|
4905
|
+
function fft(a, axis = -1) {
|
|
4906
|
+
checkPairInput("fft", a);
|
|
4907
|
+
let { real, imag } = a;
|
|
4908
|
+
axis = require_backend.checkAxis(axis, real.ndim);
|
|
4909
|
+
const n = real.shape[axis];
|
|
4910
|
+
checkPowerOfTwo("fft", n);
|
|
4911
|
+
const logN = Math.log2(n);
|
|
4912
|
+
let perm = null;
|
|
4913
|
+
if (axis !== real.ndim - 1) {
|
|
4914
|
+
perm = require_backend.range(real.ndim);
|
|
4915
|
+
perm.splice(axis, 1);
|
|
4916
|
+
perm.push(axis);
|
|
4917
|
+
real = real.transpose(perm);
|
|
4918
|
+
imag = imag.transpose(perm);
|
|
4919
|
+
}
|
|
4920
|
+
const originalShape = real.shape;
|
|
4921
|
+
real = real.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
|
|
4922
|
+
imag = imag.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
|
|
4923
|
+
for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
|
|
4924
|
+
real,
|
|
4925
|
+
imag
|
|
4926
|
+
}));
|
|
4927
|
+
real = real.reshape(originalShape);
|
|
4928
|
+
imag = imag.reshape(originalShape);
|
|
4929
|
+
if (perm !== null) {
|
|
4930
|
+
real = real.transpose(require_backend.invertPermutation(perm));
|
|
4931
|
+
imag = imag.transpose(require_backend.invertPermutation(perm));
|
|
4932
|
+
}
|
|
4933
|
+
return {
|
|
4934
|
+
real,
|
|
4935
|
+
imag
|
|
4936
|
+
};
|
|
4937
|
+
}
|
|
4938
|
+
/**
|
|
4939
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
4940
|
+
*
|
|
4941
|
+
* Currently, the size of the axis must be a power of two.
|
|
4942
|
+
*/
|
|
4943
|
+
function ifft(a, axis = -1) {
|
|
4944
|
+
checkPairInput("ifft", a);
|
|
4945
|
+
let { real, imag } = a;
|
|
4946
|
+
axis = require_backend.checkAxis(axis, real.ndim);
|
|
4947
|
+
const n = real.shape[axis];
|
|
4948
|
+
checkPowerOfTwo("ifft", n);
|
|
4949
|
+
imag = imag.mul(-1);
|
|
4950
|
+
const result = fft({
|
|
4951
|
+
real,
|
|
4952
|
+
imag
|
|
4953
|
+
}, axis);
|
|
4954
|
+
return {
|
|
4955
|
+
real: result.real.div(n),
|
|
4956
|
+
imag: result.imag.mul(-1).div(n)
|
|
4957
|
+
};
|
|
4958
|
+
}
|
|
4959
|
+
|
|
4960
|
+
//#endregion
|
|
4961
|
+
//#region src/library/numpy-linalg.ts
|
|
4962
|
+
var numpy_linalg_exports = {};
|
|
4963
|
+
__export(numpy_linalg_exports, {
|
|
4964
|
+
cholesky: () => cholesky$1,
|
|
4965
|
+
diagonal: () => diagonal,
|
|
4966
|
+
lstsq: () => lstsq,
|
|
4967
|
+
matmul: () => matmul,
|
|
4968
|
+
matrixTranspose: () => matrixTranspose,
|
|
4969
|
+
outer: () => outer,
|
|
4970
|
+
tensordot: () => tensordot,
|
|
4971
|
+
trace: () => trace,
|
|
4972
|
+
vecdot: () => vecdot
|
|
4973
|
+
});
|
|
4974
|
+
/**
|
|
4975
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4976
|
+
*
|
|
4977
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4978
|
+
* the input matrix, which is on by default.
|
|
4979
|
+
*/
|
|
4980
|
+
function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4981
|
+
a = fudgeArray(a);
|
|
4982
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
|
|
4983
|
+
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4984
|
+
return cholesky(a, { upper });
|
|
4985
|
+
}
|
|
4986
|
+
/**
|
|
4987
|
+
* Return the least-squares solution to a linear equation.
|
|
4988
|
+
*
|
|
4989
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
4990
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
4991
|
+
*
|
|
4992
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
4993
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
4994
|
+
*
|
|
4995
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
4996
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
4997
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
4998
|
+
*/
|
|
4999
|
+
function lstsq(a, b) {
|
|
5000
|
+
a = fudgeArray(a);
|
|
5001
|
+
b = fudgeArray(b);
|
|
5002
|
+
if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
|
|
5003
|
+
const [m, n] = a.shape;
|
|
5004
|
+
if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
|
|
5005
|
+
const at = matrixTranspose(a.ref);
|
|
5006
|
+
if (m <= n) {
|
|
5007
|
+
const aat = matmul(a, at.ref);
|
|
5008
|
+
const l = cholesky$1(aat, { symmetrizeInput: false });
|
|
5009
|
+
const lb = triangularSolve(l.ref, b, {
|
|
5010
|
+
leftSide: true,
|
|
5011
|
+
lower: true
|
|
5012
|
+
});
|
|
5013
|
+
const llb = triangularSolve(l, lb, {
|
|
5014
|
+
leftSide: true,
|
|
5015
|
+
transposeA: true
|
|
5016
|
+
});
|
|
5017
|
+
return matmul(at, llb.ref);
|
|
5018
|
+
} else {
|
|
5019
|
+
const ata = matmul(at.ref, a);
|
|
5020
|
+
const l = cholesky$1(ata, { symmetrizeInput: false });
|
|
5021
|
+
const atb = matmul(at, b);
|
|
5022
|
+
const lb = triangularSolve(l.ref, atb, {
|
|
5023
|
+
leftSide: true,
|
|
5024
|
+
lower: true
|
|
5025
|
+
});
|
|
5026
|
+
const llb = triangularSolve(l, lb, {
|
|
5027
|
+
leftSide: true,
|
|
5028
|
+
transposeA: true
|
|
5029
|
+
});
|
|
5030
|
+
return llb;
|
|
5031
|
+
}
|
|
5032
|
+
}
|
|
5033
|
+
|
|
4748
5034
|
//#endregion
|
|
4749
5035
|
//#region src/library/numpy.ts
|
|
4750
5036
|
var numpy_exports = {};
|
|
4751
5037
|
__export(numpy_exports, {
|
|
4752
5038
|
Array: () => Array$1,
|
|
4753
5039
|
DType: () => require_backend.DType,
|
|
4754
|
-
abs: () =>
|
|
5040
|
+
abs: () => absolute,
|
|
4755
5041
|
absolute: () => absolute,
|
|
4756
5042
|
acos: () => acos,
|
|
4757
|
-
acosh: () =>
|
|
5043
|
+
acosh: () => arccosh,
|
|
4758
5044
|
add: () => add,
|
|
5045
|
+
all: () => all,
|
|
4759
5046
|
allclose: () => allclose,
|
|
5047
|
+
any: () => any,
|
|
4760
5048
|
arange: () => arange,
|
|
4761
|
-
arccos: () =>
|
|
5049
|
+
arccos: () => acos,
|
|
4762
5050
|
arccosh: () => arccosh,
|
|
5051
|
+
arcsin: () => asin,
|
|
4763
5052
|
arcsinh: () => arcsinh,
|
|
4764
|
-
arctan: () =>
|
|
4765
|
-
arctan2: () =>
|
|
5053
|
+
arctan: () => atan,
|
|
5054
|
+
arctan2: () => atan2,
|
|
4766
5055
|
arctanh: () => arctanh,
|
|
4767
5056
|
argmax: () => argmax,
|
|
4768
5057
|
argmin: () => argmin,
|
|
5058
|
+
argsort: () => argsort,
|
|
4769
5059
|
array: () => array,
|
|
4770
5060
|
asin: () => asin,
|
|
4771
|
-
asinh: () =>
|
|
5061
|
+
asinh: () => arcsinh,
|
|
4772
5062
|
astype: () => astype,
|
|
4773
5063
|
atan: () => atan,
|
|
4774
5064
|
atan2: () => atan2,
|
|
4775
|
-
atanh: () =>
|
|
5065
|
+
atanh: () => arctanh,
|
|
4776
5066
|
bool: () => bool,
|
|
4777
5067
|
broadcastArrays: () => broadcastArrays,
|
|
4778
5068
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -4782,16 +5072,20 @@ __export(numpy_exports, {
|
|
|
4782
5072
|
clip: () => clip,
|
|
4783
5073
|
columnStack: () => columnStack,
|
|
4784
5074
|
concatenate: () => concatenate,
|
|
5075
|
+
convolve: () => convolve,
|
|
5076
|
+
corrcoef: () => corrcoef,
|
|
5077
|
+
correlate: () => correlate,
|
|
4785
5078
|
cos: () => cos,
|
|
4786
5079
|
cosh: () => cosh,
|
|
5080
|
+
cov: () => cov,
|
|
4787
5081
|
cumsum: () => cumsum,
|
|
4788
|
-
cumulativeSum: () =>
|
|
5082
|
+
cumulativeSum: () => cumsum,
|
|
4789
5083
|
deg2rad: () => deg2rad,
|
|
4790
5084
|
degrees: () => degrees,
|
|
4791
5085
|
diag: () => diag,
|
|
4792
5086
|
diagonal: () => diagonal,
|
|
4793
|
-
divide: () =>
|
|
4794
|
-
dot: () => dot,
|
|
5087
|
+
divide: () => trueDivide,
|
|
5088
|
+
dot: () => dot$1,
|
|
4795
5089
|
dstack: () => dstack,
|
|
4796
5090
|
e: () => e,
|
|
4797
5091
|
einsum: () => einsum,
|
|
@@ -4799,8 +5093,10 @@ __export(numpy_exports, {
|
|
|
4799
5093
|
eulerGamma: () => eulerGamma,
|
|
4800
5094
|
exp: () => exp,
|
|
4801
5095
|
exp2: () => exp2,
|
|
5096
|
+
expandDims: () => expandDims,
|
|
4802
5097
|
expm1: () => expm1,
|
|
4803
5098
|
eye: () => eye,
|
|
5099
|
+
fft: () => numpy_fft_exports,
|
|
4804
5100
|
flip: () => flip,
|
|
4805
5101
|
fliplr: () => fliplr,
|
|
4806
5102
|
flipud: () => flipud,
|
|
@@ -4831,12 +5127,14 @@ __export(numpy_exports, {
|
|
|
4831
5127
|
ldexp: () => ldexp,
|
|
4832
5128
|
less: () => less,
|
|
4833
5129
|
lessEqual: () => lessEqual,
|
|
5130
|
+
linalg: () => numpy_linalg_exports,
|
|
4834
5131
|
linspace: () => linspace,
|
|
4835
5132
|
log: () => log,
|
|
4836
5133
|
log10: () => log10,
|
|
4837
5134
|
log1p: () => log1p,
|
|
4838
5135
|
log2: () => log2,
|
|
4839
5136
|
matmul: () => matmul,
|
|
5137
|
+
matrixTranspose: () => matrixTranspose,
|
|
4840
5138
|
max: () => max,
|
|
4841
5139
|
maximum: () => maximum,
|
|
4842
5140
|
mean: () => mean,
|
|
@@ -4853,10 +5151,10 @@ __export(numpy_exports, {
|
|
|
4853
5151
|
onesLike: () => onesLike,
|
|
4854
5152
|
outer: () => outer,
|
|
4855
5153
|
pad: () => pad,
|
|
4856
|
-
permuteDims: () =>
|
|
5154
|
+
permuteDims: () => transpose,
|
|
4857
5155
|
pi: () => pi,
|
|
4858
5156
|
positive: () => positive,
|
|
4859
|
-
pow: () =>
|
|
5157
|
+
pow: () => power,
|
|
4860
5158
|
power: () => power,
|
|
4861
5159
|
prod: () => prod$1,
|
|
4862
5160
|
promoteTypes: () => require_backend.promoteTypes,
|
|
@@ -4873,6 +5171,7 @@ __export(numpy_exports, {
|
|
|
4873
5171
|
sin: () => sin,
|
|
4874
5172
|
sinh: () => sinh,
|
|
4875
5173
|
size: () => size,
|
|
5174
|
+
sort: () => sort,
|
|
4876
5175
|
sqrt: () => sqrt,
|
|
4877
5176
|
square: () => square,
|
|
4878
5177
|
squeeze: () => squeeze,
|
|
@@ -5037,6 +5336,26 @@ function min(a, axis = null, opts) {
|
|
|
5037
5336
|
function max(a, axis = null, opts) {
|
|
5038
5337
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
5039
5338
|
}
|
|
5339
|
+
/**
|
|
5340
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5341
|
+
*
|
|
5342
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5343
|
+
* removed. If axis is None, returns a scalar.
|
|
5344
|
+
*/
|
|
5345
|
+
function all(a, axis = null, opts) {
|
|
5346
|
+
a = fudgeArray(a).astype(require_backend.DType.Bool);
|
|
5347
|
+
return min(a, axis, opts);
|
|
5348
|
+
}
|
|
5349
|
+
/**
|
|
5350
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5351
|
+
*
|
|
5352
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5353
|
+
* removed. If axis is None, returns a scalar.
|
|
5354
|
+
*/
|
|
5355
|
+
function any(a, axis = null, opts) {
|
|
5356
|
+
a = fudgeArray(a).astype(require_backend.DType.Bool);
|
|
5357
|
+
return max(a, axis, opts);
|
|
5358
|
+
}
|
|
5040
5359
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5041
5360
|
function ptp(a, axis = null, opts) {
|
|
5042
5361
|
a = fudgeArray(a);
|
|
@@ -5111,8 +5430,6 @@ function cumsum(a, axis) {
|
|
|
5111
5430
|
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5112
5431
|
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5113
5432
|
}
|
|
5114
|
-
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
5115
|
-
const cumulativeSum = cumsum;
|
|
5116
5433
|
/** Reverse the elements in an array along the given axes. */
|
|
5117
5434
|
function flip(x, axis = null) {
|
|
5118
5435
|
const nd = ndim(x);
|
|
@@ -5222,8 +5539,11 @@ function flipud(x) {
|
|
|
5222
5539
|
function fliplr(x) {
|
|
5223
5540
|
return flip(x, 1);
|
|
5224
5541
|
}
|
|
5225
|
-
/**
|
|
5226
|
-
|
|
5542
|
+
/** Transpose the last two dimensions of an array. */
|
|
5543
|
+
function matrixTranspose(a) {
|
|
5544
|
+
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
5545
|
+
return moveaxis$1(a, -1, -2);
|
|
5546
|
+
}
|
|
5227
5547
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
5228
5548
|
function ravel(a) {
|
|
5229
5549
|
return fudgeArray(a).ravel();
|
|
@@ -5239,6 +5559,32 @@ function squeeze(a, axis = null) {
|
|
|
5239
5559
|
return reshape(a, newShape);
|
|
5240
5560
|
}
|
|
5241
5561
|
/**
|
|
5562
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
5563
|
+
*
|
|
5564
|
+
* @param a - Input array.
|
|
5565
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
5566
|
+
* is placed. Can be a single integer or an array of integers.
|
|
5567
|
+
* @returns Array with the number of dimensions increased.
|
|
5568
|
+
*
|
|
5569
|
+
* @example
|
|
5570
|
+
* ```ts
|
|
5571
|
+
* const x = np.array([1, 2]);
|
|
5572
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
5573
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
5574
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
5575
|
+
* ```
|
|
5576
|
+
*/
|
|
5577
|
+
function expandDims(a, axis) {
|
|
5578
|
+
const as = shape(a);
|
|
5579
|
+
axis = typeof axis === "number" ? [axis] : axis;
|
|
5580
|
+
axis = require_backend.normalizeAxis(axis, as.length + axis.length);
|
|
5581
|
+
const newShape = [];
|
|
5582
|
+
let srcIdx = 0;
|
|
5583
|
+
for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
|
|
5584
|
+
else newShape.push(as[srcIdx++]);
|
|
5585
|
+
return reshape(a, newShape);
|
|
5586
|
+
}
|
|
5587
|
+
/**
|
|
5242
5588
|
* Repeat each element of an array after themselves.
|
|
5243
5589
|
*
|
|
5244
5590
|
* If no axis is provided, use the flattened input array, and return a flat
|
|
@@ -5326,7 +5672,7 @@ function diagonal(a, offset, axis1, axis2) {
|
|
|
5326
5672
|
*/
|
|
5327
5673
|
function diag(v, k = 0) {
|
|
5328
5674
|
const a = fudgeArray(v);
|
|
5329
|
-
if (!Number.isInteger(k)) throw new
|
|
5675
|
+
if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
|
|
5330
5676
|
if (a.ndim === 1) {
|
|
5331
5677
|
const n = a.shape[0];
|
|
5332
5678
|
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
@@ -5334,12 +5680,32 @@ function diag(v, k = 0) {
|
|
|
5334
5680
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
5335
5681
|
else return ret;
|
|
5336
5682
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
5337
|
-
else throw new
|
|
5683
|
+
else throw new Error("numpy.diag only supports 1D and 2D arrays");
|
|
5338
5684
|
}
|
|
5339
5685
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
5340
5686
|
function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
|
|
5341
5687
|
return diagonal(a, offset, axis1, axis2).sum(-1);
|
|
5342
5688
|
}
|
|
5689
|
+
/**
|
|
5690
|
+
* Return a sorted copy of an array.
|
|
5691
|
+
*
|
|
5692
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
5693
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
5694
|
+
*/
|
|
5695
|
+
function sort(a, axis = -1) {
|
|
5696
|
+
return fudgeArray(a).sort(axis);
|
|
5697
|
+
}
|
|
5698
|
+
/**
|
|
5699
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
5700
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
5701
|
+
*
|
|
5702
|
+
* Returns an array of `int32` indices.
|
|
5703
|
+
*
|
|
5704
|
+
* The array is sorted along a specified axis (the last by default).
|
|
5705
|
+
*/
|
|
5706
|
+
function argsort(a, axis = -1) {
|
|
5707
|
+
return fudgeArray(a).argsort(axis);
|
|
5708
|
+
}
|
|
5343
5709
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5344
5710
|
function allclose(actual, expected, options) {
|
|
5345
5711
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5356,11 +5722,11 @@ function allclose(actual, expected, options) {
|
|
|
5356
5722
|
}
|
|
5357
5723
|
/** Matrix product of two arrays. */
|
|
5358
5724
|
function matmul(x, y) {
|
|
5359
|
-
if (ndim(x) === 0 || ndim(y) === 0) throw new
|
|
5725
|
+
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
5360
5726
|
x = x, y = y;
|
|
5361
5727
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5362
5728
|
const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
|
|
5363
|
-
return dot
|
|
5729
|
+
return dot(x, y, {
|
|
5364
5730
|
lhsContractingDims: [-1],
|
|
5365
5731
|
rhsContractingDims: [-2],
|
|
5366
5732
|
lhsBatchDims: require_backend.range(-2 - numBatchDims, -2),
|
|
@@ -5368,11 +5734,11 @@ function matmul(x, y) {
|
|
|
5368
5734
|
});
|
|
5369
5735
|
}
|
|
5370
5736
|
/** Dot product of two arrays. */
|
|
5371
|
-
function dot(x, y) {
|
|
5737
|
+
function dot$1(x, y) {
|
|
5372
5738
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
5373
5739
|
x = x, y = y;
|
|
5374
5740
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5375
|
-
return dot
|
|
5741
|
+
return dot(x, y, {
|
|
5376
5742
|
lhsContractingDims: [-1],
|
|
5377
5743
|
rhsContractingDims: [-2]
|
|
5378
5744
|
});
|
|
@@ -5388,7 +5754,7 @@ function tensordot(x, y, axes = 2) {
|
|
|
5388
5754
|
x = fudgeArray(x);
|
|
5389
5755
|
y = fudgeArray(y);
|
|
5390
5756
|
if (typeof axes === "number") axes = [require_backend.range(-axes, 0), require_backend.range(axes)];
|
|
5391
|
-
return dot
|
|
5757
|
+
return dot(x, y, {
|
|
5392
5758
|
lhsContractingDims: axes[0],
|
|
5393
5759
|
rhsContractingDims: axes[1]
|
|
5394
5760
|
});
|
|
@@ -5481,7 +5847,7 @@ function einsum(...args) {
|
|
|
5481
5847
|
const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
|
|
5482
5848
|
indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
|
|
5483
5849
|
const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
|
|
5484
|
-
const result = dot
|
|
5850
|
+
const result = dot(a, b, {
|
|
5485
5851
|
lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
|
|
5486
5852
|
rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
|
|
5487
5853
|
lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
|
|
@@ -5509,7 +5875,7 @@ function einsum(...args) {
|
|
|
5509
5875
|
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
5510
5876
|
*/
|
|
5511
5877
|
function inner(x, y) {
|
|
5512
|
-
return dot
|
|
5878
|
+
return dot(fudgeArray(x), fudgeArray(y), {
|
|
5513
5879
|
lhsContractingDims: [-1],
|
|
5514
5880
|
rhsContractingDims: [-1]
|
|
5515
5881
|
});
|
|
@@ -5542,6 +5908,30 @@ function vecdot(x, y, { axis } = {}) {
|
|
|
5542
5908
|
function vdot(x, y) {
|
|
5543
5909
|
return dot$2(ravel(x), ravel(y));
|
|
5544
5910
|
}
|
|
5911
|
+
function _convImpl(name, x, y, mode) {
|
|
5912
|
+
if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
|
|
5913
|
+
let flipOutput = false;
|
|
5914
|
+
if (x.shape[0] < y.shape[0]) {
|
|
5915
|
+
[x, y] = [y, x];
|
|
5916
|
+
if (name === "correlate") flipOutput = true;
|
|
5917
|
+
}
|
|
5918
|
+
if (name === "convolve") y = flip(y);
|
|
5919
|
+
let padding;
|
|
5920
|
+
if (mode === "valid") padding = "VALID";
|
|
5921
|
+
else if (mode === "same") padding = "SAME_LOWER";
|
|
5922
|
+
else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
|
|
5923
|
+
else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
|
|
5924
|
+
const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
|
|
5925
|
+
return flipOutput ? flip(z) : z;
|
|
5926
|
+
}
|
|
5927
|
+
/** Convolution of two one-dimensional arrays. */
|
|
5928
|
+
function convolve(x, y, mode = "full") {
|
|
5929
|
+
return _convImpl("convolve", x, y, mode);
|
|
5930
|
+
}
|
|
5931
|
+
/** Correlation of two one dimensional arrays. */
|
|
5932
|
+
function correlate(x, y, mode = "valid") {
|
|
5933
|
+
return _convImpl("correlate", x, y, mode);
|
|
5934
|
+
}
|
|
5545
5935
|
/**
|
|
5546
5936
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
5547
5937
|
*
|
|
@@ -5550,7 +5940,7 @@ function vdot(x, y) {
|
|
|
5550
5940
|
*/
|
|
5551
5941
|
function meshgrid(xs, { indexing } = {}) {
|
|
5552
5942
|
indexing ??= "xy";
|
|
5553
|
-
for (const x of xs) if (x.ndim !== 1) throw new
|
|
5943
|
+
for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
|
|
5554
5944
|
if (xs.length <= 1) return xs;
|
|
5555
5945
|
if (indexing === "xy") {
|
|
5556
5946
|
const [a, b, ...rest] = xs;
|
|
@@ -5566,44 +5956,7 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
5566
5956
|
];
|
|
5567
5957
|
}
|
|
5568
5958
|
const shape$1 = xs.map((x) => x.shape[0]);
|
|
5569
|
-
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
5570
|
-
}
|
|
5571
|
-
/**
|
|
5572
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
5573
|
-
*
|
|
5574
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
5575
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
5576
|
-
* `k>0` is above it.
|
|
5577
|
-
*/
|
|
5578
|
-
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
5579
|
-
m ??= n;
|
|
5580
|
-
dtype ??= require_backend.DType.Float32;
|
|
5581
|
-
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
5582
|
-
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
5583
|
-
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
5584
|
-
const rows = arange(k, n + k, 1, {
|
|
5585
|
-
dtype: require_backend.DType.Int32,
|
|
5586
|
-
device
|
|
5587
|
-
});
|
|
5588
|
-
const cols = arange(0, m, 1, {
|
|
5589
|
-
dtype: require_backend.DType.Int32,
|
|
5590
|
-
device
|
|
5591
|
-
});
|
|
5592
|
-
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
5593
|
-
}
|
|
5594
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
5595
|
-
function tril(a, k = 0) {
|
|
5596
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5597
|
-
a = fudgeArray(a);
|
|
5598
|
-
const [n, m] = a.shape.slice(-2);
|
|
5599
|
-
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
5600
|
-
}
|
|
5601
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
5602
|
-
function triu(a, k = 0) {
|
|
5603
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5604
|
-
a = fudgeArray(a);
|
|
5605
|
-
const [n, m] = a.shape.slice(-2);
|
|
5606
|
-
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
5959
|
+
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
5607
5960
|
}
|
|
5608
5961
|
/**
|
|
5609
5962
|
* Clip (limit) the values in an array.
|
|
@@ -5629,8 +5982,6 @@ function absolute(x) {
|
|
|
5629
5982
|
x = fudgeArray(x);
|
|
5630
5983
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
5631
5984
|
}
|
|
5632
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
5633
|
-
const abs = absolute;
|
|
5634
5985
|
/** Return an element-wise indication of sign of the input. */
|
|
5635
5986
|
function sign(x) {
|
|
5636
5987
|
x = fudgeArray(x);
|
|
@@ -5709,12 +6060,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
|
|
|
5709
6060
|
const denom = where(xNeg, y, r.add(x));
|
|
5710
6061
|
return atan(numer.div(denom)).mul(2);
|
|
5711
6062
|
});
|
|
5712
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
5713
|
-
const arccos = acos;
|
|
5714
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
5715
|
-
const arctan = atan;
|
|
5716
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
5717
|
-
const arctan2 = atan2;
|
|
5718
6063
|
/** Element-wise subtraction, with broadcasting. */
|
|
5719
6064
|
function subtract(x, y) {
|
|
5720
6065
|
x = fudgeArray(x);
|
|
@@ -5745,8 +6090,6 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
5745
6090
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
5746
6091
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
5747
6092
|
});
|
|
5748
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
5749
|
-
const divide = trueDivide;
|
|
5750
6093
|
/** Round input to the nearest integer towards zero. */
|
|
5751
6094
|
function trunc(x) {
|
|
5752
6095
|
return idiv(x, 1);
|
|
@@ -5768,9 +6111,9 @@ function ldexp(x1, x2) {
|
|
|
5768
6111
|
*/
|
|
5769
6112
|
function frexp(x) {
|
|
5770
6113
|
x = fudgeArray(x);
|
|
5771
|
-
const absx =
|
|
6114
|
+
const absx = absolute(x.ref);
|
|
5772
6115
|
const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(require_backend.DType.Int32));
|
|
5773
|
-
const mantissa =
|
|
6116
|
+
const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
|
|
5774
6117
|
return [mantissa, exponent];
|
|
5775
6118
|
}
|
|
5776
6119
|
/** Calculate `2**p` for all p in the input array. */
|
|
@@ -5813,10 +6156,8 @@ const power = jit$1(function power$1(x1, x2) {
|
|
|
5813
6156
|
const x2i = trunc(x2.ref);
|
|
5814
6157
|
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
5815
6158
|
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
5816
|
-
return where(shouldBeNaN, nan, exp(log(
|
|
6159
|
+
return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
|
|
5817
6160
|
});
|
|
5818
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
5819
|
-
const pow = power;
|
|
5820
6161
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
5821
6162
|
const cbrt = jit$1(function cbrt$1(x) {
|
|
5822
6163
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
@@ -5882,12 +6223,6 @@ const arccosh = jit$1(function arccosh$1(x) {
|
|
|
5882
6223
|
const arctanh = jit$1(function arctanh$1(x) {
|
|
5883
6224
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
5884
6225
|
});
|
|
5885
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5886
|
-
const asinh = arcsinh;
|
|
5887
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5888
|
-
const acosh = arccosh;
|
|
5889
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5890
|
-
const atanh = arctanh;
|
|
5891
6226
|
/**
|
|
5892
6227
|
* Compute the variance of an array.
|
|
5893
6228
|
*
|
|
@@ -5917,6 +6252,26 @@ function var_(x, axis = null, opts) {
|
|
|
5917
6252
|
function std(x, axis = null, opts) {
|
|
5918
6253
|
return sqrt(var_(x, axis, opts));
|
|
5919
6254
|
}
|
|
6255
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
6256
|
+
function cov(x, y) {
|
|
6257
|
+
x = fudgeArray(x);
|
|
6258
|
+
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6259
|
+
if (y !== void 0) {
|
|
6260
|
+
y = fudgeArray(y);
|
|
6261
|
+
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6262
|
+
x = vstack([x, y]);
|
|
6263
|
+
}
|
|
6264
|
+
const [_M, N] = x.shape;
|
|
6265
|
+
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6266
|
+
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
6267
|
+
}
|
|
6268
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
6269
|
+
function corrcoef(x, y) {
|
|
6270
|
+
const c = cov(x, y);
|
|
6271
|
+
const variances = diag(c.ref);
|
|
6272
|
+
const norm = sqrt(outer(variances.ref, variances));
|
|
6273
|
+
return c.div(norm);
|
|
6274
|
+
}
|
|
5920
6275
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5921
6276
|
function isinf(x) {
|
|
5922
6277
|
x = fudgeArray(x);
|
|
@@ -5946,6 +6301,253 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
5946
6301
|
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5947
6302
|
});
|
|
5948
6303
|
|
|
6304
|
+
//#endregion
|
|
6305
|
+
//#region src/library/lax-linalg.ts
|
|
6306
|
+
var lax_linalg_exports = {};
|
|
6307
|
+
__export(lax_linalg_exports, {
|
|
6308
|
+
cholesky: () => cholesky,
|
|
6309
|
+
triangularSolve: () => triangularSolve
|
|
6310
|
+
});
|
|
6311
|
+
/**
|
|
6312
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
6313
|
+
*
|
|
6314
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
6315
|
+
*
|
|
6316
|
+
* - A = L @ L^T (for upper=false, default)
|
|
6317
|
+
* - A = U^T @ U (for upper=true)
|
|
6318
|
+
*
|
|
6319
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
6320
|
+
* The input matrix must be symmetric and positive-definite.
|
|
6321
|
+
*
|
|
6322
|
+
* @example
|
|
6323
|
+
* ```ts
|
|
6324
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6325
|
+
*
|
|
6326
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
6327
|
+
*
|
|
6328
|
+
* // Lower Cholesky factorization (default):
|
|
6329
|
+
* const L = lax.linalg.cholesky(x);
|
|
6330
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
6331
|
+
*
|
|
6332
|
+
* // Upper Cholesky factorization:
|
|
6333
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
6334
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6335
|
+
* ```
|
|
6336
|
+
*/
|
|
6337
|
+
function cholesky(a, { upper = false } = {}) {
|
|
6338
|
+
const L = cholesky$2(a);
|
|
6339
|
+
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6340
|
+
}
|
|
6341
|
+
/**
|
|
6342
|
+
* Solve a triangular linear system.
|
|
6343
|
+
*
|
|
6344
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
6345
|
+
* where `a` is a triangular matrix.
|
|
6346
|
+
*
|
|
6347
|
+
* @example
|
|
6348
|
+
* ```ts
|
|
6349
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6350
|
+
*
|
|
6351
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
6352
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
6353
|
+
*
|
|
6354
|
+
* // Solve L @ x = b
|
|
6355
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
6356
|
+
* // x = [[2.], [5./3.]]
|
|
6357
|
+
* ```
|
|
6358
|
+
*/
|
|
6359
|
+
function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
|
|
6360
|
+
a = fudgeArray(a);
|
|
6361
|
+
b = fudgeArray(b);
|
|
6362
|
+
if (!leftSide) transposeA = !transposeA;
|
|
6363
|
+
else b = moveaxis$1(b, -2, -1);
|
|
6364
|
+
if (transposeA) a = moveaxis$1(a, -2, -1);
|
|
6365
|
+
let x = triangularSolve$1(a, b, {
|
|
6366
|
+
lower,
|
|
6367
|
+
unitDiagonal
|
|
6368
|
+
});
|
|
6369
|
+
if (leftSide) x = moveaxis$1(x, -2, -1);
|
|
6370
|
+
return x;
|
|
6371
|
+
}
|
|
6372
|
+
|
|
6373
|
+
//#endregion
|
|
6374
|
+
//#region src/library/lax.ts
|
|
6375
|
+
var lax_exports = {};
|
|
6376
|
+
__export(lax_exports, {
|
|
6377
|
+
conv: () => conv,
|
|
6378
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
6379
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6380
|
+
dot: () => dot,
|
|
6381
|
+
erf: () => erf,
|
|
6382
|
+
erfc: () => erfc,
|
|
6383
|
+
linalg: () => lax_linalg_exports,
|
|
6384
|
+
reduceWindow: () => reduceWindow,
|
|
6385
|
+
stopGradient: () => stopGradient$1
|
|
6386
|
+
});
|
|
6387
|
+
/**
|
|
6388
|
+
* General dot product/contraction operator.
|
|
6389
|
+
*
|
|
6390
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
6391
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
6392
|
+
*/
|
|
6393
|
+
function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
6394
|
+
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
6395
|
+
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
6396
|
+
lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
6397
|
+
rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
6398
|
+
lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
6399
|
+
rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
6400
|
+
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
6401
|
+
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
6402
|
+
const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
6403
|
+
const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
6404
|
+
const lhs2 = lhs.transpose([
|
|
6405
|
+
...lb,
|
|
6406
|
+
...lf,
|
|
6407
|
+
...lc
|
|
6408
|
+
]);
|
|
6409
|
+
const rhs2 = rhs.transpose([
|
|
6410
|
+
...rb,
|
|
6411
|
+
...rf,
|
|
6412
|
+
...rc
|
|
6413
|
+
]);
|
|
6414
|
+
if (lc.length === 0) return mul(lhs2.reshape([
|
|
6415
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6416
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6417
|
+
...require_backend.rep(rf.length, 1)
|
|
6418
|
+
]), rhs2.reshape([
|
|
6419
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6420
|
+
...require_backend.rep(lf.length, 1),
|
|
6421
|
+
...rf.map((a) => rhs.shape[a])
|
|
6422
|
+
]));
|
|
6423
|
+
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
6424
|
+
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
6425
|
+
if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
6426
|
+
return dot$2(lhs2.reshape([
|
|
6427
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6428
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6429
|
+
...require_backend.rep(rf.length, 1),
|
|
6430
|
+
require_backend.prod(dotShapeX)
|
|
6431
|
+
]), rhs2.reshape([
|
|
6432
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6433
|
+
...require_backend.rep(lf.length, 1),
|
|
6434
|
+
...rf.map((a) => rhs.shape[a]),
|
|
6435
|
+
require_backend.prod(dotShapeY)
|
|
6436
|
+
]));
|
|
6437
|
+
}
|
|
6438
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
6439
|
+
const padType = padding.toUpperCase();
|
|
6440
|
+
switch (padType) {
|
|
6441
|
+
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
6442
|
+
case "SAME":
|
|
6443
|
+
case "SAME_LOWER": {
|
|
6444
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
6445
|
+
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
6446
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
6447
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
6448
|
+
}
|
|
6449
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
6450
|
+
}
|
|
6451
|
+
}
|
|
6452
|
+
/**
|
|
6453
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
6454
|
+
*
|
|
6455
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6456
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
6457
|
+
*
|
|
6458
|
+
* Grouped convolutions are not supported right now.
|
|
6459
|
+
*/
|
|
6460
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6461
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
6462
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
6463
|
+
if (typeof padding === "string") {
|
|
6464
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
6465
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
6466
|
+
}
|
|
6467
|
+
if (featureGroupCount !== 1) {
|
|
6468
|
+
const G = featureGroupCount;
|
|
6469
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
6470
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
6471
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
6472
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
6473
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
6474
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
6475
|
+
N,
|
|
6476
|
+
G,
|
|
6477
|
+
C_in / G,
|
|
6478
|
+
...xs
|
|
6479
|
+
]), 1, 0);
|
|
6480
|
+
const rhsGrouped = rhs.reshape([
|
|
6481
|
+
G,
|
|
6482
|
+
C_out / G,
|
|
6483
|
+
C_in_per_group,
|
|
6484
|
+
...ks
|
|
6485
|
+
]);
|
|
6486
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
6487
|
+
vmapDims: 1,
|
|
6488
|
+
strides: windowStrides,
|
|
6489
|
+
padding,
|
|
6490
|
+
lhsDilation,
|
|
6491
|
+
rhsDilation
|
|
6492
|
+
});
|
|
6493
|
+
const ys = result.shape.slice(3);
|
|
6494
|
+
return moveaxis(result, 0, 1).reshape([
|
|
6495
|
+
N,
|
|
6496
|
+
C_out,
|
|
6497
|
+
...ys
|
|
6498
|
+
]);
|
|
6499
|
+
}
|
|
6500
|
+
return conv$1(lhs, rhs, {
|
|
6501
|
+
strides: windowStrides,
|
|
6502
|
+
padding,
|
|
6503
|
+
lhsDilation,
|
|
6504
|
+
rhsDilation
|
|
6505
|
+
});
|
|
6506
|
+
}
|
|
6507
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6508
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
6509
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
6510
|
+
lhsDilation,
|
|
6511
|
+
rhsDilation
|
|
6512
|
+
});
|
|
6513
|
+
}
|
|
6514
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6515
|
+
function conv(lhs, rhs, windowStrides, padding) {
|
|
6516
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
6517
|
+
}
|
|
6518
|
+
/** Reduce a computation over padded windows. */
|
|
6519
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
6520
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
6521
|
+
if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
|
|
6522
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
6523
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
6524
|
+
window: windowDimensions,
|
|
6525
|
+
strides: windowStrides
|
|
6526
|
+
}));
|
|
6527
|
+
}
|
|
6528
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
6529
|
+
function erf(x) {
|
|
6530
|
+
return erf$1(x);
|
|
6531
|
+
}
|
|
6532
|
+
/**
|
|
6533
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
6534
|
+
*
|
|
6535
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
6536
|
+
* where `erf(x)` is very close to 1.
|
|
6537
|
+
*/
|
|
6538
|
+
function erfc(x) {
|
|
6539
|
+
return erfc$1(x);
|
|
6540
|
+
}
|
|
6541
|
+
/**
|
|
6542
|
+
* Stops gradient computation.
|
|
6543
|
+
*
|
|
6544
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
6545
|
+
* forward or reverse-mode automatic differentiation.
|
|
6546
|
+
*/
|
|
6547
|
+
function stopGradient$1(x) {
|
|
6548
|
+
return stopGradient(x);
|
|
6549
|
+
}
|
|
6550
|
+
|
|
5949
6551
|
//#endregion
|
|
5950
6552
|
//#region src/library/nn.ts
|
|
5951
6553
|
var nn_exports = {};
|
|
@@ -5954,6 +6556,10 @@ __export(nn_exports, {
|
|
|
5954
6556
|
elu: () => elu,
|
|
5955
6557
|
gelu: () => gelu,
|
|
5956
6558
|
glu: () => glu,
|
|
6559
|
+
hardSigmoid: () => hardSigmoid,
|
|
6560
|
+
hardSilu: () => hardSilu,
|
|
6561
|
+
hardSwish: () => hardSilu,
|
|
6562
|
+
hardTanh: () => hardTanh,
|
|
5957
6563
|
identity: () => identity,
|
|
5958
6564
|
leakyRelu: () => leakyRelu,
|
|
5959
6565
|
logSigmoid: () => logSigmoid,
|
|
@@ -5964,14 +6570,17 @@ __export(nn_exports, {
|
|
|
5964
6570
|
oneHot: () => oneHot,
|
|
5965
6571
|
relu: () => relu,
|
|
5966
6572
|
relu6: () => relu6,
|
|
6573
|
+
selu: () => selu,
|
|
5967
6574
|
sigmoid: () => sigmoid,
|
|
5968
6575
|
silu: () => silu,
|
|
5969
6576
|
softSign: () => softSign,
|
|
5970
6577
|
softmax: () => softmax,
|
|
5971
6578
|
softplus: () => softplus,
|
|
6579
|
+
sparsePlus: () => sparsePlus,
|
|
6580
|
+
sparseSigmoid: () => sparseSigmoid,
|
|
5972
6581
|
squareplus: () => squareplus,
|
|
5973
6582
|
standardize: () => standardize,
|
|
5974
|
-
swish: () =>
|
|
6583
|
+
swish: () => silu
|
|
5975
6584
|
});
|
|
5976
6585
|
/**
|
|
5977
6586
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -6006,6 +6615,28 @@ function softplus(x) {
|
|
|
6006
6615
|
return log(exp(x).add(1));
|
|
6007
6616
|
}
|
|
6008
6617
|
/**
|
|
6618
|
+
* @function
|
|
6619
|
+
* Sparse plus function:
|
|
6620
|
+
*
|
|
6621
|
+
* - When `x <= -1`: `0`
|
|
6622
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
6623
|
+
* - When `x >= 1`: `x`
|
|
6624
|
+
*/
|
|
6625
|
+
const sparsePlus = jit$1((x) => {
|
|
6626
|
+
return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
|
|
6627
|
+
});
|
|
6628
|
+
/**
|
|
6629
|
+
* @function
|
|
6630
|
+
* Sparse sigmoid activation function.
|
|
6631
|
+
*
|
|
6632
|
+
* - When `x <= -1`: `0`
|
|
6633
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
6634
|
+
* - When `x >= 1`: `1`
|
|
6635
|
+
*/
|
|
6636
|
+
const sparseSigmoid = jit$1((x) => {
|
|
6637
|
+
return clip(x.add(1).mul(.5), 0, 1);
|
|
6638
|
+
});
|
|
6639
|
+
/**
|
|
6009
6640
|
* Soft-sign activation function, computed element-wise:
|
|
6010
6641
|
* `softsign(x) = x / (|x| + 1)`.
|
|
6011
6642
|
*/
|
|
@@ -6027,17 +6658,6 @@ const silu = jit$1(function silu$1(x) {
|
|
|
6027
6658
|
return x.ref.mul(sigmoid(x));
|
|
6028
6659
|
});
|
|
6029
6660
|
/**
|
|
6030
|
-
* @function
|
|
6031
|
-
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
6032
|
-
* Swish, computed element-wise:
|
|
6033
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
6034
|
-
*
|
|
6035
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
6036
|
-
*
|
|
6037
|
-
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
6038
|
-
*/
|
|
6039
|
-
const swish = silu;
|
|
6040
|
-
/**
|
|
6041
6661
|
* Log-sigmoid activation function, computed element-wise:
|
|
6042
6662
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
6043
6663
|
*/
|
|
@@ -6054,6 +6674,19 @@ function leakyRelu(x, negativeSlope = .01) {
|
|
|
6054
6674
|
x = fudgeArray(x);
|
|
6055
6675
|
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
6056
6676
|
}
|
|
6677
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
6678
|
+
function hardSigmoid(x) {
|
|
6679
|
+
return relu6(add(x, 3)).mul(1 / 6);
|
|
6680
|
+
}
|
|
6681
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
6682
|
+
function hardSilu(x) {
|
|
6683
|
+
x = fudgeArray(x);
|
|
6684
|
+
return x.ref.mul(hardSigmoid(x));
|
|
6685
|
+
}
|
|
6686
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
6687
|
+
function hardTanh(x) {
|
|
6688
|
+
return clip(x, -1, 1);
|
|
6689
|
+
}
|
|
6057
6690
|
/**
|
|
6058
6691
|
* Exponential linear unit activation function.
|
|
6059
6692
|
*
|
|
@@ -6076,6 +6709,20 @@ function celu(x, alpha = 1) {
|
|
|
6076
6709
|
}
|
|
6077
6710
|
/**
|
|
6078
6711
|
* @function
|
|
6712
|
+
* Scaled exponential linear unit activation.
|
|
6713
|
+
*
|
|
6714
|
+
* Computes the element-wise function:
|
|
6715
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
6716
|
+
*
|
|
6717
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
6718
|
+
*/
|
|
6719
|
+
const selu = jit$1(function selu$1(x) {
|
|
6720
|
+
const alpha = 1.6732632423543772;
|
|
6721
|
+
const lambda = 1.0507009873554805;
|
|
6722
|
+
return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
|
|
6723
|
+
});
|
|
6724
|
+
/**
|
|
6725
|
+
* @function
|
|
6079
6726
|
* Gaussion error linear unit (GELU) activation function.
|
|
6080
6727
|
*
|
|
6081
6728
|
* This is computed element-wise. There are two variants depending on whether
|
|
@@ -6229,8 +6876,11 @@ var random_exports = {};
|
|
|
6229
6876
|
__export(random_exports, {
|
|
6230
6877
|
bernoulli: () => bernoulli,
|
|
6231
6878
|
bits: () => bits,
|
|
6879
|
+
cauchy: () => cauchy,
|
|
6232
6880
|
exponential: () => exponential,
|
|
6881
|
+
gumbel: () => gumbel,
|
|
6233
6882
|
key: () => key,
|
|
6883
|
+
laplace: () => laplace,
|
|
6234
6884
|
normal: () => normal,
|
|
6235
6885
|
split: () => split,
|
|
6236
6886
|
uniform: () => uniform
|
|
@@ -6289,6 +6939,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
6289
6939
|
}
|
|
6290
6940
|
/**
|
|
6291
6941
|
* @function
|
|
6942
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
6943
|
+
*
|
|
6944
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
6945
|
+
*/
|
|
6946
|
+
const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
6947
|
+
const u = uniform(key$1, shape$1);
|
|
6948
|
+
return tan(u.sub(.5).mul(Math.PI));
|
|
6949
|
+
}, { staticArgnums: [1] });
|
|
6950
|
+
/**
|
|
6951
|
+
* @function
|
|
6292
6952
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
6293
6953
|
*/
|
|
6294
6954
|
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
@@ -6297,6 +6957,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
6297
6957
|
}, { staticArgnums: [1] });
|
|
6298
6958
|
/**
|
|
6299
6959
|
* @function
|
|
6960
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
6961
|
+
*
|
|
6962
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
6963
|
+
*/
|
|
6964
|
+
const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
|
|
6965
|
+
const u = uniform(key$1, shape$1);
|
|
6966
|
+
return negative(log(negative(log1p(negative(u)))));
|
|
6967
|
+
}, { staticArgnums: [1] });
|
|
6968
|
+
/**
|
|
6969
|
+
* @function
|
|
6970
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
6971
|
+
*
|
|
6972
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
6973
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
6974
|
+
*/
|
|
6975
|
+
const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
6976
|
+
const u = uniform(key$1, shape$1);
|
|
6977
|
+
const centered = u.sub(.5);
|
|
6978
|
+
const s = sign(centered.ref);
|
|
6979
|
+
const absVal = absolute(centered);
|
|
6980
|
+
return s.mul(log1p(absVal.mul(-2)).mul(-1));
|
|
6981
|
+
}, { staticArgnums: [1] });
|
|
6982
|
+
/**
|
|
6983
|
+
* @function
|
|
6300
6984
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6301
6985
|
*
|
|
6302
6986
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -6405,11 +7089,6 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
6405
7089
|
*/
|
|
6406
7090
|
const jacrev = jacrev$1;
|
|
6407
7091
|
/**
|
|
6408
|
-
* @function
|
|
6409
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
6410
|
-
*/
|
|
6411
|
-
const jacobian = jacrev;
|
|
6412
|
-
/**
|
|
6413
7092
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
6414
7093
|
*
|
|
6415
7094
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -6445,6 +7124,7 @@ async function devicePut(x, device) {
|
|
|
6445
7124
|
|
|
6446
7125
|
//#endregion
|
|
6447
7126
|
exports.Array = Array$1;
|
|
7127
|
+
exports.ClosedJaxpr = ClosedJaxpr;
|
|
6448
7128
|
exports.DType = require_backend.DType;
|
|
6449
7129
|
exports.Jaxpr = Jaxpr;
|
|
6450
7130
|
exports.blockUntilReady = blockUntilReady;
|
|
@@ -6454,7 +7134,7 @@ exports.devices = require_backend.devices;
|
|
|
6454
7134
|
exports.grad = grad;
|
|
6455
7135
|
exports.init = require_backend.init;
|
|
6456
7136
|
exports.jacfwd = jacfwd;
|
|
6457
|
-
exports.jacobian =
|
|
7137
|
+
exports.jacobian = jacrev;
|
|
6458
7138
|
exports.jacrev = jacrev;
|
|
6459
7139
|
exports.jit = jit;
|
|
6460
7140
|
exports.jvp = jvp;
|
|
@@ -6499,5 +7179,4 @@ Object.defineProperty(exports, 'tree', {
|
|
|
6499
7179
|
});
|
|
6500
7180
|
exports.valueAndGrad = valueAndGrad;
|
|
6501
7181
|
exports.vjp = vjp;
|
|
6502
|
-
exports.vmap = vmap;
|
|
6503
|
-
//# sourceMappingURL=index.cjs.map
|
|
7182
|
+
exports.vmap = vmap;
|