@jax-js/jax 0.0.2 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +57 -25
- package/dist/backend-EBRGmEYw.js +3816 -0
- package/dist/{backend-BK21PBVP.cjs → backend-Ss1Mev_-.cjs} +2075 -107
- package/dist/index.cjs +1393 -250
- package/dist/index.d.cts +651 -102
- package/dist/index.d.ts +651 -102
- package/dist/index.js +1377 -245
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-BVdMaO9T.cjs} +62 -35
- package/dist/{webgpu-JVpVad6g.js → webgpu-ow0Pn_6q.js} +62 -35
- package/package.json +21 -9
- package/dist/backend-1eVbAoaV.js +0 -1890
package/dist/index.cjs
CHANGED
|
@@ -30,13 +30,14 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-Ss1Mev_-.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
37
37
|
__export(tree_exports, {
|
|
38
38
|
JsTreeDef: () => JsTreeDef,
|
|
39
39
|
NodeType: () => NodeType,
|
|
40
|
+
dispose: () => dispose,
|
|
40
41
|
flatten: () => flatten,
|
|
41
42
|
leaves: () => leaves,
|
|
42
43
|
map: () => map,
|
|
@@ -51,7 +52,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
|
51
52
|
NodeType$1["Leaf"] = "Leaf";
|
|
52
53
|
return NodeType$1;
|
|
53
54
|
}({});
|
|
54
|
-
/**
|
|
55
|
+
/** Represents the structure of a JsTree. */
|
|
55
56
|
var JsTreeDef = class JsTreeDef {
|
|
56
57
|
static leaf = new JsTreeDef(NodeType.Leaf, null, []);
|
|
57
58
|
constructor(nodeType, nodeMetadata, childTreedefs) {
|
|
@@ -139,6 +140,194 @@ function map(fn, tree, ...rest) {
|
|
|
139
140
|
function ref(tree) {
|
|
140
141
|
return map((x) => x.ref, tree);
|
|
141
142
|
}
|
|
143
|
+
/** Dispose every array in a tree. */
|
|
144
|
+
function dispose(tree) {
|
|
145
|
+
if (tree) map((x) => x.dispose(), tree);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
//#endregion
|
|
149
|
+
//#region src/frontend/convolution.ts
|
|
150
|
+
/**
|
|
151
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
152
|
+
*
|
|
153
|
+
* If the check succeeds, returns the output shape.
|
|
154
|
+
*/
|
|
155
|
+
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
156
|
+
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
157
|
+
const n = lhsShape.length - 2;
|
|
158
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
159
|
+
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
160
|
+
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
161
|
+
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
162
|
+
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
163
|
+
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
164
|
+
const outShape = [lhsShape[0], rhsShape[0]];
|
|
165
|
+
for (let i = 0; i < n; i++) {
|
|
166
|
+
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
167
|
+
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
168
|
+
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
169
|
+
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
170
|
+
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
171
|
+
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
172
|
+
const [pl, pr] = padding[i];
|
|
173
|
+
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
174
|
+
const kernelSize = (k - 1) * rhsDilation[i] + 1;
|
|
175
|
+
const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
|
|
176
|
+
if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
|
|
177
|
+
outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
|
|
178
|
+
}
|
|
179
|
+
return outShape;
|
|
180
|
+
}
|
|
181
|
+
function checkPoolShape(inShape, window, strides) {
|
|
182
|
+
if (strides.length !== window.length) throw new Error("pool() strides != window dims");
|
|
183
|
+
if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
|
|
184
|
+
const outShape = inShape.slice(0, inShape.length - window.length);
|
|
185
|
+
for (let i = 0; i < window.length; i++) {
|
|
186
|
+
const k = window[i];
|
|
187
|
+
const s = strides[i];
|
|
188
|
+
const size$1 = inShape[inShape.length - window.length + i];
|
|
189
|
+
if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
|
|
190
|
+
if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
|
|
191
|
+
if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
|
|
192
|
+
outShape.push(Math.ceil((size$1 - k + 1) / s));
|
|
193
|
+
}
|
|
194
|
+
return outShape.concat(window);
|
|
195
|
+
}
|
|
196
|
+
/**
|
|
197
|
+
* Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
|
|
198
|
+
* `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
|
|
199
|
+
* spatial dimensions convolved with a kernel.
|
|
200
|
+
*
|
|
201
|
+
* The resulting array can be multiplied with a kernel of shape `ks`, then
|
|
202
|
+
* reduced along the last `ks.length` dimensions for a convolution.
|
|
203
|
+
*
|
|
204
|
+
* Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
|
|
205
|
+
*/
|
|
206
|
+
function pool(st, ks, strides = 1, dilation = 1) {
|
|
207
|
+
if (ks.length === 0) return st;
|
|
208
|
+
if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
|
|
209
|
+
if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
|
|
210
|
+
if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
|
|
211
|
+
if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
|
|
212
|
+
if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
|
|
213
|
+
const noop = st.shape.slice(0, -ks.length);
|
|
214
|
+
const i_ = st.shape.slice(-ks.length);
|
|
215
|
+
const s_ = strides;
|
|
216
|
+
const d_ = dilation;
|
|
217
|
+
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
218
|
+
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
219
|
+
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
220
|
+
st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
221
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
|
|
222
|
+
const kos = require_backend.zipn(ks, o_, s_);
|
|
223
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
|
|
224
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
|
|
225
|
+
[0, k],
|
|
226
|
+
[0, o],
|
|
227
|
+
[0, 1]
|
|
228
|
+
])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
|
|
229
|
+
st = st.permute([
|
|
230
|
+
...require_backend.range(noop.length),
|
|
231
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
232
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
233
|
+
]);
|
|
234
|
+
return st;
|
|
235
|
+
}
|
|
236
|
+
/**
|
|
237
|
+
* Perform the transpose of pool, directly undo-ing a pool() operation.
|
|
238
|
+
*
|
|
239
|
+
* Note that since pool repeats the input, the transpose operation technically
|
|
240
|
+
* should include a sum reduction. This function doesn't perform the reduction,
|
|
241
|
+
* which should be done on the last `k` axes of the returned shape.
|
|
242
|
+
*/
|
|
243
|
+
function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
244
|
+
if (ks.length === 0) return st;
|
|
245
|
+
if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
|
|
246
|
+
if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
|
|
247
|
+
const noop = inShape.slice(0, -ks.length);
|
|
248
|
+
const i_ = inShape.slice(-ks.length);
|
|
249
|
+
const s_ = strides;
|
|
250
|
+
const d_ = dilation;
|
|
251
|
+
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
252
|
+
if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
253
|
+
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
254
|
+
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
255
|
+
const kos = require_backend.zipn(ks, o_, s_);
|
|
256
|
+
st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
|
|
257
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
|
|
258
|
+
k,
|
|
259
|
+
o,
|
|
260
|
+
1
|
|
261
|
+
])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
|
|
262
|
+
[0, 0],
|
|
263
|
+
[0, 0],
|
|
264
|
+
[0, s - 1]
|
|
265
|
+
])]);
|
|
266
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
|
|
267
|
+
st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
|
|
268
|
+
st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
|
|
269
|
+
...require_backend.range(noop.length),
|
|
270
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
271
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
272
|
+
]);
|
|
273
|
+
return st;
|
|
274
|
+
}
|
|
275
|
+
/** Applies dilation to an array directly, for transposed convolution. */
|
|
276
|
+
function applyDilation(st, dilation) {
|
|
277
|
+
if (dilation.every((s) => s === 1)) return st;
|
|
278
|
+
const s_ = dilation;
|
|
279
|
+
const [a, b, ...k_] = st.shape;
|
|
280
|
+
st = st.reshape([
|
|
281
|
+
a,
|
|
282
|
+
b,
|
|
283
|
+
...k_.flatMap((k) => [k, 1])
|
|
284
|
+
]);
|
|
285
|
+
st = st.pad([
|
|
286
|
+
[0, 0],
|
|
287
|
+
[0, 0],
|
|
288
|
+
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
289
|
+
]);
|
|
290
|
+
st = st.reshape([
|
|
291
|
+
a,
|
|
292
|
+
b,
|
|
293
|
+
...k_.map((k, i) => k * s_[i])
|
|
294
|
+
]);
|
|
295
|
+
st = st.shrink([
|
|
296
|
+
[0, a],
|
|
297
|
+
[0, b],
|
|
298
|
+
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
299
|
+
]);
|
|
300
|
+
return st;
|
|
301
|
+
}
|
|
302
|
+
/**
|
|
303
|
+
* Prepare for a convolution between two arrays.
|
|
304
|
+
*
|
|
305
|
+
* This does not check the validity of the shapes, which should be checked
|
|
306
|
+
* beforehand using `checkConvShape()`.
|
|
307
|
+
*/
|
|
308
|
+
function prepareConv(stX, stY, params) {
|
|
309
|
+
const n = stX.shape.length - 2;
|
|
310
|
+
stX = applyDilation(stX, params.lhsDilation);
|
|
311
|
+
const ks = stY.shape.slice(2);
|
|
312
|
+
stX = stX.padOrShrink([
|
|
313
|
+
[0, 0],
|
|
314
|
+
[0, 0],
|
|
315
|
+
...params.padding
|
|
316
|
+
]);
|
|
317
|
+
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
318
|
+
stX = stX.moveaxis(1, n + 1).reshape([
|
|
319
|
+
stX.shape[0],
|
|
320
|
+
1,
|
|
321
|
+
...stX.shape.slice(2, n + 2),
|
|
322
|
+
stX.shape[1] * require_backend.prod(ks)
|
|
323
|
+
]);
|
|
324
|
+
stY = stY.reshape([
|
|
325
|
+
stY.shape[0],
|
|
326
|
+
...require_backend.rep(n, 1),
|
|
327
|
+
stY.shape[1] * require_backend.prod(ks)
|
|
328
|
+
]);
|
|
329
|
+
return [stX, stY];
|
|
330
|
+
}
|
|
142
331
|
|
|
143
332
|
//#endregion
|
|
144
333
|
//#region src/frontend/core.ts
|
|
@@ -165,12 +354,18 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
165
354
|
Primitive$1["RandomBits"] = "random_bits";
|
|
166
355
|
Primitive$1["Sin"] = "sin";
|
|
167
356
|
Primitive$1["Cos"] = "cos";
|
|
357
|
+
Primitive$1["Asin"] = "asin";
|
|
358
|
+
Primitive$1["Atan"] = "atan";
|
|
168
359
|
Primitive$1["Exp"] = "exp";
|
|
169
360
|
Primitive$1["Log"] = "log";
|
|
361
|
+
Primitive$1["Sqrt"] = "sqrt";
|
|
170
362
|
Primitive$1["Min"] = "min";
|
|
171
363
|
Primitive$1["Max"] = "max";
|
|
172
364
|
Primitive$1["Reduce"] = "reduce";
|
|
173
365
|
Primitive$1["Dot"] = "dot";
|
|
366
|
+
Primitive$1["Conv"] = "conv";
|
|
367
|
+
Primitive$1["Pool"] = "pool";
|
|
368
|
+
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
174
369
|
Primitive$1["Compare"] = "compare";
|
|
175
370
|
Primitive$1["Where"] = "where";
|
|
176
371
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -228,34 +423,52 @@ function sin$1(x) {
|
|
|
228
423
|
function cos$1(x) {
|
|
229
424
|
return bind1(Primitive.Cos, [x]);
|
|
230
425
|
}
|
|
426
|
+
function asin$1(x) {
|
|
427
|
+
return bind1(Primitive.Asin, [x]);
|
|
428
|
+
}
|
|
429
|
+
function atan$1(x) {
|
|
430
|
+
return bind1(Primitive.Atan, [x]);
|
|
431
|
+
}
|
|
231
432
|
function exp$1(x) {
|
|
232
433
|
return bind1(Primitive.Exp, [x]);
|
|
233
434
|
}
|
|
234
435
|
function log$1(x) {
|
|
235
436
|
return bind1(Primitive.Log, [x]);
|
|
236
437
|
}
|
|
438
|
+
function sqrt$1(x) {
|
|
439
|
+
return bind1(Primitive.Sqrt, [x]);
|
|
440
|
+
}
|
|
237
441
|
function min$1(x, y) {
|
|
238
442
|
return bind1(Primitive.Min, [x, y]);
|
|
239
443
|
}
|
|
240
444
|
function max$1(x, y) {
|
|
241
445
|
return bind1(Primitive.Max, [x, y]);
|
|
242
446
|
}
|
|
243
|
-
function reduce(x, op, axis, opts) {
|
|
447
|
+
function reduce(x, op, axis = null, opts) {
|
|
244
448
|
if (!require_backend.AluGroup.Reduce.has(op)) throw new TypeError(`Invalid reduce operation: ${op}`);
|
|
245
|
-
|
|
246
|
-
else axis = [];
|
|
247
|
-
else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, ndim$1(x))];
|
|
248
|
-
else axis = axis.map((a) => require_backend.checkAxis(a, ndim$1(x)));
|
|
449
|
+
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
249
450
|
const originalShape = getShape(x);
|
|
250
|
-
|
|
451
|
+
let result = bind1(Primitive.Reduce, [x], {
|
|
251
452
|
op,
|
|
252
453
|
axis
|
|
253
454
|
});
|
|
254
|
-
|
|
455
|
+
if (opts?.keepdims) result = result.reshape(originalShape.map((dim, i) => axis.includes(i) ? 1 : dim));
|
|
456
|
+
return result;
|
|
255
457
|
}
|
|
256
458
|
function dot$1(x, y) {
|
|
257
459
|
return bind1(Primitive.Dot, [x, y]);
|
|
258
460
|
}
|
|
461
|
+
function conv(x, y, params = {}) {
|
|
462
|
+
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
463
|
+
const n = x.ndim - 2;
|
|
464
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
465
|
+
return bind1(Primitive.Conv, [x, y], {
|
|
466
|
+
strides: params.strides ?? require_backend.rep(n, 1),
|
|
467
|
+
padding: params.padding ?? require_backend.rep(n, [0, 0]),
|
|
468
|
+
lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
|
|
469
|
+
rhsDilation: params.rhsDilation ?? require_backend.rep(n, 1)
|
|
470
|
+
});
|
|
471
|
+
}
|
|
259
472
|
function compare(x, y, op) {
|
|
260
473
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
261
474
|
}
|
|
@@ -286,10 +499,11 @@ function where$1(cond, x, y) {
|
|
|
286
499
|
}
|
|
287
500
|
function transpose$1(x, perm) {
|
|
288
501
|
perm = perm ? perm.map((a) => require_backend.checkAxis(a, ndim$1(x))) : require_backend.range(ndim$1(x)).reverse();
|
|
502
|
+
if (!require_backend.isPermutation(perm, ndim$1(x))) throw new Error(`Invalid transpose permutation for ${ndim$1(x)} axes: ${JSON.stringify(perm)}`);
|
|
289
503
|
return bind1(Primitive.Transpose, [x], { perm });
|
|
290
504
|
}
|
|
291
505
|
function broadcast(x, shape$1, axis) {
|
|
292
|
-
axis =
|
|
506
|
+
axis = require_backend.normalizeAxis(axis, shape$1.length);
|
|
293
507
|
return bind1(Primitive.Broadcast, [x], {
|
|
294
508
|
shape: shape$1,
|
|
295
509
|
axis
|
|
@@ -308,7 +522,7 @@ function reshape$1(x, shape$1) {
|
|
|
308
522
|
return bind1(Primitive.Reshape, [x], { shape: shape$1 });
|
|
309
523
|
}
|
|
310
524
|
function flip$1(x, axis) {
|
|
311
|
-
axis =
|
|
525
|
+
axis = require_backend.normalizeAxis(axis, ndim$1(x));
|
|
312
526
|
return bind1(Primitive.Flip, [x], { axis });
|
|
313
527
|
}
|
|
314
528
|
function shrink(x, slice) {
|
|
@@ -388,12 +602,19 @@ var Tracer = class Tracer {
|
|
|
388
602
|
constructor(trace) {
|
|
389
603
|
this._trace = trace;
|
|
390
604
|
}
|
|
605
|
+
/** The shape of the array. */
|
|
391
606
|
get shape() {
|
|
392
607
|
return this.aval.shape;
|
|
393
608
|
}
|
|
609
|
+
/** The total number of elements in the array. */
|
|
610
|
+
get size() {
|
|
611
|
+
return require_backend.prod(this.shape);
|
|
612
|
+
}
|
|
613
|
+
/** The dtype of the array. */
|
|
394
614
|
get dtype() {
|
|
395
615
|
return this.aval.dtype;
|
|
396
616
|
}
|
|
617
|
+
/** The number of dimensions of the array. */
|
|
397
618
|
get ndim() {
|
|
398
619
|
return this.shape.length;
|
|
399
620
|
}
|
|
@@ -429,22 +650,20 @@ var Tracer = class Tracer {
|
|
|
429
650
|
return lessEqual$1(this, other);
|
|
430
651
|
}
|
|
431
652
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
432
|
-
sum(axis, opts) {
|
|
653
|
+
sum(axis = null, opts) {
|
|
433
654
|
return reduce(this, require_backend.AluOp.Add, axis, opts);
|
|
434
655
|
}
|
|
435
656
|
/** Product of the array elements over a given axis. */
|
|
436
|
-
prod(axis, opts) {
|
|
657
|
+
prod(axis = null, opts) {
|
|
437
658
|
return reduce(this, require_backend.AluOp.Mul, axis, opts);
|
|
438
659
|
}
|
|
439
660
|
/** Compute the average of the array elements along the specified axis. */
|
|
440
|
-
mean(axis, opts) {
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
447
|
-
return result;
|
|
661
|
+
mean(axis = null, opts) {
|
|
662
|
+
axis = require_backend.normalizeAxis(axis, this.ndim);
|
|
663
|
+
const n = axis.reduce((acc, a) => acc * this.shape[a], 1);
|
|
664
|
+
if (n === 0) throw new Error("mean: cannot compute mean over zero-length axis");
|
|
665
|
+
const result = reduce(this, require_backend.AluOp.Add, axis, opts);
|
|
666
|
+
return result.mul(1 / n);
|
|
448
667
|
}
|
|
449
668
|
/** Permute the dimensions of an array. Defaults to reversing the axis order. */
|
|
450
669
|
transpose(perm) {
|
|
@@ -476,8 +695,29 @@ var Tracer = class Tracer {
|
|
|
476
695
|
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
477
696
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
478
697
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
698
|
+
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
699
|
+
axis1 = require_backend.checkAxis(axis1, this.ndim);
|
|
700
|
+
axis2 = require_backend.checkAxis(axis2, this.ndim);
|
|
479
701
|
if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
|
|
480
|
-
throw new Error("
|
|
702
|
+
if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
|
|
703
|
+
let ar = this;
|
|
704
|
+
if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
|
|
705
|
+
const perm = require_backend.range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
|
|
706
|
+
ar = ar.transpose(perm);
|
|
707
|
+
}
|
|
708
|
+
const [n, m] = ar.shape.slice(-2);
|
|
709
|
+
const diagSize = Math.min(n, m - offset);
|
|
710
|
+
ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
|
|
711
|
+
const npad = diagSize * (m + 1) - n * m;
|
|
712
|
+
if (npad > 0) ar = pad$1(ar, [...require_backend.rep(ar.ndim - 1, [0, 0]), [0, npad]]);
|
|
713
|
+
else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
|
|
714
|
+
ar = ar.reshape([
|
|
715
|
+
...ar.shape.slice(0, -1),
|
|
716
|
+
diagSize,
|
|
717
|
+
m + 1
|
|
718
|
+
]);
|
|
719
|
+
ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
|
|
720
|
+
return ar;
|
|
481
721
|
}
|
|
482
722
|
/** Flatten the array without changing its data. */
|
|
483
723
|
flatten() {
|
|
@@ -620,7 +860,7 @@ var ShapedArray = class ShapedArray {
|
|
|
620
860
|
get ndim() {
|
|
621
861
|
return this.shape.length;
|
|
622
862
|
}
|
|
623
|
-
|
|
863
|
+
toString() {
|
|
624
864
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
625
865
|
}
|
|
626
866
|
equals(other) {
|
|
@@ -651,7 +891,7 @@ function fullRaise(trace, val) {
|
|
|
651
891
|
if (Object.is(val._trace.main, trace.main)) return val;
|
|
652
892
|
else if (val._trace.main.level < level) return trace.lift(val);
|
|
653
893
|
else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
|
|
654
|
-
else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
|
|
894
|
+
else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
|
|
655
895
|
}
|
|
656
896
|
var TreeMismatchError = class extends TypeError {
|
|
657
897
|
constructor(where$2, left, right) {
|
|
@@ -900,16 +1140,16 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
900
1140
|
jitCompileCache.set(cacheKey, jp);
|
|
901
1141
|
return jp;
|
|
902
1142
|
}
|
|
903
|
-
function reshapeViews(exp$2, mapping) {
|
|
1143
|
+
function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
904
1144
|
return exp$2.rewrite((exp$3) => {
|
|
905
1145
|
if (exp$3.op === require_backend.AluOp.GlobalView) {
|
|
906
1146
|
const [gid, st] = exp$3.arg;
|
|
907
1147
|
const newSt = mapping(st);
|
|
908
1148
|
if (newSt) {
|
|
909
|
-
const indices = require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
|
|
1149
|
+
const indices = reduceAxis ? require_backend.unravelAlu(newSt.shape.slice(0, -1), require_backend.AluVar.gidx).concat(require_backend.AluVar.ridx) : require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
|
|
910
1150
|
return require_backend.AluExp.globalView(exp$3.dtype, gid, newSt, indices);
|
|
911
1151
|
}
|
|
912
|
-
}
|
|
1152
|
+
} else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
913
1153
|
});
|
|
914
1154
|
}
|
|
915
1155
|
function broadcastedJit(fn) {
|
|
@@ -956,8 +1196,11 @@ const jitRules = {
|
|
|
956
1196
|
},
|
|
957
1197
|
[Primitive.Sin]: unopJit(require_backend.AluExp.sin),
|
|
958
1198
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
1199
|
+
[Primitive.Asin]: unopJit(require_backend.AluExp.asin),
|
|
1200
|
+
[Primitive.Atan]: unopJit(require_backend.AluExp.atan),
|
|
959
1201
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
960
1202
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1203
|
+
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
961
1204
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
962
1205
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
963
1206
|
[Primitive.Reduce](nargs, [a], [as], { op, axis }) {
|
|
@@ -972,18 +1215,20 @@ const jitRules = {
|
|
|
972
1215
|
const size$1 = require_backend.prod(newShape);
|
|
973
1216
|
const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
974
1217
|
newShape.push(reductionSize);
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
const [gid, st] = exp$2.arg;
|
|
978
|
-
const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
|
|
979
|
-
const indices = require_backend.unravelAlu(newShape.slice(0, -1), require_backend.AluVar.gidx);
|
|
980
|
-
indices.push(require_backend.AluVar.ridx);
|
|
981
|
-
return require_backend.AluExp.globalView(exp$2.dtype, gid, newSt, indices);
|
|
982
|
-
}
|
|
983
|
-
});
|
|
1218
|
+
const perm = keptAxes.concat(shiftedAxes);
|
|
1219
|
+
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
984
1220
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
985
1221
|
return new require_backend.Kernel(nargs, size$1, a, reduction);
|
|
986
1222
|
},
|
|
1223
|
+
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1224
|
+
[Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
|
|
1225
|
+
let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1226
|
+
const size$1 = require_backend.prod(inShape);
|
|
1227
|
+
stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
|
|
1228
|
+
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1229
|
+
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1230
|
+
return new require_backend.Kernel(nargs, size$1, a, reduction);
|
|
1231
|
+
},
|
|
987
1232
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
988
1233
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
989
1234
|
const c = k1.exp;
|
|
@@ -993,6 +1238,14 @@ const jitRules = {
|
|
|
993
1238
|
axis: [cs.ndim - 1]
|
|
994
1239
|
});
|
|
995
1240
|
},
|
|
1241
|
+
[Primitive.Conv](nargs, [a, b], [as, bs], params) {
|
|
1242
|
+
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1243
|
+
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1244
|
+
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1245
|
+
as = new ShapedArray(stX.shape, as.dtype);
|
|
1246
|
+
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1247
|
+
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1248
|
+
},
|
|
996
1249
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
997
1250
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
|
|
998
1251
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
@@ -1005,8 +1258,20 @@ const jitRules = {
|
|
|
1005
1258
|
}),
|
|
1006
1259
|
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
1007
1260
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1008
|
-
[Primitive.Gather]() {
|
|
1009
|
-
|
|
1261
|
+
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1262
|
+
const axisSet = new Set(axis);
|
|
1263
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1264
|
+
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1265
|
+
finalShape.splice(outDim, 0, ...indexShape);
|
|
1266
|
+
const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
|
|
1267
|
+
const idxNonaxis = [...idxAll];
|
|
1268
|
+
idxNonaxis.splice(outDim, indexShape.length);
|
|
1269
|
+
const src = [...idxNonaxis];
|
|
1270
|
+
for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1271
|
+
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
1272
|
+
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1273
|
+
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1274
|
+
return new require_backend.Kernel(nargs, require_backend.prod(finalShape), x.substitute({ gidx: index }));
|
|
1010
1275
|
},
|
|
1011
1276
|
[Primitive.JitCall]() {
|
|
1012
1277
|
throw new Error("internal: JitCall should have been flattened before JIT compilation");
|
|
@@ -1025,9 +1290,15 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1025
1290
|
blackNodes.add(v);
|
|
1026
1291
|
p1NextBlack.set(v, v);
|
|
1027
1292
|
}
|
|
1293
|
+
const reducePrimitives = [
|
|
1294
|
+
Primitive.Reduce,
|
|
1295
|
+
Primitive.Dot,
|
|
1296
|
+
Primitive.Conv,
|
|
1297
|
+
Primitive.PoolTranspose
|
|
1298
|
+
];
|
|
1028
1299
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1029
1300
|
const eqn = jaxpr.eqns[i];
|
|
1030
|
-
if (eqn.primitive === Primitive.
|
|
1301
|
+
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1031
1302
|
for (const v of eqn.outBinders) {
|
|
1032
1303
|
blackNodes.add(v);
|
|
1033
1304
|
p1NextBlack.set(v, v);
|
|
@@ -1168,7 +1439,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1168
1439
|
* is a backend `Slot`, this constructor _takes ownership_ of the slot. It
|
|
1169
1440
|
* will be freed when the array is disposed.
|
|
1170
1441
|
*/
|
|
1171
|
-
constructor(source, st, dtype, backend, pending = null) {
|
|
1442
|
+
constructor(source, st, dtype, backend, { pending = null } = {}) {
|
|
1172
1443
|
super(baseArrayTrace);
|
|
1173
1444
|
this.id = Array$1.#nextId++;
|
|
1174
1445
|
this.#dtype = dtype;
|
|
@@ -1177,6 +1448,8 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1177
1448
|
this.#backend = backend;
|
|
1178
1449
|
this.#rc = 1;
|
|
1179
1450
|
this.#pendingSet = new Set(pending);
|
|
1451
|
+
if (this.#pendingSet.size === 0) this.#pendingSet = null;
|
|
1452
|
+
else if (source instanceof require_backend.AluExp) throw new Error("internal: AluExp source cannot have pending executes");
|
|
1180
1453
|
}
|
|
1181
1454
|
/** @ignore */
|
|
1182
1455
|
get aval() {
|
|
@@ -1231,7 +1504,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1231
1504
|
const pending = this.#pending;
|
|
1232
1505
|
for (const exe of pending) exe.updateRc(1);
|
|
1233
1506
|
if (typeof this.#source === "number") this.#backend.incRef(this.#source);
|
|
1234
|
-
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, pending);
|
|
1507
|
+
const ar = new Array$1(this.#source, st, this.#dtype, this.#backend, { pending });
|
|
1235
1508
|
this.dispose();
|
|
1236
1509
|
return ar;
|
|
1237
1510
|
}
|
|
@@ -1254,7 +1527,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1254
1527
|
const inputs = [];
|
|
1255
1528
|
const src = [...idxNonaxis];
|
|
1256
1529
|
for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1257
|
-
for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#
|
|
1530
|
+
for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#source, ar.#st, idxAxis));
|
|
1258
1531
|
else {
|
|
1259
1532
|
let gid = inputs.indexOf(ar.#source);
|
|
1260
1533
|
if (gid === -1) {
|
|
@@ -1264,7 +1537,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1264
1537
|
src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
|
|
1265
1538
|
}
|
|
1266
1539
|
let exp$2;
|
|
1267
|
-
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#
|
|
1540
|
+
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, src);
|
|
1268
1541
|
else {
|
|
1269
1542
|
let gid = inputs.indexOf(this.#source);
|
|
1270
1543
|
if (gid === -1) {
|
|
@@ -1280,7 +1553,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1280
1553
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1281
1554
|
this.dispose();
|
|
1282
1555
|
for (const ar of indices) ar.dispose();
|
|
1283
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, pending);
|
|
1556
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(finalShape), this.#dtype, this.#backend, { pending });
|
|
1284
1557
|
}
|
|
1285
1558
|
/** Move axes to the rightmost dimension of the shape. */
|
|
1286
1559
|
#moveAxesDown(axis) {
|
|
@@ -1307,7 +1580,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1307
1580
|
this.#check();
|
|
1308
1581
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1309
1582
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1310
|
-
return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
|
|
1583
|
+
return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
|
|
1311
1584
|
}
|
|
1312
1585
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1313
1586
|
const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1317,7 +1590,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1317
1590
|
for (const exe of pending) exe.updateRc(1);
|
|
1318
1591
|
pending.push(new PendingExecute(this.#backend, kernel, [this.#source], [output]));
|
|
1319
1592
|
this.dispose();
|
|
1320
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, pending);
|
|
1593
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(this.shape), dtypeOutput, this.#backend, { pending });
|
|
1321
1594
|
}
|
|
1322
1595
|
#binary(op, other) {
|
|
1323
1596
|
const custom = (src) => new require_backend.AluExp(op, this.#dtype, src);
|
|
@@ -1340,18 +1613,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1340
1613
|
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1341
1614
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1342
1615
|
const newShape = [...arrays[0].shape];
|
|
1343
|
-
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && reduceAxis
|
|
1616
|
+
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
1344
1617
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1345
1618
|
const exp$4 = custom(arrays.map((ar) => ar.#source));
|
|
1346
|
-
return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
|
|
1619
|
+
return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
|
|
1347
1620
|
}
|
|
1348
1621
|
const exp$3 = custom(arrays.map((ar) => {
|
|
1349
1622
|
const src$1 = ar.#source;
|
|
1350
1623
|
if (ar.#st.contiguous) return src$1;
|
|
1351
|
-
return require_backend.accessorAluExp(
|
|
1624
|
+
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1352
1625
|
}));
|
|
1353
1626
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1354
|
-
return new Array$1(exp$3, st, exp$3.dtype, backend);
|
|
1627
|
+
return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
|
|
1355
1628
|
}
|
|
1356
1629
|
let indices;
|
|
1357
1630
|
if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
|
|
@@ -1361,7 +1634,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1361
1634
|
}
|
|
1362
1635
|
const inputs = [];
|
|
1363
1636
|
const src = [];
|
|
1364
|
-
for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#
|
|
1637
|
+
for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
|
|
1365
1638
|
else {
|
|
1366
1639
|
let gid = inputs.indexOf(ar.#source);
|
|
1367
1640
|
if (gid === -1) {
|
|
@@ -1382,7 +1655,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1382
1655
|
for (const exe of pending) exe.updateRc(1);
|
|
1383
1656
|
pending.add(new PendingExecute(backend, kernel, inputs, [output]));
|
|
1384
1657
|
for (const ar of arrays) ar.dispose();
|
|
1385
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, pending);
|
|
1658
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), dtypeOutput, backend, { pending });
|
|
1386
1659
|
}
|
|
1387
1660
|
/** Reduce the last dimension of the array by an operation. */
|
|
1388
1661
|
#reduce(op) {
|
|
@@ -1395,7 +1668,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1395
1668
|
const indices = [...require_backend.unravelAlu(newShape, require_backend.AluVar.gidx), require_backend.AluVar.ridx];
|
|
1396
1669
|
let exp$2;
|
|
1397
1670
|
const inputs = [];
|
|
1398
|
-
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#
|
|
1671
|
+
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
|
|
1399
1672
|
else {
|
|
1400
1673
|
inputs.push(this.#source);
|
|
1401
1674
|
exp$2 = require_backend.accessorGlobal(this.#dtype, 0, this.#st, indices);
|
|
@@ -1406,7 +1679,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1406
1679
|
for (const exe of pending) exe.updateRc(1);
|
|
1407
1680
|
pending.push(new PendingExecute(this.#backend, kernel, inputs, [output]));
|
|
1408
1681
|
this.dispose();
|
|
1409
|
-
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, pending);
|
|
1682
|
+
return new Array$1(output, require_backend.ShapeTracker.fromShape(newShape), this.#dtype, this.#backend, { pending });
|
|
1410
1683
|
}
|
|
1411
1684
|
/**
|
|
1412
1685
|
* Normalizes this array into one backed by a `Slot`.
|
|
@@ -1420,7 +1693,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1420
1693
|
this.#check();
|
|
1421
1694
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1422
1695
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1423
|
-
const exp$2 = require_backend.accessorAluExp(this.#
|
|
1696
|
+
const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
|
|
1424
1697
|
const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
|
|
1425
1698
|
const output = this.#backend.malloc(kernel.bytes);
|
|
1426
1699
|
const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
|
|
@@ -1458,42 +1731,54 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1458
1731
|
}
|
|
1459
1732
|
/** Realize the array and return it as data. */
|
|
1460
1733
|
async data() {
|
|
1461
|
-
if (this.#source instanceof require_backend.AluExp &&
|
|
1734
|
+
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1462
1735
|
this.#realize();
|
|
1463
1736
|
const pending = this.#pending;
|
|
1464
1737
|
if (pending) {
|
|
1465
1738
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1466
1739
|
for (const p of pending) p.submit();
|
|
1467
1740
|
}
|
|
1468
|
-
const byteCount = require_backend.byteWidth(this.#dtype) *
|
|
1741
|
+
const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
|
|
1469
1742
|
const buf = await this.#backend.read(this.#source, 0, byteCount);
|
|
1470
1743
|
this.dispose();
|
|
1471
1744
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
1472
1745
|
}
|
|
1473
|
-
/**
|
|
1474
|
-
|
|
1746
|
+
/**
|
|
1747
|
+
* Wait for this array to finish evaluation.
|
|
1748
|
+
*
|
|
1749
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1750
|
+
* that pending operations are dispatched and fully executed before it
|
|
1751
|
+
* returns.
|
|
1752
|
+
*
|
|
1753
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1754
|
+
* dispatch of operations as well.
|
|
1755
|
+
*
|
|
1756
|
+
* **Note:** `jax.blockUntilReady()` is a higher-level API, it calls this
|
|
1757
|
+
* asynchronously for multiple arrays.
|
|
1758
|
+
*/
|
|
1759
|
+
async blockUntilReady() {
|
|
1475
1760
|
this.#check();
|
|
1476
|
-
if (this.#source instanceof require_backend.AluExp) return;
|
|
1761
|
+
if (this.#source instanceof require_backend.AluExp) return this;
|
|
1477
1762
|
const pending = this.#pending;
|
|
1478
1763
|
if (pending) {
|
|
1479
1764
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1480
1765
|
for (const p of pending) p.submit();
|
|
1481
1766
|
}
|
|
1482
1767
|
await this.#backend.read(this.#source, 0, 0);
|
|
1483
|
-
this
|
|
1768
|
+
return this;
|
|
1484
1769
|
}
|
|
1485
1770
|
/**
|
|
1486
1771
|
* Realize the array and return it as data. This is a sync variant and not
|
|
1487
1772
|
* recommended for performance reasons, as it will block rendering.
|
|
1488
1773
|
*/
|
|
1489
1774
|
dataSync() {
|
|
1490
|
-
if (this.#source instanceof require_backend.AluExp &&
|
|
1775
|
+
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1491
1776
|
this.#realize();
|
|
1492
1777
|
for (const p of this.#pending) {
|
|
1493
1778
|
p.prepareSync();
|
|
1494
1779
|
p.submit();
|
|
1495
1780
|
}
|
|
1496
|
-
const byteCount = require_backend.byteWidth(this.#dtype) *
|
|
1781
|
+
const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
|
|
1497
1782
|
const buf = this.#backend.readSync(this.#source, 0, byteCount);
|
|
1498
1783
|
this.dispose();
|
|
1499
1784
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
@@ -1514,6 +1799,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1514
1799
|
async jsAsync() {
|
|
1515
1800
|
return dataToJs(this.dtype, await this.data(), this.shape);
|
|
1516
1801
|
}
|
|
1802
|
+
/**
|
|
1803
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
1804
|
+
*
|
|
1805
|
+
* Throws an error if the array does not have a single element. The array must
|
|
1806
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
1807
|
+
*/
|
|
1808
|
+
item() {
|
|
1809
|
+
if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
|
|
1810
|
+
return this.dataSync()[0];
|
|
1811
|
+
}
|
|
1517
1812
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1518
1813
|
static _implRules() {
|
|
1519
1814
|
return {
|
|
@@ -1527,7 +1822,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1527
1822
|
return [x.#binary(require_backend.AluOp.Idiv, y)];
|
|
1528
1823
|
},
|
|
1529
1824
|
[Primitive.Neg]([x]) {
|
|
1530
|
-
return [zerosLike(x).#binary(require_backend.AluOp.Sub, x)];
|
|
1825
|
+
return [zerosLike$1(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
1531
1826
|
},
|
|
1532
1827
|
[Primitive.Reciprocal]([x]) {
|
|
1533
1828
|
return [x.#unary(require_backend.AluOp.Reciprocal)];
|
|
@@ -1547,7 +1842,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1547
1842
|
x.#backend.incRef(x.#source);
|
|
1548
1843
|
const pending = x.#pending;
|
|
1549
1844
|
for (const exe of pending) exe.updateRc(1);
|
|
1550
|
-
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, pending);
|
|
1845
|
+
const y = new Array$1(x.#source, x.#st, dtype, x.#backend, { pending });
|
|
1551
1846
|
x.dispose();
|
|
1552
1847
|
return [y];
|
|
1553
1848
|
}
|
|
@@ -1577,12 +1872,21 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1577
1872
|
[Primitive.Cos]([x]) {
|
|
1578
1873
|
return [x.#unary(require_backend.AluOp.Cos)];
|
|
1579
1874
|
},
|
|
1875
|
+
[Primitive.Asin]([x]) {
|
|
1876
|
+
return [x.#unary(require_backend.AluOp.Asin)];
|
|
1877
|
+
},
|
|
1878
|
+
[Primitive.Atan]([x]) {
|
|
1879
|
+
return [x.#unary(require_backend.AluOp.Atan)];
|
|
1880
|
+
},
|
|
1580
1881
|
[Primitive.Exp]([x]) {
|
|
1581
1882
|
return [x.#unary(require_backend.AluOp.Exp)];
|
|
1582
1883
|
},
|
|
1583
1884
|
[Primitive.Log]([x]) {
|
|
1584
1885
|
return [x.#unary(require_backend.AluOp.Log)];
|
|
1585
1886
|
},
|
|
1887
|
+
[Primitive.Sqrt]([x]) {
|
|
1888
|
+
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
1889
|
+
},
|
|
1586
1890
|
[Primitive.Min]([x, y]) {
|
|
1587
1891
|
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
1588
1892
|
},
|
|
@@ -1593,9 +1897,24 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1593
1897
|
if (axis.length === 0) return [x];
|
|
1594
1898
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
1595
1899
|
},
|
|
1900
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
1901
|
+
const st = pool(x.#st, window, strides);
|
|
1902
|
+
return [x.#reshape(st)];
|
|
1903
|
+
},
|
|
1904
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
1905
|
+
const n = inShape.length;
|
|
1906
|
+
let st = poolTranspose(x.#st, inShape, window, strides);
|
|
1907
|
+
st = st.reshape([...st.shape.slice(0, n), require_backend.prod(st.shape.slice(n))]);
|
|
1908
|
+
return [x.#reshape(st).#reduce(require_backend.AluOp.Add)];
|
|
1909
|
+
},
|
|
1596
1910
|
[Primitive.Dot]([x, y]) {
|
|
1597
1911
|
return [Array$1.#naryCustom("dot", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
|
|
1598
1912
|
},
|
|
1913
|
+
[Primitive.Conv]([x, y], params) {
|
|
1914
|
+
checkConvShape(x.shape, y.shape, params);
|
|
1915
|
+
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
1916
|
+
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
1917
|
+
},
|
|
1599
1918
|
[Primitive.Compare]([x, y], { op }) {
|
|
1600
1919
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1601
1920
|
return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
|
|
@@ -1644,7 +1963,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1644
1963
|
pending.splice(0, 0, ...prevPending);
|
|
1645
1964
|
args.forEach((x) => x.dispose());
|
|
1646
1965
|
return outputs.map((source, i) => {
|
|
1647
|
-
return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, pending);
|
|
1966
|
+
return new Array$1(source, require_backend.ShapeTracker.fromShape(jaxpr.outs[i].aval.shape), jaxpr.outs[i].aval.dtype, backend, { pending });
|
|
1648
1967
|
});
|
|
1649
1968
|
}
|
|
1650
1969
|
};
|
|
@@ -1660,6 +1979,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1660
1979
|
dtype ??= require_backend.DType.Float32;
|
|
1661
1980
|
if (![
|
|
1662
1981
|
require_backend.DType.Float32,
|
|
1982
|
+
require_backend.DType.Float16,
|
|
1663
1983
|
require_backend.DType.Int32,
|
|
1664
1984
|
require_backend.DType.Uint32
|
|
1665
1985
|
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
@@ -1667,6 +1987,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1667
1987
|
dtype ??= require_backend.DType.Bool;
|
|
1668
1988
|
if (![
|
|
1669
1989
|
require_backend.DType.Float32,
|
|
1990
|
+
require_backend.DType.Float16,
|
|
1670
1991
|
require_backend.DType.Int32,
|
|
1671
1992
|
require_backend.DType.Uint32,
|
|
1672
1993
|
require_backend.DType.Bool
|
|
@@ -1680,7 +2001,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1680
2001
|
if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1681
2002
|
if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
|
|
1682
2003
|
return values;
|
|
1683
|
-
} else if (values
|
|
2004
|
+
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1684
2005
|
dtype,
|
|
1685
2006
|
device
|
|
1686
2007
|
});
|
|
@@ -1709,7 +2030,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1709
2030
|
});
|
|
1710
2031
|
} else {
|
|
1711
2032
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
1712
|
-
const data = require_backend.
|
|
2033
|
+
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
1713
2034
|
return arrayFromData(data, shape$1, {
|
|
1714
2035
|
dtype,
|
|
1715
2036
|
device
|
|
@@ -1730,19 +2051,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
|
|
|
1730
2051
|
});
|
|
1731
2052
|
}
|
|
1732
2053
|
const backend = require_backend.getBackend(device);
|
|
1733
|
-
if (data
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
if (
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
2054
|
+
if (ArrayBuffer.isView(data)) {
|
|
2055
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2056
|
+
if (data instanceof Float32Array) {
|
|
2057
|
+
if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2058
|
+
dtype ??= require_backend.DType.Float32;
|
|
2059
|
+
} else if (data instanceof Int32Array) {
|
|
2060
|
+
if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2061
|
+
dtype ??= require_backend.DType.Int32;
|
|
2062
|
+
} else if (data instanceof Uint32Array) {
|
|
2063
|
+
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2064
|
+
dtype ??= require_backend.DType.Uint32;
|
|
2065
|
+
} else if (data instanceof Float16Array) {
|
|
2066
|
+
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2067
|
+
dtype ??= require_backend.DType.Float16;
|
|
2068
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2069
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2070
|
+
return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2071
|
+
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
1746
2072
|
}
|
|
1747
2073
|
function dataToJs(dtype, data, shape$1) {
|
|
1748
2074
|
if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -1769,9 +2095,20 @@ var EvalTrace = class extends Trace {
|
|
|
1769
2095
|
};
|
|
1770
2096
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
1771
2097
|
const implRules = Array$1._implRules();
|
|
1772
|
-
function zerosLike(val) {
|
|
2098
|
+
function zerosLike$1(val, dtype) {
|
|
1773
2099
|
const aval = getAval(val);
|
|
1774
|
-
|
|
2100
|
+
if (val instanceof Tracer) val.dispose();
|
|
2101
|
+
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2102
|
+
}
|
|
2103
|
+
function onesLike$1(val, dtype) {
|
|
2104
|
+
const aval = getAval(val);
|
|
2105
|
+
if (val instanceof Tracer) val.dispose();
|
|
2106
|
+
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2107
|
+
}
|
|
2108
|
+
function fullLike(val, fillValue, dtype) {
|
|
2109
|
+
const aval = getAval(val);
|
|
2110
|
+
if (val instanceof Tracer) val.dispose();
|
|
2111
|
+
return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
|
|
1775
2112
|
}
|
|
1776
2113
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1777
2114
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -1793,6 +2130,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
1793
2130
|
if (typeof fillValue === "number") {
|
|
1794
2131
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
1795
2132
|
source = require_backend.AluExp.const(dtype, fillValue);
|
|
2133
|
+
} else if (typeof fillValue === "bigint") {
|
|
2134
|
+
dtype = dtype ?? require_backend.DType.Int32;
|
|
2135
|
+
source = require_backend.AluExp.const(dtype, Number(fillValue));
|
|
1796
2136
|
} else if (typeof fillValue === "boolean") {
|
|
1797
2137
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
1798
2138
|
source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
|
|
@@ -1823,7 +2163,7 @@ function eye(numRows, numCols, { dtype, device } = {}) {
|
|
|
1823
2163
|
const exp$2 = require_backend.AluExp.cmplt(require_backend.AluExp.mod(require_backend.AluVar.idx, require_backend.AluExp.i32(numCols + 1)), require_backend.AluExp.i32(1));
|
|
1824
2164
|
return new Array$1(require_backend.AluExp.cast(dtype, exp$2), require_backend.ShapeTracker.fromShape([numRows, numCols]), dtype, require_backend.getBackend(device));
|
|
1825
2165
|
}
|
|
1826
|
-
/** Return the identity
|
|
2166
|
+
/** Return the identity matrix, with ones on the main diagonal. */
|
|
1827
2167
|
function identity$1(n, { dtype, device } = {}) {
|
|
1828
2168
|
return eye(n, n, {
|
|
1829
2169
|
dtype,
|
|
@@ -1890,7 +2230,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
1890
2230
|
const st = require_backend.ShapeTracker.fromShape([num]);
|
|
1891
2231
|
return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
|
|
1892
2232
|
}
|
|
1893
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
1894
2233
|
function aluCompare(a, b, op) {
|
|
1895
2234
|
switch (op) {
|
|
1896
2235
|
case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
|
|
@@ -1932,8 +2271,8 @@ function generalBroadcast(a, b) {
|
|
|
1932
2271
|
}
|
|
1933
2272
|
|
|
1934
2273
|
//#endregion
|
|
1935
|
-
//#region node_modules/.pnpm/@oxc-project+runtime@0.
|
|
1936
|
-
var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.
|
|
2274
|
+
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
|
|
2275
|
+
var require_usingCtx = /* @__PURE__ */ __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js": ((exports, module) => {
|
|
1937
2276
|
function _usingCtx() {
|
|
1938
2277
|
var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
|
|
1939
2278
|
var n$1 = Error();
|
|
@@ -1989,11 +2328,11 @@ var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.7
|
|
|
1989
2328
|
};
|
|
1990
2329
|
}
|
|
1991
2330
|
module.exports = _usingCtx, module.exports.__esModule = true, module.exports["default"] = module.exports;
|
|
1992
|
-
} });
|
|
2331
|
+
}) });
|
|
1993
2332
|
|
|
1994
2333
|
//#endregion
|
|
1995
2334
|
//#region src/frontend/jaxpr.ts
|
|
1996
|
-
var import_usingCtx$2 = __toESM(require_usingCtx(), 1);
|
|
2335
|
+
var import_usingCtx$2 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
1997
2336
|
/** Variable in a Jaxpr expression. */
|
|
1998
2337
|
var Var = class Var {
|
|
1999
2338
|
static #nextId = 1;
|
|
@@ -2004,7 +2343,7 @@ var Var = class Var {
|
|
|
2004
2343
|
this.aval = aval;
|
|
2005
2344
|
}
|
|
2006
2345
|
toString() {
|
|
2007
|
-
return `Var(${this.id}):${this.aval.
|
|
2346
|
+
return `Var(${this.id}):${this.aval.toString()}`;
|
|
2008
2347
|
}
|
|
2009
2348
|
};
|
|
2010
2349
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
@@ -2044,7 +2383,7 @@ var VarPrinter = class {
|
|
|
2044
2383
|
return name;
|
|
2045
2384
|
}
|
|
2046
2385
|
nameType(v) {
|
|
2047
|
-
return `${this.name(v)}:${v.aval.
|
|
2386
|
+
return `${this.name(v)}:${v.aval.toString()}`;
|
|
2048
2387
|
}
|
|
2049
2388
|
};
|
|
2050
2389
|
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
@@ -2104,16 +2443,19 @@ var Jaxpr = class Jaxpr {
|
|
|
2104
2443
|
varIds.set(v, require_backend.FpHash.hash(id, v.aval.dtype, ...v.aval.shape));
|
|
2105
2444
|
return id;
|
|
2106
2445
|
};
|
|
2107
|
-
hasher.update(this.inBinders.length
|
|
2108
|
-
|
|
2109
|
-
|
|
2110
|
-
|
|
2111
|
-
|
|
2112
|
-
|
|
2113
|
-
eqn.
|
|
2114
|
-
|
|
2115
|
-
|
|
2116
|
-
|
|
2446
|
+
hasher.update(this.inBinders.length);
|
|
2447
|
+
for (const x of this.inBinders) hasher.update(vi(x));
|
|
2448
|
+
hasher.update(this.eqns.length);
|
|
2449
|
+
for (const eqn of this.eqns) {
|
|
2450
|
+
hasher.update(eqn.primitive);
|
|
2451
|
+
hasher.update(eqn.inputs.length);
|
|
2452
|
+
for (const x of eqn.inputs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2453
|
+
hasher.update(JSON.stringify(eqn.params));
|
|
2454
|
+
hasher.update(eqn.outBinders.length);
|
|
2455
|
+
for (const x of eqn.outBinders) hasher.update(vi(x));
|
|
2456
|
+
}
|
|
2457
|
+
hasher.update(this.outs.length);
|
|
2458
|
+
for (const x of this.outs) hasher.update(x instanceof Var ? vi(x) : x.value);
|
|
2117
2459
|
return this.#hash = hasher.value;
|
|
2118
2460
|
}
|
|
2119
2461
|
hash(state) {
|
|
@@ -2150,7 +2492,7 @@ var Jaxpr = class Jaxpr {
|
|
|
2150
2492
|
const c = eqn.outBinders[0];
|
|
2151
2493
|
if (atomIsLit(b, 1)) context.set(c, a);
|
|
2152
2494
|
else newEqns.push(eqn);
|
|
2153
|
-
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2495
|
+
} else if ((eqn.primitive === Primitive.Broadcast || eqn.primitive === Primitive.Reshape) && require_backend.deepEqual(eqn.params.shape, eqn.inputs[0].aval.shape) || eqn.primitive === Primitive.Transpose && eqn.params.perm.every((p, i) => p === i) || eqn.primitive === Primitive.Flip && eqn.params.axis.length === 0 || eqn.primitive === Primitive.Shrink && eqn.params.slice.every(([s, e$2], i) => s === 0 && e$2 === eqn.inputs[0].aval.shape[i]) || eqn.primitive === Primitive.Pad && eqn.params.width.every(([w0, w1]) => w0 === 0 && w1 === 0)) context.set(eqn.outBinders[0], eqn.inputs[0]);
|
|
2154
2496
|
else newEqns.push(eqn);
|
|
2155
2497
|
}
|
|
2156
2498
|
const outs = this.outs.map((x) => x instanceof Var ? context.get(x) ?? x : x);
|
|
@@ -2199,8 +2541,8 @@ var JaxprType = class {
|
|
|
2199
2541
|
this.outTypes = outTypes;
|
|
2200
2542
|
}
|
|
2201
2543
|
toString() {
|
|
2202
|
-
const inTypes = this.inTypes.map((aval) => aval.
|
|
2203
|
-
const outTypes = this.outTypes.map((aval) => aval.
|
|
2544
|
+
const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
|
|
2545
|
+
const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
|
|
2204
2546
|
return `(${inTypes}) -> (${outTypes})`;
|
|
2205
2547
|
}
|
|
2206
2548
|
};
|
|
@@ -2279,7 +2621,7 @@ var JaxprTracer = class extends Tracer {
|
|
|
2279
2621
|
this.aval = aval;
|
|
2280
2622
|
}
|
|
2281
2623
|
toString() {
|
|
2282
|
-
return `JaxprTracer(${this.aval.
|
|
2624
|
+
return `JaxprTracer(${this.aval.toString()})`;
|
|
2283
2625
|
}
|
|
2284
2626
|
get ref() {
|
|
2285
2627
|
return this;
|
|
@@ -2416,8 +2758,11 @@ const abstractEvalRules = {
|
|
|
2416
2758
|
},
|
|
2417
2759
|
[Primitive.Sin]: vectorizedUnopAbstractEval,
|
|
2418
2760
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2761
|
+
[Primitive.Asin]: vectorizedUnopAbstractEval,
|
|
2762
|
+
[Primitive.Atan]: vectorizedUnopAbstractEval,
|
|
2419
2763
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2420
2764
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2765
|
+
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2421
2766
|
[Primitive.Min]: binopAbstractEval,
|
|
2422
2767
|
[Primitive.Max]: binopAbstractEval,
|
|
2423
2768
|
[Primitive.Reduce]([x], { axis }) {
|
|
@@ -2425,6 +2770,15 @@ const abstractEvalRules = {
|
|
|
2425
2770
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2426
2771
|
return [new ShapedArray(newShape, x.dtype)];
|
|
2427
2772
|
},
|
|
2773
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
2774
|
+
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2775
|
+
return [new ShapedArray(shape$1, x.dtype)];
|
|
2776
|
+
},
|
|
2777
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2778
|
+
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2779
|
+
if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2780
|
+
return [new ShapedArray(inShape, x.dtype)];
|
|
2781
|
+
},
|
|
2428
2782
|
[Primitive.Dot]([x, y]) {
|
|
2429
2783
|
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2430
2784
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
@@ -2432,6 +2786,11 @@ const abstractEvalRules = {
|
|
|
2432
2786
|
shape$1.splice(-1, 1);
|
|
2433
2787
|
return [new ShapedArray(shape$1, x.dtype)];
|
|
2434
2788
|
},
|
|
2789
|
+
[Primitive.Conv]([lhs, rhs], params) {
|
|
2790
|
+
if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
|
|
2791
|
+
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2792
|
+
return [new ShapedArray(shape$1, lhs.dtype)];
|
|
2793
|
+
},
|
|
2435
2794
|
[Primitive.Compare]: compareAbstractEval,
|
|
2436
2795
|
[Primitive.Where]([cond, x, y]) {
|
|
2437
2796
|
if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
@@ -2479,15 +2838,34 @@ const abstractEvalRules = {
|
|
|
2479
2838
|
return outTypes;
|
|
2480
2839
|
}
|
|
2481
2840
|
};
|
|
2482
|
-
function
|
|
2841
|
+
function splitIdx(values, argnums) {
|
|
2842
|
+
const a = [];
|
|
2843
|
+
const b = [];
|
|
2844
|
+
for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
|
|
2845
|
+
else b.push(values[i]);
|
|
2846
|
+
return [a, b];
|
|
2847
|
+
}
|
|
2848
|
+
function joinIdx(n, a, b, argnums) {
|
|
2849
|
+
const result = [];
|
|
2850
|
+
let ai = 0;
|
|
2851
|
+
let bi = 0;
|
|
2852
|
+
for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
|
|
2853
|
+
else result.push(b[bi++]);
|
|
2854
|
+
return result;
|
|
2855
|
+
}
|
|
2856
|
+
function makeJaxpr$1(f, opts) {
|
|
2483
2857
|
return (...argsIn) => {
|
|
2484
2858
|
try {
|
|
2485
2859
|
var _usingCtx$1 = (0, import_usingCtx$2.default)();
|
|
2486
|
-
const
|
|
2487
|
-
const [
|
|
2860
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2861
|
+
const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
|
|
2862
|
+
const [avalsIn, inTree] = flatten(shapedArgs);
|
|
2863
|
+
const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
|
|
2864
|
+
return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
|
|
2865
|
+
}, inTree);
|
|
2488
2866
|
const builder = new JaxprBuilder();
|
|
2489
2867
|
const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
|
|
2490
|
-
|
|
2868
|
+
_usingCtx$1.u(newDynamic(main));
|
|
2491
2869
|
const trace = new JaxprTrace(main);
|
|
2492
2870
|
const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
2493
2871
|
const outs = fFlat(...tracersIn);
|
|
@@ -2506,25 +2884,32 @@ function makeJaxpr$1(f) {
|
|
|
2506
2884
|
}
|
|
2507
2885
|
};
|
|
2508
2886
|
}
|
|
2509
|
-
function jit$1(f) {
|
|
2887
|
+
function jit$1(f, opts) {
|
|
2510
2888
|
const cache = /* @__PURE__ */ new Map();
|
|
2511
|
-
|
|
2512
|
-
|
|
2889
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2890
|
+
const result = ((...args) => {
|
|
2891
|
+
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2892
|
+
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2513
2893
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
2514
2894
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
2515
|
-
const
|
|
2516
|
-
const
|
|
2895
|
+
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
2896
|
+
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2897
|
+
const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2517
2898
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2518
2899
|
jaxpr,
|
|
2519
2900
|
numConsts: consts.length
|
|
2520
2901
|
});
|
|
2521
2902
|
return unflatten(outTree, outs);
|
|
2522
2903
|
});
|
|
2904
|
+
result.dispose = () => {
|
|
2905
|
+
for (const { consts } of cache.values()) for (const c of consts) c.dispose();
|
|
2906
|
+
};
|
|
2907
|
+
return result;
|
|
2523
2908
|
}
|
|
2524
2909
|
|
|
2525
2910
|
//#endregion
|
|
2526
2911
|
//#region src/frontend/jvp.ts
|
|
2527
|
-
var import_usingCtx$1 = __toESM(require_usingCtx(), 1);
|
|
2912
|
+
var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
2528
2913
|
var JVPTracer = class extends Tracer {
|
|
2529
2914
|
constructor(trace, primal, tangent) {
|
|
2530
2915
|
super(trace);
|
|
@@ -2551,7 +2936,7 @@ var JVPTrace = class extends Trace {
|
|
|
2551
2936
|
return this.lift(pureArray(val));
|
|
2552
2937
|
}
|
|
2553
2938
|
lift(val) {
|
|
2554
|
-
return new JVPTracer(this, val, zerosLike(val));
|
|
2939
|
+
return new JVPTracer(this, val, zerosLike$1(val.ref));
|
|
2555
2940
|
}
|
|
2556
2941
|
processPrimitive(primitive, tracers, params) {
|
|
2557
2942
|
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2569,19 +2954,25 @@ function linearTangentsJvp(primitive) {
|
|
|
2569
2954
|
return [ys, dys];
|
|
2570
2955
|
};
|
|
2571
2956
|
}
|
|
2957
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
2958
|
+
function bilinearTangentsJvp(primitive) {
|
|
2959
|
+
return ([x, y], [dx, dy], params) => {
|
|
2960
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
2961
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
2962
|
+
return [[primal], [tangent]];
|
|
2963
|
+
};
|
|
2964
|
+
}
|
|
2572
2965
|
/** Rule that zeros out any tangents. */
|
|
2573
2966
|
function zeroTangentsJvp(primitive) {
|
|
2574
2967
|
return (primals, tangents, params) => {
|
|
2575
2968
|
for (const t of tangents) t.dispose();
|
|
2576
2969
|
const ys = bind(primitive, primals, params);
|
|
2577
|
-
return [ys, ys.map((y) => zerosLike(y))];
|
|
2970
|
+
return [ys, ys.map((y) => zerosLike$1(y.ref))];
|
|
2578
2971
|
};
|
|
2579
2972
|
}
|
|
2580
2973
|
const jvpRules = {
|
|
2581
2974
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
2582
|
-
[Primitive.Mul](
|
|
2583
|
-
return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
|
|
2584
|
-
},
|
|
2975
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
2585
2976
|
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
2586
2977
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
2587
2978
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
@@ -2594,13 +2985,13 @@ const jvpRules = {
|
|
|
2594
2985
|
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2595
2986
|
else {
|
|
2596
2987
|
dx.dispose();
|
|
2597
|
-
return [[cast(x, dtype)], [zerosLike(x)]];
|
|
2988
|
+
return [[cast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2598
2989
|
}
|
|
2599
2990
|
},
|
|
2600
2991
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2601
2992
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2602
2993
|
dx.dispose();
|
|
2603
|
-
return [[bitcast(x, dtype)], [zerosLike(x)]];
|
|
2994
|
+
return [[bitcast(x.ref, dtype)], [zerosLike$1(x)]];
|
|
2604
2995
|
},
|
|
2605
2996
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2606
2997
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2609,6 +3000,14 @@ const jvpRules = {
|
|
|
2609
3000
|
[Primitive.Cos]([x], [dx]) {
|
|
2610
3001
|
return [[cos$1(x.ref)], [neg(sin$1(x)).mul(dx)]];
|
|
2611
3002
|
},
|
|
3003
|
+
[Primitive.Asin]([x], [dx]) {
|
|
3004
|
+
const denom = sqrt$1(reciprocal$1(cast(1, x.dtype).sub(x.ref.mul(x.ref))));
|
|
3005
|
+
return [[asin$1(x)], [denom.mul(dx)]];
|
|
3006
|
+
},
|
|
3007
|
+
[Primitive.Atan]([x], [dx]) {
|
|
3008
|
+
const denom = cast(1, x.dtype).add(x.ref.mul(x.ref));
|
|
3009
|
+
return [[atan$1(x)], [dx.div(denom)]];
|
|
3010
|
+
},
|
|
2612
3011
|
[Primitive.Exp]([x], [dx]) {
|
|
2613
3012
|
const z = exp$1(x);
|
|
2614
3013
|
return [[z.ref], [z.mul(dx)]];
|
|
@@ -2616,6 +3015,10 @@ const jvpRules = {
|
|
|
2616
3015
|
[Primitive.Log]([x], [dx]) {
|
|
2617
3016
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
2618
3017
|
},
|
|
3018
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
3019
|
+
const z = sqrt$1(x);
|
|
3020
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
3021
|
+
},
|
|
2619
3022
|
[Primitive.Min]([x, y], [dx, dy]) {
|
|
2620
3023
|
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
2621
3024
|
},
|
|
@@ -2632,13 +3035,14 @@ const jvpRules = {
|
|
|
2632
3035
|
const primal = reduce(x.ref, op, axis);
|
|
2633
3036
|
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
2634
3037
|
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
2635
|
-
const tangent = where$1(notMin,
|
|
3038
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
2636
3039
|
return [[primal], [tangent]];
|
|
2637
3040
|
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
2638
3041
|
},
|
|
2639
|
-
[Primitive.
|
|
2640
|
-
|
|
2641
|
-
|
|
3042
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3043
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3044
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3045
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
2642
3046
|
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
2643
3047
|
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
2644
3048
|
dcond.dispose();
|
|
@@ -2711,7 +3115,7 @@ function jvp$1(f, primals, tangents) {
|
|
|
2711
3115
|
|
|
2712
3116
|
//#endregion
|
|
2713
3117
|
//#region src/frontend/vmap.ts
|
|
2714
|
-
var import_usingCtx = __toESM(require_usingCtx(), 1);
|
|
3118
|
+
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
2715
3119
|
function mappedAval(batchDim, aval) {
|
|
2716
3120
|
const shape$1 = [...aval.shape];
|
|
2717
3121
|
shape$1.splice(batchDim, 1);
|
|
@@ -2720,7 +3124,10 @@ function mappedAval(batchDim, aval) {
|
|
|
2720
3124
|
/** Move one axis to a different index. */
|
|
2721
3125
|
function moveaxis$1(x, src, dst) {
|
|
2722
3126
|
const t = pureArray(x);
|
|
2723
|
-
|
|
3127
|
+
src = require_backend.checkAxis(src, t.ndim);
|
|
3128
|
+
dst = require_backend.checkAxis(dst, t.ndim);
|
|
3129
|
+
if (src === dst) return t;
|
|
3130
|
+
const perm = require_backend.range(t.ndim);
|
|
2724
3131
|
perm.splice(src, 1);
|
|
2725
3132
|
perm.splice(dst, 0, src);
|
|
2726
3133
|
return transpose$1(t, perm);
|
|
@@ -2813,8 +3220,11 @@ const vmapRules = {
|
|
|
2813
3220
|
[Primitive.Bitcast]: unopBatcher((x, { dtype }) => bitcast(x, dtype)),
|
|
2814
3221
|
[Primitive.Sin]: unopBatcher(sin$1),
|
|
2815
3222
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
3223
|
+
[Primitive.Asin]: unopBatcher(asin$1),
|
|
3224
|
+
[Primitive.Atan]: unopBatcher(atan$1),
|
|
2816
3225
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
2817
3226
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3227
|
+
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
2818
3228
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
2819
3229
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
2820
3230
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
@@ -2951,7 +3361,7 @@ var PartialVal = class PartialVal {
|
|
|
2951
3361
|
return this.val !== null;
|
|
2952
3362
|
}
|
|
2953
3363
|
toString() {
|
|
2954
|
-
return this.val ? this.val.toString() : this.aval.
|
|
3364
|
+
return this.val ? this.val.toString() : this.aval.toString();
|
|
2955
3365
|
}
|
|
2956
3366
|
};
|
|
2957
3367
|
function partialEvalFlat(f, pvalsIn) {
|
|
@@ -2997,20 +3407,28 @@ function linearizeFlatUtil(f, primalsIn) {
|
|
|
2997
3407
|
function linearizeFlat(f, primalsIn) {
|
|
2998
3408
|
const { primalsOut, jaxpr, consts } = linearizeFlatUtil(f, primalsIn);
|
|
2999
3409
|
const fLin = (...tangents) => evalJaxpr(jaxpr, [...consts.map((c) => c.ref), ...tangents]);
|
|
3000
|
-
|
|
3410
|
+
const dispose$1 = () => {
|
|
3411
|
+
for (const c of consts) c.dispose();
|
|
3412
|
+
};
|
|
3413
|
+
return [
|
|
3414
|
+
primalsOut,
|
|
3415
|
+
fLin,
|
|
3416
|
+
dispose$1
|
|
3417
|
+
];
|
|
3001
3418
|
}
|
|
3002
3419
|
function linearize$1(f, ...primalsIn) {
|
|
3003
3420
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3004
3421
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3005
|
-
const [primalsOutFlat, fLinFlat] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3422
|
+
const [primalsOutFlat, fLinFlat, dispose$1] = linearizeFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3006
3423
|
if (outTree.value === void 0) throw new Error("outTree was not set in linearize");
|
|
3007
3424
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3008
|
-
const fLin = (...tangentsIn) => {
|
|
3425
|
+
const fLin = ((...tangentsIn) => {
|
|
3009
3426
|
const [tangentsInFlat, inTree2] = flatten(tangentsIn);
|
|
3010
3427
|
if (!inTree.equals(inTree2)) throw new TreeMismatchError("linearize", inTree, inTree2);
|
|
3011
3428
|
const tangentsOutFlat = fLinFlat(...tangentsInFlat.map(pureArray));
|
|
3012
3429
|
return unflatten(outTree.value, tangentsOutFlat);
|
|
3013
|
-
};
|
|
3430
|
+
});
|
|
3431
|
+
fLin.dispose = dispose$1;
|
|
3014
3432
|
return [primalsOut, fLin];
|
|
3015
3433
|
}
|
|
3016
3434
|
var PartialEvalTracer = class extends Tracer {
|
|
@@ -3126,7 +3544,10 @@ var PartialEvalTrace = class extends Trace {
|
|
|
3126
3544
|
avalsOut: jaxpr2.outs.map((x) => x.aval),
|
|
3127
3545
|
tracerRefsOut: []
|
|
3128
3546
|
};
|
|
3129
|
-
const outs2 = jaxpr2.outs.map((x) =>
|
|
3547
|
+
const outs2 = jaxpr2.outs.map((x, i$1) => {
|
|
3548
|
+
if (i$1 > 0) recipe.tracersIn.forEach((t) => t.ref);
|
|
3549
|
+
return new PartialEvalTracer(this, PartialVal.unknown(x.aval), recipe);
|
|
3550
|
+
});
|
|
3130
3551
|
recipe.tracerRefsOut = outs2.map((t) => new WeakRef(t));
|
|
3131
3552
|
let i = 0;
|
|
3132
3553
|
let j = 0;
|
|
@@ -3210,13 +3631,15 @@ function partialEvalGraphToJaxpr(tracersIn, tracersOut) {
|
|
|
3210
3631
|
const [consts, constvars] = require_backend.unzip2(constToVar.entries());
|
|
3211
3632
|
const inBinders = [...constvars, ...tracersIn.map((t) => tracerToVar.get(t))];
|
|
3212
3633
|
const outVars = tracersOut.map((t) => tracerToVar.get(t));
|
|
3213
|
-
|
|
3634
|
+
let jaxpr = new Jaxpr(inBinders, eqns, outVars);
|
|
3214
3635
|
typecheckJaxpr(jaxpr);
|
|
3215
3636
|
for (const t of consts) t.ref;
|
|
3216
3637
|
for (const t of tracersIn) t.dispose();
|
|
3217
3638
|
for (const t of tracersOut) t.dispose();
|
|
3639
|
+
jaxpr = jaxpr.simplify();
|
|
3640
|
+
if (require_backend.DEBUG >= 5) console.log("jaxpr from partial evaluation:\n" + jaxpr.toString());
|
|
3218
3641
|
return {
|
|
3219
|
-
jaxpr
|
|
3642
|
+
jaxpr,
|
|
3220
3643
|
consts
|
|
3221
3644
|
};
|
|
3222
3645
|
}
|
|
@@ -3325,12 +3748,72 @@ const transposeRules = {
|
|
|
3325
3748
|
if (op === require_backend.AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
|
|
3326
3749
|
else throw new NonlinearError(Primitive.Reduce);
|
|
3327
3750
|
},
|
|
3751
|
+
[Primitive.Pool]([ct], [x], { window, strides }) {
|
|
3752
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
|
|
3753
|
+
return bind(Primitive.PoolTranspose, [ct], {
|
|
3754
|
+
inShape: x.aval.shape,
|
|
3755
|
+
window,
|
|
3756
|
+
strides
|
|
3757
|
+
});
|
|
3758
|
+
},
|
|
3759
|
+
[Primitive.PoolTranspose]([ct], [x], { window, strides }) {
|
|
3760
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
|
|
3761
|
+
return bind(Primitive.Pool, [ct], {
|
|
3762
|
+
window,
|
|
3763
|
+
strides
|
|
3764
|
+
});
|
|
3765
|
+
},
|
|
3328
3766
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3329
3767
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3330
3768
|
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3331
3769
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3332
3770
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3333
3771
|
},
|
|
3772
|
+
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
3773
|
+
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
3774
|
+
const rev01 = [
|
|
3775
|
+
1,
|
|
3776
|
+
0,
|
|
3777
|
+
...require_backend.range(2, ct.ndim)
|
|
3778
|
+
];
|
|
3779
|
+
if (lhs instanceof UndefPrimal) {
|
|
3780
|
+
let kernel = rhs;
|
|
3781
|
+
kernel = transpose$1(kernel, rev01);
|
|
3782
|
+
kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
|
|
3783
|
+
const result = conv(ct, kernel, {
|
|
3784
|
+
strides: params.lhsDilation,
|
|
3785
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3786
|
+
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3787
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3788
|
+
const padBefore = dilatedKernel - 1 - pl;
|
|
3789
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3790
|
+
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
3791
|
+
return [padBefore, padAfter];
|
|
3792
|
+
}),
|
|
3793
|
+
lhsDilation: params.strides,
|
|
3794
|
+
rhsDilation: params.rhsDilation
|
|
3795
|
+
});
|
|
3796
|
+
return [result, null];
|
|
3797
|
+
} else {
|
|
3798
|
+
const newLhs = transpose$1(lhs, rev01);
|
|
3799
|
+
const newRhs = transpose$1(ct, rev01);
|
|
3800
|
+
let result = conv(newLhs, newRhs, {
|
|
3801
|
+
strides: params.rhsDilation,
|
|
3802
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3803
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3804
|
+
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3805
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3806
|
+
const padFromLhs = dilatedCt - dilatedLhs;
|
|
3807
|
+
const padFromRhs = dilatedKernel - pl - 1;
|
|
3808
|
+
return [pl, padFromLhs + padFromRhs];
|
|
3809
|
+
}),
|
|
3810
|
+
lhsDilation: params.lhsDilation,
|
|
3811
|
+
rhsDilation: params.strides
|
|
3812
|
+
});
|
|
3813
|
+
result = transpose$1(result, rev01);
|
|
3814
|
+
return [null, result];
|
|
3815
|
+
}
|
|
3816
|
+
},
|
|
3334
3817
|
[Primitive.Where]([ct], [cond, x, y]) {
|
|
3335
3818
|
const cts = [
|
|
3336
3819
|
null,
|
|
@@ -3422,20 +3905,28 @@ function vjpFlat(f, primalsIn) {
|
|
|
3422
3905
|
const transposeInputs = [...consts.map((c) => c.ref), ...primalsIn.map((t) => new UndefPrimal(t.aval))];
|
|
3423
3906
|
return evalJaxprTransposed(jaxpr, transposeInputs, cotangents);
|
|
3424
3907
|
};
|
|
3425
|
-
|
|
3908
|
+
const dispose$1 = () => {
|
|
3909
|
+
for (const c of consts) c.dispose();
|
|
3910
|
+
};
|
|
3911
|
+
return [
|
|
3912
|
+
primalsOut,
|
|
3913
|
+
fVjp,
|
|
3914
|
+
dispose$1
|
|
3915
|
+
];
|
|
3426
3916
|
}
|
|
3427
3917
|
function vjp$1(f, ...primalsIn) {
|
|
3428
3918
|
const [primalsInFlat, inTree] = flatten(primalsIn);
|
|
3429
3919
|
const [fFlat, outTree] = flattenFun(f, inTree);
|
|
3430
|
-
const [primalsOutFlat, fVjpFlat] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3920
|
+
const [primalsOutFlat, fVjpFlat, dispose$1] = vjpFlat(fFlat, primalsInFlat.map(pureArray));
|
|
3431
3921
|
if (outTree.value === void 0) throw new Error("outTree was not set in vjp");
|
|
3432
3922
|
const primalsOut = unflatten(outTree.value, primalsOutFlat);
|
|
3433
|
-
const fVjp = (cotangentsOut) => {
|
|
3923
|
+
const fVjp = ((cotangentsOut) => {
|
|
3434
3924
|
const [cotangentsOutFlat, outTree2] = flatten(cotangentsOut);
|
|
3435
3925
|
if (!outTree.value.equals(outTree2)) throw new TreeMismatchError("vjp", outTree.value, outTree2);
|
|
3436
3926
|
const cotangentsInFlat = fVjpFlat(...cotangentsOutFlat.map(pureArray));
|
|
3437
3927
|
return unflatten(inTree, cotangentsInFlat);
|
|
3438
|
-
};
|
|
3928
|
+
});
|
|
3929
|
+
fVjp.dispose = dispose$1;
|
|
3439
3930
|
return [primalsOut, fVjp];
|
|
3440
3931
|
}
|
|
3441
3932
|
function grad$1(f) {
|
|
@@ -3451,9 +3942,10 @@ function valueAndGrad$1(f) {
|
|
|
3451
3942
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
3452
3943
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3453
3944
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3454
|
-
if (y.dtype
|
|
3455
|
-
const [ct, ...rest] = fVjp(
|
|
3456
|
-
for (const r of rest)
|
|
3945
|
+
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3946
|
+
const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
|
|
3947
|
+
for (const r of rest) dispose(r);
|
|
3948
|
+
fVjp.dispose();
|
|
3457
3949
|
return [y, ct];
|
|
3458
3950
|
};
|
|
3459
3951
|
}
|
|
@@ -3461,11 +3953,84 @@ function jacrev$1(f) {
|
|
|
3461
3953
|
return function jacobianReverse(x) {
|
|
3462
3954
|
if (x.shape.length !== 1) throw new TypeError("jacrev only supports 1D inputs");
|
|
3463
3955
|
const [size$1] = x.shape;
|
|
3464
|
-
const pullback = (ct) =>
|
|
3956
|
+
const pullback = (ct) => {
|
|
3957
|
+
const [y, fVjp] = vjp$1(f, x);
|
|
3958
|
+
y.dispose();
|
|
3959
|
+
const [ret] = fVjp(ct);
|
|
3960
|
+
fVjp.dispose();
|
|
3961
|
+
return ret;
|
|
3962
|
+
};
|
|
3465
3963
|
return vmap$1(pullback, [1])(eye(size$1, void 0, { dtype: x.dtype }));
|
|
3466
3964
|
};
|
|
3467
3965
|
}
|
|
3468
3966
|
|
|
3967
|
+
//#endregion
|
|
3968
|
+
//#region src/lax.ts
|
|
3969
|
+
var lax_exports = {};
|
|
3970
|
+
__export(lax_exports, {
|
|
3971
|
+
conv: () => conv$1,
|
|
3972
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
3973
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
3974
|
+
reduceWindow: () => reduceWindow
|
|
3975
|
+
});
|
|
3976
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
3977
|
+
const padType = padding.toUpperCase();
|
|
3978
|
+
switch (padType) {
|
|
3979
|
+
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
3980
|
+
case "SAME":
|
|
3981
|
+
case "SAME_LOWER": {
|
|
3982
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
3983
|
+
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
3984
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
3985
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
3986
|
+
}
|
|
3987
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
3988
|
+
}
|
|
3989
|
+
}
|
|
3990
|
+
/**
|
|
3991
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
3992
|
+
*
|
|
3993
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
3994
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
3995
|
+
*
|
|
3996
|
+
* Grouped convolutions are not supported right now.
|
|
3997
|
+
*/
|
|
3998
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
3999
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
4000
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
4001
|
+
if (typeof padding === "string") {
|
|
4002
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
4003
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
4004
|
+
}
|
|
4005
|
+
return conv(lhs, rhs, {
|
|
4006
|
+
strides: windowStrides,
|
|
4007
|
+
padding,
|
|
4008
|
+
lhsDilation,
|
|
4009
|
+
rhsDilation
|
|
4010
|
+
});
|
|
4011
|
+
}
|
|
4012
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
4013
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
4014
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
4015
|
+
lhsDilation,
|
|
4016
|
+
rhsDilation
|
|
4017
|
+
});
|
|
4018
|
+
}
|
|
4019
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
4020
|
+
function conv$1(lhs, rhs, windowStrides, padding) {
|
|
4021
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
4022
|
+
}
|
|
4023
|
+
/** Reduce a computation over padded windows. */
|
|
4024
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
4025
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
4026
|
+
if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
|
|
4027
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
4028
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
4029
|
+
window: windowDimensions,
|
|
4030
|
+
strides: windowStrides
|
|
4031
|
+
}));
|
|
4032
|
+
}
|
|
4033
|
+
|
|
3469
4034
|
//#endregion
|
|
3470
4035
|
//#region src/numpy.ts
|
|
3471
4036
|
var numpy_exports = {};
|
|
@@ -3474,19 +4039,38 @@ __export(numpy_exports, {
|
|
|
3474
4039
|
DType: () => require_backend.DType,
|
|
3475
4040
|
abs: () => abs,
|
|
3476
4041
|
absolute: () => absolute,
|
|
4042
|
+
acos: () => acos,
|
|
4043
|
+
acosh: () => acosh,
|
|
3477
4044
|
add: () => add,
|
|
3478
4045
|
allclose: () => allclose,
|
|
3479
4046
|
arange: () => arange,
|
|
4047
|
+
arccos: () => arccos,
|
|
4048
|
+
arccosh: () => arccosh,
|
|
4049
|
+
arcsinh: () => arcsinh,
|
|
4050
|
+
arctan: () => arctan,
|
|
4051
|
+
arctan2: () => arctan2,
|
|
4052
|
+
arctanh: () => arctanh,
|
|
3480
4053
|
argmax: () => argmax,
|
|
3481
4054
|
argmin: () => argmin,
|
|
3482
4055
|
array: () => array,
|
|
4056
|
+
asin: () => asin,
|
|
4057
|
+
asinh: () => asinh,
|
|
3483
4058
|
astype: () => astype,
|
|
4059
|
+
atan: () => atan,
|
|
4060
|
+
atan2: () => atan2,
|
|
4061
|
+
atanh: () => atanh,
|
|
3484
4062
|
bool: () => bool,
|
|
4063
|
+
broadcastArrays: () => broadcastArrays,
|
|
4064
|
+
broadcastShapes: () => broadcastShapes,
|
|
4065
|
+
broadcastTo: () => broadcastTo,
|
|
4066
|
+
cbrt: () => cbrt,
|
|
3485
4067
|
clip: () => clip,
|
|
3486
4068
|
columnStack: () => columnStack,
|
|
3487
|
-
complex64: () => complex64,
|
|
3488
4069
|
concatenate: () => concatenate,
|
|
3489
4070
|
cos: () => cos,
|
|
4071
|
+
cosh: () => cosh,
|
|
4072
|
+
deg2rad: () => deg2rad,
|
|
4073
|
+
degrees: () => degrees,
|
|
3490
4074
|
diag: () => diag,
|
|
3491
4075
|
diagonal: () => diagonal,
|
|
3492
4076
|
divide: () => divide,
|
|
@@ -3497,23 +4081,29 @@ __export(numpy_exports, {
|
|
|
3497
4081
|
eulerGamma: () => eulerGamma,
|
|
3498
4082
|
exp: () => exp,
|
|
3499
4083
|
exp2: () => exp2,
|
|
4084
|
+
expm1: () => expm1,
|
|
3500
4085
|
eye: () => eye,
|
|
3501
4086
|
flip: () => flip,
|
|
3502
4087
|
fliplr: () => fliplr,
|
|
3503
4088
|
flipud: () => flipud,
|
|
4089
|
+
float16: () => float16,
|
|
3504
4090
|
float32: () => float32,
|
|
3505
4091
|
full: () => full,
|
|
4092
|
+
fullLike: () => fullLike$1,
|
|
3506
4093
|
greater: () => greater,
|
|
3507
4094
|
greaterEqual: () => greaterEqual,
|
|
3508
4095
|
hstack: () => hstack,
|
|
4096
|
+
hypot: () => hypot,
|
|
3509
4097
|
identity: () => identity$1,
|
|
3510
4098
|
inf: () => inf,
|
|
4099
|
+
inner: () => inner,
|
|
3511
4100
|
int32: () => int32,
|
|
3512
4101
|
less: () => less,
|
|
3513
4102
|
lessEqual: () => lessEqual,
|
|
3514
4103
|
linspace: () => linspace,
|
|
3515
4104
|
log: () => log,
|
|
3516
4105
|
log10: () => log10,
|
|
4106
|
+
log1p: () => log1p,
|
|
3517
4107
|
log2: () => log2,
|
|
3518
4108
|
matmul: () => matmul,
|
|
3519
4109
|
max: () => max,
|
|
@@ -3529,36 +4119,55 @@ __export(numpy_exports, {
|
|
|
3529
4119
|
negative: () => negative,
|
|
3530
4120
|
notEqual: () => notEqual,
|
|
3531
4121
|
ones: () => ones,
|
|
4122
|
+
onesLike: () => onesLike,
|
|
4123
|
+
outer: () => outer,
|
|
3532
4124
|
pad: () => pad,
|
|
3533
4125
|
permuteDims: () => permuteDims,
|
|
3534
4126
|
pi: () => pi,
|
|
4127
|
+
pow: () => pow,
|
|
4128
|
+
power: () => power,
|
|
3535
4129
|
prod: () => prod$1,
|
|
4130
|
+
promoteTypes: () => require_backend.promoteTypes,
|
|
4131
|
+
rad2deg: () => rad2deg,
|
|
4132
|
+
radians: () => radians,
|
|
3536
4133
|
ravel: () => ravel,
|
|
3537
4134
|
reciprocal: () => reciprocal,
|
|
4135
|
+
repeat: () => repeat,
|
|
3538
4136
|
reshape: () => reshape,
|
|
3539
|
-
scalar: () => scalar,
|
|
3540
4137
|
shape: () => shape,
|
|
4138
|
+
sign: () => sign,
|
|
3541
4139
|
sin: () => sin,
|
|
4140
|
+
sinh: () => sinh,
|
|
3542
4141
|
size: () => size,
|
|
4142
|
+
sqrt: () => sqrt,
|
|
3543
4143
|
square: () => square,
|
|
3544
4144
|
stack: () => stack,
|
|
4145
|
+
std: () => std,
|
|
4146
|
+
subtract: () => subtract,
|
|
3545
4147
|
sum: () => sum,
|
|
3546
4148
|
tan: () => tan,
|
|
4149
|
+
tanh: () => tanh,
|
|
4150
|
+
tile: () => tile,
|
|
3547
4151
|
transpose: () => transpose,
|
|
4152
|
+
tri: () => tri,
|
|
4153
|
+
tril: () => tril,
|
|
4154
|
+
triu: () => triu,
|
|
3548
4155
|
trueDivide: () => trueDivide,
|
|
3549
4156
|
trunc: () => trunc,
|
|
3550
4157
|
uint32: () => uint32,
|
|
4158
|
+
var_: () => var_,
|
|
3551
4159
|
vdot: () => vdot,
|
|
3552
4160
|
vecdot: () => vecdot,
|
|
3553
4161
|
vstack: () => vstack,
|
|
3554
4162
|
where: () => where,
|
|
3555
|
-
zeros: () => zeros
|
|
4163
|
+
zeros: () => zeros,
|
|
4164
|
+
zerosLike: () => zerosLike
|
|
3556
4165
|
});
|
|
3557
4166
|
const float32 = require_backend.DType.Float32;
|
|
3558
4167
|
const int32 = require_backend.DType.Int32;
|
|
3559
4168
|
const uint32 = require_backend.DType.Uint32;
|
|
3560
4169
|
const bool = require_backend.DType.Bool;
|
|
3561
|
-
const
|
|
4170
|
+
const float16 = require_backend.DType.Float16;
|
|
3562
4171
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
3563
4172
|
const e = Math.E;
|
|
3564
4173
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -3569,52 +4178,66 @@ const inf = Number.POSITIVE_INFINITY;
|
|
|
3569
4178
|
const nan = NaN;
|
|
3570
4179
|
/** This is Pi, `π = 3.14159265358979...` */
|
|
3571
4180
|
const pi = Math.PI;
|
|
3572
|
-
/** Element-wise addition, with broadcasting. */
|
|
4181
|
+
/** @function Element-wise addition, with broadcasting. */
|
|
3573
4182
|
const add = add$1;
|
|
3574
|
-
/** Element-wise multiplication, with broadcasting. */
|
|
4183
|
+
/** @function Element-wise multiplication, with broadcasting. */
|
|
3575
4184
|
const multiply = mul;
|
|
3576
|
-
/** Numerical negative of every element of an array. */
|
|
4185
|
+
/** @function Numerical negative of every element of an array. */
|
|
3577
4186
|
const negative = neg;
|
|
3578
|
-
/** Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
4187
|
+
/** @function Calculate element-wise reciprocal of the input. This is `1/x`. */
|
|
3579
4188
|
const reciprocal = reciprocal$1;
|
|
3580
|
-
/** Element-wise sine function (takes radians). */
|
|
4189
|
+
/** @function Element-wise sine function (takes radians). */
|
|
3581
4190
|
const sin = sin$1;
|
|
3582
|
-
/** Element-wise cosine function (takes radians). */
|
|
4191
|
+
/** @function Element-wise cosine function (takes radians). */
|
|
3583
4192
|
const cos = cos$1;
|
|
3584
|
-
/**
|
|
4193
|
+
/** @function Element-wise inverse sine function (inverse of sin). */
|
|
4194
|
+
const asin = asin$1;
|
|
4195
|
+
/** @function Element-wise inverse tangent function (inverse of tan). */
|
|
4196
|
+
const atan = atan$1;
|
|
4197
|
+
/** @function Calculate the exponential of all elements in the input array. */
|
|
3585
4198
|
const exp = exp$1;
|
|
3586
|
-
/** Calculate the natural logarithm of all elements in the input array. */
|
|
4199
|
+
/** @function Calculate the natural logarithm of all elements in the input array. */
|
|
3587
4200
|
const log = log$1;
|
|
3588
|
-
/**
|
|
4201
|
+
/** @function Calculate the square root of all elements in the input array. */
|
|
4202
|
+
const sqrt = sqrt$1;
|
|
4203
|
+
/** @function Return element-wise minimum of the input arrays. */
|
|
3589
4204
|
const minimum = min$1;
|
|
3590
|
-
/** Return element-wise maximum of the input arrays. */
|
|
4205
|
+
/** @function Return element-wise maximum of the input arrays. */
|
|
3591
4206
|
const maximum = max$1;
|
|
3592
|
-
/** Compare two arrays element-wise. */
|
|
4207
|
+
/** @function Compare two arrays element-wise. */
|
|
3593
4208
|
const greater = greater$1;
|
|
3594
|
-
/** Compare two arrays element-wise. */
|
|
4209
|
+
/** @function Compare two arrays element-wise. */
|
|
3595
4210
|
const less = less$1;
|
|
3596
|
-
/** Compare two arrays element-wise. */
|
|
4211
|
+
/** @function Compare two arrays element-wise. */
|
|
3597
4212
|
const equal = equal$1;
|
|
3598
|
-
/** Compare two arrays element-wise. */
|
|
4213
|
+
/** @function Compare two arrays element-wise. */
|
|
3599
4214
|
const notEqual = notEqual$1;
|
|
3600
|
-
/** Compare two arrays element-wise. */
|
|
4215
|
+
/** @function Compare two arrays element-wise. */
|
|
3601
4216
|
const greaterEqual = greaterEqual$1;
|
|
3602
|
-
/** Compare two arrays element-wise. */
|
|
4217
|
+
/** @function Compare two arrays element-wise. */
|
|
3603
4218
|
const lessEqual = lessEqual$1;
|
|
3604
|
-
/** Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
4219
|
+
/** @function Element-wise ternary operator, evaluates to `x` if cond else `y`. */
|
|
3605
4220
|
const where = where$1;
|
|
3606
|
-
/**
|
|
4221
|
+
/**
|
|
4222
|
+
* @function
|
|
4223
|
+
* Permute the dimensions of an array. Defaults to reversing the axis order.
|
|
4224
|
+
*/
|
|
3607
4225
|
const transpose = transpose$1;
|
|
3608
4226
|
/**
|
|
4227
|
+
* @function
|
|
3609
4228
|
* Give a new shape to an array without changing its data.
|
|
3610
4229
|
*
|
|
3611
4230
|
* One shape dimension can be -1. In this case, the value is inferred from the
|
|
3612
4231
|
* length of the array and remaining dimensions.
|
|
3613
4232
|
*/
|
|
3614
4233
|
const reshape = reshape$1;
|
|
3615
|
-
/**
|
|
4234
|
+
/**
|
|
4235
|
+
* @function
|
|
4236
|
+
* Move axes of an array to new positions. Other axes retain original order.
|
|
4237
|
+
*/
|
|
3616
4238
|
const moveaxis = moveaxis$1;
|
|
3617
4239
|
/**
|
|
4240
|
+
* @function
|
|
3618
4241
|
* Add padding (zeros) to an array.
|
|
3619
4242
|
*
|
|
3620
4243
|
* The `width` argument is either an integer or pair of integers, in which case
|
|
@@ -3622,11 +4245,29 @@ const moveaxis = moveaxis$1;
|
|
|
3622
4245
|
* pair specifies the padding for its corresponding axis.
|
|
3623
4246
|
*/
|
|
3624
4247
|
const pad = pad$1;
|
|
3625
|
-
/**
|
|
4248
|
+
/**
|
|
4249
|
+
* @function
|
|
4250
|
+
* Return the number of dimensions of an array. Does not consume array reference.
|
|
4251
|
+
*/
|
|
3626
4252
|
const ndim = ndim$1;
|
|
3627
|
-
/** Return the shape of an array. Does not consume array reference. */
|
|
4253
|
+
/** @function Return the shape of an array. Does not consume array reference. */
|
|
3628
4254
|
const shape = getShape;
|
|
3629
4255
|
/**
|
|
4256
|
+
* @function
|
|
4257
|
+
* Return an array of zeros with the same shape and type as a given array.
|
|
4258
|
+
*/
|
|
4259
|
+
const zerosLike = zerosLike$1;
|
|
4260
|
+
/**
|
|
4261
|
+
* @function
|
|
4262
|
+
* Return an array of ones with the same shape and type as a given array.
|
|
4263
|
+
*/
|
|
4264
|
+
const onesLike = onesLike$1;
|
|
4265
|
+
/**
|
|
4266
|
+
* @function
|
|
4267
|
+
* Return a full array with the same shape and type as a given array.
|
|
4268
|
+
*/
|
|
4269
|
+
const fullLike$1 = fullLike;
|
|
4270
|
+
/**
|
|
3630
4271
|
* Return the number of elements in an array, optionally along an axis.
|
|
3631
4272
|
* Does not consume array reference.
|
|
3632
4273
|
*/
|
|
@@ -3639,23 +4280,23 @@ function astype(a, dtype) {
|
|
|
3639
4280
|
return fudgeArray(a).astype(dtype);
|
|
3640
4281
|
}
|
|
3641
4282
|
/** Sum of the elements of the array over a given axis, or axes. */
|
|
3642
|
-
function sum(a, axis, opts) {
|
|
4283
|
+
function sum(a, axis = null, opts) {
|
|
3643
4284
|
return reduce(a, require_backend.AluOp.Add, axis, opts);
|
|
3644
4285
|
}
|
|
3645
4286
|
/** Product of the array elements over a given axis. */
|
|
3646
|
-
function prod$1(a, axis, opts) {
|
|
4287
|
+
function prod$1(a, axis = null, opts) {
|
|
3647
4288
|
return reduce(a, require_backend.AluOp.Mul, axis, opts);
|
|
3648
4289
|
}
|
|
3649
4290
|
/** Return the minimum of array elements along a given axis. */
|
|
3650
|
-
function min(a, axis, opts) {
|
|
4291
|
+
function min(a, axis = null, opts) {
|
|
3651
4292
|
return reduce(a, require_backend.AluOp.Min, axis, opts);
|
|
3652
4293
|
}
|
|
3653
4294
|
/** Return the maximum of array elements along a given axis. */
|
|
3654
|
-
function max(a, axis, opts) {
|
|
4295
|
+
function max(a, axis = null, opts) {
|
|
3655
4296
|
return reduce(a, require_backend.AluOp.Max, axis, opts);
|
|
3656
4297
|
}
|
|
3657
4298
|
/** Compute the average of the array elements along the specified axis. */
|
|
3658
|
-
function mean(a, axis, opts) {
|
|
4299
|
+
function mean(a, axis = null, opts) {
|
|
3659
4300
|
return fudgeArray(a).mean(axis, opts);
|
|
3660
4301
|
}
|
|
3661
4302
|
/**
|
|
@@ -3671,18 +4312,12 @@ function argmin(a, axis, opts) {
|
|
|
3671
4312
|
axis = 0;
|
|
3672
4313
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
3673
4314
|
const shape$1 = a.shape;
|
|
3674
|
-
const isMax = equal(a, min(a.ref, axis, {
|
|
4315
|
+
const isMax = equal(a, min(a.ref, axis, { keepdims: true }));
|
|
3675
4316
|
const length = scalar(shape$1[axis], {
|
|
3676
4317
|
dtype: int32,
|
|
3677
4318
|
device: a.device
|
|
3678
4319
|
});
|
|
3679
|
-
const idx =
|
|
3680
|
-
dtype: int32,
|
|
3681
|
-
device: a.device
|
|
3682
|
-
}), scalar(0, {
|
|
3683
|
-
dtype: int32,
|
|
3684
|
-
device: a.device
|
|
3685
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4320
|
+
const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3686
4321
|
dtype: int32,
|
|
3687
4322
|
device: a.device
|
|
3688
4323
|
}).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
|
|
@@ -3701,35 +4336,21 @@ function argmax(a, axis, opts) {
|
|
|
3701
4336
|
axis = 0;
|
|
3702
4337
|
} else axis = require_backend.checkAxis(axis, a.ndim);
|
|
3703
4338
|
const shape$1 = a.shape;
|
|
3704
|
-
const isMax = equal(a, max(a.ref, axis, {
|
|
4339
|
+
const isMax = equal(a, max(a.ref, axis, { keepdims: true }));
|
|
3705
4340
|
const length = scalar(shape$1[axis], {
|
|
3706
4341
|
dtype: int32,
|
|
3707
4342
|
device: a.device
|
|
3708
4343
|
});
|
|
3709
|
-
const idx =
|
|
3710
|
-
dtype: int32,
|
|
3711
|
-
device: a.device
|
|
3712
|
-
}), scalar(0, {
|
|
3713
|
-
dtype: int32,
|
|
3714
|
-
device: a.device
|
|
3715
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4344
|
+
const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3716
4345
|
dtype: int32,
|
|
3717
4346
|
device: a.device
|
|
3718
4347
|
}).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
|
|
3719
4348
|
return length.sub(max(idx, axis, opts));
|
|
3720
4349
|
}
|
|
3721
4350
|
/** Reverse the elements in an array along the given axes. */
|
|
3722
|
-
function flip(x, axis) {
|
|
4351
|
+
function flip(x, axis = null) {
|
|
3723
4352
|
const nd = ndim(x);
|
|
3724
|
-
|
|
3725
|
-
else if (typeof axis === "number") axis = [axis];
|
|
3726
|
-
const seen = /* @__PURE__ */ new Set();
|
|
3727
|
-
for (let i = 0; i < axis.length; i++) {
|
|
3728
|
-
if (axis[i] >= nd || axis[i] < -nd) throw new Error(`flip: axis ${axis[i]} out of bounds for array of ${nd} dimensions`);
|
|
3729
|
-
if (axis[i] < 0) axis[i] += nd;
|
|
3730
|
-
if (seen.has(axis[i])) throw new Error(`flip: duplicate axis ${axis[i]} in axis list`);
|
|
3731
|
-
seen.add(axis[i]);
|
|
3732
|
-
}
|
|
4353
|
+
axis = require_backend.normalizeAxis(axis, nd);
|
|
3733
4354
|
return flip$1(x, axis);
|
|
3734
4355
|
}
|
|
3735
4356
|
/**
|
|
@@ -3835,18 +4456,88 @@ function flipud(x) {
|
|
|
3835
4456
|
function fliplr(x) {
|
|
3836
4457
|
return flip(x, 1);
|
|
3837
4458
|
}
|
|
4459
|
+
/** @function Alternative name for `numpy.transpose()`. */
|
|
3838
4460
|
const permuteDims = transpose;
|
|
3839
4461
|
/** Return a 1-D flattened array containing the elements of the input. */
|
|
3840
4462
|
function ravel(a) {
|
|
3841
4463
|
return fudgeArray(a).ravel();
|
|
3842
4464
|
}
|
|
3843
4465
|
/**
|
|
4466
|
+
* Repeat each element of an array after themselves.
|
|
4467
|
+
*
|
|
4468
|
+
* If no axis is provided, use the flattened input array, and return a flat
|
|
4469
|
+
* output array.
|
|
4470
|
+
*/
|
|
4471
|
+
function repeat(a, repeats, axis) {
|
|
4472
|
+
if (!Number.isInteger(repeats) || repeats < 0) throw new Error(`repeat: repeats must be a non-negative integer, got ${repeats}`);
|
|
4473
|
+
a = fudgeArray(a);
|
|
4474
|
+
if (axis === void 0) {
|
|
4475
|
+
a = ravel(a);
|
|
4476
|
+
axis = 0;
|
|
4477
|
+
}
|
|
4478
|
+
axis = require_backend.checkAxis(axis, a.ndim);
|
|
4479
|
+
if (repeats === 1) return a;
|
|
4480
|
+
const broadcastedShape = a.shape.toSpliced(axis + 1, 0, repeats);
|
|
4481
|
+
const finalShape = a.shape.toSpliced(axis, 1, a.shape[axis] * repeats);
|
|
4482
|
+
return broadcast(a, broadcastedShape, [axis + 1]).reshape(finalShape);
|
|
4483
|
+
}
|
|
4484
|
+
/**
|
|
4485
|
+
* Construct an array by repeating A the number of times given by reps.
|
|
4486
|
+
*
|
|
4487
|
+
* If `A` is an array of shape `(d1, d2, ..., dn)` and `reps` is a sequence of
|
|
4488
|
+
* integers, the resulting array will have a shape of `(reps[0] * d1,
|
|
4489
|
+
* reps[1] * d2, ..., reps[n] * dn)`, with `A` tiled along each dimension.
|
|
4490
|
+
*/
|
|
4491
|
+
function tile(a, reps) {
|
|
4492
|
+
a = fudgeArray(a);
|
|
4493
|
+
if (typeof reps === "number") reps = [reps];
|
|
4494
|
+
if (!reps.every((r) => Number.isInteger(r) && r >= 0)) throw new Error(`tile: reps must be non-negative integers, got ${JSON.stringify(reps)}`);
|
|
4495
|
+
const ndiff = reps.length - a.ndim;
|
|
4496
|
+
if (ndiff > 0) a = a.reshape([...require_backend.rep(ndiff, 1), ...a.shape]);
|
|
4497
|
+
if (ndiff < 0) reps = [...require_backend.rep(-ndiff, 1), ...reps];
|
|
4498
|
+
const broadcastedShape = [];
|
|
4499
|
+
const broadcastAxes = [];
|
|
4500
|
+
for (let i = 0; i < a.ndim; i++) {
|
|
4501
|
+
if (reps[i] > 1) {
|
|
4502
|
+
broadcastedShape.push(reps[i]);
|
|
4503
|
+
broadcastAxes.push(broadcastedShape.length - 1);
|
|
4504
|
+
}
|
|
4505
|
+
broadcastedShape.push(a.shape[i]);
|
|
4506
|
+
}
|
|
4507
|
+
const finalShape = a.shape.map((d, i) => reps[i] * d);
|
|
4508
|
+
return broadcast(a, broadcastedShape, broadcastAxes).reshape(finalShape);
|
|
4509
|
+
}
|
|
4510
|
+
/**
|
|
4511
|
+
* Broadcast an array to a shape, with NumPy-style broadcasing rules.
|
|
4512
|
+
*
|
|
4513
|
+
* In other words, this lets you append axes to the left, and/or expand
|
|
4514
|
+
* dimensions where the shape is 1.
|
|
4515
|
+
*/
|
|
4516
|
+
function broadcastTo(a, shape$1) {
|
|
4517
|
+
const nd = ndim(a);
|
|
4518
|
+
if (shape$1.length < nd) throw new Error(`broadcastTo: target shape ${JSON.stringify(shape$1)} has fewer dimensions than input array: ${nd}`);
|
|
4519
|
+
return broadcast(a, shape$1, require_backend.range(shape$1.length - nd));
|
|
4520
|
+
}
|
|
4521
|
+
/** Broadcast input shapes to a common output shape. */
|
|
4522
|
+
function broadcastShapes(...shapes) {
|
|
4523
|
+
if (shapes.length === 0) return [];
|
|
4524
|
+
return shapes.reduce(generalBroadcast);
|
|
4525
|
+
}
|
|
4526
|
+
/** Broadcast arrays to a common shape. */
|
|
4527
|
+
function broadcastArrays(...arrays) {
|
|
4528
|
+
const shapes = arrays.map((a) => shape(a));
|
|
4529
|
+
const outShape = broadcastShapes(...shapes);
|
|
4530
|
+
return arrays.map((a) => broadcastTo(a, outShape));
|
|
4531
|
+
}
|
|
4532
|
+
/**
|
|
3844
4533
|
* Return specified diagonals.
|
|
3845
4534
|
*
|
|
3846
4535
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
3847
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
4536
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
3848
4537
|
*
|
|
3849
|
-
* This returns a view over the existing array.
|
|
4538
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
4539
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
4540
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
3850
4541
|
*/
|
|
3851
4542
|
function diagonal(a, offset, axis1, axis2) {
|
|
3852
4543
|
return fudgeArray(a).diagonal(offset, axis1, axis2);
|
|
@@ -3862,15 +4553,16 @@ function diag(v, k = 0) {
|
|
|
3862
4553
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
3863
4554
|
if (a.ndim === 1) {
|
|
3864
4555
|
const n = a.shape[0];
|
|
3865
|
-
const ret = where(eye(n).equal(1), a,
|
|
3866
|
-
if (k
|
|
3867
|
-
return ret;
|
|
4556
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike(a));
|
|
4557
|
+
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4558
|
+
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4559
|
+
else return ret;
|
|
3868
4560
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
3869
4561
|
else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
|
|
3870
4562
|
}
|
|
3871
4563
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
3872
4564
|
function allclose(actual, expected, options) {
|
|
3873
|
-
const { rtol = 1e-5, atol = 1e-
|
|
4565
|
+
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
3874
4566
|
const x = array(actual);
|
|
3875
4567
|
const y = array(expected);
|
|
3876
4568
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
@@ -3905,8 +4597,36 @@ function dot(x, y) {
|
|
|
3905
4597
|
]);
|
|
3906
4598
|
return dot$1(x, y);
|
|
3907
4599
|
}
|
|
3908
|
-
/**
|
|
3909
|
-
|
|
4600
|
+
/**
|
|
4601
|
+
* Compute the inner product of two arrays.
|
|
4602
|
+
*
|
|
4603
|
+
* Unlike `jax.numpy.matmul()` or `jax.numpy.dot()`, this always performs a
|
|
4604
|
+
* contraction on the last axis.
|
|
4605
|
+
*
|
|
4606
|
+
* Returned array has shape `[...x.shape[:-1], ...y.shape[:-1]]`.
|
|
4607
|
+
*/
|
|
4608
|
+
function inner(x, y) {
|
|
4609
|
+
x = reshape(x, shape(x).toSpliced(-1, 0, ...require_backend.rep(ndim(y) - 1, 1)));
|
|
4610
|
+
return dot$1(x, y);
|
|
4611
|
+
}
|
|
4612
|
+
/**
|
|
4613
|
+
* Compute the outer product of two arrays.
|
|
4614
|
+
*
|
|
4615
|
+
* If the input arrays are not 1D, they will be flattened. Returned array will
|
|
4616
|
+
* be of shape `[x.size, y.size]`.
|
|
4617
|
+
*/
|
|
4618
|
+
function outer(x, y) {
|
|
4619
|
+
x = ravel(x);
|
|
4620
|
+
y = ravel(y);
|
|
4621
|
+
return multiply(x.reshape([x.shape[0], 1]), y);
|
|
4622
|
+
}
|
|
4623
|
+
/** Vector dot product of two arrays along a given axis. */
|
|
4624
|
+
function vecdot(x, y, { axis } = {}) {
|
|
4625
|
+
const xaxis = require_backend.checkAxis(axis ?? -1, ndim(x));
|
|
4626
|
+
const yaxis = require_backend.checkAxis(axis ?? -1, ndim(y));
|
|
4627
|
+
if (shape(x)[xaxis] !== shape(y)[yaxis]) throw new Error(`vecdot: shapes ${JSON.stringify(shape(x))} and ${JSON.stringify(shape(y))} not aligned along axis ${axis}: ${shape(x)[xaxis]} != ${shape(y)[yaxis]}`);
|
|
4628
|
+
x = moveaxis(x, xaxis, -1);
|
|
4629
|
+
y = moveaxis(y, yaxis, -1);
|
|
3910
4630
|
return dot$1(x, y);
|
|
3911
4631
|
}
|
|
3912
4632
|
/**
|
|
@@ -3915,7 +4635,7 @@ function vecdot(x, y) {
|
|
|
3915
4635
|
* Like vecdot() but flattens the arguments first into vectors.
|
|
3916
4636
|
*/
|
|
3917
4637
|
function vdot(x, y) {
|
|
3918
|
-
return
|
|
4638
|
+
return dot$1(ravel(x), ravel(y));
|
|
3919
4639
|
}
|
|
3920
4640
|
/**
|
|
3921
4641
|
* Return a tuple of coordinate matrices from coordinate vectors.
|
|
@@ -3944,6 +4664,43 @@ function meshgrid(xs, { indexing } = {}) {
|
|
|
3944
4664
|
return xs.map((x, i) => broadcast(x, shape$1, [...require_backend.range(i), ...require_backend.range(i + 1, xs.length)]));
|
|
3945
4665
|
}
|
|
3946
4666
|
/**
|
|
4667
|
+
* Return an array with ones on and below the diagonal and zeros elsewhere.
|
|
4668
|
+
*
|
|
4669
|
+
* If `k` is provided, it specifies the sub-diagonal on and below which the
|
|
4670
|
+
* array is filled with ones. `k=0` is the main diagonal, `k<0` is below it, and
|
|
4671
|
+
* `k>0` is above it.
|
|
4672
|
+
*/
|
|
4673
|
+
function tri(n, m, k = 0, { dtype, device } = {}) {
|
|
4674
|
+
m ??= n;
|
|
4675
|
+
dtype ??= require_backend.DType.Float32;
|
|
4676
|
+
if (!Number.isInteger(n) || n < 0) throw new TypeError(`tri: n must be a non-negative integer, got ${n}`);
|
|
4677
|
+
if (!Number.isInteger(m) || m < 0) throw new TypeError(`tri: m must be a non-negative integer, got ${m}`);
|
|
4678
|
+
if (!Number.isInteger(k)) throw new TypeError(`tri: k must be an integer, got ${k}`);
|
|
4679
|
+
const rows = arange(k, n + k, 1, {
|
|
4680
|
+
dtype: require_backend.DType.Int32,
|
|
4681
|
+
device
|
|
4682
|
+
});
|
|
4683
|
+
const cols = arange(0, m, 1, {
|
|
4684
|
+
dtype: require_backend.DType.Int32,
|
|
4685
|
+
device
|
|
4686
|
+
});
|
|
4687
|
+
return rows.reshape([n, 1]).greaterEqual(cols).astype(dtype);
|
|
4688
|
+
}
|
|
4689
|
+
/** Return the lower triangle of an array. Must be of dimension >= 2. */
|
|
4690
|
+
function tril(a, k = 0) {
|
|
4691
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4692
|
+
a = fudgeArray(a);
|
|
4693
|
+
const [n, m] = a.shape.slice(-2);
|
|
4694
|
+
return where(tri(n, m, k, { dtype: bool }), a.ref, zerosLike(a));
|
|
4695
|
+
}
|
|
4696
|
+
/** Return the upper triangle of an array. Must be of dimension >= 2. */
|
|
4697
|
+
function triu(a, k = 0) {
|
|
4698
|
+
if (ndim(a) < 2) throw new TypeError(`tril: input array must be at least 2D, got ${ndim(a)}D`);
|
|
4699
|
+
a = fudgeArray(a);
|
|
4700
|
+
const [n, m] = a.shape.slice(-2);
|
|
4701
|
+
return where(tri(n, m, k - 1, { dtype: bool }), zerosLike(a.ref), a);
|
|
4702
|
+
}
|
|
4703
|
+
/**
|
|
3947
4704
|
* Clip (limit) the values in an array.
|
|
3948
4705
|
*
|
|
3949
4706
|
* Given an interval, values outside the interval are clipped to the interval
|
|
@@ -3967,18 +4724,70 @@ function absolute(x) {
|
|
|
3967
4724
|
x = fudgeArray(x);
|
|
3968
4725
|
return where(less(x.ref, 0), x.ref.mul(-1), x);
|
|
3969
4726
|
}
|
|
3970
|
-
/** Alias of `jax.numpy.absolute()`. */
|
|
4727
|
+
/** @function Alias of `jax.numpy.absolute()`. */
|
|
3971
4728
|
const abs = absolute;
|
|
4729
|
+
/** Return an element-wise indication of sign of the input. */
|
|
4730
|
+
function sign(x) {
|
|
4731
|
+
x = fudgeArray(x);
|
|
4732
|
+
return where(notEqual(x.ref, 0), where(less(x.ref, 0), -1, 1), 0);
|
|
4733
|
+
}
|
|
3972
4734
|
/** Calculate element-wise square of the input array. */
|
|
3973
4735
|
function square(x) {
|
|
3974
4736
|
x = fudgeArray(x);
|
|
3975
4737
|
return x.ref.mul(x);
|
|
3976
4738
|
}
|
|
3977
|
-
/**
|
|
4739
|
+
/** Element-wise tangent function (takes radians). */
|
|
3978
4740
|
function tan(x) {
|
|
3979
4741
|
x = fudgeArray(x);
|
|
3980
4742
|
return sin(x.ref).div(cos(x));
|
|
3981
4743
|
}
|
|
4744
|
+
/** Element-wise inverse cosine function (inverse of cos). */
|
|
4745
|
+
function acos(x) {
|
|
4746
|
+
return subtract(pi / 2, asin(x));
|
|
4747
|
+
}
|
|
4748
|
+
/**
|
|
4749
|
+
* @function
|
|
4750
|
+
* Return element-wise hypotenuse for the given legs of a right triangle.
|
|
4751
|
+
*
|
|
4752
|
+
* In the original NumPy/JAX implementation, this function is more numerically
|
|
4753
|
+
* stable than sqrt(x1**2 + x2**2). We don't currently implement those stability
|
|
4754
|
+
* improvements.
|
|
4755
|
+
*/
|
|
4756
|
+
const hypot = jit$1((x1, x2) => {
|
|
4757
|
+
return sqrt(square(x1).add(square(x2)));
|
|
4758
|
+
});
|
|
4759
|
+
/**
|
|
4760
|
+
* @function
|
|
4761
|
+
* Element-wise arc tangent of y/x with correct quadrant.
|
|
4762
|
+
*
|
|
4763
|
+
* Returns the angle in radians between the positive x-axis and the point (x, y).
|
|
4764
|
+
* The result is in the range [-π, π].
|
|
4765
|
+
*
|
|
4766
|
+
* Uses numerically stable formulas:
|
|
4767
|
+
* - When x >= 0: atan2(y, x) = 2 * atan(y / (sqrt(x^2 + y^2) + x))
|
|
4768
|
+
* - When x < 0: atan2(y, x) = 2 * atan((sqrt(x^2 + y^2) - x) / y)
|
|
4769
|
+
*
|
|
4770
|
+
* The output is ill-defined when both x and y are zero.
|
|
4771
|
+
*/
|
|
4772
|
+
const atan2 = jit$1((y, x) => {
|
|
4773
|
+
const r = sqrt(square(x.ref).add(square(y.ref)));
|
|
4774
|
+
const xNeg = less(x.ref, 0);
|
|
4775
|
+
const numer = where(xNeg.ref, r.ref.sub(x.ref), y.ref);
|
|
4776
|
+
const denom = where(xNeg, y, r.add(x));
|
|
4777
|
+
return atan(numer.div(denom)).mul(2);
|
|
4778
|
+
});
|
|
4779
|
+
/** @function Alias of `jax.numpy.acos()`. */
|
|
4780
|
+
const arccos = acos;
|
|
4781
|
+
/** @function Alias of `jax.numpy.atan()`. */
|
|
4782
|
+
const arctan = atan;
|
|
4783
|
+
/** @function Alias of `jax.numpy.atan2()`. */
|
|
4784
|
+
const arctan2 = atan2;
|
|
4785
|
+
/** Element-wise subtraction, with broadcasting. */
|
|
4786
|
+
function subtract(x, y) {
|
|
4787
|
+
x = fudgeArray(x);
|
|
4788
|
+
y = fudgeArray(y);
|
|
4789
|
+
return x.sub(y);
|
|
4790
|
+
}
|
|
3982
4791
|
/** Calculates the floating-point division of x by y element-wise. */
|
|
3983
4792
|
function trueDivide(x, y) {
|
|
3984
4793
|
x = fudgeArray(x);
|
|
@@ -3986,7 +4795,7 @@ function trueDivide(x, y) {
|
|
|
3986
4795
|
if (!require_backend.isFloatDtype(x.dtype) || !require_backend.isFloatDtype(y.dtype)) throw new TypeError(`trueDivide: x and y must be floating-point arrays, got ${x.dtype} and ${y.dtype}`);
|
|
3987
4796
|
return x.div(y);
|
|
3988
4797
|
}
|
|
3989
|
-
/** Alias of `jax.numpy.trueDivide()`. */
|
|
4798
|
+
/** @function Alias of `jax.numpy.trueDivide()`. */
|
|
3990
4799
|
const divide = trueDivide;
|
|
3991
4800
|
/** Round input to the nearest integer towards zero. */
|
|
3992
4801
|
function trunc(x) {
|
|
@@ -4004,15 +4813,151 @@ function log2(x) {
|
|
|
4004
4813
|
function log10(x) {
|
|
4005
4814
|
return log(x).mul(Math.LOG10E);
|
|
4006
4815
|
}
|
|
4816
|
+
/** Calculate `exp(x) - 1` element-wise. */
|
|
4817
|
+
function expm1(x) {
|
|
4818
|
+
return exp(x).sub(1);
|
|
4819
|
+
}
|
|
4820
|
+
/** Calculate the natural logarithm of `1 + x` element-wise. */
|
|
4821
|
+
function log1p(x) {
|
|
4822
|
+
return log(add(1, x));
|
|
4823
|
+
}
|
|
4824
|
+
/** Convert angles from degrees to radians. */
|
|
4825
|
+
function deg2rad(x) {
|
|
4826
|
+
return multiply(x, pi / 180);
|
|
4827
|
+
}
|
|
4828
|
+
/** @function Alias of `jax.numpy.deg2rad()`. */
|
|
4829
|
+
const radians = deg2rad;
|
|
4830
|
+
/** Convert angles from radians to degrees. */
|
|
4831
|
+
function rad2deg(x) {
|
|
4832
|
+
return multiply(x, 180 / pi);
|
|
4833
|
+
}
|
|
4834
|
+
/** @function Alias of `jax.numpy.rad2deg()`. */
|
|
4835
|
+
const degrees = rad2deg;
|
|
4836
|
+
/**
|
|
4837
|
+
* @function
|
|
4838
|
+
* Computes first array raised to power of second array, element-wise.
|
|
4839
|
+
*/
|
|
4840
|
+
const power = jit$1((x1, x2) => {
|
|
4841
|
+
return exp(log(x1).mul(x2));
|
|
4842
|
+
});
|
|
4843
|
+
/** @function Alias of `jax.numpy.power()`. */
|
|
4844
|
+
const pow = power;
|
|
4845
|
+
/** @function Calculate the element-wise cube root of the input array. */
|
|
4846
|
+
const cbrt = jit$1((x) => {
|
|
4847
|
+
const sgn = where(less(x.ref, 0), -1, 1);
|
|
4848
|
+
return sgn.ref.mul(exp(log(x.mul(sgn)).mul(1 / 3)));
|
|
4849
|
+
});
|
|
4850
|
+
/**
|
|
4851
|
+
* @function
|
|
4852
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
4853
|
+
*
|
|
4854
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4855
|
+
*/
|
|
4856
|
+
const sinh = jit$1((x) => {
|
|
4857
|
+
const ex = exp(x);
|
|
4858
|
+
const emx = reciprocal(ex.ref);
|
|
4859
|
+
return ex.sub(emx).mul(.5);
|
|
4860
|
+
});
|
|
4861
|
+
/**
|
|
4862
|
+
* @function
|
|
4863
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
4864
|
+
*
|
|
4865
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4866
|
+
*/
|
|
4867
|
+
const cosh = jit$1((x) => {
|
|
4868
|
+
const ex = exp(x);
|
|
4869
|
+
const emx = reciprocal(ex.ref);
|
|
4870
|
+
return ex.add(emx).mul(.5);
|
|
4871
|
+
});
|
|
4872
|
+
/**
|
|
4873
|
+
* @function
|
|
4874
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
4875
|
+
*
|
|
4876
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4877
|
+
*/
|
|
4878
|
+
const tanh = jit$1((x) => {
|
|
4879
|
+
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4880
|
+
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4881
|
+
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4882
|
+
});
|
|
4883
|
+
/**
|
|
4884
|
+
* @function
|
|
4885
|
+
* Calculate element-wise inverse hyperbolic sine of input.
|
|
4886
|
+
*
|
|
4887
|
+
* `arcsinh(x) = ln(x + sqrt(x^2 + 1))`
|
|
4888
|
+
*/
|
|
4889
|
+
const arcsinh = jit$1((x) => {
|
|
4890
|
+
return log(x.ref.add(sqrt(square(x).add(1))));
|
|
4891
|
+
});
|
|
4892
|
+
/**
|
|
4893
|
+
* @function
|
|
4894
|
+
* Calculate element-wise inverse hyperbolic cosine of input.
|
|
4895
|
+
*
|
|
4896
|
+
* `arccosh(x) = ln(x + sqrt(x^2 - 1))`
|
|
4897
|
+
*/
|
|
4898
|
+
const arccosh = jit$1((x) => {
|
|
4899
|
+
return log(x.ref.add(sqrt(square(x).sub(1))));
|
|
4900
|
+
});
|
|
4901
|
+
/**
|
|
4902
|
+
* @function
|
|
4903
|
+
* Calculate element-wise inverse hyperbolic tangent of input.
|
|
4904
|
+
*
|
|
4905
|
+
* `arctanh(x) = 0.5 * ln((1 + x) / (1 - x))`
|
|
4906
|
+
*/
|
|
4907
|
+
const arctanh = jit$1((x) => {
|
|
4908
|
+
return log(add(1, x.ref).div(subtract(1, x))).mul(.5);
|
|
4909
|
+
});
|
|
4910
|
+
/** @function Alias of `jax.numpy.arcsinh()`. */
|
|
4911
|
+
const asinh = arcsinh;
|
|
4912
|
+
/** @function Alias of `jax.numpy.arccosh()`. */
|
|
4913
|
+
const acosh = arccosh;
|
|
4914
|
+
/** @function Alias of `jax.numpy.arctanh()`. */
|
|
4915
|
+
const atanh = arctanh;
|
|
4916
|
+
/**
|
|
4917
|
+
* Compute the variance of an array.
|
|
4918
|
+
*
|
|
4919
|
+
* The variance is computed for the flattened array by default, otherwise over
|
|
4920
|
+
* the specified axis.
|
|
4921
|
+
*
|
|
4922
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4923
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4924
|
+
*/
|
|
4925
|
+
function var_(x, axis = null, opts) {
|
|
4926
|
+
x = fudgeArray(x);
|
|
4927
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
4928
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
4929
|
+
if (n === 0) throw new Error("var: cannot compute variance over zero-length axis");
|
|
4930
|
+
const mu = opts?.mean !== void 0 ? opts.mean : mean(x.ref, axis, { keepdims: true });
|
|
4931
|
+
return square(x.sub(mu)).sum(axis, { keepdims: opts?.keepdims }).mul(1 / (n - (opts?.correction ?? 0)));
|
|
4932
|
+
}
|
|
4933
|
+
/**
|
|
4934
|
+
* Compute the standard deviation of an array.
|
|
4935
|
+
*
|
|
4936
|
+
* The standard deviation is computed for the flattened array by default,
|
|
4937
|
+
* otherwise over the specified axis.
|
|
4938
|
+
*
|
|
4939
|
+
* If `correction` is provided, the divisor in calculation is `N - correction`,
|
|
4940
|
+
* where `N` represents the number of elements (e.g., for Bessel's correction).
|
|
4941
|
+
*/
|
|
4942
|
+
function std(x, axis = null, opts) {
|
|
4943
|
+
return sqrt(var_(x, axis, opts));
|
|
4944
|
+
}
|
|
4007
4945
|
|
|
4008
4946
|
//#endregion
|
|
4009
4947
|
//#region src/nn.ts
|
|
4010
4948
|
var nn_exports = {};
|
|
4011
4949
|
__export(nn_exports, {
|
|
4950
|
+
celu: () => celu,
|
|
4951
|
+
elu: () => elu,
|
|
4952
|
+
gelu: () => gelu,
|
|
4953
|
+
glu: () => glu,
|
|
4012
4954
|
identity: () => identity,
|
|
4955
|
+
leakyRelu: () => leakyRelu,
|
|
4013
4956
|
logSigmoid: () => logSigmoid,
|
|
4014
4957
|
logSoftmax: () => logSoftmax,
|
|
4958
|
+
logmeanexp: () => logmeanexp,
|
|
4015
4959
|
logsumexp: () => logsumexp,
|
|
4960
|
+
mish: () => mish,
|
|
4016
4961
|
oneHot: () => oneHot,
|
|
4017
4962
|
relu: () => relu,
|
|
4018
4963
|
relu6: () => relu6,
|
|
@@ -4021,6 +4966,8 @@ __export(nn_exports, {
|
|
|
4021
4966
|
softSign: () => softSign,
|
|
4022
4967
|
softmax: () => softmax,
|
|
4023
4968
|
softplus: () => softplus,
|
|
4969
|
+
squareplus: () => squareplus,
|
|
4970
|
+
standardize: () => standardize,
|
|
4024
4971
|
swish: () => swish
|
|
4025
4972
|
});
|
|
4026
4973
|
/**
|
|
@@ -4064,6 +5011,7 @@ function softSign(x) {
|
|
|
4064
5011
|
return x.ref.div(absolute(x).add(1));
|
|
4065
5012
|
}
|
|
4066
5013
|
/**
|
|
5014
|
+
* @function
|
|
4067
5015
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4068
5016
|
* Swish, computed element-wise:
|
|
4069
5017
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4072,11 +5020,9 @@ function softSign(x) {
|
|
|
4072
5020
|
*
|
|
4073
5021
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4074
5022
|
*/
|
|
4075
|
-
|
|
4076
|
-
x = fudgeArray(x);
|
|
4077
|
-
return x.ref.mul(sigmoid(x));
|
|
4078
|
-
}
|
|
5023
|
+
const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
|
|
4079
5024
|
/**
|
|
5025
|
+
* @function
|
|
4080
5026
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4081
5027
|
* Swish, computed element-wise:
|
|
4082
5028
|
* `silu(x) = x * sigmoid(x) = x / (1 + exp(-x))`.
|
|
@@ -4093,8 +5039,88 @@ const swish = silu;
|
|
|
4093
5039
|
function logSigmoid(x) {
|
|
4094
5040
|
return negative(softplus(negative(x)));
|
|
4095
5041
|
}
|
|
4096
|
-
/**
|
|
5042
|
+
/**
|
|
5043
|
+
* @function
|
|
5044
|
+
* Identity activation function. Returns the argument unmodified.
|
|
5045
|
+
*/
|
|
4097
5046
|
const identity = fudgeArray;
|
|
5047
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
5048
|
+
function leakyRelu(x, negativeSlope = .01) {
|
|
5049
|
+
x = fudgeArray(x);
|
|
5050
|
+
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
5051
|
+
}
|
|
5052
|
+
/**
|
|
5053
|
+
* Exponential linear unit activation function.
|
|
5054
|
+
*
|
|
5055
|
+
* Computes the element-wise function:
|
|
5056
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
5057
|
+
*/
|
|
5058
|
+
function elu(x, alpha = 1) {
|
|
5059
|
+
x = fudgeArray(x);
|
|
5060
|
+
return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
|
|
5061
|
+
}
|
|
5062
|
+
/**
|
|
5063
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
5064
|
+
*
|
|
5065
|
+
* Computes the element-wise function:
|
|
5066
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
5067
|
+
*/
|
|
5068
|
+
function celu(x, alpha = 1) {
|
|
5069
|
+
x = fudgeArray(x);
|
|
5070
|
+
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
5071
|
+
}
|
|
5072
|
+
/**
|
|
5073
|
+
* @function
|
|
5074
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
5075
|
+
*
|
|
5076
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
5077
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
5078
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
5079
|
+
*
|
|
5080
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
5081
|
+
*
|
|
5082
|
+
* This will be improved in the future.
|
|
5083
|
+
*/
|
|
5084
|
+
const gelu = jit$1((x) => {
|
|
5085
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
5086
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
5087
|
+
});
|
|
5088
|
+
/**
|
|
5089
|
+
* Gated linear unit (GLU) activation function.
|
|
5090
|
+
*
|
|
5091
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
5092
|
+
* computes `a * sigmoid(b)`.
|
|
5093
|
+
*/
|
|
5094
|
+
function glu(x, axis = -1) {
|
|
5095
|
+
x = fudgeArray(x);
|
|
5096
|
+
axis = require_backend.checkAxis(axis, x.ndim);
|
|
5097
|
+
const size$1 = x.shape[axis];
|
|
5098
|
+
if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
|
|
5099
|
+
const slice = x.shape.map((a$1) => [0, a$1]);
|
|
5100
|
+
const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
|
|
5101
|
+
const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
|
|
5102
|
+
return a.mul(sigmoid(b));
|
|
5103
|
+
}
|
|
5104
|
+
/**
|
|
5105
|
+
* Squareplus activation function.
|
|
5106
|
+
*
|
|
5107
|
+
* Computes the element-wise function:
|
|
5108
|
+
* `squareplus(x) = 0.5 * (x + sqrt(x^2 + b))`
|
|
5109
|
+
*/
|
|
5110
|
+
function squareplus(x, b = 4) {
|
|
5111
|
+
x = fudgeArray(x);
|
|
5112
|
+
return x.ref.add(sqrt(square(x).add(b))).mul(.5);
|
|
5113
|
+
}
|
|
5114
|
+
/**
|
|
5115
|
+
* Mish activation function.
|
|
5116
|
+
*
|
|
5117
|
+
* Computes the element-wise function:
|
|
5118
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
5119
|
+
*/
|
|
5120
|
+
function mish(x) {
|
|
5121
|
+
x = fudgeArray(x);
|
|
5122
|
+
return x.ref.mul(tanh(softplus(x)));
|
|
5123
|
+
}
|
|
4098
5124
|
/**
|
|
4099
5125
|
* Softmax function. Computes the function which rescales elements to the range
|
|
4100
5126
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -4103,17 +5129,13 @@ const identity = fudgeArray;
|
|
|
4103
5129
|
*
|
|
4104
5130
|
* Reference: https://en.wikipedia.org/wiki/Softmax_function
|
|
4105
5131
|
*/
|
|
4106
|
-
function softmax(x, axis) {
|
|
5132
|
+
function softmax(x, axis = -1) {
|
|
4107
5133
|
x = fudgeArray(x);
|
|
4108
|
-
|
|
4109
|
-
|
|
4110
|
-
|
|
4111
|
-
x.dispose();
|
|
4112
|
-
return ones(x.shape);
|
|
4113
|
-
}
|
|
4114
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5134
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5135
|
+
if (axis.length === 0) return onesLike(x);
|
|
5136
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4115
5137
|
const unnormalized = exp(x.sub(stopGradient(xMax)));
|
|
4116
|
-
return unnormalized.ref.div(unnormalized.sum(axis, {
|
|
5138
|
+
return unnormalized.ref.div(unnormalized.sum(axis, { keepdims: true }));
|
|
4117
5139
|
}
|
|
4118
5140
|
/**
|
|
4119
5141
|
* Log-Softmax function.
|
|
@@ -4123,17 +5145,13 @@ function softmax(x, axis) {
|
|
|
4123
5145
|
*
|
|
4124
5146
|
* If `axis` is not specified, it defaults to the last axis.
|
|
4125
5147
|
*/
|
|
4126
|
-
function logSoftmax(x, axis) {
|
|
5148
|
+
function logSoftmax(x, axis = -1) {
|
|
4127
5149
|
x = fudgeArray(x);
|
|
4128
|
-
|
|
4129
|
-
|
|
4130
|
-
|
|
4131
|
-
x.dispose();
|
|
4132
|
-
return zeros(x.shape);
|
|
4133
|
-
}
|
|
4134
|
-
const xMax = max(x.ref, axis, { keepDims: true });
|
|
5150
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5151
|
+
if (axis.length === 0) return zerosLike(x);
|
|
5152
|
+
const xMax = max(x.ref, axis, { keepdims: true });
|
|
4135
5153
|
const shifted = x.sub(stopGradient(xMax));
|
|
4136
|
-
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, {
|
|
5154
|
+
const shiftedLogsumexp = log(exp(shifted.ref).sum(axis, { keepdims: true }));
|
|
4137
5155
|
return shifted.sub(shiftedLogsumexp);
|
|
4138
5156
|
}
|
|
4139
5157
|
/**
|
|
@@ -4144,16 +5162,39 @@ function logSoftmax(x, axis) {
|
|
|
4144
5162
|
*
|
|
4145
5163
|
* Reference: https://en.wikipedia.org/wiki/LogSumExp
|
|
4146
5164
|
*/
|
|
4147
|
-
function logsumexp(x, axis) {
|
|
5165
|
+
function logsumexp(x, axis = null) {
|
|
4148
5166
|
x = fudgeArray(x);
|
|
4149
|
-
|
|
4150
|
-
else if (typeof axis === "number") axis = [axis];
|
|
5167
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
4151
5168
|
if (axis.length === 0) return x;
|
|
4152
5169
|
const xMax = stopGradient(max(x.ref, axis));
|
|
4153
5170
|
const xMaxDims = broadcast(xMax.ref, x.shape, axis);
|
|
4154
5171
|
const shifted = x.sub(xMaxDims);
|
|
4155
5172
|
return xMax.add(log(exp(shifted).sum(axis)));
|
|
4156
5173
|
}
|
|
5174
|
+
/** Log-mean-exp reduction, like `jax.nn.logsumexp()` but subtracts `log(n)`. */
|
|
5175
|
+
function logmeanexp(x, axis = null) {
|
|
5176
|
+
x = fudgeArray(x);
|
|
5177
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5178
|
+
if (axis.length === 0) return x;
|
|
5179
|
+
const n = axis.reduce((acc, a) => acc * x.shape[a], 1);
|
|
5180
|
+
return logsumexp(x, axis).sub(Math.log(n));
|
|
5181
|
+
}
|
|
5182
|
+
/**
|
|
5183
|
+
* Standardizes input to zero mean and unit variance.
|
|
5184
|
+
*
|
|
5185
|
+
* By default, this is computed over the last axis. You can pass in a different
|
|
5186
|
+
* axis, or `null` to standardize over all elements.
|
|
5187
|
+
*
|
|
5188
|
+
* Epsilon is added to denominator, it defaults to `1e-5` for stability.
|
|
5189
|
+
*/
|
|
5190
|
+
function standardize(x, axis = -1, opts = {}) {
|
|
5191
|
+
x = fudgeArray(x);
|
|
5192
|
+
axis = require_backend.normalizeAxis(axis, x.ndim);
|
|
5193
|
+
if (axis.length === 0) return x;
|
|
5194
|
+
const mu = opts.mean !== void 0 ? fudgeArray(opts.mean) : x.ref.mean(axis, { keepdims: true });
|
|
5195
|
+
const sigma2 = opts.variance !== void 0 ? fudgeArray(opts.variance) : square(x.ref).mean(axis, { keepdims: true }).sub(square(mu.ref));
|
|
5196
|
+
return x.sub(mu).div(sqrt(sigma2.add(opts.epsilon ?? 1e-5)));
|
|
5197
|
+
}
|
|
4157
5198
|
/**
|
|
4158
5199
|
* One-hot encodes the given indices.
|
|
4159
5200
|
*
|
|
@@ -4171,7 +5212,7 @@ function logsumexp(x, axis) {
|
|
|
4171
5212
|
* ```
|
|
4172
5213
|
*/
|
|
4173
5214
|
function oneHot(x, numClasses) {
|
|
4174
|
-
if (x.dtype
|
|
5215
|
+
if (require_backend.isFloatDtype(x.dtype)) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4175
5216
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4176
5217
|
}
|
|
4177
5218
|
|
|
@@ -4179,8 +5220,11 @@ function oneHot(x, numClasses) {
|
|
|
4179
5220
|
//#region src/random.ts
|
|
4180
5221
|
var random_exports = {};
|
|
4181
5222
|
__export(random_exports, {
|
|
5223
|
+
bernoulli: () => bernoulli,
|
|
4182
5224
|
bits: () => bits,
|
|
5225
|
+
exponential: () => exponential,
|
|
4183
5226
|
key: () => key,
|
|
5227
|
+
normal: () => normal,
|
|
4184
5228
|
split: () => split,
|
|
4185
5229
|
uniform: () => uniform
|
|
4186
5230
|
});
|
|
@@ -4211,11 +5255,11 @@ function bits(key$1, shape$1 = []) {
|
|
|
4211
5255
|
/** Sample uniform random values in [minval, maxval) with given shape. */
|
|
4212
5256
|
function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
4213
5257
|
if (minval >= maxval) throw new Error(`Invalid range: [${minval}, ${maxval}).`);
|
|
4214
|
-
const mantissa = bits(key$1, shape$1).div(
|
|
5258
|
+
const mantissa = bits(key$1, shape$1).div(array(512, {
|
|
4215
5259
|
dtype: require_backend.DType.Uint32,
|
|
4216
5260
|
device: key$1.device
|
|
4217
5261
|
}));
|
|
4218
|
-
const float12 = mantissa.add(
|
|
5262
|
+
const float12 = mantissa.add(array(1065353216, {
|
|
4219
5263
|
dtype: require_backend.DType.Uint32,
|
|
4220
5264
|
device: key$1.device
|
|
4221
5265
|
}));
|
|
@@ -4223,6 +5267,36 @@ function uniform(key$1, shape$1 = [], { minval = 0, maxval = 1 } = {}) {
|
|
|
4223
5267
|
if (minval === 0 && maxval === 1) return rand;
|
|
4224
5268
|
else return rand.mul(maxval - minval).add(minval);
|
|
4225
5269
|
}
|
|
5270
|
+
/**
|
|
5271
|
+
* Sample Bernoulli random variables with given mean (0,1 categorical).
|
|
5272
|
+
*
|
|
5273
|
+
* Returns a random Boolean array with the specified shape. `p` can be an array
|
|
5274
|
+
* and must be broadcastable to `shape`.
|
|
5275
|
+
*/
|
|
5276
|
+
function bernoulli(key$1, p = .5, shape$1 = []) {
|
|
5277
|
+
p = fudgeArray(p);
|
|
5278
|
+
return uniform(key$1, shape$1).less(p);
|
|
5279
|
+
}
|
|
5280
|
+
/** Sample exponential random values according to `p(x) = exp(-x)`. */
|
|
5281
|
+
function exponential(key$1, shape$1 = []) {
|
|
5282
|
+
const u = uniform(key$1, shape$1);
|
|
5283
|
+
return negative(log1p(negative(u)));
|
|
5284
|
+
}
|
|
5285
|
+
/**
|
|
5286
|
+
* Sample random values according to `p(x) = 1/sqrt(2pi) * exp(-x^2/2)`.
|
|
5287
|
+
*
|
|
5288
|
+
* Unlike JAX, this uses the Box-Muller transform. JAX uses the erf_inv primitive instead and
|
|
5289
|
+
* directly inverts the CDF, but we don't have support for that yet. Outputs will not be
|
|
5290
|
+
* bitwise identical to JAX.
|
|
5291
|
+
*/
|
|
5292
|
+
function normal(key$1, shape$1 = []) {
|
|
5293
|
+
const [k1, k2] = split(key$1, 2);
|
|
5294
|
+
const u1 = uniform(k1, shape$1);
|
|
5295
|
+
const u2 = uniform(k2, shape$1);
|
|
5296
|
+
const radius = sqrt(log1p(negative(u1)).mul(-2));
|
|
5297
|
+
const theta = u2.mul(2 * Math.PI);
|
|
5298
|
+
return radius.mul(cos(theta));
|
|
5299
|
+
}
|
|
4226
5300
|
|
|
4227
5301
|
//#endregion
|
|
4228
5302
|
//#region src/polyfills.ts
|
|
@@ -4232,35 +5306,98 @@ Symbol.asyncDispose ??= Symbol.for("Symbol.asyncDispose");
|
|
|
4232
5306
|
|
|
4233
5307
|
//#endregion
|
|
4234
5308
|
//#region src/index.ts
|
|
4235
|
-
/**
|
|
5309
|
+
/**
|
|
5310
|
+
* @function
|
|
5311
|
+
* Compute the forward-mode Jacobian-vector product for a function.
|
|
5312
|
+
*/
|
|
4236
5313
|
const jvp = jvp$1;
|
|
4237
|
-
/**
|
|
5314
|
+
/**
|
|
5315
|
+
* @function
|
|
5316
|
+
* Vectorize an operation on a batched axis for one or more inputs.
|
|
5317
|
+
*/
|
|
4238
5318
|
const vmap = vmap$1;
|
|
4239
|
-
/**
|
|
5319
|
+
/**
|
|
5320
|
+
* @function
|
|
5321
|
+
* Compute the Jacobian evaluated column-by-column by forward-mode AD.
|
|
5322
|
+
*/
|
|
4240
5323
|
const jacfwd = jacfwd$1;
|
|
4241
|
-
/**
|
|
5324
|
+
/**
|
|
5325
|
+
* @function
|
|
5326
|
+
* Construct a Jaxpr by dynamically tracing a function with example inputs.
|
|
5327
|
+
*/
|
|
4242
5328
|
const makeJaxpr = makeJaxpr$1;
|
|
5329
|
+
/**
|
|
5330
|
+
* @function
|
|
5331
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
5332
|
+
*
|
|
5333
|
+
* The function will be compiled the first time it is called with a set of
|
|
5334
|
+
* argument shapes.
|
|
5335
|
+
*
|
|
5336
|
+
* You can call `.dispose()` on the returned, JIT-compiled function after all
|
|
5337
|
+
* calls to free memory associated with array constants.
|
|
5338
|
+
*
|
|
5339
|
+
* **Options:**
|
|
5340
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
5341
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
5342
|
+
* and different values will trigger recompilation.
|
|
5343
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
5344
|
+
* computation will be placed on the default device.
|
|
5345
|
+
*/
|
|
4243
5346
|
const jit = jit$1;
|
|
4244
5347
|
/**
|
|
5348
|
+
* @function
|
|
4245
5349
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
4246
5350
|
* partial evaluation.
|
|
4247
5351
|
*/
|
|
4248
5352
|
const linearize = linearize$1;
|
|
4249
|
-
/**
|
|
5353
|
+
/**
|
|
5354
|
+
* @function
|
|
5355
|
+
* Calculate the reverse-mode vector-Jacobian product for a function.
|
|
5356
|
+
*/
|
|
4250
5357
|
const vjp = vjp$1;
|
|
4251
5358
|
/**
|
|
5359
|
+
* @function
|
|
4252
5360
|
* Compute the gradient of a scalar-valued function `f` with respect to its
|
|
4253
5361
|
* first argument.
|
|
4254
5362
|
*/
|
|
4255
5363
|
const grad = grad$1;
|
|
4256
|
-
/**
|
|
5364
|
+
/**
|
|
5365
|
+
* @function
|
|
5366
|
+
* Create a function that evaluates both `f` and the gradient of `f`.
|
|
5367
|
+
*/
|
|
4257
5368
|
const valueAndGrad = valueAndGrad$1;
|
|
4258
|
-
/**
|
|
5369
|
+
/**
|
|
5370
|
+
* @function
|
|
5371
|
+
* Compute the Jacobian evaluated row-by-row by reverse-mode AD.
|
|
5372
|
+
*/
|
|
4259
5373
|
const jacrev = jacrev$1;
|
|
4260
|
-
/**
|
|
5374
|
+
/**
|
|
5375
|
+
* @function
|
|
5376
|
+
* Compute the Jacobian with reverse-mode AD. Alias for `jacrev()`.
|
|
5377
|
+
*/
|
|
4261
5378
|
const jacobian = jacrev;
|
|
5379
|
+
/**
|
|
5380
|
+
* Wait until all `Array` leaves are ready by calling `Array.blockUntilReady()`.
|
|
5381
|
+
*
|
|
5382
|
+
* This can be used to wait for the results of an intermediate computation to
|
|
5383
|
+
* finish. It's recommended to call this regularly in an iterative computation
|
|
5384
|
+
* to avoid queueing up too many pending operations.
|
|
5385
|
+
*
|
|
5386
|
+
* Does not consume reference to the arrays.
|
|
5387
|
+
*/
|
|
5388
|
+
async function blockUntilReady(x) {
|
|
5389
|
+
const promises = [];
|
|
5390
|
+
for (const leaf of leaves(x)) if (leaf instanceof Array$1) promises.push(leaf.blockUntilReady());
|
|
5391
|
+
await Promise.all(promises);
|
|
5392
|
+
return x;
|
|
5393
|
+
}
|
|
4262
5394
|
|
|
4263
5395
|
//#endregion
|
|
5396
|
+
exports.Array = Array$1;
|
|
5397
|
+
exports.DType = require_backend.DType;
|
|
5398
|
+
exports.Jaxpr = Jaxpr;
|
|
5399
|
+
exports.blockUntilReady = blockUntilReady;
|
|
5400
|
+
exports.defaultDevice = require_backend.defaultDevice;
|
|
4264
5401
|
exports.devices = require_backend.devices;
|
|
4265
5402
|
exports.grad = grad;
|
|
4266
5403
|
exports.init = require_backend.init;
|
|
@@ -4269,6 +5406,12 @@ exports.jacobian = jacobian;
|
|
|
4269
5406
|
exports.jacrev = jacrev;
|
|
4270
5407
|
exports.jit = jit;
|
|
4271
5408
|
exports.jvp = jvp;
|
|
5409
|
+
Object.defineProperty(exports, 'lax', {
|
|
5410
|
+
enumerable: true,
|
|
5411
|
+
get: function () {
|
|
5412
|
+
return lax_exports;
|
|
5413
|
+
}
|
|
5414
|
+
});
|
|
4272
5415
|
exports.linearize = linearize;
|
|
4273
5416
|
exports.makeJaxpr = makeJaxpr;
|
|
4274
5417
|
Object.defineProperty(exports, 'nn', {
|
|
@@ -4289,7 +5432,7 @@ Object.defineProperty(exports, 'random', {
|
|
|
4289
5432
|
return random_exports;
|
|
4290
5433
|
}
|
|
4291
5434
|
});
|
|
4292
|
-
exports.
|
|
5435
|
+
exports.setDebug = require_backend.setDebug;
|
|
4293
5436
|
Object.defineProperty(exports, 'tree', {
|
|
4294
5437
|
enumerable: true,
|
|
4295
5438
|
get: function () {
|