@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
package/dist/index.cjs
CHANGED
|
@@ -8,9 +8,9 @@ var __hasOwnProp = Object.prototype.hasOwnProperty;
|
|
|
8
8
|
var __commonJS = (cb, mod$1) => function() {
|
|
9
9
|
return mod$1 || (0, cb[__getOwnPropNames(cb)[0]])((mod$1 = { exports: {} }).exports, mod$1), mod$1.exports;
|
|
10
10
|
};
|
|
11
|
-
var __export = (target, all) => {
|
|
12
|
-
for (var name in all) __defProp(target, name, {
|
|
13
|
-
get: all[name],
|
|
11
|
+
var __export = (target, all$1) => {
|
|
12
|
+
for (var name in all$1) __defProp(target, name, {
|
|
13
|
+
get: all$1[name],
|
|
14
14
|
enumerable: true
|
|
15
15
|
});
|
|
16
16
|
};
|
|
@@ -30,30 +30,38 @@ var __toESM = (mod$1, isNodeMode, target) => (target = mod$1 != null ? __create(
|
|
|
30
30
|
}) : target, mod$1));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-Bu9GY6sK.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/frontend/convolution.ts
|
|
36
36
|
/**
|
|
37
37
|
* Check that the shapes and parameters passed to convolution are valid.
|
|
38
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
39
|
+
*
|
|
40
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
41
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
38
42
|
*
|
|
39
43
|
* If the check succeeds, returns the output shape.
|
|
40
44
|
*/
|
|
41
|
-
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
45
|
+
function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
|
|
42
46
|
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
43
|
-
const n = lhsShape.length - 2;
|
|
47
|
+
const n = lhsShape.length - 2 - vmapDims;
|
|
44
48
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
45
49
|
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
46
50
|
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
47
51
|
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
48
52
|
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
49
|
-
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
50
|
-
const outShape = [
|
|
53
|
+
if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
54
|
+
const outShape = [
|
|
55
|
+
...require_backend.generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
|
|
56
|
+
lhsShape[vmapDims],
|
|
57
|
+
rhsShape[vmapDims]
|
|
58
|
+
];
|
|
51
59
|
for (let i = 0; i < n; i++) {
|
|
52
60
|
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
53
61
|
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
54
62
|
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
55
63
|
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
56
|
-
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
64
|
+
const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
|
|
57
65
|
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
58
66
|
const [pl, pr] = padding[i];
|
|
59
67
|
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
@@ -178,27 +186,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
178
186
|
function applyDilation(st, dilation) {
|
|
179
187
|
if (dilation.every((s) => s === 1)) return st;
|
|
180
188
|
const s_ = dilation;
|
|
181
|
-
const
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
]);
|
|
187
|
-
st = st.
|
|
188
|
-
[0, 0],
|
|
189
|
-
[0, 0],
|
|
190
|
-
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
191
|
-
]);
|
|
192
|
-
st = st.reshape([
|
|
193
|
-
a,
|
|
194
|
-
b,
|
|
195
|
-
...k_.map((k, i) => k * s_[i])
|
|
196
|
-
]);
|
|
197
|
-
st = st.shrink([
|
|
198
|
-
[0, a],
|
|
199
|
-
[0, b],
|
|
200
|
-
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
201
|
-
]);
|
|
189
|
+
const n = s_.length;
|
|
190
|
+
const prefix = st.shape.slice(0, -n);
|
|
191
|
+
const k_ = st.shape.slice(-n);
|
|
192
|
+
st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
|
|
193
|
+
st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
|
|
194
|
+
st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
|
|
195
|
+
st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
|
|
202
196
|
return st;
|
|
203
197
|
}
|
|
204
198
|
/**
|
|
@@ -208,25 +202,26 @@ function applyDilation(st, dilation) {
|
|
|
208
202
|
* beforehand using `checkConvShape()`.
|
|
209
203
|
*/
|
|
210
204
|
function prepareConv(stX, stY, params) {
|
|
211
|
-
const
|
|
205
|
+
const v = params.vmapDims;
|
|
206
|
+
const n = stX.shape.length - 2 - v;
|
|
207
|
+
const vmapShape = stX.shape.slice(0, v);
|
|
212
208
|
stX = applyDilation(stX, params.lhsDilation);
|
|
213
|
-
const ks = stY.shape.slice(2);
|
|
214
|
-
stX = stX.padOrShrink([
|
|
215
|
-
[0, 0],
|
|
216
|
-
[0, 0],
|
|
217
|
-
...params.padding
|
|
218
|
-
]);
|
|
209
|
+
const ks = stY.shape.slice(v + 2);
|
|
210
|
+
stX = stX.padOrShrink([...require_backend.rep(v + 2, [0, 0]), ...params.padding]);
|
|
219
211
|
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
220
|
-
stX = stX.moveaxis(1, n + 1).reshape([
|
|
221
|
-
|
|
212
|
+
stX = stX.moveaxis(v + 1, v + n + 1).reshape([
|
|
213
|
+
...vmapShape,
|
|
214
|
+
stX.shape[v],
|
|
222
215
|
1,
|
|
223
|
-
...stX.shape.slice(2, n + 2),
|
|
224
|
-
stX.shape[1] * require_backend.prod(ks)
|
|
216
|
+
...stX.shape.slice(v + 2, v + n + 2),
|
|
217
|
+
stX.shape[v + 1] * require_backend.prod(ks)
|
|
225
218
|
]);
|
|
226
219
|
stY = stY.reshape([
|
|
227
|
-
|
|
220
|
+
...vmapShape,
|
|
221
|
+
1,
|
|
222
|
+
stY.shape[v],
|
|
228
223
|
...require_backend.rep(n, 1),
|
|
229
|
-
stY.shape[1] * require_backend.prod(ks)
|
|
224
|
+
stY.shape[v + 1] * require_backend.prod(ks)
|
|
230
225
|
]);
|
|
231
226
|
return [stX, stY];
|
|
232
227
|
}
|
|
@@ -367,6 +362,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
367
362
|
Primitive$1["Mul"] = "mul";
|
|
368
363
|
Primitive$1["Idiv"] = "idiv";
|
|
369
364
|
Primitive$1["Mod"] = "mod";
|
|
365
|
+
Primitive$1["Min"] = "min";
|
|
366
|
+
Primitive$1["Max"] = "max";
|
|
370
367
|
Primitive$1["Neg"] = "neg";
|
|
371
368
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
372
369
|
Primitive$1["Floor"] = "floor";
|
|
@@ -374,7 +371,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
374
371
|
Primitive$1["StopGradient"] = "stop_gradient";
|
|
375
372
|
Primitive$1["Cast"] = "cast";
|
|
376
373
|
Primitive$1["Bitcast"] = "bitcast";
|
|
377
|
-
Primitive$1["RandomBits"] = "random_bits";
|
|
378
374
|
Primitive$1["Sin"] = "sin";
|
|
379
375
|
Primitive$1["Cos"] = "cos";
|
|
380
376
|
Primitive$1["Asin"] = "asin";
|
|
@@ -384,8 +380,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
384
380
|
Primitive$1["Erf"] = "erf";
|
|
385
381
|
Primitive$1["Erfc"] = "erfc";
|
|
386
382
|
Primitive$1["Sqrt"] = "sqrt";
|
|
387
|
-
Primitive$1["Min"] = "min";
|
|
388
|
-
Primitive$1["Max"] = "max";
|
|
389
383
|
Primitive$1["Reduce"] = "reduce";
|
|
390
384
|
Primitive$1["Dot"] = "dot";
|
|
391
385
|
Primitive$1["Conv"] = "conv";
|
|
@@ -393,14 +387,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
393
387
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
394
388
|
Primitive$1["Compare"] = "compare";
|
|
395
389
|
Primitive$1["Where"] = "where";
|
|
390
|
+
Primitive$1["RandomBits"] = "random_bits";
|
|
391
|
+
Primitive$1["Gather"] = "gather";
|
|
396
392
|
Primitive$1["Transpose"] = "transpose";
|
|
397
393
|
Primitive$1["Broadcast"] = "broadcast";
|
|
398
394
|
Primitive$1["Reshape"] = "reshape";
|
|
399
395
|
Primitive$1["Flip"] = "flip";
|
|
400
396
|
Primitive$1["Shrink"] = "shrink";
|
|
401
397
|
Primitive$1["Pad"] = "pad";
|
|
402
|
-
Primitive$1["
|
|
403
|
-
Primitive$1["
|
|
398
|
+
Primitive$1["Sort"] = "sort";
|
|
399
|
+
Primitive$1["Argsort"] = "argsort";
|
|
400
|
+
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
401
|
+
Primitive$1["Cholesky"] = "cholesky";
|
|
402
|
+
Primitive$1["Jit"] = "jit";
|
|
404
403
|
return Primitive$1;
|
|
405
404
|
}({});
|
|
406
405
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
@@ -422,6 +421,12 @@ function idiv(x, y) {
|
|
|
422
421
|
function mod(x, y) {
|
|
423
422
|
return bind1(Primitive.Mod, [x, y]);
|
|
424
423
|
}
|
|
424
|
+
function min$1(x, y) {
|
|
425
|
+
return bind1(Primitive.Min, [x, y]);
|
|
426
|
+
}
|
|
427
|
+
function max$1(x, y) {
|
|
428
|
+
return bind1(Primitive.Max, [x, y]);
|
|
429
|
+
}
|
|
425
430
|
function neg(x) {
|
|
426
431
|
return bind1(Primitive.Neg, [x]);
|
|
427
432
|
}
|
|
@@ -443,12 +448,6 @@ function cast(x, dtype) {
|
|
|
443
448
|
function bitcast(x, dtype) {
|
|
444
449
|
return bind1(Primitive.Bitcast, [x], { dtype });
|
|
445
450
|
}
|
|
446
|
-
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
447
|
-
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
448
|
-
shape: shape$1,
|
|
449
|
-
mode
|
|
450
|
-
});
|
|
451
|
-
}
|
|
452
451
|
function sin$1(x) {
|
|
453
452
|
return bind1(Primitive.Sin, [x]);
|
|
454
453
|
}
|
|
@@ -476,12 +475,6 @@ function erfc$1(x) {
|
|
|
476
475
|
function sqrt$1(x) {
|
|
477
476
|
return bind1(Primitive.Sqrt, [x]);
|
|
478
477
|
}
|
|
479
|
-
function min$1(x, y) {
|
|
480
|
-
return bind1(Primitive.Min, [x, y]);
|
|
481
|
-
}
|
|
482
|
-
function max$1(x, y) {
|
|
483
|
-
return bind1(Primitive.Max, [x, y]);
|
|
484
|
-
}
|
|
485
478
|
function reduce(x, op, axis = null, opts) {
|
|
486
479
|
if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
487
480
|
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
@@ -498,9 +491,11 @@ function dot$2(x, y) {
|
|
|
498
491
|
}
|
|
499
492
|
function conv$1(x, y, params = {}) {
|
|
500
493
|
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
501
|
-
const
|
|
494
|
+
const vmapDims = params.vmapDims ?? 0;
|
|
495
|
+
const n = x.ndim - 2 - vmapDims;
|
|
502
496
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
503
497
|
return bind1(Primitive.Conv, [x, y], {
|
|
498
|
+
vmapDims,
|
|
504
499
|
strides: params.strides ?? require_backend.rep(n, 1),
|
|
505
500
|
padding: params.padding ?? require_backend.rep(n, [0, 0]),
|
|
506
501
|
lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
|
|
@@ -535,6 +530,23 @@ function where$1(cond, x, y) {
|
|
|
535
530
|
y
|
|
536
531
|
]);
|
|
537
532
|
}
|
|
533
|
+
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
534
|
+
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
535
|
+
shape: shape$1,
|
|
536
|
+
mode
|
|
537
|
+
});
|
|
538
|
+
}
|
|
539
|
+
function gather(x, indices, axis, outDim) {
|
|
540
|
+
if (indices.length === 0) throw new Error("gather() requires at least one index");
|
|
541
|
+
if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
|
|
542
|
+
axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
|
|
543
|
+
if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
|
|
544
|
+
outDim = require_backend.checkAxis(outDim, ndim$1(x) - axis.length + 1);
|
|
545
|
+
return bind1(Primitive.Gather, [x, ...indices], {
|
|
546
|
+
axis,
|
|
547
|
+
outDim
|
|
548
|
+
});
|
|
549
|
+
}
|
|
538
550
|
function transpose$1(x, perm) {
|
|
539
551
|
perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
|
|
540
552
|
if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
@@ -584,16 +596,27 @@ function pad$1(x, width) {
|
|
|
584
596
|
} else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
|
|
585
597
|
return bind1(Primitive.Pad, [x], { width });
|
|
586
598
|
}
|
|
587
|
-
function
|
|
588
|
-
if (
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
599
|
+
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
600
|
+
if (lower) {
|
|
601
|
+
a = flip$1(a, [-2, -1]);
|
|
602
|
+
b = flip$1(b, [-1]);
|
|
603
|
+
}
|
|
604
|
+
let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
605
|
+
if (lower) x = flip$1(x, [-1]);
|
|
606
|
+
return x;
|
|
607
|
+
}
|
|
608
|
+
function cholesky$2(x) {
|
|
609
|
+
return bind1(Primitive.Cholesky, [x]);
|
|
610
|
+
}
|
|
611
|
+
function sort$1(x) {
|
|
612
|
+
const nd = ndim$1(x);
|
|
613
|
+
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
614
|
+
return bind1(Primitive.Sort, [x]);
|
|
615
|
+
}
|
|
616
|
+
function argsort$1(x) {
|
|
617
|
+
const nd = ndim$1(x);
|
|
618
|
+
if (nd === 0) throw new Error("argsort: requires at least 1D input");
|
|
619
|
+
return bind(Primitive.Argsort, [x]);
|
|
597
620
|
}
|
|
598
621
|
function bind1(prim, args, params = {}) {
|
|
599
622
|
const [results] = bind(prim, args, params);
|
|
@@ -724,8 +747,10 @@ var Tracer = class Tracer {
|
|
|
724
747
|
axis = require_backend.normalizeAxis(axis, this.ndim);
|
|
725
748
|
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
726
749
|
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
727
|
-
const
|
|
728
|
-
|
|
750
|
+
const originalDtype = this.dtype;
|
|
751
|
+
const castDtype = require_backend.promoteTypes(originalDtype, require_backend.DType.Float32);
|
|
752
|
+
const result = reduce(this.astype(castDtype), require_backend.AluOp.Add, axis, opts);
|
|
753
|
+
return result.mul(1 / n).astype(originalDtype);
|
|
729
754
|
}
|
|
730
755
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
731
756
|
transpose(perm) {
|
|
@@ -754,7 +779,7 @@ var Tracer = class Tracer {
|
|
|
754
779
|
if (require_backend.isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
|
|
755
780
|
return idiv(this, other);
|
|
756
781
|
}
|
|
757
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
782
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
758
783
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
759
784
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
760
785
|
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
@@ -807,6 +832,34 @@ var Tracer = class Tracer {
|
|
|
807
832
|
this.dispose();
|
|
808
833
|
}
|
|
809
834
|
/**
|
|
835
|
+
* Return a sorted copy of an array in ascending order.
|
|
836
|
+
*
|
|
837
|
+
* See `jax.numpy.sort` for full docs.
|
|
838
|
+
*/
|
|
839
|
+
sort(axis = -1) {
|
|
840
|
+
axis = require_backend.checkAxis(axis, this.ndim);
|
|
841
|
+
if (this.shape[axis] <= 1) return this;
|
|
842
|
+
if (axis === this.ndim - 1) return sort$1(this);
|
|
843
|
+
const perm = require_backend.range(this.ndim);
|
|
844
|
+
perm.splice(axis, 1);
|
|
845
|
+
perm.push(axis);
|
|
846
|
+
return sort$1(this.transpose(perm)).transpose(require_backend.invertPermutation(perm));
|
|
847
|
+
}
|
|
848
|
+
/**
|
|
849
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
850
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
851
|
+
*
|
|
852
|
+
* See `jax.numpy.argsort` for full docs.
|
|
853
|
+
*/
|
|
854
|
+
argsort(axis = -1) {
|
|
855
|
+
axis = require_backend.checkAxis(axis, this.ndim);
|
|
856
|
+
if (axis === this.ndim - 1) return argsort$1(this)[1];
|
|
857
|
+
const perm = require_backend.range(this.ndim);
|
|
858
|
+
perm.splice(axis, 1);
|
|
859
|
+
perm.push(axis);
|
|
860
|
+
return argsort$1(this.transpose(perm))[1].transpose(require_backend.invertPermutation(perm));
|
|
861
|
+
}
|
|
862
|
+
/**
|
|
810
863
|
* Slice an array along one or more axes.
|
|
811
864
|
*
|
|
812
865
|
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
@@ -923,6 +976,9 @@ var ShapedArray = class ShapedArray {
|
|
|
923
976
|
get ndim() {
|
|
924
977
|
return this.shape.length;
|
|
925
978
|
}
|
|
979
|
+
get size() {
|
|
980
|
+
return require_backend.prod(this.shape);
|
|
981
|
+
}
|
|
926
982
|
toString() {
|
|
927
983
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
928
984
|
}
|
|
@@ -1205,7 +1261,7 @@ var Jaxpr = class Jaxpr {
|
|
|
1205
1261
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
1206
1262
|
const [a, b] = inputs;
|
|
1207
1263
|
const c = eqn.outBinders[0];
|
|
1208
|
-
if (atomIsLit(b, 1)) context.set(c, a);
|
|
1264
|
+
if (atomIsLit(b, 1) && !require_backend.isFloatDtype(a.aval.dtype)) context.set(c, a);
|
|
1209
1265
|
else newEqns.push(eqn);
|
|
1210
1266
|
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
1211
1267
|
else newEqns.push(eqn);
|
|
@@ -1222,13 +1278,13 @@ var Jaxpr = class Jaxpr {
|
|
|
1222
1278
|
}
|
|
1223
1279
|
return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
|
|
1224
1280
|
}
|
|
1225
|
-
/** Flattens nested
|
|
1281
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1226
1282
|
flatten() {
|
|
1227
|
-
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.
|
|
1283
|
+
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
|
|
1228
1284
|
const newEqns = [];
|
|
1229
1285
|
const varMap = /* @__PURE__ */ new Map();
|
|
1230
1286
|
const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
|
|
1231
|
-
for (const eqn of this.eqns) if (eqn.primitive === Primitive.
|
|
1287
|
+
for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
|
|
1232
1288
|
const jaxpr = eqn.params.jaxpr.flatten();
|
|
1233
1289
|
const translation = /* @__PURE__ */ new Map();
|
|
1234
1290
|
const translationF = (x) => x instanceof Var ? translation.get(x) : x;
|
|
@@ -1329,19 +1385,48 @@ function evalJaxpr(jaxpr, args) {
|
|
|
1329
1385
|
function jaxprAsFun(jaxpr) {
|
|
1330
1386
|
return (...args) => evalJaxpr(jaxpr, args);
|
|
1331
1387
|
}
|
|
1388
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1389
|
+
var ClosedJaxpr = class ClosedJaxpr {
|
|
1390
|
+
constructor(jaxpr, consts) {
|
|
1391
|
+
this.jaxpr = jaxpr;
|
|
1392
|
+
this.consts = consts;
|
|
1393
|
+
}
|
|
1394
|
+
/** String representation of this Jaxpr. */
|
|
1395
|
+
toString() {
|
|
1396
|
+
return this.jaxpr.toString();
|
|
1397
|
+
}
|
|
1398
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1399
|
+
mapJaxpr(f) {
|
|
1400
|
+
return new ClosedJaxpr(f(this.jaxpr), this.consts);
|
|
1401
|
+
}
|
|
1402
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1403
|
+
dispose() {
|
|
1404
|
+
for (const c of this.consts) c.dispose();
|
|
1405
|
+
}
|
|
1406
|
+
};
|
|
1332
1407
|
/** Tracer that records its operations to dynamically construct a Jaxpr. */
|
|
1333
1408
|
var JaxprTracer = class extends Tracer {
|
|
1409
|
+
#rc;
|
|
1334
1410
|
constructor(trace$1, aval) {
|
|
1335
1411
|
super(trace$1);
|
|
1336
1412
|
this.aval = aval;
|
|
1413
|
+
this.#rc = 1;
|
|
1337
1414
|
}
|
|
1338
1415
|
toString() {
|
|
1339
1416
|
return `JaxprTracer(${this.aval.toString()})`;
|
|
1340
1417
|
}
|
|
1341
1418
|
get ref() {
|
|
1419
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1420
|
+
this.#rc++;
|
|
1342
1421
|
return this;
|
|
1343
1422
|
}
|
|
1344
|
-
dispose() {
|
|
1423
|
+
dispose() {
|
|
1424
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1425
|
+
this.#rc--;
|
|
1426
|
+
}
|
|
1427
|
+
trackLiftedConstant() {
|
|
1428
|
+
this.#rc++;
|
|
1429
|
+
}
|
|
1345
1430
|
};
|
|
1346
1431
|
/** Analogous to the 'DynamicJaxprTrace' class in JAX. */
|
|
1347
1432
|
var JaxprTrace = class extends Trace {
|
|
@@ -1354,17 +1439,24 @@ var JaxprTrace = class extends Trace {
|
|
|
1354
1439
|
}
|
|
1355
1440
|
/** Register a constant / literal in this Jaxpr. */
|
|
1356
1441
|
getOrMakeConstTracer(val) {
|
|
1442
|
+
if (!(val instanceof Tracer)) val = pureArray(val);
|
|
1357
1443
|
let tracer = this.builder.constTracers.get(val);
|
|
1358
1444
|
if (tracer === void 0) {
|
|
1359
1445
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
1360
|
-
this.builder.addConst(tracer, val
|
|
1446
|
+
this.builder.addConst(tracer, val);
|
|
1447
|
+
} else {
|
|
1448
|
+
val.dispose();
|
|
1449
|
+
tracer.trackLiftedConstant();
|
|
1361
1450
|
}
|
|
1362
1451
|
return tracer;
|
|
1363
1452
|
}
|
|
1364
1453
|
pure = this.getOrMakeConstTracer;
|
|
1365
1454
|
lift = this.getOrMakeConstTracer;
|
|
1366
1455
|
processPrimitive(primitive, tracers, params) {
|
|
1367
|
-
const avalsIn = tracers.map((t) =>
|
|
1456
|
+
const avalsIn = tracers.map((t) => {
|
|
1457
|
+
t.dispose();
|
|
1458
|
+
return t.aval;
|
|
1459
|
+
});
|
|
1368
1460
|
const avalsOut = abstractEvalRules[primitive](avalsIn, params);
|
|
1369
1461
|
const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
|
|
1370
1462
|
this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
|
|
@@ -1407,20 +1499,17 @@ var JaxprBuilder = class {
|
|
|
1407
1499
|
return v;
|
|
1408
1500
|
}
|
|
1409
1501
|
build(inTracers, outTracers) {
|
|
1410
|
-
|
|
1502
|
+
const [constVars, consts] = require_backend.unzip2(this.constVals.entries());
|
|
1411
1503
|
const t2v = this.getVar.bind(this);
|
|
1412
1504
|
const inBinders = [...constVars, ...inTracers.map(t2v)];
|
|
1413
1505
|
const outVars = outTracers.map(t2v);
|
|
1414
|
-
|
|
1506
|
+
const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
|
|
1415
1507
|
typecheckJaxpr(jaxpr);
|
|
1416
|
-
|
|
1417
|
-
return
|
|
1418
|
-
jaxpr,
|
|
1419
|
-
consts
|
|
1420
|
-
};
|
|
1508
|
+
const cjaxpr = new ClosedJaxpr(jaxpr, consts);
|
|
1509
|
+
return _inlineLiterals(cjaxpr);
|
|
1421
1510
|
}
|
|
1422
1511
|
};
|
|
1423
|
-
function _inlineLiterals(jaxpr, consts) {
|
|
1512
|
+
function _inlineLiterals({ jaxpr, consts }) {
|
|
1424
1513
|
const literals = /* @__PURE__ */ new Map();
|
|
1425
1514
|
const constBinders = [];
|
|
1426
1515
|
const newConsts = [];
|
|
@@ -1435,7 +1524,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
1435
1524
|
const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
|
|
1436
1525
|
const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
|
|
1437
1526
|
typecheckJaxpr(newJaxpr);
|
|
1438
|
-
return
|
|
1527
|
+
return new ClosedJaxpr(newJaxpr, newConsts);
|
|
1439
1528
|
}
|
|
1440
1529
|
function binopAbstractEval([x, y]) {
|
|
1441
1530
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
@@ -1454,6 +1543,8 @@ const abstractEvalRules = {
|
|
|
1454
1543
|
[Primitive.Mul]: binopAbstractEval,
|
|
1455
1544
|
[Primitive.Idiv]: binopAbstractEval,
|
|
1456
1545
|
[Primitive.Mod]: binopAbstractEval,
|
|
1546
|
+
[Primitive.Min]: binopAbstractEval,
|
|
1547
|
+
[Primitive.Max]: binopAbstractEval,
|
|
1457
1548
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1458
1549
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1459
1550
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1467,12 +1558,6 @@ const abstractEvalRules = {
|
|
|
1467
1558
|
if (require_backend.byteWidth(x.dtype) !== require_backend.byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1468
1559
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1469
1560
|
},
|
|
1470
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1471
|
-
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1472
|
-
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1473
|
-
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1474
|
-
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1475
|
-
},
|
|
1476
1561
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
1477
1562
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
1478
1563
|
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
@@ -1482,8 +1567,6 @@ const abstractEvalRules = {
|
|
|
1482
1567
|
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
1483
1568
|
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
1484
1569
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
1485
|
-
[Primitive.Min]: binopAbstractEval,
|
|
1486
|
-
[Primitive.Max]: binopAbstractEval,
|
|
1487
1570
|
[Primitive.Reduce]([x], { axis }) {
|
|
1488
1571
|
const axisSet = new Set(axis);
|
|
1489
1572
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1516,6 +1599,25 @@ const abstractEvalRules = {
|
|
|
1516
1599
|
const shape$1 = require_backend.generalBroadcast(cond.shape, xy.shape);
|
|
1517
1600
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1518
1601
|
},
|
|
1602
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1603
|
+
if (k0.dtype !== require_backend.DType.Uint32 || k1.dtype !== require_backend.DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1604
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
1605
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1606
|
+
return [new ShapedArray(shape$1, require_backend.DType.Uint32, false)];
|
|
1607
|
+
},
|
|
1608
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
1609
|
+
for (const a of indices) if (a.dtype !== require_backend.DType.Int32 && a.dtype !== require_backend.DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
1610
|
+
if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
|
|
1611
|
+
if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
|
|
1612
|
+
if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
|
|
1613
|
+
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
1614
|
+
const axisSet = new Set(axis);
|
|
1615
|
+
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
1616
|
+
const gatherShape = indices.reduce((shape$1, a) => require_backend.generalBroadcast(shape$1, a.shape), []);
|
|
1617
|
+
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
1618
|
+
newShape.splice(outDim, 0, ...gatherShape);
|
|
1619
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1620
|
+
},
|
|
1519
1621
|
[Primitive.Transpose]([x], { perm }) {
|
|
1520
1622
|
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
1521
1623
|
},
|
|
@@ -1536,23 +1638,31 @@ const abstractEvalRules = {
|
|
|
1536
1638
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
1537
1639
|
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1538
1640
|
},
|
|
1539
|
-
[Primitive.
|
|
1540
|
-
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
if (
|
|
1545
|
-
|
|
1546
|
-
|
|
1547
|
-
|
|
1548
|
-
|
|
1549
|
-
|
|
1550
|
-
|
|
1641
|
+
[Primitive.Sort]([x]) {
|
|
1642
|
+
if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
|
|
1643
|
+
return [ShapedArray.fromAval(x)];
|
|
1644
|
+
},
|
|
1645
|
+
[Primitive.Argsort]([x]) {
|
|
1646
|
+
if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
|
|
1647
|
+
return [ShapedArray.fromAval(x), new ShapedArray(x.shape, require_backend.DType.Int32, false)];
|
|
1648
|
+
},
|
|
1649
|
+
[Primitive.TriangularSolve]([a, b]) {
|
|
1650
|
+
if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
|
|
1651
|
+
if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
|
|
1652
|
+
const [m, n] = a.shape.slice(-2);
|
|
1653
|
+
const [_batch, q] = b.shape.slice(-2);
|
|
1654
|
+
if (!require_backend.deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
|
|
1655
|
+
return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
|
|
1656
|
+
},
|
|
1657
|
+
[Primitive.Cholesky]([a]) {
|
|
1658
|
+
if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
|
|
1659
|
+
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1660
|
+
return [ShapedArray.fromAval(a)];
|
|
1551
1661
|
},
|
|
1552
|
-
[Primitive.
|
|
1662
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
1553
1663
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1554
|
-
if (args.length !== inTypes.length) throw new TypeError(`
|
|
1555
|
-
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`
|
|
1664
|
+
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
1665
|
+
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
|
|
1556
1666
|
return outTypes;
|
|
1557
1667
|
}
|
|
1558
1668
|
};
|
|
@@ -1588,11 +1698,10 @@ function makeJaxpr$1(f, opts) {
|
|
|
1588
1698
|
const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
1589
1699
|
const outs = fFlat(...tracersIn);
|
|
1590
1700
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
1591
|
-
const
|
|
1701
|
+
const jaxpr = builder.build(tracersIn, tracersOut);
|
|
1592
1702
|
if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
|
|
1593
1703
|
return {
|
|
1594
|
-
jaxpr: jaxpr.simplify(),
|
|
1595
|
-
consts,
|
|
1704
|
+
jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
|
|
1596
1705
|
treedef: outTree.value
|
|
1597
1706
|
};
|
|
1598
1707
|
} catch (_) {
|
|
@@ -1611,22 +1720,28 @@ function jit$1(f, opts) {
|
|
|
1611
1720
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
1612
1721
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
1613
1722
|
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
1614
|
-
const { jaxpr,
|
|
1615
|
-
const outs = bind(Primitive.
|
|
1723
|
+
const { jaxpr, treedef: outTree } = require_backend.runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
1724
|
+
const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
|
|
1616
1725
|
name: f.name || "closure",
|
|
1617
|
-
jaxpr,
|
|
1618
|
-
numConsts: consts.length
|
|
1726
|
+
jaxpr: jaxpr.jaxpr,
|
|
1727
|
+
numConsts: jaxpr.consts.length
|
|
1619
1728
|
});
|
|
1620
1729
|
return unflatten(outTree, outs);
|
|
1621
1730
|
});
|
|
1622
1731
|
result.dispose = () => {
|
|
1623
|
-
for (const {
|
|
1732
|
+
for (const { jaxpr } of cache.values()) jaxpr.dispose();
|
|
1624
1733
|
};
|
|
1625
1734
|
return result;
|
|
1626
1735
|
}
|
|
1627
1736
|
|
|
1628
1737
|
//#endregion
|
|
1629
1738
|
//#region src/frontend/jit.ts
|
|
1739
|
+
const routinePrimitives = new Map([
|
|
1740
|
+
[Primitive.Sort, require_backend.Routines.Sort],
|
|
1741
|
+
[Primitive.Argsort, require_backend.Routines.Argsort],
|
|
1742
|
+
[Primitive.TriangularSolve, require_backend.Routines.TriangularSolve],
|
|
1743
|
+
[Primitive.Cholesky, require_backend.Routines.Cholesky]
|
|
1744
|
+
]);
|
|
1630
1745
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1631
1746
|
var JitProgram = class {
|
|
1632
1747
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1641,9 +1756,14 @@ var JitProgram = class {
|
|
|
1641
1756
|
case "execute": {
|
|
1642
1757
|
const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
|
|
1643
1758
|
const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
|
|
1644
|
-
|
|
1759
|
+
const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
|
|
1760
|
+
if (step.source instanceof require_backend.Kernel) return require_backend.PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
|
|
1761
|
+
else if (step.source instanceof require_backend.Routine) return require_backend.PPrint.pp(`${executeText}, routine ${step.source.name}`);
|
|
1762
|
+
else {
|
|
1763
|
+
step.source;
|
|
1764
|
+
return require_backend.PPrint.pp(executeText);
|
|
1765
|
+
}
|
|
1645
1766
|
}
|
|
1646
|
-
case "const": return require_backend.PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
|
|
1647
1767
|
case "malloc": return require_backend.PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
|
|
1648
1768
|
case "incref": return require_backend.PPrint.pp(`incref ${step.input}`);
|
|
1649
1769
|
case "free": return require_backend.PPrint.pp(`free ${step.input}`);
|
|
@@ -1666,12 +1786,9 @@ var JitProgram = class {
|
|
|
1666
1786
|
const inputs$1 = step.inputs.map((id) => scope.get(id));
|
|
1667
1787
|
const outputs = step.outputs.map((id) => scope.get(id));
|
|
1668
1788
|
if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
|
|
1669
|
-
pending.push(new PendingExecute(this.backend, step.
|
|
1789
|
+
pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
|
|
1670
1790
|
break;
|
|
1671
1791
|
}
|
|
1672
|
-
case "const":
|
|
1673
|
-
scope.set(step.output, step.slot);
|
|
1674
|
-
break;
|
|
1675
1792
|
case "malloc": {
|
|
1676
1793
|
const slot = this.backend.malloc(step.size);
|
|
1677
1794
|
scope.set(step.output, slot);
|
|
@@ -1705,34 +1822,37 @@ var JitProgramBuilder = class {
|
|
|
1705
1822
|
this.#nextId = nargs;
|
|
1706
1823
|
this.steps = [];
|
|
1707
1824
|
}
|
|
1708
|
-
pushConst(slot) {
|
|
1709
|
-
const id = this.#nextId++;
|
|
1710
|
-
this.steps.push({
|
|
1711
|
-
type: "const",
|
|
1712
|
-
slot,
|
|
1713
|
-
output: id
|
|
1714
|
-
});
|
|
1715
|
-
return id;
|
|
1716
|
-
}
|
|
1717
1825
|
pushLit(lit) {
|
|
1718
|
-
const kernel = new require_backend.Kernel(0,
|
|
1826
|
+
const kernel = new require_backend.Kernel(0, lit.aval.size, require_backend.AluExp.const(lit.dtype, lit.value));
|
|
1719
1827
|
return this.pushKernel(kernel, []);
|
|
1720
1828
|
}
|
|
1721
|
-
|
|
1829
|
+
pushBuffer(size$1) {
|
|
1722
1830
|
const id = this.#nextId++;
|
|
1723
1831
|
this.steps.push({
|
|
1724
1832
|
type: "malloc",
|
|
1725
|
-
size:
|
|
1833
|
+
size: size$1,
|
|
1726
1834
|
output: id
|
|
1727
1835
|
});
|
|
1836
|
+
return id;
|
|
1837
|
+
}
|
|
1838
|
+
pushKernel(kernel, inputs) {
|
|
1839
|
+
const id = this.pushBuffer(kernel.bytes);
|
|
1728
1840
|
this.steps.push({
|
|
1729
1841
|
type: "execute",
|
|
1730
|
-
kernel,
|
|
1842
|
+
source: kernel,
|
|
1731
1843
|
inputs,
|
|
1732
1844
|
outputs: [id]
|
|
1733
1845
|
});
|
|
1734
1846
|
return id;
|
|
1735
1847
|
}
|
|
1848
|
+
pushRoutine(routine, inputs, outputs) {
|
|
1849
|
+
this.steps.push({
|
|
1850
|
+
type: "execute",
|
|
1851
|
+
source: routine,
|
|
1852
|
+
inputs,
|
|
1853
|
+
outputs
|
|
1854
|
+
});
|
|
1855
|
+
}
|
|
1736
1856
|
pushIncref(id) {
|
|
1737
1857
|
this.steps.push({
|
|
1738
1858
|
type: "incref",
|
|
@@ -1758,28 +1878,18 @@ var JitProgramBuilder = class {
|
|
|
1758
1878
|
}
|
|
1759
1879
|
};
|
|
1760
1880
|
const jitCompileCache = /* @__PURE__ */ new Map();
|
|
1761
|
-
function jitCompile(backend, jaxpr
|
|
1762
|
-
|
|
1763
|
-
for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
|
|
1764
|
-
const cacheKey = backend.type + require_backend.FpHash.hash(jaxpr, ...consts.map((c) => c.id));
|
|
1881
|
+
function jitCompile(backend, jaxpr) {
|
|
1882
|
+
const cacheKey = backend.type + "," + require_backend.FpHash.hash(jaxpr);
|
|
1765
1883
|
const cached = jitCompileCache.get(cacheKey);
|
|
1766
1884
|
if (cached) return cached;
|
|
1767
1885
|
if (require_backend.DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
|
|
1768
1886
|
jaxpr = jaxpr.flatten().simplify();
|
|
1769
|
-
const nargs = jaxpr.inBinders.length
|
|
1887
|
+
const nargs = jaxpr.inBinders.length;
|
|
1770
1888
|
const builder = new JitProgramBuilder(backend, nargs);
|
|
1771
1889
|
const blackNodes = splitGraphDataflow(backend, jaxpr);
|
|
1772
1890
|
const ctx = /* @__PURE__ */ new Map();
|
|
1773
|
-
for (let i = 0; i < consts.length; i++) {
|
|
1774
|
-
const v = jaxpr.inBinders[i];
|
|
1775
|
-
const slot = consts[i]._realizeSource();
|
|
1776
|
-
ctx.set(v, {
|
|
1777
|
-
type: "imm",
|
|
1778
|
-
arg: builder.pushConst(slot)
|
|
1779
|
-
});
|
|
1780
|
-
}
|
|
1781
1891
|
for (let i = 0; i < nargs; i++) {
|
|
1782
|
-
const v = jaxpr.inBinders[
|
|
1892
|
+
const v = jaxpr.inBinders[i];
|
|
1783
1893
|
ctx.set(v, {
|
|
1784
1894
|
type: "imm",
|
|
1785
1895
|
arg: i
|
|
@@ -1787,51 +1897,101 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1787
1897
|
}
|
|
1788
1898
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1789
1899
|
const eqn = jaxpr.eqns[i];
|
|
1900
|
+
if (routinePrimitives.has(eqn.primitive)) {
|
|
1901
|
+
const routine = new require_backend.Routine(routinePrimitives.get(eqn.primitive), {
|
|
1902
|
+
inputShapes: eqn.inputs.map((x) => x.aval.shape),
|
|
1903
|
+
inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
|
|
1904
|
+
outputShapes: eqn.outBinders.map((x) => x.aval.shape),
|
|
1905
|
+
outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
|
|
1906
|
+
}, eqn.params);
|
|
1907
|
+
const inputs = [];
|
|
1908
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1909
|
+
const jv = ctx.get(input);
|
|
1910
|
+
if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
|
|
1911
|
+
inputs.push(jv.arg);
|
|
1912
|
+
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1913
|
+
const outputs = [];
|
|
1914
|
+
for (const outVar$1 of eqn.outBinders) {
|
|
1915
|
+
const outId = builder.pushBuffer(outVar$1.aval.size * require_backend.byteWidth(outVar$1.aval.dtype));
|
|
1916
|
+
outputs.push(outId);
|
|
1917
|
+
ctx.set(outVar$1, {
|
|
1918
|
+
type: "imm",
|
|
1919
|
+
arg: outId
|
|
1920
|
+
});
|
|
1921
|
+
}
|
|
1922
|
+
builder.pushRoutine(routine, inputs, outputs);
|
|
1923
|
+
continue;
|
|
1924
|
+
}
|
|
1790
1925
|
const inputExps = [];
|
|
1791
1926
|
const inputAvals = [];
|
|
1792
1927
|
const inputArgs = [];
|
|
1793
|
-
|
|
1794
|
-
|
|
1795
|
-
|
|
1796
|
-
|
|
1797
|
-
|
|
1798
|
-
|
|
1799
|
-
|
|
1800
|
-
|
|
1801
|
-
inputArgs.push(jitId);
|
|
1802
|
-
}
|
|
1803
|
-
gidMap.set(gid, newGid);
|
|
1804
|
-
}
|
|
1805
|
-
inputExps.push(jitValue.exp.reindexGids(gidMap));
|
|
1806
|
-
} else if (jitValue.type === "imm") {
|
|
1807
|
-
let gid = inputArgs.indexOf(jitValue.arg);
|
|
1808
|
-
if (gid === -1) {
|
|
1809
|
-
gid = inputArgs.length;
|
|
1810
|
-
inputArgs.push(jitValue.arg);
|
|
1928
|
+
let inputReduction = null;
|
|
1929
|
+
const addArgs = (args) => {
|
|
1930
|
+
const newGids = [];
|
|
1931
|
+
for (const jitId of args) {
|
|
1932
|
+
let newGid = inputArgs.indexOf(jitId);
|
|
1933
|
+
if (newGid === -1) {
|
|
1934
|
+
newGid = inputArgs.length;
|
|
1935
|
+
inputArgs.push(jitId);
|
|
1811
1936
|
}
|
|
1937
|
+
newGids.push(newGid);
|
|
1938
|
+
}
|
|
1939
|
+
return newGids;
|
|
1940
|
+
};
|
|
1941
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1942
|
+
const jv = ctx.get(input);
|
|
1943
|
+
if (jv.type === "exp") {
|
|
1944
|
+
const newGids = addArgs(jv.args);
|
|
1945
|
+
inputExps.push(jv.exp.reindexGids(newGids));
|
|
1946
|
+
} else if (jv.type === "imm") {
|
|
1947
|
+
const [gid] = addArgs([jv.arg]);
|
|
1812
1948
|
const st = require_backend.ShapeTracker.fromShape(input.aval.shape);
|
|
1813
1949
|
const indices = require_backend.unravelAlu(st.shape, require_backend.AluVar.gidx);
|
|
1814
1950
|
inputExps.push(require_backend.AluExp.globalView(input.aval.dtype, gid, st, indices));
|
|
1951
|
+
} else if (jv.type === "red") {
|
|
1952
|
+
if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
|
|
1953
|
+
const newGids = addArgs(jv.args);
|
|
1954
|
+
inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
|
|
1955
|
+
inputReduction = jv;
|
|
1815
1956
|
}
|
|
1816
1957
|
inputAvals.push(input.aval);
|
|
1817
1958
|
} else if (input instanceof Lit) {
|
|
1818
1959
|
inputExps.push(require_backend.AluExp.const(input.dtype, input.value));
|
|
1819
1960
|
inputAvals.push(input.aval);
|
|
1820
1961
|
} else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
|
|
1821
|
-
const nargs$1 = inputArgs.length;
|
|
1822
1962
|
const rule = jitRules[eqn.primitive];
|
|
1823
1963
|
if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
|
|
1824
|
-
|
|
1964
|
+
let exp$2;
|
|
1965
|
+
let reduction;
|
|
1966
|
+
if (inputReduction) {
|
|
1967
|
+
const jv = inputReduction;
|
|
1968
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1969
|
+
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
1970
|
+
reduction = new require_backend.Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1971
|
+
} else {
|
|
1972
|
+
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1973
|
+
exp$2 = ruleOutput.exp;
|
|
1974
|
+
reduction = ruleOutput.reduction;
|
|
1975
|
+
}
|
|
1825
1976
|
const outVar = eqn.outBinders[0];
|
|
1826
|
-
if (
|
|
1977
|
+
if (blackNodes.has(outVar)) {
|
|
1978
|
+
const nargs$1 = inputArgs.length;
|
|
1979
|
+
const size$1 = outVar.aval.size;
|
|
1980
|
+
const kernel = new require_backend.Kernel(nargs$1, size$1, exp$2, reduction);
|
|
1827
1981
|
const outId = builder.pushKernel(kernel, inputArgs);
|
|
1828
1982
|
ctx.set(outVar, {
|
|
1829
1983
|
type: "imm",
|
|
1830
1984
|
arg: outId
|
|
1831
1985
|
});
|
|
1832
|
-
} else ctx.set(outVar, {
|
|
1986
|
+
} else if (reduction) ctx.set(outVar, {
|
|
1987
|
+
type: "red",
|
|
1988
|
+
exp: exp$2,
|
|
1989
|
+
reduction,
|
|
1990
|
+
args: inputArgs
|
|
1991
|
+
});
|
|
1992
|
+
else ctx.set(outVar, {
|
|
1833
1993
|
type: "exp",
|
|
1834
|
-
exp:
|
|
1994
|
+
exp: exp$2,
|
|
1835
1995
|
args: inputArgs
|
|
1836
1996
|
});
|
|
1837
1997
|
}
|
|
@@ -1841,7 +2001,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1841
2001
|
if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
|
|
1842
2002
|
outputIds.push(jitValue.arg);
|
|
1843
2003
|
} else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
|
|
1844
|
-
const outputNeedsRef = new Set(
|
|
2004
|
+
const outputNeedsRef = new Set(require_backend.range(nargs));
|
|
1845
2005
|
for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
|
|
1846
2006
|
else outputNeedsRef.add(outputId);
|
|
1847
2007
|
builder.insertFreeSteps(outputIds);
|
|
@@ -1863,31 +2023,33 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1863
2023
|
});
|
|
1864
2024
|
}
|
|
1865
2025
|
function broadcastedJit(fn, opts) {
|
|
1866
|
-
return (
|
|
2026
|
+
return (exps, avals, params) => {
|
|
1867
2027
|
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1868
2028
|
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1869
2029
|
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1870
|
-
exps = exps.map((exp$
|
|
1871
|
-
exp$
|
|
2030
|
+
exps = exps.map((exp$2, i) => {
|
|
2031
|
+
exp$2 = reshapeViews(exp$2, (st) => {
|
|
1872
2032
|
if (!require_backend.deepEqual(st.shape, newShape)) return st.broadcast(newShape, require_backend.range(newShape.length - st.shape.length));
|
|
1873
2033
|
});
|
|
1874
|
-
if (exp$
|
|
1875
|
-
return exp$
|
|
2034
|
+
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = require_backend.AluExp.cast(newDtype, exp$2);
|
|
2035
|
+
return exp$2;
|
|
1876
2036
|
});
|
|
1877
|
-
|
|
1878
|
-
return new require_backend.Kernel(nargs, require_backend.prod(newShape), exp$2);
|
|
2037
|
+
return { exp: fn(exps, params) };
|
|
1879
2038
|
};
|
|
1880
2039
|
}
|
|
1881
2040
|
function unopJit(fn) {
|
|
1882
|
-
return (
|
|
1883
|
-
return
|
|
2041
|
+
return ([a], [_as], params) => {
|
|
2042
|
+
return { exp: fn(a, params) };
|
|
1884
2043
|
};
|
|
1885
2044
|
}
|
|
1886
2045
|
function reshapeJit(fn) {
|
|
1887
|
-
return (
|
|
1888
|
-
|
|
1889
|
-
|
|
1890
|
-
|
|
2046
|
+
return ([a], [_as], params) => {
|
|
2047
|
+
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2048
|
+
};
|
|
2049
|
+
}
|
|
2050
|
+
function routineNoJit() {
|
|
2051
|
+
return () => {
|
|
2052
|
+
throw new Error("jit: rule is not implemented for routines");
|
|
1891
2053
|
};
|
|
1892
2054
|
}
|
|
1893
2055
|
const jitRules = {
|
|
@@ -1895,6 +2057,8 @@ const jitRules = {
|
|
|
1895
2057
|
[Primitive.Mul]: broadcastedJit(([a, b]) => require_backend.AluExp.mul(a, b)),
|
|
1896
2058
|
[Primitive.Idiv]: broadcastedJit(([a, b]) => require_backend.AluExp.idiv(a, b)),
|
|
1897
2059
|
[Primitive.Mod]: broadcastedJit(([a, b]) => require_backend.AluExp.mod(a, b)),
|
|
2060
|
+
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
2061
|
+
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1898
2062
|
[Primitive.Neg]: unopJit((a) => require_backend.AluExp.sub(require_backend.AluExp.const(a.dtype, 0), a)),
|
|
1899
2063
|
[Primitive.Reciprocal]: unopJit(require_backend.AluExp.reciprocal),
|
|
1900
2064
|
[Primitive.Floor]: unopJit(require_backend.AluExp.floor),
|
|
@@ -1902,17 +2066,6 @@ const jitRules = {
|
|
|
1902
2066
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1903
2067
|
[Primitive.Cast]: unopJit((a, { dtype }) => require_backend.AluExp.cast(dtype, a)),
|
|
1904
2068
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => require_backend.AluExp.bitcast(dtype, a)),
|
|
1905
|
-
[Primitive.RandomBits]: (nargs, keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1906
|
-
const mapping = (st) => {
|
|
1907
|
-
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
|
|
1908
|
-
};
|
|
1909
|
-
const k0 = reshapeViews(keys[0], mapping);
|
|
1910
|
-
const k1 = reshapeViews(keys[1], mapping);
|
|
1911
|
-
const c0 = require_backend.AluExp.u32(0);
|
|
1912
|
-
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
1913
|
-
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1914
|
-
return new require_backend.Kernel(nargs, require_backend.prod(shape$1), exp$2);
|
|
1915
|
-
},
|
|
1916
2069
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
1917
2070
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
1918
2071
|
[Primitive.Asin]: unopJit(require_backend.AluExp.asin),
|
|
@@ -1922,9 +2075,7 @@ const jitRules = {
|
|
|
1922
2075
|
[Primitive.Erf]: unopJit(require_backend.AluExp.erf),
|
|
1923
2076
|
[Primitive.Erfc]: unopJit(require_backend.AluExp.erfc),
|
|
1924
2077
|
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
1925
|
-
[Primitive.
|
|
1926
|
-
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
1927
|
-
[Primitive.Reduce](nargs, [a], [as], { op, axis }) {
|
|
2078
|
+
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1928
2079
|
const keptAxes = [];
|
|
1929
2080
|
const shiftedAxes = [];
|
|
1930
2081
|
const newShape = [];
|
|
@@ -1933,53 +2084,58 @@ const jitRules = {
|
|
|
1933
2084
|
keptAxes.push(i);
|
|
1934
2085
|
newShape.push(as.shape[i]);
|
|
1935
2086
|
}
|
|
1936
|
-
const size$1 = require_backend.prod(newShape);
|
|
1937
2087
|
const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
1938
2088
|
newShape.push(reductionSize);
|
|
1939
2089
|
const perm = keptAxes.concat(shiftedAxes);
|
|
1940
2090
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
1941
2091
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
1942
|
-
return
|
|
2092
|
+
return {
|
|
2093
|
+
exp: a,
|
|
2094
|
+
reduction
|
|
2095
|
+
};
|
|
1943
2096
|
},
|
|
1944
2097
|
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1945
|
-
[Primitive.PoolTranspose](
|
|
2098
|
+
[Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
|
|
1946
2099
|
let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1947
|
-
const size$1 = require_backend.prod(inShape);
|
|
1948
2100
|
stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
|
|
1949
2101
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1950
2102
|
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1951
|
-
return
|
|
2103
|
+
return {
|
|
2104
|
+
exp: a,
|
|
2105
|
+
reduction
|
|
2106
|
+
};
|
|
1952
2107
|
},
|
|
1953
|
-
[Primitive.Dot](
|
|
1954
|
-
const k1 = jitRules[Primitive.Mul](
|
|
2108
|
+
[Primitive.Dot]([a, b], [as, bs]) {
|
|
2109
|
+
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
1955
2110
|
const c = k1.exp;
|
|
1956
2111
|
const cs = promoteAvals(as, bs);
|
|
1957
|
-
return jitRules[Primitive.Reduce](
|
|
2112
|
+
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
1958
2113
|
op: require_backend.AluOp.Add,
|
|
1959
2114
|
axis: [cs.ndim - 1]
|
|
1960
2115
|
});
|
|
1961
2116
|
},
|
|
1962
|
-
[Primitive.Conv](
|
|
2117
|
+
[Primitive.Conv]([a, b], [as, bs], params) {
|
|
1963
2118
|
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1964
2119
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1965
2120
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1966
2121
|
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1967
2122
|
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1968
|
-
return jitRules[Primitive.Dot](
|
|
2123
|
+
return jitRules[Primitive.Dot]([a, b], [as, bs], {});
|
|
1969
2124
|
},
|
|
1970
2125
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1971
2126
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1972
|
-
[Primitive.
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
const
|
|
1977
|
-
|
|
1978
|
-
|
|
1979
|
-
|
|
1980
|
-
|
|
1981
|
-
|
|
1982
|
-
|
|
2127
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2128
|
+
const mapping = (st) => {
|
|
2129
|
+
if (!require_backend.deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, require_backend.range(shape$1.length - st.shape.length));
|
|
2130
|
+
};
|
|
2131
|
+
const k0 = reshapeViews(keys[0], mapping);
|
|
2132
|
+
const k1 = reshapeViews(keys[1], mapping);
|
|
2133
|
+
const c0 = require_backend.AluExp.u32(0);
|
|
2134
|
+
const c1 = require_backend.AluExp.cast(require_backend.DType.Uint32, require_backend.AluVar.gidx);
|
|
2135
|
+
const exp$2 = require_backend.AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2136
|
+
return { exp: exp$2 };
|
|
2137
|
+
},
|
|
2138
|
+
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1983
2139
|
const axisSet = new Set(axis);
|
|
1984
2140
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(require_backend.generalBroadcast);
|
|
1985
2141
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1992,24 +2148,38 @@ const jitRules = {
|
|
|
1992
2148
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
1993
2149
|
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1994
2150
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1995
|
-
return
|
|
2151
|
+
return { exp: x.substitute({ gidx: index }) };
|
|
1996
2152
|
},
|
|
1997
|
-
[Primitive.
|
|
1998
|
-
|
|
2153
|
+
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2154
|
+
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
2155
|
+
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
2156
|
+
[Primitive.Flip]: reshapeJit((st, { axis }) => {
|
|
2157
|
+
const arg = require_backend.rep(st.shape.length, false);
|
|
2158
|
+
for (const ax of axis) arg[ax] = true;
|
|
2159
|
+
return st.flip(arg);
|
|
2160
|
+
}),
|
|
2161
|
+
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
2162
|
+
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
2163
|
+
[Primitive.Sort]: routineNoJit(),
|
|
2164
|
+
[Primitive.Argsort]: routineNoJit(),
|
|
2165
|
+
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2166
|
+
[Primitive.Cholesky]: routineNoJit(),
|
|
2167
|
+
[Primitive.Jit]() {
|
|
2168
|
+
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
1999
2169
|
}
|
|
2000
2170
|
};
|
|
2001
2171
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
2002
2172
|
function splitGraphDataflow(backend, jaxpr) {
|
|
2003
|
-
const
|
|
2173
|
+
const varToDefn = /* @__PURE__ */ new Map();
|
|
2174
|
+
const varToUsages = /* @__PURE__ */ new Map();
|
|
2004
2175
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
2005
2176
|
const eqn = jaxpr.eqns[i];
|
|
2006
|
-
for (const v of eqn.outBinders) if (v instanceof Var)
|
|
2007
|
-
|
|
2008
|
-
|
|
2009
|
-
|
|
2010
|
-
|
|
2011
|
-
|
|
2012
|
-
p1NextBlack.set(v, v);
|
|
2177
|
+
for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
|
|
2178
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
2179
|
+
const usages = varToUsages.get(input);
|
|
2180
|
+
if (usages) usages.push(i);
|
|
2181
|
+
else varToUsages.set(input, [i]);
|
|
2182
|
+
}
|
|
2013
2183
|
}
|
|
2014
2184
|
const reducePrimitives = [
|
|
2015
2185
|
Primitive.Reduce,
|
|
@@ -2017,28 +2187,94 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2017
2187
|
Primitive.Conv,
|
|
2018
2188
|
Primitive.PoolTranspose
|
|
2019
2189
|
];
|
|
2020
|
-
const
|
|
2021
|
-
|
|
2190
|
+
const reductionEpilogueEqns = /* @__PURE__ */ new Set();
|
|
2191
|
+
const reductionEndpointEqns = /* @__PURE__ */ new Set();
|
|
2192
|
+
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
2022
2193
|
const eqn = jaxpr.eqns[i];
|
|
2023
|
-
if (reducePrimitives.includes(eqn.primitive)
|
|
2024
|
-
|
|
2025
|
-
|
|
2026
|
-
|
|
2194
|
+
if (reducePrimitives.includes(eqn.primitive)) {
|
|
2195
|
+
let head = i;
|
|
2196
|
+
while (true) {
|
|
2197
|
+
reductionEpilogueEqns.add(head);
|
|
2198
|
+
const outVar = jaxpr.eqns[head].outBinders[0];
|
|
2199
|
+
const usages = varToUsages.get(outVar) ?? [];
|
|
2200
|
+
if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
|
|
2201
|
+
if (reductionEpilogueEqns.has(usages[0])) break;
|
|
2202
|
+
const nextEqn = jaxpr.eqns[usages[0]];
|
|
2203
|
+
switch (nextEqn.primitive) {
|
|
2204
|
+
case Primitive.Neg:
|
|
2205
|
+
case Primitive.Reciprocal:
|
|
2206
|
+
case Primitive.Floor:
|
|
2207
|
+
case Primitive.Ceil:
|
|
2208
|
+
case Primitive.StopGradient:
|
|
2209
|
+
case Primitive.Cast:
|
|
2210
|
+
case Primitive.Bitcast:
|
|
2211
|
+
case Primitive.Sin:
|
|
2212
|
+
case Primitive.Cos:
|
|
2213
|
+
case Primitive.Asin:
|
|
2214
|
+
case Primitive.Atan:
|
|
2215
|
+
case Primitive.Exp:
|
|
2216
|
+
case Primitive.Log:
|
|
2217
|
+
case Primitive.Erf:
|
|
2218
|
+
case Primitive.Erfc:
|
|
2219
|
+
case Primitive.Sqrt:
|
|
2220
|
+
head = usages[0];
|
|
2221
|
+
continue;
|
|
2222
|
+
case Primitive.Add:
|
|
2223
|
+
case Primitive.Mul:
|
|
2224
|
+
case Primitive.Idiv:
|
|
2225
|
+
case Primitive.Mod:
|
|
2226
|
+
case Primitive.Min:
|
|
2227
|
+
case Primitive.Max: {
|
|
2228
|
+
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2229
|
+
if (otherInput instanceof Lit || require_backend.deepEqual(require_backend.generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2230
|
+
head = usages[0];
|
|
2231
|
+
continue;
|
|
2232
|
+
}
|
|
2233
|
+
break;
|
|
2234
|
+
}
|
|
2235
|
+
}
|
|
2236
|
+
break;
|
|
2027
2237
|
}
|
|
2028
|
-
|
|
2238
|
+
reductionEndpointEqns.add(head);
|
|
2029
2239
|
}
|
|
2030
|
-
|
|
2031
|
-
|
|
2032
|
-
|
|
2033
|
-
|
|
2034
|
-
|
|
2035
|
-
|
|
2036
|
-
|
|
2037
|
-
|
|
2038
|
-
|
|
2240
|
+
}
|
|
2241
|
+
const blackNodes = /* @__PURE__ */ new Set();
|
|
2242
|
+
const p1NextBlack = /* @__PURE__ */ new Map();
|
|
2243
|
+
for (const v of jaxpr.outs) if (v instanceof Var) {
|
|
2244
|
+
blackNodes.add(v);
|
|
2245
|
+
p1NextBlack.set(v, v);
|
|
2246
|
+
}
|
|
2247
|
+
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2248
|
+
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2249
|
+
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2250
|
+
const eqn = jaxpr.eqns[i];
|
|
2251
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2252
|
+
for (const v of eqn.outBinders) {
|
|
2253
|
+
blackNodes.add(v);
|
|
2254
|
+
p1NextBlack.set(v, v);
|
|
2255
|
+
}
|
|
2256
|
+
continue;
|
|
2257
|
+
}
|
|
2258
|
+
const reach = /* @__PURE__ */ new Set();
|
|
2259
|
+
let needsCleanOutput = false;
|
|
2260
|
+
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2261
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
|
|
2262
|
+
needsCleanOutput = true;
|
|
2263
|
+
break outer;
|
|
2264
|
+
}
|
|
2265
|
+
for (const o of jaxpr.eqns[j].outBinders) {
|
|
2266
|
+
const u = p1NextBlack.get(o);
|
|
2267
|
+
if (u) reach.add(u);
|
|
2268
|
+
}
|
|
2269
|
+
}
|
|
2270
|
+
if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
|
|
2039
2271
|
blackNodes.add(v);
|
|
2040
2272
|
p1NextBlack.set(v, v);
|
|
2041
2273
|
}
|
|
2274
|
+
else if (reach.size === 1) {
|
|
2275
|
+
const b = reach.values().next().value;
|
|
2276
|
+
for (const v of eqn.outBinders) p1NextBlack.set(v, b);
|
|
2277
|
+
}
|
|
2042
2278
|
}
|
|
2043
2279
|
const p2Deps = /* @__PURE__ */ new Map();
|
|
2044
2280
|
for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
|
|
@@ -2046,7 +2282,6 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2046
2282
|
while (p2idx < jaxpr.eqns.length) {
|
|
2047
2283
|
const eqn = jaxpr.eqns[p2idx++];
|
|
2048
2284
|
const deps = [];
|
|
2049
|
-
if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
|
|
2050
2285
|
for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
|
|
2051
2286
|
else deps.push(p2Deps.get(input));
|
|
2052
2287
|
else deps.push(/* @__PURE__ */ new Set());
|
|
@@ -2057,7 +2292,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2057
2292
|
let assocInput = -1;
|
|
2058
2293
|
for (let i = 0; i < eqn.inputs.length; i++) {
|
|
2059
2294
|
const input = eqn.inputs[i];
|
|
2060
|
-
if (input instanceof Var &&
|
|
2295
|
+
if (input instanceof Var && varToDefn.has(input)) {
|
|
2061
2296
|
let uniqueDeps = 0;
|
|
2062
2297
|
for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
|
|
2063
2298
|
if (uniqueDeps > maxUniqueDeps) {
|
|
@@ -2068,8 +2303,8 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2068
2303
|
}
|
|
2069
2304
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2070
2305
|
const assocVar = eqn.inputs[assocInput];
|
|
2071
|
-
p2idx =
|
|
2072
|
-
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2306
|
+
p2idx = varToDefn.get(assocVar);
|
|
2307
|
+
for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
|
|
2073
2308
|
} else {
|
|
2074
2309
|
const s = new Set(depCounter.keys());
|
|
2075
2310
|
for (const out of eqn.outBinders) p2Deps.set(out, s);
|
|
@@ -2095,9 +2330,9 @@ var PendingExecute = class {
|
|
|
2095
2330
|
submitted = false;
|
|
2096
2331
|
#promise = null;
|
|
2097
2332
|
#rc = 1;
|
|
2098
|
-
constructor(backend,
|
|
2333
|
+
constructor(backend, source, inputs, outputs) {
|
|
2099
2334
|
this.backend = backend;
|
|
2100
|
-
this.
|
|
2335
|
+
this.source = source;
|
|
2101
2336
|
this.inputs = inputs;
|
|
2102
2337
|
this.outputs = outputs;
|
|
2103
2338
|
for (const slot of inputs) this.backend.incRef(slot);
|
|
@@ -2118,13 +2353,15 @@ var PendingExecute = class {
|
|
|
2118
2353
|
return;
|
|
2119
2354
|
}
|
|
2120
2355
|
this.#promise = (async () => {
|
|
2121
|
-
this.prepared = await this.backend.
|
|
2356
|
+
if (this.source instanceof require_backend.Kernel) this.prepared = await this.backend.prepareKernel(this.source);
|
|
2357
|
+
else this.prepared = await this.backend.prepareRoutine(this.source);
|
|
2122
2358
|
})();
|
|
2123
2359
|
await this.#promise;
|
|
2124
2360
|
}
|
|
2125
2361
|
prepareSync() {
|
|
2126
2362
|
if (this.prepared) return;
|
|
2127
|
-
this.prepared = this.backend.
|
|
2363
|
+
if (this.source instanceof require_backend.Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
|
|
2364
|
+
else this.prepared = this.backend.prepareRoutineSync(this.source);
|
|
2128
2365
|
}
|
|
2129
2366
|
submit() {
|
|
2130
2367
|
if (this.submitted) return;
|
|
@@ -2147,8 +2384,6 @@ var PendingExecute = class {
|
|
|
2147
2384
|
* "Array" type by name.
|
|
2148
2385
|
*/
|
|
2149
2386
|
var Array$1 = class Array$1 extends Tracer {
|
|
2150
|
-
static #nextId = 1001;
|
|
2151
|
-
id;
|
|
2152
2387
|
#dtype;
|
|
2153
2388
|
#weakType;
|
|
2154
2389
|
#source;
|
|
@@ -2165,7 +2400,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2165
2400
|
*/
|
|
2166
2401
|
constructor(args) {
|
|
2167
2402
|
super(baseArrayTrace);
|
|
2168
|
-
this.id = Array$1.#nextId++;
|
|
2169
2403
|
this.#dtype = args.dtype;
|
|
2170
2404
|
this.#weakType = args.weakType;
|
|
2171
2405
|
this.#source = args.source;
|
|
@@ -2474,6 +2708,27 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2474
2708
|
pending
|
|
2475
2709
|
});
|
|
2476
2710
|
}
|
|
2711
|
+
/** Apply an operation with custom lowering to this array. */
|
|
2712
|
+
static #routine(routine, arrays, outputWeakType) {
|
|
2713
|
+
const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
|
|
2714
|
+
for (const ar of arrays) ar.#realize();
|
|
2715
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2716
|
+
const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(require_backend.byteWidth(dtype) * require_backend.prod(routine.type.outputShapes[i])));
|
|
2717
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2718
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2719
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2720
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2721
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2722
|
+
return outputs.map((output, i) => new Array$1({
|
|
2723
|
+
source: output,
|
|
2724
|
+
st: require_backend.ShapeTracker.fromShape(routine.type.outputShapes[i]),
|
|
2725
|
+
dtype: routine.type.outputDtypes[i],
|
|
2726
|
+
weakType: outputWeakType[i],
|
|
2727
|
+
backend,
|
|
2728
|
+
committed,
|
|
2729
|
+
pending
|
|
2730
|
+
}));
|
|
2731
|
+
}
|
|
2477
2732
|
/**
|
|
2478
2733
|
* Normalizes this array into one backed by a `Slot`.
|
|
2479
2734
|
*
|
|
@@ -2634,6 +2889,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2634
2889
|
[Primitive.Mod]([x, y]) {
|
|
2635
2890
|
return [x.#binary(require_backend.AluOp.Mod, y)];
|
|
2636
2891
|
},
|
|
2892
|
+
[Primitive.Min]([x, y]) {
|
|
2893
|
+
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
2894
|
+
},
|
|
2895
|
+
[Primitive.Max]([x, y]) {
|
|
2896
|
+
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
2897
|
+
},
|
|
2637
2898
|
[Primitive.Neg]([x]) {
|
|
2638
2899
|
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
2639
2900
|
},
|
|
@@ -2670,25 +2931,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2670
2931
|
return [y];
|
|
2671
2932
|
}
|
|
2672
2933
|
},
|
|
2673
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2674
|
-
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2675
|
-
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2676
|
-
const c0 = zeros(shape$1, {
|
|
2677
|
-
dtype: require_backend.DType.Uint32,
|
|
2678
|
-
device: k0.device
|
|
2679
|
-
});
|
|
2680
|
-
const c1 = arange(0, require_backend.prod(shape$1), 1, {
|
|
2681
|
-
dtype: require_backend.DType.Uint32,
|
|
2682
|
-
device: k0.device
|
|
2683
|
-
}).reshape(shape$1);
|
|
2684
|
-
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2685
|
-
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2686
|
-
k0,
|
|
2687
|
-
k1,
|
|
2688
|
-
c0,
|
|
2689
|
-
c1
|
|
2690
|
-
])];
|
|
2691
|
-
},
|
|
2692
2934
|
[Primitive.Sin]([x]) {
|
|
2693
2935
|
return [x.#unary(require_backend.AluOp.Sin)];
|
|
2694
2936
|
},
|
|
@@ -2716,12 +2958,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2716
2958
|
[Primitive.Sqrt]([x]) {
|
|
2717
2959
|
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
2718
2960
|
},
|
|
2719
|
-
[Primitive.Min]([x, y]) {
|
|
2720
|
-
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
2721
|
-
},
|
|
2722
|
-
[Primitive.Max]([x, y]) {
|
|
2723
|
-
return [x.#binary(require_backend.AluOp.Max, y)];
|
|
2724
|
-
},
|
|
2725
2961
|
[Primitive.Reduce]([x], { op, axis }) {
|
|
2726
2962
|
if (axis.length === 0) return [x];
|
|
2727
2963
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
@@ -2756,6 +2992,28 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2756
2992
|
y
|
|
2757
2993
|
], { dtypeOverride: [require_backend.DType.Bool] })];
|
|
2758
2994
|
},
|
|
2995
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2996
|
+
const keyShape = require_backend.generalBroadcast(k0.shape, k1.shape);
|
|
2997
|
+
if (!require_backend.deepEqual(require_backend.generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2998
|
+
const c0 = zeros(shape$1, {
|
|
2999
|
+
dtype: require_backend.DType.Uint32,
|
|
3000
|
+
device: k0.device
|
|
3001
|
+
});
|
|
3002
|
+
const c1 = arange(0, require_backend.prod(shape$1), 1, {
|
|
3003
|
+
dtype: require_backend.DType.Uint32,
|
|
3004
|
+
device: k0.device
|
|
3005
|
+
}).reshape(shape$1);
|
|
3006
|
+
const custom = ([k0$1, k1$1, c0$1, c1$1]) => require_backend.AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
3007
|
+
return [Array$1.#naryCustom("random_bits", custom, [
|
|
3008
|
+
k0,
|
|
3009
|
+
k1,
|
|
3010
|
+
c0,
|
|
3011
|
+
c1
|
|
3012
|
+
])];
|
|
3013
|
+
},
|
|
3014
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
3015
|
+
return [x.#gather(indices, axis, outDim)];
|
|
3016
|
+
},
|
|
2759
3017
|
[Primitive.Transpose]([x], { perm }) {
|
|
2760
3018
|
return [x.#transpose(perm)];
|
|
2761
3019
|
},
|
|
@@ -2776,17 +3034,48 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2776
3034
|
[Primitive.Pad]([x], { width }) {
|
|
2777
3035
|
return [x.#reshape(x.#st.pad(width))];
|
|
2778
3036
|
},
|
|
2779
|
-
[Primitive.
|
|
2780
|
-
|
|
3037
|
+
[Primitive.Sort]([x]) {
|
|
3038
|
+
const routine = new require_backend.Routine(require_backend.Routines.Sort, {
|
|
3039
|
+
inputShapes: [x.aval.shape],
|
|
3040
|
+
inputDtypes: [x.aval.dtype],
|
|
3041
|
+
outputShapes: [x.aval.shape],
|
|
3042
|
+
outputDtypes: [x.aval.dtype]
|
|
3043
|
+
});
|
|
3044
|
+
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3045
|
+
},
|
|
3046
|
+
[Primitive.Argsort]([x]) {
|
|
3047
|
+
const routine = new require_backend.Routine(require_backend.Routines.Argsort, {
|
|
3048
|
+
inputShapes: [x.aval.shape],
|
|
3049
|
+
inputDtypes: [x.aval.dtype],
|
|
3050
|
+
outputShapes: [x.aval.shape, x.aval.shape],
|
|
3051
|
+
outputDtypes: [x.aval.dtype, require_backend.DType.Int32]
|
|
3052
|
+
});
|
|
3053
|
+
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3054
|
+
},
|
|
3055
|
+
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3056
|
+
const routine = new require_backend.Routine(require_backend.Routines.TriangularSolve, {
|
|
3057
|
+
inputShapes: [a.aval.shape, b.aval.shape],
|
|
3058
|
+
inputDtypes: [a.aval.dtype, b.aval.dtype],
|
|
3059
|
+
outputShapes: [b.aval.shape],
|
|
3060
|
+
outputDtypes: [b.aval.dtype]
|
|
3061
|
+
}, { unitDiagonal });
|
|
3062
|
+
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
2781
3063
|
},
|
|
2782
|
-
[Primitive.
|
|
2783
|
-
|
|
2784
|
-
|
|
3064
|
+
[Primitive.Cholesky]([a]) {
|
|
3065
|
+
const routine = new require_backend.Routine(require_backend.Routines.Cholesky, {
|
|
3066
|
+
inputShapes: [a.aval.shape],
|
|
3067
|
+
inputDtypes: [a.aval.dtype],
|
|
3068
|
+
outputShapes: [a.aval.shape],
|
|
3069
|
+
outputDtypes: [a.aval.dtype]
|
|
3070
|
+
});
|
|
3071
|
+
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3072
|
+
},
|
|
3073
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
3074
|
+
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3075
|
+
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
2785
3076
|
args = args.map((ar) => ar._putSync(backend));
|
|
2786
|
-
const
|
|
2787
|
-
const
|
|
2788
|
-
const jp = jitCompile(backend, jaxpr, consts);
|
|
2789
|
-
const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
|
|
3077
|
+
const jp = jitCompile(backend, jaxpr);
|
|
3078
|
+
const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
|
|
2790
3079
|
for (const exe of pending) exe.updateRc(+outputs.length - 1);
|
|
2791
3080
|
const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
|
|
2792
3081
|
for (const exe of prevPending) exe.updateRc(+outputs.length);
|
|
@@ -3085,6 +3374,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
3085
3374
|
});
|
|
3086
3375
|
}
|
|
3087
3376
|
/**
|
|
3377
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
3378
|
+
*
|
|
3379
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
3380
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
3381
|
+
* `k>0` is above it.
|
|
3382
|
+
*/
|
|
3383
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
3384
|
+
m ??= n;
|
|
3385
|
+
dtype ??= require_backend.DType.Float32;
|
|
3386
|
+
if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
|
|
3387
|
+
if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
|
|
3388
|
+
if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
|
|
3389
|
+
const rows = arange(k, n + k, 1, {
|
|
3390
|
+
dtype: require_backend.DType.Int32,
|
|
3391
|
+
device
|
|
3392
|
+
});
|
|
3393
|
+
const cols = arange(0, m, 1, {
|
|
3394
|
+
dtype: require_backend.DType.Int32,
|
|
3395
|
+
device
|
|
3396
|
+
});
|
|
3397
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
3398
|
+
}
|
|
3399
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
3400
|
+
function tril(a, k = 0) {
|
|
3401
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3402
|
+
a = fudgeArray(a);
|
|
3403
|
+
const [n, m] = a.shape.slice(-2);
|
|
3404
|
+
return where$1(tri(n, m, k, { dtype: require_backend.DType.Bool }), a.ref, zerosLike$1(a));
|
|
3405
|
+
}
|
|
3406
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
3407
|
+
function triu(a, k = 0) {
|
|
3408
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3409
|
+
a = fudgeArray(a);
|
|
3410
|
+
const [n, m] = a.shape.slice(-2);
|
|
3411
|
+
return where$1(tri(n, m, k - 1, { dtype: require_backend.DType.Bool }), zerosLike$1(a.ref), a);
|
|
3412
|
+
}
|
|
3413
|
+
/**
|
|
3088
3414
|
* Return evenly spaced numbers over a specified interval.
|
|
3089
3415
|
*
|
|
3090
3416
|
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
@@ -3131,335 +3457,107 @@ function aluCompare(a, b, op) {
|
|
|
3131
3457
|
}
|
|
3132
3458
|
|
|
3133
3459
|
//#endregion
|
|
3134
|
-
//#region src/frontend/
|
|
3460
|
+
//#region src/frontend/vmap.ts
|
|
3135
3461
|
var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3136
|
-
|
|
3137
|
-
|
|
3462
|
+
function mappedAval(batchDim, aval) {
|
|
3463
|
+
const shape$1 = [...aval.shape];
|
|
3464
|
+
shape$1.splice(batchDim, 1);
|
|
3465
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3466
|
+
}
|
|
3467
|
+
/** Move one axis to a different index. */
|
|
3468
|
+
function moveaxis(x, src, dst) {
|
|
3469
|
+
const t = pureArray(x);
|
|
3470
|
+
src = require_backend.checkAxis(src, t.ndim);
|
|
3471
|
+
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3472
|
+
if (src === dst) return t;
|
|
3473
|
+
const perm = require_backend.range(t.ndim);
|
|
3474
|
+
perm.splice(src, 1);
|
|
3475
|
+
perm.splice(dst, 0, src);
|
|
3476
|
+
return transpose$1(t, perm);
|
|
3477
|
+
}
|
|
3478
|
+
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3479
|
+
if (src === null) {
|
|
3480
|
+
const targetShape = [...x.shape];
|
|
3481
|
+
targetShape.splice(dst, 0, axisSize);
|
|
3482
|
+
return broadcast(x, targetShape, [dst]);
|
|
3483
|
+
} else if (src === dst) return x;
|
|
3484
|
+
else return moveaxis(x, src, dst);
|
|
3485
|
+
}
|
|
3486
|
+
var BatchTracer = class extends Tracer {
|
|
3487
|
+
constructor(trace$1, val, batchDim) {
|
|
3138
3488
|
super(trace$1);
|
|
3139
|
-
this.
|
|
3140
|
-
this.
|
|
3489
|
+
this.val = val;
|
|
3490
|
+
this.batchDim = batchDim;
|
|
3141
3491
|
}
|
|
3142
3492
|
get aval() {
|
|
3143
|
-
return this.
|
|
3493
|
+
if (this.batchDim === null) return this.val.aval;
|
|
3494
|
+
else return mappedAval(this.batchDim, this.val.aval);
|
|
3144
3495
|
}
|
|
3145
3496
|
toString() {
|
|
3146
|
-
return `
|
|
3497
|
+
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3147
3498
|
}
|
|
3148
3499
|
get ref() {
|
|
3149
|
-
this.
|
|
3500
|
+
this.val.ref;
|
|
3150
3501
|
return this;
|
|
3151
3502
|
}
|
|
3152
3503
|
dispose() {
|
|
3153
|
-
this.
|
|
3154
|
-
|
|
3504
|
+
this.val.dispose();
|
|
3505
|
+
}
|
|
3506
|
+
fullLower() {
|
|
3507
|
+
if (this.batchDim === null) return this.val.fullLower();
|
|
3508
|
+
else return this;
|
|
3155
3509
|
}
|
|
3156
3510
|
};
|
|
3157
|
-
var
|
|
3511
|
+
var BatchTrace = class extends Trace {
|
|
3158
3512
|
pure(val) {
|
|
3159
3513
|
return this.lift(pureArray(val));
|
|
3160
3514
|
}
|
|
3161
3515
|
lift(val) {
|
|
3162
|
-
return new
|
|
3516
|
+
return new BatchTracer(this, val, null);
|
|
3163
3517
|
}
|
|
3164
3518
|
processPrimitive(primitive, tracers, params) {
|
|
3165
|
-
const [
|
|
3166
|
-
const
|
|
3167
|
-
if (
|
|
3168
|
-
|
|
3169
|
-
|
|
3519
|
+
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3520
|
+
const vmapRule = vmapRules[primitive];
|
|
3521
|
+
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3522
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3523
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3524
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3525
|
+
}
|
|
3526
|
+
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3527
|
+
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3528
|
+
}
|
|
3529
|
+
get axisSize() {
|
|
3530
|
+
return this.main.globalData;
|
|
3170
3531
|
}
|
|
3171
3532
|
};
|
|
3172
|
-
/**
|
|
3173
|
-
|
|
3174
|
-
|
|
3175
|
-
|
|
3176
|
-
|
|
3177
|
-
|
|
3178
|
-
|
|
3179
|
-
|
|
3180
|
-
|
|
3181
|
-
|
|
3182
|
-
|
|
3183
|
-
|
|
3184
|
-
|
|
3185
|
-
|
|
3533
|
+
/**
|
|
3534
|
+
* Process a primitive with built-in broadcasting.
|
|
3535
|
+
*
|
|
3536
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3537
|
+
*/
|
|
3538
|
+
function broadcastBatcher(op) {
|
|
3539
|
+
return (axisSize, args, dims) => {
|
|
3540
|
+
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3541
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3542
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3543
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3544
|
+
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3545
|
+
args = args.map((x, i) => {
|
|
3546
|
+
if (dims[i] === null) return x;
|
|
3547
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3548
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3549
|
+
x.shape[0],
|
|
3550
|
+
...require_backend.rep(nd - x.ndim, 1),
|
|
3551
|
+
...x.shape.slice(1)
|
|
3552
|
+
]);
|
|
3553
|
+
return x;
|
|
3554
|
+
});
|
|
3555
|
+
return [[op(...args)], [0]];
|
|
3186
3556
|
};
|
|
3187
3557
|
}
|
|
3188
|
-
|
|
3189
|
-
|
|
3190
|
-
|
|
3191
|
-
for (const t of tangents) t.dispose();
|
|
3192
|
-
const ys = bind(primitive, primals, params);
|
|
3193
|
-
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3194
|
-
};
|
|
3195
|
-
}
|
|
3196
|
-
const jvpRules = {
|
|
3197
|
-
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3198
|
-
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
3199
|
-
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
3200
|
-
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
3201
|
-
if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
|
|
3202
|
-
dx.dispose();
|
|
3203
|
-
dy.dispose();
|
|
3204
|
-
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
3205
|
-
}
|
|
3206
|
-
const q = idiv(x.ref, y.ref);
|
|
3207
|
-
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
3208
|
-
},
|
|
3209
|
-
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
3210
|
-
[Primitive.Reciprocal]([x], [dx]) {
|
|
3211
|
-
const xRecip = reciprocal$1(x.ref);
|
|
3212
|
-
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
3213
|
-
},
|
|
3214
|
-
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
3215
|
-
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
3216
|
-
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
3217
|
-
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
3218
|
-
if (x.dtype === dtype) return [[x], [dx]];
|
|
3219
|
-
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
3220
|
-
else {
|
|
3221
|
-
dx.dispose();
|
|
3222
|
-
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3223
|
-
}
|
|
3224
|
-
},
|
|
3225
|
-
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3226
|
-
if (x.dtype === dtype) return [[x], [dx]];
|
|
3227
|
-
dx.dispose();
|
|
3228
|
-
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3229
|
-
},
|
|
3230
|
-
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3231
|
-
[Primitive.Sin]([x], [dx]) {
|
|
3232
|
-
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3233
|
-
},
|
|
3234
|
-
[Primitive.Cos]([x], [dx]) {
|
|
3235
|
-
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
3236
|
-
},
|
|
3237
|
-
[Primitive.Asin]([x], [dx]) {
|
|
3238
|
-
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3239
|
-
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3240
|
-
},
|
|
3241
|
-
[Primitive.Atan]([x], [dx]) {
|
|
3242
|
-
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3243
|
-
return [[atan$1(x)], [dx.div(denom)]];
|
|
3244
|
-
},
|
|
3245
|
-
[Primitive.Exp]([x], [dx]) {
|
|
3246
|
-
const z = exp$1(x);
|
|
3247
|
-
return [[z.ref], [z.mul(dx)]];
|
|
3248
|
-
},
|
|
3249
|
-
[Primitive.Log]([x], [dx]) {
|
|
3250
|
-
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3251
|
-
},
|
|
3252
|
-
[Primitive.Erf]([x], [dx]) {
|
|
3253
|
-
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3254
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3255
|
-
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3256
|
-
},
|
|
3257
|
-
[Primitive.Erfc]([x], [dx]) {
|
|
3258
|
-
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3259
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3260
|
-
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3261
|
-
},
|
|
3262
|
-
[Primitive.Sqrt]([x], [dx]) {
|
|
3263
|
-
const z = sqrt$1(x);
|
|
3264
|
-
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3265
|
-
},
|
|
3266
|
-
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3267
|
-
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3268
|
-
},
|
|
3269
|
-
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3270
|
-
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3271
|
-
},
|
|
3272
|
-
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3273
|
-
if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3274
|
-
else if (op === require_backend.AluOp.Mul) {
|
|
3275
|
-
const primal = reduce(x.ref, op, axis);
|
|
3276
|
-
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3277
|
-
return [[primal], [tangent]];
|
|
3278
|
-
} else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
|
|
3279
|
-
const primal = reduce(x.ref, op, axis);
|
|
3280
|
-
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3281
|
-
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3282
|
-
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3283
|
-
return [[primal], [tangent]];
|
|
3284
|
-
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3285
|
-
},
|
|
3286
|
-
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3287
|
-
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3288
|
-
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3289
|
-
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3290
|
-
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3291
|
-
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3292
|
-
dcond.dispose();
|
|
3293
|
-
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3294
|
-
},
|
|
3295
|
-
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3296
|
-
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3297
|
-
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3298
|
-
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3299
|
-
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3300
|
-
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3301
|
-
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3302
|
-
const indicesRef = indices.map((t) => t.ref);
|
|
3303
|
-
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3304
|
-
},
|
|
3305
|
-
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3306
|
-
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3307
|
-
const outs = bind(Primitive.JitCall, [
|
|
3308
|
-
...newConsts.map((c) => c.ref),
|
|
3309
|
-
...primals,
|
|
3310
|
-
...tangents
|
|
3311
|
-
], {
|
|
3312
|
-
name: `${name}_jvp`,
|
|
3313
|
-
jaxpr: newJaxpr,
|
|
3314
|
-
numConsts: newConsts.length
|
|
3315
|
-
});
|
|
3316
|
-
const n = outs.length / 2;
|
|
3317
|
-
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3318
|
-
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3319
|
-
return [primalsOut, tangentsOut];
|
|
3320
|
-
}
|
|
3321
|
-
};
|
|
3322
|
-
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3323
|
-
function jvpJaxpr(jaxpr) {
|
|
3324
|
-
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3325
|
-
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3326
|
-
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3327
|
-
const result = {
|
|
3328
|
-
newJaxpr,
|
|
3329
|
-
newConsts
|
|
3330
|
-
};
|
|
3331
|
-
jvpJaxprCache.set(jaxpr, result);
|
|
3332
|
-
return result;
|
|
3333
|
-
}
|
|
3334
|
-
function jvpFlat(f, primals, tangents) {
|
|
3335
|
-
try {
|
|
3336
|
-
var _usingCtx$1 = (0, import_usingCtx$1.default)();
|
|
3337
|
-
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3338
|
-
const trace$1 = new JVPTrace(main);
|
|
3339
|
-
const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3340
|
-
const outs = f(...tracersIn);
|
|
3341
|
-
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3342
|
-
return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3343
|
-
} catch (_) {
|
|
3344
|
-
_usingCtx$1.e = _;
|
|
3345
|
-
} finally {
|
|
3346
|
-
_usingCtx$1.d();
|
|
3347
|
-
}
|
|
3348
|
-
}
|
|
3349
|
-
function jvp$1(f, primals, tangents) {
|
|
3350
|
-
const [primalsFlat, inTree] = flatten(primals);
|
|
3351
|
-
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
3352
|
-
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
3353
|
-
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
3354
|
-
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
3355
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
3356
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3357
|
-
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
3358
|
-
return [primalsOut, tangentsOut];
|
|
3359
|
-
}
|
|
3360
|
-
|
|
3361
|
-
//#endregion
|
|
3362
|
-
//#region src/frontend/vmap.ts
|
|
3363
|
-
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3364
|
-
function mappedAval(batchDim, aval) {
|
|
3365
|
-
const shape$1 = [...aval.shape];
|
|
3366
|
-
shape$1.splice(batchDim, 1);
|
|
3367
|
-
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3368
|
-
}
|
|
3369
|
-
/** Move one axis to a different index. */
|
|
3370
|
-
function moveaxis(x, src, dst) {
|
|
3371
|
-
const t = pureArray(x);
|
|
3372
|
-
src = require_backend.checkAxis(src, t.ndim);
|
|
3373
|
-
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3374
|
-
if (src === dst) return t;
|
|
3375
|
-
const perm = require_backend.range(t.ndim);
|
|
3376
|
-
perm.splice(src, 1);
|
|
3377
|
-
perm.splice(dst, 0, src);
|
|
3378
|
-
return transpose$1(t, perm);
|
|
3379
|
-
}
|
|
3380
|
-
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3381
|
-
if (src === null) {
|
|
3382
|
-
const targetShape = [...x.shape];
|
|
3383
|
-
targetShape.splice(dst, 0, axisSize);
|
|
3384
|
-
return broadcast(x, targetShape, [dst]);
|
|
3385
|
-
} else if (src === dst) return x;
|
|
3386
|
-
else return moveaxis(x, src, dst);
|
|
3387
|
-
}
|
|
3388
|
-
var BatchTracer = class extends Tracer {
|
|
3389
|
-
constructor(trace$1, val, batchDim) {
|
|
3390
|
-
super(trace$1);
|
|
3391
|
-
this.val = val;
|
|
3392
|
-
this.batchDim = batchDim;
|
|
3393
|
-
}
|
|
3394
|
-
get aval() {
|
|
3395
|
-
if (this.batchDim === null) return this.val.aval;
|
|
3396
|
-
else return mappedAval(this.batchDim, this.val.aval);
|
|
3397
|
-
}
|
|
3398
|
-
toString() {
|
|
3399
|
-
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3400
|
-
}
|
|
3401
|
-
get ref() {
|
|
3402
|
-
this.val.ref;
|
|
3403
|
-
return this;
|
|
3404
|
-
}
|
|
3405
|
-
dispose() {
|
|
3406
|
-
this.val.dispose();
|
|
3407
|
-
}
|
|
3408
|
-
fullLower() {
|
|
3409
|
-
if (this.batchDim === null) return this.val.fullLower();
|
|
3410
|
-
else return this;
|
|
3411
|
-
}
|
|
3412
|
-
};
|
|
3413
|
-
var BatchTrace = class extends Trace {
|
|
3414
|
-
pure(val) {
|
|
3415
|
-
return this.lift(pureArray(val));
|
|
3416
|
-
}
|
|
3417
|
-
lift(val) {
|
|
3418
|
-
return new BatchTracer(this, val, null);
|
|
3419
|
-
}
|
|
3420
|
-
processPrimitive(primitive, tracers, params) {
|
|
3421
|
-
const [valsIn, bdimsIn] = require_backend.unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3422
|
-
const vmapRule = vmapRules[primitive];
|
|
3423
|
-
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3424
|
-
if (bdimsIn.every((d) => d === null)) {
|
|
3425
|
-
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3426
|
-
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3427
|
-
}
|
|
3428
|
-
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3429
|
-
return require_backend.zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3430
|
-
}
|
|
3431
|
-
get axisSize() {
|
|
3432
|
-
return this.main.globalData;
|
|
3433
|
-
}
|
|
3434
|
-
};
|
|
3435
|
-
/**
|
|
3436
|
-
* Process a primitive with built-in broadcasting.
|
|
3437
|
-
*
|
|
3438
|
-
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3439
|
-
*/
|
|
3440
|
-
function broadcastBatcher(op) {
|
|
3441
|
-
return (axisSize, args, dims) => {
|
|
3442
|
-
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3443
|
-
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3444
|
-
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3445
|
-
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3446
|
-
if (require_backend.zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3447
|
-
args = args.map((x, i) => {
|
|
3448
|
-
if (dims[i] === null) return x;
|
|
3449
|
-
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3450
|
-
if (x.ndim < nd) x = x.reshape([
|
|
3451
|
-
x.shape[0],
|
|
3452
|
-
...require_backend.rep(nd - x.ndim, 1),
|
|
3453
|
-
...x.shape.slice(1)
|
|
3454
|
-
]);
|
|
3455
|
-
return x;
|
|
3456
|
-
});
|
|
3457
|
-
return [[op(...args)], [0]];
|
|
3458
|
-
};
|
|
3459
|
-
}
|
|
3460
|
-
function unopBatcher(op) {
|
|
3461
|
-
return (axisSize, [x], [xBdim], params) => {
|
|
3462
|
-
return [[op(x, params)], [xBdim]];
|
|
3558
|
+
function unopBatcher(op) {
|
|
3559
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3560
|
+
return [[op(x, params)], [xBdim]];
|
|
3463
3561
|
};
|
|
3464
3562
|
}
|
|
3465
3563
|
const vmapRules = {
|
|
@@ -3467,6 +3565,8 @@ const vmapRules = {
|
|
|
3467
3565
|
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3468
3566
|
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3469
3567
|
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3568
|
+
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3569
|
+
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3470
3570
|
[Primitive.Neg]: unopBatcher(neg),
|
|
3471
3571
|
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3472
3572
|
[Primitive.Floor]: unopBatcher(floor$1),
|
|
@@ -3483,8 +3583,6 @@ const vmapRules = {
|
|
|
3483
3583
|
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3484
3584
|
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3485
3585
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3486
|
-
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3487
|
-
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3488
3586
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3489
3587
|
require_backend.assertNonNull(xBdim);
|
|
3490
3588
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
@@ -3497,10 +3595,49 @@ const vmapRules = {
|
|
|
3497
3595
|
const z = dot$2(x, y);
|
|
3498
3596
|
return [[z], [z.ndim - 1]];
|
|
3499
3597
|
},
|
|
3598
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3599
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3600
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3601
|
+
const z = conv$1(x, y, {
|
|
3602
|
+
...params,
|
|
3603
|
+
vmapDims: params.vmapDims + 1
|
|
3604
|
+
});
|
|
3605
|
+
return [[z], [0]];
|
|
3606
|
+
},
|
|
3500
3607
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3501
3608
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3502
3609
|
},
|
|
3503
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3610
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3611
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3612
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3613
|
+
require_backend.assertNonNull(xBdim);
|
|
3614
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3615
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3616
|
+
let newOutDim = outDim;
|
|
3617
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3618
|
+
else newOutDim += 1;
|
|
3619
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3620
|
+
}
|
|
3621
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3622
|
+
indices = indices.map((m, i) => {
|
|
3623
|
+
if (indicesBdim[i] === null) return m;
|
|
3624
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3625
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3626
|
+
m.shape[0],
|
|
3627
|
+
...require_backend.rep(nd - m.ndim, 1),
|
|
3628
|
+
...m.shape.slice(1)
|
|
3629
|
+
]);
|
|
3630
|
+
return m;
|
|
3631
|
+
});
|
|
3632
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3633
|
+
else {
|
|
3634
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3635
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3636
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...require_backend.rep(nd - 1, 1)]);
|
|
3637
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3638
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3639
|
+
}
|
|
3640
|
+
},
|
|
3504
3641
|
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3505
3642
|
require_backend.assertNonNull(xBdim);
|
|
3506
3643
|
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
@@ -3532,42 +3669,53 @@ const vmapRules = {
|
|
|
3532
3669
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3533
3670
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3534
3671
|
},
|
|
3535
|
-
[Primitive.
|
|
3536
|
-
|
|
3537
|
-
|
|
3538
|
-
|
|
3539
|
-
|
|
3540
|
-
|
|
3541
|
-
|
|
3542
|
-
|
|
3543
|
-
|
|
3544
|
-
|
|
3545
|
-
|
|
3546
|
-
|
|
3547
|
-
|
|
3548
|
-
|
|
3549
|
-
|
|
3550
|
-
|
|
3551
|
-
|
|
3552
|
-
...
|
|
3672
|
+
[Primitive.Sort](axisSize, [x], [xBdim]) {
|
|
3673
|
+
require_backend.assertNonNull(xBdim);
|
|
3674
|
+
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3675
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3676
|
+
return [[sort$1(x)], [0]];
|
|
3677
|
+
},
|
|
3678
|
+
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3679
|
+
require_backend.assertNonNull(xBdim);
|
|
3680
|
+
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3681
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3682
|
+
return [argsort$1(x), [0, 0]];
|
|
3683
|
+
},
|
|
3684
|
+
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3685
|
+
if (aBdim === null) {
|
|
3686
|
+
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
3687
|
+
const [s, m, n] = b.shape.slice(-3);
|
|
3688
|
+
b = b.reshape([
|
|
3689
|
+
...b.shape.slice(0, -3),
|
|
3690
|
+
s * m,
|
|
3691
|
+
n
|
|
3553
3692
|
]);
|
|
3554
|
-
|
|
3555
|
-
|
|
3556
|
-
|
|
3557
|
-
|
|
3558
|
-
|
|
3559
|
-
|
|
3560
|
-
|
|
3561
|
-
|
|
3562
|
-
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3693
|
+
let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3694
|
+
x$1 = x$1.reshape([
|
|
3695
|
+
...b.shape.slice(0, -2),
|
|
3696
|
+
s,
|
|
3697
|
+
m,
|
|
3698
|
+
n
|
|
3699
|
+
]);
|
|
3700
|
+
return [[x$1], [x$1.ndim - 3]];
|
|
3563
3701
|
}
|
|
3702
|
+
a = moveBatchAxis(axisSize, aBdim, 0, a);
|
|
3703
|
+
b = moveBatchAxis(axisSize, bBdim, 0, b);
|
|
3704
|
+
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3705
|
+
return [[x], [0]];
|
|
3564
3706
|
},
|
|
3565
|
-
[Primitive.
|
|
3566
|
-
|
|
3567
|
-
|
|
3707
|
+
[Primitive.Cholesky](axisSize, [x], [xBdim]) {
|
|
3708
|
+
require_backend.assertNonNull(xBdim);
|
|
3709
|
+
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3710
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3711
|
+
return [[cholesky$2(x)], [0]];
|
|
3712
|
+
},
|
|
3713
|
+
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3714
|
+
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3715
|
+
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
3568
3716
|
name: `${name}_vmap`,
|
|
3569
|
-
jaxpr: newJaxpr,
|
|
3570
|
-
numConsts:
|
|
3717
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3718
|
+
numConsts: newJaxpr.consts.length
|
|
3571
3719
|
});
|
|
3572
3720
|
return [outs, require_backend.rep(outs.length, 0)];
|
|
3573
3721
|
}
|
|
@@ -3583,14 +3731,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3583
3731
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3584
3732
|
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3585
3733
|
});
|
|
3586
|
-
const { jaxpr: newJaxpr
|
|
3587
|
-
const result = {
|
|
3588
|
-
newJaxpr,
|
|
3589
|
-
newConsts
|
|
3590
|
-
};
|
|
3734
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3591
3735
|
if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
3592
|
-
vmapJaxprCache.get(jaxpr).set(cacheKey,
|
|
3593
|
-
return
|
|
3736
|
+
vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
3737
|
+
return newJaxpr;
|
|
3594
3738
|
}
|
|
3595
3739
|
function vmapFlat(f, inAxes, args) {
|
|
3596
3740
|
let axisSize = void 0;
|
|
@@ -3604,7 +3748,7 @@ function vmapFlat(f, inAxes, args) {
|
|
|
3604
3748
|
if (axisSize === void 0) throw new TypeError("vmap requires at least one mapped axis");
|
|
3605
3749
|
let valsOut, bdimsOut;
|
|
3606
3750
|
try {
|
|
3607
|
-
var _usingCtx$1 = (0, import_usingCtx.default)();
|
|
3751
|
+
var _usingCtx$1 = (0, import_usingCtx$1.default)();
|
|
3608
3752
|
const main = _usingCtx$1.u(newMain(BatchTrace, axisSize));
|
|
3609
3753
|
const trace$1 = new BatchTrace(main);
|
|
3610
3754
|
const tracersIn = args.map((x, i) => inAxes[i] === null ? pureArray(x) : new BatchTracer(trace$1, pureArray(x), inAxes[i]));
|
|
@@ -3645,6 +3789,261 @@ function jacfwd$1(f) {
|
|
|
3645
3789
|
};
|
|
3646
3790
|
}
|
|
3647
3791
|
|
|
3792
|
+
//#endregion
|
|
3793
|
+
//#region src/frontend/jvp.ts
|
|
3794
|
+
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
3795
|
+
var JVPTracer = class extends Tracer {
|
|
3796
|
+
constructor(trace$1, primal, tangent) {
|
|
3797
|
+
super(trace$1);
|
|
3798
|
+
this.primal = primal;
|
|
3799
|
+
this.tangent = tangent;
|
|
3800
|
+
}
|
|
3801
|
+
get aval() {
|
|
3802
|
+
return this.primal.aval;
|
|
3803
|
+
}
|
|
3804
|
+
toString() {
|
|
3805
|
+
return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
|
|
3806
|
+
}
|
|
3807
|
+
get ref() {
|
|
3808
|
+
this.primal.ref, this.tangent.ref;
|
|
3809
|
+
return this;
|
|
3810
|
+
}
|
|
3811
|
+
dispose() {
|
|
3812
|
+
this.primal.dispose();
|
|
3813
|
+
this.tangent.dispose();
|
|
3814
|
+
}
|
|
3815
|
+
};
|
|
3816
|
+
var JVPTrace = class extends Trace {
|
|
3817
|
+
pure(val) {
|
|
3818
|
+
return this.lift(pureArray(val));
|
|
3819
|
+
}
|
|
3820
|
+
lift(val) {
|
|
3821
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
3822
|
+
}
|
|
3823
|
+
processPrimitive(primitive, tracers, params) {
|
|
3824
|
+
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
3825
|
+
const jvpRule = jvpRules[primitive];
|
|
3826
|
+
if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
|
|
3827
|
+
const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
|
|
3828
|
+
return require_backend.zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
|
|
3829
|
+
}
|
|
3830
|
+
};
|
|
3831
|
+
/** Rule that applies the same operation to primals and tangents. */
|
|
3832
|
+
function linearTangentsJvp(primitive) {
|
|
3833
|
+
return (primals, tangents, params) => {
|
|
3834
|
+
const ys = bind(primitive, primals, params);
|
|
3835
|
+
const dys = bind(primitive, tangents, params);
|
|
3836
|
+
return [ys, dys];
|
|
3837
|
+
};
|
|
3838
|
+
}
|
|
3839
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
3840
|
+
function bilinearTangentsJvp(primitive) {
|
|
3841
|
+
return ([x, y], [dx, dy], params) => {
|
|
3842
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3843
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3844
|
+
return [[primal], [tangent]];
|
|
3845
|
+
};
|
|
3846
|
+
}
|
|
3847
|
+
/** Rule that zeros out any tangents. */
|
|
3848
|
+
function zeroTangentsJvp(primitive) {
|
|
3849
|
+
return (primals, tangents, params) => {
|
|
3850
|
+
for (const t of tangents) t.dispose();
|
|
3851
|
+
const ys = bind(primitive, primals, params);
|
|
3852
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3853
|
+
};
|
|
3854
|
+
}
|
|
3855
|
+
/** Compute `a @ b.T`, batched to last two axes. */
|
|
3856
|
+
function batchMatmulT(a, b) {
|
|
3857
|
+
return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
|
|
3858
|
+
}
|
|
3859
|
+
/** Batch matrix transpose. */
|
|
3860
|
+
function mT(a) {
|
|
3861
|
+
return moveaxis(a, -2, -1);
|
|
3862
|
+
}
|
|
3863
|
+
const jvpRules = {
|
|
3864
|
+
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3865
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
3866
|
+
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
3867
|
+
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
3868
|
+
if (!require_backend.isFloatDtype(x.dtype) && !require_backend.isFloatDtype(y.dtype)) {
|
|
3869
|
+
dx.dispose();
|
|
3870
|
+
dy.dispose();
|
|
3871
|
+
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
3872
|
+
}
|
|
3873
|
+
const q = idiv(x.ref, y.ref);
|
|
3874
|
+
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
3875
|
+
},
|
|
3876
|
+
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3877
|
+
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3878
|
+
},
|
|
3879
|
+
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3880
|
+
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3881
|
+
},
|
|
3882
|
+
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
3883
|
+
[Primitive.Reciprocal]([x], [dx]) {
|
|
3884
|
+
const xRecip = reciprocal$1(x.ref);
|
|
3885
|
+
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
3886
|
+
},
|
|
3887
|
+
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
3888
|
+
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
3889
|
+
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
3890
|
+
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
3891
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3892
|
+
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
3893
|
+
else {
|
|
3894
|
+
dx.dispose();
|
|
3895
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3896
|
+
}
|
|
3897
|
+
},
|
|
3898
|
+
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3899
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3900
|
+
dx.dispose();
|
|
3901
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3902
|
+
},
|
|
3903
|
+
[Primitive.Sin]([x], [dx]) {
|
|
3904
|
+
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3905
|
+
},
|
|
3906
|
+
[Primitive.Cos]([x], [dx]) {
|
|
3907
|
+
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
3908
|
+
},
|
|
3909
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3910
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3911
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3912
|
+
},
|
|
3913
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3914
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3915
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3916
|
+
},
|
|
3917
|
+
[Primitive.Exp]([x], [dx]) {
|
|
3918
|
+
const z = exp$1(x);
|
|
3919
|
+
return [[z.ref], [z.mul(dx)]];
|
|
3920
|
+
},
|
|
3921
|
+
[Primitive.Log]([x], [dx]) {
|
|
3922
|
+
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3923
|
+
},
|
|
3924
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3925
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3926
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3927
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3928
|
+
},
|
|
3929
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3930
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3931
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3932
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3933
|
+
},
|
|
3934
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
3935
|
+
const z = sqrt$1(x);
|
|
3936
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3937
|
+
},
|
|
3938
|
+
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3939
|
+
if (op === require_backend.AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3940
|
+
else if (op === require_backend.AluOp.Mul) {
|
|
3941
|
+
const primal = reduce(x.ref, op, axis);
|
|
3942
|
+
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3943
|
+
return [[primal], [tangent]];
|
|
3944
|
+
} else if (op === require_backend.AluOp.Min || op === require_backend.AluOp.Max) {
|
|
3945
|
+
const primal = reduce(x.ref, op, axis);
|
|
3946
|
+
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3947
|
+
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3948
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3949
|
+
return [[primal], [tangent]];
|
|
3950
|
+
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3951
|
+
},
|
|
3952
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3953
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3954
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3955
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3956
|
+
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3957
|
+
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3958
|
+
dcond.dispose();
|
|
3959
|
+
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3960
|
+
},
|
|
3961
|
+
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3962
|
+
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3963
|
+
const indicesRef = indices.map((t) => t.ref);
|
|
3964
|
+
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3965
|
+
},
|
|
3966
|
+
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3967
|
+
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3968
|
+
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3969
|
+
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3970
|
+
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3971
|
+
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3972
|
+
[Primitive.Sort]([x], [dx]) {
|
|
3973
|
+
const [y, idx] = argsort$1(x);
|
|
3974
|
+
return [[y], [gather(dx, [idx], [-1], -1)]];
|
|
3975
|
+
},
|
|
3976
|
+
[Primitive.Argsort]([x], [dx]) {
|
|
3977
|
+
const [y, idx] = argsort$1(x);
|
|
3978
|
+
return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
|
|
3979
|
+
},
|
|
3980
|
+
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
3981
|
+
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
3982
|
+
const dax = batchMatmulT(da, x.ref);
|
|
3983
|
+
const rhsT = db.sub(mT(dax));
|
|
3984
|
+
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
3985
|
+
return [[x], [dx]];
|
|
3986
|
+
},
|
|
3987
|
+
[Primitive.Cholesky]([a], [da]) {
|
|
3988
|
+
const L = cholesky$2(a.ref);
|
|
3989
|
+
da = da.ref.add(mT(da)).mul(.5);
|
|
3990
|
+
const W = triangularSolve$1(L.ref, da, { lower: true });
|
|
3991
|
+
const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
|
|
3992
|
+
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3993
|
+
return [[L], [dL]];
|
|
3994
|
+
},
|
|
3995
|
+
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3996
|
+
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3997
|
+
const outs = bind(Primitive.Jit, [
|
|
3998
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
3999
|
+
...primals,
|
|
4000
|
+
...tangents
|
|
4001
|
+
], {
|
|
4002
|
+
name: `${name}_jvp`,
|
|
4003
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4004
|
+
numConsts: newJaxpr.consts.length
|
|
4005
|
+
});
|
|
4006
|
+
const n = outs.length / 2;
|
|
4007
|
+
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
4008
|
+
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
4009
|
+
return [primalsOut, tangentsOut];
|
|
4010
|
+
}
|
|
4011
|
+
};
|
|
4012
|
+
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
4013
|
+
function jvpJaxpr(jaxpr) {
|
|
4014
|
+
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
4015
|
+
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
4016
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
4017
|
+
jvpJaxprCache.set(jaxpr, newJaxpr);
|
|
4018
|
+
return newJaxpr;
|
|
4019
|
+
}
|
|
4020
|
+
function jvpFlat(f, primals, tangents) {
|
|
4021
|
+
try {
|
|
4022
|
+
var _usingCtx$1 = (0, import_usingCtx.default)();
|
|
4023
|
+
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
4024
|
+
const trace$1 = new JVPTrace(main);
|
|
4025
|
+
const tracersIn = require_backend.zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
4026
|
+
const outs = f(...tracersIn);
|
|
4027
|
+
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
4028
|
+
return require_backend.unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
4029
|
+
} catch (_) {
|
|
4030
|
+
_usingCtx$1.e = _;
|
|
4031
|
+
} finally {
|
|
4032
|
+
_usingCtx$1.d();
|
|
4033
|
+
}
|
|
4034
|
+
}
|
|
4035
|
+
function jvp$1(f, primals, tangents) {
|
|
4036
|
+
const [primalsFlat, inTree] = flatten(primals);
|
|
4037
|
+
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4038
|
+
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4039
|
+
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
4040
|
+
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4041
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4042
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4043
|
+
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4044
|
+
return [primalsOut, tangentsOut];
|
|
4045
|
+
}
|
|
4046
|
+
|
|
3648
4047
|
//#endregion
|
|
3649
4048
|
//#region src/frontend/linearize.ts
|
|
3650
4049
|
/** Array value that can either be known or unknown. */
|
|
@@ -3675,11 +4074,10 @@ function partialEvalFlat(f, pvalsIn) {
|
|
|
3675
4074
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3676
4075
|
const pvalsOut = tracersOut.map((t) => t.pval);
|
|
3677
4076
|
const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
|
|
3678
|
-
const
|
|
4077
|
+
const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
|
|
3679
4078
|
return {
|
|
3680
4079
|
jaxpr,
|
|
3681
|
-
pvalsOut
|
|
3682
|
-
consts
|
|
4080
|
+
pvalsOut
|
|
3683
4081
|
};
|
|
3684
4082
|
}
|
|
3685
4083
|
/**
|
|
@@ -3696,22 +4094,19 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3696
4094
|
const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
|
|
3697
4095
|
return [...primalsOut$1, ...tangentsOut];
|
|
3698
4096
|
};
|
|
3699
|
-
const { jaxpr, pvalsOut
|
|
4097
|
+
const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
|
|
3700
4098
|
const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
|
|
3701
4099
|
if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
|
|
3702
4100
|
const primalsOut = primalPvals.map((pval) => pval.val);
|
|
3703
4101
|
return {
|
|
3704
4102
|
primalsOut,
|
|
3705
|
-
jaxpr
|
|
3706
|
-
consts
|
|
4103
|
+
jaxpr
|
|
3707
4104
|
};
|
|
3708
4105
|
}
|
|
3709
4106
|
function linearizeFlat(f, primalsIn) {
|
|
3710
|
-
const { primalsOut, jaxpr
|
|
3711
|
-
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3712
|
-
const dispose$1 = () =>
|
|
3713
|
-
for (const c of consts) c.dispose();
|
|
3714
|
-
};
|
|
4107
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4108
|
+
const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
|
|
4109
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
3715
4110
|
return [
|
|
3716
4111
|
primalsOut,
|
|
3717
4112
|
fLin,
|
|
@@ -3795,7 +4190,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3795
4190
|
}
|
|
3796
4191
|
processPrimitive(primitive, tracers, params) {
|
|
3797
4192
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3798
|
-
if (primitive === Primitive.
|
|
4193
|
+
if (primitive === Primitive.Jit) {
|
|
3799
4194
|
const { name, jaxpr, numConsts } = params;
|
|
3800
4195
|
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3801
4196
|
}
|
|
@@ -3821,14 +4216,14 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3821
4216
|
* Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
|
|
3822
4217
|
* values as possible (with JIT) and forwarding the unknown ones.
|
|
3823
4218
|
*
|
|
3824
|
-
* Used when encountering a
|
|
4219
|
+
* Used when encountering a Jit rule during the trace.
|
|
3825
4220
|
*/
|
|
3826
4221
|
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3827
4222
|
jaxpr = jaxpr.flatten();
|
|
3828
4223
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3829
4224
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3830
4225
|
const [knownTracers, unknownTracers] = require_backend.partitionList(inUnknowns, tracers);
|
|
3831
|
-
const outs1Res = bind(Primitive.
|
|
4226
|
+
const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3832
4227
|
name: `${name}_peval`,
|
|
3833
4228
|
jaxpr: jaxpr1,
|
|
3834
4229
|
numConsts: 0
|
|
@@ -3838,7 +4233,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3838
4233
|
const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
|
|
3839
4234
|
const recipe = {
|
|
3840
4235
|
type: "JaxprEqn",
|
|
3841
|
-
prim: Primitive.
|
|
4236
|
+
prim: Primitive.Jit,
|
|
3842
4237
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3843
4238
|
params: {
|
|
3844
4239
|
name: `${name}_resid`,
|
|
@@ -3867,7 +4262,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
|
|
|
3867
4262
|
const eqns1 = [];
|
|
3868
4263
|
const eqns2 = [];
|
|
3869
4264
|
for (const eqn of jaxpr.eqns) {
|
|
3870
|
-
if (eqn.primitive === Primitive.
|
|
4265
|
+
if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
|
|
3871
4266
|
const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
|
|
3872
4267
|
if (hasUnknowns) {
|
|
3873
4268
|
for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
|
|
@@ -3941,11 +4336,8 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3941
4336
|
for (const t of tracersIn) t.dispose();
|
|
3942
4337
|
for (const t of tracersOut) t.dispose();
|
|
3943
4338
|
jaxpr = jaxpr.simplify();
|
|
3944
|
-
if (require_backend.DEBUG >= 5) console.
|
|
3945
|
-
return
|
|
3946
|
-
jaxpr,
|
|
3947
|
-
consts
|
|
3948
|
-
};
|
|
4339
|
+
if (require_backend.DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
4340
|
+
return new ClosedJaxpr(jaxpr, consts);
|
|
3949
4341
|
}
|
|
3950
4342
|
/** Marker type for pullback, used by transpose rules. */
|
|
3951
4343
|
var UndefPrimal = class {
|
|
@@ -4075,22 +4467,25 @@ const transposeRules = {
|
|
|
4075
4467
|
},
|
|
4076
4468
|
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
4077
4469
|
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
4470
|
+
const v = params.vmapDims;
|
|
4078
4471
|
const rev01 = [
|
|
4079
|
-
|
|
4080
|
-
|
|
4081
|
-
|
|
4472
|
+
...require_backend.range(v),
|
|
4473
|
+
v + 1,
|
|
4474
|
+
v,
|
|
4475
|
+
...require_backend.range(v + 2, ct.ndim)
|
|
4082
4476
|
];
|
|
4083
4477
|
if (lhs instanceof UndefPrimal) {
|
|
4084
4478
|
let kernel = rhs;
|
|
4085
4479
|
kernel = transpose$1(kernel, rev01);
|
|
4086
|
-
kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
|
|
4480
|
+
kernel = flip$1(kernel, require_backend.range(v + 2, kernel.ndim));
|
|
4087
4481
|
const result = conv$1(ct, kernel, {
|
|
4482
|
+
vmapDims: v,
|
|
4088
4483
|
strides: params.lhsDilation,
|
|
4089
4484
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4090
|
-
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4091
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4485
|
+
const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4486
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4092
4487
|
const padBefore = dilatedKernel - 1 - pl;
|
|
4093
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4488
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4094
4489
|
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
4095
4490
|
return [padBefore, padAfter];
|
|
4096
4491
|
}),
|
|
@@ -4102,11 +4497,12 @@ const transposeRules = {
|
|
|
4102
4497
|
const newLhs = transpose$1(lhs, rev01);
|
|
4103
4498
|
const newRhs = transpose$1(ct, rev01);
|
|
4104
4499
|
let result = conv$1(newLhs, newRhs, {
|
|
4500
|
+
vmapDims: v,
|
|
4105
4501
|
strides: params.rhsDilation,
|
|
4106
4502
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4107
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4108
|
-
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4109
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4503
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4504
|
+
const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4505
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4110
4506
|
const padFromLhs = dilatedCt - dilatedLhs;
|
|
4111
4507
|
const padFromRhs = dilatedKernel - pl - 1;
|
|
4112
4508
|
return [pl, padFromLhs + padFromRhs];
|
|
@@ -4133,6 +4529,11 @@ const transposeRules = {
|
|
|
4133
4529
|
cond.dispose();
|
|
4134
4530
|
return cts;
|
|
4135
4531
|
},
|
|
4532
|
+
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4533
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4534
|
+
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4535
|
+
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4536
|
+
},
|
|
4136
4537
|
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4137
4538
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4138
4539
|
return [transpose$1(ct, require_backend.invertPermutation(perm))];
|
|
@@ -4159,23 +4560,26 @@ const transposeRules = {
|
|
|
4159
4560
|
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4160
4561
|
return [shrink(ct, slice)];
|
|
4161
4562
|
},
|
|
4162
|
-
[Primitive.
|
|
4163
|
-
if (!(
|
|
4164
|
-
|
|
4165
|
-
|
|
4563
|
+
[Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
|
|
4564
|
+
if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
|
|
4565
|
+
const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
|
|
4566
|
+
lower: true,
|
|
4567
|
+
unitDiagonal
|
|
4568
|
+
});
|
|
4569
|
+
return [null, ctB];
|
|
4166
4570
|
},
|
|
4167
|
-
[Primitive.
|
|
4571
|
+
[Primitive.Jit](cts, args, { name, jaxpr }) {
|
|
4168
4572
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
4169
|
-
const
|
|
4573
|
+
const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
|
|
4170
4574
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
4171
|
-
const outs = bind(Primitive.
|
|
4172
|
-
...
|
|
4575
|
+
const outs = bind(Primitive.Jit, [
|
|
4576
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4173
4577
|
...residuals,
|
|
4174
4578
|
...cts
|
|
4175
4579
|
], {
|
|
4176
4580
|
name: `${name}_t`,
|
|
4177
|
-
jaxpr: newJaxpr,
|
|
4178
|
-
numConsts:
|
|
4581
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4582
|
+
numConsts: newJaxpr.consts.length
|
|
4179
4583
|
});
|
|
4180
4584
|
let i = 0;
|
|
4181
4585
|
return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
|
|
@@ -4188,31 +4592,25 @@ function transposeJaxpr(jaxpr, undefPrimals) {
|
|
|
4188
4592
|
if (prevResult) return prevResult;
|
|
4189
4593
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
4190
4594
|
const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
|
|
4191
|
-
const { jaxpr: newJaxpr
|
|
4595
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
|
|
4192
4596
|
const args = [];
|
|
4193
4597
|
let forwardInIdx = 0;
|
|
4194
4598
|
for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
|
|
4195
4599
|
else args.push(forwardIn[forwardInIdx++]);
|
|
4196
4600
|
return evalJaxprTransposed(jaxpr, args, cotangents);
|
|
4197
4601
|
})(forwardInTypes, outTypes);
|
|
4198
|
-
typecheckJaxpr(newJaxpr);
|
|
4199
|
-
const result = {
|
|
4200
|
-
newJaxpr,
|
|
4201
|
-
newConsts
|
|
4202
|
-
};
|
|
4602
|
+
typecheckJaxpr(newJaxpr.jaxpr);
|
|
4203
4603
|
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4204
|
-
transposeJaxprCache.get(jaxpr).set(cacheKey,
|
|
4205
|
-
return
|
|
4604
|
+
transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
4605
|
+
return newJaxpr;
|
|
4206
4606
|
}
|
|
4207
4607
|
function vjpFlat(f, primalsIn) {
|
|
4208
|
-
const { primalsOut, jaxpr
|
|
4608
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4209
4609
|
const fVjp = (...cotangents) => {
|
|
4210
|
-
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4211
|
-
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
4212
|
-
};
|
|
4213
|
-
const dispose$1 = () => {
|
|
4214
|
-
for (const c of consts) c.dispose();
|
|
4610
|
+
const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4611
|
+
return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
|
|
4215
4612
|
};
|
|
4613
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
4216
4614
|
return [
|
|
4217
4615
|
primalsOut,
|
|
4218
4616
|
fVjp,
|
|
@@ -4269,150 +4667,6 @@ function jacrev$1(f) {
|
|
|
4269
4667
|
};
|
|
4270
4668
|
}
|
|
4271
4669
|
|
|
4272
|
-
//#endregion
|
|
4273
|
-
//#region src/library/lax.ts
|
|
4274
|
-
var lax_exports = {};
|
|
4275
|
-
__export(lax_exports, {
|
|
4276
|
-
conv: () => conv,
|
|
4277
|
-
convGeneralDilated: () => convGeneralDilated,
|
|
4278
|
-
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4279
|
-
dot: () => dot$1,
|
|
4280
|
-
erf: () => erf,
|
|
4281
|
-
erfc: () => erfc,
|
|
4282
|
-
reduceWindow: () => reduceWindow,
|
|
4283
|
-
stopGradient: () => stopGradient$1
|
|
4284
|
-
});
|
|
4285
|
-
/**
|
|
4286
|
-
* General dot product/contraction operator.
|
|
4287
|
-
*
|
|
4288
|
-
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
4289
|
-
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
4290
|
-
*/
|
|
4291
|
-
function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
4292
|
-
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
4293
|
-
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
4294
|
-
lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
4295
|
-
rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
4296
|
-
lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
4297
|
-
rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
4298
|
-
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
4299
|
-
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
4300
|
-
const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
4301
|
-
const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
4302
|
-
const lhs2 = lhs.transpose([
|
|
4303
|
-
...lb,
|
|
4304
|
-
...lf,
|
|
4305
|
-
...lc
|
|
4306
|
-
]);
|
|
4307
|
-
const rhs2 = rhs.transpose([
|
|
4308
|
-
...rb,
|
|
4309
|
-
...rf,
|
|
4310
|
-
...rc
|
|
4311
|
-
]);
|
|
4312
|
-
if (lc.length === 0) return mul(lhs2.reshape([
|
|
4313
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4314
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4315
|
-
...require_backend.rep(rf.length, 1)
|
|
4316
|
-
]), rhs2.reshape([
|
|
4317
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4318
|
-
...require_backend.rep(lf.length, 1),
|
|
4319
|
-
...rf.map((a) => rhs.shape[a])
|
|
4320
|
-
]));
|
|
4321
|
-
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
4322
|
-
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
4323
|
-
if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
4324
|
-
return dot$2(lhs2.reshape([
|
|
4325
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4326
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4327
|
-
...require_backend.rep(rf.length, 1),
|
|
4328
|
-
require_backend.prod(dotShapeX)
|
|
4329
|
-
]), rhs2.reshape([
|
|
4330
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4331
|
-
...require_backend.rep(lf.length, 1),
|
|
4332
|
-
...rf.map((a) => rhs.shape[a]),
|
|
4333
|
-
require_backend.prod(dotShapeY)
|
|
4334
|
-
]));
|
|
4335
|
-
}
|
|
4336
|
-
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4337
|
-
const padType = padding.toUpperCase();
|
|
4338
|
-
switch (padType) {
|
|
4339
|
-
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
4340
|
-
case "SAME":
|
|
4341
|
-
case "SAME_LOWER": {
|
|
4342
|
-
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
4343
|
-
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
4344
|
-
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
4345
|
-
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
4346
|
-
}
|
|
4347
|
-
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
4348
|
-
}
|
|
4349
|
-
}
|
|
4350
|
-
/**
|
|
4351
|
-
* General n-dimensional convolution operator, with optional dilation.
|
|
4352
|
-
*
|
|
4353
|
-
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
4354
|
-
* function in JAX, which wraps XLA's general convolution operator.
|
|
4355
|
-
*
|
|
4356
|
-
* Grouped convolutions are not supported right now.
|
|
4357
|
-
*/
|
|
4358
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
4359
|
-
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4360
|
-
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4361
|
-
if (typeof padding === "string") {
|
|
4362
|
-
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4363
|
-
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
4364
|
-
}
|
|
4365
|
-
return conv$1(lhs, rhs, {
|
|
4366
|
-
strides: windowStrides,
|
|
4367
|
-
padding,
|
|
4368
|
-
lhsDilation,
|
|
4369
|
-
rhsDilation
|
|
4370
|
-
});
|
|
4371
|
-
}
|
|
4372
|
-
/** Convenience wrapper around `convGeneralDilated`. */
|
|
4373
|
-
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
4374
|
-
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
4375
|
-
lhsDilation,
|
|
4376
|
-
rhsDilation
|
|
4377
|
-
});
|
|
4378
|
-
}
|
|
4379
|
-
/** Convenience wrapper around `convGeneralDilated`. */
|
|
4380
|
-
function conv(lhs, rhs, windowStrides, padding) {
|
|
4381
|
-
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
4382
|
-
}
|
|
4383
|
-
/** Reduce a computation over padded windows. */
|
|
4384
|
-
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
4385
|
-
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
4386
|
-
if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
|
|
4387
|
-
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
4388
|
-
return computation(bind1(Primitive.Pool, [operand], {
|
|
4389
|
-
window: windowDimensions,
|
|
4390
|
-
strides: windowStrides
|
|
4391
|
-
}));
|
|
4392
|
-
}
|
|
4393
|
-
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4394
|
-
function erf(x) {
|
|
4395
|
-
return erf$1(x);
|
|
4396
|
-
}
|
|
4397
|
-
/**
|
|
4398
|
-
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4399
|
-
*
|
|
4400
|
-
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4401
|
-
* where `erf(x)` is very close to 1.
|
|
4402
|
-
*/
|
|
4403
|
-
function erfc(x) {
|
|
4404
|
-
return erfc$1(x);
|
|
4405
|
-
}
|
|
4406
|
-
/**
|
|
4407
|
-
* Stops gradient computation.
|
|
4408
|
-
*
|
|
4409
|
-
* Behaves as the identity function but prevents the flow of gradients during
|
|
4410
|
-
* forward or reverse-mode automatic differentiation.
|
|
4411
|
-
*/
|
|
4412
|
-
function stopGradient$1(x) {
|
|
4413
|
-
return stopGradient(x);
|
|
4414
|
-
}
|
|
4415
|
-
|
|
4416
4670
|
//#endregion
|
|
4417
4671
|
//#region src/library/numpy/einsum.ts
|
|
4418
4672
|
const bprod = (...xs) => xs.reduce((acc, x) => acc * BigInt(x), 1n);
|
|
@@ -4608,34 +4862,207 @@ function* allPaths(tensors, next) {
|
|
|
4608
4862
|
}
|
|
4609
4863
|
}
|
|
4610
4864
|
|
|
4865
|
+
//#endregion
|
|
4866
|
+
//#region src/library/numpy-fft.ts
|
|
4867
|
+
var numpy_fft_exports = {};
|
|
4868
|
+
__export(numpy_fft_exports, {
|
|
4869
|
+
fft: () => fft,
|
|
4870
|
+
ifft: () => ifft
|
|
4871
|
+
});
|
|
4872
|
+
function checkPairInput(name, a) {
|
|
4873
|
+
const fullName = `jax.numpy.fft.${name}`;
|
|
4874
|
+
if (!require_backend.deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
|
|
4875
|
+
if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
|
|
4876
|
+
if (!require_backend.isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
|
|
4877
|
+
}
|
|
4878
|
+
function checkPowerOfTwo(name, n) {
|
|
4879
|
+
if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
|
|
4880
|
+
}
|
|
4881
|
+
const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
|
|
4882
|
+
const half = 2 ** i;
|
|
4883
|
+
real = real.reshape([-1, 2 * half]);
|
|
4884
|
+
imag = imag.reshape([-1, 2 * half]);
|
|
4885
|
+
const k = arange(0, half, 1, { dtype: real.dtype });
|
|
4886
|
+
const theta = k.mul(-Math.PI / half);
|
|
4887
|
+
const wr = cos(theta.ref);
|
|
4888
|
+
const wi = sin(theta);
|
|
4889
|
+
const ur = real.ref.slice([], [0, half]);
|
|
4890
|
+
const ui = imag.ref.slice([], [0, half]);
|
|
4891
|
+
const vr = real.slice([], [half, 2 * half]);
|
|
4892
|
+
const vi = imag.slice([], [half, 2 * half]);
|
|
4893
|
+
const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
|
|
4894
|
+
const ti = vr.mul(wi).add(vi.mul(wr));
|
|
4895
|
+
return {
|
|
4896
|
+
real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
|
|
4897
|
+
imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
|
|
4898
|
+
};
|
|
4899
|
+
}, { staticArgnums: [0] });
|
|
4900
|
+
/**
|
|
4901
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
4902
|
+
*
|
|
4903
|
+
* Currently, the size of the axis must be a power of two.
|
|
4904
|
+
*/
|
|
4905
|
+
function fft(a, axis = -1) {
|
|
4906
|
+
checkPairInput("fft", a);
|
|
4907
|
+
let { real, imag } = a;
|
|
4908
|
+
axis = require_backend.checkAxis(axis, real.ndim);
|
|
4909
|
+
const n = real.shape[axis];
|
|
4910
|
+
checkPowerOfTwo("fft", n);
|
|
4911
|
+
const logN = Math.log2(n);
|
|
4912
|
+
let perm = null;
|
|
4913
|
+
if (axis !== real.ndim - 1) {
|
|
4914
|
+
perm = require_backend.range(real.ndim);
|
|
4915
|
+
perm.splice(axis, 1);
|
|
4916
|
+
perm.push(axis);
|
|
4917
|
+
real = real.transpose(perm);
|
|
4918
|
+
imag = imag.transpose(perm);
|
|
4919
|
+
}
|
|
4920
|
+
const originalShape = real.shape;
|
|
4921
|
+
real = real.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
|
|
4922
|
+
imag = imag.reshape([-1, ...require_backend.rep(logN, 2)]).transpose([0, ...require_backend.range(1, logN + 1).reverse()]).flatten();
|
|
4923
|
+
for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
|
|
4924
|
+
real,
|
|
4925
|
+
imag
|
|
4926
|
+
}));
|
|
4927
|
+
real = real.reshape(originalShape);
|
|
4928
|
+
imag = imag.reshape(originalShape);
|
|
4929
|
+
if (perm !== null) {
|
|
4930
|
+
real = real.transpose(require_backend.invertPermutation(perm));
|
|
4931
|
+
imag = imag.transpose(require_backend.invertPermutation(perm));
|
|
4932
|
+
}
|
|
4933
|
+
return {
|
|
4934
|
+
real,
|
|
4935
|
+
imag
|
|
4936
|
+
};
|
|
4937
|
+
}
|
|
4938
|
+
/**
|
|
4939
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
4940
|
+
*
|
|
4941
|
+
* Currently, the size of the axis must be a power of two.
|
|
4942
|
+
*/
|
|
4943
|
+
function ifft(a, axis = -1) {
|
|
4944
|
+
checkPairInput("ifft", a);
|
|
4945
|
+
let { real, imag } = a;
|
|
4946
|
+
axis = require_backend.checkAxis(axis, real.ndim);
|
|
4947
|
+
const n = real.shape[axis];
|
|
4948
|
+
checkPowerOfTwo("ifft", n);
|
|
4949
|
+
imag = imag.mul(-1);
|
|
4950
|
+
const result = fft({
|
|
4951
|
+
real,
|
|
4952
|
+
imag
|
|
4953
|
+
}, axis);
|
|
4954
|
+
return {
|
|
4955
|
+
real: result.real.div(n),
|
|
4956
|
+
imag: result.imag.mul(-1).div(n)
|
|
4957
|
+
};
|
|
4958
|
+
}
|
|
4959
|
+
|
|
4960
|
+
//#endregion
|
|
4961
|
+
//#region src/library/numpy-linalg.ts
|
|
4962
|
+
var numpy_linalg_exports = {};
|
|
4963
|
+
__export(numpy_linalg_exports, {
|
|
4964
|
+
cholesky: () => cholesky$1,
|
|
4965
|
+
diagonal: () => diagonal,
|
|
4966
|
+
lstsq: () => lstsq,
|
|
4967
|
+
matmul: () => matmul,
|
|
4968
|
+
matrixTranspose: () => matrixTranspose,
|
|
4969
|
+
outer: () => outer,
|
|
4970
|
+
tensordot: () => tensordot,
|
|
4971
|
+
trace: () => trace,
|
|
4972
|
+
vecdot: () => vecdot
|
|
4973
|
+
});
|
|
4974
|
+
/**
|
|
4975
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4976
|
+
*
|
|
4977
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4978
|
+
* the input matrix, which is on by default.
|
|
4979
|
+
*/
|
|
4980
|
+
function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4981
|
+
a = fudgeArray(a);
|
|
4982
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
|
|
4983
|
+
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4984
|
+
return cholesky(a, { upper });
|
|
4985
|
+
}
|
|
4986
|
+
/**
|
|
4987
|
+
* Return the least-squares solution to a linear equation.
|
|
4988
|
+
*
|
|
4989
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
4990
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
4991
|
+
*
|
|
4992
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
4993
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
4994
|
+
*
|
|
4995
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
4996
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
4997
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
4998
|
+
*/
|
|
4999
|
+
function lstsq(a, b) {
|
|
5000
|
+
a = fudgeArray(a);
|
|
5001
|
+
b = fudgeArray(b);
|
|
5002
|
+
if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
|
|
5003
|
+
const [m, n] = a.shape;
|
|
5004
|
+
if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
|
|
5005
|
+
const at = matrixTranspose(a.ref);
|
|
5006
|
+
if (m <= n) {
|
|
5007
|
+
const aat = matmul(a, at.ref);
|
|
5008
|
+
const l = cholesky$1(aat, { symmetrizeInput: false });
|
|
5009
|
+
const lb = triangularSolve(l.ref, b, {
|
|
5010
|
+
leftSide: true,
|
|
5011
|
+
lower: true
|
|
5012
|
+
});
|
|
5013
|
+
const llb = triangularSolve(l, lb, {
|
|
5014
|
+
leftSide: true,
|
|
5015
|
+
transposeA: true
|
|
5016
|
+
});
|
|
5017
|
+
return matmul(at, llb.ref);
|
|
5018
|
+
} else {
|
|
5019
|
+
const ata = matmul(at.ref, a);
|
|
5020
|
+
const l = cholesky$1(ata, { symmetrizeInput: false });
|
|
5021
|
+
const atb = matmul(at, b);
|
|
5022
|
+
const lb = triangularSolve(l.ref, atb, {
|
|
5023
|
+
leftSide: true,
|
|
5024
|
+
lower: true
|
|
5025
|
+
});
|
|
5026
|
+
const llb = triangularSolve(l, lb, {
|
|
5027
|
+
leftSide: true,
|
|
5028
|
+
transposeA: true
|
|
5029
|
+
});
|
|
5030
|
+
return llb;
|
|
5031
|
+
}
|
|
5032
|
+
}
|
|
5033
|
+
|
|
4611
5034
|
//#endregion
|
|
4612
5035
|
//#region src/library/numpy.ts
|
|
4613
5036
|
var numpy_exports = {};
|
|
4614
5037
|
__export(numpy_exports, {
|
|
4615
5038
|
Array: () => Array$1,
|
|
4616
5039
|
DType: () => require_backend.DType,
|
|
4617
|
-
abs: () =>
|
|
5040
|
+
abs: () => absolute,
|
|
4618
5041
|
absolute: () => absolute,
|
|
4619
5042
|
acos: () => acos,
|
|
4620
|
-
acosh: () =>
|
|
5043
|
+
acosh: () => arccosh,
|
|
4621
5044
|
add: () => add,
|
|
5045
|
+
all: () => all,
|
|
4622
5046
|
allclose: () => allclose,
|
|
5047
|
+
any: () => any,
|
|
4623
5048
|
arange: () => arange,
|
|
4624
|
-
arccos: () =>
|
|
5049
|
+
arccos: () => acos,
|
|
4625
5050
|
arccosh: () => arccosh,
|
|
5051
|
+
arcsin: () => asin,
|
|
4626
5052
|
arcsinh: () => arcsinh,
|
|
4627
|
-
arctan: () =>
|
|
4628
|
-
arctan2: () =>
|
|
5053
|
+
arctan: () => atan,
|
|
5054
|
+
arctan2: () => atan2,
|
|
4629
5055
|
arctanh: () => arctanh,
|
|
4630
5056
|
argmax: () => argmax,
|
|
4631
5057
|
argmin: () => argmin,
|
|
5058
|
+
argsort: () => argsort,
|
|
4632
5059
|
array: () => array,
|
|
4633
5060
|
asin: () => asin,
|
|
4634
|
-
asinh: () =>
|
|
5061
|
+
asinh: () => arcsinh,
|
|
4635
5062
|
astype: () => astype,
|
|
4636
5063
|
atan: () => atan,
|
|
4637
5064
|
atan2: () => atan2,
|
|
4638
|
-
atanh: () =>
|
|
5065
|
+
atanh: () => arctanh,
|
|
4639
5066
|
bool: () => bool,
|
|
4640
5067
|
broadcastArrays: () => broadcastArrays,
|
|
4641
5068
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -4645,14 +5072,20 @@ __export(numpy_exports, {
|
|
|
4645
5072
|
clip: () => clip,
|
|
4646
5073
|
columnStack: () => columnStack,
|
|
4647
5074
|
concatenate: () => concatenate,
|
|
5075
|
+
convolve: () => convolve,
|
|
5076
|
+
corrcoef: () => corrcoef,
|
|
5077
|
+
correlate: () => correlate,
|
|
4648
5078
|
cos: () => cos,
|
|
4649
5079
|
cosh: () => cosh,
|
|
5080
|
+
cov: () => cov,
|
|
5081
|
+
cumsum: () => cumsum,
|
|
5082
|
+
cumulativeSum: () => cumsum,
|
|
4650
5083
|
deg2rad: () => deg2rad,
|
|
4651
5084
|
degrees: () => degrees,
|
|
4652
5085
|
diag: () => diag,
|
|
4653
5086
|
diagonal: () => diagonal,
|
|
4654
|
-
divide: () =>
|
|
4655
|
-
dot: () => dot,
|
|
5087
|
+
divide: () => trueDivide,
|
|
5088
|
+
dot: () => dot$1,
|
|
4656
5089
|
dstack: () => dstack,
|
|
4657
5090
|
e: () => e,
|
|
4658
5091
|
einsum: () => einsum,
|
|
@@ -4660,8 +5093,10 @@ __export(numpy_exports, {
|
|
|
4660
5093
|
eulerGamma: () => eulerGamma,
|
|
4661
5094
|
exp: () => exp,
|
|
4662
5095
|
exp2: () => exp2,
|
|
5096
|
+
expandDims: () => expandDims,
|
|
4663
5097
|
expm1: () => expm1,
|
|
4664
5098
|
eye: () => eye,
|
|
5099
|
+
fft: () => numpy_fft_exports,
|
|
4665
5100
|
flip: () => flip,
|
|
4666
5101
|
fliplr: () => fliplr,
|
|
4667
5102
|
flipud: () => flipud,
|
|
@@ -4692,12 +5127,14 @@ __export(numpy_exports, {
|
|
|
4692
5127
|
ldexp: () => ldexp,
|
|
4693
5128
|
less: () => less,
|
|
4694
5129
|
lessEqual: () => lessEqual,
|
|
5130
|
+
linalg: () => numpy_linalg_exports,
|
|
4695
5131
|
linspace: () => linspace,
|
|
4696
5132
|
log: () => log,
|
|
4697
5133
|
log10: () => log10,
|
|
4698
5134
|
log1p: () => log1p,
|
|
4699
5135
|
log2: () => log2,
|
|
4700
5136
|
matmul: () => matmul,
|
|
5137
|
+
matrixTranspose: () => matrixTranspose,
|
|
4701
5138
|
max: () => max,
|
|
4702
5139
|
maximum: () => maximum,
|
|
4703
5140
|
mean: () => mean,
|
|
@@ -4714,10 +5151,10 @@ __export(numpy_exports, {
|
|
|
4714
5151
|
onesLike: () => onesLike,
|
|
4715
5152
|
outer: () => outer,
|
|
4716
5153
|
pad: () => pad,
|
|
4717
|
-
permuteDims: () =>
|
|
5154
|
+
permuteDims: () => transpose,
|
|
4718
5155
|
pi: () => pi,
|
|
4719
5156
|
positive: () => positive,
|
|
4720
|
-
pow: () =>
|
|
5157
|
+
pow: () => power,
|
|
4721
5158
|
power: () => power,
|
|
4722
5159
|
prod: () => prod$1,
|
|
4723
5160
|
promoteTypes: () => require_backend.promoteTypes,
|
|
@@ -4734,6 +5171,7 @@ __export(numpy_exports, {
|
|
|
4734
5171
|
sin: () => sin,
|
|
4735
5172
|
sinh: () => sinh,
|
|
4736
5173
|
size: () => size,
|
|
5174
|
+
sort: () => sort,
|
|
4737
5175
|
sqrt: () => sqrt,
|
|
4738
5176
|
square: () => square,
|
|
4739
5177
|
squeeze: () => squeeze,
|
|
@@ -4898,6 +5336,26 @@ function min(a, axis = null, opts) {
|
|
|
4898
5336
|
function max(a, axis = null, opts) {
|
|
4899
5337
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
4900
5338
|
}
|
|
5339
|
+
/**
|
|
5340
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5341
|
+
*
|
|
5342
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5343
|
+
* removed. If axis is None, returns a scalar.
|
|
5344
|
+
*/
|
|
5345
|
+
function all(a, axis = null, opts) {
|
|
5346
|
+
a = fudgeArray(a).astype(require_backend.DType.Bool);
|
|
5347
|
+
return min(a, axis, opts);
|
|
5348
|
+
}
|
|
5349
|
+
/**
|
|
5350
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5351
|
+
*
|
|
5352
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5353
|
+
* removed. If axis is None, returns a scalar.
|
|
5354
|
+
*/
|
|
5355
|
+
function any(a, axis = null, opts) {
|
|
5356
|
+
a = fudgeArray(a).astype(require_backend.DType.Bool);
|
|
5357
|
+
return max(a, axis, opts);
|
|
5358
|
+
}
|
|
4901
5359
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
4902
5360
|
function ptp(a, axis = null, opts) {
|
|
4903
5361
|
a = fudgeArray(a);
|
|
@@ -4955,6 +5413,23 @@ function argmax(a, axis, opts) {
|
|
|
4955
5413
|
}).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
|
|
4956
5414
|
return length.sub(max(idx, axis, opts));
|
|
4957
5415
|
}
|
|
5416
|
+
/**
|
|
5417
|
+
* Cumulative sum of elements along an axis.
|
|
5418
|
+
*
|
|
5419
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
5420
|
+
* two-phase parallel reduction algorithm.
|
|
5421
|
+
*/
|
|
5422
|
+
function cumsum(a, axis) {
|
|
5423
|
+
a = fudgeArray(a);
|
|
5424
|
+
if (axis === void 0) {
|
|
5425
|
+
a = a.ravel();
|
|
5426
|
+
axis = 0;
|
|
5427
|
+
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
5428
|
+
const n = a.shape[axis];
|
|
5429
|
+
a = moveaxis$1(a, axis, -1);
|
|
5430
|
+
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5431
|
+
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5432
|
+
}
|
|
4958
5433
|
/** Reverse the elements in an array along the given axes. */
|
|
4959
5434
|
function flip(x, axis = null) {
|
|
4960
5435
|
const nd = ndim(x);
|
|
@@ -5064,8 +5539,11 @@ function flipud(x) {
|
|
|
5064
5539
|
function fliplr(x) {
|
|
5065
5540
|
return flip(x, 1);
|
|
5066
5541
|
}
|
|
5067
|
-
/**
|
|
5068
|
-
|
|
5542
|
+
/** Transpose the last two dimensions of an array. */
|
|
5543
|
+
function matrixTranspose(a) {
|
|
5544
|
+
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
5545
|
+
return moveaxis$1(a, -1, -2);
|
|
5546
|
+
}
|
|
5069
5547
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
5070
5548
|
function ravel(a) {
|
|
5071
5549
|
return fudgeArray(a).ravel();
|
|
@@ -5081,6 +5559,32 @@ function squeeze(a, axis = null) {
|
|
|
5081
5559
|
return reshape(a, newShape);
|
|
5082
5560
|
}
|
|
5083
5561
|
/**
|
|
5562
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
5563
|
+
*
|
|
5564
|
+
* @param a - Input array.
|
|
5565
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
5566
|
+
* is placed. Can be a single integer or an array of integers.
|
|
5567
|
+
* @returns Array with the number of dimensions increased.
|
|
5568
|
+
*
|
|
5569
|
+
* @example
|
|
5570
|
+
* ```ts
|
|
5571
|
+
* const x = np.array([1, 2]);
|
|
5572
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
5573
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
5574
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
5575
|
+
* ```
|
|
5576
|
+
*/
|
|
5577
|
+
function expandDims(a, axis) {
|
|
5578
|
+
const as = shape(a);
|
|
5579
|
+
axis = typeof axis === "number" ? [axis] : axis;
|
|
5580
|
+
axis = require_backend.normalizeAxis(axis, as.length + axis.length);
|
|
5581
|
+
const newShape = [];
|
|
5582
|
+
let srcIdx = 0;
|
|
5583
|
+
for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
|
|
5584
|
+
else newShape.push(as[srcIdx++]);
|
|
5585
|
+
return reshape(a, newShape);
|
|
5586
|
+
}
|
|
5587
|
+
/**
|
|
5084
5588
|
* Repeat each element of an array after themselves.
|
|
5085
5589
|
*
|
|
5086
5590
|
* If no axis is provided, use the flattened input array, and return a flat
|
|
@@ -5168,7 +5672,7 @@ function diagonal(a, offset, axis1, axis2) {
|
|
|
5168
5672
|
*/
|
|
5169
5673
|
function diag(v, k = 0) {
|
|
5170
5674
|
const a = fudgeArray(v);
|
|
5171
|
-
if (!Number.isInteger(k)) throw new
|
|
5675
|
+
if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
|
|
5172
5676
|
if (a.ndim === 1) {
|
|
5173
5677
|
const n = a.shape[0];
|
|
5174
5678
|
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
@@ -5176,12 +5680,32 @@ function diag(v, k = 0) {
|
|
|
5176
5680
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
5177
5681
|
else return ret;
|
|
5178
5682
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
5179
|
-
else throw new
|
|
5683
|
+
else throw new Error("numpy.diag only supports 1D and 2D arrays");
|
|
5180
5684
|
}
|
|
5181
5685
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
5182
5686
|
function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
|
|
5183
5687
|
return diagonal(a, offset, axis1, axis2).sum(-1);
|
|
5184
5688
|
}
|
|
5689
|
+
/**
|
|
5690
|
+
* Return a sorted copy of an array.
|
|
5691
|
+
*
|
|
5692
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
5693
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
5694
|
+
*/
|
|
5695
|
+
function sort(a, axis = -1) {
|
|
5696
|
+
return fudgeArray(a).sort(axis);
|
|
5697
|
+
}
|
|
5698
|
+
/**
|
|
5699
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
5700
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
5701
|
+
*
|
|
5702
|
+
* Returns an array of `int32` indices.
|
|
5703
|
+
*
|
|
5704
|
+
* The array is sorted along a specified axis (the last by default).
|
|
5705
|
+
*/
|
|
5706
|
+
function argsort(a, axis = -1) {
|
|
5707
|
+
return fudgeArray(a).argsort(axis);
|
|
5708
|
+
}
|
|
5185
5709
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5186
5710
|
function allclose(actual, expected, options) {
|
|
5187
5711
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5190,16 +5714,19 @@ function allclose(actual, expected, options) {
|
|
|
5190
5714
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
5191
5715
|
const xData = x.dataSync();
|
|
5192
5716
|
const yData = y.dataSync();
|
|
5193
|
-
for (let i = 0; i < xData.length; i++)
|
|
5717
|
+
for (let i = 0; i < xData.length; i++) {
|
|
5718
|
+
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
5719
|
+
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
5720
|
+
}
|
|
5194
5721
|
return true;
|
|
5195
5722
|
}
|
|
5196
5723
|
/** Matrix product of two arrays. */
|
|
5197
5724
|
function matmul(x, y) {
|
|
5198
|
-
if (ndim(x) === 0 || ndim(y) === 0) throw new
|
|
5725
|
+
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
5199
5726
|
x = x, y = y;
|
|
5200
5727
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5201
5728
|
const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
|
|
5202
|
-
return dot
|
|
5729
|
+
return dot(x, y, {
|
|
5203
5730
|
lhsContractingDims: [-1],
|
|
5204
5731
|
rhsContractingDims: [-2],
|
|
5205
5732
|
lhsBatchDims: require_backend.range(-2 - numBatchDims, -2),
|
|
@@ -5207,11 +5734,11 @@ function matmul(x, y) {
|
|
|
5207
5734
|
});
|
|
5208
5735
|
}
|
|
5209
5736
|
/** Dot product of two arrays. */
|
|
5210
|
-
function dot(x, y) {
|
|
5737
|
+
function dot$1(x, y) {
|
|
5211
5738
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
5212
5739
|
x = x, y = y;
|
|
5213
5740
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5214
|
-
return dot
|
|
5741
|
+
return dot(x, y, {
|
|
5215
5742
|
lhsContractingDims: [-1],
|
|
5216
5743
|
rhsContractingDims: [-2]
|
|
5217
5744
|
});
|
|
@@ -5227,7 +5754,7 @@ function tensordot(x, y, axes = 2) {
|
|
|
5227
5754
|
x = fudgeArray(x);
|
|
5228
5755
|
y = fudgeArray(y);
|
|
5229
5756
|
if (typeof axes === "number") axes = [require_backend.range(-axes, 0), require_backend.range(axes)];
|
|
5230
|
-
return dot
|
|
5757
|
+
return dot(x, y, {
|
|
5231
5758
|
lhsContractingDims: axes[0],
|
|
5232
5759
|
rhsContractingDims: axes[1]
|
|
5233
5760
|
});
|
|
@@ -5320,7 +5847,7 @@ function einsum(...args) {
|
|
|
5320
5847
|
const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
|
|
5321
5848
|
indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
|
|
5322
5849
|
const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
|
|
5323
|
-
const result = dot
|
|
5850
|
+
const result = dot(a, b, {
|
|
5324
5851
|
lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
|
|
5325
5852
|
rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
|
|
5326
5853
|
lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
|
|
@@ -5348,7 +5875,7 @@ function einsum(...args) {
|
|
|
5348
5875
|
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
5349
5876
|
*/
|
|
5350
5877
|
function inner(x, y) {
|
|
5351
|
-
return dot
|
|
5878
|
+
return dot(fudgeArray(x), fudgeArray(y), {
|
|
5352
5879
|
lhsContractingDims: [-1],
|
|
5353
5880
|
rhsContractingDims: [-1]
|
|
5354
5881
|
});
|
|
@@ -5381,6 +5908,30 @@ function vecdot(x, y, { axis } = {}) {
|
|
|
5381
5908
|
function vdot(x, y) {
|
|
5382
5909
|
return dot$2(ravel(x), ravel(y));
|
|
5383
5910
|
}
|
|
5911
|
+
function _convImpl(name, x, y, mode) {
|
|
5912
|
+
if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
|
|
5913
|
+
let flipOutput = false;
|
|
5914
|
+
if (x.shape[0] < y.shape[0]) {
|
|
5915
|
+
[x, y] = [y, x];
|
|
5916
|
+
if (name === "correlate") flipOutput = true;
|
|
5917
|
+
}
|
|
5918
|
+
if (name === "convolve") y = flip(y);
|
|
5919
|
+
let padding;
|
|
5920
|
+
if (mode === "valid") padding = "VALID";
|
|
5921
|
+
else if (mode === "same") padding = "SAME_LOWER";
|
|
5922
|
+
else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
|
|
5923
|
+
else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
|
|
5924
|
+
const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
|
|
5925
|
+
return flipOutput ? flip(z) : z;
|
|
5926
|
+
}
|
|
5927
|
+
/** Convolution of two one-dimensional arrays. */
|
|
5928
|
+
function convolve(x, y, mode = "full") {
|
|
5929
|
+
return _convImpl("convolve", x, y, mode);
|
|
5930
|
+
}
|
|
5931
|
+
/** Correlation of two one dimensional arrays. */
|
|
5932
|
+
function correlate(x, y, mode = "valid") {
|
|
5933
|
+
return _convImpl("correlate", x, y, mode);
|
|
5934
|
+
}
|
|
5384
5935
|
/**
|
|
5385
5936
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
5386
5937
|
*
|
|
@@ -5389,7 +5940,7 @@ function vdot(x, y) {
|
|
|
5389
5940
|
*/
|
|
5390
5941
|
function meshgrid(xs, { indexing } = {}) {
|
|
5391
5942
|
indexing ??= "xy";
|
|
5392
|
-
for (const x of xs) if (x.ndim !== 1) throw new
|
|
5943
|
+
for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
|
|
5393
5944
|
if (xs.length <= 1) return xs;
|
|
5394
5945
|
if (indexing === "xy") {
|
|
5395
5946
|
const [a, b, ...rest] = xs;
|
|
@@ -5408,43 +5959,6 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
5408
5959
|
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
5409
5960
|
}
|
|
5410
5961
|
/**
|
|
5411
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
5412
|
-
*
|
|
5413
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
5414
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
5415
|
-
* `k>0` is above it.
|
|
5416
|
-
*/
|
|
5417
|
-
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
5418
|
-
m ??= n;
|
|
5419
|
-
dtype ??= require_backend.DType.Float32;
|
|
5420
|
-
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
5421
|
-
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
5422
|
-
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
5423
|
-
const rows = arange(k, n + k, 1, {
|
|
5424
|
-
dtype: require_backend.DType.Int32,
|
|
5425
|
-
device
|
|
5426
|
-
});
|
|
5427
|
-
const cols = arange(0, m, 1, {
|
|
5428
|
-
dtype: require_backend.DType.Int32,
|
|
5429
|
-
device
|
|
5430
|
-
});
|
|
5431
|
-
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
5432
|
-
}
|
|
5433
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
5434
|
-
function tril(a, k = 0) {
|
|
5435
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5436
|
-
a = fudgeArray(a);
|
|
5437
|
-
const [n, m] = a.shape.slice(-2);
|
|
5438
|
-
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
5439
|
-
}
|
|
5440
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
5441
|
-
function triu(a, k = 0) {
|
|
5442
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5443
|
-
a = fudgeArray(a);
|
|
5444
|
-
const [n, m] = a.shape.slice(-2);
|
|
5445
|
-
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
5446
|
-
}
|
|
5447
|
-
/**
|
|
5448
5962
|
* Clip (limit) the values in an array.
|
|
5449
5963
|
*
|
|
5450
5964
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -5468,8 +5982,6 @@ function absolute(x) {
|
|
|
5468
5982
|
x = fudgeArray(x);
|
|
5469
5983
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
5470
5984
|
}
|
|
5471
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
5472
|
-
const abs = absolute;
|
|
5473
5985
|
/** Return an element-wise indication of sign of the input. */
|
|
5474
5986
|
function sign(x) {
|
|
5475
5987
|
x = fudgeArray(x);
|
|
@@ -5548,12 +6060,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
|
|
|
5548
6060
|
const denom = where(xNeg, y, r.add(x));
|
|
5549
6061
|
return atan(numer.div(denom)).mul(2);
|
|
5550
6062
|
});
|
|
5551
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
5552
|
-
const arccos = acos;
|
|
5553
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
5554
|
-
const arctan = atan;
|
|
5555
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
5556
|
-
const arctan2 = atan2;
|
|
5557
6063
|
/** Element-wise subtraction, with broadcasting. */
|
|
5558
6064
|
function subtract(x, y) {
|
|
5559
6065
|
x = fudgeArray(x);
|
|
@@ -5584,8 +6090,6 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
5584
6090
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
5585
6091
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
5586
6092
|
});
|
|
5587
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
5588
|
-
const divide = trueDivide;
|
|
5589
6093
|
/** Round input to the nearest integer towards zero. */
|
|
5590
6094
|
function trunc(x) {
|
|
5591
6095
|
return idiv(x, 1);
|
|
@@ -5607,9 +6111,9 @@ function ldexp(x1, x2) {
|
|
|
5607
6111
|
*/
|
|
5608
6112
|
function frexp(x) {
|
|
5609
6113
|
x = fudgeArray(x);
|
|
5610
|
-
const absx =
|
|
6114
|
+
const absx = absolute(x.ref);
|
|
5611
6115
|
const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(require_backend.DType.Int32));
|
|
5612
|
-
const mantissa =
|
|
6116
|
+
const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
|
|
5613
6117
|
return [mantissa, exponent];
|
|
5614
6118
|
}
|
|
5615
6119
|
/** Calculate `2**p` for all p in the input array. */
|
|
@@ -5649,10 +6153,11 @@ const degrees = rad2deg;
|
|
|
5649
6153
|
* Computes first array raised to power of second array, element-wise.
|
|
5650
6154
|
*/
|
|
5651
6155
|
const power = jit$1(function power$1(x1, x2) {
|
|
5652
|
-
|
|
6156
|
+
const x2i = trunc(x2.ref);
|
|
6157
|
+
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
6158
|
+
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
6159
|
+
return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
|
|
5653
6160
|
});
|
|
5654
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
5655
|
-
const pow = power;
|
|
5656
6161
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
5657
6162
|
const cbrt = jit$1(function cbrt$1(x) {
|
|
5658
6163
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
@@ -5718,12 +6223,6 @@ const arccosh = jit$1(function arccosh$1(x) {
|
|
|
5718
6223
|
const arctanh = jit$1(function arctanh$1(x) {
|
|
5719
6224
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
5720
6225
|
});
|
|
5721
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5722
|
-
const asinh = arcsinh;
|
|
5723
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5724
|
-
const acosh = arccosh;
|
|
5725
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5726
|
-
const atanh = arctanh;
|
|
5727
6226
|
/**
|
|
5728
6227
|
* Compute the variance of an array.
|
|
5729
6228
|
*
|
|
@@ -5753,6 +6252,26 @@ function var_(x, axis = null, opts) {
|
|
|
5753
6252
|
function std(x, axis = null, opts) {
|
|
5754
6253
|
return sqrt(var_(x, axis, opts));
|
|
5755
6254
|
}
|
|
6255
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
6256
|
+
function cov(x, y) {
|
|
6257
|
+
x = fudgeArray(x);
|
|
6258
|
+
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6259
|
+
if (y !== void 0) {
|
|
6260
|
+
y = fudgeArray(y);
|
|
6261
|
+
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6262
|
+
x = vstack([x, y]);
|
|
6263
|
+
}
|
|
6264
|
+
const [_M, N] = x.shape;
|
|
6265
|
+
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6266
|
+
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
6267
|
+
}
|
|
6268
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
6269
|
+
function corrcoef(x, y) {
|
|
6270
|
+
const c = cov(x, y);
|
|
6271
|
+
const variances = diag(c.ref);
|
|
6272
|
+
const norm = sqrt(outer(variances.ref, variances));
|
|
6273
|
+
return c.div(norm);
|
|
6274
|
+
}
|
|
5756
6275
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5757
6276
|
function isinf(x) {
|
|
5758
6277
|
x = fudgeArray(x);
|
|
@@ -5782,6 +6301,253 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
5782
6301
|
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5783
6302
|
});
|
|
5784
6303
|
|
|
6304
|
+
//#endregion
|
|
6305
|
+
//#region src/library/lax-linalg.ts
|
|
6306
|
+
var lax_linalg_exports = {};
|
|
6307
|
+
__export(lax_linalg_exports, {
|
|
6308
|
+
cholesky: () => cholesky,
|
|
6309
|
+
triangularSolve: () => triangularSolve
|
|
6310
|
+
});
|
|
6311
|
+
/**
|
|
6312
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
6313
|
+
*
|
|
6314
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
6315
|
+
*
|
|
6316
|
+
* - A = L @ L^T (for upper=false, default)
|
|
6317
|
+
* - A = U^T @ U (for upper=true)
|
|
6318
|
+
*
|
|
6319
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
6320
|
+
* The input matrix must be symmetric and positive-definite.
|
|
6321
|
+
*
|
|
6322
|
+
* @example
|
|
6323
|
+
* ```ts
|
|
6324
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6325
|
+
*
|
|
6326
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
6327
|
+
*
|
|
6328
|
+
* // Lower Cholesky factorization (default):
|
|
6329
|
+
* const L = lax.linalg.cholesky(x);
|
|
6330
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
6331
|
+
*
|
|
6332
|
+
* // Upper Cholesky factorization:
|
|
6333
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
6334
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6335
|
+
* ```
|
|
6336
|
+
*/
|
|
6337
|
+
function cholesky(a, { upper = false } = {}) {
|
|
6338
|
+
const L = cholesky$2(a);
|
|
6339
|
+
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6340
|
+
}
|
|
6341
|
+
/**
|
|
6342
|
+
* Solve a triangular linear system.
|
|
6343
|
+
*
|
|
6344
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
6345
|
+
* where `a` is a triangular matrix.
|
|
6346
|
+
*
|
|
6347
|
+
* @example
|
|
6348
|
+
* ```ts
|
|
6349
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6350
|
+
*
|
|
6351
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
6352
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
6353
|
+
*
|
|
6354
|
+
* // Solve L @ x = b
|
|
6355
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
6356
|
+
* // x = [[2.], [5./3.]]
|
|
6357
|
+
* ```
|
|
6358
|
+
*/
|
|
6359
|
+
function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
|
|
6360
|
+
a = fudgeArray(a);
|
|
6361
|
+
b = fudgeArray(b);
|
|
6362
|
+
if (!leftSide) transposeA = !transposeA;
|
|
6363
|
+
else b = moveaxis$1(b, -2, -1);
|
|
6364
|
+
if (transposeA) a = moveaxis$1(a, -2, -1);
|
|
6365
|
+
let x = triangularSolve$1(a, b, {
|
|
6366
|
+
lower,
|
|
6367
|
+
unitDiagonal
|
|
6368
|
+
});
|
|
6369
|
+
if (leftSide) x = moveaxis$1(x, -2, -1);
|
|
6370
|
+
return x;
|
|
6371
|
+
}
|
|
6372
|
+
|
|
6373
|
+
//#endregion
|
|
6374
|
+
//#region src/library/lax.ts
|
|
6375
|
+
var lax_exports = {};
|
|
6376
|
+
__export(lax_exports, {
|
|
6377
|
+
conv: () => conv,
|
|
6378
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
6379
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6380
|
+
dot: () => dot,
|
|
6381
|
+
erf: () => erf,
|
|
6382
|
+
erfc: () => erfc,
|
|
6383
|
+
linalg: () => lax_linalg_exports,
|
|
6384
|
+
reduceWindow: () => reduceWindow,
|
|
6385
|
+
stopGradient: () => stopGradient$1
|
|
6386
|
+
});
|
|
6387
|
+
/**
|
|
6388
|
+
* General dot product/contraction operator.
|
|
6389
|
+
*
|
|
6390
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
6391
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
6392
|
+
*/
|
|
6393
|
+
function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
6394
|
+
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
6395
|
+
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
6396
|
+
lc = lc.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
6397
|
+
rc = rc.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
6398
|
+
lb = lb.map((a) => require_backend.checkAxis(a, lhs.ndim));
|
|
6399
|
+
rb = rb.map((a) => require_backend.checkAxis(a, rhs.ndim));
|
|
6400
|
+
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
6401
|
+
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
6402
|
+
const lf = require_backend.range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
6403
|
+
const rf = require_backend.range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
6404
|
+
const lhs2 = lhs.transpose([
|
|
6405
|
+
...lb,
|
|
6406
|
+
...lf,
|
|
6407
|
+
...lc
|
|
6408
|
+
]);
|
|
6409
|
+
const rhs2 = rhs.transpose([
|
|
6410
|
+
...rb,
|
|
6411
|
+
...rf,
|
|
6412
|
+
...rc
|
|
6413
|
+
]);
|
|
6414
|
+
if (lc.length === 0) return mul(lhs2.reshape([
|
|
6415
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6416
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6417
|
+
...require_backend.rep(rf.length, 1)
|
|
6418
|
+
]), rhs2.reshape([
|
|
6419
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6420
|
+
...require_backend.rep(lf.length, 1),
|
|
6421
|
+
...rf.map((a) => rhs.shape[a])
|
|
6422
|
+
]));
|
|
6423
|
+
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
6424
|
+
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
6425
|
+
if (!require_backend.deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
6426
|
+
return dot$2(lhs2.reshape([
|
|
6427
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6428
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6429
|
+
...require_backend.rep(rf.length, 1),
|
|
6430
|
+
require_backend.prod(dotShapeX)
|
|
6431
|
+
]), rhs2.reshape([
|
|
6432
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6433
|
+
...require_backend.rep(lf.length, 1),
|
|
6434
|
+
...rf.map((a) => rhs.shape[a]),
|
|
6435
|
+
require_backend.prod(dotShapeY)
|
|
6436
|
+
]));
|
|
6437
|
+
}
|
|
6438
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
6439
|
+
const padType = padding.toUpperCase();
|
|
6440
|
+
switch (padType) {
|
|
6441
|
+
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
6442
|
+
case "SAME":
|
|
6443
|
+
case "SAME_LOWER": {
|
|
6444
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
6445
|
+
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
6446
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
6447
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
6448
|
+
}
|
|
6449
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
6450
|
+
}
|
|
6451
|
+
}
|
|
6452
|
+
/**
|
|
6453
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
6454
|
+
*
|
|
6455
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6456
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
6457
|
+
*
|
|
6458
|
+
* Grouped convolutions are not supported right now.
|
|
6459
|
+
*/
|
|
6460
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6461
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
6462
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
6463
|
+
if (typeof padding === "string") {
|
|
6464
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
6465
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
6466
|
+
}
|
|
6467
|
+
if (featureGroupCount !== 1) {
|
|
6468
|
+
const G = featureGroupCount;
|
|
6469
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
6470
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
6471
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
6472
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
6473
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
6474
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
6475
|
+
N,
|
|
6476
|
+
G,
|
|
6477
|
+
C_in / G,
|
|
6478
|
+
...xs
|
|
6479
|
+
]), 1, 0);
|
|
6480
|
+
const rhsGrouped = rhs.reshape([
|
|
6481
|
+
G,
|
|
6482
|
+
C_out / G,
|
|
6483
|
+
C_in_per_group,
|
|
6484
|
+
...ks
|
|
6485
|
+
]);
|
|
6486
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
6487
|
+
vmapDims: 1,
|
|
6488
|
+
strides: windowStrides,
|
|
6489
|
+
padding,
|
|
6490
|
+
lhsDilation,
|
|
6491
|
+
rhsDilation
|
|
6492
|
+
});
|
|
6493
|
+
const ys = result.shape.slice(3);
|
|
6494
|
+
return moveaxis(result, 0, 1).reshape([
|
|
6495
|
+
N,
|
|
6496
|
+
C_out,
|
|
6497
|
+
...ys
|
|
6498
|
+
]);
|
|
6499
|
+
}
|
|
6500
|
+
return conv$1(lhs, rhs, {
|
|
6501
|
+
strides: windowStrides,
|
|
6502
|
+
padding,
|
|
6503
|
+
lhsDilation,
|
|
6504
|
+
rhsDilation
|
|
6505
|
+
});
|
|
6506
|
+
}
|
|
6507
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6508
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
6509
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
6510
|
+
lhsDilation,
|
|
6511
|
+
rhsDilation
|
|
6512
|
+
});
|
|
6513
|
+
}
|
|
6514
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6515
|
+
function conv(lhs, rhs, windowStrides, padding) {
|
|
6516
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
6517
|
+
}
|
|
6518
|
+
/** Reduce a computation over padded windows. */
|
|
6519
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
6520
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
6521
|
+
if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
|
|
6522
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
6523
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
6524
|
+
window: windowDimensions,
|
|
6525
|
+
strides: windowStrides
|
|
6526
|
+
}));
|
|
6527
|
+
}
|
|
6528
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
6529
|
+
function erf(x) {
|
|
6530
|
+
return erf$1(x);
|
|
6531
|
+
}
|
|
6532
|
+
/**
|
|
6533
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
6534
|
+
*
|
|
6535
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
6536
|
+
* where `erf(x)` is very close to 1.
|
|
6537
|
+
*/
|
|
6538
|
+
function erfc(x) {
|
|
6539
|
+
return erfc$1(x);
|
|
6540
|
+
}
|
|
6541
|
+
/**
|
|
6542
|
+
* Stops gradient computation.
|
|
6543
|
+
*
|
|
6544
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
6545
|
+
* forward or reverse-mode automatic differentiation.
|
|
6546
|
+
*/
|
|
6547
|
+
function stopGradient$1(x) {
|
|
6548
|
+
return stopGradient(x);
|
|
6549
|
+
}
|
|
6550
|
+
|
|
5785
6551
|
//#endregion
|
|
5786
6552
|
//#region src/library/nn.ts
|
|
5787
6553
|
var nn_exports = {};
|
|
@@ -5790,6 +6556,10 @@ __export(nn_exports, {
|
|
|
5790
6556
|
elu: () => elu,
|
|
5791
6557
|
gelu: () => gelu,
|
|
5792
6558
|
glu: () => glu,
|
|
6559
|
+
hardSigmoid: () => hardSigmoid,
|
|
6560
|
+
hardSilu: () => hardSilu,
|
|
6561
|
+
hardSwish: () => hardSilu,
|
|
6562
|
+
hardTanh: () => hardTanh,
|
|
5793
6563
|
identity: () => identity,
|
|
5794
6564
|
leakyRelu: () => leakyRelu,
|
|
5795
6565
|
logSigmoid: () => logSigmoid,
|
|
@@ -5800,14 +6570,17 @@ __export(nn_exports, {
|
|
|
5800
6570
|
oneHot: () => oneHot,
|
|
5801
6571
|
relu: () => relu,
|
|
5802
6572
|
relu6: () => relu6,
|
|
6573
|
+
selu: () => selu,
|
|
5803
6574
|
sigmoid: () => sigmoid,
|
|
5804
6575
|
silu: () => silu,
|
|
5805
6576
|
softSign: () => softSign,
|
|
5806
6577
|
softmax: () => softmax,
|
|
5807
6578
|
softplus: () => softplus,
|
|
6579
|
+
sparsePlus: () => sparsePlus,
|
|
6580
|
+
sparseSigmoid: () => sparseSigmoid,
|
|
5808
6581
|
squareplus: () => squareplus,
|
|
5809
6582
|
standardize: () => standardize,
|
|
5810
|
-
swish: () =>
|
|
6583
|
+
swish: () => silu
|
|
5811
6584
|
});
|
|
5812
6585
|
/**
|
|
5813
6586
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -5842,6 +6615,28 @@ function softplus(x) {
|
|
|
5842
6615
|
return log(exp(x).add(1));
|
|
5843
6616
|
}
|
|
5844
6617
|
/**
|
|
6618
|
+
* @function
|
|
6619
|
+
* Sparse plus function:
|
|
6620
|
+
*
|
|
6621
|
+
* - When `x <= -1`: `0`
|
|
6622
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
6623
|
+
* - When `x >= 1`: `x`
|
|
6624
|
+
*/
|
|
6625
|
+
const sparsePlus = jit$1((x) => {
|
|
6626
|
+
return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
|
|
6627
|
+
});
|
|
6628
|
+
/**
|
|
6629
|
+
* @function
|
|
6630
|
+
* Sparse sigmoid activation function.
|
|
6631
|
+
*
|
|
6632
|
+
* - When `x <= -1`: `0`
|
|
6633
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
6634
|
+
* - When `x >= 1`: `1`
|
|
6635
|
+
*/
|
|
6636
|
+
const sparseSigmoid = jit$1((x) => {
|
|
6637
|
+
return clip(x.add(1).mul(.5), 0, 1);
|
|
6638
|
+
});
|
|
6639
|
+
/**
|
|
5845
6640
|
* Soft-sign activation function, computed element-wise:
|
|
5846
6641
|
* `softsign(x) = x / (|x| + 1)`.
|
|
5847
6642
|
*/
|
|
@@ -5863,17 +6658,6 @@ const silu = jit$1(function silu$1(x) {
|
|
|
5863
6658
|
return x.ref.mul(sigmoid(x));
|
|
5864
6659
|
});
|
|
5865
6660
|
/**
|
|
5866
|
-
* @function
|
|
5867
|
-
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
5868
|
-
* Swish, computed element-wise:
|
|
5869
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
5870
|
-
*
|
|
5871
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
5872
|
-
*
|
|
5873
|
-
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
5874
|
-
*/
|
|
5875
|
-
const swish = silu;
|
|
5876
|
-
/**
|
|
5877
6661
|
* Log-sigmoid activation function, computed element-wise:
|
|
5878
6662
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
5879
6663
|
*/
|
|
@@ -5890,6 +6674,19 @@ function leakyRelu(x, negativeSlope = .01) {
|
|
|
5890
6674
|
x = fudgeArray(x);
|
|
5891
6675
|
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
5892
6676
|
}
|
|
6677
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
6678
|
+
function hardSigmoid(x) {
|
|
6679
|
+
return relu6(add(x, 3)).mul(1 / 6);
|
|
6680
|
+
}
|
|
6681
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
6682
|
+
function hardSilu(x) {
|
|
6683
|
+
x = fudgeArray(x);
|
|
6684
|
+
return x.ref.mul(hardSigmoid(x));
|
|
6685
|
+
}
|
|
6686
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
6687
|
+
function hardTanh(x) {
|
|
6688
|
+
return clip(x, -1, 1);
|
|
6689
|
+
}
|
|
5893
6690
|
/**
|
|
5894
6691
|
* Exponential linear unit activation function.
|
|
5895
6692
|
*
|
|
@@ -5912,6 +6709,20 @@ function celu(x, alpha = 1) {
|
|
|
5912
6709
|
}
|
|
5913
6710
|
/**
|
|
5914
6711
|
* @function
|
|
6712
|
+
* Scaled exponential linear unit activation.
|
|
6713
|
+
*
|
|
6714
|
+
* Computes the element-wise function:
|
|
6715
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
6716
|
+
*
|
|
6717
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
6718
|
+
*/
|
|
6719
|
+
const selu = jit$1(function selu$1(x) {
|
|
6720
|
+
const alpha = 1.6732632423543772;
|
|
6721
|
+
const lambda = 1.0507009873554805;
|
|
6722
|
+
return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
|
|
6723
|
+
});
|
|
6724
|
+
/**
|
|
6725
|
+
* @function
|
|
5915
6726
|
* Gaussion error linear unit (GELU) activation function.
|
|
5916
6727
|
*
|
|
5917
6728
|
* This is computed element-wise. There are two variants depending on whether
|
|
@@ -6005,22 +6816,22 @@ function logSoftmax(x, axis = -1) {
|
|
|
6005
6816
|
*
|
|
6006
6817
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
6007
6818
|
*/
|
|
6008
|
-
function logsumexp(x, axis = null) {
|
|
6819
|
+
function logsumexp(x, axis = null, opts) {
|
|
6009
6820
|
x = fudgeArray(x);
|
|
6010
6821
|
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
6011
6822
|
if (axis.length === 0) return x;
|
|
6012
|
-
const xMax = stopGradient(max(x.ref, axis));
|
|
6013
|
-
const
|
|
6014
|
-
const
|
|
6015
|
-
return
|
|
6823
|
+
const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
|
|
6824
|
+
const shifted = x.sub(xMax.ref);
|
|
6825
|
+
const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
|
|
6826
|
+
return opts?.keepdims ? result : squeeze(result, axis);
|
|
6016
6827
|
}
|
|
6017
6828
|
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
6018
|
-
function logmeanexp(x, axis = null) {
|
|
6829
|
+
function logmeanexp(x, axis = null, opts) {
|
|
6019
6830
|
x = fudgeArray(x);
|
|
6020
6831
|
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
6021
6832
|
if (axis.length === 0) return x;
|
|
6022
6833
|
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
6023
|
-
return logsumexp(x, axis).sub(Math.log(n));
|
|
6834
|
+
return logsumexp(x, axis, opts).sub(Math.log(n));
|
|
6024
6835
|
}
|
|
6025
6836
|
/**
|
|
6026
6837
|
* Standardizes input to zero mean and unit variance.
|
|
@@ -6065,8 +6876,11 @@ var random_exports = {};
|
|
|
6065
6876
|
__export(random_exports, {
|
|
6066
6877
|
bernoulli: () => bernoulli,
|
|
6067
6878
|
bits: () => bits,
|
|
6879
|
+
cauchy: () => cauchy,
|
|
6068
6880
|
exponential: () => exponential,
|
|
6881
|
+
gumbel: () => gumbel,
|
|
6069
6882
|
key: () => key,
|
|
6883
|
+
laplace: () => laplace,
|
|
6070
6884
|
normal: () => normal,
|
|
6071
6885
|
split: () => split,
|
|
6072
6886
|
uniform: () => uniform
|
|
@@ -6125,6 +6939,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
6125
6939
|
}
|
|
6126
6940
|
/**
|
|
6127
6941
|
* @function
|
|
6942
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
6943
|
+
*
|
|
6944
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
6945
|
+
*/
|
|
6946
|
+
const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
6947
|
+
const u = uniform(key$1, shape$1);
|
|
6948
|
+
return tan(u.sub(.5).mul(Math.PI));
|
|
6949
|
+
}, { staticArgnums: [1] });
|
|
6950
|
+
/**
|
|
6951
|
+
* @function
|
|
6128
6952
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
6129
6953
|
*/
|
|
6130
6954
|
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
@@ -6133,6 +6957,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
6133
6957
|
}, { staticArgnums: [1] });
|
|
6134
6958
|
/**
|
|
6135
6959
|
* @function
|
|
6960
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
6961
|
+
*
|
|
6962
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
6963
|
+
*/
|
|
6964
|
+
const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
|
|
6965
|
+
const u = uniform(key$1, shape$1);
|
|
6966
|
+
return negative(log(negative(log1p(negative(u)))));
|
|
6967
|
+
}, { staticArgnums: [1] });
|
|
6968
|
+
/**
|
|
6969
|
+
* @function
|
|
6970
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
6971
|
+
*
|
|
6972
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
6973
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
6974
|
+
*/
|
|
6975
|
+
const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
6976
|
+
const u = uniform(key$1, shape$1);
|
|
6977
|
+
const centered = u.sub(.5);
|
|
6978
|
+
const s = sign(centered.ref);
|
|
6979
|
+
const absVal = absolute(centered);
|
|
6980
|
+
return s.mul(log1p(absVal.mul(-2)).mul(-1));
|
|
6981
|
+
}, { staticArgnums: [1] });
|
|
6982
|
+
/**
|
|
6983
|
+
* @function
|
|
6136
6984
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6137
6985
|
*
|
|
6138
6986
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -6241,11 +7089,6 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
6241
7089
|
*/
|
|
6242
7090
|
const jacrev = jacrev$1;
|
|
6243
7091
|
/**
|
|
6244
|
-
* @function
|
|
6245
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
6246
|
-
*/
|
|
6247
|
-
const jacobian = jacrev;
|
|
6248
|
-
/**
|
|
6249
7092
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
6250
7093
|
*
|
|
6251
7094
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -6281,6 +7124,7 @@ async function devicePut(x, device) {
|
|
|
6281
7124
|
|
|
6282
7125
|
//#endregion
|
|
6283
7126
|
exports.Array = Array$1;
|
|
7127
|
+
exports.ClosedJaxpr = ClosedJaxpr;
|
|
6284
7128
|
exports.DType = require_backend.DType;
|
|
6285
7129
|
exports.Jaxpr = Jaxpr;
|
|
6286
7130
|
exports.blockUntilReady = blockUntilReady;
|
|
@@ -6290,7 +7134,7 @@ exports.devices = require_backend.devices;
|
|
|
6290
7134
|
exports.grad = grad;
|
|
6291
7135
|
exports.init = require_backend.init;
|
|
6292
7136
|
exports.jacfwd = jacfwd;
|
|
6293
|
-
exports.jacobian =
|
|
7137
|
+
exports.jacobian = jacrev;
|
|
6294
7138
|
exports.jacrev = jacrev;
|
|
6295
7139
|
exports.jit = jit;
|
|
6296
7140
|
exports.jvp = jvp;
|
|
@@ -6335,5 +7179,4 @@ Object.defineProperty(exports, 'tree', {
|
|
|
6335
7179
|
});
|
|
6336
7180
|
exports.valueAndGrad = valueAndGrad;
|
|
6337
7181
|
exports.vjp = vjp;
|
|
6338
|
-
exports.vmap = vmap;
|
|
6339
|
-
//# sourceMappingURL=index.cjs.map
|
|
7182
|
+
exports.vmap = vmap;
|