@jax-js/jax 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +5 -2
- package/dist/{backend-CmaidnkQ.cjs → backend-Bu9GY6sK.cjs} +166 -18
- package/dist/{backend-BY8wlLEl.js → backend-tngXtWe4.js} +148 -18
- package/dist/index.cjs +1683 -1004
- package/dist/index.d.cts +365 -95
- package/dist/index.d.ts +365 -95
- package/dist/index.js +1675 -997
- package/dist/{webgpu-C9iAP5h5.js → webgpu-ChVgx3b6.js} +400 -95
- package/dist/{webgpu-BVns4DbI.cjs → webgpu-Oj3Kd-kd.cjs} +400 -95
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-tngXtWe4.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
@@ -331,6 +331,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
331
331
|
Primitive$1["Mul"] = "mul";
|
|
332
332
|
Primitive$1["Idiv"] = "idiv";
|
|
333
333
|
Primitive$1["Mod"] = "mod";
|
|
334
|
+
Primitive$1["Min"] = "min";
|
|
335
|
+
Primitive$1["Max"] = "max";
|
|
334
336
|
Primitive$1["Neg"] = "neg";
|
|
335
337
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
336
338
|
Primitive$1["Floor"] = "floor";
|
|
@@ -338,7 +340,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
338
340
|
Primitive$1["StopGradient"] = "stop_gradient";
|
|
339
341
|
Primitive$1["Cast"] = "cast";
|
|
340
342
|
Primitive$1["Bitcast"] = "bitcast";
|
|
341
|
-
Primitive$1["RandomBits"] = "random_bits";
|
|
342
343
|
Primitive$1["Sin"] = "sin";
|
|
343
344
|
Primitive$1["Cos"] = "cos";
|
|
344
345
|
Primitive$1["Asin"] = "asin";
|
|
@@ -348,8 +349,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
348
349
|
Primitive$1["Erf"] = "erf";
|
|
349
350
|
Primitive$1["Erfc"] = "erfc";
|
|
350
351
|
Primitive$1["Sqrt"] = "sqrt";
|
|
351
|
-
Primitive$1["Min"] = "min";
|
|
352
|
-
Primitive$1["Max"] = "max";
|
|
353
352
|
Primitive$1["Reduce"] = "reduce";
|
|
354
353
|
Primitive$1["Dot"] = "dot";
|
|
355
354
|
Primitive$1["Conv"] = "conv";
|
|
@@ -357,14 +356,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
357
356
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
358
357
|
Primitive$1["Compare"] = "compare";
|
|
359
358
|
Primitive$1["Where"] = "where";
|
|
359
|
+
Primitive$1["RandomBits"] = "random_bits";
|
|
360
|
+
Primitive$1["Gather"] = "gather";
|
|
360
361
|
Primitive$1["Transpose"] = "transpose";
|
|
361
362
|
Primitive$1["Broadcast"] = "broadcast";
|
|
362
363
|
Primitive$1["Reshape"] = "reshape";
|
|
363
364
|
Primitive$1["Flip"] = "flip";
|
|
364
365
|
Primitive$1["Shrink"] = "shrink";
|
|
365
366
|
Primitive$1["Pad"] = "pad";
|
|
366
|
-
Primitive$1["
|
|
367
|
-
Primitive$1["
|
|
367
|
+
Primitive$1["Sort"] = "sort";
|
|
368
|
+
Primitive$1["Argsort"] = "argsort";
|
|
369
|
+
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
370
|
+
Primitive$1["Cholesky"] = "cholesky";
|
|
371
|
+
Primitive$1["Jit"] = "jit";
|
|
368
372
|
return Primitive$1;
|
|
369
373
|
}({});
|
|
370
374
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
@@ -386,6 +390,12 @@ function idiv(x, y) {
|
|
|
386
390
|
function mod(x, y) {
|
|
387
391
|
return bind1(Primitive.Mod, [x, y]);
|
|
388
392
|
}
|
|
393
|
+
function min$1(x, y) {
|
|
394
|
+
return bind1(Primitive.Min, [x, y]);
|
|
395
|
+
}
|
|
396
|
+
function max$1(x, y) {
|
|
397
|
+
return bind1(Primitive.Max, [x, y]);
|
|
398
|
+
}
|
|
389
399
|
function neg(x) {
|
|
390
400
|
return bind1(Primitive.Neg, [x]);
|
|
391
401
|
}
|
|
@@ -407,12 +417,6 @@ function cast(x, dtype) {
|
|
|
407
417
|
function bitcast(x, dtype) {
|
|
408
418
|
return bind1(Primitive.Bitcast, [x], { dtype });
|
|
409
419
|
}
|
|
410
|
-
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
411
|
-
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
412
|
-
shape: shape$1,
|
|
413
|
-
mode
|
|
414
|
-
});
|
|
415
|
-
}
|
|
416
420
|
function sin$1(x) {
|
|
417
421
|
return bind1(Primitive.Sin, [x]);
|
|
418
422
|
}
|
|
@@ -440,12 +444,6 @@ function erfc$1(x) {
|
|
|
440
444
|
function sqrt$1(x) {
|
|
441
445
|
return bind1(Primitive.Sqrt, [x]);
|
|
442
446
|
}
|
|
443
|
-
function min$1(x, y) {
|
|
444
|
-
return bind1(Primitive.Min, [x, y]);
|
|
445
|
-
}
|
|
446
|
-
function max$1(x, y) {
|
|
447
|
-
return bind1(Primitive.Max, [x, y]);
|
|
448
|
-
}
|
|
449
447
|
function reduce(x, op, axis = null, opts) {
|
|
450
448
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
451
449
|
axis = normalizeAxis(axis, ndim$1(x));
|
|
@@ -501,6 +499,23 @@ function where$1(cond, x, y) {
|
|
|
501
499
|
y
|
|
502
500
|
]);
|
|
503
501
|
}
|
|
502
|
+
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
503
|
+
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
504
|
+
shape: shape$1,
|
|
505
|
+
mode
|
|
506
|
+
});
|
|
507
|
+
}
|
|
508
|
+
function gather(x, indices, axis, outDim) {
|
|
509
|
+
if (indices.length === 0) throw new Error("gather() requires at least one index");
|
|
510
|
+
if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
|
|
511
|
+
axis = axis.map((a) => checkAxis(a, ndim$1(x)));
|
|
512
|
+
if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
|
|
513
|
+
outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
|
|
514
|
+
return bind1(Primitive.Gather, [x, ...indices], {
|
|
515
|
+
axis,
|
|
516
|
+
outDim
|
|
517
|
+
});
|
|
518
|
+
}
|
|
504
519
|
function transpose$1(x, perm) {
|
|
505
520
|
perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
|
|
506
521
|
if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
@@ -550,16 +565,27 @@ function pad$1(x, width) {
|
|
|
550
565
|
} else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
|
|
551
566
|
return bind1(Primitive.Pad, [x], { width });
|
|
552
567
|
}
|
|
553
|
-
function
|
|
554
|
-
if (
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
568
|
+
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
569
|
+
if (lower) {
|
|
570
|
+
a = flip$1(a, [-2, -1]);
|
|
571
|
+
b = flip$1(b, [-1]);
|
|
572
|
+
}
|
|
573
|
+
let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
574
|
+
if (lower) x = flip$1(x, [-1]);
|
|
575
|
+
return x;
|
|
576
|
+
}
|
|
577
|
+
function cholesky$2(x) {
|
|
578
|
+
return bind1(Primitive.Cholesky, [x]);
|
|
579
|
+
}
|
|
580
|
+
function sort$1(x) {
|
|
581
|
+
const nd = ndim$1(x);
|
|
582
|
+
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
583
|
+
return bind1(Primitive.Sort, [x]);
|
|
584
|
+
}
|
|
585
|
+
function argsort$1(x) {
|
|
586
|
+
const nd = ndim$1(x);
|
|
587
|
+
if (nd === 0) throw new Error("argsort: requires at least 1D input");
|
|
588
|
+
return bind(Primitive.Argsort, [x]);
|
|
563
589
|
}
|
|
564
590
|
function bind1(prim, args, params = {}) {
|
|
565
591
|
const [results] = bind(prim, args, params);
|
|
@@ -722,7 +748,7 @@ var Tracer = class Tracer {
|
|
|
722
748
|
if (isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
|
|
723
749
|
return idiv(this, other);
|
|
724
750
|
}
|
|
725
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
751
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
726
752
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
727
753
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
728
754
|
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
@@ -775,6 +801,34 @@ var Tracer = class Tracer {
|
|
|
775
801
|
this.dispose();
|
|
776
802
|
}
|
|
777
803
|
/**
|
|
804
|
+
* Return a sorted copy of an array in ascending order.
|
|
805
|
+
*
|
|
806
|
+
* See `jax.numpy.sort` for full docs.
|
|
807
|
+
*/
|
|
808
|
+
sort(axis = -1) {
|
|
809
|
+
axis = checkAxis(axis, this.ndim);
|
|
810
|
+
if (this.shape[axis] <= 1) return this;
|
|
811
|
+
if (axis === this.ndim - 1) return sort$1(this);
|
|
812
|
+
const perm = range(this.ndim);
|
|
813
|
+
perm.splice(axis, 1);
|
|
814
|
+
perm.push(axis);
|
|
815
|
+
return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
|
|
816
|
+
}
|
|
817
|
+
/**
|
|
818
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
819
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
820
|
+
*
|
|
821
|
+
* See `jax.numpy.argsort` for full docs.
|
|
822
|
+
*/
|
|
823
|
+
argsort(axis = -1) {
|
|
824
|
+
axis = checkAxis(axis, this.ndim);
|
|
825
|
+
if (axis === this.ndim - 1) return argsort$1(this)[1];
|
|
826
|
+
const perm = range(this.ndim);
|
|
827
|
+
perm.splice(axis, 1);
|
|
828
|
+
perm.push(axis);
|
|
829
|
+
return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
|
|
830
|
+
}
|
|
831
|
+
/**
|
|
778
832
|
* Slice an array along one or more axes.
|
|
779
833
|
*
|
|
780
834
|
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
@@ -891,6 +945,9 @@ var ShapedArray = class ShapedArray {
|
|
|
891
945
|
get ndim() {
|
|
892
946
|
return this.shape.length;
|
|
893
947
|
}
|
|
948
|
+
get size() {
|
|
949
|
+
return prod(this.shape);
|
|
950
|
+
}
|
|
894
951
|
toString() {
|
|
895
952
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
896
953
|
}
|
|
@@ -1186,13 +1243,13 @@ var Jaxpr = class Jaxpr {
|
|
|
1186
1243
|
}
|
|
1187
1244
|
return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
|
|
1188
1245
|
}
|
|
1189
|
-
/** Flattens nested
|
|
1246
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1190
1247
|
flatten() {
|
|
1191
|
-
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.
|
|
1248
|
+
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
|
|
1192
1249
|
const newEqns = [];
|
|
1193
1250
|
const varMap = /* @__PURE__ */ new Map();
|
|
1194
1251
|
const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
|
|
1195
|
-
for (const eqn of this.eqns) if (eqn.primitive === Primitive.
|
|
1252
|
+
for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
|
|
1196
1253
|
const jaxpr = eqn.params.jaxpr.flatten();
|
|
1197
1254
|
const translation = /* @__PURE__ */ new Map();
|
|
1198
1255
|
const translationF = (x) => x instanceof Var ? translation.get(x) : x;
|
|
@@ -1293,19 +1350,48 @@ function evalJaxpr(jaxpr, args) {
|
|
|
1293
1350
|
function jaxprAsFun(jaxpr) {
|
|
1294
1351
|
return (...args) => evalJaxpr(jaxpr, args);
|
|
1295
1352
|
}
|
|
1353
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1354
|
+
var ClosedJaxpr = class ClosedJaxpr {
|
|
1355
|
+
constructor(jaxpr, consts) {
|
|
1356
|
+
this.jaxpr = jaxpr;
|
|
1357
|
+
this.consts = consts;
|
|
1358
|
+
}
|
|
1359
|
+
/** String representation of this Jaxpr. */
|
|
1360
|
+
toString() {
|
|
1361
|
+
return this.jaxpr.toString();
|
|
1362
|
+
}
|
|
1363
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1364
|
+
mapJaxpr(f) {
|
|
1365
|
+
return new ClosedJaxpr(f(this.jaxpr), this.consts);
|
|
1366
|
+
}
|
|
1367
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1368
|
+
dispose() {
|
|
1369
|
+
for (const c of this.consts) c.dispose();
|
|
1370
|
+
}
|
|
1371
|
+
};
|
|
1296
1372
|
/** Tracer that records its operations to dynamically construct a Jaxpr. */
|
|
1297
1373
|
var JaxprTracer = class extends Tracer {
|
|
1374
|
+
#rc;
|
|
1298
1375
|
constructor(trace$1, aval) {
|
|
1299
1376
|
super(trace$1);
|
|
1300
1377
|
this.aval = aval;
|
|
1378
|
+
this.#rc = 1;
|
|
1301
1379
|
}
|
|
1302
1380
|
toString() {
|
|
1303
1381
|
return `JaxprTracer(${this.aval.toString()})`;
|
|
1304
1382
|
}
|
|
1305
1383
|
get ref() {
|
|
1384
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1385
|
+
this.#rc++;
|
|
1306
1386
|
return this;
|
|
1307
1387
|
}
|
|
1308
|
-
dispose() {
|
|
1388
|
+
dispose() {
|
|
1389
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1390
|
+
this.#rc--;
|
|
1391
|
+
}
|
|
1392
|
+
trackLiftedConstant() {
|
|
1393
|
+
this.#rc++;
|
|
1394
|
+
}
|
|
1309
1395
|
};
|
|
1310
1396
|
/** Analogous to the 'DynamicJaxprTrace' class in JAX. */
|
|
1311
1397
|
var JaxprTrace = class extends Trace {
|
|
@@ -1318,17 +1404,24 @@ var JaxprTrace = class extends Trace {
|
|
|
1318
1404
|
}
|
|
1319
1405
|
/** Register a constant / literal in this Jaxpr. */
|
|
1320
1406
|
getOrMakeConstTracer(val) {
|
|
1407
|
+
if (!(val instanceof Tracer)) val = pureArray(val);
|
|
1321
1408
|
let tracer = this.builder.constTracers.get(val);
|
|
1322
1409
|
if (tracer === void 0) {
|
|
1323
1410
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
1324
|
-
this.builder.addConst(tracer, val
|
|
1411
|
+
this.builder.addConst(tracer, val);
|
|
1412
|
+
} else {
|
|
1413
|
+
val.dispose();
|
|
1414
|
+
tracer.trackLiftedConstant();
|
|
1325
1415
|
}
|
|
1326
1416
|
return tracer;
|
|
1327
1417
|
}
|
|
1328
1418
|
pure = this.getOrMakeConstTracer;
|
|
1329
1419
|
lift = this.getOrMakeConstTracer;
|
|
1330
1420
|
processPrimitive(primitive, tracers, params) {
|
|
1331
|
-
const avalsIn = tracers.map((t) =>
|
|
1421
|
+
const avalsIn = tracers.map((t) => {
|
|
1422
|
+
t.dispose();
|
|
1423
|
+
return t.aval;
|
|
1424
|
+
});
|
|
1332
1425
|
const avalsOut = abstractEvalRules[primitive](avalsIn, params);
|
|
1333
1426
|
const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
|
|
1334
1427
|
this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
|
|
@@ -1371,20 +1464,17 @@ var JaxprBuilder = class {
|
|
|
1371
1464
|
return v;
|
|
1372
1465
|
}
|
|
1373
1466
|
build(inTracers, outTracers) {
|
|
1374
|
-
|
|
1467
|
+
const [constVars, consts] = unzip2(this.constVals.entries());
|
|
1375
1468
|
const t2v = this.getVar.bind(this);
|
|
1376
1469
|
const inBinders = [...constVars, ...inTracers.map(t2v)];
|
|
1377
1470
|
const outVars = outTracers.map(t2v);
|
|
1378
|
-
|
|
1471
|
+
const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
|
|
1379
1472
|
typecheckJaxpr(jaxpr);
|
|
1380
|
-
|
|
1381
|
-
return
|
|
1382
|
-
jaxpr,
|
|
1383
|
-
consts
|
|
1384
|
-
};
|
|
1473
|
+
const cjaxpr = new ClosedJaxpr(jaxpr, consts);
|
|
1474
|
+
return _inlineLiterals(cjaxpr);
|
|
1385
1475
|
}
|
|
1386
1476
|
};
|
|
1387
|
-
function _inlineLiterals(jaxpr, consts) {
|
|
1477
|
+
function _inlineLiterals({ jaxpr, consts }) {
|
|
1388
1478
|
const literals = /* @__PURE__ */ new Map();
|
|
1389
1479
|
const constBinders = [];
|
|
1390
1480
|
const newConsts = [];
|
|
@@ -1399,7 +1489,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
1399
1489
|
const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
|
|
1400
1490
|
const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
|
|
1401
1491
|
typecheckJaxpr(newJaxpr);
|
|
1402
|
-
return
|
|
1492
|
+
return new ClosedJaxpr(newJaxpr, newConsts);
|
|
1403
1493
|
}
|
|
1404
1494
|
function binopAbstractEval([x, y]) {
|
|
1405
1495
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
@@ -1418,6 +1508,8 @@ const abstractEvalRules = {
|
|
|
1418
1508
|
[Primitive.Mul]: binopAbstractEval,
|
|
1419
1509
|
[Primitive.Idiv]: binopAbstractEval,
|
|
1420
1510
|
[Primitive.Mod]: binopAbstractEval,
|
|
1511
|
+
[Primitive.Min]: binopAbstractEval,
|
|
1512
|
+
[Primitive.Max]: binopAbstractEval,
|
|
1421
1513
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1422
1514
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1423
1515
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1431,12 +1523,6 @@ const abstractEvalRules = {
|
|
|
1431
1523
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1432
1524
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1433
1525
|
},
|
|
1434
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1435
|
-
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1436
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1437
|
-
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1438
|
-
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1439
|
-
},
|
|
1440
1526
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
1441
1527
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
1442
1528
|
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
@@ -1446,8 +1532,6 @@ const abstractEvalRules = {
|
|
|
1446
1532
|
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
1447
1533
|
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
1448
1534
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
1449
|
-
[Primitive.Min]: binopAbstractEval,
|
|
1450
|
-
[Primitive.Max]: binopAbstractEval,
|
|
1451
1535
|
[Primitive.Reduce]([x], { axis }) {
|
|
1452
1536
|
const axisSet = new Set(axis);
|
|
1453
1537
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1480,6 +1564,25 @@ const abstractEvalRules = {
|
|
|
1480
1564
|
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
1481
1565
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1482
1566
|
},
|
|
1567
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1568
|
+
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1569
|
+
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1570
|
+
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1571
|
+
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1572
|
+
},
|
|
1573
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
1574
|
+
for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
1575
|
+
if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
|
|
1576
|
+
if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
|
|
1577
|
+
if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
|
|
1578
|
+
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
1579
|
+
const axisSet = new Set(axis);
|
|
1580
|
+
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
1581
|
+
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
1582
|
+
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
1583
|
+
newShape.splice(outDim, 0, ...gatherShape);
|
|
1584
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1585
|
+
},
|
|
1483
1586
|
[Primitive.Transpose]([x], { perm }) {
|
|
1484
1587
|
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
1485
1588
|
},
|
|
@@ -1500,23 +1603,31 @@ const abstractEvalRules = {
|
|
|
1500
1603
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
1501
1604
|
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1502
1605
|
},
|
|
1503
|
-
[Primitive.
|
|
1504
|
-
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
if (
|
|
1509
|
-
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1606
|
+
[Primitive.Sort]([x]) {
|
|
1607
|
+
if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
|
|
1608
|
+
return [ShapedArray.fromAval(x)];
|
|
1609
|
+
},
|
|
1610
|
+
[Primitive.Argsort]([x]) {
|
|
1611
|
+
if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
|
|
1612
|
+
return [ShapedArray.fromAval(x), new ShapedArray(x.shape, DType.Int32, false)];
|
|
1613
|
+
},
|
|
1614
|
+
[Primitive.TriangularSolve]([a, b]) {
|
|
1615
|
+
if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
|
|
1616
|
+
if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
|
|
1617
|
+
const [m, n] = a.shape.slice(-2);
|
|
1618
|
+
const [_batch, q] = b.shape.slice(-2);
|
|
1619
|
+
if (!deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
|
|
1620
|
+
return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
|
|
1621
|
+
},
|
|
1622
|
+
[Primitive.Cholesky]([a]) {
|
|
1623
|
+
if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
|
|
1624
|
+
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1625
|
+
return [ShapedArray.fromAval(a)];
|
|
1515
1626
|
},
|
|
1516
|
-
[Primitive.
|
|
1627
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
1517
1628
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1518
|
-
if (args.length !== inTypes.length) throw new TypeError(`
|
|
1519
|
-
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`
|
|
1629
|
+
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
1630
|
+
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
|
|
1520
1631
|
return outTypes;
|
|
1521
1632
|
}
|
|
1522
1633
|
};
|
|
@@ -1552,11 +1663,10 @@ function makeJaxpr$1(f, opts) {
|
|
|
1552
1663
|
const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
1553
1664
|
const outs = fFlat(...tracersIn);
|
|
1554
1665
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
1555
|
-
const
|
|
1666
|
+
const jaxpr = builder.build(tracersIn, tracersOut);
|
|
1556
1667
|
if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
|
|
1557
1668
|
return {
|
|
1558
|
-
jaxpr: jaxpr.simplify(),
|
|
1559
|
-
consts,
|
|
1669
|
+
jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
|
|
1560
1670
|
treedef: outTree.value
|
|
1561
1671
|
};
|
|
1562
1672
|
} catch (_) {
|
|
@@ -1575,22 +1685,28 @@ function jit$1(f, opts) {
|
|
|
1575
1685
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
1576
1686
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
1577
1687
|
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
1578
|
-
const { jaxpr,
|
|
1579
|
-
const outs = bind(Primitive.
|
|
1688
|
+
const { jaxpr, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
1689
|
+
const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
|
|
1580
1690
|
name: f.name || "closure",
|
|
1581
|
-
jaxpr,
|
|
1582
|
-
numConsts: consts.length
|
|
1691
|
+
jaxpr: jaxpr.jaxpr,
|
|
1692
|
+
numConsts: jaxpr.consts.length
|
|
1583
1693
|
});
|
|
1584
1694
|
return unflatten(outTree, outs);
|
|
1585
1695
|
});
|
|
1586
1696
|
result.dispose = () => {
|
|
1587
|
-
for (const {
|
|
1697
|
+
for (const { jaxpr } of cache.values()) jaxpr.dispose();
|
|
1588
1698
|
};
|
|
1589
1699
|
return result;
|
|
1590
1700
|
}
|
|
1591
1701
|
|
|
1592
1702
|
//#endregion
|
|
1593
1703
|
//#region src/frontend/jit.ts
|
|
1704
|
+
const routinePrimitives = new Map([
|
|
1705
|
+
[Primitive.Sort, Routines.Sort],
|
|
1706
|
+
[Primitive.Argsort, Routines.Argsort],
|
|
1707
|
+
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1708
|
+
[Primitive.Cholesky, Routines.Cholesky]
|
|
1709
|
+
]);
|
|
1594
1710
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1595
1711
|
var JitProgram = class {
|
|
1596
1712
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1605,9 +1721,14 @@ var JitProgram = class {
|
|
|
1605
1721
|
case "execute": {
|
|
1606
1722
|
const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
|
|
1607
1723
|
const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
|
|
1608
|
-
|
|
1724
|
+
const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
|
|
1725
|
+
if (step.source instanceof Kernel) return PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
|
|
1726
|
+
else if (step.source instanceof Routine) return PPrint.pp(`${executeText}, routine ${step.source.name}`);
|
|
1727
|
+
else {
|
|
1728
|
+
step.source;
|
|
1729
|
+
return PPrint.pp(executeText);
|
|
1730
|
+
}
|
|
1609
1731
|
}
|
|
1610
|
-
case "const": return PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
|
|
1611
1732
|
case "malloc": return PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
|
|
1612
1733
|
case "incref": return PPrint.pp(`incref ${step.input}`);
|
|
1613
1734
|
case "free": return PPrint.pp(`free ${step.input}`);
|
|
@@ -1630,12 +1751,9 @@ var JitProgram = class {
|
|
|
1630
1751
|
const inputs$1 = step.inputs.map((id) => scope.get(id));
|
|
1631
1752
|
const outputs = step.outputs.map((id) => scope.get(id));
|
|
1632
1753
|
if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
|
|
1633
|
-
pending.push(new PendingExecute(this.backend, step.
|
|
1754
|
+
pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
|
|
1634
1755
|
break;
|
|
1635
1756
|
}
|
|
1636
|
-
case "const":
|
|
1637
|
-
scope.set(step.output, step.slot);
|
|
1638
|
-
break;
|
|
1639
1757
|
case "malloc": {
|
|
1640
1758
|
const slot = this.backend.malloc(step.size);
|
|
1641
1759
|
scope.set(step.output, slot);
|
|
@@ -1669,34 +1787,37 @@ var JitProgramBuilder = class {
|
|
|
1669
1787
|
this.#nextId = nargs;
|
|
1670
1788
|
this.steps = [];
|
|
1671
1789
|
}
|
|
1672
|
-
pushConst(slot) {
|
|
1673
|
-
const id = this.#nextId++;
|
|
1674
|
-
this.steps.push({
|
|
1675
|
-
type: "const",
|
|
1676
|
-
slot,
|
|
1677
|
-
output: id
|
|
1678
|
-
});
|
|
1679
|
-
return id;
|
|
1680
|
-
}
|
|
1681
1790
|
pushLit(lit) {
|
|
1682
|
-
const kernel = new Kernel(0,
|
|
1791
|
+
const kernel = new Kernel(0, lit.aval.size, AluExp.const(lit.dtype, lit.value));
|
|
1683
1792
|
return this.pushKernel(kernel, []);
|
|
1684
1793
|
}
|
|
1685
|
-
|
|
1794
|
+
pushBuffer(size$1) {
|
|
1686
1795
|
const id = this.#nextId++;
|
|
1687
1796
|
this.steps.push({
|
|
1688
1797
|
type: "malloc",
|
|
1689
|
-
size:
|
|
1798
|
+
size: size$1,
|
|
1690
1799
|
output: id
|
|
1691
1800
|
});
|
|
1801
|
+
return id;
|
|
1802
|
+
}
|
|
1803
|
+
pushKernel(kernel, inputs) {
|
|
1804
|
+
const id = this.pushBuffer(kernel.bytes);
|
|
1692
1805
|
this.steps.push({
|
|
1693
1806
|
type: "execute",
|
|
1694
|
-
kernel,
|
|
1807
|
+
source: kernel,
|
|
1695
1808
|
inputs,
|
|
1696
1809
|
outputs: [id]
|
|
1697
1810
|
});
|
|
1698
1811
|
return id;
|
|
1699
1812
|
}
|
|
1813
|
+
pushRoutine(routine, inputs, outputs) {
|
|
1814
|
+
this.steps.push({
|
|
1815
|
+
type: "execute",
|
|
1816
|
+
source: routine,
|
|
1817
|
+
inputs,
|
|
1818
|
+
outputs
|
|
1819
|
+
});
|
|
1820
|
+
}
|
|
1700
1821
|
pushIncref(id) {
|
|
1701
1822
|
this.steps.push({
|
|
1702
1823
|
type: "incref",
|
|
@@ -1722,28 +1843,18 @@ var JitProgramBuilder = class {
|
|
|
1722
1843
|
}
|
|
1723
1844
|
};
|
|
1724
1845
|
const jitCompileCache = /* @__PURE__ */ new Map();
|
|
1725
|
-
function jitCompile(backend, jaxpr
|
|
1726
|
-
|
|
1727
|
-
for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
|
|
1728
|
-
const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
|
|
1846
|
+
function jitCompile(backend, jaxpr) {
|
|
1847
|
+
const cacheKey = backend.type + "," + FpHash.hash(jaxpr);
|
|
1729
1848
|
const cached = jitCompileCache.get(cacheKey);
|
|
1730
1849
|
if (cached) return cached;
|
|
1731
1850
|
if (DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
|
|
1732
1851
|
jaxpr = jaxpr.flatten().simplify();
|
|
1733
|
-
const nargs = jaxpr.inBinders.length
|
|
1852
|
+
const nargs = jaxpr.inBinders.length;
|
|
1734
1853
|
const builder = new JitProgramBuilder(backend, nargs);
|
|
1735
1854
|
const blackNodes = splitGraphDataflow(backend, jaxpr);
|
|
1736
1855
|
const ctx = /* @__PURE__ */ new Map();
|
|
1737
|
-
for (let i = 0; i < consts.length; i++) {
|
|
1738
|
-
const v = jaxpr.inBinders[i];
|
|
1739
|
-
const slot = consts[i]._realizeSource();
|
|
1740
|
-
ctx.set(v, {
|
|
1741
|
-
type: "imm",
|
|
1742
|
-
arg: builder.pushConst(slot)
|
|
1743
|
-
});
|
|
1744
|
-
}
|
|
1745
1856
|
for (let i = 0; i < nargs; i++) {
|
|
1746
|
-
const v = jaxpr.inBinders[
|
|
1857
|
+
const v = jaxpr.inBinders[i];
|
|
1747
1858
|
ctx.set(v, {
|
|
1748
1859
|
type: "imm",
|
|
1749
1860
|
arg: i
|
|
@@ -1751,6 +1862,31 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1751
1862
|
}
|
|
1752
1863
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1753
1864
|
const eqn = jaxpr.eqns[i];
|
|
1865
|
+
if (routinePrimitives.has(eqn.primitive)) {
|
|
1866
|
+
const routine = new Routine(routinePrimitives.get(eqn.primitive), {
|
|
1867
|
+
inputShapes: eqn.inputs.map((x) => x.aval.shape),
|
|
1868
|
+
inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
|
|
1869
|
+
outputShapes: eqn.outBinders.map((x) => x.aval.shape),
|
|
1870
|
+
outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
|
|
1871
|
+
}, eqn.params);
|
|
1872
|
+
const inputs = [];
|
|
1873
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1874
|
+
const jv = ctx.get(input);
|
|
1875
|
+
if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
|
|
1876
|
+
inputs.push(jv.arg);
|
|
1877
|
+
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1878
|
+
const outputs = [];
|
|
1879
|
+
for (const outVar$1 of eqn.outBinders) {
|
|
1880
|
+
const outId = builder.pushBuffer(outVar$1.aval.size * byteWidth(outVar$1.aval.dtype));
|
|
1881
|
+
outputs.push(outId);
|
|
1882
|
+
ctx.set(outVar$1, {
|
|
1883
|
+
type: "imm",
|
|
1884
|
+
arg: outId
|
|
1885
|
+
});
|
|
1886
|
+
}
|
|
1887
|
+
builder.pushRoutine(routine, inputs, outputs);
|
|
1888
|
+
continue;
|
|
1889
|
+
}
|
|
1754
1890
|
const inputExps = [];
|
|
1755
1891
|
const inputAvals = [];
|
|
1756
1892
|
const inputArgs = [];
|
|
@@ -1805,7 +1941,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1805
1941
|
const outVar = eqn.outBinders[0];
|
|
1806
1942
|
if (blackNodes.has(outVar)) {
|
|
1807
1943
|
const nargs$1 = inputArgs.length;
|
|
1808
|
-
const size$1 =
|
|
1944
|
+
const size$1 = outVar.aval.size;
|
|
1809
1945
|
const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
|
|
1810
1946
|
const outId = builder.pushKernel(kernel, inputArgs);
|
|
1811
1947
|
ctx.set(outVar, {
|
|
@@ -1830,7 +1966,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1830
1966
|
if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
|
|
1831
1967
|
outputIds.push(jitValue.arg);
|
|
1832
1968
|
} else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
|
|
1833
|
-
const outputNeedsRef = new Set(
|
|
1969
|
+
const outputNeedsRef = new Set(range(nargs));
|
|
1834
1970
|
for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
|
|
1835
1971
|
else outputNeedsRef.add(outputId);
|
|
1836
1972
|
builder.insertFreeSteps(outputIds);
|
|
@@ -1876,11 +2012,18 @@ function reshapeJit(fn) {
|
|
|
1876
2012
|
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
1877
2013
|
};
|
|
1878
2014
|
}
|
|
2015
|
+
function routineNoJit() {
|
|
2016
|
+
return () => {
|
|
2017
|
+
throw new Error("jit: rule is not implemented for routines");
|
|
2018
|
+
};
|
|
2019
|
+
}
|
|
1879
2020
|
const jitRules = {
|
|
1880
2021
|
[Primitive.Add]: broadcastedJit(([a, b]) => AluExp.add(a, b)),
|
|
1881
2022
|
[Primitive.Mul]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
|
|
1882
2023
|
[Primitive.Idiv]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
|
|
1883
2024
|
[Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
|
|
2025
|
+
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
2026
|
+
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1884
2027
|
[Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
|
|
1885
2028
|
[Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
|
|
1886
2029
|
[Primitive.Floor]: unopJit(AluExp.floor),
|
|
@@ -1888,17 +2031,6 @@ const jitRules = {
|
|
|
1888
2031
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1889
2032
|
[Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
|
|
1890
2033
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
|
|
1891
|
-
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1892
|
-
const mapping = (st) => {
|
|
1893
|
-
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
|
|
1894
|
-
};
|
|
1895
|
-
const k0 = reshapeViews(keys[0], mapping);
|
|
1896
|
-
const k1 = reshapeViews(keys[1], mapping);
|
|
1897
|
-
const c0 = AluExp.u32(0);
|
|
1898
|
-
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1899
|
-
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1900
|
-
return { exp: exp$2 };
|
|
1901
|
-
},
|
|
1902
2034
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
1903
2035
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
1904
2036
|
[Primitive.Asin]: unopJit(AluExp.asin),
|
|
@@ -1908,8 +2040,6 @@ const jitRules = {
|
|
|
1908
2040
|
[Primitive.Erf]: unopJit(AluExp.erf),
|
|
1909
2041
|
[Primitive.Erfc]: unopJit(AluExp.erfc),
|
|
1910
2042
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1911
|
-
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
1912
|
-
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1913
2043
|
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1914
2044
|
const keptAxes = [];
|
|
1915
2045
|
const shiftedAxes = [];
|
|
@@ -1959,16 +2089,17 @@ const jitRules = {
|
|
|
1959
2089
|
},
|
|
1960
2090
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1961
2091
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1962
|
-
[Primitive.
|
|
1963
|
-
|
|
1964
|
-
|
|
1965
|
-
|
|
1966
|
-
const
|
|
1967
|
-
|
|
1968
|
-
|
|
1969
|
-
|
|
1970
|
-
|
|
1971
|
-
|
|
2092
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2093
|
+
const mapping = (st) => {
|
|
2094
|
+
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
|
|
2095
|
+
};
|
|
2096
|
+
const k0 = reshapeViews(keys[0], mapping);
|
|
2097
|
+
const k1 = reshapeViews(keys[1], mapping);
|
|
2098
|
+
const c0 = AluExp.u32(0);
|
|
2099
|
+
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
2100
|
+
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2101
|
+
return { exp: exp$2 };
|
|
2102
|
+
},
|
|
1972
2103
|
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1973
2104
|
const axisSet = new Set(axis);
|
|
1974
2105
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
@@ -1984,8 +2115,22 @@ const jitRules = {
|
|
|
1984
2115
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1985
2116
|
return { exp: x.substitute({ gidx: index }) };
|
|
1986
2117
|
},
|
|
1987
|
-
[Primitive.
|
|
1988
|
-
|
|
2118
|
+
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2119
|
+
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
2120
|
+
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
2121
|
+
[Primitive.Flip]: reshapeJit((st, { axis }) => {
|
|
2122
|
+
const arg = rep(st.shape.length, false);
|
|
2123
|
+
for (const ax of axis) arg[ax] = true;
|
|
2124
|
+
return st.flip(arg);
|
|
2125
|
+
}),
|
|
2126
|
+
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
2127
|
+
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
2128
|
+
[Primitive.Sort]: routineNoJit(),
|
|
2129
|
+
[Primitive.Argsort]: routineNoJit(),
|
|
2130
|
+
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2131
|
+
[Primitive.Cholesky]: routineNoJit(),
|
|
2132
|
+
[Primitive.Jit]() {
|
|
2133
|
+
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
1989
2134
|
}
|
|
1990
2135
|
};
|
|
1991
2136
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
@@ -2043,8 +2188,8 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2043
2188
|
case Primitive.Mul:
|
|
2044
2189
|
case Primitive.Idiv:
|
|
2045
2190
|
case Primitive.Mod:
|
|
2046
|
-
case Primitive.
|
|
2047
|
-
case Primitive.
|
|
2191
|
+
case Primitive.Min:
|
|
2192
|
+
case Primitive.Max: {
|
|
2048
2193
|
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2049
2194
|
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2050
2195
|
head = usages[0];
|
|
@@ -2064,11 +2209,11 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2064
2209
|
blackNodes.add(v);
|
|
2065
2210
|
p1NextBlack.set(v, v);
|
|
2066
2211
|
}
|
|
2067
|
-
const heterogeneousViewPrimitives = [Primitive.
|
|
2212
|
+
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2068
2213
|
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2069
2214
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2070
2215
|
const eqn = jaxpr.eqns[i];
|
|
2071
|
-
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2216
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2072
2217
|
for (const v of eqn.outBinders) {
|
|
2073
2218
|
blackNodes.add(v);
|
|
2074
2219
|
p1NextBlack.set(v, v);
|
|
@@ -2078,7 +2223,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2078
2223
|
const reach = /* @__PURE__ */ new Set();
|
|
2079
2224
|
let needsCleanOutput = false;
|
|
2080
2225
|
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2081
|
-
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive)) {
|
|
2226
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
|
|
2082
2227
|
needsCleanOutput = true;
|
|
2083
2228
|
break outer;
|
|
2084
2229
|
}
|
|
@@ -2102,7 +2247,6 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2102
2247
|
while (p2idx < jaxpr.eqns.length) {
|
|
2103
2248
|
const eqn = jaxpr.eqns[p2idx++];
|
|
2104
2249
|
const deps = [];
|
|
2105
|
-
if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
|
|
2106
2250
|
for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
|
|
2107
2251
|
else deps.push(p2Deps.get(input));
|
|
2108
2252
|
else deps.push(/* @__PURE__ */ new Set());
|
|
@@ -2125,7 +2269,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2125
2269
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2126
2270
|
const assocVar = eqn.inputs[assocInput];
|
|
2127
2271
|
p2idx = varToDefn.get(assocVar);
|
|
2128
|
-
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2272
|
+
for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
|
|
2129
2273
|
} else {
|
|
2130
2274
|
const s = new Set(depCounter.keys());
|
|
2131
2275
|
for (const out of eqn.outBinders) p2Deps.set(out, s);
|
|
@@ -2151,9 +2295,9 @@ var PendingExecute = class {
|
|
|
2151
2295
|
submitted = false;
|
|
2152
2296
|
#promise = null;
|
|
2153
2297
|
#rc = 1;
|
|
2154
|
-
constructor(backend,
|
|
2298
|
+
constructor(backend, source, inputs, outputs) {
|
|
2155
2299
|
this.backend = backend;
|
|
2156
|
-
this.
|
|
2300
|
+
this.source = source;
|
|
2157
2301
|
this.inputs = inputs;
|
|
2158
2302
|
this.outputs = outputs;
|
|
2159
2303
|
for (const slot of inputs) this.backend.incRef(slot);
|
|
@@ -2174,13 +2318,15 @@ var PendingExecute = class {
|
|
|
2174
2318
|
return;
|
|
2175
2319
|
}
|
|
2176
2320
|
this.#promise = (async () => {
|
|
2177
|
-
this.prepared = await this.backend.
|
|
2321
|
+
if (this.source instanceof Kernel) this.prepared = await this.backend.prepareKernel(this.source);
|
|
2322
|
+
else this.prepared = await this.backend.prepareRoutine(this.source);
|
|
2178
2323
|
})();
|
|
2179
2324
|
await this.#promise;
|
|
2180
2325
|
}
|
|
2181
2326
|
prepareSync() {
|
|
2182
2327
|
if (this.prepared) return;
|
|
2183
|
-
this.prepared = this.backend.
|
|
2328
|
+
if (this.source instanceof Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
|
|
2329
|
+
else this.prepared = this.backend.prepareRoutineSync(this.source);
|
|
2184
2330
|
}
|
|
2185
2331
|
submit() {
|
|
2186
2332
|
if (this.submitted) return;
|
|
@@ -2203,8 +2349,6 @@ var PendingExecute = class {
|
|
|
2203
2349
|
* "Array" type by name.
|
|
2204
2350
|
*/
|
|
2205
2351
|
var Array$1 = class Array$1 extends Tracer {
|
|
2206
|
-
static #nextId = 1001;
|
|
2207
|
-
id;
|
|
2208
2352
|
#dtype;
|
|
2209
2353
|
#weakType;
|
|
2210
2354
|
#source;
|
|
@@ -2221,7 +2365,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2221
2365
|
*/
|
|
2222
2366
|
constructor(args) {
|
|
2223
2367
|
super(baseArrayTrace);
|
|
2224
|
-
this.id = Array$1.#nextId++;
|
|
2225
2368
|
this.#dtype = args.dtype;
|
|
2226
2369
|
this.#weakType = args.weakType;
|
|
2227
2370
|
this.#source = args.source;
|
|
@@ -2530,6 +2673,27 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2530
2673
|
pending
|
|
2531
2674
|
});
|
|
2532
2675
|
}
|
|
2676
|
+
/** Apply an operation with custom lowering to this array. */
|
|
2677
|
+
static #routine(routine, arrays, outputWeakType) {
|
|
2678
|
+
const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
|
|
2679
|
+
for (const ar of arrays) ar.#realize();
|
|
2680
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2681
|
+
const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
|
|
2682
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2683
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2684
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2685
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2686
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2687
|
+
return outputs.map((output, i) => new Array$1({
|
|
2688
|
+
source: output,
|
|
2689
|
+
st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
|
|
2690
|
+
dtype: routine.type.outputDtypes[i],
|
|
2691
|
+
weakType: outputWeakType[i],
|
|
2692
|
+
backend,
|
|
2693
|
+
committed,
|
|
2694
|
+
pending
|
|
2695
|
+
}));
|
|
2696
|
+
}
|
|
2533
2697
|
/**
|
|
2534
2698
|
* Normalizes this array into one backed by a `Slot`.
|
|
2535
2699
|
*
|
|
@@ -2690,6 +2854,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2690
2854
|
[Primitive.Mod]([x, y]) {
|
|
2691
2855
|
return [x.#binary(AluOp.Mod, y)];
|
|
2692
2856
|
},
|
|
2857
|
+
[Primitive.Min]([x, y]) {
|
|
2858
|
+
return [x.#binary(AluOp.Min, y)];
|
|
2859
|
+
},
|
|
2860
|
+
[Primitive.Max]([x, y]) {
|
|
2861
|
+
return [x.#binary(AluOp.Max, y)];
|
|
2862
|
+
},
|
|
2693
2863
|
[Primitive.Neg]([x]) {
|
|
2694
2864
|
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
2695
2865
|
},
|
|
@@ -2726,25 +2896,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2726
2896
|
return [y];
|
|
2727
2897
|
}
|
|
2728
2898
|
},
|
|
2729
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2730
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2731
|
-
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2732
|
-
const c0 = zeros(shape$1, {
|
|
2733
|
-
dtype: DType.Uint32,
|
|
2734
|
-
device: k0.device
|
|
2735
|
-
});
|
|
2736
|
-
const c1 = arange(0, prod(shape$1), 1, {
|
|
2737
|
-
dtype: DType.Uint32,
|
|
2738
|
-
device: k0.device
|
|
2739
|
-
}).reshape(shape$1);
|
|
2740
|
-
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2741
|
-
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2742
|
-
k0,
|
|
2743
|
-
k1,
|
|
2744
|
-
c0,
|
|
2745
|
-
c1
|
|
2746
|
-
])];
|
|
2747
|
-
},
|
|
2748
2899
|
[Primitive.Sin]([x]) {
|
|
2749
2900
|
return [x.#unary(AluOp.Sin)];
|
|
2750
2901
|
},
|
|
@@ -2772,12 +2923,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2772
2923
|
[Primitive.Sqrt]([x]) {
|
|
2773
2924
|
return [x.#unary(AluOp.Sqrt)];
|
|
2774
2925
|
},
|
|
2775
|
-
[Primitive.Min]([x, y]) {
|
|
2776
|
-
return [x.#binary(AluOp.Min, y)];
|
|
2777
|
-
},
|
|
2778
|
-
[Primitive.Max]([x, y]) {
|
|
2779
|
-
return [x.#binary(AluOp.Max, y)];
|
|
2780
|
-
},
|
|
2781
2926
|
[Primitive.Reduce]([x], { op, axis }) {
|
|
2782
2927
|
if (axis.length === 0) return [x];
|
|
2783
2928
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
@@ -2812,13 +2957,35 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2812
2957
|
y
|
|
2813
2958
|
], { dtypeOverride: [DType.Bool] })];
|
|
2814
2959
|
},
|
|
2815
|
-
[Primitive.
|
|
2816
|
-
|
|
2817
|
-
|
|
2818
|
-
|
|
2819
|
-
|
|
2820
|
-
|
|
2821
|
-
|
|
2960
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2961
|
+
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2962
|
+
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2963
|
+
const c0 = zeros(shape$1, {
|
|
2964
|
+
dtype: DType.Uint32,
|
|
2965
|
+
device: k0.device
|
|
2966
|
+
});
|
|
2967
|
+
const c1 = arange(0, prod(shape$1), 1, {
|
|
2968
|
+
dtype: DType.Uint32,
|
|
2969
|
+
device: k0.device
|
|
2970
|
+
}).reshape(shape$1);
|
|
2971
|
+
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2972
|
+
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2973
|
+
k0,
|
|
2974
|
+
k1,
|
|
2975
|
+
c0,
|
|
2976
|
+
c1
|
|
2977
|
+
])];
|
|
2978
|
+
},
|
|
2979
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2980
|
+
return [x.#gather(indices, axis, outDim)];
|
|
2981
|
+
},
|
|
2982
|
+
[Primitive.Transpose]([x], { perm }) {
|
|
2983
|
+
return [x.#transpose(perm)];
|
|
2984
|
+
},
|
|
2985
|
+
[Primitive.Broadcast]([x], { shape: shape$1, axis }) {
|
|
2986
|
+
return [x.#reshape(x.#st.broadcast(shape$1, axis))];
|
|
2987
|
+
},
|
|
2988
|
+
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2822
2989
|
return [x.#reshape(x.#st.reshape(shape$1))];
|
|
2823
2990
|
},
|
|
2824
2991
|
[Primitive.Flip]([x], { axis }) {
|
|
@@ -2832,17 +2999,48 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2832
2999
|
[Primitive.Pad]([x], { width }) {
|
|
2833
3000
|
return [x.#reshape(x.#st.pad(width))];
|
|
2834
3001
|
},
|
|
2835
|
-
[Primitive.
|
|
2836
|
-
|
|
3002
|
+
[Primitive.Sort]([x]) {
|
|
3003
|
+
const routine = new Routine(Routines.Sort, {
|
|
3004
|
+
inputShapes: [x.aval.shape],
|
|
3005
|
+
inputDtypes: [x.aval.dtype],
|
|
3006
|
+
outputShapes: [x.aval.shape],
|
|
3007
|
+
outputDtypes: [x.aval.dtype]
|
|
3008
|
+
});
|
|
3009
|
+
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
2837
3010
|
},
|
|
2838
|
-
[Primitive.
|
|
2839
|
-
|
|
2840
|
-
|
|
3011
|
+
[Primitive.Argsort]([x]) {
|
|
3012
|
+
const routine = new Routine(Routines.Argsort, {
|
|
3013
|
+
inputShapes: [x.aval.shape],
|
|
3014
|
+
inputDtypes: [x.aval.dtype],
|
|
3015
|
+
outputShapes: [x.aval.shape, x.aval.shape],
|
|
3016
|
+
outputDtypes: [x.aval.dtype, DType.Int32]
|
|
3017
|
+
});
|
|
3018
|
+
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3019
|
+
},
|
|
3020
|
+
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3021
|
+
const routine = new Routine(Routines.TriangularSolve, {
|
|
3022
|
+
inputShapes: [a.aval.shape, b.aval.shape],
|
|
3023
|
+
inputDtypes: [a.aval.dtype, b.aval.dtype],
|
|
3024
|
+
outputShapes: [b.aval.shape],
|
|
3025
|
+
outputDtypes: [b.aval.dtype]
|
|
3026
|
+
}, { unitDiagonal });
|
|
3027
|
+
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
3028
|
+
},
|
|
3029
|
+
[Primitive.Cholesky]([a]) {
|
|
3030
|
+
const routine = new Routine(Routines.Cholesky, {
|
|
3031
|
+
inputShapes: [a.aval.shape],
|
|
3032
|
+
inputDtypes: [a.aval.dtype],
|
|
3033
|
+
outputShapes: [a.aval.shape],
|
|
3034
|
+
outputDtypes: [a.aval.dtype]
|
|
3035
|
+
});
|
|
3036
|
+
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3037
|
+
},
|
|
3038
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
3039
|
+
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3040
|
+
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
2841
3041
|
args = args.map((ar) => ar._putSync(backend));
|
|
2842
|
-
const
|
|
2843
|
-
const
|
|
2844
|
-
const jp = jitCompile(backend, jaxpr, consts);
|
|
2845
|
-
const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
|
|
3042
|
+
const jp = jitCompile(backend, jaxpr);
|
|
3043
|
+
const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
|
|
2846
3044
|
for (const exe of pending) exe.updateRc(+outputs.length - 1);
|
|
2847
3045
|
const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
|
|
2848
3046
|
for (const exe of prevPending) exe.updateRc(+outputs.length);
|
|
@@ -3141,6 +3339,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
3141
3339
|
});
|
|
3142
3340
|
}
|
|
3143
3341
|
/**
|
|
3342
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
3343
|
+
*
|
|
3344
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
3345
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
3346
|
+
* `k>0` is above it.
|
|
3347
|
+
*/
|
|
3348
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
3349
|
+
m ??= n;
|
|
3350
|
+
dtype ??= DType.Float32;
|
|
3351
|
+
if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
|
|
3352
|
+
if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
|
|
3353
|
+
if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
|
|
3354
|
+
const rows = arange(k, n + k, 1, {
|
|
3355
|
+
dtype: DType.Int32,
|
|
3356
|
+
device
|
|
3357
|
+
});
|
|
3358
|
+
const cols = arange(0, m, 1, {
|
|
3359
|
+
dtype: DType.Int32,
|
|
3360
|
+
device
|
|
3361
|
+
});
|
|
3362
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
3363
|
+
}
|
|
3364
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
3365
|
+
function tril(a, k = 0) {
|
|
3366
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3367
|
+
a = fudgeArray(a);
|
|
3368
|
+
const [n, m] = a.shape.slice(-2);
|
|
3369
|
+
return where$1(tri(n, m, k, { dtype: DType.Bool }), a.ref, zerosLike$1(a));
|
|
3370
|
+
}
|
|
3371
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
3372
|
+
function triu(a, k = 0) {
|
|
3373
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3374
|
+
a = fudgeArray(a);
|
|
3375
|
+
const [n, m] = a.shape.slice(-2);
|
|
3376
|
+
return where$1(tri(n, m, k - 1, { dtype: DType.Bool }), zerosLike$1(a.ref), a);
|
|
3377
|
+
}
|
|
3378
|
+
/**
|
|
3144
3379
|
* Return evenly spaced numbers over a specified interval.
|
|
3145
3380
|
*
|
|
3146
3381
|
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
@@ -3187,383 +3422,186 @@ function aluCompare(a, b, op) {
|
|
|
3187
3422
|
}
|
|
3188
3423
|
|
|
3189
3424
|
//#endregion
|
|
3190
|
-
//#region src/frontend/
|
|
3191
|
-
|
|
3192
|
-
|
|
3425
|
+
//#region src/frontend/vmap.ts
|
|
3426
|
+
function mappedAval(batchDim, aval) {
|
|
3427
|
+
const shape$1 = [...aval.shape];
|
|
3428
|
+
shape$1.splice(batchDim, 1);
|
|
3429
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3430
|
+
}
|
|
3431
|
+
/** Move one axis to a different index. */
|
|
3432
|
+
function moveaxis(x, src, dst) {
|
|
3433
|
+
const t = pureArray(x);
|
|
3434
|
+
src = checkAxis(src, t.ndim);
|
|
3435
|
+
dst = checkAxis(dst, t.ndim);
|
|
3436
|
+
if (src === dst) return t;
|
|
3437
|
+
const perm = range(t.ndim);
|
|
3438
|
+
perm.splice(src, 1);
|
|
3439
|
+
perm.splice(dst, 0, src);
|
|
3440
|
+
return transpose$1(t, perm);
|
|
3441
|
+
}
|
|
3442
|
+
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3443
|
+
if (src === null) {
|
|
3444
|
+
const targetShape = [...x.shape];
|
|
3445
|
+
targetShape.splice(dst, 0, axisSize);
|
|
3446
|
+
return broadcast(x, targetShape, [dst]);
|
|
3447
|
+
} else if (src === dst) return x;
|
|
3448
|
+
else return moveaxis(x, src, dst);
|
|
3449
|
+
}
|
|
3450
|
+
var BatchTracer = class extends Tracer {
|
|
3451
|
+
constructor(trace$1, val, batchDim) {
|
|
3193
3452
|
super(trace$1);
|
|
3194
|
-
this.
|
|
3195
|
-
this.
|
|
3453
|
+
this.val = val;
|
|
3454
|
+
this.batchDim = batchDim;
|
|
3196
3455
|
}
|
|
3197
3456
|
get aval() {
|
|
3198
|
-
return this.
|
|
3457
|
+
if (this.batchDim === null) return this.val.aval;
|
|
3458
|
+
else return mappedAval(this.batchDim, this.val.aval);
|
|
3199
3459
|
}
|
|
3200
3460
|
toString() {
|
|
3201
|
-
return `
|
|
3461
|
+
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3202
3462
|
}
|
|
3203
3463
|
get ref() {
|
|
3204
|
-
this.
|
|
3464
|
+
this.val.ref;
|
|
3205
3465
|
return this;
|
|
3206
3466
|
}
|
|
3207
3467
|
dispose() {
|
|
3208
|
-
this.
|
|
3209
|
-
|
|
3468
|
+
this.val.dispose();
|
|
3469
|
+
}
|
|
3470
|
+
fullLower() {
|
|
3471
|
+
if (this.batchDim === null) return this.val.fullLower();
|
|
3472
|
+
else return this;
|
|
3210
3473
|
}
|
|
3211
3474
|
};
|
|
3212
|
-
var
|
|
3475
|
+
var BatchTrace = class extends Trace {
|
|
3213
3476
|
pure(val) {
|
|
3214
3477
|
return this.lift(pureArray(val));
|
|
3215
3478
|
}
|
|
3216
3479
|
lift(val) {
|
|
3217
|
-
return new
|
|
3480
|
+
return new BatchTracer(this, val, null);
|
|
3218
3481
|
}
|
|
3219
3482
|
processPrimitive(primitive, tracers, params) {
|
|
3220
|
-
const [
|
|
3221
|
-
const
|
|
3222
|
-
if (
|
|
3223
|
-
|
|
3224
|
-
|
|
3483
|
+
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3484
|
+
const vmapRule = vmapRules[primitive];
|
|
3485
|
+
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3486
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3487
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3488
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3489
|
+
}
|
|
3490
|
+
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3491
|
+
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3492
|
+
}
|
|
3493
|
+
get axisSize() {
|
|
3494
|
+
return this.main.globalData;
|
|
3225
3495
|
}
|
|
3226
3496
|
};
|
|
3227
|
-
/**
|
|
3228
|
-
|
|
3229
|
-
|
|
3230
|
-
|
|
3231
|
-
|
|
3232
|
-
|
|
3233
|
-
|
|
3234
|
-
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
|
|
3239
|
-
|
|
3240
|
-
|
|
3497
|
+
/**
|
|
3498
|
+
* Process a primitive with built-in broadcasting.
|
|
3499
|
+
*
|
|
3500
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3501
|
+
*/
|
|
3502
|
+
function broadcastBatcher(op) {
|
|
3503
|
+
return (axisSize, args, dims) => {
|
|
3504
|
+
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3505
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3506
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3507
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3508
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3509
|
+
args = args.map((x, i) => {
|
|
3510
|
+
if (dims[i] === null) return x;
|
|
3511
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3512
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3513
|
+
x.shape[0],
|
|
3514
|
+
...rep(nd - x.ndim, 1),
|
|
3515
|
+
...x.shape.slice(1)
|
|
3516
|
+
]);
|
|
3517
|
+
return x;
|
|
3518
|
+
});
|
|
3519
|
+
return [[op(...args)], [0]];
|
|
3241
3520
|
};
|
|
3242
3521
|
}
|
|
3243
|
-
|
|
3244
|
-
|
|
3245
|
-
|
|
3246
|
-
for (const t of tangents) t.dispose();
|
|
3247
|
-
const ys = bind(primitive, primals, params);
|
|
3248
|
-
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3522
|
+
function unopBatcher(op) {
|
|
3523
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3524
|
+
return [[op(x, params)], [xBdim]];
|
|
3249
3525
|
};
|
|
3250
3526
|
}
|
|
3251
|
-
const
|
|
3252
|
-
[Primitive.Add]:
|
|
3253
|
-
[Primitive.Mul]:
|
|
3254
|
-
[Primitive.Idiv]:
|
|
3255
|
-
[Primitive.Mod](
|
|
3256
|
-
|
|
3257
|
-
|
|
3258
|
-
|
|
3259
|
-
|
|
3260
|
-
|
|
3261
|
-
|
|
3262
|
-
|
|
3527
|
+
const vmapRules = {
|
|
3528
|
+
[Primitive.Add]: broadcastBatcher(add$1),
|
|
3529
|
+
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3530
|
+
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3531
|
+
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3532
|
+
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3533
|
+
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3534
|
+
[Primitive.Neg]: unopBatcher(neg),
|
|
3535
|
+
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3536
|
+
[Primitive.Floor]: unopBatcher(floor$1),
|
|
3537
|
+
[Primitive.Ceil]: unopBatcher(ceil$1),
|
|
3538
|
+
[Primitive.StopGradient]: unopBatcher(stopGradient),
|
|
3539
|
+
[Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
|
|
3540
|
+
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3541
|
+
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3542
|
+
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3543
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3544
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3545
|
+
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3546
|
+
[Primitive.Log]: unopBatcher(log$1),
|
|
3547
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3548
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3549
|
+
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3550
|
+
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3551
|
+
assertNonNull(xBdim);
|
|
3552
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3553
|
+
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3554
|
+
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3263
3555
|
},
|
|
3264
|
-
[Primitive.
|
|
3265
|
-
|
|
3266
|
-
|
|
3267
|
-
|
|
3556
|
+
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3557
|
+
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3558
|
+
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3559
|
+
const z = dot$2(x, y);
|
|
3560
|
+
return [[z], [z.ndim - 1]];
|
|
3268
3561
|
},
|
|
3269
|
-
[Primitive.
|
|
3270
|
-
|
|
3271
|
-
|
|
3272
|
-
|
|
3273
|
-
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3277
|
-
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3278
|
-
}
|
|
3562
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3563
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3564
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3565
|
+
const z = conv$1(x, y, {
|
|
3566
|
+
...params,
|
|
3567
|
+
vmapDims: params.vmapDims + 1
|
|
3568
|
+
});
|
|
3569
|
+
return [[z], [0]];
|
|
3279
3570
|
},
|
|
3280
|
-
[Primitive.
|
|
3281
|
-
|
|
3282
|
-
dx.dispose();
|
|
3283
|
-
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3571
|
+
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3572
|
+
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3284
3573
|
},
|
|
3285
|
-
[Primitive.
|
|
3286
|
-
[Primitive.
|
|
3287
|
-
|
|
3288
|
-
|
|
3289
|
-
|
|
3290
|
-
|
|
3291
|
-
|
|
3292
|
-
|
|
3293
|
-
|
|
3294
|
-
|
|
3295
|
-
},
|
|
3296
|
-
[Primitive.Atan]([x], [dx]) {
|
|
3297
|
-
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3298
|
-
return [[atan$1(x)], [dx.div(denom)]];
|
|
3299
|
-
},
|
|
3300
|
-
[Primitive.Exp]([x], [dx]) {
|
|
3301
|
-
const z = exp$1(x);
|
|
3302
|
-
return [[z.ref], [z.mul(dx)]];
|
|
3303
|
-
},
|
|
3304
|
-
[Primitive.Log]([x], [dx]) {
|
|
3305
|
-
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3306
|
-
},
|
|
3307
|
-
[Primitive.Erf]([x], [dx]) {
|
|
3308
|
-
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3309
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3310
|
-
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3311
|
-
},
|
|
3312
|
-
[Primitive.Erfc]([x], [dx]) {
|
|
3313
|
-
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3314
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3315
|
-
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3316
|
-
},
|
|
3317
|
-
[Primitive.Sqrt]([x], [dx]) {
|
|
3318
|
-
const z = sqrt$1(x);
|
|
3319
|
-
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3320
|
-
},
|
|
3321
|
-
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3322
|
-
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3323
|
-
},
|
|
3324
|
-
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3325
|
-
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3326
|
-
},
|
|
3327
|
-
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3328
|
-
if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3329
|
-
else if (op === AluOp.Mul) {
|
|
3330
|
-
const primal = reduce(x.ref, op, axis);
|
|
3331
|
-
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3332
|
-
return [[primal], [tangent]];
|
|
3333
|
-
} else if (op === AluOp.Min || op === AluOp.Max) {
|
|
3334
|
-
const primal = reduce(x.ref, op, axis);
|
|
3335
|
-
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3336
|
-
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3337
|
-
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3338
|
-
return [[primal], [tangent]];
|
|
3339
|
-
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3340
|
-
},
|
|
3341
|
-
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3342
|
-
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3343
|
-
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3344
|
-
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3345
|
-
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3346
|
-
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3347
|
-
dcond.dispose();
|
|
3348
|
-
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3349
|
-
},
|
|
3350
|
-
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3351
|
-
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3352
|
-
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3353
|
-
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3354
|
-
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3355
|
-
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3356
|
-
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3357
|
-
const indicesRef = indices.map((t) => t.ref);
|
|
3358
|
-
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3359
|
-
},
|
|
3360
|
-
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3361
|
-
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3362
|
-
const outs = bind(Primitive.JitCall, [
|
|
3363
|
-
...newConsts.map((c) => c.ref),
|
|
3364
|
-
...primals,
|
|
3365
|
-
...tangents
|
|
3366
|
-
], {
|
|
3367
|
-
name: `${name}_jvp`,
|
|
3368
|
-
jaxpr: newJaxpr,
|
|
3369
|
-
numConsts: newConsts.length
|
|
3370
|
-
});
|
|
3371
|
-
const n = outs.length / 2;
|
|
3372
|
-
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3373
|
-
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3374
|
-
return [primalsOut, tangentsOut];
|
|
3375
|
-
}
|
|
3376
|
-
};
|
|
3377
|
-
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3378
|
-
function jvpJaxpr(jaxpr) {
|
|
3379
|
-
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3380
|
-
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3381
|
-
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3382
|
-
const result = {
|
|
3383
|
-
newJaxpr,
|
|
3384
|
-
newConsts
|
|
3385
|
-
};
|
|
3386
|
-
jvpJaxprCache.set(jaxpr, result);
|
|
3387
|
-
return result;
|
|
3388
|
-
}
|
|
3389
|
-
function jvpFlat(f, primals, tangents) {
|
|
3390
|
-
try {
|
|
3391
|
-
var _usingCtx$1 = _usingCtx();
|
|
3392
|
-
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3393
|
-
const trace$1 = new JVPTrace(main);
|
|
3394
|
-
const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3395
|
-
const outs = f(...tracersIn);
|
|
3396
|
-
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3397
|
-
return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3398
|
-
} catch (_) {
|
|
3399
|
-
_usingCtx$1.e = _;
|
|
3400
|
-
} finally {
|
|
3401
|
-
_usingCtx$1.d();
|
|
3402
|
-
}
|
|
3403
|
-
}
|
|
3404
|
-
function jvp$1(f, primals, tangents) {
|
|
3405
|
-
const [primalsFlat, inTree] = flatten(primals);
|
|
3406
|
-
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
3407
|
-
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
3408
|
-
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
3409
|
-
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
3410
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
3411
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3412
|
-
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
3413
|
-
return [primalsOut, tangentsOut];
|
|
3414
|
-
}
|
|
3415
|
-
|
|
3416
|
-
//#endregion
|
|
3417
|
-
//#region src/frontend/vmap.ts
|
|
3418
|
-
function mappedAval(batchDim, aval) {
|
|
3419
|
-
const shape$1 = [...aval.shape];
|
|
3420
|
-
shape$1.splice(batchDim, 1);
|
|
3421
|
-
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3422
|
-
}
|
|
3423
|
-
/** Move one axis to a different index. */
|
|
3424
|
-
function moveaxis(x, src, dst) {
|
|
3425
|
-
const t = pureArray(x);
|
|
3426
|
-
src = checkAxis(src, t.ndim);
|
|
3427
|
-
dst = checkAxis(dst, t.ndim);
|
|
3428
|
-
if (src === dst) return t;
|
|
3429
|
-
const perm = range(t.ndim);
|
|
3430
|
-
perm.splice(src, 1);
|
|
3431
|
-
perm.splice(dst, 0, src);
|
|
3432
|
-
return transpose$1(t, perm);
|
|
3433
|
-
}
|
|
3434
|
-
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3435
|
-
if (src === null) {
|
|
3436
|
-
const targetShape = [...x.shape];
|
|
3437
|
-
targetShape.splice(dst, 0, axisSize);
|
|
3438
|
-
return broadcast(x, targetShape, [dst]);
|
|
3439
|
-
} else if (src === dst) return x;
|
|
3440
|
-
else return moveaxis(x, src, dst);
|
|
3441
|
-
}
|
|
3442
|
-
var BatchTracer = class extends Tracer {
|
|
3443
|
-
constructor(trace$1, val, batchDim) {
|
|
3444
|
-
super(trace$1);
|
|
3445
|
-
this.val = val;
|
|
3446
|
-
this.batchDim = batchDim;
|
|
3447
|
-
}
|
|
3448
|
-
get aval() {
|
|
3449
|
-
if (this.batchDim === null) return this.val.aval;
|
|
3450
|
-
else return mappedAval(this.batchDim, this.val.aval);
|
|
3451
|
-
}
|
|
3452
|
-
toString() {
|
|
3453
|
-
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3454
|
-
}
|
|
3455
|
-
get ref() {
|
|
3456
|
-
this.val.ref;
|
|
3457
|
-
return this;
|
|
3458
|
-
}
|
|
3459
|
-
dispose() {
|
|
3460
|
-
this.val.dispose();
|
|
3461
|
-
}
|
|
3462
|
-
fullLower() {
|
|
3463
|
-
if (this.batchDim === null) return this.val.fullLower();
|
|
3464
|
-
else return this;
|
|
3465
|
-
}
|
|
3466
|
-
};
|
|
3467
|
-
var BatchTrace = class extends Trace {
|
|
3468
|
-
pure(val) {
|
|
3469
|
-
return this.lift(pureArray(val));
|
|
3470
|
-
}
|
|
3471
|
-
lift(val) {
|
|
3472
|
-
return new BatchTracer(this, val, null);
|
|
3473
|
-
}
|
|
3474
|
-
processPrimitive(primitive, tracers, params) {
|
|
3475
|
-
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3476
|
-
const vmapRule = vmapRules[primitive];
|
|
3477
|
-
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3478
|
-
if (bdimsIn.every((d) => d === null)) {
|
|
3479
|
-
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3480
|
-
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3574
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3575
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3576
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3577
|
+
assertNonNull(xBdim);
|
|
3578
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3579
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3580
|
+
let newOutDim = outDim;
|
|
3581
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3582
|
+
else newOutDim += 1;
|
|
3583
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3481
3584
|
}
|
|
3482
|
-
const
|
|
3483
|
-
|
|
3484
|
-
|
|
3485
|
-
|
|
3486
|
-
|
|
3487
|
-
|
|
3488
|
-
|
|
3489
|
-
|
|
3490
|
-
* Process a primitive with built-in broadcasting.
|
|
3491
|
-
*
|
|
3492
|
-
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3493
|
-
*/
|
|
3494
|
-
function broadcastBatcher(op) {
|
|
3495
|
-
return (axisSize, args, dims) => {
|
|
3496
|
-
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3497
|
-
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3498
|
-
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3499
|
-
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3500
|
-
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3501
|
-
args = args.map((x, i) => {
|
|
3502
|
-
if (dims[i] === null) return x;
|
|
3503
|
-
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3504
|
-
if (x.ndim < nd) x = x.reshape([
|
|
3505
|
-
x.shape[0],
|
|
3506
|
-
...rep(nd - x.ndim, 1),
|
|
3507
|
-
...x.shape.slice(1)
|
|
3585
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3586
|
+
indices = indices.map((m, i) => {
|
|
3587
|
+
if (indicesBdim[i] === null) return m;
|
|
3588
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3589
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3590
|
+
m.shape[0],
|
|
3591
|
+
...rep(nd - m.ndim, 1),
|
|
3592
|
+
...m.shape.slice(1)
|
|
3508
3593
|
]);
|
|
3509
|
-
return
|
|
3510
|
-
});
|
|
3511
|
-
return [[op(...args)], [0]];
|
|
3512
|
-
};
|
|
3513
|
-
}
|
|
3514
|
-
function unopBatcher(op) {
|
|
3515
|
-
return (axisSize, [x], [xBdim], params) => {
|
|
3516
|
-
return [[op(x, params)], [xBdim]];
|
|
3517
|
-
};
|
|
3518
|
-
}
|
|
3519
|
-
const vmapRules = {
|
|
3520
|
-
[Primitive.Add]: broadcastBatcher(add$1),
|
|
3521
|
-
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3522
|
-
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3523
|
-
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3524
|
-
[Primitive.Neg]: unopBatcher(neg),
|
|
3525
|
-
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3526
|
-
[Primitive.Floor]: unopBatcher(floor$1),
|
|
3527
|
-
[Primitive.Ceil]: unopBatcher(ceil$1),
|
|
3528
|
-
[Primitive.StopGradient]: unopBatcher(stopGradient),
|
|
3529
|
-
[Primitive.Cast]: unopBatcher((x, { dtype }) => cast(x, dtype)),
|
|
3530
|
-
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3531
|
-
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3532
|
-
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3533
|
-
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3534
|
-
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3535
|
-
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3536
|
-
[Primitive.Log]: unopBatcher(log$1),
|
|
3537
|
-
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3538
|
-
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3539
|
-
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3540
|
-
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3541
|
-
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3542
|
-
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3543
|
-
assertNonNull(xBdim);
|
|
3544
|
-
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3545
|
-
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3546
|
-
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3547
|
-
},
|
|
3548
|
-
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3549
|
-
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3550
|
-
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3551
|
-
const z = dot$2(x, y);
|
|
3552
|
-
return [[z], [z.ndim - 1]];
|
|
3553
|
-
},
|
|
3554
|
-
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3555
|
-
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3556
|
-
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3557
|
-
const z = conv$1(x, y, {
|
|
3558
|
-
...params,
|
|
3559
|
-
vmapDims: params.vmapDims + 1
|
|
3594
|
+
return m;
|
|
3560
3595
|
});
|
|
3561
|
-
return [[
|
|
3562
|
-
|
|
3563
|
-
|
|
3564
|
-
|
|
3596
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3597
|
+
else {
|
|
3598
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3599
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3600
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
|
|
3601
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3602
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3603
|
+
}
|
|
3565
3604
|
},
|
|
3566
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3567
3605
|
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3568
3606
|
assertNonNull(xBdim);
|
|
3569
3607
|
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
@@ -3595,42 +3633,53 @@ const vmapRules = {
|
|
|
3595
3633
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3596
3634
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3597
3635
|
},
|
|
3598
|
-
[Primitive.
|
|
3599
|
-
|
|
3600
|
-
|
|
3601
|
-
|
|
3602
|
-
|
|
3603
|
-
|
|
3604
|
-
|
|
3605
|
-
|
|
3606
|
-
|
|
3607
|
-
|
|
3608
|
-
|
|
3609
|
-
|
|
3610
|
-
|
|
3611
|
-
|
|
3612
|
-
|
|
3613
|
-
|
|
3614
|
-
|
|
3615
|
-
...
|
|
3636
|
+
[Primitive.Sort](axisSize, [x], [xBdim]) {
|
|
3637
|
+
assertNonNull(xBdim);
|
|
3638
|
+
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3639
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3640
|
+
return [[sort$1(x)], [0]];
|
|
3641
|
+
},
|
|
3642
|
+
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3643
|
+
assertNonNull(xBdim);
|
|
3644
|
+
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3645
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3646
|
+
return [argsort$1(x), [0, 0]];
|
|
3647
|
+
},
|
|
3648
|
+
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3649
|
+
if (aBdim === null) {
|
|
3650
|
+
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
3651
|
+
const [s, m, n] = b.shape.slice(-3);
|
|
3652
|
+
b = b.reshape([
|
|
3653
|
+
...b.shape.slice(0, -3),
|
|
3654
|
+
s * m,
|
|
3655
|
+
n
|
|
3616
3656
|
]);
|
|
3617
|
-
|
|
3618
|
-
|
|
3619
|
-
|
|
3620
|
-
|
|
3621
|
-
|
|
3622
|
-
|
|
3623
|
-
|
|
3624
|
-
|
|
3625
|
-
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3657
|
+
let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3658
|
+
x$1 = x$1.reshape([
|
|
3659
|
+
...b.shape.slice(0, -2),
|
|
3660
|
+
s,
|
|
3661
|
+
m,
|
|
3662
|
+
n
|
|
3663
|
+
]);
|
|
3664
|
+
return [[x$1], [x$1.ndim - 3]];
|
|
3626
3665
|
}
|
|
3666
|
+
a = moveBatchAxis(axisSize, aBdim, 0, a);
|
|
3667
|
+
b = moveBatchAxis(axisSize, bBdim, 0, b);
|
|
3668
|
+
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3669
|
+
return [[x], [0]];
|
|
3670
|
+
},
|
|
3671
|
+
[Primitive.Cholesky](axisSize, [x], [xBdim]) {
|
|
3672
|
+
assertNonNull(xBdim);
|
|
3673
|
+
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3674
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3675
|
+
return [[cholesky$2(x)], [0]];
|
|
3627
3676
|
},
|
|
3628
|
-
[Primitive.
|
|
3629
|
-
const
|
|
3630
|
-
const outs = bind(Primitive.
|
|
3677
|
+
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3678
|
+
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3679
|
+
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
3631
3680
|
name: `${name}_vmap`,
|
|
3632
|
-
jaxpr: newJaxpr,
|
|
3633
|
-
numConsts:
|
|
3681
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3682
|
+
numConsts: newJaxpr.consts.length
|
|
3634
3683
|
});
|
|
3635
3684
|
return [outs, rep(outs.length, 0)];
|
|
3636
3685
|
}
|
|
@@ -3646,14 +3695,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3646
3695
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3647
3696
|
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3648
3697
|
});
|
|
3649
|
-
const { jaxpr: newJaxpr
|
|
3650
|
-
const result = {
|
|
3651
|
-
newJaxpr,
|
|
3652
|
-
newConsts
|
|
3653
|
-
};
|
|
3698
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3654
3699
|
if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
3655
|
-
vmapJaxprCache.get(jaxpr).set(cacheKey,
|
|
3656
|
-
return
|
|
3700
|
+
vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
3701
|
+
return newJaxpr;
|
|
3657
3702
|
}
|
|
3658
3703
|
function vmapFlat(f, inAxes, args) {
|
|
3659
3704
|
let axisSize = void 0;
|
|
@@ -3708,6 +3753,260 @@ function jacfwd$1(f) {
|
|
|
3708
3753
|
};
|
|
3709
3754
|
}
|
|
3710
3755
|
|
|
3756
|
+
//#endregion
|
|
3757
|
+
//#region src/frontend/jvp.ts
|
|
3758
|
+
var JVPTracer = class extends Tracer {
|
|
3759
|
+
constructor(trace$1, primal, tangent) {
|
|
3760
|
+
super(trace$1);
|
|
3761
|
+
this.primal = primal;
|
|
3762
|
+
this.tangent = tangent;
|
|
3763
|
+
}
|
|
3764
|
+
get aval() {
|
|
3765
|
+
return this.primal.aval;
|
|
3766
|
+
}
|
|
3767
|
+
toString() {
|
|
3768
|
+
return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
|
|
3769
|
+
}
|
|
3770
|
+
get ref() {
|
|
3771
|
+
this.primal.ref, this.tangent.ref;
|
|
3772
|
+
return this;
|
|
3773
|
+
}
|
|
3774
|
+
dispose() {
|
|
3775
|
+
this.primal.dispose();
|
|
3776
|
+
this.tangent.dispose();
|
|
3777
|
+
}
|
|
3778
|
+
};
|
|
3779
|
+
var JVPTrace = class extends Trace {
|
|
3780
|
+
pure(val) {
|
|
3781
|
+
return this.lift(pureArray(val));
|
|
3782
|
+
}
|
|
3783
|
+
lift(val) {
|
|
3784
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
3785
|
+
}
|
|
3786
|
+
processPrimitive(primitive, tracers, params) {
|
|
3787
|
+
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
3788
|
+
const jvpRule = jvpRules[primitive];
|
|
3789
|
+
if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
|
|
3790
|
+
const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
|
|
3791
|
+
return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
|
|
3792
|
+
}
|
|
3793
|
+
};
|
|
3794
|
+
/** Rule that applies the same operation to primals and tangents. */
|
|
3795
|
+
function linearTangentsJvp(primitive) {
|
|
3796
|
+
return (primals, tangents, params) => {
|
|
3797
|
+
const ys = bind(primitive, primals, params);
|
|
3798
|
+
const dys = bind(primitive, tangents, params);
|
|
3799
|
+
return [ys, dys];
|
|
3800
|
+
};
|
|
3801
|
+
}
|
|
3802
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
3803
|
+
function bilinearTangentsJvp(primitive) {
|
|
3804
|
+
return ([x, y], [dx, dy], params) => {
|
|
3805
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3806
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3807
|
+
return [[primal], [tangent]];
|
|
3808
|
+
};
|
|
3809
|
+
}
|
|
3810
|
+
/** Rule that zeros out any tangents. */
|
|
3811
|
+
function zeroTangentsJvp(primitive) {
|
|
3812
|
+
return (primals, tangents, params) => {
|
|
3813
|
+
for (const t of tangents) t.dispose();
|
|
3814
|
+
const ys = bind(primitive, primals, params);
|
|
3815
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3816
|
+
};
|
|
3817
|
+
}
|
|
3818
|
+
/** Compute `a @ b.T`, batched to last two axes. */
|
|
3819
|
+
function batchMatmulT(a, b) {
|
|
3820
|
+
return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
|
|
3821
|
+
}
|
|
3822
|
+
/** Batch matrix transpose. */
|
|
3823
|
+
function mT(a) {
|
|
3824
|
+
return moveaxis(a, -2, -1);
|
|
3825
|
+
}
|
|
3826
|
+
const jvpRules = {
|
|
3827
|
+
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3828
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
3829
|
+
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
3830
|
+
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
3831
|
+
if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
|
|
3832
|
+
dx.dispose();
|
|
3833
|
+
dy.dispose();
|
|
3834
|
+
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
3835
|
+
}
|
|
3836
|
+
const q = idiv(x.ref, y.ref);
|
|
3837
|
+
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
3838
|
+
},
|
|
3839
|
+
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3840
|
+
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3841
|
+
},
|
|
3842
|
+
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3843
|
+
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3844
|
+
},
|
|
3845
|
+
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
3846
|
+
[Primitive.Reciprocal]([x], [dx]) {
|
|
3847
|
+
const xRecip = reciprocal$1(x.ref);
|
|
3848
|
+
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
3849
|
+
},
|
|
3850
|
+
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
3851
|
+
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
3852
|
+
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
3853
|
+
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
3854
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3855
|
+
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
3856
|
+
else {
|
|
3857
|
+
dx.dispose();
|
|
3858
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3859
|
+
}
|
|
3860
|
+
},
|
|
3861
|
+
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3862
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3863
|
+
dx.dispose();
|
|
3864
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3865
|
+
},
|
|
3866
|
+
[Primitive.Sin]([x], [dx]) {
|
|
3867
|
+
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3868
|
+
},
|
|
3869
|
+
[Primitive.Cos]([x], [dx]) {
|
|
3870
|
+
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
3871
|
+
},
|
|
3872
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3873
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3874
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3875
|
+
},
|
|
3876
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3877
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3878
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3879
|
+
},
|
|
3880
|
+
[Primitive.Exp]([x], [dx]) {
|
|
3881
|
+
const z = exp$1(x);
|
|
3882
|
+
return [[z.ref], [z.mul(dx)]];
|
|
3883
|
+
},
|
|
3884
|
+
[Primitive.Log]([x], [dx]) {
|
|
3885
|
+
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3886
|
+
},
|
|
3887
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3888
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3889
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3890
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3891
|
+
},
|
|
3892
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3893
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3894
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3895
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3896
|
+
},
|
|
3897
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
3898
|
+
const z = sqrt$1(x);
|
|
3899
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3900
|
+
},
|
|
3901
|
+
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3902
|
+
if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3903
|
+
else if (op === AluOp.Mul) {
|
|
3904
|
+
const primal = reduce(x.ref, op, axis);
|
|
3905
|
+
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3906
|
+
return [[primal], [tangent]];
|
|
3907
|
+
} else if (op === AluOp.Min || op === AluOp.Max) {
|
|
3908
|
+
const primal = reduce(x.ref, op, axis);
|
|
3909
|
+
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3910
|
+
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3911
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3912
|
+
return [[primal], [tangent]];
|
|
3913
|
+
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3914
|
+
},
|
|
3915
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3916
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3917
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3918
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3919
|
+
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3920
|
+
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3921
|
+
dcond.dispose();
|
|
3922
|
+
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3923
|
+
},
|
|
3924
|
+
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3925
|
+
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3926
|
+
const indicesRef = indices.map((t) => t.ref);
|
|
3927
|
+
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3928
|
+
},
|
|
3929
|
+
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3930
|
+
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3931
|
+
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3932
|
+
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3933
|
+
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3934
|
+
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3935
|
+
[Primitive.Sort]([x], [dx]) {
|
|
3936
|
+
const [y, idx] = argsort$1(x);
|
|
3937
|
+
return [[y], [gather(dx, [idx], [-1], -1)]];
|
|
3938
|
+
},
|
|
3939
|
+
[Primitive.Argsort]([x], [dx]) {
|
|
3940
|
+
const [y, idx] = argsort$1(x);
|
|
3941
|
+
return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
|
|
3942
|
+
},
|
|
3943
|
+
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
3944
|
+
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
3945
|
+
const dax = batchMatmulT(da, x.ref);
|
|
3946
|
+
const rhsT = db.sub(mT(dax));
|
|
3947
|
+
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
3948
|
+
return [[x], [dx]];
|
|
3949
|
+
},
|
|
3950
|
+
[Primitive.Cholesky]([a], [da]) {
|
|
3951
|
+
const L = cholesky$2(a.ref);
|
|
3952
|
+
da = da.ref.add(mT(da)).mul(.5);
|
|
3953
|
+
const W = triangularSolve$1(L.ref, da, { lower: true });
|
|
3954
|
+
const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
|
|
3955
|
+
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3956
|
+
return [[L], [dL]];
|
|
3957
|
+
},
|
|
3958
|
+
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3959
|
+
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3960
|
+
const outs = bind(Primitive.Jit, [
|
|
3961
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
3962
|
+
...primals,
|
|
3963
|
+
...tangents
|
|
3964
|
+
], {
|
|
3965
|
+
name: `${name}_jvp`,
|
|
3966
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3967
|
+
numConsts: newJaxpr.consts.length
|
|
3968
|
+
});
|
|
3969
|
+
const n = outs.length / 2;
|
|
3970
|
+
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3971
|
+
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3972
|
+
return [primalsOut, tangentsOut];
|
|
3973
|
+
}
|
|
3974
|
+
};
|
|
3975
|
+
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3976
|
+
function jvpJaxpr(jaxpr) {
|
|
3977
|
+
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3978
|
+
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3979
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3980
|
+
jvpJaxprCache.set(jaxpr, newJaxpr);
|
|
3981
|
+
return newJaxpr;
|
|
3982
|
+
}
|
|
3983
|
+
function jvpFlat(f, primals, tangents) {
|
|
3984
|
+
try {
|
|
3985
|
+
var _usingCtx$1 = _usingCtx();
|
|
3986
|
+
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3987
|
+
const trace$1 = new JVPTrace(main);
|
|
3988
|
+
const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3989
|
+
const outs = f(...tracersIn);
|
|
3990
|
+
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3991
|
+
return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3992
|
+
} catch (_) {
|
|
3993
|
+
_usingCtx$1.e = _;
|
|
3994
|
+
} finally {
|
|
3995
|
+
_usingCtx$1.d();
|
|
3996
|
+
}
|
|
3997
|
+
}
|
|
3998
|
+
function jvp$1(f, primals, tangents) {
|
|
3999
|
+
const [primalsFlat, inTree] = flatten(primals);
|
|
4000
|
+
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4001
|
+
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4002
|
+
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
4003
|
+
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4004
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4005
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4006
|
+
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4007
|
+
return [primalsOut, tangentsOut];
|
|
4008
|
+
}
|
|
4009
|
+
|
|
3711
4010
|
//#endregion
|
|
3712
4011
|
//#region src/frontend/linearize.ts
|
|
3713
4012
|
/** Array value that can either be known or unknown. */
|
|
@@ -3738,11 +4037,10 @@ function partialEvalFlat(f, pvalsIn) {
|
|
|
3738
4037
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3739
4038
|
const pvalsOut = tracersOut.map((t) => t.pval);
|
|
3740
4039
|
const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
|
|
3741
|
-
const
|
|
4040
|
+
const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
|
|
3742
4041
|
return {
|
|
3743
4042
|
jaxpr,
|
|
3744
|
-
pvalsOut
|
|
3745
|
-
consts
|
|
4043
|
+
pvalsOut
|
|
3746
4044
|
};
|
|
3747
4045
|
}
|
|
3748
4046
|
/**
|
|
@@ -3759,22 +4057,19 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3759
4057
|
const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
|
|
3760
4058
|
return [...primalsOut$1, ...tangentsOut];
|
|
3761
4059
|
};
|
|
3762
|
-
const { jaxpr, pvalsOut
|
|
4060
|
+
const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
|
|
3763
4061
|
const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
|
|
3764
4062
|
if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
|
|
3765
4063
|
const primalsOut = primalPvals.map((pval) => pval.val);
|
|
3766
4064
|
return {
|
|
3767
4065
|
primalsOut,
|
|
3768
|
-
jaxpr
|
|
3769
|
-
consts
|
|
4066
|
+
jaxpr
|
|
3770
4067
|
};
|
|
3771
4068
|
}
|
|
3772
4069
|
function linearizeFlat(f, primalsIn) {
|
|
3773
|
-
const { primalsOut, jaxpr
|
|
3774
|
-
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3775
|
-
const dispose$1 = () =>
|
|
3776
|
-
for (const c of consts) c.dispose();
|
|
3777
|
-
};
|
|
4070
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4071
|
+
const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
|
|
4072
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
3778
4073
|
return [
|
|
3779
4074
|
primalsOut,
|
|
3780
4075
|
fLin,
|
|
@@ -3858,7 +4153,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3858
4153
|
}
|
|
3859
4154
|
processPrimitive(primitive, tracers, params) {
|
|
3860
4155
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3861
|
-
if (primitive === Primitive.
|
|
4156
|
+
if (primitive === Primitive.Jit) {
|
|
3862
4157
|
const { name, jaxpr, numConsts } = params;
|
|
3863
4158
|
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3864
4159
|
}
|
|
@@ -3884,14 +4179,14 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3884
4179
|
* Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
|
|
3885
4180
|
* values as possible (with JIT) and forwarding the unknown ones.
|
|
3886
4181
|
*
|
|
3887
|
-
* Used when encountering a
|
|
4182
|
+
* Used when encountering a Jit rule during the trace.
|
|
3888
4183
|
*/
|
|
3889
4184
|
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3890
4185
|
jaxpr = jaxpr.flatten();
|
|
3891
4186
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3892
4187
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3893
4188
|
const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
|
|
3894
|
-
const outs1Res = bind(Primitive.
|
|
4189
|
+
const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3895
4190
|
name: `${name}_peval`,
|
|
3896
4191
|
jaxpr: jaxpr1,
|
|
3897
4192
|
numConsts: 0
|
|
@@ -3901,7 +4196,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3901
4196
|
const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
|
|
3902
4197
|
const recipe = {
|
|
3903
4198
|
type: "JaxprEqn",
|
|
3904
|
-
prim: Primitive.
|
|
4199
|
+
prim: Primitive.Jit,
|
|
3905
4200
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3906
4201
|
params: {
|
|
3907
4202
|
name: `${name}_resid`,
|
|
@@ -3930,7 +4225,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
|
|
|
3930
4225
|
const eqns1 = [];
|
|
3931
4226
|
const eqns2 = [];
|
|
3932
4227
|
for (const eqn of jaxpr.eqns) {
|
|
3933
|
-
if (eqn.primitive === Primitive.
|
|
4228
|
+
if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
|
|
3934
4229
|
const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
|
|
3935
4230
|
if (hasUnknowns) {
|
|
3936
4231
|
for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
|
|
@@ -4005,10 +4300,7 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
4005
4300
|
for (const t of tracersOut) t.dispose();
|
|
4006
4301
|
jaxpr = jaxpr.simplify();
|
|
4007
4302
|
if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
4008
|
-
return
|
|
4009
|
-
jaxpr,
|
|
4010
|
-
consts
|
|
4011
|
-
};
|
|
4303
|
+
return new ClosedJaxpr(jaxpr, consts);
|
|
4012
4304
|
}
|
|
4013
4305
|
/** Marker type for pullback, used by transpose rules. */
|
|
4014
4306
|
var UndefPrimal = class {
|
|
@@ -4200,317 +4492,142 @@ const transposeRules = {
|
|
|
4200
4492
|
cond.dispose();
|
|
4201
4493
|
return cts;
|
|
4202
4494
|
},
|
|
4203
|
-
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4204
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4205
|
-
return [transpose$1(ct, invertPermutation(perm))];
|
|
4206
|
-
},
|
|
4207
|
-
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4208
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4209
|
-
return [reduce(ct, AluOp.Add, axis)];
|
|
4210
|
-
},
|
|
4211
|
-
[Primitive.Reshape]([ct], [x], _) {
|
|
4212
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4213
|
-
return [reshape$1(ct, x.aval.shape)];
|
|
4214
|
-
},
|
|
4215
|
-
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4216
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4217
|
-
return [flip$1(ct, axis)];
|
|
4218
|
-
},
|
|
4219
|
-
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4220
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4221
|
-
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4222
|
-
return [pad$1(ct, width)];
|
|
4223
|
-
},
|
|
4224
|
-
[Primitive.Pad]([ct], [x], { width }) {
|
|
4225
|
-
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
|
|
4226
|
-
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4227
|
-
return [shrink(ct, slice)];
|
|
4228
|
-
},
|
|
4229
4495
|
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4230
4496
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4231
4497
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4232
4498
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4233
4499
|
},
|
|
4234
|
-
[Primitive.
|
|
4235
|
-
|
|
4236
|
-
|
|
4237
|
-
|
|
4238
|
-
|
|
4239
|
-
|
|
4240
|
-
|
|
4241
|
-
|
|
4242
|
-
|
|
4243
|
-
|
|
4244
|
-
|
|
4245
|
-
|
|
4246
|
-
|
|
4247
|
-
|
|
4248
|
-
return
|
|
4249
|
-
}
|
|
4250
|
-
}
|
|
4251
|
-
|
|
4252
|
-
|
|
4253
|
-
|
|
4254
|
-
|
|
4255
|
-
|
|
4256
|
-
|
|
4257
|
-
|
|
4258
|
-
|
|
4259
|
-
|
|
4260
|
-
|
|
4261
|
-
|
|
4262
|
-
|
|
4263
|
-
|
|
4264
|
-
|
|
4265
|
-
typecheckJaxpr(newJaxpr);
|
|
4266
|
-
const result = {
|
|
4267
|
-
newJaxpr,
|
|
4268
|
-
newConsts
|
|
4269
|
-
};
|
|
4270
|
-
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4271
|
-
transposeJaxprCache.get(jaxpr).set(cacheKey, result);
|
|
4272
|
-
return result;
|
|
4273
|
-
}
|
|
4274
|
-
function vjpFlat(f, primalsIn) {
|
|
4275
|
-
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
4276
|
-
const fVjp = (...cotangents) => {
|
|
4277
|
-
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4278
|
-
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
4279
|
-
};
|
|
4280
|
-
const dispose$1 = () => {
|
|
4281
|
-
for (const c of consts) c.dispose();
|
|
4282
|
-
};
|
|
4283
|
-
return [
|
|
4284
|
-
primalsOut,
|
|
4285
|
-
fVjp,
|
|
4286
|
-
dispose$1
|
|
4287
|
-
];
|
|
4288
|
-
}
|
|
4289
|
-
function vjp$1(f, ...primalsIn) {
|
|
4290
|
-
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4291
|
-
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4292
|
-
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4293
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4294
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4295
|
-
const fVjp = ((cotangentsOut) => {
|
|
4296
|
-
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4297
|
-
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4298
|
-
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4299
|
-
return unflatten(inTree, cotangentsInFlat);
|
|
4300
|
-
});
|
|
4301
|
-
fVjp.dispose = dispose$1;
|
|
4302
|
-
return [primalsOut, fVjp];
|
|
4303
|
-
}
|
|
4304
|
-
function grad$1(f) {
|
|
4305
|
-
const valueAndGradFn = valueAndGrad$1(f);
|
|
4306
|
-
return (...x) => {
|
|
4307
|
-
const [y, dx] = valueAndGradFn(...x);
|
|
4308
|
-
y.dispose();
|
|
4309
|
-
return dx;
|
|
4310
|
-
};
|
|
4311
|
-
}
|
|
4312
|
-
function valueAndGrad$1(f) {
|
|
4313
|
-
return (...x) => {
|
|
4314
|
-
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4315
|
-
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4316
|
-
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4317
|
-
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4318
|
-
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4319
|
-
for (const r of rest) dispose(r);
|
|
4320
|
-
fVjp.dispose();
|
|
4321
|
-
return [y, ct];
|
|
4322
|
-
};
|
|
4323
|
-
}
|
|
4324
|
-
function jacrev$1(f) {
|
|
4325
|
-
return function jacobianReverse(x) {
|
|
4326
|
-
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4327
|
-
const [size$1] = x.shape;
|
|
4328
|
-
const pullback = (ct) => {
|
|
4329
|
-
const [y, fVjp] = vjp$1(f, x);
|
|
4330
|
-
y.dispose();
|
|
4331
|
-
const [ret] = fVjp(ct);
|
|
4332
|
-
fVjp.dispose();
|
|
4333
|
-
return ret;
|
|
4334
|
-
};
|
|
4335
|
-
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4336
|
-
};
|
|
4337
|
-
}
|
|
4338
|
-
|
|
4339
|
-
//#endregion
|
|
4340
|
-
//#region src/library/lax.ts
|
|
4341
|
-
var lax_exports = {};
|
|
4342
|
-
__export(lax_exports, {
|
|
4343
|
-
conv: () => conv,
|
|
4344
|
-
convGeneralDilated: () => convGeneralDilated,
|
|
4345
|
-
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4346
|
-
dot: () => dot$1,
|
|
4347
|
-
erf: () => erf,
|
|
4348
|
-
erfc: () => erfc,
|
|
4349
|
-
reduceWindow: () => reduceWindow,
|
|
4350
|
-
stopGradient: () => stopGradient$1
|
|
4351
|
-
});
|
|
4352
|
-
/**
|
|
4353
|
-
* General dot product/contraction operator.
|
|
4354
|
-
*
|
|
4355
|
-
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
4356
|
-
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
4357
|
-
*/
|
|
4358
|
-
function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
4359
|
-
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
4360
|
-
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
4361
|
-
lc = lc.map((a) => checkAxis(a, lhs.ndim));
|
|
4362
|
-
rc = rc.map((a) => checkAxis(a, rhs.ndim));
|
|
4363
|
-
lb = lb.map((a) => checkAxis(a, lhs.ndim));
|
|
4364
|
-
rb = rb.map((a) => checkAxis(a, rhs.ndim));
|
|
4365
|
-
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
4366
|
-
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
4367
|
-
const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
4368
|
-
const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
4369
|
-
const lhs2 = lhs.transpose([
|
|
4370
|
-
...lb,
|
|
4371
|
-
...lf,
|
|
4372
|
-
...lc
|
|
4373
|
-
]);
|
|
4374
|
-
const rhs2 = rhs.transpose([
|
|
4375
|
-
...rb,
|
|
4376
|
-
...rf,
|
|
4377
|
-
...rc
|
|
4378
|
-
]);
|
|
4379
|
-
if (lc.length === 0) return mul(lhs2.reshape([
|
|
4380
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4381
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4382
|
-
...rep(rf.length, 1)
|
|
4383
|
-
]), rhs2.reshape([
|
|
4384
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4385
|
-
...rep(lf.length, 1),
|
|
4386
|
-
...rf.map((a) => rhs.shape[a])
|
|
4387
|
-
]));
|
|
4388
|
-
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
4389
|
-
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
4390
|
-
if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
4391
|
-
return dot$2(lhs2.reshape([
|
|
4392
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4393
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4394
|
-
...rep(rf.length, 1),
|
|
4395
|
-
prod(dotShapeX)
|
|
4396
|
-
]), rhs2.reshape([
|
|
4397
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4398
|
-
...rep(lf.length, 1),
|
|
4399
|
-
...rf.map((a) => rhs.shape[a]),
|
|
4400
|
-
prod(dotShapeY)
|
|
4401
|
-
]));
|
|
4402
|
-
}
|
|
4403
|
-
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4404
|
-
const padType = padding.toUpperCase();
|
|
4405
|
-
switch (padType) {
|
|
4406
|
-
case "VALID": return rep(inShape.length, [0, 0]);
|
|
4407
|
-
case "SAME":
|
|
4408
|
-
case "SAME_LOWER": {
|
|
4409
|
-
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
4410
|
-
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
4411
|
-
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
4412
|
-
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
4413
|
-
}
|
|
4414
|
-
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
4415
|
-
}
|
|
4416
|
-
}
|
|
4417
|
-
/**
|
|
4418
|
-
* General n-dimensional convolution operator, with optional dilation.
|
|
4419
|
-
*
|
|
4420
|
-
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
4421
|
-
* function in JAX, which wraps XLA's general convolution operator.
|
|
4422
|
-
*
|
|
4423
|
-
* Grouped convolutions are not supported right now.
|
|
4424
|
-
*/
|
|
4425
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
4426
|
-
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4427
|
-
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4428
|
-
if (typeof padding === "string") {
|
|
4429
|
-
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4430
|
-
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
4431
|
-
}
|
|
4432
|
-
if (featureGroupCount !== 1) {
|
|
4433
|
-
const G = featureGroupCount;
|
|
4434
|
-
const [N, C_in, ...xs] = lhs.shape;
|
|
4435
|
-
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
4436
|
-
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
4437
|
-
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
4438
|
-
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
4439
|
-
const lhsGrouped = moveaxis(lhs.reshape([
|
|
4440
|
-
N,
|
|
4441
|
-
G,
|
|
4442
|
-
C_in / G,
|
|
4443
|
-
...xs
|
|
4444
|
-
]), 1, 0);
|
|
4445
|
-
const rhsGrouped = rhs.reshape([
|
|
4446
|
-
G,
|
|
4447
|
-
C_out / G,
|
|
4448
|
-
C_in_per_group,
|
|
4449
|
-
...ks
|
|
4450
|
-
]);
|
|
4451
|
-
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
4452
|
-
vmapDims: 1,
|
|
4453
|
-
strides: windowStrides,
|
|
4454
|
-
padding,
|
|
4455
|
-
lhsDilation,
|
|
4456
|
-
rhsDilation
|
|
4500
|
+
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4501
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4502
|
+
return [transpose$1(ct, invertPermutation(perm))];
|
|
4503
|
+
},
|
|
4504
|
+
[Primitive.Broadcast]([ct], [x], { axis }) {
|
|
4505
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Broadcast);
|
|
4506
|
+
return [reduce(ct, AluOp.Add, axis)];
|
|
4507
|
+
},
|
|
4508
|
+
[Primitive.Reshape]([ct], [x], _) {
|
|
4509
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Reshape);
|
|
4510
|
+
return [reshape$1(ct, x.aval.shape)];
|
|
4511
|
+
},
|
|
4512
|
+
[Primitive.Flip]([ct], [x], { axis }) {
|
|
4513
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Flip);
|
|
4514
|
+
return [flip$1(ct, axis)];
|
|
4515
|
+
},
|
|
4516
|
+
[Primitive.Shrink]([ct], [x], { slice }) {
|
|
4517
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Shrink);
|
|
4518
|
+
const width = slice.map(([s, e$1], i) => [s, x.aval.shape[i] - e$1]);
|
|
4519
|
+
return [pad$1(ct, width)];
|
|
4520
|
+
},
|
|
4521
|
+
[Primitive.Pad]([ct], [x], { width }) {
|
|
4522
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pad);
|
|
4523
|
+
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4524
|
+
return [shrink(ct, slice)];
|
|
4525
|
+
},
|
|
4526
|
+
[Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
|
|
4527
|
+
if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
|
|
4528
|
+
const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
|
|
4529
|
+
lower: true,
|
|
4530
|
+
unitDiagonal
|
|
4457
4531
|
});
|
|
4458
|
-
|
|
4459
|
-
|
|
4460
|
-
|
|
4461
|
-
|
|
4462
|
-
|
|
4463
|
-
]);
|
|
4532
|
+
return [null, ctB];
|
|
4533
|
+
},
|
|
4534
|
+
[Primitive.Jit](cts, args, { name, jaxpr }) {
|
|
4535
|
+
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
4536
|
+
const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
|
|
4537
|
+
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
4538
|
+
const outs = bind(Primitive.Jit, [
|
|
4539
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4540
|
+
...residuals,
|
|
4541
|
+
...cts
|
|
4542
|
+
], {
|
|
4543
|
+
name: `${name}_t`,
|
|
4544
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4545
|
+
numConsts: newJaxpr.consts.length
|
|
4546
|
+
});
|
|
4547
|
+
let i = 0;
|
|
4548
|
+
return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
|
|
4464
4549
|
}
|
|
4465
|
-
|
|
4466
|
-
|
|
4467
|
-
|
|
4468
|
-
|
|
4469
|
-
|
|
4470
|
-
|
|
4471
|
-
}
|
|
4472
|
-
|
|
4473
|
-
|
|
4474
|
-
|
|
4475
|
-
|
|
4476
|
-
|
|
4477
|
-
|
|
4550
|
+
};
|
|
4551
|
+
const transposeJaxprCache = /* @__PURE__ */ new Map();
|
|
4552
|
+
function transposeJaxpr(jaxpr, undefPrimals) {
|
|
4553
|
+
const cacheKey = JSON.stringify(undefPrimals);
|
|
4554
|
+
const prevResult = transposeJaxprCache.get(jaxpr)?.get(cacheKey);
|
|
4555
|
+
if (prevResult) return prevResult;
|
|
4556
|
+
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
4557
|
+
const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
|
|
4558
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
|
|
4559
|
+
const args = [];
|
|
4560
|
+
let forwardInIdx = 0;
|
|
4561
|
+
for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
|
|
4562
|
+
else args.push(forwardIn[forwardInIdx++]);
|
|
4563
|
+
return evalJaxprTransposed(jaxpr, args, cotangents);
|
|
4564
|
+
})(forwardInTypes, outTypes);
|
|
4565
|
+
typecheckJaxpr(newJaxpr.jaxpr);
|
|
4566
|
+
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4567
|
+
transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
4568
|
+
return newJaxpr;
|
|
4478
4569
|
}
|
|
4479
|
-
|
|
4480
|
-
|
|
4481
|
-
|
|
4570
|
+
function vjpFlat(f, primalsIn) {
|
|
4571
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4572
|
+
const fVjp = (...cotangents) => {
|
|
4573
|
+
const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4574
|
+
return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
|
|
4575
|
+
};
|
|
4576
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
4577
|
+
return [
|
|
4578
|
+
primalsOut,
|
|
4579
|
+
fVjp,
|
|
4580
|
+
dispose$1
|
|
4581
|
+
];
|
|
4482
4582
|
}
|
|
4483
|
-
|
|
4484
|
-
|
|
4485
|
-
|
|
4486
|
-
|
|
4487
|
-
|
|
4488
|
-
|
|
4489
|
-
|
|
4490
|
-
|
|
4491
|
-
|
|
4583
|
+
function vjp$1(f, ...primalsIn) {
|
|
4584
|
+
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
4585
|
+
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
4586
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4587
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
4588
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4589
|
+
const fVjp = ((cotangentsOut) => {
|
|
4590
|
+
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
4591
|
+
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
4592
|
+
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
4593
|
+
return unflatten(inTree, cotangentsInFlat);
|
|
4594
|
+
});
|
|
4595
|
+
fVjp.dispose = dispose$1;
|
|
4596
|
+
return [primalsOut, fVjp];
|
|
4492
4597
|
}
|
|
4493
|
-
|
|
4494
|
-
|
|
4495
|
-
return
|
|
4598
|
+
function grad$1(f) {
|
|
4599
|
+
const valueAndGradFn = valueAndGrad$1(f);
|
|
4600
|
+
return (...x) => {
|
|
4601
|
+
const [y, dx] = valueAndGradFn(...x);
|
|
4602
|
+
y.dispose();
|
|
4603
|
+
return dx;
|
|
4604
|
+
};
|
|
4496
4605
|
}
|
|
4497
|
-
|
|
4498
|
-
|
|
4499
|
-
|
|
4500
|
-
|
|
4501
|
-
|
|
4502
|
-
|
|
4503
|
-
|
|
4504
|
-
|
|
4606
|
+
function valueAndGrad$1(f) {
|
|
4607
|
+
return (...x) => {
|
|
4608
|
+
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
4609
|
+
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4610
|
+
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4611
|
+
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4612
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4613
|
+
for (const r of rest) dispose(r);
|
|
4614
|
+
fVjp.dispose();
|
|
4615
|
+
return [y, ct];
|
|
4616
|
+
};
|
|
4505
4617
|
}
|
|
4506
|
-
|
|
4507
|
-
|
|
4508
|
-
|
|
4509
|
-
|
|
4510
|
-
|
|
4511
|
-
|
|
4512
|
-
|
|
4513
|
-
|
|
4618
|
+
function jacrev$1(f) {
|
|
4619
|
+
return function jacobianReverse(x) {
|
|
4620
|
+
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
4621
|
+
const [size$1] = x.shape;
|
|
4622
|
+
const pullback = (ct) => {
|
|
4623
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
4624
|
+
y.dispose();
|
|
4625
|
+
const [ret] = fVjp(ct);
|
|
4626
|
+
fVjp.dispose();
|
|
4627
|
+
return ret;
|
|
4628
|
+
};
|
|
4629
|
+
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
4630
|
+
};
|
|
4514
4631
|
}
|
|
4515
4632
|
|
|
4516
4633
|
//#endregion
|
|
@@ -4708,34 +4825,207 @@ function* allPaths(tensors, next) {
|
|
|
4708
4825
|
}
|
|
4709
4826
|
}
|
|
4710
4827
|
|
|
4828
|
+
//#endregion
|
|
4829
|
+
//#region src/library/numpy-fft.ts
|
|
4830
|
+
var numpy_fft_exports = {};
|
|
4831
|
+
__export(numpy_fft_exports, {
|
|
4832
|
+
fft: () => fft,
|
|
4833
|
+
ifft: () => ifft
|
|
4834
|
+
});
|
|
4835
|
+
function checkPairInput(name, a) {
|
|
4836
|
+
const fullName = `jax.numpy.fft.${name}`;
|
|
4837
|
+
if (!deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
|
|
4838
|
+
if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
|
|
4839
|
+
if (!isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
|
|
4840
|
+
}
|
|
4841
|
+
function checkPowerOfTwo(name, n) {
|
|
4842
|
+
if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
|
|
4843
|
+
}
|
|
4844
|
+
const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
|
|
4845
|
+
const half = 2 ** i;
|
|
4846
|
+
real = real.reshape([-1, 2 * half]);
|
|
4847
|
+
imag = imag.reshape([-1, 2 * half]);
|
|
4848
|
+
const k = arange(0, half, 1, { dtype: real.dtype });
|
|
4849
|
+
const theta = k.mul(-Math.PI / half);
|
|
4850
|
+
const wr = cos(theta.ref);
|
|
4851
|
+
const wi = sin(theta);
|
|
4852
|
+
const ur = real.ref.slice([], [0, half]);
|
|
4853
|
+
const ui = imag.ref.slice([], [0, half]);
|
|
4854
|
+
const vr = real.slice([], [half, 2 * half]);
|
|
4855
|
+
const vi = imag.slice([], [half, 2 * half]);
|
|
4856
|
+
const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
|
|
4857
|
+
const ti = vr.mul(wi).add(vi.mul(wr));
|
|
4858
|
+
return {
|
|
4859
|
+
real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
|
|
4860
|
+
imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
|
|
4861
|
+
};
|
|
4862
|
+
}, { staticArgnums: [0] });
|
|
4863
|
+
/**
|
|
4864
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
4865
|
+
*
|
|
4866
|
+
* Currently, the size of the axis must be a power of two.
|
|
4867
|
+
*/
|
|
4868
|
+
function fft(a, axis = -1) {
|
|
4869
|
+
checkPairInput("fft", a);
|
|
4870
|
+
let { real, imag } = a;
|
|
4871
|
+
axis = checkAxis(axis, real.ndim);
|
|
4872
|
+
const n = real.shape[axis];
|
|
4873
|
+
checkPowerOfTwo("fft", n);
|
|
4874
|
+
const logN = Math.log2(n);
|
|
4875
|
+
let perm = null;
|
|
4876
|
+
if (axis !== real.ndim - 1) {
|
|
4877
|
+
perm = range(real.ndim);
|
|
4878
|
+
perm.splice(axis, 1);
|
|
4879
|
+
perm.push(axis);
|
|
4880
|
+
real = real.transpose(perm);
|
|
4881
|
+
imag = imag.transpose(perm);
|
|
4882
|
+
}
|
|
4883
|
+
const originalShape = real.shape;
|
|
4884
|
+
real = real.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
|
|
4885
|
+
imag = imag.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
|
|
4886
|
+
for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
|
|
4887
|
+
real,
|
|
4888
|
+
imag
|
|
4889
|
+
}));
|
|
4890
|
+
real = real.reshape(originalShape);
|
|
4891
|
+
imag = imag.reshape(originalShape);
|
|
4892
|
+
if (perm !== null) {
|
|
4893
|
+
real = real.transpose(invertPermutation(perm));
|
|
4894
|
+
imag = imag.transpose(invertPermutation(perm));
|
|
4895
|
+
}
|
|
4896
|
+
return {
|
|
4897
|
+
real,
|
|
4898
|
+
imag
|
|
4899
|
+
};
|
|
4900
|
+
}
|
|
4901
|
+
/**
|
|
4902
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
4903
|
+
*
|
|
4904
|
+
* Currently, the size of the axis must be a power of two.
|
|
4905
|
+
*/
|
|
4906
|
+
function ifft(a, axis = -1) {
|
|
4907
|
+
checkPairInput("ifft", a);
|
|
4908
|
+
let { real, imag } = a;
|
|
4909
|
+
axis = checkAxis(axis, real.ndim);
|
|
4910
|
+
const n = real.shape[axis];
|
|
4911
|
+
checkPowerOfTwo("ifft", n);
|
|
4912
|
+
imag = imag.mul(-1);
|
|
4913
|
+
const result = fft({
|
|
4914
|
+
real,
|
|
4915
|
+
imag
|
|
4916
|
+
}, axis);
|
|
4917
|
+
return {
|
|
4918
|
+
real: result.real.div(n),
|
|
4919
|
+
imag: result.imag.mul(-1).div(n)
|
|
4920
|
+
};
|
|
4921
|
+
}
|
|
4922
|
+
|
|
4923
|
+
//#endregion
|
|
4924
|
+
//#region src/library/numpy-linalg.ts
|
|
4925
|
+
var numpy_linalg_exports = {};
|
|
4926
|
+
__export(numpy_linalg_exports, {
|
|
4927
|
+
cholesky: () => cholesky$1,
|
|
4928
|
+
diagonal: () => diagonal,
|
|
4929
|
+
lstsq: () => lstsq,
|
|
4930
|
+
matmul: () => matmul,
|
|
4931
|
+
matrixTranspose: () => matrixTranspose,
|
|
4932
|
+
outer: () => outer,
|
|
4933
|
+
tensordot: () => tensordot,
|
|
4934
|
+
trace: () => trace,
|
|
4935
|
+
vecdot: () => vecdot
|
|
4936
|
+
});
|
|
4937
|
+
/**
|
|
4938
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4939
|
+
*
|
|
4940
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4941
|
+
* the input matrix, which is on by default.
|
|
4942
|
+
*/
|
|
4943
|
+
function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4944
|
+
a = fudgeArray(a);
|
|
4945
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
|
|
4946
|
+
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4947
|
+
return cholesky(a, { upper });
|
|
4948
|
+
}
|
|
4949
|
+
/**
|
|
4950
|
+
* Return the least-squares solution to a linear equation.
|
|
4951
|
+
*
|
|
4952
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
4953
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
4954
|
+
*
|
|
4955
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
4956
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
4957
|
+
*
|
|
4958
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
4959
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
4960
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
4961
|
+
*/
|
|
4962
|
+
function lstsq(a, b) {
|
|
4963
|
+
a = fudgeArray(a);
|
|
4964
|
+
b = fudgeArray(b);
|
|
4965
|
+
if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
|
|
4966
|
+
const [m, n] = a.shape;
|
|
4967
|
+
if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
|
|
4968
|
+
const at = matrixTranspose(a.ref);
|
|
4969
|
+
if (m <= n) {
|
|
4970
|
+
const aat = matmul(a, at.ref);
|
|
4971
|
+
const l = cholesky$1(aat, { symmetrizeInput: false });
|
|
4972
|
+
const lb = triangularSolve(l.ref, b, {
|
|
4973
|
+
leftSide: true,
|
|
4974
|
+
lower: true
|
|
4975
|
+
});
|
|
4976
|
+
const llb = triangularSolve(l, lb, {
|
|
4977
|
+
leftSide: true,
|
|
4978
|
+
transposeA: true
|
|
4979
|
+
});
|
|
4980
|
+
return matmul(at, llb.ref);
|
|
4981
|
+
} else {
|
|
4982
|
+
const ata = matmul(at.ref, a);
|
|
4983
|
+
const l = cholesky$1(ata, { symmetrizeInput: false });
|
|
4984
|
+
const atb = matmul(at, b);
|
|
4985
|
+
const lb = triangularSolve(l.ref, atb, {
|
|
4986
|
+
leftSide: true,
|
|
4987
|
+
lower: true
|
|
4988
|
+
});
|
|
4989
|
+
const llb = triangularSolve(l, lb, {
|
|
4990
|
+
leftSide: true,
|
|
4991
|
+
transposeA: true
|
|
4992
|
+
});
|
|
4993
|
+
return llb;
|
|
4994
|
+
}
|
|
4995
|
+
}
|
|
4996
|
+
|
|
4711
4997
|
//#endregion
|
|
4712
4998
|
//#region src/library/numpy.ts
|
|
4713
4999
|
var numpy_exports = {};
|
|
4714
5000
|
__export(numpy_exports, {
|
|
4715
5001
|
Array: () => Array$1,
|
|
4716
5002
|
DType: () => DType,
|
|
4717
|
-
abs: () =>
|
|
5003
|
+
abs: () => absolute,
|
|
4718
5004
|
absolute: () => absolute,
|
|
4719
5005
|
acos: () => acos,
|
|
4720
|
-
acosh: () =>
|
|
5006
|
+
acosh: () => arccosh,
|
|
4721
5007
|
add: () => add,
|
|
5008
|
+
all: () => all,
|
|
4722
5009
|
allclose: () => allclose,
|
|
5010
|
+
any: () => any,
|
|
4723
5011
|
arange: () => arange,
|
|
4724
|
-
arccos: () =>
|
|
5012
|
+
arccos: () => acos,
|
|
4725
5013
|
arccosh: () => arccosh,
|
|
5014
|
+
arcsin: () => asin,
|
|
4726
5015
|
arcsinh: () => arcsinh,
|
|
4727
|
-
arctan: () =>
|
|
4728
|
-
arctan2: () =>
|
|
5016
|
+
arctan: () => atan,
|
|
5017
|
+
arctan2: () => atan2,
|
|
4729
5018
|
arctanh: () => arctanh,
|
|
4730
5019
|
argmax: () => argmax,
|
|
4731
5020
|
argmin: () => argmin,
|
|
5021
|
+
argsort: () => argsort,
|
|
4732
5022
|
array: () => array,
|
|
4733
5023
|
asin: () => asin,
|
|
4734
|
-
asinh: () =>
|
|
5024
|
+
asinh: () => arcsinh,
|
|
4735
5025
|
astype: () => astype,
|
|
4736
5026
|
atan: () => atan,
|
|
4737
5027
|
atan2: () => atan2,
|
|
4738
|
-
atanh: () =>
|
|
5028
|
+
atanh: () => arctanh,
|
|
4739
5029
|
bool: () => bool,
|
|
4740
5030
|
broadcastArrays: () => broadcastArrays,
|
|
4741
5031
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -4745,16 +5035,20 @@ __export(numpy_exports, {
|
|
|
4745
5035
|
clip: () => clip,
|
|
4746
5036
|
columnStack: () => columnStack,
|
|
4747
5037
|
concatenate: () => concatenate,
|
|
5038
|
+
convolve: () => convolve,
|
|
5039
|
+
corrcoef: () => corrcoef,
|
|
5040
|
+
correlate: () => correlate,
|
|
4748
5041
|
cos: () => cos,
|
|
4749
5042
|
cosh: () => cosh,
|
|
5043
|
+
cov: () => cov,
|
|
4750
5044
|
cumsum: () => cumsum,
|
|
4751
|
-
cumulativeSum: () =>
|
|
5045
|
+
cumulativeSum: () => cumsum,
|
|
4752
5046
|
deg2rad: () => deg2rad,
|
|
4753
5047
|
degrees: () => degrees,
|
|
4754
5048
|
diag: () => diag,
|
|
4755
5049
|
diagonal: () => diagonal,
|
|
4756
|
-
divide: () =>
|
|
4757
|
-
dot: () => dot,
|
|
5050
|
+
divide: () => trueDivide,
|
|
5051
|
+
dot: () => dot$1,
|
|
4758
5052
|
dstack: () => dstack,
|
|
4759
5053
|
e: () => e,
|
|
4760
5054
|
einsum: () => einsum,
|
|
@@ -4762,8 +5056,10 @@ __export(numpy_exports, {
|
|
|
4762
5056
|
eulerGamma: () => eulerGamma,
|
|
4763
5057
|
exp: () => exp,
|
|
4764
5058
|
exp2: () => exp2,
|
|
5059
|
+
expandDims: () => expandDims,
|
|
4765
5060
|
expm1: () => expm1,
|
|
4766
5061
|
eye: () => eye,
|
|
5062
|
+
fft: () => numpy_fft_exports,
|
|
4767
5063
|
flip: () => flip,
|
|
4768
5064
|
fliplr: () => fliplr,
|
|
4769
5065
|
flipud: () => flipud,
|
|
@@ -4794,12 +5090,14 @@ __export(numpy_exports, {
|
|
|
4794
5090
|
ldexp: () => ldexp,
|
|
4795
5091
|
less: () => less,
|
|
4796
5092
|
lessEqual: () => lessEqual,
|
|
5093
|
+
linalg: () => numpy_linalg_exports,
|
|
4797
5094
|
linspace: () => linspace,
|
|
4798
5095
|
log: () => log,
|
|
4799
5096
|
log10: () => log10,
|
|
4800
5097
|
log1p: () => log1p,
|
|
4801
5098
|
log2: () => log2,
|
|
4802
5099
|
matmul: () => matmul,
|
|
5100
|
+
matrixTranspose: () => matrixTranspose,
|
|
4803
5101
|
max: () => max,
|
|
4804
5102
|
maximum: () => maximum,
|
|
4805
5103
|
mean: () => mean,
|
|
@@ -4816,10 +5114,10 @@ __export(numpy_exports, {
|
|
|
4816
5114
|
onesLike: () => onesLike,
|
|
4817
5115
|
outer: () => outer,
|
|
4818
5116
|
pad: () => pad,
|
|
4819
|
-
permuteDims: () =>
|
|
5117
|
+
permuteDims: () => transpose,
|
|
4820
5118
|
pi: () => pi,
|
|
4821
5119
|
positive: () => positive,
|
|
4822
|
-
pow: () =>
|
|
5120
|
+
pow: () => power,
|
|
4823
5121
|
power: () => power,
|
|
4824
5122
|
prod: () => prod$1,
|
|
4825
5123
|
promoteTypes: () => promoteTypes,
|
|
@@ -4836,6 +5134,7 @@ __export(numpy_exports, {
|
|
|
4836
5134
|
sin: () => sin,
|
|
4837
5135
|
sinh: () => sinh,
|
|
4838
5136
|
size: () => size,
|
|
5137
|
+
sort: () => sort,
|
|
4839
5138
|
sqrt: () => sqrt,
|
|
4840
5139
|
square: () => square,
|
|
4841
5140
|
squeeze: () => squeeze,
|
|
@@ -5000,6 +5299,26 @@ function min(a, axis = null, opts) {
|
|
|
5000
5299
|
function max(a, axis = null, opts) {
|
|
5001
5300
|
return reduce(a, AluOp.Max, axis, opts);
|
|
5002
5301
|
}
|
|
5302
|
+
/**
|
|
5303
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5304
|
+
*
|
|
5305
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5306
|
+
* removed. If axis is None, returns a scalar.
|
|
5307
|
+
*/
|
|
5308
|
+
function all(a, axis = null, opts) {
|
|
5309
|
+
a = fudgeArray(a).astype(DType.Bool);
|
|
5310
|
+
return min(a, axis, opts);
|
|
5311
|
+
}
|
|
5312
|
+
/**
|
|
5313
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5314
|
+
*
|
|
5315
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5316
|
+
* removed. If axis is None, returns a scalar.
|
|
5317
|
+
*/
|
|
5318
|
+
function any(a, axis = null, opts) {
|
|
5319
|
+
a = fudgeArray(a).astype(DType.Bool);
|
|
5320
|
+
return max(a, axis, opts);
|
|
5321
|
+
}
|
|
5003
5322
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
5004
5323
|
function ptp(a, axis = null, opts) {
|
|
5005
5324
|
a = fudgeArray(a);
|
|
@@ -5074,8 +5393,6 @@ function cumsum(a, axis) {
|
|
|
5074
5393
|
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5075
5394
|
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5076
5395
|
}
|
|
5077
|
-
/** @function Alternative name for `jax.numpy.cumsum()`. */
|
|
5078
|
-
const cumulativeSum = cumsum;
|
|
5079
5396
|
/** Reverse the elements in an array along the given axes. */
|
|
5080
5397
|
function flip(x, axis = null) {
|
|
5081
5398
|
const nd = ndim(x);
|
|
@@ -5185,8 +5502,11 @@ function flipud(x) {
|
|
|
5185
5502
|
function fliplr(x) {
|
|
5186
5503
|
return flip(x, 1);
|
|
5187
5504
|
}
|
|
5188
|
-
/**
|
|
5189
|
-
|
|
5505
|
+
/** Transpose the last two dimensions of an array. */
|
|
5506
|
+
function matrixTranspose(a) {
|
|
5507
|
+
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
5508
|
+
return moveaxis$1(a, -1, -2);
|
|
5509
|
+
}
|
|
5190
5510
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
5191
5511
|
function ravel(a) {
|
|
5192
5512
|
return fudgeArray(a).ravel();
|
|
@@ -5202,6 +5522,32 @@ function squeeze(a, axis = null) {
|
|
|
5202
5522
|
return reshape(a, newShape);
|
|
5203
5523
|
}
|
|
5204
5524
|
/**
|
|
5525
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
5526
|
+
*
|
|
5527
|
+
* @param a - Input array.
|
|
5528
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
5529
|
+
* is placed. Can be a single integer or an array of integers.
|
|
5530
|
+
* @returns Array with the number of dimensions increased.
|
|
5531
|
+
*
|
|
5532
|
+
* @example
|
|
5533
|
+
* ```ts
|
|
5534
|
+
* const x = np.array([1, 2]);
|
|
5535
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
5536
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
5537
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
5538
|
+
* ```
|
|
5539
|
+
*/
|
|
5540
|
+
function expandDims(a, axis) {
|
|
5541
|
+
const as = shape(a);
|
|
5542
|
+
axis = typeof axis === "number" ? [axis] : axis;
|
|
5543
|
+
axis = normalizeAxis(axis, as.length + axis.length);
|
|
5544
|
+
const newShape = [];
|
|
5545
|
+
let srcIdx = 0;
|
|
5546
|
+
for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
|
|
5547
|
+
else newShape.push(as[srcIdx++]);
|
|
5548
|
+
return reshape(a, newShape);
|
|
5549
|
+
}
|
|
5550
|
+
/**
|
|
5205
5551
|
* Repeat each element of an array after themselves.
|
|
5206
5552
|
*
|
|
5207
5553
|
* If no axis is provided, use the flattened input array, and return a flat
|
|
@@ -5289,7 +5635,7 @@ function diagonal(a, offset, axis1, axis2) {
|
|
|
5289
5635
|
*/
|
|
5290
5636
|
function diag(v, k = 0) {
|
|
5291
5637
|
const a = fudgeArray(v);
|
|
5292
|
-
if (!Number.isInteger(k)) throw new
|
|
5638
|
+
if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
|
|
5293
5639
|
if (a.ndim === 1) {
|
|
5294
5640
|
const n = a.shape[0];
|
|
5295
5641
|
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
@@ -5297,12 +5643,32 @@ function diag(v, k = 0) {
|
|
|
5297
5643
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
5298
5644
|
else return ret;
|
|
5299
5645
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
5300
|
-
else throw new
|
|
5646
|
+
else throw new Error("numpy.diag only supports 1D and 2D arrays");
|
|
5301
5647
|
}
|
|
5302
5648
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
5303
5649
|
function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
|
|
5304
5650
|
return diagonal(a, offset, axis1, axis2).sum(-1);
|
|
5305
5651
|
}
|
|
5652
|
+
/**
|
|
5653
|
+
* Return a sorted copy of an array.
|
|
5654
|
+
*
|
|
5655
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
5656
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
5657
|
+
*/
|
|
5658
|
+
function sort(a, axis = -1) {
|
|
5659
|
+
return fudgeArray(a).sort(axis);
|
|
5660
|
+
}
|
|
5661
|
+
/**
|
|
5662
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
5663
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
5664
|
+
*
|
|
5665
|
+
* Returns an array of `int32` indices.
|
|
5666
|
+
*
|
|
5667
|
+
* The array is sorted along a specified axis (the last by default).
|
|
5668
|
+
*/
|
|
5669
|
+
function argsort(a, axis = -1) {
|
|
5670
|
+
return fudgeArray(a).argsort(axis);
|
|
5671
|
+
}
|
|
5306
5672
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5307
5673
|
function allclose(actual, expected, options) {
|
|
5308
5674
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5319,11 +5685,11 @@ function allclose(actual, expected, options) {
|
|
|
5319
5685
|
}
|
|
5320
5686
|
/** Matrix product of two arrays. */
|
|
5321
5687
|
function matmul(x, y) {
|
|
5322
|
-
if (ndim(x) === 0 || ndim(y) === 0) throw new
|
|
5688
|
+
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
5323
5689
|
x = x, y = y;
|
|
5324
5690
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5325
5691
|
const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
|
|
5326
|
-
return dot
|
|
5692
|
+
return dot(x, y, {
|
|
5327
5693
|
lhsContractingDims: [-1],
|
|
5328
5694
|
rhsContractingDims: [-2],
|
|
5329
5695
|
lhsBatchDims: range(-2 - numBatchDims, -2),
|
|
@@ -5331,11 +5697,11 @@ function matmul(x, y) {
|
|
|
5331
5697
|
});
|
|
5332
5698
|
}
|
|
5333
5699
|
/** Dot product of two arrays. */
|
|
5334
|
-
function dot(x, y) {
|
|
5700
|
+
function dot$1(x, y) {
|
|
5335
5701
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
5336
5702
|
x = x, y = y;
|
|
5337
5703
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5338
|
-
return dot
|
|
5704
|
+
return dot(x, y, {
|
|
5339
5705
|
lhsContractingDims: [-1],
|
|
5340
5706
|
rhsContractingDims: [-2]
|
|
5341
5707
|
});
|
|
@@ -5351,7 +5717,7 @@ function tensordot(x, y, axes = 2) {
|
|
|
5351
5717
|
x = fudgeArray(x);
|
|
5352
5718
|
y = fudgeArray(y);
|
|
5353
5719
|
if (typeof axes === "number") axes = [range(-axes, 0), range(axes)];
|
|
5354
|
-
return dot
|
|
5720
|
+
return dot(x, y, {
|
|
5355
5721
|
lhsContractingDims: axes[0],
|
|
5356
5722
|
rhsContractingDims: axes[1]
|
|
5357
5723
|
});
|
|
@@ -5444,7 +5810,7 @@ function einsum(...args) {
|
|
|
5444
5810
|
const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
|
|
5445
5811
|
indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
|
|
5446
5812
|
const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
|
|
5447
|
-
const result = dot
|
|
5813
|
+
const result = dot(a, b, {
|
|
5448
5814
|
lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
|
|
5449
5815
|
rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
|
|
5450
5816
|
lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
|
|
@@ -5472,7 +5838,7 @@ function einsum(...args) {
|
|
|
5472
5838
|
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
5473
5839
|
*/
|
|
5474
5840
|
function inner(x, y) {
|
|
5475
|
-
return dot
|
|
5841
|
+
return dot(fudgeArray(x), fudgeArray(y), {
|
|
5476
5842
|
lhsContractingDims: [-1],
|
|
5477
5843
|
rhsContractingDims: [-1]
|
|
5478
5844
|
});
|
|
@@ -5505,6 +5871,30 @@ function vecdot(x, y, { axis } = {}) {
|
|
|
5505
5871
|
function vdot(x, y) {
|
|
5506
5872
|
return dot$2(ravel(x), ravel(y));
|
|
5507
5873
|
}
|
|
5874
|
+
function _convImpl(name, x, y, mode) {
|
|
5875
|
+
if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
|
|
5876
|
+
let flipOutput = false;
|
|
5877
|
+
if (x.shape[0] < y.shape[0]) {
|
|
5878
|
+
[x, y] = [y, x];
|
|
5879
|
+
if (name === "correlate") flipOutput = true;
|
|
5880
|
+
}
|
|
5881
|
+
if (name === "convolve") y = flip(y);
|
|
5882
|
+
let padding;
|
|
5883
|
+
if (mode === "valid") padding = "VALID";
|
|
5884
|
+
else if (mode === "same") padding = "SAME_LOWER";
|
|
5885
|
+
else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
|
|
5886
|
+
else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
|
|
5887
|
+
const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
|
|
5888
|
+
return flipOutput ? flip(z) : z;
|
|
5889
|
+
}
|
|
5890
|
+
/** Convolution of two one-dimensional arrays. */
|
|
5891
|
+
function convolve(x, y, mode = "full") {
|
|
5892
|
+
return _convImpl("convolve", x, y, mode);
|
|
5893
|
+
}
|
|
5894
|
+
/** Correlation of two one dimensional arrays. */
|
|
5895
|
+
function correlate(x, y, mode = "valid") {
|
|
5896
|
+
return _convImpl("correlate", x, y, mode);
|
|
5897
|
+
}
|
|
5508
5898
|
/**
|
|
5509
5899
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
5510
5900
|
*
|
|
@@ -5513,7 +5903,7 @@ function vdot(x, y) {
|
|
|
5513
5903
|
*/
|
|
5514
5904
|
function meshgrid(xs, { indexing } = {}) {
|
|
5515
5905
|
indexing ??= "xy";
|
|
5516
|
-
for (const x of xs) if (x.ndim !== 1) throw new
|
|
5906
|
+
for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
|
|
5517
5907
|
if (xs.length <= 1) return xs;
|
|
5518
5908
|
if (indexing === "xy") {
|
|
5519
5909
|
const [a, b, ...rest] = xs;
|
|
@@ -5529,44 +5919,7 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
5529
5919
|
];
|
|
5530
5920
|
}
|
|
5531
5921
|
const shape$1 = xs.map((x) => x.shape[0]);
|
|
5532
|
-
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
5533
|
-
}
|
|
5534
|
-
/**
|
|
5535
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
5536
|
-
*
|
|
5537
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
5538
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
5539
|
-
* `k>0` is above it.
|
|
5540
|
-
*/
|
|
5541
|
-
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
5542
|
-
m ??= n;
|
|
5543
|
-
dtype ??= DType.Float32;
|
|
5544
|
-
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
5545
|
-
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
5546
|
-
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
5547
|
-
const rows = arange(k, n + k, 1, {
|
|
5548
|
-
dtype: DType.Int32,
|
|
5549
|
-
device
|
|
5550
|
-
});
|
|
5551
|
-
const cols = arange(0, m, 1, {
|
|
5552
|
-
dtype: DType.Int32,
|
|
5553
|
-
device
|
|
5554
|
-
});
|
|
5555
|
-
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
5556
|
-
}
|
|
5557
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
5558
|
-
function tril(a, k = 0) {
|
|
5559
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5560
|
-
a = fudgeArray(a);
|
|
5561
|
-
const [n, m] = a.shape.slice(-2);
|
|
5562
|
-
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
5563
|
-
}
|
|
5564
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
5565
|
-
function triu(a, k = 0) {
|
|
5566
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5567
|
-
a = fudgeArray(a);
|
|
5568
|
-
const [n, m] = a.shape.slice(-2);
|
|
5569
|
-
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
5922
|
+
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
5570
5923
|
}
|
|
5571
5924
|
/**
|
|
5572
5925
|
* Clip (limit) the values in an array.
|
|
@@ -5592,8 +5945,6 @@ function absolute(x) {
|
|
|
5592
5945
|
x = fudgeArray(x);
|
|
5593
5946
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
5594
5947
|
}
|
|
5595
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
5596
|
-
const abs = absolute;
|
|
5597
5948
|
/** Return an element-wise indication of sign of the input. */
|
|
5598
5949
|
function sign(x) {
|
|
5599
5950
|
x = fudgeArray(x);
|
|
@@ -5672,12 +6023,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
|
|
|
5672
6023
|
const denom = where(xNeg, y, r.add(x));
|
|
5673
6024
|
return atan(numer.div(denom)).mul(2);
|
|
5674
6025
|
});
|
|
5675
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
5676
|
-
const arccos = acos;
|
|
5677
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
5678
|
-
const arctan = atan;
|
|
5679
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
5680
|
-
const arctan2 = atan2;
|
|
5681
6026
|
/** Element-wise subtraction, with broadcasting. */
|
|
5682
6027
|
function subtract(x, y) {
|
|
5683
6028
|
x = fudgeArray(x);
|
|
@@ -5708,8 +6053,6 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
5708
6053
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
5709
6054
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
5710
6055
|
});
|
|
5711
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
5712
|
-
const divide = trueDivide;
|
|
5713
6056
|
/** Round input to the nearest integer towards zero. */
|
|
5714
6057
|
function trunc(x) {
|
|
5715
6058
|
return idiv(x, 1);
|
|
@@ -5731,9 +6074,9 @@ function ldexp(x1, x2) {
|
|
|
5731
6074
|
*/
|
|
5732
6075
|
function frexp(x) {
|
|
5733
6076
|
x = fudgeArray(x);
|
|
5734
|
-
const absx =
|
|
6077
|
+
const absx = absolute(x.ref);
|
|
5735
6078
|
const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(DType.Int32));
|
|
5736
|
-
const mantissa =
|
|
6079
|
+
const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
|
|
5737
6080
|
return [mantissa, exponent];
|
|
5738
6081
|
}
|
|
5739
6082
|
/** Calculate `2**p` for all p in the input array. */
|
|
@@ -5776,10 +6119,8 @@ const power = jit$1(function power$1(x1, x2) {
|
|
|
5776
6119
|
const x2i = trunc(x2.ref);
|
|
5777
6120
|
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
5778
6121
|
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
5779
|
-
return where(shouldBeNaN, nan, exp(log(
|
|
6122
|
+
return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
|
|
5780
6123
|
});
|
|
5781
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
5782
|
-
const pow = power;
|
|
5783
6124
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
5784
6125
|
const cbrt = jit$1(function cbrt$1(x) {
|
|
5785
6126
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
@@ -5845,12 +6186,6 @@ const arccosh = jit$1(function arccosh$1(x) {
|
|
|
5845
6186
|
const arctanh = jit$1(function arctanh$1(x) {
|
|
5846
6187
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
5847
6188
|
});
|
|
5848
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5849
|
-
const asinh = arcsinh;
|
|
5850
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5851
|
-
const acosh = arccosh;
|
|
5852
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5853
|
-
const atanh = arctanh;
|
|
5854
6189
|
/**
|
|
5855
6190
|
* Compute the variance of an array.
|
|
5856
6191
|
*
|
|
@@ -5880,6 +6215,26 @@ function var_(x, axis = null, opts) {
|
|
|
5880
6215
|
function std(x, axis = null, opts) {
|
|
5881
6216
|
return sqrt(var_(x, axis, opts));
|
|
5882
6217
|
}
|
|
6218
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
6219
|
+
function cov(x, y) {
|
|
6220
|
+
x = fudgeArray(x);
|
|
6221
|
+
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6222
|
+
if (y !== void 0) {
|
|
6223
|
+
y = fudgeArray(y);
|
|
6224
|
+
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6225
|
+
x = vstack([x, y]);
|
|
6226
|
+
}
|
|
6227
|
+
const [_M, N] = x.shape;
|
|
6228
|
+
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6229
|
+
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
6230
|
+
}
|
|
6231
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
6232
|
+
function corrcoef(x, y) {
|
|
6233
|
+
const c = cov(x, y);
|
|
6234
|
+
const variances = diag(c.ref);
|
|
6235
|
+
const norm = sqrt(outer(variances.ref, variances));
|
|
6236
|
+
return c.div(norm);
|
|
6237
|
+
}
|
|
5883
6238
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5884
6239
|
function isinf(x) {
|
|
5885
6240
|
x = fudgeArray(x);
|
|
@@ -5909,6 +6264,253 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
5909
6264
|
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5910
6265
|
});
|
|
5911
6266
|
|
|
6267
|
+
//#endregion
|
|
6268
|
+
//#region src/library/lax-linalg.ts
|
|
6269
|
+
var lax_linalg_exports = {};
|
|
6270
|
+
__export(lax_linalg_exports, {
|
|
6271
|
+
cholesky: () => cholesky,
|
|
6272
|
+
triangularSolve: () => triangularSolve
|
|
6273
|
+
});
|
|
6274
|
+
/**
|
|
6275
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
6276
|
+
*
|
|
6277
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
6278
|
+
*
|
|
6279
|
+
* - A = L @ L^T (for upper=false, default)
|
|
6280
|
+
* - A = U^T @ U (for upper=true)
|
|
6281
|
+
*
|
|
6282
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
6283
|
+
* The input matrix must be symmetric and positive-definite.
|
|
6284
|
+
*
|
|
6285
|
+
* @example
|
|
6286
|
+
* ```ts
|
|
6287
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6288
|
+
*
|
|
6289
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
6290
|
+
*
|
|
6291
|
+
* // Lower Cholesky factorization (default):
|
|
6292
|
+
* const L = lax.linalg.cholesky(x);
|
|
6293
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
6294
|
+
*
|
|
6295
|
+
* // Upper Cholesky factorization:
|
|
6296
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
6297
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6298
|
+
* ```
|
|
6299
|
+
*/
|
|
6300
|
+
function cholesky(a, { upper = false } = {}) {
|
|
6301
|
+
const L = cholesky$2(a);
|
|
6302
|
+
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6303
|
+
}
|
|
6304
|
+
/**
|
|
6305
|
+
* Solve a triangular linear system.
|
|
6306
|
+
*
|
|
6307
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
6308
|
+
* where `a` is a triangular matrix.
|
|
6309
|
+
*
|
|
6310
|
+
* @example
|
|
6311
|
+
* ```ts
|
|
6312
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6313
|
+
*
|
|
6314
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
6315
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
6316
|
+
*
|
|
6317
|
+
* // Solve L @ x = b
|
|
6318
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
6319
|
+
* // x = [[2.], [5./3.]]
|
|
6320
|
+
* ```
|
|
6321
|
+
*/
|
|
6322
|
+
function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
|
|
6323
|
+
a = fudgeArray(a);
|
|
6324
|
+
b = fudgeArray(b);
|
|
6325
|
+
if (!leftSide) transposeA = !transposeA;
|
|
6326
|
+
else b = moveaxis$1(b, -2, -1);
|
|
6327
|
+
if (transposeA) a = moveaxis$1(a, -2, -1);
|
|
6328
|
+
let x = triangularSolve$1(a, b, {
|
|
6329
|
+
lower,
|
|
6330
|
+
unitDiagonal
|
|
6331
|
+
});
|
|
6332
|
+
if (leftSide) x = moveaxis$1(x, -2, -1);
|
|
6333
|
+
return x;
|
|
6334
|
+
}
|
|
6335
|
+
|
|
6336
|
+
//#endregion
|
|
6337
|
+
//#region src/library/lax.ts
|
|
6338
|
+
var lax_exports = {};
|
|
6339
|
+
__export(lax_exports, {
|
|
6340
|
+
conv: () => conv,
|
|
6341
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
6342
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6343
|
+
dot: () => dot,
|
|
6344
|
+
erf: () => erf,
|
|
6345
|
+
erfc: () => erfc,
|
|
6346
|
+
linalg: () => lax_linalg_exports,
|
|
6347
|
+
reduceWindow: () => reduceWindow,
|
|
6348
|
+
stopGradient: () => stopGradient$1
|
|
6349
|
+
});
|
|
6350
|
+
/**
|
|
6351
|
+
* General dot product/contraction operator.
|
|
6352
|
+
*
|
|
6353
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
6354
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
6355
|
+
*/
|
|
6356
|
+
function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
6357
|
+
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
6358
|
+
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
6359
|
+
lc = lc.map((a) => checkAxis(a, lhs.ndim));
|
|
6360
|
+
rc = rc.map((a) => checkAxis(a, rhs.ndim));
|
|
6361
|
+
lb = lb.map((a) => checkAxis(a, lhs.ndim));
|
|
6362
|
+
rb = rb.map((a) => checkAxis(a, rhs.ndim));
|
|
6363
|
+
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
6364
|
+
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
6365
|
+
const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
6366
|
+
const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
6367
|
+
const lhs2 = lhs.transpose([
|
|
6368
|
+
...lb,
|
|
6369
|
+
...lf,
|
|
6370
|
+
...lc
|
|
6371
|
+
]);
|
|
6372
|
+
const rhs2 = rhs.transpose([
|
|
6373
|
+
...rb,
|
|
6374
|
+
...rf,
|
|
6375
|
+
...rc
|
|
6376
|
+
]);
|
|
6377
|
+
if (lc.length === 0) return mul(lhs2.reshape([
|
|
6378
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6379
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6380
|
+
...rep(rf.length, 1)
|
|
6381
|
+
]), rhs2.reshape([
|
|
6382
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6383
|
+
...rep(lf.length, 1),
|
|
6384
|
+
...rf.map((a) => rhs.shape[a])
|
|
6385
|
+
]));
|
|
6386
|
+
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
6387
|
+
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
6388
|
+
if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
6389
|
+
return dot$2(lhs2.reshape([
|
|
6390
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6391
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6392
|
+
...rep(rf.length, 1),
|
|
6393
|
+
prod(dotShapeX)
|
|
6394
|
+
]), rhs2.reshape([
|
|
6395
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6396
|
+
...rep(lf.length, 1),
|
|
6397
|
+
...rf.map((a) => rhs.shape[a]),
|
|
6398
|
+
prod(dotShapeY)
|
|
6399
|
+
]));
|
|
6400
|
+
}
|
|
6401
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
6402
|
+
const padType = padding.toUpperCase();
|
|
6403
|
+
switch (padType) {
|
|
6404
|
+
case "VALID": return rep(inShape.length, [0, 0]);
|
|
6405
|
+
case "SAME":
|
|
6406
|
+
case "SAME_LOWER": {
|
|
6407
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
6408
|
+
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
6409
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
6410
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
6411
|
+
}
|
|
6412
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
6413
|
+
}
|
|
6414
|
+
}
|
|
6415
|
+
/**
|
|
6416
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
6417
|
+
*
|
|
6418
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6419
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
6420
|
+
*
|
|
6421
|
+
* Grouped convolutions are not supported right now.
|
|
6422
|
+
*/
|
|
6423
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6424
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
6425
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
6426
|
+
if (typeof padding === "string") {
|
|
6427
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
6428
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
6429
|
+
}
|
|
6430
|
+
if (featureGroupCount !== 1) {
|
|
6431
|
+
const G = featureGroupCount;
|
|
6432
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
6433
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
6434
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
6435
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
6436
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
6437
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
6438
|
+
N,
|
|
6439
|
+
G,
|
|
6440
|
+
C_in / G,
|
|
6441
|
+
...xs
|
|
6442
|
+
]), 1, 0);
|
|
6443
|
+
const rhsGrouped = rhs.reshape([
|
|
6444
|
+
G,
|
|
6445
|
+
C_out / G,
|
|
6446
|
+
C_in_per_group,
|
|
6447
|
+
...ks
|
|
6448
|
+
]);
|
|
6449
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
6450
|
+
vmapDims: 1,
|
|
6451
|
+
strides: windowStrides,
|
|
6452
|
+
padding,
|
|
6453
|
+
lhsDilation,
|
|
6454
|
+
rhsDilation
|
|
6455
|
+
});
|
|
6456
|
+
const ys = result.shape.slice(3);
|
|
6457
|
+
return moveaxis(result, 0, 1).reshape([
|
|
6458
|
+
N,
|
|
6459
|
+
C_out,
|
|
6460
|
+
...ys
|
|
6461
|
+
]);
|
|
6462
|
+
}
|
|
6463
|
+
return conv$1(lhs, rhs, {
|
|
6464
|
+
strides: windowStrides,
|
|
6465
|
+
padding,
|
|
6466
|
+
lhsDilation,
|
|
6467
|
+
rhsDilation
|
|
6468
|
+
});
|
|
6469
|
+
}
|
|
6470
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6471
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
6472
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
6473
|
+
lhsDilation,
|
|
6474
|
+
rhsDilation
|
|
6475
|
+
});
|
|
6476
|
+
}
|
|
6477
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6478
|
+
function conv(lhs, rhs, windowStrides, padding) {
|
|
6479
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
6480
|
+
}
|
|
6481
|
+
/** Reduce a computation over padded windows. */
|
|
6482
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
6483
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
6484
|
+
if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
|
|
6485
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
6486
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
6487
|
+
window: windowDimensions,
|
|
6488
|
+
strides: windowStrides
|
|
6489
|
+
}));
|
|
6490
|
+
}
|
|
6491
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
6492
|
+
function erf(x) {
|
|
6493
|
+
return erf$1(x);
|
|
6494
|
+
}
|
|
6495
|
+
/**
|
|
6496
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
6497
|
+
*
|
|
6498
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
6499
|
+
* where `erf(x)` is very close to 1.
|
|
6500
|
+
*/
|
|
6501
|
+
function erfc(x) {
|
|
6502
|
+
return erfc$1(x);
|
|
6503
|
+
}
|
|
6504
|
+
/**
|
|
6505
|
+
* Stops gradient computation.
|
|
6506
|
+
*
|
|
6507
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
6508
|
+
* forward or reverse-mode automatic differentiation.
|
|
6509
|
+
*/
|
|
6510
|
+
function stopGradient$1(x) {
|
|
6511
|
+
return stopGradient(x);
|
|
6512
|
+
}
|
|
6513
|
+
|
|
5912
6514
|
//#endregion
|
|
5913
6515
|
//#region src/library/nn.ts
|
|
5914
6516
|
var nn_exports = {};
|
|
@@ -5917,6 +6519,10 @@ __export(nn_exports, {
|
|
|
5917
6519
|
elu: () => elu,
|
|
5918
6520
|
gelu: () => gelu,
|
|
5919
6521
|
glu: () => glu,
|
|
6522
|
+
hardSigmoid: () => hardSigmoid,
|
|
6523
|
+
hardSilu: () => hardSilu,
|
|
6524
|
+
hardSwish: () => hardSilu,
|
|
6525
|
+
hardTanh: () => hardTanh,
|
|
5920
6526
|
identity: () => identity,
|
|
5921
6527
|
leakyRelu: () => leakyRelu,
|
|
5922
6528
|
logSigmoid: () => logSigmoid,
|
|
@@ -5927,14 +6533,17 @@ __export(nn_exports, {
|
|
|
5927
6533
|
oneHot: () => oneHot,
|
|
5928
6534
|
relu: () => relu,
|
|
5929
6535
|
relu6: () => relu6,
|
|
6536
|
+
selu: () => selu,
|
|
5930
6537
|
sigmoid: () => sigmoid,
|
|
5931
6538
|
silu: () => silu,
|
|
5932
6539
|
softSign: () => softSign,
|
|
5933
6540
|
softmax: () => softmax,
|
|
5934
6541
|
softplus: () => softplus,
|
|
6542
|
+
sparsePlus: () => sparsePlus,
|
|
6543
|
+
sparseSigmoid: () => sparseSigmoid,
|
|
5935
6544
|
squareplus: () => squareplus,
|
|
5936
6545
|
standardize: () => standardize,
|
|
5937
|
-
swish: () =>
|
|
6546
|
+
swish: () => silu
|
|
5938
6547
|
});
|
|
5939
6548
|
/**
|
|
5940
6549
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -5969,6 +6578,28 @@ function softplus(x) {
|
|
|
5969
6578
|
return log(exp(x).add(1));
|
|
5970
6579
|
}
|
|
5971
6580
|
/**
|
|
6581
|
+
* @function
|
|
6582
|
+
* Sparse plus function:
|
|
6583
|
+
*
|
|
6584
|
+
* - When `x <= -1`: `0`
|
|
6585
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
6586
|
+
* - When `x >= 1`: `x`
|
|
6587
|
+
*/
|
|
6588
|
+
const sparsePlus = jit$1((x) => {
|
|
6589
|
+
return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
|
|
6590
|
+
});
|
|
6591
|
+
/**
|
|
6592
|
+
* @function
|
|
6593
|
+
* Sparse sigmoid activation function.
|
|
6594
|
+
*
|
|
6595
|
+
* - When `x <= -1`: `0`
|
|
6596
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
6597
|
+
* - When `x >= 1`: `1`
|
|
6598
|
+
*/
|
|
6599
|
+
const sparseSigmoid = jit$1((x) => {
|
|
6600
|
+
return clip(x.add(1).mul(.5), 0, 1);
|
|
6601
|
+
});
|
|
6602
|
+
/**
|
|
5972
6603
|
* Soft-sign activation function, computed element-wise:
|
|
5973
6604
|
* `softsign(x) = x / (|x| + 1)`.
|
|
5974
6605
|
*/
|
|
@@ -5990,17 +6621,6 @@ const silu = jit$1(function silu$1(x) {
|
|
|
5990
6621
|
return x.ref.mul(sigmoid(x));
|
|
5991
6622
|
});
|
|
5992
6623
|
/**
|
|
5993
|
-
* @function
|
|
5994
|
-
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
5995
|
-
* Swish, computed element-wise:
|
|
5996
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
5997
|
-
*
|
|
5998
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
5999
|
-
*
|
|
6000
|
-
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
6001
|
-
*/
|
|
6002
|
-
const swish = silu;
|
|
6003
|
-
/**
|
|
6004
6624
|
* Log-sigmoid activation function, computed element-wise:
|
|
6005
6625
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
6006
6626
|
*/
|
|
@@ -6017,6 +6637,19 @@ function leakyRelu(x, negativeSlope = .01) {
|
|
|
6017
6637
|
x = fudgeArray(x);
|
|
6018
6638
|
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
6019
6639
|
}
|
|
6640
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
6641
|
+
function hardSigmoid(x) {
|
|
6642
|
+
return relu6(add(x, 3)).mul(1 / 6);
|
|
6643
|
+
}
|
|
6644
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
6645
|
+
function hardSilu(x) {
|
|
6646
|
+
x = fudgeArray(x);
|
|
6647
|
+
return x.ref.mul(hardSigmoid(x));
|
|
6648
|
+
}
|
|
6649
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
6650
|
+
function hardTanh(x) {
|
|
6651
|
+
return clip(x, -1, 1);
|
|
6652
|
+
}
|
|
6020
6653
|
/**
|
|
6021
6654
|
* Exponential linear unit activation function.
|
|
6022
6655
|
*
|
|
@@ -6039,6 +6672,20 @@ function celu(x, alpha = 1) {
|
|
|
6039
6672
|
}
|
|
6040
6673
|
/**
|
|
6041
6674
|
* @function
|
|
6675
|
+
* Scaled exponential linear unit activation.
|
|
6676
|
+
*
|
|
6677
|
+
* Computes the element-wise function:
|
|
6678
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
6679
|
+
*
|
|
6680
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
6681
|
+
*/
|
|
6682
|
+
const selu = jit$1(function selu$1(x) {
|
|
6683
|
+
const alpha = 1.6732632423543772;
|
|
6684
|
+
const lambda = 1.0507009873554805;
|
|
6685
|
+
return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
|
|
6686
|
+
});
|
|
6687
|
+
/**
|
|
6688
|
+
* @function
|
|
6042
6689
|
* Gaussion error linear unit (GELU) activation function.
|
|
6043
6690
|
*
|
|
6044
6691
|
* This is computed element-wise. There are two variants depending on whether
|
|
@@ -6192,8 +6839,11 @@ var random_exports = {};
|
|
|
6192
6839
|
__export(random_exports, {
|
|
6193
6840
|
bernoulli: () => bernoulli,
|
|
6194
6841
|
bits: () => bits,
|
|
6842
|
+
cauchy: () => cauchy,
|
|
6195
6843
|
exponential: () => exponential,
|
|
6844
|
+
gumbel: () => gumbel,
|
|
6196
6845
|
key: () => key,
|
|
6846
|
+
laplace: () => laplace,
|
|
6197
6847
|
normal: () => normal,
|
|
6198
6848
|
split: () => split,
|
|
6199
6849
|
uniform: () => uniform
|
|
@@ -6252,6 +6902,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
6252
6902
|
}
|
|
6253
6903
|
/**
|
|
6254
6904
|
* @function
|
|
6905
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
6906
|
+
*
|
|
6907
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
6908
|
+
*/
|
|
6909
|
+
const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
6910
|
+
const u = uniform(key$1, shape$1);
|
|
6911
|
+
return tan(u.sub(.5).mul(Math.PI));
|
|
6912
|
+
}, { staticArgnums: [1] });
|
|
6913
|
+
/**
|
|
6914
|
+
* @function
|
|
6255
6915
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
6256
6916
|
*/
|
|
6257
6917
|
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
@@ -6260,6 +6920,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
6260
6920
|
}, { staticArgnums: [1] });
|
|
6261
6921
|
/**
|
|
6262
6922
|
* @function
|
|
6923
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
6924
|
+
*
|
|
6925
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
6926
|
+
*/
|
|
6927
|
+
const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
|
|
6928
|
+
const u = uniform(key$1, shape$1);
|
|
6929
|
+
return negative(log(negative(log1p(negative(u)))));
|
|
6930
|
+
}, { staticArgnums: [1] });
|
|
6931
|
+
/**
|
|
6932
|
+
* @function
|
|
6933
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
6934
|
+
*
|
|
6935
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
6936
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
6937
|
+
*/
|
|
6938
|
+
const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
6939
|
+
const u = uniform(key$1, shape$1);
|
|
6940
|
+
const centered = u.sub(.5);
|
|
6941
|
+
const s = sign(centered.ref);
|
|
6942
|
+
const absVal = absolute(centered);
|
|
6943
|
+
return s.mul(log1p(absVal.mul(-2)).mul(-1));
|
|
6944
|
+
}, { staticArgnums: [1] });
|
|
6945
|
+
/**
|
|
6946
|
+
* @function
|
|
6263
6947
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6264
6948
|
*
|
|
6265
6949
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -6368,11 +7052,6 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
6368
7052
|
*/
|
|
6369
7053
|
const jacrev = jacrev$1;
|
|
6370
7054
|
/**
|
|
6371
|
-
* @function
|
|
6372
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
6373
|
-
*/
|
|
6374
|
-
const jacobian = jacrev;
|
|
6375
|
-
/**
|
|
6376
7055
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
6377
7056
|
*
|
|
6378
7057
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -6407,5 +7086,4 @@ async function devicePut(x, device) {
|
|
|
6407
7086
|
}
|
|
6408
7087
|
|
|
6409
7088
|
//#endregion
|
|
6410
|
-
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
6411
|
-
//# sourceMappingURL=index.js.map
|
|
7089
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|