@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache,
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CdcTZEOF.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -323,6 +323,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
323
323
|
Primitive$1["RandomBits"] = "random_bits";
|
|
324
324
|
Primitive$1["Sin"] = "sin";
|
|
325
325
|
Primitive$1["Cos"] = "cos";
|
|
326
|
+
Primitive$1["Asin"] = "asin";
|
|
327
|
+
Primitive$1["Atan"] = "atan";
|
|
326
328
|
Primitive$1["Exp"] = "exp";
|
|
327
329
|
Primitive$1["Log"] = "log";
|
|
328
330
|
Primitive$1["Sqrt"] = "sqrt";
|
|
@@ -390,6 +392,12 @@ function sin$1(x) {
|
|
|
390
392
|
function cos$1(x) {
|
|
391
393
|
return bind1(Primitive.Cos, [x]);
|
|
392
394
|
}
|
|
395
|
+
function asin$1(x) {
|
|
396
|
+
return bind1(Primitive.Asin, [x]);
|
|
397
|
+
}
|
|
398
|
+
function atan$1(x) {
|
|
399
|
+
return bind1(Primitive.Atan, [x]);
|
|
400
|
+
}
|
|
393
401
|
function exp$1(x) {
|
|
394
402
|
return bind1(Primitive.Exp, [x]);
|
|
395
403
|
}
|
|
@@ -405,18 +413,16 @@ function min$1(x, y) {
|
|
|
405
413
|
function max$1(x, y) {
|
|
406
414
|
return bind1(Primitive.Max, [x, y]);
|
|
407
415
|
}
|
|
408
|
-
function reduce(x, op, axis, opts) {
|
|
416
|
+
function reduce(x, op, axis = null, opts) {
|
|
409
417
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
410
|
-
|
|
411
|
-
else axis = [];
|
|
412
|
-
else if (typeof axis === "number") axis = [checkAxis(axis, ndim$1(x))];
|
|
413
|
-
else axis = axis.map((a) => checkAxis(a, ndim$1(x)));
|
|
418
|
+
axis = normalizeAxis(axis, ndim$1(x));
|
|
414
419
|
const originalShape = getShape(x);
|
|
415
|
-
|
|
420
|
+
let result = bind1(Primitive.Reduce, [x], {
|
|
416
421
|
op,
|
|
417
422
|
axis
|
|
418
423
|
});
|
|
419
|
-
|
|
424
|
+
if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
|
|
425
|
+
return result;
|
|
420
426
|
}
|
|
421
427
|
function dot$1(x, y) {
|
|
422
428
|
return bind1(Primitive.Dot, [x, y]);
|
|
@@ -462,10 +468,11 @@ function where$1(cond, x, y) {
|
|
|
462
468
|
}
|
|
463
469
|
function transpose$1(x, perm) {
|
|
464
470
|
perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
|
|
471
|
+
if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
465
472
|
return bind1(Primitive.Transpose, [x], { perm });
|
|
466
473
|
}
|
|
467
474
|
function broadcast(x, shape$1, axis) {
|
|
468
|
-
axis = axis
|
|
475
|
+
axis = normalizeAxis(axis, shape$1.length);
|
|
469
476
|
return bind1(Primitive.Broadcast, [x], {
|
|
470
477
|
shape: shape$1,
|
|
471
478
|
axis
|
|
@@ -484,7 +491,7 @@ function reshape$1(x, shape$1) {
|
|
|
484
491
|
return bind1(Primitive.Reshape, [x], { shape: shape$1 });
|
|
485
492
|
}
|
|
486
493
|
function flip$1(x, axis) {
|
|
487
|
-
axis = axis
|
|
494
|
+
axis = normalizeAxis(axis, ndim$1(x));
|
|
488
495
|
return bind1(Primitive.Flip, [x], { axis });
|
|
489
496
|
}
|
|
490
497
|
function shrink(x, slice) {
|
|
@@ -558,21 +565,49 @@ var Trace = class {
|
|
|
558
565
|
this.main = main;
|
|
559
566
|
}
|
|
560
567
|
};
|
|
568
|
+
/**
|
|
569
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
570
|
+
*
|
|
571
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
572
|
+
* implemented in that function as `weakType` is not passed.
|
|
573
|
+
*/
|
|
574
|
+
function promoteAvals(a, b) {
|
|
575
|
+
const shape$1 = generalBroadcast(a.shape, b.shape);
|
|
576
|
+
const weakType = a.weakType && b.weakType;
|
|
577
|
+
let dtype;
|
|
578
|
+
if (a.weakType === b.weakType) dtype = promoteTypes(a.dtype, b.dtype);
|
|
579
|
+
else if (a.weakType) dtype = promoteTypes(b.dtype, DType.Uint32);
|
|
580
|
+
else dtype = promoteTypes(a.dtype, DType.Uint32);
|
|
581
|
+
return new ShapedArray(shape$1, dtype, weakType);
|
|
582
|
+
}
|
|
561
583
|
var Tracer = class Tracer {
|
|
562
584
|
/** @ignore */
|
|
563
585
|
_trace;
|
|
564
586
|
constructor(trace) {
|
|
565
587
|
this._trace = trace;
|
|
566
588
|
}
|
|
589
|
+
/** The shape of the array. */
|
|
567
590
|
get shape() {
|
|
568
591
|
return this.aval.shape;
|
|
569
592
|
}
|
|
593
|
+
/** The total number of elements in the array. */
|
|
570
594
|
get size() {
|
|
571
595
|
return prod(this.shape);
|
|
572
596
|
}
|
|
597
|
+
/** The dtype of elements stored in the array. */
|
|
573
598
|
get dtype() {
|
|
574
599
|
return this.aval.dtype;
|
|
575
600
|
}
|
|
601
|
+
/**
|
|
602
|
+
* Whether the array is weakly typed.
|
|
603
|
+
*
|
|
604
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
605
|
+
* `promoteTypes()` for details.
|
|
606
|
+
*/
|
|
607
|
+
get weakType() {
|
|
608
|
+
return this.aval.weakType;
|
|
609
|
+
}
|
|
610
|
+
/** The number of dimensions of the array. */
|
|
576
611
|
get ndim() {
|
|
577
612
|
return this.shape.length;
|
|
578
613
|
}
|
|
@@ -608,22 +643,20 @@ var Tracer = class Tracer {
|
|
|
608
643
|
return lessEqual$1(this, other);
|
|
609
644
|
}
|
|
610
645
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
611
|
-
sum(axis, opts) {
|
|
646
|
+
sum(axis = null, opts) {
|
|
612
647
|
return reduce(this, AluOp.Add, axis, opts);
|
|
613
648
|
}
|
|
614
649
|
/** Product of the array elements over a given axis. */
|
|
615
|
-
prod(axis, opts) {
|
|
650
|
+
prod(axis = null, opts) {
|
|
616
651
|
return reduce(this, AluOp.Mul, axis, opts);
|
|
617
652
|
}
|
|
618
653
|
/** Compute the average of the array elements along the specified axis. */
|
|
619
|
-
mean(axis, opts) {
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
626
|
-
return result;
|
|
654
|
+
mean(axis = null, opts) {
|
|
655
|
+
axis = normalizeAxis(axis, this.ndim);
|
|
656
|
+
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
657
|
+
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
658
|
+
const result = reduce(this, AluOp.Add, axis, opts);
|
|
659
|
+
return result.mul(1 / n);
|
|
627
660
|
}
|
|
628
661
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
629
662
|
transpose(perm) {
|
|
@@ -810,12 +843,13 @@ function getShape(x) {
|
|
|
810
843
|
return x instanceof Tracer ? x.shape : [];
|
|
811
844
|
}
|
|
812
845
|
var ShapedArray = class ShapedArray {
|
|
813
|
-
constructor(shape$1, dtype) {
|
|
846
|
+
constructor(shape$1, dtype, weakType) {
|
|
814
847
|
this.shape = shape$1;
|
|
815
848
|
this.dtype = dtype;
|
|
849
|
+
this.weakType = weakType;
|
|
816
850
|
}
|
|
817
851
|
static fromAval(aval) {
|
|
818
|
-
return new ShapedArray(aval.shape, aval.dtype);
|
|
852
|
+
return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
|
|
819
853
|
}
|
|
820
854
|
get ndim() {
|
|
821
855
|
return this.shape.length;
|
|
@@ -829,7 +863,7 @@ var ShapedArray = class ShapedArray {
|
|
|
829
863
|
};
|
|
830
864
|
function getAval(x) {
|
|
831
865
|
if (x instanceof Tracer) return x.aval;
|
|
832
|
-
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32);
|
|
866
|
+
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32, typeof x === "boolean" ? false : true);
|
|
833
867
|
else throw new TypeError(`Unknown value: ${x}`);
|
|
834
868
|
}
|
|
835
869
|
function bind(prim, args, params = {}) {
|
|
@@ -1151,11 +1185,13 @@ const jitRules = {
|
|
|
1151
1185
|
const k1 = reshapeViews(keys[1], mapping);
|
|
1152
1186
|
const c0 = AluExp.u32(0);
|
|
1153
1187
|
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1154
|
-
const exp$2 = AluExp.threefry2x32(
|
|
1188
|
+
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1155
1189
|
return new Kernel(nargs, prod(shape$1), exp$2);
|
|
1156
1190
|
},
|
|
1157
1191
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
1158
1192
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
1193
|
+
[Primitive.Asin]: unopJit(AluExp.asin),
|
|
1194
|
+
[Primitive.Atan]: unopJit(AluExp.atan),
|
|
1159
1195
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
1160
1196
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1161
1197
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
@@ -1190,7 +1226,7 @@ const jitRules = {
|
|
|
1190
1226
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
1191
1227
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
1192
1228
|
const c = k1.exp;
|
|
1193
|
-
const cs =
|
|
1229
|
+
const cs = promoteAvals(as, bs);
|
|
1194
1230
|
return jitRules[Primitive.Reduce](nargs, [c], [cs], {
|
|
1195
1231
|
op: AluOp.Add,
|
|
1196
1232
|
axis: [cs.ndim - 1]
|
|
@@ -1200,8 +1236,8 @@ const jitRules = {
|
|
|
1200
1236
|
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1201
1237
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1202
1238
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1203
|
-
as = new ShapedArray(stX.shape, as.dtype);
|
|
1204
|
-
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1239
|
+
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1240
|
+
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1205
1241
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1206
1242
|
},
|
|
1207
1243
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
@@ -1254,9 +1290,10 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1254
1290
|
Primitive.Conv,
|
|
1255
1291
|
Primitive.PoolTranspose
|
|
1256
1292
|
];
|
|
1293
|
+
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
1257
1294
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1258
1295
|
const eqn = jaxpr.eqns[i];
|
|
1259
|
-
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive
|
|
1296
|
+
if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1260
1297
|
for (const v of eqn.outBinders) {
|
|
1261
1298
|
blackNodes.add(v);
|
|
1262
1299
|
p1NextBlack.set(v, v);
|
|
@@ -1386,6 +1423,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1386
1423
|
static #nextId = 1001;
|
|
1387
1424
|
id;
|
|
1388
1425
|
#dtype;
|
|
1426
|
+
#weakType;
|
|
1389
1427
|
#source;
|
|
1390
1428
|
#st;
|
|
1391
1429
|
#backend;
|
|
@@ -1397,19 +1435,22 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1397
1435
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1398
1436
|
* will be freed when the array is disposed.
|
|
1399
1437
|
*/
|
|
1400
|
-
constructor(
|
|
1438
|
+
constructor(args) {
|
|
1401
1439
|
super(baseArrayTrace);
|
|
1402
1440
|
this.id = Array$1.#nextId++;
|
|
1403
|
-
this.#dtype = dtype;
|
|
1404
|
-
this.#
|
|
1405
|
-
this.#
|
|
1406
|
-
this.#
|
|
1441
|
+
this.#dtype = args.dtype;
|
|
1442
|
+
this.#weakType = args.weakType;
|
|
1443
|
+
this.#source = args.source;
|
|
1444
|
+
this.#st = args.st;
|
|
1445
|
+
this.#backend = args.backend;
|
|
1407
1446
|
this.#rc = 1;
|
|
1408
|
-
this.#pendingSet = new Set(pending);
|
|
1447
|
+
this.#pendingSet = new Set(args.pending);
|
|
1448
|
+
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1449
|
+
else if (this.#source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1409
1450
|
}
|
|
1410
1451
|
/** @ignore */
|
|
1411
1452
|
get aval() {
|
|
1412
|
-
return new ShapedArray(this.#st.shape, this.#dtype);
|
|
1453
|
+
return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
|
|
1413
1454
|
}
|
|
1414
1455
|
/** Return a simple string representation of the array's dimensions. */
|
|
1415
1456
|
toString() {
|
|
@@ -1421,6 +1462,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1421
1462
|
#check() {
|
|
1422
1463
|
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1423
1464
|
}
|
|
1465
|
+
/** Construct an array, copying fields from `this`. */
|
|
1466
|
+
#newArrayFrom(args) {
|
|
1467
|
+
return new Array$1({
|
|
1468
|
+
source: args.source ?? this.#source,
|
|
1469
|
+
st: args.st ?? this.#st,
|
|
1470
|
+
dtype: args.dtype ?? this.#dtype,
|
|
1471
|
+
weakType: this.#weakType,
|
|
1472
|
+
backend: args.backend ?? this.#backend,
|
|
1473
|
+
pending: args.pending ?? this.#pending ?? void 0
|
|
1474
|
+
});
|
|
1475
|
+
}
|
|
1424
1476
|
get ref() {
|
|
1425
1477
|
this.#check();
|
|
1426
1478
|
this.#rc++;
|
|
@@ -1460,7 +1512,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1460
1512
|
const pending = this.#pending;
|
|
1461
1513
|
for (const exe of pending) exe.updateRc(1);
|
|
1462
1514
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1463
|
-
const ar =
|
|
1515
|
+
const ar = this.#newArrayFrom({
|
|
1516
|
+
st,
|
|
1517
|
+
pending
|
|
1518
|
+
});
|
|
1464
1519
|
this.dispose();
|
|
1465
1520
|
return ar;
|
|
1466
1521
|
}
|
|
@@ -1509,7 +1564,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1509
1564
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1510
1565
|
this.dispose();
|
|
1511
1566
|
for (const ar of indices) ar.dispose();
|
|
1512
|
-
return
|
|
1567
|
+
return this.#newArrayFrom({
|
|
1568
|
+
source: output,
|
|
1569
|
+
st: ShapeTracker.fromShape(finalShape),
|
|
1570
|
+
pending
|
|
1571
|
+
});
|
|
1513
1572
|
}
|
|
1514
1573
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1515
1574
|
#moveAxesDown(axis) {
|
|
@@ -1532,11 +1591,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1532
1591
|
return this.#reshape(this.#st.permute(perm));
|
|
1533
1592
|
}
|
|
1534
1593
|
#unary(op, dtypeOutput) {
|
|
1594
|
+
const weakType = !dtypeOutput && this.#weakType;
|
|
1535
1595
|
dtypeOutput ??= this.#dtype;
|
|
1536
1596
|
this.#check();
|
|
1537
1597
|
if (this.#source instanceof AluExp) {
|
|
1538
1598
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1539
|
-
return
|
|
1599
|
+
return this.#newArrayFrom({
|
|
1600
|
+
source: exp$3.simplify(),
|
|
1601
|
+
dtype: dtypeOutput,
|
|
1602
|
+
weakType
|
|
1603
|
+
});
|
|
1540
1604
|
}
|
|
1541
1605
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1542
1606
|
const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1546,41 +1610,65 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1546
1610
|
for (const exe of pending) exe.updateRc(1);
|
|
1547
1611
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1548
1612
|
this.dispose();
|
|
1549
|
-
return
|
|
1613
|
+
return this.#newArrayFrom({
|
|
1614
|
+
source: output,
|
|
1615
|
+
st: ShapeTracker.fromShape(this.shape),
|
|
1616
|
+
dtype: dtypeOutput,
|
|
1617
|
+
weakType,
|
|
1618
|
+
pending
|
|
1619
|
+
});
|
|
1550
1620
|
}
|
|
1551
1621
|
#binary(op, other) {
|
|
1552
|
-
const custom = (src) => new AluExp(op,
|
|
1622
|
+
const custom = (src) => new AluExp(op, src[0].dtype, src);
|
|
1553
1623
|
return Array$1.#naryCustom(op, custom, [this, other]);
|
|
1554
1624
|
}
|
|
1555
|
-
static #naryCustom(name, custom, arrays, { dtypeOverride,
|
|
1625
|
+
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1556
1626
|
const n = arrays.length;
|
|
1557
1627
|
const backend = arrays[0].#backend;
|
|
1558
1628
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1559
1629
|
for (const ar of arrays) ar.#check();
|
|
1560
|
-
let
|
|
1630
|
+
let castDtype;
|
|
1631
|
+
let castWeakType = true;
|
|
1561
1632
|
for (let i = 0; i < n; i++) {
|
|
1562
1633
|
if (dtypeOverride?.[i]) {
|
|
1563
1634
|
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1564
|
-
} else if (
|
|
1565
|
-
|
|
1635
|
+
} else if (castDtype === void 0) {
|
|
1636
|
+
castDtype = arrays[i].#dtype;
|
|
1637
|
+
castWeakType = arrays[i].#weakType;
|
|
1638
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1566
1639
|
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1567
1640
|
}
|
|
1568
|
-
|
|
1569
|
-
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1641
|
+
const weakType = castWeakType && !strongTypeOutput;
|
|
1570
1642
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1571
1643
|
const newShape = [...arrays[0].shape];
|
|
1572
1644
|
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
1645
|
+
const sources = arrays.map((ar, i) => {
|
|
1646
|
+
if (!dtypeOverride?.[i]) return AluExp.cast(castDtype, ar.#source);
|
|
1647
|
+
else return ar.#source;
|
|
1648
|
+
});
|
|
1573
1649
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1574
|
-
const exp$4 = custom(
|
|
1575
|
-
return new Array$1(
|
|
1650
|
+
const exp$4 = custom(sources);
|
|
1651
|
+
return new Array$1({
|
|
1652
|
+
source: exp$4.simplify(),
|
|
1653
|
+
st: arrays[0].#st,
|
|
1654
|
+
dtype: exp$4.dtype,
|
|
1655
|
+
weakType,
|
|
1656
|
+
backend
|
|
1657
|
+
});
|
|
1576
1658
|
}
|
|
1577
|
-
const exp$3 = custom(arrays.map((ar) => {
|
|
1578
|
-
const src$1 =
|
|
1659
|
+
const exp$3 = custom(arrays.map((ar, i) => {
|
|
1660
|
+
const src$1 = sources[i];
|
|
1579
1661
|
if (ar.#st.contiguous) return src$1;
|
|
1580
1662
|
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1581
1663
|
}));
|
|
1582
1664
|
const st = ShapeTracker.fromShape(newShape);
|
|
1583
|
-
return new Array$1(
|
|
1665
|
+
return new Array$1({
|
|
1666
|
+
source: exp$3.simplify(),
|
|
1667
|
+
st,
|
|
1668
|
+
dtype: exp$3.dtype,
|
|
1669
|
+
weakType,
|
|
1670
|
+
backend
|
|
1671
|
+
});
|
|
1584
1672
|
}
|
|
1585
1673
|
let indices;
|
|
1586
1674
|
if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
|
|
@@ -1590,14 +1678,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1590
1678
|
}
|
|
1591
1679
|
const inputs = [];
|
|
1592
1680
|
const src = [];
|
|
1593
|
-
for (const ar of arrays
|
|
1594
|
-
|
|
1595
|
-
|
|
1596
|
-
|
|
1597
|
-
gid = inputs.
|
|
1598
|
-
|
|
1681
|
+
for (const [i, ar] of arrays.entries()) {
|
|
1682
|
+
let nextSrc;
|
|
1683
|
+
if (ar.#source instanceof AluExp) nextSrc = accessorAluExp(ar.#source, ar.#st, indices);
|
|
1684
|
+
else {
|
|
1685
|
+
let gid = inputs.indexOf(ar.#source);
|
|
1686
|
+
if (gid === -1) {
|
|
1687
|
+
gid = inputs.length;
|
|
1688
|
+
inputs.push(ar.#source);
|
|
1689
|
+
}
|
|
1690
|
+
nextSrc = AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
|
|
1599
1691
|
}
|
|
1600
|
-
|
|
1692
|
+
if (!dtypeOverride?.[i]) nextSrc = AluExp.cast(castDtype, nextSrc);
|
|
1693
|
+
src.push(nextSrc);
|
|
1601
1694
|
}
|
|
1602
1695
|
const exp$2 = custom(src);
|
|
1603
1696
|
let re = void 0;
|
|
@@ -1611,12 +1704,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1611
1704
|
for (const exe of pending) exe.updateRc(1);
|
|
1612
1705
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1613
1706
|
for (const ar of arrays) ar.dispose();
|
|
1614
|
-
return new Array$1(
|
|
1707
|
+
return new Array$1({
|
|
1708
|
+
source: output,
|
|
1709
|
+
st: ShapeTracker.fromShape(newShape),
|
|
1710
|
+
dtype: kernel.dtype,
|
|
1711
|
+
weakType,
|
|
1712
|
+
backend,
|
|
1713
|
+
pending
|
|
1714
|
+
});
|
|
1615
1715
|
}
|
|
1616
1716
|
/** Reduce the last dimension of the array by an operation. */
|
|
1617
1717
|
#reduce(op) {
|
|
1618
|
-
this.#check();
|
|
1619
|
-
if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
|
|
1620
1718
|
const shape$1 = this.shape;
|
|
1621
1719
|
const reduction = new Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
|
|
1622
1720
|
const newShape = shape$1.slice(0, -1);
|
|
@@ -1635,7 +1733,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1635
1733
|
for (const exe of pending) exe.updateRc(1);
|
|
1636
1734
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1637
1735
|
this.dispose();
|
|
1638
|
-
return
|
|
1736
|
+
return this.#newArrayFrom({
|
|
1737
|
+
source: output,
|
|
1738
|
+
st: ShapeTracker.fromShape(newShape),
|
|
1739
|
+
pending
|
|
1740
|
+
});
|
|
1639
1741
|
}
|
|
1640
1742
|
/**
|
|
1641
1743
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1671,8 +1773,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1671
1773
|
}
|
|
1672
1774
|
#dataInline() {
|
|
1673
1775
|
this.#check();
|
|
1674
|
-
|
|
1675
|
-
const ar =
|
|
1776
|
+
if (!(this.#source instanceof AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
|
|
1777
|
+
const ar = this.#newArrayFrom({ backend: getBackend("cpu") });
|
|
1676
1778
|
this.dispose();
|
|
1677
1779
|
return ar.dataSync();
|
|
1678
1780
|
}
|
|
@@ -1708,8 +1810,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1708
1810
|
*
|
|
1709
1811
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1710
1812
|
* dispatch of operations as well.
|
|
1813
|
+
*
|
|
1814
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1815
|
+
* asynchronously for multiple arrays.
|
|
1711
1816
|
*/
|
|
1712
|
-
async
|
|
1817
|
+
async blockUntilReady() {
|
|
1713
1818
|
this.#check();
|
|
1714
1819
|
if (this.#source instanceof AluExp) return this;
|
|
1715
1820
|
const pending = this.#pending;
|
|
@@ -1775,7 +1880,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1775
1880
|
return [x.#binary(AluOp.Idiv, y)];
|
|
1776
1881
|
},
|
|
1777
1882
|
[Primitive.Neg]([x]) {
|
|
1778
|
-
return [zerosLike(x.ref).#binary(AluOp.Sub, x)];
|
|
1883
|
+
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
1779
1884
|
},
|
|
1780
1885
|
[Primitive.Reciprocal]([x]) {
|
|
1781
1886
|
return [x.#unary(AluOp.Reciprocal)];
|
|
@@ -1795,7 +1900,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1795
1900
|
x.#backend.incRef(x.#source);
|
|
1796
1901
|
const pending = x.#pending;
|
|
1797
1902
|
for (const exe of pending) exe.updateRc(1);
|
|
1798
|
-
const y =
|
|
1903
|
+
const y = x.#newArrayFrom({
|
|
1904
|
+
dtype,
|
|
1905
|
+
weakType: false,
|
|
1906
|
+
pending
|
|
1907
|
+
});
|
|
1799
1908
|
x.dispose();
|
|
1800
1909
|
return [y];
|
|
1801
1910
|
}
|
|
@@ -1825,6 +1934,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1825
1934
|
[Primitive.Cos]([x]) {
|
|
1826
1935
|
return [x.#unary(AluOp.Cos)];
|
|
1827
1936
|
},
|
|
1937
|
+
[Primitive.Asin]([x]) {
|
|
1938
|
+
return [x.#unary(AluOp.Asin)];
|
|
1939
|
+
},
|
|
1940
|
+
[Primitive.Atan]([x]) {
|
|
1941
|
+
return [x.#unary(AluOp.Atan)];
|
|
1942
|
+
},
|
|
1828
1943
|
[Primitive.Exp]([x]) {
|
|
1829
1944
|
return [x.#unary(AluOp.Exp)];
|
|
1830
1945
|
},
|
|
@@ -1864,7 +1979,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1864
1979
|
},
|
|
1865
1980
|
[Primitive.Compare]([x, y], { op }) {
|
|
1866
1981
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1867
|
-
return [Array$1.#naryCustom("compare", custom, [x, y], {
|
|
1982
|
+
return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
|
|
1868
1983
|
},
|
|
1869
1984
|
[Primitive.Where]([cond, x, y]) {
|
|
1870
1985
|
const custom = ([cond$1, x$1, y$1]) => AluExp.where(cond$1, x$1, y$1);
|
|
@@ -1910,7 +2025,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1910
2025
|
pending.splice(0, 0, ...prevPending);
|
|
1911
2026
|
args.forEach((x) => x.dispose());
|
|
1912
2027
|
return outputs.map((source, i) => {
|
|
1913
|
-
return new Array$1(
|
|
2028
|
+
return new Array$1({
|
|
2029
|
+
source,
|
|
2030
|
+
st: ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
|
|
2031
|
+
dtype: jaxpr.outs[i].aval.dtype,
|
|
2032
|
+
weakType: jaxpr.outs[i].aval.weakType,
|
|
2033
|
+
backend,
|
|
2034
|
+
pending
|
|
2035
|
+
});
|
|
1914
2036
|
});
|
|
1915
2037
|
}
|
|
1916
2038
|
};
|
|
@@ -1920,33 +2042,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1920
2042
|
return this.#source;
|
|
1921
2043
|
}
|
|
1922
2044
|
};
|
|
1923
|
-
/** Construct an array from a single scalar constant. */
|
|
1924
|
-
function scalar(value, { dtype, device } = {}) {
|
|
1925
|
-
if (typeof value === "number") {
|
|
1926
|
-
dtype ??= DType.Float32;
|
|
1927
|
-
if (![
|
|
1928
|
-
DType.Float32,
|
|
1929
|
-
DType.Float16,
|
|
1930
|
-
DType.Int32,
|
|
1931
|
-
DType.Uint32
|
|
1932
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1933
|
-
} else if (typeof value === "boolean") {
|
|
1934
|
-
dtype ??= DType.Bool;
|
|
1935
|
-
if (![
|
|
1936
|
-
DType.Float32,
|
|
1937
|
-
DType.Float16,
|
|
1938
|
-
DType.Int32,
|
|
1939
|
-
DType.Uint32,
|
|
1940
|
-
DType.Bool
|
|
1941
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1942
|
-
} else throw new TypeError(`Invalid type for scalar ${value}`);
|
|
1943
|
-
return new Array$1(AluExp.const(dtype, value), ShapeTracker.fromShape([]), dtype, getBackend(device));
|
|
1944
|
-
}
|
|
1945
2045
|
/** Constructor for creating a new array from data. */
|
|
1946
2046
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
1947
2047
|
if (values instanceof Tracer) {
|
|
1948
2048
|
if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1949
|
-
if (dtype && values.dtype !== dtype)
|
|
2049
|
+
if (dtype && values.dtype !== dtype) values = values.astype(dtype);
|
|
1950
2050
|
return values;
|
|
1951
2051
|
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1952
2052
|
dtype,
|
|
@@ -1968,6 +2068,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1968
2068
|
dtype,
|
|
1969
2069
|
device
|
|
1970
2070
|
});
|
|
2071
|
+
if (size$1 === 1) return full(shape$1, flat[0], {
|
|
2072
|
+
dtype,
|
|
2073
|
+
device
|
|
2074
|
+
});
|
|
1971
2075
|
if (typeof flat[0] === "boolean") {
|
|
1972
2076
|
dtype = dtype ?? DType.Bool;
|
|
1973
2077
|
const data = new Int32Array(flat.map((x) => x ? 1 : 0));
|
|
@@ -1976,46 +2080,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1976
2080
|
device
|
|
1977
2081
|
});
|
|
1978
2082
|
} else {
|
|
2083
|
+
const weakType = dtype == void 0;
|
|
1979
2084
|
dtype = dtype ?? DType.Float32;
|
|
1980
2085
|
const data = dtypedJsArray(dtype, flat);
|
|
1981
2086
|
return arrayFromData(data, shape$1, {
|
|
1982
2087
|
dtype,
|
|
1983
2088
|
device
|
|
1984
|
-
});
|
|
2089
|
+
}, weakType);
|
|
1985
2090
|
}
|
|
1986
2091
|
}
|
|
1987
2092
|
}
|
|
1988
|
-
function arrayFromData(data, shape$1, { dtype, device } =
|
|
2093
|
+
function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
2094
|
+
if (data instanceof Float32Array) {
|
|
2095
|
+
if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2096
|
+
dtype ??= DType.Float32;
|
|
2097
|
+
} else if (data instanceof Int32Array) {
|
|
2098
|
+
if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2099
|
+
dtype ??= DType.Int32;
|
|
2100
|
+
} else if (data instanceof Uint32Array) {
|
|
2101
|
+
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2102
|
+
dtype ??= DType.Uint32;
|
|
2103
|
+
} else if (data instanceof Float16Array) {
|
|
2104
|
+
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2105
|
+
dtype ??= DType.Float16;
|
|
2106
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
1989
2107
|
if (data.length < inlineArrayLimit) {
|
|
1990
2108
|
let allEqual = true;
|
|
1991
2109
|
for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
|
|
1992
2110
|
allEqual = false;
|
|
1993
2111
|
break;
|
|
1994
2112
|
}
|
|
1995
|
-
if (allEqual)
|
|
1996
|
-
dtype,
|
|
1997
|
-
device
|
|
1998
|
-
}
|
|
2113
|
+
if (allEqual) {
|
|
2114
|
+
const sa = new ShapedArray(shape$1, dtype, weakType);
|
|
2115
|
+
return fullInternal(sa, data[0], device);
|
|
2116
|
+
}
|
|
1999
2117
|
}
|
|
2000
2118
|
const backend = getBackend(device);
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2004
|
-
|
|
2005
|
-
|
|
2006
|
-
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2011
|
-
dtype ??= DType.Uint32;
|
|
2012
|
-
} else if (data instanceof Float16Array) {
|
|
2013
|
-
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2014
|
-
dtype ??= DType.Float16;
|
|
2015
|
-
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2016
|
-
const slot = backend.malloc(data.byteLength, buf);
|
|
2017
|
-
return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2018
|
-
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
2119
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2120
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2121
|
+
return new Array$1({
|
|
2122
|
+
source: slot,
|
|
2123
|
+
st: ShapeTracker.fromShape(shape$1),
|
|
2124
|
+
dtype,
|
|
2125
|
+
weakType,
|
|
2126
|
+
backend
|
|
2127
|
+
});
|
|
2019
2128
|
}
|
|
2020
2129
|
function dataToJs(dtype, data, shape$1) {
|
|
2021
2130
|
if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -2031,7 +2140,7 @@ function dataToJs(dtype, data, shape$1) {
|
|
|
2031
2140
|
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
2032
2141
|
function pureArray(x) {
|
|
2033
2142
|
if (x instanceof Tracer) return x;
|
|
2034
|
-
else return
|
|
2143
|
+
else return array(x);
|
|
2035
2144
|
}
|
|
2036
2145
|
var EvalTrace = class extends Trace {
|
|
2037
2146
|
pure = (x) => pureArray(x);
|
|
@@ -2042,20 +2151,27 @@ var EvalTrace = class extends Trace {
|
|
|
2042
2151
|
};
|
|
2043
2152
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2044
2153
|
const implRules = Array$1._implRules();
|
|
2045
|
-
function
|
|
2046
|
-
|
|
2047
|
-
|
|
2048
|
-
|
|
2154
|
+
function fullInternal(aval, fillValue, device) {
|
|
2155
|
+
return new Array$1({
|
|
2156
|
+
source: AluExp.const(aval.dtype, fillValue),
|
|
2157
|
+
st: ShapeTracker.fromShape(aval.shape),
|
|
2158
|
+
dtype: aval.dtype,
|
|
2159
|
+
weakType: aval.weakType,
|
|
2160
|
+
backend: getBackend(device)
|
|
2161
|
+
});
|
|
2049
2162
|
}
|
|
2050
|
-
function
|
|
2051
|
-
|
|
2052
|
-
|
|
2053
|
-
|
|
2163
|
+
function zerosLike$1(val, dtype) {
|
|
2164
|
+
return fullLike(val, 0, dtype);
|
|
2165
|
+
}
|
|
2166
|
+
function onesLike$1(val, dtype) {
|
|
2167
|
+
return fullLike(val, 1, dtype);
|
|
2054
2168
|
}
|
|
2055
2169
|
function fullLike(val, fillValue, dtype) {
|
|
2056
2170
|
const aval = getAval(val);
|
|
2057
2171
|
if (val instanceof Tracer) val.dispose();
|
|
2058
|
-
|
|
2172
|
+
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
2173
|
+
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
2174
|
+
return fullInternal(sa, fillValue);
|
|
2059
2175
|
}
|
|
2060
2176
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
2061
2177
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -2073,19 +2189,14 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
2073
2189
|
}
|
|
2074
2190
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
2075
2191
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
2076
|
-
let
|
|
2077
|
-
if (typeof fillValue === "number")
|
|
2078
|
-
|
|
2079
|
-
source = AluExp.const(dtype, fillValue);
|
|
2080
|
-
} else if (typeof fillValue === "bigint") {
|
|
2081
|
-
dtype = dtype ?? DType.Int32;
|
|
2082
|
-
source = AluExp.const(dtype, Number(fillValue));
|
|
2083
|
-
} else if (typeof fillValue === "boolean") {
|
|
2192
|
+
let weakType = dtype == void 0;
|
|
2193
|
+
if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
|
|
2194
|
+
else if (typeof fillValue === "boolean") {
|
|
2084
2195
|
dtype = dtype ?? DType.Bool;
|
|
2085
|
-
|
|
2196
|
+
weakType = false;
|
|
2086
2197
|
} else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
|
|
2087
2198
|
else throw new TypeError(`Invalid type for full: ${fillValue}`);
|
|
2088
|
-
return new
|
|
2199
|
+
return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
|
|
2089
2200
|
}
|
|
2090
2201
|
/**
|
|
2091
2202
|
* Create an identity matrix.
|
|
@@ -2095,6 +2206,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
2095
2206
|
*/
|
|
2096
2207
|
function eye(numRows, numCols, { dtype, device } = {}) {
|
|
2097
2208
|
numCols = numCols ?? numRows;
|
|
2209
|
+
const weakType = dtype == void 0;
|
|
2098
2210
|
dtype = dtype ?? DType.Float32;
|
|
2099
2211
|
if (numCols < numRows) {
|
|
2100
2212
|
const arr = eye(numCols, numRows, {
|
|
@@ -2108,9 +2220,15 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2108
2220
|
device
|
|
2109
2221
|
});
|
|
2110
2222
|
const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
|
|
2111
|
-
return new Array$1(
|
|
2223
|
+
return new Array$1({
|
|
2224
|
+
source: AluExp.cast(dtype, exp$2),
|
|
2225
|
+
st: ShapeTracker.fromShape([numRows, numCols]),
|
|
2226
|
+
dtype,
|
|
2227
|
+
weakType,
|
|
2228
|
+
backend: getBackend(device)
|
|
2229
|
+
});
|
|
2112
2230
|
}
|
|
2113
|
-
/** Return the identity
|
|
2231
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2114
2232
|
function identity$1(n, { dtype, device } = {}) {
|
|
2115
2233
|
return eye(n, n, {
|
|
2116
2234
|
dtype,
|
|
@@ -2145,7 +2263,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2145
2263
|
});
|
|
2146
2264
|
const exp$2 = AluExp.add(AluExp.const(dtype, start), AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step)));
|
|
2147
2265
|
const st = ShapeTracker.fromShape([size$1]);
|
|
2148
|
-
return new Array$1(
|
|
2266
|
+
return new Array$1({
|
|
2267
|
+
source: exp$2,
|
|
2268
|
+
st,
|
|
2269
|
+
dtype,
|
|
2270
|
+
weakType: false,
|
|
2271
|
+
backend: getBackend(device)
|
|
2272
|
+
});
|
|
2149
2273
|
}
|
|
2150
2274
|
/**
|
|
2151
2275
|
* Return evenly spaced numbers over a specified interval.
|
|
@@ -2163,10 +2287,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2163
2287
|
dtype,
|
|
2164
2288
|
device
|
|
2165
2289
|
});
|
|
2166
|
-
else if (num === 1) return
|
|
2290
|
+
else if (num === 1) return full([1], start, {
|
|
2167
2291
|
dtype,
|
|
2168
2292
|
device
|
|
2169
|
-
})
|
|
2293
|
+
});
|
|
2170
2294
|
else if (start === stop) return full([num], start, {
|
|
2171
2295
|
dtype,
|
|
2172
2296
|
device
|
|
@@ -2175,7 +2299,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2175
2299
|
const denom = endpoint ? num - 1 : num;
|
|
2176
2300
|
const exp$2 = AluExp.cast(dtype, AluExp.add(AluExp.f32(start), AluExp.mul(AluExp.f32(delta / denom), AluExp.cast(DType.Float32, AluVar.idx))));
|
|
2177
2301
|
const st = ShapeTracker.fromShape([num]);
|
|
2178
|
-
return new Array$1(
|
|
2302
|
+
return new Array$1({
|
|
2303
|
+
source: exp$2,
|
|
2304
|
+
st,
|
|
2305
|
+
dtype,
|
|
2306
|
+
weakType: false,
|
|
2307
|
+
backend: getBackend(device)
|
|
2308
|
+
});
|
|
2179
2309
|
}
|
|
2180
2310
|
function aluCompare(a, b, op) {
|
|
2181
2311
|
switch (op) {
|
|
@@ -2187,35 +2317,6 @@ function aluCompare(a, b, op) {
|
|
|
2187
2317
|
case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
|
|
2188
2318
|
}
|
|
2189
2319
|
}
|
|
2190
|
-
/**
|
|
2191
|
-
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
2192
|
-
*
|
|
2193
|
-
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
2194
|
-
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
2195
|
-
* Two dimensions are compatible when:
|
|
2196
|
-
* 1. they are equal, or
|
|
2197
|
-
* 2. one of them is 1."
|
|
2198
|
-
*
|
|
2199
|
-
* Throws a TypeError if the broadcast is not possible.
|
|
2200
|
-
*
|
|
2201
|
-
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
2202
|
-
*/
|
|
2203
|
-
function generalBroadcast(a, b) {
|
|
2204
|
-
const out = [];
|
|
2205
|
-
let i = a.length - 1;
|
|
2206
|
-
let j = b.length - 1;
|
|
2207
|
-
for (; i >= 0 && j >= 0; i--, j--) {
|
|
2208
|
-
const x = a[i];
|
|
2209
|
-
const y = b[j];
|
|
2210
|
-
if (x === y) out.push(x);
|
|
2211
|
-
else if (x === 1) out.push(y);
|
|
2212
|
-
else if (y === 1) out.push(x);
|
|
2213
|
-
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
2214
|
-
}
|
|
2215
|
-
for (; i >= 0; i--) out.push(a[i]);
|
|
2216
|
-
for (; j >= 0; j--) out.push(b[j]);
|
|
2217
|
-
return out.reverse();
|
|
2218
|
-
}
|
|
2219
2320
|
|
|
2220
2321
|
//#endregion
|
|
2221
2322
|
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
|
|
@@ -2291,13 +2392,15 @@ var Var = class Var {
|
|
|
2291
2392
|
};
|
|
2292
2393
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
2293
2394
|
var Lit = class {
|
|
2294
|
-
dtype;
|
|
2295
2395
|
value;
|
|
2296
2396
|
aval;
|
|
2297
|
-
|
|
2298
|
-
this.dtype
|
|
2397
|
+
get dtype() {
|
|
2398
|
+
return this.aval.dtype;
|
|
2399
|
+
}
|
|
2400
|
+
constructor(aval, value) {
|
|
2401
|
+
if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
|
|
2299
2402
|
this.value = value;
|
|
2300
|
-
this.aval =
|
|
2403
|
+
this.aval = ShapedArray.fromAval(aval);
|
|
2301
2404
|
}
|
|
2302
2405
|
};
|
|
2303
2406
|
function atomIsLit(atom, literal) {
|
|
@@ -2386,16 +2489,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2386
2489
|
varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
|
|
2387
2490
|
return id;
|
|
2388
2491
|
};
|
|
2389
|
-
hasher.update(this.inBinders.length
|
|
2390
|
-
|
|
2391
|
-
|
|
2392
|
-
|
|
2393
|
-
|
|
2394
|
-
|
|
2395
|
-
eqn.
|
|
2396
|
-
|
|
2397
|
-
|
|
2398
|
-
|
|
2492
|
+
hasher.update(this.inBinders.length);
|
|
2493
|
+
for (const x of this.inBinders) hasher.update(vi(x));
|
|
2494
|
+
hasher.update(this.eqns.length);
|
|
2495
|
+
for (const eqn of this.eqns) {
|
|
2496
|
+
hasher.update(eqn.primitive);
|
|
2497
|
+
hasher.update(eqn.inputs.length);
|
|
2498
|
+
for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2499
|
+
hasher.update(JSON.stringify(eqn.params));
|
|
2500
|
+
hasher.update(eqn.outBinders.length);
|
|
2501
|
+
for (const x of eqn.outBinders) hasher.update(vi(x));
|
|
2502
|
+
}
|
|
2503
|
+
hasher.update(this.outs.length);
|
|
2504
|
+
for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2399
2505
|
return this.#hash = hasher.value;
|
|
2400
2506
|
}
|
|
2401
2507
|
hash(state) {
|
|
@@ -2418,21 +2524,26 @@ var Jaxpr = class Jaxpr {
|
|
|
2418
2524
|
const c = eqn.outBinders[0];
|
|
2419
2525
|
if (atomIsLit(a, 0)) context.set(c, b);
|
|
2420
2526
|
else if (atomIsLit(b, 0)) context.set(c, a);
|
|
2421
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2527
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
|
|
2528
|
+
else newEqns.push(eqn);
|
|
2529
|
+
} else if (eqn.primitive === Primitive.Neg) {
|
|
2530
|
+
const [a] = inputs;
|
|
2531
|
+
const c = eqn.outBinders[0];
|
|
2532
|
+
if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
|
|
2422
2533
|
else newEqns.push(eqn);
|
|
2423
2534
|
} else if (eqn.primitive === Primitive.Mul) {
|
|
2424
2535
|
const [a, b] = inputs;
|
|
2425
2536
|
const c = eqn.outBinders[0];
|
|
2426
2537
|
if (atomIsLit(a, 1)) context.set(c, b);
|
|
2427
2538
|
else if (atomIsLit(b, 1)) context.set(c, a);
|
|
2428
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2539
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
|
|
2429
2540
|
else newEqns.push(eqn);
|
|
2430
2541
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
2431
2542
|
const [a, b] = inputs;
|
|
2432
2543
|
const c = eqn.outBinders[0];
|
|
2433
2544
|
if (atomIsLit(b, 1)) context.set(c, a);
|
|
2434
2545
|
else newEqns.push(eqn);
|
|
2435
|
-
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2546
|
+
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2436
2547
|
else newEqns.push(eqn);
|
|
2437
2548
|
}
|
|
2438
2549
|
const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
|
|
@@ -2523,7 +2634,7 @@ function evalJaxpr(jaxpr, args) {
|
|
|
2523
2634
|
if (x instanceof Var) {
|
|
2524
2635
|
remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
|
|
2525
2636
|
return env.get(x);
|
|
2526
|
-
} else return
|
|
2637
|
+
} else return array(x.value, { dtype: x.dtype });
|
|
2527
2638
|
};
|
|
2528
2639
|
const write = (v, val) => {
|
|
2529
2640
|
if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
|
|
@@ -2582,7 +2693,7 @@ var JaxprTrace = class extends Trace {
|
|
|
2582
2693
|
let tracer = this.builder.constTracers.get(val);
|
|
2583
2694
|
if (tracer === void 0) {
|
|
2584
2695
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
2585
|
-
this.builder.addConst(tracer, val instanceof Tracer ? val.ref :
|
|
2696
|
+
this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
|
|
2586
2697
|
}
|
|
2587
2698
|
return tracer;
|
|
2588
2699
|
}
|
|
@@ -2651,7 +2762,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2651
2762
|
const newConsts = [];
|
|
2652
2763
|
for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
|
|
2653
2764
|
const ar = consts[i];
|
|
2654
|
-
literals.set(jaxpr.inBinders[i], new Lit(ar.
|
|
2765
|
+
literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
|
|
2655
2766
|
} else {
|
|
2656
2767
|
constBinders.push(jaxpr.inBinders[i]);
|
|
2657
2768
|
newConsts.push(consts[i]);
|
|
@@ -2664,13 +2775,12 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2664
2775
|
}
|
|
2665
2776
|
function binopAbstractEval([x, y]) {
|
|
2666
2777
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
2667
|
-
|
|
2668
|
-
return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
|
|
2778
|
+
return [promoteAvals(x, y)];
|
|
2669
2779
|
}
|
|
2670
2780
|
function compareAbstractEval([x, y]) {
|
|
2671
2781
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
|
|
2672
|
-
|
|
2673
|
-
return [new ShapedArray(
|
|
2782
|
+
const aval = promoteAvals(x, y);
|
|
2783
|
+
return [new ShapedArray(aval.shape, DType.Bool, false)];
|
|
2674
2784
|
}
|
|
2675
2785
|
function vectorizedUnopAbstractEval([x]) {
|
|
2676
2786
|
return [ShapedArray.fromAval(x)];
|
|
@@ -2683,21 +2793,23 @@ const abstractEvalRules = {
|
|
|
2683
2793
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
2684
2794
|
[Primitive.StopGradient]: vectorizedUnopAbstractEval,
|
|
2685
2795
|
[Primitive.Cast]([x], { dtype }) {
|
|
2686
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2796
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2687
2797
|
},
|
|
2688
2798
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
2689
2799
|
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
2690
2800
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
2691
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2801
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2692
2802
|
},
|
|
2693
2803
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
2694
2804
|
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
2695
2805
|
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2696
2806
|
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2697
|
-
return [new ShapedArray(shape$1, DType.Uint32)];
|
|
2807
|
+
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
2698
2808
|
},
|
|
2699
2809
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2700
2810
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2811
|
+
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
2812
|
+
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2701
2813
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2702
2814
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2703
2815
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
@@ -2706,55 +2818,54 @@ const abstractEvalRules = {
|
|
|
2706
2818
|
[Primitive.Reduce]([x], { axis }) {
|
|
2707
2819
|
const axisSet = new Set(axis);
|
|
2708
2820
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2709
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2821
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2710
2822
|
},
|
|
2711
2823
|
[Primitive.Pool]([x], { window, strides }) {
|
|
2712
2824
|
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2713
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2825
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2714
2826
|
},
|
|
2715
2827
|
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2716
2828
|
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2717
2829
|
if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2718
|
-
return [new ShapedArray(inShape, x.dtype)];
|
|
2830
|
+
return [new ShapedArray(inShape, x.dtype, x.weakType)];
|
|
2719
2831
|
},
|
|
2720
2832
|
[Primitive.Dot]([x, y]) {
|
|
2721
|
-
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2722
2833
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
2723
|
-
const shape$1 =
|
|
2834
|
+
const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
|
|
2724
2835
|
shape$1.splice(-1, 1);
|
|
2725
|
-
return [new ShapedArray(shape$1,
|
|
2836
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2726
2837
|
},
|
|
2727
2838
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
2728
|
-
|
|
2839
|
+
const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
|
|
2729
2840
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2730
|
-
return [new ShapedArray(shape$1,
|
|
2841
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2731
2842
|
},
|
|
2732
2843
|
[Primitive.Compare]: compareAbstractEval,
|
|
2733
2844
|
[Primitive.Where]([cond, x, y]) {
|
|
2734
2845
|
if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
2735
|
-
|
|
2736
|
-
const shape$1 = generalBroadcast(cond.shape,
|
|
2737
|
-
return [new ShapedArray(shape$1,
|
|
2846
|
+
const xy = promoteAvals(x, y);
|
|
2847
|
+
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
2848
|
+
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
2738
2849
|
},
|
|
2739
2850
|
[Primitive.Transpose]([x], { perm }) {
|
|
2740
|
-
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
|
|
2851
|
+
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
2741
2852
|
},
|
|
2742
2853
|
[Primitive.Broadcast]([x], { shape: shape$1 }) {
|
|
2743
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2854
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2744
2855
|
},
|
|
2745
2856
|
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2746
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2857
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2747
2858
|
},
|
|
2748
2859
|
[Primitive.Flip]([x], _) {
|
|
2749
|
-
return [
|
|
2860
|
+
return [ShapedArray.fromAval(x)];
|
|
2750
2861
|
},
|
|
2751
2862
|
[Primitive.Shrink]([x], { slice }) {
|
|
2752
2863
|
const newShape = slice.map((s) => s[1] - s[0]);
|
|
2753
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2864
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2754
2865
|
},
|
|
2755
2866
|
[Primitive.Pad]([x], { width }) {
|
|
2756
2867
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
2757
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2868
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2758
2869
|
},
|
|
2759
2870
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2760
2871
|
for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
@@ -2767,7 +2878,7 @@ const abstractEvalRules = {
|
|
|
2767
2878
|
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
2768
2879
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2769
2880
|
newShape.splice(outDim, 0, ...gatherShape);
|
|
2770
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2881
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2771
2882
|
},
|
|
2772
2883
|
[Primitive.JitCall](args, { jaxpr }) {
|
|
2773
2884
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
@@ -2825,7 +2936,7 @@ function makeJaxpr$1(f, opts) {
|
|
|
2825
2936
|
function jit$1(f, opts) {
|
|
2826
2937
|
const cache = /* @__PURE__ */ new Map();
|
|
2827
2938
|
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2828
|
-
|
|
2939
|
+
const result = ((...args) => {
|
|
2829
2940
|
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2830
2941
|
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2831
2942
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
@@ -2834,11 +2945,16 @@ function jit$1(f, opts) {
|
|
|
2834
2945
|
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2835
2946
|
const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2836
2947
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2948
|
+
name: f.name || "closure",
|
|
2837
2949
|
jaxpr,
|
|
2838
2950
|
numConsts: consts.length
|
|
2839
2951
|
});
|
|
2840
2952
|
return unflatten(outTree, outs);
|
|
2841
2953
|
});
|
|
2954
|
+
result.dispose = () => {
|
|
2955
|
+
for (const { consts } of cache.values()) for (const c of consts) c.dispose();
|
|
2956
|
+
};
|
|
2957
|
+
return result;
|
|
2842
2958
|
}
|
|
2843
2959
|
|
|
2844
2960
|
//#endregion
|
|
@@ -2869,7 +2985,7 @@ var JVPTrace = class extends Trace {
|
|
|
2869
2985
|
return this.lift(pureArray(val));
|
|
2870
2986
|
}
|
|
2871
2987
|
lift(val) {
|
|
2872
|
-
return new JVPTracer(this, val, zerosLike(val.ref));
|
|
2988
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
2873
2989
|
}
|
|
2874
2990
|
processPrimitive(primitive, tracers, params) {
|
|
2875
2991
|
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2900,7 +3016,7 @@ function zeroTangentsJvp(primitive) {
|
|
|
2900
3016
|
return (primals, tangents, params) => {
|
|
2901
3017
|
for (const t of tangents) t.dispose();
|
|
2902
3018
|
const ys = bind(primitive, primals, params);
|
|
2903
|
-
return [ys, ys.map((y) => zerosLike(y.ref))];
|
|
3019
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
2904
3020
|
};
|
|
2905
3021
|
}
|
|
2906
3022
|
const jvpRules = {
|
|
@@ -2918,13 +3034,13 @@ const jvpRules = {
|
|
|
2918
3034
|
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2919
3035
|
else {
|
|
2920
3036
|
dx.dispose();
|
|
2921
|
-
return [[cast(x.ref, dtype)], [zerosLike(x)]];
|
|
3037
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2922
3038
|
}
|
|
2923
3039
|
},
|
|
2924
3040
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2925
3041
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2926
3042
|
dx.dispose();
|
|
2927
|
-
return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
|
|
3043
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2928
3044
|
},
|
|
2929
3045
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2930
3046
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2933,6 +3049,14 @@ const jvpRules = {
|
|
|
2933
3049
|
[Primitive.Cos]([x], [dx]) {
|
|
2934
3050
|
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
2935
3051
|
},
|
|
3052
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3053
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3054
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3055
|
+
},
|
|
3056
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3057
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3058
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3059
|
+
},
|
|
2936
3060
|
[Primitive.Exp]([x], [dx]) {
|
|
2937
3061
|
const z = exp$1(x);
|
|
2938
3062
|
return [[z.ref], [z.mul(dx)]];
|
|
@@ -2983,13 +3107,14 @@ const jvpRules = {
|
|
|
2983
3107
|
const indicesRef = indices.map((t) => t.ref);
|
|
2984
3108
|
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
2985
3109
|
},
|
|
2986
|
-
[Primitive.JitCall](primals, tangents, { jaxpr }) {
|
|
3110
|
+
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
2987
3111
|
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
2988
3112
|
const outs = bind(Primitive.JitCall, [
|
|
2989
3113
|
...newConsts.map((c) => c.ref),
|
|
2990
3114
|
...primals,
|
|
2991
3115
|
...tangents
|
|
2992
3116
|
], {
|
|
3117
|
+
name: `${name}_jvp`,
|
|
2993
3118
|
jaxpr: newJaxpr,
|
|
2994
3119
|
numConsts: newConsts.length
|
|
2995
3120
|
});
|
|
@@ -3043,12 +3168,15 @@ function jvp$1(f, primals, tangents) {
|
|
|
3043
3168
|
function mappedAval(batchDim, aval) {
|
|
3044
3169
|
const shape$1 = [...aval.shape];
|
|
3045
3170
|
shape$1.splice(batchDim, 1);
|
|
3046
|
-
return new ShapedArray(shape$1, aval.dtype);
|
|
3171
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3047
3172
|
}
|
|
3048
3173
|
/** Move one axis to a different index. */
|
|
3049
3174
|
function moveaxis$1(x, src, dst) {
|
|
3050
3175
|
const t = pureArray(x);
|
|
3051
|
-
|
|
3176
|
+
src = checkAxis(src, t.ndim);
|
|
3177
|
+
dst = checkAxis(dst, t.ndim);
|
|
3178
|
+
if (src === dst) return t;
|
|
3179
|
+
const perm = range(t.ndim);
|
|
3052
3180
|
perm.splice(src, 1);
|
|
3053
3181
|
perm.splice(dst, 0, src);
|
|
3054
3182
|
return transpose$1(t, perm);
|
|
@@ -3141,6 +3269,8 @@ const vmapRules = {
|
|
|
3141
3269
|
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3142
3270
|
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3143
3271
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3272
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3273
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3144
3274
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3145
3275
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3146
3276
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
@@ -3182,9 +3312,10 @@ const vmapRules = {
|
|
|
3182
3312
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3183
3313
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3184
3314
|
},
|
|
3185
|
-
[Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
|
|
3315
|
+
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3186
3316
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3187
3317
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
3318
|
+
name: `${name}_vmap`,
|
|
3188
3319
|
jaxpr: newJaxpr,
|
|
3189
3320
|
numConsts: newConsts.length
|
|
3190
3321
|
});
|
|
@@ -3200,7 +3331,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3200
3331
|
if (dims[i] === null) return v.aval;
|
|
3201
3332
|
const shape$1 = [...v.aval.shape];
|
|
3202
3333
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3203
|
-
return new ShapedArray(shape$1, v.aval.dtype);
|
|
3334
|
+
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3204
3335
|
});
|
|
3205
3336
|
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3206
3337
|
const result = {
|
|
@@ -3326,20 +3457,28 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3326
3457
|
function linearizeFlat(f, primalsIn) {
|
|
3327
3458
|
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
3328
3459
|
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3329
|
-
|
|
3460
|
+
const dispose$1 = () => {
|
|
3461
|
+
for (const c of consts) c.dispose();
|
|
3462
|
+
};
|
|
3463
|
+
return [
|
|
3464
|
+
primalsOut,
|
|
3465
|
+
fLin,
|
|
3466
|
+
dispose$1
|
|
3467
|
+
];
|
|
3330
3468
|
}
|
|
3331
3469
|
function linearize$1(f, ...primalsIn) {
|
|
3332
3470
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3333
3471
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3334
|
-
const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3472
|
+
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3335
3473
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
3336
3474
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3337
|
-
const fLin = (...tangentsIn) => {
|
|
3475
|
+
const fLin = ((...tangentsIn) => {
|
|
3338
3476
|
const [tangentsInFlat, inTree2] = flatten(tangentsIn);
|
|
3339
3477
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
|
|
3340
3478
|
const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
|
|
3341
3479
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
3342
|
-
};
|
|
3480
|
+
});
|
|
3481
|
+
fLin.dispose = dispose$1;
|
|
3343
3482
|
return [primalsOut, fLin];
|
|
3344
3483
|
}
|
|
3345
3484
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -3405,8 +3544,8 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3405
3544
|
processPrimitive(primitive, tracers, params) {
|
|
3406
3545
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3407
3546
|
if (primitive === Primitive.JitCall) {
|
|
3408
|
-
const { jaxpr, numConsts } = params;
|
|
3409
|
-
return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
|
|
3547
|
+
const { name, jaxpr, numConsts } = params;
|
|
3548
|
+
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3410
3549
|
}
|
|
3411
3550
|
const tracersIn = tracers.map((t) => this.instantiateConst(t));
|
|
3412
3551
|
const avalsIn = tracersIn.map((t) => t.pval.aval);
|
|
@@ -3432,12 +3571,13 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3432
3571
|
*
|
|
3433
3572
|
* Used when encountering a JitCall rule during the trace.
|
|
3434
3573
|
*/
|
|
3435
|
-
#partialEvalJaxpr(jaxpr, numConsts, tracers) {
|
|
3574
|
+
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3436
3575
|
jaxpr = jaxpr.flatten();
|
|
3437
3576
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3438
3577
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3439
3578
|
const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
|
|
3440
3579
|
const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3580
|
+
name: `${name}_peval`,
|
|
3441
3581
|
jaxpr: jaxpr1,
|
|
3442
3582
|
numConsts: 0
|
|
3443
3583
|
});
|
|
@@ -3449,13 +3589,17 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3449
3589
|
prim: Primitive.JitCall,
|
|
3450
3590
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3451
3591
|
params: {
|
|
3592
|
+
name: `${name}_resid`,
|
|
3452
3593
|
jaxpr: jaxpr2,
|
|
3453
3594
|
numConsts: 0
|
|
3454
3595
|
},
|
|
3455
3596
|
avalsOut: jaxpr2.outs.map((x) => x.aval),
|
|
3456
3597
|
tracerRefsOut: []
|
|
3457
3598
|
};
|
|
3458
|
-
const outs2 = jaxpr2.outs.map((x) =>
|
|
3599
|
+
const outs2 = jaxpr2.outs.map((x, i$1) => {
|
|
3600
|
+
if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
|
|
3601
|
+
return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
|
|
3602
|
+
});
|
|
3459
3603
|
recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
|
|
3460
3604
|
let i = 0;
|
|
3461
3605
|
let j = 0;
|
|
@@ -3539,13 +3683,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3539
3683
|
const [consts, constvars] = unzip2(constToVar.entries());
|
|
3540
3684
|
const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
|
|
3541
3685
|
const outVars = tracersOut.map((t) => tracerToVar.get(t));
|
|
3542
|
-
|
|
3686
|
+
let jaxpr = new Jaxpr(inBinders, eqns, outVars);
|
|
3543
3687
|
typecheckJaxpr(jaxpr);
|
|
3544
3688
|
for (const t of consts) t.ref;
|
|
3545
3689
|
for (const t of tracersIn) t.dispose();
|
|
3546
3690
|
for (const t of tracersOut) t.dispose();
|
|
3691
|
+
jaxpr = jaxpr.simplify();
|
|
3692
|
+
if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3547
3693
|
return {
|
|
3548
|
-
jaxpr
|
|
3694
|
+
jaxpr,
|
|
3549
3695
|
consts
|
|
3550
3696
|
};
|
|
3551
3697
|
}
|
|
@@ -3586,7 +3732,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
|
|
|
3586
3732
|
}
|
|
3587
3733
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
3588
3734
|
const eqn = jaxpr.eqns[i];
|
|
3589
|
-
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ?
|
|
3735
|
+
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
|
|
3590
3736
|
const cotangentsOut = eqn.outBinders.map(readCotangent);
|
|
3591
3737
|
const rule = transposeRules[eqn.primitive];
|
|
3592
3738
|
if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
|
|
@@ -3766,7 +3912,7 @@ const transposeRules = {
|
|
|
3766
3912
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
3767
3913
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
3768
3914
|
},
|
|
3769
|
-
[Primitive.JitCall](cts, args, { jaxpr }) {
|
|
3915
|
+
[Primitive.JitCall](cts, args, { name, jaxpr }) {
|
|
3770
3916
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
3771
3917
|
const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
|
|
3772
3918
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
@@ -3775,6 +3921,7 @@ const transposeRules = {
|
|
|
3775
3921
|
...residuals,
|
|
3776
3922
|
...cts
|
|
3777
3923
|
], {
|
|
3924
|
+
name: `${name}_t`,
|
|
3778
3925
|
jaxpr: newJaxpr,
|
|
3779
3926
|
numConsts: newConsts.length
|
|
3780
3927
|
});
|
|
@@ -3811,20 +3958,28 @@ function vjpFlat(f, primalsIn) {
|
|
|
3811
3958
|
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
3812
3959
|
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
3813
3960
|
};
|
|
3814
|
-
|
|
3961
|
+
const dispose$1 = () => {
|
|
3962
|
+
for (const c of consts) c.dispose();
|
|
3963
|
+
};
|
|
3964
|
+
return [
|
|
3965
|
+
primalsOut,
|
|
3966
|
+
fVjp,
|
|
3967
|
+
dispose$1
|
|
3968
|
+
];
|
|
3815
3969
|
}
|
|
3816
3970
|
function vjp$1(f, ...primalsIn) {
|
|
3817
3971
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3818
3972
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3819
|
-
const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3973
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3820
3974
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
3821
3975
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3822
|
-
const fVjp = (cotangentsOut) => {
|
|
3976
|
+
const fVjp = ((cotangentsOut) => {
|
|
3823
3977
|
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
3824
3978
|
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
3825
3979
|
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
3826
3980
|
return unflatten(inTree, cotangentsInFlat);
|
|
3827
|
-
};
|
|
3981
|
+
});
|
|
3982
|
+
fVjp.dispose = dispose$1;
|
|
3828
3983
|
return [primalsOut, fVjp];
|
|
3829
3984
|
}
|
|
3830
3985
|
function grad$1(f) {
|
|
@@ -3841,8 +3996,9 @@ function valueAndGrad$1(f) {
|
|
|
3841
3996
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3842
3997
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3843
3998
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3844
|
-
const [ct, ...rest] = fVjp(
|
|
3845
|
-
for (const r of rest)
|
|
3999
|
+
const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
|
|
4000
|
+
for (const r of rest) dispose(r);
|
|
4001
|
+
fVjp.dispose();
|
|
3846
4002
|
return [y, ct];
|
|
3847
4003
|
};
|
|
3848
4004
|
}
|
|
@@ -3850,7 +4006,13 @@ function jacrev$1(f) {
|
|
|
3850
4006
|
return function jacobianReverse(x) {
|
|
3851
4007
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
3852
4008
|
const [size$1] = x.shape;
|
|
3853
|
-
const pullback = (ct) =>
|
|
4009
|
+
const pullback = (ct) => {
|
|
4010
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
4011
|
+
y.dispose();
|
|
4012
|
+
const [ret] = fVjp(ct);
|
|
4013
|
+
fVjp.dispose();
|
|
4014
|
+
return ret;
|
|
4015
|
+
};
|
|
3854
4016
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3855
4017
|
};
|
|
3856
4018
|
}
|
|
@@ -3930,19 +4092,38 @@ __export(numpy_exports, {
|
|
|
3930
4092
|
DType: () => DType,
|
|
3931
4093
|
abs: () => abs,
|
|
3932
4094
|
absolute: () => absolute,
|
|
4095
|
+
acos: () => acos,
|
|
4096
|
+
acosh: () => acosh,
|
|
3933
4097
|
add: () => add,
|
|
3934
4098
|
allclose: () => allclose,
|
|
3935
4099
|
arange: () => arange,
|
|
4100
|
+
arccos: () => arccos,
|
|
4101
|
+
arccosh: () => arccosh,
|
|
4102
|
+
arcsinh: () => arcsinh,
|
|
4103
|
+
arctan: () => arctan,
|
|
4104
|
+
arctan2: () => arctan2,
|
|
4105
|
+
arctanh: () => arctanh,
|
|
3936
4106
|
argmax: () => argmax,
|
|
3937
4107
|
argmin: () => argmin,
|
|
3938
4108
|
array: () => array,
|
|
4109
|
+
asin: () => asin,
|
|
4110
|
+
asinh: () => asinh,
|
|
3939
4111
|
astype: () => astype,
|
|
4112
|
+
atan: () => atan,
|
|
4113
|
+
atan2: () => atan2,
|
|
4114
|
+
atanh: () => atanh,
|
|
3940
4115
|
bool: () => bool,
|
|
4116
|
+
broadcastArrays: () => broadcastArrays,
|
|
4117
|
+
broadcastShapes: () => broadcastShapes,
|
|
4118
|
+
broadcastTo: () => broadcastTo,
|
|
4119
|
+
cbrt: () => cbrt,
|
|
3941
4120
|
clip: () => clip,
|
|
3942
4121
|
columnStack: () => columnStack,
|
|
3943
4122
|
concatenate: () => concatenate,
|
|
3944
4123
|
cos: () => cos,
|
|
3945
4124
|
cosh: () => cosh,
|
|
4125
|
+
deg2rad: () => deg2rad,
|
|
4126
|
+
degrees: () => degrees,
|
|
3946
4127
|
diag: () => diag,
|
|
3947
4128
|
diagonal: () => diagonal,
|
|
3948
4129
|
divide: () => divide,
|
|
@@ -3953,6 +4134,7 @@ __export(numpy_exports, {
|
|
|
3953
4134
|
eulerGamma: () => eulerGamma,
|
|
3954
4135
|
exp: () => exp,
|
|
3955
4136
|
exp2: () => exp2,
|
|
4137
|
+
expm1: () => expm1,
|
|
3956
4138
|
eye: () => eye,
|
|
3957
4139
|
flip: () => flip,
|
|
3958
4140
|
fliplr: () => fliplr,
|
|
@@ -3964,14 +4146,17 @@ __export(numpy_exports, {
|
|
|
3964
4146
|
greater: () => greater,
|
|
3965
4147
|
greaterEqual: () => greaterEqual,
|
|
3966
4148
|
hstack: () => hstack,
|
|
4149
|
+
hypot: () => hypot,
|
|
3967
4150
|
identity: () => identity$1,
|
|
3968
4151
|
inf: () => inf,
|
|
4152
|
+
inner: () => inner,
|
|
3969
4153
|
int32: () => int32,
|
|
3970
4154
|
less: () => less,
|
|
3971
4155
|
lessEqual: () => lessEqual,
|
|
3972
4156
|
linspace: () => linspace,
|
|
3973
4157
|
log: () => log,
|
|
3974
4158
|
log10: () => log10,
|
|
4159
|
+
log1p: () => log1p,
|
|
3975
4160
|
log2: () => log2,
|
|
3976
4161
|
matmul: () => matmul,
|
|
3977
4162
|
max: () => max,
|
|
@@ -3987,35 +4172,49 @@ __export(numpy_exports, {
|
|
|
3987
4172
|
negative: () => negative,
|
|
3988
4173
|
notEqual: () => notEqual,
|
|
3989
4174
|
ones: () => ones,
|
|
3990
|
-
onesLike: () => onesLike
|
|
4175
|
+
onesLike: () => onesLike,
|
|
4176
|
+
outer: () => outer,
|
|
3991
4177
|
pad: () => pad,
|
|
3992
4178
|
permuteDims: () => permuteDims,
|
|
3993
4179
|
pi: () => pi,
|
|
4180
|
+
pow: () => pow,
|
|
4181
|
+
power: () => power,
|
|
3994
4182
|
prod: () => prod$1,
|
|
4183
|
+
promoteTypes: () => promoteTypes,
|
|
4184
|
+
rad2deg: () => rad2deg,
|
|
4185
|
+
radians: () => radians,
|
|
3995
4186
|
ravel: () => ravel,
|
|
3996
4187
|
reciprocal: () => reciprocal,
|
|
4188
|
+
repeat: () => repeat,
|
|
3997
4189
|
reshape: () => reshape,
|
|
3998
|
-
scalar: () => scalar,
|
|
3999
4190
|
shape: () => shape,
|
|
4191
|
+
sign: () => sign,
|
|
4000
4192
|
sin: () => sin,
|
|
4001
4193
|
sinh: () => sinh,
|
|
4002
4194
|
size: () => size,
|
|
4003
4195
|
sqrt: () => sqrt,
|
|
4004
4196
|
square: () => square,
|
|
4005
4197
|
stack: () => stack,
|
|
4198
|
+
std: () => std,
|
|
4199
|
+
subtract: () => subtract,
|
|
4006
4200
|
sum: () => sum,
|
|
4007
4201
|
tan: () => tan,
|
|
4008
4202
|
tanh: () => tanh,
|
|
4203
|
+
tile: () => tile,
|
|
4009
4204
|
transpose: () => transpose,
|
|
4205
|
+
tri: () => tri,
|
|
4206
|
+
tril: () => tril,
|
|
4207
|
+
triu: () => triu,
|
|
4010
4208
|
trueDivide: () => trueDivide,
|
|
4011
4209
|
trunc: () => trunc,
|
|
4012
4210
|
uint32: () => uint32,
|
|
4211
|
+
var_: () => var_,
|
|
4013
4212
|
vdot: () => vdot,
|
|
4014
4213
|
vecdot: () => vecdot,
|
|
4015
4214
|
vstack: () => vstack,
|
|
4016
4215
|
where: () => where,
|
|
4017
4216
|
zeros: () => zeros,
|
|
4018
|
-
zerosLike: () => zerosLike
|
|
4217
|
+
zerosLike: () => zerosLike
|
|
4019
4218
|
});
|
|
4020
4219
|
const float32 = DType.Float32;
|
|
4021
4220
|
const int32 = DType.Int32;
|
|
@@ -4032,54 +4231,66 @@ const inf = Number.POSITIVE_INFINITY;
|
|
|
4032
4231
|
const nan = NaN;
|
|
4033
4232
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
4034
4233
|
const pi = Math.PI;
|
|
4035
|
-
/** Element-wise addition, with broadcasting. */
|
|
4234
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
4036
4235
|
const add = add$1;
|
|
4037
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
4236
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
4038
4237
|
const multiply = mul;
|
|
4039
|
-
/** Numerical negative of every element of an array. */
|
|
4238
|
+
/** @function Numerical negative of every element of an array. */
|
|
4040
4239
|
const negative = neg;
|
|
4041
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4240
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4042
4241
|
const reciprocal = reciprocal$1;
|
|
4043
|
-
/** Element-wise sine function (takes radians). */
|
|
4242
|
+
/** @function Element-wise sine function (takes radians). */
|
|
4044
4243
|
const sin = sin$1;
|
|
4045
|
-
/** Element-wise cosine function (takes radians). */
|
|
4244
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
4046
4245
|
const cos = cos$1;
|
|
4047
|
-
/**
|
|
4246
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
4247
|
+
const asin = asin$1;
|
|
4248
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
4249
|
+
const atan = atan$1;
|
|
4250
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
4048
4251
|
const exp = exp$1;
|
|
4049
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
4252
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
4050
4253
|
const log = log$1;
|
|
4051
|
-
/** Calculate the square root of all elements in the input array. */
|
|
4254
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
4052
4255
|
const sqrt = sqrt$1;
|
|
4053
|
-
/** Return element-wise minimum of the input arrays. */
|
|
4256
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
4054
4257
|
const minimum = min$1;
|
|
4055
|
-
/** Return element-wise maximum of the input arrays. */
|
|
4258
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
4056
4259
|
const maximum = max$1;
|
|
4057
|
-
/** Compare two arrays element-wise. */
|
|
4260
|
+
/** @function Compare two arrays element-wise. */
|
|
4058
4261
|
const greater = greater$1;
|
|
4059
|
-
/** Compare two arrays element-wise. */
|
|
4262
|
+
/** @function Compare two arrays element-wise. */
|
|
4060
4263
|
const less = less$1;
|
|
4061
|
-
/** Compare two arrays element-wise. */
|
|
4264
|
+
/** @function Compare two arrays element-wise. */
|
|
4062
4265
|
const equal = equal$1;
|
|
4063
|
-
/** Compare two arrays element-wise. */
|
|
4266
|
+
/** @function Compare two arrays element-wise. */
|
|
4064
4267
|
const notEqual = notEqual$1;
|
|
4065
|
-
/** Compare two arrays element-wise. */
|
|
4268
|
+
/** @function Compare two arrays element-wise. */
|
|
4066
4269
|
const greaterEqual = greaterEqual$1;
|
|
4067
|
-
/** Compare two arrays element-wise. */
|
|
4270
|
+
/** @function Compare two arrays element-wise. */
|
|
4068
4271
|
const lessEqual = lessEqual$1;
|
|
4069
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4272
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4070
4273
|
const where = where$1;
|
|
4071
|
-
/**
|
|
4274
|
+
/**
|
|
4275
|
+
* @function
|
|
4276
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
4277
|
+
*/
|
|
4072
4278
|
const transpose = transpose$1;
|
|
4073
4279
|
/**
|
|
4280
|
+
* @function
|
|
4074
4281
|
* Give a new shape to an array without changing its data.
|
|
4075
4282
|
*
|
|
4076
4283
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
4077
4284
|
* length of the array and remaining dimensions.
|
|
4078
4285
|
*/
|
|
4079
4286
|
const reshape = reshape$1;
|
|
4080
|
-
/**
|
|
4287
|
+
/**
|
|
4288
|
+
* @function
|
|
4289
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
4290
|
+
*/
|
|
4081
4291
|
const moveaxis = moveaxis$1;
|
|
4082
4292
|
/**
|
|
4293
|
+
* @function
|
|
4083
4294
|
* Add padding (zeros) to an array.
|
|
4084
4295
|
*
|
|
4085
4296
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -4087,15 +4298,27 @@ const moveaxis = moveaxis$1;
|
|
|
4087
4298
|
* pair specifies the padding for its corresponding axis.
|
|
4088
4299
|
*/
|
|
4089
4300
|
const pad = pad$1;
|
|
4090
|
-
/**
|
|
4301
|
+
/**
|
|
4302
|
+
* @function
|
|
4303
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
4304
|
+
*/
|
|
4091
4305
|
const ndim = ndim$1;
|
|
4092
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
4306
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
4093
4307
|
const shape = getShape;
|
|
4094
|
-
/**
|
|
4095
|
-
|
|
4096
|
-
|
|
4097
|
-
|
|
4098
|
-
|
|
4308
|
+
/**
|
|
4309
|
+
* @function
|
|
4310
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
4311
|
+
*/
|
|
4312
|
+
const zerosLike = zerosLike$1;
|
|
4313
|
+
/**
|
|
4314
|
+
* @function
|
|
4315
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
4316
|
+
*/
|
|
4317
|
+
const onesLike = onesLike$1;
|
|
4318
|
+
/**
|
|
4319
|
+
* @function
|
|
4320
|
+
* Return a full array with the same shape and type as a given array.
|
|
4321
|
+
*/
|
|
4099
4322
|
const fullLike$1 = fullLike;
|
|
4100
4323
|
/**
|
|
4101
4324
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -4110,23 +4333,23 @@ function astype(a, dtype) {
|
|
|
4110
4333
|
return fudgeArray(a).astype(dtype);
|
|
4111
4334
|
}
|
|
4112
4335
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
4113
|
-
function sum(a, axis, opts) {
|
|
4336
|
+
function sum(a, axis = null, opts) {
|
|
4114
4337
|
return reduce(a, AluOp.Add, axis, opts);
|
|
4115
4338
|
}
|
|
4116
4339
|
/** Product of the array elements over a given axis. */
|
|
4117
|
-
function prod$1(a, axis, opts) {
|
|
4340
|
+
function prod$1(a, axis = null, opts) {
|
|
4118
4341
|
return reduce(a, AluOp.Mul, axis, opts);
|
|
4119
4342
|
}
|
|
4120
4343
|
/** Return the minimum of array elements along a given axis. */
|
|
4121
|
-
function min(a, axis, opts) {
|
|
4344
|
+
function min(a, axis = null, opts) {
|
|
4122
4345
|
return reduce(a, AluOp.Min, axis, opts);
|
|
4123
4346
|
}
|
|
4124
4347
|
/** Return the maximum of array elements along a given axis. */
|
|
4125
|
-
function max(a, axis, opts) {
|
|
4348
|
+
function max(a, axis = null, opts) {
|
|
4126
4349
|
return reduce(a, AluOp.Max, axis, opts);
|
|
4127
4350
|
}
|
|
4128
4351
|
/** Compute the average of the array elements along the specified axis. */
|
|
4129
|
-
function mean(a, axis, opts) {
|
|
4352
|
+
function mean(a, axis = null, opts) {
|
|
4130
4353
|
return fudgeArray(a).mean(axis, opts);
|
|
4131
4354
|
}
|
|
4132
4355
|
/**
|
|
@@ -4142,8 +4365,8 @@ function argmin(a, axis, opts) {
|
|
|
4142
4365
|
axis = 0;
|
|
4143
4366
|
} else axis = checkAxis(axis, a.ndim);
|
|
4144
4367
|
const shape$1 = a.shape;
|
|
4145
|
-
const isMax = equal(a, min(a.ref, axis, {
|
|
4146
|
-
const length =
|
|
4368
|
+
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4369
|
+
const length = array(shape$1[axis], {
|
|
4147
4370
|
dtype: int32,
|
|
4148
4371
|
device: a.device
|
|
4149
4372
|
});
|
|
@@ -4166,8 +4389,8 @@ function argmax(a, axis, opts) {
|
|
|
4166
4389
|
axis = 0;
|
|
4167
4390
|
} else axis = checkAxis(axis, a.ndim);
|
|
4168
4391
|
const shape$1 = a.shape;
|
|
4169
|
-
const isMax = equal(a, max(a.ref, axis, {
|
|
4170
|
-
const length =
|
|
4392
|
+
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4393
|
+
const length = array(shape$1[axis], {
|
|
4171
4394
|
dtype: int32,
|
|
4172
4395
|
device: a.device
|
|
4173
4396
|
});
|
|
@@ -4178,17 +4401,9 @@ function argmax(a, axis, opts) {
|
|
|
4178
4401
|
return length.sub(max(idx, axis, opts));
|
|
4179
4402
|
}
|
|
4180
4403
|
/** Reverse the elements in an array along the given axes. */
|
|
4181
|
-
function flip(x, axis) {
|
|
4404
|
+
function flip(x, axis = null) {
|
|
4182
4405
|
const nd = ndim(x);
|
|
4183
|
-
|
|
4184
|
-
else if (typeof axis === "number") axis = [axis];
|
|
4185
|
-
const seen = /* @__PURE__ */ new Set();
|
|
4186
|
-
for (let i = 0; i < axis.length; i++) {
|
|
4187
|
-
if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
|
|
4188
|
-
if (axis[i] < 0) axis[i] += nd;
|
|
4189
|
-
if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
|
|
4190
|
-
seen.add(axis[i]);
|
|
4191
|
-
}
|
|
4406
|
+
axis = normalizeAxis(axis, nd);
|
|
4192
4407
|
return flip$1(x, axis);
|
|
4193
4408
|
}
|
|
4194
4409
|
/**
|
|
@@ -4294,12 +4509,80 @@ function flipud(x) {
|
|
|
4294
4509
|
function fliplr(x) {
|
|
4295
4510
|
return flip(x, 1);
|
|
4296
4511
|
}
|
|
4512
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
4297
4513
|
const permuteDims = transpose;
|
|
4298
4514
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
4299
4515
|
function ravel(a) {
|
|
4300
4516
|
return fudgeArray(a).ravel();
|
|
4301
4517
|
}
|
|
4302
4518
|
/**
|
|
4519
|
+
* Repeat each element of an array after themselves.
|
|
4520
|
+
*
|
|
4521
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
4522
|
+
* output array.
|
|
4523
|
+
*/
|
|
4524
|
+
function repeat(a, repeats, axis) {
|
|
4525
|
+
if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
|
|
4526
|
+
a = fudgeArray(a);
|
|
4527
|
+
if (axis === void 0) {
|
|
4528
|
+
a = ravel(a);
|
|
4529
|
+
axis = 0;
|
|
4530
|
+
}
|
|
4531
|
+
axis = checkAxis(axis, a.ndim);
|
|
4532
|
+
if (repeats === 1) return a;
|
|
4533
|
+
const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
|
|
4534
|
+
const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
|
|
4535
|
+
return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
|
|
4536
|
+
}
|
|
4537
|
+
/**
|
|
4538
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
4539
|
+
*
|
|
4540
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
4541
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
4542
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
4543
|
+
*/
|
|
4544
|
+
function tile(a, reps) {
|
|
4545
|
+
a = fudgeArray(a);
|
|
4546
|
+
if (typeof reps === "number") reps = [reps];
|
|
4547
|
+
if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
|
|
4548
|
+
const ndiff = reps.length - a.ndim;
|
|
4549
|
+
if (ndiff > 0) a = a.reshape([...rep(ndiff, 1), ...a.shape]);
|
|
4550
|
+
if (ndiff < 0) reps = [...rep(-ndiff, 1), ...reps];
|
|
4551
|
+
const broadcastedShape = [];
|
|
4552
|
+
const broadcastAxes = [];
|
|
4553
|
+
for (let i = 0; i < a.ndim; i++) {
|
|
4554
|
+
if (reps[i] > 1) {
|
|
4555
|
+
broadcastedShape.push(reps[i]);
|
|
4556
|
+
broadcastAxes.push(broadcastedShape.length - 1);
|
|
4557
|
+
}
|
|
4558
|
+
broadcastedShape.push(a.shape[i]);
|
|
4559
|
+
}
|
|
4560
|
+
const finalShape = a.shape.map((d, i) => reps[i] * d);
|
|
4561
|
+
return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
|
|
4562
|
+
}
|
|
4563
|
+
/**
|
|
4564
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
4565
|
+
*
|
|
4566
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
4567
|
+
* dimensions where the shape is 1.
|
|
4568
|
+
*/
|
|
4569
|
+
function broadcastTo(a, shape$1) {
|
|
4570
|
+
const nd = ndim(a);
|
|
4571
|
+
if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
|
|
4572
|
+
return broadcast(a, shape$1, range(shape$1.length - nd));
|
|
4573
|
+
}
|
|
4574
|
+
/** Broadcast input shapes to a common output shape. */
|
|
4575
|
+
function broadcastShapes(...shapes) {
|
|
4576
|
+
if (shapes.length === 0) return [];
|
|
4577
|
+
return shapes.reduce(generalBroadcast);
|
|
4578
|
+
}
|
|
4579
|
+
/** Broadcast arrays to a common shape. */
|
|
4580
|
+
function broadcastArrays(...arrays) {
|
|
4581
|
+
const shapes = arrays.map((a) => shape(a));
|
|
4582
|
+
const outShape = broadcastShapes(...shapes);
|
|
4583
|
+
return arrays.map((a) => broadcastTo(a, outShape));
|
|
4584
|
+
}
|
|
4585
|
+
/**
|
|
4303
4586
|
* Return specified diagonals.
|
|
4304
4587
|
*
|
|
4305
4588
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
@@ -4323,7 +4606,7 @@ function diag(v, k = 0) {
|
|
|
4323
4606
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
4324
4607
|
if (a.ndim === 1) {
|
|
4325
4608
|
const n = a.shape[0];
|
|
4326
|
-
const ret = where(eye(n).equal(1), a.ref, zerosLike
|
|
4609
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
4327
4610
|
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4328
4611
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4329
4612
|
else return ret;
|
|
@@ -4367,8 +4650,36 @@ function dot(x, y) {
|
|
|
4367
4650
|
]);
|
|
4368
4651
|
return dot$1(x, y);
|
|
4369
4652
|
}
|
|
4370
|
-
/**
|
|
4371
|
-
|
|
4653
|
+
/**
|
|
4654
|
+
* Compute the inner product of two arrays.
|
|
4655
|
+
*
|
|
4656
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
4657
|
+
* contraction on the last axis.
|
|
4658
|
+
*
|
|
4659
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
4660
|
+
*/
|
|
4661
|
+
function inner(x, y) {
|
|
4662
|
+
x = reshape(x, shape(x).toSpliced(-1, 0, ...rep(ndim(y) - 1, 1)));
|
|
4663
|
+
return dot$1(x, y);
|
|
4664
|
+
}
|
|
4665
|
+
/**
|
|
4666
|
+
* Compute the outer product of two arrays.
|
|
4667
|
+
*
|
|
4668
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
4669
|
+
* be of shape `[x.size, y.size]`.
|
|
4670
|
+
*/
|
|
4671
|
+
function outer(x, y) {
|
|
4672
|
+
x = ravel(x);
|
|
4673
|
+
y = ravel(y);
|
|
4674
|
+
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
4675
|
+
}
|
|
4676
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
4677
|
+
function vecdot(x, y, { axis } = {}) {
|
|
4678
|
+
const xaxis = checkAxis(axis ?? -1, ndim(x));
|
|
4679
|
+
const yaxis = checkAxis(axis ?? -1, ndim(y));
|
|
4680
|
+
if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
|
|
4681
|
+
x = moveaxis(x, xaxis, -1);
|
|
4682
|
+
y = moveaxis(y, yaxis, -1);
|
|
4372
4683
|
return dot$1(x, y);
|
|
4373
4684
|
}
|
|
4374
4685
|
/**
|
|
@@ -4377,7 +4688,7 @@ function vecdot(x, y) {
|
|
|
4377
4688
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
4378
4689
|
*/
|
|
4379
4690
|
function vdot(x, y) {
|
|
4380
|
-
return
|
|
4691
|
+
return dot$1(ravel(x), ravel(y));
|
|
4381
4692
|
}
|
|
4382
4693
|
/**
|
|
4383
4694
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
@@ -4406,6 +4717,43 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
4406
4717
|
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
4407
4718
|
}
|
|
4408
4719
|
/**
|
|
4720
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
4721
|
+
*
|
|
4722
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
4723
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
4724
|
+
* `k>0` is above it.
|
|
4725
|
+
*/
|
|
4726
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
4727
|
+
m ??= n;
|
|
4728
|
+
dtype ??= DType.Float32;
|
|
4729
|
+
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
4730
|
+
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
4731
|
+
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
4732
|
+
const rows = arange(k, n + k, 1, {
|
|
4733
|
+
dtype: DType.Int32,
|
|
4734
|
+
device
|
|
4735
|
+
});
|
|
4736
|
+
const cols = arange(0, m, 1, {
|
|
4737
|
+
dtype: DType.Int32,
|
|
4738
|
+
device
|
|
4739
|
+
});
|
|
4740
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
4741
|
+
}
|
|
4742
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
4743
|
+
function tril(a, k = 0) {
|
|
4744
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4745
|
+
a = fudgeArray(a);
|
|
4746
|
+
const [n, m] = a.shape.slice(-2);
|
|
4747
|
+
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
4748
|
+
}
|
|
4749
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
4750
|
+
function triu(a, k = 0) {
|
|
4751
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4752
|
+
a = fudgeArray(a);
|
|
4753
|
+
const [n, m] = a.shape.slice(-2);
|
|
4754
|
+
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
4755
|
+
}
|
|
4756
|
+
/**
|
|
4409
4757
|
* Clip (limit) the values in an array.
|
|
4410
4758
|
*
|
|
4411
4759
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -4429,18 +4777,70 @@ function absolute(x) {
|
|
|
4429
4777
|
x = fudgeArray(x);
|
|
4430
4778
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
4431
4779
|
}
|
|
4432
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
4780
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
4433
4781
|
const abs = absolute;
|
|
4782
|
+
/** Return an element-wise indication of sign of the input. */
|
|
4783
|
+
function sign(x) {
|
|
4784
|
+
x = fudgeArray(x);
|
|
4785
|
+
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4786
|
+
}
|
|
4434
4787
|
/** Calculate element-wise square of the input array. */
|
|
4435
4788
|
function square(x) {
|
|
4436
4789
|
x = fudgeArray(x);
|
|
4437
4790
|
return x.ref.mul(x);
|
|
4438
4791
|
}
|
|
4439
|
-
/**
|
|
4792
|
+
/** Element-wise tangent function (takes radians). */
|
|
4440
4793
|
function tan(x) {
|
|
4441
4794
|
x = fudgeArray(x);
|
|
4442
4795
|
return sin(x.ref).div(cos(x));
|
|
4443
4796
|
}
|
|
4797
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
4798
|
+
function acos(x) {
|
|
4799
|
+
return subtract(pi / 2, asin(x));
|
|
4800
|
+
}
|
|
4801
|
+
/**
|
|
4802
|
+
* @function
|
|
4803
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4804
|
+
*
|
|
4805
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4806
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4807
|
+
* improvements.
|
|
4808
|
+
*/
|
|
4809
|
+
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4810
|
+
return sqrt(square(x1).add(square(x2)));
|
|
4811
|
+
});
|
|
4812
|
+
/**
|
|
4813
|
+
* @function
|
|
4814
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
4815
|
+
*
|
|
4816
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
4817
|
+
* The result is in the range [-π, π].
|
|
4818
|
+
*
|
|
4819
|
+
* Uses numerically stable formulas:
|
|
4820
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
4821
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
4822
|
+
*
|
|
4823
|
+
* The output is ill-defined when both x and y are zero.
|
|
4824
|
+
*/
|
|
4825
|
+
const atan2 = jit$1(function atan2$1(y, x) {
|
|
4826
|
+
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4827
|
+
const xNeg = less(x.ref, 0);
|
|
4828
|
+
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
4829
|
+
const denom = where(xNeg, y, r.add(x));
|
|
4830
|
+
return atan(numer.div(denom)).mul(2);
|
|
4831
|
+
});
|
|
4832
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
4833
|
+
const arccos = acos;
|
|
4834
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
4835
|
+
const arctan = atan;
|
|
4836
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
4837
|
+
const arctan2 = atan2;
|
|
4838
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
4839
|
+
function subtract(x, y) {
|
|
4840
|
+
x = fudgeArray(x);
|
|
4841
|
+
y = fudgeArray(y);
|
|
4842
|
+
return x.sub(y);
|
|
4843
|
+
}
|
|
4444
4844
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
4445
4845
|
function trueDivide(x, y) {
|
|
4446
4846
|
x = fudgeArray(x);
|
|
@@ -4448,7 +4848,7 @@ function trueDivide(x, y) {
|
|
|
4448
4848
|
if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
|
|
4449
4849
|
return x.div(y);
|
|
4450
4850
|
}
|
|
4451
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
4851
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
4452
4852
|
const divide = trueDivide;
|
|
4453
4853
|
/** Round input to the nearest integer towards zero. */
|
|
4454
4854
|
function trunc(x) {
|
|
@@ -4466,36 +4866,134 @@ function log2(x) {
|
|
|
4466
4866
|
function log10(x) {
|
|
4467
4867
|
return log(x).mul(Math.LOG10E);
|
|
4468
4868
|
}
|
|
4869
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
4870
|
+
function expm1(x) {
|
|
4871
|
+
return exp(x).sub(1);
|
|
4872
|
+
}
|
|
4873
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
4874
|
+
function log1p(x) {
|
|
4875
|
+
return log(add(1, x));
|
|
4876
|
+
}
|
|
4877
|
+
/** Convert angles from degrees to radians. */
|
|
4878
|
+
function deg2rad(x) {
|
|
4879
|
+
return multiply(x, pi / 180);
|
|
4880
|
+
}
|
|
4881
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
4882
|
+
const radians = deg2rad;
|
|
4883
|
+
/** Convert angles from radians to degrees. */
|
|
4884
|
+
function rad2deg(x) {
|
|
4885
|
+
return multiply(x, 180 / pi);
|
|
4886
|
+
}
|
|
4887
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
4888
|
+
const degrees = rad2deg;
|
|
4889
|
+
/**
|
|
4890
|
+
* @function
|
|
4891
|
+
* Computes first array raised to power of second array, element-wise.
|
|
4892
|
+
*/
|
|
4893
|
+
const power = jit$1(function power$1(x1, x2) {
|
|
4894
|
+
return exp(log(x1).mul(x2));
|
|
4895
|
+
});
|
|
4896
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
4897
|
+
const pow = power;
|
|
4898
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
4899
|
+
const cbrt = jit$1(function cbrt$1(x) {
|
|
4900
|
+
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4901
|
+
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4902
|
+
});
|
|
4469
4903
|
/**
|
|
4904
|
+
* @function
|
|
4470
4905
|
* Calculate element-wise hyperbolic sine of input.
|
|
4471
4906
|
*
|
|
4472
4907
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4473
4908
|
*/
|
|
4474
|
-
function sinh(x) {
|
|
4909
|
+
const sinh = jit$1(function sinh$1(x) {
|
|
4475
4910
|
const ex = exp(x);
|
|
4476
4911
|
const emx = reciprocal(ex.ref);
|
|
4477
4912
|
return ex.sub(emx).mul(.5);
|
|
4478
|
-
}
|
|
4913
|
+
});
|
|
4479
4914
|
/**
|
|
4915
|
+
* @function
|
|
4480
4916
|
* Calculate element-wise hyperbolic cosine of input.
|
|
4481
4917
|
*
|
|
4482
4918
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4483
4919
|
*/
|
|
4484
|
-
function cosh(x) {
|
|
4920
|
+
const cosh = jit$1(function cosh$1(x) {
|
|
4485
4921
|
const ex = exp(x);
|
|
4486
4922
|
const emx = reciprocal(ex.ref);
|
|
4487
4923
|
return ex.add(emx).mul(.5);
|
|
4488
|
-
}
|
|
4924
|
+
});
|
|
4489
4925
|
/**
|
|
4926
|
+
* @function
|
|
4490
4927
|
* Calculate element-wise hyperbolic tangent of input.
|
|
4491
4928
|
*
|
|
4492
4929
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4493
4930
|
*/
|
|
4494
|
-
function tanh(x) {
|
|
4495
|
-
x = fudgeArray(x);
|
|
4931
|
+
const tanh = jit$1(function tanh$1(x) {
|
|
4496
4932
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4497
4933
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4498
4934
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4935
|
+
});
|
|
4936
|
+
/**
|
|
4937
|
+
* @function
|
|
4938
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
4939
|
+
*
|
|
4940
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4941
|
+
*/
|
|
4942
|
+
const arcsinh = jit$1(function arcsinh$1(x) {
|
|
4943
|
+
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4944
|
+
});
|
|
4945
|
+
/**
|
|
4946
|
+
* @function
|
|
4947
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
4948
|
+
*
|
|
4949
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4950
|
+
*/
|
|
4951
|
+
const arccosh = jit$1(function arccosh$1(x) {
|
|
4952
|
+
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4953
|
+
});
|
|
4954
|
+
/**
|
|
4955
|
+
* @function
|
|
4956
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
4957
|
+
*
|
|
4958
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4959
|
+
*/
|
|
4960
|
+
const arctanh = jit$1(function arctanh$1(x) {
|
|
4961
|
+
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4962
|
+
});
|
|
4963
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
4964
|
+
const asinh = arcsinh;
|
|
4965
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
4966
|
+
const acosh = arccosh;
|
|
4967
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
4968
|
+
const atanh = arctanh;
|
|
4969
|
+
/**
|
|
4970
|
+
* Compute the variance of an array.
|
|
4971
|
+
*
|
|
4972
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
4973
|
+
* the specified axis.
|
|
4974
|
+
*
|
|
4975
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4976
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4977
|
+
*/
|
|
4978
|
+
function var_(x, axis = null, opts) {
|
|
4979
|
+
x = fudgeArray(x);
|
|
4980
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
4981
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
4982
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
4983
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
4984
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
4985
|
+
}
|
|
4986
|
+
/**
|
|
4987
|
+
* Compute the standard deviation of an array.
|
|
4988
|
+
*
|
|
4989
|
+
* The standard deviation is computed for the flattened array by default,
|
|
4990
|
+
* otherwise over the specified axis.
|
|
4991
|
+
*
|
|
4992
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4993
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4994
|
+
*/
|
|
4995
|
+
function std(x, axis = null, opts) {
|
|
4996
|
+
return sqrt(var_(x, axis, opts));
|
|
4499
4997
|
}
|
|
4500
4998
|
|
|
4501
4999
|
//#endregion
|
|
@@ -4510,6 +5008,7 @@ __export(nn_exports, {
|
|
|
4510
5008
|
leakyRelu: () => leakyRelu,
|
|
4511
5009
|
logSigmoid: () => logSigmoid,
|
|
4512
5010
|
logSoftmax: () => logSoftmax,
|
|
5011
|
+
logmeanexp: () => logmeanexp,
|
|
4513
5012
|
logsumexp: () => logsumexp,
|
|
4514
5013
|
mish: () => mish,
|
|
4515
5014
|
oneHot: () => oneHot,
|
|
@@ -4520,6 +5019,8 @@ __export(nn_exports, {
|
|
|
4520
5019
|
softSign: () => softSign,
|
|
4521
5020
|
softmax: () => softmax,
|
|
4522
5021
|
softplus: () => softplus,
|
|
5022
|
+
squareplus: () => squareplus,
|
|
5023
|
+
standardize: () => standardize,
|
|
4523
5024
|
swish: () => swish
|
|
4524
5025
|
});
|
|
4525
5026
|
/**
|
|
@@ -4563,6 +5064,7 @@ function softSign(x) {
|
|
|
4563
5064
|
return x.ref.div(absolute(x).add(1));
|
|
4564
5065
|
}
|
|
4565
5066
|
/**
|
|
5067
|
+
* @function
|
|
4566
5068
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4567
5069
|
* Swish, computed element-wise:
|
|
4568
5070
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4571,8 +5073,11 @@ function softSign(x) {
|
|
|
4571
5073
|
*
|
|
4572
5074
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4573
5075
|
*/
|
|
4574
|
-
const silu = jit$1((x)
|
|
5076
|
+
const silu = jit$1(function silu$1(x) {
|
|
5077
|
+
return x.ref.mul(sigmoid(x));
|
|
5078
|
+
});
|
|
4575
5079
|
/**
|
|
5080
|
+
* @function
|
|
4576
5081
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4577
5082
|
* Swish, computed element-wise:
|
|
4578
5083
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4589,7 +5094,10 @@ const swish = silu;
|
|
|
4589
5094
|
function logSigmoid(x) {
|
|
4590
5095
|
return negative(softplus(negative(x)));
|
|
4591
5096
|
}
|
|
4592
|
-
/**
|
|
5097
|
+
/**
|
|
5098
|
+
* @function
|
|
5099
|
+
* Identity activation function. Returns the argument unmodified.
|
|
5100
|
+
*/
|
|
4593
5101
|
const identity = fudgeArray;
|
|
4594
5102
|
/** Leaky rectified linear (ReLU) activation function */
|
|
4595
5103
|
function leakyRelu(x, negativeSlope = .01) {
|
|
@@ -4617,6 +5125,7 @@ function celu(x, alpha = 1) {
|
|
|
4617
5125
|
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
4618
5126
|
}
|
|
4619
5127
|
/**
|
|
5128
|
+
* @function
|
|
4620
5129
|
* Gaussion error linear unit (GELU) activation function.
|
|
4621
5130
|
*
|
|
4622
5131
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -4627,7 +5136,7 @@ function celu(x, alpha = 1) {
|
|
|
4627
5136
|
*
|
|
4628
5137
|
* This will be improved in the future.
|
|
4629
5138
|
*/
|
|
4630
|
-
const gelu = jit$1((x)
|
|
5139
|
+
const gelu = jit$1(function gelu$1(x) {
|
|
4631
5140
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
4632
5141
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
4633
5142
|
});
|
|
@@ -4648,6 +5157,16 @@ function glu(x, axis = -1) {
|
|
|
4648
5157
|
return a.mul(sigmoid(b));
|
|
4649
5158
|
}
|
|
4650
5159
|
/**
|
|
5160
|
+
* Squareplus activation function.
|
|
5161
|
+
*
|
|
5162
|
+
* Computes the element-wise function:
|
|
5163
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
5164
|
+
*/
|
|
5165
|
+
function squareplus(x, b = 4) {
|
|
5166
|
+
x = fudgeArray(x);
|
|
5167
|
+
return x.ref.add(sqrt(square(x).add(b))).mul(.5);
|
|
5168
|
+
}
|
|
5169
|
+
/**
|
|
4651
5170
|
* Mish activation function.
|
|
4652
5171
|
*
|
|
4653
5172
|
* Computes the element-wise function:
|
|
@@ -4665,17 +5184,13 @@ function mish(x) {
|
|
|
4665
5184
|
*
|
|
4666
5185
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
4667
5186
|
*/
|
|
4668
|
-
function softmax(x, axis) {
|
|
5187
|
+
function softmax(x, axis = -1) {
|
|
4669
5188
|
x = fudgeArray(x);
|
|
4670
|
-
|
|
4671
|
-
|
|
4672
|
-
|
|
4673
|
-
x.dispose();
|
|
4674
|
-
return ones(x.shape);
|
|
4675
|
-
}
|
|
4676
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5189
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5190
|
+
if (axis.length === 0) return onesLike(x);
|
|
5191
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4677
5192
|
const unnormalized = exp(x.sub(stopGradient(xMax)));
|
|
4678
|
-
return unnormalized.ref.div(unnormalized.sum(axis, {
|
|
5193
|
+
return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
|
|
4679
5194
|
}
|
|
4680
5195
|
/**
|
|
4681
5196
|
* Log-Softmax function.
|
|
@@ -4685,17 +5200,13 @@ function softmax(x, axis) {
|
|
|
4685
5200
|
*
|
|
4686
5201
|
* If `axis` is not specified, it defaults to the last axis.
|
|
4687
5202
|
*/
|
|
4688
|
-
function logSoftmax(x, axis) {
|
|
5203
|
+
function logSoftmax(x, axis = -1) {
|
|
4689
5204
|
x = fudgeArray(x);
|
|
4690
|
-
|
|
4691
|
-
|
|
4692
|
-
|
|
4693
|
-
x.dispose();
|
|
4694
|
-
return zeros(x.shape);
|
|
4695
|
-
}
|
|
4696
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5205
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5206
|
+
if (axis.length === 0) return zerosLike(x);
|
|
5207
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4697
5208
|
const shifted = x.sub(stopGradient(xMax));
|
|
4698
|
-
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, {
|
|
5209
|
+
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
|
|
4699
5210
|
return shifted.sub(shiftedLogsumexp);
|
|
4700
5211
|
}
|
|
4701
5212
|
/**
|
|
@@ -4706,16 +5217,39 @@ function logSoftmax(x, axis) {
|
|
|
4706
5217
|
*
|
|
4707
5218
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
4708
5219
|
*/
|
|
4709
|
-
function logsumexp(x, axis) {
|
|
5220
|
+
function logsumexp(x, axis = null) {
|
|
4710
5221
|
x = fudgeArray(x);
|
|
4711
|
-
|
|
4712
|
-
else if (typeof axis === "number") axis = [axis];
|
|
5222
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
4713
5223
|
if (axis.length === 0) return x;
|
|
4714
5224
|
const xMax = stopGradient(max(x.ref, axis));
|
|
4715
5225
|
const xMaxDims = broadcast(xMax.ref, x.shape, axis);
|
|
4716
5226
|
const shifted = x.sub(xMaxDims);
|
|
4717
5227
|
return xMax.add(log(exp(shifted).sum(axis)));
|
|
4718
5228
|
}
|
|
5229
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5230
|
+
function logmeanexp(x, axis = null) {
|
|
5231
|
+
x = fudgeArray(x);
|
|
5232
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5233
|
+
if (axis.length === 0) return x;
|
|
5234
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5235
|
+
return logsumexp(x, axis).sub(Math.log(n));
|
|
5236
|
+
}
|
|
5237
|
+
/**
|
|
5238
|
+
* Standardizes input to zero mean and unit variance.
|
|
5239
|
+
*
|
|
5240
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
5241
|
+
* axis, or `null` to standardize over all elements.
|
|
5242
|
+
*
|
|
5243
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
5244
|
+
*/
|
|
5245
|
+
function standardize(x, axis = -1, opts = {}) {
|
|
5246
|
+
x = fudgeArray(x);
|
|
5247
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5248
|
+
if (axis.length === 0) return x;
|
|
5249
|
+
const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
|
|
5250
|
+
const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
|
|
5251
|
+
return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
|
|
5252
|
+
}
|
|
4719
5253
|
/**
|
|
4720
5254
|
* One-hot encodes the given indices.
|
|
4721
5255
|
*
|
|
@@ -4733,7 +5267,7 @@ function logsumexp(x, axis) {
|
|
|
4733
5267
|
* ```
|
|
4734
5268
|
*/
|
|
4735
5269
|
function oneHot(x, numClasses) {
|
|
4736
|
-
if (x.dtype
|
|
5270
|
+
if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4737
5271
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4738
5272
|
}
|
|
4739
5273
|
|
|
@@ -4741,8 +5275,11 @@ function oneHot(x, numClasses) {
|
|
|
4741
5275
|
//#region src/random.ts
|
|
4742
5276
|
var random_exports = {};
|
|
4743
5277
|
__export(random_exports, {
|
|
5278
|
+
bernoulli: () => bernoulli,
|
|
4744
5279
|
bits: () => bits,
|
|
5280
|
+
exponential: () => exponential,
|
|
4745
5281
|
key: () => key,
|
|
5282
|
+
normal: () => normal,
|
|
4746
5283
|
split: () => split,
|
|
4747
5284
|
uniform: () => uniform
|
|
4748
5285
|
});
|
|
@@ -4770,21 +5307,58 @@ function bits(key$1, shape$1 = []) {
|
|
|
4770
5307
|
const keyShape = validateKeyShape(key$1);
|
|
4771
5308
|
return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
|
|
4772
5309
|
}
|
|
4773
|
-
/**
|
|
4774
|
-
function
|
|
5310
|
+
/**
|
|
5311
|
+
* @function
|
|
5312
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
5313
|
+
*/
|
|
5314
|
+
const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
4775
5315
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
4776
|
-
const mantissa = bits(key$1, shape$1).div(
|
|
5316
|
+
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
4777
5317
|
dtype: DType.Uint32,
|
|
4778
5318
|
device: key$1.device
|
|
4779
5319
|
}));
|
|
4780
|
-
const float12 = mantissa.add(
|
|
5320
|
+
const float12 = mantissa.add(array(1065353216, {
|
|
4781
5321
|
dtype: DType.Uint32,
|
|
4782
5322
|
device: key$1.device
|
|
4783
5323
|
}));
|
|
4784
5324
|
const rand = bitcast(float12, DType.Float32).sub(1);
|
|
4785
5325
|
if (minval === 0 && maxval === 1) return rand;
|
|
4786
5326
|
else return rand.mul(maxval - minval).add(minval);
|
|
5327
|
+
}, { staticArgnums: [1, 2] });
|
|
5328
|
+
/**
|
|
5329
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5330
|
+
*
|
|
5331
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
5332
|
+
* and must be broadcastable to `shape`.
|
|
5333
|
+
*/
|
|
5334
|
+
function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
5335
|
+
p = fudgeArray(p);
|
|
5336
|
+
return uniform(key$1, shape$1).less(p);
|
|
4787
5337
|
}
|
|
5338
|
+
/**
|
|
5339
|
+
* @function
|
|
5340
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
5341
|
+
*/
|
|
5342
|
+
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
5343
|
+
const u = uniform(key$1, shape$1);
|
|
5344
|
+
return negative(log1p(negative(u)));
|
|
5345
|
+
}, { staticArgnums: [1] });
|
|
5346
|
+
/**
|
|
5347
|
+
* @function
|
|
5348
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5349
|
+
*
|
|
5350
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5351
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5352
|
+
* bitwise identical to JAX.
|
|
5353
|
+
*/
|
|
5354
|
+
const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
5355
|
+
const [k1, k2] = split(key$1, 2);
|
|
5356
|
+
const u1 = uniform(k1, shape$1);
|
|
5357
|
+
const u2 = uniform(k2, shape$1);
|
|
5358
|
+
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5359
|
+
const theta = u2.mul(2 * Math.PI);
|
|
5360
|
+
return radius.mul(cos(theta));
|
|
5361
|
+
}, { staticArgnums: [1] });
|
|
4788
5362
|
|
|
4789
5363
|
//#endregion
|
|
4790
5364
|
//#region src/polyfills.ts
|
|
@@ -4794,20 +5368,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
4794
5368
|
|
|
4795
5369
|
//#endregion
|
|
4796
5370
|
//#region src/index.ts
|
|
4797
|
-
/**
|
|
5371
|
+
/**
|
|
5372
|
+
* @function
|
|
5373
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
5374
|
+
*/
|
|
4798
5375
|
const jvp = jvp$1;
|
|
4799
|
-
/**
|
|
5376
|
+
/**
|
|
5377
|
+
* @function
|
|
5378
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
5379
|
+
*/
|
|
4800
5380
|
const vmap = vmap$1;
|
|
4801
|
-
/**
|
|
5381
|
+
/**
|
|
5382
|
+
* @function
|
|
5383
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
5384
|
+
*/
|
|
4802
5385
|
const jacfwd = jacfwd$1;
|
|
4803
|
-
/**
|
|
5386
|
+
/**
|
|
5387
|
+
* @function
|
|
5388
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
5389
|
+
*/
|
|
4804
5390
|
const makeJaxpr = makeJaxpr$1;
|
|
4805
5391
|
/**
|
|
5392
|
+
* @function
|
|
4806
5393
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
4807
5394
|
*
|
|
4808
5395
|
* The function will be compiled the first time it is called with a set of
|
|
4809
5396
|
* argument shapes.
|
|
4810
5397
|
*
|
|
5398
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
5399
|
+
* calls to free memory associated with array constants.
|
|
5400
|
+
*
|
|
4811
5401
|
* **Options:**
|
|
4812
5402
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
4813
5403
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -4817,23 +5407,52 @@ const makeJaxpr = makeJaxpr$1;
|
|
|
4817
5407
|
*/
|
|
4818
5408
|
const jit = jit$1;
|
|
4819
5409
|
/**
|
|
5410
|
+
* @function
|
|
4820
5411
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
4821
5412
|
* partial evaluation.
|
|
4822
5413
|
*/
|
|
4823
5414
|
const linearize = linearize$1;
|
|
4824
|
-
/**
|
|
5415
|
+
/**
|
|
5416
|
+
* @function
|
|
5417
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
5418
|
+
*/
|
|
4825
5419
|
const vjp = vjp$1;
|
|
4826
5420
|
/**
|
|
5421
|
+
* @function
|
|
4827
5422
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
4828
5423
|
* first argument.
|
|
4829
5424
|
*/
|
|
4830
5425
|
const grad = grad$1;
|
|
4831
|
-
/**
|
|
5426
|
+
/**
|
|
5427
|
+
* @function
|
|
5428
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
5429
|
+
*/
|
|
4832
5430
|
const valueAndGrad = valueAndGrad$1;
|
|
4833
|
-
/**
|
|
5431
|
+
/**
|
|
5432
|
+
* @function
|
|
5433
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
5434
|
+
*/
|
|
4834
5435
|
const jacrev = jacrev$1;
|
|
4835
|
-
/**
|
|
5436
|
+
/**
|
|
5437
|
+
* @function
|
|
5438
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
5439
|
+
*/
|
|
4836
5440
|
const jacobian = jacrev;
|
|
5441
|
+
/**
|
|
5442
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
5443
|
+
*
|
|
5444
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
5445
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
5446
|
+
* to avoid queueing up too many pending operations.
|
|
5447
|
+
*
|
|
5448
|
+
* Does not consume reference to the arrays.
|
|
5449
|
+
*/
|
|
5450
|
+
async function blockUntilReady(x) {
|
|
5451
|
+
const promises = [];
|
|
5452
|
+
for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
|
|
5453
|
+
await Promise.all(promises);
|
|
5454
|
+
return x;
|
|
5455
|
+
}
|
|
4837
5456
|
|
|
4838
5457
|
//#endregion
|
|
4839
|
-
export { DType, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random,
|
|
5458
|
+
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|