@jax-js/jax 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -19
- package/dist/{backend-BqDtPGaR.js → backend-EBRGmEYw.js} +296 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-Ss1Mev_-.cjs} +315 -154
- package/dist/index.cjs +681 -157
- package/dist/index.d.cts +422 -76
- package/dist/index.d.ts +422 -76
- package/dist/index.js +677 -157
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-BVdMaO9T.cjs} +9 -3
- package/dist/{webgpu-CNg9JGva.js → webgpu-ow0Pn_6q.js} +9 -3
- package/package.json +15 -4
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache,
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-EBRGmEYw.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -323,6 +323,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
323
323
|
Primitive$1["RandomBits"] = "random_bits";
|
|
324
324
|
Primitive$1["Sin"] = "sin";
|
|
325
325
|
Primitive$1["Cos"] = "cos";
|
|
326
|
+
Primitive$1["Asin"] = "asin";
|
|
327
|
+
Primitive$1["Atan"] = "atan";
|
|
326
328
|
Primitive$1["Exp"] = "exp";
|
|
327
329
|
Primitive$1["Log"] = "log";
|
|
328
330
|
Primitive$1["Sqrt"] = "sqrt";
|
|
@@ -390,6 +392,12 @@ function sin$1(x) {
|
|
|
390
392
|
function cos$1(x) {
|
|
391
393
|
return bind1(Primitive.Cos, [x]);
|
|
392
394
|
}
|
|
395
|
+
function asin$1(x) {
|
|
396
|
+
return bind1(Primitive.Asin, [x]);
|
|
397
|
+
}
|
|
398
|
+
function atan$1(x) {
|
|
399
|
+
return bind1(Primitive.Atan, [x]);
|
|
400
|
+
}
|
|
393
401
|
function exp$1(x) {
|
|
394
402
|
return bind1(Primitive.Exp, [x]);
|
|
395
403
|
}
|
|
@@ -405,18 +413,16 @@ function min$1(x, y) {
|
|
|
405
413
|
function max$1(x, y) {
|
|
406
414
|
return bind1(Primitive.Max, [x, y]);
|
|
407
415
|
}
|
|
408
|
-
function reduce(x, op, axis, opts) {
|
|
416
|
+
function reduce(x, op, axis = null, opts) {
|
|
409
417
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
410
|
-
|
|
411
|
-
else axis = [];
|
|
412
|
-
else if (typeof axis === "number") axis = [checkAxis(axis, ndim$1(x))];
|
|
413
|
-
else axis = axis.map((a) => checkAxis(a, ndim$1(x)));
|
|
418
|
+
axis = normalizeAxis(axis, ndim$1(x));
|
|
414
419
|
const originalShape = getShape(x);
|
|
415
|
-
|
|
420
|
+
let result = bind1(Primitive.Reduce, [x], {
|
|
416
421
|
op,
|
|
417
422
|
axis
|
|
418
423
|
});
|
|
419
|
-
|
|
424
|
+
if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
|
|
425
|
+
return result;
|
|
420
426
|
}
|
|
421
427
|
function dot$1(x, y) {
|
|
422
428
|
return bind1(Primitive.Dot, [x, y]);
|
|
@@ -462,10 +468,11 @@ function where$1(cond, x, y) {
|
|
|
462
468
|
}
|
|
463
469
|
function transpose$1(x, perm) {
|
|
464
470
|
perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
|
|
471
|
+
if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
465
472
|
return bind1(Primitive.Transpose, [x], { perm });
|
|
466
473
|
}
|
|
467
474
|
function broadcast(x, shape$1, axis) {
|
|
468
|
-
axis = axis
|
|
475
|
+
axis = normalizeAxis(axis, shape$1.length);
|
|
469
476
|
return bind1(Primitive.Broadcast, [x], {
|
|
470
477
|
shape: shape$1,
|
|
471
478
|
axis
|
|
@@ -484,7 +491,7 @@ function reshape$1(x, shape$1) {
|
|
|
484
491
|
return bind1(Primitive.Reshape, [x], { shape: shape$1 });
|
|
485
492
|
}
|
|
486
493
|
function flip$1(x, axis) {
|
|
487
|
-
axis = axis
|
|
494
|
+
axis = normalizeAxis(axis, ndim$1(x));
|
|
488
495
|
return bind1(Primitive.Flip, [x], { axis });
|
|
489
496
|
}
|
|
490
497
|
function shrink(x, slice) {
|
|
@@ -564,15 +571,19 @@ var Tracer = class Tracer {
|
|
|
564
571
|
constructor(trace) {
|
|
565
572
|
this._trace = trace;
|
|
566
573
|
}
|
|
574
|
+
/** The shape of the array. */
|
|
567
575
|
get shape() {
|
|
568
576
|
return this.aval.shape;
|
|
569
577
|
}
|
|
578
|
+
/** The total number of elements in the array. */
|
|
570
579
|
get size() {
|
|
571
580
|
return prod(this.shape);
|
|
572
581
|
}
|
|
582
|
+
/** The dtype of the array. */
|
|
573
583
|
get dtype() {
|
|
574
584
|
return this.aval.dtype;
|
|
575
585
|
}
|
|
586
|
+
/** The number of dimensions of the array. */
|
|
576
587
|
get ndim() {
|
|
577
588
|
return this.shape.length;
|
|
578
589
|
}
|
|
@@ -608,22 +619,20 @@ var Tracer = class Tracer {
|
|
|
608
619
|
return lessEqual$1(this, other);
|
|
609
620
|
}
|
|
610
621
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
611
|
-
sum(axis, opts) {
|
|
622
|
+
sum(axis = null, opts) {
|
|
612
623
|
return reduce(this, AluOp.Add, axis, opts);
|
|
613
624
|
}
|
|
614
625
|
/** Product of the array elements over a given axis. */
|
|
615
|
-
prod(axis, opts) {
|
|
626
|
+
prod(axis = null, opts) {
|
|
616
627
|
return reduce(this, AluOp.Mul, axis, opts);
|
|
617
628
|
}
|
|
618
629
|
/** Compute the average of the array elements along the specified axis. */
|
|
619
|
-
mean(axis, opts) {
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
626
|
-
return result;
|
|
630
|
+
mean(axis = null, opts) {
|
|
631
|
+
axis = normalizeAxis(axis, this.ndim);
|
|
632
|
+
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
633
|
+
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
634
|
+
const result = reduce(this, AluOp.Add, axis, opts);
|
|
635
|
+
return result.mul(1 / n);
|
|
627
636
|
}
|
|
628
637
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
629
638
|
transpose(perm) {
|
|
@@ -1156,6 +1165,8 @@ const jitRules = {
|
|
|
1156
1165
|
},
|
|
1157
1166
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
1158
1167
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
1168
|
+
[Primitive.Asin]: unopJit(AluExp.asin),
|
|
1169
|
+
[Primitive.Atan]: unopJit(AluExp.atan),
|
|
1159
1170
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
1160
1171
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1161
1172
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
@@ -1397,7 +1408,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1397
1408
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1398
1409
|
* will be freed when the array is disposed.
|
|
1399
1410
|
*/
|
|
1400
|
-
constructor(source, st, dtype, backend, pending = null) {
|
|
1411
|
+
constructor(source, st, dtype, backend, { pending = null } = {}) {
|
|
1401
1412
|
super(baseArrayTrace);
|
|
1402
1413
|
this.id = Array$1.#nextId++;
|
|
1403
1414
|
this.#dtype = dtype;
|
|
@@ -1406,6 +1417,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1406
1417
|
this.#backend = backend;
|
|
1407
1418
|
this.#rc = 1;
|
|
1408
1419
|
this.#pendingSet = new Set(pending);
|
|
1420
|
+
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1421
|
+
else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1409
1422
|
}
|
|
1410
1423
|
/** @ignore */
|
|
1411
1424
|
get aval() {
|
|
@@ -1460,7 +1473,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1460
1473
|
const pending = this.#pending;
|
|
1461
1474
|
for (const exe of pending) exe.updateRc(1);
|
|
1462
1475
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1463
|
-
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
|
|
1476
|
+
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
|
|
1464
1477
|
this.dispose();
|
|
1465
1478
|
return ar;
|
|
1466
1479
|
}
|
|
@@ -1509,7 +1522,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1509
1522
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1510
1523
|
this.dispose();
|
|
1511
1524
|
for (const ar of indices) ar.dispose();
|
|
1512
|
-
return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
|
|
1525
|
+
return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
|
|
1513
1526
|
}
|
|
1514
1527
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1515
1528
|
#moveAxesDown(axis) {
|
|
@@ -1546,7 +1559,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1546
1559
|
for (const exe of pending) exe.updateRc(1);
|
|
1547
1560
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1548
1561
|
this.dispose();
|
|
1549
|
-
return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
|
|
1562
|
+
return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
|
|
1550
1563
|
}
|
|
1551
1564
|
#binary(op, other) {
|
|
1552
1565
|
const custom = (src) => new AluExp(op, this.#dtype, src);
|
|
@@ -1611,7 +1624,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1611
1624
|
for (const exe of pending) exe.updateRc(1);
|
|
1612
1625
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1613
1626
|
for (const ar of arrays) ar.dispose();
|
|
1614
|
-
return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
|
|
1627
|
+
return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
|
|
1615
1628
|
}
|
|
1616
1629
|
/** Reduce the last dimension of the array by an operation. */
|
|
1617
1630
|
#reduce(op) {
|
|
@@ -1635,7 +1648,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1635
1648
|
for (const exe of pending) exe.updateRc(1);
|
|
1636
1649
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1637
1650
|
this.dispose();
|
|
1638
|
-
return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
|
|
1651
|
+
return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
|
|
1639
1652
|
}
|
|
1640
1653
|
/**
|
|
1641
1654
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1708,8 +1721,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1708
1721
|
*
|
|
1709
1722
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1710
1723
|
* dispatch of operations as well.
|
|
1724
|
+
*
|
|
1725
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1726
|
+
* asynchronously for multiple arrays.
|
|
1711
1727
|
*/
|
|
1712
|
-
async
|
|
1728
|
+
async blockUntilReady() {
|
|
1713
1729
|
this.#check();
|
|
1714
1730
|
if (this.#source instanceof AluExp) return this;
|
|
1715
1731
|
const pending = this.#pending;
|
|
@@ -1775,7 +1791,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1775
1791
|
return [x.#binary(AluOp.Idiv, y)];
|
|
1776
1792
|
},
|
|
1777
1793
|
[Primitive.Neg]([x]) {
|
|
1778
|
-
return [zerosLike(x.ref).#binary(AluOp.Sub, x)];
|
|
1794
|
+
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
1779
1795
|
},
|
|
1780
1796
|
[Primitive.Reciprocal]([x]) {
|
|
1781
1797
|
return [x.#unary(AluOp.Reciprocal)];
|
|
@@ -1795,7 +1811,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1795
1811
|
x.#backend.incRef(x.#source);
|
|
1796
1812
|
const pending = x.#pending;
|
|
1797
1813
|
for (const exe of pending) exe.updateRc(1);
|
|
1798
|
-
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
|
|
1814
|
+
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
|
|
1799
1815
|
x.dispose();
|
|
1800
1816
|
return [y];
|
|
1801
1817
|
}
|
|
@@ -1825,6 +1841,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1825
1841
|
[Primitive.Cos]([x]) {
|
|
1826
1842
|
return [x.#unary(AluOp.Cos)];
|
|
1827
1843
|
},
|
|
1844
|
+
[Primitive.Asin]([x]) {
|
|
1845
|
+
return [x.#unary(AluOp.Asin)];
|
|
1846
|
+
},
|
|
1847
|
+
[Primitive.Atan]([x]) {
|
|
1848
|
+
return [x.#unary(AluOp.Atan)];
|
|
1849
|
+
},
|
|
1828
1850
|
[Primitive.Exp]([x]) {
|
|
1829
1851
|
return [x.#unary(AluOp.Exp)];
|
|
1830
1852
|
},
|
|
@@ -1910,7 +1932,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1910
1932
|
pending.splice(0, 0, ...prevPending);
|
|
1911
1933
|
args.forEach((x) => x.dispose());
|
|
1912
1934
|
return outputs.map((source, i) => {
|
|
1913
|
-
return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
|
|
1935
|
+
return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
|
|
1914
1936
|
});
|
|
1915
1937
|
}
|
|
1916
1938
|
};
|
|
@@ -2042,12 +2064,12 @@ var EvalTrace = class extends Trace {
|
|
|
2042
2064
|
};
|
|
2043
2065
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2044
2066
|
const implRules = Array$1._implRules();
|
|
2045
|
-
function zerosLike(val, dtype) {
|
|
2067
|
+
function zerosLike$1(val, dtype) {
|
|
2046
2068
|
const aval = getAval(val);
|
|
2047
2069
|
if (val instanceof Tracer) val.dispose();
|
|
2048
2070
|
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2049
2071
|
}
|
|
2050
|
-
function onesLike(val, dtype) {
|
|
2072
|
+
function onesLike$1(val, dtype) {
|
|
2051
2073
|
const aval = getAval(val);
|
|
2052
2074
|
if (val instanceof Tracer) val.dispose();
|
|
2053
2075
|
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
@@ -2110,7 +2132,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2110
2132
|
const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
|
|
2111
2133
|
return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
|
|
2112
2134
|
}
|
|
2113
|
-
/** Return the identity
|
|
2135
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2114
2136
|
function identity$1(n, { dtype, device } = {}) {
|
|
2115
2137
|
return eye(n, n, {
|
|
2116
2138
|
dtype,
|
|
@@ -2386,16 +2408,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2386
2408
|
varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
|
|
2387
2409
|
return id;
|
|
2388
2410
|
};
|
|
2389
|
-
hasher.update(this.inBinders.length
|
|
2390
|
-
|
|
2391
|
-
|
|
2392
|
-
|
|
2393
|
-
|
|
2394
|
-
|
|
2395
|
-
eqn.
|
|
2396
|
-
|
|
2397
|
-
|
|
2398
|
-
|
|
2411
|
+
hasher.update(this.inBinders.length);
|
|
2412
|
+
for (const x of this.inBinders) hasher.update(vi(x));
|
|
2413
|
+
hasher.update(this.eqns.length);
|
|
2414
|
+
for (const eqn of this.eqns) {
|
|
2415
|
+
hasher.update(eqn.primitive);
|
|
2416
|
+
hasher.update(eqn.inputs.length);
|
|
2417
|
+
for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2418
|
+
hasher.update(JSON.stringify(eqn.params));
|
|
2419
|
+
hasher.update(eqn.outBinders.length);
|
|
2420
|
+
for (const x of eqn.outBinders) hasher.update(vi(x));
|
|
2421
|
+
}
|
|
2422
|
+
hasher.update(this.outs.length);
|
|
2423
|
+
for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2399
2424
|
return this.#hash = hasher.value;
|
|
2400
2425
|
}
|
|
2401
2426
|
hash(state) {
|
|
@@ -2432,7 +2457,7 @@ var Jaxpr = class Jaxpr {
|
|
|
2432
2457
|
const c = eqn.outBinders[0];
|
|
2433
2458
|
if (atomIsLit(b, 1)) context.set(c, a);
|
|
2434
2459
|
else newEqns.push(eqn);
|
|
2435
|
-
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2460
|
+
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2436
2461
|
else newEqns.push(eqn);
|
|
2437
2462
|
}
|
|
2438
2463
|
const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
|
|
@@ -2698,6 +2723,8 @@ const abstractEvalRules = {
|
|
|
2698
2723
|
},
|
|
2699
2724
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2700
2725
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2726
|
+
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
2727
|
+
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2701
2728
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2702
2729
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2703
2730
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
@@ -2825,7 +2852,7 @@ function makeJaxpr$1(f, opts) {
|
|
|
2825
2852
|
function jit$1(f, opts) {
|
|
2826
2853
|
const cache = /* @__PURE__ */ new Map();
|
|
2827
2854
|
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2828
|
-
|
|
2855
|
+
const result = ((...args) => {
|
|
2829
2856
|
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2830
2857
|
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2831
2858
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
@@ -2839,6 +2866,10 @@ function jit$1(f, opts) {
|
|
|
2839
2866
|
});
|
|
2840
2867
|
return unflatten(outTree, outs);
|
|
2841
2868
|
});
|
|
2869
|
+
result.dispose = () => {
|
|
2870
|
+
for (const { consts } of cache.values()) for (const c of consts) c.dispose();
|
|
2871
|
+
};
|
|
2872
|
+
return result;
|
|
2842
2873
|
}
|
|
2843
2874
|
|
|
2844
2875
|
//#endregion
|
|
@@ -2869,7 +2900,7 @@ var JVPTrace = class extends Trace {
|
|
|
2869
2900
|
return this.lift(pureArray(val));
|
|
2870
2901
|
}
|
|
2871
2902
|
lift(val) {
|
|
2872
|
-
return new JVPTracer(this, val, zerosLike(val.ref));
|
|
2903
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
2873
2904
|
}
|
|
2874
2905
|
processPrimitive(primitive, tracers, params) {
|
|
2875
2906
|
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2900,7 +2931,7 @@ function zeroTangentsJvp(primitive) {
|
|
|
2900
2931
|
return (primals, tangents, params) => {
|
|
2901
2932
|
for (const t of tangents) t.dispose();
|
|
2902
2933
|
const ys = bind(primitive, primals, params);
|
|
2903
|
-
return [ys, ys.map((y) => zerosLike(y.ref))];
|
|
2934
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
2904
2935
|
};
|
|
2905
2936
|
}
|
|
2906
2937
|
const jvpRules = {
|
|
@@ -2918,13 +2949,13 @@ const jvpRules = {
|
|
|
2918
2949
|
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2919
2950
|
else {
|
|
2920
2951
|
dx.dispose();
|
|
2921
|
-
return [[cast(x.ref, dtype)], [zerosLike(x)]];
|
|
2952
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2922
2953
|
}
|
|
2923
2954
|
},
|
|
2924
2955
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2925
2956
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2926
2957
|
dx.dispose();
|
|
2927
|
-
return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
|
|
2958
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2928
2959
|
},
|
|
2929
2960
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2930
2961
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2933,6 +2964,14 @@ const jvpRules = {
|
|
|
2933
2964
|
[Primitive.Cos]([x], [dx]) {
|
|
2934
2965
|
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
2935
2966
|
},
|
|
2967
|
+
[Primitive.Asin]([x], [dx]) {
|
|
2968
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
2969
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
2970
|
+
},
|
|
2971
|
+
[Primitive.Atan]([x], [dx]) {
|
|
2972
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
2973
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
2974
|
+
},
|
|
2936
2975
|
[Primitive.Exp]([x], [dx]) {
|
|
2937
2976
|
const z = exp$1(x);
|
|
2938
2977
|
return [[z.ref], [z.mul(dx)]];
|
|
@@ -3048,7 +3087,10 @@ function mappedAval(batchDim, aval) {
|
|
|
3048
3087
|
/** Move one axis to a different index. */
|
|
3049
3088
|
function moveaxis$1(x, src, dst) {
|
|
3050
3089
|
const t = pureArray(x);
|
|
3051
|
-
|
|
3090
|
+
src = checkAxis(src, t.ndim);
|
|
3091
|
+
dst = checkAxis(dst, t.ndim);
|
|
3092
|
+
if (src === dst) return t;
|
|
3093
|
+
const perm = range(t.ndim);
|
|
3052
3094
|
perm.splice(src, 1);
|
|
3053
3095
|
perm.splice(dst, 0, src);
|
|
3054
3096
|
return transpose$1(t, perm);
|
|
@@ -3141,6 +3183,8 @@ const vmapRules = {
|
|
|
3141
3183
|
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3142
3184
|
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3143
3185
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3186
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3187
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3144
3188
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3145
3189
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3146
3190
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
@@ -3326,20 +3370,28 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3326
3370
|
function linearizeFlat(f, primalsIn) {
|
|
3327
3371
|
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
3328
3372
|
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3329
|
-
|
|
3373
|
+
const dispose$1 = () => {
|
|
3374
|
+
for (const c of consts) c.dispose();
|
|
3375
|
+
};
|
|
3376
|
+
return [
|
|
3377
|
+
primalsOut,
|
|
3378
|
+
fLin,
|
|
3379
|
+
dispose$1
|
|
3380
|
+
];
|
|
3330
3381
|
}
|
|
3331
3382
|
function linearize$1(f, ...primalsIn) {
|
|
3332
3383
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3333
3384
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3334
|
-
const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3385
|
+
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3335
3386
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
3336
3387
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3337
|
-
const fLin = (...tangentsIn) => {
|
|
3388
|
+
const fLin = ((...tangentsIn) => {
|
|
3338
3389
|
const [tangentsInFlat, inTree2] = flatten(tangentsIn);
|
|
3339
3390
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
|
|
3340
3391
|
const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
|
|
3341
3392
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
3342
|
-
};
|
|
3393
|
+
});
|
|
3394
|
+
fLin.dispose = dispose$1;
|
|
3343
3395
|
return [primalsOut, fLin];
|
|
3344
3396
|
}
|
|
3345
3397
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -3455,7 +3507,10 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3455
3507
|
avalsOut: jaxpr2.outs.map((x) => x.aval),
|
|
3456
3508
|
tracerRefsOut: []
|
|
3457
3509
|
};
|
|
3458
|
-
const outs2 = jaxpr2.outs.map((x) =>
|
|
3510
|
+
const outs2 = jaxpr2.outs.map((x, i$1) => {
|
|
3511
|
+
if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
|
|
3512
|
+
return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
|
|
3513
|
+
});
|
|
3459
3514
|
recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
|
|
3460
3515
|
let i = 0;
|
|
3461
3516
|
let j = 0;
|
|
@@ -3539,13 +3594,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3539
3594
|
const [consts, constvars] = unzip2(constToVar.entries());
|
|
3540
3595
|
const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
|
|
3541
3596
|
const outVars = tracersOut.map((t) => tracerToVar.get(t));
|
|
3542
|
-
|
|
3597
|
+
let jaxpr = new Jaxpr(inBinders, eqns, outVars);
|
|
3543
3598
|
typecheckJaxpr(jaxpr);
|
|
3544
3599
|
for (const t of consts) t.ref;
|
|
3545
3600
|
for (const t of tracersIn) t.dispose();
|
|
3546
3601
|
for (const t of tracersOut) t.dispose();
|
|
3602
|
+
jaxpr = jaxpr.simplify();
|
|
3603
|
+
if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3547
3604
|
return {
|
|
3548
|
-
jaxpr
|
|
3605
|
+
jaxpr,
|
|
3549
3606
|
consts
|
|
3550
3607
|
};
|
|
3551
3608
|
}
|
|
@@ -3811,20 +3868,28 @@ function vjpFlat(f, primalsIn) {
|
|
|
3811
3868
|
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
3812
3869
|
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
3813
3870
|
};
|
|
3814
|
-
|
|
3871
|
+
const dispose$1 = () => {
|
|
3872
|
+
for (const c of consts) c.dispose();
|
|
3873
|
+
};
|
|
3874
|
+
return [
|
|
3875
|
+
primalsOut,
|
|
3876
|
+
fVjp,
|
|
3877
|
+
dispose$1
|
|
3878
|
+
];
|
|
3815
3879
|
}
|
|
3816
3880
|
function vjp$1(f, ...primalsIn) {
|
|
3817
3881
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3818
3882
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3819
|
-
const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3883
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3820
3884
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
3821
3885
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3822
|
-
const fVjp = (cotangentsOut) => {
|
|
3886
|
+
const fVjp = ((cotangentsOut) => {
|
|
3823
3887
|
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
3824
3888
|
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
3825
3889
|
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
3826
3890
|
return unflatten(inTree, cotangentsInFlat);
|
|
3827
|
-
};
|
|
3891
|
+
});
|
|
3892
|
+
fVjp.dispose = dispose$1;
|
|
3828
3893
|
return [primalsOut, fVjp];
|
|
3829
3894
|
}
|
|
3830
3895
|
function grad$1(f) {
|
|
@@ -3842,7 +3907,8 @@ function valueAndGrad$1(f) {
|
|
|
3842
3907
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3843
3908
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3844
3909
|
const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
|
|
3845
|
-
for (const r of rest)
|
|
3910
|
+
for (const r of rest) dispose(r);
|
|
3911
|
+
fVjp.dispose();
|
|
3846
3912
|
return [y, ct];
|
|
3847
3913
|
};
|
|
3848
3914
|
}
|
|
@@ -3850,7 +3916,13 @@ function jacrev$1(f) {
|
|
|
3850
3916
|
return function jacobianReverse(x) {
|
|
3851
3917
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
3852
3918
|
const [size$1] = x.shape;
|
|
3853
|
-
const pullback = (ct) =>
|
|
3919
|
+
const pullback = (ct) => {
|
|
3920
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
3921
|
+
y.dispose();
|
|
3922
|
+
const [ret] = fVjp(ct);
|
|
3923
|
+
fVjp.dispose();
|
|
3924
|
+
return ret;
|
|
3925
|
+
};
|
|
3854
3926
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3855
3927
|
};
|
|
3856
3928
|
}
|
|
@@ -3930,19 +4002,38 @@ __export(numpy_exports, {
|
|
|
3930
4002
|
DType: () => DType,
|
|
3931
4003
|
abs: () => abs,
|
|
3932
4004
|
absolute: () => absolute,
|
|
4005
|
+
acos: () => acos,
|
|
4006
|
+
acosh: () => acosh,
|
|
3933
4007
|
add: () => add,
|
|
3934
4008
|
allclose: () => allclose,
|
|
3935
4009
|
arange: () => arange,
|
|
4010
|
+
arccos: () => arccos,
|
|
4011
|
+
arccosh: () => arccosh,
|
|
4012
|
+
arcsinh: () => arcsinh,
|
|
4013
|
+
arctan: () => arctan,
|
|
4014
|
+
arctan2: () => arctan2,
|
|
4015
|
+
arctanh: () => arctanh,
|
|
3936
4016
|
argmax: () => argmax,
|
|
3937
4017
|
argmin: () => argmin,
|
|
3938
4018
|
array: () => array,
|
|
4019
|
+
asin: () => asin,
|
|
4020
|
+
asinh: () => asinh,
|
|
3939
4021
|
astype: () => astype,
|
|
4022
|
+
atan: () => atan,
|
|
4023
|
+
atan2: () => atan2,
|
|
4024
|
+
atanh: () => atanh,
|
|
3940
4025
|
bool: () => bool,
|
|
4026
|
+
broadcastArrays: () => broadcastArrays,
|
|
4027
|
+
broadcastShapes: () => broadcastShapes,
|
|
4028
|
+
broadcastTo: () => broadcastTo,
|
|
4029
|
+
cbrt: () => cbrt,
|
|
3941
4030
|
clip: () => clip,
|
|
3942
4031
|
columnStack: () => columnStack,
|
|
3943
4032
|
concatenate: () => concatenate,
|
|
3944
4033
|
cos: () => cos,
|
|
3945
4034
|
cosh: () => cosh,
|
|
4035
|
+
deg2rad: () => deg2rad,
|
|
4036
|
+
degrees: () => degrees,
|
|
3946
4037
|
diag: () => diag,
|
|
3947
4038
|
diagonal: () => diagonal,
|
|
3948
4039
|
divide: () => divide,
|
|
@@ -3953,6 +4044,7 @@ __export(numpy_exports, {
|
|
|
3953
4044
|
eulerGamma: () => eulerGamma,
|
|
3954
4045
|
exp: () => exp,
|
|
3955
4046
|
exp2: () => exp2,
|
|
4047
|
+
expm1: () => expm1,
|
|
3956
4048
|
eye: () => eye,
|
|
3957
4049
|
flip: () => flip,
|
|
3958
4050
|
fliplr: () => fliplr,
|
|
@@ -3964,14 +4056,17 @@ __export(numpy_exports, {
|
|
|
3964
4056
|
greater: () => greater,
|
|
3965
4057
|
greaterEqual: () => greaterEqual,
|
|
3966
4058
|
hstack: () => hstack,
|
|
4059
|
+
hypot: () => hypot,
|
|
3967
4060
|
identity: () => identity$1,
|
|
3968
4061
|
inf: () => inf,
|
|
4062
|
+
inner: () => inner,
|
|
3969
4063
|
int32: () => int32,
|
|
3970
4064
|
less: () => less,
|
|
3971
4065
|
lessEqual: () => lessEqual,
|
|
3972
4066
|
linspace: () => linspace,
|
|
3973
4067
|
log: () => log,
|
|
3974
4068
|
log10: () => log10,
|
|
4069
|
+
log1p: () => log1p,
|
|
3975
4070
|
log2: () => log2,
|
|
3976
4071
|
matmul: () => matmul,
|
|
3977
4072
|
max: () => max,
|
|
@@ -3987,35 +4082,49 @@ __export(numpy_exports, {
|
|
|
3987
4082
|
negative: () => negative,
|
|
3988
4083
|
notEqual: () => notEqual,
|
|
3989
4084
|
ones: () => ones,
|
|
3990
|
-
onesLike: () => onesLike
|
|
4085
|
+
onesLike: () => onesLike,
|
|
4086
|
+
outer: () => outer,
|
|
3991
4087
|
pad: () => pad,
|
|
3992
4088
|
permuteDims: () => permuteDims,
|
|
3993
4089
|
pi: () => pi,
|
|
4090
|
+
pow: () => pow,
|
|
4091
|
+
power: () => power,
|
|
3994
4092
|
prod: () => prod$1,
|
|
4093
|
+
promoteTypes: () => promoteTypes,
|
|
4094
|
+
rad2deg: () => rad2deg,
|
|
4095
|
+
radians: () => radians,
|
|
3995
4096
|
ravel: () => ravel,
|
|
3996
4097
|
reciprocal: () => reciprocal,
|
|
4098
|
+
repeat: () => repeat,
|
|
3997
4099
|
reshape: () => reshape,
|
|
3998
|
-
scalar: () => scalar,
|
|
3999
4100
|
shape: () => shape,
|
|
4101
|
+
sign: () => sign,
|
|
4000
4102
|
sin: () => sin,
|
|
4001
4103
|
sinh: () => sinh,
|
|
4002
4104
|
size: () => size,
|
|
4003
4105
|
sqrt: () => sqrt,
|
|
4004
4106
|
square: () => square,
|
|
4005
4107
|
stack: () => stack,
|
|
4108
|
+
std: () => std,
|
|
4109
|
+
subtract: () => subtract,
|
|
4006
4110
|
sum: () => sum,
|
|
4007
4111
|
tan: () => tan,
|
|
4008
4112
|
tanh: () => tanh,
|
|
4113
|
+
tile: () => tile,
|
|
4009
4114
|
transpose: () => transpose,
|
|
4115
|
+
tri: () => tri,
|
|
4116
|
+
tril: () => tril,
|
|
4117
|
+
triu: () => triu,
|
|
4010
4118
|
trueDivide: () => trueDivide,
|
|
4011
4119
|
trunc: () => trunc,
|
|
4012
4120
|
uint32: () => uint32,
|
|
4121
|
+
var_: () => var_,
|
|
4013
4122
|
vdot: () => vdot,
|
|
4014
4123
|
vecdot: () => vecdot,
|
|
4015
4124
|
vstack: () => vstack,
|
|
4016
4125
|
where: () => where,
|
|
4017
4126
|
zeros: () => zeros,
|
|
4018
|
-
zerosLike: () => zerosLike
|
|
4127
|
+
zerosLike: () => zerosLike
|
|
4019
4128
|
});
|
|
4020
4129
|
const float32 = DType.Float32;
|
|
4021
4130
|
const int32 = DType.Int32;
|
|
@@ -4032,54 +4141,66 @@ const inf = Number.POSITIVE_INFINITY;
|
|
|
4032
4141
|
const nan = NaN;
|
|
4033
4142
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
4034
4143
|
const pi = Math.PI;
|
|
4035
|
-
/** Element-wise addition, with broadcasting. */
|
|
4144
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
4036
4145
|
const add = add$1;
|
|
4037
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
4146
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
4038
4147
|
const multiply = mul;
|
|
4039
|
-
/** Numerical negative of every element of an array. */
|
|
4148
|
+
/** @function Numerical negative of every element of an array. */
|
|
4040
4149
|
const negative = neg;
|
|
4041
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4150
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4042
4151
|
const reciprocal = reciprocal$1;
|
|
4043
|
-
/** Element-wise sine function (takes radians). */
|
|
4152
|
+
/** @function Element-wise sine function (takes radians). */
|
|
4044
4153
|
const sin = sin$1;
|
|
4045
|
-
/** Element-wise cosine function (takes radians). */
|
|
4154
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
4046
4155
|
const cos = cos$1;
|
|
4047
|
-
/**
|
|
4156
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
4157
|
+
const asin = asin$1;
|
|
4158
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
4159
|
+
const atan = atan$1;
|
|
4160
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
4048
4161
|
const exp = exp$1;
|
|
4049
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
4162
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
4050
4163
|
const log = log$1;
|
|
4051
|
-
/** Calculate the square root of all elements in the input array. */
|
|
4164
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
4052
4165
|
const sqrt = sqrt$1;
|
|
4053
|
-
/** Return element-wise minimum of the input arrays. */
|
|
4166
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
4054
4167
|
const minimum = min$1;
|
|
4055
|
-
/** Return element-wise maximum of the input arrays. */
|
|
4168
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
4056
4169
|
const maximum = max$1;
|
|
4057
|
-
/** Compare two arrays element-wise. */
|
|
4170
|
+
/** @function Compare two arrays element-wise. */
|
|
4058
4171
|
const greater = greater$1;
|
|
4059
|
-
/** Compare two arrays element-wise. */
|
|
4172
|
+
/** @function Compare two arrays element-wise. */
|
|
4060
4173
|
const less = less$1;
|
|
4061
|
-
/** Compare two arrays element-wise. */
|
|
4174
|
+
/** @function Compare two arrays element-wise. */
|
|
4062
4175
|
const equal = equal$1;
|
|
4063
|
-
/** Compare two arrays element-wise. */
|
|
4176
|
+
/** @function Compare two arrays element-wise. */
|
|
4064
4177
|
const notEqual = notEqual$1;
|
|
4065
|
-
/** Compare two arrays element-wise. */
|
|
4178
|
+
/** @function Compare two arrays element-wise. */
|
|
4066
4179
|
const greaterEqual = greaterEqual$1;
|
|
4067
|
-
/** Compare two arrays element-wise. */
|
|
4180
|
+
/** @function Compare two arrays element-wise. */
|
|
4068
4181
|
const lessEqual = lessEqual$1;
|
|
4069
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4182
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4070
4183
|
const where = where$1;
|
|
4071
|
-
/**
|
|
4184
|
+
/**
|
|
4185
|
+
* @function
|
|
4186
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
4187
|
+
*/
|
|
4072
4188
|
const transpose = transpose$1;
|
|
4073
4189
|
/**
|
|
4190
|
+
* @function
|
|
4074
4191
|
* Give a new shape to an array without changing its data.
|
|
4075
4192
|
*
|
|
4076
4193
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
4077
4194
|
* length of the array and remaining dimensions.
|
|
4078
4195
|
*/
|
|
4079
4196
|
const reshape = reshape$1;
|
|
4080
|
-
/**
|
|
4197
|
+
/**
|
|
4198
|
+
* @function
|
|
4199
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
4200
|
+
*/
|
|
4081
4201
|
const moveaxis = moveaxis$1;
|
|
4082
4202
|
/**
|
|
4203
|
+
* @function
|
|
4083
4204
|
* Add padding (zeros) to an array.
|
|
4084
4205
|
*
|
|
4085
4206
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -4087,15 +4208,27 @@ const moveaxis = moveaxis$1;
|
|
|
4087
4208
|
* pair specifies the padding for its corresponding axis.
|
|
4088
4209
|
*/
|
|
4089
4210
|
const pad = pad$1;
|
|
4090
|
-
/**
|
|
4211
|
+
/**
|
|
4212
|
+
* @function
|
|
4213
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
4214
|
+
*/
|
|
4091
4215
|
const ndim = ndim$1;
|
|
4092
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
4216
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
4093
4217
|
const shape = getShape;
|
|
4094
|
-
/**
|
|
4095
|
-
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
|
|
4218
|
+
/**
|
|
4219
|
+
* @function
|
|
4220
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
4221
|
+
*/
|
|
4222
|
+
const zerosLike = zerosLike$1;
|
|
4223
|
+
/**
|
|
4224
|
+
* @function
|
|
4225
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
4226
|
+
*/
|
|
4227
|
+
const onesLike = onesLike$1;
|
|
4228
|
+
/**
|
|
4229
|
+
* @function
|
|
4230
|
+
* Return a full array with the same shape and type as a given array.
|
|
4231
|
+
*/
|
|
4099
4232
|
const fullLike$1 = fullLike;
|
|
4100
4233
|
/**
|
|
4101
4234
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -4110,23 +4243,23 @@ function astype(a, dtype) {
|
|
|
4110
4243
|
return fudgeArray(a).astype(dtype);
|
|
4111
4244
|
}
|
|
4112
4245
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
4113
|
-
function sum(a, axis, opts) {
|
|
4246
|
+
function sum(a, axis = null, opts) {
|
|
4114
4247
|
return reduce(a, AluOp.Add, axis, opts);
|
|
4115
4248
|
}
|
|
4116
4249
|
/** Product of the array elements over a given axis. */
|
|
4117
|
-
function prod$1(a, axis, opts) {
|
|
4250
|
+
function prod$1(a, axis = null, opts) {
|
|
4118
4251
|
return reduce(a, AluOp.Mul, axis, opts);
|
|
4119
4252
|
}
|
|
4120
4253
|
/** Return the minimum of array elements along a given axis. */
|
|
4121
|
-
function min(a, axis, opts) {
|
|
4254
|
+
function min(a, axis = null, opts) {
|
|
4122
4255
|
return reduce(a, AluOp.Min, axis, opts);
|
|
4123
4256
|
}
|
|
4124
4257
|
/** Return the maximum of array elements along a given axis. */
|
|
4125
|
-
function max(a, axis, opts) {
|
|
4258
|
+
function max(a, axis = null, opts) {
|
|
4126
4259
|
return reduce(a, AluOp.Max, axis, opts);
|
|
4127
4260
|
}
|
|
4128
4261
|
/** Compute the average of the array elements along the specified axis. */
|
|
4129
|
-
function mean(a, axis, opts) {
|
|
4262
|
+
function mean(a, axis = null, opts) {
|
|
4130
4263
|
return fudgeArray(a).mean(axis, opts);
|
|
4131
4264
|
}
|
|
4132
4265
|
/**
|
|
@@ -4142,7 +4275,7 @@ function argmin(a, axis, opts) {
|
|
|
4142
4275
|
axis = 0;
|
|
4143
4276
|
} else axis = checkAxis(axis, a.ndim);
|
|
4144
4277
|
const shape$1 = a.shape;
|
|
4145
|
-
const isMax = equal(a, min(a.ref, axis, {
|
|
4278
|
+
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4146
4279
|
const length = scalar(shape$1[axis], {
|
|
4147
4280
|
dtype: int32,
|
|
4148
4281
|
device: a.device
|
|
@@ -4166,7 +4299,7 @@ function argmax(a, axis, opts) {
|
|
|
4166
4299
|
axis = 0;
|
|
4167
4300
|
} else axis = checkAxis(axis, a.ndim);
|
|
4168
4301
|
const shape$1 = a.shape;
|
|
4169
|
-
const isMax = equal(a, max(a.ref, axis, {
|
|
4302
|
+
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4170
4303
|
const length = scalar(shape$1[axis], {
|
|
4171
4304
|
dtype: int32,
|
|
4172
4305
|
device: a.device
|
|
@@ -4178,17 +4311,9 @@ function argmax(a, axis, opts) {
|
|
|
4178
4311
|
return length.sub(max(idx, axis, opts));
|
|
4179
4312
|
}
|
|
4180
4313
|
/** Reverse the elements in an array along the given axes. */
|
|
4181
|
-
function flip(x, axis) {
|
|
4314
|
+
function flip(x, axis = null) {
|
|
4182
4315
|
const nd = ndim(x);
|
|
4183
|
-
|
|
4184
|
-
else if (typeof axis === "number") axis = [axis];
|
|
4185
|
-
const seen = /* @__PURE__ */ new Set();
|
|
4186
|
-
for (let i = 0; i < axis.length; i++) {
|
|
4187
|
-
if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
|
|
4188
|
-
if (axis[i] < 0) axis[i] += nd;
|
|
4189
|
-
if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
|
|
4190
|
-
seen.add(axis[i]);
|
|
4191
|
-
}
|
|
4316
|
+
axis = normalizeAxis(axis, nd);
|
|
4192
4317
|
return flip$1(x, axis);
|
|
4193
4318
|
}
|
|
4194
4319
|
/**
|
|
@@ -4294,12 +4419,80 @@ function flipud(x) {
|
|
|
4294
4419
|
function fliplr(x) {
|
|
4295
4420
|
return flip(x, 1);
|
|
4296
4421
|
}
|
|
4422
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
4297
4423
|
const permuteDims = transpose;
|
|
4298
4424
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
4299
4425
|
function ravel(a) {
|
|
4300
4426
|
return fudgeArray(a).ravel();
|
|
4301
4427
|
}
|
|
4302
4428
|
/**
|
|
4429
|
+
* Repeat each element of an array after themselves.
|
|
4430
|
+
*
|
|
4431
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
4432
|
+
* output array.
|
|
4433
|
+
*/
|
|
4434
|
+
function repeat(a, repeats, axis) {
|
|
4435
|
+
if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
|
|
4436
|
+
a = fudgeArray(a);
|
|
4437
|
+
if (axis === void 0) {
|
|
4438
|
+
a = ravel(a);
|
|
4439
|
+
axis = 0;
|
|
4440
|
+
}
|
|
4441
|
+
axis = checkAxis(axis, a.ndim);
|
|
4442
|
+
if (repeats === 1) return a;
|
|
4443
|
+
const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
|
|
4444
|
+
const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
|
|
4445
|
+
return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
|
|
4446
|
+
}
|
|
4447
|
+
/**
|
|
4448
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
4449
|
+
*
|
|
4450
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
4451
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
4452
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
4453
|
+
*/
|
|
4454
|
+
function tile(a, reps) {
|
|
4455
|
+
a = fudgeArray(a);
|
|
4456
|
+
if (typeof reps === "number") reps = [reps];
|
|
4457
|
+
if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
|
|
4458
|
+
const ndiff = reps.length - a.ndim;
|
|
4459
|
+
if (ndiff > 0) a = a.reshape([...rep(ndiff, 1), ...a.shape]);
|
|
4460
|
+
if (ndiff < 0) reps = [...rep(-ndiff, 1), ...reps];
|
|
4461
|
+
const broadcastedShape = [];
|
|
4462
|
+
const broadcastAxes = [];
|
|
4463
|
+
for (let i = 0; i < a.ndim; i++) {
|
|
4464
|
+
if (reps[i] > 1) {
|
|
4465
|
+
broadcastedShape.push(reps[i]);
|
|
4466
|
+
broadcastAxes.push(broadcastedShape.length - 1);
|
|
4467
|
+
}
|
|
4468
|
+
broadcastedShape.push(a.shape[i]);
|
|
4469
|
+
}
|
|
4470
|
+
const finalShape = a.shape.map((d, i) => reps[i] * d);
|
|
4471
|
+
return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
|
|
4472
|
+
}
|
|
4473
|
+
/**
|
|
4474
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
4475
|
+
*
|
|
4476
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
4477
|
+
* dimensions where the shape is 1.
|
|
4478
|
+
*/
|
|
4479
|
+
function broadcastTo(a, shape$1) {
|
|
4480
|
+
const nd = ndim(a);
|
|
4481
|
+
if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
|
|
4482
|
+
return broadcast(a, shape$1, range(shape$1.length - nd));
|
|
4483
|
+
}
|
|
4484
|
+
/** Broadcast input shapes to a common output shape. */
|
|
4485
|
+
function broadcastShapes(...shapes) {
|
|
4486
|
+
if (shapes.length === 0) return [];
|
|
4487
|
+
return shapes.reduce(generalBroadcast);
|
|
4488
|
+
}
|
|
4489
|
+
/** Broadcast arrays to a common shape. */
|
|
4490
|
+
function broadcastArrays(...arrays) {
|
|
4491
|
+
const shapes = arrays.map((a) => shape(a));
|
|
4492
|
+
const outShape = broadcastShapes(...shapes);
|
|
4493
|
+
return arrays.map((a) => broadcastTo(a, outShape));
|
|
4494
|
+
}
|
|
4495
|
+
/**
|
|
4303
4496
|
* Return specified diagonals.
|
|
4304
4497
|
*
|
|
4305
4498
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
@@ -4323,7 +4516,7 @@ function diag(v, k = 0) {
|
|
|
4323
4516
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
4324
4517
|
if (a.ndim === 1) {
|
|
4325
4518
|
const n = a.shape[0];
|
|
4326
|
-
const ret = where(eye(n).equal(1), a.ref, zerosLike
|
|
4519
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
4327
4520
|
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4328
4521
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4329
4522
|
else return ret;
|
|
@@ -4367,8 +4560,36 @@ function dot(x, y) {
|
|
|
4367
4560
|
]);
|
|
4368
4561
|
return dot$1(x, y);
|
|
4369
4562
|
}
|
|
4370
|
-
/**
|
|
4371
|
-
|
|
4563
|
+
/**
|
|
4564
|
+
* Compute the inner product of two arrays.
|
|
4565
|
+
*
|
|
4566
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
4567
|
+
* contraction on the last axis.
|
|
4568
|
+
*
|
|
4569
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
4570
|
+
*/
|
|
4571
|
+
function inner(x, y) {
|
|
4572
|
+
x = reshape(x, shape(x).toSpliced(-1, 0, ...rep(ndim(y) - 1, 1)));
|
|
4573
|
+
return dot$1(x, y);
|
|
4574
|
+
}
|
|
4575
|
+
/**
|
|
4576
|
+
* Compute the outer product of two arrays.
|
|
4577
|
+
*
|
|
4578
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
4579
|
+
* be of shape `[x.size, y.size]`.
|
|
4580
|
+
*/
|
|
4581
|
+
function outer(x, y) {
|
|
4582
|
+
x = ravel(x);
|
|
4583
|
+
y = ravel(y);
|
|
4584
|
+
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
4585
|
+
}
|
|
4586
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
4587
|
+
function vecdot(x, y, { axis } = {}) {
|
|
4588
|
+
const xaxis = checkAxis(axis ?? -1, ndim(x));
|
|
4589
|
+
const yaxis = checkAxis(axis ?? -1, ndim(y));
|
|
4590
|
+
if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
|
|
4591
|
+
x = moveaxis(x, xaxis, -1);
|
|
4592
|
+
y = moveaxis(y, yaxis, -1);
|
|
4372
4593
|
return dot$1(x, y);
|
|
4373
4594
|
}
|
|
4374
4595
|
/**
|
|
@@ -4377,7 +4598,7 @@ function vecdot(x, y) {
|
|
|
4377
4598
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
4378
4599
|
*/
|
|
4379
4600
|
function vdot(x, y) {
|
|
4380
|
-
return
|
|
4601
|
+
return dot$1(ravel(x), ravel(y));
|
|
4381
4602
|
}
|
|
4382
4603
|
/**
|
|
4383
4604
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
@@ -4406,6 +4627,43 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
4406
4627
|
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
4407
4628
|
}
|
|
4408
4629
|
/**
|
|
4630
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
4631
|
+
*
|
|
4632
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
4633
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
4634
|
+
* `k>0` is above it.
|
|
4635
|
+
*/
|
|
4636
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
4637
|
+
m ??= n;
|
|
4638
|
+
dtype ??= DType.Float32;
|
|
4639
|
+
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
4640
|
+
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
4641
|
+
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
4642
|
+
const rows = arange(k, n + k, 1, {
|
|
4643
|
+
dtype: DType.Int32,
|
|
4644
|
+
device
|
|
4645
|
+
});
|
|
4646
|
+
const cols = arange(0, m, 1, {
|
|
4647
|
+
dtype: DType.Int32,
|
|
4648
|
+
device
|
|
4649
|
+
});
|
|
4650
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
4651
|
+
}
|
|
4652
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
4653
|
+
function tril(a, k = 0) {
|
|
4654
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4655
|
+
a = fudgeArray(a);
|
|
4656
|
+
const [n, m] = a.shape.slice(-2);
|
|
4657
|
+
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
4658
|
+
}
|
|
4659
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
4660
|
+
function triu(a, k = 0) {
|
|
4661
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4662
|
+
a = fudgeArray(a);
|
|
4663
|
+
const [n, m] = a.shape.slice(-2);
|
|
4664
|
+
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
4665
|
+
}
|
|
4666
|
+
/**
|
|
4409
4667
|
* Clip (limit) the values in an array.
|
|
4410
4668
|
*
|
|
4411
4669
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -4429,18 +4687,70 @@ function absolute(x) {
|
|
|
4429
4687
|
x = fudgeArray(x);
|
|
4430
4688
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
4431
4689
|
}
|
|
4432
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
4690
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
4433
4691
|
const abs = absolute;
|
|
4692
|
+
/** Return an element-wise indication of sign of the input. */
|
|
4693
|
+
function sign(x) {
|
|
4694
|
+
x = fudgeArray(x);
|
|
4695
|
+
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4696
|
+
}
|
|
4434
4697
|
/** Calculate element-wise square of the input array. */
|
|
4435
4698
|
function square(x) {
|
|
4436
4699
|
x = fudgeArray(x);
|
|
4437
4700
|
return x.ref.mul(x);
|
|
4438
4701
|
}
|
|
4439
|
-
/**
|
|
4702
|
+
/** Element-wise tangent function (takes radians). */
|
|
4440
4703
|
function tan(x) {
|
|
4441
4704
|
x = fudgeArray(x);
|
|
4442
4705
|
return sin(x.ref).div(cos(x));
|
|
4443
4706
|
}
|
|
4707
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
4708
|
+
function acos(x) {
|
|
4709
|
+
return subtract(pi / 2, asin(x));
|
|
4710
|
+
}
|
|
4711
|
+
/**
|
|
4712
|
+
* @function
|
|
4713
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4714
|
+
*
|
|
4715
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4716
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4717
|
+
* improvements.
|
|
4718
|
+
*/
|
|
4719
|
+
const hypot = jit$1((x1, x2) => {
|
|
4720
|
+
return sqrt(square(x1).add(square(x2)));
|
|
4721
|
+
});
|
|
4722
|
+
/**
|
|
4723
|
+
* @function
|
|
4724
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
4725
|
+
*
|
|
4726
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
4727
|
+
* The result is in the range [-π, π].
|
|
4728
|
+
*
|
|
4729
|
+
* Uses numerically stable formulas:
|
|
4730
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
4731
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
4732
|
+
*
|
|
4733
|
+
* The output is ill-defined when both x and y are zero.
|
|
4734
|
+
*/
|
|
4735
|
+
const atan2 = jit$1((y, x) => {
|
|
4736
|
+
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4737
|
+
const xNeg = less(x.ref, 0);
|
|
4738
|
+
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
4739
|
+
const denom = where(xNeg, y, r.add(x));
|
|
4740
|
+
return atan(numer.div(denom)).mul(2);
|
|
4741
|
+
});
|
|
4742
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
4743
|
+
const arccos = acos;
|
|
4744
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
4745
|
+
const arctan = atan;
|
|
4746
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
4747
|
+
const arctan2 = atan2;
|
|
4748
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
4749
|
+
function subtract(x, y) {
|
|
4750
|
+
x = fudgeArray(x);
|
|
4751
|
+
y = fudgeArray(y);
|
|
4752
|
+
return x.sub(y);
|
|
4753
|
+
}
|
|
4444
4754
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
4445
4755
|
function trueDivide(x, y) {
|
|
4446
4756
|
x = fudgeArray(x);
|
|
@@ -4448,7 +4758,7 @@ function trueDivide(x, y) {
|
|
|
4448
4758
|
if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
|
|
4449
4759
|
return x.div(y);
|
|
4450
4760
|
}
|
|
4451
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
4761
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
4452
4762
|
const divide = trueDivide;
|
|
4453
4763
|
/** Round input to the nearest integer towards zero. */
|
|
4454
4764
|
function trunc(x) {
|
|
@@ -4466,36 +4776,134 @@ function log2(x) {
|
|
|
4466
4776
|
function log10(x) {
|
|
4467
4777
|
return log(x).mul(Math.LOG10E);
|
|
4468
4778
|
}
|
|
4779
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
4780
|
+
function expm1(x) {
|
|
4781
|
+
return exp(x).sub(1);
|
|
4782
|
+
}
|
|
4783
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
4784
|
+
function log1p(x) {
|
|
4785
|
+
return log(add(1, x));
|
|
4786
|
+
}
|
|
4787
|
+
/** Convert angles from degrees to radians. */
|
|
4788
|
+
function deg2rad(x) {
|
|
4789
|
+
return multiply(x, pi / 180);
|
|
4790
|
+
}
|
|
4791
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
4792
|
+
const radians = deg2rad;
|
|
4793
|
+
/** Convert angles from radians to degrees. */
|
|
4794
|
+
function rad2deg(x) {
|
|
4795
|
+
return multiply(x, 180 / pi);
|
|
4796
|
+
}
|
|
4797
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
4798
|
+
const degrees = rad2deg;
|
|
4469
4799
|
/**
|
|
4800
|
+
* @function
|
|
4801
|
+
* Computes first array raised to power of second array, element-wise.
|
|
4802
|
+
*/
|
|
4803
|
+
const power = jit$1((x1, x2) => {
|
|
4804
|
+
return exp(log(x1).mul(x2));
|
|
4805
|
+
});
|
|
4806
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
4807
|
+
const pow = power;
|
|
4808
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
4809
|
+
const cbrt = jit$1((x) => {
|
|
4810
|
+
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4811
|
+
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4812
|
+
});
|
|
4813
|
+
/**
|
|
4814
|
+
* @function
|
|
4470
4815
|
* Calculate element-wise hyperbolic sine of input.
|
|
4471
4816
|
*
|
|
4472
4817
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4473
4818
|
*/
|
|
4474
|
-
|
|
4819
|
+
const sinh = jit$1((x) => {
|
|
4475
4820
|
const ex = exp(x);
|
|
4476
4821
|
const emx = reciprocal(ex.ref);
|
|
4477
4822
|
return ex.sub(emx).mul(.5);
|
|
4478
|
-
}
|
|
4823
|
+
});
|
|
4479
4824
|
/**
|
|
4825
|
+
* @function
|
|
4480
4826
|
* Calculate element-wise hyperbolic cosine of input.
|
|
4481
4827
|
*
|
|
4482
4828
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4483
4829
|
*/
|
|
4484
|
-
|
|
4830
|
+
const cosh = jit$1((x) => {
|
|
4485
4831
|
const ex = exp(x);
|
|
4486
4832
|
const emx = reciprocal(ex.ref);
|
|
4487
4833
|
return ex.add(emx).mul(.5);
|
|
4488
|
-
}
|
|
4834
|
+
});
|
|
4489
4835
|
/**
|
|
4836
|
+
* @function
|
|
4490
4837
|
* Calculate element-wise hyperbolic tangent of input.
|
|
4491
4838
|
*
|
|
4492
4839
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4493
4840
|
*/
|
|
4494
|
-
|
|
4495
|
-
x = fudgeArray(x);
|
|
4841
|
+
const tanh = jit$1((x) => {
|
|
4496
4842
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4497
4843
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4498
4844
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4845
|
+
});
|
|
4846
|
+
/**
|
|
4847
|
+
* @function
|
|
4848
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
4849
|
+
*
|
|
4850
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4851
|
+
*/
|
|
4852
|
+
const arcsinh = jit$1((x) => {
|
|
4853
|
+
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4854
|
+
});
|
|
4855
|
+
/**
|
|
4856
|
+
* @function
|
|
4857
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
4858
|
+
*
|
|
4859
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4860
|
+
*/
|
|
4861
|
+
const arccosh = jit$1((x) => {
|
|
4862
|
+
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4863
|
+
});
|
|
4864
|
+
/**
|
|
4865
|
+
* @function
|
|
4866
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
4867
|
+
*
|
|
4868
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4869
|
+
*/
|
|
4870
|
+
const arctanh = jit$1((x) => {
|
|
4871
|
+
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4872
|
+
});
|
|
4873
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
4874
|
+
const asinh = arcsinh;
|
|
4875
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
4876
|
+
const acosh = arccosh;
|
|
4877
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
4878
|
+
const atanh = arctanh;
|
|
4879
|
+
/**
|
|
4880
|
+
* Compute the variance of an array.
|
|
4881
|
+
*
|
|
4882
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
4883
|
+
* the specified axis.
|
|
4884
|
+
*
|
|
4885
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4886
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4887
|
+
*/
|
|
4888
|
+
function var_(x, axis = null, opts) {
|
|
4889
|
+
x = fudgeArray(x);
|
|
4890
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
4891
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
4892
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
4893
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
4894
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
4895
|
+
}
|
|
4896
|
+
/**
|
|
4897
|
+
* Compute the standard deviation of an array.
|
|
4898
|
+
*
|
|
4899
|
+
* The standard deviation is computed for the flattened array by default,
|
|
4900
|
+
* otherwise over the specified axis.
|
|
4901
|
+
*
|
|
4902
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4903
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4904
|
+
*/
|
|
4905
|
+
function std(x, axis = null, opts) {
|
|
4906
|
+
return sqrt(var_(x, axis, opts));
|
|
4499
4907
|
}
|
|
4500
4908
|
|
|
4501
4909
|
//#endregion
|
|
@@ -4510,6 +4918,7 @@ __export(nn_exports, {
|
|
|
4510
4918
|
leakyRelu: () => leakyRelu,
|
|
4511
4919
|
logSigmoid: () => logSigmoid,
|
|
4512
4920
|
logSoftmax: () => logSoftmax,
|
|
4921
|
+
logmeanexp: () => logmeanexp,
|
|
4513
4922
|
logsumexp: () => logsumexp,
|
|
4514
4923
|
mish: () => mish,
|
|
4515
4924
|
oneHot: () => oneHot,
|
|
@@ -4520,6 +4929,8 @@ __export(nn_exports, {
|
|
|
4520
4929
|
softSign: () => softSign,
|
|
4521
4930
|
softmax: () => softmax,
|
|
4522
4931
|
softplus: () => softplus,
|
|
4932
|
+
squareplus: () => squareplus,
|
|
4933
|
+
standardize: () => standardize,
|
|
4523
4934
|
swish: () => swish
|
|
4524
4935
|
});
|
|
4525
4936
|
/**
|
|
@@ -4563,6 +4974,7 @@ function softSign(x) {
|
|
|
4563
4974
|
return x.ref.div(absolute(x).add(1));
|
|
4564
4975
|
}
|
|
4565
4976
|
/**
|
|
4977
|
+
* @function
|
|
4566
4978
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4567
4979
|
* Swish, computed element-wise:
|
|
4568
4980
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4573,6 +4985,7 @@ function softSign(x) {
|
|
|
4573
4985
|
*/
|
|
4574
4986
|
const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
|
|
4575
4987
|
/**
|
|
4988
|
+
* @function
|
|
4576
4989
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4577
4990
|
* Swish, computed element-wise:
|
|
4578
4991
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4589,7 +5002,10 @@ const swish = silu;
|
|
|
4589
5002
|
function logSigmoid(x) {
|
|
4590
5003
|
return negative(softplus(negative(x)));
|
|
4591
5004
|
}
|
|
4592
|
-
/**
|
|
5005
|
+
/**
|
|
5006
|
+
* @function
|
|
5007
|
+
* Identity activation function. Returns the argument unmodified.
|
|
5008
|
+
*/
|
|
4593
5009
|
const identity = fudgeArray;
|
|
4594
5010
|
/** Leaky rectified linear (ReLU) activation function */
|
|
4595
5011
|
function leakyRelu(x, negativeSlope = .01) {
|
|
@@ -4617,6 +5033,7 @@ function celu(x, alpha = 1) {
|
|
|
4617
5033
|
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
4618
5034
|
}
|
|
4619
5035
|
/**
|
|
5036
|
+
* @function
|
|
4620
5037
|
* Gaussion error linear unit (GELU) activation function.
|
|
4621
5038
|
*
|
|
4622
5039
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -4648,6 +5065,16 @@ function glu(x, axis = -1) {
|
|
|
4648
5065
|
return a.mul(sigmoid(b));
|
|
4649
5066
|
}
|
|
4650
5067
|
/**
|
|
5068
|
+
* Squareplus activation function.
|
|
5069
|
+
*
|
|
5070
|
+
* Computes the element-wise function:
|
|
5071
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
5072
|
+
*/
|
|
5073
|
+
function squareplus(x, b = 4) {
|
|
5074
|
+
x = fudgeArray(x);
|
|
5075
|
+
return x.ref.add(sqrt(square(x).add(b))).mul(.5);
|
|
5076
|
+
}
|
|
5077
|
+
/**
|
|
4651
5078
|
* Mish activation function.
|
|
4652
5079
|
*
|
|
4653
5080
|
* Computes the element-wise function:
|
|
@@ -4665,17 +5092,13 @@ function mish(x) {
|
|
|
4665
5092
|
*
|
|
4666
5093
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
4667
5094
|
*/
|
|
4668
|
-
function softmax(x, axis) {
|
|
5095
|
+
function softmax(x, axis = -1) {
|
|
4669
5096
|
x = fudgeArray(x);
|
|
4670
|
-
|
|
4671
|
-
|
|
4672
|
-
|
|
4673
|
-
x.dispose();
|
|
4674
|
-
return ones(x.shape);
|
|
4675
|
-
}
|
|
4676
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5097
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5098
|
+
if (axis.length === 0) return onesLike(x);
|
|
5099
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4677
5100
|
const unnormalized = exp(x.sub(stopGradient(xMax)));
|
|
4678
|
-
return unnormalized.ref.div(unnormalized.sum(axis, {
|
|
5101
|
+
return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
|
|
4679
5102
|
}
|
|
4680
5103
|
/**
|
|
4681
5104
|
* Log-Softmax function.
|
|
@@ -4685,17 +5108,13 @@ function softmax(x, axis) {
|
|
|
4685
5108
|
*
|
|
4686
5109
|
* If `axis` is not specified, it defaults to the last axis.
|
|
4687
5110
|
*/
|
|
4688
|
-
function logSoftmax(x, axis) {
|
|
5111
|
+
function logSoftmax(x, axis = -1) {
|
|
4689
5112
|
x = fudgeArray(x);
|
|
4690
|
-
|
|
4691
|
-
|
|
4692
|
-
|
|
4693
|
-
x.dispose();
|
|
4694
|
-
return zeros(x.shape);
|
|
4695
|
-
}
|
|
4696
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5113
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5114
|
+
if (axis.length === 0) return zerosLike(x);
|
|
5115
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4697
5116
|
const shifted = x.sub(stopGradient(xMax));
|
|
4698
|
-
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, {
|
|
5117
|
+
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
|
|
4699
5118
|
return shifted.sub(shiftedLogsumexp);
|
|
4700
5119
|
}
|
|
4701
5120
|
/**
|
|
@@ -4706,16 +5125,39 @@ function logSoftmax(x, axis) {
|
|
|
4706
5125
|
*
|
|
4707
5126
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
4708
5127
|
*/
|
|
4709
|
-
function logsumexp(x, axis) {
|
|
5128
|
+
function logsumexp(x, axis = null) {
|
|
4710
5129
|
x = fudgeArray(x);
|
|
4711
|
-
|
|
4712
|
-
else if (typeof axis === "number") axis = [axis];
|
|
5130
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
4713
5131
|
if (axis.length === 0) return x;
|
|
4714
5132
|
const xMax = stopGradient(max(x.ref, axis));
|
|
4715
5133
|
const xMaxDims = broadcast(xMax.ref, x.shape, axis);
|
|
4716
5134
|
const shifted = x.sub(xMaxDims);
|
|
4717
5135
|
return xMax.add(log(exp(shifted).sum(axis)));
|
|
4718
5136
|
}
|
|
5137
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5138
|
+
function logmeanexp(x, axis = null) {
|
|
5139
|
+
x = fudgeArray(x);
|
|
5140
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5141
|
+
if (axis.length === 0) return x;
|
|
5142
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5143
|
+
return logsumexp(x, axis).sub(Math.log(n));
|
|
5144
|
+
}
|
|
5145
|
+
/**
|
|
5146
|
+
* Standardizes input to zero mean and unit variance.
|
|
5147
|
+
*
|
|
5148
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
5149
|
+
* axis, or `null` to standardize over all elements.
|
|
5150
|
+
*
|
|
5151
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
5152
|
+
*/
|
|
5153
|
+
function standardize(x, axis = -1, opts = {}) {
|
|
5154
|
+
x = fudgeArray(x);
|
|
5155
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5156
|
+
if (axis.length === 0) return x;
|
|
5157
|
+
const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
|
|
5158
|
+
const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
|
|
5159
|
+
return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
|
|
5160
|
+
}
|
|
4719
5161
|
/**
|
|
4720
5162
|
* One-hot encodes the given indices.
|
|
4721
5163
|
*
|
|
@@ -4733,7 +5175,7 @@ function logsumexp(x, axis) {
|
|
|
4733
5175
|
* ```
|
|
4734
5176
|
*/
|
|
4735
5177
|
function oneHot(x, numClasses) {
|
|
4736
|
-
if (x.dtype
|
|
5178
|
+
if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4737
5179
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4738
5180
|
}
|
|
4739
5181
|
|
|
@@ -4741,8 +5183,11 @@ function oneHot(x, numClasses) {
|
|
|
4741
5183
|
//#region src/random.ts
|
|
4742
5184
|
var random_exports = {};
|
|
4743
5185
|
__export(random_exports, {
|
|
5186
|
+
bernoulli: () => bernoulli,
|
|
4744
5187
|
bits: () => bits,
|
|
5188
|
+
exponential: () => exponential,
|
|
4745
5189
|
key: () => key,
|
|
5190
|
+
normal: () => normal,
|
|
4746
5191
|
split: () => split,
|
|
4747
5192
|
uniform: () => uniform
|
|
4748
5193
|
});
|
|
@@ -4773,11 +5218,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
4773
5218
|
/** Sample uniform random values in [minval, maxval) with given shape. */
|
|
4774
5219
|
function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
4775
5220
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
4776
|
-
const mantissa = bits(key$1, shape$1).div(
|
|
5221
|
+
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
4777
5222
|
dtype: DType.Uint32,
|
|
4778
5223
|
device: key$1.device
|
|
4779
5224
|
}));
|
|
4780
|
-
const float12 = mantissa.add(
|
|
5225
|
+
const float12 = mantissa.add(array(1065353216, {
|
|
4781
5226
|
dtype: DType.Uint32,
|
|
4782
5227
|
device: key$1.device
|
|
4783
5228
|
}));
|
|
@@ -4785,6 +5230,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
4785
5230
|
if (minval === 0 && maxval === 1) return rand;
|
|
4786
5231
|
else return rand.mul(maxval - minval).add(minval);
|
|
4787
5232
|
}
|
|
5233
|
+
/**
|
|
5234
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5235
|
+
*
|
|
5236
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
5237
|
+
* and must be broadcastable to `shape`.
|
|
5238
|
+
*/
|
|
5239
|
+
function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
5240
|
+
p = fudgeArray(p);
|
|
5241
|
+
return uniform(key$1, shape$1).less(p);
|
|
5242
|
+
}
|
|
5243
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
5244
|
+
function exponential(key$1, shape$1 = []) {
|
|
5245
|
+
const u = uniform(key$1, shape$1);
|
|
5246
|
+
return negative(log1p(negative(u)));
|
|
5247
|
+
}
|
|
5248
|
+
/**
|
|
5249
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5250
|
+
*
|
|
5251
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5252
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5253
|
+
* bitwise identical to JAX.
|
|
5254
|
+
*/
|
|
5255
|
+
function normal(key$1, shape$1 = []) {
|
|
5256
|
+
const [k1, k2] = split(key$1, 2);
|
|
5257
|
+
const u1 = uniform(k1, shape$1);
|
|
5258
|
+
const u2 = uniform(k2, shape$1);
|
|
5259
|
+
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5260
|
+
const theta = u2.mul(2 * Math.PI);
|
|
5261
|
+
return radius.mul(cos(theta));
|
|
5262
|
+
}
|
|
4788
5263
|
|
|
4789
5264
|
//#endregion
|
|
4790
5265
|
//#region src/polyfills.ts
|
|
@@ -4794,20 +5269,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
4794
5269
|
|
|
4795
5270
|
//#endregion
|
|
4796
5271
|
//#region src/index.ts
|
|
4797
|
-
/**
|
|
5272
|
+
/**
|
|
5273
|
+
* @function
|
|
5274
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
5275
|
+
*/
|
|
4798
5276
|
const jvp = jvp$1;
|
|
4799
|
-
/**
|
|
5277
|
+
/**
|
|
5278
|
+
* @function
|
|
5279
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
5280
|
+
*/
|
|
4800
5281
|
const vmap = vmap$1;
|
|
4801
|
-
/**
|
|
5282
|
+
/**
|
|
5283
|
+
* @function
|
|
5284
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
5285
|
+
*/
|
|
4802
5286
|
const jacfwd = jacfwd$1;
|
|
4803
|
-
/**
|
|
5287
|
+
/**
|
|
5288
|
+
* @function
|
|
5289
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
5290
|
+
*/
|
|
4804
5291
|
const makeJaxpr = makeJaxpr$1;
|
|
4805
5292
|
/**
|
|
5293
|
+
* @function
|
|
4806
5294
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
4807
5295
|
*
|
|
4808
5296
|
* The function will be compiled the first time it is called with a set of
|
|
4809
5297
|
* argument shapes.
|
|
4810
5298
|
*
|
|
5299
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
5300
|
+
* calls to free memory associated with array constants.
|
|
5301
|
+
*
|
|
4811
5302
|
* **Options:**
|
|
4812
5303
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
4813
5304
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -4817,23 +5308,52 @@ const makeJaxpr = makeJaxpr$1;
|
|
|
4817
5308
|
*/
|
|
4818
5309
|
const jit = jit$1;
|
|
4819
5310
|
/**
|
|
5311
|
+
* @function
|
|
4820
5312
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
4821
5313
|
* partial evaluation.
|
|
4822
5314
|
*/
|
|
4823
5315
|
const linearize = linearize$1;
|
|
4824
|
-
/**
|
|
5316
|
+
/**
|
|
5317
|
+
* @function
|
|
5318
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
5319
|
+
*/
|
|
4825
5320
|
const vjp = vjp$1;
|
|
4826
5321
|
/**
|
|
5322
|
+
* @function
|
|
4827
5323
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
4828
5324
|
* first argument.
|
|
4829
5325
|
*/
|
|
4830
5326
|
const grad = grad$1;
|
|
4831
|
-
/**
|
|
5327
|
+
/**
|
|
5328
|
+
* @function
|
|
5329
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
5330
|
+
*/
|
|
4832
5331
|
const valueAndGrad = valueAndGrad$1;
|
|
4833
|
-
/**
|
|
5332
|
+
/**
|
|
5333
|
+
* @function
|
|
5334
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
5335
|
+
*/
|
|
4834
5336
|
const jacrev = jacrev$1;
|
|
4835
|
-
/**
|
|
5337
|
+
/**
|
|
5338
|
+
* @function
|
|
5339
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
5340
|
+
*/
|
|
4836
5341
|
const jacobian = jacrev;
|
|
5342
|
+
/**
|
|
5343
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
5344
|
+
*
|
|
5345
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
5346
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
5347
|
+
* to avoid queueing up too many pending operations.
|
|
5348
|
+
*
|
|
5349
|
+
* Does not consume reference to the arrays.
|
|
5350
|
+
*/
|
|
5351
|
+
async function blockUntilReady(x) {
|
|
5352
|
+
const promises = [];
|
|
5353
|
+
for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
|
|
5354
|
+
await Promise.all(promises);
|
|
5355
|
+
return x;
|
|
5356
|
+
}
|
|
4837
5357
|
|
|
4838
5358
|
//#endregion
|
|
4839
|
-
export { DType, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random,
|
|
5359
|
+
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|