@jax-js/jax 0.0.3 → 0.0.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +96 -22
- package/dist/{backend-BqDtPGaR.js → backend-CdcTZEOF.js} +325 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-yEU0L_ig.cjs} +350 -154
- package/dist/index.cjs +977 -354
- package/dist/index.d.cts +479 -88
- package/dist/index.d.ts +479 -88
- package/dist/index.js +964 -345
- package/dist/{webgpu-CNg9JGva.js → webgpu-CM-xNYzW.js} +9 -3
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-CNOpiO5T.cjs} +9 -3
- package/package.json +15 -4
|
@@ -36,7 +36,22 @@ var PPrint = class PPrint {
|
|
|
36
36
|
//#endregion
|
|
37
37
|
//#region src/utils.ts
|
|
38
38
|
/** @file Generic programming utilities with no dependencies on library code. */
|
|
39
|
-
|
|
39
|
+
let DEBUG = 0;
|
|
40
|
+
/**
|
|
41
|
+
* Set the debug level for verbose logging.
|
|
42
|
+
*
|
|
43
|
+
* 1. JIT compile logs
|
|
44
|
+
* 2. Shader code
|
|
45
|
+
* 3. Expressions and metadata
|
|
46
|
+
* 4. JIT programs, tuning details
|
|
47
|
+
* 5. Most verbose operation traces
|
|
48
|
+
*
|
|
49
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
50
|
+
* in production.
|
|
51
|
+
*/
|
|
52
|
+
function setDebug(level) {
|
|
53
|
+
DEBUG = level;
|
|
54
|
+
}
|
|
40
55
|
function unzip2(pairs) {
|
|
41
56
|
const lst1 = [];
|
|
42
57
|
const lst2 = [];
|
|
@@ -110,9 +125,23 @@ function isNumberPair(x) {
|
|
|
110
125
|
}
|
|
111
126
|
/** Check an axis against number of dimensions, and resolve negative axes. */
|
|
112
127
|
function checkAxis(axis, ndim) {
|
|
113
|
-
if (axis < -ndim || axis >= ndim) throw new Error(`
|
|
128
|
+
if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
|
|
114
129
|
return axis < 0 ? axis + ndim : axis;
|
|
115
130
|
}
|
|
131
|
+
/** Normalize common axis argument for functions, defaulting to all axes. */
|
|
132
|
+
function normalizeAxis(axis, ndim) {
|
|
133
|
+
if (axis === null) return range(ndim);
|
|
134
|
+
else if (typeof axis === "number") return [checkAxis(axis, ndim)];
|
|
135
|
+
else {
|
|
136
|
+
const seen = /* @__PURE__ */ new Set();
|
|
137
|
+
for (const a of axis) {
|
|
138
|
+
const ca = checkAxis(a, ndim);
|
|
139
|
+
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
140
|
+
seen.add(ca);
|
|
141
|
+
}
|
|
142
|
+
return [...seen].sort();
|
|
143
|
+
}
|
|
144
|
+
}
|
|
116
145
|
function range(start, stop, step = 1) {
|
|
117
146
|
if (stop === void 0) {
|
|
118
147
|
stop = start;
|
|
@@ -178,6 +207,35 @@ function findPow2(hint, max) {
|
|
|
178
207
|
while (ret < hint && 2 * ret <= max) ret *= 2;
|
|
179
208
|
return ret;
|
|
180
209
|
}
|
|
210
|
+
/**
|
|
211
|
+
* Implements a NumPy-style generalized broadcast rule on two array shapes.
|
|
212
|
+
*
|
|
213
|
+
* "When operating on two arrays, NumPy compares their shapes element-wise. It
|
|
214
|
+
* starts with the trailing (i.e. rightmost) dimension and works its way left.
|
|
215
|
+
* Two dimensions are compatible when:
|
|
216
|
+
* 1. they are equal, or
|
|
217
|
+
* 2. one of them is 1."
|
|
218
|
+
*
|
|
219
|
+
* Throws a TypeError if the broadcast is not possible.
|
|
220
|
+
*
|
|
221
|
+
* <https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules>
|
|
222
|
+
*/
|
|
223
|
+
function generalBroadcast(a, b) {
|
|
224
|
+
const out = [];
|
|
225
|
+
let i = a.length - 1;
|
|
226
|
+
let j = b.length - 1;
|
|
227
|
+
for (; i >= 0 && j >= 0; i--, j--) {
|
|
228
|
+
const x = a[i];
|
|
229
|
+
const y = b[j];
|
|
230
|
+
if (x === y) out.push(x);
|
|
231
|
+
else if (x === 1) out.push(y);
|
|
232
|
+
else if (y === 1) out.push(x);
|
|
233
|
+
else throw new TypeError(`Incompatible array broadcast shapes: ${a} vs ${b}`);
|
|
234
|
+
}
|
|
235
|
+
for (; i >= 0; i--) out.push(a[i]);
|
|
236
|
+
for (; j >= 0; j--) out.push(b[j]);
|
|
237
|
+
return out.reverse();
|
|
238
|
+
}
|
|
181
239
|
function recursiveFlatten(ar) {
|
|
182
240
|
if (!Array.isArray(ar)) return [ar];
|
|
183
241
|
return ar.flat(Infinity);
|
|
@@ -187,6 +245,7 @@ function strip1(str) {
|
|
|
187
245
|
if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
|
|
188
246
|
return str;
|
|
189
247
|
}
|
|
248
|
+
const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
|
|
190
249
|
/**
|
|
191
250
|
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
192
251
|
* Probability-wise, it's good enough to be used for something like
|
|
@@ -201,22 +260,26 @@ var FpHash = class FpHash {
|
|
|
201
260
|
const modulus = 3189051996290219n;
|
|
202
261
|
this.value = (this.value * base + x) % modulus;
|
|
203
262
|
}
|
|
204
|
-
update(
|
|
205
|
-
|
|
206
|
-
|
|
263
|
+
update(x) {
|
|
264
|
+
if (typeof x === "string") {
|
|
265
|
+
this.#update(BigInt(x.length));
|
|
266
|
+
for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
|
|
267
|
+
} else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
|
|
207
268
|
else {
|
|
208
|
-
|
|
209
|
-
this.#update(
|
|
269
|
+
_stagingbuf.setFloat64(0, x, true);
|
|
270
|
+
this.#update(_stagingbuf.getBigUint64(0, true));
|
|
210
271
|
}
|
|
211
272
|
else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
|
|
212
273
|
else if (typeof x === "bigint") this.#update(x ^ 71657401n);
|
|
213
274
|
else if (x === null) this.#update(37832657n);
|
|
214
275
|
else if (x === void 0) this.#update(18145117n);
|
|
215
|
-
else
|
|
276
|
+
else x.hash(this);
|
|
216
277
|
return this;
|
|
217
278
|
}
|
|
218
279
|
static hash(...values) {
|
|
219
|
-
|
|
280
|
+
const h = new FpHash();
|
|
281
|
+
for (const x of values) h.update(x);
|
|
282
|
+
return h.value;
|
|
220
283
|
}
|
|
221
284
|
};
|
|
222
285
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -251,6 +314,41 @@ const byteWidth = (dtype) => {
|
|
|
251
314
|
}
|
|
252
315
|
};
|
|
253
316
|
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
317
|
+
/**
|
|
318
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
319
|
+
*
|
|
320
|
+
* When performing operations between arrays of different types, we need to
|
|
321
|
+
* promote both operands to a common type that can represent values from both
|
|
322
|
+
* input types. This follows JAX's type promotion rules.
|
|
323
|
+
*
|
|
324
|
+
* **Type lattice:**
|
|
325
|
+
* ```text
|
|
326
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
327
|
+
* weakType --^
|
|
328
|
+
* ```
|
|
329
|
+
*
|
|
330
|
+
* `weakType` represents weakly typed arrays. These are created for JS numbers,
|
|
331
|
+
* which default to float32 but "weak" so they cast to the dtype of any array
|
|
332
|
+
* they are first combined with, except `bool`.
|
|
333
|
+
*
|
|
334
|
+
* **Examples:**
|
|
335
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
336
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
337
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
338
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
339
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
340
|
+
*/
|
|
341
|
+
function promoteTypes(dtype1, dtype2) {
|
|
342
|
+
if (dtype1 === dtype2) return dtype1;
|
|
343
|
+
const rank = {
|
|
344
|
+
[DType.Bool]: 0,
|
|
345
|
+
[DType.Uint32]: 1,
|
|
346
|
+
[DType.Int32]: 2,
|
|
347
|
+
[DType.Float16]: 3,
|
|
348
|
+
[DType.Float32]: 4
|
|
349
|
+
};
|
|
350
|
+
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
351
|
+
}
|
|
254
352
|
function dtypedArray(dtype, data) {
|
|
255
353
|
const { buffer, byteLength, byteOffset } = data;
|
|
256
354
|
const length = byteLength / byteWidth(dtype);
|
|
@@ -320,6 +418,12 @@ var AluExp = class AluExp {
|
|
|
320
418
|
static cos(a) {
|
|
321
419
|
return new AluExp(AluOp.Cos, a.dtype, [a]);
|
|
322
420
|
}
|
|
421
|
+
static asin(a) {
|
|
422
|
+
return new AluExp(AluOp.Asin, a.dtype, [a]);
|
|
423
|
+
}
|
|
424
|
+
static atan(a) {
|
|
425
|
+
return new AluExp(AluOp.Atan, a.dtype, [a]);
|
|
426
|
+
}
|
|
323
427
|
static exp(a) {
|
|
324
428
|
return new AluExp(AluOp.Exp, a.dtype, [a]);
|
|
325
429
|
}
|
|
@@ -403,8 +507,11 @@ var AluExp = class AluExp {
|
|
|
403
507
|
getHash() {
|
|
404
508
|
if (this.#hash !== void 0) return this.#hash;
|
|
405
509
|
const hasher = new FpHash();
|
|
406
|
-
hasher.update(this.op
|
|
407
|
-
hasher.update(this.
|
|
510
|
+
hasher.update(this.op);
|
|
511
|
+
hasher.update(this.dtype);
|
|
512
|
+
hasher.update(JSON.stringify(this.arg));
|
|
513
|
+
hasher.update(this.src.length);
|
|
514
|
+
for (const s of this.src) hasher.update(s);
|
|
408
515
|
this.#hash = hasher.value;
|
|
409
516
|
return this.#hash;
|
|
410
517
|
}
|
|
@@ -476,10 +583,16 @@ var AluExp = class AluExp {
|
|
|
476
583
|
ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
|
|
477
584
|
break;
|
|
478
585
|
case AluOp.Sin:
|
|
479
|
-
ret = [
|
|
586
|
+
ret = [-1, 1];
|
|
480
587
|
break;
|
|
481
588
|
case AluOp.Cos:
|
|
482
|
-
ret = [
|
|
589
|
+
ret = [-1, 1];
|
|
590
|
+
break;
|
|
591
|
+
case AluOp.Asin:
|
|
592
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
593
|
+
break;
|
|
594
|
+
case AluOp.Atan:
|
|
595
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
483
596
|
break;
|
|
484
597
|
case AluOp.Exp:
|
|
485
598
|
ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
|
|
@@ -595,11 +708,12 @@ var AluExp = class AluExp {
|
|
|
595
708
|
simplify(cache = /* @__PURE__ */ new Map()) {
|
|
596
709
|
if (this.#simplified !== void 0) return this.#simplified;
|
|
597
710
|
const hash = this.getHash();
|
|
598
|
-
|
|
711
|
+
const prevCachedValue = cache.get(hash);
|
|
712
|
+
if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
|
|
599
713
|
const simplified = this.#simplifyInner(cache);
|
|
600
714
|
const simplifiedHash = simplified.getHash();
|
|
601
|
-
|
|
602
|
-
|
|
715
|
+
const prevSimplified = cache.get(simplifiedHash);
|
|
716
|
+
if (prevSimplified !== void 0) {
|
|
603
717
|
cache.set(hash, prevSimplified);
|
|
604
718
|
this.#simplified = prevSimplified;
|
|
605
719
|
return prevSimplified;
|
|
@@ -803,6 +917,8 @@ var AluExp = class AluExp {
|
|
|
803
917
|
switch (this.op) {
|
|
804
918
|
case AluOp.Sin: return Math.sin(x);
|
|
805
919
|
case AluOp.Cos: return Math.cos(x);
|
|
920
|
+
case AluOp.Asin: return Math.asin(x);
|
|
921
|
+
case AluOp.Atan: return Math.atan(x);
|
|
806
922
|
case AluOp.Exp: return Math.exp(x);
|
|
807
923
|
case AluOp.Log: return Math.log(x);
|
|
808
924
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
@@ -982,6 +1098,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
982
1098
|
AluOp$1["Max"] = "Max";
|
|
983
1099
|
AluOp$1["Sin"] = "Sin";
|
|
984
1100
|
AluOp$1["Cos"] = "Cos";
|
|
1101
|
+
AluOp$1["Asin"] = "Asin";
|
|
1102
|
+
AluOp$1["Atan"] = "Atan";
|
|
985
1103
|
AluOp$1["Exp"] = "Exp";
|
|
986
1104
|
AluOp$1["Log"] = "Log";
|
|
987
1105
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
@@ -1012,6 +1130,8 @@ const AluGroup = {
|
|
|
1012
1130
|
Unary: new Set([
|
|
1013
1131
|
AluOp.Sin,
|
|
1014
1132
|
AluOp.Cos,
|
|
1133
|
+
AluOp.Asin,
|
|
1134
|
+
AluOp.Atan,
|
|
1015
1135
|
AluOp.Exp,
|
|
1016
1136
|
AluOp.Log,
|
|
1017
1137
|
AluOp.Sqrt,
|
|
@@ -1035,6 +1155,8 @@ const AluGroup = {
|
|
|
1035
1155
|
RequiredFloat: new Set([
|
|
1036
1156
|
AluOp.Sin,
|
|
1037
1157
|
AluOp.Cos,
|
|
1158
|
+
AluOp.Asin,
|
|
1159
|
+
AluOp.Atan,
|
|
1038
1160
|
AluOp.Exp,
|
|
1039
1161
|
AluOp.Log,
|
|
1040
1162
|
AluOp.Sqrt,
|
|
@@ -1066,7 +1188,7 @@ var Kernel = class {
|
|
|
1066
1188
|
this.exp = exp.simplify();
|
|
1067
1189
|
}
|
|
1068
1190
|
hash(state) {
|
|
1069
|
-
state.update(this.nargs
|
|
1191
|
+
state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
|
|
1070
1192
|
}
|
|
1071
1193
|
pprint() {
|
|
1072
1194
|
let details = PPrint.pp(`exp = ${this.exp}`);
|
|
@@ -1112,7 +1234,7 @@ var Reduction = class {
|
|
|
1112
1234
|
this.epilogue = epilogue.simplify();
|
|
1113
1235
|
}
|
|
1114
1236
|
hash(state) {
|
|
1115
|
-
state.update(this.dtype
|
|
1237
|
+
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
1116
1238
|
}
|
|
1117
1239
|
toString() {
|
|
1118
1240
|
return `${this.op}{${this.size}} -> ${this.epilogue}`;
|
|
@@ -2284,78 +2406,92 @@ function wasm_log(cg) {
|
|
|
2284
2406
|
});
|
|
2285
2407
|
}
|
|
2286
2408
|
/**
|
|
2287
|
-
*
|
|
2409
|
+
* Common helper to approximate sin(x) and cos(x).
|
|
2288
2410
|
*
|
|
2289
2411
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2290
|
-
* z = y - q*(π/2); use
|
|
2412
|
+
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2291
2413
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2414
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2415
|
+
*/
|
|
2416
|
+
function _sincos(cg) {
|
|
2417
|
+
const y = cg.local.declare(cg.f32);
|
|
2418
|
+
const qf = cg.local.declare(cg.f32);
|
|
2419
|
+
const q = cg.local.declare(cg.i32);
|
|
2420
|
+
const z = cg.local.declare(cg.f32);
|
|
2421
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2422
|
+
const sz = cg.local.declare(cg.f32);
|
|
2423
|
+
const cz = cg.local.declare(cg.f32);
|
|
2424
|
+
cg.local.get(0);
|
|
2425
|
+
cg.local.get(0);
|
|
2426
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2427
|
+
cg.f32.mul();
|
|
2428
|
+
cg.f32.nearest();
|
|
2429
|
+
cg.local.tee(qf);
|
|
2430
|
+
cg.f32.const(2 * Math.PI);
|
|
2431
|
+
cg.f32.mul();
|
|
2432
|
+
cg.f32.sub();
|
|
2433
|
+
cg.local.set(y);
|
|
2434
|
+
cg.local.get(y);
|
|
2435
|
+
cg.f32.const(2 / Math.PI);
|
|
2436
|
+
cg.f32.mul();
|
|
2437
|
+
cg.f32.nearest();
|
|
2438
|
+
cg.local.tee(qf);
|
|
2439
|
+
cg.i32.trunc_f32_s();
|
|
2440
|
+
cg.local.set(q);
|
|
2441
|
+
cg.local.get(y);
|
|
2442
|
+
cg.local.get(qf);
|
|
2443
|
+
cg.f32.const(Math.PI / 2);
|
|
2444
|
+
cg.f32.mul();
|
|
2445
|
+
cg.f32.sub();
|
|
2446
|
+
cg.local.tee(z);
|
|
2447
|
+
cg.local.get(z);
|
|
2448
|
+
cg.f32.mul();
|
|
2449
|
+
cg.local.set(z2);
|
|
2450
|
+
cg.f32.const(-1 / 5040);
|
|
2451
|
+
cg.local.get(z2);
|
|
2452
|
+
cg.f32.mul();
|
|
2453
|
+
cg.f32.const(1 / 120);
|
|
2454
|
+
cg.f32.add();
|
|
2455
|
+
cg.local.get(z2);
|
|
2456
|
+
cg.f32.mul();
|
|
2457
|
+
cg.f32.const(-1 / 6);
|
|
2458
|
+
cg.f32.add();
|
|
2459
|
+
cg.local.get(z2);
|
|
2460
|
+
cg.f32.mul();
|
|
2461
|
+
cg.f32.const(1);
|
|
2462
|
+
cg.f32.add();
|
|
2463
|
+
cg.local.get(z);
|
|
2464
|
+
cg.f32.mul();
|
|
2465
|
+
cg.local.set(sz);
|
|
2466
|
+
cg.f32.const(-1 / 720);
|
|
2467
|
+
cg.local.get(z2);
|
|
2468
|
+
cg.f32.mul();
|
|
2469
|
+
cg.f32.const(1 / 24);
|
|
2470
|
+
cg.f32.add();
|
|
2471
|
+
cg.local.get(z2);
|
|
2472
|
+
cg.f32.mul();
|
|
2473
|
+
cg.f32.const(-1 / 2);
|
|
2474
|
+
cg.f32.add();
|
|
2475
|
+
cg.local.get(z2);
|
|
2476
|
+
cg.f32.mul();
|
|
2477
|
+
cg.f32.const(1);
|
|
2478
|
+
cg.f32.add();
|
|
2479
|
+
cg.local.set(cz);
|
|
2480
|
+
return {
|
|
2481
|
+
q,
|
|
2482
|
+
sz,
|
|
2483
|
+
cz
|
|
2484
|
+
};
|
|
2485
|
+
}
|
|
2486
|
+
/**
|
|
2487
|
+
* Approximate sin(x).
|
|
2488
|
+
*
|
|
2489
|
+
* Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
|
|
2292
2490
|
*/
|
|
2293
2491
|
function wasm_sin(cg) {
|
|
2294
2492
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2295
|
-
const
|
|
2296
|
-
const qf = cg.local.declare(cg.f32);
|
|
2297
|
-
const q = cg.local.declare(cg.i32);
|
|
2298
|
-
const z = cg.local.declare(cg.f32);
|
|
2299
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2300
|
-
const sz = cg.local.declare(cg.f32);
|
|
2301
|
-
const cz = cg.local.declare(cg.f32);
|
|
2493
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2302
2494
|
const mag = cg.local.declare(cg.f32);
|
|
2303
|
-
cg.local.get(0);
|
|
2304
|
-
cg.local.get(0);
|
|
2305
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2306
|
-
cg.f32.mul();
|
|
2307
|
-
cg.f32.nearest();
|
|
2308
|
-
cg.local.tee(qf);
|
|
2309
|
-
cg.f32.const(2 * Math.PI);
|
|
2310
|
-
cg.f32.mul();
|
|
2311
|
-
cg.f32.sub();
|
|
2312
|
-
cg.local.set(y);
|
|
2313
|
-
cg.local.get(y);
|
|
2314
|
-
cg.f32.const(2 / Math.PI);
|
|
2315
|
-
cg.f32.mul();
|
|
2316
|
-
cg.f32.nearest();
|
|
2317
|
-
cg.local.tee(qf);
|
|
2318
|
-
cg.i32.trunc_f32_s();
|
|
2319
|
-
cg.local.set(q);
|
|
2320
|
-
cg.local.get(y);
|
|
2321
|
-
cg.local.get(qf);
|
|
2322
|
-
cg.f32.const(Math.PI / 2);
|
|
2323
|
-
cg.f32.mul();
|
|
2324
|
-
cg.f32.sub();
|
|
2325
|
-
cg.local.tee(z);
|
|
2326
|
-
cg.local.get(z);
|
|
2327
|
-
cg.f32.mul();
|
|
2328
|
-
cg.local.set(z2);
|
|
2329
|
-
cg.f32.const(-1 / 5040);
|
|
2330
|
-
cg.local.get(z2);
|
|
2331
|
-
cg.f32.mul();
|
|
2332
|
-
cg.f32.const(1 / 120);
|
|
2333
|
-
cg.f32.add();
|
|
2334
|
-
cg.local.get(z2);
|
|
2335
|
-
cg.f32.mul();
|
|
2336
|
-
cg.f32.const(-1 / 6);
|
|
2337
|
-
cg.f32.add();
|
|
2338
|
-
cg.local.get(z2);
|
|
2339
|
-
cg.f32.mul();
|
|
2340
|
-
cg.f32.const(1);
|
|
2341
|
-
cg.f32.add();
|
|
2342
|
-
cg.local.get(z);
|
|
2343
|
-
cg.f32.mul();
|
|
2344
|
-
cg.local.set(sz);
|
|
2345
|
-
cg.f32.const(-1 / 720);
|
|
2346
|
-
cg.local.get(z2);
|
|
2347
|
-
cg.f32.mul();
|
|
2348
|
-
cg.f32.const(1 / 24);
|
|
2349
|
-
cg.f32.add();
|
|
2350
|
-
cg.local.get(z2);
|
|
2351
|
-
cg.f32.mul();
|
|
2352
|
-
cg.f32.const(-1 / 2);
|
|
2353
|
-
cg.f32.add();
|
|
2354
|
-
cg.local.get(z2);
|
|
2355
|
-
cg.f32.mul();
|
|
2356
|
-
cg.f32.const(1);
|
|
2357
|
-
cg.f32.add();
|
|
2358
|
-
cg.local.set(cz);
|
|
2359
2495
|
cg.local.get(cz);
|
|
2360
2496
|
cg.local.get(sz);
|
|
2361
2497
|
cg.local.get(q);
|
|
@@ -2374,75 +2510,12 @@ function wasm_sin(cg) {
|
|
|
2374
2510
|
/**
|
|
2375
2511
|
* Approximate cos(x).
|
|
2376
2512
|
*
|
|
2377
|
-
*
|
|
2378
|
-
* k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2513
|
+
* Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2379
2514
|
*/
|
|
2380
2515
|
function wasm_cos(cg) {
|
|
2381
2516
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2382
|
-
const
|
|
2383
|
-
const qf = cg.local.declare(cg.f32);
|
|
2384
|
-
const q = cg.local.declare(cg.i32);
|
|
2385
|
-
const z = cg.local.declare(cg.f32);
|
|
2386
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2387
|
-
const sz = cg.local.declare(cg.f32);
|
|
2388
|
-
const cz = cg.local.declare(cg.f32);
|
|
2517
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2389
2518
|
const mag = cg.local.declare(cg.f32);
|
|
2390
|
-
cg.local.get(0);
|
|
2391
|
-
cg.local.get(0);
|
|
2392
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2393
|
-
cg.f32.mul();
|
|
2394
|
-
cg.f32.nearest();
|
|
2395
|
-
cg.local.tee(qf);
|
|
2396
|
-
cg.f32.const(2 * Math.PI);
|
|
2397
|
-
cg.f32.mul();
|
|
2398
|
-
cg.f32.sub();
|
|
2399
|
-
cg.local.set(y);
|
|
2400
|
-
cg.local.get(y);
|
|
2401
|
-
cg.f32.const(2 / Math.PI);
|
|
2402
|
-
cg.f32.mul();
|
|
2403
|
-
cg.f32.nearest();
|
|
2404
|
-
cg.local.tee(qf);
|
|
2405
|
-
cg.i32.trunc_f32_s();
|
|
2406
|
-
cg.local.set(q);
|
|
2407
|
-
cg.local.get(y);
|
|
2408
|
-
cg.local.get(qf);
|
|
2409
|
-
cg.f32.const(Math.PI / 2);
|
|
2410
|
-
cg.f32.mul();
|
|
2411
|
-
cg.f32.sub();
|
|
2412
|
-
cg.local.tee(z);
|
|
2413
|
-
cg.local.get(z);
|
|
2414
|
-
cg.f32.mul();
|
|
2415
|
-
cg.local.set(z2);
|
|
2416
|
-
cg.f32.const(-1 / 5040);
|
|
2417
|
-
cg.local.get(z2);
|
|
2418
|
-
cg.f32.mul();
|
|
2419
|
-
cg.f32.const(1 / 120);
|
|
2420
|
-
cg.f32.add();
|
|
2421
|
-
cg.local.get(z2);
|
|
2422
|
-
cg.f32.mul();
|
|
2423
|
-
cg.f32.const(-1 / 6);
|
|
2424
|
-
cg.f32.add();
|
|
2425
|
-
cg.local.get(z2);
|
|
2426
|
-
cg.f32.mul();
|
|
2427
|
-
cg.f32.const(1);
|
|
2428
|
-
cg.f32.add();
|
|
2429
|
-
cg.local.get(z);
|
|
2430
|
-
cg.f32.mul();
|
|
2431
|
-
cg.local.set(sz);
|
|
2432
|
-
cg.f32.const(-1 / 720);
|
|
2433
|
-
cg.local.get(z2);
|
|
2434
|
-
cg.f32.mul();
|
|
2435
|
-
cg.f32.const(1 / 24);
|
|
2436
|
-
cg.f32.add();
|
|
2437
|
-
cg.local.get(z2);
|
|
2438
|
-
cg.f32.mul();
|
|
2439
|
-
cg.f32.const(-1 / 2);
|
|
2440
|
-
cg.f32.add();
|
|
2441
|
-
cg.local.get(z2);
|
|
2442
|
-
cg.f32.mul();
|
|
2443
|
-
cg.f32.const(1);
|
|
2444
|
-
cg.f32.add();
|
|
2445
|
-
cg.local.set(cz);
|
|
2446
2519
|
cg.local.get(sz);
|
|
2447
2520
|
cg.local.get(cz);
|
|
2448
2521
|
cg.local.get(q);
|
|
@@ -2460,6 +2533,100 @@ function wasm_cos(cg) {
|
|
|
2460
2533
|
cg.select();
|
|
2461
2534
|
});
|
|
2462
2535
|
}
|
|
2536
|
+
/** Helper function for approximating arctan(x). */
|
|
2537
|
+
function _atan(cg) {
|
|
2538
|
+
const x = cg.local.declare(cg.f32);
|
|
2539
|
+
const abs_x = cg.local.declare(cg.f32);
|
|
2540
|
+
const z = cg.local.declare(cg.f32);
|
|
2541
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2542
|
+
const p = cg.local.declare(cg.f32);
|
|
2543
|
+
cg.local.set(x);
|
|
2544
|
+
cg.local.get(x);
|
|
2545
|
+
cg.f32.abs();
|
|
2546
|
+
cg.local.set(abs_x);
|
|
2547
|
+
cg.f32.const(1);
|
|
2548
|
+
cg.local.get(abs_x);
|
|
2549
|
+
cg.f32.div();
|
|
2550
|
+
cg.local.get(abs_x);
|
|
2551
|
+
cg.local.get(abs_x);
|
|
2552
|
+
cg.f32.const(1);
|
|
2553
|
+
cg.f32.ge();
|
|
2554
|
+
cg.select();
|
|
2555
|
+
cg.local.set(z);
|
|
2556
|
+
cg.local.get(z);
|
|
2557
|
+
cg.local.get(z);
|
|
2558
|
+
cg.f32.mul();
|
|
2559
|
+
cg.local.set(z2);
|
|
2560
|
+
cg.f32.const(.0415796528637);
|
|
2561
|
+
cg.local.get(z2);
|
|
2562
|
+
cg.f32.mul();
|
|
2563
|
+
cg.f32.const(.661705427875);
|
|
2564
|
+
cg.f32.add();
|
|
2565
|
+
cg.local.get(z2);
|
|
2566
|
+
cg.f32.mul();
|
|
2567
|
+
cg.f32.const(.999998614341);
|
|
2568
|
+
cg.f32.add();
|
|
2569
|
+
cg.f32.const(.173698870181);
|
|
2570
|
+
cg.local.get(z2);
|
|
2571
|
+
cg.f32.mul();
|
|
2572
|
+
cg.f32.const(.994987933645);
|
|
2573
|
+
cg.f32.add();
|
|
2574
|
+
cg.local.get(z2);
|
|
2575
|
+
cg.f32.mul();
|
|
2576
|
+
cg.f32.const(1);
|
|
2577
|
+
cg.f32.add();
|
|
2578
|
+
cg.f32.div();
|
|
2579
|
+
cg.local.get(z);
|
|
2580
|
+
cg.f32.mul();
|
|
2581
|
+
cg.local.set(p);
|
|
2582
|
+
cg.f32.const(Math.PI / 2);
|
|
2583
|
+
cg.local.get(p);
|
|
2584
|
+
cg.f32.sub();
|
|
2585
|
+
cg.local.get(p);
|
|
2586
|
+
cg.local.get(abs_x);
|
|
2587
|
+
cg.f32.const(1);
|
|
2588
|
+
cg.f32.ge();
|
|
2589
|
+
cg.select();
|
|
2590
|
+
cg.local.get(x);
|
|
2591
|
+
cg.f32.copysign();
|
|
2592
|
+
}
|
|
2593
|
+
/**
|
|
2594
|
+
* Approximate atan(x).
|
|
2595
|
+
*
|
|
2596
|
+
* Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
|
|
2597
|
+
* where P(u) = A0 + A1*u + A2*u^2 (degree 2)
|
|
2598
|
+
* Q(u) = 1 + B1*u + B2*u^2 (degree 2)
|
|
2599
|
+
* if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
|
|
2600
|
+
* (fitted coefficients, max error ~5e-7 on [0,1])
|
|
2601
|
+
*/
|
|
2602
|
+
function wasm_atan(cg) {
|
|
2603
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2604
|
+
cg.local.get(0);
|
|
2605
|
+
_atan(cg);
|
|
2606
|
+
});
|
|
2607
|
+
}
|
|
2608
|
+
/**
|
|
2609
|
+
* Approximate asin(x).
|
|
2610
|
+
*
|
|
2611
|
+
* Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
|
2612
|
+
*/
|
|
2613
|
+
function wasm_asin(cg) {
|
|
2614
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2615
|
+
cg.local.get(0);
|
|
2616
|
+
cg.f32.const(1);
|
|
2617
|
+
cg.local.get(0);
|
|
2618
|
+
cg.local.get(0);
|
|
2619
|
+
cg.f32.mul();
|
|
2620
|
+
cg.f32.sub();
|
|
2621
|
+
cg.f32.sqrt();
|
|
2622
|
+
cg.f32.const(1);
|
|
2623
|
+
cg.f32.add();
|
|
2624
|
+
cg.f32.div();
|
|
2625
|
+
_atan(cg);
|
|
2626
|
+
cg.f32.const(2);
|
|
2627
|
+
cg.f32.mul();
|
|
2628
|
+
});
|
|
2629
|
+
}
|
|
2463
2630
|
/**
|
|
2464
2631
|
* Threefry2x32 pseudorandom number generator.
|
|
2465
2632
|
*
|
|
@@ -3339,6 +3506,8 @@ function codegenWasm(kernel) {
|
|
|
3339
3506
|
const funcs = {};
|
|
3340
3507
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3341
3508
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3509
|
+
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3510
|
+
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3342
3511
|
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3343
3512
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3344
3513
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
@@ -3490,6 +3659,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3490
3659
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3491
3660
|
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
|
|
3492
3661
|
else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
|
|
3662
|
+
else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
|
|
3663
|
+
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3493
3664
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3494
3665
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3495
3666
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
@@ -3588,10 +3759,11 @@ let defaultBackend = "wasm";
|
|
|
3588
3759
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3589
3760
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3590
3761
|
initializedBackends.set("wasm", new WasmBackend());
|
|
3591
|
-
/**
|
|
3592
|
-
function
|
|
3593
|
-
if (initializedBackends.has(device)) defaultBackend = device;
|
|
3762
|
+
/** Configure the default device for arrays. */
|
|
3763
|
+
function defaultDevice(device) {
|
|
3764
|
+
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
3594
3765
|
else throw new Error(`Backend not initialized: ${device}`);
|
|
3766
|
+
return defaultBackend;
|
|
3595
3767
|
}
|
|
3596
3768
|
/**
|
|
3597
3769
|
* Initialize `jax-js` library backends.
|
|
@@ -3618,7 +3790,7 @@ async function createBackend(device) {
|
|
|
3618
3790
|
if (!navigator.gpu) return null;
|
|
3619
3791
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3620
3792
|
if (!adapter) return null;
|
|
3621
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3793
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-CNOpiO5T.cjs"));
|
|
3622
3794
|
const importantLimits = [
|
|
3623
3795
|
"maxBufferSize",
|
|
3624
3796
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3785,6 +3957,12 @@ Object.defineProperty(exports, 'deepEqual', {
|
|
|
3785
3957
|
return deepEqual;
|
|
3786
3958
|
}
|
|
3787
3959
|
});
|
|
3960
|
+
Object.defineProperty(exports, 'defaultDevice', {
|
|
3961
|
+
enumerable: true,
|
|
3962
|
+
get: function () {
|
|
3963
|
+
return defaultDevice;
|
|
3964
|
+
}
|
|
3965
|
+
});
|
|
3788
3966
|
Object.defineProperty(exports, 'devices', {
|
|
3789
3967
|
enumerable: true,
|
|
3790
3968
|
get: function () {
|
|
@@ -3809,6 +3987,12 @@ Object.defineProperty(exports, 'findPow2', {
|
|
|
3809
3987
|
return findPow2;
|
|
3810
3988
|
}
|
|
3811
3989
|
});
|
|
3990
|
+
Object.defineProperty(exports, 'generalBroadcast', {
|
|
3991
|
+
enumerable: true,
|
|
3992
|
+
get: function () {
|
|
3993
|
+
return generalBroadcast;
|
|
3994
|
+
}
|
|
3995
|
+
});
|
|
3812
3996
|
Object.defineProperty(exports, 'getBackend', {
|
|
3813
3997
|
enumerable: true,
|
|
3814
3998
|
get: function () {
|
|
@@ -3845,6 +4029,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
3845
4029
|
return isPermutation;
|
|
3846
4030
|
}
|
|
3847
4031
|
});
|
|
4032
|
+
Object.defineProperty(exports, 'normalizeAxis', {
|
|
4033
|
+
enumerable: true,
|
|
4034
|
+
get: function () {
|
|
4035
|
+
return normalizeAxis;
|
|
4036
|
+
}
|
|
4037
|
+
});
|
|
3848
4038
|
Object.defineProperty(exports, 'partitionList', {
|
|
3849
4039
|
enumerable: true,
|
|
3850
4040
|
get: function () {
|
|
@@ -3857,6 +4047,12 @@ Object.defineProperty(exports, 'prod', {
|
|
|
3857
4047
|
return prod;
|
|
3858
4048
|
}
|
|
3859
4049
|
});
|
|
4050
|
+
Object.defineProperty(exports, 'promoteTypes', {
|
|
4051
|
+
enumerable: true,
|
|
4052
|
+
get: function () {
|
|
4053
|
+
return promoteTypes;
|
|
4054
|
+
}
|
|
4055
|
+
});
|
|
3860
4056
|
Object.defineProperty(exports, 'range', {
|
|
3861
4057
|
enumerable: true,
|
|
3862
4058
|
get: function () {
|
|
@@ -3881,10 +4077,10 @@ Object.defineProperty(exports, 'runWithCache', {
|
|
|
3881
4077
|
return runWithCache;
|
|
3882
4078
|
}
|
|
3883
4079
|
});
|
|
3884
|
-
Object.defineProperty(exports, '
|
|
4080
|
+
Object.defineProperty(exports, 'setDebug', {
|
|
3885
4081
|
enumerable: true,
|
|
3886
4082
|
get: function () {
|
|
3887
|
-
return
|
|
4083
|
+
return setDebug;
|
|
3888
4084
|
}
|
|
3889
4085
|
});
|
|
3890
4086
|
Object.defineProperty(exports, 'strip1', {
|