@jax-js/jax 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +67 -24
- package/dist/{backend-EBRGmEYw.js → backend-CdcTZEOF.js} +35 -6
- package/dist/{backend-Ss1Mev_-.cjs → backend-yEU0L_ig.cjs} +40 -5
- package/dist/index.cjs +324 -225
- package/dist/index.d.cts +71 -26
- package/dist/index.d.ts +71 -26
- package/dist/index.js +314 -215
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-CM-xNYzW.js} +1 -1
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-CNOpiO5T.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-yEU0L_ig.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
@@ -596,6 +596,21 @@ var Trace = class {
|
|
|
596
596
|
this.main = main;
|
|
597
597
|
}
|
|
598
598
|
};
|
|
599
|
+
/**
|
|
600
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
601
|
+
*
|
|
602
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
603
|
+
* implemented in that function as `weakType` is not passed.
|
|
604
|
+
*/
|
|
605
|
+
function promoteAvals(a, b) {
|
|
606
|
+
const shape$1 = require_backend.generalBroadcast(a.shape, b.shape);
|
|
607
|
+
const weakType = a.weakType && b.weakType;
|
|
608
|
+
let dtype;
|
|
609
|
+
if (a.weakType === b.weakType) dtype = require_backend.promoteTypes(a.dtype, b.dtype);
|
|
610
|
+
else if (a.weakType) dtype = require_backend.promoteTypes(b.dtype, require_backend.DType.Uint32);
|
|
611
|
+
else dtype = require_backend.promoteTypes(a.dtype, require_backend.DType.Uint32);
|
|
612
|
+
return new ShapedArray(shape$1, dtype, weakType);
|
|
613
|
+
}
|
|
599
614
|
var Tracer = class Tracer {
|
|
600
615
|
/** @ignore */
|
|
601
616
|
_trace;
|
|
@@ -610,10 +625,19 @@ var Tracer = class Tracer {
|
|
|
610
625
|
get size() {
|
|
611
626
|
return require_backend.prod(this.shape);
|
|
612
627
|
}
|
|
613
|
-
/** The dtype of the array. */
|
|
628
|
+
/** The dtype of elements stored in the array. */
|
|
614
629
|
get dtype() {
|
|
615
630
|
return this.aval.dtype;
|
|
616
631
|
}
|
|
632
|
+
/**
|
|
633
|
+
* Whether the array is weakly typed.
|
|
634
|
+
*
|
|
635
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
636
|
+
* `promoteTypes()` for details.
|
|
637
|
+
*/
|
|
638
|
+
get weakType() {
|
|
639
|
+
return this.aval.weakType;
|
|
640
|
+
}
|
|
617
641
|
/** The number of dimensions of the array. */
|
|
618
642
|
get ndim() {
|
|
619
643
|
return this.shape.length;
|
|
@@ -850,12 +874,13 @@ function getShape(x) {
|
|
|
850
874
|
return x instanceof Tracer ? x.shape : [];
|
|
851
875
|
}
|
|
852
876
|
var ShapedArray = class ShapedArray {
|
|
853
|
-
constructor(shape$1, dtype) {
|
|
877
|
+
constructor(shape$1, dtype, weakType) {
|
|
854
878
|
this.shape = shape$1;
|
|
855
879
|
this.dtype = dtype;
|
|
880
|
+
this.weakType = weakType;
|
|
856
881
|
}
|
|
857
882
|
static fromAval(aval) {
|
|
858
|
-
return new ShapedArray(aval.shape, aval.dtype);
|
|
883
|
+
return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
|
|
859
884
|
}
|
|
860
885
|
get ndim() {
|
|
861
886
|
return this.shape.length;
|
|
@@ -869,7 +894,7 @@ var ShapedArray = class ShapedArray {
|
|
|
869
894
|
};
|
|
870
895
|
function getAval(x) {
|
|
871
896
|
if (x instanceof Tracer) return x.aval;
|
|
872
|
-
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32);
|
|
897
|
+
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32, typeof x === "boolean" ? false : true);
|
|
873
898
|
else throw new TypeError(`Unknown value: ${x}`);
|
|
874
899
|
}
|
|
875
900
|
function bind(prim, args, params = {}) {
|
|
@@ -1154,7 +1179,7 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1154
1179
|
}
|
|
1155
1180
|
function broadcastedJit(fn) {
|
|
1156
1181
|
return (nargs, exps, avals, params) => {
|
|
1157
|
-
const newShape = avals.map((aval) => aval.shape).reduce(generalBroadcast);
|
|
1182
|
+
const newShape = avals.map((aval) => aval.shape).reduce(require_backend.generalBroadcast);
|
|
1158
1183
|
exps = exps.map((exp$3) => reshapeViews(exp$3, (st) => {
|
|
1159
1184
|
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1160
1185
|
}));
|
|
@@ -1191,7 +1216,7 @@ const jitRules = {
|
|
|
1191
1216
|
const k1 = reshapeViews(keys[1], mapping);
|
|
1192
1217
|
const c0 = require_backend.AluExp.u32(0);
|
|
1193
1218
|
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1194
|
-
const exp$2 = require_backend.AluExp.threefry2x32(
|
|
1219
|
+
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1195
1220
|
return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
|
|
1196
1221
|
},
|
|
1197
1222
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
@@ -1232,7 +1257,7 @@ const jitRules = {
|
|
|
1232
1257
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
1233
1258
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
1234
1259
|
const c = k1.exp;
|
|
1235
|
-
const cs =
|
|
1260
|
+
const cs = promoteAvals(as, bs);
|
|
1236
1261
|
return jitRules[Primitive.Reduce](nargs, [c], [cs], {
|
|
1237
1262
|
op: require_backend.AluOp.Add,
|
|
1238
1263
|
axis: [cs.ndim - 1]
|
|
@@ -1242,8 +1267,8 @@ const jitRules = {
|
|
|
1242
1267
|
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1243
1268
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1244
1269
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1245
|
-
as = new ShapedArray(stX.shape, as.dtype);
|
|
1246
|
-
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1270
|
+
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1271
|
+
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1247
1272
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1248
1273
|
},
|
|
1249
1274
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
@@ -1260,7 +1285,7 @@ const jitRules = {
|
|
|
1260
1285
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1261
1286
|
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1262
1287
|
const axisSet = new Set(axis);
|
|
1263
|
-
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1288
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
1264
1289
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1265
1290
|
finalShape.splice(outDim, 0, ...indexShape);
|
|
1266
1291
|
const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
|
|
@@ -1296,9 +1321,10 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1296
1321
|
Primitive.Conv,
|
|
1297
1322
|
Primitive.PoolTranspose
|
|
1298
1323
|
];
|
|
1324
|
+
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
1299
1325
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1300
1326
|
const eqn = jaxpr.eqns[i];
|
|
1301
|
-
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive
|
|
1327
|
+
if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1302
1328
|
for (const v of eqn.outBinders) {
|
|
1303
1329
|
blackNodes.add(v);
|
|
1304
1330
|
p1NextBlack.set(v, v);
|
|
@@ -1428,6 +1454,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1428
1454
|
static #nextId = 1001;
|
|
1429
1455
|
id;
|
|
1430
1456
|
#dtype;
|
|
1457
|
+
#weakType;
|
|
1431
1458
|
#source;
|
|
1432
1459
|
#st;
|
|
1433
1460
|
#backend;
|
|
@@ -1439,21 +1466,22 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1439
1466
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1440
1467
|
* will be freed when the array is disposed.
|
|
1441
1468
|
*/
|
|
1442
|
-
constructor(
|
|
1469
|
+
constructor(args) {
|
|
1443
1470
|
super(baseArrayTrace);
|
|
1444
1471
|
this.id = Array$1.#nextId++;
|
|
1445
|
-
this.#dtype = dtype;
|
|
1446
|
-
this.#
|
|
1447
|
-
this.#
|
|
1448
|
-
this.#
|
|
1472
|
+
this.#dtype = args.dtype;
|
|
1473
|
+
this.#weakType = args.weakType;
|
|
1474
|
+
this.#source = args.source;
|
|
1475
|
+
this.#st = args.st;
|
|
1476
|
+
this.#backend = args.backend;
|
|
1449
1477
|
this.#rc = 1;
|
|
1450
|
-
this.#pendingSet = new Set(pending);
|
|
1478
|
+
this.#pendingSet = new Set(args.pending);
|
|
1451
1479
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1452
|
-
else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1480
|
+
else if (this.#source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1453
1481
|
}
|
|
1454
1482
|
/** @ignore */
|
|
1455
1483
|
get aval() {
|
|
1456
|
-
return new ShapedArray(this.#st.shape, this.#dtype);
|
|
1484
|
+
return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
|
|
1457
1485
|
}
|
|
1458
1486
|
/** Return a simple string representation of the array's dimensions. */
|
|
1459
1487
|
toString() {
|
|
@@ -1465,6 +1493,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1465
1493
|
#check() {
|
|
1466
1494
|
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1467
1495
|
}
|
|
1496
|
+
/** Construct an array, copying fields from `this`. */
|
|
1497
|
+
#newArrayFrom(args) {
|
|
1498
|
+
return new Array$1({
|
|
1499
|
+
source: args.source ?? this.#source,
|
|
1500
|
+
st: args.st ?? this.#st,
|
|
1501
|
+
dtype: args.dtype ?? this.#dtype,
|
|
1502
|
+
weakType: this.#weakType,
|
|
1503
|
+
backend: args.backend ?? this.#backend,
|
|
1504
|
+
pending: args.pending ?? this.#pending ?? void 0
|
|
1505
|
+
});
|
|
1506
|
+
}
|
|
1468
1507
|
get ref() {
|
|
1469
1508
|
this.#check();
|
|
1470
1509
|
this.#rc++;
|
|
@@ -1504,7 +1543,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1504
1543
|
const pending = this.#pending;
|
|
1505
1544
|
for (const exe of pending) exe.updateRc(1);
|
|
1506
1545
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1507
|
-
const ar =
|
|
1546
|
+
const ar = this.#newArrayFrom({
|
|
1547
|
+
st,
|
|
1548
|
+
pending
|
|
1549
|
+
});
|
|
1508
1550
|
this.dispose();
|
|
1509
1551
|
return ar;
|
|
1510
1552
|
}
|
|
@@ -1553,7 +1595,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1553
1595
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1554
1596
|
this.dispose();
|
|
1555
1597
|
for (const ar of indices) ar.dispose();
|
|
1556
|
-
return
|
|
1598
|
+
return this.#newArrayFrom({
|
|
1599
|
+
source: output,
|
|
1600
|
+
st: require_backend.ShapeTracker.fromShape(finalShape),
|
|
1601
|
+
pending
|
|
1602
|
+
});
|
|
1557
1603
|
}
|
|
1558
1604
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1559
1605
|
#moveAxesDown(axis) {
|
|
@@ -1576,11 +1622,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1576
1622
|
return this.#reshape(this.#st.permute(perm));
|
|
1577
1623
|
}
|
|
1578
1624
|
#unary(op, dtypeOutput) {
|
|
1625
|
+
const weakType = !dtypeOutput && this.#weakType;
|
|
1579
1626
|
dtypeOutput ??= this.#dtype;
|
|
1580
1627
|
this.#check();
|
|
1581
1628
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1582
1629
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1583
|
-
return
|
|
1630
|
+
return this.#newArrayFrom({
|
|
1631
|
+
source: exp$3.simplify(),
|
|
1632
|
+
dtype: dtypeOutput,
|
|
1633
|
+
weakType
|
|
1634
|
+
});
|
|
1584
1635
|
}
|
|
1585
1636
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1586
1637
|
const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1590,41 +1641,65 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1590
1641
|
for (const exe of pending) exe.updateRc(1);
|
|
1591
1642
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1592
1643
|
this.dispose();
|
|
1593
|
-
return
|
|
1644
|
+
return this.#newArrayFrom({
|
|
1645
|
+
source: output,
|
|
1646
|
+
st: require_backend.ShapeTracker.fromShape(this.shape),
|
|
1647
|
+
dtype: dtypeOutput,
|
|
1648
|
+
weakType,
|
|
1649
|
+
pending
|
|
1650
|
+
});
|
|
1594
1651
|
}
|
|
1595
1652
|
#binary(op, other) {
|
|
1596
|
-
const custom = (src) => new require_backend.AluExp(op,
|
|
1653
|
+
const custom = (src) => new require_backend.AluExp(op, src[0].dtype, src);
|
|
1597
1654
|
return Array$1.#naryCustom(op, custom, [this, other]);
|
|
1598
1655
|
}
|
|
1599
|
-
static #naryCustom(name, custom, arrays, { dtypeOverride,
|
|
1656
|
+
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1600
1657
|
const n = arrays.length;
|
|
1601
1658
|
const backend = arrays[0].#backend;
|
|
1602
1659
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1603
1660
|
for (const ar of arrays) ar.#check();
|
|
1604
|
-
let
|
|
1661
|
+
let castDtype;
|
|
1662
|
+
let castWeakType = true;
|
|
1605
1663
|
for (let i = 0; i < n; i++) {
|
|
1606
1664
|
if (dtypeOverride?.[i]) {
|
|
1607
1665
|
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1608
|
-
} else if (
|
|
1609
|
-
|
|
1666
|
+
} else if (castDtype === void 0) {
|
|
1667
|
+
castDtype = arrays[i].#dtype;
|
|
1668
|
+
castWeakType = arrays[i].#weakType;
|
|
1669
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1610
1670
|
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1611
1671
|
}
|
|
1612
|
-
|
|
1613
|
-
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1672
|
+
const weakType = castWeakType && !strongTypeOutput;
|
|
1614
1673
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1615
1674
|
const newShape = [...arrays[0].shape];
|
|
1616
1675
|
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
1676
|
+
const sources = arrays.map((ar, i) => {
|
|
1677
|
+
if (!dtypeOverride?.[i]) return require_backend.AluExp.cast(castDtype, ar.#source);
|
|
1678
|
+
else return ar.#source;
|
|
1679
|
+
});
|
|
1617
1680
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1618
|
-
const exp$4 = custom(
|
|
1619
|
-
return new Array$1(
|
|
1681
|
+
const exp$4 = custom(sources);
|
|
1682
|
+
return new Array$1({
|
|
1683
|
+
source: exp$4.simplify(),
|
|
1684
|
+
st: arrays[0].#st,
|
|
1685
|
+
dtype: exp$4.dtype,
|
|
1686
|
+
weakType,
|
|
1687
|
+
backend
|
|
1688
|
+
});
|
|
1620
1689
|
}
|
|
1621
|
-
const exp$3 = custom(arrays.map((ar) => {
|
|
1622
|
-
const src$1 =
|
|
1690
|
+
const exp$3 = custom(arrays.map((ar, i) => {
|
|
1691
|
+
const src$1 = sources[i];
|
|
1623
1692
|
if (ar.#st.contiguous) return src$1;
|
|
1624
1693
|
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1625
1694
|
}));
|
|
1626
1695
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1627
|
-
return new Array$1(
|
|
1696
|
+
return new Array$1({
|
|
1697
|
+
source: exp$3.simplify(),
|
|
1698
|
+
st,
|
|
1699
|
+
dtype: exp$3.dtype,
|
|
1700
|
+
weakType,
|
|
1701
|
+
backend
|
|
1702
|
+
});
|
|
1628
1703
|
}
|
|
1629
1704
|
let indices;
|
|
1630
1705
|
if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
|
|
@@ -1634,14 +1709,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1634
1709
|
}
|
|
1635
1710
|
const inputs = [];
|
|
1636
1711
|
const src = [];
|
|
1637
|
-
for (const ar of arrays
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
gid = inputs.
|
|
1642
|
-
|
|
1712
|
+
for (const [i, ar] of arrays.entries()) {
|
|
1713
|
+
let nextSrc;
|
|
1714
|
+
if (ar.#source instanceof require_backend.AluExp) nextSrc = require_backend.accessorAluExp(ar.#source, ar.#st, indices);
|
|
1715
|
+
else {
|
|
1716
|
+
let gid = inputs.indexOf(ar.#source);
|
|
1717
|
+
if (gid === -1) {
|
|
1718
|
+
gid = inputs.length;
|
|
1719
|
+
inputs.push(ar.#source);
|
|
1720
|
+
}
|
|
1721
|
+
nextSrc = require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
|
|
1643
1722
|
}
|
|
1644
|
-
|
|
1723
|
+
if (!dtypeOverride?.[i]) nextSrc = require_backend.AluExp.cast(castDtype, nextSrc);
|
|
1724
|
+
src.push(nextSrc);
|
|
1645
1725
|
}
|
|
1646
1726
|
const exp$2 = custom(src);
|
|
1647
1727
|
let re = void 0;
|
|
@@ -1655,12 +1735,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1655
1735
|
for (const exe of pending) exe.updateRc(1);
|
|
1656
1736
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1657
1737
|
for (const ar of arrays) ar.dispose();
|
|
1658
|
-
return new Array$1(
|
|
1738
|
+
return new Array$1({
|
|
1739
|
+
source: output,
|
|
1740
|
+
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1741
|
+
dtype: kernel.dtype,
|
|
1742
|
+
weakType,
|
|
1743
|
+
backend,
|
|
1744
|
+
pending
|
|
1745
|
+
});
|
|
1659
1746
|
}
|
|
1660
1747
|
/** Reduce the last dimension of the array by an operation. */
|
|
1661
1748
|
#reduce(op) {
|
|
1662
|
-
this.#check();
|
|
1663
|
-
if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
|
|
1664
1749
|
const shape$1 = this.shape;
|
|
1665
1750
|
const reduction = new require_backend.Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
|
|
1666
1751
|
const newShape = shape$1.slice(0, -1);
|
|
@@ -1679,7 +1764,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1679
1764
|
for (const exe of pending) exe.updateRc(1);
|
|
1680
1765
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1681
1766
|
this.dispose();
|
|
1682
|
-
return
|
|
1767
|
+
return this.#newArrayFrom({
|
|
1768
|
+
source: output,
|
|
1769
|
+
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1770
|
+
pending
|
|
1771
|
+
});
|
|
1683
1772
|
}
|
|
1684
1773
|
/**
|
|
1685
1774
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1715,15 +1804,15 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1715
1804
|
}
|
|
1716
1805
|
#dataInline() {
|
|
1717
1806
|
this.#check();
|
|
1718
|
-
|
|
1719
|
-
const ar =
|
|
1807
|
+
if (!(this.#source instanceof require_backend.AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
|
|
1808
|
+
const ar = this.#newArrayFrom({ backend: require_backend.getBackend("cpu") });
|
|
1720
1809
|
this.dispose();
|
|
1721
1810
|
return ar.dataSync();
|
|
1722
1811
|
}
|
|
1723
1812
|
static #broadcastArrays(arrays) {
|
|
1724
1813
|
if (arrays.length === 0) throw new Error("Need at least one array to broadcast");
|
|
1725
1814
|
if (arrays.length === 1) return arrays;
|
|
1726
|
-
const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
|
|
1815
|
+
const newShape = arrays.map((a) => a.shape).reduce(require_backend.generalBroadcast);
|
|
1727
1816
|
return arrays.map((ar) => {
|
|
1728
1817
|
if (require_backend.deepEqual(ar.shape, newShape)) return ar;
|
|
1729
1818
|
return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
|
|
@@ -1842,14 +1931,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1842
1931
|
x.#backend.incRef(x.#source);
|
|
1843
1932
|
const pending = x.#pending;
|
|
1844
1933
|
for (const exe of pending) exe.updateRc(1);
|
|
1845
|
-
const y =
|
|
1934
|
+
const y = x.#newArrayFrom({
|
|
1935
|
+
dtype,
|
|
1936
|
+
weakType: false,
|
|
1937
|
+
pending
|
|
1938
|
+
});
|
|
1846
1939
|
x.dispose();
|
|
1847
1940
|
return [y];
|
|
1848
1941
|
}
|
|
1849
1942
|
},
|
|
1850
1943
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
1851
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1852
|
-
if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1944
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1945
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1853
1946
|
const c0 = zeros(shape$1, {
|
|
1854
1947
|
dtype: require_backend.DType.Uint32,
|
|
1855
1948
|
device: k0.device
|
|
@@ -1917,7 +2010,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1917
2010
|
},
|
|
1918
2011
|
[Primitive.Compare]([x, y], { op }) {
|
|
1919
2012
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1920
|
-
return [Array$1.#naryCustom("compare", custom, [x, y], {
|
|
2013
|
+
return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
|
|
1921
2014
|
},
|
|
1922
2015
|
[Primitive.Where]([cond, x, y]) {
|
|
1923
2016
|
const custom = ([cond$1, x$1, y$1]) => require_backend.AluExp.where(cond$1, x$1, y$1);
|
|
@@ -1963,7 +2056,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1963
2056
|
pending.splice(0, 0, ...prevPending);
|
|
1964
2057
|
args.forEach((x) => x.dispose());
|
|
1965
2058
|
return outputs.map((source, i) => {
|
|
1966
|
-
return new Array$1(
|
|
2059
|
+
return new Array$1({
|
|
2060
|
+
source,
|
|
2061
|
+
st: require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
|
|
2062
|
+
dtype: jaxpr.outs[i].aval.dtype,
|
|
2063
|
+
weakType: jaxpr.outs[i].aval.weakType,
|
|
2064
|
+
backend,
|
|
2065
|
+
pending
|
|
2066
|
+
});
|
|
1967
2067
|
});
|
|
1968
2068
|
}
|
|
1969
2069
|
};
|
|
@@ -1973,33 +2073,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1973
2073
|
return this.#source;
|
|
1974
2074
|
}
|
|
1975
2075
|
};
|
|
1976
|
-
/** Construct an array from a single scalar constant. */
|
|
1977
|
-
function scalar(value, { dtype, device } = {}) {
|
|
1978
|
-
if (typeof value === "number") {
|
|
1979
|
-
dtype ??= require_backend.DType.Float32;
|
|
1980
|
-
if (![
|
|
1981
|
-
require_backend.DType.Float32,
|
|
1982
|
-
require_backend.DType.Float16,
|
|
1983
|
-
require_backend.DType.Int32,
|
|
1984
|
-
require_backend.DType.Uint32
|
|
1985
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1986
|
-
} else if (typeof value === "boolean") {
|
|
1987
|
-
dtype ??= require_backend.DType.Bool;
|
|
1988
|
-
if (![
|
|
1989
|
-
require_backend.DType.Float32,
|
|
1990
|
-
require_backend.DType.Float16,
|
|
1991
|
-
require_backend.DType.Int32,
|
|
1992
|
-
require_backend.DType.Uint32,
|
|
1993
|
-
require_backend.DType.Bool
|
|
1994
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1995
|
-
} else throw new TypeError(`Invalid type for scalar ${value}`);
|
|
1996
|
-
return new Array$1(require_backend.AluExp.const(dtype, value), require_backend.ShapeTracker.fromShape([]), dtype, require_backend.getBackend(device));
|
|
1997
|
-
}
|
|
1998
2076
|
/** Constructor for creating a new array from data. */
|
|
1999
2077
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
2000
2078
|
if (values instanceof Tracer) {
|
|
2001
2079
|
if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
2002
|
-
if (dtype && values.dtype !== dtype)
|
|
2080
|
+
if (dtype && values.dtype !== dtype) values = values.astype(dtype);
|
|
2003
2081
|
return values;
|
|
2004
2082
|
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
2005
2083
|
dtype,
|
|
@@ -2021,6 +2099,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2021
2099
|
dtype,
|
|
2022
2100
|
device
|
|
2023
2101
|
});
|
|
2102
|
+
if (size$1 === 1) return full(shape$1, flat[0], {
|
|
2103
|
+
dtype,
|
|
2104
|
+
device
|
|
2105
|
+
});
|
|
2024
2106
|
if (typeof flat[0] === "boolean") {
|
|
2025
2107
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
2026
2108
|
const data = new Int32Array(flat.map((x) => x ? 1 : 0));
|
|
@@ -2029,46 +2111,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2029
2111
|
device
|
|
2030
2112
|
});
|
|
2031
2113
|
} else {
|
|
2114
|
+
const weakType = dtype == void 0;
|
|
2032
2115
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2033
2116
|
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
2034
2117
|
return arrayFromData(data, shape$1, {
|
|
2035
2118
|
dtype,
|
|
2036
2119
|
device
|
|
2037
|
-
});
|
|
2120
|
+
}, weakType);
|
|
2038
2121
|
}
|
|
2039
2122
|
}
|
|
2040
2123
|
}
|
|
2041
|
-
function arrayFromData(data, shape$1, { dtype, device } =
|
|
2124
|
+
function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
2125
|
+
if (data instanceof Float32Array) {
|
|
2126
|
+
if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2127
|
+
dtype ??= require_backend.DType.Float32;
|
|
2128
|
+
} else if (data instanceof Int32Array) {
|
|
2129
|
+
if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2130
|
+
dtype ??= require_backend.DType.Int32;
|
|
2131
|
+
} else if (data instanceof Uint32Array) {
|
|
2132
|
+
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2133
|
+
dtype ??= require_backend.DType.Uint32;
|
|
2134
|
+
} else if (data instanceof Float16Array) {
|
|
2135
|
+
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2136
|
+
dtype ??= require_backend.DType.Float16;
|
|
2137
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2042
2138
|
if (data.length < inlineArrayLimit) {
|
|
2043
2139
|
let allEqual = true;
|
|
2044
2140
|
for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
|
|
2045
2141
|
allEqual = false;
|
|
2046
2142
|
break;
|
|
2047
2143
|
}
|
|
2048
|
-
if (allEqual)
|
|
2049
|
-
dtype,
|
|
2050
|
-
device
|
|
2051
|
-
}
|
|
2144
|
+
if (allEqual) {
|
|
2145
|
+
const sa = new ShapedArray(shape$1, dtype, weakType);
|
|
2146
|
+
return fullInternal(sa, data[0], device);
|
|
2147
|
+
}
|
|
2052
2148
|
}
|
|
2053
2149
|
const backend = require_backend.getBackend(device);
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2064
|
-
dtype ??= require_backend.DType.Uint32;
|
|
2065
|
-
} else if (data instanceof Float16Array) {
|
|
2066
|
-
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2067
|
-
dtype ??= require_backend.DType.Float16;
|
|
2068
|
-
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2069
|
-
const slot = backend.malloc(data.byteLength, buf);
|
|
2070
|
-
return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2071
|
-
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
2150
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2151
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2152
|
+
return new Array$1({
|
|
2153
|
+
source: slot,
|
|
2154
|
+
st: require_backend.ShapeTracker.fromShape(shape$1),
|
|
2155
|
+
dtype,
|
|
2156
|
+
weakType,
|
|
2157
|
+
backend
|
|
2158
|
+
});
|
|
2072
2159
|
}
|
|
2073
2160
|
function dataToJs(dtype, data, shape$1) {
|
|
2074
2161
|
if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -2084,7 +2171,7 @@ function dataToJs(dtype, data, shape$1) {
|
|
|
2084
2171
|
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
2085
2172
|
function pureArray(x) {
|
|
2086
2173
|
if (x instanceof Tracer) return x;
|
|
2087
|
-
else return
|
|
2174
|
+
else return array(x);
|
|
2088
2175
|
}
|
|
2089
2176
|
var EvalTrace = class extends Trace {
|
|
2090
2177
|
pure = (x) => pureArray(x);
|
|
@@ -2095,20 +2182,27 @@ var EvalTrace = class extends Trace {
|
|
|
2095
2182
|
};
|
|
2096
2183
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2097
2184
|
const implRules = Array$1._implRules();
|
|
2185
|
+
function fullInternal(aval, fillValue, device) {
|
|
2186
|
+
return new Array$1({
|
|
2187
|
+
source: require_backend.AluExp.const(aval.dtype, fillValue),
|
|
2188
|
+
st: require_backend.ShapeTracker.fromShape(aval.shape),
|
|
2189
|
+
dtype: aval.dtype,
|
|
2190
|
+
weakType: aval.weakType,
|
|
2191
|
+
backend: require_backend.getBackend(device)
|
|
2192
|
+
});
|
|
2193
|
+
}
|
|
2098
2194
|
function zerosLike$1(val, dtype) {
|
|
2099
|
-
|
|
2100
|
-
if (val instanceof Tracer) val.dispose();
|
|
2101
|
-
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2195
|
+
return fullLike(val, 0, dtype);
|
|
2102
2196
|
}
|
|
2103
2197
|
function onesLike$1(val, dtype) {
|
|
2104
|
-
|
|
2105
|
-
if (val instanceof Tracer) val.dispose();
|
|
2106
|
-
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2198
|
+
return fullLike(val, 1, dtype);
|
|
2107
2199
|
}
|
|
2108
2200
|
function fullLike(val, fillValue, dtype) {
|
|
2109
2201
|
const aval = getAval(val);
|
|
2110
2202
|
if (val instanceof Tracer) val.dispose();
|
|
2111
|
-
|
|
2203
|
+
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
2204
|
+
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
2205
|
+
return fullInternal(sa, fillValue);
|
|
2112
2206
|
}
|
|
2113
2207
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
2114
2208
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -2126,19 +2220,14 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
2126
2220
|
}
|
|
2127
2221
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
2128
2222
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
2129
|
-
let
|
|
2130
|
-
if (typeof fillValue === "number")
|
|
2131
|
-
|
|
2132
|
-
source = require_backend.AluExp.const(dtype, fillValue);
|
|
2133
|
-
} else if (typeof fillValue === "bigint") {
|
|
2134
|
-
dtype = dtype ?? require_backend.DType.Int32;
|
|
2135
|
-
source = require_backend.AluExp.const(dtype, Number(fillValue));
|
|
2136
|
-
} else if (typeof fillValue === "boolean") {
|
|
2223
|
+
let weakType = dtype == void 0;
|
|
2224
|
+
if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
|
|
2225
|
+
else if (typeof fillValue === "boolean") {
|
|
2137
2226
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
2138
|
-
|
|
2227
|
+
weakType = false;
|
|
2139
2228
|
} else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
|
|
2140
2229
|
else throw new TypeError(`Invalid type for full: ${fillValue}`);
|
|
2141
|
-
return new
|
|
2230
|
+
return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
|
|
2142
2231
|
}
|
|
2143
2232
|
/**
|
|
2144
2233
|
* Create an identity matrix.
|
|
@@ -2148,6 +2237,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
2148
2237
|
*/
|
|
2149
2238
|
function eye(numRows, numCols, { dtype, device } = {}) {
|
|
2150
2239
|
numCols = numCols ?? numRows;
|
|
2240
|
+
const weakType = dtype == void 0;
|
|
2151
2241
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2152
2242
|
if (numCols < numRows) {
|
|
2153
2243
|
const arr = eye(numCols, numRows, {
|
|
@@ -2161,7 +2251,13 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2161
2251
|
device
|
|
2162
2252
|
});
|
|
2163
2253
|
const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
|
|
2164
|
-
return new Array$1(
|
|
2254
|
+
return new Array$1({
|
|
2255
|
+
source: require_backend.AluExp.cast(dtype, exp$2),
|
|
2256
|
+
st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
|
|
2257
|
+
dtype,
|
|
2258
|
+
weakType,
|
|
2259
|
+
backend: require_backend.getBackend(device)
|
|
2260
|
+
});
|
|
2165
2261
|
}
|
|
2166
2262
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2167
2263
|
function identity$1(n, { dtype, device } = {}) {
|
|
@@ -2198,7 +2294,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2198
2294
|
});
|
|
2199
2295
|
const exp$2 = require_backend.AluExp.add(require_backend.AluExp.const(dtype, start), require_backend.AluExp.mul(require_backend.AluExp.cast(dtype, require_backend.AluVar.idx), require_backend.AluExp.const(dtype, step)));
|
|
2200
2296
|
const st = require_backend.ShapeTracker.fromShape([size$1]);
|
|
2201
|
-
return new Array$1(
|
|
2297
|
+
return new Array$1({
|
|
2298
|
+
source: exp$2,
|
|
2299
|
+
st,
|
|
2300
|
+
dtype,
|
|
2301
|
+
weakType: false,
|
|
2302
|
+
backend: require_backend.getBackend(device)
|
|
2303
|
+
});
|
|
2202
2304
|
}
|
|
2203
2305
|
/**
|
|
2204
2306
|
* Return evenly spaced numbers over a specified interval.
|
|
@@ -2216,10 +2318,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2216
2318
|
dtype,
|
|
2217
2319
|
device
|
|
2218
2320
|
});
|
|
2219
|
-
else if (num === 1) return
|
|
2321
|
+
else if (num === 1) return full([1], start, {
|
|
2220
2322
|
dtype,
|
|
2221
2323
|
device
|
|
2222
|
-
})
|
|
2324
|
+
});
|
|
2223
2325
|
else if (start === stop) return full([num], start, {
|
|
2224
2326
|
dtype,
|
|
2225
2327
|
device
|
|
@@ -2228,7 +2330,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2228
2330
|
const denom = endpoint ? num - 1 : num;
|
|
2229
2331
|
const exp$2 = require_backend.AluExp.cast(dtype, require_backend.AluExp.add(require_backend.AluExp.f32(start), require_backend.AluExp.mul(require_backend.AluExp.f32(delta / denom), require_backend.AluExp.cast(require_backend.DType.Float32, require_backend.AluVar.idx))));
|
|
2230
2332
|
const st = require_backend.ShapeTracker.fromShape([num]);
|
|
2231
|
-
return new Array$1(
|
|
2333
|
+
return new Array$1({
|
|
2334
|
+
source: exp$2,
|
|
2335
|
+
st,
|
|
2336
|
+
dtype,
|
|
2337
|
+
weakType: false,
|
|
2338
|
+
backend: require_backend.getBackend(device)
|
|
2339
|
+
});
|
|
2232
2340
|
}
|
|
2233
2341
|
function aluCompare(a, b, op) {
|
|
2234
2342
|
switch (op) {
|
|
@@ -2240,35 +2348,6 @@ function aluCompare(a, b, op) {
|
|
|
2240
2348
|
case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
|
|
2241
2349
|
}
|
|
2242
2350
|
}
|
|
2243
|
-
/**
|
|
2244
|
-
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
2245
|
-
*
|
|
2246
|
-
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
2247
|
-
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
2248
|
-
* Two dimensions are compatible when:
|
|
2249
|
-
* 1. they are equal, or
|
|
2250
|
-
* 2. one of them is 1."
|
|
2251
|
-
*
|
|
2252
|
-
* Throws a TypeError if the broadcast is not possible.
|
|
2253
|
-
*
|
|
2254
|
-
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
2255
|
-
*/
|
|
2256
|
-
function generalBroadcast(a, b) {
|
|
2257
|
-
const out = [];
|
|
2258
|
-
let i = a.length - 1;
|
|
2259
|
-
let j = b.length - 1;
|
|
2260
|
-
for (; i >= 0 && j >= 0; i--, j--) {
|
|
2261
|
-
const x = a[i];
|
|
2262
|
-
const y = b[j];
|
|
2263
|
-
if (x === y) out.push(x);
|
|
2264
|
-
else if (x === 1) out.push(y);
|
|
2265
|
-
else if (y === 1) out.push(x);
|
|
2266
|
-
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
2267
|
-
}
|
|
2268
|
-
for (; i >= 0; i--) out.push(a[i]);
|
|
2269
|
-
for (; j >= 0; j--) out.push(b[j]);
|
|
2270
|
-
return out.reverse();
|
|
2271
|
-
}
|
|
2272
2351
|
|
|
2273
2352
|
//#endregion
|
|
2274
2353
|
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
|
|
@@ -2348,13 +2427,15 @@ var Var = class Var {
|
|
|
2348
2427
|
};
|
|
2349
2428
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
2350
2429
|
var Lit = class {
|
|
2351
|
-
dtype;
|
|
2352
2430
|
value;
|
|
2353
2431
|
aval;
|
|
2354
|
-
|
|
2355
|
-
this.dtype
|
|
2432
|
+
get dtype() {
|
|
2433
|
+
return this.aval.dtype;
|
|
2434
|
+
}
|
|
2435
|
+
constructor(aval, value) {
|
|
2436
|
+
if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
|
|
2356
2437
|
this.value = value;
|
|
2357
|
-
this.aval =
|
|
2438
|
+
this.aval = ShapedArray.fromAval(aval);
|
|
2358
2439
|
}
|
|
2359
2440
|
};
|
|
2360
2441
|
function atomIsLit(atom, literal) {
|
|
@@ -2478,14 +2559,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2478
2559
|
const c = eqn.outBinders[0];
|
|
2479
2560
|
if (atomIsLit(a, 0)) context.set(c, b);
|
|
2480
2561
|
else if (atomIsLit(b, 0)) context.set(c, a);
|
|
2481
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2562
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
|
|
2563
|
+
else newEqns.push(eqn);
|
|
2564
|
+
} else if (eqn.primitive === Primitive.Neg) {
|
|
2565
|
+
const [a] = inputs;
|
|
2566
|
+
const c = eqn.outBinders[0];
|
|
2567
|
+
if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
|
|
2482
2568
|
else newEqns.push(eqn);
|
|
2483
2569
|
} else if (eqn.primitive === Primitive.Mul) {
|
|
2484
2570
|
const [a, b] = inputs;
|
|
2485
2571
|
const c = eqn.outBinders[0];
|
|
2486
2572
|
if (atomIsLit(a, 1)) context.set(c, b);
|
|
2487
2573
|
else if (atomIsLit(b, 1)) context.set(c, a);
|
|
2488
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2574
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
|
|
2489
2575
|
else newEqns.push(eqn);
|
|
2490
2576
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
2491
2577
|
const [a, b] = inputs;
|
|
@@ -2583,7 +2669,7 @@ function evalJaxpr(jaxpr, args) {
|
|
|
2583
2669
|
if (x instanceof Var) {
|
|
2584
2670
|
remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
|
|
2585
2671
|
return env.get(x);
|
|
2586
|
-
} else return
|
|
2672
|
+
} else return array(x.value, { dtype: x.dtype });
|
|
2587
2673
|
};
|
|
2588
2674
|
const write = (v, val) => {
|
|
2589
2675
|
if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
|
|
@@ -2642,7 +2728,7 @@ var JaxprTrace = class extends Trace {
|
|
|
2642
2728
|
let tracer = this.builder.constTracers.get(val);
|
|
2643
2729
|
if (tracer === void 0) {
|
|
2644
2730
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
2645
|
-
this.builder.addConst(tracer, val instanceof Tracer ? val.ref :
|
|
2731
|
+
this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
|
|
2646
2732
|
}
|
|
2647
2733
|
return tracer;
|
|
2648
2734
|
}
|
|
@@ -2711,7 +2797,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2711
2797
|
const newConsts = [];
|
|
2712
2798
|
for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
|
|
2713
2799
|
const ar = consts[i];
|
|
2714
|
-
literals.set(jaxpr.inBinders[i], new Lit(ar.
|
|
2800
|
+
literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
|
|
2715
2801
|
} else {
|
|
2716
2802
|
constBinders.push(jaxpr.inBinders[i]);
|
|
2717
2803
|
newConsts.push(consts[i]);
|
|
@@ -2724,13 +2810,12 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2724
2810
|
}
|
|
2725
2811
|
function binopAbstractEval([x, y]) {
|
|
2726
2812
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
2727
|
-
|
|
2728
|
-
return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
|
|
2813
|
+
return [promoteAvals(x, y)];
|
|
2729
2814
|
}
|
|
2730
2815
|
function compareAbstractEval([x, y]) {
|
|
2731
2816
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
|
|
2732
|
-
|
|
2733
|
-
return [new ShapedArray(
|
|
2817
|
+
const aval = promoteAvals(x, y);
|
|
2818
|
+
return [new ShapedArray(aval.shape, require_backend.DType.Bool, false)];
|
|
2734
2819
|
}
|
|
2735
2820
|
function vectorizedUnopAbstractEval([x]) {
|
|
2736
2821
|
return [ShapedArray.fromAval(x)];
|
|
@@ -2743,18 +2828,18 @@ const abstractEvalRules = {
|
|
|
2743
2828
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
2744
2829
|
[Primitive.StopGradient]: vectorizedUnopAbstractEval,
|
|
2745
2830
|
[Primitive.Cast]([x], { dtype }) {
|
|
2746
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2831
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2747
2832
|
},
|
|
2748
2833
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
2749
2834
|
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
2750
2835
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
2751
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2836
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2752
2837
|
},
|
|
2753
2838
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
2754
2839
|
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
2755
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2756
|
-
if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2757
|
-
return [new ShapedArray(shape$1, require_backend.DType.Uint32)];
|
|
2840
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2841
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2842
|
+
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
2758
2843
|
},
|
|
2759
2844
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2760
2845
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
@@ -2768,55 +2853,54 @@ const abstractEvalRules = {
|
|
|
2768
2853
|
[Primitive.Reduce]([x], { axis }) {
|
|
2769
2854
|
const axisSet = new Set(axis);
|
|
2770
2855
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2771
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2856
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2772
2857
|
},
|
|
2773
2858
|
[Primitive.Pool]([x], { window, strides }) {
|
|
2774
2859
|
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2775
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2860
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2776
2861
|
},
|
|
2777
2862
|
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2778
2863
|
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2779
2864
|
if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2780
|
-
return [new ShapedArray(inShape, x.dtype)];
|
|
2865
|
+
return [new ShapedArray(inShape, x.dtype, x.weakType)];
|
|
2781
2866
|
},
|
|
2782
2867
|
[Primitive.Dot]([x, y]) {
|
|
2783
|
-
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2784
2868
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
2785
|
-
const shape$1 =
|
|
2869
|
+
const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
|
|
2786
2870
|
shape$1.splice(-1, 1);
|
|
2787
|
-
return [new ShapedArray(shape$1,
|
|
2871
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2788
2872
|
},
|
|
2789
2873
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
2790
|
-
|
|
2874
|
+
const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
|
|
2791
2875
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2792
|
-
return [new ShapedArray(shape$1,
|
|
2876
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2793
2877
|
},
|
|
2794
2878
|
[Primitive.Compare]: compareAbstractEval,
|
|
2795
2879
|
[Primitive.Where]([cond, x, y]) {
|
|
2796
2880
|
if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
2797
|
-
|
|
2798
|
-
const shape$1 = generalBroadcast(cond.shape,
|
|
2799
|
-
return [new ShapedArray(shape$1,
|
|
2881
|
+
const xy = promoteAvals(x, y);
|
|
2882
|
+
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
2883
|
+
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
2800
2884
|
},
|
|
2801
2885
|
[Primitive.Transpose]([x], { perm }) {
|
|
2802
|
-
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
|
|
2886
|
+
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
2803
2887
|
},
|
|
2804
2888
|
[Primitive.Broadcast]([x], { shape: shape$1 }) {
|
|
2805
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2889
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2806
2890
|
},
|
|
2807
2891
|
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2808
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2892
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2809
2893
|
},
|
|
2810
2894
|
[Primitive.Flip]([x], _) {
|
|
2811
|
-
return [
|
|
2895
|
+
return [ShapedArray.fromAval(x)];
|
|
2812
2896
|
},
|
|
2813
2897
|
[Primitive.Shrink]([x], { slice }) {
|
|
2814
2898
|
const newShape = slice.map((s) => s[1] - s[0]);
|
|
2815
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2899
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2816
2900
|
},
|
|
2817
2901
|
[Primitive.Pad]([x], { width }) {
|
|
2818
2902
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
2819
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2903
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2820
2904
|
},
|
|
2821
2905
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2822
2906
|
for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
@@ -2826,10 +2910,10 @@ const abstractEvalRules = {
|
|
|
2826
2910
|
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
2827
2911
|
const axisSet = new Set(axis);
|
|
2828
2912
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
2829
|
-
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
2913
|
+
const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
|
|
2830
2914
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2831
2915
|
newShape.splice(outDim, 0, ...gatherShape);
|
|
2832
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2916
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2833
2917
|
},
|
|
2834
2918
|
[Primitive.JitCall](args, { jaxpr }) {
|
|
2835
2919
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
@@ -2896,6 +2980,7 @@ function jit$1(f, opts) {
|
|
|
2896
2980
|
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2897
2981
|
const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2898
2982
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2983
|
+
name: f.name || "closure",
|
|
2899
2984
|
jaxpr,
|
|
2900
2985
|
numConsts: consts.length
|
|
2901
2986
|
});
|
|
@@ -3058,13 +3143,14 @@ const jvpRules = {
|
|
|
3058
3143
|
const indicesRef = indices.map((t) => t.ref);
|
|
3059
3144
|
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3060
3145
|
},
|
|
3061
|
-
[Primitive.JitCall](primals, tangents, { jaxpr }) {
|
|
3146
|
+
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3062
3147
|
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3063
3148
|
const outs = bind(Primitive.JitCall, [
|
|
3064
3149
|
...newConsts.map((c) => c.ref),
|
|
3065
3150
|
...primals,
|
|
3066
3151
|
...tangents
|
|
3067
3152
|
], {
|
|
3153
|
+
name: `${name}_jvp`,
|
|
3068
3154
|
jaxpr: newJaxpr,
|
|
3069
3155
|
numConsts: newConsts.length
|
|
3070
3156
|
});
|
|
@@ -3119,7 +3205,7 @@ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
|
3119
3205
|
function mappedAval(batchDim, aval) {
|
|
3120
3206
|
const shape$1 = [...aval.shape];
|
|
3121
3207
|
shape$1.splice(batchDim, 1);
|
|
3122
|
-
return new ShapedArray(shape$1, aval.dtype);
|
|
3208
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3123
3209
|
}
|
|
3124
3210
|
/** Move one axis to a different index. */
|
|
3125
3211
|
function moveaxis$1(x, src, dst) {
|
|
@@ -3263,9 +3349,10 @@ const vmapRules = {
|
|
|
3263
3349
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3264
3350
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3265
3351
|
},
|
|
3266
|
-
[Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
|
|
3352
|
+
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3267
3353
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3268
3354
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
3355
|
+
name: `${name}_vmap`,
|
|
3269
3356
|
jaxpr: newJaxpr,
|
|
3270
3357
|
numConsts: newConsts.length
|
|
3271
3358
|
});
|
|
@@ -3281,7 +3368,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3281
3368
|
if (dims[i] === null) return v.aval;
|
|
3282
3369
|
const shape$1 = [...v.aval.shape];
|
|
3283
3370
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3284
|
-
return new ShapedArray(shape$1, v.aval.dtype);
|
|
3371
|
+
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3285
3372
|
});
|
|
3286
3373
|
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3287
3374
|
const result = {
|
|
@@ -3494,8 +3581,8 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3494
3581
|
processPrimitive(primitive, tracers, params) {
|
|
3495
3582
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3496
3583
|
if (primitive === Primitive.JitCall) {
|
|
3497
|
-
const { jaxpr, numConsts } = params;
|
|
3498
|
-
return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
|
|
3584
|
+
const { name, jaxpr, numConsts } = params;
|
|
3585
|
+
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3499
3586
|
}
|
|
3500
3587
|
const tracersIn = tracers.map((t) => this.instantiateConst(t));
|
|
3501
3588
|
const avalsIn = tracersIn.map((t) => t.pval.aval);
|
|
@@ -3521,12 +3608,13 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3521
3608
|
*
|
|
3522
3609
|
* Used when encountering a JitCall rule during the trace.
|
|
3523
3610
|
*/
|
|
3524
|
-
#partialEvalJaxpr(jaxpr, numConsts, tracers) {
|
|
3611
|
+
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3525
3612
|
jaxpr = jaxpr.flatten();
|
|
3526
3613
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3527
3614
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3528
3615
|
const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
|
|
3529
3616
|
const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3617
|
+
name: `${name}_peval`,
|
|
3530
3618
|
jaxpr: jaxpr1,
|
|
3531
3619
|
numConsts: 0
|
|
3532
3620
|
});
|
|
@@ -3538,6 +3626,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3538
3626
|
prim: Primitive.JitCall,
|
|
3539
3627
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3540
3628
|
params: {
|
|
3629
|
+
name: `${name}_resid`,
|
|
3541
3630
|
jaxpr: jaxpr2,
|
|
3542
3631
|
numConsts: 0
|
|
3543
3632
|
},
|
|
@@ -3680,7 +3769,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
|
|
|
3680
3769
|
}
|
|
3681
3770
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
3682
3771
|
const eqn = jaxpr.eqns[i];
|
|
3683
|
-
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ?
|
|
3772
|
+
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
|
|
3684
3773
|
const cotangentsOut = eqn.outBinders.map(readCotangent);
|
|
3685
3774
|
const rule = transposeRules[eqn.primitive];
|
|
3686
3775
|
if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
|
|
@@ -3765,7 +3854,7 @@ const transposeRules = {
|
|
|
3765
3854
|
},
|
|
3766
3855
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3767
3856
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3768
|
-
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3857
|
+
const axisSize = require_backend.generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3769
3858
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3770
3859
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3771
3860
|
},
|
|
@@ -3860,7 +3949,7 @@ const transposeRules = {
|
|
|
3860
3949
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
3861
3950
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
3862
3951
|
},
|
|
3863
|
-
[Primitive.JitCall](cts, args, { jaxpr }) {
|
|
3952
|
+
[Primitive.JitCall](cts, args, { name, jaxpr }) {
|
|
3864
3953
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
3865
3954
|
const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
|
|
3866
3955
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
@@ -3869,6 +3958,7 @@ const transposeRules = {
|
|
|
3869
3958
|
...residuals,
|
|
3870
3959
|
...cts
|
|
3871
3960
|
], {
|
|
3961
|
+
name: `${name}_t`,
|
|
3872
3962
|
jaxpr: newJaxpr,
|
|
3873
3963
|
numConsts: newConsts.length
|
|
3874
3964
|
});
|
|
@@ -3943,7 +4033,7 @@ function valueAndGrad$1(f) {
|
|
|
3943
4033
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3944
4034
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3945
4035
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3946
|
-
const [ct, ...rest] = fVjp(
|
|
4036
|
+
const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
|
|
3947
4037
|
for (const r of rest) dispose(r);
|
|
3948
4038
|
fVjp.dispose();
|
|
3949
4039
|
return [y, ct];
|
|
@@ -4313,7 +4403,7 @@ function argmin(a, axis, opts) {
|
|
|
4313
4403
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4314
4404
|
const shape$1 = a.shape;
|
|
4315
4405
|
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4316
|
-
const length =
|
|
4406
|
+
const length = array(shape$1[axis], {
|
|
4317
4407
|
dtype: int32,
|
|
4318
4408
|
device: a.device
|
|
4319
4409
|
});
|
|
@@ -4337,7 +4427,7 @@ function argmax(a, axis, opts) {
|
|
|
4337
4427
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4338
4428
|
const shape$1 = a.shape;
|
|
4339
4429
|
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4340
|
-
const length =
|
|
4430
|
+
const length = array(shape$1[axis], {
|
|
4341
4431
|
dtype: int32,
|
|
4342
4432
|
device: a.device
|
|
4343
4433
|
});
|
|
@@ -4521,7 +4611,7 @@ function broadcastTo(a, shape$1) {
|
|
|
4521
4611
|
/** Broadcast input shapes to a common output shape. */
|
|
4522
4612
|
function broadcastShapes(...shapes) {
|
|
4523
4613
|
if (shapes.length === 0) return [];
|
|
4524
|
-
return shapes.reduce(generalBroadcast);
|
|
4614
|
+
return shapes.reduce(require_backend.generalBroadcast);
|
|
4525
4615
|
}
|
|
4526
4616
|
/** Broadcast arrays to a common shape. */
|
|
4527
4617
|
function broadcastArrays(...arrays) {
|
|
@@ -4753,7 +4843,7 @@ function acos(x) {
|
|
|
4753
4843
|
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4754
4844
|
* improvements.
|
|
4755
4845
|
*/
|
|
4756
|
-
const hypot = jit$1((x1, x2)
|
|
4846
|
+
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4757
4847
|
return sqrt(square(x1).add(square(x2)));
|
|
4758
4848
|
});
|
|
4759
4849
|
/**
|
|
@@ -4769,7 +4859,7 @@ const hypot = jit$1((x1, x2) => {
|
|
|
4769
4859
|
*
|
|
4770
4860
|
* The output is ill-defined when both x and y are zero.
|
|
4771
4861
|
*/
|
|
4772
|
-
const atan2 = jit$1((y, x)
|
|
4862
|
+
const atan2 = jit$1(function atan2$1(y, x) {
|
|
4773
4863
|
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4774
4864
|
const xNeg = less(x.ref, 0);
|
|
4775
4865
|
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
@@ -4837,13 +4927,13 @@ const degrees = rad2deg;
|
|
|
4837
4927
|
* @function
|
|
4838
4928
|
* Computes first array raised to power of second array, element-wise.
|
|
4839
4929
|
*/
|
|
4840
|
-
const power = jit$1((x1, x2)
|
|
4930
|
+
const power = jit$1(function power$1(x1, x2) {
|
|
4841
4931
|
return exp(log(x1).mul(x2));
|
|
4842
4932
|
});
|
|
4843
4933
|
/** @function Alias of `jax.numpy.power()`. */
|
|
4844
4934
|
const pow = power;
|
|
4845
4935
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
4846
|
-
const cbrt = jit$1((x)
|
|
4936
|
+
const cbrt = jit$1(function cbrt$1(x) {
|
|
4847
4937
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4848
4938
|
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4849
4939
|
});
|
|
@@ -4853,7 +4943,7 @@ const cbrt = jit$1((x) => {
|
|
|
4853
4943
|
*
|
|
4854
4944
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4855
4945
|
*/
|
|
4856
|
-
const sinh = jit$1((x)
|
|
4946
|
+
const sinh = jit$1(function sinh$1(x) {
|
|
4857
4947
|
const ex = exp(x);
|
|
4858
4948
|
const emx = reciprocal(ex.ref);
|
|
4859
4949
|
return ex.sub(emx).mul(.5);
|
|
@@ -4864,7 +4954,7 @@ const sinh = jit$1((x) => {
|
|
|
4864
4954
|
*
|
|
4865
4955
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4866
4956
|
*/
|
|
4867
|
-
const cosh = jit$1((x)
|
|
4957
|
+
const cosh = jit$1(function cosh$1(x) {
|
|
4868
4958
|
const ex = exp(x);
|
|
4869
4959
|
const emx = reciprocal(ex.ref);
|
|
4870
4960
|
return ex.add(emx).mul(.5);
|
|
@@ -4875,7 +4965,7 @@ const cosh = jit$1((x) => {
|
|
|
4875
4965
|
*
|
|
4876
4966
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4877
4967
|
*/
|
|
4878
|
-
const tanh = jit$1((x)
|
|
4968
|
+
const tanh = jit$1(function tanh$1(x) {
|
|
4879
4969
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4880
4970
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4881
4971
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
@@ -4886,7 +4976,7 @@ const tanh = jit$1((x) => {
|
|
|
4886
4976
|
*
|
|
4887
4977
|
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4888
4978
|
*/
|
|
4889
|
-
const arcsinh = jit$1((x)
|
|
4979
|
+
const arcsinh = jit$1(function arcsinh$1(x) {
|
|
4890
4980
|
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4891
4981
|
});
|
|
4892
4982
|
/**
|
|
@@ -4895,7 +4985,7 @@ const arcsinh = jit$1((x) => {
|
|
|
4895
4985
|
*
|
|
4896
4986
|
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4897
4987
|
*/
|
|
4898
|
-
const arccosh = jit$1((x)
|
|
4988
|
+
const arccosh = jit$1(function arccosh$1(x) {
|
|
4899
4989
|
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4900
4990
|
});
|
|
4901
4991
|
/**
|
|
@@ -4904,7 +4994,7 @@ const arccosh = jit$1((x) => {
|
|
|
4904
4994
|
*
|
|
4905
4995
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4906
4996
|
*/
|
|
4907
|
-
const arctanh = jit$1((x)
|
|
4997
|
+
const arctanh = jit$1(function arctanh$1(x) {
|
|
4908
4998
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4909
4999
|
});
|
|
4910
5000
|
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
@@ -5020,7 +5110,9 @@ function softSign(x) {
|
|
|
5020
5110
|
*
|
|
5021
5111
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
5022
5112
|
*/
|
|
5023
|
-
const silu = jit$1((x)
|
|
5113
|
+
const silu = jit$1(function silu$1(x) {
|
|
5114
|
+
return x.ref.mul(sigmoid(x));
|
|
5115
|
+
});
|
|
5024
5116
|
/**
|
|
5025
5117
|
* @function
|
|
5026
5118
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -5081,7 +5173,7 @@ function celu(x, alpha = 1) {
|
|
|
5081
5173
|
*
|
|
5082
5174
|
* This will be improved in the future.
|
|
5083
5175
|
*/
|
|
5084
|
-
const gelu = jit$1((x)
|
|
5176
|
+
const gelu = jit$1(function gelu$1(x) {
|
|
5085
5177
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5086
5178
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5087
5179
|
});
|
|
@@ -5252,8 +5344,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
5252
5344
|
const keyShape = validateKeyShape(key$1);
|
|
5253
5345
|
return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
|
|
5254
5346
|
}
|
|
5255
|
-
/**
|
|
5256
|
-
function
|
|
5347
|
+
/**
|
|
5348
|
+
* @function
|
|
5349
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
5350
|
+
*/
|
|
5351
|
+
const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
5257
5352
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
5258
5353
|
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
5259
5354
|
dtype: require_backend.DType.Uint32,
|
|
@@ -5266,7 +5361,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
5266
5361
|
const rand = bitcast(float12, require_backend.DType.Float32).sub(1);
|
|
5267
5362
|
if (minval === 0 && maxval === 1) return rand;
|
|
5268
5363
|
else return rand.mul(maxval - minval).add(minval);
|
|
5269
|
-
}
|
|
5364
|
+
}, { staticArgnums: [1, 2] });
|
|
5270
5365
|
/**
|
|
5271
5366
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5272
5367
|
*
|
|
@@ -5277,26 +5372,30 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
5277
5372
|
p = fudgeArray(p);
|
|
5278
5373
|
return uniform(key$1, shape$1).less(p);
|
|
5279
5374
|
}
|
|
5280
|
-
/**
|
|
5281
|
-
function
|
|
5375
|
+
/**
|
|
5376
|
+
* @function
|
|
5377
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
5378
|
+
*/
|
|
5379
|
+
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
5282
5380
|
const u = uniform(key$1, shape$1);
|
|
5283
5381
|
return negative(log1p(negative(u)));
|
|
5284
|
-
}
|
|
5382
|
+
}, { staticArgnums: [1] });
|
|
5285
5383
|
/**
|
|
5384
|
+
* @function
|
|
5286
5385
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5287
5386
|
*
|
|
5288
5387
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5289
5388
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5290
5389
|
* bitwise identical to JAX.
|
|
5291
5390
|
*/
|
|
5292
|
-
function normal(key$1, shape$1 = []) {
|
|
5391
|
+
const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
5293
5392
|
const [k1, k2] = split(key$1, 2);
|
|
5294
5393
|
const u1 = uniform(k1, shape$1);
|
|
5295
5394
|
const u2 = uniform(k2, shape$1);
|
|
5296
5395
|
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5297
5396
|
const theta = u2.mul(2 * Math.PI);
|
|
5298
5397
|
return radius.mul(cos(theta));
|
|
5299
|
-
}
|
|
5398
|
+
}, { staticArgnums: [1] });
|
|
5300
5399
|
|
|
5301
5400
|
//#endregion
|
|
5302
5401
|
//#region src/polyfills.ts
|