@jax-js/jax 0.0.4 → 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +296 -78
- package/dist/{backend-EBRGmEYw.js → backend-DwIAd0AG.js} +238 -116
- package/dist/{backend-Ss1Mev_-.cjs → backend-FtkbO6pI.cjs} +256 -122
- package/dist/index.cjs +653 -277
- package/dist/index.d.cts +167 -44
- package/dist/index.d.ts +167 -44
- package/dist/index.js +637 -268
- package/dist/{webgpu-BVdMaO9T.cjs → webgpu-BE7zA_01.cjs} +181 -151
- package/dist/{webgpu-ow0Pn_6q.js → webgpu-LGi2A3mS.js} +181 -151
- package/package.json +7 -5
package/dist/index.js
CHANGED
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-DwIAd0AG.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
@@ -29,6 +29,10 @@ var JsTreeDef = class JsTreeDef {
|
|
|
29
29
|
this.nodeMetadata = nodeMetadata;
|
|
30
30
|
this.childTreedefs = childTreedefs;
|
|
31
31
|
}
|
|
32
|
+
/** Get the total number of leaves in the tree. */
|
|
33
|
+
get size() {
|
|
34
|
+
return this.nodeType === NodeType.Leaf ? 1 : this.childTreedefs.reduce((a, b) => a + b.size, 0);
|
|
35
|
+
}
|
|
32
36
|
/** Returns a string representation of this tree definition. */
|
|
33
37
|
toString(root = true) {
|
|
34
38
|
if (root) return "JsTreeDef(" + this.toString(false) + ")";
|
|
@@ -184,6 +188,16 @@ function pool(st, ks, strides = 1, dilation = 1) {
|
|
|
184
188
|
const s_ = strides;
|
|
185
189
|
const d_ = dilation;
|
|
186
190
|
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
191
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
192
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, o * s - i])]);
|
|
193
|
+
st = st.reshape([...noop, ...zip(o_, s_).flatMap(([o, s]) => [o, s])]).shrink([...noop.map((x) => [0, x]), ...zip(o_, ks).flatMap(([o, k]) => [[0, o], [0, k]])]);
|
|
194
|
+
st = st.permute([
|
|
195
|
+
...range(noop.length),
|
|
196
|
+
...ks.map((_, j) => noop.length + 2 * j),
|
|
197
|
+
...ks.map((_, j) => noop.length + 2 * j + 1)
|
|
198
|
+
]);
|
|
199
|
+
return st;
|
|
200
|
+
}
|
|
187
201
|
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
188
202
|
const kidf = zipn(ks, i_, d_, f_);
|
|
189
203
|
st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
@@ -218,6 +232,12 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
218
232
|
const s_ = strides;
|
|
219
233
|
const d_ = dilation;
|
|
220
234
|
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
235
|
+
if (d_.every((d) => d === 1) && ks.every((k, j) => k <= s_[j])) {
|
|
236
|
+
st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + j, noop.length + o_.length + j])]);
|
|
237
|
+
st = st.pad([...noop.map(() => [0, 0]), ...zip(s_, ks).flatMap(([s, k]) => [[0, 0], [0, s - k]])]).reshape([...noop, ...zip(o_, s_).map(([o, s]) => o * s)]);
|
|
238
|
+
st = st.padOrShrink([...noop.map(() => [0, 0]), ...zipn(i_, o_, s_).map(([i, o, s]) => [0, i - o * s])]);
|
|
239
|
+
return st.reshape(st.shape.concat(rep(ks.length, 1)));
|
|
240
|
+
}
|
|
221
241
|
if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
222
242
|
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
223
243
|
const kidf = zipn(ks, i_, d_, f_);
|
|
@@ -327,6 +347,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
327
347
|
Primitive$1["Atan"] = "atan";
|
|
328
348
|
Primitive$1["Exp"] = "exp";
|
|
329
349
|
Primitive$1["Log"] = "log";
|
|
350
|
+
Primitive$1["Erf"] = "erf";
|
|
351
|
+
Primitive$1["Erfc"] = "erfc";
|
|
330
352
|
Primitive$1["Sqrt"] = "sqrt";
|
|
331
353
|
Primitive$1["Min"] = "min";
|
|
332
354
|
Primitive$1["Max"] = "max";
|
|
@@ -404,6 +426,12 @@ function exp$1(x) {
|
|
|
404
426
|
function log$1(x) {
|
|
405
427
|
return bind1(Primitive.Log, [x]);
|
|
406
428
|
}
|
|
429
|
+
function erf$1(x) {
|
|
430
|
+
return bind1(Primitive.Erf, [x]);
|
|
431
|
+
}
|
|
432
|
+
function erfc$1(x) {
|
|
433
|
+
return bind1(Primitive.Erfc, [x]);
|
|
434
|
+
}
|
|
407
435
|
function sqrt$1(x) {
|
|
408
436
|
return bind1(Primitive.Sqrt, [x]);
|
|
409
437
|
}
|
|
@@ -565,6 +593,21 @@ var Trace = class {
|
|
|
565
593
|
this.main = main;
|
|
566
594
|
}
|
|
567
595
|
};
|
|
596
|
+
/**
|
|
597
|
+
* Broadcast shapes and promote types with casting for two avals.
|
|
598
|
+
*
|
|
599
|
+
* This implements the weak type behavior described in `promoteTypes()`, but not
|
|
600
|
+
* implemented in that function as `weakType` is not passed.
|
|
601
|
+
*/
|
|
602
|
+
function promoteAvals(a, b) {
|
|
603
|
+
const shape$1 = generalBroadcast(a.shape, b.shape);
|
|
604
|
+
const weakType = a.weakType && b.weakType;
|
|
605
|
+
let dtype;
|
|
606
|
+
if (a.weakType === b.weakType) dtype = promoteTypes(a.dtype, b.dtype);
|
|
607
|
+
else if (a.weakType) dtype = promoteTypes(b.dtype, DType.Uint32);
|
|
608
|
+
else dtype = promoteTypes(a.dtype, DType.Uint32);
|
|
609
|
+
return new ShapedArray(shape$1, dtype, weakType);
|
|
610
|
+
}
|
|
568
611
|
var Tracer = class Tracer {
|
|
569
612
|
/** @ignore */
|
|
570
613
|
_trace;
|
|
@@ -579,10 +622,19 @@ var Tracer = class Tracer {
|
|
|
579
622
|
get size() {
|
|
580
623
|
return prod(this.shape);
|
|
581
624
|
}
|
|
582
|
-
/** The dtype of the array. */
|
|
625
|
+
/** The dtype of elements stored in the array. */
|
|
583
626
|
get dtype() {
|
|
584
627
|
return this.aval.dtype;
|
|
585
628
|
}
|
|
629
|
+
/**
|
|
630
|
+
* Whether the array is weakly typed.
|
|
631
|
+
*
|
|
632
|
+
* Weakly typed arrays will cast to the dtype of the other operand. See
|
|
633
|
+
* `promoteTypes()` for details.
|
|
634
|
+
*/
|
|
635
|
+
get weakType() {
|
|
636
|
+
return this.aval.weakType;
|
|
637
|
+
}
|
|
586
638
|
/** The number of dimensions of the array. */
|
|
587
639
|
get ndim() {
|
|
588
640
|
return this.shape.length;
|
|
@@ -819,12 +871,13 @@ function getShape(x) {
|
|
|
819
871
|
return x instanceof Tracer ? x.shape : [];
|
|
820
872
|
}
|
|
821
873
|
var ShapedArray = class ShapedArray {
|
|
822
|
-
constructor(shape$1, dtype) {
|
|
874
|
+
constructor(shape$1, dtype, weakType) {
|
|
823
875
|
this.shape = shape$1;
|
|
824
876
|
this.dtype = dtype;
|
|
877
|
+
this.weakType = weakType;
|
|
825
878
|
}
|
|
826
879
|
static fromAval(aval) {
|
|
827
|
-
return new ShapedArray(aval.shape, aval.dtype);
|
|
880
|
+
return new ShapedArray(aval.shape, aval.dtype, aval.weakType);
|
|
828
881
|
}
|
|
829
882
|
get ndim() {
|
|
830
883
|
return this.shape.length;
|
|
@@ -838,7 +891,7 @@ var ShapedArray = class ShapedArray {
|
|
|
838
891
|
};
|
|
839
892
|
function getAval(x) {
|
|
840
893
|
if (x instanceof Tracer) return x.aval;
|
|
841
|
-
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32);
|
|
894
|
+
else if (typeof x === "boolean" || typeof x === "number") return new ShapedArray([], typeof x === "boolean" ? DType.Bool : DType.Float32, typeof x === "boolean" ? false : true);
|
|
842
895
|
else throw new TypeError(`Unknown value: ${x}`);
|
|
843
896
|
}
|
|
844
897
|
function bind(prim, args, params = {}) {
|
|
@@ -1121,12 +1174,18 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1121
1174
|
} else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
1122
1175
|
});
|
|
1123
1176
|
}
|
|
1124
|
-
function broadcastedJit(fn) {
|
|
1177
|
+
function broadcastedJit(fn, opts) {
|
|
1125
1178
|
return (nargs, exps, avals, params) => {
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1179
|
+
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1180
|
+
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1181
|
+
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1182
|
+
exps = exps.map((exp$3, i) => {
|
|
1183
|
+
exp$3 = reshapeViews(exp$3, (st) => {
|
|
1184
|
+
if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
|
|
1185
|
+
});
|
|
1186
|
+
if (exp$3.dtype !== newDtype && !skipCastIdx.includes(i)) exp$3 = AluExp.cast(newDtype, exp$3);
|
|
1187
|
+
return exp$3;
|
|
1188
|
+
});
|
|
1130
1189
|
const exp$2 = fn(exps, params);
|
|
1131
1190
|
return new Kernel(nargs, prod(newShape), exp$2);
|
|
1132
1191
|
};
|
|
@@ -1160,7 +1219,7 @@ const jitRules = {
|
|
|
1160
1219
|
const k1 = reshapeViews(keys[1], mapping);
|
|
1161
1220
|
const c0 = AluExp.u32(0);
|
|
1162
1221
|
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1163
|
-
const exp$2 = AluExp.threefry2x32(
|
|
1222
|
+
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1164
1223
|
return new Kernel(nargs, prod(shape$1), exp$2);
|
|
1165
1224
|
},
|
|
1166
1225
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
@@ -1169,6 +1228,8 @@ const jitRules = {
|
|
|
1169
1228
|
[Primitive.Atan]: unopJit(AluExp.atan),
|
|
1170
1229
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
1171
1230
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1231
|
+
[Primitive.Erf]: unopJit(AluExp.erf),
|
|
1232
|
+
[Primitive.Erfc]: unopJit(AluExp.erfc),
|
|
1172
1233
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1173
1234
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
1174
1235
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
@@ -1201,7 +1262,7 @@ const jitRules = {
|
|
|
1201
1262
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
1202
1263
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
1203
1264
|
const c = k1.exp;
|
|
1204
|
-
const cs =
|
|
1265
|
+
const cs = promoteAvals(as, bs);
|
|
1205
1266
|
return jitRules[Primitive.Reduce](nargs, [c], [cs], {
|
|
1206
1267
|
op: AluOp.Add,
|
|
1207
1268
|
axis: [cs.ndim - 1]
|
|
@@ -1211,12 +1272,12 @@ const jitRules = {
|
|
|
1211
1272
|
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1212
1273
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1213
1274
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1214
|
-
as = new ShapedArray(stX.shape, as.dtype);
|
|
1215
|
-
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1275
|
+
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1276
|
+
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1216
1277
|
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1217
1278
|
},
|
|
1218
1279
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1219
|
-
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
|
|
1280
|
+
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1220
1281
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
1221
1282
|
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
1222
1283
|
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
@@ -1265,9 +1326,10 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1265
1326
|
Primitive.Conv,
|
|
1266
1327
|
Primitive.PoolTranspose
|
|
1267
1328
|
];
|
|
1329
|
+
const heterogeneousViewPrimitives = [Primitive.Gather, Primitive.RandomBits];
|
|
1268
1330
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1269
1331
|
const eqn = jaxpr.eqns[i];
|
|
1270
|
-
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive
|
|
1332
|
+
if (reducePrimitives.includes(eqn.primitive) || heterogeneousViewPrimitives.includes(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1271
1333
|
for (const v of eqn.outBinders) {
|
|
1272
1334
|
blackNodes.add(v);
|
|
1273
1335
|
p1NextBlack.set(v, v);
|
|
@@ -1386,7 +1448,7 @@ var PendingExecute = class {
|
|
|
1386
1448
|
/**
|
|
1387
1449
|
* A multidimensional numeric array with data stored on CPU or GPU.
|
|
1388
1450
|
*
|
|
1389
|
-
* This is the library's core data type. Equivalent to `
|
|
1451
|
+
* This is the library's core data type. Equivalent to `jax.Array` from JAX, or
|
|
1390
1452
|
* `torch.Tensor`.
|
|
1391
1453
|
*
|
|
1392
1454
|
* Not to be confused with the JavaScript "Array" constructor. Avoid importing
|
|
@@ -1397,9 +1459,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1397
1459
|
static #nextId = 1001;
|
|
1398
1460
|
id;
|
|
1399
1461
|
#dtype;
|
|
1462
|
+
#weakType;
|
|
1400
1463
|
#source;
|
|
1401
1464
|
#st;
|
|
1402
1465
|
#backend;
|
|
1466
|
+
#committed;
|
|
1403
1467
|
#rc;
|
|
1404
1468
|
#pendingSet;
|
|
1405
1469
|
/**
|
|
@@ -1408,21 +1472,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1408
1472
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1409
1473
|
* will be freed when the array is disposed.
|
|
1410
1474
|
*/
|
|
1411
|
-
constructor(
|
|
1475
|
+
constructor(args) {
|
|
1412
1476
|
super(baseArrayTrace);
|
|
1413
1477
|
this.id = Array$1.#nextId++;
|
|
1414
|
-
this.#dtype = dtype;
|
|
1415
|
-
this.#
|
|
1416
|
-
this.#
|
|
1417
|
-
this.#
|
|
1478
|
+
this.#dtype = args.dtype;
|
|
1479
|
+
this.#weakType = args.weakType;
|
|
1480
|
+
this.#source = args.source;
|
|
1481
|
+
this.#st = args.st;
|
|
1482
|
+
this.#backend = args.backend;
|
|
1483
|
+
this.#committed = args.committed;
|
|
1418
1484
|
this.#rc = 1;
|
|
1419
|
-
this.#pendingSet = new Set(pending);
|
|
1485
|
+
this.#pendingSet = new Set(args.pending);
|
|
1420
1486
|
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1421
|
-
else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1487
|
+
else if (this.#source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1422
1488
|
}
|
|
1423
1489
|
/** @ignore */
|
|
1424
1490
|
get aval() {
|
|
1425
|
-
return new ShapedArray(this.#st.shape, this.#dtype);
|
|
1491
|
+
return new ShapedArray(this.#st.shape, this.#dtype, this.#weakType);
|
|
1426
1492
|
}
|
|
1427
1493
|
/** Return a simple string representation of the array's dimensions. */
|
|
1428
1494
|
toString() {
|
|
@@ -1434,6 +1500,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1434
1500
|
#check() {
|
|
1435
1501
|
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1436
1502
|
}
|
|
1503
|
+
/** Construct an array, copying fields from `this`. */
|
|
1504
|
+
#newArrayFrom(args) {
|
|
1505
|
+
return new Array$1({
|
|
1506
|
+
source: args.source ?? this.#source,
|
|
1507
|
+
st: args.st ?? this.#st,
|
|
1508
|
+
dtype: args.dtype ?? this.#dtype,
|
|
1509
|
+
weakType: this.#weakType,
|
|
1510
|
+
backend: args.backend ?? this.#backend,
|
|
1511
|
+
committed: args.committed ?? this.#committed,
|
|
1512
|
+
pending: args.pending ?? this.#pending ?? void 0
|
|
1513
|
+
});
|
|
1514
|
+
}
|
|
1437
1515
|
get ref() {
|
|
1438
1516
|
this.#check();
|
|
1439
1517
|
this.#rc++;
|
|
@@ -1473,7 +1551,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1473
1551
|
const pending = this.#pending;
|
|
1474
1552
|
for (const exe of pending) exe.updateRc(1);
|
|
1475
1553
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1476
|
-
const ar =
|
|
1554
|
+
const ar = this.#newArrayFrom({
|
|
1555
|
+
st,
|
|
1556
|
+
pending
|
|
1557
|
+
});
|
|
1477
1558
|
this.dispose();
|
|
1478
1559
|
return ar;
|
|
1479
1560
|
}
|
|
@@ -1483,9 +1564,10 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1483
1564
|
*/
|
|
1484
1565
|
#gather(indices, axis, outDim) {
|
|
1485
1566
|
this.#check();
|
|
1486
|
-
if (indices.some((a) => a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1487
1567
|
const axisSet = new Set(axis);
|
|
1488
1568
|
if (axisSet.size !== axis.length) throw new TypeError("Gather axis must not have duplicates");
|
|
1569
|
+
if (indices.some((a) => a.#committed && a.#backend !== this.#backend)) throw new TypeError(`Gather indices must have the same backend: ${this.#backend.type}`);
|
|
1570
|
+
indices = indices.map((ar) => ar._putSync(this.#backend));
|
|
1489
1571
|
indices = Array$1.#broadcastArrays(indices);
|
|
1490
1572
|
const indexShape = indices[0].shape;
|
|
1491
1573
|
const finalShape = this.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1522,7 +1604,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1522
1604
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1523
1605
|
this.dispose();
|
|
1524
1606
|
for (const ar of indices) ar.dispose();
|
|
1525
|
-
return
|
|
1607
|
+
return this.#newArrayFrom({
|
|
1608
|
+
source: output,
|
|
1609
|
+
st: ShapeTracker.fromShape(finalShape),
|
|
1610
|
+
pending
|
|
1611
|
+
});
|
|
1526
1612
|
}
|
|
1527
1613
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1528
1614
|
#moveAxesDown(axis) {
|
|
@@ -1545,11 +1631,17 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1545
1631
|
return this.#reshape(this.#st.permute(perm));
|
|
1546
1632
|
}
|
|
1547
1633
|
#unary(op, dtypeOutput) {
|
|
1634
|
+
const weakType = !dtypeOutput && this.#weakType;
|
|
1548
1635
|
dtypeOutput ??= this.#dtype;
|
|
1549
1636
|
this.#check();
|
|
1550
1637
|
if (this.#source instanceof AluExp) {
|
|
1551
1638
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1552
|
-
|
|
1639
|
+
this.dispose();
|
|
1640
|
+
return this.#newArrayFrom({
|
|
1641
|
+
source: exp$3.simplify(),
|
|
1642
|
+
dtype: dtypeOutput,
|
|
1643
|
+
weakType
|
|
1644
|
+
});
|
|
1553
1645
|
}
|
|
1554
1646
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1555
1647
|
const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1559,41 +1651,67 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1559
1651
|
for (const exe of pending) exe.updateRc(1);
|
|
1560
1652
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1561
1653
|
this.dispose();
|
|
1562
|
-
return
|
|
1654
|
+
return this.#newArrayFrom({
|
|
1655
|
+
source: output,
|
|
1656
|
+
st: ShapeTracker.fromShape(this.shape),
|
|
1657
|
+
dtype: dtypeOutput,
|
|
1658
|
+
weakType,
|
|
1659
|
+
pending
|
|
1660
|
+
});
|
|
1563
1661
|
}
|
|
1564
1662
|
#binary(op, other) {
|
|
1565
|
-
const custom = (src) => new AluExp(op,
|
|
1663
|
+
const custom = (src) => new AluExp(op, src[0].dtype, src);
|
|
1566
1664
|
return Array$1.#naryCustom(op, custom, [this, other]);
|
|
1567
1665
|
}
|
|
1568
|
-
static #naryCustom(name, custom, arrays, { dtypeOverride,
|
|
1666
|
+
static #naryCustom(name, custom, arrays, { dtypeOverride, strongTypeOutput, reduceAxis } = {}) {
|
|
1569
1667
|
const n = arrays.length;
|
|
1570
|
-
const backend = arrays[0].#backend;
|
|
1571
1668
|
if (n === 0) throw new TypeError(`No inputs for ${name}`);
|
|
1572
1669
|
for (const ar of arrays) ar.#check();
|
|
1573
|
-
let
|
|
1574
|
-
|
|
1575
|
-
|
|
1576
|
-
|
|
1577
|
-
|
|
1578
|
-
|
|
1579
|
-
|
|
1580
|
-
}
|
|
1581
|
-
|
|
1582
|
-
|
|
1670
|
+
let castDtype;
|
|
1671
|
+
let castWeakType = true;
|
|
1672
|
+
for (let i = 0; i < n; i++) if (dtypeOverride?.[i]) {
|
|
1673
|
+
if (arrays[i].#dtype !== dtypeOverride[i]) throw new TypeError(`Wrong dtype in ${name}: expected ${dtypeOverride[i]}, got ${arrays[i].#dtype}`);
|
|
1674
|
+
} else if (castDtype === void 0) {
|
|
1675
|
+
castDtype = arrays[i].#dtype;
|
|
1676
|
+
castWeakType = arrays[i].#weakType;
|
|
1677
|
+
} else ({dtype: castDtype, weakType: castWeakType} = promoteAvals(new ShapedArray([], castDtype, castWeakType), new ShapedArray([], arrays[i].#dtype, arrays[i].#weakType)));
|
|
1678
|
+
const weakType = castWeakType && !strongTypeOutput;
|
|
1679
|
+
const { backend, committed } = Array$1.#computeBackend(name, arrays);
|
|
1680
|
+
arrays = arrays.map((ar) => ar._putSync(backend));
|
|
1583
1681
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1584
1682
|
const newShape = [...arrays[0].shape];
|
|
1585
1683
|
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
1684
|
+
const sources = arrays.map((ar, i) => {
|
|
1685
|
+
if (!dtypeOverride?.[i]) return AluExp.cast(castDtype, ar.#source);
|
|
1686
|
+
else return ar.#source;
|
|
1687
|
+
});
|
|
1586
1688
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1587
|
-
const exp$4 = custom(
|
|
1588
|
-
|
|
1689
|
+
const exp$4 = custom(sources);
|
|
1690
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1691
|
+
return new Array$1({
|
|
1692
|
+
source: exp$4.simplify(),
|
|
1693
|
+
st: arrays[0].#st,
|
|
1694
|
+
dtype: exp$4.dtype,
|
|
1695
|
+
weakType,
|
|
1696
|
+
backend,
|
|
1697
|
+
committed
|
|
1698
|
+
});
|
|
1589
1699
|
}
|
|
1590
|
-
const exp$3 = custom(arrays.map((ar) => {
|
|
1591
|
-
const src$1 =
|
|
1700
|
+
const exp$3 = custom(arrays.map((ar, i) => {
|
|
1701
|
+
const src$1 = sources[i];
|
|
1592
1702
|
if (ar.#st.contiguous) return src$1;
|
|
1593
1703
|
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1594
1704
|
}));
|
|
1595
1705
|
const st = ShapeTracker.fromShape(newShape);
|
|
1596
|
-
|
|
1706
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1707
|
+
return new Array$1({
|
|
1708
|
+
source: exp$3.simplify(),
|
|
1709
|
+
st,
|
|
1710
|
+
dtype: exp$3.dtype,
|
|
1711
|
+
weakType,
|
|
1712
|
+
backend,
|
|
1713
|
+
committed
|
|
1714
|
+
});
|
|
1597
1715
|
}
|
|
1598
1716
|
let indices;
|
|
1599
1717
|
if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
|
|
@@ -1603,14 +1721,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1603
1721
|
}
|
|
1604
1722
|
const inputs = [];
|
|
1605
1723
|
const src = [];
|
|
1606
|
-
for (const ar of arrays
|
|
1607
|
-
|
|
1608
|
-
|
|
1609
|
-
|
|
1610
|
-
gid = inputs.
|
|
1611
|
-
|
|
1724
|
+
for (const [i, ar] of arrays.entries()) {
|
|
1725
|
+
let nextSrc;
|
|
1726
|
+
if (ar.#source instanceof AluExp) nextSrc = accessorAluExp(ar.#source, ar.#st, indices);
|
|
1727
|
+
else {
|
|
1728
|
+
let gid = inputs.indexOf(ar.#source);
|
|
1729
|
+
if (gid === -1) {
|
|
1730
|
+
gid = inputs.length;
|
|
1731
|
+
inputs.push(ar.#source);
|
|
1732
|
+
}
|
|
1733
|
+
nextSrc = AluExp.globalView(ar.#dtype, gid, ar.#st, indices);
|
|
1612
1734
|
}
|
|
1613
|
-
|
|
1735
|
+
if (!dtypeOverride?.[i]) nextSrc = AluExp.cast(castDtype, nextSrc);
|
|
1736
|
+
src.push(nextSrc);
|
|
1614
1737
|
}
|
|
1615
1738
|
const exp$2 = custom(src);
|
|
1616
1739
|
let re = void 0;
|
|
@@ -1623,13 +1746,19 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1623
1746
|
const pending = new Set([...arrays.flatMap((ar) => ar.#pending)]);
|
|
1624
1747
|
for (const exe of pending) exe.updateRc(1);
|
|
1625
1748
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1626
|
-
|
|
1627
|
-
return new Array$1(
|
|
1749
|
+
arrays.forEach((ar) => ar.dispose());
|
|
1750
|
+
return new Array$1({
|
|
1751
|
+
source: output,
|
|
1752
|
+
st: ShapeTracker.fromShape(newShape),
|
|
1753
|
+
dtype: kernel.dtype,
|
|
1754
|
+
weakType,
|
|
1755
|
+
backend,
|
|
1756
|
+
committed,
|
|
1757
|
+
pending
|
|
1758
|
+
});
|
|
1628
1759
|
}
|
|
1629
1760
|
/** Reduce the last dimension of the array by an operation. */
|
|
1630
1761
|
#reduce(op) {
|
|
1631
|
-
this.#check();
|
|
1632
|
-
if (this.ndim === 0) throw new Error("Cannot reduce a scalar");
|
|
1633
1762
|
const shape$1 = this.shape;
|
|
1634
1763
|
const reduction = new Reduction(this.#dtype, op, shape$1[shape$1.length - 1]);
|
|
1635
1764
|
const newShape = shape$1.slice(0, -1);
|
|
@@ -1648,7 +1777,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1648
1777
|
for (const exe of pending) exe.updateRc(1);
|
|
1649
1778
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1650
1779
|
this.dispose();
|
|
1651
|
-
return
|
|
1780
|
+
return this.#newArrayFrom({
|
|
1781
|
+
source: output,
|
|
1782
|
+
st: ShapeTracker.fromShape(newShape),
|
|
1783
|
+
pending
|
|
1784
|
+
});
|
|
1652
1785
|
}
|
|
1653
1786
|
/**
|
|
1654
1787
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1684,8 +1817,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1684
1817
|
}
|
|
1685
1818
|
#dataInline() {
|
|
1686
1819
|
this.#check();
|
|
1687
|
-
|
|
1688
|
-
const ar =
|
|
1820
|
+
if (!(this.#source instanceof AluExp)) throw new Error("internal: #dataInline called on non-AluExp source");
|
|
1821
|
+
const ar = this.#newArrayFrom({ backend: getBackend("cpu") });
|
|
1689
1822
|
this.dispose();
|
|
1690
1823
|
return ar.dataSync();
|
|
1691
1824
|
}
|
|
@@ -1698,6 +1831,23 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1698
1831
|
return ar.#reshape(ar.#st.broadcast(newShape, range(newShape.length - ar.ndim)));
|
|
1699
1832
|
});
|
|
1700
1833
|
}
|
|
1834
|
+
static #computeBackend(name, arrays) {
|
|
1835
|
+
const committed = arrays.filter((ar) => ar.#committed);
|
|
1836
|
+
if (committed.length > 0) {
|
|
1837
|
+
const backend = committed[0].#backend;
|
|
1838
|
+
for (const ar of committed) if (ar.#backend !== backend) throw new Error(`Device mismatch in ${name} between committed arrays on (${backend.type}, ${ar.#backend.type}), please move to the same device with devicePut()`);
|
|
1839
|
+
return {
|
|
1840
|
+
backend,
|
|
1841
|
+
committed: true
|
|
1842
|
+
};
|
|
1843
|
+
} else {
|
|
1844
|
+
const backend = arrays.length > 0 ? arrays[0].#backend : getBackend();
|
|
1845
|
+
return {
|
|
1846
|
+
backend,
|
|
1847
|
+
committed: false
|
|
1848
|
+
};
|
|
1849
|
+
}
|
|
1850
|
+
}
|
|
1701
1851
|
/** Realize the array and return it as data. */
|
|
1702
1852
|
async data() {
|
|
1703
1853
|
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
@@ -1811,7 +1961,11 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1811
1961
|
x.#backend.incRef(x.#source);
|
|
1812
1962
|
const pending = x.#pending;
|
|
1813
1963
|
for (const exe of pending) exe.updateRc(1);
|
|
1814
|
-
const y =
|
|
1964
|
+
const y = x.#newArrayFrom({
|
|
1965
|
+
dtype,
|
|
1966
|
+
weakType: false,
|
|
1967
|
+
pending
|
|
1968
|
+
});
|
|
1815
1969
|
x.dispose();
|
|
1816
1970
|
return [y];
|
|
1817
1971
|
}
|
|
@@ -1853,6 +2007,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1853
2007
|
[Primitive.Log]([x]) {
|
|
1854
2008
|
return [x.#unary(AluOp.Log)];
|
|
1855
2009
|
},
|
|
2010
|
+
[Primitive.Erf]([x]) {
|
|
2011
|
+
return [x.#unary(AluOp.Erf)];
|
|
2012
|
+
},
|
|
2013
|
+
[Primitive.Erfc]([x]) {
|
|
2014
|
+
return [x.#unary(AluOp.Erfc)];
|
|
2015
|
+
},
|
|
1856
2016
|
[Primitive.Sqrt]([x]) {
|
|
1857
2017
|
return [x.#unary(AluOp.Sqrt)];
|
|
1858
2018
|
},
|
|
@@ -1886,7 +2046,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1886
2046
|
},
|
|
1887
2047
|
[Primitive.Compare]([x, y], { op }) {
|
|
1888
2048
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1889
|
-
return [Array$1.#naryCustom("compare", custom, [x, y], {
|
|
2049
|
+
return [Array$1.#naryCustom("compare", custom, [x, y], { strongTypeOutput: true })];
|
|
1890
2050
|
},
|
|
1891
2051
|
[Primitive.Where]([cond, x, y]) {
|
|
1892
2052
|
const custom = ([cond$1, x$1, y$1]) => AluExp.where(cond$1, x$1, y$1);
|
|
@@ -1921,7 +2081,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1921
2081
|
},
|
|
1922
2082
|
[Primitive.JitCall](args, { jaxpr, numConsts }) {
|
|
1923
2083
|
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit_call expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
1924
|
-
const backend =
|
|
2084
|
+
const { backend, committed } = Array$1.#computeBackend("jit_call", args);
|
|
2085
|
+
args = args.map((ar) => ar._putSync(backend));
|
|
1925
2086
|
const consts = args.slice(0, numConsts);
|
|
1926
2087
|
const tracers = args.slice(numConsts);
|
|
1927
2088
|
const jp = jitCompile(backend, jaxpr, consts);
|
|
@@ -1932,43 +2093,66 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1932
2093
|
pending.splice(0, 0, ...prevPending);
|
|
1933
2094
|
args.forEach((x) => x.dispose());
|
|
1934
2095
|
return outputs.map((source, i) => {
|
|
1935
|
-
return new Array$1(
|
|
2096
|
+
return new Array$1({
|
|
2097
|
+
source,
|
|
2098
|
+
st: ShapeTracker.fromShape(jaxpr.outs[i].aval.shape),
|
|
2099
|
+
dtype: jaxpr.outs[i].aval.dtype,
|
|
2100
|
+
weakType: jaxpr.outs[i].aval.weakType,
|
|
2101
|
+
backend,
|
|
2102
|
+
committed,
|
|
2103
|
+
pending
|
|
2104
|
+
});
|
|
1936
2105
|
});
|
|
1937
2106
|
}
|
|
1938
2107
|
};
|
|
1939
2108
|
}
|
|
2109
|
+
/** @private */
|
|
1940
2110
|
_realizeSource() {
|
|
1941
2111
|
this.#realize();
|
|
1942
2112
|
return this.#source;
|
|
1943
2113
|
}
|
|
2114
|
+
/** @private Put this array on a new backend, asynchronously. */
|
|
2115
|
+
async _put(backend) {
|
|
2116
|
+
if (this.#backend === backend) return this;
|
|
2117
|
+
if (this.#source instanceof AluExp) {
|
|
2118
|
+
const ar = this.#newArrayFrom({
|
|
2119
|
+
backend,
|
|
2120
|
+
committed: true
|
|
2121
|
+
});
|
|
2122
|
+
this.dispose();
|
|
2123
|
+
return ar;
|
|
2124
|
+
} else {
|
|
2125
|
+
const data = await this.data();
|
|
2126
|
+
return arrayFromData(data, this.shape, {
|
|
2127
|
+
dtype: this.#dtype,
|
|
2128
|
+
device: backend.type
|
|
2129
|
+
}, this.#weakType);
|
|
2130
|
+
}
|
|
2131
|
+
}
|
|
2132
|
+
/** @private Put this array on a new backend, synchronously. */
|
|
2133
|
+
_putSync(backend) {
|
|
2134
|
+
if (this.#backend === backend) return this;
|
|
2135
|
+
if (this.#source instanceof AluExp) {
|
|
2136
|
+
const ar = this.#newArrayFrom({
|
|
2137
|
+
backend,
|
|
2138
|
+
committed: true
|
|
2139
|
+
});
|
|
2140
|
+
this.dispose();
|
|
2141
|
+
return ar;
|
|
2142
|
+
} else {
|
|
2143
|
+
const data = this.dataSync();
|
|
2144
|
+
return arrayFromData(data, this.shape, {
|
|
2145
|
+
dtype: this.#dtype,
|
|
2146
|
+
device: backend.type
|
|
2147
|
+
}, this.#weakType);
|
|
2148
|
+
}
|
|
2149
|
+
}
|
|
1944
2150
|
};
|
|
1945
|
-
/** Construct an array from a single scalar constant. */
|
|
1946
|
-
function scalar(value, { dtype, device } = {}) {
|
|
1947
|
-
if (typeof value === "number") {
|
|
1948
|
-
dtype ??= DType.Float32;
|
|
1949
|
-
if (![
|
|
1950
|
-
DType.Float32,
|
|
1951
|
-
DType.Float16,
|
|
1952
|
-
DType.Int32,
|
|
1953
|
-
DType.Uint32
|
|
1954
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1955
|
-
} else if (typeof value === "boolean") {
|
|
1956
|
-
dtype ??= DType.Bool;
|
|
1957
|
-
if (![
|
|
1958
|
-
DType.Float32,
|
|
1959
|
-
DType.Float16,
|
|
1960
|
-
DType.Int32,
|
|
1961
|
-
DType.Uint32,
|
|
1962
|
-
DType.Bool
|
|
1963
|
-
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
1964
|
-
} else throw new TypeError(`Invalid type for scalar ${value}`);
|
|
1965
|
-
return new Array$1(AluExp.const(dtype, value), ShapeTracker.fromShape([]), dtype, getBackend(device));
|
|
1966
|
-
}
|
|
1967
2151
|
/** Constructor for creating a new array from data. */
|
|
1968
2152
|
function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
1969
2153
|
if (values instanceof Tracer) {
|
|
1970
2154
|
if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1971
|
-
if (dtype && values.dtype !== dtype)
|
|
2155
|
+
if (dtype && values.dtype !== dtype) values = values.astype(dtype);
|
|
1972
2156
|
return values;
|
|
1973
2157
|
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1974
2158
|
dtype,
|
|
@@ -1990,6 +2174,10 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1990
2174
|
dtype,
|
|
1991
2175
|
device
|
|
1992
2176
|
});
|
|
2177
|
+
if (size$1 === 1) return full(shape$1, flat[0], {
|
|
2178
|
+
dtype,
|
|
2179
|
+
device
|
|
2180
|
+
});
|
|
1993
2181
|
if (typeof flat[0] === "boolean") {
|
|
1994
2182
|
dtype = dtype ?? DType.Bool;
|
|
1995
2183
|
const data = new Int32Array(flat.map((x) => x ? 1 : 0));
|
|
@@ -1998,46 +2186,52 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1998
2186
|
device
|
|
1999
2187
|
});
|
|
2000
2188
|
} else {
|
|
2189
|
+
const weakType = dtype == void 0;
|
|
2001
2190
|
dtype = dtype ?? DType.Float32;
|
|
2002
2191
|
const data = dtypedJsArray(dtype, flat);
|
|
2003
2192
|
return arrayFromData(data, shape$1, {
|
|
2004
2193
|
dtype,
|
|
2005
2194
|
device
|
|
2006
|
-
});
|
|
2195
|
+
}, weakType);
|
|
2007
2196
|
}
|
|
2008
2197
|
}
|
|
2009
2198
|
}
|
|
2010
|
-
function arrayFromData(data, shape$1, { dtype, device } =
|
|
2199
|
+
function arrayFromData(data, shape$1, { dtype, device }, weakType = false) {
|
|
2200
|
+
if (data instanceof Float32Array) {
|
|
2201
|
+
if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2202
|
+
dtype ??= DType.Float32;
|
|
2203
|
+
} else if (data instanceof Int32Array) {
|
|
2204
|
+
if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2205
|
+
dtype ??= DType.Int32;
|
|
2206
|
+
} else if (data instanceof Uint32Array) {
|
|
2207
|
+
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2208
|
+
dtype ??= DType.Uint32;
|
|
2209
|
+
} else if (data instanceof Float16Array) {
|
|
2210
|
+
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2211
|
+
dtype ??= DType.Float16;
|
|
2212
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2011
2213
|
if (data.length < inlineArrayLimit) {
|
|
2012
2214
|
let allEqual = true;
|
|
2013
2215
|
for (let i = 1; i < data.length; i++) if (data[i] !== data[0]) {
|
|
2014
2216
|
allEqual = false;
|
|
2015
2217
|
break;
|
|
2016
2218
|
}
|
|
2017
|
-
if (allEqual)
|
|
2018
|
-
dtype,
|
|
2019
|
-
device
|
|
2020
|
-
}
|
|
2219
|
+
if (allEqual) {
|
|
2220
|
+
const sa = new ShapedArray(shape$1, dtype, weakType);
|
|
2221
|
+
return fullInternal(sa, data[0], device);
|
|
2222
|
+
}
|
|
2021
2223
|
}
|
|
2022
2224
|
const backend = getBackend(device);
|
|
2023
|
-
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2027
|
-
|
|
2028
|
-
|
|
2029
|
-
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
dtype ??= DType.Uint32;
|
|
2034
|
-
} else if (data instanceof Float16Array) {
|
|
2035
|
-
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2036
|
-
dtype ??= DType.Float16;
|
|
2037
|
-
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2038
|
-
const slot = backend.malloc(data.byteLength, buf);
|
|
2039
|
-
return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2040
|
-
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
2225
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2226
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2227
|
+
return new Array$1({
|
|
2228
|
+
source: slot,
|
|
2229
|
+
st: ShapeTracker.fromShape(shape$1),
|
|
2230
|
+
dtype,
|
|
2231
|
+
weakType,
|
|
2232
|
+
backend,
|
|
2233
|
+
committed: device != void 0
|
|
2234
|
+
});
|
|
2041
2235
|
}
|
|
2042
2236
|
function dataToJs(dtype, data, shape$1) {
|
|
2043
2237
|
if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -2053,7 +2247,7 @@ function dataToJs(dtype, data, shape$1) {
|
|
|
2053
2247
|
/** If x is a value, lift it into an array, otherwise leave it be. */
|
|
2054
2248
|
function pureArray(x) {
|
|
2055
2249
|
if (x instanceof Tracer) return x;
|
|
2056
|
-
else return
|
|
2250
|
+
else return array(x);
|
|
2057
2251
|
}
|
|
2058
2252
|
var EvalTrace = class extends Trace {
|
|
2059
2253
|
pure = (x) => pureArray(x);
|
|
@@ -2064,20 +2258,28 @@ var EvalTrace = class extends Trace {
|
|
|
2064
2258
|
};
|
|
2065
2259
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
2066
2260
|
const implRules = Array$1._implRules();
|
|
2261
|
+
function fullInternal(aval, fillValue, device) {
|
|
2262
|
+
return new Array$1({
|
|
2263
|
+
source: AluExp.const(aval.dtype, fillValue),
|
|
2264
|
+
st: ShapeTracker.fromShape(aval.shape),
|
|
2265
|
+
dtype: aval.dtype,
|
|
2266
|
+
weakType: aval.weakType,
|
|
2267
|
+
backend: getBackend(device),
|
|
2268
|
+
committed: device != void 0
|
|
2269
|
+
});
|
|
2270
|
+
}
|
|
2067
2271
|
function zerosLike$1(val, dtype) {
|
|
2068
|
-
|
|
2069
|
-
if (val instanceof Tracer) val.dispose();
|
|
2070
|
-
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2272
|
+
return fullLike(val, 0, dtype);
|
|
2071
2273
|
}
|
|
2072
2274
|
function onesLike$1(val, dtype) {
|
|
2073
|
-
|
|
2074
|
-
if (val instanceof Tracer) val.dispose();
|
|
2075
|
-
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2275
|
+
return fullLike(val, 1, dtype);
|
|
2076
2276
|
}
|
|
2077
2277
|
function fullLike(val, fillValue, dtype) {
|
|
2078
2278
|
const aval = getAval(val);
|
|
2079
2279
|
if (val instanceof Tracer) val.dispose();
|
|
2080
|
-
|
|
2280
|
+
if (fillValue instanceof Tracer) throw new Error("numpy.fullLike() with array argument not implemented yet");
|
|
2281
|
+
const sa = new ShapedArray(aval.shape, dtype ?? aval.dtype, aval.weakType);
|
|
2282
|
+
return fullInternal(sa, fillValue);
|
|
2081
2283
|
}
|
|
2082
2284
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
2083
2285
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -2095,19 +2297,14 @@ function ones(shape$1, { dtype, device } = {}) {
|
|
|
2095
2297
|
}
|
|
2096
2298
|
/** Return a new array of given shape and type, filled with `fill_value`. */
|
|
2097
2299
|
function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
2098
|
-
let
|
|
2099
|
-
if (typeof fillValue === "number")
|
|
2100
|
-
|
|
2101
|
-
source = AluExp.const(dtype, fillValue);
|
|
2102
|
-
} else if (typeof fillValue === "bigint") {
|
|
2103
|
-
dtype = dtype ?? DType.Int32;
|
|
2104
|
-
source = AluExp.const(dtype, Number(fillValue));
|
|
2105
|
-
} else if (typeof fillValue === "boolean") {
|
|
2300
|
+
let weakType = dtype == void 0;
|
|
2301
|
+
if (typeof fillValue === "number") dtype = dtype ?? DType.Float32;
|
|
2302
|
+
else if (typeof fillValue === "boolean") {
|
|
2106
2303
|
dtype = dtype ?? DType.Bool;
|
|
2107
|
-
|
|
2304
|
+
weakType = false;
|
|
2108
2305
|
} else if (fillValue instanceof Tracer) throw new Error("numpy.full() with array argument not implemented yet");
|
|
2109
2306
|
else throw new TypeError(`Invalid type for full: ${fillValue}`);
|
|
2110
|
-
return new
|
|
2307
|
+
return fullInternal(new ShapedArray(shape$1, dtype, weakType), fillValue, device);
|
|
2111
2308
|
}
|
|
2112
2309
|
/**
|
|
2113
2310
|
* Create an identity matrix.
|
|
@@ -2117,6 +2314,7 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
2117
2314
|
*/
|
|
2118
2315
|
function eye(numRows, numCols, { dtype, device } = {}) {
|
|
2119
2316
|
numCols = numCols ?? numRows;
|
|
2317
|
+
const weakType = dtype == void 0;
|
|
2120
2318
|
dtype = dtype ?? DType.Float32;
|
|
2121
2319
|
if (numCols < numRows) {
|
|
2122
2320
|
const arr = eye(numCols, numRows, {
|
|
@@ -2130,7 +2328,14 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
2130
2328
|
device
|
|
2131
2329
|
});
|
|
2132
2330
|
const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
|
|
2133
|
-
return new Array$1(
|
|
2331
|
+
return new Array$1({
|
|
2332
|
+
source: AluExp.cast(dtype, exp$2),
|
|
2333
|
+
st: ShapeTracker.fromShape([numRows, numCols]),
|
|
2334
|
+
dtype,
|
|
2335
|
+
weakType,
|
|
2336
|
+
backend: getBackend(device),
|
|
2337
|
+
committed: device != void 0
|
|
2338
|
+
});
|
|
2134
2339
|
}
|
|
2135
2340
|
/** Return the identity matrix, with ones on the main diagonal. */
|
|
2136
2341
|
function identity$1(n, { dtype, device } = {}) {
|
|
@@ -2167,7 +2372,14 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
2167
2372
|
});
|
|
2168
2373
|
const exp$2 = AluExp.add(AluExp.const(dtype, start), AluExp.mul(AluExp.cast(dtype, AluVar.idx), AluExp.const(dtype, step)));
|
|
2169
2374
|
const st = ShapeTracker.fromShape([size$1]);
|
|
2170
|
-
return new Array$1(
|
|
2375
|
+
return new Array$1({
|
|
2376
|
+
source: exp$2,
|
|
2377
|
+
st,
|
|
2378
|
+
dtype,
|
|
2379
|
+
weakType: false,
|
|
2380
|
+
backend: getBackend(device),
|
|
2381
|
+
committed: device != void 0
|
|
2382
|
+
});
|
|
2171
2383
|
}
|
|
2172
2384
|
/**
|
|
2173
2385
|
* Return evenly spaced numbers over a specified interval.
|
|
@@ -2185,10 +2397,10 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2185
2397
|
dtype,
|
|
2186
2398
|
device
|
|
2187
2399
|
});
|
|
2188
|
-
else if (num === 1) return
|
|
2400
|
+
else if (num === 1) return full([1], start, {
|
|
2189
2401
|
dtype,
|
|
2190
2402
|
device
|
|
2191
|
-
})
|
|
2403
|
+
});
|
|
2192
2404
|
else if (start === stop) return full([num], start, {
|
|
2193
2405
|
dtype,
|
|
2194
2406
|
device
|
|
@@ -2197,7 +2409,14 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
2197
2409
|
const denom = endpoint ? num - 1 : num;
|
|
2198
2410
|
const exp$2 = AluExp.cast(dtype, AluExp.add(AluExp.f32(start), AluExp.mul(AluExp.f32(delta / denom), AluExp.cast(DType.Float32, AluVar.idx))));
|
|
2199
2411
|
const st = ShapeTracker.fromShape([num]);
|
|
2200
|
-
return new Array$1(
|
|
2412
|
+
return new Array$1({
|
|
2413
|
+
source: exp$2,
|
|
2414
|
+
st,
|
|
2415
|
+
dtype,
|
|
2416
|
+
weakType: false,
|
|
2417
|
+
backend: getBackend(device),
|
|
2418
|
+
committed: device != void 0
|
|
2419
|
+
});
|
|
2201
2420
|
}
|
|
2202
2421
|
function aluCompare(a, b, op) {
|
|
2203
2422
|
switch (op) {
|
|
@@ -2209,35 +2428,6 @@ function aluCompare(a, b, op) {
|
|
|
2209
2428
|
case CompareOp.LessEqual: return AluExp.add(AluExp.cmplt(a, b), AluExp.cmpne(a, b).not());
|
|
2210
2429
|
}
|
|
2211
2430
|
}
|
|
2212
|
-
/**
|
|
2213
|
-
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
2214
|
-
*
|
|
2215
|
-
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
2216
|
-
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
2217
|
-
* Two dimensions are compatible when:
|
|
2218
|
-
* 1. they are equal, or
|
|
2219
|
-
* 2. one of them is 1."
|
|
2220
|
-
*
|
|
2221
|
-
* Throws a TypeError if the broadcast is not possible.
|
|
2222
|
-
*
|
|
2223
|
-
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
2224
|
-
*/
|
|
2225
|
-
function generalBroadcast(a, b) {
|
|
2226
|
-
const out = [];
|
|
2227
|
-
let i = a.length - 1;
|
|
2228
|
-
let j = b.length - 1;
|
|
2229
|
-
for (; i >= 0 && j >= 0; i--, j--) {
|
|
2230
|
-
const x = a[i];
|
|
2231
|
-
const y = b[j];
|
|
2232
|
-
if (x === y) out.push(x);
|
|
2233
|
-
else if (x === 1) out.push(y);
|
|
2234
|
-
else if (y === 1) out.push(x);
|
|
2235
|
-
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
2236
|
-
}
|
|
2237
|
-
for (; i >= 0; i--) out.push(a[i]);
|
|
2238
|
-
for (; j >= 0; j--) out.push(b[j]);
|
|
2239
|
-
return out.reverse();
|
|
2240
|
-
}
|
|
2241
2431
|
|
|
2242
2432
|
//#endregion
|
|
2243
2433
|
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
|
|
@@ -2313,13 +2503,15 @@ var Var = class Var {
|
|
|
2313
2503
|
};
|
|
2314
2504
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
2315
2505
|
var Lit = class {
|
|
2316
|
-
dtype;
|
|
2317
2506
|
value;
|
|
2318
2507
|
aval;
|
|
2319
|
-
|
|
2320
|
-
this.dtype
|
|
2508
|
+
get dtype() {
|
|
2509
|
+
return this.aval.dtype;
|
|
2510
|
+
}
|
|
2511
|
+
constructor(aval, value) {
|
|
2512
|
+
if (aval.shape.length !== 0) throw new Error(`internal: Lit must be a scalar`);
|
|
2321
2513
|
this.value = value;
|
|
2322
|
-
this.aval =
|
|
2514
|
+
this.aval = ShapedArray.fromAval(aval);
|
|
2323
2515
|
}
|
|
2324
2516
|
};
|
|
2325
2517
|
function atomIsLit(atom, literal) {
|
|
@@ -2443,14 +2635,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2443
2635
|
const c = eqn.outBinders[0];
|
|
2444
2636
|
if (atomIsLit(a, 0)) context.set(c, b);
|
|
2445
2637
|
else if (atomIsLit(b, 0)) context.set(c, a);
|
|
2446
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2638
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.dtype === DType.Bool ? Math.min(a.value + b.value, 1) : a.value + b.value));
|
|
2639
|
+
else newEqns.push(eqn);
|
|
2640
|
+
} else if (eqn.primitive === Primitive.Neg) {
|
|
2641
|
+
const [a] = inputs;
|
|
2642
|
+
const c = eqn.outBinders[0];
|
|
2643
|
+
if (atomIsLit(a)) context.set(c, new Lit(a.aval, -a.value));
|
|
2447
2644
|
else newEqns.push(eqn);
|
|
2448
2645
|
} else if (eqn.primitive === Primitive.Mul) {
|
|
2449
2646
|
const [a, b] = inputs;
|
|
2450
2647
|
const c = eqn.outBinders[0];
|
|
2451
2648
|
if (atomIsLit(a, 1)) context.set(c, b);
|
|
2452
2649
|
else if (atomIsLit(b, 1)) context.set(c, a);
|
|
2453
|
-
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(a.
|
|
2650
|
+
else if (atomIsLit(a) && atomIsLit(b)) context.set(c, new Lit(promoteAvals(a.aval, b.aval), a.value * b.value));
|
|
2454
2651
|
else newEqns.push(eqn);
|
|
2455
2652
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
2456
2653
|
const [a, b] = inputs;
|
|
@@ -2548,7 +2745,7 @@ function evalJaxpr(jaxpr, args) {
|
|
|
2548
2745
|
if (x instanceof Var) {
|
|
2549
2746
|
remainingRefs.set(x, (remainingRefs.get(x) ?? 0) - 1);
|
|
2550
2747
|
return env.get(x);
|
|
2551
|
-
} else return
|
|
2748
|
+
} else return array(x.value, { dtype: x.dtype });
|
|
2552
2749
|
};
|
|
2553
2750
|
const write = (v, val) => {
|
|
2554
2751
|
if (env.has(v)) throw new Error(`Variable already bound: ${v}`);
|
|
@@ -2607,7 +2804,7 @@ var JaxprTrace = class extends Trace {
|
|
|
2607
2804
|
let tracer = this.builder.constTracers.get(val);
|
|
2608
2805
|
if (tracer === void 0) {
|
|
2609
2806
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
2610
|
-
this.builder.addConst(tracer, val instanceof Tracer ? val.ref :
|
|
2807
|
+
this.builder.addConst(tracer, val instanceof Tracer ? val.ref : array(val));
|
|
2611
2808
|
}
|
|
2612
2809
|
return tracer;
|
|
2613
2810
|
}
|
|
@@ -2676,7 +2873,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2676
2873
|
const newConsts = [];
|
|
2677
2874
|
for (let i = 0; i < consts.length; i++) if (ndim$1(consts[i]) === 0 && consts[i] instanceof Array$1) {
|
|
2678
2875
|
const ar = consts[i];
|
|
2679
|
-
literals.set(jaxpr.inBinders[i], new Lit(ar.
|
|
2876
|
+
literals.set(jaxpr.inBinders[i], new Lit(ar.aval, ar.dataSync()[0]));
|
|
2680
2877
|
} else {
|
|
2681
2878
|
constBinders.push(jaxpr.inBinders[i]);
|
|
2682
2879
|
newConsts.push(consts[i]);
|
|
@@ -2689,13 +2886,12 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
2689
2886
|
}
|
|
2690
2887
|
function binopAbstractEval([x, y]) {
|
|
2691
2888
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
2692
|
-
|
|
2693
|
-
return [new ShapedArray(generalBroadcast(x.shape, y.shape), x.dtype)];
|
|
2889
|
+
return [promoteAvals(x, y)];
|
|
2694
2890
|
}
|
|
2695
2891
|
function compareAbstractEval([x, y]) {
|
|
2696
2892
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("compareAbstractEval expects ShapedArray inputs");
|
|
2697
|
-
|
|
2698
|
-
return [new ShapedArray(
|
|
2893
|
+
const aval = promoteAvals(x, y);
|
|
2894
|
+
return [new ShapedArray(aval.shape, DType.Bool, false)];
|
|
2699
2895
|
}
|
|
2700
2896
|
function vectorizedUnopAbstractEval([x]) {
|
|
2701
2897
|
return [ShapedArray.fromAval(x)];
|
|
@@ -2708,18 +2904,18 @@ const abstractEvalRules = {
|
|
|
2708
2904
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
2709
2905
|
[Primitive.StopGradient]: vectorizedUnopAbstractEval,
|
|
2710
2906
|
[Primitive.Cast]([x], { dtype }) {
|
|
2711
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2907
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2712
2908
|
},
|
|
2713
2909
|
[Primitive.Bitcast]([x], { dtype }) {
|
|
2714
2910
|
if (x.dtype === DType.Bool || dtype === DType.Bool) throw new TypeError("Bitcast to/from bool is not allowed");
|
|
2715
2911
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
2716
|
-
return [new ShapedArray(x.shape, dtype)];
|
|
2912
|
+
return [new ShapedArray(x.shape, dtype, false)];
|
|
2717
2913
|
},
|
|
2718
2914
|
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
2719
2915
|
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
2720
2916
|
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2721
2917
|
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2722
|
-
return [new ShapedArray(shape$1, DType.Uint32)];
|
|
2918
|
+
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
2723
2919
|
},
|
|
2724
2920
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2725
2921
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
@@ -2727,61 +2923,62 @@ const abstractEvalRules = {
|
|
|
2727
2923
|
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2728
2924
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2729
2925
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2926
|
+
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
2927
|
+
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
2730
2928
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2731
2929
|
[Primitive.Min]: binopAbstractEval,
|
|
2732
2930
|
[Primitive.Max]: binopAbstractEval,
|
|
2733
2931
|
[Primitive.Reduce]([x], { axis }) {
|
|
2734
2932
|
const axisSet = new Set(axis);
|
|
2735
2933
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2736
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2934
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2737
2935
|
},
|
|
2738
2936
|
[Primitive.Pool]([x], { window, strides }) {
|
|
2739
2937
|
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2740
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2938
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2741
2939
|
},
|
|
2742
2940
|
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2743
2941
|
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2744
2942
|
if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2745
|
-
return [new ShapedArray(inShape, x.dtype)];
|
|
2943
|
+
return [new ShapedArray(inShape, x.dtype, x.weakType)];
|
|
2746
2944
|
},
|
|
2747
2945
|
[Primitive.Dot]([x, y]) {
|
|
2748
|
-
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2749
2946
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
2750
|
-
const shape$1 =
|
|
2947
|
+
const { shape: shape$1, dtype, weakType } = promoteAvals(x, y);
|
|
2751
2948
|
shape$1.splice(-1, 1);
|
|
2752
|
-
return [new ShapedArray(shape$1,
|
|
2949
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2753
2950
|
},
|
|
2754
2951
|
[Primitive.Conv]([lhs, rhs], params) {
|
|
2755
|
-
|
|
2952
|
+
const { dtype, weakType } = promoteAvals(new ShapedArray([], lhs.dtype, lhs.weakType), new ShapedArray([], rhs.dtype, rhs.weakType));
|
|
2756
2953
|
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2757
|
-
return [new ShapedArray(shape$1,
|
|
2954
|
+
return [new ShapedArray(shape$1, dtype, weakType)];
|
|
2758
2955
|
},
|
|
2759
2956
|
[Primitive.Compare]: compareAbstractEval,
|
|
2760
2957
|
[Primitive.Where]([cond, x, y]) {
|
|
2761
2958
|
if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
2762
|
-
|
|
2763
|
-
const shape$1 = generalBroadcast(cond.shape,
|
|
2764
|
-
return [new ShapedArray(shape$1,
|
|
2959
|
+
const xy = promoteAvals(x, y);
|
|
2960
|
+
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
2961
|
+
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
2765
2962
|
},
|
|
2766
2963
|
[Primitive.Transpose]([x], { perm }) {
|
|
2767
|
-
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype)];
|
|
2964
|
+
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
2768
2965
|
},
|
|
2769
2966
|
[Primitive.Broadcast]([x], { shape: shape$1 }) {
|
|
2770
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2967
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2771
2968
|
},
|
|
2772
2969
|
[Primitive.Reshape]([x], { shape: shape$1 }) {
|
|
2773
|
-
return [new ShapedArray(shape$1, x.dtype)];
|
|
2970
|
+
return [new ShapedArray(shape$1, x.dtype, x.weakType)];
|
|
2774
2971
|
},
|
|
2775
2972
|
[Primitive.Flip]([x], _) {
|
|
2776
|
-
return [
|
|
2973
|
+
return [ShapedArray.fromAval(x)];
|
|
2777
2974
|
},
|
|
2778
2975
|
[Primitive.Shrink]([x], { slice }) {
|
|
2779
2976
|
const newShape = slice.map((s) => s[1] - s[0]);
|
|
2780
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2977
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2781
2978
|
},
|
|
2782
2979
|
[Primitive.Pad]([x], { width }) {
|
|
2783
2980
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
2784
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2981
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2785
2982
|
},
|
|
2786
2983
|
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2787
2984
|
for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
@@ -2794,7 +2991,7 @@ const abstractEvalRules = {
|
|
|
2794
2991
|
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
2795
2992
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2796
2993
|
newShape.splice(outDim, 0, ...gatherShape);
|
|
2797
|
-
return [new ShapedArray(newShape, x.dtype)];
|
|
2994
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
2798
2995
|
},
|
|
2799
2996
|
[Primitive.JitCall](args, { jaxpr }) {
|
|
2800
2997
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
@@ -2861,6 +3058,7 @@ function jit$1(f, opts) {
|
|
|
2861
3058
|
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2862
3059
|
const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2863
3060
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
3061
|
+
name: f.name || "closure",
|
|
2864
3062
|
jaxpr,
|
|
2865
3063
|
numConsts: consts.length
|
|
2866
3064
|
});
|
|
@@ -2979,6 +3177,16 @@ const jvpRules = {
|
|
|
2979
3177
|
[Primitive.Log]([x], [dx]) {
|
|
2980
3178
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
2981
3179
|
},
|
|
3180
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3181
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3182
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3183
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3184
|
+
},
|
|
3185
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3186
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3187
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3188
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3189
|
+
},
|
|
2982
3190
|
[Primitive.Sqrt]([x], [dx]) {
|
|
2983
3191
|
const z = sqrt$1(x);
|
|
2984
3192
|
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
@@ -3022,13 +3230,14 @@ const jvpRules = {
|
|
|
3022
3230
|
const indicesRef = indices.map((t) => t.ref);
|
|
3023
3231
|
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3024
3232
|
},
|
|
3025
|
-
[Primitive.JitCall](primals, tangents, { jaxpr }) {
|
|
3233
|
+
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3026
3234
|
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3027
3235
|
const outs = bind(Primitive.JitCall, [
|
|
3028
3236
|
...newConsts.map((c) => c.ref),
|
|
3029
3237
|
...primals,
|
|
3030
3238
|
...tangents
|
|
3031
3239
|
], {
|
|
3240
|
+
name: `${name}_jvp`,
|
|
3032
3241
|
jaxpr: newJaxpr,
|
|
3033
3242
|
numConsts: newConsts.length
|
|
3034
3243
|
});
|
|
@@ -3082,7 +3291,7 @@ function jvp$1(f, primals, tangents) {
|
|
|
3082
3291
|
function mappedAval(batchDim, aval) {
|
|
3083
3292
|
const shape$1 = [...aval.shape];
|
|
3084
3293
|
shape$1.splice(batchDim, 1);
|
|
3085
|
-
return new ShapedArray(shape$1, aval.dtype);
|
|
3294
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3086
3295
|
}
|
|
3087
3296
|
/** Move one axis to a different index. */
|
|
3088
3297
|
function moveaxis$1(x, src, dst) {
|
|
@@ -3139,6 +3348,10 @@ var BatchTrace = class extends Trace {
|
|
|
3139
3348
|
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3140
3349
|
const vmapRule = vmapRules[primitive];
|
|
3141
3350
|
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3351
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3352
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3353
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3354
|
+
}
|
|
3142
3355
|
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3143
3356
|
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3144
3357
|
}
|
|
@@ -3146,24 +3359,28 @@ var BatchTrace = class extends Trace {
|
|
|
3146
3359
|
return this.main.globalData;
|
|
3147
3360
|
}
|
|
3148
3361
|
};
|
|
3149
|
-
|
|
3150
|
-
|
|
3151
|
-
|
|
3152
|
-
|
|
3153
|
-
|
|
3154
|
-
return broadcast(x, shape$1, axis);
|
|
3155
|
-
}
|
|
3156
|
-
}
|
|
3157
|
-
/** Process a primitive with built-in broadcasting. */
|
|
3362
|
+
/**
|
|
3363
|
+
* Process a primitive with built-in broadcasting.
|
|
3364
|
+
*
|
|
3365
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3366
|
+
*/
|
|
3158
3367
|
function broadcastBatcher(op) {
|
|
3159
3368
|
return (axisSize, args, dims) => {
|
|
3160
3369
|
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3161
|
-
const
|
|
3162
|
-
|
|
3163
|
-
|
|
3164
|
-
args
|
|
3165
|
-
|
|
3166
|
-
|
|
3370
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3371
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3372
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3373
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3374
|
+
args = args.map((x, i) => {
|
|
3375
|
+
if (dims[i] === null) return x;
|
|
3376
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3377
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3378
|
+
x.shape[0],
|
|
3379
|
+
...rep(nd - x.ndim, 1),
|
|
3380
|
+
...x.shape.slice(1)
|
|
3381
|
+
]);
|
|
3382
|
+
return x;
|
|
3383
|
+
});
|
|
3167
3384
|
return [[op(...args)], [0]];
|
|
3168
3385
|
};
|
|
3169
3386
|
}
|
|
@@ -3187,17 +3404,18 @@ const vmapRules = {
|
|
|
3187
3404
|
[Primitive.Atan]: unopBatcher(atan$1),
|
|
3188
3405
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
3189
3406
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3407
|
+
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3408
|
+
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3190
3409
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3191
3410
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3192
3411
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3193
3412
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3194
|
-
|
|
3413
|
+
assertNonNull(xBdim);
|
|
3195
3414
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3196
3415
|
const outBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3197
3416
|
return [[reduce(x, op, newAxis)], [outBdim]];
|
|
3198
3417
|
},
|
|
3199
3418
|
[Primitive.Dot](axisSize, [x, y], [xBdim, yBdim]) {
|
|
3200
|
-
if (xBdim === null && yBdim === null) return [[dot$1(x, y)], [null]];
|
|
3201
3419
|
x = moveBatchAxis(axisSize, xBdim, x.ndim - (xBdim === null ? 1 : 2), x);
|
|
3202
3420
|
y = moveBatchAxis(axisSize, yBdim, y.ndim - (yBdim === null ? 1 : 2), y);
|
|
3203
3421
|
const z = dot$1(x, y);
|
|
@@ -3206,29 +3424,72 @@ const vmapRules = {
|
|
|
3206
3424
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3207
3425
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3208
3426
|
},
|
|
3427
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3428
|
+
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3429
|
+
assertNonNull(xBdim);
|
|
3430
|
+
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
3431
|
+
newPerm.splice(xBdim, 0, xBdim);
|
|
3432
|
+
return [[transpose$1(x, newPerm)], [xBdim]];
|
|
3433
|
+
},
|
|
3434
|
+
[Primitive.Broadcast](axisSize, [x], [xBdim], { shape: shape$1, axis }) {
|
|
3435
|
+
assertNonNull(xBdim);
|
|
3436
|
+
const newShape = shape$1.toSpliced(xBdim, 0, axisSize);
|
|
3437
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3438
|
+
return [[broadcast(x, newShape, newAxis)], [xBdim]];
|
|
3439
|
+
},
|
|
3209
3440
|
[Primitive.Reshape](axisSize, [x], [xBdim], { shape: shape$1 }) {
|
|
3210
|
-
if (xBdim === null) return [[reshape$1(x, shape$1)], [null]];
|
|
3211
3441
|
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3212
3442
|
return [[reshape$1(x, [axisSize, ...shape$1])], [0]];
|
|
3213
3443
|
},
|
|
3214
3444
|
[Primitive.Flip](axisSize, [x], [xBdim], { axis }) {
|
|
3215
|
-
|
|
3445
|
+
assertNonNull(xBdim);
|
|
3216
3446
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3217
3447
|
return [[flip$1(x, newAxis)], [xBdim]];
|
|
3218
3448
|
},
|
|
3219
3449
|
[Primitive.Shrink](axisSize, [x], [xBdim], { slice }) {
|
|
3220
|
-
|
|
3450
|
+
assertNonNull(xBdim);
|
|
3221
3451
|
const newSlice = slice.toSpliced(xBdim, 0, [0, axisSize]);
|
|
3222
3452
|
return [[shrink(x, newSlice)], [xBdim]];
|
|
3223
3453
|
},
|
|
3224
3454
|
[Primitive.Pad](axisSize, [x], [xBdim], { width }) {
|
|
3225
|
-
|
|
3455
|
+
assertNonNull(xBdim);
|
|
3226
3456
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3227
3457
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3228
3458
|
},
|
|
3229
|
-
[Primitive.
|
|
3459
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3460
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3461
|
+
assertNonNull(xBdim);
|
|
3462
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3463
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3464
|
+
let newOutDim = outDim;
|
|
3465
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3466
|
+
else newOutDim += 1;
|
|
3467
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3468
|
+
}
|
|
3469
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3470
|
+
indices = indices.map((m, i) => {
|
|
3471
|
+
if (indicesBdim[i] === null) return m;
|
|
3472
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3473
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3474
|
+
m.shape[0],
|
|
3475
|
+
...rep(nd - m.ndim, 1),
|
|
3476
|
+
...m.shape.slice(1)
|
|
3477
|
+
]);
|
|
3478
|
+
return m;
|
|
3479
|
+
});
|
|
3480
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3481
|
+
else {
|
|
3482
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3483
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3484
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
|
|
3485
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3486
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3487
|
+
}
|
|
3488
|
+
},
|
|
3489
|
+
[Primitive.JitCall](axisSize, args, dims, { name, jaxpr }) {
|
|
3230
3490
|
const { newJaxpr, newConsts } = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3231
3491
|
const outs = bind(Primitive.JitCall, [...newConsts.map((c) => c.ref), ...args], {
|
|
3492
|
+
name: `${name}_vmap`,
|
|
3232
3493
|
jaxpr: newJaxpr,
|
|
3233
3494
|
numConsts: newConsts.length
|
|
3234
3495
|
});
|
|
@@ -3244,7 +3505,7 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3244
3505
|
if (dims[i] === null) return v.aval;
|
|
3245
3506
|
const shape$1 = [...v.aval.shape];
|
|
3246
3507
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3247
|
-
return new ShapedArray(shape$1, v.aval.dtype);
|
|
3508
|
+
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3248
3509
|
});
|
|
3249
3510
|
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3250
3511
|
const result = {
|
|
@@ -3284,12 +3545,14 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3284
3545
|
function vmap$1(f, inAxes = 0) {
|
|
3285
3546
|
return (...args) => {
|
|
3286
3547
|
const [argsFlat, inTree] = flatten(args);
|
|
3287
|
-
let inAxesFlat;
|
|
3548
|
+
let inAxesFlat = [];
|
|
3288
3549
|
if (typeof inAxes === "number") inAxesFlat = rep(argsFlat.length, inAxes);
|
|
3550
|
+
else for (let i = 0; i < args.length; i++) if (inAxes[i] == null) inAxesFlat.push(...rep(inTree.childTreedefs[i].size, null));
|
|
3551
|
+
else if (typeof inAxes[i] === "number") inAxesFlat.push(...rep(inTree.childTreedefs[i].size, inAxes[i]));
|
|
3289
3552
|
else {
|
|
3290
|
-
|
|
3291
|
-
[
|
|
3292
|
-
|
|
3553
|
+
const [axesFlat, axesTreeDef] = flatten(inAxes[i]);
|
|
3554
|
+
if (!inTree.childTreedefs[i].equals(axesTreeDef)) throw new TreeMismatchError("vmap", inTree.childTreedefs[i], axesTreeDef);
|
|
3555
|
+
inAxesFlat.push(...axesFlat);
|
|
3293
3556
|
}
|
|
3294
3557
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3295
3558
|
const outsFlat = vmapFlat(fFlat, inAxesFlat, argsFlat);
|
|
@@ -3457,8 +3720,8 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3457
3720
|
processPrimitive(primitive, tracers, params) {
|
|
3458
3721
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3459
3722
|
if (primitive === Primitive.JitCall) {
|
|
3460
|
-
const { jaxpr, numConsts } = params;
|
|
3461
|
-
return this.#partialEvalJaxpr(jaxpr, numConsts, tracers);
|
|
3723
|
+
const { name, jaxpr, numConsts } = params;
|
|
3724
|
+
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3462
3725
|
}
|
|
3463
3726
|
const tracersIn = tracers.map((t) => this.instantiateConst(t));
|
|
3464
3727
|
const avalsIn = tracersIn.map((t) => t.pval.aval);
|
|
@@ -3484,12 +3747,13 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3484
3747
|
*
|
|
3485
3748
|
* Used when encountering a JitCall rule during the trace.
|
|
3486
3749
|
*/
|
|
3487
|
-
#partialEvalJaxpr(jaxpr, numConsts, tracers) {
|
|
3750
|
+
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3488
3751
|
jaxpr = jaxpr.flatten();
|
|
3489
3752
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3490
3753
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3491
3754
|
const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
|
|
3492
3755
|
const outs1Res = bind(Primitive.JitCall, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3756
|
+
name: `${name}_peval`,
|
|
3493
3757
|
jaxpr: jaxpr1,
|
|
3494
3758
|
numConsts: 0
|
|
3495
3759
|
});
|
|
@@ -3501,6 +3765,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3501
3765
|
prim: Primitive.JitCall,
|
|
3502
3766
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3503
3767
|
params: {
|
|
3768
|
+
name: `${name}_resid`,
|
|
3504
3769
|
jaxpr: jaxpr2,
|
|
3505
3770
|
numConsts: 0
|
|
3506
3771
|
},
|
|
@@ -3643,7 +3908,7 @@ function evalJaxprTransposed(jaxpr, args, cotangents) {
|
|
|
3643
3908
|
}
|
|
3644
3909
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
3645
3910
|
const eqn = jaxpr.eqns[i];
|
|
3646
|
-
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ?
|
|
3911
|
+
const primalsIn = eqn.inputs.map((v) => v instanceof Lit ? array(v.value, { dtype: v.dtype }) : knownPrimals.has(v) ? knownPrimals.get(v).ref : new UndefPrimal(v.aval));
|
|
3647
3912
|
const cotangentsOut = eqn.outBinders.map(readCotangent);
|
|
3648
3913
|
const rule = transposeRules[eqn.primitive];
|
|
3649
3914
|
if (!rule) throw new TypeError(`Backward pass not implemented for ${eqn.primitive}`);
|
|
@@ -3823,7 +4088,7 @@ const transposeRules = {
|
|
|
3823
4088
|
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
3824
4089
|
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
3825
4090
|
},
|
|
3826
|
-
[Primitive.JitCall](cts, args, { jaxpr }) {
|
|
4091
|
+
[Primitive.JitCall](cts, args, { name, jaxpr }) {
|
|
3827
4092
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
3828
4093
|
const { newJaxpr, newConsts } = transposeJaxpr(jaxpr, undefPrimals);
|
|
3829
4094
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
@@ -3832,6 +4097,7 @@ const transposeRules = {
|
|
|
3832
4097
|
...residuals,
|
|
3833
4098
|
...cts
|
|
3834
4099
|
], {
|
|
4100
|
+
name: `${name}_t`,
|
|
3835
4101
|
jaxpr: newJaxpr,
|
|
3836
4102
|
numConsts: newConsts.length
|
|
3837
4103
|
});
|
|
@@ -3906,7 +4172,7 @@ function valueAndGrad$1(f) {
|
|
|
3906
4172
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3907
4173
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3908
4174
|
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3909
|
-
const [ct, ...rest] = fVjp(
|
|
4175
|
+
const [ct, ...rest] = fVjp(onesLike$1(y.ref));
|
|
3910
4176
|
for (const r of rest) dispose(r);
|
|
3911
4177
|
fVjp.dispose();
|
|
3912
4178
|
return [y, ct];
|
|
@@ -3934,7 +4200,10 @@ __export(lax_exports, {
|
|
|
3934
4200
|
conv: () => conv$1,
|
|
3935
4201
|
convGeneralDilated: () => convGeneralDilated,
|
|
3936
4202
|
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
3937
|
-
|
|
4203
|
+
erf: () => erf,
|
|
4204
|
+
erfc: () => erfc,
|
|
4205
|
+
reduceWindow: () => reduceWindow,
|
|
4206
|
+
stopGradient: () => stopGradient$1
|
|
3938
4207
|
});
|
|
3939
4208
|
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
3940
4209
|
const padType = padding.toUpperCase();
|
|
@@ -3993,6 +4262,28 @@ function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
|
3993
4262
|
strides: windowStrides
|
|
3994
4263
|
}));
|
|
3995
4264
|
}
|
|
4265
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4266
|
+
function erf(x) {
|
|
4267
|
+
return erf$1(x);
|
|
4268
|
+
}
|
|
4269
|
+
/**
|
|
4270
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4271
|
+
*
|
|
4272
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4273
|
+
* where `erf(x)` is very close to 1.
|
|
4274
|
+
*/
|
|
4275
|
+
function erfc(x) {
|
|
4276
|
+
return erfc$1(x);
|
|
4277
|
+
}
|
|
4278
|
+
/**
|
|
4279
|
+
* Stops gradient computation.
|
|
4280
|
+
*
|
|
4281
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
4282
|
+
* forward or reverse-mode automatic differentiation.
|
|
4283
|
+
*/
|
|
4284
|
+
function stopGradient$1(x) {
|
|
4285
|
+
return stopGradient(x);
|
|
4286
|
+
}
|
|
3996
4287
|
|
|
3997
4288
|
//#endregion
|
|
3998
4289
|
//#region src/numpy.ts
|
|
@@ -4055,6 +4346,9 @@ __export(numpy_exports, {
|
|
|
4055
4346
|
fullLike: () => fullLike$1,
|
|
4056
4347
|
greater: () => greater,
|
|
4057
4348
|
greaterEqual: () => greaterEqual,
|
|
4349
|
+
hamming: () => hamming,
|
|
4350
|
+
hann: () => hann,
|
|
4351
|
+
heaviside: () => heaviside,
|
|
4058
4352
|
hstack: () => hstack,
|
|
4059
4353
|
hypot: () => hypot,
|
|
4060
4354
|
identity: () => identity$1,
|
|
@@ -4276,7 +4570,7 @@ function argmin(a, axis, opts) {
|
|
|
4276
4570
|
} else axis = checkAxis(axis, a.ndim);
|
|
4277
4571
|
const shape$1 = a.shape;
|
|
4278
4572
|
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
4279
|
-
const length =
|
|
4573
|
+
const length = array(shape$1[axis], {
|
|
4280
4574
|
dtype: int32,
|
|
4281
4575
|
device: a.device
|
|
4282
4576
|
});
|
|
@@ -4300,7 +4594,7 @@ function argmax(a, axis, opts) {
|
|
|
4300
4594
|
} else axis = checkAxis(axis, a.ndim);
|
|
4301
4595
|
const shape$1 = a.shape;
|
|
4302
4596
|
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
4303
|
-
const length =
|
|
4597
|
+
const length = array(shape$1[axis], {
|
|
4304
4598
|
dtype: int32,
|
|
4305
4599
|
device: a.device
|
|
4306
4600
|
});
|
|
@@ -4694,6 +4988,32 @@ function sign(x) {
|
|
|
4694
4988
|
x = fudgeArray(x);
|
|
4695
4989
|
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4696
4990
|
}
|
|
4991
|
+
/**
|
|
4992
|
+
* Return the Hamming window of size M, a taper with a weighted cosine bell.
|
|
4993
|
+
*
|
|
4994
|
+
* `w(n) = 0.54 - 0.46 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
4995
|
+
*/
|
|
4996
|
+
function hamming(M) {
|
|
4997
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.46).add(.54);
|
|
4998
|
+
}
|
|
4999
|
+
/**
|
|
5000
|
+
* Return the Hann window of size M, a taper with a weighted cosine bell.
|
|
5001
|
+
*
|
|
5002
|
+
* `w(n) = 0.5 - 0.5 * cos(2πn/(M-1))` for `0 <= n <= M-1`.
|
|
5003
|
+
*/
|
|
5004
|
+
function hann(M) {
|
|
5005
|
+
return cos(linspace(0, 2 * Math.PI, M)).mul(-.5).add(.5);
|
|
5006
|
+
}
|
|
5007
|
+
/**
|
|
5008
|
+
* @function
|
|
5009
|
+
* Compute the Heaviside step function. It is defined piecewise:
|
|
5010
|
+
* - `heaviside(x1, x2) = 0` for `x1 < 0`,
|
|
5011
|
+
* - `heaviside(x1, x2) = x2` for `x1 == 0`,
|
|
5012
|
+
* - `heaviside(x1, x2) = 1` for `x1 > 0`.
|
|
5013
|
+
*/
|
|
5014
|
+
const heaviside = jit$1(function heaviside$1(x1, x2) {
|
|
5015
|
+
return where(less(x1.ref, 0), 0, where(equal(x1, 0), x2, 1));
|
|
5016
|
+
});
|
|
4697
5017
|
/** Calculate element-wise square of the input array. */
|
|
4698
5018
|
function square(x) {
|
|
4699
5019
|
x = fudgeArray(x);
|
|
@@ -4713,10 +5033,10 @@ function acos(x) {
|
|
|
4713
5033
|
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4714
5034
|
*
|
|
4715
5035
|
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4716
|
-
* stable than sqrt(x1**2 + x2**2)
|
|
4717
|
-
* improvements.
|
|
5036
|
+
* stable than `sqrt(x1**2 + x2**2)`. We don't currently implement those
|
|
5037
|
+
* stability improvements.
|
|
4718
5038
|
*/
|
|
4719
|
-
const hypot = jit$1((x1, x2)
|
|
5039
|
+
const hypot = jit$1(function hypot$1(x1, x2) {
|
|
4720
5040
|
return sqrt(square(x1).add(square(x2)));
|
|
4721
5041
|
});
|
|
4722
5042
|
/**
|
|
@@ -4732,7 +5052,7 @@ const hypot = jit$1((x1, x2) => {
|
|
|
4732
5052
|
*
|
|
4733
5053
|
* The output is ill-defined when both x and y are zero.
|
|
4734
5054
|
*/
|
|
4735
|
-
const atan2 = jit$1((y, x)
|
|
5055
|
+
const atan2 = jit$1(function atan2$1(y, x) {
|
|
4736
5056
|
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4737
5057
|
const xNeg = less(x.ref, 0);
|
|
4738
5058
|
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
@@ -4800,13 +5120,13 @@ const degrees = rad2deg;
|
|
|
4800
5120
|
* @function
|
|
4801
5121
|
* Computes first array raised to power of second array, element-wise.
|
|
4802
5122
|
*/
|
|
4803
|
-
const power = jit$1((x1, x2)
|
|
5123
|
+
const power = jit$1(function power$1(x1, x2) {
|
|
4804
5124
|
return exp(log(x1).mul(x2));
|
|
4805
5125
|
});
|
|
4806
5126
|
/** @function Alias of `jax.numpy.power()`. */
|
|
4807
5127
|
const pow = power;
|
|
4808
5128
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
4809
|
-
const cbrt = jit$1((x)
|
|
5129
|
+
const cbrt = jit$1(function cbrt$1(x) {
|
|
4810
5130
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4811
5131
|
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4812
5132
|
});
|
|
@@ -4816,7 +5136,7 @@ const cbrt = jit$1((x) => {
|
|
|
4816
5136
|
*
|
|
4817
5137
|
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4818
5138
|
*/
|
|
4819
|
-
const sinh = jit$1((x)
|
|
5139
|
+
const sinh = jit$1(function sinh$1(x) {
|
|
4820
5140
|
const ex = exp(x);
|
|
4821
5141
|
const emx = reciprocal(ex.ref);
|
|
4822
5142
|
return ex.sub(emx).mul(.5);
|
|
@@ -4827,7 +5147,7 @@ const sinh = jit$1((x) => {
|
|
|
4827
5147
|
*
|
|
4828
5148
|
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4829
5149
|
*/
|
|
4830
|
-
const cosh = jit$1((x)
|
|
5150
|
+
const cosh = jit$1(function cosh$1(x) {
|
|
4831
5151
|
const ex = exp(x);
|
|
4832
5152
|
const emx = reciprocal(ex.ref);
|
|
4833
5153
|
return ex.add(emx).mul(.5);
|
|
@@ -4838,7 +5158,7 @@ const cosh = jit$1((x) => {
|
|
|
4838
5158
|
*
|
|
4839
5159
|
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4840
5160
|
*/
|
|
4841
|
-
const tanh = jit$1((x)
|
|
5161
|
+
const tanh = jit$1(function tanh$1(x) {
|
|
4842
5162
|
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4843
5163
|
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4844
5164
|
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
@@ -4849,7 +5169,7 @@ const tanh = jit$1((x) => {
|
|
|
4849
5169
|
*
|
|
4850
5170
|
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4851
5171
|
*/
|
|
4852
|
-
const arcsinh = jit$1((x)
|
|
5172
|
+
const arcsinh = jit$1(function arcsinh$1(x) {
|
|
4853
5173
|
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4854
5174
|
});
|
|
4855
5175
|
/**
|
|
@@ -4858,7 +5178,7 @@ const arcsinh = jit$1((x) => {
|
|
|
4858
5178
|
*
|
|
4859
5179
|
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4860
5180
|
*/
|
|
4861
|
-
const arccosh = jit$1((x)
|
|
5181
|
+
const arccosh = jit$1(function arccosh$1(x) {
|
|
4862
5182
|
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4863
5183
|
});
|
|
4864
5184
|
/**
|
|
@@ -4867,7 +5187,7 @@ const arccosh = jit$1((x) => {
|
|
|
4867
5187
|
*
|
|
4868
5188
|
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4869
5189
|
*/
|
|
4870
|
-
const arctanh = jit$1((x)
|
|
5190
|
+
const arctanh = jit$1(function arctanh$1(x) {
|
|
4871
5191
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4872
5192
|
});
|
|
4873
5193
|
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
@@ -4983,7 +5303,9 @@ function softSign(x) {
|
|
|
4983
5303
|
*
|
|
4984
5304
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4985
5305
|
*/
|
|
4986
|
-
const silu = jit$1((x)
|
|
5306
|
+
const silu = jit$1(function silu$1(x) {
|
|
5307
|
+
return x.ref.mul(sigmoid(x));
|
|
5308
|
+
});
|
|
4987
5309
|
/**
|
|
4988
5310
|
* @function
|
|
4989
5311
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
@@ -5036,18 +5358,20 @@ function celu(x, alpha = 1) {
|
|
|
5036
5358
|
* @function
|
|
5037
5359
|
* Gaussion error linear unit (GELU) activation function.
|
|
5038
5360
|
*
|
|
5039
|
-
* This is computed element-wise.
|
|
5040
|
-
*
|
|
5041
|
-
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5361
|
+
* This is computed element-wise. There are two variants depending on whether
|
|
5362
|
+
* `approximate` is set (default true):
|
|
5042
5363
|
*
|
|
5043
|
-
*
|
|
5364
|
+
* - Approximate: `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`
|
|
5365
|
+
* - Exact: `gelu(x) = x * 0.5 * erfc(-x / sqrt(2))`
|
|
5044
5366
|
*
|
|
5045
|
-
*
|
|
5367
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5046
5368
|
*/
|
|
5047
|
-
const gelu = jit$1((x)
|
|
5048
|
-
|
|
5049
|
-
|
|
5050
|
-
|
|
5369
|
+
const gelu = jit$1(function gelu$1(x, opts) {
|
|
5370
|
+
if (opts?.approximate ?? true) {
|
|
5371
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5372
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5373
|
+
} else return x.ref.mul(.5).mul(erfc$1(negative(x.ref.mul(Math.SQRT1_2))));
|
|
5374
|
+
}, { staticArgnums: [1] });
|
|
5051
5375
|
/**
|
|
5052
5376
|
* Gated linear unit (GLU) activation function.
|
|
5053
5377
|
*
|
|
@@ -5215,8 +5539,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
5215
5539
|
const keyShape = validateKeyShape(key$1);
|
|
5216
5540
|
return randomBits(key$1.ref.slice(...keyShape.map(() => null), 0), key$1.slice(...keyShape.map(() => null), 1), shape$1);
|
|
5217
5541
|
}
|
|
5218
|
-
/**
|
|
5219
|
-
function
|
|
5542
|
+
/**
|
|
5543
|
+
* @function
|
|
5544
|
+
* Sample uniform random values in [minval, maxval) with given shape.
|
|
5545
|
+
*/
|
|
5546
|
+
const uniform = jit$1(function uniform$1(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
5220
5547
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
5221
5548
|
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
5222
5549
|
dtype: DType.Uint32,
|
|
@@ -5229,7 +5556,7 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
5229
5556
|
const rand = bitcast(float12, DType.Float32).sub(1);
|
|
5230
5557
|
if (minval === 0 && maxval === 1) return rand;
|
|
5231
5558
|
else return rand.mul(maxval - minval).add(minval);
|
|
5232
|
-
}
|
|
5559
|
+
}, { staticArgnums: [1, 2] });
|
|
5233
5560
|
/**
|
|
5234
5561
|
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5235
5562
|
*
|
|
@@ -5240,26 +5567,49 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
5240
5567
|
p = fudgeArray(p);
|
|
5241
5568
|
return uniform(key$1, shape$1).less(p);
|
|
5242
5569
|
}
|
|
5243
|
-
/**
|
|
5244
|
-
function
|
|
5570
|
+
/**
|
|
5571
|
+
* @function
|
|
5572
|
+
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
5573
|
+
*/
|
|
5574
|
+
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
5245
5575
|
const u = uniform(key$1, shape$1);
|
|
5246
5576
|
return negative(log1p(negative(u)));
|
|
5247
|
-
}
|
|
5577
|
+
}, { staticArgnums: [1] });
|
|
5248
5578
|
/**
|
|
5579
|
+
* @function
|
|
5249
5580
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5250
5581
|
*
|
|
5251
5582
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5252
5583
|
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5253
5584
|
* bitwise identical to JAX.
|
|
5254
5585
|
*/
|
|
5255
|
-
function normal(key$1, shape$1 = []) {
|
|
5586
|
+
const normal = jit$1(function normal$1(key$1, shape$1 = []) {
|
|
5256
5587
|
const [k1, k2] = split(key$1, 2);
|
|
5257
5588
|
const u1 = uniform(k1, shape$1);
|
|
5258
5589
|
const u2 = uniform(k2, shape$1);
|
|
5259
5590
|
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5260
5591
|
const theta = u2.mul(2 * Math.PI);
|
|
5261
5592
|
return radius.mul(cos(theta));
|
|
5262
|
-
}
|
|
5593
|
+
}, { staticArgnums: [1] });
|
|
5594
|
+
|
|
5595
|
+
//#endregion
|
|
5596
|
+
//#region src/scipy-special.ts
|
|
5597
|
+
var scipy_special_exports = {};
|
|
5598
|
+
__export(scipy_special_exports, {
|
|
5599
|
+
erf: () => erf,
|
|
5600
|
+
erfc: () => erfc,
|
|
5601
|
+
logSoftmax: () => logSoftmax,
|
|
5602
|
+
logit: () => logit,
|
|
5603
|
+
logsumexp: () => logsumexp,
|
|
5604
|
+
softmax: () => softmax
|
|
5605
|
+
});
|
|
5606
|
+
/**
|
|
5607
|
+
* @function
|
|
5608
|
+
* The logit function, `logit(p) = log(p / (1-p))`.
|
|
5609
|
+
*/
|
|
5610
|
+
const logit = jit$1(function logit$1(x) {
|
|
5611
|
+
return log(x.ref.div(subtract(1, x)));
|
|
5612
|
+
});
|
|
5263
5613
|
|
|
5264
5614
|
//#endregion
|
|
5265
5615
|
//#region src/polyfills.ts
|
|
@@ -5354,6 +5704,25 @@ async function blockUntilReady(x) {
|
|
|
5354
5704
|
await Promise.all(promises);
|
|
5355
5705
|
return x;
|
|
5356
5706
|
}
|
|
5707
|
+
/**
|
|
5708
|
+
* Transfer `x` to `device`.
|
|
5709
|
+
*
|
|
5710
|
+
* `x` may be a nested container of arrays or scalars. The resulting structure
|
|
5711
|
+
* is committed to the device.
|
|
5712
|
+
*
|
|
5713
|
+
* If `device` is not specified, this function behaves as identity if the input
|
|
5714
|
+
* is already an `Array`, otherwise it places the scalar uncommitted on the
|
|
5715
|
+
* default device.
|
|
5716
|
+
*/
|
|
5717
|
+
async function devicePut(x, device) {
|
|
5718
|
+
const [xflat, structure$1] = flatten(x);
|
|
5719
|
+
const yflat = await Promise.all(xflat.map((leaf) => {
|
|
5720
|
+
if (leaf instanceof Array$1) return device ? leaf._put(getBackend(device)) : Promise.resolve(leaf);
|
|
5721
|
+
else return Promise.resolve(array(leaf, { device }));
|
|
5722
|
+
}));
|
|
5723
|
+
return unflatten(structure$1, yflat);
|
|
5724
|
+
}
|
|
5357
5725
|
|
|
5358
5726
|
//#endregion
|
|
5359
|
-
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
5727
|
+
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
5728
|
+
//# sourceMappingURL=index.js.map
|