@jax-js/jax 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -19
- package/dist/{backend-BqDtPGaR.js → backend-EBRGmEYw.js} +296 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-Ss1Mev_-.cjs} +315 -154
- package/dist/index.cjs +681 -157
- package/dist/index.d.cts +422 -76
- package/dist/index.d.ts +422 -76
- package/dist/index.js +677 -157
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-BVdMaO9T.cjs} +9 -3
- package/dist/{webgpu-CNg9JGva.js → webgpu-ow0Pn_6q.js} +9 -3
- package/package.json +15 -4
|
@@ -36,7 +36,22 @@ var PPrint = class PPrint {
|
|
|
36
36
|
//#endregion
|
|
37
37
|
//#region src/utils.ts
|
|
38
38
|
/** @file Generic programming utilities with no dependencies on library code. */
|
|
39
|
-
|
|
39
|
+
let DEBUG = 0;
|
|
40
|
+
/**
|
|
41
|
+
* Set the debug level for verbose logging.
|
|
42
|
+
*
|
|
43
|
+
* 1. JIT compile logs
|
|
44
|
+
* 2. Shader code
|
|
45
|
+
* 3. Expressions and metadata
|
|
46
|
+
* 4. JIT programs, tuning details
|
|
47
|
+
* 5. Most verbose operation traces
|
|
48
|
+
*
|
|
49
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
50
|
+
* in production.
|
|
51
|
+
*/
|
|
52
|
+
function setDebug(level) {
|
|
53
|
+
DEBUG = level;
|
|
54
|
+
}
|
|
40
55
|
function unzip2(pairs) {
|
|
41
56
|
const lst1 = [];
|
|
42
57
|
const lst2 = [];
|
|
@@ -110,9 +125,23 @@ function isNumberPair(x) {
|
|
|
110
125
|
}
|
|
111
126
|
/** Check an axis against number of dimensions, and resolve negative axes. */
|
|
112
127
|
function checkAxis(axis, ndim) {
|
|
113
|
-
if (axis < -ndim || axis >= ndim) throw new Error(`
|
|
128
|
+
if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
|
|
114
129
|
return axis < 0 ? axis + ndim : axis;
|
|
115
130
|
}
|
|
131
|
+
/** Normalize common axis argument for functions, defaulting to all axes. */
|
|
132
|
+
function normalizeAxis(axis, ndim) {
|
|
133
|
+
if (axis === null) return range(ndim);
|
|
134
|
+
else if (typeof axis === "number") return [checkAxis(axis, ndim)];
|
|
135
|
+
else {
|
|
136
|
+
const seen = /* @__PURE__ */ new Set();
|
|
137
|
+
for (const a of axis) {
|
|
138
|
+
const ca = checkAxis(a, ndim);
|
|
139
|
+
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
140
|
+
seen.add(ca);
|
|
141
|
+
}
|
|
142
|
+
return [...seen].sort();
|
|
143
|
+
}
|
|
144
|
+
}
|
|
116
145
|
function range(start, stop, step = 1) {
|
|
117
146
|
if (stop === void 0) {
|
|
118
147
|
stop = start;
|
|
@@ -187,6 +216,7 @@ function strip1(str) {
|
|
|
187
216
|
if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
|
|
188
217
|
return str;
|
|
189
218
|
}
|
|
219
|
+
const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
|
|
190
220
|
/**
|
|
191
221
|
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
192
222
|
* Probability-wise, it's good enough to be used for something like
|
|
@@ -201,22 +231,26 @@ var FpHash = class FpHash {
|
|
|
201
231
|
const modulus = 3189051996290219n;
|
|
202
232
|
this.value = (this.value * base + x) % modulus;
|
|
203
233
|
}
|
|
204
|
-
update(
|
|
205
|
-
|
|
206
|
-
|
|
234
|
+
update(x) {
|
|
235
|
+
if (typeof x === "string") {
|
|
236
|
+
this.#update(BigInt(x.length));
|
|
237
|
+
for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
|
|
238
|
+
} else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
|
|
207
239
|
else {
|
|
208
|
-
|
|
209
|
-
this.#update(
|
|
240
|
+
_stagingbuf.setFloat64(0, x, true);
|
|
241
|
+
this.#update(_stagingbuf.getBigUint64(0, true));
|
|
210
242
|
}
|
|
211
243
|
else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
|
|
212
244
|
else if (typeof x === "bigint") this.#update(x ^ 71657401n);
|
|
213
245
|
else if (x === null) this.#update(37832657n);
|
|
214
246
|
else if (x === void 0) this.#update(18145117n);
|
|
215
|
-
else
|
|
247
|
+
else x.hash(this);
|
|
216
248
|
return this;
|
|
217
249
|
}
|
|
218
250
|
static hash(...values) {
|
|
219
|
-
|
|
251
|
+
const h = new FpHash();
|
|
252
|
+
for (const x of values) h.update(x);
|
|
253
|
+
return h.value;
|
|
220
254
|
}
|
|
221
255
|
};
|
|
222
256
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -251,6 +285,41 @@ const byteWidth = (dtype) => {
|
|
|
251
285
|
}
|
|
252
286
|
};
|
|
253
287
|
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
288
|
+
/**
|
|
289
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
290
|
+
*
|
|
291
|
+
* When performing operations between arrays of different types, we need to
|
|
292
|
+
* promote both operands to a common type that can represent values from both
|
|
293
|
+
* input types. This follows JAX's type promotion rules.
|
|
294
|
+
*
|
|
295
|
+
* **Type lattice:**
|
|
296
|
+
* ```text
|
|
297
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
298
|
+
* weak f* --^
|
|
299
|
+
* ```
|
|
300
|
+
*
|
|
301
|
+
* The asterisk f* is a weak type used for JS number constants. When creating
|
|
302
|
+
* arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
|
|
303
|
+
* any array they are first combined with.
|
|
304
|
+
*
|
|
305
|
+
* **Examples:**
|
|
306
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
307
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
308
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
309
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
310
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
311
|
+
*/
|
|
312
|
+
function promoteTypes(dtype1, dtype2) {
|
|
313
|
+
if (dtype1 === dtype2) return dtype1;
|
|
314
|
+
const rank = {
|
|
315
|
+
[DType.Bool]: 0,
|
|
316
|
+
[DType.Uint32]: 1,
|
|
317
|
+
[DType.Int32]: 2,
|
|
318
|
+
[DType.Float16]: 3,
|
|
319
|
+
[DType.Float32]: 4
|
|
320
|
+
};
|
|
321
|
+
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
322
|
+
}
|
|
254
323
|
function dtypedArray(dtype, data) {
|
|
255
324
|
const { buffer, byteLength, byteOffset } = data;
|
|
256
325
|
const length = byteLength / byteWidth(dtype);
|
|
@@ -320,6 +389,12 @@ var AluExp = class AluExp {
|
|
|
320
389
|
static cos(a) {
|
|
321
390
|
return new AluExp(AluOp.Cos, a.dtype, [a]);
|
|
322
391
|
}
|
|
392
|
+
static asin(a) {
|
|
393
|
+
return new AluExp(AluOp.Asin, a.dtype, [a]);
|
|
394
|
+
}
|
|
395
|
+
static atan(a) {
|
|
396
|
+
return new AluExp(AluOp.Atan, a.dtype, [a]);
|
|
397
|
+
}
|
|
323
398
|
static exp(a) {
|
|
324
399
|
return new AluExp(AluOp.Exp, a.dtype, [a]);
|
|
325
400
|
}
|
|
@@ -403,8 +478,11 @@ var AluExp = class AluExp {
|
|
|
403
478
|
getHash() {
|
|
404
479
|
if (this.#hash !== void 0) return this.#hash;
|
|
405
480
|
const hasher = new FpHash();
|
|
406
|
-
hasher.update(this.op
|
|
407
|
-
hasher.update(this.
|
|
481
|
+
hasher.update(this.op);
|
|
482
|
+
hasher.update(this.dtype);
|
|
483
|
+
hasher.update(JSON.stringify(this.arg));
|
|
484
|
+
hasher.update(this.src.length);
|
|
485
|
+
for (const s of this.src) hasher.update(s);
|
|
408
486
|
this.#hash = hasher.value;
|
|
409
487
|
return this.#hash;
|
|
410
488
|
}
|
|
@@ -476,10 +554,16 @@ var AluExp = class AluExp {
|
|
|
476
554
|
ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
|
|
477
555
|
break;
|
|
478
556
|
case AluOp.Sin:
|
|
479
|
-
ret = [
|
|
557
|
+
ret = [-1, 1];
|
|
480
558
|
break;
|
|
481
559
|
case AluOp.Cos:
|
|
482
|
-
ret = [
|
|
560
|
+
ret = [-1, 1];
|
|
561
|
+
break;
|
|
562
|
+
case AluOp.Asin:
|
|
563
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
564
|
+
break;
|
|
565
|
+
case AluOp.Atan:
|
|
566
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
483
567
|
break;
|
|
484
568
|
case AluOp.Exp:
|
|
485
569
|
ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
|
|
@@ -595,11 +679,12 @@ var AluExp = class AluExp {
|
|
|
595
679
|
simplify(cache = /* @__PURE__ */ new Map()) {
|
|
596
680
|
if (this.#simplified !== void 0) return this.#simplified;
|
|
597
681
|
const hash = this.getHash();
|
|
598
|
-
|
|
682
|
+
const prevCachedValue = cache.get(hash);
|
|
683
|
+
if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
|
|
599
684
|
const simplified = this.#simplifyInner(cache);
|
|
600
685
|
const simplifiedHash = simplified.getHash();
|
|
601
|
-
|
|
602
|
-
|
|
686
|
+
const prevSimplified = cache.get(simplifiedHash);
|
|
687
|
+
if (prevSimplified !== void 0) {
|
|
603
688
|
cache.set(hash, prevSimplified);
|
|
604
689
|
this.#simplified = prevSimplified;
|
|
605
690
|
return prevSimplified;
|
|
@@ -803,6 +888,8 @@ var AluExp = class AluExp {
|
|
|
803
888
|
switch (this.op) {
|
|
804
889
|
case AluOp.Sin: return Math.sin(x);
|
|
805
890
|
case AluOp.Cos: return Math.cos(x);
|
|
891
|
+
case AluOp.Asin: return Math.asin(x);
|
|
892
|
+
case AluOp.Atan: return Math.atan(x);
|
|
806
893
|
case AluOp.Exp: return Math.exp(x);
|
|
807
894
|
case AluOp.Log: return Math.log(x);
|
|
808
895
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
@@ -982,6 +1069,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
982
1069
|
AluOp$1["Max"] = "Max";
|
|
983
1070
|
AluOp$1["Sin"] = "Sin";
|
|
984
1071
|
AluOp$1["Cos"] = "Cos";
|
|
1072
|
+
AluOp$1["Asin"] = "Asin";
|
|
1073
|
+
AluOp$1["Atan"] = "Atan";
|
|
985
1074
|
AluOp$1["Exp"] = "Exp";
|
|
986
1075
|
AluOp$1["Log"] = "Log";
|
|
987
1076
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
@@ -1012,6 +1101,8 @@ const AluGroup = {
|
|
|
1012
1101
|
Unary: new Set([
|
|
1013
1102
|
AluOp.Sin,
|
|
1014
1103
|
AluOp.Cos,
|
|
1104
|
+
AluOp.Asin,
|
|
1105
|
+
AluOp.Atan,
|
|
1015
1106
|
AluOp.Exp,
|
|
1016
1107
|
AluOp.Log,
|
|
1017
1108
|
AluOp.Sqrt,
|
|
@@ -1035,6 +1126,8 @@ const AluGroup = {
|
|
|
1035
1126
|
RequiredFloat: new Set([
|
|
1036
1127
|
AluOp.Sin,
|
|
1037
1128
|
AluOp.Cos,
|
|
1129
|
+
AluOp.Asin,
|
|
1130
|
+
AluOp.Atan,
|
|
1038
1131
|
AluOp.Exp,
|
|
1039
1132
|
AluOp.Log,
|
|
1040
1133
|
AluOp.Sqrt,
|
|
@@ -1066,7 +1159,7 @@ var Kernel = class {
|
|
|
1066
1159
|
this.exp = exp.simplify();
|
|
1067
1160
|
}
|
|
1068
1161
|
hash(state) {
|
|
1069
|
-
state.update(this.nargs
|
|
1162
|
+
state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
|
|
1070
1163
|
}
|
|
1071
1164
|
pprint() {
|
|
1072
1165
|
let details = PPrint.pp(`exp = ${this.exp}`);
|
|
@@ -1112,7 +1205,7 @@ var Reduction = class {
|
|
|
1112
1205
|
this.epilogue = epilogue.simplify();
|
|
1113
1206
|
}
|
|
1114
1207
|
hash(state) {
|
|
1115
|
-
state.update(this.dtype
|
|
1208
|
+
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
1116
1209
|
}
|
|
1117
1210
|
toString() {
|
|
1118
1211
|
return `${this.op}{${this.size}} -> ${this.epilogue}`;
|
|
@@ -2284,78 +2377,92 @@ function wasm_log(cg) {
|
|
|
2284
2377
|
});
|
|
2285
2378
|
}
|
|
2286
2379
|
/**
|
|
2287
|
-
*
|
|
2380
|
+
* Common helper to approximate sin(x) and cos(x).
|
|
2288
2381
|
*
|
|
2289
2382
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2290
|
-
* z = y - q*(π/2); use
|
|
2383
|
+
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2291
2384
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2385
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2386
|
+
*/
|
|
2387
|
+
function _sincos(cg) {
|
|
2388
|
+
const y = cg.local.declare(cg.f32);
|
|
2389
|
+
const qf = cg.local.declare(cg.f32);
|
|
2390
|
+
const q = cg.local.declare(cg.i32);
|
|
2391
|
+
const z = cg.local.declare(cg.f32);
|
|
2392
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2393
|
+
const sz = cg.local.declare(cg.f32);
|
|
2394
|
+
const cz = cg.local.declare(cg.f32);
|
|
2395
|
+
cg.local.get(0);
|
|
2396
|
+
cg.local.get(0);
|
|
2397
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2398
|
+
cg.f32.mul();
|
|
2399
|
+
cg.f32.nearest();
|
|
2400
|
+
cg.local.tee(qf);
|
|
2401
|
+
cg.f32.const(2 * Math.PI);
|
|
2402
|
+
cg.f32.mul();
|
|
2403
|
+
cg.f32.sub();
|
|
2404
|
+
cg.local.set(y);
|
|
2405
|
+
cg.local.get(y);
|
|
2406
|
+
cg.f32.const(2 / Math.PI);
|
|
2407
|
+
cg.f32.mul();
|
|
2408
|
+
cg.f32.nearest();
|
|
2409
|
+
cg.local.tee(qf);
|
|
2410
|
+
cg.i32.trunc_f32_s();
|
|
2411
|
+
cg.local.set(q);
|
|
2412
|
+
cg.local.get(y);
|
|
2413
|
+
cg.local.get(qf);
|
|
2414
|
+
cg.f32.const(Math.PI / 2);
|
|
2415
|
+
cg.f32.mul();
|
|
2416
|
+
cg.f32.sub();
|
|
2417
|
+
cg.local.tee(z);
|
|
2418
|
+
cg.local.get(z);
|
|
2419
|
+
cg.f32.mul();
|
|
2420
|
+
cg.local.set(z2);
|
|
2421
|
+
cg.f32.const(-1 / 5040);
|
|
2422
|
+
cg.local.get(z2);
|
|
2423
|
+
cg.f32.mul();
|
|
2424
|
+
cg.f32.const(1 / 120);
|
|
2425
|
+
cg.f32.add();
|
|
2426
|
+
cg.local.get(z2);
|
|
2427
|
+
cg.f32.mul();
|
|
2428
|
+
cg.f32.const(-1 / 6);
|
|
2429
|
+
cg.f32.add();
|
|
2430
|
+
cg.local.get(z2);
|
|
2431
|
+
cg.f32.mul();
|
|
2432
|
+
cg.f32.const(1);
|
|
2433
|
+
cg.f32.add();
|
|
2434
|
+
cg.local.get(z);
|
|
2435
|
+
cg.f32.mul();
|
|
2436
|
+
cg.local.set(sz);
|
|
2437
|
+
cg.f32.const(-1 / 720);
|
|
2438
|
+
cg.local.get(z2);
|
|
2439
|
+
cg.f32.mul();
|
|
2440
|
+
cg.f32.const(1 / 24);
|
|
2441
|
+
cg.f32.add();
|
|
2442
|
+
cg.local.get(z2);
|
|
2443
|
+
cg.f32.mul();
|
|
2444
|
+
cg.f32.const(-1 / 2);
|
|
2445
|
+
cg.f32.add();
|
|
2446
|
+
cg.local.get(z2);
|
|
2447
|
+
cg.f32.mul();
|
|
2448
|
+
cg.f32.const(1);
|
|
2449
|
+
cg.f32.add();
|
|
2450
|
+
cg.local.set(cz);
|
|
2451
|
+
return {
|
|
2452
|
+
q,
|
|
2453
|
+
sz,
|
|
2454
|
+
cz
|
|
2455
|
+
};
|
|
2456
|
+
}
|
|
2457
|
+
/**
|
|
2458
|
+
* Approximate sin(x).
|
|
2459
|
+
*
|
|
2460
|
+
* Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
|
|
2292
2461
|
*/
|
|
2293
2462
|
function wasm_sin(cg) {
|
|
2294
2463
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2295
|
-
const
|
|
2296
|
-
const qf = cg.local.declare(cg.f32);
|
|
2297
|
-
const q = cg.local.declare(cg.i32);
|
|
2298
|
-
const z = cg.local.declare(cg.f32);
|
|
2299
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2300
|
-
const sz = cg.local.declare(cg.f32);
|
|
2301
|
-
const cz = cg.local.declare(cg.f32);
|
|
2464
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2302
2465
|
const mag = cg.local.declare(cg.f32);
|
|
2303
|
-
cg.local.get(0);
|
|
2304
|
-
cg.local.get(0);
|
|
2305
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2306
|
-
cg.f32.mul();
|
|
2307
|
-
cg.f32.nearest();
|
|
2308
|
-
cg.local.tee(qf);
|
|
2309
|
-
cg.f32.const(2 * Math.PI);
|
|
2310
|
-
cg.f32.mul();
|
|
2311
|
-
cg.f32.sub();
|
|
2312
|
-
cg.local.set(y);
|
|
2313
|
-
cg.local.get(y);
|
|
2314
|
-
cg.f32.const(2 / Math.PI);
|
|
2315
|
-
cg.f32.mul();
|
|
2316
|
-
cg.f32.nearest();
|
|
2317
|
-
cg.local.tee(qf);
|
|
2318
|
-
cg.i32.trunc_f32_s();
|
|
2319
|
-
cg.local.set(q);
|
|
2320
|
-
cg.local.get(y);
|
|
2321
|
-
cg.local.get(qf);
|
|
2322
|
-
cg.f32.const(Math.PI / 2);
|
|
2323
|
-
cg.f32.mul();
|
|
2324
|
-
cg.f32.sub();
|
|
2325
|
-
cg.local.tee(z);
|
|
2326
|
-
cg.local.get(z);
|
|
2327
|
-
cg.f32.mul();
|
|
2328
|
-
cg.local.set(z2);
|
|
2329
|
-
cg.f32.const(-1 / 5040);
|
|
2330
|
-
cg.local.get(z2);
|
|
2331
|
-
cg.f32.mul();
|
|
2332
|
-
cg.f32.const(1 / 120);
|
|
2333
|
-
cg.f32.add();
|
|
2334
|
-
cg.local.get(z2);
|
|
2335
|
-
cg.f32.mul();
|
|
2336
|
-
cg.f32.const(-1 / 6);
|
|
2337
|
-
cg.f32.add();
|
|
2338
|
-
cg.local.get(z2);
|
|
2339
|
-
cg.f32.mul();
|
|
2340
|
-
cg.f32.const(1);
|
|
2341
|
-
cg.f32.add();
|
|
2342
|
-
cg.local.get(z);
|
|
2343
|
-
cg.f32.mul();
|
|
2344
|
-
cg.local.set(sz);
|
|
2345
|
-
cg.f32.const(-1 / 720);
|
|
2346
|
-
cg.local.get(z2);
|
|
2347
|
-
cg.f32.mul();
|
|
2348
|
-
cg.f32.const(1 / 24);
|
|
2349
|
-
cg.f32.add();
|
|
2350
|
-
cg.local.get(z2);
|
|
2351
|
-
cg.f32.mul();
|
|
2352
|
-
cg.f32.const(-1 / 2);
|
|
2353
|
-
cg.f32.add();
|
|
2354
|
-
cg.local.get(z2);
|
|
2355
|
-
cg.f32.mul();
|
|
2356
|
-
cg.f32.const(1);
|
|
2357
|
-
cg.f32.add();
|
|
2358
|
-
cg.local.set(cz);
|
|
2359
2466
|
cg.local.get(cz);
|
|
2360
2467
|
cg.local.get(sz);
|
|
2361
2468
|
cg.local.get(q);
|
|
@@ -2374,75 +2481,12 @@ function wasm_sin(cg) {
|
|
|
2374
2481
|
/**
|
|
2375
2482
|
* Approximate cos(x).
|
|
2376
2483
|
*
|
|
2377
|
-
*
|
|
2378
|
-
* k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2484
|
+
* Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2379
2485
|
*/
|
|
2380
2486
|
function wasm_cos(cg) {
|
|
2381
2487
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2382
|
-
const
|
|
2383
|
-
const qf = cg.local.declare(cg.f32);
|
|
2384
|
-
const q = cg.local.declare(cg.i32);
|
|
2385
|
-
const z = cg.local.declare(cg.f32);
|
|
2386
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2387
|
-
const sz = cg.local.declare(cg.f32);
|
|
2388
|
-
const cz = cg.local.declare(cg.f32);
|
|
2488
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2389
2489
|
const mag = cg.local.declare(cg.f32);
|
|
2390
|
-
cg.local.get(0);
|
|
2391
|
-
cg.local.get(0);
|
|
2392
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2393
|
-
cg.f32.mul();
|
|
2394
|
-
cg.f32.nearest();
|
|
2395
|
-
cg.local.tee(qf);
|
|
2396
|
-
cg.f32.const(2 * Math.PI);
|
|
2397
|
-
cg.f32.mul();
|
|
2398
|
-
cg.f32.sub();
|
|
2399
|
-
cg.local.set(y);
|
|
2400
|
-
cg.local.get(y);
|
|
2401
|
-
cg.f32.const(2 / Math.PI);
|
|
2402
|
-
cg.f32.mul();
|
|
2403
|
-
cg.f32.nearest();
|
|
2404
|
-
cg.local.tee(qf);
|
|
2405
|
-
cg.i32.trunc_f32_s();
|
|
2406
|
-
cg.local.set(q);
|
|
2407
|
-
cg.local.get(y);
|
|
2408
|
-
cg.local.get(qf);
|
|
2409
|
-
cg.f32.const(Math.PI / 2);
|
|
2410
|
-
cg.f32.mul();
|
|
2411
|
-
cg.f32.sub();
|
|
2412
|
-
cg.local.tee(z);
|
|
2413
|
-
cg.local.get(z);
|
|
2414
|
-
cg.f32.mul();
|
|
2415
|
-
cg.local.set(z2);
|
|
2416
|
-
cg.f32.const(-1 / 5040);
|
|
2417
|
-
cg.local.get(z2);
|
|
2418
|
-
cg.f32.mul();
|
|
2419
|
-
cg.f32.const(1 / 120);
|
|
2420
|
-
cg.f32.add();
|
|
2421
|
-
cg.local.get(z2);
|
|
2422
|
-
cg.f32.mul();
|
|
2423
|
-
cg.f32.const(-1 / 6);
|
|
2424
|
-
cg.f32.add();
|
|
2425
|
-
cg.local.get(z2);
|
|
2426
|
-
cg.f32.mul();
|
|
2427
|
-
cg.f32.const(1);
|
|
2428
|
-
cg.f32.add();
|
|
2429
|
-
cg.local.get(z);
|
|
2430
|
-
cg.f32.mul();
|
|
2431
|
-
cg.local.set(sz);
|
|
2432
|
-
cg.f32.const(-1 / 720);
|
|
2433
|
-
cg.local.get(z2);
|
|
2434
|
-
cg.f32.mul();
|
|
2435
|
-
cg.f32.const(1 / 24);
|
|
2436
|
-
cg.f32.add();
|
|
2437
|
-
cg.local.get(z2);
|
|
2438
|
-
cg.f32.mul();
|
|
2439
|
-
cg.f32.const(-1 / 2);
|
|
2440
|
-
cg.f32.add();
|
|
2441
|
-
cg.local.get(z2);
|
|
2442
|
-
cg.f32.mul();
|
|
2443
|
-
cg.f32.const(1);
|
|
2444
|
-
cg.f32.add();
|
|
2445
|
-
cg.local.set(cz);
|
|
2446
2490
|
cg.local.get(sz);
|
|
2447
2491
|
cg.local.get(cz);
|
|
2448
2492
|
cg.local.get(q);
|
|
@@ -2460,6 +2504,100 @@ function wasm_cos(cg) {
|
|
|
2460
2504
|
cg.select();
|
|
2461
2505
|
});
|
|
2462
2506
|
}
|
|
2507
|
+
/** Helper function for approximating arctan(x). */
|
|
2508
|
+
function _atan(cg) {
|
|
2509
|
+
const x = cg.local.declare(cg.f32);
|
|
2510
|
+
const abs_x = cg.local.declare(cg.f32);
|
|
2511
|
+
const z = cg.local.declare(cg.f32);
|
|
2512
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2513
|
+
const p = cg.local.declare(cg.f32);
|
|
2514
|
+
cg.local.set(x);
|
|
2515
|
+
cg.local.get(x);
|
|
2516
|
+
cg.f32.abs();
|
|
2517
|
+
cg.local.set(abs_x);
|
|
2518
|
+
cg.f32.const(1);
|
|
2519
|
+
cg.local.get(abs_x);
|
|
2520
|
+
cg.f32.div();
|
|
2521
|
+
cg.local.get(abs_x);
|
|
2522
|
+
cg.local.get(abs_x);
|
|
2523
|
+
cg.f32.const(1);
|
|
2524
|
+
cg.f32.ge();
|
|
2525
|
+
cg.select();
|
|
2526
|
+
cg.local.set(z);
|
|
2527
|
+
cg.local.get(z);
|
|
2528
|
+
cg.local.get(z);
|
|
2529
|
+
cg.f32.mul();
|
|
2530
|
+
cg.local.set(z2);
|
|
2531
|
+
cg.f32.const(.0415796528637);
|
|
2532
|
+
cg.local.get(z2);
|
|
2533
|
+
cg.f32.mul();
|
|
2534
|
+
cg.f32.const(.661705427875);
|
|
2535
|
+
cg.f32.add();
|
|
2536
|
+
cg.local.get(z2);
|
|
2537
|
+
cg.f32.mul();
|
|
2538
|
+
cg.f32.const(.999998614341);
|
|
2539
|
+
cg.f32.add();
|
|
2540
|
+
cg.f32.const(.173698870181);
|
|
2541
|
+
cg.local.get(z2);
|
|
2542
|
+
cg.f32.mul();
|
|
2543
|
+
cg.f32.const(.994987933645);
|
|
2544
|
+
cg.f32.add();
|
|
2545
|
+
cg.local.get(z2);
|
|
2546
|
+
cg.f32.mul();
|
|
2547
|
+
cg.f32.const(1);
|
|
2548
|
+
cg.f32.add();
|
|
2549
|
+
cg.f32.div();
|
|
2550
|
+
cg.local.get(z);
|
|
2551
|
+
cg.f32.mul();
|
|
2552
|
+
cg.local.set(p);
|
|
2553
|
+
cg.f32.const(Math.PI / 2);
|
|
2554
|
+
cg.local.get(p);
|
|
2555
|
+
cg.f32.sub();
|
|
2556
|
+
cg.local.get(p);
|
|
2557
|
+
cg.local.get(abs_x);
|
|
2558
|
+
cg.f32.const(1);
|
|
2559
|
+
cg.f32.ge();
|
|
2560
|
+
cg.select();
|
|
2561
|
+
cg.local.get(x);
|
|
2562
|
+
cg.f32.copysign();
|
|
2563
|
+
}
|
|
2564
|
+
/**
|
|
2565
|
+
* Approximate atan(x).
|
|
2566
|
+
*
|
|
2567
|
+
* Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
|
|
2568
|
+
* where P(u) = A0 + A1*u + A2*u^2 (degree 2)
|
|
2569
|
+
* Q(u) = 1 + B1*u + B2*u^2 (degree 2)
|
|
2570
|
+
* if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
|
|
2571
|
+
* (fitted coefficients, max error ~5e-7 on [0,1])
|
|
2572
|
+
*/
|
|
2573
|
+
function wasm_atan(cg) {
|
|
2574
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2575
|
+
cg.local.get(0);
|
|
2576
|
+
_atan(cg);
|
|
2577
|
+
});
|
|
2578
|
+
}
|
|
2579
|
+
/**
|
|
2580
|
+
* Approximate asin(x).
|
|
2581
|
+
*
|
|
2582
|
+
* Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
|
2583
|
+
*/
|
|
2584
|
+
function wasm_asin(cg) {
|
|
2585
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2586
|
+
cg.local.get(0);
|
|
2587
|
+
cg.f32.const(1);
|
|
2588
|
+
cg.local.get(0);
|
|
2589
|
+
cg.local.get(0);
|
|
2590
|
+
cg.f32.mul();
|
|
2591
|
+
cg.f32.sub();
|
|
2592
|
+
cg.f32.sqrt();
|
|
2593
|
+
cg.f32.const(1);
|
|
2594
|
+
cg.f32.add();
|
|
2595
|
+
cg.f32.div();
|
|
2596
|
+
_atan(cg);
|
|
2597
|
+
cg.f32.const(2);
|
|
2598
|
+
cg.f32.mul();
|
|
2599
|
+
});
|
|
2600
|
+
}
|
|
2463
2601
|
/**
|
|
2464
2602
|
* Threefry2x32 pseudorandom number generator.
|
|
2465
2603
|
*
|
|
@@ -3339,6 +3477,8 @@ function codegenWasm(kernel) {
|
|
|
3339
3477
|
const funcs = {};
|
|
3340
3478
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3341
3479
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3480
|
+
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3481
|
+
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3342
3482
|
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3343
3483
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3344
3484
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
@@ -3490,6 +3630,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3490
3630
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3491
3631
|
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
|
|
3492
3632
|
else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
|
|
3633
|
+
else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
|
|
3634
|
+
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3493
3635
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3494
3636
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3495
3637
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
@@ -3588,10 +3730,11 @@ let defaultBackend = "wasm";
|
|
|
3588
3730
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3589
3731
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3590
3732
|
initializedBackends.set("wasm", new WasmBackend());
|
|
3591
|
-
/**
|
|
3592
|
-
function
|
|
3593
|
-
if (initializedBackends.has(device)) defaultBackend = device;
|
|
3733
|
+
/** Configure the default device for arrays. */
|
|
3734
|
+
function defaultDevice(device) {
|
|
3735
|
+
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
3594
3736
|
else throw new Error(`Backend not initialized: ${device}`);
|
|
3737
|
+
return defaultBackend;
|
|
3595
3738
|
}
|
|
3596
3739
|
/**
|
|
3597
3740
|
* Initialize `jax-js` library backends.
|
|
@@ -3618,7 +3761,7 @@ async function createBackend(device) {
|
|
|
3618
3761
|
if (!navigator.gpu) return null;
|
|
3619
3762
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3620
3763
|
if (!adapter) return null;
|
|
3621
|
-
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-
|
|
3764
|
+
const { WebGPUBackend } = await Promise.resolve().then(() => require("./webgpu-BVdMaO9T.cjs"));
|
|
3622
3765
|
const importantLimits = [
|
|
3623
3766
|
"maxBufferSize",
|
|
3624
3767
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3785,6 +3928,12 @@ Object.defineProperty(exports, 'deepEqual', {
|
|
|
3785
3928
|
return deepEqual;
|
|
3786
3929
|
}
|
|
3787
3930
|
});
|
|
3931
|
+
Object.defineProperty(exports, 'defaultDevice', {
|
|
3932
|
+
enumerable: true,
|
|
3933
|
+
get: function () {
|
|
3934
|
+
return defaultDevice;
|
|
3935
|
+
}
|
|
3936
|
+
});
|
|
3788
3937
|
Object.defineProperty(exports, 'devices', {
|
|
3789
3938
|
enumerable: true,
|
|
3790
3939
|
get: function () {
|
|
@@ -3845,6 +3994,12 @@ Object.defineProperty(exports, 'isPermutation', {
|
|
|
3845
3994
|
return isPermutation;
|
|
3846
3995
|
}
|
|
3847
3996
|
});
|
|
3997
|
+
Object.defineProperty(exports, 'normalizeAxis', {
|
|
3998
|
+
enumerable: true,
|
|
3999
|
+
get: function () {
|
|
4000
|
+
return normalizeAxis;
|
|
4001
|
+
}
|
|
4002
|
+
});
|
|
3848
4003
|
Object.defineProperty(exports, 'partitionList', {
|
|
3849
4004
|
enumerable: true,
|
|
3850
4005
|
get: function () {
|
|
@@ -3857,6 +4012,12 @@ Object.defineProperty(exports, 'prod', {
|
|
|
3857
4012
|
return prod;
|
|
3858
4013
|
}
|
|
3859
4014
|
});
|
|
4015
|
+
Object.defineProperty(exports, 'promoteTypes', {
|
|
4016
|
+
enumerable: true,
|
|
4017
|
+
get: function () {
|
|
4018
|
+
return promoteTypes;
|
|
4019
|
+
}
|
|
4020
|
+
});
|
|
3860
4021
|
Object.defineProperty(exports, 'range', {
|
|
3861
4022
|
enumerable: true,
|
|
3862
4023
|
get: function () {
|
|
@@ -3881,10 +4042,10 @@ Object.defineProperty(exports, 'runWithCache', {
|
|
|
3881
4042
|
return runWithCache;
|
|
3882
4043
|
}
|
|
3883
4044
|
});
|
|
3884
|
-
Object.defineProperty(exports, '
|
|
4045
|
+
Object.defineProperty(exports, 'setDebug', {
|
|
3885
4046
|
enumerable: true,
|
|
3886
4047
|
get: function () {
|
|
3887
|
-
return
|
|
4048
|
+
return setDebug;
|
|
3888
4049
|
}
|
|
3889
4050
|
});
|
|
3890
4051
|
Object.defineProperty(exports, 'strip1', {
|