@jax-js/jax 0.0.5 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-yEU0L_ig.cjs → backend-BbrKEB18.cjs} +378 -183
- package/dist/{backend-CdcTZEOF.js → backend-CoVtc9dx.js} +366 -177
- package/dist/index.cjs +385 -74
- package/dist/index.d.cts +115 -23
- package/dist/index.d.ts +115 -23
- package/dist/index.js +378 -74
- package/dist/{webgpu-CM-xNYzW.js → webgpu-B3UVme6n.js} +188 -153
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-DGYNVHma.cjs} +188 -153
- package/package.json +25 -15
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-BbrKEB18.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
@@ -60,6 +60,10 @@ var JsTreeDef = class JsTreeDef {
|
|
|
60
60
|
this.nodeMetadata = nodeMetadata;
|
|
61
61
|
this.childTreedefs = childTreedefs;
|
|
62
62
|
}
|
|
63
|
+
/** Get the total number of leaves in the tree. */
|
|
64
|
+
get size() {
|
|
65
|
+
return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
|
|
66
|
+
}
|
|
63
67
|
/** Returns a string representation of this tree definition. */
|
|
64
68
|
toString(root = true) {
|
|
65
69
|
if (root) return "JsTreeDef(" + this.toString(false) + ")";
|
|
@@ -215,6 +219,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
|
|
|
215
219
|
const s_ = strides;
|
|
216
220
|
const d_ = dilation;
|
|
217
221
|
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
222
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
223
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
|
|
224
|
+
st = st.reshape([...noop, ...require_backend.zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...require_backend.zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
|
|
225
|
+
st = st.permute([
|
|
226
|
+
...require_backend.range(noop.length),
|
|
227
|
+
...ks.map((_, j) => noop.length + 2 * j),
|
|
228
|
+
...ks.map((_, j) => noop.length + 2 * j + 1)
|
|
229
|
+
]);
|
|
230
|
+
return st;
|
|
231
|
+
}
|
|
218
232
|
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
219
233
|
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
220
234
|
st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
@@ -249,6 +263,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
249
263
|
const s_ = strides;
|
|
250
264
|
const d_ = dilation;
|
|
251
265
|
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
266
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
267
|
+
st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
|
|
268
|
+
st = st.pad([...noop.map(() => [0, 0]), ...require_backend.zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...require_backend.zip(o_, s_).map(([o, s]) => o * s)]);
|
|
269
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
|
|
270
|
+
return st.reshape(st.shape.concat(require_backend.rep(ks.length, 1)));
|
|
271
|
+
}
|
|
252
272
|
if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
253
273
|
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
254
274
|
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
@@ -358,6 +378,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
358
378
|
Primitive$1["Atan"] = "atan";
|
|
359
379
|
Primitive$1["Exp"] = "exp";
|
|
360
380
|
Primitive$1["Log"] = "log";
|
|
381
|
+
Primitive$1["Erf"] = "erf";
|
|
382
|
+
Primitive$1["Erfc"] = "erfc";
|
|
361
383
|
Primitive$1["Sqrt"] = "sqrt";
|
|
362
384
|
Primitive$1["Min"] = "min";
|
|
363
385
|
Primitive$1["Max"] = "max";
|
|
@@ -379,11 +401,9 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
379
401
|
return Primitive$1;
|
|
380
402
|
}({});
|
|
381
403
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
382
|
-
CompareOp$1["Greater"] = "greater";
|
|
383
404
|
CompareOp$1["Less"] = "less";
|
|
384
405
|
CompareOp$1["Equal"] = "equal";
|
|
385
406
|
CompareOp$1["NotEqual"] = "not_equal";
|
|
386
|
-
CompareOp$1["GreaterEqual"] = "greater_equal";
|
|
387
407
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
388
408
|
return CompareOp$1;
|
|
389
409
|
}({});
|
|
@@ -435,6 +455,12 @@ function exp$1(x) {
|
|
|
435
455
|
function log$1(x) {
|
|
436
456
|
return bind1(Primitive.Log, [x]);
|
|
437
457
|
}
|
|
458
|
+
function erf$1(x) {
|
|
459
|
+
return bind1(Primitive.Erf, [x]);
|
|
460
|
+
}
|
|
461
|
+
function erfc$1(x) {
|
|
462
|
+
return bind1(Primitive.Erfc, [x]);
|
|
463
|
+
}
|
|
438
464
|
function sqrt$1(x) {
|
|
439
465
|
return bind1(Primitive.Sqrt, [x]);
|
|
440
466
|
}
|
|
@@ -473,7 +499,7 @@ function compare(x, y, op) {
|
|
|
473
499
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
474
500
|
}
|
|
475
501
|
function greater$1(x, y) {
|
|
476
|
-
return compare(
|
|
502
|
+
return compare(y, x, CompareOp.Less);
|
|
477
503
|
}
|
|
478
504
|
function less$1(x, y) {
|
|
479
505
|
return compare(x, y, CompareOp.Less);
|
|
@@ -485,7 +511,7 @@ function notEqual$1(x, y) {
|
|
|
485
511
|
return compare(x, y, CompareOp.NotEqual);
|
|
486
512
|
}
|
|
487
513
|
function greaterEqual$1(x, y) {
|
|
488
|
-
return compare(
|
|
514
|
+
return compare(y, x, CompareOp.LessEqual);
|
|
489
515
|
}
|
|
490
516
|
function lessEqual$1(x, y) {
|
|
491
517
|
return compare(x, y, CompareOp.LessEqual);
|
|
@@ -1177,12 +1203,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1177
1203
|
} else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
1178
1204
|
});
|
|
1179
1205
|
}
|
|
1180
|
-
function broadcastedJit(fn) {
|
|
1206
|
+
function broadcastedJit(fn, opts) {
|
|
1181
1207
|
return (nargs, exps, avals, params) => {
|
|
1182
|
-
|
|
1183
|
-
|
|
1184
|
-
|
|
1185
|
-
|
|
1208
|
+
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1209
|
+
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1210
|
+
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1211
|
+
exps = exps.map((exp$3, i) => {
|
|
1212
|
+
exp$3 = reshapeViews(exp$3, (st) => {
|
|
1213
|
+
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1214
|
+
});
|
|
1215
|
+
if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
|
|
1216
|
+
return exp$3;
|
|
1217
|
+
});
|
|
1186
1218
|
const exp$2 = fn(exps, params);
|
|
1187
1219
|
return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
|
|
1188
1220
|
};
|
|
@@ -1225,6 +1257,8 @@ const jitRules = {
|
|
|
1225
1257
|
[Primitive.Atan]: unopJit(require_backend.AluExp.atan),
|
|
1226
1258
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
1227
1259
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1260
|
+
[Primitive.Erf]: unopJit(require_backend.AluExp.erf),
|
|
1261
|
+
[Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
|
|
1228
1262
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1229
1263
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
1230
1264
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
@@ -1272,7 +1306,7 @@ const jitRules = {
|
|
|
1272
1306
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1273
1307
|
},
|
|
1274
1308
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1275
|
-
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
|
|
1309
|
+
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1276
1310
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
1277
1311
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
1278
1312
|
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
@@ -1443,7 +1477,7 @@ var PendingExecute = class {
|
|
|
1443
1477
|
/**
|
|
1444
1478
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1445
1479
|
*
|
|
1446
|
-
* This is the library's core data type. Equivalent to `
|
|
1480
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1447
1481
|
* `torch.Tensor`.
|
|
1448
1482
|
*
|
|
1449
1483
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -1458,6 +1492,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1458
1492
|
#source;
|
|
1459
1493
|
#st;
|
|
1460
1494
|
#backend;
|
|
1495
|
+
#committed;
|
|
1461
1496
|
#rc;
|
|
1462
1497
|
#pendingSet;
|
|
1463
1498
|
/**
|
|
@@ -1474,6 +1509,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1474
1509
|
this.#source = args.source;
|
|
1475
1510
|
this.#st = args.st;
|
|
1476
1511
|
this.#backend = args.backend;
|
|
1512
|
+
this.#committed = args.committed;
|
|
1477
1513
|
this.#rc = 1;
|
|
1478
1514
|
this.#pendingSet = new Set(args.pending);
|
|
1479
1515
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
@@ -1501,6 +1537,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1501
1537
|
dtype: args.dtype ?? this.#dtype,
|
|
1502
1538
|
weakType: this.#weakType,
|
|
1503
1539
|
backend: args.backend ?? this.#backend,
|
|
1540
|
+
committed: args.committed ?? this.#committed,
|
|
1504
1541
|
pending: args.pending ?? this.#pending ?? void 0
|
|
1505
1542
|
});
|
|
1506
1543
|
}
|
|
@@ -1556,9 +1593,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1556
1593
|
*/
|
|
1557
1594
|
#gather(indices, axis, outDim) {
|
|
1558
1595
|
this.#check();
|
|
1559
|
-
if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1560
1596
|
const axisSet = new Set(axis);
|
|
1561
1597
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
|
|
1598
|
+
if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1599
|
+
indices = indices.map((ar) => ar._putSync(this.#backend));
|
|
1562
1600
|
indices = Array$1.#broadcastArrays(indices);
|
|
1563
1601
|
const indexShape = indices[0].shape;
|
|
1564
1602
|
const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1627,6 +1665,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1627
1665
|
this.#check();
|
|
1628
1666
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1629
1667
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1668
|
+
this.dispose();
|
|
1630
1669
|
return this.#newArrayFrom({
|
|
1631
1670
|
source: exp$3.simplify(),
|
|
1632
1671
|
dtype: dtypeOutput,
|
|
@@ -1655,21 +1694,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1655
1694
|
}
|
|
1656
1695
|
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1657
1696
|
const n = arrays.length;
|
|
1658
|
-
const backend = arrays[0].#backend;
|
|
1659
1697
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1660
1698
|
for (const ar of arrays) ar.#check();
|
|
1661
1699
|
let castDtype;
|
|
1662
1700
|
let castWeakType = true;
|
|
1663
|
-
for (let i = 0; i < n; i++) {
|
|
1664
|
-
if (dtypeOverride
|
|
1665
|
-
|
|
1666
|
-
|
|
1667
|
-
|
|
1668
|
-
|
|
1669
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1670
|
-
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1671
|
-
}
|
|
1701
|
+
for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
|
|
1702
|
+
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1703
|
+
} else if (castDtype === void 0) {
|
|
1704
|
+
castDtype = arrays[i].#dtype;
|
|
1705
|
+
castWeakType = arrays[i].#weakType;
|
|
1706
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1672
1707
|
const weakType = castWeakType && !strongTypeOutput;
|
|
1708
|
+
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
1709
|
+
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
1673
1710
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1674
1711
|
const newShape = [...arrays[0].shape];
|
|
1675
1712
|
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
@@ -1679,12 +1716,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1679
1716
|
});
|
|
1680
1717
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1681
1718
|
const exp$4 = custom(sources);
|
|
1719
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1682
1720
|
return new Array$1({
|
|
1683
1721
|
source: exp$4.simplify(),
|
|
1684
1722
|
st: arrays[0].#st,
|
|
1685
1723
|
dtype: exp$4.dtype,
|
|
1686
1724
|
weakType,
|
|
1687
|
-
backend
|
|
1725
|
+
backend,
|
|
1726
|
+
committed
|
|
1688
1727
|
});
|
|
1689
1728
|
}
|
|
1690
1729
|
const exp$3 = custom(arrays.map((ar, i) => {
|
|
@@ -1693,12 +1732,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1693
1732
|
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1694
1733
|
}));
|
|
1695
1734
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1735
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1696
1736
|
return new Array$1({
|
|
1697
1737
|
source: exp$3.simplify(),
|
|
1698
1738
|
st,
|
|
1699
1739
|
dtype: exp$3.dtype,
|
|
1700
1740
|
weakType,
|
|
1701
|
-
backend
|
|
1741
|
+
backend,
|
|
1742
|
+
committed
|
|
1702
1743
|
});
|
|
1703
1744
|
}
|
|
1704
1745
|
let indices;
|
|
@@ -1734,13 +1775,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1734
1775
|
const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
|
|
1735
1776
|
for (const exe of pending) exe.updateRc(1);
|
|
1736
1777
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1737
|
-
|
|
1778
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1738
1779
|
return new Array$1({
|
|
1739
1780
|
source: output,
|
|
1740
1781
|
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1741
1782
|
dtype: kernel.dtype,
|
|
1742
1783
|
weakType,
|
|
1743
1784
|
backend,
|
|
1785
|
+
committed,
|
|
1744
1786
|
pending
|
|
1745
1787
|
});
|
|
1746
1788
|
}
|
|
@@ -1818,6 +1860,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1818
1860
|
return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
|
|
1819
1861
|
});
|
|
1820
1862
|
}
|
|
1863
|
+
static #computeBackend(name, arrays) {
|
|
1864
|
+
const committed = arrays.filter((ar) => ar.#committed);
|
|
1865
|
+
if (committed.length > 0) {
|
|
1866
|
+
const backend = committed[0].#backend;
|
|
1867
|
+
for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
|
|
1868
|
+
return {
|
|
1869
|
+
backend,
|
|
1870
|
+
committed: true
|
|
1871
|
+
};
|
|
1872
|
+
} else {
|
|
1873
|
+
const backend = arrays.length > 0 ? arrays[0].#backend : require_backend.getBackend();
|
|
1874
|
+
return {
|
|
1875
|
+
backend,
|
|
1876
|
+
committed: false
|
|
1877
|
+
};
|
|
1878
|
+
}
|
|
1879
|
+
}
|
|
1821
1880
|
/** Realize the array and return it as data. */
|
|
1822
1881
|
async data() {
|
|
1823
1882
|
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
@@ -1977,6 +2036,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1977
2036
|
[Primitive.Log]([x]) {
|
|
1978
2037
|
return [x.#unary(require_backend.AluOp.Log)];
|
|
1979
2038
|
},
|
|
2039
|
+
[Primitive.Erf]([x]) {
|
|
2040
|
+
return [x.#unary(require_backend.AluOp.Erf)];
|
|
2041
|
+
},
|
|
2042
|
+
[Primitive.Erfc]([x]) {
|
|
2043
|
+
return [x.#unary(require_backend.AluOp.Erfc)];
|
|
2044
|
+
},
|
|
1980
2045
|
[Primitive.Sqrt]([x]) {
|
|
1981
2046
|
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
1982
2047
|
},
|
|
@@ -2045,7 +2110,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2045
2110
|
},
|
|
2046
2111
|
[Primitive.JitCall](args, { jaxpr, numConsts }) {
|
|
2047
2112
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
2048
|
-
const backend =
|
|
2113
|
+
const { backend, committed } = Array$1.#computeBackend("jit_call", args);
|
|
2114
|
+
args = args.map((ar) => ar._putSync(backend));
|
|
2049
2115
|
const consts = args.slice(0, numConsts);
|
|
2050
2116
|
const tracers = args.slice(numConsts);
|
|
2051
2117
|
const jp = jitCompile(backend, jaxpr, consts);
|
|
@@ -2062,16 +2128,54 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2062
2128
|
dtype: jaxpr.outs[i].aval.dtype,
|
|
2063
2129
|
weakType: jaxpr.outs[i].aval.weakType,
|
|
2064
2130
|
backend,
|
|
2131
|
+
committed,
|
|
2065
2132
|
pending
|
|
2066
2133
|
});
|
|
2067
2134
|
});
|
|
2068
2135
|
}
|
|
2069
2136
|
};
|
|
2070
2137
|
}
|
|
2138
|
+
/** @private */
|
|
2071
2139
|
_realizeSource() {
|
|
2072
2140
|
this.#realize();
|
|
2073
2141
|
return this.#source;
|
|
2074
2142
|
}
|
|
2143
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
2144
|
+
async _put(backend) {
|
|
2145
|
+
if (this.#backend === backend) return this;
|
|
2146
|
+
if (this.#source instanceof require_backend.AluExp) {
|
|
2147
|
+
const ar = this.#newArrayFrom({
|
|
2148
|
+
backend,
|
|
2149
|
+
committed: true
|
|
2150
|
+
});
|
|
2151
|
+
this.dispose();
|
|
2152
|
+
return ar;
|
|
2153
|
+
} else {
|
|
2154
|
+
const data = await this.data();
|
|
2155
|
+
return arrayFromData(data, this.shape, {
|
|
2156
|
+
dtype: this.#dtype,
|
|
2157
|
+
device: backend.type
|
|
2158
|
+
}, this.#weakType);
|
|
2159
|
+
}
|
|
2160
|
+
}
|
|
2161
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
2162
|
+
_putSync(backend) {
|
|
2163
|
+
if (this.#backend === backend) return this;
|
|
2164
|
+
if (this.#source instanceof require_backend.AluExp) {
|
|
2165
|
+
const ar = this.#newArrayFrom({
|
|
2166
|
+
backend,
|
|
2167
|
+
committed: true
|
|
2168
|
+
});
|
|
2169
|
+
this.dispose();
|
|
2170
|
+
return ar;
|
|
2171
|
+
} else {
|
|
2172
|
+
const data = this.dataSync();
|
|
2173
|
+
return arrayFromData(data, this.shape, {
|
|
2174
|
+
dtype: this.#dtype,
|
|
2175
|
+
device: backend.type
|
|
2176
|
+
}, this.#weakType);
|
|
2177
|
+
}
|
|
2178
|
+
}
|
|
2075
2179
|
};
|
|
2076
2180
|
/** Constructor for creating a new array from data. */
|
|
2077
2181
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
@@ -2134,6 +2238,9 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2134
2238
|
} else if (data instanceof Float16Array) {
|
|
2135
2239
|
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2136
2240
|
dtype ??= require_backend.DType.Float16;
|
|
2241
|
+
} else if (data instanceof Float64Array) {
|
|
2242
|
+
if (dtype && dtype !== require_backend.DType.Float64) throw new Error("Float64Array must have float64 type");
|
|
2243
|
+
dtype ??= require_backend.DType.Float64;
|
|
2137
2244
|
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2138
2245
|
if (data.length < inlineArrayLimit) {
|
|
2139
2246
|
let allEqual = true;
|
|
@@ -2154,7 +2261,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2154
2261
|
st: require_backend.ShapeTracker.fromShape(shape$1),
|
|
2155
2262
|
dtype,
|
|
2156
2263
|
weakType,
|
|
2157
|
-
backend
|
|
2264
|
+
backend,
|
|
2265
|
+
committed: device != void 0
|
|
2158
2266
|
});
|
|
2159
2267
|
}
|
|
2160
2268
|
function dataToJs(dtype, data, shape$1) {
|
|
@@ -2188,7 +2296,8 @@ function fullInternal(aval, fillValue, device) {
|
|
|
2188
2296
|
st: require_backend.ShapeTracker.fromShape(aval.shape),
|
|
2189
2297
|
dtype: aval.dtype,
|
|
2190
2298
|
weakType: aval.weakType,
|
|
2191
|
-
backend: require_backend.getBackend(device)
|
|
2299
|
+
backend: require_backend.getBackend(device),
|
|
2300
|
+
committed: device != void 0
|
|
2192
2301
|
});
|
|
2193
2302
|
}
|
|
2194
2303
|
function zerosLike$1(val, dtype) {
|
|
@@ -2256,7 +2365,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2256
2365
|
st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
|
|
2257
2366
|
dtype,
|
|
2258
2367
|
weakType,
|
|
2259
|
-
backend: require_backend.getBackend(device)
|
|
2368
|
+
backend: require_backend.getBackend(device),
|
|
2369
|
+
committed: device != void 0
|
|
2260
2370
|
});
|
|
2261
2371
|
}
|
|
2262
2372
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
@@ -2299,7 +2409,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2299
2409
|
st,
|
|
2300
2410
|
dtype,
|
|
2301
2411
|
weakType: false,
|
|
2302
|
-
backend: require_backend.getBackend(device)
|
|
2412
|
+
backend: require_backend.getBackend(device),
|
|
2413
|
+
committed: device != void 0
|
|
2303
2414
|
});
|
|
2304
2415
|
}
|
|
2305
2416
|
/**
|
|
@@ -2335,16 +2446,15 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2335
2446
|
st,
|
|
2336
2447
|
dtype,
|
|
2337
2448
|
weakType: false,
|
|
2338
|
-
backend: require_backend.getBackend(device)
|
|
2449
|
+
backend: require_backend.getBackend(device),
|
|
2450
|
+
committed: device != void 0
|
|
2339
2451
|
});
|
|
2340
2452
|
}
|
|
2341
2453
|
function aluCompare(a, b, op) {
|
|
2342
2454
|
switch (op) {
|
|
2343
|
-
case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
|
|
2344
2455
|
case CompareOp.Less: return require_backend.AluExp.cmplt(a, b);
|
|
2345
2456
|
case CompareOp.Equal: return require_backend.AluExp.cmpne(a, b).not();
|
|
2346
2457
|
case CompareOp.NotEqual: return require_backend.AluExp.cmpne(a, b);
|
|
2347
|
-
case CompareOp.GreaterEqual: return require_backend.AluExp.cmplt(a, b).not();
|
|
2348
2458
|
case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
|
|
2349
2459
|
}
|
|
2350
2460
|
}
|
|
@@ -2481,7 +2591,7 @@ var JaxprEqn = class {
|
|
|
2481
2591
|
const paramsList = Object.entries(this.params).map(([k, v]) => require_backend.PPrint.pp(`${k}=${v}`));
|
|
2482
2592
|
if (paramsList.length > 0) rhs = rhs.stack(require_backend.PPrint.pp(" [ ")).stack(require_backend.PPrint.prototype.concat(...paramsList)).stack(require_backend.PPrint.pp(" ] "));
|
|
2483
2593
|
else rhs = rhs.stack(require_backend.PPrint.pp(" "));
|
|
2484
|
-
rhs = rhs.stack(require_backend.PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) :
|
|
2594
|
+
rhs = rhs.stack(require_backend.PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : String(x.value)).join(" ")));
|
|
2485
2595
|
return lhs.stack(require_backend.PPrint.pp(" = ")).stack(rhs);
|
|
2486
2596
|
}
|
|
2487
2597
|
toString() {
|
|
@@ -2847,6 +2957,8 @@ const abstractEvalRules = {
|
|
|
2847
2957
|
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2848
2958
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2849
2959
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2960
|
+
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
2961
|
+
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
2850
2962
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2851
2963
|
[Primitive.Min]: binopAbstractEval,
|
|
2852
2964
|
[Primitive.Max]: binopAbstractEval,
|
|
@@ -3100,6 +3212,16 @@ const jvpRules = {
|
|
|
3100
3212
|
[Primitive.Log]([x], [dx]) {
|
|
3101
3213
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3102
3214
|
},
|
|
3215
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3216
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3217
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3218
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3219
|
+
},
|
|
3220
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3221
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3222
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3223
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3224
|
+
},
|
|
3103
3225
|
[Primitive.Sqrt]([x], [dx]) {
|
|
3104
3226
|
const z = sqrt$1(x);
|
|
3105
3227
|
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
@@ -3262,6 +3384,10 @@ var BatchTrace = class extends Trace {
|
|
|
3262
3384
|
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3263
3385
|
const vmapRule = vmapRules[primitive];
|
|
3264
3386
|
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3387
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3388
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3389
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3390
|
+
}
|
|
3265
3391
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3266
3392
|
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3267
3393
|
}
|
|
@@ -3269,24 +3395,28 @@ var BatchTrace = class extends Trace {
|
|
|
3269
3395
|
return this.main.globalData;
|
|
3270
3396
|
}
|
|
3271
3397
|
};
|
|
3272
|
-
|
|
3273
|
-
|
|
3274
|
-
|
|
3275
|
-
|
|
3276
|
-
|
|
3277
|
-
return broadcast(x, shape$1, axis);
|
|
3278
|
-
}
|
|
3279
|
-
}
|
|
3280
|
-
/** Process a primitive with built-in broadcasting. */
|
|
3398
|
+
/**
|
|
3399
|
+
* Process a primitive with built-in broadcasting.
|
|
3400
|
+
*
|
|
3401
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3402
|
+
*/
|
|
3281
3403
|
function broadcastBatcher(op) {
|
|
3282
3404
|
return (axisSize, args, dims) => {
|
|
3283
3405
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3284
|
-
const
|
|
3285
|
-
|
|
3286
|
-
|
|
3287
|
-
args
|
|
3288
|
-
|
|
3289
|
-
|
|
3406
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3407
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3408
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3409
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3410
|
+
args = args.map((x, i) => {
|
|
3411
|
+
if (dims[i] === null) return x;
|
|
3412
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3413
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3414
|
+
x.shape[0],
|
|
3415
|
+
...require_backend.rep(nd - x.ndim, 1),
|
|
3416
|
+
...x.shape.slice(1)
|
|
3417
|
+
]);
|
|
3418
|
+
return x;
|
|
3419
|
+
});
|
|
3290
3420
|
return [[op(...args)], [0]];
|
|
3291
3421
|
};
|
|
3292
3422
|
}
|
|
@@ -3310,17 +3440,18 @@ const vmapRules = {
|
|
|
3310
3440
|
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3311
3441
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3312
3442
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3443
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3444
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3313
3445
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3314
3446
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3315
3447
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3316
3448
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3317
|
-
|
|
3449
|
+
require_backend.assertNonNull(xBdim);
|
|
3318
3450
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3319
3451
|
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3320
3452
|
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3321
3453
|
},
|
|
3322
3454
|
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3323
|
-
if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
|
|
3324
3455
|
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3325
3456
|
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3326
3457
|
const z = dot$1(x, y);
|
|
@@ -3329,26 +3460,68 @@ const vmapRules = {
|
|
|
3329
3460
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3330
3461
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3331
3462
|
},
|
|
3463
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3464
|
+
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3465
|
+
require_backend.assertNonNull(xBdim);
|
|
3466
|
+
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
3467
|
+
newPerm.splice(xBdim, 0, xBdim);
|
|
3468
|
+
return [[transpose$1(x, newPerm)], [xBdim]];
|
|
3469
|
+
},
|
|
3470
|
+
[Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
|
|
3471
|
+
require_backend.assertNonNull(xBdim);
|
|
3472
|
+
const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
|
|
3473
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3474
|
+
return [[broadcast(x, newShape, newAxis)], [xBdim]];
|
|
3475
|
+
},
|
|
3332
3476
|
[Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
|
|
3333
|
-
if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
|
|
3334
3477
|
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3335
3478
|
return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
|
|
3336
3479
|
},
|
|
3337
3480
|
[Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
|
|
3338
|
-
|
|
3481
|
+
require_backend.assertNonNull(xBdim);
|
|
3339
3482
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3340
3483
|
return [[flip$1(x, newAxis)], [xBdim]];
|
|
3341
3484
|
},
|
|
3342
3485
|
[Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
|
|
3343
|
-
|
|
3486
|
+
require_backend.assertNonNull(xBdim);
|
|
3344
3487
|
const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
|
|
3345
3488
|
return [[shrink(x, newSlice)], [xBdim]];
|
|
3346
3489
|
},
|
|
3347
3490
|
[Primitive.Pad](axisSize, [x], [xBdim], { width }) {
|
|
3348
|
-
|
|
3491
|
+
require_backend.assertNonNull(xBdim);
|
|
3349
3492
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3350
3493
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3351
3494
|
},
|
|
3495
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3496
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3497
|
+
require_backend.assertNonNull(xBdim);
|
|
3498
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3499
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3500
|
+
let newOutDim = outDim;
|
|
3501
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3502
|
+
else newOutDim += 1;
|
|
3503
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3504
|
+
}
|
|
3505
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3506
|
+
indices = indices.map((m, i) => {
|
|
3507
|
+
if (indicesBdim[i] === null) return m;
|
|
3508
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3509
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3510
|
+
m.shape[0],
|
|
3511
|
+
...require_backend.rep(nd - m.ndim, 1),
|
|
3512
|
+
...m.shape.slice(1)
|
|
3513
|
+
]);
|
|
3514
|
+
return m;
|
|
3515
|
+
});
|
|
3516
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3517
|
+
else {
|
|
3518
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3519
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3520
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
|
|
3521
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3522
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3523
|
+
}
|
|
3524
|
+
},
|
|
3352
3525
|
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3353
3526
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3354
3527
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
@@ -3408,12 +3581,14 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3408
3581
|
function vmap$1(f, inAxes = 0) {
|
|
3409
3582
|
return (...args) => {
|
|
3410
3583
|
const [argsFlat, inTree] = flatten(args);
|
|
3411
|
-
let inAxesFlat;
|
|
3584
|
+
let inAxesFlat = [];
|
|
3412
3585
|
if (typeof inAxes === "number") inAxesFlat = require_backend.rep(argsFlat.length, inAxes);
|
|
3586
|
+
else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, null));
|
|
3587
|
+
else if (typeof inAxes[i] === "number") inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, inAxes[i]));
|
|
3413
3588
|
else {
|
|
3414
|
-
|
|
3415
|
-
[
|
|
3416
|
-
|
|
3589
|
+
const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
|
|
3590
|
+
if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
|
|
3591
|
+
inAxesFlat.push(...axesFlat);
|
|
3417
3592
|
}
|
|
3418
3593
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3419
3594
|
const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
|
|
@@ -4033,7 +4208,7 @@ function valueAndGrad$1(f) {
|
|
|
4033
4208
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
4034
4209
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
4035
4210
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
4036
|
-
const [ct, ...rest] = fVjp(
|
|
4211
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4037
4212
|
for (const r of rest) dispose(r);
|
|
4038
4213
|
fVjp.dispose();
|
|
4039
4214
|
return [y, ct];
|
|
@@ -4061,7 +4236,10 @@ __export(lax_exports, {
|
|
|
4061
4236
|
conv: () => conv$1,
|
|
4062
4237
|
convGeneralDilated: () => convGeneralDilated,
|
|
4063
4238
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4064
|
-
|
|
4239
|
+
erf: () => erf,
|
|
4240
|
+
erfc: () => erfc,
|
|
4241
|
+
reduceWindow: () => reduceWindow,
|
|
4242
|
+
stopGradient: () => stopGradient$1
|
|
4065
4243
|
});
|
|
4066
4244
|
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4067
4245
|
const padType = padding.toUpperCase();
|
|
@@ -4120,6 +4298,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
|
4120
4298
|
strides: windowStrides
|
|
4121
4299
|
}));
|
|
4122
4300
|
}
|
|
4301
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4302
|
+
function erf(x) {
|
|
4303
|
+
return erf$1(x);
|
|
4304
|
+
}
|
|
4305
|
+
/**
|
|
4306
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4307
|
+
*
|
|
4308
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4309
|
+
* where `erf(x)` is very close to 1.
|
|
4310
|
+
*/
|
|
4311
|
+
function erfc(x) {
|
|
4312
|
+
return erfc$1(x);
|
|
4313
|
+
}
|
|
4314
|
+
/**
|
|
4315
|
+
* Stops gradient computation.
|
|
4316
|
+
*
|
|
4317
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
4318
|
+
* forward or reverse-mode automatic differentiation.
|
|
4319
|
+
*/
|
|
4320
|
+
function stopGradient$1(x) {
|
|
4321
|
+
return stopGradient(x);
|
|
4322
|
+
}
|
|
4123
4323
|
|
|
4124
4324
|
//#endregion
|
|
4125
4325
|
//#region src/numpy.ts
|
|
@@ -4178,16 +4378,25 @@ __export(numpy_exports, {
|
|
|
4178
4378
|
flipud: () => flipud,
|
|
4179
4379
|
float16: () => float16,
|
|
4180
4380
|
float32: () => float32,
|
|
4381
|
+
float64: () => float64,
|
|
4181
4382
|
full: () => full,
|
|
4182
4383
|
fullLike: () => fullLike$1,
|
|
4183
4384
|
greater: () => greater,
|
|
4184
4385
|
greaterEqual: () => greaterEqual,
|
|
4386
|
+
hamming: () => hamming,
|
|
4387
|
+
hann: () => hann,
|
|
4388
|
+
heaviside: () => heaviside,
|
|
4185
4389
|
hstack: () => hstack,
|
|
4186
4390
|
hypot: () => hypot,
|
|
4187
4391
|
identity: () => identity$1,
|
|
4188
4392
|
inf: () => inf,
|
|
4189
4393
|
inner: () => inner,
|
|
4190
4394
|
int32: () => int32,
|
|
4395
|
+
isfinite: () => isfinite,
|
|
4396
|
+
isinf: () => isinf,
|
|
4397
|
+
isnan: () => isnan,
|
|
4398
|
+
isneginf: () => isneginf,
|
|
4399
|
+
isposinf: () => isposinf,
|
|
4191
4400
|
less: () => less,
|
|
4192
4401
|
lessEqual: () => lessEqual,
|
|
4193
4402
|
linspace: () => linspace,
|
|
@@ -4258,6 +4467,7 @@ const int32 = require_backend.DType.Int32;
|
|
|
4258
4467
|
const uint32 = require_backend.DType.Uint32;
|
|
4259
4468
|
const bool = require_backend.DType.Bool;
|
|
4260
4469
|
const float16 = require_backend.DType.Float16;
|
|
4470
|
+
const float64 = require_backend.DType.Float64;
|
|
4261
4471
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
4262
4472
|
const e = Math.E;
|
|
4263
4473
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -4821,6 +5031,32 @@ function sign(x) {
|
|
|
4821
5031
|
x = fudgeArray(x);
|
|
4822
5032
|
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4823
5033
|
}
|
|
5034
|
+
/**
|
|
5035
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
5036
|
+
*
|
|
5037
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5038
|
+
*/
|
|
5039
|
+
function hamming(M) {
|
|
5040
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
|
|
5041
|
+
}
|
|
5042
|
+
/**
|
|
5043
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
5044
|
+
*
|
|
5045
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5046
|
+
*/
|
|
5047
|
+
function hann(M) {
|
|
5048
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
|
|
5049
|
+
}
|
|
5050
|
+
/**
|
|
5051
|
+
* @function
|
|
5052
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
5053
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
5054
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
5055
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
5056
|
+
*/
|
|
5057
|
+
const heaviside = jit$1(function heaviside$1(x1, x2) {
|
|
5058
|
+
return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
|
|
5059
|
+
});
|
|
4824
5060
|
/** Calculate element-wise square of the input array. */
|
|
4825
5061
|
function square(x) {
|
|
4826
5062
|
x = fudgeArray(x);
|
|
@@ -4840,8 +5076,8 @@ function acos(x) {
|
|
|
4840
5076
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4841
5077
|
*
|
|
4842
5078
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4843
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
4844
|
-
* improvements.
|
|
5079
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
5080
|
+
* stability improvements.
|
|
4845
5081
|
*/
|
|
4846
5082
|
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4847
5083
|
return sqrt(square(x1).add(square(x2)));
|
|
@@ -5032,6 +5268,34 @@ function var_(x, axis = null, opts) {
|
|
|
5032
5268
|
function std(x, axis = null, opts) {
|
|
5033
5269
|
return sqrt(var_(x, axis, opts));
|
|
5034
5270
|
}
|
|
5271
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5272
|
+
function isinf(x) {
|
|
5273
|
+
x = fudgeArray(x);
|
|
5274
|
+
return require_backend.isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
|
|
5275
|
+
}
|
|
5276
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
5277
|
+
function isnan(x) {
|
|
5278
|
+
x = fudgeArray(x);
|
|
5279
|
+
return require_backend.isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
5280
|
+
}
|
|
5281
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
5282
|
+
function isneginf(x) {
|
|
5283
|
+
x = fudgeArray(x);
|
|
5284
|
+
return require_backend.isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
|
|
5285
|
+
}
|
|
5286
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
5287
|
+
function isposinf(x) {
|
|
5288
|
+
x = fudgeArray(x);
|
|
5289
|
+
return require_backend.isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
5290
|
+
}
|
|
5291
|
+
/**
|
|
5292
|
+
* @function
|
|
5293
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
5294
|
+
*/
|
|
5295
|
+
const isfinite = jit$1(function isfinite$1(x) {
|
|
5296
|
+
if (!require_backend.isFloatDtype(x.dtype)) return fullLike$1(x, true);
|
|
5297
|
+
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5298
|
+
});
|
|
5035
5299
|
|
|
5036
5300
|
//#endregion
|
|
5037
5301
|
//#region src/nn.ts
|
|
@@ -5165,18 +5429,20 @@ function celu(x, alpha = 1) {
|
|
|
5165
5429
|
* @function
|
|
5166
5430
|
* Gaussion error linear unit (GELU) activation function.
|
|
5167
5431
|
*
|
|
5168
|
-
* This is computed element-wise.
|
|
5169
|
-
*
|
|
5170
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5432
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
5433
|
+
* `approximate` is set (default true):
|
|
5171
5434
|
*
|
|
5172
|
-
*
|
|
5435
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
5436
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
5173
5437
|
*
|
|
5174
|
-
*
|
|
5438
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5175
5439
|
*/
|
|
5176
|
-
const gelu = jit$1(function gelu$1(x) {
|
|
5177
|
-
|
|
5178
|
-
|
|
5179
|
-
|
|
5440
|
+
const gelu = jit$1(function gelu$1(x, opts) {
|
|
5441
|
+
if (opts?.approximate ?? true) {
|
|
5442
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5443
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5444
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
|
|
5445
|
+
}, { staticArgnums: [1] });
|
|
5180
5446
|
/**
|
|
5181
5447
|
* Gated linear unit (GLU) activation function.
|
|
5182
5448
|
*
|
|
@@ -5397,6 +5663,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
5397
5663
|
return radius.mul(cos(theta));
|
|
5398
5664
|
}, { staticArgnums: [1] });
|
|
5399
5665
|
|
|
5666
|
+
//#endregion
|
|
5667
|
+
//#region src/scipy-special.ts
|
|
5668
|
+
var scipy_special_exports = {};
|
|
5669
|
+
__export(scipy_special_exports, {
|
|
5670
|
+
erf: () => erf,
|
|
5671
|
+
erfc: () => erfc,
|
|
5672
|
+
logSoftmax: () => logSoftmax,
|
|
5673
|
+
logit: () => logit,
|
|
5674
|
+
logsumexp: () => logsumexp,
|
|
5675
|
+
softmax: () => softmax
|
|
5676
|
+
});
|
|
5677
|
+
/**
|
|
5678
|
+
* @function
|
|
5679
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
5680
|
+
*/
|
|
5681
|
+
const logit = jit$1(function logit$1(x) {
|
|
5682
|
+
return log(x.ref.div(subtract(1, x)));
|
|
5683
|
+
});
|
|
5684
|
+
|
|
5400
5685
|
//#endregion
|
|
5401
5686
|
//#region src/polyfills.ts
|
|
5402
5687
|
/** @file Polyfills for using this library. */
|
|
@@ -5490,6 +5775,24 @@ async function blockUntilReady(x) {
|
|
|
5490
5775
|
await Promise.all(promises);
|
|
5491
5776
|
return x;
|
|
5492
5777
|
}
|
|
5778
|
+
/**
|
|
5779
|
+
* Transfer `x` to `device`.
|
|
5780
|
+
*
|
|
5781
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
5782
|
+
* is committed to the device.
|
|
5783
|
+
*
|
|
5784
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
5785
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
5786
|
+
* default device.
|
|
5787
|
+
*/
|
|
5788
|
+
async function devicePut(x, device) {
|
|
5789
|
+
const [xflat, structure$1] = flatten(x);
|
|
5790
|
+
const yflat = await Promise.all(xflat.map((leaf) => {
|
|
5791
|
+
if (leaf instanceof Array$1) return device ? leaf._put(require_backend.getBackend(device)) : Promise.resolve(leaf);
|
|
5792
|
+
else return Promise.resolve(array(leaf, { device }));
|
|
5793
|
+
}));
|
|
5794
|
+
return unflatten(structure$1, yflat);
|
|
5795
|
+
}
|
|
5493
5796
|
|
|
5494
5797
|
//#endregion
|
|
5495
5798
|
exports.Array = Array$1;
|
|
@@ -5497,6 +5800,7 @@ exports.DType = require_backend.DType;
|
|
|
5497
5800
|
exports.Jaxpr = Jaxpr;
|
|
5498
5801
|
exports.blockUntilReady = blockUntilReady;
|
|
5499
5802
|
exports.defaultDevice = require_backend.defaultDevice;
|
|
5803
|
+
exports.devicePut = devicePut;
|
|
5500
5804
|
exports.devices = require_backend.devices;
|
|
5501
5805
|
exports.grad = grad;
|
|
5502
5806
|
exports.init = require_backend.init;
|
|
@@ -5531,6 +5835,12 @@ Object.defineProperty(exports, 'random', {
|
|
|
5531
5835
|
return random_exports;
|
|
5532
5836
|
}
|
|
5533
5837
|
});
|
|
5838
|
+
Object.defineProperty(exports, 'scipySpecial', {
|
|
5839
|
+
enumerable: true,
|
|
5840
|
+
get: function () {
|
|
5841
|
+
return scipy_special_exports;
|
|
5842
|
+
}
|
|
5843
|
+
});
|
|
5534
5844
|
exports.setDebug = require_backend.setDebug;
|
|
5535
5845
|
Object.defineProperty(exports, 'tree', {
|
|
5536
5846
|
enumerable: true,
|
|
@@ -5540,4 +5850,5 @@ Object.defineProperty(exports, 'tree', {
|
|
|
5540
5850
|
});
|
|
5541
5851
|
exports.valueAndGrad = valueAndGrad;
|
|
5542
5852
|
exports.vjp = vjp;
|
|
5543
|
-
exports.vmap = vmap;
|
|
5853
|
+
exports.vmap = vmap;
|
|
5854
|
+
//# sourceMappingURL=index.cjs.map
|