@jax-js/jax 0.0.1 → 0.0.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +14 -9
- package/dist/backend-1eVbAoaV.js +1890 -0
- package/dist/backend-BK21PBVP.cjs +2130 -0
- package/dist/chunk-Cl8Af3a2.js +11 -0
- package/dist/index.cjs +4076 -6059
- package/dist/index.d.cts +892 -650
- package/dist/index.d.ts +889 -650
- package/dist/index.js +3992 -3473
- package/dist/webgpu-JVpVad6g.js +591 -0
- package/dist/webgpu-c5Fe8nx8.cjs +591 -0
- package/package.json +29 -24
- package/dist/chunk-B2GFURUN.js +0 -1978
- package/dist/webgpu-QNXDOQZP.js +0 -559
|
@@ -0,0 +1,1890 @@
|
|
|
1
|
+
//#region src/pprint.ts
|
|
2
|
+
/** General class for pretty-printing expressions with indentation. */
|
|
3
|
+
var PPrint = class PPrint {
|
|
4
|
+
constructor(indents, lines) {
|
|
5
|
+
this.indents = indents;
|
|
6
|
+
this.lines = lines;
|
|
7
|
+
}
|
|
8
|
+
/** Add a fixed amount of indentation to each line. */
|
|
9
|
+
indent(spaces) {
|
|
10
|
+
return new PPrint(this.indents.map((i) => i + spaces), this.lines);
|
|
11
|
+
}
|
|
12
|
+
/** Concatenate pretty-printed expressions with newlines. */
|
|
13
|
+
concat(...items) {
|
|
14
|
+
return new PPrint((this.indents ?? []).concat(...items.map((i) => i.indents)), (this.lines ?? []).concat(...items.map((i) => i.lines)));
|
|
15
|
+
}
|
|
16
|
+
/** Stack one block to the right of another one, sharing 1 common line. */
|
|
17
|
+
stack(other) {
|
|
18
|
+
if (!other.lines.length) return this;
|
|
19
|
+
if (!this.lines.length) return other;
|
|
20
|
+
const indent = this.indents[this.indents.length - 1];
|
|
21
|
+
const s = this.lines[this.lines.length - 1];
|
|
22
|
+
const indentedBlock = other.indent(indent + s.length);
|
|
23
|
+
return new PPrint(this.indents.concat(indentedBlock.indents.slice(1)), this.lines.slice(0, -1).concat(s + " ".repeat(other.indents[0]) + other.lines[0], ...indentedBlock.lines.slice(1)));
|
|
24
|
+
}
|
|
25
|
+
/** Combine this block of lines into a formatted string. */
|
|
26
|
+
toString() {
|
|
27
|
+
return this.lines.map((line, i) => " ".repeat(this.indents[i]) + line).join("\n");
|
|
28
|
+
}
|
|
29
|
+
static pp(s) {
|
|
30
|
+
const lines = s.toString().split("\n");
|
|
31
|
+
return new PPrint(Array(lines.length).fill(0), lines);
|
|
32
|
+
}
|
|
33
|
+
};
|
|
34
|
+
|
|
35
|
+
//#endregion
|
|
36
|
+
//#region src/utils.ts
|
|
37
|
+
/** @file Generic programming utilities with no dependencies on library code. */
|
|
38
|
+
const DEBUG = 3;
|
|
39
|
+
function unzip2(pairs) {
|
|
40
|
+
const lst1 = [];
|
|
41
|
+
const lst2 = [];
|
|
42
|
+
for (const [x, y] of pairs) {
|
|
43
|
+
lst1.push(x);
|
|
44
|
+
lst2.push(y);
|
|
45
|
+
}
|
|
46
|
+
return [lst1, lst2];
|
|
47
|
+
}
|
|
48
|
+
function zip(xs, ys) {
|
|
49
|
+
return xs.map((x, i) => [x, ys[i]]);
|
|
50
|
+
}
|
|
51
|
+
function rep(length, value) {
|
|
52
|
+
if (value instanceof Function) return new Array(length).fill(0).map((_, i) => value(i));
|
|
53
|
+
return new Array(length).fill(value);
|
|
54
|
+
}
|
|
55
|
+
function prod(arr) {
|
|
56
|
+
return arr.reduce((acc, x) => acc * x, 1);
|
|
57
|
+
}
|
|
58
|
+
/** Shorthand for integer division, like in Python. */
|
|
59
|
+
function intdiv(a, b) {
|
|
60
|
+
return Math.floor(a / b);
|
|
61
|
+
}
|
|
62
|
+
/** Clamp `x` to the range `[min, max]`. */
|
|
63
|
+
function clamp(x, min, max) {
|
|
64
|
+
return Math.max(min, Math.min(max, x));
|
|
65
|
+
}
|
|
66
|
+
/** Check if two objects are deep equal. */
|
|
67
|
+
function deepEqual(a, b) {
|
|
68
|
+
if (a === b) return true;
|
|
69
|
+
if (typeof a !== "object" || typeof b !== "object") return false;
|
|
70
|
+
if (a === null || b === null) return false;
|
|
71
|
+
if (Object.keys(a).length !== Object.keys(b).length) return false;
|
|
72
|
+
for (const key of Object.keys(a)) if (!deepEqual(a[key], b[key])) return false;
|
|
73
|
+
return true;
|
|
74
|
+
}
|
|
75
|
+
/** Splits the list based on a condition, `false` first then `true`. */
|
|
76
|
+
function partitionList(which, array) {
|
|
77
|
+
const falseList = [];
|
|
78
|
+
const trueList = [];
|
|
79
|
+
for (let i = 0; i < which.length; i++) if (which[i]) trueList.push(array[i]);
|
|
80
|
+
else falseList.push(array[i]);
|
|
81
|
+
return [falseList, trueList];
|
|
82
|
+
}
|
|
83
|
+
/** Compare two arrays of numbers lexicographically. */
|
|
84
|
+
function lexCompare(a, b) {
|
|
85
|
+
const minLength = Math.min(a.length, b.length);
|
|
86
|
+
for (let i = 0; i < minLength; i++) {
|
|
87
|
+
if (a[i] < b[i]) return -1;
|
|
88
|
+
if (a[i] > b[i]) return 1;
|
|
89
|
+
}
|
|
90
|
+
return a.length - b.length;
|
|
91
|
+
}
|
|
92
|
+
/** Check if an object is a number pair, i.e., a tuple of two numbers. */
|
|
93
|
+
function isNumberPair(x) {
|
|
94
|
+
return Array.isArray(x) && x.length === 2 && typeof x[0] === "number" && typeof x[1] === "number";
|
|
95
|
+
}
|
|
96
|
+
/** Check an axis against number of dimensions, and resolve negative axes. */
|
|
97
|
+
function checkAxis(axis, ndim) {
|
|
98
|
+
if (axis < -ndim || axis >= ndim) throw new Error(`Invalid axis ${axis} for array of ${ndim} dimensions`);
|
|
99
|
+
return axis < 0 ? axis + ndim : axis;
|
|
100
|
+
}
|
|
101
|
+
function range(start, stop, step = 1) {
|
|
102
|
+
if (stop === void 0) {
|
|
103
|
+
stop = start;
|
|
104
|
+
start = 0;
|
|
105
|
+
}
|
|
106
|
+
const result = [];
|
|
107
|
+
for (let i = start; i < stop; i += step) result.push(i);
|
|
108
|
+
return result;
|
|
109
|
+
}
|
|
110
|
+
function isPermutation(axis, n) {
|
|
111
|
+
if (axis.length !== n) return false;
|
|
112
|
+
const seen = /* @__PURE__ */ new Set();
|
|
113
|
+
for (const x of axis) {
|
|
114
|
+
if (x < 0 || x >= n) return false;
|
|
115
|
+
seen.add(x);
|
|
116
|
+
}
|
|
117
|
+
return seen.size === n;
|
|
118
|
+
}
|
|
119
|
+
function invertPermutation(axis) {
|
|
120
|
+
const n = axis.length;
|
|
121
|
+
if (!isPermutation(axis, n)) throw new Error("invertPermutation: axis is not a permutation");
|
|
122
|
+
const result = new Array(n);
|
|
123
|
+
for (let i = 0; i < n; i++) result[axis[i]] = i;
|
|
124
|
+
return result;
|
|
125
|
+
}
|
|
126
|
+
/** Topologically sort a DAG, given terminal nodes and an ancestor function. */
|
|
127
|
+
function toposort(terminals, parents) {
|
|
128
|
+
const childCounts = /* @__PURE__ */ new Map();
|
|
129
|
+
const stack = [...new Set(terminals)];
|
|
130
|
+
while (true) {
|
|
131
|
+
const node = stack.pop();
|
|
132
|
+
if (!node) break;
|
|
133
|
+
for (const parent of parents(node)) if (childCounts.has(parent)) childCounts.set(parent, childCounts.get(parent) + 1);
|
|
134
|
+
else {
|
|
135
|
+
childCounts.set(parent, 1);
|
|
136
|
+
stack.push(parent);
|
|
137
|
+
}
|
|
138
|
+
}
|
|
139
|
+
for (const node of terminals) childCounts.set(node, childCounts.get(node) - 1);
|
|
140
|
+
const order = [];
|
|
141
|
+
const frontier = terminals.filter((n) => !childCounts.get(n));
|
|
142
|
+
while (true) {
|
|
143
|
+
const node = frontier.pop();
|
|
144
|
+
if (!node) break;
|
|
145
|
+
order.push(node);
|
|
146
|
+
for (const parent of parents(node)) {
|
|
147
|
+
const c = childCounts.get(parent) - 1;
|
|
148
|
+
childCounts.set(parent, c);
|
|
149
|
+
if (c == 0) frontier.push(parent);
|
|
150
|
+
}
|
|
151
|
+
}
|
|
152
|
+
return order.reverse();
|
|
153
|
+
}
|
|
154
|
+
/**
|
|
155
|
+
* Returns the largest power of 2 less than or equal to `max`.
|
|
156
|
+
*
|
|
157
|
+
* If `hint` is nonzero, it will not return a number greater than the first
|
|
158
|
+
* power of 2 that is greater than or equal to `hint`.
|
|
159
|
+
*/
|
|
160
|
+
function findPow2(hint, max) {
|
|
161
|
+
if (max < 1) throw new Error("max must be a positive integer");
|
|
162
|
+
let ret = 1;
|
|
163
|
+
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
164
|
+
return ret;
|
|
165
|
+
}
|
|
166
|
+
function recursiveFlatten(ar) {
|
|
167
|
+
if (!Array.isArray(ar)) return [ar];
|
|
168
|
+
return ar.flat(Infinity);
|
|
169
|
+
}
|
|
170
|
+
/** Strip an outermost pair of nested parentheses from an expression, if any. */
|
|
171
|
+
function strip1(str) {
|
|
172
|
+
if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
|
|
173
|
+
return str;
|
|
174
|
+
}
|
|
175
|
+
/**
|
|
176
|
+
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
177
|
+
* Probability-wise, it's good enough to be used for something like
|
|
178
|
+
* deduplicating seen compiler expressions, although it's not adversarial.
|
|
179
|
+
*
|
|
180
|
+
* See https://en.wikipedia.org/wiki/Lagrange%27s_theorem_(number_theory)
|
|
181
|
+
*/
|
|
182
|
+
var FpHash = class FpHash {
|
|
183
|
+
value = 8773157n;
|
|
184
|
+
#update(x) {
|
|
185
|
+
const base = 873192869n;
|
|
186
|
+
const modulus = 3189051996290219n;
|
|
187
|
+
this.value = (this.value * base + x) % modulus;
|
|
188
|
+
}
|
|
189
|
+
update(...values) {
|
|
190
|
+
for (const x of values) if (typeof x === "string") for (const c of x) this.#update(BigInt(199 + c.charCodeAt(0)));
|
|
191
|
+
else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
|
|
192
|
+
else {
|
|
193
|
+
const ar = new Float64Array([x]);
|
|
194
|
+
this.#update(new DataView(ar.buffer).getBigUint64(0, true));
|
|
195
|
+
}
|
|
196
|
+
else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
|
|
197
|
+
else if (typeof x === "bigint") this.#update(x ^ 71657401n);
|
|
198
|
+
else if (x === null) this.#update(37832657n);
|
|
199
|
+
else if (x === void 0) this.#update(18145117n);
|
|
200
|
+
else if (typeof x === "object" && "hash" in x) x.hash(this);
|
|
201
|
+
return this;
|
|
202
|
+
}
|
|
203
|
+
static hash(...values) {
|
|
204
|
+
return new FpHash().update(...values).value;
|
|
205
|
+
}
|
|
206
|
+
};
|
|
207
|
+
/** Run a function while caching it inline inside a `Map`. */
|
|
208
|
+
function runWithCache(cache, key, thunk) {
|
|
209
|
+
if (cache.has(key)) return cache.get(key);
|
|
210
|
+
else {
|
|
211
|
+
const value = thunk();
|
|
212
|
+
cache.set(key, value);
|
|
213
|
+
return value;
|
|
214
|
+
}
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
//#endregion
|
|
218
|
+
//#region src/alu.ts
|
|
219
|
+
let DType = /* @__PURE__ */ function(DType$1) {
|
|
220
|
+
DType$1["Float32"] = "float32";
|
|
221
|
+
DType$1["Int32"] = "int32";
|
|
222
|
+
DType$1["Uint32"] = "uint32";
|
|
223
|
+
DType$1["Bool"] = "bool";
|
|
224
|
+
DType$1["Complex64"] = "complex64";
|
|
225
|
+
return DType$1;
|
|
226
|
+
}({});
|
|
227
|
+
const byteWidth = (dtype) => {
|
|
228
|
+
switch (dtype) {
|
|
229
|
+
case DType.Float32:
|
|
230
|
+
case DType.Int32:
|
|
231
|
+
case DType.Uint32:
|
|
232
|
+
case DType.Bool: return 4;
|
|
233
|
+
case DType.Complex64: return 8;
|
|
234
|
+
default: throw new TypeError(`Unknown dtype: ${dtype}`);
|
|
235
|
+
}
|
|
236
|
+
};
|
|
237
|
+
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Complex64;
|
|
238
|
+
function dtypedArray(dtype, data) {
|
|
239
|
+
switch (dtype) {
|
|
240
|
+
case DType.Float32: return new Float32Array(data);
|
|
241
|
+
case DType.Int32: return new Int32Array(data);
|
|
242
|
+
case DType.Uint32: return new Uint32Array(data);
|
|
243
|
+
case DType.Bool: return new Int32Array(data);
|
|
244
|
+
default: throw new Error(`Unimplemented dtype: ${dtype}`);
|
|
245
|
+
}
|
|
246
|
+
}
|
|
247
|
+
/**
|
|
248
|
+
* Mathematical expression on scalar values.
|
|
249
|
+
*
|
|
250
|
+
* This is similiar to and based on tinygrad's UOp class, but it's more specific
|
|
251
|
+
* to just math on scalars. We're doing this to avoid the complexity of a full
|
|
252
|
+
* graph rewrite engine.
|
|
253
|
+
*/
|
|
254
|
+
var AluExp = class AluExp {
|
|
255
|
+
#hash;
|
|
256
|
+
#simplified;
|
|
257
|
+
#range;
|
|
258
|
+
constructor(op, dtype, src, arg = void 0) {
|
|
259
|
+
this.op = op;
|
|
260
|
+
this.dtype = dtype;
|
|
261
|
+
this.src = src;
|
|
262
|
+
this.arg = arg;
|
|
263
|
+
if (AluGroup.RequiredFloat.has(op) && !isFloatDtype(dtype)) throw new TypeError(`Unsupported dtype for ${op}: ${dtype}`);
|
|
264
|
+
if (op === AluOp.Bitcast && (dtype === DType.Bool || src[0].dtype === DType.Bool || byteWidth(dtype) !== byteWidth(src[0].dtype))) throw new TypeError(`Bitcast from ${src[0].dtype} -> ${dtype}`);
|
|
265
|
+
if (op === AluOp.Threefry2x32 && (dtype !== DType.Uint32 || src.some((x) => x.dtype !== DType.Uint32))) throw new TypeError("Threefry2x32 requires uint32 types");
|
|
266
|
+
}
|
|
267
|
+
static add(a, b) {
|
|
268
|
+
return new AluExp(AluOp.Add, a.dtype, [a, b]);
|
|
269
|
+
}
|
|
270
|
+
static sub(a, b) {
|
|
271
|
+
return new AluExp(AluOp.Sub, a.dtype, [a, b]);
|
|
272
|
+
}
|
|
273
|
+
static mul(a, b) {
|
|
274
|
+
return new AluExp(AluOp.Mul, a.dtype, [a, b]);
|
|
275
|
+
}
|
|
276
|
+
static idiv(a, b) {
|
|
277
|
+
return new AluExp(AluOp.Idiv, a.dtype, [a, b]);
|
|
278
|
+
}
|
|
279
|
+
static mod(a, b) {
|
|
280
|
+
return new AluExp(AluOp.Mod, a.dtype, [a, b]);
|
|
281
|
+
}
|
|
282
|
+
static min(a, b) {
|
|
283
|
+
return new AluExp(AluOp.Min, a.dtype, [a, b]);
|
|
284
|
+
}
|
|
285
|
+
static max(a, b) {
|
|
286
|
+
return new AluExp(AluOp.Max, a.dtype, [a, b]);
|
|
287
|
+
}
|
|
288
|
+
static sin(a) {
|
|
289
|
+
return new AluExp(AluOp.Sin, a.dtype, [a]);
|
|
290
|
+
}
|
|
291
|
+
static cos(a) {
|
|
292
|
+
return new AluExp(AluOp.Cos, a.dtype, [a]);
|
|
293
|
+
}
|
|
294
|
+
static exp(a) {
|
|
295
|
+
return new AluExp(AluOp.Exp, a.dtype, [a]);
|
|
296
|
+
}
|
|
297
|
+
static log(a) {
|
|
298
|
+
return new AluExp(AluOp.Log, a.dtype, [a]);
|
|
299
|
+
}
|
|
300
|
+
static reciprocal(a) {
|
|
301
|
+
return new AluExp(AluOp.Reciprocal, a.dtype, [a]);
|
|
302
|
+
}
|
|
303
|
+
static cast(dtype, a) {
|
|
304
|
+
if (a.dtype === dtype) return a;
|
|
305
|
+
return new AluExp(AluOp.Cast, dtype, [a]);
|
|
306
|
+
}
|
|
307
|
+
static bitcast(dtype, a) {
|
|
308
|
+
if (a.dtype === dtype) return a;
|
|
309
|
+
return new AluExp(AluOp.Bitcast, dtype, [a]);
|
|
310
|
+
}
|
|
311
|
+
static threefry2x32(k0, k1, c0, c1, mode = "xor") {
|
|
312
|
+
return new AluExp(AluOp.Threefry2x32, DType.Uint32, [
|
|
313
|
+
k0,
|
|
314
|
+
k1,
|
|
315
|
+
c0,
|
|
316
|
+
c1
|
|
317
|
+
], mode);
|
|
318
|
+
}
|
|
319
|
+
static cmplt(a, b) {
|
|
320
|
+
return new AluExp(AluOp.Cmplt, DType.Bool, [a, b]);
|
|
321
|
+
}
|
|
322
|
+
static cmpne(a, b) {
|
|
323
|
+
return new AluExp(AluOp.Cmpne, DType.Bool, [a, b]);
|
|
324
|
+
}
|
|
325
|
+
static where(cond, a, b) {
|
|
326
|
+
return new AluExp(AluOp.Where, a.dtype, [
|
|
327
|
+
cond,
|
|
328
|
+
a,
|
|
329
|
+
b
|
|
330
|
+
]);
|
|
331
|
+
}
|
|
332
|
+
static const(dtype, value) {
|
|
333
|
+
if (dtype === DType.Bool) value = Number(Boolean(value));
|
|
334
|
+
else if (dtype === DType.Int32) value = Math.trunc(value) | 0;
|
|
335
|
+
else if (dtype === DType.Uint32) value = Math.trunc(value) >>> 0;
|
|
336
|
+
if (typeof value !== "number") throw new TypeError(`Expected a number for constant, got ${typeof value}: ${value}`);
|
|
337
|
+
return new AluExp(AluOp.Const, dtype, [], value);
|
|
338
|
+
}
|
|
339
|
+
static special(dtype, name, n) {
|
|
340
|
+
return new AluExp(AluOp.Special, dtype, [], [name, n]);
|
|
341
|
+
}
|
|
342
|
+
static variable(dtype, name) {
|
|
343
|
+
return new AluExp(AluOp.Variable, dtype, [], name);
|
|
344
|
+
}
|
|
345
|
+
static globalIndex(dtype, gid, bufidx) {
|
|
346
|
+
return new AluExp(AluOp.GlobalIndex, dtype, [bufidx], gid);
|
|
347
|
+
}
|
|
348
|
+
static globalView(dtype, gid, st, indices) {
|
|
349
|
+
return new AluExp(AluOp.GlobalView, dtype, indices, [gid, st]);
|
|
350
|
+
}
|
|
351
|
+
static i32(value) {
|
|
352
|
+
return AluExp.const(DType.Int32, value);
|
|
353
|
+
}
|
|
354
|
+
static u32(value) {
|
|
355
|
+
return AluExp.const(DType.Uint32, value);
|
|
356
|
+
}
|
|
357
|
+
static f32(value) {
|
|
358
|
+
return AluExp.const(DType.Float32, value);
|
|
359
|
+
}
|
|
360
|
+
static bool(value) {
|
|
361
|
+
return AluExp.const(DType.Bool, Number(value));
|
|
362
|
+
}
|
|
363
|
+
not() {
|
|
364
|
+
if (this.dtype !== DType.Bool) throw new Error("not() can only be called on boolean expressions");
|
|
365
|
+
return AluExp.cmpne(this, AluExp.const(DType.Bool, true));
|
|
366
|
+
}
|
|
367
|
+
/** Compute a reasonable expression hash with low collision rate. */
|
|
368
|
+
getHash() {
|
|
369
|
+
if (this.#hash !== void 0) return this.#hash;
|
|
370
|
+
const hasher = new FpHash();
|
|
371
|
+
hasher.update(this.op, this.dtype, JSON.stringify(this.arg));
|
|
372
|
+
hasher.update(this.src.length, ...this.src);
|
|
373
|
+
this.#hash = hasher.value;
|
|
374
|
+
return this.#hash;
|
|
375
|
+
}
|
|
376
|
+
hash(state) {
|
|
377
|
+
state.update(this.getHash());
|
|
378
|
+
}
|
|
379
|
+
/** Substitute variables in this AluExp to values. */
|
|
380
|
+
substitute(variables) {
|
|
381
|
+
return this.rewrite((exp) => {
|
|
382
|
+
if (exp.op === AluOp.Variable && Object.hasOwn(variables, exp.arg)) {
|
|
383
|
+
if (exp.dtype !== variables[exp.arg].dtype) throw new Error(`Type mismatch: ${exp.dtype} vs ${variables[exp.arg].dtype}`);
|
|
384
|
+
return variables[exp.arg];
|
|
385
|
+
}
|
|
386
|
+
});
|
|
387
|
+
}
|
|
388
|
+
/** Reindex gid values in this expression as needed. */
|
|
389
|
+
reindexGids(gidMap) {
|
|
390
|
+
return this.rewrite((exp) => {
|
|
391
|
+
if (exp.op === AluOp.GlobalIndex) {
|
|
392
|
+
const gid = exp.arg;
|
|
393
|
+
const newGid = gidMap.get(gid);
|
|
394
|
+
if (newGid !== void 0 && newGid !== gid) return AluExp.globalIndex(exp.dtype, newGid, exp.src[0]);
|
|
395
|
+
} else if (exp.op === AluOp.GlobalView) {
|
|
396
|
+
const gid = exp.arg[0];
|
|
397
|
+
const newGid = gidMap.get(gid);
|
|
398
|
+
if (newGid !== void 0 && newGid !== gid) return AluExp.globalView(exp.dtype, newGid, exp.arg[1], exp.src);
|
|
399
|
+
}
|
|
400
|
+
});
|
|
401
|
+
}
|
|
402
|
+
#computeRange() {
|
|
403
|
+
if (this.#range !== void 0) return this.#range;
|
|
404
|
+
const src = this.src;
|
|
405
|
+
const minMax4 = (f) => {
|
|
406
|
+
const [r1, r2] = [src[0].#computeRange(), src[1].#computeRange()];
|
|
407
|
+
const values = [
|
|
408
|
+
f(r1[0], r2[0]),
|
|
409
|
+
f(r1[0], r2[1]),
|
|
410
|
+
f(r1[1], r2[0]),
|
|
411
|
+
f(r1[1], r2[1])
|
|
412
|
+
];
|
|
413
|
+
return [Math.min(...values), Math.max(...values)];
|
|
414
|
+
};
|
|
415
|
+
let ret;
|
|
416
|
+
switch (this.op) {
|
|
417
|
+
case AluOp.Add:
|
|
418
|
+
ret = [src[0].min + src[1].min, src[0].max + src[1].max];
|
|
419
|
+
break;
|
|
420
|
+
case AluOp.Sub:
|
|
421
|
+
ret = [src[0].min - src[1].max, src[0].max - src[1].min];
|
|
422
|
+
break;
|
|
423
|
+
case AluOp.Mul: {
|
|
424
|
+
ret = minMax4((a, b) => a * b);
|
|
425
|
+
break;
|
|
426
|
+
}
|
|
427
|
+
case AluOp.Idiv: {
|
|
428
|
+
ret = minMax4((a, b) => Math.floor(a / b));
|
|
429
|
+
break;
|
|
430
|
+
}
|
|
431
|
+
case AluOp.Mod: {
|
|
432
|
+
let divisorRange = src[1].#computeRange();
|
|
433
|
+
if (divisorRange[0] <= 0 && divisorRange[1] >= 0) divisorRange = [0, Math.max(-divisorRange[0], divisorRange[1])];
|
|
434
|
+
const maxDivisor = isFloatDtype(this.dtype) ? divisorRange[1] : divisorRange[1] - 1;
|
|
435
|
+
ret = [clamp(src[0].min, -maxDivisor, 0), clamp(src[0].max, 0, maxDivisor)];
|
|
436
|
+
break;
|
|
437
|
+
}
|
|
438
|
+
case AluOp.Min:
|
|
439
|
+
ret = [Math.min(src[0].min, src[1].min), Math.min(src[0].max, src[1].max)];
|
|
440
|
+
break;
|
|
441
|
+
case AluOp.Max:
|
|
442
|
+
ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
|
|
443
|
+
break;
|
|
444
|
+
case AluOp.Sin:
|
|
445
|
+
ret = [Math.sin(src[0].min), Math.sin(src[0].max)];
|
|
446
|
+
break;
|
|
447
|
+
case AluOp.Cos:
|
|
448
|
+
ret = [Math.cos(src[0].min), Math.cos(src[0].max)];
|
|
449
|
+
break;
|
|
450
|
+
case AluOp.Exp:
|
|
451
|
+
ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
|
|
452
|
+
break;
|
|
453
|
+
case AluOp.Log:
|
|
454
|
+
ret = [Math.log(src[0].min), Math.log(src[0].max)];
|
|
455
|
+
break;
|
|
456
|
+
case AluOp.Reciprocal:
|
|
457
|
+
if (src[0].min <= 0 && src[0].max >= 0) return [-Infinity, Infinity];
|
|
458
|
+
ret = [1 / src[0].max, 1 / src[0].min];
|
|
459
|
+
break;
|
|
460
|
+
case AluOp.Cast:
|
|
461
|
+
if (this.dtype === DType.Bool) {
|
|
462
|
+
const canBeZero = src[0].min <= 0 && src[0].max >= 0;
|
|
463
|
+
const mustBeZero = src[0].min === 0 && src[0].max === 0;
|
|
464
|
+
ret = mustBeZero ? [0, 0] : canBeZero ? [0, 1] : [1, 1];
|
|
465
|
+
} else if (this.dtype === DType.Int32) ret = [Math.trunc(src[0].min), Math.trunc(src[0].max)];
|
|
466
|
+
else if (this.dtype === DType.Uint32) {
|
|
467
|
+
const a = Math.trunc(src[0].min);
|
|
468
|
+
const b = Math.trunc(src[0].max);
|
|
469
|
+
if (Math.floor(a / 2 ** 32) !== Math.floor(b / 2 ** 32)) ret = [0, -1 >>> 0];
|
|
470
|
+
else ret = [a % 2 ** 32, b % 2 ** 32];
|
|
471
|
+
} else ret = [src[0].min, src[0].max];
|
|
472
|
+
break;
|
|
473
|
+
case AluOp.Cmplt:
|
|
474
|
+
ret = [0, 1];
|
|
475
|
+
break;
|
|
476
|
+
case AluOp.Cmpne:
|
|
477
|
+
ret = [0, 1];
|
|
478
|
+
break;
|
|
479
|
+
case AluOp.Where:
|
|
480
|
+
ret = [Math.min(src[1].min, src[2].min), Math.max(src[1].max, src[2].max)];
|
|
481
|
+
break;
|
|
482
|
+
case AluOp.Const:
|
|
483
|
+
ret = [this.arg, this.arg];
|
|
484
|
+
break;
|
|
485
|
+
case AluOp.Special:
|
|
486
|
+
ret = [0, this.arg[1] - 1];
|
|
487
|
+
break;
|
|
488
|
+
default: ret = [-Infinity, Infinity];
|
|
489
|
+
}
|
|
490
|
+
if (isNaN(ret[0]) || isNaN(ret[1])) ret = [-Infinity, Infinity];
|
|
491
|
+
if (this.dtype === DType.Bool) {
|
|
492
|
+
ret[0] = clamp(ret[0], 0, 1);
|
|
493
|
+
ret[1] = clamp(ret[1], 0, 1);
|
|
494
|
+
}
|
|
495
|
+
this.#range = ret;
|
|
496
|
+
return ret;
|
|
497
|
+
}
|
|
498
|
+
get min() {
|
|
499
|
+
return this.#computeRange()[0];
|
|
500
|
+
}
|
|
501
|
+
get max() {
|
|
502
|
+
return this.#computeRange()[1];
|
|
503
|
+
}
|
|
504
|
+
#isConstInt() {
|
|
505
|
+
return this.op === AluOp.Const && (this.dtype === DType.Int32 || this.dtype === DType.Uint32);
|
|
506
|
+
}
|
|
507
|
+
/**
|
|
508
|
+
* Simplify the expression by replacing any known patterns and deduping
|
|
509
|
+
* identical subexpressions.
|
|
510
|
+
*/
|
|
511
|
+
simplify(cache = /* @__PURE__ */ new Map()) {
|
|
512
|
+
if (this.#simplified !== void 0) return this.#simplified;
|
|
513
|
+
const hash = this.getHash();
|
|
514
|
+
if (cache.has(hash)) return this.#simplified = cache.get(hash);
|
|
515
|
+
const simplified = this.#simplifyInner(cache);
|
|
516
|
+
const simplifiedHash = simplified.getHash();
|
|
517
|
+
if (cache.has(simplifiedHash)) {
|
|
518
|
+
const prevSimplified = cache.get(simplifiedHash);
|
|
519
|
+
cache.set(hash, prevSimplified);
|
|
520
|
+
this.#simplified = prevSimplified;
|
|
521
|
+
return prevSimplified;
|
|
522
|
+
} else {
|
|
523
|
+
cache.set(hash, simplified);
|
|
524
|
+
cache.set(simplifiedHash, simplified);
|
|
525
|
+
this.#simplified = simplified;
|
|
526
|
+
return simplified;
|
|
527
|
+
}
|
|
528
|
+
}
|
|
529
|
+
#simplifyInner(cache) {
|
|
530
|
+
const src = this.src.map((x) => x.simplify(cache));
|
|
531
|
+
const { op } = this;
|
|
532
|
+
if (src.every((x) => x.op === AluOp.Const) && !AluGroup.Variable.has(op)) {
|
|
533
|
+
const newExp$1 = new AluExp(op, this.dtype, src, this.arg);
|
|
534
|
+
return AluExp.const(this.dtype, newExp$1.evaluate({}));
|
|
535
|
+
}
|
|
536
|
+
if (op !== AluOp.Const && this.min === this.max) return AluExp.const(this.dtype, this.min);
|
|
537
|
+
if (AluGroup.Binary.has(op)) for (let i = 0; i < 2; i++) {
|
|
538
|
+
if (src[i].op !== AluOp.Const) continue;
|
|
539
|
+
const x = src[i].arg;
|
|
540
|
+
if (op === AluOp.Add && x === 0) return src[1 - i];
|
|
541
|
+
if (op === AluOp.Sub && i === 1 && x === 0) return src[1 - i];
|
|
542
|
+
if (op === AluOp.Mul && x === 1) return src[1 - i];
|
|
543
|
+
if (op === AluOp.Mul && x === 0) return AluExp.const(this.dtype, 0);
|
|
544
|
+
if (op === AluOp.Idiv && i === 1 && x === 1) return src[1 - i];
|
|
545
|
+
}
|
|
546
|
+
if ((op === AluOp.Add || op === AluOp.Sub) && src[1].op === AluOp.Mul) {
|
|
547
|
+
const [a, b] = src[1].src;
|
|
548
|
+
const opNeg = op === AluOp.Add ? AluOp.Sub : AluOp.Add;
|
|
549
|
+
if (a.op === AluOp.Const && a.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], b]);
|
|
550
|
+
else if (b.op === AluOp.Const && b.arg === -1) return new AluExp(opNeg, this.dtype, [src[0], a]);
|
|
551
|
+
}
|
|
552
|
+
if (op === AluOp.Mod && src[1].op === AluOp.Const && src[0].min >= 0 && src[0].max < src[1].arg) return src[0];
|
|
553
|
+
if (op === AluOp.Add && src[0].op === AluOp.Mul && src[0].src[1].#isConstInt() && src[1].op === AluOp.Mod && src[1].src[1].#isConstInt() && src[0].src[1].arg === src[1].src[1].arg) {
|
|
554
|
+
const [mul, mod] = src;
|
|
555
|
+
const check = (exp) => {
|
|
556
|
+
return exp.op === AluOp.Idiv && exp.src[1].#isConstInt() && exp.src[1].arg === mod.src[1].arg && exp.src[0] === mod.src[0];
|
|
557
|
+
};
|
|
558
|
+
if (check(mul.src[0])) return mod.src[0];
|
|
559
|
+
if (mul.src[0].op === AluOp.Mod) {
|
|
560
|
+
const [x, y] = mul.src[0].src;
|
|
561
|
+
if (check(x)) return AluExp.mod(mod.src[0], AluExp.mul(mod.src[1], y)).simplify(cache);
|
|
562
|
+
}
|
|
563
|
+
}
|
|
564
|
+
if (op === AluOp.Idiv && src[1].#isConstInt()) {
|
|
565
|
+
const [numer, denom] = src;
|
|
566
|
+
const B = denom.arg;
|
|
567
|
+
for (let i = 0; i < 2; i++) {
|
|
568
|
+
if (numer.op === AluOp.Mul && numer.src[i].#isConstInt()) {
|
|
569
|
+
const A = numer.src[i].arg;
|
|
570
|
+
if (A % B === 0) {
|
|
571
|
+
let ret = numer.src[1 - i];
|
|
572
|
+
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.i32(A / B));
|
|
573
|
+
return ret.simplify(cache);
|
|
574
|
+
}
|
|
575
|
+
}
|
|
576
|
+
for (let j = 0; j < 2; j++) if (numer.op === AluOp.Add && numer.src[j].op === AluOp.Mul && numer.src[j].src[i].#isConstInt()) {
|
|
577
|
+
const A = numer.src[j].src[i].arg;
|
|
578
|
+
if (A % B === 0) {
|
|
579
|
+
let ret = numer.src[j].src[1 - i];
|
|
580
|
+
if (A / B !== 1) ret = AluExp.mul(ret, AluExp.i32(A / B));
|
|
581
|
+
ret = AluExp.add(ret, AluExp.idiv(numer.src[1 - j], B));
|
|
582
|
+
return ret.simplify(cache);
|
|
583
|
+
}
|
|
584
|
+
}
|
|
585
|
+
}
|
|
586
|
+
}
|
|
587
|
+
if (op === AluOp.Mod && src[1].#isConstInt() && src[1].arg > 0 && src[0].min >= 0) {
|
|
588
|
+
const [numer, denom] = src;
|
|
589
|
+
const B = denom.arg;
|
|
590
|
+
for (let i = 0; i < 2; i++) if (numer.op === AluOp.Add && numer.src[i].#isConstInt()) {
|
|
591
|
+
const A = numer.src[i].arg;
|
|
592
|
+
let ret = numer.src[1 - i];
|
|
593
|
+
if (A % B !== 0) ret = AluExp.add(ret, AluExp.i32(A % B));
|
|
594
|
+
return ret.simplify(cache);
|
|
595
|
+
}
|
|
596
|
+
}
|
|
597
|
+
if (op === AluOp.Cmplt) {
|
|
598
|
+
if (src[0].min >= src[1].max) return AluExp.const(DType.Bool, false);
|
|
599
|
+
if (src[0].max < src[1].min) return AluExp.const(DType.Bool, true);
|
|
600
|
+
}
|
|
601
|
+
if (op === AluOp.Cmpne) {
|
|
602
|
+
if (src[0].max < src[1].min || src[0].min > src[1].max) return AluExp.const(DType.Bool, true);
|
|
603
|
+
}
|
|
604
|
+
if (op === AluOp.Where) {
|
|
605
|
+
if (src[0].max === 0) return src[2];
|
|
606
|
+
if (src[0].min === 1) return src[1];
|
|
607
|
+
}
|
|
608
|
+
const newExp = src.every((s, i) => s === this.src[i]) ? this : new AluExp(op, this.dtype, src, this.arg);
|
|
609
|
+
return newExp;
|
|
610
|
+
}
|
|
611
|
+
/** Resolve this to a value, or `undefined` if not possible. */
|
|
612
|
+
resolve() {
|
|
613
|
+
const x = this.simplify();
|
|
614
|
+
if (x.op === AluOp.Const) return x.arg;
|
|
615
|
+
return void 0;
|
|
616
|
+
}
|
|
617
|
+
/**
|
|
618
|
+
* Evaluate the expression on CPU, returning the result.
|
|
619
|
+
*
|
|
620
|
+
* Typically you would compile the AluExp as a representation to a lower-level
|
|
621
|
+
* language. This is just to define the semantics and help debug.
|
|
622
|
+
*
|
|
623
|
+
* Note that the representation of Bool is as a number (0 or 1) here.
|
|
624
|
+
*/
|
|
625
|
+
evaluate(context, globals) {
|
|
626
|
+
if (AluGroup.Binary.has(this.op) || AluGroup.Compare.has(this.op)) {
|
|
627
|
+
const x = this.src[0].evaluate(context, globals);
|
|
628
|
+
const y = this.src[1].evaluate(context, globals);
|
|
629
|
+
switch (this.op) {
|
|
630
|
+
case AluOp.Add: return this.dtype === DType.Bool ? Number(x || y) : x + y;
|
|
631
|
+
case AluOp.Sub: return x - y;
|
|
632
|
+
case AluOp.Mul: return this.dtype === DType.Bool ? Number(x && y) : x * y;
|
|
633
|
+
case AluOp.Idiv: return Math.trunc(x / y);
|
|
634
|
+
case AluOp.Mod: return x % y;
|
|
635
|
+
case AluOp.Min: return Math.min(x, y);
|
|
636
|
+
case AluOp.Max: return Math.max(x, y);
|
|
637
|
+
case AluOp.Cmplt: return Number(x < y);
|
|
638
|
+
case AluOp.Cmpne: return Number(x != y);
|
|
639
|
+
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
640
|
+
}
|
|
641
|
+
}
|
|
642
|
+
if (AluGroup.Unary.has(this.op)) {
|
|
643
|
+
const x = this.src[0].evaluate(context, globals);
|
|
644
|
+
switch (this.op) {
|
|
645
|
+
case AluOp.Sin: return Math.sin(x);
|
|
646
|
+
case AluOp.Cos: return Math.cos(x);
|
|
647
|
+
case AluOp.Exp: return Math.exp(x);
|
|
648
|
+
case AluOp.Log: return Math.log(x);
|
|
649
|
+
case AluOp.Reciprocal: return 1 / x;
|
|
650
|
+
case AluOp.Cast: if (this.dtype === DType.Int32) return Math.trunc(x) | 0;
|
|
651
|
+
else if (this.dtype === DType.Uint32) return Math.trunc(x) >>> 0;
|
|
652
|
+
else if (this.dtype === DType.Float32) return x;
|
|
653
|
+
else if (this.dtype === DType.Bool) return Number(Boolean(x));
|
|
654
|
+
else throw new Error(`Unsupported cast to ${this.dtype}`);
|
|
655
|
+
case AluOp.Bitcast: {
|
|
656
|
+
const buf = new ArrayBuffer(byteWidth(this.dtype));
|
|
657
|
+
const view = new DataView(buf);
|
|
658
|
+
const fromType = this.src[0].dtype;
|
|
659
|
+
if (fromType === DType.Float32) view.setFloat32(0, x, true);
|
|
660
|
+
else if (fromType === DType.Int32) view.setInt32(0, x, true);
|
|
661
|
+
else if (fromType === DType.Uint32) view.setUint32(0, x, true);
|
|
662
|
+
else throw new Error(`Unsupported bitcast from ${fromType}`);
|
|
663
|
+
if (this.dtype === DType.Float32) return view.getFloat32(0, true);
|
|
664
|
+
else if (this.dtype === DType.Int32) return view.getInt32(0, true);
|
|
665
|
+
else if (this.dtype === DType.Uint32) return view.getUint32(0, true);
|
|
666
|
+
else throw new Error(`Unsupported bitcast to ${this.dtype}`);
|
|
667
|
+
}
|
|
668
|
+
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
669
|
+
}
|
|
670
|
+
}
|
|
671
|
+
switch (this.op) {
|
|
672
|
+
case AluOp.Where: return this.src[0].evaluate(context, globals) ? this.src[1].evaluate(context, globals) : this.src[2].evaluate(context, globals);
|
|
673
|
+
case AluOp.Threefry2x32: {
|
|
674
|
+
const [k0, k1, c0, c1] = this.src.map((x) => x.evaluate(context, globals));
|
|
675
|
+
const [x0, x1] = threefry2x32(k0, k1, c0, c1);
|
|
676
|
+
if (this.arg === "xor") return (x0 ^ x1) >>> 0;
|
|
677
|
+
else if (this.arg === 0) return x0;
|
|
678
|
+
else if (this.arg === 1) return x1;
|
|
679
|
+
else throw new Error(`Invalid Threefry2x32 mode: ${this.arg}`);
|
|
680
|
+
}
|
|
681
|
+
case AluOp.Const: return this.arg;
|
|
682
|
+
case AluOp.Special: {
|
|
683
|
+
const x = context[this.arg[0]];
|
|
684
|
+
if (x === void 0) throw new Error(`Missing special: ${this.arg[0]}`);
|
|
685
|
+
return x;
|
|
686
|
+
}
|
|
687
|
+
case AluOp.Variable: {
|
|
688
|
+
const x = context[this.arg];
|
|
689
|
+
if (x === void 0) throw new Error(`Missing variable: ${this.arg}`);
|
|
690
|
+
return x;
|
|
691
|
+
}
|
|
692
|
+
case AluOp.GlobalIndex: {
|
|
693
|
+
if (!globals) throw new Error("Missing globals function");
|
|
694
|
+
const gid = this.arg;
|
|
695
|
+
const bufidx = this.src[0].evaluate(context, globals);
|
|
696
|
+
return globals(gid, bufidx);
|
|
697
|
+
}
|
|
698
|
+
case AluOp.GlobalView: {
|
|
699
|
+
if (!globals) throw new Error("Missing globals function");
|
|
700
|
+
const gid = this.arg[0];
|
|
701
|
+
const st = this.arg[1];
|
|
702
|
+
const [iexpr, vexpr] = st.toAluExp(this.src);
|
|
703
|
+
if (vexpr.evaluate(context, globals)) {
|
|
704
|
+
const bufidx = iexpr.evaluate(context, globals);
|
|
705
|
+
return globals(gid, bufidx);
|
|
706
|
+
} else return 0;
|
|
707
|
+
}
|
|
708
|
+
default: throw new Error(`Missing implemementation for ${this.op}`);
|
|
709
|
+
}
|
|
710
|
+
}
|
|
711
|
+
/** Get this expression in debug format as a string. */
|
|
712
|
+
toString() {
|
|
713
|
+
const BIN_SYM = {
|
|
714
|
+
[AluOp.Add]: "+",
|
|
715
|
+
[AluOp.Sub]: "-",
|
|
716
|
+
[AluOp.Mul]: "*",
|
|
717
|
+
[AluOp.Idiv]: "/",
|
|
718
|
+
[AluOp.Mod]: "%"
|
|
719
|
+
};
|
|
720
|
+
const CMP_SYM = {
|
|
721
|
+
[AluOp.Cmplt]: "<",
|
|
722
|
+
[AluOp.Cmpne]: "!="
|
|
723
|
+
};
|
|
724
|
+
const UNARY_SYM = {
|
|
725
|
+
[AluOp.Sin]: "sin",
|
|
726
|
+
[AluOp.Cos]: "cos",
|
|
727
|
+
[AluOp.Exp]: "exp",
|
|
728
|
+
[AluOp.Log]: "log",
|
|
729
|
+
[AluOp.Reciprocal]: "1/"
|
|
730
|
+
};
|
|
731
|
+
return this.fold((node, parts) => {
|
|
732
|
+
switch (node.op) {
|
|
733
|
+
case AluOp.Const: return "" + (node.dtype === DType.Bool ? Boolean(node.arg) : node.arg);
|
|
734
|
+
case AluOp.Variable: return `$${node.arg}:${node.dtype}`;
|
|
735
|
+
case AluOp.Special: {
|
|
736
|
+
const [name, n] = node.arg;
|
|
737
|
+
return `#${name}{${n}}`;
|
|
738
|
+
}
|
|
739
|
+
case AluOp.GlobalIndex: return `G_${node.arg}<${node.dtype}>[${strip1(parts[0])}]`;
|
|
740
|
+
case AluOp.GlobalView: {
|
|
741
|
+
const [gid, st] = node.arg;
|
|
742
|
+
const shape = st.shape.join(",");
|
|
743
|
+
const lastStrides = st.lastStrides.join(",");
|
|
744
|
+
const cont = st.contiguous ? "c" : "nc";
|
|
745
|
+
return `GV_${gid}<${node.dtype}>{${shape}:${lastStrides}:${cont}}[${parts.map(strip1).join(", ")}]`;
|
|
746
|
+
}
|
|
747
|
+
}
|
|
748
|
+
if (BIN_SYM[node.op]) return `(${parts[0]} ${BIN_SYM[node.op]} ${parts[1]})`;
|
|
749
|
+
if (CMP_SYM[node.op]) return `(${parts[0]} ${CMP_SYM[node.op]} ${parts[1]})`;
|
|
750
|
+
if (UNARY_SYM[node.op]) return `${UNARY_SYM[node.op]}${parts[0]}`;
|
|
751
|
+
if (node.op === AluOp.Cast) return `Cast<${node.dtype}>(${strip1(parts[0])})`;
|
|
752
|
+
if (node.op === AluOp.Bitcast) return `Bitcast<${node.dtype}>(${strip1(parts[0])})`;
|
|
753
|
+
return `${node.op}(${parts.map(strip1).join(", ")})`;
|
|
754
|
+
});
|
|
755
|
+
}
|
|
756
|
+
/** Generic fold() operation with a reducer over the expression tree. */
|
|
757
|
+
fold(reducer) {
|
|
758
|
+
const visited = /* @__PURE__ */ new Map();
|
|
759
|
+
const recurse = (exp) => {
|
|
760
|
+
if (visited.has(exp)) return visited.get(exp);
|
|
761
|
+
const mappedSrc = exp.src.map((s) => recurse(s));
|
|
762
|
+
const result = reducer(exp, mappedSrc);
|
|
763
|
+
visited.set(exp, result);
|
|
764
|
+
return result;
|
|
765
|
+
};
|
|
766
|
+
return recurse(this);
|
|
767
|
+
}
|
|
768
|
+
/** Rewrite the expression recursively using a visitor. */
|
|
769
|
+
rewrite(visitor) {
|
|
770
|
+
return this.fold((exp, newSrc) => {
|
|
771
|
+
if (newSrc.length === exp.src.length && newSrc.every((s, i) => s === exp.src[i])) return visitor(exp) ?? exp;
|
|
772
|
+
else {
|
|
773
|
+
const newExp = new AluExp(exp.op, exp.dtype, newSrc, exp.arg);
|
|
774
|
+
return visitor(newExp) ?? newExp;
|
|
775
|
+
}
|
|
776
|
+
});
|
|
777
|
+
}
|
|
778
|
+
/** Collect all nodes that satisfy a predicate. */
|
|
779
|
+
collect(predicate) {
|
|
780
|
+
const result = [];
|
|
781
|
+
this.fold((exp) => {
|
|
782
|
+
if (predicate(exp)) result.push(exp);
|
|
783
|
+
});
|
|
784
|
+
return result;
|
|
785
|
+
}
|
|
786
|
+
};
|
|
787
|
+
/** Symbolic form for each mathematical operation. */
|
|
788
|
+
let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
789
|
+
AluOp$1["Add"] = "Add";
|
|
790
|
+
AluOp$1["Sub"] = "Sub";
|
|
791
|
+
AluOp$1["Mul"] = "Mul";
|
|
792
|
+
AluOp$1["Idiv"] = "Idiv";
|
|
793
|
+
AluOp$1["Mod"] = "Mod";
|
|
794
|
+
AluOp$1["Min"] = "Min";
|
|
795
|
+
AluOp$1["Max"] = "Max";
|
|
796
|
+
AluOp$1["Sin"] = "Sin";
|
|
797
|
+
AluOp$1["Cos"] = "Cos";
|
|
798
|
+
AluOp$1["Exp"] = "Exp";
|
|
799
|
+
AluOp$1["Log"] = "Log";
|
|
800
|
+
AluOp$1["Reciprocal"] = "Reciprocal";
|
|
801
|
+
AluOp$1["Cast"] = "Cast";
|
|
802
|
+
AluOp$1["Bitcast"] = "Bitcast";
|
|
803
|
+
AluOp$1["Cmplt"] = "Cmplt";
|
|
804
|
+
AluOp$1["Cmpne"] = "Cmpne";
|
|
805
|
+
AluOp$1["Where"] = "Where";
|
|
806
|
+
AluOp$1["Threefry2x32"] = "Threefry2x32";
|
|
807
|
+
AluOp$1["Const"] = "Const";
|
|
808
|
+
AluOp$1["Special"] = "Special";
|
|
809
|
+
AluOp$1["Variable"] = "Variable";
|
|
810
|
+
AluOp$1["GlobalIndex"] = "GlobalIndex";
|
|
811
|
+
AluOp$1["GlobalView"] = "GlobalView";
|
|
812
|
+
return AluOp$1;
|
|
813
|
+
}({});
|
|
814
|
+
const AluGroup = {
|
|
815
|
+
Binary: new Set([
|
|
816
|
+
AluOp.Add,
|
|
817
|
+
AluOp.Sub,
|
|
818
|
+
AluOp.Mul,
|
|
819
|
+
AluOp.Idiv,
|
|
820
|
+
AluOp.Mod,
|
|
821
|
+
AluOp.Min,
|
|
822
|
+
AluOp.Max
|
|
823
|
+
]),
|
|
824
|
+
Unary: new Set([
|
|
825
|
+
AluOp.Sin,
|
|
826
|
+
AluOp.Cos,
|
|
827
|
+
AluOp.Exp,
|
|
828
|
+
AluOp.Log,
|
|
829
|
+
AluOp.Reciprocal,
|
|
830
|
+
AluOp.Cast,
|
|
831
|
+
AluOp.Bitcast
|
|
832
|
+
]),
|
|
833
|
+
Compare: new Set([AluOp.Cmplt, AluOp.Cmpne]),
|
|
834
|
+
Variable: new Set([
|
|
835
|
+
AluOp.Special,
|
|
836
|
+
AluOp.Variable,
|
|
837
|
+
AluOp.GlobalIndex,
|
|
838
|
+
AluOp.GlobalView
|
|
839
|
+
]),
|
|
840
|
+
Reduce: new Set([
|
|
841
|
+
AluOp.Add,
|
|
842
|
+
AluOp.Mul,
|
|
843
|
+
AluOp.Min,
|
|
844
|
+
AluOp.Max
|
|
845
|
+
]),
|
|
846
|
+
RequiredFloat: new Set([
|
|
847
|
+
AluOp.Sin,
|
|
848
|
+
AluOp.Cos,
|
|
849
|
+
AluOp.Exp,
|
|
850
|
+
AluOp.Log,
|
|
851
|
+
AluOp.Reciprocal
|
|
852
|
+
])
|
|
853
|
+
};
|
|
854
|
+
/** Common variables that can be substituted in expressions. */
|
|
855
|
+
const AluVar = {
|
|
856
|
+
gidx: AluExp.variable(DType.Int32, "gidx"),
|
|
857
|
+
ridx: AluExp.variable(DType.Int32, "ridx"),
|
|
858
|
+
acc: (dtype) => AluExp.variable(dtype, "acc"),
|
|
859
|
+
idx: AluExp.variable(DType.Int32, "idx"),
|
|
860
|
+
unroll: AluExp.variable(DType.Int32, "unroll"),
|
|
861
|
+
upcast: AluExp.variable(DType.Int32, "upcast")
|
|
862
|
+
};
|
|
863
|
+
/**
|
|
864
|
+
* Description of a kernel to be compiled.
|
|
865
|
+
*
|
|
866
|
+
* Each of these can be processed by a backend into some lower-level
|
|
867
|
+
* representation. It consists of one or more fused operations, optionally
|
|
868
|
+
* indexing into a buffer.
|
|
869
|
+
*/
|
|
870
|
+
var Kernel = class {
|
|
871
|
+
constructor(nargs, size, exp, reduction) {
|
|
872
|
+
this.nargs = nargs;
|
|
873
|
+
this.size = size;
|
|
874
|
+
this.exp = exp;
|
|
875
|
+
this.reduction = reduction;
|
|
876
|
+
this.exp = exp.simplify();
|
|
877
|
+
}
|
|
878
|
+
hash(state) {
|
|
879
|
+
state.update(this.nargs, this.size, this.exp, this.reduction);
|
|
880
|
+
}
|
|
881
|
+
pprint() {
|
|
882
|
+
let details = PPrint.pp(`exp = ${this.exp}`);
|
|
883
|
+
details = details.concat(PPrint.pp(`size = ${this.size}`));
|
|
884
|
+
if (this.reduction) details = details.concat(PPrint.pp(`reduction = ${this.reduction}`));
|
|
885
|
+
return PPrint.pp("{ ").stack(details).stack(PPrint.pp(" }"));
|
|
886
|
+
}
|
|
887
|
+
toString() {
|
|
888
|
+
return this.pprint().toString();
|
|
889
|
+
}
|
|
890
|
+
/** The dtype of the values output by this kernel. */
|
|
891
|
+
get dtype() {
|
|
892
|
+
if (this.reduction) return this.reduction.fusion.dtype;
|
|
893
|
+
else return this.exp.dtype;
|
|
894
|
+
}
|
|
895
|
+
/** The number of bytes in the output array when evaluating this kernel. */
|
|
896
|
+
get bytes() {
|
|
897
|
+
return this.size * byteWidth(this.dtype);
|
|
898
|
+
}
|
|
899
|
+
};
|
|
900
|
+
/**
|
|
901
|
+
* Description of a reduction.
|
|
902
|
+
*
|
|
903
|
+
* The strategy of jax-js backends is to either handle a standard operation that
|
|
904
|
+
* is dispatched in a vectorized way over an array, or to reduce over one axis
|
|
905
|
+
* of some computation. This is a description of the reduction.
|
|
906
|
+
*
|
|
907
|
+
* Reduction only supports a few operations, and only over one axis. Users can
|
|
908
|
+
* always `flatten()` the array before reducing if needed.
|
|
909
|
+
*
|
|
910
|
+
* The backend is responsible for implementing the reduction in a way that
|
|
911
|
+
* minimizes the number of global memory loads, for efficiency. This involves
|
|
912
|
+
* passing through some optimization strategy. But optimizations are not coded
|
|
913
|
+
* at this level since they depend on GPU, versus CPU or Wasm.
|
|
914
|
+
*/
|
|
915
|
+
var Reduction = class {
|
|
916
|
+
constructor(dtype, op, size, fusion = AluVar.acc(dtype)) {
|
|
917
|
+
this.dtype = dtype;
|
|
918
|
+
this.op = op;
|
|
919
|
+
this.size = size;
|
|
920
|
+
this.fusion = fusion;
|
|
921
|
+
if (!AluGroup.Reduce.has(op)) throw new TypeError(`Unsupported reduction: ${op}`);
|
|
922
|
+
}
|
|
923
|
+
hash(state) {
|
|
924
|
+
state.update(this.dtype, this.op, this.size, this.fusion);
|
|
925
|
+
}
|
|
926
|
+
toString() {
|
|
927
|
+
return `${this.op}{${this.size}} -> ${this.fusion}`;
|
|
928
|
+
}
|
|
929
|
+
/** Get the identity for this reduction operation. */
|
|
930
|
+
get identity() {
|
|
931
|
+
if (this.dtype === DType.Bool) return this.op === AluOp.Add || this.op === AluOp.Max ? false : true;
|
|
932
|
+
else if (this.dtype === DType.Int32) {
|
|
933
|
+
if (this.op === AluOp.Add) return 0;
|
|
934
|
+
else if (this.op === AluOp.Mul) return 1;
|
|
935
|
+
else if (this.op === AluOp.Min) return -1 >>> 1;
|
|
936
|
+
else if (this.op === AluOp.Max) return 1 << 31;
|
|
937
|
+
} else if (this.dtype === DType.Uint32) {
|
|
938
|
+
if (this.op === AluOp.Add) return 0;
|
|
939
|
+
else if (this.op === AluOp.Mul) return 1;
|
|
940
|
+
else if (this.op === AluOp.Min) return -1 >>> 0;
|
|
941
|
+
else if (this.op === AluOp.Max) return 0;
|
|
942
|
+
} else if (this.dtype === DType.Float32) {
|
|
943
|
+
if (this.op === AluOp.Add) return 0;
|
|
944
|
+
else if (this.op === AluOp.Mul) return 1;
|
|
945
|
+
else if (this.op === AluOp.Min) return Infinity;
|
|
946
|
+
else if (this.op === AluOp.Max) return -Infinity;
|
|
947
|
+
}
|
|
948
|
+
throw new TypeError(`Unsupported reduction: ${this.op} ${this.dtype}`);
|
|
949
|
+
}
|
|
950
|
+
/** Evaluate this operation on CPU. */
|
|
951
|
+
evaluate(...values) {
|
|
952
|
+
if (this.dtype === DType.Bool) {
|
|
953
|
+
if (this.op === AluOp.Add || this.op === AluOp.Max) return values.reduce((a, b) => a || b, true);
|
|
954
|
+
else if (this.op === AluOp.Mul || this.op === AluOp.Min) return values.reduce((a, b) => a && b, true);
|
|
955
|
+
} else if (this.dtype === DType.Int32) {
|
|
956
|
+
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b | 0, 0);
|
|
957
|
+
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b | 0, 1);
|
|
958
|
+
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 1);
|
|
959
|
+
else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 1 << 31);
|
|
960
|
+
} else if (this.dtype === DType.Uint32) {
|
|
961
|
+
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b >>> 0, 0);
|
|
962
|
+
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b >>> 0, 1);
|
|
963
|
+
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), -1 >>> 0);
|
|
964
|
+
else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), 0);
|
|
965
|
+
} else if (this.dtype === DType.Float32) {
|
|
966
|
+
if (this.op === AluOp.Add) return values.reduce((a, b) => a + b, 0);
|
|
967
|
+
else if (this.op === AluOp.Mul) return values.reduce((a, b) => a * b, 1);
|
|
968
|
+
else if (this.op === AluOp.Min) return values.reduce((a, b) => Math.min(a, b), Infinity);
|
|
969
|
+
else if (this.op === AluOp.Max) return values.reduce((a, b) => Math.max(a, b), -Infinity);
|
|
970
|
+
}
|
|
971
|
+
throw new TypeError(`Unsupported reduction: ${this.op} ${this.dtype}`);
|
|
972
|
+
}
|
|
973
|
+
};
|
|
974
|
+
/** Expression for accessing `indices` in input array with the given shape. */
|
|
975
|
+
function accessorGlobal(dtype, gid, st, indices) {
|
|
976
|
+
const [index, valid] = st.toAluExp(indices);
|
|
977
|
+
return AluExp.where(valid, AluExp.globalIndex(dtype, gid, index), AluExp.const(dtype, 0));
|
|
978
|
+
}
|
|
979
|
+
/** Expression for accessing `indices` in an array recipe with variable "idx". */
|
|
980
|
+
function accessorAluExp(dtype, exp, st, indices) {
|
|
981
|
+
const [index, valid] = st.toAluExp(indices);
|
|
982
|
+
return AluExp.where(valid, exp.substitute({ idx: index }), AluExp.const(dtype, 0));
|
|
983
|
+
}
|
|
984
|
+
function threefry2x32(k0, k1, c0, c1) {
|
|
985
|
+
const rotl32 = (x, r) => (x << r | x >>> 32 - r) >>> 0;
|
|
986
|
+
const ks0 = k0 >>> 0;
|
|
987
|
+
const ks1 = k1 >>> 0;
|
|
988
|
+
const ks2 = (ks0 ^ ks1 ^ 466688986) >>> 0;
|
|
989
|
+
let x0 = c0 + ks0 >>> 0;
|
|
990
|
+
let x1 = c1 + ks1 >>> 0;
|
|
991
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 13) ^ x0;
|
|
992
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 15) ^ x0;
|
|
993
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 26) ^ x0;
|
|
994
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 6) ^ x0;
|
|
995
|
+
x0 = x0 + ks1 >>> 0;
|
|
996
|
+
x1 = x1 + ks2 + 1 >>> 0;
|
|
997
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 17) ^ x0;
|
|
998
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 29) ^ x0;
|
|
999
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 16) ^ x0;
|
|
1000
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 24) ^ x0;
|
|
1001
|
+
x0 = x0 + ks2 >>> 0;
|
|
1002
|
+
x1 = x1 + ks0 + 2 >>> 0;
|
|
1003
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 13) ^ x0;
|
|
1004
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 15) ^ x0;
|
|
1005
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 26) ^ x0;
|
|
1006
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 6) ^ x0;
|
|
1007
|
+
x0 = x0 + ks0 >>> 0;
|
|
1008
|
+
x1 = x1 + ks1 + 3 >>> 0;
|
|
1009
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 17) ^ x0;
|
|
1010
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 29) ^ x0;
|
|
1011
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 16) ^ x0;
|
|
1012
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 24) ^ x0;
|
|
1013
|
+
x0 = x0 + ks1 >>> 0;
|
|
1014
|
+
x1 = x1 + ks2 + 4 >>> 0;
|
|
1015
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 13) ^ x0;
|
|
1016
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 15) ^ x0;
|
|
1017
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 26) ^ x0;
|
|
1018
|
+
x0 = x0 + x1 >>> 0, x1 = rotl32(x1, 6) ^ x0;
|
|
1019
|
+
x0 = x0 + ks2 >>> 0;
|
|
1020
|
+
x1 = x1 + ks0 + 5 >>> 0;
|
|
1021
|
+
return [x0, x1];
|
|
1022
|
+
}
|
|
1023
|
+
|
|
1024
|
+
//#endregion
|
|
1025
|
+
//#region src/shape.ts
|
|
1026
|
+
const jstr = JSON.stringify;
|
|
1027
|
+
/** Remove "1" dimensions from the strides list. */
|
|
1028
|
+
function canonicalizeStrides(shape, strides) {
|
|
1029
|
+
const newStrides = [];
|
|
1030
|
+
for (let i = 0; i < shape.length; i++) if (shape[i] === 1) newStrides.push(0);
|
|
1031
|
+
else newStrides.push(strides[i]);
|
|
1032
|
+
return newStrides;
|
|
1033
|
+
}
|
|
1034
|
+
/** Get the strides for a shape in default row-major order. */
|
|
1035
|
+
function defaultStrides(shape) {
|
|
1036
|
+
if (shape.length === 0) return [];
|
|
1037
|
+
const strides = rep(shape.length, 1);
|
|
1038
|
+
for (let i = shape.length - 1; i > 0; i--) strides[i - 1] = shape[i] * strides[i];
|
|
1039
|
+
return canonicalizeStrides(shape, strides);
|
|
1040
|
+
}
|
|
1041
|
+
/** Merge contiguous subparts or zero-strided dimensions in a view. */
|
|
1042
|
+
function mergeDims(shape, strides, mask) {
|
|
1043
|
+
if (shape.length === 0) return [];
|
|
1044
|
+
if (shape.length !== strides.length || mask && shape.length !== mask.length) throw new Error("internal: invalid args to mergeDims");
|
|
1045
|
+
const ret = [[
|
|
1046
|
+
shape[0],
|
|
1047
|
+
strides[0],
|
|
1048
|
+
strides[0] !== 0 ? shape[0] : 0
|
|
1049
|
+
]];
|
|
1050
|
+
let merging = mask ? mask[0][1] - mask[0][0] === 1 : shape[0] === 1;
|
|
1051
|
+
for (let i = 1; i < shape.length; i++) {
|
|
1052
|
+
const [s, st] = [shape[i], strides[i]];
|
|
1053
|
+
if (s === 1) continue;
|
|
1054
|
+
const [lastS, lastSt, lastPreExpandS] = ret[ret.length - 1];
|
|
1055
|
+
if (merging || lastSt === s * st) ret[ret.length - 1] = [
|
|
1056
|
+
lastS * s,
|
|
1057
|
+
st,
|
|
1058
|
+
merging ? s : lastPreExpandS * s
|
|
1059
|
+
];
|
|
1060
|
+
else ret.push([
|
|
1061
|
+
s,
|
|
1062
|
+
st,
|
|
1063
|
+
s
|
|
1064
|
+
]);
|
|
1065
|
+
merging = mask ? mask[i][1] - mask[i][0] === 1 : false;
|
|
1066
|
+
}
|
|
1067
|
+
return ret;
|
|
1068
|
+
}
|
|
1069
|
+
/** Return the new mask if a reshape if possible, otherwise `null`. */
|
|
1070
|
+
function reshapeMask(maskInput, oldShape, newShape) {
|
|
1071
|
+
const newMask = [];
|
|
1072
|
+
let rMasksI = maskInput.length;
|
|
1073
|
+
let rShapeI = oldShape.length;
|
|
1074
|
+
let rNewShapeI = newShape.length;
|
|
1075
|
+
const rMasks = () => rMasksI ? maskInput[--rMasksI] : [0, 1];
|
|
1076
|
+
const rShape = () => rShapeI ? oldShape[--rShapeI] : 1;
|
|
1077
|
+
const rNewShape = () => rNewShapeI ? newShape[--rNewShapeI] : 1;
|
|
1078
|
+
let currStride = 1;
|
|
1079
|
+
let [oldDim, newDim, mask] = [
|
|
1080
|
+
rShape(),
|
|
1081
|
+
rNewShape(),
|
|
1082
|
+
rMasks()
|
|
1083
|
+
];
|
|
1084
|
+
while (newMask.length < newShape.length) {
|
|
1085
|
+
const [l, r] = mask;
|
|
1086
|
+
const nextStride = newDim * currStride;
|
|
1087
|
+
if (oldDim === nextStride) {
|
|
1088
|
+
newMask.push([intdiv(l, currStride), intdiv(r - 1, currStride) + 1]);
|
|
1089
|
+
currStride = 1;
|
|
1090
|
+
[oldDim, newDim, mask] = [
|
|
1091
|
+
rShape(),
|
|
1092
|
+
rNewShape(),
|
|
1093
|
+
rMasks()
|
|
1094
|
+
];
|
|
1095
|
+
} else if (oldDim > nextStride) {
|
|
1096
|
+
if (oldDim % nextStride !== 0) return null;
|
|
1097
|
+
if ((l % nextStride !== 0 || r % nextStride !== 0) && intdiv(l, nextStride) !== intdiv(r - 1, nextStride)) return null;
|
|
1098
|
+
newMask.push([intdiv(l % nextStride, currStride), intdiv((r - 1) % nextStride, currStride) + 1]);
|
|
1099
|
+
[currStride, newDim] = [nextStride, rNewShape()];
|
|
1100
|
+
} else {
|
|
1101
|
+
const nextMask = rMasks();
|
|
1102
|
+
if (!deepEqual(mask, [0, oldDim]) && l !== r && nextMask[1] - nextMask[0] !== 1) return null;
|
|
1103
|
+
mask = [nextMask[0] * oldDim + l, (nextMask[1] - 1) * oldDim + r];
|
|
1104
|
+
oldDim *= rShape();
|
|
1105
|
+
}
|
|
1106
|
+
}
|
|
1107
|
+
return newMask.reverse();
|
|
1108
|
+
}
|
|
1109
|
+
/**
|
|
1110
|
+
* A multidimensional view into memory. An array can be thought of as the
|
|
1111
|
+
* combination of a linear buffer of memory, along with a `View`.
|
|
1112
|
+
*
|
|
1113
|
+
* Formula for getting a data point is basically:
|
|
1114
|
+
* 1. Check if ∀i. 0 <= dim[i] < shape[i], otherwise out of bounds.
|
|
1115
|
+
* 2. If mask exists, and ∃i. dim[i] ∉ mask[i], return 0.
|
|
1116
|
+
* 2. Otherwise, look at this memory address: offset + ∑(strides[i] * dim[i]).
|
|
1117
|
+
*/
|
|
1118
|
+
var View = class View {
|
|
1119
|
+
#size;
|
|
1120
|
+
#contiguous;
|
|
1121
|
+
constructor(shape, strides, offset, mask) {
|
|
1122
|
+
this.shape = shape;
|
|
1123
|
+
this.strides = strides;
|
|
1124
|
+
this.offset = offset;
|
|
1125
|
+
this.mask = mask;
|
|
1126
|
+
}
|
|
1127
|
+
static create(shape, strides, offset = 0, mask = null) {
|
|
1128
|
+
if (shape.some((s) => s < 0)) throw new Error("View shape must be non-negative");
|
|
1129
|
+
strides = strides ? canonicalizeStrides(shape, strides) : defaultStrides(shape);
|
|
1130
|
+
if (shape.includes(0)) return new View(shape, rep(shape.length, 0), 0, null);
|
|
1131
|
+
if (mask !== null && mask.every(([b, e], i) => b === 0 && e === shape[i])) mask = null;
|
|
1132
|
+
if (mask !== null) {
|
|
1133
|
+
const elimDims = [];
|
|
1134
|
+
let hasNoData = false;
|
|
1135
|
+
for (let i = 0; i < shape.length; i++) {
|
|
1136
|
+
const [b, e] = mask[i];
|
|
1137
|
+
if (b + 1 >= e) elimDims.push(i);
|
|
1138
|
+
if (b >= e) hasNoData = true;
|
|
1139
|
+
}
|
|
1140
|
+
if (elimDims.length) {
|
|
1141
|
+
if (hasNoData) {
|
|
1142
|
+
strides = rep(shape.length, 0);
|
|
1143
|
+
offset = 0;
|
|
1144
|
+
mask = rep(shape.length, () => [0, 0]);
|
|
1145
|
+
}
|
|
1146
|
+
for (const i of elimDims) {
|
|
1147
|
+
offset += strides[i] * mask[i][0];
|
|
1148
|
+
strides[i] = 0;
|
|
1149
|
+
}
|
|
1150
|
+
}
|
|
1151
|
+
}
|
|
1152
|
+
return new View(shape, strides, offset, mask);
|
|
1153
|
+
}
|
|
1154
|
+
get ndim() {
|
|
1155
|
+
return this.shape.length;
|
|
1156
|
+
}
|
|
1157
|
+
get size() {
|
|
1158
|
+
if (this.#size === void 0) this.#size = prod(this.shape);
|
|
1159
|
+
return this.#size;
|
|
1160
|
+
}
|
|
1161
|
+
/** Whether this is a default, contiguous, unaltered view of the data (identity). */
|
|
1162
|
+
get contiguous() {
|
|
1163
|
+
if (this.#contiguous === void 0) this.#contiguous = this.size === 0 || this.offset === 0 && this.mask === null && deepEqual(this.strides, defaultStrides(this.shape));
|
|
1164
|
+
return this.#contiguous;
|
|
1165
|
+
}
|
|
1166
|
+
/** Produce an AluExp for evaluating this view at an index. */
|
|
1167
|
+
toAluExp(idxs) {
|
|
1168
|
+
let iexpr = AluExp.i32(this.offset);
|
|
1169
|
+
let vexpr = AluExp.bool(true);
|
|
1170
|
+
for (let i = this.ndim - 1; i >= 0; i--) {
|
|
1171
|
+
const idx = idxs[i];
|
|
1172
|
+
if (this.shape[i] !== 1 && this.strides[i] !== 0) iexpr = AluExp.add(AluExp.mul(idx, AluExp.i32(this.strides[i])), iexpr);
|
|
1173
|
+
if (this.mask) {
|
|
1174
|
+
if (this.mask[i][0] !== 0) vexpr = AluExp.mul(AluExp.cmplt(idx, AluExp.i32(this.mask[i][0])).not(), vexpr);
|
|
1175
|
+
if (this.mask[i][1] !== this.shape[i]) vexpr = AluExp.mul(AluExp.cmplt(idx, AluExp.i32(this.mask[i][1])), vexpr);
|
|
1176
|
+
}
|
|
1177
|
+
}
|
|
1178
|
+
return [iexpr, vexpr];
|
|
1179
|
+
}
|
|
1180
|
+
/**
|
|
1181
|
+
* Try to compose this view with another one. `this` view is applied first,
|
|
1182
|
+
* followed by the argument. If this is not possible for the specific views,
|
|
1183
|
+
* return `null` instead.
|
|
1184
|
+
*
|
|
1185
|
+
* If composable, return a combined view with the same shape as `v1`.
|
|
1186
|
+
*
|
|
1187
|
+
* This is very tricky. The shapes of v1 and v2 may be different, and in that
|
|
1188
|
+
* case, we do some math to figure out whether they're compatible.
|
|
1189
|
+
*/
|
|
1190
|
+
compose(v1) {
|
|
1191
|
+
const v2 = this;
|
|
1192
|
+
if (v2.contiguous) return v1;
|
|
1193
|
+
if (v1.contiguous) {
|
|
1194
|
+
if (deepEqual(v1.shape, v2.shape)) return v2;
|
|
1195
|
+
if (v1.size === v2.size) {
|
|
1196
|
+
const ret = v2.reshape(v1.shape);
|
|
1197
|
+
if (ret !== null) return ret;
|
|
1198
|
+
}
|
|
1199
|
+
}
|
|
1200
|
+
if (v1.mask !== null) {
|
|
1201
|
+
const newV1 = v1.shrink(v1.mask);
|
|
1202
|
+
const merged = v2.compose(newV1);
|
|
1203
|
+
return merged ? merged.pad(zip(v1.mask, v1.shape).map(([m, s]) => [m[0], s - m[1]])) : null;
|
|
1204
|
+
}
|
|
1205
|
+
const origin = unravel(v2.shape, v1.offset);
|
|
1206
|
+
const terms = rep(v2.ndim, () => []);
|
|
1207
|
+
const strides = rep(v1.ndim, 0);
|
|
1208
|
+
for (let d1 = 0; d1 < v1.strides.length; d1++) {
|
|
1209
|
+
const st = v1.strides[d1];
|
|
1210
|
+
if (st === 0) continue;
|
|
1211
|
+
const unravelOffset = unravel(v2.shape, v1.offset + st);
|
|
1212
|
+
for (let d2 = 0; d2 < v2.ndim; d2++) {
|
|
1213
|
+
const o = origin[d2];
|
|
1214
|
+
const diff = unravelOffset[d2] - o;
|
|
1215
|
+
if (diff === 0) continue;
|
|
1216
|
+
terms[d2].push([d1, diff]);
|
|
1217
|
+
strides[d1] += diff * v2.strides[d2];
|
|
1218
|
+
}
|
|
1219
|
+
}
|
|
1220
|
+
let [mergedSize, mergedTermMin, mergedTermMax] = [
|
|
1221
|
+
1,
|
|
1222
|
+
0,
|
|
1223
|
+
0
|
|
1224
|
+
];
|
|
1225
|
+
const extents = [];
|
|
1226
|
+
for (let i = v2.ndim - 1; i >= 0; i--) {
|
|
1227
|
+
const term = terms[i];
|
|
1228
|
+
const s = v2.shape[i];
|
|
1229
|
+
let [tmin, tmax] = [origin[i], origin[i]];
|
|
1230
|
+
for (const [d1, s1] of term) if (s1 > 0) tmax += (v1.shape[d1] - 1) * s1;
|
|
1231
|
+
else if (s1 < 0) tmin += (v1.shape[d1] - 1) * s1;
|
|
1232
|
+
mergedTermMin += tmin * mergedSize;
|
|
1233
|
+
mergedTermMax += tmax * mergedSize;
|
|
1234
|
+
mergedSize *= s;
|
|
1235
|
+
if (mergedTermMin >= 0 && mergedTermMax < mergedSize) {
|
|
1236
|
+
extents.push([
|
|
1237
|
+
mergedSize,
|
|
1238
|
+
mergedTermMin,
|
|
1239
|
+
mergedTermMax
|
|
1240
|
+
]);
|
|
1241
|
+
[mergedSize, mergedTermMin, mergedTermMax] = [
|
|
1242
|
+
1,
|
|
1243
|
+
0,
|
|
1244
|
+
0
|
|
1245
|
+
];
|
|
1246
|
+
}
|
|
1247
|
+
}
|
|
1248
|
+
if (mergedTermMin !== 0 || mergedTermMax !== 0) return null;
|
|
1249
|
+
extents.reverse();
|
|
1250
|
+
const v2Shape = extents.map(([s]) => s);
|
|
1251
|
+
if (!deepEqual(v2Shape, v2.shape)) {
|
|
1252
|
+
const reshapedV2 = v2.reshape(v2Shape);
|
|
1253
|
+
if (reshapedV2 === null) return null;
|
|
1254
|
+
if (!deepEqual(reshapedV2.shape, v2.shape)) return reshapedV2.compose(v1);
|
|
1255
|
+
}
|
|
1256
|
+
if (v2.mask !== null) {
|
|
1257
|
+
const newB = rep(v1.ndim, 0);
|
|
1258
|
+
const newE = v1.shape.slice();
|
|
1259
|
+
let bad = false;
|
|
1260
|
+
for (let d2 = 0; d2 < v2.ndim; d2++) {
|
|
1261
|
+
const [b, e] = v2.mask[d2];
|
|
1262
|
+
const o = origin[d2];
|
|
1263
|
+
const term = terms[d2];
|
|
1264
|
+
const [_, tmin, tmax] = extents[d2];
|
|
1265
|
+
if (b <= tmin && tmax < e) continue;
|
|
1266
|
+
if (term.length !== 1) if (term.length === 0 && newE.length) newE[0] = 0;
|
|
1267
|
+
else bad = true;
|
|
1268
|
+
else {
|
|
1269
|
+
const [d1, s1] = term[0];
|
|
1270
|
+
newB[d1] = Math.max(newB[d1], Math.ceil((s1 > 0 ? b - o : e - o - 1) / s1));
|
|
1271
|
+
newE[d1] = Math.min(newE[d1], Math.floor((s1 < 0 ? b - o : e - o - 1) / s1) + 1);
|
|
1272
|
+
}
|
|
1273
|
+
}
|
|
1274
|
+
for (let d1 = 0; d1 < v1.ndim; d1++) if (newB[d1] !== 0 || newE[d1] !== v1.shape[d1]) return v2.compose(View.create(v1.shape, v1.strides, v1.offset, zip(newB, newE)));
|
|
1275
|
+
if (bad) return null;
|
|
1276
|
+
}
|
|
1277
|
+
let finalOffset = v2.offset;
|
|
1278
|
+
for (let d2 = 0; d2 < v2.ndim; d2++) finalOffset += origin[d2] * v2.strides[d2];
|
|
1279
|
+
return View.create(v1.shape, strides, finalOffset, null);
|
|
1280
|
+
}
|
|
1281
|
+
/** Attempt to simplify this view into a smaller reshaped form. */
|
|
1282
|
+
minify() {
|
|
1283
|
+
const minShape = mergeDims(this.shape, this.strides, this.mask).map((x) => x[0]);
|
|
1284
|
+
const nv = this.reshape(minShape);
|
|
1285
|
+
return nv ? nv : this;
|
|
1286
|
+
}
|
|
1287
|
+
/** Pad the view with zeros on each dimension. */
|
|
1288
|
+
pad(arg) {
|
|
1289
|
+
if (arg.length !== this.ndim || !arg.every(([b, e]) => b >= 0 && e >= 0)) throw new Error(`invalid pad ${jstr(arg)} for ${jstr(this.shape)}`);
|
|
1290
|
+
if (arg.every(([b, e]) => b === 0 && e === 0)) return this;
|
|
1291
|
+
const zvarg = arg.map(([b, e], i) => [-b, this.shape[i] + e]);
|
|
1292
|
+
const mask = arg.map(([b, _e], i) => [b, this.shape[i] + b]);
|
|
1293
|
+
return this.#unsafeResize(zvarg, mask);
|
|
1294
|
+
}
|
|
1295
|
+
/** Shrink the view by taking a subarray. */
|
|
1296
|
+
shrink(arg) {
|
|
1297
|
+
if (arg.length !== this.ndim || !arg.every(([b, e], i) => 0 <= b && b <= e && e <= this.shape[i])) throw new Error(`invalid shrink ${jstr(arg)} for ${jstr(this.shape)}`);
|
|
1298
|
+
return this.#unsafeResize(arg);
|
|
1299
|
+
}
|
|
1300
|
+
#unsafeResize(arg, mask) {
|
|
1301
|
+
const offset = this.strides.map((s, i) => s * arg[i][0]).reduce((a, b) => a + b, 0);
|
|
1302
|
+
if (this.mask) {
|
|
1303
|
+
const nmask = this.mask.map(([mx, my], i) => [Math.max(0, Math.min(mx - arg[i][0], arg[i][1] - arg[i][0])), Math.max(0, Math.min(my - arg[i][0], arg[i][1] - arg[i][0]))]);
|
|
1304
|
+
mask = mask ? mask.map(([mx, my], i) => [Math.max(mx, nmask[i][0]), Math.min(my, nmask[i][1])]) : nmask;
|
|
1305
|
+
}
|
|
1306
|
+
return View.create(arg.map(([b, e]) => e - b), this.strides, this.offset + offset, mask);
|
|
1307
|
+
}
|
|
1308
|
+
/** Expand one or more axes with length "1" by repeating the data. */
|
|
1309
|
+
expand(newShape) {
|
|
1310
|
+
if (newShape.length !== this.ndim) throw new Error(`Can't expand ${jstr(this.shape)} into ${jstr(newShape)}`);
|
|
1311
|
+
for (let i = 0; i < this.ndim; i++) if (newShape[i] !== this.shape[i] && this.shape[i] !== 1) throw new Error(`Can't expand ${jstr(this.shape)} into ${jstr(newShape)}`);
|
|
1312
|
+
if (this.size === 0) return View.create(newShape);
|
|
1313
|
+
const mask = this.mask ? this.mask.map((m, i) => this.shape[i] === newShape[i] ? m : m[0] === 0 && m[1] === 1 ? [0, newShape[i]] : [0, 0]) : null;
|
|
1314
|
+
return View.create(newShape, this.strides, this.offset, mask);
|
|
1315
|
+
}
|
|
1316
|
+
/** Permute the axes of an array. */
|
|
1317
|
+
permute(axis) {
|
|
1318
|
+
if (!isPermutation(axis, this.ndim)) throw new Error(`Invalid permutation ${jstr(axis)} of len ${this.ndim}`);
|
|
1319
|
+
const newShape = axis.map((a) => this.shape[a]);
|
|
1320
|
+
const newStrides = axis.map((a) => this.strides[a]);
|
|
1321
|
+
const newMask = this.mask ? axis.map((a) => this.mask[a]) : null;
|
|
1322
|
+
return View.create(newShape, newStrides, this.offset, newMask);
|
|
1323
|
+
}
|
|
1324
|
+
/** Flip (reverse) one or more axes of the view. */
|
|
1325
|
+
flip(arg) {
|
|
1326
|
+
if (arg.length !== this.ndim) throw new Error(`Invalid flip ${jstr(arg)} for ${jstr(this.shape)}`);
|
|
1327
|
+
const strides = this.strides.slice();
|
|
1328
|
+
let offset = this.offset;
|
|
1329
|
+
const mask = this.mask ? this.mask.slice() : null;
|
|
1330
|
+
for (let i = 0; i < this.ndim; i++) {
|
|
1331
|
+
const s = this.shape[i];
|
|
1332
|
+
if (arg[i]) {
|
|
1333
|
+
strides[i] = -strides[i];
|
|
1334
|
+
offset += (s - 1) * this.strides[i];
|
|
1335
|
+
if (mask) mask[i] = [s - mask[i][1], s - mask[i][0]];
|
|
1336
|
+
}
|
|
1337
|
+
}
|
|
1338
|
+
return View.create(this.shape, strides, offset, mask);
|
|
1339
|
+
}
|
|
1340
|
+
/** Reshape the view into a new shape. */
|
|
1341
|
+
reshape(newShape) {
|
|
1342
|
+
if (deepEqual(this.shape, newShape)) return this;
|
|
1343
|
+
if (newShape.some((s) => s < 0)) throw new Error(`Reshape cannot have negative numbers ${jstr(newShape)}`);
|
|
1344
|
+
if (this.size !== prod(newShape)) throw new Error(`Reshape size ${jstr(this.shape)} -> ${jstr(newShape)}`);
|
|
1345
|
+
if (this.size === 0) return View.create(newShape);
|
|
1346
|
+
if (newShape.length === 0 && this.mask?.some(([b, e]) => b === e)) return null;
|
|
1347
|
+
if (this.contiguous) return View.create(newShape);
|
|
1348
|
+
const rStrides = [];
|
|
1349
|
+
const merge = mergeDims(this.shape, this.strides, this.mask);
|
|
1350
|
+
let rShapeIdx = newShape.length;
|
|
1351
|
+
for (let i = merge.length - 1; i >= 0; i--) {
|
|
1352
|
+
let [mergedSize, newStride, realSize] = merge[i];
|
|
1353
|
+
let acc = 1;
|
|
1354
|
+
while (acc < mergedSize && rShapeIdx > 0) {
|
|
1355
|
+
const newDim = newShape[--rShapeIdx];
|
|
1356
|
+
rStrides.push(newStride * acc);
|
|
1357
|
+
acc *= newDim;
|
|
1358
|
+
if (acc >= realSize) newStride = 0;
|
|
1359
|
+
}
|
|
1360
|
+
if (acc !== mergedSize) return null;
|
|
1361
|
+
}
|
|
1362
|
+
const newStrides = rep(newShape.length - rStrides.length, 0).concat(rStrides.reverse());
|
|
1363
|
+
if (!this.mask) return View.create(newShape, newStrides, this.offset);
|
|
1364
|
+
const newMask = reshapeMask(this.mask, this.shape, newShape);
|
|
1365
|
+
if (!newMask) return null;
|
|
1366
|
+
let newOffset = this.offset;
|
|
1367
|
+
for (let i = 0; i < this.ndim; i++) newOffset += this.strides[i] * this.mask[i][0];
|
|
1368
|
+
for (let i = 0; i < newShape.length; i++) newOffset -= newStrides[i] * newMask[i][0];
|
|
1369
|
+
return View.create(newShape, newStrides, newOffset, newMask);
|
|
1370
|
+
}
|
|
1371
|
+
};
|
|
1372
|
+
/**
|
|
1373
|
+
* Find position of `offset` in each dimension within an existing shape. Like
|
|
1374
|
+
* `numpy.unravel_index` in behavior.
|
|
1375
|
+
*/
|
|
1376
|
+
function unravel(shape, offset) {
|
|
1377
|
+
let acc = 1;
|
|
1378
|
+
const idxs = [];
|
|
1379
|
+
for (let i = shape.length - 1; i >= 0; i--) {
|
|
1380
|
+
const d = shape[i];
|
|
1381
|
+
idxs.push(Math.floor(offset / acc) % d);
|
|
1382
|
+
acc *= d;
|
|
1383
|
+
}
|
|
1384
|
+
return idxs.reverse();
|
|
1385
|
+
}
|
|
1386
|
+
/** Generate a list of AluExp for computing unravel(). */
|
|
1387
|
+
function unravelAlu(shape, offset) {
|
|
1388
|
+
let acc = 1;
|
|
1389
|
+
const idxs = [];
|
|
1390
|
+
for (let i = shape.length - 1; i >= 0; i--) {
|
|
1391
|
+
const d = shape[i];
|
|
1392
|
+
idxs.push(AluExp.mod(AluExp.idiv(offset, AluExp.i32(acc)), AluExp.i32(d)));
|
|
1393
|
+
acc *= d;
|
|
1394
|
+
}
|
|
1395
|
+
return idxs.reverse();
|
|
1396
|
+
}
|
|
1397
|
+
/**
|
|
1398
|
+
* Array shape after applying movement operations, as a series of views.
|
|
1399
|
+
*
|
|
1400
|
+
* Each view is applied, then treated as if it were a contiguous array of its
|
|
1401
|
+
* shape, then used as the virtual buffer for the next view.
|
|
1402
|
+
*/
|
|
1403
|
+
var ShapeTracker = class ShapeTracker {
|
|
1404
|
+
constructor(views) {
|
|
1405
|
+
this.views = views;
|
|
1406
|
+
}
|
|
1407
|
+
/** Compose this shape tracker with another, applying it after this one. */
|
|
1408
|
+
compose(other) {
|
|
1409
|
+
if (this.contiguous) return other;
|
|
1410
|
+
let ret = this;
|
|
1411
|
+
for (const v of other.views) ret = new ShapeTracker(ret.views.concat(v)).simplify();
|
|
1412
|
+
return ret;
|
|
1413
|
+
}
|
|
1414
|
+
static fromShape(shape) {
|
|
1415
|
+
return new ShapeTracker([View.create(shape)]);
|
|
1416
|
+
}
|
|
1417
|
+
get contiguous() {
|
|
1418
|
+
return this.views.length === 1 && this.views[0].contiguous;
|
|
1419
|
+
}
|
|
1420
|
+
get consecutive() {
|
|
1421
|
+
return this.views.length === 1 && this.views[0].mask === null && deepEqual(this.views[0].strides, defaultStrides(this.views[0].shape));
|
|
1422
|
+
}
|
|
1423
|
+
get lastStrides() {
|
|
1424
|
+
return this.views[this.views.length - 1].strides;
|
|
1425
|
+
}
|
|
1426
|
+
get shape() {
|
|
1427
|
+
return this.views[this.views.length - 1].shape;
|
|
1428
|
+
}
|
|
1429
|
+
get size() {
|
|
1430
|
+
return this.views[this.views.length - 1].size;
|
|
1431
|
+
}
|
|
1432
|
+
toAluExp(idxs) {
|
|
1433
|
+
let [iexpr, vexpr] = this.views[this.views.length - 1].toAluExp(idxs);
|
|
1434
|
+
for (let i = this.views.length - 2; i >= 0; i--) {
|
|
1435
|
+
const view = this.views[i].minify();
|
|
1436
|
+
const exprs = view.toAluExp(unravelAlu(view.shape, iexpr));
|
|
1437
|
+
iexpr = exprs[0];
|
|
1438
|
+
vexpr = AluExp.mul(vexpr, exprs[1]);
|
|
1439
|
+
}
|
|
1440
|
+
return [iexpr.simplify(), vexpr.simplify()];
|
|
1441
|
+
}
|
|
1442
|
+
simplify() {
|
|
1443
|
+
const views = this.views.slice();
|
|
1444
|
+
while (views.length >= 2) {
|
|
1445
|
+
const newView = views[views.length - 2].compose(views[views.length - 1]);
|
|
1446
|
+
if (newView === null) break;
|
|
1447
|
+
views.splice(views.length - 2, 2, newView);
|
|
1448
|
+
}
|
|
1449
|
+
return new ShapeTracker(views);
|
|
1450
|
+
}
|
|
1451
|
+
pad(arg) {
|
|
1452
|
+
return new ShapeTracker(applyLast(this.views, (x) => x.pad(arg)));
|
|
1453
|
+
}
|
|
1454
|
+
shrink(arg) {
|
|
1455
|
+
return new ShapeTracker(applyLast(this.views, (x) => x.shrink(arg)));
|
|
1456
|
+
}
|
|
1457
|
+
expand(newShape) {
|
|
1458
|
+
return new ShapeTracker(applyLast(this.views, (x) => x.expand(newShape)));
|
|
1459
|
+
}
|
|
1460
|
+
permute(axis) {
|
|
1461
|
+
return new ShapeTracker(applyLast(this.views, (x) => x.permute(axis)));
|
|
1462
|
+
}
|
|
1463
|
+
flip(arg) {
|
|
1464
|
+
return new ShapeTracker(applyLast(this.views, (x) => x.flip(arg)));
|
|
1465
|
+
}
|
|
1466
|
+
reshape(newShape) {
|
|
1467
|
+
const newView = this.views[this.views.length - 1].reshape(newShape);
|
|
1468
|
+
return new ShapeTracker(newView === null ? this.views.concat(View.create(newShape)) : this.views.toSpliced(this.views.length - 1, 1, newView));
|
|
1469
|
+
}
|
|
1470
|
+
/** Broadcast along the given new axes, then expand the shape. */
|
|
1471
|
+
broadcast(newShape, axis) {
|
|
1472
|
+
let st = this;
|
|
1473
|
+
if (axis.length > 0) {
|
|
1474
|
+
const unsqueezed = [...st.shape];
|
|
1475
|
+
for (const i of axis.toSorted()) unsqueezed.splice(i, 0, 1);
|
|
1476
|
+
st = st.reshape(unsqueezed);
|
|
1477
|
+
}
|
|
1478
|
+
return st.expand(newShape);
|
|
1479
|
+
}
|
|
1480
|
+
};
|
|
1481
|
+
function applyLast(ar, f) {
|
|
1482
|
+
return ar.toSpliced(ar.length - 1, 1, f(ar[ar.length - 1]));
|
|
1483
|
+
}
|
|
1484
|
+
|
|
1485
|
+
//#endregion
|
|
1486
|
+
//#region src/tuner.ts
|
|
1487
|
+
/** Stores dimensions of the kernel's applied shape. Globals start at 0. */
|
|
1488
|
+
var TuneDims = class {
|
|
1489
|
+
st;
|
|
1490
|
+
outputSt;
|
|
1491
|
+
groups;
|
|
1492
|
+
reduce;
|
|
1493
|
+
unroll;
|
|
1494
|
+
upcast;
|
|
1495
|
+
get end() {
|
|
1496
|
+
return this.st.shape.length;
|
|
1497
|
+
}
|
|
1498
|
+
constructor(shape) {
|
|
1499
|
+
this.st = ShapeTracker.fromShape(shape);
|
|
1500
|
+
this.outputSt = ShapeTracker.fromShape(shape.slice(0, -1));
|
|
1501
|
+
this.groups = this.st.shape.length - 1;
|
|
1502
|
+
this.reduce = this.st.shape.length - 1;
|
|
1503
|
+
this.unroll = this.st.shape.length;
|
|
1504
|
+
this.upcast = this.st.shape.length;
|
|
1505
|
+
}
|
|
1506
|
+
applyLocal(axis, amount) {
|
|
1507
|
+
if (axis >= this.groups) throw new Error("Cannot localize reduction axis");
|
|
1508
|
+
const length = this.st.shape[axis];
|
|
1509
|
+
if (length % amount !== 0) throw new Error(`Localize by ${amount} on axis length ${length}`);
|
|
1510
|
+
if (length !== amount) {
|
|
1511
|
+
this.groups++, this.reduce++, this.unroll++, this.upcast++;
|
|
1512
|
+
this.st = this.st.reshape([
|
|
1513
|
+
...this.st.shape.slice(0, axis),
|
|
1514
|
+
length / amount,
|
|
1515
|
+
amount,
|
|
1516
|
+
...this.st.shape.slice(axis + 1)
|
|
1517
|
+
]);
|
|
1518
|
+
this.outputSt = this.outputSt.reshape([
|
|
1519
|
+
...this.outputSt.shape.slice(0, axis),
|
|
1520
|
+
length / amount,
|
|
1521
|
+
amount,
|
|
1522
|
+
...this.outputSt.shape.slice(axis + 1)
|
|
1523
|
+
]);
|
|
1524
|
+
axis++;
|
|
1525
|
+
}
|
|
1526
|
+
this.st = this.st.permute([
|
|
1527
|
+
...range(axis),
|
|
1528
|
+
...range(axis + 1, this.groups),
|
|
1529
|
+
axis,
|
|
1530
|
+
...range(this.groups, this.st.shape.length)
|
|
1531
|
+
]);
|
|
1532
|
+
this.outputSt = this.outputSt.permute([
|
|
1533
|
+
...range(axis),
|
|
1534
|
+
...range(axis + 1, this.groups),
|
|
1535
|
+
axis,
|
|
1536
|
+
...range(this.groups, this.outputSt.shape.length)
|
|
1537
|
+
]);
|
|
1538
|
+
}
|
|
1539
|
+
applyUpcast(axis, amount) {
|
|
1540
|
+
if (axis >= this.groups) throw new Error("Cannot upcast along reduction axis");
|
|
1541
|
+
const length = this.st.shape[axis];
|
|
1542
|
+
if (length % amount !== 0) throw new Error(`Upcast by ${amount} on axis length ${length}`);
|
|
1543
|
+
this.st = this.st.reshape([
|
|
1544
|
+
...this.st.shape.slice(0, axis),
|
|
1545
|
+
length / amount,
|
|
1546
|
+
amount,
|
|
1547
|
+
...this.st.shape.slice(axis + 1)
|
|
1548
|
+
]).permute([
|
|
1549
|
+
...range(axis + 1),
|
|
1550
|
+
...range(axis + 2, this.st.shape.length + 1),
|
|
1551
|
+
axis + 1
|
|
1552
|
+
]);
|
|
1553
|
+
this.outputSt = this.outputSt.reshape([
|
|
1554
|
+
...this.outputSt.shape.slice(0, axis),
|
|
1555
|
+
length / amount,
|
|
1556
|
+
amount,
|
|
1557
|
+
...this.outputSt.shape.slice(axis + 1)
|
|
1558
|
+
]).permute([
|
|
1559
|
+
...range(axis + 1),
|
|
1560
|
+
...range(axis + 2, this.outputSt.shape.length + 1),
|
|
1561
|
+
axis + 1
|
|
1562
|
+
]);
|
|
1563
|
+
}
|
|
1564
|
+
applyUnroll(axis, amount) {
|
|
1565
|
+
if (axis < this.groups) throw new Error("Cannot unroll non-reduce axis");
|
|
1566
|
+
if (axis >= this.unroll) throw new Error("Axis already unrolled");
|
|
1567
|
+
const length = this.st.shape[axis];
|
|
1568
|
+
if (length % amount !== 0) throw new Error(`Unroll by ${amount} on axis length ${length}`);
|
|
1569
|
+
if (length === amount) {
|
|
1570
|
+
this.st = this.st.permute([
|
|
1571
|
+
...range(axis),
|
|
1572
|
+
...range(axis + 1, this.upcast),
|
|
1573
|
+
axis,
|
|
1574
|
+
...range(this.upcast, this.st.shape.length)
|
|
1575
|
+
]);
|
|
1576
|
+
if (axis < this.reduce) this.reduce--;
|
|
1577
|
+
this.unroll--;
|
|
1578
|
+
} else {
|
|
1579
|
+
this.st = this.st.reshape([
|
|
1580
|
+
...this.st.shape.slice(0, axis),
|
|
1581
|
+
length / amount,
|
|
1582
|
+
amount,
|
|
1583
|
+
...this.st.shape.slice(axis + 1)
|
|
1584
|
+
]).permute([
|
|
1585
|
+
...range(axis + 1),
|
|
1586
|
+
...range(axis + 2, this.upcast + 1),
|
|
1587
|
+
axis + 1,
|
|
1588
|
+
...range(this.upcast + 1, this.st.shape.length + 1)
|
|
1589
|
+
]);
|
|
1590
|
+
this.upcast++;
|
|
1591
|
+
}
|
|
1592
|
+
}
|
|
1593
|
+
};
|
|
1594
|
+
/** Tuning step that does not apply any optimization. */
|
|
1595
|
+
function tuneNullopt(kernel) {
|
|
1596
|
+
const vars = {};
|
|
1597
|
+
vars.gidx = AluExp.special(DType.Int32, "gidx", kernel.size);
|
|
1598
|
+
if (kernel.reduction) vars.ridx = AluExp.special(DType.Int32, "ridx", kernel.reduction.size);
|
|
1599
|
+
return {
|
|
1600
|
+
exp: kernel.exp.rewrite((exp) => {
|
|
1601
|
+
if (exp.op === AluOp.GlobalView) {
|
|
1602
|
+
const gid = exp.arg[0];
|
|
1603
|
+
const st = exp.arg[1];
|
|
1604
|
+
return accessorGlobal(exp.dtype, gid, st, exp.src);
|
|
1605
|
+
}
|
|
1606
|
+
}).substitute(vars).simplify(),
|
|
1607
|
+
outputIdxExp: AluExp.special(DType.Int32, "gidx", kernel.size),
|
|
1608
|
+
threadCount: kernel.size,
|
|
1609
|
+
size: { reduce: kernel.reduction ? kernel.reduction.size : 0 }
|
|
1610
|
+
};
|
|
1611
|
+
}
|
|
1612
|
+
/** Tuning for WebGPU kernels. */
|
|
1613
|
+
function tuneWebgpu(kernel) {
|
|
1614
|
+
const { exp, reduction } = kernel;
|
|
1615
|
+
if (!reduction) return tuneNullopt(kernel);
|
|
1616
|
+
const globalIndexes = exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex);
|
|
1617
|
+
if (globalIndexes.length > 0) {
|
|
1618
|
+
if (DEBUG >= 4) console.info("Tuning: Found GlobalIndex ops, skipping opt.");
|
|
1619
|
+
return tuneNullopt(kernel);
|
|
1620
|
+
}
|
|
1621
|
+
const globalViews = exp.collect((exp$1) => exp$1.op === AluOp.GlobalView);
|
|
1622
|
+
if (globalViews.length === 0) {
|
|
1623
|
+
if (DEBUG >= 4) console.info("Tuning: No GlobalView ops found in kernel.");
|
|
1624
|
+
return tuneNullopt(kernel);
|
|
1625
|
+
}
|
|
1626
|
+
const shape = globalViews[0].arg[1].shape;
|
|
1627
|
+
const expectedSrc = [...unravelAlu(shape.slice(0, -1), AluVar.gidx), AluVar.ridx].map((e) => e.simplify());
|
|
1628
|
+
for (const gv of globalViews) if (!gv.src.length || !deepEqual(gv.src, expectedSrc)) {
|
|
1629
|
+
if (DEBUG >= 4) console.info("Tuning: GlobalView src[] not consistent with reduction.");
|
|
1630
|
+
return tuneNullopt(kernel);
|
|
1631
|
+
}
|
|
1632
|
+
if (shape[shape.length - 1] !== reduction.size) throw new Error("Invariant violation: shape doesn't match reduction size.");
|
|
1633
|
+
const sts = globalViews.map((gv) => gv.arg[1]);
|
|
1634
|
+
for (const st of sts) if (!deepEqual(st.shape, shape)) throw new Error("Invariant violation: GlobalView shape mismatch");
|
|
1635
|
+
const dim = new TuneDims(shape);
|
|
1636
|
+
const upcastedAxis = /* @__PURE__ */ new Set();
|
|
1637
|
+
while (prod(dim.st.shape.slice(0, dim.groups)) >= 1024) {
|
|
1638
|
+
const choices = [];
|
|
1639
|
+
const composedSts = sts.map((st) => st.compose(dim.st));
|
|
1640
|
+
for (let axis = 0; axis < dim.groups; axis++) for (const amount of [3, 4]) if (!upcastedAxis.has(axis) && dim.st.shape[axis] % amount === 0 && composedSts.some((st) => st.lastStrides[axis] === 0 && st.lastStrides.slice(dim.unroll).every((stride) => stride > 0))) {
|
|
1641
|
+
let nonzeroStrides = 0;
|
|
1642
|
+
let totalStrides = 0;
|
|
1643
|
+
for (const st of composedSts) {
|
|
1644
|
+
nonzeroStrides += st.lastStrides[axis] > 0 ? 1 : 0;
|
|
1645
|
+
totalStrides += st.lastStrides[axis];
|
|
1646
|
+
}
|
|
1647
|
+
choices.push([
|
|
1648
|
+
nonzeroStrides,
|
|
1649
|
+
totalStrides,
|
|
1650
|
+
axis,
|
|
1651
|
+
amount
|
|
1652
|
+
]);
|
|
1653
|
+
}
|
|
1654
|
+
if (choices.length > 0) {
|
|
1655
|
+
choices.sort(lexCompare);
|
|
1656
|
+
dim.applyUpcast(choices[0][2], choices[0][3]);
|
|
1657
|
+
upcastedAxis.add(choices[0][2]);
|
|
1658
|
+
} else break;
|
|
1659
|
+
}
|
|
1660
|
+
if (/chrome/i.test(navigator.userAgent) && dim.reduce < dim.unroll && (prod(dim.st.shape.slice(dim.unroll)) <= 4 || dim.unroll === dim.upcast && prod(dim.st.shape.slice(dim.upcast)) < 64)) {
|
|
1661
|
+
const s = dim.st.shape[dim.unroll - 1];
|
|
1662
|
+
if (s <= 32) dim.applyUnroll(dim.reduce, s);
|
|
1663
|
+
else for (const splits of [4]) if (s % splits === 0) {
|
|
1664
|
+
dim.applyUnroll(dim.unroll - 1, splits);
|
|
1665
|
+
break;
|
|
1666
|
+
}
|
|
1667
|
+
}
|
|
1668
|
+
for (const ax of Array.from(upcastedAxis).sort()) {
|
|
1669
|
+
const s = dim.st.shape[ax];
|
|
1670
|
+
for (const amount of [8, 4]) if (s % amount === 0) {
|
|
1671
|
+
dim.applyLocal(ax, amount);
|
|
1672
|
+
break;
|
|
1673
|
+
}
|
|
1674
|
+
}
|
|
1675
|
+
const indices = [];
|
|
1676
|
+
const addIndices = (s, exp$1) => {
|
|
1677
|
+
if (s.length === 0) return;
|
|
1678
|
+
else if (s.length === 1) indices.push(exp$1);
|
|
1679
|
+
else indices.push(...unravelAlu(s, exp$1));
|
|
1680
|
+
};
|
|
1681
|
+
if (0 < dim.groups) {
|
|
1682
|
+
const s = dim.st.shape.slice(0, dim.groups);
|
|
1683
|
+
addIndices(s, AluExp.special(DType.Int32, "gidx", prod(s)));
|
|
1684
|
+
}
|
|
1685
|
+
if (dim.groups < dim.reduce) {
|
|
1686
|
+
const s = dim.st.shape.slice(dim.groups, dim.reduce);
|
|
1687
|
+
addIndices(s, AluExp.special(DType.Int32, "group", prod(s)));
|
|
1688
|
+
}
|
|
1689
|
+
if (dim.reduce <= dim.unroll) {
|
|
1690
|
+
const s = dim.st.shape.slice(dim.reduce, dim.unroll);
|
|
1691
|
+
addIndices(s, AluExp.special(DType.Int32, "ridx", prod(s)));
|
|
1692
|
+
}
|
|
1693
|
+
if (dim.unroll < dim.upcast) {
|
|
1694
|
+
const s = dim.st.shape.slice(dim.unroll, dim.upcast);
|
|
1695
|
+
addIndices(s, AluVar.unroll);
|
|
1696
|
+
}
|
|
1697
|
+
if (dim.upcast < dim.end) {
|
|
1698
|
+
const s = dim.st.shape.slice(dim.upcast);
|
|
1699
|
+
addIndices(s, AluVar.upcast);
|
|
1700
|
+
}
|
|
1701
|
+
const newExp = exp.rewrite((exp$1) => {
|
|
1702
|
+
if (exp$1.op === AluOp.GlobalView) {
|
|
1703
|
+
const gid = exp$1.arg[0];
|
|
1704
|
+
const st = exp$1.arg[1];
|
|
1705
|
+
return accessorGlobal(exp$1.dtype, gid, st.compose(dim.st), indices);
|
|
1706
|
+
}
|
|
1707
|
+
});
|
|
1708
|
+
const outputGidx = dim.outputSt.shape.slice(0, dim.groups);
|
|
1709
|
+
const outputUpcast = dim.outputSt.shape.slice(dim.groups);
|
|
1710
|
+
const [outputIdxExp, _] = dim.outputSt.toAluExp([...unravelAlu(outputGidx, AluExp.special(DType.Int32, "gidx", prod(outputGidx))), ...unravelAlu(outputUpcast, AluVar.upcast)]);
|
|
1711
|
+
if (prod(dim.st.shape.slice(dim.groups, dim.upcast)) !== reduction.size) throw new Error(`Invariant violation: reduction size ${reduction.size} does not match tuned dims ${JSON.stringify(dim.st.shape.slice(dim.groups, dim.upcast))}`);
|
|
1712
|
+
const size = {
|
|
1713
|
+
groups: prod(dim.st.shape.slice(dim.groups, dim.reduce)),
|
|
1714
|
+
reduce: prod(dim.st.shape.slice(dim.reduce, dim.unroll)),
|
|
1715
|
+
unroll: prod(dim.st.shape.slice(dim.unroll, dim.upcast)),
|
|
1716
|
+
upcast: prod(dim.st.shape.slice(dim.upcast))
|
|
1717
|
+
};
|
|
1718
|
+
return {
|
|
1719
|
+
exp: newExp.simplify(),
|
|
1720
|
+
outputIdxExp: outputIdxExp.simplify(),
|
|
1721
|
+
threadCount: kernel.size / size.upcast * size.groups,
|
|
1722
|
+
size
|
|
1723
|
+
};
|
|
1724
|
+
}
|
|
1725
|
+
|
|
1726
|
+
//#endregion
|
|
1727
|
+
//#region src/backend/cpu.ts
|
|
1728
|
+
/** Most basic implementation of `Backend` for testing. */
|
|
1729
|
+
var CPUBackend = class {
|
|
1730
|
+
type = "cpu";
|
|
1731
|
+
maxArgs = Infinity;
|
|
1732
|
+
#buffers;
|
|
1733
|
+
#nextSlot;
|
|
1734
|
+
constructor() {
|
|
1735
|
+
this.#buffers = /* @__PURE__ */ new Map();
|
|
1736
|
+
this.#nextSlot = 1;
|
|
1737
|
+
}
|
|
1738
|
+
malloc(size, initialData) {
|
|
1739
|
+
const buffer = new ArrayBuffer(size);
|
|
1740
|
+
if (initialData) {
|
|
1741
|
+
if (initialData.byteLength !== size) throw new Error("initialData size does not match buffer size");
|
|
1742
|
+
new Uint8Array(buffer).set(new Uint8Array(initialData));
|
|
1743
|
+
}
|
|
1744
|
+
const slot = this.#nextSlot++;
|
|
1745
|
+
this.#buffers.set(slot, {
|
|
1746
|
+
buffer,
|
|
1747
|
+
ref: 1
|
|
1748
|
+
});
|
|
1749
|
+
return slot;
|
|
1750
|
+
}
|
|
1751
|
+
incRef(slot) {
|
|
1752
|
+
const buffer = this.#buffers.get(slot);
|
|
1753
|
+
if (!buffer) throw new SlotError(slot);
|
|
1754
|
+
buffer.ref++;
|
|
1755
|
+
}
|
|
1756
|
+
decRef(slot) {
|
|
1757
|
+
const buffer = this.#buffers.get(slot);
|
|
1758
|
+
if (!buffer) throw new SlotError(slot);
|
|
1759
|
+
buffer.ref--;
|
|
1760
|
+
if (buffer.ref === 0) this.#buffers.delete(slot);
|
|
1761
|
+
}
|
|
1762
|
+
async read(slot, start, count) {
|
|
1763
|
+
return this.readSync(slot, start, count);
|
|
1764
|
+
}
|
|
1765
|
+
readSync(slot, start, count) {
|
|
1766
|
+
const buffer = this.#getBuffer(slot);
|
|
1767
|
+
if (start === void 0) start = 0;
|
|
1768
|
+
if (count === void 0) count = buffer.byteLength - start;
|
|
1769
|
+
return buffer.slice(start, start + count);
|
|
1770
|
+
}
|
|
1771
|
+
async prepare(kernel) {
|
|
1772
|
+
return this.prepareSync(kernel);
|
|
1773
|
+
}
|
|
1774
|
+
prepareSync(kernel) {
|
|
1775
|
+
return new Executable(kernel, void 0);
|
|
1776
|
+
}
|
|
1777
|
+
dispatch({ kernel }, inputs, outputs) {
|
|
1778
|
+
const { exp } = tuneNullopt(kernel);
|
|
1779
|
+
const inputBuffers = inputs.map((slot) => this.#getBuffer(slot));
|
|
1780
|
+
const outputBuffers = outputs.map((slot) => this.#getBuffer(slot));
|
|
1781
|
+
const usedArgs = new Map(exp.collect((exp$1) => exp$1.op === AluOp.GlobalIndex).map((exp$1) => [exp$1.arg, exp$1.dtype]));
|
|
1782
|
+
const inputArrays = inputBuffers.map((buf, i) => {
|
|
1783
|
+
const dtype = usedArgs.get(i);
|
|
1784
|
+
if (!dtype) return null;
|
|
1785
|
+
return dtypedArray(dtype, buf);
|
|
1786
|
+
});
|
|
1787
|
+
const outputArray = dtypedArray(kernel.dtype, outputBuffers[0]);
|
|
1788
|
+
const globals = (gid, bufidx) => {
|
|
1789
|
+
if (gid < 0 || gid >= inputArrays.length) throw new Error("gid out of bounds: " + gid);
|
|
1790
|
+
if (bufidx < 0 || bufidx >= inputArrays[gid].length) throw new Error("bufidx out of bounds: " + bufidx);
|
|
1791
|
+
return inputArrays[gid][bufidx];
|
|
1792
|
+
};
|
|
1793
|
+
if (!kernel.reduction) for (let i = 0; i < kernel.size; i++) outputArray[i] = exp.evaluate({ gidx: i }, globals);
|
|
1794
|
+
else for (let i = 0; i < kernel.size; i++) {
|
|
1795
|
+
let acc = kernel.reduction.identity;
|
|
1796
|
+
for (let j = 0; j < kernel.reduction.size; j++) {
|
|
1797
|
+
const item = exp.evaluate({
|
|
1798
|
+
gidx: i,
|
|
1799
|
+
ridx: j
|
|
1800
|
+
}, globals);
|
|
1801
|
+
acc = kernel.reduction.evaluate(acc, item);
|
|
1802
|
+
}
|
|
1803
|
+
outputArray[i] = kernel.reduction.fusion.evaluate({ acc });
|
|
1804
|
+
}
|
|
1805
|
+
}
|
|
1806
|
+
#getBuffer(slot) {
|
|
1807
|
+
const buffer = this.#buffers.get(slot);
|
|
1808
|
+
if (!buffer) throw new SlotError(slot);
|
|
1809
|
+
return buffer.buffer;
|
|
1810
|
+
}
|
|
1811
|
+
};
|
|
1812
|
+
|
|
1813
|
+
//#endregion
|
|
1814
|
+
//#region src/backend.ts
|
|
1815
|
+
const devices = ["cpu", "webgpu"];
|
|
1816
|
+
let defaultBackend = "cpu";
|
|
1817
|
+
const initializedBackends = /* @__PURE__ */ new Map();
|
|
1818
|
+
initializedBackends.set("cpu", new CPUBackend());
|
|
1819
|
+
/** Set the default device backend (must be initialized). */
|
|
1820
|
+
function setDevice(device) {
|
|
1821
|
+
if (initializedBackends.has(device)) defaultBackend = device;
|
|
1822
|
+
else throw new Error(`Backend not initialized: ${device}`);
|
|
1823
|
+
}
|
|
1824
|
+
/**
|
|
1825
|
+
* Initialize `jax-js` library backends.
|
|
1826
|
+
*
|
|
1827
|
+
* By default, this will initialize all available backends. If one or more
|
|
1828
|
+
* backends is provided, only attempt to initialize those. Returns a list of
|
|
1829
|
+
* available backends.
|
|
1830
|
+
*/
|
|
1831
|
+
async function init(...devicesToInit) {
|
|
1832
|
+
if (devicesToInit.length === 0) devicesToInit = devices;
|
|
1833
|
+
const promises = [];
|
|
1834
|
+
for (const device of new Set(devicesToInit)) if (!initializedBackends.has(device)) promises.push((async () => {
|
|
1835
|
+
const backend = await createBackend(device);
|
|
1836
|
+
if (backend) initializedBackends.set(device, backend);
|
|
1837
|
+
})());
|
|
1838
|
+
await Promise.all(promises);
|
|
1839
|
+
return Array.from(initializedBackends.keys());
|
|
1840
|
+
}
|
|
1841
|
+
/** Create a backend, if available. Internal function called by `init()`. */
|
|
1842
|
+
async function createBackend(device) {
|
|
1843
|
+
if (device === "cpu") return new CPUBackend();
|
|
1844
|
+
else if (device === "webgpu") {
|
|
1845
|
+
if (!navigator.gpu) return null;
|
|
1846
|
+
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
1847
|
+
if (!adapter) return null;
|
|
1848
|
+
const { WebGPUBackend } = await import("./webgpu-JVpVad6g.js");
|
|
1849
|
+
const importantLimits = [
|
|
1850
|
+
"maxBufferSize",
|
|
1851
|
+
"maxComputeInvocationsPerWorkgroup",
|
|
1852
|
+
"maxComputeWorkgroupSizeX",
|
|
1853
|
+
"maxComputeWorkgroupSizeY",
|
|
1854
|
+
"maxComputeWorkgroupSizeZ",
|
|
1855
|
+
"maxComputeWorkgroupStorageSize",
|
|
1856
|
+
"maxComputeWorkgroupsPerDimension",
|
|
1857
|
+
"maxStorageBufferBindingSize",
|
|
1858
|
+
"maxStorageBuffersPerShaderStage",
|
|
1859
|
+
"maxStorageTexturesPerShaderStage"
|
|
1860
|
+
];
|
|
1861
|
+
try {
|
|
1862
|
+
const device$1 = await adapter.requestDevice({ requiredLimits: Object.fromEntries(importantLimits.map((feature) => [feature, adapter.limits[feature]])) });
|
|
1863
|
+
return new WebGPUBackend(device$1);
|
|
1864
|
+
} catch (error) {
|
|
1865
|
+
console.error("Unexpected error requesting WebGPU device:", error);
|
|
1866
|
+
return null;
|
|
1867
|
+
}
|
|
1868
|
+
} else throw new Error(`Backend not found: ${device}`);
|
|
1869
|
+
}
|
|
1870
|
+
/** Retrieve a backend that has been initialized. */
|
|
1871
|
+
function getBackend(device) {
|
|
1872
|
+
device = device ?? defaultBackend;
|
|
1873
|
+
const backend = initializedBackends.get(device);
|
|
1874
|
+
if (!backend) throw new Error(`${device} backend not ready, call init() first`);
|
|
1875
|
+
return backend;
|
|
1876
|
+
}
|
|
1877
|
+
var Executable = class {
|
|
1878
|
+
constructor(kernel, data) {
|
|
1879
|
+
this.kernel = kernel;
|
|
1880
|
+
this.data = data;
|
|
1881
|
+
}
|
|
1882
|
+
};
|
|
1883
|
+
var SlotError = class extends Error {
|
|
1884
|
+
constructor(slot) {
|
|
1885
|
+
super(`Used a buffer that is invalid or already freed: ${slot}`);
|
|
1886
|
+
}
|
|
1887
|
+
};
|
|
1888
|
+
|
|
1889
|
+
//#endregion
|
|
1890
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache, setDevice, strip1, toposort, tuneWebgpu, unravelAlu, unzip2, zip };
|