@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
package/dist/index.cjs
CHANGED
|
@@ -30,13 +30,14 @@ var __toESM = (mod, isNodeMode, target) => (target = mod != null ? __create(__ge
|
|
|
30
30
|
}) : target, mod));
|
|
31
31
|
|
|
32
32
|
//#endregion
|
|
33
|
-
const require_backend = require('./backend-
|
|
33
|
+
const require_backend = require('./backend-D2C4MJRP.cjs');
|
|
34
34
|
|
|
35
35
|
//#region src/tree.ts
|
|
36
36
|
var tree_exports = {};
|
|
37
37
|
__export(tree_exports, {
|
|
38
38
|
JsTreeDef: () => JsTreeDef,
|
|
39
39
|
NodeType: () => NodeType,
|
|
40
|
+
dispose: () => dispose,
|
|
40
41
|
flatten: () => flatten,
|
|
41
42
|
leaves: () => leaves,
|
|
42
43
|
map: () => map,
|
|
@@ -51,7 +52,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
|
51
52
|
NodeType$1["Leaf"] = "Leaf";
|
|
52
53
|
return NodeType$1;
|
|
53
54
|
}({});
|
|
54
|
-
/**
|
|
55
|
+
/** Represents the structure of a JsTree. */
|
|
55
56
|
var JsTreeDef = class JsTreeDef {
|
|
56
57
|
static leaf = new JsTreeDef(NodeType.Leaf, null, []);
|
|
57
58
|
constructor(nodeType, nodeMetadata, childTreedefs) {
|
|
@@ -139,6 +140,194 @@ function map(fn, tree, ...rest) {
|
|
|
139
140
|
function ref(tree) {
|
|
140
141
|
return map((x) => x.ref, tree);
|
|
141
142
|
}
|
|
143
|
+
/** Dispose every array in a tree. */
|
|
144
|
+
function dispose(tree) {
|
|
145
|
+
if (tree) map((x) => x.dispose(), tree);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
//#endregion
|
|
149
|
+
//#region src/frontend/convolution.ts
|
|
150
|
+
/**
|
|
151
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
152
|
+
*
|
|
153
|
+
* If the check succeeds, returns the output shape.
|
|
154
|
+
*/
|
|
155
|
+
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
156
|
+
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
157
|
+
const n = lhsShape.length - 2;
|
|
158
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
159
|
+
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
160
|
+
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
161
|
+
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
162
|
+
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
163
|
+
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
164
|
+
const outShape = [lhsShape[0], rhsShape[0]];
|
|
165
|
+
for (let i = 0; i < n; i++) {
|
|
166
|
+
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
167
|
+
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
168
|
+
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
169
|
+
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
170
|
+
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
171
|
+
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
172
|
+
const [pl, pr] = padding[i];
|
|
173
|
+
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
174
|
+
const kernelSize = (k - 1) * rhsDilation[i] + 1;
|
|
175
|
+
const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
|
|
176
|
+
if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
|
|
177
|
+
outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
|
|
178
|
+
}
|
|
179
|
+
return outShape;
|
|
180
|
+
}
|
|
181
|
+
function checkPoolShape(inShape, window, strides) {
|
|
182
|
+
if (strides.length !== window.length) throw new Error("pool() strides != window dims");
|
|
183
|
+
if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
|
|
184
|
+
const outShape = inShape.slice(0, inShape.length - window.length);
|
|
185
|
+
for (let i = 0; i < window.length; i++) {
|
|
186
|
+
const k = window[i];
|
|
187
|
+
const s = strides[i];
|
|
188
|
+
const size$1 = inShape[inShape.length - window.length + i];
|
|
189
|
+
if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
|
|
190
|
+
if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
|
|
191
|
+
if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
|
|
192
|
+
outShape.push(Math.ceil((size$1 - k + 1) / s));
|
|
193
|
+
}
|
|
194
|
+
return outShape.concat(window);
|
|
195
|
+
}
|
|
196
|
+
/**
|
|
197
|
+
* Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
|
|
198
|
+
* `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
|
|
199
|
+
* spatial dimensions convolved with a kernel.
|
|
200
|
+
*
|
|
201
|
+
* The resulting array can be multiplied with a kernel of shape `ks`, then
|
|
202
|
+
* reduced along the last `ks.length` dimensions for a convolution.
|
|
203
|
+
*
|
|
204
|
+
* Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
|
|
205
|
+
*/
|
|
206
|
+
function pool(st, ks, strides = 1, dilation = 1) {
|
|
207
|
+
if (ks.length === 0) return st;
|
|
208
|
+
if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
|
|
209
|
+
if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
|
|
210
|
+
if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
|
|
211
|
+
if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
|
|
212
|
+
if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
|
|
213
|
+
const noop = st.shape.slice(0, -ks.length);
|
|
214
|
+
const i_ = st.shape.slice(-ks.length);
|
|
215
|
+
const s_ = strides;
|
|
216
|
+
const d_ = dilation;
|
|
217
|
+
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
218
|
+
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
219
|
+
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
220
|
+
st = st.repeat([...require_backend.rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
221
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
|
|
222
|
+
const kos = require_backend.zipn(ks, o_, s_);
|
|
223
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
|
|
224
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
|
|
225
|
+
[0, k],
|
|
226
|
+
[0, o],
|
|
227
|
+
[0, 1]
|
|
228
|
+
])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
|
|
229
|
+
st = st.permute([
|
|
230
|
+
...require_backend.range(noop.length),
|
|
231
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
232
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
233
|
+
]);
|
|
234
|
+
return st;
|
|
235
|
+
}
|
|
236
|
+
/**
|
|
237
|
+
* Perform the transpose of pool, directly undo-ing a pool() operation.
|
|
238
|
+
*
|
|
239
|
+
* Note that since pool repeats the input, the transpose operation technically
|
|
240
|
+
* should include a sum reduction. This function doesn't perform the reduction,
|
|
241
|
+
* which should be done on the last `k` axes of the returned shape.
|
|
242
|
+
*/
|
|
243
|
+
function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
244
|
+
if (ks.length === 0) return st;
|
|
245
|
+
if (typeof strides === "number") strides = require_backend.rep(ks.length, strides);
|
|
246
|
+
if (typeof dilation === "number") dilation = require_backend.rep(ks.length, dilation);
|
|
247
|
+
const noop = inShape.slice(0, -ks.length);
|
|
248
|
+
const i_ = inShape.slice(-ks.length);
|
|
249
|
+
const s_ = strides;
|
|
250
|
+
const d_ = dilation;
|
|
251
|
+
const o_ = require_backend.zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
252
|
+
if (!require_backend.deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
253
|
+
const f_ = require_backend.zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
254
|
+
const kidf = require_backend.zipn(ks, i_, d_, f_);
|
|
255
|
+
const kos = require_backend.zipn(ks, o_, s_);
|
|
256
|
+
st = st.permute([...require_backend.range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
|
|
257
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
|
|
258
|
+
k,
|
|
259
|
+
o,
|
|
260
|
+
1
|
|
261
|
+
])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
|
|
262
|
+
[0, 0],
|
|
263
|
+
[0, 0],
|
|
264
|
+
[0, s - 1]
|
|
265
|
+
])]);
|
|
266
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
|
|
267
|
+
st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
|
|
268
|
+
st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
|
|
269
|
+
...require_backend.range(noop.length),
|
|
270
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
271
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
272
|
+
]);
|
|
273
|
+
return st;
|
|
274
|
+
}
|
|
275
|
+
/** Applies dilation to an array directly, for transposed convolution. */
|
|
276
|
+
function applyDilation(st, dilation) {
|
|
277
|
+
if (dilation.every((s) => s === 1)) return st;
|
|
278
|
+
const s_ = dilation;
|
|
279
|
+
const [a, b, ...k_] = st.shape;
|
|
280
|
+
st = st.reshape([
|
|
281
|
+
a,
|
|
282
|
+
b,
|
|
283
|
+
...k_.flatMap((k) => [k, 1])
|
|
284
|
+
]);
|
|
285
|
+
st = st.pad([
|
|
286
|
+
[0, 0],
|
|
287
|
+
[0, 0],
|
|
288
|
+
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
289
|
+
]);
|
|
290
|
+
st = st.reshape([
|
|
291
|
+
a,
|
|
292
|
+
b,
|
|
293
|
+
...k_.map((k, i) => k * s_[i])
|
|
294
|
+
]);
|
|
295
|
+
st = st.shrink([
|
|
296
|
+
[0, a],
|
|
297
|
+
[0, b],
|
|
298
|
+
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
299
|
+
]);
|
|
300
|
+
return st;
|
|
301
|
+
}
|
|
302
|
+
/**
|
|
303
|
+
* Prepare for a convolution between two arrays.
|
|
304
|
+
*
|
|
305
|
+
* This does not check the validity of the shapes, which should be checked
|
|
306
|
+
* beforehand using `checkConvShape()`.
|
|
307
|
+
*/
|
|
308
|
+
function prepareConv(stX, stY, params) {
|
|
309
|
+
const n = stX.shape.length - 2;
|
|
310
|
+
stX = applyDilation(stX, params.lhsDilation);
|
|
311
|
+
const ks = stY.shape.slice(2);
|
|
312
|
+
stX = stX.padOrShrink([
|
|
313
|
+
[0, 0],
|
|
314
|
+
[0, 0],
|
|
315
|
+
...params.padding
|
|
316
|
+
]);
|
|
317
|
+
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
318
|
+
stX = stX.moveaxis(1, n + 1).reshape([
|
|
319
|
+
stX.shape[0],
|
|
320
|
+
1,
|
|
321
|
+
...stX.shape.slice(2, n + 2),
|
|
322
|
+
stX.shape[1] * require_backend.prod(ks)
|
|
323
|
+
]);
|
|
324
|
+
stY = stY.reshape([
|
|
325
|
+
stY.shape[0],
|
|
326
|
+
...require_backend.rep(n, 1),
|
|
327
|
+
stY.shape[1] * require_backend.prod(ks)
|
|
328
|
+
]);
|
|
329
|
+
return [stX, stY];
|
|
330
|
+
}
|
|
142
331
|
|
|
143
332
|
//#endregion
|
|
144
333
|
//#region src/frontend/core.ts
|
|
@@ -167,10 +356,14 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
167
356
|
Primitive$1["Cos"] = "cos";
|
|
168
357
|
Primitive$1["Exp"] = "exp";
|
|
169
358
|
Primitive$1["Log"] = "log";
|
|
359
|
+
Primitive$1["Sqrt"] = "sqrt";
|
|
170
360
|
Primitive$1["Min"] = "min";
|
|
171
361
|
Primitive$1["Max"] = "max";
|
|
172
362
|
Primitive$1["Reduce"] = "reduce";
|
|
173
363
|
Primitive$1["Dot"] = "dot";
|
|
364
|
+
Primitive$1["Conv"] = "conv";
|
|
365
|
+
Primitive$1["Pool"] = "pool";
|
|
366
|
+
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
174
367
|
Primitive$1["Compare"] = "compare";
|
|
175
368
|
Primitive$1["Where"] = "where";
|
|
176
369
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -234,6 +427,9 @@ function exp$1(x) {
|
|
|
234
427
|
function log$1(x) {
|
|
235
428
|
return bind1(Primitive.Log, [x]);
|
|
236
429
|
}
|
|
430
|
+
function sqrt$1(x) {
|
|
431
|
+
return bind1(Primitive.Sqrt, [x]);
|
|
432
|
+
}
|
|
237
433
|
function min$1(x, y) {
|
|
238
434
|
return bind1(Primitive.Min, [x, y]);
|
|
239
435
|
}
|
|
@@ -256,6 +452,17 @@ function reduce(x, op, axis, opts) {
|
|
|
256
452
|
function dot$1(x, y) {
|
|
257
453
|
return bind1(Primitive.Dot, [x, y]);
|
|
258
454
|
}
|
|
455
|
+
function conv(x, y, params = {}) {
|
|
456
|
+
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
457
|
+
const n = x.ndim - 2;
|
|
458
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
459
|
+
return bind1(Primitive.Conv, [x, y], {
|
|
460
|
+
strides: params.strides ?? require_backend.rep(n, 1),
|
|
461
|
+
padding: params.padding ?? require_backend.rep(n, [0, 0]),
|
|
462
|
+
lhsDilation: params.lhsDilation ?? require_backend.rep(n, 1),
|
|
463
|
+
rhsDilation: params.rhsDilation ?? require_backend.rep(n, 1)
|
|
464
|
+
});
|
|
465
|
+
}
|
|
259
466
|
function compare(x, y, op) {
|
|
260
467
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
261
468
|
}
|
|
@@ -391,6 +598,9 @@ var Tracer = class Tracer {
|
|
|
391
598
|
get shape() {
|
|
392
599
|
return this.aval.shape;
|
|
393
600
|
}
|
|
601
|
+
get size() {
|
|
602
|
+
return require_backend.prod(this.shape);
|
|
603
|
+
}
|
|
394
604
|
get dtype() {
|
|
395
605
|
return this.aval.dtype;
|
|
396
606
|
}
|
|
@@ -442,7 +652,7 @@ var Tracer = class Tracer {
|
|
|
442
652
|
else if (typeof axis === "number") axis = [require_backend.checkAxis(axis, this.ndim)];
|
|
443
653
|
else axis = axis.map((a) => require_backend.checkAxis(a, this.ndim));
|
|
444
654
|
let result = reduce(this, require_backend.AluOp.Add, axis);
|
|
445
|
-
result = result.mul(
|
|
655
|
+
result = result.mul(result.size / this.size);
|
|
446
656
|
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
447
657
|
return result;
|
|
448
658
|
}
|
|
@@ -476,8 +686,29 @@ var Tracer = class Tracer {
|
|
|
476
686
|
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
477
687
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
478
688
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
689
|
+
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
690
|
+
axis1 = require_backend.checkAxis(axis1, this.ndim);
|
|
691
|
+
axis2 = require_backend.checkAxis(axis2, this.ndim);
|
|
479
692
|
if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
|
|
480
|
-
throw new Error("
|
|
693
|
+
if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
|
|
694
|
+
let ar = this;
|
|
695
|
+
if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
|
|
696
|
+
const perm = require_backend.range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
|
|
697
|
+
ar = ar.transpose(perm);
|
|
698
|
+
}
|
|
699
|
+
const [n, m] = ar.shape.slice(-2);
|
|
700
|
+
const diagSize = Math.min(n, m - offset);
|
|
701
|
+
ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
|
|
702
|
+
const npad = diagSize * (m + 1) - n * m;
|
|
703
|
+
if (npad > 0) ar = pad$1(ar, [...require_backend.rep(ar.ndim - 1, [0, 0]), [0, npad]]);
|
|
704
|
+
else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
|
|
705
|
+
ar = ar.reshape([
|
|
706
|
+
...ar.shape.slice(0, -1),
|
|
707
|
+
diagSize,
|
|
708
|
+
m + 1
|
|
709
|
+
]);
|
|
710
|
+
ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
|
|
711
|
+
return ar;
|
|
481
712
|
}
|
|
482
713
|
/** Flatten the array without changing its data. */
|
|
483
714
|
flatten() {
|
|
@@ -620,7 +851,7 @@ var ShapedArray = class ShapedArray {
|
|
|
620
851
|
get ndim() {
|
|
621
852
|
return this.shape.length;
|
|
622
853
|
}
|
|
623
|
-
|
|
854
|
+
toString() {
|
|
624
855
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
625
856
|
}
|
|
626
857
|
equals(other) {
|
|
@@ -651,7 +882,7 @@ function fullRaise(trace, val) {
|
|
|
651
882
|
if (Object.is(val._trace.main, trace.main)) return val;
|
|
652
883
|
else if (val._trace.main.level < level) return trace.lift(val);
|
|
653
884
|
else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
|
|
654
|
-
else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
|
|
885
|
+
else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
|
|
655
886
|
}
|
|
656
887
|
var TreeMismatchError = class extends TypeError {
|
|
657
888
|
constructor(where$2, left, right) {
|
|
@@ -900,16 +1131,16 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
900
1131
|
jitCompileCache.set(cacheKey, jp);
|
|
901
1132
|
return jp;
|
|
902
1133
|
}
|
|
903
|
-
function reshapeViews(exp$2, mapping) {
|
|
1134
|
+
function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
904
1135
|
return exp$2.rewrite((exp$3) => {
|
|
905
1136
|
if (exp$3.op === require_backend.AluOp.GlobalView) {
|
|
906
1137
|
const [gid, st] = exp$3.arg;
|
|
907
1138
|
const newSt = mapping(st);
|
|
908
1139
|
if (newSt) {
|
|
909
|
-
const indices = require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
|
|
1140
|
+
const indices = reduceAxis ? require_backend.unravelAlu(newSt.shape.slice(0, -1), require_backend.AluVar.gidx).concat(require_backend.AluVar.ridx) : require_backend.unravelAlu(newSt.shape, require_backend.AluVar.gidx);
|
|
910
1141
|
return require_backend.AluExp.globalView(exp$3.dtype, gid, newSt, indices);
|
|
911
1142
|
}
|
|
912
|
-
}
|
|
1143
|
+
} else if (exp$3.op === require_backend.AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
913
1144
|
});
|
|
914
1145
|
}
|
|
915
1146
|
function broadcastedJit(fn) {
|
|
@@ -958,6 +1189,7 @@ const jitRules = {
|
|
|
958
1189
|
[Primitive.Cos]: unopJit(require_backend.AluExp.cos),
|
|
959
1190
|
[Primitive.Exp]: unopJit(require_backend.AluExp.exp),
|
|
960
1191
|
[Primitive.Log]: unopJit(require_backend.AluExp.log),
|
|
1192
|
+
[Primitive.Sqrt]: unopJit(require_backend.AluExp.sqrt),
|
|
961
1193
|
[Primitive.Min]: broadcastedJit(([a, b]) => require_backend.AluExp.min(a, b)),
|
|
962
1194
|
[Primitive.Max]: broadcastedJit(([a, b]) => require_backend.AluExp.max(a, b)),
|
|
963
1195
|
[Primitive.Reduce](nargs, [a], [as], { op, axis }) {
|
|
@@ -972,18 +1204,20 @@ const jitRules = {
|
|
|
972
1204
|
const size$1 = require_backend.prod(newShape);
|
|
973
1205
|
const reductionSize = require_backend.prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
974
1206
|
newShape.push(reductionSize);
|
|
975
|
-
|
|
976
|
-
|
|
977
|
-
const [gid, st] = exp$2.arg;
|
|
978
|
-
const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
|
|
979
|
-
const indices = require_backend.unravelAlu(newShape.slice(0, -1), require_backend.AluVar.gidx);
|
|
980
|
-
indices.push(require_backend.AluVar.ridx);
|
|
981
|
-
return require_backend.AluExp.globalView(exp$2.dtype, gid, newSt, indices);
|
|
982
|
-
}
|
|
983
|
-
});
|
|
1207
|
+
const perm = keptAxes.concat(shiftedAxes);
|
|
1208
|
+
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
984
1209
|
const reduction = new require_backend.Reduction(a.dtype, op, reductionSize);
|
|
985
1210
|
return new require_backend.Kernel(nargs, size$1, a, reduction);
|
|
986
1211
|
},
|
|
1212
|
+
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1213
|
+
[Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
|
|
1214
|
+
let stX = poolTranspose(require_backend.ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1215
|
+
const size$1 = require_backend.prod(inShape);
|
|
1216
|
+
stX = stX.reshape([...inShape, require_backend.prod(stX.shape.slice(inShape.length))]);
|
|
1217
|
+
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1218
|
+
const reduction = new require_backend.Reduction(a.dtype, require_backend.AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1219
|
+
return new require_backend.Kernel(nargs, size$1, a, reduction);
|
|
1220
|
+
},
|
|
987
1221
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
988
1222
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
989
1223
|
const c = k1.exp;
|
|
@@ -993,6 +1227,14 @@ const jitRules = {
|
|
|
993
1227
|
axis: [cs.ndim - 1]
|
|
994
1228
|
});
|
|
995
1229
|
},
|
|
1230
|
+
[Primitive.Conv](nargs, [a, b], [as, bs], params) {
|
|
1231
|
+
const [stX, stY] = prepareConv(require_backend.ShapeTracker.fromShape(as.shape), require_backend.ShapeTracker.fromShape(bs.shape), params);
|
|
1232
|
+
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1233
|
+
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1234
|
+
as = new ShapedArray(stX.shape, as.dtype);
|
|
1235
|
+
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1236
|
+
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1237
|
+
},
|
|
996
1238
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
997
1239
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => require_backend.AluExp.where(cond, a, b)),
|
|
998
1240
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
@@ -1005,8 +1247,20 @@ const jitRules = {
|
|
|
1005
1247
|
}),
|
|
1006
1248
|
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
1007
1249
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
1008
|
-
[Primitive.Gather]() {
|
|
1009
|
-
|
|
1250
|
+
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1251
|
+
const axisSet = new Set(axis);
|
|
1252
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1253
|
+
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1254
|
+
finalShape.splice(outDim, 0, ...indexShape);
|
|
1255
|
+
const idxAll = require_backend.unravelAlu(finalShape, require_backend.AluVar.gidx);
|
|
1256
|
+
const idxNonaxis = [...idxAll];
|
|
1257
|
+
idxNonaxis.splice(outDim, indexShape.length);
|
|
1258
|
+
const src = [...idxNonaxis];
|
|
1259
|
+
for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1260
|
+
for (const [i, iexp] of indices.entries()) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...require_backend.range(outDim + indexShape.length - st.shape.length), ...require_backend.range(outDim + indexShape.length, finalShape.length)])));
|
|
1261
|
+
const [index, valid] = require_backend.ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1262
|
+
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1263
|
+
return new require_backend.Kernel(nargs, require_backend.prod(finalShape), x.substitute({ gidx: index }));
|
|
1010
1264
|
},
|
|
1011
1265
|
[Primitive.JitCall]() {
|
|
1012
1266
|
throw new Error("internal: JitCall should have been flattened before JIT compilation");
|
|
@@ -1025,9 +1279,15 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
1025
1279
|
blackNodes.add(v);
|
|
1026
1280
|
p1NextBlack.set(v, v);
|
|
1027
1281
|
}
|
|
1282
|
+
const reducePrimitives = [
|
|
1283
|
+
Primitive.Reduce,
|
|
1284
|
+
Primitive.Dot,
|
|
1285
|
+
Primitive.Conv,
|
|
1286
|
+
Primitive.PoolTranspose
|
|
1287
|
+
];
|
|
1028
1288
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
1029
1289
|
const eqn = jaxpr.eqns[i];
|
|
1030
|
-
if (eqn.primitive === Primitive.
|
|
1290
|
+
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1031
1291
|
for (const v of eqn.outBinders) {
|
|
1032
1292
|
blackNodes.add(v);
|
|
1033
1293
|
p1NextBlack.set(v, v);
|
|
@@ -1254,7 +1514,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1254
1514
|
const inputs = [];
|
|
1255
1515
|
const src = [...idxNonaxis];
|
|
1256
1516
|
for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1257
|
-
for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#
|
|
1517
|
+
for (const [i, ar] of indices.entries()) if (ar.#source instanceof require_backend.AluExp) src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.accessorAluExp(ar.#source, ar.#st, idxAxis));
|
|
1258
1518
|
else {
|
|
1259
1519
|
let gid = inputs.indexOf(ar.#source);
|
|
1260
1520
|
if (gid === -1) {
|
|
@@ -1264,7 +1524,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1264
1524
|
src[axis[i]] = require_backend.AluExp.cast(require_backend.DType.Int32, require_backend.AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
|
|
1265
1525
|
}
|
|
1266
1526
|
let exp$2;
|
|
1267
|
-
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#
|
|
1527
|
+
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, src);
|
|
1268
1528
|
else {
|
|
1269
1529
|
let gid = inputs.indexOf(this.#source);
|
|
1270
1530
|
if (gid === -1) {
|
|
@@ -1307,7 +1567,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1307
1567
|
this.#check();
|
|
1308
1568
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1309
1569
|
const exp$3 = new require_backend.AluExp(op, dtypeOutput, [this.#source]);
|
|
1310
|
-
return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
|
|
1570
|
+
return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
|
|
1311
1571
|
}
|
|
1312
1572
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1313
1573
|
const exp$2 = new require_backend.AluExp(op, dtypeOutput, [require_backend.AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1340,18 +1600,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1340
1600
|
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1341
1601
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1342
1602
|
const newShape = [...arrays[0].shape];
|
|
1343
|
-
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && reduceAxis
|
|
1603
|
+
if (arrays.every((ar) => ar.#source instanceof require_backend.AluExp) && !reduceAxis) {
|
|
1344
1604
|
if (arrays.every((ar) => require_backend.deepEqual(ar.#st, arrays[0].#st))) {
|
|
1345
1605
|
const exp$4 = custom(arrays.map((ar) => ar.#source));
|
|
1346
|
-
return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
|
|
1606
|
+
return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
|
|
1347
1607
|
}
|
|
1348
1608
|
const exp$3 = custom(arrays.map((ar) => {
|
|
1349
1609
|
const src$1 = ar.#source;
|
|
1350
1610
|
if (ar.#st.contiguous) return src$1;
|
|
1351
|
-
return require_backend.accessorAluExp(
|
|
1611
|
+
return require_backend.accessorAluExp(src$1, ar.#st, require_backend.unravelAlu(newShape, require_backend.AluVar.idx));
|
|
1352
1612
|
}));
|
|
1353
1613
|
const st = require_backend.ShapeTracker.fromShape(newShape);
|
|
1354
|
-
return new Array$1(exp$3, st, exp$3.dtype, backend);
|
|
1614
|
+
return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
|
|
1355
1615
|
}
|
|
1356
1616
|
let indices;
|
|
1357
1617
|
if (!reduceAxis) indices = require_backend.unravelAlu(newShape, require_backend.AluVar.gidx);
|
|
@@ -1361,7 +1621,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1361
1621
|
}
|
|
1362
1622
|
const inputs = [];
|
|
1363
1623
|
const src = [];
|
|
1364
|
-
for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#
|
|
1624
|
+
for (const ar of arrays) if (ar.#source instanceof require_backend.AluExp) src.push(require_backend.accessorAluExp(ar.#source, ar.#st, indices));
|
|
1365
1625
|
else {
|
|
1366
1626
|
let gid = inputs.indexOf(ar.#source);
|
|
1367
1627
|
if (gid === -1) {
|
|
@@ -1395,7 +1655,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1395
1655
|
const indices = [...require_backend.unravelAlu(newShape, require_backend.AluVar.gidx), require_backend.AluVar.ridx];
|
|
1396
1656
|
let exp$2;
|
|
1397
1657
|
const inputs = [];
|
|
1398
|
-
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#
|
|
1658
|
+
if (this.#source instanceof require_backend.AluExp) exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
|
|
1399
1659
|
else {
|
|
1400
1660
|
inputs.push(this.#source);
|
|
1401
1661
|
exp$2 = require_backend.accessorGlobal(this.#dtype, 0, this.#st, indices);
|
|
@@ -1420,7 +1680,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1420
1680
|
this.#check();
|
|
1421
1681
|
const indices = require_backend.unravelAlu(this.#st.shape, require_backend.AluVar.gidx);
|
|
1422
1682
|
if (this.#source instanceof require_backend.AluExp) {
|
|
1423
|
-
const exp$2 = require_backend.accessorAluExp(this.#
|
|
1683
|
+
const exp$2 = require_backend.accessorAluExp(this.#source, this.#st, indices);
|
|
1424
1684
|
const kernel = new require_backend.Kernel(0, this.#st.size, exp$2);
|
|
1425
1685
|
const output = this.#backend.malloc(kernel.bytes);
|
|
1426
1686
|
const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
|
|
@@ -1458,42 +1718,51 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1458
1718
|
}
|
|
1459
1719
|
/** Realize the array and return it as data. */
|
|
1460
1720
|
async data() {
|
|
1461
|
-
if (this.#source instanceof require_backend.AluExp &&
|
|
1721
|
+
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1462
1722
|
this.#realize();
|
|
1463
1723
|
const pending = this.#pending;
|
|
1464
1724
|
if (pending) {
|
|
1465
1725
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1466
1726
|
for (const p of pending) p.submit();
|
|
1467
1727
|
}
|
|
1468
|
-
const byteCount = require_backend.byteWidth(this.#dtype) *
|
|
1728
|
+
const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
|
|
1469
1729
|
const buf = await this.#backend.read(this.#source, 0, byteCount);
|
|
1470
1730
|
this.dispose();
|
|
1471
1731
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
1472
1732
|
}
|
|
1473
|
-
/**
|
|
1733
|
+
/**
|
|
1734
|
+
* Wait for this array to finish evaluation.
|
|
1735
|
+
*
|
|
1736
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1737
|
+
* that pending operations are dispatched and fully executed before it
|
|
1738
|
+
* returns.
|
|
1739
|
+
*
|
|
1740
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1741
|
+
* dispatch of operations as well.
|
|
1742
|
+
*/
|
|
1474
1743
|
async wait() {
|
|
1475
1744
|
this.#check();
|
|
1476
|
-
if (this.#source instanceof require_backend.AluExp) return;
|
|
1745
|
+
if (this.#source instanceof require_backend.AluExp) return this;
|
|
1477
1746
|
const pending = this.#pending;
|
|
1478
1747
|
if (pending) {
|
|
1479
1748
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1480
1749
|
for (const p of pending) p.submit();
|
|
1481
1750
|
}
|
|
1482
1751
|
await this.#backend.read(this.#source, 0, 0);
|
|
1483
|
-
this
|
|
1752
|
+
return this;
|
|
1484
1753
|
}
|
|
1485
1754
|
/**
|
|
1486
1755
|
* Realize the array and return it as data. This is a sync variant and not
|
|
1487
1756
|
* recommended for performance reasons, as it will block rendering.
|
|
1488
1757
|
*/
|
|
1489
1758
|
dataSync() {
|
|
1490
|
-
if (this.#source instanceof require_backend.AluExp &&
|
|
1759
|
+
if (this.#source instanceof require_backend.AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1491
1760
|
this.#realize();
|
|
1492
1761
|
for (const p of this.#pending) {
|
|
1493
1762
|
p.prepareSync();
|
|
1494
1763
|
p.submit();
|
|
1495
1764
|
}
|
|
1496
|
-
const byteCount = require_backend.byteWidth(this.#dtype) *
|
|
1765
|
+
const byteCount = require_backend.byteWidth(this.#dtype) * this.size;
|
|
1497
1766
|
const buf = this.#backend.readSync(this.#source, 0, byteCount);
|
|
1498
1767
|
this.dispose();
|
|
1499
1768
|
return require_backend.dtypedArray(this.dtype, buf);
|
|
@@ -1514,6 +1783,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1514
1783
|
async jsAsync() {
|
|
1515
1784
|
return dataToJs(this.dtype, await this.data(), this.shape);
|
|
1516
1785
|
}
|
|
1786
|
+
/**
|
|
1787
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
1788
|
+
*
|
|
1789
|
+
* Throws an error if the array does not have a single element. The array must
|
|
1790
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
1791
|
+
*/
|
|
1792
|
+
item() {
|
|
1793
|
+
if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
|
|
1794
|
+
return this.dataSync()[0];
|
|
1795
|
+
}
|
|
1517
1796
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1518
1797
|
static _implRules() {
|
|
1519
1798
|
return {
|
|
@@ -1527,7 +1806,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1527
1806
|
return [x.#binary(require_backend.AluOp.Idiv, y)];
|
|
1528
1807
|
},
|
|
1529
1808
|
[Primitive.Neg]([x]) {
|
|
1530
|
-
return [zerosLike(x).#binary(require_backend.AluOp.Sub, x)];
|
|
1809
|
+
return [zerosLike(x.ref).#binary(require_backend.AluOp.Sub, x)];
|
|
1531
1810
|
},
|
|
1532
1811
|
[Primitive.Reciprocal]([x]) {
|
|
1533
1812
|
return [x.#unary(require_backend.AluOp.Reciprocal)];
|
|
@@ -1583,6 +1862,9 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1583
1862
|
[Primitive.Log]([x]) {
|
|
1584
1863
|
return [x.#unary(require_backend.AluOp.Log)];
|
|
1585
1864
|
},
|
|
1865
|
+
[Primitive.Sqrt]([x]) {
|
|
1866
|
+
return [x.#unary(require_backend.AluOp.Sqrt)];
|
|
1867
|
+
},
|
|
1586
1868
|
[Primitive.Min]([x, y]) {
|
|
1587
1869
|
return [x.#binary(require_backend.AluOp.Min, y)];
|
|
1588
1870
|
},
|
|
@@ -1593,9 +1875,24 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1593
1875
|
if (axis.length === 0) return [x];
|
|
1594
1876
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
1595
1877
|
},
|
|
1878
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
1879
|
+
const st = pool(x.#st, window, strides);
|
|
1880
|
+
return [x.#reshape(st)];
|
|
1881
|
+
},
|
|
1882
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
1883
|
+
const n = inShape.length;
|
|
1884
|
+
let st = poolTranspose(x.#st, inShape, window, strides);
|
|
1885
|
+
st = st.reshape([...st.shape.slice(0, n), require_backend.prod(st.shape.slice(n))]);
|
|
1886
|
+
return [x.#reshape(st).#reduce(require_backend.AluOp.Add)];
|
|
1887
|
+
},
|
|
1596
1888
|
[Primitive.Dot]([x, y]) {
|
|
1597
1889
|
return [Array$1.#naryCustom("dot", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
|
|
1598
1890
|
},
|
|
1891
|
+
[Primitive.Conv]([x, y], params) {
|
|
1892
|
+
checkConvShape(x.shape, y.shape, params);
|
|
1893
|
+
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
1894
|
+
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => require_backend.AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
1895
|
+
},
|
|
1599
1896
|
[Primitive.Compare]([x, y], { op }) {
|
|
1600
1897
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1601
1898
|
return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: require_backend.DType.Bool })];
|
|
@@ -1660,6 +1957,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1660
1957
|
dtype ??= require_backend.DType.Float32;
|
|
1661
1958
|
if (![
|
|
1662
1959
|
require_backend.DType.Float32,
|
|
1960
|
+
require_backend.DType.Float16,
|
|
1663
1961
|
require_backend.DType.Int32,
|
|
1664
1962
|
require_backend.DType.Uint32
|
|
1665
1963
|
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
@@ -1667,6 +1965,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1667
1965
|
dtype ??= require_backend.DType.Bool;
|
|
1668
1966
|
if (![
|
|
1669
1967
|
require_backend.DType.Float32,
|
|
1968
|
+
require_backend.DType.Float16,
|
|
1670
1969
|
require_backend.DType.Int32,
|
|
1671
1970
|
require_backend.DType.Uint32,
|
|
1672
1971
|
require_backend.DType.Bool
|
|
@@ -1680,7 +1979,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1680
1979
|
if (shape$1 && !require_backend.deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1681
1980
|
if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
|
|
1682
1981
|
return values;
|
|
1683
|
-
} else if (values
|
|
1982
|
+
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1684
1983
|
dtype,
|
|
1685
1984
|
device
|
|
1686
1985
|
});
|
|
@@ -1709,7 +2008,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1709
2008
|
});
|
|
1710
2009
|
} else {
|
|
1711
2010
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
1712
|
-
const data = require_backend.
|
|
2011
|
+
const data = require_backend.dtypedJsArray(dtype, flat);
|
|
1713
2012
|
return arrayFromData(data, shape$1, {
|
|
1714
2013
|
dtype,
|
|
1715
2014
|
device
|
|
@@ -1730,19 +2029,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
|
|
|
1730
2029
|
});
|
|
1731
2030
|
}
|
|
1732
2031
|
const backend = require_backend.getBackend(device);
|
|
1733
|
-
if (data
|
|
1734
|
-
|
|
1735
|
-
|
|
1736
|
-
|
|
1737
|
-
|
|
1738
|
-
if (
|
|
1739
|
-
|
|
1740
|
-
|
|
1741
|
-
|
|
1742
|
-
|
|
1743
|
-
|
|
1744
|
-
|
|
1745
|
-
|
|
2032
|
+
if (ArrayBuffer.isView(data)) {
|
|
2033
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2034
|
+
if (data instanceof Float32Array) {
|
|
2035
|
+
if (dtype && dtype !== require_backend.DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2036
|
+
dtype ??= require_backend.DType.Float32;
|
|
2037
|
+
} else if (data instanceof Int32Array) {
|
|
2038
|
+
if (dtype && dtype !== require_backend.DType.Int32 && dtype !== require_backend.DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2039
|
+
dtype ??= require_backend.DType.Int32;
|
|
2040
|
+
} else if (data instanceof Uint32Array) {
|
|
2041
|
+
if (dtype && dtype !== require_backend.DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2042
|
+
dtype ??= require_backend.DType.Uint32;
|
|
2043
|
+
} else if (data instanceof Float16Array) {
|
|
2044
|
+
if (dtype && dtype !== require_backend.DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2045
|
+
dtype ??= require_backend.DType.Float16;
|
|
2046
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2047
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2048
|
+
return new Array$1(slot, require_backend.ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2049
|
+
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
1746
2050
|
}
|
|
1747
2051
|
function dataToJs(dtype, data, shape$1) {
|
|
1748
2052
|
if (shape$1.length === 0) return dtype === require_backend.DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -1769,9 +2073,20 @@ var EvalTrace = class extends Trace {
|
|
|
1769
2073
|
};
|
|
1770
2074
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
1771
2075
|
const implRules = Array$1._implRules();
|
|
1772
|
-
function zerosLike(val) {
|
|
2076
|
+
function zerosLike(val, dtype) {
|
|
1773
2077
|
const aval = getAval(val);
|
|
1774
|
-
|
|
2078
|
+
if (val instanceof Tracer) val.dispose();
|
|
2079
|
+
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2080
|
+
}
|
|
2081
|
+
function onesLike(val, dtype) {
|
|
2082
|
+
const aval = getAval(val);
|
|
2083
|
+
if (val instanceof Tracer) val.dispose();
|
|
2084
|
+
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2085
|
+
}
|
|
2086
|
+
function fullLike(val, fillValue, dtype) {
|
|
2087
|
+
const aval = getAval(val);
|
|
2088
|
+
if (val instanceof Tracer) val.dispose();
|
|
2089
|
+
return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
|
|
1775
2090
|
}
|
|
1776
2091
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1777
2092
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -1793,6 +2108,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
1793
2108
|
if (typeof fillValue === "number") {
|
|
1794
2109
|
dtype = dtype ?? require_backend.DType.Float32;
|
|
1795
2110
|
source = require_backend.AluExp.const(dtype, fillValue);
|
|
2111
|
+
} else if (typeof fillValue === "bigint") {
|
|
2112
|
+
dtype = dtype ?? require_backend.DType.Int32;
|
|
2113
|
+
source = require_backend.AluExp.const(dtype, Number(fillValue));
|
|
1796
2114
|
} else if (typeof fillValue === "boolean") {
|
|
1797
2115
|
dtype = dtype ?? require_backend.DType.Bool;
|
|
1798
2116
|
source = require_backend.AluExp.const(dtype, fillValue ? 1 : 0);
|
|
@@ -1890,7 +2208,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
1890
2208
|
const st = require_backend.ShapeTracker.fromShape([num]);
|
|
1891
2209
|
return new Array$1(exp$2, st, dtype, require_backend.getBackend(device));
|
|
1892
2210
|
}
|
|
1893
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
1894
2211
|
function aluCompare(a, b, op) {
|
|
1895
2212
|
switch (op) {
|
|
1896
2213
|
case CompareOp.Greater: return require_backend.AluExp.mul(require_backend.AluExp.cmpne(a, b), require_backend.AluExp.cmplt(a, b).not());
|
|
@@ -1932,8 +2249,8 @@ function generalBroadcast(a, b) {
|
|
|
1932
2249
|
}
|
|
1933
2250
|
|
|
1934
2251
|
//#endregion
|
|
1935
|
-
//#region node_modules/.pnpm/@oxc-project+runtime@0.
|
|
1936
|
-
var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.
|
|
2252
|
+
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js
|
|
2253
|
+
var require_usingCtx = /* @__PURE__ */ __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/usingCtx.js": ((exports, module) => {
|
|
1937
2254
|
function _usingCtx() {
|
|
1938
2255
|
var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
|
|
1939
2256
|
var n$1 = Error();
|
|
@@ -1989,11 +2306,11 @@ var require_usingCtx = __commonJS({ "node_modules/.pnpm/@oxc-project+runtime@0.7
|
|
|
1989
2306
|
};
|
|
1990
2307
|
}
|
|
1991
2308
|
module.exports = _usingCtx, module.exports.__esModule = true, module.exports["default"] = module.exports;
|
|
1992
|
-
} });
|
|
2309
|
+
}) });
|
|
1993
2310
|
|
|
1994
2311
|
//#endregion
|
|
1995
2312
|
//#region src/frontend/jaxpr.ts
|
|
1996
|
-
var import_usingCtx$2 = __toESM(require_usingCtx(), 1);
|
|
2313
|
+
var import_usingCtx$2 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
1997
2314
|
/** Variable in a Jaxpr expression. */
|
|
1998
2315
|
var Var = class Var {
|
|
1999
2316
|
static #nextId = 1;
|
|
@@ -2004,7 +2321,7 @@ var Var = class Var {
|
|
|
2004
2321
|
this.aval = aval;
|
|
2005
2322
|
}
|
|
2006
2323
|
toString() {
|
|
2007
|
-
return `Var(${this.id}):${this.aval.
|
|
2324
|
+
return `Var(${this.id}):${this.aval.toString()}`;
|
|
2008
2325
|
}
|
|
2009
2326
|
};
|
|
2010
2327
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
@@ -2044,7 +2361,7 @@ var VarPrinter = class {
|
|
|
2044
2361
|
return name;
|
|
2045
2362
|
}
|
|
2046
2363
|
nameType(v) {
|
|
2047
|
-
return `${this.name(v)}:${v.aval.
|
|
2364
|
+
return `${this.name(v)}:${v.aval.toString()}`;
|
|
2048
2365
|
}
|
|
2049
2366
|
};
|
|
2050
2367
|
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
@@ -2199,8 +2516,8 @@ var JaxprType = class {
|
|
|
2199
2516
|
this.outTypes = outTypes;
|
|
2200
2517
|
}
|
|
2201
2518
|
toString() {
|
|
2202
|
-
const inTypes = this.inTypes.map((aval) => aval.
|
|
2203
|
-
const outTypes = this.outTypes.map((aval) => aval.
|
|
2519
|
+
const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
|
|
2520
|
+
const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
|
|
2204
2521
|
return `(${inTypes}) -> (${outTypes})`;
|
|
2205
2522
|
}
|
|
2206
2523
|
};
|
|
@@ -2279,7 +2596,7 @@ var JaxprTracer = class extends Tracer {
|
|
|
2279
2596
|
this.aval = aval;
|
|
2280
2597
|
}
|
|
2281
2598
|
toString() {
|
|
2282
|
-
return `JaxprTracer(${this.aval.
|
|
2599
|
+
return `JaxprTracer(${this.aval.toString()})`;
|
|
2283
2600
|
}
|
|
2284
2601
|
get ref() {
|
|
2285
2602
|
return this;
|
|
@@ -2418,6 +2735,7 @@ const abstractEvalRules = {
|
|
|
2418
2735
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2419
2736
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2420
2737
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2738
|
+
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2421
2739
|
[Primitive.Min]: binopAbstractEval,
|
|
2422
2740
|
[Primitive.Max]: binopAbstractEval,
|
|
2423
2741
|
[Primitive.Reduce]([x], { axis }) {
|
|
@@ -2425,6 +2743,15 @@ const abstractEvalRules = {
|
|
|
2425
2743
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2426
2744
|
return [new ShapedArray(newShape, x.dtype)];
|
|
2427
2745
|
},
|
|
2746
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
2747
|
+
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2748
|
+
return [new ShapedArray(shape$1, x.dtype)];
|
|
2749
|
+
},
|
|
2750
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2751
|
+
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2752
|
+
if (!require_backend.deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2753
|
+
return [new ShapedArray(inShape, x.dtype)];
|
|
2754
|
+
},
|
|
2428
2755
|
[Primitive.Dot]([x, y]) {
|
|
2429
2756
|
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2430
2757
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
@@ -2432,6 +2759,11 @@ const abstractEvalRules = {
|
|
|
2432
2759
|
shape$1.splice(-1, 1);
|
|
2433
2760
|
return [new ShapedArray(shape$1, x.dtype)];
|
|
2434
2761
|
},
|
|
2762
|
+
[Primitive.Conv]([lhs, rhs], params) {
|
|
2763
|
+
if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
|
|
2764
|
+
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2765
|
+
return [new ShapedArray(shape$1, lhs.dtype)];
|
|
2766
|
+
},
|
|
2435
2767
|
[Primitive.Compare]: compareAbstractEval,
|
|
2436
2768
|
[Primitive.Where]([cond, x, y]) {
|
|
2437
2769
|
if (cond.dtype !== require_backend.DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
@@ -2479,15 +2811,34 @@ const abstractEvalRules = {
|
|
|
2479
2811
|
return outTypes;
|
|
2480
2812
|
}
|
|
2481
2813
|
};
|
|
2482
|
-
function
|
|
2814
|
+
function splitIdx(values, argnums) {
|
|
2815
|
+
const a = [];
|
|
2816
|
+
const b = [];
|
|
2817
|
+
for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
|
|
2818
|
+
else b.push(values[i]);
|
|
2819
|
+
return [a, b];
|
|
2820
|
+
}
|
|
2821
|
+
function joinIdx(n, a, b, argnums) {
|
|
2822
|
+
const result = [];
|
|
2823
|
+
let ai = 0;
|
|
2824
|
+
let bi = 0;
|
|
2825
|
+
for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
|
|
2826
|
+
else result.push(b[bi++]);
|
|
2827
|
+
return result;
|
|
2828
|
+
}
|
|
2829
|
+
function makeJaxpr$1(f, opts) {
|
|
2483
2830
|
return (...argsIn) => {
|
|
2484
2831
|
try {
|
|
2485
2832
|
var _usingCtx$1 = (0, import_usingCtx$2.default)();
|
|
2486
|
-
const
|
|
2487
|
-
const [
|
|
2833
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2834
|
+
const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
|
|
2835
|
+
const [avalsIn, inTree] = flatten(shapedArgs);
|
|
2836
|
+
const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
|
|
2837
|
+
return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
|
|
2838
|
+
}, inTree);
|
|
2488
2839
|
const builder = new JaxprBuilder();
|
|
2489
2840
|
const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
|
|
2490
|
-
|
|
2841
|
+
_usingCtx$1.u(newDynamic(main));
|
|
2491
2842
|
const trace = new JaxprTrace(main);
|
|
2492
2843
|
const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
2493
2844
|
const outs = fFlat(...tracersIn);
|
|
@@ -2506,14 +2857,17 @@ function makeJaxpr$1(f) {
|
|
|
2506
2857
|
}
|
|
2507
2858
|
};
|
|
2508
2859
|
}
|
|
2509
|
-
function jit$1(f) {
|
|
2860
|
+
function jit$1(f, opts) {
|
|
2510
2861
|
const cache = /* @__PURE__ */ new Map();
|
|
2862
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2511
2863
|
return ((...args) => {
|
|
2512
|
-
const [
|
|
2864
|
+
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2865
|
+
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2513
2866
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
2514
2867
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
2515
|
-
const
|
|
2516
|
-
const
|
|
2868
|
+
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
2869
|
+
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2870
|
+
const { jaxpr, consts, treedef: outTree } = require_backend.runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2517
2871
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2518
2872
|
jaxpr,
|
|
2519
2873
|
numConsts: consts.length
|
|
@@ -2524,7 +2878,7 @@ function jit$1(f) {
|
|
|
2524
2878
|
|
|
2525
2879
|
//#endregion
|
|
2526
2880
|
//#region src/frontend/jvp.ts
|
|
2527
|
-
var import_usingCtx$1 = __toESM(require_usingCtx(), 1);
|
|
2881
|
+
var import_usingCtx$1 = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
2528
2882
|
var JVPTracer = class extends Tracer {
|
|
2529
2883
|
constructor(trace, primal, tangent) {
|
|
2530
2884
|
super(trace);
|
|
@@ -2551,7 +2905,7 @@ var JVPTrace = class extends Trace {
|
|
|
2551
2905
|
return this.lift(pureArray(val));
|
|
2552
2906
|
}
|
|
2553
2907
|
lift(val) {
|
|
2554
|
-
return new JVPTracer(this, val, zerosLike(val));
|
|
2908
|
+
return new JVPTracer(this, val, zerosLike(val.ref));
|
|
2555
2909
|
}
|
|
2556
2910
|
processPrimitive(primitive, tracers, params) {
|
|
2557
2911
|
const [primalsIn, tangentsIn] = require_backend.unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2569,19 +2923,25 @@ function linearTangentsJvp(primitive) {
|
|
|
2569
2923
|
return [ys, dys];
|
|
2570
2924
|
};
|
|
2571
2925
|
}
|
|
2926
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
2927
|
+
function bilinearTangentsJvp(primitive) {
|
|
2928
|
+
return ([x, y], [dx, dy], params) => {
|
|
2929
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
2930
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
2931
|
+
return [[primal], [tangent]];
|
|
2932
|
+
};
|
|
2933
|
+
}
|
|
2572
2934
|
/** Rule that zeros out any tangents. */
|
|
2573
2935
|
function zeroTangentsJvp(primitive) {
|
|
2574
2936
|
return (primals, tangents, params) => {
|
|
2575
2937
|
for (const t of tangents) t.dispose();
|
|
2576
2938
|
const ys = bind(primitive, primals, params);
|
|
2577
|
-
return [ys, ys.map((y) => zerosLike(y))];
|
|
2939
|
+
return [ys, ys.map((y) => zerosLike(y.ref))];
|
|
2578
2940
|
};
|
|
2579
2941
|
}
|
|
2580
2942
|
const jvpRules = {
|
|
2581
2943
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
2582
|
-
[Primitive.Mul](
|
|
2583
|
-
return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
|
|
2584
|
-
},
|
|
2944
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
2585
2945
|
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
2586
2946
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
2587
2947
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
@@ -2594,13 +2954,13 @@ const jvpRules = {
|
|
|
2594
2954
|
if (require_backend.isFloatDtype(dtype) && require_backend.isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2595
2955
|
else {
|
|
2596
2956
|
dx.dispose();
|
|
2597
|
-
return [[cast(x, dtype)], [zerosLike(x)]];
|
|
2957
|
+
return [[cast(x.ref, dtype)], [zerosLike(x)]];
|
|
2598
2958
|
}
|
|
2599
2959
|
},
|
|
2600
2960
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2601
2961
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2602
2962
|
dx.dispose();
|
|
2603
|
-
return [[bitcast(x, dtype)], [zerosLike(x)]];
|
|
2963
|
+
return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
|
|
2604
2964
|
},
|
|
2605
2965
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2606
2966
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2616,6 +2976,10 @@ const jvpRules = {
|
|
|
2616
2976
|
[Primitive.Log]([x], [dx]) {
|
|
2617
2977
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
2618
2978
|
},
|
|
2979
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
2980
|
+
const z = sqrt$1(x);
|
|
2981
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
2982
|
+
},
|
|
2619
2983
|
[Primitive.Min]([x, y], [dx, dy]) {
|
|
2620
2984
|
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
2621
2985
|
},
|
|
@@ -2632,13 +2996,14 @@ const jvpRules = {
|
|
|
2632
2996
|
const primal = reduce(x.ref, op, axis);
|
|
2633
2997
|
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
2634
2998
|
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
2635
|
-
const tangent = where$1(notMin,
|
|
2999
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
2636
3000
|
return [[primal], [tangent]];
|
|
2637
3001
|
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
2638
3002
|
},
|
|
2639
|
-
[Primitive.
|
|
2640
|
-
|
|
2641
|
-
|
|
3003
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
3004
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
3005
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
3006
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
2642
3007
|
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
2643
3008
|
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
2644
3009
|
dcond.dispose();
|
|
@@ -2711,7 +3076,7 @@ function jvp$1(f, primals, tangents) {
|
|
|
2711
3076
|
|
|
2712
3077
|
//#endregion
|
|
2713
3078
|
//#region src/frontend/vmap.ts
|
|
2714
|
-
var import_usingCtx = __toESM(require_usingCtx(), 1);
|
|
3079
|
+
var import_usingCtx = /* @__PURE__ */ __toESM(require_usingCtx(), 1);
|
|
2715
3080
|
function mappedAval(batchDim, aval) {
|
|
2716
3081
|
const shape$1 = [...aval.shape];
|
|
2717
3082
|
shape$1.splice(batchDim, 1);
|
|
@@ -2815,6 +3180,7 @@ const vmapRules = {
|
|
|
2815
3180
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
2816
3181
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
2817
3182
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3183
|
+
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
2818
3184
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
2819
3185
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
2820
3186
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
@@ -2951,7 +3317,7 @@ var PartialVal = class PartialVal {
|
|
|
2951
3317
|
return this.val !== null;
|
|
2952
3318
|
}
|
|
2953
3319
|
toString() {
|
|
2954
|
-
return this.val ? this.val.toString() : this.aval.
|
|
3320
|
+
return this.val ? this.val.toString() : this.aval.toString();
|
|
2955
3321
|
}
|
|
2956
3322
|
};
|
|
2957
3323
|
function partialEvalFlat(f, pvalsIn) {
|
|
@@ -3325,12 +3691,72 @@ const transposeRules = {
|
|
|
3325
3691
|
if (op === require_backend.AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
|
|
3326
3692
|
else throw new NonlinearError(Primitive.Reduce);
|
|
3327
3693
|
},
|
|
3694
|
+
[Primitive.Pool]([ct], [x], { window, strides }) {
|
|
3695
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
|
|
3696
|
+
return bind(Primitive.PoolTranspose, [ct], {
|
|
3697
|
+
inShape: x.aval.shape,
|
|
3698
|
+
window,
|
|
3699
|
+
strides
|
|
3700
|
+
});
|
|
3701
|
+
},
|
|
3702
|
+
[Primitive.PoolTranspose]([ct], [x], { window, strides }) {
|
|
3703
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
|
|
3704
|
+
return bind(Primitive.Pool, [ct], {
|
|
3705
|
+
window,
|
|
3706
|
+
strides
|
|
3707
|
+
});
|
|
3708
|
+
},
|
|
3328
3709
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3329
3710
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3330
3711
|
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3331
3712
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3332
3713
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3333
3714
|
},
|
|
3715
|
+
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
3716
|
+
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
3717
|
+
const rev01 = [
|
|
3718
|
+
1,
|
|
3719
|
+
0,
|
|
3720
|
+
...require_backend.range(2, ct.ndim)
|
|
3721
|
+
];
|
|
3722
|
+
if (lhs instanceof UndefPrimal) {
|
|
3723
|
+
let kernel = rhs;
|
|
3724
|
+
kernel = transpose$1(kernel, rev01);
|
|
3725
|
+
kernel = flip$1(kernel, require_backend.range(2, kernel.ndim));
|
|
3726
|
+
const result = conv(ct, kernel, {
|
|
3727
|
+
strides: params.lhsDilation,
|
|
3728
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3729
|
+
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3730
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3731
|
+
const padBefore = dilatedKernel - 1 - pl;
|
|
3732
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3733
|
+
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
3734
|
+
return [padBefore, padAfter];
|
|
3735
|
+
}),
|
|
3736
|
+
lhsDilation: params.strides,
|
|
3737
|
+
rhsDilation: params.rhsDilation
|
|
3738
|
+
});
|
|
3739
|
+
return [result, null];
|
|
3740
|
+
} else {
|
|
3741
|
+
const newLhs = transpose$1(lhs, rev01);
|
|
3742
|
+
const newRhs = transpose$1(ct, rev01);
|
|
3743
|
+
let result = conv(newLhs, newRhs, {
|
|
3744
|
+
strides: params.rhsDilation,
|
|
3745
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3746
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3747
|
+
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3748
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3749
|
+
const padFromLhs = dilatedCt - dilatedLhs;
|
|
3750
|
+
const padFromRhs = dilatedKernel - pl - 1;
|
|
3751
|
+
return [pl, padFromLhs + padFromRhs];
|
|
3752
|
+
}),
|
|
3753
|
+
lhsDilation: params.lhsDilation,
|
|
3754
|
+
rhsDilation: params.strides
|
|
3755
|
+
});
|
|
3756
|
+
result = transpose$1(result, rev01);
|
|
3757
|
+
return [null, result];
|
|
3758
|
+
}
|
|
3759
|
+
},
|
|
3334
3760
|
[Primitive.Where]([ct], [cond, x, y]) {
|
|
3335
3761
|
const cts = [
|
|
3336
3762
|
null,
|
|
@@ -3451,8 +3877,8 @@ function valueAndGrad$1(f) {
|
|
|
3451
3877
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
3452
3878
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3453
3879
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3454
|
-
if (y.dtype
|
|
3455
|
-
const [ct, ...rest] = fVjp(
|
|
3880
|
+
if (!require_backend.isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3881
|
+
const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
|
|
3456
3882
|
for (const r of rest) r.dispose();
|
|
3457
3883
|
return [y, ct];
|
|
3458
3884
|
};
|
|
@@ -3466,6 +3892,73 @@ function jacrev$1(f) {
|
|
|
3466
3892
|
};
|
|
3467
3893
|
}
|
|
3468
3894
|
|
|
3895
|
+
//#endregion
|
|
3896
|
+
//#region src/lax.ts
|
|
3897
|
+
var lax_exports = {};
|
|
3898
|
+
__export(lax_exports, {
|
|
3899
|
+
conv: () => conv$1,
|
|
3900
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
3901
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
3902
|
+
reduceWindow: () => reduceWindow
|
|
3903
|
+
});
|
|
3904
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
3905
|
+
const padType = padding.toUpperCase();
|
|
3906
|
+
switch (padType) {
|
|
3907
|
+
case "VALID": return require_backend.rep(inShape.length, [0, 0]);
|
|
3908
|
+
case "SAME":
|
|
3909
|
+
case "SAME_LOWER": {
|
|
3910
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
3911
|
+
const padSizes = require_backend.zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
3912
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
3913
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
3914
|
+
}
|
|
3915
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
3916
|
+
}
|
|
3917
|
+
}
|
|
3918
|
+
/**
|
|
3919
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
3920
|
+
*
|
|
3921
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
3922
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
3923
|
+
*
|
|
3924
|
+
* Grouped convolutions are not supported right now.
|
|
3925
|
+
*/
|
|
3926
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
3927
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
3928
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
3929
|
+
if (typeof padding === "string") {
|
|
3930
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
3931
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? require_backend.rep(rhs.ndim - 2, 1), padding);
|
|
3932
|
+
}
|
|
3933
|
+
return conv(lhs, rhs, {
|
|
3934
|
+
strides: windowStrides,
|
|
3935
|
+
padding,
|
|
3936
|
+
lhsDilation,
|
|
3937
|
+
rhsDilation
|
|
3938
|
+
});
|
|
3939
|
+
}
|
|
3940
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
3941
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
3942
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
3943
|
+
lhsDilation,
|
|
3944
|
+
rhsDilation
|
|
3945
|
+
});
|
|
3946
|
+
}
|
|
3947
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
3948
|
+
function conv$1(lhs, rhs, windowStrides, padding) {
|
|
3949
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
3950
|
+
}
|
|
3951
|
+
/** Reduce a computation over padded windows. */
|
|
3952
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
3953
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
3954
|
+
if (!windowStrides) windowStrides = require_backend.rep(windowDimensions.length, 1);
|
|
3955
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
3956
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
3957
|
+
window: windowDimensions,
|
|
3958
|
+
strides: windowStrides
|
|
3959
|
+
}));
|
|
3960
|
+
}
|
|
3961
|
+
|
|
3469
3962
|
//#endregion
|
|
3470
3963
|
//#region src/numpy.ts
|
|
3471
3964
|
var numpy_exports = {};
|
|
@@ -3484,9 +3977,9 @@ __export(numpy_exports, {
|
|
|
3484
3977
|
bool: () => bool,
|
|
3485
3978
|
clip: () => clip,
|
|
3486
3979
|
columnStack: () => columnStack,
|
|
3487
|
-
complex64: () => complex64,
|
|
3488
3980
|
concatenate: () => concatenate,
|
|
3489
3981
|
cos: () => cos,
|
|
3982
|
+
cosh: () => cosh,
|
|
3490
3983
|
diag: () => diag,
|
|
3491
3984
|
diagonal: () => diagonal,
|
|
3492
3985
|
divide: () => divide,
|
|
@@ -3501,8 +3994,10 @@ __export(numpy_exports, {
|
|
|
3501
3994
|
flip: () => flip,
|
|
3502
3995
|
fliplr: () => fliplr,
|
|
3503
3996
|
flipud: () => flipud,
|
|
3997
|
+
float16: () => float16,
|
|
3504
3998
|
float32: () => float32,
|
|
3505
3999
|
full: () => full,
|
|
4000
|
+
fullLike: () => fullLike$1,
|
|
3506
4001
|
greater: () => greater,
|
|
3507
4002
|
greaterEqual: () => greaterEqual,
|
|
3508
4003
|
hstack: () => hstack,
|
|
@@ -3529,6 +4024,7 @@ __export(numpy_exports, {
|
|
|
3529
4024
|
negative: () => negative,
|
|
3530
4025
|
notEqual: () => notEqual,
|
|
3531
4026
|
ones: () => ones,
|
|
4027
|
+
onesLike: () => onesLike$1,
|
|
3532
4028
|
pad: () => pad,
|
|
3533
4029
|
permuteDims: () => permuteDims,
|
|
3534
4030
|
pi: () => pi,
|
|
@@ -3539,11 +4035,14 @@ __export(numpy_exports, {
|
|
|
3539
4035
|
scalar: () => scalar,
|
|
3540
4036
|
shape: () => shape,
|
|
3541
4037
|
sin: () => sin,
|
|
4038
|
+
sinh: () => sinh,
|
|
3542
4039
|
size: () => size,
|
|
4040
|
+
sqrt: () => sqrt,
|
|
3543
4041
|
square: () => square,
|
|
3544
4042
|
stack: () => stack,
|
|
3545
4043
|
sum: () => sum,
|
|
3546
4044
|
tan: () => tan,
|
|
4045
|
+
tanh: () => tanh,
|
|
3547
4046
|
transpose: () => transpose,
|
|
3548
4047
|
trueDivide: () => trueDivide,
|
|
3549
4048
|
trunc: () => trunc,
|
|
@@ -3552,13 +4051,14 @@ __export(numpy_exports, {
|
|
|
3552
4051
|
vecdot: () => vecdot,
|
|
3553
4052
|
vstack: () => vstack,
|
|
3554
4053
|
where: () => where,
|
|
3555
|
-
zeros: () => zeros
|
|
4054
|
+
zeros: () => zeros,
|
|
4055
|
+
zerosLike: () => zerosLike$1
|
|
3556
4056
|
});
|
|
3557
4057
|
const float32 = require_backend.DType.Float32;
|
|
3558
4058
|
const int32 = require_backend.DType.Int32;
|
|
3559
4059
|
const uint32 = require_backend.DType.Uint32;
|
|
3560
4060
|
const bool = require_backend.DType.Bool;
|
|
3561
|
-
const
|
|
4061
|
+
const float16 = require_backend.DType.Float16;
|
|
3562
4062
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
3563
4063
|
const e = Math.E;
|
|
3564
4064
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -3585,6 +4085,8 @@ const cos = cos$1;
|
|
|
3585
4085
|
const exp = exp$1;
|
|
3586
4086
|
/** Calculate the natural logarithm of all elements in the input array. */
|
|
3587
4087
|
const log = log$1;
|
|
4088
|
+
/** Calculate the square root of all elements in the input array. */
|
|
4089
|
+
const sqrt = sqrt$1;
|
|
3588
4090
|
/** Return element-wise minimum of the input arrays. */
|
|
3589
4091
|
const minimum = min$1;
|
|
3590
4092
|
/** Return element-wise maximum of the input arrays. */
|
|
@@ -3626,6 +4128,12 @@ const pad = pad$1;
|
|
|
3626
4128
|
const ndim = ndim$1;
|
|
3627
4129
|
/** Return the shape of an array. Does not consume array reference. */
|
|
3628
4130
|
const shape = getShape;
|
|
4131
|
+
/** Return an array of zeros with the same shape and type as a given array. */
|
|
4132
|
+
const zerosLike$1 = zerosLike;
|
|
4133
|
+
/** Return an array of ones with the same shape and type as a given array. */
|
|
4134
|
+
const onesLike$1 = onesLike;
|
|
4135
|
+
/** Return a full array with the same shape and type as a given array. */
|
|
4136
|
+
const fullLike$1 = fullLike;
|
|
3629
4137
|
/**
|
|
3630
4138
|
* Return the number of elements in an array, optionally along an axis.
|
|
3631
4139
|
* Does not consume array reference.
|
|
@@ -3676,13 +4184,7 @@ function argmin(a, axis, opts) {
|
|
|
3676
4184
|
dtype: int32,
|
|
3677
4185
|
device: a.device
|
|
3678
4186
|
});
|
|
3679
|
-
const idx =
|
|
3680
|
-
dtype: int32,
|
|
3681
|
-
device: a.device
|
|
3682
|
-
}), scalar(0, {
|
|
3683
|
-
dtype: int32,
|
|
3684
|
-
device: a.device
|
|
3685
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4187
|
+
const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3686
4188
|
dtype: int32,
|
|
3687
4189
|
device: a.device
|
|
3688
4190
|
}).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
|
|
@@ -3706,13 +4208,7 @@ function argmax(a, axis, opts) {
|
|
|
3706
4208
|
dtype: int32,
|
|
3707
4209
|
device: a.device
|
|
3708
4210
|
});
|
|
3709
|
-
const idx =
|
|
3710
|
-
dtype: int32,
|
|
3711
|
-
device: a.device
|
|
3712
|
-
}), scalar(0, {
|
|
3713
|
-
dtype: int32,
|
|
3714
|
-
device: a.device
|
|
3715
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4211
|
+
const idx = isMax.astype(require_backend.DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3716
4212
|
dtype: int32,
|
|
3717
4213
|
device: a.device
|
|
3718
4214
|
}).reshape([shape$1[axis], ...require_backend.rep(shape$1.length - axis - 1, 1)]));
|
|
@@ -3844,9 +4340,11 @@ function ravel(a) {
|
|
|
3844
4340
|
* Return specified diagonals.
|
|
3845
4341
|
*
|
|
3846
4342
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
3847
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
4343
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
3848
4344
|
*
|
|
3849
|
-
* This returns a view over the existing array.
|
|
4345
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
4346
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
4347
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
3850
4348
|
*/
|
|
3851
4349
|
function diagonal(a, offset, axis1, axis2) {
|
|
3852
4350
|
return fudgeArray(a).diagonal(offset, axis1, axis2);
|
|
@@ -3862,15 +4360,16 @@ function diag(v, k = 0) {
|
|
|
3862
4360
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
3863
4361
|
if (a.ndim === 1) {
|
|
3864
4362
|
const n = a.shape[0];
|
|
3865
|
-
const ret = where(eye(n).equal(1), a,
|
|
3866
|
-
if (k
|
|
3867
|
-
return ret;
|
|
4363
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
|
|
4364
|
+
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4365
|
+
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4366
|
+
else return ret;
|
|
3868
4367
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
3869
4368
|
else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
|
|
3870
4369
|
}
|
|
3871
4370
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
3872
4371
|
function allclose(actual, expected, options) {
|
|
3873
|
-
const { rtol = 1e-5, atol = 1e-
|
|
4372
|
+
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
3874
4373
|
const x = array(actual);
|
|
3875
4374
|
const y = array(expected);
|
|
3876
4375
|
if (!require_backend.deepEqual(x.shape, y.shape)) return false;
|
|
@@ -4004,15 +4503,52 @@ function log2(x) {
|
|
|
4004
4503
|
function log10(x) {
|
|
4005
4504
|
return log(x).mul(Math.LOG10E);
|
|
4006
4505
|
}
|
|
4506
|
+
/**
|
|
4507
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
4508
|
+
*
|
|
4509
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4510
|
+
*/
|
|
4511
|
+
function sinh(x) {
|
|
4512
|
+
const ex = exp(x);
|
|
4513
|
+
const emx = reciprocal(ex.ref);
|
|
4514
|
+
return ex.sub(emx).mul(.5);
|
|
4515
|
+
}
|
|
4516
|
+
/**
|
|
4517
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
4518
|
+
*
|
|
4519
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4520
|
+
*/
|
|
4521
|
+
function cosh(x) {
|
|
4522
|
+
const ex = exp(x);
|
|
4523
|
+
const emx = reciprocal(ex.ref);
|
|
4524
|
+
return ex.add(emx).mul(.5);
|
|
4525
|
+
}
|
|
4526
|
+
/**
|
|
4527
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
4528
|
+
*
|
|
4529
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4530
|
+
*/
|
|
4531
|
+
function tanh(x) {
|
|
4532
|
+
x = fudgeArray(x);
|
|
4533
|
+
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4534
|
+
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4535
|
+
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4536
|
+
}
|
|
4007
4537
|
|
|
4008
4538
|
//#endregion
|
|
4009
4539
|
//#region src/nn.ts
|
|
4010
4540
|
var nn_exports = {};
|
|
4011
4541
|
__export(nn_exports, {
|
|
4542
|
+
celu: () => celu,
|
|
4543
|
+
elu: () => elu,
|
|
4544
|
+
gelu: () => gelu,
|
|
4545
|
+
glu: () => glu,
|
|
4012
4546
|
identity: () => identity,
|
|
4547
|
+
leakyRelu: () => leakyRelu,
|
|
4013
4548
|
logSigmoid: () => logSigmoid,
|
|
4014
4549
|
logSoftmax: () => logSoftmax,
|
|
4015
4550
|
logsumexp: () => logsumexp,
|
|
4551
|
+
mish: () => mish,
|
|
4016
4552
|
oneHot: () => oneHot,
|
|
4017
4553
|
relu: () => relu,
|
|
4018
4554
|
relu6: () => relu6,
|
|
@@ -4072,10 +4608,7 @@ function softSign(x) {
|
|
|
4072
4608
|
*
|
|
4073
4609
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4074
4610
|
*/
|
|
4075
|
-
|
|
4076
|
-
x = fudgeArray(x);
|
|
4077
|
-
return x.ref.mul(sigmoid(x));
|
|
4078
|
-
}
|
|
4611
|
+
const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
|
|
4079
4612
|
/**
|
|
4080
4613
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4081
4614
|
* Swish, computed element-wise:
|
|
@@ -4095,6 +4628,72 @@ function logSigmoid(x) {
|
|
|
4095
4628
|
}
|
|
4096
4629
|
/** Identity activation function. Returns the argument unmodified. */
|
|
4097
4630
|
const identity = fudgeArray;
|
|
4631
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
4632
|
+
function leakyRelu(x, negativeSlope = .01) {
|
|
4633
|
+
x = fudgeArray(x);
|
|
4634
|
+
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
4635
|
+
}
|
|
4636
|
+
/**
|
|
4637
|
+
* Exponential linear unit activation function.
|
|
4638
|
+
*
|
|
4639
|
+
* Computes the element-wise function:
|
|
4640
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
4641
|
+
*/
|
|
4642
|
+
function elu(x, alpha = 1) {
|
|
4643
|
+
x = fudgeArray(x);
|
|
4644
|
+
return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
|
|
4645
|
+
}
|
|
4646
|
+
/**
|
|
4647
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
4648
|
+
*
|
|
4649
|
+
* Computes the element-wise function:
|
|
4650
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
4651
|
+
*/
|
|
4652
|
+
function celu(x, alpha = 1) {
|
|
4653
|
+
x = fudgeArray(x);
|
|
4654
|
+
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
4655
|
+
}
|
|
4656
|
+
/**
|
|
4657
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
4658
|
+
*
|
|
4659
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
4660
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
4661
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
4662
|
+
*
|
|
4663
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
4664
|
+
*
|
|
4665
|
+
* This will be improved in the future.
|
|
4666
|
+
*/
|
|
4667
|
+
const gelu = jit$1((x) => {
|
|
4668
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
4669
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
4670
|
+
});
|
|
4671
|
+
/**
|
|
4672
|
+
* Gated linear unit (GLU) activation function.
|
|
4673
|
+
*
|
|
4674
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
4675
|
+
* computes `a * sigmoid(b)`.
|
|
4676
|
+
*/
|
|
4677
|
+
function glu(x, axis = -1) {
|
|
4678
|
+
x = fudgeArray(x);
|
|
4679
|
+
axis = require_backend.checkAxis(axis, x.ndim);
|
|
4680
|
+
const size$1 = x.shape[axis];
|
|
4681
|
+
if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
|
|
4682
|
+
const slice = x.shape.map((a$1) => [0, a$1]);
|
|
4683
|
+
const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
|
|
4684
|
+
const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
|
|
4685
|
+
return a.mul(sigmoid(b));
|
|
4686
|
+
}
|
|
4687
|
+
/**
|
|
4688
|
+
* Mish activation function.
|
|
4689
|
+
*
|
|
4690
|
+
* Computes the element-wise function:
|
|
4691
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
4692
|
+
*/
|
|
4693
|
+
function mish(x) {
|
|
4694
|
+
x = fudgeArray(x);
|
|
4695
|
+
return x.ref.mul(tanh(softplus(x)));
|
|
4696
|
+
}
|
|
4098
4697
|
/**
|
|
4099
4698
|
* Softmax function. Computes the function which rescales elements to the range
|
|
4100
4699
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -4171,7 +4770,7 @@ function logsumexp(x, axis) {
|
|
|
4171
4770
|
* ```
|
|
4172
4771
|
*/
|
|
4173
4772
|
function oneHot(x, numClasses) {
|
|
4174
|
-
if (x.dtype !==
|
|
4773
|
+
if (x.dtype !== require_backend.DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4175
4774
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4176
4775
|
}
|
|
4177
4776
|
|
|
@@ -4240,6 +4839,19 @@ const vmap = vmap$1;
|
|
|
4240
4839
|
const jacfwd = jacfwd$1;
|
|
4241
4840
|
/** Construct a Jaxpr by dynamically tracing a function with example inputs. */
|
|
4242
4841
|
const makeJaxpr = makeJaxpr$1;
|
|
4842
|
+
/**
|
|
4843
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
4844
|
+
*
|
|
4845
|
+
* The function will be compiled the first time it is called with a set of
|
|
4846
|
+
* argument shapes.
|
|
4847
|
+
*
|
|
4848
|
+
* **Options:**
|
|
4849
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
4850
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
4851
|
+
* and different values will trigger recompilation.
|
|
4852
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
4853
|
+
* computation will be placed on the default device.
|
|
4854
|
+
*/
|
|
4243
4855
|
const jit = jit$1;
|
|
4244
4856
|
/**
|
|
4245
4857
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
@@ -4261,6 +4873,7 @@ const jacrev = jacrev$1;
|
|
|
4261
4873
|
const jacobian = jacrev;
|
|
4262
4874
|
|
|
4263
4875
|
//#endregion
|
|
4876
|
+
exports.DType = require_backend.DType;
|
|
4264
4877
|
exports.devices = require_backend.devices;
|
|
4265
4878
|
exports.grad = grad;
|
|
4266
4879
|
exports.init = require_backend.init;
|
|
@@ -4269,6 +4882,12 @@ exports.jacobian = jacobian;
|
|
|
4269
4882
|
exports.jacrev = jacrev;
|
|
4270
4883
|
exports.jit = jit;
|
|
4271
4884
|
exports.jvp = jvp;
|
|
4885
|
+
Object.defineProperty(exports, 'lax', {
|
|
4886
|
+
enumerable: true,
|
|
4887
|
+
get: function () {
|
|
4888
|
+
return lax_exports;
|
|
4889
|
+
}
|
|
4890
|
+
});
|
|
4272
4891
|
exports.linearize = linearize;
|
|
4273
4892
|
exports.makeJaxpr = makeJaxpr;
|
|
4274
4893
|
Object.defineProperty(exports, 'nn', {
|