@jax-js/jax 0.1.2 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +16 -34
- package/dist/{backend-DeVfWEFS.cjs → backend-Bu9GY6sK.cjs} +222 -36
- package/dist/{backend-BqymqzuU.js → backend-tngXtWe4.js} +204 -36
- package/dist/index.cjs +1798 -955
- package/dist/index.d.cts +383 -97
- package/dist/index.d.ts +383 -97
- package/dist/index.js +1791 -949
- package/dist/{webgpu-BGuG58KZ.js → webgpu-ChVgx3b6.js} +410 -97
- package/dist/{webgpu-CcGP160M.cjs → webgpu-Oj3Kd-kd.cjs} +410 -97
- package/package.json +1 -1
package/dist/index.js
CHANGED
|
@@ -1,28 +1,36 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, Routine, Routines, ShapeTracker, accessorAluExp, accessorGlobal, assertNonNull, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, generalBroadcast, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-tngXtWe4.js";
|
|
3
3
|
|
|
4
4
|
//#region src/frontend/convolution.ts
|
|
5
5
|
/**
|
|
6
6
|
* Check that the shapes and parameters passed to convolution are valid.
|
|
7
|
+
* Expected shapes of the lhs and rhs of the convolution are:
|
|
8
|
+
*
|
|
9
|
+
* - `lhsShape = [*vmapDims, batchSize, inChannels, spatialDims...]`
|
|
10
|
+
* - `rhsShape = [*vmapDims, outChannels, inChannels, kernelSize...]`
|
|
7
11
|
*
|
|
8
12
|
* If the check succeeds, returns the output shape.
|
|
9
13
|
*/
|
|
10
|
-
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
14
|
+
function checkConvShape(lhsShape, rhsShape, { vmapDims, strides, padding, lhsDilation, rhsDilation }) {
|
|
11
15
|
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
12
|
-
const n = lhsShape.length - 2;
|
|
16
|
+
const n = lhsShape.length - 2 - vmapDims;
|
|
13
17
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
14
18
|
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
15
19
|
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
16
20
|
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
17
21
|
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
18
|
-
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
19
|
-
const outShape = [
|
|
22
|
+
if (lhsShape[vmapDims + 1] !== rhsShape[vmapDims + 1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
23
|
+
const outShape = [
|
|
24
|
+
...generalBroadcast(lhsShape.slice(0, vmapDims), rhsShape.slice(0, vmapDims)),
|
|
25
|
+
lhsShape[vmapDims],
|
|
26
|
+
rhsShape[vmapDims]
|
|
27
|
+
];
|
|
20
28
|
for (let i = 0; i < n; i++) {
|
|
21
29
|
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
22
30
|
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
23
31
|
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
24
32
|
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
25
|
-
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
33
|
+
const [x, k] = [lhsShape[i + vmapDims + 2], rhsShape[i + vmapDims + 2]];
|
|
26
34
|
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
27
35
|
const [pl, pr] = padding[i];
|
|
28
36
|
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
@@ -147,27 +155,13 @@ function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
|
147
155
|
function applyDilation(st, dilation) {
|
|
148
156
|
if (dilation.every((s) => s === 1)) return st;
|
|
149
157
|
const s_ = dilation;
|
|
150
|
-
const
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
]);
|
|
156
|
-
st = st.
|
|
157
|
-
[0, 0],
|
|
158
|
-
[0, 0],
|
|
159
|
-
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
160
|
-
]);
|
|
161
|
-
st = st.reshape([
|
|
162
|
-
a,
|
|
163
|
-
b,
|
|
164
|
-
...k_.map((k, i) => k * s_[i])
|
|
165
|
-
]);
|
|
166
|
-
st = st.shrink([
|
|
167
|
-
[0, a],
|
|
168
|
-
[0, b],
|
|
169
|
-
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
170
|
-
]);
|
|
158
|
+
const n = s_.length;
|
|
159
|
+
const prefix = st.shape.slice(0, -n);
|
|
160
|
+
const k_ = st.shape.slice(-n);
|
|
161
|
+
st = st.reshape([...prefix, ...k_.flatMap((k) => [k, 1])]);
|
|
162
|
+
st = st.pad([...prefix.map(() => [0, 0]), ...s_.flatMap((s) => [[0, 0], [0, s - 1]])]);
|
|
163
|
+
st = st.reshape([...prefix, ...k_.map((k, i) => k * s_[i])]);
|
|
164
|
+
st = st.shrink([...prefix.map((p) => [0, p]), ...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])]);
|
|
171
165
|
return st;
|
|
172
166
|
}
|
|
173
167
|
/**
|
|
@@ -177,25 +171,26 @@ function applyDilation(st, dilation) {
|
|
|
177
171
|
* beforehand using `checkConvShape()`.
|
|
178
172
|
*/
|
|
179
173
|
function prepareConv(stX, stY, params) {
|
|
180
|
-
const
|
|
174
|
+
const v = params.vmapDims;
|
|
175
|
+
const n = stX.shape.length - 2 - v;
|
|
176
|
+
const vmapShape = stX.shape.slice(0, v);
|
|
181
177
|
stX = applyDilation(stX, params.lhsDilation);
|
|
182
|
-
const ks = stY.shape.slice(2);
|
|
183
|
-
stX = stX.padOrShrink([
|
|
184
|
-
[0, 0],
|
|
185
|
-
[0, 0],
|
|
186
|
-
...params.padding
|
|
187
|
-
]);
|
|
178
|
+
const ks = stY.shape.slice(v + 2);
|
|
179
|
+
stX = stX.padOrShrink([...rep(v + 2, [0, 0]), ...params.padding]);
|
|
188
180
|
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
189
|
-
stX = stX.moveaxis(1, n + 1).reshape([
|
|
190
|
-
|
|
181
|
+
stX = stX.moveaxis(v + 1, v + n + 1).reshape([
|
|
182
|
+
...vmapShape,
|
|
183
|
+
stX.shape[v],
|
|
191
184
|
1,
|
|
192
|
-
...stX.shape.slice(2, n + 2),
|
|
193
|
-
stX.shape[1] * prod(ks)
|
|
185
|
+
...stX.shape.slice(v + 2, v + n + 2),
|
|
186
|
+
stX.shape[v + 1] * prod(ks)
|
|
194
187
|
]);
|
|
195
188
|
stY = stY.reshape([
|
|
196
|
-
|
|
189
|
+
...vmapShape,
|
|
190
|
+
1,
|
|
191
|
+
stY.shape[v],
|
|
197
192
|
...rep(n, 1),
|
|
198
|
-
stY.shape[1] * prod(ks)
|
|
193
|
+
stY.shape[v + 1] * prod(ks)
|
|
199
194
|
]);
|
|
200
195
|
return [stX, stY];
|
|
201
196
|
}
|
|
@@ -336,6 +331,8 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
336
331
|
Primitive$1["Mul"] = "mul";
|
|
337
332
|
Primitive$1["Idiv"] = "idiv";
|
|
338
333
|
Primitive$1["Mod"] = "mod";
|
|
334
|
+
Primitive$1["Min"] = "min";
|
|
335
|
+
Primitive$1["Max"] = "max";
|
|
339
336
|
Primitive$1["Neg"] = "neg";
|
|
340
337
|
Primitive$1["Reciprocal"] = "reciprocal";
|
|
341
338
|
Primitive$1["Floor"] = "floor";
|
|
@@ -343,7 +340,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
343
340
|
Primitive$1["StopGradient"] = "stop_gradient";
|
|
344
341
|
Primitive$1["Cast"] = "cast";
|
|
345
342
|
Primitive$1["Bitcast"] = "bitcast";
|
|
346
|
-
Primitive$1["RandomBits"] = "random_bits";
|
|
347
343
|
Primitive$1["Sin"] = "sin";
|
|
348
344
|
Primitive$1["Cos"] = "cos";
|
|
349
345
|
Primitive$1["Asin"] = "asin";
|
|
@@ -353,8 +349,6 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
353
349
|
Primitive$1["Erf"] = "erf";
|
|
354
350
|
Primitive$1["Erfc"] = "erfc";
|
|
355
351
|
Primitive$1["Sqrt"] = "sqrt";
|
|
356
|
-
Primitive$1["Min"] = "min";
|
|
357
|
-
Primitive$1["Max"] = "max";
|
|
358
352
|
Primitive$1["Reduce"] = "reduce";
|
|
359
353
|
Primitive$1["Dot"] = "dot";
|
|
360
354
|
Primitive$1["Conv"] = "conv";
|
|
@@ -362,14 +356,19 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
362
356
|
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
363
357
|
Primitive$1["Compare"] = "compare";
|
|
364
358
|
Primitive$1["Where"] = "where";
|
|
359
|
+
Primitive$1["RandomBits"] = "random_bits";
|
|
360
|
+
Primitive$1["Gather"] = "gather";
|
|
365
361
|
Primitive$1["Transpose"] = "transpose";
|
|
366
362
|
Primitive$1["Broadcast"] = "broadcast";
|
|
367
363
|
Primitive$1["Reshape"] = "reshape";
|
|
368
364
|
Primitive$1["Flip"] = "flip";
|
|
369
365
|
Primitive$1["Shrink"] = "shrink";
|
|
370
366
|
Primitive$1["Pad"] = "pad";
|
|
371
|
-
Primitive$1["
|
|
372
|
-
Primitive$1["
|
|
367
|
+
Primitive$1["Sort"] = "sort";
|
|
368
|
+
Primitive$1["Argsort"] = "argsort";
|
|
369
|
+
Primitive$1["TriangularSolve"] = "triangular_solve";
|
|
370
|
+
Primitive$1["Cholesky"] = "cholesky";
|
|
371
|
+
Primitive$1["Jit"] = "jit";
|
|
373
372
|
return Primitive$1;
|
|
374
373
|
}({});
|
|
375
374
|
let CompareOp = /* @__PURE__ */ function(CompareOp$1) {
|
|
@@ -391,6 +390,12 @@ function idiv(x, y) {
|
|
|
391
390
|
function mod(x, y) {
|
|
392
391
|
return bind1(Primitive.Mod, [x, y]);
|
|
393
392
|
}
|
|
393
|
+
function min$1(x, y) {
|
|
394
|
+
return bind1(Primitive.Min, [x, y]);
|
|
395
|
+
}
|
|
396
|
+
function max$1(x, y) {
|
|
397
|
+
return bind1(Primitive.Max, [x, y]);
|
|
398
|
+
}
|
|
394
399
|
function neg(x) {
|
|
395
400
|
return bind1(Primitive.Neg, [x]);
|
|
396
401
|
}
|
|
@@ -412,12 +417,6 @@ function cast(x, dtype) {
|
|
|
412
417
|
function bitcast(x, dtype) {
|
|
413
418
|
return bind1(Primitive.Bitcast, [x], { dtype });
|
|
414
419
|
}
|
|
415
|
-
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
416
|
-
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
417
|
-
shape: shape$1,
|
|
418
|
-
mode
|
|
419
|
-
});
|
|
420
|
-
}
|
|
421
420
|
function sin$1(x) {
|
|
422
421
|
return bind1(Primitive.Sin, [x]);
|
|
423
422
|
}
|
|
@@ -445,12 +444,6 @@ function erfc$1(x) {
|
|
|
445
444
|
function sqrt$1(x) {
|
|
446
445
|
return bind1(Primitive.Sqrt, [x]);
|
|
447
446
|
}
|
|
448
|
-
function min$1(x, y) {
|
|
449
|
-
return bind1(Primitive.Min, [x, y]);
|
|
450
|
-
}
|
|
451
|
-
function max$1(x, y) {
|
|
452
|
-
return bind1(Primitive.Max, [x, y]);
|
|
453
|
-
}
|
|
454
447
|
function reduce(x, op, axis = null, opts) {
|
|
455
448
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
456
449
|
axis = normalizeAxis(axis, ndim$1(x));
|
|
@@ -467,9 +460,11 @@ function dot$2(x, y) {
|
|
|
467
460
|
}
|
|
468
461
|
function conv$1(x, y, params = {}) {
|
|
469
462
|
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
470
|
-
const
|
|
463
|
+
const vmapDims = params.vmapDims ?? 0;
|
|
464
|
+
const n = x.ndim - 2 - vmapDims;
|
|
471
465
|
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
472
466
|
return bind1(Primitive.Conv, [x, y], {
|
|
467
|
+
vmapDims,
|
|
473
468
|
strides: params.strides ?? rep(n, 1),
|
|
474
469
|
padding: params.padding ?? rep(n, [0, 0]),
|
|
475
470
|
lhsDilation: params.lhsDilation ?? rep(n, 1),
|
|
@@ -504,6 +499,23 @@ function where$1(cond, x, y) {
|
|
|
504
499
|
y
|
|
505
500
|
]);
|
|
506
501
|
}
|
|
502
|
+
function randomBits(k0, k1, shape$1, mode = "xor") {
|
|
503
|
+
return bind1(Primitive.RandomBits, [k0, k1], {
|
|
504
|
+
shape: shape$1,
|
|
505
|
+
mode
|
|
506
|
+
});
|
|
507
|
+
}
|
|
508
|
+
function gather(x, indices, axis, outDim) {
|
|
509
|
+
if (indices.length === 0) throw new Error("gather() requires at least one index");
|
|
510
|
+
if (!Array.isArray(axis) || axis.length !== indices.length) throw new Error(`Invalid gather() axis: expected ${indices.length} axes, got ${JSON.stringify(axis)}`);
|
|
511
|
+
axis = axis.map((a) => checkAxis(a, ndim$1(x)));
|
|
512
|
+
if (new Set(axis).size !== axis.length) throw new Error(`Invalid gather() axis: duplicate axes ${JSON.stringify(axis)}`);
|
|
513
|
+
outDim = checkAxis(outDim, ndim$1(x) - axis.length + 1);
|
|
514
|
+
return bind1(Primitive.Gather, [x, ...indices], {
|
|
515
|
+
axis,
|
|
516
|
+
outDim
|
|
517
|
+
});
|
|
518
|
+
}
|
|
507
519
|
function transpose$1(x, perm) {
|
|
508
520
|
perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
|
|
509
521
|
if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
@@ -553,16 +565,27 @@ function pad$1(x, width) {
|
|
|
553
565
|
} else if (width.length !== nd) throw new Error(`Invalid pad(): expected ${nd} axes, got ${width.length}`);
|
|
554
566
|
return bind1(Primitive.Pad, [x], { width });
|
|
555
567
|
}
|
|
556
|
-
function
|
|
557
|
-
if (
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
568
|
+
function triangularSolve$1(a, b, { lower = false, unitDiagonal = false } = {}) {
|
|
569
|
+
if (lower) {
|
|
570
|
+
a = flip$1(a, [-2, -1]);
|
|
571
|
+
b = flip$1(b, [-1]);
|
|
572
|
+
}
|
|
573
|
+
let x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
574
|
+
if (lower) x = flip$1(x, [-1]);
|
|
575
|
+
return x;
|
|
576
|
+
}
|
|
577
|
+
function cholesky$2(x) {
|
|
578
|
+
return bind1(Primitive.Cholesky, [x]);
|
|
579
|
+
}
|
|
580
|
+
function sort$1(x) {
|
|
581
|
+
const nd = ndim$1(x);
|
|
582
|
+
if (nd === 0) throw new Error("sort: requires at least 1D input");
|
|
583
|
+
return bind1(Primitive.Sort, [x]);
|
|
584
|
+
}
|
|
585
|
+
function argsort$1(x) {
|
|
586
|
+
const nd = ndim$1(x);
|
|
587
|
+
if (nd === 0) throw new Error("argsort: requires at least 1D input");
|
|
588
|
+
return bind(Primitive.Argsort, [x]);
|
|
566
589
|
}
|
|
567
590
|
function bind1(prim, args, params = {}) {
|
|
568
591
|
const [results] = bind(prim, args, params);
|
|
@@ -693,8 +716,10 @@ var Tracer = class Tracer {
|
|
|
693
716
|
axis = normalizeAxis(axis, this.ndim);
|
|
694
717
|
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
695
718
|
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
696
|
-
const
|
|
697
|
-
|
|
719
|
+
const originalDtype = this.dtype;
|
|
720
|
+
const castDtype = promoteTypes(originalDtype, DType.Float32);
|
|
721
|
+
const result = reduce(this.astype(castDtype), AluOp.Add, axis, opts);
|
|
722
|
+
return result.mul(1 / n).astype(originalDtype);
|
|
698
723
|
}
|
|
699
724
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
700
725
|
transpose(perm) {
|
|
@@ -723,7 +748,7 @@ var Tracer = class Tracer {
|
|
|
723
748
|
if (isFloatDtype(this.dtype)) return this.mul(reciprocal$1(other));
|
|
724
749
|
return idiv(this, other);
|
|
725
750
|
}
|
|
726
|
-
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
751
|
+
/** Return specified diagonals. See `jax.numpy.diagonal` for full docs. */
|
|
727
752
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
728
753
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
729
754
|
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
@@ -776,6 +801,34 @@ var Tracer = class Tracer {
|
|
|
776
801
|
this.dispose();
|
|
777
802
|
}
|
|
778
803
|
/**
|
|
804
|
+
* Return a sorted copy of an array in ascending order.
|
|
805
|
+
*
|
|
806
|
+
* See `jax.numpy.sort` for full docs.
|
|
807
|
+
*/
|
|
808
|
+
sort(axis = -1) {
|
|
809
|
+
axis = checkAxis(axis, this.ndim);
|
|
810
|
+
if (this.shape[axis] <= 1) return this;
|
|
811
|
+
if (axis === this.ndim - 1) return sort$1(this);
|
|
812
|
+
const perm = range(this.ndim);
|
|
813
|
+
perm.splice(axis, 1);
|
|
814
|
+
perm.push(axis);
|
|
815
|
+
return sort$1(this.transpose(perm)).transpose(invertPermutation(perm));
|
|
816
|
+
}
|
|
817
|
+
/**
|
|
818
|
+
* Return the indices that would sort an array. This may not be a stable
|
|
819
|
+
* sorting algorithm; it need not preserve order of indices in ties.
|
|
820
|
+
*
|
|
821
|
+
* See `jax.numpy.argsort` for full docs.
|
|
822
|
+
*/
|
|
823
|
+
argsort(axis = -1) {
|
|
824
|
+
axis = checkAxis(axis, this.ndim);
|
|
825
|
+
if (axis === this.ndim - 1) return argsort$1(this)[1];
|
|
826
|
+
const perm = range(this.ndim);
|
|
827
|
+
perm.splice(axis, 1);
|
|
828
|
+
perm.push(axis);
|
|
829
|
+
return argsort$1(this.transpose(perm))[1].transpose(invertPermutation(perm));
|
|
830
|
+
}
|
|
831
|
+
/**
|
|
779
832
|
* Slice an array along one or more axes.
|
|
780
833
|
*
|
|
781
834
|
* This is the equivalent of slicing in Python, e.g. `x[1:3, 2, :, None]`. To
|
|
@@ -892,6 +945,9 @@ var ShapedArray = class ShapedArray {
|
|
|
892
945
|
get ndim() {
|
|
893
946
|
return this.shape.length;
|
|
894
947
|
}
|
|
948
|
+
get size() {
|
|
949
|
+
return prod(this.shape);
|
|
950
|
+
}
|
|
895
951
|
toString() {
|
|
896
952
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
897
953
|
}
|
|
@@ -1170,7 +1226,7 @@ var Jaxpr = class Jaxpr {
|
|
|
1170
1226
|
} else if (eqn.primitive === Primitive.Idiv) {
|
|
1171
1227
|
const [a, b] = inputs;
|
|
1172
1228
|
const c = eqn.outBinders[0];
|
|
1173
|
-
if (atomIsLit(b, 1)) context.set(c, a);
|
|
1229
|
+
if (atomIsLit(b, 1) && !isFloatDtype(a.aval.dtype)) context.set(c, a);
|
|
1174
1230
|
else newEqns.push(eqn);
|
|
1175
1231
|
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
1176
1232
|
else newEqns.push(eqn);
|
|
@@ -1187,13 +1243,13 @@ var Jaxpr = class Jaxpr {
|
|
|
1187
1243
|
}
|
|
1188
1244
|
return new Jaxpr(this.inBinders, liveEqns.reverse(), outs);
|
|
1189
1245
|
}
|
|
1190
|
-
/** Flattens nested
|
|
1246
|
+
/** Flattens nested Jit in a Jaxpr. Useful for handling jit-of-jit. */
|
|
1191
1247
|
flatten() {
|
|
1192
|
-
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.
|
|
1248
|
+
if (!this.eqns.some((eqn) => eqn.primitive === Primitive.Jit)) return this;
|
|
1193
1249
|
const newEqns = [];
|
|
1194
1250
|
const varMap = /* @__PURE__ */ new Map();
|
|
1195
1251
|
const varMapF = (x) => x instanceof Var ? varMap.get(x) ?? x : x;
|
|
1196
|
-
for (const eqn of this.eqns) if (eqn.primitive === Primitive.
|
|
1252
|
+
for (const eqn of this.eqns) if (eqn.primitive === Primitive.Jit) {
|
|
1197
1253
|
const jaxpr = eqn.params.jaxpr.flatten();
|
|
1198
1254
|
const translation = /* @__PURE__ */ new Map();
|
|
1199
1255
|
const translationF = (x) => x instanceof Var ? translation.get(x) : x;
|
|
@@ -1294,19 +1350,48 @@ function evalJaxpr(jaxpr, args) {
|
|
|
1294
1350
|
function jaxprAsFun(jaxpr) {
|
|
1295
1351
|
return (...args) => evalJaxpr(jaxpr, args);
|
|
1296
1352
|
}
|
|
1353
|
+
/** Jaxpr with a collection of associated, traced constants. */
|
|
1354
|
+
var ClosedJaxpr = class ClosedJaxpr {
|
|
1355
|
+
constructor(jaxpr, consts) {
|
|
1356
|
+
this.jaxpr = jaxpr;
|
|
1357
|
+
this.consts = consts;
|
|
1358
|
+
}
|
|
1359
|
+
/** String representation of this Jaxpr. */
|
|
1360
|
+
toString() {
|
|
1361
|
+
return this.jaxpr.toString();
|
|
1362
|
+
}
|
|
1363
|
+
/** Apply a function to the underlying Jaxpr. */
|
|
1364
|
+
mapJaxpr(f) {
|
|
1365
|
+
return new ClosedJaxpr(f(this.jaxpr), this.consts);
|
|
1366
|
+
}
|
|
1367
|
+
/** Dispose of the constants in this Jaxpr. */
|
|
1368
|
+
dispose() {
|
|
1369
|
+
for (const c of this.consts) c.dispose();
|
|
1370
|
+
}
|
|
1371
|
+
};
|
|
1297
1372
|
/** Tracer that records its operations to dynamically construct a Jaxpr. */
|
|
1298
1373
|
var JaxprTracer = class extends Tracer {
|
|
1374
|
+
#rc;
|
|
1299
1375
|
constructor(trace$1, aval) {
|
|
1300
1376
|
super(trace$1);
|
|
1301
1377
|
this.aval = aval;
|
|
1378
|
+
this.#rc = 1;
|
|
1302
1379
|
}
|
|
1303
1380
|
toString() {
|
|
1304
1381
|
return `JaxprTracer(${this.aval.toString()})`;
|
|
1305
1382
|
}
|
|
1306
1383
|
get ref() {
|
|
1384
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1385
|
+
this.#rc++;
|
|
1307
1386
|
return this;
|
|
1308
1387
|
}
|
|
1309
|
-
dispose() {
|
|
1388
|
+
dispose() {
|
|
1389
|
+
if (this.#rc <= 0) throw new UseAfterFreeError(this);
|
|
1390
|
+
this.#rc--;
|
|
1391
|
+
}
|
|
1392
|
+
trackLiftedConstant() {
|
|
1393
|
+
this.#rc++;
|
|
1394
|
+
}
|
|
1310
1395
|
};
|
|
1311
1396
|
/** Analogous to the 'DynamicJaxprTrace' class in JAX. */
|
|
1312
1397
|
var JaxprTrace = class extends Trace {
|
|
@@ -1319,17 +1404,24 @@ var JaxprTrace = class extends Trace {
|
|
|
1319
1404
|
}
|
|
1320
1405
|
/** Register a constant / literal in this Jaxpr. */
|
|
1321
1406
|
getOrMakeConstTracer(val) {
|
|
1407
|
+
if (!(val instanceof Tracer)) val = pureArray(val);
|
|
1322
1408
|
let tracer = this.builder.constTracers.get(val);
|
|
1323
1409
|
if (tracer === void 0) {
|
|
1324
1410
|
tracer = this.builder.newTracer(this, ShapedArray.fromAval(getAval(val)));
|
|
1325
|
-
this.builder.addConst(tracer, val
|
|
1411
|
+
this.builder.addConst(tracer, val);
|
|
1412
|
+
} else {
|
|
1413
|
+
val.dispose();
|
|
1414
|
+
tracer.trackLiftedConstant();
|
|
1326
1415
|
}
|
|
1327
1416
|
return tracer;
|
|
1328
1417
|
}
|
|
1329
1418
|
pure = this.getOrMakeConstTracer;
|
|
1330
1419
|
lift = this.getOrMakeConstTracer;
|
|
1331
1420
|
processPrimitive(primitive, tracers, params) {
|
|
1332
|
-
const avalsIn = tracers.map((t) =>
|
|
1421
|
+
const avalsIn = tracers.map((t) => {
|
|
1422
|
+
t.dispose();
|
|
1423
|
+
return t.aval;
|
|
1424
|
+
});
|
|
1333
1425
|
const avalsOut = abstractEvalRules[primitive](avalsIn, params);
|
|
1334
1426
|
const outTracers = avalsOut.map((aval) => this.builder.newTracer(this, aval));
|
|
1335
1427
|
this.builder.addEqn(new JaxprEqn(primitive, tracers.map((t) => this.builder.getVar(t)), params, outTracers.map((t) => this.builder.addVar(t))));
|
|
@@ -1372,20 +1464,17 @@ var JaxprBuilder = class {
|
|
|
1372
1464
|
return v;
|
|
1373
1465
|
}
|
|
1374
1466
|
build(inTracers, outTracers) {
|
|
1375
|
-
|
|
1467
|
+
const [constVars, consts] = unzip2(this.constVals.entries());
|
|
1376
1468
|
const t2v = this.getVar.bind(this);
|
|
1377
1469
|
const inBinders = [...constVars, ...inTracers.map(t2v)];
|
|
1378
1470
|
const outVars = outTracers.map(t2v);
|
|
1379
|
-
|
|
1471
|
+
const jaxpr = new Jaxpr(inBinders, this.eqns, outVars);
|
|
1380
1472
|
typecheckJaxpr(jaxpr);
|
|
1381
|
-
|
|
1382
|
-
return
|
|
1383
|
-
jaxpr,
|
|
1384
|
-
consts
|
|
1385
|
-
};
|
|
1473
|
+
const cjaxpr = new ClosedJaxpr(jaxpr, consts);
|
|
1474
|
+
return _inlineLiterals(cjaxpr);
|
|
1386
1475
|
}
|
|
1387
1476
|
};
|
|
1388
|
-
function _inlineLiterals(jaxpr, consts) {
|
|
1477
|
+
function _inlineLiterals({ jaxpr, consts }) {
|
|
1389
1478
|
const literals = /* @__PURE__ */ new Map();
|
|
1390
1479
|
const constBinders = [];
|
|
1391
1480
|
const newConsts = [];
|
|
@@ -1400,7 +1489,7 @@ function _inlineLiterals(jaxpr, consts) {
|
|
|
1400
1489
|
const newOuts = jaxpr.outs.map((x) => literals.get(x) ?? x);
|
|
1401
1490
|
const newJaxpr = new Jaxpr([...constBinders, ...jaxpr.inBinders.slice(consts.length)], newEqns, newOuts);
|
|
1402
1491
|
typecheckJaxpr(newJaxpr);
|
|
1403
|
-
return
|
|
1492
|
+
return new ClosedJaxpr(newJaxpr, newConsts);
|
|
1404
1493
|
}
|
|
1405
1494
|
function binopAbstractEval([x, y]) {
|
|
1406
1495
|
if (!(x instanceof ShapedArray) || !(y instanceof ShapedArray)) throw new TypeError("binopAbstractEval expects ShapedArray inputs");
|
|
@@ -1419,6 +1508,8 @@ const abstractEvalRules = {
|
|
|
1419
1508
|
[Primitive.Mul]: binopAbstractEval,
|
|
1420
1509
|
[Primitive.Idiv]: binopAbstractEval,
|
|
1421
1510
|
[Primitive.Mod]: binopAbstractEval,
|
|
1511
|
+
[Primitive.Min]: binopAbstractEval,
|
|
1512
|
+
[Primitive.Max]: binopAbstractEval,
|
|
1422
1513
|
[Primitive.Neg]: vectorizedUnopAbstractEval,
|
|
1423
1514
|
[Primitive.Reciprocal]: vectorizedUnopAbstractEval,
|
|
1424
1515
|
[Primitive.Floor]: vectorizedUnopAbstractEval,
|
|
@@ -1432,12 +1523,6 @@ const abstractEvalRules = {
|
|
|
1432
1523
|
if (byteWidth(x.dtype) !== byteWidth(dtype)) throw new TypeError(`Bitcast from ${x.dtype} to ${dtype} with different byte width`);
|
|
1433
1524
|
return [new ShapedArray(x.shape, dtype, false)];
|
|
1434
1525
|
},
|
|
1435
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1436
|
-
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1437
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1438
|
-
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1439
|
-
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1440
|
-
},
|
|
1441
1526
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
1442
1527
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
1443
1528
|
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
@@ -1447,8 +1532,6 @@ const abstractEvalRules = {
|
|
|
1447
1532
|
[Primitive.Erf]: vectorizedUnopAbstractEval,
|
|
1448
1533
|
[Primitive.Erfc]: vectorizedUnopAbstractEval,
|
|
1449
1534
|
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
1450
|
-
[Primitive.Min]: binopAbstractEval,
|
|
1451
|
-
[Primitive.Max]: binopAbstractEval,
|
|
1452
1535
|
[Primitive.Reduce]([x], { axis }) {
|
|
1453
1536
|
const axisSet = new Set(axis);
|
|
1454
1537
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1481,6 +1564,25 @@ const abstractEvalRules = {
|
|
|
1481
1564
|
const shape$1 = generalBroadcast(cond.shape, xy.shape);
|
|
1482
1565
|
return [new ShapedArray(shape$1, xy.dtype, xy.weakType)];
|
|
1483
1566
|
},
|
|
1567
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1 }) {
|
|
1568
|
+
if (k0.dtype !== DType.Uint32 || k1.dtype !== DType.Uint32) throw new TypeError(`RandomBits requires uint32 keys, got ${k0.dtype} and ${k1.dtype}`);
|
|
1569
|
+
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
1570
|
+
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
1571
|
+
return [new ShapedArray(shape$1, DType.Uint32, false)];
|
|
1572
|
+
},
|
|
1573
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
1574
|
+
for (const a of indices) if (a.dtype !== DType.Int32 && a.dtype !== DType.Uint32) throw new TypeError(`Gather indices must be Int32 or Uint32, got ${a.dtype}`);
|
|
1575
|
+
if (axis.length !== indices.length) throw new TypeError(`Gather: ${axis} axes but ${indices.length} indices`);
|
|
1576
|
+
if (indices.length === 0) throw new TypeError("Gather must have 1+ indices with same shape");
|
|
1577
|
+
if (axis.some((a) => a < 0 || a >= x.shape.length)) throw new TypeError("Gather axis out of bounds");
|
|
1578
|
+
if (outDim < 0 || outDim > x.shape.length - axis.length) throw new TypeError("Gather outDim out of bounds");
|
|
1579
|
+
const axisSet = new Set(axis);
|
|
1580
|
+
if (axisSet.size !== axis.length) throw new TypeError("Gather axes are not unique");
|
|
1581
|
+
const gatherShape = indices.reduce((shape$1, a) => generalBroadcast(shape$1, a.shape), []);
|
|
1582
|
+
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
1583
|
+
newShape.splice(outDim, 0, ...gatherShape);
|
|
1584
|
+
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1585
|
+
},
|
|
1484
1586
|
[Primitive.Transpose]([x], { perm }) {
|
|
1485
1587
|
return [new ShapedArray(perm.map((i) => x.shape[i]), x.dtype, x.weakType)];
|
|
1486
1588
|
},
|
|
@@ -1501,23 +1603,31 @@ const abstractEvalRules = {
|
|
|
1501
1603
|
const newShape = x.shape.map((dim, i) => dim + width[i][0] + width[i][1]);
|
|
1502
1604
|
return [new ShapedArray(newShape, x.dtype, x.weakType)];
|
|
1503
1605
|
},
|
|
1504
|
-
[Primitive.
|
|
1505
|
-
|
|
1506
|
-
|
|
1507
|
-
|
|
1508
|
-
|
|
1509
|
-
if (
|
|
1510
|
-
|
|
1511
|
-
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
|
|
1515
|
-
|
|
1606
|
+
[Primitive.Sort]([x]) {
|
|
1607
|
+
if (x.ndim === 0) throw new TypeError("sort: requires at least 1D input");
|
|
1608
|
+
return [ShapedArray.fromAval(x)];
|
|
1609
|
+
},
|
|
1610
|
+
[Primitive.Argsort]([x]) {
|
|
1611
|
+
if (x.ndim === 0) throw new TypeError("argsort: requires at least 1D input");
|
|
1612
|
+
return [ShapedArray.fromAval(x), new ShapedArray(x.shape, DType.Int32, false)];
|
|
1613
|
+
},
|
|
1614
|
+
[Primitive.TriangularSolve]([a, b]) {
|
|
1615
|
+
if (a.ndim < 2) throw new TypeError(`triangular_solve: a must be at least 2D, got ${a}`);
|
|
1616
|
+
if (b.ndim < 2) throw new TypeError(`triangular_solve: b must be at least 2D, got ${b}`);
|
|
1617
|
+
const [m, n] = a.shape.slice(-2);
|
|
1618
|
+
const [_batch, q] = b.shape.slice(-2);
|
|
1619
|
+
if (!deepEqual(a.shape.slice(0, -2), b.shape.slice(0, -2)) || a.dtype !== b.dtype || m !== n || n !== q) throw new TypeError(`triangular_solve: mismatch ${a} vs ${b}`);
|
|
1620
|
+
return [new ShapedArray(b.shape, b.dtype, a.weakType && b.weakType)];
|
|
1621
|
+
},
|
|
1622
|
+
[Primitive.Cholesky]([a]) {
|
|
1623
|
+
if (a.ndim < 2) throw new TypeError(`cholesky: requires at least 2D input, got ${a}`);
|
|
1624
|
+
if (a.shape[a.ndim - 2] !== a.shape[a.ndim - 1]) throw new TypeError(`cholesky: must be square, got ${a}`);
|
|
1625
|
+
return [ShapedArray.fromAval(a)];
|
|
1516
1626
|
},
|
|
1517
|
-
[Primitive.
|
|
1627
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
1518
1628
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
1519
|
-
if (args.length !== inTypes.length) throw new TypeError(`
|
|
1520
|
-
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`
|
|
1629
|
+
if (args.length !== inTypes.length) throw new TypeError(`jit expected ${inTypes.length} arguments, got ${args.length}`);
|
|
1630
|
+
for (let i = 0; i < inTypes.length; i++) if (!args[i].equals(inTypes[i])) throw new TypeError(`jit argument ${i} has type ${args[i]}, expected ${inTypes[i]}`);
|
|
1521
1631
|
return outTypes;
|
|
1522
1632
|
}
|
|
1523
1633
|
};
|
|
@@ -1553,11 +1663,10 @@ function makeJaxpr$1(f, opts) {
|
|
|
1553
1663
|
const tracersIn = avalsIn.map((aval) => trace$1.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
1554
1664
|
const outs = fFlat(...tracersIn);
|
|
1555
1665
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
1556
|
-
const
|
|
1666
|
+
const jaxpr = builder.build(tracersIn, tracersOut);
|
|
1557
1667
|
if (outTree.value === void 0) throw new Error("outTree was not set in makeJaxpr");
|
|
1558
1668
|
return {
|
|
1559
|
-
jaxpr: jaxpr.simplify(),
|
|
1560
|
-
consts,
|
|
1669
|
+
jaxpr: jaxpr.mapJaxpr((j) => j.simplify()),
|
|
1561
1670
|
treedef: outTree.value
|
|
1562
1671
|
};
|
|
1563
1672
|
} catch (_) {
|
|
@@ -1576,22 +1685,28 @@ function jit$1(f, opts) {
|
|
|
1576
1685
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
1577
1686
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
1578
1687
|
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
1579
|
-
const { jaxpr,
|
|
1580
|
-
const outs = bind(Primitive.
|
|
1688
|
+
const { jaxpr, treedef: outTree } = runWithCache(cache, jaxprArgs, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
1689
|
+
const outs = bind(Primitive.Jit, [...jaxpr.consts.map((c) => c.ref), ...argsFlat], {
|
|
1581
1690
|
name: f.name || "closure",
|
|
1582
|
-
jaxpr,
|
|
1583
|
-
numConsts: consts.length
|
|
1691
|
+
jaxpr: jaxpr.jaxpr,
|
|
1692
|
+
numConsts: jaxpr.consts.length
|
|
1584
1693
|
});
|
|
1585
1694
|
return unflatten(outTree, outs);
|
|
1586
1695
|
});
|
|
1587
1696
|
result.dispose = () => {
|
|
1588
|
-
for (const {
|
|
1697
|
+
for (const { jaxpr } of cache.values()) jaxpr.dispose();
|
|
1589
1698
|
};
|
|
1590
1699
|
return result;
|
|
1591
1700
|
}
|
|
1592
1701
|
|
|
1593
1702
|
//#endregion
|
|
1594
1703
|
//#region src/frontend/jit.ts
|
|
1704
|
+
const routinePrimitives = new Map([
|
|
1705
|
+
[Primitive.Sort, Routines.Sort],
|
|
1706
|
+
[Primitive.Argsort, Routines.Argsort],
|
|
1707
|
+
[Primitive.TriangularSolve, Routines.TriangularSolve],
|
|
1708
|
+
[Primitive.Cholesky, Routines.Cholesky]
|
|
1709
|
+
]);
|
|
1595
1710
|
/** Result of compiling a Jaxpr. Can be evaluated on a series of inputs. */
|
|
1596
1711
|
var JitProgram = class {
|
|
1597
1712
|
constructor(backend, steps, inputs, outputs) {
|
|
@@ -1606,9 +1721,14 @@ var JitProgram = class {
|
|
|
1606
1721
|
case "execute": {
|
|
1607
1722
|
const inputsNice = step.inputs.map((id, i) => `${i}: %${id}`).join(", ");
|
|
1608
1723
|
const outputsNice = step.outputs.map((id) => `%${id}`).join(", ");
|
|
1609
|
-
|
|
1724
|
+
const executeText = `execute (${inputsNice}) -> ${outputsNice}`;
|
|
1725
|
+
if (step.source instanceof Kernel) return PPrint.pp(`${executeText}, kernel`).concat(step.source.pprint().indent(2));
|
|
1726
|
+
else if (step.source instanceof Routine) return PPrint.pp(`${executeText}, routine ${step.source.name}`);
|
|
1727
|
+
else {
|
|
1728
|
+
step.source;
|
|
1729
|
+
return PPrint.pp(executeText);
|
|
1730
|
+
}
|
|
1610
1731
|
}
|
|
1611
|
-
case "const": return PPrint.pp(`%${step.output} = const <Slot ${step.slot}>`);
|
|
1612
1732
|
case "malloc": return PPrint.pp(`%${step.output} = malloc <${step.size} bytes>`);
|
|
1613
1733
|
case "incref": return PPrint.pp(`incref ${step.input}`);
|
|
1614
1734
|
case "free": return PPrint.pp(`free ${step.input}`);
|
|
@@ -1631,12 +1751,9 @@ var JitProgram = class {
|
|
|
1631
1751
|
const inputs$1 = step.inputs.map((id) => scope.get(id));
|
|
1632
1752
|
const outputs = step.outputs.map((id) => scope.get(id));
|
|
1633
1753
|
if (inputs$1.some((s) => s === void 0) || outputs.some((s) => s === void 0)) throw new Error(`internal: JitProgram scope undefined`);
|
|
1634
|
-
pending.push(new PendingExecute(this.backend, step.
|
|
1754
|
+
pending.push(new PendingExecute(this.backend, step.source, inputs$1, outputs));
|
|
1635
1755
|
break;
|
|
1636
1756
|
}
|
|
1637
|
-
case "const":
|
|
1638
|
-
scope.set(step.output, step.slot);
|
|
1639
|
-
break;
|
|
1640
1757
|
case "malloc": {
|
|
1641
1758
|
const slot = this.backend.malloc(step.size);
|
|
1642
1759
|
scope.set(step.output, slot);
|
|
@@ -1670,34 +1787,37 @@ var JitProgramBuilder = class {
|
|
|
1670
1787
|
this.#nextId = nargs;
|
|
1671
1788
|
this.steps = [];
|
|
1672
1789
|
}
|
|
1673
|
-
pushConst(slot) {
|
|
1674
|
-
const id = this.#nextId++;
|
|
1675
|
-
this.steps.push({
|
|
1676
|
-
type: "const",
|
|
1677
|
-
slot,
|
|
1678
|
-
output: id
|
|
1679
|
-
});
|
|
1680
|
-
return id;
|
|
1681
|
-
}
|
|
1682
1790
|
pushLit(lit) {
|
|
1683
|
-
const kernel = new Kernel(0,
|
|
1791
|
+
const kernel = new Kernel(0, lit.aval.size, AluExp.const(lit.dtype, lit.value));
|
|
1684
1792
|
return this.pushKernel(kernel, []);
|
|
1685
1793
|
}
|
|
1686
|
-
|
|
1794
|
+
pushBuffer(size$1) {
|
|
1687
1795
|
const id = this.#nextId++;
|
|
1688
1796
|
this.steps.push({
|
|
1689
1797
|
type: "malloc",
|
|
1690
|
-
size:
|
|
1798
|
+
size: size$1,
|
|
1691
1799
|
output: id
|
|
1692
1800
|
});
|
|
1801
|
+
return id;
|
|
1802
|
+
}
|
|
1803
|
+
pushKernel(kernel, inputs) {
|
|
1804
|
+
const id = this.pushBuffer(kernel.bytes);
|
|
1693
1805
|
this.steps.push({
|
|
1694
1806
|
type: "execute",
|
|
1695
|
-
kernel,
|
|
1807
|
+
source: kernel,
|
|
1696
1808
|
inputs,
|
|
1697
1809
|
outputs: [id]
|
|
1698
1810
|
});
|
|
1699
1811
|
return id;
|
|
1700
1812
|
}
|
|
1813
|
+
pushRoutine(routine, inputs, outputs) {
|
|
1814
|
+
this.steps.push({
|
|
1815
|
+
type: "execute",
|
|
1816
|
+
source: routine,
|
|
1817
|
+
inputs,
|
|
1818
|
+
outputs
|
|
1819
|
+
});
|
|
1820
|
+
}
|
|
1701
1821
|
pushIncref(id) {
|
|
1702
1822
|
this.steps.push({
|
|
1703
1823
|
type: "incref",
|
|
@@ -1723,28 +1843,18 @@ var JitProgramBuilder = class {
|
|
|
1723
1843
|
}
|
|
1724
1844
|
};
|
|
1725
1845
|
const jitCompileCache = /* @__PURE__ */ new Map();
|
|
1726
|
-
function jitCompile(backend, jaxpr
|
|
1727
|
-
|
|
1728
|
-
for (let i = 0; i < consts.length; i++) if (consts[i].device !== backend.type) throw new TypeError(`Const ${i} has device ${consts[i].device}, but expected ${backend.type}`);
|
|
1729
|
-
const cacheKey = backend.type + FpHash.hash(jaxpr, ...consts.map((c) => c.id));
|
|
1846
|
+
function jitCompile(backend, jaxpr) {
|
|
1847
|
+
const cacheKey = backend.type + "," + FpHash.hash(jaxpr);
|
|
1730
1848
|
const cached = jitCompileCache.get(cacheKey);
|
|
1731
1849
|
if (cached) return cached;
|
|
1732
1850
|
if (DEBUG >= 1) console.info("=========== JIT Compile ===========\n" + jaxpr.toString());
|
|
1733
1851
|
jaxpr = jaxpr.flatten().simplify();
|
|
1734
|
-
const nargs = jaxpr.inBinders.length
|
|
1852
|
+
const nargs = jaxpr.inBinders.length;
|
|
1735
1853
|
const builder = new JitProgramBuilder(backend, nargs);
|
|
1736
1854
|
const blackNodes = splitGraphDataflow(backend, jaxpr);
|
|
1737
1855
|
const ctx = /* @__PURE__ */ new Map();
|
|
1738
|
-
for (let i = 0; i < consts.length; i++) {
|
|
1739
|
-
const v = jaxpr.inBinders[i];
|
|
1740
|
-
const slot = consts[i]._realizeSource();
|
|
1741
|
-
ctx.set(v, {
|
|
1742
|
-
type: "imm",
|
|
1743
|
-
arg: builder.pushConst(slot)
|
|
1744
|
-
});
|
|
1745
|
-
}
|
|
1746
1856
|
for (let i = 0; i < nargs; i++) {
|
|
1747
|
-
const v = jaxpr.inBinders[
|
|
1857
|
+
const v = jaxpr.inBinders[i];
|
|
1748
1858
|
ctx.set(v, {
|
|
1749
1859
|
type: "imm",
|
|
1750
1860
|
arg: i
|
|
@@ -1752,51 +1862,101 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1752
1862
|
}
|
|
1753
1863
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1754
1864
|
const eqn = jaxpr.eqns[i];
|
|
1865
|
+
if (routinePrimitives.has(eqn.primitive)) {
|
|
1866
|
+
const routine = new Routine(routinePrimitives.get(eqn.primitive), {
|
|
1867
|
+
inputShapes: eqn.inputs.map((x) => x.aval.shape),
|
|
1868
|
+
inputDtypes: eqn.inputs.map((x) => x.aval.dtype),
|
|
1869
|
+
outputShapes: eqn.outBinders.map((x) => x.aval.shape),
|
|
1870
|
+
outputDtypes: eqn.outBinders.map((x) => x.aval.dtype)
|
|
1871
|
+
}, eqn.params);
|
|
1872
|
+
const inputs = [];
|
|
1873
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1874
|
+
const jv = ctx.get(input);
|
|
1875
|
+
if (jv.type !== "imm") throw new Error(`jit: routine primitive ${eqn.primitive} input is not imm`);
|
|
1876
|
+
inputs.push(jv.arg);
|
|
1877
|
+
} else if (input instanceof Lit) inputs.push(builder.pushLit(input));
|
|
1878
|
+
const outputs = [];
|
|
1879
|
+
for (const outVar$1 of eqn.outBinders) {
|
|
1880
|
+
const outId = builder.pushBuffer(outVar$1.aval.size * byteWidth(outVar$1.aval.dtype));
|
|
1881
|
+
outputs.push(outId);
|
|
1882
|
+
ctx.set(outVar$1, {
|
|
1883
|
+
type: "imm",
|
|
1884
|
+
arg: outId
|
|
1885
|
+
});
|
|
1886
|
+
}
|
|
1887
|
+
builder.pushRoutine(routine, inputs, outputs);
|
|
1888
|
+
continue;
|
|
1889
|
+
}
|
|
1755
1890
|
const inputExps = [];
|
|
1756
1891
|
const inputAvals = [];
|
|
1757
1892
|
const inputArgs = [];
|
|
1758
|
-
|
|
1759
|
-
|
|
1760
|
-
|
|
1761
|
-
|
|
1762
|
-
|
|
1763
|
-
|
|
1764
|
-
|
|
1765
|
-
|
|
1766
|
-
inputArgs.push(jitId);
|
|
1767
|
-
}
|
|
1768
|
-
gidMap.set(gid, newGid);
|
|
1769
|
-
}
|
|
1770
|
-
inputExps.push(jitValue.exp.reindexGids(gidMap));
|
|
1771
|
-
} else if (jitValue.type === "imm") {
|
|
1772
|
-
let gid = inputArgs.indexOf(jitValue.arg);
|
|
1773
|
-
if (gid === -1) {
|
|
1774
|
-
gid = inputArgs.length;
|
|
1775
|
-
inputArgs.push(jitValue.arg);
|
|
1893
|
+
let inputReduction = null;
|
|
1894
|
+
const addArgs = (args) => {
|
|
1895
|
+
const newGids = [];
|
|
1896
|
+
for (const jitId of args) {
|
|
1897
|
+
let newGid = inputArgs.indexOf(jitId);
|
|
1898
|
+
if (newGid === -1) {
|
|
1899
|
+
newGid = inputArgs.length;
|
|
1900
|
+
inputArgs.push(jitId);
|
|
1776
1901
|
}
|
|
1902
|
+
newGids.push(newGid);
|
|
1903
|
+
}
|
|
1904
|
+
return newGids;
|
|
1905
|
+
};
|
|
1906
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
1907
|
+
const jv = ctx.get(input);
|
|
1908
|
+
if (jv.type === "exp") {
|
|
1909
|
+
const newGids = addArgs(jv.args);
|
|
1910
|
+
inputExps.push(jv.exp.reindexGids(newGids));
|
|
1911
|
+
} else if (jv.type === "imm") {
|
|
1912
|
+
const [gid] = addArgs([jv.arg]);
|
|
1777
1913
|
const st = ShapeTracker.fromShape(input.aval.shape);
|
|
1778
1914
|
const indices = unravelAlu(st.shape, AluVar.gidx);
|
|
1779
1915
|
inputExps.push(AluExp.globalView(input.aval.dtype, gid, st, indices));
|
|
1916
|
+
} else if (jv.type === "red") {
|
|
1917
|
+
if (inputReduction) throw new Error("jit: unexpected, multiple red inputs");
|
|
1918
|
+
const newGids = addArgs(jv.args);
|
|
1919
|
+
inputExps.push(jv.reduction.epilogue.reindexGids(newGids));
|
|
1920
|
+
inputReduction = jv;
|
|
1780
1921
|
}
|
|
1781
1922
|
inputAvals.push(input.aval);
|
|
1782
1923
|
} else if (input instanceof Lit) {
|
|
1783
1924
|
inputExps.push(AluExp.const(input.dtype, input.value));
|
|
1784
1925
|
inputAvals.push(input.aval);
|
|
1785
1926
|
} else throw new TypeError(`Unexpected input in Jaxpr: ${input}`);
|
|
1786
|
-
const nargs$1 = inputArgs.length;
|
|
1787
1927
|
const rule = jitRules[eqn.primitive];
|
|
1788
1928
|
if (!rule) throw new TypeError(`JIT not implemented for primitive ${eqn.primitive}`);
|
|
1789
|
-
|
|
1929
|
+
let exp$2;
|
|
1930
|
+
let reduction;
|
|
1931
|
+
if (inputReduction) {
|
|
1932
|
+
const jv = inputReduction;
|
|
1933
|
+
const newEpilogue = rule(inputExps, inputAvals, eqn.params).exp;
|
|
1934
|
+
exp$2 = jv.exp.reindexGids(addArgs(jv.args));
|
|
1935
|
+
reduction = new Reduction(jv.reduction.dtype, jv.reduction.op, jv.reduction.size, newEpilogue);
|
|
1936
|
+
} else {
|
|
1937
|
+
const ruleOutput = rule(inputExps, inputAvals, eqn.params);
|
|
1938
|
+
exp$2 = ruleOutput.exp;
|
|
1939
|
+
reduction = ruleOutput.reduction;
|
|
1940
|
+
}
|
|
1790
1941
|
const outVar = eqn.outBinders[0];
|
|
1791
|
-
if (
|
|
1942
|
+
if (blackNodes.has(outVar)) {
|
|
1943
|
+
const nargs$1 = inputArgs.length;
|
|
1944
|
+
const size$1 = outVar.aval.size;
|
|
1945
|
+
const kernel = new Kernel(nargs$1, size$1, exp$2, reduction);
|
|
1792
1946
|
const outId = builder.pushKernel(kernel, inputArgs);
|
|
1793
1947
|
ctx.set(outVar, {
|
|
1794
1948
|
type: "imm",
|
|
1795
1949
|
arg: outId
|
|
1796
1950
|
});
|
|
1797
|
-
} else ctx.set(outVar, {
|
|
1951
|
+
} else if (reduction) ctx.set(outVar, {
|
|
1952
|
+
type: "red",
|
|
1953
|
+
exp: exp$2,
|
|
1954
|
+
reduction,
|
|
1955
|
+
args: inputArgs
|
|
1956
|
+
});
|
|
1957
|
+
else ctx.set(outVar, {
|
|
1798
1958
|
type: "exp",
|
|
1799
|
-
exp:
|
|
1959
|
+
exp: exp$2,
|
|
1800
1960
|
args: inputArgs
|
|
1801
1961
|
});
|
|
1802
1962
|
}
|
|
@@ -1806,7 +1966,7 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
1806
1966
|
if (jitValue.type !== "imm") throw new Error("internal: Expected imm, since outs are black nodes");
|
|
1807
1967
|
outputIds.push(jitValue.arg);
|
|
1808
1968
|
} else if (out instanceof Lit) outputIds.push(builder.pushLit(out));
|
|
1809
|
-
const outputNeedsRef = new Set(
|
|
1969
|
+
const outputNeedsRef = new Set(range(nargs));
|
|
1810
1970
|
for (const outputId of outputIds) if (outputNeedsRef.has(outputId)) builder.pushIncref(outputId);
|
|
1811
1971
|
else outputNeedsRef.add(outputId);
|
|
1812
1972
|
builder.insertFreeSteps(outputIds);
|
|
@@ -1828,31 +1988,33 @@ function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
|
1828
1988
|
});
|
|
1829
1989
|
}
|
|
1830
1990
|
function broadcastedJit(fn, opts) {
|
|
1831
|
-
return (
|
|
1991
|
+
return (exps, avals, params) => {
|
|
1832
1992
|
let { shape: newShape, dtype: newDtype } = avals.reduce(promoteAvals);
|
|
1833
1993
|
const skipCastIdx = opts?.skipCastIdx ?? [];
|
|
1834
1994
|
if (skipCastIdx.length) newDtype = avals.filter((_, i) => !skipCastIdx.includes(i)).reduce(promoteAvals).dtype;
|
|
1835
|
-
exps = exps.map((exp$
|
|
1836
|
-
exp$
|
|
1995
|
+
exps = exps.map((exp$2, i) => {
|
|
1996
|
+
exp$2 = reshapeViews(exp$2, (st) => {
|
|
1837
1997
|
if (!deepEqual(st.shape, newShape)) return st.broadcast(newShape, range(newShape.length - st.shape.length));
|
|
1838
1998
|
});
|
|
1839
|
-
if (exp$
|
|
1840
|
-
return exp$
|
|
1999
|
+
if (exp$2.dtype !== newDtype && !skipCastIdx.includes(i)) exp$2 = AluExp.cast(newDtype, exp$2);
|
|
2000
|
+
return exp$2;
|
|
1841
2001
|
});
|
|
1842
|
-
|
|
1843
|
-
return new Kernel(nargs, prod(newShape), exp$2);
|
|
2002
|
+
return { exp: fn(exps, params) };
|
|
1844
2003
|
};
|
|
1845
2004
|
}
|
|
1846
2005
|
function unopJit(fn) {
|
|
1847
|
-
return (
|
|
1848
|
-
return
|
|
2006
|
+
return ([a], [_as], params) => {
|
|
2007
|
+
return { exp: fn(a, params) };
|
|
1849
2008
|
};
|
|
1850
2009
|
}
|
|
1851
2010
|
function reshapeJit(fn) {
|
|
1852
|
-
return (
|
|
1853
|
-
|
|
1854
|
-
|
|
1855
|
-
|
|
2011
|
+
return ([a], [_as], params) => {
|
|
2012
|
+
return { exp: reshapeViews(a, (st) => fn(st, params)) };
|
|
2013
|
+
};
|
|
2014
|
+
}
|
|
2015
|
+
function routineNoJit() {
|
|
2016
|
+
return () => {
|
|
2017
|
+
throw new Error("jit: rule is not implemented for routines");
|
|
1856
2018
|
};
|
|
1857
2019
|
}
|
|
1858
2020
|
const jitRules = {
|
|
@@ -1860,6 +2022,8 @@ const jitRules = {
|
|
|
1860
2022
|
[Primitive.Mul]: broadcastedJit(([a, b]) => AluExp.mul(a, b)),
|
|
1861
2023
|
[Primitive.Idiv]: broadcastedJit(([a, b]) => AluExp.idiv(a, b)),
|
|
1862
2024
|
[Primitive.Mod]: broadcastedJit(([a, b]) => AluExp.mod(a, b)),
|
|
2025
|
+
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
2026
|
+
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1863
2027
|
[Primitive.Neg]: unopJit((a) => AluExp.sub(AluExp.const(a.dtype, 0), a)),
|
|
1864
2028
|
[Primitive.Reciprocal]: unopJit(AluExp.reciprocal),
|
|
1865
2029
|
[Primitive.Floor]: unopJit(AluExp.floor),
|
|
@@ -1867,17 +2031,6 @@ const jitRules = {
|
|
|
1867
2031
|
[Primitive.StopGradient]: unopJit((a) => a),
|
|
1868
2032
|
[Primitive.Cast]: unopJit((a, { dtype }) => AluExp.cast(dtype, a)),
|
|
1869
2033
|
[Primitive.Bitcast]: unopJit((a, { dtype }) => AluExp.bitcast(dtype, a)),
|
|
1870
|
-
[Primitive.RandomBits]: (nargs, keys, keyShapes, { shape: shape$1, mode }) => {
|
|
1871
|
-
const mapping = (st) => {
|
|
1872
|
-
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
|
|
1873
|
-
};
|
|
1874
|
-
const k0 = reshapeViews(keys[0], mapping);
|
|
1875
|
-
const k1 = reshapeViews(keys[1], mapping);
|
|
1876
|
-
const c0 = AluExp.u32(0);
|
|
1877
|
-
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
1878
|
-
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
1879
|
-
return new Kernel(nargs, prod(shape$1), exp$2);
|
|
1880
|
-
},
|
|
1881
2034
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
1882
2035
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
1883
2036
|
[Primitive.Asin]: unopJit(AluExp.asin),
|
|
@@ -1887,9 +2040,7 @@ const jitRules = {
|
|
|
1887
2040
|
[Primitive.Erf]: unopJit(AluExp.erf),
|
|
1888
2041
|
[Primitive.Erfc]: unopJit(AluExp.erfc),
|
|
1889
2042
|
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
1890
|
-
[Primitive.
|
|
1891
|
-
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
1892
|
-
[Primitive.Reduce](nargs, [a], [as], { op, axis }) {
|
|
2043
|
+
[Primitive.Reduce]([a], [as], { op, axis }) {
|
|
1893
2044
|
const keptAxes = [];
|
|
1894
2045
|
const shiftedAxes = [];
|
|
1895
2046
|
const newShape = [];
|
|
@@ -1898,53 +2049,58 @@ const jitRules = {
|
|
|
1898
2049
|
keptAxes.push(i);
|
|
1899
2050
|
newShape.push(as.shape[i]);
|
|
1900
2051
|
}
|
|
1901
|
-
const size$1 = prod(newShape);
|
|
1902
2052
|
const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
1903
2053
|
newShape.push(reductionSize);
|
|
1904
2054
|
const perm = keptAxes.concat(shiftedAxes);
|
|
1905
2055
|
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
1906
2056
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
1907
|
-
return
|
|
2057
|
+
return {
|
|
2058
|
+
exp: a,
|
|
2059
|
+
reduction
|
|
2060
|
+
};
|
|
1908
2061
|
},
|
|
1909
2062
|
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1910
|
-
[Primitive.PoolTranspose](
|
|
2063
|
+
[Primitive.PoolTranspose]([a], [as], { inShape, window, strides }) {
|
|
1911
2064
|
let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1912
|
-
const size$1 = prod(inShape);
|
|
1913
2065
|
stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
|
|
1914
2066
|
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1915
2067
|
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1916
|
-
return
|
|
2068
|
+
return {
|
|
2069
|
+
exp: a,
|
|
2070
|
+
reduction
|
|
2071
|
+
};
|
|
1917
2072
|
},
|
|
1918
|
-
[Primitive.Dot](
|
|
1919
|
-
const k1 = jitRules[Primitive.Mul](
|
|
2073
|
+
[Primitive.Dot]([a, b], [as, bs]) {
|
|
2074
|
+
const k1 = jitRules[Primitive.Mul]([a, b], [as, bs], {});
|
|
1920
2075
|
const c = k1.exp;
|
|
1921
2076
|
const cs = promoteAvals(as, bs);
|
|
1922
|
-
return jitRules[Primitive.Reduce](
|
|
2077
|
+
return jitRules[Primitive.Reduce]([c], [cs], {
|
|
1923
2078
|
op: AluOp.Add,
|
|
1924
2079
|
axis: [cs.ndim - 1]
|
|
1925
2080
|
});
|
|
1926
2081
|
},
|
|
1927
|
-
[Primitive.Conv](
|
|
2082
|
+
[Primitive.Conv]([a, b], [as, bs], params) {
|
|
1928
2083
|
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1929
2084
|
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1930
2085
|
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1931
2086
|
as = new ShapedArray(stX.shape, as.dtype, as.weakType);
|
|
1932
2087
|
bs = new ShapedArray(stY.shape, bs.dtype, bs.weakType);
|
|
1933
|
-
return jitRules[Primitive.Dot](
|
|
2088
|
+
return jitRules[Primitive.Dot]([a, b], [as, bs], {});
|
|
1934
2089
|
},
|
|
1935
2090
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
1936
2091
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b), { skipCastIdx: [0] }),
|
|
1937
|
-
[Primitive.
|
|
1938
|
-
|
|
1939
|
-
|
|
1940
|
-
|
|
1941
|
-
const
|
|
1942
|
-
|
|
1943
|
-
|
|
1944
|
-
|
|
1945
|
-
|
|
1946
|
-
|
|
1947
|
-
|
|
2092
|
+
[Primitive.RandomBits]: (keys, keyShapes, { shape: shape$1, mode }) => {
|
|
2093
|
+
const mapping = (st) => {
|
|
2094
|
+
if (!deepEqual(st.shape, shape$1)) return st.broadcast(shape$1, range(shape$1.length - st.shape.length));
|
|
2095
|
+
};
|
|
2096
|
+
const k0 = reshapeViews(keys[0], mapping);
|
|
2097
|
+
const k1 = reshapeViews(keys[1], mapping);
|
|
2098
|
+
const c0 = AluExp.u32(0);
|
|
2099
|
+
const c1 = AluExp.cast(DType.Uint32, AluVar.gidx);
|
|
2100
|
+
const exp$2 = AluExp.threefry2x32(k0, k1, c0, c1, mode);
|
|
2101
|
+
return { exp: exp$2 };
|
|
2102
|
+
},
|
|
2103
|
+
[Primitive.Gather]([x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1948
2104
|
const axisSet = new Set(axis);
|
|
1949
2105
|
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1950
2106
|
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
@@ -1957,24 +2113,38 @@ const jitRules = {
|
|
|
1957
2113
|
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
1958
2114
|
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1959
2115
|
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1960
|
-
return
|
|
2116
|
+
return { exp: x.substitute({ gidx: index }) };
|
|
1961
2117
|
},
|
|
1962
|
-
[Primitive.
|
|
1963
|
-
|
|
2118
|
+
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
2119
|
+
[Primitive.Broadcast]: reshapeJit((st, { shape: shape$1, axis }) => st.broadcast(shape$1, axis)),
|
|
2120
|
+
[Primitive.Reshape]: reshapeJit((st, { shape: shape$1 }) => st.reshape(shape$1)),
|
|
2121
|
+
[Primitive.Flip]: reshapeJit((st, { axis }) => {
|
|
2122
|
+
const arg = rep(st.shape.length, false);
|
|
2123
|
+
for (const ax of axis) arg[ax] = true;
|
|
2124
|
+
return st.flip(arg);
|
|
2125
|
+
}),
|
|
2126
|
+
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
2127
|
+
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
2128
|
+
[Primitive.Sort]: routineNoJit(),
|
|
2129
|
+
[Primitive.Argsort]: routineNoJit(),
|
|
2130
|
+
[Primitive.TriangularSolve]: routineNoJit(),
|
|
2131
|
+
[Primitive.Cholesky]: routineNoJit(),
|
|
2132
|
+
[Primitive.Jit]() {
|
|
2133
|
+
throw new Error("internal: Jit should have been flattened before JIT compilation");
|
|
1964
2134
|
}
|
|
1965
2135
|
};
|
|
1966
2136
|
/** Determines how to split the Jaxpr into kernels via dataflow analysis. */
|
|
1967
2137
|
function splitGraphDataflow(backend, jaxpr) {
|
|
1968
|
-
const
|
|
2138
|
+
const varToDefn = /* @__PURE__ */ new Map();
|
|
2139
|
+
const varToUsages = /* @__PURE__ */ new Map();
|
|
1969
2140
|
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1970
2141
|
const eqn = jaxpr.eqns[i];
|
|
1971
|
-
for (const v of eqn.outBinders) if (v instanceof Var)
|
|
1972
|
-
|
|
1973
|
-
|
|
1974
|
-
|
|
1975
|
-
|
|
1976
|
-
|
|
1977
|
-
p1NextBlack.set(v, v);
|
|
2142
|
+
for (const v of eqn.outBinders) if (v instanceof Var) varToDefn.set(v, i);
|
|
2143
|
+
for (const input of eqn.inputs) if (input instanceof Var) {
|
|
2144
|
+
const usages = varToUsages.get(input);
|
|
2145
|
+
if (usages) usages.push(i);
|
|
2146
|
+
else varToUsages.set(input, [i]);
|
|
2147
|
+
}
|
|
1978
2148
|
}
|
|
1979
2149
|
const reducePrimitives = [
|
|
1980
2150
|
Primitive.Reduce,
|
|
@@ -1982,28 +2152,94 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1982
2152
|
Primitive.Conv,
|
|
1983
2153
|
Primitive.PoolTranspose
|
|
1984
2154
|
];
|
|
1985
|
-
const
|
|
1986
|
-
|
|
2155
|
+
const reductionEpilogueEqns = /* @__PURE__ */ new Set();
|
|
2156
|
+
const reductionEndpointEqns = /* @__PURE__ */ new Set();
|
|
2157
|
+
for (let i = 0; i < jaxpr.eqns.length; i++) {
|
|
1987
2158
|
const eqn = jaxpr.eqns[i];
|
|
1988
|
-
if (reducePrimitives.includes(eqn.primitive)
|
|
1989
|
-
|
|
1990
|
-
|
|
1991
|
-
|
|
2159
|
+
if (reducePrimitives.includes(eqn.primitive)) {
|
|
2160
|
+
let head = i;
|
|
2161
|
+
while (true) {
|
|
2162
|
+
reductionEpilogueEqns.add(head);
|
|
2163
|
+
const outVar = jaxpr.eqns[head].outBinders[0];
|
|
2164
|
+
const usages = varToUsages.get(outVar) ?? [];
|
|
2165
|
+
if (jaxpr.outs.includes(outVar) || usages.length !== 1) break;
|
|
2166
|
+
if (reductionEpilogueEqns.has(usages[0])) break;
|
|
2167
|
+
const nextEqn = jaxpr.eqns[usages[0]];
|
|
2168
|
+
switch (nextEqn.primitive) {
|
|
2169
|
+
case Primitive.Neg:
|
|
2170
|
+
case Primitive.Reciprocal:
|
|
2171
|
+
case Primitive.Floor:
|
|
2172
|
+
case Primitive.Ceil:
|
|
2173
|
+
case Primitive.StopGradient:
|
|
2174
|
+
case Primitive.Cast:
|
|
2175
|
+
case Primitive.Bitcast:
|
|
2176
|
+
case Primitive.Sin:
|
|
2177
|
+
case Primitive.Cos:
|
|
2178
|
+
case Primitive.Asin:
|
|
2179
|
+
case Primitive.Atan:
|
|
2180
|
+
case Primitive.Exp:
|
|
2181
|
+
case Primitive.Log:
|
|
2182
|
+
case Primitive.Erf:
|
|
2183
|
+
case Primitive.Erfc:
|
|
2184
|
+
case Primitive.Sqrt:
|
|
2185
|
+
head = usages[0];
|
|
2186
|
+
continue;
|
|
2187
|
+
case Primitive.Add:
|
|
2188
|
+
case Primitive.Mul:
|
|
2189
|
+
case Primitive.Idiv:
|
|
2190
|
+
case Primitive.Mod:
|
|
2191
|
+
case Primitive.Min:
|
|
2192
|
+
case Primitive.Max: {
|
|
2193
|
+
const otherInput = nextEqn.inputs.find((v) => v !== outVar);
|
|
2194
|
+
if (otherInput instanceof Lit || deepEqual(generalBroadcast(otherInput.aval.shape, outVar.aval.shape), outVar.aval.shape)) {
|
|
2195
|
+
head = usages[0];
|
|
2196
|
+
continue;
|
|
2197
|
+
}
|
|
2198
|
+
break;
|
|
2199
|
+
}
|
|
2200
|
+
}
|
|
2201
|
+
break;
|
|
1992
2202
|
}
|
|
1993
|
-
|
|
2203
|
+
reductionEndpointEqns.add(head);
|
|
1994
2204
|
}
|
|
1995
|
-
|
|
1996
|
-
|
|
1997
|
-
|
|
1998
|
-
|
|
1999
|
-
|
|
2000
|
-
|
|
2001
|
-
|
|
2002
|
-
|
|
2003
|
-
|
|
2205
|
+
}
|
|
2206
|
+
const blackNodes = /* @__PURE__ */ new Set();
|
|
2207
|
+
const p1NextBlack = /* @__PURE__ */ new Map();
|
|
2208
|
+
for (const v of jaxpr.outs) if (v instanceof Var) {
|
|
2209
|
+
blackNodes.add(v);
|
|
2210
|
+
p1NextBlack.set(v, v);
|
|
2211
|
+
}
|
|
2212
|
+
const heterogeneousViewPrimitives = [Primitive.RandomBits, Primitive.Gather];
|
|
2213
|
+
const needsCleanShapePrimitives = [Primitive.Pad];
|
|
2214
|
+
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
2215
|
+
const eqn = jaxpr.eqns[i];
|
|
2216
|
+
if (reductionEndpointEqns.has(i) || heterogeneousViewPrimitives.includes(eqn.primitive) || routinePrimitives.has(eqn.primitive) || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
2217
|
+
for (const v of eqn.outBinders) {
|
|
2218
|
+
blackNodes.add(v);
|
|
2219
|
+
p1NextBlack.set(v, v);
|
|
2220
|
+
}
|
|
2221
|
+
continue;
|
|
2222
|
+
}
|
|
2223
|
+
const reach = /* @__PURE__ */ new Set();
|
|
2224
|
+
let needsCleanOutput = false;
|
|
2225
|
+
outer: for (const v of eqn.outBinders) for (const j of varToUsages.get(v) ?? []) {
|
|
2226
|
+
if (needsCleanShapePrimitives.includes(jaxpr.eqns[j].primitive) || routinePrimitives.has(jaxpr.eqns[j].primitive)) {
|
|
2227
|
+
needsCleanOutput = true;
|
|
2228
|
+
break outer;
|
|
2229
|
+
}
|
|
2230
|
+
for (const o of jaxpr.eqns[j].outBinders) {
|
|
2231
|
+
const u = p1NextBlack.get(o);
|
|
2232
|
+
if (u) reach.add(u);
|
|
2233
|
+
}
|
|
2234
|
+
}
|
|
2235
|
+
if (reach.size > 1 || needsCleanOutput) for (const v of eqn.outBinders) {
|
|
2004
2236
|
blackNodes.add(v);
|
|
2005
2237
|
p1NextBlack.set(v, v);
|
|
2006
2238
|
}
|
|
2239
|
+
else if (reach.size === 1) {
|
|
2240
|
+
const b = reach.values().next().value;
|
|
2241
|
+
for (const v of eqn.outBinders) p1NextBlack.set(v, b);
|
|
2242
|
+
}
|
|
2007
2243
|
}
|
|
2008
2244
|
const p2Deps = /* @__PURE__ */ new Map();
|
|
2009
2245
|
for (const v of jaxpr.inBinders) p2Deps.set(v, new Set([v]));
|
|
@@ -2011,7 +2247,6 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2011
2247
|
while (p2idx < jaxpr.eqns.length) {
|
|
2012
2248
|
const eqn = jaxpr.eqns[p2idx++];
|
|
2013
2249
|
const deps = [];
|
|
2014
|
-
if (eqn.outBinders.some((v) => blackNodes.has(v))) continue;
|
|
2015
2250
|
for (const input of eqn.inputs) if (input instanceof Var) if (blackNodes.has(input)) deps.push(new Set([input]));
|
|
2016
2251
|
else deps.push(p2Deps.get(input));
|
|
2017
2252
|
else deps.push(/* @__PURE__ */ new Set());
|
|
@@ -2022,7 +2257,7 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2022
2257
|
let assocInput = -1;
|
|
2023
2258
|
for (let i = 0; i < eqn.inputs.length; i++) {
|
|
2024
2259
|
const input = eqn.inputs[i];
|
|
2025
|
-
if (input instanceof Var &&
|
|
2260
|
+
if (input instanceof Var && varToDefn.has(input)) {
|
|
2026
2261
|
let uniqueDeps = 0;
|
|
2027
2262
|
for (const dep of deps[i]) if (depCounter.get(dep) === 1) uniqueDeps++;
|
|
2028
2263
|
if (uniqueDeps > maxUniqueDeps) {
|
|
@@ -2033,8 +2268,8 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
2033
2268
|
}
|
|
2034
2269
|
if (assocInput === -1) throw new Error(`internal: maxArgs, no input found to mark as black in Jaxpr equation ${eqn}`);
|
|
2035
2270
|
const assocVar = eqn.inputs[assocInput];
|
|
2036
|
-
p2idx =
|
|
2037
|
-
for (const out of jaxpr.eqns[p2idx].outBinders) blackNodes.add(out);
|
|
2271
|
+
p2idx = varToDefn.get(assocVar);
|
|
2272
|
+
for (const out of jaxpr.eqns[p2idx++].outBinders) blackNodes.add(out);
|
|
2038
2273
|
} else {
|
|
2039
2274
|
const s = new Set(depCounter.keys());
|
|
2040
2275
|
for (const out of eqn.outBinders) p2Deps.set(out, s);
|
|
@@ -2060,9 +2295,9 @@ var PendingExecute = class {
|
|
|
2060
2295
|
submitted = false;
|
|
2061
2296
|
#promise = null;
|
|
2062
2297
|
#rc = 1;
|
|
2063
|
-
constructor(backend,
|
|
2298
|
+
constructor(backend, source, inputs, outputs) {
|
|
2064
2299
|
this.backend = backend;
|
|
2065
|
-
this.
|
|
2300
|
+
this.source = source;
|
|
2066
2301
|
this.inputs = inputs;
|
|
2067
2302
|
this.outputs = outputs;
|
|
2068
2303
|
for (const slot of inputs) this.backend.incRef(slot);
|
|
@@ -2083,13 +2318,15 @@ var PendingExecute = class {
|
|
|
2083
2318
|
return;
|
|
2084
2319
|
}
|
|
2085
2320
|
this.#promise = (async () => {
|
|
2086
|
-
this.prepared = await this.backend.
|
|
2321
|
+
if (this.source instanceof Kernel) this.prepared = await this.backend.prepareKernel(this.source);
|
|
2322
|
+
else this.prepared = await this.backend.prepareRoutine(this.source);
|
|
2087
2323
|
})();
|
|
2088
2324
|
await this.#promise;
|
|
2089
2325
|
}
|
|
2090
2326
|
prepareSync() {
|
|
2091
2327
|
if (this.prepared) return;
|
|
2092
|
-
this.prepared = this.backend.
|
|
2328
|
+
if (this.source instanceof Kernel) this.prepared = this.backend.prepareKernelSync(this.source);
|
|
2329
|
+
else this.prepared = this.backend.prepareRoutineSync(this.source);
|
|
2093
2330
|
}
|
|
2094
2331
|
submit() {
|
|
2095
2332
|
if (this.submitted) return;
|
|
@@ -2112,8 +2349,6 @@ var PendingExecute = class {
|
|
|
2112
2349
|
* "Array" type by name.
|
|
2113
2350
|
*/
|
|
2114
2351
|
var Array$1 = class Array$1 extends Tracer {
|
|
2115
|
-
static #nextId = 1001;
|
|
2116
|
-
id;
|
|
2117
2352
|
#dtype;
|
|
2118
2353
|
#weakType;
|
|
2119
2354
|
#source;
|
|
@@ -2130,7 +2365,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2130
2365
|
*/
|
|
2131
2366
|
constructor(args) {
|
|
2132
2367
|
super(baseArrayTrace);
|
|
2133
|
-
this.id = Array$1.#nextId++;
|
|
2134
2368
|
this.#dtype = args.dtype;
|
|
2135
2369
|
this.#weakType = args.weakType;
|
|
2136
2370
|
this.#source = args.source;
|
|
@@ -2439,6 +2673,27 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2439
2673
|
pending
|
|
2440
2674
|
});
|
|
2441
2675
|
}
|
|
2676
|
+
/** Apply an operation with custom lowering to this array. */
|
|
2677
|
+
static #routine(routine, arrays, outputWeakType) {
|
|
2678
|
+
const { backend, committed } = Array$1.#computeBackend(routine.name, arrays);
|
|
2679
|
+
for (const ar of arrays) ar.#realize();
|
|
2680
|
+
const inputs = arrays.map((ar) => ar.#source);
|
|
2681
|
+
const outputs = routine.type.outputDtypes.map((dtype, i) => backend.malloc(byteWidth(dtype) * prod(routine.type.outputShapes[i])));
|
|
2682
|
+
const pending = arrays.flatMap((ar) => ar.#pending);
|
|
2683
|
+
for (const exe of pending) exe.updateRc(+outputs.length);
|
|
2684
|
+
pending.push(new PendingExecute(backend, routine, inputs, outputs));
|
|
2685
|
+
pending[pending.length - 1].updateRc(+outputs.length - 1);
|
|
2686
|
+
arrays.forEach((ar) => ar.dispose());
|
|
2687
|
+
return outputs.map((output, i) => new Array$1({
|
|
2688
|
+
source: output,
|
|
2689
|
+
st: ShapeTracker.fromShape(routine.type.outputShapes[i]),
|
|
2690
|
+
dtype: routine.type.outputDtypes[i],
|
|
2691
|
+
weakType: outputWeakType[i],
|
|
2692
|
+
backend,
|
|
2693
|
+
committed,
|
|
2694
|
+
pending
|
|
2695
|
+
}));
|
|
2696
|
+
}
|
|
2442
2697
|
/**
|
|
2443
2698
|
* Normalizes this array into one backed by a `Slot`.
|
|
2444
2699
|
*
|
|
@@ -2599,6 +2854,12 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2599
2854
|
[Primitive.Mod]([x, y]) {
|
|
2600
2855
|
return [x.#binary(AluOp.Mod, y)];
|
|
2601
2856
|
},
|
|
2857
|
+
[Primitive.Min]([x, y]) {
|
|
2858
|
+
return [x.#binary(AluOp.Min, y)];
|
|
2859
|
+
},
|
|
2860
|
+
[Primitive.Max]([x, y]) {
|
|
2861
|
+
return [x.#binary(AluOp.Max, y)];
|
|
2862
|
+
},
|
|
2602
2863
|
[Primitive.Neg]([x]) {
|
|
2603
2864
|
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
2604
2865
|
},
|
|
@@ -2635,25 +2896,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2635
2896
|
return [y];
|
|
2636
2897
|
}
|
|
2637
2898
|
},
|
|
2638
|
-
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2639
|
-
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2640
|
-
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2641
|
-
const c0 = zeros(shape$1, {
|
|
2642
|
-
dtype: DType.Uint32,
|
|
2643
|
-
device: k0.device
|
|
2644
|
-
});
|
|
2645
|
-
const c1 = arange(0, prod(shape$1), 1, {
|
|
2646
|
-
dtype: DType.Uint32,
|
|
2647
|
-
device: k0.device
|
|
2648
|
-
}).reshape(shape$1);
|
|
2649
|
-
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2650
|
-
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2651
|
-
k0,
|
|
2652
|
-
k1,
|
|
2653
|
-
c0,
|
|
2654
|
-
c1
|
|
2655
|
-
])];
|
|
2656
|
-
},
|
|
2657
2899
|
[Primitive.Sin]([x]) {
|
|
2658
2900
|
return [x.#unary(AluOp.Sin)];
|
|
2659
2901
|
},
|
|
@@ -2681,12 +2923,6 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2681
2923
|
[Primitive.Sqrt]([x]) {
|
|
2682
2924
|
return [x.#unary(AluOp.Sqrt)];
|
|
2683
2925
|
},
|
|
2684
|
-
[Primitive.Min]([x, y]) {
|
|
2685
|
-
return [x.#binary(AluOp.Min, y)];
|
|
2686
|
-
},
|
|
2687
|
-
[Primitive.Max]([x, y]) {
|
|
2688
|
-
return [x.#binary(AluOp.Max, y)];
|
|
2689
|
-
},
|
|
2690
2926
|
[Primitive.Reduce]([x], { op, axis }) {
|
|
2691
2927
|
if (axis.length === 0) return [x];
|
|
2692
2928
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
@@ -2721,6 +2957,28 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2721
2957
|
y
|
|
2722
2958
|
], { dtypeOverride: [DType.Bool] })];
|
|
2723
2959
|
},
|
|
2960
|
+
[Primitive.RandomBits]([k0, k1], { shape: shape$1, mode }) {
|
|
2961
|
+
const keyShape = generalBroadcast(k0.shape, k1.shape);
|
|
2962
|
+
if (!deepEqual(generalBroadcast(keyShape, shape$1), shape$1)) throw new TypeError(`Keys of shapes ${k0.shape} and ${k1.shape} cannot be broadcast to shape ${shape$1}`);
|
|
2963
|
+
const c0 = zeros(shape$1, {
|
|
2964
|
+
dtype: DType.Uint32,
|
|
2965
|
+
device: k0.device
|
|
2966
|
+
});
|
|
2967
|
+
const c1 = arange(0, prod(shape$1), 1, {
|
|
2968
|
+
dtype: DType.Uint32,
|
|
2969
|
+
device: k0.device
|
|
2970
|
+
}).reshape(shape$1);
|
|
2971
|
+
const custom = ([k0$1, k1$1, c0$1, c1$1]) => AluExp.threefry2x32(k0$1, k1$1, c0$1, c1$1, mode);
|
|
2972
|
+
return [Array$1.#naryCustom("random_bits", custom, [
|
|
2973
|
+
k0,
|
|
2974
|
+
k1,
|
|
2975
|
+
c0,
|
|
2976
|
+
c1
|
|
2977
|
+
])];
|
|
2978
|
+
},
|
|
2979
|
+
[Primitive.Gather]([x, ...indices], { axis, outDim }) {
|
|
2980
|
+
return [x.#gather(indices, axis, outDim)];
|
|
2981
|
+
},
|
|
2724
2982
|
[Primitive.Transpose]([x], { perm }) {
|
|
2725
2983
|
return [x.#transpose(perm)];
|
|
2726
2984
|
},
|
|
@@ -2741,17 +2999,48 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
2741
2999
|
[Primitive.Pad]([x], { width }) {
|
|
2742
3000
|
return [x.#reshape(x.#st.pad(width))];
|
|
2743
3001
|
},
|
|
2744
|
-
[Primitive.
|
|
2745
|
-
|
|
3002
|
+
[Primitive.Sort]([x]) {
|
|
3003
|
+
const routine = new Routine(Routines.Sort, {
|
|
3004
|
+
inputShapes: [x.aval.shape],
|
|
3005
|
+
inputDtypes: [x.aval.dtype],
|
|
3006
|
+
outputShapes: [x.aval.shape],
|
|
3007
|
+
outputDtypes: [x.aval.dtype]
|
|
3008
|
+
});
|
|
3009
|
+
return Array$1.#routine(routine, [x], [x.#weakType]);
|
|
3010
|
+
},
|
|
3011
|
+
[Primitive.Argsort]([x]) {
|
|
3012
|
+
const routine = new Routine(Routines.Argsort, {
|
|
3013
|
+
inputShapes: [x.aval.shape],
|
|
3014
|
+
inputDtypes: [x.aval.dtype],
|
|
3015
|
+
outputShapes: [x.aval.shape, x.aval.shape],
|
|
3016
|
+
outputDtypes: [x.aval.dtype, DType.Int32]
|
|
3017
|
+
});
|
|
3018
|
+
return Array$1.#routine(routine, [x], [x.#weakType, false]);
|
|
3019
|
+
},
|
|
3020
|
+
[Primitive.TriangularSolve]([a, b], { unitDiagonal }) {
|
|
3021
|
+
const routine = new Routine(Routines.TriangularSolve, {
|
|
3022
|
+
inputShapes: [a.aval.shape, b.aval.shape],
|
|
3023
|
+
inputDtypes: [a.aval.dtype, b.aval.dtype],
|
|
3024
|
+
outputShapes: [b.aval.shape],
|
|
3025
|
+
outputDtypes: [b.aval.dtype]
|
|
3026
|
+
}, { unitDiagonal });
|
|
3027
|
+
return Array$1.#routine(routine, [a, b], [a.#weakType && b.#weakType]);
|
|
2746
3028
|
},
|
|
2747
|
-
[Primitive.
|
|
2748
|
-
|
|
2749
|
-
|
|
3029
|
+
[Primitive.Cholesky]([a]) {
|
|
3030
|
+
const routine = new Routine(Routines.Cholesky, {
|
|
3031
|
+
inputShapes: [a.aval.shape],
|
|
3032
|
+
inputDtypes: [a.aval.dtype],
|
|
3033
|
+
outputShapes: [a.aval.shape],
|
|
3034
|
+
outputDtypes: [a.aval.dtype]
|
|
3035
|
+
});
|
|
3036
|
+
return Array$1.#routine(routine, [a], [a.#weakType]);
|
|
3037
|
+
},
|
|
3038
|
+
[Primitive.Jit](args, { jaxpr }) {
|
|
3039
|
+
if (jaxpr.inBinders.length !== args.length) throw new Error(`jit expects ${jaxpr.inBinders.length} args, got ${args.length}`);
|
|
3040
|
+
const { backend, committed } = Array$1.#computeBackend("jit", args);
|
|
2750
3041
|
args = args.map((ar) => ar._putSync(backend));
|
|
2751
|
-
const
|
|
2752
|
-
const
|
|
2753
|
-
const jp = jitCompile(backend, jaxpr, consts);
|
|
2754
|
-
const { outputs, pending } = jp.execute(tracers.map((x) => x._realizeSource()));
|
|
3042
|
+
const jp = jitCompile(backend, jaxpr);
|
|
3043
|
+
const { outputs, pending } = jp.execute(args.map((x) => x._realizeSource()));
|
|
2755
3044
|
for (const exe of pending) exe.updateRc(+outputs.length - 1);
|
|
2756
3045
|
const prevPending = [...new Set(args.flatMap((x) => x.#pending))];
|
|
2757
3046
|
for (const exe of prevPending) exe.updateRc(+outputs.length);
|
|
@@ -3050,6 +3339,43 @@ function arange(start, stop, step = 1, { dtype, device } = {}) {
|
|
|
3050
3339
|
});
|
|
3051
3340
|
}
|
|
3052
3341
|
/**
|
|
3342
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
3343
|
+
*
|
|
3344
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
3345
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
3346
|
+
* `k>0` is above it.
|
|
3347
|
+
*/
|
|
3348
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
3349
|
+
m ??= n;
|
|
3350
|
+
dtype ??= DType.Float32;
|
|
3351
|
+
if (!Number.isInteger(n) || n < 0) throw new Error(`tri: n must be a non-negative integer, got ${n}`);
|
|
3352
|
+
if (!Number.isInteger(m) || m < 0) throw new Error(`tri: m must be a non-negative integer, got ${m}`);
|
|
3353
|
+
if (!Number.isInteger(k)) throw new Error(`tri: k must be an integer, got ${k}`);
|
|
3354
|
+
const rows = arange(k, n + k, 1, {
|
|
3355
|
+
dtype: DType.Int32,
|
|
3356
|
+
device
|
|
3357
|
+
});
|
|
3358
|
+
const cols = arange(0, m, 1, {
|
|
3359
|
+
dtype: DType.Int32,
|
|
3360
|
+
device
|
|
3361
|
+
});
|
|
3362
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
3363
|
+
}
|
|
3364
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
3365
|
+
function tril(a, k = 0) {
|
|
3366
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3367
|
+
a = fudgeArray(a);
|
|
3368
|
+
const [n, m] = a.shape.slice(-2);
|
|
3369
|
+
return where$1(tri(n, m, k, { dtype: DType.Bool }), a.ref, zerosLike$1(a));
|
|
3370
|
+
}
|
|
3371
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
3372
|
+
function triu(a, k = 0) {
|
|
3373
|
+
if (ndim$1(a) < 2) throw new Error(`tril: input array must be at least 2D, got ${ndim$1(a)}D`);
|
|
3374
|
+
a = fudgeArray(a);
|
|
3375
|
+
const [n, m] = a.shape.slice(-2);
|
|
3376
|
+
return where$1(tri(n, m, k - 1, { dtype: DType.Bool }), zerosLike$1(a.ref), a);
|
|
3377
|
+
}
|
|
3378
|
+
/**
|
|
3053
3379
|
* Return evenly spaced numbers over a specified interval.
|
|
3054
3380
|
*
|
|
3055
3381
|
* Returns _num_ evenly spaced samples, calculated over the interval
|
|
@@ -3096,333 +3422,106 @@ function aluCompare(a, b, op) {
|
|
|
3096
3422
|
}
|
|
3097
3423
|
|
|
3098
3424
|
//#endregion
|
|
3099
|
-
//#region src/frontend/
|
|
3100
|
-
|
|
3101
|
-
|
|
3425
|
+
//#region src/frontend/vmap.ts
|
|
3426
|
+
function mappedAval(batchDim, aval) {
|
|
3427
|
+
const shape$1 = [...aval.shape];
|
|
3428
|
+
shape$1.splice(batchDim, 1);
|
|
3429
|
+
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3430
|
+
}
|
|
3431
|
+
/** Move one axis to a different index. */
|
|
3432
|
+
function moveaxis(x, src, dst) {
|
|
3433
|
+
const t = pureArray(x);
|
|
3434
|
+
src = checkAxis(src, t.ndim);
|
|
3435
|
+
dst = checkAxis(dst, t.ndim);
|
|
3436
|
+
if (src === dst) return t;
|
|
3437
|
+
const perm = range(t.ndim);
|
|
3438
|
+
perm.splice(src, 1);
|
|
3439
|
+
perm.splice(dst, 0, src);
|
|
3440
|
+
return transpose$1(t, perm);
|
|
3441
|
+
}
|
|
3442
|
+
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3443
|
+
if (src === null) {
|
|
3444
|
+
const targetShape = [...x.shape];
|
|
3445
|
+
targetShape.splice(dst, 0, axisSize);
|
|
3446
|
+
return broadcast(x, targetShape, [dst]);
|
|
3447
|
+
} else if (src === dst) return x;
|
|
3448
|
+
else return moveaxis(x, src, dst);
|
|
3449
|
+
}
|
|
3450
|
+
var BatchTracer = class extends Tracer {
|
|
3451
|
+
constructor(trace$1, val, batchDim) {
|
|
3102
3452
|
super(trace$1);
|
|
3103
|
-
this.
|
|
3104
|
-
this.
|
|
3453
|
+
this.val = val;
|
|
3454
|
+
this.batchDim = batchDim;
|
|
3105
3455
|
}
|
|
3106
3456
|
get aval() {
|
|
3107
|
-
return this.
|
|
3457
|
+
if (this.batchDim === null) return this.val.aval;
|
|
3458
|
+
else return mappedAval(this.batchDim, this.val.aval);
|
|
3108
3459
|
}
|
|
3109
3460
|
toString() {
|
|
3110
|
-
return `
|
|
3461
|
+
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3111
3462
|
}
|
|
3112
3463
|
get ref() {
|
|
3113
|
-
this.
|
|
3464
|
+
this.val.ref;
|
|
3114
3465
|
return this;
|
|
3115
3466
|
}
|
|
3116
3467
|
dispose() {
|
|
3117
|
-
this.
|
|
3118
|
-
|
|
3468
|
+
this.val.dispose();
|
|
3469
|
+
}
|
|
3470
|
+
fullLower() {
|
|
3471
|
+
if (this.batchDim === null) return this.val.fullLower();
|
|
3472
|
+
else return this;
|
|
3119
3473
|
}
|
|
3120
3474
|
};
|
|
3121
|
-
var
|
|
3475
|
+
var BatchTrace = class extends Trace {
|
|
3122
3476
|
pure(val) {
|
|
3123
3477
|
return this.lift(pureArray(val));
|
|
3124
3478
|
}
|
|
3125
3479
|
lift(val) {
|
|
3126
|
-
return new
|
|
3480
|
+
return new BatchTracer(this, val, null);
|
|
3127
3481
|
}
|
|
3128
3482
|
processPrimitive(primitive, tracers, params) {
|
|
3129
|
-
const [
|
|
3130
|
-
const
|
|
3131
|
-
if (
|
|
3132
|
-
|
|
3133
|
-
|
|
3483
|
+
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3484
|
+
const vmapRule = vmapRules[primitive];
|
|
3485
|
+
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3486
|
+
if (bdimsIn.every((d) => d === null)) {
|
|
3487
|
+
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3488
|
+
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3489
|
+
}
|
|
3490
|
+
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3491
|
+
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3492
|
+
}
|
|
3493
|
+
get axisSize() {
|
|
3494
|
+
return this.main.globalData;
|
|
3134
3495
|
}
|
|
3135
3496
|
};
|
|
3136
|
-
/**
|
|
3137
|
-
|
|
3138
|
-
|
|
3139
|
-
|
|
3140
|
-
|
|
3141
|
-
|
|
3142
|
-
|
|
3143
|
-
|
|
3144
|
-
|
|
3145
|
-
|
|
3146
|
-
|
|
3147
|
-
|
|
3148
|
-
|
|
3149
|
-
|
|
3497
|
+
/**
|
|
3498
|
+
* Process a primitive with built-in broadcasting.
|
|
3499
|
+
*
|
|
3500
|
+
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3501
|
+
*/
|
|
3502
|
+
function broadcastBatcher(op) {
|
|
3503
|
+
return (axisSize, args, dims) => {
|
|
3504
|
+
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3505
|
+
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3506
|
+
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3507
|
+
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3508
|
+
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3509
|
+
args = args.map((x, i) => {
|
|
3510
|
+
if (dims[i] === null) return x;
|
|
3511
|
+
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3512
|
+
if (x.ndim < nd) x = x.reshape([
|
|
3513
|
+
x.shape[0],
|
|
3514
|
+
...rep(nd - x.ndim, 1),
|
|
3515
|
+
...x.shape.slice(1)
|
|
3516
|
+
]);
|
|
3517
|
+
return x;
|
|
3518
|
+
});
|
|
3519
|
+
return [[op(...args)], [0]];
|
|
3150
3520
|
};
|
|
3151
3521
|
}
|
|
3152
|
-
|
|
3153
|
-
|
|
3154
|
-
|
|
3155
|
-
for (const t of tangents) t.dispose();
|
|
3156
|
-
const ys = bind(primitive, primals, params);
|
|
3157
|
-
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3158
|
-
};
|
|
3159
|
-
}
|
|
3160
|
-
const jvpRules = {
|
|
3161
|
-
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3162
|
-
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
3163
|
-
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
3164
|
-
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
3165
|
-
if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
|
|
3166
|
-
dx.dispose();
|
|
3167
|
-
dy.dispose();
|
|
3168
|
-
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
3169
|
-
}
|
|
3170
|
-
const q = idiv(x.ref, y.ref);
|
|
3171
|
-
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
3172
|
-
},
|
|
3173
|
-
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
3174
|
-
[Primitive.Reciprocal]([x], [dx]) {
|
|
3175
|
-
const xRecip = reciprocal$1(x.ref);
|
|
3176
|
-
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
3177
|
-
},
|
|
3178
|
-
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
3179
|
-
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
3180
|
-
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
3181
|
-
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
3182
|
-
if (x.dtype === dtype) return [[x], [dx]];
|
|
3183
|
-
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
3184
|
-
else {
|
|
3185
|
-
dx.dispose();
|
|
3186
|
-
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3187
|
-
}
|
|
3188
|
-
},
|
|
3189
|
-
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3190
|
-
if (x.dtype === dtype) return [[x], [dx]];
|
|
3191
|
-
dx.dispose();
|
|
3192
|
-
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3193
|
-
},
|
|
3194
|
-
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3195
|
-
[Primitive.Sin]([x], [dx]) {
|
|
3196
|
-
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3197
|
-
},
|
|
3198
|
-
[Primitive.Cos]([x], [dx]) {
|
|
3199
|
-
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
3200
|
-
},
|
|
3201
|
-
[Primitive.Asin]([x], [dx]) {
|
|
3202
|
-
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3203
|
-
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3204
|
-
},
|
|
3205
|
-
[Primitive.Atan]([x], [dx]) {
|
|
3206
|
-
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3207
|
-
return [[atan$1(x)], [dx.div(denom)]];
|
|
3208
|
-
},
|
|
3209
|
-
[Primitive.Exp]([x], [dx]) {
|
|
3210
|
-
const z = exp$1(x);
|
|
3211
|
-
return [[z.ref], [z.mul(dx)]];
|
|
3212
|
-
},
|
|
3213
|
-
[Primitive.Log]([x], [dx]) {
|
|
3214
|
-
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3215
|
-
},
|
|
3216
|
-
[Primitive.Erf]([x], [dx]) {
|
|
3217
|
-
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3218
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3219
|
-
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3220
|
-
},
|
|
3221
|
-
[Primitive.Erfc]([x], [dx]) {
|
|
3222
|
-
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3223
|
-
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3224
|
-
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3225
|
-
},
|
|
3226
|
-
[Primitive.Sqrt]([x], [dx]) {
|
|
3227
|
-
const z = sqrt$1(x);
|
|
3228
|
-
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3229
|
-
},
|
|
3230
|
-
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3231
|
-
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3232
|
-
},
|
|
3233
|
-
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3234
|
-
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3235
|
-
},
|
|
3236
|
-
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3237
|
-
if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3238
|
-
else if (op === AluOp.Mul) {
|
|
3239
|
-
const primal = reduce(x.ref, op, axis);
|
|
3240
|
-
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3241
|
-
return [[primal], [tangent]];
|
|
3242
|
-
} else if (op === AluOp.Min || op === AluOp.Max) {
|
|
3243
|
-
const primal = reduce(x.ref, op, axis);
|
|
3244
|
-
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3245
|
-
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3246
|
-
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3247
|
-
return [[primal], [tangent]];
|
|
3248
|
-
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3249
|
-
},
|
|
3250
|
-
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3251
|
-
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3252
|
-
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3253
|
-
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3254
|
-
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3255
|
-
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3256
|
-
dcond.dispose();
|
|
3257
|
-
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3258
|
-
},
|
|
3259
|
-
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3260
|
-
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3261
|
-
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3262
|
-
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3263
|
-
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3264
|
-
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3265
|
-
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3266
|
-
const indicesRef = indices.map((t) => t.ref);
|
|
3267
|
-
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3268
|
-
},
|
|
3269
|
-
[Primitive.JitCall](primals, tangents, { name, jaxpr }) {
|
|
3270
|
-
const { newJaxpr, newConsts } = jvpJaxpr(jaxpr);
|
|
3271
|
-
const outs = bind(Primitive.JitCall, [
|
|
3272
|
-
...newConsts.map((c) => c.ref),
|
|
3273
|
-
...primals,
|
|
3274
|
-
...tangents
|
|
3275
|
-
], {
|
|
3276
|
-
name: `${name}_jvp`,
|
|
3277
|
-
jaxpr: newJaxpr,
|
|
3278
|
-
numConsts: newConsts.length
|
|
3279
|
-
});
|
|
3280
|
-
const n = outs.length / 2;
|
|
3281
|
-
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3282
|
-
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3283
|
-
return [primalsOut, tangentsOut];
|
|
3284
|
-
}
|
|
3285
|
-
};
|
|
3286
|
-
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3287
|
-
function jvpJaxpr(jaxpr) {
|
|
3288
|
-
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3289
|
-
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3290
|
-
const { jaxpr: newJaxpr, consts: newConsts } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3291
|
-
const result = {
|
|
3292
|
-
newJaxpr,
|
|
3293
|
-
newConsts
|
|
3294
|
-
};
|
|
3295
|
-
jvpJaxprCache.set(jaxpr, result);
|
|
3296
|
-
return result;
|
|
3297
|
-
}
|
|
3298
|
-
function jvpFlat(f, primals, tangents) {
|
|
3299
|
-
try {
|
|
3300
|
-
var _usingCtx$1 = _usingCtx();
|
|
3301
|
-
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3302
|
-
const trace$1 = new JVPTrace(main);
|
|
3303
|
-
const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3304
|
-
const outs = f(...tracersIn);
|
|
3305
|
-
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3306
|
-
return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3307
|
-
} catch (_) {
|
|
3308
|
-
_usingCtx$1.e = _;
|
|
3309
|
-
} finally {
|
|
3310
|
-
_usingCtx$1.d();
|
|
3311
|
-
}
|
|
3312
|
-
}
|
|
3313
|
-
function jvp$1(f, primals, tangents) {
|
|
3314
|
-
const [primalsFlat, inTree] = flatten(primals);
|
|
3315
|
-
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
3316
|
-
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
3317
|
-
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
3318
|
-
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
3319
|
-
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
3320
|
-
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3321
|
-
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
3322
|
-
return [primalsOut, tangentsOut];
|
|
3323
|
-
}
|
|
3324
|
-
|
|
3325
|
-
//#endregion
|
|
3326
|
-
//#region src/frontend/vmap.ts
|
|
3327
|
-
function mappedAval(batchDim, aval) {
|
|
3328
|
-
const shape$1 = [...aval.shape];
|
|
3329
|
-
shape$1.splice(batchDim, 1);
|
|
3330
|
-
return new ShapedArray(shape$1, aval.dtype, aval.weakType);
|
|
3331
|
-
}
|
|
3332
|
-
/** Move one axis to a different index. */
|
|
3333
|
-
function moveaxis(x, src, dst) {
|
|
3334
|
-
const t = pureArray(x);
|
|
3335
|
-
src = checkAxis(src, t.ndim);
|
|
3336
|
-
dst = checkAxis(dst, t.ndim);
|
|
3337
|
-
if (src === dst) return t;
|
|
3338
|
-
const perm = range(t.ndim);
|
|
3339
|
-
perm.splice(src, 1);
|
|
3340
|
-
perm.splice(dst, 0, src);
|
|
3341
|
-
return transpose$1(t, perm);
|
|
3342
|
-
}
|
|
3343
|
-
function moveBatchAxis(axisSize, src, dst, x) {
|
|
3344
|
-
if (src === null) {
|
|
3345
|
-
const targetShape = [...x.shape];
|
|
3346
|
-
targetShape.splice(dst, 0, axisSize);
|
|
3347
|
-
return broadcast(x, targetShape, [dst]);
|
|
3348
|
-
} else if (src === dst) return x;
|
|
3349
|
-
else return moveaxis(x, src, dst);
|
|
3350
|
-
}
|
|
3351
|
-
var BatchTracer = class extends Tracer {
|
|
3352
|
-
constructor(trace$1, val, batchDim) {
|
|
3353
|
-
super(trace$1);
|
|
3354
|
-
this.val = val;
|
|
3355
|
-
this.batchDim = batchDim;
|
|
3356
|
-
}
|
|
3357
|
-
get aval() {
|
|
3358
|
-
if (this.batchDim === null) return this.val.aval;
|
|
3359
|
-
else return mappedAval(this.batchDim, this.val.aval);
|
|
3360
|
-
}
|
|
3361
|
-
toString() {
|
|
3362
|
-
return `BatchTracer(${this.val.toString()}, ${this.batchDim})`;
|
|
3363
|
-
}
|
|
3364
|
-
get ref() {
|
|
3365
|
-
this.val.ref;
|
|
3366
|
-
return this;
|
|
3367
|
-
}
|
|
3368
|
-
dispose() {
|
|
3369
|
-
this.val.dispose();
|
|
3370
|
-
}
|
|
3371
|
-
fullLower() {
|
|
3372
|
-
if (this.batchDim === null) return this.val.fullLower();
|
|
3373
|
-
else return this;
|
|
3374
|
-
}
|
|
3375
|
-
};
|
|
3376
|
-
var BatchTrace = class extends Trace {
|
|
3377
|
-
pure(val) {
|
|
3378
|
-
return this.lift(pureArray(val));
|
|
3379
|
-
}
|
|
3380
|
-
lift(val) {
|
|
3381
|
-
return new BatchTracer(this, val, null);
|
|
3382
|
-
}
|
|
3383
|
-
processPrimitive(primitive, tracers, params) {
|
|
3384
|
-
const [valsIn, bdimsIn] = unzip2(tracers.map((t) => [t.val, t.batchDim]));
|
|
3385
|
-
const vmapRule = vmapRules[primitive];
|
|
3386
|
-
if (vmapRule === void 0) throw new Error(`No vmap rule for: ${primitive}`);
|
|
3387
|
-
if (bdimsIn.every((d) => d === null)) {
|
|
3388
|
-
const valOuts$1 = bind(primitive, valsIn, params);
|
|
3389
|
-
return valOuts$1.map((x) => new BatchTracer(this, x, null));
|
|
3390
|
-
}
|
|
3391
|
-
const [valOuts, bdimOuts] = vmapRule(this.axisSize, valsIn, bdimsIn, params);
|
|
3392
|
-
return zip(valOuts, bdimOuts).map(([x, bd]) => new BatchTracer(this, x, bd));
|
|
3393
|
-
}
|
|
3394
|
-
get axisSize() {
|
|
3395
|
-
return this.main.globalData;
|
|
3396
|
-
}
|
|
3397
|
-
};
|
|
3398
|
-
/**
|
|
3399
|
-
* Process a primitive with built-in broadcasting.
|
|
3400
|
-
*
|
|
3401
|
-
* Reference: https://github.com/jax-ml/jax/blob/jax-v0.8.1/jax/_src/interpreters/batching.py#L1029
|
|
3402
|
-
*/
|
|
3403
|
-
function broadcastBatcher(op) {
|
|
3404
|
-
return (axisSize, args, dims) => {
|
|
3405
|
-
if (args.length === 0) throw new Error("Empty list in broadcastBatcher");
|
|
3406
|
-
const nd = Math.max(...args.map((x, i) => ndim$1(x) + (dims[i] === null ? 1 : 0)));
|
|
3407
|
-
const firstIdx = dims.findIndex((d) => d !== null);
|
|
3408
|
-
const firstBdim = dims[firstIdx] - args[firstIdx].ndim;
|
|
3409
|
-
if (zip(args, dims).every(([x, d]) => d === null && ndim$1(x) < -firstBdim || d !== null && d - x.ndim === firstBdim)) return [[op(...args)], [nd + firstBdim]];
|
|
3410
|
-
args = args.map((x, i) => {
|
|
3411
|
-
if (dims[i] === null) return x;
|
|
3412
|
-
x = moveBatchAxis(axisSize, dims[i], 0, x);
|
|
3413
|
-
if (x.ndim < nd) x = x.reshape([
|
|
3414
|
-
x.shape[0],
|
|
3415
|
-
...rep(nd - x.ndim, 1),
|
|
3416
|
-
...x.shape.slice(1)
|
|
3417
|
-
]);
|
|
3418
|
-
return x;
|
|
3419
|
-
});
|
|
3420
|
-
return [[op(...args)], [0]];
|
|
3421
|
-
};
|
|
3422
|
-
}
|
|
3423
|
-
function unopBatcher(op) {
|
|
3424
|
-
return (axisSize, [x], [xBdim], params) => {
|
|
3425
|
-
return [[op(x, params)], [xBdim]];
|
|
3522
|
+
function unopBatcher(op) {
|
|
3523
|
+
return (axisSize, [x], [xBdim], params) => {
|
|
3524
|
+
return [[op(x, params)], [xBdim]];
|
|
3426
3525
|
};
|
|
3427
3526
|
}
|
|
3428
3527
|
const vmapRules = {
|
|
@@ -3430,6 +3529,8 @@ const vmapRules = {
|
|
|
3430
3529
|
[Primitive.Mul]: broadcastBatcher(mul),
|
|
3431
3530
|
[Primitive.Idiv]: broadcastBatcher(idiv),
|
|
3432
3531
|
[Primitive.Mod]: broadcastBatcher(mod),
|
|
3532
|
+
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3533
|
+
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3433
3534
|
[Primitive.Neg]: unopBatcher(neg),
|
|
3434
3535
|
[Primitive.Reciprocal]: unopBatcher(reciprocal$1),
|
|
3435
3536
|
[Primitive.Floor]: unopBatcher(floor$1),
|
|
@@ -3446,8 +3547,6 @@ const vmapRules = {
|
|
|
3446
3547
|
[Primitive.Erf]: unopBatcher(erf$1),
|
|
3447
3548
|
[Primitive.Erfc]: unopBatcher(erfc$1),
|
|
3448
3549
|
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
3449
|
-
[Primitive.Min]: broadcastBatcher(min$1),
|
|
3450
|
-
[Primitive.Max]: broadcastBatcher(max$1),
|
|
3451
3550
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
3452
3551
|
assertNonNull(xBdim);
|
|
3453
3552
|
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
@@ -3460,10 +3559,49 @@ const vmapRules = {
|
|
|
3460
3559
|
const z = dot$2(x, y);
|
|
3461
3560
|
return [[z], [z.ndim - 1]];
|
|
3462
3561
|
},
|
|
3562
|
+
[Primitive.Conv](axisSize, [x, y], [xBdim, yBdim], params) {
|
|
3563
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3564
|
+
y = moveBatchAxis(axisSize, yBdim, 0, y);
|
|
3565
|
+
const z = conv$1(x, y, {
|
|
3566
|
+
...params,
|
|
3567
|
+
vmapDims: params.vmapDims + 1
|
|
3568
|
+
});
|
|
3569
|
+
return [[z], [0]];
|
|
3570
|
+
},
|
|
3463
3571
|
[Primitive.Compare](axisSize, args, dims, { op }) {
|
|
3464
3572
|
return broadcastBatcher((x, y) => compare(x, y, op))(axisSize, args, dims, {});
|
|
3465
3573
|
},
|
|
3466
|
-
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3574
|
+
[Primitive.Where]: broadcastBatcher(where$1),
|
|
3575
|
+
[Primitive.Gather](axisSize, [x, ...indices], [xBdim, ...indicesBdim], { axis, outDim }) {
|
|
3576
|
+
if (indicesBdim.every((d) => d === null)) {
|
|
3577
|
+
assertNonNull(xBdim);
|
|
3578
|
+
const newAxis = axis.map((ax) => ax + (xBdim <= ax ? 1 : 0));
|
|
3579
|
+
let newBdim = xBdim - axis.filter((ax) => ax < xBdim).length;
|
|
3580
|
+
let newOutDim = outDim;
|
|
3581
|
+
if (newOutDim < newBdim) newBdim += axis.length;
|
|
3582
|
+
else newOutDim += 1;
|
|
3583
|
+
return [[gather(x, indices, newAxis, newOutDim)], [newBdim]];
|
|
3584
|
+
}
|
|
3585
|
+
const nd = Math.max(...indices.map((m, i) => ndim$1(m) + (indicesBdim[i] === null ? 1 : 0)));
|
|
3586
|
+
indices = indices.map((m, i) => {
|
|
3587
|
+
if (indicesBdim[i] === null) return m;
|
|
3588
|
+
m = moveBatchAxis(axisSize, indicesBdim[i], 0, m);
|
|
3589
|
+
if (m.ndim < nd) m = m.reshape([
|
|
3590
|
+
m.shape[0],
|
|
3591
|
+
...rep(nd - m.ndim, 1),
|
|
3592
|
+
...m.shape.slice(1)
|
|
3593
|
+
]);
|
|
3594
|
+
return m;
|
|
3595
|
+
});
|
|
3596
|
+
if (xBdim === null) return [[gather(x, indices, axis, outDim)], [outDim]];
|
|
3597
|
+
else {
|
|
3598
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3599
|
+
const newAxis = [0, ...axis.map((ax) => ax + 1)];
|
|
3600
|
+
const extraBatchIndex = arange(axisSize).reshape([-1, ...rep(nd - 1, 1)]);
|
|
3601
|
+
indices.splice(0, 0, extraBatchIndex);
|
|
3602
|
+
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3603
|
+
}
|
|
3604
|
+
},
|
|
3467
3605
|
[Primitive.Transpose](axisSize, [x], [xBdim], { perm }) {
|
|
3468
3606
|
assertNonNull(xBdim);
|
|
3469
3607
|
const newPerm = perm.map((p) => p + (xBdim <= p ? 1 : 0));
|
|
@@ -3495,42 +3633,53 @@ const vmapRules = {
|
|
|
3495
3633
|
const newWidth = width.toSpliced(xBdim, 0, [0, 0]);
|
|
3496
3634
|
return [[pad$1(x, newWidth)], [xBdim]];
|
|
3497
3635
|
},
|
|
3498
|
-
[Primitive.
|
|
3499
|
-
|
|
3500
|
-
|
|
3501
|
-
|
|
3502
|
-
|
|
3503
|
-
|
|
3504
|
-
|
|
3505
|
-
|
|
3506
|
-
|
|
3507
|
-
|
|
3508
|
-
|
|
3509
|
-
|
|
3510
|
-
|
|
3511
|
-
|
|
3512
|
-
|
|
3513
|
-
|
|
3514
|
-
|
|
3515
|
-
...
|
|
3636
|
+
[Primitive.Sort](axisSize, [x], [xBdim]) {
|
|
3637
|
+
assertNonNull(xBdim);
|
|
3638
|
+
if (xBdim !== x.ndim - 1) return [[sort$1(x)], [xBdim]];
|
|
3639
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3640
|
+
return [[sort$1(x)], [0]];
|
|
3641
|
+
},
|
|
3642
|
+
[Primitive.Argsort](axisSize, [x], [xBdim]) {
|
|
3643
|
+
assertNonNull(xBdim);
|
|
3644
|
+
if (xBdim !== x.ndim - 1) return [argsort$1(x), [xBdim, xBdim]];
|
|
3645
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3646
|
+
return [argsort$1(x), [0, 0]];
|
|
3647
|
+
},
|
|
3648
|
+
[Primitive.TriangularSolve](axisSize, [a, b], [aBdim, bBdim], { unitDiagonal }) {
|
|
3649
|
+
if (aBdim === null) {
|
|
3650
|
+
b = moveBatchAxis(axisSize, bBdim, -3, b);
|
|
3651
|
+
const [s, m, n] = b.shape.slice(-3);
|
|
3652
|
+
b = b.reshape([
|
|
3653
|
+
...b.shape.slice(0, -3),
|
|
3654
|
+
s * m,
|
|
3655
|
+
n
|
|
3516
3656
|
]);
|
|
3517
|
-
|
|
3518
|
-
|
|
3519
|
-
|
|
3520
|
-
|
|
3521
|
-
|
|
3522
|
-
|
|
3523
|
-
|
|
3524
|
-
|
|
3525
|
-
return [[gather(x, indices, newAxis, outDim)], [outDim]];
|
|
3657
|
+
let x$1 = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3658
|
+
x$1 = x$1.reshape([
|
|
3659
|
+
...b.shape.slice(0, -2),
|
|
3660
|
+
s,
|
|
3661
|
+
m,
|
|
3662
|
+
n
|
|
3663
|
+
]);
|
|
3664
|
+
return [[x$1], [x$1.ndim - 3]];
|
|
3526
3665
|
}
|
|
3666
|
+
a = moveBatchAxis(axisSize, aBdim, 0, a);
|
|
3667
|
+
b = moveBatchAxis(axisSize, bBdim, 0, b);
|
|
3668
|
+
const x = bind1(Primitive.TriangularSolve, [a, b], { unitDiagonal });
|
|
3669
|
+
return [[x], [0]];
|
|
3670
|
+
},
|
|
3671
|
+
[Primitive.Cholesky](axisSize, [x], [xBdim]) {
|
|
3672
|
+
assertNonNull(xBdim);
|
|
3673
|
+
if (xBdim < x.ndim - 2) return [[cholesky$2(x)], [xBdim]];
|
|
3674
|
+
x = moveBatchAxis(axisSize, xBdim, 0, x);
|
|
3675
|
+
return [[cholesky$2(x)], [0]];
|
|
3527
3676
|
},
|
|
3528
|
-
[Primitive.
|
|
3529
|
-
const
|
|
3530
|
-
const outs = bind(Primitive.
|
|
3677
|
+
[Primitive.Jit](axisSize, args, dims, { name, jaxpr }) {
|
|
3678
|
+
const newJaxpr = vmapJaxpr(jaxpr, axisSize, dims);
|
|
3679
|
+
const outs = bind(Primitive.Jit, [...newJaxpr.consts.map((c) => c.ref), ...args], {
|
|
3531
3680
|
name: `${name}_vmap`,
|
|
3532
|
-
jaxpr: newJaxpr,
|
|
3533
|
-
numConsts:
|
|
3681
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3682
|
+
numConsts: newJaxpr.consts.length
|
|
3534
3683
|
});
|
|
3535
3684
|
return [outs, rep(outs.length, 0)];
|
|
3536
3685
|
}
|
|
@@ -3546,14 +3695,10 @@ function vmapJaxpr(jaxpr, axisSize, dims) {
|
|
|
3546
3695
|
shape$1.splice(dims[i], 0, axisSize);
|
|
3547
3696
|
return new ShapedArray(shape$1, v.aval.dtype, v.aval.weakType);
|
|
3548
3697
|
});
|
|
3549
|
-
const { jaxpr: newJaxpr
|
|
3550
|
-
const result = {
|
|
3551
|
-
newJaxpr,
|
|
3552
|
-
newConsts
|
|
3553
|
-
};
|
|
3698
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((args) => vmapFlat(jaxprAsFun(jaxpr), dims, args))(inAvals);
|
|
3554
3699
|
if (!vmapJaxprCache.has(jaxpr)) vmapJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
3555
|
-
vmapJaxprCache.get(jaxpr).set(cacheKey,
|
|
3556
|
-
return
|
|
3700
|
+
vmapJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
3701
|
+
return newJaxpr;
|
|
3557
3702
|
}
|
|
3558
3703
|
function vmapFlat(f, inAxes, args) {
|
|
3559
3704
|
let axisSize = void 0;
|
|
@@ -3608,6 +3753,260 @@ function jacfwd$1(f) {
|
|
|
3608
3753
|
};
|
|
3609
3754
|
}
|
|
3610
3755
|
|
|
3756
|
+
//#endregion
|
|
3757
|
+
//#region src/frontend/jvp.ts
|
|
3758
|
+
var JVPTracer = class extends Tracer {
|
|
3759
|
+
constructor(trace$1, primal, tangent) {
|
|
3760
|
+
super(trace$1);
|
|
3761
|
+
this.primal = primal;
|
|
3762
|
+
this.tangent = tangent;
|
|
3763
|
+
}
|
|
3764
|
+
get aval() {
|
|
3765
|
+
return this.primal.aval;
|
|
3766
|
+
}
|
|
3767
|
+
toString() {
|
|
3768
|
+
return `JVPTracer(${this.primal.toString()}, ${this.tangent.toString()})`;
|
|
3769
|
+
}
|
|
3770
|
+
get ref() {
|
|
3771
|
+
this.primal.ref, this.tangent.ref;
|
|
3772
|
+
return this;
|
|
3773
|
+
}
|
|
3774
|
+
dispose() {
|
|
3775
|
+
this.primal.dispose();
|
|
3776
|
+
this.tangent.dispose();
|
|
3777
|
+
}
|
|
3778
|
+
};
|
|
3779
|
+
var JVPTrace = class extends Trace {
|
|
3780
|
+
pure(val) {
|
|
3781
|
+
return this.lift(pureArray(val));
|
|
3782
|
+
}
|
|
3783
|
+
lift(val) {
|
|
3784
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
3785
|
+
}
|
|
3786
|
+
processPrimitive(primitive, tracers, params) {
|
|
3787
|
+
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
3788
|
+
const jvpRule = jvpRules[primitive];
|
|
3789
|
+
if (jvpRule === void 0) throw new Error(`No JVP rule for: ${primitive}`);
|
|
3790
|
+
const [primalsOut, tangentsOut] = jvpRule(primalsIn, tangentsIn, params);
|
|
3791
|
+
return zip(primalsOut, tangentsOut).map(([x, t]) => new JVPTracer(this, x, t));
|
|
3792
|
+
}
|
|
3793
|
+
};
|
|
3794
|
+
/** Rule that applies the same operation to primals and tangents. */
|
|
3795
|
+
function linearTangentsJvp(primitive) {
|
|
3796
|
+
return (primals, tangents, params) => {
|
|
3797
|
+
const ys = bind(primitive, primals, params);
|
|
3798
|
+
const dys = bind(primitive, tangents, params);
|
|
3799
|
+
return [ys, dys];
|
|
3800
|
+
};
|
|
3801
|
+
}
|
|
3802
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
3803
|
+
function bilinearTangentsJvp(primitive) {
|
|
3804
|
+
return ([x, y], [dx, dy], params) => {
|
|
3805
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
3806
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
3807
|
+
return [[primal], [tangent]];
|
|
3808
|
+
};
|
|
3809
|
+
}
|
|
3810
|
+
/** Rule that zeros out any tangents. */
|
|
3811
|
+
function zeroTangentsJvp(primitive) {
|
|
3812
|
+
return (primals, tangents, params) => {
|
|
3813
|
+
for (const t of tangents) t.dispose();
|
|
3814
|
+
const ys = bind(primitive, primals, params);
|
|
3815
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
3816
|
+
};
|
|
3817
|
+
}
|
|
3818
|
+
/** Compute `a @ b.T`, batched to last two axes. */
|
|
3819
|
+
function batchMatmulT(a, b) {
|
|
3820
|
+
return dot$2(a.reshape(a.shape.toSpliced(-1, 0, 1)), b.reshape(b.shape.toSpliced(-2, 0, 1)));
|
|
3821
|
+
}
|
|
3822
|
+
/** Batch matrix transpose. */
|
|
3823
|
+
function mT(a) {
|
|
3824
|
+
return moveaxis(a, -2, -1);
|
|
3825
|
+
}
|
|
3826
|
+
const jvpRules = {
|
|
3827
|
+
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
3828
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
3829
|
+
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
3830
|
+
[Primitive.Mod]([x, y], [dx, dy]) {
|
|
3831
|
+
if (!isFloatDtype(x.dtype) && !isFloatDtype(y.dtype)) {
|
|
3832
|
+
dx.dispose();
|
|
3833
|
+
dy.dispose();
|
|
3834
|
+
return [[x.ref, y.ref], [zerosLike$1(x), zerosLike$1(y)]];
|
|
3835
|
+
}
|
|
3836
|
+
const q = idiv(x.ref, y.ref);
|
|
3837
|
+
return [[mod(x, y)], [dx.sub(dy.mul(q))]];
|
|
3838
|
+
},
|
|
3839
|
+
[Primitive.Min]([x, y], [dx, dy]) {
|
|
3840
|
+
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
3841
|
+
},
|
|
3842
|
+
[Primitive.Max]([x, y], [dx, dy]) {
|
|
3843
|
+
return [[max$1(x.ref, y.ref)], [where$1(less$1(x, y), dy, dx)]];
|
|
3844
|
+
},
|
|
3845
|
+
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
3846
|
+
[Primitive.Reciprocal]([x], [dx]) {
|
|
3847
|
+
const xRecip = reciprocal$1(x.ref);
|
|
3848
|
+
return [[xRecip.ref], [neg(xRecip.ref.mul(xRecip)).mul(dx)]];
|
|
3849
|
+
},
|
|
3850
|
+
[Primitive.Floor]: zeroTangentsJvp(Primitive.Floor),
|
|
3851
|
+
[Primitive.Ceil]: zeroTangentsJvp(Primitive.Ceil),
|
|
3852
|
+
[Primitive.StopGradient]: zeroTangentsJvp(Primitive.StopGradient),
|
|
3853
|
+
[Primitive.Cast]([x], [dx], { dtype }) {
|
|
3854
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3855
|
+
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
3856
|
+
else {
|
|
3857
|
+
dx.dispose();
|
|
3858
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3859
|
+
}
|
|
3860
|
+
},
|
|
3861
|
+
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
3862
|
+
if (x.dtype === dtype) return [[x], [dx]];
|
|
3863
|
+
dx.dispose();
|
|
3864
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
3865
|
+
},
|
|
3866
|
+
[Primitive.Sin]([x], [dx]) {
|
|
3867
|
+
return [[sin$1(x.ref)], [cos$1(x).mul(dx)]];
|
|
3868
|
+
},
|
|
3869
|
+
[Primitive.Cos]([x], [dx]) {
|
|
3870
|
+
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
3871
|
+
},
|
|
3872
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3873
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3874
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3875
|
+
},
|
|
3876
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3877
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3878
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3879
|
+
},
|
|
3880
|
+
[Primitive.Exp]([x], [dx]) {
|
|
3881
|
+
const z = exp$1(x);
|
|
3882
|
+
return [[z.ref], [z.mul(dx)]];
|
|
3883
|
+
},
|
|
3884
|
+
[Primitive.Log]([x], [dx]) {
|
|
3885
|
+
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
3886
|
+
},
|
|
3887
|
+
[Primitive.Erf]([x], [dx]) {
|
|
3888
|
+
const coeff = 2 / Math.sqrt(Math.PI);
|
|
3889
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3890
|
+
return [[erf$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3891
|
+
},
|
|
3892
|
+
[Primitive.Erfc]([x], [dx]) {
|
|
3893
|
+
const coeff = -2 / Math.sqrt(Math.PI);
|
|
3894
|
+
const expTerm = exp$1(neg(x.ref.mul(x.ref)));
|
|
3895
|
+
return [[erfc$1(x)], [expTerm.mul(coeff).mul(dx)]];
|
|
3896
|
+
},
|
|
3897
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
3898
|
+
const z = sqrt$1(x);
|
|
3899
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3900
|
+
},
|
|
3901
|
+
[Primitive.Reduce]([x], [dx], { op, axis }) {
|
|
3902
|
+
if (op === AluOp.Add) return [[reduce(x, op, axis)], [reduce(dx, op, axis)]];
|
|
3903
|
+
else if (op === AluOp.Mul) {
|
|
3904
|
+
const primal = reduce(x.ref, op, axis);
|
|
3905
|
+
const tangent = broadcast(primal.ref, x.shape, axis).mul(reciprocal$1(x)).mul(dx).sum(axis);
|
|
3906
|
+
return [[primal], [tangent]];
|
|
3907
|
+
} else if (op === AluOp.Min || op === AluOp.Max) {
|
|
3908
|
+
const primal = reduce(x.ref, op, axis);
|
|
3909
|
+
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
3910
|
+
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
3911
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
3912
|
+
return [[primal], [tangent]];
|
|
3913
|
+
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
3914
|
+
},
|
|
3915
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3916
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3917
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3918
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
3919
|
+
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
3920
|
+
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
3921
|
+
dcond.dispose();
|
|
3922
|
+
return [[where$1(cond.ref, x, y)], [where$1(cond, dx, dy)]];
|
|
3923
|
+
},
|
|
3924
|
+
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
3925
|
+
[Primitive.Gather]([x, ...indices], [dx, ..._], { axis, outDim }) {
|
|
3926
|
+
const indicesRef = indices.map((t) => t.ref);
|
|
3927
|
+
return [[gather(x, indices, axis, outDim)], [gather(dx, indicesRef, axis, outDim)]];
|
|
3928
|
+
},
|
|
3929
|
+
[Primitive.Transpose]: linearTangentsJvp(Primitive.Transpose),
|
|
3930
|
+
[Primitive.Broadcast]: linearTangentsJvp(Primitive.Broadcast),
|
|
3931
|
+
[Primitive.Reshape]: linearTangentsJvp(Primitive.Reshape),
|
|
3932
|
+
[Primitive.Flip]: linearTangentsJvp(Primitive.Flip),
|
|
3933
|
+
[Primitive.Shrink]: linearTangentsJvp(Primitive.Shrink),
|
|
3934
|
+
[Primitive.Pad]: linearTangentsJvp(Primitive.Pad),
|
|
3935
|
+
[Primitive.Sort]([x], [dx]) {
|
|
3936
|
+
const [y, idx] = argsort$1(x);
|
|
3937
|
+
return [[y], [gather(dx, [idx], [-1], -1)]];
|
|
3938
|
+
},
|
|
3939
|
+
[Primitive.Argsort]([x], [dx]) {
|
|
3940
|
+
const [y, idx] = argsort$1(x);
|
|
3941
|
+
return [[y, idx.ref], [gather(dx, [idx.ref], [-1], -1), zerosLike$1(idx)]];
|
|
3942
|
+
},
|
|
3943
|
+
[Primitive.TriangularSolve]([a, b], [da, db], { unitDiagonal }) {
|
|
3944
|
+
const x = triangularSolve$1(a.ref, b, { unitDiagonal });
|
|
3945
|
+
const dax = batchMatmulT(da, x.ref);
|
|
3946
|
+
const rhsT = db.sub(mT(dax));
|
|
3947
|
+
const dx = triangularSolve$1(a, rhsT, { unitDiagonal });
|
|
3948
|
+
return [[x], [dx]];
|
|
3949
|
+
},
|
|
3950
|
+
[Primitive.Cholesky]([a], [da]) {
|
|
3951
|
+
const L = cholesky$2(a.ref);
|
|
3952
|
+
da = da.ref.add(mT(da)).mul(.5);
|
|
3953
|
+
const W = triangularSolve$1(L.ref, da, { lower: true });
|
|
3954
|
+
const ST = triangularSolve$1(L.ref, mT(W), { lower: true });
|
|
3955
|
+
const dL = batchMatmulT(L.ref, triu(ST.ref, 1).add(triu(ST)).mul(.5));
|
|
3956
|
+
return [[L], [dL]];
|
|
3957
|
+
},
|
|
3958
|
+
[Primitive.Jit](primals, tangents, { name, jaxpr }) {
|
|
3959
|
+
const newJaxpr = jvpJaxpr(jaxpr);
|
|
3960
|
+
const outs = bind(Primitive.Jit, [
|
|
3961
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
3962
|
+
...primals,
|
|
3963
|
+
...tangents
|
|
3964
|
+
], {
|
|
3965
|
+
name: `${name}_jvp`,
|
|
3966
|
+
jaxpr: newJaxpr.jaxpr,
|
|
3967
|
+
numConsts: newJaxpr.consts.length
|
|
3968
|
+
});
|
|
3969
|
+
const n = outs.length / 2;
|
|
3970
|
+
if (!Number.isInteger(n)) throw new Error("internal: JVP Jaxpr output length is not even");
|
|
3971
|
+
const [primalsOut, tangentsOut] = [outs.slice(0, n), outs.slice(n)];
|
|
3972
|
+
return [primalsOut, tangentsOut];
|
|
3973
|
+
}
|
|
3974
|
+
};
|
|
3975
|
+
const jvpJaxprCache = /* @__PURE__ */ new Map();
|
|
3976
|
+
function jvpJaxpr(jaxpr) {
|
|
3977
|
+
if (jvpJaxprCache.has(jaxpr)) return jvpJaxprCache.get(jaxpr);
|
|
3978
|
+
const inAvals = jaxpr.inBinders.map((v) => v.aval);
|
|
3979
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((primals, tangents) => jvpFlat(jaxprAsFun(jaxpr), primals, tangents))(inAvals, inAvals);
|
|
3980
|
+
jvpJaxprCache.set(jaxpr, newJaxpr);
|
|
3981
|
+
return newJaxpr;
|
|
3982
|
+
}
|
|
3983
|
+
function jvpFlat(f, primals, tangents) {
|
|
3984
|
+
try {
|
|
3985
|
+
var _usingCtx$1 = _usingCtx();
|
|
3986
|
+
const main = _usingCtx$1.u(newMain(JVPTrace));
|
|
3987
|
+
const trace$1 = new JVPTrace(main);
|
|
3988
|
+
const tracersIn = zip(primals, tangents).map(([x, t]) => new JVPTracer(trace$1, pureArray(x), pureArray(t)));
|
|
3989
|
+
const outs = f(...tracersIn);
|
|
3990
|
+
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3991
|
+
return unzip2(tracersOut.map((t) => [t.primal, t.tangent]));
|
|
3992
|
+
} catch (_) {
|
|
3993
|
+
_usingCtx$1.e = _;
|
|
3994
|
+
} finally {
|
|
3995
|
+
_usingCtx$1.d();
|
|
3996
|
+
}
|
|
3997
|
+
}
|
|
3998
|
+
function jvp$1(f, primals, tangents) {
|
|
3999
|
+
const [primalsFlat, inTree] = flatten(primals);
|
|
4000
|
+
const [tangentsFlat, inTree2] = flatten(tangents);
|
|
4001
|
+
if (!inTree.equals(inTree2)) throw new TreeMismatchError("jvp", inTree, inTree2);
|
|
4002
|
+
const [flatFun, outTree] = flattenFun(f, inTree);
|
|
4003
|
+
const [primalsOutFlat, tangentsOutFlat] = jvpFlat(flatFun, primalsFlat, tangentsFlat);
|
|
4004
|
+
if (outTree.value === void 0) throw new Error("outTree was not set in jvp");
|
|
4005
|
+
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
4006
|
+
const tangentsOut = unflatten(outTree.value, tangentsOutFlat);
|
|
4007
|
+
return [primalsOut, tangentsOut];
|
|
4008
|
+
}
|
|
4009
|
+
|
|
3611
4010
|
//#endregion
|
|
3612
4011
|
//#region src/frontend/linearize.ts
|
|
3613
4012
|
/** Array value that can either be known or unknown. */
|
|
@@ -3638,11 +4037,10 @@ function partialEvalFlat(f, pvalsIn) {
|
|
|
3638
4037
|
const tracersOut = outs.map((out) => fullRaise(trace$1, out));
|
|
3639
4038
|
const pvalsOut = tracersOut.map((t) => t.pval);
|
|
3640
4039
|
const unknownTracersOut = tracersOut.filter((t) => !t.pval.isKnown);
|
|
3641
|
-
const
|
|
4040
|
+
const jaxpr = partialEvalGraphToJaxpr(unknownTracersIn, unknownTracersOut);
|
|
3642
4041
|
return {
|
|
3643
4042
|
jaxpr,
|
|
3644
|
-
pvalsOut
|
|
3645
|
-
consts
|
|
4043
|
+
pvalsOut
|
|
3646
4044
|
};
|
|
3647
4045
|
}
|
|
3648
4046
|
/**
|
|
@@ -3659,22 +4057,19 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
3659
4057
|
const [primalsOut$1, tangentsOut] = jvp$1(f, x.slice(0, k), x.slice(k, 2 * k));
|
|
3660
4058
|
return [...primalsOut$1, ...tangentsOut];
|
|
3661
4059
|
};
|
|
3662
|
-
const { jaxpr, pvalsOut
|
|
4060
|
+
const { jaxpr, pvalsOut } = partialEvalFlat(fJvp, pvalsIn);
|
|
3663
4061
|
const primalPvals = pvalsOut.slice(0, pvalsOut.length / 2);
|
|
3664
4062
|
if (!primalPvals.every((pval) => pval.isKnown)) throw new Error("Not all primal values are known after partial evaluation");
|
|
3665
4063
|
const primalsOut = primalPvals.map((pval) => pval.val);
|
|
3666
4064
|
return {
|
|
3667
4065
|
primalsOut,
|
|
3668
|
-
jaxpr
|
|
3669
|
-
consts
|
|
4066
|
+
jaxpr
|
|
3670
4067
|
};
|
|
3671
4068
|
}
|
|
3672
4069
|
function linearizeFlat(f, primalsIn) {
|
|
3673
|
-
const { primalsOut, jaxpr
|
|
3674
|
-
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3675
|
-
const dispose$1 = () =>
|
|
3676
|
-
for (const c of consts) c.dispose();
|
|
3677
|
-
};
|
|
4070
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4071
|
+
const fLin = (...tangents) => evalJaxpr(jaxpr.jaxpr, [...jaxpr.consts.map((c) => c.ref), ...tangents]);
|
|
4072
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
3678
4073
|
return [
|
|
3679
4074
|
primalsOut,
|
|
3680
4075
|
fLin,
|
|
@@ -3758,7 +4153,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3758
4153
|
}
|
|
3759
4154
|
processPrimitive(primitive, tracers, params) {
|
|
3760
4155
|
if (tracers.every((t) => t.pval.isKnown)) return bind(primitive, tracers.map((t) => t.fullLower()), params);
|
|
3761
|
-
if (primitive === Primitive.
|
|
4156
|
+
if (primitive === Primitive.Jit) {
|
|
3762
4157
|
const { name, jaxpr, numConsts } = params;
|
|
3763
4158
|
return this.#partialEvalJaxpr(name, jaxpr, numConsts, tracers);
|
|
3764
4159
|
}
|
|
@@ -3784,14 +4179,14 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3784
4179
|
* Evaluate a Jaxpr on a set of PartialEvalTracers, computing as many known
|
|
3785
4180
|
* values as possible (with JIT) and forwarding the unknown ones.
|
|
3786
4181
|
*
|
|
3787
|
-
* Used when encountering a
|
|
4182
|
+
* Used when encountering a Jit rule during the trace.
|
|
3788
4183
|
*/
|
|
3789
4184
|
#partialEvalJaxpr(name, jaxpr, numConsts, tracers) {
|
|
3790
4185
|
jaxpr = jaxpr.flatten();
|
|
3791
4186
|
const inUnknowns = tracers.map((t) => !t.pval.isKnown);
|
|
3792
4187
|
const { jaxpr1, jaxpr2, outUnknowns, numRes } = partialEvalJaxpr(jaxpr, inUnknowns);
|
|
3793
4188
|
const [knownTracers, unknownTracers] = partitionList(inUnknowns, tracers);
|
|
3794
|
-
const outs1Res = bind(Primitive.
|
|
4189
|
+
const outs1Res = bind(Primitive.Jit, knownTracers.map((t) => t.ref.fullLower()), {
|
|
3795
4190
|
name: `${name}_peval`,
|
|
3796
4191
|
jaxpr: jaxpr1,
|
|
3797
4192
|
numConsts: 0
|
|
@@ -3801,7 +4196,7 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3801
4196
|
const resTracers = res.map((x) => this.instantiateConst(fullRaise(this, x)));
|
|
3802
4197
|
const recipe = {
|
|
3803
4198
|
type: "JaxprEqn",
|
|
3804
|
-
prim: Primitive.
|
|
4199
|
+
prim: Primitive.Jit,
|
|
3805
4200
|
tracersIn: resTracers.concat(unknownTracers),
|
|
3806
4201
|
params: {
|
|
3807
4202
|
name: `${name}_resid`,
|
|
@@ -3830,7 +4225,7 @@ function partialEvalJaxpr(jaxpr, inUnknowns, instantiate) {
|
|
|
3830
4225
|
const eqns1 = [];
|
|
3831
4226
|
const eqns2 = [];
|
|
3832
4227
|
for (const eqn of jaxpr.eqns) {
|
|
3833
|
-
if (eqn.primitive === Primitive.
|
|
4228
|
+
if (eqn.primitive === Primitive.Jit) throw new TypeError("partialEvalJaxpr requires flattened Jaxpr");
|
|
3834
4229
|
const hasUnknowns = eqn.inputs.some((x) => x instanceof Var && !knownVars.has(x));
|
|
3835
4230
|
if (hasUnknowns) {
|
|
3836
4231
|
for (const x of eqn.inputs) if (x instanceof Var && knownVars.has(x)) residuals.add(x);
|
|
@@ -3904,11 +4299,8 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3904
4299
|
for (const t of tracersIn) t.dispose();
|
|
3905
4300
|
for (const t of tracersOut) t.dispose();
|
|
3906
4301
|
jaxpr = jaxpr.simplify();
|
|
3907
|
-
if (DEBUG >= 5) console.
|
|
3908
|
-
return
|
|
3909
|
-
jaxpr,
|
|
3910
|
-
consts
|
|
3911
|
-
};
|
|
4302
|
+
if (DEBUG >= 5) console.info("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
4303
|
+
return new ClosedJaxpr(jaxpr, consts);
|
|
3912
4304
|
}
|
|
3913
4305
|
/** Marker type for pullback, used by transpose rules. */
|
|
3914
4306
|
var UndefPrimal = class {
|
|
@@ -4038,22 +4430,25 @@ const transposeRules = {
|
|
|
4038
4430
|
},
|
|
4039
4431
|
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
4040
4432
|
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
4433
|
+
const v = params.vmapDims;
|
|
4041
4434
|
const rev01 = [
|
|
4042
|
-
|
|
4043
|
-
|
|
4044
|
-
|
|
4435
|
+
...range(v),
|
|
4436
|
+
v + 1,
|
|
4437
|
+
v,
|
|
4438
|
+
...range(v + 2, ct.ndim)
|
|
4045
4439
|
];
|
|
4046
4440
|
if (lhs instanceof UndefPrimal) {
|
|
4047
4441
|
let kernel = rhs;
|
|
4048
4442
|
kernel = transpose$1(kernel, rev01);
|
|
4049
|
-
kernel = flip$1(kernel, range(2, kernel.ndim));
|
|
4443
|
+
kernel = flip$1(kernel, range(v + 2, kernel.ndim));
|
|
4050
4444
|
const result = conv$1(ct, kernel, {
|
|
4445
|
+
vmapDims: v,
|
|
4051
4446
|
strides: params.lhsDilation,
|
|
4052
4447
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4053
|
-
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4054
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4448
|
+
const dilatedKernel = (kernel.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4449
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4055
4450
|
const padBefore = dilatedKernel - 1 - pl;
|
|
4056
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4451
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4057
4452
|
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
4058
4453
|
return [padBefore, padAfter];
|
|
4059
4454
|
}),
|
|
@@ -4065,11 +4460,12 @@ const transposeRules = {
|
|
|
4065
4460
|
const newLhs = transpose$1(lhs, rev01);
|
|
4066
4461
|
const newRhs = transpose$1(ct, rev01);
|
|
4067
4462
|
let result = conv$1(newLhs, newRhs, {
|
|
4463
|
+
vmapDims: v,
|
|
4068
4464
|
strides: params.rhsDilation,
|
|
4069
4465
|
padding: params.padding.map(([pl, _pr], i) => {
|
|
4070
|
-
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4071
|
-
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4072
|
-
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
4466
|
+
const dilatedLhs = (lhs.aval.shape[i + v + 2] - 1) * params.lhsDilation[i] + 1;
|
|
4467
|
+
const dilatedKernel = (rhs.aval.shape[i + v + 2] - 1) * params.rhsDilation[i] + 1;
|
|
4468
|
+
const dilatedCt = (ct.shape[i + v + 2] - 1) * params.strides[i] + 1;
|
|
4073
4469
|
const padFromLhs = dilatedCt - dilatedLhs;
|
|
4074
4470
|
const padFromRhs = dilatedKernel - pl - 1;
|
|
4075
4471
|
return [pl, padFromLhs + padFromRhs];
|
|
@@ -4096,6 +4492,11 @@ const transposeRules = {
|
|
|
4096
4492
|
cond.dispose();
|
|
4097
4493
|
return cts;
|
|
4098
4494
|
},
|
|
4495
|
+
[Primitive.Gather]([ct], [x, ...indices], { axis, outDim }) {
|
|
4496
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4497
|
+
if (indices.some((i) => i instanceof UndefPrimal)) throw new NonlinearError(Primitive.Gather);
|
|
4498
|
+
throw new Error("Gather transpose rule is not yet implemented, requires complex Scatter sum operation");
|
|
4499
|
+
},
|
|
4099
4500
|
[Primitive.Transpose]([ct], [x], { perm }) {
|
|
4100
4501
|
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Transpose);
|
|
4101
4502
|
return [transpose$1(ct, invertPermutation(perm))];
|
|
@@ -4122,23 +4523,26 @@ const transposeRules = {
|
|
|
4122
4523
|
const slice = width.map(([s, _e], i) => [s, s + x.aval.shape[i]]);
|
|
4123
4524
|
return [shrink(ct, slice)];
|
|
4124
4525
|
},
|
|
4125
|
-
[Primitive.
|
|
4126
|
-
if (!(
|
|
4127
|
-
|
|
4128
|
-
|
|
4526
|
+
[Primitive.TriangularSolve]([ct], [a, b], { unitDiagonal }) {
|
|
4527
|
+
if (a instanceof UndefPrimal || !(b instanceof UndefPrimal)) throw new NonlinearError(Primitive.TriangularSolve);
|
|
4528
|
+
const ctB = triangularSolve$1(moveaxis(a, -2, -1), ct, {
|
|
4529
|
+
lower: true,
|
|
4530
|
+
unitDiagonal
|
|
4531
|
+
});
|
|
4532
|
+
return [null, ctB];
|
|
4129
4533
|
},
|
|
4130
|
-
[Primitive.
|
|
4534
|
+
[Primitive.Jit](cts, args, { name, jaxpr }) {
|
|
4131
4535
|
const undefPrimals = args.map((x) => x instanceof UndefPrimal);
|
|
4132
|
-
const
|
|
4536
|
+
const newJaxpr = transposeJaxpr(jaxpr, undefPrimals);
|
|
4133
4537
|
const residuals = args.filter((x, i$1) => !undefPrimals[i$1]);
|
|
4134
|
-
const outs = bind(Primitive.
|
|
4135
|
-
...
|
|
4538
|
+
const outs = bind(Primitive.Jit, [
|
|
4539
|
+
...newJaxpr.consts.map((c) => c.ref),
|
|
4136
4540
|
...residuals,
|
|
4137
4541
|
...cts
|
|
4138
4542
|
], {
|
|
4139
4543
|
name: `${name}_t`,
|
|
4140
|
-
jaxpr: newJaxpr,
|
|
4141
|
-
numConsts:
|
|
4544
|
+
jaxpr: newJaxpr.jaxpr,
|
|
4545
|
+
numConsts: newJaxpr.consts.length
|
|
4142
4546
|
});
|
|
4143
4547
|
let i = 0;
|
|
4144
4548
|
return undefPrimals.map((isUndef) => isUndef ? outs[i++] : null);
|
|
@@ -4151,31 +4555,25 @@ function transposeJaxpr(jaxpr, undefPrimals) {
|
|
|
4151
4555
|
if (prevResult) return prevResult;
|
|
4152
4556
|
const { inTypes, outTypes } = typecheckJaxpr(jaxpr);
|
|
4153
4557
|
const forwardInTypes = inTypes.filter((_, i) => !undefPrimals[i]);
|
|
4154
|
-
const { jaxpr: newJaxpr
|
|
4558
|
+
const { jaxpr: newJaxpr } = makeJaxpr$1((forwardIn, cotangents) => {
|
|
4155
4559
|
const args = [];
|
|
4156
4560
|
let forwardInIdx = 0;
|
|
4157
4561
|
for (let i = 0; i < undefPrimals.length; i++) if (undefPrimals[i]) args.push(new UndefPrimal(inTypes[i]));
|
|
4158
4562
|
else args.push(forwardIn[forwardInIdx++]);
|
|
4159
4563
|
return evalJaxprTransposed(jaxpr, args, cotangents);
|
|
4160
4564
|
})(forwardInTypes, outTypes);
|
|
4161
|
-
typecheckJaxpr(newJaxpr);
|
|
4162
|
-
const result = {
|
|
4163
|
-
newJaxpr,
|
|
4164
|
-
newConsts
|
|
4165
|
-
};
|
|
4565
|
+
typecheckJaxpr(newJaxpr.jaxpr);
|
|
4166
4566
|
if (!transposeJaxprCache.has(jaxpr)) transposeJaxprCache.set(jaxpr, /* @__PURE__ */ new Map());
|
|
4167
|
-
transposeJaxprCache.get(jaxpr).set(cacheKey,
|
|
4168
|
-
return
|
|
4567
|
+
transposeJaxprCache.get(jaxpr).set(cacheKey, newJaxpr);
|
|
4568
|
+
return newJaxpr;
|
|
4169
4569
|
}
|
|
4170
4570
|
function vjpFlat(f, primalsIn) {
|
|
4171
|
-
const { primalsOut, jaxpr
|
|
4571
|
+
const { primalsOut, jaxpr } = linearizeFlatUtil(f, primalsIn);
|
|
4172
4572
|
const fVjp = (...cotangents) => {
|
|
4173
|
-
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4174
|
-
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
4175
|
-
};
|
|
4176
|
-
const dispose$1 = () => {
|
|
4177
|
-
for (const c of consts) c.dispose();
|
|
4573
|
+
const transposeInputs = [...jaxpr.consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
4574
|
+
return evalJaxprTransposed(jaxpr.jaxpr, transposeInputs, cotangents);
|
|
4178
4575
|
};
|
|
4576
|
+
const dispose$1 = () => jaxpr.dispose();
|
|
4179
4577
|
return [
|
|
4180
4578
|
primalsOut,
|
|
4181
4579
|
fVjp,
|
|
@@ -4232,150 +4630,6 @@ function jacrev$1(f) {
|
|
|
4232
4630
|
};
|
|
4233
4631
|
}
|
|
4234
4632
|
|
|
4235
|
-
//#endregion
|
|
4236
|
-
//#region src/library/lax.ts
|
|
4237
|
-
var lax_exports = {};
|
|
4238
|
-
__export(lax_exports, {
|
|
4239
|
-
conv: () => conv,
|
|
4240
|
-
convGeneralDilated: () => convGeneralDilated,
|
|
4241
|
-
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
4242
|
-
dot: () => dot$1,
|
|
4243
|
-
erf: () => erf,
|
|
4244
|
-
erfc: () => erfc,
|
|
4245
|
-
reduceWindow: () => reduceWindow,
|
|
4246
|
-
stopGradient: () => stopGradient$1
|
|
4247
|
-
});
|
|
4248
|
-
/**
|
|
4249
|
-
* General dot product/contraction operator.
|
|
4250
|
-
*
|
|
4251
|
-
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
4252
|
-
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
4253
|
-
*/
|
|
4254
|
-
function dot$1(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
4255
|
-
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
4256
|
-
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
4257
|
-
lc = lc.map((a) => checkAxis(a, lhs.ndim));
|
|
4258
|
-
rc = rc.map((a) => checkAxis(a, rhs.ndim));
|
|
4259
|
-
lb = lb.map((a) => checkAxis(a, lhs.ndim));
|
|
4260
|
-
rb = rb.map((a) => checkAxis(a, rhs.ndim));
|
|
4261
|
-
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
4262
|
-
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
4263
|
-
const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
4264
|
-
const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
4265
|
-
const lhs2 = lhs.transpose([
|
|
4266
|
-
...lb,
|
|
4267
|
-
...lf,
|
|
4268
|
-
...lc
|
|
4269
|
-
]);
|
|
4270
|
-
const rhs2 = rhs.transpose([
|
|
4271
|
-
...rb,
|
|
4272
|
-
...rf,
|
|
4273
|
-
...rc
|
|
4274
|
-
]);
|
|
4275
|
-
if (lc.length === 0) return mul(lhs2.reshape([
|
|
4276
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4277
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4278
|
-
...rep(rf.length, 1)
|
|
4279
|
-
]), rhs2.reshape([
|
|
4280
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4281
|
-
...rep(lf.length, 1),
|
|
4282
|
-
...rf.map((a) => rhs.shape[a])
|
|
4283
|
-
]));
|
|
4284
|
-
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
4285
|
-
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
4286
|
-
if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
4287
|
-
return dot$2(lhs2.reshape([
|
|
4288
|
-
...lb.map((a) => lhs.shape[a]),
|
|
4289
|
-
...lf.map((a) => lhs.shape[a]),
|
|
4290
|
-
...rep(rf.length, 1),
|
|
4291
|
-
prod(dotShapeX)
|
|
4292
|
-
]), rhs2.reshape([
|
|
4293
|
-
...rb.map((a) => rhs.shape[a]),
|
|
4294
|
-
...rep(lf.length, 1),
|
|
4295
|
-
...rf.map((a) => rhs.shape[a]),
|
|
4296
|
-
prod(dotShapeY)
|
|
4297
|
-
]));
|
|
4298
|
-
}
|
|
4299
|
-
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
4300
|
-
const padType = padding.toUpperCase();
|
|
4301
|
-
switch (padType) {
|
|
4302
|
-
case "VALID": return rep(inShape.length, [0, 0]);
|
|
4303
|
-
case "SAME":
|
|
4304
|
-
case "SAME_LOWER": {
|
|
4305
|
-
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
4306
|
-
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
4307
|
-
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
4308
|
-
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
4309
|
-
}
|
|
4310
|
-
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
4311
|
-
}
|
|
4312
|
-
}
|
|
4313
|
-
/**
|
|
4314
|
-
* General n-dimensional convolution operator, with optional dilation.
|
|
4315
|
-
*
|
|
4316
|
-
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
4317
|
-
* function in JAX, which wraps XLA's general convolution operator.
|
|
4318
|
-
*
|
|
4319
|
-
* Grouped convolutions are not supported right now.
|
|
4320
|
-
*/
|
|
4321
|
-
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
4322
|
-
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4323
|
-
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4324
|
-
if (typeof padding === "string") {
|
|
4325
|
-
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4326
|
-
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
4327
|
-
}
|
|
4328
|
-
return conv$1(lhs, rhs, {
|
|
4329
|
-
strides: windowStrides,
|
|
4330
|
-
padding,
|
|
4331
|
-
lhsDilation,
|
|
4332
|
-
rhsDilation
|
|
4333
|
-
});
|
|
4334
|
-
}
|
|
4335
|
-
/** Convenience wrapper around `convGeneralDilated`. */
|
|
4336
|
-
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
4337
|
-
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
4338
|
-
lhsDilation,
|
|
4339
|
-
rhsDilation
|
|
4340
|
-
});
|
|
4341
|
-
}
|
|
4342
|
-
/** Convenience wrapper around `convGeneralDilated`. */
|
|
4343
|
-
function conv(lhs, rhs, windowStrides, padding) {
|
|
4344
|
-
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
4345
|
-
}
|
|
4346
|
-
/** Reduce a computation over padded windows. */
|
|
4347
|
-
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
4348
|
-
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
4349
|
-
if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
|
|
4350
|
-
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
4351
|
-
return computation(bind1(Primitive.Pool, [operand], {
|
|
4352
|
-
window: windowDimensions,
|
|
4353
|
-
strides: windowStrides
|
|
4354
|
-
}));
|
|
4355
|
-
}
|
|
4356
|
-
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
4357
|
-
function erf(x) {
|
|
4358
|
-
return erf$1(x);
|
|
4359
|
-
}
|
|
4360
|
-
/**
|
|
4361
|
-
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
4362
|
-
*
|
|
4363
|
-
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
4364
|
-
* where `erf(x)` is very close to 1.
|
|
4365
|
-
*/
|
|
4366
|
-
function erfc(x) {
|
|
4367
|
-
return erfc$1(x);
|
|
4368
|
-
}
|
|
4369
|
-
/**
|
|
4370
|
-
* Stops gradient computation.
|
|
4371
|
-
*
|
|
4372
|
-
* Behaves as the identity function but prevents the flow of gradients during
|
|
4373
|
-
* forward or reverse-mode automatic differentiation.
|
|
4374
|
-
*/
|
|
4375
|
-
function stopGradient$1(x) {
|
|
4376
|
-
return stopGradient(x);
|
|
4377
|
-
}
|
|
4378
|
-
|
|
4379
4633
|
//#endregion
|
|
4380
4634
|
//#region src/library/numpy/einsum.ts
|
|
4381
4635
|
const bprod = (...xs) => xs.reduce((acc, x) => acc * BigInt(x), 1n);
|
|
@@ -4571,34 +4825,207 @@ function* allPaths(tensors, next) {
|
|
|
4571
4825
|
}
|
|
4572
4826
|
}
|
|
4573
4827
|
|
|
4828
|
+
//#endregion
|
|
4829
|
+
//#region src/library/numpy-fft.ts
|
|
4830
|
+
var numpy_fft_exports = {};
|
|
4831
|
+
__export(numpy_fft_exports, {
|
|
4832
|
+
fft: () => fft,
|
|
4833
|
+
ifft: () => ifft
|
|
4834
|
+
});
|
|
4835
|
+
function checkPairInput(name, a) {
|
|
4836
|
+
const fullName = `jax.numpy.fft.${name}`;
|
|
4837
|
+
if (!deepEqual(a.real.shape, a.imag.shape)) throw new Error(`${fullName}: real and imaginary parts must have the same shape, got ${JSON.stringify(a.real.shape)} and ${JSON.stringify(a.imag.shape)}`);
|
|
4838
|
+
if (a.real.dtype !== a.imag.dtype) throw new Error(`${fullName}: real and imaginary parts must have the same dtype, got ${a.real.dtype} and ${a.imag.dtype}`);
|
|
4839
|
+
if (!isFloatDtype(a.real.dtype)) throw new Error(`${fullName}: input must have a float dtype, got ${a.real.dtype}`);
|
|
4840
|
+
}
|
|
4841
|
+
function checkPowerOfTwo(name, n) {
|
|
4842
|
+
if ((n & n - 1) !== 0) throw new Error(`jax.numpy.fft.${name}: size must be a power of two, got ${n}`);
|
|
4843
|
+
}
|
|
4844
|
+
const fftUpdate = jit$1(function fftUpdate$1(i, { real, imag }) {
|
|
4845
|
+
const half = 2 ** i;
|
|
4846
|
+
real = real.reshape([-1, 2 * half]);
|
|
4847
|
+
imag = imag.reshape([-1, 2 * half]);
|
|
4848
|
+
const k = arange(0, half, 1, { dtype: real.dtype });
|
|
4849
|
+
const theta = k.mul(-Math.PI / half);
|
|
4850
|
+
const wr = cos(theta.ref);
|
|
4851
|
+
const wi = sin(theta);
|
|
4852
|
+
const ur = real.ref.slice([], [0, half]);
|
|
4853
|
+
const ui = imag.ref.slice([], [0, half]);
|
|
4854
|
+
const vr = real.slice([], [half, 2 * half]);
|
|
4855
|
+
const vi = imag.slice([], [half, 2 * half]);
|
|
4856
|
+
const tr = vr.ref.mul(wr.ref).sub(vi.ref.mul(wi.ref));
|
|
4857
|
+
const ti = vr.mul(wi).add(vi.mul(wr));
|
|
4858
|
+
return {
|
|
4859
|
+
real: concatenate([ur.ref.add(tr.ref), ur.sub(tr)], -1),
|
|
4860
|
+
imag: concatenate([ui.ref.add(ti.ref), ui.sub(ti)], -1)
|
|
4861
|
+
};
|
|
4862
|
+
}, { staticArgnums: [0] });
|
|
4863
|
+
/**
|
|
4864
|
+
* Compute a one-dimensional discrete Fourier transform.
|
|
4865
|
+
*
|
|
4866
|
+
* Currently, the size of the axis must be a power of two.
|
|
4867
|
+
*/
|
|
4868
|
+
function fft(a, axis = -1) {
|
|
4869
|
+
checkPairInput("fft", a);
|
|
4870
|
+
let { real, imag } = a;
|
|
4871
|
+
axis = checkAxis(axis, real.ndim);
|
|
4872
|
+
const n = real.shape[axis];
|
|
4873
|
+
checkPowerOfTwo("fft", n);
|
|
4874
|
+
const logN = Math.log2(n);
|
|
4875
|
+
let perm = null;
|
|
4876
|
+
if (axis !== real.ndim - 1) {
|
|
4877
|
+
perm = range(real.ndim);
|
|
4878
|
+
perm.splice(axis, 1);
|
|
4879
|
+
perm.push(axis);
|
|
4880
|
+
real = real.transpose(perm);
|
|
4881
|
+
imag = imag.transpose(perm);
|
|
4882
|
+
}
|
|
4883
|
+
const originalShape = real.shape;
|
|
4884
|
+
real = real.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
|
|
4885
|
+
imag = imag.reshape([-1, ...rep(logN, 2)]).transpose([0, ...range(1, logN + 1).reverse()]).flatten();
|
|
4886
|
+
for (let i = 0; i < logN; i++) ({real, imag} = fftUpdate(i, {
|
|
4887
|
+
real,
|
|
4888
|
+
imag
|
|
4889
|
+
}));
|
|
4890
|
+
real = real.reshape(originalShape);
|
|
4891
|
+
imag = imag.reshape(originalShape);
|
|
4892
|
+
if (perm !== null) {
|
|
4893
|
+
real = real.transpose(invertPermutation(perm));
|
|
4894
|
+
imag = imag.transpose(invertPermutation(perm));
|
|
4895
|
+
}
|
|
4896
|
+
return {
|
|
4897
|
+
real,
|
|
4898
|
+
imag
|
|
4899
|
+
};
|
|
4900
|
+
}
|
|
4901
|
+
/**
|
|
4902
|
+
* Compute a one-dimensional inverse discrete Fourier transform.
|
|
4903
|
+
*
|
|
4904
|
+
* Currently, the size of the axis must be a power of two.
|
|
4905
|
+
*/
|
|
4906
|
+
function ifft(a, axis = -1) {
|
|
4907
|
+
checkPairInput("ifft", a);
|
|
4908
|
+
let { real, imag } = a;
|
|
4909
|
+
axis = checkAxis(axis, real.ndim);
|
|
4910
|
+
const n = real.shape[axis];
|
|
4911
|
+
checkPowerOfTwo("ifft", n);
|
|
4912
|
+
imag = imag.mul(-1);
|
|
4913
|
+
const result = fft({
|
|
4914
|
+
real,
|
|
4915
|
+
imag
|
|
4916
|
+
}, axis);
|
|
4917
|
+
return {
|
|
4918
|
+
real: result.real.div(n),
|
|
4919
|
+
imag: result.imag.mul(-1).div(n)
|
|
4920
|
+
};
|
|
4921
|
+
}
|
|
4922
|
+
|
|
4923
|
+
//#endregion
|
|
4924
|
+
//#region src/library/numpy-linalg.ts
|
|
4925
|
+
var numpy_linalg_exports = {};
|
|
4926
|
+
__export(numpy_linalg_exports, {
|
|
4927
|
+
cholesky: () => cholesky$1,
|
|
4928
|
+
diagonal: () => diagonal,
|
|
4929
|
+
lstsq: () => lstsq,
|
|
4930
|
+
matmul: () => matmul,
|
|
4931
|
+
matrixTranspose: () => matrixTranspose,
|
|
4932
|
+
outer: () => outer,
|
|
4933
|
+
tensordot: () => tensordot,
|
|
4934
|
+
trace: () => trace,
|
|
4935
|
+
vecdot: () => vecdot
|
|
4936
|
+
});
|
|
4937
|
+
/**
|
|
4938
|
+
* Compute the Cholesky decomposition of a (batched) positive-definite matrix.
|
|
4939
|
+
*
|
|
4940
|
+
* This is like `jax.lax.linalg.cholesky()`, except with an option to symmetrize
|
|
4941
|
+
* the input matrix, which is on by default.
|
|
4942
|
+
*/
|
|
4943
|
+
function cholesky$1(a, { upper = false, symmetrizeInput = true } = {}) {
|
|
4944
|
+
a = fudgeArray(a);
|
|
4945
|
+
if (a.ndim < 2 || a.shape[a.ndim - 1] !== a.shape[a.ndim - 2]) throw new Error(`cholesky: input must be at least 2D square matrix, got ${a.aval}`);
|
|
4946
|
+
if (symmetrizeInput) a = a.ref.add(matrixTranspose(a)).mul(.5);
|
|
4947
|
+
return cholesky(a, { upper });
|
|
4948
|
+
}
|
|
4949
|
+
/**
|
|
4950
|
+
* Return the least-squares solution to a linear equation.
|
|
4951
|
+
*
|
|
4952
|
+
* For overdetermined systems, this finds the `x` that minimizes `norm(ax - b)`.
|
|
4953
|
+
* For underdetermined systems, this finds the minimum-norm solution for `x`.
|
|
4954
|
+
*
|
|
4955
|
+
* This currently uses Cholesky decomposition to solve the normal equations,
|
|
4956
|
+
* under the hood. The method is not as robust as QR or SVD.
|
|
4957
|
+
*
|
|
4958
|
+
* @param a coefficient matrix of shape `(M, N)`
|
|
4959
|
+
* @param b right-hand side of shape `(M,)` or `(M, K)`
|
|
4960
|
+
* @return least-squares solution of shape `(N,)` or `(N, K)`
|
|
4961
|
+
*/
|
|
4962
|
+
function lstsq(a, b) {
|
|
4963
|
+
a = fudgeArray(a);
|
|
4964
|
+
b = fudgeArray(b);
|
|
4965
|
+
if (a.ndim !== 2) throw new Error(`lstsq: 'a' must be a 2D array, got ${a.aval}`);
|
|
4966
|
+
const [m, n] = a.shape;
|
|
4967
|
+
if (b.shape[0] !== m) throw new Error(`lstsq: leading dimension of 'b' must match number of rows of 'a', got ${b.aval}`);
|
|
4968
|
+
const at = matrixTranspose(a.ref);
|
|
4969
|
+
if (m <= n) {
|
|
4970
|
+
const aat = matmul(a, at.ref);
|
|
4971
|
+
const l = cholesky$1(aat, { symmetrizeInput: false });
|
|
4972
|
+
const lb = triangularSolve(l.ref, b, {
|
|
4973
|
+
leftSide: true,
|
|
4974
|
+
lower: true
|
|
4975
|
+
});
|
|
4976
|
+
const llb = triangularSolve(l, lb, {
|
|
4977
|
+
leftSide: true,
|
|
4978
|
+
transposeA: true
|
|
4979
|
+
});
|
|
4980
|
+
return matmul(at, llb.ref);
|
|
4981
|
+
} else {
|
|
4982
|
+
const ata = matmul(at.ref, a);
|
|
4983
|
+
const l = cholesky$1(ata, { symmetrizeInput: false });
|
|
4984
|
+
const atb = matmul(at, b);
|
|
4985
|
+
const lb = triangularSolve(l.ref, atb, {
|
|
4986
|
+
leftSide: true,
|
|
4987
|
+
lower: true
|
|
4988
|
+
});
|
|
4989
|
+
const llb = triangularSolve(l, lb, {
|
|
4990
|
+
leftSide: true,
|
|
4991
|
+
transposeA: true
|
|
4992
|
+
});
|
|
4993
|
+
return llb;
|
|
4994
|
+
}
|
|
4995
|
+
}
|
|
4996
|
+
|
|
4574
4997
|
//#endregion
|
|
4575
4998
|
//#region src/library/numpy.ts
|
|
4576
4999
|
var numpy_exports = {};
|
|
4577
5000
|
__export(numpy_exports, {
|
|
4578
5001
|
Array: () => Array$1,
|
|
4579
5002
|
DType: () => DType,
|
|
4580
|
-
abs: () =>
|
|
5003
|
+
abs: () => absolute,
|
|
4581
5004
|
absolute: () => absolute,
|
|
4582
5005
|
acos: () => acos,
|
|
4583
|
-
acosh: () =>
|
|
5006
|
+
acosh: () => arccosh,
|
|
4584
5007
|
add: () => add,
|
|
5008
|
+
all: () => all,
|
|
4585
5009
|
allclose: () => allclose,
|
|
5010
|
+
any: () => any,
|
|
4586
5011
|
arange: () => arange,
|
|
4587
|
-
arccos: () =>
|
|
5012
|
+
arccos: () => acos,
|
|
4588
5013
|
arccosh: () => arccosh,
|
|
5014
|
+
arcsin: () => asin,
|
|
4589
5015
|
arcsinh: () => arcsinh,
|
|
4590
|
-
arctan: () =>
|
|
4591
|
-
arctan2: () =>
|
|
5016
|
+
arctan: () => atan,
|
|
5017
|
+
arctan2: () => atan2,
|
|
4592
5018
|
arctanh: () => arctanh,
|
|
4593
5019
|
argmax: () => argmax,
|
|
4594
5020
|
argmin: () => argmin,
|
|
5021
|
+
argsort: () => argsort,
|
|
4595
5022
|
array: () => array,
|
|
4596
5023
|
asin: () => asin,
|
|
4597
|
-
asinh: () =>
|
|
5024
|
+
asinh: () => arcsinh,
|
|
4598
5025
|
astype: () => astype,
|
|
4599
5026
|
atan: () => atan,
|
|
4600
5027
|
atan2: () => atan2,
|
|
4601
|
-
atanh: () =>
|
|
5028
|
+
atanh: () => arctanh,
|
|
4602
5029
|
bool: () => bool,
|
|
4603
5030
|
broadcastArrays: () => broadcastArrays,
|
|
4604
5031
|
broadcastShapes: () => broadcastShapes,
|
|
@@ -4608,14 +5035,20 @@ __export(numpy_exports, {
|
|
|
4608
5035
|
clip: () => clip,
|
|
4609
5036
|
columnStack: () => columnStack,
|
|
4610
5037
|
concatenate: () => concatenate,
|
|
5038
|
+
convolve: () => convolve,
|
|
5039
|
+
corrcoef: () => corrcoef,
|
|
5040
|
+
correlate: () => correlate,
|
|
4611
5041
|
cos: () => cos,
|
|
4612
5042
|
cosh: () => cosh,
|
|
5043
|
+
cov: () => cov,
|
|
5044
|
+
cumsum: () => cumsum,
|
|
5045
|
+
cumulativeSum: () => cumsum,
|
|
4613
5046
|
deg2rad: () => deg2rad,
|
|
4614
5047
|
degrees: () => degrees,
|
|
4615
5048
|
diag: () => diag,
|
|
4616
5049
|
diagonal: () => diagonal,
|
|
4617
|
-
divide: () =>
|
|
4618
|
-
dot: () => dot,
|
|
5050
|
+
divide: () => trueDivide,
|
|
5051
|
+
dot: () => dot$1,
|
|
4619
5052
|
dstack: () => dstack,
|
|
4620
5053
|
e: () => e,
|
|
4621
5054
|
einsum: () => einsum,
|
|
@@ -4623,8 +5056,10 @@ __export(numpy_exports, {
|
|
|
4623
5056
|
eulerGamma: () => eulerGamma,
|
|
4624
5057
|
exp: () => exp,
|
|
4625
5058
|
exp2: () => exp2,
|
|
5059
|
+
expandDims: () => expandDims,
|
|
4626
5060
|
expm1: () => expm1,
|
|
4627
5061
|
eye: () => eye,
|
|
5062
|
+
fft: () => numpy_fft_exports,
|
|
4628
5063
|
flip: () => flip,
|
|
4629
5064
|
fliplr: () => fliplr,
|
|
4630
5065
|
flipud: () => flipud,
|
|
@@ -4655,12 +5090,14 @@ __export(numpy_exports, {
|
|
|
4655
5090
|
ldexp: () => ldexp,
|
|
4656
5091
|
less: () => less,
|
|
4657
5092
|
lessEqual: () => lessEqual,
|
|
5093
|
+
linalg: () => numpy_linalg_exports,
|
|
4658
5094
|
linspace: () => linspace,
|
|
4659
5095
|
log: () => log,
|
|
4660
5096
|
log10: () => log10,
|
|
4661
5097
|
log1p: () => log1p,
|
|
4662
5098
|
log2: () => log2,
|
|
4663
5099
|
matmul: () => matmul,
|
|
5100
|
+
matrixTranspose: () => matrixTranspose,
|
|
4664
5101
|
max: () => max,
|
|
4665
5102
|
maximum: () => maximum,
|
|
4666
5103
|
mean: () => mean,
|
|
@@ -4677,10 +5114,10 @@ __export(numpy_exports, {
|
|
|
4677
5114
|
onesLike: () => onesLike,
|
|
4678
5115
|
outer: () => outer,
|
|
4679
5116
|
pad: () => pad,
|
|
4680
|
-
permuteDims: () =>
|
|
5117
|
+
permuteDims: () => transpose,
|
|
4681
5118
|
pi: () => pi,
|
|
4682
5119
|
positive: () => positive,
|
|
4683
|
-
pow: () =>
|
|
5120
|
+
pow: () => power,
|
|
4684
5121
|
power: () => power,
|
|
4685
5122
|
prod: () => prod$1,
|
|
4686
5123
|
promoteTypes: () => promoteTypes,
|
|
@@ -4697,6 +5134,7 @@ __export(numpy_exports, {
|
|
|
4697
5134
|
sin: () => sin,
|
|
4698
5135
|
sinh: () => sinh,
|
|
4699
5136
|
size: () => size,
|
|
5137
|
+
sort: () => sort,
|
|
4700
5138
|
sqrt: () => sqrt,
|
|
4701
5139
|
square: () => square,
|
|
4702
5140
|
squeeze: () => squeeze,
|
|
@@ -4861,6 +5299,26 @@ function min(a, axis = null, opts) {
|
|
|
4861
5299
|
function max(a, axis = null, opts) {
|
|
4862
5300
|
return reduce(a, AluOp.Max, axis, opts);
|
|
4863
5301
|
}
|
|
5302
|
+
/**
|
|
5303
|
+
* Test whether all array elements along a given axis evaluate to True.
|
|
5304
|
+
*
|
|
5305
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5306
|
+
* removed. If axis is None, returns a scalar.
|
|
5307
|
+
*/
|
|
5308
|
+
function all(a, axis = null, opts) {
|
|
5309
|
+
a = fudgeArray(a).astype(DType.Bool);
|
|
5310
|
+
return min(a, axis, opts);
|
|
5311
|
+
}
|
|
5312
|
+
/**
|
|
5313
|
+
* Test whether any array element along a given axis evaluates to True.
|
|
5314
|
+
*
|
|
5315
|
+
* Returns a boolean array with the same shape as `a` with the specified axis
|
|
5316
|
+
* removed. If axis is None, returns a scalar.
|
|
5317
|
+
*/
|
|
5318
|
+
function any(a, axis = null, opts) {
|
|
5319
|
+
a = fudgeArray(a).astype(DType.Bool);
|
|
5320
|
+
return max(a, axis, opts);
|
|
5321
|
+
}
|
|
4864
5322
|
/** Return the peak-to-peak range along a given axis (`max - min`). */
|
|
4865
5323
|
function ptp(a, axis = null, opts) {
|
|
4866
5324
|
a = fudgeArray(a);
|
|
@@ -4918,6 +5376,23 @@ function argmax(a, axis, opts) {
|
|
|
4918
5376
|
}).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
|
|
4919
5377
|
return length.sub(max(idx, axis, opts));
|
|
4920
5378
|
}
|
|
5379
|
+
/**
|
|
5380
|
+
* Cumulative sum of elements along an axis.
|
|
5381
|
+
*
|
|
5382
|
+
* Currently this function is `O(n^2)`, we'll improve this later on with a
|
|
5383
|
+
* two-phase parallel reduction algorithm.
|
|
5384
|
+
*/
|
|
5385
|
+
function cumsum(a, axis) {
|
|
5386
|
+
a = fudgeArray(a);
|
|
5387
|
+
if (axis === void 0) {
|
|
5388
|
+
a = a.ravel();
|
|
5389
|
+
axis = 0;
|
|
5390
|
+
} else axis = checkAxis(axis, a.ndim);
|
|
5391
|
+
const n = a.shape[axis];
|
|
5392
|
+
a = moveaxis$1(a, axis, -1);
|
|
5393
|
+
a = broadcast(a, a.shape.concat(n), [-2]);
|
|
5394
|
+
return moveaxis$1(tril(a).sum(-1), -1, axis);
|
|
5395
|
+
}
|
|
4921
5396
|
/** Reverse the elements in an array along the given axes. */
|
|
4922
5397
|
function flip(x, axis = null) {
|
|
4923
5398
|
const nd = ndim(x);
|
|
@@ -5027,8 +5502,11 @@ function flipud(x) {
|
|
|
5027
5502
|
function fliplr(x) {
|
|
5028
5503
|
return flip(x, 1);
|
|
5029
5504
|
}
|
|
5030
|
-
/**
|
|
5031
|
-
|
|
5505
|
+
/** Transpose the last two dimensions of an array. */
|
|
5506
|
+
function matrixTranspose(a) {
|
|
5507
|
+
if (ndim(a) < 2) throw new Error(`matrixTranspose: input array must be at least 2D`);
|
|
5508
|
+
return moveaxis$1(a, -1, -2);
|
|
5509
|
+
}
|
|
5032
5510
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
5033
5511
|
function ravel(a) {
|
|
5034
5512
|
return fudgeArray(a).ravel();
|
|
@@ -5044,6 +5522,32 @@ function squeeze(a, axis = null) {
|
|
|
5044
5522
|
return reshape(a, newShape);
|
|
5045
5523
|
}
|
|
5046
5524
|
/**
|
|
5525
|
+
* Expand the shape of an array by inserting new axes of length 1.
|
|
5526
|
+
*
|
|
5527
|
+
* @param a - Input array.
|
|
5528
|
+
* @param axis - Position(s) in the expanded axes where the new axis (or axes)
|
|
5529
|
+
* is placed. Can be a single integer or an array of integers.
|
|
5530
|
+
* @returns Array with the number of dimensions increased.
|
|
5531
|
+
*
|
|
5532
|
+
* @example
|
|
5533
|
+
* ```ts
|
|
5534
|
+
* const x = np.array([1, 2]);
|
|
5535
|
+
* np.expandDims(x, 0); // Shape [1, 2]
|
|
5536
|
+
* np.expandDims(x, 1); // Shape [2, 1]
|
|
5537
|
+
* np.expandDims(x, [0, 2]); // Shape [1, 2, 1]
|
|
5538
|
+
* ```
|
|
5539
|
+
*/
|
|
5540
|
+
function expandDims(a, axis) {
|
|
5541
|
+
const as = shape(a);
|
|
5542
|
+
axis = typeof axis === "number" ? [axis] : axis;
|
|
5543
|
+
axis = normalizeAxis(axis, as.length + axis.length);
|
|
5544
|
+
const newShape = [];
|
|
5545
|
+
let srcIdx = 0;
|
|
5546
|
+
for (let i = 0; i < as.length + axis.length; i++) if (axis.includes(i)) newShape.push(1);
|
|
5547
|
+
else newShape.push(as[srcIdx++]);
|
|
5548
|
+
return reshape(a, newShape);
|
|
5549
|
+
}
|
|
5550
|
+
/**
|
|
5047
5551
|
* Repeat each element of an array after themselves.
|
|
5048
5552
|
*
|
|
5049
5553
|
* If no axis is provided, use the flattened input array, and return a flat
|
|
@@ -5131,7 +5635,7 @@ function diagonal(a, offset, axis1, axis2) {
|
|
|
5131
5635
|
*/
|
|
5132
5636
|
function diag(v, k = 0) {
|
|
5133
5637
|
const a = fudgeArray(v);
|
|
5134
|
-
if (!Number.isInteger(k)) throw new
|
|
5638
|
+
if (!Number.isInteger(k)) throw new Error(`k must be an integer, got ${k}`);
|
|
5135
5639
|
if (a.ndim === 1) {
|
|
5136
5640
|
const n = a.shape[0];
|
|
5137
5641
|
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
@@ -5139,12 +5643,32 @@ function diag(v, k = 0) {
|
|
|
5139
5643
|
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
5140
5644
|
else return ret;
|
|
5141
5645
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
5142
|
-
else throw new
|
|
5646
|
+
else throw new Error("numpy.diag only supports 1D and 2D arrays");
|
|
5143
5647
|
}
|
|
5144
5648
|
/** Calculate the sum of the diagonal of an array along the given axes. */
|
|
5145
5649
|
function trace(a, offset = 0, axis1 = 0, axis2 = 1) {
|
|
5146
5650
|
return diagonal(a, offset, axis1, axis2).sum(-1);
|
|
5147
5651
|
}
|
|
5652
|
+
/**
|
|
5653
|
+
* Return a sorted copy of an array.
|
|
5654
|
+
*
|
|
5655
|
+
* The array is sorted along a specified axis (the last by default). This may be
|
|
5656
|
+
* an unstable sort, and it dispatches to device-specific implementation.
|
|
5657
|
+
*/
|
|
5658
|
+
function sort(a, axis = -1) {
|
|
5659
|
+
return fudgeArray(a).sort(axis);
|
|
5660
|
+
}
|
|
5661
|
+
/**
|
|
5662
|
+
* Return indices that would sort an array. This may be an unstable sorting
|
|
5663
|
+
* algorithm; it need not preserve order of indices in ties.
|
|
5664
|
+
*
|
|
5665
|
+
* Returns an array of `int32` indices.
|
|
5666
|
+
*
|
|
5667
|
+
* The array is sorted along a specified axis (the last by default).
|
|
5668
|
+
*/
|
|
5669
|
+
function argsort(a, axis = -1) {
|
|
5670
|
+
return fudgeArray(a).argsort(axis);
|
|
5671
|
+
}
|
|
5148
5672
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
5149
5673
|
function allclose(actual, expected, options) {
|
|
5150
5674
|
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
@@ -5153,16 +5677,19 @@ function allclose(actual, expected, options) {
|
|
|
5153
5677
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
5154
5678
|
const xData = x.dataSync();
|
|
5155
5679
|
const yData = y.dataSync();
|
|
5156
|
-
for (let i = 0; i < xData.length; i++)
|
|
5680
|
+
for (let i = 0; i < xData.length; i++) {
|
|
5681
|
+
if (isNaN(xData[i]) !== isNaN(yData[i])) return false;
|
|
5682
|
+
if (Math.abs(xData[i] - yData[i]) > atol + rtol * Math.abs(yData[i])) return false;
|
|
5683
|
+
}
|
|
5157
5684
|
return true;
|
|
5158
5685
|
}
|
|
5159
5686
|
/** Matrix product of two arrays. */
|
|
5160
5687
|
function matmul(x, y) {
|
|
5161
|
-
if (ndim(x) === 0 || ndim(y) === 0) throw new
|
|
5688
|
+
if (ndim(x) === 0 || ndim(y) === 0) throw new Error("matmul: x and y must be at least 1D");
|
|
5162
5689
|
x = x, y = y;
|
|
5163
5690
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5164
5691
|
const numBatchDims = Math.min(Math.max(x.ndim, 2), y.ndim) - 2;
|
|
5165
|
-
return dot
|
|
5692
|
+
return dot(x, y, {
|
|
5166
5693
|
lhsContractingDims: [-1],
|
|
5167
5694
|
rhsContractingDims: [-2],
|
|
5168
5695
|
lhsBatchDims: range(-2 - numBatchDims, -2),
|
|
@@ -5170,11 +5697,11 @@ function matmul(x, y) {
|
|
|
5170
5697
|
});
|
|
5171
5698
|
}
|
|
5172
5699
|
/** Dot product of two arrays. */
|
|
5173
|
-
function dot(x, y) {
|
|
5700
|
+
function dot$1(x, y) {
|
|
5174
5701
|
if (ndim(x) === 0 || ndim(y) === 0) return multiply(x, y);
|
|
5175
5702
|
x = x, y = y;
|
|
5176
5703
|
if (y.ndim === 1) return dot$2(x, y);
|
|
5177
|
-
return dot
|
|
5704
|
+
return dot(x, y, {
|
|
5178
5705
|
lhsContractingDims: [-1],
|
|
5179
5706
|
rhsContractingDims: [-2]
|
|
5180
5707
|
});
|
|
@@ -5190,7 +5717,7 @@ function tensordot(x, y, axes = 2) {
|
|
|
5190
5717
|
x = fudgeArray(x);
|
|
5191
5718
|
y = fudgeArray(y);
|
|
5192
5719
|
if (typeof axes === "number") axes = [range(-axes, 0), range(axes)];
|
|
5193
|
-
return dot
|
|
5720
|
+
return dot(x, y, {
|
|
5194
5721
|
lhsContractingDims: axes[0],
|
|
5195
5722
|
rhsContractingDims: axes[1]
|
|
5196
5723
|
});
|
|
@@ -5283,7 +5810,7 @@ function einsum(...args) {
|
|
|
5283
5810
|
const [b, bidx] = processSingleTensor(operands[j], indices[j], indices[i]);
|
|
5284
5811
|
indexReduced = indexReduced.filter((idx) => aidx.includes(idx));
|
|
5285
5812
|
const indexBatch = aidx.filter((idx) => bidx.includes(idx) && !indexReduced.includes(idx));
|
|
5286
|
-
const result = dot
|
|
5813
|
+
const result = dot(a, b, {
|
|
5287
5814
|
lhsContractingDims: indexReduced.map((idx) => aidx.indexOf(idx)),
|
|
5288
5815
|
rhsContractingDims: indexReduced.map((idx) => bidx.indexOf(idx)),
|
|
5289
5816
|
lhsBatchDims: indexBatch.map((idx) => aidx.indexOf(idx)),
|
|
@@ -5311,7 +5838,7 @@ function einsum(...args) {
|
|
|
5311
5838
|
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
5312
5839
|
*/
|
|
5313
5840
|
function inner(x, y) {
|
|
5314
|
-
return dot
|
|
5841
|
+
return dot(fudgeArray(x), fudgeArray(y), {
|
|
5315
5842
|
lhsContractingDims: [-1],
|
|
5316
5843
|
rhsContractingDims: [-1]
|
|
5317
5844
|
});
|
|
@@ -5344,6 +5871,30 @@ function vecdot(x, y, { axis } = {}) {
|
|
|
5344
5871
|
function vdot(x, y) {
|
|
5345
5872
|
return dot$2(ravel(x), ravel(y));
|
|
5346
5873
|
}
|
|
5874
|
+
function _convImpl(name, x, y, mode) {
|
|
5875
|
+
if (x.ndim !== 1 || y.ndim !== 1) throw new Error(`${name}: both inputs must be 1D arrays, got ${x.ndim}D and ${y.ndim}D`);
|
|
5876
|
+
let flipOutput = false;
|
|
5877
|
+
if (x.shape[0] < y.shape[0]) {
|
|
5878
|
+
[x, y] = [y, x];
|
|
5879
|
+
if (name === "correlate") flipOutput = true;
|
|
5880
|
+
}
|
|
5881
|
+
if (name === "convolve") y = flip(y);
|
|
5882
|
+
let padding;
|
|
5883
|
+
if (mode === "valid") padding = "VALID";
|
|
5884
|
+
else if (mode === "same") padding = "SAME_LOWER";
|
|
5885
|
+
else if (mode === "full") padding = [[y.shape[0] - 1, y.shape[0] - 1]];
|
|
5886
|
+
else throw new Error(`${name}: invalid mode ${mode}, expected "full", "same", or "valid"`);
|
|
5887
|
+
const z = conv(x.slice(null, null), y.slice(null, null), [1], padding).slice(0, 0);
|
|
5888
|
+
return flipOutput ? flip(z) : z;
|
|
5889
|
+
}
|
|
5890
|
+
/** Convolution of two one-dimensional arrays. */
|
|
5891
|
+
function convolve(x, y, mode = "full") {
|
|
5892
|
+
return _convImpl("convolve", x, y, mode);
|
|
5893
|
+
}
|
|
5894
|
+
/** Correlation of two one dimensional arrays. */
|
|
5895
|
+
function correlate(x, y, mode = "valid") {
|
|
5896
|
+
return _convImpl("correlate", x, y, mode);
|
|
5897
|
+
}
|
|
5347
5898
|
/**
|
|
5348
5899
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
5349
5900
|
*
|
|
@@ -5352,7 +5903,7 @@ function vdot(x, y) {
|
|
|
5352
5903
|
*/
|
|
5353
5904
|
function meshgrid(xs, { indexing } = {}) {
|
|
5354
5905
|
indexing ??= "xy";
|
|
5355
|
-
for (const x of xs) if (x.ndim !== 1) throw new
|
|
5906
|
+
for (const x of xs) if (x.ndim !== 1) throw new Error(`meshgrid: all inputs must be 1D arrays, got ${x.ndim}D array`);
|
|
5356
5907
|
if (xs.length <= 1) return xs;
|
|
5357
5908
|
if (indexing === "xy") {
|
|
5358
5909
|
const [a, b, ...rest] = xs;
|
|
@@ -5371,43 +5922,6 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
5371
5922
|
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
5372
5923
|
}
|
|
5373
5924
|
/**
|
|
5374
|
-
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
5375
|
-
*
|
|
5376
|
-
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
5377
|
-
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
5378
|
-
* `k>0` is above it.
|
|
5379
|
-
*/
|
|
5380
|
-
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
5381
|
-
m ??= n;
|
|
5382
|
-
dtype ??= DType.Float32;
|
|
5383
|
-
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
5384
|
-
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
5385
|
-
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
5386
|
-
const rows = arange(k, n + k, 1, {
|
|
5387
|
-
dtype: DType.Int32,
|
|
5388
|
-
device
|
|
5389
|
-
});
|
|
5390
|
-
const cols = arange(0, m, 1, {
|
|
5391
|
-
dtype: DType.Int32,
|
|
5392
|
-
device
|
|
5393
|
-
});
|
|
5394
|
-
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
5395
|
-
}
|
|
5396
|
-
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
5397
|
-
function tril(a, k = 0) {
|
|
5398
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5399
|
-
a = fudgeArray(a);
|
|
5400
|
-
const [n, m] = a.shape.slice(-2);
|
|
5401
|
-
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
5402
|
-
}
|
|
5403
|
-
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
5404
|
-
function triu(a, k = 0) {
|
|
5405
|
-
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
5406
|
-
a = fudgeArray(a);
|
|
5407
|
-
const [n, m] = a.shape.slice(-2);
|
|
5408
|
-
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
5409
|
-
}
|
|
5410
|
-
/**
|
|
5411
5925
|
* Clip (limit) the values in an array.
|
|
5412
5926
|
*
|
|
5413
5927
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -5431,8 +5945,6 @@ function absolute(x) {
|
|
|
5431
5945
|
x = fudgeArray(x);
|
|
5432
5946
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
5433
5947
|
}
|
|
5434
|
-
/** @function Alias of `jax.numpy.absolute()`. */
|
|
5435
|
-
const abs = absolute;
|
|
5436
5948
|
/** Return an element-wise indication of sign of the input. */
|
|
5437
5949
|
function sign(x) {
|
|
5438
5950
|
x = fudgeArray(x);
|
|
@@ -5511,12 +6023,6 @@ const atan2 = jit$1(function atan2$1(y, x) {
|
|
|
5511
6023
|
const denom = where(xNeg, y, r.add(x));
|
|
5512
6024
|
return atan(numer.div(denom)).mul(2);
|
|
5513
6025
|
});
|
|
5514
|
-
/** @function Alias of `jax.numpy.acos()`. */
|
|
5515
|
-
const arccos = acos;
|
|
5516
|
-
/** @function Alias of `jax.numpy.atan()`. */
|
|
5517
|
-
const arctan = atan;
|
|
5518
|
-
/** @function Alias of `jax.numpy.atan2()`. */
|
|
5519
|
-
const arctan2 = atan2;
|
|
5520
6026
|
/** Element-wise subtraction, with broadcasting. */
|
|
5521
6027
|
function subtract(x, y) {
|
|
5522
6028
|
x = fudgeArray(x);
|
|
@@ -5547,8 +6053,6 @@ const fmod = jit$1(function fmod$1(x, y) {
|
|
|
5547
6053
|
const remainder = jit$1(function remainder$1(x, y) {
|
|
5548
6054
|
return mod(mod(x, y.ref).add(y.ref), y);
|
|
5549
6055
|
});
|
|
5550
|
-
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
5551
|
-
const divide = trueDivide;
|
|
5552
6056
|
/** Round input to the nearest integer towards zero. */
|
|
5553
6057
|
function trunc(x) {
|
|
5554
6058
|
return idiv(x, 1);
|
|
@@ -5570,9 +6074,9 @@ function ldexp(x1, x2) {
|
|
|
5570
6074
|
*/
|
|
5571
6075
|
function frexp(x) {
|
|
5572
6076
|
x = fudgeArray(x);
|
|
5573
|
-
const absx =
|
|
6077
|
+
const absx = absolute(x.ref);
|
|
5574
6078
|
const exponent = where(equal(x.ref, 0), 0, floor(log2(absx)).add(1).astype(DType.Int32));
|
|
5575
|
-
const mantissa =
|
|
6079
|
+
const mantissa = x.div(exp2(exponent.ref.astype(x.dtype)));
|
|
5576
6080
|
return [mantissa, exponent];
|
|
5577
6081
|
}
|
|
5578
6082
|
/** Calculate `2**p` for all p in the input array. */
|
|
@@ -5612,10 +6116,11 @@ const degrees = rad2deg;
|
|
|
5612
6116
|
* Computes first array raised to power of second array, element-wise.
|
|
5613
6117
|
*/
|
|
5614
6118
|
const power = jit$1(function power$1(x1, x2) {
|
|
5615
|
-
|
|
6119
|
+
const x2i = trunc(x2.ref);
|
|
6120
|
+
const shouldBeNaN = multiply(x2.ref.notEqual(x2i.ref), x1.ref.less(0));
|
|
6121
|
+
const resultSign = where(mod(x2i, 2).notEqual(0), where(x1.ref.less(0), -1, 1), 1);
|
|
6122
|
+
return where(shouldBeNaN, nan, exp(log(absolute(x1)).mul(x2)).mul(resultSign));
|
|
5616
6123
|
});
|
|
5617
|
-
/** @function Alias of `jax.numpy.power()`. */
|
|
5618
|
-
const pow = power;
|
|
5619
6124
|
/** @function Calculate the element-wise cube root of the input array. */
|
|
5620
6125
|
const cbrt = jit$1(function cbrt$1(x) {
|
|
5621
6126
|
const sgn = where(less(x.ref, 0), -1, 1);
|
|
@@ -5681,12 +6186,6 @@ const arccosh = jit$1(function arccosh$1(x) {
|
|
|
5681
6186
|
const arctanh = jit$1(function arctanh$1(x) {
|
|
5682
6187
|
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
5683
6188
|
});
|
|
5684
|
-
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
5685
|
-
const asinh = arcsinh;
|
|
5686
|
-
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
5687
|
-
const acosh = arccosh;
|
|
5688
|
-
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
5689
|
-
const atanh = arctanh;
|
|
5690
6189
|
/**
|
|
5691
6190
|
* Compute the variance of an array.
|
|
5692
6191
|
*
|
|
@@ -5716,6 +6215,26 @@ function var_(x, axis = null, opts) {
|
|
|
5716
6215
|
function std(x, axis = null, opts) {
|
|
5717
6216
|
return sqrt(var_(x, axis, opts));
|
|
5718
6217
|
}
|
|
6218
|
+
/** Estimate the sample covariance of a set of variables. */
|
|
6219
|
+
function cov(x, y) {
|
|
6220
|
+
x = fudgeArray(x);
|
|
6221
|
+
if (x.ndim === 1) x = x.reshape([1, x.shape[0]]);
|
|
6222
|
+
if (y !== void 0) {
|
|
6223
|
+
y = fudgeArray(y);
|
|
6224
|
+
if (y.ndim === 1) y = y.reshape([1, y.shape[0]]);
|
|
6225
|
+
x = vstack([x, y]);
|
|
6226
|
+
}
|
|
6227
|
+
const [_M, N] = x.shape;
|
|
6228
|
+
x = x.ref.sub(x.mean(1, { keepdims: true }));
|
|
6229
|
+
return dot$1(x.ref, x.transpose()).div(N - 1);
|
|
6230
|
+
}
|
|
6231
|
+
/** Compute the Pearson correlation coefficients (in range `[-1, 1]`). */
|
|
6232
|
+
function corrcoef(x, y) {
|
|
6233
|
+
const c = cov(x, y);
|
|
6234
|
+
const variances = diag(c.ref);
|
|
6235
|
+
const norm = sqrt(outer(variances.ref, variances));
|
|
6236
|
+
return c.div(norm);
|
|
6237
|
+
}
|
|
5719
6238
|
/** Test element-wise for positive or negative infinity, return bool array. */
|
|
5720
6239
|
function isinf(x) {
|
|
5721
6240
|
x = fudgeArray(x);
|
|
@@ -5745,6 +6264,253 @@ const isfinite = jit$1(function isfinite$1(x) {
|
|
|
5745
6264
|
return isnan(x.ref).add(isinf(x)).notEqual(true);
|
|
5746
6265
|
});
|
|
5747
6266
|
|
|
6267
|
+
//#endregion
|
|
6268
|
+
//#region src/library/lax-linalg.ts
|
|
6269
|
+
var lax_linalg_exports = {};
|
|
6270
|
+
__export(lax_linalg_exports, {
|
|
6271
|
+
cholesky: () => cholesky,
|
|
6272
|
+
triangularSolve: () => triangularSolve
|
|
6273
|
+
});
|
|
6274
|
+
/**
|
|
6275
|
+
* Compute the Cholesky decomposition of a symmetric positive-definite matrix.
|
|
6276
|
+
*
|
|
6277
|
+
* The Cholesky decomposition of a matrix `A` is:
|
|
6278
|
+
*
|
|
6279
|
+
* - A = L @ L^T (for upper=false, default)
|
|
6280
|
+
* - A = U^T @ U (for upper=true)
|
|
6281
|
+
*
|
|
6282
|
+
* where `L` is a lower-triangular matrix and `U` is an upper-triangular matrix.
|
|
6283
|
+
* The input matrix must be symmetric and positive-definite.
|
|
6284
|
+
*
|
|
6285
|
+
* @example
|
|
6286
|
+
* ```ts
|
|
6287
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6288
|
+
*
|
|
6289
|
+
* const x = np.array([[2., 1.], [1., 2.]]);
|
|
6290
|
+
*
|
|
6291
|
+
* // Lower Cholesky factorization (default):
|
|
6292
|
+
* const L = lax.linalg.cholesky(x);
|
|
6293
|
+
* // L ≈ [[1.4142135, 0], [0.70710677, 1.2247449]]
|
|
6294
|
+
*
|
|
6295
|
+
* // Upper Cholesky factorization:
|
|
6296
|
+
* const U = lax.linalg.cholesky(x, { upper: true });
|
|
6297
|
+
* // U ≈ [[1.4142135, 0.70710677], [0, 1.2247449]]
|
|
6298
|
+
* ```
|
|
6299
|
+
*/
|
|
6300
|
+
function cholesky(a, { upper = false } = {}) {
|
|
6301
|
+
const L = cholesky$2(a);
|
|
6302
|
+
return upper ? moveaxis$1(L, -2, -1) : L;
|
|
6303
|
+
}
|
|
6304
|
+
/**
|
|
6305
|
+
* Solve a triangular linear system.
|
|
6306
|
+
*
|
|
6307
|
+
* Solves `a @ x = b` (if leftSide=true) or `x @ a = b` (if leftSide=false)
|
|
6308
|
+
* where `a` is a triangular matrix.
|
|
6309
|
+
*
|
|
6310
|
+
* @example
|
|
6311
|
+
* ```ts
|
|
6312
|
+
* import { lax, numpy as np } from "@jax-js/jax";
|
|
6313
|
+
*
|
|
6314
|
+
* const L = np.array([[2., 0.], [1., 3.]]);
|
|
6315
|
+
* const b = np.array([4., 7.]).reshape([2, 1]);
|
|
6316
|
+
*
|
|
6317
|
+
* // Solve L @ x = b
|
|
6318
|
+
* const x = lax.linalg.triangularSolve(L, b, { leftSide: true, lower: true });
|
|
6319
|
+
* // x = [[2.], [5./3.]]
|
|
6320
|
+
* ```
|
|
6321
|
+
*/
|
|
6322
|
+
function triangularSolve(a, b, { leftSide = false, lower = false, transposeA = false, unitDiagonal = false } = {}) {
|
|
6323
|
+
a = fudgeArray(a);
|
|
6324
|
+
b = fudgeArray(b);
|
|
6325
|
+
if (!leftSide) transposeA = !transposeA;
|
|
6326
|
+
else b = moveaxis$1(b, -2, -1);
|
|
6327
|
+
if (transposeA) a = moveaxis$1(a, -2, -1);
|
|
6328
|
+
let x = triangularSolve$1(a, b, {
|
|
6329
|
+
lower,
|
|
6330
|
+
unitDiagonal
|
|
6331
|
+
});
|
|
6332
|
+
if (leftSide) x = moveaxis$1(x, -2, -1);
|
|
6333
|
+
return x;
|
|
6334
|
+
}
|
|
6335
|
+
|
|
6336
|
+
//#endregion
|
|
6337
|
+
//#region src/library/lax.ts
|
|
6338
|
+
var lax_exports = {};
|
|
6339
|
+
__export(lax_exports, {
|
|
6340
|
+
conv: () => conv,
|
|
6341
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
6342
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
6343
|
+
dot: () => dot,
|
|
6344
|
+
erf: () => erf,
|
|
6345
|
+
erfc: () => erfc,
|
|
6346
|
+
linalg: () => lax_linalg_exports,
|
|
6347
|
+
reduceWindow: () => reduceWindow,
|
|
6348
|
+
stopGradient: () => stopGradient$1
|
|
6349
|
+
});
|
|
6350
|
+
/**
|
|
6351
|
+
* General dot product/contraction operator.
|
|
6352
|
+
*
|
|
6353
|
+
* Prefer higher-level functions like `jax.numpy.dot()`, `jax.numpy.matmul()`,
|
|
6354
|
+
* `jax.numpy.tensordot(), and `jax.numpy.einsum()` where possible.
|
|
6355
|
+
*/
|
|
6356
|
+
function dot(lhs, rhs, { lhsContractingDims: lc = [], rhsContractingDims: rc = [], lhsBatchDims: lb = [], rhsBatchDims: rb = [] } = {}) {
|
|
6357
|
+
if (lc.length !== rc.length) throw new Error(`dot: contracting dims lengths mismatch, got ${JSON.stringify(lc)} and ${JSON.stringify(rc)}`);
|
|
6358
|
+
else if (lb.length !== rb.length) throw new Error(`dot: batch dims lengths mismatch, got ${JSON.stringify(lb)} and ${JSON.stringify(rb)}`);
|
|
6359
|
+
lc = lc.map((a) => checkAxis(a, lhs.ndim));
|
|
6360
|
+
rc = rc.map((a) => checkAxis(a, rhs.ndim));
|
|
6361
|
+
lb = lb.map((a) => checkAxis(a, lhs.ndim));
|
|
6362
|
+
rb = rb.map((a) => checkAxis(a, rhs.ndim));
|
|
6363
|
+
if (lc.some((a) => lb.includes(a))) throw new Error(`dot: lhs contracting dims ${JSON.stringify(lc)} overlap with batch dims ${JSON.stringify(lb)}`);
|
|
6364
|
+
else if (rc.some((a) => rb.includes(a))) throw new Error(`dot: rhs contracting dims ${JSON.stringify(rc)} overlap with batch dims ${JSON.stringify(rb)}`);
|
|
6365
|
+
const lf = range(lhs.ndim).filter((a) => !lc.includes(a) && !lb.includes(a));
|
|
6366
|
+
const rf = range(rhs.ndim).filter((a) => !rc.includes(a) && !rb.includes(a));
|
|
6367
|
+
const lhs2 = lhs.transpose([
|
|
6368
|
+
...lb,
|
|
6369
|
+
...lf,
|
|
6370
|
+
...lc
|
|
6371
|
+
]);
|
|
6372
|
+
const rhs2 = rhs.transpose([
|
|
6373
|
+
...rb,
|
|
6374
|
+
...rf,
|
|
6375
|
+
...rc
|
|
6376
|
+
]);
|
|
6377
|
+
if (lc.length === 0) return mul(lhs2.reshape([
|
|
6378
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6379
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6380
|
+
...rep(rf.length, 1)
|
|
6381
|
+
]), rhs2.reshape([
|
|
6382
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6383
|
+
...rep(lf.length, 1),
|
|
6384
|
+
...rf.map((a) => rhs.shape[a])
|
|
6385
|
+
]));
|
|
6386
|
+
const dotShapeX = lc.map((a) => lhs.shape[a]);
|
|
6387
|
+
const dotShapeY = rc.map((a) => rhs.shape[a]);
|
|
6388
|
+
if (!deepEqual(dotShapeX, dotShapeY)) throw new Error(`dot: shapes not aligned along contracting dims: ${JSON.stringify(dotShapeX)} != ${JSON.stringify(dotShapeY)}`);
|
|
6389
|
+
return dot$2(lhs2.reshape([
|
|
6390
|
+
...lb.map((a) => lhs.shape[a]),
|
|
6391
|
+
...lf.map((a) => lhs.shape[a]),
|
|
6392
|
+
...rep(rf.length, 1),
|
|
6393
|
+
prod(dotShapeX)
|
|
6394
|
+
]), rhs2.reshape([
|
|
6395
|
+
...rb.map((a) => rhs.shape[a]),
|
|
6396
|
+
...rep(lf.length, 1),
|
|
6397
|
+
...rf.map((a) => rhs.shape[a]),
|
|
6398
|
+
prod(dotShapeY)
|
|
6399
|
+
]));
|
|
6400
|
+
}
|
|
6401
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
6402
|
+
const padType = padding.toUpperCase();
|
|
6403
|
+
switch (padType) {
|
|
6404
|
+
case "VALID": return rep(inShape.length, [0, 0]);
|
|
6405
|
+
case "SAME":
|
|
6406
|
+
case "SAME_LOWER": {
|
|
6407
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
6408
|
+
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
6409
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
6410
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
6411
|
+
}
|
|
6412
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
6413
|
+
}
|
|
6414
|
+
}
|
|
6415
|
+
/**
|
|
6416
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
6417
|
+
*
|
|
6418
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
6419
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
6420
|
+
*
|
|
6421
|
+
* Grouped convolutions are not supported right now.
|
|
6422
|
+
*/
|
|
6423
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation, featureGroupCount = 1 } = {}) {
|
|
6424
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
6425
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
6426
|
+
if (typeof padding === "string") {
|
|
6427
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
6428
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
6429
|
+
}
|
|
6430
|
+
if (featureGroupCount !== 1) {
|
|
6431
|
+
const G = featureGroupCount;
|
|
6432
|
+
const [N, C_in, ...xs] = lhs.shape;
|
|
6433
|
+
const [C_out, C_in_per_group, ...ks] = rhs.shape;
|
|
6434
|
+
if (C_in % G !== 0) throw new Error(`featureGroupCount=${G} must divide input channels=${C_in}`);
|
|
6435
|
+
if (C_out % G !== 0) throw new Error(`featureGroupCount=${G} must divide output channels=${C_out}`);
|
|
6436
|
+
if (C_in / G !== C_in_per_group) throw new Error(`rhs input channels=${C_in_per_group} must equal lhs input channels / groups=${C_in / G}`);
|
|
6437
|
+
const lhsGrouped = moveaxis(lhs.reshape([
|
|
6438
|
+
N,
|
|
6439
|
+
G,
|
|
6440
|
+
C_in / G,
|
|
6441
|
+
...xs
|
|
6442
|
+
]), 1, 0);
|
|
6443
|
+
const rhsGrouped = rhs.reshape([
|
|
6444
|
+
G,
|
|
6445
|
+
C_out / G,
|
|
6446
|
+
C_in_per_group,
|
|
6447
|
+
...ks
|
|
6448
|
+
]);
|
|
6449
|
+
const result = conv$1(lhsGrouped, rhsGrouped, {
|
|
6450
|
+
vmapDims: 1,
|
|
6451
|
+
strides: windowStrides,
|
|
6452
|
+
padding,
|
|
6453
|
+
lhsDilation,
|
|
6454
|
+
rhsDilation
|
|
6455
|
+
});
|
|
6456
|
+
const ys = result.shape.slice(3);
|
|
6457
|
+
return moveaxis(result, 0, 1).reshape([
|
|
6458
|
+
N,
|
|
6459
|
+
C_out,
|
|
6460
|
+
...ys
|
|
6461
|
+
]);
|
|
6462
|
+
}
|
|
6463
|
+
return conv$1(lhs, rhs, {
|
|
6464
|
+
strides: windowStrides,
|
|
6465
|
+
padding,
|
|
6466
|
+
lhsDilation,
|
|
6467
|
+
rhsDilation
|
|
6468
|
+
});
|
|
6469
|
+
}
|
|
6470
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6471
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
6472
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
6473
|
+
lhsDilation,
|
|
6474
|
+
rhsDilation
|
|
6475
|
+
});
|
|
6476
|
+
}
|
|
6477
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
6478
|
+
function conv(lhs, rhs, windowStrides, padding) {
|
|
6479
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
6480
|
+
}
|
|
6481
|
+
/** Reduce a computation over padded windows. */
|
|
6482
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
6483
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
6484
|
+
if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
|
|
6485
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
6486
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
6487
|
+
window: windowDimensions,
|
|
6488
|
+
strides: windowStrides
|
|
6489
|
+
}));
|
|
6490
|
+
}
|
|
6491
|
+
/** The error function: `erf(x) = 2/sqrt(pi) * int[0..x] exp(-t^2) dt`. */
|
|
6492
|
+
function erf(x) {
|
|
6493
|
+
return erf$1(x);
|
|
6494
|
+
}
|
|
6495
|
+
/**
|
|
6496
|
+
* The complementary error function: `erfc(x) = 1 - erf(x)`.
|
|
6497
|
+
*
|
|
6498
|
+
* This function is more accurate than `1 - erf(x)` for large values of `x`,
|
|
6499
|
+
* where `erf(x)` is very close to 1.
|
|
6500
|
+
*/
|
|
6501
|
+
function erfc(x) {
|
|
6502
|
+
return erfc$1(x);
|
|
6503
|
+
}
|
|
6504
|
+
/**
|
|
6505
|
+
* Stops gradient computation.
|
|
6506
|
+
*
|
|
6507
|
+
* Behaves as the identity function but prevents the flow of gradients during
|
|
6508
|
+
* forward or reverse-mode automatic differentiation.
|
|
6509
|
+
*/
|
|
6510
|
+
function stopGradient$1(x) {
|
|
6511
|
+
return stopGradient(x);
|
|
6512
|
+
}
|
|
6513
|
+
|
|
5748
6514
|
//#endregion
|
|
5749
6515
|
//#region src/library/nn.ts
|
|
5750
6516
|
var nn_exports = {};
|
|
@@ -5753,6 +6519,10 @@ __export(nn_exports, {
|
|
|
5753
6519
|
elu: () => elu,
|
|
5754
6520
|
gelu: () => gelu,
|
|
5755
6521
|
glu: () => glu,
|
|
6522
|
+
hardSigmoid: () => hardSigmoid,
|
|
6523
|
+
hardSilu: () => hardSilu,
|
|
6524
|
+
hardSwish: () => hardSilu,
|
|
6525
|
+
hardTanh: () => hardTanh,
|
|
5756
6526
|
identity: () => identity,
|
|
5757
6527
|
leakyRelu: () => leakyRelu,
|
|
5758
6528
|
logSigmoid: () => logSigmoid,
|
|
@@ -5763,14 +6533,17 @@ __export(nn_exports, {
|
|
|
5763
6533
|
oneHot: () => oneHot,
|
|
5764
6534
|
relu: () => relu,
|
|
5765
6535
|
relu6: () => relu6,
|
|
6536
|
+
selu: () => selu,
|
|
5766
6537
|
sigmoid: () => sigmoid,
|
|
5767
6538
|
silu: () => silu,
|
|
5768
6539
|
softSign: () => softSign,
|
|
5769
6540
|
softmax: () => softmax,
|
|
5770
6541
|
softplus: () => softplus,
|
|
6542
|
+
sparsePlus: () => sparsePlus,
|
|
6543
|
+
sparseSigmoid: () => sparseSigmoid,
|
|
5771
6544
|
squareplus: () => squareplus,
|
|
5772
6545
|
standardize: () => standardize,
|
|
5773
|
-
swish: () =>
|
|
6546
|
+
swish: () => silu
|
|
5774
6547
|
});
|
|
5775
6548
|
/**
|
|
5776
6549
|
* Rectified Linear Unit (ReLU) activation function:
|
|
@@ -5805,6 +6578,28 @@ function softplus(x) {
|
|
|
5805
6578
|
return log(exp(x).add(1));
|
|
5806
6579
|
}
|
|
5807
6580
|
/**
|
|
6581
|
+
* @function
|
|
6582
|
+
* Sparse plus function:
|
|
6583
|
+
*
|
|
6584
|
+
* - When `x <= -1`: `0`
|
|
6585
|
+
* - When `-1 < x < 1`: `(x+1)**2 / 4`
|
|
6586
|
+
* - When `x >= 1`: `x`
|
|
6587
|
+
*/
|
|
6588
|
+
const sparsePlus = jit$1((x) => {
|
|
6589
|
+
return where(x.ref.lessEqual(-1), 0, where(x.ref.less(1), square(x.ref.add(1)).mul(.25), x));
|
|
6590
|
+
});
|
|
6591
|
+
/**
|
|
6592
|
+
* @function
|
|
6593
|
+
* Sparse sigmoid activation function.
|
|
6594
|
+
*
|
|
6595
|
+
* - When `x <= -1`: `0`
|
|
6596
|
+
* - When `-1 < x < 1`: `(x + 1) / 2`
|
|
6597
|
+
* - When `x >= 1`: `1`
|
|
6598
|
+
*/
|
|
6599
|
+
const sparseSigmoid = jit$1((x) => {
|
|
6600
|
+
return clip(x.add(1).mul(.5), 0, 1);
|
|
6601
|
+
});
|
|
6602
|
+
/**
|
|
5808
6603
|
* Soft-sign activation function, computed element-wise:
|
|
5809
6604
|
* `softsign(x) = x / (|x| + 1)`.
|
|
5810
6605
|
*/
|
|
@@ -5826,17 +6621,6 @@ const silu = jit$1(function silu$1(x) {
|
|
|
5826
6621
|
return x.ref.mul(sigmoid(x));
|
|
5827
6622
|
});
|
|
5828
6623
|
/**
|
|
5829
|
-
* @function
|
|
5830
|
-
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
5831
|
-
* Swish, computed element-wise:
|
|
5832
|
-
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
5833
|
-
*
|
|
5834
|
-
* `swish()` and `silu()` are both aliases for the same function.
|
|
5835
|
-
*
|
|
5836
|
-
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
5837
|
-
*/
|
|
5838
|
-
const swish = silu;
|
|
5839
|
-
/**
|
|
5840
6624
|
* Log-sigmoid activation function, computed element-wise:
|
|
5841
6625
|
* `log_sigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x))`.
|
|
5842
6626
|
*/
|
|
@@ -5853,6 +6637,19 @@ function leakyRelu(x, negativeSlope = .01) {
|
|
|
5853
6637
|
x = fudgeArray(x);
|
|
5854
6638
|
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
5855
6639
|
}
|
|
6640
|
+
/** Hard sigmoid activation function: `relu6(x+3)/6`. */
|
|
6641
|
+
function hardSigmoid(x) {
|
|
6642
|
+
return relu6(add(x, 3)).mul(1 / 6);
|
|
6643
|
+
}
|
|
6644
|
+
/** Hard SiLU (swish) activation function: `x * hardSigmoid(x)`. */
|
|
6645
|
+
function hardSilu(x) {
|
|
6646
|
+
x = fudgeArray(x);
|
|
6647
|
+
return x.ref.mul(hardSigmoid(x));
|
|
6648
|
+
}
|
|
6649
|
+
/** Hard tanh activation function: `clip(x, -1, 1)`. */
|
|
6650
|
+
function hardTanh(x) {
|
|
6651
|
+
return clip(x, -1, 1);
|
|
6652
|
+
}
|
|
5856
6653
|
/**
|
|
5857
6654
|
* Exponential linear unit activation function.
|
|
5858
6655
|
*
|
|
@@ -5875,6 +6672,20 @@ function celu(x, alpha = 1) {
|
|
|
5875
6672
|
}
|
|
5876
6673
|
/**
|
|
5877
6674
|
* @function
|
|
6675
|
+
* Scaled exponential linear unit activation.
|
|
6676
|
+
*
|
|
6677
|
+
* Computes the element-wise function:
|
|
6678
|
+
* `selu(x) = lambda * (x > 0 ? x : alpha * (exp(x) - 1))`
|
|
6679
|
+
*
|
|
6680
|
+
* Where `alpha = 1.6732632423543772` and `lambda = 1.0507009873554805`.
|
|
6681
|
+
*/
|
|
6682
|
+
const selu = jit$1(function selu$1(x) {
|
|
6683
|
+
const alpha = 1.6732632423543772;
|
|
6684
|
+
const lambda = 1.0507009873554805;
|
|
6685
|
+
return where(x.ref.less(0), expm1(x.ref).mul(alpha), x).mul(lambda);
|
|
6686
|
+
});
|
|
6687
|
+
/**
|
|
6688
|
+
* @function
|
|
5878
6689
|
* Gaussion error linear unit (GELU) activation function.
|
|
5879
6690
|
*
|
|
5880
6691
|
* This is computed element-wise. There are two variants depending on whether
|
|
@@ -5968,22 +6779,22 @@ function logSoftmax(x, axis = -1) {
|
|
|
5968
6779
|
*
|
|
5969
6780
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
5970
6781
|
*/
|
|
5971
|
-
function logsumexp(x, axis = null) {
|
|
6782
|
+
function logsumexp(x, axis = null, opts) {
|
|
5972
6783
|
x = fudgeArray(x);
|
|
5973
6784
|
axis = normalizeAxis(axis, x.ndim);
|
|
5974
6785
|
if (axis.length === 0) return x;
|
|
5975
|
-
const xMax = stopGradient(max(x.ref, axis));
|
|
5976
|
-
const
|
|
5977
|
-
const
|
|
5978
|
-
return
|
|
6786
|
+
const xMax = stopGradient(max(x.ref, axis, { keepdims: true }));
|
|
6787
|
+
const shifted = x.sub(xMax.ref);
|
|
6788
|
+
const result = xMax.add(log(exp(shifted).sum(axis, { keepdims: true })));
|
|
6789
|
+
return opts?.keepdims ? result : squeeze(result, axis);
|
|
5979
6790
|
}
|
|
5980
6791
|
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5981
|
-
function logmeanexp(x, axis = null) {
|
|
6792
|
+
function logmeanexp(x, axis = null, opts) {
|
|
5982
6793
|
x = fudgeArray(x);
|
|
5983
6794
|
axis = normalizeAxis(axis, x.ndim);
|
|
5984
6795
|
if (axis.length === 0) return x;
|
|
5985
6796
|
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5986
|
-
return logsumexp(x, axis).sub(Math.log(n));
|
|
6797
|
+
return logsumexp(x, axis, opts).sub(Math.log(n));
|
|
5987
6798
|
}
|
|
5988
6799
|
/**
|
|
5989
6800
|
* Standardizes input to zero mean and unit variance.
|
|
@@ -6028,8 +6839,11 @@ var random_exports = {};
|
|
|
6028
6839
|
__export(random_exports, {
|
|
6029
6840
|
bernoulli: () => bernoulli,
|
|
6030
6841
|
bits: () => bits,
|
|
6842
|
+
cauchy: () => cauchy,
|
|
6031
6843
|
exponential: () => exponential,
|
|
6844
|
+
gumbel: () => gumbel,
|
|
6032
6845
|
key: () => key,
|
|
6846
|
+
laplace: () => laplace,
|
|
6033
6847
|
normal: () => normal,
|
|
6034
6848
|
split: () => split,
|
|
6035
6849
|
uniform: () => uniform
|
|
@@ -6088,6 +6902,16 @@ function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
|
6088
6902
|
}
|
|
6089
6903
|
/**
|
|
6090
6904
|
* @function
|
|
6905
|
+
* Sample from a Cauchy distribution with location 0 and scale 1.
|
|
6906
|
+
*
|
|
6907
|
+
* Uses inverse transform sampling: `x = tan(π * (u - 0.5))` where u ~ Uniform(0, 1).
|
|
6908
|
+
*/
|
|
6909
|
+
const cauchy = jit$1(function cauchy$1(key$1, shape$1 = []) {
|
|
6910
|
+
const u = uniform(key$1, shape$1);
|
|
6911
|
+
return tan(u.sub(.5).mul(Math.PI));
|
|
6912
|
+
}, { staticArgnums: [1] });
|
|
6913
|
+
/**
|
|
6914
|
+
* @function
|
|
6091
6915
|
* Sample exponential random values according to `p(x) = exp(-x)`.
|
|
6092
6916
|
*/
|
|
6093
6917
|
const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
@@ -6096,6 +6920,30 @@ const exponential = jit$1(function exponential$1(key$1, shape$1 = []) {
|
|
|
6096
6920
|
}, { staticArgnums: [1] });
|
|
6097
6921
|
/**
|
|
6098
6922
|
* @function
|
|
6923
|
+
* Sample from a Gumbel distribution with location 0 and scale 1.
|
|
6924
|
+
*
|
|
6925
|
+
* Uses inverse transform sampling: `x = -log(-log(u))` where u ~ Uniform(0, 1).
|
|
6926
|
+
*/
|
|
6927
|
+
const gumbel = jit$1(function gumbel$1(key$1, shape$1 = []) {
|
|
6928
|
+
const u = uniform(key$1, shape$1);
|
|
6929
|
+
return negative(log(negative(log1p(negative(u)))));
|
|
6930
|
+
}, { staticArgnums: [1] });
|
|
6931
|
+
/**
|
|
6932
|
+
* @function
|
|
6933
|
+
* Sample from a Laplace distribution with location 0 and scale 1.
|
|
6934
|
+
*
|
|
6935
|
+
* Uses inverse transform sampling: the CDF is `F(x) = 0.5 + 0.5 * sign(x) * (1 - exp(-|x|))`.
|
|
6936
|
+
* Inverting: `x = -sign(u - 0.5) * log(1 - 2 * |u - 0.5|)`.
|
|
6937
|
+
*/
|
|
6938
|
+
const laplace = jit$1(function laplace$1(key$1, shape$1 = []) {
|
|
6939
|
+
const u = uniform(key$1, shape$1);
|
|
6940
|
+
const centered = u.sub(.5);
|
|
6941
|
+
const s = sign(centered.ref);
|
|
6942
|
+
const absVal = absolute(centered);
|
|
6943
|
+
return s.mul(log1p(absVal.mul(-2)).mul(-1));
|
|
6944
|
+
}, { staticArgnums: [1] });
|
|
6945
|
+
/**
|
|
6946
|
+
* @function
|
|
6099
6947
|
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
6100
6948
|
*
|
|
6101
6949
|
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
@@ -6204,11 +7052,6 @@ const valueAndGrad = valueAndGrad$1;
|
|
|
6204
7052
|
*/
|
|
6205
7053
|
const jacrev = jacrev$1;
|
|
6206
7054
|
/**
|
|
6207
|
-
* @function
|
|
6208
|
-
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
6209
|
-
*/
|
|
6210
|
-
const jacobian = jacrev;
|
|
6211
|
-
/**
|
|
6212
7055
|
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
6213
7056
|
*
|
|
6214
7057
|
* This can be used to wait for the results of an intermediate computation to
|
|
@@ -6243,5 +7086,4 @@ async function devicePut(x, device) {
|
|
|
6243
7086
|
}
|
|
6244
7087
|
|
|
6245
7088
|
//#endregion
|
|
6246
|
-
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
6247
|
-
//# sourceMappingURL=index.js.map
|
|
7089
|
+
export { Array$1 as Array, ClosedJaxpr, DType, Jaxpr, blockUntilReady, defaultDevice, devicePut, devices, grad, init, jacfwd, jacrev as jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, scipy_special_exports as scipySpecial, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|