@jax-js/jax 0.0.4 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +67 -24
- package/dist/{backend-EBRGmEYw.js → backend-CdcTZEOF.js} +35 -6
- package/dist/{backend-Ss1Mev_-.cjs → backend-yEU0L_ig.cjs} +40 -5
- package/dist/index.cjs +324 -225
- package/dist/index.d.cts +71 -26
- package/dist/index.d.ts +71 -26
- package/dist/index.js +314 -215
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-CM-xNYzW.js} +1 -1
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-CNOpiO5T.cjs} +1 -1
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CdcTZEOF.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -565,6 +565,21 @@ var Trace = class {
|
|
|
565
565
|
this.main = main;
|
|
566
566
|
}
|
|
567
567
|
};
|
|
568
|
+
/**
|
|
569
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
570
|
+
*
|
|
571
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
572
|
+
* implemented in that function as `weakType` is not passed.
|
|
573
|
+
*/
|
|
574
|
+
function promoteAvals(a, b) {
|
|
575
|
+
const shape$1 = generalBroadcast(a.shape, b.shape);
|
|
576
|
+
const weakType = a.weakType && b.weakType;
|
|
577
|
+
let dtype;
|
|
578
|
+
if (a.weakType === b.weakType) dtype = promoteTypes(a.dtype, b.dtype);
|
|
579
|
+
else if (a.weakType) dtype = promoteTypes(b.dtype, DType.Uint32);
|
|
580
|
+
else dtype = promoteTypes(a.dtype, DType.Uint32);
|
|
581
|
+
return new ShapedArray(shape$1, dtype, weakType);
|
|
582
|
+
}
|
|
568
583
|
var Tracer = class Tracer {
|
|
569
584
|
/** @ignore */
|
|
570
585
|
_trace;
|
|
@@ -579,10 +594,19 @@ var Tracer = class Tracer {
|
|
|
579
594
|
get size() {
|
|
580
595
|
return prod(this.shape);
|
|
581
596
|
}
|
|
582
|
-
/** The dtype of the array. */
|
|
597
|
+
/** The dtype of elements stored in the array. */
|
|
583
598
|
get dtype() {
|
|
584
599
|
return this.aval.dtype;
|
|
585
600
|
}
|
|
601
|
+
/**
|
|
602
|
+
* Whether the array is weakly typed.
|
|
603
|
+
*
|
|
604
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
605
|
+
* `promoteTypes()` for details.
|
|
606
|
+
*/
|
|
607
|
+
get weakType() {
|
|
608
|
+
return this.aval.weakType;
|
|
609
|
+
}
|
|
586
610
|
/** The number of dimensions of the array. */
|
|
587
611
|
get ndim() {
|
|
588
612
|
return this.shape.length;
|
|
@@ -819,12 +843,13 @@ function getShape(x) {
|
|
|
819
843
|
return x instanceof Tracer ? x.shape : [];
|
|
820
844
|
}
|
|
821
845
|
var ShapedArray = class ShapedArray {
|
|
822
|
-
constructor(shape$1, dtype) {
|
|
846
|
+
constructor(shape$1, dtype, weakType) {
|
|
823
847
|
this.shape = shape$1;
|
|
824
848
|
this.dtype = dtype;
|
|
849
|
+
this.weakType = weakType;
|
|
825
850
|
}
|
|
826
851
|
static fromAval(aval) {
|
|
827
|
-
return new ShapedArray(aval.shape, aval.dtype);
|
|
852
|
+
return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
|
|
828
853
|
}
|
|
829
854
|
get ndim() {
|
|
830
855
|
return this.shape.length;
|
|
@@ -838,7 +863,7 @@ var ShapedArray = class ShapedArray {
|
|
|
838
863
|
};
|
|
839
864
|
function getAval(x) {
|
|
840
865
|
if (x instanceof Tracer) return x.aval;
|
|
841
|
-
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32);
|
|
866
|
+
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32, typeof x === "boolean" ? false : true);
|
|
842
867
|
else throw new TypeError(`Unknown value: ${x}`);
|
|
843
868
|
}
|
|
844
869
|
function bind(prim, args, params = {}) {
|
|
@@ -1160,7 +1185,7 @@ const jitRules = {
|
|
|
1160
1185
|
const k1 = reshapeViews(keys[1], mapping);
|
|
1161
1186
|
const c0 = AluExp.u32(0);
|
|
1162
1187
|
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1163
|
-
const exp$2 = AluExp.threefry2x32(
|
|
1188
|
+
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1164
1189
|
return new Kernel(nargs, prod(shape$1), exp$2);
|
|
1165
1190
|
},
|
|
1166
1191
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
@@ -1201,7 +1226,7 @@ const jitRules = {
|
|
|
1201
1226
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
1202
1227
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
1203
1228
|
const c = k1.exp;
|
|
1204
|
-
const cs =
|
|
1229
|
+
const cs = promoteAvals(as, bs);
|
|
1205
1230
|
return jitRules[Primitive.Reduce](nargs, [c], [cs], {
|
|
1206
1231
|
op: AluOp.Add,
|
|
1207
1232
|
axis: [cs.ndim - 1]
|
|
@@ -1211,8 +1236,8 @@ const jitRules = {
|
|
|
1211
1236
|
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1212
1237
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1213
1238
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1214
|
-
as = new ShapedArray(stX.shape, as.dtype);
|
|
1215
|
-
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1239
|
+
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1240
|
+
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1216
1241
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1217
1242
|
},
|
|
1218
1243
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
@@ -1265,9 +1290,10 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1265
1290
|
Primitive.Conv,
|
|
1266
1291
|
Primitive.PoolTranspose
|
|
1267
1292
|
];
|
|
1293
|
+
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
1268
1294
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1269
1295
|
const eqn = jaxpr.eqns[i];
|
|
1270
|
-
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive
|
|
1296
|
+
if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1271
1297
|
for (const v of eqn.outBinders) {
|
|
1272
1298
|
blackNodes.add(v);
|
|
1273
1299
|
p1NextBlack.set(v, v);
|
|
@@ -1397,6 +1423,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1397
1423
|
static #nextId = 1001;
|
|
1398
1424
|
id;
|
|
1399
1425
|
#dtype;
|
|
1426
|
+
#weakType;
|
|
1400
1427
|
#source;
|
|
1401
1428
|
#st;
|
|
1402
1429
|
#backend;
|
|
@@ -1408,21 +1435,22 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1408
1435
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1409
1436
|
* will be freed when the array is disposed.
|
|
1410
1437
|
*/
|
|
1411
|
-
constructor(
|
|
1438
|
+
constructor(args) {
|
|
1412
1439
|
super(baseArrayTrace);
|
|
1413
1440
|
this.id = Array$1.#nextId++;
|
|
1414
|
-
this.#dtype = dtype;
|
|
1415
|
-
this.#
|
|
1416
|
-
this.#
|
|
1417
|
-
this.#
|
|
1441
|
+
this.#dtype = args.dtype;
|
|
1442
|
+
this.#weakType = args.weakType;
|
|
1443
|
+
this.#source = args.source;
|
|
1444
|
+
this.#st = args.st;
|
|
1445
|
+
this.#backend = args.backend;
|
|
1418
1446
|
this.#rc = 1;
|
|
1419
|
-
this.#pendingSet = new Set(pending);
|
|
1447
|
+
this.#pendingSet = new Set(args.pending);
|
|
1420
1448
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1421
|
-
else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1449
|
+
else if (this.#source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1422
1450
|
}
|
|
1423
1451
|
/** @ignore */
|
|
1424
1452
|
get aval() {
|
|
1425
|
-
return new ShapedArray(this.#st.shape, this.#dtype);
|
|
1453
|
+
return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
|
|
1426
1454
|
}
|
|
1427
1455
|
/** Return a simple string representation of the array's dimensions. */
|
|
1428
1456
|
toString() {
|
|
@@ -1434,6 +1462,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1434
1462
|
#check() {
|
|
1435
1463
|
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1436
1464
|
}
|
|
1465
|
+
/** Construct an array, copying fields from `this`. */
|
|
1466
|
+
#newArrayFrom(args) {
|
|
1467
|
+
return new Array$1({
|
|
1468
|
+
source: args.source ?? this.#source,
|
|
1469
|
+
st: args.st ?? this.#st,
|
|
1470
|
+
dtype: args.dtype ?? this.#dtype,
|
|
1471
|
+
weakType: this.#weakType,
|
|
1472
|
+
backend: args.backend ?? this.#backend,
|
|
1473
|
+
pending: args.pending ?? this.#pending ?? void 0
|
|
1474
|
+
});
|
|
1475
|
+
}
|
|
1437
1476
|
get ref() {
|
|
1438
1477
|
this.#check();
|
|
1439
1478
|
this.#rc++;
|
|
@@ -1473,7 +1512,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1473
1512
|
const pending = this.#pending;
|
|
1474
1513
|
for (const exe of pending) exe.updateRc(1);
|
|
1475
1514
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1476
|
-
const ar =
|
|
1515
|
+
const ar = this.#newArrayFrom({
|
|
1516
|
+
st,
|
|
1517
|
+
pending
|
|
1518
|
+
});
|
|
1477
1519
|
this.dispose();
|
|
1478
1520
|
return ar;
|
|
1479
1521
|
}
|
|
@@ -1522,7 +1564,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1522
1564
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1523
1565
|
this.dispose();
|
|
1524
1566
|
for (const ar of indices) ar.dispose();
|
|
1525
|
-
return
|
|
1567
|
+
return this.#newArrayFrom({
|
|
1568
|
+
source: output,
|
|
1569
|
+
st: ShapeTracker.fromShape(finalShape),
|
|
1570
|
+
pending
|
|
1571
|
+
});
|
|
1526
1572
|
}
|
|
1527
1573
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1528
1574
|
#moveAxesDown(axis) {
|
|
@@ -1545,11 +1591,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1545
1591
|
return this.#reshape(this.#st.permute(perm));
|
|
1546
1592
|
}
|
|
1547
1593
|
#unary(op, dtypeOutput) {
|
|
1594
|
+
const weakType = !dtypeOutput && this.#weakType;
|
|
1548
1595
|
dtypeOutput ??= this.#dtype;
|
|
1549
1596
|
this.#check();
|
|
1550
1597
|
if (this.#source instanceof AluExp) {
|
|
1551
1598
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1552
|
-
return
|
|
1599
|
+
return this.#newArrayFrom({
|
|
1600
|
+
source: exp$3.simplify(),
|
|
1601
|
+
dtype: dtypeOutput,
|
|
1602
|
+
weakType
|
|
1603
|
+
});
|
|
1553
1604
|
}
|
|
1554
1605
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1555
1606
|
const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1559,41 +1610,65 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1559
1610
|
for (const exe of pending) exe.updateRc(1);
|
|
1560
1611
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1561
1612
|
this.dispose();
|
|
1562
|
-
return
|
|
1613
|
+
return this.#newArrayFrom({
|
|
1614
|
+
source: output,
|
|
1615
|
+
st: ShapeTracker.fromShape(this.shape),
|
|
1616
|
+
dtype: dtypeOutput,
|
|
1617
|
+
weakType,
|
|
1618
|
+
pending
|
|
1619
|
+
});
|
|
1563
1620
|
}
|
|
1564
1621
|
#binary(op, other) {
|
|
1565
|
-
const custom = (src) => new AluExp(op,
|
|
1622
|
+
const custom = (src) => new AluExp(op, src[0].dtype, src);
|
|
1566
1623
|
return Array$1.#naryCustom(op, custom, [this, other]);
|
|
1567
1624
|
}
|
|
1568
|
-
static #naryCustom(name, custom, arrays, { dtypeOverride,
|
|
1625
|
+
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1569
1626
|
const n = arrays.length;
|
|
1570
1627
|
const backend = arrays[0].#backend;
|
|
1571
1628
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1572
1629
|
for (const ar of arrays) ar.#check();
|
|
1573
|
-
let
|
|
1630
|
+
let castDtype;
|
|
1631
|
+
let castWeakType = true;
|
|
1574
1632
|
for (let i = 0; i < n; i++) {
|
|
1575
1633
|
if (dtypeOverride?.[i]) {
|
|
1576
1634
|
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1577
|
-
} else if (
|
|
1578
|
-
|
|
1635
|
+
} else if (castDtype === void 0) {
|
|
1636
|
+
castDtype = arrays[i].#dtype;
|
|
1637
|
+
castWeakType = arrays[i].#weakType;
|
|
1638
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1579
1639
|
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1580
1640
|
}
|
|
1581
|
-
|
|
1582
|
-
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1641
|
+
const weakType = castWeakType && !strongTypeOutput;
|
|
1583
1642
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1584
1643
|
const newShape = [...arrays[0].shape];
|
|
1585
1644
|
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
1645
|
+
const sources = arrays.map((ar, i) => {
|
|
1646
|
+
if (!dtypeOverride?.[i]) return AluExp.cast(castDtype, ar.#source);
|
|
1647
|
+
else return ar.#source;
|
|
1648
|
+
});
|
|
1586
1649
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1587
|
-
const exp$4 = custom(
|
|
1588
|
-
return new Array$1(
|
|
1650
|
+
const exp$4 = custom(sources);
|
|
1651
|
+
return new Array$1({
|
|
1652
|
+
source: exp$4.simplify(),
|
|
1653
|
+
st: arrays[0].#st,
|
|
1654
|
+
dtype: exp$4.dtype,
|
|
1655
|
+
weakType,
|
|
1656
|
+
backend
|
|
1657
|
+
});
|
|
1589
1658
|
}
|
|
1590
|
-
const exp$3 = custom(arrays.map((ar) => {
|
|
1591
|
-
const src$1 =
|
|
1659
|
+
const exp$3 = custom(arrays.map((ar, i) => {
|
|
1660
|
+
const src$1 = sources[i];
|
|
1592
1661
|
if (ar.#st.contiguous) return src$1;
|
|
1593
1662
|
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1594
1663
|
}));
|
|
1595
1664
|
const st = ShapeTracker.fromShape(newShape);
|
|
1596
|
-
return new Array$1(
|
|
1665
|
+
return new Array$1({
|
|
1666
|
+
source: exp$3.simplify(),
|
|
1667
|
+
st,
|
|
1668
|
+
dtype: exp$3.dtype,
|
|
1669
|
+
weakType,
|
|
1670
|
+
backend
|
|
1671
|
+
});
|
|
1597
1672
|
}
|
|
1598
1673
|
let indices;
|
|
1599
1674
|
if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
|
|
@@ -1603,14 +1678,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1603
1678
|
}
|
|
1604
1679
|
const inputs = [];
|
|
1605
1680
|
const src = [];
|
|
1606
|
-
for (const ar of arrays
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
gid = inputs.
|
|
1611
|
-
|
|
1681
|
+
for (const [i, ar] of arrays.entries()) {
|
|
1682
|
+
let nextSrc;
|
|
1683
|
+
if (ar.#source instanceof AluExp) nextSrc = accessorAluExp(ar.#source, ar.#st, indices);
|
|
1684
|
+
else {
|
|
1685
|
+
let gid = inputs.indexOf(ar.#source);
|
|
1686
|
+
if (gid === -1) {
|
|
1687
|
+
gid = inputs.length;
|
|
1688
|
+
inputs.push(ar.#source);
|
|
1689
|
+
}
|
|
1690
|
+
nextSrc = AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
|
|
1612
1691
|
}
|
|
1613
|
-
|
|
1692
|
+
if (!dtypeOverride?.[i]) nextSrc = AluExp.cast(castDtype, nextSrc);
|
|
1693
|
+
src.push(nextSrc);
|
|
1614
1694
|
}
|
|
1615
1695
|
const exp$2 = custom(src);
|
|
1616
1696
|
let re = void 0;
|
|
@@ -1624,12 +1704,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1624
1704
|
for (const exe of pending) exe.updateRc(1);
|
|
1625
1705
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1626
1706
|
for (const ar of arrays) ar.dispose();
|
|
1627
|
-
return new Array$1(
|
|
1707
|
+
return new Array$1({
|
|
1708
|
+
source: output,
|
|
1709
|
+
st: ShapeTracker.fromShape(newShape),
|
|
1710
|
+
dtype: kernel.dtype,
|
|
1711
|
+
weakType,
|
|
1712
|
+
backend,
|
|
1713
|
+
pending
|
|
1714
|
+
});
|
|
1628
1715
|
}
|
|
1629
1716
|
/** Reduce the last dimension of the array by an operation. */
|
|
1630
1717
|
#reduce(op) {
|
|
1631
|
-
this.#check();
|
|
1632
|
-
if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
|
|
1633
1718
|
const shape$1 = this.shape;
|
|
1634
1719
|
const reduction = new Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
|
|
1635
1720
|
const newShape = shape$1.slice(0, -1);
|
|
@@ -1648,7 +1733,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1648
1733
|
for (const exe of pending) exe.updateRc(1);
|
|
1649
1734
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1650
1735
|
this.dispose();
|
|
1651
|
-
return
|
|
1736
|
+
return this.#newArrayFrom({
|
|
1737
|
+
source: output,
|
|
1738
|
+
st: ShapeTracker.fromShape(newShape),
|
|
1739
|
+
pending
|
|
1740
|
+
});
|
|
1652
1741
|
}
|
|
1653
1742
|
/**
|
|
1654
1743
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1684,8 +1773,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1684
1773
|
}
|
|
1685
1774
|
#dataInline() {
|
|
1686
1775
|
this.#check();
|
|
1687
|
-
|
|
1688
|
-
const ar =
|
|
1776
|
+
if (!(this.#source instanceof AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
|
|
1777
|
+
const ar = this.#newArrayFrom({ backend: getBackend("cpu") });
|
|
1689
1778
|
this.dispose();
|
|
1690
1779
|
return ar.dataSync();
|
|
1691
1780
|
}
|
|
@@ -1811,7 +1900,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1811
1900
|
x.#backend.incRef(x.#source);
|
|
1812
1901
|
const pending = x.#pending;
|
|
1813
1902
|
for (const exe of pending) exe.updateRc(1);
|
|
1814
|
-
const y =
|
|
1903
|
+
const y = x.#newArrayFrom({
|
|
1904
|
+
dtype,
|
|
1905
|
+
weakType: false,
|
|
1906
|
+
pending
|
|
1907
|
+
});
|
|
1815
1908
|
x.dispose();
|
|
1816
1909
|
return [y];
|
|
1817
1910
|
}
|
|
@@ -1886,7 +1979,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1886
1979
|
},
|
|
1887
1980
|
[Primitive.Compare]([x, y], { op }) {
|
|
1888
1981
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1889
|
-
return [Array$1.#naryCustom("compare", custom, [x, y], {
|
|
1982
|
+
return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
|
|
1890
1983
|
},
|
|
1891
1984
|
[Primitive.Where]([cond, x, y]) {
|
|
1892
1985
|
const custom = ([cond$1, x$1, y$1]) => AluExp.where(cond$1, x$1, y$1);
|
|
@@ -1932,7 +2025,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1932
2025
|
pending.splice(0, 0, ...prevPending);
|
|
1933
2026
|
args.forEach((x) => x.dispose());
|
|
1934
2027
|
return outputs.map((source, i) => {
|
|
1935
|
-
return new Array$1(
|
|
2028
|
+
return new Array$1({
|
|
2029
|
+
source,
|
|
2030
|
+
st: ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
|
|
2031
|
+
dtype: jaxpr.outs[i].aval.dtype,
|
|
2032
|
+
weakType: jaxpr.outs[i].aval.weakType,
|
|
2033
|
+
backend,
|
|
2034
|
+
pending
|
|
2035
|
+
});
|
|
1936
2036
|
});
|
|
1937
2037
|
}
|
|
1938
2038
|
};
|
|
@@ -1942,33 +2042,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1942
2042
|
return this.#source;
|
|
1943
2043
|
}
|
|
1944
2044
|
};
|
|
1945
|
-
/** Construct an array from a single scalar constant. */
|
|
1946
|
-
function scalar(value, { dtype, device } = {}) {
|
|
1947
|
-
if (typeof value === "number") {
|
|
1948
|
-
dtype ??= DType.Float32;
|
|
1949
|
-
if (![
|
|
1950
|
-
DType.Float32,
|
|
1951
|
-
DType.Float16,
|
|
1952
|
-
DType.Int32,
|
|
1953
|
-
DType.Uint32
|
|
1954
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1955
|
-
} else if (typeof value === "boolean") {
|
|
1956
|
-
dtype ??= DType.Bool;
|
|
1957
|
-
if (![
|
|
1958
|
-
DType.Float32,
|
|
1959
|
-
DType.Float16,
|
|
1960
|
-
DType.Int32,
|
|
1961
|
-
DType.Uint32,
|
|
1962
|
-
DType.Bool
|
|
1963
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1964
|
-
} else throw new TypeError(`Invalid type for scalar ${value}`);
|
|
1965
|
-
return new Array$1(AluExp.const(dtype, value), ShapeTracker.fromShape([]), dtype, getBackend(device));
|
|
1966
|
-
}
|
|
1967
2045
|
/** Constructor for creating a new array from data. */
|
|
1968
2046
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
1969
2047
|
if (values instanceof Tracer) {
|
|
1970
2048
|
if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1971
|
-
if (dtype && values.dtype !== dtype)
|
|
2049
|
+
if (dtype && values.dtype !== dtype) values = values.astype(dtype);
|
|
1972
2050
|
return values;
|
|
1973
2051
|
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1974
2052
|
dtype,
|
|
@@ -1990,6 +2068,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1990
2068
|
dtype,
|
|
1991
2069
|
device
|
|
1992
2070
|
});
|
|
2071
|
+
if (size$1 === 1) return full(shape$1, flat[0], {
|
|
2072
|
+
dtype,
|
|
2073
|
+
device
|
|
2074
|
+
});
|
|
1993
2075
|
if (typeof flat[0] === "boolean") {
|
|
1994
2076
|
dtype = dtype ?? DType.Bool;
|
|
1995
2077
|
const data = new Int32Array(flat.map((x) => x ? 1 : 0));
|
|
@@ -1998,46 +2080,51 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1998
2080
|
device
|
|
1999
2081
|
});
|
|
2000
2082
|
} else {
|
|
2083
|
+
const weakType = dtype == void 0;
|
|
2001
2084
|
dtype = dtype ?? DType.Float32;
|
|
2002
2085
|
const data = dtypedJsArray(dtype, flat);
|
|
2003
2086
|
return arrayFromData(data, shape$1, {
|
|
2004
2087
|
dtype,
|
|
2005
2088
|
device
|
|
2006
|
-
});
|
|
2089
|
+
}, weakType);
|
|
2007
2090
|
}
|
|
2008
2091
|
}
|
|
2009
2092
|
}
|
|
2010
|
-
function arrayFromData(data, shape$1, { dtype, device } =
|
|
2093
|
+
function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
2094
|
+
if (data instanceof Float32Array) {
|
|
2095
|
+
if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2096
|
+
dtype ??= DType.Float32;
|
|
2097
|
+
} else if (data instanceof Int32Array) {
|
|
2098
|
+
if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2099
|
+
dtype ??= DType.Int32;
|
|
2100
|
+
} else if (data instanceof Uint32Array) {
|
|
2101
|
+
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2102
|
+
dtype ??= DType.Uint32;
|
|
2103
|
+
} else if (data instanceof Float16Array) {
|
|
2104
|
+
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2105
|
+
dtype ??= DType.Float16;
|
|
2106
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2011
2107
|
if (data.length < inlineArrayLimit) {
|
|
2012
2108
|
let allEqual = true;
|
|
2013
2109
|
for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
|
|
2014
2110
|
allEqual = false;
|
|
2015
2111
|
break;
|
|
2016
2112
|
}
|
|
2017
|
-
if (allEqual)
|
|
2018
|
-
dtype,
|
|
2019
|
-
device
|
|
2020
|
-
}
|
|
2113
|
+
if (allEqual) {
|
|
2114
|
+
const sa = new ShapedArray(shape$1, dtype, weakType);
|
|
2115
|
+
return fullInternal(sa, data[0], device);
|
|
2116
|
+
}
|
|
2021
2117
|
}
|
|
2022
2118
|
const backend = getBackend(device);
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2033
|
-
dtype ??= DType.Uint32;
|
|
2034
|
-
} else if (data instanceof Float16Array) {
|
|
2035
|
-
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2036
|
-
dtype ??= DType.Float16;
|
|
2037
|
-
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2038
|
-
const slot = backend.malloc(data.byteLength, buf);
|
|
2039
|
-
return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2040
|
-
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
2119
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2120
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2121
|
+
return new Array$1({
|
|
2122
|
+
source: slot,
|
|
2123
|
+
st: ShapeTracker.fromShape(shape$1),
|
|
2124
|
+
dtype,
|
|
2125
|
+
weakType,
|
|
2126
|
+
backend
|
|
2127
|
+
});
|
|
2041
2128
|
}
|
|
2042
2129
|
function dataToJs(dtype, data, shape$1) {
|
|
2043
2130
|
if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -2053,7 +2140,7 @@ function dataToJs(dtype, data, shape$1) {
|
|
|
2053
2140
|
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
2054
2141
|
function pureArray(x) {
|
|
2055
2142
|
if (x instanceof Tracer) return x;
|
|
2056
|
-
else return
|
|
2143
|
+
else return array(x);
|
|
2057
2144
|
}
|
|
2058
2145
|
var EvalTrace = class extends Trace {
|
|
2059
2146
|
pure = (x) => pureArray(x);
|
|
@@ -2064,20 +2151,27 @@ var EvalTrace = class extends Trace {
|
|
|
2064
2151
|
};
|
|
2065
2152
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2066
2153
|
const implRules = Array$1._implRules();
|
|
2154
|
+
function fullInternal(aval, fillValue, device) {
|
|
2155
|
+
return new Array$1({
|
|
2156
|
+
source: AluExp.const(aval.dtype, fillValue),
|
|
2157
|
+
st: ShapeTracker.fromShape(aval.shape),
|
|
2158
|
+
dtype: aval.dtype,
|
|
2159
|
+
weakType: aval.weakType,
|
|
2160
|
+
backend: getBackend(device)
|
|
2161
|
+
});
|
|
2162
|
+
}
|
|
2067
2163
|
function zerosLike$1(val, dtype) {
|
|
2068
|
-
|
|
2069
|
-
if (val instanceof Tracer) val.dispose();
|
|
2070
|
-
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2164
|
+
return fullLike(val, 0, dtype);
|
|
2071
2165
|
}
|
|
2072
2166
|
function onesLike$1(val, dtype) {
|
|
2073
|
-
|
|
2074
|
-
if (val instanceof Tracer) val.dispose();
|
|
2075
|
-
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2167
|
+
return fullLike(val, 1, dtype);
|
|
2076
2168
|
}
|
|
2077
2169
|
function fullLike(val, fillValue, dtype) {
|
|
2078
2170
|
const aval = getAval(val);
|
|
2079
2171
|
if (val instanceof Tracer) val.dispose();
|
|
2080
|
-
|
|
2172
|
+
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
2173
|
+
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
2174
|
+
return fullInternal(sa, fillValue);
|
|
2081
2175
|
}
|
|
2082
2176
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
2083
2177
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -2095,19 +2189,14 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
2095
2189
|
}
|
|
2096
2190
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
2097
2191
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
2098
|
-
let
|
|
2099
|
-
if (typeof fillValue === "number")
|
|
2100
|
-
|
|
2101
|
-
source = AluExp.const(dtype, fillValue);
|
|
2102
|
-
} else if (typeof fillValue === "bigint") {
|
|
2103
|
-
dtype = dtype ?? DType.Int32;
|
|
2104
|
-
source = AluExp.const(dtype, Number(fillValue));
|
|
2105
|
-
} else if (typeof fillValue === "boolean") {
|
|
2192
|
+
let weakType = dtype == void 0;
|
|
2193
|
+
if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
|
|
2194
|
+
else if (typeof fillValue === "boolean") {
|
|
2106
2195
|
dtype = dtype ?? DType.Bool;
|
|
2107
|
-
|
|
2196
|
+
weakType = false;
|
|
2108
2197
|
} else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
|
|
2109
2198
|
else throw new TypeError(`Invalid type for full: ${fillValue}`);
|
|
2110
|
-
return new
|
|
2199
|
+
return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
|
|
2111
2200
|
}
|
|
2112
2201
|
/**
|
|
2113
2202
|
* Create an identity matrix.
|
|
@@ -2117,6 +2206,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
2117
2206
|
*/
|
|
2118
2207
|
function eye(numRows, numCols, { dtype, device } = {}) {
|
|
2119
2208
|
numCols = numCols ?? numRows;
|
|
2209
|
+
const weakType = dtype == void 0;
|
|
2120
2210
|
dtype = dtype ?? DType.Float32;
|
|
2121
2211
|
if (numCols < numRows) {
|
|
2122
2212
|
const arr = eye(numCols, numRows, {
|
|
@@ -2130,7 +2220,13 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2130
2220
|
device
|
|
2131
2221
|
});
|
|
2132
2222
|
const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
|
|
2133
|
-
return new Array$1(
|
|
2223
|
+
return new Array$1({
|
|
2224
|
+
source: AluExp.cast(dtype, exp$2),
|
|
2225
|
+
st: ShapeTracker.fromShape([numRows, numCols]),
|
|
2226
|
+
dtype,
|
|
2227
|
+
weakType,
|
|
2228
|
+
backend: getBackend(device)
|
|
2229
|
+
});
|
|
2134
2230
|
}
|
|
2135
2231
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2136
2232
|
function identity$1(n, { dtype, device } = {}) {
|
|
@@ -2167,7 +2263,13 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2167
2263
|
});
|
|
2168
2264
|
const exp$2 = AluExp.add(AluExp.const(dtype, start), AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step)));
|
|
2169
2265
|
const st = ShapeTracker.fromShape([size$1]);
|
|
2170
|
-
return new Array$1(
|
|
2266
|
+
return new Array$1({
|
|
2267
|
+
source: exp$2,
|
|
2268
|
+
st,
|
|
2269
|
+
dtype,
|
|
2270
|
+
weakType: false,
|
|
2271
|
+
backend: getBackend(device)
|
|
2272
|
+
});
|
|
2171
2273
|
}
|
|
2172
2274
|
/**
|
|
2173
2275
|
* Return evenly spaced numbers over a specified interval.
|
|
@@ -2185,10 +2287,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2185
2287
|
dtype,
|
|
2186
2288
|
device
|
|
2187
2289
|
});
|
|
2188
|
-
else if (num === 1) return
|
|
2290
|
+
else if (num === 1) return full([1], start, {
|
|
2189
2291
|
dtype,
|
|
2190
2292
|
device
|
|
2191
|
-
})
|
|
2293
|
+
});
|
|
2192
2294
|
else if (start === stop) return full([num], start, {
|
|
2193
2295
|
dtype,
|
|
2194
2296
|
device
|
|
@@ -2197,7 +2299,13 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2197
2299
|
const denom = endpoint ? num - 1 : num;
|
|
2198
2300
|
const exp$2 = AluExp.cast(dtype, AluExp.add(AluExp.f32(start), AluExp.mul(AluExp.f32(delta / denom), AluExp.cast(DType.Float32, AluVar.idx))));
|
|
2199
2301
|
const st = ShapeTracker.fromShape([num]);
|
|
2200
|
-
return new Array$1(
|
|
2302
|
+
return new Array$1({
|
|
2303
|
+
source: exp$2,
|
|
2304
|
+
st,
|
|
2305
|
+
dtype,
|
|
2306
|
+
weakType: false,
|
|
2307
|
+
backend: getBackend(device)
|
|
2308
|
+
});
|
|
2201
2309
|
}
|
|
2202
2310
|
function aluCompare(a, b, op) {
|
|
2203
2311
|
switch (op) {
|
|
@@ -2209,35 +2317,6 @@ function aluCompare(a, b, op) {
|
|
|
2209
2317
|
case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
|
|
2210
2318
|
}
|
|
2211
2319
|
}
|
|
2212
|
-
/**
|
|
2213
|
-
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
2214
|
-
*
|
|
2215
|
-
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
2216
|
-
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
2217
|
-
* Two dimensions are compatible when:
|
|
2218
|
-
* 1. they are equal, or
|
|
2219
|
-
* 2. one of them is 1."
|
|
2220
|
-
*
|
|
2221
|
-
* Throws a TypeError if the broadcast is not possible.
|
|
2222
|
-
*
|
|
2223
|
-
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
2224
|
-
*/
|
|
2225
|
-
function generalBroadcast(a, b) {
|
|
2226
|
-
const out = [];
|
|
2227
|
-
let i = a.length - 1;
|
|
2228
|
-
let j = b.length - 1;
|
|
2229
|
-
for (; i >= 0 && j >= 0; i--, j--) {
|
|
2230
|
-
const x = a[i];
|
|
2231
|
-
const y = b[j];
|
|
2232
|
-
if (x === y) out.push(x);
|
|
2233
|
-
else if (x === 1) out.push(y);
|
|
2234
|
-
else if (y === 1) out.push(x);
|
|
2235
|
-
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
2236
|
-
}
|
|
2237
|
-
for (; i >= 0; i--) out.push(a[i]);
|
|
2238
|
-
for (; j >= 0; j--) out.push(b[j]);
|
|
2239
|
-
return out.reverse();
|
|
2240
|
-
}
|
|
2241
2320
|
|
|
2242
2321
|
//#endregion
|
|
2243
2322
|
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
|
|
@@ -2313,13 +2392,15 @@ var Var = class Var {
|
|
|
2313
2392
|
};
|
|
2314
2393
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
2315
2394
|
var Lit = class {
|
|
2316
|
-
dtype;
|
|
2317
2395
|
value;
|
|
2318
2396
|
aval;
|
|
2319
|
-
|
|
2320
|
-
this.dtype
|
|
2397
|
+
get dtype() {
|
|
2398
|
+
return this.aval.dtype;
|
|
2399
|
+
}
|
|
2400
|
+
constructor(aval, value) {
|
|
2401
|
+
if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
|
|
2321
2402
|
this.value = value;
|
|
2322
|
-
this.aval =
|
|
2403
|
+
this.aval = ShapedArray.fromAval(aval);
|
|
2323
2404
|
}
|
|
2324
2405
|
};
|
|
2325
2406
|
function atomIsLit(atom, literal) {
|
|
@@ -2443,14 +2524,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2443
2524
|
const c = eqn.outBinders[0];
|
|
2444
2525
|
if (atomIsLit(a, 0)) context.set(c, b);
|
|
2445
2526
|
else if (atomIsLit(b, 0)) context.set(c, a);
|
|
2446
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2527
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
|
|
2528
|
+
else newEqns.push(eqn);
|
|
2529
|
+
} else if (eqn.primitive === Primitive.Neg) {
|
|
2530
|
+
const [a] = inputs;
|
|
2531
|
+
const c = eqn.outBinders[0];
|
|
2532
|
+
if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
|
|
2447
2533
|
else newEqns.push(eqn);
|
|
2448
2534
|
} else if (eqn.primitive === Primitive.Mul) {
|
|
2449
2535
|
const [a, b] = inputs;
|
|
2450
2536
|
const c = eqn.outBinders[0];
|
|
2451
2537
|
if (atomIsLit(a, 1)) context.set(c, b);
|
|
2452
2538
|
else if (atomIsLit(b, 1)) context.set(c, a);
|
|
2453
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2539
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
|
|
2454
2540
|
else newEqns.push(eqn);
|
|
2455
2541
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
2456
2542
|
const [a, b] = inputs;
|
|
@@ -2548,7 +2634,7 @@ function evalJaxpr(jaxpr, args) {
|
|
|
2548
2634
|
if (x instanceof Var) {
|
|
2549
2635
|
remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
|
|
2550
2636
|
return env.get(x);
|
|
2551
|
-
} else return
|
|
2637
|
+
} else return array(x.value, { dtype: x.dtype });
|
|
2552
2638
|
};
|
|
2553
2639
|
const write = (v, val) => {
|
|
2554
2640
|
if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
|
|
@@ -2607,7 +2693,7 @@ var JaxprTrace = class extends Trace {
|
|
|
2607
2693
|
let tracer = this.builder.constTracers.get(val);
|
|
2608
2694
|
if (tracer === void 0) {
|
|
2609
2695
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
2610
|
-
this.builder.addConst(tracer, val instanceof Tracer ? val.ref :
|
|
2696
|
+
this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
|
|
2611
2697
|
}
|
|
2612
2698
|
return tracer;
|
|
2613
2699
|
}
|
|
@@ -2676,7 +2762,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2676
2762
|
const newConsts = [];
|
|
2677
2763
|
for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
|
|
2678
2764
|
const ar = consts[i];
|
|
2679
|
-
literals.set(jaxpr.inBinders[i], new Lit(ar.
|
|
2765
|
+
literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
|
|
2680
2766
|
} else {
|
|
2681
2767
|
constBinders.push(jaxpr.inBinders[i]);
|
|
2682
2768
|
newConsts.push(consts[i]);
|
|
@@ -2689,13 +2775,12 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2689
2775
|
}
|
|
2690
2776
|
function binopAbstractEval([x, y]) {
|
|
2691
2777
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
2692
|
-
|
|
2693
|
-
return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
|
|
2778
|
+
return [promoteAvals(x, y)];
|
|
2694
2779
|
}
|
|
2695
2780
|
function compareAbstractEval([x, y]) {
|
|
2696
2781
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
|
|
2697
|
-
|
|
2698
|
-
return [new ShapedArray(
|
|
2782
|
+
const aval = promoteAvals(x, y);
|
|
2783
|
+
return [new ShapedArray(aval.shape, DType.Bool, false)];
|
|
2699
2784
|
}
|
|
2700
2785
|
function vectorizedUnopAbstractEval([x]) {
|
|
2701
2786
|
return [ShapedArray.fromAval(x)];
|
|
@@ -2708,18 +2793,18 @@ const abstractEvalRules = {
|
|
|
2708
2793
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
2709
2794
|
[Primitive.StopGradient]: vectorizedUnopAbstractEval,
|
|
2710
2795
|
[Primitive.Cast]([x], { dtype }) {
|
|
2711
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2796
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2712
2797
|
},
|
|
2713
2798
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
2714
2799
|
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
2715
2800
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
2716
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2801
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2717
2802
|
},
|
|
2718
2803
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
2719
2804
|
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
2720
2805
|
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2721
2806
|
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2722
|
-
return [new ShapedArray(shape$1, DType.Uint32)];
|
|
2807
|
+
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
2723
2808
|
},
|
|
2724
2809
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2725
2810
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
@@ -2733,55 +2818,54 @@ const abstractEvalRules = {
|
|
|
2733
2818
|
[Primitive.Reduce]([x], { axis }) {
|
|
2734
2819
|
const axisSet = new Set(axis);
|
|
2735
2820
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2736
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2821
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2737
2822
|
},
|
|
2738
2823
|
[Primitive.Pool]([x], { window, strides }) {
|
|
2739
2824
|
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2740
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2825
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2741
2826
|
},
|
|
2742
2827
|
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2743
2828
|
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2744
2829
|
if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2745
|
-
return [new ShapedArray(inShape, x.dtype)];
|
|
2830
|
+
return [new ShapedArray(inShape, x.dtype, x.weakType)];
|
|
2746
2831
|
},
|
|
2747
2832
|
[Primitive.Dot]([x, y]) {
|
|
2748
|
-
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2749
2833
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
2750
|
-
const shape$1 =
|
|
2834
|
+
const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
|
|
2751
2835
|
shape$1.splice(-1, 1);
|
|
2752
|
-
return [new ShapedArray(shape$1,
|
|
2836
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2753
2837
|
},
|
|
2754
2838
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
2755
|
-
|
|
2839
|
+
const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
|
|
2756
2840
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2757
|
-
return [new ShapedArray(shape$1,
|
|
2841
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2758
2842
|
},
|
|
2759
2843
|
[Primitive.Compare]: compareAbstractEval,
|
|
2760
2844
|
[Primitive.Where]([cond, x, y]) {
|
|
2761
2845
|
if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
2762
|
-
|
|
2763
|
-
const shape$1 = generalBroadcast(cond.shape,
|
|
2764
|
-
return [new ShapedArray(shape$1,
|
|
2846
|
+
const xy = promoteAvals(x, y);
|
|
2847
|
+
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
2848
|
+
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
2765
2849
|
},
|
|
2766
2850
|
[Primitive.Transpose]([x], { perm }) {
|
|
2767
|
-
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
|
|
2851
|
+
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
2768
2852
|
},
|
|
2769
2853
|
[Primitive.Broadcast]([x], { shape: shape$1 }) {
|
|
2770
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2854
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2771
2855
|
},
|
|
2772
2856
|
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2773
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2857
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2774
2858
|
},
|
|
2775
2859
|
[Primitive.Flip]([x], _) {
|
|
2776
|
-
return [
|
|
2860
|
+
return [ShapedArray.fromAval(x)];
|
|
2777
2861
|
},
|
|
2778
2862
|
[Primitive.Shrink]([x], { slice }) {
|
|
2779
2863
|
const newShape = slice.map((s) => s[1] - s[0]);
|
|
2780
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2864
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2781
2865
|
},
|
|
2782
2866
|
[Primitive.Pad]([x], { width }) {
|
|
2783
2867
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
2784
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2868
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2785
2869
|
},
|
|
2786
2870
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2787
2871
|
for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
@@ -2794,7 +2878,7 @@ const abstractEvalRules = {
|
|
|
2794
2878
|
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
2795
2879
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2796
2880
|
newShape.splice(outDim, 0, ...gatherShape);
|
|
2797
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2881
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2798
2882
|
},
|
|
2799
2883
|
[Primitive.JitCall](args, { jaxpr }) {
|
|
2800
2884
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
@@ -2861,6 +2945,7 @@ function jit$1(f, opts) {
|
|
|
2861
2945
|
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2862
2946
|
const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2863
2947
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2948
|
+
name: f.name || "closure",
|
|
2864
2949
|
jaxpr,
|
|
2865
2950
|
numConsts: consts.length
|
|
2866
2951
|
});
|
|
@@ -3022,13 +3107,14 @@ const jvpRules = {
|
|
|
3022
3107
|
const indicesRef = indices.map((t) => t.ref);
|
|
3023
3108
|
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3024
3109
|
},
|
|
3025
|
-
[Primitive.JitCall](primals, tangents, { jaxpr }) {
|
|
3110
|
+
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3026
3111
|
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3027
3112
|
const outs = bind(Primitive.JitCall, [
|
|
3028
3113
|
...newConsts.map((c) => c.ref),
|
|
3029
3114
|
...primals,
|
|
3030
3115
|
...tangents
|
|
3031
3116
|
], {
|
|
3117
|
+
name: `${name}_jvp`,
|
|
3032
3118
|
jaxpr: newJaxpr,
|
|
3033
3119
|
numConsts: newConsts.length
|
|
3034
3120
|
});
|
|
@@ -3082,7 +3168,7 @@ function jvp$1(f, primals, tangents) {
|
|
|
3082
3168
|
function mappedAval(batchDim, aval) {
|
|
3083
3169
|
const shape$1 = [...aval.shape];
|
|
3084
3170
|
shape$1.splice(batchDim, 1);
|
|
3085
|
-
return new ShapedArray(shape$1, aval.dtype);
|
|
3171
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3086
3172
|
}
|
|
3087
3173
|
/** Move one axis to a different index. */
|
|
3088
3174
|
function moveaxis$1(x, src, dst) {
|
|
@@ -3226,9 +3312,10 @@ const vmapRules = {
|
|
|
3226
3312
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3227
3313
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3228
3314
|
},
|
|
3229
|
-
[Primitive.JitCall](axisSize, args, dims, { jaxpr }) {
|
|
3315
|
+
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3230
3316
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3231
3317
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
3318
|
+
name: `${name}_vmap`,
|
|
3232
3319
|
jaxpr: newJaxpr,
|
|
3233
3320
|
numConsts: newConsts.length
|
|
3234
3321
|
});
|
|
@@ -3244,7 +3331,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3244
3331
|
if (dims[i] === null) return v.aval;
|
|
3245
3332
|
const shape$1 = [...v.aval.shape];
|
|
3246
3333
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3247
|
-
return new ShapedArray(shape$1, v.aval.dtype);
|
|
3334
|
+
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3248
3335
|
});
|
|
3249
3336
|
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3250
3337
|
const result = {
|
|
@@ -3457,8 +3544,8 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3457
3544
|
processPrimitive(primitive, tracers, params) {
|
|
3458
3545
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3459
3546
|
if (primitive === Primitive.JitCall) {
|
|
3460
|
-
const { jaxpr, numConsts } = params;
|
|
3461
|
-
return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
|
|
3547
|
+
const { name, jaxpr, numConsts } = params;
|
|
3548
|
+
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3462
3549
|
}
|
|
3463
3550
|
const tracersIn = tracers.map((t) => this.instantiateConst(t));
|
|
3464
3551
|
const avalsIn = tracersIn.map((t) => t.pval.aval);
|
|
@@ -3484,12 +3571,13 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3484
3571
|
*
|
|
3485
3572
|
* Used when encountering a JitCall rule during the trace.
|
|
3486
3573
|
*/
|
|
3487
|
-
#partialEvalJaxpr(jaxpr, numConsts, tracers) {
|
|
3574
|
+
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3488
3575
|
jaxpr = jaxpr.flatten();
|
|
3489
3576
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3490
3577
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3491
3578
|
const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
|
|
3492
3579
|
const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3580
|
+
name: `${name}_peval`,
|
|
3493
3581
|
jaxpr: jaxpr1,
|
|
3494
3582
|
numConsts: 0
|
|
3495
3583
|
});
|
|
@@ -3501,6 +3589,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3501
3589
|
prim: Primitive.JitCall,
|
|
3502
3590
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3503
3591
|
params: {
|
|
3592
|
+
name: `${name}_resid`,
|
|
3504
3593
|
jaxpr: jaxpr2,
|
|
3505
3594
|
numConsts: 0
|
|
3506
3595
|
},
|
|
@@ -3643,7 +3732,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
|
|
|
3643
3732
|
}
|
|
3644
3733
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
3645
3734
|
const eqn = jaxpr.eqns[i];
|
|
3646
|
-
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ?
|
|
3735
|
+
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
|
|
3647
3736
|
const cotangentsOut = eqn.outBinders.map(readCotangent);
|
|
3648
3737
|
const rule = transposeRules[eqn.primitive];
|
|
3649
3738
|
if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
|
|
@@ -3823,7 +3912,7 @@ const transposeRules = {
|
|
|
3823
3912
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
3824
3913
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
3825
3914
|
},
|
|
3826
|
-
[Primitive.JitCall](cts, args, { jaxpr }) {
|
|
3915
|
+
[Primitive.JitCall](cts, args, { name, jaxpr }) {
|
|
3827
3916
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
3828
3917
|
const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
|
|
3829
3918
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
@@ -3832,6 +3921,7 @@ const transposeRules = {
|
|
|
3832
3921
|
...residuals,
|
|
3833
3922
|
...cts
|
|
3834
3923
|
], {
|
|
3924
|
+
name: `${name}_t`,
|
|
3835
3925
|
jaxpr: newJaxpr,
|
|
3836
3926
|
numConsts: newConsts.length
|
|
3837
3927
|
});
|
|
@@ -3906,7 +3996,7 @@ function valueAndGrad$1(f) {
|
|
|
3906
3996
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3907
3997
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3908
3998
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3909
|
-
const [ct, ...rest] = fVjp(
|
|
3999
|
+
const [ct, ...rest] = fVjp(array(1, { dtype: y.dtype }));
|
|
3910
4000
|
for (const r of rest) dispose(r);
|
|
3911
4001
|
fVjp.dispose();
|
|
3912
4002
|
return [y, ct];
|
|
@@ -4276,7 +4366,7 @@ function argmin(a, axis, opts) {
|
|
|
4276
4366
|
} else axis = checkAxis(axis, a.ndim);
|
|
4277
4367
|
const shape$1 = a.shape;
|
|
4278
4368
|
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4279
|
-
const length =
|
|
4369
|
+
const length = array(shape$1[axis], {
|
|
4280
4370
|
dtype: int32,
|
|
4281
4371
|
device: a.device
|
|
4282
4372
|
});
|
|
@@ -4300,7 +4390,7 @@ function argmax(a, axis, opts) {
|
|
|
4300
4390
|
} else axis = checkAxis(axis, a.ndim);
|
|
4301
4391
|
const shape$1 = a.shape;
|
|
4302
4392
|
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4303
|
-
const length =
|
|
4393
|
+
const length = array(shape$1[axis], {
|
|
4304
4394
|
dtype: int32,
|
|
4305
4395
|
device: a.device
|
|
4306
4396
|
});
|
|
@@ -4716,7 +4806,7 @@ function acos(x) {
|
|
|
4716
4806
|
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4717
4807
|
* improvements.
|
|
4718
4808
|
*/
|
|
4719
|
-
const hypot = jit$1((x1, x2)
|
|
4809
|
+
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4720
4810
|
return sqrt(square(x1).add(square(x2)));
|
|
4721
4811
|
});
|
|
4722
4812
|
/**
|
|
@@ -4732,7 +4822,7 @@ const hypot = jit$1((x1, x2) => {
|
|
|
4732
4822
|
*
|
|
4733
4823
|
* The output is ill-defined when both x and y are zero.
|
|
4734
4824
|
*/
|
|
4735
|
-
const atan2 = jit$1((y, x)
|
|
4825
|
+
const atan2 = jit$1(function atan2$1(y, x) {
|
|
4736
4826
|
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4737
4827
|
const xNeg = less(x.ref, 0);
|
|
4738
4828
|
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
@@ -4800,13 +4890,13 @@ const degrees = rad2deg;
|
|
|
4800
4890
|
* @function
|
|
4801
4891
|
* Computes first array raised to power of second array, element-wise.
|
|
4802
4892
|
*/
|
|
4803
|
-
const power = jit$1((x1, x2)
|
|
4893
|
+
const power = jit$1(function power$1(x1, x2) {
|
|
4804
4894
|
return exp(log(x1).mul(x2));
|
|
4805
4895
|
});
|
|
4806
4896
|
/** @function Alias of `jax.numpy.power()`. */
|
|
4807
4897
|
const pow = power;
|
|
4808
4898
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
4809
|
-
const cbrt = jit$1((x)
|
|
4899
|
+
const cbrt = jit$1(function cbrt$1(x) {
|
|
4810
4900
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4811
4901
|
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4812
4902
|
});
|
|
@@ -4816,7 +4906,7 @@ const cbrt = jit$1((x) => {
|
|
|
4816
4906
|
*
|
|
4817
4907
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4818
4908
|
*/
|
|
4819
|
-
const sinh = jit$1((x)
|
|
4909
|
+
const sinh = jit$1(function sinh$1(x) {
|
|
4820
4910
|
const ex = exp(x);
|
|
4821
4911
|
const emx = reciprocal(ex.ref);
|
|
4822
4912
|
return ex.sub(emx).mul(.5);
|
|
@@ -4827,7 +4917,7 @@ const sinh = jit$1((x) => {
|
|
|
4827
4917
|
*
|
|
4828
4918
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4829
4919
|
*/
|
|
4830
|
-
const cosh = jit$1((x)
|
|
4920
|
+
const cosh = jit$1(function cosh$1(x) {
|
|
4831
4921
|
const ex = exp(x);
|
|
4832
4922
|
const emx = reciprocal(ex.ref);
|
|
4833
4923
|
return ex.add(emx).mul(.5);
|
|
@@ -4838,7 +4928,7 @@ const cosh = jit$1((x) => {
|
|
|
4838
4928
|
*
|
|
4839
4929
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4840
4930
|
*/
|
|
4841
|
-
const tanh = jit$1((x)
|
|
4931
|
+
const tanh = jit$1(function tanh$1(x) {
|
|
4842
4932
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4843
4933
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4844
4934
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
@@ -4849,7 +4939,7 @@ const tanh = jit$1((x) => {
|
|
|
4849
4939
|
*
|
|
4850
4940
|
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4851
4941
|
*/
|
|
4852
|
-
const arcsinh = jit$1((x)
|
|
4942
|
+
const arcsinh = jit$1(function arcsinh$1(x) {
|
|
4853
4943
|
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4854
4944
|
});
|
|
4855
4945
|
/**
|
|
@@ -4858,7 +4948,7 @@ const arcsinh = jit$1((x) => {
|
|
|
4858
4948
|
*
|
|
4859
4949
|
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4860
4950
|
*/
|
|
4861
|
-
const arccosh = jit$1((x)
|
|
4951
|
+
const arccosh = jit$1(function arccosh$1(x) {
|
|
4862
4952
|
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4863
4953
|
});
|
|
4864
4954
|
/**
|
|
@@ -4867,7 +4957,7 @@ const arccosh = jit$1((x) => {
|
|
|
4867
4957
|
*
|
|
4868
4958
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4869
4959
|
*/
|
|
4870
|
-
const arctanh = jit$1((x)
|
|
4960
|
+
const arctanh = jit$1(function arctanh$1(x) {
|
|
4871
4961
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4872
4962
|
});
|
|
4873
4963
|
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
@@ -4983,7 +5073,9 @@ function softSign(x) {
|
|
|
4983
5073
|
*
|
|
4984
5074
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4985
5075
|
*/
|
|
4986
|
-
const silu = jit$1((x)
|
|
5076
|
+
const silu = jit$1(function silu$1(x) {
|
|
5077
|
+
return x.ref.mul(sigmoid(x));
|
|
5078
|
+
});
|
|
4987
5079
|
/**
|
|
4988
5080
|
* @function
|
|
4989
5081
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -5044,7 +5136,7 @@ function celu(x, alpha = 1) {
|
|
|
5044
5136
|
*
|
|
5045
5137
|
* This will be improved in the future.
|
|
5046
5138
|
*/
|
|
5047
|
-
const gelu = jit$1((x)
|
|
5139
|
+
const gelu = jit$1(function gelu$1(x) {
|
|
5048
5140
|
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5049
5141
|
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5050
5142
|
});
|
|
@@ -5215,8 +5307,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
5215
5307
|
const keyShape = validateKeyShape(key$1);
|
|
5216
5308
|
return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
|
|
5217
5309
|
}
|
|
5218
|
-
/**
|
|
5219
|
-
function
|
|
5310
|
+
/**
|
|
5311
|
+
* @function
|
|
5312
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
5313
|
+
*/
|
|
5314
|
+
const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
5220
5315
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
5221
5316
|
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
5222
5317
|
dtype: DType.Uint32,
|
|
@@ -5229,7 +5324,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
5229
5324
|
const rand = bitcast(float12, DType.Float32).sub(1);
|
|
5230
5325
|
if (minval === 0 && maxval === 1) return rand;
|
|
5231
5326
|
else return rand.mul(maxval - minval).add(minval);
|
|
5232
|
-
}
|
|
5327
|
+
}, { staticArgnums: [1, 2] });
|
|
5233
5328
|
/**
|
|
5234
5329
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5235
5330
|
*
|
|
@@ -5240,26 +5335,30 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
5240
5335
|
p = fudgeArray(p);
|
|
5241
5336
|
return uniform(key$1, shape$1).less(p);
|
|
5242
5337
|
}
|
|
5243
|
-
/**
|
|
5244
|
-
function
|
|
5338
|
+
/**
|
|
5339
|
+
* @function
|
|
5340
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
5341
|
+
*/
|
|
5342
|
+
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
5245
5343
|
const u = uniform(key$1, shape$1);
|
|
5246
5344
|
return negative(log1p(negative(u)));
|
|
5247
|
-
}
|
|
5345
|
+
}, { staticArgnums: [1] });
|
|
5248
5346
|
/**
|
|
5347
|
+
* @function
|
|
5249
5348
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5250
5349
|
*
|
|
5251
5350
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5252
5351
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5253
5352
|
* bitwise identical to JAX.
|
|
5254
5353
|
*/
|
|
5255
|
-
function normal(key$1, shape$1 = []) {
|
|
5354
|
+
const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
5256
5355
|
const [k1, k2] = split(key$1, 2);
|
|
5257
5356
|
const u1 = uniform(k1, shape$1);
|
|
5258
5357
|
const u2 = uniform(k2, shape$1);
|
|
5259
5358
|
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5260
5359
|
const theta = u2.mul(2 * Math.PI);
|
|
5261
5360
|
return radius.mul(cos(theta));
|
|
5262
|
-
}
|
|
5361
|
+
}, { staticArgnums: [1] });
|
|
5263
5362
|
|
|
5264
5363
|
//#endregion
|
|
5265
5364
|
//#region src/polyfills.ts
|