@jax-js/jax 0.0.3 → 0.0.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +50 -19
- package/dist/{backend-BqDtPGaR.js → backend-EBRGmEYw.js} +296 -153
- package/dist/{backend-D2C4MJRP.cjs → backend-Ss1Mev_-.cjs} +315 -154
- package/dist/index.cjs +681 -157
- package/dist/index.d.cts +422 -76
- package/dist/index.d.ts +422 -76
- package/dist/index.js +677 -157
- package/dist/{webgpu-fqhx41TC.cjs → webgpu-BVdMaO9T.cjs} +9 -3
- package/dist/{webgpu-CNg9JGva.js → webgpu-ow0Pn_6q.js} +9 -3
- package/package.json +15 -4
|
@@ -35,7 +35,22 @@ var PPrint = class PPrint {
|
|
|
35
35
|
//#endregion
|
|
36
36
|
//#region src/utils.ts
|
|
37
37
|
/** @file Generic programming utilities with no dependencies on library code. */
|
|
38
|
-
|
|
38
|
+
let DEBUG = 0;
|
|
39
|
+
/**
|
|
40
|
+
* Set the debug level for verbose logging.
|
|
41
|
+
*
|
|
42
|
+
* 1. JIT compile logs
|
|
43
|
+
* 2. Shader code
|
|
44
|
+
* 3. Expressions and metadata
|
|
45
|
+
* 4. JIT programs, tuning details
|
|
46
|
+
* 5. Most verbose operation traces
|
|
47
|
+
*
|
|
48
|
+
* This is an experimental API and may change in behavior. Do not rely on this
|
|
49
|
+
* in production.
|
|
50
|
+
*/
|
|
51
|
+
function setDebug(level) {
|
|
52
|
+
DEBUG = level;
|
|
53
|
+
}
|
|
39
54
|
function unzip2(pairs) {
|
|
40
55
|
const lst1 = [];
|
|
41
56
|
const lst2 = [];
|
|
@@ -109,9 +124,23 @@ function isNumberPair(x) {
|
|
|
109
124
|
}
|
|
110
125
|
/** Check an axis against number of dimensions, and resolve negative axes. */
|
|
111
126
|
function checkAxis(axis, ndim) {
|
|
112
|
-
if (axis < -ndim || axis >= ndim) throw new Error(`
|
|
127
|
+
if (axis < -ndim || axis >= ndim) throw new Error(`Axis ${axis} out of bounds for array of dimension ${ndim}`);
|
|
113
128
|
return axis < 0 ? axis + ndim : axis;
|
|
114
129
|
}
|
|
130
|
+
/** Normalize common axis argument for functions, defaulting to all axes. */
|
|
131
|
+
function normalizeAxis(axis, ndim) {
|
|
132
|
+
if (axis === null) return range(ndim);
|
|
133
|
+
else if (typeof axis === "number") return [checkAxis(axis, ndim)];
|
|
134
|
+
else {
|
|
135
|
+
const seen = /* @__PURE__ */ new Set();
|
|
136
|
+
for (const a of axis) {
|
|
137
|
+
const ca = checkAxis(a, ndim);
|
|
138
|
+
if (seen.has(ca)) throw new Error(`Duplicate axis ${ca} passed to function`);
|
|
139
|
+
seen.add(ca);
|
|
140
|
+
}
|
|
141
|
+
return [...seen].sort();
|
|
142
|
+
}
|
|
143
|
+
}
|
|
115
144
|
function range(start, stop, step = 1) {
|
|
116
145
|
if (stop === void 0) {
|
|
117
146
|
stop = start;
|
|
@@ -186,6 +215,7 @@ function strip1(str) {
|
|
|
186
215
|
if (str[0] === "(" && str[str.length - 1] === ")") return str.slice(1, -1);
|
|
187
216
|
return str;
|
|
188
217
|
}
|
|
218
|
+
const _stagingbuf = /* @__PURE__ */ new DataView(/* @__PURE__ */ new ArrayBuffer(8));
|
|
189
219
|
/**
|
|
190
220
|
* Polynomial hashes modulo p are good at avoiding collisions in expectation.
|
|
191
221
|
* Probability-wise, it's good enough to be used for something like
|
|
@@ -200,22 +230,26 @@ var FpHash = class FpHash {
|
|
|
200
230
|
const modulus = 3189051996290219n;
|
|
201
231
|
this.value = (this.value * base + x) % modulus;
|
|
202
232
|
}
|
|
203
|
-
update(
|
|
204
|
-
|
|
205
|
-
|
|
233
|
+
update(x) {
|
|
234
|
+
if (typeof x === "string") {
|
|
235
|
+
this.#update(BigInt(x.length));
|
|
236
|
+
for (let i = 0; i < x.length; i++) this.#update(BigInt(199 + x.charCodeAt(i)));
|
|
237
|
+
} else if (typeof x === "number") if (Number.isInteger(x)) this.#update(68265653n ^ BigInt(x));
|
|
206
238
|
else {
|
|
207
|
-
|
|
208
|
-
this.#update(
|
|
239
|
+
_stagingbuf.setFloat64(0, x, true);
|
|
240
|
+
this.#update(_stagingbuf.getBigUint64(0, true));
|
|
209
241
|
}
|
|
210
242
|
else if (typeof x === "boolean") this.#update(x ? 69069841n : 63640693n);
|
|
211
243
|
else if (typeof x === "bigint") this.#update(x ^ 71657401n);
|
|
212
244
|
else if (x === null) this.#update(37832657n);
|
|
213
245
|
else if (x === void 0) this.#update(18145117n);
|
|
214
|
-
else
|
|
246
|
+
else x.hash(this);
|
|
215
247
|
return this;
|
|
216
248
|
}
|
|
217
249
|
static hash(...values) {
|
|
218
|
-
|
|
250
|
+
const h = new FpHash();
|
|
251
|
+
for (const x of values) h.update(x);
|
|
252
|
+
return h.value;
|
|
219
253
|
}
|
|
220
254
|
};
|
|
221
255
|
/** Run a function while caching it inline inside a `Map`. */
|
|
@@ -250,6 +284,41 @@ const byteWidth = (dtype) => {
|
|
|
250
284
|
}
|
|
251
285
|
};
|
|
252
286
|
const isFloatDtype = (dtype) => dtype === DType.Float32 || dtype === DType.Float16;
|
|
287
|
+
/**
|
|
288
|
+
* Promote two dtypes to their join according to the type lattice.
|
|
289
|
+
*
|
|
290
|
+
* When performing operations between arrays of different types, we need to
|
|
291
|
+
* promote both operands to a common type that can represent values from both
|
|
292
|
+
* input types. This follows JAX's type promotion rules.
|
|
293
|
+
*
|
|
294
|
+
* **Type lattice:**
|
|
295
|
+
* ```text
|
|
296
|
+
* bool -> uint32 -> int32 -> float16 -> float32
|
|
297
|
+
* weak f* --^
|
|
298
|
+
* ```
|
|
299
|
+
*
|
|
300
|
+
* The asterisk f* is a weak type used for JS number constants. When creating
|
|
301
|
+
* arrays, JS numbers default to float32 but "weak" so they cast to the dtype of
|
|
302
|
+
* any array they are first combined with.
|
|
303
|
+
*
|
|
304
|
+
* **Examples:**
|
|
305
|
+
* - `promoteTypes(bool, int32) → int32`
|
|
306
|
+
* - `promoteTypes(uint32, int32) → int32`
|
|
307
|
+
* - `promoteTypes(int32, float16) → float16`
|
|
308
|
+
* - `promoteTypes(float16, float32) → float32`
|
|
309
|
+
* - `promoteTypes(uint32, float32) → float32`
|
|
310
|
+
*/
|
|
311
|
+
function promoteTypes(dtype1, dtype2) {
|
|
312
|
+
if (dtype1 === dtype2) return dtype1;
|
|
313
|
+
const rank = {
|
|
314
|
+
[DType.Bool]: 0,
|
|
315
|
+
[DType.Uint32]: 1,
|
|
316
|
+
[DType.Int32]: 2,
|
|
317
|
+
[DType.Float16]: 3,
|
|
318
|
+
[DType.Float32]: 4
|
|
319
|
+
};
|
|
320
|
+
return rank[dtype1] > rank[dtype2] ? dtype1 : dtype2;
|
|
321
|
+
}
|
|
253
322
|
function dtypedArray(dtype, data) {
|
|
254
323
|
const { buffer, byteLength, byteOffset } = data;
|
|
255
324
|
const length = byteLength / byteWidth(dtype);
|
|
@@ -319,6 +388,12 @@ var AluExp = class AluExp {
|
|
|
319
388
|
static cos(a) {
|
|
320
389
|
return new AluExp(AluOp.Cos, a.dtype, [a]);
|
|
321
390
|
}
|
|
391
|
+
static asin(a) {
|
|
392
|
+
return new AluExp(AluOp.Asin, a.dtype, [a]);
|
|
393
|
+
}
|
|
394
|
+
static atan(a) {
|
|
395
|
+
return new AluExp(AluOp.Atan, a.dtype, [a]);
|
|
396
|
+
}
|
|
322
397
|
static exp(a) {
|
|
323
398
|
return new AluExp(AluOp.Exp, a.dtype, [a]);
|
|
324
399
|
}
|
|
@@ -402,8 +477,11 @@ var AluExp = class AluExp {
|
|
|
402
477
|
getHash() {
|
|
403
478
|
if (this.#hash !== void 0) return this.#hash;
|
|
404
479
|
const hasher = new FpHash();
|
|
405
|
-
hasher.update(this.op
|
|
406
|
-
hasher.update(this.
|
|
480
|
+
hasher.update(this.op);
|
|
481
|
+
hasher.update(this.dtype);
|
|
482
|
+
hasher.update(JSON.stringify(this.arg));
|
|
483
|
+
hasher.update(this.src.length);
|
|
484
|
+
for (const s of this.src) hasher.update(s);
|
|
407
485
|
this.#hash = hasher.value;
|
|
408
486
|
return this.#hash;
|
|
409
487
|
}
|
|
@@ -475,10 +553,16 @@ var AluExp = class AluExp {
|
|
|
475
553
|
ret = [Math.max(src[0].min, src[1].min), Math.max(src[0].max, src[1].max)];
|
|
476
554
|
break;
|
|
477
555
|
case AluOp.Sin:
|
|
478
|
-
ret = [
|
|
556
|
+
ret = [-1, 1];
|
|
479
557
|
break;
|
|
480
558
|
case AluOp.Cos:
|
|
481
|
-
ret = [
|
|
559
|
+
ret = [-1, 1];
|
|
560
|
+
break;
|
|
561
|
+
case AluOp.Asin:
|
|
562
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
563
|
+
break;
|
|
564
|
+
case AluOp.Atan:
|
|
565
|
+
ret = [-Math.PI / 2, Math.PI / 2];
|
|
482
566
|
break;
|
|
483
567
|
case AluOp.Exp:
|
|
484
568
|
ret = [Math.exp(src[0].min), Math.exp(src[0].max)];
|
|
@@ -594,11 +678,12 @@ var AluExp = class AluExp {
|
|
|
594
678
|
simplify(cache = /* @__PURE__ */ new Map()) {
|
|
595
679
|
if (this.#simplified !== void 0) return this.#simplified;
|
|
596
680
|
const hash = this.getHash();
|
|
597
|
-
|
|
681
|
+
const prevCachedValue = cache.get(hash);
|
|
682
|
+
if (prevCachedValue !== void 0) return this.#simplified = prevCachedValue;
|
|
598
683
|
const simplified = this.#simplifyInner(cache);
|
|
599
684
|
const simplifiedHash = simplified.getHash();
|
|
600
|
-
|
|
601
|
-
|
|
685
|
+
const prevSimplified = cache.get(simplifiedHash);
|
|
686
|
+
if (prevSimplified !== void 0) {
|
|
602
687
|
cache.set(hash, prevSimplified);
|
|
603
688
|
this.#simplified = prevSimplified;
|
|
604
689
|
return prevSimplified;
|
|
@@ -802,6 +887,8 @@ var AluExp = class AluExp {
|
|
|
802
887
|
switch (this.op) {
|
|
803
888
|
case AluOp.Sin: return Math.sin(x);
|
|
804
889
|
case AluOp.Cos: return Math.cos(x);
|
|
890
|
+
case AluOp.Asin: return Math.asin(x);
|
|
891
|
+
case AluOp.Atan: return Math.atan(x);
|
|
805
892
|
case AluOp.Exp: return Math.exp(x);
|
|
806
893
|
case AluOp.Log: return Math.log(x);
|
|
807
894
|
case AluOp.Sqrt: return Math.sqrt(x);
|
|
@@ -981,6 +1068,8 @@ let AluOp = /* @__PURE__ */ function(AluOp$1) {
|
|
|
981
1068
|
AluOp$1["Max"] = "Max";
|
|
982
1069
|
AluOp$1["Sin"] = "Sin";
|
|
983
1070
|
AluOp$1["Cos"] = "Cos";
|
|
1071
|
+
AluOp$1["Asin"] = "Asin";
|
|
1072
|
+
AluOp$1["Atan"] = "Atan";
|
|
984
1073
|
AluOp$1["Exp"] = "Exp";
|
|
985
1074
|
AluOp$1["Log"] = "Log";
|
|
986
1075
|
AluOp$1["Sqrt"] = "Sqrt";
|
|
@@ -1011,6 +1100,8 @@ const AluGroup = {
|
|
|
1011
1100
|
Unary: new Set([
|
|
1012
1101
|
AluOp.Sin,
|
|
1013
1102
|
AluOp.Cos,
|
|
1103
|
+
AluOp.Asin,
|
|
1104
|
+
AluOp.Atan,
|
|
1014
1105
|
AluOp.Exp,
|
|
1015
1106
|
AluOp.Log,
|
|
1016
1107
|
AluOp.Sqrt,
|
|
@@ -1034,6 +1125,8 @@ const AluGroup = {
|
|
|
1034
1125
|
RequiredFloat: new Set([
|
|
1035
1126
|
AluOp.Sin,
|
|
1036
1127
|
AluOp.Cos,
|
|
1128
|
+
AluOp.Asin,
|
|
1129
|
+
AluOp.Atan,
|
|
1037
1130
|
AluOp.Exp,
|
|
1038
1131
|
AluOp.Log,
|
|
1039
1132
|
AluOp.Sqrt,
|
|
@@ -1065,7 +1158,7 @@ var Kernel = class {
|
|
|
1065
1158
|
this.exp = exp.simplify();
|
|
1066
1159
|
}
|
|
1067
1160
|
hash(state) {
|
|
1068
|
-
state.update(this.nargs
|
|
1161
|
+
state.update(this.nargs).update(this.size).update(this.exp).update(this.reduction);
|
|
1069
1162
|
}
|
|
1070
1163
|
pprint() {
|
|
1071
1164
|
let details = PPrint.pp(`exp = ${this.exp}`);
|
|
@@ -1111,7 +1204,7 @@ var Reduction = class {
|
|
|
1111
1204
|
this.epilogue = epilogue.simplify();
|
|
1112
1205
|
}
|
|
1113
1206
|
hash(state) {
|
|
1114
|
-
state.update(this.dtype
|
|
1207
|
+
state.update(this.dtype).update(this.op).update(this.size).update(this.epilogue);
|
|
1115
1208
|
}
|
|
1116
1209
|
toString() {
|
|
1117
1210
|
return `${this.op}{${this.size}} -> ${this.epilogue}`;
|
|
@@ -2283,78 +2376,92 @@ function wasm_log(cg) {
|
|
|
2283
2376
|
});
|
|
2284
2377
|
}
|
|
2285
2378
|
/**
|
|
2286
|
-
*
|
|
2379
|
+
* Common helper to approximate sin(x) and cos(x).
|
|
2287
2380
|
*
|
|
2288
2381
|
* Method: reduce to y in [-π, π], then quadrant via q = round(y/(π/2))
|
|
2289
|
-
* z = y - q*(π/2); use
|
|
2382
|
+
* z = y - q*(π/2); use one of two polynomials on z:
|
|
2290
2383
|
* sin(z) ≈ z + z^3*(-1/6) + z^5*(1/120) + z^7*(-1/5040)
|
|
2384
|
+
* cos(z) ≈ 1 + z^2*(-1/2) + z^4*(1/24) + z^6*(-1/720)
|
|
2385
|
+
*/
|
|
2386
|
+
function _sincos(cg) {
|
|
2387
|
+
const y = cg.local.declare(cg.f32);
|
|
2388
|
+
const qf = cg.local.declare(cg.f32);
|
|
2389
|
+
const q = cg.local.declare(cg.i32);
|
|
2390
|
+
const z = cg.local.declare(cg.f32);
|
|
2391
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2392
|
+
const sz = cg.local.declare(cg.f32);
|
|
2393
|
+
const cz = cg.local.declare(cg.f32);
|
|
2394
|
+
cg.local.get(0);
|
|
2395
|
+
cg.local.get(0);
|
|
2396
|
+
cg.f32.const(1 / (2 * Math.PI));
|
|
2397
|
+
cg.f32.mul();
|
|
2398
|
+
cg.f32.nearest();
|
|
2399
|
+
cg.local.tee(qf);
|
|
2400
|
+
cg.f32.const(2 * Math.PI);
|
|
2401
|
+
cg.f32.mul();
|
|
2402
|
+
cg.f32.sub();
|
|
2403
|
+
cg.local.set(y);
|
|
2404
|
+
cg.local.get(y);
|
|
2405
|
+
cg.f32.const(2 / Math.PI);
|
|
2406
|
+
cg.f32.mul();
|
|
2407
|
+
cg.f32.nearest();
|
|
2408
|
+
cg.local.tee(qf);
|
|
2409
|
+
cg.i32.trunc_f32_s();
|
|
2410
|
+
cg.local.set(q);
|
|
2411
|
+
cg.local.get(y);
|
|
2412
|
+
cg.local.get(qf);
|
|
2413
|
+
cg.f32.const(Math.PI / 2);
|
|
2414
|
+
cg.f32.mul();
|
|
2415
|
+
cg.f32.sub();
|
|
2416
|
+
cg.local.tee(z);
|
|
2417
|
+
cg.local.get(z);
|
|
2418
|
+
cg.f32.mul();
|
|
2419
|
+
cg.local.set(z2);
|
|
2420
|
+
cg.f32.const(-1 / 5040);
|
|
2421
|
+
cg.local.get(z2);
|
|
2422
|
+
cg.f32.mul();
|
|
2423
|
+
cg.f32.const(1 / 120);
|
|
2424
|
+
cg.f32.add();
|
|
2425
|
+
cg.local.get(z2);
|
|
2426
|
+
cg.f32.mul();
|
|
2427
|
+
cg.f32.const(-1 / 6);
|
|
2428
|
+
cg.f32.add();
|
|
2429
|
+
cg.local.get(z2);
|
|
2430
|
+
cg.f32.mul();
|
|
2431
|
+
cg.f32.const(1);
|
|
2432
|
+
cg.f32.add();
|
|
2433
|
+
cg.local.get(z);
|
|
2434
|
+
cg.f32.mul();
|
|
2435
|
+
cg.local.set(sz);
|
|
2436
|
+
cg.f32.const(-1 / 720);
|
|
2437
|
+
cg.local.get(z2);
|
|
2438
|
+
cg.f32.mul();
|
|
2439
|
+
cg.f32.const(1 / 24);
|
|
2440
|
+
cg.f32.add();
|
|
2441
|
+
cg.local.get(z2);
|
|
2442
|
+
cg.f32.mul();
|
|
2443
|
+
cg.f32.const(-1 / 2);
|
|
2444
|
+
cg.f32.add();
|
|
2445
|
+
cg.local.get(z2);
|
|
2446
|
+
cg.f32.mul();
|
|
2447
|
+
cg.f32.const(1);
|
|
2448
|
+
cg.f32.add();
|
|
2449
|
+
cg.local.set(cz);
|
|
2450
|
+
return {
|
|
2451
|
+
q,
|
|
2452
|
+
sz,
|
|
2453
|
+
cz
|
|
2454
|
+
};
|
|
2455
|
+
}
|
|
2456
|
+
/**
|
|
2457
|
+
* Approximate sin(x).
|
|
2458
|
+
*
|
|
2459
|
+
* Quadrant mapping: k=q mod 4: 0: +sz, 1: +cz, 2: -sz, 3: -cz
|
|
2291
2460
|
*/
|
|
2292
2461
|
function wasm_sin(cg) {
|
|
2293
2462
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2294
|
-
const
|
|
2295
|
-
const qf = cg.local.declare(cg.f32);
|
|
2296
|
-
const q = cg.local.declare(cg.i32);
|
|
2297
|
-
const z = cg.local.declare(cg.f32);
|
|
2298
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2299
|
-
const sz = cg.local.declare(cg.f32);
|
|
2300
|
-
const cz = cg.local.declare(cg.f32);
|
|
2463
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2301
2464
|
const mag = cg.local.declare(cg.f32);
|
|
2302
|
-
cg.local.get(0);
|
|
2303
|
-
cg.local.get(0);
|
|
2304
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2305
|
-
cg.f32.mul();
|
|
2306
|
-
cg.f32.nearest();
|
|
2307
|
-
cg.local.tee(qf);
|
|
2308
|
-
cg.f32.const(2 * Math.PI);
|
|
2309
|
-
cg.f32.mul();
|
|
2310
|
-
cg.f32.sub();
|
|
2311
|
-
cg.local.set(y);
|
|
2312
|
-
cg.local.get(y);
|
|
2313
|
-
cg.f32.const(2 / Math.PI);
|
|
2314
|
-
cg.f32.mul();
|
|
2315
|
-
cg.f32.nearest();
|
|
2316
|
-
cg.local.tee(qf);
|
|
2317
|
-
cg.i32.trunc_f32_s();
|
|
2318
|
-
cg.local.set(q);
|
|
2319
|
-
cg.local.get(y);
|
|
2320
|
-
cg.local.get(qf);
|
|
2321
|
-
cg.f32.const(Math.PI / 2);
|
|
2322
|
-
cg.f32.mul();
|
|
2323
|
-
cg.f32.sub();
|
|
2324
|
-
cg.local.tee(z);
|
|
2325
|
-
cg.local.get(z);
|
|
2326
|
-
cg.f32.mul();
|
|
2327
|
-
cg.local.set(z2);
|
|
2328
|
-
cg.f32.const(-1 / 5040);
|
|
2329
|
-
cg.local.get(z2);
|
|
2330
|
-
cg.f32.mul();
|
|
2331
|
-
cg.f32.const(1 / 120);
|
|
2332
|
-
cg.f32.add();
|
|
2333
|
-
cg.local.get(z2);
|
|
2334
|
-
cg.f32.mul();
|
|
2335
|
-
cg.f32.const(-1 / 6);
|
|
2336
|
-
cg.f32.add();
|
|
2337
|
-
cg.local.get(z2);
|
|
2338
|
-
cg.f32.mul();
|
|
2339
|
-
cg.f32.const(1);
|
|
2340
|
-
cg.f32.add();
|
|
2341
|
-
cg.local.get(z);
|
|
2342
|
-
cg.f32.mul();
|
|
2343
|
-
cg.local.set(sz);
|
|
2344
|
-
cg.f32.const(-1 / 720);
|
|
2345
|
-
cg.local.get(z2);
|
|
2346
|
-
cg.f32.mul();
|
|
2347
|
-
cg.f32.const(1 / 24);
|
|
2348
|
-
cg.f32.add();
|
|
2349
|
-
cg.local.get(z2);
|
|
2350
|
-
cg.f32.mul();
|
|
2351
|
-
cg.f32.const(-1 / 2);
|
|
2352
|
-
cg.f32.add();
|
|
2353
|
-
cg.local.get(z2);
|
|
2354
|
-
cg.f32.mul();
|
|
2355
|
-
cg.f32.const(1);
|
|
2356
|
-
cg.f32.add();
|
|
2357
|
-
cg.local.set(cz);
|
|
2358
2465
|
cg.local.get(cz);
|
|
2359
2466
|
cg.local.get(sz);
|
|
2360
2467
|
cg.local.get(q);
|
|
@@ -2373,75 +2480,12 @@ function wasm_sin(cg) {
|
|
|
2373
2480
|
/**
|
|
2374
2481
|
* Approximate cos(x).
|
|
2375
2482
|
*
|
|
2376
|
-
*
|
|
2377
|
-
* k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2483
|
+
* Quadrant mapping: k=q mod 4: 0: +cz, 1: -sz, 2: -cz, 3: +sz
|
|
2378
2484
|
*/
|
|
2379
2485
|
function wasm_cos(cg) {
|
|
2380
2486
|
return cg.function([cg.f32], [cg.f32], () => {
|
|
2381
|
-
const
|
|
2382
|
-
const qf = cg.local.declare(cg.f32);
|
|
2383
|
-
const q = cg.local.declare(cg.i32);
|
|
2384
|
-
const z = cg.local.declare(cg.f32);
|
|
2385
|
-
const z2 = cg.local.declare(cg.f32);
|
|
2386
|
-
const sz = cg.local.declare(cg.f32);
|
|
2387
|
-
const cz = cg.local.declare(cg.f32);
|
|
2487
|
+
const { q, sz, cz } = _sincos(cg);
|
|
2388
2488
|
const mag = cg.local.declare(cg.f32);
|
|
2389
|
-
cg.local.get(0);
|
|
2390
|
-
cg.local.get(0);
|
|
2391
|
-
cg.f32.const(1 / (2 * Math.PI));
|
|
2392
|
-
cg.f32.mul();
|
|
2393
|
-
cg.f32.nearest();
|
|
2394
|
-
cg.local.tee(qf);
|
|
2395
|
-
cg.f32.const(2 * Math.PI);
|
|
2396
|
-
cg.f32.mul();
|
|
2397
|
-
cg.f32.sub();
|
|
2398
|
-
cg.local.set(y);
|
|
2399
|
-
cg.local.get(y);
|
|
2400
|
-
cg.f32.const(2 / Math.PI);
|
|
2401
|
-
cg.f32.mul();
|
|
2402
|
-
cg.f32.nearest();
|
|
2403
|
-
cg.local.tee(qf);
|
|
2404
|
-
cg.i32.trunc_f32_s();
|
|
2405
|
-
cg.local.set(q);
|
|
2406
|
-
cg.local.get(y);
|
|
2407
|
-
cg.local.get(qf);
|
|
2408
|
-
cg.f32.const(Math.PI / 2);
|
|
2409
|
-
cg.f32.mul();
|
|
2410
|
-
cg.f32.sub();
|
|
2411
|
-
cg.local.tee(z);
|
|
2412
|
-
cg.local.get(z);
|
|
2413
|
-
cg.f32.mul();
|
|
2414
|
-
cg.local.set(z2);
|
|
2415
|
-
cg.f32.const(-1 / 5040);
|
|
2416
|
-
cg.local.get(z2);
|
|
2417
|
-
cg.f32.mul();
|
|
2418
|
-
cg.f32.const(1 / 120);
|
|
2419
|
-
cg.f32.add();
|
|
2420
|
-
cg.local.get(z2);
|
|
2421
|
-
cg.f32.mul();
|
|
2422
|
-
cg.f32.const(-1 / 6);
|
|
2423
|
-
cg.f32.add();
|
|
2424
|
-
cg.local.get(z2);
|
|
2425
|
-
cg.f32.mul();
|
|
2426
|
-
cg.f32.const(1);
|
|
2427
|
-
cg.f32.add();
|
|
2428
|
-
cg.local.get(z);
|
|
2429
|
-
cg.f32.mul();
|
|
2430
|
-
cg.local.set(sz);
|
|
2431
|
-
cg.f32.const(-1 / 720);
|
|
2432
|
-
cg.local.get(z2);
|
|
2433
|
-
cg.f32.mul();
|
|
2434
|
-
cg.f32.const(1 / 24);
|
|
2435
|
-
cg.f32.add();
|
|
2436
|
-
cg.local.get(z2);
|
|
2437
|
-
cg.f32.mul();
|
|
2438
|
-
cg.f32.const(-1 / 2);
|
|
2439
|
-
cg.f32.add();
|
|
2440
|
-
cg.local.get(z2);
|
|
2441
|
-
cg.f32.mul();
|
|
2442
|
-
cg.f32.const(1);
|
|
2443
|
-
cg.f32.add();
|
|
2444
|
-
cg.local.set(cz);
|
|
2445
2489
|
cg.local.get(sz);
|
|
2446
2490
|
cg.local.get(cz);
|
|
2447
2491
|
cg.local.get(q);
|
|
@@ -2459,6 +2503,100 @@ function wasm_cos(cg) {
|
|
|
2459
2503
|
cg.select();
|
|
2460
2504
|
});
|
|
2461
2505
|
}
|
|
2506
|
+
/** Helper function for approximating arctan(x). */
|
|
2507
|
+
function _atan(cg) {
|
|
2508
|
+
const x = cg.local.declare(cg.f32);
|
|
2509
|
+
const abs_x = cg.local.declare(cg.f32);
|
|
2510
|
+
const z = cg.local.declare(cg.f32);
|
|
2511
|
+
const z2 = cg.local.declare(cg.f32);
|
|
2512
|
+
const p = cg.local.declare(cg.f32);
|
|
2513
|
+
cg.local.set(x);
|
|
2514
|
+
cg.local.get(x);
|
|
2515
|
+
cg.f32.abs();
|
|
2516
|
+
cg.local.set(abs_x);
|
|
2517
|
+
cg.f32.const(1);
|
|
2518
|
+
cg.local.get(abs_x);
|
|
2519
|
+
cg.f32.div();
|
|
2520
|
+
cg.local.get(abs_x);
|
|
2521
|
+
cg.local.get(abs_x);
|
|
2522
|
+
cg.f32.const(1);
|
|
2523
|
+
cg.f32.ge();
|
|
2524
|
+
cg.select();
|
|
2525
|
+
cg.local.set(z);
|
|
2526
|
+
cg.local.get(z);
|
|
2527
|
+
cg.local.get(z);
|
|
2528
|
+
cg.f32.mul();
|
|
2529
|
+
cg.local.set(z2);
|
|
2530
|
+
cg.f32.const(.0415796528637);
|
|
2531
|
+
cg.local.get(z2);
|
|
2532
|
+
cg.f32.mul();
|
|
2533
|
+
cg.f32.const(.661705427875);
|
|
2534
|
+
cg.f32.add();
|
|
2535
|
+
cg.local.get(z2);
|
|
2536
|
+
cg.f32.mul();
|
|
2537
|
+
cg.f32.const(.999998614341);
|
|
2538
|
+
cg.f32.add();
|
|
2539
|
+
cg.f32.const(.173698870181);
|
|
2540
|
+
cg.local.get(z2);
|
|
2541
|
+
cg.f32.mul();
|
|
2542
|
+
cg.f32.const(.994987933645);
|
|
2543
|
+
cg.f32.add();
|
|
2544
|
+
cg.local.get(z2);
|
|
2545
|
+
cg.f32.mul();
|
|
2546
|
+
cg.f32.const(1);
|
|
2547
|
+
cg.f32.add();
|
|
2548
|
+
cg.f32.div();
|
|
2549
|
+
cg.local.get(z);
|
|
2550
|
+
cg.f32.mul();
|
|
2551
|
+
cg.local.set(p);
|
|
2552
|
+
cg.f32.const(Math.PI / 2);
|
|
2553
|
+
cg.local.get(p);
|
|
2554
|
+
cg.f32.sub();
|
|
2555
|
+
cg.local.get(p);
|
|
2556
|
+
cg.local.get(abs_x);
|
|
2557
|
+
cg.f32.const(1);
|
|
2558
|
+
cg.f32.ge();
|
|
2559
|
+
cg.select();
|
|
2560
|
+
cg.local.get(x);
|
|
2561
|
+
cg.f32.copysign();
|
|
2562
|
+
}
|
|
2563
|
+
/**
|
|
2564
|
+
* Approximate atan(x).
|
|
2565
|
+
*
|
|
2566
|
+
* Method: if |x| < 1, use rational approximation: atan(x) ≈ x * P(x^2) / Q(x^2)
|
|
2567
|
+
* where P(u) = A0 + A1*u + A2*u^2 (degree 2)
|
|
2568
|
+
* Q(u) = 1 + B1*u + B2*u^2 (degree 2)
|
|
2569
|
+
* if |x| >= 1, use: atan(x) = sign(x)*π/2 - atan(1/x)
|
|
2570
|
+
* (fitted coefficients, max error ~5e-7 on [0,1])
|
|
2571
|
+
*/
|
|
2572
|
+
function wasm_atan(cg) {
|
|
2573
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2574
|
+
cg.local.get(0);
|
|
2575
|
+
_atan(cg);
|
|
2576
|
+
});
|
|
2577
|
+
}
|
|
2578
|
+
/**
|
|
2579
|
+
* Approximate asin(x).
|
|
2580
|
+
*
|
|
2581
|
+
* Method: asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2)))
|
|
2582
|
+
*/
|
|
2583
|
+
function wasm_asin(cg) {
|
|
2584
|
+
return cg.function([cg.f32], [cg.f32], () => {
|
|
2585
|
+
cg.local.get(0);
|
|
2586
|
+
cg.f32.const(1);
|
|
2587
|
+
cg.local.get(0);
|
|
2588
|
+
cg.local.get(0);
|
|
2589
|
+
cg.f32.mul();
|
|
2590
|
+
cg.f32.sub();
|
|
2591
|
+
cg.f32.sqrt();
|
|
2592
|
+
cg.f32.const(1);
|
|
2593
|
+
cg.f32.add();
|
|
2594
|
+
cg.f32.div();
|
|
2595
|
+
_atan(cg);
|
|
2596
|
+
cg.f32.const(2);
|
|
2597
|
+
cg.f32.mul();
|
|
2598
|
+
});
|
|
2599
|
+
}
|
|
2462
2600
|
/**
|
|
2463
2601
|
* Threefry2x32 pseudorandom number generator.
|
|
2464
2602
|
*
|
|
@@ -3338,6 +3476,8 @@ function codegenWasm(kernel) {
|
|
|
3338
3476
|
const funcs = {};
|
|
3339
3477
|
if (distinctOps.has(AluOp.Sin)) funcs.sin = wasm_sin(cg);
|
|
3340
3478
|
if (distinctOps.has(AluOp.Cos)) funcs.cos = wasm_cos(cg);
|
|
3479
|
+
if (distinctOps.has(AluOp.Asin)) funcs.asin = wasm_asin(cg);
|
|
3480
|
+
if (distinctOps.has(AluOp.Atan)) funcs.atan = wasm_atan(cg);
|
|
3341
3481
|
if (distinctOps.has(AluOp.Exp)) funcs.exp = wasm_exp(cg);
|
|
3342
3482
|
if (distinctOps.has(AluOp.Log)) funcs.log = wasm_log(cg);
|
|
3343
3483
|
if (distinctOps.has(AluOp.Threefry2x32)) funcs.threefry2x32 = wasm_threefry2x32(cg);
|
|
@@ -3489,6 +3629,8 @@ function translateExp(cg, funcs, exp, ctx) {
|
|
|
3489
3629
|
else throw new UnsupportedOpError(op, dtype, "wasm");
|
|
3490
3630
|
} else if (AluGroup.Unary.has(op)) if (op === AluOp.Sin) gen(src[0]), cg.call(funcs.sin);
|
|
3491
3631
|
else if (op === AluOp.Cos) gen(src[0]), cg.call(funcs.cos);
|
|
3632
|
+
else if (op === AluOp.Asin) gen(src[0]), cg.call(funcs.asin);
|
|
3633
|
+
else if (op === AluOp.Atan) gen(src[0]), cg.call(funcs.atan);
|
|
3492
3634
|
else if (op === AluOp.Exp) gen(src[0]), cg.call(funcs.exp);
|
|
3493
3635
|
else if (op === AluOp.Log) gen(src[0]), cg.call(funcs.log);
|
|
3494
3636
|
else if (op === AluOp.Sqrt) gen(src[0]), cg.f32.sqrt();
|
|
@@ -3587,10 +3729,11 @@ let defaultBackend = "wasm";
|
|
|
3587
3729
|
const initializedBackends = /* @__PURE__ */ new Map();
|
|
3588
3730
|
initializedBackends.set("cpu", new CpuBackend());
|
|
3589
3731
|
initializedBackends.set("wasm", new WasmBackend());
|
|
3590
|
-
/**
|
|
3591
|
-
function
|
|
3592
|
-
if (initializedBackends.has(device)) defaultBackend = device;
|
|
3732
|
+
/** Configure the default device for arrays. */
|
|
3733
|
+
function defaultDevice(device) {
|
|
3734
|
+
if (device !== void 0) if (initializedBackends.has(device)) defaultBackend = device;
|
|
3593
3735
|
else throw new Error(`Backend not initialized: ${device}`);
|
|
3736
|
+
return defaultBackend;
|
|
3594
3737
|
}
|
|
3595
3738
|
/**
|
|
3596
3739
|
* Initialize `jax-js` library backends.
|
|
@@ -3617,7 +3760,7 @@ async function createBackend(device) {
|
|
|
3617
3760
|
if (!navigator.gpu) return null;
|
|
3618
3761
|
const adapter = await navigator.gpu.requestAdapter({ powerPreference: "high-performance" });
|
|
3619
3762
|
if (!adapter) return null;
|
|
3620
|
-
const { WebGPUBackend } = await import("./webgpu-
|
|
3763
|
+
const { WebGPUBackend } = await import("./webgpu-ow0Pn_6q.js");
|
|
3621
3764
|
const importantLimits = [
|
|
3622
3765
|
"maxBufferSize",
|
|
3623
3766
|
"maxComputeInvocationsPerWorkgroup",
|
|
@@ -3670,4 +3813,4 @@ var UnsupportedOpError = class extends Error {
|
|
|
3670
3813
|
};
|
|
3671
3814
|
|
|
3672
3815
|
//#endregion
|
|
3673
|
-
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, partitionList, prod, range, recursiveFlatten, rep, runWithCache,
|
|
3816
|
+
export { AluExp, AluGroup, AluOp, AluVar, DEBUG, DType, Executable, FpHash, Kernel, PPrint, Reduction, ShapeTracker, SlotError, UnsupportedOpError, accessorAluExp, accessorGlobal, byteWidth, checkAxis, deepEqual, defaultDevice, devices, dtypedArray, dtypedJsArray, findPow2, getBackend, init, invertPermutation, isFloatDtype, isNumberPair, isPermutation, normalizeAxis, partitionList, prod, promoteTypes, range, recursiveFlatten, rep, runWithCache, setDebug, strip1, toposort, tuneWebgpu, union, unravelAlu, unzip2, zip, zipn };
|