@jax-js/jax 0.0.5 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +267 -92
- package/dist/{backend-yEU0L_ig.cjs → backend-BbrKEB18.cjs} +378 -183
- package/dist/{backend-CdcTZEOF.js → backend-CoVtc9dx.js} +366 -177
- package/dist/index.cjs +385 -74
- package/dist/index.d.cts +115 -23
- package/dist/index.d.ts +115 -23
- package/dist/index.js +378 -74
- package/dist/{webgpu-CM-xNYzW.js → webgpu-B3UVme6n.js} +188 -153
- package/dist/{webgpu-CNOpiO5T.cjs → webgpu-DGYNVHma.cjs} +188 -153
- package/package.json +25 -15
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-CoVtc9dx.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -29,6 +29,10 @@ var JsTreeDef = class JsTreeDef {
|
|
|
29
29
|
this.nodeMetadata = nodeMetadata;
|
|
30
30
|
this.childTreedefs = childTreedefs;
|
|
31
31
|
}
|
|
32
|
+
/** Get the total number of leaves in the tree. */
|
|
33
|
+
get size() {
|
|
34
|
+
return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
|
|
35
|
+
}
|
|
32
36
|
/** Returns a string representation of this tree definition. */
|
|
33
37
|
toString(root = true) {
|
|
34
38
|
if (root) return "JsTreeDef(" + this.toString(false) + ")";
|
|
@@ -184,6 +188,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
|
|
|
184
188
|
const s_ = strides;
|
|
185
189
|
const d_ = dilation;
|
|
186
190
|
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
191
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
192
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
|
|
193
|
+
st = st.reshape([...noop, ...zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
|
|
194
|
+
st = st.permute([
|
|
195
|
+
...range(noop.length),
|
|
196
|
+
...ks.map((_, j) => noop.length + 2 * j),
|
|
197
|
+
...ks.map((_, j) => noop.length + 2 * j + 1)
|
|
198
|
+
]);
|
|
199
|
+
return st;
|
|
200
|
+
}
|
|
187
201
|
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
188
202
|
const kidf = zipn(ks, i_, d_, f_);
|
|
189
203
|
st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
@@ -218,6 +232,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
218
232
|
const s_ = strides;
|
|
219
233
|
const d_ = dilation;
|
|
220
234
|
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
235
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
236
|
+
st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
|
|
237
|
+
st = st.pad([...noop.map(() => [0, 0]), ...zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...zip(o_, s_).map(([o, s]) => o * s)]);
|
|
238
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
|
|
239
|
+
return st.reshape(st.shape.concat(rep(ks.length, 1)));
|
|
240
|
+
}
|
|
221
241
|
if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
222
242
|
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
223
243
|
const kidf = zipn(ks, i_, d_, f_);
|
|
@@ -327,6 +347,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
327
347
|
Primitive$1["Atan"] = "atan";
|
|
328
348
|
Primitive$1["Exp"] = "exp";
|
|
329
349
|
Primitive$1["Log"] = "log";
|
|
350
|
+
Primitive$1["Erf"] = "erf";
|
|
351
|
+
Primitive$1["Erfc"] = "erfc";
|
|
330
352
|
Primitive$1["Sqrt"] = "sqrt";
|
|
331
353
|
Primitive$1["Min"] = "min";
|
|
332
354
|
Primitive$1["Max"] = "max";
|
|
@@ -348,11 +370,9 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
348
370
|
return Primitive$1;
|
|
349
371
|
}({});
|
|
350
372
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
351
|
-
CompareOp$1["Greater"] = "greater";
|
|
352
373
|
CompareOp$1["Less"] = "less";
|
|
353
374
|
CompareOp$1["Equal"] = "equal";
|
|
354
375
|
CompareOp$1["NotEqual"] = "not_equal";
|
|
355
|
-
CompareOp$1["GreaterEqual"] = "greater_equal";
|
|
356
376
|
CompareOp$1["LessEqual"] = "less_equal";
|
|
357
377
|
return CompareOp$1;
|
|
358
378
|
}({});
|
|
@@ -404,6 +424,12 @@ function exp$1(x) {
|
|
|
404
424
|
function log$1(x) {
|
|
405
425
|
return bind1(Primitive.Log, [x]);
|
|
406
426
|
}
|
|
427
|
+
function erf$1(x) {
|
|
428
|
+
return bind1(Primitive.Erf, [x]);
|
|
429
|
+
}
|
|
430
|
+
function erfc$1(x) {
|
|
431
|
+
return bind1(Primitive.Erfc, [x]);
|
|
432
|
+
}
|
|
407
433
|
function sqrt$1(x) {
|
|
408
434
|
return bind1(Primitive.Sqrt, [x]);
|
|
409
435
|
}
|
|
@@ -442,7 +468,7 @@ function compare(x, y, op) {
|
|
|
442
468
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
443
469
|
}
|
|
444
470
|
function greater$1(x, y) {
|
|
445
|
-
return compare(
|
|
471
|
+
return compare(y, x, CompareOp.Less);
|
|
446
472
|
}
|
|
447
473
|
function less$1(x, y) {
|
|
448
474
|
return compare(x, y, CompareOp.Less);
|
|
@@ -454,7 +480,7 @@ function notEqual$1(x, y) {
|
|
|
454
480
|
return compare(x, y, CompareOp.NotEqual);
|
|
455
481
|
}
|
|
456
482
|
function greaterEqual$1(x, y) {
|
|
457
|
-
return compare(
|
|
483
|
+
return compare(y, x, CompareOp.LessEqual);
|
|
458
484
|
}
|
|
459
485
|
function lessEqual$1(x, y) {
|
|
460
486
|
return compare(x, y, CompareOp.LessEqual);
|
|
@@ -1146,12 +1172,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1146
1172
|
} else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
1147
1173
|
});
|
|
1148
1174
|
}
|
|
1149
|
-
function broadcastedJit(fn) {
|
|
1175
|
+
function broadcastedJit(fn, opts) {
|
|
1150
1176
|
return (nargs, exps, avals, params) => {
|
|
1151
|
-
|
|
1152
|
-
|
|
1153
|
-
|
|
1154
|
-
|
|
1177
|
+
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1178
|
+
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1179
|
+
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1180
|
+
exps = exps.map((exp$3, i) => {
|
|
1181
|
+
exp$3 = reshapeViews(exp$3, (st) => {
|
|
1182
|
+
if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
|
|
1183
|
+
});
|
|
1184
|
+
if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
|
|
1185
|
+
return exp$3;
|
|
1186
|
+
});
|
|
1155
1187
|
const exp$2 = fn(exps, params);
|
|
1156
1188
|
return new Kernel(nargs, prod(newShape), exp$2);
|
|
1157
1189
|
};
|
|
@@ -1194,6 +1226,8 @@ const jitRules = {
|
|
|
1194
1226
|
[Primitive.Atan]: unopJit(AluExp.atan),
|
|
1195
1227
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
1196
1228
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1229
|
+
[Primitive.Erf]: unopJit(AluExp.erf),
|
|
1230
|
+
[Primitive.Erfc]: unopJit(AluExp.erfc),
|
|
1197
1231
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1198
1232
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
1199
1233
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
@@ -1241,7 +1275,7 @@ const jitRules = {
|
|
|
1241
1275
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1242
1276
|
},
|
|
1243
1277
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1244
|
-
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
|
|
1278
|
+
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1245
1279
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
1246
1280
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
1247
1281
|
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
@@ -1412,7 +1446,7 @@ var PendingExecute = class {
|
|
|
1412
1446
|
/**
|
|
1413
1447
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1414
1448
|
*
|
|
1415
|
-
* This is the library's core data type. Equivalent to `
|
|
1449
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1416
1450
|
* `torch.Tensor`.
|
|
1417
1451
|
*
|
|
1418
1452
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -1427,6 +1461,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1427
1461
|
#source;
|
|
1428
1462
|
#st;
|
|
1429
1463
|
#backend;
|
|
1464
|
+
#committed;
|
|
1430
1465
|
#rc;
|
|
1431
1466
|
#pendingSet;
|
|
1432
1467
|
/**
|
|
@@ -1443,6 +1478,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1443
1478
|
this.#source = args.source;
|
|
1444
1479
|
this.#st = args.st;
|
|
1445
1480
|
this.#backend = args.backend;
|
|
1481
|
+
this.#committed = args.committed;
|
|
1446
1482
|
this.#rc = 1;
|
|
1447
1483
|
this.#pendingSet = new Set(args.pending);
|
|
1448
1484
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
@@ -1470,6 +1506,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1470
1506
|
dtype: args.dtype ?? this.#dtype,
|
|
1471
1507
|
weakType: this.#weakType,
|
|
1472
1508
|
backend: args.backend ?? this.#backend,
|
|
1509
|
+
committed: args.committed ?? this.#committed,
|
|
1473
1510
|
pending: args.pending ?? this.#pending ?? void 0
|
|
1474
1511
|
});
|
|
1475
1512
|
}
|
|
@@ -1525,9 +1562,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1525
1562
|
*/
|
|
1526
1563
|
#gather(indices, axis, outDim) {
|
|
1527
1564
|
this.#check();
|
|
1528
|
-
if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1529
1565
|
const axisSet = new Set(axis);
|
|
1530
1566
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
|
|
1567
|
+
if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1568
|
+
indices = indices.map((ar) => ar._putSync(this.#backend));
|
|
1531
1569
|
indices = Array$1.#broadcastArrays(indices);
|
|
1532
1570
|
const indexShape = indices[0].shape;
|
|
1533
1571
|
const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1596,6 +1634,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1596
1634
|
this.#check();
|
|
1597
1635
|
if (this.#source instanceof AluExp) {
|
|
1598
1636
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1637
|
+
this.dispose();
|
|
1599
1638
|
return this.#newArrayFrom({
|
|
1600
1639
|
source: exp$3.simplify(),
|
|
1601
1640
|
dtype: dtypeOutput,
|
|
@@ -1624,21 +1663,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1624
1663
|
}
|
|
1625
1664
|
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1626
1665
|
const n = arrays.length;
|
|
1627
|
-
const backend = arrays[0].#backend;
|
|
1628
1666
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1629
1667
|
for (const ar of arrays) ar.#check();
|
|
1630
1668
|
let castDtype;
|
|
1631
1669
|
let castWeakType = true;
|
|
1632
|
-
for (let i = 0; i < n; i++) {
|
|
1633
|
-
if (dtypeOverride
|
|
1634
|
-
|
|
1635
|
-
|
|
1636
|
-
|
|
1637
|
-
|
|
1638
|
-
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1639
|
-
if (arrays[i].#backend !== backend) throw new TypeError(`Backend mismatch in ${name}: ${backend.type} vs ${arrays[i].#backend.type}`);
|
|
1640
|
-
}
|
|
1670
|
+
for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
|
|
1671
|
+
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1672
|
+
} else if (castDtype === void 0) {
|
|
1673
|
+
castDtype = arrays[i].#dtype;
|
|
1674
|
+
castWeakType = arrays[i].#weakType;
|
|
1675
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1641
1676
|
const weakType = castWeakType && !strongTypeOutput;
|
|
1677
|
+
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
1678
|
+
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
1642
1679
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1643
1680
|
const newShape = [...arrays[0].shape];
|
|
1644
1681
|
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
@@ -1648,12 +1685,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1648
1685
|
});
|
|
1649
1686
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1650
1687
|
const exp$4 = custom(sources);
|
|
1688
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1651
1689
|
return new Array$1({
|
|
1652
1690
|
source: exp$4.simplify(),
|
|
1653
1691
|
st: arrays[0].#st,
|
|
1654
1692
|
dtype: exp$4.dtype,
|
|
1655
1693
|
weakType,
|
|
1656
|
-
backend
|
|
1694
|
+
backend,
|
|
1695
|
+
committed
|
|
1657
1696
|
});
|
|
1658
1697
|
}
|
|
1659
1698
|
const exp$3 = custom(arrays.map((ar, i) => {
|
|
@@ -1662,12 +1701,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1662
1701
|
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1663
1702
|
}));
|
|
1664
1703
|
const st = ShapeTracker.fromShape(newShape);
|
|
1704
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1665
1705
|
return new Array$1({
|
|
1666
1706
|
source: exp$3.simplify(),
|
|
1667
1707
|
st,
|
|
1668
1708
|
dtype: exp$3.dtype,
|
|
1669
1709
|
weakType,
|
|
1670
|
-
backend
|
|
1710
|
+
backend,
|
|
1711
|
+
committed
|
|
1671
1712
|
});
|
|
1672
1713
|
}
|
|
1673
1714
|
let indices;
|
|
@@ -1703,13 +1744,14 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1703
1744
|
const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
|
|
1704
1745
|
for (const exe of pending) exe.updateRc(1);
|
|
1705
1746
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1706
|
-
|
|
1747
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1707
1748
|
return new Array$1({
|
|
1708
1749
|
source: output,
|
|
1709
1750
|
st: ShapeTracker.fromShape(newShape),
|
|
1710
1751
|
dtype: kernel.dtype,
|
|
1711
1752
|
weakType,
|
|
1712
1753
|
backend,
|
|
1754
|
+
committed,
|
|
1713
1755
|
pending
|
|
1714
1756
|
});
|
|
1715
1757
|
}
|
|
@@ -1787,6 +1829,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1787
1829
|
return ar.#reshape(ar.#st.broadcast(newShape, range(newShape.length - ar.ndim)));
|
|
1788
1830
|
});
|
|
1789
1831
|
}
|
|
1832
|
+
static #computeBackend(name, arrays) {
|
|
1833
|
+
const committed = arrays.filter((ar) => ar.#committed);
|
|
1834
|
+
if (committed.length > 0) {
|
|
1835
|
+
const backend = committed[0].#backend;
|
|
1836
|
+
for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
|
|
1837
|
+
return {
|
|
1838
|
+
backend,
|
|
1839
|
+
committed: true
|
|
1840
|
+
};
|
|
1841
|
+
} else {
|
|
1842
|
+
const backend = arrays.length > 0 ? arrays[0].#backend : getBackend();
|
|
1843
|
+
return {
|
|
1844
|
+
backend,
|
|
1845
|
+
committed: false
|
|
1846
|
+
};
|
|
1847
|
+
}
|
|
1848
|
+
}
|
|
1790
1849
|
/** Realize the array and return it as data. */
|
|
1791
1850
|
async data() {
|
|
1792
1851
|
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
@@ -1946,6 +2005,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1946
2005
|
[Primitive.Log]([x]) {
|
|
1947
2006
|
return [x.#unary(AluOp.Log)];
|
|
1948
2007
|
},
|
|
2008
|
+
[Primitive.Erf]([x]) {
|
|
2009
|
+
return [x.#unary(AluOp.Erf)];
|
|
2010
|
+
},
|
|
2011
|
+
[Primitive.Erfc]([x]) {
|
|
2012
|
+
return [x.#unary(AluOp.Erfc)];
|
|
2013
|
+
},
|
|
1949
2014
|
[Primitive.Sqrt]([x]) {
|
|
1950
2015
|
return [x.#unary(AluOp.Sqrt)];
|
|
1951
2016
|
},
|
|
@@ -2014,7 +2079,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2014
2079
|
},
|
|
2015
2080
|
[Primitive.JitCall](args, { jaxpr, numConsts }) {
|
|
2016
2081
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
2017
|
-
const backend =
|
|
2082
|
+
const { backend, committed } = Array$1.#computeBackend("jit_call", args);
|
|
2083
|
+
args = args.map((ar) => ar._putSync(backend));
|
|
2018
2084
|
const consts = args.slice(0, numConsts);
|
|
2019
2085
|
const tracers = args.slice(numConsts);
|
|
2020
2086
|
const jp = jitCompile(backend, jaxpr, consts);
|
|
@@ -2031,16 +2097,54 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2031
2097
|
dtype: jaxpr.outs[i].aval.dtype,
|
|
2032
2098
|
weakType: jaxpr.outs[i].aval.weakType,
|
|
2033
2099
|
backend,
|
|
2100
|
+
committed,
|
|
2034
2101
|
pending
|
|
2035
2102
|
});
|
|
2036
2103
|
});
|
|
2037
2104
|
}
|
|
2038
2105
|
};
|
|
2039
2106
|
}
|
|
2107
|
+
/** @private */
|
|
2040
2108
|
_realizeSource() {
|
|
2041
2109
|
this.#realize();
|
|
2042
2110
|
return this.#source;
|
|
2043
2111
|
}
|
|
2112
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
2113
|
+
async _put(backend) {
|
|
2114
|
+
if (this.#backend === backend) return this;
|
|
2115
|
+
if (this.#source instanceof AluExp) {
|
|
2116
|
+
const ar = this.#newArrayFrom({
|
|
2117
|
+
backend,
|
|
2118
|
+
committed: true
|
|
2119
|
+
});
|
|
2120
|
+
this.dispose();
|
|
2121
|
+
return ar;
|
|
2122
|
+
} else {
|
|
2123
|
+
const data = await this.data();
|
|
2124
|
+
return arrayFromData(data, this.shape, {
|
|
2125
|
+
dtype: this.#dtype,
|
|
2126
|
+
device: backend.type
|
|
2127
|
+
}, this.#weakType);
|
|
2128
|
+
}
|
|
2129
|
+
}
|
|
2130
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
2131
|
+
_putSync(backend) {
|
|
2132
|
+
if (this.#backend === backend) return this;
|
|
2133
|
+
if (this.#source instanceof AluExp) {
|
|
2134
|
+
const ar = this.#newArrayFrom({
|
|
2135
|
+
backend,
|
|
2136
|
+
committed: true
|
|
2137
|
+
});
|
|
2138
|
+
this.dispose();
|
|
2139
|
+
return ar;
|
|
2140
|
+
} else {
|
|
2141
|
+
const data = this.dataSync();
|
|
2142
|
+
return arrayFromData(data, this.shape, {
|
|
2143
|
+
dtype: this.#dtype,
|
|
2144
|
+
device: backend.type
|
|
2145
|
+
}, this.#weakType);
|
|
2146
|
+
}
|
|
2147
|
+
}
|
|
2044
2148
|
};
|
|
2045
2149
|
/** Constructor for creating a new array from data. */
|
|
2046
2150
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
@@ -2103,6 +2207,9 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2103
2207
|
} else if (data instanceof Float16Array) {
|
|
2104
2208
|
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2105
2209
|
dtype ??= DType.Float16;
|
|
2210
|
+
} else if (data instanceof Float64Array) {
|
|
2211
|
+
if (dtype && dtype !== DType.Float64) throw new Error("Float64Array must have float64 type");
|
|
2212
|
+
dtype ??= DType.Float64;
|
|
2106
2213
|
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2107
2214
|
if (data.length < inlineArrayLimit) {
|
|
2108
2215
|
let allEqual = true;
|
|
@@ -2123,7 +2230,8 @@ function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
|
2123
2230
|
st: ShapeTracker.fromShape(shape$1),
|
|
2124
2231
|
dtype,
|
|
2125
2232
|
weakType,
|
|
2126
|
-
backend
|
|
2233
|
+
backend,
|
|
2234
|
+
committed: device != void 0
|
|
2127
2235
|
});
|
|
2128
2236
|
}
|
|
2129
2237
|
function dataToJs(dtype, data, shape$1) {
|
|
@@ -2157,7 +2265,8 @@ function fullInternal(aval, fillValue, device) {
|
|
|
2157
2265
|
st: ShapeTracker.fromShape(aval.shape),
|
|
2158
2266
|
dtype: aval.dtype,
|
|
2159
2267
|
weakType: aval.weakType,
|
|
2160
|
-
backend: getBackend(device)
|
|
2268
|
+
backend: getBackend(device),
|
|
2269
|
+
committed: device != void 0
|
|
2161
2270
|
});
|
|
2162
2271
|
}
|
|
2163
2272
|
function zerosLike$1(val, dtype) {
|
|
@@ -2225,7 +2334,8 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2225
2334
|
st: ShapeTracker.fromShape([numRows, numCols]),
|
|
2226
2335
|
dtype,
|
|
2227
2336
|
weakType,
|
|
2228
|
-
backend: getBackend(device)
|
|
2337
|
+
backend: getBackend(device),
|
|
2338
|
+
committed: device != void 0
|
|
2229
2339
|
});
|
|
2230
2340
|
}
|
|
2231
2341
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
@@ -2268,7 +2378,8 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2268
2378
|
st,
|
|
2269
2379
|
dtype,
|
|
2270
2380
|
weakType: false,
|
|
2271
|
-
backend: getBackend(device)
|
|
2381
|
+
backend: getBackend(device),
|
|
2382
|
+
committed: device != void 0
|
|
2272
2383
|
});
|
|
2273
2384
|
}
|
|
2274
2385
|
/**
|
|
@@ -2304,16 +2415,15 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2304
2415
|
st,
|
|
2305
2416
|
dtype,
|
|
2306
2417
|
weakType: false,
|
|
2307
|
-
backend: getBackend(device)
|
|
2418
|
+
backend: getBackend(device),
|
|
2419
|
+
committed: device != void 0
|
|
2308
2420
|
});
|
|
2309
2421
|
}
|
|
2310
2422
|
function aluCompare(a, b, op) {
|
|
2311
2423
|
switch (op) {
|
|
2312
|
-
case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
|
|
2313
2424
|
case CompareOp.Less: return AluExp.cmplt(a, b);
|
|
2314
2425
|
case CompareOp.Equal: return AluExp.cmpne(a, b).not();
|
|
2315
2426
|
case CompareOp.NotEqual: return AluExp.cmpne(a, b);
|
|
2316
|
-
case CompareOp.GreaterEqual: return AluExp.cmplt(a, b).not();
|
|
2317
2427
|
case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
|
|
2318
2428
|
}
|
|
2319
2429
|
}
|
|
@@ -2446,7 +2556,7 @@ var JaxprEqn = class {
|
|
|
2446
2556
|
const paramsList = Object.entries(this.params).map(([k, v]) => PPrint.pp(`${k}=${v}`));
|
|
2447
2557
|
if (paramsList.length > 0) rhs = rhs.stack(PPrint.pp(" [ ")).stack(PPrint.prototype.concat(...paramsList)).stack(PPrint.pp(" ] "));
|
|
2448
2558
|
else rhs = rhs.stack(PPrint.pp(" "));
|
|
2449
|
-
rhs = rhs.stack(PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) :
|
|
2559
|
+
rhs = rhs.stack(PPrint.pp(this.inputs.map((x) => x instanceof Var ? vp.name(x) : String(x.value)).join(" ")));
|
|
2450
2560
|
return lhs.stack(PPrint.pp(" = ")).stack(rhs);
|
|
2451
2561
|
}
|
|
2452
2562
|
toString() {
|
|
@@ -2812,6 +2922,8 @@ const abstractEvalRules = {
|
|
|
2812
2922
|
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2813
2923
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2814
2924
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2925
|
+
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
2926
|
+
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
2815
2927
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2816
2928
|
[Primitive.Min]: binopAbstractEval,
|
|
2817
2929
|
[Primitive.Max]: binopAbstractEval,
|
|
@@ -3064,6 +3176,16 @@ const jvpRules = {
|
|
|
3064
3176
|
[Primitive.Log]([x], [dx]) {
|
|
3065
3177
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3066
3178
|
},
|
|
3179
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3180
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3181
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3182
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3183
|
+
},
|
|
3184
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3185
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3186
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3187
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3188
|
+
},
|
|
3067
3189
|
[Primitive.Sqrt]([x], [dx]) {
|
|
3068
3190
|
const z = sqrt$1(x);
|
|
3069
3191
|
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
@@ -3225,6 +3347,10 @@ var BatchTrace = class extends Trace {
|
|
|
3225
3347
|
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3226
3348
|
const vmapRule = vmapRules[primitive];
|
|
3227
3349
|
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3350
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3351
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3352
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3353
|
+
}
|
|
3228
3354
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3229
3355
|
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3230
3356
|
}
|
|
@@ -3232,24 +3358,28 @@ var BatchTrace = class extends Trace {
|
|
|
3232
3358
|
return this.main.globalData;
|
|
3233
3359
|
}
|
|
3234
3360
|
};
|
|
3235
|
-
|
|
3236
|
-
|
|
3237
|
-
|
|
3238
|
-
|
|
3239
|
-
|
|
3240
|
-
return broadcast(x, shape$1, axis);
|
|
3241
|
-
}
|
|
3242
|
-
}
|
|
3243
|
-
/** Process a primitive with built-in broadcasting. */
|
|
3361
|
+
/**
|
|
3362
|
+
* Process a primitive with built-in broadcasting.
|
|
3363
|
+
*
|
|
3364
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3365
|
+
*/
|
|
3244
3366
|
function broadcastBatcher(op) {
|
|
3245
3367
|
return (axisSize, args, dims) => {
|
|
3246
3368
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3247
|
-
const
|
|
3248
|
-
|
|
3249
|
-
|
|
3250
|
-
args
|
|
3251
|
-
|
|
3252
|
-
|
|
3369
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3370
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3371
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3372
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3373
|
+
args = args.map((x, i) => {
|
|
3374
|
+
if (dims[i] === null) return x;
|
|
3375
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3376
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3377
|
+
x.shape[0],
|
|
3378
|
+
...rep(nd - x.ndim, 1),
|
|
3379
|
+
...x.shape.slice(1)
|
|
3380
|
+
]);
|
|
3381
|
+
return x;
|
|
3382
|
+
});
|
|
3253
3383
|
return [[op(...args)], [0]];
|
|
3254
3384
|
};
|
|
3255
3385
|
}
|
|
@@ -3273,17 +3403,18 @@ const vmapRules = {
|
|
|
3273
3403
|
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3274
3404
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3275
3405
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3406
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3407
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3276
3408
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3277
3409
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3278
3410
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3279
3411
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3280
|
-
|
|
3412
|
+
assertNonNull(xBdim);
|
|
3281
3413
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3282
3414
|
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3283
3415
|
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3284
3416
|
},
|
|
3285
3417
|
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3286
|
-
if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
|
|
3287
3418
|
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3288
3419
|
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3289
3420
|
const z = dot$1(x, y);
|
|
@@ -3292,26 +3423,68 @@ const vmapRules = {
|
|
|
3292
3423
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3293
3424
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3294
3425
|
},
|
|
3426
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3427
|
+
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3428
|
+
assertNonNull(xBdim);
|
|
3429
|
+
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
3430
|
+
newPerm.splice(xBdim, 0, xBdim);
|
|
3431
|
+
return [[transpose$1(x, newPerm)], [xBdim]];
|
|
3432
|
+
},
|
|
3433
|
+
[Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
|
|
3434
|
+
assertNonNull(xBdim);
|
|
3435
|
+
const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
|
|
3436
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3437
|
+
return [[broadcast(x, newShape, newAxis)], [xBdim]];
|
|
3438
|
+
},
|
|
3295
3439
|
[Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
|
|
3296
|
-
if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
|
|
3297
3440
|
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3298
3441
|
return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
|
|
3299
3442
|
},
|
|
3300
3443
|
[Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
|
|
3301
|
-
|
|
3444
|
+
assertNonNull(xBdim);
|
|
3302
3445
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3303
3446
|
return [[flip$1(x, newAxis)], [xBdim]];
|
|
3304
3447
|
},
|
|
3305
3448
|
[Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
|
|
3306
|
-
|
|
3449
|
+
assertNonNull(xBdim);
|
|
3307
3450
|
const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
|
|
3308
3451
|
return [[shrink(x, newSlice)], [xBdim]];
|
|
3309
3452
|
},
|
|
3310
3453
|
[Primitive.Pad](axisSize, [x], [xBdim], { width }) {
|
|
3311
|
-
|
|
3454
|
+
assertNonNull(xBdim);
|
|
3312
3455
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3313
3456
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3314
3457
|
},
|
|
3458
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3459
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3460
|
+
assertNonNull(xBdim);
|
|
3461
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3462
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3463
|
+
let newOutDim = outDim;
|
|
3464
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3465
|
+
else newOutDim += 1;
|
|
3466
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3467
|
+
}
|
|
3468
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3469
|
+
indices = indices.map((m, i) => {
|
|
3470
|
+
if (indicesBdim[i] === null) return m;
|
|
3471
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3472
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3473
|
+
m.shape[0],
|
|
3474
|
+
...rep(nd - m.ndim, 1),
|
|
3475
|
+
...m.shape.slice(1)
|
|
3476
|
+
]);
|
|
3477
|
+
return m;
|
|
3478
|
+
});
|
|
3479
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3480
|
+
else {
|
|
3481
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3482
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3483
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
|
|
3484
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3485
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3486
|
+
}
|
|
3487
|
+
},
|
|
3315
3488
|
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3316
3489
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3317
3490
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
@@ -3371,12 +3544,14 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3371
3544
|
function vmap$1(f, inAxes = 0) {
|
|
3372
3545
|
return (...args) => {
|
|
3373
3546
|
const [argsFlat, inTree] = flatten(args);
|
|
3374
|
-
let inAxesFlat;
|
|
3547
|
+
let inAxesFlat = [];
|
|
3375
3548
|
if (typeof inAxes === "number") inAxesFlat = rep(argsFlat.length, inAxes);
|
|
3549
|
+
else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...rep(inTree.childTreedefs[i].size, null));
|
|
3550
|
+
else if (typeof inAxes[i] === "number") inAxesFlat.push(...rep(inTree.childTreedefs[i].size, inAxes[i]));
|
|
3376
3551
|
else {
|
|
3377
|
-
|
|
3378
|
-
[
|
|
3379
|
-
|
|
3552
|
+
const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
|
|
3553
|
+
if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
|
|
3554
|
+
inAxesFlat.push(...axesFlat);
|
|
3380
3555
|
}
|
|
3381
3556
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3382
3557
|
const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
|
|
@@ -3996,7 +4171,7 @@ function valueAndGrad$1(f) {
|
|
|
3996
4171
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3997
4172
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3998
4173
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3999
|
-
const [ct, ...rest] = fVjp(
|
|
4174
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
4000
4175
|
for (const r of rest) dispose(r);
|
|
4001
4176
|
fVjp.dispose();
|
|
4002
4177
|
return [y, ct];
|
|
@@ -4024,7 +4199,10 @@ __export(lax_exports, {
|
|
|
4024
4199
|
conv: () => conv$1,
|
|
4025
4200
|
convGeneralDilated: () => convGeneralDilated,
|
|
4026
4201
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4027
|
-
|
|
4202
|
+
erf: () => erf,
|
|
4203
|
+
erfc: () => erfc,
|
|
4204
|
+
reduceWindow: () => reduceWindow,
|
|
4205
|
+
stopGradient: () => stopGradient$1
|
|
4028
4206
|
});
|
|
4029
4207
|
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4030
4208
|
const padType = padding.toUpperCase();
|
|
@@ -4083,6 +4261,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
|
4083
4261
|
strides: windowStrides
|
|
4084
4262
|
}));
|
|
4085
4263
|
}
|
|
4264
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4265
|
+
function erf(x) {
|
|
4266
|
+
return erf$1(x);
|
|
4267
|
+
}
|
|
4268
|
+
/**
|
|
4269
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4270
|
+
*
|
|
4271
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4272
|
+
* where `erf(x)` is very close to 1.
|
|
4273
|
+
*/
|
|
4274
|
+
function erfc(x) {
|
|
4275
|
+
return erfc$1(x);
|
|
4276
|
+
}
|
|
4277
|
+
/**
|
|
4278
|
+
* Stops gradient computation.
|
|
4279
|
+
*
|
|
4280
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
4281
|
+
* forward or reverse-mode automatic differentiation.
|
|
4282
|
+
*/
|
|
4283
|
+
function stopGradient$1(x) {
|
|
4284
|
+
return stopGradient(x);
|
|
4285
|
+
}
|
|
4086
4286
|
|
|
4087
4287
|
//#endregion
|
|
4088
4288
|
//#region src/numpy.ts
|
|
@@ -4141,16 +4341,25 @@ __export(numpy_exports, {
|
|
|
4141
4341
|
flipud: () => flipud,
|
|
4142
4342
|
float16: () => float16,
|
|
4143
4343
|
float32: () => float32,
|
|
4344
|
+
float64: () => float64,
|
|
4144
4345
|
full: () => full,
|
|
4145
4346
|
fullLike: () => fullLike$1,
|
|
4146
4347
|
greater: () => greater,
|
|
4147
4348
|
greaterEqual: () => greaterEqual,
|
|
4349
|
+
hamming: () => hamming,
|
|
4350
|
+
hann: () => hann,
|
|
4351
|
+
heaviside: () => heaviside,
|
|
4148
4352
|
hstack: () => hstack,
|
|
4149
4353
|
hypot: () => hypot,
|
|
4150
4354
|
identity: () => identity$1,
|
|
4151
4355
|
inf: () => inf,
|
|
4152
4356
|
inner: () => inner,
|
|
4153
4357
|
int32: () => int32,
|
|
4358
|
+
isfinite: () => isfinite,
|
|
4359
|
+
isinf: () => isinf,
|
|
4360
|
+
isnan: () => isnan,
|
|
4361
|
+
isneginf: () => isneginf,
|
|
4362
|
+
isposinf: () => isposinf,
|
|
4154
4363
|
less: () => less,
|
|
4155
4364
|
lessEqual: () => lessEqual,
|
|
4156
4365
|
linspace: () => linspace,
|
|
@@ -4221,6 +4430,7 @@ const int32 = DType.Int32;
|
|
|
4221
4430
|
const uint32 = DType.Uint32;
|
|
4222
4431
|
const bool = DType.Bool;
|
|
4223
4432
|
const float16 = DType.Float16;
|
|
4433
|
+
const float64 = DType.Float64;
|
|
4224
4434
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
4225
4435
|
const e = Math.E;
|
|
4226
4436
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -4784,6 +4994,32 @@ function sign(x) {
|
|
|
4784
4994
|
x = fudgeArray(x);
|
|
4785
4995
|
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4786
4996
|
}
|
|
4997
|
+
/**
|
|
4998
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
4999
|
+
*
|
|
5000
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5001
|
+
*/
|
|
5002
|
+
function hamming(M) {
|
|
5003
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
|
|
5004
|
+
}
|
|
5005
|
+
/**
|
|
5006
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
5007
|
+
*
|
|
5008
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5009
|
+
*/
|
|
5010
|
+
function hann(M) {
|
|
5011
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
|
|
5012
|
+
}
|
|
5013
|
+
/**
|
|
5014
|
+
* @function
|
|
5015
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
5016
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
5017
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
5018
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
5019
|
+
*/
|
|
5020
|
+
const heaviside = jit$1(function heaviside$1(x1, x2) {
|
|
5021
|
+
return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
|
|
5022
|
+
});
|
|
4787
5023
|
/** Calculate element-wise square of the input array. */
|
|
4788
5024
|
function square(x) {
|
|
4789
5025
|
x = fudgeArray(x);
|
|
@@ -4803,8 +5039,8 @@ function acos(x) {
|
|
|
4803
5039
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4804
5040
|
*
|
|
4805
5041
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4806
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
4807
|
-
* improvements.
|
|
5042
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
5043
|
+
* stability improvements.
|
|
4808
5044
|
*/
|
|
4809
5045
|
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4810
5046
|
return sqrt(square(x1).add(square(x2)));
|
|
@@ -4995,6 +5231,34 @@ function var_(x, axis = null, opts) {
|
|
|
4995
5231
|
function std(x, axis = null, opts) {
|
|
4996
5232
|
return sqrt(var_(x, axis, opts));
|
|
4997
5233
|
}
|
|
5234
|
+
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5235
|
+
function isinf(x) {
|
|
5236
|
+
x = fudgeArray(x);
|
|
5237
|
+
return isFloatDtype(x.dtype) ? x.ref.equal(Infinity).add(x.equal(-Infinity)) : fullLike$1(x, false);
|
|
5238
|
+
}
|
|
5239
|
+
/** Test element-wise for NaN (Not a Number). */
|
|
5240
|
+
function isnan(x) {
|
|
5241
|
+
x = fudgeArray(x);
|
|
5242
|
+
return isFloatDtype(x.dtype) ? x.ref.notEqual(x) : fullLike$1(x, false);
|
|
5243
|
+
}
|
|
5244
|
+
/** Test element-wise for negative infinity, return bool array. */
|
|
5245
|
+
function isneginf(x) {
|
|
5246
|
+
x = fudgeArray(x);
|
|
5247
|
+
return isFloatDtype(x.dtype) ? x.equal(-Infinity) : fullLike$1(x, false);
|
|
5248
|
+
}
|
|
5249
|
+
/** Test element-wise for positive infinity, return bool array. */
|
|
5250
|
+
function isposinf(x) {
|
|
5251
|
+
x = fudgeArray(x);
|
|
5252
|
+
return isFloatDtype(x.dtype) ? x.equal(Infinity) : fullLike$1(x, false);
|
|
5253
|
+
}
|
|
5254
|
+
/**
|
|
5255
|
+
* @function
|
|
5256
|
+
* Test element-wise for finite values (not infinity or NaN).
|
|
5257
|
+
*/
|
|
5258
|
+
const isfinite = jit$1(function isfinite$1(x) {
|
|
5259
|
+
if (!isFloatDtype(x.dtype)) return fullLike$1(x, true);
|
|
5260
|
+
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5261
|
+
});
|
|
4998
5262
|
|
|
4999
5263
|
//#endregion
|
|
5000
5264
|
//#region src/nn.ts
|
|
@@ -5128,18 +5392,20 @@ function celu(x, alpha = 1) {
|
|
|
5128
5392
|
* @function
|
|
5129
5393
|
* Gaussion error linear unit (GELU) activation function.
|
|
5130
5394
|
*
|
|
5131
|
-
* This is computed element-wise.
|
|
5132
|
-
*
|
|
5133
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5395
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
5396
|
+
* `approximate` is set (default true):
|
|
5134
5397
|
*
|
|
5135
|
-
*
|
|
5398
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
5399
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
5136
5400
|
*
|
|
5137
|
-
*
|
|
5401
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5138
5402
|
*/
|
|
5139
|
-
const gelu = jit$1(function gelu$1(x) {
|
|
5140
|
-
|
|
5141
|
-
|
|
5142
|
-
|
|
5403
|
+
const gelu = jit$1(function gelu$1(x, opts) {
|
|
5404
|
+
if (opts?.approximate ?? true) {
|
|
5405
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5406
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5407
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
|
|
5408
|
+
}, { staticArgnums: [1] });
|
|
5143
5409
|
/**
|
|
5144
5410
|
* Gated linear unit (GLU) activation function.
|
|
5145
5411
|
*
|
|
@@ -5360,6 +5626,25 @@ const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
|
5360
5626
|
return radius.mul(cos(theta));
|
|
5361
5627
|
}, { staticArgnums: [1] });
|
|
5362
5628
|
|
|
5629
|
+
//#endregion
|
|
5630
|
+
//#region src/scipy-special.ts
|
|
5631
|
+
var scipy_special_exports = {};
|
|
5632
|
+
__export(scipy_special_exports, {
|
|
5633
|
+
erf: () => erf,
|
|
5634
|
+
erfc: () => erfc,
|
|
5635
|
+
logSoftmax: () => logSoftmax,
|
|
5636
|
+
logit: () => logit,
|
|
5637
|
+
logsumexp: () => logsumexp,
|
|
5638
|
+
softmax: () => softmax
|
|
5639
|
+
});
|
|
5640
|
+
/**
|
|
5641
|
+
* @function
|
|
5642
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
5643
|
+
*/
|
|
5644
|
+
const logit = jit$1(function logit$1(x) {
|
|
5645
|
+
return log(x.ref.div(subtract(1, x)));
|
|
5646
|
+
});
|
|
5647
|
+
|
|
5363
5648
|
//#endregion
|
|
5364
5649
|
//#region src/polyfills.ts
|
|
5365
5650
|
/** @file Polyfills for using this library. */
|
|
@@ -5453,6 +5738,25 @@ async function blockUntilReady(x) {
|
|
|
5453
5738
|
await Promise.all(promises);
|
|
5454
5739
|
return x;
|
|
5455
5740
|
}
|
|
5741
|
+
/**
|
|
5742
|
+
* Transfer `x` to `device`.
|
|
5743
|
+
*
|
|
5744
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
5745
|
+
* is committed to the device.
|
|
5746
|
+
*
|
|
5747
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
5748
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
5749
|
+
* default device.
|
|
5750
|
+
*/
|
|
5751
|
+
async function devicePut(x, device) {
|
|
5752
|
+
const [xflat, structure$1] = flatten(x);
|
|
5753
|
+
const yflat = await Promise.all(xflat.map((leaf) => {
|
|
5754
|
+
if (leaf instanceof Array$1) return device ? leaf._put(getBackend(device)) : Promise.resolve(leaf);
|
|
5755
|
+
else return Promise.resolve(array(leaf, { device }));
|
|
5756
|
+
}));
|
|
5757
|
+
return unflatten(structure$1, yflat);
|
|
5758
|
+
}
|
|
5456
5759
|
|
|
5457
5760
|
//#endregion
|
|
5458
|
-
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
5761
|
+
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
5762
|
+
//# sourceMappingURL=index.js.map
|