@jax-js/jax 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +57 -25
- package/dist/backend-EBRGmEYw.js +3816 -0
- package/dist/{backend-BK21PBVP.cjs → backend-Ss1Mev_-.cjs} +2075 -107
- package/dist/index.cjs +1393 -250
- package/dist/index.d.cts +651 -102
- package/dist/index.d.ts +651 -102
- package/dist/index.js +1377 -245
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-BVdMaO9T.cjs} +62 -35
- package/dist/{webgpu-JVpVad6g.js → webgpu-ow0Pn_6q.js} +62 -35
- package/package.json +21 -9
- package/dist/backend-1eVbAoaV.js +0 -1890
package/dist/index.js
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache,
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-EBRGmEYw.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
6
6
|
__export(tree_exports, {
|
|
7
7
|
JsTreeDef: () => JsTreeDef,
|
|
8
8
|
NodeType: () => NodeType,
|
|
9
|
+
dispose: () => dispose,
|
|
9
10
|
flatten: () => flatten,
|
|
10
11
|
leaves: () => leaves,
|
|
11
12
|
map: () => map,
|
|
@@ -20,7 +21,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
|
20
21
|
NodeType$1["Leaf"] = "Leaf";
|
|
21
22
|
return NodeType$1;
|
|
22
23
|
}({});
|
|
23
|
-
/**
|
|
24
|
+
/** Represents the structure of a JsTree. */
|
|
24
25
|
var JsTreeDef = class JsTreeDef {
|
|
25
26
|
static leaf = new JsTreeDef(NodeType.Leaf, null, []);
|
|
26
27
|
constructor(nodeType, nodeMetadata, childTreedefs) {
|
|
@@ -108,6 +109,194 @@ function map(fn, tree, ...rest) {
|
|
|
108
109
|
function ref(tree) {
|
|
109
110
|
return map((x) => x.ref, tree);
|
|
110
111
|
}
|
|
112
|
+
/** Dispose every array in a tree. */
|
|
113
|
+
function dispose(tree) {
|
|
114
|
+
if (tree) map((x) => x.dispose(), tree);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
//#endregion
|
|
118
|
+
//#region src/frontend/convolution.ts
|
|
119
|
+
/**
|
|
120
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
121
|
+
*
|
|
122
|
+
* If the check succeeds, returns the output shape.
|
|
123
|
+
*/
|
|
124
|
+
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
125
|
+
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
126
|
+
const n = lhsShape.length - 2;
|
|
127
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
128
|
+
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
129
|
+
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
130
|
+
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
131
|
+
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
132
|
+
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
133
|
+
const outShape = [lhsShape[0], rhsShape[0]];
|
|
134
|
+
for (let i = 0; i < n; i++) {
|
|
135
|
+
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
136
|
+
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
137
|
+
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
138
|
+
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
139
|
+
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
140
|
+
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
141
|
+
const [pl, pr] = padding[i];
|
|
142
|
+
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
143
|
+
const kernelSize = (k - 1) * rhsDilation[i] + 1;
|
|
144
|
+
const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
|
|
145
|
+
if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
|
|
146
|
+
outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
|
|
147
|
+
}
|
|
148
|
+
return outShape;
|
|
149
|
+
}
|
|
150
|
+
function checkPoolShape(inShape, window, strides) {
|
|
151
|
+
if (strides.length !== window.length) throw new Error("pool() strides != window dims");
|
|
152
|
+
if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
|
|
153
|
+
const outShape = inShape.slice(0, inShape.length - window.length);
|
|
154
|
+
for (let i = 0; i < window.length; i++) {
|
|
155
|
+
const k = window[i];
|
|
156
|
+
const s = strides[i];
|
|
157
|
+
const size$1 = inShape[inShape.length - window.length + i];
|
|
158
|
+
if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
|
|
159
|
+
if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
|
|
160
|
+
if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
|
|
161
|
+
outShape.push(Math.ceil((size$1 - k + 1) / s));
|
|
162
|
+
}
|
|
163
|
+
return outShape.concat(window);
|
|
164
|
+
}
|
|
165
|
+
/**
|
|
166
|
+
* Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
|
|
167
|
+
* `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
|
|
168
|
+
* spatial dimensions convolved with a kernel.
|
|
169
|
+
*
|
|
170
|
+
* The resulting array can be multiplied with a kernel of shape `ks`, then
|
|
171
|
+
* reduced along the last `ks.length` dimensions for a convolution.
|
|
172
|
+
*
|
|
173
|
+
* Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
|
|
174
|
+
*/
|
|
175
|
+
function pool(st, ks, strides = 1, dilation = 1) {
|
|
176
|
+
if (ks.length === 0) return st;
|
|
177
|
+
if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
|
|
178
|
+
if (typeof strides === "number") strides = rep(ks.length, strides);
|
|
179
|
+
if (typeof dilation === "number") dilation = rep(ks.length, dilation);
|
|
180
|
+
if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
|
|
181
|
+
if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
|
|
182
|
+
const noop = st.shape.slice(0, -ks.length);
|
|
183
|
+
const i_ = st.shape.slice(-ks.length);
|
|
184
|
+
const s_ = strides;
|
|
185
|
+
const d_ = dilation;
|
|
186
|
+
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
187
|
+
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
188
|
+
const kidf = zipn(ks, i_, d_, f_);
|
|
189
|
+
st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
190
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
|
|
191
|
+
const kos = zipn(ks, o_, s_);
|
|
192
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
|
|
193
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
|
|
194
|
+
[0, k],
|
|
195
|
+
[0, o],
|
|
196
|
+
[0, 1]
|
|
197
|
+
])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
|
|
198
|
+
st = st.permute([
|
|
199
|
+
...range(noop.length),
|
|
200
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
201
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
202
|
+
]);
|
|
203
|
+
return st;
|
|
204
|
+
}
|
|
205
|
+
/**
|
|
206
|
+
* Perform the transpose of pool, directly undo-ing a pool() operation.
|
|
207
|
+
*
|
|
208
|
+
* Note that since pool repeats the input, the transpose operation technically
|
|
209
|
+
* should include a sum reduction. This function doesn't perform the reduction,
|
|
210
|
+
* which should be done on the last `k` axes of the returned shape.
|
|
211
|
+
*/
|
|
212
|
+
function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
213
|
+
if (ks.length === 0) return st;
|
|
214
|
+
if (typeof strides === "number") strides = rep(ks.length, strides);
|
|
215
|
+
if (typeof dilation === "number") dilation = rep(ks.length, dilation);
|
|
216
|
+
const noop = inShape.slice(0, -ks.length);
|
|
217
|
+
const i_ = inShape.slice(-ks.length);
|
|
218
|
+
const s_ = strides;
|
|
219
|
+
const d_ = dilation;
|
|
220
|
+
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
221
|
+
if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
222
|
+
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
223
|
+
const kidf = zipn(ks, i_, d_, f_);
|
|
224
|
+
const kos = zipn(ks, o_, s_);
|
|
225
|
+
st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
|
|
226
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
|
|
227
|
+
k,
|
|
228
|
+
o,
|
|
229
|
+
1
|
|
230
|
+
])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
|
|
231
|
+
[0, 0],
|
|
232
|
+
[0, 0],
|
|
233
|
+
[0, s - 1]
|
|
234
|
+
])]);
|
|
235
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
|
|
236
|
+
st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
|
|
237
|
+
st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
|
|
238
|
+
...range(noop.length),
|
|
239
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
240
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
241
|
+
]);
|
|
242
|
+
return st;
|
|
243
|
+
}
|
|
244
|
+
/** Applies dilation to an array directly, for transposed convolution. */
|
|
245
|
+
function applyDilation(st, dilation) {
|
|
246
|
+
if (dilation.every((s) => s === 1)) return st;
|
|
247
|
+
const s_ = dilation;
|
|
248
|
+
const [a, b, ...k_] = st.shape;
|
|
249
|
+
st = st.reshape([
|
|
250
|
+
a,
|
|
251
|
+
b,
|
|
252
|
+
...k_.flatMap((k) => [k, 1])
|
|
253
|
+
]);
|
|
254
|
+
st = st.pad([
|
|
255
|
+
[0, 0],
|
|
256
|
+
[0, 0],
|
|
257
|
+
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
258
|
+
]);
|
|
259
|
+
st = st.reshape([
|
|
260
|
+
a,
|
|
261
|
+
b,
|
|
262
|
+
...k_.map((k, i) => k * s_[i])
|
|
263
|
+
]);
|
|
264
|
+
st = st.shrink([
|
|
265
|
+
[0, a],
|
|
266
|
+
[0, b],
|
|
267
|
+
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
268
|
+
]);
|
|
269
|
+
return st;
|
|
270
|
+
}
|
|
271
|
+
/**
|
|
272
|
+
* Prepare for a convolution between two arrays.
|
|
273
|
+
*
|
|
274
|
+
* This does not check the validity of the shapes, which should be checked
|
|
275
|
+
* beforehand using `checkConvShape()`.
|
|
276
|
+
*/
|
|
277
|
+
function prepareConv(stX, stY, params) {
|
|
278
|
+
const n = stX.shape.length - 2;
|
|
279
|
+
stX = applyDilation(stX, params.lhsDilation);
|
|
280
|
+
const ks = stY.shape.slice(2);
|
|
281
|
+
stX = stX.padOrShrink([
|
|
282
|
+
[0, 0],
|
|
283
|
+
[0, 0],
|
|
284
|
+
...params.padding
|
|
285
|
+
]);
|
|
286
|
+
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
287
|
+
stX = stX.moveaxis(1, n + 1).reshape([
|
|
288
|
+
stX.shape[0],
|
|
289
|
+
1,
|
|
290
|
+
...stX.shape.slice(2, n + 2),
|
|
291
|
+
stX.shape[1] * prod(ks)
|
|
292
|
+
]);
|
|
293
|
+
stY = stY.reshape([
|
|
294
|
+
stY.shape[0],
|
|
295
|
+
...rep(n, 1),
|
|
296
|
+
stY.shape[1] * prod(ks)
|
|
297
|
+
]);
|
|
298
|
+
return [stX, stY];
|
|
299
|
+
}
|
|
111
300
|
|
|
112
301
|
//#endregion
|
|
113
302
|
//#region src/frontend/core.ts
|
|
@@ -134,12 +323,18 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
134
323
|
Primitive$1["RandomBits"] = "random_bits";
|
|
135
324
|
Primitive$1["Sin"] = "sin";
|
|
136
325
|
Primitive$1["Cos"] = "cos";
|
|
326
|
+
Primitive$1["Asin"] = "asin";
|
|
327
|
+
Primitive$1["Atan"] = "atan";
|
|
137
328
|
Primitive$1["Exp"] = "exp";
|
|
138
329
|
Primitive$1["Log"] = "log";
|
|
330
|
+
Primitive$1["Sqrt"] = "sqrt";
|
|
139
331
|
Primitive$1["Min"] = "min";
|
|
140
332
|
Primitive$1["Max"] = "max";
|
|
141
333
|
Primitive$1["Reduce"] = "reduce";
|
|
142
334
|
Primitive$1["Dot"] = "dot";
|
|
335
|
+
Primitive$1["Conv"] = "conv";
|
|
336
|
+
Primitive$1["Pool"] = "pool";
|
|
337
|
+
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
143
338
|
Primitive$1["Compare"] = "compare";
|
|
144
339
|
Primitive$1["Where"] = "where";
|
|
145
340
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -197,34 +392,52 @@ function sin$1(x) {
|
|
|
197
392
|
function cos$1(x) {
|
|
198
393
|
return bind1(Primitive.Cos, [x]);
|
|
199
394
|
}
|
|
395
|
+
function asin$1(x) {
|
|
396
|
+
return bind1(Primitive.Asin, [x]);
|
|
397
|
+
}
|
|
398
|
+
function atan$1(x) {
|
|
399
|
+
return bind1(Primitive.Atan, [x]);
|
|
400
|
+
}
|
|
200
401
|
function exp$1(x) {
|
|
201
402
|
return bind1(Primitive.Exp, [x]);
|
|
202
403
|
}
|
|
203
404
|
function log$1(x) {
|
|
204
405
|
return bind1(Primitive.Log, [x]);
|
|
205
406
|
}
|
|
407
|
+
function sqrt$1(x) {
|
|
408
|
+
return bind1(Primitive.Sqrt, [x]);
|
|
409
|
+
}
|
|
206
410
|
function min$1(x, y) {
|
|
207
411
|
return bind1(Primitive.Min, [x, y]);
|
|
208
412
|
}
|
|
209
413
|
function max$1(x, y) {
|
|
210
414
|
return bind1(Primitive.Max, [x, y]);
|
|
211
415
|
}
|
|
212
|
-
function reduce(x, op, axis, opts) {
|
|
416
|
+
function reduce(x, op, axis = null, opts) {
|
|
213
417
|
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
214
|
-
|
|
215
|
-
else axis = [];
|
|
216
|
-
else if (typeof axis === "number") axis = [checkAxis(axis, ndim$1(x))];
|
|
217
|
-
else axis = axis.map((a) => checkAxis(a, ndim$1(x)));
|
|
418
|
+
axis = normalizeAxis(axis, ndim$1(x));
|
|
218
419
|
const originalShape = getShape(x);
|
|
219
|
-
|
|
420
|
+
let result = bind1(Primitive.Reduce, [x], {
|
|
220
421
|
op,
|
|
221
422
|
axis
|
|
222
423
|
});
|
|
223
|
-
|
|
424
|
+
if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
|
|
425
|
+
return result;
|
|
224
426
|
}
|
|
225
427
|
function dot$1(x, y) {
|
|
226
428
|
return bind1(Primitive.Dot, [x, y]);
|
|
227
429
|
}
|
|
430
|
+
function conv(x, y, params = {}) {
|
|
431
|
+
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
432
|
+
const n = x.ndim - 2;
|
|
433
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
434
|
+
return bind1(Primitive.Conv, [x, y], {
|
|
435
|
+
strides: params.strides ?? rep(n, 1),
|
|
436
|
+
padding: params.padding ?? rep(n, [0, 0]),
|
|
437
|
+
lhsDilation: params.lhsDilation ?? rep(n, 1),
|
|
438
|
+
rhsDilation: params.rhsDilation ?? rep(n, 1)
|
|
439
|
+
});
|
|
440
|
+
}
|
|
228
441
|
function compare(x, y, op) {
|
|
229
442
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
230
443
|
}
|
|
@@ -255,10 +468,11 @@ function where$1(cond, x, y) {
|
|
|
255
468
|
}
|
|
256
469
|
function transpose$1(x, perm) {
|
|
257
470
|
perm = perm ? perm.map((a) => checkAxis(a, ndim$1(x))) : range(ndim$1(x)).reverse();
|
|
471
|
+
if (!isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
258
472
|
return bind1(Primitive.Transpose, [x], { perm });
|
|
259
473
|
}
|
|
260
474
|
function broadcast(x, shape$1, axis) {
|
|
261
|
-
axis = axis
|
|
475
|
+
axis = normalizeAxis(axis, shape$1.length);
|
|
262
476
|
return bind1(Primitive.Broadcast, [x], {
|
|
263
477
|
shape: shape$1,
|
|
264
478
|
axis
|
|
@@ -277,7 +491,7 @@ function reshape$1(x, shape$1) {
|
|
|
277
491
|
return bind1(Primitive.Reshape, [x], { shape: shape$1 });
|
|
278
492
|
}
|
|
279
493
|
function flip$1(x, axis) {
|
|
280
|
-
axis = axis
|
|
494
|
+
axis = normalizeAxis(axis, ndim$1(x));
|
|
281
495
|
return bind1(Primitive.Flip, [x], { axis });
|
|
282
496
|
}
|
|
283
497
|
function shrink(x, slice) {
|
|
@@ -357,12 +571,19 @@ var Tracer = class Tracer {
|
|
|
357
571
|
constructor(trace) {
|
|
358
572
|
this._trace = trace;
|
|
359
573
|
}
|
|
574
|
+
/** The shape of the array. */
|
|
360
575
|
get shape() {
|
|
361
576
|
return this.aval.shape;
|
|
362
577
|
}
|
|
578
|
+
/** The total number of elements in the array. */
|
|
579
|
+
get size() {
|
|
580
|
+
return prod(this.shape);
|
|
581
|
+
}
|
|
582
|
+
/** The dtype of the array. */
|
|
363
583
|
get dtype() {
|
|
364
584
|
return this.aval.dtype;
|
|
365
585
|
}
|
|
586
|
+
/** The number of dimensions of the array. */
|
|
366
587
|
get ndim() {
|
|
367
588
|
return this.shape.length;
|
|
368
589
|
}
|
|
@@ -398,22 +619,20 @@ var Tracer = class Tracer {
|
|
|
398
619
|
return lessEqual$1(this, other);
|
|
399
620
|
}
|
|
400
621
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
401
|
-
sum(axis, opts) {
|
|
622
|
+
sum(axis = null, opts) {
|
|
402
623
|
return reduce(this, AluOp.Add, axis, opts);
|
|
403
624
|
}
|
|
404
625
|
/** Product of the array elements over a given axis. */
|
|
405
|
-
prod(axis, opts) {
|
|
626
|
+
prod(axis = null, opts) {
|
|
406
627
|
return reduce(this, AluOp.Mul, axis, opts);
|
|
407
628
|
}
|
|
408
629
|
/** Compute the average of the array elements along the specified axis. */
|
|
409
|
-
mean(axis, opts) {
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
416
|
-
return result;
|
|
630
|
+
mean(axis = null, opts) {
|
|
631
|
+
axis = normalizeAxis(axis, this.ndim);
|
|
632
|
+
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
633
|
+
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
634
|
+
const result = reduce(this, AluOp.Add, axis, opts);
|
|
635
|
+
return result.mul(1 / n);
|
|
417
636
|
}
|
|
418
637
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
419
638
|
transpose(perm) {
|
|
@@ -445,8 +664,29 @@ var Tracer = class Tracer {
|
|
|
445
664
|
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
446
665
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
447
666
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
667
|
+
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
668
|
+
axis1 = checkAxis(axis1, this.ndim);
|
|
669
|
+
axis2 = checkAxis(axis2, this.ndim);
|
|
448
670
|
if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
|
|
449
|
-
throw new Error("
|
|
671
|
+
if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
|
|
672
|
+
let ar = this;
|
|
673
|
+
if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
|
|
674
|
+
const perm = range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
|
|
675
|
+
ar = ar.transpose(perm);
|
|
676
|
+
}
|
|
677
|
+
const [n, m] = ar.shape.slice(-2);
|
|
678
|
+
const diagSize = Math.min(n, m - offset);
|
|
679
|
+
ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
|
|
680
|
+
const npad = diagSize * (m + 1) - n * m;
|
|
681
|
+
if (npad > 0) ar = pad$1(ar, [...rep(ar.ndim - 1, [0, 0]), [0, npad]]);
|
|
682
|
+
else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
|
|
683
|
+
ar = ar.reshape([
|
|
684
|
+
...ar.shape.slice(0, -1),
|
|
685
|
+
diagSize,
|
|
686
|
+
m + 1
|
|
687
|
+
]);
|
|
688
|
+
ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
|
|
689
|
+
return ar;
|
|
450
690
|
}
|
|
451
691
|
/** Flatten the array without changing its data. */
|
|
452
692
|
flatten() {
|
|
@@ -589,7 +829,7 @@ var ShapedArray = class ShapedArray {
|
|
|
589
829
|
get ndim() {
|
|
590
830
|
return this.shape.length;
|
|
591
831
|
}
|
|
592
|
-
|
|
832
|
+
toString() {
|
|
593
833
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
594
834
|
}
|
|
595
835
|
equals(other) {
|
|
@@ -620,7 +860,7 @@ function fullRaise(trace, val) {
|
|
|
620
860
|
if (Object.is(val._trace.main, trace.main)) return val;
|
|
621
861
|
else if (val._trace.main.level < level) return trace.lift(val);
|
|
622
862
|
else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
|
|
623
|
-
else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
|
|
863
|
+
else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
|
|
624
864
|
}
|
|
625
865
|
var TreeMismatchError = class extends TypeError {
|
|
626
866
|
constructor(where$2, left, right) {
|
|
@@ -869,16 +1109,16 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
869
1109
|
jitCompileCache.set(cacheKey, jp);
|
|
870
1110
|
return jp;
|
|
871
1111
|
}
|
|
872
|
-
function reshapeViews(exp$2, mapping) {
|
|
1112
|
+
function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
873
1113
|
return exp$2.rewrite((exp$3) => {
|
|
874
1114
|
if (exp$3.op === AluOp.GlobalView) {
|
|
875
1115
|
const [gid, st] = exp$3.arg;
|
|
876
1116
|
const newSt = mapping(st);
|
|
877
1117
|
if (newSt) {
|
|
878
|
-
const indices = unravelAlu(newSt.shape, AluVar.gidx);
|
|
1118
|
+
const indices = reduceAxis ? unravelAlu(newSt.shape.slice(0, -1), AluVar.gidx).concat(AluVar.ridx) : unravelAlu(newSt.shape, AluVar.gidx);
|
|
879
1119
|
return AluExp.globalView(exp$3.dtype, gid, newSt, indices);
|
|
880
1120
|
}
|
|
881
|
-
}
|
|
1121
|
+
} else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
882
1122
|
});
|
|
883
1123
|
}
|
|
884
1124
|
function broadcastedJit(fn) {
|
|
@@ -925,8 +1165,11 @@ const jitRules = {
|
|
|
925
1165
|
},
|
|
926
1166
|
[Primitive.Sin]: unopJit(AluExp.sin),
|
|
927
1167
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
1168
|
+
[Primitive.Asin]: unopJit(AluExp.asin),
|
|
1169
|
+
[Primitive.Atan]: unopJit(AluExp.atan),
|
|
928
1170
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
929
1171
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1172
|
+
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
930
1173
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
931
1174
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
932
1175
|
[Primitive.Reduce](nargs, [a], [as], { op, axis }) {
|
|
@@ -941,18 +1184,20 @@ const jitRules = {
|
|
|
941
1184
|
const size$1 = prod(newShape);
|
|
942
1185
|
const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
943
1186
|
newShape.push(reductionSize);
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
const [gid, st] = exp$2.arg;
|
|
947
|
-
const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
|
|
948
|
-
const indices = unravelAlu(newShape.slice(0, -1), AluVar.gidx);
|
|
949
|
-
indices.push(AluVar.ridx);
|
|
950
|
-
return AluExp.globalView(exp$2.dtype, gid, newSt, indices);
|
|
951
|
-
}
|
|
952
|
-
});
|
|
1187
|
+
const perm = keptAxes.concat(shiftedAxes);
|
|
1188
|
+
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
953
1189
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
954
1190
|
return new Kernel(nargs, size$1, a, reduction);
|
|
955
1191
|
},
|
|
1192
|
+
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1193
|
+
[Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
|
|
1194
|
+
let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1195
|
+
const size$1 = prod(inShape);
|
|
1196
|
+
stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
|
|
1197
|
+
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1198
|
+
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1199
|
+
return new Kernel(nargs, size$1, a, reduction);
|
|
1200
|
+
},
|
|
956
1201
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
957
1202
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
958
1203
|
const c = k1.exp;
|
|
@@ -962,6 +1207,14 @@ const jitRules = {
|
|
|
962
1207
|
axis: [cs.ndim - 1]
|
|
963
1208
|
});
|
|
964
1209
|
},
|
|
1210
|
+
[Primitive.Conv](nargs, [a, b], [as, bs], params) {
|
|
1211
|
+
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1212
|
+
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1213
|
+
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1214
|
+
as = new ShapedArray(stX.shape, as.dtype);
|
|
1215
|
+
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1216
|
+
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1217
|
+
},
|
|
965
1218
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
966
1219
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
|
|
967
1220
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
@@ -974,8 +1227,20 @@ const jitRules = {
|
|
|
974
1227
|
}),
|
|
975
1228
|
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
976
1229
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
977
|
-
[Primitive.Gather]() {
|
|
978
|
-
|
|
1230
|
+
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1231
|
+
const axisSet = new Set(axis);
|
|
1232
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1233
|
+
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1234
|
+
finalShape.splice(outDim, 0, ...indexShape);
|
|
1235
|
+
const idxAll = unravelAlu(finalShape, AluVar.gidx);
|
|
1236
|
+
const idxNonaxis = [...idxAll];
|
|
1237
|
+
idxNonaxis.splice(outDim, indexShape.length);
|
|
1238
|
+
const src = [...idxNonaxis];
|
|
1239
|
+
for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1240
|
+
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
1241
|
+
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1242
|
+
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1243
|
+
return new Kernel(nargs, prod(finalShape), x.substitute({ gidx: index }));
|
|
979
1244
|
},
|
|
980
1245
|
[Primitive.JitCall]() {
|
|
981
1246
|
throw new Error("internal: JitCall should have been flattened before JIT compilation");
|
|
@@ -994,9 +1259,15 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
994
1259
|
blackNodes.add(v);
|
|
995
1260
|
p1NextBlack.set(v, v);
|
|
996
1261
|
}
|
|
1262
|
+
const reducePrimitives = [
|
|
1263
|
+
Primitive.Reduce,
|
|
1264
|
+
Primitive.Dot,
|
|
1265
|
+
Primitive.Conv,
|
|
1266
|
+
Primitive.PoolTranspose
|
|
1267
|
+
];
|
|
997
1268
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
998
1269
|
const eqn = jaxpr.eqns[i];
|
|
999
|
-
if (eqn.primitive === Primitive.
|
|
1270
|
+
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1000
1271
|
for (const v of eqn.outBinders) {
|
|
1001
1272
|
blackNodes.add(v);
|
|
1002
1273
|
p1NextBlack.set(v, v);
|
|
@@ -1137,7 +1408,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1137
1408
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1138
1409
|
* will be freed when the array is disposed.
|
|
1139
1410
|
*/
|
|
1140
|
-
constructor(source, st, dtype, backend, pending = null) {
|
|
1411
|
+
constructor(source, st, dtype, backend, { pending = null } = {}) {
|
|
1141
1412
|
super(baseArrayTrace);
|
|
1142
1413
|
this.id = Array$1.#nextId++;
|
|
1143
1414
|
this.#dtype = dtype;
|
|
@@ -1146,6 +1417,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1146
1417
|
this.#backend = backend;
|
|
1147
1418
|
this.#rc = 1;
|
|
1148
1419
|
this.#pendingSet = new Set(pending);
|
|
1420
|
+
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1421
|
+
else if (source instanceof AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1149
1422
|
}
|
|
1150
1423
|
/** @ignore */
|
|
1151
1424
|
get aval() {
|
|
@@ -1200,7 +1473,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1200
1473
|
const pending = this.#pending;
|
|
1201
1474
|
for (const exe of pending) exe.updateRc(1);
|
|
1202
1475
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1203
|
-
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
|
|
1476
|
+
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
|
|
1204
1477
|
this.dispose();
|
|
1205
1478
|
return ar;
|
|
1206
1479
|
}
|
|
@@ -1223,7 +1496,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1223
1496
|
const inputs = [];
|
|
1224
1497
|
const src = [...idxNonaxis];
|
|
1225
1498
|
for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1226
|
-
for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#
|
|
1499
|
+
for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#source, ar.#st, idxAxis));
|
|
1227
1500
|
else {
|
|
1228
1501
|
let gid = inputs.indexOf(ar.#source);
|
|
1229
1502
|
if (gid === -1) {
|
|
@@ -1233,7 +1506,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1233
1506
|
src[axis[i]] = AluExp.cast(DType.Int32, AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
|
|
1234
1507
|
}
|
|
1235
1508
|
let exp$2;
|
|
1236
|
-
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#
|
|
1509
|
+
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, src);
|
|
1237
1510
|
else {
|
|
1238
1511
|
let gid = inputs.indexOf(this.#source);
|
|
1239
1512
|
if (gid === -1) {
|
|
@@ -1249,7 +1522,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1249
1522
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1250
1523
|
this.dispose();
|
|
1251
1524
|
for (const ar of indices) ar.dispose();
|
|
1252
|
-
return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
|
|
1525
|
+
return new Array$1(output, ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
|
|
1253
1526
|
}
|
|
1254
1527
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1255
1528
|
#moveAxesDown(axis) {
|
|
@@ -1276,7 +1549,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1276
1549
|
this.#check();
|
|
1277
1550
|
if (this.#source instanceof AluExp) {
|
|
1278
1551
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1279
|
-
return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
|
|
1552
|
+
return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
|
|
1280
1553
|
}
|
|
1281
1554
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1282
1555
|
const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1286,7 +1559,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1286
1559
|
for (const exe of pending) exe.updateRc(1);
|
|
1287
1560
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1288
1561
|
this.dispose();
|
|
1289
|
-
return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
|
|
1562
|
+
return new Array$1(output, ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
|
|
1290
1563
|
}
|
|
1291
1564
|
#binary(op, other) {
|
|
1292
1565
|
const custom = (src) => new AluExp(op, this.#dtype, src);
|
|
@@ -1309,18 +1582,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1309
1582
|
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1310
1583
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1311
1584
|
const newShape = [...arrays[0].shape];
|
|
1312
|
-
if (arrays.every((ar) => ar.#source instanceof AluExp) && reduceAxis
|
|
1585
|
+
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
1313
1586
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1314
1587
|
const exp$4 = custom(arrays.map((ar) => ar.#source));
|
|
1315
|
-
return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
|
|
1588
|
+
return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
|
|
1316
1589
|
}
|
|
1317
1590
|
const exp$3 = custom(arrays.map((ar) => {
|
|
1318
1591
|
const src$1 = ar.#source;
|
|
1319
1592
|
if (ar.#st.contiguous) return src$1;
|
|
1320
|
-
return accessorAluExp(
|
|
1593
|
+
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1321
1594
|
}));
|
|
1322
1595
|
const st = ShapeTracker.fromShape(newShape);
|
|
1323
|
-
return new Array$1(exp$3, st, exp$3.dtype, backend);
|
|
1596
|
+
return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
|
|
1324
1597
|
}
|
|
1325
1598
|
let indices;
|
|
1326
1599
|
if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
|
|
@@ -1330,7 +1603,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1330
1603
|
}
|
|
1331
1604
|
const inputs = [];
|
|
1332
1605
|
const src = [];
|
|
1333
|
-
for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#
|
|
1606
|
+
for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
|
|
1334
1607
|
else {
|
|
1335
1608
|
let gid = inputs.indexOf(ar.#source);
|
|
1336
1609
|
if (gid === -1) {
|
|
@@ -1351,7 +1624,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1351
1624
|
for (const exe of pending) exe.updateRc(1);
|
|
1352
1625
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1353
1626
|
for (const ar of arrays) ar.dispose();
|
|
1354
|
-
return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
|
|
1627
|
+
return new Array$1(output, ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
|
|
1355
1628
|
}
|
|
1356
1629
|
/** Reduce the last dimension of the array by an operation. */
|
|
1357
1630
|
#reduce(op) {
|
|
@@ -1364,7 +1637,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1364
1637
|
const indices = [...unravelAlu(newShape, AluVar.gidx), AluVar.ridx];
|
|
1365
1638
|
let exp$2;
|
|
1366
1639
|
const inputs = [];
|
|
1367
|
-
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#
|
|
1640
|
+
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, indices);
|
|
1368
1641
|
else {
|
|
1369
1642
|
inputs.push(this.#source);
|
|
1370
1643
|
exp$2 = accessorGlobal(this.#dtype, 0, this.#st, indices);
|
|
@@ -1375,7 +1648,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1375
1648
|
for (const exe of pending) exe.updateRc(1);
|
|
1376
1649
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1377
1650
|
this.dispose();
|
|
1378
|
-
return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
|
|
1651
|
+
return new Array$1(output, ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
|
|
1379
1652
|
}
|
|
1380
1653
|
/**
|
|
1381
1654
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1389,7 +1662,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1389
1662
|
this.#check();
|
|
1390
1663
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1391
1664
|
if (this.#source instanceof AluExp) {
|
|
1392
|
-
const exp$2 = accessorAluExp(this.#
|
|
1665
|
+
const exp$2 = accessorAluExp(this.#source, this.#st, indices);
|
|
1393
1666
|
const kernel = new Kernel(0, this.#st.size, exp$2);
|
|
1394
1667
|
const output = this.#backend.malloc(kernel.bytes);
|
|
1395
1668
|
const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
|
|
@@ -1427,42 +1700,54 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1427
1700
|
}
|
|
1428
1701
|
/** Realize the array and return it as data. */
|
|
1429
1702
|
async data() {
|
|
1430
|
-
if (this.#source instanceof AluExp &&
|
|
1703
|
+
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1431
1704
|
this.#realize();
|
|
1432
1705
|
const pending = this.#pending;
|
|
1433
1706
|
if (pending) {
|
|
1434
1707
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1435
1708
|
for (const p of pending) p.submit();
|
|
1436
1709
|
}
|
|
1437
|
-
const byteCount = byteWidth(this.#dtype) *
|
|
1710
|
+
const byteCount = byteWidth(this.#dtype) * this.size;
|
|
1438
1711
|
const buf = await this.#backend.read(this.#source, 0, byteCount);
|
|
1439
1712
|
this.dispose();
|
|
1440
1713
|
return dtypedArray(this.dtype, buf);
|
|
1441
1714
|
}
|
|
1442
|
-
/**
|
|
1443
|
-
|
|
1715
|
+
/**
|
|
1716
|
+
* Wait for this array to finish evaluation.
|
|
1717
|
+
*
|
|
1718
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1719
|
+
* that pending operations are dispatched and fully executed before it
|
|
1720
|
+
* returns.
|
|
1721
|
+
*
|
|
1722
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1723
|
+
* dispatch of operations as well.
|
|
1724
|
+
*
|
|
1725
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1726
|
+
* asynchronously for multiple arrays.
|
|
1727
|
+
*/
|
|
1728
|
+
async blockUntilReady() {
|
|
1444
1729
|
this.#check();
|
|
1445
|
-
if (this.#source instanceof AluExp) return;
|
|
1730
|
+
if (this.#source instanceof AluExp) return this;
|
|
1446
1731
|
const pending = this.#pending;
|
|
1447
1732
|
if (pending) {
|
|
1448
1733
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1449
1734
|
for (const p of pending) p.submit();
|
|
1450
1735
|
}
|
|
1451
1736
|
await this.#backend.read(this.#source, 0, 0);
|
|
1452
|
-
this
|
|
1737
|
+
return this;
|
|
1453
1738
|
}
|
|
1454
1739
|
/**
|
|
1455
1740
|
* Realize the array and return it as data. This is a sync variant and not
|
|
1456
1741
|
* recommended for performance reasons, as it will block rendering.
|
|
1457
1742
|
*/
|
|
1458
1743
|
dataSync() {
|
|
1459
|
-
if (this.#source instanceof AluExp &&
|
|
1744
|
+
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1460
1745
|
this.#realize();
|
|
1461
1746
|
for (const p of this.#pending) {
|
|
1462
1747
|
p.prepareSync();
|
|
1463
1748
|
p.submit();
|
|
1464
1749
|
}
|
|
1465
|
-
const byteCount = byteWidth(this.#dtype) *
|
|
1750
|
+
const byteCount = byteWidth(this.#dtype) * this.size;
|
|
1466
1751
|
const buf = this.#backend.readSync(this.#source, 0, byteCount);
|
|
1467
1752
|
this.dispose();
|
|
1468
1753
|
return dtypedArray(this.dtype, buf);
|
|
@@ -1483,6 +1768,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1483
1768
|
async jsAsync() {
|
|
1484
1769
|
return dataToJs(this.dtype, await this.data(), this.shape);
|
|
1485
1770
|
}
|
|
1771
|
+
/**
|
|
1772
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
1773
|
+
*
|
|
1774
|
+
* Throws an error if the array does not have a single element. The array must
|
|
1775
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
1776
|
+
*/
|
|
1777
|
+
item() {
|
|
1778
|
+
if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
|
|
1779
|
+
return this.dataSync()[0];
|
|
1780
|
+
}
|
|
1486
1781
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1487
1782
|
static _implRules() {
|
|
1488
1783
|
return {
|
|
@@ -1496,7 +1791,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1496
1791
|
return [x.#binary(AluOp.Idiv, y)];
|
|
1497
1792
|
},
|
|
1498
1793
|
[Primitive.Neg]([x]) {
|
|
1499
|
-
return [zerosLike(x).#binary(AluOp.Sub, x)];
|
|
1794
|
+
return [zerosLike$1(x.ref).#binary(AluOp.Sub, x)];
|
|
1500
1795
|
},
|
|
1501
1796
|
[Primitive.Reciprocal]([x]) {
|
|
1502
1797
|
return [x.#unary(AluOp.Reciprocal)];
|
|
@@ -1516,7 +1811,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1516
1811
|
x.#backend.incRef(x.#source);
|
|
1517
1812
|
const pending = x.#pending;
|
|
1518
1813
|
for (const exe of pending) exe.updateRc(1);
|
|
1519
|
-
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
|
|
1814
|
+
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
|
|
1520
1815
|
x.dispose();
|
|
1521
1816
|
return [y];
|
|
1522
1817
|
}
|
|
@@ -1546,12 +1841,21 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1546
1841
|
[Primitive.Cos]([x]) {
|
|
1547
1842
|
return [x.#unary(AluOp.Cos)];
|
|
1548
1843
|
},
|
|
1844
|
+
[Primitive.Asin]([x]) {
|
|
1845
|
+
return [x.#unary(AluOp.Asin)];
|
|
1846
|
+
},
|
|
1847
|
+
[Primitive.Atan]([x]) {
|
|
1848
|
+
return [x.#unary(AluOp.Atan)];
|
|
1849
|
+
},
|
|
1549
1850
|
[Primitive.Exp]([x]) {
|
|
1550
1851
|
return [x.#unary(AluOp.Exp)];
|
|
1551
1852
|
},
|
|
1552
1853
|
[Primitive.Log]([x]) {
|
|
1553
1854
|
return [x.#unary(AluOp.Log)];
|
|
1554
1855
|
},
|
|
1856
|
+
[Primitive.Sqrt]([x]) {
|
|
1857
|
+
return [x.#unary(AluOp.Sqrt)];
|
|
1858
|
+
},
|
|
1555
1859
|
[Primitive.Min]([x, y]) {
|
|
1556
1860
|
return [x.#binary(AluOp.Min, y)];
|
|
1557
1861
|
},
|
|
@@ -1562,9 +1866,24 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1562
1866
|
if (axis.length === 0) return [x];
|
|
1563
1867
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
1564
1868
|
},
|
|
1869
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
1870
|
+
const st = pool(x.#st, window, strides);
|
|
1871
|
+
return [x.#reshape(st)];
|
|
1872
|
+
},
|
|
1873
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
1874
|
+
const n = inShape.length;
|
|
1875
|
+
let st = poolTranspose(x.#st, inShape, window, strides);
|
|
1876
|
+
st = st.reshape([...st.shape.slice(0, n), prod(st.shape.slice(n))]);
|
|
1877
|
+
return [x.#reshape(st).#reduce(AluOp.Add)];
|
|
1878
|
+
},
|
|
1565
1879
|
[Primitive.Dot]([x, y]) {
|
|
1566
1880
|
return [Array$1.#naryCustom("dot", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
|
|
1567
1881
|
},
|
|
1882
|
+
[Primitive.Conv]([x, y], params) {
|
|
1883
|
+
checkConvShape(x.shape, y.shape, params);
|
|
1884
|
+
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
1885
|
+
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
1886
|
+
},
|
|
1568
1887
|
[Primitive.Compare]([x, y], { op }) {
|
|
1569
1888
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1570
1889
|
return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
|
|
@@ -1613,7 +1932,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1613
1932
|
pending.splice(0, 0, ...prevPending);
|
|
1614
1933
|
args.forEach((x) => x.dispose());
|
|
1615
1934
|
return outputs.map((source, i) => {
|
|
1616
|
-
return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
|
|
1935
|
+
return new Array$1(source, ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
|
|
1617
1936
|
});
|
|
1618
1937
|
}
|
|
1619
1938
|
};
|
|
@@ -1629,6 +1948,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1629
1948
|
dtype ??= DType.Float32;
|
|
1630
1949
|
if (![
|
|
1631
1950
|
DType.Float32,
|
|
1951
|
+
DType.Float16,
|
|
1632
1952
|
DType.Int32,
|
|
1633
1953
|
DType.Uint32
|
|
1634
1954
|
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
@@ -1636,6 +1956,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1636
1956
|
dtype ??= DType.Bool;
|
|
1637
1957
|
if (![
|
|
1638
1958
|
DType.Float32,
|
|
1959
|
+
DType.Float16,
|
|
1639
1960
|
DType.Int32,
|
|
1640
1961
|
DType.Uint32,
|
|
1641
1962
|
DType.Bool
|
|
@@ -1649,7 +1970,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1649
1970
|
if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1650
1971
|
if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
|
|
1651
1972
|
return values;
|
|
1652
|
-
} else if (values
|
|
1973
|
+
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1653
1974
|
dtype,
|
|
1654
1975
|
device
|
|
1655
1976
|
});
|
|
@@ -1678,7 +1999,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1678
1999
|
});
|
|
1679
2000
|
} else {
|
|
1680
2001
|
dtype = dtype ?? DType.Float32;
|
|
1681
|
-
const data =
|
|
2002
|
+
const data = dtypedJsArray(dtype, flat);
|
|
1682
2003
|
return arrayFromData(data, shape$1, {
|
|
1683
2004
|
dtype,
|
|
1684
2005
|
device
|
|
@@ -1699,19 +2020,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
|
|
|
1699
2020
|
});
|
|
1700
2021
|
}
|
|
1701
2022
|
const backend = getBackend(device);
|
|
1702
|
-
if (data
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
if (
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
2023
|
+
if (ArrayBuffer.isView(data)) {
|
|
2024
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2025
|
+
if (data instanceof Float32Array) {
|
|
2026
|
+
if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2027
|
+
dtype ??= DType.Float32;
|
|
2028
|
+
} else if (data instanceof Int32Array) {
|
|
2029
|
+
if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2030
|
+
dtype ??= DType.Int32;
|
|
2031
|
+
} else if (data instanceof Uint32Array) {
|
|
2032
|
+
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2033
|
+
dtype ??= DType.Uint32;
|
|
2034
|
+
} else if (data instanceof Float16Array) {
|
|
2035
|
+
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2036
|
+
dtype ??= DType.Float16;
|
|
2037
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2038
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2039
|
+
return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2040
|
+
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
1715
2041
|
}
|
|
1716
2042
|
function dataToJs(dtype, data, shape$1) {
|
|
1717
2043
|
if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -1738,9 +2064,20 @@ var EvalTrace = class extends Trace {
|
|
|
1738
2064
|
};
|
|
1739
2065
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
1740
2066
|
const implRules = Array$1._implRules();
|
|
1741
|
-
function zerosLike(val) {
|
|
2067
|
+
function zerosLike$1(val, dtype) {
|
|
2068
|
+
const aval = getAval(val);
|
|
2069
|
+
if (val instanceof Tracer) val.dispose();
|
|
2070
|
+
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2071
|
+
}
|
|
2072
|
+
function onesLike$1(val, dtype) {
|
|
2073
|
+
const aval = getAval(val);
|
|
2074
|
+
if (val instanceof Tracer) val.dispose();
|
|
2075
|
+
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2076
|
+
}
|
|
2077
|
+
function fullLike(val, fillValue, dtype) {
|
|
1742
2078
|
const aval = getAval(val);
|
|
1743
|
-
|
|
2079
|
+
if (val instanceof Tracer) val.dispose();
|
|
2080
|
+
return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
|
|
1744
2081
|
}
|
|
1745
2082
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1746
2083
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -1762,6 +2099,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
1762
2099
|
if (typeof fillValue === "number") {
|
|
1763
2100
|
dtype = dtype ?? DType.Float32;
|
|
1764
2101
|
source = AluExp.const(dtype, fillValue);
|
|
2102
|
+
} else if (typeof fillValue === "bigint") {
|
|
2103
|
+
dtype = dtype ?? DType.Int32;
|
|
2104
|
+
source = AluExp.const(dtype, Number(fillValue));
|
|
1765
2105
|
} else if (typeof fillValue === "boolean") {
|
|
1766
2106
|
dtype = dtype ?? DType.Bool;
|
|
1767
2107
|
source = AluExp.const(dtype, fillValue ? 1 : 0);
|
|
@@ -1792,7 +2132,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
1792
2132
|
const exp$2 = AluExp.cmplt(AluExp.mod(AluVar.idx, AluExp.i32(numCols + 1)), AluExp.i32(1));
|
|
1793
2133
|
return new Array$1(AluExp.cast(dtype, exp$2), ShapeTracker.fromShape([numRows, numCols]), dtype, getBackend(device));
|
|
1794
2134
|
}
|
|
1795
|
-
/** Return the identity
|
|
2135
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
1796
2136
|
function identity$1(n, { dtype, device } = {}) {
|
|
1797
2137
|
return eye(n, n, {
|
|
1798
2138
|
dtype,
|
|
@@ -1859,7 +2199,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
1859
2199
|
const st = ShapeTracker.fromShape([num]);
|
|
1860
2200
|
return new Array$1(exp$2, st, dtype, getBackend(device));
|
|
1861
2201
|
}
|
|
1862
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
1863
2202
|
function aluCompare(a, b, op) {
|
|
1864
2203
|
switch (op) {
|
|
1865
2204
|
case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
|
|
@@ -1901,7 +2240,7 @@ function generalBroadcast(a, b) {
|
|
|
1901
2240
|
}
|
|
1902
2241
|
|
|
1903
2242
|
//#endregion
|
|
1904
|
-
//#region node_modules/.pnpm/@oxc-project+runtime@0.
|
|
2243
|
+
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
|
|
1905
2244
|
function _usingCtx() {
|
|
1906
2245
|
var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
|
|
1907
2246
|
var n$1 = Error();
|
|
@@ -1969,7 +2308,7 @@ var Var = class Var {
|
|
|
1969
2308
|
this.aval = aval;
|
|
1970
2309
|
}
|
|
1971
2310
|
toString() {
|
|
1972
|
-
return `Var(${this.id}):${this.aval.
|
|
2311
|
+
return `Var(${this.id}):${this.aval.toString()}`;
|
|
1973
2312
|
}
|
|
1974
2313
|
};
|
|
1975
2314
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
@@ -2009,7 +2348,7 @@ var VarPrinter = class {
|
|
|
2009
2348
|
return name;
|
|
2010
2349
|
}
|
|
2011
2350
|
nameType(v) {
|
|
2012
|
-
return `${this.name(v)}:${v.aval.
|
|
2351
|
+
return `${this.name(v)}:${v.aval.toString()}`;
|
|
2013
2352
|
}
|
|
2014
2353
|
};
|
|
2015
2354
|
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
@@ -2069,16 +2408,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2069
2408
|
varIds.set(v, FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
|
|
2070
2409
|
return id;
|
|
2071
2410
|
};
|
|
2072
|
-
hasher.update(this.inBinders.length
|
|
2073
|
-
|
|
2074
|
-
|
|
2075
|
-
|
|
2076
|
-
|
|
2077
|
-
|
|
2078
|
-
eqn.
|
|
2079
|
-
|
|
2080
|
-
|
|
2081
|
-
|
|
2411
|
+
hasher.update(this.inBinders.length);
|
|
2412
|
+
for (const x of this.inBinders) hasher.update(vi(x));
|
|
2413
|
+
hasher.update(this.eqns.length);
|
|
2414
|
+
for (const eqn of this.eqns) {
|
|
2415
|
+
hasher.update(eqn.primitive);
|
|
2416
|
+
hasher.update(eqn.inputs.length);
|
|
2417
|
+
for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2418
|
+
hasher.update(JSON.stringify(eqn.params));
|
|
2419
|
+
hasher.update(eqn.outBinders.length);
|
|
2420
|
+
for (const x of eqn.outBinders) hasher.update(vi(x));
|
|
2421
|
+
}
|
|
2422
|
+
hasher.update(this.outs.length);
|
|
2423
|
+
for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2082
2424
|
return this.#hash = hasher.value;
|
|
2083
2425
|
}
|
|
2084
2426
|
hash(state) {
|
|
@@ -2115,7 +2457,7 @@ var Jaxpr = class Jaxpr {
|
|
|
2115
2457
|
const c = eqn.outBinders[0];
|
|
2116
2458
|
if (atomIsLit(b, 1)) context.set(c, a);
|
|
2117
2459
|
else newEqns.push(eqn);
|
|
2118
|
-
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2460
|
+
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2119
2461
|
else newEqns.push(eqn);
|
|
2120
2462
|
}
|
|
2121
2463
|
const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
|
|
@@ -2164,8 +2506,8 @@ var JaxprType = class {
|
|
|
2164
2506
|
this.outTypes = outTypes;
|
|
2165
2507
|
}
|
|
2166
2508
|
toString() {
|
|
2167
|
-
const inTypes = this.inTypes.map((aval) => aval.
|
|
2168
|
-
const outTypes = this.outTypes.map((aval) => aval.
|
|
2509
|
+
const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
|
|
2510
|
+
const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
|
|
2169
2511
|
return `(${inTypes}) -> (${outTypes})`;
|
|
2170
2512
|
}
|
|
2171
2513
|
};
|
|
@@ -2244,7 +2586,7 @@ var JaxprTracer = class extends Tracer {
|
|
|
2244
2586
|
this.aval = aval;
|
|
2245
2587
|
}
|
|
2246
2588
|
toString() {
|
|
2247
|
-
return `JaxprTracer(${this.aval.
|
|
2589
|
+
return `JaxprTracer(${this.aval.toString()})`;
|
|
2248
2590
|
}
|
|
2249
2591
|
get ref() {
|
|
2250
2592
|
return this;
|
|
@@ -2381,8 +2723,11 @@ const abstractEvalRules = {
|
|
|
2381
2723
|
},
|
|
2382
2724
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2383
2725
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2726
|
+
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
2727
|
+
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2384
2728
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2385
2729
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2730
|
+
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2386
2731
|
[Primitive.Min]: binopAbstractEval,
|
|
2387
2732
|
[Primitive.Max]: binopAbstractEval,
|
|
2388
2733
|
[Primitive.Reduce]([x], { axis }) {
|
|
@@ -2390,6 +2735,15 @@ const abstractEvalRules = {
|
|
|
2390
2735
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2391
2736
|
return [new ShapedArray(newShape, x.dtype)];
|
|
2392
2737
|
},
|
|
2738
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
2739
|
+
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2740
|
+
return [new ShapedArray(shape$1, x.dtype)];
|
|
2741
|
+
},
|
|
2742
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2743
|
+
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2744
|
+
if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2745
|
+
return [new ShapedArray(inShape, x.dtype)];
|
|
2746
|
+
},
|
|
2393
2747
|
[Primitive.Dot]([x, y]) {
|
|
2394
2748
|
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2395
2749
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
@@ -2397,6 +2751,11 @@ const abstractEvalRules = {
|
|
|
2397
2751
|
shape$1.splice(-1, 1);
|
|
2398
2752
|
return [new ShapedArray(shape$1, x.dtype)];
|
|
2399
2753
|
},
|
|
2754
|
+
[Primitive.Conv]([lhs, rhs], params) {
|
|
2755
|
+
if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
|
|
2756
|
+
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2757
|
+
return [new ShapedArray(shape$1, lhs.dtype)];
|
|
2758
|
+
},
|
|
2400
2759
|
[Primitive.Compare]: compareAbstractEval,
|
|
2401
2760
|
[Primitive.Where]([cond, x, y]) {
|
|
2402
2761
|
if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
@@ -2444,15 +2803,34 @@ const abstractEvalRules = {
|
|
|
2444
2803
|
return outTypes;
|
|
2445
2804
|
}
|
|
2446
2805
|
};
|
|
2447
|
-
function
|
|
2806
|
+
function splitIdx(values, argnums) {
|
|
2807
|
+
const a = [];
|
|
2808
|
+
const b = [];
|
|
2809
|
+
for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
|
|
2810
|
+
else b.push(values[i]);
|
|
2811
|
+
return [a, b];
|
|
2812
|
+
}
|
|
2813
|
+
function joinIdx(n, a, b, argnums) {
|
|
2814
|
+
const result = [];
|
|
2815
|
+
let ai = 0;
|
|
2816
|
+
let bi = 0;
|
|
2817
|
+
for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
|
|
2818
|
+
else result.push(b[bi++]);
|
|
2819
|
+
return result;
|
|
2820
|
+
}
|
|
2821
|
+
function makeJaxpr$1(f, opts) {
|
|
2448
2822
|
return (...argsIn) => {
|
|
2449
2823
|
try {
|
|
2450
2824
|
var _usingCtx$1 = _usingCtx();
|
|
2451
|
-
const
|
|
2452
|
-
const [
|
|
2825
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2826
|
+
const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
|
|
2827
|
+
const [avalsIn, inTree] = flatten(shapedArgs);
|
|
2828
|
+
const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
|
|
2829
|
+
return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
|
|
2830
|
+
}, inTree);
|
|
2453
2831
|
const builder = new JaxprBuilder();
|
|
2454
2832
|
const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
|
|
2455
|
-
|
|
2833
|
+
_usingCtx$1.u(newDynamic(main));
|
|
2456
2834
|
const trace = new JaxprTrace(main);
|
|
2457
2835
|
const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
2458
2836
|
const outs = fFlat(...tracersIn);
|
|
@@ -2471,20 +2849,27 @@ function makeJaxpr$1(f) {
|
|
|
2471
2849
|
}
|
|
2472
2850
|
};
|
|
2473
2851
|
}
|
|
2474
|
-
function jit$1(f) {
|
|
2852
|
+
function jit$1(f, opts) {
|
|
2475
2853
|
const cache = /* @__PURE__ */ new Map();
|
|
2476
|
-
|
|
2477
|
-
|
|
2854
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2855
|
+
const result = ((...args) => {
|
|
2856
|
+
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2857
|
+
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2478
2858
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
2479
2859
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
2480
|
-
const
|
|
2481
|
-
const
|
|
2860
|
+
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
2861
|
+
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2862
|
+
const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2482
2863
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2483
2864
|
jaxpr,
|
|
2484
2865
|
numConsts: consts.length
|
|
2485
2866
|
});
|
|
2486
2867
|
return unflatten(outTree, outs);
|
|
2487
2868
|
});
|
|
2869
|
+
result.dispose = () => {
|
|
2870
|
+
for (const { consts } of cache.values()) for (const c of consts) c.dispose();
|
|
2871
|
+
};
|
|
2872
|
+
return result;
|
|
2488
2873
|
}
|
|
2489
2874
|
|
|
2490
2875
|
//#endregion
|
|
@@ -2515,7 +2900,7 @@ var JVPTrace = class extends Trace {
|
|
|
2515
2900
|
return this.lift(pureArray(val));
|
|
2516
2901
|
}
|
|
2517
2902
|
lift(val) {
|
|
2518
|
-
return new JVPTracer(this, val, zerosLike(val));
|
|
2903
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
2519
2904
|
}
|
|
2520
2905
|
processPrimitive(primitive, tracers, params) {
|
|
2521
2906
|
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2533,19 +2918,25 @@ function linearTangentsJvp(primitive) {
|
|
|
2533
2918
|
return [ys, dys];
|
|
2534
2919
|
};
|
|
2535
2920
|
}
|
|
2921
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
2922
|
+
function bilinearTangentsJvp(primitive) {
|
|
2923
|
+
return ([x, y], [dx, dy], params) => {
|
|
2924
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
2925
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
2926
|
+
return [[primal], [tangent]];
|
|
2927
|
+
};
|
|
2928
|
+
}
|
|
2536
2929
|
/** Rule that zeros out any tangents. */
|
|
2537
2930
|
function zeroTangentsJvp(primitive) {
|
|
2538
2931
|
return (primals, tangents, params) => {
|
|
2539
2932
|
for (const t of tangents) t.dispose();
|
|
2540
2933
|
const ys = bind(primitive, primals, params);
|
|
2541
|
-
return [ys, ys.map((y) => zerosLike(y))];
|
|
2934
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
2542
2935
|
};
|
|
2543
2936
|
}
|
|
2544
2937
|
const jvpRules = {
|
|
2545
2938
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
2546
|
-
[Primitive.Mul](
|
|
2547
|
-
return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
|
|
2548
|
-
},
|
|
2939
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
2549
2940
|
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
2550
2941
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
2551
2942
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
@@ -2558,13 +2949,13 @@ const jvpRules = {
|
|
|
2558
2949
|
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2559
2950
|
else {
|
|
2560
2951
|
dx.dispose();
|
|
2561
|
-
return [[cast(x, dtype)], [zerosLike(x)]];
|
|
2952
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2562
2953
|
}
|
|
2563
2954
|
},
|
|
2564
2955
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2565
2956
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2566
2957
|
dx.dispose();
|
|
2567
|
-
return [[bitcast(x, dtype)], [zerosLike(x)]];
|
|
2958
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2568
2959
|
},
|
|
2569
2960
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2570
2961
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2573,6 +2964,14 @@ const jvpRules = {
|
|
|
2573
2964
|
[Primitive.Cos]([x], [dx]) {
|
|
2574
2965
|
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
2575
2966
|
},
|
|
2967
|
+
[Primitive.Asin]([x], [dx]) {
|
|
2968
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
2969
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
2970
|
+
},
|
|
2971
|
+
[Primitive.Atan]([x], [dx]) {
|
|
2972
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
2973
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
2974
|
+
},
|
|
2576
2975
|
[Primitive.Exp]([x], [dx]) {
|
|
2577
2976
|
const z = exp$1(x);
|
|
2578
2977
|
return [[z.ref], [z.mul(dx)]];
|
|
@@ -2580,6 +2979,10 @@ const jvpRules = {
|
|
|
2580
2979
|
[Primitive.Log]([x], [dx]) {
|
|
2581
2980
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
2582
2981
|
},
|
|
2982
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
2983
|
+
const z = sqrt$1(x);
|
|
2984
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
2985
|
+
},
|
|
2583
2986
|
[Primitive.Min]([x, y], [dx, dy]) {
|
|
2584
2987
|
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
2585
2988
|
},
|
|
@@ -2596,13 +2999,14 @@ const jvpRules = {
|
|
|
2596
2999
|
const primal = reduce(x.ref, op, axis);
|
|
2597
3000
|
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
2598
3001
|
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
2599
|
-
const tangent = where$1(notMin,
|
|
3002
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
2600
3003
|
return [[primal], [tangent]];
|
|
2601
3004
|
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
2602
3005
|
},
|
|
2603
|
-
[Primitive.
|
|
2604
|
-
|
|
2605
|
-
|
|
3006
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3007
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3008
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3009
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
2606
3010
|
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
2607
3011
|
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
2608
3012
|
dcond.dispose();
|
|
@@ -2683,7 +3087,10 @@ function mappedAval(batchDim, aval) {
|
|
|
2683
3087
|
/** Move one axis to a different index. */
|
|
2684
3088
|
function moveaxis$1(x, src, dst) {
|
|
2685
3089
|
const t = pureArray(x);
|
|
2686
|
-
|
|
3090
|
+
src = checkAxis(src, t.ndim);
|
|
3091
|
+
dst = checkAxis(dst, t.ndim);
|
|
3092
|
+
if (src === dst) return t;
|
|
3093
|
+
const perm = range(t.ndim);
|
|
2687
3094
|
perm.splice(src, 1);
|
|
2688
3095
|
perm.splice(dst, 0, src);
|
|
2689
3096
|
return transpose$1(t, perm);
|
|
@@ -2776,8 +3183,11 @@ const vmapRules = {
|
|
|
2776
3183
|
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
2777
3184
|
[Primitive.Sin]: unopBatcher(sin$1),
|
|
2778
3185
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3186
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3187
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
2779
3188
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
2780
3189
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3190
|
+
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
2781
3191
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
2782
3192
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
2783
3193
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
@@ -2914,7 +3324,7 @@ var PartialVal = class PartialVal {
|
|
|
2914
3324
|
return this.val !== null;
|
|
2915
3325
|
}
|
|
2916
3326
|
toString() {
|
|
2917
|
-
return this.val ? this.val.toString() : this.aval.
|
|
3327
|
+
return this.val ? this.val.toString() : this.aval.toString();
|
|
2918
3328
|
}
|
|
2919
3329
|
};
|
|
2920
3330
|
function partialEvalFlat(f, pvalsIn) {
|
|
@@ -2960,20 +3370,28 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
2960
3370
|
function linearizeFlat(f, primalsIn) {
|
|
2961
3371
|
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
2962
3372
|
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
2963
|
-
|
|
3373
|
+
const dispose$1 = () => {
|
|
3374
|
+
for (const c of consts) c.dispose();
|
|
3375
|
+
};
|
|
3376
|
+
return [
|
|
3377
|
+
primalsOut,
|
|
3378
|
+
fLin,
|
|
3379
|
+
dispose$1
|
|
3380
|
+
];
|
|
2964
3381
|
}
|
|
2965
3382
|
function linearize$1(f, ...primalsIn) {
|
|
2966
3383
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
2967
3384
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
2968
|
-
const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3385
|
+
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
2969
3386
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
2970
3387
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
2971
|
-
const fLin = (...tangentsIn) => {
|
|
3388
|
+
const fLin = ((...tangentsIn) => {
|
|
2972
3389
|
const [tangentsInFlat, inTree2] = flatten(tangentsIn);
|
|
2973
3390
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
|
|
2974
3391
|
const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
|
|
2975
3392
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
2976
|
-
};
|
|
3393
|
+
});
|
|
3394
|
+
fLin.dispose = dispose$1;
|
|
2977
3395
|
return [primalsOut, fLin];
|
|
2978
3396
|
}
|
|
2979
3397
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -3089,7 +3507,10 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3089
3507
|
avalsOut: jaxpr2.outs.map((x) => x.aval),
|
|
3090
3508
|
tracerRefsOut: []
|
|
3091
3509
|
};
|
|
3092
|
-
const outs2 = jaxpr2.outs.map((x) =>
|
|
3510
|
+
const outs2 = jaxpr2.outs.map((x, i$1) => {
|
|
3511
|
+
if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
|
|
3512
|
+
return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
|
|
3513
|
+
});
|
|
3093
3514
|
recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
|
|
3094
3515
|
let i = 0;
|
|
3095
3516
|
let j = 0;
|
|
@@ -3173,13 +3594,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3173
3594
|
const [consts, constvars] = unzip2(constToVar.entries());
|
|
3174
3595
|
const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
|
|
3175
3596
|
const outVars = tracersOut.map((t) => tracerToVar.get(t));
|
|
3176
|
-
|
|
3597
|
+
let jaxpr = new Jaxpr(inBinders, eqns, outVars);
|
|
3177
3598
|
typecheckJaxpr(jaxpr);
|
|
3178
3599
|
for (const t of consts) t.ref;
|
|
3179
3600
|
for (const t of tracersIn) t.dispose();
|
|
3180
3601
|
for (const t of tracersOut) t.dispose();
|
|
3602
|
+
jaxpr = jaxpr.simplify();
|
|
3603
|
+
if (DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3181
3604
|
return {
|
|
3182
|
-
jaxpr
|
|
3605
|
+
jaxpr,
|
|
3183
3606
|
consts
|
|
3184
3607
|
};
|
|
3185
3608
|
}
|
|
@@ -3288,12 +3711,72 @@ const transposeRules = {
|
|
|
3288
3711
|
if (op === AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
|
|
3289
3712
|
else throw new NonlinearError(Primitive.Reduce);
|
|
3290
3713
|
},
|
|
3714
|
+
[Primitive.Pool]([ct], [x], { window, strides }) {
|
|
3715
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
|
|
3716
|
+
return bind(Primitive.PoolTranspose, [ct], {
|
|
3717
|
+
inShape: x.aval.shape,
|
|
3718
|
+
window,
|
|
3719
|
+
strides
|
|
3720
|
+
});
|
|
3721
|
+
},
|
|
3722
|
+
[Primitive.PoolTranspose]([ct], [x], { window, strides }) {
|
|
3723
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
|
|
3724
|
+
return bind(Primitive.Pool, [ct], {
|
|
3725
|
+
window,
|
|
3726
|
+
strides
|
|
3727
|
+
});
|
|
3728
|
+
},
|
|
3291
3729
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3292
3730
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3293
3731
|
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3294
3732
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3295
3733
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3296
3734
|
},
|
|
3735
|
+
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
3736
|
+
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
3737
|
+
const rev01 = [
|
|
3738
|
+
1,
|
|
3739
|
+
0,
|
|
3740
|
+
...range(2, ct.ndim)
|
|
3741
|
+
];
|
|
3742
|
+
if (lhs instanceof UndefPrimal) {
|
|
3743
|
+
let kernel = rhs;
|
|
3744
|
+
kernel = transpose$1(kernel, rev01);
|
|
3745
|
+
kernel = flip$1(kernel, range(2, kernel.ndim));
|
|
3746
|
+
const result = conv(ct, kernel, {
|
|
3747
|
+
strides: params.lhsDilation,
|
|
3748
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3749
|
+
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3750
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3751
|
+
const padBefore = dilatedKernel - 1 - pl;
|
|
3752
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3753
|
+
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
3754
|
+
return [padBefore, padAfter];
|
|
3755
|
+
}),
|
|
3756
|
+
lhsDilation: params.strides,
|
|
3757
|
+
rhsDilation: params.rhsDilation
|
|
3758
|
+
});
|
|
3759
|
+
return [result, null];
|
|
3760
|
+
} else {
|
|
3761
|
+
const newLhs = transpose$1(lhs, rev01);
|
|
3762
|
+
const newRhs = transpose$1(ct, rev01);
|
|
3763
|
+
let result = conv(newLhs, newRhs, {
|
|
3764
|
+
strides: params.rhsDilation,
|
|
3765
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3766
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3767
|
+
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3768
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3769
|
+
const padFromLhs = dilatedCt - dilatedLhs;
|
|
3770
|
+
const padFromRhs = dilatedKernel - pl - 1;
|
|
3771
|
+
return [pl, padFromLhs + padFromRhs];
|
|
3772
|
+
}),
|
|
3773
|
+
lhsDilation: params.lhsDilation,
|
|
3774
|
+
rhsDilation: params.strides
|
|
3775
|
+
});
|
|
3776
|
+
result = transpose$1(result, rev01);
|
|
3777
|
+
return [null, result];
|
|
3778
|
+
}
|
|
3779
|
+
},
|
|
3297
3780
|
[Primitive.Where]([ct], [cond, x, y]) {
|
|
3298
3781
|
const cts = [
|
|
3299
3782
|
null,
|
|
@@ -3385,20 +3868,28 @@ function vjpFlat(f, primalsIn) {
|
|
|
3385
3868
|
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
3386
3869
|
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
3387
3870
|
};
|
|
3388
|
-
|
|
3871
|
+
const dispose$1 = () => {
|
|
3872
|
+
for (const c of consts) c.dispose();
|
|
3873
|
+
};
|
|
3874
|
+
return [
|
|
3875
|
+
primalsOut,
|
|
3876
|
+
fVjp,
|
|
3877
|
+
dispose$1
|
|
3878
|
+
];
|
|
3389
3879
|
}
|
|
3390
3880
|
function vjp$1(f, ...primalsIn) {
|
|
3391
3881
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3392
3882
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3393
|
-
const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3883
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3394
3884
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
3395
3885
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3396
|
-
const fVjp = (cotangentsOut) => {
|
|
3886
|
+
const fVjp = ((cotangentsOut) => {
|
|
3397
3887
|
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
3398
3888
|
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
3399
3889
|
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
3400
3890
|
return unflatten(inTree, cotangentsInFlat);
|
|
3401
|
-
};
|
|
3891
|
+
});
|
|
3892
|
+
fVjp.dispose = dispose$1;
|
|
3402
3893
|
return [primalsOut, fVjp];
|
|
3403
3894
|
}
|
|
3404
3895
|
function grad$1(f) {
|
|
@@ -3414,9 +3905,10 @@ function valueAndGrad$1(f) {
|
|
|
3414
3905
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
3415
3906
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3416
3907
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3417
|
-
if (y.dtype
|
|
3418
|
-
const [ct, ...rest] = fVjp(
|
|
3419
|
-
for (const r of rest)
|
|
3908
|
+
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3909
|
+
const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
|
|
3910
|
+
for (const r of rest) dispose(r);
|
|
3911
|
+
fVjp.dispose();
|
|
3420
3912
|
return [y, ct];
|
|
3421
3913
|
};
|
|
3422
3914
|
}
|
|
@@ -3424,11 +3916,84 @@ function jacrev$1(f) {
|
|
|
3424
3916
|
return function jacobianReverse(x) {
|
|
3425
3917
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
3426
3918
|
const [size$1] = x.shape;
|
|
3427
|
-
const pullback = (ct) =>
|
|
3919
|
+
const pullback = (ct) => {
|
|
3920
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
3921
|
+
y.dispose();
|
|
3922
|
+
const [ret] = fVjp(ct);
|
|
3923
|
+
fVjp.dispose();
|
|
3924
|
+
return ret;
|
|
3925
|
+
};
|
|
3428
3926
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3429
3927
|
};
|
|
3430
3928
|
}
|
|
3431
3929
|
|
|
3930
|
+
//#endregion
|
|
3931
|
+
//#region src/lax.ts
|
|
3932
|
+
var lax_exports = {};
|
|
3933
|
+
__export(lax_exports, {
|
|
3934
|
+
conv: () => conv$1,
|
|
3935
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
3936
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
3937
|
+
reduceWindow: () => reduceWindow
|
|
3938
|
+
});
|
|
3939
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
3940
|
+
const padType = padding.toUpperCase();
|
|
3941
|
+
switch (padType) {
|
|
3942
|
+
case "VALID": return rep(inShape.length, [0, 0]);
|
|
3943
|
+
case "SAME":
|
|
3944
|
+
case "SAME_LOWER": {
|
|
3945
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
3946
|
+
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
3947
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
3948
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
3949
|
+
}
|
|
3950
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
3951
|
+
}
|
|
3952
|
+
}
|
|
3953
|
+
/**
|
|
3954
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
3955
|
+
*
|
|
3956
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
3957
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
3958
|
+
*
|
|
3959
|
+
* Grouped convolutions are not supported right now.
|
|
3960
|
+
*/
|
|
3961
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
3962
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
3963
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
3964
|
+
if (typeof padding === "string") {
|
|
3965
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
3966
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
3967
|
+
}
|
|
3968
|
+
return conv(lhs, rhs, {
|
|
3969
|
+
strides: windowStrides,
|
|
3970
|
+
padding,
|
|
3971
|
+
lhsDilation,
|
|
3972
|
+
rhsDilation
|
|
3973
|
+
});
|
|
3974
|
+
}
|
|
3975
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
3976
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
3977
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
3978
|
+
lhsDilation,
|
|
3979
|
+
rhsDilation
|
|
3980
|
+
});
|
|
3981
|
+
}
|
|
3982
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
3983
|
+
function conv$1(lhs, rhs, windowStrides, padding) {
|
|
3984
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
3985
|
+
}
|
|
3986
|
+
/** Reduce a computation over padded windows. */
|
|
3987
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
3988
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
3989
|
+
if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
|
|
3990
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
3991
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
3992
|
+
window: windowDimensions,
|
|
3993
|
+
strides: windowStrides
|
|
3994
|
+
}));
|
|
3995
|
+
}
|
|
3996
|
+
|
|
3432
3997
|
//#endregion
|
|
3433
3998
|
//#region src/numpy.ts
|
|
3434
3999
|
var numpy_exports = {};
|
|
@@ -3437,19 +4002,38 @@ __export(numpy_exports, {
|
|
|
3437
4002
|
DType: () => DType,
|
|
3438
4003
|
abs: () => abs,
|
|
3439
4004
|
absolute: () => absolute,
|
|
4005
|
+
acos: () => acos,
|
|
4006
|
+
acosh: () => acosh,
|
|
3440
4007
|
add: () => add,
|
|
3441
4008
|
allclose: () => allclose,
|
|
3442
4009
|
arange: () => arange,
|
|
4010
|
+
arccos: () => arccos,
|
|
4011
|
+
arccosh: () => arccosh,
|
|
4012
|
+
arcsinh: () => arcsinh,
|
|
4013
|
+
arctan: () => arctan,
|
|
4014
|
+
arctan2: () => arctan2,
|
|
4015
|
+
arctanh: () => arctanh,
|
|
3443
4016
|
argmax: () => argmax,
|
|
3444
4017
|
argmin: () => argmin,
|
|
3445
4018
|
array: () => array,
|
|
4019
|
+
asin: () => asin,
|
|
4020
|
+
asinh: () => asinh,
|
|
3446
4021
|
astype: () => astype,
|
|
4022
|
+
atan: () => atan,
|
|
4023
|
+
atan2: () => atan2,
|
|
4024
|
+
atanh: () => atanh,
|
|
3447
4025
|
bool: () => bool,
|
|
4026
|
+
broadcastArrays: () => broadcastArrays,
|
|
4027
|
+
broadcastShapes: () => broadcastShapes,
|
|
4028
|
+
broadcastTo: () => broadcastTo,
|
|
4029
|
+
cbrt: () => cbrt,
|
|
3448
4030
|
clip: () => clip,
|
|
3449
4031
|
columnStack: () => columnStack,
|
|
3450
|
-
complex64: () => complex64,
|
|
3451
4032
|
concatenate: () => concatenate,
|
|
3452
4033
|
cos: () => cos,
|
|
4034
|
+
cosh: () => cosh,
|
|
4035
|
+
deg2rad: () => deg2rad,
|
|
4036
|
+
degrees: () => degrees,
|
|
3453
4037
|
diag: () => diag,
|
|
3454
4038
|
diagonal: () => diagonal,
|
|
3455
4039
|
divide: () => divide,
|
|
@@ -3460,23 +4044,29 @@ __export(numpy_exports, {
|
|
|
3460
4044
|
eulerGamma: () => eulerGamma,
|
|
3461
4045
|
exp: () => exp,
|
|
3462
4046
|
exp2: () => exp2,
|
|
4047
|
+
expm1: () => expm1,
|
|
3463
4048
|
eye: () => eye,
|
|
3464
4049
|
flip: () => flip,
|
|
3465
4050
|
fliplr: () => fliplr,
|
|
3466
4051
|
flipud: () => flipud,
|
|
4052
|
+
float16: () => float16,
|
|
3467
4053
|
float32: () => float32,
|
|
3468
4054
|
full: () => full,
|
|
4055
|
+
fullLike: () => fullLike$1,
|
|
3469
4056
|
greater: () => greater,
|
|
3470
4057
|
greaterEqual: () => greaterEqual,
|
|
3471
4058
|
hstack: () => hstack,
|
|
4059
|
+
hypot: () => hypot,
|
|
3472
4060
|
identity: () => identity$1,
|
|
3473
4061
|
inf: () => inf,
|
|
4062
|
+
inner: () => inner,
|
|
3474
4063
|
int32: () => int32,
|
|
3475
4064
|
less: () => less,
|
|
3476
4065
|
lessEqual: () => lessEqual,
|
|
3477
4066
|
linspace: () => linspace,
|
|
3478
4067
|
log: () => log,
|
|
3479
4068
|
log10: () => log10,
|
|
4069
|
+
log1p: () => log1p,
|
|
3480
4070
|
log2: () => log2,
|
|
3481
4071
|
matmul: () => matmul,
|
|
3482
4072
|
max: () => max,
|
|
@@ -3492,36 +4082,55 @@ __export(numpy_exports, {
|
|
|
3492
4082
|
negative: () => negative,
|
|
3493
4083
|
notEqual: () => notEqual,
|
|
3494
4084
|
ones: () => ones,
|
|
4085
|
+
onesLike: () => onesLike,
|
|
4086
|
+
outer: () => outer,
|
|
3495
4087
|
pad: () => pad,
|
|
3496
4088
|
permuteDims: () => permuteDims,
|
|
3497
4089
|
pi: () => pi,
|
|
4090
|
+
pow: () => pow,
|
|
4091
|
+
power: () => power,
|
|
3498
4092
|
prod: () => prod$1,
|
|
4093
|
+
promoteTypes: () => promoteTypes,
|
|
4094
|
+
rad2deg: () => rad2deg,
|
|
4095
|
+
radians: () => radians,
|
|
3499
4096
|
ravel: () => ravel,
|
|
3500
4097
|
reciprocal: () => reciprocal,
|
|
4098
|
+
repeat: () => repeat,
|
|
3501
4099
|
reshape: () => reshape,
|
|
3502
|
-
scalar: () => scalar,
|
|
3503
4100
|
shape: () => shape,
|
|
4101
|
+
sign: () => sign,
|
|
3504
4102
|
sin: () => sin,
|
|
4103
|
+
sinh: () => sinh,
|
|
3505
4104
|
size: () => size,
|
|
4105
|
+
sqrt: () => sqrt,
|
|
3506
4106
|
square: () => square,
|
|
3507
4107
|
stack: () => stack,
|
|
4108
|
+
std: () => std,
|
|
4109
|
+
subtract: () => subtract,
|
|
3508
4110
|
sum: () => sum,
|
|
3509
4111
|
tan: () => tan,
|
|
4112
|
+
tanh: () => tanh,
|
|
4113
|
+
tile: () => tile,
|
|
3510
4114
|
transpose: () => transpose,
|
|
4115
|
+
tri: () => tri,
|
|
4116
|
+
tril: () => tril,
|
|
4117
|
+
triu: () => triu,
|
|
3511
4118
|
trueDivide: () => trueDivide,
|
|
3512
4119
|
trunc: () => trunc,
|
|
3513
4120
|
uint32: () => uint32,
|
|
4121
|
+
var_: () => var_,
|
|
3514
4122
|
vdot: () => vdot,
|
|
3515
4123
|
vecdot: () => vecdot,
|
|
3516
4124
|
vstack: () => vstack,
|
|
3517
4125
|
where: () => where,
|
|
3518
|
-
zeros: () => zeros
|
|
4126
|
+
zeros: () => zeros,
|
|
4127
|
+
zerosLike: () => zerosLike
|
|
3519
4128
|
});
|
|
3520
4129
|
const float32 = DType.Float32;
|
|
3521
4130
|
const int32 = DType.Int32;
|
|
3522
4131
|
const uint32 = DType.Uint32;
|
|
3523
4132
|
const bool = DType.Bool;
|
|
3524
|
-
const
|
|
4133
|
+
const float16 = DType.Float16;
|
|
3525
4134
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
3526
4135
|
const e = Math.E;
|
|
3527
4136
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -3532,52 +4141,66 @@ const inf = Number.POSITIVE_INFINITY;
|
|
|
3532
4141
|
const nan = NaN;
|
|
3533
4142
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
3534
4143
|
const pi = Math.PI;
|
|
3535
|
-
/** Element-wise addition, with broadcasting. */
|
|
4144
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
3536
4145
|
const add = add$1;
|
|
3537
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
4146
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
3538
4147
|
const multiply = mul;
|
|
3539
|
-
/** Numerical negative of every element of an array. */
|
|
4148
|
+
/** @function Numerical negative of every element of an array. */
|
|
3540
4149
|
const negative = neg;
|
|
3541
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4150
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
3542
4151
|
const reciprocal = reciprocal$1;
|
|
3543
|
-
/** Element-wise sine function (takes radians). */
|
|
4152
|
+
/** @function Element-wise sine function (takes radians). */
|
|
3544
4153
|
const sin = sin$1;
|
|
3545
|
-
/** Element-wise cosine function (takes radians). */
|
|
4154
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
3546
4155
|
const cos = cos$1;
|
|
3547
|
-
/**
|
|
4156
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
4157
|
+
const asin = asin$1;
|
|
4158
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
4159
|
+
const atan = atan$1;
|
|
4160
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
3548
4161
|
const exp = exp$1;
|
|
3549
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
4162
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
3550
4163
|
const log = log$1;
|
|
3551
|
-
/**
|
|
4164
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
4165
|
+
const sqrt = sqrt$1;
|
|
4166
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
3552
4167
|
const minimum = min$1;
|
|
3553
|
-
/** Return element-wise maximum of the input arrays. */
|
|
4168
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
3554
4169
|
const maximum = max$1;
|
|
3555
|
-
/** Compare two arrays element-wise. */
|
|
4170
|
+
/** @function Compare two arrays element-wise. */
|
|
3556
4171
|
const greater = greater$1;
|
|
3557
|
-
/** Compare two arrays element-wise. */
|
|
4172
|
+
/** @function Compare two arrays element-wise. */
|
|
3558
4173
|
const less = less$1;
|
|
3559
|
-
/** Compare two arrays element-wise. */
|
|
4174
|
+
/** @function Compare two arrays element-wise. */
|
|
3560
4175
|
const equal = equal$1;
|
|
3561
|
-
/** Compare two arrays element-wise. */
|
|
4176
|
+
/** @function Compare two arrays element-wise. */
|
|
3562
4177
|
const notEqual = notEqual$1;
|
|
3563
|
-
/** Compare two arrays element-wise. */
|
|
4178
|
+
/** @function Compare two arrays element-wise. */
|
|
3564
4179
|
const greaterEqual = greaterEqual$1;
|
|
3565
|
-
/** Compare two arrays element-wise. */
|
|
4180
|
+
/** @function Compare two arrays element-wise. */
|
|
3566
4181
|
const lessEqual = lessEqual$1;
|
|
3567
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4182
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
3568
4183
|
const where = where$1;
|
|
3569
|
-
/**
|
|
4184
|
+
/**
|
|
4185
|
+
* @function
|
|
4186
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
4187
|
+
*/
|
|
3570
4188
|
const transpose = transpose$1;
|
|
3571
4189
|
/**
|
|
4190
|
+
* @function
|
|
3572
4191
|
* Give a new shape to an array without changing its data.
|
|
3573
4192
|
*
|
|
3574
4193
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
3575
4194
|
* length of the array and remaining dimensions.
|
|
3576
4195
|
*/
|
|
3577
4196
|
const reshape = reshape$1;
|
|
3578
|
-
/**
|
|
4197
|
+
/**
|
|
4198
|
+
* @function
|
|
4199
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
4200
|
+
*/
|
|
3579
4201
|
const moveaxis = moveaxis$1;
|
|
3580
4202
|
/**
|
|
4203
|
+
* @function
|
|
3581
4204
|
* Add padding (zeros) to an array.
|
|
3582
4205
|
*
|
|
3583
4206
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -3585,11 +4208,29 @@ const moveaxis = moveaxis$1;
|
|
|
3585
4208
|
* pair specifies the padding for its corresponding axis.
|
|
3586
4209
|
*/
|
|
3587
4210
|
const pad = pad$1;
|
|
3588
|
-
/**
|
|
4211
|
+
/**
|
|
4212
|
+
* @function
|
|
4213
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
4214
|
+
*/
|
|
3589
4215
|
const ndim = ndim$1;
|
|
3590
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
4216
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
3591
4217
|
const shape = getShape;
|
|
3592
4218
|
/**
|
|
4219
|
+
* @function
|
|
4220
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
4221
|
+
*/
|
|
4222
|
+
const zerosLike = zerosLike$1;
|
|
4223
|
+
/**
|
|
4224
|
+
* @function
|
|
4225
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
4226
|
+
*/
|
|
4227
|
+
const onesLike = onesLike$1;
|
|
4228
|
+
/**
|
|
4229
|
+
* @function
|
|
4230
|
+
* Return a full array with the same shape and type as a given array.
|
|
4231
|
+
*/
|
|
4232
|
+
const fullLike$1 = fullLike;
|
|
4233
|
+
/**
|
|
3593
4234
|
* Return the number of elements in an array, optionally along an axis.
|
|
3594
4235
|
* Does not consume array reference.
|
|
3595
4236
|
*/
|
|
@@ -3602,23 +4243,23 @@ function astype(a, dtype) {
|
|
|
3602
4243
|
return fudgeArray(a).astype(dtype);
|
|
3603
4244
|
}
|
|
3604
4245
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
3605
|
-
function sum(a, axis, opts) {
|
|
4246
|
+
function sum(a, axis = null, opts) {
|
|
3606
4247
|
return reduce(a, AluOp.Add, axis, opts);
|
|
3607
4248
|
}
|
|
3608
4249
|
/** Product of the array elements over a given axis. */
|
|
3609
|
-
function prod$1(a, axis, opts) {
|
|
4250
|
+
function prod$1(a, axis = null, opts) {
|
|
3610
4251
|
return reduce(a, AluOp.Mul, axis, opts);
|
|
3611
4252
|
}
|
|
3612
4253
|
/** Return the minimum of array elements along a given axis. */
|
|
3613
|
-
function min(a, axis, opts) {
|
|
4254
|
+
function min(a, axis = null, opts) {
|
|
3614
4255
|
return reduce(a, AluOp.Min, axis, opts);
|
|
3615
4256
|
}
|
|
3616
4257
|
/** Return the maximum of array elements along a given axis. */
|
|
3617
|
-
function max(a, axis, opts) {
|
|
4258
|
+
function max(a, axis = null, opts) {
|
|
3618
4259
|
return reduce(a, AluOp.Max, axis, opts);
|
|
3619
4260
|
}
|
|
3620
4261
|
/** Compute the average of the array elements along the specified axis. */
|
|
3621
|
-
function mean(a, axis, opts) {
|
|
4262
|
+
function mean(a, axis = null, opts) {
|
|
3622
4263
|
return fudgeArray(a).mean(axis, opts);
|
|
3623
4264
|
}
|
|
3624
4265
|
/**
|
|
@@ -3634,18 +4275,12 @@ function argmin(a, axis, opts) {
|
|
|
3634
4275
|
axis = 0;
|
|
3635
4276
|
} else axis = checkAxis(axis, a.ndim);
|
|
3636
4277
|
const shape$1 = a.shape;
|
|
3637
|
-
const isMax = equal(a, min(a.ref, axis, {
|
|
4278
|
+
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
3638
4279
|
const length = scalar(shape$1[axis], {
|
|
3639
4280
|
dtype: int32,
|
|
3640
4281
|
device: a.device
|
|
3641
4282
|
});
|
|
3642
|
-
const idx =
|
|
3643
|
-
dtype: int32,
|
|
3644
|
-
device: a.device
|
|
3645
|
-
}), scalar(0, {
|
|
3646
|
-
dtype: int32,
|
|
3647
|
-
device: a.device
|
|
3648
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4283
|
+
const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3649
4284
|
dtype: int32,
|
|
3650
4285
|
device: a.device
|
|
3651
4286
|
}).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
|
|
@@ -3664,35 +4299,21 @@ function argmax(a, axis, opts) {
|
|
|
3664
4299
|
axis = 0;
|
|
3665
4300
|
} else axis = checkAxis(axis, a.ndim);
|
|
3666
4301
|
const shape$1 = a.shape;
|
|
3667
|
-
const isMax = equal(a, max(a.ref, axis, {
|
|
4302
|
+
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
3668
4303
|
const length = scalar(shape$1[axis], {
|
|
3669
4304
|
dtype: int32,
|
|
3670
4305
|
device: a.device
|
|
3671
4306
|
});
|
|
3672
|
-
const idx =
|
|
3673
|
-
dtype: int32,
|
|
3674
|
-
device: a.device
|
|
3675
|
-
}), scalar(0, {
|
|
3676
|
-
dtype: int32,
|
|
3677
|
-
device: a.device
|
|
3678
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4307
|
+
const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3679
4308
|
dtype: int32,
|
|
3680
4309
|
device: a.device
|
|
3681
4310
|
}).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
|
|
3682
4311
|
return length.sub(max(idx, axis, opts));
|
|
3683
4312
|
}
|
|
3684
4313
|
/** Reverse the elements in an array along the given axes. */
|
|
3685
|
-
function flip(x, axis) {
|
|
4314
|
+
function flip(x, axis = null) {
|
|
3686
4315
|
const nd = ndim(x);
|
|
3687
|
-
|
|
3688
|
-
else if (typeof axis === "number") axis = [axis];
|
|
3689
|
-
const seen = /* @__PURE__ */ new Set();
|
|
3690
|
-
for (let i = 0; i < axis.length; i++) {
|
|
3691
|
-
if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
|
|
3692
|
-
if (axis[i] < 0) axis[i] += nd;
|
|
3693
|
-
if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
|
|
3694
|
-
seen.add(axis[i]);
|
|
3695
|
-
}
|
|
4316
|
+
axis = normalizeAxis(axis, nd);
|
|
3696
4317
|
return flip$1(x, axis);
|
|
3697
4318
|
}
|
|
3698
4319
|
/**
|
|
@@ -3798,18 +4419,88 @@ function flipud(x) {
|
|
|
3798
4419
|
function fliplr(x) {
|
|
3799
4420
|
return flip(x, 1);
|
|
3800
4421
|
}
|
|
4422
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
3801
4423
|
const permuteDims = transpose;
|
|
3802
4424
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
3803
4425
|
function ravel(a) {
|
|
3804
4426
|
return fudgeArray(a).ravel();
|
|
3805
4427
|
}
|
|
3806
4428
|
/**
|
|
4429
|
+
* Repeat each element of an array after themselves.
|
|
4430
|
+
*
|
|
4431
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
4432
|
+
* output array.
|
|
4433
|
+
*/
|
|
4434
|
+
function repeat(a, repeats, axis) {
|
|
4435
|
+
if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
|
|
4436
|
+
a = fudgeArray(a);
|
|
4437
|
+
if (axis === void 0) {
|
|
4438
|
+
a = ravel(a);
|
|
4439
|
+
axis = 0;
|
|
4440
|
+
}
|
|
4441
|
+
axis = checkAxis(axis, a.ndim);
|
|
4442
|
+
if (repeats === 1) return a;
|
|
4443
|
+
const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
|
|
4444
|
+
const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
|
|
4445
|
+
return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
|
|
4446
|
+
}
|
|
4447
|
+
/**
|
|
4448
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
4449
|
+
*
|
|
4450
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
4451
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
4452
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
4453
|
+
*/
|
|
4454
|
+
function tile(a, reps) {
|
|
4455
|
+
a = fudgeArray(a);
|
|
4456
|
+
if (typeof reps === "number") reps = [reps];
|
|
4457
|
+
if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
|
|
4458
|
+
const ndiff = reps.length - a.ndim;
|
|
4459
|
+
if (ndiff > 0) a = a.reshape([...rep(ndiff, 1), ...a.shape]);
|
|
4460
|
+
if (ndiff < 0) reps = [...rep(-ndiff, 1), ...reps];
|
|
4461
|
+
const broadcastedShape = [];
|
|
4462
|
+
const broadcastAxes = [];
|
|
4463
|
+
for (let i = 0; i < a.ndim; i++) {
|
|
4464
|
+
if (reps[i] > 1) {
|
|
4465
|
+
broadcastedShape.push(reps[i]);
|
|
4466
|
+
broadcastAxes.push(broadcastedShape.length - 1);
|
|
4467
|
+
}
|
|
4468
|
+
broadcastedShape.push(a.shape[i]);
|
|
4469
|
+
}
|
|
4470
|
+
const finalShape = a.shape.map((d, i) => reps[i] * d);
|
|
4471
|
+
return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
|
|
4472
|
+
}
|
|
4473
|
+
/**
|
|
4474
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
4475
|
+
*
|
|
4476
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
4477
|
+
* dimensions where the shape is 1.
|
|
4478
|
+
*/
|
|
4479
|
+
function broadcastTo(a, shape$1) {
|
|
4480
|
+
const nd = ndim(a);
|
|
4481
|
+
if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
|
|
4482
|
+
return broadcast(a, shape$1, range(shape$1.length - nd));
|
|
4483
|
+
}
|
|
4484
|
+
/** Broadcast input shapes to a common output shape. */
|
|
4485
|
+
function broadcastShapes(...shapes) {
|
|
4486
|
+
if (shapes.length === 0) return [];
|
|
4487
|
+
return shapes.reduce(generalBroadcast);
|
|
4488
|
+
}
|
|
4489
|
+
/** Broadcast arrays to a common shape. */
|
|
4490
|
+
function broadcastArrays(...arrays) {
|
|
4491
|
+
const shapes = arrays.map((a) => shape(a));
|
|
4492
|
+
const outShape = broadcastShapes(...shapes);
|
|
4493
|
+
return arrays.map((a) => broadcastTo(a, outShape));
|
|
4494
|
+
}
|
|
4495
|
+
/**
|
|
3807
4496
|
* Return specified diagonals.
|
|
3808
4497
|
*
|
|
3809
4498
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
3810
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
4499
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
3811
4500
|
*
|
|
3812
|
-
* This returns a view over the existing array.
|
|
4501
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
4502
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
4503
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
3813
4504
|
*/
|
|
3814
4505
|
function diagonal(a, offset, axis1, axis2) {
|
|
3815
4506
|
return fudgeArray(a).diagonal(offset, axis1, axis2);
|
|
@@ -3825,15 +4516,16 @@ function diag(v, k = 0) {
|
|
|
3825
4516
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
3826
4517
|
if (a.ndim === 1) {
|
|
3827
4518
|
const n = a.shape[0];
|
|
3828
|
-
const ret = where(eye(n).equal(1), a,
|
|
3829
|
-
if (k
|
|
3830
|
-
return ret;
|
|
4519
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
4520
|
+
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4521
|
+
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4522
|
+
else return ret;
|
|
3831
4523
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
3832
4524
|
else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
|
|
3833
4525
|
}
|
|
3834
4526
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
3835
4527
|
function allclose(actual, expected, options) {
|
|
3836
|
-
const { rtol = 1e-5, atol = 1e-
|
|
4528
|
+
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
3837
4529
|
const x = array(actual);
|
|
3838
4530
|
const y = array(expected);
|
|
3839
4531
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
@@ -3868,8 +4560,36 @@ function dot(x, y) {
|
|
|
3868
4560
|
]);
|
|
3869
4561
|
return dot$1(x, y);
|
|
3870
4562
|
}
|
|
3871
|
-
/**
|
|
3872
|
-
|
|
4563
|
+
/**
|
|
4564
|
+
* Compute the inner product of two arrays.
|
|
4565
|
+
*
|
|
4566
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
4567
|
+
* contraction on the last axis.
|
|
4568
|
+
*
|
|
4569
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
4570
|
+
*/
|
|
4571
|
+
function inner(x, y) {
|
|
4572
|
+
x = reshape(x, shape(x).toSpliced(-1, 0, ...rep(ndim(y) - 1, 1)));
|
|
4573
|
+
return dot$1(x, y);
|
|
4574
|
+
}
|
|
4575
|
+
/**
|
|
4576
|
+
* Compute the outer product of two arrays.
|
|
4577
|
+
*
|
|
4578
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
4579
|
+
* be of shape `[x.size, y.size]`.
|
|
4580
|
+
*/
|
|
4581
|
+
function outer(x, y) {
|
|
4582
|
+
x = ravel(x);
|
|
4583
|
+
y = ravel(y);
|
|
4584
|
+
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
4585
|
+
}
|
|
4586
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
4587
|
+
function vecdot(x, y, { axis } = {}) {
|
|
4588
|
+
const xaxis = checkAxis(axis ?? -1, ndim(x));
|
|
4589
|
+
const yaxis = checkAxis(axis ?? -1, ndim(y));
|
|
4590
|
+
if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
|
|
4591
|
+
x = moveaxis(x, xaxis, -1);
|
|
4592
|
+
y = moveaxis(y, yaxis, -1);
|
|
3873
4593
|
return dot$1(x, y);
|
|
3874
4594
|
}
|
|
3875
4595
|
/**
|
|
@@ -3878,7 +4598,7 @@ function vecdot(x, y) {
|
|
|
3878
4598
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
3879
4599
|
*/
|
|
3880
4600
|
function vdot(x, y) {
|
|
3881
|
-
return
|
|
4601
|
+
return dot$1(ravel(x), ravel(y));
|
|
3882
4602
|
}
|
|
3883
4603
|
/**
|
|
3884
4604
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
@@ -3907,6 +4627,43 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
3907
4627
|
return xs.map((x, i) => broadcast(x, shape$1, [...range(i), ...range(i + 1, xs.length)]));
|
|
3908
4628
|
}
|
|
3909
4629
|
/**
|
|
4630
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
4631
|
+
*
|
|
4632
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
4633
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
4634
|
+
* `k>0` is above it.
|
|
4635
|
+
*/
|
|
4636
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
4637
|
+
m ??= n;
|
|
4638
|
+
dtype ??= DType.Float32;
|
|
4639
|
+
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
4640
|
+
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
4641
|
+
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
4642
|
+
const rows = arange(k, n + k, 1, {
|
|
4643
|
+
dtype: DType.Int32,
|
|
4644
|
+
device
|
|
4645
|
+
});
|
|
4646
|
+
const cols = arange(0, m, 1, {
|
|
4647
|
+
dtype: DType.Int32,
|
|
4648
|
+
device
|
|
4649
|
+
});
|
|
4650
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
4651
|
+
}
|
|
4652
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
4653
|
+
function tril(a, k = 0) {
|
|
4654
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4655
|
+
a = fudgeArray(a);
|
|
4656
|
+
const [n, m] = a.shape.slice(-2);
|
|
4657
|
+
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
4658
|
+
}
|
|
4659
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
4660
|
+
function triu(a, k = 0) {
|
|
4661
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4662
|
+
a = fudgeArray(a);
|
|
4663
|
+
const [n, m] = a.shape.slice(-2);
|
|
4664
|
+
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
4665
|
+
}
|
|
4666
|
+
/**
|
|
3910
4667
|
* Clip (limit) the values in an array.
|
|
3911
4668
|
*
|
|
3912
4669
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -3930,18 +4687,70 @@ function absolute(x) {
|
|
|
3930
4687
|
x = fudgeArray(x);
|
|
3931
4688
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
3932
4689
|
}
|
|
3933
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
4690
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
3934
4691
|
const abs = absolute;
|
|
4692
|
+
/** Return an element-wise indication of sign of the input. */
|
|
4693
|
+
function sign(x) {
|
|
4694
|
+
x = fudgeArray(x);
|
|
4695
|
+
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4696
|
+
}
|
|
3935
4697
|
/** Calculate element-wise square of the input array. */
|
|
3936
4698
|
function square(x) {
|
|
3937
4699
|
x = fudgeArray(x);
|
|
3938
4700
|
return x.ref.mul(x);
|
|
3939
4701
|
}
|
|
3940
|
-
/**
|
|
4702
|
+
/** Element-wise tangent function (takes radians). */
|
|
3941
4703
|
function tan(x) {
|
|
3942
4704
|
x = fudgeArray(x);
|
|
3943
4705
|
return sin(x.ref).div(cos(x));
|
|
3944
4706
|
}
|
|
4707
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
4708
|
+
function acos(x) {
|
|
4709
|
+
return subtract(pi / 2, asin(x));
|
|
4710
|
+
}
|
|
4711
|
+
/**
|
|
4712
|
+
* @function
|
|
4713
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4714
|
+
*
|
|
4715
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4716
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4717
|
+
* improvements.
|
|
4718
|
+
*/
|
|
4719
|
+
const hypot = jit$1((x1, x2) => {
|
|
4720
|
+
return sqrt(square(x1).add(square(x2)));
|
|
4721
|
+
});
|
|
4722
|
+
/**
|
|
4723
|
+
* @function
|
|
4724
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
4725
|
+
*
|
|
4726
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
4727
|
+
* The result is in the range [-π, π].
|
|
4728
|
+
*
|
|
4729
|
+
* Uses numerically stable formulas:
|
|
4730
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
4731
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
4732
|
+
*
|
|
4733
|
+
* The output is ill-defined when both x and y are zero.
|
|
4734
|
+
*/
|
|
4735
|
+
const atan2 = jit$1((y, x) => {
|
|
4736
|
+
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4737
|
+
const xNeg = less(x.ref, 0);
|
|
4738
|
+
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
4739
|
+
const denom = where(xNeg, y, r.add(x));
|
|
4740
|
+
return atan(numer.div(denom)).mul(2);
|
|
4741
|
+
});
|
|
4742
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
4743
|
+
const arccos = acos;
|
|
4744
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
4745
|
+
const arctan = atan;
|
|
4746
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
4747
|
+
const arctan2 = atan2;
|
|
4748
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
4749
|
+
function subtract(x, y) {
|
|
4750
|
+
x = fudgeArray(x);
|
|
4751
|
+
y = fudgeArray(y);
|
|
4752
|
+
return x.sub(y);
|
|
4753
|
+
}
|
|
3945
4754
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
3946
4755
|
function trueDivide(x, y) {
|
|
3947
4756
|
x = fudgeArray(x);
|
|
@@ -3949,7 +4758,7 @@ function trueDivide(x, y) {
|
|
|
3949
4758
|
if (!isFloatDtype(x.dtype) || !isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
|
|
3950
4759
|
return x.div(y);
|
|
3951
4760
|
}
|
|
3952
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
4761
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
3953
4762
|
const divide = trueDivide;
|
|
3954
4763
|
/** Round input to the nearest integer towards zero. */
|
|
3955
4764
|
function trunc(x) {
|
|
@@ -3967,15 +4776,151 @@ function log2(x) {
|
|
|
3967
4776
|
function log10(x) {
|
|
3968
4777
|
return log(x).mul(Math.LOG10E);
|
|
3969
4778
|
}
|
|
4779
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
4780
|
+
function expm1(x) {
|
|
4781
|
+
return exp(x).sub(1);
|
|
4782
|
+
}
|
|
4783
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
4784
|
+
function log1p(x) {
|
|
4785
|
+
return log(add(1, x));
|
|
4786
|
+
}
|
|
4787
|
+
/** Convert angles from degrees to radians. */
|
|
4788
|
+
function deg2rad(x) {
|
|
4789
|
+
return multiply(x, pi / 180);
|
|
4790
|
+
}
|
|
4791
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
4792
|
+
const radians = deg2rad;
|
|
4793
|
+
/** Convert angles from radians to degrees. */
|
|
4794
|
+
function rad2deg(x) {
|
|
4795
|
+
return multiply(x, 180 / pi);
|
|
4796
|
+
}
|
|
4797
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
4798
|
+
const degrees = rad2deg;
|
|
4799
|
+
/**
|
|
4800
|
+
* @function
|
|
4801
|
+
* Computes first array raised to power of second array, element-wise.
|
|
4802
|
+
*/
|
|
4803
|
+
const power = jit$1((x1, x2) => {
|
|
4804
|
+
return exp(log(x1).mul(x2));
|
|
4805
|
+
});
|
|
4806
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
4807
|
+
const pow = power;
|
|
4808
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
4809
|
+
const cbrt = jit$1((x) => {
|
|
4810
|
+
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4811
|
+
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4812
|
+
});
|
|
4813
|
+
/**
|
|
4814
|
+
* @function
|
|
4815
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
4816
|
+
*
|
|
4817
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4818
|
+
*/
|
|
4819
|
+
const sinh = jit$1((x) => {
|
|
4820
|
+
const ex = exp(x);
|
|
4821
|
+
const emx = reciprocal(ex.ref);
|
|
4822
|
+
return ex.sub(emx).mul(.5);
|
|
4823
|
+
});
|
|
4824
|
+
/**
|
|
4825
|
+
* @function
|
|
4826
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
4827
|
+
*
|
|
4828
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4829
|
+
*/
|
|
4830
|
+
const cosh = jit$1((x) => {
|
|
4831
|
+
const ex = exp(x);
|
|
4832
|
+
const emx = reciprocal(ex.ref);
|
|
4833
|
+
return ex.add(emx).mul(.5);
|
|
4834
|
+
});
|
|
4835
|
+
/**
|
|
4836
|
+
* @function
|
|
4837
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
4838
|
+
*
|
|
4839
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4840
|
+
*/
|
|
4841
|
+
const tanh = jit$1((x) => {
|
|
4842
|
+
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4843
|
+
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4844
|
+
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4845
|
+
});
|
|
4846
|
+
/**
|
|
4847
|
+
* @function
|
|
4848
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
4849
|
+
*
|
|
4850
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4851
|
+
*/
|
|
4852
|
+
const arcsinh = jit$1((x) => {
|
|
4853
|
+
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4854
|
+
});
|
|
4855
|
+
/**
|
|
4856
|
+
* @function
|
|
4857
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
4858
|
+
*
|
|
4859
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4860
|
+
*/
|
|
4861
|
+
const arccosh = jit$1((x) => {
|
|
4862
|
+
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4863
|
+
});
|
|
4864
|
+
/**
|
|
4865
|
+
* @function
|
|
4866
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
4867
|
+
*
|
|
4868
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4869
|
+
*/
|
|
4870
|
+
const arctanh = jit$1((x) => {
|
|
4871
|
+
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4872
|
+
});
|
|
4873
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
4874
|
+
const asinh = arcsinh;
|
|
4875
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
4876
|
+
const acosh = arccosh;
|
|
4877
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
4878
|
+
const atanh = arctanh;
|
|
4879
|
+
/**
|
|
4880
|
+
* Compute the variance of an array.
|
|
4881
|
+
*
|
|
4882
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
4883
|
+
* the specified axis.
|
|
4884
|
+
*
|
|
4885
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4886
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4887
|
+
*/
|
|
4888
|
+
function var_(x, axis = null, opts) {
|
|
4889
|
+
x = fudgeArray(x);
|
|
4890
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
4891
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
4892
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
4893
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
4894
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
4895
|
+
}
|
|
4896
|
+
/**
|
|
4897
|
+
* Compute the standard deviation of an array.
|
|
4898
|
+
*
|
|
4899
|
+
* The standard deviation is computed for the flattened array by default,
|
|
4900
|
+
* otherwise over the specified axis.
|
|
4901
|
+
*
|
|
4902
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4903
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4904
|
+
*/
|
|
4905
|
+
function std(x, axis = null, opts) {
|
|
4906
|
+
return sqrt(var_(x, axis, opts));
|
|
4907
|
+
}
|
|
3970
4908
|
|
|
3971
4909
|
//#endregion
|
|
3972
4910
|
//#region src/nn.ts
|
|
3973
4911
|
var nn_exports = {};
|
|
3974
4912
|
__export(nn_exports, {
|
|
4913
|
+
celu: () => celu,
|
|
4914
|
+
elu: () => elu,
|
|
4915
|
+
gelu: () => gelu,
|
|
4916
|
+
glu: () => glu,
|
|
3975
4917
|
identity: () => identity,
|
|
4918
|
+
leakyRelu: () => leakyRelu,
|
|
3976
4919
|
logSigmoid: () => logSigmoid,
|
|
3977
4920
|
logSoftmax: () => logSoftmax,
|
|
4921
|
+
logmeanexp: () => logmeanexp,
|
|
3978
4922
|
logsumexp: () => logsumexp,
|
|
4923
|
+
mish: () => mish,
|
|
3979
4924
|
oneHot: () => oneHot,
|
|
3980
4925
|
relu: () => relu,
|
|
3981
4926
|
relu6: () => relu6,
|
|
@@ -3984,6 +4929,8 @@ __export(nn_exports, {
|
|
|
3984
4929
|
softSign: () => softSign,
|
|
3985
4930
|
softmax: () => softmax,
|
|
3986
4931
|
softplus: () => softplus,
|
|
4932
|
+
squareplus: () => squareplus,
|
|
4933
|
+
standardize: () => standardize,
|
|
3987
4934
|
swish: () => swish
|
|
3988
4935
|
});
|
|
3989
4936
|
/**
|
|
@@ -4027,6 +4974,7 @@ function softSign(x) {
|
|
|
4027
4974
|
return x.ref.div(absolute(x).add(1));
|
|
4028
4975
|
}
|
|
4029
4976
|
/**
|
|
4977
|
+
* @function
|
|
4030
4978
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4031
4979
|
* Swish, computed element-wise:
|
|
4032
4980
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4035,11 +4983,9 @@ function softSign(x) {
|
|
|
4035
4983
|
*
|
|
4036
4984
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4037
4985
|
*/
|
|
4038
|
-
|
|
4039
|
-
x = fudgeArray(x);
|
|
4040
|
-
return x.ref.mul(sigmoid(x));
|
|
4041
|
-
}
|
|
4986
|
+
const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
|
|
4042
4987
|
/**
|
|
4988
|
+
* @function
|
|
4043
4989
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4044
4990
|
* Swish, computed element-wise:
|
|
4045
4991
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4056,8 +5002,88 @@ const swish = silu;
|
|
|
4056
5002
|
function logSigmoid(x) {
|
|
4057
5003
|
return negative(softplus(negative(x)));
|
|
4058
5004
|
}
|
|
4059
|
-
/**
|
|
5005
|
+
/**
|
|
5006
|
+
* @function
|
|
5007
|
+
* Identity activation function. Returns the argument unmodified.
|
|
5008
|
+
*/
|
|
4060
5009
|
const identity = fudgeArray;
|
|
5010
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
5011
|
+
function leakyRelu(x, negativeSlope = .01) {
|
|
5012
|
+
x = fudgeArray(x);
|
|
5013
|
+
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
5014
|
+
}
|
|
5015
|
+
/**
|
|
5016
|
+
* Exponential linear unit activation function.
|
|
5017
|
+
*
|
|
5018
|
+
* Computes the element-wise function:
|
|
5019
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
5020
|
+
*/
|
|
5021
|
+
function elu(x, alpha = 1) {
|
|
5022
|
+
x = fudgeArray(x);
|
|
5023
|
+
return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
|
|
5024
|
+
}
|
|
5025
|
+
/**
|
|
5026
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
5027
|
+
*
|
|
5028
|
+
* Computes the element-wise function:
|
|
5029
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
5030
|
+
*/
|
|
5031
|
+
function celu(x, alpha = 1) {
|
|
5032
|
+
x = fudgeArray(x);
|
|
5033
|
+
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
5034
|
+
}
|
|
5035
|
+
/**
|
|
5036
|
+
* @function
|
|
5037
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
5038
|
+
*
|
|
5039
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
5040
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
5041
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5042
|
+
*
|
|
5043
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5044
|
+
*
|
|
5045
|
+
* This will be improved in the future.
|
|
5046
|
+
*/
|
|
5047
|
+
const gelu = jit$1((x) => {
|
|
5048
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5049
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5050
|
+
});
|
|
5051
|
+
/**
|
|
5052
|
+
* Gated linear unit (GLU) activation function.
|
|
5053
|
+
*
|
|
5054
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
5055
|
+
* computes `a * sigmoid(b)`.
|
|
5056
|
+
*/
|
|
5057
|
+
function glu(x, axis = -1) {
|
|
5058
|
+
x = fudgeArray(x);
|
|
5059
|
+
axis = checkAxis(axis, x.ndim);
|
|
5060
|
+
const size$1 = x.shape[axis];
|
|
5061
|
+
if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
|
|
5062
|
+
const slice = x.shape.map((a$1) => [0, a$1]);
|
|
5063
|
+
const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
|
|
5064
|
+
const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
|
|
5065
|
+
return a.mul(sigmoid(b));
|
|
5066
|
+
}
|
|
5067
|
+
/**
|
|
5068
|
+
* Squareplus activation function.
|
|
5069
|
+
*
|
|
5070
|
+
* Computes the element-wise function:
|
|
5071
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
5072
|
+
*/
|
|
5073
|
+
function squareplus(x, b = 4) {
|
|
5074
|
+
x = fudgeArray(x);
|
|
5075
|
+
return x.ref.add(sqrt(square(x).add(b))).mul(.5);
|
|
5076
|
+
}
|
|
5077
|
+
/**
|
|
5078
|
+
* Mish activation function.
|
|
5079
|
+
*
|
|
5080
|
+
* Computes the element-wise function:
|
|
5081
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
5082
|
+
*/
|
|
5083
|
+
function mish(x) {
|
|
5084
|
+
x = fudgeArray(x);
|
|
5085
|
+
return x.ref.mul(tanh(softplus(x)));
|
|
5086
|
+
}
|
|
4061
5087
|
/**
|
|
4062
5088
|
* Softmax function. Computes the function which rescales elements to the range
|
|
4063
5089
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -4066,17 +5092,13 @@ const identity = fudgeArray;
|
|
|
4066
5092
|
*
|
|
4067
5093
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
4068
5094
|
*/
|
|
4069
|
-
function softmax(x, axis) {
|
|
5095
|
+
function softmax(x, axis = -1) {
|
|
4070
5096
|
x = fudgeArray(x);
|
|
4071
|
-
|
|
4072
|
-
|
|
4073
|
-
|
|
4074
|
-
x.dispose();
|
|
4075
|
-
return ones(x.shape);
|
|
4076
|
-
}
|
|
4077
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5097
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5098
|
+
if (axis.length === 0) return onesLike(x);
|
|
5099
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4078
5100
|
const unnormalized = exp(x.sub(stopGradient(xMax)));
|
|
4079
|
-
return unnormalized.ref.div(unnormalized.sum(axis, {
|
|
5101
|
+
return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
|
|
4080
5102
|
}
|
|
4081
5103
|
/**
|
|
4082
5104
|
* Log-Softmax function.
|
|
@@ -4086,17 +5108,13 @@ function softmax(x, axis) {
|
|
|
4086
5108
|
*
|
|
4087
5109
|
* If `axis` is not specified, it defaults to the last axis.
|
|
4088
5110
|
*/
|
|
4089
|
-
function logSoftmax(x, axis) {
|
|
5111
|
+
function logSoftmax(x, axis = -1) {
|
|
4090
5112
|
x = fudgeArray(x);
|
|
4091
|
-
|
|
4092
|
-
|
|
4093
|
-
|
|
4094
|
-
x.dispose();
|
|
4095
|
-
return zeros(x.shape);
|
|
4096
|
-
}
|
|
4097
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5113
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5114
|
+
if (axis.length === 0) return zerosLike(x);
|
|
5115
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4098
5116
|
const shifted = x.sub(stopGradient(xMax));
|
|
4099
|
-
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, {
|
|
5117
|
+
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
|
|
4100
5118
|
return shifted.sub(shiftedLogsumexp);
|
|
4101
5119
|
}
|
|
4102
5120
|
/**
|
|
@@ -4107,16 +5125,39 @@ function logSoftmax(x, axis) {
|
|
|
4107
5125
|
*
|
|
4108
5126
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
4109
5127
|
*/
|
|
4110
|
-
function logsumexp(x, axis) {
|
|
5128
|
+
function logsumexp(x, axis = null) {
|
|
4111
5129
|
x = fudgeArray(x);
|
|
4112
|
-
|
|
4113
|
-
else if (typeof axis === "number") axis = [axis];
|
|
5130
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
4114
5131
|
if (axis.length === 0) return x;
|
|
4115
5132
|
const xMax = stopGradient(max(x.ref, axis));
|
|
4116
5133
|
const xMaxDims = broadcast(xMax.ref, x.shape, axis);
|
|
4117
5134
|
const shifted = x.sub(xMaxDims);
|
|
4118
5135
|
return xMax.add(log(exp(shifted).sum(axis)));
|
|
4119
5136
|
}
|
|
5137
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5138
|
+
function logmeanexp(x, axis = null) {
|
|
5139
|
+
x = fudgeArray(x);
|
|
5140
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5141
|
+
if (axis.length === 0) return x;
|
|
5142
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5143
|
+
return logsumexp(x, axis).sub(Math.log(n));
|
|
5144
|
+
}
|
|
5145
|
+
/**
|
|
5146
|
+
* Standardizes input to zero mean and unit variance.
|
|
5147
|
+
*
|
|
5148
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
5149
|
+
* axis, or `null` to standardize over all elements.
|
|
5150
|
+
*
|
|
5151
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
5152
|
+
*/
|
|
5153
|
+
function standardize(x, axis = -1, opts = {}) {
|
|
5154
|
+
x = fudgeArray(x);
|
|
5155
|
+
axis = normalizeAxis(axis, x.ndim);
|
|
5156
|
+
if (axis.length === 0) return x;
|
|
5157
|
+
const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
|
|
5158
|
+
const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
|
|
5159
|
+
return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
|
|
5160
|
+
}
|
|
4120
5161
|
/**
|
|
4121
5162
|
* One-hot encodes the given indices.
|
|
4122
5163
|
*
|
|
@@ -4134,7 +5175,7 @@ function logsumexp(x, axis) {
|
|
|
4134
5175
|
* ```
|
|
4135
5176
|
*/
|
|
4136
5177
|
function oneHot(x, numClasses) {
|
|
4137
|
-
if (x.dtype
|
|
5178
|
+
if (isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4138
5179
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4139
5180
|
}
|
|
4140
5181
|
|
|
@@ -4142,8 +5183,11 @@ function oneHot(x, numClasses) {
|
|
|
4142
5183
|
//#region src/random.ts
|
|
4143
5184
|
var random_exports = {};
|
|
4144
5185
|
__export(random_exports, {
|
|
5186
|
+
bernoulli: () => bernoulli,
|
|
4145
5187
|
bits: () => bits,
|
|
5188
|
+
exponential: () => exponential,
|
|
4146
5189
|
key: () => key,
|
|
5190
|
+
normal: () => normal,
|
|
4147
5191
|
split: () => split,
|
|
4148
5192
|
uniform: () => uniform
|
|
4149
5193
|
});
|
|
@@ -4174,11 +5218,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
4174
5218
|
/** Sample uniform random values in [minval, maxval) with given shape. */
|
|
4175
5219
|
function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
4176
5220
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
4177
|
-
const mantissa = bits(key$1, shape$1).div(
|
|
5221
|
+
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
4178
5222
|
dtype: DType.Uint32,
|
|
4179
5223
|
device: key$1.device
|
|
4180
5224
|
}));
|
|
4181
|
-
const float12 = mantissa.add(
|
|
5225
|
+
const float12 = mantissa.add(array(1065353216, {
|
|
4182
5226
|
dtype: DType.Uint32,
|
|
4183
5227
|
device: key$1.device
|
|
4184
5228
|
}));
|
|
@@ -4186,6 +5230,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
4186
5230
|
if (minval === 0 && maxval === 1) return rand;
|
|
4187
5231
|
else return rand.mul(maxval - minval).add(minval);
|
|
4188
5232
|
}
|
|
5233
|
+
/**
|
|
5234
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5235
|
+
*
|
|
5236
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
5237
|
+
* and must be broadcastable to `shape`.
|
|
5238
|
+
*/
|
|
5239
|
+
function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
5240
|
+
p = fudgeArray(p);
|
|
5241
|
+
return uniform(key$1, shape$1).less(p);
|
|
5242
|
+
}
|
|
5243
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
5244
|
+
function exponential(key$1, shape$1 = []) {
|
|
5245
|
+
const u = uniform(key$1, shape$1);
|
|
5246
|
+
return negative(log1p(negative(u)));
|
|
5247
|
+
}
|
|
5248
|
+
/**
|
|
5249
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5250
|
+
*
|
|
5251
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5252
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5253
|
+
* bitwise identical to JAX.
|
|
5254
|
+
*/
|
|
5255
|
+
function normal(key$1, shape$1 = []) {
|
|
5256
|
+
const [k1, k2] = split(key$1, 2);
|
|
5257
|
+
const u1 = uniform(k1, shape$1);
|
|
5258
|
+
const u2 = uniform(k2, shape$1);
|
|
5259
|
+
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5260
|
+
const theta = u2.mul(2 * Math.PI);
|
|
5261
|
+
return radius.mul(cos(theta));
|
|
5262
|
+
}
|
|
4189
5263
|
|
|
4190
5264
|
//#endregion
|
|
4191
5265
|
//#region src/polyfills.ts
|
|
@@ -4195,33 +5269,91 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
4195
5269
|
|
|
4196
5270
|
//#endregion
|
|
4197
5271
|
//#region src/index.ts
|
|
4198
|
-
/**
|
|
5272
|
+
/**
|
|
5273
|
+
* @function
|
|
5274
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
5275
|
+
*/
|
|
4199
5276
|
const jvp = jvp$1;
|
|
4200
|
-
/**
|
|
5277
|
+
/**
|
|
5278
|
+
* @function
|
|
5279
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
5280
|
+
*/
|
|
4201
5281
|
const vmap = vmap$1;
|
|
4202
|
-
/**
|
|
5282
|
+
/**
|
|
5283
|
+
* @function
|
|
5284
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
5285
|
+
*/
|
|
4203
5286
|
const jacfwd = jacfwd$1;
|
|
4204
|
-
/**
|
|
5287
|
+
/**
|
|
5288
|
+
* @function
|
|
5289
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
5290
|
+
*/
|
|
4205
5291
|
const makeJaxpr = makeJaxpr$1;
|
|
5292
|
+
/**
|
|
5293
|
+
* @function
|
|
5294
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
5295
|
+
*
|
|
5296
|
+
* The function will be compiled the first time it is called with a set of
|
|
5297
|
+
* argument shapes.
|
|
5298
|
+
*
|
|
5299
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
5300
|
+
* calls to free memory associated with array constants.
|
|
5301
|
+
*
|
|
5302
|
+
* **Options:**
|
|
5303
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
5304
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
5305
|
+
* and different values will trigger recompilation.
|
|
5306
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
5307
|
+
* computation will be placed on the default device.
|
|
5308
|
+
*/
|
|
4206
5309
|
const jit = jit$1;
|
|
4207
5310
|
/**
|
|
5311
|
+
* @function
|
|
4208
5312
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
4209
5313
|
* partial evaluation.
|
|
4210
5314
|
*/
|
|
4211
5315
|
const linearize = linearize$1;
|
|
4212
|
-
/**
|
|
5316
|
+
/**
|
|
5317
|
+
* @function
|
|
5318
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
5319
|
+
*/
|
|
4213
5320
|
const vjp = vjp$1;
|
|
4214
5321
|
/**
|
|
5322
|
+
* @function
|
|
4215
5323
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
4216
5324
|
* first argument.
|
|
4217
5325
|
*/
|
|
4218
5326
|
const grad = grad$1;
|
|
4219
|
-
/**
|
|
5327
|
+
/**
|
|
5328
|
+
* @function
|
|
5329
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
5330
|
+
*/
|
|
4220
5331
|
const valueAndGrad = valueAndGrad$1;
|
|
4221
|
-
/**
|
|
5332
|
+
/**
|
|
5333
|
+
* @function
|
|
5334
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
5335
|
+
*/
|
|
4222
5336
|
const jacrev = jacrev$1;
|
|
4223
|
-
/**
|
|
5337
|
+
/**
|
|
5338
|
+
* @function
|
|
5339
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
5340
|
+
*/
|
|
4224
5341
|
const jacobian = jacrev;
|
|
5342
|
+
/**
|
|
5343
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
5344
|
+
*
|
|
5345
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
5346
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
5347
|
+
* to avoid queueing up too many pending operations.
|
|
5348
|
+
*
|
|
5349
|
+
* Does not consume reference to the arrays.
|
|
5350
|
+
*/
|
|
5351
|
+
async function blockUntilReady(x) {
|
|
5352
|
+
const promises = [];
|
|
5353
|
+
for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
|
|
5354
|
+
await Promise.all(promises);
|
|
5355
|
+
return x;
|
|
5356
|
+
}
|
|
4225
5357
|
|
|
4226
5358
|
//#endregion
|
|
4227
|
-
export { devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random,
|
|
5359
|
+
export { Array$1 as Array, DType, Jaxpr, blockUntilReady, defaultDevice, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDebug, tree_exports as tree, valueAndGrad, vjp, vmap };
|