@jax-js/jax 0.0.5 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-CdcTZEOF.js → backend-DwIAd0AG.js} +205 -112
- package/dist/{backend-yEU0L_ig.cjs → backend-FtkbO6pI.cjs} +217 -118
- package/dist/index.cjs +344 -67
- package/dist/index.d.cts +96 -18
- package/dist/index.d.ts +96 -18
- package/dist/index.js +337 -67
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-CM-xNYzW.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-FtkbO6pI.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
@@ -60,6 +60,10 @@ var JsTreeDef = class JsTreeDef {
|
|
|
60
60
|
this.nodeMetadata = nodeMetadata;
|
|
61
61
|
this.childTreedefs = childTreedefs;
|
|
62
62
|
}
|
|
63
|
+
/** Get the total number of leaves in the tree. */
|
|
64
|
+
get size() {
|
|
65
|
+
return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
|
|
66
|
+
}
|
|
63
67
|
/** Returns a string representation of this tree definition. */
|
|
64
68
|
toString(root = true) {
|
|
65
69
|
if (root) return "JsTreeDef(" + this.toString(false) + ")";
|
|
@@ -215,6 +219,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
|
|
|
215
219
|
const s_ = strides;
|
|
216
220
|
const d_ = dilation;
|
|
217
221
|
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
222
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
223
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
|
|
224
|
+
st = st.reshape([...noop, ...require_backend.zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...require_backend.zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
|
|
225
|
+
st = st.permute([
|
|
226
|
+
...require_backend.range(noop.length),
|
|
227
|
+
...ks.map((_, j) => noop.length + 2 * j),
|
|
228
|
+
...ks.map((_, j) => noop.length + 2 * j + 1)
|
|
229
|
+
]);
|
|
230
|
+
return st;
|
|
231
|
+
}
|
|
218
232
|
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
219
233
|
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
220
234
|
st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
@@ -249,6 +263,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
249
263
|
const s_ = strides;
|
|
250
264
|
const d_ = dilation;
|
|
251
265
|
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
266
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
267
|
+
st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
|
|
268
|
+
st = st.pad([...noop.map(() => [0, 0]), ...require_backend.zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...require_backend.zip(o_, s_).map(([o, s]) => o * s)]);
|
|
269
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
|
|
270
|
+
return st.reshape(st.shape.concat(require_backend.rep(ks.length, 1)));
|
|
271
|
+
}
|
|
252
272
|
if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
253
273
|
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
254
274
|
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
@@ -358,6 +378,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
358
378
|
Primitive$1["Atan"] = "atan";
|
|
359
379
|
Primitive$1["Exp"] = "exp";
|
|
360
380
|
Primitive$1["Log"] = "log";
|
|
381
|
+
Primitive$1["Erf"] = "erf";
|
|
382
|
+
Primitive$1["Erfc"] = "erfc";
|
|
361
383
|
Primitive$1["Sqrt"] = "sqrt";
|
|
362
384
|
Primitive$1["Min"] = "min";
|
|
363
385
|
Primitive$1["Max"] = "max";
|
|
@@ -435,6 +457,12 @@ function exp$1(x) {
|
|
|
435
457
|
function log$1(x) {
|
|
436
458
|
return bind1(Primitive.Log, [x]);
|
|
437
459
|
}
|
|
460
|
+
function erf$1(x) {
|
|
461
|
+
return bind1(Primitive.Erf, [x]);
|
|
462
|
+
}
|
|
463
|
+
function erfc$1(x) {
|
|
464
|
+
return bind1(Primitive.Erfc, [x]);
|
|
465
|
+
}
|
|
438
466
|
function sqrt$1(x) {
|
|
439
467
|
return bind1(Primitive.Sqrt, [x]);
|
|
440
468
|
}
|
|
@@ -1177,12 +1205,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1177
1205
|
} else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
1178
1206
|
});
|
|
1179
1207
|
}
|
|
1180
|
-
function broadcastedJit(fn) {
|
|
1208
|
+
function broadcastedJit(fn, opts) {
|
|
1181
1209
|
return (nargs, exps, avals, params) => {
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1210
|
+
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1211
|
+
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1212
|
+
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1213
|
+
exps = exps.map((exp$3, i) => {
|
|
1214
|
+
exp$3 = reshapeViews(exp$3, (st) => {
|
|
1215
|
+
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1216
|
+
});
|
|
1217
|
+
if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
|
|
1218
|
+
return exp$3;
|
|
1219
|
+
});
|
|
1186
1220
|
const exp$2 = fn(exps, params);
|
|
1187
1221
|
return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
|
|
1188
1222
|
};
|
|
@@ -1225,6 +1259,8 @@ const jitRules = {
|
|
|
1225
1259
|
[Primitive.Atan]: unopJit(require_backend.AluExp.atan),
|
|
1226
1260
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
1227
1261
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1262
|
+
[Primitive.Erf]: unopJit(require_backend.AluExp.erf),
|
|
1263
|
+
[Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
|
|
1228
1264
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1229
1265
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
1230
1266
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
@@ -1272,7 +1308,7 @@ const jitRules = {
|
|
|
1272
1308
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1273
1309
|
},
|
|
1274
1310
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1275
|
-
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
|
|
1311
|
+
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1276
1312
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
1277
1313
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
1278
1314
|
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
@@ -1443,7 +1479,7 @@ var PendingExecute = class {
|
|
|
1443
1479
|
/**
|
|
1444
1480
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1445
1481
|
*
|
|
1446
|
-
* This is the library's core data type. Equivalent to `
|
|
1482
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1447
1483
|
* `torch.Tensor`.
|
|
1448
1484
|
*
|
|
1449
1485
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -1458,6 +1494,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1458
1494
|
#source;
|
|
1459
1495
|
#st;
|
|
1460
1496
|
#backend;
|
|
1497
|
+
#committed;
|
|
1461
1498
|
#rc;
|
|
1462
1499
|
#pendingSet;
|
|
1463
1500
|
/**
|
|
@@ -1474,6 +1511,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1474
1511
|
this.#source = args.source;
|
|
1475
1512
|
this.#st = args.st;
|
|
1476
1513
|
this.#backend = args.backend;
|
|
1514
|
+
this.#committed = args.committed;
|
|
1477
1515
|
this.#rc = 1;
|
|
1478
1516
|
this.#pendingSet = new Set(args.pending);
|
|
1479
1517
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
@@ -1501,6 +1539,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1501
1539
|
dtype: args.dtype ?? this.#dtype,
|
|
1502
1540
|
weakType: this.#weakType,
|
|
1503
1541
|
backend: args.backend ?? this.#backend,
|
|
1542
|
+
committed: args.committed ?? this.#committed,
|
|
1504
1543
|
pending: args.pending ?? this.#pending ?? void 0
|
|
1505
1544
|
});
|
|
1506
1545
|
}
|
|
@@ -1556,9 +1595,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1556
1595
|
*/
|
|
1557
1596
|
#gather(indices, axis, outDim) {
|
|
1558
1597
|
this.#check();
|
|
1559
|
-
if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1560
1598
|
const axisSet = new Set(axis);
|
|
1561
1599
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
|
|
1600
|
+
if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1601
|
+
indices = indices.map((ar) => ar._putSync(this.#backend));
|
|
1562
1602
|
indices = Array$1.#broadcastArrays(indices);
|
|
1563
1603
|
const indexShape = indices[0].shape;
|
|
1564
1604
|
const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1627,6 +1667,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1627
1667
|
this.#check();
|
|
1628
1668
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1629
1669
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1670
|
+
this.dispose();
|
|
1630
1671
|
return this.#newArrayFrom({
|
|
1631
1672
|
source: exp$3.simplify(),
|
|
1632
1673
|
dtype: dtypeOutput,
|
|
@@ -1655,21 +1696,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1655
1696
|
}
|
|
1656
1697
|
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1657
1698
|
const n = arrays.length;
|
|
1658
|
-
const backend = arrays[0].#backend;
|
|
1659
1699
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1660
1700
|
for (const ar of arrays) ar.#check();
|
|
1661
1701
|
let castDtype;
|
|
1662
1702
|
let castWeakType = true;
|
|
1663
|
-
for (let i = 0; i < n; i++) {
|
|
1664
|
-
if (dtypeOverride
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1670
|
-
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1671
|
-
}
|
|
1703
|
+
for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
|
|
1704
|
+
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1705
|
+
} else if (castDtype === void 0) {
|
|
1706
|
+
castDtype = arrays[i].#dtype;
|
|
1707
|
+
castWeakType = arrays[i].#weakType;
|
|
1708
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1672
1709
|
const weakType = castWeakType && !strongTypeOutput;
|
|
1710
|
+
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
1711
|
+
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
1673
1712
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1674
1713
|
const newShape = [...arrays[0].shape];
|
|
1675
1714
|
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
@@ -1679,12 +1718,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1679
1718
|
});
|
|
1680
1719
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1681
1720
|
const exp$4 = custom(sources);
|
|
1721
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1682
1722
|
return new Array$1({
|
|
1683
1723
|
source: exp$4.simplify(),
|
|
1684
1724
|
st: arrays[0].#st,
|
|
1685
1725
|
dtype: exp$4.dtype,
|
|
1686
1726
|
weakType,
|
|
1687
|
-
backend
|
|
1727
|
+
backend,
|
|
1728
|
+
committed
|
|
1688
1729
|
});
|
|
1689
1730
|
}
|
|
1690
1731
|
const exp$3 = custom(arrays.map((ar, i) => {
|
|
@@ -1693,12 +1734,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1693
1734
|
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1694
1735
|
}));
|
|
1695
1736
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1737
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1696
1738
|
return new Array$1({
|
|
1697
1739
|
source: exp$3.simplify(),
|
|
1698
1740
|
st,
|
|
1699
1741
|
dtype: exp$3.dtype,
|
|
1700
1742
|
weakType,
|
|
1701
|
-
backend
|
|
1743
|
+
backend,
|
|
1744
|
+
committed
|
|
1702
1745
|
});
|
|
1703
1746
|
}
|
|
1704
1747
|
let indices;
|
|
@@ -1734,13 +1777,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1734
1777
|
const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
|
|
1735
1778
|
for (const exe of pending) exe.updateRc(1);
|
|
1736
1779
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1737
|
-
|
|
1780
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1738
1781
|
return new Array$1({
|
|
1739
1782
|
source: output,
|
|
1740
1783
|
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1741
1784
|
dtype: kernel.dtype,
|
|
1742
1785
|
weakType,
|
|
1743
1786
|
backend,
|
|
1787
|
+
committed,
|
|
1744
1788
|
pending
|
|
1745
1789
|
});
|
|
1746
1790
|
}
|
|
@@ -1818,6 +1862,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1818
1862
|
return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
|
|
1819
1863
|
});
|
|
1820
1864
|
}
|
|
1865
|
+
static #computeBackend(name, arrays) {
|
|
1866
|
+
const committed = arrays.filter((ar) => ar.#committed);
|
|
1867
|
+
if (committed.length > 0) {
|
|
1868
|
+
const backend = committed[0].#backend;
|
|
1869
|
+
for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
|
|
1870
|
+
return {
|
|
1871
|
+
backend,
|
|
1872
|
+
committed: true
|
|
1873
|
+
};
|
|
1874
|
+
} else {
|
|
1875
|
+
const backend = arrays.length > 0 ? arrays[0].#backend : require_backend.getBackend();
|
|
1876
|
+
return {
|
|
1877
|
+
backend,
|
|
1878
|
+
committed: false
|
|
1879
|
+
};
|
|
1880
|
+
}
|
|
1881
|
+
}
|
|
1821
1882
|
/** Realize the array and return it as data. */
|
|
1822
1883
|
async data() {
|
|
1823
1884
|
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
@@ -1977,6 +2038,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1977
2038
|
[Primitive.Log]([x]) {
|
|
1978
2039
|
return [x.#unary(require_backend.AluOp.Log)];
|
|
1979
2040
|
},
|
|
2041
|
+
[Primitive.Erf]([x]) {
|
|
2042
|
+
return [x.#unary(require_backend.AluOp.Erf)];
|
|
2043
|
+
},
|
|
2044
|
+
[Primitive.Erfc]([x]) {
|
|
2045
|
+
return [x.#unary(require_backend.AluOp.Erfc)];
|
|
2046
|
+
},
|
|
1980
2047
|
[Primitive.Sqrt]([x]) {
|
|
1981
2048
|
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
1982
2049
|
},
|
|
@@ -2045,7 +2112,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2045
2112
|
},
|
|
2046
2113
|
[Primitive.JitCall](args, { jaxpr, numConsts }) {
|
|
2047
2114
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
2048
|
-
const backend =
|
|
2115
|
+
const { backend, committed } = Array$1.#computeBackend("jit_call", args);
|
|
2116
|
+
args = args.map((ar) => ar._putSync(backend));
|
|
2049
2117
|
const consts = args.slice(0, numConsts);
|
|
2050
2118
|
const tracers = args.slice(numConsts);
|
|
2051
2119
|
const jp = jitCompile(backend, jaxpr, consts);
|
|
@@ -2062,16 +2130,54 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2062
2130
|
dtype: jaxpr.outs[i].aval.dtype,
|
|
2063
2131
|
weakType: jaxpr.outs[i].aval.weakType,
|
|
2064
2132
|
backend,
|
|
2133
|
+
committed,
|
|
2065
2134
|
pending
|
|
2066
2135
|
});
|
|
2067
2136
|
});
|
|
2068
2137
|
}
|
|
2069
2138
|
};
|
|
2070
2139
|
}
|
|
2140
|
+
/** @private */
|
|
2071
2141
|
_realizeSource() {
|
|
2072
2142
|
this.#realize();
|
|
2073
2143
|
return this.#source;
|
|
2074
2144
|
}
|
|
2145
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
2146
|
+
async _put(backend) {
|
|
2147
|
+
if (this.#backend === backend) return this;
|
|
2148
|
+
if (this.#source instanceof require_backend.AluExp) {
|
|
2149
|
+
const ar = this.#newArrayFrom({
|
|
2150
|
+
backend,
|
|
2151
|
+
committed: true
|
|
2152
|
+
});
|
|
2153
|
+
this.dispose();
|
|
2154
|
+
return ar;
|
|
2155
|
+
} else {
|
|
2156
|
+
const data = await this.data();
|
|
2157
|
+
return arrayFromData(data, this.shape, {
|
|
2158
|
+
dtype: this.#dtype,
|
|
2159
|
+
device: backend.type
|
|
2160
|
+
}, this.#weakType);
|
|
2161
|
+
}
|
|
2162
|
+
}
|
|
2163
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
2164
|
+
_putSync(backend) {
|
|
2165
|
+
if (this.#backend === backend) return this;
|
|
2166
|
+
if (this.#source instanceof require_backend.AluExp) {
|
|
2167
|
+
const ar = this.#newArrayFrom({
|
|
2168
|
+
backend,
|
|
2169
|
+
committed: true
|
|
2170
|
+
});
|
|
2171
|
+
this.dispose();
|
|
2172
|
+
return ar;
|
|
2173
|
+
} else {
|
|
2174
|
+
const data = this.dataSync();
|
|
2175
|
+
return arrayFromData(data, this.shape, {
|
|
2176
|
+
dtype: this.#dtype,
|
|
2177
|
+
device: backend.type
|
|
2178
|
+
}, this.#weakType);
|
|
2179
|
+
}
|
|
2180
|
+
}
|
|
2075
2181
|
};
|
|
2076
2182
|
/** Constructor for creating a new array from data. */
|
|
2077
2183
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
@@ -2154,7 +2260,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2154
2260
|
st: require_backend.ShapeTracker.fromShape(shape$1),
|
|
2155
2261
|
dtype,
|
|
2156
2262
|
weakType,
|
|
2157
|
-
backend
|
|
2263
|
+
backend,
|
|
2264
|
+
committed: device != void 0
|
|
2158
2265
|
});
|
|
2159
2266
|
}
|
|
2160
2267
|
function dataToJs(dtype, data, shape$1) {
|
|
@@ -2188,7 +2295,8 @@ function fullInternal(aval, fillValue, device) {
|
|
|
2188
2295
|
st: require_backend.ShapeTracker.fromShape(aval.shape),
|
|
2189
2296
|
dtype: aval.dtype,
|
|
2190
2297
|
weakType: aval.weakType,
|
|
2191
|
-
backend: require_backend.getBackend(device)
|
|
2298
|
+
backend: require_backend.getBackend(device),
|
|
2299
|
+
committed: device != void 0
|
|
2192
2300
|
});
|
|
2193
2301
|
}
|
|
2194
2302
|
function zerosLike$1(val, dtype) {
|
|
@@ -2256,7 +2364,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2256
2364
|
st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
|
|
2257
2365
|
dtype,
|
|
2258
2366
|
weakType,
|
|
2259
|
-
backend: require_backend.getBackend(device)
|
|
2367
|
+
backend: require_backend.getBackend(device),
|
|
2368
|
+
committed: device != void 0
|
|
2260
2369
|
});
|
|
2261
2370
|
}
|
|
2262
2371
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
@@ -2299,7 +2408,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2299
2408
|
st,
|
|
2300
2409
|
dtype,
|
|
2301
2410
|
weakType: false,
|
|
2302
|
-
backend: require_backend.getBackend(device)
|
|
2411
|
+
backend: require_backend.getBackend(device),
|
|
2412
|
+
committed: device != void 0
|
|
2303
2413
|
});
|
|
2304
2414
|
}
|
|
2305
2415
|
/**
|
|
@@ -2335,7 +2445,8 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2335
2445
|
st,
|
|
2336
2446
|
dtype,
|
|
2337
2447
|
weakType: false,
|
|
2338
|
-
backend: require_backend.getBackend(device)
|
|
2448
|
+
backend: require_backend.getBackend(device),
|
|
2449
|
+
committed: device != void 0
|
|
2339
2450
|
});
|
|
2340
2451
|
}
|
|
2341
2452
|
function aluCompare(a, b, op) {
|
|
@@ -2847,6 +2958,8 @@ const abstractEvalRules = {
|
|
|
2847
2958
|
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2848
2959
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2849
2960
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2961
|
+
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
2962
|
+
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
2850
2963
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2851
2964
|
[Primitive.Min]: binopAbstractEval,
|
|
2852
2965
|
[Primitive.Max]: binopAbstractEval,
|
|
@@ -3100,6 +3213,16 @@ const jvpRules = {
|
|
|
3100
3213
|
[Primitive.Log]([x], [dx]) {
|
|
3101
3214
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3102
3215
|
},
|
|
3216
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3217
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3218
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3219
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3220
|
+
},
|
|
3221
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3222
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3223
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3224
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3225
|
+
},
|
|
3103
3226
|
[Primitive.Sqrt]([x], [dx]) {
|
|
3104
3227
|
const z = sqrt$1(x);
|
|
3105
3228
|
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
@@ -3262,6 +3385,10 @@ var BatchTrace = class extends Trace {
|
|
|
3262
3385
|
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3263
3386
|
const vmapRule = vmapRules[primitive];
|
|
3264
3387
|
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3388
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3389
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3390
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3391
|
+
}
|
|
3265
3392
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3266
3393
|
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3267
3394
|
}
|
|
@@ -3269,24 +3396,28 @@ var BatchTrace = class extends Trace {
|
|
|
3269
3396
|
return this.main.globalData;
|
|
3270
3397
|
}
|
|
3271
3398
|
};
|
|
3272
|
-
|
|
3273
|
-
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3277
|
-
return broadcast(x, shape$1, axis);
|
|
3278
|
-
}
|
|
3279
|
-
}
|
|
3280
|
-
/** Process a primitive with built-in broadcasting. */
|
|
3399
|
+
/**
|
|
3400
|
+
* Process a primitive with built-in broadcasting.
|
|
3401
|
+
*
|
|
3402
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3403
|
+
*/
|
|
3281
3404
|
function broadcastBatcher(op) {
|
|
3282
3405
|
return (axisSize, args, dims) => {
|
|
3283
3406
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3284
|
-
const
|
|
3285
|
-
|
|
3286
|
-
|
|
3287
|
-
args
|
|
3288
|
-
|
|
3289
|
-
|
|
3407
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3408
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3409
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3410
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3411
|
+
args = args.map((x, i) => {
|
|
3412
|
+
if (dims[i] === null) return x;
|
|
3413
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3414
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3415
|
+
x.shape[0],
|
|
3416
|
+
...require_backend.rep(nd - x.ndim, 1),
|
|
3417
|
+
...x.shape.slice(1)
|
|
3418
|
+
]);
|
|
3419
|
+
return x;
|
|
3420
|
+
});
|
|
3290
3421
|
return [[op(...args)], [0]];
|
|
3291
3422
|
};
|
|
3292
3423
|
}
|
|
@@ -3310,17 +3441,18 @@ const vmapRules = {
|
|
|
3310
3441
|
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3311
3442
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3312
3443
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3444
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3445
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3313
3446
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3314
3447
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3315
3448
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3316
3449
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3317
|
-
|
|
3450
|
+
require_backend.assertNonNull(xBdim);
|
|
3318
3451
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3319
3452
|
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3320
3453
|
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3321
3454
|
},
|
|
3322
3455
|
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3323
|
-
if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
|
|
3324
3456
|
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3325
3457
|
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3326
3458
|
const z = dot$1(x, y);
|
|
@@ -3329,26 +3461,68 @@ const vmapRules = {
|
|
|
3329
3461
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3330
3462
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3331
3463
|
},
|
|
3464
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3465
|
+
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3466
|
+
require_backend.assertNonNull(xBdim);
|
|
3467
|
+
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
3468
|
+
newPerm.splice(xBdim, 0, xBdim);
|
|
3469
|
+
return [[transpose$1(x, newPerm)], [xBdim]];
|
|
3470
|
+
},
|
|
3471
|
+
[Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
|
|
3472
|
+
require_backend.assertNonNull(xBdim);
|
|
3473
|
+
const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
|
|
3474
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3475
|
+
return [[broadcast(x, newShape, newAxis)], [xBdim]];
|
|
3476
|
+
},
|
|
3332
3477
|
[Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
|
|
3333
|
-
if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
|
|
3334
3478
|
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3335
3479
|
return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
|
|
3336
3480
|
},
|
|
3337
3481
|
[Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
|
|
3338
|
-
|
|
3482
|
+
require_backend.assertNonNull(xBdim);
|
|
3339
3483
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3340
3484
|
return [[flip$1(x, newAxis)], [xBdim]];
|
|
3341
3485
|
},
|
|
3342
3486
|
[Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
|
|
3343
|
-
|
|
3487
|
+
require_backend.assertNonNull(xBdim);
|
|
3344
3488
|
const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
|
|
3345
3489
|
return [[shrink(x, newSlice)], [xBdim]];
|
|
3346
3490
|
},
|
|
3347
3491
|
[Primitive.Pad](axisSize, [x], [xBdim], { width }) {
|
|
3348
|
-
|
|
3492
|
+
require_backend.assertNonNull(xBdim);
|
|
3349
3493
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3350
3494
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3351
3495
|
},
|
|
3496
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3497
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3498
|
+
require_backend.assertNonNull(xBdim);
|
|
3499
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3500
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3501
|
+
let newOutDim = outDim;
|
|
3502
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3503
|
+
else newOutDim += 1;
|
|
3504
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3505
|
+
}
|
|
3506
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3507
|
+
indices = indices.map((m, i) => {
|
|
3508
|
+
if (indicesBdim[i] === null) return m;
|
|
3509
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3510
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3511
|
+
m.shape[0],
|
|
3512
|
+
...require_backend.rep(nd - m.ndim, 1),
|
|
3513
|
+
...m.shape.slice(1)
|
|
3514
|
+
]);
|
|
3515
|
+
return m;
|
|
3516
|
+
});
|
|
3517
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3518
|
+
else {
|
|
3519
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3520
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3521
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
|
|
3522
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3523
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3524
|
+
}
|
|
3525
|
+
},
|
|
3352
3526
|
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3353
3527
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3354
3528
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
@@ -3408,12 +3582,14 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3408
3582
|
function vmap$1(f, inAxes = 0) {
|
|
3409
3583
|
return (...args) => {
|
|
3410
3584
|
const [argsFlat, inTree] = flatten(args);
|
|
3411
|
-
let inAxesFlat;
|
|
3585
|
+
let inAxesFlat = [];
|
|
3412
3586
|
if (typeof inAxes === "number") inAxesFlat = require_backend.rep(argsFlat.length, inAxes);
|
|
3587
|
+
else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, null));
|
|
3588
|
+
else if (typeof inAxes[i] === "number") inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, inAxes[i]));
|
|
3413
3589
|
else {
|
|
3414
|
-
|
|
3415
|
-
[
|
|
3416
|
-
|
|
3590
|
+
const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
|
|
3591
|
+
if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
|
|
3592
|
+
inAxesFlat.push(...axesFlat);
|
|
3417
3593
|
}
|
|
3418
3594
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3419
3595
|
const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
|
|
@@ -4033,7 +4209,7 @@ function valueAndGrad$1(f) {
|
|
|
4033
4209
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4034
4210
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4035
4211
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4036
|
-
const [ct, ...rest] = fVjp(
|
|
4212
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4037
4213
|
for (const r of rest) dispose(r);
|
|
4038
4214
|
fVjp.dispose();
|
|
4039
4215
|
return [y, ct];
|
|
@@ -4061,7 +4237,10 @@ __export(lax_exports, {
|
|
|
4061
4237
|
conv: () => conv$1,
|
|
4062
4238
|
convGeneralDilated: () => convGeneralDilated,
|
|
4063
4239
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4064
|
-
|
|
4240
|
+
erf: () => erf,
|
|
4241
|
+
erfc: () => erfc,
|
|
4242
|
+
reduceWindow: () => reduceWindow,
|
|
4243
|
+
stopGradient: () => stopGradient$1
|
|
4065
4244
|
});
|
|
4066
4245
|
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4067
4246
|
const padType = padding.toUpperCase();
|
|
@@ -4120,6 +4299,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
|
4120
4299
|
strides: windowStrides
|
|
4121
4300
|
}));
|
|
4122
4301
|
}
|
|
4302
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4303
|
+
function erf(x) {
|
|
4304
|
+
return erf$1(x);
|
|
4305
|
+
}
|
|
4306
|
+
/**
|
|
4307
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4308
|
+
*
|
|
4309
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4310
|
+
* where `erf(x)` is very close to 1.
|
|
4311
|
+
*/
|
|
4312
|
+
function erfc(x) {
|
|
4313
|
+
return erfc$1(x);
|
|
4314
|
+
}
|
|
4315
|
+
/**
|
|
4316
|
+
* Stops gradient computation.
|
|
4317
|
+
*
|
|
4318
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
4319
|
+
* forward or reverse-mode automatic differentiation.
|
|
4320
|
+
*/
|
|
4321
|
+
function stopGradient$1(x) {
|
|
4322
|
+
return stopGradient(x);
|
|
4323
|
+
}
|
|
4123
4324
|
|
|
4124
4325
|
//#endregion
|
|
4125
4326
|
//#region src/numpy.ts
|
|
@@ -4182,6 +4383,9 @@ __export(numpy_exports, {
|
|
|
4182
4383
|
fullLike: () => fullLike$1,
|
|
4183
4384
|
greater: () => greater,
|
|
4184
4385
|
greaterEqual: () => greaterEqual,
|
|
4386
|
+
hamming: () => hamming,
|
|
4387
|
+
hann: () => hann,
|
|
4388
|
+
heaviside: () => heaviside,
|
|
4185
4389
|
hstack: () => hstack,
|
|
4186
4390
|
hypot: () => hypot,
|
|
4187
4391
|
identity: () => identity$1,
|
|
@@ -4821,6 +5025,32 @@ function sign(x) {
|
|
|
4821
5025
|
x = fudgeArray(x);
|
|
4822
5026
|
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4823
5027
|
}
|
|
5028
|
+
/**
|
|
5029
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
5030
|
+
*
|
|
5031
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5032
|
+
*/
|
|
5033
|
+
function hamming(M) {
|
|
5034
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
|
|
5035
|
+
}
|
|
5036
|
+
/**
|
|
5037
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
5038
|
+
*
|
|
5039
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5040
|
+
*/
|
|
5041
|
+
function hann(M) {
|
|
5042
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
|
|
5043
|
+
}
|
|
5044
|
+
/**
|
|
5045
|
+
* @function
|
|
5046
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
5047
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
5048
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
5049
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
5050
|
+
*/
|
|
5051
|
+
const heaviside = jit$1(function heaviside$1(x1, x2) {
|
|
5052
|
+
return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
|
|
5053
|
+
});
|
|
4824
5054
|
/** Calculate element-wise square of the input array. */
|
|
4825
5055
|
function square(x) {
|
|
4826
5056
|
x = fudgeArray(x);
|
|
@@ -4840,8 +5070,8 @@ function acos(x) {
|
|
|
4840
5070
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4841
5071
|
*
|
|
4842
5072
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4843
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
4844
|
-
* improvements.
|
|
5073
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
5074
|
+
* stability improvements.
|
|
4845
5075
|
*/
|
|
4846
5076
|
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4847
5077
|
return sqrt(square(x1).add(square(x2)));
|
|
@@ -5165,18 +5395,20 @@ function celu(x, alpha = 1) {
|
|
|
5165
5395
|
* @function
|
|
5166
5396
|
* Gaussion error linear unit (GELU) activation function.
|
|
5167
5397
|
*
|
|
5168
|
-
* This is computed element-wise.
|
|
5169
|
-
*
|
|
5170
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5398
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
5399
|
+
* `approximate` is set (default true):
|
|
5171
5400
|
*
|
|
5172
|
-
*
|
|
5401
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
5402
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
5173
5403
|
*
|
|
5174
|
-
*
|
|
5404
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5175
5405
|
*/
|
|
5176
|
-
const gelu = jit$1(function gelu$1(x) {
|
|
5177
|
-
|
|
5178
|
-
|
|
5179
|
-
|
|
5406
|
+
const gelu = jit$1(function gelu$1(x, opts) {
|
|
5407
|
+
if (opts?.approximate ?? true) {
|
|
5408
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5409
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5410
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
|
|
5411
|
+
}, { staticArgnums: [1] });
|
|
5180
5412
|
/**
|
|
5181
5413
|
* Gated linear unit (GLU) activation function.
|
|
5182
5414
|
*
|
|
@@ -5397,6 +5629,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
5397
5629
|
return radius.mul(cos(theta));
|
|
5398
5630
|
}, { staticArgnums: [1] });
|
|
5399
5631
|
|
|
5632
|
+
//#endregion
|
|
5633
|
+
//#region src/scipy-special.ts
|
|
5634
|
+
var scipy_special_exports = {};
|
|
5635
|
+
__export(scipy_special_exports, {
|
|
5636
|
+
erf: () => erf,
|
|
5637
|
+
erfc: () => erfc,
|
|
5638
|
+
logSoftmax: () => logSoftmax,
|
|
5639
|
+
logit: () => logit,
|
|
5640
|
+
logsumexp: () => logsumexp,
|
|
5641
|
+
softmax: () => softmax
|
|
5642
|
+
});
|
|
5643
|
+
/**
|
|
5644
|
+
* @function
|
|
5645
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
5646
|
+
*/
|
|
5647
|
+
const logit = jit$1(function logit$1(x) {
|
|
5648
|
+
return log(x.ref.div(subtract(1, x)));
|
|
5649
|
+
});
|
|
5650
|
+
|
|
5400
5651
|
//#endregion
|
|
5401
5652
|
//#region src/polyfills.ts
|
|
5402
5653
|
/** @file Polyfills for using this library. */
|
|
@@ -5490,6 +5741,24 @@ async function blockUntilReady(x) {
|
|
|
5490
5741
|
await Promise.all(promises);
|
|
5491
5742
|
return x;
|
|
5492
5743
|
}
|
|
5744
|
+
/**
|
|
5745
|
+
* Transfer `x` to `device`.
|
|
5746
|
+
*
|
|
5747
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
5748
|
+
* is committed to the device.
|
|
5749
|
+
*
|
|
5750
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
5751
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
5752
|
+
* default device.
|
|
5753
|
+
*/
|
|
5754
|
+
async function devicePut(x, device) {
|
|
5755
|
+
const [xflat, structure$1] = flatten(x);
|
|
5756
|
+
const yflat = await Promise.all(xflat.map((leaf) => {
|
|
5757
|
+
if (leaf instanceof Array$1) return device ? leaf._put(require_backend.getBackend(device)) : Promise.resolve(leaf);
|
|
5758
|
+
else return Promise.resolve(array(leaf, { device }));
|
|
5759
|
+
}));
|
|
5760
|
+
return unflatten(structure$1, yflat);
|
|
5761
|
+
}
|
|
5493
5762
|
|
|
5494
5763
|
//#endregion
|
|
5495
5764
|
exports.Array = Array$1;
|
|
@@ -5497,6 +5766,7 @@ exports.DType = require_backend.DType;
|
|
|
5497
5766
|
exports.Jaxpr = Jaxpr;
|
|
5498
5767
|
exports.blockUntilReady = blockUntilReady;
|
|
5499
5768
|
exports.defaultDevice = require_backend.defaultDevice;
|
|
5769
|
+
exports.devicePut = devicePut;
|
|
5500
5770
|
exports.devices = require_backend.devices;
|
|
5501
5771
|
exports.grad = grad;
|
|
5502
5772
|
exports.init = require_backend.init;
|
|
@@ -5531,6 +5801,12 @@ Object.defineProperty(exports, 'random', {
|
|
|
5531
5801
|
return random_exports;
|
|
5532
5802
|
}
|
|
5533
5803
|
});
|
|
5804
|
+
Object.defineProperty(exports, 'scipySpecial', {
|
|
5805
|
+
enumerable: true,
|
|
5806
|
+
get: function () {
|
|
5807
|
+
return scipy_special_exports;
|
|
5808
|
+
}
|
|
5809
|
+
});
|
|
5534
5810
|
exports.setDebug = require_backend.setDebug;
|
|
5535
5811
|
Object.defineProperty(exports, 'tree', {
|
|
5536
5812
|
enumerable: true,
|
|
@@ -5540,4 +5816,5 @@ Object.defineProperty(exports, 'tree', {
|
|
|
5540
5816
|
});
|
|
5541
5817
|
exports.valueAndGrad = valueAndGrad;
|
|
5542
5818
|
exports.vjp = vjp;
|
|
5543
|
-
exports.vmap = vmap;
|
|
5819
|
+
exports.vmap = vmap;
|
|
5820
|
+
//# sourceMappingURL=index.cjs.map
|