@jax-js/jax 0.0.4 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +296 -78
- package/dist/{backend-EBRGmEYw.js → backend-DwIAd0AG.js} +238 -116
- package/dist/{backend-Ss1Mev_-.cjs → backend-FtkbO6pI.cjs} +256 -122
- package/dist/index.cjs +653 -277
- package/dist/index.d.cts +167 -44
- package/dist/index.d.ts +167 -44
- package/dist/index.js +637 -268
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.cjs
CHANGED
|
@@ -30,7 +30,7 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-FtkbO6pI.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
@@ -60,6 +60,10 @@ var JsTreeDef = class JsTreeDef {
|
|
|
60
60
|
this.nodeMetadata = nodeMetadata;
|
|
61
61
|
this.childTreedefs = childTreedefs;
|
|
62
62
|
}
|
|
63
|
+
/** Get the total number of leaves in the tree. */
|
|
64
|
+
get size() {
|
|
65
|
+
return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
|
|
66
|
+
}
|
|
63
67
|
/** Returns a string representation of this tree definition. */
|
|
64
68
|
toString(root = true) {
|
|
65
69
|
if (root) return "JsTreeDef(" + this.toString(false) + ")";
|
|
@@ -215,6 +219,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
|
|
|
215
219
|
const s_ = strides;
|
|
216
220
|
const d_ = dilation;
|
|
217
221
|
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
222
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
223
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
|
|
224
|
+
st = st.reshape([...noop, ...require_backend.zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...require_backend.zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
|
|
225
|
+
st = st.permute([
|
|
226
|
+
...require_backend.range(noop.length),
|
|
227
|
+
...ks.map((_, j) => noop.length + 2 * j),
|
|
228
|
+
...ks.map((_, j) => noop.length + 2 * j + 1)
|
|
229
|
+
]);
|
|
230
|
+
return st;
|
|
231
|
+
}
|
|
218
232
|
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
219
233
|
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
220
234
|
st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
@@ -249,6 +263,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
249
263
|
const s_ = strides;
|
|
250
264
|
const d_ = dilation;
|
|
251
265
|
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
266
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
267
|
+
st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
|
|
268
|
+
st = st.pad([...noop.map(() => [0, 0]), ...require_backend.zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...require_backend.zip(o_, s_).map(([o, s]) => o * s)]);
|
|
269
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...require_backend.zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
|
|
270
|
+
return st.reshape(st.shape.concat(require_backend.rep(ks.length, 1)));
|
|
271
|
+
}
|
|
252
272
|
if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
253
273
|
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
254
274
|
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
@@ -358,6 +378,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
358
378
|
Primitive$1["Atan"] = "atan";
|
|
359
379
|
Primitive$1["Exp"] = "exp";
|
|
360
380
|
Primitive$1["Log"] = "log";
|
|
381
|
+
Primitive$1["Erf"] = "erf";
|
|
382
|
+
Primitive$1["Erfc"] = "erfc";
|
|
361
383
|
Primitive$1["Sqrt"] = "sqrt";
|
|
362
384
|
Primitive$1["Min"] = "min";
|
|
363
385
|
Primitive$1["Max"] = "max";
|
|
@@ -435,6 +457,12 @@ function exp$1(x) {
|
|
|
435
457
|
function log$1(x) {
|
|
436
458
|
return bind1(Primitive.Log, [x]);
|
|
437
459
|
}
|
|
460
|
+
function erf$1(x) {
|
|
461
|
+
return bind1(Primitive.Erf, [x]);
|
|
462
|
+
}
|
|
463
|
+
function erfc$1(x) {
|
|
464
|
+
return bind1(Primitive.Erfc, [x]);
|
|
465
|
+
}
|
|
438
466
|
function sqrt$1(x) {
|
|
439
467
|
return bind1(Primitive.Sqrt, [x]);
|
|
440
468
|
}
|
|
@@ -596,6 +624,21 @@ var Trace = class {
|
|
|
596
624
|
this.main = main;
|
|
597
625
|
}
|
|
598
626
|
};
|
|
627
|
+
/**
|
|
628
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
629
|
+
*
|
|
630
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
631
|
+
* implemented in that function as `weakType` is not passed.
|
|
632
|
+
*/
|
|
633
|
+
function promoteAvals(a, b) {
|
|
634
|
+
const shape$1 = require_backend.generalBroadcast(a.shape, b.shape);
|
|
635
|
+
const weakType = a.weakType && b.weakType;
|
|
636
|
+
let dtype;
|
|
637
|
+
if (a.weakType === b.weakType) dtype = require_backend.promoteTypes(a.dtype, b.dtype);
|
|
638
|
+
else if (a.weakType) dtype = require_backend.promoteTypes(b.dtype, require_backend.DType.Uint32);
|
|
639
|
+
else dtype = require_backend.promoteTypes(a.dtype, require_backend.DType.Uint32);
|
|
640
|
+
return new ShapedArray(shape$1, dtype, weakType);
|
|
641
|
+
}
|
|
599
642
|
var Tracer = class Tracer {
|
|
600
643
|
/** @ignore */
|
|
601
644
|
_trace;
|
|
@@ -610,10 +653,19 @@ var Tracer = class Tracer {
|
|
|
610
653
|
get size() {
|
|
611
654
|
return require_backend.prod(this.shape);
|
|
612
655
|
}
|
|
613
|
-
/** The dtype of the array. */
|
|
656
|
+
/** The dtype of elements stored in the array. */
|
|
614
657
|
get dtype() {
|
|
615
658
|
return this.aval.dtype;
|
|
616
659
|
}
|
|
660
|
+
/**
|
|
661
|
+
* Whether the array is weakly typed.
|
|
662
|
+
*
|
|
663
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
664
|
+
* `promoteTypes()` for details.
|
|
665
|
+
*/
|
|
666
|
+
get weakType() {
|
|
667
|
+
return this.aval.weakType;
|
|
668
|
+
}
|
|
617
669
|
/** The number of dimensions of the array. */
|
|
618
670
|
get ndim() {
|
|
619
671
|
return this.shape.length;
|
|
@@ -850,12 +902,13 @@ function getShape(x) {
|
|
|
850
902
|
return x instanceof Tracer ? x.shape : [];
|
|
851
903
|
}
|
|
852
904
|
var ShapedArray = class ShapedArray {
|
|
853
|
-
constructor(shape$1, dtype) {
|
|
905
|
+
constructor(shape$1, dtype, weakType) {
|
|
854
906
|
this.shape = shape$1;
|
|
855
907
|
this.dtype = dtype;
|
|
908
|
+
this.weakType = weakType;
|
|
856
909
|
}
|
|
857
910
|
static fromAval(aval) {
|
|
858
|
-
return new ShapedArray(aval.shape, aval.dtype);
|
|
911
|
+
return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
|
|
859
912
|
}
|
|
860
913
|
get ndim() {
|
|
861
914
|
return this.shape.length;
|
|
@@ -869,7 +922,7 @@ var ShapedArray = class ShapedArray {
|
|
|
869
922
|
};
|
|
870
923
|
function getAval(x) {
|
|
871
924
|
if (x instanceof Tracer) return x.aval;
|
|
872
|
-
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32);
|
|
925
|
+
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? require_backend.DType.Bool : require_backend.DType.Float32, typeof x === "boolean" ? false : true);
|
|
873
926
|
else throw new TypeError(`Unknown value: ${x}`);
|
|
874
927
|
}
|
|
875
928
|
function bind(prim, args, params = {}) {
|
|
@@ -1152,12 +1205,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1152
1205
|
} else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
1153
1206
|
});
|
|
1154
1207
|
}
|
|
1155
|
-
function broadcastedJit(fn) {
|
|
1208
|
+
function broadcastedJit(fn, opts) {
|
|
1156
1209
|
return (nargs, exps, avals, params) => {
|
|
1157
|
-
|
|
1158
|
-
|
|
1159
|
-
|
|
1160
|
-
|
|
1210
|
+
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1211
|
+
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1212
|
+
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1213
|
+
exps = exps.map((exp$3, i) => {
|
|
1214
|
+
exp$3 = reshapeViews(exp$3, (st) => {
|
|
1215
|
+
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1216
|
+
});
|
|
1217
|
+
if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = require_backend.AluExp.cast(newDtype, exp$3);
|
|
1218
|
+
return exp$3;
|
|
1219
|
+
});
|
|
1161
1220
|
const exp$2 = fn(exps, params);
|
|
1162
1221
|
return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
|
|
1163
1222
|
};
|
|
@@ -1191,7 +1250,7 @@ const jitRules = {
|
|
|
1191
1250
|
const k1 = reshapeViews(keys[1], mapping);
|
|
1192
1251
|
const c0 = require_backend.AluExp.u32(0);
|
|
1193
1252
|
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1194
|
-
const exp$2 = require_backend.AluExp.threefry2x32(
|
|
1253
|
+
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1195
1254
|
return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
|
|
1196
1255
|
},
|
|
1197
1256
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
@@ -1200,6 +1259,8 @@ const jitRules = {
|
|
|
1200
1259
|
[Primitive.Atan]: unopJit(require_backend.AluExp.atan),
|
|
1201
1260
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
1202
1261
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1262
|
+
[Primitive.Erf]: unopJit(require_backend.AluExp.erf),
|
|
1263
|
+
[Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
|
|
1203
1264
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1204
1265
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
1205
1266
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
@@ -1232,7 +1293,7 @@ const jitRules = {
|
|
|
1232
1293
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
1233
1294
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
1234
1295
|
const c = k1.exp;
|
|
1235
|
-
const cs =
|
|
1296
|
+
const cs = promoteAvals(as, bs);
|
|
1236
1297
|
return jitRules[Primitive.Reduce](nargs, [c], [cs], {
|
|
1237
1298
|
op: require_backend.AluOp.Add,
|
|
1238
1299
|
axis: [cs.ndim - 1]
|
|
@@ -1242,12 +1303,12 @@ const jitRules = {
|
|
|
1242
1303
|
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1243
1304
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1244
1305
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1245
|
-
as = new ShapedArray(stX.shape, as.dtype);
|
|
1246
|
-
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1306
|
+
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1307
|
+
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1247
1308
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1248
1309
|
},
|
|
1249
1310
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1250
|
-
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
|
|
1311
|
+
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1251
1312
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
1252
1313
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
1253
1314
|
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
@@ -1260,7 +1321,7 @@ const jitRules = {
|
|
|
1260
1321
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1261
1322
|
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1262
1323
|
const axisSet = new Set(axis);
|
|
1263
|
-
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1324
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
1264
1325
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1265
1326
|
finalShape.splice(outDim, 0, ...indexShape);
|
|
1266
1327
|
const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
|
|
@@ -1296,9 +1357,10 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1296
1357
|
Primitive.Conv,
|
|
1297
1358
|
Primitive.PoolTranspose
|
|
1298
1359
|
];
|
|
1360
|
+
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
1299
1361
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1300
1362
|
const eqn = jaxpr.eqns[i];
|
|
1301
|
-
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive
|
|
1363
|
+
if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1302
1364
|
for (const v of eqn.outBinders) {
|
|
1303
1365
|
blackNodes.add(v);
|
|
1304
1366
|
p1NextBlack.set(v, v);
|
|
@@ -1417,7 +1479,7 @@ var PendingExecute = class {
|
|
|
1417
1479
|
/**
|
|
1418
1480
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1419
1481
|
*
|
|
1420
|
-
* This is the library's core data type. Equivalent to `
|
|
1482
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1421
1483
|
* `torch.Tensor`.
|
|
1422
1484
|
*
|
|
1423
1485
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -1428,9 +1490,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1428
1490
|
static #nextId = 1001;
|
|
1429
1491
|
id;
|
|
1430
1492
|
#dtype;
|
|
1493
|
+
#weakType;
|
|
1431
1494
|
#source;
|
|
1432
1495
|
#st;
|
|
1433
1496
|
#backend;
|
|
1497
|
+
#committed;
|
|
1434
1498
|
#rc;
|
|
1435
1499
|
#pendingSet;
|
|
1436
1500
|
/**
|
|
@@ -1439,21 +1503,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1439
1503
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1440
1504
|
* will be freed when the array is disposed.
|
|
1441
1505
|
*/
|
|
1442
|
-
constructor(
|
|
1506
|
+
constructor(args) {
|
|
1443
1507
|
super(baseArrayTrace);
|
|
1444
1508
|
this.id = Array$1.#nextId++;
|
|
1445
|
-
this.#dtype = dtype;
|
|
1446
|
-
this.#
|
|
1447
|
-
this.#
|
|
1448
|
-
this.#
|
|
1509
|
+
this.#dtype = args.dtype;
|
|
1510
|
+
this.#weakType = args.weakType;
|
|
1511
|
+
this.#source = args.source;
|
|
1512
|
+
this.#st = args.st;
|
|
1513
|
+
this.#backend = args.backend;
|
|
1514
|
+
this.#committed = args.committed;
|
|
1449
1515
|
this.#rc = 1;
|
|
1450
|
-
this.#pendingSet = new Set(pending);
|
|
1516
|
+
this.#pendingSet = new Set(args.pending);
|
|
1451
1517
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1452
|
-
else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1518
|
+
else if (this.#source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1453
1519
|
}
|
|
1454
1520
|
/** @ignore */
|
|
1455
1521
|
get aval() {
|
|
1456
|
-
return new ShapedArray(this.#st.shape, this.#dtype);
|
|
1522
|
+
return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
|
|
1457
1523
|
}
|
|
1458
1524
|
/** Return a simple string representation of the array's dimensions. */
|
|
1459
1525
|
toString() {
|
|
@@ -1465,6 +1531,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1465
1531
|
#check() {
|
|
1466
1532
|
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1467
1533
|
}
|
|
1534
|
+
/** Construct an array, copying fields from `this`. */
|
|
1535
|
+
#newArrayFrom(args) {
|
|
1536
|
+
return new Array$1({
|
|
1537
|
+
source: args.source ?? this.#source,
|
|
1538
|
+
st: args.st ?? this.#st,
|
|
1539
|
+
dtype: args.dtype ?? this.#dtype,
|
|
1540
|
+
weakType: this.#weakType,
|
|
1541
|
+
backend: args.backend ?? this.#backend,
|
|
1542
|
+
committed: args.committed ?? this.#committed,
|
|
1543
|
+
pending: args.pending ?? this.#pending ?? void 0
|
|
1544
|
+
});
|
|
1545
|
+
}
|
|
1468
1546
|
get ref() {
|
|
1469
1547
|
this.#check();
|
|
1470
1548
|
this.#rc++;
|
|
@@ -1504,7 +1582,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1504
1582
|
const pending = this.#pending;
|
|
1505
1583
|
for (const exe of pending) exe.updateRc(1);
|
|
1506
1584
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1507
|
-
const ar =
|
|
1585
|
+
const ar = this.#newArrayFrom({
|
|
1586
|
+
st,
|
|
1587
|
+
pending
|
|
1588
|
+
});
|
|
1508
1589
|
this.dispose();
|
|
1509
1590
|
return ar;
|
|
1510
1591
|
}
|
|
@@ -1514,9 +1595,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1514
1595
|
*/
|
|
1515
1596
|
#gather(indices, axis, outDim) {
|
|
1516
1597
|
this.#check();
|
|
1517
|
-
if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1518
1598
|
const axisSet = new Set(axis);
|
|
1519
1599
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
|
|
1600
|
+
if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1601
|
+
indices = indices.map((ar) => ar._putSync(this.#backend));
|
|
1520
1602
|
indices = Array$1.#broadcastArrays(indices);
|
|
1521
1603
|
const indexShape = indices[0].shape;
|
|
1522
1604
|
const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1553,7 +1635,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1553
1635
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1554
1636
|
this.dispose();
|
|
1555
1637
|
for (const ar of indices) ar.dispose();
|
|
1556
|
-
return
|
|
1638
|
+
return this.#newArrayFrom({
|
|
1639
|
+
source: output,
|
|
1640
|
+
st: require_backend.ShapeTracker.fromShape(finalShape),
|
|
1641
|
+
pending
|
|
1642
|
+
});
|
|
1557
1643
|
}
|
|
1558
1644
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1559
1645
|
#moveAxesDown(axis) {
|
|
@@ -1576,11 +1662,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1576
1662
|
return this.#reshape(this.#st.permute(perm));
|
|
1577
1663
|
}
|
|
1578
1664
|
#unary(op, dtypeOutput) {
|
|
1665
|
+
const weakType = !dtypeOutput && this.#weakType;
|
|
1579
1666
|
dtypeOutput ??= this.#dtype;
|
|
1580
1667
|
this.#check();
|
|
1581
1668
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1582
1669
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1583
|
-
|
|
1670
|
+
this.dispose();
|
|
1671
|
+
return this.#newArrayFrom({
|
|
1672
|
+
source: exp$3.simplify(),
|
|
1673
|
+
dtype: dtypeOutput,
|
|
1674
|
+
weakType
|
|
1675
|
+
});
|
|
1584
1676
|
}
|
|
1585
1677
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1586
1678
|
const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1590,41 +1682,67 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1590
1682
|
for (const exe of pending) exe.updateRc(1);
|
|
1591
1683
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1592
1684
|
this.dispose();
|
|
1593
|
-
return
|
|
1685
|
+
return this.#newArrayFrom({
|
|
1686
|
+
source: output,
|
|
1687
|
+
st: require_backend.ShapeTracker.fromShape(this.shape),
|
|
1688
|
+
dtype: dtypeOutput,
|
|
1689
|
+
weakType,
|
|
1690
|
+
pending
|
|
1691
|
+
});
|
|
1594
1692
|
}
|
|
1595
1693
|
#binary(op, other) {
|
|
1596
|
-
const custom = (src) => new require_backend.AluExp(op,
|
|
1694
|
+
const custom = (src) => new require_backend.AluExp(op, src[0].dtype, src);
|
|
1597
1695
|
return Array$1.#naryCustom(op, custom, [this, other]);
|
|
1598
1696
|
}
|
|
1599
|
-
static #naryCustom(name, custom, arrays, { dtypeOverride,
|
|
1697
|
+
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1600
1698
|
const n = arrays.length;
|
|
1601
|
-
const backend = arrays[0].#backend;
|
|
1602
1699
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1603
1700
|
for (const ar of arrays) ar.#check();
|
|
1604
|
-
let
|
|
1605
|
-
|
|
1606
|
-
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
|
|
1611
|
-
}
|
|
1612
|
-
|
|
1613
|
-
|
|
1701
|
+
let castDtype;
|
|
1702
|
+
let castWeakType = true;
|
|
1703
|
+
for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
|
|
1704
|
+
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1705
|
+
} else if (castDtype === void 0) {
|
|
1706
|
+
castDtype = arrays[i].#dtype;
|
|
1707
|
+
castWeakType = arrays[i].#weakType;
|
|
1708
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1709
|
+
const weakType = castWeakType && !strongTypeOutput;
|
|
1710
|
+
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
1711
|
+
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
1614
1712
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1615
1713
|
const newShape = [...arrays[0].shape];
|
|
1616
1714
|
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
1715
|
+
const sources = arrays.map((ar, i) => {
|
|
1716
|
+
if (!dtypeOverride?.[i]) return require_backend.AluExp.cast(castDtype, ar.#source);
|
|
1717
|
+
else return ar.#source;
|
|
1718
|
+
});
|
|
1617
1719
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1618
|
-
const exp$4 = custom(
|
|
1619
|
-
|
|
1720
|
+
const exp$4 = custom(sources);
|
|
1721
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1722
|
+
return new Array$1({
|
|
1723
|
+
source: exp$4.simplify(),
|
|
1724
|
+
st: arrays[0].#st,
|
|
1725
|
+
dtype: exp$4.dtype,
|
|
1726
|
+
weakType,
|
|
1727
|
+
backend,
|
|
1728
|
+
committed
|
|
1729
|
+
});
|
|
1620
1730
|
}
|
|
1621
|
-
const exp$3 = custom(arrays.map((ar) => {
|
|
1622
|
-
const src$1 =
|
|
1731
|
+
const exp$3 = custom(arrays.map((ar, i) => {
|
|
1732
|
+
const src$1 = sources[i];
|
|
1623
1733
|
if (ar.#st.contiguous) return src$1;
|
|
1624
1734
|
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1625
1735
|
}));
|
|
1626
1736
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1627
|
-
|
|
1737
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1738
|
+
return new Array$1({
|
|
1739
|
+
source: exp$3.simplify(),
|
|
1740
|
+
st,
|
|
1741
|
+
dtype: exp$3.dtype,
|
|
1742
|
+
weakType,
|
|
1743
|
+
backend,
|
|
1744
|
+
committed
|
|
1745
|
+
});
|
|
1628
1746
|
}
|
|
1629
1747
|
let indices;
|
|
1630
1748
|
if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
|
|
@@ -1634,14 +1752,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1634
1752
|
}
|
|
1635
1753
|
const inputs = [];
|
|
1636
1754
|
const src = [];
|
|
1637
|
-
for (const ar of arrays
|
|
1638
|
-
|
|
1639
|
-
|
|
1640
|
-
|
|
1641
|
-
gid = inputs.
|
|
1642
|
-
|
|
1755
|
+
for (const [i, ar] of arrays.entries()) {
|
|
1756
|
+
let nextSrc;
|
|
1757
|
+
if (ar.#source instanceof require_backend.AluExp) nextSrc = require_backend.accessorAluExp(ar.#source, ar.#st, indices);
|
|
1758
|
+
else {
|
|
1759
|
+
let gid = inputs.indexOf(ar.#source);
|
|
1760
|
+
if (gid === -1) {
|
|
1761
|
+
gid = inputs.length;
|
|
1762
|
+
inputs.push(ar.#source);
|
|
1763
|
+
}
|
|
1764
|
+
nextSrc = require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
|
|
1643
1765
|
}
|
|
1644
|
-
|
|
1766
|
+
if (!dtypeOverride?.[i]) nextSrc = require_backend.AluExp.cast(castDtype, nextSrc);
|
|
1767
|
+
src.push(nextSrc);
|
|
1645
1768
|
}
|
|
1646
1769
|
const exp$2 = custom(src);
|
|
1647
1770
|
let re = void 0;
|
|
@@ -1654,13 +1777,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1654
1777
|
const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
|
|
1655
1778
|
for (const exe of pending) exe.updateRc(1);
|
|
1656
1779
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1657
|
-
|
|
1658
|
-
return new Array$1(
|
|
1780
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1781
|
+
return new Array$1({
|
|
1782
|
+
source: output,
|
|
1783
|
+
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1784
|
+
dtype: kernel.dtype,
|
|
1785
|
+
weakType,
|
|
1786
|
+
backend,
|
|
1787
|
+
committed,
|
|
1788
|
+
pending
|
|
1789
|
+
});
|
|
1659
1790
|
}
|
|
1660
1791
|
/** Reduce the last dimension of the array by an operation. */
|
|
1661
1792
|
#reduce(op) {
|
|
1662
|
-
this.#check();
|
|
1663
|
-
if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
|
|
1664
1793
|
const shape$1 = this.shape;
|
|
1665
1794
|
const reduction = new require_backend.Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
|
|
1666
1795
|
const newShape = shape$1.slice(0, -1);
|
|
@@ -1679,7 +1808,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1679
1808
|
for (const exe of pending) exe.updateRc(1);
|
|
1680
1809
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1681
1810
|
this.dispose();
|
|
1682
|
-
return
|
|
1811
|
+
return this.#newArrayFrom({
|
|
1812
|
+
source: output,
|
|
1813
|
+
st: require_backend.ShapeTracker.fromShape(newShape),
|
|
1814
|
+
pending
|
|
1815
|
+
});
|
|
1683
1816
|
}
|
|
1684
1817
|
/**
|
|
1685
1818
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1715,20 +1848,37 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1715
1848
|
}
|
|
1716
1849
|
#dataInline() {
|
|
1717
1850
|
this.#check();
|
|
1718
|
-
|
|
1719
|
-
const ar =
|
|
1851
|
+
if (!(this.#source instanceof require_backend.AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
|
|
1852
|
+
const ar = this.#newArrayFrom({ backend: require_backend.getBackend("cpu") });
|
|
1720
1853
|
this.dispose();
|
|
1721
1854
|
return ar.dataSync();
|
|
1722
1855
|
}
|
|
1723
1856
|
static #broadcastArrays(arrays) {
|
|
1724
1857
|
if (arrays.length === 0) throw new Error("Need at least one array to broadcast");
|
|
1725
1858
|
if (arrays.length === 1) return arrays;
|
|
1726
|
-
const newShape = arrays.map((a) => a.shape).reduce(generalBroadcast);
|
|
1859
|
+
const newShape = arrays.map((a) => a.shape).reduce(require_backend.generalBroadcast);
|
|
1727
1860
|
return arrays.map((ar) => {
|
|
1728
1861
|
if (require_backend.deepEqual(ar.shape, newShape)) return ar;
|
|
1729
1862
|
return ar.#reshape(ar.#st.broadcast(newShape, require_backend.range(newShape.length - ar.ndim)));
|
|
1730
1863
|
});
|
|
1731
1864
|
}
|
|
1865
|
+
static #computeBackend(name, arrays) {
|
|
1866
|
+
const committed = arrays.filter((ar) => ar.#committed);
|
|
1867
|
+
if (committed.length > 0) {
|
|
1868
|
+
const backend = committed[0].#backend;
|
|
1869
|
+
for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
|
|
1870
|
+
return {
|
|
1871
|
+
backend,
|
|
1872
|
+
committed: true
|
|
1873
|
+
};
|
|
1874
|
+
} else {
|
|
1875
|
+
const backend = arrays.length > 0 ? arrays[0].#backend : require_backend.getBackend();
|
|
1876
|
+
return {
|
|
1877
|
+
backend,
|
|
1878
|
+
committed: false
|
|
1879
|
+
};
|
|
1880
|
+
}
|
|
1881
|
+
}
|
|
1732
1882
|
/** Realize the array and return it as data. */
|
|
1733
1883
|
async data() {
|
|
1734
1884
|
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
@@ -1842,14 +1992,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1842
1992
|
x.#backend.incRef(x.#source);
|
|
1843
1993
|
const pending = x.#pending;
|
|
1844
1994
|
for (const exe of pending) exe.updateRc(1);
|
|
1845
|
-
const y =
|
|
1995
|
+
const y = x.#newArrayFrom({
|
|
1996
|
+
dtype,
|
|
1997
|
+
weakType: false,
|
|
1998
|
+
pending
|
|
1999
|
+
});
|
|
1846
2000
|
x.dispose();
|
|
1847
2001
|
return [y];
|
|
1848
2002
|
}
|
|
1849
2003
|
},
|
|
1850
2004
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
1851
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1852
|
-
if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2005
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2006
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1853
2007
|
const c0 = zeros(shape$1, {
|
|
1854
2008
|
dtype: require_backend.DType.Uint32,
|
|
1855
2009
|
device: k0.device
|
|
@@ -1884,6 +2038,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1884
2038
|
[Primitive.Log]([x]) {
|
|
1885
2039
|
return [x.#unary(require_backend.AluOp.Log)];
|
|
1886
2040
|
},
|
|
2041
|
+
[Primitive.Erf]([x]) {
|
|
2042
|
+
return [x.#unary(require_backend.AluOp.Erf)];
|
|
2043
|
+
},
|
|
2044
|
+
[Primitive.Erfc]([x]) {
|
|
2045
|
+
return [x.#unary(require_backend.AluOp.Erfc)];
|
|
2046
|
+
},
|
|
1887
2047
|
[Primitive.Sqrt]([x]) {
|
|
1888
2048
|
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
1889
2049
|
},
|
|
@@ -1917,7 +2077,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1917
2077
|
},
|
|
1918
2078
|
[Primitive.Compare]([x, y], { op }) {
|
|
1919
2079
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1920
|
-
return [Array$1.#naryCustom("compare", custom, [x, y], {
|
|
2080
|
+
return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
|
|
1921
2081
|
},
|
|
1922
2082
|
[Primitive.Where]([cond, x, y]) {
|
|
1923
2083
|
const custom = ([cond$1, x$1, y$1]) => require_backend.AluExp.where(cond$1, x$1, y$1);
|
|
@@ -1952,7 +2112,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1952
2112
|
},
|
|
1953
2113
|
[Primitive.JitCall](args, { jaxpr, numConsts }) {
|
|
1954
2114
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
1955
|
-
const backend =
|
|
2115
|
+
const { backend, committed } = Array$1.#computeBackend("jit_call", args);
|
|
2116
|
+
args = args.map((ar) => ar._putSync(backend));
|
|
1956
2117
|
const consts = args.slice(0, numConsts);
|
|
1957
2118
|
const tracers = args.slice(numConsts);
|
|
1958
2119
|
const jp = jitCompile(backend, jaxpr, consts);
|
|
@@ -1963,43 +2124,66 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1963
2124
|
pending.splice(0, 0, ...prevPending);
|
|
1964
2125
|
args.forEach((x) => x.dispose());
|
|
1965
2126
|
return outputs.map((source, i) => {
|
|
1966
|
-
return new Array$1(
|
|
2127
|
+
return new Array$1({
|
|
2128
|
+
source,
|
|
2129
|
+
st: require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
|
|
2130
|
+
dtype: jaxpr.outs[i].aval.dtype,
|
|
2131
|
+
weakType: jaxpr.outs[i].aval.weakType,
|
|
2132
|
+
backend,
|
|
2133
|
+
committed,
|
|
2134
|
+
pending
|
|
2135
|
+
});
|
|
1967
2136
|
});
|
|
1968
2137
|
}
|
|
1969
2138
|
};
|
|
1970
2139
|
}
|
|
2140
|
+
/** @private */
|
|
1971
2141
|
_realizeSource() {
|
|
1972
2142
|
this.#realize();
|
|
1973
2143
|
return this.#source;
|
|
1974
2144
|
}
|
|
2145
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
2146
|
+
async _put(backend) {
|
|
2147
|
+
if (this.#backend === backend) return this;
|
|
2148
|
+
if (this.#source instanceof require_backend.AluExp) {
|
|
2149
|
+
const ar = this.#newArrayFrom({
|
|
2150
|
+
backend,
|
|
2151
|
+
committed: true
|
|
2152
|
+
});
|
|
2153
|
+
this.dispose();
|
|
2154
|
+
return ar;
|
|
2155
|
+
} else {
|
|
2156
|
+
const data = await this.data();
|
|
2157
|
+
return arrayFromData(data, this.shape, {
|
|
2158
|
+
dtype: this.#dtype,
|
|
2159
|
+
device: backend.type
|
|
2160
|
+
}, this.#weakType);
|
|
2161
|
+
}
|
|
2162
|
+
}
|
|
2163
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
2164
|
+
_putSync(backend) {
|
|
2165
|
+
if (this.#backend === backend) return this;
|
|
2166
|
+
if (this.#source instanceof require_backend.AluExp) {
|
|
2167
|
+
const ar = this.#newArrayFrom({
|
|
2168
|
+
backend,
|
|
2169
|
+
committed: true
|
|
2170
|
+
});
|
|
2171
|
+
this.dispose();
|
|
2172
|
+
return ar;
|
|
2173
|
+
} else {
|
|
2174
|
+
const data = this.dataSync();
|
|
2175
|
+
return arrayFromData(data, this.shape, {
|
|
2176
|
+
dtype: this.#dtype,
|
|
2177
|
+
device: backend.type
|
|
2178
|
+
}, this.#weakType);
|
|
2179
|
+
}
|
|
2180
|
+
}
|
|
1975
2181
|
};
|
|
1976
|
-
/** Construct an array from a single scalar constant. */
|
|
1977
|
-
function scalar(value, { dtype, device } = {}) {
|
|
1978
|
-
if (typeof value === "number") {
|
|
1979
|
-
dtype ??= require_backend.DType.Float32;
|
|
1980
|
-
if (![
|
|
1981
|
-
require_backend.DType.Float32,
|
|
1982
|
-
require_backend.DType.Float16,
|
|
1983
|
-
require_backend.DType.Int32,
|
|
1984
|
-
require_backend.DType.Uint32
|
|
1985
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1986
|
-
} else if (typeof value === "boolean") {
|
|
1987
|
-
dtype ??= require_backend.DType.Bool;
|
|
1988
|
-
if (![
|
|
1989
|
-
require_backend.DType.Float32,
|
|
1990
|
-
require_backend.DType.Float16,
|
|
1991
|
-
require_backend.DType.Int32,
|
|
1992
|
-
require_backend.DType.Uint32,
|
|
1993
|
-
require_backend.DType.Bool
|
|
1994
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1995
|
-
} else throw new TypeError(`Invalid type for scalar ${value}`);
|
|
1996
|
-
return new Array$1(require_backend.AluExp.const(dtype, value), require_backend.ShapeTracker.fromShape([]), dtype, require_backend.getBackend(device));
|
|
1997
|
-
}
|
|
1998
2182
|
/** Constructor for creating a new array from data. */
|
|
1999
2183
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
2000
2184
|
if (values instanceof Tracer) {
|
|
2001
2185
|
if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
2002
|
-
if (dtype && values.dtype !== dtype)
|
|
2186
|
+
if (dtype && values.dtype !== dtype) values = values.astype(dtype);
|
|
2003
2187
|
return values;
|
|
2004
2188
|
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
2005
2189
|
dtype,
|
|
@@ -2021,6 +2205,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2021
2205
|
dtype,
|
|
2022
2206
|
device
|
|
2023
2207
|
});
|
|
2208
|
+
if (size$1 === 1) return full(shape$1, flat[0], {
|
|
2209
|
+
dtype,
|
|
2210
|
+
device
|
|
2211
|
+
});
|
|
2024
2212
|
if (typeof flat[0] === "boolean") {
|
|
2025
2213
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
2026
2214
|
const data = new Int32Array(flat.map((x) => x ? 1 : 0));
|
|
@@ -2029,46 +2217,52 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
2029
2217
|
device
|
|
2030
2218
|
});
|
|
2031
2219
|
} else {
|
|
2220
|
+
const weakType = dtype == void 0;
|
|
2032
2221
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2033
2222
|
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
2034
2223
|
return arrayFromData(data, shape$1, {
|
|
2035
2224
|
dtype,
|
|
2036
2225
|
device
|
|
2037
|
-
});
|
|
2226
|
+
}, weakType);
|
|
2038
2227
|
}
|
|
2039
2228
|
}
|
|
2040
2229
|
}
|
|
2041
|
-
function arrayFromData(data, shape$1, { dtype, device } =
|
|
2230
|
+
function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
2231
|
+
if (data instanceof Float32Array) {
|
|
2232
|
+
if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2233
|
+
dtype ??= require_backend.DType.Float32;
|
|
2234
|
+
} else if (data instanceof Int32Array) {
|
|
2235
|
+
if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2236
|
+
dtype ??= require_backend.DType.Int32;
|
|
2237
|
+
} else if (data instanceof Uint32Array) {
|
|
2238
|
+
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2239
|
+
dtype ??= require_backend.DType.Uint32;
|
|
2240
|
+
} else if (data instanceof Float16Array) {
|
|
2241
|
+
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2242
|
+
dtype ??= require_backend.DType.Float16;
|
|
2243
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2042
2244
|
if (data.length < inlineArrayLimit) {
|
|
2043
2245
|
let allEqual = true;
|
|
2044
2246
|
for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
|
|
2045
2247
|
allEqual = false;
|
|
2046
2248
|
break;
|
|
2047
2249
|
}
|
|
2048
|
-
if (allEqual)
|
|
2049
|
-
dtype,
|
|
2050
|
-
device
|
|
2051
|
-
}
|
|
2250
|
+
if (allEqual) {
|
|
2251
|
+
const sa = new ShapedArray(shape$1, dtype, weakType);
|
|
2252
|
+
return fullInternal(sa, data[0], device);
|
|
2253
|
+
}
|
|
2052
2254
|
}
|
|
2053
2255
|
const backend = require_backend.getBackend(device);
|
|
2054
|
-
|
|
2055
|
-
|
|
2056
|
-
|
|
2057
|
-
|
|
2058
|
-
|
|
2059
|
-
|
|
2060
|
-
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
dtype ??= require_backend.DType.Uint32;
|
|
2065
|
-
} else if (data instanceof Float16Array) {
|
|
2066
|
-
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2067
|
-
dtype ??= require_backend.DType.Float16;
|
|
2068
|
-
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2069
|
-
const slot = backend.malloc(data.byteLength, buf);
|
|
2070
|
-
return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2071
|
-
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
2256
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2257
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2258
|
+
return new Array$1({
|
|
2259
|
+
source: slot,
|
|
2260
|
+
st: require_backend.ShapeTracker.fromShape(shape$1),
|
|
2261
|
+
dtype,
|
|
2262
|
+
weakType,
|
|
2263
|
+
backend,
|
|
2264
|
+
committed: device != void 0
|
|
2265
|
+
});
|
|
2072
2266
|
}
|
|
2073
2267
|
function dataToJs(dtype, data, shape$1) {
|
|
2074
2268
|
if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -2084,7 +2278,7 @@ function dataToJs(dtype, data, shape$1) {
|
|
|
2084
2278
|
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
2085
2279
|
function pureArray(x) {
|
|
2086
2280
|
if (x instanceof Tracer) return x;
|
|
2087
|
-
else return
|
|
2281
|
+
else return array(x);
|
|
2088
2282
|
}
|
|
2089
2283
|
var EvalTrace = class extends Trace {
|
|
2090
2284
|
pure = (x) => pureArray(x);
|
|
@@ -2095,20 +2289,28 @@ var EvalTrace = class extends Trace {
|
|
|
2095
2289
|
};
|
|
2096
2290
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2097
2291
|
const implRules = Array$1._implRules();
|
|
2292
|
+
function fullInternal(aval, fillValue, device) {
|
|
2293
|
+
return new Array$1({
|
|
2294
|
+
source: require_backend.AluExp.const(aval.dtype, fillValue),
|
|
2295
|
+
st: require_backend.ShapeTracker.fromShape(aval.shape),
|
|
2296
|
+
dtype: aval.dtype,
|
|
2297
|
+
weakType: aval.weakType,
|
|
2298
|
+
backend: require_backend.getBackend(device),
|
|
2299
|
+
committed: device != void 0
|
|
2300
|
+
});
|
|
2301
|
+
}
|
|
2098
2302
|
function zerosLike$1(val, dtype) {
|
|
2099
|
-
|
|
2100
|
-
if (val instanceof Tracer) val.dispose();
|
|
2101
|
-
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2303
|
+
return fullLike(val, 0, dtype);
|
|
2102
2304
|
}
|
|
2103
2305
|
function onesLike$1(val, dtype) {
|
|
2104
|
-
|
|
2105
|
-
if (val instanceof Tracer) val.dispose();
|
|
2106
|
-
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2306
|
+
return fullLike(val, 1, dtype);
|
|
2107
2307
|
}
|
|
2108
2308
|
function fullLike(val, fillValue, dtype) {
|
|
2109
2309
|
const aval = getAval(val);
|
|
2110
2310
|
if (val instanceof Tracer) val.dispose();
|
|
2111
|
-
|
|
2311
|
+
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
2312
|
+
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
2313
|
+
return fullInternal(sa, fillValue);
|
|
2112
2314
|
}
|
|
2113
2315
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
2114
2316
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -2126,19 +2328,14 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
2126
2328
|
}
|
|
2127
2329
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
2128
2330
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
2129
|
-
let
|
|
2130
|
-
if (typeof fillValue === "number")
|
|
2131
|
-
|
|
2132
|
-
source = require_backend.AluExp.const(dtype, fillValue);
|
|
2133
|
-
} else if (typeof fillValue === "bigint") {
|
|
2134
|
-
dtype = dtype ?? require_backend.DType.Int32;
|
|
2135
|
-
source = require_backend.AluExp.const(dtype, Number(fillValue));
|
|
2136
|
-
} else if (typeof fillValue === "boolean") {
|
|
2331
|
+
let weakType = dtype == void 0;
|
|
2332
|
+
if (typeof fillValue === "number") dtype = dtype ?? require_backend.DType.Float32;
|
|
2333
|
+
else if (typeof fillValue === "boolean") {
|
|
2137
2334
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
2138
|
-
|
|
2335
|
+
weakType = false;
|
|
2139
2336
|
} else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
|
|
2140
2337
|
else throw new TypeError(`Invalid type for full: ${fillValue}`);
|
|
2141
|
-
return new
|
|
2338
|
+
return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
|
|
2142
2339
|
}
|
|
2143
2340
|
/**
|
|
2144
2341
|
* Create an identity matrix.
|
|
@@ -2148,6 +2345,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
2148
2345
|
*/
|
|
2149
2346
|
function eye(numRows, numCols, { dtype, device } = {}) {
|
|
2150
2347
|
numCols = numCols ?? numRows;
|
|
2348
|
+
const weakType = dtype == void 0;
|
|
2151
2349
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
2152
2350
|
if (numCols < numRows) {
|
|
2153
2351
|
const arr = eye(numCols, numRows, {
|
|
@@ -2161,7 +2359,14 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2161
2359
|
device
|
|
2162
2360
|
});
|
|
2163
2361
|
const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
|
|
2164
|
-
return new Array$1(
|
|
2362
|
+
return new Array$1({
|
|
2363
|
+
source: require_backend.AluExp.cast(dtype, exp$2),
|
|
2364
|
+
st: require_backend.ShapeTracker.fromShape([numRows, numCols]),
|
|
2365
|
+
dtype,
|
|
2366
|
+
weakType,
|
|
2367
|
+
backend: require_backend.getBackend(device),
|
|
2368
|
+
committed: device != void 0
|
|
2369
|
+
});
|
|
2165
2370
|
}
|
|
2166
2371
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2167
2372
|
function identity$1(n, { dtype, device } = {}) {
|
|
@@ -2198,7 +2403,14 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2198
2403
|
});
|
|
2199
2404
|
const exp$2 = require_backend.AluExp.add(require_backend.AluExp.const(dtype, start), require_backend.AluExp.mul(require_backend.AluExp.cast(dtype, require_backend.AluVar.idx), require_backend.AluExp.const(dtype, step)));
|
|
2200
2405
|
const st = require_backend.ShapeTracker.fromShape([size$1]);
|
|
2201
|
-
return new Array$1(
|
|
2406
|
+
return new Array$1({
|
|
2407
|
+
source: exp$2,
|
|
2408
|
+
st,
|
|
2409
|
+
dtype,
|
|
2410
|
+
weakType: false,
|
|
2411
|
+
backend: require_backend.getBackend(device),
|
|
2412
|
+
committed: device != void 0
|
|
2413
|
+
});
|
|
2202
2414
|
}
|
|
2203
2415
|
/**
|
|
2204
2416
|
* Return evenly spaced numbers over a specified interval.
|
|
@@ -2216,10 +2428,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2216
2428
|
dtype,
|
|
2217
2429
|
device
|
|
2218
2430
|
});
|
|
2219
|
-
else if (num === 1) return
|
|
2431
|
+
else if (num === 1) return full([1], start, {
|
|
2220
2432
|
dtype,
|
|
2221
2433
|
device
|
|
2222
|
-
})
|
|
2434
|
+
});
|
|
2223
2435
|
else if (start === stop) return full([num], start, {
|
|
2224
2436
|
dtype,
|
|
2225
2437
|
device
|
|
@@ -2228,7 +2440,14 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2228
2440
|
const denom = endpoint ? num - 1 : num;
|
|
2229
2441
|
const exp$2 = require_backend.AluExp.cast(dtype, require_backend.AluExp.add(require_backend.AluExp.f32(start), require_backend.AluExp.mul(require_backend.AluExp.f32(delta / denom), require_backend.AluExp.cast(require_backend.DType.Float32, require_backend.AluVar.idx))));
|
|
2230
2442
|
const st = require_backend.ShapeTracker.fromShape([num]);
|
|
2231
|
-
return new Array$1(
|
|
2443
|
+
return new Array$1({
|
|
2444
|
+
source: exp$2,
|
|
2445
|
+
st,
|
|
2446
|
+
dtype,
|
|
2447
|
+
weakType: false,
|
|
2448
|
+
backend: require_backend.getBackend(device),
|
|
2449
|
+
committed: device != void 0
|
|
2450
|
+
});
|
|
2232
2451
|
}
|
|
2233
2452
|
function aluCompare(a, b, op) {
|
|
2234
2453
|
switch (op) {
|
|
@@ -2240,35 +2459,6 @@ function aluCompare(a, b, op) {
|
|
|
2240
2459
|
case CompareOp.LessEqual: return require_backend.AluExp.add(require_backend.AluExp.cmplt(a, b), require_backend.AluExp.cmpne(a, b).not());
|
|
2241
2460
|
}
|
|
2242
2461
|
}
|
|
2243
|
-
/**
|
|
2244
|
-
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
2245
|
-
*
|
|
2246
|
-
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
2247
|
-
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
2248
|
-
* Two dimensions are compatible when:
|
|
2249
|
-
* 1. they are equal, or
|
|
2250
|
-
* 2. one of them is 1."
|
|
2251
|
-
*
|
|
2252
|
-
* Throws a TypeError if the broadcast is not possible.
|
|
2253
|
-
*
|
|
2254
|
-
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
2255
|
-
*/
|
|
2256
|
-
function generalBroadcast(a, b) {
|
|
2257
|
-
const out = [];
|
|
2258
|
-
let i = a.length - 1;
|
|
2259
|
-
let j = b.length - 1;
|
|
2260
|
-
for (; i >= 0 && j >= 0; i--, j--) {
|
|
2261
|
-
const x = a[i];
|
|
2262
|
-
const y = b[j];
|
|
2263
|
-
if (x === y) out.push(x);
|
|
2264
|
-
else if (x === 1) out.push(y);
|
|
2265
|
-
else if (y === 1) out.push(x);
|
|
2266
|
-
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
2267
|
-
}
|
|
2268
|
-
for (; i >= 0; i--) out.push(a[i]);
|
|
2269
|
-
for (; j >= 0; j--) out.push(b[j]);
|
|
2270
|
-
return out.reverse();
|
|
2271
|
-
}
|
|
2272
2462
|
|
|
2273
2463
|
//#endregion
|
|
2274
2464
|
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
|
|
@@ -2348,13 +2538,15 @@ var Var = class Var {
|
|
|
2348
2538
|
};
|
|
2349
2539
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
2350
2540
|
var Lit = class {
|
|
2351
|
-
dtype;
|
|
2352
2541
|
value;
|
|
2353
2542
|
aval;
|
|
2354
|
-
|
|
2355
|
-
this.dtype
|
|
2543
|
+
get dtype() {
|
|
2544
|
+
return this.aval.dtype;
|
|
2545
|
+
}
|
|
2546
|
+
constructor(aval, value) {
|
|
2547
|
+
if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
|
|
2356
2548
|
this.value = value;
|
|
2357
|
-
this.aval =
|
|
2549
|
+
this.aval = ShapedArray.fromAval(aval);
|
|
2358
2550
|
}
|
|
2359
2551
|
};
|
|
2360
2552
|
function atomIsLit(atom, literal) {
|
|
@@ -2478,14 +2670,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2478
2670
|
const c = eqn.outBinders[0];
|
|
2479
2671
|
if (atomIsLit(a, 0)) context.set(c, b);
|
|
2480
2672
|
else if (atomIsLit(b, 0)) context.set(c, a);
|
|
2481
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2673
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === require_backend.DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
|
|
2674
|
+
else newEqns.push(eqn);
|
|
2675
|
+
} else if (eqn.primitive === Primitive.Neg) {
|
|
2676
|
+
const [a] = inputs;
|
|
2677
|
+
const c = eqn.outBinders[0];
|
|
2678
|
+
if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
|
|
2482
2679
|
else newEqns.push(eqn);
|
|
2483
2680
|
} else if (eqn.primitive === Primitive.Mul) {
|
|
2484
2681
|
const [a, b] = inputs;
|
|
2485
2682
|
const c = eqn.outBinders[0];
|
|
2486
2683
|
if (atomIsLit(a, 1)) context.set(c, b);
|
|
2487
2684
|
else if (atomIsLit(b, 1)) context.set(c, a);
|
|
2488
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2685
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
|
|
2489
2686
|
else newEqns.push(eqn);
|
|
2490
2687
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
2491
2688
|
const [a, b] = inputs;
|
|
@@ -2583,7 +2780,7 @@ function evalJaxpr(jaxpr, args) {
|
|
|
2583
2780
|
if (x instanceof Var) {
|
|
2584
2781
|
remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
|
|
2585
2782
|
return env.get(x);
|
|
2586
|
-
} else return
|
|
2783
|
+
} else return array(x.value, { dtype: x.dtype });
|
|
2587
2784
|
};
|
|
2588
2785
|
const write = (v, val) => {
|
|
2589
2786
|
if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
|
|
@@ -2642,7 +2839,7 @@ var JaxprTrace = class extends Trace {
|
|
|
2642
2839
|
let tracer = this.builder.constTracers.get(val);
|
|
2643
2840
|
if (tracer === void 0) {
|
|
2644
2841
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
2645
|
-
this.builder.addConst(tracer, val instanceof Tracer ? val.ref :
|
|
2842
|
+
this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
|
|
2646
2843
|
}
|
|
2647
2844
|
return tracer;
|
|
2648
2845
|
}
|
|
@@ -2711,7 +2908,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2711
2908
|
const newConsts = [];
|
|
2712
2909
|
for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
|
|
2713
2910
|
const ar = consts[i];
|
|
2714
|
-
literals.set(jaxpr.inBinders[i], new Lit(ar.
|
|
2911
|
+
literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
|
|
2715
2912
|
} else {
|
|
2716
2913
|
constBinders.push(jaxpr.inBinders[i]);
|
|
2717
2914
|
newConsts.push(consts[i]);
|
|
@@ -2724,13 +2921,12 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2724
2921
|
}
|
|
2725
2922
|
function binopAbstractEval([x, y]) {
|
|
2726
2923
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
2727
|
-
|
|
2728
|
-
return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
|
|
2924
|
+
return [promoteAvals(x, y)];
|
|
2729
2925
|
}
|
|
2730
2926
|
function compareAbstractEval([x, y]) {
|
|
2731
2927
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
|
|
2732
|
-
|
|
2733
|
-
return [new ShapedArray(
|
|
2928
|
+
const aval = promoteAvals(x, y);
|
|
2929
|
+
return [new ShapedArray(aval.shape, require_backend.DType.Bool, false)];
|
|
2734
2930
|
}
|
|
2735
2931
|
function vectorizedUnopAbstractEval([x]) {
|
|
2736
2932
|
return [ShapedArray.fromAval(x)];
|
|
@@ -2743,18 +2939,18 @@ const abstractEvalRules = {
|
|
|
2743
2939
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
2744
2940
|
[Primitive.StopGradient]: vectorizedUnopAbstractEval,
|
|
2745
2941
|
[Primitive.Cast]([x], { dtype }) {
|
|
2746
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2942
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2747
2943
|
},
|
|
2748
2944
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
2749
2945
|
if (x.dtype === require_backend.DType.Bool || dtype === require_backend.DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
2750
2946
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
2751
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2947
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2752
2948
|
},
|
|
2753
2949
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
2754
2950
|
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
2755
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2756
|
-
if (!require_backend.deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2757
|
-
return [new ShapedArray(shape$1, require_backend.DType.Uint32)];
|
|
2951
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2952
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2953
|
+
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
2758
2954
|
},
|
|
2759
2955
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2760
2956
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
@@ -2762,61 +2958,62 @@ const abstractEvalRules = {
|
|
|
2762
2958
|
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2763
2959
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2764
2960
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2961
|
+
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
2962
|
+
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
2765
2963
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2766
2964
|
[Primitive.Min]: binopAbstractEval,
|
|
2767
2965
|
[Primitive.Max]: binopAbstractEval,
|
|
2768
2966
|
[Primitive.Reduce]([x], { axis }) {
|
|
2769
2967
|
const axisSet = new Set(axis);
|
|
2770
2968
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2771
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2969
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2772
2970
|
},
|
|
2773
2971
|
[Primitive.Pool]([x], { window, strides }) {
|
|
2774
2972
|
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2775
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2973
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2776
2974
|
},
|
|
2777
2975
|
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2778
2976
|
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2779
2977
|
if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2780
|
-
return [new ShapedArray(inShape, x.dtype)];
|
|
2978
|
+
return [new ShapedArray(inShape, x.dtype, x.weakType)];
|
|
2781
2979
|
},
|
|
2782
2980
|
[Primitive.Dot]([x, y]) {
|
|
2783
|
-
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2784
2981
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
2785
|
-
const shape$1 =
|
|
2982
|
+
const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
|
|
2786
2983
|
shape$1.splice(-1, 1);
|
|
2787
|
-
return [new ShapedArray(shape$1,
|
|
2984
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2788
2985
|
},
|
|
2789
2986
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
2790
|
-
|
|
2987
|
+
const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
|
|
2791
2988
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2792
|
-
return [new ShapedArray(shape$1,
|
|
2989
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2793
2990
|
},
|
|
2794
2991
|
[Primitive.Compare]: compareAbstractEval,
|
|
2795
2992
|
[Primitive.Where]([cond, x, y]) {
|
|
2796
2993
|
if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
2797
|
-
|
|
2798
|
-
const shape$1 = generalBroadcast(cond.shape,
|
|
2799
|
-
return [new ShapedArray(shape$1,
|
|
2994
|
+
const xy = promoteAvals(x, y);
|
|
2995
|
+
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
2996
|
+
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
2800
2997
|
},
|
|
2801
2998
|
[Primitive.Transpose]([x], { perm }) {
|
|
2802
|
-
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
|
|
2999
|
+
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
2803
3000
|
},
|
|
2804
3001
|
[Primitive.Broadcast]([x], { shape: shape$1 }) {
|
|
2805
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
3002
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2806
3003
|
},
|
|
2807
3004
|
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2808
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
3005
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2809
3006
|
},
|
|
2810
3007
|
[Primitive.Flip]([x], _) {
|
|
2811
|
-
return [
|
|
3008
|
+
return [ShapedArray.fromAval(x)];
|
|
2812
3009
|
},
|
|
2813
3010
|
[Primitive.Shrink]([x], { slice }) {
|
|
2814
3011
|
const newShape = slice.map((s) => s[1] - s[0]);
|
|
2815
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
3012
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2816
3013
|
},
|
|
2817
3014
|
[Primitive.Pad]([x], { width }) {
|
|
2818
3015
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
2819
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
3016
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2820
3017
|
},
|
|
2821
3018
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2822
3019
|
for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
@@ -2826,10 +3023,10 @@ const abstractEvalRules = {
|
|
|
2826
3023
|
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
2827
3024
|
const axisSet = new Set(axis);
|
|
2828
3025
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
2829
|
-
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
3026
|
+
const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
|
|
2830
3027
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2831
3028
|
newShape.splice(outDim, 0, ...gatherShape);
|
|
2832
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
3029
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2833
3030
|
},
|
|
2834
3031
|
[Primitive.JitCall](args, { jaxpr }) {
|
|
2835
3032
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
@@ -2896,6 +3093,7 @@ function jit$1(f, opts) {
|
|
|
2896
3093
|
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2897
3094
|
const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2898
3095
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
3096
|
+
name: f.name || "closure",
|
|
2899
3097
|
jaxpr,
|
|
2900
3098
|
numConsts: consts.length
|
|
2901
3099
|
});
|
|
@@ -3015,6 +3213,16 @@ const jvpRules = {
|
|
|
3015
3213
|
[Primitive.Log]([x], [dx]) {
|
|
3016
3214
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3017
3215
|
},
|
|
3216
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3217
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3218
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3219
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3220
|
+
},
|
|
3221
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3222
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3223
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3224
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3225
|
+
},
|
|
3018
3226
|
[Primitive.Sqrt]([x], [dx]) {
|
|
3019
3227
|
const z = sqrt$1(x);
|
|
3020
3228
|
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
@@ -3058,13 +3266,14 @@ const jvpRules = {
|
|
|
3058
3266
|
const indicesRef = indices.map((t) => t.ref);
|
|
3059
3267
|
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3060
3268
|
},
|
|
3061
|
-
[Primitive.JitCall](primals, tangents, { jaxpr }) {
|
|
3269
|
+
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3062
3270
|
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3063
3271
|
const outs = bind(Primitive.JitCall, [
|
|
3064
3272
|
...newConsts.map((c) => c.ref),
|
|
3065
3273
|
...primals,
|
|
3066
3274
|
...tangents
|
|
3067
3275
|
], {
|
|
3276
|
+
name: `${name}_jvp`,
|
|
3068
3277
|
jaxpr: newJaxpr,
|
|
3069
3278
|
numConsts: newConsts.length
|
|
3070
3279
|
});
|
|
@@ -3119,7 +3328,7 @@ var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
|
3119
3328
|
function mappedAval(batchDim, aval) {
|
|
3120
3329
|
const shape$1 = [...aval.shape];
|
|
3121
3330
|
shape$1.splice(batchDim, 1);
|
|
3122
|
-
return new ShapedArray(shape$1, aval.dtype);
|
|
3331
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3123
3332
|
}
|
|
3124
3333
|
/** Move one axis to a different index. */
|
|
3125
3334
|
function moveaxis$1(x, src, dst) {
|
|
@@ -3176,6 +3385,10 @@ var BatchTrace = class extends Trace {
|
|
|
3176
3385
|
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3177
3386
|
const vmapRule = vmapRules[primitive];
|
|
3178
3387
|
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3388
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3389
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3390
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3391
|
+
}
|
|
3179
3392
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3180
3393
|
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3181
3394
|
}
|
|
@@ -3183,24 +3396,28 @@ var BatchTrace = class extends Trace {
|
|
|
3183
3396
|
return this.main.globalData;
|
|
3184
3397
|
}
|
|
3185
3398
|
};
|
|
3186
|
-
|
|
3187
|
-
|
|
3188
|
-
|
|
3189
|
-
|
|
3190
|
-
|
|
3191
|
-
return broadcast(x, shape$1, axis);
|
|
3192
|
-
}
|
|
3193
|
-
}
|
|
3194
|
-
/** Process a primitive with built-in broadcasting. */
|
|
3399
|
+
/**
|
|
3400
|
+
* Process a primitive with built-in broadcasting.
|
|
3401
|
+
*
|
|
3402
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3403
|
+
*/
|
|
3195
3404
|
function broadcastBatcher(op) {
|
|
3196
3405
|
return (axisSize, args, dims) => {
|
|
3197
3406
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3198
|
-
const
|
|
3199
|
-
|
|
3200
|
-
|
|
3201
|
-
args
|
|
3202
|
-
|
|
3203
|
-
|
|
3407
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3408
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3409
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3410
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3411
|
+
args = args.map((x, i) => {
|
|
3412
|
+
if (dims[i] === null) return x;
|
|
3413
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3414
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3415
|
+
x.shape[0],
|
|
3416
|
+
...require_backend.rep(nd - x.ndim, 1),
|
|
3417
|
+
...x.shape.slice(1)
|
|
3418
|
+
]);
|
|
3419
|
+
return x;
|
|
3420
|
+
});
|
|
3204
3421
|
return [[op(...args)], [0]];
|
|
3205
3422
|
};
|
|
3206
3423
|
}
|
|
@@ -3224,17 +3441,18 @@ const vmapRules = {
|
|
|
3224
3441
|
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3225
3442
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3226
3443
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3444
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3445
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3227
3446
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3228
3447
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3229
3448
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3230
3449
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3231
|
-
|
|
3450
|
+
require_backend.assertNonNull(xBdim);
|
|
3232
3451
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3233
3452
|
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3234
3453
|
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3235
3454
|
},
|
|
3236
3455
|
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3237
|
-
if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
|
|
3238
3456
|
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3239
3457
|
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3240
3458
|
const z = dot$1(x, y);
|
|
@@ -3243,29 +3461,72 @@ const vmapRules = {
|
|
|
3243
3461
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3244
3462
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3245
3463
|
},
|
|
3464
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3465
|
+
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3466
|
+
require_backend.assertNonNull(xBdim);
|
|
3467
|
+
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
3468
|
+
newPerm.splice(xBdim, 0, xBdim);
|
|
3469
|
+
return [[transpose$1(x, newPerm)], [xBdim]];
|
|
3470
|
+
},
|
|
3471
|
+
[Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
|
|
3472
|
+
require_backend.assertNonNull(xBdim);
|
|
3473
|
+
const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
|
|
3474
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3475
|
+
return [[broadcast(x, newShape, newAxis)], [xBdim]];
|
|
3476
|
+
},
|
|
3246
3477
|
[Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
|
|
3247
|
-
if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
|
|
3248
3478
|
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3249
3479
|
return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
|
|
3250
3480
|
},
|
|
3251
3481
|
[Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
|
|
3252
|
-
|
|
3482
|
+
require_backend.assertNonNull(xBdim);
|
|
3253
3483
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3254
3484
|
return [[flip$1(x, newAxis)], [xBdim]];
|
|
3255
3485
|
},
|
|
3256
3486
|
[Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
|
|
3257
|
-
|
|
3487
|
+
require_backend.assertNonNull(xBdim);
|
|
3258
3488
|
const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
|
|
3259
3489
|
return [[shrink(x, newSlice)], [xBdim]];
|
|
3260
3490
|
},
|
|
3261
3491
|
[Primitive.Pad](axisSize, [x], [xBdim], { width }) {
|
|
3262
|
-
|
|
3492
|
+
require_backend.assertNonNull(xBdim);
|
|
3263
3493
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3264
3494
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3265
3495
|
},
|
|
3266
|
-
[Primitive.
|
|
3496
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3497
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3498
|
+
require_backend.assertNonNull(xBdim);
|
|
3499
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3500
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3501
|
+
let newOutDim = outDim;
|
|
3502
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3503
|
+
else newOutDim += 1;
|
|
3504
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3505
|
+
}
|
|
3506
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3507
|
+
indices = indices.map((m, i) => {
|
|
3508
|
+
if (indicesBdim[i] === null) return m;
|
|
3509
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3510
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3511
|
+
m.shape[0],
|
|
3512
|
+
...require_backend.rep(nd - m.ndim, 1),
|
|
3513
|
+
...m.shape.slice(1)
|
|
3514
|
+
]);
|
|
3515
|
+
return m;
|
|
3516
|
+
});
|
|
3517
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3518
|
+
else {
|
|
3519
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3520
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3521
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
|
|
3522
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3523
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3524
|
+
}
|
|
3525
|
+
},
|
|
3526
|
+
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3267
3527
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3268
3528
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
3529
|
+
name: `${name}_vmap`,
|
|
3269
3530
|
jaxpr: newJaxpr,
|
|
3270
3531
|
numConsts: newConsts.length
|
|
3271
3532
|
});
|
|
@@ -3281,7 +3542,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3281
3542
|
if (dims[i] === null) return v.aval;
|
|
3282
3543
|
const shape$1 = [...v.aval.shape];
|
|
3283
3544
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3284
|
-
return new ShapedArray(shape$1, v.aval.dtype);
|
|
3545
|
+
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3285
3546
|
});
|
|
3286
3547
|
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3287
3548
|
const result = {
|
|
@@ -3321,12 +3582,14 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3321
3582
|
function vmap$1(f, inAxes = 0) {
|
|
3322
3583
|
return (...args) => {
|
|
3323
3584
|
const [argsFlat, inTree] = flatten(args);
|
|
3324
|
-
let inAxesFlat;
|
|
3585
|
+
let inAxesFlat = [];
|
|
3325
3586
|
if (typeof inAxes === "number") inAxesFlat = require_backend.rep(argsFlat.length, inAxes);
|
|
3587
|
+
else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, null));
|
|
3588
|
+
else if (typeof inAxes[i] === "number") inAxesFlat.push(...require_backend.rep(inTree.childTreedefs[i].size, inAxes[i]));
|
|
3326
3589
|
else {
|
|
3327
|
-
|
|
3328
|
-
[
|
|
3329
|
-
|
|
3590
|
+
const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
|
|
3591
|
+
if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
|
|
3592
|
+
inAxesFlat.push(...axesFlat);
|
|
3330
3593
|
}
|
|
3331
3594
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3332
3595
|
const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
|
|
@@ -3494,8 +3757,8 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3494
3757
|
processPrimitive(primitive, tracers, params) {
|
|
3495
3758
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3496
3759
|
if (primitive === Primitive.JitCall) {
|
|
3497
|
-
const { jaxpr, numConsts } = params;
|
|
3498
|
-
return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
|
|
3760
|
+
const { name, jaxpr, numConsts } = params;
|
|
3761
|
+
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3499
3762
|
}
|
|
3500
3763
|
const tracersIn = tracers.map((t) => this.instantiateConst(t));
|
|
3501
3764
|
const avalsIn = tracersIn.map((t) => t.pval.aval);
|
|
@@ -3521,12 +3784,13 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3521
3784
|
*
|
|
3522
3785
|
* Used when encountering a JitCall rule during the trace.
|
|
3523
3786
|
*/
|
|
3524
|
-
#partialEvalJaxpr(jaxpr, numConsts, tracers) {
|
|
3787
|
+
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3525
3788
|
jaxpr = jaxpr.flatten();
|
|
3526
3789
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3527
3790
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3528
3791
|
const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
|
|
3529
3792
|
const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3793
|
+
name: `${name}_peval`,
|
|
3530
3794
|
jaxpr: jaxpr1,
|
|
3531
3795
|
numConsts: 0
|
|
3532
3796
|
});
|
|
@@ -3538,6 +3802,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3538
3802
|
prim: Primitive.JitCall,
|
|
3539
3803
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3540
3804
|
params: {
|
|
3805
|
+
name: `${name}_resid`,
|
|
3541
3806
|
jaxpr: jaxpr2,
|
|
3542
3807
|
numConsts: 0
|
|
3543
3808
|
},
|
|
@@ -3680,7 +3945,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
|
|
|
3680
3945
|
}
|
|
3681
3946
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
3682
3947
|
const eqn = jaxpr.eqns[i];
|
|
3683
|
-
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ?
|
|
3948
|
+
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
|
|
3684
3949
|
const cotangentsOut = eqn.outBinders.map(readCotangent);
|
|
3685
3950
|
const rule = transposeRules[eqn.primitive];
|
|
3686
3951
|
if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
|
|
@@ -3765,7 +4030,7 @@ const transposeRules = {
|
|
|
3765
4030
|
},
|
|
3766
4031
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3767
4032
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3768
|
-
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
4033
|
+
const axisSize = require_backend.generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3769
4034
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3770
4035
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3771
4036
|
},
|
|
@@ -3860,7 +4125,7 @@ const transposeRules = {
|
|
|
3860
4125
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
3861
4126
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
3862
4127
|
},
|
|
3863
|
-
[Primitive.JitCall](cts, args, { jaxpr }) {
|
|
4128
|
+
[Primitive.JitCall](cts, args, { name, jaxpr }) {
|
|
3864
4129
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
3865
4130
|
const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
|
|
3866
4131
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
@@ -3869,6 +4134,7 @@ const transposeRules = {
|
|
|
3869
4134
|
...residuals,
|
|
3870
4135
|
...cts
|
|
3871
4136
|
], {
|
|
4137
|
+
name: `${name}_t`,
|
|
3872
4138
|
jaxpr: newJaxpr,
|
|
3873
4139
|
numConsts: newConsts.length
|
|
3874
4140
|
});
|
|
@@ -3943,7 +4209,7 @@ function valueAndGrad$1(f) {
|
|
|
3943
4209
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3944
4210
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3945
4211
|
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3946
|
-
const [ct, ...rest] = fVjp(
|
|
4212
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
3947
4213
|
for (const r of rest) dispose(r);
|
|
3948
4214
|
fVjp.dispose();
|
|
3949
4215
|
return [y, ct];
|
|
@@ -3971,7 +4237,10 @@ __export(lax_exports, {
|
|
|
3971
4237
|
conv: () => conv$1,
|
|
3972
4238
|
convGeneralDilated: () => convGeneralDilated,
|
|
3973
4239
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
3974
|
-
|
|
4240
|
+
erf: () => erf,
|
|
4241
|
+
erfc: () => erfc,
|
|
4242
|
+
reduceWindow: () => reduceWindow,
|
|
4243
|
+
stopGradient: () => stopGradient$1
|
|
3975
4244
|
});
|
|
3976
4245
|
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
3977
4246
|
const padType = padding.toUpperCase();
|
|
@@ -4030,6 +4299,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
|
4030
4299
|
strides: windowStrides
|
|
4031
4300
|
}));
|
|
4032
4301
|
}
|
|
4302
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4303
|
+
function erf(x) {
|
|
4304
|
+
return erf$1(x);
|
|
4305
|
+
}
|
|
4306
|
+
/**
|
|
4307
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4308
|
+
*
|
|
4309
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4310
|
+
* where `erf(x)` is very close to 1.
|
|
4311
|
+
*/
|
|
4312
|
+
function erfc(x) {
|
|
4313
|
+
return erfc$1(x);
|
|
4314
|
+
}
|
|
4315
|
+
/**
|
|
4316
|
+
* Stops gradient computation.
|
|
4317
|
+
*
|
|
4318
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
4319
|
+
* forward or reverse-mode automatic differentiation.
|
|
4320
|
+
*/
|
|
4321
|
+
function stopGradient$1(x) {
|
|
4322
|
+
return stopGradient(x);
|
|
4323
|
+
}
|
|
4033
4324
|
|
|
4034
4325
|
//#endregion
|
|
4035
4326
|
//#region src/numpy.ts
|
|
@@ -4092,6 +4383,9 @@ __export(numpy_exports, {
|
|
|
4092
4383
|
fullLike: () => fullLike$1,
|
|
4093
4384
|
greater: () => greater,
|
|
4094
4385
|
greaterEqual: () => greaterEqual,
|
|
4386
|
+
hamming: () => hamming,
|
|
4387
|
+
hann: () => hann,
|
|
4388
|
+
heaviside: () => heaviside,
|
|
4095
4389
|
hstack: () => hstack,
|
|
4096
4390
|
hypot: () => hypot,
|
|
4097
4391
|
identity: () => identity$1,
|
|
@@ -4313,7 +4607,7 @@ function argmin(a, axis, opts) {
|
|
|
4313
4607
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4314
4608
|
const shape$1 = a.shape;
|
|
4315
4609
|
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4316
|
-
const length =
|
|
4610
|
+
const length = array(shape$1[axis], {
|
|
4317
4611
|
dtype: int32,
|
|
4318
4612
|
device: a.device
|
|
4319
4613
|
});
|
|
@@ -4337,7 +4631,7 @@ function argmax(a, axis, opts) {
|
|
|
4337
4631
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
4338
4632
|
const shape$1 = a.shape;
|
|
4339
4633
|
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4340
|
-
const length =
|
|
4634
|
+
const length = array(shape$1[axis], {
|
|
4341
4635
|
dtype: int32,
|
|
4342
4636
|
device: a.device
|
|
4343
4637
|
});
|
|
@@ -4521,7 +4815,7 @@ function broadcastTo(a, shape$1) {
|
|
|
4521
4815
|
/** Broadcast input shapes to a common output shape. */
|
|
4522
4816
|
function broadcastShapes(...shapes) {
|
|
4523
4817
|
if (shapes.length === 0) return [];
|
|
4524
|
-
return shapes.reduce(generalBroadcast);
|
|
4818
|
+
return shapes.reduce(require_backend.generalBroadcast);
|
|
4525
4819
|
}
|
|
4526
4820
|
/** Broadcast arrays to a common shape. */
|
|
4527
4821
|
function broadcastArrays(...arrays) {
|
|
@@ -4731,6 +5025,32 @@ function sign(x) {
|
|
|
4731
5025
|
x = fudgeArray(x);
|
|
4732
5026
|
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4733
5027
|
}
|
|
5028
|
+
/**
|
|
5029
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
5030
|
+
*
|
|
5031
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5032
|
+
*/
|
|
5033
|
+
function hamming(M) {
|
|
5034
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
|
|
5035
|
+
}
|
|
5036
|
+
/**
|
|
5037
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
5038
|
+
*
|
|
5039
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5040
|
+
*/
|
|
5041
|
+
function hann(M) {
|
|
5042
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
|
|
5043
|
+
}
|
|
5044
|
+
/**
|
|
5045
|
+
* @function
|
|
5046
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
5047
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
5048
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
5049
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
5050
|
+
*/
|
|
5051
|
+
const heaviside = jit$1(function heaviside$1(x1, x2) {
|
|
5052
|
+
return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
|
|
5053
|
+
});
|
|
4734
5054
|
/** Calculate element-wise square of the input array. */
|
|
4735
5055
|
function square(x) {
|
|
4736
5056
|
x = fudgeArray(x);
|
|
@@ -4750,10 +5070,10 @@ function acos(x) {
|
|
|
4750
5070
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4751
5071
|
*
|
|
4752
5072
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4753
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
4754
|
-
* improvements.
|
|
5073
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
5074
|
+
* stability improvements.
|
|
4755
5075
|
*/
|
|
4756
|
-
const hypot = jit$1((x1, x2)
|
|
5076
|
+
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4757
5077
|
return sqrt(square(x1).add(square(x2)));
|
|
4758
5078
|
});
|
|
4759
5079
|
/**
|
|
@@ -4769,7 +5089,7 @@ const hypot = jit$1((x1, x2) => {
|
|
|
4769
5089
|
*
|
|
4770
5090
|
* The output is ill-defined when both x and y are zero.
|
|
4771
5091
|
*/
|
|
4772
|
-
const atan2 = jit$1((y, x)
|
|
5092
|
+
const atan2 = jit$1(function atan2$1(y, x) {
|
|
4773
5093
|
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4774
5094
|
const xNeg = less(x.ref, 0);
|
|
4775
5095
|
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
@@ -4837,13 +5157,13 @@ const degrees = rad2deg;
|
|
|
4837
5157
|
* @function
|
|
4838
5158
|
* Computes first array raised to power of second array, element-wise.
|
|
4839
5159
|
*/
|
|
4840
|
-
const power = jit$1((x1, x2)
|
|
5160
|
+
const power = jit$1(function power$1(x1, x2) {
|
|
4841
5161
|
return exp(log(x1).mul(x2));
|
|
4842
5162
|
});
|
|
4843
5163
|
/** @function Alias of `jax.numpy.power()`. */
|
|
4844
5164
|
const pow = power;
|
|
4845
5165
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
4846
|
-
const cbrt = jit$1((x)
|
|
5166
|
+
const cbrt = jit$1(function cbrt$1(x) {
|
|
4847
5167
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4848
5168
|
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4849
5169
|
});
|
|
@@ -4853,7 +5173,7 @@ const cbrt = jit$1((x) => {
|
|
|
4853
5173
|
*
|
|
4854
5174
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4855
5175
|
*/
|
|
4856
|
-
const sinh = jit$1((x)
|
|
5176
|
+
const sinh = jit$1(function sinh$1(x) {
|
|
4857
5177
|
const ex = exp(x);
|
|
4858
5178
|
const emx = reciprocal(ex.ref);
|
|
4859
5179
|
return ex.sub(emx).mul(.5);
|
|
@@ -4864,7 +5184,7 @@ const sinh = jit$1((x) => {
|
|
|
4864
5184
|
*
|
|
4865
5185
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4866
5186
|
*/
|
|
4867
|
-
const cosh = jit$1((x)
|
|
5187
|
+
const cosh = jit$1(function cosh$1(x) {
|
|
4868
5188
|
const ex = exp(x);
|
|
4869
5189
|
const emx = reciprocal(ex.ref);
|
|
4870
5190
|
return ex.add(emx).mul(.5);
|
|
@@ -4875,7 +5195,7 @@ const cosh = jit$1((x) => {
|
|
|
4875
5195
|
*
|
|
4876
5196
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4877
5197
|
*/
|
|
4878
|
-
const tanh = jit$1((x)
|
|
5198
|
+
const tanh = jit$1(function tanh$1(x) {
|
|
4879
5199
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4880
5200
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4881
5201
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
@@ -4886,7 +5206,7 @@ const tanh = jit$1((x) => {
|
|
|
4886
5206
|
*
|
|
4887
5207
|
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4888
5208
|
*/
|
|
4889
|
-
const arcsinh = jit$1((x)
|
|
5209
|
+
const arcsinh = jit$1(function arcsinh$1(x) {
|
|
4890
5210
|
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4891
5211
|
});
|
|
4892
5212
|
/**
|
|
@@ -4895,7 +5215,7 @@ const arcsinh = jit$1((x) => {
|
|
|
4895
5215
|
*
|
|
4896
5216
|
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4897
5217
|
*/
|
|
4898
|
-
const arccosh = jit$1((x)
|
|
5218
|
+
const arccosh = jit$1(function arccosh$1(x) {
|
|
4899
5219
|
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4900
5220
|
});
|
|
4901
5221
|
/**
|
|
@@ -4904,7 +5224,7 @@ const arccosh = jit$1((x) => {
|
|
|
4904
5224
|
*
|
|
4905
5225
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4906
5226
|
*/
|
|
4907
|
-
const arctanh = jit$1((x)
|
|
5227
|
+
const arctanh = jit$1(function arctanh$1(x) {
|
|
4908
5228
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4909
5229
|
});
|
|
4910
5230
|
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
@@ -5020,7 +5340,9 @@ function softSign(x) {
|
|
|
5020
5340
|
*
|
|
5021
5341
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
5022
5342
|
*/
|
|
5023
|
-
const silu = jit$1((x)
|
|
5343
|
+
const silu = jit$1(function silu$1(x) {
|
|
5344
|
+
return x.ref.mul(sigmoid(x));
|
|
5345
|
+
});
|
|
5024
5346
|
/**
|
|
5025
5347
|
* @function
|
|
5026
5348
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -5073,18 +5395,20 @@ function celu(x, alpha = 1) {
|
|
|
5073
5395
|
* @function
|
|
5074
5396
|
* Gaussion error linear unit (GELU) activation function.
|
|
5075
5397
|
*
|
|
5076
|
-
* This is computed element-wise.
|
|
5077
|
-
*
|
|
5078
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5398
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
5399
|
+
* `approximate` is set (default true):
|
|
5079
5400
|
*
|
|
5080
|
-
*
|
|
5401
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
5402
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
5081
5403
|
*
|
|
5082
|
-
*
|
|
5404
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5083
5405
|
*/
|
|
5084
|
-
const gelu = jit$1((x)
|
|
5085
|
-
|
|
5086
|
-
|
|
5087
|
-
|
|
5406
|
+
const gelu = jit$1(function gelu$1(x, opts) {
|
|
5407
|
+
if (opts?.approximate ?? true) {
|
|
5408
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5409
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5410
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
|
|
5411
|
+
}, { staticArgnums: [1] });
|
|
5088
5412
|
/**
|
|
5089
5413
|
* Gated linear unit (GLU) activation function.
|
|
5090
5414
|
*
|
|
@@ -5252,8 +5576,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
5252
5576
|
const keyShape = validateKeyShape(key$1);
|
|
5253
5577
|
return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
|
|
5254
5578
|
}
|
|
5255
|
-
/**
|
|
5256
|
-
function
|
|
5579
|
+
/**
|
|
5580
|
+
* @function
|
|
5581
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
5582
|
+
*/
|
|
5583
|
+
const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
5257
5584
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
5258
5585
|
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
5259
5586
|
dtype: require_backend.DType.Uint32,
|
|
@@ -5266,7 +5593,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
5266
5593
|
const rand = bitcast(float12, require_backend.DType.Float32).sub(1);
|
|
5267
5594
|
if (minval === 0 && maxval === 1) return rand;
|
|
5268
5595
|
else return rand.mul(maxval - minval).add(minval);
|
|
5269
|
-
}
|
|
5596
|
+
}, { staticArgnums: [1, 2] });
|
|
5270
5597
|
/**
|
|
5271
5598
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5272
5599
|
*
|
|
@@ -5277,26 +5604,49 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
5277
5604
|
p = fudgeArray(p);
|
|
5278
5605
|
return uniform(key$1, shape$1).less(p);
|
|
5279
5606
|
}
|
|
5280
|
-
/**
|
|
5281
|
-
function
|
|
5607
|
+
/**
|
|
5608
|
+
* @function
|
|
5609
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
5610
|
+
*/
|
|
5611
|
+
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
5282
5612
|
const u = uniform(key$1, shape$1);
|
|
5283
5613
|
return negative(log1p(negative(u)));
|
|
5284
|
-
}
|
|
5614
|
+
}, { staticArgnums: [1] });
|
|
5285
5615
|
/**
|
|
5616
|
+
* @function
|
|
5286
5617
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5287
5618
|
*
|
|
5288
5619
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5289
5620
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5290
5621
|
* bitwise identical to JAX.
|
|
5291
5622
|
*/
|
|
5292
|
-
function normal(key$1, shape$1 = []) {
|
|
5623
|
+
const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
5293
5624
|
const [k1, k2] = split(key$1, 2);
|
|
5294
5625
|
const u1 = uniform(k1, shape$1);
|
|
5295
5626
|
const u2 = uniform(k2, shape$1);
|
|
5296
5627
|
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5297
5628
|
const theta = u2.mul(2 * Math.PI);
|
|
5298
5629
|
return radius.mul(cos(theta));
|
|
5299
|
-
}
|
|
5630
|
+
}, { staticArgnums: [1] });
|
|
5631
|
+
|
|
5632
|
+
//#endregion
|
|
5633
|
+
//#region src/scipy-special.ts
|
|
5634
|
+
var scipy_special_exports = {};
|
|
5635
|
+
__export(scipy_special_exports, {
|
|
5636
|
+
erf: () => erf,
|
|
5637
|
+
erfc: () => erfc,
|
|
5638
|
+
logSoftmax: () => logSoftmax,
|
|
5639
|
+
logit: () => logit,
|
|
5640
|
+
logsumexp: () => logsumexp,
|
|
5641
|
+
softmax: () => softmax
|
|
5642
|
+
});
|
|
5643
|
+
/**
|
|
5644
|
+
* @function
|
|
5645
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
5646
|
+
*/
|
|
5647
|
+
const logit = jit$1(function logit$1(x) {
|
|
5648
|
+
return log(x.ref.div(subtract(1, x)));
|
|
5649
|
+
});
|
|
5300
5650
|
|
|
5301
5651
|
//#endregion
|
|
5302
5652
|
//#region src/polyfills.ts
|
|
@@ -5391,6 +5741,24 @@ async function blockUntilReady(x) {
|
|
|
5391
5741
|
await Promise.all(promises);
|
|
5392
5742
|
return x;
|
|
5393
5743
|
}
|
|
5744
|
+
/**
|
|
5745
|
+
* Transfer `x` to `device`.
|
|
5746
|
+
*
|
|
5747
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
5748
|
+
* is committed to the device.
|
|
5749
|
+
*
|
|
5750
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
5751
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
5752
|
+
* default device.
|
|
5753
|
+
*/
|
|
5754
|
+
async function devicePut(x, device) {
|
|
5755
|
+
const [xflat, structure$1] = flatten(x);
|
|
5756
|
+
const yflat = await Promise.all(xflat.map((leaf) => {
|
|
5757
|
+
if (leaf instanceof Array$1) return device ? leaf._put(require_backend.getBackend(device)) : Promise.resolve(leaf);
|
|
5758
|
+
else return Promise.resolve(array(leaf, { device }));
|
|
5759
|
+
}));
|
|
5760
|
+
return unflatten(structure$1, yflat);
|
|
5761
|
+
}
|
|
5394
5762
|
|
|
5395
5763
|
//#endregion
|
|
5396
5764
|
exports.Array = Array$1;
|
|
@@ -5398,6 +5766,7 @@ exports.DType = require_backend.DType;
|
|
|
5398
5766
|
exports.Jaxpr = Jaxpr;
|
|
5399
5767
|
exports.blockUntilReady = blockUntilReady;
|
|
5400
5768
|
exports.defaultDevice = require_backend.defaultDevice;
|
|
5769
|
+
exports.devicePut = devicePut;
|
|
5401
5770
|
exports.devices = require_backend.devices;
|
|
5402
5771
|
exports.grad = grad;
|
|
5403
5772
|
exports.init = require_backend.init;
|
|
@@ -5432,6 +5801,12 @@ Object.defineProperty(exports, 'random', {
|
|
|
5432
5801
|
return random_exports;
|
|
5433
5802
|
}
|
|
5434
5803
|
});
|
|
5804
|
+
Object.defineProperty(exports, 'scipySpecial', {
|
|
5805
|
+
enumerable: true,
|
|
5806
|
+
get: function () {
|
|
5807
|
+
return scipy_special_exports;
|
|
5808
|
+
}
|
|
5809
|
+
});
|
|
5435
5810
|
exports.setDebug = require_backend.setDebug;
|
|
5436
5811
|
Object.defineProperty(exports, 'tree', {
|
|
5437
5812
|
enumerable: true,
|
|
@@ -5441,4 +5816,5 @@ Object.defineProperty(exports, 'tree', {
|
|
|
5441
5816
|
});
|
|
5442
5817
|
exports.valueAndGrad = valueAndGrad;
|
|
5443
5818
|
exports.vjp = vjp;
|
|
5444
|
-
exports.vmap = vmap;
|
|
5819
|
+
exports.vmap = vmap;
|
|
5820
|
+
//# sourceMappingURL=index.cjs.map
|