@jax-js/jax 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -19
- package/dist/{backend-BqDtPGaR.js → backend-EBRGmEYw.js} +296 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-Ss1Mev_-.cjs} +315 -154
- package/dist/index.cjs +681 -157
- package/dist/index.d.cts +422 -76
- package/dist/index.d.ts +422 -76
- package/dist/index.js +677 -157
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-BVdMaO9T.cjs} +9 -3
- package/dist/{webgpu-CNg9JGva.js → webgpu-ow0Pn_6q.js} +9 -3
- package/package.json +15 -4
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-Ss1Mev_-.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
@@ -354,6 +354,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
354
354
|
Primitive$1["RandomBits"] = "random_bits";
|
|
355
355
|
Primitive$1["Sin"] = "sin";
|
|
356
356
|
Primitive$1["Cos"] = "cos";
|
|
357
|
+
Primitive$1["Asin"] = "asin";
|
|
358
|
+
Primitive$1["Atan"] = "atan";
|
|
357
359
|
Primitive$1["Exp"] = "exp";
|
|
358
360
|
Primitive$1["Log"] = "log";
|
|
359
361
|
Primitive$1["Sqrt"] = "sqrt";
|
|
@@ -421,6 +423,12 @@ function sin$1(x) {
|
|
|
421
423
|
function cos$1(x) {
|
|
422
424
|
return bind1(Primitive.Cos, [x]);
|
|
423
425
|
}
|
|
426
|
+
function asin$1(x) {
|
|
427
|
+
return bind1(Primitive.Asin, [x]);
|
|
428
|
+
}
|
|
429
|
+
function atan$1(x) {
|
|
430
|
+
return bind1(Primitive.Atan, [x]);
|
|
431
|
+
}
|
|
424
432
|
function exp$1(x) {
|
|
425
433
|
return bind1(Primitive.Exp, [x]);
|
|
426
434
|
}
|
|
@@ -436,18 +444,16 @@ function min$1(x, y) {
|
|
|
436
444
|
function max$1(x, y) {
|
|
437
445
|
return bind1(Primitive.Max, [x, y]);
|
|
438
446
|
}
|
|
439
|
-
function reduce(x, op, axis, opts) {
|
|
447
|
+
function reduce(x, op, axis = null, opts) {
|
|
440
448
|
if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
441
|
-
|
|
442
|
-
else axis = [];
|
|
443
|
-
else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, ndim$1(x))];
|
|
444
|
-
else axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
|
|
449
|
+
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
445
450
|
const originalShape = getShape(x);
|
|
446
|
-
|
|
451
|
+
let result = bind1(Primitive.Reduce, [x], {
|
|
447
452
|
op,
|
|
448
453
|
axis
|
|
449
454
|
});
|
|
450
|
-
|
|
455
|
+
if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
|
|
456
|
+
return result;
|
|
451
457
|
}
|
|
452
458
|
function dot$1(x, y) {
|
|
453
459
|
return bind1(Primitive.Dot, [x, y]);
|
|
@@ -493,10 +499,11 @@ function where$1(cond, x, y) {
|
|
|
493
499
|
}
|
|
494
500
|
function transpose$1(x, perm) {
|
|
495
501
|
perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
|
|
502
|
+
if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
496
503
|
return bind1(Primitive.Transpose, [x], { perm });
|
|
497
504
|
}
|
|
498
505
|
function broadcast(x, shape$1, axis) {
|
|
499
|
-
axis =
|
|
506
|
+
axis = require_backend.normalizeAxis(axis, shape$1.length);
|
|
500
507
|
return bind1(Primitive.Broadcast, [x], {
|
|
501
508
|
shape: shape$1,
|
|
502
509
|
axis
|
|
@@ -515,7 +522,7 @@ function reshape$1(x, shape$1) {
|
|
|
515
522
|
return bind1(Primitive.Reshape, [x], { shape: shape$1 });
|
|
516
523
|
}
|
|
517
524
|
function flip$1(x, axis) {
|
|
518
|
-
axis =
|
|
525
|
+
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
519
526
|
return bind1(Primitive.Flip, [x], { axis });
|
|
520
527
|
}
|
|
521
528
|
function shrink(x, slice) {
|
|
@@ -595,15 +602,19 @@ var Tracer = class Tracer {
|
|
|
595
602
|
constructor(trace) {
|
|
596
603
|
this._trace = trace;
|
|
597
604
|
}
|
|
605
|
+
/** The shape of the array. */
|
|
598
606
|
get shape() {
|
|
599
607
|
return this.aval.shape;
|
|
600
608
|
}
|
|
609
|
+
/** The total number of elements in the array. */
|
|
601
610
|
get size() {
|
|
602
611
|
return require_backend.prod(this.shape);
|
|
603
612
|
}
|
|
613
|
+
/** The dtype of the array. */
|
|
604
614
|
get dtype() {
|
|
605
615
|
return this.aval.dtype;
|
|
606
616
|
}
|
|
617
|
+
/** The number of dimensions of the array. */
|
|
607
618
|
get ndim() {
|
|
608
619
|
return this.shape.length;
|
|
609
620
|
}
|
|
@@ -639,22 +650,20 @@ var Tracer = class Tracer {
|
|
|
639
650
|
return lessEqual$1(this, other);
|
|
640
651
|
}
|
|
641
652
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
642
|
-
sum(axis, opts) {
|
|
653
|
+
sum(axis = null, opts) {
|
|
643
654
|
return reduce(this, require_backend.AluOp.Add, axis, opts);
|
|
644
655
|
}
|
|
645
656
|
/** Product of the array elements over a given axis. */
|
|
646
|
-
prod(axis, opts) {
|
|
657
|
+
prod(axis = null, opts) {
|
|
647
658
|
return reduce(this, require_backend.AluOp.Mul, axis, opts);
|
|
648
659
|
}
|
|
649
660
|
/** Compute the average of the array elements along the specified axis. */
|
|
650
|
-
mean(axis, opts) {
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
657
|
-
return result;
|
|
661
|
+
mean(axis = null, opts) {
|
|
662
|
+
axis = require_backend.normalizeAxis(axis, this.ndim);
|
|
663
|
+
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
664
|
+
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
665
|
+
const result = reduce(this, require_backend.AluOp.Add, axis, opts);
|
|
666
|
+
return result.mul(1 / n);
|
|
658
667
|
}
|
|
659
668
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
660
669
|
transpose(perm) {
|
|
@@ -1187,6 +1196,8 @@ const jitRules = {
|
|
|
1187
1196
|
},
|
|
1188
1197
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
1189
1198
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
1199
|
+
[Primitive.Asin]: unopJit(require_backend.AluExp.asin),
|
|
1200
|
+
[Primitive.Atan]: unopJit(require_backend.AluExp.atan),
|
|
1190
1201
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
1191
1202
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1192
1203
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
@@ -1428,7 +1439,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1428
1439
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1429
1440
|
* will be freed when the array is disposed.
|
|
1430
1441
|
*/
|
|
1431
|
-
constructor(source, st, dtype, backend, pending = null) {
|
|
1442
|
+
constructor(source, st, dtype, backend, { pending = null } = {}) {
|
|
1432
1443
|
super(baseArrayTrace);
|
|
1433
1444
|
this.id = Array$1.#nextId++;
|
|
1434
1445
|
this.#dtype = dtype;
|
|
@@ -1437,6 +1448,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1437
1448
|
this.#backend = backend;
|
|
1438
1449
|
this.#rc = 1;
|
|
1439
1450
|
this.#pendingSet = new Set(pending);
|
|
1451
|
+
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1452
|
+
else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1440
1453
|
}
|
|
1441
1454
|
/** @ignore */
|
|
1442
1455
|
get aval() {
|
|
@@ -1491,7 +1504,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1491
1504
|
const pending = this.#pending;
|
|
1492
1505
|
for (const exe of pending) exe.updateRc(1);
|
|
1493
1506
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1494
|
-
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
|
|
1507
|
+
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
|
|
1495
1508
|
this.dispose();
|
|
1496
1509
|
return ar;
|
|
1497
1510
|
}
|
|
@@ -1540,7 +1553,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1540
1553
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1541
1554
|
this.dispose();
|
|
1542
1555
|
for (const ar of indices) ar.dispose();
|
|
1543
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
|
|
1556
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
|
|
1544
1557
|
}
|
|
1545
1558
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1546
1559
|
#moveAxesDown(axis) {
|
|
@@ -1577,7 +1590,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1577
1590
|
for (const exe of pending) exe.updateRc(1);
|
|
1578
1591
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1579
1592
|
this.dispose();
|
|
1580
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
|
|
1593
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
|
|
1581
1594
|
}
|
|
1582
1595
|
#binary(op, other) {
|
|
1583
1596
|
const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
|
|
@@ -1642,7 +1655,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1642
1655
|
for (const exe of pending) exe.updateRc(1);
|
|
1643
1656
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1644
1657
|
for (const ar of arrays) ar.dispose();
|
|
1645
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
|
|
1658
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
|
|
1646
1659
|
}
|
|
1647
1660
|
/** Reduce the last dimension of the array by an operation. */
|
|
1648
1661
|
#reduce(op) {
|
|
@@ -1666,7 +1679,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1666
1679
|
for (const exe of pending) exe.updateRc(1);
|
|
1667
1680
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1668
1681
|
this.dispose();
|
|
1669
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
|
|
1682
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
|
|
1670
1683
|
}
|
|
1671
1684
|
/**
|
|
1672
1685
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1739,8 +1752,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1739
1752
|
*
|
|
1740
1753
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1741
1754
|
* dispatch of operations as well.
|
|
1755
|
+
*
|
|
1756
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1757
|
+
* asynchronously for multiple arrays.
|
|
1742
1758
|
*/
|
|
1743
|
-
async
|
|
1759
|
+
async blockUntilReady() {
|
|
1744
1760
|
this.#check();
|
|
1745
1761
|
if (this.#source instanceof require_backend.AluExp) return this;
|
|
1746
1762
|
const pending = this.#pending;
|
|
@@ -1806,7 +1822,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1806
1822
|
return [x.#binary(require_backend.AluOp.Idiv, y)];
|
|
1807
1823
|
},
|
|
1808
1824
|
[Primitive.Neg]([x]) {
|
|
1809
|
-
return [zerosLike(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
1825
|
+
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
1810
1826
|
},
|
|
1811
1827
|
[Primitive.Reciprocal]([x]) {
|
|
1812
1828
|
return [x.#unary(require_backend.AluOp.Reciprocal)];
|
|
@@ -1826,7 +1842,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1826
1842
|
x.#backend.incRef(x.#source);
|
|
1827
1843
|
const pending = x.#pending;
|
|
1828
1844
|
for (const exe of pending) exe.updateRc(1);
|
|
1829
|
-
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
|
|
1845
|
+
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
|
|
1830
1846
|
x.dispose();
|
|
1831
1847
|
return [y];
|
|
1832
1848
|
}
|
|
@@ -1856,6 +1872,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1856
1872
|
[Primitive.Cos]([x]) {
|
|
1857
1873
|
return [x.#unary(require_backend.AluOp.Cos)];
|
|
1858
1874
|
},
|
|
1875
|
+
[Primitive.Asin]([x]) {
|
|
1876
|
+
return [x.#unary(require_backend.AluOp.Asin)];
|
|
1877
|
+
},
|
|
1878
|
+
[Primitive.Atan]([x]) {
|
|
1879
|
+
return [x.#unary(require_backend.AluOp.Atan)];
|
|
1880
|
+
},
|
|
1859
1881
|
[Primitive.Exp]([x]) {
|
|
1860
1882
|
return [x.#unary(require_backend.AluOp.Exp)];
|
|
1861
1883
|
},
|
|
@@ -1941,7 +1963,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1941
1963
|
pending.splice(0, 0, ...prevPending);
|
|
1942
1964
|
args.forEach((x) => x.dispose());
|
|
1943
1965
|
return outputs.map((source, i) => {
|
|
1944
|
-
return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
|
|
1966
|
+
return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
|
|
1945
1967
|
});
|
|
1946
1968
|
}
|
|
1947
1969
|
};
|
|
@@ -2073,12 +2095,12 @@ var EvalTrace = class extends Trace {
|
|
|
2073
2095
|
};
|
|
2074
2096
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2075
2097
|
const implRules = Array$1._implRules();
|
|
2076
|
-
function zerosLike(val, dtype) {
|
|
2098
|
+
function zerosLike$1(val, dtype) {
|
|
2077
2099
|
const aval = getAval(val);
|
|
2078
2100
|
if (val instanceof Tracer) val.dispose();
|
|
2079
2101
|
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2080
2102
|
}
|
|
2081
|
-
function onesLike(val, dtype) {
|
|
2103
|
+
function onesLike$1(val, dtype) {
|
|
2082
2104
|
const aval = getAval(val);
|
|
2083
2105
|
if (val instanceof Tracer) val.dispose();
|
|
2084
2106
|
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
@@ -2141,7 +2163,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2141
2163
|
const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
|
|
2142
2164
|
return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
|
|
2143
2165
|
}
|
|
2144
|
-
/** Return the identity
|
|
2166
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2145
2167
|
function identity$1(n, { dtype, device } = {}) {
|
|
2146
2168
|
return eye(n, n, {
|
|
2147
2169
|
dtype,
|
|
@@ -2421,16 +2443,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2421
2443
|
varIds.set(v, require_backend.FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
|
|
2422
2444
|
return id;
|
|
2423
2445
|
};
|
|
2424
|
-
hasher.update(this.inBinders.length
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
eqn.
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
|
|
2446
|
+
hasher.update(this.inBinders.length);
|
|
2447
|
+
for (const x of this.inBinders) hasher.update(vi(x));
|
|
2448
|
+
hasher.update(this.eqns.length);
|
|
2449
|
+
for (const eqn of this.eqns) {
|
|
2450
|
+
hasher.update(eqn.primitive);
|
|
2451
|
+
hasher.update(eqn.inputs.length);
|
|
2452
|
+
for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2453
|
+
hasher.update(JSON.stringify(eqn.params));
|
|
2454
|
+
hasher.update(eqn.outBinders.length);
|
|
2455
|
+
for (const x of eqn.outBinders) hasher.update(vi(x));
|
|
2456
|
+
}
|
|
2457
|
+
hasher.update(this.outs.length);
|
|
2458
|
+
for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2434
2459
|
return this.#hash = hasher.value;
|
|
2435
2460
|
}
|
|
2436
2461
|
hash(state) {
|
|
@@ -2467,7 +2492,7 @@ var Jaxpr = class Jaxpr {
|
|
|
2467
2492
|
const c = eqn.outBinders[0];
|
|
2468
2493
|
if (atomIsLit(b, 1)) context.set(c, a);
|
|
2469
2494
|
else newEqns.push(eqn);
|
|
2470
|
-
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2495
|
+
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2471
2496
|
else newEqns.push(eqn);
|
|
2472
2497
|
}
|
|
2473
2498
|
const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
|
|
@@ -2733,6 +2758,8 @@ const abstractEvalRules = {
|
|
|
2733
2758
|
},
|
|
2734
2759
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2735
2760
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2761
|
+
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
2762
|
+
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2736
2763
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2737
2764
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2738
2765
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
@@ -2860,7 +2887,7 @@ function makeJaxpr$1(f, opts) {
|
|
|
2860
2887
|
function jit$1(f, opts) {
|
|
2861
2888
|
const cache = /* @__PURE__ */ new Map();
|
|
2862
2889
|
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2863
|
-
|
|
2890
|
+
const result = ((...args) => {
|
|
2864
2891
|
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2865
2892
|
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2866
2893
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
@@ -2874,6 +2901,10 @@ function jit$1(f, opts) {
|
|
|
2874
2901
|
});
|
|
2875
2902
|
return unflatten(outTree, outs);
|
|
2876
2903
|
});
|
|
2904
|
+
result.dispose = () => {
|
|
2905
|
+
for (const { consts } of cache.values()) for (const c of consts) c.dispose();
|
|
2906
|
+
};
|
|
2907
|
+
return result;
|
|
2877
2908
|
}
|
|
2878
2909
|
|
|
2879
2910
|
//#endregion
|
|
@@ -2905,7 +2936,7 @@ var JVPTrace = class extends Trace {
|
|
|
2905
2936
|
return this.lift(pureArray(val));
|
|
2906
2937
|
}
|
|
2907
2938
|
lift(val) {
|
|
2908
|
-
return new JVPTracer(this, val, zerosLike(val.ref));
|
|
2939
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
2909
2940
|
}
|
|
2910
2941
|
processPrimitive(primitive, tracers, params) {
|
|
2911
2942
|
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2936,7 +2967,7 @@ function zeroTangentsJvp(primitive) {
|
|
|
2936
2967
|
return (primals, tangents, params) => {
|
|
2937
2968
|
for (const t of tangents) t.dispose();
|
|
2938
2969
|
const ys = bind(primitive, primals, params);
|
|
2939
|
-
return [ys, ys.map((y) => zerosLike(y.ref))];
|
|
2970
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
2940
2971
|
};
|
|
2941
2972
|
}
|
|
2942
2973
|
const jvpRules = {
|
|
@@ -2954,13 +2985,13 @@ const jvpRules = {
|
|
|
2954
2985
|
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2955
2986
|
else {
|
|
2956
2987
|
dx.dispose();
|
|
2957
|
-
return [[cast(x.ref, dtype)], [zerosLike(x)]];
|
|
2988
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2958
2989
|
}
|
|
2959
2990
|
},
|
|
2960
2991
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2961
2992
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2962
2993
|
dx.dispose();
|
|
2963
|
-
return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
|
|
2994
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2964
2995
|
},
|
|
2965
2996
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2966
2997
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2969,6 +3000,14 @@ const jvpRules = {
|
|
|
2969
3000
|
[Primitive.Cos]([x], [dx]) {
|
|
2970
3001
|
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
2971
3002
|
},
|
|
3003
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3004
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3005
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3006
|
+
},
|
|
3007
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3008
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3009
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3010
|
+
},
|
|
2972
3011
|
[Primitive.Exp]([x], [dx]) {
|
|
2973
3012
|
const z = exp$1(x);
|
|
2974
3013
|
return [[z.ref], [z.mul(dx)]];
|
|
@@ -3085,7 +3124,10 @@ function mappedAval(batchDim, aval) {
|
|
|
3085
3124
|
/** Move one axis to a different index. */
|
|
3086
3125
|
function moveaxis$1(x, src, dst) {
|
|
3087
3126
|
const t = pureArray(x);
|
|
3088
|
-
|
|
3127
|
+
src = require_backend.checkAxis(src, t.ndim);
|
|
3128
|
+
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3129
|
+
if (src === dst) return t;
|
|
3130
|
+
const perm = require_backend.range(t.ndim);
|
|
3089
3131
|
perm.splice(src, 1);
|
|
3090
3132
|
perm.splice(dst, 0, src);
|
|
3091
3133
|
return transpose$1(t, perm);
|
|
@@ -3178,6 +3220,8 @@ const vmapRules = {
|
|
|
3178
3220
|
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3179
3221
|
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3180
3222
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3223
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3224
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3181
3225
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3182
3226
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3183
3227
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
@@ -3363,20 +3407,28 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3363
3407
|
function linearizeFlat(f, primalsIn) {
|
|
3364
3408
|
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
3365
3409
|
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3366
|
-
|
|
3410
|
+
const dispose$1 = () => {
|
|
3411
|
+
for (const c of consts) c.dispose();
|
|
3412
|
+
};
|
|
3413
|
+
return [
|
|
3414
|
+
primalsOut,
|
|
3415
|
+
fLin,
|
|
3416
|
+
dispose$1
|
|
3417
|
+
];
|
|
3367
3418
|
}
|
|
3368
3419
|
function linearize$1(f, ...primalsIn) {
|
|
3369
3420
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3370
3421
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3371
|
-
const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3422
|
+
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3372
3423
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
3373
3424
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3374
|
-
const fLin = (...tangentsIn) => {
|
|
3425
|
+
const fLin = ((...tangentsIn) => {
|
|
3375
3426
|
const [tangentsInFlat, inTree2] = flatten(tangentsIn);
|
|
3376
3427
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
|
|
3377
3428
|
const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
|
|
3378
3429
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
3379
|
-
};
|
|
3430
|
+
});
|
|
3431
|
+
fLin.dispose = dispose$1;
|
|
3380
3432
|
return [primalsOut, fLin];
|
|
3381
3433
|
}
|
|
3382
3434
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -3492,7 +3544,10 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3492
3544
|
avalsOut: jaxpr2.outs.map((x) => x.aval),
|
|
3493
3545
|
tracerRefsOut: []
|
|
3494
3546
|
};
|
|
3495
|
-
const outs2 = jaxpr2.outs.map((x) =>
|
|
3547
|
+
const outs2 = jaxpr2.outs.map((x, i$1) => {
|
|
3548
|
+
if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
|
|
3549
|
+
return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
|
|
3550
|
+
});
|
|
3496
3551
|
recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
|
|
3497
3552
|
let i = 0;
|
|
3498
3553
|
let j = 0;
|
|
@@ -3576,13 +3631,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3576
3631
|
const [consts, constvars] = require_backend.unzip2(constToVar.entries());
|
|
3577
3632
|
const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
|
|
3578
3633
|
const outVars = tracersOut.map((t) => tracerToVar.get(t));
|
|
3579
|
-
|
|
3634
|
+
let jaxpr = new Jaxpr(inBinders, eqns, outVars);
|
|
3580
3635
|
typecheckJaxpr(jaxpr);
|
|
3581
3636
|
for (const t of consts) t.ref;
|
|
3582
3637
|
for (const t of tracersIn) t.dispose();
|
|
3583
3638
|
for (const t of tracersOut) t.dispose();
|
|
3639
|
+
jaxpr = jaxpr.simplify();
|
|
3640
|
+
if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3584
3641
|
return {
|
|
3585
|
-
jaxpr
|
|
3642
|
+
jaxpr,
|
|
3586
3643
|
consts
|
|
3587
3644
|
};
|
|
3588
3645
|
}
|
|
@@ -3848,20 +3905,28 @@ function vjpFlat(f, primalsIn) {
|
|
|
3848
3905
|
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
3849
3906
|
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
3850
3907
|
};
|
|
3851
|
-
|
|
3908
|
+
const dispose$1 = () => {
|
|
3909
|
+
for (const c of consts) c.dispose();
|
|
3910
|
+
};
|
|
3911
|
+
return [
|
|
3912
|
+
primalsOut,
|
|
3913
|
+
fVjp,
|
|
3914
|
+
dispose$1
|
|
3915
|
+
];
|
|
3852
3916
|
}
|
|
3853
3917
|
function vjp$1(f, ...primalsIn) {
|
|
3854
3918
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3855
3919
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3856
|
-
const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3920
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3857
3921
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
3858
3922
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3859
|
-
const fVjp = (cotangentsOut) => {
|
|
3923
|
+
const fVjp = ((cotangentsOut) => {
|
|
3860
3924
|
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
3861
3925
|
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
3862
3926
|
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
3863
3927
|
return unflatten(inTree, cotangentsInFlat);
|
|
3864
|
-
};
|
|
3928
|
+
});
|
|
3929
|
+
fVjp.dispose = dispose$1;
|
|
3865
3930
|
return [primalsOut, fVjp];
|
|
3866
3931
|
}
|
|
3867
3932
|
function grad$1(f) {
|
|
@@ -3879,7 +3944,8 @@ function valueAndGrad$1(f) {
|
|
|
3879
3944
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3880
3945
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3881
3946
|
const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
|
|
3882
|
-
for (const r of rest)
|
|
3947
|
+
for (const r of rest) dispose(r);
|
|
3948
|
+
fVjp.dispose();
|
|
3883
3949
|
return [y, ct];
|
|
3884
3950
|
};
|
|
3885
3951
|
}
|
|
@@ -3887,7 +3953,13 @@ function jacrev$1(f) {
|
|
|
3887
3953
|
return function jacobianReverse(x) {
|
|
3888
3954
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
3889
3955
|
const [size$1] = x.shape;
|
|
3890
|
-
const pullback = (ct) =>
|
|
3956
|
+
const pullback = (ct) => {
|
|
3957
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
3958
|
+
y.dispose();
|
|
3959
|
+
const [ret] = fVjp(ct);
|
|
3960
|
+
fVjp.dispose();
|
|
3961
|
+
return ret;
|
|
3962
|
+
};
|
|
3891
3963
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3892
3964
|
};
|
|
3893
3965
|
}
|
|
@@ -3967,19 +4039,38 @@ __export(numpy_exports, {
|
|
|
3967
4039
|
DType: () => require_backend.DType,
|
|
3968
4040
|
abs: () => abs,
|
|
3969
4041
|
absolute: () => absolute,
|
|
4042
|
+
acos: () => acos,
|
|
4043
|
+
acosh: () => acosh,
|
|
3970
4044
|
add: () => add,
|
|
3971
4045
|
allclose: () => allclose,
|
|
3972
4046
|
arange: () => arange,
|
|
4047
|
+
arccos: () => arccos,
|
|
4048
|
+
arccosh: () => arccosh,
|
|
4049
|
+
arcsinh: () => arcsinh,
|
|
4050
|
+
arctan: () => arctan,
|
|
4051
|
+
arctan2: () => arctan2,
|
|
4052
|
+
arctanh: () => arctanh,
|
|
3973
4053
|
argmax: () => argmax,
|
|
3974
4054
|
argmin: () => argmin,
|
|
3975
4055
|
array: () => array,
|
|
4056
|
+
asin: () => asin,
|
|
4057
|
+
asinh: () => asinh,
|
|
3976
4058
|
astype: () => astype,
|
|
4059
|
+
atan: () => atan,
|
|
4060
|
+
atan2: () => atan2,
|
|
4061
|
+
atanh: () => atanh,
|
|
3977
4062
|
bool: () => bool,
|
|
4063
|
+
broadcastArrays: () => broadcastArrays,
|
|
4064
|
+
broadcastShapes: () => broadcastShapes,
|
|
4065
|
+
broadcastTo: () => broadcastTo,
|
|
4066
|
+
cbrt: () => cbrt,
|
|
3978
4067
|
clip: () => clip,
|
|
3979
4068
|
columnStack: () => columnStack,
|
|
3980
4069
|
concatenate: () => concatenate,
|
|
3981
4070
|
cos: () => cos,
|
|
3982
4071
|
cosh: () => cosh,
|
|
4072
|
+
deg2rad: () => deg2rad,
|
|
4073
|
+
degrees: () => degrees,
|
|
3983
4074
|
diag: () => diag,
|
|
3984
4075
|
diagonal: () => diagonal,
|
|
3985
4076
|
divide: () => divide,
|
|
@@ -3990,6 +4081,7 @@ __export(numpy_exports, {
|
|
|
3990
4081
|
eulerGamma: () => eulerGamma,
|
|
3991
4082
|
exp: () => exp,
|
|
3992
4083
|
exp2: () => exp2,
|
|
4084
|
+
expm1: () => expm1,
|
|
3993
4085
|
eye: () => eye,
|
|
3994
4086
|
flip: () => flip,
|
|
3995
4087
|
fliplr: () => fliplr,
|
|
@@ -4001,14 +4093,17 @@ __export(numpy_exports, {
|
|
|
4001
4093
|
greater: () => greater,
|
|
4002
4094
|
greaterEqual: () => greaterEqual,
|
|
4003
4095
|
hstack: () => hstack,
|
|
4096
|
+
hypot: () => hypot,
|
|
4004
4097
|
identity: () => identity$1,
|
|
4005
4098
|
inf: () => inf,
|
|
4099
|
+
inner: () => inner,
|
|
4006
4100
|
int32: () => int32,
|
|
4007
4101
|
less: () => less,
|
|
4008
4102
|
lessEqual: () => lessEqual,
|
|
4009
4103
|
linspace: () => linspace,
|
|
4010
4104
|
log: () => log,
|
|
4011
4105
|
log10: () => log10,
|
|
4106
|
+
log1p: () => log1p,
|
|
4012
4107
|
log2: () => log2,
|
|
4013
4108
|
matmul: () => matmul,
|
|
4014
4109
|
max: () => max,
|
|
@@ -4024,35 +4119,49 @@ __export(numpy_exports, {
|
|
|
4024
4119
|
negative: () => negative,
|
|
4025
4120
|
notEqual: () => notEqual,
|
|
4026
4121
|
ones: () => ones,
|
|
4027
|
-
onesLike: () => onesLike
|
|
4122
|
+
onesLike: () => onesLike,
|
|
4123
|
+
outer: () => outer,
|
|
4028
4124
|
pad: () => pad,
|
|
4029
4125
|
permuteDims: () => permuteDims,
|
|
4030
4126
|
pi: () => pi,
|
|
4127
|
+
pow: () => pow,
|
|
4128
|
+
power: () => power,
|
|
4031
4129
|
prod: () => prod$1,
|
|
4130
|
+
promoteTypes: () => require_backend.promoteTypes,
|
|
4131
|
+
rad2deg: () => rad2deg,
|
|
4132
|
+
radians: () => radians,
|
|
4032
4133
|
ravel: () => ravel,
|
|
4033
4134
|
reciprocal: () => reciprocal,
|
|
4135
|
+
repeat: () => repeat,
|
|
4034
4136
|
reshape: () => reshape,
|
|
4035
|
-
scalar: () => scalar,
|
|
4036
4137
|
shape: () => shape,
|
|
4138
|
+
sign: () => sign,
|
|
4037
4139
|
sin: () => sin,
|
|
4038
4140
|
sinh: () => sinh,
|
|
4039
4141
|
size: () => size,
|
|
4040
4142
|
sqrt: () => sqrt,
|
|
4041
4143
|
square: () => square,
|
|
4042
4144
|
stack: () => stack,
|
|
4145
|
+
std: () => std,
|
|
4146
|
+
subtract: () => subtract,
|
|
4043
4147
|
sum: () => sum,
|
|
4044
4148
|
tan: () => tan,
|
|
4045
4149
|
tanh: () => tanh,
|
|
4150
|
+
tile: () => tile,
|
|
4046
4151
|
transpose: () => transpose,
|
|
4152
|
+
tri: () => tri,
|
|
4153
|
+
tril: () => tril,
|
|
4154
|
+
triu: () => triu,
|
|
4047
4155
|
trueDivide: () => trueDivide,
|
|
4048
4156
|
trunc: () => trunc,
|
|
4049
4157
|
uint32: () => uint32,
|
|
4158
|
+
var_: () => var_,
|
|
4050
4159
|
vdot: () => vdot,
|
|
4051
4160
|
vecdot: () => vecdot,
|
|
4052
4161
|
vstack: () => vstack,
|
|
4053
4162
|
where: () => where,
|
|
4054
4163
|
zeros: () => zeros,
|
|
4055
|
-
zerosLike: () => zerosLike
|
|
4164
|
+
zerosLike: () => zerosLike
|
|
4056
4165
|
});
|
|
4057
4166
|
const float32 = require_backend.DType.Float32;
|
|
4058
4167
|
const int32 = require_backend.DType.Int32;
|
|
@@ -4069,54 +4178,66 @@ const inf = Number.POSITIVE_INFINITY;
|
|
|
4069
4178
|
const nan = NaN;
|
|
4070
4179
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
4071
4180
|
const pi = Math.PI;
|
|
4072
|
-
/** Element-wise addition, with broadcasting. */
|
|
4181
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
4073
4182
|
const add = add$1;
|
|
4074
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
4183
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
4075
4184
|
const multiply = mul;
|
|
4076
|
-
/** Numerical negative of every element of an array. */
|
|
4185
|
+
/** @function Numerical negative of every element of an array. */
|
|
4077
4186
|
const negative = neg;
|
|
4078
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4187
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4079
4188
|
const reciprocal = reciprocal$1;
|
|
4080
|
-
/** Element-wise sine function (takes radians). */
|
|
4189
|
+
/** @function Element-wise sine function (takes radians). */
|
|
4081
4190
|
const sin = sin$1;
|
|
4082
|
-
/** Element-wise cosine function (takes radians). */
|
|
4191
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
4083
4192
|
const cos = cos$1;
|
|
4084
|
-
/**
|
|
4193
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
4194
|
+
const asin = asin$1;
|
|
4195
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
4196
|
+
const atan = atan$1;
|
|
4197
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
4085
4198
|
const exp = exp$1;
|
|
4086
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
4199
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
4087
4200
|
const log = log$1;
|
|
4088
|
-
/** Calculate the square root of all elements in the input array. */
|
|
4201
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
4089
4202
|
const sqrt = sqrt$1;
|
|
4090
|
-
/** Return element-wise minimum of the input arrays. */
|
|
4203
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
4091
4204
|
const minimum = min$1;
|
|
4092
|
-
/** Return element-wise maximum of the input arrays. */
|
|
4205
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
4093
4206
|
const maximum = max$1;
|
|
4094
|
-
/** Compare two arrays element-wise. */
|
|
4207
|
+
/** @function Compare two arrays element-wise. */
|
|
4095
4208
|
const greater = greater$1;
|
|
4096
|
-
/** Compare two arrays element-wise. */
|
|
4209
|
+
/** @function Compare two arrays element-wise. */
|
|
4097
4210
|
const less = less$1;
|
|
4098
|
-
/** Compare two arrays element-wise. */
|
|
4211
|
+
/** @function Compare two arrays element-wise. */
|
|
4099
4212
|
const equal = equal$1;
|
|
4100
|
-
/** Compare two arrays element-wise. */
|
|
4213
|
+
/** @function Compare two arrays element-wise. */
|
|
4101
4214
|
const notEqual = notEqual$1;
|
|
4102
|
-
/** Compare two arrays element-wise. */
|
|
4215
|
+
/** @function Compare two arrays element-wise. */
|
|
4103
4216
|
const greaterEqual = greaterEqual$1;
|
|
4104
|
-
/** Compare two arrays element-wise. */
|
|
4217
|
+
/** @function Compare two arrays element-wise. */
|
|
4105
4218
|
const lessEqual = lessEqual$1;
|
|
4106
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4219
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4107
4220
|
const where = where$1;
|
|
4108
|
-
/**
|
|
4221
|
+
/**
|
|
4222
|
+
* @function
|
|
4223
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
4224
|
+
*/
|
|
4109
4225
|
const transpose = transpose$1;
|
|
4110
4226
|
/**
|
|
4227
|
+
* @function
|
|
4111
4228
|
* Give a new shape to an array without changing its data.
|
|
4112
4229
|
*
|
|
4113
4230
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
4114
4231
|
* length of the array and remaining dimensions.
|
|
4115
4232
|
*/
|
|
4116
4233
|
const reshape = reshape$1;
|
|
4117
|
-
/**
|
|
4234
|
+
/**
|
|
4235
|
+
* @function
|
|
4236
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
4237
|
+
*/
|
|
4118
4238
|
const moveaxis = moveaxis$1;
|
|
4119
4239
|
/**
|
|
4240
|
+
* @function
|
|
4120
4241
|
* Add padding (zeros) to an array.
|
|
4121
4242
|
*
|
|
4122
4243
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -4124,15 +4245,27 @@ const moveaxis = moveaxis$1;
|
|
|
4124
4245
|
* pair specifies the padding for its corresponding axis.
|
|
4125
4246
|
*/
|
|
4126
4247
|
const pad = pad$1;
|
|
4127
|
-
/**
|
|
4248
|
+
/**
|
|
4249
|
+
* @function
|
|
4250
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
4251
|
+
*/
|
|
4128
4252
|
const ndim = ndim$1;
|
|
4129
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
4253
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
4130
4254
|
const shape = getShape;
|
|
4131
|
-
/**
|
|
4132
|
-
|
|
4133
|
-
|
|
4134
|
-
|
|
4135
|
-
|
|
4255
|
+
/**
|
|
4256
|
+
* @function
|
|
4257
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
4258
|
+
*/
|
|
4259
|
+
const zerosLike = zerosLike$1;
|
|
4260
|
+
/**
|
|
4261
|
+
* @function
|
|
4262
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
4263
|
+
*/
|
|
4264
|
+
const onesLike = onesLike$1;
|
|
4265
|
+
/**
|
|
4266
|
+
* @function
|
|
4267
|
+
* Return a full array with the same shape and type as a given array.
|
|
4268
|
+
*/
|
|
4136
4269
|
const fullLike$1 = fullLike;
|
|
4137
4270
|
/**
|
|
4138
4271
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -4147,23 +4280,23 @@ function astype(a, dtype) {
|
|
|
4147
4280
|
return fudgeArray(a).astype(dtype);
|
|
4148
4281
|
}
|
|
4149
4282
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
4150
|
-
function sum(a, axis, opts) {
|
|
4283
|
+
function sum(a, axis = null, opts) {
|
|
4151
4284
|
return reduce(a, require_backend.AluOp.Add, axis, opts);
|
|
4152
4285
|
}
|
|
4153
4286
|
/** Product of the array elements over a given axis. */
|
|
4154
|
-
function prod$1(a, axis, opts) {
|
|
4287
|
+
function prod$1(a, axis = null, opts) {
|
|
4155
4288
|
return reduce(a, require_backend.AluOp.Mul, axis, opts);
|
|
4156
4289
|
}
|
|
4157
4290
|
/** Return the minimum of array elements along a given axis. */
|
|
4158
|
-
function min(a, axis, opts) {
|
|
4291
|
+
function min(a, axis = null, opts) {
|
|
4159
4292
|
return reduce(a, require_backend.AluOp.Min, axis, opts);
|
|
4160
4293
|
}
|
|
4161
4294
|
/** Return the maximum of array elements along a given axis. */
|
|
4162
|
-
function max(a, axis, opts) {
|
|
4295
|
+
function max(a, axis = null, opts) {
|
|
4163
4296
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
4164
4297
|
}
|
|
4165
4298
|
/** Compute the average of the array elements along the specified axis. */
|
|
4166
|
-
function mean(a, axis, opts) {
|
|
4299
|
+
function mean(a, axis = null, opts) {
|
|
4167
4300
|
return fudgeArray(a).mean(axis, opts);
|
|
4168
4301
|
}
|
|
4169
4302
|
/**
|
|
@@ -4179,7 +4312,7 @@ function argmin(a, axis, opts) {
|
|
|
4179
4312
|
axis = 0;
|
|
4180
4313
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4181
4314
|
const shape$1 = a.shape;
|
|
4182
|
-
const isMax = equal(a, min(a.ref, axis, {
|
|
4315
|
+
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4183
4316
|
const length = scalar(shape$1[axis], {
|
|
4184
4317
|
dtype: int32,
|
|
4185
4318
|
device: a.device
|
|
@@ -4203,7 +4336,7 @@ function argmax(a, axis, opts) {
|
|
|
4203
4336
|
axis = 0;
|
|
4204
4337
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4205
4338
|
const shape$1 = a.shape;
|
|
4206
|
-
const isMax = equal(a, max(a.ref, axis, {
|
|
4339
|
+
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4207
4340
|
const length = scalar(shape$1[axis], {
|
|
4208
4341
|
dtype: int32,
|
|
4209
4342
|
device: a.device
|
|
@@ -4215,17 +4348,9 @@ function argmax(a, axis, opts) {
|
|
|
4215
4348
|
return length.sub(max(idx, axis, opts));
|
|
4216
4349
|
}
|
|
4217
4350
|
/** Reverse the elements in an array along the given axes. */
|
|
4218
|
-
function flip(x, axis) {
|
|
4351
|
+
function flip(x, axis = null) {
|
|
4219
4352
|
const nd = ndim(x);
|
|
4220
|
-
|
|
4221
|
-
else if (typeof axis === "number") axis = [axis];
|
|
4222
|
-
const seen = /* @__PURE__ */ new Set();
|
|
4223
|
-
for (let i = 0; i < axis.length; i++) {
|
|
4224
|
-
if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
|
|
4225
|
-
if (axis[i] < 0) axis[i] += nd;
|
|
4226
|
-
if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
|
|
4227
|
-
seen.add(axis[i]);
|
|
4228
|
-
}
|
|
4353
|
+
axis = require_backend.normalizeAxis(axis, nd);
|
|
4229
4354
|
return flip$1(x, axis);
|
|
4230
4355
|
}
|
|
4231
4356
|
/**
|
|
@@ -4331,12 +4456,80 @@ function flipud(x) {
|
|
|
4331
4456
|
function fliplr(x) {
|
|
4332
4457
|
return flip(x, 1);
|
|
4333
4458
|
}
|
|
4459
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
4334
4460
|
const permuteDims = transpose;
|
|
4335
4461
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
4336
4462
|
function ravel(a) {
|
|
4337
4463
|
return fudgeArray(a).ravel();
|
|
4338
4464
|
}
|
|
4339
4465
|
/**
|
|
4466
|
+
* Repeat each element of an array after themselves.
|
|
4467
|
+
*
|
|
4468
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
4469
|
+
* output array.
|
|
4470
|
+
*/
|
|
4471
|
+
function repeat(a, repeats, axis) {
|
|
4472
|
+
if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
|
|
4473
|
+
a = fudgeArray(a);
|
|
4474
|
+
if (axis === void 0) {
|
|
4475
|
+
a = ravel(a);
|
|
4476
|
+
axis = 0;
|
|
4477
|
+
}
|
|
4478
|
+
axis = require_backend.checkAxis(axis, a.ndim);
|
|
4479
|
+
if (repeats === 1) return a;
|
|
4480
|
+
const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
|
|
4481
|
+
const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
|
|
4482
|
+
return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
|
|
4483
|
+
}
|
|
4484
|
+
/**
|
|
4485
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
4486
|
+
*
|
|
4487
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
4488
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
4489
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
4490
|
+
*/
|
|
4491
|
+
function tile(a, reps) {
|
|
4492
|
+
a = fudgeArray(a);
|
|
4493
|
+
if (typeof reps === "number") reps = [reps];
|
|
4494
|
+
if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
|
|
4495
|
+
const ndiff = reps.length - a.ndim;
|
|
4496
|
+
if (ndiff > 0) a = a.reshape([...require_backend.rep(ndiff, 1), ...a.shape]);
|
|
4497
|
+
if (ndiff < 0) reps = [...require_backend.rep(-ndiff, 1), ...reps];
|
|
4498
|
+
const broadcastedShape = [];
|
|
4499
|
+
const broadcastAxes = [];
|
|
4500
|
+
for (let i = 0; i < a.ndim; i++) {
|
|
4501
|
+
if (reps[i] > 1) {
|
|
4502
|
+
broadcastedShape.push(reps[i]);
|
|
4503
|
+
broadcastAxes.push(broadcastedShape.length - 1);
|
|
4504
|
+
}
|
|
4505
|
+
broadcastedShape.push(a.shape[i]);
|
|
4506
|
+
}
|
|
4507
|
+
const finalShape = a.shape.map((d, i) => reps[i] * d);
|
|
4508
|
+
return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
|
|
4509
|
+
}
|
|
4510
|
+
/**
|
|
4511
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
4512
|
+
*
|
|
4513
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
4514
|
+
* dimensions where the shape is 1.
|
|
4515
|
+
*/
|
|
4516
|
+
function broadcastTo(a, shape$1) {
|
|
4517
|
+
const nd = ndim(a);
|
|
4518
|
+
if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
|
|
4519
|
+
return broadcast(a, shape$1, require_backend.range(shape$1.length - nd));
|
|
4520
|
+
}
|
|
4521
|
+
/** Broadcast input shapes to a common output shape. */
|
|
4522
|
+
function broadcastShapes(...shapes) {
|
|
4523
|
+
if (shapes.length === 0) return [];
|
|
4524
|
+
return shapes.reduce(generalBroadcast);
|
|
4525
|
+
}
|
|
4526
|
+
/** Broadcast arrays to a common shape. */
|
|
4527
|
+
function broadcastArrays(...arrays) {
|
|
4528
|
+
const shapes = arrays.map((a) => shape(a));
|
|
4529
|
+
const outShape = broadcastShapes(...shapes);
|
|
4530
|
+
return arrays.map((a) => broadcastTo(a, outShape));
|
|
4531
|
+
}
|
|
4532
|
+
/**
|
|
4340
4533
|
* Return specified diagonals.
|
|
4341
4534
|
*
|
|
4342
4535
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
@@ -4360,7 +4553,7 @@ function diag(v, k = 0) {
|
|
|
4360
4553
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
4361
4554
|
if (a.ndim === 1) {
|
|
4362
4555
|
const n = a.shape[0];
|
|
4363
|
-
const ret = where(eye(n).equal(1), a.ref, zerosLike
|
|
4556
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
4364
4557
|
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4365
4558
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4366
4559
|
else return ret;
|
|
@@ -4404,8 +4597,36 @@ function dot(x, y) {
|
|
|
4404
4597
|
]);
|
|
4405
4598
|
return dot$1(x, y);
|
|
4406
4599
|
}
|
|
4407
|
-
/**
|
|
4408
|
-
|
|
4600
|
+
/**
|
|
4601
|
+
* Compute the inner product of two arrays.
|
|
4602
|
+
*
|
|
4603
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
4604
|
+
* contraction on the last axis.
|
|
4605
|
+
*
|
|
4606
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
4607
|
+
*/
|
|
4608
|
+
function inner(x, y) {
|
|
4609
|
+
x = reshape(x, shape(x).toSpliced(-1, 0, ...require_backend.rep(ndim(y) - 1, 1)));
|
|
4610
|
+
return dot$1(x, y);
|
|
4611
|
+
}
|
|
4612
|
+
/**
|
|
4613
|
+
* Compute the outer product of two arrays.
|
|
4614
|
+
*
|
|
4615
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
4616
|
+
* be of shape `[x.size, y.size]`.
|
|
4617
|
+
*/
|
|
4618
|
+
function outer(x, y) {
|
|
4619
|
+
x = ravel(x);
|
|
4620
|
+
y = ravel(y);
|
|
4621
|
+
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
4622
|
+
}
|
|
4623
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
4624
|
+
function vecdot(x, y, { axis } = {}) {
|
|
4625
|
+
const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
|
|
4626
|
+
const yaxis = require_backend.checkAxis(axis ?? -1, ndim(y));
|
|
4627
|
+
if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
|
|
4628
|
+
x = moveaxis(x, xaxis, -1);
|
|
4629
|
+
y = moveaxis(y, yaxis, -1);
|
|
4409
4630
|
return dot$1(x, y);
|
|
4410
4631
|
}
|
|
4411
4632
|
/**
|
|
@@ -4414,7 +4635,7 @@ function vecdot(x, y) {
|
|
|
4414
4635
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
4415
4636
|
*/
|
|
4416
4637
|
function vdot(x, y) {
|
|
4417
|
-
return
|
|
4638
|
+
return dot$1(ravel(x), ravel(y));
|
|
4418
4639
|
}
|
|
4419
4640
|
/**
|
|
4420
4641
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
@@ -4443,6 +4664,43 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
4443
4664
|
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
4444
4665
|
}
|
|
4445
4666
|
/**
|
|
4667
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
4668
|
+
*
|
|
4669
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
4670
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
4671
|
+
* `k>0` is above it.
|
|
4672
|
+
*/
|
|
4673
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
4674
|
+
m ??= n;
|
|
4675
|
+
dtype ??= require_backend.DType.Float32;
|
|
4676
|
+
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
4677
|
+
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
4678
|
+
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
4679
|
+
const rows = arange(k, n + k, 1, {
|
|
4680
|
+
dtype: require_backend.DType.Int32,
|
|
4681
|
+
device
|
|
4682
|
+
});
|
|
4683
|
+
const cols = arange(0, m, 1, {
|
|
4684
|
+
dtype: require_backend.DType.Int32,
|
|
4685
|
+
device
|
|
4686
|
+
});
|
|
4687
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
4688
|
+
}
|
|
4689
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
4690
|
+
function tril(a, k = 0) {
|
|
4691
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4692
|
+
a = fudgeArray(a);
|
|
4693
|
+
const [n, m] = a.shape.slice(-2);
|
|
4694
|
+
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
4695
|
+
}
|
|
4696
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
4697
|
+
function triu(a, k = 0) {
|
|
4698
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4699
|
+
a = fudgeArray(a);
|
|
4700
|
+
const [n, m] = a.shape.slice(-2);
|
|
4701
|
+
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
4702
|
+
}
|
|
4703
|
+
/**
|
|
4446
4704
|
* Clip (limit) the values in an array.
|
|
4447
4705
|
*
|
|
4448
4706
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -4466,18 +4724,70 @@ function absolute(x) {
|
|
|
4466
4724
|
x = fudgeArray(x);
|
|
4467
4725
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
4468
4726
|
}
|
|
4469
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
4727
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
4470
4728
|
const abs = absolute;
|
|
4729
|
+
/** Return an element-wise indication of sign of the input. */
|
|
4730
|
+
function sign(x) {
|
|
4731
|
+
x = fudgeArray(x);
|
|
4732
|
+
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4733
|
+
}
|
|
4471
4734
|
/** Calculate element-wise square of the input array. */
|
|
4472
4735
|
function square(x) {
|
|
4473
4736
|
x = fudgeArray(x);
|
|
4474
4737
|
return x.ref.mul(x);
|
|
4475
4738
|
}
|
|
4476
|
-
/**
|
|
4739
|
+
/** Element-wise tangent function (takes radians). */
|
|
4477
4740
|
function tan(x) {
|
|
4478
4741
|
x = fudgeArray(x);
|
|
4479
4742
|
return sin(x.ref).div(cos(x));
|
|
4480
4743
|
}
|
|
4744
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
4745
|
+
function acos(x) {
|
|
4746
|
+
return subtract(pi / 2, asin(x));
|
|
4747
|
+
}
|
|
4748
|
+
/**
|
|
4749
|
+
* @function
|
|
4750
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4751
|
+
*
|
|
4752
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4753
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4754
|
+
* improvements.
|
|
4755
|
+
*/
|
|
4756
|
+
const hypot = jit$1((x1, x2) => {
|
|
4757
|
+
return sqrt(square(x1).add(square(x2)));
|
|
4758
|
+
});
|
|
4759
|
+
/**
|
|
4760
|
+
* @function
|
|
4761
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
4762
|
+
*
|
|
4763
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
4764
|
+
* The result is in the range [-π, π].
|
|
4765
|
+
*
|
|
4766
|
+
* Uses numerically stable formulas:
|
|
4767
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
4768
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
4769
|
+
*
|
|
4770
|
+
* The output is ill-defined when both x and y are zero.
|
|
4771
|
+
*/
|
|
4772
|
+
const atan2 = jit$1((y, x) => {
|
|
4773
|
+
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4774
|
+
const xNeg = less(x.ref, 0);
|
|
4775
|
+
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
4776
|
+
const denom = where(xNeg, y, r.add(x));
|
|
4777
|
+
return atan(numer.div(denom)).mul(2);
|
|
4778
|
+
});
|
|
4779
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
4780
|
+
const arccos = acos;
|
|
4781
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
4782
|
+
const arctan = atan;
|
|
4783
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
4784
|
+
const arctan2 = atan2;
|
|
4785
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
4786
|
+
function subtract(x, y) {
|
|
4787
|
+
x = fudgeArray(x);
|
|
4788
|
+
y = fudgeArray(y);
|
|
4789
|
+
return x.sub(y);
|
|
4790
|
+
}
|
|
4481
4791
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
4482
4792
|
function trueDivide(x, y) {
|
|
4483
4793
|
x = fudgeArray(x);
|
|
@@ -4485,7 +4795,7 @@ function trueDivide(x, y) {
|
|
|
4485
4795
|
if (!require_backend.isFloatDtype(x.dtype) || !require_backend.isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
|
|
4486
4796
|
return x.div(y);
|
|
4487
4797
|
}
|
|
4488
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
4798
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
4489
4799
|
const divide = trueDivide;
|
|
4490
4800
|
/** Round input to the nearest integer towards zero. */
|
|
4491
4801
|
function trunc(x) {
|
|
@@ -4503,36 +4813,134 @@ function log2(x) {
|
|
|
4503
4813
|
function log10(x) {
|
|
4504
4814
|
return log(x).mul(Math.LOG10E);
|
|
4505
4815
|
}
|
|
4816
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
4817
|
+
function expm1(x) {
|
|
4818
|
+
return exp(x).sub(1);
|
|
4819
|
+
}
|
|
4820
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
4821
|
+
function log1p(x) {
|
|
4822
|
+
return log(add(1, x));
|
|
4823
|
+
}
|
|
4824
|
+
/** Convert angles from degrees to radians. */
|
|
4825
|
+
function deg2rad(x) {
|
|
4826
|
+
return multiply(x, pi / 180);
|
|
4827
|
+
}
|
|
4828
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
4829
|
+
const radians = deg2rad;
|
|
4830
|
+
/** Convert angles from radians to degrees. */
|
|
4831
|
+
function rad2deg(x) {
|
|
4832
|
+
return multiply(x, 180 / pi);
|
|
4833
|
+
}
|
|
4834
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
4835
|
+
const degrees = rad2deg;
|
|
4506
4836
|
/**
|
|
4837
|
+
* @function
|
|
4838
|
+
* Computes first array raised to power of second array, element-wise.
|
|
4839
|
+
*/
|
|
4840
|
+
const power = jit$1((x1, x2) => {
|
|
4841
|
+
return exp(log(x1).mul(x2));
|
|
4842
|
+
});
|
|
4843
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
4844
|
+
const pow = power;
|
|
4845
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
4846
|
+
const cbrt = jit$1((x) => {
|
|
4847
|
+
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4848
|
+
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4849
|
+
});
|
|
4850
|
+
/**
|
|
4851
|
+
* @function
|
|
4507
4852
|
* Calculate element-wise hyperbolic sine of input.
|
|
4508
4853
|
*
|
|
4509
4854
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4510
4855
|
*/
|
|
4511
|
-
|
|
4856
|
+
const sinh = jit$1((x) => {
|
|
4512
4857
|
const ex = exp(x);
|
|
4513
4858
|
const emx = reciprocal(ex.ref);
|
|
4514
4859
|
return ex.sub(emx).mul(.5);
|
|
4515
|
-
}
|
|
4860
|
+
});
|
|
4516
4861
|
/**
|
|
4862
|
+
* @function
|
|
4517
4863
|
* Calculate element-wise hyperbolic cosine of input.
|
|
4518
4864
|
*
|
|
4519
4865
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4520
4866
|
*/
|
|
4521
|
-
|
|
4867
|
+
const cosh = jit$1((x) => {
|
|
4522
4868
|
const ex = exp(x);
|
|
4523
4869
|
const emx = reciprocal(ex.ref);
|
|
4524
4870
|
return ex.add(emx).mul(.5);
|
|
4525
|
-
}
|
|
4871
|
+
});
|
|
4526
4872
|
/**
|
|
4873
|
+
* @function
|
|
4527
4874
|
* Calculate element-wise hyperbolic tangent of input.
|
|
4528
4875
|
*
|
|
4529
4876
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4530
4877
|
*/
|
|
4531
|
-
|
|
4532
|
-
x = fudgeArray(x);
|
|
4878
|
+
const tanh = jit$1((x) => {
|
|
4533
4879
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4534
4880
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4535
4881
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4882
|
+
});
|
|
4883
|
+
/**
|
|
4884
|
+
* @function
|
|
4885
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
4886
|
+
*
|
|
4887
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4888
|
+
*/
|
|
4889
|
+
const arcsinh = jit$1((x) => {
|
|
4890
|
+
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4891
|
+
});
|
|
4892
|
+
/**
|
|
4893
|
+
* @function
|
|
4894
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
4895
|
+
*
|
|
4896
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4897
|
+
*/
|
|
4898
|
+
const arccosh = jit$1((x) => {
|
|
4899
|
+
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4900
|
+
});
|
|
4901
|
+
/**
|
|
4902
|
+
* @function
|
|
4903
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
4904
|
+
*
|
|
4905
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4906
|
+
*/
|
|
4907
|
+
const arctanh = jit$1((x) => {
|
|
4908
|
+
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4909
|
+
});
|
|
4910
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
4911
|
+
const asinh = arcsinh;
|
|
4912
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
4913
|
+
const acosh = arccosh;
|
|
4914
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
4915
|
+
const atanh = arctanh;
|
|
4916
|
+
/**
|
|
4917
|
+
* Compute the variance of an array.
|
|
4918
|
+
*
|
|
4919
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
4920
|
+
* the specified axis.
|
|
4921
|
+
*
|
|
4922
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4923
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4924
|
+
*/
|
|
4925
|
+
function var_(x, axis = null, opts) {
|
|
4926
|
+
x = fudgeArray(x);
|
|
4927
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
4928
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
4929
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
4930
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
4931
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
4932
|
+
}
|
|
4933
|
+
/**
|
|
4934
|
+
* Compute the standard deviation of an array.
|
|
4935
|
+
*
|
|
4936
|
+
* The standard deviation is computed for the flattened array by default,
|
|
4937
|
+
* otherwise over the specified axis.
|
|
4938
|
+
*
|
|
4939
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4940
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4941
|
+
*/
|
|
4942
|
+
function std(x, axis = null, opts) {
|
|
4943
|
+
return sqrt(var_(x, axis, opts));
|
|
4536
4944
|
}
|
|
4537
4945
|
|
|
4538
4946
|
//#endregion
|
|
@@ -4547,6 +4955,7 @@ __export(nn_exports, {
|
|
|
4547
4955
|
leakyRelu: () => leakyRelu,
|
|
4548
4956
|
logSigmoid: () => logSigmoid,
|
|
4549
4957
|
logSoftmax: () => logSoftmax,
|
|
4958
|
+
logmeanexp: () => logmeanexp,
|
|
4550
4959
|
logsumexp: () => logsumexp,
|
|
4551
4960
|
mish: () => mish,
|
|
4552
4961
|
oneHot: () => oneHot,
|
|
@@ -4557,6 +4966,8 @@ __export(nn_exports, {
|
|
|
4557
4966
|
softSign: () => softSign,
|
|
4558
4967
|
softmax: () => softmax,
|
|
4559
4968
|
softplus: () => softplus,
|
|
4969
|
+
squareplus: () => squareplus,
|
|
4970
|
+
standardize: () => standardize,
|
|
4560
4971
|
swish: () => swish
|
|
4561
4972
|
});
|
|
4562
4973
|
/**
|
|
@@ -4600,6 +5011,7 @@ function softSign(x) {
|
|
|
4600
5011
|
return x.ref.div(absolute(x).add(1));
|
|
4601
5012
|
}
|
|
4602
5013
|
/**
|
|
5014
|
+
* @function
|
|
4603
5015
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4604
5016
|
* Swish, computed element-wise:
|
|
4605
5017
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4610,6 +5022,7 @@ function softSign(x) {
|
|
|
4610
5022
|
*/
|
|
4611
5023
|
const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
|
|
4612
5024
|
/**
|
|
5025
|
+
* @function
|
|
4613
5026
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4614
5027
|
* Swish, computed element-wise:
|
|
4615
5028
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4626,7 +5039,10 @@ const swish = silu;
|
|
|
4626
5039
|
function logSigmoid(x) {
|
|
4627
5040
|
return negative(softplus(negative(x)));
|
|
4628
5041
|
}
|
|
4629
|
-
/**
|
|
5042
|
+
/**
|
|
5043
|
+
* @function
|
|
5044
|
+
* Identity activation function. Returns the argument unmodified.
|
|
5045
|
+
*/
|
|
4630
5046
|
const identity = fudgeArray;
|
|
4631
5047
|
/** Leaky rectified linear (ReLU) activation function */
|
|
4632
5048
|
function leakyRelu(x, negativeSlope = .01) {
|
|
@@ -4654,6 +5070,7 @@ function celu(x, alpha = 1) {
|
|
|
4654
5070
|
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
4655
5071
|
}
|
|
4656
5072
|
/**
|
|
5073
|
+
* @function
|
|
4657
5074
|
* Gaussion error linear unit (GELU) activation function.
|
|
4658
5075
|
*
|
|
4659
5076
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -4685,6 +5102,16 @@ function glu(x, axis = -1) {
|
|
|
4685
5102
|
return a.mul(sigmoid(b));
|
|
4686
5103
|
}
|
|
4687
5104
|
/**
|
|
5105
|
+
* Squareplus activation function.
|
|
5106
|
+
*
|
|
5107
|
+
* Computes the element-wise function:
|
|
5108
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
5109
|
+
*/
|
|
5110
|
+
function squareplus(x, b = 4) {
|
|
5111
|
+
x = fudgeArray(x);
|
|
5112
|
+
return x.ref.add(sqrt(square(x).add(b))).mul(.5);
|
|
5113
|
+
}
|
|
5114
|
+
/**
|
|
4688
5115
|
* Mish activation function.
|
|
4689
5116
|
*
|
|
4690
5117
|
* Computes the element-wise function:
|
|
@@ -4702,17 +5129,13 @@ function mish(x) {
|
|
|
4702
5129
|
*
|
|
4703
5130
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
4704
5131
|
*/
|
|
4705
|
-
function softmax(x, axis) {
|
|
5132
|
+
function softmax(x, axis = -1) {
|
|
4706
5133
|
x = fudgeArray(x);
|
|
4707
|
-
|
|
4708
|
-
|
|
4709
|
-
|
|
4710
|
-
x.dispose();
|
|
4711
|
-
return ones(x.shape);
|
|
4712
|
-
}
|
|
4713
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5134
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5135
|
+
if (axis.length === 0) return onesLike(x);
|
|
5136
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4714
5137
|
const unnormalized = exp(x.sub(stopGradient(xMax)));
|
|
4715
|
-
return unnormalized.ref.div(unnormalized.sum(axis, {
|
|
5138
|
+
return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
|
|
4716
5139
|
}
|
|
4717
5140
|
/**
|
|
4718
5141
|
* Log-Softmax function.
|
|
@@ -4722,17 +5145,13 @@ function softmax(x, axis) {
|
|
|
4722
5145
|
*
|
|
4723
5146
|
* If `axis` is not specified, it defaults to the last axis.
|
|
4724
5147
|
*/
|
|
4725
|
-
function logSoftmax(x, axis) {
|
|
5148
|
+
function logSoftmax(x, axis = -1) {
|
|
4726
5149
|
x = fudgeArray(x);
|
|
4727
|
-
|
|
4728
|
-
|
|
4729
|
-
|
|
4730
|
-
x.dispose();
|
|
4731
|
-
return zeros(x.shape);
|
|
4732
|
-
}
|
|
4733
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5150
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5151
|
+
if (axis.length === 0) return zerosLike(x);
|
|
5152
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4734
5153
|
const shifted = x.sub(stopGradient(xMax));
|
|
4735
|
-
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, {
|
|
5154
|
+
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
|
|
4736
5155
|
return shifted.sub(shiftedLogsumexp);
|
|
4737
5156
|
}
|
|
4738
5157
|
/**
|
|
@@ -4743,16 +5162,39 @@ function logSoftmax(x, axis) {
|
|
|
4743
5162
|
*
|
|
4744
5163
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
4745
5164
|
*/
|
|
4746
|
-
function logsumexp(x, axis) {
|
|
5165
|
+
function logsumexp(x, axis = null) {
|
|
4747
5166
|
x = fudgeArray(x);
|
|
4748
|
-
|
|
4749
|
-
else if (typeof axis === "number") axis = [axis];
|
|
5167
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
4750
5168
|
if (axis.length === 0) return x;
|
|
4751
5169
|
const xMax = stopGradient(max(x.ref, axis));
|
|
4752
5170
|
const xMaxDims = broadcast(xMax.ref, x.shape, axis);
|
|
4753
5171
|
const shifted = x.sub(xMaxDims);
|
|
4754
5172
|
return xMax.add(log(exp(shifted).sum(axis)));
|
|
4755
5173
|
}
|
|
5174
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5175
|
+
function logmeanexp(x, axis = null) {
|
|
5176
|
+
x = fudgeArray(x);
|
|
5177
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5178
|
+
if (axis.length === 0) return x;
|
|
5179
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5180
|
+
return logsumexp(x, axis).sub(Math.log(n));
|
|
5181
|
+
}
|
|
5182
|
+
/**
|
|
5183
|
+
* Standardizes input to zero mean and unit variance.
|
|
5184
|
+
*
|
|
5185
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
5186
|
+
* axis, or `null` to standardize over all elements.
|
|
5187
|
+
*
|
|
5188
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
5189
|
+
*/
|
|
5190
|
+
function standardize(x, axis = -1, opts = {}) {
|
|
5191
|
+
x = fudgeArray(x);
|
|
5192
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5193
|
+
if (axis.length === 0) return x;
|
|
5194
|
+
const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
|
|
5195
|
+
const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
|
|
5196
|
+
return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
|
|
5197
|
+
}
|
|
4756
5198
|
/**
|
|
4757
5199
|
* One-hot encodes the given indices.
|
|
4758
5200
|
*
|
|
@@ -4770,7 +5212,7 @@ function logsumexp(x, axis) {
|
|
|
4770
5212
|
* ```
|
|
4771
5213
|
*/
|
|
4772
5214
|
function oneHot(x, numClasses) {
|
|
4773
|
-
if (x.dtype
|
|
5215
|
+
if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4774
5216
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4775
5217
|
}
|
|
4776
5218
|
|
|
@@ -4778,8 +5220,11 @@ function oneHot(x, numClasses) {
|
|
|
4778
5220
|
//#region src/random.ts
|
|
4779
5221
|
var random_exports = {};
|
|
4780
5222
|
__export(random_exports, {
|
|
5223
|
+
bernoulli: () => bernoulli,
|
|
4781
5224
|
bits: () => bits,
|
|
5225
|
+
exponential: () => exponential,
|
|
4782
5226
|
key: () => key,
|
|
5227
|
+
normal: () => normal,
|
|
4783
5228
|
split: () => split,
|
|
4784
5229
|
uniform: () => uniform
|
|
4785
5230
|
});
|
|
@@ -4810,11 +5255,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
4810
5255
|
/** Sample uniform random values in [minval, maxval) with given shape. */
|
|
4811
5256
|
function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
4812
5257
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
4813
|
-
const mantissa = bits(key$1, shape$1).div(
|
|
5258
|
+
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
4814
5259
|
dtype: require_backend.DType.Uint32,
|
|
4815
5260
|
device: key$1.device
|
|
4816
5261
|
}));
|
|
4817
|
-
const float12 = mantissa.add(
|
|
5262
|
+
const float12 = mantissa.add(array(1065353216, {
|
|
4818
5263
|
dtype: require_backend.DType.Uint32,
|
|
4819
5264
|
device: key$1.device
|
|
4820
5265
|
}));
|
|
@@ -4822,6 +5267,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
4822
5267
|
if (minval === 0 && maxval === 1) return rand;
|
|
4823
5268
|
else return rand.mul(maxval - minval).add(minval);
|
|
4824
5269
|
}
|
|
5270
|
+
/**
|
|
5271
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5272
|
+
*
|
|
5273
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
5274
|
+
* and must be broadcastable to `shape`.
|
|
5275
|
+
*/
|
|
5276
|
+
function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
5277
|
+
p = fudgeArray(p);
|
|
5278
|
+
return uniform(key$1, shape$1).less(p);
|
|
5279
|
+
}
|
|
5280
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
5281
|
+
function exponential(key$1, shape$1 = []) {
|
|
5282
|
+
const u = uniform(key$1, shape$1);
|
|
5283
|
+
return negative(log1p(negative(u)));
|
|
5284
|
+
}
|
|
5285
|
+
/**
|
|
5286
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5287
|
+
*
|
|
5288
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5289
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5290
|
+
* bitwise identical to JAX.
|
|
5291
|
+
*/
|
|
5292
|
+
function normal(key$1, shape$1 = []) {
|
|
5293
|
+
const [k1, k2] = split(key$1, 2);
|
|
5294
|
+
const u1 = uniform(k1, shape$1);
|
|
5295
|
+
const u2 = uniform(k2, shape$1);
|
|
5296
|
+
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5297
|
+
const theta = u2.mul(2 * Math.PI);
|
|
5298
|
+
return radius.mul(cos(theta));
|
|
5299
|
+
}
|
|
4825
5300
|
|
|
4826
5301
|
//#endregion
|
|
4827
5302
|
//#region src/polyfills.ts
|
|
@@ -4831,20 +5306,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
4831
5306
|
|
|
4832
5307
|
//#endregion
|
|
4833
5308
|
//#region src/index.ts
|
|
4834
|
-
/**
|
|
5309
|
+
/**
|
|
5310
|
+
* @function
|
|
5311
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
5312
|
+
*/
|
|
4835
5313
|
const jvp = jvp$1;
|
|
4836
|
-
/**
|
|
5314
|
+
/**
|
|
5315
|
+
* @function
|
|
5316
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
5317
|
+
*/
|
|
4837
5318
|
const vmap = vmap$1;
|
|
4838
|
-
/**
|
|
5319
|
+
/**
|
|
5320
|
+
* @function
|
|
5321
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
5322
|
+
*/
|
|
4839
5323
|
const jacfwd = jacfwd$1;
|
|
4840
|
-
/**
|
|
5324
|
+
/**
|
|
5325
|
+
* @function
|
|
5326
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
5327
|
+
*/
|
|
4841
5328
|
const makeJaxpr = makeJaxpr$1;
|
|
4842
5329
|
/**
|
|
5330
|
+
* @function
|
|
4843
5331
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
4844
5332
|
*
|
|
4845
5333
|
* The function will be compiled the first time it is called with a set of
|
|
4846
5334
|
* argument shapes.
|
|
4847
5335
|
*
|
|
5336
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
5337
|
+
* calls to free memory associated with array constants.
|
|
5338
|
+
*
|
|
4848
5339
|
* **Options:**
|
|
4849
5340
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
4850
5341
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -4854,26 +5345,59 @@ const makeJaxpr = makeJaxpr$1;
|
|
|
4854
5345
|
*/
|
|
4855
5346
|
const jit = jit$1;
|
|
4856
5347
|
/**
|
|
5348
|
+
* @function
|
|
4857
5349
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
4858
5350
|
* partial evaluation.
|
|
4859
5351
|
*/
|
|
4860
5352
|
const linearize = linearize$1;
|
|
4861
|
-
/**
|
|
5353
|
+
/**
|
|
5354
|
+
* @function
|
|
5355
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
5356
|
+
*/
|
|
4862
5357
|
const vjp = vjp$1;
|
|
4863
5358
|
/**
|
|
5359
|
+
* @function
|
|
4864
5360
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
4865
5361
|
* first argument.
|
|
4866
5362
|
*/
|
|
4867
5363
|
const grad = grad$1;
|
|
4868
|
-
/**
|
|
5364
|
+
/**
|
|
5365
|
+
* @function
|
|
5366
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
5367
|
+
*/
|
|
4869
5368
|
const valueAndGrad = valueAndGrad$1;
|
|
4870
|
-
/**
|
|
5369
|
+
/**
|
|
5370
|
+
* @function
|
|
5371
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
5372
|
+
*/
|
|
4871
5373
|
const jacrev = jacrev$1;
|
|
4872
|
-
/**
|
|
5374
|
+
/**
|
|
5375
|
+
* @function
|
|
5376
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
5377
|
+
*/
|
|
4873
5378
|
const jacobian = jacrev;
|
|
5379
|
+
/**
|
|
5380
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
5381
|
+
*
|
|
5382
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
5383
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
5384
|
+
* to avoid queueing up too many pending operations.
|
|
5385
|
+
*
|
|
5386
|
+
* Does not consume reference to the arrays.
|
|
5387
|
+
*/
|
|
5388
|
+
async function blockUntilReady(x) {
|
|
5389
|
+
const promises = [];
|
|
5390
|
+
for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
|
|
5391
|
+
await Promise.all(promises);
|
|
5392
|
+
return x;
|
|
5393
|
+
}
|
|
4874
5394
|
|
|
4875
5395
|
//#endregion
|
|
5396
|
+
exports.Array = Array$1;
|
|
4876
5397
|
exports.DType = require_backend.DType;
|
|
5398
|
+
exports.Jaxpr = Jaxpr;
|
|
5399
|
+
exports.blockUntilReady = blockUntilReady;
|
|
5400
|
+
exports.defaultDevice = require_backend.defaultDevice;
|
|
4877
5401
|
exports.devices = require_backend.devices;
|
|
4878
5402
|
exports.grad = grad;
|
|
4879
5403
|
exports.init = require_backend.init;
|
|
@@ -4908,7 +5432,7 @@ Object.defineProperty(exports, 'random', {
|
|
|
4908
5432
|
return random_exports;
|
|
4909
5433
|
}
|
|
4910
5434
|
});
|
|
4911
|
-
exports.
|
|
5435
|
+
exports.setDebug = require_backend.setDebug;
|
|
4912
5436
|
Object.defineProperty(exports, 'tree', {
|
|
4913
5437
|
enumerable: true,
|
|
4914
5438
|
get: function () {
|