@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-yEU0L_ig.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
@@ -354,6 +354,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
354
354
|
Primitive$1["RandomBits"] = "random_bits";
|
|
355
355
|
Primitive$1["Sin"] = "sin";
|
|
356
356
|
Primitive$1["Cos"] = "cos";
|
|
357
|
+
Primitive$1["Asin"] = "asin";
|
|
358
|
+
Primitive$1["Atan"] = "atan";
|
|
357
359
|
Primitive$1["Exp"] = "exp";
|
|
358
360
|
Primitive$1["Log"] = "log";
|
|
359
361
|
Primitive$1["Sqrt"] = "sqrt";
|
|
@@ -421,6 +423,12 @@ function sin$1(x) {
|
|
|
421
423
|
function cos$1(x) {
|
|
422
424
|
return bind1(Primitive.Cos, [x]);
|
|
423
425
|
}
|
|
426
|
+
function asin$1(x) {
|
|
427
|
+
return bind1(Primitive.Asin, [x]);
|
|
428
|
+
}
|
|
429
|
+
function atan$1(x) {
|
|
430
|
+
return bind1(Primitive.Atan, [x]);
|
|
431
|
+
}
|
|
424
432
|
function exp$1(x) {
|
|
425
433
|
return bind1(Primitive.Exp, [x]);
|
|
426
434
|
}
|
|
@@ -436,18 +444,16 @@ function min$1(x, y) {
|
|
|
436
444
|
function max$1(x, y) {
|
|
437
445
|
return bind1(Primitive.Max, [x, y]);
|
|
438
446
|
}
|
|
439
|
-
function reduce(x, op, axis, opts) {
|
|
447
|
+
function reduce(x, op, axis = null, opts) {
|
|
440
448
|
if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
441
|
-
|
|
442
|
-
else axis = [];
|
|
443
|
-
else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, ndim$1(x))];
|
|
444
|
-
else axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
|
|
449
|
+
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
445
450
|
const originalShape = getShape(x);
|
|
446
|
-
|
|
451
|
+
let result = bind1(Primitive.Reduce, [x], {
|
|
447
452
|
op,
|
|
448
453
|
axis
|
|
449
454
|
});
|
|
450
|
-
|
|
455
|
+
if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
|
|
456
|
+
return result;
|
|
451
457
|
}
|
|
452
458
|
function dot$1(x, y) {
|
|
453
459
|
return bind1(Primitive.Dot, [x, y]);
|
|
@@ -493,10 +499,11 @@ function where$1(cond, x, y) {
|
|
|
493
499
|
}
|
|
494
500
|
function transpose$1(x, perm) {
|
|
495
501
|
perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
|
|
502
|
+
if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
496
503
|
return bind1(Primitive.Transpose, [x], { perm });
|
|
497
504
|
}
|
|
498
505
|
function broadcast(x, shape$1, axis) {
|
|
499
|
-
axis =
|
|
506
|
+
axis = require_backend.normalizeAxis(axis, shape$1.length);
|
|
500
507
|
return bind1(Primitive.Broadcast, [x], {
|
|
501
508
|
shape: shape$1,
|
|
502
509
|
axis
|
|
@@ -515,7 +522,7 @@ function reshape$1(x, shape$1) {
|
|
|
515
522
|
return bind1(Primitive.Reshape, [x], { shape: shape$1 });
|
|
516
523
|
}
|
|
517
524
|
function flip$1(x, axis) {
|
|
518
|
-
axis =
|
|
525
|
+
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
519
526
|
return bind1(Primitive.Flip, [x], { axis });
|
|
520
527
|
}
|
|
521
528
|
function shrink(x, slice) {
|
|
@@ -589,21 +596,49 @@ var Trace = class {
|
|
|
589
596
|
this.main = main;
|
|
590
597
|
}
|
|
591
598
|
};
|
|
599
|
+
/**
|
|
600
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
601
|
+
*
|
|
602
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
603
|
+
* implemented in that function as `weakType` is not passed.
|
|
604
|
+
*/
|
|
605
|
+
function promoteAvals(a, b) {
|
|
606
|
+
const shape$1 = require_backend.generalBroadcast(a.shape, b.shape);
|
|
607
|
+
const weakType = a.weakType && b.weakType;
|
|
608
|
+
let dtype;
|
|
609
|
+
if (a.weakType === b.weakType) dtype = require_backend.promoteTypes(a.dtype, b.dtype);
|
|
610
|
+
else if (a.weakType) dtype = require_backend.promoteTypes(b.dtype, require_backend.DType.Uint32);
|
|
611
|
+
else dtype = require_backend.promoteTypes(a.dtype, require_backend.DType.Uint32);
|
|
612
|
+
return new ShapedArray(shape$1, dtype, weakType);
|
|
613
|
+
}
|
|
592
614
|
var Tracer = class Tracer {
|
|
593
615
|
/** @ignore */
|
|
594
616
|
_trace;
|
|
595
617
|
constructor(trace) {
|
|
596
618
|
this._trace = trace;
|
|
597
619
|
}
|
|
620
|
+
/** The shape of the array. */
|
|
598
621
|
get shape() {
|
|
599
622
|
return this.aval.shape;
|
|
600
623
|
}
|
|
624
|
+
/** The total number of elements in the array. */
|
|
601
625
|
get size() {
|
|
602
626
|
return require_backend.prod(this.shape);
|
|
603
627
|
}
|
|
628
|
+
/** The dtype of elements stored in the array. */
|
|
604
629
|
get dtype() {
|
|
605
630
|
return this.aval.dtype;
|
|
606
631
|
}
|
|
632
|
+
/**
|
|
633
|
+
* Whether the array is weakly typed.
|
|
634
|
+
*
|
|
635
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
636
|
+
* `promoteTypes()` for details.
|
|
637
|
+
*/
|
|
638
|
+
get weakType() {
|
|
639
|
+
return this.aval.weakType;
|
|
640
|
+
}
|
|
641
|
+
/** The number of dimensions of the array. */
|
|
607
642
|
get ndim() {
|
|
608
643
|
return this.shape.length;
|
|
609
644
|
}
|
|
@@ -639,22 +674,20 @@ var Tracer = class Tracer {
|
|
|
639
674
|
return lessEqual$1(this, other);
|
|
640
675
|
}
|
|
641
676
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
642
|
-
sum(axis, opts) {
|
|
677
|
+
sum(axis = null, opts) {
|
|
643
678
|
return reduce(this, require_backend.AluOp.Add, axis, opts);
|
|
644
679
|
}
|
|
645
680
|
/** Product of the array elements over a given axis. */
|
|
646
|
-
prod(axis, opts) {
|
|
681
|
+
prod(axis = null, opts) {
|
|
647
682
|
return reduce(this, require_backend.AluOp.Mul, axis, opts);
|
|
648
683
|
}
|
|
649
684
|
/** Compute the average of the array elements along the specified axis. */
|
|
650
|
-
mean(axis, opts) {
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
657
|
-
return result;
|
|
685
|
+
mean(axis = null, opts) {
|
|
686
|
+
axis = require_backend.normalizeAxis(axis, this.ndim);
|
|
687
|
+
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
688
|
+
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
689
|
+
const result = reduce(this, require_backend.AluOp.Add, axis, opts);
|
|
690
|
+
return result.mul(1 / n);
|
|
658
691
|
}
|
|
659
692
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
660
693
|
transpose(perm) {
|
|
@@ -841,12 +874,13 @@ function getShape(x) {
|
|
|
841
874
|
return x instanceof Tracer ? x.shape : [];
|
|
842
875
|
}
|
|
843
876
|
var ShapedArray = class ShapedArray {
|
|
844
|
-
constructor(shape$1, dtype) {
|
|
877
|
+
constructor(shape$1, dtype, weakType) {
|
|
845
878
|
this.shape = shape$1;
|
|
846
879
|
this.dtype = dtype;
|
|
880
|
+
this.weakType = weakType;
|
|
847
881
|
}
|
|
848
882
|
static fromAval(aval) {
|
|
849
|
-
return new ShapedArray(aval.shape, aval.dtype);
|
|
883
|
+
return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
|
|
850
884
|
}
|
|
851
885
|
get ndim() {
|
|
852
886
|
return this.shape.length;
|
|
@@ -860,7 +894,7 @@ var ShapedArray = class ShapedArray {
|
|
|
860
894
|
};
|
|
861
895
|
function getAval(x) {
|
|
862
896
|
if (x instanceof Tracer) return x.aval;
|
|
863
|
-
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32);
|
|
897
|
+
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32, typeof x === "boolean" ? false : true);
|
|
864
898
|
else throw new TypeError(`Unknown value: ${x}`);
|
|
865
899
|
}
|
|
866
900
|
function bind(prim, args, params = {}) {
|
|
@@ -1145,7 +1179,7 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1145
1179
|
}
|
|
1146
1180
|
function broadcastedJit(fn) {
|
|
1147
1181
|
return (nargs, exps, avals, params) => {
|
|
1148
|
-
const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
|
|
1182
|
+
const newShape = avals.map((aval) => aval.shape).reduce(require_backend.generalBroadcast);
|
|
1149
1183
|
exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
|
|
1150
1184
|
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1151
1185
|
}));
|
|
@@ -1182,11 +1216,13 @@ const jitRules = {
|
|
|
1182
1216
|
const k1 = reshapeViews(keys[1], mapping);
|
|
1183
1217
|
const c0 = require_backend.AluExp.u32(0);
|
|
1184
1218
|
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1185
|
-
const exp$2 = require_backend.AluExp.threefry2x32(
|
|
1219
|
+
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1186
1220
|
return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
|
|
1187
1221
|
},
|
|
1188
1222
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
1189
1223
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
1224
|
+
[Primitive.Asin]: unopJit(require_backend.AluExp.asin),
|
|
1225
|
+
[Primitive.Atan]: unopJit(require_backend.AluExp.atan),
|
|
1190
1226
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
1191
1227
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1192
1228
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
@@ -1221,7 +1257,7 @@ const jitRules = {
|
|
|
1221
1257
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
1222
1258
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
1223
1259
|
const c = k1.exp;
|
|
1224
|
-
const cs =
|
|
1260
|
+
const cs = promoteAvals(as, bs);
|
|
1225
1261
|
return jitRules[Primitive.Reduce](nargs, [c], [cs], {
|
|
1226
1262
|
op: require_backend.AluOp.Add,
|
|
1227
1263
|
axis: [cs.ndim - 1]
|
|
@@ -1231,8 +1267,8 @@ const jitRules = {
|
|
|
1231
1267
|
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1232
1268
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1233
1269
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1234
|
-
as = new ShapedArray(stX.shape, as.dtype);
|
|
1235
|
-
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1270
|
+
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1271
|
+
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1236
1272
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1237
1273
|
},
|
|
1238
1274
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
@@ -1249,7 +1285,7 @@ const jitRules = {
|
|
|
1249
1285
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1250
1286
|
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1251
1287
|
const axisSet = new Set(axis);
|
|
1252
|
-
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1288
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
1253
1289
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1254
1290
|
finalShape.splice(outDim, 0, ...indexShape);
|
|
1255
1291
|
const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
|
|
@@ -1285,9 +1321,10 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1285
1321
|
Primitive.Conv,
|
|
1286
1322
|
Primitive.PoolTranspose
|
|
1287
1323
|
];
|
|
1324
|
+
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
1288
1325
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1289
1326
|
const eqn = jaxpr.eqns[i];
|
|
1290
|
-
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive
|
|
1327
|
+
if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1291
1328
|
for (const v of eqn.outBinders) {
|
|
1292
1329
|
blackNodes.add(v);
|
|
1293
1330
|
p1NextBlack.set(v, v);
|
|
@@ -1417,6 +1454,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1417
1454
|
static #nextId = 1001;
|
|
1418
1455
|
id;
|
|
1419
1456
|
#dtype;
|
|
1457
|
+
#weakType;
|
|
1420
1458
|
#source;
|
|
1421
1459
|
#st;
|
|
1422
1460
|
#backend;
|
|
@@ -1428,19 +1466,22 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1428
1466
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1429
1467
|
* will be freed when the array is disposed.
|
|
1430
1468
|
*/
|
|
1431
|
-
constructor(
|
|
1469
|
+
constructor(args) {
|
|
1432
1470
|
super(baseArrayTrace);
|
|
1433
1471
|
this.id = Array$1.#nextId++;
|
|
1434
|
-
this.#dtype = dtype;
|
|
1435
|
-
this.#
|
|
1436
|
-
this.#
|
|
1437
|
-
this.#
|
|
1472
|
+
this.#dtype = args.dtype;
|
|
1473
|
+
this.#weakType = args.weakType;
|
|
1474
|
+
this.#source = args.source;
|
|
1475
|
+
this.#st = args.st;
|
|
1476
|
+
this.#backend = args.backend;
|
|
1438
1477
|
this.#rc = 1;
|
|
1439
|
-
this.#pendingSet = new Set(pending);
|
|
1478
|
+
this.#pendingSet = new Set(args.pending);
|
|
1479
|
+
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1480
|
+
else if (this.#source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1440
1481
|
}
|
|
1441
1482
|
/** @ignore */
|
|
1442
1483
|
get aval() {
|
|
1443
|
-
return new ShapedArray(this.#st.shape, this.#dtype);
|
|
1484
|
+
return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
|
|
1444
1485
|
}
|
|
1445
1486
|
/** Return a simple string representation of the array's dimensions. */
|
|
1446
1487
|
toString() {
|
|
@@ -1452,6 +1493,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1452
1493
|
#check() {
|
|
1453
1494
|
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1454
1495
|
}
|
|
1496
|
+
/** Construct an array, copying fields from `this`. */
|
|
1497
|
+
#newArrayFrom(args) {
|
|
1498
|
+
return new Array$1({
|
|
1499
|
+
source: args.source ?? this.#source,
|
|
1500
|
+
st: args.st ?? this.#st,
|
|
1501
|
+
dtype: args.dtype ?? this.#dtype,
|
|
1502
|
+
weakType: this.#weakType,
|
|
1503
|
+
backend: args.backend ?? this.#backend,
|
|
1504
|
+
pending: args.pending ?? this.#pending ?? void 0
|
|
1505
|
+
});
|
|
1506
|
+
}
|
|
1455
1507
|
get ref() {
|
|
1456
1508
|
this.#check();
|
|
1457
1509
|
this.#rc++;
|
|
@@ -1491,7 +1543,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1491
1543
|
const pending = this.#pending;
|
|
1492
1544
|
for (const exe of pending) exe.updateRc(1);
|
|
1493
1545
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1494
|
-
const ar =
|
|
1546
|
+
const ar = this.#newArrayFrom({
|
|
1547
|
+
st,
|
|
1548
|
+
pending
|
|
1549
|
+
});
|
|
1495
1550
|
this.dispose();
|
|
1496
1551
|
return ar;
|
|
1497
1552
|
}
|
|
@@ -1540,7 +1595,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1540
1595
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1541
1596
|
this.dispose();
|
|
1542
1597
|
for (const ar of indices) ar.dispose();
|
|
1543
|
-
return
|
|
1598
|
+
return this.#newArrayFrom({
|
|
1599
|
+
source: output,
|
|
1600
|
+
st: require_backend.ShapeTracker.fromShape(finalShape),
|
|
1601
|
+
pending
|
|
1602
|
+
});
|
|
1544
1603
|
}
|
|
1545
1604
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1546
1605
|
#moveAxesDown(axis) {
|
|
@@ -1563,11 +1622,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1563
1622
|
return this.#reshape(this.#st.permute(perm));
|
|
1564
1623
|
}
|
|
1565
1624
|
#unary(op, dtypeOutput) {
|
|
1625
|
+
const weakType = !dtypeOutput && this.#weakType;
|
|
1566
1626
|
dtypeOutput ??= this.#dtype;
|
|
1567
1627
|
this.#check();
|
|
1568
1628
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1569
1629
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1570
|
-
return
|
|
1630
|
+
return this.#newArrayFrom({
|
|
1631
|
+
source: exp$3.simplify(),
|
|
1632
|
+
dtype: dtypeOutput,
|
|
1633
|
+
weakType
|
|
1634
|
+
});
|
|
1571
1635
|
}
|
|
1572
1636
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1573
1637
|
const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1577,41 +1641,65 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1577
1641
|
for (const exe of pending) exe.updateRc(1);
|
|
1578
1642
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1579
1643
|
this.dispose();
|
|
1580
|
-
return
|
|
1644
|
+
return this.#newArrayFrom({
|
|
1645
|
+
source: output,
|
|
1646
|
+
st: require_backend.ShapeTracker.fromShape(this.shape),
|
|
1647
|
+
dtype: dtypeOutput,
|
|
1648
|
+
weakType,
|
|
1649
|
+
pending
|
|
1650
|
+
});
|
|
1581
1651
|
}
|
|
1582
1652
|
#binary(op, other) {
|
|
1583
|
-
const custom = (src) => new require_backend.AluExp(op,
|
|
1653
|
+
const custom = (src) => new require_backend.AluExp(op, src[0].dtype, src);
|
|
1584
1654
|
return Array$1.#naryCustom(op, custom, [this, other]);
|
|
1585
1655
|
}
|
|
1586
|
-
static #naryCustom(name, custom, arrays, { dtypeOverride,
|
|
1656
|
+
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1587
1657
|
const n = arrays.length;
|
|
1588
1658
|
const backend = arrays[0].#backend;
|
|
1589
1659
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1590
1660
|
for (const ar of arrays) ar.#check();
|
|
1591
|
-
let
|
|
1661
|
+
let castDtype;
|
|
1662
|
+
let castWeakType = true;
|
|
1592
1663
|
for (let i = 0; i < n; i++) {
|
|
1593
1664
|
if (dtypeOverride?.[i]) {
|
|
1594
1665
|
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1595
|
-
} else if (
|
|
1596
|
-
|
|
1666
|
+
} else if (castDtype === void 0) {
|
|
1667
|
+
castDtype = arrays[i].#dtype;
|
|
1668
|
+
castWeakType = arrays[i].#weakType;
|
|
1669
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1597
1670
|
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1598
1671
|
}
|
|
1599
|
-
|
|
1600
|
-
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1672
|
+
const weakType = castWeakType && !strongTypeOutput;
|
|
1601
1673
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1602
1674
|
const newShape = [...arrays[0].shape];
|
|
1603
1675
|
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
1676
|
+
const sources = arrays.map((ar, i) => {
|
|
1677
|
+
if (!dtypeOverride?.[i]) return require_backend.AluExp.cast(castDtype, ar.#source);
|
|
1678
|
+
else return ar.#source;
|
|
1679
|
+
});
|
|
1604
1680
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1605
|
-
const exp$4 = custom(
|
|
1606
|
-
return new Array$1(
|
|
1681
|
+
const exp$4 = custom(sources);
|
|
1682
|
+
return new Array$1({
|
|
1683
|
+
source: exp$4.simplify(),
|
|
1684
|
+
st: arrays[0].#st,
|
|
1685
|
+
dtype: exp$4.dtype,
|
|
1686
|
+
weakType,
|
|
1687
|
+
backend
|
|
1688
|
+
});
|
|
1607
1689
|
}
|
|
1608
|
-
const exp$3 = custom(arrays.map((ar) => {
|
|
1609
|
-
const src$1 =
|
|
1690
|
+
const exp$3 = custom(arrays.map((ar, i) => {
|
|
1691
|
+
const src$1 = sources[i];
|
|
1610
1692
|
if (ar.#st.contiguous) return src$1;
|
|
1611
1693
|
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1612
1694
|
}));
|
|
1613
1695
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1614
|
-
return new Array$1(
|
|
1696
|
+
return new Array$1({
|
|
1697
|
+
source: exp$3.simplify(),
|
|
1698
|
+
st,
|
|
1699
|
+
dtype: exp$3.dtype,
|
|
1700
|
+
weakType,
|
|
1701
|
+
backend
|
|
1702
|
+
});
|
|
1615
1703
|
}
|
|
1616
1704
|
let indices;
|
|
1617
1705
|
if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
|
|
@@ -1621,14 +1709,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1621
1709
|
}
|
|
1622
1710
|
const inputs = [];
|
|
1623
1711
|
const src = [];
|
|
1624
|
-
for (const ar of arrays
|
|
1625
|
-
|
|
1626
|
-
|
|
1627
|
-
|
|
1628
|
-
gid = inputs.
|
|
1629
|
-
|
|
1712
|
+
for (const [i, ar] of arrays.entries()) {
|
|
1713
|
+
let nextSrc;
|
|
1714
|
+
if (ar.#source instanceof require_backend.AluExp) nextSrc = require_backend.accessorAluExp(ar.#source, ar.#st, indices);
|
|
1715
|
+
else {
|
|
1716
|
+
let gid = inputs.indexOf(ar.#source);
|
|
1717
|
+
if (gid === -1) {
|
|
1718
|
+
gid = inputs.length;
|
|
1719
|
+
inputs.push(ar.#source);
|
|
1720
|
+
}
|
|
1721
|
+
nextSrc = require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
|
|
1630
1722
|
}
|
|
1631
|
-
|
|
1723
|
+
if (!dtypeOverride?.[i]) nextSrc = require_backend.AluExp.cast(castDtype, nextSrc);
|
|
1724
|
+
src.push(nextSrc);
|
|
1632
1725
|
}
|
|
1633
1726
|
const exp$2 = custom(src);
|
|
1634
1727
|
let re = void 0;
|
|
@@ -1642,12 +1735,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1642
1735
|
for (const exe of pending) exe.updateRc(1);
|
|
1643
1736
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1644
1737
|
for (const ar of arrays) ar.dispose();
|
|
1645
|
-
return new Array$1(
|
|
1738
|
+
return new Array$1({
|
|
1739
|
+
source: output,
|
|
1740
|
+
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1741
|
+
dtype: kernel.dtype,
|
|
1742
|
+
weakType,
|
|
1743
|
+
backend,
|
|
1744
|
+
pending
|
|
1745
|
+
});
|
|
1646
1746
|
}
|
|
1647
1747
|
/** Reduce the last dimension of the array by an operation. */
|
|
1648
1748
|
#reduce(op) {
|
|
1649
|
-
this.#check();
|
|
1650
|
-
if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
|
|
1651
1749
|
const shape$1 = this.shape;
|
|
1652
1750
|
const reduction = new require_backend.Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
|
|
1653
1751
|
const newShape = shape$1.slice(0, -1);
|
|
@@ -1666,7 +1764,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1666
1764
|
for (const exe of pending) exe.updateRc(1);
|
|
1667
1765
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1668
1766
|
this.dispose();
|
|
1669
|
-
return
|
|
1767
|
+
return this.#newArrayFrom({
|
|
1768
|
+
source: output,
|
|
1769
|
+
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1770
|
+
pending
|
|
1771
|
+
});
|
|
1670
1772
|
}
|
|
1671
1773
|
/**
|
|
1672
1774
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1702,15 +1804,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1702
1804
|
}
|
|
1703
1805
|
#dataInline() {
|
|
1704
1806
|
this.#check();
|
|
1705
|
-
|
|
1706
|
-
const ar =
|
|
1807
|
+
if (!(this.#source instanceof require_backend.AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
|
|
1808
|
+
const ar = this.#newArrayFrom({ backend: require_backend.getBackend("cpu") });
|
|
1707
1809
|
this.dispose();
|
|
1708
1810
|
return ar.dataSync();
|
|
1709
1811
|
}
|
|
1710
1812
|
static #broadcastArrays(arrays) {
|
|
1711
1813
|
if (arrays.length === 0) throw new Error("Need at least one array to broadcast");
|
|
1712
1814
|
if (arrays.length === 1) return arrays;
|
|
1713
|
-
const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
|
|
1815
|
+
const newShape = arrays.map((a) => a.shape).reduce(require_backend.generalBroadcast);
|
|
1714
1816
|
return arrays.map((ar) => {
|
|
1715
1817
|
if (require_backend.deepEqual(ar.shape, newShape)) return ar;
|
|
1716
1818
|
return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
|
|
@@ -1739,8 +1841,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1739
1841
|
*
|
|
1740
1842
|
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1741
1843
|
* dispatch of operations as well.
|
|
1844
|
+
*
|
|
1845
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1846
|
+
* asynchronously for multiple arrays.
|
|
1742
1847
|
*/
|
|
1743
|
-
async
|
|
1848
|
+
async blockUntilReady() {
|
|
1744
1849
|
this.#check();
|
|
1745
1850
|
if (this.#source instanceof require_backend.AluExp) return this;
|
|
1746
1851
|
const pending = this.#pending;
|
|
@@ -1806,7 +1911,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1806
1911
|
return [x.#binary(require_backend.AluOp.Idiv, y)];
|
|
1807
1912
|
},
|
|
1808
1913
|
[Primitive.Neg]([x]) {
|
|
1809
|
-
return [zerosLike(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
1914
|
+
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
1810
1915
|
},
|
|
1811
1916
|
[Primitive.Reciprocal]([x]) {
|
|
1812
1917
|
return [x.#unary(require_backend.AluOp.Reciprocal)];
|
|
@@ -1826,14 +1931,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1826
1931
|
x.#backend.incRef(x.#source);
|
|
1827
1932
|
const pending = x.#pending;
|
|
1828
1933
|
for (const exe of pending) exe.updateRc(1);
|
|
1829
|
-
const y =
|
|
1934
|
+
const y = x.#newArrayFrom({
|
|
1935
|
+
dtype,
|
|
1936
|
+
weakType: false,
|
|
1937
|
+
pending
|
|
1938
|
+
});
|
|
1830
1939
|
x.dispose();
|
|
1831
1940
|
return [y];
|
|
1832
1941
|
}
|
|
1833
1942
|
},
|
|
1834
1943
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
1835
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1836
|
-
if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1944
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1945
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1837
1946
|
const c0 = zeros(shape$1, {
|
|
1838
1947
|
dtype: require_backend.DType.Uint32,
|
|
1839
1948
|
device: k0.device
|
|
@@ -1856,6 +1965,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1856
1965
|
[Primitive.Cos]([x]) {
|
|
1857
1966
|
return [x.#unary(require_backend.AluOp.Cos)];
|
|
1858
1967
|
},
|
|
1968
|
+
[Primitive.Asin]([x]) {
|
|
1969
|
+
return [x.#unary(require_backend.AluOp.Asin)];
|
|
1970
|
+
},
|
|
1971
|
+
[Primitive.Atan]([x]) {
|
|
1972
|
+
return [x.#unary(require_backend.AluOp.Atan)];
|
|
1973
|
+
},
|
|
1859
1974
|
[Primitive.Exp]([x]) {
|
|
1860
1975
|
return [x.#unary(require_backend.AluOp.Exp)];
|
|
1861
1976
|
},
|
|
@@ -1895,7 +2010,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1895
2010
|
},
|
|
1896
2011
|
[Primitive.Compare]([x, y], { op }) {
|
|
1897
2012
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1898
|
-
return [Array$1.#naryCustom("compare", custom, [x, y], {
|
|
2013
|
+
return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
|
|
1899
2014
|
},
|
|
1900
2015
|
[Primitive.Where]([cond, x, y]) {
|
|
1901
2016
|
const custom = ([cond$1, x$1, y$1]) => require_backend.AluExp.where(cond$1, x$1, y$1);
|
|
@@ -1941,7 +2056,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1941
2056
|
pending.splice(0, 0, ...prevPending);
|
|
1942
2057
|
args.forEach((x) => x.dispose());
|
|
1943
2058
|
return outputs.map((source, i) => {
|
|
1944
|
-
return new Array$1(
|
|
2059
|
+
return new Array$1({
|
|
2060
|
+
source,
|
|
2061
|
+
st: require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
|
|
2062
|
+
dtype: jaxpr.outs[i].aval.dtype,
|
|
2063
|
+
weakType: jaxpr.outs[i].aval.weakType,
|
|
2064
|
+
backend,
|
|
2065
|
+
pending
|
|
2066
|
+
});
|
|
1945
2067
|
});
|
|
1946
2068
|
}
|
|
1947
2069
|
};
|
|
@@ -1951,33 +2073,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1951
2073
|
return this.#source;
|
|
1952
2074
|
}
|
|
1953
2075
|
};
|
|
1954
|
-
/** Construct an array from a single scalar constant. */
|
|
1955
|
-
function scalar(value, { dtype, device } = {}) {
|
|
1956
|
-
if (typeof value === "number") {
|
|
1957
|
-
dtype ??= require_backend.DType.Float32;
|
|
1958
|
-
if (![
|
|
1959
|
-
require_backend.DType.Float32,
|
|
1960
|
-
require_backend.DType.Float16,
|
|
1961
|
-
require_backend.DType.Int32,
|
|
1962
|
-
require_backend.DType.Uint32
|
|
1963
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1964
|
-
} else if (typeof value === "boolean") {
|
|
1965
|
-
dtype ??= require_backend.DType.Bool;
|
|
1966
|
-
if (![
|
|
1967
|
-
require_backend.DType.Float32,
|
|
1968
|
-
require_backend.DType.Float16,
|
|
1969
|
-
require_backend.DType.Int32,
|
|
1970
|
-
require_backend.DType.Uint32,
|
|
1971
|
-
require_backend.DType.Bool
|
|
1972
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1973
|
-
} else throw new TypeError(`Invalid type for scalar ${value}`);
|
|
1974
|
-
return new Array$1(require_backend.AluExp.const(dtype, value), require_backend.ShapeTracker.fromShape([]), dtype, require_backend.getBackend(device));
|
|
1975
|
-
}
|
|
1976
2076
|
/** Constructor for creating a new array from data. */
|
|
1977
2077
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
1978
2078
|
if (values instanceof Tracer) {
|
|
1979
2079
|
if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1980
|
-
if (dtype && values.dtype !== dtype)
|
|
2080
|
+
if (dtype && values.dtype !== dtype) values = values.astype(dtype);
|
|
1981
2081
|
return values;
|
|
1982
2082
|
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1983
2083
|
dtype,
|
|
@@ -1999,6 +2099,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1999
2099
|
dtype,
|
|
2000
2100
|
device
|
|
2001
2101
|
});
|
|
2102
|
+
if (size$1 === 1) return full(shape$1, flat[0], {
|
|
2103
|
+
dtype,
|
|
2104
|
+
device
|
|
2105
|
+
});
|
|
2002
2106
|
if (typeof flat[0] === "boolean") {
|
|
2003
2107
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
2004
2108
|
const data = new Int32Array(flat.map((x) => x ? 1 : 0));
|
|
@@ -2007,46 +2111,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2007
2111
|
device
|
|
2008
2112
|
});
|
|
2009
2113
|
} else {
|
|
2114
|
+
const weakType = dtype == void 0;
|
|
2010
2115
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2011
2116
|
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
2012
2117
|
return arrayFromData(data, shape$1, {
|
|
2013
2118
|
dtype,
|
|
2014
2119
|
device
|
|
2015
|
-
});
|
|
2120
|
+
}, weakType);
|
|
2016
2121
|
}
|
|
2017
2122
|
}
|
|
2018
2123
|
}
|
|
2019
|
-
function arrayFromData(data, shape$1, { dtype, device } =
|
|
2124
|
+
function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
2125
|
+
if (data instanceof Float32Array) {
|
|
2126
|
+
if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2127
|
+
dtype ??= require_backend.DType.Float32;
|
|
2128
|
+
} else if (data instanceof Int32Array) {
|
|
2129
|
+
if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2130
|
+
dtype ??= require_backend.DType.Int32;
|
|
2131
|
+
} else if (data instanceof Uint32Array) {
|
|
2132
|
+
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2133
|
+
dtype ??= require_backend.DType.Uint32;
|
|
2134
|
+
} else if (data instanceof Float16Array) {
|
|
2135
|
+
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2136
|
+
dtype ??= require_backend.DType.Float16;
|
|
2137
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2020
2138
|
if (data.length < inlineArrayLimit) {
|
|
2021
2139
|
let allEqual = true;
|
|
2022
2140
|
for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
|
|
2023
2141
|
allEqual = false;
|
|
2024
2142
|
break;
|
|
2025
2143
|
}
|
|
2026
|
-
if (allEqual)
|
|
2027
|
-
dtype,
|
|
2028
|
-
device
|
|
2029
|
-
}
|
|
2144
|
+
if (allEqual) {
|
|
2145
|
+
const sa = new ShapedArray(shape$1, dtype, weakType);
|
|
2146
|
+
return fullInternal(sa, data[0], device);
|
|
2147
|
+
}
|
|
2030
2148
|
}
|
|
2031
2149
|
const backend = require_backend.getBackend(device);
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2039
|
-
|
|
2040
|
-
|
|
2041
|
-
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2042
|
-
dtype ??= require_backend.DType.Uint32;
|
|
2043
|
-
} else if (data instanceof Float16Array) {
|
|
2044
|
-
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2045
|
-
dtype ??= require_backend.DType.Float16;
|
|
2046
|
-
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2047
|
-
const slot = backend.malloc(data.byteLength, buf);
|
|
2048
|
-
return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2049
|
-
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
2150
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2151
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2152
|
+
return new Array$1({
|
|
2153
|
+
source: slot,
|
|
2154
|
+
st: require_backend.ShapeTracker.fromShape(shape$1),
|
|
2155
|
+
dtype,
|
|
2156
|
+
weakType,
|
|
2157
|
+
backend
|
|
2158
|
+
});
|
|
2050
2159
|
}
|
|
2051
2160
|
function dataToJs(dtype, data, shape$1) {
|
|
2052
2161
|
if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -2062,7 +2171,7 @@ function dataToJs(dtype, data, shape$1) {
|
|
|
2062
2171
|
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
2063
2172
|
function pureArray(x) {
|
|
2064
2173
|
if (x instanceof Tracer) return x;
|
|
2065
|
-
else return
|
|
2174
|
+
else return array(x);
|
|
2066
2175
|
}
|
|
2067
2176
|
var EvalTrace = class extends Trace {
|
|
2068
2177
|
pure = (x) => pureArray(x);
|
|
@@ -2073,20 +2182,27 @@ var EvalTrace = class extends Trace {
|
|
|
2073
2182
|
};
|
|
2074
2183
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2075
2184
|
const implRules = Array$1._implRules();
|
|
2076
|
-
function
|
|
2077
|
-
|
|
2078
|
-
|
|
2079
|
-
|
|
2185
|
+
function fullInternal(aval, fillValue, device) {
|
|
2186
|
+
return new Array$1({
|
|
2187
|
+
source: require_backend.AluExp.const(aval.dtype, fillValue),
|
|
2188
|
+
st: require_backend.ShapeTracker.fromShape(aval.shape),
|
|
2189
|
+
dtype: aval.dtype,
|
|
2190
|
+
weakType: aval.weakType,
|
|
2191
|
+
backend: require_backend.getBackend(device)
|
|
2192
|
+
});
|
|
2080
2193
|
}
|
|
2081
|
-
function
|
|
2082
|
-
|
|
2083
|
-
|
|
2084
|
-
|
|
2194
|
+
function zerosLike$1(val, dtype) {
|
|
2195
|
+
return fullLike(val, 0, dtype);
|
|
2196
|
+
}
|
|
2197
|
+
function onesLike$1(val, dtype) {
|
|
2198
|
+
return fullLike(val, 1, dtype);
|
|
2085
2199
|
}
|
|
2086
2200
|
function fullLike(val, fillValue, dtype) {
|
|
2087
2201
|
const aval = getAval(val);
|
|
2088
2202
|
if (val instanceof Tracer) val.dispose();
|
|
2089
|
-
|
|
2203
|
+
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
2204
|
+
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
2205
|
+
return fullInternal(sa, fillValue);
|
|
2090
2206
|
}
|
|
2091
2207
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
2092
2208
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -2104,19 +2220,14 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
2104
2220
|
}
|
|
2105
2221
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
2106
2222
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
2107
|
-
let
|
|
2108
|
-
if (typeof fillValue === "number")
|
|
2109
|
-
|
|
2110
|
-
source = require_backend.AluExp.const(dtype, fillValue);
|
|
2111
|
-
} else if (typeof fillValue === "bigint") {
|
|
2112
|
-
dtype = dtype ?? require_backend.DType.Int32;
|
|
2113
|
-
source = require_backend.AluExp.const(dtype, Number(fillValue));
|
|
2114
|
-
} else if (typeof fillValue === "boolean") {
|
|
2223
|
+
let weakType = dtype == void 0;
|
|
2224
|
+
if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
|
|
2225
|
+
else if (typeof fillValue === "boolean") {
|
|
2115
2226
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
2116
|
-
|
|
2227
|
+
weakType = false;
|
|
2117
2228
|
} else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
|
|
2118
2229
|
else throw new TypeError(`Invalid type for full: ${fillValue}`);
|
|
2119
|
-
return new
|
|
2230
|
+
return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
|
|
2120
2231
|
}
|
|
2121
2232
|
/**
|
|
2122
2233
|
* Create an identity matrix.
|
|
@@ -2126,6 +2237,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
2126
2237
|
*/
|
|
2127
2238
|
function eye(numRows, numCols, { dtype, device } = {}) {
|
|
2128
2239
|
numCols = numCols ?? numRows;
|
|
2240
|
+
const weakType = dtype == void 0;
|
|
2129
2241
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2130
2242
|
if (numCols < numRows) {
|
|
2131
2243
|
const arr = eye(numCols, numRows, {
|
|
@@ -2139,9 +2251,15 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2139
2251
|
device
|
|
2140
2252
|
});
|
|
2141
2253
|
const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
|
|
2142
|
-
return new Array$1(
|
|
2254
|
+
return new Array$1({
|
|
2255
|
+
source: require_backend.AluExp.cast(dtype, exp$2),
|
|
2256
|
+
st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
|
|
2257
|
+
dtype,
|
|
2258
|
+
weakType,
|
|
2259
|
+
backend: require_backend.getBackend(device)
|
|
2260
|
+
});
|
|
2143
2261
|
}
|
|
2144
|
-
/** Return the identity
|
|
2262
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2145
2263
|
function identity$1(n, { dtype, device } = {}) {
|
|
2146
2264
|
return eye(n, n, {
|
|
2147
2265
|
dtype,
|
|
@@ -2176,7 +2294,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2176
2294
|
});
|
|
2177
2295
|
const exp$2 = require_backend.AluExp.add(require_backend.AluExp.const(dtype, start), require_backend.AluExp.mul(require_backend.AluExp.cast(dtype, require_backend.AluVar.idx), require_backend.AluExp.const(dtype, step)));
|
|
2178
2296
|
const st = require_backend.ShapeTracker.fromShape([size$1]);
|
|
2179
|
-
return new Array$1(
|
|
2297
|
+
return new Array$1({
|
|
2298
|
+
source: exp$2,
|
|
2299
|
+
st,
|
|
2300
|
+
dtype,
|
|
2301
|
+
weakType: false,
|
|
2302
|
+
backend: require_backend.getBackend(device)
|
|
2303
|
+
});
|
|
2180
2304
|
}
|
|
2181
2305
|
/**
|
|
2182
2306
|
* Return evenly spaced numbers over a specified interval.
|
|
@@ -2194,10 +2318,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2194
2318
|
dtype,
|
|
2195
2319
|
device
|
|
2196
2320
|
});
|
|
2197
|
-
else if (num === 1) return
|
|
2321
|
+
else if (num === 1) return full([1], start, {
|
|
2198
2322
|
dtype,
|
|
2199
2323
|
device
|
|
2200
|
-
})
|
|
2324
|
+
});
|
|
2201
2325
|
else if (start === stop) return full([num], start, {
|
|
2202
2326
|
dtype,
|
|
2203
2327
|
device
|
|
@@ -2206,7 +2330,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2206
2330
|
const denom = endpoint ? num - 1 : num;
|
|
2207
2331
|
const exp$2 = require_backend.AluExp.cast(dtype, require_backend.AluExp.add(require_backend.AluExp.f32(start), require_backend.AluExp.mul(require_backend.AluExp.f32(delta / denom), require_backend.AluExp.cast(require_backend.DType.Float32, require_backend.AluVar.idx))));
|
|
2208
2332
|
const st = require_backend.ShapeTracker.fromShape([num]);
|
|
2209
|
-
return new Array$1(
|
|
2333
|
+
return new Array$1({
|
|
2334
|
+
source: exp$2,
|
|
2335
|
+
st,
|
|
2336
|
+
dtype,
|
|
2337
|
+
weakType: false,
|
|
2338
|
+
backend: require_backend.getBackend(device)
|
|
2339
|
+
});
|
|
2210
2340
|
}
|
|
2211
2341
|
function aluCompare(a, b, op) {
|
|
2212
2342
|
switch (op) {
|
|
@@ -2218,35 +2348,6 @@ function aluCompare(a, b, op) {
|
|
|
2218
2348
|
case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
|
|
2219
2349
|
}
|
|
2220
2350
|
}
|
|
2221
|
-
/**
|
|
2222
|
-
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
2223
|
-
*
|
|
2224
|
-
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
2225
|
-
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
2226
|
-
* Two dimensions are compatible when:
|
|
2227
|
-
* 1. they are equal, or
|
|
2228
|
-
* 2. one of them is 1."
|
|
2229
|
-
*
|
|
2230
|
-
* Throws a TypeError if the broadcast is not possible.
|
|
2231
|
-
*
|
|
2232
|
-
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
2233
|
-
*/
|
|
2234
|
-
function generalBroadcast(a, b) {
|
|
2235
|
-
const out = [];
|
|
2236
|
-
let i = a.length - 1;
|
|
2237
|
-
let j = b.length - 1;
|
|
2238
|
-
for (; i >= 0 && j >= 0; i--, j--) {
|
|
2239
|
-
const x = a[i];
|
|
2240
|
-
const y = b[j];
|
|
2241
|
-
if (x === y) out.push(x);
|
|
2242
|
-
else if (x === 1) out.push(y);
|
|
2243
|
-
else if (y === 1) out.push(x);
|
|
2244
|
-
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
2245
|
-
}
|
|
2246
|
-
for (; i >= 0; i--) out.push(a[i]);
|
|
2247
|
-
for (; j >= 0; j--) out.push(b[j]);
|
|
2248
|
-
return out.reverse();
|
|
2249
|
-
}
|
|
2250
2351
|
|
|
2251
2352
|
//#endregion
|
|
2252
2353
|
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
|
|
@@ -2326,13 +2427,15 @@ var Var = class Var {
|
|
|
2326
2427
|
};
|
|
2327
2428
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
2328
2429
|
var Lit = class {
|
|
2329
|
-
dtype;
|
|
2330
2430
|
value;
|
|
2331
2431
|
aval;
|
|
2332
|
-
|
|
2333
|
-
this.dtype
|
|
2432
|
+
get dtype() {
|
|
2433
|
+
return this.aval.dtype;
|
|
2434
|
+
}
|
|
2435
|
+
constructor(aval, value) {
|
|
2436
|
+
if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
|
|
2334
2437
|
this.value = value;
|
|
2335
|
-
this.aval =
|
|
2438
|
+
this.aval = ShapedArray.fromAval(aval);
|
|
2336
2439
|
}
|
|
2337
2440
|
};
|
|
2338
2441
|
function atomIsLit(atom, literal) {
|
|
@@ -2421,16 +2524,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2421
2524
|
varIds.set(v, require_backend.FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
|
|
2422
2525
|
return id;
|
|
2423
2526
|
};
|
|
2424
|
-
hasher.update(this.inBinders.length
|
|
2425
|
-
|
|
2426
|
-
|
|
2427
|
-
|
|
2428
|
-
|
|
2429
|
-
|
|
2430
|
-
eqn.
|
|
2431
|
-
|
|
2432
|
-
|
|
2433
|
-
|
|
2527
|
+
hasher.update(this.inBinders.length);
|
|
2528
|
+
for (const x of this.inBinders) hasher.update(vi(x));
|
|
2529
|
+
hasher.update(this.eqns.length);
|
|
2530
|
+
for (const eqn of this.eqns) {
|
|
2531
|
+
hasher.update(eqn.primitive);
|
|
2532
|
+
hasher.update(eqn.inputs.length);
|
|
2533
|
+
for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2534
|
+
hasher.update(JSON.stringify(eqn.params));
|
|
2535
|
+
hasher.update(eqn.outBinders.length);
|
|
2536
|
+
for (const x of eqn.outBinders) hasher.update(vi(x));
|
|
2537
|
+
}
|
|
2538
|
+
hasher.update(this.outs.length);
|
|
2539
|
+
for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2434
2540
|
return this.#hash = hasher.value;
|
|
2435
2541
|
}
|
|
2436
2542
|
hash(state) {
|
|
@@ -2453,21 +2559,26 @@ var Jaxpr = class Jaxpr {
|
|
|
2453
2559
|
const c = eqn.outBinders[0];
|
|
2454
2560
|
if (atomIsLit(a, 0)) context.set(c, b);
|
|
2455
2561
|
else if (atomIsLit(b, 0)) context.set(c, a);
|
|
2456
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2562
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
|
|
2563
|
+
else newEqns.push(eqn);
|
|
2564
|
+
} else if (eqn.primitive === Primitive.Neg) {
|
|
2565
|
+
const [a] = inputs;
|
|
2566
|
+
const c = eqn.outBinders[0];
|
|
2567
|
+
if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
|
|
2457
2568
|
else newEqns.push(eqn);
|
|
2458
2569
|
} else if (eqn.primitive === Primitive.Mul) {
|
|
2459
2570
|
const [a, b] = inputs;
|
|
2460
2571
|
const c = eqn.outBinders[0];
|
|
2461
2572
|
if (atomIsLit(a, 1)) context.set(c, b);
|
|
2462
2573
|
else if (atomIsLit(b, 1)) context.set(c, a);
|
|
2463
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2574
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
|
|
2464
2575
|
else newEqns.push(eqn);
|
|
2465
2576
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
2466
2577
|
const [a, b] = inputs;
|
|
2467
2578
|
const c = eqn.outBinders[0];
|
|
2468
2579
|
if (atomIsLit(b, 1)) context.set(c, a);
|
|
2469
2580
|
else newEqns.push(eqn);
|
|
2470
|
-
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2581
|
+
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2471
2582
|
else newEqns.push(eqn);
|
|
2472
2583
|
}
|
|
2473
2584
|
const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
|
|
@@ -2558,7 +2669,7 @@ function evalJaxpr(jaxpr, args) {
|
|
|
2558
2669
|
if (x instanceof Var) {
|
|
2559
2670
|
remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
|
|
2560
2671
|
return env.get(x);
|
|
2561
|
-
} else return
|
|
2672
|
+
} else return array(x.value, { dtype: x.dtype });
|
|
2562
2673
|
};
|
|
2563
2674
|
const write = (v, val) => {
|
|
2564
2675
|
if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
|
|
@@ -2617,7 +2728,7 @@ var JaxprTrace = class extends Trace {
|
|
|
2617
2728
|
let tracer = this.builder.constTracers.get(val);
|
|
2618
2729
|
if (tracer === void 0) {
|
|
2619
2730
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
2620
|
-
this.builder.addConst(tracer, val instanceof Tracer ? val.ref :
|
|
2731
|
+
this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
|
|
2621
2732
|
}
|
|
2622
2733
|
return tracer;
|
|
2623
2734
|
}
|
|
@@ -2686,7 +2797,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2686
2797
|
const newConsts = [];
|
|
2687
2798
|
for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
|
|
2688
2799
|
const ar = consts[i];
|
|
2689
|
-
literals.set(jaxpr.inBinders[i], new Lit(ar.
|
|
2800
|
+
literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
|
|
2690
2801
|
} else {
|
|
2691
2802
|
constBinders.push(jaxpr.inBinders[i]);
|
|
2692
2803
|
newConsts.push(consts[i]);
|
|
@@ -2699,13 +2810,12 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2699
2810
|
}
|
|
2700
2811
|
function binopAbstractEval([x, y]) {
|
|
2701
2812
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
2702
|
-
|
|
2703
|
-
return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
|
|
2813
|
+
return [promoteAvals(x, y)];
|
|
2704
2814
|
}
|
|
2705
2815
|
function compareAbstractEval([x, y]) {
|
|
2706
2816
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
|
|
2707
|
-
|
|
2708
|
-
return [new ShapedArray(
|
|
2817
|
+
const aval = promoteAvals(x, y);
|
|
2818
|
+
return [new ShapedArray(aval.shape, require_backend.DType.Bool, false)];
|
|
2709
2819
|
}
|
|
2710
2820
|
function vectorizedUnopAbstractEval([x]) {
|
|
2711
2821
|
return [ShapedArray.fromAval(x)];
|
|
@@ -2718,21 +2828,23 @@ const abstractEvalRules = {
|
|
|
2718
2828
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
2719
2829
|
[Primitive.StopGradient]: vectorizedUnopAbstractEval,
|
|
2720
2830
|
[Primitive.Cast]([x], { dtype }) {
|
|
2721
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2831
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2722
2832
|
},
|
|
2723
2833
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
2724
2834
|
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
2725
2835
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
2726
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2836
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2727
2837
|
},
|
|
2728
2838
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
2729
2839
|
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
2730
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2731
|
-
if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2732
|
-
return [new ShapedArray(shape$1, require_backend.DType.Uint32)];
|
|
2840
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2841
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2842
|
+
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
2733
2843
|
},
|
|
2734
2844
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2735
2845
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2846
|
+
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
2847
|
+
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2736
2848
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2737
2849
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2738
2850
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
@@ -2741,55 +2853,54 @@ const abstractEvalRules = {
|
|
|
2741
2853
|
[Primitive.Reduce]([x], { axis }) {
|
|
2742
2854
|
const axisSet = new Set(axis);
|
|
2743
2855
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2744
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2856
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2745
2857
|
},
|
|
2746
2858
|
[Primitive.Pool]([x], { window, strides }) {
|
|
2747
2859
|
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2748
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2860
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2749
2861
|
},
|
|
2750
2862
|
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2751
2863
|
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2752
2864
|
if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2753
|
-
return [new ShapedArray(inShape, x.dtype)];
|
|
2865
|
+
return [new ShapedArray(inShape, x.dtype, x.weakType)];
|
|
2754
2866
|
},
|
|
2755
2867
|
[Primitive.Dot]([x, y]) {
|
|
2756
|
-
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2757
2868
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
2758
|
-
const shape$1 =
|
|
2869
|
+
const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
|
|
2759
2870
|
shape$1.splice(-1, 1);
|
|
2760
|
-
return [new ShapedArray(shape$1,
|
|
2871
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2761
2872
|
},
|
|
2762
2873
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
2763
|
-
|
|
2874
|
+
const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
|
|
2764
2875
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2765
|
-
return [new ShapedArray(shape$1,
|
|
2876
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2766
2877
|
},
|
|
2767
2878
|
[Primitive.Compare]: compareAbstractEval,
|
|
2768
2879
|
[Primitive.Where]([cond, x, y]) {
|
|
2769
2880
|
if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
2770
|
-
|
|
2771
|
-
const shape$1 = generalBroadcast(cond.shape,
|
|
2772
|
-
return [new ShapedArray(shape$1,
|
|
2881
|
+
const xy = promoteAvals(x, y);
|
|
2882
|
+
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
2883
|
+
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
2773
2884
|
},
|
|
2774
2885
|
[Primitive.Transpose]([x], { perm }) {
|
|
2775
|
-
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
|
|
2886
|
+
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
2776
2887
|
},
|
|
2777
2888
|
[Primitive.Broadcast]([x], { shape: shape$1 }) {
|
|
2778
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2889
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2779
2890
|
},
|
|
2780
2891
|
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2781
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2892
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2782
2893
|
},
|
|
2783
2894
|
[Primitive.Flip]([x], _) {
|
|
2784
|
-
return [
|
|
2895
|
+
return [ShapedArray.fromAval(x)];
|
|
2785
2896
|
},
|
|
2786
2897
|
[Primitive.Shrink]([x], { slice }) {
|
|
2787
2898
|
const newShape = slice.map((s) => s[1] - s[0]);
|
|
2788
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2899
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2789
2900
|
},
|
|
2790
2901
|
[Primitive.Pad]([x], { width }) {
|
|
2791
2902
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
2792
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2903
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2793
2904
|
},
|
|
2794
2905
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2795
2906
|
for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
@@ -2799,10 +2910,10 @@ const abstractEvalRules = {
|
|
|
2799
2910
|
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
2800
2911
|
const axisSet = new Set(axis);
|
|
2801
2912
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
2802
|
-
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
2913
|
+
const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
|
|
2803
2914
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2804
2915
|
newShape.splice(outDim, 0, ...gatherShape);
|
|
2805
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2916
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2806
2917
|
},
|
|
2807
2918
|
[Primitive.JitCall](args, { jaxpr }) {
|
|
2808
2919
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
@@ -2860,7 +2971,7 @@ function makeJaxpr$1(f, opts) {
|
|
|
2860
2971
|
function jit$1(f, opts) {
|
|
2861
2972
|
const cache = /* @__PURE__ */ new Map();
|
|
2862
2973
|
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2863
|
-
|
|
2974
|
+
const result = ((...args) => {
|
|
2864
2975
|
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2865
2976
|
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2866
2977
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
@@ -2869,11 +2980,16 @@ function jit$1(f, opts) {
|
|
|
2869
2980
|
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2870
2981
|
const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2871
2982
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2983
|
+
name: f.name || "closure",
|
|
2872
2984
|
jaxpr,
|
|
2873
2985
|
numConsts: consts.length
|
|
2874
2986
|
});
|
|
2875
2987
|
return unflatten(outTree, outs);
|
|
2876
2988
|
});
|
|
2989
|
+
result.dispose = () => {
|
|
2990
|
+
for (const { consts } of cache.values()) for (const c of consts) c.dispose();
|
|
2991
|
+
};
|
|
2992
|
+
return result;
|
|
2877
2993
|
}
|
|
2878
2994
|
|
|
2879
2995
|
//#endregion
|
|
@@ -2905,7 +3021,7 @@ var JVPTrace = class extends Trace {
|
|
|
2905
3021
|
return this.lift(pureArray(val));
|
|
2906
3022
|
}
|
|
2907
3023
|
lift(val) {
|
|
2908
|
-
return new JVPTracer(this, val, zerosLike(val.ref));
|
|
3024
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
2909
3025
|
}
|
|
2910
3026
|
processPrimitive(primitive, tracers, params) {
|
|
2911
3027
|
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2936,7 +3052,7 @@ function zeroTangentsJvp(primitive) {
|
|
|
2936
3052
|
return (primals, tangents, params) => {
|
|
2937
3053
|
for (const t of tangents) t.dispose();
|
|
2938
3054
|
const ys = bind(primitive, primals, params);
|
|
2939
|
-
return [ys, ys.map((y) => zerosLike(y.ref))];
|
|
3055
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
2940
3056
|
};
|
|
2941
3057
|
}
|
|
2942
3058
|
const jvpRules = {
|
|
@@ -2954,13 +3070,13 @@ const jvpRules = {
|
|
|
2954
3070
|
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2955
3071
|
else {
|
|
2956
3072
|
dx.dispose();
|
|
2957
|
-
return [[cast(x.ref, dtype)], [zerosLike(x)]];
|
|
3073
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2958
3074
|
}
|
|
2959
3075
|
},
|
|
2960
3076
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2961
3077
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2962
3078
|
dx.dispose();
|
|
2963
|
-
return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
|
|
3079
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2964
3080
|
},
|
|
2965
3081
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2966
3082
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2969,6 +3085,14 @@ const jvpRules = {
|
|
|
2969
3085
|
[Primitive.Cos]([x], [dx]) {
|
|
2970
3086
|
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
2971
3087
|
},
|
|
3088
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3089
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3090
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3091
|
+
},
|
|
3092
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3093
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3094
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3095
|
+
},
|
|
2972
3096
|
[Primitive.Exp]([x], [dx]) {
|
|
2973
3097
|
const z = exp$1(x);
|
|
2974
3098
|
return [[z.ref], [z.mul(dx)]];
|
|
@@ -3019,13 +3143,14 @@ const jvpRules = {
|
|
|
3019
3143
|
const indicesRef = indices.map((t) => t.ref);
|
|
3020
3144
|
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3021
3145
|
},
|
|
3022
|
-
[Primitive.JitCall](primals, tangents, { jaxpr }) {
|
|
3146
|
+
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3023
3147
|
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3024
3148
|
const outs = bind(Primitive.JitCall, [
|
|
3025
3149
|
...newConsts.map((c) => c.ref),
|
|
3026
3150
|
...primals,
|
|
3027
3151
|
...tangents
|
|
3028
3152
|
], {
|
|
3153
|
+
name: `${name}_jvp`,
|
|
3029
3154
|
jaxpr: newJaxpr,
|
|
3030
3155
|
numConsts: newConsts.length
|
|
3031
3156
|
});
|
|
@@ -3080,12 +3205,15 @@ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
|
3080
3205
|
function mappedAval(batchDim, aval) {
|
|
3081
3206
|
const shape$1 = [...aval.shape];
|
|
3082
3207
|
shape$1.splice(batchDim, 1);
|
|
3083
|
-
return new ShapedArray(shape$1, aval.dtype);
|
|
3208
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3084
3209
|
}
|
|
3085
3210
|
/** Move one axis to a different index. */
|
|
3086
3211
|
function moveaxis$1(x, src, dst) {
|
|
3087
3212
|
const t = pureArray(x);
|
|
3088
|
-
|
|
3213
|
+
src = require_backend.checkAxis(src, t.ndim);
|
|
3214
|
+
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3215
|
+
if (src === dst) return t;
|
|
3216
|
+
const perm = require_backend.range(t.ndim);
|
|
3089
3217
|
perm.splice(src, 1);
|
|
3090
3218
|
perm.splice(dst, 0, src);
|
|
3091
3219
|
return transpose$1(t, perm);
|
|
@@ -3178,6 +3306,8 @@ const vmapRules = {
|
|
|
3178
3306
|
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
3179
3307
|
[Primitive.Sin]: unopBatcher(sin$1),
|
|
3180
3308
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3309
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3310
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3181
3311
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3182
3312
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3183
3313
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
@@ -3219,9 +3349,10 @@ const vmapRules = {
|
|
|
3219
3349
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3220
3350
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3221
3351
|
},
|
|
3222
|
-
[Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
|
|
3352
|
+
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3223
3353
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3224
3354
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
3355
|
+
name: `${name}_vmap`,
|
|
3225
3356
|
jaxpr: newJaxpr,
|
|
3226
3357
|
numConsts: newConsts.length
|
|
3227
3358
|
});
|
|
@@ -3237,7 +3368,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3237
3368
|
if (dims[i] === null) return v.aval;
|
|
3238
3369
|
const shape$1 = [...v.aval.shape];
|
|
3239
3370
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3240
|
-
return new ShapedArray(shape$1, v.aval.dtype);
|
|
3371
|
+
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3241
3372
|
});
|
|
3242
3373
|
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3243
3374
|
const result = {
|
|
@@ -3363,20 +3494,28 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3363
3494
|
function linearizeFlat(f, primalsIn) {
|
|
3364
3495
|
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
3365
3496
|
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3366
|
-
|
|
3497
|
+
const dispose$1 = () => {
|
|
3498
|
+
for (const c of consts) c.dispose();
|
|
3499
|
+
};
|
|
3500
|
+
return [
|
|
3501
|
+
primalsOut,
|
|
3502
|
+
fLin,
|
|
3503
|
+
dispose$1
|
|
3504
|
+
];
|
|
3367
3505
|
}
|
|
3368
3506
|
function linearize$1(f, ...primalsIn) {
|
|
3369
3507
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3370
3508
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3371
|
-
const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3509
|
+
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3372
3510
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
3373
3511
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3374
|
-
const fLin = (...tangentsIn) => {
|
|
3512
|
+
const fLin = ((...tangentsIn) => {
|
|
3375
3513
|
const [tangentsInFlat, inTree2] = flatten(tangentsIn);
|
|
3376
3514
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
|
|
3377
3515
|
const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
|
|
3378
3516
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
3379
|
-
};
|
|
3517
|
+
});
|
|
3518
|
+
fLin.dispose = dispose$1;
|
|
3380
3519
|
return [primalsOut, fLin];
|
|
3381
3520
|
}
|
|
3382
3521
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -3442,8 +3581,8 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3442
3581
|
processPrimitive(primitive, tracers, params) {
|
|
3443
3582
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3444
3583
|
if (primitive === Primitive.JitCall) {
|
|
3445
|
-
const { jaxpr, numConsts } = params;
|
|
3446
|
-
return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
|
|
3584
|
+
const { name, jaxpr, numConsts } = params;
|
|
3585
|
+
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3447
3586
|
}
|
|
3448
3587
|
const tracersIn = tracers.map((t) => this.instantiateConst(t));
|
|
3449
3588
|
const avalsIn = tracersIn.map((t) => t.pval.aval);
|
|
@@ -3469,12 +3608,13 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3469
3608
|
*
|
|
3470
3609
|
* Used when encountering a JitCall rule during the trace.
|
|
3471
3610
|
*/
|
|
3472
|
-
#partialEvalJaxpr(jaxpr, numConsts, tracers) {
|
|
3611
|
+
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3473
3612
|
jaxpr = jaxpr.flatten();
|
|
3474
3613
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3475
3614
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3476
3615
|
const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
|
|
3477
3616
|
const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3617
|
+
name: `${name}_peval`,
|
|
3478
3618
|
jaxpr: jaxpr1,
|
|
3479
3619
|
numConsts: 0
|
|
3480
3620
|
});
|
|
@@ -3486,13 +3626,17 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3486
3626
|
prim: Primitive.JitCall,
|
|
3487
3627
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3488
3628
|
params: {
|
|
3629
|
+
name: `${name}_resid`,
|
|
3489
3630
|
jaxpr: jaxpr2,
|
|
3490
3631
|
numConsts: 0
|
|
3491
3632
|
},
|
|
3492
3633
|
avalsOut: jaxpr2.outs.map((x) => x.aval),
|
|
3493
3634
|
tracerRefsOut: []
|
|
3494
3635
|
};
|
|
3495
|
-
const outs2 = jaxpr2.outs.map((x) =>
|
|
3636
|
+
const outs2 = jaxpr2.outs.map((x, i$1) => {
|
|
3637
|
+
if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
|
|
3638
|
+
return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
|
|
3639
|
+
});
|
|
3496
3640
|
recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
|
|
3497
3641
|
let i = 0;
|
|
3498
3642
|
let j = 0;
|
|
@@ -3576,13 +3720,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3576
3720
|
const [consts, constvars] = require_backend.unzip2(constToVar.entries());
|
|
3577
3721
|
const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
|
|
3578
3722
|
const outVars = tracersOut.map((t) => tracerToVar.get(t));
|
|
3579
|
-
|
|
3723
|
+
let jaxpr = new Jaxpr(inBinders, eqns, outVars);
|
|
3580
3724
|
typecheckJaxpr(jaxpr);
|
|
3581
3725
|
for (const t of consts) t.ref;
|
|
3582
3726
|
for (const t of tracersIn) t.dispose();
|
|
3583
3727
|
for (const t of tracersOut) t.dispose();
|
|
3728
|
+
jaxpr = jaxpr.simplify();
|
|
3729
|
+
if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3584
3730
|
return {
|
|
3585
|
-
jaxpr
|
|
3731
|
+
jaxpr,
|
|
3586
3732
|
consts
|
|
3587
3733
|
};
|
|
3588
3734
|
}
|
|
@@ -3623,7 +3769,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
|
|
|
3623
3769
|
}
|
|
3624
3770
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
3625
3771
|
const eqn = jaxpr.eqns[i];
|
|
3626
|
-
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ?
|
|
3772
|
+
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
|
|
3627
3773
|
const cotangentsOut = eqn.outBinders.map(readCotangent);
|
|
3628
3774
|
const rule = transposeRules[eqn.primitive];
|
|
3629
3775
|
if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
|
|
@@ -3708,7 +3854,7 @@ const transposeRules = {
|
|
|
3708
3854
|
},
|
|
3709
3855
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3710
3856
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3711
|
-
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3857
|
+
const axisSize = require_backend.generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3712
3858
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3713
3859
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3714
3860
|
},
|
|
@@ -3803,7 +3949,7 @@ const transposeRules = {
|
|
|
3803
3949
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
3804
3950
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
3805
3951
|
},
|
|
3806
|
-
[Primitive.JitCall](cts, args, { jaxpr }) {
|
|
3952
|
+
[Primitive.JitCall](cts, args, { name, jaxpr }) {
|
|
3807
3953
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
3808
3954
|
const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
|
|
3809
3955
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
@@ -3812,6 +3958,7 @@ const transposeRules = {
|
|
|
3812
3958
|
...residuals,
|
|
3813
3959
|
...cts
|
|
3814
3960
|
], {
|
|
3961
|
+
name: `${name}_t`,
|
|
3815
3962
|
jaxpr: newJaxpr,
|
|
3816
3963
|
numConsts: newConsts.length
|
|
3817
3964
|
});
|
|
@@ -3848,20 +3995,28 @@ function vjpFlat(f, primalsIn) {
|
|
|
3848
3995
|
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
3849
3996
|
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
3850
3997
|
};
|
|
3851
|
-
|
|
3998
|
+
const dispose$1 = () => {
|
|
3999
|
+
for (const c of consts) c.dispose();
|
|
4000
|
+
};
|
|
4001
|
+
return [
|
|
4002
|
+
primalsOut,
|
|
4003
|
+
fVjp,
|
|
4004
|
+
dispose$1
|
|
4005
|
+
];
|
|
3852
4006
|
}
|
|
3853
4007
|
function vjp$1(f, ...primalsIn) {
|
|
3854
4008
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3855
4009
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3856
|
-
const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
4010
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3857
4011
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
3858
4012
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3859
|
-
const fVjp = (cotangentsOut) => {
|
|
4013
|
+
const fVjp = ((cotangentsOut) => {
|
|
3860
4014
|
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
3861
4015
|
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
3862
4016
|
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
3863
4017
|
return unflatten(inTree, cotangentsInFlat);
|
|
3864
|
-
};
|
|
4018
|
+
});
|
|
4019
|
+
fVjp.dispose = dispose$1;
|
|
3865
4020
|
return [primalsOut, fVjp];
|
|
3866
4021
|
}
|
|
3867
4022
|
function grad$1(f) {
|
|
@@ -3878,8 +4033,9 @@ function valueAndGrad$1(f) {
|
|
|
3878
4033
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3879
4034
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3880
4035
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3881
|
-
const [ct, ...rest] = fVjp(
|
|
3882
|
-
for (const r of rest)
|
|
4036
|
+
const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
|
|
4037
|
+
for (const r of rest) dispose(r);
|
|
4038
|
+
fVjp.dispose();
|
|
3883
4039
|
return [y, ct];
|
|
3884
4040
|
};
|
|
3885
4041
|
}
|
|
@@ -3887,7 +4043,13 @@ function jacrev$1(f) {
|
|
|
3887
4043
|
return function jacobianReverse(x) {
|
|
3888
4044
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
3889
4045
|
const [size$1] = x.shape;
|
|
3890
|
-
const pullback = (ct) =>
|
|
4046
|
+
const pullback = (ct) => {
|
|
4047
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
4048
|
+
y.dispose();
|
|
4049
|
+
const [ret] = fVjp(ct);
|
|
4050
|
+
fVjp.dispose();
|
|
4051
|
+
return ret;
|
|
4052
|
+
};
|
|
3891
4053
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3892
4054
|
};
|
|
3893
4055
|
}
|
|
@@ -3967,19 +4129,38 @@ __export(numpy_exports, {
|
|
|
3967
4129
|
DType: () => require_backend.DType,
|
|
3968
4130
|
abs: () => abs,
|
|
3969
4131
|
absolute: () => absolute,
|
|
4132
|
+
acos: () => acos,
|
|
4133
|
+
acosh: () => acosh,
|
|
3970
4134
|
add: () => add,
|
|
3971
4135
|
allclose: () => allclose,
|
|
3972
4136
|
arange: () => arange,
|
|
4137
|
+
arccos: () => arccos,
|
|
4138
|
+
arccosh: () => arccosh,
|
|
4139
|
+
arcsinh: () => arcsinh,
|
|
4140
|
+
arctan: () => arctan,
|
|
4141
|
+
arctan2: () => arctan2,
|
|
4142
|
+
arctanh: () => arctanh,
|
|
3973
4143
|
argmax: () => argmax,
|
|
3974
4144
|
argmin: () => argmin,
|
|
3975
4145
|
array: () => array,
|
|
4146
|
+
asin: () => asin,
|
|
4147
|
+
asinh: () => asinh,
|
|
3976
4148
|
astype: () => astype,
|
|
4149
|
+
atan: () => atan,
|
|
4150
|
+
atan2: () => atan2,
|
|
4151
|
+
atanh: () => atanh,
|
|
3977
4152
|
bool: () => bool,
|
|
4153
|
+
broadcastArrays: () => broadcastArrays,
|
|
4154
|
+
broadcastShapes: () => broadcastShapes,
|
|
4155
|
+
broadcastTo: () => broadcastTo,
|
|
4156
|
+
cbrt: () => cbrt,
|
|
3978
4157
|
clip: () => clip,
|
|
3979
4158
|
columnStack: () => columnStack,
|
|
3980
4159
|
concatenate: () => concatenate,
|
|
3981
4160
|
cos: () => cos,
|
|
3982
4161
|
cosh: () => cosh,
|
|
4162
|
+
deg2rad: () => deg2rad,
|
|
4163
|
+
degrees: () => degrees,
|
|
3983
4164
|
diag: () => diag,
|
|
3984
4165
|
diagonal: () => diagonal,
|
|
3985
4166
|
divide: () => divide,
|
|
@@ -3990,6 +4171,7 @@ __export(numpy_exports, {
|
|
|
3990
4171
|
eulerGamma: () => eulerGamma,
|
|
3991
4172
|
exp: () => exp,
|
|
3992
4173
|
exp2: () => exp2,
|
|
4174
|
+
expm1: () => expm1,
|
|
3993
4175
|
eye: () => eye,
|
|
3994
4176
|
flip: () => flip,
|
|
3995
4177
|
fliplr: () => fliplr,
|
|
@@ -4001,14 +4183,17 @@ __export(numpy_exports, {
|
|
|
4001
4183
|
greater: () => greater,
|
|
4002
4184
|
greaterEqual: () => greaterEqual,
|
|
4003
4185
|
hstack: () => hstack,
|
|
4186
|
+
hypot: () => hypot,
|
|
4004
4187
|
identity: () => identity$1,
|
|
4005
4188
|
inf: () => inf,
|
|
4189
|
+
inner: () => inner,
|
|
4006
4190
|
int32: () => int32,
|
|
4007
4191
|
less: () => less,
|
|
4008
4192
|
lessEqual: () => lessEqual,
|
|
4009
4193
|
linspace: () => linspace,
|
|
4010
4194
|
log: () => log,
|
|
4011
4195
|
log10: () => log10,
|
|
4196
|
+
log1p: () => log1p,
|
|
4012
4197
|
log2: () => log2,
|
|
4013
4198
|
matmul: () => matmul,
|
|
4014
4199
|
max: () => max,
|
|
@@ -4024,35 +4209,49 @@ __export(numpy_exports, {
|
|
|
4024
4209
|
negative: () => negative,
|
|
4025
4210
|
notEqual: () => notEqual,
|
|
4026
4211
|
ones: () => ones,
|
|
4027
|
-
onesLike: () => onesLike
|
|
4212
|
+
onesLike: () => onesLike,
|
|
4213
|
+
outer: () => outer,
|
|
4028
4214
|
pad: () => pad,
|
|
4029
4215
|
permuteDims: () => permuteDims,
|
|
4030
4216
|
pi: () => pi,
|
|
4217
|
+
pow: () => pow,
|
|
4218
|
+
power: () => power,
|
|
4031
4219
|
prod: () => prod$1,
|
|
4220
|
+
promoteTypes: () => require_backend.promoteTypes,
|
|
4221
|
+
rad2deg: () => rad2deg,
|
|
4222
|
+
radians: () => radians,
|
|
4032
4223
|
ravel: () => ravel,
|
|
4033
4224
|
reciprocal: () => reciprocal,
|
|
4225
|
+
repeat: () => repeat,
|
|
4034
4226
|
reshape: () => reshape,
|
|
4035
|
-
scalar: () => scalar,
|
|
4036
4227
|
shape: () => shape,
|
|
4228
|
+
sign: () => sign,
|
|
4037
4229
|
sin: () => sin,
|
|
4038
4230
|
sinh: () => sinh,
|
|
4039
4231
|
size: () => size,
|
|
4040
4232
|
sqrt: () => sqrt,
|
|
4041
4233
|
square: () => square,
|
|
4042
4234
|
stack: () => stack,
|
|
4235
|
+
std: () => std,
|
|
4236
|
+
subtract: () => subtract,
|
|
4043
4237
|
sum: () => sum,
|
|
4044
4238
|
tan: () => tan,
|
|
4045
4239
|
tanh: () => tanh,
|
|
4240
|
+
tile: () => tile,
|
|
4046
4241
|
transpose: () => transpose,
|
|
4242
|
+
tri: () => tri,
|
|
4243
|
+
tril: () => tril,
|
|
4244
|
+
triu: () => triu,
|
|
4047
4245
|
trueDivide: () => trueDivide,
|
|
4048
4246
|
trunc: () => trunc,
|
|
4049
4247
|
uint32: () => uint32,
|
|
4248
|
+
var_: () => var_,
|
|
4050
4249
|
vdot: () => vdot,
|
|
4051
4250
|
vecdot: () => vecdot,
|
|
4052
4251
|
vstack: () => vstack,
|
|
4053
4252
|
where: () => where,
|
|
4054
4253
|
zeros: () => zeros,
|
|
4055
|
-
zerosLike: () => zerosLike
|
|
4254
|
+
zerosLike: () => zerosLike
|
|
4056
4255
|
});
|
|
4057
4256
|
const float32 = require_backend.DType.Float32;
|
|
4058
4257
|
const int32 = require_backend.DType.Int32;
|
|
@@ -4069,54 +4268,66 @@ const inf = Number.POSITIVE_INFINITY;
|
|
|
4069
4268
|
const nan = NaN;
|
|
4070
4269
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
4071
4270
|
const pi = Math.PI;
|
|
4072
|
-
/** Element-wise addition, with broadcasting. */
|
|
4271
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
4073
4272
|
const add = add$1;
|
|
4074
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
4273
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
4075
4274
|
const multiply = mul;
|
|
4076
|
-
/** Numerical negative of every element of an array. */
|
|
4275
|
+
/** @function Numerical negative of every element of an array. */
|
|
4077
4276
|
const negative = neg;
|
|
4078
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4277
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4079
4278
|
const reciprocal = reciprocal$1;
|
|
4080
|
-
/** Element-wise sine function (takes radians). */
|
|
4279
|
+
/** @function Element-wise sine function (takes radians). */
|
|
4081
4280
|
const sin = sin$1;
|
|
4082
|
-
/** Element-wise cosine function (takes radians). */
|
|
4281
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
4083
4282
|
const cos = cos$1;
|
|
4084
|
-
/**
|
|
4283
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
4284
|
+
const asin = asin$1;
|
|
4285
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
4286
|
+
const atan = atan$1;
|
|
4287
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
4085
4288
|
const exp = exp$1;
|
|
4086
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
4289
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
4087
4290
|
const log = log$1;
|
|
4088
|
-
/** Calculate the square root of all elements in the input array. */
|
|
4291
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
4089
4292
|
const sqrt = sqrt$1;
|
|
4090
|
-
/** Return element-wise minimum of the input arrays. */
|
|
4293
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
4091
4294
|
const minimum = min$1;
|
|
4092
|
-
/** Return element-wise maximum of the input arrays. */
|
|
4295
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
4093
4296
|
const maximum = max$1;
|
|
4094
|
-
/** Compare two arrays element-wise. */
|
|
4297
|
+
/** @function Compare two arrays element-wise. */
|
|
4095
4298
|
const greater = greater$1;
|
|
4096
|
-
/** Compare two arrays element-wise. */
|
|
4299
|
+
/** @function Compare two arrays element-wise. */
|
|
4097
4300
|
const less = less$1;
|
|
4098
|
-
/** Compare two arrays element-wise. */
|
|
4301
|
+
/** @function Compare two arrays element-wise. */
|
|
4099
4302
|
const equal = equal$1;
|
|
4100
|
-
/** Compare two arrays element-wise. */
|
|
4303
|
+
/** @function Compare two arrays element-wise. */
|
|
4101
4304
|
const notEqual = notEqual$1;
|
|
4102
|
-
/** Compare two arrays element-wise. */
|
|
4305
|
+
/** @function Compare two arrays element-wise. */
|
|
4103
4306
|
const greaterEqual = greaterEqual$1;
|
|
4104
|
-
/** Compare two arrays element-wise. */
|
|
4307
|
+
/** @function Compare two arrays element-wise. */
|
|
4105
4308
|
const lessEqual = lessEqual$1;
|
|
4106
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4309
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4107
4310
|
const where = where$1;
|
|
4108
|
-
/**
|
|
4311
|
+
/**
|
|
4312
|
+
* @function
|
|
4313
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
4314
|
+
*/
|
|
4109
4315
|
const transpose = transpose$1;
|
|
4110
4316
|
/**
|
|
4317
|
+
* @function
|
|
4111
4318
|
* Give a new shape to an array without changing its data.
|
|
4112
4319
|
*
|
|
4113
4320
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
4114
4321
|
* length of the array and remaining dimensions.
|
|
4115
4322
|
*/
|
|
4116
4323
|
const reshape = reshape$1;
|
|
4117
|
-
/**
|
|
4324
|
+
/**
|
|
4325
|
+
* @function
|
|
4326
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
4327
|
+
*/
|
|
4118
4328
|
const moveaxis = moveaxis$1;
|
|
4119
4329
|
/**
|
|
4330
|
+
* @function
|
|
4120
4331
|
* Add padding (zeros) to an array.
|
|
4121
4332
|
*
|
|
4122
4333
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -4124,15 +4335,27 @@ const moveaxis = moveaxis$1;
|
|
|
4124
4335
|
* pair specifies the padding for its corresponding axis.
|
|
4125
4336
|
*/
|
|
4126
4337
|
const pad = pad$1;
|
|
4127
|
-
/**
|
|
4338
|
+
/**
|
|
4339
|
+
* @function
|
|
4340
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
4341
|
+
*/
|
|
4128
4342
|
const ndim = ndim$1;
|
|
4129
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
4343
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
4130
4344
|
const shape = getShape;
|
|
4131
|
-
/**
|
|
4132
|
-
|
|
4133
|
-
|
|
4134
|
-
|
|
4135
|
-
|
|
4345
|
+
/**
|
|
4346
|
+
* @function
|
|
4347
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
4348
|
+
*/
|
|
4349
|
+
const zerosLike = zerosLike$1;
|
|
4350
|
+
/**
|
|
4351
|
+
* @function
|
|
4352
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
4353
|
+
*/
|
|
4354
|
+
const onesLike = onesLike$1;
|
|
4355
|
+
/**
|
|
4356
|
+
* @function
|
|
4357
|
+
* Return a full array with the same shape and type as a given array.
|
|
4358
|
+
*/
|
|
4136
4359
|
const fullLike$1 = fullLike;
|
|
4137
4360
|
/**
|
|
4138
4361
|
* Return the number of elements in an array, optionally along an axis.
|
|
@@ -4147,23 +4370,23 @@ function astype(a, dtype) {
|
|
|
4147
4370
|
return fudgeArray(a).astype(dtype);
|
|
4148
4371
|
}
|
|
4149
4372
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
4150
|
-
function sum(a, axis, opts) {
|
|
4373
|
+
function sum(a, axis = null, opts) {
|
|
4151
4374
|
return reduce(a, require_backend.AluOp.Add, axis, opts);
|
|
4152
4375
|
}
|
|
4153
4376
|
/** Product of the array elements over a given axis. */
|
|
4154
|
-
function prod$1(a, axis, opts) {
|
|
4377
|
+
function prod$1(a, axis = null, opts) {
|
|
4155
4378
|
return reduce(a, require_backend.AluOp.Mul, axis, opts);
|
|
4156
4379
|
}
|
|
4157
4380
|
/** Return the minimum of array elements along a given axis. */
|
|
4158
|
-
function min(a, axis, opts) {
|
|
4381
|
+
function min(a, axis = null, opts) {
|
|
4159
4382
|
return reduce(a, require_backend.AluOp.Min, axis, opts);
|
|
4160
4383
|
}
|
|
4161
4384
|
/** Return the maximum of array elements along a given axis. */
|
|
4162
|
-
function max(a, axis, opts) {
|
|
4385
|
+
function max(a, axis = null, opts) {
|
|
4163
4386
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
4164
4387
|
}
|
|
4165
4388
|
/** Compute the average of the array elements along the specified axis. */
|
|
4166
|
-
function mean(a, axis, opts) {
|
|
4389
|
+
function mean(a, axis = null, opts) {
|
|
4167
4390
|
return fudgeArray(a).mean(axis, opts);
|
|
4168
4391
|
}
|
|
4169
4392
|
/**
|
|
@@ -4179,8 +4402,8 @@ function argmin(a, axis, opts) {
|
|
|
4179
4402
|
axis = 0;
|
|
4180
4403
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4181
4404
|
const shape$1 = a.shape;
|
|
4182
|
-
const isMax = equal(a, min(a.ref, axis, {
|
|
4183
|
-
const length =
|
|
4405
|
+
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4406
|
+
const length = array(shape$1[axis], {
|
|
4184
4407
|
dtype: int32,
|
|
4185
4408
|
device: a.device
|
|
4186
4409
|
});
|
|
@@ -4203,8 +4426,8 @@ function argmax(a, axis, opts) {
|
|
|
4203
4426
|
axis = 0;
|
|
4204
4427
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4205
4428
|
const shape$1 = a.shape;
|
|
4206
|
-
const isMax = equal(a, max(a.ref, axis, {
|
|
4207
|
-
const length =
|
|
4429
|
+
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4430
|
+
const length = array(shape$1[axis], {
|
|
4208
4431
|
dtype: int32,
|
|
4209
4432
|
device: a.device
|
|
4210
4433
|
});
|
|
@@ -4215,17 +4438,9 @@ function argmax(a, axis, opts) {
|
|
|
4215
4438
|
return length.sub(max(idx, axis, opts));
|
|
4216
4439
|
}
|
|
4217
4440
|
/** Reverse the elements in an array along the given axes. */
|
|
4218
|
-
function flip(x, axis) {
|
|
4441
|
+
function flip(x, axis = null) {
|
|
4219
4442
|
const nd = ndim(x);
|
|
4220
|
-
|
|
4221
|
-
else if (typeof axis === "number") axis = [axis];
|
|
4222
|
-
const seen = /* @__PURE__ */ new Set();
|
|
4223
|
-
for (let i = 0; i < axis.length; i++) {
|
|
4224
|
-
if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
|
|
4225
|
-
if (axis[i] < 0) axis[i] += nd;
|
|
4226
|
-
if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
|
|
4227
|
-
seen.add(axis[i]);
|
|
4228
|
-
}
|
|
4443
|
+
axis = require_backend.normalizeAxis(axis, nd);
|
|
4229
4444
|
return flip$1(x, axis);
|
|
4230
4445
|
}
|
|
4231
4446
|
/**
|
|
@@ -4331,12 +4546,80 @@ function flipud(x) {
|
|
|
4331
4546
|
function fliplr(x) {
|
|
4332
4547
|
return flip(x, 1);
|
|
4333
4548
|
}
|
|
4549
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
4334
4550
|
const permuteDims = transpose;
|
|
4335
4551
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
4336
4552
|
function ravel(a) {
|
|
4337
4553
|
return fudgeArray(a).ravel();
|
|
4338
4554
|
}
|
|
4339
4555
|
/**
|
|
4556
|
+
* Repeat each element of an array after themselves.
|
|
4557
|
+
*
|
|
4558
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
4559
|
+
* output array.
|
|
4560
|
+
*/
|
|
4561
|
+
function repeat(a, repeats, axis) {
|
|
4562
|
+
if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
|
|
4563
|
+
a = fudgeArray(a);
|
|
4564
|
+
if (axis === void 0) {
|
|
4565
|
+
a = ravel(a);
|
|
4566
|
+
axis = 0;
|
|
4567
|
+
}
|
|
4568
|
+
axis = require_backend.checkAxis(axis, a.ndim);
|
|
4569
|
+
if (repeats === 1) return a;
|
|
4570
|
+
const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
|
|
4571
|
+
const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
|
|
4572
|
+
return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
|
|
4573
|
+
}
|
|
4574
|
+
/**
|
|
4575
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
4576
|
+
*
|
|
4577
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
4578
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
4579
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
4580
|
+
*/
|
|
4581
|
+
function tile(a, reps) {
|
|
4582
|
+
a = fudgeArray(a);
|
|
4583
|
+
if (typeof reps === "number") reps = [reps];
|
|
4584
|
+
if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
|
|
4585
|
+
const ndiff = reps.length - a.ndim;
|
|
4586
|
+
if (ndiff > 0) a = a.reshape([...require_backend.rep(ndiff, 1), ...a.shape]);
|
|
4587
|
+
if (ndiff < 0) reps = [...require_backend.rep(-ndiff, 1), ...reps];
|
|
4588
|
+
const broadcastedShape = [];
|
|
4589
|
+
const broadcastAxes = [];
|
|
4590
|
+
for (let i = 0; i < a.ndim; i++) {
|
|
4591
|
+
if (reps[i] > 1) {
|
|
4592
|
+
broadcastedShape.push(reps[i]);
|
|
4593
|
+
broadcastAxes.push(broadcastedShape.length - 1);
|
|
4594
|
+
}
|
|
4595
|
+
broadcastedShape.push(a.shape[i]);
|
|
4596
|
+
}
|
|
4597
|
+
const finalShape = a.shape.map((d, i) => reps[i] * d);
|
|
4598
|
+
return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
|
|
4599
|
+
}
|
|
4600
|
+
/**
|
|
4601
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
4602
|
+
*
|
|
4603
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
4604
|
+
* dimensions where the shape is 1.
|
|
4605
|
+
*/
|
|
4606
|
+
function broadcastTo(a, shape$1) {
|
|
4607
|
+
const nd = ndim(a);
|
|
4608
|
+
if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
|
|
4609
|
+
return broadcast(a, shape$1, require_backend.range(shape$1.length - nd));
|
|
4610
|
+
}
|
|
4611
|
+
/** Broadcast input shapes to a common output shape. */
|
|
4612
|
+
function broadcastShapes(...shapes) {
|
|
4613
|
+
if (shapes.length === 0) return [];
|
|
4614
|
+
return shapes.reduce(require_backend.generalBroadcast);
|
|
4615
|
+
}
|
|
4616
|
+
/** Broadcast arrays to a common shape. */
|
|
4617
|
+
function broadcastArrays(...arrays) {
|
|
4618
|
+
const shapes = arrays.map((a) => shape(a));
|
|
4619
|
+
const outShape = broadcastShapes(...shapes);
|
|
4620
|
+
return arrays.map((a) => broadcastTo(a, outShape));
|
|
4621
|
+
}
|
|
4622
|
+
/**
|
|
4340
4623
|
* Return specified diagonals.
|
|
4341
4624
|
*
|
|
4342
4625
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
@@ -4360,7 +4643,7 @@ function diag(v, k = 0) {
|
|
|
4360
4643
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
4361
4644
|
if (a.ndim === 1) {
|
|
4362
4645
|
const n = a.shape[0];
|
|
4363
|
-
const ret = where(eye(n).equal(1), a.ref, zerosLike
|
|
4646
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
4364
4647
|
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4365
4648
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4366
4649
|
else return ret;
|
|
@@ -4404,8 +4687,36 @@ function dot(x, y) {
|
|
|
4404
4687
|
]);
|
|
4405
4688
|
return dot$1(x, y);
|
|
4406
4689
|
}
|
|
4407
|
-
/**
|
|
4408
|
-
|
|
4690
|
+
/**
|
|
4691
|
+
* Compute the inner product of two arrays.
|
|
4692
|
+
*
|
|
4693
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
4694
|
+
* contraction on the last axis.
|
|
4695
|
+
*
|
|
4696
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
4697
|
+
*/
|
|
4698
|
+
function inner(x, y) {
|
|
4699
|
+
x = reshape(x, shape(x).toSpliced(-1, 0, ...require_backend.rep(ndim(y) - 1, 1)));
|
|
4700
|
+
return dot$1(x, y);
|
|
4701
|
+
}
|
|
4702
|
+
/**
|
|
4703
|
+
* Compute the outer product of two arrays.
|
|
4704
|
+
*
|
|
4705
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
4706
|
+
* be of shape `[x.size, y.size]`.
|
|
4707
|
+
*/
|
|
4708
|
+
function outer(x, y) {
|
|
4709
|
+
x = ravel(x);
|
|
4710
|
+
y = ravel(y);
|
|
4711
|
+
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
4712
|
+
}
|
|
4713
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
4714
|
+
function vecdot(x, y, { axis } = {}) {
|
|
4715
|
+
const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
|
|
4716
|
+
const yaxis = require_backend.checkAxis(axis ?? -1, ndim(y));
|
|
4717
|
+
if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
|
|
4718
|
+
x = moveaxis(x, xaxis, -1);
|
|
4719
|
+
y = moveaxis(y, yaxis, -1);
|
|
4409
4720
|
return dot$1(x, y);
|
|
4410
4721
|
}
|
|
4411
4722
|
/**
|
|
@@ -4414,7 +4725,7 @@ function vecdot(x, y) {
|
|
|
4414
4725
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
4415
4726
|
*/
|
|
4416
4727
|
function vdot(x, y) {
|
|
4417
|
-
return
|
|
4728
|
+
return dot$1(ravel(x), ravel(y));
|
|
4418
4729
|
}
|
|
4419
4730
|
/**
|
|
4420
4731
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
@@ -4443,6 +4754,43 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
4443
4754
|
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
4444
4755
|
}
|
|
4445
4756
|
/**
|
|
4757
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
4758
|
+
*
|
|
4759
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
4760
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
4761
|
+
* `k>0` is above it.
|
|
4762
|
+
*/
|
|
4763
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
4764
|
+
m ??= n;
|
|
4765
|
+
dtype ??= require_backend.DType.Float32;
|
|
4766
|
+
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
4767
|
+
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
4768
|
+
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
4769
|
+
const rows = arange(k, n + k, 1, {
|
|
4770
|
+
dtype: require_backend.DType.Int32,
|
|
4771
|
+
device
|
|
4772
|
+
});
|
|
4773
|
+
const cols = arange(0, m, 1, {
|
|
4774
|
+
dtype: require_backend.DType.Int32,
|
|
4775
|
+
device
|
|
4776
|
+
});
|
|
4777
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
4778
|
+
}
|
|
4779
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
4780
|
+
function tril(a, k = 0) {
|
|
4781
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4782
|
+
a = fudgeArray(a);
|
|
4783
|
+
const [n, m] = a.shape.slice(-2);
|
|
4784
|
+
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
4785
|
+
}
|
|
4786
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
4787
|
+
function triu(a, k = 0) {
|
|
4788
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4789
|
+
a = fudgeArray(a);
|
|
4790
|
+
const [n, m] = a.shape.slice(-2);
|
|
4791
|
+
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
4792
|
+
}
|
|
4793
|
+
/**
|
|
4446
4794
|
* Clip (limit) the values in an array.
|
|
4447
4795
|
*
|
|
4448
4796
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -4466,18 +4814,70 @@ function absolute(x) {
|
|
|
4466
4814
|
x = fudgeArray(x);
|
|
4467
4815
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
4468
4816
|
}
|
|
4469
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
4817
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
4470
4818
|
const abs = absolute;
|
|
4819
|
+
/** Return an element-wise indication of sign of the input. */
|
|
4820
|
+
function sign(x) {
|
|
4821
|
+
x = fudgeArray(x);
|
|
4822
|
+
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4823
|
+
}
|
|
4471
4824
|
/** Calculate element-wise square of the input array. */
|
|
4472
4825
|
function square(x) {
|
|
4473
4826
|
x = fudgeArray(x);
|
|
4474
4827
|
return x.ref.mul(x);
|
|
4475
4828
|
}
|
|
4476
|
-
/**
|
|
4829
|
+
/** Element-wise tangent function (takes radians). */
|
|
4477
4830
|
function tan(x) {
|
|
4478
4831
|
x = fudgeArray(x);
|
|
4479
4832
|
return sin(x.ref).div(cos(x));
|
|
4480
4833
|
}
|
|
4834
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
4835
|
+
function acos(x) {
|
|
4836
|
+
return subtract(pi / 2, asin(x));
|
|
4837
|
+
}
|
|
4838
|
+
/**
|
|
4839
|
+
* @function
|
|
4840
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4841
|
+
*
|
|
4842
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4843
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4844
|
+
* improvements.
|
|
4845
|
+
*/
|
|
4846
|
+
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4847
|
+
return sqrt(square(x1).add(square(x2)));
|
|
4848
|
+
});
|
|
4849
|
+
/**
|
|
4850
|
+
* @function
|
|
4851
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
4852
|
+
*
|
|
4853
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
4854
|
+
* The result is in the range [-π, π].
|
|
4855
|
+
*
|
|
4856
|
+
* Uses numerically stable formulas:
|
|
4857
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
4858
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
4859
|
+
*
|
|
4860
|
+
* The output is ill-defined when both x and y are zero.
|
|
4861
|
+
*/
|
|
4862
|
+
const atan2 = jit$1(function atan2$1(y, x) {
|
|
4863
|
+
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4864
|
+
const xNeg = less(x.ref, 0);
|
|
4865
|
+
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
4866
|
+
const denom = where(xNeg, y, r.add(x));
|
|
4867
|
+
return atan(numer.div(denom)).mul(2);
|
|
4868
|
+
});
|
|
4869
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
4870
|
+
const arccos = acos;
|
|
4871
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
4872
|
+
const arctan = atan;
|
|
4873
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
4874
|
+
const arctan2 = atan2;
|
|
4875
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
4876
|
+
function subtract(x, y) {
|
|
4877
|
+
x = fudgeArray(x);
|
|
4878
|
+
y = fudgeArray(y);
|
|
4879
|
+
return x.sub(y);
|
|
4880
|
+
}
|
|
4481
4881
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
4482
4882
|
function trueDivide(x, y) {
|
|
4483
4883
|
x = fudgeArray(x);
|
|
@@ -4485,7 +4885,7 @@ function trueDivide(x, y) {
|
|
|
4485
4885
|
if (!require_backend.isFloatDtype(x.dtype) || !require_backend.isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
|
|
4486
4886
|
return x.div(y);
|
|
4487
4887
|
}
|
|
4488
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
4888
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
4489
4889
|
const divide = trueDivide;
|
|
4490
4890
|
/** Round input to the nearest integer towards zero. */
|
|
4491
4891
|
function trunc(x) {
|
|
@@ -4503,36 +4903,134 @@ function log2(x) {
|
|
|
4503
4903
|
function log10(x) {
|
|
4504
4904
|
return log(x).mul(Math.LOG10E);
|
|
4505
4905
|
}
|
|
4906
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
4907
|
+
function expm1(x) {
|
|
4908
|
+
return exp(x).sub(1);
|
|
4909
|
+
}
|
|
4910
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
4911
|
+
function log1p(x) {
|
|
4912
|
+
return log(add(1, x));
|
|
4913
|
+
}
|
|
4914
|
+
/** Convert angles from degrees to radians. */
|
|
4915
|
+
function deg2rad(x) {
|
|
4916
|
+
return multiply(x, pi / 180);
|
|
4917
|
+
}
|
|
4918
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
4919
|
+
const radians = deg2rad;
|
|
4920
|
+
/** Convert angles from radians to degrees. */
|
|
4921
|
+
function rad2deg(x) {
|
|
4922
|
+
return multiply(x, 180 / pi);
|
|
4923
|
+
}
|
|
4924
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
4925
|
+
const degrees = rad2deg;
|
|
4926
|
+
/**
|
|
4927
|
+
* @function
|
|
4928
|
+
* Computes first array raised to power of second array, element-wise.
|
|
4929
|
+
*/
|
|
4930
|
+
const power = jit$1(function power$1(x1, x2) {
|
|
4931
|
+
return exp(log(x1).mul(x2));
|
|
4932
|
+
});
|
|
4933
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
4934
|
+
const pow = power;
|
|
4935
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
4936
|
+
const cbrt = jit$1(function cbrt$1(x) {
|
|
4937
|
+
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4938
|
+
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4939
|
+
});
|
|
4506
4940
|
/**
|
|
4941
|
+
* @function
|
|
4507
4942
|
* Calculate element-wise hyperbolic sine of input.
|
|
4508
4943
|
*
|
|
4509
4944
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4510
4945
|
*/
|
|
4511
|
-
function sinh(x) {
|
|
4946
|
+
const sinh = jit$1(function sinh$1(x) {
|
|
4512
4947
|
const ex = exp(x);
|
|
4513
4948
|
const emx = reciprocal(ex.ref);
|
|
4514
4949
|
return ex.sub(emx).mul(.5);
|
|
4515
|
-
}
|
|
4950
|
+
});
|
|
4516
4951
|
/**
|
|
4952
|
+
* @function
|
|
4517
4953
|
* Calculate element-wise hyperbolic cosine of input.
|
|
4518
4954
|
*
|
|
4519
4955
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4520
4956
|
*/
|
|
4521
|
-
function cosh(x) {
|
|
4957
|
+
const cosh = jit$1(function cosh$1(x) {
|
|
4522
4958
|
const ex = exp(x);
|
|
4523
4959
|
const emx = reciprocal(ex.ref);
|
|
4524
4960
|
return ex.add(emx).mul(.5);
|
|
4525
|
-
}
|
|
4961
|
+
});
|
|
4526
4962
|
/**
|
|
4963
|
+
* @function
|
|
4527
4964
|
* Calculate element-wise hyperbolic tangent of input.
|
|
4528
4965
|
*
|
|
4529
4966
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4530
4967
|
*/
|
|
4531
|
-
function tanh(x) {
|
|
4532
|
-
x = fudgeArray(x);
|
|
4968
|
+
const tanh = jit$1(function tanh$1(x) {
|
|
4533
4969
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4534
4970
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4535
4971
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4972
|
+
});
|
|
4973
|
+
/**
|
|
4974
|
+
* @function
|
|
4975
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
4976
|
+
*
|
|
4977
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4978
|
+
*/
|
|
4979
|
+
const arcsinh = jit$1(function arcsinh$1(x) {
|
|
4980
|
+
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4981
|
+
});
|
|
4982
|
+
/**
|
|
4983
|
+
* @function
|
|
4984
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
4985
|
+
*
|
|
4986
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4987
|
+
*/
|
|
4988
|
+
const arccosh = jit$1(function arccosh$1(x) {
|
|
4989
|
+
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4990
|
+
});
|
|
4991
|
+
/**
|
|
4992
|
+
* @function
|
|
4993
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
4994
|
+
*
|
|
4995
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4996
|
+
*/
|
|
4997
|
+
const arctanh = jit$1(function arctanh$1(x) {
|
|
4998
|
+
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4999
|
+
});
|
|
5000
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5001
|
+
const asinh = arcsinh;
|
|
5002
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5003
|
+
const acosh = arccosh;
|
|
5004
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5005
|
+
const atanh = arctanh;
|
|
5006
|
+
/**
|
|
5007
|
+
* Compute the variance of an array.
|
|
5008
|
+
*
|
|
5009
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
5010
|
+
* the specified axis.
|
|
5011
|
+
*
|
|
5012
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
5013
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
5014
|
+
*/
|
|
5015
|
+
function var_(x, axis = null, opts) {
|
|
5016
|
+
x = fudgeArray(x);
|
|
5017
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5018
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5019
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
5020
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
5021
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
5022
|
+
}
|
|
5023
|
+
/**
|
|
5024
|
+
* Compute the standard deviation of an array.
|
|
5025
|
+
*
|
|
5026
|
+
* The standard deviation is computed for the flattened array by default,
|
|
5027
|
+
* otherwise over the specified axis.
|
|
5028
|
+
*
|
|
5029
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
5030
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
5031
|
+
*/
|
|
5032
|
+
function std(x, axis = null, opts) {
|
|
5033
|
+
return sqrt(var_(x, axis, opts));
|
|
4536
5034
|
}
|
|
4537
5035
|
|
|
4538
5036
|
//#endregion
|
|
@@ -4547,6 +5045,7 @@ __export(nn_exports, {
|
|
|
4547
5045
|
leakyRelu: () => leakyRelu,
|
|
4548
5046
|
logSigmoid: () => logSigmoid,
|
|
4549
5047
|
logSoftmax: () => logSoftmax,
|
|
5048
|
+
logmeanexp: () => logmeanexp,
|
|
4550
5049
|
logsumexp: () => logsumexp,
|
|
4551
5050
|
mish: () => mish,
|
|
4552
5051
|
oneHot: () => oneHot,
|
|
@@ -4557,6 +5056,8 @@ __export(nn_exports, {
|
|
|
4557
5056
|
softSign: () => softSign,
|
|
4558
5057
|
softmax: () => softmax,
|
|
4559
5058
|
softplus: () => softplus,
|
|
5059
|
+
squareplus: () => squareplus,
|
|
5060
|
+
standardize: () => standardize,
|
|
4560
5061
|
swish: () => swish
|
|
4561
5062
|
});
|
|
4562
5063
|
/**
|
|
@@ -4600,6 +5101,7 @@ function softSign(x) {
|
|
|
4600
5101
|
return x.ref.div(absolute(x).add(1));
|
|
4601
5102
|
}
|
|
4602
5103
|
/**
|
|
5104
|
+
* @function
|
|
4603
5105
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4604
5106
|
* Swish, computed element-wise:
|
|
4605
5107
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4608,8 +5110,11 @@ function softSign(x) {
|
|
|
4608
5110
|
*
|
|
4609
5111
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4610
5112
|
*/
|
|
4611
|
-
const silu = jit$1((x)
|
|
5113
|
+
const silu = jit$1(function silu$1(x) {
|
|
5114
|
+
return x.ref.mul(sigmoid(x));
|
|
5115
|
+
});
|
|
4612
5116
|
/**
|
|
5117
|
+
* @function
|
|
4613
5118
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4614
5119
|
* Swish, computed element-wise:
|
|
4615
5120
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4626,7 +5131,10 @@ const swish = silu;
|
|
|
4626
5131
|
function logSigmoid(x) {
|
|
4627
5132
|
return negative(softplus(negative(x)));
|
|
4628
5133
|
}
|
|
4629
|
-
/**
|
|
5134
|
+
/**
|
|
5135
|
+
* @function
|
|
5136
|
+
* Identity activation function. Returns the argument unmodified.
|
|
5137
|
+
*/
|
|
4630
5138
|
const identity = fudgeArray;
|
|
4631
5139
|
/** Leaky rectified linear (ReLU) activation function */
|
|
4632
5140
|
function leakyRelu(x, negativeSlope = .01) {
|
|
@@ -4654,6 +5162,7 @@ function celu(x, alpha = 1) {
|
|
|
4654
5162
|
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
4655
5163
|
}
|
|
4656
5164
|
/**
|
|
5165
|
+
* @function
|
|
4657
5166
|
* Gaussion error linear unit (GELU) activation function.
|
|
4658
5167
|
*
|
|
4659
5168
|
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
@@ -4664,7 +5173,7 @@ function celu(x, alpha = 1) {
|
|
|
4664
5173
|
*
|
|
4665
5174
|
* This will be improved in the future.
|
|
4666
5175
|
*/
|
|
4667
|
-
const gelu = jit$1((x)
|
|
5176
|
+
const gelu = jit$1(function gelu$1(x) {
|
|
4668
5177
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
4669
5178
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
4670
5179
|
});
|
|
@@ -4685,6 +5194,16 @@ function glu(x, axis = -1) {
|
|
|
4685
5194
|
return a.mul(sigmoid(b));
|
|
4686
5195
|
}
|
|
4687
5196
|
/**
|
|
5197
|
+
* Squareplus activation function.
|
|
5198
|
+
*
|
|
5199
|
+
* Computes the element-wise function:
|
|
5200
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
5201
|
+
*/
|
|
5202
|
+
function squareplus(x, b = 4) {
|
|
5203
|
+
x = fudgeArray(x);
|
|
5204
|
+
return x.ref.add(sqrt(square(x).add(b))).mul(.5);
|
|
5205
|
+
}
|
|
5206
|
+
/**
|
|
4688
5207
|
* Mish activation function.
|
|
4689
5208
|
*
|
|
4690
5209
|
* Computes the element-wise function:
|
|
@@ -4702,17 +5221,13 @@ function mish(x) {
|
|
|
4702
5221
|
*
|
|
4703
5222
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
4704
5223
|
*/
|
|
4705
|
-
function softmax(x, axis) {
|
|
5224
|
+
function softmax(x, axis = -1) {
|
|
4706
5225
|
x = fudgeArray(x);
|
|
4707
|
-
|
|
4708
|
-
|
|
4709
|
-
|
|
4710
|
-
x.dispose();
|
|
4711
|
-
return ones(x.shape);
|
|
4712
|
-
}
|
|
4713
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5226
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5227
|
+
if (axis.length === 0) return onesLike(x);
|
|
5228
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4714
5229
|
const unnormalized = exp(x.sub(stopGradient(xMax)));
|
|
4715
|
-
return unnormalized.ref.div(unnormalized.sum(axis, {
|
|
5230
|
+
return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
|
|
4716
5231
|
}
|
|
4717
5232
|
/**
|
|
4718
5233
|
* Log-Softmax function.
|
|
@@ -4722,17 +5237,13 @@ function softmax(x, axis) {
|
|
|
4722
5237
|
*
|
|
4723
5238
|
* If `axis` is not specified, it defaults to the last axis.
|
|
4724
5239
|
*/
|
|
4725
|
-
function logSoftmax(x, axis) {
|
|
5240
|
+
function logSoftmax(x, axis = -1) {
|
|
4726
5241
|
x = fudgeArray(x);
|
|
4727
|
-
|
|
4728
|
-
|
|
4729
|
-
|
|
4730
|
-
x.dispose();
|
|
4731
|
-
return zeros(x.shape);
|
|
4732
|
-
}
|
|
4733
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5242
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5243
|
+
if (axis.length === 0) return zerosLike(x);
|
|
5244
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4734
5245
|
const shifted = x.sub(stopGradient(xMax));
|
|
4735
|
-
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, {
|
|
5246
|
+
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
|
|
4736
5247
|
return shifted.sub(shiftedLogsumexp);
|
|
4737
5248
|
}
|
|
4738
5249
|
/**
|
|
@@ -4743,16 +5254,39 @@ function logSoftmax(x, axis) {
|
|
|
4743
5254
|
*
|
|
4744
5255
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
4745
5256
|
*/
|
|
4746
|
-
function logsumexp(x, axis) {
|
|
5257
|
+
function logsumexp(x, axis = null) {
|
|
4747
5258
|
x = fudgeArray(x);
|
|
4748
|
-
|
|
4749
|
-
else if (typeof axis === "number") axis = [axis];
|
|
5259
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
4750
5260
|
if (axis.length === 0) return x;
|
|
4751
5261
|
const xMax = stopGradient(max(x.ref, axis));
|
|
4752
5262
|
const xMaxDims = broadcast(xMax.ref, x.shape, axis);
|
|
4753
5263
|
const shifted = x.sub(xMaxDims);
|
|
4754
5264
|
return xMax.add(log(exp(shifted).sum(axis)));
|
|
4755
5265
|
}
|
|
5266
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5267
|
+
function logmeanexp(x, axis = null) {
|
|
5268
|
+
x = fudgeArray(x);
|
|
5269
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5270
|
+
if (axis.length === 0) return x;
|
|
5271
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5272
|
+
return logsumexp(x, axis).sub(Math.log(n));
|
|
5273
|
+
}
|
|
5274
|
+
/**
|
|
5275
|
+
* Standardizes input to zero mean and unit variance.
|
|
5276
|
+
*
|
|
5277
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
5278
|
+
* axis, or `null` to standardize over all elements.
|
|
5279
|
+
*
|
|
5280
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
5281
|
+
*/
|
|
5282
|
+
function standardize(x, axis = -1, opts = {}) {
|
|
5283
|
+
x = fudgeArray(x);
|
|
5284
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5285
|
+
if (axis.length === 0) return x;
|
|
5286
|
+
const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
|
|
5287
|
+
const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
|
|
5288
|
+
return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
|
|
5289
|
+
}
|
|
4756
5290
|
/**
|
|
4757
5291
|
* One-hot encodes the given indices.
|
|
4758
5292
|
*
|
|
@@ -4770,7 +5304,7 @@ function logsumexp(x, axis) {
|
|
|
4770
5304
|
* ```
|
|
4771
5305
|
*/
|
|
4772
5306
|
function oneHot(x, numClasses) {
|
|
4773
|
-
if (x.dtype
|
|
5307
|
+
if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4774
5308
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4775
5309
|
}
|
|
4776
5310
|
|
|
@@ -4778,8 +5312,11 @@ function oneHot(x, numClasses) {
|
|
|
4778
5312
|
//#region src/random.ts
|
|
4779
5313
|
var random_exports = {};
|
|
4780
5314
|
__export(random_exports, {
|
|
5315
|
+
bernoulli: () => bernoulli,
|
|
4781
5316
|
bits: () => bits,
|
|
5317
|
+
exponential: () => exponential,
|
|
4782
5318
|
key: () => key,
|
|
5319
|
+
normal: () => normal,
|
|
4783
5320
|
split: () => split,
|
|
4784
5321
|
uniform: () => uniform
|
|
4785
5322
|
});
|
|
@@ -4807,21 +5344,58 @@ function bits(key$1, shape$1 = []) {
|
|
|
4807
5344
|
const keyShape = validateKeyShape(key$1);
|
|
4808
5345
|
return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
|
|
4809
5346
|
}
|
|
4810
|
-
/**
|
|
4811
|
-
function
|
|
5347
|
+
/**
|
|
5348
|
+
* @function
|
|
5349
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
5350
|
+
*/
|
|
5351
|
+
const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
4812
5352
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
4813
|
-
const mantissa = bits(key$1, shape$1).div(
|
|
5353
|
+
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
4814
5354
|
dtype: require_backend.DType.Uint32,
|
|
4815
5355
|
device: key$1.device
|
|
4816
5356
|
}));
|
|
4817
|
-
const float12 = mantissa.add(
|
|
5357
|
+
const float12 = mantissa.add(array(1065353216, {
|
|
4818
5358
|
dtype: require_backend.DType.Uint32,
|
|
4819
5359
|
device: key$1.device
|
|
4820
5360
|
}));
|
|
4821
5361
|
const rand = bitcast(float12, require_backend.DType.Float32).sub(1);
|
|
4822
5362
|
if (minval === 0 && maxval === 1) return rand;
|
|
4823
5363
|
else return rand.mul(maxval - minval).add(minval);
|
|
5364
|
+
}, { staticArgnums: [1, 2] });
|
|
5365
|
+
/**
|
|
5366
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5367
|
+
*
|
|
5368
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
5369
|
+
* and must be broadcastable to `shape`.
|
|
5370
|
+
*/
|
|
5371
|
+
function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
5372
|
+
p = fudgeArray(p);
|
|
5373
|
+
return uniform(key$1, shape$1).less(p);
|
|
4824
5374
|
}
|
|
5375
|
+
/**
|
|
5376
|
+
* @function
|
|
5377
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
5378
|
+
*/
|
|
5379
|
+
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
5380
|
+
const u = uniform(key$1, shape$1);
|
|
5381
|
+
return negative(log1p(negative(u)));
|
|
5382
|
+
}, { staticArgnums: [1] });
|
|
5383
|
+
/**
|
|
5384
|
+
* @function
|
|
5385
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5386
|
+
*
|
|
5387
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5388
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5389
|
+
* bitwise identical to JAX.
|
|
5390
|
+
*/
|
|
5391
|
+
const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
5392
|
+
const [k1, k2] = split(key$1, 2);
|
|
5393
|
+
const u1 = uniform(k1, shape$1);
|
|
5394
|
+
const u2 = uniform(k2, shape$1);
|
|
5395
|
+
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5396
|
+
const theta = u2.mul(2 * Math.PI);
|
|
5397
|
+
return radius.mul(cos(theta));
|
|
5398
|
+
}, { staticArgnums: [1] });
|
|
4825
5399
|
|
|
4826
5400
|
//#endregion
|
|
4827
5401
|
//#region src/polyfills.ts
|
|
@@ -4831,20 +5405,36 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
4831
5405
|
|
|
4832
5406
|
//#endregion
|
|
4833
5407
|
//#region src/index.ts
|
|
4834
|
-
/**
|
|
5408
|
+
/**
|
|
5409
|
+
* @function
|
|
5410
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
5411
|
+
*/
|
|
4835
5412
|
const jvp = jvp$1;
|
|
4836
|
-
/**
|
|
5413
|
+
/**
|
|
5414
|
+
* @function
|
|
5415
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
5416
|
+
*/
|
|
4837
5417
|
const vmap = vmap$1;
|
|
4838
|
-
/**
|
|
5418
|
+
/**
|
|
5419
|
+
* @function
|
|
5420
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
5421
|
+
*/
|
|
4839
5422
|
const jacfwd = jacfwd$1;
|
|
4840
|
-
/**
|
|
5423
|
+
/**
|
|
5424
|
+
* @function
|
|
5425
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
5426
|
+
*/
|
|
4841
5427
|
const makeJaxpr = makeJaxpr$1;
|
|
4842
5428
|
/**
|
|
5429
|
+
* @function
|
|
4843
5430
|
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
4844
5431
|
*
|
|
4845
5432
|
* The function will be compiled the first time it is called with a set of
|
|
4846
5433
|
* argument shapes.
|
|
4847
5434
|
*
|
|
5435
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
5436
|
+
* calls to free memory associated with array constants.
|
|
5437
|
+
*
|
|
4848
5438
|
* **Options:**
|
|
4849
5439
|
* - `staticArgnums`: An array of argument indices to treat as static
|
|
4850
5440
|
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
@@ -4854,26 +5444,59 @@ const makeJaxpr = makeJaxpr$1;
|
|
|
4854
5444
|
*/
|
|
4855
5445
|
const jit = jit$1;
|
|
4856
5446
|
/**
|
|
5447
|
+
* @function
|
|
4857
5448
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
4858
5449
|
* partial evaluation.
|
|
4859
5450
|
*/
|
|
4860
5451
|
const linearize = linearize$1;
|
|
4861
|
-
/**
|
|
5452
|
+
/**
|
|
5453
|
+
* @function
|
|
5454
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
5455
|
+
*/
|
|
4862
5456
|
const vjp = vjp$1;
|
|
4863
5457
|
/**
|
|
5458
|
+
* @function
|
|
4864
5459
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
4865
5460
|
* first argument.
|
|
4866
5461
|
*/
|
|
4867
5462
|
const grad = grad$1;
|
|
4868
|
-
/**
|
|
5463
|
+
/**
|
|
5464
|
+
* @function
|
|
5465
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
5466
|
+
*/
|
|
4869
5467
|
const valueAndGrad = valueAndGrad$1;
|
|
4870
|
-
/**
|
|
5468
|
+
/**
|
|
5469
|
+
* @function
|
|
5470
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
5471
|
+
*/
|
|
4871
5472
|
const jacrev = jacrev$1;
|
|
4872
|
-
/**
|
|
5473
|
+
/**
|
|
5474
|
+
* @function
|
|
5475
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
5476
|
+
*/
|
|
4873
5477
|
const jacobian = jacrev;
|
|
5478
|
+
/**
|
|
5479
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
5480
|
+
*
|
|
5481
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
5482
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
5483
|
+
* to avoid queueing up too many pending operations.
|
|
5484
|
+
*
|
|
5485
|
+
* Does not consume reference to the arrays.
|
|
5486
|
+
*/
|
|
5487
|
+
async function blockUntilReady(x) {
|
|
5488
|
+
const promises = [];
|
|
5489
|
+
for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
|
|
5490
|
+
await Promise.all(promises);
|
|
5491
|
+
return x;
|
|
5492
|
+
}
|
|
4874
5493
|
|
|
4875
5494
|
//#endregion
|
|
5495
|
+
exports.Array = Array$1;
|
|
4876
5496
|
exports.DType = require_backend.DType;
|
|
5497
|
+
exports.Jaxpr = Jaxpr;
|
|
5498
|
+
exports.blockUntilReady = blockUntilReady;
|
|
5499
|
+
exports.defaultDevice = require_backend.defaultDevice;
|
|
4877
5500
|
exports.devices = require_backend.devices;
|
|
4878
5501
|
exports.grad = grad;
|
|
4879
5502
|
exports.init = require_backend.init;
|
|
@@ -4908,7 +5531,7 @@ Object.defineProperty(exports, 'random', {
|
|
|
4908
5531
|
return random_exports;
|
|
4909
5532
|
}
|
|
4910
5533
|
});
|
|
4911
|
-
exports.
|
|
5534
|
+
exports.setDebug = require_backend.setDebug;
|
|
4912
5535
|
Object.defineProperty(exports, 'tree', {
|
|
4913
5536
|
enumerable: true,
|
|
4914
5537
|
get: function () {
|