@jax-js/jax 0.0.5 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-CdcTZEOF.js → backend-DwIAd0AG.js} +205 -112
- package/dist/{backend-yEU0L_ig.cjs → backend-FtkbO6pI.cjs} +217 -118
- package/dist/index.cjs +344 -67
- package/dist/index.d.cts +96 -18
- package/dist/index.d.ts +96 -18
- package/dist/index.js +337 -67
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-CM-xNYzW.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DwIAd0AG.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -29,6 +29,10 @@ var JsTreeDef = class JsTreeDef {
|
|
|
29
29
|
this.nodeMetadata = nodeMetadata;
|
|
30
30
|
this.childTreedefs = childTreedefs;
|
|
31
31
|
}
|
|
32
|
+
/** Get the total number of leaves in the tree. */
|
|
33
|
+
get size() {
|
|
34
|
+
return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
|
|
35
|
+
}
|
|
32
36
|
/** Returns a string representation of this tree definition. */
|
|
33
37
|
toString(root = true) {
|
|
34
38
|
if (root) return "JsTreeDef(" + this.toString(false) + ")";
|
|
@@ -184,6 +188,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
|
|
|
184
188
|
const s_ = strides;
|
|
185
189
|
const d_ = dilation;
|
|
186
190
|
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
191
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
192
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
|
|
193
|
+
st = st.reshape([...noop, ...zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
|
|
194
|
+
st = st.permute([
|
|
195
|
+
...range(noop.length),
|
|
196
|
+
...ks.map((_, j) => noop.length + 2 * j),
|
|
197
|
+
...ks.map((_, j) => noop.length + 2 * j + 1)
|
|
198
|
+
]);
|
|
199
|
+
return st;
|
|
200
|
+
}
|
|
187
201
|
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
188
202
|
const kidf = zipn(ks, i_, d_, f_);
|
|
189
203
|
st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
@@ -218,6 +232,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
218
232
|
const s_ = strides;
|
|
219
233
|
const d_ = dilation;
|
|
220
234
|
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
235
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
236
|
+
st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
|
|
237
|
+
st = st.pad([...noop.map(() => [0, 0]), ...zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...zip(o_, s_).map(([o, s]) => o * s)]);
|
|
238
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
|
|
239
|
+
return st.reshape(st.shape.concat(rep(ks.length, 1)));
|
|
240
|
+
}
|
|
221
241
|
if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
222
242
|
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
223
243
|
const kidf = zipn(ks, i_, d_, f_);
|
|
@@ -327,6 +347,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
327
347
|
Primitive$1["Atan"] = "atan";
|
|
328
348
|
Primitive$1["Exp"] = "exp";
|
|
329
349
|
Primitive$1["Log"] = "log";
|
|
350
|
+
Primitive$1["Erf"] = "erf";
|
|
351
|
+
Primitive$1["Erfc"] = "erfc";
|
|
330
352
|
Primitive$1["Sqrt"] = "sqrt";
|
|
331
353
|
Primitive$1["Min"] = "min";
|
|
332
354
|
Primitive$1["Max"] = "max";
|
|
@@ -404,6 +426,12 @@ function exp$1(x) {
|
|
|
404
426
|
function log$1(x) {
|
|
405
427
|
return bind1(Primitive.Log, [x]);
|
|
406
428
|
}
|
|
429
|
+
function erf$1(x) {
|
|
430
|
+
return bind1(Primitive.Erf, [x]);
|
|
431
|
+
}
|
|
432
|
+
function erfc$1(x) {
|
|
433
|
+
return bind1(Primitive.Erfc, [x]);
|
|
434
|
+
}
|
|
407
435
|
function sqrt$1(x) {
|
|
408
436
|
return bind1(Primitive.Sqrt, [x]);
|
|
409
437
|
}
|
|
@@ -1146,12 +1174,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1146
1174
|
} else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
1147
1175
|
});
|
|
1148
1176
|
}
|
|
1149
|
-
function broadcastedJit(fn) {
|
|
1177
|
+
function broadcastedJit(fn, opts) {
|
|
1150
1178
|
return (nargs, exps, avals, params) => {
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1179
|
+
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1180
|
+
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1181
|
+
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1182
|
+
exps = exps.map((exp$3, i) => {
|
|
1183
|
+
exp$3 = reshapeViews(exp$3, (st) => {
|
|
1184
|
+
if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
|
|
1185
|
+
});
|
|
1186
|
+
if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
|
|
1187
|
+
return exp$3;
|
|
1188
|
+
});
|
|
1155
1189
|
const exp$2 = fn(exps, params);
|
|
1156
1190
|
return new Kernel(nargs, prod(newShape), exp$2);
|
|
1157
1191
|
};
|
|
@@ -1194,6 +1228,8 @@ const jitRules = {
|
|
|
1194
1228
|
[Primitive.Atan]: unopJit(AluExp.atan),
|
|
1195
1229
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
1196
1230
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1231
|
+
[Primitive.Erf]: unopJit(AluExp.erf),
|
|
1232
|
+
[Primitive.Erfc]: unopJit(AluExp.erfc),
|
|
1197
1233
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1198
1234
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
1199
1235
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
@@ -1241,7 +1277,7 @@ const jitRules = {
|
|
|
1241
1277
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1242
1278
|
},
|
|
1243
1279
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1244
|
-
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
|
|
1280
|
+
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1245
1281
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
1246
1282
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
1247
1283
|
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
@@ -1412,7 +1448,7 @@ var PendingExecute = class {
|
|
|
1412
1448
|
/**
|
|
1413
1449
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1414
1450
|
*
|
|
1415
|
-
* This is the library's core data type. Equivalent to `
|
|
1451
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1416
1452
|
* `torch.Tensor`.
|
|
1417
1453
|
*
|
|
1418
1454
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -1427,6 +1463,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1427
1463
|
#source;
|
|
1428
1464
|
#st;
|
|
1429
1465
|
#backend;
|
|
1466
|
+
#committed;
|
|
1430
1467
|
#rc;
|
|
1431
1468
|
#pendingSet;
|
|
1432
1469
|
/**
|
|
@@ -1443,6 +1480,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1443
1480
|
this.#source = args.source;
|
|
1444
1481
|
this.#st = args.st;
|
|
1445
1482
|
this.#backend = args.backend;
|
|
1483
|
+
this.#committed = args.committed;
|
|
1446
1484
|
this.#rc = 1;
|
|
1447
1485
|
this.#pendingSet = new Set(args.pending);
|
|
1448
1486
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
@@ -1470,6 +1508,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1470
1508
|
dtype: args.dtype ?? this.#dtype,
|
|
1471
1509
|
weakType: this.#weakType,
|
|
1472
1510
|
backend: args.backend ?? this.#backend,
|
|
1511
|
+
committed: args.committed ?? this.#committed,
|
|
1473
1512
|
pending: args.pending ?? this.#pending ?? void 0
|
|
1474
1513
|
});
|
|
1475
1514
|
}
|
|
@@ -1525,9 +1564,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1525
1564
|
*/
|
|
1526
1565
|
#gather(indices, axis, outDim) {
|
|
1527
1566
|
this.#check();
|
|
1528
|
-
if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1529
1567
|
const axisSet = new Set(axis);
|
|
1530
1568
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
|
|
1569
|
+
if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1570
|
+
indices = indices.map((ar) => ar._putSync(this.#backend));
|
|
1531
1571
|
indices = Array$1.#broadcastArrays(indices);
|
|
1532
1572
|
const indexShape = indices[0].shape;
|
|
1533
1573
|
const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1596,6 +1636,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1596
1636
|
this.#check();
|
|
1597
1637
|
if (this.#source instanceof AluExp) {
|
|
1598
1638
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1639
|
+
this.dispose();
|
|
1599
1640
|
return this.#newArrayFrom({
|
|
1600
1641
|
source: exp$3.simplify(),
|
|
1601
1642
|
dtype: dtypeOutput,
|
|
@@ -1624,21 +1665,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1624
1665
|
}
|
|
1625
1666
|
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1626
1667
|
const n = arrays.length;
|
|
1627
|
-
const backend = arrays[0].#backend;
|
|
1628
1668
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1629
1669
|
for (const ar of arrays) ar.#check();
|
|
1630
1670
|
let castDtype;
|
|
1631
1671
|
let castWeakType = true;
|
|
1632
|
-
for (let i = 0; i < n; i++) {
|
|
1633
|
-
if (dtypeOverride
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1639
|
-
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1640
|
-
}
|
|
1672
|
+
for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
|
|
1673
|
+
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1674
|
+
} else if (castDtype === void 0) {
|
|
1675
|
+
castDtype = arrays[i].#dtype;
|
|
1676
|
+
castWeakType = arrays[i].#weakType;
|
|
1677
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1641
1678
|
const weakType = castWeakType && !strongTypeOutput;
|
|
1679
|
+
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
1680
|
+
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
1642
1681
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1643
1682
|
const newShape = [...arrays[0].shape];
|
|
1644
1683
|
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
@@ -1648,12 +1687,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1648
1687
|
});
|
|
1649
1688
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1650
1689
|
const exp$4 = custom(sources);
|
|
1690
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1651
1691
|
return new Array$1({
|
|
1652
1692
|
source: exp$4.simplify(),
|
|
1653
1693
|
st: arrays[0].#st,
|
|
1654
1694
|
dtype: exp$4.dtype,
|
|
1655
1695
|
weakType,
|
|
1656
|
-
backend
|
|
1696
|
+
backend,
|
|
1697
|
+
committed
|
|
1657
1698
|
});
|
|
1658
1699
|
}
|
|
1659
1700
|
const exp$3 = custom(arrays.map((ar, i) => {
|
|
@@ -1662,12 +1703,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1662
1703
|
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1663
1704
|
}));
|
|
1664
1705
|
const st = ShapeTracker.fromShape(newShape);
|
|
1706
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1665
1707
|
return new Array$1({
|
|
1666
1708
|
source: exp$3.simplify(),
|
|
1667
1709
|
st,
|
|
1668
1710
|
dtype: exp$3.dtype,
|
|
1669
1711
|
weakType,
|
|
1670
|
-
backend
|
|
1712
|
+
backend,
|
|
1713
|
+
committed
|
|
1671
1714
|
});
|
|
1672
1715
|
}
|
|
1673
1716
|
let indices;
|
|
@@ -1703,13 +1746,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1703
1746
|
const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
|
|
1704
1747
|
for (const exe of pending) exe.updateRc(1);
|
|
1705
1748
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1706
|
-
|
|
1749
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1707
1750
|
return new Array$1({
|
|
1708
1751
|
source: output,
|
|
1709
1752
|
st: ShapeTracker.fromShape(newShape),
|
|
1710
1753
|
dtype: kernel.dtype,
|
|
1711
1754
|
weakType,
|
|
1712
1755
|
backend,
|
|
1756
|
+
committed,
|
|
1713
1757
|
pending
|
|
1714
1758
|
});
|
|
1715
1759
|
}
|
|
@@ -1787,6 +1831,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1787
1831
|
return ar.#reshape(ar.#st.broadcast(newShape, range(newShape.length - ar.ndim)));
|
|
1788
1832
|
});
|
|
1789
1833
|
}
|
|
1834
|
+
static #computeBackend(name, arrays) {
|
|
1835
|
+
const committed = arrays.filter((ar) => ar.#committed);
|
|
1836
|
+
if (committed.length > 0) {
|
|
1837
|
+
const backend = committed[0].#backend;
|
|
1838
|
+
for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
|
|
1839
|
+
return {
|
|
1840
|
+
backend,
|
|
1841
|
+
committed: true
|
|
1842
|
+
};
|
|
1843
|
+
} else {
|
|
1844
|
+
const backend = arrays.length > 0 ? arrays[0].#backend : getBackend();
|
|
1845
|
+
return {
|
|
1846
|
+
backend,
|
|
1847
|
+
committed: false
|
|
1848
|
+
};
|
|
1849
|
+
}
|
|
1850
|
+
}
|
|
1790
1851
|
/** Realize the array and return it as data. */
|
|
1791
1852
|
async data() {
|
|
1792
1853
|
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
@@ -1946,6 +2007,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1946
2007
|
[Primitive.Log]([x]) {
|
|
1947
2008
|
return [x.#unary(AluOp.Log)];
|
|
1948
2009
|
},
|
|
2010
|
+
[Primitive.Erf]([x]) {
|
|
2011
|
+
return [x.#unary(AluOp.Erf)];
|
|
2012
|
+
},
|
|
2013
|
+
[Primitive.Erfc]([x]) {
|
|
2014
|
+
return [x.#unary(AluOp.Erfc)];
|
|
2015
|
+
},
|
|
1949
2016
|
[Primitive.Sqrt]([x]) {
|
|
1950
2017
|
return [x.#unary(AluOp.Sqrt)];
|
|
1951
2018
|
},
|
|
@@ -2014,7 +2081,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2014
2081
|
},
|
|
2015
2082
|
[Primitive.JitCall](args, { jaxpr, numConsts }) {
|
|
2016
2083
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
2017
|
-
const backend =
|
|
2084
|
+
const { backend, committed } = Array$1.#computeBackend("jit_call", args);
|
|
2085
|
+
args = args.map((ar) => ar._putSync(backend));
|
|
2018
2086
|
const consts = args.slice(0, numConsts);
|
|
2019
2087
|
const tracers = args.slice(numConsts);
|
|
2020
2088
|
const jp = jitCompile(backend, jaxpr, consts);
|
|
@@ -2031,16 +2099,54 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2031
2099
|
dtype: jaxpr.outs[i].aval.dtype,
|
|
2032
2100
|
weakType: jaxpr.outs[i].aval.weakType,
|
|
2033
2101
|
backend,
|
|
2102
|
+
committed,
|
|
2034
2103
|
pending
|
|
2035
2104
|
});
|
|
2036
2105
|
});
|
|
2037
2106
|
}
|
|
2038
2107
|
};
|
|
2039
2108
|
}
|
|
2109
|
+
/** @private */
|
|
2040
2110
|
_realizeSource() {
|
|
2041
2111
|
this.#realize();
|
|
2042
2112
|
return this.#source;
|
|
2043
2113
|
}
|
|
2114
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
2115
|
+
async _put(backend) {
|
|
2116
|
+
if (this.#backend === backend) return this;
|
|
2117
|
+
if (this.#source instanceof AluExp) {
|
|
2118
|
+
const ar = this.#newArrayFrom({
|
|
2119
|
+
backend,
|
|
2120
|
+
committed: true
|
|
2121
|
+
});
|
|
2122
|
+
this.dispose();
|
|
2123
|
+
return ar;
|
|
2124
|
+
} else {
|
|
2125
|
+
const data = await this.data();
|
|
2126
|
+
return arrayFromData(data, this.shape, {
|
|
2127
|
+
dtype: this.#dtype,
|
|
2128
|
+
device: backend.type
|
|
2129
|
+
}, this.#weakType);
|
|
2130
|
+
}
|
|
2131
|
+
}
|
|
2132
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
2133
|
+
_putSync(backend) {
|
|
2134
|
+
if (this.#backend === backend) return this;
|
|
2135
|
+
if (this.#source instanceof AluExp) {
|
|
2136
|
+
const ar = this.#newArrayFrom({
|
|
2137
|
+
backend,
|
|
2138
|
+
committed: true
|
|
2139
|
+
});
|
|
2140
|
+
this.dispose();
|
|
2141
|
+
return ar;
|
|
2142
|
+
} else {
|
|
2143
|
+
const data = this.dataSync();
|
|
2144
|
+
return arrayFromData(data, this.shape, {
|
|
2145
|
+
dtype: this.#dtype,
|
|
2146
|
+
device: backend.type
|
|
2147
|
+
}, this.#weakType);
|
|
2148
|
+
}
|
|
2149
|
+
}
|
|
2044
2150
|
};
|
|
2045
2151
|
/** Constructor for creating a new array from data. */
|
|
2046
2152
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
@@ -2123,7 +2229,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2123
2229
|
st: ShapeTracker.fromShape(shape$1),
|
|
2124
2230
|
dtype,
|
|
2125
2231
|
weakType,
|
|
2126
|
-
backend
|
|
2232
|
+
backend,
|
|
2233
|
+
committed: device != void 0
|
|
2127
2234
|
});
|
|
2128
2235
|
}
|
|
2129
2236
|
function dataToJs(dtype, data, shape$1) {
|
|
@@ -2157,7 +2264,8 @@ function fullInternal(aval, fillValue, device) {
|
|
|
2157
2264
|
st: ShapeTracker.fromShape(aval.shape),
|
|
2158
2265
|
dtype: aval.dtype,
|
|
2159
2266
|
weakType: aval.weakType,
|
|
2160
|
-
backend: getBackend(device)
|
|
2267
|
+
backend: getBackend(device),
|
|
2268
|
+
committed: device != void 0
|
|
2161
2269
|
});
|
|
2162
2270
|
}
|
|
2163
2271
|
function zerosLike$1(val, dtype) {
|
|
@@ -2225,7 +2333,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2225
2333
|
st: ShapeTracker.fromShape([numRows, numCols]),
|
|
2226
2334
|
dtype,
|
|
2227
2335
|
weakType,
|
|
2228
|
-
backend: getBackend(device)
|
|
2336
|
+
backend: getBackend(device),
|
|
2337
|
+
committed: device != void 0
|
|
2229
2338
|
});
|
|
2230
2339
|
}
|
|
2231
2340
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
@@ -2268,7 +2377,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2268
2377
|
st,
|
|
2269
2378
|
dtype,
|
|
2270
2379
|
weakType: false,
|
|
2271
|
-
backend: getBackend(device)
|
|
2380
|
+
backend: getBackend(device),
|
|
2381
|
+
committed: device != void 0
|
|
2272
2382
|
});
|
|
2273
2383
|
}
|
|
2274
2384
|
/**
|
|
@@ -2304,7 +2414,8 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2304
2414
|
st,
|
|
2305
2415
|
dtype,
|
|
2306
2416
|
weakType: false,
|
|
2307
|
-
backend: getBackend(device)
|
|
2417
|
+
backend: getBackend(device),
|
|
2418
|
+
committed: device != void 0
|
|
2308
2419
|
});
|
|
2309
2420
|
}
|
|
2310
2421
|
function aluCompare(a, b, op) {
|
|
@@ -2812,6 +2923,8 @@ const abstractEvalRules = {
|
|
|
2812
2923
|
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2813
2924
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2814
2925
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2926
|
+
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
2927
|
+
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
2815
2928
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2816
2929
|
[Primitive.Min]: binopAbstractEval,
|
|
2817
2930
|
[Primitive.Max]: binopAbstractEval,
|
|
@@ -3064,6 +3177,16 @@ const jvpRules = {
|
|
|
3064
3177
|
[Primitive.Log]([x], [dx]) {
|
|
3065
3178
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3066
3179
|
},
|
|
3180
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3181
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3182
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3183
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3184
|
+
},
|
|
3185
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3186
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3187
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3188
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3189
|
+
},
|
|
3067
3190
|
[Primitive.Sqrt]([x], [dx]) {
|
|
3068
3191
|
const z = sqrt$1(x);
|
|
3069
3192
|
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
@@ -3225,6 +3348,10 @@ var BatchTrace = class extends Trace {
|
|
|
3225
3348
|
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3226
3349
|
const vmapRule = vmapRules[primitive];
|
|
3227
3350
|
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3351
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3352
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3353
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3354
|
+
}
|
|
3228
3355
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3229
3356
|
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3230
3357
|
}
|
|
@@ -3232,24 +3359,28 @@ var BatchTrace = class extends Trace {
|
|
|
3232
3359
|
return this.main.globalData;
|
|
3233
3360
|
}
|
|
3234
3361
|
};
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
|
|
3239
|
-
|
|
3240
|
-
return broadcast(x, shape$1, axis);
|
|
3241
|
-
}
|
|
3242
|
-
}
|
|
3243
|
-
/** Process a primitive with built-in broadcasting. */
|
|
3362
|
+
/**
|
|
3363
|
+
* Process a primitive with built-in broadcasting.
|
|
3364
|
+
*
|
|
3365
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3366
|
+
*/
|
|
3244
3367
|
function broadcastBatcher(op) {
|
|
3245
3368
|
return (axisSize, args, dims) => {
|
|
3246
3369
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3247
|
-
const
|
|
3248
|
-
|
|
3249
|
-
|
|
3250
|
-
args
|
|
3251
|
-
|
|
3252
|
-
|
|
3370
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3371
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3372
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3373
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3374
|
+
args = args.map((x, i) => {
|
|
3375
|
+
if (dims[i] === null) return x;
|
|
3376
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3377
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3378
|
+
x.shape[0],
|
|
3379
|
+
...rep(nd - x.ndim, 1),
|
|
3380
|
+
...x.shape.slice(1)
|
|
3381
|
+
]);
|
|
3382
|
+
return x;
|
|
3383
|
+
});
|
|
3253
3384
|
return [[op(...args)], [0]];
|
|
3254
3385
|
};
|
|
3255
3386
|
}
|
|
@@ -3273,17 +3404,18 @@ const vmapRules = {
|
|
|
3273
3404
|
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3274
3405
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3275
3406
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3407
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3408
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3276
3409
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3277
3410
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3278
3411
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3279
3412
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3280
|
-
|
|
3413
|
+
assertNonNull(xBdim);
|
|
3281
3414
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3282
3415
|
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3283
3416
|
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3284
3417
|
},
|
|
3285
3418
|
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3286
|
-
if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
|
|
3287
3419
|
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3288
3420
|
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3289
3421
|
const z = dot$1(x, y);
|
|
@@ -3292,26 +3424,68 @@ const vmapRules = {
|
|
|
3292
3424
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3293
3425
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3294
3426
|
},
|
|
3427
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3428
|
+
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3429
|
+
assertNonNull(xBdim);
|
|
3430
|
+
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
3431
|
+
newPerm.splice(xBdim, 0, xBdim);
|
|
3432
|
+
return [[transpose$1(x, newPerm)], [xBdim]];
|
|
3433
|
+
},
|
|
3434
|
+
[Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
|
|
3435
|
+
assertNonNull(xBdim);
|
|
3436
|
+
const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
|
|
3437
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3438
|
+
return [[broadcast(x, newShape, newAxis)], [xBdim]];
|
|
3439
|
+
},
|
|
3295
3440
|
[Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
|
|
3296
|
-
if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
|
|
3297
3441
|
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3298
3442
|
return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
|
|
3299
3443
|
},
|
|
3300
3444
|
[Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
|
|
3301
|
-
|
|
3445
|
+
assertNonNull(xBdim);
|
|
3302
3446
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3303
3447
|
return [[flip$1(x, newAxis)], [xBdim]];
|
|
3304
3448
|
},
|
|
3305
3449
|
[Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
|
|
3306
|
-
|
|
3450
|
+
assertNonNull(xBdim);
|
|
3307
3451
|
const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
|
|
3308
3452
|
return [[shrink(x, newSlice)], [xBdim]];
|
|
3309
3453
|
},
|
|
3310
3454
|
[Primitive.Pad](axisSize, [x], [xBdim], { width }) {
|
|
3311
|
-
|
|
3455
|
+
assertNonNull(xBdim);
|
|
3312
3456
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3313
3457
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3314
3458
|
},
|
|
3459
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3460
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3461
|
+
assertNonNull(xBdim);
|
|
3462
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3463
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3464
|
+
let newOutDim = outDim;
|
|
3465
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3466
|
+
else newOutDim += 1;
|
|
3467
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3468
|
+
}
|
|
3469
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3470
|
+
indices = indices.map((m, i) => {
|
|
3471
|
+
if (indicesBdim[i] === null) return m;
|
|
3472
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3473
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3474
|
+
m.shape[0],
|
|
3475
|
+
...rep(nd - m.ndim, 1),
|
|
3476
|
+
...m.shape.slice(1)
|
|
3477
|
+
]);
|
|
3478
|
+
return m;
|
|
3479
|
+
});
|
|
3480
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3481
|
+
else {
|
|
3482
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3483
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3484
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
|
|
3485
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3486
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3487
|
+
}
|
|
3488
|
+
},
|
|
3315
3489
|
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3316
3490
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3317
3491
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
@@ -3371,12 +3545,14 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3371
3545
|
function vmap$1(f, inAxes = 0) {
|
|
3372
3546
|
return (...args) => {
|
|
3373
3547
|
const [argsFlat, inTree] = flatten(args);
|
|
3374
|
-
let inAxesFlat;
|
|
3548
|
+
let inAxesFlat = [];
|
|
3375
3549
|
if (typeof inAxes === "number") inAxesFlat = rep(argsFlat.length, inAxes);
|
|
3550
|
+
else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...rep(inTree.childTreedefs[i].size, null));
|
|
3551
|
+
else if (typeof inAxes[i] === "number") inAxesFlat.push(...rep(inTree.childTreedefs[i].size, inAxes[i]));
|
|
3376
3552
|
else {
|
|
3377
|
-
|
|
3378
|
-
[
|
|
3379
|
-
|
|
3553
|
+
const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
|
|
3554
|
+
if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
|
|
3555
|
+
inAxesFlat.push(...axesFlat);
|
|
3380
3556
|
}
|
|
3381
3557
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3382
3558
|
const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
|
|
@@ -3996,7 +4172,7 @@ function valueAndGrad$1(f) {
|
|
|
3996
4172
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3997
4173
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3998
4174
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3999
|
-
const [ct, ...rest] = fVjp(
|
|
4175
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4000
4176
|
for (const r of rest) dispose(r);
|
|
4001
4177
|
fVjp.dispose();
|
|
4002
4178
|
return [y, ct];
|
|
@@ -4024,7 +4200,10 @@ __export(lax_exports, {
|
|
|
4024
4200
|
conv: () => conv$1,
|
|
4025
4201
|
convGeneralDilated: () => convGeneralDilated,
|
|
4026
4202
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4027
|
-
|
|
4203
|
+
erf: () => erf,
|
|
4204
|
+
erfc: () => erfc,
|
|
4205
|
+
reduceWindow: () => reduceWindow,
|
|
4206
|
+
stopGradient: () => stopGradient$1
|
|
4028
4207
|
});
|
|
4029
4208
|
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4030
4209
|
const padType = padding.toUpperCase();
|
|
@@ -4083,6 +4262,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
|
4083
4262
|
strides: windowStrides
|
|
4084
4263
|
}));
|
|
4085
4264
|
}
|
|
4265
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4266
|
+
function erf(x) {
|
|
4267
|
+
return erf$1(x);
|
|
4268
|
+
}
|
|
4269
|
+
/**
|
|
4270
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4271
|
+
*
|
|
4272
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4273
|
+
* where `erf(x)` is very close to 1.
|
|
4274
|
+
*/
|
|
4275
|
+
function erfc(x) {
|
|
4276
|
+
return erfc$1(x);
|
|
4277
|
+
}
|
|
4278
|
+
/**
|
|
4279
|
+
* Stops gradient computation.
|
|
4280
|
+
*
|
|
4281
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
4282
|
+
* forward or reverse-mode automatic differentiation.
|
|
4283
|
+
*/
|
|
4284
|
+
function stopGradient$1(x) {
|
|
4285
|
+
return stopGradient(x);
|
|
4286
|
+
}
|
|
4086
4287
|
|
|
4087
4288
|
//#endregion
|
|
4088
4289
|
//#region src/numpy.ts
|
|
@@ -4145,6 +4346,9 @@ __export(numpy_exports, {
|
|
|
4145
4346
|
fullLike: () => fullLike$1,
|
|
4146
4347
|
greater: () => greater,
|
|
4147
4348
|
greaterEqual: () => greaterEqual,
|
|
4349
|
+
hamming: () => hamming,
|
|
4350
|
+
hann: () => hann,
|
|
4351
|
+
heaviside: () => heaviside,
|
|
4148
4352
|
hstack: () => hstack,
|
|
4149
4353
|
hypot: () => hypot,
|
|
4150
4354
|
identity: () => identity$1,
|
|
@@ -4784,6 +4988,32 @@ function sign(x) {
|
|
|
4784
4988
|
x = fudgeArray(x);
|
|
4785
4989
|
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4786
4990
|
}
|
|
4991
|
+
/**
|
|
4992
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
4993
|
+
*
|
|
4994
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
4995
|
+
*/
|
|
4996
|
+
function hamming(M) {
|
|
4997
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
|
|
4998
|
+
}
|
|
4999
|
+
/**
|
|
5000
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
5001
|
+
*
|
|
5002
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5003
|
+
*/
|
|
5004
|
+
function hann(M) {
|
|
5005
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
|
|
5006
|
+
}
|
|
5007
|
+
/**
|
|
5008
|
+
* @function
|
|
5009
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
5010
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
5011
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
5012
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
5013
|
+
*/
|
|
5014
|
+
const heaviside = jit$1(function heaviside$1(x1, x2) {
|
|
5015
|
+
return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
|
|
5016
|
+
});
|
|
4787
5017
|
/** Calculate element-wise square of the input array. */
|
|
4788
5018
|
function square(x) {
|
|
4789
5019
|
x = fudgeArray(x);
|
|
@@ -4803,8 +5033,8 @@ function acos(x) {
|
|
|
4803
5033
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4804
5034
|
*
|
|
4805
5035
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4806
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
4807
|
-
* improvements.
|
|
5036
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
5037
|
+
* stability improvements.
|
|
4808
5038
|
*/
|
|
4809
5039
|
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4810
5040
|
return sqrt(square(x1).add(square(x2)));
|
|
@@ -5128,18 +5358,20 @@ function celu(x, alpha = 1) {
|
|
|
5128
5358
|
* @function
|
|
5129
5359
|
* Gaussion error linear unit (GELU) activation function.
|
|
5130
5360
|
*
|
|
5131
|
-
* This is computed element-wise.
|
|
5132
|
-
*
|
|
5133
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5361
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
5362
|
+
* `approximate` is set (default true):
|
|
5134
5363
|
*
|
|
5135
|
-
*
|
|
5364
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
5365
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
5136
5366
|
*
|
|
5137
|
-
*
|
|
5367
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5138
5368
|
*/
|
|
5139
|
-
const gelu = jit$1(function gelu$1(x) {
|
|
5140
|
-
|
|
5141
|
-
|
|
5142
|
-
|
|
5369
|
+
const gelu = jit$1(function gelu$1(x, opts) {
|
|
5370
|
+
if (opts?.approximate ?? true) {
|
|
5371
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5372
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5373
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
|
|
5374
|
+
}, { staticArgnums: [1] });
|
|
5143
5375
|
/**
|
|
5144
5376
|
* Gated linear unit (GLU) activation function.
|
|
5145
5377
|
*
|
|
@@ -5360,6 +5592,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
5360
5592
|
return radius.mul(cos(theta));
|
|
5361
5593
|
}, { staticArgnums: [1] });
|
|
5362
5594
|
|
|
5595
|
+
//#endregion
|
|
5596
|
+
//#region src/scipy-special.ts
|
|
5597
|
+
var scipy_special_exports = {};
|
|
5598
|
+
__export(scipy_special_exports, {
|
|
5599
|
+
erf: () => erf,
|
|
5600
|
+
erfc: () => erfc,
|
|
5601
|
+
logSoftmax: () => logSoftmax,
|
|
5602
|
+
logit: () => logit,
|
|
5603
|
+
logsumexp: () => logsumexp,
|
|
5604
|
+
softmax: () => softmax
|
|
5605
|
+
});
|
|
5606
|
+
/**
|
|
5607
|
+
* @function
|
|
5608
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
5609
|
+
*/
|
|
5610
|
+
const logit = jit$1(function logit$1(x) {
|
|
5611
|
+
return log(x.ref.div(subtract(1, x)));
|
|
5612
|
+
});
|
|
5613
|
+
|
|
5363
5614
|
//#endregion
|
|
5364
5615
|
//#region src/polyfills.ts
|
|
5365
5616
|
/** @file Polyfills for using this library. */
|
|
@@ -5453,6 +5704,25 @@ async function blockUntilReady(x) {
|
|
|
5453
5704
|
await Promise.all(promises);
|
|
5454
5705
|
return x;
|
|
5455
5706
|
}
|
|
5707
|
+
/**
|
|
5708
|
+
* Transfer `x` to `device`.
|
|
5709
|
+
*
|
|
5710
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
5711
|
+
* is committed to the device.
|
|
5712
|
+
*
|
|
5713
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
5714
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
5715
|
+
* default device.
|
|
5716
|
+
*/
|
|
5717
|
+
async function devicePut(x, device) {
|
|
5718
|
+
const [xflat, structure$1] = flatten(x);
|
|
5719
|
+
const yflat = await Promise.all(xflat.map((leaf) => {
|
|
5720
|
+
if (leaf instanceof Array$1) return device ? leaf._put(getBackend(device)) : Promise.resolve(leaf);
|
|
5721
|
+
else return Promise.resolve(array(leaf, { device }));
|
|
5722
|
+
}));
|
|
5723
|
+
return unflatten(structure$1, yflat);
|
|
5724
|
+
}
|
|
5456
5725
|
|
|
5457
5726
|
//#endregion
|
|
5458
|
-
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
5727
|
+
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
5728
|
+
//# sourceMappingURL=index.js.map
|