@jax-js/jax 0.0.2 → 0.0.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +9 -8
- package/dist/{backend-1eVbAoaV.js → backend-BqDtPGaR.js} +1869 -86
- package/dist/{backend-BK21PBVP.cjs → backend-D2C4MJRP.cjs} +1892 -85
- package/dist/index.cjs +737 -118
- package/dist/index.d.cts +247 -44
- package/dist/index.d.ts +247 -44
- package/dist/index.js +726 -114
- package/dist/{webgpu-JVpVad6g.js → webgpu-CNg9JGva.js} +54 -33
- package/dist/{webgpu-c5Fe8nx8.cjs → webgpu-fqhx41TC.cjs} +54 -33
- package/package.json +7 -6
package/dist/index.js
CHANGED
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
import { __export } from "./chunk-Cl8Af3a2.js";
|
|
2
|
-
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip } from "./backend-
|
|
2
|
+
import { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, FpHash, Kernel, PPrint, Reduction, ShapeTracker, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, toposort, unravelAlu, unzip2, zip, zipn } from "./backend-BqDtPGaR.js";
|
|
3
3
|
|
|
4
4
|
//#region src/tree.ts
|
|
5
5
|
var tree_exports = {};
|
|
6
6
|
__export(tree_exports, {
|
|
7
7
|
JsTreeDef: () => JsTreeDef,
|
|
8
8
|
NodeType: () => NodeType,
|
|
9
|
+
dispose: () => dispose,
|
|
9
10
|
flatten: () => flatten,
|
|
10
11
|
leaves: () => leaves,
|
|
11
12
|
map: () => map,
|
|
@@ -20,7 +21,7 @@ let NodeType = /* @__PURE__ */ function(NodeType$1) {
|
|
|
20
21
|
NodeType$1["Leaf"] = "Leaf";
|
|
21
22
|
return NodeType$1;
|
|
22
23
|
}({});
|
|
23
|
-
/**
|
|
24
|
+
/** Represents the structure of a JsTree. */
|
|
24
25
|
var JsTreeDef = class JsTreeDef {
|
|
25
26
|
static leaf = new JsTreeDef(NodeType.Leaf, null, []);
|
|
26
27
|
constructor(nodeType, nodeMetadata, childTreedefs) {
|
|
@@ -108,6 +109,194 @@ function map(fn, tree, ...rest) {
|
|
|
108
109
|
function ref(tree) {
|
|
109
110
|
return map((x) => x.ref, tree);
|
|
110
111
|
}
|
|
112
|
+
/** Dispose every array in a tree. */
|
|
113
|
+
function dispose(tree) {
|
|
114
|
+
if (tree) map((x) => x.dispose(), tree);
|
|
115
|
+
}
|
|
116
|
+
|
|
117
|
+
//#endregion
|
|
118
|
+
//#region src/frontend/convolution.ts
|
|
119
|
+
/**
|
|
120
|
+
* Check that the shapes and parameters passed to convolution are valid.
|
|
121
|
+
*
|
|
122
|
+
* If the check succeeds, returns the output shape.
|
|
123
|
+
*/
|
|
124
|
+
function checkConvShape(lhsShape, rhsShape, { strides, padding, lhsDilation, rhsDilation }) {
|
|
125
|
+
if (lhsShape.length !== rhsShape.length) throw new Error(`conv() requires inputs with the same number of dimensions, got ${lhsShape.length} and ${rhsShape.length}`);
|
|
126
|
+
const n = lhsShape.length - 2;
|
|
127
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
128
|
+
if (strides.length !== n) throw new Error("conv() strides != spatial dims");
|
|
129
|
+
if (padding.length !== n) throw new Error("conv() padding != spatial dims");
|
|
130
|
+
if (lhsDilation.length !== n) throw new Error("conv() lhsDilation != spatial dimensions");
|
|
131
|
+
if (rhsDilation.length !== n) throw new Error("conv() rhsDilation != spatial dimensions");
|
|
132
|
+
if (lhsShape[1] !== rhsShape[1]) throw new Error(`conv() input channels: ${lhsShape[1]} != ${rhsShape[1]}`);
|
|
133
|
+
const outShape = [lhsShape[0], rhsShape[0]];
|
|
134
|
+
for (let i = 0; i < n; i++) {
|
|
135
|
+
if (strides[i] <= 0 || !Number.isInteger(strides[i])) throw new Error(`conv() strides[${i}] must be a positive integer`);
|
|
136
|
+
if (padding[i].length !== 2 || !padding[i].every(Number.isInteger)) throw new Error(`conv() padding[${i}] must be a 2-tuple of integers`);
|
|
137
|
+
if (lhsDilation[i] <= 0 || !Number.isInteger(lhsDilation[i])) throw new Error(`conv() lhsDilation[${i}] must be a positive integer`);
|
|
138
|
+
if (rhsDilation[i] <= 0 || !Number.isInteger(rhsDilation[i])) throw new Error(`conv() rhsDilation[${i}] must be a positive integer`);
|
|
139
|
+
const [x, k] = [lhsShape[i + 2], rhsShape[i + 2]];
|
|
140
|
+
if (k <= 0) throw new Error("conv() kernel size must be positive");
|
|
141
|
+
const [pl, pr] = padding[i];
|
|
142
|
+
if (pl < -x || pr < -x || pl + pr < -x) throw new Error(`conv() padding[${i}]=(${pl},${pr}) is too negative for input size ${x}`);
|
|
143
|
+
const kernelSize = (k - 1) * rhsDilation[i] + 1;
|
|
144
|
+
const inSize = Math.max((x - 1) * lhsDilation[i] + 1, 0) + pl + pr;
|
|
145
|
+
if (kernelSize > inSize) throw new Error(`conv() kernel size ${kernelSize} > input size ${inSize} in dimension ${i}`);
|
|
146
|
+
outShape.push(Math.ceil((inSize - kernelSize + 1) / strides[i]));
|
|
147
|
+
}
|
|
148
|
+
return outShape;
|
|
149
|
+
}
|
|
150
|
+
function checkPoolShape(inShape, window, strides) {
|
|
151
|
+
if (strides.length !== window.length) throw new Error("pool() strides != window dims");
|
|
152
|
+
if (window.length > inShape.length) throw new Error("pool() window has more dimensions than input");
|
|
153
|
+
const outShape = inShape.slice(0, inShape.length - window.length);
|
|
154
|
+
for (let i = 0; i < window.length; i++) {
|
|
155
|
+
const k = window[i];
|
|
156
|
+
const s = strides[i];
|
|
157
|
+
const size$1 = inShape[inShape.length - window.length + i];
|
|
158
|
+
if (k <= 0 || !Number.isInteger(k)) throw new Error(`pool() window[${i}] must be a positive integer`);
|
|
159
|
+
if (k > size$1) throw new Error(`pool() window[${i}]=${k} > input size ${size$1}`);
|
|
160
|
+
if (s <= 0 || !Number.isInteger(s)) throw new Error(`pool() strides[${i}] must be a positive integer`);
|
|
161
|
+
outShape.push(Math.ceil((size$1 - k + 1) / s));
|
|
162
|
+
}
|
|
163
|
+
return outShape.concat(window);
|
|
164
|
+
}
|
|
165
|
+
/**
|
|
166
|
+
* Takes a shape tracker and a kernel size `ks`, then reshapes it so the last
|
|
167
|
+
* `ks.length` dimensions become `2 * ks.length` dimensions by treating them as
|
|
168
|
+
* spatial dimensions convolved with a kernel.
|
|
169
|
+
*
|
|
170
|
+
* The resulting array can be multiplied with a kernel of shape `ks`, then
|
|
171
|
+
* reduced along the last `ks.length` dimensions for a convolution.
|
|
172
|
+
*
|
|
173
|
+
* Reference: https://github.com/tinygrad/tinygrad/blob/v0.10.3/tinygrad/tensor.py#L2097
|
|
174
|
+
*/
|
|
175
|
+
function pool(st, ks, strides = 1, dilation = 1) {
|
|
176
|
+
if (ks.length === 0) return st;
|
|
177
|
+
if (st.shape.length < ks.length) throw new Error("pool() called with too many dimensions");
|
|
178
|
+
if (typeof strides === "number") strides = rep(ks.length, strides);
|
|
179
|
+
if (typeof dilation === "number") dilation = rep(ks.length, dilation);
|
|
180
|
+
if (strides.some((s) => s <= 0 || !Number.isInteger(s))) throw new Error("pool() strides must be positive integers");
|
|
181
|
+
if (dilation.some((d) => d <= 0 || !Number.isInteger(d))) throw new Error("pool() dilation must be positive integers");
|
|
182
|
+
const noop = st.shape.slice(0, -ks.length);
|
|
183
|
+
const i_ = st.shape.slice(-ks.length);
|
|
184
|
+
const s_ = strides;
|
|
185
|
+
const d_ = dilation;
|
|
186
|
+
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
187
|
+
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
188
|
+
const kidf = zipn(ks, i_, d_, f_);
|
|
189
|
+
st = st.repeat([...rep(noop.length, 1), ...kidf.map(([k, i, d, f]) => Math.ceil(k * (i * f + d) / i))]);
|
|
190
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kidf.map(([k, i, d, f]) => [0, k * (i * f + d)])]).reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [k, i * f + d])]);
|
|
191
|
+
const kos = zipn(ks, o_, s_);
|
|
192
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o, s]) => [[0, k], [0, o * s]])]).reshape([...noop, ...kos.flat(1)]);
|
|
193
|
+
st = st.shrink([...noop.map((x) => [0, x]), ...kos.flatMap(([k, o]) => [
|
|
194
|
+
[0, k],
|
|
195
|
+
[0, o],
|
|
196
|
+
[0, 1]
|
|
197
|
+
])]).reshape([...noop, ...kos.flatMap(([k, o]) => [k, o])]);
|
|
198
|
+
st = st.permute([
|
|
199
|
+
...range(noop.length),
|
|
200
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
201
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
202
|
+
]);
|
|
203
|
+
return st;
|
|
204
|
+
}
|
|
205
|
+
/**
|
|
206
|
+
* Perform the transpose of pool, directly undo-ing a pool() operation.
|
|
207
|
+
*
|
|
208
|
+
* Note that since pool repeats the input, the transpose operation technically
|
|
209
|
+
* should include a sum reduction. This function doesn't perform the reduction,
|
|
210
|
+
* which should be done on the last `k` axes of the returned shape.
|
|
211
|
+
*/
|
|
212
|
+
function poolTranspose(st, inShape, ks, strides = 1, dilation = 1) {
|
|
213
|
+
if (ks.length === 0) return st;
|
|
214
|
+
if (typeof strides === "number") strides = rep(ks.length, strides);
|
|
215
|
+
if (typeof dilation === "number") dilation = rep(ks.length, dilation);
|
|
216
|
+
const noop = inShape.slice(0, -ks.length);
|
|
217
|
+
const i_ = inShape.slice(-ks.length);
|
|
218
|
+
const s_ = strides;
|
|
219
|
+
const d_ = dilation;
|
|
220
|
+
const o_ = zipn(i_, d_, ks, s_).map(([i, d, k, s]) => Math.ceil((i - d * (k - 1)) / s));
|
|
221
|
+
if (!deepEqual(o_, st.shape.slice(noop.length, noop.length + ks.length))) throw new Error("poolTranspose() called with mismatched output shape");
|
|
222
|
+
const f_ = zipn(o_, s_, i_, d_, ks).map(([o, s, i, d, k]) => 1 + Number(o * s > i - d * (k - 1)));
|
|
223
|
+
const kidf = zipn(ks, i_, d_, f_);
|
|
224
|
+
const kos = zipn(ks, o_, s_);
|
|
225
|
+
st = st.permute([...range(noop.length), ...ks.flatMap((_, j) => [noop.length + ks.length + j, noop.length + j])]);
|
|
226
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o]) => [
|
|
227
|
+
k,
|
|
228
|
+
o,
|
|
229
|
+
1
|
|
230
|
+
])]).pad([...noop.map(() => [0, 0]), ...s_.flatMap((s) => [
|
|
231
|
+
[0, 0],
|
|
232
|
+
[0, 0],
|
|
233
|
+
[0, s - 1]
|
|
234
|
+
])]);
|
|
235
|
+
st = st.reshape([...noop, ...kos.flatMap(([k, o, s]) => [k, o * s])]).pad([...noop.map(() => [0, 0]), ...kidf.flatMap(([_k, i, d, f], j) => [[0, 0], [0, i * f + d - o_[j] * s_[j]]])]);
|
|
236
|
+
st = st.reshape([...noop, ...kidf.map(([k, i, d, f]) => k * (i * f + d))]).pad([...noop.map(() => [0, 0]), ...kidf.map(([k, i, d, f]) => [0, Math.ceil(k * (i * f + d) / i) * i - k * (i * f + d)])]);
|
|
237
|
+
st = st.reshape([...noop, ...kidf.flatMap(([k, i, d, f]) => [Math.ceil(k * (i * f + d) / i), i])]).permute([
|
|
238
|
+
...range(noop.length),
|
|
239
|
+
...ks.map((_, j) => noop.length + 2 * j + 1),
|
|
240
|
+
...ks.map((_, j) => noop.length + 2 * j)
|
|
241
|
+
]);
|
|
242
|
+
return st;
|
|
243
|
+
}
|
|
244
|
+
/** Applies dilation to an array directly, for transposed convolution. */
|
|
245
|
+
function applyDilation(st, dilation) {
|
|
246
|
+
if (dilation.every((s) => s === 1)) return st;
|
|
247
|
+
const s_ = dilation;
|
|
248
|
+
const [a, b, ...k_] = st.shape;
|
|
249
|
+
st = st.reshape([
|
|
250
|
+
a,
|
|
251
|
+
b,
|
|
252
|
+
...k_.flatMap((k) => [k, 1])
|
|
253
|
+
]);
|
|
254
|
+
st = st.pad([
|
|
255
|
+
[0, 0],
|
|
256
|
+
[0, 0],
|
|
257
|
+
...s_.flatMap((s) => [[0, 0], [0, s - 1]])
|
|
258
|
+
]);
|
|
259
|
+
st = st.reshape([
|
|
260
|
+
a,
|
|
261
|
+
b,
|
|
262
|
+
...k_.map((k, i) => k * s_[i])
|
|
263
|
+
]);
|
|
264
|
+
st = st.shrink([
|
|
265
|
+
[0, a],
|
|
266
|
+
[0, b],
|
|
267
|
+
...k_.map((k, i) => [0, (k - 1) * s_[i] + 1])
|
|
268
|
+
]);
|
|
269
|
+
return st;
|
|
270
|
+
}
|
|
271
|
+
/**
|
|
272
|
+
* Prepare for a convolution between two arrays.
|
|
273
|
+
*
|
|
274
|
+
* This does not check the validity of the shapes, which should be checked
|
|
275
|
+
* beforehand using `checkConvShape()`.
|
|
276
|
+
*/
|
|
277
|
+
function prepareConv(stX, stY, params) {
|
|
278
|
+
const n = stX.shape.length - 2;
|
|
279
|
+
stX = applyDilation(stX, params.lhsDilation);
|
|
280
|
+
const ks = stY.shape.slice(2);
|
|
281
|
+
stX = stX.padOrShrink([
|
|
282
|
+
[0, 0],
|
|
283
|
+
[0, 0],
|
|
284
|
+
...params.padding
|
|
285
|
+
]);
|
|
286
|
+
stX = pool(stX, ks, params.strides, params.rhsDilation);
|
|
287
|
+
stX = stX.moveaxis(1, n + 1).reshape([
|
|
288
|
+
stX.shape[0],
|
|
289
|
+
1,
|
|
290
|
+
...stX.shape.slice(2, n + 2),
|
|
291
|
+
stX.shape[1] * prod(ks)
|
|
292
|
+
]);
|
|
293
|
+
stY = stY.reshape([
|
|
294
|
+
stY.shape[0],
|
|
295
|
+
...rep(n, 1),
|
|
296
|
+
stY.shape[1] * prod(ks)
|
|
297
|
+
]);
|
|
298
|
+
return [stX, stY];
|
|
299
|
+
}
|
|
111
300
|
|
|
112
301
|
//#endregion
|
|
113
302
|
//#region src/frontend/core.ts
|
|
@@ -136,10 +325,14 @@ let Primitive = /* @__PURE__ */ function(Primitive$1) {
|
|
|
136
325
|
Primitive$1["Cos"] = "cos";
|
|
137
326
|
Primitive$1["Exp"] = "exp";
|
|
138
327
|
Primitive$1["Log"] = "log";
|
|
328
|
+
Primitive$1["Sqrt"] = "sqrt";
|
|
139
329
|
Primitive$1["Min"] = "min";
|
|
140
330
|
Primitive$1["Max"] = "max";
|
|
141
331
|
Primitive$1["Reduce"] = "reduce";
|
|
142
332
|
Primitive$1["Dot"] = "dot";
|
|
333
|
+
Primitive$1["Conv"] = "conv";
|
|
334
|
+
Primitive$1["Pool"] = "pool";
|
|
335
|
+
Primitive$1["PoolTranspose"] = "pool_transpose";
|
|
143
336
|
Primitive$1["Compare"] = "compare";
|
|
144
337
|
Primitive$1["Where"] = "where";
|
|
145
338
|
Primitive$1["Transpose"] = "transpose";
|
|
@@ -203,6 +396,9 @@ function exp$1(x) {
|
|
|
203
396
|
function log$1(x) {
|
|
204
397
|
return bind1(Primitive.Log, [x]);
|
|
205
398
|
}
|
|
399
|
+
function sqrt$1(x) {
|
|
400
|
+
return bind1(Primitive.Sqrt, [x]);
|
|
401
|
+
}
|
|
206
402
|
function min$1(x, y) {
|
|
207
403
|
return bind1(Primitive.Min, [x, y]);
|
|
208
404
|
}
|
|
@@ -225,6 +421,17 @@ function reduce(x, op, axis, opts) {
|
|
|
225
421
|
function dot$1(x, y) {
|
|
226
422
|
return bind1(Primitive.Dot, [x, y]);
|
|
227
423
|
}
|
|
424
|
+
function conv(x, y, params = {}) {
|
|
425
|
+
if (x.ndim !== y.ndim) throw new Error(`conv() requires inputs with the same number of dimensions, got ${x.ndim} and ${y.ndim}`);
|
|
426
|
+
const n = x.ndim - 2;
|
|
427
|
+
if (n < 0) throw new Error("conv() requires at least 2D inputs");
|
|
428
|
+
return bind1(Primitive.Conv, [x, y], {
|
|
429
|
+
strides: params.strides ?? rep(n, 1),
|
|
430
|
+
padding: params.padding ?? rep(n, [0, 0]),
|
|
431
|
+
lhsDilation: params.lhsDilation ?? rep(n, 1),
|
|
432
|
+
rhsDilation: params.rhsDilation ?? rep(n, 1)
|
|
433
|
+
});
|
|
434
|
+
}
|
|
228
435
|
function compare(x, y, op) {
|
|
229
436
|
return bind1(Primitive.Compare, [x, y], { op });
|
|
230
437
|
}
|
|
@@ -360,6 +567,9 @@ var Tracer = class Tracer {
|
|
|
360
567
|
get shape() {
|
|
361
568
|
return this.aval.shape;
|
|
362
569
|
}
|
|
570
|
+
get size() {
|
|
571
|
+
return prod(this.shape);
|
|
572
|
+
}
|
|
363
573
|
get dtype() {
|
|
364
574
|
return this.aval.dtype;
|
|
365
575
|
}
|
|
@@ -411,7 +621,7 @@ var Tracer = class Tracer {
|
|
|
411
621
|
else if (typeof axis === "number") axis = [checkAxis(axis, this.ndim)];
|
|
412
622
|
else axis = axis.map((a) => checkAxis(a, this.ndim));
|
|
413
623
|
let result = reduce(this, AluOp.Add, axis);
|
|
414
|
-
result = result.mul(
|
|
624
|
+
result = result.mul(result.size / this.size);
|
|
415
625
|
if (opts?.keepDims) result = broadcast(result, this.shape, axis);
|
|
416
626
|
return result;
|
|
417
627
|
}
|
|
@@ -445,8 +655,29 @@ var Tracer = class Tracer {
|
|
|
445
655
|
/** Return specified diagonals. See `numpy.diagonal` for full docs. */
|
|
446
656
|
diagonal(offset = 0, axis1 = 0, axis2 = 1) {
|
|
447
657
|
if (!Number.isInteger(offset)) throw new TypeError(`offset must be an integer, got ${offset}`);
|
|
658
|
+
if (offset < 0) return this.diagonal(-offset, axis2, axis1);
|
|
659
|
+
axis1 = checkAxis(axis1, this.ndim);
|
|
660
|
+
axis2 = checkAxis(axis2, this.ndim);
|
|
448
661
|
if (axis1 === axis2) throw new Error("axis1 and axis2 must not be equal");
|
|
449
|
-
throw new Error("
|
|
662
|
+
if (offset >= this.shape[axis2]) throw new Error("offset exceeds axis size");
|
|
663
|
+
let ar = this;
|
|
664
|
+
if (axis1 !== ar.ndim - 2 || axis2 !== ar.ndim - 1) {
|
|
665
|
+
const perm = range(ar.ndim).filter((i) => i !== axis1 && i !== axis2).concat(axis1, axis2);
|
|
666
|
+
ar = ar.transpose(perm);
|
|
667
|
+
}
|
|
668
|
+
const [n, m] = ar.shape.slice(-2);
|
|
669
|
+
const diagSize = Math.min(n, m - offset);
|
|
670
|
+
ar = ar.reshape([...ar.shape.slice(0, -2), n * m]);
|
|
671
|
+
const npad = diagSize * (m + 1) - n * m;
|
|
672
|
+
if (npad > 0) ar = pad$1(ar, [...rep(ar.ndim - 1, [0, 0]), [0, npad]]);
|
|
673
|
+
else if (npad < 0) ar = shrink(ar, [...ar.shape.slice(0, -1), n * m + npad].map((x) => [0, x]));
|
|
674
|
+
ar = ar.reshape([
|
|
675
|
+
...ar.shape.slice(0, -1),
|
|
676
|
+
diagSize,
|
|
677
|
+
m + 1
|
|
678
|
+
]);
|
|
679
|
+
ar = shrink(ar, [...ar.shape.slice(0, -1).map((x) => [0, x]), [offset, offset + 1]]).reshape(ar.shape.slice(0, -1));
|
|
680
|
+
return ar;
|
|
450
681
|
}
|
|
451
682
|
/** Flatten the array without changing its data. */
|
|
452
683
|
flatten() {
|
|
@@ -589,7 +820,7 @@ var ShapedArray = class ShapedArray {
|
|
|
589
820
|
get ndim() {
|
|
590
821
|
return this.shape.length;
|
|
591
822
|
}
|
|
592
|
-
|
|
823
|
+
toString() {
|
|
593
824
|
return `${this.dtype}[${this.shape.join(",")}]`;
|
|
594
825
|
}
|
|
595
826
|
equals(other) {
|
|
@@ -620,7 +851,7 @@ function fullRaise(trace, val) {
|
|
|
620
851
|
if (Object.is(val._trace.main, trace.main)) return val;
|
|
621
852
|
else if (val._trace.main.level < level) return trace.lift(val);
|
|
622
853
|
else if (val._trace.main.level > level) throw new Error(`Can't lift Tracer level ${val._trace.main.level} to level ${level}`);
|
|
623
|
-
else throw new Error(`Different traces at same level: ${val._trace}, ${trace}.`);
|
|
854
|
+
else throw new Error(`Different traces at same level: ${val._trace.constructor}, ${trace.constructor}.`);
|
|
624
855
|
}
|
|
625
856
|
var TreeMismatchError = class extends TypeError {
|
|
626
857
|
constructor(where$2, left, right) {
|
|
@@ -869,16 +1100,16 @@ function jitCompile(backend, jaxpr, consts) {
|
|
|
869
1100
|
jitCompileCache.set(cacheKey, jp);
|
|
870
1101
|
return jp;
|
|
871
1102
|
}
|
|
872
|
-
function reshapeViews(exp$2, mapping) {
|
|
1103
|
+
function reshapeViews(exp$2, mapping, reduceAxis = false) {
|
|
873
1104
|
return exp$2.rewrite((exp$3) => {
|
|
874
1105
|
if (exp$3.op === AluOp.GlobalView) {
|
|
875
1106
|
const [gid, st] = exp$3.arg;
|
|
876
1107
|
const newSt = mapping(st);
|
|
877
1108
|
if (newSt) {
|
|
878
|
-
const indices = unravelAlu(newSt.shape, AluVar.gidx);
|
|
1109
|
+
const indices = reduceAxis ? unravelAlu(newSt.shape.slice(0, -1), AluVar.gidx).concat(AluVar.ridx) : unravelAlu(newSt.shape, AluVar.gidx);
|
|
879
1110
|
return AluExp.globalView(exp$3.dtype, gid, newSt, indices);
|
|
880
1111
|
}
|
|
881
|
-
}
|
|
1112
|
+
} else if (exp$3.op === AluOp.GlobalIndex) throw new Error("internal: reshapeViews() called with GlobalIndex op");
|
|
882
1113
|
});
|
|
883
1114
|
}
|
|
884
1115
|
function broadcastedJit(fn) {
|
|
@@ -927,6 +1158,7 @@ const jitRules = {
|
|
|
927
1158
|
[Primitive.Cos]: unopJit(AluExp.cos),
|
|
928
1159
|
[Primitive.Exp]: unopJit(AluExp.exp),
|
|
929
1160
|
[Primitive.Log]: unopJit(AluExp.log),
|
|
1161
|
+
[Primitive.Sqrt]: unopJit(AluExp.sqrt),
|
|
930
1162
|
[Primitive.Min]: broadcastedJit(([a, b]) => AluExp.min(a, b)),
|
|
931
1163
|
[Primitive.Max]: broadcastedJit(([a, b]) => AluExp.max(a, b)),
|
|
932
1164
|
[Primitive.Reduce](nargs, [a], [as], { op, axis }) {
|
|
@@ -941,18 +1173,20 @@ const jitRules = {
|
|
|
941
1173
|
const size$1 = prod(newShape);
|
|
942
1174
|
const reductionSize = prod(shiftedAxes.map((ax) => as.shape[ax]));
|
|
943
1175
|
newShape.push(reductionSize);
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
const [gid, st] = exp$2.arg;
|
|
947
|
-
const newSt = st.permute(keptAxes.concat(shiftedAxes)).reshape(newShape);
|
|
948
|
-
const indices = unravelAlu(newShape.slice(0, -1), AluVar.gidx);
|
|
949
|
-
indices.push(AluVar.ridx);
|
|
950
|
-
return AluExp.globalView(exp$2.dtype, gid, newSt, indices);
|
|
951
|
-
}
|
|
952
|
-
});
|
|
1176
|
+
const perm = keptAxes.concat(shiftedAxes);
|
|
1177
|
+
a = reshapeViews(a, (st) => st.permute(perm).reshape(newShape), true);
|
|
953
1178
|
const reduction = new Reduction(a.dtype, op, reductionSize);
|
|
954
1179
|
return new Kernel(nargs, size$1, a, reduction);
|
|
955
1180
|
},
|
|
1181
|
+
[Primitive.Pool]: reshapeJit((st, { window, strides }) => pool(st, window, strides)),
|
|
1182
|
+
[Primitive.PoolTranspose](nargs, [a], [as], { inShape, window, strides }) {
|
|
1183
|
+
let stX = poolTranspose(ShapeTracker.fromShape(as.shape), inShape, window, strides);
|
|
1184
|
+
const size$1 = prod(inShape);
|
|
1185
|
+
stX = stX.reshape([...inShape, prod(stX.shape.slice(inShape.length))]);
|
|
1186
|
+
a = reshapeViews(a, (st) => st.compose(stX), true);
|
|
1187
|
+
const reduction = new Reduction(a.dtype, AluOp.Add, stX.shape[stX.shape.length - 1]);
|
|
1188
|
+
return new Kernel(nargs, size$1, a, reduction);
|
|
1189
|
+
},
|
|
956
1190
|
[Primitive.Dot](nargs, [a, b], [as, bs]) {
|
|
957
1191
|
const k1 = jitRules[Primitive.Mul](nargs, [a, b], [as, bs], {});
|
|
958
1192
|
const c = k1.exp;
|
|
@@ -962,6 +1196,14 @@ const jitRules = {
|
|
|
962
1196
|
axis: [cs.ndim - 1]
|
|
963
1197
|
});
|
|
964
1198
|
},
|
|
1199
|
+
[Primitive.Conv](nargs, [a, b], [as, bs], params) {
|
|
1200
|
+
const [stX, stY] = prepareConv(ShapeTracker.fromShape(as.shape), ShapeTracker.fromShape(bs.shape), params);
|
|
1201
|
+
a = reshapeViews(a, (st) => st.compose(stX));
|
|
1202
|
+
b = reshapeViews(b, (st) => st.compose(stY));
|
|
1203
|
+
as = new ShapedArray(stX.shape, as.dtype);
|
|
1204
|
+
bs = new ShapedArray(stY.shape, bs.dtype);
|
|
1205
|
+
return jitRules[Primitive.Dot](nargs, [a, b], [as, bs], {});
|
|
1206
|
+
},
|
|
965
1207
|
[Primitive.Compare]: broadcastedJit(([a, b], { op }) => aluCompare(a, b, op)),
|
|
966
1208
|
[Primitive.Where]: broadcastedJit(([cond, a, b]) => AluExp.where(cond, a, b)),
|
|
967
1209
|
[Primitive.Transpose]: reshapeJit((st, { perm }) => st.permute(perm)),
|
|
@@ -974,8 +1216,20 @@ const jitRules = {
|
|
|
974
1216
|
}),
|
|
975
1217
|
[Primitive.Shrink]: reshapeJit((st, { slice }) => st.shrink(slice)),
|
|
976
1218
|
[Primitive.Pad]: reshapeJit((st, { width }) => st.pad(width)),
|
|
977
|
-
[Primitive.Gather]() {
|
|
978
|
-
|
|
1219
|
+
[Primitive.Gather](nargs, [x, ...indices], [xs, ...indicesShapes], { axis, outDim }) {
|
|
1220
|
+
const axisSet = new Set(axis);
|
|
1221
|
+
const indexShape = indicesShapes.map((c) => c.shape).reduce(generalBroadcast);
|
|
1222
|
+
const finalShape = xs.shape.filter((_, i) => !axisSet.has(i));
|
|
1223
|
+
finalShape.splice(outDim, 0, ...indexShape);
|
|
1224
|
+
const idxAll = unravelAlu(finalShape, AluVar.gidx);
|
|
1225
|
+
const idxNonaxis = [...idxAll];
|
|
1226
|
+
idxNonaxis.splice(outDim, indexShape.length);
|
|
1227
|
+
const src = [...idxNonaxis];
|
|
1228
|
+
for (let i = 0; i < xs.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1229
|
+
for (const [i, iexp] of indices.entries()) src[axis[i]] = AluExp.cast(DType.Int32, reshapeViews(iexp, (st) => st.broadcast(finalShape, [...range(outDim + indexShape.length - st.shape.length), ...range(outDim + indexShape.length, finalShape.length)])));
|
|
1230
|
+
const [index, valid] = ShapeTracker.fromShape(xs.shape).toAluExp(src);
|
|
1231
|
+
if (!valid.resolve()) throw new Error("internal: expected full validity mask in Gather");
|
|
1232
|
+
return new Kernel(nargs, prod(finalShape), x.substitute({ gidx: index }));
|
|
979
1233
|
},
|
|
980
1234
|
[Primitive.JitCall]() {
|
|
981
1235
|
throw new Error("internal: JitCall should have been flattened before JIT compilation");
|
|
@@ -994,9 +1248,15 @@ function splitGraphDataflow(backend, jaxpr) {
|
|
|
994
1248
|
blackNodes.add(v);
|
|
995
1249
|
p1NextBlack.set(v, v);
|
|
996
1250
|
}
|
|
1251
|
+
const reducePrimitives = [
|
|
1252
|
+
Primitive.Reduce,
|
|
1253
|
+
Primitive.Dot,
|
|
1254
|
+
Primitive.Conv,
|
|
1255
|
+
Primitive.PoolTranspose
|
|
1256
|
+
];
|
|
997
1257
|
for (let i = jaxpr.eqns.length - 1; i >= 0; i--) {
|
|
998
1258
|
const eqn = jaxpr.eqns[i];
|
|
999
|
-
if (eqn.primitive === Primitive.
|
|
1259
|
+
if (reducePrimitives.includes(eqn.primitive) || eqn.primitive === Primitive.Gather || eqn.outBinders.some((v) => blackNodes.has(v))) {
|
|
1000
1260
|
for (const v of eqn.outBinders) {
|
|
1001
1261
|
blackNodes.add(v);
|
|
1002
1262
|
p1NextBlack.set(v, v);
|
|
@@ -1223,7 +1483,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1223
1483
|
const inputs = [];
|
|
1224
1484
|
const src = [...idxNonaxis];
|
|
1225
1485
|
for (let i = 0; i < this.shape.length; i++) if (axisSet.has(i)) src.splice(i, 0, null);
|
|
1226
|
-
for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#
|
|
1486
|
+
for (const [i, ar] of indices.entries()) if (ar.#source instanceof AluExp) src[axis[i]] = AluExp.cast(DType.Int32, accessorAluExp(ar.#source, ar.#st, idxAxis));
|
|
1227
1487
|
else {
|
|
1228
1488
|
let gid = inputs.indexOf(ar.#source);
|
|
1229
1489
|
if (gid === -1) {
|
|
@@ -1233,7 +1493,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1233
1493
|
src[axis[i]] = AluExp.cast(DType.Int32, AluExp.globalView(ar.#dtype, gid, ar.#st, idxAxis));
|
|
1234
1494
|
}
|
|
1235
1495
|
let exp$2;
|
|
1236
|
-
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#
|
|
1496
|
+
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, src);
|
|
1237
1497
|
else {
|
|
1238
1498
|
let gid = inputs.indexOf(this.#source);
|
|
1239
1499
|
if (gid === -1) {
|
|
@@ -1276,7 +1536,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1276
1536
|
this.#check();
|
|
1277
1537
|
if (this.#source instanceof AluExp) {
|
|
1278
1538
|
const exp$3 = new AluExp(op, dtypeOutput, [this.#source]);
|
|
1279
|
-
return new Array$1(exp$3, this.#st, dtypeOutput, this.#backend);
|
|
1539
|
+
return new Array$1(exp$3.simplify(), this.#st, dtypeOutput, this.#backend);
|
|
1280
1540
|
}
|
|
1281
1541
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1282
1542
|
const exp$2 = new AluExp(op, dtypeOutput, [AluExp.globalView(this.#dtype, 0, this.#st, indices)]);
|
|
@@ -1309,18 +1569,18 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1309
1569
|
if (!dtypeOutput) throw new TypeError("nary operation with no dtype");
|
|
1310
1570
|
arrays = Array$1.#broadcastArrays(arrays);
|
|
1311
1571
|
const newShape = [...arrays[0].shape];
|
|
1312
|
-
if (arrays.every((ar) => ar.#source instanceof AluExp) && reduceAxis
|
|
1572
|
+
if (arrays.every((ar) => ar.#source instanceof AluExp) && !reduceAxis) {
|
|
1313
1573
|
if (arrays.every((ar) => deepEqual(ar.#st, arrays[0].#st))) {
|
|
1314
1574
|
const exp$4 = custom(arrays.map((ar) => ar.#source));
|
|
1315
|
-
return new Array$1(exp$4, arrays[0].#st, exp$4.dtype, backend);
|
|
1575
|
+
return new Array$1(exp$4.simplify(), arrays[0].#st, exp$4.dtype, backend);
|
|
1316
1576
|
}
|
|
1317
1577
|
const exp$3 = custom(arrays.map((ar) => {
|
|
1318
1578
|
const src$1 = ar.#source;
|
|
1319
1579
|
if (ar.#st.contiguous) return src$1;
|
|
1320
|
-
return accessorAluExp(
|
|
1580
|
+
return accessorAluExp(src$1, ar.#st, unravelAlu(newShape, AluVar.idx));
|
|
1321
1581
|
}));
|
|
1322
1582
|
const st = ShapeTracker.fromShape(newShape);
|
|
1323
|
-
return new Array$1(exp$3, st, exp$3.dtype, backend);
|
|
1583
|
+
return new Array$1(exp$3.simplify(), st, exp$3.dtype, backend);
|
|
1324
1584
|
}
|
|
1325
1585
|
let indices;
|
|
1326
1586
|
if (!reduceAxis) indices = unravelAlu(newShape, AluVar.gidx);
|
|
@@ -1330,7 +1590,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1330
1590
|
}
|
|
1331
1591
|
const inputs = [];
|
|
1332
1592
|
const src = [];
|
|
1333
|
-
for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#
|
|
1593
|
+
for (const ar of arrays) if (ar.#source instanceof AluExp) src.push(accessorAluExp(ar.#source, ar.#st, indices));
|
|
1334
1594
|
else {
|
|
1335
1595
|
let gid = inputs.indexOf(ar.#source);
|
|
1336
1596
|
if (gid === -1) {
|
|
@@ -1364,7 +1624,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1364
1624
|
const indices = [...unravelAlu(newShape, AluVar.gidx), AluVar.ridx];
|
|
1365
1625
|
let exp$2;
|
|
1366
1626
|
const inputs = [];
|
|
1367
|
-
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#
|
|
1627
|
+
if (this.#source instanceof AluExp) exp$2 = accessorAluExp(this.#source, this.#st, indices);
|
|
1368
1628
|
else {
|
|
1369
1629
|
inputs.push(this.#source);
|
|
1370
1630
|
exp$2 = accessorGlobal(this.#dtype, 0, this.#st, indices);
|
|
@@ -1389,7 +1649,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1389
1649
|
this.#check();
|
|
1390
1650
|
const indices = unravelAlu(this.#st.shape, AluVar.gidx);
|
|
1391
1651
|
if (this.#source instanceof AluExp) {
|
|
1392
|
-
const exp$2 = accessorAluExp(this.#
|
|
1652
|
+
const exp$2 = accessorAluExp(this.#source, this.#st, indices);
|
|
1393
1653
|
const kernel = new Kernel(0, this.#st.size, exp$2);
|
|
1394
1654
|
const output = this.#backend.malloc(kernel.bytes);
|
|
1395
1655
|
const pendingItem = new PendingExecute(this.#backend, kernel, [], [output]);
|
|
@@ -1427,42 +1687,51 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1427
1687
|
}
|
|
1428
1688
|
/** Realize the array and return it as data. */
|
|
1429
1689
|
async data() {
|
|
1430
|
-
if (this.#source instanceof AluExp &&
|
|
1690
|
+
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1431
1691
|
this.#realize();
|
|
1432
1692
|
const pending = this.#pending;
|
|
1433
1693
|
if (pending) {
|
|
1434
1694
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1435
1695
|
for (const p of pending) p.submit();
|
|
1436
1696
|
}
|
|
1437
|
-
const byteCount = byteWidth(this.#dtype) *
|
|
1697
|
+
const byteCount = byteWidth(this.#dtype) * this.size;
|
|
1438
1698
|
const buf = await this.#backend.read(this.#source, 0, byteCount);
|
|
1439
1699
|
this.dispose();
|
|
1440
1700
|
return dtypedArray(this.dtype, buf);
|
|
1441
1701
|
}
|
|
1442
|
-
/**
|
|
1702
|
+
/**
|
|
1703
|
+
* Wait for this array to finish evaluation.
|
|
1704
|
+
*
|
|
1705
|
+
* Operations and data loading in jax-js are lazy, so this function ensures
|
|
1706
|
+
* that pending operations are dispatched and fully executed before it
|
|
1707
|
+
* returns.
|
|
1708
|
+
*
|
|
1709
|
+
* If you are mapping from `data()` or `dataSync()`, it will also trigger
|
|
1710
|
+
* dispatch of operations as well.
|
|
1711
|
+
*/
|
|
1443
1712
|
async wait() {
|
|
1444
1713
|
this.#check();
|
|
1445
|
-
if (this.#source instanceof AluExp) return;
|
|
1714
|
+
if (this.#source instanceof AluExp) return this;
|
|
1446
1715
|
const pending = this.#pending;
|
|
1447
1716
|
if (pending) {
|
|
1448
1717
|
await Promise.all(pending.map((p) => p.prepare()));
|
|
1449
1718
|
for (const p of pending) p.submit();
|
|
1450
1719
|
}
|
|
1451
1720
|
await this.#backend.read(this.#source, 0, 0);
|
|
1452
|
-
this
|
|
1721
|
+
return this;
|
|
1453
1722
|
}
|
|
1454
1723
|
/**
|
|
1455
1724
|
* Realize the array and return it as data. This is a sync variant and not
|
|
1456
1725
|
* recommended for performance reasons, as it will block rendering.
|
|
1457
1726
|
*/
|
|
1458
1727
|
dataSync() {
|
|
1459
|
-
if (this.#source instanceof AluExp &&
|
|
1728
|
+
if (this.#source instanceof AluExp && this.size < inlineArrayLimit && this.device !== "cpu") return this.#dataInline();
|
|
1460
1729
|
this.#realize();
|
|
1461
1730
|
for (const p of this.#pending) {
|
|
1462
1731
|
p.prepareSync();
|
|
1463
1732
|
p.submit();
|
|
1464
1733
|
}
|
|
1465
|
-
const byteCount = byteWidth(this.#dtype) *
|
|
1734
|
+
const byteCount = byteWidth(this.#dtype) * this.size;
|
|
1466
1735
|
const buf = this.#backend.readSync(this.#source, 0, byteCount);
|
|
1467
1736
|
this.dispose();
|
|
1468
1737
|
return dtypedArray(this.dtype, buf);
|
|
@@ -1483,6 +1752,16 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1483
1752
|
async jsAsync() {
|
|
1484
1753
|
return dataToJs(this.dtype, await this.data(), this.shape);
|
|
1485
1754
|
}
|
|
1755
|
+
/**
|
|
1756
|
+
* Copy an element of an array to a numeric scalar and return it.
|
|
1757
|
+
*
|
|
1758
|
+
* Throws an error if the array does not have a single element. The array must
|
|
1759
|
+
* either be rank-0, or all dimensions of the shape are 1.
|
|
1760
|
+
*/
|
|
1761
|
+
item() {
|
|
1762
|
+
if (this.size !== 1) throw new Error(`item() can only be called on arrays of size 1`);
|
|
1763
|
+
return this.dataSync()[0];
|
|
1764
|
+
}
|
|
1486
1765
|
/** @private Internal plumbing method for Array / Tracer ops. */
|
|
1487
1766
|
static _implRules() {
|
|
1488
1767
|
return {
|
|
@@ -1496,7 +1775,7 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1496
1775
|
return [x.#binary(AluOp.Idiv, y)];
|
|
1497
1776
|
},
|
|
1498
1777
|
[Primitive.Neg]([x]) {
|
|
1499
|
-
return [zerosLike(x).#binary(AluOp.Sub, x)];
|
|
1778
|
+
return [zerosLike(x.ref).#binary(AluOp.Sub, x)];
|
|
1500
1779
|
},
|
|
1501
1780
|
[Primitive.Reciprocal]([x]) {
|
|
1502
1781
|
return [x.#unary(AluOp.Reciprocal)];
|
|
@@ -1552,6 +1831,9 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1552
1831
|
[Primitive.Log]([x]) {
|
|
1553
1832
|
return [x.#unary(AluOp.Log)];
|
|
1554
1833
|
},
|
|
1834
|
+
[Primitive.Sqrt]([x]) {
|
|
1835
|
+
return [x.#unary(AluOp.Sqrt)];
|
|
1836
|
+
},
|
|
1555
1837
|
[Primitive.Min]([x, y]) {
|
|
1556
1838
|
return [x.#binary(AluOp.Min, y)];
|
|
1557
1839
|
},
|
|
@@ -1562,9 +1844,24 @@ var Array$1 = class Array$1 extends Tracer {
|
|
|
1562
1844
|
if (axis.length === 0) return [x];
|
|
1563
1845
|
return [x.#moveAxesDown(axis).#reduce(op)];
|
|
1564
1846
|
},
|
|
1847
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
1848
|
+
const st = pool(x.#st, window, strides);
|
|
1849
|
+
return [x.#reshape(st)];
|
|
1850
|
+
},
|
|
1851
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
1852
|
+
const n = inShape.length;
|
|
1853
|
+
let st = poolTranspose(x.#st, inShape, window, strides);
|
|
1854
|
+
st = st.reshape([...st.shape.slice(0, n), prod(st.shape.slice(n))]);
|
|
1855
|
+
return [x.#reshape(st).#reduce(AluOp.Add)];
|
|
1856
|
+
},
|
|
1565
1857
|
[Primitive.Dot]([x, y]) {
|
|
1566
1858
|
return [Array$1.#naryCustom("dot", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x, y], { reduceAxis: true })];
|
|
1567
1859
|
},
|
|
1860
|
+
[Primitive.Conv]([x, y], params) {
|
|
1861
|
+
checkConvShape(x.shape, y.shape, params);
|
|
1862
|
+
const [stX, stY] = prepareConv(x.#st, y.#st, params);
|
|
1863
|
+
return [Array$1.#naryCustom("conv", ([x$1, y$1]) => AluExp.mul(x$1, y$1), [x.#reshape(stX), y.#reshape(stY)], { reduceAxis: true })];
|
|
1864
|
+
},
|
|
1568
1865
|
[Primitive.Compare]([x, y], { op }) {
|
|
1569
1866
|
const custom = ([x$1, y$1]) => aluCompare(x$1, y$1, op);
|
|
1570
1867
|
return [Array$1.#naryCustom("compare", custom, [x, y], { dtypeOutput: DType.Bool })];
|
|
@@ -1629,6 +1926,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1629
1926
|
dtype ??= DType.Float32;
|
|
1630
1927
|
if (![
|
|
1631
1928
|
DType.Float32,
|
|
1929
|
+
DType.Float16,
|
|
1632
1930
|
DType.Int32,
|
|
1633
1931
|
DType.Uint32
|
|
1634
1932
|
].includes(dtype)) throw new TypeError(`Mismatched dtype for scalar ${value}`);
|
|
@@ -1636,6 +1934,7 @@ function scalar(value, { dtype, device } = {}) {
|
|
|
1636
1934
|
dtype ??= DType.Bool;
|
|
1637
1935
|
if (![
|
|
1638
1936
|
DType.Float32,
|
|
1937
|
+
DType.Float16,
|
|
1639
1938
|
DType.Int32,
|
|
1640
1939
|
DType.Uint32,
|
|
1641
1940
|
DType.Bool
|
|
@@ -1649,7 +1948,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1649
1948
|
if (shape$1 && !deepEqual(values.shape, shape$1)) values = values.reshape(shape$1);
|
|
1650
1949
|
if (dtype && values.dtype !== dtype) throw new Error("array astype not implemented yet");
|
|
1651
1950
|
return values;
|
|
1652
|
-
} else if (values
|
|
1951
|
+
} else if (ArrayBuffer.isView(values)) return arrayFromData(values, shape$1 ?? [values.length], {
|
|
1653
1952
|
dtype,
|
|
1654
1953
|
device
|
|
1655
1954
|
});
|
|
@@ -1678,7 +1977,7 @@ function array(values, { shape: shape$1, dtype, device } = {}) {
|
|
|
1678
1977
|
});
|
|
1679
1978
|
} else {
|
|
1680
1979
|
dtype = dtype ?? DType.Float32;
|
|
1681
|
-
const data =
|
|
1980
|
+
const data = dtypedJsArray(dtype, flat);
|
|
1682
1981
|
return arrayFromData(data, shape$1, {
|
|
1683
1982
|
dtype,
|
|
1684
1983
|
device
|
|
@@ -1699,19 +1998,24 @@ function arrayFromData(data, shape$1, { dtype, device } = {}) {
|
|
|
1699
1998
|
});
|
|
1700
1999
|
}
|
|
1701
2000
|
const backend = getBackend(device);
|
|
1702
|
-
if (data
|
|
1703
|
-
|
|
1704
|
-
|
|
1705
|
-
|
|
1706
|
-
|
|
1707
|
-
if (
|
|
1708
|
-
|
|
1709
|
-
|
|
1710
|
-
|
|
1711
|
-
|
|
1712
|
-
|
|
1713
|
-
|
|
1714
|
-
|
|
2001
|
+
if (ArrayBuffer.isView(data)) {
|
|
2002
|
+
const buf = new Uint8Array(data.buffer, data.byteOffset, data.byteLength);
|
|
2003
|
+
if (data instanceof Float32Array) {
|
|
2004
|
+
if (dtype && dtype !== DType.Float32) throw new Error("Float32Array must have float32 type");
|
|
2005
|
+
dtype ??= DType.Float32;
|
|
2006
|
+
} else if (data instanceof Int32Array) {
|
|
2007
|
+
if (dtype && dtype !== DType.Int32 && dtype !== DType.Bool) throw new Error("Int32Array must have int32 or bool type");
|
|
2008
|
+
dtype ??= DType.Int32;
|
|
2009
|
+
} else if (data instanceof Uint32Array) {
|
|
2010
|
+
if (dtype && dtype !== DType.Uint32) throw new Error("Uint32Array must have uint32 type");
|
|
2011
|
+
dtype ??= DType.Uint32;
|
|
2012
|
+
} else if (data instanceof Float16Array) {
|
|
2013
|
+
if (dtype && dtype !== DType.Float16) throw new Error("Float16Array must have float16 type");
|
|
2014
|
+
dtype ??= DType.Float16;
|
|
2015
|
+
} else throw new Error("Unsupported data array type: " + data.constructor.name);
|
|
2016
|
+
const slot = backend.malloc(data.byteLength, buf);
|
|
2017
|
+
return new Array$1(slot, ShapeTracker.fromShape(shape$1), dtype, backend);
|
|
2018
|
+
} else throw new Error("Unsupported data type: " + data.constructor.name);
|
|
1715
2019
|
}
|
|
1716
2020
|
function dataToJs(dtype, data, shape$1) {
|
|
1717
2021
|
if (shape$1.length === 0) return dtype === DType.Bool ? Boolean(data[0]) : data[0];
|
|
@@ -1738,9 +2042,20 @@ var EvalTrace = class extends Trace {
|
|
|
1738
2042
|
};
|
|
1739
2043
|
const baseArrayTrace = new EvalTrace(newMain(EvalTrace, null));
|
|
1740
2044
|
const implRules = Array$1._implRules();
|
|
1741
|
-
function zerosLike(val) {
|
|
2045
|
+
function zerosLike(val, dtype) {
|
|
2046
|
+
const aval = getAval(val);
|
|
2047
|
+
if (val instanceof Tracer) val.dispose();
|
|
2048
|
+
return zeros(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2049
|
+
}
|
|
2050
|
+
function onesLike(val, dtype) {
|
|
2051
|
+
const aval = getAval(val);
|
|
2052
|
+
if (val instanceof Tracer) val.dispose();
|
|
2053
|
+
return ones(aval.shape, { dtype: dtype ?? aval.dtype });
|
|
2054
|
+
}
|
|
2055
|
+
function fullLike(val, fillValue, dtype) {
|
|
1742
2056
|
const aval = getAval(val);
|
|
1743
|
-
|
|
2057
|
+
if (val instanceof Tracer) val.dispose();
|
|
2058
|
+
return full(aval.shape, fillValue, { dtype: dtype ?? aval.dtype });
|
|
1744
2059
|
}
|
|
1745
2060
|
/** Return a new array of given shape and type, filled with zeros. */
|
|
1746
2061
|
function zeros(shape$1, { dtype, device } = {}) {
|
|
@@ -1762,6 +2077,9 @@ function full(shape$1, fillValue, { dtype, device } = {}) {
|
|
|
1762
2077
|
if (typeof fillValue === "number") {
|
|
1763
2078
|
dtype = dtype ?? DType.Float32;
|
|
1764
2079
|
source = AluExp.const(dtype, fillValue);
|
|
2080
|
+
} else if (typeof fillValue === "bigint") {
|
|
2081
|
+
dtype = dtype ?? DType.Int32;
|
|
2082
|
+
source = AluExp.const(dtype, Number(fillValue));
|
|
1765
2083
|
} else if (typeof fillValue === "boolean") {
|
|
1766
2084
|
dtype = dtype ?? DType.Bool;
|
|
1767
2085
|
source = AluExp.const(dtype, fillValue ? 1 : 0);
|
|
@@ -1859,7 +2177,6 @@ function linspace(start, stop, num = 50, endpoint = true, { dtype, device } = {}
|
|
|
1859
2177
|
const st = ShapeTracker.fromShape([num]);
|
|
1860
2178
|
return new Array$1(exp$2, st, dtype, getBackend(device));
|
|
1861
2179
|
}
|
|
1862
|
-
/** Translate a `CompareOp` into an `AluExp` on two sub-expressions. */
|
|
1863
2180
|
function aluCompare(a, b, op) {
|
|
1864
2181
|
switch (op) {
|
|
1865
2182
|
case CompareOp.Greater: return AluExp.mul(AluExp.cmpne(a, b), AluExp.cmplt(a, b).not());
|
|
@@ -1901,7 +2218,7 @@ function generalBroadcast(a, b) {
|
|
|
1901
2218
|
}
|
|
1902
2219
|
|
|
1903
2220
|
//#endregion
|
|
1904
|
-
//#region node_modules/.pnpm/@oxc-project+runtime@0.
|
|
2221
|
+
//#region node_modules/.pnpm/@oxc-project+runtime@0.78.0/node_modules/@oxc-project/runtime/src/helpers/esm/usingCtx.js
|
|
1905
2222
|
function _usingCtx() {
|
|
1906
2223
|
var r = "function" == typeof SuppressedError ? SuppressedError : function(r$1, e$2) {
|
|
1907
2224
|
var n$1 = Error();
|
|
@@ -1969,7 +2286,7 @@ var Var = class Var {
|
|
|
1969
2286
|
this.aval = aval;
|
|
1970
2287
|
}
|
|
1971
2288
|
toString() {
|
|
1972
|
-
return `Var(${this.id}):${this.aval.
|
|
2289
|
+
return `Var(${this.id}):${this.aval.toString()}`;
|
|
1973
2290
|
}
|
|
1974
2291
|
};
|
|
1975
2292
|
/** Literal in a Jaxpr expression. Currently, only scalars are supported. */
|
|
@@ -2009,7 +2326,7 @@ var VarPrinter = class {
|
|
|
2009
2326
|
return name;
|
|
2010
2327
|
}
|
|
2011
2328
|
nameType(v) {
|
|
2012
|
-
return `${this.name(v)}:${v.aval.
|
|
2329
|
+
return `${this.name(v)}:${v.aval.toString()}`;
|
|
2013
2330
|
}
|
|
2014
2331
|
};
|
|
2015
2332
|
/** A single statement / binding in a Jaxpr, in ANF form. */
|
|
@@ -2164,8 +2481,8 @@ var JaxprType = class {
|
|
|
2164
2481
|
this.outTypes = outTypes;
|
|
2165
2482
|
}
|
|
2166
2483
|
toString() {
|
|
2167
|
-
const inTypes = this.inTypes.map((aval) => aval.
|
|
2168
|
-
const outTypes = this.outTypes.map((aval) => aval.
|
|
2484
|
+
const inTypes = this.inTypes.map((aval) => aval.toString()).join(", ");
|
|
2485
|
+
const outTypes = this.outTypes.map((aval) => aval.toString()).join(", ");
|
|
2169
2486
|
return `(${inTypes}) -> (${outTypes})`;
|
|
2170
2487
|
}
|
|
2171
2488
|
};
|
|
@@ -2244,7 +2561,7 @@ var JaxprTracer = class extends Tracer {
|
|
|
2244
2561
|
this.aval = aval;
|
|
2245
2562
|
}
|
|
2246
2563
|
toString() {
|
|
2247
|
-
return `JaxprTracer(${this.aval.
|
|
2564
|
+
return `JaxprTracer(${this.aval.toString()})`;
|
|
2248
2565
|
}
|
|
2249
2566
|
get ref() {
|
|
2250
2567
|
return this;
|
|
@@ -2383,6 +2700,7 @@ const abstractEvalRules = {
|
|
|
2383
2700
|
[Primitive.Cos]: vectorizedUnopAbstractEval,
|
|
2384
2701
|
[Primitive.Exp]: vectorizedUnopAbstractEval,
|
|
2385
2702
|
[Primitive.Log]: vectorizedUnopAbstractEval,
|
|
2703
|
+
[Primitive.Sqrt]: vectorizedUnopAbstractEval,
|
|
2386
2704
|
[Primitive.Min]: binopAbstractEval,
|
|
2387
2705
|
[Primitive.Max]: binopAbstractEval,
|
|
2388
2706
|
[Primitive.Reduce]([x], { axis }) {
|
|
@@ -2390,6 +2708,15 @@ const abstractEvalRules = {
|
|
|
2390
2708
|
const newShape = x.shape.filter((_, i) => !axisSet.has(i));
|
|
2391
2709
|
return [new ShapedArray(newShape, x.dtype)];
|
|
2392
2710
|
},
|
|
2711
|
+
[Primitive.Pool]([x], { window, strides }) {
|
|
2712
|
+
const shape$1 = checkPoolShape(x.shape, window, strides);
|
|
2713
|
+
return [new ShapedArray(shape$1, x.dtype)];
|
|
2714
|
+
},
|
|
2715
|
+
[Primitive.PoolTranspose]([x], { inShape, window, strides }) {
|
|
2716
|
+
const shape$1 = checkPoolShape(inShape, window, strides);
|
|
2717
|
+
if (!deepEqual(shape$1, x.shape)) throw new TypeError(`PoolTranspose shape mismatch: expected ${JSON.stringify(shape$1)}, got ${JSON.stringify(x.shape)}`);
|
|
2718
|
+
return [new ShapedArray(inShape, x.dtype)];
|
|
2719
|
+
},
|
|
2393
2720
|
[Primitive.Dot]([x, y]) {
|
|
2394
2721
|
if (x.dtype !== y.dtype) throw new TypeError(`Dot dtype mismatch, got ${x.dtype} vs ${y.dtype}`);
|
|
2395
2722
|
if (x.ndim === 0 && y.ndim === 0) throw new TypeError("Dot requires at least 1D inputs");
|
|
@@ -2397,6 +2724,11 @@ const abstractEvalRules = {
|
|
|
2397
2724
|
shape$1.splice(-1, 1);
|
|
2398
2725
|
return [new ShapedArray(shape$1, x.dtype)];
|
|
2399
2726
|
},
|
|
2727
|
+
[Primitive.Conv]([lhs, rhs], params) {
|
|
2728
|
+
if (lhs.dtype !== rhs.dtype) throw new TypeError(`Conv dtype mismatch, got ${lhs.dtype} vs ${rhs.dtype}`);
|
|
2729
|
+
const shape$1 = checkConvShape(lhs.shape, rhs.shape, params);
|
|
2730
|
+
return [new ShapedArray(shape$1, lhs.dtype)];
|
|
2731
|
+
},
|
|
2400
2732
|
[Primitive.Compare]: compareAbstractEval,
|
|
2401
2733
|
[Primitive.Where]([cond, x, y]) {
|
|
2402
2734
|
if (cond.dtype !== DType.Bool) throw new TypeError(`Condition must be boolean, got ${cond.dtype}`);
|
|
@@ -2444,15 +2776,34 @@ const abstractEvalRules = {
|
|
|
2444
2776
|
return outTypes;
|
|
2445
2777
|
}
|
|
2446
2778
|
};
|
|
2447
|
-
function
|
|
2779
|
+
function splitIdx(values, argnums) {
|
|
2780
|
+
const a = [];
|
|
2781
|
+
const b = [];
|
|
2782
|
+
for (let i = 0; i < values.length; i++) if (argnums.has(i)) a.push(values[i]);
|
|
2783
|
+
else b.push(values[i]);
|
|
2784
|
+
return [a, b];
|
|
2785
|
+
}
|
|
2786
|
+
function joinIdx(n, a, b, argnums) {
|
|
2787
|
+
const result = [];
|
|
2788
|
+
let ai = 0;
|
|
2789
|
+
let bi = 0;
|
|
2790
|
+
for (let i = 0; i < n; i++) if (argnums.has(i)) result.push(a[ai++]);
|
|
2791
|
+
else result.push(b[bi++]);
|
|
2792
|
+
return result;
|
|
2793
|
+
}
|
|
2794
|
+
function makeJaxpr$1(f, opts) {
|
|
2448
2795
|
return (...argsIn) => {
|
|
2449
2796
|
try {
|
|
2450
2797
|
var _usingCtx$1 = _usingCtx();
|
|
2451
|
-
const
|
|
2452
|
-
const [
|
|
2798
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2799
|
+
const [staticArgs, shapedArgs] = splitIdx(argsIn, staticArgnums);
|
|
2800
|
+
const [avalsIn, inTree] = flatten(shapedArgs);
|
|
2801
|
+
const [fFlat, outTree] = flattenFun((...dynamicArgs) => {
|
|
2802
|
+
return f(...joinIdx(argsIn.length, staticArgs, dynamicArgs, staticArgnums));
|
|
2803
|
+
}, inTree);
|
|
2453
2804
|
const builder = new JaxprBuilder();
|
|
2454
2805
|
const main = _usingCtx$1.u(newMain(JaxprTrace, builder));
|
|
2455
|
-
|
|
2806
|
+
_usingCtx$1.u(newDynamic(main));
|
|
2456
2807
|
const trace = new JaxprTrace(main);
|
|
2457
2808
|
const tracersIn = avalsIn.map((aval) => trace.newArg(typeof aval === "object" ? aval : pureArray(aval)));
|
|
2458
2809
|
const outs = fFlat(...tracersIn);
|
|
@@ -2471,14 +2822,17 @@ function makeJaxpr$1(f) {
|
|
|
2471
2822
|
}
|
|
2472
2823
|
};
|
|
2473
2824
|
}
|
|
2474
|
-
function jit$1(f) {
|
|
2825
|
+
function jit$1(f, opts) {
|
|
2475
2826
|
const cache = /* @__PURE__ */ new Map();
|
|
2827
|
+
const staticArgnums = new Set(opts?.staticArgnums ?? []);
|
|
2476
2828
|
return ((...args) => {
|
|
2477
|
-
const [
|
|
2829
|
+
const [staticArgs, dynamicArgs] = splitIdx(args, staticArgnums);
|
|
2830
|
+
const [argsFlat, inTree] = flatten(dynamicArgs);
|
|
2478
2831
|
const avalsInFlat = argsFlat.map((x) => ShapedArray.fromAval(getAval(x)));
|
|
2479
2832
|
const avalsIn = unflatten(inTree, avalsInFlat);
|
|
2480
|
-
const
|
|
2481
|
-
const
|
|
2833
|
+
const jaxprArgs = joinIdx(args.length, staticArgs, avalsIn, staticArgnums);
|
|
2834
|
+
const cacheKey = JSON.stringify(jaxprArgs);
|
|
2835
|
+
const { jaxpr, consts, treedef: outTree } = runWithCache(cache, cacheKey, () => makeJaxpr$1(f, opts)(...jaxprArgs));
|
|
2482
2836
|
const outs = bind(Primitive.JitCall, [...consts.map((c) => c.ref), ...argsFlat], {
|
|
2483
2837
|
jaxpr,
|
|
2484
2838
|
numConsts: consts.length
|
|
@@ -2515,7 +2869,7 @@ var JVPTrace = class extends Trace {
|
|
|
2515
2869
|
return this.lift(pureArray(val));
|
|
2516
2870
|
}
|
|
2517
2871
|
lift(val) {
|
|
2518
|
-
return new JVPTracer(this, val, zerosLike(val));
|
|
2872
|
+
return new JVPTracer(this, val, zerosLike(val.ref));
|
|
2519
2873
|
}
|
|
2520
2874
|
processPrimitive(primitive, tracers, params) {
|
|
2521
2875
|
const [primalsIn, tangentsIn] = unzip2(tracers.map((x) => [x.primal, x.tangent]));
|
|
@@ -2533,19 +2887,25 @@ function linearTangentsJvp(primitive) {
|
|
|
2533
2887
|
return [ys, dys];
|
|
2534
2888
|
};
|
|
2535
2889
|
}
|
|
2890
|
+
/** Rule for product of gradients in bilinear operations. */
|
|
2891
|
+
function bilinearTangentsJvp(primitive) {
|
|
2892
|
+
return ([x, y], [dx, dy], params) => {
|
|
2893
|
+
const primal = bind1(primitive, [x.ref, y.ref], params);
|
|
2894
|
+
const tangent = bind1(primitive, [x, dy], params).add(bind1(primitive, [dx, y], params));
|
|
2895
|
+
return [[primal], [tangent]];
|
|
2896
|
+
};
|
|
2897
|
+
}
|
|
2536
2898
|
/** Rule that zeros out any tangents. */
|
|
2537
2899
|
function zeroTangentsJvp(primitive) {
|
|
2538
2900
|
return (primals, tangents, params) => {
|
|
2539
2901
|
for (const t of tangents) t.dispose();
|
|
2540
2902
|
const ys = bind(primitive, primals, params);
|
|
2541
|
-
return [ys, ys.map((y) => zerosLike(y))];
|
|
2903
|
+
return [ys, ys.map((y) => zerosLike(y.ref))];
|
|
2542
2904
|
};
|
|
2543
2905
|
}
|
|
2544
2906
|
const jvpRules = {
|
|
2545
2907
|
[Primitive.Add]: linearTangentsJvp(Primitive.Add),
|
|
2546
|
-
[Primitive.Mul](
|
|
2547
|
-
return [[x.ref.mul(y.ref)], [x.mul(dy).add(dx.mul(y))]];
|
|
2548
|
-
},
|
|
2908
|
+
[Primitive.Mul]: bilinearTangentsJvp(Primitive.Mul),
|
|
2549
2909
|
[Primitive.Idiv]: zeroTangentsJvp(Primitive.Idiv),
|
|
2550
2910
|
[Primitive.Neg]: linearTangentsJvp(Primitive.Neg),
|
|
2551
2911
|
[Primitive.Reciprocal]([x], [dx]) {
|
|
@@ -2558,13 +2918,13 @@ const jvpRules = {
|
|
|
2558
2918
|
if (isFloatDtype(dtype) && isFloatDtype(x.dtype)) return [[cast(x, dtype)], [cast(dx, dtype)]];
|
|
2559
2919
|
else {
|
|
2560
2920
|
dx.dispose();
|
|
2561
|
-
return [[cast(x, dtype)], [zerosLike(x)]];
|
|
2921
|
+
return [[cast(x.ref, dtype)], [zerosLike(x)]];
|
|
2562
2922
|
}
|
|
2563
2923
|
},
|
|
2564
2924
|
[Primitive.Bitcast]([x], [dx], { dtype }) {
|
|
2565
2925
|
if (x.dtype === dtype) return [[x], [dx]];
|
|
2566
2926
|
dx.dispose();
|
|
2567
|
-
return [[bitcast(x, dtype)], [zerosLike(x)]];
|
|
2927
|
+
return [[bitcast(x.ref, dtype)], [zerosLike(x)]];
|
|
2568
2928
|
},
|
|
2569
2929
|
[Primitive.RandomBits]: zeroTangentsJvp(Primitive.RandomBits),
|
|
2570
2930
|
[Primitive.Sin]([x], [dx]) {
|
|
@@ -2580,6 +2940,10 @@ const jvpRules = {
|
|
|
2580
2940
|
[Primitive.Log]([x], [dx]) {
|
|
2581
2941
|
return [[log$1(x.ref)], [reciprocal$1(x).mul(dx)]];
|
|
2582
2942
|
},
|
|
2943
|
+
[Primitive.Sqrt]([x], [dx]) {
|
|
2944
|
+
const z = sqrt$1(x);
|
|
2945
|
+
return [[z.ref], [reciprocal$1(z.mul(2)).mul(dx)]];
|
|
2946
|
+
},
|
|
2583
2947
|
[Primitive.Min]([x, y], [dx, dy]) {
|
|
2584
2948
|
return [[min$1(x.ref, y.ref)], [where$1(less$1(y, x), dy, dx)]];
|
|
2585
2949
|
},
|
|
@@ -2596,13 +2960,14 @@ const jvpRules = {
|
|
|
2596
2960
|
const primal = reduce(x.ref, op, axis);
|
|
2597
2961
|
const notMin = notEqual$1(x, broadcast(primal.ref, x.shape, axis));
|
|
2598
2962
|
const minCount = where$1(notMin.ref, 0, 1).sum(axis);
|
|
2599
|
-
const tangent = where$1(notMin,
|
|
2963
|
+
const tangent = where$1(notMin, 0, dx).sum(axis).div(minCount);
|
|
2600
2964
|
return [[primal], [tangent]];
|
|
2601
2965
|
} else throw new Error(`JVP rule not implemented for reduce op: ${op}`);
|
|
2602
2966
|
},
|
|
2603
|
-
[Primitive.
|
|
2604
|
-
|
|
2605
|
-
|
|
2967
|
+
[Primitive.Pool]: linearTangentsJvp(Primitive.Pool),
|
|
2968
|
+
[Primitive.PoolTranspose]: linearTangentsJvp(Primitive.PoolTranspose),
|
|
2969
|
+
[Primitive.Dot]: bilinearTangentsJvp(Primitive.Dot),
|
|
2970
|
+
[Primitive.Conv]: bilinearTangentsJvp(Primitive.Conv),
|
|
2606
2971
|
[Primitive.Compare]: zeroTangentsJvp(Primitive.Compare),
|
|
2607
2972
|
[Primitive.Where]([cond, x, y], [dcond, dx, dy]) {
|
|
2608
2973
|
dcond.dispose();
|
|
@@ -2778,6 +3143,7 @@ const vmapRules = {
|
|
|
2778
3143
|
[Primitive.Cos]: unopBatcher(cos$1),
|
|
2779
3144
|
[Primitive.Exp]: unopBatcher(exp$1),
|
|
2780
3145
|
[Primitive.Log]: unopBatcher(log$1),
|
|
3146
|
+
[Primitive.Sqrt]: unopBatcher(sqrt$1),
|
|
2781
3147
|
[Primitive.Min]: broadcastBatcher(min$1),
|
|
2782
3148
|
[Primitive.Max]: broadcastBatcher(max$1),
|
|
2783
3149
|
[Primitive.Reduce](axisSize, [x], [xBdim], { op, axis }) {
|
|
@@ -2914,7 +3280,7 @@ var PartialVal = class PartialVal {
|
|
|
2914
3280
|
return this.val !== null;
|
|
2915
3281
|
}
|
|
2916
3282
|
toString() {
|
|
2917
|
-
return this.val ? this.val.toString() : this.aval.
|
|
3283
|
+
return this.val ? this.val.toString() : this.aval.toString();
|
|
2918
3284
|
}
|
|
2919
3285
|
};
|
|
2920
3286
|
function partialEvalFlat(f, pvalsIn) {
|
|
@@ -3288,12 +3654,72 @@ const transposeRules = {
|
|
|
3288
3654
|
if (op === AluOp.Add) return [broadcast(ct, x.aval.shape, axis)];
|
|
3289
3655
|
else throw new NonlinearError(Primitive.Reduce);
|
|
3290
3656
|
},
|
|
3657
|
+
[Primitive.Pool]([ct], [x], { window, strides }) {
|
|
3658
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.Pool);
|
|
3659
|
+
return bind(Primitive.PoolTranspose, [ct], {
|
|
3660
|
+
inShape: x.aval.shape,
|
|
3661
|
+
window,
|
|
3662
|
+
strides
|
|
3663
|
+
});
|
|
3664
|
+
},
|
|
3665
|
+
[Primitive.PoolTranspose]([ct], [x], { window, strides }) {
|
|
3666
|
+
if (!(x instanceof UndefPrimal)) throw new NonlinearError(Primitive.PoolTranspose);
|
|
3667
|
+
return bind(Primitive.Pool, [ct], {
|
|
3668
|
+
window,
|
|
3669
|
+
strides
|
|
3670
|
+
});
|
|
3671
|
+
},
|
|
3291
3672
|
[Primitive.Dot]([ct], [x, y]) {
|
|
3292
3673
|
if (x instanceof UndefPrimal === y instanceof UndefPrimal) throw new NonlinearError(Primitive.Dot);
|
|
3293
3674
|
const axisSize = generalBroadcast(x.aval.shape, y.aval.shape).slice(-1)[0];
|
|
3294
3675
|
ct = broadcast(ct, ct.shape.concat(axisSize), [-1]);
|
|
3295
3676
|
return [x instanceof UndefPrimal ? unbroadcast(mul(ct, y), x) : null, y instanceof UndefPrimal ? unbroadcast(mul(x, ct), y) : null];
|
|
3296
3677
|
},
|
|
3678
|
+
[Primitive.Conv]([ct], [lhs, rhs], params) {
|
|
3679
|
+
if (lhs instanceof UndefPrimal === rhs instanceof UndefPrimal) throw new NonlinearError(Primitive.Conv);
|
|
3680
|
+
const rev01 = [
|
|
3681
|
+
1,
|
|
3682
|
+
0,
|
|
3683
|
+
...range(2, ct.ndim)
|
|
3684
|
+
];
|
|
3685
|
+
if (lhs instanceof UndefPrimal) {
|
|
3686
|
+
let kernel = rhs;
|
|
3687
|
+
kernel = transpose$1(kernel, rev01);
|
|
3688
|
+
kernel = flip$1(kernel, range(2, kernel.ndim));
|
|
3689
|
+
const result = conv(ct, kernel, {
|
|
3690
|
+
strides: params.lhsDilation,
|
|
3691
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3692
|
+
const dilatedKernel = (kernel.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3693
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3694
|
+
const padBefore = dilatedKernel - 1 - pl;
|
|
3695
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3696
|
+
const padAfter = dilatedLhs + dilatedKernel - 1 - dilatedCt - padBefore;
|
|
3697
|
+
return [padBefore, padAfter];
|
|
3698
|
+
}),
|
|
3699
|
+
lhsDilation: params.strides,
|
|
3700
|
+
rhsDilation: params.rhsDilation
|
|
3701
|
+
});
|
|
3702
|
+
return [result, null];
|
|
3703
|
+
} else {
|
|
3704
|
+
const newLhs = transpose$1(lhs, rev01);
|
|
3705
|
+
const newRhs = transpose$1(ct, rev01);
|
|
3706
|
+
let result = conv(newLhs, newRhs, {
|
|
3707
|
+
strides: params.rhsDilation,
|
|
3708
|
+
padding: params.padding.map(([pl, _pr], i) => {
|
|
3709
|
+
const dilatedLhs = (lhs.aval.shape[i + 2] - 1) * params.lhsDilation[i] + 1;
|
|
3710
|
+
const dilatedKernel = (rhs.aval.shape[i + 2] - 1) * params.rhsDilation[i] + 1;
|
|
3711
|
+
const dilatedCt = (ct.shape[i + 2] - 1) * params.strides[i] + 1;
|
|
3712
|
+
const padFromLhs = dilatedCt - dilatedLhs;
|
|
3713
|
+
const padFromRhs = dilatedKernel - pl - 1;
|
|
3714
|
+
return [pl, padFromLhs + padFromRhs];
|
|
3715
|
+
}),
|
|
3716
|
+
lhsDilation: params.lhsDilation,
|
|
3717
|
+
rhsDilation: params.strides
|
|
3718
|
+
});
|
|
3719
|
+
result = transpose$1(result, rev01);
|
|
3720
|
+
return [null, result];
|
|
3721
|
+
}
|
|
3722
|
+
},
|
|
3297
3723
|
[Primitive.Where]([ct], [cond, x, y]) {
|
|
3298
3724
|
const cts = [
|
|
3299
3725
|
null,
|
|
@@ -3414,8 +3840,8 @@ function valueAndGrad$1(f) {
|
|
|
3414
3840
|
if (x.length === 0) throw new Error("grad requires at least one argument to differentiate");
|
|
3415
3841
|
const [y, fVjp] = vjp$1(f, x[0], ...x.slice(1).map(stopGradient));
|
|
3416
3842
|
if (!(y instanceof Tracer) || ndim$1(y) !== 0) throw new TypeError("grad requires a scalar output");
|
|
3417
|
-
if (y.dtype
|
|
3418
|
-
const [ct, ...rest] = fVjp(
|
|
3843
|
+
if (!isFloatDtype(y.dtype)) throw new TypeError("grad only supports floating-point dtypes");
|
|
3844
|
+
const [ct, ...rest] = fVjp(scalar(1, { dtype: y.dtype }));
|
|
3419
3845
|
for (const r of rest) r.dispose();
|
|
3420
3846
|
return [y, ct];
|
|
3421
3847
|
};
|
|
@@ -3429,6 +3855,73 @@ function jacrev$1(f) {
|
|
|
3429
3855
|
};
|
|
3430
3856
|
}
|
|
3431
3857
|
|
|
3858
|
+
//#endregion
|
|
3859
|
+
//#region src/lax.ts
|
|
3860
|
+
var lax_exports = {};
|
|
3861
|
+
__export(lax_exports, {
|
|
3862
|
+
conv: () => conv$1,
|
|
3863
|
+
convGeneralDilated: () => convGeneralDilated,
|
|
3864
|
+
convWithGeneralPadding: () => convWithGeneralPadding,
|
|
3865
|
+
reduceWindow: () => reduceWindow
|
|
3866
|
+
});
|
|
3867
|
+
function padtypeToPads(inShape, filterShape, strides, dilation, padding) {
|
|
3868
|
+
const padType = padding.toUpperCase();
|
|
3869
|
+
switch (padType) {
|
|
3870
|
+
case "VALID": return rep(inShape.length, [0, 0]);
|
|
3871
|
+
case "SAME":
|
|
3872
|
+
case "SAME_LOWER": {
|
|
3873
|
+
const outShape = inShape.map((size$1, i) => Math.ceil(size$1 / strides[i]));
|
|
3874
|
+
const padSizes = zipn(outShape, strides, filterShape, dilation, inShape).map(([o, s, k, d, i]) => Math.max(0, (o - 1) * s + 1 + (k - 1) * d - i));
|
|
3875
|
+
if (padType === "SAME") return padSizes.map((size$1) => [size$1 >> 1, size$1 - (size$1 >> 1)]);
|
|
3876
|
+
else return padSizes.map((size$1) => [size$1 - (size$1 >> 1), size$1 >> 1]);
|
|
3877
|
+
}
|
|
3878
|
+
default: throw new Error(`Unknown padding type: ${padType}`);
|
|
3879
|
+
}
|
|
3880
|
+
}
|
|
3881
|
+
/**
|
|
3882
|
+
* General n-dimensional convolution operator, with optional dilation.
|
|
3883
|
+
*
|
|
3884
|
+
* The semantics of this operation mimic the `jax.lax.conv_general_dilated`
|
|
3885
|
+
* function in JAX, which wraps XLA's general convolution operator.
|
|
3886
|
+
*
|
|
3887
|
+
* Grouped convolutions are not supported right now.
|
|
3888
|
+
*/
|
|
3889
|
+
function convGeneralDilated(lhs, rhs, windowStrides, padding, { lhsDilation, rhsDilation } = {}) {
|
|
3890
|
+
if (lhs.ndim < 2) throw new Error("lhs must have at least 2 dimensions");
|
|
3891
|
+
if (rhs.ndim < 2) throw new Error("rhs must have at least 2 dimensions");
|
|
3892
|
+
if (typeof padding === "string") {
|
|
3893
|
+
if (lhsDilation?.some((d) => d !== 1)) throw new Error("String padding is not supported for transposed convolutions");
|
|
3894
|
+
padding = padtypeToPads(lhs.shape.slice(2), rhs.shape.slice(2), windowStrides, rhsDilation ?? rep(rhs.ndim - 2, 1), padding);
|
|
3895
|
+
}
|
|
3896
|
+
return conv(lhs, rhs, {
|
|
3897
|
+
strides: windowStrides,
|
|
3898
|
+
padding,
|
|
3899
|
+
lhsDilation,
|
|
3900
|
+
rhsDilation
|
|
3901
|
+
});
|
|
3902
|
+
}
|
|
3903
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
3904
|
+
function convWithGeneralPadding(lhs, rhs, windowStrides, padding, lhsDilation, rhsDilation) {
|
|
3905
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding, {
|
|
3906
|
+
lhsDilation,
|
|
3907
|
+
rhsDilation
|
|
3908
|
+
});
|
|
3909
|
+
}
|
|
3910
|
+
/** Convenience wrapper around `convGeneralDilated`. */
|
|
3911
|
+
function conv$1(lhs, rhs, windowStrides, padding) {
|
|
3912
|
+
return convGeneralDilated(lhs, rhs, windowStrides, padding);
|
|
3913
|
+
}
|
|
3914
|
+
/** Reduce a computation over padded windows. */
|
|
3915
|
+
function reduceWindow(operand, computation, windowDimensions, windowStrides) {
|
|
3916
|
+
if (operand.ndim < windowDimensions.length) throw new Error(`Operand dimensions ${operand.ndim} < window ${windowDimensions.length}`);
|
|
3917
|
+
if (!windowStrides) windowStrides = rep(windowDimensions.length, 1);
|
|
3918
|
+
for (let i = 0; i < operand.ndim; i++) computation = vmap$1(computation, 0);
|
|
3919
|
+
return computation(bind1(Primitive.Pool, [operand], {
|
|
3920
|
+
window: windowDimensions,
|
|
3921
|
+
strides: windowStrides
|
|
3922
|
+
}));
|
|
3923
|
+
}
|
|
3924
|
+
|
|
3432
3925
|
//#endregion
|
|
3433
3926
|
//#region src/numpy.ts
|
|
3434
3927
|
var numpy_exports = {};
|
|
@@ -3447,9 +3940,9 @@ __export(numpy_exports, {
|
|
|
3447
3940
|
bool: () => bool,
|
|
3448
3941
|
clip: () => clip,
|
|
3449
3942
|
columnStack: () => columnStack,
|
|
3450
|
-
complex64: () => complex64,
|
|
3451
3943
|
concatenate: () => concatenate,
|
|
3452
3944
|
cos: () => cos,
|
|
3945
|
+
cosh: () => cosh,
|
|
3453
3946
|
diag: () => diag,
|
|
3454
3947
|
diagonal: () => diagonal,
|
|
3455
3948
|
divide: () => divide,
|
|
@@ -3464,8 +3957,10 @@ __export(numpy_exports, {
|
|
|
3464
3957
|
flip: () => flip,
|
|
3465
3958
|
fliplr: () => fliplr,
|
|
3466
3959
|
flipud: () => flipud,
|
|
3960
|
+
float16: () => float16,
|
|
3467
3961
|
float32: () => float32,
|
|
3468
3962
|
full: () => full,
|
|
3963
|
+
fullLike: () => fullLike$1,
|
|
3469
3964
|
greater: () => greater,
|
|
3470
3965
|
greaterEqual: () => greaterEqual,
|
|
3471
3966
|
hstack: () => hstack,
|
|
@@ -3492,6 +3987,7 @@ __export(numpy_exports, {
|
|
|
3492
3987
|
negative: () => negative,
|
|
3493
3988
|
notEqual: () => notEqual,
|
|
3494
3989
|
ones: () => ones,
|
|
3990
|
+
onesLike: () => onesLike$1,
|
|
3495
3991
|
pad: () => pad,
|
|
3496
3992
|
permuteDims: () => permuteDims,
|
|
3497
3993
|
pi: () => pi,
|
|
@@ -3502,11 +3998,14 @@ __export(numpy_exports, {
|
|
|
3502
3998
|
scalar: () => scalar,
|
|
3503
3999
|
shape: () => shape,
|
|
3504
4000
|
sin: () => sin,
|
|
4001
|
+
sinh: () => sinh,
|
|
3505
4002
|
size: () => size,
|
|
4003
|
+
sqrt: () => sqrt,
|
|
3506
4004
|
square: () => square,
|
|
3507
4005
|
stack: () => stack,
|
|
3508
4006
|
sum: () => sum,
|
|
3509
4007
|
tan: () => tan,
|
|
4008
|
+
tanh: () => tanh,
|
|
3510
4009
|
transpose: () => transpose,
|
|
3511
4010
|
trueDivide: () => trueDivide,
|
|
3512
4011
|
trunc: () => trunc,
|
|
@@ -3515,13 +4014,14 @@ __export(numpy_exports, {
|
|
|
3515
4014
|
vecdot: () => vecdot,
|
|
3516
4015
|
vstack: () => vstack,
|
|
3517
4016
|
where: () => where,
|
|
3518
|
-
zeros: () => zeros
|
|
4017
|
+
zeros: () => zeros,
|
|
4018
|
+
zerosLike: () => zerosLike$1
|
|
3519
4019
|
});
|
|
3520
4020
|
const float32 = DType.Float32;
|
|
3521
4021
|
const int32 = DType.Int32;
|
|
3522
4022
|
const uint32 = DType.Uint32;
|
|
3523
4023
|
const bool = DType.Bool;
|
|
3524
|
-
const
|
|
4024
|
+
const float16 = DType.Float16;
|
|
3525
4025
|
/** Euler's constant, `e = 2.7182818284590...` */
|
|
3526
4026
|
const e = Math.E;
|
|
3527
4027
|
/** Euler-Mascheroni constant, `γ = 0.5772156649...` */
|
|
@@ -3548,6 +4048,8 @@ const cos = cos$1;
|
|
|
3548
4048
|
const exp = exp$1;
|
|
3549
4049
|
/** Calculate the natural logarithm of all elements in the input array. */
|
|
3550
4050
|
const log = log$1;
|
|
4051
|
+
/** Calculate the square root of all elements in the input array. */
|
|
4052
|
+
const sqrt = sqrt$1;
|
|
3551
4053
|
/** Return element-wise minimum of the input arrays. */
|
|
3552
4054
|
const minimum = min$1;
|
|
3553
4055
|
/** Return element-wise maximum of the input arrays. */
|
|
@@ -3589,6 +4091,12 @@ const pad = pad$1;
|
|
|
3589
4091
|
const ndim = ndim$1;
|
|
3590
4092
|
/** Return the shape of an array. Does not consume array reference. */
|
|
3591
4093
|
const shape = getShape;
|
|
4094
|
+
/** Return an array of zeros with the same shape and type as a given array. */
|
|
4095
|
+
const zerosLike$1 = zerosLike;
|
|
4096
|
+
/** Return an array of ones with the same shape and type as a given array. */
|
|
4097
|
+
const onesLike$1 = onesLike;
|
|
4098
|
+
/** Return a full array with the same shape and type as a given array. */
|
|
4099
|
+
const fullLike$1 = fullLike;
|
|
3592
4100
|
/**
|
|
3593
4101
|
* Return the number of elements in an array, optionally along an axis.
|
|
3594
4102
|
* Does not consume array reference.
|
|
@@ -3639,13 +4147,7 @@ function argmin(a, axis, opts) {
|
|
|
3639
4147
|
dtype: int32,
|
|
3640
4148
|
device: a.device
|
|
3641
4149
|
});
|
|
3642
|
-
const idx =
|
|
3643
|
-
dtype: int32,
|
|
3644
|
-
device: a.device
|
|
3645
|
-
}), scalar(0, {
|
|
3646
|
-
dtype: int32,
|
|
3647
|
-
device: a.device
|
|
3648
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4150
|
+
const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3649
4151
|
dtype: int32,
|
|
3650
4152
|
device: a.device
|
|
3651
4153
|
}).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
|
|
@@ -3669,13 +4171,7 @@ function argmax(a, axis, opts) {
|
|
|
3669
4171
|
dtype: int32,
|
|
3670
4172
|
device: a.device
|
|
3671
4173
|
});
|
|
3672
|
-
const idx =
|
|
3673
|
-
dtype: int32,
|
|
3674
|
-
device: a.device
|
|
3675
|
-
}), scalar(0, {
|
|
3676
|
-
dtype: int32,
|
|
3677
|
-
device: a.device
|
|
3678
|
-
})).mul(arange(shape$1[axis], 0, -1, {
|
|
4174
|
+
const idx = isMax.astype(DType.Int32).mul(arange(shape$1[axis], 0, -1, {
|
|
3679
4175
|
dtype: int32,
|
|
3680
4176
|
device: a.device
|
|
3681
4177
|
}).reshape([shape$1[axis], ...rep(shape$1.length - axis - 1, 1)]));
|
|
@@ -3807,9 +4303,11 @@ function ravel(a) {
|
|
|
3807
4303
|
* Return specified diagonals.
|
|
3808
4304
|
*
|
|
3809
4305
|
* If a is 2D, return the diagonal of the array with the given offset. If a is
|
|
3810
|
-
* 3D or higher, compute diagonals along the two given axes.
|
|
4306
|
+
* 3D or higher, compute diagonals along the two given axes (default: 0, 1).
|
|
3811
4307
|
*
|
|
3812
|
-
* This returns a view over the existing array.
|
|
4308
|
+
* This returns a view over the existing array. The shape of the resulting array
|
|
4309
|
+
* is determined by removing the two axes along which the diagonal is taken,
|
|
4310
|
+
* then appending a new axis to the right with holding the diagonals.
|
|
3813
4311
|
*/
|
|
3814
4312
|
function diagonal(a, offset, axis1, axis2) {
|
|
3815
4313
|
return fudgeArray(a).diagonal(offset, axis1, axis2);
|
|
@@ -3825,15 +4323,16 @@ function diag(v, k = 0) {
|
|
|
3825
4323
|
if (!Number.isInteger(k)) throw new TypeError(`k must be an integer, got ${k}`);
|
|
3826
4324
|
if (a.ndim === 1) {
|
|
3827
4325
|
const n = a.shape[0];
|
|
3828
|
-
const ret = where(eye(n).equal(1), a,
|
|
3829
|
-
if (k
|
|
3830
|
-
return ret;
|
|
4326
|
+
const ret = where(eye(n).equal(1), a.ref, zerosLike$1(a));
|
|
4327
|
+
if (k > 0) return pad(ret, [[0, k], [k, 0]]);
|
|
4328
|
+
else if (k < 0) return pad(ret, [[-k, 0], [0, -k]]);
|
|
4329
|
+
else return ret;
|
|
3831
4330
|
} else if (a.ndim === 2) return diagonal(a, k);
|
|
3832
4331
|
else throw new TypeError("numpy.diag only supports 1D and 2D arrays");
|
|
3833
4332
|
}
|
|
3834
4333
|
/** Return if two arrays are element-wise equal within a tolerance. */
|
|
3835
4334
|
function allclose(actual, expected, options) {
|
|
3836
|
-
const { rtol = 1e-5, atol = 1e-
|
|
4335
|
+
const { rtol = 1e-5, atol = 1e-7 } = options ?? {};
|
|
3837
4336
|
const x = array(actual);
|
|
3838
4337
|
const y = array(expected);
|
|
3839
4338
|
if (!deepEqual(x.shape, y.shape)) return false;
|
|
@@ -3967,15 +4466,52 @@ function log2(x) {
|
|
|
3967
4466
|
function log10(x) {
|
|
3968
4467
|
return log(x).mul(Math.LOG10E);
|
|
3969
4468
|
}
|
|
4469
|
+
/**
|
|
4470
|
+
* Calculate element-wise hyperbolic sine of input.
|
|
4471
|
+
*
|
|
4472
|
+
* `sinh(x) = (exp(x) - exp(-x)) / 2`
|
|
4473
|
+
*/
|
|
4474
|
+
function sinh(x) {
|
|
4475
|
+
const ex = exp(x);
|
|
4476
|
+
const emx = reciprocal(ex.ref);
|
|
4477
|
+
return ex.sub(emx).mul(.5);
|
|
4478
|
+
}
|
|
4479
|
+
/**
|
|
4480
|
+
* Calculate element-wise hyperbolic cosine of input.
|
|
4481
|
+
*
|
|
4482
|
+
* `cosh(x) = (exp(x) + exp(-x)) / 2`
|
|
4483
|
+
*/
|
|
4484
|
+
function cosh(x) {
|
|
4485
|
+
const ex = exp(x);
|
|
4486
|
+
const emx = reciprocal(ex.ref);
|
|
4487
|
+
return ex.add(emx).mul(.5);
|
|
4488
|
+
}
|
|
4489
|
+
/**
|
|
4490
|
+
* Calculate element-wise hyperbolic tangent of input.
|
|
4491
|
+
*
|
|
4492
|
+
* `tanh(x) = sinh(x)/cosh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))`
|
|
4493
|
+
*/
|
|
4494
|
+
function tanh(x) {
|
|
4495
|
+
x = fudgeArray(x);
|
|
4496
|
+
const negsgn = where(less(x.ref, 0), 1, -1);
|
|
4497
|
+
const en2x = exp(x.mul(negsgn.ref).mul(2));
|
|
4498
|
+
return en2x.ref.sub(1).div(en2x.add(1)).mul(negsgn);
|
|
4499
|
+
}
|
|
3970
4500
|
|
|
3971
4501
|
//#endregion
|
|
3972
4502
|
//#region src/nn.ts
|
|
3973
4503
|
var nn_exports = {};
|
|
3974
4504
|
__export(nn_exports, {
|
|
4505
|
+
celu: () => celu,
|
|
4506
|
+
elu: () => elu,
|
|
4507
|
+
gelu: () => gelu,
|
|
4508
|
+
glu: () => glu,
|
|
3975
4509
|
identity: () => identity,
|
|
4510
|
+
leakyRelu: () => leakyRelu,
|
|
3976
4511
|
logSigmoid: () => logSigmoid,
|
|
3977
4512
|
logSoftmax: () => logSoftmax,
|
|
3978
4513
|
logsumexp: () => logsumexp,
|
|
4514
|
+
mish: () => mish,
|
|
3979
4515
|
oneHot: () => oneHot,
|
|
3980
4516
|
relu: () => relu,
|
|
3981
4517
|
relu6: () => relu6,
|
|
@@ -4035,10 +4571,7 @@ function softSign(x) {
|
|
|
4035
4571
|
*
|
|
4036
4572
|
* Reference: https://en.wikipedia.org/wiki/Swish_function
|
|
4037
4573
|
*/
|
|
4038
|
-
|
|
4039
|
-
x = fudgeArray(x);
|
|
4040
|
-
return x.ref.mul(sigmoid(x));
|
|
4041
|
-
}
|
|
4574
|
+
const silu = jit$1((x) => x.ref.mul(sigmoid(x)));
|
|
4042
4575
|
/**
|
|
4043
4576
|
* Sigmoid-weighted Linear Unit (SiLU) activation function, also known as
|
|
4044
4577
|
* Swish, computed element-wise:
|
|
@@ -4058,6 +4591,72 @@ function logSigmoid(x) {
|
|
|
4058
4591
|
}
|
|
4059
4592
|
/** Identity activation function. Returns the argument unmodified. */
|
|
4060
4593
|
const identity = fudgeArray;
|
|
4594
|
+
/** Leaky rectified linear (ReLU) activation function */
|
|
4595
|
+
function leakyRelu(x, negativeSlope = .01) {
|
|
4596
|
+
x = fudgeArray(x);
|
|
4597
|
+
return where(less(x.ref, 0), x.ref.mul(negativeSlope), x);
|
|
4598
|
+
}
|
|
4599
|
+
/**
|
|
4600
|
+
* Exponential linear unit activation function.
|
|
4601
|
+
*
|
|
4602
|
+
* Computes the element-wise function:
|
|
4603
|
+
* `elu(x) = x > 0 ? x : alpha * (exp(x) - 1)`
|
|
4604
|
+
*/
|
|
4605
|
+
function elu(x, alpha = 1) {
|
|
4606
|
+
x = fudgeArray(x);
|
|
4607
|
+
return where(less(x.ref, 0), exp(x.ref).sub(1).mul(alpha), x);
|
|
4608
|
+
}
|
|
4609
|
+
/**
|
|
4610
|
+
* Continuously-differentiable exponential linear unit activation function.
|
|
4611
|
+
*
|
|
4612
|
+
* Computes the element-wise function:
|
|
4613
|
+
* `celu(x) = x > 0 ? x : alpha * (exp(x/alpha) - 1)`
|
|
4614
|
+
*/
|
|
4615
|
+
function celu(x, alpha = 1) {
|
|
4616
|
+
x = fudgeArray(x);
|
|
4617
|
+
return where(less(x.ref, 0), exp(x.ref.div(alpha)).sub(1).mul(alpha), x);
|
|
4618
|
+
}
|
|
4619
|
+
/**
|
|
4620
|
+
* Gaussion error linear unit (GELU) activation function.
|
|
4621
|
+
*
|
|
4622
|
+
* This is computed element-wise. Currently jax-js does not support the erf() or
|
|
4623
|
+
* gelu() functions exactly as primitives, so an approximation is used:
|
|
4624
|
+
* `gelu(x) ~= x * 0.5 * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3)))`.
|
|
4625
|
+
*
|
|
4626
|
+
* Reference: https://ml-explore.github.io/mlx/build/html/python/nn/_autosummary_functions/mlx.nn.gelu_approx.html
|
|
4627
|
+
*
|
|
4628
|
+
* This will be improved in the future.
|
|
4629
|
+
*/
|
|
4630
|
+
const gelu = jit$1((x) => {
|
|
4631
|
+
const SQRT_2_OVER_PI = Math.sqrt(2 / Math.PI);
|
|
4632
|
+
return x.ref.mul(.5).mul(tanh(x.ref.mul(x.ref.mul(x).mul(.044715).add(1)).mul(SQRT_2_OVER_PI)).add(1));
|
|
4633
|
+
});
|
|
4634
|
+
/**
|
|
4635
|
+
* Gated linear unit (GLU) activation function.
|
|
4636
|
+
*
|
|
4637
|
+
* Splits the `axis` dimension of the input into two halves, a and b, then
|
|
4638
|
+
* computes `a * sigmoid(b)`.
|
|
4639
|
+
*/
|
|
4640
|
+
function glu(x, axis = -1) {
|
|
4641
|
+
x = fudgeArray(x);
|
|
4642
|
+
axis = checkAxis(axis, x.ndim);
|
|
4643
|
+
const size$1 = x.shape[axis];
|
|
4644
|
+
if (size$1 % 2 !== 0) throw new Error(`glu: axis ${axis} of shape (${x.shape}) does not have even length`);
|
|
4645
|
+
const slice = x.shape.map((a$1) => [0, a$1]);
|
|
4646
|
+
const a = shrink(x.ref, slice.toSpliced(axis, 1, [0, size$1 / 2]));
|
|
4647
|
+
const b = shrink(x, slice.toSpliced(axis, 1, [size$1 / 2, size$1]));
|
|
4648
|
+
return a.mul(sigmoid(b));
|
|
4649
|
+
}
|
|
4650
|
+
/**
|
|
4651
|
+
* Mish activation function.
|
|
4652
|
+
*
|
|
4653
|
+
* Computes the element-wise function:
|
|
4654
|
+
* `mish(x) = x * tanh(softplus(x))`
|
|
4655
|
+
*/
|
|
4656
|
+
function mish(x) {
|
|
4657
|
+
x = fudgeArray(x);
|
|
4658
|
+
return x.ref.mul(tanh(softplus(x)));
|
|
4659
|
+
}
|
|
4061
4660
|
/**
|
|
4062
4661
|
* Softmax function. Computes the function which rescales elements to the range
|
|
4063
4662
|
* [0, 1] such that the elements along `axis` sum to 1.
|
|
@@ -4134,7 +4733,7 @@ function logsumexp(x, axis) {
|
|
|
4134
4733
|
* ```
|
|
4135
4734
|
*/
|
|
4136
4735
|
function oneHot(x, numClasses) {
|
|
4137
|
-
if (x.dtype !==
|
|
4736
|
+
if (x.dtype !== DType.Int32) throw new TypeError(`oneHot expects integers, got ${x.dtype}`);
|
|
4138
4737
|
return eye(numClasses, void 0, { device: x.device }).slice(x);
|
|
4139
4738
|
}
|
|
4140
4739
|
|
|
@@ -4203,6 +4802,19 @@ const vmap = vmap$1;
|
|
|
4203
4802
|
const jacfwd = jacfwd$1;
|
|
4204
4803
|
/** Construct a Jaxpr by dynamically tracing a function with example inputs. */
|
|
4205
4804
|
const makeJaxpr = makeJaxpr$1;
|
|
4805
|
+
/**
|
|
4806
|
+
* Mark a function for automatic JIT compilation, with operator fusion.
|
|
4807
|
+
*
|
|
4808
|
+
* The function will be compiled the first time it is called with a set of
|
|
4809
|
+
* argument shapes.
|
|
4810
|
+
*
|
|
4811
|
+
* **Options:**
|
|
4812
|
+
* - `staticArgnums`: An array of argument indices to treat as static
|
|
4813
|
+
* (compile-time constant). These arguments must be hashable, won't be traced,
|
|
4814
|
+
* and different values will trigger recompilation.
|
|
4815
|
+
* - `device`: The device to place the computation on. If not specified, the
|
|
4816
|
+
* computation will be placed on the default device.
|
|
4817
|
+
*/
|
|
4206
4818
|
const jit = jit$1;
|
|
4207
4819
|
/**
|
|
4208
4820
|
* Produce a local linear approximation to a function at a point using jvp() and
|
|
@@ -4224,4 +4836,4 @@ const jacrev = jacrev$1;
|
|
|
4224
4836
|
const jacobian = jacrev;
|
|
4225
4837
|
|
|
4226
4838
|
//#endregion
|
|
4227
|
-
export { devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };
|
|
4839
|
+
export { DType, devices, grad, init, jacfwd, jacobian, jacrev, jit, jvp, lax_exports as lax, linearize, makeJaxpr, nn_exports as nn, numpy_exports as numpy, random_exports as random, setDevice, tree_exports as tree, valueAndGrad, vjp, vmap };
|